encryption: Decrypt incomming messages.

This commit is contained in:
poljar (Damir Jelić) 2018-04-03 21:51:03 +02:00
parent ec2995fe52
commit 1fd5bd637d
3 changed files with 73 additions and 30 deletions

View file

@ -171,15 +171,20 @@ class Olm():
self.group_sessions = group_sessions self.group_sessions = group_sessions
def _create_session(self, sender, sender_key, message): def _create_session(self, sender, sender_key, message):
W.prnt("", "matrix: Creating session for {}".format(sender))
session = InboundSession(self.account, message, sender_key) session = InboundSession(self.account, message, sender_key)
W.prnt("", "matrix: Created session for {}".format(sender))
self.sessions[sender].append(session) self.sessions[sender].append(session)
self.account.remove_one_time_keys(session) # self.account.remove_one_time_keys(session)
# TODO store account here
return session return session
def create_group_session(self, room_id, session_id, session_key): def create_group_session(self, room_id, session_id, session_key):
W.prnt("", "matrix: Creating group session for {}".format(room_id))
session = InboundGroupSession(session_key) session = InboundGroupSession(session_key)
self.group_sessions[room_id][session_id] = session self.group_sessions[room_id][session_id] = session
# TODO store account here
@encrypt_enabled @encrypt_enabled
def decrypt(self, sender, sender_key, message): def decrypt(self, sender, sender_key, message):
@ -204,6 +209,19 @@ class Olm():
except OlmSessionError: except OlmSessionError:
return None return None
@encrypt_enabled
def group_decrypt(self, room_id, session_id, ciphertext):
if session_id not in self.group_sessions[room_id]:
return None
session = self.group_sessions[room_id][session_id]
try:
plaintext = session.decrypt(ciphertext)
except OlmGroupSessionError:
return None
return plaintext
@classmethod @classmethod
@encrypt_enabled @encrypt_enabled
def from_session_dir(cls, server): def from_session_dir(cls, server):

View file

@ -17,6 +17,8 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from builtins import str from builtins import str
import json
from collections import deque from collections import deque
from functools import partial from functools import partial
from operator import itemgetter from operator import itemgetter
@ -437,7 +439,7 @@ class MatrixSyncEvent(MatrixEvent):
MatrixEvent.__init__(self, server) MatrixEvent.__init__(self, server)
@staticmethod @staticmethod
def _infos_from_dict(parsed_dict): def _infos_from_dict(olm, parsed_dict):
join_infos = [] join_infos = []
invite_infos = [] invite_infos = []
@ -445,27 +447,28 @@ class MatrixSyncEvent(MatrixEvent):
if not room_id: if not room_id:
continue continue
join_infos.append(RoomInfo.from_dict(room_id, room_dict)) join_infos.append(RoomInfo.from_dict(olm, room_id, room_dict))
return (join_infos, invite_infos) return (join_infos, invite_infos)
@staticmethod @staticmethod
def _get_olm_device_event(server, parsed_dict): def _get_olm_device_event(server, parsed_dict):
device_key = server.olm.account.device_keys["curve25519"] device_key = server.olm.account.identity_keys()["curve25519"]
if device_key not in parsed_dict["content"]["ciphertext"]: if device_key not in parsed_dict["content"]["ciphertext"]:
return None return None
ciphertext = parsed_dict["content"]["ciphertext"].pop(device_key)
sender = sanitize_id(parsed_dict["sender"]) sender = sanitize_id(parsed_dict["sender"])
sender_key = sanitize_id(parsed_dict["sender_key"]) sender_key = sanitize_id(parsed_dict["content"]["sender_key"])
ciphertext = parsed_dict["content"]["ciphertext"].pop(device_key)
message = None message = None
if ciphertext["type"] == 0: if ciphertext["type"] == 0:
message = OlmPreKeyMessage(parsed_dict["body"]) message = OlmPreKeyMessage(ciphertext["body"])
elif ciphertext["type"] == 1: elif ciphertext["type"] == 1:
message = OlmMessage(parsed_dict["body"]) message = OlmMessage(ciphertext["body"])
else: else:
raise ValueError("Invalid Olm message type") raise ValueError("Invalid Olm message type")
@ -473,23 +476,24 @@ class MatrixSyncEvent(MatrixEvent):
plaintext = olm.decrypt(sender, sender_key, message) plaintext = olm.decrypt(sender, sender_key, message)
# TODO check sender key # TODO check sender key
decrypted_sender = sanitize_id(plaintext["sender"]) parsed_plaintext = json.loads(plaintext, encoding='utf-8')
decrypted_recepient = sanitize_id(plaintext["recipient"]) decrypted_sender = parsed_plaintext["sender"]
decrypted_recepient_key = sanitize_id( decrypted_recepient = parsed_plaintext["recipient"]
plaintext["recipient_keys"]["ed25519"]) decrypted_recepient_key = parsed_plaintext["recipient_keys"]["ed25519"]
if (sender != decrypted_sender or if (sender != decrypted_sender or
server.user_id != decrypted_recepient or server.user_id != decrypted_recepient or
device_key != decrypted_recepient_key): olm.account.identity_keys()["ed25519"] !=
error_message = ("{prefix}matrix: Mismatch in decrypted Olm" decrypted_recepient_key):
"message").format(W.prefix("error")) error_message = ("{prefix}matrix: Mismatch in decrypted Olm "
"message").format(prefix=W.prefix("error"))
W.prnt("", error_message) W.prnt("", error_message)
return None return None
if plaintext["type"] != "m.room.key": if parsed_plaintext["type"] != "m.room_key":
return None return None
MatrixSyncEvent._handle_key_event(server, sender_key, plaintext) MatrixSyncEvent._handle_key_event(server, sender_key, parsed_plaintext)
@staticmethod @staticmethod
def _handle_key_event(server, sender_key, parsed_dict): def _handle_key_event(server, sender_key, parsed_dict):
@ -497,12 +501,12 @@ class MatrixSyncEvent(MatrixEvent):
olm = server.olm olm = server.olm
content = parsed_dict.pop("content") content = parsed_dict.pop("content")
if content["type"] != "m.megolm.v1.aes-sha2": if content["algorithm"] != "m.megolm.v1.aes-sha2":
return return
room_id = sanitize_id(content["room_id"]) room_id = content["room_id"]
session_id = sanitize_id(content["session_id"]) session_id = content["session_id"]
session_key = sanitize_id(content["session_key"]) session_key = content["session_key"]
if session_id in olm.group_sessions[room_id]: if session_id in olm.group_sessions[room_id]:
return return
@ -516,7 +520,7 @@ class MatrixSyncEvent(MatrixEvent):
if event["type"] == "m.room.encrypted": if event["type"] == "m.room.encrypted":
if (event["content"]["algorithm"] == if (event["content"]["algorithm"] ==
'm.olm.v1.curve25519-aes-sha2'): 'm.olm.v1.curve25519-aes-sha2'):
MatrixSyncEvent._get_olm_device_event(server, parsed_dict) MatrixSyncEvent._get_olm_device_event(server, event)
@classmethod @classmethod
def from_dict(cls, server, parsed_dict): def from_dict(cls, server, parsed_dict):
@ -531,12 +535,12 @@ class MatrixSyncEvent(MatrixEvent):
parsed_dict["device_one_time_keys_count"]["signed_curve25519"]) parsed_dict["device_one_time_keys_count"]["signed_curve25519"])
MatrixSyncEvent._get_to_device_events( MatrixSyncEvent._get_to_device_events(
server.olm, parsed_dict.pop("to_device")) server, parsed_dict.pop("to_device"))
room_info_dict = parsed_dict["rooms"] room_info_dict = parsed_dict["rooms"]
join_infos, invite_infos = MatrixSyncEvent._infos_from_dict( join_infos, invite_infos = MatrixSyncEvent._infos_from_dict(
room_info_dict) server.olm, room_info_dict)
return cls(server, next_batch, join_infos, invite_infos, return cls(server, next_batch, join_infos, invite_infos,
one_time_key_count) one_time_key_count)

