Skip to content

Commit

Permalink
feat: support indicator alias
Browse files Browse the repository at this point in the history
  • Loading branch information
yuangn committed Nov 25, 2024
1 parent 3d4d3e3 commit 7bfb194
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 8 deletions.
6 changes: 3 additions & 3 deletions backtest/daily_use.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@

def test_stock_bond_ratio(): # 股债性价比回测

df = get_wind_data(indicator_list=['沪深300PE-TTM', '中债国债到期收益率10y'
df = get_wind_data(indicator_list=['沪深300PE:TTM', '中债国债到期收益率10y'
, '沪深300收盘价'])

df['股债性价比'] = 1 / df['沪深300PE-TTM'] * 100 - df['中债国债到期收益率10y']
df['股债性价比'] = 1 / df['沪深300PE:TTM'] * 100 - df['中债国债到期收益率10y']

def signal_func(factor_series: pd.Series, param: Dict):
threshold = param.get('threshold')
Expand Down Expand Up @@ -61,5 +61,5 @@ def signal_func(factor_series: pd.Series, param: Dict):
print(ret)

if __name__ == '__main__':
# test_stock_bond_ratio() # 股债性价比回测
test_stock_bond_ratio() # 股债性价比回测
test_strong_stock_ratio()
56 changes: 51 additions & 5 deletions backtest/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,53 @@ def visit_Subscript(self, node):
return [v.replace('_', ':') for v in variables]


def _get_alias_map():
'''
从数据库中取得指标别名映射表,用于支持用户使用方便的别名
'''
sql_stat = '''
select replace(indicator_name,':', '_') indicator_name, replace(remark,':', '_') remark from (
select indicator_name, remark from edb_desc ed
union
select indicator_name, remark from wind_desc wd
union
select name, remark from indicator_description
)
where remark is not null
'''
conn = sqlite3.connect(SQLITE_FILE_PATH)
df = pd.read_sql(sql_stat, conn)
return dict(zip(df['remark'], df['indicator_name']))


def _replace_variables_in_expr(expr: str) -> str:
expr = expr.replace(':', '_')
var_map = _get_alias_map()
# 替换表达式中的变量名
class CustomVisitor(ast.NodeVisitor):
def __init__(self, var_map):
self.var_map = var_map

def visit_Name(self, node):
# 如果节点的名称在map中,替换成对应的值
if node.id in self.var_map:
node.id = self.var_map[node.id]
self.generic_visit(node)

try:
# 解析表达式为AST
parsed_ast = ast.parse(expr, mode='eval')
visitor = CustomVisitor(var_map)
visitor.visit(parsed_ast)

# 将修改后的AST转换回表达式字符串
modified_expr = ast.unparse(parsed_ast) # 需要Python 3.9及以上
return modified_expr

except SyntaxError:
raise ValueError(f"无法解析表达式: {expr}")


# 所有序列函数的实现
def ema(series: pd.Series, window: int) -> pd.Series:
return series.ewm(span=window, adjust=False).mean()
Expand Down Expand Up @@ -637,8 +684,8 @@ def visit_Name(self, node):


def get_data(indicators: List, start_date: str='20200101'):
origin_indicators = [indi for indi in indicators if indi] # 用户看到的带冒号的指标名 不支持 减号 下划线
expr_indicators = [indi.replace(':', '_') for indi in indicators if indi] # 支持 eval的指标名
origin_indicators = [_replace_variables_in_expr(indi) for indi in indicators if indi] # 用户看到的带冒号的指标名 不支持 减号 下划线
expr_indicators = [_replace_variables_in_expr(indi) for indi in indicators if indi] # 支持 eval的指标名
indicators_set = set() # 支持eval的单个指标名
for indi in expr_indicators:
if '(' in indi or '-' in indi or '+' in indi: # 若是需要表达式解析的
Expand All @@ -650,7 +697,6 @@ def get_data(indicators: List, start_date: str='20200101'):
for indi in expr_indicators:
if '(' in indi or '-' in indi or '+' in indi: # 若是需要表达式解析的
df = calc_complicate_indicator(df, indi)
df.columns = [c.replace('_', ':') for c in df.columns]
return df[origin_indicators]


Expand All @@ -669,8 +715,8 @@ def iReturn(arr, start_date="20050101") -> PDFrame(headings=True, index=True):

@xlo.func
def iPlot(y_indicators, y2_indicators, start_date="20050101", convert2return=False, title=""):
y_indicators = [x for x in y_indicators.flatten().tolist() if x]
y2_indicators = [x for x in y2_indicators.flatten().tolist() if x]
y_indicators = [_replace_variables_in_expr(x) for x in y_indicators.flatten().tolist() if x]
y2_indicators = [_replace_variables_in_expr(x) for x in y2_indicators.flatten().tolist() if x]
df_data = get_data(y_indicators + y2_indicators, start_date)
if convert2return:
df_data = ((df_data.pct_change() + 1).cumprod() - 1) * 100
Expand Down

0 comments on commit 7bfb194

Please sign in to comment.