Skip to content

Commit

Permalink
Merge pull request #117 from mwalmsley/dev
Browse files Browse the repository at this point in the history
Add support for greyscale models and Euclid
  • Loading branch information
mwalmsley authored May 14, 2024
2 parents 98fcbea + 8463e98 commit 1517b20
Show file tree
Hide file tree
Showing 7 changed files with 80 additions and 31 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ CUDA 12.1 for PyTorch 2.1.0:

### Recent release features (v2.0.0)

- **New in 2.0.1** Add greyscale encoders. Use `hf_hub:mwalmsley/zoobot-encoder-greyscale-convnext_nano` or [similar](https://huggingface.co/collections/mwalmsley/zoobot-encoders-greyscale-66427c51133285ca01b490c6).
- New pretrained architectures: ConvNeXT, EfficientNetV2, MaxViT, and more. Each in several sizes.
- Reworked finetuning procedure. All these architectures are finetuneable through a common method.
- Reworked finetuning options. Batch norm finetuning removed. Cosine schedule option added.
Expand Down
33 changes: 28 additions & 5 deletions docs/pretrained_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,26 +32,41 @@ Zoobot includes weights for the following pretrained models:
- Test loss
- Finetune
- HF |:hugging:|
* - ConvNeXT-Pico
- 9.1M
- 19.33
- Yes
- `Link <https://huggingface.co/mwalmsley/zoobot-encoder-convnext_pico>`__
* - ConvNeXT-Nano
- 15.6M
- 19.23
- Yes
- `Link <https://huggingface.co/mwalmsley/zoobot-encoder-convnext_nano>`__
* - ConvNeXT-Tiny
- 44.6M
- 19.08
- Yes
- `Link <https://huggingface.co/mwalmsley/zoobot-encoder-convnext_tiny>`__
* - ConvNeXT-Small
- 58.5M
- 19.14
- 19.06
- Yes
- `Link <https://huggingface.co/mwalmsley/zoobot-encoder-convnext_small>`__
* - ConvNeXT-Base
- 88.6M
- **19.04**
- **19.05**
- Yes
- `Link <https://huggingface.co/mwalmsley/zoobot-encoder-convnext_base>`__
* - ConvNeXT-Large
- 197.8M
- 19.09
- Yes
- `Link <https://huggingface.co/mwalmsley/zoobot-encoder-convnext_large>`__
* - MaxViT-Tiny
- 29.1M
- 19.22
- Yes
- `Link <https://huggingface.co/mwalmsley/zoobot-encoder-maxvit_rmlp_tiny_rw_224>`__
* - MaxViT-Small
- 64.9M
- 19.20
Expand All @@ -61,7 +76,7 @@ Zoobot includes weights for the following pretrained models:
- 124.5
- 19.09
- Yes
- TODO
- `Link <https://huggingface.co/mwalmsley/zoobot-encoder-maxvit_base_rw_224>`__
* - Max-ViT-Large
- 211.8M
- 19.18
Expand All @@ -71,12 +86,12 @@ Zoobot includes weights for the following pretrained models:
- 5.33M
- 19.48
- Yes
- `Link <https://huggingface.co/mwalmsley/zoobot-encoder-efficientnet_b0>`__
- WIP
* - EfficientNetV2-S
- 48.3M
- 19.33
- Yes
- `Link <https://huggingface.co/mwalmsley/zoobot-encoder-tf_efficientnetv2_s>`__
- WIP
* - ResNet18
- 11.7M
- 19.83
Expand All @@ -87,12 +102,20 @@ Zoobot includes weights for the following pretrained models:
- 19.43
- Yes
- `Link <https://huggingface.co/mwalmsley/zoobot-encoder-resnet50>`__
* - ResNet101
- 44.5M
- 19.37
- Yes
- `Link <https://huggingface.co/mwalmsley/zoobot-encoder-resnet101>`__


.. note::

Missing a model you need? Reach out! There's a good chance we can train any model supported by `timm <https://github.com/huggingface/pytorch-image-models>`_.

.. note::

New in Zoobot v2.0.1: greyscale (single channel) versions are available `here <https://huggingface.co/collections/mwalmsley/zoobot-encoders-greyscale-66427c51133285ca01b490c6>`_.

Which model should I use?
===========================
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setuptools.setup(
name="zoobot",
version="2.0.0",
version="2.0.1",
author="Mike Walmsley",
author_email="[email protected]",
description="Galaxy morphology classifiers",
Expand Down Expand Up @@ -117,6 +117,6 @@
'webdataset', # for reading webdataset files
'huggingface_hub', # login may be required
'setuptools', # no longer pinned
'galaxy-datasets>=0.0.17' # for dataset loading in both TF and Torch (see github/mwalmsley/galaxy-datasets)
'galaxy-datasets>=0.0.18' # for dataset loading in both TF and Torch (see github/mwalmsley/galaxy-datasets)
]
)
19 changes: 12 additions & 7 deletions zoobot/pytorch/examples/finetuning/finetune_counts_full_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from zoobot.pytorch.training import finetune
from zoobot.pytorch.predictions import predict_on_catalog
from zoobot.shared.schemas import gz_candels_ortho_schema
from zoobot.shared.load_predictions import prediction_hdf5_to_summary_parquet

