Skip to content

Commit

Permalink
using the new function
Browse files Browse the repository at this point in the history
  • Loading branch information
sarthakpati committed Sep 16, 2022
1 parent fb3718f commit 34fbc34
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 18 deletions.
15 changes: 5 additions & 10 deletions GANDLF/compute/forward_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
get_unique_timestamp,
resample_image,
reverse_one_hot,
get_ground_truths_and_predictions_tensor,
)
from GANDLF.metrics import overall_stats
from tqdm import tqdm
Expand Down Expand Up @@ -102,16 +103,10 @@ def validate_network(

# get ground truths for classification problem, validation set
if is_classification and mode == "validation":
ground_truth_array = torch.from_numpy(
params["validation_data"][
params["validation_data"].columns[
params["headers"]["predictionHeaders"]
]
]
.to_numpy()
.ravel()
).type(torch.int)
predictions_array = torch.zeros_like(ground_truth_array)
(
ground_truth_array,
predictions_array,
) = get_ground_truths_and_predictions_tensor(params, "validation_data")

for batch_idx, (subject) in enumerate(
tqdm(valid_dataloader, desc="Looping over " + mode + " data")
Expand Down
13 changes: 5 additions & 8 deletions GANDLF/compute/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
version_check,
write_training_patches,
print_model_summary,
get_ground_truths_and_predictions_tensor,
)
from GANDLF.metrics import overall_stats
from GANDLF.logger import Logger
Expand Down Expand Up @@ -72,14 +73,10 @@ def train_network(model, train_dataloader, optimizer, params):

# get ground truths
if params["problem_type"] == "classification":
ground_truth_array = torch.from_numpy(
params["training_data"][
params["training_data"].columns[params["headers"]["predictionHeaders"]]
]
.to_numpy()
.ravel()
).type(torch.int)
predictions_array = torch.zeros_like(ground_truth_array)
(
ground_truth_array,
predictions_array,
) = get_ground_truths_and_predictions_tensor(params, "training_data")
# Set the model to train
model.train()
for batch_idx, (subject) in enumerate(
Expand Down

0 comments on commit 34fbc34

Please sign in to comment.