This is an image classification model for the Seasats Autonomous Seacraft, which classifies images as containing a boat or not containing a boat. The model is based on the Vision Transformer (ViT) architecture and is trained on a dataset of labeled images of the sea surface.
To set up the project, first clone the repository:
git clone https://github.com/WillReynolds5/seasats-boat-classification.git
Next, create a new conda environment from the provided environment.yml file:
conda env create -f environment.yaml
Activate the environment:
conda activate shipclassifier
You should now be able to run the model and train it on new data.
This code crops and resizes images to size=256px, input images can be of any aspect ratio / size
TODO: This should be optimized to what ever aspect ratio and resolution the Seastats camera will have
Images are expected to be in the following directory structure:
- 70% of the images at /dataset/train, with boat and not_boat as separate folders
- 30% of the images at /dataset/val, with boat and not_boat as separate folders
The ViT model has several hyperparameters that can be tuned to achieve better performance on a particular task. Here are some tips for choosing hyperparameters for the SEASAT dataset:
- image_size: This parameter controls the size of the input images. In general, larger images may require a larger image_size to capture more details, but may also require more processing power and longer training times. You can experiment with different values to find the best tradeoff between performance and speed.
- patch_size: This parameter controls the size of the image patches used by the transformer. Larger patches may capture more contextual information, but may also introduce more noise or reduce the resolution of the input images. Smaller patches may be more precise, but may require more processing power to process. You can experiment with different values to find the best balance.
- dim: This parameter controls the dimension of the transformer embeddings. Higher values may allow the model to capture more complex relationships between patches, but may also require more processing power and longer training times. You can experiment with different values to find the best tradeoff between performance and speed.
- depth: This parameter controls the number of transformer layers. Deeper models may be able to capture more complex patterns, but may also require more processing power and longer training times. You can experiment with different values to find the best tradeoff between performance and speed.
- heads: This parameter controls the number of attention heads in each transformer layer. More heads may allow the model to capture more fine-grained patterns, but may also require more processing power and longer training times. You can experiment with different values to find the best tradeoff between performance and speed.
- mlp_dim: This parameter controls the dimension of the multi-layer perceptron used in each transformer layer. Higher values may allow the model to capture more complex patterns, but may also require more processing power and longer training times. You can experiment with different values to find the best tradeoff between performance and speed.
To train the model on new data, just run the train.py script. But first, choose your training parameters
-- epochs: The number of epochs to train for. (default: 10)
-- batch_size: The batch size for training. (default: 32)
-- lr: The learning rate for the optimizer. (default: 1e-4)
![Alt text](https://github.com/WillReynolds5/seasats-image-classification/blob/main/losses.png?raw=true)WHY ISNT THE IMAGE SHOWING AHHH
To evaluate the trained model on a new image, you can use the evaluate.py script. Eval has been built with a command line interface so the project can be interfaced as a subprocess from another programming language (what ever you guys are using for hardware, C++?). The script takes the following arguments:
- image_path: The path to the image file to evaluate.
- --model_path: The path to the saved model file. (default: "checkpoints/model.pth")
For example, to evaluate the model on an image file named boat.png, you could run:
python evaluate.py boat.png
The script will preprocess the image using the preprocess_data function, run it through the pre-trained model, and output the model's prediction for the image ("boat" or "not boat"). The default path for the model file is model.pt, but you can specify a different path using the --model_path argument:
python evaluate.py boat.png --model_path my_model.pth
TODO: do not randomly crop images for EVAL
Add GPU supports with .to('cuda)
make the model work on non square images
Optimize the preprocessing for the seasats dataset
ViT is optimzied to run on GPU, a different model may be more well suited for the seasats hardware (CNN/RESNET)