From dedff37a60e9433c327f2ad8a5f37fca77e16b68 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?poljar=20=28Damir=20Jeli=C4=87=29?= Date: Wed, 16 May 2018 18:18:02 +0200 Subject: [PATCH] encryption: Add a device store class for the trust database. --- matrix/encryption.py | 158 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 158 insertions(+) diff --git a/matrix/encryption.py b/matrix/encryption.py index 70f460d..37111ab 100644 --- a/matrix/encryption.py +++ b/matrix/encryption.py @@ -29,6 +29,11 @@ from collections import defaultdict from functools import wraps from future.moves.itertools import zip_longest +try: + FileNotFoundError +except NameError: + FileNotFoundError = IOError + import matrix.globals try: @@ -185,6 +190,159 @@ class EncryptionError(Exception): pass +class DeviceStore(object): + def __init__(self, filename): + self._entries = [] + self._filename = filename + + self._load(filename) + + def _load(self, filename): + # type: (str) -> None + try: + with open(filename, "r") as f: + for line in f: + line = line.strip() + + if not line or line.startswith("#"): + continue + + entry = StoreEntry.from_line(line) + + if not entry: + continue + + self._entries.append(entry) + except FileNotFoundError: + pass + + def _save_store(f): + @wraps(f) + def decorated(*args, **kwargs): + self = args[0] + ret = f(*args, **kwargs) + self._save() + return ret + + return decorated + + def _save(self): + # type: (str) -> None + with open(self._filename, "w") as f: + for entry in self._entries: + line = entry.to_line() + f.write(line) + + @_save_store + def add(self, device): + # type: (OlmDeviceKey) -> None + new_entries = StoreEntry.from_olmdevice(device) + self._entries += new_entries + + # Remove duplicate entries + self._entries = list(set(self._entries)) + + self._save() + + @_save_store + def remove(self, device): + # type: (OlmDeviceKey) -> int + removed = 0 + entries = StoreEntry.from_olmdevice(device) + + for entry in entries: + if entry in self._entries: + self._entries.remove(entry) + removed += 1 + + self._save() + + return removed + + def check(self, device): + # type: (OlmDeviceKey) -> bool + entries = StoreEntry.from_olmdevice(device) + result = map(lambda entry: entry in self._entries, entries) + + if False in result: + return False + + return True + + +class StoreEntry(object): + def __init__(self, user_id, device_id, key_type, key): + # type: (str, str, str, str) -> None + self.user_id = user_id + self.device_id = device_id + self.key_type = key_type + self.key = key + + @classmethod + def from_line(cls, line): + # type: (str) -> StoreEntry + fields = line.split(' ') + + if len(fields) < 4: + return None + + user_id, device_id, key_type, key = fields[:4] + + if key_type == "matrix-ed25519": + return cls(user_id, device_id, "ed25519", key) + else: + return None + + @classmethod + def from_olmdevice(cls, device_key): + # type: (OlmDeviceKey) -> [StoreEntry] + entries = [] + + user_id = device_key.user_id + device_id = device_key.device_id + + for key_type, key in device_key.keys.items(): + if key_type == "ed25519": + entries.append(cls(user_id, device_id, "ed25519", key)) + + return entries + + def to_line(self): + # type: () -> str + key_type = "matrix-{}".format(self.key_type) + line = "{} {} {} {}".format( + self.user_id, + self.device_id, + key_type, + self.key + ) + return line + + def __hash__(self): + # type: () -> int + return hash(str(self)) + + def __str__(self): + # type: () -> str + key_type = "matrix-{}".format(self.key_type) + line = "{} {} {} {}".format( + self.user_id, + self.device_id, + key_type, + self.key + ) + return line + + def __eq__(self, value): + # type: (StoreEntry) -> bool + if (self.user_id == value.user_id + and self.device_id == value.device_id + and self.key_type == value.key_type and self.key == value.key): + return True + + return False + + class OlmDeviceKey(): def __init__(self, user_id, device_id, key_dict): # type: (str, str, Dict[str, str])