encryption: Decrypt incomming messages.

This commit is contained in:
poljar (Damir Jelić) 2018-04-03 21:51:03 +02:00
parent ec2995fe52
commit 1fd5bd637d
3 changed files with 73 additions and 30 deletions

View file

@ -171,15 +171,20 @@ class Olm():
self.group_sessions = group_sessions
def _create_session(self, sender, sender_key, message):
W.prnt("", "matrix: Creating session for {}".format(sender))
session = InboundSession(self.account, message, sender_key)
W.prnt("", "matrix: Created session for {}".format(sender))
self.sessions[sender].append(session)
self.account.remove_one_time_keys(session)
# self.account.remove_one_time_keys(session)
# TODO store account here
return session
def create_group_session(self, room_id, session_id, session_key):
W.prnt("", "matrix: Creating group session for {}".format(room_id))
session = InboundGroupSession(session_key)
self.group_sessions[room_id][session_id] = session
# TODO store account here
@encrypt_enabled
def decrypt(self, sender, sender_key, message):
@ -204,6 +209,19 @@ class Olm():
except OlmSessionError:
return None
@encrypt_enabled
def group_decrypt(self, room_id, session_id, ciphertext):
if session_id not in self.group_sessions[room_id]:
return None
session = self.group_sessions[room_id][session_id]
try:
plaintext = session.decrypt(ciphertext)
except OlmGroupSessionError:
return None
return plaintext
@classmethod
@encrypt_enabled
def from_session_dir(cls, server):

View file

@ -17,6 +17,8 @@
from __future__ import unicode_literals
from builtins import str
import json
from collections import deque
from functools import partial
from operator import itemgetter
@ -437,7 +439,7 @@ class MatrixSyncEvent(MatrixEvent):
MatrixEvent.__init__(self, server)
@staticmethod
def _infos_from_dict(parsed_dict):
def _infos_from_dict(olm, parsed_dict):
join_infos = []
invite_infos = []
@ -445,27 +447,28 @@ class MatrixSyncEvent(MatrixEvent):
if not room_id:
continue
join_infos.append(RoomInfo.from_dict(room_id, room_dict))
join_infos.append(RoomInfo.from_dict(olm, room_id, room_dict))
return (join_infos, invite_infos)
@staticmethod
def _get_olm_device_event(server, parsed_dict):
device_key = server.olm.account.device_keys["curve25519"]
device_key = server.olm.account.identity_keys()["curve25519"]
if device_key not in parsed_dict["content"]["ciphertext"]:
return None
ciphertext = parsed_dict["content"]["ciphertext"].pop(device_key)
sender = sanitize_id(parsed_dict["sender"])
sender_key = sanitize_id(parsed_dict["sender_key"])
sender_key = sanitize_id(parsed_dict["content"]["sender_key"])
ciphertext = parsed_dict["content"]["ciphertext"].pop(device_key)
message = None
if ciphertext["type"] == 0:
message = OlmPreKeyMessage(parsed_dict["body"])
message = OlmPreKeyMessage(ciphertext["body"])
elif ciphertext["type"] == 1:
message = OlmMessage(parsed_dict["body"])
message = OlmMessage(ciphertext["body"])
else:
raise ValueError("Invalid Olm message type")
@ -473,23 +476,24 @@ class MatrixSyncEvent(MatrixEvent):
plaintext = olm.decrypt(sender, sender_key, message)
# TODO check sender key
decrypted_sender = sanitize_id(plaintext["sender"])
decrypted_recepient = sanitize_id(plaintext["recipient"])
decrypted_recepient_key = sanitize_id(
plaintext["recipient_keys"]["ed25519"])
parsed_plaintext = json.loads(plaintext, encoding='utf-8')
decrypted_sender = parsed_plaintext["sender"]
decrypted_recepient = parsed_plaintext["recipient"]
decrypted_recepient_key = parsed_plaintext["recipient_keys"]["ed25519"]
if (sender != decrypted_sender or
server.user_id != decrypted_recepient or
device_key != decrypted_recepient_key):
error_message = ("{prefix}matrix: Mismatch in decrypted Olm"
"message").format(W.prefix("error"))
olm.account.identity_keys()["ed25519"] !=
decrypted_recepient_key):
error_message = ("{prefix}matrix: Mismatch in decrypted Olm "
"message").format(prefix=W.prefix("error"))
W.prnt("", error_message)
return None
if plaintext["type"] != "m.room.key":
if parsed_plaintext["type"] != "m.room_key":
return None
MatrixSyncEvent._handle_key_event(server, sender_key, plaintext)
MatrixSyncEvent._handle_key_event(server, sender_key, parsed_plaintext)
@staticmethod
def _handle_key_event(server, sender_key, parsed_dict):
@ -497,12 +501,12 @@ class MatrixSyncEvent(MatrixEvent):
olm = server.olm
content = parsed_dict.pop("content")
if content["type"] != "m.megolm.v1.aes-sha2":
if content["algorithm"] != "m.megolm.v1.aes-sha2":
return
room_id = sanitize_id(content["room_id"])
session_id = sanitize_id(content["session_id"])
session_key = sanitize_id(content["session_key"])
room_id = content["room_id"]
session_id = content["session_id"]
session_key = content["session_key"]
if session_id in olm.group_sessions[room_id]:
return
@ -516,7 +520,7 @@ class MatrixSyncEvent(MatrixEvent):
if event["type"] == "m.room.encrypted":
if (event["content"]["algorithm"] ==
'm.olm.v1.curve25519-aes-sha2'):
MatrixSyncEvent._get_olm_device_event(server, parsed_dict)
MatrixSyncEvent._get_olm_device_event(server, event)
@classmethod
def from_dict(cls, server, parsed_dict):
@ -531,12 +535,12 @@ class MatrixSyncEvent(MatrixEvent):
parsed_dict["device_one_time_keys_count"]["signed_curve25519"])
MatrixSyncEvent._get_to_device_events(
server.olm, parsed_dict.pop("to_device"))
server, parsed_dict.pop("to_device"))
room_info_dict = parsed_dict["rooms"]
join_infos, invite_infos = MatrixSyncEvent._infos_from_dict(
room_info_dict)
server.olm, room_info_dict)
return cls(server, next_batch, join_infos, invite_infos,
one_time_key_count)

