diff --git a/matrix/encryption.py b/matrix/encryption.py index 23e4156..6af199b 100644 --- a/matrix/encryption.py +++ b/matrix/encryption.py @@ -23,6 +23,7 @@ import json # pylint: disable=redefined-builtin from builtins import str +from collections import defaultdict from functools import wraps from future.moves.itertools import zip_longest @@ -30,6 +31,9 @@ import matrix.globals try: from olm.account import Account, OlmAccountError + from olm.session import (InboundSession, OlmSessionError, OlmMessage, + OlmPreKeyMessage) + from olm.group_session import InboundGroupSession, OlmGroupSessionError except ImportError: matrix.globals.ENCRYPTION = False @@ -151,13 +155,55 @@ class EncryptionError(Exception): class Olm(): @encrypt_enabled - def __init__(self, account=None): - # type: (Server, Account) -> None + def __init__( + self, + account=None, + sessions=defaultdict(list), + group_sessions=defaultdict(dict) + ): + # type: (Account, Dict[str, List[Session]) -> None if account: self.account = account else: self.account = Account() + self.sessions = sessions + self.group_sessions = group_sessions + + def _create_session(self, sender, sender_key, message): + session = InboundSession(self.account, message, sender_key) + self.sessions[sender].append(session) + self.account.remove_one_time_keys(session) + + return session + + def create_group_session(self, room_id, session_id, session_key): + session = InboundGroupSession(session_key) + self.group_sessions[room_id][session_id] = session + + @encrypt_enabled + def decrypt(self, sender, sender_key, message): + plaintext = None + + for session in self.sessions[sender]: + try: + if isinstance(message, OlmPreKeyMessage): + if not session.matches(message): + continue + + plaintext = session.decrypt(message) + break + except OlmSessionError: + pass + + session = self._create_session(sender, sender_key, message) + + try: + plaintext = session.decrypt(message) + return plaintext + except OlmSessionError: + return None + @classmethod @encrypt_enabled def from_session_dir(cls, server): diff --git a/matrix/events.py b/matrix/events.py index 8cfdf06..10fb86e 100644 --- a/matrix/events.py +++ b/matrix/events.py @@ -28,6 +28,11 @@ from matrix.rooms import (matrix_create_room_buffer, RoomInfo, RoomMessageText, RoomMessageEvent, RoomRedactedMessageEvent, RoomMessageEmote) +try: + from olm.session import OlmMessage, OlmPreKeyMessage +except ImportError: + pass + class MatrixEvent(): @@ -444,6 +449,75 @@ class MatrixSyncEvent(MatrixEvent): return (join_infos, invite_infos) + @staticmethod + def _get_olm_device_event(server, parsed_dict): + device_key = server.olm.account.device_keys["curve25519"] + + if device_key not in parsed_dict["content"]["ciphertext"]: + return None + + ciphertext = parsed_dict["content"]["ciphertext"].pop(device_key) + sender = sanitize_id(parsed_dict["sender"]) + sender_key = sanitize_id(parsed_dict["sender_key"]) + + message = None + + if ciphertext["type"] == 0: + message = OlmPreKeyMessage(parsed_dict["body"]) + elif ciphertext["type"] == 1: + message = OlmMessage(parsed_dict["body"]) + else: + raise ValueError("Invalid Olm message type") + + olm = server.olm + plaintext = olm.decrypt(sender, sender_key, message) + + # TODO check sender key + decrypted_sender = sanitize_id(plaintext["sender"]) + decrypted_recepient = sanitize_id(plaintext["recipient"]) + decrypted_recepient_key = sanitize_id( + plaintext["recipient_keys"]["ed25519"]) + + if (sender != decrypted_sender or + server.user_id != decrypted_recepient or + device_key != decrypted_recepient_key): + error_message = ("{prefix}matrix: Mismatch in decrypted Olm" + "message").format(W.prefix("error")) + W.prnt("", error_message) + return None + + if plaintext["type"] != "m.room.key": + return None + + MatrixSyncEvent._handle_key_event(server, sender_key, plaintext) + + @staticmethod + def _handle_key_event(server, sender_key, parsed_dict): + # type: (MatrixServer, str, Dict[Any, Any] -> None + olm = server.olm + content = parsed_dict.pop("content") + + if content["type"] != "m.megolm.v1.aes-sha2": + return + + room_id = sanitize_id(content["room_id"]) + session_id = sanitize_id(content["session_id"]) + session_key = sanitize_id(content["session_key"]) + + if session_id in olm.group_sessions[room_id]: + return + + olm.create_group_session(room_id, session_id, session_key) + + @staticmethod + def _get_to_device_events(server, parsed_dict): + # type: (MatrixServer, Dict[Any, Any]) -> None + for event in parsed_dict["events"]: + if event["type"] == "m.room.encrypted": + if (event["content"]["algorithm"] == + 'm.olm.v1.curve25519-aes-sha2'): + MatrixSyncEvent._get_olm_device_event(server, parsed_dict) + @classmethod def from_dict(cls, server, parsed_dict): try: @@ -456,6 +530,9 @@ class MatrixSyncEvent(MatrixEvent): one_time_key_count = ( parsed_dict["device_one_time_keys_count"]["signed_curve25519"]) + MatrixSyncEvent._get_to_device_events( + server.olm, parsed_dict.pop("to_device")) + room_info_dict = parsed_dict["rooms"] join_infos, invite_infos = MatrixSyncEvent._infos_from_dict(