Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
johann-petrak committed Jul 31, 2024
1 parent 8ee57e0 commit d10eb35
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 14 deletions.
2 changes: 2 additions & 0 deletions python/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
scikit-learn
scipy
37 changes: 23 additions & 14 deletions python/scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,13 @@ def check_columns(data, columns):
if column not in columns:
raise ValueError(f"Column {column} not expected in data, expected {columns}")

def check_allowed(data, column, allowed=["0", "1"]):
def check_allowed(data, column, allowed=None):
"""Check if the predictions are in the allowed set, if not, print an error message to
stderr also showing the id and throw an exception. Otherwise return to the caller."""
if allowed is None:
allowed = ["0", "1"]
if column not in data:
raise ValueError(f"Column {column} not found in data")
return
for i, value in enumerate(data[column]):
if value not in allowed:
raise ValueError(f"Invalid value {value} not one of {allowed} in column {column} at index {i} with id {data['id'][i]}")
Expand All @@ -71,24 +72,28 @@ def check_dist(data, columns):
if abs(sum - 1.0) > EPS:
raise ValueError(f"Values in columns {columns} do not sum to 1.0 at index {i} with id {theid}")

def load_tsv(submission_dir, expected_rows, expected_cols):

def load_tsv(submission_dir, expected_rows, expected_cols, file=None):
"""
Try to load a TSV file from the submission directory. This expects a single TSV file to be present in the submission directory.
If there is no TSV file or there are multiple files, it will log an error to stderr and return None.
"""
tsv_files = [f for f in os.listdir(submission_dir) if f.endswith('.tsv')]
if len(tsv_files) == 0:
print("No TSV file ending with '.tsv' found in submission directory", file=sys.stderr)
return None
if len(tsv_files) > 1:
print("Multiple TSV files found in submission directory", file=sys.stderr)
return None
tsv_file = tsv_files[0]
if file is not None:
tsv_file = file
else:
tsv_files = [f for f in os.listdir(submission_dir) if f.endswith('.tsv')]
if len(tsv_files) == 0:
print("No TSV file ending with '.tsv' found in submission directory", file=sys.stderr)
return None
if len(tsv_files) > 1:
print("Multiple TSV files found in submission directory", file=sys.stderr)
return None
tsv_file = tsv_files[0]
tsv_path = os.path.join(submission_dir, tsv_file)
print("Loading TSV file", tsv_path)
# Read the TSV file incrementally row by row and create a dictionary where the key is the column name and the value is a list of values for that column.
# Expect the column names in the first row of the TSV file.
# Abort reading and log an error to stderr if the file is not a valid TSV file, if it contains more than one row with the same id,
# Abort reading and log an error to stderr if the file is not a valid TSV file, if it contains more than one row with the same id,
# if the column name is not known, or if there are more than N_MAX rows.
data = defaultdict(list)
with open(tsv_path, 'rt') as infp:
Expand Down Expand Up @@ -189,6 +194,7 @@ def main():

parser = argparse.ArgumentParser(description='Scorer for the competition')
parser.add_argument('--submission-dir', help='Directory containing the submission (.)', default=".")
parser.add_argument('--submission-file', help='If submission directory contains more than one file, name of the file to use (None)', default=None)
parser.add_argument('--reference-dir', help='Directory containing the reference data (./dev_phase/reference_data/)', default="./dev_phase/reference_data/")
parser.add_argument('--score-dir', help='Directory to write the scores to (.)', default=".")
parser.add_argument('--codabench', help='Indicate we are running on codabench, not locally', action='store_true')
Expand Down Expand Up @@ -238,9 +244,12 @@ def main():

# load the submissing tsv file
if args.st == "1":
data = load_tsv(submission_dir, expected_rows=len(targets), expected_cols=ST1_COLUMNS)
data = load_tsv(submission_dir, expected_rows=len(targets), expected_cols=ST1_COLUMNS, file=args.submission_file)
elif args.st == "2":
data = load_tsv(submission_dir, expected_rows=len(targets), expected_cols=ST2_COLUMNS)
data = load_tsv(submission_dir, expected_rows=len(targets), expected_cols=ST2_COLUMNS, file=args.submission_file)
else:
print("Unknown subtask", file=sys.stderr)
sys.exit(1)
if data is None:
print("Problems loading the submission, aborting", file=sys.stderr)
sys.exit(1)
Expand Down

0 comments on commit d10eb35

Please sign in to comment.