encryption: Store the account and sessions in a sqlite db.
This commit is contained in:
parent
8a162a7a80
commit
0bd20cc333
3 changed files with 157 additions and 26 deletions
|
@ -19,9 +19,10 @@ from __future__ import unicode_literals
|
|||
|
||||
import os
|
||||
import json
|
||||
import sqlite3
|
||||
|
||||
# pylint: disable=redefined-builtin
|
||||
from builtins import str
|
||||
from builtins import str, bytes
|
||||
|
||||
from collections import defaultdict
|
||||
from functools import wraps
|
||||
|
@ -31,7 +32,7 @@ import matrix.globals
|
|||
|
||||
try:
|
||||
from olm.account import Account, OlmAccountError
|
||||
from olm.session import (InboundSession, OlmSessionError, OlmMessage,
|
||||
from olm.session import (Session, InboundSession, OlmSessionError,
|
||||
OlmPreKeyMessage)
|
||||
from olm.group_session import InboundGroupSession, OlmGroupSessionError
|
||||
except ImportError:
|
||||
|
@ -160,37 +161,55 @@ class Olm():
|
|||
user,
|
||||
device_id,
|
||||
session_path,
|
||||
database=None,
|
||||
account=None,
|
||||
sessions=defaultdict(list),
|
||||
group_sessions=defaultdict(dict)
|
||||
sessions=None,
|
||||
inbound_group_sessions=None
|
||||
):
|
||||
# type: (str, str, str, Account, Dict[str, List[Session]) -> None
|
||||
self.user = user
|
||||
self.device_id = device_id
|
||||
self.session_path = session_path
|
||||
self.database = database
|
||||
|
||||
if not database:
|
||||
db_file = "{}_{}.db".format(user, device_id)
|
||||
db_path = os.path.join(session_path, db_file)
|
||||
self.database = sqlite3.connect(db_path)
|
||||
Olm._check_db_tables(self.database)
|
||||
|
||||
if account:
|
||||
self.account = account
|
||||
|
||||
else:
|
||||
self.account = Account()
|
||||
self._insert_acc_to_db()
|
||||
|
||||
if not sessions:
|
||||
sessions = defaultdict(list)
|
||||
|
||||
if not inbound_group_sessions:
|
||||
inbound_group_sessions = defaultdict(dict)
|
||||
|
||||
self.sessions = sessions
|
||||
self.group_sessions = group_sessions
|
||||
self.inbound_group_sessions = inbound_group_sessions
|
||||
|
||||
def _create_session(self, sender, sender_key, message):
|
||||
W.prnt("", "matrix: Creating session for {}".format(sender))
|
||||
session = InboundSession(self.account, message, sender_key)
|
||||
W.prnt("", "matrix: Created session for {}".format(sender))
|
||||
self.sessions[sender].append(session)
|
||||
# self.account.remove_one_time_keys(session)
|
||||
# TODO store account here
|
||||
self._store_session(sender, session)
|
||||
self.account.remove_one_time_keys(session)
|
||||
self._update_acc_in_db()
|
||||
|
||||
return session
|
||||
|
||||
def create_group_session(self, room_id, session_id, session_key):
|
||||
W.prnt("", "matrix: Creating group session for {}".format(room_id))
|
||||
session = InboundGroupSession(session_key)
|
||||
self.group_sessions[room_id][session_id] = session
|
||||
# TODO store account here
|
||||
self.inbound_group_sessions[room_id][session_id] = session
|
||||
self._store_inbound_group_session(room_id, session)
|
||||
|
||||
@encrypt_enabled
|
||||
def decrypt(self, sender, sender_key, message):
|
||||
|
@ -217,10 +236,10 @@ class Olm():
|
|||
|
||||
@encrypt_enabled
|
||||
def group_decrypt(self, room_id, session_id, ciphertext):
|
||||
if session_id not in self.group_sessions[room_id]:
|
||||
if session_id not in self.inbound_group_sessions[room_id]:
|
||||
return None
|
||||
|
||||
session = self.group_sessions[room_id][session_id]
|
||||
session = self.inbound_group_sessions[room_id][session_id]
|
||||
try:
|
||||
plaintext = session.decrypt(ciphertext)
|
||||
except OlmGroupSessionError:
|
||||
|
@ -232,28 +251,134 @@ class Olm():
|
|||
@encrypt_enabled
|
||||
def from_session_dir(cls, user, device_id, session_path):
|
||||
# type: (Server) -> Olm
|
||||
account_file_name = "{}_{}.account".format(user, device_id)
|
||||
path = os.path.join(session_path, account_file_name)
|
||||
db_file = "{}_{}.db".format(user, device_id)
|
||||
db_path = os.path.join(session_path, db_file)
|
||||
database = sqlite3.connect(db_path)
|
||||
Olm._check_db_tables(database)
|
||||
|
||||
cursor = database.cursor()
|
||||
|
||||
cursor.execute("select pickle from olmaccount where user = ?", (user,))
|
||||
row = cursor.fetchone()
|
||||
account_pickle = row[0]
|
||||
|
||||
cursor.execute("select user, pickle from olmsessions")
|
||||
db_sessions = cursor.fetchall()
|
||||
|
||||
cursor.execute("select room_id, pickle from inbound_group_sessions")
|
||||
db_inbound_group_sessions = cursor.fetchall()
|
||||
|
||||
cursor.close()
|
||||
|
||||
sessions = defaultdict(list)
|
||||
inbound_group_sessions = defaultdict(dict)
|
||||
|
||||
try:
|
||||
with open(path, "rb") as f:
|
||||
pickle = f.read()
|
||||
account = Account.from_pickle(pickle)
|
||||
return cls(user, device_id, session_path, account)
|
||||
except OlmAccountError as error:
|
||||
account = Account.from_pickle(bytes(account_pickle, "utf-8"))
|
||||
|
||||
for db_session in db_sessions:
|
||||
sessions[db_session[0]].append(
|
||||
Session.from_pickle(bytes(db_session[1], "utf-8")))
|
||||
|
||||
for db_session in db_inbound_group_sessions:
|
||||
session = InboundGroupSession.from_pickle(
|
||||
bytes(db_session[1], "utf-8"))
|
||||
inbound_group_sessions[db_session[0]][session.id] = session
|
||||
|
||||
return cls(user, device_id, session_path, database, account,
|
||||
sessions, inbound_group_sessions)
|
||||
except (OlmAccountError, OlmSessionError) as error:
|
||||
raise EncryptionError(error)
|
||||
|
||||
def _update_acc_in_db(self):
|
||||
cursor = self.database.cursor()
|
||||
cursor.execute("update olmaccount set pickle=? where user = ?",
|
||||
(self.account.pickle(), self.user))
|
||||
self.database.commit()
|
||||
cursor.close()
|
||||
|
||||
def _update_sessions_in_db(self):
|
||||
cursor = self.database.cursor()
|
||||
|
||||
for user, session_list in self.sessions.items():
|
||||
for session in session_list:
|
||||
cursor.execute("""update olmsessions set pickle=?
|
||||
where user = ? and session_id = ?""",
|
||||
(session.pickle(), user, session.id()))
|
||||
self.database.commit()
|
||||
|
||||
cursor.close()
|
||||
|
||||
def _update_inbound_group_sessions(self):
|
||||
cursor = self.database.cursor()
|
||||
|
||||
for room_id, session_dict in self.inbound_group_sessions.items():
|
||||
for session in session_dict.values():
|
||||
cursor.execute("""update inbound_group_sessions set pickle=?
|
||||
where room_id = ? and session_id = ?""",
|
||||
(session.pickle(), room_id, session.id()))
|
||||
self.database.commit()
|
||||
|
||||
cursor.close()
|
||||
|
||||
def _store_session(self, user, session):
|
||||
cursor = self.database.cursor()
|
||||
|
||||
cursor.execute("insert into olmsessions values(?,?,?)",
|
||||
(user, session.id(), session.pickle()))
|
||||
|
||||
self.database.commit()
|
||||
|
||||
cursor.close()
|
||||
|
||||
def _store_inbound_group_session(self, room_id, session):
|
||||
cursor = self.database.cursor()
|
||||
|
||||
cursor.execute("insert into inbound_group_sessions values(?,?,?)",
|
||||
(room_id, session.id, session.pickle()))
|
||||
|
||||
self.database.commit()
|
||||
|
||||
cursor.close()
|
||||
|
||||
def _insert_acc_to_db(self):
|
||||
cursor = self.database.cursor()
|
||||
cursor.execute("insert into olmaccount values (?,?)",
|
||||
(self.user, self.account.pickle()))
|
||||
self.database.commit()
|
||||
cursor.close()
|
||||
|
||||
@staticmethod
|
||||
def _check_db_tables(database):
|
||||
cursor = database.cursor()
|
||||
cursor.execute("""select name from sqlite_master where type='table'
|
||||
and name='olmaccount'""")
|
||||
if not cursor.fetchone():
|
||||
cursor.execute("create table olmaccount (user text, pickle text)")
|
||||
database.commit()
|
||||
|
||||
cursor.execute("""select name from sqlite_master where type='table'
|
||||
and name='olmsessions'""")
|
||||
if not cursor.fetchone():
|
||||
cursor.execute("""create table olmsessions (user text,
|
||||
session_id text, pickle text)""")
|
||||
database.commit()
|
||||
|
||||
cursor.execute("""select name from sqlite_master where type='table'
|
||||
and name='inbound_group_sessions'""")
|
||||
if not cursor.fetchone():
|
||||
cursor.execute("""create table inbound_group_sessions
|
||||
(room_id text, session_id text, pickle text)""")
|
||||
database.commit()
|
||||
|
||||
cursor.close()
|
||||
|
||||
@encrypt_enabled
|
||||
def to_session_dir(self):
|
||||
# type: (Server) -> None
|
||||
account_file_name = "{}_{}.account".format(self.user,
|
||||
self.device_id)
|
||||
path = os.path.join(self.session_path, account_file_name)
|
||||
|
||||
try:
|
||||
with open(path, "wb") as f:
|
||||
pickle = self.account.pickle()
|
||||
f.write(pickle)
|
||||
self._update_acc_in_db()
|
||||
self._update_sessions_in_db()
|
||||
except OlmAccountError as error:
|
||||
raise EncryptionError(error)
|
||||
|
||||
|
|
|
@ -508,7 +508,7 @@ class MatrixSyncEvent(MatrixEvent):
|
|||
session_id = content["session_id"]
|
||||
session_key = content["session_key"]
|
||||
|
||||
if session_id in olm.group_sessions[room_id]:
|
||||
if session_id in olm.inbound_group_sessions[room_id]:
|
||||
return
|
||||
|
||||
olm.create_group_session(room_id, session_id, session_key)
|
||||
|
|
|
@ -146,6 +146,12 @@ class MatrixServer:
|
|||
self.device_id,
|
||||
self.get_session_path()
|
||||
)
|
||||
message = ("{prefix}matrix: Loaded Olm account for {user} (device:"
|
||||
"{device})").format(prefix=W.prefix("network"),
|
||||
user=self.user,
|
||||
device=self.device_id)
|
||||
W.prnt("", message)
|
||||
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
except EncryptionError as error:
|
||||
|
|
Loading…
Add table
Reference in a new issue