encryption: Add inbound session creation.
This commit is contained in:
parent
f8c1f564ae
commit
ec2995fe52
2 changed files with 125 additions and 2 deletions
|
@ -23,6 +23,7 @@ import json
|
||||||
# pylint: disable=redefined-builtin
|
# pylint: disable=redefined-builtin
|
||||||
from builtins import str
|
from builtins import str
|
||||||
|
|
||||||
|
from collections import defaultdict
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from future.moves.itertools import zip_longest
|
from future.moves.itertools import zip_longest
|
||||||
|
|
||||||
|
@ -30,6 +31,9 @@ import matrix.globals
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from olm.account import Account, OlmAccountError
|
from olm.account import Account, OlmAccountError
|
||||||
|
from olm.session import (InboundSession, OlmSessionError, OlmMessage,
|
||||||
|
OlmPreKeyMessage)
|
||||||
|
from olm.group_session import InboundGroupSession, OlmGroupSessionError
|
||||||
except ImportError:
|
except ImportError:
|
||||||
matrix.globals.ENCRYPTION = False
|
matrix.globals.ENCRYPTION = False
|
||||||
|
|
||||||
|
@ -151,13 +155,55 @@ class EncryptionError(Exception):
|
||||||
class Olm():
|
class Olm():
|
||||||
|
|
||||||
@encrypt_enabled
|
@encrypt_enabled
|
||||||
def __init__(self, account=None):
|
def __init__(
|
||||||
# type: (Server, Account) -> None
|
self,
|
||||||
|
account=None,
|
||||||
|
sessions=defaultdict(list),
|
||||||
|
group_sessions=defaultdict(dict)
|
||||||
|
):
|
||||||
|
# type: (Account, Dict[str, List[Session]) -> None
|
||||||
if account:
|
if account:
|
||||||
self.account = account
|
self.account = account
|
||||||
else:
|
else:
|
||||||
self.account = Account()
|
self.account = Account()
|
||||||
|
|
||||||
|
self.sessions = sessions
|
||||||
|
self.group_sessions = group_sessions
|
||||||
|
|
||||||
|
def _create_session(self, sender, sender_key, message):
|
||||||
|
session = InboundSession(self.account, message, sender_key)
|
||||||
|
self.sessions[sender].append(session)
|
||||||
|
self.account.remove_one_time_keys(session)
|
||||||
|
|
||||||
|
return session
|
||||||
|
|
||||||
|
def create_group_session(self, room_id, session_id, session_key):
|
||||||
|
session = InboundGroupSession(session_key)
|
||||||
|
self.group_sessions[room_id][session_id] = session
|
||||||
|
|
||||||
|
@encrypt_enabled
|
||||||
|
def decrypt(self, sender, sender_key, message):
|
||||||
|
plaintext = None
|
||||||
|
|
||||||
|
for session in self.sessions[sender]:
|
||||||
|
try:
|
||||||
|
if isinstance(message, OlmPreKeyMessage):
|
||||||
|
if not session.matches(message):
|
||||||
|
continue
|
||||||
|
|
||||||
|
plaintext = session.decrypt(message)
|
||||||
|
break
|
||||||
|
except OlmSessionError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
session = self._create_session(sender, sender_key, message)
|
||||||
|
|
||||||
|
try:
|
||||||
|
plaintext = session.decrypt(message)
|
||||||
|
return plaintext
|
||||||
|
except OlmSessionError:
|
||||||
|
return None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@encrypt_enabled
|
@encrypt_enabled
|
||||||
def from_session_dir(cls, server):
|
def from_session_dir(cls, server):
|
||||||
|
|
|
@ -28,6 +28,11 @@ from matrix.rooms import (matrix_create_room_buffer, RoomInfo, RoomMessageText,
|
||||||
RoomMessageEvent, RoomRedactedMessageEvent,
|
RoomMessageEvent, RoomRedactedMessageEvent,
|
||||||
RoomMessageEmote)
|
RoomMessageEmote)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from olm.session import OlmMessage, OlmPreKeyMessage
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class MatrixEvent():
|
class MatrixEvent():
|
||||||
|
|
||||||
|
@ -444,6 +449,75 @@ class MatrixSyncEvent(MatrixEvent):
|
||||||
|
|
||||||
return (join_infos, invite_infos)
|
return (join_infos, invite_infos)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_olm_device_event(server, parsed_dict):
|
||||||
|
device_key = server.olm.account.device_keys["curve25519"]
|
||||||
|
|
||||||
|
if device_key not in parsed_dict["content"]["ciphertext"]:
|
||||||
|
return None
|
||||||
|
|
||||||
|
ciphertext = parsed_dict["content"]["ciphertext"].pop(device_key)
|
||||||
|
sender = sanitize_id(parsed_dict["sender"])
|
||||||
|
sender_key = sanitize_id(parsed_dict["sender_key"])
|
||||||
|
|
||||||
|
message = None
|
||||||
|
|
||||||
|
if ciphertext["type"] == 0:
|
||||||
|
message = OlmPreKeyMessage(parsed_dict["body"])
|
||||||
|
elif ciphertext["type"] == 1:
|
||||||
|
message = OlmMessage(parsed_dict["body"])
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid Olm message type")
|
||||||
|
|
||||||
|
olm = server.olm
|
||||||
|
plaintext = olm.decrypt(sender, sender_key, message)
|
||||||
|
|
||||||
|
# TODO check sender key
|
||||||
|
decrypted_sender = sanitize_id(plaintext["sender"])
|
||||||
|
decrypted_recepient = sanitize_id(plaintext["recipient"])
|
||||||
|
decrypted_recepient_key = sanitize_id(
|
||||||
|
plaintext["recipient_keys"]["ed25519"])
|
||||||
|
|
||||||
|
if (sender != decrypted_sender or
|
||||||
|
server.user_id != decrypted_recepient or
|
||||||
|
device_key != decrypted_recepient_key):
|
||||||
|
error_message = ("{prefix}matrix: Mismatch in decrypted Olm"
|
||||||
|
"message").format(W.prefix("error"))
|
||||||
|
W.prnt("", error_message)
|
||||||
|
return None
|
||||||
|
|
||||||
|
if plaintext["type"] != "m.room.key":
|
||||||
|
return None
|
||||||
|
|
||||||
|
MatrixSyncEvent._handle_key_event(server, sender_key, plaintext)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _handle_key_event(server, sender_key, parsed_dict):
|
||||||
|
# type: (MatrixServer, str, Dict[Any, Any] -> None
|
||||||
|
olm = server.olm
|
||||||
|
content = parsed_dict.pop("content")
|
||||||
|
|
||||||
|
if content["type"] != "m.megolm.v1.aes-sha2":
|
||||||
|
return
|
||||||
|
|
||||||
|
room_id = sanitize_id(content["room_id"])
|
||||||
|
session_id = sanitize_id(content["session_id"])
|
||||||
|
session_key = sanitize_id(content["session_key"])
|
||||||
|
|
||||||
|
if session_id in olm.group_sessions[room_id]:
|
||||||
|
return
|
||||||
|
|
||||||
|
olm.create_group_session(room_id, session_id, session_key)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_to_device_events(server, parsed_dict):
|
||||||
|
# type: (MatrixServer, Dict[Any, Any]) -> None
|
||||||
|
for event in parsed_dict["events"]:
|
||||||
|
if event["type"] == "m.room.encrypted":
|
||||||
|
if (event["content"]["algorithm"] ==
|
||||||
|
'm.olm.v1.curve25519-aes-sha2'):
|
||||||
|
MatrixSyncEvent._get_olm_device_event(server, parsed_dict)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls, server, parsed_dict):
|
def from_dict(cls, server, parsed_dict):
|
||||||
try:
|
try:
|
||||||
|
@ -456,6 +530,9 @@ class MatrixSyncEvent(MatrixEvent):
|
||||||
one_time_key_count = (
|
one_time_key_count = (
|
||||||
parsed_dict["device_one_time_keys_count"]["signed_curve25519"])
|
parsed_dict["device_one_time_keys_count"]["signed_curve25519"])
|
||||||
|
|
||||||
|
MatrixSyncEvent._get_to_device_events(
|
||||||
|
server.olm, parsed_dict.pop("to_device"))
|
||||||
|
|
||||||
room_info_dict = parsed_dict["rooms"]
|
room_info_dict = parsed_dict["rooms"]
|
||||||
|
|
||||||
join_infos, invite_infos = MatrixSyncEvent._infos_from_dict(
|
join_infos, invite_infos = MatrixSyncEvent._infos_from_dict(
|
||||||
|
|
Loading…
Reference in a new issue