From 8f7dac4a0d655537330839dc031df4ea98c18146 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?poljar=20=28Damir=20Jeli=C4=87=29?= Date: Sun, 6 May 2018 15:03:50 +0200 Subject: [PATCH] encryption: Change the way olm sessions are stored. --- matrix/encryption.py | 64 ++++++++++++++++++++++++++------------------ matrix/events.py | 11 ++++---- 2 files changed, 43 insertions(+), 32 deletions(-) diff --git a/matrix/encryption.py b/matrix/encryption.py index ff7d1d2..92088b1 100644 --- a/matrix/encryption.py +++ b/matrix/encryption.py @@ -39,6 +39,7 @@ except ImportError: matrix.globals.ENCRYPTION = False from matrix.globals import W, SERVERS +from matrix.utils import sanitize_id from matrix.utf import utf8_decode @@ -221,7 +222,7 @@ class Olm(): self._insert_acc_to_db() if not sessions: - sessions = defaultdict(list) + sessions = defaultdict(lambda: defaultdict(list)) if not inbound_group_sessions: inbound_group_sessions = defaultdict(dict) @@ -233,8 +234,6 @@ class Olm(): W.prnt("", "matrix: Creating session for {}".format(sender)) session = InboundSession(self.account, message, sender_key) W.prnt("", "matrix: Created session for {}".format(sender)) - self.sessions[sender].append(session) - self._store_session(sender, session) self.account.remove_one_time_keys(session) self._update_acc_in_db() @@ -250,16 +249,21 @@ class Olm(): def decrypt(self, sender, sender_key, message): plaintext = None - for session in self.sessions[sender]: - try: - if isinstance(message, OlmPreKeyMessage): - if not session.matches(message): - continue + for device_id, session_list in self.sessions[sender].items(): + for session in session_list: + W.prnt("", "Trying session for device {}".format(device_id)) + try: + if isinstance(message, OlmPreKeyMessage): + if not session.matches(message): + continue - plaintext = session.decrypt(message) - return plaintext - except OlmSessionError: - pass + W.prnt("", "Decrypting using existing session") + plaintext = session.decrypt(message) + parsed_plaintext = json.loads(plaintext, encoding='utf-8') + W.prnt("", "Decrypted using existing session") + return parsed_plaintext + except OlmSessionError: + pass try: session = self._create_session(sender, sender_key, message) @@ -268,7 +272,12 @@ class Olm(): try: plaintext = session.decrypt(message) - return plaintext + parsed_plaintext = json.loads(plaintext, encoding='utf-8') + + device_id = sanitize_id(parsed_plaintext["sender_device"]) + self.sessions[sender][device_id].append(session) + self._store_session(sender, device_id, session) + return parsed_plaintext except OlmSessionError: return None @@ -300,7 +309,7 @@ class Olm(): row = cursor.fetchone() account_pickle = row[0] - cursor.execute("select user, pickle from olmsessions") + cursor.execute("select user, device_id, pickle from olmsessions") db_sessions = cursor.fetchall() cursor.execute("select room_id, pickle from inbound_group_sessions") @@ -308,15 +317,15 @@ class Olm(): cursor.close() - sessions = defaultdict(list) + sessions = defaultdict(lambda: defaultdict(list)) inbound_group_sessions = defaultdict(dict) try: account = Account.from_pickle(bytes(account_pickle, "utf-8")) for db_session in db_sessions: - sessions[db_session[0]].append( - Session.from_pickle(bytes(db_session[1], "utf-8"))) + sessions[db_session[0]][db_session[1]].append( + Session.from_pickle(bytes(db_session[2], "utf-8"))) for db_session in db_inbound_group_sessions: session = InboundGroupSession.from_pickle( @@ -338,11 +347,14 @@ class Olm(): def _update_sessions_in_db(self): cursor = self.database.cursor() - for user, session_list in self.sessions.items(): - for session in session_list: - cursor.execute("""update olmsessions set pickle=? - where user = ? and session_id = ?""", - (session.pickle(), user, session.id())) + for user, session_dict in self.sessions.items(): + for device_id, session_list in session_dict.items(): + for session in session_list: + cursor.execute("""update olmsessions set pickle=? + where user = ? and session_id = ? and + device_id = ?""", + (session.pickle(), user, session.id(), + device_id)) self.database.commit() cursor.close() @@ -359,11 +371,11 @@ class Olm(): cursor.close() - def _store_session(self, user, session): + def _store_session(self, user, device_id, session): cursor = self.database.cursor() - cursor.execute("insert into olmsessions values(?,?,?)", - (user, session.id(), session.pickle())) + cursor.execute("insert into olmsessions values(?,?,?,?)", + (user, device_id, session.id(), session.pickle())) self.database.commit() @@ -399,7 +411,7 @@ class Olm(): and name='olmsessions'""") if not cursor.fetchone(): cursor.execute("""create table olmsessions (user text, - session_id text, pickle text)""") + device_id text, session_id text, pickle text)""") database.commit() cursor.execute("""select name from sqlite_master where type='table' diff --git a/matrix/events.py b/matrix/events.py index f79d048..2b96865 100644 --- a/matrix/events.py +++ b/matrix/events.py @@ -522,10 +522,9 @@ class MatrixSyncEvent(MatrixEvent): return None # TODO check sender key - parsed_plaintext = json.loads(plaintext, encoding='utf-8') - decrypted_sender = parsed_plaintext["sender"] - decrypted_recepient = parsed_plaintext["recipient"] - decrypted_recepient_key = parsed_plaintext["recipient_keys"]["ed25519"] + decrypted_sender = plaintext["sender"] + decrypted_recepient = plaintext["recipient"] + decrypted_recepient_key = plaintext["recipient_keys"]["ed25519"] if (sender != decrypted_sender or server.user_id != decrypted_recepient or @@ -536,10 +535,10 @@ class MatrixSyncEvent(MatrixEvent): W.prnt("", error_message) return None - if parsed_plaintext["type"] != "m.room_key": + if plaintext["type"] != "m.room_key": return None - MatrixSyncEvent._handle_key_event(server, sender_key, parsed_plaintext) + MatrixSyncEvent._handle_key_event(server, sender_key, plaintext) @staticmethod def _handle_key_event(server, sender_key, parsed_dict):