From d12105c6220dbc7fdc60ef29a058d766879ce76e Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Wed, 18 Feb 2026 16:10:50 +0100 Subject: [PATCH 01/30] More tri-part makes. Allow db no-connect import fail in validate.py --- scripts/validate.py | 5 ++ src/spyglass/spikesorting/v1/recording.py | 78 +++++++++++++++++----- src/spyglass/spikesorting/v1/sorting.py | 79 +++++++++++++++-------- tests/common/test_video_import_fail.py | 2 +- tests/conftest.py | 5 +- 5 files changed, 124 insertions(+), 45 deletions(-) diff --git a/scripts/validate.py b/scripts/validate.py index 42deda8b8..a932afb33 100755 --- a/scripts/validate.py +++ b/scripts/validate.py @@ -198,6 +198,11 @@ def check_spyglass_import() -> None: print(f"✓ Spyglass version: {version}") except ImportError as e: raise RuntimeError(f"Cannot import spyglass: {e}") + except Exception as e: + if type(e).__name__ == "OperationalError" and "Can't connect" in str(e): + print("⚠ Spyglass import warning: Database connection issues") + return # Elsewhere, this is not treated as critical, so only warn + raise def check_spyglass_config() -> None: diff --git a/src/spyglass/spikesorting/v1/recording.py b/src/spyglass/spikesorting/v1/recording.py index 190c8e169..80efef467 100644 --- a/src/spyglass/spikesorting/v1/recording.py +++ b/src/spyglass/spikesorting/v1/recording.py @@ -16,7 +16,7 @@ from spyglass.common import Session # noqa: F401 from spyglass.common.common_device import Probe from spyglass.common.common_ephys import Electrode, Raw # noqa: F401 -from spyglass.common.common_interval import IntervalList +from spyglass.common.common_interval import IntervalLike, IntervalList from spyglass.common.common_lab import LabTeam from spyglass.common.common_nwbfile import AnalysisNwbfile, Nwbfile from spyglass.settings import analysis_dir, test_mode @@ -194,9 +194,15 @@ class SpikeSortingRecording(SpyglassMixin, dj.Computed): hash=null: varchar(32) # Hash of the NWB file """ - _use_transaction, _allow_insert = False, True + def _insert_sort_interval(self, key): + """Insert sort interval valid times into IntervalList. - def make(self, key): + Separated from make() so it can be called before _hash_upstream in + the no-transaction populate path, preventing a false-positive hash + mismatch that would silently delete the just-populated row. + """ + + def make_fetch(self, key): """Populate SpikeSortingRecording. 1. Get valid times for sort interval from IntervalList @@ -210,26 +216,53 @@ def make(self, key): nwb_file_name = (SpikeSortingRecordingSelection & key).fetch1( "nwb_file_name" ) - - key.update(self._make_file(key)) - - # INSERT: - # - valid times into IntervalList - # - analysis NWB file holding processed recording into AnalysisNwbfile - # - entry into SpikeSortingRecording sort_interval_valid_times = self._get_sort_interval_valid_times(key) sort_interval_valid_times.set_key( nwb_file_name=nwb_file_name, interval_list_name=key["recording_id"], pipeline="spikesorting_recording_v1", ) + return [nwb_file_name, sort_interval_valid_times] + + def make_compute( + self, key, nwb_file_name, sort_interval_valid_times + ) -> dict: + """Compute/save SpikeSortingRecording + + Returns + ------- + dict + Result of _make_file, containing: + analysis_file_name: str + object_id: UUID + electrodes_id: str + hash: str + """ + file_dict = self._make_file(key, parent_file_name=nwb_file_name) + + return [nwb_file_name, file_dict, sort_interval_valid_times] + + def make_insert( + self, + key: dict, + nwb_file_name: str, + file_dict: dict, + sort_interval_valid_times: IntervalLike, + ) -> dict: + insert_key = dict(key, **file_dict) + + # INSERT: + # - valid times into IntervalList + # - analysis NWB file holding processed recording into AnalysisNwbfile + # - entry into SpikeSortingRecording + IntervalList.insert1( sort_interval_valid_times.as_dict, skip_duplicates=True ) - AnalysisNwbfile().add(nwb_file_name, key["analysis_file_name"]) + AnalysisNwbfile().add(nwb_file_name, insert_key["analysis_file_name"]) - self.insert1(key) - self._record_environment(key) + self.insert1(insert_key) + self._record_environment(insert_key) def _record_environment(self, key): """Record environment details for this recording.""" @@ -244,7 +277,8 @@ def _make_file( recompute_file_name: str = None, save_to: Union[str, Path] = None, rounding: int = 4, - ): + parent_file_name: str = None, + ) -> dict: """Preprocess recording and write to NWB file. All `_make_file` methods should exit early if the file already exists. @@ -264,6 +298,15 @@ def _make_file( save_to : Union[str,Path], Optional Default None, save to analysis directory. If provided, save to specified path. Used for recomputation prior to deletion. + rounding : int, Optional + Decimal places to round to when hashing. Default 4, which is typical + for microvolt precision. Only used for hash computation, does not + affect data written to NWB file. + parent_file_name : str, Optional + If specified, use this NWB file as the source of the recording to be + preprocessed and written to the new NWB file. If none, fetch source + NWB file from SpikeSortingRecordingSelection. Used avoiding fetch + during tri-part make. """ if not key and not recompute_file_name: raise ValueError( @@ -295,11 +338,14 @@ def _make_file( else: recompute_object_id, recompute_electrodes_id = None, None - parent = SpikeSortingRecordingSelection & key + if not parent_file_name: + parent = SpikeSortingRecordingSelection & key + parent.fetch1("nwb_file_name") + recording_nwb_file_name, recording_object_id, electrodes_id = ( _write_recording_to_nwb( **cls()._get_preprocessed_recording(key), - nwb_file_name=parent.fetch1("nwb_file_name"), + nwb_file_name=parent_file_name, recompute_file_name=recompute_file_name, recompute_object_id=recompute_object_id, recompute_electrodes_id=recompute_electrodes_id, diff --git a/src/spyglass/spikesorting/v1/sorting.py b/src/spyglass/spikesorting/v1/sorting.py index 5738a26d9..6cfa2e9da 100644 --- a/src/spyglass/spikesorting/v1/sorting.py +++ b/src/spyglass/spikesorting/v1/sorting.py @@ -14,7 +14,7 @@ import spikeinterface.sorters as sis from spikeinterface.sortingcomponents.peak_detection import detect_peaks -from spyglass.common.common_interval import IntervalList +from spyglass.common.common_interval import IntervalLike, IntervalList from spyglass.common.common_nwbfile import AnalysisNwbfile from spyglass.settings import temp_dir from spyglass.spikesorting.v1.recording import ( # noqa: F401 @@ -187,14 +187,13 @@ class SpikeSorting(SpyglassMixin, dj.Computed): time_of_sort: int # in Unix time, to the nearest second """ - _use_transaction, _allow_insert = False, True _parallel_make = True # True if n_workers > 1 - def make(self, key: dict): + def make_fetch(self, key: dict) -> list: """Runs spike sorting on the data and parameters specified by the SpikeSortingSelection table and inserts a new entry to SpikeSorting table. """ - # FETCH (Spyglass logic - always tested): + # FETCH # - information about the recording # - artifact free intervals # - spike sorter and sorter params @@ -202,25 +201,42 @@ def make(self, key: dict): recording_key = ( SpikeSortingRecording * SpikeSortingSelection & key ).fetch1() + + nwb_file_name = recording_key["nwb_file_name"] + artifact_removed_intervals = ( IntervalList & { - "nwb_file_name": (SpikeSortingSelection & key).fetch1( - "nwb_file_name" - ), - "interval_list_name": (SpikeSortingSelection & key).fetch1( - "interval_list_name" - ), + "nwb_file_name": nwb_file_name, + "interval_list_name": recording_key["interval_list_name"], } ).fetch1("valid_times") + sorter, sorter_params = ( SpikeSorterParameters * SpikeSortingSelection & key ).fetch1("sorter", "sorter_params") + recording_analysis_nwb_file_abs_path = AnalysisNwbfile.get_abs_path( recording_key["analysis_file_name"] ) - # External dependency - MOCKABLE in tests + return [ + nwb_file_name, + artifact_removed_intervals, + sorter, + sorter_params, + recording_analysis_nwb_file_abs_path, + ] + + def make_compute( + self, + key: dict, + nwb_file_name: str, + artifact_removed_intervals: IntervalLike, + sorter: str, + sorter_params: dict, + recording_analysis_nwb_file_abs_path: str, + ): sorting, timestamps = self._run_spike_sorter( recording_analysis_nwb_file_abs_path=recording_analysis_nwb_file_abs_path, artifact_removed_intervals=artifact_removed_intervals, @@ -228,25 +244,34 @@ def make(self, key: dict): sorter_params=sorter_params, ) - # External I/O - MOCKABLE in tests - key["time_of_sort"] = int(time.time()) - key["analysis_file_name"], key["object_id"] = ( - self._save_sorting_results( - sorting=sorting, - timestamps=timestamps, - artifact_removed_intervals=artifact_removed_intervals, - nwb_file_name=(SpikeSortingSelection & key).fetch1( - "nwb_file_name" - ), - ) + time_of_sort = int(time.time()) + analysis_file_name, object_id = self._save_sorting_results( + sorting=sorting, + timestamps=timestamps, + artifact_removed_intervals=artifact_removed_intervals, + nwb_file_name=(SpikeSortingSelection & key).fetch1("nwb_file_name"), ) - # Database operations (Spyglass logic - always tested) - AnalysisNwbfile().add( - (SpikeSortingSelection & key).fetch1("nwb_file_name"), - key["analysis_file_name"], + return [nwb_file_name, time_of_sort, analysis_file_name, object_id] + + def make_insert( + self, + key: dict, + nwb_file_name: str, + time_of_sort: int, + analysis_file_name: str, + object_id: str, + ): + AnalysisNwbfile().add(nwb_file_name, analysis_file_name) + self.insert1( + dict( + key, + time_of_sort=time_of_sort, + analysis_file_name=analysis_file_name, + object_id=object_id, + ), + skip_duplicates=True, ) - self.insert1(key, skip_duplicates=True) def _run_spike_sorter( self, diff --git a/tests/common/test_video_import_fail.py b/tests/common/test_video_import_fail.py index fa0326a74..dc91cad6a 100644 --- a/tests/common/test_video_import_fail.py +++ b/tests/common/test_video_import_fail.py @@ -25,11 +25,11 @@ def nwb_with_video_no_task(raw_dir, common): experimenter=["Test Experimenter"], ) nwbfile.subject = mock_Subject() + camera_device = CameraDevice( name="camera_device 0", meters_per_pixel=0.001, manufacturer="Test Camera Co", - model="TestCam 3000", lens="50mm", camera_name="test_camera", ) diff --git a/tests/conftest.py b/tests/conftest.py index 1034c094d..c95aae750 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1423,7 +1423,10 @@ def pop_rec(spike_v1, mini_dict, team_name): ssr_pk = ( (spike_v1.SpikeSortingRecordingSelection & key).proj().fetch1("KEY") ) - spike_v1.SpikeSortingRecording.populate(ssr_pk) + spike_v1.SpikeSortingRecording.populate() + + if not spike_v1.SpikeSortingRecording() & ssr_pk: + raise ValueError("SpikeSortingRecording failed to populate.") yield ssr_pk From 3841e9add17c7581c0b6d4c2ca6f11f27900e8d8 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Thu, 19 Feb 2026 13:13:29 +0100 Subject: [PATCH 02/30] Fix local docker test_install.py runs --- tests/conftest.py | 26 ++++++++++++++++++-------- tests/setup/test_install.py | 11 ++++++++--- 2 files changed, 26 insertions(+), 11 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index c95aae750..a4c7bbfc8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -219,27 +219,37 @@ def worker_id(request): @pytest.fixture(scope="session") -def dj_conn(request, server, worker_id, verbose, teardown): - """Fixture for datajoint connection with pytest-xdist support. +def dj_config(verbose): + """Fixture for branch-specific config name""" + SERVER.wait() # ensure MySQL is ready before any test uses these credentials - For parallel execution, each worker gets its own database schema prefix - to avoid race conditions and ensure test isolation. - """ # Worker-specific config file to avoid conflicts config_file = "dj_local_conf.json" - if branch_name := server.branch_name: + if branch_name := SERVER.branch_name: config_file = f"dj_local_conf_{branch_name}.json" if Path(config_file).exists(): os.remove(config_file) # Set worker-specific schema prefix for database isolation - dj.config.update(server.credentials) + dj.config.update(SERVER.credentials) dj.config["loglevel"] = "INFO" if verbose else "ERROR" dj.config["database.prefix"] = "pytests" dj.config["custom"]["spyglass_dirs"] = {"base": str(BASE_DIR)} dj.config.save(config_file) + return config_file + + +@pytest.fixture(scope="session") +def dj_conn(dj_config): + """Fixture for datajoint connection with pytest-xdist support. + + For parallel execution, each worker gets its own database schema prefix + to avoid race conditions and ensure test isolation. + """ + dj.config.load(dj_config) + try: dj.conn().ping() except Exception as e: # If can't connect, exit all tests @@ -344,7 +354,7 @@ def mini_insert( dj_logger.info("Inserting test data.") - if not server.connected: + if not SERVER.connected: raise ConnectionError("No server connection.") if len(Nwbfile & mini_dict) != 0: diff --git a/tests/setup/test_install.py b/tests/setup/test_install.py index 2f14583a7..ce42f43c6 100644 --- a/tests/setup/test_install.py +++ b/tests/setup/test_install.py @@ -22,6 +22,7 @@ from pathlib import Path from unittest.mock import Mock, patch +import datajoint as dj import pytest # Add scripts to path @@ -1120,6 +1121,7 @@ def test_remote_without_credentials_fails(self, tmp_path): str(tmp_path), ], capture_output=True, + stdin=subprocess.DEVNULL, text=True, env={**os.environ, "HOME": str(tmp_path)}, ) @@ -1238,6 +1240,7 @@ def test_config_only_remote_missing_host(self, tmp_path): str(tmp_path), ], capture_output=True, + stdin=subprocess.DEVNULL, text=True, env={**os.environ, "HOME": str(tmp_path)}, timeout=5, # Should fail quickly, not hang @@ -2046,8 +2049,11 @@ class TestConfigCompatibility: can consume without missing required keys. """ - def test_installer_has_all_settings_keys(self, tmp_path): + def test_installer_has_all_settings_keys(self, dj_config, tmp_path): """Installer config contains all keys expected by settings.py.""" + + dj.config.load(dj_config) # Database connection before import + from spyglass.settings import SpyglassConfig base_dir = tmp_path / "spyglass_data" @@ -2122,9 +2128,8 @@ def test_installer_has_all_settings_keys(self, tmp_path): } # Get settings.py config structure - sg_config = SpyglassConfig() + sg_config = SpyglassConfig(base_dir=str(base_dir)) settings_config = sg_config._generate_dj_config( - base_dir=str(base_dir), database_user="testuser", database_password="testpass", database_host="localhost", From 9fcd4dea1b6d2323a7533c993ff10b51299c18f2 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Thu, 19 Feb 2026 13:33:43 +0100 Subject: [PATCH 03/30] Fix missing variable name --- src/spyglass/spikesorting/v1/recording.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spyglass/spikesorting/v1/recording.py b/src/spyglass/spikesorting/v1/recording.py index 80efef467..6ee2376bc 100644 --- a/src/spyglass/spikesorting/v1/recording.py +++ b/src/spyglass/spikesorting/v1/recording.py @@ -340,7 +340,7 @@ def _make_file( if not parent_file_name: parent = SpikeSortingRecordingSelection & key - parent.fetch1("nwb_file_name") + parent_file_name = parent.fetch1("nwb_file_name") recording_nwb_file_name, recording_object_id, electrodes_id = ( _write_recording_to_nwb( From 5c89f6d4f926f0a5a4cd60114587b25519a14be3 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Thu, 19 Feb 2026 14:40:10 +0100 Subject: [PATCH 04/30] Fetch #1529 fix from @samuelbray --- src/spyglass/common/common_user.py | 7 ++++- .../spikesorting/analysis/v1/group.py | 2 +- src/spyglass/utils/dj_graph.py | 26 +++++++++++++------ src/spyglass/utils/mixins/base.py | 8 +++--- tests/utils/test_graph.py | 17 ++++++++++-- 5 files changed, 43 insertions(+), 17 deletions(-) diff --git a/src/spyglass/common/common_user.py b/src/spyglass/common/common_user.py index 9f3cc203b..a290e7bcf 100644 --- a/src/spyglass/common/common_user.py +++ b/src/spyglass/common/common_user.py @@ -157,10 +157,15 @@ def _comment_install(self): ) def _warn_if_custom_or_conflict(self): + # NOTE: allowed editable items... + # spyglass - this package is often installed as editable for fetching + # updates without reinstalling. + # jsonschema - this package is often installed as editable in a tmp dir + pip_custom_no_spy = { k: "".join(v) for k, v in self._pip_custom.items() - if "spyglass" not in k + if "spyglass" not in k and "jsonschema" not in k } if pip_custom_no_spy: logger.warning( diff --git a/src/spyglass/spikesorting/analysis/v1/group.py b/src/spyglass/spikesorting/analysis/v1/group.py index aa8512cd2..cb60e715a 100644 --- a/src/spyglass/spikesorting/analysis/v1/group.py +++ b/src/spyglass/spikesorting/analysis/v1/group.py @@ -90,7 +90,7 @@ def create_group( if test_mode: return raise ValueError( - f"Group {nwb_file_name}: {group_name} already exists", + f"Group {nwb_file_name}: {group_name} already exists ", "please delete the group before creating a new one", ) diff --git a/src/spyglass/utils/dj_graph.py b/src/spyglass/utils/dj_graph.py index dc8f7c688..e9eb80ba8 100644 --- a/src/spyglass/utils/dj_graph.py +++ b/src/spyglass/utils/dj_graph.py @@ -1232,6 +1232,19 @@ def file_dict(self) -> Dict[str, List[str]]: self.cascade(warn=False) return {t: self._get_node(t).get("files", []) for t in self.restr_ft} + def _stored_files(self, as_dict=False) -> Dict[str, str] | Set[str]: + """Return dictionary of table names and files.""" + # Added for debugging + self.cascade(warn=False) + + files = { + table: file + for table in self.included_tables + for file in self._get_node(table).get("files", []) + } + + return files if as_dict else set(files.values()) + @property def file_paths(self) -> List[str]: """Return list of unique analysis files from all visited nodes. @@ -1239,15 +1252,12 @@ def file_paths(self) -> List[str]: This covers intermediate analysis files that may not have been fetched directly by the user. """ - self.cascade() - - files = { - file - for table in self.included_tables - for file in self._get_node(table).get("files", []) - } + self.cascade(warn=False) - return [self.analysis_file_tbl.get_abs_path(file) for file in files] + return [ + self.analysis_file_tbl.get_abs_path(file) + for file in self._stored_files() + ] class TableChain(RestrGraph): diff --git a/src/spyglass/utils/mixins/base.py b/src/spyglass/utils/mixins/base.py index b9c7ed2d2..9affa8d5a 100644 --- a/src/spyglass/utils/mixins/base.py +++ b/src/spyglass/utils/mixins/base.py @@ -36,7 +36,7 @@ def _graph_deps(self) -> list: return [TableChain, RestrGraph] - @property + @cached_property def _test_mode(self) -> bool: """Return True if in test mode. @@ -49,11 +49,9 @@ def _test_mode(self) -> bool: - BaseMixin._spyglass_version - HelpersMixin """ - import datajoint as dj + from spyglass.settings import config as sg_config - # Check dj.config directly instead of importing module-level variable - # which gets stale if load_config() is called after initial import - return dj.config.get("custom", {}).get("test_mode", False) + return sg_config.get("test_mode", False) @cached_property def _spyglass_version(self): diff --git a/tests/utils/test_graph.py b/tests/utils/test_graph.py index dab599f7d..9062a201c 100644 --- a/tests/utils/test_graph.py +++ b/tests/utils/test_graph.py @@ -145,8 +145,21 @@ def test_rg_restr_ft(restr_graph): def test_rg_file_paths(restr_graph): - """Test collection of upstream file paths.""" - assert len(restr_graph.file_paths) == 3, "Unexpected number of file paths." + """Test collection of upstream file paths. + + NOTE: This test previously tested how many files were collected, which may + differ if only subset of tests are run. Instead, we now check which tables + store collected files. See #1440, #1534 for context. + """ + expected_tbls = [ + "`position_linearization_v1`.`__linearized_position_v1`", + "`position_v1_trodes_position`.`__trodes_pos_v1`", + ] + stored_files = restr_graph._stored_files(as_dict=True) + for tbl in expected_tbls: + assert tbl in stored_files, f"Expected table {tbl} did not show file." + + assert len(restr_graph.file_paths) > 1, "Unexpected file paths collected." def test_rg_invalid_table(restr_graph): From 651180f144b739452dcb7a4a6992fa2a4a5e0e5e Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Thu, 19 Feb 2026 16:00:27 +0100 Subject: [PATCH 05/30] Revise mock cdf to avoid teardown error --- src/spyglass/utils/dj_graph.py | 9 ++++----- tests/decoding/conftest.py | 13 +++++++++++-- 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/src/spyglass/utils/dj_graph.py b/src/spyglass/utils/dj_graph.py index e9eb80ba8..0b3345420 100644 --- a/src/spyglass/utils/dj_graph.py +++ b/src/spyglass/utils/dj_graph.py @@ -1237,13 +1237,12 @@ def _stored_files(self, as_dict=False) -> Dict[str, str] | Set[str]: # Added for debugging self.cascade(warn=False) - files = { - table: file + pairs = [ + (table, file) for table in self.included_tables for file in self._get_node(table).get("files", []) - } - - return files if as_dict else set(files.values()) + ] + return dict(pairs) if as_dict else {file for _, file in pairs} @property def file_paths(self) -> List[str]: diff --git a/tests/decoding/conftest.py b/tests/decoding/conftest.py index a88af5b51..2cd097a2b 100644 --- a/tests/decoding/conftest.py +++ b/tests/decoding/conftest.py @@ -1,3 +1,4 @@ +from pathlib import Path from unittest.mock import patch import numpy as np @@ -30,10 +31,18 @@ def mock_to_netcdf( # Return bytes if no path given (original behavior for some use cases) return None + # Ensure parent directory exists + Path(path).parent.mkdir(parents=True, exist_ok=True) + # Keep the .nc extension to match expectations, but write pickle format # This avoids netCDF4/HDF5 errors while maintaining file path compatibility - with open(path, "wb") as f: - pickle.dump(self, f) + try: + with open(path, "wb") as f: + pickle.dump(self, f) + except (FileNotFoundError, PermissionError, OSError): + # Copilot suggested that this is where a file might throw error + # during teatdown, attempted automatic cleanup. + pass return None From 9543e5b776803bd5122944b4f3213cbd19f87871 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Thu, 19 Feb 2026 16:16:52 +0100 Subject: [PATCH 06/30] Prevent decoding teardown error --- src/spyglass/decoding/decoding_merge.py | 33 +++++++++---------------- 1 file changed, 12 insertions(+), 21 deletions(-) diff --git a/src/spyglass/decoding/decoding_merge.py b/src/spyglass/decoding/decoding_merge.py index e25e94524..12fa215cd 100644 --- a/src/spyglass/decoding/decoding_merge.py +++ b/src/spyglass/decoding/decoding_merge.py @@ -1,4 +1,3 @@ -from itertools import chain from pathlib import Path import datajoint as dj @@ -38,22 +37,23 @@ class SortedSpikesDecodingV1(SpyglassMixin, dj.Part): # noqa: F811 -> SortedSpikesDecodingV1 """ + def _fetch_registered_paths(self, attr): + """Fetch a filepath attribute from all part parents, skipping missing.""" + paths = [] + for tbl in self.merge_get_parent(multi_source=True): + try: + paths.extend(tbl.fetch(attr).tolist()) + except FileNotFoundError: + pass + return paths + def cleanup(self, dry_run=False): """Remove any decoding outputs that are not in the merge table""" if dry_run: logger.info("Dry run, not removing any files") else: logger.info("Cleaning up decoding outputs") - table_results_paths = list( - chain( - *[ - part_parent_table.fetch("results_path").tolist() - for part_parent_table in self.merge_get_parent( - multi_source=True - ) - ] - ) - ) + table_results_paths = self._fetch_registered_paths("results_path") for path in Path(config["SPYGLASS_ANALYSIS_DIR"]).glob("**/*.nc"): if str(path) not in table_results_paths: logger.info(f"Removing {path}") @@ -63,16 +63,7 @@ def cleanup(self, dry_run=False): except PermissionError: logger.warning(f"Unable to remove {path}, skipping") - table_model_paths = list( - chain( - *[ - part_parent_table.fetch("classifier_path").tolist() - for part_parent_table in self.merge_get_parent( - multi_source=True - ) - ] - ) - ) + table_model_paths = self._fetch_registered_paths("classifier_path") for path in Path(config["SPYGLASS_ANALYSIS_DIR"]).glob("**/*.pkl"): if str(path) not in table_model_paths: logger.info(f"Removing {path}") From 2c38b43860e73c85bd26d38da22abeeaa9aa0604 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Thu, 19 Feb 2026 16:37:09 +0100 Subject: [PATCH 07/30] Denoising tests --- pyproject.toml | 2 ++ src/spyglass/common/common_user.py | 7 ++++++- src/spyglass/decoding/v0/clusterless.py | 2 +- src/spyglass/decoding/v0/core.py | 2 +- .../decoding/v0/dj_decoder_conversion.py | 2 +- src/spyglass/decoding/v0/sorted_spikes.py | 2 +- src/spyglass/decoding/v0/utils.py | 2 +- src/spyglass/decoding/v1/clusterless.py | 2 +- .../position/v1/position_dlc_cohort.py | 2 +- .../spikesorting/v1/metric_curation.py | 18 ++++++++++-------- src/spyglass/spikesorting/v1/recompute.py | 5 ++++- src/spyglass/utils/dj_graph.py | 2 +- src/spyglass/utils/mixins/analysis.py | 6 +++--- tests/position/v1/conftest.py | 2 +- 14 files changed, 34 insertions(+), 22 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 632069a1d..e566c60a1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -172,6 +172,8 @@ filterwarnings = [ "ignore::ResourceWarning:.*", "ignore::DeprecationWarning:.*", "ignore::UserWarning:.*", + "ignore::FutureWarning:.*", + "ignore::PerformanceWarning:.*", "ignore::MissingRequiredBuildWarning:.*", ] markers = [ diff --git a/src/spyglass/common/common_user.py b/src/spyglass/common/common_user.py index a290e7bcf..68cafee8d 100644 --- a/src/spyglass/common/common_user.py +++ b/src/spyglass/common/common_user.py @@ -47,6 +47,7 @@ class UserEnvironment(SpyglassMixin, dj.Manual): _pip_custom = dict() # Custom pip installs from the environment _freeze_comments = dict() # Comments from pip freeze _conda_conflicts = dict() # Conda and pip conflicts in the environment + _env_warned = False # Suppress repeated warnings per process def _get_conda_export(self) -> Tuple[dict, dict]: """Fetch the current Conda environment export. @@ -162,6 +163,10 @@ def _warn_if_custom_or_conflict(self): # updates without reinstalling. # jsonschema - this package is often installed as editable in a tmp dir + if UserEnvironment._env_warned: + return + UserEnvironment._env_warned = True + pip_custom_no_spy = { k: "".join(v) for k, v in self._pip_custom.items() @@ -247,7 +252,7 @@ def _parse_pip_line(self, line: str) -> bool: return True # successfully parsed basic dependency # --- if conflicting versions, log conflict, overwrite with pip --- - logger.info(f"Conda/pip conflict: {line}") + logger.debug(f"Conda/pip conflict: {line}") self._conda_conflicts[package] = [pip_version, conda_version] self._conda_pip_dict[package] = line return True # successfully parsed basic dependency conflict diff --git a/src/spyglass/decoding/v0/clusterless.py b/src/spyglass/decoding/v0/clusterless.py index f3fe89a3b..f6e03aba0 100644 --- a/src/spyglass/decoding/v0/clusterless.py +++ b/src/spyglass/decoding/v0/clusterless.py @@ -45,7 +45,7 @@ DiagonalDiscrete, UniformInitialConditions, ) = [None] * 5 - logger.warning(e) + logger.debug(e) from tqdm.auto import tqdm diff --git a/src/spyglass/decoding/v0/core.py b/src/spyglass/decoding/v0/core.py index 11baf063a..4bd1f0a89 100644 --- a/src/spyglass/decoding/v0/core.py +++ b/src/spyglass/decoding/v0/core.py @@ -14,7 +14,7 @@ ) except (ImportError, ModuleNotFoundError) as e: RandomWalk, Uniform, Environment, ObservationModel = None, None, None, None - logger.warning(e) + logger.debug(e) from spyglass.common.common_behav import PositionIntervalMap, RawPosition from spyglass.common.common_interval import IntervalList diff --git a/src/spyglass/decoding/v0/dj_decoder_conversion.py b/src/spyglass/decoding/v0/dj_decoder_conversion.py index af03f541e..ad13bd687 100644 --- a/src/spyglass/decoding/v0/dj_decoder_conversion.py +++ b/src/spyglass/decoding/v0/dj_decoder_conversion.py @@ -41,7 +41,7 @@ UniformOneEnvironmentInitialConditions, ObservationModel, ) = [None] * 13 - logger.warning(e) + logger.debug(e) from track_linearization import make_track_graph diff --git a/src/spyglass/decoding/v0/sorted_spikes.py b/src/spyglass/decoding/v0/sorted_spikes.py index 914b8b936..ceda58cd9 100644 --- a/src/spyglass/decoding/v0/sorted_spikes.py +++ b/src/spyglass/decoding/v0/sorted_spikes.py @@ -34,7 +34,7 @@ DiagonalDiscrete, UniformInitialConditions, ) = [None] * 5 - logger.warning(e) + logger.debug(e) from spyglass.common.common_behav import ( convert_epoch_interval_name_to_position_interval_name, diff --git a/src/spyglass/decoding/v0/utils.py b/src/spyglass/decoding/v0/utils.py index 7f1ae1c15..d99ea4681 100644 --- a/src/spyglass/decoding/v0/utils.py +++ b/src/spyglass/decoding/v0/utils.py @@ -25,7 +25,7 @@ DiagonalDiscrete, UniformInitialConditions, ) = [None] * 6 - logger.warning(e) + logger.debug(e) def get_time_bins_from_interval(interval_times: np.array, sampling_rate: int): diff --git a/src/spyglass/decoding/v1/clusterless.py b/src/spyglass/decoding/v1/clusterless.py index 905fe54f3..9e0ac68aa 100644 --- a/src/spyglass/decoding/v1/clusterless.py +++ b/src/spyglass/decoding/v1/clusterless.py @@ -65,7 +65,7 @@ def create_group( "waveform_features_group_name": group_name, } if self & group_key: - logger.error( # No error on duplicate helps with pytests + logger.warning( # No error on duplicate helps with pytests f"Group {nwb_file_name}: {group_name} already exists" + "please delete the group before creating a new one", ) diff --git a/src/spyglass/position/v1/position_dlc_cohort.py b/src/spyglass/position/v1/position_dlc_cohort.py index d246e087b..1f9b90fce 100644 --- a/src/spyglass/position/v1/position_dlc_cohort.py +++ b/src/spyglass/position/v1/position_dlc_cohort.py @@ -114,7 +114,7 @@ def _logged_make(self, key): table_entries = [] bp_params_dict = cohort_selection.pop("bodyparts_params_dict") if len(bp_params_dict) == 0: - logger.warn("No bodyparts specified in bodyparts_params_dict") + logger.warning("No bodyparts specified in bodyparts_params_dict") self.insert1(key) return temp_key = cohort_selection.copy() diff --git a/src/spyglass/spikesorting/v1/metric_curation.py b/src/spyglass/spikesorting/v1/metric_curation.py index fe18afc0d..f170f76ef 100644 --- a/src/spyglass/spikesorting/v1/metric_curation.py +++ b/src/spyglass/spikesorting/v1/metric_curation.py @@ -373,14 +373,16 @@ def get_waveforms( # Extract non-sparse waveforms by default waveform_params.setdefault("sparse", False) - waveforms = si.extract_waveforms( - recording=recording, - sorting=sorting, - folder=waveforms_dir, - overwrite=overwrite, - load_if_exists=not overwrite, - **waveform_params, - ) + if overwrite: + waveforms = si.extract_waveforms( + recording=recording, + sorting=sorting, + folder=waveforms_dir, + overwrite=True, + **waveform_params, + ) + else: + waveforms = si.load_waveforms(waveforms_dir) self._waves_cache[key_hash] = waveforms diff --git a/src/spyglass/spikesorting/v1/recompute.py b/src/spyglass/spikesorting/v1/recompute.py index 7ce782bf9..30b486f00 100644 --- a/src/spyglass/spikesorting/v1/recompute.py +++ b/src/spyglass/spikesorting/v1/recompute.py @@ -339,7 +339,7 @@ def attempt_all( if not bool(RecordingRecompute & key) ] if not inserts: - logger.info(f"No rows to insert from:\n\t{source}") + logger.debug(f"No rows to insert from:\n\t{source}") return logger.info(f"Inserting recompute attempts for {len(inserts)} files.") @@ -903,6 +903,9 @@ def delete_files( file_names = query.fetch("analysis_file_name") prefix = "DRY RUN: " if dry_run else "" + if not len(file_names): + logger.debug(f"{prefix}Delete 0 files. Nothing to do.") + return msg = f"{prefix}Delete {len(file_names)} files?\n\t" + "\n\t".join( file_names[:10] ) diff --git a/src/spyglass/utils/dj_graph.py b/src/spyglass/utils/dj_graph.py index 0b3345420..dbcf82707 100644 --- a/src/spyglass/utils/dj_graph.py +++ b/src/spyglass/utils/dj_graph.py @@ -1000,7 +1000,7 @@ def cascade(self, show_progress=None, direction="up", warn=True) -> None: cascade=True, ) cascaded_leaves.append(leaf_graph) - logger.info("adding cascaded leaves") + logger.debug("adding cascaded leaves") self = self + cascaded_leaves self.cascaded = True # Mark here so next step can use `restr_ft` diff --git a/src/spyglass/utils/mixins/analysis.py b/src/spyglass/utils/mixins/analysis.py index 55f6f5963..86d9b60bf 100644 --- a/src/spyglass/utils/mixins/analysis.py +++ b/src/spyglass/utils/mixins/analysis.py @@ -832,7 +832,7 @@ def add_units( # to ensure that things go in the right order metric_values = metric_values[np.argsort(unit_ids)] - self._logger.info(f"Adding metric {metric} : {metric_values}") + self._logger.debug(f"Adding metric {metric} : {metric_values}") nwbf.add_unit_column( name=metric, description=f"{metric} metric", @@ -917,7 +917,7 @@ def add_units_waveforms( # If metrics were specified, add one column per metric if metrics is not None: for metric_name, metric_dict in metrics.items(): - self._logger.info( + self._logger.debug( f"Adding metric {metric_name} : {metric_dict}" ) metric_data = metric_dict.values().to_list() @@ -963,7 +963,7 @@ def add_units_metrics(self, analysis_file_name: str, metrics: dict): nwbf.add_unit(id=id) for metric_name, metric_dict in metrics.items(): - self._logger.info( + self._logger.debug( f"Adding metric {metric_name} : {metric_dict}" ) metric_data = list(metric_dict.values()) diff --git a/tests/position/v1/conftest.py b/tests/position/v1/conftest.py index a5db185b2..497baed9a 100644 --- a/tests/position/v1/conftest.py +++ b/tests/position/v1/conftest.py @@ -77,7 +77,7 @@ def increment_count(): def process_value(x): return increment_count() if x == 1 else x - return df.applymap(process_value) + return df.map(process_value) @pytest.fixture(scope="session") From a8ec95dd9a200a6f48dd08012f253e2eb716eb25 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Thu, 19 Feb 2026 16:57:37 +0100 Subject: [PATCH 08/30] Denoising tests 2 --- src/spyglass/spikesorting/v1/metric_curation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spyglass/spikesorting/v1/metric_curation.py b/src/spyglass/spikesorting/v1/metric_curation.py index f170f76ef..740e2f31b 100644 --- a/src/spyglass/spikesorting/v1/metric_curation.py +++ b/src/spyglass/spikesorting/v1/metric_curation.py @@ -373,7 +373,7 @@ def get_waveforms( # Extract non-sparse waveforms by default waveform_params.setdefault("sparse", False) - if overwrite: + if overwrite or not Path(waveforms_dir).exists(): waveforms = si.extract_waveforms( recording=recording, sorting=sorting, From f0fe31133da38ff58d30fcfa200d88f242583937 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Fri, 20 Feb 2026 12:24:26 +0100 Subject: [PATCH 09/30] Denoising tests 3 --- pyproject.toml | 2 + scripts/install.py | 749 +++++++++++++--------- src/spyglass/position/utils_dlc.py | 52 +- src/spyglass/spikesorting/v1/recompute.py | 4 +- src/spyglass/utils/mixins/analysis.py | 4 +- src/spyglass/utils/mixins/base.py | 12 + src/spyglass/utils/mixins/ingestion.py | 2 +- tests/conftest.py | 30 + tests/position/v1/test_pos_merge.py | 4 +- tests/setup/conftest.py | 16 + 10 files changed, 554 insertions(+), 321 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e566c60a1..5a74e516c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -175,6 +175,8 @@ filterwarnings = [ "ignore::FutureWarning:.*", "ignore::PerformanceWarning:.*", "ignore::MissingRequiredBuildWarning:.*", + # DLC training ends by cancelling a TF enqueue thread; expected, not a bug + "ignore::pytest.PytestUnhandledThreadExceptionWarning", ] markers = [ # Speed-based markers (based on total time: setup + call + teardown) diff --git a/scripts/install.py b/scripts/install.py index eb7b6d8ce..a70ceede0 100755 --- a/scripts/install.py +++ b/scripts/install.py @@ -52,6 +52,7 @@ else {"success": "[OK]", "error": "[X]", "warning": "[!]", "step": "->"} ) + # System constants BYTES_PER_GB = 1024**3 LOCALHOST_ADDRESSES = frozenset(["localhost", "127.0.0.1", "::1"]) @@ -125,26 +126,37 @@ class Console: For standalone status messages, use success/warning/error directly. """ + # Suppress non-essential output when running under pytest + _quiet: bool = os.getenv("SPYGLASS_INSTALL_QUIET", "0") not in ("", "0") + @staticmethod def step(msg: str) -> None: """Print step message without newline, waiting for done/fail.""" + if Console._quiet: + return print(f"{msg}... ", end="", flush=True) @staticmethod def done(msg: str = "") -> None: """Complete a step with success checkmark.""" + if Console._quiet: + return suffix = f" {msg}" if msg else "" print(f"{COLORS['green']}{SYMBOLS['success']}{COLORS['reset']}{suffix}") @staticmethod def fail(msg: str = "") -> None: """Complete a step with failure mark.""" + if Console._quiet: + return suffix = f" {msg}" if msg else "" print(f"{COLORS['red']}{SYMBOLS['error']}{COLORS['reset']}{suffix}") @staticmethod def success(msg: str, indent: bool = False) -> None: """Print standalone success message (with newline).""" + if Console._quiet: + return prefix = " " if indent else "" print( f"{prefix}{COLORS['green']}{SYMBOLS['success']}{COLORS['reset']} {msg}" @@ -153,6 +165,8 @@ def success(msg: str, indent: bool = False) -> None: @staticmethod def warning(msg: str, indent: bool = False) -> None: """Print warning message.""" + if Console._quiet: + return prefix = " " if indent else "" print( f"{prefix}{COLORS['yellow']}{SYMBOLS['warning']}{COLORS['reset']} {msg}" @@ -161,6 +175,8 @@ def warning(msg: str, indent: bool = False) -> None: @staticmethod def error(msg: str, indent: bool = False) -> None: """Print error message.""" + if Console._quiet: + return prefix = " " if indent else "" print( f"{prefix}{COLORS['red']}{SYMBOLS['error']}{COLORS['reset']} {msg}" @@ -169,9 +185,38 @@ def error(msg: str, indent: bool = False) -> None: @staticmethod def info(msg: str, indent: bool = False) -> None: """Print info message (blue color).""" + if Console._quiet: + return prefix = " " if indent else "" print(f"{prefix}{COLORS['blue']}{msg}{COLORS['reset']}") + @staticmethod + def print(msg: str, color: Optional[str] = None, indent: int = 0) -> None: + """Print message with optional color and indentation. + + Parameters + ---------- + msg : str + Message to print + color : str, optional + Color name: "blue", "green", "yellow", "red" (default: no color) + indent : int + Indentation level (number of 2-space indents, default: 0) + """ + if Console._quiet: + return + prefix = " " * indent + c = COLORS.get(color, COLORS["reset"]) + r = COLORS["reset"] + print(f"{prefix}{c}{msg}{r}") + + def multi( + self, msgs: List[str], color: Optional[str] = None, indent: int = 0 + ) -> None: + """Print multiple messages with optional color and indentation.""" + for msg in msgs: + self.print(msg, color=color, indent=indent) + @staticmethod def banner(msg: str, color: str = "blue", width: int = 60) -> None: """Print a banner/header line. @@ -185,6 +230,8 @@ def banner(msg: str, color: str = "blue", width: int = 60) -> None: width : int Width of the separator line """ + if Console._quiet: + return separator = "=" * width c = COLORS.get(color, COLORS["reset"]) r = COLORS["reset"] @@ -198,6 +245,8 @@ def banner(msg: str, color: str = "blue", width: int = 60) -> None: @staticmethod def manual_password_instructions(env_name: str) -> None: """Print instructions for manual password change.""" + if Console._quiet: + return print("\nYou can change it manually later:") print(f" conda activate {env_name}") print(" python -c 'import datajoint as dj; dj.set_password()'") @@ -230,6 +279,8 @@ def progress(operation: str, estimated_minutes: int) -> None: Unlike step(), this prints on its own line with details below. """ + if Console._quiet: + return print(f"{operation}...") print(f" Estimated time: ~{estimated_minutes} minute(s)") print(" This may take a while - please be patient...") @@ -349,7 +400,7 @@ def check_prerequisites( -------- >>> check_prerequisites("minimal", Path("/tmp/spyglass_data")) """ - print("Checking prerequisites...") + Console.print("Checking prerequisites...") # Get Python version requirement from pyproject.toml min_version = get_required_python_version() @@ -398,26 +449,27 @@ def check_prerequisites( return needed_to_free = required_gb - available_gb + min_total = DISK_SPACE_REQUIREMENTS["minimal"]["total"] + Console.error( "Insufficient disk space - installation cannot continue", indent=True, ) - print(f" Checking: {base_dir}") - print(f" Available: {available_gb} GB") - print( - f" Required: {required_gb} GB ({install_type}: ~{packages_gb} GB packages + ~{buffer_gb} GB buffer)" - ) - print(f" Need to free: {needed_to_free} GB") - print() - print(" To fix:") - print(f" 1. Free at least {needed_to_free} GB in this location") - print( - " 2. Choose different directory: python scripts/install.py --base-dir /other/path" - ) - min_total = DISK_SPACE_REQUIREMENTS["minimal"]["total"] - print( - f" 3. Use minimal install (needs {min_total} GB): python scripts/install.py --minimal" + Console().multi( + [ + f"Checking: {base_dir}", + f"Available: {available_gb} GB", + f"Required: {required_gb} GB ({install_type}: ~{packages_gb} GB packages + ~{buffer_gb} GB buffer)", + f"Need to free: {needed_to_free} GB", + "", + "To fix:", + f" 1. Free at least {needed_to_free} GB in this location", + " 2. Choose different directory: python scripts/install.py --base-dir /other/path", + f" 3. Use minimal install (needs {min_total} GB): python scripts/install.py --minimal", + ], + indent=2, ) + raise RuntimeError("Insufficient disk space") @@ -530,17 +582,20 @@ def create(self, env_file: str, force: bool = False) -> None: Console.success( f"Using existing environment '{self.env_name}'" ) - print( - " Package installation will continue (updates if needed)" - ) - print( - " To use a different name, run with: --env-name " + Console().multi( + [ + "Package installation will continue (updates if needed)", + "To use a different name, run with: --env-name ", + ], + indent=1, ) return self.remove() conda_cmd = self.get_command() - print(" Installing packages... (this will take several minutes)") + Console.print( + " Installing packages... (this will take several minutes)" + ) env_file_path = REPO_ROOT / env_file @@ -565,8 +620,8 @@ def create(self, env_file: str, force: bool = False) -> None: kw in line for kw in ["Solving", "Downloading", "Extracting"] ): - print(".", end="", flush=True) - print() + Console.print(".", end="", flush=True) + Console.print() if process.returncode != 0: raise subprocess.CalledProcessError( @@ -731,14 +786,17 @@ def validate_and_test_write(path: Path) -> Path: return validated_path # 3. Interactive prompt - print("\nWhere should Spyglass store data?") - print(" This will store raw NWB files, analysis results, and video data.") - print(" Typical usage: 10-100+ GB depending on your experiments.") - print() default = Path.home() / "spyglass_data" - print(f" Default: {default}") - print( - " Tip: Set SPYGLASS_BASE_DIR environment variable to skip this prompt" + Console.print("\nWhere should Spyglass store data?") + Console().multi( + [ + "This will store raw NWB files, analysis results, and video data.", + "Typical usage: 10-100+ GB depending on your experiments.", + "", + f"Default: {default}", + "Tip: Set SPYGLASS_BASE_DIR environment variable to skip this prompt", + ], + indent=1, ) while True: @@ -761,7 +819,7 @@ def validate_and_test_write(path: Path) -> Path: Console.error( f"Parent directory does not exist: {base_path.parent}" ) - print( + Console.print( " Please create parent directory first or choose another location" ) continue @@ -814,9 +872,7 @@ def prompt_install_type() -> Tuple[str, str]: >>> env_file, install_type = prompt_install_type() >>> print(f"Using {env_file} for {install_type} installation") """ - print("\n" + "=" * 60) - print("Installation Type") - print("=" * 60) + Console.banner("Installation Type") # Get disk space values from constants for consistency min_pkg = DISK_SPACE_REQUIREMENTS["minimal"]["packages"] @@ -824,32 +880,36 @@ def prompt_install_type() -> Tuple[str, str]: full_pkg = DISK_SPACE_REQUIREMENTS["full"]["packages"] full_total = DISK_SPACE_REQUIREMENTS["full"]["total"] - print("\n1. Minimal (Recommended for getting started)") - print(f" ├─ Install time: ~{ENV_CREATION_TIME_MINIMAL} minutes") - print( - f" ├─ Disk space: ~{min_pkg} GB packages ({min_total} GB total with buffer)" + Console.print("\n1. Minimal (Recommended for getting started)") + Console().multi( + [ + f"├─ Install time: ~{ENV_CREATION_TIME_MINIMAL} minutes", + f"├─ Disk space: ~{min_pkg} GB packages ({min_total} GB total with buffer)", + "├─ Includes:", + "│ • Core Spyglass functionality", + "│ • Common data tables", + "│ • Position tracking", + "│ • LFP analysis", + "│ • Basic spike sorting", + "└─ Good for: Learning, basic workflows", + ], + indent=2, ) - print(" ├─ Includes:") - print(" │ • Core Spyglass functionality") - print(" │ • Common data tables") - print(" │ • Position tracking") - print(" │ • LFP analysis") - print(" │ • Basic spike sorting") - print(" └─ Good for: Learning, basic workflows") - - print("\n2. Full (For advanced analysis)") - print(f" ├─ Install time: ~{ENV_CREATION_TIME_FULL} minutes") - print( - f" ├─ Disk space: ~{full_pkg} GB packages ({full_total} GB total with buffer)" + Console.print("\n2. Full (For advanced analysis)") + Console().multi( + [ + f"├─ Install time: ~{ENV_CREATION_TIME_FULL} minutes", + f"├─ Disk space: ~{full_pkg} GB packages ({full_total} GB total with buffer)", + "├─ Includes: Everything in Minimal, plus:", + "│ • Advanced spike sorting (Kilosort, etc.)", + "│ • Ripple detection", + "│ • Track linearization", + "└─ Good for: Production work, all features", + ], + indent=2, ) - print(" ├─ Includes: Everything in Minimal, plus:") - print(" │ • Advanced spike sorting (Kilosort, etc.)") - print(" │ • Ripple detection") - print(" │ • Track linearization") - print(" └─ Good for: Production work, all features") - - print("\nNote: DeepLabCut, Moseq, and some decoding features") - print(" require separate installation (see docs)") + Console.print("\nNote: DeepLabCut, Moseq, and some decoding features") + Console.print(" require separate installation (see docs)") # Map choices to (env_file, install_type) choice_map = { @@ -1016,7 +1076,7 @@ def cleanup() -> None: cwd=REPO_ROOT, ) except (subprocess.TimeoutExpired, FileNotFoundError): - print("Failed to clean up Docker Compose setup") + Console.warning("Failed to clean up Docker Compose setup") # ============================================================================ @@ -1155,8 +1215,8 @@ def build_directory_structure( directories = {} if verbose and create: - print(f"Creating Spyglass directory structure in {base_dir}") - print(" Creating:") + Console.print(f"Creating Spyglass directory structure in {base_dir}") + Console.print("Creating:", indent=1) for prefix, dir_map in schema.items(): for key, rel_path in dir_map.items(): @@ -1166,10 +1226,10 @@ def build_directory_structure( if create: full_path.mkdir(parents=True, exist_ok=True) if verbose: - print(f" • {rel_path}") + Console.print(f" • {rel_path}") if verbose and create: - print(f" ✓ Created {len(directories)} directories") + Console.success(f"Created {len(directories)} directories", indent=True) return directories @@ -1196,13 +1256,19 @@ def determine_tls(host: str) -> bool: # User-friendly messaging (plain language instead of technical terms) if not use_tls: # localhost Console.info(f"✓ Connecting to local database at {host}") - print(" Security: Using unencrypted connection (safe for localhost)") + Console.print( + "Security: Using unencrypted connection (safe for localhost)", + indent=1, + ) else: Console.info(f"✓ Connecting to remote database at {host}") - print( - " Security: Using encrypted connection (TLS) to protect your data" + Console.print( + "Security: Using encrypted connection (TLS) to protect your data", + indent=1, + ) + Console.print( + "This is required when connecting over a network", indent=1 ) - print(" This is required when connecting over a network") return use_tls @@ -1334,24 +1400,28 @@ def create_database_config( # Handle existing config file with better UX if config_file.exists(): Console.warning(f"Configuration file already exists: {config_file}") - print("\nExisting database settings:") + Console.print("\nExisting database settings:") try: with config_file.open() as f: existing = json.load(f) existing_host = existing.get("database.host", "unknown") existing_port = existing.get("database.port", "unknown") existing_user = existing.get("database.user", "unknown") - print(f" Database: {existing_host}:{existing_port}") - print(f" User: {existing_user}") + Console.print( + f"Database: {existing_host}:{existing_port}", indent=1 + ) + Console.print(f"User: {existing_user}", indent=1) except (OSError, IOError, json.JSONDecodeError, KeyError) as e: - print(f" (Unable to read existing config: {e})") - - print("\nOptions:") - print( - " [b] Backup and create new (saves to .datajoint_config.json.backup)" + Console.print(f"(Unable to read existing config: {e})", indent=1) + + Console().multi( + [ + "\nOptions:", + " [b] Backup and create new (saves to .datajoint_config.json.backup)", + " [o] Overwrite with new settings", + " [k] Keep existing (cancel installation)", + ] ) - print(" [o] Overwrite with new settings") - print(" [k] Keep existing (cancel installation)") choice = input("\nChoice [B/o/k]: ").strip().lower() or "b" @@ -1359,11 +1429,13 @@ def create_database_config( Console.warning( "Keeping existing configuration. Installation cancelled." ) - print("\nTo install with different settings:") - print( - " 1. Backup your config: cp ~/.datajoint_config.json ~/.datajoint_config.json.backup" + Console().multi( + [ + "\nTo install with different settings:", + " 1. Backup your config: cp ~/.datajoint_config.json ~/.datajoint_config.json.backup", + " 2. Run installer again", + ] ) - print(" 2. Run installer again") return elif choice in ["b", "backup"]: backup_file = config_file.with_suffix(".json.backup") @@ -1406,31 +1478,38 @@ def create_database_config( # Atomic move (on same filesystem) shutil.move(tmp_path, config_file) - Console.success(f"Configuration saved to: {config_file}") - print(" Permissions: Owner read/write only (secure)") + tls_status = "Yes" if use_tls else "No (localhost)" - # Enhanced success message with next steps - print() + Console.success(f"Configuration saved to: {config_file}") + Console().multi( + [ + " Permissions: Owner read/write only (secure)", + "", + ] + ) Console.success("✓ Spyglass configuration complete!") - print() - print("Database connection:") - print(f" • Server: {host}:{port}") - print(f" • User: {user}") - tls_status = "Yes" if use_tls else "No (localhost)" - print(f" • Encrypted: {tls_status}") - print() - print("Data directories:") - print(f" • Base: {base_dir}") - print(f" • Raw data: {config['custom']['spyglass_dirs']['raw']}") - print(f" • Analysis: {config['custom']['spyglass_dirs']['analysis']}") - print(f" • ({len(dirs)} directories total)") - print() - print("Next steps:") - print(" 1. Activate environment: conda activate spyglass") - print(" 2. Test your installation: python scripts/validate.py") - print(" 3. Start using Spyglass: python -c 'import spyglass'") - print() - print("Need help? See: https://lorenfranklab.github.io/spyglass/") + Console().multi( + [ + "", + "Database connection:", + f" • Server: {host}:{port}", + f" • User: {user}", + f" • Encrypted: {tls_status}", + "", + "Data directories:", + f" • Base: {base_dir}", + f" • Raw data: {config['custom']['spyglass_dirs']['raw']}", + f" • Analysis: {config['custom']['spyglass_dirs']['analysis']}", + f" • ({len(dirs)} directories total)", + "", + "Next steps:", + " 1. Activate environment: conda activate spyglass", + " 2. Test your installation: python scripts/validate.py", + " 3. Start using Spyglass: python -c 'import spyglass'", + "", + "Need help? See: https://lorenfranklab.github.io/spyglass/", + ] + ) # ============================================================================= @@ -1746,10 +1825,14 @@ def prompt_remote_database_config() -> Optional[Dict[str, Any]]: >>> if config: ... print(f"Connecting to {config['host']}:{config['port']}") """ - print("\nRemote database configuration:") - print(" Your lab admin should have provided these credentials.") - print(" Check your welcome email or contact your admin if unsure.") - print(" (Press Ctrl+C to cancel)") + Console().multi( + [ + "\nRemote database configuration:", + " Your lab admin should have provided these credentials.", + " Check your welcome email or contact your admin if unsure.", + " (Press Ctrl+C to cancel)", + ] + ) try: host = input(" Host (e.g., db.lab.edu): ").strip() @@ -1757,7 +1840,9 @@ def prompt_remote_database_config() -> Optional[Dict[str, Any]]: # Require explicit host for remote database if not host: Console.error("Host is required for remote database connection") - print(" Ask your lab admin for the database hostname") + Console.print( + "Ask your lab admin for the database hostname", indent=1 + ) return None port_str = input(" Port [3306]: ").strip() or "3306" @@ -1783,7 +1868,7 @@ def prompt_remote_database_config() -> Optional[Dict[str, Any]]: } except KeyboardInterrupt: - print("\n") + Console.print("") Console.warning("Database configuration cancelled") return None @@ -1881,13 +1966,12 @@ def prompt_database_setup() -> str: >>> if choice == "compose": ... setup_database_compose() """ - print("\n" + "=" * 60) - print("Database Setup") - print("=" * 60) + Console.print("") + Console.banner("Database Setup") options, compose_available = get_database_options() - print("\nOptions:") + Console.print("\nOptions:") for opt in options: # Color status based on availability (check for success/error symbols) if SYMBOLS["error"] in opt.status: @@ -1896,22 +1980,24 @@ def prompt_database_setup() -> str: status_color = COLORS["green"] else: status_color = COLORS["reset"] - print( + Console.print( f" {opt.number}. {opt.name:20} {status_color}{opt.status}{COLORS['reset']}" ) - print(f" {opt.description}") + Console.print(f" {opt.description}") # If Docker not available, guide user if not compose_available: - print() + Console.print("") Console.warning("Docker is not available") - print(" To enable Docker setup:") - print( - " 1. Install Docker Desktop: https://docs.docker.com/get-docker/" + Console().multi( + [ + " To enable Docker setup:", + " 1. Install Docker Desktop: https://docs.docker.com/get-docker/", + " 2. Start Docker Desktop", + " 3. Verify: docker compose version", + " 4. Re-run installer", + ] ) - print(" 2. Start Docker Desktop") - print(" 3. Verify: docker compose version") - print(" 4. Re-run installer") # Map choices to actions (updated order: Remote first, then Docker) choice_map = { @@ -1998,50 +2084,67 @@ def setup_database_compose() -> Tuple[bool, str]: port_available, port_msg = is_port_available("localhost", actual_port) if not port_available: Console.error(port_msg) - print(f"\n Port {actual_port} is already in use. Solutions:") + Console.print( + f"\n Port {actual_port} is already in use. Solutions:" + ) # Platform-specific guidance if sys.platform == "darwin": # macOS - print(" 1. Stop existing MySQL (if installed):") - print(" brew services stop mysql") - print( - " # or: sudo launchctl unload -w /Library/LaunchDaemons/com.mysql.mysql.plist" + Console().multi( + [ + " 1. Stop existing MySQL (if installed):", + " brew services stop mysql", + " # or: sudo launchctl unload -w /Library/LaunchDaemons/com.mysql.mysql.plist", + " 2. Find what's using the port:", + f" lsof -i :{actual_port}", + ] ) - print(" 2. Find what's using the port:") - print(f" lsof -i :{actual_port}") elif sys.platform.startswith("linux"): # Linux - print(" 1. Stop existing MySQL service:") - print(" sudo systemctl stop mysql") - print(" # or: sudo service mysql stop") - print(" 2. Find what's using the port:") - print(f" sudo lsof -i :{actual_port}") - print(f" # or: sudo netstat -tulpn | grep {actual_port}") + Console().multi( + [ + " 1. Stop existing MySQL service:", + " sudo systemctl stop mysql", + " # or: sudo service mysql stop", + " 2. Find what's using the port:", + f" sudo lsof -i :{actual_port}", + f" # or: sudo netstat -tulpn | grep {actual_port}", + ] + ) elif sys.platform == "win32": # Windows - print(" 1. Stop existing MySQL service:") - print(" net stop MySQL") - print(" # or use Services app (services.msc)") - print(" 2. Find what's using the port:") - print(f" netstat -ano | findstr :{actual_port}") + Console().multi( + [ + " 1. Stop existing MySQL service:", + " net stop MySQL", + " # or use Services app (services.msc)", + " 2. Find what's using the port:", + f" netstat -ano | findstr :{actual_port}", + ] + ) - print(" Alternative: Use a different port:") + Console.print(" Alternative: Use a different port:") if env_path.exists(): - print(f" Edit {env_path} and set MYSQL_PORT=3307") + Console.print(f" Edit {env_path} and set MYSQL_PORT=3307") else: - print(" Create .env file with: MYSQL_PORT=3307") - print(" (and update DataJoint config to match)") + Console.print(" Create .env file with: MYSQL_PORT=3307") + Console.print(" (and update DataJoint config to match)") return False, "port_in_use" # Show what will happen - print("\n" + "=" * 60) - print("Docker Database Setup") - print("=" * 60) - print("\nThis will:") - print(" • Download MySQL 8.0 Docker image (~500 MB)") - print(" • Create a container named 'spyglass-db'") - print(f" • Start MySQL on localhost:{actual_port}") - print(" • Save credentials to ~/.datajoint_config.json") - print("\nEstimated time: 2-3 minutes") - print("=" * 60) + Console.print("") + Console.banner("Docker Database Setup") + Console().multi( + [ + "", + "This will:", + " • Download MySQL 8.0 Docker image (~500 MB)", + " • Create a container named 'spyglass-db'", + f" • Start MySQL on localhost:{actual_port}", + " • Save credentials to ~/.datajoint_config.json", + "", + "Estimated time: 2-3 minutes", + ] + ) + Console.banner("") # Get compose command compose_cmd = DockerManager.get_compose_command() @@ -2061,20 +2164,28 @@ def setup_database_compose() -> Tuple[bool, str]: # Prioritize causes by likelihood (network issues most common) if "no space" in error_lower or "disk" in error_lower: - print("\n Most likely cause: Insufficient disk space") - print(" Fix: Free up space with: docker system prune -a") + Console.print("\n Most likely cause: Insufficient disk space") + Console.print( + " Fix: Free up space with: docker system prune -a" + ) elif "timeout" in error_lower or "connection" in error_lower: - print("\n Most likely cause: Network connection issue") - print(" Fix: Check internet connection and retry") + Console.print("\n Most likely cause: Network connection issue") + Console.print(" Fix: Check internet connection and retry") else: - print("\n Most likely cause: Network or Docker Hub issue") - print(" Fix: Wait a moment and retry") - - print("\n Other steps to try:") - print(" 1. Check internet connection") - print(" 2. Check disk space: docker system df") - print(" 3. Retry: docker compose pull") - print(" 4. If persistent, try: docker system prune") + Console.print( + "\n Most likely cause: Network or Docker Hub issue" + ) + Console.print(" Fix: Wait a moment and retry") + + Console().multi( + [ + "\n Other steps to try:", + " 1. Check internet connection", + " 2. Check disk space: docker system df", + " 3. Retry: docker compose pull", + " 4. If persistent, try: docker system prune", + ] + ) return False, "pull_failed" # Start services @@ -2137,7 +2248,8 @@ def setup_database_compose() -> Tuple[bool, str]: pass if attempt < MYSQL_HEALTH_CHECK_ATTEMPTS - 1: - print(".", end="", flush=True) + if not Console._quiet: + print(".", end="", flush=True) time.sleep(MYSQL_HEALTH_CHECK_INTERVAL) else: # Timeout - provide debug info @@ -2145,8 +2257,8 @@ def setup_database_compose() -> Tuple[bool, str]: Console.error( f"MySQL did not become ready within {MYSQL_HEALTH_CHECK_TIMEOUT} seconds" ) - print("\n Check logs:") - print(" docker compose logs mysql") + Console.print("\n Check logs:") + Console.print(" docker compose logs mysql") DockerManager.cleanup() return False, "timeout" @@ -2227,7 +2339,7 @@ def test_database_connection( except ImportError: # pymysql not available yet (before pip install) Console.warning("Cannot test connection (pymysql not available)") - print(" Connection will be tested during validation") + Console.print("Connection will be tested during validation", indent=1) return True, None # Allow to proceed try: @@ -2304,19 +2416,23 @@ def handle_database_setup_interactive(env_name: str) -> None: else: Console.error("Docker setup failed") if reason == "compose_unavailable": - print("\nDocker is not available.") - print(" Option 1: Install Docker Desktop and restart") - print(" Option 2: Choose remote database") - print(" Option 3: Skip for now") + Console().multi( + [ + "\nDocker is not available.", + " Option 1: Install Docker Desktop and restart", + " Option 2: Choose remote database", + " Option 3: Skip for now", + ] + ) else: - print(f" Error: {reason}") + Console.print(f" Error: {reason}") if not Console.prompt_yes_no( "\nTry different option?", default_yes=True ): Console.warning("Skipping database setup") - print(" Configure later: docker compose up -d") - print(" Or manually: see docs/DATABASE.md") + Console.print(" Configure later: docker compose up -d") + Console.print(" Or manually: see docs/DATABASE.md") break # Loop continues to show menu again @@ -2328,8 +2444,8 @@ def handle_database_setup_interactive(env_name: str) -> None: else: # skip Console.warning("Skipping database setup") - print(" Configure later: docker compose up -d") - print(" Or manually: see docs/DATABASE.md") + Console.print(" Configure later: docker compose up -d") + Console.print(" Or manually: see docs/DATABASE.md") break @@ -2372,10 +2488,12 @@ def handle_database_setup_cli( Console.error("Docker setup failed") if reason == "compose_unavailable": Console.warning("Docker not available") - print(" Install from: https://docs.docker.com/get-docker/") + Console.print( + " Install from: https://docs.docker.com/get-docker/" + ) else: Console.error(f"Error: {reason}") - print(" You can configure manually later") + Console.print(" You can configure manually later") elif db_type == "remote": success = setup_database_remote( env_name=env_name, @@ -2386,7 +2504,7 @@ def handle_database_setup_cli( ) if not success: Console.warning("Remote database setup cancelled") - print(" You can configure manually later") + Console.print(" You can configure manually later") def change_database_password( @@ -2432,12 +2550,16 @@ def change_database_password( """ import getpass - print("\n" + "=" * 60) - print("Password Change (Recommended for lab members)") - print("=" * 60) - print("\nIf you received temporary credentials from your lab admin,") - print("you should change your password now for security.") - print() + Console.print("") + Console.banner("Password Change (Recommended for lab members)") + Console().multi( + [ + "", + "If you received temporary credentials from your lab admin,", + "you should change your password now for security.", + "", + ] + ) if not Console.prompt_yes_no("Change password?", default_yes=True): Console.warning("Keeping current password") @@ -2445,7 +2567,7 @@ def change_database_password( # Prompt for new password with confirmation while True: - print() + Console.print("") new_password = getpass.getpass(" New password: ") if not new_password: Console.error("Password cannot be empty") @@ -2569,7 +2691,7 @@ def setup_database_remote( >>> if setup_database_remote(host="db.example.com", user="myuser"): ... print("Non-interactive setup succeeded") """ - print("Setting up remote database connection...") + Console.print("Setting up remote database connection...") # Get config either from prompt or from provided parameters if host is None or user is None or password is None: @@ -2598,56 +2720,62 @@ def setup_database_remote( if not valid: Console.error("Invalid database configuration:") for err in errors: - print(f" - {err}") + Console.print(f"- {err}", indent=1) return False # Check if port is reachable (for remote hosts only) if Validators.should_use_tls(host): - print(f" Testing connection to {host}:{port}...") + Console.print(f" Testing connection to {host}:{port}...") port_reachable, port_msg = is_port_available(host, port) if not port_reachable: Console.warning(port_msg) - print("\n Possible causes:") - print(" • Wrong port number (MySQL usually uses 3306)") - print(" • Firewall blocking connections") - print(" • Database server not running") - print(" • Wrong hostname") + Console().multi( + [ + "\n Possible causes:", + " • Wrong port number (MySQL usually uses 3306)", + " • Firewall blocking connections", + " • Database server not running", + " • Wrong hostname", + ] + ) if not Console.prompt_yes_no( "\n Continue anyway?", default_yes=False ): return False else: - print(" ✓ Port is reachable") + Console.print(" ✓ Port is reachable") # Determine TLS based on host (use TLS for non-localhost) use_tls = Validators.should_use_tls(host) config["use_tls"] = use_tls - print(f" Connecting to {host}:{port} as {user}") + Console.print(f" Connecting to {host}:{port} as {user}") if use_tls: - print(" TLS: enabled") + Console.print(" TLS: enabled") # Test connection before saving success, _error = test_database_connection(**config) if not success: Console.error(f"Cannot connect to database: {_error}") - print() - print("Most common causes (in order):") - print(" 1. Wrong password - Double check credentials") - print(" 2. Firewall blocking connection") - print(" 3. Database not running") - print(" 4. TLS mismatch") - print() - print("Diagnostic steps:") - print(f" Test port: nc -zv {host} {port}") - print(f" Test MySQL: mysql -h {host} -P {port} -u {user} -p") - print() - print( - "Need help? See: docs/TROUBLESHOOTING.md#database-connection-fails" + Console().multi( + [ + "", + "Most common causes (in order):", + " 1. Wrong password - Double check credentials", + " 2. Firewall blocking connection", + " 3. Database not running", + " 4. TLS mismatch", + "", + "Diagnostic steps:", + f" Test port: nc -zv {host} {port}", + f" Test MySQL: mysql -h {host} -P {port} -u {user} -p", + "", + "Need help? See: docs/TROUBLESHOOTING.md#database-connection-fails", + "", + ] ) - print() if Console.prompt_yes_no( "Retry with different settings?", default_yes=False @@ -2710,25 +2838,27 @@ def validate_installation(env_name: str) -> bool: except subprocess.CalledProcessError: Console.fail() Console.warning("Some optional validation checks did not pass") - print( - "\n Core installation succeeded, but some features may need attention." + Console().multi( + [ + "\n Core installation succeeded, but some features may need attention.", + " Many warnings are not critical for getting started.", + "\n Common non-critical warnings:", + " - Database connection: Configure later if needed", + " - Optional packages: Install when you need them", + "\n To investigate:", + f" 1. Activate environment: conda activate {env_name}", + " 2. Run detailed validation: python scripts/validate.py -v", + " 3. See docs/TROUBLESHOOTING.md for specific issues", + ] ) - print(" Many warnings are not critical for getting started.") - print("\n Common non-critical warnings:") - print(" - Database connection: Configure later if needed") - print(" - Optional packages: Install when you need them") - print("\n To investigate:") - print(f" 1. Activate environment: conda activate {env_name}") - print(" 2. Run detailed validation: python scripts/validate.py -v") - print(" 3. See docs/TROUBLESHOOTING.md for specific issues") return False def print_installation_header() -> None: """Print installation header banner.""" - print() + Console.print("") Console.banner(" Spyglass Installation", color="blue") - print() + Console.print("") def determine_installation_type(args: argparse.Namespace) -> Tuple[str, str]: @@ -2813,21 +2943,27 @@ def print_completion_message(env_name: str, validation_passed: bool) -> None: validation_passed : bool Whether validation checks passed """ - print() + Console.print("") if validation_passed: Console.banner("Installation complete!", color="green") - print() + Console.print("") else: Console.banner("Installation complete with warnings", color="yellow") - print() - print("Core installation succeeded but some features may not work.") - print("Review warnings above and see: docs/TROUBLESHOOTING.md\n") - - print("Next steps:") - print(f" 1. Activate environment: conda activate {env_name}") - print(" 2. Start tutorial: jupyter notebook notebooks/") - print( - " 3. View documentation: https://lorenfranklab.github.io/spyglass/" + Console().multi( + [ + "", + "Core installation succeeded but some features may not work.", + "Review warnings above and see: docs/TROUBLESHOOTING.md\n", + ] + ) + + Console().multi( + [ + "Next steps:", + f" 1. Activate environment: conda activate {env_name}", + " 2. Start tutorial: jupyter notebook notebooks/", + " 3. View documentation: https://lorenfranklab.github.io/spyglass/", + ] ) @@ -2839,9 +2975,9 @@ def run_dry_run(args: argparse.Namespace) -> None: args : argparse.Namespace Parsed command-line arguments """ - print("\n" + "=" * 60) - print("DRY RUN MODE - No changes will be made") - print("=" * 60 + "\n") + Console.print("") + Console.banner("DRY RUN MODE - No changes will be made") + Console.print("") # Determine installation type if args.minimal: @@ -2875,48 +3011,51 @@ def run_dry_run(args: argparse.Namespace) -> None: install_type.split()[0], DISK_SPACE_REQUIREMENTS["minimal"] ) - print("Would perform the following steps:\n") - - print(f"1. {SYMBOLS['step']} Check prerequisites") - print( - f" Python version: {sys.version_info.major}.{sys.version_info.minor}" - ) - print(f" Required disk space: {space_req['total']} GB") - print() - - print(f"2. {SYMBOLS['step']} Create conda environment") - print(f" Environment name: {args.env_name}") - print(f" Environment file: {env_file}") - print(f" Install type: {install_type}") - print() - - print(f"3. {SYMBOLS['step']} Install spyglass package") - print(" Command: pip install -e .") - print() - - print(f"4. {SYMBOLS['step']} Create directory structure") - print(f" Base directory: {base_dir}") - print( - " Subdirectories: raw, analysis, recording, sorting, waveforms, etc." + Console.print("Would perform the following steps:\n") + + Console().multi( + [ + f"1. {SYMBOLS['step']} Check prerequisites", + f" Python version: {sys.version_info.major}.{sys.version_info.minor}", + f" Required disk space: {space_req['total']} GB", + "", + f"2. {SYMBOLS['step']} Create conda environment", + f" Environment name: {args.env_name}", + f" Environment file: {env_file}", + f" Install type: {install_type}", + "", + f"3. {SYMBOLS['step']} Install spyglass package", + " Command: pip install -e .", + "", + f"4. {SYMBOLS['step']} Create directory structure", + f" Base directory: {base_dir}", + " Subdirectories: raw, analysis, recording, sorting, waveforms, etc.", + "", + f"5. {SYMBOLS['step']} Setup database", + f" Method: {db_setup}", + "", + f"6. {SYMBOLS['step']} Create configuration file", + f" Location: {Path.home() / '.datajoint_config.json'}", + "", + ] ) - print() - - print(f"5. {SYMBOLS['step']} Setup database") - print(f" Method: {db_setup}") - print() - - print(f"6. {SYMBOLS['step']} Create configuration file") - print(f" Location: {Path.home() / '.datajoint_config.json'}") - print() if not args.skip_validation: - print(f"7. {SYMBOLS['step']} Validate installation") - print(" Run: python scripts/validate.py") - print() + Console().multi( + [ + f"7. {SYMBOLS['step']} Validate installation", + " Run: python scripts/validate.py", + "", + ] + ) - print("=" * 60) - print("To perform installation, run without --dry-run flag") - print("=" * 60) + Console().multi( + [ + "=" * 60, + "To perform installation, run without --dry-run flag", + "=" * 60, + ] + ) def run_config_only(args: argparse.Namespace) -> None: @@ -2932,9 +3071,9 @@ def run_config_only(args: argparse.Namespace) -> None: args : argparse.Namespace Parsed command-line arguments containing database and path options """ - print("\n" + "=" * 60) - print("CONFIG-ONLY MODE - Generating configuration file") - print("=" * 60 + "\n") + Console.print("") + Console.banner("CONFIG-ONLY MODE - Generating configuration file") + Console.print("") # Get base directory base_dir = get_base_directory(args.base_dir) @@ -2966,9 +3105,13 @@ def run_config_only(args: argparse.Namespace) -> None: raise ValueError("Database password is required") else: # Interactive mode - print("Database configuration:") - print(" 1. Local Docker database (localhost)") - print(" 2. Remote database (e.g., lmf-db.cin.ucsf.edu)") + Console().multi( + [ + "Database configuration:", + " 1. Local Docker database (localhost)", + " 2. Remote database (e.g., lmf-db.cin.ucsf.edu)", + ] + ) choice = input("\nChoice [1/2]: ").strip() or "1" if choice == "1": @@ -3006,19 +3149,23 @@ def run_config_only(args: argparse.Namespace) -> None: Console.done() config_file = Path.home() / ".datajoint_config.json" - print() - print("=" * 60) + Console.print("") + Console.banner("") Console.success(f"Configuration created: {config_file}") - print("=" * 60) - print() - print("Configuration summary:") - print(f" Database: {host}:{port}") - print(f" User: {user}") - print(f" Base directory: {base_dir}") - print(f" TLS: {'enabled' if determine_tls(host) else 'disabled'}") - print() - print("To test your configuration:") - print(' python -c "import datajoint as dj; dj.conn()"') + Console.banner("") + Console().multi( + [ + "", + "Configuration summary:", + f" Database: {host}:{port}", + f" User: {user}", + f" Base directory: {base_dir}", + f" TLS: {'enabled' if determine_tls(host) else 'disabled'}", + "", + "To test your configuration:", + ' python -c "import datajoint as dj; dj.conn()"', + ] + ) def run_installation(args: argparse.Namespace) -> None: @@ -3159,7 +3306,7 @@ def main() -> None: try: run_installation(args) except KeyboardInterrupt: - print("\n\nInstallation cancelled by user.") + Console.print("\n\nInstallation cancelled by user.") sys.exit(1) except RuntimeError as e: # Expected errors from our code (prerequisites, validation, etc.) @@ -3168,12 +3315,12 @@ def main() -> None: except subprocess.CalledProcessError as e: # Process execution failures Console.error(f"Command failed: {e}") - print(" Check the output above for details") + Console.print("Check the output above for details", indent=1) sys.exit(1) except (OSError, IOError) as e: # File system errors Console.error(f"File system error: {e}") - print(" Check disk space and permissions") + Console.print("Check disk space and permissions", indent=1) sys.exit(1) except ValueError as e: # Configuration/validation errors diff --git a/src/spyglass/position/utils_dlc.py b/src/spyglass/position/utils_dlc.py index d77d62e2a..e4e76b322 100644 --- a/src/spyglass/position/utils_dlc.py +++ b/src/spyglass/position/utils_dlc.py @@ -1,6 +1,7 @@ -import builtins import contextlib import csv +import inspect +import sys from pathlib import Path try: @@ -14,25 +15,48 @@ @contextlib.contextmanager def suppress_print_from_package(package: str = "deeplabcut"): - original_print = builtins.print + """Suppress stdout/stderr writes that originate from *package*. - def dummy_print(*args, **kwargs): - stack = [ - frame.f_globals.get("__name__") - for frame in inspect.stack() - if hasattr(frame, "f_globals") - ] - if any(name and name.startswith(package) for name in stack): - return # Suppress if the call comes from the target package - return original_print(*args, **kwargs) + Replaces sys.stdout and sys.stderr with a proxy that walks the call stack + on every write; output whose innermost package-level frame matches + ``package`` is dropped, everything else passes through unchanged. - import inspect + More reliable than patching builtins.print because it also catches tqdm + progress bars and any code that calls sys.stdout.write() directly. + """ - builtins.print = dummy_print + class _PackageFilter: + """Proxy stream: suppress writes from *package*, pass others through.""" + + def __init__(self, stream: object) -> None: + self._stream = stream + + def write(self, text: str) -> int: + for frame_info in inspect.stack(): + # Real FrameInfo objects store the frame in .frame; + # test mocks may expose f_globals directly on the object. + fg = getattr(frame_info, "f_globals", None) + if fg is None: + raw = getattr(frame_info, "frame", None) + fg = getattr(raw, "f_globals", {}) if raw else {} + if fg.get("__name__", "").startswith(package): + return len(text) # drop — came from target package + return self._stream.write(text) + + def flush(self) -> None: + return self._stream.flush() + + def __getattr__(self, name: str): + return getattr(self._stream, name) + + old_stdout, old_stderr = sys.stdout, sys.stderr + sys.stdout = _PackageFilter(old_stdout) + sys.stderr = _PackageFilter(old_stderr) try: yield finally: - builtins.print = original_print + sys.stdout = old_stdout + sys.stderr = old_stderr def get_dlc_model_eval( diff --git a/src/spyglass/spikesorting/v1/recompute.py b/src/spyglass/spikesorting/v1/recompute.py index 30b486f00..5b066dc35 100644 --- a/src/spyglass/spikesorting/v1/recompute.py +++ b/src/spyglass/spikesorting/v1/recompute.py @@ -339,7 +339,7 @@ def attempt_all( if not bool(RecordingRecompute & key) ] if not inserts: - logger.debug(f"No rows to insert from:\n\t{source}") + logger.info(f"No rows to insert from:\n\t{source}") return logger.info(f"Inserting recompute attempts for {len(inserts)} files.") @@ -904,7 +904,7 @@ def delete_files( file_names = query.fetch("analysis_file_name") prefix = "DRY RUN: " if dry_run else "" if not len(file_names): - logger.debug(f"{prefix}Delete 0 files. Nothing to do.") + logger.info(f"{prefix}Delete 0 files. Nothing to do.") return msg = f"{prefix}Delete {len(file_names)} files?\n\t" + "\n\t".join( file_names[:10] diff --git a/src/spyglass/utils/mixins/analysis.py b/src/spyglass/utils/mixins/analysis.py index 86d9b60bf..9004e4c21 100644 --- a/src/spyglass/utils/mixins/analysis.py +++ b/src/spyglass/utils/mixins/analysis.py @@ -264,7 +264,7 @@ def create( # write the new file if not recompute_file_name: - self._logger.info(f"Writing new NWB file {analysis_file_name}") + self._info_msg(f"Writing new NWB file {analysis_file_name}") analysis_file_abs_path = self.get_abs_path( analysis_file_name, from_schema=bool(recompute_file_name) @@ -413,7 +413,7 @@ def copy(cls, nwb_file_name: str): original_nwb_file_name = query.fetch("nwb_file_name")[0] analysis_file_name = cls.__get_new_file_name(original_nwb_file_name) # write the new file - cls()._logger.info(f"Writing new NWB file {analysis_file_name}...") + cls()._info_msg(f"Writing new NWB file {analysis_file_name}...") analysis_file_abs_path = cls().get_abs_path(analysis_file_name) # export the new NWB file with pynwb.NWBHDF5IO( diff --git a/src/spyglass/utils/mixins/base.py b/src/spyglass/utils/mixins/base.py index 9affa8d5a..3c8b44283 100644 --- a/src/spyglass/utils/mixins/base.py +++ b/src/spyglass/utils/mixins/base.py @@ -36,6 +36,18 @@ def _graph_deps(self) -> list: return [TableChain, RestrGraph] + def _info_msg(self, msg: str) -> str: + """Log info message, but debug if in test mode. + + Quiets logs during testing, but preserves user experience during use. + + Used by ... + - AnalysisMixin.copy and .create + - IngestionMixin._insert_logline + """ + log = self._logger.debug if self._test_mode else self._logger.info + log(msg) + @cached_property def _test_mode(self) -> bool: """Return True if in test mode. diff --git a/src/spyglass/utils/mixins/ingestion.py b/src/spyglass/utils/mixins/ingestion.py index d5f5ebb2e..e578bc59f 100644 --- a/src/spyglass/utils/mixins/ingestion.py +++ b/src/spyglass/utils/mixins/ingestion.py @@ -201,7 +201,7 @@ def _camel(tbl=None): this_tbl, self_tbl = _camel(table), _camel(self) suffix = "" if this_tbl == self_tbl else f" via {self_tbl}" - logger.info( + self._info_msg( f"{nwb_file_name} inserts {n_entries} into {this_tbl}{suffix}" ) diff --git a/tests/conftest.py b/tests/conftest.py index a4c7bbfc8..8be896caa 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,6 +6,36 @@ """ import os + +# --------------------------------------------------------------------------- +# Environment variables — set before any package imports so that TensorFlow, +# CUDA, and Qt pick them up at their first import. +# --------------------------------------------------------------------------- + +# Suppress TensorFlow C++ logging (0=DEBUG … 3=FATAL-only). +os.environ.setdefault("TF_CPP_MIN_LOG_LEVEL", "3") + +# Disable oneDNN fused-ops to avoid the "numerical results may differ" banner. +os.environ.setdefault("TF_ENABLE_ONEDNN_OPTS", "0") + +# Qt requires a display; offscreen keeps headless CI from crashing. +os.environ.setdefault("QT_QPA_PLATFORM", "offscreen") +os.environ.setdefault("DISPLAY", ":0") + +# Disable all tqdm progress bars; they pollute test output. +os.environ.setdefault("TQDM_DISABLE", "1") + +# Suppress ResourceWarning at the OS level so datajoint/hash.py unclosed-file +# warnings don't bleed through even during GC finalisation. +_existing = os.environ.get("PYTHONWARNINGS", "") +_rw_filter = "ignore::ResourceWarning" +if _rw_filter not in _existing: + os.environ["PYTHONWARNINGS"] = ( + f"{_existing},{_rw_filter}" if _existing else _rw_filter + ) + +# --------------------------------------------------------------------------- + import sys import warnings from contextlib import nullcontext diff --git a/tests/position/v1/test_pos_merge.py b/tests/position/v1/test_pos_merge.py index acfb712e9..140e909b9 100644 --- a/tests/position/v1/test_pos_merge.py +++ b/tests/position/v1/test_pos_merge.py @@ -47,7 +47,9 @@ def test_merge_fetch_video_path(pos_merge, dlc_key, populate_dlc): assert Path(path).exists(), f"Video path does not exist: {path}" -def test_merge_id_order(pos_merge): +def test_merge_id_order(trodes_pos_v1, pos_merge): + _ = trodes_pos_v1 # Ensure populated + merge_keys = pos_merge.TrodesPosV1().fetch("KEY") assert len(merge_keys) > 1 nwb_file_list, merge_ids = (pos_merge & merge_keys).fetch_nwb( diff --git a/tests/setup/conftest.py b/tests/setup/conftest.py index 1c6eab878..355e8e466 100644 --- a/tests/setup/conftest.py +++ b/tests/setup/conftest.py @@ -5,8 +5,24 @@ spyglass imports that connect to MySQL. """ +import sys +from pathlib import Path + import pytest +# Suppress non-essential install.py console output in-process without +# propagating to subprocess test runs (env vars propagate; class patch does not). +try: + _sp = str(Path(__file__).parent.parent.parent / "scripts") + if _sp not in sys.path: + sys.path.insert(0, _sp) + import install as _install + + _install.Console._quiet = True + del _sp, _install +except Exception: + pass + # Override ALL session-scoped fixtures that might trigger database connections # or spyglass imports during collection From 7f7add869c8ab9912bfb1dd9e45a20b599081e65 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Fri, 20 Feb 2026 15:39:00 +0100 Subject: [PATCH 10/30] Denoising tests 4 --- src/spyglass/common/common_behav.py | 14 ++++++++------ src/spyglass/common/common_ephys.py | 5 +++-- src/spyglass/common/common_interval.py | 4 +++- src/spyglass/common/common_lab.py | 8 ++++---- src/spyglass/common/common_nwbfile.py | 10 +++++----- src/spyglass/common/common_optogenetics.py | 12 ++++++------ src/spyglass/common/common_position.py | 17 +++++++++-------- src/spyglass/common/common_sensors.py | 4 +++- src/spyglass/common/common_task.py | 4 ++-- src/spyglass/common/common_usage.py | 6 +++++- src/spyglass/common/populate_all_common.py | 4 +++- src/spyglass/data_import/insert_sessions.py | 18 ++++++++++-------- src/spyglass/decoding/decoding_merge.py | 8 ++++---- src/spyglass/lfp/lfp_imported.py | 12 ++++++------ src/spyglass/position/v1/dlc_utils_makevid.py | 1 + src/spyglass/position/v1/imported_pose.py | 1 + .../position/v1/position_dlc_selection.py | 4 ++-- .../position/v1/position_trodes_position.py | 4 ++-- src/spyglass/settings.py | 2 +- src/spyglass/spikesorting/imported.py | 2 +- .../spikesorting/v0/spikesorting_curation.py | 2 +- .../spikesorting/v0/spikesorting_recording.py | 3 +++ .../spikesorting/v0/spikesorting_sorting.py | 3 +++ .../spikesorting/v1/metric_curation.py | 4 ++-- src/spyglass/spikesorting/v1/recompute.py | 2 +- src/spyglass/utils/database_settings.py | 1 + src/spyglass/utils/dj_merge_tables.py | 2 +- src/spyglass/utils/mixins/base.py | 10 +++++++++- src/spyglass/utils/mixins/export.py | 4 ++-- src/spyglass/utils/mixins/ingestion.py | 2 +- src/spyglass/utils/mixins/restrict_by.py | 4 +--- src/spyglass/utils/nwb_helper_fn.py | 10 +++++++--- tests/common/test_task_epoch_tags.py | 8 +------- tests/conftest.py | 4 +--- tests/container.py | 5 +++-- tests/utils/test_graph.py | 17 ++++++++++------- tests/utils/test_mixin.py | 8 ++++++-- 37 files changed, 132 insertions(+), 97 deletions(-) diff --git a/src/spyglass/common/common_behav.py b/src/spyglass/common/common_behav.py index 03ff7e556..e81362999 100644 --- a/src/spyglass/common/common_behav.py +++ b/src/spyglass/common/common_behav.py @@ -87,7 +87,9 @@ def insert_from_nwbfile(cls, nwb_file_name, skip_duplicates=False) -> None: src_key = dict(**sess_key, source="imported", import_file_name="") if all_pos is None: - logger.info(f"No position data found in {nwb_file_name}. Skipping.") + cls()._info_msg( + f"No position data found in {nwb_file_name}. Skipping." + ) return sources = [] @@ -406,7 +408,7 @@ def make(self, key): "associated_files" ) or nwbf.processing.get("associated files") if associated_files is None: - logger.info( + self._info_msg( "Unable to import StateScriptFile: no processing module named " + f'"associated_files" found in {nwb_file_name}.' ) @@ -485,7 +487,7 @@ def _prepare_video_entry( The video object from the NWB file cam_device_regex : str, optional Regular expression pattern to extract camera device number. - Default: r"camera_device (\d+)" + Default: r"camera_device (\\d+)" Returns ------- @@ -645,7 +647,7 @@ def make(self, key, verbose=True, skip_duplicates=False): if isinstance(obj, pynwb.image.ImageSeries) } if not videos: - logger.warning( + self._warn_msg( f"No video data interface found in {nwb_file_name}\n" ) return @@ -720,7 +722,7 @@ def make(self, key, verbose=True, skip_duplicates=False): ) elif imported_count == 0 and verbose: - logger.info( + self._info_msg( f"No video found corresponding to file {nwb_file_name}, " f"epoch {interval_list_name}" ) @@ -910,7 +912,7 @@ def _no_transaction_make(self, key): dict(key, position_interval_name=matching_pos_intervals[0]), **insert_opts, ) - logger.info( + self._info_msg( "Populated PosIntervalMap for " + f'{nwb_file_name}, {key["interval_list_name"]}' ) diff --git a/src/spyglass/common/common_ephys.py b/src/spyglass/common/common_ephys.py index d0b985e50..d4c35129e 100644 --- a/src/spyglass/common/common_ephys.py +++ b/src/spyglass/common/common_ephys.py @@ -319,7 +319,7 @@ def _rate_fallback(self, nwb_object): if timestamps is None: raise ValueError("Neither rate nor timestamps are available.") return estimate_sampling_rate( - np.asarray(timestamps[: int(1e6)]), 1.5, verbose=True + np.asarray(timestamps[: int(1e6)]), 1.5, verbose=not self._test_mode ) def _valid_times_from_raw(self, nwb_object): @@ -337,6 +337,7 @@ def _valid_times_from_raw(self, nwb_object): sampling_rate=self._rate_fallback(nwb_object), gap_proportion=1.75, min_valid_len=0, + warn=not self._test_mode, ) def generate_entries_from_nwb_object(self, nwb_obj, base_key=None): @@ -400,7 +401,7 @@ def make(self, key): # TODO: change name when nwb file is changed sample_count = get_data_interface(nwbf, "sample_count") if sample_count is None: - logger.info( + self._info_msg( "Unable to import SampleCount: no data interface named " + f'"sample_count" found in {nwb_file_name}.' ) diff --git a/src/spyglass/common/common_interval.py b/src/spyglass/common/common_interval.py index 3650c1567..4e0cc83a8 100644 --- a/src/spyglass/common/common_interval.py +++ b/src/spyglass/common/common_interval.py @@ -11,6 +11,7 @@ import pynwb from spyglass.common.common_session import Session # noqa: F401 +from spyglass.settings import test_mode from spyglass.utils import SpyglassIngestion, logger from spyglass.utils.dj_helper_fn import get_child_tables @@ -334,7 +335,7 @@ def __init__( from_inds=False, no_overlap=False, no_duplicates=True, - warn=True, + warn=not test_mode, # warn by default, unless running pytests **kwargs, ) -> None: """Initialize the Intervals class with a list of intervals. @@ -358,6 +359,7 @@ def __init__( Additional keyword arguments to pass to the class, including "valid_times" and "interval_list_name" for times and name. """ + self.kwargs = dict( # Returned objects will set this behavior kwargs, no_overlap=no_overlap, diff --git a/src/spyglass/common/common_lab.py b/src/spyglass/common/common_lab.py index 77278afd2..3046f8c93 100644 --- a/src/spyglass/common/common_lab.py +++ b/src/spyglass/common/common_lab.py @@ -56,7 +56,7 @@ def generate_entries_from_nwb_object( base_key = base_key or dict() experimenter_list = nwb_obj.experimenter if not experimenter_list: - logger.info("No experimenter metadata found.\n") + self._info_msg("No experimenter metadata found.\n") return dict() entries = [] @@ -181,7 +181,7 @@ def generate_entries_from_nwb_object(self, nwb_obj, base_key=None): base_key = base_key or dict() experimenter_list = nwb_obj.experimenter if not experimenter_list: - logger.info("No experimenter metadata found for LabTeam.\n") + self._info_msg("No experimenter metadata found for LabTeam.\n") return dict() team_entries = [] @@ -259,7 +259,7 @@ def create_new_team( query = (LabMember.LabMemberInfo() & member_dict).fetch( "google_user_name" ) - if not query: + if not query and not cls()._test_mode: logger.warning( "To help manage permissions in LabMemberInfo, please add " + f"Google user ID for {team_member}" @@ -348,7 +348,7 @@ def insert_from_nwbfile( .get(self) ) if not insert_entries: - logger.info("No lab metadata found.\n") + self._info_msg("No lab metadata found.\n") return dict() if len(insert_entries) > 1: logger.info( diff --git a/src/spyglass/common/common_nwbfile.py b/src/spyglass/common/common_nwbfile.py index 7e1b1b093..dc491cbe4 100644 --- a/src/spyglass/common/common_nwbfile.py +++ b/src/spyglass/common/common_nwbfile.py @@ -740,7 +740,7 @@ def _cleanup_custom_table( unused = analysis_tbl.cleanup_external( dry_run=dry_run, delete_external_files=True ) - logger.info( + self._info_msg( f" [{table_num}/{num_tables}] {prefix}: {n_orphans} orphans, " + f"{len(unused)} unused externals" ) @@ -792,7 +792,7 @@ def cleanup(self, dry_run: bool = False) -> None: """ heading = "============== Analysis Cleanup " suffix = "(Dry Run) ==============" if dry_run else "==============" - logger.info(heading + suffix) + self._info_msg(heading + suffix) registry = AnalysisRegistry() @@ -822,7 +822,7 @@ def cleanup(self, dry_run: bool = False) -> None: dry_run=dry_run, delete_external_files=False ) - logger.info( + self._info_msg( f" [{num_tables}/{num_tables}] common: {n_orphans} " f"orphans, {len(unused)} unused externals" ) @@ -859,7 +859,7 @@ def check_all_files(self) -> dict: """ from spyglass.common.common_file_tracking import AnalysisFileIssues - logger.info("Checking analysis files across all tables") + self._info_msg("Checking analysis files across all tables") registry = AnalysisRegistry() # Include common table + all custom tables @@ -871,7 +871,7 @@ def check_all_files(self) -> dict: for i, analysis_tbl in enumerate(analysis_tables, start=1): tbl_name = analysis_tbl.full_table_name - logger.info(f" [{i}/{num_tables}] Checking {tbl_name} files") + self._info_msg(f" [{i}/{num_tables}] Checking {tbl_name} files") issue_count = file_checker.check_files(analysis_tbl) results[tbl_name] = issue_count diff --git a/src/spyglass/common/common_optogenetics.py b/src/spyglass/common/common_optogenetics.py index 72a30f99b..ee6d7520e 100644 --- a/src/spyglass/common/common_optogenetics.py +++ b/src/spyglass/common/common_optogenetics.py @@ -53,7 +53,7 @@ def make(self, key): nwb = (Nwbfile() & nwb_key).fetch_nwb()[0] opto_epoch_obj = nwb.intervals.get("optogenetic_epochs", None) if opto_epoch_obj is None: - logger.warning( + self._warn_msg( f"No optogenetic epochs found in NWB file {nwb_key['nwb_file_name']}" ) return @@ -103,15 +103,15 @@ def make(self, key): spatial_inserts.append(spatial_key) # insert keys with self._safe_context(): - logger.info("Inserting Protocol") + self._info_msg("Inserting Protocol") self.insert(epoch_inserts) - logger.info("Inserting RippleTrigger") + self._info_msg("Inserting RippleTrigger") self.RippleTrigger.insert(ripple_inserts) - logger.info("Inserting ThetaTrigger") + self._info_msg("Inserting ThetaTrigger") self.ThetaTrigger.insert(theta_inserts) - logger.info("Inserting SpeedConditional") + self._info_msg("Inserting SpeedConditional") self.SpeedConditional.insert(speed_inserts) - logger.info("Inserting SpatialConditional") + self._info_msg("Inserting SpatialConditional") self.SpatialConditional.insert(spatial_inserts) @staticmethod diff --git a/src/spyglass/common/common_position.py b/src/spyglass/common/common_position.py index 338dbfa3c..4734b16b7 100644 --- a/src/spyglass/common/common_position.py +++ b/src/spyglass/common/common_position.py @@ -212,10 +212,11 @@ def generate_pos_components( **time_comments, ) else: - logger.info( - "No video frame index found. Assuming all camera frames " - + "are present." - ) + if not test_mode: + logger.info( + "No video frame index found. Assuming all camera " + + "frames are present." + ) velocity.create_timeseries( name="video_frame_ind", unit="index", @@ -466,7 +467,7 @@ def calculate_position_info( # set orientation to NaN in single LED data if np.all(front_LED == 0) or np.all(back_LED == 0): - logger.warning( + self._warn_msg( "Single LED data detected. Setting orientation to NaN." ) orientation = np.full_like(orientation, np.nan) @@ -566,7 +567,7 @@ def make(self, key): """ M_TO_CM = 100 - logger.info("Loading position data...") + self._info_msg("Loading position data...") nwb_dict = dict(nwb_file_name=key["nwb_file_name"]) @@ -584,7 +585,7 @@ def make(self, key): } ).fetch1_dataframe() - logger.info("Loading video data...") + self._info_msg("Loading video data...") epoch = get_position_interval_epoch( key["nwb_file_name"], key["interval_list_name"] ) @@ -623,7 +624,7 @@ def make(self, key): position_time = np.asarray(position_info_df.index) cm_per_pixel = nwb_video.device.meters_per_pixel * M_TO_CM - logger.info("Making video...") + self._info_msg("Making video...") self.make_video( f"{video_dir}/{video_filename}", centroids, diff --git a/src/spyglass/common/common_sensors.py b/src/spyglass/common/common_sensors.py index 9d11ea518..bd689af46 100644 --- a/src/spyglass/common/common_sensors.py +++ b/src/spyglass/common/common_sensors.py @@ -54,7 +54,9 @@ def make(self, key: dict) -> None: # Validate the sensor data if sensor is None: - logger.info(f"No conforming sensor data found in {nwb_file_name}\n") + self._info_msg( + f"No conforming sensor data found in {nwb_file_name}\n" + ) return columns = sensor.time_series["analog"].description.split() diff --git a/src/spyglass/common/common_task.py b/src/spyglass/common/common_task.py index b4b644a13..51d85db00 100644 --- a/src/spyglass/common/common_task.py +++ b/src/spyglass/common/common_task.py @@ -240,7 +240,7 @@ def make(self, key): tasks_mod = nwbf.processing.get("tasks") config_tasks = config.get("Tasks", []) if tasks_mod is None and (not config_tasks): - logger.warning( + self._warn_msg( f"No tasks processing module found in {nwbf} or config\n" ) # Issue #1444: Check for orphaned ImageSeries @@ -377,7 +377,7 @@ def get_epoch_interval_name(cls, epoch, session_intervals): warn = "Multiple" if len(possible_targets) > 1 else "No" - logger.warning( + cls()._warn_msg( f"{warn} interval(s) found for epoch {epoch}. " f"Available intervals: {session_intervals}" ) diff --git a/src/spyglass/common/common_usage.py b/src/spyglass/common/common_usage.py index a54264052..612d26286 100644 --- a/src/spyglass/common/common_usage.py +++ b/src/spyglass/common/common_usage.py @@ -567,7 +567,11 @@ def make(self, key): unlinked_files = set() if self._n_file_link_processes == 1: - for file in tqdm(file_paths, desc="Checking linked nwb files"): + for file in tqdm( + file_paths, + desc="Checking linked nwb files", + disable=not test_mode, + ): unlinked_files.update(get_unlinked_files(file)) else: with Pool(processes=self._n_file_link_processes) as pool: diff --git a/src/spyglass/common/populate_all_common.py b/src/spyglass/common/populate_all_common.py index d57d98db3..acec27cde 100644 --- a/src/spyglass/common/populate_all_common.py +++ b/src/spyglass/common/populate_all_common.py @@ -103,6 +103,8 @@ def single_transaction_make( all tables will have exactly one key_source entry per nwb file. """ + nwbfile_tbl = Nwbfile() + file_restr = {"nwb_file_name": nwb_file_name} with Nwbfile._safe_context(): for table in tables: @@ -123,7 +125,7 @@ def single_transaction_make( continue # If imported/computed table, get key from key_source - logger.info(f"Populating {table.__name__}...") + nwbfile_tbl._info_msg(f"Populating {table.__name__}...") key_source = getattr(table, "key_source", None) if key_source is None: # Generate key from parents parents = table.parents(as_objects=True) diff --git a/src/spyglass/data_import/insert_sessions.py b/src/spyglass/data_import/insert_sessions.py index 5362943e8..84dcb28eb 100644 --- a/src/spyglass/data_import/insert_sessions.py +++ b/src/spyglass/data_import/insert_sessions.py @@ -8,7 +8,7 @@ from spyglass.common import Nwbfile, get_raw_eseries, populate_all_common from spyglass.common.common_nwbfile import schema as nwbfile_schema -from spyglass.settings import debug_mode, raw_dir +from spyglass.settings import debug_mode, raw_dir, test_mode from spyglass.utils import logger from spyglass.utils.nwb_helper_fn import get_nwb_copy_filename @@ -113,10 +113,11 @@ def copy_nwb_link_raw_ephys( str The absolute path of the new NWB file. """ - logger.info( - f"Creating a copy of NWB file {nwb_file_name} " - + f"with link to raw ephys data: {out_nwb_file_name}" - ) + if not test_mode: + logger.info( + f"Creating a copy of NWB file {nwb_file_name} " + + f"with link to raw ephys data: {out_nwb_file_name}" + ) nwb_file_abs_path = Nwbfile.get_abs_path(nwb_file_name, new_file=True) @@ -130,9 +131,10 @@ def copy_nwb_link_raw_ephys( if os.path.exists(out_nwb_file_abs_path): if debug_mode or keep_existing: return out_nwb_file_abs_path - logger.warning( - f"Output file exists, will be overwritten: {out_nwb_file_abs_path}" - ) + if not test_mode: + logger.warning( + f"Output file exists, will be overwritten: {out_nwb_file_abs_path}" + ) with pynwb.NWBHDF5IO( path=nwb_file_abs_path, mode="r", load_namespaces=True diff --git a/src/spyglass/decoding/decoding_merge.py b/src/spyglass/decoding/decoding_merge.py index 12fa215cd..bfaa62696 100644 --- a/src/spyglass/decoding/decoding_merge.py +++ b/src/spyglass/decoding/decoding_merge.py @@ -50,13 +50,13 @@ def _fetch_registered_paths(self, attr): def cleanup(self, dry_run=False): """Remove any decoding outputs that are not in the merge table""" if dry_run: - logger.info("Dry run, not removing any files") + self._info_msg("Dry run, not removing any files") else: - logger.info("Cleaning up decoding outputs") + self._info_msg("Cleaning up decoding outputs") table_results_paths = self._fetch_registered_paths("results_path") for path in Path(config["SPYGLASS_ANALYSIS_DIR"]).glob("**/*.nc"): if str(path) not in table_results_paths: - logger.info(f"Removing {path}") + self._info_msg(f"Removing {path}") if not dry_run: try: path.unlink(missing_ok=True) # Ignore FileNotFoundError @@ -66,7 +66,7 @@ def cleanup(self, dry_run=False): table_model_paths = self._fetch_registered_paths("classifier_path") for path in Path(config["SPYGLASS_ANALYSIS_DIR"]).glob("**/*.pkl"): if str(path) not in table_model_paths: - logger.info(f"Removing {path}") + self._info_msg(f"Removing {path}") if not dry_run: try: path.unlink() diff --git a/src/spyglass/lfp/lfp_imported.py b/src/spyglass/lfp/lfp_imported.py index 63a54e043..207dc92a5 100644 --- a/src/spyglass/lfp/lfp_imported.py +++ b/src/spyglass/lfp/lfp_imported.py @@ -3,10 +3,8 @@ import pynwb from spyglass.common.common_interval import IntervalList # noqa: F401 -from spyglass.common.common_nwbfile import ( - AnalysisNwbfile, - Nwbfile, -) # noqa: F401 +from spyglass.common.common_nwbfile import AnalysisNwbfile # noqa: F401 +from spyglass.common.common_nwbfile import Nwbfile from spyglass.common.common_session import Session # noqa: F401 from spyglass.lfp.lfp_electrode import LFPElectrodeGroup # noqa: F401 from spyglass.utils import logger @@ -46,7 +44,7 @@ def make(self, key): ] if len(lfp_objects) == 0: - logger.warning( + self._warn_msg( f"No LFP objects found in {nwb_file_name}. Skipping." ) return @@ -96,7 +94,9 @@ def make(self, key): interval_key = { "nwb_file_name": nwb_file_name, "interval_list_name": f"imported lfp {i} valid times", - "valid_times": get_valid_intervals(timestamps, sampling_rate), + "valid_times": get_valid_intervals( + timestamps, sampling_rate, warn=not self._test_mode + ), "pipeline": "imported_lfp", } IntervalList().insert1(interval_key) diff --git a/src/spyglass/position/v1/dlc_utils_makevid.py b/src/spyglass/position/v1/dlc_utils_makevid.py index 21d7fe770..655a5c44c 100644 --- a/src/spyglass/position/v1/dlc_utils_makevid.py +++ b/src/spyglass/position/v1/dlc_utils_makevid.py @@ -115,6 +115,7 @@ def __init__( self.ffmpeg_fmt_args = ["-c:v", "libx264", "-pix_fmt", "yuv420p"] prev_backend = matplotlib.get_backend() + plt.close("all") # Required before backend switch (matplotlib >= 3.8) matplotlib.use("Agg") # Use non-interactive backend _ = self._set_frame_info() diff --git a/src/spyglass/position/v1/imported_pose.py b/src/spyglass/position/v1/imported_pose.py index f69c02e8d..eb982180e 100644 --- a/src/spyglass/position/v1/imported_pose.py +++ b/src/spyglass/position/v1/imported_pose.py @@ -71,6 +71,7 @@ def insert_from_nwbfile(self, nwb_file_name, **kwargs): timestamps, sampling_rate=sampling_rate, min_valid_len=sampling_rate, + warn=not self._test_mode, ) interval_pk = { "nwb_file_name": nwb_file_name, diff --git a/src/spyglass/position/v1/position_dlc_selection.py b/src/spyglass/position/v1/position_dlc_selection.py index 1833a27a6..2b5e63a03 100644 --- a/src/spyglass/position/v1/position_dlc_selection.py +++ b/src/spyglass/position/v1/position_dlc_selection.py @@ -434,8 +434,8 @@ def make(self, key): if pose_estimation_params is None: pose_estimation_params = dict() - logger.info(f"video filename: {video_filename}") - logger.info("Loading position data...") + self._info_msg(f"video filename: {video_filename}") + self._info_msg("Loading position data...") v1_key = {k: v for k, v in key.items() if k in DLCPosV1.primary_key} pos_info_df = ( diff --git a/src/spyglass/position/v1/position_trodes_position.py b/src/spyglass/position/v1/position_trodes_position.py index 2e85581e5..0816612ee 100644 --- a/src/spyglass/position/v1/position_trodes_position.py +++ b/src/spyglass/position/v1/position_trodes_position.py @@ -313,7 +313,7 @@ def make(self, key): """ M_TO_CM = 100 - logger.info("Loading position data...") + self._info_msg("Loading position data...") raw_df = ( RawPosition.PosObject & { @@ -323,7 +323,7 @@ def make(self, key): ).fetch1_dataframe() pos_df = (TrodesPosV1() & key).fetch1_dataframe() - logger.info("Loading video data...") + self._info_msg("Loading video data...") epoch = get_position_interval_epoch( key["nwb_file_name"], key["interval_list_name"] ) diff --git a/src/spyglass/settings.py b/src/spyglass/settings.py index 7bede8ebb..12c04ec5c 100644 --- a/src/spyglass/settings.py +++ b/src/spyglass/settings.py @@ -314,7 +314,7 @@ def _set_dj_config_stores(self, check_match=True, set_stores=True) -> None: ) if set_stores: - if mismatch_raw or mismatch_analysis: + if (mismatch_raw or mismatch_analysis) and not self.test_mode: logger.warning( "Setting config DJ stores to resolve mismatch.\n\t" + f"raw : {self.raw_dir}\n\t" diff --git a/src/spyglass/spikesorting/imported.py b/src/spyglass/spikesorting/imported.py index 702641bd6..aa3f043b0 100644 --- a/src/spyglass/spikesorting/imported.py +++ b/src/spyglass/spikesorting/imported.py @@ -45,7 +45,7 @@ def _source_nwb_object_type(self): def get_nwb_objects(self, nwb_file, nwb_file_name=None): """Override to get units from nwb_file.units.""" if not getattr(nwb_file, "units", None): - logger.warn("No units found in NWB file") + self._warn_msg("No units found in NWB file") return [] return [nwb_file.units] diff --git a/src/spyglass/spikesorting/v0/spikesorting_curation.py b/src/spyglass/spikesorting/v0/spikesorting_curation.py index 24e38f42e..fa8dc54df 100644 --- a/src/spyglass/spikesorting/v0/spikesorting_curation.py +++ b/src/spyglass/spikesorting/v0/spikesorting_curation.py @@ -682,7 +682,7 @@ def make_compute( ) qm[metric_name] = metric - logger.info(f"Computed all metrics: {qm}") + self._info_msg(f"Computed all metrics: {qm}") self._dump_to_json(qm, quality_metrics_path) # save dict as json object_id = AnalysisNwbfile().add_units_metrics( diff --git a/src/spyglass/spikesorting/v0/spikesorting_recording.py b/src/spyglass/spikesorting/v0/spikesorting_recording.py index b6630da6c..c8ce113ac 100644 --- a/src/spyglass/spikesorting/v0/spikesorting_recording.py +++ b/src/spyglass/spikesorting/v0/spikesorting_recording.py @@ -901,6 +901,9 @@ def _get_filtered_recording(self, key: dict): def cleanup(self, dry_run=False, verbose=True): """Removes the recording data from the recording directory.""" + if self._test_mode: + verbose = False + rec_dir = Path(recording_dir) tracked = set(self.fetch("recording_path")) all_dirs = {str(f) for f in rec_dir.iterdir() if f.is_dir()} diff --git a/src/spyglass/spikesorting/v0/spikesorting_sorting.py b/src/spyglass/spikesorting/v0/spikesorting_sorting.py index 166971f6d..bf233d97b 100644 --- a/src/spyglass/spikesorting/v0/spikesorting_sorting.py +++ b/src/spyglass/spikesorting/v0/spikesorting_sorting.py @@ -315,6 +315,9 @@ def fetch_nwb(self, *attrs, **kwargs): def cleanup(self, dry_run=False, verbose=True): """Clean up spike sorting directories that are not in the table.""" + if self._test_mode: + verbose = False + sort_dir = Path(sorting_dir) tracked = set(self.fetch("sorting_path")) all_dirs = {str(f) for f in sort_dir.iterdir() if f.is_dir()} diff --git a/src/spyglass/spikesorting/v1/metric_curation.py b/src/spyglass/spikesorting/v1/metric_curation.py index 740e2f31b..960032d80 100644 --- a/src/spyglass/spikesorting/v1/metric_curation.py +++ b/src/spyglass/spikesorting/v1/metric_curation.py @@ -306,11 +306,11 @@ def make_compute(self, key, upstream): for unit_id, value in metrics["nn_isolation"].items() } - logger.info("Applying curation...") + self._info_msg("Applying curation...") labels = self._compute_labels(metrics, label_params) merge_groups = self._compute_merge_groups(metrics, merge_params) - logger.info("Saving to NWB...") + self._info_msg("Saving to NWB...") analysis_file_name, object_id = _write_metric_curation_to_nwb( nwb_file_name, waveforms, metrics, labels, merge_groups ) diff --git a/src/spyglass/spikesorting/v1/recompute.py b/src/spyglass/spikesorting/v1/recompute.py index 5b066dc35..2506b2f74 100644 --- a/src/spyglass/spikesorting/v1/recompute.py +++ b/src/spyglass/spikesorting/v1/recompute.py @@ -163,7 +163,7 @@ def make(self, key): try: path = AnalysisNwbfile().get_abs_path(parent["analysis_file_name"]) except (FileNotFoundError, dj.DataJointError) as e: - logger.warning( # pragma: no cover + self._warn_msg( # pragma: no cover f"Issue w/{parent['analysis_file_name']}. Skipping.\n{e}" ) return # pragma: no cover diff --git a/src/spyglass/utils/database_settings.py b/src/spyglass/utils/database_settings.py index d48806db2..182fc32ac 100755 --- a/src/spyglass/utils/database_settings.py +++ b/src/spyglass/utils/database_settings.py @@ -14,6 +14,7 @@ "behavior", "common", "decoding", + "figurl", "lfp", "linearization", "mua", diff --git a/src/spyglass/utils/dj_merge_tables.py b/src/spyglass/utils/dj_merge_tables.py index a622fa8aa..c69f0bc8c 100644 --- a/src/spyglass/utils/dj_merge_tables.py +++ b/src/spyglass/utils/dj_merge_tables.py @@ -889,7 +889,7 @@ def super_delete(self, warn=True, *args, **kwargs): Added to support MRO of SpyglassMixin """ if warn: - logger.warning("!! Bypassing cautious_delete !!") + logger._warn_msg("!! Bypassing cautious_delete !!") self._log_delete(start=time(), super_delete=True) super().delete(*args, **kwargs) diff --git a/src/spyglass/utils/mixins/base.py b/src/spyglass/utils/mixins/base.py index 3c8b44283..b157fa41c 100644 --- a/src/spyglass/utils/mixins/base.py +++ b/src/spyglass/utils/mixins/base.py @@ -36,7 +36,7 @@ def _graph_deps(self) -> list: return [TableChain, RestrGraph] - def _info_msg(self, msg: str) -> str: + def _info_msg(self, msg: str) -> None: """Log info message, but debug if in test mode. Quiets logs during testing, but preserves user experience during use. @@ -48,6 +48,14 @@ def _info_msg(self, msg: str) -> str: log = self._logger.debug if self._test_mode else self._logger.info log(msg) + def _warn_msg(self, msg: str) -> None: + """Log warning message, but debug if in test mode. + + Quiets logs during testing, but preserves user experience during use. + """ + log = self._logger.debug if self._test_mode else self._logger.warning + log(msg) + @cached_property def _test_mode(self) -> bool: """Return True if in test mode. diff --git a/src/spyglass/utils/mixins/export.py b/src/spyglass/utils/mixins/export.py index dbf727bc3..25e3086b9 100644 --- a/src/spyglass/utils/mixins/export.py +++ b/src/spyglass/utils/mixins/export.py @@ -111,7 +111,7 @@ def _export_id_cleanup(self): def _start_export(self, paper_id, analysis_id): """Start export process.""" if self.export_id: - self._logger.info( + self._info_msg( f"Export {self.export_id} in progress. Starting new." ) self._stop_export(warn=False) @@ -127,7 +127,7 @@ def _start_export(self, paper_id, analysis_id): def _stop_export(self, warn=True): """End export process.""" if not self.export_id and warn: - self._logger.warning("Export not in progress.") + self._warn_msg("Export not in progress.") del self.export_id # --------------------------- Utility Functions --------------------------- diff --git a/src/spyglass/utils/mixins/ingestion.py b/src/spyglass/utils/mixins/ingestion.py index e578bc59f..520b310ad 100644 --- a/src/spyglass/utils/mixins/ingestion.py +++ b/src/spyglass/utils/mixins/ingestion.py @@ -311,7 +311,7 @@ def _key_has_required_attrs(self, key): if attr.nullable or attr.autoincrement or attr.default is not None: continue # skip nullable, autoincrement, or default val attrs if attr.name not in key or key.get(attr.name) is None: - logger.info( + self._info_msg( f"Key {key} missing required attribute {attr.name}." ) return False diff --git a/src/spyglass/utils/mixins/restrict_by.py b/src/spyglass/utils/mixins/restrict_by.py index 0ab91edd7..8687c9e28 100644 --- a/src/spyglass/utils/mixins/restrict_by.py +++ b/src/spyglass/utils/mixins/restrict_by.py @@ -77,9 +77,7 @@ def restrict_by( if len(ret) < len(self): # If it actually restricts, if not it might by a dict that # is not a valid restriction, returned as True - self._logger.warning( - "Restriction valid for this table. Using as is." - ) + self._warn_msg("Restriction valid for this table. Using as is.") return ret except DataJointError: # need assert_join_compatible return bool self._logger.debug("Restriction not valid. Attempting to cascade.") diff --git a/src/spyglass/utils/nwb_helper_fn.py b/src/spyglass/utils/nwb_helper_fn.py index 6ae5fa092..a45b1a07c 100644 --- a/src/spyglass/utils/nwb_helper_fn.py +++ b/src/spyglass/utils/nwb_helper_fn.py @@ -177,11 +177,12 @@ def get_config(nwb_file_path: str, calling_table: str = None) -> dict: obj_path.stem[:-1] + "_spyglass_config.yaml" ) if not os.path.exists(config_path): - from spyglass.settings import base_dir # noqa: F401 + from spyglass.settings import base_dir, test_mode # noqa: F401 rel_path = obj_path.relative_to(base_dir) table = f"{calling_table}: " if calling_table else "" - logger.info(f"{table}No config found at {rel_path}") + if not test_mode: + logger.info(f"{table}No config found at {rel_path}") ret = dict() __configs[nwb_file_path] = ret # cache to avoid repeated null lookups return ret @@ -368,7 +369,7 @@ def estimate_sampling_rate( def get_valid_intervals( - timestamps, sampling_rate, gap_proportion=2.5, min_valid_len=0 + timestamps, sampling_rate, gap_proportion=2.5, min_valid_len=0, warn=True ): """Finds the set of all valid intervals in a list of timestamps. @@ -388,6 +389,9 @@ def get_valid_intervals( min_valid_len : float, optional Length of smallest valid interval. Default to 0. If greater than interval duration, log warning and use half the total time. + warn : bool, optional + Whether to log a warning if the minimum valid interval length is greater + than the total time of the timestamps. Default, True. Returns ------- diff --git a/tests/common/test_task_epoch_tags.py b/tests/common/test_task_epoch_tags.py index e9388215a..4219b2fbd 100644 --- a/tests/common/test_task_epoch_tags.py +++ b/tests/common/test_task_epoch_tags.py @@ -154,7 +154,7 @@ def test_interval_list_accepts_all_tag_formats( assert 3 in task_epochs, "TaskEpoch should accept epoch 3 with tag '003'" -def test_task_epoch_get_epoch_interval_name(common, caplog): +def test_task_epoch_get_epoch_interval_name(common): """Test get_epoch_interval_name with single digit tags.""" get_epoch = common.TaskEpoch.get_epoch_interval_name msg_template = "get_epoch_interval_name should find '{}' when epoch is {}" @@ -172,12 +172,6 @@ def test_task_epoch_get_epoch_interval_name(common, caplog): result = get_epoch(epoch, session_intervals) assert result == expected, msg_template.format(expected, epoch) - # Test non-matching descriptive tag - caplog.clear() - result = get_epoch("fake_epoch", session_intervals) - assert result is None, "Should return None for non-matching epoch" - assert "for epoch fake_epoch" in caplog.text - def test_franklab_task_epoch_tags(common): """Test task epoch tags in the franklab format are handled correctly.""" diff --git a/tests/conftest.py b/tests/conftest.py index 8be896caa..bf8023299 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -382,8 +382,6 @@ def mini_insert( ["Root User", "email", "root", 1], skip_duplicates=True ) - dj_logger.info("Inserting test data.") - if not SERVER.connected: raise ConnectionError("No server connection.") @@ -667,9 +665,9 @@ def trodes_pos_v1(teardown, sgp, trodes_sel_keys): @pytest.fixture(scope="session") def pos_merge_tables(dj_conn): """Return the merge tables as activated.""" - from spyglass.common.common_position import TrackGraph from spyglass.lfp.lfp_merge import LFPOutput from spyglass.linearization.merge import LinearizedPositionOutput + from spyglass.linearization.v0.main import TrackGraph from spyglass.position.position_merge import PositionOutput # must import common_position before LinOutput to avoid circular import diff --git a/tests/container.py b/tests/container.py index 871c2c9cd..c61295ebd 100644 --- a/tests/container.py +++ b/tests/container.py @@ -230,10 +230,11 @@ def wait(self, timeout=120, wait=3) -> None: self.start() print("") - for i in range(timeout // wait): + self.logger.info(f"Container {self.container_name} starting...") + for _ in range(timeout // wait): if self.container.health == "healthy": break - self.logger.info(f"Container {self.container_name} starting... {i}") + print(".", end="") time.sleep(wait) self.logger.info( f"Container {self.container_name}, {self.container.health}." diff --git a/tests/utils/test_graph.py b/tests/utils/test_graph.py index 9062a201c..9aabe50d7 100644 --- a/tests/utils/test_graph.py +++ b/tests/utils/test_graph.py @@ -10,7 +10,7 @@ def leaf(lin_merge): @pytest.fixture(scope="session") -def restr_graph(leaf, verbose, lin_merge_key): +def restr_graph(leaf, lin_merge_key): from spyglass.utils.dj_graph import RestrGraph _ = lin_merge_key # linearization merge table populated @@ -20,7 +20,7 @@ def restr_graph(leaf, verbose, lin_merge_key): leaves={leaf.full_table_name: True}, include_files=True, cascade=True, - verbose=verbose, + verbose=False, ) @@ -44,6 +44,7 @@ def add_graph_rgs(add_graph_tables): leaves={tables["B1"].full_table_name: restr_1}, direction="up", cascade=True, + verbose=False, ) rg_1.cascade() @@ -51,6 +52,7 @@ def add_graph_rgs(add_graph_tables): seed_table=add_graph_tables["B2"], direction="up", cascade=True, + verbose=False, ) rg_2.add_leaf(table_name=tables["B2"].full_table_name, restriction=restr_2) rg_2.cascade() @@ -59,6 +61,7 @@ def add_graph_rgs(add_graph_tables): seed_table=add_graph_tables["B2"], direction="up", cascade=True, + verbose=False, ) rg_3.add_leaf(table_name=tables["B2"].full_table_name, restriction=restr_3) rg_3.cascade() @@ -183,11 +186,10 @@ def test_rg_restr_subset(restr_graph, leaf): assert len(prev_ft) == len(new_ft), "Subset sestriction changed length." -@pytest.mark.skipif(not VERBOSE, reason="No logging to test when quiet-spy") def test_rg_no_restr(caplog, restr_graph, common): restr_graph._set_restr(common.LabTeam, restriction=False) - restr_graph._get_ft(common.LabTeam.full_table_name, with_restr=True) - assert "No restr" in caplog.text, "No warning logged on no restriction." + ret = restr_graph._get_ft(common.LabTeam.full_table_name, with_restr=True) + assert not ret, "Expected empty restricted table when no restriction." def test_rg_invalid_direction(restr_graph, leaf): @@ -298,8 +300,9 @@ def test_null_restrict_by(graph_tables): @pytest.mark.skipif(not VERBOSE, reason="No logging to test when quiet-spy.") def test_restrict_by_this_table(caplog, graph_tables): PkNode = graph_tables["PkNode"]() - PkNode >> "pk_id > 4" - assert "valid for" in caplog.text, "No warning logged without search." + dist = (PkNode >> "pk_id > 4").restriction + plain = (PkNode & "pk_id > 4").restriction + assert dist == plain, "Restricting by own table did not use existing restr." def test_invalid_restr_direction(graph_tables): diff --git a/tests/utils/test_mixin.py b/tests/utils/test_mixin.py index 9197c67b7..374e1a7e5 100644 --- a/tests/utils/test_mixin.py +++ b/tests/utils/test_mixin.py @@ -93,9 +93,13 @@ def test_cautious_del_dry_run(Nwbfile, frequent_imports): @pytest.mark.skipif(not VERBOSE, reason="No logging to test when quiet-spy.") def test_empty_cautious_del(caplog, schema_test, Mixin): schema_test(Mixin) - Mixin().cautious_delete(safemode=False) - Mixin().cautious_delete(safemode=False) + mixin = Mixin() + prev_level = mixin._logger.level + mixin._logger.setLevel("INFO") + mixin.cautious_delete(safemode=False) + mixin.cautious_delete(safemode=False) assert "empty" in caplog.text, "No warning issued." + mixin._logger.setLevel(prev_level) def test_super_delete(schema_test, Mixin, common): From 7d7f0a1142ffe52af769f6a9fff145d4838bd1af Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Fri, 20 Feb 2026 15:56:48 +0100 Subject: [PATCH 11/30] Denoising tests 5 --- src/spyglass/common/common_file_tracking.py | 4 +- src/spyglass/common/common_session.py | 2 +- src/spyglass/common/common_usage.py | 2 +- .../position/v1/position_dlc_cohort.py | 6 +- .../v1/position_dlc_pose_estimation.py | 2 +- src/spyglass/spikesorting/v1/artifact.py | 9 ++- tests/conftest.py | 74 ++++++++++++++++++- 7 files changed, 83 insertions(+), 16 deletions(-) diff --git a/src/spyglass/common/common_file_tracking.py b/src/spyglass/common/common_file_tracking.py index fb324813d..ace016025 100644 --- a/src/spyglass/common/common_file_tracking.py +++ b/src/spyglass/common/common_file_tracking.py @@ -184,7 +184,7 @@ def show_downstream(self, restriction=True): """ entries = (self & "can_read=0" & restriction).fetch("KEY", as_dict=True) if not entries: - logger.info("No issues found.") + self._info_msg("No issues found.") return [] # Get unique analysis tables from entries @@ -206,6 +206,6 @@ def show_downstream(self, restriction=True): ret.append(child & entries) if not ret: - logger.info("No issues found.") + self._info_msg("No issues found.") return ret or [] diff --git a/src/spyglass/common/common_session.py b/src/spyglass/common/common_session.py index 1b6cf4933..8a79f9340 100644 --- a/src/spyglass/common/common_session.py +++ b/src/spyglass/common/common_session.py @@ -90,7 +90,7 @@ def generate_entries_from_nwb_object( base_key = base_key or dict() experimenter_list = nwb_obj.experimenter if not experimenter_list: - logger.info("No experimenter metadata found for Session.\n") + self._info_msg("No experimenter metadata found for Session.\n") return dict() entries = [] diff --git a/src/spyglass/common/common_usage.py b/src/spyglass/common/common_usage.py index 612d26286..2bbab7471 100644 --- a/src/spyglass/common/common_usage.py +++ b/src/spyglass/common/common_usage.py @@ -141,7 +141,7 @@ def insert1_return_pk(self, key: dict, **kwargs) -> int: if query := (Export & export_key): safemode = False if test_mode else None # No prompt in tests query.super_delete(warn=False, safemode=safemode) - logger.info(f"{status} {export_key}") + self._info_msg(f"{status} {export_key}") return export_id def start_export(self, paper_id, analysis_id) -> None: diff --git a/src/spyglass/position/v1/position_dlc_cohort.py b/src/spyglass/position/v1/position_dlc_cohort.py index 1f9b90fce..4dc168cf5 100644 --- a/src/spyglass/position/v1/position_dlc_cohort.py +++ b/src/spyglass/position/v1/position_dlc_cohort.py @@ -6,9 +6,9 @@ from spyglass.common.common_nwbfile import AnalysisNwbfile from spyglass.position.v1.dlc_utils import file_log, infer_output_dir -from spyglass.position.v1.position_dlc_pose_estimation import ( # noqa: F401 +from spyglass.position.v1.position_dlc_pose_estimation import ( DLCPoseEstimation, -) +) # noqa: F401 from spyglass.position.v1.position_dlc_position import DLCSmoothInterp from spyglass.utils import SpyglassMixin, logger @@ -103,7 +103,7 @@ def make(self, key): output_dir = infer_output_dir(key=key, makedir=False) self.log_path = Path(output_dir) / "log.log" self._logged_make(key) - logger.info("Inserted entry into DLCSmoothInterpCohort") + self._info_msg("Inserted entry into DLCSmoothInterpCohort") @file_log(logger, console=False) def _logged_make(self, key): diff --git a/src/spyglass/position/v1/position_dlc_pose_estimation.py b/src/spyglass/position/v1/position_dlc_pose_estimation.py index 0869a7b7f..36a9859a8 100644 --- a/src/spyglass/position/v1/position_dlc_pose_estimation.py +++ b/src/spyglass/position/v1/position_dlc_pose_estimation.py @@ -276,7 +276,7 @@ def _logged_make(self, key): # video_frame_ind from RawPosition, which also has timestamps # Insert entry into DLCPoseEstimation - logger.info( + self._info_msg( "Inserting %s, epoch %02d into DLCPoseEsimation", key["nwb_file_name"], key["epoch"], diff --git a/src/spyglass/spikesorting/v1/artifact.py b/src/spyglass/spikesorting/v1/artifact.py index 4e06614a7..529c350f5 100644 --- a/src/spyglass/spikesorting/v1/artifact.py +++ b/src/spyglass/spikesorting/v1/artifact.py @@ -250,10 +250,11 @@ def _get_artifact_times( # if both thresholds are None, we skip artifract detection if amplitude_thresh_uV is zscore_thresh is None: - logger.info( - "Amplitude and zscore thresholds are both None, " - + "skipping artifact detection" - ) + if verbose: + logger.info( + "Amplitude and zscore thresholds are both None, " + + "skipping artifact detection" + ) return np.asarray( [valid_timestamps[0], valid_timestamps[-1]] ), np.asarray([]) diff --git a/tests/conftest.py b/tests/conftest.py index bf8023299..ed4eeb2ae 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -43,8 +43,11 @@ from shutil import rmtree as shutil_rmtree import datajoint as dj +import datajoint.hash as _dj_hash import numpy as np import pynwb +import pynwb.device as _pynwb_device +import pynwb.io.device as _pynwb_io_device import pytest from datajoint.logging import logger as dj_logger from hdmf.build.warnings import MissingRequiredBuildWarning @@ -56,6 +59,69 @@ # ------------------------------- TESTS CONFIG ------------------------------- + +# ---------- Fix ResourceWarning from datajoint.hash.uuid_from_file ----------- +# Patch uuid_from_file to properly close file handles (upstream opens without +# `with`, triggering ResourceWarning on GC). This is safe: the function reads +# the whole file before returning, so closing after uuid_from_stream is fine. +def _uuid_from_file_safe(filepath, *, init_string=""): + with Path(filepath).open("rb") as f: + return _dj_hash.uuid_from_stream(f, init_string=init_string) + + +_dj_hash.uuid_from_file = _uuid_from_file_safe + +# ----------- Prevent NWB-2.9 migration warnings from test NWB file ----------- +# Patch pynwb Device NWB-2.9 migration warnings triggered by the test NWB file, +# which was written before NWB 2.9 (Device.model stored as string, manufacturer +# as a field). pynwb uses stacklevel= values that attribute these warnings to +# hdmf internals (hdmf.build.objectmapper / hdmf.utils) rather than to pynwb, +# so they bypass any module-specific filter and can defeat category-only filters +# once an hdmf module's __warningregistry__ pre-dates the filter installation. +# +# Two distinct call sites require two different strategies: +# +# (a) pynwb/io/device.py uses `from warnings import warn` — the `warn` name +# lives in pynwb.io.device's namespace, so it is directly patchable. +# +# (b) pynwb/device.py uses `import warnings; warnings.warn(...)` — we cannot +# replace an attribute on the warnings module itself without side-effects, +# so we wrap Device.__init__ instead. + +_orig_io_device_warn = _pynwb_io_device.warn + + +def _io_device_warn_filtered(message, *args, **kwargs): + if "Device.model was detected as a string" not in str(message): + _orig_io_device_warn(message, *args, **kwargs) + + +_pynwb_io_device.warn = _io_device_warn_filtered + +_orig_device_init = _pynwb_device.Device.__init__ + + +def _device_init_no_field_deprecations(*args, **kwargs): + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + message=r"The '(?:manufacturer|model_number|model_name)' field is deprecated", + category=DeprecationWarning, + ) + return _orig_device_init(*args, **kwargs) + + +# hdmf's objectmapper calls get_docval(cls.__init__) to discover constructor +# arguments. get_docval reads the __docval__ / __docval_idx__ attributes that +# the @docval decorator stores in func.__dict__. Copy the original function's +# __dict__ to the wrapper so that Device subclasses which inherit __init__ +# (e.g. ndx-franklab-novela's CameraDevice) still work correctly. +_device_init_no_field_deprecations.__dict__.update(_orig_device_init.__dict__) +_device_init_no_field_deprecations.__name__ = _orig_device_init.__name__ +_device_init_no_field_deprecations.__module__ = _orig_device_init.__module__ + +_pynwb_device.Device.__init__ = _device_init_no_field_deprecations + # globals in pytest_configure: # BASE_DIR, RAW_DIR, SERVER, TEARDOWN, VERBOSE, TEST_FILE, DOWNLOAD, NO_DLC @@ -392,10 +458,10 @@ def mini_insert( # Useful try/except for avoiding a full run on insert failure # Should be commented out in favor of vanilla insert for debugging # the insert_sessions function itself. - try: - insert_sessions(mini_path.name, raise_err=True) - except Exception as e: # If can't insert session, exit all tests - pytest.exit(f"Failed to insert sessions: {e}") + # try: + insert_sessions(mini_path.name, raise_err=True) + # except Exception as e: # If can't insert session, exit all tests + # pytest.exit(f"Failed to insert sessions: {e}") if len(Session()) == 0: raise ValueError("No sessions inserted.") From 7994c623abec324b6db78b292b198ce3e3998e48 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Fri, 20 Feb 2026 16:02:21 +0100 Subject: [PATCH 12/30] Denoising tests 6 --- src/spyglass/spikesorting/v0/spikesorting_recompute.py | 2 +- src/spyglass/spikesorting/v1/recompute.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spyglass/spikesorting/v0/spikesorting_recompute.py b/src/spyglass/spikesorting/v0/spikesorting_recompute.py index 938790db2..65a571d5f 100644 --- a/src/spyglass/spikesorting/v0/spikesorting_recompute.py +++ b/src/spyglass/spikesorting/v0/spikesorting_recompute.py @@ -803,7 +803,7 @@ def delete_files( if days_since_creation > 0: date_temp = "created_at < DATE_SUB(CURDATE(), INTERVAL {} DAY)" query = query & date_temp.format(days_since_creation) - logger.info( + self._info_msg( f"Excluding files created within {days_since_creation} days" ) diff --git a/src/spyglass/spikesorting/v1/recompute.py b/src/spyglass/spikesorting/v1/recompute.py index 2506b2f74..88aa91437 100644 --- a/src/spyglass/spikesorting/v1/recompute.py +++ b/src/spyglass/spikesorting/v1/recompute.py @@ -897,7 +897,7 @@ def delete_files( if days_since_creation > 0: date_templ = "created_at < DATE_SUB(CURDATE(), INTERVAL {} DAY)" query = query & date_templ.format(days_since_creation) - logger.info( + self._info_msg( f"Excluding files created within {days_since_creation} days" ) From ac839e35522affa6e27cfcf39533900998a9442a Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Fri, 20 Feb 2026 16:16:43 +0100 Subject: [PATCH 13/30] Revert subset of denoising --- src/spyglass/decoding/v0/clusterless.py | 2 +- src/spyglass/decoding/v0/core.py | 2 +- src/spyglass/decoding/v0/dj_decoder_conversion.py | 2 +- src/spyglass/decoding/v0/sorted_spikes.py | 2 +- src/spyglass/decoding/v0/utils.py | 2 +- src/spyglass/utils/mixins/analysis.py | 8 +++----- 6 files changed, 8 insertions(+), 10 deletions(-) diff --git a/src/spyglass/decoding/v0/clusterless.py b/src/spyglass/decoding/v0/clusterless.py index f6e03aba0..f3fe89a3b 100644 --- a/src/spyglass/decoding/v0/clusterless.py +++ b/src/spyglass/decoding/v0/clusterless.py @@ -45,7 +45,7 @@ DiagonalDiscrete, UniformInitialConditions, ) = [None] * 5 - logger.debug(e) + logger.warning(e) from tqdm.auto import tqdm diff --git a/src/spyglass/decoding/v0/core.py b/src/spyglass/decoding/v0/core.py index 4bd1f0a89..11baf063a 100644 --- a/src/spyglass/decoding/v0/core.py +++ b/src/spyglass/decoding/v0/core.py @@ -14,7 +14,7 @@ ) except (ImportError, ModuleNotFoundError) as e: RandomWalk, Uniform, Environment, ObservationModel = None, None, None, None - logger.debug(e) + logger.warning(e) from spyglass.common.common_behav import PositionIntervalMap, RawPosition from spyglass.common.common_interval import IntervalList diff --git a/src/spyglass/decoding/v0/dj_decoder_conversion.py b/src/spyglass/decoding/v0/dj_decoder_conversion.py index ad13bd687..af03f541e 100644 --- a/src/spyglass/decoding/v0/dj_decoder_conversion.py +++ b/src/spyglass/decoding/v0/dj_decoder_conversion.py @@ -41,7 +41,7 @@ UniformOneEnvironmentInitialConditions, ObservationModel, ) = [None] * 13 - logger.debug(e) + logger.warning(e) from track_linearization import make_track_graph diff --git a/src/spyglass/decoding/v0/sorted_spikes.py b/src/spyglass/decoding/v0/sorted_spikes.py index ceda58cd9..914b8b936 100644 --- a/src/spyglass/decoding/v0/sorted_spikes.py +++ b/src/spyglass/decoding/v0/sorted_spikes.py @@ -34,7 +34,7 @@ DiagonalDiscrete, UniformInitialConditions, ) = [None] * 5 - logger.debug(e) + logger.warning(e) from spyglass.common.common_behav import ( convert_epoch_interval_name_to_position_interval_name, diff --git a/src/spyglass/decoding/v0/utils.py b/src/spyglass/decoding/v0/utils.py index d99ea4681..7f1ae1c15 100644 --- a/src/spyglass/decoding/v0/utils.py +++ b/src/spyglass/decoding/v0/utils.py @@ -25,7 +25,7 @@ DiagonalDiscrete, UniformInitialConditions, ) = [None] * 6 - logger.debug(e) + logger.warning(e) def get_time_bins_from_interval(interval_times: np.array, sampling_rate: int): diff --git a/src/spyglass/utils/mixins/analysis.py b/src/spyglass/utils/mixins/analysis.py index 9004e4c21..267f57abb 100644 --- a/src/spyglass/utils/mixins/analysis.py +++ b/src/spyglass/utils/mixins/analysis.py @@ -832,7 +832,7 @@ def add_units( # to ensure that things go in the right order metric_values = metric_values[np.argsort(unit_ids)] - self._logger.debug(f"Adding metric {metric} : {metric_values}") + self._info_msg(f"Adding metric {metric} : {metric_values}") nwbf.add_unit_column( name=metric, description=f"{metric} metric", @@ -917,7 +917,7 @@ def add_units_waveforms( # If metrics were specified, add one column per metric if metrics is not None: for metric_name, metric_dict in metrics.items(): - self._logger.debug( + self._info_msg( f"Adding metric {metric_name} : {metric_dict}" ) metric_data = metric_dict.values().to_list() @@ -963,9 +963,7 @@ def add_units_metrics(self, analysis_file_name: str, metrics: dict): nwbf.add_unit(id=id) for metric_name, metric_dict in metrics.items(): - self._logger.debug( - f"Adding metric {metric_name} : {metric_dict}" - ) + self._info_msg(f"Adding metric {metric_name} : {metric_dict}") metric_data = list(metric_dict.values()) nwbf.add_unit_column( name=metric_name, description=metric_name, data=metric_data From 57bc8e03a74b38a7ff43d998cd89dcd7e5a336ea Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Fri, 20 Feb 2026 16:17:54 +0100 Subject: [PATCH 14/30] Update changelog --- CHANGELOG.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 64242f2c4..b48e92bb6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -137,8 +137,10 @@ for label, interval_data in results.groupby("interval_labels"): - Log expected recompute failures #1470 - Track file created/deletion status of recomputes #1470 - Upgrade to pynwb>=3.1 #1506 -- Remove imports of ndx extensions in main package to prevent errors in nwb io #1506 +- Remove imports of ndx extensions in main package to prevent errors in nwb io + #1506 - Add `analysis_table` property to mixin for custom pipelines #1525 +- Quiet pytest output for expected warnings in test runs #1534 ### Pipelines From 4e61e6f9e6aad7696a7f4452c18ed23c878cac7c Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Fri, 20 Feb 2026 16:41:01 +0100 Subject: [PATCH 15/30] Revert 2 --- src/spyglass/common/common_file_tracking.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/spyglass/common/common_file_tracking.py b/src/spyglass/common/common_file_tracking.py index ace016025..5086e956f 100644 --- a/src/spyglass/common/common_file_tracking.py +++ b/src/spyglass/common/common_file_tracking.py @@ -5,6 +5,7 @@ from tqdm import tqdm from spyglass.common.common_nwbfile import AnalysisRegistry +from spyglass.settings import test_mode from spyglass.utils import SpyglassAnalysis, logger schema = dj.Schema("common_file_tracking") @@ -183,8 +184,8 @@ def show_downstream(self, restriction=True): Downstream tables that reference files with issues """ entries = (self & "can_read=0" & restriction).fetch("KEY", as_dict=True) - if not entries: - self._info_msg("No issues found.") + if not entries and not test_mode: + logger.info("No issues found.") return [] # Get unique analysis tables from entries @@ -205,7 +206,7 @@ def show_downstream(self, restriction=True): if child & entries: ret.append(child & entries) - if not ret: - self._info_msg("No issues found.") + if not ret and not test_mode: + logger.info("No issues found.") return ret or [] From 99dd9c51b230021706377594cb5f2be2a4aedc16 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Fri, 20 Feb 2026 17:18:24 +0100 Subject: [PATCH 16/30] Revert 3 --- src/spyglass/common/common_file_tracking.py | 2 +- tests/conftest.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/spyglass/common/common_file_tracking.py b/src/spyglass/common/common_file_tracking.py index 5086e956f..0b4279e01 100644 --- a/src/spyglass/common/common_file_tracking.py +++ b/src/spyglass/common/common_file_tracking.py @@ -184,7 +184,7 @@ def show_downstream(self, restriction=True): Downstream tables that reference files with issues """ entries = (self & "can_read=0" & restriction).fetch("KEY", as_dict=True) - if not entries and not test_mode: + if not entries: logger.info("No issues found.") return [] diff --git a/tests/conftest.py b/tests/conftest.py index ed4eeb2ae..9ed15902d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -458,10 +458,10 @@ def mini_insert( # Useful try/except for avoiding a full run on insert failure # Should be commented out in favor of vanilla insert for debugging # the insert_sessions function itself. - # try: - insert_sessions(mini_path.name, raise_err=True) - # except Exception as e: # If can't insert session, exit all tests - # pytest.exit(f"Failed to insert sessions: {e}") + try: + insert_sessions(mini_path.name, raise_err=True) + except Exception as e: # If can't insert session, exit all tests + pytest.exit(f"Failed to insert sessions: {e}") if len(Session()) == 0: raise ValueError("No sessions inserted.") From 2b843ccfc82edbe18ccde6601f15b2fbe3e3041c Mon Sep 17 00:00:00 2001 From: Chris Broz Date: Fri, 20 Feb 2026 18:47:13 +0100 Subject: [PATCH 17/30] Apply suggestions from code review Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Samuel Bray --- src/spyglass/common/common_usage.py | 2 +- src/spyglass/utils/dj_merge_tables.py | 2 +- tests/decoding/conftest.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/spyglass/common/common_usage.py b/src/spyglass/common/common_usage.py index 2bbab7471..d7eee94c1 100644 --- a/src/spyglass/common/common_usage.py +++ b/src/spyglass/common/common_usage.py @@ -570,7 +570,7 @@ def make(self, key): for file in tqdm( file_paths, desc="Checking linked nwb files", - disable=not test_mode, + disable=test_mode, ): unlinked_files.update(get_unlinked_files(file)) else: diff --git a/src/spyglass/utils/dj_merge_tables.py b/src/spyglass/utils/dj_merge_tables.py index c69f0bc8c..eeda06a7d 100644 --- a/src/spyglass/utils/dj_merge_tables.py +++ b/src/spyglass/utils/dj_merge_tables.py @@ -889,7 +889,7 @@ def super_delete(self, warn=True, *args, **kwargs): Added to support MRO of SpyglassMixin """ if warn: - logger._warn_msg("!! Bypassing cautious_delete !!") + self._warn_msg("!! Bypassing cautious_delete !!") self._log_delete(start=time(), super_delete=True) super().delete(*args, **kwargs) diff --git a/tests/decoding/conftest.py b/tests/decoding/conftest.py index 2cd097a2b..32b5135f1 100644 --- a/tests/decoding/conftest.py +++ b/tests/decoding/conftest.py @@ -41,7 +41,7 @@ def mock_to_netcdf( pickle.dump(self, f) except (FileNotFoundError, PermissionError, OSError): # Copilot suggested that this is where a file might throw error - # during teatdown, attempted automatic cleanup. + # during teardown, attempted automatic cleanup. pass return None From e6b02aa6a80234f5c2dfd727ca1b5fadc6be3c85 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Fri, 20 Feb 2026 18:53:22 +0100 Subject: [PATCH 18/30] Apply more suggestions --- scripts/install.py | 2 +- src/spyglass/spikesorting/analysis/v1/group.py | 4 ++-- src/spyglass/spikesorting/v1/metric_curation.py | 4 +++- tests/common/test_video_import_fail.py | 8 +++++++- 4 files changed, 13 insertions(+), 5 deletions(-) diff --git a/scripts/install.py b/scripts/install.py index a70ceede0..3f0d0dcc0 100755 --- a/scripts/install.py +++ b/scripts/install.py @@ -620,7 +620,7 @@ def create(self, env_file: str, force: bool = False) -> None: kw in line for kw in ["Solving", "Downloading", "Extracting"] ): - Console.print(".", end="", flush=True) + print(".", end="", flush=True) Console.print() if process.returncode != 0: diff --git a/src/spyglass/spikesorting/analysis/v1/group.py b/src/spyglass/spikesorting/analysis/v1/group.py index cb60e715a..be69219e8 100644 --- a/src/spyglass/spikesorting/analysis/v1/group.py +++ b/src/spyglass/spikesorting/analysis/v1/group.py @@ -90,8 +90,8 @@ def create_group( if test_mode: return raise ValueError( - f"Group {nwb_file_name}: {group_name} already exists ", - "please delete the group before creating a new one", + f"Group {nwb_file_name}: {group_name} already exists " + + "please delete the group before creating a new one", ) parts_insert = [{**key, **group_key} for key in keys] diff --git a/src/spyglass/spikesorting/v1/metric_curation.py b/src/spyglass/spikesorting/v1/metric_curation.py index 960032d80..ef905eb11 100644 --- a/src/spyglass/spikesorting/v1/metric_curation.py +++ b/src/spyglass/spikesorting/v1/metric_curation.py @@ -373,7 +373,9 @@ def get_waveforms( # Extract non-sparse waveforms by default waveform_params.setdefault("sparse", False) - if overwrite or not Path(waveforms_dir).exists(): + dir_empty = not any(Path(waveforms_dir).iterdir()) + + if overwrite or dir_empty: waveforms = si.extract_waveforms( recording=recording, sorting=sorting, diff --git a/tests/common/test_video_import_fail.py b/tests/common/test_video_import_fail.py index dc91cad6a..2224b45d1 100644 --- a/tests/common/test_video_import_fail.py +++ b/tests/common/test_video_import_fail.py @@ -6,6 +6,7 @@ import pytest from ndx_franklab_novela import CameraDevice from pynwb import NWBHDF5IO +from pynwb.device import DeviceModel from pynwb.image import ImageSeries from pynwb.testing.mock.file import mock_NWBFile, mock_Subject @@ -26,13 +27,18 @@ def nwb_with_video_no_task(raw_dir, common): ) nwbfile.subject = mock_Subject() + camera_model = DeviceModel( + name="TestCam 3000", + manufacturer="Test Camera Co", + ) camera_device = CameraDevice( name="camera_device 0", meters_per_pixel=0.001, - manufacturer="Test Camera Co", + model=camera_model, lens="50mm", camera_name="test_camera", ) + nwbfile.add_device_model(camera_model) nwbfile.add_device(camera_device) # Create ImageSeries (video data) with timestamps From 077ea43ead981508c5844c9c5d724b3b1d1fb8ce Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Fri, 20 Feb 2026 18:54:31 +0100 Subject: [PATCH 19/30] Fix 'overwrite' overwrite --- src/spyglass/spikesorting/v1/metric_curation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spyglass/spikesorting/v1/metric_curation.py b/src/spyglass/spikesorting/v1/metric_curation.py index ef905eb11..cc78cb73f 100644 --- a/src/spyglass/spikesorting/v1/metric_curation.py +++ b/src/spyglass/spikesorting/v1/metric_curation.py @@ -380,7 +380,7 @@ def get_waveforms( recording=recording, sorting=sorting, folder=waveforms_dir, - overwrite=True, + overwrite=overwrite, **waveform_params, ) else: From f22b65549cfa8cdd4ac9cba963b78b7b60eeddc0 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Sat, 21 Feb 2026 11:44:53 +0100 Subject: [PATCH 20/30] Denoise tests 7 --- src/spyglass/common/common_position.py | 2 +- src/spyglass/lfp/analysis/v1/lfp_band.py | 8 +- src/spyglass/lfp/v1/lfp.py | 2 +- .../position/v1/position_dlc_centroid.py | 16 +-- .../position/v1/position_dlc_model.py | 2 +- .../v1/position_dlc_pose_estimation.py | 21 ++-- .../position/v1/position_dlc_position.py | 18 +-- .../position/v1/position_dlc_training.py | 4 +- .../position/v1/position_trodes_position.py | 2 +- .../spikesorting/v1/metric_curation.py | 8 +- src/spyglass/spikesorting/v1/recompute.py | 40 ++++--- src/spyglass/utils/database_settings.py | 4 + src/spyglass/utils/sql_helper_fn.py | 2 +- tests/conftest.py | 107 ++++++++++++++++++ 14 files changed, 177 insertions(+), 59 deletions(-) diff --git a/src/spyglass/common/common_position.py b/src/spyglass/common/common_position.py index 4734b16b7..c3368a9e8 100644 --- a/src/spyglass/common/common_position.py +++ b/src/spyglass/common/common_position.py @@ -113,7 +113,7 @@ class IntervalPositionInfo(SpyglassMixin, dj.Computed): def make(self, key): """Insert smoothed head position, orientation and velocity.""" - logger.info(f"Computing position for: {key}") + self._info_msg(f"Computing position for: {key}") analysis_file_name = AnalysisNwbfile().create(key["nwb_file_name"]) diff --git a/src/spyglass/lfp/analysis/v1/lfp_band.py b/src/spyglass/lfp/analysis/v1/lfp_band.py index 3558c6085..8c82bee1b 100644 --- a/src/spyglass/lfp/analysis/v1/lfp_band.py +++ b/src/spyglass/lfp/analysis/v1/lfp_band.py @@ -197,7 +197,7 @@ def set_lfp_band_electrodes( # Warn if the sampling rate is not the same as the original if lfp_band_sampling_rate is not None: - logger.info( + self._info_msg( "It is recommended to use the same sampling rate as the original " + "lfp data to avoid aliasing." ) @@ -268,7 +268,7 @@ def set_lfp_band_electrodes( "electrode_id" ) if set(existing_electrodes) == set(electrode_list): - logger.info( + self._info_msg( f"LFPBandSelection already exists for {master_key}; " + "not inserting" ) @@ -736,12 +736,12 @@ def fix_1481(self, restriction: Optional[dict] = True) -> None: ] if not fixed_keys: - logger.info("No entries needed to be fixed for the 1481 bug.") + self._info_msg("No entries needed to be fixed for the 1481 bug.") return from spyglass.ripple.v1.ripple import RippleTimesV1 - logger.info( + self._info_msg( f"Fixing {len(fixed_keys)} entries in the RippleTimesV1 table " "due to the 1481 bug. See github issue #1481 for more details." ) diff --git a/src/spyglass/lfp/v1/lfp.py b/src/spyglass/lfp/v1/lfp.py index a2d837101..4f09410f7 100644 --- a/src/spyglass/lfp/v1/lfp.py +++ b/src/spyglass/lfp/v1/lfp.py @@ -97,7 +97,7 @@ def make(self, key): raw_valid_times, min_length=MIN_LFP_INTERVAL_DURATION ) - logger.info( + self._info_msg( f"LFP: found {len(valid_times)} intervals > " + f"{MIN_LFP_INTERVAL_DURATION} sec long." ) diff --git a/src/spyglass/position/v1/position_dlc_centroid.py b/src/spyglass/position/v1/position_dlc_centroid.py index eacbf5a94..680361649 100644 --- a/src/spyglass/position/v1/position_dlc_centroid.py +++ b/src/spyglass/position/v1/position_dlc_centroid.py @@ -179,7 +179,7 @@ def make(self, key): output_dir = infer_output_dir(key=key, makedir=False) self.log_path = Path(output_dir, "log.log") self._logged_make(key) - logger.info("inserted entry into DLCCentroid") + self._info_msg("inserted entry into DLCCentroid") def _fetch_pos_df(self, key, bodyparts_to_use): return pd.concat( @@ -200,7 +200,7 @@ def _available_bodyparts(self, key): def _logged_make(self, key): METERS_PER_CM = 0.01 idx = pd.IndexSlice - logger.info("Centroid Calculation") + self._info_msg("Centroid Calculation") # Get labels to smooth from Parameters table params = (DLCCentroidParams() & key).fetch1("params") @@ -241,7 +241,9 @@ def _logged_make(self, key): pos_df = self._fetch_pos_df(key=key, bodyparts_to_use=bodyparts_to_use) - logger.info("Calculating centroid") # now done using number of points + self._info_msg( + "Calculating centroid" + ) # now done using number of points centroid = Centroid( pos_df=pos_df, points=params.get("points"), @@ -255,7 +257,7 @@ def _logged_make(self, key): if params.get("interpolate"): if np.any(np.isnan(centroid)): - logger.info("interpolating over NaNs") + self._info_msg("interpolating over NaNs") nan_inds = ( pd.isnull(centroid_df.loc[:, idx[("x", "y")]]) .any(axis=1) @@ -279,7 +281,7 @@ def _logged_make(self, key): smooth_func = _key_to_smooth_func_dict[ smooth_params["smooth_method"] ] - logger.info( + self._info_msg( f"Smoothing using method: {smooth_func.__name__}", ) final_df = smooth_func( @@ -288,7 +290,7 @@ def _logged_make(self, key): else: final_df = interp_df.copy() - logger.info("getting velocity") + self._info_msg("getting velocity") velocity = get_velocity( final_df.loc[:, idx[("x", "y")]].to_numpy(), time=pos_df.index.to_numpy(), @@ -303,7 +305,7 @@ def _logged_make(self, key): ) total_nan = np.sum(final_df.loc[:, idx[("x", "y")]].isna().any(axis=1)) - logger.info(f"total NaNs in centroid dataset: {total_nan}") + self._info_msg(f"total NaNs in centroid dataset: {total_nan}") position = pynwb.behavior.Position() velocity = pynwb.behavior.BehavioralTimeSeries() if query := (RawPosition() & key): diff --git a/src/spyglass/position/v1/position_dlc_model.py b/src/spyglass/position/v1/position_dlc_model.py index 34e0fd6b2..1b96e9d82 100644 --- a/src/spyglass/position/v1/position_dlc_model.py +++ b/src/spyglass/position/v1/position_dlc_model.py @@ -305,7 +305,7 @@ def make(self, key): self.BodyPart.insert( {**part_key, "bodypart": bp} for bp in dlc_config["bodyparts"] ) - logger.info( + self._info_msg( f"Finished inserting {model_name}, training iteration" f" {dlc_config['iteration']} into DLCModel" ) diff --git a/src/spyglass/position/v1/position_dlc_pose_estimation.py b/src/spyglass/position/v1/position_dlc_pose_estimation.py index 36a9859a8..a5bf664c6 100644 --- a/src/spyglass/position/v1/position_dlc_pose_estimation.py +++ b/src/spyglass/position/v1/position_dlc_pose_estimation.py @@ -114,7 +114,7 @@ def insert_estimation_task( self._insert_est_with_log( key, task_mode, params, check_crop, skip_duplicates, output_dir ) - logger.info("inserted entry into Pose Estimation Selection") + self._info_msg("inserted entry into Pose Estimation Selection") return {**key, "task_mode": task_mode} @file_log(logger, console=False) @@ -124,8 +124,8 @@ def _insert_est_with_log( v_path, v_fname, _, _ = get_video_info(key) if not v_path: raise FileNotFoundError(f"Video file not found for {key}") - logger.info("Pose Estimation Selection") - logger.info(f"video_dir: {v_path}") + self._info_msg("Pose Estimation Selection") + self._info_msg(f"video_dir: {v_path}") v_path = find_mp4(video_path=Path(v_path), video_filename=v_fname) if check_crop: params["cropping"] = self.get_video_crop( @@ -220,8 +220,8 @@ def make(self, key): def _logged_make(self, key): METERS_PER_CM = 0.01 - logger.info("----------------------") - logger.info("Pose Estimation") + self._info_msg("----------------------") + self._info_msg("Pose Estimation") # ID model and directories dlc_model = (DLCModel & key).fetch1() bodyparts = (DLCModel.BodyPart & key).fetch("bodypart") @@ -252,7 +252,7 @@ def _logged_make(self, key): dlc_result.creation_time ).strftime("%Y-%m-%d %H:%M:%S") - logger.info("getting raw position") + self._info_msg("getting raw position") interval_list_name = ( convert_epoch_interval_name_to_position_interval_name( { @@ -277,9 +277,8 @@ def _logged_make(self, key): # Insert entry into DLCPoseEstimation self._info_msg( - "Inserting %s, epoch %02d into DLCPoseEsimation", - key["nwb_file_name"], - key["epoch"], + f"Inserting {key['nwb_file_name']}, epoch {key['epoch']:02}" + + " into DLCPoseEsimation" ) self.insert1({**key, "pose_estimation_time": creation_time}) @@ -298,9 +297,9 @@ def _logged_make(self, key): ) idx = pd.IndexSlice for body_part, part_df in body_parts_df.items(): - logger.info("converting to cm") + self._info_msg("converting to cm") part_df = convert_to_cm(part_df, meters_per_pixel) - logger.info("adding timestamps to DataFrame") + self._info_msg("adding timestamps to DataFrame") part_df = add_timestamps( part_df, pos_time=getattr(spatial_series, "timestamps", video_time), diff --git a/src/spyglass/position/v1/position_dlc_position.py b/src/spyglass/position/v1/position_dlc_position.py index c8f3ec95a..fac7b522c 100644 --- a/src/spyglass/position/v1/position_dlc_position.py +++ b/src/spyglass/position/v1/position_dlc_position.py @@ -189,18 +189,18 @@ def make(self, key): Path(infer_output_dir(key=key, makedir=False)) / "log.log" ) self._logged_make(key) - logger.info("inserted entry into DLCSmoothInterp") + self._info_msg("inserted entry into DLCSmoothInterp") @file_log(logger, console=False) def _logged_make(self, key): METERS_PER_CM = 0.01 - logger.info("-----------------------") + self._info_msg("-----------------------") idx = pd.IndexSlice # Get labels to smooth from Parameters table params = (DLCSmoothInterpParams() & key).fetch1("params") # Get DLC output dataframe - logger.info("fetching Pose Estimation Dataframe") + self._info_msg("fetching Pose Estimation Dataframe") bp_key = key.copy() if test_mode: # during testing, analysis_file not in BodyPart table @@ -208,7 +208,7 @@ def _logged_make(self, key): dlc_df = (DLCPoseEstimation.BodyPart() & bp_key).fetch1_dataframe() dt = np.median(np.diff(dlc_df.index.to_numpy())) - logger.info("Identifying indices to NaN") + self._info_msg("Identifying indices to NaN") likelihood_thresh = params.pop("likelihood_thresh") df_w_nans, bad_inds = nan_inds( dlc_df.copy(), @@ -221,11 +221,11 @@ def _logged_make(self, key): if params.get("interpolate"): interp_params = params.get("interp_params", dict()) - logger.info("interpolating across low likelihood times") + self._info_msg("interpolating across low likelihood times") interp_df = interp_pos(df_w_nans.copy(), nan_spans, **interp_params) else: interp_df = df_w_nans.copy() - logger.info("skipping interpolation") + self._info_msg("skipping interpolation") if params.get("smooth"): smooth_params = params.get("smoothing_params") @@ -238,13 +238,13 @@ def _logged_make(self, key): ].pop("smoothing_duration", None) dt = np.median(np.diff(dlc_df.index.to_numpy())) - logger.info(f"Smoothing using method: {smooth_method}") + self._info_msg(f"Smoothing using method: {smooth_method}") smooth_df = smooth_func( interp_df, smoothing_duration=smooth_dur, sampling_rate=1 / dt ) else: smooth_df = interp_df.copy() - logger.info("skipping smoothing") + self._info_msg("skipping smoothing") final_df = smooth_df.drop(["likelihood"], axis=1) final_df = final_df.rename_axis("time").reset_index() @@ -261,7 +261,7 @@ def _logged_make(self, key): nwb_analysis_file = AnalysisNwbfile() position = pynwb.behavior.Position() video_frame_ind = pynwb.behavior.BehavioralTimeSeries() - logger.info("Creating NWB objects") + self._info_msg("Creating NWB objects") position.create_spatial_series( name="position", timestamps=final_df.time.to_numpy(), diff --git a/src/spyglass/position/v1/position_dlc_training.py b/src/spyglass/position/v1/position_dlc_training.py index f06f83fa0..eeefecda0 100644 --- a/src/spyglass/position/v1/position_dlc_training.py +++ b/src/spyglass/position/v1/position_dlc_training.py @@ -200,7 +200,7 @@ def make_compute( for k, v in dlc_config.items() if k in get_param_names(create_training_dataset) } - logger.info("creating training dataset") + self._info_msg("creating training dataset") # NOTE: if DLC > 3, this will raise engine error create_training_dataset(dlc_cfg_filepath, **training_dataset_kwargs) @@ -220,7 +220,7 @@ def make_compute( with suppress_print_from_package(): train_network(dlc_cfg_filepath, **train_network_kwargs) except KeyboardInterrupt: # pragma: no cover - logger.info("DLC training stopped via Keyboard Interrupt") + self._info_msg("DLC training stopped via Keyboard Interrupt") except Exception as e: msg = str(e) hit_end_of_train = ("CancelledError" in msg) and ( diff --git a/src/spyglass/position/v1/position_trodes_position.py b/src/spyglass/position/v1/position_trodes_position.py index 0816612ee..37ee569a6 100644 --- a/src/spyglass/position/v1/position_trodes_position.py +++ b/src/spyglass/position/v1/position_trodes_position.py @@ -199,7 +199,7 @@ def make(self, key): 3. Generate AnalysisNwbfile and insert the key into the table. 4. Insert the key into the PositionOutput Merge table. """ - logger.info(f"Computing position for: {key}") + self._info_msg(f"Computing position for: {key}") orig_key = copy.deepcopy(key) analysis_file_name = AnalysisNwbfile().create(key["nwb_file_name"]) diff --git a/src/spyglass/spikesorting/v1/metric_curation.py b/src/spyglass/spikesorting/v1/metric_curation.py index cc78cb73f..b6dc6b2c0 100644 --- a/src/spyglass/spikesorting/v1/metric_curation.py +++ b/src/spyglass/spikesorting/v1/metric_curation.py @@ -290,11 +290,11 @@ def make_compute(self, key, upstream): # cannot handle these objects. # TODO: refactor upstream to allow for passing of keys to avoid fetch, # only fetching data from disk here. - logger.info("Extracting waveforms...") + self._info_msg("Extracting waveforms...") waveforms = self.get_waveforms(key) # compute metrics - logger.info("Computing metrics...") + self._info_msg("Computing metrics...") metrics = {} for metric_name, metric_param_dict in metric_params.items(): metrics[metric_name] = self._compute_metric( @@ -373,7 +373,9 @@ def get_waveforms( # Extract non-sparse waveforms by default waveform_params.setdefault("sparse", False) - dir_empty = not any(Path(waveforms_dir).iterdir()) + dir_empty = not Path(waveforms_dir).exists() or not any( + Path(waveforms_dir).iterdir() + ) if overwrite or dir_empty: waveforms = si.extract_waveforms( diff --git a/src/spyglass/spikesorting/v1/recompute.py b/src/spyglass/spikesorting/v1/recompute.py index 88aa91437..ec7ae2f1c 100644 --- a/src/spyglass/spikesorting/v1/recompute.py +++ b/src/spyglass/spikesorting/v1/recompute.py @@ -204,7 +204,7 @@ def default_rounding(self) -> int: @cached_property def env_dict(self): - logger.info("Initializing UserEnvironment") + self._info_msg("Initializing UserEnvironment") return UserEnvironment().insert_current_env() def insert( @@ -247,7 +247,7 @@ def insert( return if not rows: - logger.info("No rows to insert.") + self._info_msg("No rows to insert.") return if not isinstance(rows, (list, tuple)): rows = [rows] @@ -339,10 +339,12 @@ def attempt_all( if not bool(RecordingRecompute & key) ] if not inserts: - logger.info(f"No rows to insert from:\n\t{source}") + self._info_msg(f"No rows to insert from:\n\t{source}") return - logger.info(f"Inserting recompute attempts for {len(inserts)} files.") + self._info_msg( + f"Inserting recompute attempts for {len(inserts)} files." + ) self.insert(inserts, at_creation=False, **kwargs) @@ -498,25 +500,25 @@ def remove_matched( return 0 prefix = "DRY RUN: " if dry_run else "" - logger.info( + self._info_msg( f"{prefix}Found {count} selection entries for already-matched files" ) if dry_run: # Show sample of what would be deleted sample = redundant.fetch("KEY", as_dict=True, limit=10) - logger.info(f"{prefix}Sample entries (up to 10):") + self._info_msg(f"{prefix}Sample entries (up to 10):") for i, key in enumerate(sample, 1): analysis_file = key.get("analysis_file_name", "unknown") env_id = key.get("env_id", "unknown") - logger.info(f" {i}. {analysis_file} (env: {env_id})") + self._info_msg(f" {i}. {analysis_file} (env: {env_id})") if count > 10: - logger.info(f" ... and {count - 10} more") + self._info_msg(f" ... and {count - 10} more") return redundant # Actually delete the redundant entries redundant.delete_quick() - logger.info(f"Deleted {count} redundant entries") + self._info_msg(f"Deleted {count} redundant entries") return count @@ -794,7 +796,7 @@ def make(self, key, force_check=False) -> None: rec_key = dict(recording_id=key["recording_id"]) if not force_check and (self & rec_key & "matched=1"): RecordingRecomputeSelection().remove_matched(rec_key, dry_run=False) - logger.info("Previous match found. Skipping recompute.") + self._info_msg("Previous match found. Skipping recompute.") return parent = self.get_parent_key(key) @@ -802,7 +804,9 @@ def make(self, key, force_check=False) -> None: # Skip recompute for files with xfail reasons created_key = dict(created_at=self._get_file_created_at(key)) if parent.get("xfail_reason"): - logger.info(f"Skipping xfail entry: {parent.get('xfail_reason')}") + self._info_msg( + f"Skipping xfail entry: {parent.get('xfail_reason')}" + ) self.insert1( dict( key, @@ -815,7 +819,7 @@ def make(self, key, force_check=False) -> None: # Skip recompute for files logged at creation if parent["logged_at_creation"]: - logger.info("Skipping entry logged at creation.") + self._info_msg("Skipping entry logged at creation.") self.insert1(dict(key, matched=True, **created_key)) return @@ -835,7 +839,7 @@ def make(self, key, force_check=False) -> None: return if new_hasher.hash == old_hasher.hash: - logger.info(f"V1 Recompute match: {new_hasher.path.name}") + self._info_msg(f"V1 Recompute match: {new_hasher.path.name}") self.insert1(dict(key, matched=True, **created_key)) return @@ -904,7 +908,7 @@ def delete_files( file_names = query.fetch("analysis_file_name") prefix = "DRY RUN: " if dry_run else "" if not len(file_names): - logger.info(f"{prefix}Delete 0 files. Nothing to do.") + self._info_msg(f"{prefix}Delete 0 files. Nothing to do.") return msg = f"{prefix}Delete {len(file_names)} files?\n\t" + "\n\t".join( file_names[:10] @@ -916,7 +920,7 @@ def delete_files( restr = query.fetch("KEY", as_dict=True) space = self.get_disk_space(which="old", restr=restr) msg += f"\n{space}" - logger.info(msg) + self._info_msg(msg) return space if dj.utils.user_choice(msg).lower() not in ["yes", "y"]: @@ -960,10 +964,10 @@ def update_secondary(self, restriction=True) -> None: total = len(query) if total == 0: - logger.info("No entries to update") + self._info_msg("No entries to update") return - logger.info( + self._info_msg( f"Updating created_at for {total} entries from file timestamps" ) @@ -974,4 +978,4 @@ def update_secondary(self, restriction=True) -> None: dict(key, created_at=created_at, deleted=not old.exists()) ) - logger.info("Update complete") + self._info_msg("Update complete") diff --git a/src/spyglass/utils/database_settings.py b/src/spyglass/utils/database_settings.py index 182fc32ac..f3fa1bf78 100755 --- a/src/spyglass/utils/database_settings.py +++ b/src/spyglass/utils/database_settings.py @@ -240,4 +240,8 @@ def exec(self, file): + f"{self.exec_user} --password={self.exec_pass} < {file.name}" ) + # Suppress mysql's "Using a password on the command line interface" + # warning that fires when -p or --password= is used. + # Redirect stderr through grep to drop only that warning line. + cmd += " 2>&1 | { grep -v 'mysql: \\[Warning\\]' || true; }" os.system(cmd) diff --git a/src/spyglass/utils/sql_helper_fn.py b/src/spyglass/utils/sql_helper_fn.py index d195f2642..7c7fc874e 100644 --- a/src/spyglass/utils/sql_helper_fn.py +++ b/src/spyglass/utils/sql_helper_fn.py @@ -149,7 +149,7 @@ def write_mysqldump( self._remove_encoding(dump_script) self._write_version_file() - self._logger.info(f"Export script written to {dump_script}") + self._logger(f"Export script written to {dump_script}") self._export_conda_env() diff --git a/tests/conftest.py b/tests/conftest.py index 9ed15902d..0d7e14771 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -43,12 +43,15 @@ from shutil import rmtree as shutil_rmtree import datajoint as dj +import datajoint.external as _dj_external import datajoint.hash as _dj_hash +import hdmf.build.objectmapper as _hdmf_objectmapper import numpy as np import pynwb import pynwb.device as _pynwb_device import pynwb.io.device as _pynwb_io_device import pytest +import sklearn.utils.parallel as _sklearn_parallel from datajoint.logging import logger as dj_logger from hdmf.build.warnings import MissingRequiredBuildWarning from numba import NumbaWarning @@ -70,6 +73,10 @@ def _uuid_from_file_safe(filepath, *, init_string=""): _dj_hash.uuid_from_file = _uuid_from_file_safe +# datajoint.external uses `from .hash import uuid_from_file` at import time, +# creating a local binding that bypasses the patch above. Patch the external +# module's namespace directly so both paths use the safe version. +_dj_external.uuid_from_file = _uuid_from_file_safe # ----------- Prevent NWB-2.9 migration warnings from test NWB file ----------- # Patch pynwb Device NWB-2.9 migration warnings triggered by the test NWB file, @@ -122,6 +129,78 @@ def _device_init_no_field_deprecations(*args, **kwargs): _pynwb_device.Device.__init__ = _device_init_no_field_deprecations +# ------ Suppress warnings that bypass Python-level filters at call-time ------ +# +# Two remaining warnings survive even broad `filterwarnings("ignore", ...)` +# calls because they fire inside contexts where `warnings.filters` has been +# cleared or overridden: +# +# (1) MissingRequiredBuildWarning — hdmf.build.objectmapper.__check_quantity +# warns when an NWB container is missing a required attribute. The test +# NWB file predates NWB 2.9 and lacks 'source_script_file_name'. +# +# (2) sklearn UserWarning — sklearn.utils.parallel._FuncWrapper.__call__ +# warns when sklearn.delayed is used with non-sklearn Parallel. Worse, +# that same __call__ executes ``warnings.filters = []`` inside a +# catch_warnings block, which clears ALL Python-level filters for the +# duration of every wrapped parallel task — causing (1) to escape even +# when our "ignore" filters are present. +# +# Both modules use ``import warnings; warnings.warn(...)`` style, so we cannot +# patch the `warn` name directly in their namespace the way we did for +# pynwb.io.device. Instead we replace each module's `warnings` attribute with +# a thin proxy object that: +# • Intercepts warn() and suppresses the specific message/category. +# • Stores attribute *writes* (e.g. proxy.filters = []) locally so they never +# propagate to the real warnings module — preventing _FuncWrapper from +# clearing the real warnings.filters state. +# • Delegates all other attribute *reads* to the real warnings module. + + +class _ModuleWarningsProxy: + """Proxy for a module-level `warnings` reference. + + Suppresses specific warn() calls before they reach the real warnings + module. Attribute writes are stored locally (preventing callers like + sklearn._FuncWrapper from zeroing out the real warnings.filters list). + Attribute reads fall through to the real warnings module. + """ + + def __init__(self, suppress_fn): + # Use object.__setattr__ to avoid triggering our own __setattr__ logic. + object.__setattr__(self, "_suppress_fn", suppress_fn) + + def warn(self, message, *args, **kwargs): + if not object.__getattribute__(self, "_suppress_fn")( + message, *args, **kwargs + ): + # stacklevel=2 so the warning is attributed to the caller of the + # module's warnings.warn(), not to this proxy line. + kwargs.setdefault("stacklevel", 2) + warnings.warn(message, *args, **kwargs) + + def __getattr__(self, name): + return getattr(warnings, name) + + +# (1) Suppress MissingRequiredBuildWarning from hdmf objectmapper. +_hdmf_objectmapper.warnings = _ModuleWarningsProxy( + lambda msg, *a, **kw: ( + a + and isinstance(a[0], type) + and issubclass(a[0], MissingRequiredBuildWarning) + ) +) + +# (2) Suppress sklearn cross-library delayed/Parallel mismatch warning AND +# prevent _FuncWrapper from clearing the real warnings.filters. +_sklearn_parallel.warnings = _ModuleWarningsProxy( + lambda msg, *a, **kw: ( + "sklearn.utils.parallel.delayed" in str(msg) + and "sklearn.utils.parallel.Parallel" in str(msg) + ) +) + # globals in pytest_configure: # BASE_DIR, RAW_DIR, SERVER, TEARDOWN, VERBOSE, TEST_FILE, DOWNLOAD, NO_DLC @@ -137,6 +216,34 @@ def _device_init_no_field_deprecations(*args, **kwargs): warnings.filterwarnings("ignore", category=PerformanceWarning, module="pandas") warnings.filterwarnings("ignore", category=NumbaWarning, module="numba") +# RuntimeWarning: os.fork() was called after os.forkserver() or JAX import. +# JAX disables fork after parallelism starts; these are harmless in tests. +warnings.filterwarnings("ignore", category=RuntimeWarning, message=".*fork.*") +warnings.filterwarnings( + "ignore", category=RuntimeWarning, message=".*os\\.fork.*" +) + +# spikeinterface leaves mmap'd file handles open (traces_cached_seg*.raw). +# These show up as ResourceWarning during GC; suppress by module path. +warnings.filterwarnings( + "ignore", + category=ResourceWarning, + message=".*traces_cached_seg.*", +) +warnings.filterwarnings( + "ignore", + category=ResourceWarning, + module="spikeinterface", +) + +# TemporaryDirectory objects may be GC'd before __exit__ is called in some +# test teardown scenarios; suppress the resulting ResourceWarning. +warnings.filterwarnings( + "ignore", + category=ResourceWarning, + message=".*TemporaryDirectory.*", +) + def pytest_addoption(parser): """Permit constants when calling pytest at command line From cc47d9fbb289f5af99e87a4b2110971a581c325c Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Sat, 21 Feb 2026 11:45:49 +0100 Subject: [PATCH 21/30] Revert --- src/spyglass/utils/sql_helper_fn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spyglass/utils/sql_helper_fn.py b/src/spyglass/utils/sql_helper_fn.py index 7c7fc874e..d195f2642 100644 --- a/src/spyglass/utils/sql_helper_fn.py +++ b/src/spyglass/utils/sql_helper_fn.py @@ -149,7 +149,7 @@ def write_mysqldump( self._remove_encoding(dump_script) self._write_version_file() - self._logger(f"Export script written to {dump_script}") + self._logger.info(f"Export script written to {dump_script}") self._export_conda_env() From 3b914fb80e0314bdf2c08be6d8950c3faf4fb06d Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Sun, 22 Feb 2026 15:10:48 +0100 Subject: [PATCH 22/30] Respond to PR comments --- scripts/install.py | 73 +++++++++++-------- src/spyglass/common/common_behav.py | 4 +- src/spyglass/common/common_file_tracking.py | 3 +- src/spyglass/common/common_nwbfile.py | 2 +- src/spyglass/decoding/v1/clusterless.py | 14 ++-- src/spyglass/lfp/lfp_electrode.py | 11 ++- src/spyglass/position/utils_dlc.py | 6 ++ src/spyglass/position/v1/dlc_reader.py | 38 +++++----- src/spyglass/position/v1/dlc_utils.py | 19 +++-- src/spyglass/position/v1/dlc_utils_makevid.py | 4 +- .../position/v1/position_dlc_project.py | 16 ++-- .../position/v1/position_dlc_training.py | 9 ++- .../spikesorting/v0/spikesorting_curation.py | 2 +- .../spikesorting/v0/spikesorting_sorting.py | 4 +- src/spyglass/spikesorting/v1/sorting.py | 2 +- src/spyglass/utils/dj_helper_fn.py | 8 +- src/spyglass/utils/dj_merge_tables.py | 3 +- src/spyglass/utils/mixins/analysis_builder.py | 27 ++++--- src/spyglass/utils/mixins/base.py | 12 ++- src/spyglass/utils/mixins/cautious_delete.py | 4 +- src/spyglass/utils/mixins/restrict_by.py | 4 +- tests/common/test_analysis_builder.py | 18 ++--- tests/common/test_file_tracking.py | 6 +- tests/common/test_video_import_fail.py | 3 + tests/conftest.py | 16 ++++ tests/container.py | 8 +- tests/decoding/test_clusterless.py | 3 + .../decoding/test_intervals_removal_simple.py | 3 + tests/position/test_utils.py | 3 + tests/spikesorting/v1/test_recompute.py | 28 +++++-- tests/utils/test_graph.py | 12 +-- tests/utils/test_merge.py | 21 +++--- 32 files changed, 244 insertions(+), 142 deletions(-) diff --git a/scripts/install.py b/scripts/install.py index 3f0d0dcc0..75581c9db 100755 --- a/scripts/install.py +++ b/scripts/install.py @@ -159,7 +159,8 @@ def success(msg: str, indent: bool = False) -> None: return prefix = " " if indent else "" print( - f"{prefix}{COLORS['green']}{SYMBOLS['success']}{COLORS['reset']} {msg}" + f"{prefix}{COLORS['green']}{SYMBOLS['success']}{COLORS['reset']}" + + f" {msg}" ) @staticmethod @@ -169,7 +170,8 @@ def warning(msg: str, indent: bool = False) -> None: return prefix = " " if indent else "" print( - f"{prefix}{COLORS['yellow']}{SYMBOLS['warning']}{COLORS['reset']} {msg}" + f"{prefix}{COLORS['yellow']}{SYMBOLS['warning']}{COLORS['reset']}" + + f" {msg}" ) @staticmethod @@ -210,12 +212,19 @@ def print(msg: str, color: Optional[str] = None, indent: int = 0) -> None: r = COLORS["reset"] print(f"{prefix}{c}{msg}{r}") + @staticmethod def multi( - self, msgs: List[str], color: Optional[str] = None, indent: int = 0 + msgs: List[str], color: Optional[str] = None, indent: int = 0 ) -> None: """Print multiple messages with optional color and indentation.""" + if Console._quiet: + return + prefix = " " * indent + c = COLORS.get(color, COLORS["reset"]) + r = COLORS["reset"] + for msg in msgs: - self.print(msg, color=color, indent=indent) + print(f"{prefix}{c}{msg}{r}") @staticmethod def banner(msg: str, color: str = "blue", width: int = 60) -> None: @@ -455,7 +464,7 @@ def check_prerequisites( "Insufficient disk space - installation cannot continue", indent=True, ) - Console().multi( + Console.multi( [ f"Checking: {base_dir}", f"Available: {available_gb} GB", @@ -582,7 +591,7 @@ def create(self, env_file: str, force: bool = False) -> None: Console.success( f"Using existing environment '{self.env_name}'" ) - Console().multi( + Console.multi( [ "Package installation will continue (updates if needed)", "To use a different name, run with: --env-name ", @@ -788,7 +797,7 @@ def validate_and_test_write(path: Path) -> Path: # 3. Interactive prompt default = Path.home() / "spyglass_data" Console.print("\nWhere should Spyglass store data?") - Console().multi( + Console.multi( [ "This will store raw NWB files, analysis results, and video data.", "Typical usage: 10-100+ GB depending on your experiments.", @@ -881,7 +890,7 @@ def prompt_install_type() -> Tuple[str, str]: full_total = DISK_SPACE_REQUIREMENTS["full"]["total"] Console.print("\n1. Minimal (Recommended for getting started)") - Console().multi( + Console.multi( [ f"├─ Install time: ~{ENV_CREATION_TIME_MINIMAL} minutes", f"├─ Disk space: ~{min_pkg} GB packages ({min_total} GB total with buffer)", @@ -896,7 +905,7 @@ def prompt_install_type() -> Tuple[str, str]: indent=2, ) Console.print("\n2. Full (For advanced analysis)") - Console().multi( + Console.multi( [ f"├─ Install time: ~{ENV_CREATION_TIME_FULL} minutes", f"├─ Disk space: ~{full_pkg} GB packages ({full_total} GB total with buffer)", @@ -1414,7 +1423,7 @@ def create_database_config( except (OSError, IOError, json.JSONDecodeError, KeyError) as e: Console.print(f"(Unable to read existing config: {e})", indent=1) - Console().multi( + Console.multi( [ "\nOptions:", " [b] Backup and create new (saves to .datajoint_config.json.backup)", @@ -1429,7 +1438,7 @@ def create_database_config( Console.warning( "Keeping existing configuration. Installation cancelled." ) - Console().multi( + Console.multi( [ "\nTo install with different settings:", " 1. Backup your config: cp ~/.datajoint_config.json ~/.datajoint_config.json.backup", @@ -1481,14 +1490,14 @@ def create_database_config( tls_status = "Yes" if use_tls else "No (localhost)" Console.success(f"Configuration saved to: {config_file}") - Console().multi( + Console.multi( [ " Permissions: Owner read/write only (secure)", "", ] ) Console.success("✓ Spyglass configuration complete!") - Console().multi( + Console.multi( [ "", "Database connection:", @@ -1825,7 +1834,7 @@ def prompt_remote_database_config() -> Optional[Dict[str, Any]]: >>> if config: ... print(f"Connecting to {config['host']}:{config['port']}") """ - Console().multi( + Console.multi( [ "\nRemote database configuration:", " Your lab admin should have provided these credentials.", @@ -1989,7 +1998,7 @@ def prompt_database_setup() -> str: if not compose_available: Console.print("") Console.warning("Docker is not available") - Console().multi( + Console.multi( [ " To enable Docker setup:", " 1. Install Docker Desktop: https://docs.docker.com/get-docker/", @@ -2090,7 +2099,7 @@ def setup_database_compose() -> Tuple[bool, str]: # Platform-specific guidance if sys.platform == "darwin": # macOS - Console().multi( + Console.multi( [ " 1. Stop existing MySQL (if installed):", " brew services stop mysql", @@ -2100,7 +2109,7 @@ def setup_database_compose() -> Tuple[bool, str]: ] ) elif sys.platform.startswith("linux"): # Linux - Console().multi( + Console.multi( [ " 1. Stop existing MySQL service:", " sudo systemctl stop mysql", @@ -2111,7 +2120,7 @@ def setup_database_compose() -> Tuple[bool, str]: ] ) elif sys.platform == "win32": # Windows - Console().multi( + Console.multi( [ " 1. Stop existing MySQL service:", " net stop MySQL", @@ -2132,7 +2141,7 @@ def setup_database_compose() -> Tuple[bool, str]: # Show what will happen Console.print("") Console.banner("Docker Database Setup") - Console().multi( + Console.multi( [ "", "This will:", @@ -2177,7 +2186,7 @@ def setup_database_compose() -> Tuple[bool, str]: ) Console.print(" Fix: Wait a moment and retry") - Console().multi( + Console.multi( [ "\n Other steps to try:", " 1. Check internet connection", @@ -2416,7 +2425,7 @@ def handle_database_setup_interactive(env_name: str) -> None: else: Console.error("Docker setup failed") if reason == "compose_unavailable": - Console().multi( + Console.multi( [ "\nDocker is not available.", " Option 1: Install Docker Desktop and restart", @@ -2552,7 +2561,7 @@ def change_database_password( Console.print("") Console.banner("Password Change (Recommended for lab members)") - Console().multi( + Console.multi( [ "", "If you received temporary credentials from your lab admin,", @@ -2729,7 +2738,7 @@ def setup_database_remote( port_reachable, port_msg = is_port_available(host, port) if not port_reachable: Console.warning(port_msg) - Console().multi( + Console.multi( [ "\n Possible causes:", " • Wrong port number (MySQL usually uses 3306)", @@ -2759,7 +2768,7 @@ def setup_database_remote( if not success: Console.error(f"Cannot connect to database: {_error}") - Console().multi( + Console.multi( [ "", "Most common causes (in order):", @@ -2838,7 +2847,7 @@ def validate_installation(env_name: str) -> bool: except subprocess.CalledProcessError: Console.fail() Console.warning("Some optional validation checks did not pass") - Console().multi( + Console.multi( [ "\n Core installation succeeded, but some features may need attention.", " Many warnings are not critical for getting started.", @@ -2949,7 +2958,7 @@ def print_completion_message(env_name: str, validation_passed: bool) -> None: Console.print("") else: Console.banner("Installation complete with warnings", color="yellow") - Console().multi( + Console.multi( [ "", "Core installation succeeded but some features may not work.", @@ -2957,7 +2966,7 @@ def print_completion_message(env_name: str, validation_passed: bool) -> None: ] ) - Console().multi( + Console.multi( [ "Next steps:", f" 1. Activate environment: conda activate {env_name}", @@ -3013,7 +3022,7 @@ def run_dry_run(args: argparse.Namespace) -> None: Console.print("Would perform the following steps:\n") - Console().multi( + Console.multi( [ f"1. {SYMBOLS['step']} Check prerequisites", f" Python version: {sys.version_info.major}.{sys.version_info.minor}", @@ -3041,7 +3050,7 @@ def run_dry_run(args: argparse.Namespace) -> None: ) if not args.skip_validation: - Console().multi( + Console.multi( [ f"7. {SYMBOLS['step']} Validate installation", " Run: python scripts/validate.py", @@ -3049,7 +3058,7 @@ def run_dry_run(args: argparse.Namespace) -> None: ] ) - Console().multi( + Console.multi( [ "=" * 60, "To perform installation, run without --dry-run flag", @@ -3105,7 +3114,7 @@ def run_config_only(args: argparse.Namespace) -> None: raise ValueError("Database password is required") else: # Interactive mode - Console().multi( + Console.multi( [ "Database configuration:", " 1. Local Docker database (localhost)", @@ -3153,7 +3162,7 @@ def run_config_only(args: argparse.Namespace) -> None: Console.banner("") Console.success(f"Configuration created: {config_file}") Console.banner("") - Console().multi( + Console.multi( [ "", "Configuration summary:", diff --git a/src/spyglass/common/common_behav.py b/src/spyglass/common/common_behav.py index e81362999..0bb910cc8 100644 --- a/src/spyglass/common/common_behav.py +++ b/src/spyglass/common/common_behav.py @@ -862,7 +862,7 @@ def _no_transaction_make(self, key): # Skip populating if no pos interval list names if len(pos_intervals) == 0: - logger.error(f"NO POS INTERVALS FOR {key};\n{no_pop_msg}") + self._err_msg(f"NO POS INTERVALS FOR {key};\n{no_pop_msg}") self.insert1(null_key, **insert_opts) return @@ -899,7 +899,7 @@ def _no_transaction_make(self, key): # Check that each pos interval was matched to only one epoch if len(matching_pos_intervals) != 1: - logger.warning( + self._warn_msg( f"{no_pop_msg}. Found {len(matching_pos_intervals)} pos " + f"intervals for\n\t{key}\n\t" + f"Matching intervals: {matching_pos_intervals}" diff --git a/src/spyglass/common/common_file_tracking.py b/src/spyglass/common/common_file_tracking.py index 0b4279e01..3b4bf4fae 100644 --- a/src/spyglass/common/common_file_tracking.py +++ b/src/spyglass/common/common_file_tracking.py @@ -185,7 +185,8 @@ def show_downstream(self, restriction=True): """ entries = (self & "can_read=0" & restriction).fetch("KEY", as_dict=True) if not entries: - logger.info("No issues found.") + if not test_mode: + logger.info("No issues found.") return [] # Get unique analysis tables from entries diff --git a/src/spyglass/common/common_nwbfile.py b/src/spyglass/common/common_nwbfile.py index dc491cbe4..482727a06 100644 --- a/src/spyglass/common/common_nwbfile.py +++ b/src/spyglass/common/common_nwbfile.py @@ -880,6 +880,6 @@ def check_all_files(self) -> dict: logger.warning(f" Found {issue_count} file issues") total_issues = sum(results.values()) - logger.info(f"File check complete: {total_issues} issues found") + self._info_msg(f"File check complete: {total_issues} issues found") return results diff --git a/src/spyglass/decoding/v1/clusterless.py b/src/spyglass/decoding/v1/clusterless.py index 9e0ac68aa..3dfb14a5d 100644 --- a/src/spyglass/decoding/v1/clusterless.py +++ b/src/spyglass/decoding/v1/clusterless.py @@ -22,10 +22,8 @@ from spyglass.common.common_interval import IntervalList # noqa: F401 from spyglass.common.common_session import Session # noqa: F401 -from spyglass.decoding.v1.core import ( - DecodingParameters, # noqa: F401 - PositionGroup, -) +from spyglass.decoding.v1.core import DecodingParameters # noqa: F401 +from spyglass.decoding.v1.core import PositionGroup from spyglass.decoding.v1.utils import ( _get_interval_range, concatenate_interval_results, @@ -33,8 +31,8 @@ get_valid_kwargs, ) from spyglass.decoding.v1.waveform_features import ( - UnitWaveformFeatures, # noqa: F401 -) + UnitWaveformFeatures, +) # noqa: F401 from spyglass.position.position_merge import PositionOutput # noqa: F401 from spyglass.settings import config from spyglass.utils import SpyglassMixin, SpyglassMixinPart, logger @@ -65,8 +63,8 @@ def create_group( "waveform_features_group_name": group_name, } if self & group_key: - logger.warning( # No error on duplicate helps with pytests - f"Group {nwb_file_name}: {group_name} already exists" + self._warn_msg( # No error on duplicate helps with pytests + f"Group {nwb_file_name}: {group_name} already exists " + "please delete the group before creating a new one", ) return diff --git a/src/spyglass/lfp/lfp_electrode.py b/src/spyglass/lfp/lfp_electrode.py index c8ddc2173..8c5d8525e 100644 --- a/src/spyglass/lfp/lfp_electrode.py +++ b/src/spyglass/lfp/lfp_electrode.py @@ -5,6 +5,7 @@ from spyglass.common.common_ephys import Electrode from spyglass.common.common_session import Session # noqa: F401 +from spyglass.settings import test_mode from spyglass.utils import logger from spyglass.utils.dj_mixin import SpyglassMixin @@ -120,10 +121,12 @@ def create_lfp_electrode_group( # Insert part table entries LFPElectrodeGroup.LFPElectrode.insert(part_keys, **kwargs) - logger.info( - f"Successfully created/updated LFPElectrodeGroup {nwb_file_name}, {group_name} " - f"with {len(electrode_list)} electrodes." - ) + if not test_mode: + logger.info( + "Successfully created/updated LFPElectrodeGroup " + + f"{nwb_file_name}, {group_name} with {len(electrode_list)} " + + "electrodes." + ) def cautious_insert( self, session_key: dict, electrode_ids: List[int], group_name: str diff --git a/src/spyglass/position/utils_dlc.py b/src/spyglass/position/utils_dlc.py index e4e76b322..6a0aa1da5 100644 --- a/src/spyglass/position/utils_dlc.py +++ b/src/spyglass/position/utils_dlc.py @@ -11,6 +11,7 @@ evaluate_network, get_evaluation_folder = None, None # pragma: no cover from spyglass.position.utils import get_most_recent_file +from spyglass.settings import test_mode @contextlib.contextmanager @@ -59,6 +60,11 @@ def __getattr__(self, name: str): sys.stderr = old_stderr +test_mode_suppress = ( + suppress_print_from_package if test_mode else contextlib.nullcontext +) + + def get_dlc_model_eval( yml_path: str, model_prefix: str, diff --git a/src/spyglass/position/v1/dlc_reader.py b/src/spyglass/position/v1/dlc_reader.py index 05b74a93c..f3813e79b 100644 --- a/src/spyglass/position/v1/dlc_reader.py +++ b/src/spyglass/position/v1/dlc_reader.py @@ -8,6 +8,7 @@ import pandas as pd import ruamel.yaml as yaml +from spyglass.position.utils_dlc import test_mode_suppress from spyglass.settings import test_mode @@ -234,21 +235,22 @@ def do_pose_estimation( if dlc_project_path != output_dir: config_filepath = save_yaml(dlc_project_path, dlc_config) # ---- Trigger DLC prediction job ---- - analyze_videos( - config=config_filepath, - videos=video_filepaths, - shuffle=dlc_model["shuffle"], - trainingsetindex=dlc_model["trainingsetindex"], - destfolder=output_dir, - modelprefix=dlc_model["model_prefix"], - videotype=videotype, - gputouse=gputouse, - save_as_csv=save_as_csv, - batchsize=batchsize, - cropping=cropping, - TFGPUinference=TFGPUinference, - dynamic=dynamic, - robust_nframes=robust_nframes, - allow_growth=allow_growth, - use_shelve=use_shelve, - ) + with test_mode_suppress(): + analyze_videos( + config=config_filepath, + videos=video_filepaths, + shuffle=dlc_model["shuffle"], + trainingsetindex=dlc_model["trainingsetindex"], + destfolder=output_dir, + modelprefix=dlc_model["model_prefix"], + videotype=videotype, + gputouse=gputouse, + save_as_csv=save_as_csv, + batchsize=batchsize, + cropping=cropping, + TFGPUinference=TFGPUinference, + dynamic=dynamic, + robust_nframes=robust_nframes, + allow_growth=allow_growth, + use_shelve=use_shelve, + ) diff --git a/src/spyglass/position/v1/dlc_utils.py b/src/spyglass/position/v1/dlc_utils.py index caa303826..738123c26 100644 --- a/src/spyglass/position/v1/dlc_utils.py +++ b/src/spyglass/position/v1/dlc_utils.py @@ -20,7 +20,7 @@ from spyglass.common.common_behav import VideoFile from spyglass.common.common_usage import ActivityLog -from spyglass.settings import dlc_output_dir, dlc_video_dir, raw_dir +from spyglass.settings import dlc_output_dir, dlc_video_dir, raw_dir, test_mode from spyglass.utils.logging import logger, stream_handler @@ -610,11 +610,13 @@ def _get_new_dim(dim, span_start, span_stop, start_time, stop_time): if (span_stop + 1) >= len(dlc_df): dlc_df.loc[idx_span, idx[["x", "y"]]] = np.nan - logger.info(no_x_msg.format(ind=ind, coord="end")) + if not test_mode: + logger.info(no_x_msg.format(ind=ind, coord="end")) continue if span_start < 1: dlc_df.loc[idx_span, idx[["x", "y"]]] = np.nan - logger.info(no_x_msg.format(ind=ind, coord="start")) + if not test_mode: + logger.info(no_x_msg.format(ind=ind, coord="start")) continue x = [dlc_df["x"].iloc[span_start - 1], dlc_df["x"].iloc[span_stop + 1]] @@ -627,7 +629,10 @@ def _get_new_dim(dim, span_start, span_stop, start_time, stop_time): if span_len > max_pts_to_interp or change > max_cm_to_interp: dlc_df.loc[idx_span, idx[["x", "y"]]] = np.nan - logger.info(no_interp_msg.format(start=span_start, stop=span_stop)) + if not test_mode: + logger.info( + no_interp_msg.format(start=span_start, stop=span_stop) + ) if change > max_cm_to_interp: continue @@ -650,11 +655,13 @@ def interp_orientation(df, spans_to_interp, **kwargs): idx_span = idx[span_start:span_stop] if (span_stop + 1) >= len(df): df.loc[idx_span, idx["orientation"]] = np.nan - logger.info(no_x_msg.format(ind=ind, x="stop")) + if not test_mode: + logger.info(no_x_msg.format(ind=ind, x="stop")) continue if span_start < 1: df.loc[idx_span, idx["orientation"]] = np.nan - logger.info(no_x_msg.format(ind=ind, x="start")) + if not test_mode: + logger.info(no_x_msg.format(ind=ind, x="start")) continue orient = [df_orient.iloc[span_start - 1], df_orient.iloc[span_stop + 1]] diff --git a/src/spyglass/position/v1/dlc_utils_makevid.py b/src/spyglass/position/v1/dlc_utils_makevid.py index 655a5c44c..82f5c0122 100644 --- a/src/spyglass/position/v1/dlc_utils_makevid.py +++ b/src/spyglass/position/v1/dlc_utils_makevid.py @@ -400,7 +400,9 @@ def _generate_single_frame(self, frame_ind): # pragma: no cover def process_frames(self): """Process video frames in batches and generate matplotlib frames.""" - progress_bar = tqdm(leave=True, position=0, disable=self.debug) + disable = False if test_mode else self.debug + + progress_bar = tqdm(leave=True, position=0, disable=disable) progress_bar.reset(total=self.n_frames) for start_frame in range(0, self.n_frames, self.batch_size): diff --git a/src/spyglass/position/v1/position_dlc_project.py b/src/spyglass/position/v1/position_dlc_project.py index 757706e33..b96e4f6d1 100644 --- a/src/spyglass/position/v1/position_dlc_project.py +++ b/src/spyglass/position/v1/position_dlc_project.py @@ -10,6 +10,7 @@ from spyglass.common.common_lab import LabTeam from spyglass.position.utils import sanitize_filename +from spyglass.position.utils_dlc import test_mode_suppress from spyglass.position.v1.dlc_utils import find_mp4, get_video_info from spyglass.settings import dlc_project_dir, dlc_video_dir from spyglass.utils import SpyglassMixin, logger @@ -399,7 +400,9 @@ def add_training_files(cls, key, **kwargs): ] } - cfg = read_config(config_path) + with test_mode_suppress(): + cfg = read_config(config_path) + video_names = list(cfg["video_sets"]) label_dir = Path(cfg["project_path"]) / "labeled-data" training_files = [] @@ -423,7 +426,7 @@ def add_training_files(cls, key, **kwargs): ) if len(training_files) == 0: - logger.warning("No training files to add") + cls()._warn_msg("No training files to add") return training_file_inserts = [] @@ -451,7 +454,8 @@ def run_extract_frames(cls, key, **kwargs): config_path = (cls & key).fetch1("config_path") from deeplabcut import extract_frames - extract_frames(config_path, **kwargs) + with test_mode_suppress(): + extract_frames(config_path, **kwargs) @classmethod def run_label_frames(cls, key): # pragma: no cover @@ -466,7 +470,8 @@ def run_label_frames(cls, key): # pragma: no cover logger.error("DLC loaded in light mode, cannot label frames") return - label_frames(config_path) # pragma: no cover + with test_mode_suppress(): + label_frames(config_path) # pragma: no cover @classmethod def check_labels(cls, key, **kwargs): # pragma: no cover @@ -476,7 +481,8 @@ def check_labels(cls, key, **kwargs): # pragma: no cover config_path = (cls & key).fetch1("config_path") from deeplabcut import check_labels - check_labels(config_path, **kwargs) + with test_mode_suppress(): + check_labels(config_path, **kwargs) @classmethod def import_labeled_frames( diff --git a/src/spyglass/position/v1/position_dlc_training.py b/src/spyglass/position/v1/position_dlc_training.py index eeefecda0..5a551aa28 100644 --- a/src/spyglass/position/v1/position_dlc_training.py +++ b/src/spyglass/position/v1/position_dlc_training.py @@ -4,7 +4,10 @@ import datajoint as dj from spyglass.position.utils import get_param_names -from spyglass.position.utils_dlc import suppress_print_from_package +from spyglass.position.utils_dlc import ( + suppress_print_from_package, + test_mode_suppress, +) from spyglass.position.v1.dlc_utils import file_log from spyglass.position.v1.position_dlc_project import DLCProject from spyglass.settings import test_mode @@ -203,7 +206,9 @@ def make_compute( self._info_msg("creating training dataset") # NOTE: if DLC > 3, this will raise engine error - create_training_dataset(dlc_cfg_filepath, **training_dataset_kwargs) + with test_mode_suppress(): + create_training_dataset(dlc_cfg_filepath, **training_dataset_kwargs) + # ---- Trigger DLC model training job ---- train_network_kwargs = { k: v diff --git a/src/spyglass/spikesorting/v0/spikesorting_curation.py b/src/spyglass/spikesorting/v0/spikesorting_curation.py index fa8dc54df..c387421c2 100644 --- a/src/spyglass/spikesorting/v0/spikesorting_curation.py +++ b/src/spyglass/spikesorting/v0/spikesorting_curation.py @@ -1140,7 +1140,7 @@ def make_fetch(self, key): "quality_metrics", "curation_labels" ) if metrics == {}: - logger.warning( + self._warn_msg( f"Metrics for Curation {key} should normally be calculated " + "before insertion here" ) diff --git a/src/spyglass/spikesorting/v0/spikesorting_sorting.py b/src/spyglass/spikesorting/v0/spikesorting_sorting.py index bf233d97b..df001c1b4 100644 --- a/src/spyglass/spikesorting/v0/spikesorting_sorting.py +++ b/src/spyglass/spikesorting/v0/spikesorting_sorting.py @@ -253,7 +253,7 @@ def make_compute( mode="zeros", ) - logger.info(f"Running spike sorting on {key}...") + self._info_msg(f"Running spike sorting on {key}...") sorter_temp_dir = tempfile.TemporaryDirectory(dir=temp_dir) # add tempdir option for mountainsort @@ -283,6 +283,8 @@ def make_compute( # whiten recording separately; make sure dtype is float32 # to avoid downstream error with svd recording = sip.whiten(recording, dtype="float32") + # NOTE: mountainsort4's ms4alg.py calls warnings.resetwarnings() + # at import time, which clears all user-defined warning filters. sorting = sis.run_sorter( sorter, recording, diff --git a/src/spyglass/spikesorting/v1/sorting.py b/src/spyglass/spikesorting/v1/sorting.py index 6cfa2e9da..7ab3334f0 100644 --- a/src/spyglass/spikesorting/v1/sorting.py +++ b/src/spyglass/spikesorting/v1/sorting.py @@ -249,7 +249,7 @@ def make_compute( sorting=sorting, timestamps=timestamps, artifact_removed_intervals=artifact_removed_intervals, - nwb_file_name=(SpikeSortingSelection & key).fetch1("nwb_file_name"), + nwb_file_name=nwb_file_name, ) return [nwb_file_name, time_of_sort, analysis_file_name, object_id] diff --git a/src/spyglass/utils/dj_helper_fn.py b/src/spyglass/utils/dj_helper_fn.py index 6a824c115..7264fd41e 100644 --- a/src/spyglass/utils/dj_helper_fn.py +++ b/src/spyglass/utils/dj_helper_fn.py @@ -105,9 +105,9 @@ def declare_all_merge_tables() -> Tuple[Type[dj.Table]]: from spyglass.decoding.decoding_merge import DecodingOutput # noqa: F401 from spyglass.lfp.lfp_merge import LFPOutput # noqa: F401 from spyglass.position.position_merge import PositionOutput # noqa: F401 - from spyglass.spikesorting.spikesorting_merge import ( # noqa: F401 + from spyglass.spikesorting.spikesorting_merge import ( SpikeSortingOutput, - ) + ) # noqa: F401 return DecodingOutput, LFPOutput, PositionOutput, SpikeSortingOutput @@ -686,8 +686,8 @@ def accept_divergence( """ if test_mode: # If get here in test mode, is because want to test failure - logger.warning( - "accept_divergence called in test mode, returning False w/o prompt" + logger.debug( + "\naccept_divergence called in testing, returning False w/o prompt" ) return False tbl_msg = "" diff --git a/src/spyglass/utils/dj_merge_tables.py b/src/spyglass/utils/dj_merge_tables.py index eeda06a7d..fd7b29168 100644 --- a/src/spyglass/utils/dj_merge_tables.py +++ b/src/spyglass/utils/dj_merge_tables.py @@ -13,6 +13,7 @@ from IPython.core.display import HTML from spyglass.utils.logging import logger +from spyglass.utils.mixins.base import BaseMixin from spyglass.utils.mixins.export import ExportMixin RESERVED_PRIMARY_KEY = "merge_id" @@ -270,7 +271,7 @@ def _merge_repr( ) ] if not parts: - logger.warning("No parts found. Try adjusting restriction.") + cls()._warn_msg("No parts found. Try adjusting restriction.") return attr_dict = { # NULL for non-numeric, 0 for numeric diff --git a/src/spyglass/utils/mixins/analysis_builder.py b/src/spyglass/utils/mixins/analysis_builder.py index ca039e9ba..c66e482d2 100644 --- a/src/spyglass/utils/mixins/analysis_builder.py +++ b/src/spyglass/utils/mixins/analysis_builder.py @@ -9,6 +9,7 @@ import pynwb +from spyglass.settings import test_mode from spyglass.utils.logging import logger @@ -81,6 +82,7 @@ def __init__(self, analysis_table, nwb_file_name: str): self.analysis_file_name = None self._state = "INIT" self._exception_occurred = False + self._exception_log = dict() def __enter__(self): """Create analysis file (CREATE phase). @@ -124,22 +126,27 @@ def __exit__(self, exc_type, exc_val, exc_tb): if self._exception_occurred: # Log failed file for cleanup - logger.warning( - f"Analysis file '{self.analysis_file_name}' created but not " - f"registered due to exception: {exc_type.__name__}. " - f"File will be detected and cleaned up by " - f"AnalysisNwbfile.cleanup()" - ) + self._exception_log = {self.analysis_file_name: exc_type} + if not test_mode: + logger.error( + f"Analysis file '{self.analysis_file_name}' created but " + f"not registered due to exception: {exc_type.__name__}. " + f"File will be detected and cleaned up by " + f"AnalysisNwbfile.cleanup()" + ) return False # Don't suppress exception # Always auto-register on successful exit try: self.register() except Exception as e: - logger.error( - f"Failed to register analysis file " - f"'{self.analysis_file_name}': {e}" - ) + self._exception_occurred = True + self._exception_log = {self.analysis_file_name: e.__name__} + if not test_mode: + logger.error( + f"Failed to register analysis file " + f"'{self.analysis_file_name}': {e}" + ) raise # Re-raise registration error return False # Never suppress exceptions diff --git a/src/spyglass/utils/mixins/base.py b/src/spyglass/utils/mixins/base.py index b157fa41c..19a6f4cb3 100644 --- a/src/spyglass/utils/mixins/base.py +++ b/src/spyglass/utils/mixins/base.py @@ -56,14 +56,22 @@ def _warn_msg(self, msg: str) -> None: log = self._logger.debug if self._test_mode else self._logger.warning log(msg) + def _err_msg(self, msg: str) -> None: + """Log error message, but debug if in test mode. + + Quiets logs during testing, but preserves user experience during use. + """ + log = self._logger.debug if self._test_mode else self._logger.error + log(msg) + @cached_property def _test_mode(self) -> bool: """Return True if in test mode. Avoids circular import. Prevents prompt on delete. - Note: Using @property instead of @cached_property so we always get - current value from dj.config, even if test_mode changes after first access. + Note: Using cached property b/c we don't expect test_mode to change + during runtime, and it avoids repeated lookups. Used by ... - BaseMixin._spyglass_version diff --git a/src/spyglass/utils/mixins/cautious_delete.py b/src/spyglass/utils/mixins/cautious_delete.py index fcc8b2b0b..1170ca2ef 100644 --- a/src/spyglass/utils/mixins/cautious_delete.py +++ b/src/spyglass/utils/mixins/cautious_delete.py @@ -215,7 +215,7 @@ def cautious_delete( Passed to datajoint.table.Table.delete. """ if len(self) == 0: - self._logger.warning(f"Table is empty. No need to delete.\n{self}") + self._logger.warning("Table is empty. Nothing to delete.") return if self._has_updated_dj_version and not isinstance(self, dj.Part): @@ -249,6 +249,6 @@ def delete(self, *args, **kwargs): def super_delete(self, warn=True, *args, **kwargs): """Alias for datajoint.table.Table.delete.""" if warn: - self._logger.warning("!! Bypassing cautious_delete !!") + self._warn_msg("!! Bypassing cautious_delete !!") self._log_delete(start=time(), super_delete=True) super().delete(*args, **kwargs) diff --git a/src/spyglass/utils/mixins/restrict_by.py b/src/spyglass/utils/mixins/restrict_by.py index 8687c9e28..64b61453f 100644 --- a/src/spyglass/utils/mixins/restrict_by.py +++ b/src/spyglass/utils/mixins/restrict_by.py @@ -112,9 +112,9 @@ def restrict_by( + "See `help(YourTable.restrict_by)`" ) if len(ret) == len(self): - self._logger.warning("Same length" + warn_text) + self._warn_msg("Same length" + warn_text) elif len(ret) == 0: - self._logger.warning("No entries" + warn_text) + self._warn_msg("No entries" + warn_text) return ret diff --git a/tests/common/test_analysis_builder.py b/tests/common/test_analysis_builder.py index b9e905b1c..e58e0d36b 100644 --- a/tests/common/test_analysis_builder.py +++ b/tests/common/test_analysis_builder.py @@ -388,7 +388,7 @@ def test_idempotent_registration( class TestExceptionHandling: - """Test exception handling and cleanup logging.""" + """Test exception handling and cleanup.""" def test_file_not_registered_on_exception( self, master_analysis_table, mini_copy_name, mock_create, teardown @@ -423,7 +423,6 @@ def test_failed_file_logged_for_cleanup( mini_copy_name, mock_create, teardown, - caplog, ): """Test that failed files are logged for cleanup detection.""" table = master_analysis_table @@ -431,10 +430,6 @@ def test_failed_file_logged_for_cleanup( analysis_file = None - import logging - - caplog.set_level(logging.WARNING) - try: with table.build(mini_copy_name) as builder: analysis_file = builder.analysis_file_name @@ -444,11 +439,12 @@ def test_failed_file_logged_for_cleanup( except RuntimeError: pass # Expected - # Check that warning was logged - assert any( - "not registered due to exception" in record.message - for record in caplog.records - ), "Should log warning about unregistered file" + assert ( + builder._exception_occurred + ), "Builder should have recorded that an exception occurred" + + err_type = builder._exception_log[analysis_file] + assert err_type is RuntimeError, "Failed to log exception type" # Cleanup if teardown and analysis_file: diff --git a/tests/common/test_file_tracking.py b/tests/common/test_file_tracking.py index 9bc941a9a..7eae9cbb6 100644 --- a/tests/common/test_file_tracking.py +++ b/tests/common/test_file_tracking.py @@ -44,6 +44,7 @@ def analysis_file_issues(file_tracking_module, teardown): @pytest.fixture(scope="module") def custom_analysis_with_files(custom_config, dj_conn, common_nwbfile): """Create custom analysis table with test files.""" + from spyglass.common import Nwbfile # noqa: F401 from spyglass.utils.dj_mixin import SpyglassAnalysis prefix = custom_config @@ -152,15 +153,14 @@ def test_get_tbl_method_exists(analysis_file_issues): assert callable(analysis_file_issues.get_tbl) -def test_show_downstream_no_issues(analysis_file_issues, caplog): +def test_show_downstream_no_issues(analysis_file_issues): """Test show_downstream() with no issues.""" # Call with restriction that matches nothing result = analysis_file_issues.show_downstream( restriction={"analysis_file_name": "definitely_nonexistent_file.nwb"} ) - assert "No issues found" in caplog.text - assert not result # Should return empty list + assert isinstance(result, list) and not result, "Should be an empty list" def test_integration_check_all_files_method_exists(common_nwbfile): diff --git a/tests/common/test_video_import_fail.py b/tests/common/test_video_import_fail.py index 2224b45d1..f80f871b8 100644 --- a/tests/common/test_video_import_fail.py +++ b/tests/common/test_video_import_fail.py @@ -10,6 +10,8 @@ from pynwb.image import ImageSeries from pynwb.testing.mock.file import mock_NWBFile, mock_Subject +from tests.conftest import VERBOSE + @pytest.fixture(scope="function") def nwb_with_video_no_task(raw_dir, common): @@ -72,6 +74,7 @@ def nwb_with_video_no_task(raw_dir, common): nwbfile_path.unlink() +@pytest.mark.skipif(not VERBOSE, reason="No logging to test when quiet-spy.") def test_video_import_without_task_silent_failure( nwb_with_video_no_task, common, caplog ): diff --git a/tests/conftest.py b/tests/conftest.py index 0d7e14771..8f3c970d8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -244,6 +244,22 @@ def __getattr__(self, name): message=".*TemporaryDirectory.*", ) +# numcodecs/__init__.py registers `atexit.register(blosc.destroy)` where +# `blosc.destroy` is decorated with @deprecated (PyPI `deprecated` package). +# This fires a DeprecationWarning at process exit. We could filter it, but +# ms4alg.py calls `warnings.resetwarnings()` during sorting — since pytest +# runs with `-p no:warnings` (no catch_warnings restoration), that clears all +# our filters and they are not restored before atexit fires. +# Unregistering the atexit handler is cleaner: blosc._init() has already run, +# and skipping _destroy() in the test process is harmless. +try: + import atexit as _atexit + import numcodecs.blosc as _numcodecs_blosc + + _atexit.unregister(_numcodecs_blosc.destroy) +except Exception: + pass # numcodecs not installed — nothing to unregister + def pytest_addoption(parser): """Permit constants when calling pytest at command line diff --git a/tests/container.py b/tests/container.py index c61295ebd..05109efb6 100644 --- a/tests/container.py +++ b/tests/container.py @@ -228,6 +228,8 @@ def wait(self, timeout=120, wait=3) -> None: return None if not self.container_status or self.container_status == "exited": self.start() + if self.container.health == "healthy": + return print("") self.logger.info(f"Container {self.container_name} starting...") @@ -310,8 +312,10 @@ def stop(self, remove=True) -> None: container_name = self.container_name self.container.stop() # Logger I/O operations close during teardown - print(f"Container {container_name} stopped.") + logline = f"Container {container_name} stopped" if remove: self.container.remove() - print(f"Container {container_name} removed.") + logline += " and removed" + + print(f"{logline}.") diff --git a/tests/decoding/test_clusterless.py b/tests/decoding/test_clusterless.py index c439d420e..795cfeebc 100644 --- a/tests/decoding/test_clusterless.py +++ b/tests/decoding/test_clusterless.py @@ -2,6 +2,8 @@ import pandas as pd import pytest +from tests.conftest import VERBOSE + @pytest.mark.very_slow def test_fetch_results(clusterless_pop, result_coordinates): @@ -75,6 +77,7 @@ def test_get_ahead(clusterless_pop): assert dist is not None, "Distance is None" +@pytest.mark.skipif(not VERBOSE, reason="No logging to test when quiet-spy") def test_insert_existing_group(caplog, group_unitwave): file, group = group_unitwave.fetch1( "nwb_file_name", "waveform_features_group_name" diff --git a/tests/decoding/test_intervals_removal_simple.py b/tests/decoding/test_intervals_removal_simple.py index 92e63a6b6..b73684a8d 100644 --- a/tests/decoding/test_intervals_removal_simple.py +++ b/tests/decoding/test_intervals_removal_simple.py @@ -13,6 +13,8 @@ import pytest import xarray as xr +from tests.conftest import VERBOSE + @pytest.fixture def create_interval_labels(): @@ -482,6 +484,7 @@ class TestIntervalIdxWarning: when interval_idx is specified but results don't have interval_labels. """ + @pytest.mark.skipif(not VERBOSE, reason="No logging to test when quiet-spy") def test_interval_idx_warning_when_no_labels(self, caplog): """Should warn when interval_idx specified but no interval_labels.""" from unittest.mock import patch diff --git a/tests/position/test_utils.py b/tests/position/test_utils.py index ec4842b09..5947a8dbb 100644 --- a/tests/position/test_utils.py +++ b/tests/position/test_utils.py @@ -5,6 +5,8 @@ import pandas as pd import pytest +from tests.conftest import VERBOSE + def test_get_params_fallback(sgp): get_params = sgp.utils.get_param_names @@ -50,6 +52,7 @@ def test_valid_list_error(sgp): validate(option_list=["a", "b"], required_items=["a", "c"]) +@pytest.mark.skipif(not VERBOSE, reason="No logging to test when quiet-spy") def test_log_decorator_no_path(sgp): from spyglass.position.v1.dlc_utils import file_log from spyglass.utils import logger diff --git a/tests/spikesorting/v1/test_recompute.py b/tests/spikesorting/v1/test_recompute.py index ddb077b47..1c7e24650 100644 --- a/tests/spikesorting/v1/test_recompute.py +++ b/tests/spikesorting/v1/test_recompute.py @@ -63,18 +63,34 @@ def test_recompute_env(recomp_tbl): assert ret, "Recompute failed" -@pytest.mark.skipif(not VERBOSE, reason="No logging to test when quiet-spy") def test_selection_attempt(caplog, recomp_selection): - """Test that the selection attempt works.""" + """Test that the selection reattempt does not add new entries.""" _ = recomp_selection.attempt_all() - assert "No rows" in caplog.text, "Selection attempt failed null log" + prev_len = len(recomp_selection) + ret = recomp_selection.attempt_all() + post_len = len(recomp_selection) + assert ret is None, "Selection attempt failed" + assert prev_len == post_len, "Selection attempt should not add new entries" -@pytest.mark.skipif(not VERBOSE, reason="No logging to test when quiet-spy") -def test_delete_dry_run(caplog, recomp_tbl): +def test_delete_dry_run(recomp_tbl): """Test dry run delete.""" + prev_len = len(recomp_tbl) _ = recomp_tbl.delete_files(dry_run=True) - assert "DRY" in caplog.text, "Dry run delete failed to log" + post_len = len(recomp_tbl) + assert prev_len == post_len, "Dry run delete should not remove entries" + + +def test_recompute_disk_check(recomp_tbl): + """Test that the disk check works.""" + from spyglass.utils.dj_helper_fn import bytes_to_human_readable + + key = recomp_tbl.fetch("KEY")[0] + path, _ = recomp_tbl._get_paths(key) + size = Path(path).stat().st_size if Path(path).exists() else 0 + expected = bytes_to_human_readable(size) + result = recomp_tbl.get_disk_space(which="old", restr=key) + assert expected in result, "Disk check failed" @pytest.mark.slow diff --git a/tests/utils/test_graph.py b/tests/utils/test_graph.py index 9aabe50d7..4437ca290 100644 --- a/tests/utils/test_graph.py +++ b/tests/utils/test_graph.py @@ -311,15 +311,15 @@ def test_invalid_restr_direction(graph_tables): PkNode.restrict_by("bad_attr > 0", direction="invalid_direction") -@pytest.mark.skipif(not VERBOSE, reason="No logging to test when quiet-spy.") -def test_warn_nonrestrict(caplog, graph_tables): +def test_warn_nonrestrict(graph_tables): ParentNode = graph_tables["ParentNode"]() restr_parent = ParentNode & "parent_id > 4 AND parent_id < 9" - restr_parent >> "sk_id > 0" - assert "Same length" in caplog.text, "No warning logged on non-restrict." - restr_parent >> "sk_id > 99" - assert "No entries" in caplog.text, "No warning logged on non-restrict." + ret = restr_parent >> "sk_id > 0" + assert len(ret) == len(restr_parent), "Restriction should have no effect." + + ret = restr_parent >> "sk_id > 99" + assert len(ret) == 0, "Return should be empty." def test_restr_many_to_one(graph_tables_many_to_one): diff --git a/tests/utils/test_merge.py b/tests/utils/test_merge.py index fc225cf21..6636544aa 100644 --- a/tests/utils/test_merge.py +++ b/tests/utils/test_merge.py @@ -1,5 +1,6 @@ import datajoint as dj import pytest +from datajoint.logging import logger as dj_logger from tests.conftest import VERBOSE @@ -24,8 +25,11 @@ class BadChild(SpyglassMixin, dj.Part): yield BadMerge + prev_level = dj_logger.level + dj_logger.setLevel("ERROR") BadMerge.BadChild().drop_quick() BadMerge().drop_quick() + dj_logger.setLevel(prev_level) @pytest.mark.skipif(not VERBOSE, reason="No logging to test when quiet-spy.") @@ -72,11 +76,9 @@ def test_merge_view(pos_merge): assert len(view.heading.names) > 14, "Repr not showing all columns." -@pytest.mark.skipif(not VERBOSE, reason="No logging to test when quiet-spy.") -def test_merge_view_warning(caplog, merge_table): - _ = merge_table.merge_restrict(restriction='source="bad"') - txt = caplog.text - assert "No parts" in txt, "Warning not caught." +def test_merge_view_null(merge_table): + ret = merge_table.merge_restrict(restriction='source="bad"') + assert ret is None, "Restriction should return None for no matches." def test_merge_get_class(merge_table): @@ -85,8 +87,7 @@ def test_merge_get_class(merge_table): assert parent_cls.__name__ == part_name, "Class not found." -@pytest.mark.skip(reason="Pending populated merge table.") -def test_merge_get_class_invalid(caplog, merge_table): - _ = merge_table.merge_get_parent_class("bad") - txt = caplog.text - assert "No source" in txt, "Warning not caught." +# @pytest.mark.skip(reason="Pending populated merge table.") +def test_merge_get_class_invalid(spike_merge, pop_spike_merge): + ret = spike_merge.merge_get_parent_class("bad") + assert ret is None, "Should return None for invalid part name." From 439dd4e5f6342fddd21ea28305a1e089b33a5228 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Sun, 22 Feb 2026 15:31:52 +0100 Subject: [PATCH 23/30] Denoise tests 8 --- src/spyglass/utils/dj_merge_tables.py | 6 ++--- src/spyglass/utils/mixins/base.py | 2 ++ src/spyglass/utils/mixins/cautious_delete.py | 2 +- src/spyglass/utils/mixins/helpers.py | 8 +++--- src/spyglass/utils/mixins/restrict_by.py | 2 +- tests/decoding/test_clusterless.py | 9 +++---- tests/utils/test_graph.py | 26 +++++++------------- tests/utils/test_merge.py | 22 ++++++++--------- tests/utils/test_mixin.py | 25 ++++++++----------- 9 files changed, 44 insertions(+), 58 deletions(-) diff --git a/src/spyglass/utils/dj_merge_tables.py b/src/spyglass/utils/dj_merge_tables.py index fd7b29168..9599a8bd4 100644 --- a/src/spyglass/utils/dj_merge_tables.py +++ b/src/spyglass/utils/dj_merge_tables.py @@ -67,14 +67,14 @@ def __init__(self): self._reserved_sk = RESERVED_SECONDARY_KEY if not self.is_declared: if not is_merge_table(self): # Check definition - logger.warning( + self._warn_msg( "Merge table with non-default definition\n" + f"Expected:\n{MERGE_DEFINITION.strip()}\n" + f"Actual :\n{self.definition.strip()}" ) for part in self.parts(as_objects=True): if part.primary_key != self.primary_key: - logger.warning( # PK is only 'merge_id' in parts, no others + self._warn_msg( # PK is only 'merge_id' in parts, no others f"Unexpected primary key in {part.table_name}" + f"\n\tExpected: {self.primary_key}" + f"\n\tActual : {part.primary_key}" @@ -101,7 +101,7 @@ def parts(self, camel_case=False, *args, **kwargs) -> list: self._ensure_dependencies_loaded() if camel_case and kwargs.get("as_objects"): - logger.warning( + self._warn_msg( "Overriding as_objects=True to return CamelCase part names." ) kwargs["as_objects"] = False diff --git a/src/spyglass/utils/mixins/base.py b/src/spyglass/utils/mixins/base.py index 19a6f4cb3..08fc641d6 100644 --- a/src/spyglass/utils/mixins/base.py +++ b/src/spyglass/utils/mixins/base.py @@ -16,6 +16,7 @@ def _logger(self): - RestrictByMixin - ExportMixin - AnalysisMixin + - Merge """ from spyglass.utils import logger @@ -44,6 +45,7 @@ def _info_msg(self, msg: str) -> None: Used by ... - AnalysisMixin.copy and .create - IngestionMixin._insert_logline + - Merge._merge_repr """ log = self._logger.debug if self._test_mode else self._logger.info log(msg) diff --git a/src/spyglass/utils/mixins/cautious_delete.py b/src/spyglass/utils/mixins/cautious_delete.py index 1170ca2ef..224570845 100644 --- a/src/spyglass/utils/mixins/cautious_delete.py +++ b/src/spyglass/utils/mixins/cautious_delete.py @@ -215,7 +215,7 @@ def cautious_delete( Passed to datajoint.table.Table.delete. """ if len(self) == 0: - self._logger.warning("Table is empty. Nothing to delete.") + self._warn_msg("Table is empty. Nothing to delete.") return if self._has_updated_dj_version and not isinstance(self, dj.Part): diff --git a/src/spyglass/utils/mixins/helpers.py b/src/spyglass/utils/mixins/helpers.py index f2d655091..927f02b59 100644 --- a/src/spyglass/utils/mixins/helpers.py +++ b/src/spyglass/utils/mixins/helpers.py @@ -51,9 +51,7 @@ def file_like(self, name=None, **kwargs): attr = field break if not attr: - self._logger.error( - f"No file_like field found in {self.full_table_name}" - ) + self._err_msg(f"No file_like field found in {self.full_table_name}") return return self & f"{attr} LIKE '%{name}%'" @@ -79,7 +77,9 @@ def find_insert_fail(self, key): rets.append(f"{parent_name}:\n{query}") else: rets.append(f"{parent_name}: MISSING") - self._logger.info("\n".join(rets)) + result = "\n".join(rets) + self._info_msg(result) + return result @classmethod def _safe_context(cls): diff --git a/src/spyglass/utils/mixins/restrict_by.py b/src/spyglass/utils/mixins/restrict_by.py index 64b61453f..5a61d1694 100644 --- a/src/spyglass/utils/mixins/restrict_by.py +++ b/src/spyglass/utils/mixins/restrict_by.py @@ -26,7 +26,7 @@ def unban_search_table(self, table): def see_banned_tables(self): """Print banned tables.""" - self._logger.info(f"Banned tables: {self._banned_search_tables}") + self._info_msg(f"Banned tables: {self._banned_search_tables}") def restrict_by( self, diff --git a/tests/decoding/test_clusterless.py b/tests/decoding/test_clusterless.py index 795cfeebc..b8e4a9f08 100644 --- a/tests/decoding/test_clusterless.py +++ b/tests/decoding/test_clusterless.py @@ -2,8 +2,6 @@ import pandas as pd import pytest -from tests.conftest import VERBOSE - @pytest.mark.very_slow def test_fetch_results(clusterless_pop, result_coordinates): @@ -77,10 +75,9 @@ def test_get_ahead(clusterless_pop): assert dist is not None, "Distance is None" -@pytest.mark.skipif(not VERBOSE, reason="No logging to test when quiet-spy") -def test_insert_existing_group(caplog, group_unitwave): +def test_insert_existing_group(group_unitwave): file, group = group_unitwave.fetch1( "nwb_file_name", "waveform_features_group_name" ) - group_unitwave.create_group(file, group, ["dummy_data"]) - assert "already exists" in caplog.text, "No warning issued." + ret = group_unitwave.create_group(file, group, ["dummy_data"]) + assert ret is None, "Duplicate create_group should return None." diff --git a/tests/utils/test_graph.py b/tests/utils/test_graph.py index 4437ca290..955d59b9a 100644 --- a/tests/utils/test_graph.py +++ b/tests/utils/test_graph.py @@ -1,8 +1,6 @@ import pytest from datajoint.utils import to_camel_case -from tests.conftest import VERBOSE - @pytest.fixture(scope="session") def leaf(lin_merge): @@ -275,18 +273,16 @@ def test_restr_from_downstream(graph_tables, table, restr, expect_n, msg): assert len(graph_tables[table]() << restr) == expect_n, msg -@pytest.mark.skipif(not VERBOSE, reason="No logging to test when quiet-spy.") -def test_ban_node(caplog, graph_tables): +def test_ban_node(graph_tables): search_restr = "sk_attr > 17" ParentNode = graph_tables["ParentNode"]() SkNode = graph_tables["SkNode"]() ParentNode.ban_search_table(SkNode) - ParentNode >> search_restr - assert "could not be applied" in caplog.text, "Found banned table." - - ParentNode.see_banned_tables() - assert "Banned tables" in caplog.text, "Banned tables not logged." + assert (ParentNode >> search_restr) is None, "Banned table still reachable." + assert ( + SkNode.full_table_name in ParentNode._banned_search_tables + ), "Banned table not in set." ParentNode.unban_search_table(SkNode) assert len(ParentNode >> search_restr) == 3, "Unban failed." @@ -297,8 +293,7 @@ def test_null_restrict_by(graph_tables): assert (PkNode >> True) == PkNode, "Null restriction failed." -@pytest.mark.skipif(not VERBOSE, reason="No logging to test when quiet-spy.") -def test_restrict_by_this_table(caplog, graph_tables): +def test_restrict_by_this_table(graph_tables): PkNode = graph_tables["PkNode"]() dist = (PkNode >> "pk_id > 4").restriction plain = (PkNode & "pk_id > 4").restriction @@ -342,12 +337,9 @@ def test_restr_invalid_err(graph_tables): len(PkNode << set(["parent_attr > 15", "parent_attr < 20"])) -@pytest.mark.skipif(not VERBOSE, reason="No logging to test when quiet-spy.") -def test_restr_invalid(caplog, graph_tables): - graph_tables["PkNode"]() << "invalid_restr=1" - assert ( - "could not be applied" in caplog.text - ), "No warning logged on invalid restr." +def test_restr_invalid(graph_tables): + result = graph_tables["PkNode"]() << "invalid_restr=1" + assert result is None, "Invalid restriction should return None." @pytest.fixture(scope="session") diff --git a/tests/utils/test_merge.py b/tests/utils/test_merge.py index 6636544aa..324cab2a6 100644 --- a/tests/utils/test_merge.py +++ b/tests/utils/test_merge.py @@ -2,8 +2,6 @@ import pytest from datajoint.logging import logger as dj_logger -from tests.conftest import VERBOSE - @pytest.fixture(scope="function") def BadMerge(): @@ -32,11 +30,13 @@ class BadChild(SpyglassMixin, dj.Part): dj_logger.setLevel(prev_level) -@pytest.mark.skipif(not VERBOSE, reason="No logging to test when quiet-spy.") -def test_nwb_table_missing(BadMerge, caplog, schema_test): +def test_nwb_table_missing(BadMerge, schema_test): + from spyglass.utils.dj_merge_tables import is_merge_table + schema_test(BadMerge) - txt = caplog.text - assert "non-default definition" in txt, "Warning not caught." + assert not is_merge_table( + BadMerge() + ), "BadMerge should fail merge-table check." @pytest.fixture(scope="function") @@ -63,11 +63,11 @@ def test_part_camel(merge_table): assert "_" not in example_part, "Camel case not applied." -@pytest.mark.skipif(not VERBOSE, reason="No logging to test when quiet-spy.") -def test_override_warning(caplog, merge_table): - _ = merge_table.parts(camel_case=True, as_objects=True)[0] - txt = caplog.text - assert "Overriding" in txt, "Warning not caught." +def test_override_warning(merge_table): + parts = merge_table.parts(camel_case=True, as_objects=True) + assert all( + isinstance(p, str) for p in parts + ), "as_objects=True should be overridden to return CamelCase strings." def test_merge_view(pos_merge): diff --git a/tests/utils/test_mixin.py b/tests/utils/test_mixin.py index 374e1a7e5..6d6c79731 100644 --- a/tests/utils/test_mixin.py +++ b/tests/utils/test_mixin.py @@ -50,18 +50,17 @@ def test_null_file_like(schema_test, Mixin): assert len(ret) == len(Mixin()), "Null file_like not working." -@pytest.mark.skipif(not VERBOSE, reason="No logging to test when quiet-spy.") -def test_bad_file_like(caplog, schema_test, Mixin): +def test_bad_file_like(schema_test, Mixin): schema_test(Mixin) - Mixin().file_like("BadName") - assert "No file_like field" in caplog.text, "No warning issued." + assert ( + Mixin().file_like("BadName") is None + ), "Expected None for missing field." -@pytest.mark.skipif(not VERBOSE, reason="No logging to test when quiet-spy.") -def test_insert_fail(caplog, common, mini_dict): +def test_insert_fail(common, mini_dict): this_key = dict(mini_dict, interval_list_name="BadName") - common.PositionSource().find_insert_fail(this_key) - assert "IntervalList: MISSING" in caplog.text, "No warning issued." + result = common.PositionSource().find_insert_fail(this_key) + assert "MISSING" in result, "Expected MISSING parent in result." def test_exp_summary(Nwbfile): @@ -90,16 +89,12 @@ def test_cautious_del_dry_run(Nwbfile, frequent_imports): ), "Dry run delete not working." -@pytest.mark.skipif(not VERBOSE, reason="No logging to test when quiet-spy.") -def test_empty_cautious_del(caplog, schema_test, Mixin): +def test_empty_cautious_del(schema_test, Mixin): schema_test(Mixin) mixin = Mixin() - prev_level = mixin._logger.level - mixin._logger.setLevel("INFO") - mixin.cautious_delete(safemode=False) mixin.cautious_delete(safemode=False) - assert "empty" in caplog.text, "No warning issued." - mixin._logger.setLevel(prev_level) + assert len(mixin) == 0, "Table should be empty after delete." + mixin.cautious_delete(safemode=False) # Should not raise on empty table def test_super_delete(schema_test, Mixin, common): From 4e6ed78793dbe6c36ef3f42c8170c4a38f9b79c5 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Sun, 22 Feb 2026 15:46:21 +0100 Subject: [PATCH 24/30] Remove redundant log --- src/spyglass/utils/mixins/helpers.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/spyglass/utils/mixins/helpers.py b/src/spyglass/utils/mixins/helpers.py index 927f02b59..ee4acb9a7 100644 --- a/src/spyglass/utils/mixins/helpers.py +++ b/src/spyglass/utils/mixins/helpers.py @@ -77,9 +77,7 @@ def find_insert_fail(self, key): rets.append(f"{parent_name}:\n{query}") else: rets.append(f"{parent_name}: MISSING") - result = "\n".join(rets) - self._info_msg(result) - return result + return "\n".join(rets) @classmethod def _safe_context(cls): From 768d07072cd4f8778ec34a144208c844bcb8d832 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Sun, 22 Feb 2026 15:54:28 +0100 Subject: [PATCH 25/30] Review feedback --- src/spyglass/utils/mixins/analysis_builder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spyglass/utils/mixins/analysis_builder.py b/src/spyglass/utils/mixins/analysis_builder.py index c66e482d2..ee3cb64ba 100644 --- a/src/spyglass/utils/mixins/analysis_builder.py +++ b/src/spyglass/utils/mixins/analysis_builder.py @@ -141,7 +141,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): self.register() except Exception as e: self._exception_occurred = True - self._exception_log = {self.analysis_file_name: e.__name__} + self._exception_log = {self.analysis_file_name: type(e)} if not test_mode: logger.error( f"Failed to register analysis file " From beb9c2001dbfd556fc4506d85aca2adeee2f0eef Mon Sep 17 00:00:00 2001 From: Chris Broz Date: Sun, 22 Feb 2026 15:56:23 +0100 Subject: [PATCH 26/30] Update scripts/install.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- scripts/install.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/install.py b/scripts/install.py index 75581c9db..425bfa192 100755 --- a/scripts/install.py +++ b/scripts/install.py @@ -630,7 +630,7 @@ def create(self, env_file: str, force: bool = False) -> None: for kw in ["Solving", "Downloading", "Extracting"] ): print(".", end="", flush=True) - Console.print() + Console.print("") if process.returncode != 0: raise subprocess.CalledProcessError( From 0c67f6a8eb905ed13ac66418b8194fcaaf3ada49 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Sun, 22 Feb 2026 15:59:35 +0100 Subject: [PATCH 27/30] Review feedback --- src/spyglass/spikesorting/v1/recording.py | 4 ++-- src/spyglass/utils/dj_graph.py | 7 +++++-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/src/spyglass/spikesorting/v1/recording.py b/src/spyglass/spikesorting/v1/recording.py index 6ee2376bc..ad8953788 100644 --- a/src/spyglass/spikesorting/v1/recording.py +++ b/src/spyglass/spikesorting/v1/recording.py @@ -226,7 +226,7 @@ def make_fetch(self, key): def make_compute( self, key, nwb_file_name, sort_interval_valid_times - ) -> dict: + ) -> list: """Compute/save SpikeSortingRecording Returns @@ -248,7 +248,7 @@ def make_insert( nwb_file_name: str, file_dict: dict, sort_interval_valid_times: IntervalLike, - ) -> dict: + ) -> None: insert_key = dict(key, **file_dict) # INSERT: diff --git a/src/spyglass/utils/dj_graph.py b/src/spyglass/utils/dj_graph.py index dbcf82707..edd828613 100644 --- a/src/spyglass/utils/dj_graph.py +++ b/src/spyglass/utils/dj_graph.py @@ -1233,8 +1233,11 @@ def file_dict(self) -> Dict[str, List[str]]: return {t: self._get_node(t).get("files", []) for t in self.restr_ft} def _stored_files(self, as_dict=False) -> Dict[str, str] | Set[str]: - """Return dictionary of table names and files.""" - # Added for debugging + """Return dictionary of table names and files. + + Dictionary format is used for debugging and testing. Set format is used + for hashing and typical use. + """ self.cascade(warn=False) pairs = [ From e1d306db88e3b892601a8c74ee506d28e8738614 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Wed, 25 Feb 2026 09:15:24 +0100 Subject: [PATCH 28/30] Expand DLCProject.config_path to accommodate different base_dirs --- CHANGELOG.md | 6 ++++++ src/spyglass/position/v1/position_dlc_project.py | 5 ++++- tests/common/test_interval_helpers.py | 4 ++++ tests/conftest.py | 13 ++++++++++--- 4 files changed, 24 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b48e92bb6..e609ec32f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -42,6 +42,11 @@ RecordingRecompute().alter() from spyglass.lfp.analysis.v1 import LFPBandV1 LFPBandV1().fix_1481() + +# Increase DLCProject.config_path length +from spyglass.position.v1.position_dlc_project import DLCProject + +DLCProject().alter() ``` ### Breaking Changes @@ -195,6 +200,7 @@ for label, interval_data in results.groupby("interval_labels"): - DLC parameter handling improvements and default value corrections #1379 - Fix ingestion nwb files with position objects but no spatial series #1405 - Ignore `percent_frames` when using `limit` in `DLCPosVideo` #1418 + - Increase `DLCProject.config_path` length #1534 - Spikesorting diff --git a/src/spyglass/position/v1/position_dlc_project.py b/src/spyglass/position/v1/position_dlc_project.py index b96e4f6d1..f9d2b3c2e 100644 --- a/src/spyglass/position/v1/position_dlc_project.py +++ b/src/spyglass/position/v1/position_dlc_project.py @@ -66,9 +66,12 @@ class DLCProject(SpyglassMixin, dj.Manual): -> LabTeam bodyparts : blob # list of bodyparts to label frames_per_video : int # number of frames to extract from each video - config_path : varchar(120) # path to config.yaml for model + config_path : varchar(255) # path to config.yaml for model """ + # NOTE: #1534, config~path: varchar(120) -> varchar(255) + # to accommodate longer paths for nested projects. + class BodyPart(SpyglassMixin, dj.Part): """Part table to hold bodyparts used in each project.""" diff --git a/tests/common/test_interval_helpers.py b/tests/common/test_interval_helpers.py index 421cb3e44..61d30a69f 100644 --- a/tests/common/test_interval_helpers.py +++ b/tests/common/test_interval_helpers.py @@ -16,6 +16,10 @@ def cautious_interval(interval_list, mini_dict): valid_times=[[0, 1]], pipeline="", ) + # Always clean up before inserting to handle --no-teardown reruns where + # test_cautious_insert_update may have changed the times to [[0, 2]] + pk = {k: insert[k] for k in interval_list.primary_key if k in insert} + (interval_list & pk).delete_quick() interval_list.insert1(insert, skip_duplicates=True) yield insert diff --git a/tests/conftest.py b/tests/conftest.py index 8f3c970d8..e283b350d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -288,9 +288,13 @@ def pytest_addoption(parser): parser.addoption( "--base-dir", action="store", - default="./tests/_data/", + default=None, dest="base_dir", - help="Directory for local input file.", + help=( + "Directory for local input file. " + "Also reads SPYGLASS_BASE_DIR env var when unset. " + "Default: './tests/_data/'." + ), ) parser.addoption( "--no-teardown", @@ -339,7 +343,10 @@ def pytest_configure(config): NO_DLC = config.option.no_dlc pytest.NO_DLC = NO_DLC - BASE_DIR = Path(config.option.base_dir).absolute() + _base_dir = config.option.base_dir or os.environ.get( + "SPYGLASS_BASE_DIR", "./tests/_data/" + ) + BASE_DIR = Path(_base_dir).expanduser().absolute() BASE_DIR.mkdir(parents=True, exist_ok=True) RAW_DIR = BASE_DIR / "raw" os.environ["SPYGLASS_BASE_DIR"] = str(BASE_DIR) From a929a10e3de4c2f11ee154b215127b903e90fbcf Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Wed, 25 Feb 2026 12:41:14 +0100 Subject: [PATCH 29/30] No-teardown inconsistency fixes --- .../position/v1/position_dlc_project.py | 1 - tests/common/test_file_tracking.py | 38 +++++++++---------- tests/common/test_interval.py | 5 ++- tests/decoding/conftest.py | 32 +++++++++++----- tests/position/v1/test_dlc_cohort.py | 2 +- tests/spikesorting/v0/test_recompute.py | 7 +++- tests/spikesorting/v1/test_recompute.py | 5 +++ 7 files changed, 56 insertions(+), 34 deletions(-) diff --git a/src/spyglass/position/v1/position_dlc_project.py b/src/spyglass/position/v1/position_dlc_project.py index f9d2b3c2e..a5710a9e8 100644 --- a/src/spyglass/position/v1/position_dlc_project.py +++ b/src/spyglass/position/v1/position_dlc_project.py @@ -70,7 +70,6 @@ class DLCProject(SpyglassMixin, dj.Manual): """ # NOTE: #1534, config~path: varchar(120) -> varchar(255) - # to accommodate longer paths for nested projects. class BodyPart(SpyglassMixin, dj.Part): """Part table to hold bodyparts used in each project.""" diff --git a/tests/common/test_file_tracking.py b/tests/common/test_file_tracking.py index 7eae9cbb6..df010649b 100644 --- a/tests/common/test_file_tracking.py +++ b/tests/common/test_file_tracking.py @@ -114,8 +114,8 @@ def test_check1_file_existing( "full_table_name": analysis_tbl.full_table_name, "analysis_file_name": analysis_file_name, } - if teardown: - (analysis_file_issues & key).delete_quick() + # Always clear before test to avoid stale entries from prior --no-teardown runs + (analysis_file_issues & key).delete_quick() # Check the file - should not find issues issue_found = analysis_file_issues.check1_file( @@ -134,14 +134,17 @@ def test_check_files_with_valid_file( analysis_file_name, analysis_tbl = test_analysis_file # Clear previous entries for this table - if teardown: - ( - analysis_file_issues - & {"full_table_name": analysis_tbl.full_table_name} - ).delete_quick() - - # Run check_files - should find no issues with valid file - issue_count = analysis_file_issues.check_files(analysis_tbl) + # Always clear before test to avoid stale entries from prior --no-teardown runs + ( + analysis_file_issues & {"full_table_name": analysis_tbl.full_table_name} + ).delete_quick() + + # Run check_files - should find no issues with the specific valid file. + # Restrict to this file only: the table is shared across test modules, so + # checking all entries would include files from other tests that may be + # missing (e.g. after --no-teardown reruns). + restricted_tbl = analysis_tbl & {"analysis_file_name": analysis_file_name} + issue_count = analysis_file_issues.check_files(restricted_tbl) # Valid files should return 0 issues assert issue_count == 0 @@ -205,9 +208,8 @@ def test_check1_file_missing_from_disk( "analysis_file_name": analysis_file_name, } - # Clear any previous entries for this file - if teardown: - (analysis_file_issues & key).delete_quick() + # Always clear before test to avoid stale entries from prior --no-teardown runs + (analysis_file_issues & key).delete_quick() try: # Check the file - should detect it's missing @@ -244,9 +246,8 @@ def test_check1_file_filenotfound_error( "analysis_file_name": analysis_file_name, } - # Clear any previous entries for this file - if teardown: - (analysis_file_issues & key).delete_quick() + # Always clear before test to avoid stale entries from prior --no-teardown runs + (analysis_file_issues & key).delete_quick() # Mock get_abs_path to raise FileNotFoundError error_msg = "File path could not be constructed" @@ -283,9 +284,8 @@ def test_check1_file_checksum_error( "analysis_file_name": analysis_file_name, } - # Clear any previous entries for this file - if teardown: - (analysis_file_issues & key).delete_quick() + # Always clear before test to avoid stale entries from prior --no-teardown runs + (analysis_file_issues & key).delete_quick() # Mock get_abs_path to raise DataJointError (simulating checksum mismatch) error_msg = "Checksum mismatch for file" diff --git a/tests/common/test_interval.py b/tests/common/test_interval.py index 12b1d2486..b527bc11f 100644 --- a/tests/common/test_interval.py +++ b/tests/common/test_interval.py @@ -47,13 +47,14 @@ def test_plot_intervals(interval_list, mini_dict, start_time=0): intervals_fig = np.array(intervals_fig, dtype="object") # extract interval times from the IntervalList table - intervals_fetch = interval_list.fetch("valid_times") + intervals_fetch = interval_query.fetch("valid_times") all_equal = True for i_fig, i_fetch in zip(intervals_fig, intervals_fetch): # permit rounding errors up to the 4th decimal place, as a result of inaccuracies during unit conversions i_fig = np.round(i_fig.astype("float"), 4) - i_fetch = np.round(i_fetch.astype("float"), 4) + # Normalize i_fetch: handle Python lists and 1D [start, end] arrays + i_fetch = np.round(np.array(i_fetch, dtype="float").reshape(-1, 2), 4) if not array_equal(i_fig, i_fetch): all_equal = False diff --git a/tests/decoding/conftest.py b/tests/decoding/conftest.py index 32b5135f1..5e6eb4765 100644 --- a/tests/decoding/conftest.py +++ b/tests/decoding/conftest.py @@ -755,22 +755,36 @@ def _mock_load_results(filename): return xr.open_dataset(filename_str, engine="netcdf4") except (FileNotFoundError, OSError) as e: # OSError with "Unknown file format" means old pickle file exists - # FileNotFoundError means file doesn't exist at all + # (from a prior --no-teardown run where mock_netcdf_saves wrote + # pickle format). Fall back to pickle to keep reruns working. if "Unknown file format" in str(e): - raise FileNotFoundError( - f"Mock result has invalid format (likely old pickle file): {filename_str}. " - "Please delete old *_mocked.nc files from tests/_data/analysis/" - ) + import pickle + + try: + with open(filename_str, "rb") as f: + return pickle.load(f) + except Exception: + raise FileNotFoundError( + f"Mock result has invalid format: {filename_str}. " + "Please delete old *_mocked.nc files from tests/_data/analysis/" + ) raise FileNotFoundError(f"Mock result not found: {filename_str}") def _mock_load_model(filename): - """Load classifier from in-memory storage.""" + """Load classifier from in-memory storage or disk.""" filename_str = str(filename) if filename_str in mock_results_storage["classifiers"]: return mock_results_storage["classifiers"][filename_str] - raise FileNotFoundError( - f"Mock classifier not found in memory: {filename_str}" - ) + # Fall back to disk for --no-teardown reruns where in-memory is empty + import pickle + + try: + with open(filename_str, "rb") as f: + return pickle.load(f) + except (FileNotFoundError, OSError): + raise FileNotFoundError( + f"Mock classifier not found in memory or on disk: {filename_str}" + ) # Patch the detector base classes' load methods globally with ( diff --git a/tests/position/v1/test_dlc_cohort.py b/tests/position/v1/test_dlc_cohort.py index 668c5d40c..e4adc13b2 100644 --- a/tests/position/v1/test_dlc_cohort.py +++ b/tests/position/v1/test_dlc_cohort.py @@ -40,7 +40,7 @@ def test_cohort_error(cohort_tbls): dict(select_pk, bodyparts_params_dict=dict(bad_bp="bad_bp")) ) if select_tbl & select_pk: - select_tbl.delete(safemode=False) + (select_tbl & select_pk).delete(safemode=False) select_tbl.insert1(select_key, skip_duplicates=True) with pytest.raises(ValueError): diff --git a/tests/spikesorting/v0/test_recompute.py b/tests/spikesorting/v0/test_recompute.py index d620bb4c5..b9b90048b 100644 --- a/tests/spikesorting/v0/test_recompute.py +++ b/tests/spikesorting/v0/test_recompute.py @@ -73,10 +73,13 @@ def test_recompute_env(recomp_repop): assert ret, "Recompute failed" -def test_selection_restr(recomp_repop, user_env_tbl, recomp_selection): +def test_selection_restr(recomp_repop, recomp_selection): """Test that the selection env restriction works.""" _ = recomp_repop # Ensure recompute repop is used to load the recording - env_dict = user_env_tbl.this_env + # Use recomp_selection.env_dict (same source as this_env) for consistency. + # user_env_tbl.this_env is a cached_property that can become stale in full + # suite runs due to class-level _pip_custom dict contamination across tests. + env_dict = recomp_selection.env_dict manual_restr = recomp_selection & env_dict assert len(recomp_selection.this_env) == len( manual_restr diff --git a/tests/spikesorting/v1/test_recompute.py b/tests/spikesorting/v1/test_recompute.py index 1c7e24650..a02b4e7dc 100644 --- a/tests/spikesorting/v1/test_recompute.py +++ b/tests/spikesorting/v1/test_recompute.py @@ -45,6 +45,11 @@ def recomp_tbl(recomp_module, recomp_selection, spike_v1, pop_rec): """Fixture to ensure recompute table is loaded.""" _ = spike_v1, pop_rec # Ensure pop_rec is used to load the recording + # Re-populate selection if empty. After a prior --no-teardown run, + # make() calls remove_matched() to clean up entries post-recompute. + if not len(recomp_selection): + recomp_selection.attempt_all() + key = recomp_selection.fetch("KEY")[0] key["logged_at_creation"] = False # Prevent skip of recompute recomp_selection.update1(key) From 75e6b2c185ea18bef75e3f79d84776e2d00045e4 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Thu, 26 Feb 2026 10:46:11 +0100 Subject: [PATCH 30/30] Fetch tests from 1529 @samuelbray32 --- src/spyglass/utils/mixins/base.py | 4 ++- tests/utils/test_dj_helper_fn.py | 50 +++++++++++++++++++++++++++++++ tests/utils/test_mixin.py | 22 ++++++++++++++ 3 files changed, 75 insertions(+), 1 deletion(-) diff --git a/src/spyglass/utils/mixins/base.py b/src/spyglass/utils/mixins/base.py index 08fc641d6..a2c834349 100644 --- a/src/spyglass/utils/mixins/base.py +++ b/src/spyglass/utils/mixins/base.py @@ -73,7 +73,9 @@ def _test_mode(self) -> bool: Avoids circular import. Prevents prompt on delete. Note: Using cached property b/c we don't expect test_mode to change - during runtime, and it avoids repeated lookups. + during runtime, and it avoids repeated lookups. Changing to @property + wouldn't reload the config. It would just re-fetch from the settings + module. Used by ... - BaseMixin._spyglass_version diff --git a/tests/utils/test_dj_helper_fn.py b/tests/utils/test_dj_helper_fn.py index aa063648f..a2278ddf6 100644 --- a/tests/utils/test_dj_helper_fn.py +++ b/tests/utils/test_dj_helper_fn.py @@ -9,3 +9,53 @@ def test_deprecation_factory(caplog, common): assert ( "Deprecation:" in caplog.text ), "No deprecation warning logged on migrated table." + + +@pytest.fixture(scope="module") +def str_to_bool(): + from spyglass.utils.dj_helper_fn import str_to_bool + + return str_to_bool + + +@pytest.mark.parametrize( + "input_str, expected", + [ + # Test truthy strings (case insensitive) + ("true", True), + ("True", True), + ("TRUE", True), + ("t", True), + ("T", True), + ("yes", True), + ("YES", True), + ("y", True), + ("Y", True), + ("1", True), + # Test falsy strings + ("false", False), + ("False", False), + ("FALSE", False), + ("no", False), + ("n", False), + ("0", False), + ("", False), + ("random_string", False), + # Test non-string inputs + (True, True), + (False, False), + (None, False), + (0, False), + (1, True), + ], +) +def test_str_to_bool(str_to_bool, input_str, expected): + """Test str_to_bool function handles various string inputs correctly. + + This test verifies that str_to_bool properly converts string values + to booleans, which is essential for fixing issue #1528 where string + "false" was incorrectly evaluated as True. + """ + assert ( + str_to_bool(input_str) == expected + ), f"Expected {expected} for input '{input_str}'" diff --git a/tests/utils/test_mixin.py b/tests/utils/test_mixin.py index 6d6c79731..eb582854d 100644 --- a/tests/utils/test_mixin.py +++ b/tests/utils/test_mixin.py @@ -172,3 +172,25 @@ def test_mixin_del_orphans(dj_conn, Mixin, MixinChild): Mixin().delete_orphans(dry_run=False, safemode=False) post_del = Mixin().fetch1("id") assert post_del == 0, "Delete orphans not working." + + +def test_test_mode_property_uses_settings(schema_test, Mixin): + """Test that _test_mode property uses spyglass.settings.config. + + Verifies fix for issue #1528 where string "false" was incorrectly + evaluated as True. The property should now use spyglass.settings.config.test_mode + which properly converts strings to booleans via str_to_bool(). + """ + schema_test(Mixin) + + # The _test_mode property should return a boolean + test_mode_value = Mixin()._test_mode + assert isinstance( + test_mode_value, bool + ), "_test_mode should return a boolean value" + + # In test environment, test_mode should be True + # (set by load_config fixture in conftest.py) + assert ( + test_mode_value is True + ), "_test_mode should be True in test environment"