Source code for copra.websocket.client

# -*- coding: utf-8 -*-
"""Asynchronous WebSocket client for the Coinbase Pro platform.

"""

import asyncio
import base64
import hashlib
import hmac
import json
import logging
import time
from urllib.parse import urlparse

from autobahn.asyncio.websocket import WebSocketClientFactory
from autobahn.asyncio.websocket import WebSocketClientProtocol

logger = logging.getLogger(__name__)

FEED_URL = 'wss://ws-feed.pro.coinbase.com:443'
SANDBOX_FEED_URL = 'wss://ws-feed-public.sandbox.pro.coinbase.com:443'


class ClientProtocol(WebSocketClientProtocol):
    """Websocket client protocol.

    This is a subclass of autobahn.asyncio.WebSocket.WebSocketClientProtocol.
    In most cases this should not need to be subclassed or even accessed
    directly.
    """

    def __call__(self):
        return self

    def onOpen(self):
        """Callback fired on initial WebSocket opening handshake completion.

        You now can send and receive WebSocket messages.
        """
        self.factory.on_open()

    def onClose(self, wasClean, code, reason):
        """Callback fired when the WebSocket connection has been closed.

        (WebSocket closing handshake has been finished or the connection was
        closed uncleanly).

        Args:
          wasClean (bool): True iff the WebSocket connection closed cleanly.
          code (int or None): Close status code as sent by the WebSocket peer.
          reason (str or None): Close reason as sent by the WebSocket peer.
        """
        self.factory.on_close(wasClean, code, reason)

    def onMessage(self, payload, isBinary):
        """Callback fired when a complete WebSocket message was received.

        Call its factory's (the client's) on_message method with a
        dict representing the JSON message receieved.

        Args:
            payload (bytes): The WebSocket message received.
            isBinary (bool): Flag indicating whether payload is binary or UTF-8
            encoded text.
        """
        msg = json.loads(payload.decode('utf8'))
        if msg['type'] == 'error':
            self.factory.on_error(msg['message'], msg.get('reason', ''))
        else:
            self.factory.on_message(msg)


