Skip to content

Commit

Permalink
release PAOT code and model
Browse files Browse the repository at this point in the history
  • Loading branch information
yoxu515 committed Jun 27, 2023
1 parent bf91804 commit 04fe7d9
Show file tree
Hide file tree
Showing 200 changed files with 7,871 additions and 1,113 deletions.
29 changes: 0 additions & 29 deletions LICENSE

This file was deleted.

115 changes: 0 additions & 115 deletions MODEL_ZOO.md

This file was deleted.

174 changes: 47 additions & 127 deletions README.md

Large diffs are not rendered by default.

69 changes: 54 additions & 15 deletions configs/default.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@
class DefaultEngineConfig():
def __init__(self, exp_name='default', model='aott'):
model_cfg = importlib.import_module('configs.models.' +
model).ModelConfig()
model.lower()).ModelConfig()
self.__dict__.update(model_cfg.__dict__) # add model config

self.EXP_NAME = exp_name + '_' + self.MODEL_NAME

self.STAGE_NAME = 'YTB'

self.DATASETS = ['youtubevos']
self.DATA_WORKERS = 8
self.DATA_WORKERS = 16 #8
self.DATA_RANDOMCROP = (465,
465) if self.MODEL_ALIGN_CORNERS else (464,
464)
Expand All @@ -22,16 +22,35 @@ def __init__(self, exp_name='default', model='aott'):
self.DATA_SHORT_EDGE_LEN = 480
self.DATA_MIN_SCALE_FACTOR = 0.7
self.DATA_MAX_SCALE_FACTOR = 1.3

self.DATA_PRE_STRONG_AUG = False # for PRE
self.DATA_TPS_PROB = 0.0
self.DATA_TPS_SCALE = 0.0
self.DATA_RANDOM_GAUSSIAN_BLUR = 0.0 #0.3
self.DATA_RANDOM_GRAYSCALE = 0.0 #0.2
self.DATA_RANDOM_COLOR_JITTER = 0.0 #0.8

self.DATA_RANDOM_REVERSE_SEQ = True
self.DATA_SEQ_LEN = 5
self.DATA_DAVIS_REPEAT = 5
self.DATA_YTB_REPEAT = 1
self.DATA_RANDOM_GAP_DAVIS = 12 # max frame interval between two sampled frames for DAVIS (24fps)
self.DATA_RANDOM_GAP_YTB = 3 # max frame interval between two sampled frames for YouTube-VOS (6fps)
self.DATA_RANDOM_GAP_BL30K = 12
self.DATA_RANDOM_GAP_VIP = 3
self.DATA_DYNAMIC_MERGE_PROB = 0.3
self.DATA_DYNAMIC_MERGE_PROB_BL30K = 0.0
self.DATA_DYNAMIC_MERGE_PROB_VIP = 0.1

self.DATA_YTB_BALANCE_SAMPLE = False
self.DATA_YTB_BALANCE_RATIO = 0.0
self.DATA_YTB_USE_VOSP = False

self.PRETRAIN = True
self.PRETRAIN_FULL = False # if False, load encoder only
self.PRETRAIN_MODEL = './data_wd/pretrain_model/mobilenet_v2.pth'
self.PRETRAIN_MODEL = ''
self.PRETRAIN_ID_MODEL = ''
# self.PRETRAIN_MODEL = './data_wd/pretrain_model/mobilenet_v2.pth'
# self.PRETRAIN_MODEL = './pretrain_models/mobilenet_v2-b0353104.pth'

