Skip to content

Commit

Permalink
Removing argparse from tblogger example
Browse files Browse the repository at this point in the history
  • Loading branch information
Neeratyoy committed Oct 31, 2023
1 parent 17b9863 commit 74e1c0b
Showing 1 changed file with 24 additions and 33 deletions.
57 changes: 24 additions & 33 deletions neps_examples/convenience/neps_tblogger_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
"""

import argparse
import logging
import random
import time
Expand Down Expand Up @@ -95,7 +94,7 @@ def set_seed(seed=123):


def MNIST(
batch_size: int = 32, n_train_size: float = 0.9
batch_size: int = 256, n_train_size: float = 0.9
) -> Tuple[DataLoader, DataLoader, DataLoader]:
# Download MNIST training and test datasets if not already downloaded.
train_dataset = torchvision.datasets.MNIST(
Expand Down Expand Up @@ -226,7 +225,7 @@ def training(
# Design the pipeline search spaces.


def pipeline_space_BO() -> dict:
def pipeline_space() -> dict:
pipeline = dict(
lr=neps.FloatParameter(lower=1e-5, upper=1e-1, log=True),
optim=neps.CategoricalParameter(choices=["Adam", "SGD"]),
Expand All @@ -240,7 +239,7 @@ def pipeline_space_BO() -> dict:
# Implement the pipeline run search.


def run_pipeline_BO(lr, optim, weight_decay):
def run_pipeline(lr, optim, weight_decay):
# Create the network model.
model = MLP()

Expand Down Expand Up @@ -334,15 +333,6 @@ def run_pipeline_BO(lr, optim, weight_decay):
python neps_tblogger_tutorial.py
```
"""
parser = argparse.ArgumentParser()
parser.add_argument(
"--max_evaluations_total",
type=int,
default=3,
help="Number of different configs to train",
)
args = parser.parse_args()

start_time = time.time()

set_seed(112)
Expand All @@ -351,16 +341,18 @@ def run_pipeline_BO(lr, optim, weight_decay):
# To check the status of tblogger:
# tblogger.get_status()

# tblogger.disable()

neps.run(
run_pipeline=run_pipeline_BO,
pipeline_space=pipeline_space_BO(),
run_args = dict(
run_pipeline=run_pipeline,
pipeline_space=pipeline_space(),
root_directory="output",
max_evaluations_total=args.max_evaluations_total,
searcher="random_search",
)

neps.run(
**run_args,
max_evaluations_total=2,
)

"""
To check live plots during this search, please open a new terminal
and make sure to be at the same level directory of your project and
Expand All @@ -382,20 +374,19 @@ def run_pipeline_BO(lr, optim, weight_decay):
actually exists.
"""

end_time = time.time() # Record the end time
execution_time = end_time - start_time
logging.info(f"Execution time: {execution_time} seconds")
# Disables tblogger for the continued run
tblogger.disable()

"""
After your first run, you can continue with more experiments by
uncommenting `tblogger.enable()` before `neps.run()`and running
the following command in your terminal:
```bash:
python neps_tblogger_tutorial.py --max_evaluations_total 10
```

neps.run(
**run_args,
max_evaluations_total=3, # continues the previous run for 1 more evaluation
)

This adds seven more configurations to your search and turns off tblogger.
By default, tblogger is on, but you can control it with `tblogger.enable()`
or `tblogger.disable()` in your code."
"""
This second run of one more configuration will not add to the tensorboard logs.
"""

end_time = time.time() # Record the end time
execution_time = end_time - start_time
logging.info(f"Execution time: {execution_time} seconds\n")

0 comments on commit 74e1c0b

Please sign in to comment.