Skip to content

Commit

Permalink
update dpmd test
Browse files Browse the repository at this point in the history
  • Loading branch information
y1xiaoc committed Apr 13, 2021
1 parent b7f869d commit 4f7af6b
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 16 deletions.
6 changes: 3 additions & 3 deletions dargs/dargs.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,8 +492,8 @@ def gen_doc(self, paths: Optional[List[str]] = None,
if kwargs.get("make_link"):
if not kwargs.get("make_anchor"):
raise ValueError("`make_link` only works with `make_anchor` set")
fnstr, target = make_ref_pair(paths+[self.flag_name], fnstr, "emph")
body_list.append("\n" + target)
fnstr, target = make_ref_pair(paths+[self.flag_name], fnstr, "flag")
body_list.append(target + "\n")
for choice in self.choice_dict.values():
body_list.append("")
choice_path = self._make_cpath(choice.name, paths, showflag)
Expand Down Expand Up @@ -526,7 +526,7 @@ def gen_doc_flag(self, paths: Optional[List[str]] = None, **kwargs) -> str:
self._make_cpath(c.name, paths, kwargs["showflag"]),
text=f"``{c.name}``", prefix="code")
for c in self.choice_dict.values()))
targetdoc = indent('\n' + '\n'.join(l_target), INDENT)
targetdoc = indent('\n'.join(l_target) + "\n", INDENT)
else:
l_choice = [c.name for c in self.choice_dict.values()]
targetdoc = None
Expand Down
53 changes: 40 additions & 13 deletions tests/dpmdargs.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,14 +181,13 @@ def descrpt_variant_type_args():
link_se_a_3be = make_link('se_a_3be', 'model/descriptor[se_a_3be]')
link_se_a_tpe = make_link('se_a_tpe', 'model/descriptor[se_a_tpe]')
link_hybrid = make_link('hybrid', 'model/descriptor[hybrid]')
doc_descrpt_type = f'The type of the descritpor. Valid types are {link_lf}, {link_se_a}, {link_se_r}, {link_se_a_3be}, {link_se_a_tpe}, {link_hybrid}. \n\n\
doc_descrpt_type = f'The type of the descritpor. See explanation below. \n\n\
- `loc_frame`: Defines a local frame at each atom, and the compute the descriptor as local coordinates under this frame.\n\n\
- `se_a`: Used by the smooth edition of Deep Potential. The full relative coordinates are used to construct the descriptor.\n\n\
- `se_r`: Used by the smooth edition of Deep Potential. Only the distance between atoms is used to construct the descriptor.\n\n\
- `se_a_3be`: Used by the smooth edition of Deep Potential. The full relative coordinates are used to construct the descriptor. Three-body embedding will be used by this descriptor.\n\n\
- `se_a_tpe`: Used by the smooth edition of Deep Potential. The full relative coordinates are used to construct the descriptor. Type embedding will be used by this descriptor.\n\n\
- `hybrid`: Concatenate of a list of descriptors as a new descriptor.\n\n\
- `se_ar`: A hybrid of `se_a` and `se_r`. Typically `se_a` has a smaller cut-off while the `se_r` has a larger cut-off. Deprecated, use `hybrid` instead.'
- `hybrid`: Concatenate of a list of descriptors as a new descriptor.'

return Variant("type", [
Argument("loc_frame", dict, descrpt_local_frame_args()),
Expand All @@ -197,7 +196,6 @@ def descrpt_variant_type_args():
Argument("se_a_3be", dict, descrpt_se_a_3be_args(), alias = ['se_at']),
Argument("se_a_tpe", dict, descrpt_se_a_tpe_args(), alias = ['se_a_ebd']),
Argument("hybrid", dict, descrpt_hybrid_args()),
Argument("se_ar", dict, descrpt_se_ar_args()),
], doc = doc_descrpt_type)


Expand Down Expand Up @@ -275,7 +273,7 @@ def fitting_dipole():


def fitting_variant_type_args():
doc_descrpt_type = 'The type of the fitting. Valid types are `ener`, `dipole`, `polar` and `global_polar`. \n\n\
doc_descrpt_type = 'The type of the fitting. See explanation below. \n\n\
- `ener`: Fit an energy model (potential energy surface).\n\n\
- `dipole`: Fit an atomic dipole model. Atomic dipole labels for all the selected atoms (see `sel_type`) should be provided by `dipole.npy` in each data system. The file has number of frames lines and 3 times of number of selected atoms columns.\n\n\
- `polar`: Fit an atomic polarizability model. Atomic polarizability labels for all the selected atoms (see `sel_type`) should be provided by `polarizability.npy` in each data system. The file has number of frames lines and 9 times of number of selected atoms columns.\n\n\
Expand All @@ -289,13 +287,39 @@ def fitting_variant_type_args():
default_tag = 'ener',
doc = doc_descrpt_type)

