diff --git a/CHANGELOG.md b/CHANGELOG.md index c30ec7d6c..49c0a9eaa 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -201,6 +201,7 @@ for label, interval_data in results.groupby("interval_labels"): - Fix `FigURLCuration.make`. Postpone fetch of unhashable items #1505 - Improve get_recording efficiency #1522 - Raise error if `FigURLCurationSelection` finds no curation label #1531 + - Allow `CurationV1` to save without any spikes #1533 ## [0.5.5] (Aug 6, 2025) diff --git a/src/spyglass/spikesorting/v1/curation.py b/src/spyglass/spikesorting/v1/curation.py index de30fd53b..97c3bf259 100644 --- a/src/spyglass/spikesorting/v1/curation.py +++ b/src/spyglass/spikesorting/v1/curation.py @@ -346,12 +346,15 @@ def _write_sorting_to_nwb_with_curation( ) as io: nwbf = io.read() units = nwbf.units.to_dataframe() - units_dict = { - unit_id: spike_times - for unit_id, spike_times in zip(units.index, units["spike_times"]) - } + if "spike_times" in units.columns: + units_dict = { + unit_id: spike_times + for unit_id, spike_times in zip(units.index, units["spike_times"]) + } + else: + units_dict = {} - if apply_merge: + if apply_merge and units_dict: for merge_group in merge_groups: new_unit_id = np.max(list(units_dict.keys())) + 1 units_dict[new_unit_id] = np.concatenate( @@ -372,12 +375,16 @@ def _write_sorting_to_nwb_with_curation( ) as io: nwbf = io.read() # write sorting to the nwb file - for unit_id in unit_ids: - # spike_times = sorting.get_unit_spike_train(unit_id) - nwbf.add_unit( - spike_times=units_dict[unit_id], - id=unit_id, + if not unit_ids: + nwbf.units = pynwb.misc.Units( + name="units", description="Empty units table." ) + else: + for unit_id in unit_ids: + nwbf.add_unit( + spike_times=units_dict[unit_id], + id=unit_id, + ) # add labels, merge groups, metrics if labels is not None: label_values = [] diff --git a/tests/spikesorting/v1/test_curation.py b/tests/spikesorting/v1/test_curation.py index 89617ce21..4237e1dcb 100644 --- a/tests/spikesorting/v1/test_curation.py +++ b/tests/spikesorting/v1/test_curation.py @@ -1,7 +1,7 @@ import numpy as np +import pytest from spikeinterface import BaseSorting from spikeinterface.extractors.nwbextractors import NwbRecordingExtractor -import pytest def test_curation_rec(spike_v1, pop_curation): @@ -110,3 +110,164 @@ def test_curation_sort_metric(spike_v1, pop_curation, pop_curation_metric): assert ( sort_metric[k] == expected[k] ), f"CurationV1.get_sort_group_info unexpected value: {k}" + + +# ============================================================================ +# No-Spikes Case Tests (Issue #1532) +# ============================================================================ + + +@pytest.fixture +def empty_units_nwb(tmp_path): + """Create NWB file with empty units table for testing no-spikes case.""" + from datetime import datetime + from uuid import uuid4 + + import pynwb + + nwb_path = tmp_path / "empty_units.nwb" + + nwbfile = pynwb.NWBFile( + session_description="Test session with no spikes", + identifier=str(uuid4()), + session_start_time=datetime.now(), + ) + # Create empty units table (no spike_times column) + nwbfile.units = pynwb.misc.Units( + name="units", description="Empty units table." + ) + + with pynwb.NWBHDF5IO(str(nwb_path), "w") as io: + io.write(nwbfile) + + return nwb_path + + +@pytest.fixture +def curation_mocks(tmp_path): + """Fixture providing mocked dependencies for _write_sorting_to_nwb_with_curation.""" + from unittest.mock import MagicMock, patch + + import pandas as pd + + class CurationMocks: + def __init__(self): + self.tmp_path = tmp_path + self.patches = [] + self.write_nwbf = None + + def setup(self, units_df=None): + """Setup mocks with given units DataFrame.""" + if units_df is None: + units_df = pd.DataFrame() # Empty DataFrame, no spike_times + + mock_nwbf = MagicMock() + mock_nwbf.units.to_dataframe.return_value = units_df + + class _WriteNWBMock(MagicMock): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._units = None + + @property + def units(self): + return self._units + + @units.setter + def units(self, value): + self._units = value + + self.write_nwbf = _WriteNWBMock() + mock_io_read = MagicMock() + mock_io_read.read.return_value = mock_nwbf + mock_io_read.__enter__ = MagicMock(return_value=mock_io_read) + mock_io_read.__exit__ = MagicMock(return_value=False) + + mock_io_write = MagicMock() + mock_io_write.read.return_value = self.write_nwbf + mock_io_write.__enter__ = MagicMock(return_value=mock_io_write) + mock_io_write.__exit__ = MagicMock(return_value=False) + + def mock_nwbhdf5io(*args, **kwargs): + if kwargs.get("mode") == "r" or (args and "r" in str(args)): + return mock_io_read + return mock_io_write + + self.patches = [ + patch( + "spyglass.spikesorting.v1.curation.SpikeSortingSelection", + self._mock_table("test.nwb"), + ), + patch( + "spyglass.spikesorting.v1.curation.SpikeSorting", + self._mock_table("test_analysis.nwb"), + ), + patch( + "spyglass.spikesorting.v1.curation.AnalysisNwbfile", + self._mock_analysis_nwb(), + ), + patch( + "spyglass.spikesorting.v1.curation.pynwb.NWBHDF5IO", + side_effect=mock_nwbhdf5io, + ), + ] + return self + + def _mock_table(self, return_value): + mock = MagicMock() + mock_instance = MagicMock() + mock_instance.fetch1.return_value = return_value + mock.__and__.return_value = mock_instance + return mock + + def _mock_analysis_nwb(self): + mock = MagicMock() + mock_instance = MagicMock() + mock_instance.create.return_value = "new_analysis.nwb" + mock.return_value = mock_instance + mock.get_abs_path.return_value = str(self.tmp_path / "test.nwb") + return mock + + def __enter__(self): + for p in self.patches: + p.__enter__() + return self + + def __exit__(self, *args): + for p in reversed(self.patches): + p.__exit__(*args) + + return CurationMocks() + + +def test_write_sorting_no_spikes(curation_mocks): + """Test _write_sorting_to_nwb_with_curation handles missing spike_times.""" + import pynwb + + from spyglass.spikesorting.v1.curation import ( + _write_sorting_to_nwb_with_curation, + ) + + with curation_mocks.setup(): + result = _write_sorting_to_nwb_with_curation( + sorting_id="test_sorting_id", + labels=None, + merge_groups=[["unit1", "unit2"]], + metrics=None, + apply_merge=True, # Also tests apply_merge guard + ) + + assert result is not None + assert len(result) == 2 + assert isinstance(curation_mocks.write_nwbf._units, pynwb.misc.Units) + + +def test_empty_units_nwb_readable(empty_units_nwb): + """Test that NWB file with empty units table is readable.""" + import pynwb + + with pynwb.NWBHDF5IO(str(empty_units_nwb), "r") as io: + nwbf = io.read() + assert nwbf.units is not None + units_df = nwbf.units.to_dataframe() + assert len(units_df) == 0