Skip to content

Commit

Permalink
fix symbolic abs function
Browse files Browse the repository at this point in the history
  • Loading branch information
GiggleLiu committed Sep 25, 2023
1 parent cf5b36d commit 57f6fff
Show file tree
Hide file tree
Showing 10 changed files with 68 additions and 106 deletions.
4 changes: 3 additions & 1 deletion lib/YaoArrayRegister/src/instruct.jl
Original file line number Diff line number Diff line change
Expand Up @@ -221,8 +221,10 @@ function YaoAPI.instruct!(::Val{2},
::Val{:H},
locs::NTuple{N,Int},
) where {T, N}
instruct!(Val(2), state, matchtype(T, YaoArrayRegister.Const.H), locs)
instruct!(Val(2), state, _hadamard_matrix(T), locs)
end
# this is an interface for future extension (e.g. in YaoSym)
_hadamard_matrix(::Type{T}) where T = matchtype(T, YaoArrayRegister.Const.H)

function YaoAPI.instruct!(::Val{2},
state::AbstractVecOrMat{T},
Expand Down
8 changes: 3 additions & 5 deletions lib/YaoSym/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,21 @@ version = "0.6.6"
BitBasis = "50ba71b6-fa0f-514d-ae9a-0916efc90dcf"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LuxurySparse = "d05aeea4-b7d4-55ac-b691-9e7fabb07ba2"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SymEngine = "123dc426-2d89-5057-bbad-38513e3affd8"
YaoArrayRegister = "e600142f-9330-5003-8abb-0ebd767abc51"
YaoBlocks = "418bc28f-b43b-5e0b-a6e7-61bbc1a2c1df"

[compat]
BitBasis = "0.8"
LuxurySparse = "0.7"
Requires = "1"
SymEngine = "0.8"
SymEngine = "0.8, 0.9, 0.10"
YaoArrayRegister = "0.9"
YaoBlocks = "0.13"
julia = "1"

[extras]
SymEngine = "123dc426-2d89-5057-bbad-38513e3affd8"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "SymEngine"]
test = ["Test"]
19 changes: 14 additions & 5 deletions lib/YaoSym/src/YaoSym.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,20 @@
module YaoSym

using Requires
using SparseArrays, LuxurySparse, LinearAlgebra
using BitBasis, YaoArrayRegister, YaoBlocks
import YaoArrayRegister: parametric_mat
using SymEngine
using SymEngine: @vars, Basic, BasicType, BasicOp, BasicTrigFunction, BasicComplexNumber

include("register.jl")
# SymEngine APIs
export @vars, Basic, subs, expand, simplify_expi

# Symbolic registers
export @ket_str, @bra_str
export SymReg, AdjointSymReg, SymRegOrAdjointSymReg
export szero_state

function __init__()
@require SymEngine = "123dc426-2d89-5057-bbad-38513e3affd8" include("symengine/backend.jl")
end
include("register.jl")
include("symengine/symengine.jl")

end # module
3 changes: 0 additions & 3 deletions lib/YaoSym/src/register.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
using SparseArrays, BitBasis, YaoArrayRegister
export @ket_str, @bra_str

function parse_str(s::String)
v = 0
k = 1
Expand Down
120 changes: 47 additions & 73 deletions lib/YaoSym/src/symengine/blocks.jl
Original file line number Diff line number Diff line change
@@ -1,63 +1,39 @@
using YaoBlocks
using LuxurySparse
using LinearAlgebra
using ..SymEngine
using ..SymEngine: BasicType, BasicOp, BasicTrigFunction

op_types = [:Mul, :Add, :Pow]
const BiVarOp = Union{[SymEngine.BasicType{Val{i}} for i in op_types]...}

export @vars

simag = SymFunction("Im")
sreal = SymFunction("Re")
sabs = SymFunction("abs")
sangle = SymFunction("angle")

Base.promote_rule(::Type{Bool}, ::Type{Basic}) = Basic
# NOTE: need to annotate the output because otherwise it is type unstable!
Base.conj(x::Basic)::Basic = Basic(conj(SymEngine.BasicType(x)))
Base.conj(x::BasicType) = real(x) - im * imag(x)
Base.abs(x::Basic)::Basic = Basic(abs(SymEngine.BasicType(x)))
#Base.abs(x::BasicComplexNumber)::Basic = sqrt(real(x)^2 + imag(x)^2)
# abs(a) ^ real(b) * exp(-angle(a) * imag(b))
function Base.abs(x::BasicType{Val{:Pow}})
a, b = get_args(x.x)
abs(a)^real(b) * exp(-angle(a) * imag(b))
end

