Skip to content

Commit

Permalink
option to rerun inferences and use GPU
Browse files Browse the repository at this point in the history
  • Loading branch information
KatrionaGoldmann committed Aug 1, 2024
1 parent 5eee0af commit cedac81
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 32 deletions.
59 changes: 36 additions & 23 deletions s3_download_with_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,20 @@
import boto3
import torch
import pandas as pd
import datetime
import os
import argparse

from utils.aws_scripts import get_objects, get_deployments
from utils.custom_models import load_models

device = torch.device('cpu')
# Use GPU if available
print(f' - Cuda available: {torch.cuda.is_available()}')
if (torch.cuda.is_available()):
device = torch.device("cuda")
else:
device = torch.device("cpu")

def display_menu(country, deployment, crops_interval):
def display_menu(country, deployment, crops_interval, csv_file, rerun_existing):
"""Display the main menu and handle user interaction."""

print("- Read in configs and credentials")
Expand All @@ -29,19 +34,13 @@ def display_menu(country, deployment, crops_interval):

all_deployments = get_deployments(username, password)

#countries = list({d["country"] for d in all_deployments if d["status"] == "active"})
print('- Analysing: ', country)

country_deployments = [
f"{d['location_name']} - {d['camera_id']}"
for d in all_deployments
if d["country"] == country and d["status"] == "active"
]
country_deployments = country_deployments


#data_types = ["snapshot_images", "audible_recordings", "ultrasound_recordings"]
data_type = "snapshot_images"

s3_bucket_name = [
d["country_code"]
Expand All @@ -53,6 +52,7 @@ def display_menu(country, deployment, crops_interval):
remove_image = True

print(' - Removing images after analysis: ', remove_image)
print(' - Rerun existing inferences: ', rerun_existing)
print(' - Performing inference: ', perform_inference)

if deployment == 'All':
Expand All @@ -71,9 +71,7 @@ def display_menu(country, deployment, crops_interval):
and d["status"] == "active"
][0]



prefix = f"{dep_id}/{data_type}"
prefix = f"{dep_id}/snapshot_images"
get_objects(session,
aws_credentials,
s3_bucket_name,
Expand All @@ -93,6 +91,7 @@ def display_menu(country, deployment, crops_interval):
device=device,
order_data_thresholds=order_data_thresholds,
csv_file=csv_file,
rerun_existing=rerun_existing,
crops_interval=crops_interval)

if __name__ == "__main__":
Expand All @@ -115,17 +114,10 @@ def display_menu(country, deployment, crops_interval):
local_directory_path = aws_credentials['directory']
print(' - Scratch storage: ', local_directory_path)

date_time = datetime.datetime.now().strftime("%y%m%d_%H%M%S")
csv_file = f'{local_directory_path}/mila_outputs_{date_time}.csv'
# date_time = datetime.datetime.now().strftime("%y%m%d_%H%M%S")
# csv_file = f'{local_directory_path}/mila_outputs_{date_time}.csv'


all_boxes = pd.DataFrame(
columns=['image_path', 'box_score', 'x_min', 'y_min', 'x_max', 'y_max',
'class_name', 'class_confidence',
'order_name', 'order_confidence',
'species_name', 'species_confidence',
'cropped_image_path']
)
all_boxes.to_csv(csv_file, index=False)

model_loc, classification_model, regional_model, regional_category_map, order_model, order_data_thresholds, order_labels = load_models(device)

Expand All @@ -138,6 +130,12 @@ def display_menu(country, deployment, crops_interval):
parser.add_argument("--crops_interval", type=str,
help="The interval for which to preserve the crops",
default=10)
parser.add_argument("--csv_file", type=str,
help="The path to the csv file to save the results",
default=f'./{(parser.parse_args().country).replace(" ", "_")}_results.csv')
parser.add_argument("--rerun_existing", action=argparse.BooleanOptionalAction,
default=False,
help="Whether to rerun images which have already been analysed")

args = parser.parse_args()

Expand All @@ -149,4 +147,19 @@ def display_menu(country, deployment, crops_interval):
print(' - Not keeping crops')
crops_interval = None

display_menu(args.country, args.deployment, crops_interval)
print(f'Saving results to: {args.csv_file}')

# if the file doesnt exist, print headers
csv_file = args.csv_file
if not os.path.isfile(csv_file):
all_boxes = pd.DataFrame(
columns=['image_path', 'analysis_datetime',
'box_score', 'x_min', 'y_min', 'x_max', 'y_max',
'class_name', 'class_confidence',
'order_name', 'order_confidence',
'species_name', 'species_confidence',
'cropped_image_path']
)
all_boxes.to_csv(csv_file, index=False)

