encryption: Add inbound session creation.

This commit is contained in:
poljar (Damir Jelić) 2018-03-29 11:41:01 +02:00
parent f8c1f564ae
commit ec2995fe52
2 changed files with 125 additions and 2 deletions

View file

@ -23,6 +23,7 @@ import json
# pylint: disable=redefined-builtin # pylint: disable=redefined-builtin
from builtins import str from builtins import str
from collections import defaultdict
from functools import wraps from functools import wraps
from future.moves.itertools import zip_longest from future.moves.itertools import zip_longest
@ -30,6 +31,9 @@ import matrix.globals
try: try:
from olm.account import Account, OlmAccountError from olm.account import Account, OlmAccountError
from olm.session import (InboundSession, OlmSessionError, OlmMessage,
OlmPreKeyMessage)
from olm.group_session import InboundGroupSession, OlmGroupSessionError
except ImportError: except ImportError:
matrix.globals.ENCRYPTION = False matrix.globals.ENCRYPTION = False
@ -151,13 +155,55 @@ class EncryptionError(Exception):
class Olm(): class Olm():
@encrypt_enabled @encrypt_enabled
def __init__(self, account=None): def __init__(
# type: (Server, Account) -> None self,
account=None,
sessions=defaultdict(list),
group_sessions=defaultdict(dict)
):
# type: (Account, Dict[str, List[Session]) -> None
if account: if account:
self.account = account self.account = account
else: else:
self.account = Account() self.account = Account()
self.sessions = sessions
self.group_sessions = group_sessions
def _create_session(self, sender, sender_key, message):
session = InboundSession(self.account, message, sender_key)
self.sessions[sender].append(session)
self.account.remove_one_time_keys(session)
return session
def create_group_session(self, room_id, session_id, session_key):
session = InboundGroupSession(session_key)
self.group_sessions[room_id][session_id] = session
@encrypt_enabled
def decrypt(self, sender, sender_key, message):
plaintext = None
for session in self.sessions[sender]:
try:
if isinstance(message, OlmPreKeyMessage):
if not session.matches(message):
continue
plaintext = session.decrypt(message)
break
except OlmSessionError:
pass
session = self._create_session(sender, sender_key, message)
try:
plaintext = session.decrypt(message)
return plaintext
except OlmSessionError:
return None
@classmethod @classmethod
@encrypt_enabled @encrypt_enabled
def from_session_dir(cls, server): def from_session_dir(cls, server):

View file

@ -28,6 +28,11 @@ from matrix.rooms import (matrix_create_room_buffer, RoomInfo, RoomMessageText,
RoomMessageEvent, RoomRedactedMessageEvent, RoomMessageEvent, RoomRedactedMessageEvent,
RoomMessageEmote) RoomMessageEmote)
try:
from olm.session import OlmMessage, OlmPreKeyMessage
except ImportError:
pass
class MatrixEvent(): class MatrixEvent():
@ -444,6 +449,75 @@ class MatrixSyncEvent(MatrixEvent):
return (join_infos, invite_infos) return (join_infos, invite_infos)
@staticmethod
def _get_olm_device_event(server, parsed_dict):
device_key = server.olm.account.device_keys["curve25519"]
if device_key not in parsed_dict["content"]["ciphertext"]:
return None
ciphertext = parsed_dict["content"]["ciphertext"].pop(device_key)
sender = sanitize_id(parsed_dict["sender"])
sender_key = sanitize_id(parsed_dict["sender_key"])
message = None
if ciphertext["type"] == 0:
message = OlmPreKeyMessage(parsed_dict["body"])
elif ciphertext["type"] == 1:
message = OlmMessage(parsed_dict["body"])
else:
raise ValueError("Invalid Olm message type")
olm = server.olm
plaintext = olm.decrypt(sender, sender_key, message)
# TODO check sender key
decrypted_sender = sanitize_id(plaintext["sender"])
decrypted_recepient = sanitize_id(plaintext["recipient"])
decrypted_recepient_key = sanitize_id(
plaintext["recipient_keys"]["ed25519"])
if (sender != decrypted_sender or
server.user_id != decrypted_recepient or
device_key != decrypted_recepient_key):
error_message = ("{prefix}matrix: Mismatch in decrypted Olm"
"message").format(W.prefix("error"))
W.prnt("", error_message)
return None
if plaintext["type"] != "m.room.key":
return None
MatrixSyncEvent._handle_key_event(server, sender_key, plaintext)
@staticmethod
def _handle_key_event(server, sender_key, parsed_dict):
# type: (MatrixServer, str, Dict[Any, Any] -> None
olm = server.olm
content = parsed_dict.pop("content")
if content["type"] != "m.megolm.v1.aes-sha2":
return
room_id = sanitize_id(content["room_id"])
session_id = sanitize_id(content["session_id"])
session_key = sanitize_id(content["session_key"])
if session_id in olm.group_sessions[room_id]:
return
olm.create_group_session(room_id, session_id, session_key)
@staticmethod
def _get_to_device_events(server, parsed_dict):
# type: (MatrixServer, Dict[Any, Any]) -> None
for event in parsed_dict["events"]:
if event["type"] == "m.room.encrypted":
if (event["content"]["algorithm"] ==
'm.olm.v1.curve25519-aes-sha2'):
MatrixSyncEvent._get_olm_device_event(server, parsed_dict)
@classmethod @classmethod
def from_dict(cls, server, parsed_dict): def from_dict(cls, server, parsed_dict):
try: try:
@ -456,6 +530,9 @@ class MatrixSyncEvent(MatrixEvent):
one_time_key_count = ( one_time_key_count = (
parsed_dict["device_one_time_keys_count"]["signed_curve25519"]) parsed_dict["device_one_time_keys_count"]["signed_curve25519"])
MatrixSyncEvent._get_to_device_events(
server.olm, parsed_dict.pop("to_device"))
room_info_dict = parsed_dict["rooms"] room_info_dict = parsed_dict["rooms"]
join_infos, invite_infos = MatrixSyncEvent._infos_from_dict( join_infos, invite_infos = MatrixSyncEvent._infos_from_dict(