From 171b2b595668d0d9b432615e07721e59edbeb7f0 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Wed, 28 Jan 2026 14:19:43 -0800 Subject: [PATCH 1/4] WIP: get_labels bugfix 1 --- CHANGELOG.md | 12 + .../spikesorting/v0/spikesorting_curation.py | 330 +++++++++++++++++- tests/spikesorting/v0/test_bug_1281.py | 198 +++++++++++ 3 files changed, 536 insertions(+), 4 deletions(-) create mode 100644 tests/spikesorting/v0/test_bug_1281.py diff --git a/CHANGELOG.md b/CHANGELOG.md index d09bca1ba..dd75ed699 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -27,6 +27,17 @@ from spyglass.lfp.analysis.v1 import LFPBandV1 LFPBandV1().fix_1481() ``` +#### AutomaticCuration Fix + +If you were using `v0.AutomaticCuration` after April 2025, you may have stored +inaccurate labels due to #14XX. To fix these, please run the following after updating: + +```python +from spyglass.spikesorting.v0 import AutomaticCuration + +AutomaticCuration().fix_15XX() +``` + ### Breaking Changes #### Decoding Results Structure @@ -158,6 +169,7 @@ for label, interval_data in results.groupby('interval_labels'): - Implement short-transaction `SpikeSortingRecording.make` for v0 #1338 - Fix `FigURLCuration.make`. Postpone fetch of unhashable items #1505 + - Implement fix for `AutomaticCuration` incorrect labels #15XY ## [0.5.5] (Aug 6, 2025) diff --git a/src/spyglass/spikesorting/v0/spikesorting_curation.py b/src/spyglass/spikesorting/v0/spikesorting_curation.py index eb732bde5..95466edd7 100644 --- a/src/spyglass/spikesorting/v0/spikesorting_curation.py +++ b/src/spyglass/spikesorting/v0/spikesorting_curation.py @@ -1090,12 +1090,334 @@ def get_labels(sorting, parent_labels, quality_metrics, label_params): if compare(quality_metrics[metric][unit_id], label[1]): if unit_id not in parent_labels: - parent_labels[unit_id] = label[2] + parent_labels[unit_id] = label[2].copy() # check if the label is already there, and if not, add it - elif label[2] not in parent_labels[unit_id]: - parent_labels[unit_id].extend(label[2]) + else: + for element in label[2].copy(): + if element not in parent_labels[unit_id]: + parent_labels[unit_id].append(element) + + return parent_labels # Unindent to fix #15XX + + def fix_15XX(self, restriction=True, dry_run=True, verbose=True): + """Find and repair entries affected by get_labels bugs. + + PR #1281 (2025-04-22) introduced three bugs in `get_labels`: + + A. **Early return**: `return parent_labels` was indented inside + the `for metric` loop, so only the first metric was + processed. Affects entries with >1 metric in label_params. + B. **List aliasing**: `parent_labels[unit_id] = label[2]` + assigned the label list without copying, so units sharing + a metric shared the same list object. Mutations on one + unit could corrupt others. + C. **Duplicate comparison**: `label[2] not in parent_labels` + compared a list against a list of strings. This always + evaluated True for flat lists, so `.extend()` ran + unconditionally, creating duplicate labels. + + Bugs B and C can affect single-metric entries when + `parent_labels` has pre-existing entries from a prior + curation step. + + For each impacted entry, this method: + 1. Recomputes labels using the fixed `get_labels` logic. + 2. Updates `Curation.curation_labels` with corrected labels. + 3. Updates `CuratedSpikeSorting.Unit.label` on existing rows. + + Steps 2-3 are wrapped in a transaction per entry to prevent + partial updates on interruption. + + Note: if a unit's accept/reject status changed, the NWB + analysis file still contains the old unit set. A full + repopulation of `CuratedSpikeSorting` is needed in those + cases. Such entries are flagged in `reject_status_changed`. - return parent_labels + Parameters + ---------- + restriction : str, dict, optional + Restrict to a subset of AutomaticCuration entries. + dry_run : bool + If True, report affected entries without modifying the + database. Default True. + verbose : bool + If True, print progress and details. Default True. + + Returns + ------- + list of dict + Each dict has keys: auto_curation_key, curation_key, + old_labels, new_labels, changed, has_downstream, + reject_status_changed. + """ + restr = (self & restriction) if restriction else self + if verbose: + logger.info(f"fix_15XX: scanning {len(restr)} entries") + + results = [] + for key in restr: + result = self._fix1_15XX(key, dry_run, verbose) + if result is not None: + results.append(result) + + if verbose: + logger.info( + f"fix_15XX: {len(results)} impacted entries" + + (" (dry run)" if dry_run else " (applied)") + ) + + return results + + def _fix1_15XX(self, key, dry_run=True, verbose=True): + """Detect and repair a single AutomaticCuration entry. + + Returns a result dict if the entry is impacted, None otherwise. + See `fix_15XX` for full documentation. + """ + from copy import deepcopy + from datetime import datetime + + bug_date = datetime(2025, 4, 22).timestamp() + + # --- Early return 1: empty label_params --- + params = (self & key) * AutomaticCurationParameters + label_params = params.fetch1("label_params") + if not label_params: + return None + + # --- Early return 2: created before bug date --- + auto_curation_key = (self & key).fetch1("auto_curation_key") + time_created = (Curation & auto_curation_key).fetch1("time_of_creation") + if time_created < bug_date: + return None + + # --- Load quality metrics --- + metrics_path = (QualityMetrics & key).fetch1("quality_metrics_path") + try: + with open(metrics_path) as f: + quality_metrics = json.load(f) + except FileNotFoundError: + if verbose: + logger.warning( + f"fix_15XX: metrics file not found: " + f"{metrics_path}; skipping {key}" + ) + return None + + # --- Recompute labels --- + parent_curation = (Curation & key).fetch(as_dict=True)[0] + parent_labels = parent_curation["curation_labels"] + + new_labels = self.get_labels( + sorting=None, + parent_labels=deepcopy(parent_labels), + quality_metrics=quality_metrics, + label_params=label_params, + ) + stored_labels = (Curation & auto_curation_key).fetch1("curation_labels") + + if new_labels == stored_labels: + return None + + # --- Determine downstream impact --- + has_downstream = ( + len(CuratedSpikeSortingSelection & auto_curation_key) > 0 + ) + + reject_changed = [] + if has_downstream: + all_uids = set(stored_labels.keys()) | set(new_labels.keys()) + for uid in all_uids: + old_reject = ( + uid in stored_labels and "reject" in stored_labels[uid] + ) + new_reject = uid in new_labels and "reject" in new_labels[uid] + if old_reject != new_reject: + reject_changed.append( + { + "unit": uid, + "was_rejected": old_reject, + "should_reject": new_reject, + } + ) + + result = { + "auto_curation_key": auto_curation_key, + "curation_key": { + k: auto_curation_key[k] + for k in Curation.primary_key + if k in auto_curation_key + }, + "old_labels": stored_labels, + "new_labels": new_labels, + "changed": True, + "has_downstream": has_downstream, + "reject_status_changed": reject_changed, + } + + if verbose: + n_diff = sum( + 1 + for u in set(stored_labels) | set(new_labels) + if stored_labels.get(u) != new_labels.get(u) + ) + logger.info( + f"fix_15XX: {auto_curation_key} — " + f"{n_diff} unit(s) with label changes" + + ( + f", {len(reject_changed)} reject " f"status change(s)" + if reject_changed + else "" + ) + ) + + if not dry_run: + # NWB edit happens before the DB transaction. If the + # file edit fails, DB state is unchanged and the entry + # can be retried. If the DB commit fails after a + # successful file write, the labels in the NWB are + # already correct and idempotent reprocessing is safe. + with Curation.connection.transaction: + Curation.update1( + { + **auto_curation_key, + "curation_labels": new_labels, + } + ) + if has_downstream: + self._fix_15XX_units(auto_curation_key, new_labels, verbose) + + return result + + @staticmethod + def _fix_15XX_units(curation_key, new_labels, verbose=True): + """Update CuratedSpikeSorting.Unit labels for a curation. + + Updates the label column on existing Unit rows. Does NOT + add or remove rows — if accept/reject status changed, a + full repopulation of CuratedSpikeSorting is needed. + + Parameters + ---------- + curation_key : dict + Primary key to the Curation entry. + new_labels : dict + Corrected labels dict {unit_id: [label, ...]}. + verbose : bool + If True, log changes. + """ + unit_rows = (CuratedSpikeSorting.Unit & curation_key).fetch( + as_dict=True + ) + for row in unit_rows: + uid = row["unit_id"] + str_uid = str(uid) + old_label = row["label"] + if str_uid in new_labels: + new_label = ",".join(new_labels[str_uid]) + elif uid in new_labels: + new_label = ",".join(new_labels[uid]) + else: + new_label = "" + if old_label != new_label: + if verbose: + logger.info( + f" unit {uid}: " f"'{old_label}' -> '{new_label}'" + ) + CuratedSpikeSorting.Unit.update1( + { + **{ + k: row[k] + for k in CuratedSpikeSorting.Unit.primary_key + }, + "label": new_label, + } + ) + + @staticmethod + def _fix_15XX_nwb(curation_key, new_labels, verbose=True): + """Update labels in the CuratedSpikeSorting NWB analysis file. + + Edits the ``label`` column in the NWB units table in place + and updates the external-table checksum so DataJoint's + filepath store stays consistent. + + Does NOT add or remove unit rows. If accept/reject status + changed such that units need to be added, a full + repopulation of CuratedSpikeSorting is required. + + Parameters + ---------- + curation_key : dict + Primary key to the Curation entry (used to look up + the CuratedSpikeSorting analysis file). + new_labels : dict + Corrected labels dict ``{unit_id: [label, ...]}``. + verbose : bool + If True, log file path and changes. + """ + import pynwb + + css_row = (CuratedSpikeSorting & curation_key).fetch(as_dict=True) + if not css_row: + return + css_row = css_row[0] + analysis_file_name = css_row["analysis_file_name"] + abs_path = AnalysisNwbfile().get_abs_path(analysis_file_name) + + if verbose: + logger.info(f" NWB: editing {abs_path}") + + with pynwb.NWBHDF5IO( + path=abs_path, mode="a", load_namespaces=True + ) as io: + nwbf = io.read() + if nwbf.units is None or "label" not in nwbf.units: + if verbose: + logger.info(" NWB: no units/label column; skip") + return + + unit_ids = list(nwbf.units.id.data[:]) + label_col = nwbf.units["label"] + changed = False + + for idx, uid in enumerate(unit_ids): + str_uid = str(uid) + if str_uid in new_labels: + new_val = ",".join(new_labels[str_uid]) + elif uid in new_labels: + new_val = ",".join(new_labels[uid]) + else: + new_val = "" + + old_val = label_col[idx] + if old_val != new_val: + label_col.data[idx] = new_val + changed = True + if verbose: + logger.info( + f" NWB unit {uid}: " f"'{old_val}' -> '{new_val}'" + ) + + if changed: + io.write(nwbf) + + # Update the external-table checksum to match edited file + if changed: + abs_path = Path(abs_path) + anwb = AnalysisNwbfile() + rel_path = abs_path.relative_to(anwb._analysis_dir) + ext_tbl = anwb._ext_tbl + ext_key = (ext_tbl & f"filepath = '{str(rel_path)}'").fetch1() + ext_key.update( + { + "contents_hash": dj.hash.uuid_from_file(abs_path), + "size": abs_path.stat().st_size, + } + ) + ext_tbl.update1(ext_key) + if verbose: + logger.info(" NWB: checksum updated") @schema diff --git a/tests/spikesorting/v0/test_bug_1281.py b/tests/spikesorting/v0/test_bug_1281.py new file mode 100644 index 000000000..a936b75a0 --- /dev/null +++ b/tests/spikesorting/v0/test_bug_1281.py @@ -0,0 +1,198 @@ +"""Tests for bug 1281: AutomaticCuration.get_labels early return. + +The bug caused get_labels() to return after processing only the first +metric in label_params, silently dropping labels from subsequent metrics. + +These tests verify the fix and define the expected behavior for the +fix_1281 repair function. +""" + +import pytest + + +# -- Helpers ---------------------------------------------------------- + + +@pytest.fixture(scope="session") +def get_labels(spike_v0): + """Wrapper: sorting arg is unused in label logic.""" + + def _get_labels(parent_labels, quality_metrics, label_params): + return spike_v0.AutomaticCuration.get_labels( + sorting=None, + parent_labels=parent_labels, + quality_metrics=quality_metrics, + label_params=label_params, + ) + + yield _get_labels + + +# -- Task 5.1: get_labels unit tests --------------------------------- + + +class TestGetLabelsMultiMetric: + """Verify get_labels processes ALL metrics in label_params.""" + + two_metric_params = { + "nn_noise_overlap": [">", 0.1, ["noise", "reject"]], + "isi_violation": [">", 0.5, ["mua"]], + } + quality_metrics = { + "nn_noise_overlap": {"0": 0.2, "1": 0.05, "2": 0.3}, + "isi_violation": {"0": 0.8, "1": 0.9, "2": 0.1}, + } + + def test_both_metrics_applied(self, get_labels): + """Unit matching both metrics gets labels from both.""" + result = get_labels({}, self.quality_metrics, self.two_metric_params) + # unit 0: nn_noise_overlap 0.2 > 0.1 -> noise,reject + # isi_violation 0.8 > 0.5 -> mua + assert "noise" in result["0"] + assert "reject" in result["0"] + assert "mua" in result["0"] + + def test_second_metric_only(self, get_labels): + """Unit matching only second metric still gets its labels.""" + result = get_labels({}, self.quality_metrics, self.two_metric_params) + # unit 1: nn_noise_overlap 0.05 <= 0.1 -> no label + # isi_violation 0.9 > 0.5 -> mua + assert "1" in result + assert "mua" in result["1"] + assert "noise" not in result.get("1", []) + + def test_no_match_no_label(self, get_labels): + """Unit matching neither metric gets no labels.""" + result = get_labels({}, self.quality_metrics, self.two_metric_params) + # unit 2: nn_noise_overlap 0.3 > 0.1 -> noise,reject + # isi_violation 0.1 <= 0.5 -> no label + assert "noise" in result["2"] + assert "reject" in result["2"] + assert "mua" not in result["2"] + + def test_all_units_covered(self, get_labels): + """All units present in quality_metrics are evaluated.""" + result = get_labels({}, self.quality_metrics, self.two_metric_params) + assert set(result.keys()) >= {"0", "1"} + + +class TestGetLabelsSingleMetric: + """Single-metric label_params works identically pre/post fix.""" + + single_metric_params = { + "nn_noise_overlap": [">", 0.1, ["noise", "reject"]], + } + quality_metrics = { + "nn_noise_overlap": {"0": 0.2, "1": 0.05}, + } + + def test_single_metric_labels(self, get_labels): + result = get_labels({}, self.quality_metrics, self.single_metric_params) + assert "noise" in result["0"] + assert "reject" in result["0"] + + def test_single_metric_no_match(self, get_labels): + result = get_labels({}, self.quality_metrics, self.single_metric_params) + assert "1" not in result + + +class TestGetLabelsEmpty: + """Empty label_params returns parent_labels unchanged.""" + + def test_empty_params_returns_parent(self, get_labels): + parent = {"0": ["accept"]} + result = get_labels(parent, {"metric": {"0": 1.0}}, {}) + assert result == {"0": ["accept"]} + + def test_empty_params_empty_parent(self, get_labels): + result = get_labels({}, {"metric": {"0": 1.0}}, {}) + assert result == {} + + +class TestGetLabelsParentPreservation: + """Existing parent labels are preserved and extended.""" + + def test_parent_labels_preserved(self, get_labels): + parent = {"0": ["accept"]} + params = {"nn_noise_overlap": [">", 0.1, ["noise", "reject"]]} + qm = {"nn_noise_overlap": {"0": 0.2}} + result = get_labels(parent, qm, params) + assert "accept" in result["0"] + assert "noise" in result["0"] + + def test_no_duplicate_labels(self, get_labels): + parent = {"0": ["noise"]} + params = {"nn_noise_overlap": [">", 0.1, ["noise", "reject"]]} + qm = {"nn_noise_overlap": {"0": 0.2}} + result = get_labels(parent, qm, params) + assert result["0"].count("noise") == 1 + + def test_parent_labels_extended_by_second_metric(self, get_labels): + """Parent labels from first metric are extended by second.""" + parent = {} + params = { + "nn_noise_overlap": [">", 0.1, ["noise", "reject"]], + "isi_violation": [">", 0.5, ["mua"]], + } + qm = { + "nn_noise_overlap": {"0": 0.2}, + "isi_violation": {"0": 0.8}, + } + result = get_labels(parent, qm, params) + assert set(result["0"]) == {"noise", "reject", "mua"} + + +class TestGetLabelsMetricSkipping: + """Metrics not in quality_metrics are skipped, not errored.""" + + def test_missing_metric_skipped(self, get_labels): + params = { + "nn_noise_overlap": [">", 0.1, ["noise", "reject"]], + "nonexistent_metric": [">", 0.5, ["mua"]], + } + qm = {"nn_noise_overlap": {"0": 0.2}} + result = get_labels({}, qm, params) + assert "noise" in result["0"] + assert "mua" not in result.get("0", []) + + def test_only_overlapping_metrics_apply(self, get_labels): + """With 3 metrics, only 2 present in qm are applied.""" + params = { + "nn_noise_overlap": [">", 0.1, ["noise"]], + "missing": [">", 0.5, ["artifact"]], + "isi_violation": [">", 0.5, ["mua"]], + } + qm = { + "nn_noise_overlap": {"0": 0.2}, + "isi_violation": {"0": 0.8}, + } + result = get_labels({}, qm, params) + assert "noise" in result["0"] + assert "mua" in result["0"] + assert "artifact" not in result.get("0", []) + + +class TestGetLabelsComparisonOperators: + """All comparison operators work correctly.""" + + @pytest.mark.parametrize( + "op,threshold,value,should_label", + [ + (">", 0.5, 0.6, True), + (">", 0.5, 0.5, False), + (">=", 0.5, 0.5, True), + ("<", 0.5, 0.4, True), + ("<", 0.5, 0.5, False), + ("<=", 0.5, 0.5, True), + ("==", 0.5, 0.5, True), + ("==", 0.5, 0.6, False), + ], + ) + def test_operator(self, get_labels, op, threshold, value, should_label): + params = {"nn_noise_overlap": [op, threshold, ["reject"]]} + qm = {"nn_noise_overlap": {"0": value}} + result = get_labels({}, qm, params) + if should_label: + assert "0" in result and "reject" in result["0"] + else: + assert "0" not in result From ef26261269ef33ed0f7564e49cd1103d12324baa Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Wed, 28 Jan 2026 14:38:15 -0800 Subject: [PATCH 2/4] WIP: get_labels bugfix 2, add temp doc, normalize labels --- .../spikesorting/v0/spikesorting_curation.py | 47 +-- track-1281.py | 359 ++++++++++++++++++ 2 files changed, 384 insertions(+), 22 deletions(-) create mode 100644 track-1281.py diff --git a/src/spyglass/spikesorting/v0/spikesorting_curation.py b/src/spyglass/spikesorting/v0/spikesorting_curation.py index 95466edd7..88ee1ef49 100644 --- a/src/spyglass/spikesorting/v0/spikesorting_curation.py +++ b/src/spyglass/spikesorting/v0/spikesorting_curation.py @@ -1099,6 +1099,16 @@ def get_labels(sorting, parent_labels, quality_metrics, label_params): return parent_labels # Unindent to fix #15XX + @staticmethod + def _normalize_labels(labels): + """Return labels dict with all keys cast to strings. + + Quality metrics loaded from JSON always have string keys, but + DataJoint blob serialization may store them as ints. Normalize + to strings so comparisons are consistent. + """ + return {str(k): v for k, v in labels.items()} + def fix_15XX(self, restriction=True, dry_run=True, verbose=True): """Find and repair entries affected by get_labels bugs. @@ -1208,13 +1218,17 @@ def _fix1_15XX(self, key, dry_run=True, verbose=True): parent_curation = (Curation & key).fetch(as_dict=True)[0] parent_labels = parent_curation["curation_labels"] - new_labels = self.get_labels( - sorting=None, - parent_labels=deepcopy(parent_labels), - quality_metrics=quality_metrics, - label_params=label_params, + new_labels = self._normalize_labels( + self.get_labels( + sorting=None, + parent_labels=deepcopy(parent_labels), + quality_metrics=quality_metrics, + label_params=label_params, + ) + ) + stored_labels = self._normalize_labels( + (Curation & auto_curation_key).fetch1("curation_labels") ) - stored_labels = (Curation & auto_curation_key).fetch1("curation_labels") if new_labels == stored_labels: return None @@ -1302,7 +1316,8 @@ def _fix_15XX_units(curation_key, new_labels, verbose=True): curation_key : dict Primary key to the Curation entry. new_labels : dict - Corrected labels dict {unit_id: [label, ...]}. + Corrected labels dict with string keys + ``{unit_id: [label, ...]}``. verbose : bool If True, log changes. """ @@ -1310,15 +1325,9 @@ def _fix_15XX_units(curation_key, new_labels, verbose=True): as_dict=True ) for row in unit_rows: - uid = row["unit_id"] - str_uid = str(uid) + uid = str(row["unit_id"]) old_label = row["label"] - if str_uid in new_labels: - new_label = ",".join(new_labels[str_uid]) - elif uid in new_labels: - new_label = ",".join(new_labels[uid]) - else: - new_label = "" + new_label = ",".join(new_labels.get(uid, [])) if old_label != new_label: if verbose: logger.info( @@ -1382,13 +1391,7 @@ def _fix_15XX_nwb(curation_key, new_labels, verbose=True): changed = False for idx, uid in enumerate(unit_ids): - str_uid = str(uid) - if str_uid in new_labels: - new_val = ",".join(new_labels[str_uid]) - elif uid in new_labels: - new_val = ",".join(new_labels[uid]) - else: - new_val = "" + new_val = ",".join(new_labels.get(str(uid), [])) old_val = label_col[idx] if old_val != new_val: diff --git a/track-1281.py b/track-1281.py new file mode 100644 index 000000000..a8277a4e3 --- /dev/null +++ b/track-1281.py @@ -0,0 +1,359 @@ +"""Temporary table to track impact of `AutomaticCuration.get_labels` + +File to be deleted before merge. +""" + +import json +from copy import deepcopy +from datetime import datetime + +import datajoint as dj + +from spyglass.spikesorting.v0.spikesorting_curation import ( + AutomaticCuration, + AutomaticCurationParameters, + CuratedSpikeSorting, + CuratedSpikeSortingSelection, + Curation, + QualityMetrics, +) + +schema = dj.schema("cbroz_bugs") + + +@schema +class Bug1281(dj.Computed): + definition = """ + -> AutomaticCuration + --- + is_impacted: bool # This key is impacted by bug 1281 + has_downstream: bool # A CuratedSpikeSorting entry depends on this + missing_metrics=null: blob # List of metrics not applied + """ + + _impact_date = datetime(2025, 4, 22) + + # -- Helpers ------------------------------------------------------ + + @staticmethod + def _normalize_labels(labels): + """Return labels dict with all keys cast to strings. + + Quality metrics loaded from JSON always have string keys, but + DataJoint blob serialization may store them as ints. Normalize + to strings so comparisons are consistent. + + Parameters + ---------- + labels : dict + ``{unit_id: [label, ...]}`` with str or int keys. + + Returns + ------- + dict + Same structure with all keys as ``str``. + """ + return {str(k): v for k, v in labels.items()} + + @staticmethod + def _fetch_auto_curation_key(key): + """Return ``auto_curation_key`` blob for an AutomaticCuration key.""" + return (AutomaticCuration & key).fetch1("auto_curation_key") + + @staticmethod + def _fetch_label_params(key): + """Return ``label_params`` dict for an AutomaticCuration key.""" + params = (AutomaticCuration & key) * AutomaticCurationParameters + return params.fetch1("label_params") + + @staticmethod + def _fetch_quality_metrics(key): + """Load quality metrics JSON for a key, or None if missing.""" + metrics_path = (QualityMetrics & key).fetch1("quality_metrics_path") + try: + with open(metrics_path) as f: + return json.load(f) + except FileNotFoundError: + return None + + @classmethod + def _compute_expected(cls, key, label_params, quality_metrics): + """Recompute labels using the fixed ``get_labels``. + + Parameters + ---------- + key : dict + Primary key to the parent Curation (same as AutomaticCuration + primary key minus the auto_curation_key). + label_params : dict + Label parameter rules. + quality_metrics : dict + Quality metrics loaded from JSON. + + Returns + ------- + expected_labels : dict + Recomputed labels with normalized (string) keys. + """ + parent_curation = (Curation & key).fetch(as_dict=True)[0] + parent_labels = parent_curation["curation_labels"] + expected = AutomaticCuration.get_labels( + sorting=None, + parent_labels=deepcopy(parent_labels), + quality_metrics=quality_metrics, + label_params=label_params, + ) + return cls._normalize_labels(expected) + + @classmethod + def _compare_labels( + cls, stored_labels, expected_labels, label_params, quality_metrics + ): + """Compare stored vs expected labels, return diffs and metrics. + + Both dicts are normalized to string keys before comparison. + + Parameters + ---------- + stored_labels : dict + Labels fetched from the Curation row. + expected_labels : dict + Labels recomputed by the fixed ``get_labels``. + label_params : dict + Label parameter rules (metric -> [op, thresh, tags]). + quality_metrics : dict + Quality metrics loaded from JSON. + + Returns + ------- + diffs : dict + ``{unit_id: {"stored": [...], "expected": [...]}}`` for + units whose labels differ. + missing_metrics : list[str] + Metrics whose labels are absent or wrong in stored data. + """ + stored = cls._normalize_labels(stored_labels) + expected = cls._normalize_labels(expected_labels) + + all_uids = set(stored.keys()) | set(expected.keys()) + + diffs = {} + for uid in all_uids: + s = stored.get(uid) + e = expected.get(uid) + if s != e: + diffs[uid] = {"stored": s, "expected": e} + + missing_metrics = [] + for metric in label_params: + if metric not in quality_metrics: + continue + for unit_id in quality_metrics[metric]: + uid = str(unit_id) + if uid in all_uids and stored.get(uid) != expected.get(uid): + missing_metrics.append(metric) + break + + return diffs, missing_metrics + + @staticmethod + def _has_downstream(auto_curation_key): + """Check if a CuratedSpikeSortingSelection entry exists.""" + return len(CuratedSpikeSortingSelection & auto_curation_key) > 0 + + # -- Core methods ------------------------------------------------- + + def _insert( + self, + key, + is_impacted=False, + has_downstream=False, + missing_metrics=None, + ): + self.insert1( + { + **key, + "is_impacted": is_impacted, + "has_downstream": has_downstream, + "missing_metrics": missing_metrics, + } + ) + + def make(self, key): + # --- Early return 1: Empty label_params --- + # No labels to compute, nothing to break. + label_params = self._fetch_label_params(key) + if not label_params: + self._insert(key) + return + + # --- Early return 2: Curation created before bug date --- + # All three bugs (indentation, aliasing, duplicate comparison) + # were introduced in PR #1281. Entries created before that + # date used different code and are not affected. + auto_curation_key = self._fetch_auto_curation_key(key) + time_of_creation = (Curation & auto_curation_key).fetch1( + "time_of_creation" + ) + if time_of_creation < self._impact_date.timestamp(): + self._insert(key) + return + + # --- Early return 3: Missing quality metrics file --- + quality_metrics = self._fetch_quality_metrics(key) + if quality_metrics is None: + self._insert(key) + return + + # --- Full check: Recompute labels and compare --- + has_downstream = self._has_downstream(auto_curation_key) + expected_labels = self._compute_expected( + key, label_params, quality_metrics + ) + stored_labels = (Curation & auto_curation_key).fetch1("curation_labels") + + _, missing_metrics = self._compare_labels( + stored_labels, expected_labels, label_params, quality_metrics + ) + + is_impacted = len(missing_metrics) > 0 + self.insert1( + { + **key, + "is_impacted": is_impacted, + "has_downstream": has_downstream, + "missing_metrics": (missing_metrics if is_impacted else None), + } + ) + + def inspect(self, key): + """Print detailed diagnostics for one AutomaticCuration entry. + + Parameters + ---------- + key : dict + Primary key to AutomaticCuration (and thus Bug1281). + """ + # --- Fetch Bug1281 row if populated --- + row = (self & key).fetch(as_dict=True) + if row: + row = row[0] + print("=== Bug1281 record ===") + print(f" is_impacted: {row['is_impacted']}") + print(f" has_downstream: {row['has_downstream']}") + print(f" missing_metrics: {row['missing_metrics']}") + else: + print("=== Bug1281 record: not yet populated ===") + return + + # --- label_params --- + label_params = self._fetch_label_params(key) + print(f"\n=== label_params ({len(label_params)} metric(s)) ===") + for metric, rule in label_params.items(): + print(f" {metric}: {rule[0]} {rule[1]} -> {rule[2]}") + + # --- auto_curation_key & Curation timestamps --- + auto_curation_key = self._fetch_auto_curation_key(key) + time_of_creation = (Curation & auto_curation_key).fetch1( + "time_of_creation" + ) + created = datetime.fromtimestamp(time_of_creation) + bug_date = self._impact_date + print("\n=== Curation created ===") + print(f" {created} (bug introduced {bug_date.date()})") + print( + f" after bug date: " f"{time_of_creation >= bug_date.timestamp()}" + ) + + # --- quality_metrics overlap --- + quality_metrics = self._fetch_quality_metrics(key) + if quality_metrics is None: + print("\n=== Quality metrics: FILE NOT FOUND ===") + return + + overlap = sorted(set(label_params.keys()) & set(quality_metrics.keys())) + missing_from_qm = sorted( + set(label_params.keys()) - set(quality_metrics.keys()) + ) + print("\n=== Metric overlap ===") + print(f" label_params metrics: {sorted(label_params.keys())}") + print(f" quality_metrics keys: {sorted(quality_metrics.keys())}") + print(f" overlap ({len(overlap)}): {overlap}") + if missing_from_qm: + print(f" skipped (not in qm): {missing_from_qm}") + + # --- Stored vs expected labels --- + expected_labels = self._compute_expected( + key, label_params, quality_metrics + ) + stored_labels = (Curation & auto_curation_key).fetch1("curation_labels") + + diffs, _ = self._compare_labels( + stored_labels, expected_labels, label_params, quality_metrics + ) + + stored_norm = self._normalize_labels(stored_labels) + expected_norm = self._normalize_labels(expected_labels) + + print("\n=== Label comparison ===") + print(f" total units (stored): {len(stored_norm)}") + print(f" total units (expected): {len(expected_norm)}") + print(f" units with differences: {len(diffs)}") + if diffs: + print(f"\n {'unit':>8} {'stored':<30} {'expected':<30}") + print(f" {'----':>8} {'------':<30} {'--------':<30}") + for uid in sorted( + diffs, + key=lambda x: int(x) if x.isdigit() else x, + ): + d = diffs[uid] + s_str = str(d["stored"]) if d["stored"] is not None else "---" + e_str = ( + str(d["expected"]) if d["expected"] is not None else "---" + ) + print(f" {uid:>8} {s_str:<30} {e_str:<30}") + + # --- Downstream impact --- + has_downstream = self._has_downstream(auto_curation_key) + print("\n=== Downstream ===") + print(f" CuratedSpikeSortingSelection entry: {has_downstream}") + if has_downstream: + n_units = len(CuratedSpikeSorting.Unit & auto_curation_key) + print(f" CuratedSpikeSorting.Unit rows: {n_units}") + + # Show which units would change accept/reject status + if diffs: + reject_changes = [] + for uid, d in diffs.items(): + was_reject = ( + d["stored"] is not None and "reject" in d["stored"] + ) + should_reject = ( + d["expected"] is not None and "reject" in d["expected"] + ) + if was_reject != should_reject: + reject_changes.append( + { + "unit": uid, + "was_rejected": was_reject, + "should_reject": should_reject, + } + ) + if reject_changes: + print( + f"\n Units with changed accept/reject " + f"status: {len(reject_changes)}" + ) + for rc in reject_changes: + status = ( + "SHOULD BE REJECTED (was accepted)" + if rc["should_reject"] + else "SHOULD BE ACCEPTED (was rejected)" + ) + print(f" unit {rc['unit']}: {status}") + else: + print( + "\n No units change accept/reject status " + "(label differences are non-reject labels " + "only)" + ) From e6b739fc97c734caf31748e5e9e1e77231426d52 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Fri, 30 Jan 2026 09:14:29 -0800 Subject: [PATCH 3/4] Sep bug components --- .../spikesorting/v0/spikesorting_curation.py | 29 +- tests/spikesorting/v0/test_bug_1281.py | 59 +-- track-1281.py | 396 +++++++++++------- 3 files changed, 299 insertions(+), 185 deletions(-) diff --git a/src/spyglass/spikesorting/v0/spikesorting_curation.py b/src/spyglass/spikesorting/v0/spikesorting_curation.py index 88ee1ef49..9dfdba202 100644 --- a/src/spyglass/spikesorting/v0/spikesorting_curation.py +++ b/src/spyglass/spikesorting/v0/spikesorting_curation.py @@ -1089,25 +1089,28 @@ def get_labels(sorting, parent_labels, quality_metrics, label_params): label = label_params[metric] if compare(quality_metrics[metric][unit_id], label[1]): - if unit_id not in parent_labels: - parent_labels[unit_id] = label[2].copy() + if int(unit_id) not in parent_labels: + parent_labels[int(unit_id)] = label[2].copy() # check if the label is already there, and if not, add it else: + # remove 'accept' label if it exists + if "accept" in parent_labels[int(unit_id)]: + parent_labels[int(unit_id)].remove("accept") for element in label[2].copy(): - if element not in parent_labels[unit_id]: - parent_labels[unit_id].append(element) + if element not in parent_labels[int(unit_id)]: + parent_labels[int(unit_id)].append(element) - return parent_labels # Unindent to fix #15XX + return parent_labels @staticmethod def _normalize_labels(labels): - """Return labels dict with all keys cast to strings. + """Return labels dict with all keys cast to int. - Quality metrics loaded from JSON always have string keys, but - DataJoint blob serialization may store them as ints. Normalize - to strings so comparisons are consistent. + Quality metrics loaded from JSON have string keys, while + the fixed ``get_labels`` uses ``int(unit_id)``. Normalize + to int so comparisons are consistent regardless of source. """ - return {str(k): v for k, v in labels.items()} + return {int(k): v for k, v in labels.items()} def fix_15XX(self, restriction=True, dry_run=True, verbose=True): """Find and repair entries affected by get_labels bugs. @@ -1316,7 +1319,7 @@ def _fix_15XX_units(curation_key, new_labels, verbose=True): curation_key : dict Primary key to the Curation entry. new_labels : dict - Corrected labels dict with string keys + Corrected labels dict with int keys ``{unit_id: [label, ...]}``. verbose : bool If True, log changes. @@ -1325,7 +1328,7 @@ def _fix_15XX_units(curation_key, new_labels, verbose=True): as_dict=True ) for row in unit_rows: - uid = str(row["unit_id"]) + uid = int(row["unit_id"]) old_label = row["label"] new_label = ",".join(new_labels.get(uid, [])) if old_label != new_label: @@ -1391,7 +1394,7 @@ def _fix_15XX_nwb(curation_key, new_labels, verbose=True): changed = False for idx, uid in enumerate(unit_ids): - new_val = ",".join(new_labels.get(str(uid), [])) + new_val = ",".join(new_labels.get(int(uid), [])) old_val = label_col[idx] if old_val != new_val: diff --git a/tests/spikesorting/v0/test_bug_1281.py b/tests/spikesorting/v0/test_bug_1281.py index a936b75a0..550df9199 100644 --- a/tests/spikesorting/v0/test_bug_1281.py +++ b/tests/spikesorting/v0/test_bug_1281.py @@ -48,32 +48,32 @@ def test_both_metrics_applied(self, get_labels): result = get_labels({}, self.quality_metrics, self.two_metric_params) # unit 0: nn_noise_overlap 0.2 > 0.1 -> noise,reject # isi_violation 0.8 > 0.5 -> mua - assert "noise" in result["0"] - assert "reject" in result["0"] - assert "mua" in result["0"] + assert "noise" in result[0] + assert "reject" in result[0] + assert "mua" in result[0] def test_second_metric_only(self, get_labels): """Unit matching only second metric still gets its labels.""" result = get_labels({}, self.quality_metrics, self.two_metric_params) # unit 1: nn_noise_overlap 0.05 <= 0.1 -> no label # isi_violation 0.9 > 0.5 -> mua - assert "1" in result - assert "mua" in result["1"] - assert "noise" not in result.get("1", []) + assert 1 in result + assert "mua" in result[1] + assert "noise" not in result.get(1, []) def test_no_match_no_label(self, get_labels): """Unit matching neither metric gets no labels.""" result = get_labels({}, self.quality_metrics, self.two_metric_params) # unit 2: nn_noise_overlap 0.3 > 0.1 -> noise,reject # isi_violation 0.1 <= 0.5 -> no label - assert "noise" in result["2"] - assert "reject" in result["2"] - assert "mua" not in result["2"] + assert "noise" in result[2] + assert "reject" in result[2] + assert "mua" not in result[2] def test_all_units_covered(self, get_labels): """All units present in quality_metrics are evaluated.""" result = get_labels({}, self.quality_metrics, self.two_metric_params) - assert set(result.keys()) >= {"0", "1"} + assert set(result.keys()) >= {0, 1} class TestGetLabelsSingleMetric: @@ -88,21 +88,21 @@ class TestGetLabelsSingleMetric: def test_single_metric_labels(self, get_labels): result = get_labels({}, self.quality_metrics, self.single_metric_params) - assert "noise" in result["0"] - assert "reject" in result["0"] + assert "noise" in result[0] + assert "reject" in result[0] def test_single_metric_no_match(self, get_labels): result = get_labels({}, self.quality_metrics, self.single_metric_params) - assert "1" not in result + assert 1 not in result class TestGetLabelsEmpty: """Empty label_params returns parent_labels unchanged.""" def test_empty_params_returns_parent(self, get_labels): - parent = {"0": ["accept"]} + parent = {0: ["accept"]} result = get_labels(parent, {"metric": {"0": 1.0}}, {}) - assert result == {"0": ["accept"]} + assert result == {0: ["accept"]} def test_empty_params_empty_parent(self, get_labels): result = get_labels({}, {"metric": {"0": 1.0}}, {}) @@ -112,20 +112,21 @@ def test_empty_params_empty_parent(self, get_labels): class TestGetLabelsParentPreservation: """Existing parent labels are preserved and extended.""" - def test_parent_labels_preserved(self, get_labels): - parent = {"0": ["accept"]} + def test_accept_removed_when_other_labels_added(self, get_labels): + """Accept label is removed when metric labels are applied.""" + parent = {0: ["accept"]} params = {"nn_noise_overlap": [">", 0.1, ["noise", "reject"]]} qm = {"nn_noise_overlap": {"0": 0.2}} result = get_labels(parent, qm, params) - assert "accept" in result["0"] - assert "noise" in result["0"] + assert "accept" not in result[0] + assert "noise" in result[0] def test_no_duplicate_labels(self, get_labels): - parent = {"0": ["noise"]} + parent = {0: ["noise"]} params = {"nn_noise_overlap": [">", 0.1, ["noise", "reject"]]} qm = {"nn_noise_overlap": {"0": 0.2}} result = get_labels(parent, qm, params) - assert result["0"].count("noise") == 1 + assert result[0].count("noise") == 1 def test_parent_labels_extended_by_second_metric(self, get_labels): """Parent labels from first metric are extended by second.""" @@ -139,7 +140,7 @@ def test_parent_labels_extended_by_second_metric(self, get_labels): "isi_violation": {"0": 0.8}, } result = get_labels(parent, qm, params) - assert set(result["0"]) == {"noise", "reject", "mua"} + assert set(result[0]) == {"noise", "reject", "mua"} class TestGetLabelsMetricSkipping: @@ -152,8 +153,8 @@ def test_missing_metric_skipped(self, get_labels): } qm = {"nn_noise_overlap": {"0": 0.2}} result = get_labels({}, qm, params) - assert "noise" in result["0"] - assert "mua" not in result.get("0", []) + assert "noise" in result[0] + assert "mua" not in result.get(0, []) def test_only_overlapping_metrics_apply(self, get_labels): """With 3 metrics, only 2 present in qm are applied.""" @@ -167,9 +168,9 @@ def test_only_overlapping_metrics_apply(self, get_labels): "isi_violation": {"0": 0.8}, } result = get_labels({}, qm, params) - assert "noise" in result["0"] - assert "mua" in result["0"] - assert "artifact" not in result.get("0", []) + assert "noise" in result[0] + assert "mua" in result[0] + assert "artifact" not in result.get(0, []) class TestGetLabelsComparisonOperators: @@ -193,6 +194,6 @@ def test_operator(self, get_labels, op, threshold, value, should_label): qm = {"nn_noise_overlap": {"0": value}} result = get_labels({}, qm, params) if should_label: - assert "0" in result and "reject" in result["0"] + assert 0 in result and "reject" in result[0] else: - assert "0" not in result + assert 0 not in result diff --git a/track-1281.py b/track-1281.py index a8277a4e3..2c5a10553 100644 --- a/track-1281.py +++ b/track-1281.py @@ -16,6 +16,7 @@ CuratedSpikeSortingSelection, Curation, QualityMetrics, + _comparison_to_function, ) schema = dj.schema("cbroz_bugs") @@ -26,22 +27,23 @@ class Bug1281(dj.Computed): definition = """ -> AutomaticCuration --- - is_impacted: bool # This key is impacted by bug 1281 - has_downstream: bool # A CuratedSpikeSorting entry depends on this - missing_metrics=null: blob # List of metrics not applied + return_bug: bool # Bug A: early return after first metric + list_bug: bool # Bug B: list aliasing across units + dupe_bug: bool # Bug C: duplicate label comparison + has_downstream: bool # CuratedSpikeSorting depends on this """ - _impact_date = datetime(2025, 4, 22) + _return_bug_impact_date = datetime(2025, 4, 22) - # -- Helpers ------------------------------------------------------ + # -- Normalization ------------------------------------------------ @staticmethod def _normalize_labels(labels): - """Return labels dict with all keys cast to strings. + """Return labels dict with all keys cast to int. - Quality metrics loaded from JSON always have string keys, but - DataJoint blob serialization may store them as ints. Normalize - to strings so comparisons are consistent. + Quality metrics loaded from JSON have string keys, while + the fixed ``get_labels`` uses ``int(unit_id)``. Normalize + to int so comparisons are consistent regardless of source. Parameters ---------- @@ -51,24 +53,26 @@ def _normalize_labels(labels): Returns ------- dict - Same structure with all keys as ``str``. + Same structure with all keys as ``int``. """ - return {str(k): v for k, v in labels.items()} + return {int(k): v for k, v in labels.items()} + + # -- Fetch helpers ------------------------------------------------ @staticmethod def _fetch_auto_curation_key(key): - """Return ``auto_curation_key`` blob for an AutomaticCuration key.""" + """Return ``auto_curation_key`` blob.""" return (AutomaticCuration & key).fetch1("auto_curation_key") @staticmethod def _fetch_label_params(key): - """Return ``label_params`` dict for an AutomaticCuration key.""" + """Return ``label_params`` dict.""" params = (AutomaticCuration & key) * AutomaticCurationParameters return params.fetch1("label_params") @staticmethod def _fetch_quality_metrics(key): - """Load quality metrics JSON for a key, or None if missing.""" + """Load quality metrics JSON, or None if missing.""" metrics_path = (QualityMetrics & key).fetch1("quality_metrics_path") try: with open(metrics_path) as f: @@ -76,15 +80,21 @@ def _fetch_quality_metrics(key): except FileNotFoundError: return None + @staticmethod + def _has_downstream(auto_curation_key): + """Check if a CuratedSpikeSortingSelection entry exists.""" + return len(CuratedSpikeSortingSelection & auto_curation_key) > 0 + + # -- Label computation -------------------------------------------- + @classmethod - def _compute_expected(cls, key, label_params, quality_metrics): - """Recompute labels using the fixed ``get_labels``. + def _compute_expected(cls, parent_labels, label_params, quality_metrics): + """Compute fully-fixed labels, normalized to int keys. Parameters ---------- - key : dict - Primary key to the parent Curation (same as AutomaticCuration - primary key minus the auto_curation_key). + parent_labels : dict + Labels from the parent Curation (int-normalized). label_params : dict Label parameter rules. quality_metrics : dict @@ -92,11 +102,9 @@ def _compute_expected(cls, key, label_params, quality_metrics): Returns ------- - expected_labels : dict - Recomputed labels with normalized (string) keys. + dict + Recomputed labels with int keys. """ - parent_curation = (Curation & key).fetch(as_dict=True)[0] - parent_labels = parent_curation["curation_labels"] expected = AutomaticCuration.get_labels( sorting=None, parent_labels=deepcopy(parent_labels), @@ -105,124 +113,247 @@ def _compute_expected(cls, key, label_params, quality_metrics): ) return cls._normalize_labels(expected) + @staticmethod + def _get_labels_buggy( + parent_labels, + quality_metrics, + label_params, + bug_a=False, + bug_b=False, + bug_c=False, + ): + """Run labeling logic with specified bugs enabled. + + Each flag re-introduces one historical bug while leaving the + rest of the logic fixed. Caller must ``deepcopy`` both + *parent_labels* and *label_params* before calling, because + Bug B mutates label_params through aliased references. + + All keys are normalized to ``int`` so that bug detection is + not confounded by key-type mismatches. + + Parameters + ---------- + parent_labels : dict + Starting labels (int keys). Will be mutated. + quality_metrics : dict + Quality metrics (string keys from JSON). + label_params : dict + Label parameter rules. + bug_a : bool + If True, return inside the ``for metric`` loop. + bug_b : bool + If True, skip ``.copy()`` on ``label[2]``. + bug_c : bool + If True, use list-in-list comparison + ``.extend()``. + + Returns + ------- + dict + Labels dict with int keys. + """ + if not label_params: + return parent_labels + + for metric in label_params: + if metric not in quality_metrics: + continue + + compare = _comparison_to_function[label_params[metric][0]] + + for unit_id in quality_metrics[metric]: + label = label_params[metric] + uid = int(unit_id) + + if not compare(quality_metrics[metric][unit_id], label[1]): + continue + + if uid not in parent_labels: + if bug_b: + parent_labels[uid] = label[2] + else: + parent_labels[uid] = label[2].copy() + else: + if bug_c: + if label[2] not in parent_labels[uid]: + parent_labels[uid].extend(label[2]) + else: + if "accept" in parent_labels[uid]: + parent_labels[uid].remove("accept") + for element in label[2].copy(): + if element not in parent_labels[uid]: + parent_labels[uid].append(element) + + if bug_a: + return parent_labels + + return parent_labels + + # -- Per-bug detection -------------------------------------------- + @classmethod - def _compare_labels( - cls, stored_labels, expected_labels, label_params, quality_metrics + def _detect_return_bug( + cls, + auto_curation_key, + parent_labels, + label_params, + quality_metrics, + expected, ): - """Compare stored vs expected labels, return diffs and metrics. + """Bug A: would early return produce different labels? - Both dicts are normalized to string keys before comparison. + Only possible when >1 metric overlaps with quality_metrics + AND the Curation was created on or after the date Bug A was + introduced (PR #1281, 2025-04-22). + """ + time_of_creation = (Curation & auto_curation_key).fetch1( + "time_of_creation" + ) + if time_of_creation < cls._return_bug_impact_date.timestamp(): + return False + overlap = set(label_params) & set(quality_metrics) + if len(overlap) <= 1: + return False + buggy = cls._get_labels_buggy( + deepcopy(parent_labels), + quality_metrics, + deepcopy(label_params), + bug_a=True, + ) + return cls._normalize_labels(buggy) != expected + + @classmethod + def _detect_list_bug( + cls, parent_labels, label_params, quality_metrics, expected + ): + """Bug B: would list aliasing cause cross-unit leakage? + + Aliasing manifests when multiple units match the same + metric (sharing a list object) and at least one of those + units also matches a subsequent metric, causing the + append to propagate to all aliased units. + """ + buggy = cls._get_labels_buggy( + deepcopy(parent_labels), + quality_metrics, + deepcopy(label_params), + bug_b=True, + ) + return cls._normalize_labels(buggy) != expected + + @classmethod + def _detect_dupe_bug( + cls, parent_labels, label_params, quality_metrics, expected + ): + """Bug C: would list-in-list comparison create duplicates? + + The old code checked ``label[2] not in parent_labels[uid]`` + (always True for flat string lists), so ``.extend()`` + always ran, producing duplicate labels. + """ + buggy = cls._get_labels_buggy( + deepcopy(parent_labels), + quality_metrics, + deepcopy(label_params), + bug_c=True, + ) + return cls._normalize_labels(buggy) != expected + + # -- Comparison helper -------------------------------------------- + + @classmethod + def _compare_labels(cls, stored_labels, expected_labels): + """Return per-unit diffs between stored and expected. + + Both dicts are normalized to int keys before comparison. Parameters ---------- stored_labels : dict - Labels fetched from the Curation row. + Labels from the Curation row. expected_labels : dict - Labels recomputed by the fixed ``get_labels``. - label_params : dict - Label parameter rules (metric -> [op, thresh, tags]). - quality_metrics : dict - Quality metrics loaded from JSON. + Labels from the fixed ``get_labels``. Returns ------- - diffs : dict - ``{unit_id: {"stored": [...], "expected": [...]}}`` for + dict + ``{uid: {"stored": [...], "expected": [...]}}`` for units whose labels differ. - missing_metrics : list[str] - Metrics whose labels are absent or wrong in stored data. """ stored = cls._normalize_labels(stored_labels) expected = cls._normalize_labels(expected_labels) - - all_uids = set(stored.keys()) | set(expected.keys()) - + all_uids = set(stored) | set(expected) diffs = {} for uid in all_uids: s = stored.get(uid) e = expected.get(uid) if s != e: diffs[uid] = {"stored": s, "expected": e} - - missing_metrics = [] - for metric in label_params: - if metric not in quality_metrics: - continue - for unit_id in quality_metrics[metric]: - uid = str(unit_id) - if uid in all_uids and stored.get(uid) != expected.get(uid): - missing_metrics.append(metric) - break - - return diffs, missing_metrics - - @staticmethod - def _has_downstream(auto_curation_key): - """Check if a CuratedSpikeSortingSelection entry exists.""" - return len(CuratedSpikeSortingSelection & auto_curation_key) > 0 + return diffs # -- Core methods ------------------------------------------------- - def _insert( - self, - key, - is_impacted=False, - has_downstream=False, - missing_metrics=None, - ): + def _insert_clean(self, key, has_downstream=False): + """Insert an unaffected entry.""" self.insert1( { **key, - "is_impacted": is_impacted, + "return_bug": False, + "list_bug": False, + "dupe_bug": False, "has_downstream": has_downstream, - "missing_metrics": missing_metrics, } ) def make(self, key): - # --- Early return 1: Empty label_params --- - # No labels to compute, nothing to break. + # --- Early return: empty label_params --- label_params = self._fetch_label_params(key) if not label_params: - self._insert(key) + self._insert_clean(key) return - # --- Early return 2: Curation created before bug date --- - # All three bugs (indentation, aliasing, duplicate comparison) - # were introduced in PR #1281. Entries created before that - # date used different code and are not affected. auto_curation_key = self._fetch_auto_curation_key(key) - time_of_creation = (Curation & auto_curation_key).fetch1( - "time_of_creation" - ) - if time_of_creation < self._impact_date.timestamp(): - self._insert(key) - return - # --- Early return 3: Missing quality metrics file --- + # --- Early return: missing quality metrics file --- quality_metrics = self._fetch_quality_metrics(key) if quality_metrics is None: - self._insert(key) + self._insert_clean(key) return - # --- Full check: Recompute labels and compare --- - has_downstream = self._has_downstream(auto_curation_key) - expected_labels = self._compute_expected( - key, label_params, quality_metrics + # --- Normalize parent labels to int keys --- + parent_curation = (Curation & key).fetch(as_dict=True)[0] + parent_labels = self._normalize_labels( + parent_curation["curation_labels"] ) - stored_labels = (Curation & auto_curation_key).fetch1("curation_labels") - _, missing_metrics = self._compare_labels( - stored_labels, expected_labels, label_params, quality_metrics + # --- Compute fully-fixed expected labels --- + expected = self._compute_expected( + parent_labels, label_params, quality_metrics ) - is_impacted = len(missing_metrics) > 0 + # --- Detect each bug independently --- + return_bug = self._detect_return_bug( + auto_curation_key, + parent_labels, + label_params, + quality_metrics, + expected, + ) + list_bug = self._detect_list_bug( + parent_labels, label_params, quality_metrics, expected + ) + dupe_bug = self._detect_dupe_bug( + parent_labels, label_params, quality_metrics, expected + ) + has_downstream = self._has_downstream(auto_curation_key) + self.insert1( { **key, - "is_impacted": is_impacted, + "return_bug": return_bug, + "list_bug": list_bug, + "dupe_bug": dupe_bug, "has_downstream": has_downstream, - "missing_metrics": (missing_metrics if is_impacted else None), } ) @@ -234,78 +365,61 @@ def inspect(self, key): key : dict Primary key to AutomaticCuration (and thus Bug1281). """ - # --- Fetch Bug1281 row if populated --- + # --- Bug1281 record --- row = (self & key).fetch(as_dict=True) if row: row = row[0] print("=== Bug1281 record ===") - print(f" is_impacted: {row['is_impacted']}") + print(f" return_bug (A): {row['return_bug']}") + print(f" list_bug (B): {row['list_bug']}") + print(f" dupe_bug (C): {row['dupe_bug']}") print(f" has_downstream: {row['has_downstream']}") - print(f" missing_metrics: {row['missing_metrics']}") else: print("=== Bug1281 record: not yet populated ===") return # --- label_params --- label_params = self._fetch_label_params(key) - print(f"\n=== label_params ({len(label_params)} metric(s)) ===") + print(f"\n=== label_params " f"({len(label_params)} metric(s)) ===") for metric, rule in label_params.items(): print(f" {metric}: {rule[0]} {rule[1]} -> {rule[2]}") - # --- auto_curation_key & Curation timestamps --- - auto_curation_key = self._fetch_auto_curation_key(key) - time_of_creation = (Curation & auto_curation_key).fetch1( - "time_of_creation" - ) - created = datetime.fromtimestamp(time_of_creation) - bug_date = self._impact_date - print("\n=== Curation created ===") - print(f" {created} (bug introduced {bug_date.date()})") - print( - f" after bug date: " f"{time_of_creation >= bug_date.timestamp()}" - ) - # --- quality_metrics overlap --- quality_metrics = self._fetch_quality_metrics(key) if quality_metrics is None: print("\n=== Quality metrics: FILE NOT FOUND ===") return - overlap = sorted(set(label_params.keys()) & set(quality_metrics.keys())) - missing_from_qm = sorted( - set(label_params.keys()) - set(quality_metrics.keys()) - ) + overlap = sorted(set(label_params) & set(quality_metrics)) + missing_from_qm = sorted(set(label_params) - set(quality_metrics)) print("\n=== Metric overlap ===") - print(f" label_params metrics: {sorted(label_params.keys())}") - print(f" quality_metrics keys: {sorted(quality_metrics.keys())}") - print(f" overlap ({len(overlap)}): {overlap}") + print(f" label_params: {sorted(label_params.keys())}") + print(f" quality_metrics: " f"{sorted(quality_metrics.keys())}") + print(f" overlap ({len(overlap)}): {overlap}") if missing_from_qm: - print(f" skipped (not in qm): {missing_from_qm}") + print(f" skipped (not in qm): {missing_from_qm}") # --- Stored vs expected labels --- - expected_labels = self._compute_expected( - key, label_params, quality_metrics + parent_curation = (Curation & key).fetch(as_dict=True)[0] + parent_labels = self._normalize_labels( + parent_curation["curation_labels"] ) - stored_labels = (Curation & auto_curation_key).fetch1("curation_labels") - - diffs, _ = self._compare_labels( - stored_labels, expected_labels, label_params, quality_metrics + expected = self._compute_expected( + parent_labels, label_params, quality_metrics ) + auto_curation_key = self._fetch_auto_curation_key(key) + stored_labels = (Curation & auto_curation_key).fetch1("curation_labels") + diffs = self._compare_labels(stored_labels, expected) stored_norm = self._normalize_labels(stored_labels) - expected_norm = self._normalize_labels(expected_labels) - print("\n=== Label comparison ===") print(f" total units (stored): {len(stored_norm)}") - print(f" total units (expected): {len(expected_norm)}") + print(f" total units (expected): {len(expected)}") print(f" units with differences: {len(diffs)}") if diffs: - print(f"\n {'unit':>8} {'stored':<30} {'expected':<30}") - print(f" {'----':>8} {'------':<30} {'--------':<30}") - for uid in sorted( - diffs, - key=lambda x: int(x) if x.isdigit() else x, - ): + print(f"\n {'unit':>8} " f"{'stored':<30} {'expected':<30}") + print(f" {'----':>8} " f"{'------':<30} {'--------':<30}") + for uid in sorted(diffs): d = diffs[uid] s_str = str(d["stored"]) if d["stored"] is not None else "---" e_str = ( @@ -316,12 +430,11 @@ def inspect(self, key): # --- Downstream impact --- has_downstream = self._has_downstream(auto_curation_key) print("\n=== Downstream ===") - print(f" CuratedSpikeSortingSelection entry: {has_downstream}") + print(f" CuratedSpikeSortingSelection: {has_downstream}") if has_downstream: n_units = len(CuratedSpikeSorting.Unit & auto_curation_key) - print(f" CuratedSpikeSorting.Unit rows: {n_units}") + print(f" CuratedSpikeSorting.Unit rows: {n_units}") - # Show which units would change accept/reject status if diffs: reject_changes = [] for uid, d in diffs.items(): @@ -341,19 +454,16 @@ def inspect(self, key): ) if reject_changes: print( - f"\n Units with changed accept/reject " - f"status: {len(reject_changes)}" + f"\n Units with changed " + f"accept/reject status: " + f"{len(reject_changes)}" ) for rc in reject_changes: status = ( - "SHOULD BE REJECTED (was accepted)" + "SHOULD BE REJECTED " "(was accepted)" if rc["should_reject"] - else "SHOULD BE ACCEPTED (was rejected)" + else "SHOULD BE ACCEPTED " "(was rejected)" ) print(f" unit {rc['unit']}: {status}") else: - print( - "\n No units change accept/reject status " - "(label differences are non-reject labels " - "only)" - ) + print("\n No units change accept/reject " "status") From e8a405c0b8f28fd52bf98a7b5f6e6ab6063fe80e Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Fri, 20 Feb 2026 17:26:12 +0100 Subject: [PATCH 4/4] Revise number --- CHANGELOG.md | 28 +- .../spikesorting/v0/spikesorting_curation.py | 57 +-- .../v0/{test_bug_1281.py => test_bug_1513.py} | 0 track-1281.py | 469 ------------------ 4 files changed, 31 insertions(+), 523 deletions(-) rename tests/spikesorting/v0/{test_bug_1281.py => test_bug_1513.py} (100%) delete mode 100644 track-1281.py diff --git a/CHANGELOG.md b/CHANGELOG.md index dd75ed699..b5c6061bb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -30,12 +30,13 @@ LFPBandV1().fix_1481() #### AutomaticCuration Fix If you were using `v0.AutomaticCuration` after April 2025, you may have stored -inaccurate labels due to #14XX. To fix these, please run the following after updating: +inaccurate labels due to #1513. To fix these, please run the following after +updating: ```python from spyglass.spikesorting.v0 import AutomaticCuration -AutomaticCuration().fix_15XX() +AutomaticCuration().fix_1513() ``` ### Breaking Changes @@ -55,17 +56,19 @@ memory usage significantly. ```python # OLD (before v0.5.6): results.isel(intervals=0) # Get first interval -for i in range(results.sizes['intervals']): # Iterate intervals +for i in range(results.sizes["intervals"]): # Iterate intervals interval_data = results.isel(intervals=i) # NEW (v0.5.6+): results.where(results.interval_labels == 0, drop=True) # Get first interval for label in np.unique(results.interval_labels.values): # Iterate intervals - if label >= 0: # Skip -1 (outside intervals, only with estimate_decoding_params=True) + if ( + label >= 0 + ): # Skip -1 (outside intervals, only with estimate_decoding_params=True) interval_data = results.where(results.interval_labels == label, drop=True) # Or use groupby: -for label, interval_data in results.groupby('interval_labels'): +for label, interval_data in results.groupby("interval_labels"): if label >= 0: # process interval_data pass @@ -75,7 +78,7 @@ for label, interval_data in results.groupby('interval_labels'): - `0, 1, 2, ...` - Sequential interval indices (0-indexed) - `-1` - Time points outside any decoding interval (only when - `estimate_decoding_params=True`) + `estimate_decoding_params=True`) ### Documentation @@ -87,8 +90,8 @@ for label, interval_data in results.groupby('interval_labels'): ### Infrastructure -- Add cross-platform installer script with Docker support, input validation, - and automated environment setup #1414 +- Add cross-platform installer script with Docker support, input validation, and + automated environment setup #1414 - Set default codecov threshold for test fail, disable patch check #1370, #1372 - Simplify PR template #1370 - Allow email send on space check success, clean up maintenance logging #1381 @@ -150,9 +153,10 @@ for label, interval_data in results.groupby('interval_labels'): merge IDs are fetched in non-chronological order #1471 - Separate `ClusterlessDecodingV1` to tri-part `make` #1467 - **BREAKING**: Remove `intervals` dimension from decoding results. Results - from multiple intervals are now concatenated along the `time` dimension - with an `interval_labels` coordinate to track interval membership. This - eliminates NaN padding and reduces memory usage. See migration guide above. + from multiple intervals are now concatenated along the `time` dimension + with an `interval_labels` coordinate to track interval membership. This + eliminates NaN padding and reduces memory usage. See migration guide + above. - LFP @@ -169,7 +173,7 @@ for label, interval_data in results.groupby('interval_labels'): - Implement short-transaction `SpikeSortingRecording.make` for v0 #1338 - Fix `FigURLCuration.make`. Postpone fetch of unhashable items #1505 - - Implement fix for `AutomaticCuration` incorrect labels #15XY + - Implement fix for `AutomaticCuration` incorrect labels #1513 ## [0.5.5] (Aug 6, 2025) diff --git a/src/spyglass/spikesorting/v0/spikesorting_curation.py b/src/spyglass/spikesorting/v0/spikesorting_curation.py index 9dfdba202..ea96dae94 100644 --- a/src/spyglass/spikesorting/v0/spikesorting_curation.py +++ b/src/spyglass/spikesorting/v0/spikesorting_curation.py @@ -1112,39 +1112,10 @@ def _normalize_labels(labels): """ return {int(k): v for k, v in labels.items()} - def fix_15XX(self, restriction=True, dry_run=True, verbose=True): + def fix_1513(self, restriction=True, dry_run=True, verbose=True): """Find and repair entries affected by get_labels bugs. - PR #1281 (2025-04-22) introduced three bugs in `get_labels`: - - A. **Early return**: `return parent_labels` was indented inside - the `for metric` loop, so only the first metric was - processed. Affects entries with >1 metric in label_params. - B. **List aliasing**: `parent_labels[unit_id] = label[2]` - assigned the label list without copying, so units sharing - a metric shared the same list object. Mutations on one - unit could corrupt others. - C. **Duplicate comparison**: `label[2] not in parent_labels` - compared a list against a list of strings. This always - evaluated True for flat lists, so `.extend()` ran - unconditionally, creating duplicate labels. - - Bugs B and C can affect single-metric entries when - `parent_labels` has pre-existing entries from a prior - curation step. - - For each impacted entry, this method: - 1. Recomputes labels using the fixed `get_labels` logic. - 2. Updates `Curation.curation_labels` with corrected labels. - 3. Updates `CuratedSpikeSorting.Unit.label` on existing rows. - - Steps 2-3 are wrapped in a transaction per entry to prevent - partial updates on interruption. - - Note: if a unit's accept/reject status changed, the NWB - analysis file still contains the old unit set. A full - repopulation of `CuratedSpikeSorting` is needed in those - cases. Such entries are flagged in `reject_status_changed`. + For more information, see issue #1513. Parameters ---------- @@ -1163,29 +1134,31 @@ def fix_15XX(self, restriction=True, dry_run=True, verbose=True): old_labels, new_labels, changed, has_downstream, reject_status_changed. """ + raise NotImplementedError("fix_1513 is not yet fully implemented") + restr = (self & restriction) if restriction else self if verbose: - logger.info(f"fix_15XX: scanning {len(restr)} entries") + logger.info(f"fix_1513: scanning {len(restr)} entries") results = [] for key in restr: - result = self._fix1_15XX(key, dry_run, verbose) + result = self._fix1_1513(key, dry_run, verbose) if result is not None: results.append(result) if verbose: logger.info( - f"fix_15XX: {len(results)} impacted entries" + f"fix_1513: {len(results)} impacted entries" + (" (dry run)" if dry_run else " (applied)") ) return results - def _fix1_15XX(self, key, dry_run=True, verbose=True): + def _fix1_1513(self, key, dry_run=True, verbose=True): """Detect and repair a single AutomaticCuration entry. Returns a result dict if the entry is impacted, None otherwise. - See `fix_15XX` for full documentation. + See `fix_1513` for full documentation. """ from copy import deepcopy from datetime import datetime @@ -1212,7 +1185,7 @@ def _fix1_15XX(self, key, dry_run=True, verbose=True): except FileNotFoundError: if verbose: logger.warning( - f"fix_15XX: metrics file not found: " + f"fix_1513: metrics file not found: " f"{metrics_path}; skipping {key}" ) return None @@ -1279,7 +1252,7 @@ def _fix1_15XX(self, key, dry_run=True, verbose=True): if stored_labels.get(u) != new_labels.get(u) ) logger.info( - f"fix_15XX: {auto_curation_key} — " + f"fix_1513: {auto_curation_key} — " f"{n_diff} unit(s) with label changes" + ( f", {len(reject_changed)} reject " f"status change(s)" @@ -1302,12 +1275,12 @@ def _fix1_15XX(self, key, dry_run=True, verbose=True): } ) if has_downstream: - self._fix_15XX_units(auto_curation_key, new_labels, verbose) + self._fix_1513_units(auto_curation_key, new_labels, verbose) return result @staticmethod - def _fix_15XX_units(curation_key, new_labels, verbose=True): + def _fix_1513_units(curation_key, new_labels, verbose=True): """Update CuratedSpikeSorting.Unit labels for a curation. Updates the label column on existing Unit rows. Does NOT @@ -1347,7 +1320,7 @@ def _fix_15XX_units(curation_key, new_labels, verbose=True): ) @staticmethod - def _fix_15XX_nwb(curation_key, new_labels, verbose=True): + def _fix_1513_nwb(curation_key, new_labels, verbose=True): """Update labels in the CuratedSpikeSorting NWB analysis file. Edits the ``label`` column in the NWB units table in place @@ -1549,7 +1522,7 @@ def make_compute( recording = si.load_extractor(recording_path) timestamps = SpikeSortingRecording._get_recording_timestamps(recording) - (analysis_file_name, units_object_id) = Curation().save_sorting_nwb( + analysis_file_name, units_object_id = Curation().save_sorting_nwb( key=key, sorting=sorting, timestamps=timestamps, diff --git a/tests/spikesorting/v0/test_bug_1281.py b/tests/spikesorting/v0/test_bug_1513.py similarity index 100% rename from tests/spikesorting/v0/test_bug_1281.py rename to tests/spikesorting/v0/test_bug_1513.py diff --git a/track-1281.py b/track-1281.py deleted file mode 100644 index 2c5a10553..000000000 --- a/track-1281.py +++ /dev/null @@ -1,469 +0,0 @@ -"""Temporary table to track impact of `AutomaticCuration.get_labels` - -File to be deleted before merge. -""" - -import json -from copy import deepcopy -from datetime import datetime - -import datajoint as dj - -from spyglass.spikesorting.v0.spikesorting_curation import ( - AutomaticCuration, - AutomaticCurationParameters, - CuratedSpikeSorting, - CuratedSpikeSortingSelection, - Curation, - QualityMetrics, - _comparison_to_function, -) - -schema = dj.schema("cbroz_bugs") - - -@schema -class Bug1281(dj.Computed): - definition = """ - -> AutomaticCuration - --- - return_bug: bool # Bug A: early return after first metric - list_bug: bool # Bug B: list aliasing across units - dupe_bug: bool # Bug C: duplicate label comparison - has_downstream: bool # CuratedSpikeSorting depends on this - """ - - _return_bug_impact_date = datetime(2025, 4, 22) - - # -- Normalization ------------------------------------------------ - - @staticmethod - def _normalize_labels(labels): - """Return labels dict with all keys cast to int. - - Quality metrics loaded from JSON have string keys, while - the fixed ``get_labels`` uses ``int(unit_id)``. Normalize - to int so comparisons are consistent regardless of source. - - Parameters - ---------- - labels : dict - ``{unit_id: [label, ...]}`` with str or int keys. - - Returns - ------- - dict - Same structure with all keys as ``int``. - """ - return {int(k): v for k, v in labels.items()} - - # -- Fetch helpers ------------------------------------------------ - - @staticmethod - def _fetch_auto_curation_key(key): - """Return ``auto_curation_key`` blob.""" - return (AutomaticCuration & key).fetch1("auto_curation_key") - - @staticmethod - def _fetch_label_params(key): - """Return ``label_params`` dict.""" - params = (AutomaticCuration & key) * AutomaticCurationParameters - return params.fetch1("label_params") - - @staticmethod - def _fetch_quality_metrics(key): - """Load quality metrics JSON, or None if missing.""" - metrics_path = (QualityMetrics & key).fetch1("quality_metrics_path") - try: - with open(metrics_path) as f: - return json.load(f) - except FileNotFoundError: - return None - - @staticmethod - def _has_downstream(auto_curation_key): - """Check if a CuratedSpikeSortingSelection entry exists.""" - return len(CuratedSpikeSortingSelection & auto_curation_key) > 0 - - # -- Label computation -------------------------------------------- - - @classmethod - def _compute_expected(cls, parent_labels, label_params, quality_metrics): - """Compute fully-fixed labels, normalized to int keys. - - Parameters - ---------- - parent_labels : dict - Labels from the parent Curation (int-normalized). - label_params : dict - Label parameter rules. - quality_metrics : dict - Quality metrics loaded from JSON. - - Returns - ------- - dict - Recomputed labels with int keys. - """ - expected = AutomaticCuration.get_labels( - sorting=None, - parent_labels=deepcopy(parent_labels), - quality_metrics=quality_metrics, - label_params=label_params, - ) - return cls._normalize_labels(expected) - - @staticmethod - def _get_labels_buggy( - parent_labels, - quality_metrics, - label_params, - bug_a=False, - bug_b=False, - bug_c=False, - ): - """Run labeling logic with specified bugs enabled. - - Each flag re-introduces one historical bug while leaving the - rest of the logic fixed. Caller must ``deepcopy`` both - *parent_labels* and *label_params* before calling, because - Bug B mutates label_params through aliased references. - - All keys are normalized to ``int`` so that bug detection is - not confounded by key-type mismatches. - - Parameters - ---------- - parent_labels : dict - Starting labels (int keys). Will be mutated. - quality_metrics : dict - Quality metrics (string keys from JSON). - label_params : dict - Label parameter rules. - bug_a : bool - If True, return inside the ``for metric`` loop. - bug_b : bool - If True, skip ``.copy()`` on ``label[2]``. - bug_c : bool - If True, use list-in-list comparison + ``.extend()``. - - Returns - ------- - dict - Labels dict with int keys. - """ - if not label_params: - return parent_labels - - for metric in label_params: - if metric not in quality_metrics: - continue - - compare = _comparison_to_function[label_params[metric][0]] - - for unit_id in quality_metrics[metric]: - label = label_params[metric] - uid = int(unit_id) - - if not compare(quality_metrics[metric][unit_id], label[1]): - continue - - if uid not in parent_labels: - if bug_b: - parent_labels[uid] = label[2] - else: - parent_labels[uid] = label[2].copy() - else: - if bug_c: - if label[2] not in parent_labels[uid]: - parent_labels[uid].extend(label[2]) - else: - if "accept" in parent_labels[uid]: - parent_labels[uid].remove("accept") - for element in label[2].copy(): - if element not in parent_labels[uid]: - parent_labels[uid].append(element) - - if bug_a: - return parent_labels - - return parent_labels - - # -- Per-bug detection -------------------------------------------- - - @classmethod - def _detect_return_bug( - cls, - auto_curation_key, - parent_labels, - label_params, - quality_metrics, - expected, - ): - """Bug A: would early return produce different labels? - - Only possible when >1 metric overlaps with quality_metrics - AND the Curation was created on or after the date Bug A was - introduced (PR #1281, 2025-04-22). - """ - time_of_creation = (Curation & auto_curation_key).fetch1( - "time_of_creation" - ) - if time_of_creation < cls._return_bug_impact_date.timestamp(): - return False - overlap = set(label_params) & set(quality_metrics) - if len(overlap) <= 1: - return False - buggy = cls._get_labels_buggy( - deepcopy(parent_labels), - quality_metrics, - deepcopy(label_params), - bug_a=True, - ) - return cls._normalize_labels(buggy) != expected - - @classmethod - def _detect_list_bug( - cls, parent_labels, label_params, quality_metrics, expected - ): - """Bug B: would list aliasing cause cross-unit leakage? - - Aliasing manifests when multiple units match the same - metric (sharing a list object) and at least one of those - units also matches a subsequent metric, causing the - append to propagate to all aliased units. - """ - buggy = cls._get_labels_buggy( - deepcopy(parent_labels), - quality_metrics, - deepcopy(label_params), - bug_b=True, - ) - return cls._normalize_labels(buggy) != expected - - @classmethod - def _detect_dupe_bug( - cls, parent_labels, label_params, quality_metrics, expected - ): - """Bug C: would list-in-list comparison create duplicates? - - The old code checked ``label[2] not in parent_labels[uid]`` - (always True for flat string lists), so ``.extend()`` - always ran, producing duplicate labels. - """ - buggy = cls._get_labels_buggy( - deepcopy(parent_labels), - quality_metrics, - deepcopy(label_params), - bug_c=True, - ) - return cls._normalize_labels(buggy) != expected - - # -- Comparison helper -------------------------------------------- - - @classmethod - def _compare_labels(cls, stored_labels, expected_labels): - """Return per-unit diffs between stored and expected. - - Both dicts are normalized to int keys before comparison. - - Parameters - ---------- - stored_labels : dict - Labels from the Curation row. - expected_labels : dict - Labels from the fixed ``get_labels``. - - Returns - ------- - dict - ``{uid: {"stored": [...], "expected": [...]}}`` for - units whose labels differ. - """ - stored = cls._normalize_labels(stored_labels) - expected = cls._normalize_labels(expected_labels) - all_uids = set(stored) | set(expected) - diffs = {} - for uid in all_uids: - s = stored.get(uid) - e = expected.get(uid) - if s != e: - diffs[uid] = {"stored": s, "expected": e} - return diffs - - # -- Core methods ------------------------------------------------- - - def _insert_clean(self, key, has_downstream=False): - """Insert an unaffected entry.""" - self.insert1( - { - **key, - "return_bug": False, - "list_bug": False, - "dupe_bug": False, - "has_downstream": has_downstream, - } - ) - - def make(self, key): - # --- Early return: empty label_params --- - label_params = self._fetch_label_params(key) - if not label_params: - self._insert_clean(key) - return - - auto_curation_key = self._fetch_auto_curation_key(key) - - # --- Early return: missing quality metrics file --- - quality_metrics = self._fetch_quality_metrics(key) - if quality_metrics is None: - self._insert_clean(key) - return - - # --- Normalize parent labels to int keys --- - parent_curation = (Curation & key).fetch(as_dict=True)[0] - parent_labels = self._normalize_labels( - parent_curation["curation_labels"] - ) - - # --- Compute fully-fixed expected labels --- - expected = self._compute_expected( - parent_labels, label_params, quality_metrics - ) - - # --- Detect each bug independently --- - return_bug = self._detect_return_bug( - auto_curation_key, - parent_labels, - label_params, - quality_metrics, - expected, - ) - list_bug = self._detect_list_bug( - parent_labels, label_params, quality_metrics, expected - ) - dupe_bug = self._detect_dupe_bug( - parent_labels, label_params, quality_metrics, expected - ) - has_downstream = self._has_downstream(auto_curation_key) - - self.insert1( - { - **key, - "return_bug": return_bug, - "list_bug": list_bug, - "dupe_bug": dupe_bug, - "has_downstream": has_downstream, - } - ) - - def inspect(self, key): - """Print detailed diagnostics for one AutomaticCuration entry. - - Parameters - ---------- - key : dict - Primary key to AutomaticCuration (and thus Bug1281). - """ - # --- Bug1281 record --- - row = (self & key).fetch(as_dict=True) - if row: - row = row[0] - print("=== Bug1281 record ===") - print(f" return_bug (A): {row['return_bug']}") - print(f" list_bug (B): {row['list_bug']}") - print(f" dupe_bug (C): {row['dupe_bug']}") - print(f" has_downstream: {row['has_downstream']}") - else: - print("=== Bug1281 record: not yet populated ===") - return - - # --- label_params --- - label_params = self._fetch_label_params(key) - print(f"\n=== label_params " f"({len(label_params)} metric(s)) ===") - for metric, rule in label_params.items(): - print(f" {metric}: {rule[0]} {rule[1]} -> {rule[2]}") - - # --- quality_metrics overlap --- - quality_metrics = self._fetch_quality_metrics(key) - if quality_metrics is None: - print("\n=== Quality metrics: FILE NOT FOUND ===") - return - - overlap = sorted(set(label_params) & set(quality_metrics)) - missing_from_qm = sorted(set(label_params) - set(quality_metrics)) - print("\n=== Metric overlap ===") - print(f" label_params: {sorted(label_params.keys())}") - print(f" quality_metrics: " f"{sorted(quality_metrics.keys())}") - print(f" overlap ({len(overlap)}): {overlap}") - if missing_from_qm: - print(f" skipped (not in qm): {missing_from_qm}") - - # --- Stored vs expected labels --- - parent_curation = (Curation & key).fetch(as_dict=True)[0] - parent_labels = self._normalize_labels( - parent_curation["curation_labels"] - ) - expected = self._compute_expected( - parent_labels, label_params, quality_metrics - ) - auto_curation_key = self._fetch_auto_curation_key(key) - stored_labels = (Curation & auto_curation_key).fetch1("curation_labels") - diffs = self._compare_labels(stored_labels, expected) - - stored_norm = self._normalize_labels(stored_labels) - print("\n=== Label comparison ===") - print(f" total units (stored): {len(stored_norm)}") - print(f" total units (expected): {len(expected)}") - print(f" units with differences: {len(diffs)}") - if diffs: - print(f"\n {'unit':>8} " f"{'stored':<30} {'expected':<30}") - print(f" {'----':>8} " f"{'------':<30} {'--------':<30}") - for uid in sorted(diffs): - d = diffs[uid] - s_str = str(d["stored"]) if d["stored"] is not None else "---" - e_str = ( - str(d["expected"]) if d["expected"] is not None else "---" - ) - print(f" {uid:>8} {s_str:<30} {e_str:<30}") - - # --- Downstream impact --- - has_downstream = self._has_downstream(auto_curation_key) - print("\n=== Downstream ===") - print(f" CuratedSpikeSortingSelection: {has_downstream}") - if has_downstream: - n_units = len(CuratedSpikeSorting.Unit & auto_curation_key) - print(f" CuratedSpikeSorting.Unit rows: {n_units}") - - if diffs: - reject_changes = [] - for uid, d in diffs.items(): - was_reject = ( - d["stored"] is not None and "reject" in d["stored"] - ) - should_reject = ( - d["expected"] is not None and "reject" in d["expected"] - ) - if was_reject != should_reject: - reject_changes.append( - { - "unit": uid, - "was_rejected": was_reject, - "should_reject": should_reject, - } - ) - if reject_changes: - print( - f"\n Units with changed " - f"accept/reject status: " - f"{len(reject_changes)}" - ) - for rc in reject_changes: - status = ( - "SHOULD BE REJECTED " "(was accepted)" - if rc["should_reject"] - else "SHOULD BE ACCEPTED " "(was rejected)" - ) - print(f" unit {rc['unit']}: {status}") - else: - print("\n No units change accept/reject " "status")