encryption: Add a device store class for the trust database.
This commit is contained in:
parent
c8a7b4815d
commit
dedff37a60
1 changed files with 158 additions and 0 deletions
|
@ -29,6 +29,11 @@ from collections import defaultdict
|
|||
from functools import wraps
|
||||
from future.moves.itertools import zip_longest
|
||||
|
||||
try:
|
||||
FileNotFoundError
|
||||
except NameError:
|
||||
FileNotFoundError = IOError
|
||||
|
||||
import matrix.globals
|
||||
|
||||
try:
|
||||
|
@ -185,6 +190,159 @@ class EncryptionError(Exception):
|
|||
pass
|
||||
|
||||
|
||||
class DeviceStore(object):
|
||||
def __init__(self, filename):
|
||||
self._entries = []
|
||||
self._filename = filename
|
||||
|
||||
self._load(filename)
|
||||
|
||||
def _load(self, filename):
|
||||
# type: (str) -> None
|
||||
try:
|
||||
with open(filename, "r") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
|
||||
if not line or line.startswith("#"):
|
||||
continue
|
||||
|
||||
entry = StoreEntry.from_line(line)
|
||||
|
||||
if not entry:
|
||||
continue
|
||||
|
||||
self._entries.append(entry)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
def _save_store(f):
|
||||
@wraps(f)
|
||||
def decorated(*args, **kwargs):
|
||||
self = args[0]
|
||||
ret = f(*args, **kwargs)
|
||||
self._save()
|
||||
return ret
|
||||
|
||||
return decorated
|
||||
|
||||
def _save(self):
|
||||
# type: (str) -> None
|
||||
with open(self._filename, "w") as f:
|
||||
for entry in self._entries:
|
||||
line = entry.to_line()
|
||||
f.write(line)
|
||||
|
||||
@_save_store
|
||||
def add(self, device):
|
||||
# type: (OlmDeviceKey) -> None
|
||||
new_entries = StoreEntry.from_olmdevice(device)
|
||||
self._entries += new_entries
|
||||
|
||||
# Remove duplicate entries
|
||||
self._entries = list(set(self._entries))
|
||||
|
||||
self._save()
|
||||
|
||||
@_save_store
|
||||
def remove(self, device):
|
||||
# type: (OlmDeviceKey) -> int
|
||||
removed = 0
|
||||
entries = StoreEntry.from_olmdevice(device)
|
||||
|
||||
for entry in entries:
|
||||
if entry in self._entries:
|
||||
self._entries.remove(entry)
|
||||
removed += 1
|
||||
|
||||
self._save()
|
||||
|
||||
return removed
|
||||
|
||||
def check(self, device):
|
||||
# type: (OlmDeviceKey) -> bool
|
||||
entries = StoreEntry.from_olmdevice(device)
|
||||
result = map(lambda entry: entry in self._entries, entries)
|
||||
|
||||
if False in result:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
class StoreEntry(object):
|
||||
def __init__(self, user_id, device_id, key_type, key):
|
||||
# type: (str, str, str, str) -> None
|
||||
self.user_id = user_id
|
||||
self.device_id = device_id
|
||||
self.key_type = key_type
|
||||
self.key = key
|
||||
|
||||
@classmethod
|
||||
def from_line(cls, line):
|
||||
# type: (str) -> StoreEntry
|
||||
fields = line.split(' ')
|
||||
|
||||
if len(fields) < 4:
|
||||
return None
|
||||
|
||||
user_id, device_id, key_type, key = fields[:4]
|
||||
|
||||
if key_type == "matrix-ed25519":
|
||||
return cls(user_id, device_id, "ed25519", key)
|
||||
else:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def from_olmdevice(cls, device_key):
|
||||
# type: (OlmDeviceKey) -> [StoreEntry]
|
||||
entries = []
|
||||
|
||||
user_id = device_key.user_id
|
||||
device_id = device_key.device_id
|
||||
|
||||
for key_type, key in device_key.keys.items():
|
||||
if key_type == "ed25519":
|
||||
entries.append(cls(user_id, device_id, "ed25519", key))
|
||||
|
||||
return entries
|
||||
|
||||
def to_line(self):
|
||||
# type: () -> str
|
||||
key_type = "matrix-{}".format(self.key_type)
|
||||
line = "{} {} {} {}".format(
|
||||
self.user_id,
|
||||
self.device_id,
|
||||
key_type,
|
||||
self.key
|
||||
)
|
||||
return line
|
||||
|
||||
def __hash__(self):
|
||||
# type: () -> int
|
||||
return hash(str(self))
|
||||
|
||||
def __str__(self):
|
||||
# type: () -> str
|
||||
key_type = "matrix-{}".format(self.key_type)
|
||||
line = "{} {} {} {}".format(
|
||||
self.user_id,
|
||||
self.device_id,
|
||||
key_type,
|
||||
self.key
|
||||
)
|
||||
return line
|
||||
|
||||
def __eq__(self, value):
|
||||
# type: (StoreEntry) -> bool
|
||||
if (self.user_id == value.user_id
|
||||
and self.device_id == value.device_id
|
||||
and self.key_type == value.key_type and self.key == value.key):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class OlmDeviceKey():
|
||||
def __init__(self, user_id, device_id, key_dict):
|
||||
# type: (str, str, Dict[str, str])
|
||||
|
|
Loading…
Reference in a new issue