self.TRAIN_TOTAL_STEPS = 100000
Expand All @@ -58,17 +77,19 @@ def __init__(self, exp_name='default', model='aott'):
self.TRAIN_SGD_MOMENTUM = 0.9
self.TRAIN_GPUS = 4
self.TRAIN_BATCH_SIZE = 16
self.TRAIN_TBLOG = False
self.TRAIN_TBLOG = True
self.TRAIN_TBLOG_STEP = 50
self.TRAIN_LOG_STEP = 20
self.TRAIN_IMG_LOG = True
self.TRAIN_IMG_LOG = False
self.TRAIN_TOP_K_PERCENT_PIXELS = 0.15
self.TRAIN_SEQ_TRAINING_FREEZE_PARAMS = ['patch_wise_id_bank']
self.TRAIN_SEQ_TRAINING_FREEZE_PARAMS = ['patch_wise_id_bank','id_encoder','id_post_conv']
self.TRAIN_SEQ_TRAINING_START_RATIO = 0.5
self.TRAIN_HARD_MINING_RATIO = 0.5
self.TRAIN_EMA_RATIO = 0.1
self.TRAIN_CLIP_GRAD_NORM = 5.
self.TRAIN_SAVE_STEP = 5000
self.TRAIN_SAVE_STEP = 1000
self.TRAIN_SAVE_MED_STEP = 10000
self.TRAIN_START_SAVE_MED_RATIO = 1.0
self.TRAIN_MAX_KEEP_CKPT = 8
self.TRAIN_RESUME = False
self.TRAIN_RESUME_CKPT = None
Expand All @@ -84,6 +105,7 @@ def __init__(self, exp_name='default', model='aott'):
self.TRAIN_LSTT_DROPPATH_LST = False
self.TRAIN_LSTT_LT_DROPOUT = 0.
self.TRAIN_LSTT_ST_DROPOUT = 0.
self.TRAIN_PANO = False

self.TEST_GPU_ID = 0
self.TEST_GPU_NUM = 1
Expand All @@ -95,40 +117,57 @@ def __init__(self, exp_name='default', model='aott'):
# if "None", evaluate the latest checkpoint.
self.TEST_CKPT_STEP = None
self.TEST_FLIP = False
self.TEST_INPLACE_FLIP = False
self.TEST_MULTISCALE = [1]
self.TEST_MAX_SHORT_EDGE = None
self.TEST_MAX_LONG_EDGE = 800 * 1.3
self.TEST_WORKERS = 4
self.TEST_SAVE_PROB = False
self.TEST_SAVE_PROB_SCALE = 0.5
self.TEST_SAVE_LOGIT = False
self.TEST_BOX_FILTER = False
self.TEST_TOP_K = -1
self.TEST_PANO = False

self.TEST_INTERMEDIATE_PRED = False
self.TRAIN_INTERMEDIATE_PRED_LOSS = False

# GPU distribution
self.DIST_ENABLE = True
self.DIST_BACKEND = "nccl" # "gloo"
self.DIST_URL = "tcp://127.0.0.1:13241"
self.DIST_START_GPU = 0

def init_dir(self):
self.DIR_DATA = '../VOS02/datasets'#'./datasets'
def init_dir(self,data='./datasets',root='./results',eval=None):
self.DIR_DATA = data
self.DIR_DAVIS = os.path.join(self.DIR_DATA, 'DAVIS')
self.DIR_YTB = os.path.join(self.DIR_DATA, 'YTB')
self.DIR_STATIC = os.path.join(self.DIR_DATA, 'Static')
self.DIR_BL30K = os.path.join(self.DIR_DATA,'BL30K')
self.DIR_VIP = os.path.join(self.DIR_DATA,'VIPOSeg')

self.DIR_ROOT = './'#'./data_wd/youtube_vos_jobs'
self.DIR_ROOT = root

self.DIR_RESULT = os.path.join(self.DIR_ROOT, 'result', self.EXP_NAME,
self.STAGE_NAME)
self.DIR_CKPT = os.path.join(self.DIR_RESULT, 'ckpt')
self.DIR_AUX = os.path.join(self.DIR_RESULT, 'aux')
self.DIR_EMA_CKPT = os.path.join(self.DIR_RESULT, 'ema_ckpt')
self.DIR_MED_CKPT = os.path.join(self.DIR_RESULT, 'med_ckpt')
self.DIR_LOG = os.path.join(self.DIR_RESULT, 'log')
self.DIR_TB_LOG = os.path.join(self.DIR_RESULT, 'log', 'tensorboard')
# self.DIR_IMG_LOG = os.path.join(self.DIR_RESULT, 'log', 'img')
# self.DIR_EVALUATION = os.path.join(self.DIR_RESULT, 'eval')
self.DIR_IMG_LOG = './img_logs'
self.DIR_EVALUATION = './results'
self.DIR_IMG_LOG = os.path.join(self.DIR_RESULT, 'log', 'img')

if eval == None:
self.DIR_EVALUATION = os.path.join(self.DIR_RESULT, 'eval')
else:
self.DIR_EVAL = os.path.join(eval,'result',self.EXP_NAME,self.STAGE_NAME)
self.DIR_EVALUATION = os.path.join(self.DIR_EVAL,'eval')

for path in [
self.DIR_RESULT, self.DIR_CKPT, self.DIR_EMA_CKPT,
self.DIR_LOG, self.DIR_EVALUATION, self.DIR_IMG_LOG,
self.DIR_TB_LOG
self.DIR_TB_LOG, self.DIR_MED_CKPT
]:
if not os.path.isdir(path):
try:
Expand Down
17 changes: 17 additions & 0 deletions configs/models/aost.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from .default import DefaultModelConfig


class ModelConfig(DefaultModelConfig):
def __init__(self):
super().__init__()
self.MODEL_NAME = 'AOST'
self.MODEL_VOS = 'aost'
self.MODEL_ENGINE = 'aostengine'

self.MODEL_ENCODER = 'mobilenetv2'

self.MODEL_LSTT_NUM = 3

self.TRAIN_LONG_TERM_MEM_GAP = 2

self.TEST_LONG_TERM_MEM_GAP = 5
18 changes: 18 additions & 0 deletions configs/models/aost_share.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from .default import DefaultModelConfig


class ModelConfig(DefaultModelConfig):
def __init__(self):
super().__init__()
self.MODEL_NAME = 'AOST_Share'
self.MODEL_VOS = 'aost_share'
self.MODEL_ENGINE = 'aostengine'

self.MODEL_ENCODER = 'mobilenetv2'

self.MODEL_LSTT_NUM = 3
self.MODEL_DECODER_INTERMEDIATE_LSTT = False

self.TRAIN_LONG_TERM_MEM_GAP = 2

self.TEST_LONG_TERM_MEM_GAP = 5
2 changes: 1 addition & 1 deletion configs/models/aotl.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
from .default import DefaultModelConfig


class ModelConfig(DefaultModelConfig):
def __init__(self):
super().__init__()
Expand Down
16 changes: 16 additions & 0 deletions configs/models/aotv2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from .default import DefaultModelConfig


class ModelConfig(DefaultModelConfig):
def __init__(self):
super().__init__()
self.MODEL_NAME = 'AOTv2'
self.MODEL_VOS = 'aotv2'
self.MODEL_ENGINE = 'aotv2engine'
self.MODEL_ENCODER = 'mobilenetv2'
self.MODEL_DECODER_INTERMEDIATE_LSTT = False
self.MODEL_LSTT_NUM = 3

self.TRAIN_LONG_TERM_MEM_GAP = 2

self.TEST_LONG_TERM_MEM_GAP = 5
7 changes: 5 additions & 2 deletions configs/models/deaotl.py → configs/models/aotv2l.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from .default_deaot import DefaultModelConfig
from .default import DefaultModelConfig


class ModelConfig(DefaultModelConfig):
def __init__(self):
super().__init__()
self.MODEL_NAME = 'DeAOTL'
self.MODEL_NAME = 'AOTv2L'
self.MODEL_VOS = 'aotv2'

self.MODEL_ENCODER = 'mobilenetv2'

self.MODEL_LSTT_NUM = 3

Expand Down
12 changes: 12 additions & 0 deletions configs/models/aotv3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from .default import DefaultModelConfig


class ModelConfig(DefaultModelConfig):
def __init__(self):
super().__init__()
self.MODEL_NAME = 'AOTv3'
self.MODEL_VOS = 'aotv3'
self.MODEL_ENGINE = 'aotv3engine'
self.MODEL_ENCODER = 'mobilenetv2'

self.MODEL_DECODER_INTERMEDIATE_LSTT = True
9 changes: 0 additions & 9 deletions configs/models/deaots.py

This file was deleted.

7 changes: 0 additions & 7 deletions configs/models/deaott.py

This file was deleted.

41 changes: 39 additions & 2 deletions configs/models/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,45 @@ def __init__(self):
self.MODEL_USE_PREV_PROB = False

self.TRAIN_LONG_TERM_MEM_GAP = 9999
self.TRAIN_AUG_TYPE = 'v1'

self.TEST_LONG_TERM_MEM_GAP = 9999
self.TEST_SHORT_TERM_MEM_GAP = 1
self.TEST_LONG_TERM_MEM_MAX = 9999

