Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Graph AD for Forces #119

Open
wants to merge 19 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0"
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
SimpleWeightedGraphs = "47aef6b3-ad0c-573a-a1e2-d07658019622"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Xtals = "ede5f01d-793e-4c47-9885-c447d1f18d6d"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
CSV = "0.7, 0.8, 0.9"
Expand All @@ -37,12 +39,16 @@ JSON3 = "1.9"
MolecularGraph = "0.11"
NearestNeighbors = "0.4"
PyCall = "1"
StaticArrays = "1"
SimpleWeightedGraphs = "1.2"
Xtals = "0.3"
Zygote = "0.6"
julia = "1.6"

[extras]
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Test"]
test = ["FiniteDifferences", "Test", "Zygote"]
65 changes: 65 additions & 0 deletions src/utils/adjoints.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
using Zygote # , ChainRulesCore
using Zygote: @adjoint
using StaticArrays
using LinearAlgebra

@adjoint function Base.Iterators.Zip(is)
Zip_pullback(Δ) = (Zygote.unzip(Δ.iter),)
return Base.Iterators.Zip(is), Zip_pullback
end

@adjoint function Dict(g::Base.Generator)
ys, backs = Zygote.unzip([Zygote.pullback(g.f, args) for args in g.iter])
Dict(ys...), Δ -> begin
dd = Dict(k => b(Δ)[1].second for (b,(k,v)) in zip(backs, pairs(Δ)))
((x for x in dd),)
end
end

@adjoint function Base.Generator(f, args)
Base.Generator(f, args), Δ -> (nothing, Δ)
end
@adjoint function Pair(k, v)
Pair(k, v), Δ -> begin
@show Δ
(nothing, Δ[k])
end
end

_zero(x) = zero(x)
_zero(::Nothing) = nothing

@adjoint function _cutoff!(weight_mat, f, ijd,
nb_counts, longest_dists;
max_num_nbr = 12)
y, ld = _cutoff!(weight_mat, f, ijd,
nb_counts, longest_dists;
max_num_nbr = max_num_nbr)
function cutoff_pb((Δ,nt))
s = size(Δ)
Δ = vec(collect(Δ))
for (ix, (_,_,d)) in zip(eachindex(Δ), ijd)
y_, back_ = Zygote.pullback(f, d)
Δ[ix] *= first(back_(Zygote.sensitivity(d)))
end
(reshape(Δ, s), nothing,
collect(zip(fill(nothing, size(Δ,1)),
fill(nothing, size(Δ,1)),
Δ)),
nothing,
nothing)
end

(y,ld), cutoff_pb
end

Zygote.@nograd Xtals.Charges{Xtals.Frac}

function Zygote.ChainRulesCore.rrule(::Type{SArray{D, T, ND, L}}, x...) where {D, T, ND, L}
y = SArray{D, T, ND, L}(x...)
function sarray_pb(Δy)
Δy = map(t->eltype(x...)(t...), Δy)
return NoTangent(), (Δy...,)
end
return y, sarray_pb
end
91 changes: 61 additions & 30 deletions src/utils/graph_building.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ using Serialization
using Xtals
using NearestNeighbors
#rc[:paths][:crystals] = @__DIR__ # so that Xtals.jl knows where things are
using Zygote

