Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,8 @@ for label, interval_data in results.groupby("interval_labels"):
- Log expected recompute failures #1470
- Track file created/deletion status of recomputes #1470
- Upgrade to pynwb>=3.1 #1506
- Remove imports of ndx extensions in main package to prevent errors in nwb io #1506
- Remove imports of ndx extensions in main package to prevent errors in nwb io
#1506
- Add `analysis_table` property to mixin for custom pipelines #1525

### Pipelines
Expand Down Expand Up @@ -199,6 +200,7 @@ for label, interval_data in results.groupby("interval_labels"):
- Implement short-transaction `SpikeSortingRecording.make` for v0 #1338
- Fix `FigURLCuration.make`. Postpone fetch of unhashable items #1505
- Improve get_recording efficiency #1522
- Allow `CurationV1` to save without any spikes #1533

## [0.5.5] (Aug 6, 2025)

Expand Down
28 changes: 18 additions & 10 deletions src/spyglass/spikesorting/v1/curation.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,12 +346,15 @@ def _write_sorting_to_nwb_with_curation(
) as io:
nwbf = io.read()
units = nwbf.units.to_dataframe()
units_dict = {
unit_id: spike_times
for unit_id, spike_times in zip(units.index, units["spike_times"])
}
if "spike_times" in units.columns:
units_dict = {
unit_id: spike_times
for unit_id, spike_times in zip(units.index, units["spike_times"])
}
else:
units_dict = {}

if apply_merge:
if apply_merge and units_dict:
for merge_group in merge_groups:
new_unit_id = np.max(list(units_dict.keys())) + 1
units_dict[new_unit_id] = np.concatenate(
Expand All @@ -372,12 +375,17 @@ def _write_sorting_to_nwb_with_curation(
) as io:
nwbf = io.read()
# write sorting to the nwb file
for unit_id in unit_ids:
# spike_times = sorting.get_unit_spike_train(unit_id)
nwbf.add_unit(
spike_times=units_dict[unit_id],
id=unit_id,
if not unit_ids:
nwbf.units = pynwb.misc.Units(
name="units", description="Empty units table."
)
else:
for unit_id in unit_ids:
# spike_times = sorting.get_unit_spike_train(unit_id)
nwbf.add_unit(
spike_times=units_dict[unit_id],
id=unit_id,
)
# add labels, merge groups, metrics
if labels is not None:
label_values = []
Expand Down
163 changes: 162 additions & 1 deletion tests/spikesorting/v1/test_curation.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np
import pytest
from spikeinterface import BaseSorting
from spikeinterface.extractors.nwbextractors import NwbRecordingExtractor
import pytest


def test_curation_rec(spike_v1, pop_curation):
Expand Down Expand Up @@ -110,3 +110,164 @@ def test_curation_sort_metric(spike_v1, pop_curation, pop_curation_metric):
assert (
sort_metric[k] == expected[k]
), f"CurationV1.get_sort_group_info unexpected value: {k}"


# ============================================================================
# No-Spikes Case Tests (Issue #1532)
# ============================================================================


@pytest.fixture
def empty_units_nwb(tmp_path):
"""Create NWB file with empty units table for testing no-spikes case."""
from datetime import datetime
from uuid import uuid4

import pynwb

nwb_path = tmp_path / "empty_units.nwb"

nwbfile = pynwb.NWBFile(
session_description="Test session with no spikes",
identifier=str(uuid4()),
session_start_time=datetime.now(),
)
# Create empty units table (no spike_times column)
nwbfile.units = pynwb.misc.Units(
name="units", description="Empty units table."
)

with pynwb.NWBHDF5IO(str(nwb_path), "w") as io:
io.write(nwbfile)

return nwb_path


@pytest.fixture
def curation_mocks(tmp_path):
"""Fixture providing mocked dependencies for _write_sorting_to_nwb_with_curation."""
from unittest.mock import MagicMock, patch

import pandas as pd

class CurationMocks:
def __init__(self):
self.tmp_path = tmp_path
self.patches = []
self.write_nwbf = None

def setup(self, units_df=None):
"""Setup mocks with given units DataFrame."""
if units_df is None:
units_df = pd.DataFrame() # Empty DataFrame, no spike_times

mock_nwbf = MagicMock()
mock_nwbf.units.to_dataframe.return_value = units_df

class _WriteNWBMock(MagicMock):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._units = None

@property
def units(self):
return self._units

@units.setter
def units(self, value):
self._units = value

self.write_nwbf = _WriteNWBMock()
mock_io_read = MagicMock()
mock_io_read.read.return_value = mock_nwbf
mock_io_read.__enter__ = MagicMock(return_value=mock_io_read)
mock_io_read.__exit__ = MagicMock(return_value=False)

mock_io_write = MagicMock()
mock_io_write.read.return_value = self.write_nwbf
mock_io_write.__enter__ = MagicMock(return_value=mock_io_write)
mock_io_write.__exit__ = MagicMock(return_value=False)

def mock_nwbhdf5io(*args, **kwargs):
if kwargs.get("mode") == "r" or (args and "r" in str(args)):
return mock_io_read
return mock_io_write

self.patches = [
patch(
"spyglass.spikesorting.v1.curation.SpikeSortingSelection",
self._mock_table("test.nwb"),
),
patch(
"spyglass.spikesorting.v1.curation.SpikeSorting",
self._mock_table("test_analysis.nwb"),
),
patch(
"spyglass.spikesorting.v1.curation.AnalysisNwbfile",
self._mock_analysis_nwb(),
),
patch(
"spyglass.spikesorting.v1.curation.pynwb.NWBHDF5IO",
side_effect=mock_nwbhdf5io,
),
]
return self

def _mock_table(self, return_value):
mock = MagicMock()
mock_instance = MagicMock()
mock_instance.fetch1.return_value = return_value
mock.__and__.return_value = mock_instance
return mock

def _mock_analysis_nwb(self):
mock = MagicMock()
mock_instance = MagicMock()
mock_instance.create.return_value = "new_analysis.nwb"
mock.return_value = mock_instance
mock.get_abs_path.return_value = str(self.tmp_path / "test.nwb")
return mock

def __enter__(self):
for p in self.patches:
p.__enter__()
return self

def __exit__(self, *args):
for p in reversed(self.patches):
p.__exit__(*args)

return CurationMocks()


def test_write_sorting_no_spikes(curation_mocks):
"""Test _write_sorting_to_nwb_with_curation handles missing spike_times."""
import pynwb

from spyglass.spikesorting.v1.curation import (
_write_sorting_to_nwb_with_curation,
)

with curation_mocks.setup():
result = _write_sorting_to_nwb_with_curation(
sorting_id="test_sorting_id",
labels=None,
merge_groups=[["unit1", "unit2"]],
metrics=None,
apply_merge=True, # Also tests apply_merge guard
)

assert result is not None
assert len(result) == 2
assert isinstance(curation_mocks.write_nwbf._units, pynwb.misc.Units)


def test_empty_units_nwb_readable(empty_units_nwb):
"""Test that NWB file with empty units table is readable."""
import pynwb

with pynwb.NWBHDF5IO(str(empty_units_nwb), "r") as io:
nwbf = io.read()
assert nwbf.units is not None
units_df = nwbf.units.to_dataframe()
assert len(units_df) == 0
Loading