encryption: Port the olm command parser to argparse.

This commit is contained in:
poljar (Damir Jelić) 2018-05-16 22:46:22 +02:00
parent dedff37a60
commit d5b768e7e0

View file

@ -21,6 +21,7 @@ import os
import json import json
import sqlite3 import sqlite3
import pprint import pprint
import argparse
# pylint: disable=redefined-builtin # pylint: disable=redefined-builtin
from builtins import str, bytes from builtins import str, bytes
@ -53,6 +54,22 @@ from matrix.utils import sanitize_id
from matrix.utf import utf8_decode from matrix.utf import utf8_decode
class ParseError(Exception):
pass
class WeechatArgParse(argparse.ArgumentParser):
def print_usage(self, file):
pass
def error(self, message):
m = ("{prefix}Error: {message} for command {command} "
"(see /help {command})").format(prefix=W.prefix("error"),
message=message, command=self.prog)
W.prnt("", m)
raise ParseError
def own_buffer_or_error(f): def own_buffer_or_error(f):
@wraps(f) @wraps(f)
@ -74,7 +91,6 @@ def own_buffer_or_error(f):
def encrypt_enabled(f): def encrypt_enabled(f):
@wraps(f) @wraps(f)
def wrapper(*args, **kwds): def wrapper(*args, **kwds):
if matrix.globals.ENCRYPTION: if matrix.globals.ENCRYPTION:
@ -92,9 +108,9 @@ def matrix_hook_olm_command():
"Matrix olm encryption command", "Matrix olm encryption command",
# Synopsis # Synopsis
("info all|blacklisted|private|unverified|verified <filter>||" ("info all|blacklisted|private|unverified|verified <filter>||"
"blacklist <device-id> ||" "blacklist <user-id> <device-id> ||"
"unverify <device-id> ||" "unverify <user-id> <device-id> ||"
"verify <device-id>"), "verify <user-id> <device-id>"),
# Description # Description
(" info: show info about known devices and their keys\n" (" info: show info about known devices and their keys\n"
"blacklist: blacklist a device\n" "blacklist: blacklist a device\n"
@ -104,21 +120,38 @@ def matrix_hook_olm_command():
# Completions # Completions
('info all|blacklisted|private|unverified|verified ||' ('info all|blacklisted|private|unverified|verified ||'
'blacklist %(device_ids) ||' 'blacklist %(device_ids) ||'
'unverify %(device_ids) ||' 'unverify %(user_ids) %(device_ids) ||'
'verify %(device_ids)'), 'verify %(user_ids) %(device_ids)'),
# Function name # Function name
'matrix_olm_command_cb', 'matrix_olm_command_cb',
'') '')
def olm_cmd_parse_args(args): def olm_cmd_parse_args(args):
split_args = args.split() parser = WeechatArgParse(prog="olm")
subparsers = parser.add_subparsers(dest="subcommand")
command = split_args.pop(0) if split_args else "info" info_parser = subparsers.add_parser("info")
info_parser.add_argument(
"category", nargs="?", default="private",
choices=[
"all",
"blacklisted",
"private",
"unverified",
"verified"
])
info_parser.add_argument("filter", nargs="?")
rest_args = split_args if split_args else [] verify_parser = subparsers.add_parser("verify")
verify_parser.add_argument("user_filter")
verify_parser.add_argument("device_filter", nargs="?")
return command, rest_args try:
parsed_args = parser.parse_args(args.split())
return parsed_args
except ParseError:
return None
def grouper(iterable, n, fillvalue=None): def grouper(iterable, n, fillvalue=None):
@ -133,16 +166,10 @@ def partition_key(key):
return ' '.join(''.join(g) for g in groups) return ' '.join(''.join(g) for g in groups)
@own_buffer_or_error def olm_info_command(server, args):
@utf8_decode
def matrix_olm_command_cb(server_name, buffer, args):
server = SERVERS[server_name]
command, args = olm_cmd_parse_args(args)
if not command or command == "info":
olm = server.olm olm = server.olm
if not args or args[0] == "private": if args.category == "private":
device_msg = (" - Device ID: {}\n".format(server.device_id) device_msg = (" - Device ID: {}\n".format(server.device_id)
if server.device_id else "") if server.device_id else "")
id_key = partition_key(olm.account.identity_keys()["curve25519"]) id_key = partition_key(olm.account.identity_keys()["curve25519"])
@ -158,7 +185,7 @@ def matrix_olm_command_cb(server_name, buffer, args):
id_key=id_key, id_key=id_key,
fp_key=fp_key) fp_key=fp_key)
W.prnt(server.server_buffer, message) W.prnt(server.server_buffer, message)
elif args[0] == "all": elif args.category == "all":
for user, keys in olm.device_keys.items(): for user, keys in olm.device_keys.items():
message = ("{prefix}matrix: Identity keys:\n" message = ("{prefix}matrix: Identity keys:\n"
" - User: {user}\n").format( " - User: {user}\n").format(
@ -178,6 +205,19 @@ def matrix_olm_command_cb(server_name, buffer, args):
id_key=id_key, id_key=id_key,
fp_key=fp_key) fp_key=fp_key)
W.prnt(server.server_buffer, message) W.prnt(server.server_buffer, message)
@own_buffer_or_error
@utf8_decode
def matrix_olm_command_cb(server_name, buffer, args):
server = SERVERS[server_name]
parsed_args = olm_cmd_parse_args(args)
if not parsed_args:
return W.WEECHAT_RC_OK
if not parsed_args.subcommand or parsed_args.subcommand == "info":
olm_info_command(server, parsed_args)
else: else:
message = ("{prefix}matrix: Command not implemented.".format( message = ("{prefix}matrix: Command not implemented.".format(
prefix=W.prefix("error"))) prefix=W.prefix("error")))
@ -192,6 +232,7 @@ class EncryptionError(Exception):
class DeviceStore(object): class DeviceStore(object):
def __init__(self, filename): def __init__(self, filename):
# type: (str) -> None
self._entries = [] self._entries = []
self._filename = filename self._filename = filename