Base.angle(x::Basic)::Basic = Basic(angle(SymEngine.BasicType(x)))
#Base.angle(x::BasicComplexNumber)::Basic = atan(imag(x), real(x))
Base.angle(x::BasicType)::Basic = sangle(x.x)
Base.angle(::BasicType{Val{:Symbol}})::Basic = Basic(0)
Base.angle(::BasicType{Val{:Constant}})::Basic = Basic(0)

Base.conj(x::BiVarOp) = juliafunc(x)(conj.(get_args(x.x))...)
Base.conj(x::BasicTrigFunction) = juliafunc(x)(conj.(get_args(x.x)...)...)
# WARNING: symbols and constants are assumed real!
Base.imag(x::BasicType{Val{:Constant}}) = Basic(0)
Base.imag(x::BasicType{Val{:Symbol}}) = Basic(0)
Base.abs(x::Basic) = Basic(abs(SymEngine.BasicType(x)))
Base.abs(x::BasicType{Val{:Constant}}) = x
Base.abs(x::BasicType{Val{:Symbol}}) = x

Base.imag(::BasicType{Val{:Constant}}) = Basic(0)
Base.imag(::BasicType{Val{:Symbol}}) = Basic(0)
function Base.imag(x::BasicType{Val{:Add}})
args = get_args(x.x)
mapreduce(imag, +, args)
end

function Base.real(x::BasicType{Val{:Add}})
args = get_args(x.x)
mapreduce(real, +, args)
end

function Base.abs(x::BasicType{Val{:Add}})
args = get_args(x.x)
mapreduce(abs, +, args)
end

function Base.imag(x::BasicType{Val{:Mul}})
args = (get_args(x.x)...,)
get_mul_imag(args)
end

function Base.real(x::BasicType{Val{:Pow}})
a, b = get_args(x.x)
if imag(a) == 0 && imag(b) == 0
return x.x
else
if imag(a) == 0
return a^real(b) * cos(log(a) * imag(b))
else
return sreal(x.x)
end
end
end

function Base.imag(x::BasicType{Val{:Pow}})
a, b = get_args(x.x)
if imag(a) == 0 && imag(b) == 0
Expand All @@ -70,53 +46,51 @@ function Base.imag(x::BasicType{Val{:Pow}})
end
end
end

function Base.abs(x::BasicType{Val{:Pow}})
a, b = get_args(x.x)
abs(a)^real(b)
end

function Base.real(x::BasicType{Val{:Mul}})
args = (get_args(x.x)...,)
get_mul_real(args)
function Base.imag(x::BasicTrigFunction)
a, = get_args(x.x)
if imag(a) == 0
return Basic(0)
else
return simag(x.x)
end
end

function get_mul_imag(args::NTuple{N,Any}) where {N}
imag(args[1]) * get_mul_real(args[2:end]) + real(args[1]) * get_mul_imag(args[2:end])
end
get_mul_imag(args::Tuple{Basic}) = imag(args[1])

function get_mul_real(args::NTuple{N,Any}) where {N}
real(args[1]) * get_mul_real(args[2:end]) - imag(args[1]) * get_mul_imag(args[2:end])
function Base.real(x::BasicType{Val{:Add}})
args = get_args(x.x)
mapreduce(real, +, args)
end
get_mul_real(args::Tuple{Basic}) = real(args[1])

function Base.real(x::BasicTrigFunction)
a, = get_args(x.x)
if imag(a) == 0
function Base.real(x::BasicType{Val{:Pow}})
a, b = get_args(x.x)
if imag(a) == 0 && imag(b) == 0
return x.x
else
return sreal(x.x)
if imag(a) == 0
return a^real(b) * cos(log(a) * imag(b))
else
return sreal(x.x)
end
end
end

function Base.abs(x::BasicTrigFunction)
function Base.real(x::BasicType{Val{:Mul}})
args = (get_args(x.x)...,)
get_mul_real(args)
end
function Base.real(x::BasicTrigFunction)
a, = get_args(x.x)
if imag(a) == 0
return x.x
else
return sabs(x.x)
return sreal(x.x)
end
end

function Base.imag(x::BasicTrigFunction)
a, = get_args(x.x)
if imag(a) == 0
return Basic(0)
else
return simag(x.x)
end
function get_mul_real(args::NTuple{N,Any}) where {N}
real(args[1]) * get_mul_real(args[2:end]) - imag(args[1]) * get_mul_imag(args[2:end])
end
get_mul_real(args::Tuple{Basic}) = real(args[1])

