encryption: Change the way olm sessions are stored.

This commit is contained in:
poljar (Damir Jelić) 2018-05-06 15:03:50 +02:00
parent 658ec67ff4
commit 8f7dac4a0d
2 changed files with 43 additions and 32 deletions

View file

@ -39,6 +39,7 @@ except ImportError:
matrix.globals.ENCRYPTION = False
from matrix.globals import W, SERVERS
from matrix.utils import sanitize_id
from matrix.utf import utf8_decode
@ -221,7 +222,7 @@ class Olm():
self._insert_acc_to_db()
if not sessions:
sessions = defaultdict(list)
sessions = defaultdict(lambda: defaultdict(list))
if not inbound_group_sessions:
inbound_group_sessions = defaultdict(dict)
@ -233,8 +234,6 @@ class Olm():
W.prnt("", "matrix: Creating session for {}".format(sender))
session = InboundSession(self.account, message, sender_key)
W.prnt("", "matrix: Created session for {}".format(sender))
self.sessions[sender].append(session)
self._store_session(sender, session)
self.account.remove_one_time_keys(session)
self._update_acc_in_db()
@ -250,14 +249,19 @@ class Olm():
def decrypt(self, sender, sender_key, message):
plaintext = None
for session in self.sessions[sender]:
for device_id, session_list in self.sessions[sender].items():
for session in session_list:
W.prnt("", "Trying session for device {}".format(device_id))
try:
if isinstance(message, OlmPreKeyMessage):
if not session.matches(message):
continue
W.prnt("", "Decrypting using existing session")
plaintext = session.decrypt(message)
return plaintext
parsed_plaintext = json.loads(plaintext, encoding='utf-8')
W.prnt("", "Decrypted using existing session")
return parsed_plaintext
except OlmSessionError:
pass
@ -268,7 +272,12 @@ class Olm():
try:
plaintext = session.decrypt(message)
return plaintext
parsed_plaintext = json.loads(plaintext, encoding='utf-8')
device_id = sanitize_id(parsed_plaintext["sender_device"])
self.sessions[sender][device_id].append(session)
self._store_session(sender, device_id, session)
return parsed_plaintext
except OlmSessionError:
return None
@ -300,7 +309,7 @@ class Olm():
row = cursor.fetchone()
account_pickle = row[0]
cursor.execute("select user, pickle from olmsessions")
cursor.execute("select user, device_id, pickle from olmsessions")
db_sessions = cursor.fetchall()
cursor.execute("select room_id, pickle from inbound_group_sessions")
@ -308,15 +317,15 @@ class Olm():
cursor.close()
sessions = defaultdict(list)
sessions = defaultdict(lambda: defaultdict(list))
inbound_group_sessions = defaultdict(dict)
try:
account = Account.from_pickle(bytes(account_pickle, "utf-8"))
for db_session in db_sessions:
sessions[db_session[0]].append(
Session.from_pickle(bytes(db_session[1], "utf-8")))
sessions[db_session[0]][db_session[1]].append(
Session.from_pickle(bytes(db_session[2], "utf-8")))
for db_session in db_inbound_group_sessions:
session = InboundGroupSession.from_pickle(
@ -338,11 +347,14 @@ class Olm():
def _update_sessions_in_db(self):
cursor = self.database.cursor()
for user, session_list in self.sessions.items():
for user, session_dict in self.sessions.items():
for device_id, session_list in session_dict.items():
for session in session_list:
cursor.execute("""update olmsessions set pickle=?
where user = ? and session_id = ?""",
(session.pickle(), user, session.id()))
where user = ? and session_id = ? and
device_id = ?""",
(session.pickle(), user, session.id(),
device_id))
self.database.commit()
cursor.close()
@ -359,11 +371,11 @@ class Olm():
cursor.close()
def _store_session(self, user, session):
def _store_session(self, user, device_id, session):
cursor = self.database.cursor()
cursor.execute("insert into olmsessions values(?,?,?)",
(user, session.id(), session.pickle()))
cursor.execute("insert into olmsessions values(?,?,?,?)",
(user, device_id, session.id(), session.pickle()))
self.database.commit()
@ -399,7 +411,7 @@ class Olm():
and name='olmsessions'""")
if not cursor.fetchone():
cursor.execute("""create table olmsessions (user text,
session_id text, pickle text)""")
device_id text, session_id text, pickle text)""")
database.commit()
cursor.execute("""select name from sqlite_master where type='table'

View file

@ -522,10 +522,9 @@ class MatrixSyncEvent(MatrixEvent):
return None
# TODO check sender key
parsed_plaintext = json.loads(plaintext, encoding='utf-8')
decrypted_sender = parsed_plaintext["sender"]
decrypted_recepient = parsed_plaintext["recipient"]
decrypted_recepient_key = parsed_plaintext["recipient_keys"]["ed25519"]
decrypted_sender = plaintext["sender"]
decrypted_recepient = plaintext["recipient"]
decrypted_recepient_key = plaintext["recipient_keys"]["ed25519"]
if (sender != decrypted_sender or
server.user_id != decrypted_recepient or
@ -536,10 +535,10 @@ class MatrixSyncEvent(MatrixEvent):
W.prnt("", error_message)
return None
if parsed_plaintext["type"] != "m.room_key":
if plaintext["type"] != "m.room_key":
return None
MatrixSyncEvent._handle_key_event(server, sender_key, parsed_plaintext)
MatrixSyncEvent._handle_key_event(server, sender_key, plaintext)
@staticmethod
def _handle_key_event(server, sender_key, parsed_dict):