matrix: mypy fixes.

This commit is contained in:
Damir Jelić 2018-08-29 20:57:12 +02:00
parent 05a413f7cb
commit 67141c980a
6 changed files with 82 additions and 67 deletions

View file

@ -20,7 +20,7 @@ from __future__ import unicode_literals
import time import time
from builtins import super from builtins import super
from functools import partial from functools import partial
from typing import NamedTuple from typing import Dict, List, NamedTuple, Set, Optional
from nio import ( from nio import (
Api, Api,
@ -46,8 +46,8 @@ from .globals import SCRIPT_NAME, SERVERS, W
from .utf import utf8_decode from .utf import utf8_decode
from .utils import server_ts_to_weechat, shorten_sender, string_strikethrough from .utils import server_ts_to_weechat, shorten_sender, string_strikethrough
OwnMessage = NamedTuple( OwnMessages = NamedTuple(
"OwnMessage", "OwnMessages",
[ [
("sender", str), ("sender", str),
("age", int), ("age", int),
@ -58,6 +58,10 @@ OwnMessage = NamedTuple(
) )
class OwnMessage(OwnMessages):
pass
class OwnAction(OwnMessage): class OwnAction(OwnMessage):
pass pass
@ -100,7 +104,7 @@ class WeechatUser(object):
self.prefix = prefix self.prefix = prefix
self.color = W.info_get("nick_color_name", nick) self.color = W.info_get("nick_color_name", nick)
self.join_time = join_time or time.time() self.join_time = join_time or time.time()
self.speaking_time = None # type: int self.speaking_time = None # type: Optional[int]
def update_speaking_time(self, new_time=None): def update_speaking_time(self, new_time=None):
self.speaking_time = new_time or time.time() self.speaking_time = new_time or time.time()
@ -434,7 +438,7 @@ class WeechatChannelBuffer(object):
return color return color
def _message_tags(self, user, message_type): def _message_tags(self, user, message_type):
# type: (str, RoomUser, str) -> List[str] # type: (WeechatUser, str) -> List[str]
tags = list(self.tags[message_type]) tags = list(self.tags[message_type])
tags.append("nick_{nick}".format(nick=user.nick)) tags.append("nick_{nick}".format(nick=user.nick))
@ -476,7 +480,7 @@ class WeechatChannelBuffer(object):
self.print_date_tags(data, date, tags) self.print_date_tags(data, date, tags)
def message(self, nick, message, date, extra_tags=None): def message(self, nick, message, date, extra_tags=None):
# type: (str, str, int, str) -> None # type: (str, str, int, List[str]) -> None
user = self._get_user(nick) user = self._get_user(nick)
tags = self._message_tags(user, "message") + (extra_tags or []) tags = self._message_tags(user, "message") + (extra_tags or [])
self._print_message(user, message, date, tags) self._print_message(user, message, date, tags)
@ -636,16 +640,16 @@ class WeechatChannelBuffer(object):
if message: if message:
tags = self._message_tags(user, "join") tags = self._message_tags(user, "join")
message = self._membership_message(user, "join") msg = self._membership_message(user, "join")
# TODO add a option to disable smart filters # TODO add a option to disable smart filters
tags.append(SCRIPT_NAME + "_smart_filter") tags.append(SCRIPT_NAME + "_smart_filter")
self.print_date_tags(message, date, tags) self.print_date_tags(msg, date, tags)
self.add_smart_filtered_nick(user.nick) self.add_smart_filtered_nick(user.nick)
def invite(self, nick, date, extra_tags=None): def invite(self, nick, date, extra_tags=None):
# type: (str, int, Optional[bool], Optional[List[str]]) -> None # type: (str, int, Optional[List[str]]) -> None
user = self._get_user(nick) user = self._get_user(nick)
tags = self._message_tags(user, "invite") tags = self._message_tags(user, "invite")
message = self._membership_message(user, "invite") message = self._membership_message(user, "invite")
@ -670,19 +674,19 @@ class WeechatChannelBuffer(object):
if not user.spoken_recently: if not user.spoken_recently:
tags.append(SCRIPT_NAME + "_smart_filter") tags.append(SCRIPT_NAME + "_smart_filter")
message = self._membership_message(user, leave_type) msg = self._membership_message(user, leave_type)
self.print_date_tags(message, date, tags + (extra_tags or [])) self.print_date_tags(msg, date, tags + (extra_tags or []))
self.remove_smart_filtered_nick(user.nick) self.remove_smart_filtered_nick(user.nick)
if user.nick in self.users: if user.nick in self.users:
del self.users[user.nick] del self.users[user.nick]
def part(self, nick, date, message=True, extra_tags=None): def part(self, nick, date, message=True, extra_tags=None):
# type: (str, int, Optional[bool], Optional[List[str]]) -> None # type: (str, int, bool, Optional[List[str]]) -> None
self._leave(nick, date, message, "part", extra_tags) self._leave(nick, date, message, "part", extra_tags)
def kick(self, nick, date, message=True, extra_tags=None): def kick(self, nick, date, message=True, extra_tags=None):
# type: (str, int, Optional[bool], Optional[List[str]]) -> None # type: (str, int, bool, Optional[List[str]]) -> None
self._leave(nick, date, message, "kick", extra_tags) self._leave(nick, date, message, "kick", extra_tags)
def _print_topic(self, nick, topic, date): def _print_topic(self, nick, topic, date):

View file

@ -25,6 +25,7 @@ import textwrap
# pylint: disable=redefined-builtin # pylint: disable=redefined-builtin
from builtins import str from builtins import str
from collections import namedtuple from collections import namedtuple
from typing import List
import webcolors import webcolors
from pygments import highlight from pygments import highlight
@ -45,7 +46,7 @@ except ImportError:
FormattedString = namedtuple("FormattedString", ["text", "attributes"]) FormattedString = namedtuple("FormattedString", ["text", "attributes"])
class Formatted: class Formatted(object):
def __init__(self, substrings): def __init__(self, substrings):
# type: (List[FormattedString]) -> None # type: (List[FormattedString]) -> None
self.substrings = substrings self.substrings = substrings
@ -247,7 +248,7 @@ class Formatted:
# TODO do we want at least some formatting using unicode # TODO do we want at least some formatting using unicode
# (strikethrough, quotes)? # (strikethrough, quotes)?
def to_plain(self): def to_plain(self):
# type: (List[FormattedString]) -> str # type: () -> str
def strip_atribute(string, _, __): def strip_atribute(string, _, __):
return string return string
@ -688,7 +689,7 @@ def color_html_to_weechat(color):
try: try:
rgb_color = webcolors.html5_parse_legacy_color(color) rgb_color = webcolors.html5_parse_legacy_color(color)
except ValueError: except ValueError:
return None return ""
if rgb_color in weechat_basic_colors: if rgb_color in weechat_basic_colors:
return weechat_basic_colors[rgb_color] return weechat_basic_colors[rgb_color]

View file

@ -478,7 +478,7 @@ def matrix_kick_command_cb(data, buffer, args):
def event_id_from_line(buf, target_number): def event_id_from_line(buf, target_number):
# type: (weechat.buffer, int) -> str # type: (str, int) -> str
own_lines = W.hdata_pointer(W.hdata_get("buffer"), buf, "own_lines") own_lines = W.hdata_pointer(W.hdata_get("buffer"), buf, "own_lines")
if own_lines: if own_lines:
line = W.hdata_pointer(W.hdata_get("lines"), own_lines, "last_line") line = W.hdata_pointer(W.hdata_get("lines"), own_lines, "last_line")

View file

@ -20,8 +20,11 @@ import sys
from .utf import WeechatWrapper from .utf import WeechatWrapper
from typing import Dict, Optional
if False: if False:
from typing import Dict from .server import MatrixServer
from .config import MatrixConfig
try: try:
@ -29,11 +32,11 @@ try:
W = weechat if sys.hexversion >= 0x3000000 else WeechatWrapper(weechat) W = weechat if sys.hexversion >= 0x3000000 else WeechatWrapper(weechat)
except ImportError: except ImportError:
import matrix._weechat as weechat import matrix._weechat as weechat # type: ignore
W = weechat W = weechat
SERVERS = dict() # type: Dict[str, MatrixServer] SERVERS = dict() # type: Dict[str, MatrixServer]
CONFIG = None # type: MatrixConfig CONFIG = None # type: Optional[MatrixConfig]
ENCRYPTION = True # type: bool ENCRYPTION = True # type: bool
SCRIPT_NAME = "matrix" # type: str SCRIPT_NAME = "matrix" # type: str

View file

@ -22,6 +22,10 @@ import socket
import ssl import ssl
import time import time
from collections import defaultdict, deque from collections import defaultdict, deque
from typing import Any, Deque, Dict, Optional
if False:
from .colors import Formatted
from nio import ( from nio import (
HttpClient, HttpClient,
@ -31,6 +35,8 @@ from nio import (
SyncRepsponse, SyncRepsponse,
TransportResponse, TransportResponse,
TransportType, TransportType,
Rooms,
Response
) )
from . import globals as G from . import globals as G
@ -41,7 +47,7 @@ from .utf import utf8_decode
from .utils import create_server_buffer, key_from_value, server_buffer_prnt from .utils import create_server_buffer, key_from_value, server_buffer_prnt
try: try:
FileNotFoundError FileNotFoundError # type: ignore
except NameError: except NameError:
FileNotFoundError = IOError FileNotFoundError = IOError
@ -51,7 +57,7 @@ class ServerConfig(ConfigSection):
# type: (str, str) -> None # type: (str, str) -> None
self._server_name = server_name self._server_name = server_name
self._config_ptr = config_ptr self._config_ptr = config_ptr
self._option_ptrs = {} self._option_ptrs = {} # type: Dict[str, str]
options = [ options = [
Option( Option(
@ -167,32 +173,29 @@ class ServerConfig(ConfigSection):
class MatrixServer(object): class MatrixServer(object):
# pylint: disable=too-many-instance-attributes # pylint: disable=too-many-instance-attributes
def __init__(self, name, config_file): def __init__(self, name, config_ptr):
# type: (str, weechat.config) -> None # type: (str, str) -> None
# yapf: disable # yapf: disable
self.name = name # type: str self.name = name # type: str
self.user_id = "" self.user_id = ""
self.device_id = "" # type: str self.device_id = "" # type: str
self.olm = None # type: Olm self.room_buffers = dict() # type: Dict[str, RoomBuffer]
self.encryption_queue = defaultdict(deque) self.buffers = dict() # type: Dict[str, str]
self.server_buffer = None # type: Optional[str]
self.room_buffers = dict() # type: Dict[str, WeechatChannelBuffer] self.fd_hook = None # type: Optional[str]
self.buffers = dict() # type: Dict[str, weechat.buffer] self.ssl_hook = None # type: Optional[str]
self.server_buffer = None # type: weechat.buffer self.timer_hook = None # type: Optional[str]
self.fd_hook = None # type: weechat.hook self.numeric_address = "" # type: Optional[str]
self.ssl_hook = None # type: weechat.hook
self.timer_hook = None # type: weechat.hook
self.numeric_address = "" # type: str
self.connected = False # type: bool self.connected = False # type: bool
self.connecting = False # type: bool self.connecting = False # type: bool
self.reconnect_delay = 0 # type: int self.reconnect_delay = 0 # type: int
self.reconnect_time = None # type: float self.reconnect_time = None # type: Optional[float]
self.sync_time = None # type: Optional[float] self.sync_time = None # type: Optional[float]
self.socket = None # type: ssl.SSLSocket self.socket = None # type: Optional[ssl.SSLSocket]
self.ssl_context = ssl.create_default_context() # type: ssl.SSLContext self.ssl_context = ssl.create_default_context() # type: ssl.SSLContext
self.transport_type = None # type: Optional[nio.TransportType] self.transport_type = None # type: Optional[TransportType]
# Enable http2 negotiation on the ssl context. # Enable http2 negotiation on the ssl context.
self.ssl_context.set_alpn_protocols(["h2", "http/1.1"]) self.ssl_context.set_alpn_protocols(["h2", "http/1.1"])
@ -203,24 +206,19 @@ class MatrixServer(object):
pass pass
self.client = None self.client = None
self.access_token = None # type: str self.access_token = None # type: Optional[str]
self.next_batch = None # type: str self.next_batch = None # type: Optional[str]
self.transaction_id = 0 # type: int self.transaction_id = 0 # type: int
self.lag = 0 # type: int self.lag = 0 # type: int
self.lag_done = False # type: bool self.lag_done = False # type: bool
self.send_fd_hook = None # type: weechat.hook self.send_fd_hook = None # type: Optional[str]
self.send_buffer = b"" # type: bytes self.send_buffer = b"" # type: bytes
self.device_check_timestamp = None self.device_check_timestamp = None # type: Optional[int]
self.send_queue = deque() self.own_message_queue = dict() # type: Dict[str, OwnMessage]
self.own_message_queue = dict() # type: Dict[OwnMessage]
self.event_queue_timer = None self.config = ServerConfig(self.name, config_ptr)
self.event_queue = deque() # type: Deque[RoomInfo]
# self._create_options(config_file)
self.config = ServerConfig(self.name, config_file)
self._create_session_dir() self._create_session_dir()
# yapf: enable # yapf: enable
@ -287,13 +285,16 @@ class MatrixServer(object):
def send_or_queue(self, request): def send_or_queue(self, request):
# type: (bytes) -> None # type: (bytes) -> None
if not self.send(request): self.send(request)
self.send_queue.append(request)
def try_send(self, message): def try_send(self, message):
# type: (MatrixServer, bytes) -> bool # type: (MatrixServer, bytes) -> bool
sock = self.socket sock = self.socket
if not sock:
return False
total_sent = 0 total_sent = 0
message_length = len(message) message_length = len(message)
@ -353,7 +354,7 @@ class MatrixServer(object):
return True return True
def _abort_send(self): def _abort_send(self):
self.send_buffer = "" self.send_buffer = b""
def _finalize_send(self): def _finalize_send(self):
# type: (MatrixServer) -> None # type: (MatrixServer) -> None
@ -433,6 +434,7 @@ class MatrixServer(object):
self.send_buffer = b"" self.send_buffer = b""
self.transport_type = None self.transport_type = None
if self.client:
try: try:
self.client.disconnect() self.client.disconnect()
except LocalProtocolError: except LocalProtocolError:
@ -511,12 +513,18 @@ class MatrixServer(object):
def sync(self, timeout=None, sync_filter=None): def sync(self, timeout=None, sync_filter=None):
# type: (Optional[int], Optional[Dict[Any, Any]]) -> None # type: (Optional[int], Optional[Dict[Any, Any]]) -> None
if not self.client:
return
self.sync_time = None self.sync_time = None
_, request = self.client.sync(timeout, sync_filter) _, request = self.client.sync(timeout, sync_filter)
self.send_or_queue(request) self.send_or_queue(request)
def login(self): def login(self):
# type: () -> None # type: () -> None
if not self.client:
return
if self.client.logged_in: if self.client.logged_in:
msg = ( msg = (
"{prefix}{script_name}: Already logged in, " "syncing..." "{prefix}{script_name}: Already logged in, " "syncing..."
@ -576,6 +584,9 @@ class MatrixServer(object):
if room_buffer.room.encrypted: if room_buffer.room.encrypted:
return return
if not self.client:
return
if msgtype == "m.emote": if msgtype == "m.emote":
message_class = OwnAction message_class = OwnAction
else: else:
@ -720,7 +731,7 @@ class MatrixServer(object):
self.disconnect() self.disconnect()
def handle_response(self, response): def handle_response(self, response):
# type: (MatrixMessage) -> None # type: (Response) -> None
self.lag = response.elapsed * 1000 self.lag = response.elapsed * 1000
# If the response was a sync response and contained a timeout the # If the response was a sync response and contained a timeout the
@ -836,7 +847,7 @@ def matrix_config_server_write_cb(data, config_file, section_name):
@utf8_decode @utf8_decode
def matrix_config_server_change_cb(server_name, option): def matrix_config_server_change_cb(server_name, option):
# type: (str, weechat.config_option) -> int # type: (str, str) -> int
server = SERVERS[server_name] server = SERVERS[server_name]
option_name = None option_name = None
@ -881,14 +892,6 @@ def matrix_timer_cb(server_name, remaining_calls):
sync_filter = {"room": {"timeline": {"limit": 5000}}} sync_filter = {"room": {"timeline": {"limit": 5000}}}
server.sync(timeout, sync_filter) server.sync(timeout, sync_filter)
while server.send_queue:
message = server.send_queue.popleft()
if not server.send(message):
# We got an error while sending the last message return the message
# to the queue and exit the loop
server.send_queue.appendleft(message)
break
if not server.next_batch: if not server.next_batch:
return W.WEECHAT_RC_OK return W.WEECHAT_RC_OK
@ -926,6 +929,6 @@ def send_cb(server_name, file_descriptor):
server.send_fd_hook = None server.send_fd_hook = None
if server.send_buffer: if server.send_buffer:
server.try_send(server, server.send_buffer) server.try_send(server.send_buffer)
return W.WEECHAT_RC_OK return W.WEECHAT_RC_OK

View file

@ -17,6 +17,10 @@
from __future__ import unicode_literals from __future__ import unicode_literals
import time import time
from typing import Any, Dict, List
if False:
from .server import MatrixServer
from .globals import W from .globals import W
@ -35,7 +39,7 @@ def server_buffer_prnt(server, string):
def tags_from_line_data(line_data): def tags_from_line_data(line_data):
# type: (weechat.hdata) -> List[str] # type: (str) -> List[str]
tags_count = W.hdata_get_var_array_size( tags_count = W.hdata_get_var_array_size(
W.hdata_get("line_data"), line_data, "tags_array" W.hdata_get("line_data"), line_data, "tags_array"
) )