Skip to content

Commit

Permalink
pass inference arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
KatrionaGoldmann committed Aug 1, 2024
1 parent a0d4d18 commit 734ee83
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 12 deletions.
10 changes: 4 additions & 6 deletions cr_analysis.sh
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,7 @@ conda activate "${CONDA_ENV_PATH}"


# Run the Python script on baskerville
python s3_download_with_inference.py "Costa Rica" "All of the above"





python s3_download_with_inference.py \
--country "Costa Rica" \
--deployment "Forest Edge - EC4AB109" \
--keep_crops
30 changes: 24 additions & 6 deletions s3_download_with_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,15 @@
from utils.custom_models import load_models


def display_menu(
country, deployment, crops_interval, csv_file, rerun_existing, local_directory_path
def download_and_inference(
country,
deployment,
crops_interval,
csv_file,
rerun_existing,
local_directory_path,
perform_inference,
remove_image,
):
"""
Display the main menu and handle user interaction.
Expand All @@ -28,9 +35,6 @@ def display_menu(
username = aws_credentials["UKCEH_username"]
password = aws_credentials["UKCEH_password"]

perform_inference = True
remove_image = True

print(f"\033[93m - Removing images after analysis: {remove_image}\033[0m")
print(f"\033[93m - Performing inference: {perform_inference}\033[0m")
print(f"\033[93m - Rerun existing inferences: {rerun_existing}\033[0m")
Expand Down Expand Up @@ -145,6 +149,18 @@ def display_menu(
default=False,
help="Whether to keep the crops",
)
parser.add_argument(
"--perform_inference",
action=argparse.BooleanOptionalAction,
default=True,
help="Whether to perform the inference",
)
parser.add_argument(
"--remove_image",
action=argparse.BooleanOptionalAction,
default=True,
help="Whether to remove the raw image after inference",
)
parser.add_argument(
"--crops_interval",
type=str,
Expand Down Expand Up @@ -213,11 +229,13 @@ def display_menu(
)
all_boxes.to_csv(csv_file, index=False)

display_menu(
download_and_inference(
args.country,
args.deployment,
crops_interval,
csv_file,
args.rerun_existing,
data_storage_path,
args.perform_inference,
args.remove_image,
)

0 comments on commit 734ee83

Please sign in to comment.