def modifier_dipole_charge():
doc_model_name = "The name of the frozen dipole model file."
doc_model_charge_map = f"The charge of the WFCC. The list length should be the same as the {make_link('sel_type', 'model/fitting_net[dipole]/sel_type')}. "
doc_sys_charge_map = f"The charge of real atoms. The list length should be the same as the {make_link('type_map', 'model/type_map')}"
doc_ewald_h = f"The grid spacing of the FFT grid. Unit is A"
doc_ewald_beta = f"The splitting parameter of Ewald sum. Unit is A^{-1}"

return [
Argument("model_name", str, optional = False, doc = doc_model_name),
Argument("model_charge_map", list, optional = False, doc = doc_model_charge_map),
Argument("sys_charge_map", list, optional = False, doc = doc_sys_charge_map),
Argument("ewald_beta", float, optional = True, default = 0.4, doc = doc_ewald_beta),
Argument("ewald_h", float, optional = True, default = 1.0, doc = doc_ewald_h),
]

def modifier_variant_type_args():
doc_modifier_type = "The type of modifier. See explanation below.\n\n\
-`dipole_charge`: Use WFCC to model the electronic structure of the system. Correct the long-range interaction"
return Variant("type",
[
Argument("dipole_charge", dict, modifier_dipole_charge()),
],
optional = False,
doc = doc_modifier_type)


def model_args ():
doc_type_map = 'A list of strings. Give the name to each type of atoms.'
doc_data_stat_nbatch = 'The model determines the normalization from the statistics of the data. This key specifies the number of `frames` in each `system` used for statistics.'
doc_data_stat_protect = 'Protect parameter for atomic energy regression.'
doc_descrpt = 'The descriptor of atomic environment.'
doc_fitting = 'The fitting of physical properties.'
doc_modifier = 'The modifier of model output.'
doc_use_srtab = 'The table for the short-range pairwise interaction added on top of DP. The table is a text data file with (N_t + 1) * N_t / 2 + 1 columes. The first colume is the distance between atoms. The second to the last columes are energies for pairs of certain types. For example we have two atom types, 0 and 1. The columes from 2nd to 4th are for 0-0, 0-1 and 1-1 correspondingly.'
doc_smin_alpha = 'The short-range tabulated interaction will be swithed according to the distance of the nearest neighbor. This distance is calculated by softmin. This parameter is the decaying parameter in the softmin. It is only required when `use_srtab` is provided.'
doc_sw_rmin = 'The lower boundary of the interpolation between short-range tabulated interaction and DP. It is only required when `use_srtab` is provided.'
Expand All @@ -310,7 +334,8 @@ def model_args ():
Argument("sw_rmin", float, optional = True, doc = doc_sw_rmin),
Argument("sw_rmax", float, optional = True, doc = doc_sw_rmax),
Argument("descriptor", dict, [], [descrpt_variant_type_args()], doc = doc_descrpt),
Argument("fitting_net", dict, [], [fitting_variant_type_args()], doc = doc_fitting)
Argument("fitting_net", dict, [], [fitting_variant_type_args()], doc = doc_fitting),
Argument("modifier", dict, [], [modifier_variant_type_args()], optional = True, doc = doc_modifier),
])
# print(ca.gen_doc())
return ca
Expand All @@ -330,7 +355,7 @@ def learning_rate_exp():


def learning_rate_variant_type_args():
doc_lr = 'The type of the learning rate. Current type `exp`, the exponentially decaying learning rate is supported.'
doc_lr = 'The type of the learning rate.'

return Variant("type",
[Argument("exp", dict, learning_rate_exp())],
Expand Down Expand Up @@ -376,7 +401,7 @@ def loss_ener():


def loss_variant_type_args():
doc_loss = 'The type of the loss. For fitting type `ener`, the loss type should be set to `ener` or left unset. For tensorial fitting types `dipole`, `polar` and `global_polar`, the type should be left unset.\n\.'
doc_loss = 'The type of the loss. \n\.'

return Variant("type",
[Argument("ener", dict, loss_ener())],
Expand Down Expand Up @@ -452,16 +477,18 @@ def make_index(keys):
return ', '.join(ret)


def gen_doc(**kwargs):
def gen_doc(*, make_anchor=True, make_link=True, **kwargs):
if make_link:
make_anchor = True
ma = model_args()
lra = learning_rate_args()
la = loss_args()
ta = training_args()
ptr = []
ptr.append(ma.gen_doc(**kwargs))
ptr.append(la.gen_doc(**kwargs))
ptr.append(lra.gen_doc(**kwargs))
ptr.append(ta.gen_doc(**kwargs))
ptr.append(ma.gen_doc(make_anchor=make_anchor, make_link=make_link, **kwargs))
ptr.append(la.gen_doc(make_anchor=make_anchor, make_link=make_link, **kwargs))
ptr.append(lra.gen_doc(make_anchor=make_anchor, make_link=make_link, **kwargs))
ptr.append(ta.gen_doc(make_anchor=make_anchor, make_link=make_link, **kwargs))

key_words = []
for ii in "\n\n".join(ptr).split('\n'):
Expand Down

0 comments on commit 4f7af6b

Please sign in to comment.