diff --git a/python/requirements.txt b/python/requirements.txt new file mode 100644 index 0000000..c914aca --- /dev/null +++ b/python/requirements.txt @@ -0,0 +1,2 @@ +scikit-learn +scipy diff --git a/python/scoring.py b/python/scoring.py index c6a34a8..f83cfe9 100755 --- a/python/scoring.py +++ b/python/scoring.py @@ -40,12 +40,13 @@ def check_columns(data, columns): if column not in columns: raise ValueError(f"Column {column} not expected in data, expected {columns}") -def check_allowed(data, column, allowed=["0", "1"]): +def check_allowed(data, column, allowed=None): """Check if the predictions are in the allowed set, if not, print an error message to stderr also showing the id and throw an exception. Otherwise return to the caller.""" + if allowed is None: + allowed = ["0", "1"] if column not in data: raise ValueError(f"Column {column} not found in data") - return for i, value in enumerate(data[column]): if value not in allowed: raise ValueError(f"Invalid value {value} not one of {allowed} in column {column} at index {i} with id {data['id'][i]}") @@ -71,24 +72,28 @@ def check_dist(data, columns): if abs(sum - 1.0) > EPS: raise ValueError(f"Values in columns {columns} do not sum to 1.0 at index {i} with id {theid}") -def load_tsv(submission_dir, expected_rows, expected_cols): + +def load_tsv(submission_dir, expected_rows, expected_cols, file=None): """ Try to load a TSV file from the submission directory. This expects a single TSV file to be present in the submission directory. If there is no TSV file or there are multiple files, it will log an error to stderr and return None. """ - tsv_files = [f for f in os.listdir(submission_dir) if f.endswith('.tsv')] - if len(tsv_files) == 0: - print("No TSV file ending with '.tsv' found in submission directory", file=sys.stderr) - return None - if len(tsv_files) > 1: - print("Multiple TSV files found in submission directory", file=sys.stderr) - return None - tsv_file = tsv_files[0] + if file is not None: + tsv_file = file + else: + tsv_files = [f for f in os.listdir(submission_dir) if f.endswith('.tsv')] + if len(tsv_files) == 0: + print("No TSV file ending with '.tsv' found in submission directory", file=sys.stderr) + return None + if len(tsv_files) > 1: + print("Multiple TSV files found in submission directory", file=sys.stderr) + return None + tsv_file = tsv_files[0] tsv_path = os.path.join(submission_dir, tsv_file) print("Loading TSV file", tsv_path) # Read the TSV file incrementally row by row and create a dictionary where the key is the column name and the value is a list of values for that column. # Expect the column names in the first row of the TSV file. - # Abort reading and log an error to stderr if the file is not a valid TSV file, if it contains more than one row with the same id, + # Abort reading and log an error to stderr if the file is not a valid TSV file, if it contains more than one row with the same id, # if the column name is not known, or if there are more than N_MAX rows. data = defaultdict(list) with open(tsv_path, 'rt') as infp: @@ -189,6 +194,7 @@ def main(): parser = argparse.ArgumentParser(description='Scorer for the competition') parser.add_argument('--submission-dir', help='Directory containing the submission (.)', default=".") + parser.add_argument('--submission-file', help='If submission directory contains more than one file, name of the file to use (None)', default=None) parser.add_argument('--reference-dir', help='Directory containing the reference data (./dev_phase/reference_data/)', default="./dev_phase/reference_data/") parser.add_argument('--score-dir', help='Directory to write the scores to (.)', default=".") parser.add_argument('--codabench', help='Indicate we are running on codabench, not locally', action='store_true') @@ -238,9 +244,12 @@ def main(): # load the submissing tsv file if args.st == "1": - data = load_tsv(submission_dir, expected_rows=len(targets), expected_cols=ST1_COLUMNS) + data = load_tsv(submission_dir, expected_rows=len(targets), expected_cols=ST1_COLUMNS, file=args.submission_file) elif args.st == "2": - data = load_tsv(submission_dir, expected_rows=len(targets), expected_cols=ST2_COLUMNS) + data = load_tsv(submission_dir, expected_rows=len(targets), expected_cols=ST2_COLUMNS, file=args.submission_file) + else: + print("Unknown subtask", file=sys.stderr) + sys.exit(1) if data is None: print("Problems loading the submission, aborting", file=sys.stderr) sys.exit(1)