From 1fd5bd637d14d81414bdc016d295e01ddc25139f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?poljar=20=28Damir=20Jeli=C4=87=29?= Date: Tue, 3 Apr 2018 21:51:03 +0200 Subject: [PATCH] encryption: Decrypt incomming messages. --- matrix/encryption.py | 20 +++++++++++++++++- matrix/events.py | 50 ++++++++++++++++++++++++-------------------- matrix/rooms.py | 33 +++++++++++++++++++++++------ 3 files changed, 73 insertions(+), 30 deletions(-) diff --git a/matrix/encryption.py b/matrix/encryption.py index 6af199b..dd077fa 100644 --- a/matrix/encryption.py +++ b/matrix/encryption.py @@ -171,15 +171,20 @@ class Olm(): self.group_sessions = group_sessions def _create_session(self, sender, sender_key, message): + W.prnt("", "matrix: Creating session for {}".format(sender)) session = InboundSession(self.account, message, sender_key) + W.prnt("", "matrix: Created session for {}".format(sender)) self.sessions[sender].append(session) - self.account.remove_one_time_keys(session) + # self.account.remove_one_time_keys(session) + # TODO store account here return session def create_group_session(self, room_id, session_id, session_key): + W.prnt("", "matrix: Creating group session for {}".format(room_id)) session = InboundGroupSession(session_key) self.group_sessions[room_id][session_id] = session + # TODO store account here @encrypt_enabled def decrypt(self, sender, sender_key, message): @@ -204,6 +209,19 @@ class Olm(): except OlmSessionError: return None + @encrypt_enabled + def group_decrypt(self, room_id, session_id, ciphertext): + if session_id not in self.group_sessions[room_id]: + return None + + session = self.group_sessions[room_id][session_id] + try: + plaintext = session.decrypt(ciphertext) + except OlmGroupSessionError: + return None + + return plaintext + @classmethod @encrypt_enabled def from_session_dir(cls, server): diff --git a/matrix/events.py b/matrix/events.py index 10fb86e..ad576a7 100644 --- a/matrix/events.py +++ b/matrix/events.py @@ -17,6 +17,8 @@ from __future__ import unicode_literals from builtins import str +import json + from collections import deque from functools import partial from operator import itemgetter @@ -437,7 +439,7 @@ class MatrixSyncEvent(MatrixEvent): MatrixEvent.__init__(self, server) @staticmethod - def _infos_from_dict(parsed_dict): + def _infos_from_dict(olm, parsed_dict): join_infos = [] invite_infos = [] @@ -445,27 +447,28 @@ class MatrixSyncEvent(MatrixEvent): if not room_id: continue - join_infos.append(RoomInfo.from_dict(room_id, room_dict)) + join_infos.append(RoomInfo.from_dict(olm, room_id, room_dict)) return (join_infos, invite_infos) @staticmethod def _get_olm_device_event(server, parsed_dict): - device_key = server.olm.account.device_keys["curve25519"] + device_key = server.olm.account.identity_keys()["curve25519"] if device_key not in parsed_dict["content"]["ciphertext"]: return None - ciphertext = parsed_dict["content"]["ciphertext"].pop(device_key) sender = sanitize_id(parsed_dict["sender"]) - sender_key = sanitize_id(parsed_dict["sender_key"]) + sender_key = sanitize_id(parsed_dict["content"]["sender_key"]) + + ciphertext = parsed_dict["content"]["ciphertext"].pop(device_key) message = None if ciphertext["type"] == 0: - message = OlmPreKeyMessage(parsed_dict["body"]) + message = OlmPreKeyMessage(ciphertext["body"]) elif ciphertext["type"] == 1: - message = OlmMessage(parsed_dict["body"]) + message = OlmMessage(ciphertext["body"]) else: raise ValueError("Invalid Olm message type") @@ -473,23 +476,24 @@ class MatrixSyncEvent(MatrixEvent): plaintext = olm.decrypt(sender, sender_key, message) # TODO check sender key - decrypted_sender = sanitize_id(plaintext["sender"]) - decrypted_recepient = sanitize_id(plaintext["recipient"]) - decrypted_recepient_key = sanitize_id( - plaintext["recipient_keys"]["ed25519"]) + parsed_plaintext = json.loads(plaintext, encoding='utf-8') + decrypted_sender = parsed_plaintext["sender"] + decrypted_recepient = parsed_plaintext["recipient"] + decrypted_recepient_key = parsed_plaintext["recipient_keys"]["ed25519"] if (sender != decrypted_sender or server.user_id != decrypted_recepient or - device_key != decrypted_recepient_key): - error_message = ("{prefix}matrix: Mismatch in decrypted Olm" - "message").format(W.prefix("error")) + olm.account.identity_keys()["ed25519"] != + decrypted_recepient_key): + error_message = ("{prefix}matrix: Mismatch in decrypted Olm " + "message").format(prefix=W.prefix("error")) W.prnt("", error_message) return None - if plaintext["type"] != "m.room.key": + if parsed_plaintext["type"] != "m.room_key": return None - MatrixSyncEvent._handle_key_event(server, sender_key, plaintext) + MatrixSyncEvent._handle_key_event(server, sender_key, parsed_plaintext) @staticmethod def _handle_key_event(server, sender_key, parsed_dict): @@ -497,12 +501,12 @@ class MatrixSyncEvent(MatrixEvent): olm = server.olm content = parsed_dict.pop("content") - if content["type"] != "m.megolm.v1.aes-sha2": + if content["algorithm"] != "m.megolm.v1.aes-sha2": return - room_id = sanitize_id(content["room_id"]) - session_id = sanitize_id(content["session_id"]) - session_key = sanitize_id(content["session_key"]) + room_id = content["room_id"] + session_id = content["session_id"] + session_key = content["session_key"] if session_id in olm.group_sessions[room_id]: return @@ -516,7 +520,7 @@ class MatrixSyncEvent(MatrixEvent): if event["type"] == "m.room.encrypted": if (event["content"]["algorithm"] == 'm.olm.v1.curve25519-aes-sha2'): - MatrixSyncEvent._get_olm_device_event(server, parsed_dict) + MatrixSyncEvent._get_olm_device_event(server, event) @classmethod def from_dict(cls, server, parsed_dict): @@ -531,12 +535,12 @@ class MatrixSyncEvent(MatrixEvent): parsed_dict["device_one_time_keys_count"]["signed_curve25519"]) MatrixSyncEvent._get_to_device_events( - server.olm, parsed_dict.pop("to_device")) + server, parsed_dict.pop("to_device")) room_info_dict = parsed_dict["rooms"] join_infos, invite_infos = MatrixSyncEvent._infos_from_dict( - room_info_dict) + server.olm, room_info_dict) return cls(server, next_batch, join_infos, invite_infos, one_time_key_count) diff --git a/matrix/rooms.py b/matrix/rooms.py index b0dc06f..ff7f7c0 100644 --- a/matrix/rooms.py +++ b/matrix/rooms.py @@ -17,6 +17,8 @@ from __future__ import unicode_literals from builtins import str +import json + from pprint import pformat from collections import namedtuple, deque @@ -243,7 +245,7 @@ class RoomInfo(): return None, None @staticmethod - def parse_event(event_dict): + def parse_event(olm, room_id, event_dict): # type: (Dict[Any, Any]) -> (RoomEvent, RoomEvent) state_event = None message_event = None @@ -270,11 +272,30 @@ class RoomInfo(): state_event = RoomAliasEvent.from_dict(event_dict) elif event_dict["type"] == "m.room.encryption": state_event = RoomEncryptionEvent.from_dict(event_dict) + elif event_dict["type"] == "m.room.encrypted": + state_event, message_event = RoomInfo._decrypt_event(olm, room_id, + event_dict) return state_event, message_event @staticmethod - def _parse_events(parsed_dict, messages=True, state=True): + def _decrypt_event(olm, room_id, event_dict): + session_id = event_dict["content"]["session_id"] + ciphertext = event_dict["content"]["ciphertext"] + plaintext = olm.group_decrypt(room_id, session_id, ciphertext) + + if not plaintext: + return None, None + + parsed_plaintext = json.loads(plaintext, encoding="utf-8") + + event_dict["content"] = parsed_plaintext["content"] + event_dict["type"] = parsed_plaintext["type"] + + return RoomInfo.parse_event(olm, room_id, event_dict) + + @staticmethod + def _parse_events(olm, room_id, parsed_dict, messages=True, state=True): state_events = [] message_events = [] @@ -283,7 +304,7 @@ class RoomInfo(): try: for event in parsed_dict: - m_event, s_event = RoomInfo.parse_event(event) + m_event, s_event = RoomInfo.parse_event(olm, room_id, event) state_events.append(m_event) message_events.append(s_event) except (ValueError, TypeError, KeyError) as error: @@ -306,14 +327,14 @@ class RoomInfo(): return events @classmethod - def from_dict(cls, room_id, parsed_dict): + def from_dict(cls, olm, room_id, parsed_dict): prev_batch = sanitize_id(parsed_dict['timeline']['prev_batch']) state_dict = parsed_dict['state']['events'] timeline_dict = parsed_dict['timeline']['events'] - state_events = RoomInfo._parse_events(state_dict, messages=False) - timeline_events = RoomInfo._parse_events(timeline_dict) + state_events = RoomInfo._parse_events(olm, room_id, state_dict, messages=False) + timeline_events = RoomInfo._parse_events(olm, room_id, timeline_dict) events = state_events + timeline_events