@generated function juliafunc(x::BasicType{Val{T}}) where {T}
SymEngine.map_fn(T, SymEngine.fn_map)
Expand All @@ -132,7 +106,8 @@ YaoBlocks.shift(θ::SymReal) = ShiftGate(θ)
YaoBlocks.mat(::Type{Basic}, gate::GT) where GT<:ConstantGate = _pretty_basic.(mat(gate))
YaoBlocks.mat(::Type{Basic}, gate::ConstGate.TGate) = Diagonal(Basic[1, exp(Basic(im)*Basic(π)/4)])
YaoBlocks.mat(::Type{Basic}, gate::ConstGate.TdagGate) = Diagonal(Basic[1, exp(-Basic(im)*Basic(π)/4)])
YaoBlocks.mat(::Type{Basic}, ::HGate) = 1 / sqrt(Basic(2)) * Basic[1 1; 1 -1]
YaoArrayRegister._hadamard_matrix(::Type{Basic}) = 1 / sqrt(Basic(2)) * Basic[1 1; 1 -1]
YaoBlocks.mat(::Type{Basic}, ::HGate) = YaoArrayRegister._hadamard_matrix(Basic)
YaoBlocks.mat(::Type{Basic}, gate::ShiftGate) = Diagonal([1, exp(im * gate.theta)])
YaoBlocks.mat(::Type{Basic}, gate::PhaseGate) = exp(im * gate.theta) * IMatrix(2)
function YaoBlocks.mat(::Type{Basic}, R::RotationGate{D}) where {D}
Expand All @@ -154,7 +129,6 @@ YaoBlocks.PSwap(n::Int, locs::Tuple{Int,Int}, θ::SymReal) =
YaoBlocks.pswap(n::Int, i::Int, j::Int, α::SymReal) = PSwap(n, (i, j), α)
YaoBlocks.pswap(i::Int, j::Int, α::SymReal) = n -> pswap(n, i, j, α)

export subs
SymEngine.subs(c::AbstractBlock, args...; kwargs...) = subs(Basic, c, args...; kwargs...)
function SymEngine.subs(::Type{T}, c::AbstractBlock, args...; kwargs...) where {T}
c = setiparams(c, map(x -> T(subs(x, args...; kwargs...)), getiparams(c))...)
Expand Down
3 changes: 0 additions & 3 deletions lib/YaoSym/src/symengine/instruct.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
using ..SymEngine
import YaoArrayRegister: parametric_mat

parametric_mat(::Type{T}, ::Val{:Rx}, theta::Basic) where {T} =
Basic[cos(theta / 2) -im*sin(theta / 2); -im*sin(theta / 2) cos(theta / 2)]
parametric_mat(::Type{T}, ::Val{:Ry}, theta::Basic) where {T} =
Expand Down
3 changes: 0 additions & 3 deletions lib/YaoSym/src/symengine/patch.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
using ..SymEngine: Basic
export simplify_expi

function Base.iszero(x::Basic)
isempty(free_symbols(x)) && iszero(N(x))
end
Expand Down
8 changes: 1 addition & 7 deletions lib/YaoSym/src/symengine/register.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,3 @@
using SparseArrays, BitBasis, YaoArrayRegister
using ..SymEngine
export @ket_str, @bra_str
export SymReg, AdjointSymReg, SymRegOrAdjointSymReg, expand
export szero_state

YaoArrayRegister._warn_type(raw::AbstractMatrix{Basic}) = nothing

const SymReg{D,MT} = AbstractArrayReg{D,Basic,MT} where {MT<:AbstractMatrix{Basic}}
Expand Down Expand Up @@ -34,7 +28,7 @@ function _pretty_basic(x::Complex)
end
end

function ket_m(s)
function ket_m(s::AbstractString)
v, N = parse_str(s)
st = spzeros(Basic, 1 << N, 1)
st[v+1] = 1
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
using ..SymEngine
using ..SymEngine: @vars, Basic
export @vars, Basic, subs

include("register.jl")
include("instruct.jl")
include("blocks.jl")
Expand Down
2 changes: 0 additions & 2 deletions lib/YaoSym/test/symengine/blocks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@ using Test
@test real(exp(im * a)) == cos(a)
@test imag(exp(im * a)) == sin(a)
@test abs(exp(im * a)) == 1
@test abs(sin(a)) == sin(a)
@test abs(sin(a) + 2) == sin(a) + 2
end

@testset "mat" begin
Expand Down

0 comments on commit 57f6fff

Please sign in to comment.