The Python files in the repository can be used as a starting point for a simple and clean way to train TensorFlow models for classification in computer vision. The only Python module required is TensorFlow. Tested on TensorFlow v2.8.0 and 2.9.1.
Make sure you have the tensorflow python library installed and a valid config.yaml
file. Then to start the training process run python train.py
in your terminal of choice.
In most cases the training setup file itself should not need to be edited, instead a configuration file can be used to change the dataset, model, callbacks and training parameters. Below you can find a reference for all the possible options in the configuration file. Note: *required
Args | |
---|---|
src* | Root source directory of your dataset containing a training and validation subdirectory. |
classes* | Names of the classes in the dataset. |
class_mode | Mode of the dataset classes (binary, categorical, etc.). Defaults to categorical . |
batch | Size of the batches of data for training. Defaults to 32 . |
train_options | Additional parameters for training data flow from directory. See docs. |
valid_options | Additional parameters for validation data flow from directory. See docs. |
Args | |
---|---|
cls* | Model class to import from the specified module. |
module | Python module to import the model class from. defaults to tensorflow.keras.applications . See docs for a list of available models. |
weights | Specify pre-trained weights (imagenet, etc.) or path to TensorFlow weights file. Defaults to None . |
name | Overwrite the name of the resulting model. Defaults to the name of the specified model. |
checkpoints | Path to directory to save checkpoints to. Defaults to ./models . |
class_options | Additional parameters for intializing model class. See docs of the model class for more information. |
compile_options | Additional parameters for compiling model. See docs. |
Args | |
---|---|
cls* | Optimizer class to import from the specified module. |
module | Python module to import the optimizer class from. defaults to tf.keras.optimizers . See docs for a list of available optimizers. |
options | The constructor parameters of the specified optimizer. |
Callbacks can be imported from a module. See docs for a list of available callbacks. Include the name of the class (capitalized) and the parameters in the configuration file. For example:
callbacks:
- cls: ReduceLROnPlateau
module:
options:
monitor: val_loss
mode: min
patience: 5
factor: 0.5
min_lr: 0.000001
verbose: 1
Args | |
---|---|
epochs | Number of epochs to train the model for. |
training_steps_per_epoch | Number of steps to take per epoch of training. Defaults to train_gen.n//train_gen.batch_size . |
validation_steps_per_epoch | Number of steps to take per epoch of validation. Defaults to valid_gen.n//valid_gen.batch_size . |
options | Additional parameters for the Model.fit() method. See docs. |