Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

用自己的数据集来 预训练 模型时,如何配置 #59

Open
emanlee opened this issue Dec 30, 2023 · 2 comments
Open

用自己的数据集来 预训练 模型时,如何配置 #59

emanlee opened this issue Dec 30, 2023 · 2 comments

Comments

@emanlee
Copy link

emanlee commented Dec 30, 2023

作者您好!

我们想用自己的数据集来 预训练 模型。
我们看了一下 TSFormer_METR-LA.py 文件,里面有些配置不明白。因此,请教一下。

CFG.DATASET_INPUT_LEN = 288 * 7 这个288和7分别表示什么,谢谢!

CFG.MODEL.PARAM = {
"patch_size":12, ############ 请问这个12表示输出时间步吗?
"in_channel":1,
"embed_dim":96,
"num_heads":4,
"mlp_ratio":4,
"dropout":0.1,
"num_token":288 * 7 / 12, ############ 请问这个地方为什么除以12?
"mask_ratio":0.75,
"encoder_depth":4,
"decoder_depth":1,
"mode":"pre-train"
}

从 TSFormer_METR-LA.py 文件看不出原始的输入数据的文件名以及文件扩展名,请问输入数据应该放到哪个文件夹,并且,文件取名有什么要求吗? 是不是要这样取名 METR-LA.h5 ? 还是要类似于 scaler_in2016_out12.pkl ?

谢谢!

@zezhishao
Copy link
Collaborator

  1. 2887是历史时间步长。288个时间片在METR-LA中涵盖了一天的数据,因此2887代表着用历史七天数据进行预训练。这样方便显式捕捉“天”周期性和“周”周期性。
  2. patch_size不是输出步长,预训练阶段是一个重构任务,输出步长不参与运算。patch_size是切片大小。具体细节可以看论文呢,有详细描述如何做切片(patchify)
  3. num_token是token的数量。这个在论文中有专门提到,使用patch作为基本输入单元,而不是常用的point。因此,真正输入到模型的token的序列长度是288*7/12(序列长度/patch长度=patch数量=token数量)。
  4. 添加新的数据集的话,这个比较困难。你需要学习一下generate_training_data.py,仿照它的逻辑产生符合BasicTS规范的自己的数据集。产生完之后就比较简单了,在配置文件中改一下数据集名称,他就会自动去寻找然后读取了。

@emanlee
Copy link
Author

emanlee commented Dec 31, 2023

十分感谢您的解释!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants