encryption: Store the account and sessions in a sqlite db.

This commit is contained in:
poljar (Damir Jelić) 2018-04-11 14:00:37 +02:00
parent 8a162a7a80
commit 0bd20cc333
3 changed files with 157 additions and 26 deletions

View file

@ -19,9 +19,10 @@ from __future__ import unicode_literals
import os import os
import json import json
import sqlite3
# pylint: disable=redefined-builtin # pylint: disable=redefined-builtin
from builtins import str from builtins import str, bytes
from collections import defaultdict from collections import defaultdict
from functools import wraps from functools import wraps
@ -31,7 +32,7 @@ import matrix.globals
try: try:
from olm.account import Account, OlmAccountError from olm.account import Account, OlmAccountError
from olm.session import (InboundSession, OlmSessionError, OlmMessage, from olm.session import (Session, InboundSession, OlmSessionError,
OlmPreKeyMessage) OlmPreKeyMessage)
from olm.group_session import InboundGroupSession, OlmGroupSessionError from olm.group_session import InboundGroupSession, OlmGroupSessionError
except ImportError: except ImportError:
@ -160,37 +161,55 @@ class Olm():
user, user,
device_id, device_id,
session_path, session_path,
database=None,
account=None, account=None,
sessions=defaultdict(list), sessions=None,
group_sessions=defaultdict(dict) inbound_group_sessions=None
): ):
# type: (str, str, str, Account, Dict[str, List[Session]) -> None # type: (str, str, str, Account, Dict[str, List[Session]) -> None
self.user = user self.user = user
self.device_id = device_id self.device_id = device_id
self.session_path = session_path self.session_path = session_path
self.database = database
if not database:
db_file = "{}_{}.db".format(user, device_id)
db_path = os.path.join(session_path, db_file)
self.database = sqlite3.connect(db_path)
Olm._check_db_tables(self.database)
if account: if account:
self.account = account self.account = account
else: else:
self.account = Account() self.account = Account()
self._insert_acc_to_db()
if not sessions:
sessions = defaultdict(list)
if not inbound_group_sessions:
inbound_group_sessions = defaultdict(dict)
self.sessions = sessions self.sessions = sessions
self.group_sessions = group_sessions self.inbound_group_sessions = inbound_group_sessions
def _create_session(self, sender, sender_key, message): def _create_session(self, sender, sender_key, message):
W.prnt("", "matrix: Creating session for {}".format(sender)) W.prnt("", "matrix: Creating session for {}".format(sender))
session = InboundSession(self.account, message, sender_key) session = InboundSession(self.account, message, sender_key)
W.prnt("", "matrix: Created session for {}".format(sender)) W.prnt("", "matrix: Created session for {}".format(sender))
self.sessions[sender].append(session) self.sessions[sender].append(session)
# self.account.remove_one_time_keys(session) self._store_session(sender, session)
# TODO store account here self.account.remove_one_time_keys(session)
self._update_acc_in_db()
return session return session
def create_group_session(self, room_id, session_id, session_key): def create_group_session(self, room_id, session_id, session_key):
W.prnt("", "matrix: Creating group session for {}".format(room_id)) W.prnt("", "matrix: Creating group session for {}".format(room_id))
session = InboundGroupSession(session_key) session = InboundGroupSession(session_key)
self.group_sessions[room_id][session_id] = session self.inbound_group_sessions[room_id][session_id] = session
# TODO store account here self._store_inbound_group_session(room_id, session)
@encrypt_enabled @encrypt_enabled
def decrypt(self, sender, sender_key, message): def decrypt(self, sender, sender_key, message):
@ -217,10 +236,10 @@ class Olm():
@encrypt_enabled @encrypt_enabled
def group_decrypt(self, room_id, session_id, ciphertext): def group_decrypt(self, room_id, session_id, ciphertext):
if session_id not in self.group_sessions[room_id]: if session_id not in self.inbound_group_sessions[room_id]:
return None return None
session = self.group_sessions[room_id][session_id] session = self.inbound_group_sessions[room_id][session_id]
try: try:
plaintext = session.decrypt(ciphertext) plaintext = session.decrypt(ciphertext)
except OlmGroupSessionError: except OlmGroupSessionError:
@ -232,28 +251,134 @@ class Olm():
@encrypt_enabled @encrypt_enabled
def from_session_dir(cls, user, device_id, session_path): def from_session_dir(cls, user, device_id, session_path):
# type: (Server) -> Olm # type: (Server) -> Olm
account_file_name = "{}_{}.account".format(user, device_id) db_file = "{}_{}.db".format(user, device_id)
path = os.path.join(session_path, account_file_name) db_path = os.path.join(session_path, db_file)
database = sqlite3.connect(db_path)
Olm._check_db_tables(database)
cursor = database.cursor()
cursor.execute("select pickle from olmaccount where user = ?", (user,))
row = cursor.fetchone()
account_pickle = row[0]
cursor.execute("select user, pickle from olmsessions")
db_sessions = cursor.fetchall()
cursor.execute("select room_id, pickle from inbound_group_sessions")
db_inbound_group_sessions = cursor.fetchall()
cursor.close()
sessions = defaultdict(list)
inbound_group_sessions = defaultdict(dict)
try: try:
with open(path, "rb") as f: account = Account.from_pickle(bytes(account_pickle, "utf-8"))
pickle = f.read()
account = Account.from_pickle(pickle) for db_session in db_sessions:
return cls(user, device_id, session_path, account) sessions[db_session[0]].append(
except OlmAccountError as error: Session.from_pickle(bytes(db_session[1], "utf-8")))
for db_session in db_inbound_group_sessions:
session = InboundGroupSession.from_pickle(
bytes(db_session[1], "utf-8"))
inbound_group_sessions[db_session[0]][session.id] = session
return cls(user, device_id, session_path, database, account,
sessions, inbound_group_sessions)
except (OlmAccountError, OlmSessionError) as error:
raise EncryptionError(error) raise EncryptionError(error)
def _update_acc_in_db(self):
cursor = self.database.cursor()
cursor.execute("update olmaccount set pickle=? where user = ?",
(self.account.pickle(), self.user))
self.database.commit()
cursor.close()
def _update_sessions_in_db(self):
cursor = self.database.cursor()
for user, session_list in self.sessions.items():
for session in session_list:
cursor.execute("""update olmsessions set pickle=?
where user = ? and session_id = ?""",
(session.pickle(), user, session.id()))
self.database.commit()
cursor.close()
def _update_inbound_group_sessions(self):
cursor = self.database.cursor()
for room_id, session_dict in self.inbound_group_sessions.items():
for session in session_dict.values():
cursor.execute("""update inbound_group_sessions set pickle=?
where room_id = ? and session_id = ?""",
(session.pickle(), room_id, session.id()))
self.database.commit()
cursor.close()
def _store_session(self, user, session):
cursor = self.database.cursor()
cursor.execute("insert into olmsessions values(?,?,?)",
(user, session.id(), session.pickle()))
self.database.commit()
cursor.close()
def _store_inbound_group_session(self, room_id, session):
cursor = self.database.cursor()
cursor.execute("insert into inbound_group_sessions values(?,?,?)",
(room_id, session.id, session.pickle()))
self.database.commit()
cursor.close()
def _insert_acc_to_db(self):
cursor = self.database.cursor()
cursor.execute("insert into olmaccount values (?,?)",
(self.user, self.account.pickle()))
self.database.commit()
cursor.close()
@staticmethod
def _check_db_tables(database):
cursor = database.cursor()
cursor.execute("""select name from sqlite_master where type='table'
and name='olmaccount'""")
if not cursor.fetchone():
cursor.execute("create table olmaccount (user text, pickle text)")
database.commit()
cursor.execute("""select name from sqlite_master where type='table'
and name='olmsessions'""")
if not cursor.fetchone():
cursor.execute("""create table olmsessions (user text,
session_id text, pickle text)""")
database.commit()
cursor.execute("""select name from sqlite_master where type='table'
and name='inbound_group_sessions'""")
if not cursor.fetchone():
cursor.execute("""create table inbound_group_sessions
(room_id text, session_id text, pickle text)""")
database.commit()
cursor.close()
@encrypt_enabled @encrypt_enabled
def to_session_dir(self): def to_session_dir(self):
# type: (Server) -> None # type: (Server) -> None
account_file_name = "{}_{}.account".format(self.user,
self.device_id)
path = os.path.join(self.session_path, account_file_name)
try: try:
with open(path, "wb") as f: self._update_acc_in_db()
pickle = self.account.pickle() self._update_sessions_in_db()
f.write(pickle)
except OlmAccountError as error: except OlmAccountError as error:
raise EncryptionError(error) raise EncryptionError(error)

View file

@ -508,7 +508,7 @@ class MatrixSyncEvent(MatrixEvent):
session_id = content["session_id"] session_id = content["session_id"]
session_key = content["session_key"] session_key = content["session_key"]
if session_id in olm.group_sessions[room_id]: if session_id in olm.inbound_group_sessions[room_id]:
return return
olm.create_group_session(room_id, session_id, session_key) olm.create_group_session(room_id, session_id, session_key)

View file

@ -146,6 +146,12 @@ class MatrixServer:
self.device_id, self.device_id,
self.get_session_path() self.get_session_path()
) )
message = ("{prefix}matrix: Loaded Olm account for {user} (device:"
"{device})").format(prefix=W.prefix("network"),
user=self.user,
device=self.device_id)
W.prnt("", message)
except FileNotFoundError: except FileNotFoundError:
pass pass
except EncryptionError as error: except EncryptionError as error: