Skip to content

Commit

Permalink
simplifying fix for GradientKernelElement
Browse files Browse the repository at this point in the history
  • Loading branch information
SebastianAment committed Jun 3, 2022
1 parent 7c95b1c commit c551902
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 66 deletions.
64 changes: 37 additions & 27 deletions Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,16 @@ version = "2.3.0"
uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f"

[[deps.ArrayInterface]]
deps = ["Compat", "IfElse", "LinearAlgebra", "Requires", "SparseArrays", "Static"]
git-tree-sha1 = "81f0cb60dc994ca17f68d9fb7c942a5ae70d9ee4"
deps = ["ArrayInterfaceCore", "Compat", "IfElse", "LinearAlgebra", "Static"]
git-tree-sha1 = "a24db3a330d0ff64789abd52a26c732805619a53"
uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
version = "5.0.8"
version = "6.0.5"

[[deps.ArrayInterfaceCore]]
deps = ["LinearAlgebra", "SparseArrays", "SuiteSparse"]
git-tree-sha1 = "d3a275e927d411e054c4192e5aca03998c233e94"
uuid = "30b0a656-2188-435a-8636-2ec0e6a096e2"
version = "0.1.7"

[[deps.ArrayLayouts]]
deps = ["FillArrays", "LinearAlgebra", "SparseArrays"]
Expand Down Expand Up @@ -84,21 +90,21 @@ version = "0.4.2"

[[deps.CUDA]]
deps = ["AbstractFFTs", "Adapt", "BFloat16s", "CEnum", "CompilerSupportLibraries_jll", "ExprTools", "GPUArrays", "GPUCompiler", "LLVM", "LazyArtifacts", "Libdl", "LinearAlgebra", "Logging", "Printf", "Random", "Random123", "RandomNumbers", "Reexport", "Requires", "SparseArrays", "SpecialFunctions", "TimerOutputs"]
git-tree-sha1 = "19fb33957a5f85efb3cc10e70cf4dd4e30174ac9"
git-tree-sha1 = "925a16b909fdae16920c1319feadecffb6695b9d"
uuid = "052768ef-5323-5732-b1bb-66c8b64840ba"
version = "3.10.0"
version = "3.10.1"

[[deps.ChainRules]]
deps = ["ChainRulesCore", "Compat", "IrrationalConstants", "LinearAlgebra", "Random", "RealDot", "SparseArrays", "Statistics"]
git-tree-sha1 = "de68815ccf15c7d3e3e3338f0bd3a8a0528f9b9f"
git-tree-sha1 = "e9023f88b1655ffc6a4aaef2502878e8116151ef"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "1.33.0"
version = "1.35.1"

[[deps.ChainRulesCore]]
deps = ["Compat", "LinearAlgebra", "SparseArrays"]
git-tree-sha1 = "9950387274246d08af38f6eef8cb5480862a435f"
git-tree-sha1 = "9489214b993cd42d17f44c36e359bf6a7c919abf"
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
version = "1.14.0"
version = "1.15.0"

[[deps.ChangesOfVariables]]
deps = ["ChainRulesCore", "LinearAlgebra", "Test"]
Expand All @@ -114,9 +120,9 @@ version = "0.3.0"

[[deps.Compat]]
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
git-tree-sha1 = "b153278a25dd42c65abbf4e62344f9d22e59191b"
git-tree-sha1 = "87e84b2293559571802f97dd9c94cfd6be52c5e5"
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
version = "3.43.0"
version = "3.44.0"

[[deps.CompilerSupportLibraries_jll]]
deps = ["Artifacts", "Libdl"]
Expand Down Expand Up @@ -146,9 +152,9 @@ version = "1.10.0"

[[deps.DataStructures]]
deps = ["Compat", "InteractiveUtils", "OrderedCollections"]
git-tree-sha1 = "cc1a8e22627f33c789ab60b36a9132ac050bbf75"
git-tree-sha1 = "d1fff3a548102f48987a52a2e0d114fa97d730f0"
uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
version = "0.18.12"
version = "0.18.13"

[[deps.DataValueInterfaces]]
git-tree-sha1 = "bfc1187b79289637fa0ef6d4436ebdfe6905cbd6"
Expand Down Expand Up @@ -237,9 +243,9 @@ version = "0.13.2"

[[deps.Flux]]
deps = ["Adapt", "ArrayInterface", "CUDA", "ChainRulesCore", "Functors", "LinearAlgebra", "MLUtils", "MacroTools", "NNlib", "NNlibCUDA", "Optimisers", "ProgressLogging", "Random", "Reexport", "SparseArrays", "SpecialFunctions", "Statistics", "StatsBase", "Test", "Zygote"]
git-tree-sha1 = "f84e50845ab88702c721dc7c6129a85cbc1de332"
git-tree-sha1 = "62350a872545e1369b1d8f11358a21681aa73929"
uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c"
version = "0.13.1"
version = "0.13.3"

