Source code for nio.client.async_client

# -*- coding: utf-8 -*-

# Copyright © 2018, 2019 Damir Jelić <poljar@termina.org.uk>
#
# Permission to use, copy, modify, and/or distribute this software for
# any purpose with or without fee is hereby granted, provided that the
# above copyright notice and this permission notice appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY
# SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER
# RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF
# CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
# CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.

import asyncio
import warnings
from asyncio import Event
from functools import partial, wraps
from json.decoder import JSONDecodeError
from typing import (Any, AsyncIterable, BinaryIO, Coroutine, Dict, Iterable,
                    List, Optional, Tuple, Type, Union)
from uuid import uuid4

import attr
from aiohttp import ClientResponse, ClientSession, ContentTypeError
from aiohttp.client_exceptions import ClientConnectionError

from . import Client, ClientConfig
from .base_client import logged_in, store_loaded
from ..api import Api, MessageDirection, ResizingMethod
from ..exceptions import (GroupEncryptionError, LocalProtocolError,
                          MembersSyncError, SendRetryError)
from ..events import RoomKeyRequest, RoomKeyRequestCancellation
from ..messages import ToDeviceMessage
from ..responses import (ErrorResponse, FileResponse,
                         JoinResponse, JoinError,
                         JoinedMembersError, JoinedMembersResponse,
                         KeysClaimError, KeysClaimResponse, KeysQueryResponse,
                         KeysUploadResponse, LoginError, LoginResponse,
                         LogoutError, LogoutResponse,
                         ProfileGetAvatarResponse, ProfileGetAvatarError,
                         ProfileGetDisplayNameResponse,
                         ProfileGetDisplayNameError, ProfileGetResponse,
                         ProfileGetError, ProfileSetAvatarResponse,
                         ProfileSetAvatarError, ProfileSetDisplayNameResponse,
                         ProfileSetDisplayNameError, Response,
                         RoomContextError, RoomContextResponse,
                         RoomForgetResponse, RoomForgetError,
                         RoomKeyRequestError, RoomKeyRequestResponse,
                         RoomLeaveResponse, RoomLeaveError,
                         RoomMessagesError, RoomMessagesResponse,
                         RoomSendResponse, RoomTypingResponse, RoomTypingError,
                         ShareGroupSessionError,
                         ShareGroupSessionResponse, SyncError, SyncResponse,
                         PartialSyncResponse,
                         ThumbnailError, ThumbnailResponse,
                         ToDeviceError, ToDeviceResponse,
                         UploadError, UploadResponse)

if False:
    from ..events import MegolmEvent
    from .crypto import OlmDevice

_UploadDataT = Union[bytes, BinaryIO, AsyncIterable[bytes]]

_ShareGroupSessionT = Union[ShareGroupSessionError, ShareGroupSessionResponse]

_ProfileGetDisplayNameT = Union[
    ProfileGetDisplayNameResponse,
    ProfileGetDisplayNameError
]
_ProfileSetDisplayNameT = Union[
    ProfileSetDisplayNameResponse,
    ProfileSetDisplayNameError
]


@attr.s
class ResponseCb(object):
    """Response callback."""

    func = attr.ib()
    filter = attr.ib(default=None)


def client_session(func):
    """Ensure that the Async client has a valid client session."""
    @wraps(func)
    async def wrapper(self, *args, **kwargs):
        if not self.client_session:
            self.client_session = ClientSession()
        return await func(self, *args, **kwargs)
    return wrapper


@attr.s(frozen=True)
class AsyncClientConfig(ClientConfig):
    """Async nio client configuration.

    Attributes:
        max_limit_exceeded (int, optional): How many 429 (Too many requests)
            errors can a request encounter before giving up and returning
            an ErrorResponse.
            Default is None for unlimited.

        max_timeouts (int, optional): How many timeout connection errors can
            a request encounter before giving up and raising the error:
            a ClientConnectionError, TimeoutError, or asyncio.TimeoutError.
            Default is None for unlimited.

        backoff_factor (float): A backoff factor to apply between retries
            for timeouts, starting from the second try.
            nio will sleep for `backoff_factor * (2 ** (total_retries - 1))`
            seconds.
            For example, with the default backoff_factor of 0.1,
            nio will sleep for 0.0, 0.2, 0.4, ... seconds between retries.

        max_timeout_retry_wait_time (float): The maximum time to wait between
            retries for timeouts, by default 60.
    """

    max_limit_exceeded = attr.ib(type=Optional[int], default=None)
    max_timeouts = attr.ib(type=Optional[int], default=None)
    backoff_factor = attr.ib(type=float, default=0.1)
    max_timeout_retry_wait_time = attr.ib(type=float, default=60)


