weechat-matrix/matrix/encryption.py
2018-05-16 18:18:02 +02:00

773 lines
25 KiB
Python

# -*- coding: utf-8 -*-
# Weechat Matrix Protocol Script
# Copyright © 2018 Damir Jelić <poljar@termina.org.uk>
#
# Permission to use, copy, modify, and/or distribute this software for
# any purpose with or without fee is hereby granted, provided that the
# above copyright notice and this permission notice appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY
# SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER
# RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF
# CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
# CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
from __future__ import unicode_literals
import os
import json
import sqlite3
import pprint
# pylint: disable=redefined-builtin
from builtins import str, bytes
from collections import defaultdict
from functools import wraps
from future.moves.itertools import zip_longest
try:
FileNotFoundError
except NameError:
FileNotFoundError = IOError
import matrix.globals
try:
from olm.account import Account, OlmAccountError
from olm.session import (Session, InboundSession, OutboundSession,
OlmSessionError, OlmPreKeyMessage)
from olm.group_session import (
InboundGroupSession,
OutboundGroupSession,
OlmGroupSessionError
)
except ImportError:
matrix.globals.ENCRYPTION = False
from matrix.globals import W, SERVERS
from matrix.utils import sanitize_id
from matrix.utf import utf8_decode
def own_buffer_or_error(f):
@wraps(f)
def wrapper(data, buffer, *args, **kwargs):
for server in SERVERS.values():
if buffer in server.buffers.values():
return f(server.name, buffer, *args, **kwargs)
elif buffer == server.server_buffer:
return f(server.name, buffer, *args, **kwargs)
W.prnt("", "{prefix}matrix: command \"olm\" must be executed on a "
"matrix buffer (server or channel)".format(
prefix=W.prefix("error")))
return W.WEECHAT_RC_OK
return wrapper
def encrypt_enabled(f):
@wraps(f)
def wrapper(*args, **kwds):
if matrix.globals.ENCRYPTION:
return f(*args, **kwds)
return None
return wrapper
@encrypt_enabled
def matrix_hook_olm_command():
W.hook_command(
# Command name and short description
"olm",
"Matrix olm encryption command",
# Synopsis
("info all|blacklisted|private|unverified|verified <filter>||"
"blacklist <device-id> ||"
"unverify <device-id> ||"
"verify <device-id>"),
# Description
(" info: show info about known devices and their keys\n"
"blacklist: blacklist a device\n"
" unverify: unverify a device\n"
" verify: verify a device\n\n"
"Examples:\n"),
# Completions
('info all|blacklisted|private|unverified|verified ||'
'blacklist %(device_ids) ||'
'unverify %(device_ids) ||'
'verify %(device_ids)'),
# Function name
'matrix_olm_command_cb',
'')
def olm_cmd_parse_args(args):
split_args = args.split()
command = split_args.pop(0) if split_args else "info"
rest_args = split_args if split_args else []
return command, rest_args
def grouper(iterable, n, fillvalue=None):
"Collect data into fixed-length chunks or blocks"
# grouper('ABCDEFG', 3, 'x') --> ABC DEF Gxx"
args = [iter(iterable)] * n
return zip_longest(*args, fillvalue=fillvalue)
def partition_key(key):
groups = grouper(key, 4, " ")
return ' '.join(''.join(g) for g in groups)
@own_buffer_or_error
@utf8_decode
def matrix_olm_command_cb(server_name, buffer, args):
server = SERVERS[server_name]
command, args = olm_cmd_parse_args(args)
if not command or command == "info":
olm = server.olm
if not args or args[0] == "private":
device_msg = (" - Device ID: {}\n".format(server.device_id)
if server.device_id else "")
id_key = partition_key(olm.account.identity_keys()["curve25519"])
fp_key = partition_key(olm.account.identity_keys()["ed25519"])
message = ("{prefix}matrix: Identity keys:\n"
" - User: {user}\n"
"{device_msg}"
" - Identity key: {id_key}\n"
" - Fingerprint key: {fp_key}\n").format(
prefix=W.prefix("network"),
user=server.user,
device_msg=device_msg,
id_key=id_key,
fp_key=fp_key)
W.prnt(server.server_buffer, message)
elif args[0] == "all":
for user, keys in olm.device_keys.items():
message = ("{prefix}matrix: Identity keys:\n"
" - User: {user}\n").format(
prefix=W.prefix("network"),
user=user)
W.prnt(server.server_buffer, message)
for key in keys:
id_key = partition_key(key.keys["curve25519"])
fp_key = partition_key(key.keys["ed25519"])
device_msg = (" - Device ID: {}\n".format(
key.device_id) if key.device_id else "")
message = ("{device_msg}"
" - Identity key: {id_key}\n"
" - Fingerprint key: {fp_key}\n\n").format(
device_msg=device_msg,
id_key=id_key,
fp_key=fp_key)
W.prnt(server.server_buffer, message)
else:
message = ("{prefix}matrix: Command not implemented.".format(
prefix=W.prefix("error")))
W.prnt(server.server_buffer, message)
return W.WEECHAT_RC_OK
class EncryptionError(Exception):
pass
class DeviceStore(object):
def __init__(self, filename):
self._entries = []
self._filename = filename
self._load(filename)
def _load(self, filename):
# type: (str) -> None
try:
with open(filename, "r") as f:
for line in f:
line = line.strip()
if not line or line.startswith("#"):
continue
entry = StoreEntry.from_line(line)
if not entry:
continue
self._entries.append(entry)
except FileNotFoundError:
pass
def _save_store(f):
@wraps(f)
def decorated(*args, **kwargs):
self = args[0]
ret = f(*args, **kwargs)
self._save()
return ret
return decorated
def _save(self):
# type: (str) -> None
with open(self._filename, "w") as f:
for entry in self._entries:
line = entry.to_line()
f.write(line)
@_save_store
def add(self, device):
# type: (OlmDeviceKey) -> None
new_entries = StoreEntry.from_olmdevice(device)
self._entries += new_entries
# Remove duplicate entries
self._entries = list(set(self._entries))
self._save()
@_save_store
def remove(self, device):
# type: (OlmDeviceKey) -> int
removed = 0
entries = StoreEntry.from_olmdevice(device)
for entry in entries:
if entry in self._entries:
self._entries.remove(entry)
removed += 1
self._save()
return removed
def check(self, device):
# type: (OlmDeviceKey) -> bool
entries = StoreEntry.from_olmdevice(device)
result = map(lambda entry: entry in self._entries, entries)
if False in result:
return False
return True
class StoreEntry(object):
def __init__(self, user_id, device_id, key_type, key):
# type: (str, str, str, str) -> None
self.user_id = user_id
self.device_id = device_id
self.key_type = key_type
self.key = key
@classmethod
def from_line(cls, line):
# type: (str) -> StoreEntry
fields = line.split(' ')
if len(fields) < 4:
return None
user_id, device_id, key_type, key = fields[:4]
if key_type == "matrix-ed25519":
return cls(user_id, device_id, "ed25519", key)
else:
return None
@classmethod
def from_olmdevice(cls, device_key):
# type: (OlmDeviceKey) -> [StoreEntry]
entries = []
user_id = device_key.user_id
device_id = device_key.device_id
for key_type, key in device_key.keys.items():
if key_type == "ed25519":
entries.append(cls(user_id, device_id, "ed25519", key))
return entries
def to_line(self):
# type: () -> str
key_type = "matrix-{}".format(self.key_type)
line = "{} {} {} {}".format(
self.user_id,
self.device_id,
key_type,
self.key
)
return line
def __hash__(self):
# type: () -> int
return hash(str(self))
def __str__(self):
# type: () -> str
key_type = "matrix-{}".format(self.key_type)
line = "{} {} {} {}".format(
self.user_id,
self.device_id,
key_type,
self.key
)
return line
def __eq__(self, value):
# type: (StoreEntry) -> bool
if (self.user_id == value.user_id
and self.device_id == value.device_id
and self.key_type == value.key_type and self.key == value.key):
return True
return False
class OlmDeviceKey():
def __init__(self, user_id, device_id, key_dict):
# type: (str, str, Dict[str, str])
self.user_id = user_id
self.device_id = device_id
self.keys = key_dict
class OneTimeKey():
def __init__(self, user_id, device_id, key):
# type: (str, str, str) -> None
self.user_id = user_id
self.device_id = device_id
self.key = key
class Olm():
@encrypt_enabled
def __init__(
self,
user,
device_id,
session_path,
database=None,
account=None,
sessions=None,
inbound_group_sessions=None
):
# type: (str, str, str, Account, Dict[str, List[Session]) -> None
self.user = user
self.device_id = device_id
self.session_path = session_path
self.database = database
self.device_keys = {}
if not database:
db_file = "{}_{}.db".format(user, device_id)
db_path = os.path.join(session_path, db_file)
self.database = sqlite3.connect(db_path)
Olm._check_db_tables(self.database)
if account:
self.account = account
else:
self.account = Account()
self._insert_acc_to_db()
if not sessions:
sessions = defaultdict(lambda: defaultdict(list))
if not inbound_group_sessions:
inbound_group_sessions = defaultdict(dict)
self.sessions = sessions
self.inbound_group_sessions = inbound_group_sessions
self.outbound_group_sessions = {}
def _create_session(self, sender, sender_key, message):
W.prnt("", "matrix: Creating session for {}".format(sender))
session = InboundSession(self.account, message, sender_key)
W.prnt("", "matrix: Created session for {}".format(sender))
self.account.remove_one_time_keys(session)
self._update_acc_in_db()
return session
def create_session(self, user_id, device_id, one_time_key):
W.prnt("", "matrix: Creating session for {}".format(user_id))
id_key = None
for user, keys in self.device_keys.items():
if user != user_id:
continue
for key in keys:
if key.device_id == device_id:
id_key = key.keys["curve25519"]
break
if not id_key:
W.prnt("", "ERRR not found ID key")
W.prnt("", "Found id key {}".format(id_key))
session = OutboundSession(self.account, id_key, one_time_key)
self._update_acc_in_db()
self.sessions[user_id][device_id].append(session)
self._store_session(user_id, device_id, session)
W.prnt("", "matrix: Created session for {}".format(user_id))
def create_group_session(self, room_id, session_id, session_key):
W.prnt("", "matrix: Creating group session for {}".format(room_id))
session = InboundGroupSession(session_key)
self.inbound_group_sessions[room_id][session_id] = session
self._store_inbound_group_session(room_id, session)
def create_outbound_group_session(self, room_id):
session = OutboundGroupSession()
self.outbound_group_sessions[room_id] = session
self.create_group_session(room_id, session.id, session.session_key)
@encrypt_enabled
def get_missing_sessions(self, users):
# type: (List[str]) -> Dict[str, Dict[str, str]]
missing = {}
for user in users:
devices = []
for key in self.device_keys[user]:
# we don't need a session for our own device, skip it
if key.device_id == self.device_id:
continue
if not self.sessions[user][key.device_id]:
W.prnt("", "Missing session for device {}".format(key.device_id))
devices.append(key.device_id)
if devices:
missing[user] = {device: "signed_curve25519" for
device in devices}
return missing
@encrypt_enabled
def decrypt(self, sender, sender_key, message):
plaintext = None
for device_id, session_list in self.sessions[sender].items():
for session in session_list:
W.prnt("", "Trying session for device {}".format(device_id))
try:
if isinstance(message, OlmPreKeyMessage):
if not session.matches(message):
continue
W.prnt("", "Decrypting using existing session")
plaintext = session.decrypt(message)
parsed_plaintext = json.loads(plaintext, encoding='utf-8')
W.prnt("", "Decrypted using existing session")
return parsed_plaintext
except OlmSessionError:
pass
try:
session = self._create_session(sender, sender_key, message)
except OlmSessionError:
return None
try:
plaintext = session.decrypt(message)
parsed_plaintext = json.loads(plaintext, encoding='utf-8')
device_id = sanitize_id(parsed_plaintext["sender_device"])
self.sessions[sender][device_id].append(session)
self._store_session(sender, device_id, session)
return parsed_plaintext
except OlmSessionError:
return None
def group_encrypt(self, room_id, plaintext_dict, own_id, users):
# type: (str, Dict[str, str]) -> Dict[str, str], Optional[Dict[Any, Any]]
plaintext_dict["room_id"] = room_id
to_device_dict = None
if room_id not in self.outbound_group_sessions:
self.create_outbound_group_session(room_id)
to_device_dict = self.share_group_session(room_id, own_id, users)
session = self.outbound_group_sessions[room_id]
ciphertext = session.encrypt(Olm._to_json(plaintext_dict))
payload_dict = {
"algorithm": "m.megolm.v1.aes-sha2",
"sender_key": self.account.identity_keys()["curve25519"],
"ciphertext": ciphertext,
"session_id": session.id,
"device_id": self.device_id
}
return payload_dict, to_device_dict
@encrypt_enabled
def group_decrypt(self, room_id, session_id, ciphertext):
if session_id not in self.inbound_group_sessions[room_id]:
return None
session = self.inbound_group_sessions[room_id][session_id]
try:
plaintext = session.decrypt(ciphertext)
except OlmGroupSessionError:
return None
return plaintext
def share_group_session(self, room_id, own_id, users):
group_session = self.outbound_group_sessions[room_id]
key_content = {
"algorithm": "m.megolm.v1.aes-sha2",
"room_id": room_id,
"session_id": group_session.id,
"session_key": group_session.session_key,
"chain_index": group_session.message_index
}
payload_dict = {
"type": "m.room_key",
"content": key_content,
# TODO we don't have the user_id in the Olm class
"sender": own_id,
"sender_device": self.device_id,
"keys": {
"ed25519": self.account.identity_keys()["ed25519"]
}
}
to_device_dict = {
"messages": {}
}
for user in users:
if user not in self.device_keys:
continue
for key in self.device_keys[user]:
if key.device_id == self.device_id:
continue
if not self.sessions[user][key.device_id]:
continue
device_payload_dict = payload_dict.copy()
# TODO sort the sessions
session = self.sessions[user][key.device_id][0]
device_payload_dict["recipient"] = user
device_payload_dict["recipient_keys"] = {
"ed25519": key.keys["ed25519"]
}
olm_message = session.encrypt(
Olm._to_json(device_payload_dict)
)
olm_dict = {
"algorithm": "m.olm.v1.curve25519-aes-sha2",
"sender_key": self.account.identity_keys()["curve25519"],
"ciphertext": {
key.keys["curve25519"]: {
"type": (0 if isinstance(
olm_message,
OlmPreKeyMessage
) else 1),
"body": olm_message.ciphertext
}
}
}
if user not in to_device_dict["messages"]:
to_device_dict["messages"][user] = {}
to_device_dict["messages"][user][key.device_id] = olm_dict
# W.prnt("", pprint.pformat(to_device_dict))
return to_device_dict
@classmethod
@encrypt_enabled
def from_session_dir(cls, user, device_id, session_path):
# type: (Server) -> Olm
db_file = "{}_{}.db".format(user, device_id)
db_path = os.path.join(session_path, db_file)
database = sqlite3.connect(db_path)
Olm._check_db_tables(database)
cursor = database.cursor()
cursor.execute("select pickle from olmaccount where user = ?", (user,))
row = cursor.fetchone()
account_pickle = row[0]
cursor.execute("select user, device_id, pickle from olmsessions")
db_sessions = cursor.fetchall()
cursor.execute("select room_id, pickle from inbound_group_sessions")
db_inbound_group_sessions = cursor.fetchall()
cursor.close()
sessions = defaultdict(lambda: defaultdict(list))
inbound_group_sessions = defaultdict(dict)
try:
account = Account.from_pickle(bytes(account_pickle, "utf-8"))
for db_session in db_sessions:
sessions[db_session[0]][db_session[1]].append(
Session.from_pickle(bytes(db_session[2], "utf-8")))
for db_session in db_inbound_group_sessions:
session = InboundGroupSession.from_pickle(
bytes(db_session[1], "utf-8"))
inbound_group_sessions[db_session[0]][session.id] = session
return cls(user, device_id, session_path, database, account,
sessions, inbound_group_sessions)
except (OlmAccountError, OlmSessionError) as error:
raise EncryptionError(error)
def _update_acc_in_db(self):
cursor = self.database.cursor()
cursor.execute("update olmaccount set pickle=? where user = ?",
(self.account.pickle(), self.user))
self.database.commit()
cursor.close()
def _update_sessions_in_db(self):
cursor = self.database.cursor()
for user, session_dict in self.sessions.items():
for device_id, session_list in session_dict.items():
for session in session_list:
cursor.execute("""update olmsessions set pickle=?
where user = ? and session_id = ? and
device_id = ?""",
(session.pickle(), user, session.id(),
device_id))
self.database.commit()
cursor.close()
def _update_inbound_group_sessions(self):
cursor = self.database.cursor()
for room_id, session_dict in self.inbound_group_sessions.items():
for session in session_dict.values():
cursor.execute("""update inbound_group_sessions set pickle=?
where room_id = ? and session_id = ?""",
(session.pickle(), room_id, session.id()))
self.database.commit()
cursor.close()
def _store_session(self, user, device_id, session):
cursor = self.database.cursor()
cursor.execute("insert into olmsessions values(?,?,?,?)",
(user, device_id, session.id(), session.pickle()))
self.database.commit()
cursor.close()
def _store_inbound_group_session(self, room_id, session):
cursor = self.database.cursor()
cursor.execute("insert into inbound_group_sessions values(?,?,?)",
(room_id, session.id, session.pickle()))
self.database.commit()
cursor.close()
def _insert_acc_to_db(self):
cursor = self.database.cursor()
cursor.execute("insert into olmaccount values (?,?)",
(self.user, self.account.pickle()))
self.database.commit()
cursor.close()
@staticmethod
def _check_db_tables(database):
cursor = database.cursor()
cursor.execute("""select name from sqlite_master where type='table'
and name='olmaccount'""")
if not cursor.fetchone():
cursor.execute("create table olmaccount (user text, pickle text)")
database.commit()
cursor.execute("""select name from sqlite_master where type='table'
and name='olmsessions'""")
if not cursor.fetchone():
cursor.execute("""create table olmsessions (user text,
device_id text, session_id text, pickle text)""")
database.commit()
cursor.execute("""select name from sqlite_master where type='table'
and name='inbound_group_sessions'""")
if not cursor.fetchone():
cursor.execute("""create table inbound_group_sessions
(room_id text, session_id text, pickle text)""")
database.commit()
cursor.close()
@encrypt_enabled
def to_session_dir(self):
# type: (Server) -> None
try:
self._update_acc_in_db()
self._update_sessions_in_db()
except OlmAccountError as error:
raise EncryptionError(error)
def sign_json(self, json_dict):
signature = self.account.sign(json.dumps(
json_dict,
ensure_ascii=False,
separators=(',', ':'),
sort_keys=True,
))
return signature
@staticmethod
def _to_json(json_dict):
# type: (Dict[Any, Any]) -> str
return json.dumps(
json_dict,
ensure_ascii=False,
separators=(",", ":"),
sort_keys=True
)
@encrypt_enabled
def mark_keys_as_published(self):
self.account.mark_keys_as_published()