Skip to content

Commit b2a5769

Browse files
feat(tidy3d): FXC-5311-enable-cached-loading-from-batch-data
1 parent 7464a49 commit b2a5769

File tree

6 files changed

+208
-7
lines changed

6 files changed

+208
-7
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1717
- Added `GaussianPort` and `AstigmaticGaussianPort` for S-matrix calculations using Gaussian beam sources and overlap monitors.
1818
- Added `symmetric_pseudo` option for `s_param_def` in `TerminalComponentModeler` which applies a scaling factor that ensures the S-matrix is symmetric in reciprocal systems.
1919
- Added deprecation warning for ``TemperatureMonitor`` and ``SteadyPotentialMonitor`` when ``unstructured`` parameter is not explicitly set. The default value of ``unstructured`` will change from ``False`` to ``True`` after the 2.11 release.
20+
- Added in-memory caching for downloaded batch results, configurable via ``config.batch_data_cache``.
2021

2122
### Breaking Changes
2223
- Added optional automatic extrusion of structures at the simulation boundaries into/through PML/Absorber layers via `extrude_structures` field in class `AbsorberSpec`.

docs/configuration/reference.rst

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,29 @@ Controls the optional on-disk cache for simulation artifacts.
245245
- Maximum number of cached simulations retained. ``0`` means no limit and eviction falls back to size constraints.
246246

247247

248+
Batch Data Cache (``config.batch_data_cache``)
249+
----------------------------------------------
250+
251+
Controls the in-memory cache used when accessing entries in ``BatchData``.
252+
253+
.. list-table::
254+
:header-rows: 1
255+
:widths: 24 18 10 48
256+
257+
* - Option
258+
- Default
259+
- Persisted
260+
- Description
261+
* - ``enabled``
262+
- ``True``
263+
- Yes
264+
- Cache batch results in memory when all task data files are below the size threshold.
265+
* - ``max_total_size_gb``
266+
- ``1.0``
267+
- Yes
268+
- Cache batch task data only when the combined size of all task data files is at or below this threshold. ``0`` disables caching.
269+
270+
248271
Plugins (``config.plugins``)
249272
----------------------------
250273

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
"""Tests for BatchData in-memory caching."""
2+
3+
from __future__ import annotations
4+
5+
from pathlib import Path
6+
7+
import tidy3d as td
8+
from tidy3d.web.api import container as web_container
9+
10+
11+
def _write_bytes(path: Path, size: int) -> None:
12+
path.write_bytes(b"0" * size)
13+
14+
15+
def test_batch_data_caches_small_files(monkeypatch, tmp_path):
16+
task_paths = {
17+
"task1": str(tmp_path / "task1.hdf5"),
18+
"task2": str(tmp_path / "task2.hdf5"),
19+
}
20+
task_ids = {"task1": "task-1", "task2": "task-2"}
21+
_write_bytes(Path(task_paths["task1"]), 1)
22+
_write_bytes(Path(task_paths["task2"]), 2)
23+
24+
monkeypatch.setattr(td.config.batch_data_cache, "enabled", True)
25+
monkeypatch.setattr(td.config.batch_data_cache, "max_total_size_gb", 1.0)
26+
27+
calls = {"load": 0, "info": 0}
28+
sentinels = [object(), object()]
29+
30+
def fake_load(*args, **kwargs):
31+
result = sentinels[calls["load"]]
32+
calls["load"] += 1
33+
return result
34+
35+
def fake_get_info(*args, **kwargs):
36+
calls["info"] += 1
37+
return None
38+
39+
monkeypatch.setattr(web_container.web, "load", fake_load)
40+
monkeypatch.setattr(web_container.web, "get_info", fake_get_info)
41+
42+
batch_data = td.web.BatchData(
43+
task_paths=task_paths,
44+
task_ids=task_ids,
45+
is_downloaded=True,
46+
)
47+
48+
first = batch_data["task1"]
49+
second = batch_data["task1"]
50+
51+
assert first is second
52+
assert calls["load"] == 1
53+
assert calls["info"] == 1
54+
55+
56+
def test_batch_data_skips_cache_when_any_file_is_large(monkeypatch, tmp_path):
57+
task_paths = {
58+
"task1": str(tmp_path / "task1.hdf5"),
59+
"task2": str(tmp_path / "task2.hdf5"),
60+
}
61+
task_ids = {"task1": "task-1", "task2": "task-2"}
62+
_write_bytes(Path(task_paths["task1"]), 1)
63+
_write_bytes(Path(task_paths["task2"]), 2)
64+
65+
threshold_gb = 2 / (1024**3)
66+
monkeypatch.setattr(td.config.batch_data_cache, "enabled", True)
67+
monkeypatch.setattr(td.config.batch_data_cache, "max_total_size_gb", threshold_gb)
68+
69+
calls = {"load": 0, "info": 0}
70+
sentinels = [object(), object()]
71+
72+
def fake_load(*args, **kwargs):
73+
result = sentinels[calls["load"]]
74+
calls["load"] += 1
75+
return result
76+
77+
def fake_get_info(*args, **kwargs):
78+
calls["info"] += 1
79+
return None
80+
81+
monkeypatch.setattr(web_container.web, "load", fake_load)
82+
monkeypatch.setattr(web_container.web, "get_info", fake_get_info)
83+
84+
batch_data = td.web.BatchData(
85+
task_paths=task_paths,
86+
task_ids=task_ids,
87+
is_downloaded=True,
88+
)
89+
90+
first = batch_data["task1"]
91+
second = batch_data["task1"]
92+
93+
assert first is not second
94+
assert calls["load"] == 2
95+
assert calls["info"] == 2

