Source code for interactive_python.connection

import asyncio
import json
import websockets
import collections

from .log import logger
from .encoding import Encoding, TextEncoding
from ._util import json_encoder


[docs]class Call: def __init__(self, connection, payload): """ A Call is an incoming message from the Interactive service. :param connection: the connection :param payload: """ self._connection = connection self._payload = payload @property def name(self): """ :return: The name of the method being called. :rtype: str """ return self._payload['method'] @property def data(self): """ :return: The payload of the method being called. :rtype: dict """ return self._payload['params'] def reply_error(self, result): """ Submits a successful reply for the call. :param result: The result to send to tetrisd """ self._connection.reply(self._id, result=result)
[docs] def reply_error(self, error): """ Submits an errorful reply for the call. :param error: The error to send to tetrisd """ self._connection.reply(self._id, error=error)
[docs]class Connection: """ The Connection is used to connect to the Interactive server. It connects to a provided socket address and provides an interface for making RPC calls. Example usage:: connection = Connection( address=get_interactive_address(), authorization="Bearer {}".format(my_oauth_token), interactive_version_id=1234) """ def __init__(self, address=None, authorization=None, project_version_id=None, project_sharecode=None, extra_headers={}, loop=asyncio.get_event_loop(), socket=None, protocol_version="2.0"): if authorization is not None: extra_headers['Authorization'] = authorization if project_version_id is not None: extra_headers['X-Interactive-Version'] = project_version_id if project_sharecode is not None: extra_headers['X-Interactive-Sharecode'] = project_sharecode extra_headers['X-Protocol-Version'] = protocol_version self._socket_or_connector = socket or websockets.client.connect( address, loop=loop, extra_headers=extra_headers) self._socket = None self._loop = loop self._encoding = TextEncoding() self._awaiting_replies = {} self._call_counter = 0 self._last_sequence_number = 0 self._recv_queue = collections.deque() self._recv_await = None self._recv_task = None
[docs] async def connect(self): """ Connects to the Interactive server, waiting until the connection if fully established before returning. Can throw a ClosedError if something, such as authentication, fails. """ if not hasattr(self._socket_or_connector, '__await__'): self._socket = self._socket_or_connector else: self._socket = await self._socket_or_connector # await a hello event while True: packet = await self._read_single() if packet['type'] == 'method' and packet['method'] == 'hello': break self._recv_queue.append(packet) self._recv_task = asyncio.ensure_future(self._read(), loop=self._loop)
def _fallback_to_plain_text(self): if isinstance(self._encoding, TextEncoding): return # we're already falling back self._encoding = TextEncoding() asyncio.ensure_future( self.set_compression(TextEncoding()), loop=self._loop) def _decode(self, data): """ Converts the packet data to a string, decompressing it if necessary. Always returns a string. """ if isinstance(data, str): return data try: return self._encoding.decode(data) except Exception as e: self._fallback_to_plain_text() logger.info("error decoding Interactive message, falling back to" "plain text", extra=e) def _encode(self, data): """ Converts the packet data to a string or byte array, compressing it if necessary. """ try: return self._encoding.encode(data) except Exception as e: self._fallback_to_plain_text() logger.warn("error encoding Interactive message, falling back to" "plain text", extra=e) return data def _handle_recv(self, data): """ Handles a single received packet from the Interactive service. """ if 'seq' in data: self._last_sequence_number = data['seq'] if data['type'] == 'reply': if data['id'] in self._awaiting_replies: self._awaiting_replies[data['id']]. \ set_result(data['result']) del self._awaiting_replies[data['id']] return self._recv_queue.append(Call(self, data)) if self._recv_await is not None: self._recv_await.set_result(True) self._recv_await = None def _send(self, payload): """ Encodes and sends a dict payload. """ future = self._socket.send(self._encode(json_encoder.encode(payload))) asyncio.ensure_future(future, loop=self._loop) async def _read_single(self): """ Reads a single event off the websocket. """ try: raw_data = await self._socket.recv() except (asyncio.CancelledError, websockets.ConnectionClosed) as e: if self._recv_await is None: self._recv_await = asyncio.Future(loop=self._loop) self._recv_await.set_result(False) raise e return json.loads(self._decode(raw_data)) async def _read(self): """ Endless read loop that runs until the socket is closed. """ while True: try: data = await self._read_single() except (asyncio.CancelledError, websockets.ConnectionClosed): break # will already be handled except Exception as e: logger.error("error in interactive read loop", extra=e) break if isinstance(data, list): for item in data: self._handle_recv(item) else: self._handle_recv(data)
[docs] async def set_compression(self, scheme): """Updates the compression used on the websocket this should be called with an instance of the Encoding class, for example:: connection.set_compression(GzipEncoding()) You can, optionally, await on the resolution of method, though doing so it not at all required. Returns True if the server agreed on and executed the switch. :param scheme: The compression scheme to use :type scheme: Encoding :return: Whether the upgrade was successful :rtype: bool """ result = await self.call("setCompression", {'scheme': [scheme.name()]}) if result['scheme'] == scheme.name(): self._encoding = scheme return True return False
[docs] def reply(self, call_id, result=None, error=None): """ Sends a reply for a packet id. Either the result or error should be fulfilled. :param call_id: The ID of the call being replied to. :type call_id: int :param result: The successful result of the call. :param error: The errorful result of the call. """ packet = {'type': 'reply', 'id': call_id} if result is not None: packet['result'] = result if error is not None: packet['error'] = result self._send(packet)
[docs] async def call(self, method, params, discard=False, timeout=10): """ Sends a method call to the interactive socket. If discard is false, we'll wait for a response before returning, up to the timeout duration in seconds, at which point it raises an asyncio.TimeoutError. If the timeout is None, we'll wait forever. :param method: Method name to call :type method: str :param params: Parameters to insert into the method, generally a dict. :param discard: ``True`` to not request any reply to the method. :type discard: bool :param timeout: Call timeout duration, in seconds. :type timeout: int :return: The call response, or None if it was discarded. :raises: asyncio.TimeoutError """ packet = { 'type': 'method', 'method': method, 'params': params, 'id': self._call_counter, 'seq': self._last_sequence_number, } if discard: packet['discard'] = True self._call_counter += 1 self._send(packet) if discard: return None future = asyncio.Future(loop=self._loop) self._awaiting_replies[packet['id']] = future try: return await asyncio.wait_for(future, timeout, loop=self._loop) except Exception as e: del self._awaiting_replies[packet['id']] raise e
[docs] def get_packet(self): """ Synchronously reads a packet from the connection. Returns None if there are no more packets in the queue. Example:: while await connection.has_packet(): dispatch_call(connection.get_packet()) :rtype: Call """ if len(self._recv_queue) > 0: return self._recv_queue.popleft() return None
[docs] async def has_packet(self): """ Blocks until a packet is read. Returns true if a packet is then available, or false if the connection is subsequently closed. Example:: while await connection.has_packet(): dispatch_call(connection.get_packet()) :rtype: bool """ if len(self._recv_queue) > 0: return if self._recv_await is None: self._recv_await = asyncio.Future(loop=self._loop) return await self._recv_await
[docs] async def close(self): """Closes the socket connection gracefully""" self._recv_task.cancel() await self._socket.close()