[[deps.FoldsThreads]]
deps = ["Accessors", "FunctionWrappers", "InitialValues", "SplittablesBase", "Transducers"]
Expand Down Expand Up @@ -317,9 +323,9 @@ uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"

[[deps.Intervals]]
deps = ["Dates", "Printf", "RecipesBase", "Serialization", "TimeZones"]
git-tree-sha1 = "b993074580045d1551d30990dc0fa5ba6feef92b"
git-tree-sha1 = "1fd6fccdbdccee5997fb245289d98386c8996180"
uuid = "d8418881-c3e1-53bb-8760-2df7ec849ed5"
version = "1.6.0"
version = "1.7.0"

[[deps.InverseFunctions]]
deps = ["Test"]
Expand Down Expand Up @@ -363,9 +369,9 @@ version = "1.1.1"

[[deps.LLVM]]
deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Printf", "Unicode"]
git-tree-sha1 = "c8d47589611803a0f3b4813d9e267cd4e3dbcefb"
git-tree-sha1 = "10a20c556107dc5833d3bb7c5e45c4a6e191bd28"
uuid = "929cbde3-209d-540e-8aea-75f648917ca0"
version = "4.11.1"
version = "4.13.0"

[[deps.LLVMExtra_jll]]
deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "Pkg", "TOML"]
Expand Down Expand Up @@ -428,9 +434,9 @@ uuid = "856f044c-d86e-5d09-b602-aeab76dc8ba7"
version = "2022.0.0+0"

[[deps.MLStyle]]
git-tree-sha1 = "e49789e5eb7b2d5577aaea395bfcac769df64bb8"
git-tree-sha1 = "2041c1fd6833b3720d363c3ea8140bffaf86d9c4"
uuid = "d8e11817-5142-5d16-987a-aa16d5891078"
version = "0.4.11"
version = "0.4.12"

[[deps.MLUtils]]
deps = ["ChainRulesCore", "DelimitedFiles", "FLoops", "FoldsThreads", "Random", "ShowCases", "Statistics", "StatsBase"]
Expand Down Expand Up @@ -506,9 +512,9 @@ version = "0.8.5"

[[deps.NNlibCUDA]]
deps = ["CUDA", "LinearAlgebra", "NNlib", "Random", "Statistics"]
git-tree-sha1 = "0d18b4c80a92a00d3d96e8f9677511a7422a946e"
git-tree-sha1 = "e161b835c6aa9e2339c1e72c3d4e39891eac7a4f"
uuid = "a00861dc-f156-4864-bf3c-e6376f28a68d"
version = "0.2.2"
version = "0.2.3"

[[deps.NaNMath]]
git-tree-sha1 = "737a5957f387b17e74d4ad2f440eb330b39a62c5"
Expand Down Expand Up @@ -546,9 +552,9 @@ version = "0.5.5+0"

[[deps.Optimisers]]
deps = ["ChainRulesCore", "Functors", "LinearAlgebra", "Random", "Statistics"]
git-tree-sha1 = "2442c3ddbda547c80e8b6451a103719d6a3593dd"
git-tree-sha1 = "26f58049054343c8103d67a5530284a35f1186cb"
uuid = "3bd65402-5787-11e9-1adc-39752487f4e2"
version = "0.2.4"
version = "0.2.5"

[[deps.OrderedCollections]]
git-tree-sha1 = "85f8e6578bf1f9ee0d11e7bb1b1456435479d47c"
Expand Down Expand Up @@ -682,9 +688,9 @@ version = "0.1.14"

[[deps.Static]]
deps = ["IfElse"]
git-tree-sha1 = "3a2a99b067090deb096edecec1dc291c5b4b31cb"
git-tree-sha1 = "5d2c08cef80c7a3a8ba9ca023031a85c263012c5"
uuid = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
version = "0.6.5"
version = "0.6.6"

[[deps.StaticArrays]]
deps = ["LinearAlgebra", "Random", "Statistics"]
Expand All @@ -708,6 +714,10 @@ git-tree-sha1 = "8977b17906b0a1cc74ab2e3a05faa16cf08a8291"
uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
version = "0.33.16"

[[deps.SuiteSparse]]
deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"]
uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9"

[[deps.SymEngine]]
deps = ["Compat", "Libdl", "LinearAlgebra", "RecipesBase", "SpecialFunctions", "SymEngine_jll"]
git-tree-sha1 = "6cf88a0b98c758a36e6e978a41e8a12f6f5cdacc"
Expand Down
54 changes: 15 additions & 39 deletions src/gradient.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,18 +60,27 @@ Base.eltype(K::GradientKernelElement{T}) where T = T
# gradient kernel element only used for sparsely representable elements
Base.Matrix(K::GradientKernelElement) = K * I(size(K, 1))

function GradientKernelElement{T}(k, x, y, it::InputTrait) where T
GradientKernelElement{T, typeof(k), typeof(x), typeof(y), typeof(it)}(k, x, y, it)
end

