Skip to content

Commit

Permalink
Merge pull request #13 from libAtoms/v07
Browse files Browse the repository at this point in the history
V07
  • Loading branch information
cortner authored Feb 16, 2019
2 parents 65d1389 + d8460bc commit 5a1da97
Show file tree
Hide file tree
Showing 10 changed files with 226 additions and 212 deletions.
2 changes: 1 addition & 1 deletion REQUIRE
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
julia 0.6
julia 0.7
StaticArrays
52 changes: 29 additions & 23 deletions src/cell_list.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
using Base.Threads
using Base.Threads, LinearAlgebra

export npairs, nsites

PairList{T}(X::Vector{SVec{T}}, cutoff::AbstractFloat, cell::AbstractMatrix,
pbc, int_type = zero(Int);
store_first::Bool = true, sorted ::Bool= true, fixcell::Bool = true) =
_pairlist_(X, SMat{T}(cell), SVec{Bool}(pbc), T(cutoff), int_type,
PairList(X::Vector{SVec{T}}, cutoff::AbstractFloat, cell::AbstractMatrix, pbc;
int_type::Type = Int, store_first = true, sorted = true, fixcell = true) where {T} =
_pairlist_(X, SMat{T}(cell), SVec{Bool}(pbc), T(cutoff), zero(int_type),
store_first, sorted, fixcell)

PairList{T}(X::Matrix{T}, args...; kwargs...) =
PairList(X::Matrix{T}, args...; kwargs...) where {T} =
PairList(reinterpret(SVec{T}, X, (size(X,2),)), args...; varargs...)

npairs(nlist::PairList) = length(nlist.i)
Expand Down Expand Up @@ -43,25 +42,28 @@ end
pbc ? bin_wrap(i, n) : bin_trunc(i, n)

"Map particle position to a (cartesian) cell index"
@inline position_to_cell_index{T, TI <: Integer}(
inv_cell::SMat{T}, x::SVec{T}, ns::SVec{TI}) =
@inline position_to_cell_index(inv_cell::SMat{T}, x::SVec{T}, ns::SVec{TI}
) where {T, TI <: Integer} =
floor.(TI, ((inv_cell' * x) .* ns + 1))


# ------------ The next two functions are the only dimension-dependent
# parts of the code!

# an extension of sub2ind for the case when i is a vector (cartesian index)
@inline Base.sub2ind{TI <: Integer}(dims::NTuple{3,TI}, i::SVec{TI}) =
sub2ind(dims, i[1], i[2], i[3])
# @inline Base.sub2ind(dims::NTuple{3,TI}, i::SVec{TI}) where {TI <: Integer} =
# sub2ind(dims, i[1], i[2], i[3])
# WARNING: this smells like a performance regression!
@inline _sub2ind(dims::NTuple{3,TI}, i::SVec{TI}) where {TI <: Integer} =
(LinearIndices(dims))[i[1], i[2], i[3]]

lengths{T}(C::SMat{T}) =
lengths(C::SMat{T}) where {T} =
det(C) ./ SVec{T}(norm(C[2,:]×C[3,:]), norm(C[3,:]×C[1,:]), norm(C[1,:]×C[2,:]))


# --------------------------------------------------------------------------

function analyze_cell(cell, cutoff, _::TI) where TI
function analyze_cell(cell, cutoff, _::TI) where {TI <: Integer}
# check the cell volume (allow only 3D volumes!)
volume = abs(det(cell))
@assert volume > 1e-12
Expand All @@ -70,15 +72,18 @@ function analyze_cell(cell, cutoff, _::TI) where TI
# Compute distance of cell faces
lens = abs.(lengths(cell))
# Number of cells for cell subdivision
ns_vec = max.(floor.(TI, lens / cutoff), 1)
_t = floor.(TI, lens / cutoff)
ns_vec = max.(_t, one(TI))
return inv_cell, ns_vec, lens
end

# multi-threading setup

function setup_mt(niter::TI, maxnt = MAX_THREADS[1]) where TI
nt = minimum([6, nthreads(), ceil(TI, niter / 20), maxnt])
nn = ceil.(TI, linspace(1, niter+1, nt+1))
# nn = ceil.(TI, linspace(1, niter+1, nt+1))
# range(start, stop=stop, length=length)
nn = ceil.(TI, range(1, stop=niter+1, length=nt+1))
return nt, nn
end

Expand All @@ -103,8 +108,8 @@ function _celllist_(X::Vector{SVec{T}}, cell::SMat{T}, pbc::SVec{Bool},
# data structure to store a linked list for each bin
ncells = prod(ns_vec)
seed = fill(TI(-1), ncells)
last = Vector{TI}(ncells)
next = Vector{TI}(nat)
last = Vector{TI}(undef, ncells)
next = Vector{TI}(undef, nat)
nats = zeros(TI, ncells)

for i = 1:nat
Expand All @@ -113,7 +118,7 @@ function _celllist_(X::Vector{SVec{T}}, cell::SMat{T}, pbc::SVec{Bool},
# Periodic/non-periodic boundary conditions
c = bin_wrap_or_trunc.(c, pbc, ns_vec)
# linear cell index # (+1 due to 1-based indexing)
ci = sub2ind(ns, c) # <<<<
ci = _sub2ind(ns, c) # <<<<

# Put atom into appropriate bin (list of linked lists)
if seed[ci] < 0 # ci contains no atom yet
Expand Down Expand Up @@ -169,8 +174,9 @@ function _pairlist_(clist::CellList{T, TI}) where {T, TI}

# Find out over how many neighbor cells we need to loop (if the box is small)
nxyz = ceil.(TI, cutoff * (ns_vec ./ lens))
cxyz = CartesianIndex(nxyz.data)
xyz_range = CartesianRange(- cxyz, cxyz)
# cxyz = CartesianIndex(nxyz.data)
# WARNING : 3D-specific hack; also potential performance regression
xyz_range = CartesianIndices((-nxyz[1]:nxyz[1], -nxyz[2]:nxyz[2], -nxyz[3]:nxyz[3]))

# Loop over threads
@threads for it = 1:nt
Expand Down Expand Up @@ -230,7 +236,7 @@ function _find_neighbours_!(i, clist, ns_vec::SVec{TI}, bins, xyz_range,
# skip this bin if not inside the domain
all(1 .<= cj .<= ns_vec) || continue
# linear cell index
ncj = sub2ind(ns_vec.data, cj)
ncj = _sub2ind(ns_vec.data, cj)
# Offset of the neighboring bins
off = bins * xyz

Expand Down Expand Up @@ -308,9 +314,9 @@ element of `i` with value `n`. Further, `first[nat+1]` will be
If `first[n] == first[n+1]` then this means that `i` contains no element `n`.
"""
function get_first{TI}(i::Vector{TI}, nat::Integer = i[end])
function get_first(i::Vector{TI}, nat::Integer = i[end]) where {TI}
# compute the first index for each site
first = Vector{TI}(nat + 1)
first = Vector{TI}(undef, nat + 1)
idx = 1
n = 1
while n <= nat && idx <= length(i)
Expand All @@ -323,7 +329,7 @@ function get_first{TI}(i::Vector{TI}, nat::Integer = i[end])
end
n += 1
end
first[n:end] = length(i)+1
first[n:end] .= length(i)+1
return first
end

Expand Down
20 changes: 11 additions & 9 deletions src/iterators.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@

import Base: start, done, next, length
import Base: iterate, length, pairs

export pairs, sites, site, nbodies

abstract type AbstractIterator end

inc{T <: Integer}(i::T) = i + T(1)
inc(i::T) where {T <: Integer} = i + T(1)

# -------------- iterator over pairs ---------------

Expand All @@ -16,10 +16,11 @@ struct PairIterator{T,TI} <: AbstractIterator
nlist::PairList{T,TI}
end

start{T,TI}(it::PairIterator{T,TI}) = TI(1)
done(it::PairIterator, i::Integer) = (i > npairs(it.nlist))
next(it::PairIterator, i) =
(it.nlist.i[i], it.nlist.j[i], it.nlist.r[i], it.nlist.R[i]), inc(i)
_item(it::PairIterator, i::Integer) =
(it.nlist.i[i], it.nlist.j[i], it.nlist.r[i], it.nlist.R[i])
iterate(it::PairIterator{T,TI}) where {T,TI} = _item(it, 1), TI(1)
iterate(it::PairIterator, i::Integer) =
i >= npairs(it.nlist) ? nothing : (_item(it, inc(i)), inc(i))
length(it::PairIterator) = npairs(it.nlist)

# -------------- iterator over sites ---------------
Expand All @@ -35,9 +36,10 @@ struct SiteIterator{T,TI} <: AbstractIterator
nlist::PairList{T,TI}
end

start{T,TI}(it::SiteIterator{T,TI}) = one(TI)
done(it::SiteIterator, i::Integer) = (i > nsites(it.nlist))
next(it::SiteIterator, i::Integer) = (i, site(it.nlist, i)...), inc(i)
_item(it::SiteIterator, i::Integer) = (i, site(it.nlist, i)...)
iterate(it::SiteIterator{T,TI}) where {T,TI} = _item(it, 1), one(TI)
iterate(it::SiteIterator, i::Integer) =
i >= length(it) ? nothing : (_item(it, i+1), inc(i))
length(it::SiteIterator) = nsites(it.nlist)


Expand Down
Loading

0 comments on commit 5a1da97

Please sign in to comment.