encryption: Add a device store class for the trust database.
This commit is contained in:
parent
c8a7b4815d
commit
dedff37a60
1 changed files with 158 additions and 0 deletions
|
@ -29,6 +29,11 @@ from collections import defaultdict
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from future.moves.itertools import zip_longest
|
from future.moves.itertools import zip_longest
|
||||||
|
|
||||||
|
try:
|
||||||
|
FileNotFoundError
|
||||||
|
except NameError:
|
||||||
|
FileNotFoundError = IOError
|
||||||
|
|
||||||
import matrix.globals
|
import matrix.globals
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -185,6 +190,159 @@ class EncryptionError(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class DeviceStore(object):
|
||||||
|
def __init__(self, filename):
|
||||||
|
self._entries = []
|
||||||
|
self._filename = filename
|
||||||
|
|
||||||
|
self._load(filename)
|
||||||
|
|
||||||
|
def _load(self, filename):
|
||||||
|
# type: (str) -> None
|
||||||
|
try:
|
||||||
|
with open(filename, "r") as f:
|
||||||
|
for line in f:
|
||||||
|
line = line.strip()
|
||||||
|
|
||||||
|
if not line or line.startswith("#"):
|
||||||
|
continue
|
||||||
|
|
||||||
|
entry = StoreEntry.from_line(line)
|
||||||
|
|
||||||
|
if not entry:
|
||||||
|
continue
|
||||||
|
|
||||||
|
self._entries.append(entry)
|
||||||
|
except FileNotFoundError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _save_store(f):
|
||||||
|
@wraps(f)
|
||||||
|
def decorated(*args, **kwargs):
|
||||||
|
self = args[0]
|
||||||
|
ret = f(*args, **kwargs)
|
||||||
|
self._save()
|
||||||
|
return ret
|
||||||
|
|
||||||
|
return decorated
|
||||||
|
|
||||||
|
def _save(self):
|
||||||
|
# type: (str) -> None
|
||||||
|
with open(self._filename, "w") as f:
|
||||||
|
for entry in self._entries:
|
||||||
|
line = entry.to_line()
|
||||||
|
f.write(line)
|
||||||
|
|
||||||
|
@_save_store
|
||||||
|
def add(self, device):
|
||||||
|
# type: (OlmDeviceKey) -> None
|
||||||
|
new_entries = StoreEntry.from_olmdevice(device)
|
||||||
|
self._entries += new_entries
|
||||||
|
|
||||||
|
# Remove duplicate entries
|
||||||
|
self._entries = list(set(self._entries))
|
||||||
|
|
||||||
|
self._save()
|
||||||
|
|
||||||
|
@_save_store
|
||||||
|
def remove(self, device):
|
||||||
|
# type: (OlmDeviceKey) -> int
|
||||||
|
removed = 0
|
||||||
|
entries = StoreEntry.from_olmdevice(device)
|
||||||
|
|
||||||
|
for entry in entries:
|
||||||
|
if entry in self._entries:
|
||||||
|
self._entries.remove(entry)
|
||||||
|
removed += 1
|
||||||
|
|
||||||
|
self._save()
|
||||||
|
|
||||||
|
return removed
|
||||||
|
|
||||||
|
def check(self, device):
|
||||||
|
# type: (OlmDeviceKey) -> bool
|
||||||
|
entries = StoreEntry.from_olmdevice(device)
|
||||||
|
result = map(lambda entry: entry in self._entries, entries)
|
||||||
|
|
||||||
|
if False in result:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
class StoreEntry(object):
|
||||||
|
def __init__(self, user_id, device_id, key_type, key):
|
||||||
|
# type: (str, str, str, str) -> None
|
||||||
|
self.user_id = user_id
|
||||||
|
self.device_id = device_id
|
||||||
|
self.key_type = key_type
|
||||||
|
self.key = key
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_line(cls, line):
|
||||||
|
# type: (str) -> StoreEntry
|
||||||
|
fields = line.split(' ')
|
||||||
|
|
||||||
|
if len(fields) < 4:
|
||||||
|
return None
|
||||||
|
|
||||||
|
user_id, device_id, key_type, key = fields[:4]
|
||||||
|
|
||||||
|
if key_type == "matrix-ed25519":
|
||||||
|
return cls(user_id, device_id, "ed25519", key)
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_olmdevice(cls, device_key):
|
||||||
|
# type: (OlmDeviceKey) -> [StoreEntry]
|
||||||
|
entries = []
|
||||||
|
|
||||||
|
user_id = device_key.user_id
|
||||||
|
device_id = device_key.device_id
|
||||||
|
|
||||||
|
for key_type, key in device_key.keys.items():
|
||||||
|
if key_type == "ed25519":
|
||||||
|
entries.append(cls(user_id, device_id, "ed25519", key))
|
||||||
|
|
||||||
|
return entries
|
||||||
|
|
||||||
|
def to_line(self):
|
||||||
|
# type: () -> str
|
||||||
|
key_type = "matrix-{}".format(self.key_type)
|
||||||
|
line = "{} {} {} {}".format(
|
||||||
|
self.user_id,
|
||||||
|
self.device_id,
|
||||||
|
key_type,
|
||||||
|
self.key
|
||||||
|
)
|
||||||
|
return line
|
||||||
|
|
||||||
|
def __hash__(self):
|
||||||
|
# type: () -> int
|
||||||
|
return hash(str(self))
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
# type: () -> str
|
||||||
|
key_type = "matrix-{}".format(self.key_type)
|
||||||
|
line = "{} {} {} {}".format(
|
||||||
|
self.user_id,
|
||||||
|
self.device_id,
|
||||||
|
key_type,
|
||||||
|
self.key
|
||||||
|
)
|
||||||
|
return line
|
||||||
|
|
||||||
|
def __eq__(self, value):
|
||||||
|
# type: (StoreEntry) -> bool
|
||||||
|
if (self.user_id == value.user_id
|
||||||
|
and self.device_id == value.device_id
|
||||||
|
and self.key_type == value.key_type and self.key == value.key):
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
class OlmDeviceKey():
|
class OlmDeviceKey():
|
||||||
def __init__(self, user_id, device_id, key_dict):
|
def __init__(self, user_id, device_id, key_dict):
|
||||||
# type: (str, str, Dict[str, str])
|
# type: (str, str, Dict[str, str])
|
||||||
|
|
Loading…
Reference in a new issue