You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
`# DRL models from Stable Baselines 3
from future import annotations
import time
import numpy as np
import pandas as pd
from stable_baselines3 import A2C
from stable_baselines3 import DDPG
from stable_baselines3 import PPO
from stable_baselines3 import SAC
from stable_baselines3 import TD3
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.noise import NormalActionNoise
from stable_baselines3.common.noise import OrnsteinUhlenbeckActionNoise
from stable_baselines3.common.vec_env import DummyVecEnv
from finrl import config
from finrl.meta.env_stock_trading.env_stocktrading import StockTradingEnv
from finrl.meta.preprocessor.preprocessors import data_split
`# DRL models from Stable Baselines 3
from future import annotations
import time
import numpy as np
import pandas as pd
from stable_baselines3 import A2C
from stable_baselines3 import DDPG
from stable_baselines3 import PPO
from stable_baselines3 import SAC
from stable_baselines3 import TD3
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.noise import NormalActionNoise
from stable_baselines3.common.noise import OrnsteinUhlenbeckActionNoise
from stable_baselines3.common.vec_env import DummyVecEnv
from finrl import config
from finrl.meta.env_stock_trading.env_stocktrading import StockTradingEnv
from finrl.meta.preprocessor.preprocessors import data_split
MODELS = {"a2c": A2C, "ddpg": DDPG, "td3": TD3, "sac": SAC, "ppo": PPO}
MODEL_KWARGS = {x: config.dict[f"{x.upper()}_PARAMS"] for x in MODELS.keys()}
NOISE = {
"normal": NormalActionNoise,
"ornstein_uhlenbeck": OrnsteinUhlenbeckActionNoise,
}
class TensorboardCallback(BaseCallback):
"""
Custom callback for plotting additional values in tensorboard.
"""
class DRLAgent:
"""Provides implementations for DRL algorithms
class DRLEnsembleAgent:
@staticmethod
def get_model(
model_name,
env,
policy="MlpPolicy",
policy_kwargs=None,
model_kwargs=None,
seed=None,
verbose=1,
):
model_sac = self.get_model(
"sac", self.train_env, policy="MlpPolicy", model_kwargs=SAC_model_kwargs
)
model_sac = self.train_model(
model_sac,
"sac",
tb_log_name=f"sac_{i}",
iter_num=i,
total_timesteps=timesteps_dict["sac"],
) # 100_000
print("======TD3 Training========")
model_td3 = self.get_model(
"td3", self.train_env, policy="MlpPolicy", model_kwargs=TD3_model_kwargs
)
model_td3 = self.train_model(
model_td3,
"td3",
tb_log_name=f"td3_{i}",
iter_num=i,
total_timesteps=timesteps_dict["td3"],
) # 100_000
`
The text was updated successfully, but these errors were encountered: