diff --git a/README.md b/README.md index dc0cf7a8..3146659f 100644 --- a/README.md +++ b/README.md @@ -23,7 +23,7 @@ Here are the [results](https://github.com/AntixK/PyTorch-VAE/blob/master/README. ### Requirements - Python >= 3.5 - PyTorch >= 1.3 -- Pytorch Lightning >= 0.6.0 ([GitHub Repo](https://github.com/PyTorchLightning/pytorch-lightning/tree/deb1581e26b7547baf876b7a94361e60bb200d32)) +- Pytorch Lightning >= 1.0.0 ([GitHub Repo](https://github.com/PyTorchLightning/pytorch-lightning/tree/deb1581e26b7547baf876b7a94361e60bb200d32)) - CUDA enabled computing device ### Installation @@ -60,7 +60,7 @@ exp_params: trainer_params: gpus: 1 - max_nb_epochs: 50 + max_epochs: 50 gradient_clip_val: 1.5 . . diff --git a/run.py b/run.py index d95dd657..5252a1f8 100644 --- a/run.py +++ b/run.py @@ -6,7 +6,7 @@ from experiment import VAEXperiment import torch.backends.cudnn as cudnn from pytorch_lightning import Trainer -from pytorch_lightning.logging import TestTubeLogger +from pytorch_lightning.loggers import TestTubeLogger parser = argparse.ArgumentParser(description='Generic runner for VAE models') @@ -42,14 +42,13 @@ config['exp_params']) runner = Trainer(default_save_path=f"{tt_logger.save_dir}", - min_nb_epochs=1, + min_epochs=1, logger=tt_logger, - log_save_interval=100, - train_percent_check=1., - val_percent_check=1., + flush_logs_every_n_steps=100, + limit_train_batches=1., + limit_val_batches=1., num_sanity_val_steps=5, - early_stop_callback = False, **config['trainer_params']) print(f"======= Training {config['model_params']['name']} =======") -runner.fit(experiment) \ No newline at end of file +runner.fit(experiment)