diff --git a/src/abstract_tensor_train.jl b/src/abstract_tensor_train.jl index 594ed4d..1e60991 100644 --- a/src/abstract_tensor_train.jl +++ b/src/abstract_tensor_train.jl @@ -28,6 +28,21 @@ function check_bond_dims(tensors::Vector{<:Array}) return true end +""" + evaluate(A::AbstractTensorTrain, X...) + +Evaluate the Tensor Train `A` at input `X` + +Example: +```@example + L = 3 + q = (2, 3) + A = rand_tt(4, L, q...) + X = [[rand(1:qi) for qi in q] for l in 1:L] + evaluate(A, X) +``` +""" +evaluate(A::AbstractTensorTrain, X...) = tr(prod(@view a[:, :, x...] for (a,x) in zip(A, X...))) """ normalize_eachmatrix!(A::AbstractTensorTrain) diff --git a/src/periodic_tensor_train.jl b/src/periodic_tensor_train.jl index b7f4586..1988a84 100644 --- a/src/periodic_tensor_train.jl +++ b/src/periodic_tensor_train.jl @@ -53,10 +53,6 @@ function rand_periodic_tt(bondsizes::AbstractVector{<:Integer}, q...) end rand_periodic_tt(d::Integer, L::Integer, q...) = rand_periodic_tt(fill(d, L-1), q...) -evaluate(A::PeriodicTensorTrain, X...) = tr(prod(@view a[:, :, x...] for (a,x) in zip(A, X...))) - - - function _compose(f, A::PeriodicTensorTrain{F,NA}, B::PeriodicTensorTrain{F,NB}) where {F,NA,NB} @assert NA == NB diff --git a/src/tensor_train.jl b/src/tensor_train.jl index 3271970..6c851d1 100644 --- a/src/tensor_train.jl +++ b/src/tensor_train.jl @@ -58,23 +58,6 @@ end rand_tt(d::Integer, L::Integer, q...) = rand_tt([1; fill(d, L-1); 1], q...) -""" - evaluate(A::AbstractTensorTrain, X...) - -Evaluate the Tensor Train `A` at input `X` - -Example: -```@example - L = 3 - q = (2, 3) - A = rand_tt(4, L, q...) - X = [[rand(1:qi) for qi in q] for l in 1:L] - evaluate(A, X) -``` -""" -evaluate(A::TensorTrain, X...) = only(prod(@view a[:, :, x...] for (a,x) in zip(A, X...))) - - """ orthogonalize_right!(A::AbstractTensorTrain; svd_trunc::SVDTrunc)