diff --git a/matrix/api.py b/matrix/api.py index ae4539f..dfb775a 100644 --- a/matrix/api.py +++ b/matrix/api.py @@ -279,6 +279,20 @@ class MatrixClient: return HttpRequest(RequestType.POST, self.host, path, content) + def keys_query(self, users): + query_parameters = {"access_token": self.access_token} + + path = ("{api}/keys/query?" + "{query_parameters}").format( + api=MATRIX_API_PATH, + query_parameters=urlencode(query_parameters)) + + content = { + "device_keys": {user: {} for user in users} + } + + return HttpRequest(RequestType.POST, self.host, path, content) + def mxc_to_http(self, mxc): # type: (str) -> str url = urlparse(mxc) @@ -588,3 +602,19 @@ class MatrixKeyUploadMessage(MatrixMessage): server, self.device_keys) return self._decode(server, object_hook) + + +class MatrixKeyQueryMessage(MatrixMessage): + + def __init__(self, client, users): + data = { + "users": users, + } + + MatrixMessage.__init__(self, client.keys_query, data) + + def decode_body(self, server): + object_hook = partial(MatrixEvents.MatrixKeyQueryEvent.from_dict, + server) + + return self._decode(server, object_hook) diff --git a/matrix/encryption.py b/matrix/encryption.py index f2c826b..acaef9c 100644 --- a/matrix/encryption.py +++ b/matrix/encryption.py @@ -153,6 +153,14 @@ class EncryptionError(Exception): pass +class OlmDeviceKey(): + def __init__(self, user_id, device_id, key_dict): + # type: (str, str, Dict[str, str]) + self.user_id = user_id + self.device_id = device_id + self.keys = key_dict + + class Olm(): @encrypt_enabled @@ -171,6 +179,7 @@ class Olm(): self.device_id = device_id self.session_path = session_path self.database = database + self.device_keys = {} if not database: db_file = "{}_{}.db".format(user, device_id) diff --git a/matrix/events.py b/matrix/events.py index 4af350a..f79d048 100644 --- a/matrix/events.py +++ b/matrix/events.py @@ -19,7 +19,7 @@ from builtins import str import json -from collections import deque +from collections import deque, defaultdict from functools import partial from operator import itemgetter @@ -30,6 +30,8 @@ from matrix.rooms import (matrix_create_room_buffer, RoomInfo, RoomMessageText, RoomMessageEvent, RoomRedactedMessageEvent, RoomMessageEmote) +from matrix.encryption import OlmDeviceKey + try: from olm.session import OlmMessage, OlmPreKeyMessage except ImportError: @@ -307,6 +309,47 @@ class MatrixKickEvent(MatrixEvent): False, parsed_dict) +class MatrixKeyQueryEvent(MatrixEvent): + + def __init__(self, server, keys): + self.keys = keys + MatrixEvent.__init__(self, server) + + @staticmethod + def _get_keys(key_dict): + keys = {} + + for key_type, key in key_dict.items(): + key_type, _ = key_type.split(":") + keys[key_type] = key + + return keys + + @classmethod + def from_dict(cls, server, parsed_dict): + keys = defaultdict(list) + try: + for user_id, device_dict in parsed_dict["device_keys"].items(): + for device_id, key_dict in device_dict.items(): + device_keys = MatrixKeyQueryEvent._get_keys( + key_dict.pop("keys")) + keys[user_id].append(OlmDeviceKey(user_id, device_id, + device_keys)) + return cls(server, keys) + except KeyError: + return MatrixErrorEvent.from_dict(server, "Error kicking user", + False, parsed_dict) + + def execute(self): + olm = self.server.olm + + if olm.device_keys == self.keys: + return + + olm.device_keys = self.keys + # TODO invalidate megolm sessions for rooms that got new devices + + class MatrixBacklogEvent(MatrixEvent): def __init__(self, server, room_id, end_token, events): diff --git a/matrix/server.py b/matrix/server.py index f5517fa..f9cb46a 100644 --- a/matrix/server.py +++ b/matrix/server.py @@ -37,7 +37,8 @@ from matrix.api import ( MatrixClient, MatrixSyncMessage, MatrixLoginMessage, - MatrixKeyUploadMessage + MatrixKeyUploadMessage, + MatrixKeyQueryMessage ) from matrix.encryption import Olm, EncryptionError, encrypt_enabled @@ -92,6 +93,7 @@ class MatrixServer: self.send_fd_hook = None # type: weechat.hook self.send_buffer = b"" # type: bytes self.current_message = None # type: MatrixMessage + self.device_check_timestamp = None self.http_parser = HttpParser() # type: HttpParser self.http_buffer = [] # type: List[bytes] @@ -497,6 +499,21 @@ class MatrixServer: self.olm.account.generate_one_time_keys(key_count) self.upload_keys(device_keys=False, one_time_keys=True) + @encrypt_enabled + def query_keys(self): + users = [] + + for room in self.rooms.values(): + if not room.encrypted: + continue + users += list(room.users) + + if not users: + return + + message = MatrixKeyQueryMessage(self.client, users) + self.send_queue.append(message) + def login(self): # type: (MatrixServer) -> None message = MatrixLoginMessage(self.client, self.user, self.password, @@ -705,6 +722,20 @@ def matrix_timer_cb(server_name, remaining_calls): server.send_queue.appendleft(message) break + if not server.next_batch: + return W.WEECHAT_RC_OK + + # check for new devices by users in encrypted rooms periodically + if (not server.device_check_timestamp or + current_time - server.device_check_timestamp > 600): + + W.prnt(server.server_buffer, + "{prefix}matrix: Querying user devices.".format( + prefix=W.prefix("networ"))) + + server.query_keys() + server.device_check_timestamp = current_time + return W.WEECHAT_RC_OK