Skip to content

Commit

Permalink
For real this time
Browse files Browse the repository at this point in the history
  • Loading branch information
alex404 committed Oct 31, 2024
1 parent f95d7e1 commit cddd898
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions runner/frameworks/classification/analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,8 @@ def analyze(
if epoch == 0:
_perform_initialization_analysis(
channel_analysis,
analyses_dir,
use_wandb,
analyses_dir,
plot_dir,
checkpoint_plot_dir,
run_dir,
Expand All @@ -128,6 +128,7 @@ def analyze(

_perform_reconstruction_analysis(
use_wandb,
analyses_dir,
plot_dir,
checkpoint_plot_dir,
device,
Expand All @@ -152,8 +153,8 @@ def _plot_and_save_histories(plot_dir: Path, histories: Dict[str, List[float]]):

def _perform_initialization_analysis(
channel_analysis: bool,
analyses_dir: Path,
use_wandb: bool,
analyses_dir: Path,
plot_dir: Path,
checkpoint_plot_dir: Path,
run_dir: Path,
Expand Down Expand Up @@ -326,6 +327,7 @@ def _analyze_regular_layer(

def _perform_reconstruction_analysis(
use_wandb: bool,
analyses_dir: Path,
plot_dir: Path,
checkpoint_plot_dir: Path,
device: torch.device,
Expand All @@ -348,7 +350,7 @@ def _perform_reconstruction_analysis(
reconstruct_images(device, brain, decoder, train_set, test_set, 5)
)
# Save the reconstructions
rec_path = plot_dir / f"{decoder}_reconstructions_epoch_{epoch}.json"
rec_path = analyses_dir / f"{decoder}_reconstructions_epoch_{epoch}.json"
with open(rec_path, "w") as f:
json.dump(rec_dict, f, cls=NumpyEncoder)

Expand Down

0 comments on commit cddd898

Please sign in to comment.