Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Smooth sharp edges #87

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .python-version
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
3.8
2,819 changes: 2,819 additions & 0 deletions poetry.lock

Large diffs are not rendered by default.

41 changes: 41 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
[tool.poetry]
name = "spacetimeformer"
version = "0.1.0"
description = "QData Research Library for the Spacetimeformer project"
authors = ["[email protected]"]
readme = "README.md"

[tool.poetry.dependencies]
python = "^3.8"
cython = ">=0.22"
cmdstanpy = "0.9.68"
# https://discourse.mc-stan.org/t/pystan-build-is-not-pep-517-compliant/32971/6
# https://pystan.readthedocs.io/en/latest/faq.html#how-can-i-run-pystan-on-macos-with-apple-silicon-chips-apple-m1-m2-etc
pystan = ">=2.19.1.1,<2.20.0.0"
numpy = ">=1.15.4"
pandas = ">=1.0.4"
matplotlib = ">=2.0.0"
convertdate = ">=2.1.2"
python-dateutil = ">=2.8.0"
performer-pytorch = "^1.1.4"
tqdm = ">=4.36.1"
nystrom-attention = "^0.0.11"
pytorch-lightning = "1.6"
netcdf4 = "^1.6.5"
scikit-learn = "^1.3.2"
omegaconf = "^2.3.0"
seaborn = "^0.13.0"
opencv-python = "^4.8.1.78"
wandb = "^0.16.0"
einops = "^0.7.0"
chardet = "^5.2.0"
opt-einsum = "^3.3.0"
torchmetrics = "0.5.1"
torch = "1.11.0"
torchvision = "0.12.0"
joblib = "^1.3.2"


