Skip to content

Commit

Permalink
[feature] proper I/O from/to hdf5 (#1028)
Browse files Browse the repository at this point in the history
* save config to hdf5
* restart from hdf5
* serialize species to json
* simplify hdf5 write of density matrix
* do not save potential as it can be generated from the density; use less strict convergence check
  • Loading branch information
toxa81 authored Dec 2, 2024
1 parent e05414b commit e4c7408
Show file tree
Hide file tree
Showing 22 changed files with 528 additions and 300 deletions.
38 changes: 29 additions & 9 deletions apps/mini_app/sirius.scf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,19 +82,24 @@ preprocess_json_input(std::string fname__)
std::unique_ptr<Simulation_context>
create_sim_ctx(std::string fname__, cmd_args const& args__)
{
auto json = preprocess_json_input(fname__);
std::string config_string;
if (isHDF5(fname__)) {
config_string = fname__;
} else {
auto json = preprocess_json_input(fname__);
config_string = json.dump();
}

auto ctx_ptr = std::make_unique<Simulation_context>(json.dump(), mpi::Communicator::world());
Simulation_context& ctx = *ctx_ptr;
auto ctx = std::make_unique<Simulation_context>(config_string);

auto& inp = ctx.cfg().parameters();
auto& inp = ctx->cfg().parameters();
if (inp.gamma_point() && !(inp.ngridk()[0] * inp.ngridk()[1] * inp.ngridk()[2] == 1)) {
RTE_THROW("this is not a Gamma-point calculation")
}

ctx.import(args__);
ctx->import(args__);

return ctx_ptr;
return ctx;
}

auto
Expand All @@ -110,6 +115,12 @@ ground_state(Simulation_context& ctx, int task_id, cmd_args const& args, int wri
<< "+----------------------+" << std::endl;
break;
}
case task_t::ground_state_restart: {
ctx.out() << "+--------------------------+" << std::endl
<< "| restart SCF ground state |" << std::endl
<< "+--------------------------+" << std::endl;
break;
}
case task_t::ground_state_new_relax: {
ctx.out() << "+---------------------------------------------+" << std::endl
<< "| new SCF ground state with atomic relaxation |" << std::endl
Expand Down Expand Up @@ -145,11 +156,18 @@ ground_state(Simulation_context& ctx, int task_id, cmd_args const& args, int wri
auto& density = dft.density();

if (task_id == task_t::ground_state_restart) {
if (!file_exists(storage_file_name)) {
auto fname = args.value<fs::path>("input", storage_file_name);
if (!isHDF5(fname)) {
fname = storage_file_name;
}
if (!file_exists(fname)) {
RTE_THROW("storage file is not found");
}
density.load(storage_file_name);
potential.load(storage_file_name);
density.load(fname);
density.generate_paw_density();
potential.generate(density, ctx.use_symmetry(), true);
Hamiltonian0<double> H0(potential, true);
initialize_subspace(kset, H0);
} else {
dft.initial_state();
}
Expand Down Expand Up @@ -579,6 +597,8 @@ main(int argn, char** argv)
{"iterative_solver.orthogonalize=", ""},
{"iterative_solver.early_restart=",
"{double} value between 0 and 1 to control the early restart ratio in Davidson"},
{"iterative_solver.energy_tolerance=", "{double} starting tolerance of iterative solver"},
{"iterative_solver.num_steps=", "{int} number of steps in iterative solver"},
{"mixer.type=", "{string} mixer name (anderson, anderson_stable, broyden2, linear)"},
{"mixer.beta=", "{double} mixing parameter"},
{"volume_scale0=", "{double} starting volume scale for EOS calculation"},
Expand Down
4 changes: 2 additions & 2 deletions apps/utils/unit_cell_tools.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ create_supercell(cmd_args const& args__)
std::cout << std::endl;
}

Simulation_context ctx("sirius.json", mpi::Communicator::self());
Simulation_context ctx(std::string("sirius.json"), mpi::Communicator::self());

auto scell_lattice_vectors = dot(ctx.unit_cell().lattice_vectors(), r3::matrix<double>(scell));

Expand Down Expand Up @@ -113,7 +113,7 @@ create_supercell(cmd_args const& args__)
void
find_primitive()
{
Simulation_context ctx("sirius.json", mpi::Communicator::self());
Simulation_context ctx(std::string("sirius.json"), mpi::Communicator::self());

double lattice[3][3];
for (int i : {0, 1, 2}) {
Expand Down
5 changes: 3 additions & 2 deletions examples/pp-pw/Si7Ge/sirius.json
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
"gk_cutoff" : 5.0,
"pw_cutoff" : 20.00,

"energy_tol" : 1e-8,
"energy_tol" : 1e-7,
"density_tol" : 1e-7,

"num_dft_iter" : 100,
Expand All @@ -33,7 +33,8 @@
"mixer" : {
"beta" : 0.8,
"type" : "anderson",
"max_history" : 8
"max_history" : 8,
"use_hartree" : true
},

"unit_cell": {
Expand Down
2 changes: 0 additions & 2 deletions python_module/py_sirius.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -301,8 +301,6 @@ PYBIND11_MODULE(py_sirius, m)
.def(py::init<Simulation_context&>(), py::keep_alive<1, 2>(), "ctx"_a)
.def("generate", &Potential::generate, "density"_a, "use_sym"_a, "transform_to_rg"_a)
.def("fft_transform", &Potential::fft_transform)
.def("save", &Potential::save)
.def("load", &Potential::load)
.def_property("vxc", py::overload_cast<>(&Potential::xc_potential),
py::overload_cast<>(&Potential::xc_potential), py::return_value_policy::reference_internal)
.def_property("exc", py::overload_cast<>(&Potential::xc_energy_density),
Expand Down
4 changes: 1 addition & 3 deletions src/api/sirius_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2819,7 +2819,7 @@ sirius_generate_effective_potential(void* const* gs_handler__, int* error_code__
call_sirius(
[&]() {
auto& gs = get_gs(gs_handler__);
gs.potential().generate(gs.density(), gs.ctx().use_symmetry(), false);
gs.potential().generate(gs.density(), gs.ctx().use_symmetry(), true);
},
error_code__);
}
Expand Down Expand Up @@ -6454,7 +6454,6 @@ sirius_save_state(void** gs_handler__, const char* file_name__, int* error_code_
auto& gs = get_gs(gs_handler__);
std::string file_name(file_name__);
gs.ctx().create_storage_file(file_name);
gs.potential().save(file_name);
gs.density().save(file_name);
},
error_code__);
Expand Down Expand Up @@ -6486,7 +6485,6 @@ sirius_load_state(void** gs_handler__, const char* file_name__, int* error_code_
[&]() {
auto& gs = get_gs(gs_handler__);
std::string file_name(file_name__);
gs.potential().load(file_name);
gs.density().load(file_name);
},
error_code__);
Expand Down
25 changes: 19 additions & 6 deletions src/context/simulation_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1166,18 +1166,18 @@ Simulation_context::update()
void
Simulation_context::create_storage_file(std::string name__) const
{
if (!initialized_) {
RTE_THROW("Simulation_context must be initialized first");
}
if (comm_.rank() == 0) {
/* create new hdf5 file */
HDF5_tree fout(name__, hdf5_access_t::truncate);
fout.create_node("parameters");
fout.create_node("effective_potential");
fout.create_node("effective_magnetic_field");
fout.create_node("density");
fout.create_node("magnetization");

for (int j = 0; j < num_mag_dims(); j++) {
fout["magnetization"].create_node(j);
fout["effective_magnetic_field"].create_node(j);
}

fout["parameters"].write("num_spins", num_spins());
Expand All @@ -1196,10 +1196,23 @@ Simulation_context::create_storage_file(std::string name__) const

fout.create_node("unit_cell");
fout["unit_cell"].create_node("atoms");
for (int j = 0; j < unit_cell().num_atoms(); j++) {
fout["unit_cell"]["atoms"].create_node(j);
fout["unit_cell"]["atoms"][j].write("mt_basis_size", unit_cell().atom(j).mt_basis_size());
fout["unit_cell"].create_node("atom_types");
for (int ia = 0; ia < unit_cell().num_atoms(); ia++) {
fout["unit_cell"]["atoms"].create_node(ia);
fout["unit_cell"]["atoms"][ia].write("mt_basis_size", unit_cell().atom(ia).mt_basis_size());
}
for (int iat = 0; iat < unit_cell().num_atom_types(); iat++) {
fout["unit_cell"]["atom_types"].create_node(iat);
fout["unit_cell"]["atom_types"][iat].write("config", unit_cell().atom_type(iat).serialize().dump());
}
auto config = this->serialize()["config"];
config.erase("locked");
config["control"].erase("mpi_grid_dims");
config["control"].erase("fft_mode");
config["control"].erase("gen_evp_solver_name");
config["control"].erase("std_evp_solver_name");
config["settings"].erase("fft_grid_size");
fout.write("config", config.dump());
}
comm_.barrier();
}
Expand Down
48 changes: 31 additions & 17 deletions src/context/simulation_context.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,8 @@ class Simulation_context : public Simulation_parameters
init_common();
}

/// Create an empty simulation context with an explicit communicators for k-point and
/// band parallelisation.
Simulation_context(mpi::Communicator const& comm__, mpi::Communicator const& comm_k__,
mpi::Communicator const& comm_band__)
: comm_(comm__)
Expand All @@ -327,29 +329,41 @@ class Simulation_context : public Simulation_parameters
init_common();
}

/// Create a simulation context with world communicator and load parameters from JSON string or JSON file.
Simulation_context(std::string const& str__)
: comm_(mpi::Communicator::world())
{
init_common();
import(str__);
unit_cell_->import(cfg().unit_cell());
}

explicit Simulation_context(nlohmann::json const& dict__)
: comm_(mpi::Communicator::world())
/// Create a simulation context with world communicator and load parameters from JSON string or a file.
explicit Simulation_context(std::string const& str__, mpi::Communicator const& comm__ = mpi::Communicator::world())
: comm_(comm__)
{
init_common();
import(dict__);
unit_cell_->import(cfg().unit_cell());
if (!is_json_string(str__) && isHDF5(str__)) {
HDF5_tree fin(str__, hdf5_access_t::read_only);
std::string json_string;
fin.read("config", json_string);
auto dict = read_json_from_file_or_string(json_string);
for (auto& e : dict["unit_cell"]["atom_types"]) {
auto label = e.get<std::string>();
dict["unit_cell"]["atom_files"][label] = "";
}
import(dict);
unit_cell_->import(cfg().unit_cell());

/* need to set type of calculation before parsing species */
this->electronic_structure_method(cfg().parameters().electronic_structure_method());
for (int iat = 0; iat < unit_cell_->num_atom_types(); iat++) {
fin["unit_cell"]["atom_types"][iat].read("config", json_string);
unit_cell_->atom_type(iat).read_input(json_string);
}
} else {
import(str__);
unit_cell_->import(cfg().unit_cell());
}
}

// /// Create a simulation context with world communicator and load parameters from JSON string or JSON file.
Simulation_context(std::string const& str__, mpi::Communicator const& comm__)
explicit Simulation_context(nlohmann::json const& dict__,
mpi::Communicator const& comm__ = mpi::Communicator::world())
: comm_(comm__)
{
init_common();
import(str__);
import(dict__);
unit_cell_->import(cfg().unit_cell());
}

Expand Down Expand Up @@ -764,7 +778,7 @@ class Simulation_context : public Simulation_parameters

/// Export parameters of simulation context as a JSON dictionary.
nlohmann::json
serialize()
serialize() const
{
nlohmann::json dict;
dict["config"] = cfg().dict();
Expand Down
13 changes: 8 additions & 5 deletions src/context/simulation_parameters.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,16 +142,16 @@ get_section_options(std::string const& section__)
}

void
Simulation_parameters::import(std::string const& str__)
Simulation_parameters::import(nlohmann::json const& dict__)
{
auto json = read_json_from_file_or_string(str__);
import(json);
cfg_.import(dict__);
}

void
Simulation_parameters::import(nlohmann::json const& dict__)
Simulation_parameters::import(std::string const& str__)
{
cfg_.import(dict__);
auto dict = read_json_from_file_or_string(str__);
this->import(dict);
}

void
Expand All @@ -174,6 +174,9 @@ Simulation_parameters::import(cmd_args const& args__)

cfg_.iterative_solver().early_restart(
args__.value("iterative_solver.early_restart", cfg_.iterative_solver().early_restart()));
cfg_.iterative_solver().energy_tolerance(
args__.value("iterative_solver.energy_tolerance", cfg_.iterative_solver().energy_tolerance()));
cfg_.iterative_solver().num_steps(args__.value("iterative_solver.num_steps", cfg_.iterative_solver().num_steps()));
cfg_.mixer().beta(args__.value("mixer.beta", cfg_.mixer().beta()));
cfg_.mixer().type(args__.value("mixer.type", cfg_.mixer().type()));
}
Expand Down
Loading

0 comments on commit e4c7408

Please sign in to comment.