Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 40 additions & 1 deletion qualtran/bloqs/multiplexers/selected_majorana_fermion.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

from functools import cached_property
from typing import Iterator, Sequence, Tuple, Union
from typing import Dict, Iterator, Sequence, Tuple, Union

import attrs
import cirq
Expand All @@ -25,6 +25,7 @@
from qualtran._infra.data_types import BQUInt
from qualtran._infra.gate_with_registers import total_bits
from qualtran.bloqs.multiplexers.unary_iteration_bloq import UnaryIterationGate
from qualtran.simulation.classical_sim import ClassicalValT


@attrs.frozen
Expand Down Expand Up @@ -137,5 +138,43 @@ def nth_operation( # type: ignore[override]
yield self.target_gate(target[target_idx]).controlled_by(control)
yield cirq.CZ(*accumulator, target[target_idx])

def on_classical_vals(self, **vals) -> Dict[str, 'ClassicalValT']:
if self.target_gate != cirq.X and self.target_gate != cirq.Z:
return NotImplemented
if len(self.control_registers) > 1 or len(self.selection_registers) > 1:
return NotImplemented
Comment on lines +144 to +145
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this restriction necessary?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure - it is hard for me to understand what this gate does in the general case. Is my understanding in #1699 (comment) correct?

control_name = self.control_registers[0].name
control = vals[control_name]
selection_name = self.selection_registers[0].name
selection = vals[selection_name]
target = vals['target']

# When target_gate == cirq.X, the action is (modulo phase) a single bitflip.
if control and self.target_gate == cirq.X:
max_selection = self.selection_registers[0].dtype.iteration_length_or_zero() - 1
target = (2 ** (max_selection - selection)) ^ target
Comment on lines +154 to +155
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add a comment describing how this logic works

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

# When target_gate == cirq.Z, the action is only in the phase.

return {control_name: control, selection_name: selection, 'target': target}

def basis_state_phase(self, **vals) -> Union[complex, None]:
if self.target_gate != cirq.X and self.target_gate != cirq.Z:
return None
if len(self.control_registers) > 1 or len(self.selection_registers) > 1:
return None
control_name = self.control_registers[0].name
control = vals[control_name]
selection_name = self.selection_registers[0].name
selection = vals[selection_name]
target = vals['target']
if control:
max_selection = self.selection_registers[0].dtype.iteration_length_or_zero() - 1
if self.target_gate == cirq.X:
num_phases = (target >> (max_selection - selection + 1)).bit_count()
else:
num_phases = (target >> (max_selection - selection)).bit_count()
return 1 if (num_phases % 2) == 0 else -1
return 1

def __str__(self):
return f'SelectedMajoranaFermion({self.target_gate})'
16 changes: 15 additions & 1 deletion qualtran/bloqs/multiplexers/selected_majorana_fermion_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@
from qualtran._infra.gate_with_registers import get_named_qubits, total_bits
from qualtran.bloqs.multiplexers.selected_majorana_fermion import SelectedMajoranaFermion
from qualtran.cirq_interop.testing import GateHelper
from qualtran.testing import assert_valid_bloq_decomposition
from qualtran.testing import (
assert_consistent_phased_classical_action,
assert_valid_bloq_decomposition,
)


@pytest.mark.slow
Expand Down Expand Up @@ -148,3 +151,14 @@ def test_selected_majorana_fermion_gate_make_on():
op = gate.on_registers(**get_named_qubits(gate.signature))
op2 = SelectedMajoranaFermion.make_on(target_gate=cirq.X, **get_named_qubits(gate.signature))
assert op == op2


@pytest.mark.parametrize("selection_bitsize, target_bitsize", [(2, 4), (3, 5)])
@pytest.mark.parametrize("target_gate", [cirq.X, cirq.Z])
def test_selected_majorana_fermion_classical_action(selection_bitsize, target_bitsize, target_gate):
gate = SelectedMajoranaFermion(
Register('selection', BQUInt(selection_bitsize, target_bitsize)), target_gate=target_gate
)
assert_consistent_phased_classical_action(
gate, selection=range(target_bitsize), target=range(2**target_bitsize), control=range(2)
)
27 changes: 27 additions & 0 deletions qualtran/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
Side,
)
from qualtran._infra.composite_bloq import _get_flat_dangling_soqs
from qualtran.simulation.classical_sim import do_phased_classical_simulation
from qualtran.symbolics import is_symbolic

if TYPE_CHECKING:
Expand Down Expand Up @@ -716,3 +717,29 @@ def assert_consistent_classical_action(
np.testing.assert_equal(
bloq_res, decomposed_res, err_msg=f'{bloq=} {call_with=} {bloq_res=} {decomposed_res=}'
)


def assert_consistent_phased_classical_action(
bloq: Bloq,
**parameter_ranges: Union[NDArray, Sequence[int], Sequence[Union[Sequence[int], NDArray]]],
):
"""Check that the bloq has a phased classical action consistent with its decomposition.

Args:
bloq: bloq to test.
parameter_ranges: named arguments giving ranges for each of the registers of the bloq.
"""
cb = bloq.decompose_bloq()
parameter_names = tuple(parameter_ranges.keys())
for vals in itertools.product(*[parameter_ranges[p] for p in parameter_names]):
call_with = {p: v for p, v in zip(parameter_names, vals)}
bloq_res, bloq_phase = do_phased_classical_simulation(bloq, call_with)
decomposed_res, decomposed_phase = do_phased_classical_simulation(cb, call_with)
np.testing.assert_equal(
bloq_res, decomposed_res, err_msg=f'{bloq=} {call_with=} {bloq_res=} {decomposed_res=}'
)
np.testing.assert_equal(
bloq_phase,
decomposed_phase,
err_msg=f'{bloq=} {call_with=} {bloq_phase=} {decomposed_phase=}',
)
Loading