[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
15 changes: 14 additions & 1 deletion spacetimeformer/data/csv_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __init__(

if raw_df is None:
self.data_path = data_path
assert os.path.exists(self.data_path)
assert os.path.exists(self.data_path), f"data_path not found: ${self.data_path}"
raw_df = pd.read_csv(
self.data_path,
**read_csv_kwargs,
Expand Down Expand Up @@ -212,6 +212,10 @@ def val_data(self):
def test_data(self):
return self._test_data

@property
def scaler_obj(self):
return self._scaler

def length(self, split):
return {
"train": len(self.train_data),
Expand Down Expand Up @@ -240,6 +244,14 @@ def __init__(
self.target_points = target_points
self.time_resolution = time_resolution

assert (
self.series.length(split) + time_resolution * (-target_points - context_points) + 1 > 0
), (f"Dataset length for split {split} is negative. Check time_resolution, context_points, and target_points.\n"
f"Dataset length: {self.series.length(split)}\n"
f"Target points: {target_points}\n"
f"Context points: {context_points}\n"
f"Time Resolution: {time_resolution}")

self._slice_start_points = [
i
for i in range(
Expand All @@ -249,6 +261,7 @@ def __init__(
+ 1,
)
]
print(f"${split} dataset length: {len(self)}")

def __len__(self):
return len(self._slice_start_points)
Expand Down
731 changes: 731 additions & 0 deletions spacetimeformer/data/sine_waves_with_dates.csv

Large diffs are not rendered by default.

6 changes: 2 additions & 4 deletions spacetimeformer/forecaster.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,10 +293,8 @@ def configure_optimizers(self):
)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": scheduler,
"monitor": "val/loss",
},
"lr_scheduler": scheduler,
"monitor": "val/loss",
}

@classmethod
Expand Down
5 changes: 5 additions & 0 deletions spacetimeformer/spacetimeformer_model/nn/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,11 @@ def forward(self, x, self_mask_seq=None, output_attn=False):
if self.local_attention:
# attention on tokens of each variable ind.
x1 = self.norm1(x)
# assert that second dim of x1 is a multiple of d_yc
assert (x1.shape[1] % self.d_yc == 0), (
"x1.shape[1] is not a multiple of d_yc. Check that train arg yc_dim matches the number of variables"
f"x1.shape[1] = {x1.shape[1]}, d_yc = {self.d_yc}"
)
x1 = Localize(x1, self.d_yc)
# TODO: localize self_mask_seq
x1, _ = self.local_attention(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,11 @@ def configure_optimizers(self):
patience=3,
factor=self.decay_factor,
)
return [self.optimizer], [self.scheduler]
return {
"optimizer": self.optimizer,
"lr_scheduler": self.scheduler,
"monitor": "val/loss",
}

@classmethod
def add_cli(self, parser):
Expand Down
36 changes: 33 additions & 3 deletions spacetimeformer/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,10 @@
import warnings
import os
import uuid

import pytorch_lightning as pl
import torch

import spacetimeformer as stf
import pickle

_MODELS = ["spacetimeformer", "mtgnn", "heuristic", "lstm", "lstnet", "linear", "s4"]

Expand All @@ -19,6 +18,7 @@
"exchange",
"precip",
"toy2",
"sinewaves",
"solar_energy",
"syn",
"mnist",
Expand Down Expand Up @@ -105,6 +105,8 @@ def create_parser():
parser.add_argument("--accumulate", type=int, default=1)
parser.add_argument("--val_check_interval", type=float, default=1.0)
parser.add_argument("--limit_val_batches", type=float, default=1.0)
parser.add_argument("--max_epochs", type=int)
parser.add_argument("--log_every_n_steps", type=int, default=50)
parser.add_argument("--no_earlystopping", action="store_true")
parser.add_argument("--patience", type=int, default=5)
parser.add_argument(
Expand All @@ -119,6 +121,9 @@ def create_parser():


def create_model(config):
# x_dim time embedding dimension
# yc_dim number of variables in context
# yt_dim number of variables in target
x_dim, yc_dim, yt_dim = None, None, None
if config.dset == "metr-la":
x_dim = 2
Expand Down Expand Up @@ -148,6 +153,10 @@ def create_model(config):
x_dim = 6
yc_dim = 20
yt_dim = 20
elif config.dset == "sinewaves":
x_dim = 6
yc_dim = 3
yt_dim = 3
elif config.dset == "syn":
x_dim = 5
yc_dim = 20
Expand Down Expand Up @@ -448,6 +457,7 @@ def create_dset(config):
)
INV_SCALER = dset.reverse_scaling
SCALER = dset.apply_scaling
SCALER_OBJ = dset.scaler_obj
elif config.dset in ["mnist", "cifar"]:
if config.dset == "mnist":
config.target_points = 28 - config.context_points
Expand Down Expand Up @@ -575,6 +585,7 @@ def create_dset(config):
)
INV_SCALER = dset.reverse_scaling
SCALER = dset.apply_scaling
SCALER_OBJ = dset.scaler_obj
NULL_VAL = None
# PAD_VAL = -32.0
PLOT_VAR_NAMES = target_cols
Expand Down Expand Up @@ -605,6 +616,7 @@ def create_dset(config):
)
INV_SCALER = dset.reverse_scaling
SCALER = dset.apply_scaling
SCALER_OBJ = dset.scaler_obj
NULL_VAL = None
PLOT_VAR_NAMES = ["OT", "p (mbar)", "raining (s)"]
PLOT_VAR_IDXS = [20, 0, 15]
Expand All @@ -627,6 +639,11 @@ def create_dset(config):
else:
raise ValueError(f"Unrecognized toy dataset {config.dset}")
target_cols = [f"D{i}" for i in range(1, 21)]
elif "sinewaves" in config.dset:
target_cols = [
"Sine Wave 1","Sine Wave 2","Sine Wave 3"
]
data_path = "./spacetimeformer/data/sine_waves_with_dates.csv"
elif config.dset == "exchange":
if data_path == "auto":
data_path = "./data/exchange_rate_converted.csv"
Expand Down Expand Up @@ -670,6 +687,7 @@ def create_dset(config):
)
INV_SCALER = dset.reverse_scaling
SCALER = dset.apply_scaling
SCALER_OBJ = dset.scaler_obj
NULL_VAL = None

return (
Expand All @@ -680,6 +698,7 @@ def create_dset(config):
PLOT_VAR_IDXS,
PLOT_VAR_NAMES,
PAD_VAL,
SCALER_OBJ,
)


Expand Down Expand Up @@ -770,8 +789,17 @@ def main(args):
plot_var_idxs,
plot_var_names,
pad_val,
scaler_obj,
) = create_dset(args)

# save scaler for inference post-training
with open('scaler_method.pkl', 'wb') as file:
pickle.dump(scaler, file)
with open('fitted_scaler_obj.pkl', 'wb') as file:
pickle.dump(scaler_obj, file)

assert (len(data_module.test_dataloader()) > 0), "The DataLoader should not be empty, check the Dataset __init__ and __getitem__"

# Model
args.null_value = null_val
args.pad_value = pad_val
Expand Down Expand Up @@ -840,8 +868,10 @@ def main(args):
gradient_clip_algorithm="norm",
overfit_batches=20 if args.debug else 0,
accumulate_grad_batches=args.accumulate,
sync_batchnorm=True,
sync_batchnorm=False, #set False on CPU, else "SyncBatchNorm layers only work with GPU modules"
limit_val_batches=args.limit_val_batches,
max_epochs=args.max_epochs,
log_every_n_steps=args.log_every_n_steps,
**val_control,
)

Expand Down