Skip to content

Commit

Permalink
Refactor by importing Reactant.TracedRArray
Browse files Browse the repository at this point in the history
  • Loading branch information
mofeing committed Oct 29, 2024
1 parent 7b66f1d commit ba15e5e
Showing 1 changed file with 10 additions and 5 deletions.
15 changes: 10 additions & 5 deletions ext/TenetReactantExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module TenetReactantExt
using Tenet
using EinExprs
using Reactant
using Reactant: @reactant_override
using Reactant: @reactant_override, TracedRArray
const MLIR = Reactant.MLIR
const stablehlo = MLIR.Dialects.stablehlo

Expand Down Expand Up @@ -125,7 +125,7 @@ end

function Tenet.contract(
a::Tensor{Ta,Na,Aa}, b::Tensor{Tb,Nb,Ab}; dims=((inds(a), inds(b))), out=nothing
) where {Ta,Na,Aa<:Reactant.TracedRArray,Tb,Nb,Ab<:Reactant.TracedRArray}
) where {Ta,Na,Aa<:TracedRArray,Tb,Nb,Ab<:TracedRArray}
ia = collect(inds(a))
ib = collect(inds(b))
i = (dims, ia, ib)
Expand Down Expand Up @@ -154,12 +154,12 @@ function Tenet.contract(

result = Reactant.MLIR.IR.result(stablehlo.einsum(op_a, op_b; result_0, einsum_config))

data = Reactant.TracedRArray{T,length(ic)}((), result, rsize)
data = TracedRArray{T,length(ic)}((), result, rsize)
_res = Tensor(data, ic)
return _res
end

function Tenet.contract(a::Tensor{T,N,A}; dims=nonunique(inds(a)), out=nothing) where {T,N,A<:Reactant.TracedRArray}
function Tenet.contract(a::Tensor{T,N,A}; dims=nonunique(inds(a)), out=nothing) where {T,N,A<:TracedRArray}
ia = inds(a)
i = (dims, ia)

Expand All @@ -178,8 +178,13 @@ function Tenet.contract(a::Tensor{T,N,A}; dims=nonunique(inds(a)), out=nothing)

result = Reactant.MLIR.IR.result(stablehlo.unary_einsum(operand; result_0, einsum_config))

data = Reactant.TracedRArray{T,length(ic)}((), result, rsize)
data = TracedRArray{T,length(ic)}((), result, rsize)
return Tensor(data, ic)
end

Tenet.contract(a::Tensor, b::Tensor{T,N,TracedRArray{T,N}}; kwargs...) where {T,N} = contract(b, a; kwargs...)
function Tenet.contract(a::Tensor{Ta,Na,TracedRArray{Ta,Na}}, b::Tensor{Tb,Nb}; kwargs...) where {Ta,Na,Tb,Nb}
return contract(a, Tensor(Reactant.promote_to(TracedRArray{Tb,Nb}, parent(b)), inds(b)); kwargs...)
end

end

0 comments on commit ba15e5e

Please sign in to comment.