function gradient_kernel(k, x, y, it::InputTrait)
T = gramian_eltype(k, x, y)
GradientKernelElement{T}(k, x, y, it)
end

function gradient_kernel!(K::GradientKernelElement, k, x, y, it::InputTrait)
GradientKernelElement{eltype(K)}(k, x, y, it)
end

function Base.:*(G::GradientKernelElement, a)
T = promote_type(eltype(G), eltype(a))
b = zeros(T, size(a))
mul!(b, G, a)
end

const GenericGradientKernelElement{T, K, X, Y} = GradientKernelElement{T, K, X, Y, <:GenericInput}

const IsotropicGradientKernelElement{T, K, X, Y} = GradientKernelElement{T, K, X, Y, IsotropicInput}
function IsotropicGradientKernelElement{T}(k, x, y) where T
IsotropicGradientKernelElement{T, typeof(k), typeof(x), typeof(y)}(k, x, y, IsotropicInput())
end

# isotropic kernel
function LinearAlgebra.mul!(b, G::IsotropicGradientKernelElement, a, α::Number = 1, β::Number = 0) #, ::IsotropicInput = G.input_trait)
Expand All @@ -95,19 +104,8 @@ function WoodburyFactorizations.Woodbury(K::IsotropicGradientKernelElement)
return K = Woodbury(D, r, C, r')
end

function gradient_kernel!(K::IsotropicGradientKernelElement, k, x, y, ::IsotropicInput)
typeof(K)(k, x, y, IsotropicInput())
end

function gradient_kernel(k, x, y, ::IsotropicInput)
T = gramian_eltype(k, x, y)
IsotropicGradientKernelElement{T}(k, x, y)
end

const DotProductGradientKernelElement{T, K, X, Y} = GradientKernelElement{T, K, X, Y, DotProductInput}
function DotProductGradientKernelElement{T}(k, x, y) where T
DotProductGradientKernelElement{T, typeof(k), typeof(x), typeof(y)}(k, x, y, DotProductInput())
end

function LinearAlgebra.mul!(b, K::DotProductGradientKernelElement, a, α::Number = 1, β::Number = 0)
k, x, y = K.k, K.x, K.y
= dot(x, y)
Expand All @@ -126,19 +124,7 @@ function WoodburyFactorizations.Woodbury(K::DotProductGradientKernelElement)
return K = Woodbury(D, copy(y), C, copy(x)')
end

function gradient_kernel!(K::DotProductGradientKernelElement, k, x, y, ::DotProductInput)
typeof(K)(k, x, y, DotProductInput())
end

function gradient_kernel(k, x, y, ::DotProductInput)
T = gramian_eltype(k, x, y)
DotProductGradientKernelElement{T}(k, x, y)
end

const LinearFunctionalGradientKernelElement{T, K, X, Y} = GradientKernelElement{T, K, X, Y, StationaryLinearFunctionalInput}
function LinearFunctionalGradientKernelElement{T}(k, x, y) where T
LinearFunctionalGradientKernelElement{T, typeof(k), typeof(x), typeof(y)}(k, x, y, StationaryLinearFunctionalInput())
end

function LinearAlgebra.mul!(b, K::LinearFunctionalGradientKernelElement, a, α::Number = 1, β::Number = 0)
k, x, y = K.k, K.x, K.y
Expand All @@ -162,16 +148,6 @@ function LazyMatrixProduct(K::LinearFunctionalGradientKernelElement)
return LazyMatrixProduct(c, c2')
end

# is this necessary?
function gradient_kernel!(K::LinearFunctionalGradientKernelElement, k, x, y, ::StationaryLinearFunctionalInput)
typeof(K)(k, x, y, StationaryLinearFunctionalInput())
end

function gradient_kernel(k, x, y, ::StationaryLinearFunctionalInput)
T = gramian_eltype(k, x, y)
LinearFunctionalGradientKernelElement{T}(k, x, y)
end

function evaluate_block!(Gij, k::GradientKernel, x, y, IT = input_trait(k))
gradient_kernel!(Gij, k.k, x, y, IT)
end
Expand Down Expand Up @@ -420,7 +396,7 @@ end
################################################################################
# [f, ∂f] ∼ GP([μ, ∂μ], dK) # value + gradient kernel
# IDEA: For efficiency, maybe create ValueGradientKernelElement like in hessian.jl
# currently, this is an order of magnitude slower than GradientKernel
# might not be necessary anymore, benchmark against GradientKernel
struct ValueGradientKernel{T, K, IT<:InputTrait} <: AbstractDerivativeKernel{T, K}
k::K
input_trait::IT
Expand Down
1 change: 1 addition & 0 deletions src/gradient_algebra.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
############################ Gradient Algebra ##################################
# IDEA: could have specialization for gradient kernels of Power kernels of composite kernels
################################### Sum ########################################
# allocates space for gradient kernel evaluation but does not evaluate
# the separation from evaluation is useful for ValueGradientKernel
Expand Down

0 comments on commit c551902

Please sign in to comment.