Skip to content

Commit

Permalink
Merge pull request #299 from quantumlib/numpy-params
Browse files Browse the repository at this point in the history
Improve parameter validation
  • Loading branch information
95-martin-orion authored Mar 3, 2021
2 parents d1f7680 + 83f2c8e commit f0df655
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 4 deletions.
21 changes: 17 additions & 4 deletions qsimcirq/qsim_circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,14 @@
from typing import Dict, Union


# List of parameter names that appear in valid Cirq protos.
GATE_PARAMS = [
'exponent', 'phase_exponent', 'global_shift',
'x_exponent', 'z_exponent', 'axis_phase_exponent',
'phi', 'theta'
]


def _cirq_gate_kind(gate: cirq.ops.Gate):
if isinstance(gate, cirq.ops.ControlledGate):
return _cirq_gate_kind(gate.sub_gate)
Expand Down Expand Up @@ -215,10 +223,15 @@ def add_op_to_circuit(
else:
qsim.add_matrix_gate_channel(time, qsim_qubits, m, circuit)
else:
params = {
p.strip('_'): val for p, val in vars(qsim_gate).items()
if isinstance(val, float) or isinstance(val, int)
}
params = {}
for p, val in vars(qsim_gate).items():
key = p.strip('_')
if key not in GATE_PARAMS:
continue
if isinstance(val, (int, float, np.integer, np.floating)):
params[key] = val
else:
raise ValueError('Parameters must be numeric.')
if isinstance(circuit, qsim.Circuit):
qsim.add_gate(gate_kind, time, qsim_qubits, params, circuit)
else:
Expand Down
23 changes: 23 additions & 0 deletions qsimcirq_tests/qsimcirq_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ def test_cirq_qsim_simulate_sweep(mode: str):
assert cirq.linalg.allclose_up_to_global_phase(
qsim_result[i].state_vector(), cirq_result[i].state_vector())


def test_input_vector_validation():
cirq_circuit = cirq.Circuit(
cirq.X(cirq.LineQubit(0)), cirq.X(cirq.LineQubit(1))
Expand All @@ -192,6 +193,28 @@ def test_input_vector_validation():
cirq_circuit, params, initial_state=initial_state)


def test_numpy_params():
q0 = cirq.LineQubit(0)
x, y = sympy.Symbol('x'), sympy.Symbol('y')
circuit = cirq.Circuit(cirq.X(q0) ** x, cirq.H(q0) ** y)
prs = [{x: np.int64(0), y: np.int64(1)}, {x: np.int64(1), y: np.int64(0)}]

qsim_simulator = qsimcirq.QSimSimulator()
qsim_result = qsim_simulator.simulate_sweep(circuit, params=prs)


def test_invalid_params():
# Parameters must have numeric values.
q0 = cirq.LineQubit(0)
x, y = sympy.Symbol('x'), sympy.Symbol('y')
circuit = cirq.Circuit(cirq.X(q0) ** x, cirq.H(q0) ** y)
prs = [{x: np.int64(0), y: np.int64(1)}, {x: np.int64(1), y: 'z'}]

qsim_simulator = qsimcirq.QSimSimulator()
with pytest.raises(ValueError, match='Parameters must be numeric'):
_ = qsim_simulator.simulate_sweep(circuit, params=prs)


@pytest.mark.parametrize('mode', ['noiseless', 'noisy'])
def test_cirq_qsim_run(mode: str):
# Pick qubits.
Expand Down

0 comments on commit f0df655

Please sign in to comment.