Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ for label, interval_data in results.groupby("interval_labels"):
- Fix `FigURLCuration.make`. Postpone fetch of unhashable items #1505
- Improve get_recording efficiency #1522
- Raise error if `FigURLCurationSelection` finds no curation label #1531
- Allow `CurationV1` to save without any spikes #1533

## [0.5.5] (Aug 6, 2025)

Expand Down
27 changes: 17 additions & 10 deletions src/spyglass/spikesorting/v1/curation.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,12 +346,15 @@ def _write_sorting_to_nwb_with_curation(
) as io:
nwbf = io.read()
units = nwbf.units.to_dataframe()
units_dict = {
unit_id: spike_times
for unit_id, spike_times in zip(units.index, units["spike_times"])
}
if "spike_times" in units.columns:
units_dict = {
unit_id: spike_times
for unit_id, spike_times in zip(units.index, units["spike_times"])
}
else:
units_dict = {}

if apply_merge:
if apply_merge and units_dict:
for merge_group in merge_groups:
new_unit_id = np.max(list(units_dict.keys())) + 1
units_dict[new_unit_id] = np.concatenate(
Expand All @@ -372,12 +375,16 @@ def _write_sorting_to_nwb_with_curation(
) as io:
nwbf = io.read()
# write sorting to the nwb file
for unit_id in unit_ids:
# spike_times = sorting.get_unit_spike_train(unit_id)
nwbf.add_unit(
spike_times=units_dict[unit_id],
id=unit_id,
if not unit_ids:
nwbf.units = pynwb.misc.Units(
name="units", description="Empty units table."
)
else:
for unit_id in unit_ids:
nwbf.add_unit(
spike_times=units_dict[unit_id],
id=unit_id,
)
# add labels, merge groups, metrics
if labels is not None:
label_values = []
Expand Down
163 changes: 162 additions & 1 deletion tests/spikesorting/v1/test_curation.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np
import pytest
from spikeinterface import BaseSorting
from spikeinterface.extractors.nwbextractors import NwbRecordingExtractor
import pytest


def test_curation_rec(spike_v1, pop_curation):
Expand Down Expand Up @@ -110,3 +110,164 @@ def test_curation_sort_metric(spike_v1, pop_curation, pop_curation_metric):
assert (
sort_metric[k] == expected[k]
), f"CurationV1.get_sort_group_info unexpected value: {k}"


# ============================================================================
# No-Spikes Case Tests (Issue #1532)
# ============================================================================


@pytest.fixture
def empty_units_nwb(tmp_path):
"""Create NWB file with empty units table for testing no-spikes case."""
from datetime import datetime
from uuid import uuid4

import pynwb

nwb_path = tmp_path / "empty_units.nwb"

nwbfile = pynwb.NWBFile(
session_description="Test session with no spikes",
identifier=str(uuid4()),
session_start_time=datetime.now(),
)
# Create empty units table (no spike_times column)
nwbfile.units = pynwb.misc.Units(
name="units", description="Empty units table."
)

with pynwb.NWBHDF5IO(str(nwb_path), "w") as io:
io.write(nwbfile)

return nwb_path


@pytest.fixture
def curation_mocks(tmp_path):
"""Fixture providing mocked dependencies for _write_sorting_to_nwb_with_curation."""
from unittest.mock import MagicMock, patch

import pandas as pd

class CurationMocks:
def __init__(self):
self.tmp_path = tmp_path
self.patches = []
self.write_nwbf = None

def setup(self, units_df=None):
"""Setup mocks with given units DataFrame."""
if units_df is None:
units_df = pd.DataFrame() # Empty DataFrame, no spike_times

mock_nwbf = MagicMock()
mock_nwbf.units.to_dataframe.return_value = units_df

class _WriteNWBMock(MagicMock):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._units = None

@property
def units(self):
return self._units

@units.setter
def units(self, value):
self._units = value

self.write_nwbf = _WriteNWBMock()
mock_io_read = MagicMock()
mock_io_read.read.return_value = mock_nwbf
mock_io_read.__enter__ = MagicMock(return_value=mock_io_read)
mock_io_read.__exit__ = MagicMock(return_value=False)

mock_io_write = MagicMock()
mock_io_write.read.return_value = self.write_nwbf
mock_io_write.__enter__ = MagicMock(return_value=mock_io_write)
mock_io_write.__exit__ = MagicMock(return_value=False)

def mock_nwbhdf5io(*args, **kwargs):
if kwargs.get("mode") == "r" or (args and "r" in str(args)):
return mock_io_read
return mock_io_write

self.patches = [
patch(
"spyglass.spikesorting.v1.curation.SpikeSortingSelection",
self._mock_table("test.nwb"),
),
patch(
"spyglass.spikesorting.v1.curation.SpikeSorting",
self._mock_table("test_analysis.nwb"),
),
patch(
"spyglass.spikesorting.v1.curation.AnalysisNwbfile",
self._mock_analysis_nwb(),
),
patch(
"spyglass.spikesorting.v1.curation.pynwb.NWBHDF5IO",
side_effect=mock_nwbhdf5io,
),
]
return self

def _mock_table(self, return_value):
mock = MagicMock()
mock_instance = MagicMock()
mock_instance.fetch1.return_value = return_value
mock.__and__.return_value = mock_instance
return mock

def _mock_analysis_nwb(self):
mock = MagicMock()
mock_instance = MagicMock()
mock_instance.create.return_value = "new_analysis.nwb"
mock.return_value = mock_instance
mock.get_abs_path.return_value = str(self.tmp_path / "test.nwb")
return mock

def __enter__(self):
for p in self.patches:
p.__enter__()
return self

def __exit__(self, *args):
for p in reversed(self.patches):
p.__exit__(*args)

return CurationMocks()


def test_write_sorting_no_spikes(curation_mocks):
"""Test _write_sorting_to_nwb_with_curation handles missing spike_times."""
import pynwb

from spyglass.spikesorting.v1.curation import (
_write_sorting_to_nwb_with_curation,
)

with curation_mocks.setup():
result = _write_sorting_to_nwb_with_curation(
sorting_id="test_sorting_id",
labels=None,
merge_groups=[["unit1", "unit2"]],
metrics=None,
apply_merge=True, # Also tests apply_merge guard
)

assert result is not None
assert len(result) == 2
assert isinstance(curation_mocks.write_nwbf._units, pynwb.misc.Units)


def test_empty_units_nwb_readable(empty_units_nwb):
"""Test that NWB file with empty units table is readable."""
import pynwb

with pynwb.NWBHDF5IO(str(empty_units_nwb), "r") as io:
nwbf = io.read()
assert nwbf.units is not None
units_df = nwbf.units.to_dataframe()
assert len(units_df) == 0
Loading