tidy3d/config/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ flowchart LR
5050

5151
## Module Reference
5252

53-
- `sections.py` - Pydantic models for built-in sections (logging, simulation, microwave, adjoint, web, local cache, plugin container) registered via `register_section`. The bundled models inherit from the internal `ConfigSection` helper, but external code can use plain `BaseModel` subclasses. Optional handlers perform side effects. Fields mark persistence with `json_schema_extra={"persist": True}`.
53+
- `sections.py` - Pydantic models for built-in sections (logging, simulation, microwave, adjoint, web, local cache, in-memory batch data cache, plugin container) registered via `register_section`. The bundled models inherit from the internal `ConfigSection` helper, but external code can use plain `BaseModel` subclasses. Optional handlers perform side effects. Fields mark persistence with `json_schema_extra={"persist": True}`.
5454
- `registry.py` - Stores section and handler registries and notifies the attached manager so new entries appear immediately.
5555
- `manager.py` - `ConfigManager` caches validated models, tracks runtime overrides per profile, filters persisted fields, exposes helpers such as `plugins`, `profiles`, and `format`. `SectionAccessor` routes attribute access to `update_section`.
5656
- `loader.py` - Resolves the config directory, loads `config.toml` and `profiles/<name>.toml`, parses environment overrides, and writes atomically through `serializer.build_document`.

tidy3d/config/sections.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -514,6 +514,27 @@ def _serialize_directory(self, value: Path) -> str:
514514
return str(value)
515515

516516

517+
class BatchDataCacheConfig(ConfigSection):
518+
"""Settings controlling in-memory caching for batch data."""
519+
520+
enabled: bool = Field(
521+
True,
522+
title="Enable batch data cache",
523+
description="Cache batch results in memory when files are below the size threshold.",
524+
json_schema_extra={"persist": True},
525+
)
526+
527+
max_total_size_gb: NonNegativeFloat = Field(
528+
1.0,
529+
title="Maximum total batch data size (GB)",
530+
description=(
531+
"Cache batch task data only when the combined size of all task data files is at or "
532+
"below this threshold. Set to 0 to disable."
533+
),
534+
json_schema_extra={"persist": True},
535+
)
536+
537+
517538
@register_section("plugins")
518539
class PluginsContainer(ConfigSection):
519540
"""Container that holds plugin-specific configuration sections."""
@@ -527,10 +548,12 @@ class PluginsContainer(ConfigSection):
527548
register_section("web")(WebConfig)
528549
register_handler("web")(apply_web)
529550
register_section("local_cache")(LocalCacheConfig)
551+
register_section("batch_data_cache")(BatchDataCacheConfig)
530552

531553

532554
__all__ = [
533555
"AdjointConfig",
556+
"BatchDataCacheConfig",
534557
"LocalCacheConfig",
535558
"LoggingConfig",
536559
"MicrowaveConfig",

tidy3d/web/api/container.py

Lines changed: 65 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,12 @@
2727
TimeElapsedColumn,
2828
)
2929

