Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add get_corporate_actions #502

Merged
merged 6 commits into from
Nov 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions alpaca/data/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,3 +140,39 @@ class NewsImageSize(str, Enum):
THUMB = "thumb"
SMALL = "small"
LARGE = "large"


class CorporateActionsType(str, Enum):
"""
The type of corporate action.
ref. https://docs.alpaca.markets/reference/corporateactions-1

Attributes:
REVERSE_SPLIT (str): Reverse split
FORWARD_SPLIT (str): Forward split
UNIT_SPLIT (str): Unit split
CASH_DIVIDEND (str): Cash dividend
STOCK_DIVIDEND (str): Stock dividend
SPIN_OFF (str): Spin off
CASH_MERGER (str): Cash merger
STOCK_MERGER (str): Stock merger
STOCK_AND_CASH_MERGER (str): Stock and cash merger
REDEMPTION (str): Redemption
NAME_CHANGE (str): Name change
WORTHLESS_REMOVAL (str): Worthless removal
RIGHTS_DISTRIBUTION (str): Rights distribution
"""

REVERSE_SPLIT = "reverse_split"
FORWARD_SPLIT = "forward_split"
UNIT_SPLIT = "unit_split"
CASH_DIVIDEND = "cash_dividend"
STOCK_DIVIDEND = "stock_dividend"
SPIN_OFF = "spin_off"
CASH_MERGER = "cash_merger"
STOCK_MERGER = "stock_merger"
STOCK_AND_CASH_MERGER = "stock_and_cash_merger"
REDEMPTION = "redemption"
NAME_CHANGE = "name_change"
WORTHLESS_REMOVAL = "worthless_removal"
RIGHTS_DISTRIBUTION = "rights_distribution"
118 changes: 118 additions & 0 deletions alpaca/data/historical/corporate_actions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
from collections import defaultdict
from typing import Callable, Optional, Union

from alpaca.common.enums import BaseURL
from alpaca.common.rest import RESTClient
from alpaca.common.types import RawData
from alpaca.data.historical.utils import get_data_from_response
from alpaca.data.models.corporate_actions import CorporateActionsSet
from alpaca.data.requests import CorporateActionsRequest


class CorporateActionsClient(RESTClient):
"""
The REST client for interacting with Alpaca Corporate Actions API endpoints.
"""

def __init__(
self,
api_key: Optional[str] = None,
secret_key: Optional[str] = None,
oauth_token: Optional[str] = None,
use_basic_auth: bool = False,
raw_data: bool = False,
url_override: Optional[str] = None,
) -> None:
"""
Instantiates a Corporate Actions Client.
Args:
api_key (Optional[str], optional): Alpaca API key. Defaults to None.
secret_key (Optional[str], optional): Alpaca API secret key. Defaults to None.
oauth_token (Optional[str]): The oauth token if authenticating via OAuth. Defaults to None.
use_basic_auth (bool, optional): If true, API requests will use basic authorization headers.
raw_data (bool, optional): If true, API responses will not be wrapped and raw responses will be returned from
methods. Defaults to False. This has not been implemented yet.
url_override (Optional[str], optional): If specified allows you to override the base url the client points
to for proxy/testing.
"""
super().__init__(
api_key=api_key,
secret_key=secret_key,
oauth_token=oauth_token,
use_basic_auth=use_basic_auth,
api_version="v1beta1",
base_url=url_override if url_override is not None else BaseURL.DATA,
sandbox=False,
raw_data=raw_data,
)

def get_corporate_actions(
self, request_params: CorporateActionsRequest
) -> Union[RawData, CorporateActionsSet]:
"""Returns corporate actions data
Args:
request_params (CorporateActionsRequest): The request params to filter the corporate actions data
"""
params = request_params.to_request_fields()

if request_params.symbols:
params["symbols"] = ",".join(request_params.symbols)
if request_params.types:
params["types"] = ",".join(request_params.types)

response = self._data_get(
path="/corporate-actions", api_version=self._api_version, **params
)
if self._use_raw_data:
return response

return CorporateActionsSet(response)

# TODO: Refactor data_get (common to all historical data queries!)
def _data_get(
self,
path: str,
limit: Optional[int] = None,
page_limit: int = 1000,
api_version: str = "v1",
**kwargs,
) -> RawData:
params = kwargs

# data is grouped by corporate action type (reverse_splits, forward_splits, etc.)
d = defaultdict(list)

total_items = 0
page_token = None

while True:
actual_limit = None

# adjusts the limit parameter value if it is over the page_limit
if limit:
# actual_limit is the adjusted total number of items to query per request
actual_limit = min(int(limit) - total_items, page_limit)
if actual_limit < 1:
break

params["limit"] = actual_limit
params["page_token"] = page_token

response = self.get(path=path, data=params, api_version=api_version)

for ca_type, cas in get_data_from_response(response).items():
d[ca_type].extend(cas)

# if we've sent a request with a limit, increment count
if actual_limit:
total_items = sum([len(items) for items in d.values()])

page_token = response.get("next_page_token", None)

if page_token is None:
break

# users receive Type dict
return dict(d)
1 change: 1 addition & 0 deletions alpaca/data/historical/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def get_data_from_response(response: HTTPResult) -> RawData:
"snapshots",
"orderbook",
"orderbooks",
"corporate_actions",
}

selected_key = data_keys.intersection(response)
Expand Down
7 changes: 6 additions & 1 deletion alpaca/data/models/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import itertools
from typing import Any, Dict, List

import numpy as np
import pandas as pd
from pandas import DataFrame

Expand All @@ -20,12 +21,16 @@ def df(self) -> DataFrame:
data_list = list(itertools.chain.from_iterable(self.dict().values()))

df = pd.DataFrame(data_list)
columns = df.columns

# set multi-level index
if "news" in self.dict():
# level=0 - id
df = df.set_index(["id"])
if set(["symbol", "timestamp"]).issubset(df.columns):
elif "corporate_action_type" in columns:
# level=0 - corporate_action_type
df = df.set_index(["corporate_action_type"])
elif set(["symbol", "timestamp"]).issubset(columns):
# level=0 - symbol
# level=1 - timestamp
df = df.set_index(["symbol", "timestamp"])
Expand Down
Loading
Loading