View file

@ -17,6 +17,8 @@
from __future__ import unicode_literals
from builtins import str
import json
from pprint import pformat
from collections import namedtuple, deque
@ -243,7 +245,7 @@ class RoomInfo():
return None, None
@staticmethod
def parse_event(event_dict):
def parse_event(olm, room_id, event_dict):
# type: (Dict[Any, Any]) -> (RoomEvent, RoomEvent)
state_event = None
message_event = None
@ -270,11 +272,30 @@ class RoomInfo():
state_event = RoomAliasEvent.from_dict(event_dict)
elif event_dict["type"] == "m.room.encryption":
state_event = RoomEncryptionEvent.from_dict(event_dict)
elif event_dict["type"] == "m.room.encrypted":
state_event, message_event = RoomInfo._decrypt_event(olm, room_id,
event_dict)
return state_event, message_event
@staticmethod
def _parse_events(parsed_dict, messages=True, state=True):
def _decrypt_event(olm, room_id, event_dict):
session_id = event_dict["content"]["session_id"]
ciphertext = event_dict["content"]["ciphertext"]
plaintext = olm.group_decrypt(room_id, session_id, ciphertext)
if not plaintext:
return None, None
parsed_plaintext = json.loads(plaintext, encoding="utf-8")
event_dict["content"] = parsed_plaintext["content"]
event_dict["type"] = parsed_plaintext["type"]
return RoomInfo.parse_event(olm, room_id, event_dict)
@staticmethod
def _parse_events(olm, room_id, parsed_dict, messages=True, state=True):
state_events = []
message_events = []
@ -283,7 +304,7 @@ class RoomInfo():
try:
for event in parsed_dict:
m_event, s_event = RoomInfo.parse_event(event)
m_event, s_event = RoomInfo.parse_event(olm, room_id, event)
state_events.append(m_event)
message_events.append(s_event)
except (ValueError, TypeError, KeyError) as error:
@ -306,14 +327,14 @@ class RoomInfo():
return events
@classmethod
def from_dict(cls, room_id, parsed_dict):
def from_dict(cls, olm, room_id, parsed_dict):
prev_batch = sanitize_id(parsed_dict['timeline']['prev_batch'])
state_dict = parsed_dict['state']['events']
timeline_dict = parsed_dict['timeline']['events']
state_events = RoomInfo._parse_events(state_dict, messages=False)
timeline_events = RoomInfo._parse_events(timeline_dict)
state_events = RoomInfo._parse_events(olm, room_id, state_dict, messages=False)
timeline_events = RoomInfo._parse_events(olm, room_id, timeline_dict)
events = state_events + timeline_events