Skip to content

Commit

Permalink
generalize evaluate
Browse files Browse the repository at this point in the history
  • Loading branch information
abraunst authored and stecrotti committed Oct 19, 2023
1 parent 1a58fc8 commit 00061b4
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 21 deletions.
15 changes: 15 additions & 0 deletions src/abstract_tensor_train.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,21 @@ function check_bond_dims(tensors::Vector{<:Array})
return true
end

"""
evaluate(A::AbstractTensorTrain, X...)
Evaluate the Tensor Train `A` at input `X`
Example:
```@example
L = 3
q = (2, 3)
A = rand_tt(4, L, q...)
X = [[rand(1:qi) for qi in q] for l in 1:L]
evaluate(A, X)
```
"""
evaluate(A::AbstractTensorTrain, X...) = tr(prod(@view a[:, :, x...] for (a,x) in zip(A, X...)))

"""
normalize_eachmatrix!(A::AbstractTensorTrain)
Expand Down
4 changes: 0 additions & 4 deletions src/periodic_tensor_train.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,6 @@ function rand_periodic_tt(bondsizes::AbstractVector{<:Integer}, q...)
end
rand_periodic_tt(d::Integer, L::Integer, q...) = rand_periodic_tt(fill(d, L-1), q...)

evaluate(A::PeriodicTensorTrain, X...) = tr(prod(@view a[:, :, x...] for (a,x) in zip(A, X...)))




function _compose(f, A::PeriodicTensorTrain{F,NA}, B::PeriodicTensorTrain{F,NB}) where {F,NA,NB}
@assert NA == NB
Expand Down
17 changes: 0 additions & 17 deletions src/tensor_train.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,23 +58,6 @@ end
rand_tt(d::Integer, L::Integer, q...) = rand_tt([1; fill(d, L-1); 1], q...)


"""
evaluate(A::AbstractTensorTrain, X...)
Evaluate the Tensor Train `A` at input `X`
Example:
```@example
L = 3
q = (2, 3)
A = rand_tt(4, L, q...)
X = [[rand(1:qi) for qi in q] for l in 1:L]
evaluate(A, X)
```
"""
evaluate(A::TensorTrain, X...) = only(prod(@view a[:, :, x...] for (a,x) in zip(A, X...)))


"""
orthogonalize_right!(A::AbstractTensorTrain; svd_trunc::SVDTrunc)
Expand Down

0 comments on commit 00061b4

Please sign in to comment.