display_menu(args.country, args.deployment, crops_interval, csv_file, args.rerun_existing)
27 changes: 22 additions & 5 deletions utils/aws_scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import os
from datetime import datetime, timedelta
from utils.inference_scripts import perform_inf
import pandas as pd

def get_deployments(username, password):
"""Fetch deployments from the API with authentication."""
Expand Down Expand Up @@ -87,24 +88,39 @@ def download_batch(s3_client, bucket_name, keys, local_path,
order_labels=None, species_model=None, species_labels=None,
country='UK', region='UKCEH', device=None,
order_data_thresholds=None, csv_file='results.csv',
intervals=None):
rerun_existing=False, intervals=None):
"""
Download a batch of objects from S3.
"""


existing_df = pd.read_csv(csv_file)
print(existing_df['image_path'])

for key in keys:
file_path, filename = os.path.split(key)


os.makedirs(os.path.join(local_path, file_path), exist_ok=True)
download_path = os.path.join(local_path, file_path, filename)

print(download_path)
print(not rerun_existing)
print(download_path in existing_df['image_path'])
# check if file is in csv_file 'path' column
if not rerun_existing:
if existing_df['image_path'].str.contains(download_path).any():
print(f'{download_path} has already been processed. Skipping...')
continue


download_object(s3_client, bucket_name, key, download_path,
perform_inference, remove_image,
localisation_model, binary_model,
order_model, order_labels, species_model,
species_labels, country, region, device,
order_data_thresholds, csv_file, intervals)
order_data_thresholds, csv_file,
intervals)

def count_files(s3_client, bucket_name, prefix):
"""
Expand All @@ -129,7 +145,7 @@ def get_objects(session, aws_credentials, bucket_name, key, local_path,
order_labels=None, species_model=None, species_labels=None,
country='UK', region='UKCEH', device=None,
order_data_thresholds=None, csv_file='results.csv',
crops_interval=None):
rerun_existing=False, crops_interval=None):
"""
Fetch objects from the S3 bucket and download them synchronously in batches.
"""
Expand Down Expand Up @@ -165,15 +181,16 @@ def get_objects(session, aws_credentials, bucket_name, key, local_path,
localisation_model, binary_model, order_model,
order_labels, species_model, species_labels,
country, region, device, order_data_thresholds,
csv_file, intervals)
csv_file, rerun_existing, intervals)
keys = []
progress_bar.update(batch_size)
if keys:
download_batch(s3_client, bucket_name, keys, local_path,
perform_inference, remove_image, localisation_model,
binary_model, order_model, order_labels,
species_model, species_labels, country, region,
device, order_data_thresholds, csv_file, intervals)
device, order_data_thresholds, csv_file,
rerun_existing, intervals)
progress_bar.update(len(keys))

progress_bar.close()
8 changes: 4 additions & 4 deletions utils/inference_scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
from PIL import Image, ImageDraw
import torchvision.transforms as transforms
import numpy as np
from datetime import datetime

# from utils.custom_models import Resnet50_species, ResNet50_order, load_models


def classify_species(image_tensor, regional_model, regional_category_map):
'''
Classify the species of the moth using the regional model.
Expand Down Expand Up @@ -100,7 +100,7 @@ def perform_inf(image_path, loc_model, binary_model, order_model,
input_tensor = transform_loc(image).unsqueeze(0).to(device)

all_boxes = pd.DataFrame(
columns=['image_path',
columns=['image_path', 'analysis_datetime',
'box_score', 'x_min', 'y_min', 'x_max', 'y_max', #localisation info
'class_name', 'class_confidence', # binary class info
'order_name', 'order_confidence', # order info
Expand Down Expand Up @@ -175,12 +175,12 @@ def perform_inf(image_path, loc_model, binary_model, order_model,

# append to csv with pandas
df = pd.DataFrame(
[[image_path,
[[image_path, str(datetime.now()),
box_score, x_min, y_min, x_max, y_max,
class_name, class_confidence ,
order_name, order_confidence,
species_name, species_confidence, crop_path]],
columns=['image_path',
columns=['image_path', 'analysis_datetime',
'box_score', 'x_min', 'y_min', 'x_max', 'y_max',
'class_name', 'class_confidence',
'order_name', 'order_confidence',
Expand Down

0 comments on commit cedac81

Please sign in to comment.