Skip to content

Commit

Permalink
Move updated bars to their own subscription (#574)
Browse files Browse the repository at this point in the history
* Move updated bars to their own subscription

* update unsubscribe

* pep8 format fix
  • Loading branch information
gnvk authored Feb 24, 2022
1 parent 294b105 commit 43f0766
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 20 deletions.
87 changes: 67 additions & 20 deletions alpaca_trade_api/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,11 @@ def __init__(self,
self._raw_data = raw_data
self._stop_stream_queue = queue.Queue()
self._handlers = {
'trades': {},
'quotes': {},
'bars': {},
'dailyBars': {},
'trades': {},
'quotes': {},
'bars': {},
'updatedBars': {},
'dailyBars': {},
}
self._name = 'data'
self._should_run = True
Expand Down Expand Up @@ -137,7 +138,7 @@ def _cast(self, msg_type, msg):
quote_mapping_v2[k]: v
for k, v in msg.items() if k in quote_mapping_v2
})
elif msg_type in ('b', 'd'):
elif msg_type in ('b', 'u', 'd'):
result = Bar({
bar_mapping_v2[k]: v
for k, v in msg.items() if k in bar_mapping_v2
Expand All @@ -164,6 +165,11 @@ async def _dispatch(self, msg):
symbol, self._handlers['bars'].get('*', None))
if handler:
await handler(self._cast(msg_type, msg))
elif msg_type == 'u':
handler = self._handlers['updatedBars'].get(
symbol, self._handlers['updatedBars'].get('*', None))
if handler:
await handler(self._cast(msg_type, msg))
elif msg_type == 'd':
handler = self._handlers['dailyBars'].get(
symbol, self._handlers['dailyBars'].get('*', None))
Expand Down Expand Up @@ -201,15 +207,17 @@ async def _unsubscribe(self,
trades=(),
quotes=(),
bars=(),
updated_bars=(),
daily_bars=()):
if trades or quotes or bars or daily_bars:
if trades or quotes or bars or updated_bars or daily_bars:
await self._ws.send(
msgpack.packb({
'action': 'unsubscribe',
'trades': trades,
'quotes': quotes,
'bars': bars,
'dailyBars': daily_bars,
'action': 'unsubscribe',
'trades': trades,
'quotes': quotes,
'bars': bars,
'updatedBars': updated_bars,
'dailyBars': daily_bars,
}))

async def _run_forever(self):
Expand Down Expand Up @@ -261,6 +269,9 @@ def subscribe_quotes(self, handler, *symbols):
def subscribe_bars(self, handler, *symbols):
self._subscribe(handler, symbols, self._handlers['bars'])

def subscribe_updated_bars(self, handler, *symbols):
self._subscribe(handler, symbols, self._handlers['updatedBars'])

def subscribe_daily_bars(self, handler, *symbols):
self._subscribe(handler, symbols, self._handlers['dailyBars'])

Expand Down Expand Up @@ -288,6 +299,13 @@ def unsubscribe_bars(self, *symbols):
for symbol in symbols:
del self._handlers['bars'][symbol]

def unsubscribe_updated_bars(self, *symbols):
if self._running:
asyncio.get_event_loop().run_until_complete(
self._unsubscribe(updated_bars=symbols))
for symbol in symbols:
del self._handlers['updatedBars'][symbol]

def unsubscribe_daily_bars(self, *symbols):
if self._running:
asyncio.run_coroutine_threadsafe(
Expand Down Expand Up @@ -377,19 +395,22 @@ async def _unsubscribe(self,
trades=(),
quotes=(),
bars=(),
updated_bars=(),
daily_bars=(),
statuses=(),
lulds=()):
if trades or quotes or bars or daily_bars or statuses or lulds:
if (trades or quotes or bars or updated_bars or daily_bars or
statuses or lulds):
await self._ws.send(
msgpack.packb({
'action': 'unsubscribe',
'trades': trades,
'quotes': quotes,
'bars': bars,
'dailyBars': daily_bars,
'statuses': statuses,
'lulds': lulds,
'action': 'unsubscribe',
'trades': trades,
'quotes': quotes,
'bars': bars,
'updatedBars': updated_bars,
'dailyBars': daily_bars,
'statuses': statuses,
'lulds': lulds,
}))

def subscribe_statuses(self, handler, *symbols):
Expand Down Expand Up @@ -702,6 +723,9 @@ def subscribe_quotes(self, handler, *symbols):
def subscribe_bars(self, handler, *symbols):
self._data_ws.subscribe_bars(handler, *symbols)

def subscribe_updated_bars(self, handler, *symbols):
self._data_ws.subscribe_updated_bars(handler, *symbols)

def subscribe_daily_bars(self, handler, *symbols):
self._data_ws.subscribe_daily_bars(handler, *symbols)

Expand All @@ -720,6 +744,9 @@ def subscribe_crypto_quotes(self, handler, *symbols):
def subscribe_crypto_bars(self, handler, *symbols):
self._crypto_ws.subscribe_bars(handler, *symbols)

def subscribe_crypto_updated_bars(self, handler, *symbols):
self._crypto_ws.subscribe_updated_bars(handler, *symbols)

def subscribe_crypto_daily_bars(self, handler, *symbols):
self._crypto_ws.subscribe_daily_bars(handler, *symbols)

Expand Down Expand Up @@ -751,6 +778,13 @@ def decorator(func):

return decorator

def on_updated_bar(self, *symbols):
def decorator(func):
self.subscribe_updated_bars(func, *symbols)
return func

return decorator

def on_daily_bar(self, *symbols):
def decorator(func):
self.subscribe_daily_bars(func, *symbols)
Expand Down Expand Up @@ -807,6 +841,13 @@ def decorator(func):

return decorator

def on_crypto_updated_bar(self, *symbols):
def decorator(func):
self.subscribe_crypto_updated_bars(func, *symbols)
return func

return decorator

def on_crypto_daily_bar(self, *symbols):
def decorator(func):
self.subscribe_crypto_daily_bars(func, *symbols)
Expand All @@ -832,6 +873,9 @@ def unsubscribe_quotes(self, *symbols):
def unsubscribe_bars(self, *symbols):
self._data_ws.unsubscribe_bars(*symbols)

def unsubscribe_updated_bars(self, *symbols):
self._data_ws.unsubscribe_updated_bars(*symbols)

def unsubscribe_daily_bars(self, *symbols):
self._data_ws.unsubscribe_daily_bars(*symbols)

Expand All @@ -850,6 +894,9 @@ def unsubscribe_crypto_quotes(self, *symbols):
def unsubscribe_crypto_bars(self, *symbols):
self._crypto_ws.unsubscribe_bars(*symbols)

def unsubscribe_crypto_updated_bars(self, *symbols):
self._crypto_ws.unsubscribe_updated_bars(*symbols)

def unsubscribe_crypto_daily_bars(self, *symbols):
self._crypto_ws.unsubscribe_daily_bars(*symbols)

Expand Down Expand Up @@ -901,7 +948,7 @@ def is_open(self):
:return:
"""
open_ws = (self._trading_ws._ws or self._data_ws._ws
or self._crypto_ws._ws or self._news_ws) # noqa
or self._crypto_ws._ws or self._news_ws) # noqa
if open_ws:
return True
return False
4 changes: 4 additions & 0 deletions examples/websockets/v2_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ def main():
async def _(bar):
print('bar', bar)

@stream.on_updated_bar('MSFT')
async def _(bar):
print('updated bar', bar)

@stream.on_status("*")
async def _(status):
print('status', status)
Expand Down

0 comments on commit 43f0766

Please sign in to comment.