Skip to content

Commit

Permalink
refactor build api in data and nn (deepmodeling#106)
Browse files Browse the repository at this point in the history
* refactor build api in data and nn

* fix random choice out

* fix text json

* update test data split
  • Loading branch information
floatingCatty authored Apr 2, 2024
1 parent 48e503c commit 7ff210a
Show file tree
Hide file tree
Showing 17 changed files with 112 additions and 99 deletions.
62 changes: 35 additions & 27 deletions dptb/data/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,26 +104,36 @@ def dataset_from_config(config, prefix: str = "dataset") -> AtomicDataset:
return instance


def build_dataset(set_options, common_options):
def build_dataset(
# set_options
root: str,
type: str = "DefaultDataset",
prefix: str = None,
get_Hamiltonian: bool = False,
get_overlap: bool = False,
get_DM: bool = False,
get_eigenvalues: bool = False,

# common_options
basis: str = None,
**kwargs,
):

"""
Build a dataset based on the provided set options and common options.
Args:
set_options (dict): A dictionary containing the set options for building the dataset.
- "type" (str): The type of dataset to build. Default is "DefaultDataset".
- "root" (str): The main directory storing all trajectory folders.
- "prefix" (str, optional): Load selected trajectory folders with the specified prefix.
- "get_Hamiltonian" (bool, optional): Load the Hamiltonian file to edges of the graph or not.
- "get_eigenvalues" (bool, optional): Load the eigenvalues to the graph or not.
e.g.
"train": {
"type": "DefaultDataset",
"root": "foo/bar/data_files_here",
"prefix": "set"
}
common_options (dict): A dictionary containing common options for building the dataset.
- "basis" (str, optional): The basis for the OrbitalMapper.
- type (str): The type of dataset to build. Default is "DefaultDataset".
- root (str): The main directory storing all trajectory folders.
- prefix (str, optional): Load selected trajectory folders with the specified prefix.
- get_Hamiltonian (bool, optional): Load the Hamiltonian file to edges of the graph or not.
- get_eigenvalues (bool, optional): Load the eigenvalues to the graph or not.
e.g.
type = "DefaultDataset",
root = "foo/bar/data_files_here",
prefix = "set"
- basis (str, optional): The basis for the OrbitalMapper.
Returns:
dataset: The built dataset.
Expand All @@ -132,18 +142,16 @@ def build_dataset(set_options, common_options):
ValueError: If the dataset type is not supported.
Exception: If the info.json file is not properly provided for a trajectory folder.
"""
dataset_type = set_options.get("type", "DefaultDataset")
dataset_type = type

if dataset_type in ["DefaultDataset", "DeePHDataset"]:
# See if we can get a OrbitalMapper.
if "basis" in common_options:
idp = OrbitalMapper(common_options["basis"])
if basis is not None:
idp = OrbitalMapper(basis=basis)
else:
idp = None

# Explore the dataset's folder structure.
root = set_options["root"]
prefix = set_options.get("prefix", None)
include_folders = []
for dir_name in os.listdir(root):
dir_path = os.path.join(root, dir_name)
Expand Down Expand Up @@ -194,18 +202,18 @@ def build_dataset(set_options, common_options):
dataset = DeePHE3Dataset(
root=root,
type_mapper=idp,
get_Hamiltonian=set_options.get("get_Hamiltonian", False),
get_eigenvalues=set_options.get("get_eigenvalues", False),
get_Hamiltonian=get_Hamiltonian,
get_eigenvalues=get_eigenvalues,
info_files = info_files
)
else:
dataset = DefaultDataset(
root=root,
type_mapper=idp,
get_Hamiltonian=set_options.get("get_Hamiltonian", False),
get_overlap=set_options.get("get_overlap", False),
get_DM=set_options.get("get_DM", False),
get_eigenvalues=set_options.get("get_eigenvalues", False),
get_Hamiltonian=get_Hamiltonian,
get_overlap=get_overlap,
get_DM=get_DM,
get_eigenvalues=get_eigenvalues,
info_files = info_files
)

Expand Down
13 changes: 7 additions & 6 deletions dptb/entrypoints/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ def data(
# }

# setup seed
setup_seed(seed=jdata.get("seed", 1021312))
rds = np.random.RandomState(1)


dataset_dir = jdata.get("dataset_dir")
filenames = list(glob.glob(os.path.join(dataset_dir, jdata.get("prefix")+"*")))
Expand All @@ -73,21 +74,21 @@ def data(
n_test = int(nfile * test_ratio)
n_val = int(nfile * val_ratio)

indices = np.random.choice(nfile, nfile, replace=False)
indices = rds.choice(nfile, nfile, replace=False)
train_indices = indices[:n_train]
val_indices = indices[n_train:n_train+n_val]
test_indices = indices[n_train+n_val:]

os.mkdir(os.path.join(dataset_dir, "train"))
os.mkdir(os.path.join(dataset_dir, "val"))

for id in tqdm(train_indices, desc="Copying files to training sets..."):
os.system("cp " + filenames[id] + " " + os.path.join(dataset_dir, "train")+" -r")
os.system(f"cp -r {filenames[id]} {dataset_dir}/train")

for id in tqdm(val_indices, desc="Copying files to validation sets..."):
os.system("cp " + filenames[id] + " " + os.path.join(dataset_dir, "val")+" -r")
os.system(f"cp -r {filenames[id]} {dataset_dir}/val")

if n_test > 0:
os.mkdir(os.path.join(dataset_dir, "test"))
for id in tqdm(test_indices, desc="Copying files to testing sets..."):
os.system("cp " + filenames[id] + " " + os.path.join(dataset_dir, "test")+" -r")
os.system(f"cp -r {filenames[id]} {dataset_dir}/test")
2 changes: 1 addition & 1 deletion dptb/entrypoints/pth2json.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def pth2json(
"log_path": log_path,
}

model = build_model(run_options)
model = build_model(run_options["init_model"])

if model.name == "nnsk":
nnsk = model
Expand Down
4 changes: 2 additions & 2 deletions dptb/entrypoints/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ def _test(
jdata["model_options"] = f["config"]["model_options"]
del f

test_datasets = build_dataset(set_options=jdata["data_options"]["test"], common_options=jdata["common_options"])
model = build_model(run_options=run_opt, model_options=jdata["model_options"], common_options=jdata["common_options"])
test_datasets = build_dataset(**jdata["data_options"]["test"], **jdata["common_options"])
model = build_model(run_opt["init_model"], model_options=jdata["model_options"], common_options=jdata["common_options"])
model.eval()
tester = Tester(
test_options=jdata["test_options"],
Expand Down
9 changes: 5 additions & 4 deletions dptb/entrypoints/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,13 +163,13 @@ def train(
# json.dump(jdata, fp, indent=4)

# build dataset
train_datasets = build_dataset(set_options=jdata["data_options"]["train"], common_options=jdata["common_options"])
train_datasets = build_dataset(**jdata["data_options"]["train"], **jdata["common_options"])
if jdata["data_options"].get("validation"):
validation_datasets = build_dataset(set_options=jdata["data_options"]["validation"], common_options=jdata["common_options"])
validation_datasets = build_dataset(**jdata["data_options"]["validation"], **jdata["common_options"])
else:
validation_datasets = None
if jdata["data_options"].get("reference"):
reference_datasets = build_dataset(set_options=jdata["data_options"]["reference"], common_options=jdata["common_options"])
reference_datasets = build_dataset(**jdata["data_options"]["reference"], **jdata["common_options"])
else:
reference_datasets = None

Expand All @@ -185,7 +185,8 @@ def train(
else:
# include the init model and from scratch
# build model will handle the init model cases where the model options provided is not equals to the ones in checkpoint.
model = build_model(run_options=run_opt, model_options=jdata["model_options"], common_options=jdata["common_options"], statistics=train_datasets.E3statistics())
checkpoint = init_model if init_model else None
model = build_model(checkpoint=checkpoint, model_options=jdata["model_options"], common_options=jdata["common_options"], statistics=train_datasets.E3statistics())
trainer = Trainer(
train_options=jdata["train_options"],
common_options=jdata["common_options"],
Expand Down
16 changes: 8 additions & 8 deletions dptb/nn/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,12 @@

log = logging.getLogger(__name__)

def build_model(run_options, model_options: dict={}, common_options: dict={}, statistics: dict=None):
def build_model(
checkpoint: str=None,
model_options: dict={},
common_options: dict={},
statistics: dict=None
):
"""
The build model method should composed of the following steps:
1. process the configs from user input and the config from the checkpoint (if any).
Expand All @@ -16,18 +21,13 @@ def build_model(run_options, model_options: dict={}, common_options: dict={}, st
run_opt = {
"init_model": init_model,
"restart": restart,
"freeze": freeze,
"log_path": log_path,
"log_level": log_level,
"use_correction": use_correction
}
"""
# this is the
# process the model_options
assert not all((run_options.get("init_model"), run_options.get("restart"))), "You can only choose one of the init_model and restart options."
if any((run_options.get("init_model"), run_options.get("restart"))):
# assert not all((init_model, restart)), "You can only choose one of the init_model and restart options."
if checkpoint is not None:
from_scratch = False
checkpoint = run_options.get("init_model") or run_options.get("restart")
else:
from_scratch = True
if not all((model_options, common_options)):
Expand Down
7 changes: 1 addition & 6 deletions dptb/nnops/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,12 +136,7 @@ def restart(
"""restart the training from a checkpoint, it does not support model options change."""

ckpt = torch.load(checkpoint)

run_opt = {
"restart": checkpoint,
}

model = build_model(run_opt, ckpt["config"]["model_options"], ckpt["config"]["common_options"])
model = build_model(checkpoint, ckpt["config"]["model_options"], ckpt["config"]["common_options"])
if len(train_options) == 0:
train_options = ckpt["config"]["train_options"]
if len(common_options) == 0:
Expand Down
2 changes: 1 addition & 1 deletion dptb/tests/test_SKHamiltonian.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class TestSKHamiltonian:
}
}

train_datasets = build_dataset(set_options=data_options["train"], common_options=common_options)
train_datasets = build_dataset(**data_options["train"], **common_options)
train_loader = DataLoader(dataset=train_datasets, batch_size=1, shuffle=True)

batch = next(iter(train_loader))
Expand Down
4 changes: 2 additions & 2 deletions dptb/tests/test_build_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def test_build_dataset_success(root_directory):
common_options={"basis": {"Si": ["3s", "3p"]}}


dataset = build_dataset(set_options, common_options)
dataset = build_dataset(**set_options, **common_options)

# Assert that the dataset is of the expected type
assert isinstance(dataset, DefaultDataset)
Expand Down Expand Up @@ -61,7 +61,7 @@ def test_build_dataset_fail(root_directory):
common_options={"basis": {"Si": ["3s", "3p"]}}

with pytest.raises(AssertionError) as excinfo:
dataset = build_dataset(set_options, common_options)
dataset = build_dataset(**set_options, **common_options)
assert "Hamiltonian file not found" in str(excinfo.value)

#TODO: Add failure test cases for build_dataset. when get_eigenvalues is True and get_Hamiltonian is False; 当我们补充E3的测试案例时,会有一个数据集,只有Hamiltonian,没有eigenvalues。我们需要测试这种情况。
24 changes: 12 additions & 12 deletions dptb/tests/test_build_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def test_build_nnsk_from_scratch():
"seed": 3982377700
}
statistics = None
model = build_model(run_options, model_options, common_options, statistics)
model = build_model(None, model_options, common_options, statistics)

assert isinstance(model, NNSK)
assert model.device == "cpu"
Expand Down Expand Up @@ -113,7 +113,7 @@ def test_build_model_MIX_from_scratch():
}
statistics = None

model = build_model(run_options, model_options, common_options, statistics)
model = build_model(None, model_options, common_options, statistics)

assert isinstance(model, MIX)
assert model.name == "mix"
Expand All @@ -136,60 +136,60 @@ def test_build_model_failure():
common_options = {}

with pytest.raises(ValueError) as excinfo:
build_model(run_options, model_options, common_options)
build_model(None, model_options, common_options)
assert "You need to provide model_options and common_options" in str(excinfo.value)

common_options = {"basis": {"Si": ["3s", "3p"]}}

# T F T
model_options = {"embedding":True, "prediction":False, "nnsk":True}
with pytest.raises(ValueError) as excinfo:
build_model(run_options, model_options, common_options)
build_model(None, model_options, common_options)
assert "Model_options are not set correctly!" in str(excinfo.value)

# F T T
model_options = {"embedding":False,"prediction":True, "nnsk":True}
with pytest.raises(ValueError) as excinfo:
build_model(run_options, model_options, common_options)
build_model(None, model_options, common_options)
assert "Model_options are not set correctly!" in str(excinfo.value)

# F T F
model_options = {"embedding":False,"prediction":True, "nnsk":False}
with pytest.raises(ValueError) as excinfo:
build_model(run_options, model_options, common_options)
build_model(None, model_options, common_options)
assert "Model_options are not set correctly!" in str(excinfo.value)

# T F F
model_options = {"embedding":True,"prediction":False, "nnsk":False}
with pytest.raises(ValueError) as excinfo:
build_model(run_options, model_options, common_options)
build_model(None, model_options, common_options)
assert "Model_options are not set correctly!" in str(excinfo.value)

# F F F
model_options = {"embedding":False,"prediction":False, "nnsk":False}
with pytest.raises(ValueError) as excinfo:
build_model(run_options, model_options, common_options)
build_model(None, model_options, common_options)
assert "Model_options are not set correctly!" in str(excinfo.value)


model_options = {"embedding":{"method":"se2"},"prediction":{"method":"e3tb"}, "nnsk":True}
with pytest.raises(ValueError) as excinfo:
build_model(run_options, model_options, common_options)
build_model(None, model_options, common_options)
assert "The prediction method must be sktb for mix mode." in str(excinfo.value)

model_options = {"embedding":{"method":"e3"},"prediction":{"method":"sktb"}, "nnsk":True}
with pytest.raises(ValueError) as excinfo:
build_model(run_options, model_options, common_options)
build_model(None, model_options, common_options)
assert "The embedding method must be se2 for mix mode." in str(excinfo.value)

model_options = {"embedding":{"method":"e3"},"prediction":{"method":"sktb"}, "nnsk":False}
with pytest.raises(ValueError) as excinfo:
build_model(run_options, model_options, common_options)
build_model(None, model_options, common_options)
assert "The embedding method must be se2 for sktb prediction in deeptb mode." in str(excinfo.value)

model_options = {"embedding":{"method":"se2"},"prediction":{"method":"e3tb"}, "nnsk":False}
with pytest.raises(ValueError) as excinfo:
build_model(run_options, model_options, common_options)
build_model(None, model_options, common_options)
assert "The embedding method can not be se2 for e3tb prediction in deeptb mode" in str(excinfo.value)

#TODO: add test for dptb-e3tb from scratch
Expand Down
22 changes: 15 additions & 7 deletions dptb/tests/test_data_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,22 @@ def test_data_split(root_directory):
assert os.path.exists(root_directory + "/dptb/tests/data/fake_dataset/test")
assert os.path.exists(root_directory + "/dptb/tests/data/fake_dataset/val")

assert os.path.exists(root_directory + "/dptb/tests/data/fake_dataset/train/frame.1")
assert os.path.exists(root_directory + "/dptb/tests/data/fake_dataset/train/frame.2")
assert os.path.exists(root_directory + "/dptb/tests/data/fake_dataset/train/frame.0")
assert os.path.exists(root_directory + "/dptb/tests/data/fake_dataset/train/frame.4")
assert len(os.listdir(root_directory + "/dptb/tests/data/fake_dataset/train")) == 4
assert len(os.listdir(root_directory + "/dptb/tests/data/fake_dataset/test")) == 2
assert len(os.listdir(root_directory + "/dptb/tests/data/fake_dataset/val")) == 1

assert os.path.exists(root_directory + "/dptb/tests/data/fake_dataset/test/frame.3")
assert os.path.exists(root_directory + "/dptb/tests/data/fake_dataset/test/frame.5")
# assert os.path.exists(root_directory + "/dptb/tests/data/fake_dataset/train/frame.1")
# assert os.path.exists(root_directory + "/dptb/tests/data/fake_dataset/train/frame.2")
# assert os.path.exists(root_directory + "/dptb/tests/data/fake_dataset/train/frame.0")
# assert os.path.exists(root_directory + "/dptb/tests/data/fake_dataset/train/frame.5")

assert os.path.exists(root_directory + "/dptb/tests/data/fake_dataset/val/frame.6")
# assert os.path.exists(root_directory + "/dptb/tests/data/fake_dataset/test/frame.4")
# assert os.path.exists(root_directory + "/dptb/tests/data/fake_dataset/test/frame.6")

# assert os.path.exists(root_directory + "/dptb/tests/data/fake_dataset/val/frame.3")

os.system("rm -r " + root_directory + "/dptb/tests/data/fake_dataset/train")
os.system("rm -r " + root_directory + "/dptb/tests/data/fake_dataset/val")
os.system("rm -r " + root_directory + "/dptb/tests/data/fake_dataset/test")


2 changes: 1 addition & 1 deletion dptb/tests/test_dataloader_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class TestDataLoaderBatch:
"overlap": False,
"seed": 3982377700
}
train_datasets = build_dataset(set_options=data_options["train"], common_options=common_options)
train_datasets = build_dataset(**data_options["train"], **common_options)

def test_init(self):
train_loader = DataLoader(dataset=self.train_datasets, batch_size=1, shuffle=True)
Expand Down
Loading

0 comments on commit 7ff210a

Please sign in to comment.