Testing software is vital to ensure that code behaves as expected. In Machine Learning projects, testing is not as widely common as normal software testing. The aim of this talk is to give a brief overview on unit testing and to show how a Data Scientist/Machine Learning Engineer can implement it in a modern Machine Learning Development Lifecycle along with DevOps principles such as CI/CD.
Clone the project
git clone https://github.com/yudhiesh/ctmlp
Create the conda environment
conda create --name ctmlp python=3.7
conda activate ctmlp
Install dependencies
pip install -r requirements.txt
Train a model
python src/models/train_model.py --train_path="./data/raw/train.csv" --test_path="./data/raw/test.csv"
To run tests, run the following command
pytest --no-header -v
├── LICENSE
├── README.md
├── conftest.py <- shares fixtures for test to all test
├── data <- data used
│ └── raw
│ ├── data_description.txt
│ ├── test.csv <- test data
│ └── train.csv <- train data
├── models
│ └── model.pkl <- saved model that was trained
├── pytest.ini <- configurations that are used for tests
├── requirements.txt <- dependencies
├── setup.cfg <- configures the behavior of the various setup commands for the project
├── src
│ ├── __init__.py
│ └── models
│ ├── __init__.py
│ └── train_model.py <- script to train the model
├── test_score.json <- json of the model metrics from training
└── tests
├── helpers
│ ├── __init__.py
│ └── utils.py <- helper methods used in test
└── test_post_train.py <- contains post training test
# pre-train tests
# located at src/models/train_model.py
is_data_leaking() # checks if there is data leakage detected
is_overfitting_batch() # checks if the model is able to overfit a single batch of data
# post-train tests
# located at tests/test_post_train.py
test_invariance_tests() # checks for small perturbations that should not impact the models predictions
test_directional_expectation_tests() # checks for small perturbations that should impact the model
test_model_inference_times() # check that the models inference speed at the 99th percentile is acceptable
test_model_metric() # check that the models metric is below a set score
Here are some resources I used when coming up with this talk
- How to Test Machine Learning Code and Systems
- Effective testing for machine learning systems
- Automated Testing for Machine Learing
- Decouple the model definition from the training code to ensure more flexibility
- Add in more test cases
- use DVC to version data as data in the real world would be too big to include inside of a repository