diff --git a/Project.toml b/Project.toml index c3d0aa3..d1e0ff7 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "NQCBase" uuid = "78c76ebc-5665-4934-b512-82d81b5cbfb7" authors = ["James Gardner and contributors"] -version = "0.2.1" +version = "0.2.9" [deps] AtomsBase = "a963bdd2-2df7-4f54-a1ee-49d51e6be12a" diff --git a/src/NQCBase.jl b/src/NQCBase.jl index efb58a5..509a668 100644 --- a/src/NQCBase.jl +++ b/src/NQCBase.jl @@ -10,7 +10,8 @@ include("cells.jl") include("io/extxyz.jl") function __init__() - @require PyCall="438e738f-606a-5dbb-bf0a-cddfbfd45ab0" @eval include("io/ase.jl") + @require PyCall="438e738f-606a-5dbb-bf0a-cddfbfd45ab0" @eval include("io/PyCall-ase.jl") + @require PythonCall="6099a3de-0909-46bc-b1f4-468b9a2dfc0d" @eval include("io/PythonCall-ase.jl") end include("atoms_base.jl") diff --git a/src/io/ase.jl b/src/io/PyCall-ase.jl similarity index 100% rename from src/io/ase.jl rename to src/io/PyCall-ase.jl diff --git a/src/io/PythonCall-ase.jl b/src/io/PythonCall-ase.jl new file mode 100644 index 0000000..f887ad2 --- /dev/null +++ b/src/io/PythonCall-ase.jl @@ -0,0 +1,41 @@ + +using .PythonCall +using Unitful, UnitfulAtomic + +export convert_from_ase_atoms +export convert_to_ase_atoms + +const ase = pyimport("ase") + +convert_to_ase_atoms(atoms::Atoms, R::Matrix) = + ase.Atoms(positions=ustrip.(u"Å", R'u"bohr"), symbols=string.(atoms.types)) + +convert_to_ase_atoms(atoms::Atoms, R::Matrix, ::InfiniteCell) = + convert_to_ase_atoms(atoms, R) + +function convert_to_ase_atoms(atoms::Atoms, R::Matrix, cell::PeriodicCell) + ase.Atoms( + positions=ustrip.(u"Å", R'u"bohr"), + cell=ustrip.(u"Å", cell.vectors'u"bohr"), + symbols=string.(atoms.types), + pbc=cell.periodicity) +end + +function convert_to_ase_atoms(atoms::Atoms, R::Vector{<:Matrix}, cell::AbstractCell) + convert_to_ase_atoms.(Ref(atoms), R, Ref(cell)) +end + +convert_from_ase_atoms(ase_atoms::Py) = + Atoms(ase_atoms), positions(ase_atoms), Cell(ase_atoms) + +Atoms(ase_atoms::Py) = Atoms{Float64}(Symbol.(PyList(ase_atoms.get_chemical_symbols()))) + +positions(ase_atoms::Py) = austrip.(PyArray(ase_atoms.get_positions())'u"Å") + +function Cell(ase_atoms::Py) + if all(PyArray(ase_atoms.cell.array) .== 0) + return InfiniteCell() + else + return PeriodicCell{Float64}(austrip.(PyArray(ase_atoms.cell.array)'u"Å"), [Bool(x) for x in ase_atoms.pbc]) + end +end