Skip to content

Commit

Permalink
remove NEP2 in train part
Browse files Browse the repository at this point in the history
  • Loading branch information
brucefan1983 committed Jul 3, 2024
1 parent 37d8b96 commit 5ca9d36
Show file tree
Hide file tree
Showing 6 changed files with 91 additions and 191 deletions.
32 changes: 6 additions & 26 deletions src/main_nep/fitness.cu
Original file line number Diff line number Diff line change
Expand Up @@ -257,13 +257,7 @@ void Fitness::output(
void Fitness::write_nep_txt(FILE* fid_nep, Parameters& para, float* elite)
{
if (para.train_mode == 0) { // potential model
if (para.version == 2) {
if (para.enable_zbl) {
fprintf(fid_nep, "nep_zbl %d ", para.num_types);
} else {
fprintf(fid_nep, "nep %d ", para.num_types);
}
} else if (para.version == 3) {
if (para.version == 3) {
if (para.enable_zbl) {
fprintf(fid_nep, "nep3_zbl %d ", para.num_types);
} else {
Expand All @@ -277,29 +271,19 @@ void Fitness::write_nep_txt(FILE* fid_nep, Parameters& para, float* elite)
}
}
} else if (para.train_mode == 1) { // dipole model
if (para.version == 2) {
fprintf(fid_nep, "nep_dipole %d ", para.num_types);
} else if (para.version == 3) {
if (para.version == 3) {
fprintf(fid_nep, "nep3_dipole %d ", para.num_types);
} else if (para.version == 4) {
fprintf(fid_nep, "nep4_dipole %d ", para.num_types);
}
} else if (para.train_mode == 2) { // polarizability model
if (para.version == 2) {
fprintf(fid_nep, "nep_polarizability %d ", para.num_types);
} else if (para.version == 3) {
if (para.version == 3) {
fprintf(fid_nep, "nep3_polarizability %d ", para.num_types);
} else if (para.version == 4) {
fprintf(fid_nep, "nep4_polarizability %d ", para.num_types);
}
} else if (para.train_mode == 3) { // temperature model
if (para.version == 2) {
if (para.enable_zbl) {
fprintf(fid_nep, "nep_zbl_temperature %d ", para.num_types);
} else {
fprintf(fid_nep, "nep_temperature %d ", para.num_types);
}
} else if (para.version == 3) {
if (para.version == 3) {
if (para.enable_zbl) {
fprintf(fid_nep, "nep3_zbl_temperature %d ", para.num_types);
} else {
Expand Down Expand Up @@ -333,12 +317,8 @@ void Fitness::write_nep_txt(FILE* fid_nep, Parameters& para, float* elite)
max_NN_radial,
max_NN_angular);
fprintf(fid_nep, "n_max %d %d\n", para.n_max_radial, para.n_max_angular);
if (para.version >= 3) {
fprintf(fid_nep, "basis_size %d %d\n", para.basis_size_radial, para.basis_size_angular);
fprintf(fid_nep, "l_max %d %d %d\n", para.L_max, para.L_max_4body, para.L_max_5body);
} else {
fprintf(fid_nep, "l_max %d\n", para.L_max);
}
fprintf(fid_nep, "basis_size %d %d\n", para.basis_size_radial, para.basis_size_angular);
fprintf(fid_nep, "l_max %d %d %d\n", para.L_max, para.L_max_4body, para.L_max_5body);

fprintf(fid_nep, "ANN %d %d\n", para.num_neurons1, 0);
for (int m = 0; m < para.number_of_variables; ++m) {
Expand Down
157 changes: 52 additions & 105 deletions src/main_nep/nep3.cu
Original file line number Diff line number Diff line change
Expand Up @@ -148,25 +148,15 @@ static __global__ void find_descriptors_radial(
find_fc(rc, rcinv, d12, fc12);

float fn12[MAX_NUM_N];
if (paramb.version == 2) {
find_fn(paramb.n_max_radial, rcinv, d12, fc12, fn12);
for (int n = 0; n <= paramb.n_max_radial; ++n) {
float c = (paramb.num_types == 1)
? 1.0f
: annmb.c[(n * paramb.num_types + t1) * paramb.num_types + t2];
q[n] += fn12[n] * c;
}
} else {
find_fn(paramb.basis_size_radial, rcinv, d12, fc12, fn12);
for (int n = 0; n <= paramb.n_max_radial; ++n) {
float gn12 = 0.0f;
for (int k = 0; k <= paramb.basis_size_radial; ++k) {
int c_index = (n * (paramb.basis_size_radial + 1) + k) * paramb.num_types_sq;
c_index += t1 * paramb.num_types + t2;
gn12 += fn12[k] * annmb.c[c_index];
}
q[n] += gn12;
find_fn(paramb.basis_size_radial, rcinv, d12, fc12, fn12);
for (int n = 0; n <= paramb.n_max_radial; ++n) {
float gn12 = 0.0f;
for (int k = 0; k <= paramb.basis_size_radial; ++k) {
int c_index = (n * (paramb.basis_size_radial + 1) + k) * paramb.num_types_sq;
c_index += t1 * paramb.num_types + t2;
gn12 += fn12[k] * annmb.c[c_index];
}
q[n] += gn12;
}
}
for (int n = 0; n <= paramb.n_max_radial; ++n) {
Expand Down Expand Up @@ -211,26 +201,15 @@ static __global__ void find_descriptors_angular(
}
float rcinv = 1.0f / rc;
find_fc(rc, rcinv, d12, fc12);
if (paramb.version == 2) {
float fn;
find_fn(n, rcinv, d12, fc12, fn);
fn *=
(paramb.num_types == 1)
? 1.0f
: annmb.c
[((paramb.n_max_radial + 1 + n) * paramb.num_types + t1) * paramb.num_types + t2];
accumulate_s(d12, x12, y12, z12, fn, s);
} else {
float fn12[MAX_NUM_N];
find_fn(paramb.basis_size_angular, rcinv, d12, fc12, fn12);
float gn12 = 0.0f;
for (int k = 0; k <= paramb.basis_size_angular; ++k) {
int c_index = (n * (paramb.basis_size_angular + 1) + k) * paramb.num_types_sq;
c_index += t1 * paramb.num_types + t2 + paramb.num_c_radial;
gn12 += fn12[k] * annmb.c[c_index];
}
accumulate_s(d12, x12, y12, z12, gn12, s);
float fn12[MAX_NUM_N];
find_fn(paramb.basis_size_angular, rcinv, d12, fc12, fn12);
float gn12 = 0.0f;
for (int k = 0; k <= paramb.basis_size_angular; ++k) {
int c_index = (n * (paramb.basis_size_angular + 1) + k) * paramb.num_types_sq;
c_index += t1 * paramb.num_types + t2 + paramb.num_c_radial;
gn12 += fn12[k] * annmb.c[c_index];
}
accumulate_s(d12, x12, y12, z12, gn12, s);
}
if (paramb.num_L == paramb.L_max) {
find_q(paramb.n_max_angular + 1, n, s, q);
Expand Down Expand Up @@ -272,13 +251,11 @@ NEP3::NEP3(
paramb.n_max_angular = para.n_max_angular;
paramb.L_max = para.L_max;
paramb.num_L = paramb.L_max;
if (version >= 3) {
if (para.L_max_4body == 2) {
paramb.num_L += 1;
}
if (para.L_max_5body == 1) {
paramb.num_L += 1;
}
if (para.L_max_4body == 2) {
paramb.num_L += 1;
}
if (para.L_max_5body == 1) {
paramb.num_L += 1;
}
paramb.dim_angular = (para.n_max_angular + 1) * paramb.num_L;

Expand Down Expand Up @@ -331,7 +308,7 @@ void NEP3::update_potential(Parameters& para, float* parameters, ANN& ann)
{
float* pointer = parameters;
for (int t = 0; t < paramb.num_types; ++t) {
if (t > 0 && paramb.version != 4) { // Use the same set of NN parameters for NEP2 and NEP3
if (t > 0 && paramb.version != 4) { // Use the same set of NN parameters for NEP3
pointer -= (ann.dim + 2) * ann.num_neurons1;
}
ann.w0[t] = pointer;
Expand All @@ -346,7 +323,7 @@ void NEP3::update_potential(Parameters& para, float* parameters, ANN& ann)

if (para.train_mode == 2) {
for (int t = 0; t < paramb.num_types; ++t) {
if (t > 0 && paramb.version != 4) { // Use the same set of NN parameters for NEP2 and NEP3
if (t > 0 && paramb.version != 4) { // Use the same set of NN parameters for NEP3
pointer -= (ann.dim + 2) * ann.num_neurons1;
}
ann.w0_pol[t] = pointer;
Expand Down Expand Up @@ -599,31 +576,17 @@ static __global__ void find_force_radial(
float fnp12[MAX_NUM_N];
float f12[3] = {0.0f};

if (paramb.version == 2) {
find_fn_and_fnp(paramb.n_max_radial, rcinv, d12, fc12, fcp12, fn12, fnp12);
for (int n = 0; n <= paramb.n_max_radial; ++n) {
float tmp12 = g_Fp[n1 + n * N] * fnp12[n] * d12inv;
tmp12 *= (paramb.num_types == 1)
? 1.0f
: annmb.c[(n * paramb.num_types + t1) * paramb.num_types + t2];
for (int d = 0; d < 3; ++d) {
f12[d] += tmp12 * r12[d];
}
find_fn_and_fnp(paramb.basis_size_radial, rcinv, d12, fc12, fcp12, fn12, fnp12);
for (int n = 0; n <= paramb.n_max_radial; ++n) {
float gnp12 = 0.0f;
for (int k = 0; k <= paramb.basis_size_radial; ++k) {
int c_index = (n * (paramb.basis_size_radial + 1) + k) * paramb.num_types_sq;
c_index += t1 * paramb.num_types + t2;
gnp12 += fnp12[k] * annmb.c[c_index];
}
} else {
find_fn_and_fnp(
paramb.basis_size_radial, rcinv, d12, fc12, fcp12, fn12, fnp12);
for (int n = 0; n <= paramb.n_max_radial; ++n) {
float gnp12 = 0.0f;
for (int k = 0; k <= paramb.basis_size_radial; ++k) {
int c_index = (n * (paramb.basis_size_radial + 1) + k) * paramb.num_types_sq;
c_index += t1 * paramb.num_types + t2;
gnp12 += fnp12[k] * annmb.c[c_index];
}
float tmp12 = g_Fp[n1 + n * N] * gnp12 * d12inv;
for (int d = 0; d < 3; ++d) {
f12[d] += tmp12 * r12[d];
}
float tmp12 = g_Fp[n1 + n * N] * gnp12 * d12inv;
for (int d = 0; d < 3; ++d) {
f12[d] += tmp12 * r12[d];
}
}

Expand Down Expand Up @@ -710,42 +673,26 @@ static __global__ void find_force_angular(
find_fc_and_fcp(rc, rcinv, d12, fc12, fcp12);
float f12[3] = {0.0f};

if (paramb.version == 2) {
for (int n = 0; n <= paramb.n_max_angular; ++n) {
float fn;
float fnp;
find_fn_and_fnp(n, rcinv, d12, fc12, fcp12, fn, fnp);
const float c =
(paramb.num_types == 1)
? 1.0f
: annmb.c
[((paramb.n_max_radial + 1 + n) * paramb.num_types + t1) * paramb.num_types + t2];
fn *= c;
fnp *= c;
accumulate_f12(n, paramb.n_max_angular + 1, d12, r12, fn, fnp, Fp, sum_fxyz, f12);
float fn12[MAX_NUM_N];
float fnp12[MAX_NUM_N];
find_fn_and_fnp(paramb.basis_size_angular, rcinv, d12, fc12, fcp12, fn12, fnp12);
for (int n = 0; n <= paramb.n_max_angular; ++n) {
float gn12 = 0.0f;
float gnp12 = 0.0f;
for (int k = 0; k <= paramb.basis_size_angular; ++k) {
int c_index = (n * (paramb.basis_size_angular + 1) + k) * paramb.num_types_sq;
c_index += t1 * paramb.num_types + t2 + paramb.num_c_radial;
gn12 += fn12[k] * annmb.c[c_index];
gnp12 += fnp12[k] * annmb.c[c_index];
}
} else {
float fn12[MAX_NUM_N];
float fnp12[MAX_NUM_N];
find_fn_and_fnp(paramb.basis_size_angular, rcinv, d12, fc12, fcp12, fn12, fnp12);
for (int n = 0; n <= paramb.n_max_angular; ++n) {
float gn12 = 0.0f;
float gnp12 = 0.0f;
for (int k = 0; k <= paramb.basis_size_angular; ++k) {
int c_index = (n * (paramb.basis_size_angular + 1) + k) * paramb.num_types_sq;
c_index += t1 * paramb.num_types + t2 + paramb.num_c_radial;
gn12 += fn12[k] * annmb.c[c_index];
gnp12 += fnp12[k] * annmb.c[c_index];
}
if (paramb.num_L == paramb.L_max) {
accumulate_f12(n, paramb.n_max_angular + 1, d12, r12, gn12, gnp12, Fp, sum_fxyz, f12);
} else if (paramb.num_L == paramb.L_max + 1) {
accumulate_f12_with_4body(
n, paramb.n_max_angular + 1, d12, r12, gn12, gnp12, Fp, sum_fxyz, f12);
} else {
accumulate_f12_with_5body(
n, paramb.n_max_angular + 1, d12, r12, gn12, gnp12, Fp, sum_fxyz, f12);
}
if (paramb.num_L == paramb.L_max) {
accumulate_f12(n, paramb.n_max_angular + 1, d12, r12, gn12, gnp12, Fp, sum_fxyz, f12);
} else if (paramb.num_L == paramb.L_max + 1) {
accumulate_f12_with_4body(
n, paramb.n_max_angular + 1, d12, r12, gn12, gnp12, Fp, sum_fxyz, f12);
} else {
accumulate_f12_with_5body(
n, paramb.n_max_angular + 1, d12, r12, gn12, gnp12, Fp, sum_fxyz, f12);
}
}

Expand Down
10 changes: 5 additions & 5 deletions src/main_nep/nep3.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -46,17 +46,17 @@ public:
float rc_angular = 0.0f; // angular cutoff
float rcinv_radial = 0.0f; // inverse of the radial cutoff
float rcinv_angular = 0.0f; // inverse of the angular cutoff
int basis_size_radial = 0; // for nep3
int basis_size_angular = 0; // for nep3
int basis_size_radial = 0;
int basis_size_angular = 0;
int n_max_radial = 0; // n_radial = 0, 1, 2, ..., n_max_radial
int n_max_angular = 0; // n_angular = 0, 1, 2, ..., n_max_angular
int L_max = 0; // l = 1, 2, ..., L_max
int dim_angular;
int num_L;
int num_types = 0;
int num_types_sq = 0; // for nep3
int num_c_radial = 0; // for nep3
int version = 2; // 2 for NEP2 and 3 for NEP3
int num_types_sq = 0;
int num_c_radial = 0;
int version = 4; // 3 for NEP3 and 4 for NEP4
int atomic_numbers[NUM_ELEMENTS];
};

Expand Down
19 changes: 7 additions & 12 deletions src/main_nep/parameters.cu
Original file line number Diff line number Diff line change
Expand Up @@ -163,10 +163,10 @@ void Parameters::calculate_parameters()
}
dim_radial = n_max_radial + 1; // 2-body descriptors q^i_n
dim_angular = (n_max_angular + 1) * L_max; // 3-body descriptors q^i_nl
if (version >= 3 && L_max_4body == 2) { // 4-body descriptors q^i_n222
if (L_max_4body == 2) { // 4-body descriptors q^i_n222
dim_angular += n_max_angular + 1;
}
if (version >= 3 && L_max_5body == 1) { // 5-body descriptors q^i_n1111
if (L_max_5body == 1) { // 5-body descriptors q^i_n1111
dim_angular += n_max_angular + 1;
}
dim = dim_radial + dim_angular;
Expand All @@ -182,14 +182,9 @@ void Parameters::calculate_parameters()

number_of_variables_ann = (dim + 2) * num_neurons1 * (version == 4 ? num_types : 1) + 1;

if (version == 2) {
number_of_variables_descriptor =
(num_types == 1) ? 0 : num_types * num_types * (n_max_radial + n_max_angular + 2);
} else {
number_of_variables_descriptor =
num_types * num_types *
(dim_radial * (basis_size_radial + 1) + (n_max_angular + 1) * (basis_size_angular + 1));
}
number_of_variables_descriptor =
num_types * num_types *
(dim_radial * (basis_size_radial + 1) + (n_max_angular + 1) * (basis_size_angular + 1));

number_of_variables = number_of_variables_ann + number_of_variables_descriptor;
if (train_mode == 2) {
Expand Down Expand Up @@ -516,8 +511,8 @@ void Parameters::parse_version(const char** param, int num_param)
if (!is_valid_int(param[1], &version)) {
PRINT_INPUT_ERROR("version should be an integer.\n");
}
if (version < 2 || version > 4) {
PRINT_INPUT_ERROR("version should = 2 or 3 or 4.");
if (version < 3 || version > 4) {
PRINT_INPUT_ERROR("version should = 3 or 4.");
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/main_nep/parameters.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ public:
Parameters();

// parameters to be read in
int version; // nep version, can be 2 or 3
int version; // nep version, can be 3 or 4
int batch_size; // number of configurations in one batch
int use_full_batch; // 1 for effective full-batch even though batch_size is not full-batch
int num_types; // number of atom types
Expand Down
Loading

0 comments on commit 5ca9d36

Please sign in to comment.