From 0bd20cc3336985c158e1a19f0c78fe9e42590160 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?poljar=20=28Damir=20Jeli=C4=87=29?= Date: Wed, 11 Apr 2018 14:00:37 +0200 Subject: [PATCH] encryption: Store the account and sessions in a sqlite db. --- matrix/encryption.py | 175 ++++++++++++++++++++++++++++++++++++------- matrix/events.py | 2 +- matrix/server.py | 6 ++ 3 files changed, 157 insertions(+), 26 deletions(-) diff --git a/matrix/encryption.py b/matrix/encryption.py index aa1429a..28e684a 100644 --- a/matrix/encryption.py +++ b/matrix/encryption.py @@ -19,9 +19,10 @@ from __future__ import unicode_literals import os import json +import sqlite3 # pylint: disable=redefined-builtin -from builtins import str +from builtins import str, bytes from collections import defaultdict from functools import wraps @@ -31,7 +32,7 @@ import matrix.globals try: from olm.account import Account, OlmAccountError - from olm.session import (InboundSession, OlmSessionError, OlmMessage, + from olm.session import (Session, InboundSession, OlmSessionError, OlmPreKeyMessage) from olm.group_session import InboundGroupSession, OlmGroupSessionError except ImportError: @@ -160,37 +161,55 @@ class Olm(): user, device_id, session_path, + database=None, account=None, - sessions=defaultdict(list), - group_sessions=defaultdict(dict) + sessions=None, + inbound_group_sessions=None ): # type: (str, str, str, Account, Dict[str, List[Session]) -> None self.user = user self.device_id = device_id self.session_path = session_path + self.database = database + + if not database: + db_file = "{}_{}.db".format(user, device_id) + db_path = os.path.join(session_path, db_file) + self.database = sqlite3.connect(db_path) + Olm._check_db_tables(self.database) + if account: self.account = account + else: self.account = Account() + self._insert_acc_to_db() + + if not sessions: + sessions = defaultdict(list) + + if not inbound_group_sessions: + inbound_group_sessions = defaultdict(dict) self.sessions = sessions - self.group_sessions = group_sessions + self.inbound_group_sessions = inbound_group_sessions def _create_session(self, sender, sender_key, message): W.prnt("", "matrix: Creating session for {}".format(sender)) session = InboundSession(self.account, message, sender_key) W.prnt("", "matrix: Created session for {}".format(sender)) self.sessions[sender].append(session) - # self.account.remove_one_time_keys(session) - # TODO store account here + self._store_session(sender, session) + self.account.remove_one_time_keys(session) + self._update_acc_in_db() return session def create_group_session(self, room_id, session_id, session_key): W.prnt("", "matrix: Creating group session for {}".format(room_id)) session = InboundGroupSession(session_key) - self.group_sessions[room_id][session_id] = session - # TODO store account here + self.inbound_group_sessions[room_id][session_id] = session + self._store_inbound_group_session(room_id, session) @encrypt_enabled def decrypt(self, sender, sender_key, message): @@ -217,10 +236,10 @@ class Olm(): @encrypt_enabled def group_decrypt(self, room_id, session_id, ciphertext): - if session_id not in self.group_sessions[room_id]: + if session_id not in self.inbound_group_sessions[room_id]: return None - session = self.group_sessions[room_id][session_id] + session = self.inbound_group_sessions[room_id][session_id] try: plaintext = session.decrypt(ciphertext) except OlmGroupSessionError: @@ -232,28 +251,134 @@ class Olm(): @encrypt_enabled def from_session_dir(cls, user, device_id, session_path): # type: (Server) -> Olm - account_file_name = "{}_{}.account".format(user, device_id) - path = os.path.join(session_path, account_file_name) + db_file = "{}_{}.db".format(user, device_id) + db_path = os.path.join(session_path, db_file) + database = sqlite3.connect(db_path) + Olm._check_db_tables(database) + + cursor = database.cursor() + + cursor.execute("select pickle from olmaccount where user = ?", (user,)) + row = cursor.fetchone() + account_pickle = row[0] + + cursor.execute("select user, pickle from olmsessions") + db_sessions = cursor.fetchall() + + cursor.execute("select room_id, pickle from inbound_group_sessions") + db_inbound_group_sessions = cursor.fetchall() + + cursor.close() + + sessions = defaultdict(list) + inbound_group_sessions = defaultdict(dict) try: - with open(path, "rb") as f: - pickle = f.read() - account = Account.from_pickle(pickle) - return cls(user, device_id, session_path, account) - except OlmAccountError as error: + account = Account.from_pickle(bytes(account_pickle, "utf-8")) + + for db_session in db_sessions: + sessions[db_session[0]].append( + Session.from_pickle(bytes(db_session[1], "utf-8"))) + + for db_session in db_inbound_group_sessions: + session = InboundGroupSession.from_pickle( + bytes(db_session[1], "utf-8")) + inbound_group_sessions[db_session[0]][session.id] = session + + return cls(user, device_id, session_path, database, account, + sessions, inbound_group_sessions) + except (OlmAccountError, OlmSessionError) as error: raise EncryptionError(error) + def _update_acc_in_db(self): + cursor = self.database.cursor() + cursor.execute("update olmaccount set pickle=? where user = ?", + (self.account.pickle(), self.user)) + self.database.commit() + cursor.close() + + def _update_sessions_in_db(self): + cursor = self.database.cursor() + + for user, session_list in self.sessions.items(): + for session in session_list: + cursor.execute("""update olmsessions set pickle=? + where user = ? and session_id = ?""", + (session.pickle(), user, session.id())) + self.database.commit() + + cursor.close() + + def _update_inbound_group_sessions(self): + cursor = self.database.cursor() + + for room_id, session_dict in self.inbound_group_sessions.items(): + for session in session_dict.values(): + cursor.execute("""update inbound_group_sessions set pickle=? + where room_id = ? and session_id = ?""", + (session.pickle(), room_id, session.id())) + self.database.commit() + + cursor.close() + + def _store_session(self, user, session): + cursor = self.database.cursor() + + cursor.execute("insert into olmsessions values(?,?,?)", + (user, session.id(), session.pickle())) + + self.database.commit() + + cursor.close() + + def _store_inbound_group_session(self, room_id, session): + cursor = self.database.cursor() + + cursor.execute("insert into inbound_group_sessions values(?,?,?)", + (room_id, session.id, session.pickle())) + + self.database.commit() + + cursor.close() + + def _insert_acc_to_db(self): + cursor = self.database.cursor() + cursor.execute("insert into olmaccount values (?,?)", + (self.user, self.account.pickle())) + self.database.commit() + cursor.close() + + @staticmethod + def _check_db_tables(database): + cursor = database.cursor() + cursor.execute("""select name from sqlite_master where type='table' + and name='olmaccount'""") + if not cursor.fetchone(): + cursor.execute("create table olmaccount (user text, pickle text)") + database.commit() + + cursor.execute("""select name from sqlite_master where type='table' + and name='olmsessions'""") + if not cursor.fetchone(): + cursor.execute("""create table olmsessions (user text, + session_id text, pickle text)""") + database.commit() + + cursor.execute("""select name from sqlite_master where type='table' + and name='inbound_group_sessions'""") + if not cursor.fetchone(): + cursor.execute("""create table inbound_group_sessions + (room_id text, session_id text, pickle text)""") + database.commit() + + cursor.close() + @encrypt_enabled def to_session_dir(self): # type: (Server) -> None - account_file_name = "{}_{}.account".format(self.user, - self.device_id) - path = os.path.join(self.session_path, account_file_name) - try: - with open(path, "wb") as f: - pickle = self.account.pickle() - f.write(pickle) + self._update_acc_in_db() + self._update_sessions_in_db() except OlmAccountError as error: raise EncryptionError(error) diff --git a/matrix/events.py b/matrix/events.py index ad576a7..e3127e2 100644 --- a/matrix/events.py +++ b/matrix/events.py @@ -508,7 +508,7 @@ class MatrixSyncEvent(MatrixEvent): session_id = content["session_id"] session_key = content["session_key"] - if session_id in olm.group_sessions[room_id]: + if session_id in olm.inbound_group_sessions[room_id]: return olm.create_group_session(room_id, session_id, session_key) diff --git a/matrix/server.py b/matrix/server.py index deb1f80..f5517fa 100644 --- a/matrix/server.py +++ b/matrix/server.py @@ -146,6 +146,12 @@ class MatrixServer: self.device_id, self.get_session_path() ) + message = ("{prefix}matrix: Loaded Olm account for {user} (device:" + "{device})").format(prefix=W.prefix("network"), + user=self.user, + device=self.device_id) + W.prnt("", message) + except FileNotFoundError: pass except EncryptionError as error: