Skip to content

Commit

Permalink
Add AtomsBase conversions (#6)
Browse files Browse the repository at this point in the history
  • Loading branch information
jamesgardner1421 authored Aug 19, 2023
1 parent b0ec32c commit b2a8c24
Show file tree
Hide file tree
Showing 6 changed files with 192 additions and 9 deletions.
7 changes: 6 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,28 +4,33 @@ authors = ["James Gardner <[email protected]> and contributors"]
version = "0.2.0"

[deps]
AtomsBase = "a963bdd2-2df7-4f54-a1ee-49d51e6be12a"
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
ExtXYZ = "352459e4-ddd7-4360-8937-99dcb397b478"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
PeriodicTable = "7b2266bf-644c-5ea3-82d8-af4bbd25a884"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
UnitfulAtomic = "a7773ee8-282e-5fa2-be4e-bd808c38a91a"

[compat]
AtomsBase = "0.3"
Distances = "0.10"
ExtXYZ = "0.1"
PeriodicTable = "1"
PyCall = "1"
Requires = "1"
StaticArraysCore = "1"
Unitful = "1"
UnitfulAtomic = "1"
julia = "1"

[extras]
AtomsBaseTesting = "ed7c10db-df7e-4efa-a7be-4f4190f7f227"
PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["PyCall", "Test", "SafeTestsets"]
test = ["AtomsBaseTesting", "PyCall", "Test", "SafeTestsets"]
6 changes: 6 additions & 0 deletions src/NQCBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,10 @@ function __init__()
@require PyCall="438e738f-606a-5dbb-bf0a-cddfbfd45ab0" @eval include("io/ase.jl")
end

include("atoms_base.jl")
export Cell
export System
export Trajectory
export Position, Velocity

end
125 changes: 125 additions & 0 deletions src/atoms_base.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
using AtomsBase: AtomsBase
using StaticArraysCore: SVector

function NQCBase.Atoms(system::AtomsBase.AbstractSystem)
return NQCBase.Atoms(AtomsBase.atomic_symbol(system))
end

function Cell(system::AtomsBase.AbstractSystem)
if AtomsBase.isinfinite(system)
return NQCBase.InfiniteCell()
else
box = AtomsBase.bounding_box(system)
cell = PeriodicCell(reduce(hcat, box))
NQCBase.set_periodicity!(cell, AtomsBase.periodicity(system))
return cell
end
end

function Position(system::AtomsBase.AbstractSystem)
r = AtomsBase.position(system)
output = zeros(AtomsBase.n_dimensions(system), Base.length(system))
for i in axes(output, 2)
for j in axes(output, 1)
output[j, i] = austrip(r[i][j])
end
end
return output
end

function Velocity(system::AtomsBase.AbstractSystem)
v = AtomsBase.velocity(system)
output = zeros(AtomsBase.n_dimensions(system), Base.length(system))
for i in axes(output, 2)
for j in axes(output, 1)
output[j, i] = austrip(v[i][j])
end
end
return output
end

function AtomsBase.bounding_box(cell::PeriodicCell)
S = size(cell.vectors, 2)
return SVector{S}(auconvert.(u"Å", vec) for vec in eachcol(cell.vectors))
end

function AtomsBase.boundary_conditions(cell::PeriodicCell)
S = size(cell.vectors, 2)
return SVector{S}(
bc ? AtomsBase.Periodic() : AtomsBase.DirichletZero() for bc in cell.periodicity
)
end
AtomsBase.isinfinite(::PeriodicCell) = false
AtomsBase.isinfinite(::InfiniteCell) = true

function System(atoms::NQCBase.Atoms, position::AbstractMatrix, cell::AbstractCell=InfiniteCell())
output_atoms = AtomsBaseAtoms(atoms, position)
return build_system(output_atoms, cell)
end

function System(atoms::NQCBase.Atoms, position::AbstractMatrix, velocity::AbstractMatrix, cell::AbstractCell=InfiniteCell())
output_atoms = AtomsBaseAtoms(atoms, position, velocity)
return build_system(output_atoms, cell)
end

function AtomsBaseAtoms(atoms::NQCBase.Atoms, position::AbstractMatrix)
if length(atoms) != size(position, 2)
@error atoms position
error("The provided `Atoms` do not match the `position` array.")
end

output_atoms = AtomsBase.Atom[]
sizehint!(output_atoms, length(atoms))
for i in axes(position,2)
r = auconvert.(u"Å", position[:,i])
push!(output_atoms, AtomsBase.Atom(atoms.numbers[i], r))
end
return output_atoms
end

function AtomsBaseAtoms(atoms::NQCBase.Atoms, position::AbstractMatrix, velocity::AbstractMatrix)
if length(atoms) != size(position, 2)
@error atoms position
error("The provided `Atoms` do not match the `position` array.")
end
if length(atoms) != size(velocity, 2)
@error atoms velocity
error("The provided `Atoms` do not match the `velocity` array.")
end

output_atoms = AtomsBase.Atom[]
sizehint!(output_atoms, length(atoms))
for i in axes(position,2)
r = auconvert.(u"Å", position[:,i])
v = auconvert.(u"Å/ps", velocity[:,i])
push!(output_atoms, AtomsBase.Atom(atoms.numbers[i], r, v))
end
return output_atoms
end

function build_system(atoms, cell)
if AtomsBase.isinfinite(cell)
return AtomsBase.isolated_system(atoms)
else
box = AtomsBase.bounding_box(cell)
bc = AtomsBase.boundary_conditions(cell)
return AtomsBase.atomic_system(atoms, box, bc)
end
end

function Trajectory(
atoms::NQCBase.Atoms,
position::Vector{<:AbstractMatrix},
velocity::Vector{<:AbstractMatrix},
cell::AbstractCell=InfiniteCell()
)

trajectory = AtomsBase.FlexibleSystem[]
sizehint!(trajectory, length(position))

for i in eachindex(position, velocity)
push!(trajectory, System(atoms, position[i], velocity[i], cell))
end

return trajectory
end
9 changes: 5 additions & 4 deletions src/cells.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,19 +36,20 @@ end

Base.eltype(::PeriodicCell{T}) where {T} = T

function PeriodicCell(vectors::AbstractMatrix{T}) where {T<:AbstractFloat}
PeriodicCell{T}(vectors, [true, true, true])
function PeriodicCell(vectors::AbstractMatrix)
vectors = austrip.(vectors)
PeriodicCell{eltype(vectors)}(vectors, [true, true, true])
end

function PeriodicCell(vectors::AbstractMatrix{<:Integer})
PeriodicCell{Float64}(vectors, [true, true, true])
end

function set_periodicity!(cell::PeriodicCell, periodicity::Vector{Bool})
function set_periodicity!(cell::PeriodicCell, periodicity::AbstractVector{Bool})
cell.periodicity .= periodicity
end

function set_vectors!(cell::PeriodicCell, vectors::Matrix)
function set_vectors!(cell::PeriodicCell, vectors::AbstractMatrix)
cell.vectors .= vectors
cell.inverse .= inv(cell.vectors)
end
Expand Down
45 changes: 45 additions & 0 deletions test/atoms_base.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
using NQCBase
using AtomsBase: AtomsBase
using AtomsBaseTesting: AtomsBaseTesting
using Unitful, UnitfulAtomic
using Test

hydrogen = AtomsBase.isolated_system([:H => [0, 0, 1.]u"bohr",
:H => [0, 0, 3.]u"bohr"])

box = 10.26 / 2 * [[0, 0, 1], [1, 0, 1], [1, 1, 0]]u"bohr"
silicon = AtomsBase.periodic_system([:Si => ones(3)/8,
:Si => -ones(3)/8],
box, fractional=true)

@testset "Atoms conversions" begin
@test Atoms(hydrogen) == Atoms([:H, :H])
@test Atoms(silicon) == Atoms([:Si, :Si])
end

@testset "Cell conversions" begin
@test Cell(hydrogen) === InfiniteCell()
@test Cell(silicon) isa PeriodicCell
end

@testset "System" begin
@test System(Atoms([:H, :H]), rand(3,2)) isa AtomsBase.FlexibleSystem
@test System(Atoms([:H, :H]), rand(3,2), rand(3,2)) isa AtomsBase.FlexibleSystem
@test System(Atoms([:H, :H]), rand(3,2), rand(3,2), Cell(silicon)) isa AtomsBase.FlexibleSystem
end

@testset "Forward and backward conversion" begin
atoms = Atoms(silicon)
cell = Cell(silicon)
position = Position(silicon)
velocity = Velocity(silicon)
AtomsBaseTesting.test_approx_eq(System(atoms, position, cell), silicon)
AtomsBaseTesting.test_approx_eq(System(atoms, position, velocity, cell), silicon)
end

@testset "Trajectory" begin
position = [rand(3,3) for i in 1:10]
velocity = [rand(3,3) for i in 1:10]
Trajectory(Atoms([:H, :C, :N]), position, velocity) isa Vector{<:AtomsBase.FlexibleSystem}
Trajectory(Atoms([:H, :C, :N]), position, velocity, Cell(silicon)) isa Vector{<:AtomsBase.FlexibleSystem}
end
9 changes: 5 additions & 4 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
using NQCBase
using Test, SafeTestsets

@time @safetestset "Atoms tests" begin include("atoms.jl") end
@time @safetestset "Cells tests" begin include("cells.jl") end
@time @safetestset "ExtXYZ tests" begin include("io/extxyz.jl") end
@time @safetestset "ase tests" begin include("io/ase.jl") end
@safetestset "Atoms tests" begin include("atoms.jl") end
@safetestset "Cells tests" begin include("cells.jl") end
@safetestset "ExtXYZ tests" begin include("io/extxyz.jl") end
@safetestset "ase tests" begin include("io/ase.jl") end
@safetestset "AtomsBase tests" begin include("atoms_base.jl") end

0 comments on commit b2a8c24

Please sign in to comment.