Skip to content

Commit

Permalink
refactor: move pagination logic for historical market data to the bas…
Browse files Browse the repository at this point in the history
…e rest client
  • Loading branch information
gnvk committed Nov 8, 2024
1 parent 370b86b commit d3444c9
Show file tree
Hide file tree
Showing 14 changed files with 316 additions and 969 deletions.
12 changes: 9 additions & 3 deletions alpaca/common/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,17 @@ def map_values(val: Any) -> Any:

return val

d = self.model_dump(exclude_none=True)
if "symbol_or_symbols" in d:
s = d["symbol_or_symbols"]
if isinstance(s, list):
s = ",".join(s)
d["symbols"] = s
del d["symbol_or_symbols"]

# pydantic almost has what we need by passing exclude_none to dict() but it returns:
# {trusted_contact: {}, contact: {}, identity: None, etc}
# so we do a simple list comprehension to filter out None and {}
return {
key: map_values(val)
for key, val in self.model_dump(exclude_none=True).items()
if val and len(str(val)) > 0
key: map_values(val) for key, val in d.items() if val and len(str(val)) > 0
}
73 changes: 72 additions & 1 deletion alpaca/common/rest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from collections import defaultdict
from collections.abc import Callable
import time
import base64
from abc import ABC
from typing import Any, List, Optional, Type, Union, Tuple, Iterator
from typing import Any, Dict, List, Optional, Type, Union, Tuple, Iterator

from pydantic import BaseModel
from requests import Session
Expand Down Expand Up @@ -362,3 +364,72 @@ def _validate_credentials(
)

return api_key, secret_key, oauth_token

def get_marketdata(
self,
path: str,
params: Dict[str, Any] = None,
page_limit: int = 10_000,
no_sub_key: bool = False,
) -> Dict[str, List[Any]]:
d = defaultdict(list)
limit = params.get("limit")
total_items = 0
page_token = params.get("page_token")

while True:
actual_limit = None

# adjusts the limit parameter value if it is over the page_limit
if limit:
# actual_limit is the adjusted total number of items to query per request
actual_limit = min(int(limit) - total_items, page_limit)
if actual_limit < 1:
break

params["limit"] = actual_limit
params["page_token"] = page_token

response = self.get(path=path, data=params)

for k, v in _get_marketdata_entries(response, no_sub_key).items():
if isinstance(v, list):
d[k].extend(v)
else:
d[k] = v

# if we've sent a request with a limit, increment count
if actual_limit:
total_items = sum([len(items) for items in d.values()])

page_token = response.get("next_page_token", None)
if page_token is None:
break
return dict(d)


def _get_marketdata_entries(response: HTTPResult, no_sub_key: bool) -> RawData:
if no_sub_key:
return response

data_keys = {
"bar",
"bars",
"corporate_actions",
"news",
"orderbook",
"orderbooks",
"quote",
"quotes",
"snapshot",
"snapshots",
"trade",
"trades",
}
selected_key = data_keys.intersection(response)
if selected_key is None or len(selected_key) < 1:
raise ValueError("The data in response does not match any known keys.")
selected_key = selected_key.pop()
if selected_key == "news":
return {"news": response[selected_key]}
return response[selected_key]
59 changes: 6 additions & 53 deletions alpaca/data/historical/corporate_actions.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
from collections import defaultdict
from typing import Callable, Optional, Union
from typing import Optional, Union

from alpaca.common.enums import BaseURL
from alpaca.common.rest import RESTClient
from alpaca.common.types import RawData
from alpaca.data.historical.utils import get_data_from_response
from alpaca.data.models.corporate_actions import CorporateActionsSet
from alpaca.data.requests import CorporateActionsRequest

Expand Down Expand Up @@ -41,7 +39,7 @@ def __init__(
secret_key=secret_key,
oauth_token=oauth_token,
use_basic_auth=use_basic_auth,
api_version="v1beta1",
api_version="v1",
base_url=url_override if url_override is not None else BaseURL.DATA,
sandbox=False,
raw_data=raw_data,
Expand All @@ -62,57 +60,12 @@ def get_corporate_actions(
if request_params.types:
params["types"] = ",".join(request_params.types)

response = self._data_get(
path="/corporate-actions", api_version=self._api_version, **params
response = self.get_marketdata(
path="/corporate-actions",
params=params,
page_limit=1000,
)
if self._use_raw_data:
return response

return CorporateActionsSet(response)

# TODO: Refactor data_get (common to all historical data queries!)
def _data_get(
self,
path: str,
limit: Optional[int] = None,
page_limit: int = 1000,
api_version: str = "v1",
**kwargs,
) -> RawData:
params = kwargs

# data is grouped by corporate action type (reverse_splits, forward_splits, etc.)
d = defaultdict(list)

total_items = 0
page_token = None

while True:
actual_limit = None

# adjusts the limit parameter value if it is over the page_limit
if limit:
# actual_limit is the adjusted total number of items to query per request
actual_limit = min(int(limit) - total_items, page_limit)
if actual_limit < 1:
break

params["limit"] = actual_limit
params["page_token"] = page_token

response = self.get(path=path, data=params, api_version=api_version)

for ca_type, cas in get_data_from_response(response).items():
d[ca_type].extend(cas)

# if we've sent a request with a limit, increment count
if actual_limit:
total_items = sum([len(items) for items in d.values()])

page_token = response.get("next_page_token", None)

if page_token is None:
break

# users receive Type dict
return dict(d)
Loading

0 comments on commit d3444c9

Please sign in to comment.