From 88a403f92a16f09ba5d827c691305c28b87b7104 Mon Sep 17 00:00:00 2001 From: yuangn Date: Tue, 26 Nov 2024 16:03:34 +0800 Subject: [PATCH] feat: add some funcs --- backtest/load.py | 417 +++++++++++++++++++++++++++----------- backtest/requirements.txt | 2 + 2 files changed, 302 insertions(+), 117 deletions(-) diff --git a/backtest/load.py b/backtest/load.py index 003ccc2..2a472c1 100644 --- a/backtest/load.py +++ b/backtest/load.py @@ -1,13 +1,21 @@ +import ast +import datetime as dt import re -import xloil as xlo -import pandas as pd +import win32api +import win32con + import numpy as np +import pandas as pd +import scipy.stats as stats +import statsmodels.regression.linear_model as sm_ols +from statsmodels.tools.tools import add_constant +from statsmodels.tsa.stattools import coint +from typing import List, Tuple import plotly.express as px -import datetime as dt -import win32api, win32con -from typing import List + +import xloil as xlo from xloil.pandas import PDFrame -import ast + import sqlite3 SQLITE_FILE_PATH = r'D:\onedrive\文档\etc\ifind.db' @@ -17,12 +25,14 @@ def print_status(*args): with xlo.StatusBar(2000) as status: status.msg(",".join([str(a) for a in args])) + def _is_fetching(data): for item in data: - if item==None or item=='' or item=='抓取中...': + if item == None or item == '' or item == '抓取中...': return True return False + def MsgBox(content: str = "", title: str = "知心提示") -> int: response = win32api.MessageBox( 0, content, title, 4, win32con.MB_SYSTEMMODAL) @@ -35,6 +45,7 @@ def execute_sql(sql_stat): cursor.execute(sql_stat) conn.commit() + def create_table(): execute_sql(''' CREATE TABLE IF NOT EXISTS tdate ( @@ -138,20 +149,25 @@ def fetch_one(): # 1 取得所有指标及其公式 indicator_df = pd.read_sql(f''' select a.id, b.date, a.name, a.formula from indicator_description a -join tdate b on b.is_trade_date=1 +join tdate b on b.is_trade_date=1 and b.date < strftime('%Y%m%d', 'now') -- 注意这里,判断含不含当天 and name in ( '{indicator_name}') ''', conn) - if len(indicator_df)==0: + if len(indicator_df) == 0: print_status(f'指标{indicator_name}没有待处理任务') return - indicator_df['tdate'] = indicator_df['date'].apply(lambda x: f"{(str(x))[0:4]}/{(str(x))[4:6]}/{(str(x))[6:8]}") - indicator_df['tformula'] = indicator_df.apply(lambda row: row['formula'].replace('date()', row['tdate']), axis=1) + indicator_df['tdate'] = indicator_df['date'].apply( + lambda x: f"{(str(x))[0:4]}/{(str(x))[4:6]}/{(str(x))[6:8]}") + indicator_df['tformula'] = indicator_df.apply( + lambda row: row['formula'].replace('date()', row['tdate']), axis=1) print_status(f'正在处理{indicator_df["name"].iloc[0]}') # 填充Excel数据 - ws.range(0, 0, len(indicator_df)-1, 0).value = np.array(indicator_df['tdate']).reshape(-1, 1) - ws.range(0, 1, len(indicator_df)-1, 1).Formula = np.array(indicator_df['tformula']).reshape(-1, 1) - ws.range(0, 2, len(indicator_df)-1, 2).value = np.array(indicator_df['id']).reshape(-1, 1) + ws.range(0, 0, len(indicator_df)-1, + 0).value = np.array(indicator_df['tdate']).reshape(-1, 1) + ws.range(0, 1, len(indicator_df)-1, + 1).Formula = np.array(indicator_df['tformula']).reshape(-1, 1) + ws.range(0, 2, len(indicator_df)-1, + 2).value = np.array(indicator_df['id']).reshape(-1, 1) xlo.app().calculate(full=True, rebuild=True) @@ -183,22 +199,25 @@ def save_one(): MsgBox(f'指标{indicator_name}还未完全计算完毕!') return # 组织数据并插入DB - data_df = pd.DataFrame(data, columns=['indicator_date','value', 'indicator_id']) - data_df = data_df[data_df['value']!=0] # TODO 把值为0的去掉,或许这个对某些指标来说不严谨。 - data_df['indicator_date'] = data_df['indicator_date'].apply(lambda x: xlo.from_excel_date(x).strftime('%Y%m%d')) + data_df = pd.DataFrame( + data, columns=['indicator_date', 'value', 'indicator_id']) + data_df = data_df[data_df['value'] != 0] # TODO 把值为0的去掉,或许这个对某些指标来说不严谨。 + data_df['indicator_date'] = data_df['indicator_date'].apply( + lambda x: xlo.from_excel_date(x).strftime('%Y%m%d')) conn = sqlite3.connect(SQLITE_FILE_PATH) # 删DB数据 - execute_sql(f'delete from indicator_data where indicator_id={data_df["indicator_id"].iloc[0]}') + execute_sql(f'delete from indicator_data where indicator_id={ + data_df["indicator_id"].iloc[0]}') # 插入数据 data_df.to_sql('indicator_data', conn, if_exists='append', index=False) # 更新数据最新日期 - execute_sql(f'''update indicator_description set last_updated_date='{data_df["indicator_date"].iloc[-1]}' where id={data_df["indicator_id"].iloc[-1]}''') + execute_sql(f'''update indicator_description set last_updated_date='{ + data_df["indicator_date"].iloc[-1]}' where id={data_df["indicator_id"].iloc[-1]}''') # 清理Excel ws.range(0, 0, 5000, 2).clear() ws.cell(indicator_row, 8).value = '' - @xlo.func(command=True) def fetch_increment(): ''' @@ -207,35 +226,45 @@ def fetch_increment(): ws = xlo.active_worksheet() ws.range(0, 0, 5000, 4).clear() conn = sqlite3.connect(SQLITE_FILE_PATH) - + # 1 取得所有指标及其公式 indicator_df = pd.read_sql(f''' select a.id, b.date, a.name, a.formula from indicator_description a -join tdate b on b.is_trade_date=1 +join tdate b on b.is_trade_date=1 and b.date > coalesce(a.last_updated_date, '20241020') and b.date <= strftime('%Y%m%d', 'now') -- 注意这里,判断含不含当天 and a.last_updated_date is not null ''', conn) - if len(indicator_df)==0: + if len(indicator_df) == 0: print_status(f'没有待处理任务') return - + # 取得上一交易日 - date_df = pd.read_sql('''select date from tdate where is_trade_date=1 and date>'20241020' ''', conn) + date_df = pd.read_sql( + '''select date from tdate where is_trade_date=1 and date>'20241020' ''', conn) date_df['prev_date'] = date_df['date'].shift(1) date_df.dropna(inplace=True) indicator_df = indicator_df.merge(date_df, on='date', how='inner') - indicator_df['tdate'] = indicator_df['date'].apply(lambda x: f"{(str(x))[0:4]}/{(str(x))[4:6]}/{(str(x))[6:8]}") - indicator_df['tformula'] = indicator_df.apply(lambda row: row['formula'].replace('date()', row['tdate']), axis=1) - indicator_df['prev_tdate'] = indicator_df['prev_date'].apply(lambda x: f"{(str(x))[0:4]}/{(str(x))[4:6]}/{(str(x))[6:8]}") - indicator_df['prev_tformula'] = indicator_df.apply(lambda row: row['formula'].replace('date()', row['prev_tdate']), axis=1) - + indicator_df['tdate'] = indicator_df['date'].apply( + lambda x: f"{(str(x))[0:4]}/{(str(x))[4:6]}/{(str(x))[6:8]}") + indicator_df['tformula'] = indicator_df.apply( + lambda row: row['formula'].replace('date()', row['tdate']), axis=1) + indicator_df['prev_tdate'] = indicator_df['prev_date'].apply( + lambda x: f"{(str(x))[0:4]}/{(str(x))[4:6]}/{(str(x))[6:8]}") + indicator_df['prev_tformula'] = indicator_df.apply( + lambda row: row['formula'].replace('date()', row['prev_tdate']), axis=1) + # 填充Excel数据 - ws.range(0, 0, len(indicator_df)-1, 0).value = np.array(indicator_df['tdate']).reshape(-1, 1) - ws.range(0, 1, len(indicator_df)-1, 1).Formula = np.array(indicator_df['tformula']).reshape(-1, 1) - ws.range(0, 2, len(indicator_df)-1, 2).value = np.array(indicator_df['id']).reshape(-1, 1) - ws.range(0, 3, len(indicator_df)-1, 3).Formula = np.array(indicator_df['prev_tformula']).reshape(-1, 1) - ws.range(0, 4, len(indicator_df)-1, 4).value = np.array(indicator_df['name']).reshape(-1, 1) + ws.range(0, 0, len(indicator_df)-1, + 0).value = np.array(indicator_df['tdate']).reshape(-1, 1) + ws.range(0, 1, len(indicator_df)-1, + 1).Formula = np.array(indicator_df['tformula']).reshape(-1, 1) + ws.range(0, 2, len(indicator_df)-1, + 2).value = np.array(indicator_df['id']).reshape(-1, 1) + ws.range(0, 3, len(indicator_df)-1, + 3).Formula = np.array(indicator_df['prev_tformula']).reshape(-1, 1) + ws.range(0, 4, len(indicator_df)-1, + 4).value = np.array(indicator_df['name']).reshape(-1, 1) xlo.app().calculate(full=True, rebuild=True) @@ -254,16 +283,21 @@ def save_increment(): MsgBox(f'指标还未完全计算完毕!') return # 组织数据并插入DB - data_df = pd.DataFrame(data, columns=['indicator_date','value', 'indicator_id', 'pre_value','name']) - data_df = data_df[(data_df['value']!=data_df['pre_value']) & (data_df['value']!=0.0)].reset_index() - data_df['indicator_date'] = data_df['indicator_date'].apply(lambda x: xlo.from_excel_date(x).strftime('%Y%m%d')) + data_df = pd.DataFrame( + data, columns=['indicator_date', 'value', 'indicator_id', 'pre_value', 'name']) + data_df = data_df[(data_df['value'] != data_df['pre_value']) + & (data_df['value'] != 0.0)].reset_index() + data_df['indicator_date'] = data_df['indicator_date'].apply( + lambda x: xlo.from_excel_date(x).strftime('%Y%m%d')) conn = sqlite3.connect(SQLITE_FILE_PATH) # 插入数据 - data_df[['indicator_date','value','indicator_id']].to_sql('indicator_data', conn, if_exists='append', index=False) + data_df[['indicator_date', 'value', 'indicator_id']].to_sql( + 'indicator_data', conn, if_exists='append', index=False) # 更新数据最新日期 for indicator_id in list(data_df['indicator_id'].unique()): - sub_df = data_df[data_df['indicator_id']==indicator_id] - execute_sql(f'''update indicator_description set last_updated_date='{sub_df["indicator_date"].iloc[-1]}' where id={sub_df["indicator_id"].iloc[-1]}''') + sub_df = data_df[data_df['indicator_id'] == indicator_id] + execute_sql(f'''update indicator_description set last_updated_date='{ + sub_df["indicator_date"].iloc[-1]}' where id={sub_df["indicator_id"].iloc[-1]}''') # 清理Excel ws.range(0, 0, 5000, 4).clear() @@ -294,18 +328,20 @@ def fetch_one_edb(): select id, indicator_name, indicator_id, last_updated_date from edb_desc where indicator_name in ( '{indicator_name}' ) ''', conn) - if len(indicator_df)==0: + if len(indicator_df) == 0: print_status(f'指标{indicator_name}没有待处理任务') return - indicator_df['tdate'] = indicator_df['last_updated_date'].apply(lambda x: f"{(str(x))[0:4]}/{(str(x))[4:6]}/{(str(x))[6:8]}" if x else '2005/01/31') - indicator_df['tformula'] = indicator_df.apply(lambda row: f'=thsMEDB("{row["indicator_id"]}","{row["tdate"]}","","Format(isAsc=N,Display=R,FillBlank=B,DecimalPoint=2,LineBlank=N)")', axis=1) + indicator_df['tdate'] = indicator_df['last_updated_date'].apply( + lambda x: f"{(str(x))[0:4]}/{(str(x))[4:6]}/{(str(x))[6:8]}" if x else '2005/01/31') + indicator_df['tformula'] = indicator_df.apply(lambda row: f'=thsMEDB("{row["indicator_id"]}","{ + row["tdate"]}","","Format(isAsc=N,Display=R,FillBlank=B,DecimalPoint=2,LineBlank=N)")', axis=1) print_status(f'正在处理{indicator_df["indicator_name"].iloc[0]}') # 填充Excel数据 - ws.range(0, 1, len(indicator_df)-1, 1).Formula = np.array(indicator_df['tformula']).reshape(-1, 1) + ws.range(0, 1, len(indicator_df)-1, + 1).Formula = np.array(indicator_df['tformula']).reshape(-1, 1) xlo.app().calculate(full=True, rebuild=True) - @xlo.func(command=True) def save_one_edb(): ''' @@ -330,11 +366,10 @@ def save_one_edb(): select id, indicator_name, indicator_id, last_updated_date from edb_desc where indicator_name in ( '{indicator_name}' ) ''', conn) - if len(indicator_df)==0: + if len(indicator_df) == 0: print_status(f'指标{indicator_name}不存在,请检查数据库') return indicator_foreign_id = indicator_df['id'].iloc[0] # 获得外键中的指标id - data = ws.range(0, 0, 5000, 1).value i = 0 @@ -345,16 +380,19 @@ def save_one_edb(): MsgBox(f'指标{indicator_name}还未完全计算完毕!') return # 组织数据并插入DB - data_df = pd.DataFrame(data, columns=['indicator_date','value']) + data_df = pd.DataFrame(data, columns=['indicator_date', 'value']) data_df['desc_id'] = indicator_foreign_id - data_df = data_df[data_df['value']!=0] # TODO 把值为0的去掉,或许这个对某些指标来说不严谨。 - data_df['indicator_date'] = data_df['indicator_date'].apply(lambda x: xlo.from_excel_date(x).strftime('%Y%m%d')) + data_df = data_df[data_df['value'] != 0] # TODO 把值为0的去掉,或许这个对某些指标来说不严谨。 + data_df['indicator_date'] = data_df['indicator_date'].apply( + lambda x: xlo.from_excel_date(x).strftime('%Y%m%d')) # 删DB数据 - execute_sql(f'delete from edb_data where desc_id={data_df["desc_id"].iloc[0]}') + execute_sql(f'delete from edb_data where desc_id={ + data_df["desc_id"].iloc[0]}') # 插入数据 data_df.to_sql('edb_data', conn, if_exists='append', index=False) # 更新数据最新日期 - execute_sql(f'''update edb_desc set last_updated_date='{np.max(data_df["indicator_date"])}' where id={data_df["desc_id"].iloc[-1]}''') + execute_sql(f'''update edb_desc set last_updated_date='{np.max( + data_df["indicator_date"])}' where id={data_df["desc_id"].iloc[-1]}''') # 清理Excel ws.range(0, 0, 5000, 1).clear() ws.cell(indicator_row, 8).value = '' @@ -368,26 +406,29 @@ def fetch_daily_edb(): ws = xlo.active_worksheet() ws.range(0, 0, 5000, 500).clear() conn = sqlite3.connect(SQLITE_FILE_PATH) - + # 1 取得日频数据中最小日期的下一天 min_date = pd.read_sql(''' select min(last_updated_date) from edb_desc where frequency ='日' ''', conn) - if len(min_date)==0: + if len(min_date) == 0: print_status(f'没有日频指标') return - min_date = dt.datetime.strptime(f'{min_date.iloc[0,0]}', '%Y%m%d') + dt.timedelta(days=1) + min_date = dt.datetime.strptime( + f'{min_date.iloc[0, 0]}', '%Y%m%d') + dt.timedelta(days=1) min_date = min_date.strftime('%Y/%m/%d') # 2 取得所有EDB指标,写入B1及之后的横排 indicator_df = pd.read_sql(f''' select indicator_id from edb_desc where frequency ='日' ''', conn) - if len(indicator_df)==0: + if len(indicator_df) == 0: print_status(f'没有待处理任务') return - ws.range(0, 1, 0, len(indicator_df)).value = indicator_df['indicator_id'].to_numpy().reshape(1, -1) + ws.range(0, 1, 0, len(indicator_df) + ).value = indicator_df['indicator_id'].to_numpy().reshape(1, -1) indicator_address = ws.range(0, 1, 0, len(indicator_df)).address() - ws.cell(1,1).Formula = f'''=thsMEDB({str(indicator_address)},"{min_date}","","Format(isAsc=N,Display=R,FillBlank=B,DecimalPoint=2,LineBlank=N)")''' + ws.cell(1, 1).Formula = f'''=thsMEDB({str(indicator_address)},"{ + min_date}","","Format(isAsc=N,Display=R,FillBlank=B,DecimalPoint=2,LineBlank=N)")''' xlo.app().calculate(full=True, rebuild=True) @@ -400,32 +441,37 @@ def save_daily_edb(): # 1 拿到数据并melt为合适格式 data = ws.used_range.value data[0][0] = 'indicator_date' - + i = 0 while i < len(data) and data[i][0] is not None: i += 1 data = data[:i] data_df = pd.DataFrame(data[1:], columns=data[0]) - data_df['indicator_date'] = data_df['indicator_date'].apply(lambda x: xlo.from_excel_date(x).strftime('%Y%m%d')) - melted_df = data_df.melt(id_vars='indicator_date', var_name='indicator_id', value_name='indicator_value') - + data_df['indicator_date'] = data_df['indicator_date'].apply( + lambda x: xlo.from_excel_date(x).strftime('%Y%m%d')) + melted_df = data_df.melt( + id_vars='indicator_date', var_name='indicator_id', value_name='indicator_value') + # 2 得到具体id conn = sqlite3.connect(SQLITE_FILE_PATH) desc_df = pd.read_sql('select id, indicator_id from edb_desc', conn) melted_df = melted_df.merge(desc_df, on='indicator_id', how='left') # 3 删除原有的 date_str = ",".join(list(melted_df["indicator_date"].unique())) - execute_sql(f'''delete from edb_data where indicator_date in ( {date_str} ) ''') + execute_sql( + f'''delete from edb_data where indicator_date in ( {date_str} ) ''') # 4 组织数据并插入DB - data_df = melted_df[['id','indicator_date','indicator_value']].rename(columns={'id':'desc_id', 'indicator_value':'value'}) + data_df = melted_df[['id', 'indicator_date', 'indicator_value']].rename( + columns={'id': 'desc_id', 'indicator_value': 'value'}) data_df = data_df.dropna(subset=['value']) data_df.to_sql('edb_data', conn, if_exists='append', index=False) - + # 5 更新数据最新日期 for desc_id in list(data_df['desc_id'].unique()): - sub_df = data_df[data_df['desc_id']==desc_id] - execute_sql(f'''update edb_desc set last_updated_date='{np.max(sub_df["indicator_date"])}' where id={sub_df["desc_id"].iloc[-1]}''') + sub_df = data_df[data_df['desc_id'] == desc_id] + execute_sql(f'''update edb_desc set last_updated_date='{np.max( + sub_df["indicator_date"])}' where id={sub_df["desc_id"].iloc[-1]}''') # 清理Excel ws.used_range.clear() @@ -449,8 +495,8 @@ def get_data_from_db(indicators: List, start_date: str): and b.indicator_name in ( {indicators_str} ) union all - - select b.indicator_name, a.value, a.indicator_date + + select b.indicator_name, a.value, a.indicator_date from wind_data a join wind_desc b on a.indicator_id=b.id and a.indicator_date >'{start_date}' @@ -459,10 +505,12 @@ def get_data_from_db(indicators: List, start_date: str): ''' df = pd.read_sql(sql_stat, conn) pivot_df = df.pivot(index='日期', columns='name', values='value') - pivot_df.index = pd.to_datetime(pivot_df.index.astype(str), format='%Y%m%d') + pivot_df.index = pd.to_datetime( + pivot_df.index.astype(str), format='%Y%m%d') pivot_df = pivot_df.dropna() return pivot_df + def extract_variables(expr: str): """ 从表达式中提取所有变量名 支持变量含有冒号 不支持变量含有减号、下划线!! @@ -470,7 +518,7 @@ def extract_variables(expr: str): :param expr: 表达式字符串,例如 "ma40(ma20(col1) - ema30(col2) - col3:2) + ema10(col1:2)" :return: 包含所有变量名的集合 """ - expr = expr.replace(':','_') + expr = expr.replace(':', '_') variables = set() functions = set() @@ -519,12 +567,13 @@ def _get_alias_map(): conn = sqlite3.connect(SQLITE_FILE_PATH) df = pd.read_sql(sql_stat, conn) return dict(zip(df['remark'], df['indicator_name'])) - + def _replace_variables_in_expr(expr: str) -> str: expr = expr.replace(':', '_') var_map = _get_alias_map() # 替换表达式中的变量名 + class CustomVisitor(ast.NodeVisitor): def __init__(self, var_map): self.var_map = var_map @@ -553,34 +602,42 @@ def visit_Name(self, node): def ema(series: pd.Series, window: int) -> pd.Series: return series.ewm(span=window, adjust=False, min_periods=window).mean() + def ma(series: pd.Series, window: int) -> pd.Series: return series.rolling(window=window, min_periods=window).mean() + def zscore(series: pd.Series, window: int) -> pd.Series: rolling_mean = series.rolling(window=window, min_periods=window).mean() rolling_std = series.rolling(window=window, min_periods=window).std() return (series - rolling_mean) / rolling_std + def pr(series: pd.Series, window: int) -> pd.Series: """Calculate the percentile rank of the most recent value in the rolling window.""" - return series.rolling(window=window, min_periods=window).apply( + return series.rolling(window=window, min_periods=1).apply( lambda x: (x.argsort().argsort()[-1] + 1) / len(x) * 100, raw=True ) -def p(series: pd.Series, percentile: int) -> pd.Series: # p70 表示所有数据中P70的点位,水平线来的 + +def p(series: pd.Series, percentile: int) -> pd.Series: # p70 表示所有数据中P70的点位,水平线来的 value = series.quantile(percentile / 100.0) return pd.Series(value, index=series.index) -def macd(series: pd.Series, short_window: int, long_window: int, signal_window: int) -> pd.DataFrame: # # MACD (指数平滑异同移动平均线) - short_ema = series.ewm(span=short_window, adjust=False, min_periods=short_window).mean() - long_ema = series.ewm(span=long_window, adjust=False, min_periods=long_window).mean() + +def macd(series: pd.Series, short_window: int, long_window: int, signal_window: int) -> pd.DataFrame: # MACD (指数平滑异同移动平均线) + short_ema = series.ewm(span=short_window, adjust=False, + min_periods=short_window).mean() + long_ema = series.ewm(span=long_window, adjust=False, + min_periods=long_window).mean() macd_line = short_ema - long_ema - signal_line = macd_line.ewm(span=signal_window, adjust=False, min_periods=signal_window).mean() + signal_line = macd_line.ewm( + span=signal_window, adjust=False, min_periods=signal_window).mean() histogram = macd_line - signal_line return pd.DataFrame({'MACD': macd_line, 'Signal': signal_line, 'Histogram': histogram}) -def rsi(series: pd.Series, n: int) -> pd.Series: # RSI (相对强弱指数) +def rsi(series: pd.Series, n: int) -> pd.Series: # RSI (相对强弱指数) delta = series.diff() gain = (delta.where(delta > 0, 0)).rolling(window=n, min_periods=n).mean() loss = (-delta.where(delta < 0, 0)).rolling(window=n, min_periods=n).mean() @@ -588,7 +645,8 @@ def rsi(series: pd.Series, n: int) -> pd.Series: # RSI (相对强弱指数) return 100 - (100 / (1 + rs)) -def bollinger_bands(series: pd.Series, n: int, num_std: float) -> pd.DataFrame: # Bollinger Bands (布林带) +# Bollinger Bands (布林带) +def bollinger_bands(series: pd.Series, n: int, num_std: float) -> pd.DataFrame: sma = series.rolling(window=n, min_periods=n).mean() std = series.rolling(window=n, min_periods=n).std() upper_band = sma + num_std * std @@ -596,6 +654,8 @@ def bollinger_bands(series: pd.Series, n: int, num_std: float) -> pd.DataFrame: return pd.DataFrame({'SMA': sma, 'Upper Band': upper_band, 'Lower Band': lower_band}) # 处理字段名,去掉特殊字符并映射到新字段 + + def normalize_column_names(df: pd.DataFrame, expr: str) -> (pd.DataFrame, dict, str): special_char_pattern = r"[:]" # 指标的特殊字符只支持冒号 不支持减号 column_map = {} @@ -610,13 +670,14 @@ def normalize_column_names(df: pd.DataFrame, expr: str) -> (pd.DataFrame, dict, expr = re.sub(rf"\b{re.escape(original)}\b", normalized, expr) return df, column_map, expr + def calc_complicate_indicator(df: pd.DataFrame, expr: str) -> pd.DataFrame: # 处理字段名特殊字符 df, column_map, expr = normalize_column_names(df, expr) # 解析 AST parsed_ast = ast.parse(expr, mode='eval') - + # 提取函数和变量 class FunctionAndVariableVisitor(ast.NodeVisitor): def __init__(self): @@ -641,30 +702,40 @@ def visit_Name(self, node): # 动态生成函数并存储 my_functions = {} for func in functions: - if re.match(r"^ema(\d+)$", func): # 例如 ema20 + if re.match(r"^ema(\d+)$", func): # 例如 ema20 window = int(re.match(r"^ema(\d+)$", func).group(1)) - my_functions[func] = lambda series, window=window: ema(series, window) - elif re.match(r"^ma(\d+)$", func): # 例如 ma30 + my_functions[func] = lambda series, window=window: ema( + series, window) + elif re.match(r"^ma(\d+)$", func): # 例如 ma30 window = int(re.match(r"^ma(\d+)$", func).group(1)) - my_functions[func] = lambda series, window=window: ma(series, window) - elif re.match(r"^zscore(\d+)$", func): # 例如 zscore30 + my_functions[func] = lambda series, window=window: ma( + series, window) + elif re.match(r"^zscore(\d+)$", func): # 例如 zscore30 window = int(re.match(r"^zscore(\d+)$", func).group(1)) - my_functions[func] = lambda series, window=window: zscore(series, window) - elif re.match(r"^pr(\d+)$", func): # 例如 pr70 + my_functions[func] = lambda series, window=window: zscore( + series, window) + elif re.match(r"^pr(\d+)$", func): # 例如 pr70 window = int(re.match(r"^pr(\d+)$", func).group(1)) - my_functions[func] = lambda series, window=window: pr(series, window) - elif re.match(r"^p(\d+)$", func): # 例如 p70 + my_functions[func] = lambda series, window=window: pr( + series, window) + elif re.match(r"^p(\d+)$", func): # 例如 p70 percentile = int(re.match(r"^p(\d+)$", func).group(1)) - my_functions[func] = lambda series, percentile=percentile: p(series, percentile) + my_functions[func] = lambda series, percentile=percentile: p( + series, percentile) elif re.match(r"^macd(\d+)_(\d+)_(\d+)$", func): # 例如 macd12_26_9 - params = list(map(int, re.match(r"^macd(\d+)_(\d+)_(\d+)$", func).groups())) - my_functions[func] = lambda series, params=params: macd(series, *params) + params = list( + map(int, re.match(r"^macd(\d+)_(\d+)_(\d+)$", func).groups())) + my_functions[func] = lambda series, params=params: macd( + series, *params) elif re.match(r"^rsi(\d+)$", func): # 例如 rsi14 window = int(re.match(r"^rsi(\d+)$", func).group(1)) - my_functions[func] = lambda series, window=window: rsi(series, window) + my_functions[func] = lambda series, window=window: rsi( + series, window) elif re.match(r"^bollinger(\d+)_(\d+)$", func): # 例如 bollinger20_2 - params = list(map(int, re.match(r"^bollinger(\d+)_(\d+)$", func).groups())) - my_functions[func] = lambda series, params=params: bollinger_bands(series, *params) + params = list( + map(int, re.match(r"^bollinger(\d+)_(\d+)$", func).groups())) + my_functions[func] = lambda series, params=params: bollinger_bands( + series, *params) # 构造 pd.eval 环境 eval_env = {"df": df} @@ -676,26 +747,29 @@ def visit_Name(self, node): # 计算表达式并添加到 DataFrame res = pd.eval(expr, local_dict=eval_env) - if type(res)==pd.Series: + if type(res) == pd.Series: df[expr.replace('df.', '')] = res - elif type(res)==pd.DataFrame: + elif type(res) == pd.DataFrame: df = pd.concat([df, res], axis=1) return df -def get_data(indicators: List, start_date: str='20200101'): - origin_indicators = [_replace_variables_in_expr(indi) for indi in indicators if indi] # 用户看到的带冒号的指标名 不支持 减号 下划线 - expr_indicators = [_replace_variables_in_expr(indi) for indi in indicators if indi] # 支持 eval的指标名 - indicators_set = set() # 支持eval的单个指标名 +def get_data(indicators: List, start_date: str = '20200101'): + origin_indicators = [_replace_variables_in_expr( + indi) for indi in indicators if indi] # 用户看到的带冒号的指标名 不支持 减号 下划线 + expr_indicators = [_replace_variables_in_expr( + indi) for indi in indicators if indi] # 支持 eval的指标名 + indicators_set = set() # 支持eval的单个指标名 for indi in expr_indicators: - if '(' in indi or '-' in indi or '+' in indi: # 若是需要表达式解析的 + if '(' in indi or '-' in indi or '+' in indi: # 若是需要表达式解析的 indicators_set.update(extract_variables(indi)) else: indicators_set.add(indi) - df = get_data_from_db([indi.replace('_', ':') for indi in list(indicators_set)], start_date) # 从数据库得到所有指标 + df = get_data_from_db([indi.replace('_', ':') for indi in list( + indicators_set)], start_date) # 从数据库得到所有指标 df.columns = [c.replace(':', '_') for c in df.columns] for indi in expr_indicators: - if '(' in indi or '-' in indi or '+' in indi: # 若是需要表达式解析的 + if '(' in indi or '-' in indi or '+' in indi: # 若是需要表达式解析的 df = calc_complicate_indicator(df, indi) return df[origin_indicators] @@ -715,8 +789,10 @@ def iReturn(arr, start_date="20050101") -> PDFrame(headings=True, index=True): @xlo.func def iPlot(y_indicators, y2_indicators, start_date="20050101", convert2return=False, title=""): - y_indicators = [_replace_variables_in_expr(x) for x in y_indicators.flatten().tolist() if x] - y2_indicators = [_replace_variables_in_expr(x) for x in y2_indicators.flatten().tolist() if x] + y_indicators = [_replace_variables_in_expr( + x) for x in y_indicators.flatten().tolist() if x] + y2_indicators = [_replace_variables_in_expr( + x) for x in y2_indicators.flatten().tolist() if x] df_data = get_data(y_indicators + y2_indicators, start_date) if convert2return: df_data = ((df_data.pct_change() + 1).cumprod() - 1) * 100 @@ -725,10 +801,11 @@ def iPlot(y_indicators, y2_indicators, start_date="20050101", convert2return=Fal y2_indicators = [f'{x}(右)' for x in y2_indicators] fig = px.line(df_data, x=df_data.index, y=y_indicators) - for trace in fig.data: # 设置y1如果是分位水平线,则要改成虚线 + for trace in fig.data: # 设置y1如果是分位水平线,则要改成虚线 if trace.name in y_indicators: - trace.line.dash = 'dash' if re.match(r'^p\d+', trace.name) else 'solid' - + trace.line.dash = 'dash' if re.match( + r'^p\d+', trace.name) else 'solid' + # 设置第二个 Y 轴 fig.update_layout( title=title, @@ -738,18 +815,124 @@ def iPlot(y_indicators, y2_indicators, start_date="20050101", convert2return=Fal side='right' ) ) - + # 添加第二个 Y 轴的数据 for col in y2_indicators: - line_dash = 'dash' if re.match(r'^p\d+', col) else 'solid' # 如果是p30之类的,则画虚线 - fig.add_scatter(x=df_data.index, y=df_data[col], mode='lines', name=col, yaxis='y2', line=dict(dash=line_dash), hovertemplate='variable='+col+'
日期=%{x}
value=%{y}') + line_dash = 'dash' if re.match( + r'^p\d+', col) else 'solid' # 如果是p30之类的,则画虚线 + fig.add_scatter(x=df_data.index, y=df_data[col], mode='lines', name=col, yaxis='y2', line=dict( + dash=line_dash), hovertemplate='variable='+col+'
日期=%{x}
value=%{y}') # legend中太长的描述要换行 - fig.for_each_trace(lambda trace: trace.update(name='
'.join([trace.name[i:i+18] for i in range(0, len(trace.name), 18)]))) + fig.for_each_trace(lambda trace: trace.update(name='
'.join( + [trace.name[i:i+18] for i in range(0, len(trace.name), 18)]))) fig.show() - return '查看' + return '折线图' + + +@xlo.func +def iMatrix(indicators: List[str], start_date="20050101") -> PDFrame(headings=True, index=True): + df = get_data(indicators, start_date) + + correlation_df = df.corr() # 计算相关性矩阵 + + fig = px.imshow(correlation_df, + text_auto=True, + color_continuous_scale='RdYlGn', + title=f'相关性矩阵,起始日{start_date}') # 使用Plotly生成热力图 + + fig.show() + return '相关性矩阵' + + +@xlo.func +def iCoint(indicators: List[str], start_date="20050101") -> PDFrame(headings=True, index=True): + # 获取多个指标的数据 + df = get_data(indicators, start_date) + + # 计算各个资产的日收益率(如果数据是价格的话) + for col in df.columns: + df[col] = np.log(df[col] / df[col].shift(1)) # 对数收益率 + + # 清理数据:去除缺失值(NaN)和无穷大值(inf) + df.replace([np.inf, -np.inf], np.nan, inplace=True) # 将无穷大值替换为NaN + df.dropna(inplace=True) # 删除含有NaN的行 + + # 计算协整性 p 值矩阵 + p_values = [] + columns = df.columns + for i in range(len(columns)): + for j in range(i + 1, len(columns)): + # 进行协整性测试 (Engle-Granger 二步法) + score, p_value, _ = coint(df[columns[i]], df[columns[j]]) + p_values.append((columns[i], columns[j], p_value)) + + # 将 p 值结果存储到 DataFrame + result_df = pd.DataFrame( + p_values, columns=["Asset 1", "Asset 2", "p-value"]) + + # 构造协整性 p 值矩阵 + p_matrix = pd.pivot_table(result_df, values='p-value', + index='Asset 1', columns='Asset 2') + + # 使用 Plotly 生成协整性 p 值矩阵的热力图 + fig = px.imshow(p_matrix, + text_auto=True, + color_continuous_scale='RdYlGn', # 使用绿黄红渐变色 + title=f'Cointegration Test p-value Matrix 起始日{start_date}') + fig.show() + return '协整性测试' + + +@xlo.func +def iAlphaBeta(assets: str, benchmark: str, start_date: str = "20050101") -> Tuple[PDFrame, any]: + assets = assets.replace(':', '_') + benchmark = benchmark.replace(':', '_') + # 获取资产和基准的数据 + asset_df = get_data([assets], start_date) + benchmark_df = get_data([benchmark], start_date) + + # 计算对数收益率 + asset_returns = np.log(asset_df[assets] / asset_df[assets].shift(1)) + benchmark_returns = np.log( + benchmark_df[benchmark] / benchmark_df[benchmark].shift(1)) + + # 清理数据:去除NaN或无穷值 + data = pd.concat([asset_returns, benchmark_returns], axis=1).dropna() + data.columns = ['Asset', 'Benchmark'] + + # 计算 Alpha 和 Beta + X = add_constant(data['Benchmark']) # 为基准添加常数项 + y = data['Asset'] + + model = OLS(y, X).fit() + alpha, beta = model.params + r_squared = model.rsquared + + # 计算 Sharpe Ratio(假设无风险利率为 0) + excess_returns = data['Asset'] - data['Benchmark'] # 超额收益 + sharpe_ratio = np.mean(excess_returns) / \ + np.std(excess_returns) * np.sqrt(252) # 年化 Sharpe 比率 + + # 生成回归图 + fig = px.scatter(data_frame=data, x='Benchmark', y='Asset', + title=f"Asset vs Benchmark: Alpha = {alpha:.4f}, Beta = {beta:.4f}, R² = {r_squared:.4f}, Sharpe Ratio = {sharpe_ratio}") + fig.update_layout( + xaxis_title=f'{benchmark} Daily Returns', + yaxis_title=f'{assets} Daily Returns' + ) + # 添加红色回归线 + fig.add_trace( + px.line(x=data['Benchmark'], y=model.fittedvalues, + line_shape='linear').data[0] + ) + fig.data[-1].line.color = 'red' # 设置回归线的颜色为红色 + + fig.show() + + return '资产基准回归图' # if __name__ == '__main__': -# create_table() \ No newline at end of file +# create_table() diff --git a/backtest/requirements.txt b/backtest/requirements.txt index 6e63b02..019aac1 100644 --- a/backtest/requirements.txt +++ b/backtest/requirements.txt @@ -10,3 +10,5 @@ scipy bottleneck SQLAlchemy==2.0.20 cx_Oracle==8.3.0 +statsmodels==0.14.4 +