From ba15e5ef696c8e4ce0439cd05917f3df40f9a89f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Sun, 6 Oct 2024 17:05:42 -0400 Subject: [PATCH] Refactor by importing `Reactant.TracedRArray` --- ext/TenetReactantExt.jl | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/ext/TenetReactantExt.jl b/ext/TenetReactantExt.jl index f7816f8d..f75a271a 100644 --- a/ext/TenetReactantExt.jl +++ b/ext/TenetReactantExt.jl @@ -3,7 +3,7 @@ module TenetReactantExt using Tenet using EinExprs using Reactant -using Reactant: @reactant_override +using Reactant: @reactant_override, TracedRArray const MLIR = Reactant.MLIR const stablehlo = MLIR.Dialects.stablehlo @@ -125,7 +125,7 @@ end function Tenet.contract( a::Tensor{Ta,Na,Aa}, b::Tensor{Tb,Nb,Ab}; dims=(∩(inds(a), inds(b))), out=nothing -) where {Ta,Na,Aa<:Reactant.TracedRArray,Tb,Nb,Ab<:Reactant.TracedRArray} +) where {Ta,Na,Aa<:TracedRArray,Tb,Nb,Ab<:TracedRArray} ia = collect(inds(a)) ib = collect(inds(b)) i = ∩(dims, ia, ib) @@ -154,12 +154,12 @@ function Tenet.contract( result = Reactant.MLIR.IR.result(stablehlo.einsum(op_a, op_b; result_0, einsum_config)) - data = Reactant.TracedRArray{T,length(ic)}((), result, rsize) + data = TracedRArray{T,length(ic)}((), result, rsize) _res = Tensor(data, ic) return _res end -function Tenet.contract(a::Tensor{T,N,A}; dims=nonunique(inds(a)), out=nothing) where {T,N,A<:Reactant.TracedRArray} +function Tenet.contract(a::Tensor{T,N,A}; dims=nonunique(inds(a)), out=nothing) where {T,N,A<:TracedRArray} ia = inds(a) i = ∩(dims, ia) @@ -178,8 +178,13 @@ function Tenet.contract(a::Tensor{T,N,A}; dims=nonunique(inds(a)), out=nothing) result = Reactant.MLIR.IR.result(stablehlo.unary_einsum(operand; result_0, einsum_config)) - data = Reactant.TracedRArray{T,length(ic)}((), result, rsize) + data = TracedRArray{T,length(ic)}((), result, rsize) return Tensor(data, ic) end +Tenet.contract(a::Tensor, b::Tensor{T,N,TracedRArray{T,N}}; kwargs...) where {T,N} = contract(b, a; kwargs...) +function Tenet.contract(a::Tensor{Ta,Na,TracedRArray{Ta,Na}}, b::Tensor{Tb,Nb}; kwargs...) where {Ta,Na,Tb,Nb} + return contract(a, Tensor(Reactant.promote_to(TracedRArray{Tb,Nb}, parent(b)), inds(b)); kwargs...) +end + end