Skip to content

Commit

Permalink
Fix pydantic issue and update variable names to more meaningful names…
Browse files Browse the repository at this point in the history
… in models.py
  • Loading branch information
rgaveiga committed Sep 27, 2024
1 parent 3e6e201 commit ec03b85
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 26 deletions.
20 changes: 10 additions & 10 deletions optionlab/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import typing


VERSION = "1.3.0"
VERSION = "1.3.1"


if typing.TYPE_CHECKING:
Expand All @@ -11,16 +11,16 @@
Inputs,
OptionType,
OptionInfo,
OptionStrategy,
Option,
Outputs,
ClosedPosition,
ProbabilityOfProfitArrayInputs,
ProbabilityOfProfitInputs,
BlackScholesInfo,
Distribution,
Strategy,
StrategyLeg,
StrategyType,
StockStrategy,
Stock,
Country,
Action,
)
Expand Down Expand Up @@ -53,16 +53,16 @@
"Inputs",
"OptionType",
"OptionInfo",
"OptionStrategy",
"Option",
"Outputs",
"ClosedPosition",
"ProbabilityOfProfitArrayInputs",
"ProbabilityOfProfitInputs",
"BlackScholesInfo",
"Distribution",
"Strategy",
"StrategyLeg",
"StrategyType",
"StockStrategy",
"Stock",
"Country",
"Action",
# engine
Expand Down Expand Up @@ -96,15 +96,15 @@
"Outputs": (__package__, ".models"),
"OptionType": (__package__, ".models"),
"OptionInfo": (__package__, ".models"),
"OptionStrategy": (__package__, ".models"),
"Option": (__package__, ".models"),
"ClosedPosition": (__package__, ".models"),
"ProbabilityOfProfitArrayInputs": (__package__, ".models"),
"ProbabilityOfProfitInputs": (__package__, ".models"),
"BlackScholesInfo": (__package__, ".models"),
"Distribution": (__package__, ".models"),
"Strategy": (__package__, ".models"),
"StrategyLeg": (__package__, ".models"),
"StrategyType": (__package__, ".models"),
"StockStrategy": (__package__, ".models"),
"Stock": (__package__, ".models"),
"Country": (__package__, ".models"),
"Action": (__package__, ".models"),
# engine
Expand Down
8 changes: 4 additions & 4 deletions optionlab/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
from optionlab.models import (
Inputs,
Action,
OptionStrategy,
StockStrategy,
Option,
Stock,
ClosedPosition,
Outputs,
ProbabilityOfProfitInputs,
Expand Down Expand Up @@ -75,7 +75,7 @@ def _init_inputs(inputs: Inputs) -> EngineData:
for i, strategy in enumerate(inputs.strategy):
data.type.append(strategy.type)

if isinstance(strategy, OptionStrategy):
if isinstance(strategy, Option):
data.strike.append(strategy.strike)
data.premium.append(strategy.premium)
data.n.append(strategy.n)
Expand Down Expand Up @@ -110,7 +110,7 @@ def _init_inputs(inputs: Inputs) -> EngineData:
else:
raise ValueError("Expiration must be a date, an int or None.")

elif isinstance(strategy, StockStrategy):
elif isinstance(strategy, Stock):
data.n.append(strategy.n)
data.action.append(strategy.action)
data._previous_position.append(strategy.prev_pos or 0.0)
Expand Down
20 changes: 8 additions & 12 deletions optionlab/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,13 @@
]


class BaseStrategy(BaseModel):
class BaseLeg(BaseModel):
n: int = Field(gt=0)
action: Action
prev_pos: float | None = None


class StockStrategy(BaseStrategy):
class Stock(BaseLeg):
"""
"type" : string
It must be 'stock'. It is mandatory.
Expand All @@ -48,15 +49,12 @@ class StockStrategy(BaseStrategy):
negative, it means that the position is closed and the
difference between this price and the current price is
considered in the payoff calculation.
"""

type: Literal["stock"] = "stock"
n: int = Field(gt=0)
premium: float | None = None


class OptionStrategy(BaseStrategy):
class Option(BaseLeg):
"""
"type" : string
Either 'call' or 'put'. It is mandatory.
Expand All @@ -83,7 +81,6 @@ class OptionStrategy(BaseStrategy):
type: OptionType
strike: float = Field(gt=0)
premium: float = Field(gt=0)
n: int = Field(gt=0)
expiration: dt.date | int | None = None

@field_validator("expiration")
Expand All @@ -107,7 +104,7 @@ class ClosedPosition(BaseModel):
prev_pos: float


Strategy = StockStrategy | OptionStrategy | ClosedPosition
StrategyLeg = Stock | Option | ClosedPosition


class ProbabilityOfProfitInputs(BaseModel):
Expand Down Expand Up @@ -203,7 +200,7 @@ class Inputs(BaseModel):
interest_rate: float = Field(gt=0, le=0.2)
min_stock: float
max_stock: float
strategy: list[Strategy] = Field(..., min_length=1, discriminator="type")
strategy: list[StrategyLeg] = Field(..., min_length=1)
dividend_yield: float = 0.0
profit_target: float | None = None
loss_limit: float | None = None
Expand All @@ -221,7 +218,7 @@ class Inputs(BaseModel):

@field_validator("strategy")
@classmethod
def validate_strategy(cls, v: list[Strategy]) -> list[Strategy]:
def validate_strategy(cls, v: list[StrategyLeg]) -> list[StrategyLeg]:
types = [strategy.type for strategy in v]
if types.count("closed") > 1:
raise ValueError("Only one position of type 'closed' is allowed!")
Expand All @@ -232,8 +229,7 @@ def validate_dates(self) -> "Inputs":
expiration_dates = [
strategy.expiration
for strategy in self.strategy
if isinstance(strategy, OptionStrategy)
and isinstance(strategy.expiration, dt.date)
if isinstance(strategy, Option) and isinstance(strategy.expiration, dt.date)
]
if self.start_date and self.target_date:
if any(
Expand Down

0 comments on commit ec03b85

Please sign in to comment.