encryption: Add a device store class for the trust database.

This commit is contained in:
poljar (Damir Jelić) 2018-05-16 18:18:02 +02:00
parent c8a7b4815d
commit dedff37a60

View file

@ -29,6 +29,11 @@ from collections import defaultdict
from functools import wraps from functools import wraps
from future.moves.itertools import zip_longest from future.moves.itertools import zip_longest
try:
FileNotFoundError
except NameError:
FileNotFoundError = IOError
import matrix.globals import matrix.globals
try: try:
@ -185,6 +190,159 @@ class EncryptionError(Exception):
pass pass
class DeviceStore(object):
def __init__(self, filename):
self._entries = []
self._filename = filename
self._load(filename)
def _load(self, filename):
# type: (str) -> None
try:
with open(filename, "r") as f:
for line in f:
line = line.strip()
if not line or line.startswith("#"):
continue
entry = StoreEntry.from_line(line)
if not entry:
continue
self._entries.append(entry)
except FileNotFoundError:
pass
def _save_store(f):
@wraps(f)
def decorated(*args, **kwargs):
self = args[0]
ret = f(*args, **kwargs)
self._save()
return ret
return decorated
def _save(self):
# type: (str) -> None
with open(self._filename, "w") as f:
for entry in self._entries:
line = entry.to_line()
f.write(line)
@_save_store
def add(self, device):
# type: (OlmDeviceKey) -> None
new_entries = StoreEntry.from_olmdevice(device)
self._entries += new_entries
# Remove duplicate entries
self._entries = list(set(self._entries))
self._save()
@_save_store
def remove(self, device):
# type: (OlmDeviceKey) -> int
removed = 0
entries = StoreEntry.from_olmdevice(device)
for entry in entries:
if entry in self._entries:
self._entries.remove(entry)
removed += 1
self._save()
return removed
def check(self, device):
# type: (OlmDeviceKey) -> bool
entries = StoreEntry.from_olmdevice(device)
result = map(lambda entry: entry in self._entries, entries)
if False in result:
return False
return True
class StoreEntry(object):
def __init__(self, user_id, device_id, key_type, key):
# type: (str, str, str, str) -> None
self.user_id = user_id
self.device_id = device_id
self.key_type = key_type
self.key = key
@classmethod
def from_line(cls, line):
# type: (str) -> StoreEntry
fields = line.split(' ')
if len(fields) < 4:
return None
user_id, device_id, key_type, key = fields[:4]
if key_type == "matrix-ed25519":
return cls(user_id, device_id, "ed25519", key)
else:
return None
@classmethod
def from_olmdevice(cls, device_key):
# type: (OlmDeviceKey) -> [StoreEntry]
entries = []
user_id = device_key.user_id
device_id = device_key.device_id
for key_type, key in device_key.keys.items():
if key_type == "ed25519":
entries.append(cls(user_id, device_id, "ed25519", key))
return entries
def to_line(self):
# type: () -> str
key_type = "matrix-{}".format(self.key_type)
line = "{} {} {} {}".format(
self.user_id,
self.device_id,
key_type,
self.key
)
return line
def __hash__(self):
# type: () -> int
return hash(str(self))
def __str__(self):
# type: () -> str
key_type = "matrix-{}".format(self.key_type)
line = "{} {} {} {}".format(
self.user_id,
self.device_id,
key_type,
self.key
)
return line
def __eq__(self, value):
# type: (StoreEntry) -> bool
if (self.user_id == value.user_id
and self.device_id == value.device_id
and self.key_type == value.key_type and self.key == value.key):
return True
return False
class OlmDeviceKey(): class OlmDeviceKey():
def __init__(self, user_id, device_id, key_dict): def __init__(self, user_id, device_id, key_dict):
# type: (str, str, Dict[str, str]) # type: (str, str, Dict[str, str])