Skip to content

Code for "BECoTTA: Input-dependent Online Blending of Experts for Continual Test-time Adaptation [ICML2024]".

Notifications You must be signed in to change notification settings

daeunni/BECoTTA

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

23 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

BECoTTA: Input-dependent Online Blending of Experts for Continual Test-time Adaptation [ICML2024]

Project Website arXiv

🚨 The new version of BECoTTA will be updated soon! We improved some methods and they will be updated soon.

image

Abstract

Continual Test Time Adaptation (CTTA) is required to adapt efficiently to continuous unseen domains while retaining previously learned knowledge. However, despite the progress of CTTA, forgetting-adaptation trade-offs and efficiency are still unexplored. Moreover, current CTTA scenarios assume only the disjoint situation, even though real-world domains are seamlessly changed. To tackle these challenges, this paper proposes BECoTTA, an input-dependent yet efficient framework for CTTA. We propose Mixture-of-Domain Low-rank Experts (MoDE) that contains two core components: i) Domain-Adaptive Routing, which aids in selectively capturing the domain-adaptive knowledge with multiple domain routers, and (ii) Domain-Expert Synergy Loss to maximize the dependency between each domain and expert. We validate our method outperforms multiple CTTA scenarios including disjoint and gradual domain shits, while only requiring ∼98% fewer trainable parameters. We also provide analyses of our method, including the construction of experts, the effect of domain-adaptive experts, and visualizations.

🚗 Main process of CTTA (Continual Test-time Adaptation)

  • You can set our main config file. becotta/local_configs/segformer/B5/tta.py
  • You can find our initialized model here. Please note that this code is based on w/ WAD setting.
# CTTA process 
bash ./tools/becotta.sh

🖥️ Setup

[1] Environment

  • We follow mmsegmentation code base provided by CoTTA authors.

    • You can refer to this issue related to environment setup.
    • 📣 Note that our source model (Segformer) mainly uses pretty low mmcv version. (mmcv==1.2.0)

1. You can create conda environment using .yaml file we provided.

conda env update --name cotta --file environment.yml
conda activate cotta

2. You can install mmcv by yourself.

  • Our code is tested torch 1.7.0 + cuda 11.0 + mmcv 1.2.0
pip install torch==1.7.0+cu110 torchvision==0.8.1+cu110 https://download.pytorch.org/whl/torch_stable.html
  • Install lower version of mmcv refer to this issue.
pip install mmcv-full=={mmcv_version} -f https://download.openmmlab.com/mmcv/dist/cu110/torch1.7.0/index.html

[2] Dataset

  • You can download the target domain ACDC dataset from here.

    • Setup Fog -> Night -> Rain -> Snow scenario using train dataset.

    • You need to change becotta/local_configs/_base_/datasets/acdc_1024x1024_repeat_origin.py to your own path.

      # dataset settings
      dataset_type = 'ACDCDataset'
      data_root = 'your data path'   
  • We also provide a bunch of config files of driving datasets at becotta/mmseg/datasets! However, note that you should match the segmentation label format with Cityscapes style. You can freely use these data configs and design your own scenario.

    • BDD100k: bdd.py
    • Kitti Seg: kitti.py
    • Foggy Driving: fog.py
    • GTAV & Synthetia: gtav_syn.py
    • Dark Zurich: dark.py

[3] Pre-trained model

  • We mainly adopt pre-trained Segformer with Cityscapes dataset.
    • You can segformer.b5.1024x1024.city.160k.pth here.
    • Also, you can find mit_b5.pth backbone here. Please located them at ./pretrained/ directory.

📁 Note

[1] Checkpoint of our model

  • We provide our trained initialized model checkpoints here.

[2] Flexibility of BECoTTA

  • As we mentioned in our paper, you can freely change the rank of experts, number of experts, and selected number of experts ($K$).

  • e.g. You can modify it as follows.

    class mit_b5_EveryMOEadapter_wDomain(MOE_MixVisionTransformer_EveryAdapter_wDomain):
        def __init__(self, **kwargs):
            super(mit_b5_EveryMOEadapter_wDomain, self).__init__(
                patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], 
                mlp_ratios=[4, 4, 4, 4],
                qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), 
                depths=[3, 6, 40, 3], 
                sr_ratios=[8, 4, 2, 1],
                drop_rate=0.0, drop_path_rate=0.1, 
                expert_num=6,                        # Modify here 
                select_mode='new_topk', 
                hidden_dims = [2, 4, 10, 16],        # Modify here 
                num_k = 3                            # Modify here 
                )   
  • Our model utilized these parameters as follows. Please refer to our main paper for more details.

    Exp K Rank MoDE
    BECoTTA-S 4 3 [0, 0, 0, 6] Last
    BECoTTA-M 6 3 [2, 4, 10, 16] Every
    BECoTTA-L 6 3 [16, 32, 60, 80] Every

[3] TODO

  • Construction process of Continual Gradual Shifts (CGS) scenario will be updated.
  • Warmup initializing process will be updated.
  • Whole process of CTTA was added.

Reference

@inproceedings{lee2024becotta,
    title={BECoTTA: Input-dependent Online Blending of Experts for Continual Test-time Adaptation},
    author={Lee, Daeun and Yoon, Jaehong and Hwang, Sung Ju},
    booktitle={International Conference on Machine Learning},
    year={2024},
}

About

Code for "BECoTTA: Input-dependent Online Blending of Experts for Continual Test-time Adaptation [ICML2024]".

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published