# options for decay of bond weights with distance...
# user can of course write their own as well
Expand Down Expand Up @@ -107,23 +108,41 @@ function weights_cutoff(is, js, dists; max_num_nbr = 12, dist_decay_func = inver

# iterate over list of tuples to build edge weights...
# note that neighbor list double counts so we only have to increment one counter per pair
weight_mat = zeros(Float32, num_atoms, num_atoms)
weight_mat = zeros(Float64, round(Int,num_atoms), round(Int,num_atoms))
weight_mat, longest_dists = _cutoff!(weight_mat,
dist_decay_func,
ijd,
nb_counts,
longest_dists)

# average across diagonal, just in case
weight_mat = 0.5 .* (weight_mat .+ weight_mat')

# normalize weights
weight_mat = weight_mat ./ maximum(weight_mat)
weight_mat
end

function _cutoff!(weight_mat, f, ijd,
nb_counts, longest_dists; max_num_nbr = 12)

for (i, j, d) in ijd
# FiniteDifferences doesn't like non integers as indices
# and is used to test
i, j = round.(Int, (i,j))

# if we're under the max OR if it's at the same distance as the previous one
if nb_counts[i] < max_num_nbr || isapprox(longest_dists[i], d)
weight_mat[i, j] += dist_decay_func(d)
weight_mat[i, j] += f(d)
longest_dists[i] = d
nb_counts[i] += 1
end
end

# average across diagonal, just in case
weight_mat = 0.5 .* (weight_mat .+ weight_mat')

# normalize weights
weight_mat = weight_mat ./ maximum(weight_mat)
weight_mat, longest_dists
end


"""
Build graph using neighbors from faces of Voronoi polyedra and weights from areas. Based on the approach from https://github.com/ulissigroup/uncertainty_benchmarking/blob/aabb407807e35b5fd6ad06b14b440609ae09e6ef/BNN/data_pyro.py#L268
"""
Expand Down Expand Up @@ -155,6 +174,31 @@ function weights_voronoi(struc)
weight_mat = weight_mat ./ maximum(weight_mat)
end

function index_works(crystal::Xtals.Crystal, n_atoms; cutoff_radius = 8.)
tree = BruteTree(Cart(crystal.atoms.coords, crystal.box).x)

is_raw = 13*n_atoms+1:14*n_atoms
js_raw = inrange(tree,
Cart(crystal.atoms.coords[is_raw],
crystal.box).x,
cutoff_radius)

split1 = map(zip(is_raw, js_raw)) do x
[
p for p in [(x[1], [j for j in js if j != x[1]]...) for js in x[2]] if
length(p) == 2
]
end
ijraw_pairs = [(split1...)...]
end

index_map(i, n_atoms) = (i - 1) % n_atoms + 1

function more_index_stuff(s, n; cutoff_radius = 8.)
ijraw_pairs = index_works(s, n, cutoff_radius = cutoff_radius)
[t[1] for t in ijraw_pairs],
[t[2] for t in ijraw_pairs]
end

"""
Find all lists of pairs of atoms in `crys` that are within a distance of `cutoff_radius` of each other, respecting periodic boundary conditions.
Expand All @@ -166,7 +210,7 @@ function neighbor_list(crys::Crystal; cutoff_radius::Real = 8.0)

# make 3 x 3 x 3 supercell and find indices of "middle" atoms
# as well as index mapping from outer -> inner
supercell = replicate(crys, (3, 3, 3))
supercell = replicate2(crys, (3, 3, 3))

# check for size of cutoff radius relative to size of cell
min_celldim = min(crys.box.a, crys.box.b, crys.box.c)
Expand All @@ -175,33 +219,20 @@ function neighbor_list(crys::Crystal; cutoff_radius::Real = 8.0)
cutoff_radius = 0.99 * min_celldim
end

# todo: try BallTree, also perhaps other leafsize values
#tree = BruteTree(sc.atoms.coords.xf, PeriodicEuclidean([1.0, 1.0, 1.0]))
tree = BruteTree(Cart(supercell.atoms.coords, supercell.box).x)

is_raw = 13*n_atoms+1:14*n_atoms
js_raw =
inrange(tree, Cart(supercell.atoms.coords[is_raw], supercell.box).x, cutoff_radius)

index_map(i) = (i - 1) % n_atoms + 1 # I suddenly understand why some people dislike 1-based indexing

# this looks horrifying but it does do the right thing...
#ijraw_pairs = [p for p in Iterators.flatten([Iterators.product([p for p in zip(is_raw, js_raw)][n]...) for n in 1:4]) if p[1]!=p[2]]
split1 = map(zip(is_raw, js_raw)) do x
return [
p for p in [(x[1], [j for j in js if j != x[1]]...) for js in x[2]] if
length(p) == 2
]
is, js = Zygote.ignore() do
more_index_stuff(supercell, n_atoms; cutoff_radius = cutoff_radius)
end
ijraw_pairs = [(split1...)...]
get_pairdist((i, j)) = distance(supercell.atoms, supercell.box, i, j, false)
dists = get_pairdist.(ijraw_pairs)
is = index_map.([t[1] for t in ijraw_pairs])
js = index_map.([t[2] for t in ijraw_pairs])

dists = Xtals.distance(supercell.atoms.coords, supercell.box, is, js, false)

is, js = Int.(index_map.(is, n_atoms)),
Int.(index_map.(js, n_atoms))
return is, js, dists
end

# TODO: graphs from SMILES via OpenSMILES.jl

include("adjoints.jl")
include("xtals.jl")

end
60 changes: 60 additions & 0 deletions src/utils/xtals.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
using StaticArrays
using Xtals
using Zygote: @adjoint
using Zygote.ChainRulesCore
using NearestNeighbors
using LinearAlgebra

function rl(f_to_c::Array{Float64, 2})
# the unit cell vectors are the columns of f_to_c
a₁ = f_to_c[:, 1]
a₂ = f_to_c[:, 2]
a₃ = f_to_c[:, 3]

# r = zeros(Float64, 3, 3)
l1 = 2 * π * cross(a₂, a₃) / dot(a₁, cross(a₂, a₃))
l2 = 2 * π * cross(a₃, a₁) / dot(a₂, cross(a₃, a₁))
l3 = 2 * π * cross(a₁, a₂) / dot(a₃, cross(a₁, a₂))
vcat(l1', l2', l3')
end

@adjoint function Xtals.reciprocal_lattice(x)
Zygote.pullback(rl, x)
end

function replicate2(crystal::Crystal, repfactors::Tuple{Int, Int, Int})
if Xtals.ne(crystal.bonds) != 0
error("the crystal " * crystal.name * " has assigned bonds. to replicate, remove
its bonds with `remove_bonds!(crystal)`. then use `infer_bonds(crystal)` to
reassign the bonds")
end

assert_P1_symmetry(crystal)

n_atoms = crystal.atoms.n * prod(repfactors)
n_charges = crystal.charges.n * prod(repfactors)

box = replicate(crystal.box, repfactors)

xf_shift = Zygote.ignore() do
rf = range.(0, repfactors .- 1, step = 1)
x = repeat(collect.(sort(vec(collect(Iterators.product(rf...))))), inner = crystal.atoms.n)
reduce(hcat, x)
end

# Repeat Atoms
xf_raw = repeat(crystal.atoms.coords.xf, 1, prod(repfactors)) .+ xf_shift
xf = xf_raw ./ repfactors

frac = Frac(xf)
species = repeat(crystal.atoms.species, inner = prod(repfactors))
atoms = Xtals.Atoms(length(species), species, frac)

# Repeat Charges
q = repeat(crystal.charges.q, inner = prod(repfactors))
xf_raw = repeat(crystal.charges.coords.xf, 1, prod(repfactors))
xf = length(xf_raw) > 0 ? (xf_raw .+ xf_shift) ./ repfactors : xf_raw
charges = Xtals.Charges(length(q), q, Frac(xf))

return Crystal(crystal.name, box, atoms, charges, Xtals.MetaGraph(n_atoms), crystal.symmetry)
end
1 change: 1 addition & 0 deletions test/featurizations/GraphNodeFeaturization_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,4 @@ using ChemistryFeaturization.Featurization
end

end

25 changes: 25 additions & 0 deletions test/utils/GraphBuilding_tests.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
using Test
using ChemistryFeaturization.Utils.GraphBuilding
using Xtals
using Zygote, FiniteDifferences


@testset "GraphBuilding" begin
path1 = abspath(@__DIR__, "..", "test_data", "strucs", "mp-195.cif")
Expand All @@ -20,3 +22,26 @@ using Xtals
@test adjc == wm_true
@test elsc == els_true
end

@testset "Graph Building AD tests" begin

function test_fd(i, j, dist)
fd = grad(forward_fdm(2,1),
(i,j,dist) -> sum(GraphBuilding.weights_cutoff(i,j,dist)),
i, j, dist)

gs = gradient(i, j, dist) do i, j, dist
sum(GraphBuilding.weights_cutoff(i, j, dist))
end

@test gs[1] === nothing
@test gs[2] === nothing
t = isapprox(gs[3], fd[3], rtol = 1e-4)
@test t
end

# test with non-overlapping indices
test_fd(collect(1:10), collect(1:10), Float64.(collect(1:10)))
# test with overlapping indices
test_fd(rand(1:10, 100), rand(1:10, 100), rand(100))
end