Skip to content

Commit

Permalink
Merge pull request #152 from ViCCo-Group/test_improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
LukasMut authored Aug 7, 2023
2 parents 62a80b7 + 04f232b commit 4f99f0a
Showing 1 changed file with 90 additions and 49 deletions.
139 changes: 90 additions & 49 deletions tests/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,66 +16,70 @@
TEST_PATH = "./test_images"
OUT_PATH = "./test"

SSL_RN50_DEFAULT_CONFIG = {
"modules": ["avgpool"],
"pretrained": True,
"source": "ssl",
}

MODEL_AND_MODULE_NAMES = {
# Torchvision models
"vgg16": {
"model_name": "vgg16",
"modules": ["features.23", "classifier.3"],
"pretrained": True,
"source": "torchvision",
},
"vgg19_bn": {
"model_name": "vgg19_bn",
"modules": ["features.23", "classifier.3"],
"pretrained": False,
"source": "torchvision",
},
# Hardcoded models
"cornet_r": {
"model_name": "cornet_r",
"modules": ["decoder.flatten"],
"pretrained": True,
"source": "custom",
},
"cornet_rt": {
"model_name": "cornet_rt",
"modules": ["decoder.flatten"],
"pretrained": False,
"source": "custom",
},
"cornet_s": {
"model_name": "cornet_s",
"modules": ["decoder.flatten"],
"pretrained": False,
"source": "custom",
},
"cornet_z": {
"model_name": "cornet_z",
"modules": ["decoder.flatten"],
"pretrained": True,
"source": "custom",
},
# Custom models
"VGG16_ecoset": {
"model_name": "VGG16_ecoset",
"modules": ["classifier.3"],
"pretrained": True,
"source": "custom",
},
"clip": {
"clip_vitb32": {
"model_name": "clip",
"modules": ["visual"],
"pretrained": True,
"source": "custom",
"clip": True,
"kwargs": {"variant": "ViT-B/32"},
},
"clip": {
"clip_rn50": {
"model_name": "clip",
"modules": ["visual"],
"pretrained": True,
"source": "custom",
"clip": True,
"kwargs": {"variant": "RN50"},
},
"OpenCLIP": {
"OpenCLIP_vitb32": {
"model_name": "OpenCLIP",
"modules": ["visual"],
"pretrained": True,
"source": "custom",
Expand All @@ -84,82 +88,119 @@
},
# Timm models
"mixnet_l": {
"modules": ["conv_head"],
"pretrained": True,
"model_name": "mixnet_l",
"modules": ["conv_head"],
"pretrained": True,
"source": "timm"
},
# "gluon_inception_v3": {
# "modules": ["Mixed_6d"],
# "pretrained": False,
# "source": "timm",
# },

# Keras models
"VGG16": {
"VGG16_keras": {
"model_name": "VGG16",
"modules": ["block1_conv1", "flatten"],
"pretrained": True,
"source": "keras",
},
"VGG19": {
"VGG19_keras": {
"model_name": "VGG19",
"modules": ["block1_conv1", "flatten"],
"pretrained": False,
"source": "keras",
},
# Vissl models
"simclr-rn50": SSL_RN50_DEFAULT_CONFIG,
"mocov2-rn50": SSL_RN50_DEFAULT_CONFIG,
"jigsaw-rn50": SSL_RN50_DEFAULT_CONFIG,
"rotnet-rn50": SSL_RN50_DEFAULT_CONFIG,
"swav-rn50": SSL_RN50_DEFAULT_CONFIG,
"pirl-rn50": SSL_RN50_DEFAULT_CONFIG,
"barlowtwins-rn50": SSL_RN50_DEFAULT_CONFIG,
"vicreg-rn50": SSL_RN50_DEFAULT_CONFIG,
"dino-rn50" : SSL_RN50_DEFAULT_CONFIG,
"dino-vit-small-p8": {
"modules": ["norm"],
"simclr-rn50": {
"model_name": "simclr-rn50",
"modules": ["avgpool"],
"pretrained": True,
"source": "ssl",
},
"mocov2-rn50": {
"model_name": "mocov2-rn50",
"modules": ["avgpool"],
"pretrained": True,
"source": "ssl",
},
"jigsaw-rn50": {
"model_name": "jigsaw-rn50",
"modules": ["avgpool"],
"pretrained": True,
"source": "ssl",
},
"rotnet-rn50": {
"model_name": "rotnet-rn50",
"modules": ["avgpool"],
"pretrained": True,
"source": "ssl",
},
"swav-rn50": {
"model_name": "swav-rn50",
"modules": ["avgpool"],
"pretrained": True,
"source": "ssl",
},
"pirl-rn50": {
"model_name": "pirl-rn50",
"modules": ["avgpool"],
"pretrained": True,
"source": "ssl",
},
"barlowtwins-rn50": {
"model_name": "barlowtwins-rn50",
"modules": ["avgpool"],
"pretrained": True,
"source": "ssl",
},
"dino-vit-base-p8": {
"model_name": "dino-vit-base-p8",
"modules": ["norm"],
"pretrained": True,
"source": "ssl",
"extract_cls_token": True,
"kwargs": {"extract_cls_token": True}
},
# Harmonization models
"Harmonization": {
"modules": ["visual"],
"dinov2-vit-small-p14": {
"model_name": "dinov2-vit-small-p14",
"modules": ["norm"],
"pretrained": True,
"source": "ssl",
"kwargs": {"extract_cls_token": True}
},
# Additional models
"Harmonization_visual_ResNet50": {
"model_name": "Harmonization",
"modules": ["avg_pool"],
"pretrained": True,
"source": "custom",
"kwargs": {"variant": "ResNet50"},
},
"Harmonization": {
"Harmonization_fc2_VGG16": {
"model_name": "Harmonization",
"modules": ["fc2"],
"pretrained": True,
"source": "custom",
"kwargs": {"variant": "VGG16"},
},
"Harmonization": {
"Harmonization_head_ViT_B16": {
"model_name": "Harmonization",
"modules": ["head"],
"pretrained": True,
"source": "custom",
"kwargs": {"variant": "ViT_B16"},
},
"DreamSim": {
"DreamSim_mlp_clip_vitb32": {
"model_name": "DreamSim",
"modules": ["model.mlp"],
"pretrained": True,
"source": "custom",
"kwargs": {"variant": "clip_vitb32"},
},
"DreamSim": {
"DreamSim_mlp_open_clip_vitb32": {
"model_name": "DreamSim",
"modules": ["model.mlp"],
"pretrained": True,
"source": "custom",
"kwargs": {"variant": "open_clip_vitb32"},
}
},
}


FILE_FORMATS = ["hdf5", "npy", "mat", "pt", "txt"]
DISTANCES = ["correlation", "cosine", "euclidean", "gaussian"]

Expand All @@ -169,7 +210,6 @@
NUM_SAMPLES = int(BATCH_SIZE * 2)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"


tf_model = Sequential()
tf_model.add(Dense(2, input_dim=1, activation="relu", use_bias=False, name="relu"))
weights = np.array([[[1, 1]]])
Expand Down Expand Up @@ -245,21 +285,22 @@ def __len__(self) -> int:


def iterate_through_all_model_combinations():
for model_name in MODEL_AND_MODULE_NAMES:
pretrained = MODEL_AND_MODULE_NAMES[model_name]["pretrained"]
source = MODEL_AND_MODULE_NAMES[model_name]["source"]
kwargs = MODEL_AND_MODULE_NAMES[model_name].get("kwargs", {})
for model_config in MODEL_AND_MODULE_NAMES.values():
model_name = model_config['model_name']
pretrained = model_config["pretrained"]
source = model_config["source"]
kwargs = model_config.get("kwargs", {})
extractor, dataset, batches = create_extractor_and_dataloader(
model_name, pretrained, source, kwargs
)

modules = MODEL_AND_MODULE_NAMES[model_name]["modules"]
clip = MODEL_AND_MODULE_NAMES[model_name].get("clip", False)
modules = model_config["modules"]
clip = model_config.get("clip", False)
yield extractor, dataset, batches, modules, model_name, clip


def create_extractor_and_dataloader(
model_name: str, pretrained: bool, source: str, kwargs: dict = {}
model_name: str, pretrained: bool, source: str, kwargs: dict = {}
):
"""Iterate through models and create model, dataset and data loader."""
extractor = get_extractor(
Expand Down Expand Up @@ -310,6 +351,6 @@ def create_test_images(n_samples: int = NUM_SAMPLES) -> None:
noisy_img = test_img + np.random.randn(H, W, C)
noisy_img = noisy_img.astype(np.uint8)
imageio.imsave(
os.path.join(TEST_PATH, cls, f"test_img_{i+1:03d}.png"), noisy_img
os.path.join(TEST_PATH, cls, f"test_img_{i + 1:03d}.png"), noisy_img
)
print("\n...Successfully created image dataset for testing.\n")

0 comments on commit 4f99f0a

Please sign in to comment.