encryption: Don't pass the server to the Olm class.

This commit is contained in:
poljar (Damir Jelić) 2018-04-07 11:30:36 +02:00
parent 1fd5bd637d
commit 8a162a7a80
2 changed files with 21 additions and 14 deletions

View file

@ -157,11 +157,17 @@ class Olm():
@encrypt_enabled
def __init__(
self,
user,
device_id,
session_path,
account=None,
sessions=defaultdict(list),
group_sessions=defaultdict(dict)
):
# type: (Account, Dict[str, List[Session]) -> None
# type: (str, str, str, Account, Dict[str, List[Session]) -> None
self.user = user
self.device_id = device_id
self.session_path = session_path
if account:
self.account = account
else:
@ -224,28 +230,25 @@ class Olm():
@classmethod
@encrypt_enabled
def from_session_dir(cls, server):
def from_session_dir(cls, user, device_id, session_path):
# type: (Server) -> Olm
account_file_name = "{}_{}.account".format(server.user,
server.device_id)
session_path = server.get_session_path()
account_file_name = "{}_{}.account".format(user, device_id)
path = os.path.join(session_path, account_file_name)
try:
with open(path, "rb") as f:
pickle = f.read()
account = Account.from_pickle(pickle)
return cls(account)
return cls(user, device_id, session_path, account)
except OlmAccountError as error:
raise EncryptionError(error)
@encrypt_enabled
def to_session_dir(self, server):
def to_session_dir(self):
# type: (Server) -> None
account_file_name = "{}_{}.account".format(server.user,
server.device_id)
session_path = server.get_session_path()
path = os.path.join(session_path, account_file_name)
account_file_name = "{}_{}.account".format(self.user,
self.device_id)
path = os.path.join(self.session_path, account_file_name)
try:
with open(path, "wb") as f:

View file

@ -141,7 +141,11 @@ class MatrixServer:
def _load_olm(self):
try:
self.olm = Olm.from_session_dir(self)
self.olm = Olm.from_session_dir(
self.user,
self.device_id,
self.get_session_path()
)
except FileNotFoundError:
pass
except EncryptionError as error:
@ -164,11 +168,11 @@ class MatrixServer:
server=self.name,
device=self.device_id)
W.prnt(self.server_buffer, message)
self.olm = Olm()
self.olm = Olm(self.user, self.device_id, self.get_session_path())
@encrypt_enabled
def store_olm(self):
self.olm.to_session_dir(self)
self.olm.to_session_dir()
def _create_options(self, config_file):
options = [