-
Notifications
You must be signed in to change notification settings - Fork 13
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
AD rules that apply to KroneckerProducts #92
Comments
The question might be what is fixed and what you might want to compute the derivative of. I originally conceived Kronecker to work with systems as Taking the gradients of the Kronecker matrix itself would be a Maybe the I have been working with ChainRulesCore, so you might open a PR and we can look together? |
It doesn't appear to be quite that straight forward. Care must be taken on setting the appropriate size of
I've managed to put together a semi-working example with the eager using LinearAlgebra
using Random
using Zygote
M, N = 3, 2
n_samples = 3
Random.seed!(0)
A = rand(1, N)
B = rand(1, M)
x = rand(M*N, n_samples)
y = rand(n_samples)
model(A, B, X) = kron(A, B) * X
function loss(A, B, X)
Z = model(A, B, X) - y'
L = 0.5 * Z * Z'
return L[1]
end
function gradient_A(A, B, x)
Z = model(A, B, x) - y'
n = size(A, 2)
IA_col = Diagonal(ones(n))
return Z * (kron(IA_col', B) * x)'
end
function gradient_B(A, B, x)
Z = model(A, B, x) - y'
n = size(B, 2)
IB_col = Diagonal(ones(n))
return Z * (kron(A, IB_col) * x)'
end
# Compare hand-written gradients with running Zygote.gradient on the loss function
@assert gradient_A(A, B, x) ≈ gradient(loss, A, B, x)[1]
@assert gradient_B(A, B, x) ≈ gradient(loss, A, B, x)[2]
# Show partial derivatives of the loss function w.r.t. to the Kronecker-factors.
@show gradient(loss, A, B, x)[1:2]
What did you have in mind for this? |
(Related to #11)
I'm trying to wrap my head around getting gradients with
kron
/kronecker
.The text was updated successfully, but these errors were encountered: