diff --git a/CHANGELOG.md b/CHANGELOG.md index d09bca1ba..b5c6061bb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -27,6 +27,18 @@ from spyglass.lfp.analysis.v1 import LFPBandV1 LFPBandV1().fix_1481() ``` +#### AutomaticCuration Fix + +If you were using `v0.AutomaticCuration` after April 2025, you may have stored +inaccurate labels due to #1513. To fix these, please run the following after +updating: + +```python +from spyglass.spikesorting.v0 import AutomaticCuration + +AutomaticCuration().fix_1513() +``` + ### Breaking Changes #### Decoding Results Structure @@ -44,17 +56,19 @@ memory usage significantly. ```python # OLD (before v0.5.6): results.isel(intervals=0) # Get first interval -for i in range(results.sizes['intervals']): # Iterate intervals +for i in range(results.sizes["intervals"]): # Iterate intervals interval_data = results.isel(intervals=i) # NEW (v0.5.6+): results.where(results.interval_labels == 0, drop=True) # Get first interval for label in np.unique(results.interval_labels.values): # Iterate intervals - if label >= 0: # Skip -1 (outside intervals, only with estimate_decoding_params=True) + if ( + label >= 0 + ): # Skip -1 (outside intervals, only with estimate_decoding_params=True) interval_data = results.where(results.interval_labels == label, drop=True) # Or use groupby: -for label, interval_data in results.groupby('interval_labels'): +for label, interval_data in results.groupby("interval_labels"): if label >= 0: # process interval_data pass @@ -64,7 +78,7 @@ for label, interval_data in results.groupby('interval_labels'): - `0, 1, 2, ...` - Sequential interval indices (0-indexed) - `-1` - Time points outside any decoding interval (only when - `estimate_decoding_params=True`) + `estimate_decoding_params=True`) ### Documentation @@ -76,8 +90,8 @@ for label, interval_data in results.groupby('interval_labels'): ### Infrastructure -- Add cross-platform installer script with Docker support, input validation, - and automated environment setup #1414 +- Add cross-platform installer script with Docker support, input validation, and + automated environment setup #1414 - Set default codecov threshold for test fail, disable patch check #1370, #1372 - Simplify PR template #1370 - Allow email send on space check success, clean up maintenance logging #1381 @@ -139,9 +153,10 @@ for label, interval_data in results.groupby('interval_labels'): merge IDs are fetched in non-chronological order #1471 - Separate `ClusterlessDecodingV1` to tri-part `make` #1467 - **BREAKING**: Remove `intervals` dimension from decoding results. Results - from multiple intervals are now concatenated along the `time` dimension - with an `interval_labels` coordinate to track interval membership. This - eliminates NaN padding and reduces memory usage. See migration guide above. + from multiple intervals are now concatenated along the `time` dimension + with an `interval_labels` coordinate to track interval membership. This + eliminates NaN padding and reduces memory usage. See migration guide + above. - LFP @@ -158,6 +173,7 @@ for label, interval_data in results.groupby('interval_labels'): - Implement short-transaction `SpikeSortingRecording.make` for v0 #1338 - Fix `FigURLCuration.make`. Postpone fetch of unhashable items #1505 + - Implement fix for `AutomaticCuration` incorrect labels #1513 ## [0.5.5] (Aug 6, 2025) diff --git a/src/spyglass/spikesorting/v0/spikesorting_curation.py b/src/spyglass/spikesorting/v0/spikesorting_curation.py index eb732bde5..ea96dae94 100644 --- a/src/spyglass/spikesorting/v0/spikesorting_curation.py +++ b/src/spyglass/spikesorting/v0/spikesorting_curation.py @@ -1089,13 +1089,314 @@ def get_labels(sorting, parent_labels, quality_metrics, label_params): label = label_params[metric] if compare(quality_metrics[metric][unit_id], label[1]): - if unit_id not in parent_labels: - parent_labels[unit_id] = label[2] + if int(unit_id) not in parent_labels: + parent_labels[int(unit_id)] = label[2].copy() # check if the label is already there, and if not, add it - elif label[2] not in parent_labels[unit_id]: - parent_labels[unit_id].extend(label[2]) + else: + # remove 'accept' label if it exists + if "accept" in parent_labels[int(unit_id)]: + parent_labels[int(unit_id)].remove("accept") + for element in label[2].copy(): + if element not in parent_labels[int(unit_id)]: + parent_labels[int(unit_id)].append(element) - return parent_labels + return parent_labels + + @staticmethod + def _normalize_labels(labels): + """Return labels dict with all keys cast to int. + + Quality metrics loaded from JSON have string keys, while + the fixed ``get_labels`` uses ``int(unit_id)``. Normalize + to int so comparisons are consistent regardless of source. + """ + return {int(k): v for k, v in labels.items()} + + def fix_1513(self, restriction=True, dry_run=True, verbose=True): + """Find and repair entries affected by get_labels bugs. + + For more information, see issue #1513. + + Parameters + ---------- + restriction : str, dict, optional + Restrict to a subset of AutomaticCuration entries. + dry_run : bool + If True, report affected entries without modifying the + database. Default True. + verbose : bool + If True, print progress and details. Default True. + + Returns + ------- + list of dict + Each dict has keys: auto_curation_key, curation_key, + old_labels, new_labels, changed, has_downstream, + reject_status_changed. + """ + raise NotImplementedError("fix_1513 is not yet fully implemented") + + restr = (self & restriction) if restriction else self + if verbose: + logger.info(f"fix_1513: scanning {len(restr)} entries") + + results = [] + for key in restr: + result = self._fix1_1513(key, dry_run, verbose) + if result is not None: + results.append(result) + + if verbose: + logger.info( + f"fix_1513: {len(results)} impacted entries" + + (" (dry run)" if dry_run else " (applied)") + ) + + return results + + def _fix1_1513(self, key, dry_run=True, verbose=True): + """Detect and repair a single AutomaticCuration entry. + + Returns a result dict if the entry is impacted, None otherwise. + See `fix_1513` for full documentation. + """ + from copy import deepcopy + from datetime import datetime + + bug_date = datetime(2025, 4, 22).timestamp() + + # --- Early return 1: empty label_params --- + params = (self & key) * AutomaticCurationParameters + label_params = params.fetch1("label_params") + if not label_params: + return None + + # --- Early return 2: created before bug date --- + auto_curation_key = (self & key).fetch1("auto_curation_key") + time_created = (Curation & auto_curation_key).fetch1("time_of_creation") + if time_created < bug_date: + return None + + # --- Load quality metrics --- + metrics_path = (QualityMetrics & key).fetch1("quality_metrics_path") + try: + with open(metrics_path) as f: + quality_metrics = json.load(f) + except FileNotFoundError: + if verbose: + logger.warning( + f"fix_1513: metrics file not found: " + f"{metrics_path}; skipping {key}" + ) + return None + + # --- Recompute labels --- + parent_curation = (Curation & key).fetch(as_dict=True)[0] + parent_labels = parent_curation["curation_labels"] + + new_labels = self._normalize_labels( + self.get_labels( + sorting=None, + parent_labels=deepcopy(parent_labels), + quality_metrics=quality_metrics, + label_params=label_params, + ) + ) + stored_labels = self._normalize_labels( + (Curation & auto_curation_key).fetch1("curation_labels") + ) + + if new_labels == stored_labels: + return None + + # --- Determine downstream impact --- + has_downstream = ( + len(CuratedSpikeSortingSelection & auto_curation_key) > 0 + ) + + reject_changed = [] + if has_downstream: + all_uids = set(stored_labels.keys()) | set(new_labels.keys()) + for uid in all_uids: + old_reject = ( + uid in stored_labels and "reject" in stored_labels[uid] + ) + new_reject = uid in new_labels and "reject" in new_labels[uid] + if old_reject != new_reject: + reject_changed.append( + { + "unit": uid, + "was_rejected": old_reject, + "should_reject": new_reject, + } + ) + + result = { + "auto_curation_key": auto_curation_key, + "curation_key": { + k: auto_curation_key[k] + for k in Curation.primary_key + if k in auto_curation_key + }, + "old_labels": stored_labels, + "new_labels": new_labels, + "changed": True, + "has_downstream": has_downstream, + "reject_status_changed": reject_changed, + } + + if verbose: + n_diff = sum( + 1 + for u in set(stored_labels) | set(new_labels) + if stored_labels.get(u) != new_labels.get(u) + ) + logger.info( + f"fix_1513: {auto_curation_key} — " + f"{n_diff} unit(s) with label changes" + + ( + f", {len(reject_changed)} reject " f"status change(s)" + if reject_changed + else "" + ) + ) + + if not dry_run: + # NWB edit happens before the DB transaction. If the + # file edit fails, DB state is unchanged and the entry + # can be retried. If the DB commit fails after a + # successful file write, the labels in the NWB are + # already correct and idempotent reprocessing is safe. + with Curation.connection.transaction: + Curation.update1( + { + **auto_curation_key, + "curation_labels": new_labels, + } + ) + if has_downstream: + self._fix_1513_units(auto_curation_key, new_labels, verbose) + + return result + + @staticmethod + def _fix_1513_units(curation_key, new_labels, verbose=True): + """Update CuratedSpikeSorting.Unit labels for a curation. + + Updates the label column on existing Unit rows. Does NOT + add or remove rows — if accept/reject status changed, a + full repopulation of CuratedSpikeSorting is needed. + + Parameters + ---------- + curation_key : dict + Primary key to the Curation entry. + new_labels : dict + Corrected labels dict with int keys + ``{unit_id: [label, ...]}``. + verbose : bool + If True, log changes. + """ + unit_rows = (CuratedSpikeSorting.Unit & curation_key).fetch( + as_dict=True + ) + for row in unit_rows: + uid = int(row["unit_id"]) + old_label = row["label"] + new_label = ",".join(new_labels.get(uid, [])) + if old_label != new_label: + if verbose: + logger.info( + f" unit {uid}: " f"'{old_label}' -> '{new_label}'" + ) + CuratedSpikeSorting.Unit.update1( + { + **{ + k: row[k] + for k in CuratedSpikeSorting.Unit.primary_key + }, + "label": new_label, + } + ) + + @staticmethod + def _fix_1513_nwb(curation_key, new_labels, verbose=True): + """Update labels in the CuratedSpikeSorting NWB analysis file. + + Edits the ``label`` column in the NWB units table in place + and updates the external-table checksum so DataJoint's + filepath store stays consistent. + + Does NOT add or remove unit rows. If accept/reject status + changed such that units need to be added, a full + repopulation of CuratedSpikeSorting is required. + + Parameters + ---------- + curation_key : dict + Primary key to the Curation entry (used to look up + the CuratedSpikeSorting analysis file). + new_labels : dict + Corrected labels dict ``{unit_id: [label, ...]}``. + verbose : bool + If True, log file path and changes. + """ + import pynwb + + css_row = (CuratedSpikeSorting & curation_key).fetch(as_dict=True) + if not css_row: + return + css_row = css_row[0] + analysis_file_name = css_row["analysis_file_name"] + abs_path = AnalysisNwbfile().get_abs_path(analysis_file_name) + + if verbose: + logger.info(f" NWB: editing {abs_path}") + + with pynwb.NWBHDF5IO( + path=abs_path, mode="a", load_namespaces=True + ) as io: + nwbf = io.read() + if nwbf.units is None or "label" not in nwbf.units: + if verbose: + logger.info(" NWB: no units/label column; skip") + return + + unit_ids = list(nwbf.units.id.data[:]) + label_col = nwbf.units["label"] + changed = False + + for idx, uid in enumerate(unit_ids): + new_val = ",".join(new_labels.get(int(uid), [])) + + old_val = label_col[idx] + if old_val != new_val: + label_col.data[idx] = new_val + changed = True + if verbose: + logger.info( + f" NWB unit {uid}: " f"'{old_val}' -> '{new_val}'" + ) + + if changed: + io.write(nwbf) + + # Update the external-table checksum to match edited file + if changed: + abs_path = Path(abs_path) + anwb = AnalysisNwbfile() + rel_path = abs_path.relative_to(anwb._analysis_dir) + ext_tbl = anwb._ext_tbl + ext_key = (ext_tbl & f"filepath = '{str(rel_path)}'").fetch1() + ext_key.update( + { + "contents_hash": dj.hash.uuid_from_file(abs_path), + "size": abs_path.stat().st_size, + } + ) + ext_tbl.update1(ext_key) + if verbose: + logger.info(" NWB: checksum updated") @schema @@ -1221,7 +1522,7 @@ def make_compute( recording = si.load_extractor(recording_path) timestamps = SpikeSortingRecording._get_recording_timestamps(recording) - (analysis_file_name, units_object_id) = Curation().save_sorting_nwb( + analysis_file_name, units_object_id = Curation().save_sorting_nwb( key=key, sorting=sorting, timestamps=timestamps, diff --git a/tests/spikesorting/v0/test_bug_1513.py b/tests/spikesorting/v0/test_bug_1513.py new file mode 100644 index 000000000..550df9199 --- /dev/null +++ b/tests/spikesorting/v0/test_bug_1513.py @@ -0,0 +1,199 @@ +"""Tests for bug 1281: AutomaticCuration.get_labels early return. + +The bug caused get_labels() to return after processing only the first +metric in label_params, silently dropping labels from subsequent metrics. + +These tests verify the fix and define the expected behavior for the +fix_1281 repair function. +""" + +import pytest + + +# -- Helpers ---------------------------------------------------------- + + +@pytest.fixture(scope="session") +def get_labels(spike_v0): + """Wrapper: sorting arg is unused in label logic.""" + + def _get_labels(parent_labels, quality_metrics, label_params): + return spike_v0.AutomaticCuration.get_labels( + sorting=None, + parent_labels=parent_labels, + quality_metrics=quality_metrics, + label_params=label_params, + ) + + yield _get_labels + + +# -- Task 5.1: get_labels unit tests --------------------------------- + + +class TestGetLabelsMultiMetric: + """Verify get_labels processes ALL metrics in label_params.""" + + two_metric_params = { + "nn_noise_overlap": [">", 0.1, ["noise", "reject"]], + "isi_violation": [">", 0.5, ["mua"]], + } + quality_metrics = { + "nn_noise_overlap": {"0": 0.2, "1": 0.05, "2": 0.3}, + "isi_violation": {"0": 0.8, "1": 0.9, "2": 0.1}, + } + + def test_both_metrics_applied(self, get_labels): + """Unit matching both metrics gets labels from both.""" + result = get_labels({}, self.quality_metrics, self.two_metric_params) + # unit 0: nn_noise_overlap 0.2 > 0.1 -> noise,reject + # isi_violation 0.8 > 0.5 -> mua + assert "noise" in result[0] + assert "reject" in result[0] + assert "mua" in result[0] + + def test_second_metric_only(self, get_labels): + """Unit matching only second metric still gets its labels.""" + result = get_labels({}, self.quality_metrics, self.two_metric_params) + # unit 1: nn_noise_overlap 0.05 <= 0.1 -> no label + # isi_violation 0.9 > 0.5 -> mua + assert 1 in result + assert "mua" in result[1] + assert "noise" not in result.get(1, []) + + def test_no_match_no_label(self, get_labels): + """Unit matching neither metric gets no labels.""" + result = get_labels({}, self.quality_metrics, self.two_metric_params) + # unit 2: nn_noise_overlap 0.3 > 0.1 -> noise,reject + # isi_violation 0.1 <= 0.5 -> no label + assert "noise" in result[2] + assert "reject" in result[2] + assert "mua" not in result[2] + + def test_all_units_covered(self, get_labels): + """All units present in quality_metrics are evaluated.""" + result = get_labels({}, self.quality_metrics, self.two_metric_params) + assert set(result.keys()) >= {0, 1} + + +class TestGetLabelsSingleMetric: + """Single-metric label_params works identically pre/post fix.""" + + single_metric_params = { + "nn_noise_overlap": [">", 0.1, ["noise", "reject"]], + } + quality_metrics = { + "nn_noise_overlap": {"0": 0.2, "1": 0.05}, + } + + def test_single_metric_labels(self, get_labels): + result = get_labels({}, self.quality_metrics, self.single_metric_params) + assert "noise" in result[0] + assert "reject" in result[0] + + def test_single_metric_no_match(self, get_labels): + result = get_labels({}, self.quality_metrics, self.single_metric_params) + assert 1 not in result + + +class TestGetLabelsEmpty: + """Empty label_params returns parent_labels unchanged.""" + + def test_empty_params_returns_parent(self, get_labels): + parent = {0: ["accept"]} + result = get_labels(parent, {"metric": {"0": 1.0}}, {}) + assert result == {0: ["accept"]} + + def test_empty_params_empty_parent(self, get_labels): + result = get_labels({}, {"metric": {"0": 1.0}}, {}) + assert result == {} + + +class TestGetLabelsParentPreservation: + """Existing parent labels are preserved and extended.""" + + def test_accept_removed_when_other_labels_added(self, get_labels): + """Accept label is removed when metric labels are applied.""" + parent = {0: ["accept"]} + params = {"nn_noise_overlap": [">", 0.1, ["noise", "reject"]]} + qm = {"nn_noise_overlap": {"0": 0.2}} + result = get_labels(parent, qm, params) + assert "accept" not in result[0] + assert "noise" in result[0] + + def test_no_duplicate_labels(self, get_labels): + parent = {0: ["noise"]} + params = {"nn_noise_overlap": [">", 0.1, ["noise", "reject"]]} + qm = {"nn_noise_overlap": {"0": 0.2}} + result = get_labels(parent, qm, params) + assert result[0].count("noise") == 1 + + def test_parent_labels_extended_by_second_metric(self, get_labels): + """Parent labels from first metric are extended by second.""" + parent = {} + params = { + "nn_noise_overlap": [">", 0.1, ["noise", "reject"]], + "isi_violation": [">", 0.5, ["mua"]], + } + qm = { + "nn_noise_overlap": {"0": 0.2}, + "isi_violation": {"0": 0.8}, + } + result = get_labels(parent, qm, params) + assert set(result[0]) == {"noise", "reject", "mua"} + + +class TestGetLabelsMetricSkipping: + """Metrics not in quality_metrics are skipped, not errored.""" + + def test_missing_metric_skipped(self, get_labels): + params = { + "nn_noise_overlap": [">", 0.1, ["noise", "reject"]], + "nonexistent_metric": [">", 0.5, ["mua"]], + } + qm = {"nn_noise_overlap": {"0": 0.2}} + result = get_labels({}, qm, params) + assert "noise" in result[0] + assert "mua" not in result.get(0, []) + + def test_only_overlapping_metrics_apply(self, get_labels): + """With 3 metrics, only 2 present in qm are applied.""" + params = { + "nn_noise_overlap": [">", 0.1, ["noise"]], + "missing": [">", 0.5, ["artifact"]], + "isi_violation": [">", 0.5, ["mua"]], + } + qm = { + "nn_noise_overlap": {"0": 0.2}, + "isi_violation": {"0": 0.8}, + } + result = get_labels({}, qm, params) + assert "noise" in result[0] + assert "mua" in result[0] + assert "artifact" not in result.get(0, []) + + +class TestGetLabelsComparisonOperators: + """All comparison operators work correctly.""" + + @pytest.mark.parametrize( + "op,threshold,value,should_label", + [ + (">", 0.5, 0.6, True), + (">", 0.5, 0.5, False), + (">=", 0.5, 0.5, True), + ("<", 0.5, 0.4, True), + ("<", 0.5, 0.5, False), + ("<=", 0.5, 0.5, True), + ("==", 0.5, 0.5, True), + ("==", 0.5, 0.6, False), + ], + ) + def test_operator(self, get_labels, op, threshold, value, should_label): + params = {"nn_noise_overlap": [op, threshold, ["reject"]]} + qm = {"nn_noise_overlap": {"0": value}} + result = get_labels({}, qm, params) + if should_label: + assert 0 in result and "reject" in result[0] + else: + assert 0 not in result