"""
Example for finetuning Zoobot on counts of volunteer responses throughout a complex decision tree (here, GZ CANDELS).
Expand Down Expand Up @@ -67,12 +68,12 @@
resize_after_crop=resize_after_crop
)

checkpoint_loc = os.path.join(
# TODO replace with path to downloaded checkpoints. See Zoobot README for download links.
repo_dir, 'gz-decals-classifiers/results/benchmarks/pytorch/evo/uploaded/effnetb0_greyscale_224px.ckpt') # decals hparams

model = finetune.FinetuneableZoobotTree(checkpoint_loc=checkpoint_loc, schema=schema)

model = finetune.FinetuneableZoobotTree(
name='hf_hub:mwalmsley/zoobot-encoder-convnext_nano',
schema=schema
)

# TODO set this to wherever you'd like to save your results
save_dir = os.path.join(
repo_dir, f'gz-decals-classifiers/results/finetune_{np.random.randint(1e8)}')

Expand All @@ -86,12 +87,16 @@
# now save predictions on test set to evaluate performance
datamodule_kwargs = {'batch_size': batch_size, 'resize_after_crop': resize_after_crop}
trainer_kwargs = {'devices': 1, 'accelerator': accelerator}

hdf5_loc = os.path.join(save_dir, 'test_predictions.hdf5')
predict_on_catalog.predict(
test_catalog,
model,
n_samples=1,
label_cols=schema.label_cols,
save_loc=os.path.join(save_dir, 'test_predictions.csv'),
save_loc=hdf5_loc,
datamodule_kwargs=datamodule_kwargs,
trainer_kwargs=trainer_kwargs
)

prediction_hdf5_to_summary_parquet(hdf5_loc=hdf5_loc, save_loc=hdf5_loc.replace('.hdf5', 'summary.parquet'), schema=schema)
42 changes: 26 additions & 16 deletions zoobot/pytorch/examples/representations/get_representations.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,45 @@
import logging
import os

import timm

from galaxy_datasets import demo_rings

from zoobot.pytorch.training import finetune, representations
from zoobot.pytorch.estimators import define_model
from zoobot.pytorch.predictions import predict_on_catalog
from zoobot.pytorch.training import finetune
from zoobot.shared import load_predictions, schemas


def main(catalog, checkpoint_loc, save_dir):
def main(catalog, save_dir, name="hf_hub:mwalmsley/zoobot-encoder-convnext_nano"):

assert all([os.path.isfile(x) for x in catalog['file_loc']])

if not os.path.exists(save_dir):
os.mkdir(save_dir)

# can load from either ZoobotTree checkpoint (if trained from scratch)
encoder = define_model.ZoobotTree.load_from_checkpoint(checkpoint_loc).encoder
# or FinetuneableZoobotTree (if finetuned)
# currently, FinetuneableZoobotTree checkpoints should be loaded as ZoobotTree with the args below
# this is a bit awkward and I'm working on a clearer method - but it does work.
# encoder = define_model.ZoobotTree.load_from_checkpoint(checkpoint_loc, output_dim=TODO, question_index_groups=[]).encoder
# load the encoder

# OPTION 1
# Load a pretrained model from HuggingFace, with no finetuning, only as published
model = representations.ZoobotEncoder.load_from_name(name)
# or equivalently (the above is just a wrapper for these two lines below)
# encoder = timm.create_model(model_name=name, pretrained=True)
# model = representations.ZoobotEncoder(encoder=encoder)

# convert to simple pytorch lightning model
model = representations.ZoobotEncoder(encoder=encoder, pyramid=False)
"""
# OPTION 2
label_cols = [f'feat_{n}' for n in range(1280)]
# Load a model that has been finetuned on your own data
# (...do your usual finetuning..., or load a finetuned model with finetune.FinetuneableZoobotClassifier(checkpoint_loc=....ckpt)
encoder = finetuned_model.encoder
# and then convert to simple pytorch lightning model. You can use any pytorch model here.
model = representations.ZoobotEncoder(encoder=encoder)
"""

encoder_dim = define_model.get_encoder_dim(model.encoder)
label_cols = [f'feat_{n}' for n in range(encoder_dim)]
save_loc = os.path.join(save_dir, 'representations.hdf5')

accelerator = 'cpu' # or 'gpu' if available
Expand All @@ -52,20 +65,17 @@ def main(catalog, checkpoint_loc, save_dir):

logging.basicConfig(level=logging.INFO)

# load the gz evo model for representations
checkpoint_loc = '/home/walml/repos/gz-decals-classifiers/results/benchmarks/pytorch/evo/evo_py_gr_11941/checkpoints/epoch=73-step=42698.ckpt'

# use this demo dataset
# TODO change this to wherever you'd like, it will auto-download
data_dir = '/home/walml/repos/galaxy-datasets/roots/demo_rings'
data_dir = '/Users/user/repos/galaxy-datasets/roots/demo_rings'
catalog, _ = demo_rings(root=data_dir, download=True, train=True)
print(catalog.head())
# zoobot expects id_str and file_loc columns, so add these if needed

# save the representations here
# TODO change this to wherever you'd like
save_dir = os.path.join('/home/walml/repos/zoobot/results/pytorch/representations/example')
save_dir = os.path.join('/Users/user/repos/zoobot/results/pytorch/representations/example')

representations_loc = main(catalog, checkpoint_loc, save_dir)
representations_loc = main(catalog, save_dir)
rep_df = load_predictions.single_forward_pass_hdf5s_to_df(representations_loc)
print(rep_df)
9 changes: 8 additions & 1 deletion zoobot/pytorch/training/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,14 @@ def __init__(

if name is not None:
assert encoder is None, 'Cannot pass both name and encoder to use'
self.encoder = timm.create_model(name, num_classes=0, pretrained=True)
if 'greyscale' in name:
# I'm not sure why timm is happy to convert color model stem to greyscale
# but doesn't correctly load greyscale model without this hack
logging.info('Loading greyscale model (auto-detected from name)')
timm_kwargs = {'in_chans': 1}
else:
timm_kwargs = {}
self.encoder = timm.create_model(name, num_classes=0, pretrained=True, **timm_kwargs)
self.encoder_dim = self.encoder.num_features

elif zoobot_checkpoint_loc is not None:
Expand Down
3 changes: 3 additions & 0 deletions zoobot/shared/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,3 +299,6 @@ def answers(self):

gz_ukidss_schema = Schema(label_metadata.ukidss_ortho_pairs, label_metadata.ukidss_ortho_dependencies)
gz_jwst_schema = Schema(label_metadata.jwst_ortho_pairs, label_metadata.jwst_ortho_dependencies)

euclid_ortho_schema = Schema(label_metadata.euclid_ortho_pairs , label_metadata.euclid_ortho_dependencies)
euclid_schema = Schema(label_metadata.euclid_pairs , label_metadata.euclid_dependencies)

0 comments on commit 1517b20

Please sign in to comment.