From e009628c091d2b108ec7fae4649bc97ea670eb65 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt <67932820+kshyatt-aws@users.noreply.github.com> Date: Fri, 27 Sep 2024 11:25:48 -0400 Subject: [PATCH] fix: Support casting for most types (#54) --- src/Quasar.jl | 19 ++++++++++++++---- test/test_openqasm3.jl | 44 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 59 insertions(+), 4 deletions(-) diff --git a/src/Quasar.jl b/src/Quasar.jl index 672f727..8019dbe 100644 --- a/src/Quasar.jl +++ b/src/Quasar.jl @@ -1769,10 +1769,21 @@ function (v::AbstractVisitor)(program_expr::QasmExpression) return evaluate_unary_op(op, arg) elseif head(program_expr) == :cast casting_to = program_expr.args[1].args[1] - value = v(program_expr.args[2]) - if casting_to == Bool - return value > 0 - # TODO + value = v(program_expr.args[2]) + if casting_to == Bool && !(value isa Period) + return value isa BitVector ? any(value) : value > 0 + elseif casting_to isa SizedBitVector && !(value isa Period || value isa AbstractFloat) + new_size = v(casting_to.size) + return value isa BitVector ? value[1:new_size] : BitVector(reverse(digits(value, base=2, pad=new_size))) + elseif casting_to isa SizedNumber && !(value isa Period) + num_value = value isa BitVector ? sum(reverse(value)[k]*2^(k-1) for k=1:length(value)) : value + if casting_to isa SizedInt + return Int(num_value) + elseif casting_to isa SizedUInt + return UInt(num_value) + elseif casting_to isa SizedFloat + return Float64(num_value) + end else throw(QasmVisitorError("unable to evaluate cast expression $program_expr")) end diff --git a/test/test_openqasm3.jl b/test/test_openqasm3.jl index 5fffd09..4e02dbd 100644 --- a/test/test_openqasm3.jl +++ b/test/test_openqasm3.jl @@ -261,6 +261,50 @@ get_tol(shots::Int) = return ( @test simulation.state_vector ≈ sv @test circuit.result_types == [BraketSimulator.Amplitude(["00", "01", "10", "11"])] end + @testset "Casting" begin + @testset "Casting to $to_type from $from_type" for (to_type, to_value) in (("bool", true),), (from_type, from_value) in (("int[32]", "32",), + ("uint[16]", "1",), + ("float", "2.5",), + ("bool", "true",), + ("bit", "\"1\"",), + ) + qasm = """ + $from_type a = $from_value; + $to_type b = $to_type(a); + """ + parsed = parse_qasm(qasm) + visitor = QasmProgramVisitor() + visitor(parsed) + @test visitor.classical_defs["b"].val == to_value + end + @testset "Casting to $to_type from $from_type" for (to_type, to_value) in (("uint[32]", 32), ("int[16]", 32), ("float", 32.0)), (from_type, from_value) in (("int[32]", "32",), + ("uint[16]", "32",), + ("float", "32.0",), + ("bit[6]", "\"100000\"",), + ) + qasm = """ + $from_type a = $from_value; + $to_type b = $to_type(a); + """ + parsed = parse_qasm(qasm) + visitor = QasmProgramVisitor() + visitor(parsed) + @test visitor.classical_defs["b"].val == to_value + end + @testset "Casting to $to_type from $from_type" for (to_type, to_value) in (("bit[6]", BitVector([1,0,0,0,0,0])),), (from_type, from_value) in (("int[32]", "32",), + ("uint[16]", "32",), + ("bit[6]", "\"100000\"",), + ) + qasm = """ + $from_type a = $from_value; + $to_type b = $to_type(a); + """ + parsed = parse_qasm(qasm) + visitor = QasmProgramVisitor() + visitor(parsed) + @test visitor.classical_defs["b"].val == to_value + end + end @testset "Numbers $qasm_str" for (qasm_str, var_name, output_val) in (("float[32] a = 1.24e-3;", "a", 1.24e-3), ("complex[float] b = 1-0.23im;", "b", 1-0.23im), ("const bit c = \"0\";", "c", falses(1)),