Skip to content

Commit

Permalink
add tqdm
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelndor committed Mar 4, 2024
1 parent 2d5a42e commit 22ec31e
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 2 deletions.
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,5 @@ gymnasium>=0.29.1
numpy >=1.21.0
scikit-learn>=1.4.0
pytest
tqdm
matplotlib
3 changes: 1 addition & 2 deletions src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import os
from tqdm import tqdm
from sklearn.ensemble import RandomForestRegressor
import joblib # Use this import instead if you have a newer version of scikit-learn

env = TimeLimit(
env=HIVPatient(domain_randomization=False), max_episode_steps=200
Expand Down Expand Up @@ -63,7 +62,7 @@ def act(self, observation, use_random=False):
net_pop = self.model_pop
device = "cuda" if next(network.parameters()).is_cuda else "cpu"
with torch.no_grad():
Q = network(torch.Tensor(observation).unsqueeze(0).to(device))*0.32 + net_pop(torch.Tensor(observation).unsqueeze(0).to(device))
Q = network(torch.Tensor(observation).unsqueeze(0).to(device)) + 0.7*net_pop(torch.Tensor(observation).unsqueeze(0).to(device))
return torch.argmax(Q).item()
else :
return np.random.randint(4)
Expand Down

0 comments on commit 22ec31e

Please sign in to comment.