encryption: Change the way olm sessions are stored.
This commit is contained in:
parent
658ec67ff4
commit
8f7dac4a0d
2 changed files with 43 additions and 32 deletions
|
@ -39,6 +39,7 @@ except ImportError:
|
||||||
matrix.globals.ENCRYPTION = False
|
matrix.globals.ENCRYPTION = False
|
||||||
|
|
||||||
from matrix.globals import W, SERVERS
|
from matrix.globals import W, SERVERS
|
||||||
|
from matrix.utils import sanitize_id
|
||||||
from matrix.utf import utf8_decode
|
from matrix.utf import utf8_decode
|
||||||
|
|
||||||
|
|
||||||
|
@ -221,7 +222,7 @@ class Olm():
|
||||||
self._insert_acc_to_db()
|
self._insert_acc_to_db()
|
||||||
|
|
||||||
if not sessions:
|
if not sessions:
|
||||||
sessions = defaultdict(list)
|
sessions = defaultdict(lambda: defaultdict(list))
|
||||||
|
|
||||||
if not inbound_group_sessions:
|
if not inbound_group_sessions:
|
||||||
inbound_group_sessions = defaultdict(dict)
|
inbound_group_sessions = defaultdict(dict)
|
||||||
|
@ -233,8 +234,6 @@ class Olm():
|
||||||
W.prnt("", "matrix: Creating session for {}".format(sender))
|
W.prnt("", "matrix: Creating session for {}".format(sender))
|
||||||
session = InboundSession(self.account, message, sender_key)
|
session = InboundSession(self.account, message, sender_key)
|
||||||
W.prnt("", "matrix: Created session for {}".format(sender))
|
W.prnt("", "matrix: Created session for {}".format(sender))
|
||||||
self.sessions[sender].append(session)
|
|
||||||
self._store_session(sender, session)
|
|
||||||
self.account.remove_one_time_keys(session)
|
self.account.remove_one_time_keys(session)
|
||||||
self._update_acc_in_db()
|
self._update_acc_in_db()
|
||||||
|
|
||||||
|
@ -250,16 +249,21 @@ class Olm():
|
||||||
def decrypt(self, sender, sender_key, message):
|
def decrypt(self, sender, sender_key, message):
|
||||||
plaintext = None
|
plaintext = None
|
||||||
|
|
||||||
for session in self.sessions[sender]:
|
for device_id, session_list in self.sessions[sender].items():
|
||||||
try:
|
for session in session_list:
|
||||||
if isinstance(message, OlmPreKeyMessage):
|
W.prnt("", "Trying session for device {}".format(device_id))
|
||||||
if not session.matches(message):
|
try:
|
||||||
continue
|
if isinstance(message, OlmPreKeyMessage):
|
||||||
|
if not session.matches(message):
|
||||||
|
continue
|
||||||
|
|
||||||
plaintext = session.decrypt(message)
|
W.prnt("", "Decrypting using existing session")
|
||||||
return plaintext
|
plaintext = session.decrypt(message)
|
||||||
except OlmSessionError:
|
parsed_plaintext = json.loads(plaintext, encoding='utf-8')
|
||||||
pass
|
W.prnt("", "Decrypted using existing session")
|
||||||
|
return parsed_plaintext
|
||||||
|
except OlmSessionError:
|
||||||
|
pass
|
||||||
|
|
||||||
try:
|
try:
|
||||||
session = self._create_session(sender, sender_key, message)
|
session = self._create_session(sender, sender_key, message)
|
||||||
|
@ -268,7 +272,12 @@ class Olm():
|
||||||
|
|
||||||
try:
|
try:
|
||||||
plaintext = session.decrypt(message)
|
plaintext = session.decrypt(message)
|
||||||
return plaintext
|
parsed_plaintext = json.loads(plaintext, encoding='utf-8')
|
||||||
|
|
||||||
|
device_id = sanitize_id(parsed_plaintext["sender_device"])
|
||||||
|
self.sessions[sender][device_id].append(session)
|
||||||
|
self._store_session(sender, device_id, session)
|
||||||
|
return parsed_plaintext
|
||||||
except OlmSessionError:
|
except OlmSessionError:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@ -300,7 +309,7 @@ class Olm():
|
||||||
row = cursor.fetchone()
|
row = cursor.fetchone()
|
||||||
account_pickle = row[0]
|
account_pickle = row[0]
|
||||||
|
|
||||||
cursor.execute("select user, pickle from olmsessions")
|
cursor.execute("select user, device_id, pickle from olmsessions")
|
||||||
db_sessions = cursor.fetchall()
|
db_sessions = cursor.fetchall()
|
||||||
|
|
||||||
cursor.execute("select room_id, pickle from inbound_group_sessions")
|
cursor.execute("select room_id, pickle from inbound_group_sessions")
|
||||||
|
@ -308,15 +317,15 @@ class Olm():
|
||||||
|
|
||||||
cursor.close()
|
cursor.close()
|
||||||
|
|
||||||
sessions = defaultdict(list)
|
sessions = defaultdict(lambda: defaultdict(list))
|
||||||
inbound_group_sessions = defaultdict(dict)
|
inbound_group_sessions = defaultdict(dict)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
account = Account.from_pickle(bytes(account_pickle, "utf-8"))
|
account = Account.from_pickle(bytes(account_pickle, "utf-8"))
|
||||||
|
|
||||||
for db_session in db_sessions:
|
for db_session in db_sessions:
|
||||||
sessions[db_session[0]].append(
|
sessions[db_session[0]][db_session[1]].append(
|
||||||
Session.from_pickle(bytes(db_session[1], "utf-8")))
|
Session.from_pickle(bytes(db_session[2], "utf-8")))
|
||||||
|
|
||||||
for db_session in db_inbound_group_sessions:
|
for db_session in db_inbound_group_sessions:
|
||||||
session = InboundGroupSession.from_pickle(
|
session = InboundGroupSession.from_pickle(
|
||||||
|
@ -338,11 +347,14 @@ class Olm():
|
||||||
def _update_sessions_in_db(self):
|
def _update_sessions_in_db(self):
|
||||||
cursor = self.database.cursor()
|
cursor = self.database.cursor()
|
||||||
|
|
||||||
for user, session_list in self.sessions.items():
|
for user, session_dict in self.sessions.items():
|
||||||
for session in session_list:
|
for device_id, session_list in session_dict.items():
|
||||||
cursor.execute("""update olmsessions set pickle=?
|
for session in session_list:
|
||||||
where user = ? and session_id = ?""",
|
cursor.execute("""update olmsessions set pickle=?
|
||||||
(session.pickle(), user, session.id()))
|
where user = ? and session_id = ? and
|
||||||
|
device_id = ?""",
|
||||||
|
(session.pickle(), user, session.id(),
|
||||||
|
device_id))
|
||||||
self.database.commit()
|
self.database.commit()
|
||||||
|
|
||||||
cursor.close()
|
cursor.close()
|
||||||
|
@ -359,11 +371,11 @@ class Olm():
|
||||||
|
|
||||||
cursor.close()
|
cursor.close()
|
||||||
|
|
||||||
def _store_session(self, user, session):
|
def _store_session(self, user, device_id, session):
|
||||||
cursor = self.database.cursor()
|
cursor = self.database.cursor()
|
||||||
|
|
||||||
cursor.execute("insert into olmsessions values(?,?,?)",
|
cursor.execute("insert into olmsessions values(?,?,?,?)",
|
||||||
(user, session.id(), session.pickle()))
|
(user, device_id, session.id(), session.pickle()))
|
||||||
|
|
||||||
self.database.commit()
|
self.database.commit()
|
||||||
|
|
||||||
|
@ -399,7 +411,7 @@ class Olm():
|
||||||
and name='olmsessions'""")
|
and name='olmsessions'""")
|
||||||
if not cursor.fetchone():
|
if not cursor.fetchone():
|
||||||
cursor.execute("""create table olmsessions (user text,
|
cursor.execute("""create table olmsessions (user text,
|
||||||
session_id text, pickle text)""")
|
device_id text, session_id text, pickle text)""")
|
||||||
database.commit()
|
database.commit()
|
||||||
|
|
||||||
cursor.execute("""select name from sqlite_master where type='table'
|
cursor.execute("""select name from sqlite_master where type='table'
|
||||||
|
|
|
@ -522,10 +522,9 @@ class MatrixSyncEvent(MatrixEvent):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# TODO check sender key
|
# TODO check sender key
|
||||||
parsed_plaintext = json.loads(plaintext, encoding='utf-8')
|
decrypted_sender = plaintext["sender"]
|
||||||
decrypted_sender = parsed_plaintext["sender"]
|
decrypted_recepient = plaintext["recipient"]
|
||||||
decrypted_recepient = parsed_plaintext["recipient"]
|
decrypted_recepient_key = plaintext["recipient_keys"]["ed25519"]
|
||||||
decrypted_recepient_key = parsed_plaintext["recipient_keys"]["ed25519"]
|
|
||||||
|
|
||||||
if (sender != decrypted_sender or
|
if (sender != decrypted_sender or
|
||||||
server.user_id != decrypted_recepient or
|
server.user_id != decrypted_recepient or
|
||||||
|
@ -536,10 +535,10 @@ class MatrixSyncEvent(MatrixEvent):
|
||||||
W.prnt("", error_message)
|
W.prnt("", error_message)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if parsed_plaintext["type"] != "m.room_key":
|
if plaintext["type"] != "m.room_key":
|
||||||
return None
|
return None
|
||||||
|
|
||||||
MatrixSyncEvent._handle_key_event(server, sender_key, parsed_plaintext)
|
MatrixSyncEvent._handle_key_event(server, sender_key, plaintext)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _handle_key_event(server, sender_key, parsed_dict):
|
def _handle_key_event(server, sender_key, parsed_dict):
|
||||||
|
|
Loading…
Add table
Reference in a new issue