self.TEST_SHORT_TERM_MEM_SKIP = 1
# multi-scale param
self.MODEL_MS_LSTT_NUMS = [2,1,1,1]
self.MODEL_MS_ENCODER_EMBEDDING_DIMS = [256,256,128,128]
self.MODEL_MS_SCALES = [16,16,8,4]
self.MODEL_MS_SELF_HEADS = [8,1,1,1]
self.MODEL_MS_ATT_HEADS = [8,1,1,1]
self.MODEL_MS_ATT_DIMS = [None,None,None,None]
self.MODEL_MS_FEEDFOWARD_DIMS = [1024,1024,512,512]
self.MODEL_MS_GLOBAL_DILATIONS = [1,1,2,4]
self.MODEL_MS_LOCAL_DILATIONS = [1,1,1,1]
self.MODEL_MS_CONV_DILATION = False
self.TRAIN_MS_LSTT_EMB_DROPOUTS = [0.,0.,0.,0.]
self.MODEL_MS_SHARE_ID = False
self.MODEL_MS_SHARE_ID_SCALE = 0
self.MODEL_DECODER_RES = False
self.MODEL_DECODER_RES_IN = False
self.MODEL_USE_RELATIVE_V = True
self.MODEL_USE_SELF_POS = True

self.TRAIN_MS_LSTT_DROPPATH = [0.1,0.1,0.1,0.1]
self.TRAIN_MS_LSTT_DROPPATH_SCALING = [False,False,False,False]
self.TRAIN_MS_LSTT_DROPPATH_LST = [False,False,False,False]
self.TRAIN_MS_LSTT_LT_DROPOUT = [0.,0.,0.,0.]
self.TRAIN_MS_LSTT_ST_DROPOUT = [0.,0.,0.,0.]
self.TRAIN_MS_LSTT_MEMORY_DILATION = True

self.MODEL_USE_ID_ENCODER = False
self.MODEL_ID_ENCODER = 'resnet18'
self.MODEL_ID_ENCODER_DIM = [64, 128, 256, 256]
self.MODEL_ID_ENCODER_FROZEN_BN = False
self.MODEL_ID_ENCODER_FREEZE_AT = -1
self.MODEL_ID_ENCODER_USE_LN = False
self.MODEL_USE_ID_BANK_POST_CONV = False
self.MODEL_ID_BANK_POST_CONV_USE_LN = False
self.MODEL_SEP_ID_BANK = False


Empty file modified configs/models/default_deaot.py
100755 → 100644
Empty file.
3 changes: 1 addition & 2 deletions configs/models/r101_aotl.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import os
from .default import DefaultModelConfig


class ModelConfig(DefaultModelConfig):
def __init__(self):
super().__init__()
self.MODEL_NAME = 'R101_AOTL'

self.MODEL_ENCODER = 'resnet101'
self.MODEL_ENCODER_PRETRAIN = './pretrain_models/resnet101-63fe2227.pth' # https://download.pytorch.org/models/resnet101-63fe2227.pth
self.MODEL_ENCODER_DIM = [256, 512, 1024, 1024] # 4x, 8x, 16x, 16x
self.MODEL_LSTT_NUM = 3

Expand Down
18 changes: 18 additions & 0 deletions configs/models/r50_aost.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from .default import DefaultModelConfig


class ModelConfig(DefaultModelConfig):
def __init__(self):
super().__init__()
self.MODEL_NAME = 'R50_AOST'
self.MODEL_VOS = 'aost'
self.MODEL_ENGINE = 'aostengine'

self.MODEL_ENCODER = 'resnet50'
self.MODEL_ENCODER_DIM = [256, 512, 1024, 1024] # 4x, 8x, 16x, 16x

self.MODEL_LSTT_NUM = 3

self.TRAIN_LONG_TERM_MEM_GAP = 2

self.TEST_LONG_TERM_MEM_GAP = 5
19 changes: 19 additions & 0 deletions configs/models/r50_aost_share.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from .default import DefaultModelConfig


class ModelConfig(DefaultModelConfig):
def __init__(self):
super().__init__()
self.MODEL_NAME = 'R50_AOST_Share'
self.MODEL_VOS = 'aost_share'
self.MODEL_ENGINE = 'aostengine'

self.MODEL_ENCODER = 'resnet50'
self.MODEL_ENCODER_DIM = [256, 512, 1024, 1024] # 4x, 8x, 16x, 16x

self.MODEL_LSTT_NUM = 3
self.MODEL_DECODER_INTERMEDIATE_LSTT = False

self.TRAIN_LONG_TERM_MEM_GAP = 2

self.TEST_LONG_TERM_MEM_GAP = 5
Loading

0 comments on commit 04fe7d9

Please sign in to comment.