-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
adding alternative entry point to train several models sequentially i…
…n one go
- Loading branch information
1 parent
732257e
commit 8778e4b
Showing
2 changed files
with
57 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
[ | ||
{ | ||
"dataframe_csv": "filepath to csv", | ||
"data_path": "filepath to data", | ||
"wandb_run_name": "run name" | ||
}, | ||
{ | ||
"dataframe_csv": "filepath to csv", | ||
"data_path": "filepath to data", | ||
"wandb_run_name": "run name", | ||
"learning_rate": 0.001, | ||
"batch_size": 42 | ||
} | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
#script for running several training runs with different configurations | ||
import os | ||
import time | ||
import torch | ||
import json | ||
import pyha_analyzer as pa | ||
from pyha_analyzer import config | ||
|
||
cfg = config.cfg | ||
|
||
# Function to update the config and run the training | ||
def run_training(config_dict): | ||
# Dynamically update cfg variables from the config_dict | ||
for key, value in config_dict.items(): | ||
if hasattr(cfg, key): | ||
setattr(cfg, key, value) | ||
else: | ||
print(f"Warning: Config has no attribute '{key}', skipping...") | ||
|
||
# Set multiprocessing strategy | ||
torch.multiprocessing.set_sharing_strategy('file_system') | ||
torch.multiprocessing.set_start_method('spawn') | ||
|
||
# Run the training | ||
pa.train.main(in_sweep=False) | ||
|
||
# Function to load configurations from a JSON file and run the training runs | ||
def run_training_runs(config_file): | ||
# Load configurations from the JSON file | ||
with open(config_file, 'r') as file: | ||
configurations = json.load(file) | ||
|
||
# Iterate over each configuration and run the training | ||
for config in configurations: | ||
print(f"Running training for: {config.get('wandb_run_name', 'Unnamed Run')}") | ||
run_training(config) | ||
time.sleep(10) # Optional wait time between runs to make sure logging is done fully | ||
|
||
# Entry point to start the training runs | ||
if __name__ == "__main__": | ||
config_file_path = "sequential_run_cfg" # Change to your config file path if needed | ||
run_training_runs(config_file_path) | ||
|