Skip to content

Commit

Permalink
fix typos add singapore pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
KatrionaGoldmann committed Oct 30, 2024
1 parent 02834aa commit 70f81f6
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 16 deletions.
5 changes: 2 additions & 3 deletions cr_analysis.sh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#!/bin/bash
#SBATCH --output=cr_garden2.out
#SBATCH --output=./logs/cr_garden2.out

source ~/miniforge3/bin/activate

Expand All @@ -8,12 +8,11 @@ CONDA_ENV_PATH="~/moth_detector_env/"

# Activate the environment
conda activate "${CONDA_ENV_PATH}"
#conda install --yes --file requirements.txt

# Print the Costa Rica deployments avaialble on the object store
# python print_deployments.py --subset_countries 'Costa Rica'

# Run the Python script on JASMIN
# Run the Inference script
python s3_download_with_inference.py \
--country "Costa Rica" \
--deployment "Garden - 3F1C4908"
Expand Down
3 changes: 1 addition & 2 deletions s3_download_with_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,6 @@ def download_and_inference(
print("\033[93m\033[1m" + "Pipeline parameters" + "\033[0m\033[0m")
print(f"\033[93m - Scratch and crops storage: {data_storage_path}\033[0m")

crops_interval = args.crops_interval
if args.keep_crops:
crops_interval = args.crops_interval
print(f"\033[93m - Keeping crops every {crops_interval}mins\033[0m")
Expand All @@ -272,7 +271,7 @@ def download_and_inference(
download_and_inference(
args.country,
args.deployment,
crops_interval,
int(crops_interval),
args.rerun_existing,
data_storage_path,
args.perform_inference,
Expand Down
23 changes: 23 additions & 0 deletions singapore_analysis.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#!/bin/bash
#SBATCH --output=./logs/singapore.out
#SBATCH --time=24:00:00

source ~/miniforge3/bin/activate

# Define the path to your existing Conda environment (modify as appropriate)
CONDA_ENV_PATH="~/moth_detector_env/"

# Activate the environment
conda activate "${CONDA_ENV_PATH}"

# Print the Costa Rica deployments avaialble on the object store
# python print_deployments.py --subset_countries 'Costa Rica'

# Run the Inference script
python s3_download_with_inference.py \
--country "Singapore" \
--deployment "All" \
--crops_interval 10 \
--keep_crops \
--data_storage_path ./data/singapore

15 changes: 4 additions & 11 deletions utils/custom_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,38 +85,31 @@ def load_models(device, localisation_model_path, binary_model_path, order_model_
in_features, num_classes
)
)
checkpoint = torch.load(weights_path, map_location=device)
checkpoint = torch.load(weights_path, map_location=device, weights_only=True)
state_dict = checkpoint.get("model_state_dict") or checkpoint
model_loc.load_state_dict(state_dict)
model_loc = model_loc.to(device)
model_loc.eval()

# Load the binary model
weights_path = binary_model_path
# labels_path = "/bask/homes/f/fspo1218/amber/data/mila_models/05-moth-nonmoth_category_map.json"
num_classes = 2 # moth, non-moth
classification_model = timm.create_model(
"tf_efficientnetv2_b3", num_classes=num_classes, weights=None
)
classification_model = classification_model.to(device)
checkpoint = torch.load(weights_path, map_location=device)
checkpoint = torch.load(weights_path, map_location=device, weights_only=True)
state_dict = checkpoint.get("model_state_dict") or checkpoint
classification_model.load_state_dict(state_dict)
classification_model.eval()

# Load the order model
savedWeights = order_model_path
thresholdFile = order_threshold_path
# img_size = 128
order_data_thresholds = pd.read_csv(thresholdFile)
order_labels = order_data_thresholds["ClassName"].to_list()
# thresholds = data_thresholds["Threshold"].to_list()
# means = data_thresholds["Mean"].to_list()
# stds = data_thresholds["Std"].to_list()
# img_depth = 3
num_classes = len(order_labels)
model_order = ResNet50_order(num_classes=num_classes)
model_order.load_state_dict(torch.load(savedWeights, map_location=device))
model_order.load_state_dict(torch.load(savedWeights, map_location=device, weights_only=True))
model_order = model_order.to(device)
model_order.eval()

Expand All @@ -130,7 +123,7 @@ def load_models(device, localisation_model_path, binary_model_path, order_model_
num_classes = len(species_category_map)
species_model = Resnet50_species(num_classes=num_classes)
species_model = species_model.to(device)
checkpoint = torch.load(weights, map_location=device)
checkpoint = torch.load(weights, map_location=device, weights_only=True)
# The model state dict is nested in some checkpoints, and not in others
state_dict = checkpoint.get("model_state_dict") or checkpoint
species_model.load_state_dict(state_dict)
Expand Down

0 comments on commit 70f81f6

Please sign in to comment.