[docs]class Client(WebSocketClientFactory): """Asyncronous WebSocket client for Coinbase Pro. """
[docs] def __init__(self, loop, channels, feed_url=FEED_URL, auth=False, key='', secret='', passphrase='', auto_connect=True, auto_reconnect=True, name='WebSocket Client'): """ :param loop: The asyncio loop that the client runs in. :type loop: asyncio loop :param channels: The channels to initially subscribe to. :type channels: Channel or list of Channels :param str feed_url: The url of the WebSocket server. The defualt is copra.WebSocket.FEED_URL (wss://ws-feed.gdax.com) :param bool auth: Whether or not the (entire) WebSocket session is authenticated. If True, you will need an API key from the Coinbase Pro website. The default is False. :param str key: The API key to use for authentication. Required if auth is True. The default is ''. :param str secret: The secret string for the API key used for authenticaiton. Required if auth is True. The default is ''. :param str passphrase: The passphrase for the API key used for authentication. Required if auth is True. The default is ''. :param bool auto_connect: If True, the Client will automatically add itself to its event loop (ie., open a connection if the loop is running or as soon as it starts). If False, add_as_task_to_loop() needs to be explicitly called to add the client to the loop. The default is True. :param bool auto_reconnect: If True, the Client will attemp to autom- matically reconnect and resubscribe if the connection is closed any way but by the Client explicitly itself. The default is True. :param str name: A name to identify this client in logging, etc. :raises ValueError: If auth is True and key, secret, and passphrase are not provided. """ self.loop = loop self.connected = asyncio.Event() self.disconnected = asyncio.Event() self.disconnected.set() self.closing = False if not isinstance(channels, list): channels = [channels] self._initial_channels = channels self.feed_url = feed_url self.channels = {} self.subscribe(channels) if auth and not (key and secret and passphrase): raise ValueError('auth requires key, secret, and passphrase') self.auth = auth self.key = key self.secret = secret self.passphrase = passphrase self.auto_connect = auto_connect self.auto_reconnect = auto_reconnect self.name = name super().__init__(self.feed_url) if self.auto_connect: self.add_as_task_to_loop()
def _get_subscribe_message(self, channels, unsubscribe=False, timestamp=None): """Create and return the subscription message for the provided channels. :param channels: List of channels to be subscribed to. :type channels: list of Channel :param bool unsubscribe: If True, returns an unsubscribe message instead of a subscribe method. The default is False. :returns: JSON-formatted, UTF-8 encoded bytes object representing the subscription message for the provided channels. """ msg_type = 'unsubscribe' if unsubscribe else 'subscribe' msg = {'type': msg_type, 'channels': [channel._as_dict() for channel in channels]} if self.auth: if not timestamp: timestamp = str(time.time()) message = timestamp + 'GET' + '/users/self/verify' message = message.encode('ascii') hmac_key = base64.b64decode(self.secret) signature = hmac.new(hmac_key, message, hashlib.sha256) signature_b64 = base64.b64encode(signature.digest()) signature_b64 = signature_b64.decode('utf-8').rstrip('\n') msg['signature'] = signature_b64 msg['key'] = self.key msg['passphrase'] = self.passphrase msg['timestamp'] = timestamp return json.dumps(msg).encode('utf8')
[docs] def subscribe(self, channels): """Subscribe to the given channels. :param channels: The channels to subscribe to. :type channels: Channel or list of Channels """ if not isinstance(channels, list): channels = [channels] sub_channels = [] for channel in channels: if channel.name in self.channels: sub_channel = channel - self.channels[channel.name] if sub_channel: self.channels[channel.name] += channel sub_channels.append(sub_channel) else: self.channels[channel.name] = channel sub_channels.append(channel) if self.connected.is_set(): msg = self._get_subscribe_message(sub_channels) self.protocol.sendMessage(msg)
[docs] def unsubscribe(self, channels): """Unsubscribe from the given channels. :param channels: The channels to subscribe to. :type channels: Channel or list of Channels """ if not isinstance(channels, list): channels = [channels] for channel in channels: if channel.name in self.channels: self.channels[channel.name] -= channel if not self.channels[channel.name]: del self.channels[channel.name] if self.connected.is_set(): msg = self._get_subscribe_message(channels, unsubscribe=True) self.protocol.sendMessage(msg)
[docs] def add_as_task_to_loop(self): """Add the client to the asyncio loop. Creates a coroutine for making a connection to the WebSocket server and adds it as a task to the asyncio loop. """ self.protocol = ClientProtocol() url = urlparse(self.url) self.coro = self.loop.create_connection(self, url.hostname, url.port, ssl=(url.scheme == 'wss')) self.loop.create_task(self.coro)
[docs] def on_open(self): """Callback fired on initial WebSocket opening handshake completion. The WebSocket is open. This method sends the subscription message to the server. """ self.connected.set() self.disconnected.clear() self.closing = False logger.info('{} connected to {}'.format(self.name, self.url)) msg = self._get_subscribe_message(self.channels.values()) self.protocol.sendMessage(msg)
[docs] def on_close(self, was_clean, code, reason): """Callback fired when the WebSocket connection has been closed. (WebSocket closing handshake has been finished or the connection was closed uncleanly). :param bool was_clean: True iff the WebSocket connection closed cleanly. :param code: Close status code as sent by the WebSocket peer. :type code: int or None :param reason: Close reason as sent by the WebSocket peer. :type reason: str or None """ self.connected.clear() self.disconnected.set() msg = '{} connection to {} {}closed. {}' expected = 'unexpectedly ' if self.closing is False else '' logger.info(msg.format(self.name, self.url, expected, reason)) if not self.closing and self.auto_reconnect: msg = '{} attempting to reconnect to {}.' logger.info(msg.format(self.name, self.url)) self.add_as_task_to_loop()
[docs] def on_error(self, message, reason=''): """Callback fired when an error message is received. :param str message: A general description of the error. :param str reason: A more detailed description of the error. """ logger.error('{}. {}'.format(message, reason))
[docs] def on_message(self, message): """Callback fired when a complete WebSocket message was received. You will likely want to override this method. :param dict message: Dictionary representing the message. """ print(message)
[docs] async def close(self): """Close the WebSocket connection. """ self.closing = True self.protocol.sendClose() await self.disconnected.wait()
if __name__ == '__main__': # A sanity check. logging.getLogger().setLevel(logging.DEBUG) logging.getLogger().addHandler(logging.StreamHandler()) loop = asyncio.get_event_loop() ws = Client(loop, [Channel('heartbeat', 'BTC-USD')]) async def add_a_channel(): await asyncio.sleep(20) ws.subscribe(Channel('heartbeat', 'LTC-USD')) loop.create_task(remove_a_channel()) async def remove_a_channel(): await asyncio.sleep(20) ws.unsubscribe(Channel('heartbeat', 'BTC-USD')) loop.create_task(add_a_channel()) try: loop.run_forever() except KeyboardInterrupt: loop.run_until_complete(ws.close()) loop.close()