[docs]class AsyncClient(Client): """An async IO matrix client. Args: homeserver (str): The URL of the homeserver which we want to connect to. user (str, optional): The user which will be used when we log in to the homeserver. device_id (str, optional): An unique identifier that distinguishes this client instance. If not set the server will provide one after log in. store_path (str, optional): The directory that should be used for state storage. config (AsyncClientConfig, optional): Configuration for the client. ssl (bool/ssl.SSLContext, optional): SSL validation mode. None for default SSL check (ssl.create_default_context() is used), False for skip SSL certificate validation connection. proxy (str, optional): The proxy that should be used for the HTTP connection. Attributes: synced (Event): An asyncio event that is fired every time the client successfully syncs with the server. A simple example can be found bellow. Example: >>> client = AsyncClient("https://example.org", "example") >>> login_response = loop.run_until_complete( >>> client.login("hunter1") >>> ) >>> asyncio.run(client.sync_forever(30000)) This example assumes a full sync on every run. If a sync token is provided for the `since` parameter of the `sync_forever` method `full_state` should be set to `True` as well. Example: >>> asyncio.run( >>> client.sync_forever(30000, since="token123", >>> full_state=True) >>> ) The client can also be configured to store and restore the sync token automatically. The `full_state` argument should be set to `True` in that case as well. Example: >>> config = ClientConfig(store_sync_tokens=True) >>> client = AsyncClient("https://example.org", "example", >>> store_path="/home/example", >>> config=config) >>> login_response = loop.run_until_complete( >>> client.login("hunter1") >>> ) >>> asyncio.run(client.sync_forever(30000, full_state=True)) """ def __init__( self, homeserver, # type: str user="", # type: str device_id="", # type: Optional[str] store_path="", # type: Optional[str] config=None, # type: Optional[AsyncClientConfig] ssl=None, # type: Optional[bool] proxy=None, # type: Optional[str] ): # type: (...) -> None self.homeserver = homeserver self.client_session = None # type: Optional[ClientSession] self.ssl = ssl self.proxy = proxy self.synced = Event() self.response_callbacks = [] # type: List[ResponseCb] self.sharing_session = dict() # type: Dict[str, Event] if isinstance(config, ClientConfig): warnings.warn( "Pass an AsyncClientConfig instead of ClientConfig.", DeprecationWarning ) config = AsyncClientConfig(**config.__dict__) self.config = config or AsyncClientConfig() # type: AsyncClientConfig super().__init__(user, device_id, store_path, self.config)
[docs] def add_response_callback( self, func, # type: Coroutine[Any, Any, Response] cb_filter=None # type: Union[Tuple[Type], Type, None] ): # type: (...) -> None """Add a coroutine that will be called if a response is received. Args: func (Coroutine): The coroutine that will be called with the response as the argument. cb_filter (Type, optional): A type or a tuple of types for which the callback should be called. Example: >>> # A callback that will be called every time our `sync_forever` >>> # method succesfully syncs with the server. >>> async def sync_cb(response): ... print(f"We synced, token: {response.next_batch}") ... >>> client.add_response_callback(sync_cb, SyncResponse) >>> await client.sync_forever(30000) """ cb = ResponseCb(func, cb_filter) self.response_callbacks.append(cb)
[docs] async def parse_body(self, transport_response): # type: (ClientResponse) -> Dict[Any, Any] """Parse the body of the response. Args: transport_response(ClientResponse): The transport response that contains the body of the response. Returns a dictionary representing the response. """ try: return await transport_response.json() except (JSONDecodeError, ContentTypeError): return {}
[docs] async def create_matrix_response( self, response_class, transport_response, data=None ): # type: (Type, ClientResponse, Tuple) -> Response """Transform a transport response into a nio matrix response. Args: response_class (Type): The class that the requests belongs to. transport_response (ClientResponse): The underlying transport response that contains our response body. data (Tuple, optional): Extra data that is required to instantiate the response class. Returns a subclass of `Response` depending on the type of the response_class argument. """ data = data or () content_type = transport_response.content_type is_json = content_type == "application/json" if issubclass(response_class, FileResponse) and is_json: parsed_dict = await self.parse_body(transport_response) resp = response_class.from_data(parsed_dict, content_type) elif issubclass(response_class, FileResponse): body = await transport_response.read() resp = response_class.from_data(body, content_type) else: parsed_dict = await self.parse_body(transport_response) resp = response_class.from_dict(parsed_dict, *data) resp.transport_response = transport_response return resp
async def _run_to_device_callbacks(self, event): for cb in self.to_device_callbacks: if (cb.filter is None or isinstance(event, cb.filter)): await asyncio.coroutine(cb.func)(event) async def _handle_to_device(self, response): decrypted_to_device = [] # type: ignore for index, to_device_event in enumerate(response.to_device_events): decrypted_event = self._handle_decrypt_to_device(to_device_event) if decrypted_event: decrypted_to_device.append((index, decrypted_event)) to_device_event = decrypted_event # Do not pass room key request events to our user here. We don't # want to notify them about requests that get automatically handled # or canceled right away. if isinstance( to_device_event, (RoomKeyRequest, RoomKeyRequestCancellation) ): continue await self._run_to_device_callbacks(to_device_event) self._replace_decrypted_to_device(decrypted_to_device, response) async def _handle_invited_rooms(self, response): for room_id, info in response.rooms.invite.items(): room = self._get_invited_room(room_id) for event in info.invite_state: room.handle_event(event) for cb in self.event_callbacks: if (cb.filter is None or isinstance(event, cb.filter)): await asyncio.coroutine(cb.func)(room, event) async def _handle_joined_rooms(self, response): encrypted_rooms = set() for room_id, join_info in response.rooms.join.items(): self._handle_joined_state(room_id, join_info, encrypted_rooms) room = self.rooms[room_id] decrypted_events = [] for index, event in enumerate(join_info.timeline.events): decrypted_event = self._handle_timeline_event( event, room_id, room, encrypted_rooms ) if decrypted_event: event = decrypted_event decrypted_events.append((index, decrypted_event)) for cb in self.event_callbacks: if (cb.filter is None or isinstance(event, cb.filter)): await asyncio.coroutine(cb.func)(room, event) # Replace the Megolm events with decrypted ones for decrypted_event in decrypted_events: index, event = decrypted_event join_info.timeline.events[index] = event for event in join_info.ephemeral: room.handle_ephemeral_event(event) for cb in self.ephemeral_callbacks: if (cb.filter is None or isinstance(event, cb.filter)): await asyncio.coroutine(cb.func)(room, event) if room.encrypted and self.olm is not None: self.olm.update_tracked_users(room) self.encrypted_rooms.update(encrypted_rooms) if self.store: self.store.save_encrypted_rooms(encrypted_rooms) async def _handle_expired_verifications(self): expired_verifications = self.olm.clear_verifications() for event in expired_verifications: for cb in self.to_device_callbacks: if (cb.filter is None or isinstance(event, cb.filter)): await asyncio.coroutine(cb.func)(event) async def _handle_sync(self, response): # We already recieved such a sync response, do nothing in that case. if self.next_batch == response.next_batch: return if isinstance(response, SyncResponse): self.next_batch = response.next_batch await self._handle_to_device(response) await self._handle_invited_rooms(response) await self._handle_joined_rooms(response) if self.olm: await self._handle_expired_verifications() self._handle_olm_events(response) await self._collect_key_requests() async def _collect_key_requests(self): events = self.olm.collect_key_requests() for event in events: await self._run_to_device_callbacks(event)
[docs] async def receive_response(self, response): """Receive a Matrix Response and change the client state accordingly. Some responses will get edited for the callers convenience e.g. sync responses that contain encrypted messages. The encrypted messages will be replaced by decrypted ones if decryption is possible. Args: response (Response): the response that we wish the client to handle """ if not isinstance(response, Response): raise ValueError("Invalid response received") if isinstance(response, (SyncResponse, PartialSyncResponse)): await self._handle_sync(response) else: super().receive_response(response)
[docs] async def get_timeout_retry_wait_time(self, got_timeouts): # type: (int) -> float if got_timeouts < 2: return 0.0 return min( self.config.backoff_factor * (2 ** (got_timeouts - 1)), self.config.max_timeout_retry_wait_time )
async def _send( self, response_class, method, path, data=None, response_data=None, content_type=None, ): headers = {"content-type": content_type} if content_type else {} got_429 = 0 max_429 = self.config.max_limit_exceeded got_timeouts = 0 max_timeouts = self.config.max_timeouts while True: try: transport_resp = await self.send(method, path, data, headers) resp = await self.create_matrix_response( response_class, transport_resp, response_data ) if isinstance(resp, ErrorResponse) and resp.retry_after_ms: got_429 += 1 if max_429 is not None and got_429 > max_429: break await self.run_response_callbacks([resp]) await asyncio.sleep(resp.retry_after_ms / 1000) else: break except (ClientConnectionError, TimeoutError, asyncio.TimeoutError): got_timeouts += 1 if max_timeouts is not None and got_timeouts > max_timeouts: raise wait = await self.get_timeout_retry_wait_time(got_timeouts) await asyncio.sleep(wait) await self.receive_response(resp) return resp
[docs] @client_session async def send( self, method, # type: str path, # type: str data=None, # type: Union[None, str, _UploadDataT] headers=None # type: Optional[Dict[str, str]] ): # type: (...) -> ClientResponse """Send a request to the homeserver. Args: method (str): The request method that should be used. One of get, post, put, delete. path (str): The URL path of the request. data (str, optional): Data that will be posted with the request. headers (Dict[str,str] , optional): Additional request headers that should be used with the request. """ assert self.client_session return await self.client_session.request( method, self.homeserver + path, data=data, ssl=self.ssl, proxy=self.proxy, headers=headers )
[docs] async def login(self, password, device_name=""): # type: (str, str) -> Union[LoginResponse, LoginError] """Login to the homeserver. Args: password (str): The user's password. device_name (str): A display name to assign to a newly-created device. Ignored if the logged in device corresponds to a known device. Returns either a `LoginResponse` if the request was successful or a `LoginError` if there was an error with the request. """ method, path, data = Api.login( self.user, password=password, device_name=device_name, device_id=self.device_id ) return await self._send(LoginResponse, method, path, data)
[docs] @logged_in async def logout(self): """Logout from the homeserver. Returns either 'LogoutResponse' if the request was successful or a `Logouterror` if there was an error with the request. """ method, path, data = Api.logout( self.access_token ) return await self._send(LogoutResponse, method, path, data)
[docs] @logged_in async def sync( self, timeout=None, # type: Optional[int] sync_filter=None, # type: Optional[Dict[Any, Any]] since=None, # type: Optional[str] full_state=None # type: Optional[bool] ): # type: (...) -> Union[SyncResponse, SyncError] """Synchronise the client's state with the latest state on the server. Args: timeout(int, optional): The maximum time that the server should wait for new events before it should return the request anyways, in milliseconds. sync_filter (Dict[Any, Any], optional): A filter that should be used for this sync request. full_state(bool, optional): Controls whether to include the full state for all rooms the user is a member of. If this is set to true, then all state events will be returned, even if since is non-empty. The timeline will still be limited by the since parameter. since(str, optional): A token specifying a point in time where to continue the sync from. Defaults to the last sync token we received from the server using this API call. Returns either a `SyncResponse` if the request was successful or a `SyncError` if there was an error with the request. """ sync_token = since or self.next_batch method, path = Api.sync( self.access_token, since=sync_token or self.loaded_sync_token, timeout=timeout, filter=sync_filter, full_state=full_state ) response = await self._send(SyncResponse, method, path) self.synced.set() self.synced.clear() return response
[docs] @logged_in async def send_to_device_messages(self): # type: () -> List[ToDeviceResponse] """Send out outgoing to-device messages.""" if not self.outgoing_to_device_messages: return [] tasks = [] for message in self.outgoing_to_device_messages: task = asyncio.ensure_future(self.to_device(message)) tasks.append(task) return await asyncio.gather(*tasks)
[docs] async def run_response_callbacks(self, responses): """Run the configured response callbacks for the given responses.""" for response in responses: for cb in self.response_callbacks: if (cb.filter is None or isinstance(response, cb.filter)): await asyncio.coroutine(cb.func)(response)
[docs] @logged_in async def sync_forever( self, timeout=None, # type: Optional[int] sync_filter=None, # type: Optional[Dict[Any, Any]] since=None, # type: Optional[str] full_state=None, # type: Optional[bool] loop_sleep_time=None # type: Optional[int] ): # type: (...) -> None """Continuously sync with the configured homeserver. This method calls the sync method in a loop. To react to events event callbacks should be configured. The loop also makes sure to handle other required requests between syncs. To react to the responses a response callback should be added. Args: timeout (int, optional): The maximum time that the server should wait for new events before it should return the request anyways, in milliseconds. sync_filter (Dict[Any, Any], optional): A filter that should be used for this sync request. full_state (bool, optional): Controls whether to include the full state for all rooms the user is a member of. If this is set to true, then all state events will be returned, even if since is non-empty. The timeline will still be limited by the since parameter. This argument will be used only for the first sync request. since (str, optional): A token specifying a point in time where to continue the sync from. Defaults to the last sync token we received from the server using this API call. This argument will be used only for the first sync request, the subsequent sync requests will use the token from the last sync response. loop_sleep_time (int, optional): The sleep time, if any, between successful sync loop iterations in milliseconds. """ while True: try: tasks = [ asyncio.ensure_future(coro) for coro in ( self.sync(timeout, sync_filter, since, full_state), self.send_to_device_messages() ) ] if self.should_upload_keys: tasks.append(asyncio.ensure_future(self.keys_upload())) if self.should_query_keys: tasks.append(asyncio.ensure_future(self.keys_query())) if self.should_claim_keys: tasks.append(asyncio.ensure_future( self.keys_claim(self.get_users_for_key_claiming()) )) for response in asyncio.as_completed(tasks): await self.run_response_callbacks((await response,)) full_state = None since = None if loop_sleep_time: await asyncio.sleep(loop_sleep_time / 1000) except asyncio.CancelledError: for task in tasks: task.cancel() break
[docs] @logged_in @store_loaded async def start_key_verification( self, device, # type: OlmDevice tx_id=None # type: Optional[str] ): # type: (...) -> Union[ToDeviceResponse, ToDeviceError] """Start a interactive key verification with the given device. Returns either a `ToDeviceResponse` if the request was successful or a `ToDeviceError` if there was an error with the request. Args: device (OlmDevice): An device with which we would like to start the interactive key verification process. """ message = self.create_key_verification(device) return await self.to_device(message, tx_id)
[docs] @logged_in @store_loaded async def cancel_key_verification( self, transaction_id, # type: OlmDevice reject=False, # type: bool tx_id=None # type: Optional[str] ): # type: (...) -> Union[ToDeviceResponse, ToDeviceError] """Cancel a interactive key verification with the given device. Returns either a `ToDeviceResponse` if the request was successful or a `ToDeviceError` if there was an error with the request. Args: transaction_id (str): An transaction id of a valid key verification process. reject (bool): Is the cancelation reason because we're rejecting the short auth string and mark it as mismatching or a normal user cancelation. Raises a LocalProtocolError no verification process with the given transaction ID exists or if reject is True and the short auth string couldn't be shown yet because plublic keys weren't yet exchanged. """ if transaction_id not in self.key_verifications: raise LocalProtocolError("Key verification with the transaction " "id {} does not exist.".format( transaction_id )) sas = self.key_verifications[transaction_id] if reject: sas.reject_sas() else: sas.cancel() message = sas.get_cancellation() return await self.to_device(message, tx_id)
[docs] @logged_in @store_loaded async def accept_key_verification(self, transaction_id, tx_id=None): # type: (str, Optional[str]) -> Union[ToDeviceResponse, ToDeviceError] """Accept a key verification start event. Returns either a `ToDeviceResponse` if the request was successful or a `ToDeviceError` if there was an error with the request. Args: transaction_id (str): An transaction id of a valid key verification process. """ if transaction_id not in self.key_verifications: raise LocalProtocolError("Key verification with the transaction " "id {} does not exist.".format( transaction_id )) sas = self.key_verifications[transaction_id] message = sas.accept_verification() return await self.to_device(message, tx_id)
[docs] @logged_in @store_loaded async def confirm_short_auth_string(self, transaction_id, tx_id=None): # type: (str, Optional[str]) -> Union[ToDeviceResponse, ToDeviceError] """Confirm a short auth string and mark it as matching. Returns either a `ToDeviceResponse` if the request was successful or a `ToDeviceError` if there was an error with the request. Args: transaction_id (str): An transaction id of a valid key verification process. """ message = self.confirm_key_verification(transaction_id) return await self.to_device(message, tx_id)
[docs] @logged_in async def to_device( self, message, # type: ToDeviceMessage tx_id=None # type: Optional[str] ): # type: (...) -> Union[ToDeviceResponse, ToDeviceError] """Send a to-device message. Args: message (ToDeviceMessage): The message that should be sent out. tx_id (str, optional): The transaction ID for this message. Should be unique. Returns either a `ToDeviceResponse` if the request was successful or a `ToDeviceError` if there was an error with the request. """ uuid = tx_id or uuid4() method, path, data = Api.to_device( self.access_token, message.type, message.as_dict(), uuid ) return await self._send( ToDeviceResponse, method, path, data, response_data=(message, ) )
[docs] @logged_in @store_loaded async def keys_upload(self): """Upload the E2E encryption keys. This uploads the long lived session keys as well as the required amount of one-time keys. Raises LocalProtocolError if the client isn't logged in, if the session store isn't loaded or if no encryption keys need to be uploaded. """ if not self.should_upload_keys: raise LocalProtocolError("No key upload needed.") keys_dict = self.olm.share_keys() method, path, data = Api.keys_upload( self.access_token, keys_dict ) return await self._send(KeysUploadResponse, method, path, data)
[docs] @logged_in @store_loaded async def keys_query(self): # type: () -> Union[KeysQueryResponse] """Query the server for user keys. This queries the server for device keys of users with which we share an encrypted room. Raises LocalProtocolError if the client isn't logged in, if the session store isn't loaded or if no key query needs to be performed. """ # TODO refactor that out into the base client, and use our knowledge of # already queried users to limit the user list. user_list = [ user_id for room in self.rooms.values() if room.encrypted for user_id in room.users ] if not user_list: raise LocalProtocolError("No key query required.") # TODO pass the sync token here if it's a device update that triggered # our need for a key query. method, path, data = Api.keys_query( self.access_token, user_list ) return await self._send(KeysQueryResponse, method, path, data)
[docs] @logged_in async def joined_members(self, room_id): # type: (str) -> Union[JoinedMembersResponse, JoinedMembersError] """Send a message to a room. Args: room_id(str): The room id of the room for which we wan't to request the joined member list. Returns either a `JoinedMembersResponse` if the request was successful or a `JoinedMembersError` if there was an error with the request. """ method, path = Api.joined_members( self.access_token, room_id ) return await self._send( JoinedMembersResponse, method, path, response_data=(room_id, ) )
[docs] @logged_in async def room_send( self, room_id, message_type, content, tx_id=None, ignore_unverified_devices=False ): """Send a message to a room. Args: room_id(str): The room id of the room where the message should be sent to. message_type(str): A string identifying the type of the message. content(Dict[Any, Any]): A dictionary containing the content of the message. tx_id(str, optional): The transaction ID of this event used to uniquely identify this message. ignore_unverified_devices(bool): If the room is encrypted and contains unverified devices, the devices can be marked as ignored here. Ignored devices will still receive encryption keys for messages but they won't be marked as verified. If the room where the message should be sent is encrypted the message will be encrypted before sending. This method also makes sure that the room members are fully synced and that keys are queried before sending messages to an encrypted room. If the method can't sync the state fully to send out an encrypted message after a couple of retries it raises `SendRetryError`. Raises `LocalProtocolError` if the client isn't logged in. """ async def send(room_id, message_type, content, tx_id): if self.olm: try: room = self.rooms[room_id] except KeyError: raise LocalProtocolError( "No such room with id {} found.".format(room_id) ) if room.encrypted: message_type, content = self.encrypt(room_id, message_type, content) method, path, data = Api.room_send(self.access_token, room_id, message_type, content, tx_id) return await self._send(RoomSendResponse, method, path, data, (room_id, )) retries = 5 uuid = tx_id or uuid4() for i in range(retries): try: return await send(room_id, message_type, content, uuid) except GroupEncryptionError: sharing_event = self.sharing_session.get(room_id, None) if sharing_event: await sharing_event.wait() else: share = await self.share_group_session( room_id, ignore_unverified_devices=ignore_unverified_devices ) await self.run_response_callbacks([share]) except MembersSyncError: responses = [] responses.append(await self.joined_members(room_id)) if self.should_query_keys: responses.append(await self.keys_query()) await self.run_response_callbacks(responses) raise SendRetryError("Max retries exceeded while trying to send " "the message")
[docs] @logged_in @store_loaded async def keys_claim( self, user_set # type: Dict[str, Iterable[str]] ): # type: (...) -> Union[KeysClaimResponse, KeysClaimError] """Claim one-time keys for a set of user and device pairs. Args: user_set(Dict[str, Iterator[str]]): A dictionary maping from a user id to a iterator of device ids. If a user set for a specific room is required it can be obtained using the `get_missing_sessions()` method. Raises LocalProtocolError if the client isn't logged in, if the session store isn't loaded, no room with the given room id exists or the room isn't an encrypted room. """ method, path, data = Api.keys_claim( self.access_token, user_set ) return await self._send(KeysClaimResponse, method, path, data)
[docs] @logged_in @store_loaded async def share_group_session( self, room_id, # type: str tx_id=None, # type: Optional[str] ignore_unverified_devices=False # type: bool ): # type: (...) -> _ShareGroupSessionT """Share a group session with a room. This method sends a group session to members of a room. Args: room_id(str): The room id of the room where the message should be sent to. tx_id(str, optional): The transaction ID of this event used to uniquely identify this message. ignore_unverified_devices(bool): Mark unverified devices as ignored. Ignored devices will still receive encryption keys for messages but they won't be marked as verified. Raises LocalProtocolError if the client isn't logged in, if the session store isn't loaded, no room with the given room id exists, the room isn't an encrypted room or a key sharing request is already in flight for this room. """ assert self.olm try: room = self.rooms[room_id] except KeyError: raise LocalProtocolError("No such room with id {}".format(room_id)) if not room.encrypted: raise LocalProtocolError("Room with id {} is not encrypted".format( room_id)) if room_id in self.sharing_session: raise LocalProtocolError( "Already sharing a group session for {}".format(room_id) ) self.sharing_session[room_id] = Event() shared_with = set() missing_sessions = self.get_missing_sessions(room_id) if missing_sessions: await self.keys_claim(missing_sessions) try: while True: user_set, to_device_dict = self.olm.share_group_session( room_id, list(room.users.keys()), ignore_missing_sessions=True, ignore_unverified_devices=ignore_unverified_devices ) uuid = tx_id or uuid4() method, path, data = Api.to_device( self.access_token, "m.room.encrypted", to_device_dict, uuid ) response = await self._send( ShareGroupSessionResponse, method, path, data, (room_id, user_set) ) if isinstance(response, ShareGroupSessionResponse): shared_with.update(response.users_shared_with) except LocalProtocolError: return ShareGroupSessionResponse(room_id, shared_with) except ClientConnectionError: raise finally: event = self.sharing_session.pop(room_id) event.set()
[docs] @logged_in @store_loaded async def request_room_key( self, event, # type: MegolmEvent tx_id=None # type: Optional[str] ): # type: (...) -> Union[RoomKeyRequestResponse, RoomKeyRequestError] """Request a missing room key. This sends out a message to other devices requesting a room key from them. Args: event (str): An undecrypted MegolmEvent for which we would like to request the decryption key. Returns either a `RoomKeyRequestResponse` if the request was successful or a `RoomKeyRequestError` if there was an error with the request. Raises a LocalProtocolError if the room key was already requested. """ uuid = tx_id or uuid4() if event.session_id in self.outgoing_key_requests: raise LocalProtocolError("A key sharing request is already sent" " out for this session id.") assert self.user_id assert self.device_id message = event.as_key_request(self.user_id, self.device_id) method, path, data = Api.to_device( self.access_token, message.type, message.as_dict(), uuid ) return await self._send( RoomKeyRequestResponse, method, path, data, ( event.session_id, event.session_id, event.room_id, event.algorithm ) )
[docs] async def close(self): """Close the underlying http session.""" if self.client_session: await self.client_session.close() self.client_session = None
[docs] @store_loaded async def export_keys(self, outfile, passphrase, count=10000): """Export all the Megolm decryption keys of this device. The keys will be encrypted using the passphrase. Note that this does not save other information such as the private identity keys of the device. Args: outfile (str): The file to write the keys to. passphrase (str): The encryption passphrase. count (int): Optional. Round count for the underlying key derivation. It is not recommended to specify it unless absolutely sure of the consequences. """ assert self.store assert self.olm loop = asyncio.get_event_loop() inbound_group_store = self.store.load_inbound_group_sessions() export_keys = partial(self.olm.export_keys_static, inbound_group_store, outfile, passphrase, count) await loop.run_in_executor(None, export_keys)
[docs] @store_loaded async def import_keys(self, infile, passphrase): """Import Megolm decryption keys. The keys will be added to the current instance as well as written to database. Args: infile (str): The file containing the keys. passphrase (str): The decryption passphrase. Raises `EncryptionError` if the file is invalid or couldn't be decrypted. Raises the usual file errors if the file couldn't be opened. """ assert self.store assert self.olm loop = asyncio.get_event_loop() import_keys = partial(self.olm.import_keys_static, infile, passphrase) sessions = await loop.run_in_executor(None, import_keys) for session in sessions: # This could be improved by writing everything to db at once at # the end if self.olm.inbound_group_store.add(session): self.store.save_inbound_group_session(session)
[docs] @logged_in async def join(self, room_id): # type: (str) -> Union[JoinResponse, JoinError] """Join a room. This tells the server to join the given room. If the room is not public, the user must be invited. Returns either a `JoinResponse` if the request was successful or a `JoinError` if there was an error with the request. Args: room_id: The room id or alias of the room to join. """ method, path, data = Api.join(self.access_token, room_id) return await self._send(JoinResponse, method, path, data)
[docs] @logged_in async def room_leave(self, room_id): # type: (str) -> Union[RoomLeaveResponse, RoomLeaveError] """Leave a room or reject an invite. This tells the server to leave the given room. If the user was only invited, the invite is rejected. Returns either a `RoomLeaveResponse` if the request was successful or a `RoomLeaveError` if there was an error with the request. Args: room_id: The room id of the room to leave. """ method, path, data = Api.room_leave(self.access_token, room_id) return await self._send(RoomLeaveResponse, method, path, data)
[docs] @logged_in async def room_forget(self, room_id): # type: (str) -> Union[RoomForgetResponse, RoomForgetError] """Forget a room. This tells the server to forget the given room's history for our user. If all users on a homeserver forget the room, the room will be eligible for deletion from that homeserver. Returns either a `RoomForgetResponse` if the request was successful or a `RoomForgetError` if there was an error with the request. Args: room_id (str): The room id of the room to forget. """ method, path, data = Api.room_forget(self.access_token, room_id) return await self._send( RoomForgetResponse, method, path, data, response_data=(room_id,) )
[docs] @logged_in async def room_context( self, room_id, # type: str event_id, # type: str limit=None, # type: Optional[int] ): # type: (...) -> Union[RoomContextResponse, RoomContextError] """Fetch a number of events that happened before and after an event. This allows clients to get the context surrounding an event. Returns either a `RoomContextResponse` if the request was successful or a `RoomContextError` if there was an error with the request. Args: room_id (str): The room id of the room that contains the event and its context. event_id (str): The event_id of the event that we wish to get the context for. limit(int, optional): The maximum number of events to request. """ method, path = Api.room_context(self.access_token, room_id, event_id, limit) return await self._send(RoomContextResponse, method, path, response_data=(room_id, ))
[docs] @logged_in async def room_messages( self, room_id, # type: str start, # type: str end=None, # type: Optional[str] direction=MessageDirection.back, # type: MessageDirection limit=10 # type: int ): # type: (...) -> Union[RoomMessagesResponse, RoomMessagesError] """Fetch a list of message and state events for a room. It uses pagination query parameters to paginate history in the room. Args: room_id (str): The room id of the room for which we would like to fetch the messages. start (str): The token to start returning events from. This token can be obtained from a prev_batch token returned for each room by the sync API, or from a start or end token returned by a previous request to this endpoint. end (str, optional): The token to stop returning events at. This token can be obtained from a prev_batch token returned for each room by the sync endpoint, or from a start or end token returned by a previous request to this endpoint. direction (MessageDirection, optional): The direction to return events from. Defaults to MessageDirection.back. limit (int, optional): The maximum number of events to return. Defaults to 10. Returns either a `RoomContextResponse` if the request was successful or a `RoomContextError` if there was an error with the request. Example: >>> response = await client.room_messages(room_id, previous_batch) >>> next_response = await client.room_messages(room_id, ... response.end) """ method, path = Api.room_messages( self.access_token, room_id, start, end=end, direction=direction, limit=limit ) return await self._send( RoomMessagesResponse, method, path, response_data=(room_id, ) )
[docs] @logged_in async def room_typing( self, room_id, # type: str typing_state=True, # type: bool timeout=30000 # type: int ): # type: (...) -> Union[RoomTypingResponse, RoomTypingError] """Send a typing notice to the server. This tells the server that the user is typing for the next N milliseconds or that the user has stopped typing. Returns either a `RoomTypingResponse` if the request was successful or a `RoomTypingError` if there was an error with the request. Args: room_id (str): The room id of the room where the user is typing. typign_state (bool): A flag representing whether the user started or stopped typing. timeout (int): For how long should the new typing notice be valid for in milliseconds. """ method, path, data = Api.room_typing( self.access_token, room_id, self.user_id, typing_state, timeout ) return await self._send( RoomTypingResponse, method, path, data, response_data=(room_id, ) )
[docs] @logged_in async def upload( self, data, # type: _UploadDataT content_type, # type: str filename=None # type: Optional[str] ): # type: (...) -> Union[UploadResponse, UploadError] """Upload a file's content to the content repository. Returns either a `UploadResponse` if the request was successful or a `UploadError` if there was an error with the request. Args: data (bytes/BinaryIO/AsyncIterable[bytes]): The file's binary content. Using a binary file-like object or async iterable allows sending large files without reading them into memory. content_type (str): The content MIME type of the file, e.g. "image/png" filename (str, optional): The file's original name. Example: >>> with open("vid.webm", "rb") as f: >>> response = await client.upload(f, "video/webm", "vid.webm") >>> http_url = nio.Api.mxc_to_http(response.content_uri) """ http_method, path, _ = Api.upload(self.access_token, filename) return await self._send( UploadResponse, http_method, path, data, content_type=content_type )
[docs] @logged_in async def thumbnail( self, server_name, # type: str media_id, # type: str width, # type: int height, # type: int method=ResizingMethod.scale, # ŧype: ResizingMethod allow_remote=True, # type: bool ): # type: (...) -> Union[ThumbnailResponse, ThumbnailError] """Get the thumbnail of a file from the content repository. Note: The actual thumbnail may be larger than the size specified. Returns either a `ThumbnailResponse` if the request was successful or a `ThumbnailError` if there was an error with the request. Args: server_name (str): The server name from the mxc:// URI. media_id (str): The media ID from the mxc:// URI. width (int): The desired width of the thumbnail. height (int): The desired height of the thumbnail. method (ResizingMethod): The desired resizing method. allow_remote (bool): Indicates to the server that it should not attempt to fetch the media if it is deemed remote. This is to prevent routing loops where the server contacts itself. """ http_method, path = Api.thumbnail( self.access_token, server_name, media_id, width, height, method, allow_remote ) return await self._send(ThumbnailResponse, http_method, path)
[docs] @logged_in async def get_profile(self, user_id=None): # type: (Optional[str]) -> Union[ProfileGetResponse, ProfileGetError] """Get a user's combined profile information. This queries the display name and avatar matrix content URI of a user from the server. Additional profile information may be present. The currently logged in user is queried if no user is specified. Returns either a `ProfileGetResponse` if the request was successful or a `ProfileGetError` if there was an error with the request. Args: user_id (str): User id of the user to get the profile for. """ method, path = Api.profile_get( self.access_token, user_id or self.user_id ) return await self._send( ProfileGetResponse, method, path, )
[docs] @logged_in async def get_displayname( self, user_id=None # type: Optional[str] ): # type: (...) -> _ProfileGetDisplayNameT """Get a user's display name. This queries the display name of a user from the server. The currently logged in user is queried if no user is specified. Returns either a `ProfileGetDisplayNameResponse` if the request was successful or a `ProfileGetDisplayNameError` if there was an error with the request. Args: user_id (str): User id of the user to get the display name for. """ method, path = Api.profile_get_displayname( self.access_token, user_id or self.user_id ) return await self._send( ProfileGetDisplayNameResponse, method, path, )
[docs] @logged_in async def set_displayname(self, displayname): # type: (str) -> _ProfileSetDisplayNameT """Set user's display name. This tells the server to set display name of the currently logged in user to the supplied string. Returns either a `ProfileSetDisplayNameResponse` if the request was successful or a `ProfileSetDisplayNameError` if there was an error with the request. Args: displayname (str): Display name to set. """ method, path, data = Api.profile_set_displayname( self.access_token, self.user_id, displayname ) return await self._send( ProfileSetDisplayNameResponse, method, path, data, )
[docs] @logged_in async def get_avatar( self, user_id=None # type: Optional[str] ): # type: (...) -> Union[ProfileGetAvatarResponse, ProfileGetAvatarError] """Get a user's avatar URL. This queries the avatar matrix content URI of a user from the server. The currently logged in user is queried if no user is specified. Returns either a `ProfileGetAvatarResponse` if the request was successful or a `ProfileGetAvatarError` if there was an error with the request. Args: user_id (str): User id of the user to get the avatar for. """ method, path = Api.profile_get_avatar( self.access_token, user_id or self.user_id ) return await self._send( ProfileGetAvatarResponse, method, path, )
[docs] @logged_in async def set_avatar(self, avatar_url): # type: (str) -> Union[ProfileSetAvatarResponse, ProfileSetAvatarError] """Set the user's avatar URL. This tells the server to set the avatar of the currently logged in user to supplied matrix content URI. Returns either a `ProfileSetAvatarResponse` if the request was successful or a `ProfileSetAvatarError` if there was an error with the request. Args: avatar_url (str): matrix content URI of the avatar to set. """ method, path, data = Api.profile_set_avatar( self.access_token, self.user_id, avatar_url ) return await self._send( ProfileSetAvatarResponse, method, path, data, )