encryption: Change the way olm sessions are stored.

This commit is contained in:
poljar (Damir Jelić) 2018-05-06 15:03:50 +02:00
parent 658ec67ff4
commit 8f7dac4a0d
2 changed files with 43 additions and 32 deletions

View file

@ -39,6 +39,7 @@ except ImportError:
matrix.globals.ENCRYPTION = False matrix.globals.ENCRYPTION = False
from matrix.globals import W, SERVERS from matrix.globals import W, SERVERS
from matrix.utils import sanitize_id
from matrix.utf import utf8_decode from matrix.utf import utf8_decode
@ -221,7 +222,7 @@ class Olm():
self._insert_acc_to_db() self._insert_acc_to_db()
if not sessions: if not sessions:
sessions = defaultdict(list) sessions = defaultdict(lambda: defaultdict(list))
if not inbound_group_sessions: if not inbound_group_sessions:
inbound_group_sessions = defaultdict(dict) inbound_group_sessions = defaultdict(dict)
@ -233,8 +234,6 @@ class Olm():
W.prnt("", "matrix: Creating session for {}".format(sender)) W.prnt("", "matrix: Creating session for {}".format(sender))
session = InboundSession(self.account, message, sender_key) session = InboundSession(self.account, message, sender_key)
W.prnt("", "matrix: Created session for {}".format(sender)) W.prnt("", "matrix: Created session for {}".format(sender))
self.sessions[sender].append(session)
self._store_session(sender, session)
self.account.remove_one_time_keys(session) self.account.remove_one_time_keys(session)
self._update_acc_in_db() self._update_acc_in_db()
@ -250,16 +249,21 @@ class Olm():
def decrypt(self, sender, sender_key, message): def decrypt(self, sender, sender_key, message):
plaintext = None plaintext = None
for session in self.sessions[sender]: for device_id, session_list in self.sessions[sender].items():
try: for session in session_list:
if isinstance(message, OlmPreKeyMessage): W.prnt("", "Trying session for device {}".format(device_id))
if not session.matches(message): try:
continue if isinstance(message, OlmPreKeyMessage):
if not session.matches(message):
continue
plaintext = session.decrypt(message) W.prnt("", "Decrypting using existing session")
return plaintext plaintext = session.decrypt(message)
except OlmSessionError: parsed_plaintext = json.loads(plaintext, encoding='utf-8')
pass W.prnt("", "Decrypted using existing session")
return parsed_plaintext
except OlmSessionError:
pass
try: try:
session = self._create_session(sender, sender_key, message) session = self._create_session(sender, sender_key, message)
@ -268,7 +272,12 @@ class Olm():
try: try:
plaintext = session.decrypt(message) plaintext = session.decrypt(message)
return plaintext parsed_plaintext = json.loads(plaintext, encoding='utf-8')
device_id = sanitize_id(parsed_plaintext["sender_device"])
self.sessions[sender][device_id].append(session)
self._store_session(sender, device_id, session)
return parsed_plaintext
except OlmSessionError: except OlmSessionError:
return None return None
@ -300,7 +309,7 @@ class Olm():
row = cursor.fetchone() row = cursor.fetchone()
account_pickle = row[0] account_pickle = row[0]
cursor.execute("select user, pickle from olmsessions") cursor.execute("select user, device_id, pickle from olmsessions")
db_sessions = cursor.fetchall() db_sessions = cursor.fetchall()
cursor.execute("select room_id, pickle from inbound_group_sessions") cursor.execute("select room_id, pickle from inbound_group_sessions")
@ -308,15 +317,15 @@ class Olm():
cursor.close() cursor.close()
sessions = defaultdict(list) sessions = defaultdict(lambda: defaultdict(list))
inbound_group_sessions = defaultdict(dict) inbound_group_sessions = defaultdict(dict)
try: try:
account = Account.from_pickle(bytes(account_pickle, "utf-8")) account = Account.from_pickle(bytes(account_pickle, "utf-8"))
for db_session in db_sessions: for db_session in db_sessions:
sessions[db_session[0]].append( sessions[db_session[0]][db_session[1]].append(
Session.from_pickle(bytes(db_session[1], "utf-8"))) Session.from_pickle(bytes(db_session[2], "utf-8")))
for db_session in db_inbound_group_sessions: for db_session in db_inbound_group_sessions:
session = InboundGroupSession.from_pickle( session = InboundGroupSession.from_pickle(
@ -338,11 +347,14 @@ class Olm():
def _update_sessions_in_db(self): def _update_sessions_in_db(self):
cursor = self.database.cursor() cursor = self.database.cursor()
for user, session_list in self.sessions.items(): for user, session_dict in self.sessions.items():
for session in session_list: for device_id, session_list in session_dict.items():
cursor.execute("""update olmsessions set pickle=? for session in session_list:
where user = ? and session_id = ?""", cursor.execute("""update olmsessions set pickle=?
(session.pickle(), user, session.id())) where user = ? and session_id = ? and
device_id = ?""",
(session.pickle(), user, session.id(),
device_id))
self.database.commit() self.database.commit()
cursor.close() cursor.close()
@ -359,11 +371,11 @@ class Olm():
cursor.close() cursor.close()
def _store_session(self, user, session): def _store_session(self, user, device_id, session):
cursor = self.database.cursor() cursor = self.database.cursor()
cursor.execute("insert into olmsessions values(?,?,?)", cursor.execute("insert into olmsessions values(?,?,?,?)",
(user, session.id(), session.pickle())) (user, device_id, session.id(), session.pickle()))
self.database.commit() self.database.commit()
@ -399,7 +411,7 @@ class Olm():
and name='olmsessions'""") and name='olmsessions'""")
if not cursor.fetchone(): if not cursor.fetchone():
cursor.execute("""create table olmsessions (user text, cursor.execute("""create table olmsessions (user text,
session_id text, pickle text)""") device_id text, session_id text, pickle text)""")
database.commit() database.commit()
cursor.execute("""select name from sqlite_master where type='table' cursor.execute("""select name from sqlite_master where type='table'

View file

@ -522,10 +522,9 @@ class MatrixSyncEvent(MatrixEvent):
return None return None
# TODO check sender key # TODO check sender key
parsed_plaintext = json.loads(plaintext, encoding='utf-8') decrypted_sender = plaintext["sender"]
decrypted_sender = parsed_plaintext["sender"] decrypted_recepient = plaintext["recipient"]
decrypted_recepient = parsed_plaintext["recipient"] decrypted_recepient_key = plaintext["recipient_keys"]["ed25519"]
decrypted_recepient_key = parsed_plaintext["recipient_keys"]["ed25519"]
if (sender != decrypted_sender or if (sender != decrypted_sender or
server.user_id != decrypted_recepient or server.user_id != decrypted_recepient or
@ -536,10 +535,10 @@ class MatrixSyncEvent(MatrixEvent):
W.prnt("", error_message) W.prnt("", error_message)
return None return None
if parsed_plaintext["type"] != "m.room_key": if plaintext["type"] != "m.room_key":
return None return None
MatrixSyncEvent._handle_key_event(server, sender_key, parsed_plaintext) MatrixSyncEvent._handle_key_event(server, sender_key, plaintext)
@staticmethod @staticmethod
def _handle_key_event(server, sender_key, parsed_dict): def _handle_key_event(server, sender_key, parsed_dict):