encryption: Don't pass the server to the Olm class.

This commit is contained in:
poljar (Damir Jelić) 2018-04-07 11:30:36 +02:00
parent 1fd5bd637d
commit 8a162a7a80
2 changed files with 21 additions and 14 deletions

View file

@ -157,11 +157,17 @@ class Olm():
@encrypt_enabled @encrypt_enabled
def __init__( def __init__(
self, self,
user,
device_id,
session_path,
account=None, account=None,
sessions=defaultdict(list), sessions=defaultdict(list),
group_sessions=defaultdict(dict) group_sessions=defaultdict(dict)
): ):
# type: (Account, Dict[str, List[Session]) -> None # type: (str, str, str, Account, Dict[str, List[Session]) -> None
self.user = user
self.device_id = device_id
self.session_path = session_path
if account: if account:
self.account = account self.account = account
else: else:
@ -224,28 +230,25 @@ class Olm():
@classmethod @classmethod
@encrypt_enabled @encrypt_enabled
def from_session_dir(cls, server): def from_session_dir(cls, user, device_id, session_path):
# type: (Server) -> Olm # type: (Server) -> Olm
account_file_name = "{}_{}.account".format(server.user, account_file_name = "{}_{}.account".format(user, device_id)
server.device_id)
session_path = server.get_session_path()
path = os.path.join(session_path, account_file_name) path = os.path.join(session_path, account_file_name)
try: try:
with open(path, "rb") as f: with open(path, "rb") as f:
pickle = f.read() pickle = f.read()
account = Account.from_pickle(pickle) account = Account.from_pickle(pickle)
return cls(account) return cls(user, device_id, session_path, account)
except OlmAccountError as error: except OlmAccountError as error:
raise EncryptionError(error) raise EncryptionError(error)
@encrypt_enabled @encrypt_enabled
def to_session_dir(self, server): def to_session_dir(self):
# type: (Server) -> None # type: (Server) -> None
account_file_name = "{}_{}.account".format(server.user, account_file_name = "{}_{}.account".format(self.user,
server.device_id) self.device_id)
session_path = server.get_session_path() path = os.path.join(self.session_path, account_file_name)
path = os.path.join(session_path, account_file_name)
try: try:
with open(path, "wb") as f: with open(path, "wb") as f:

View file

@ -141,7 +141,11 @@ class MatrixServer:
def _load_olm(self): def _load_olm(self):
try: try:
self.olm = Olm.from_session_dir(self) self.olm = Olm.from_session_dir(
self.user,
self.device_id,
self.get_session_path()
)
except FileNotFoundError: except FileNotFoundError:
pass pass
except EncryptionError as error: except EncryptionError as error:
@ -164,11 +168,11 @@ class MatrixServer:
server=self.name, server=self.name,
device=self.device_id) device=self.device_id)
W.prnt(self.server_buffer, message) W.prnt(self.server_buffer, message)
self.olm = Olm() self.olm = Olm(self.user, self.device_id, self.get_session_path())
@encrypt_enabled @encrypt_enabled
def store_olm(self): def store_olm(self):
self.olm.to_session_dir(self) self.olm.to_session_dir()
def _create_options(self, config_file): def _create_options(self, config_file):
options = [ options = [