View file

@ -17,6 +17,8 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from builtins import str from builtins import str
import json
from pprint import pformat from pprint import pformat
from collections import namedtuple, deque from collections import namedtuple, deque
@ -243,7 +245,7 @@ class RoomInfo():
return None, None return None, None
@staticmethod @staticmethod
def parse_event(event_dict): def parse_event(olm, room_id, event_dict):
# type: (Dict[Any, Any]) -> (RoomEvent, RoomEvent) # type: (Dict[Any, Any]) -> (RoomEvent, RoomEvent)
state_event = None state_event = None
message_event = None message_event = None
@ -270,11 +272,30 @@ class RoomInfo():
state_event = RoomAliasEvent.from_dict(event_dict) state_event = RoomAliasEvent.from_dict(event_dict)
elif event_dict["type"] == "m.room.encryption": elif event_dict["type"] == "m.room.encryption":
state_event = RoomEncryptionEvent.from_dict(event_dict) state_event = RoomEncryptionEvent.from_dict(event_dict)
elif event_dict["type"] == "m.room.encrypted":
state_event, message_event = RoomInfo._decrypt_event(olm, room_id,
event_dict)
return state_event, message_event return state_event, message_event
@staticmethod @staticmethod
def _parse_events(parsed_dict, messages=True, state=True): def _decrypt_event(olm, room_id, event_dict):
session_id = event_dict["content"]["session_id"]
ciphertext = event_dict["content"]["ciphertext"]
plaintext = olm.group_decrypt(room_id, session_id, ciphertext)
if not plaintext:
return None, None
parsed_plaintext = json.loads(plaintext, encoding="utf-8")
event_dict["content"] = parsed_plaintext["content"]
event_dict["type"] = parsed_plaintext["type"]
return RoomInfo.parse_event(olm, room_id, event_dict)
@staticmethod
def _parse_events(olm, room_id, parsed_dict, messages=True, state=True):
state_events = [] state_events = []
message_events = [] message_events = []
@ -283,7 +304,7 @@ class RoomInfo():
try: try:
for event in parsed_dict: for event in parsed_dict:
m_event, s_event = RoomInfo.parse_event(event) m_event, s_event = RoomInfo.parse_event(olm, room_id, event)
state_events.append(m_event) state_events.append(m_event)
message_events.append(s_event) message_events.append(s_event)
except (ValueError, TypeError, KeyError) as error: except (ValueError, TypeError, KeyError) as error:
@ -306,14 +327,14 @@ class RoomInfo():
return events return events
@classmethod @classmethod
def from_dict(cls, room_id, parsed_dict): def from_dict(cls, olm, room_id, parsed_dict):
prev_batch = sanitize_id(parsed_dict['timeline']['prev_batch']) prev_batch = sanitize_id(parsed_dict['timeline']['prev_batch'])
state_dict = parsed_dict['state']['events'] state_dict = parsed_dict['state']['events']
timeline_dict = parsed_dict['timeline']['events'] timeline_dict = parsed_dict['timeline']['events']
state_events = RoomInfo._parse_events(state_dict, messages=False) state_events = RoomInfo._parse_events(olm, room_id, state_dict, messages=False)
timeline_events = RoomInfo._parse_events(timeline_dict) timeline_events = RoomInfo._parse_events(olm, room_id, timeline_dict)
events = state_events + timeline_events events = state_events + timeline_events