30+
from tidy3d._runtime import WASM_BUILD
3031
from tidy3d.components.base import Tidy3dBaseModel, cached_property
3132
from tidy3d.components.mode.mode_solver import ModeSolver
3233
from tidy3d.components.types import annotate_type
3334
from tidy3d.components.types.workflow import WorkflowDataType, WorkflowType
35+
from tidy3d.config import config
3436
from tidy3d.exceptions import DataError
3537
from tidy3d.log import get_logging_console, log
3638
from tidy3d.web.api import webapi as web
@@ -617,24 +619,80 @@ class BatchData(Tidy3dBaseModel, Mapping):
617619
description="Whether the simulation data was downloaded before.",
618620
)
619621

622+
_data_cache: dict[TaskName, WorkflowDataType] = PrivateAttr(default_factory=dict)
623+
_cache_enabled: Optional[bool] = PrivateAttr(default=None)
624+
625+
def _should_cache_data(self) -> bool:
626+
"""Return True when in-memory caching should be enabled for batch data."""
627+
if self._cache_enabled is not None:
628+
return self._cache_enabled
629+
630+
self._cache_enabled = False
631+
if WASM_BUILD:
632+
return False
633+
634+
try:
635+
cache_config = config.batch_data_cache
636+
except AttributeError:
637+
return False
638+
if not cache_config.enabled:
639+
return False
640+
641+
max_bytes = int(cache_config.max_total_size_gb * (1024**3))
642+
if max_bytes <= 0:
643+
return False
644+
645+
total_size = 0
646+
for task_path in self.task_paths.values():
647+
try:
648+
file_size = Path(task_path).stat().st_size
649+
except FileNotFoundError: # not downloaded yet
650+
self._cache_enabled = None
651+
return False
652+
total_size += file_size
653+
if total_size > max_bytes:
654+
return False
655+
656+
self._cache_enabled = True
657+
return True
658+
620659
def load_sim_data(self, task_name: str) -> WorkflowDataType:
621-
"""Load a simulation data object from file by task name."""
660+
"""Load a simulation data object from file by task name.
661+
662+
When ``config.batch_data_cache.enabled`` is ``True`` and the total size of all task
663+
files stays under the configured threshold, the loaded object is cached in
664+
memory for subsequent accesses.
665+
"""
666+
cache_enabled = self._should_cache_data()
667+
if cache_enabled and task_name in self._data_cache:
668+
return self._data_cache[task_name]
669+
622670
task_data_path = Path(self.task_paths[task_name])
623671
task_id = self.task_ids[task_name]
624672
from_cache = self.cached_tasks[task_name] if self.cached_tasks else False
625673
if not from_cache:
626674
web.get_info(task_id)
627675

628-
return web.load(
676+
data = web.load(
629677
task_id=None if from_cache else task_id,
630678
path=task_data_path,
631679
verbose=False,
632680
replace_existing=not (from_cache or self.is_downloaded),
633681
lazy=self.lazy,
634682
)
635683

684+
if not cache_enabled and self._cache_enabled is None:
685+
cache_enabled = self._should_cache_data()
686+
if cache_enabled:
687+
self._data_cache[task_name] = data
688+
return data
689+
636690
def __getitem__(self, task_name: TaskName) -> WorkflowDataType:
637-
"""Get the simulation data object for a given ``task_name``."""
691+
"""Get the simulation data object for a given ``task_name``.
692+
693+
When ``config.batch_data_cache.enabled`` is `True` and the batch data size is within
694+
the configured threshold, the result is cached in memory.
695+
"""
638696
return self.load_sim_data(task_name)
639697

640698
def __iter__(self) -> Iterator[TaskName]:
@@ -811,9 +869,10 @@ def run(
811869
>>> for task_name, sim_data in batch_data.items(): # doctest: +SKIP
812870
... # do something with data. # doctest: +SKIP
813871
814-
``batch_data`` does not store all of the data objects in memory,
815-
rather it iterates over the task names and loads the corresponding
816-
data from file one by one. If no file exists for that task, it downloads it.
872+
``batch_data`` iterates over task names and loads the corresponding data
873+
from file one by one. When ``config.batch_data_cache.enabled`` is ``True`` and the
874+
total size of all task files is below `config.batch_data_cache.max_total_size_gb`,
875+
accessed results are cached in memory to avoid repeated loads.
817876
"""
818877
loaded = [job.load_if_cached for job in self.jobs.values()]
819878
self._check_path_dir(path_dir)

0 commit comments

Comments
 (0)