diff --git a/matrix/server.py b/matrix/server.py index feab432..354bd03 100644 --- a/matrix/server.py +++ b/matrix/server.py @@ -89,6 +89,7 @@ from .utils import create_server_buffer, key_from_value, server_buffer_prnt from .uploads import Upload from .colors import Formatted, FormattedString, DEFAULT_ATTRIBUTES +from .store import FileStore try: from urllib.parse import urlparse @@ -370,6 +371,7 @@ class MatrixServer(object): def _load_device_id(self, user=None): user = user or self.config.username + user = user.replace("/", "_") file_name = "{}{}".format(user, ".device_id") path = os.path.join(self.get_session_path(), file_name) @@ -384,6 +386,8 @@ class MatrixServer(object): def save_device_id(self): file_name = "{}{}".format(self.config.username or "main", ".device_id") + file_name = file_name.replace("/", "_") + path = os.path.join(self.get_session_path(), file_name) with atomic_write(path, overwrite=True) as device_file: @@ -410,7 +414,7 @@ class MatrixServer(object): self.address = homeserver.hostname self.homeserver = homeserver - config = ClientConfig(store_sync_tokens=True) + config = ClientConfig(store_sync_tokens=True, store=FileStore) self.client = HttpClient( homeserver.geturl(), diff --git a/matrix/store.py b/matrix/store.py new file mode 100644 index 0000000..2b1724f --- /dev/null +++ b/matrix/store.py @@ -0,0 +1,20 @@ +import os +from dataclasses import dataclass +from nio.store import DefaultStore +from nio.store.file_trustdb import KeyStore + + +@dataclass +class FileStore(DefaultStore): + + def __post_init__(self): + self.database_name = self.default_dbname("db") + super(DefaultStore, self).__post_init__() + self.trust_db = KeyStore(os.path.join(self.store_path, self.default_dbname("trusted_devices"))) + self.blacklist_db = KeyStore(os.path.join(self.store_path, self.default_dbname("blacklisted_devices"))) + self.ignore_db = KeyStore(os.path.join(self.store_path, self.default_dbname("ignored_devices"))) + + def default_dbname(self, postfix): + name = f"{self.user_id}_{self.device_id}.{postfix}" + name = name.replace("/", "_") + return name