-
Notifications
You must be signed in to change notification settings - Fork 0
/
pred_validator.py
55 lines (55 loc) · 1.81 KB
/
pred_validator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
"""
Course: Statistical Machine Learning.
This script checks the predictions csv file for the most common formatting
mistakes.
"""
import argparse
import collections
import sys
N_TEST_CASES = 387
def parse_args():
"""Parse and return command line arguments."""
parser = argparse.ArgumentParser(description="Argument parser")
parser.add_argument(
"--file",
default="predictions.csv",
required=False,
help="Path to a CSV file with predictions. Default: predictions.csv")
args = parser.parse_args()
return args
def import_file(path):
"""Import the files as a list of lines"""
print(f"Importing {path}...")
try:
with open(path, "rt") as f:
data = f.readlines()
except:
print(f"File {path} not found. Exiting.")
sys.exit(1)
return data
def parse_and_check_(lines):
"""All prediction are expected to be 0/1 values in the first line separated by
commas."""
predictions = lines[0].strip().split(",")
# Calculate frequencies.
freq = collections.Counter(predictions)
n = sum(freq.values())
# Check the number of elements.
if n != N_TEST_CASES:
print(f"Error: the number of predictions must be {N_TEST_CASES}. Got {n}.")
sys.exit(2)
# Check the values.
if set(predictions) != set(["0","1"]):
print(f"Error: predicted values must be 0 or 1. Got {set(predictions)}.")
sys.exit(3)
print(f"The format seems to be correct. Your predicted frequencies: {freq}. Total number of predictions: {n}.")
return predictions
def main():
# Parse arguments.
args = parse_args()
# Import predictions.
lines = import_file(args.file)
# Parse and check predictions.
parse_and_check_(lines)
if __name__ == '__main__':
main()