Skip to content
42 changes: 20 additions & 22 deletions tests/test_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
from torch.distributed.fsdp import fully_shard
from torch.distributed.tensor import DTensor, Replicate, Shard
from torchstore.state_dict_utils import TensorReference, TorchStoreStateDict
from torchstore.state_dict_utils import TensorMetadata, TorchStoreStateDict
from torchstore.utils import spawn_actors

from .utils import main, set_transport_type, transport_plus_strategy_params
Expand All @@ -41,13 +41,17 @@ def _setup_process_group():
"""Set up minimal distributed environment for DTensor testing."""

if not dist.is_initialized():
# Set minimal environment variables for single process
import os
import socket

# Find an available port dynamically
def _find_free_port():
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", 0))
return s.getsockname()[1]

os.environ.setdefault("MASTER_ADDR", "127.0.0.1")
os.environ.setdefault(
"MASTER_PORT", "29501"
) # Different port to avoid conflicts
os.environ.setdefault("MASTER_PORT", str(_find_free_port()))
os.environ.setdefault("RANK", "0")
os.environ.setdefault("WORLD_SIZE", "1")

Expand Down Expand Up @@ -217,7 +221,6 @@ async def do_test(self):
"optimizer": optimizer.state_dict(),
}
await ts.put_state_dict(state_dict, "v0")
print(state_dict)
fetched_state_dict = await ts.get_state_dict("v0")
return state_dict, fetched_state_dict

Expand Down Expand Up @@ -318,7 +321,7 @@ def _assert_equal_state_dict(state_dict1, state_dict2):
), f"{key=} {flattened_state_dict_1[key]=} {flattened_state_dict_2[key]=}"


def _verify_reconstructed_state_dict(flattened_original, flattened_reconstructed):
def _assert_equal_flattened_state_dict(flattened_original, flattened_reconstructed):
"""Utility function to verify reconstructed state dict matches original."""
for key, original_value in flattened_original.items():
reconstructed_value = flattened_reconstructed[key]
Expand Down Expand Up @@ -398,17 +401,14 @@ def test_torchstore_state_dict():
# Create TorchStoreStateDict
torchstore_state_dict = TorchStoreStateDict.from_state_dict(original_state_dict)

# Verify blob properties
blob = torchstore_state_dict.tensor_blob
assert blob.dtype == torch.uint8, f"Expected uint8 blob, got {blob.dtype}"
assert blob.dim() == 1, f"Expected 1D blob, got {blob.dim()}D"

# 1. Flatten original state dict
original_flattened, _ = flatten_state_dict(original_state_dict)

# 2. Verify keys match between original flattened and torchstore flattened state dict
assert set(original_flattened.keys()) == set(
torchstore_state_dict.flattened_state_dict.keys()
torchstore_state_dict.metadata_state_dict.keys()
), "Keys don't match between original and torchstore flattened state dicts"

# 3. Verify tensor references and calculate total size
Expand Down Expand Up @@ -450,18 +450,18 @@ def test_torchstore_state_dict():
), "Flattened keys don't match"

# Verify reconstruction using utility function
_verify_reconstructed_state_dict(original_flattened, reconstructed_flattened)
_assert_equal_flattened_state_dict(original_flattened, reconstructed_flattened)


def _verify_tensor_references(torchstore_state_dict, flattened_original):
"""Utility function to verify TensorReference objects in flattened state dict."""
"""Utility function to verify TensorMetadata objects in flattened state dict."""
for key, original_value in flattened_original.items():
torchstore_value = torchstore_state_dict.flattened_state_dict[key]
torchstore_value = torchstore_state_dict.metadata_state_dict[key]

if isinstance(original_value, torch.Tensor):
if hasattr(original_value, "_local_tensor"): # DTensor check
# DTensor should be converted to TensorReference with tensor_slice
assert isinstance(torchstore_value, TensorReference)
# DTensor should be converted to TensorMetadata with tensor_slice
assert isinstance(torchstore_value, TensorMetadata)
assert (
torchstore_value.tensor_slice is not None
), f"DTensor at {key} should have tensor_slice"
Expand All @@ -478,7 +478,7 @@ def _verify_tensor_references(torchstore_state_dict, flattened_original):
assert torchstore_value.dtype == local_tensor.dtype
else:
# Regular tensor should not have tensor_slice
assert isinstance(torchstore_value, TensorReference)
assert isinstance(torchstore_value, TensorMetadata)
assert (
torchstore_value.tensor_slice is None
), f"Regular tensor at {key} should not have tensor_slice"
Expand Down Expand Up @@ -521,7 +521,7 @@ def test_torchstore_state_dict_with_dtensor():

# Verify reconstruction using utility function
flattened_reconstructed, _ = flatten_state_dict(reconstructed_state_dict)
_verify_reconstructed_state_dict(flattened_original, flattened_reconstructed)
_assert_equal_flattened_state_dict(flattened_original, flattened_reconstructed)

dist.destroy_process_group()

Expand Down Expand Up @@ -584,9 +584,9 @@ async def test_state_dict_with_dtensor(self):
# Verify DTensor metadata is preserved
flattened_original, _ = flatten_state_dict(original_state_dict)
for key, value in flattened_original.items():
ref = torchstore_sd.flattened_state_dict.get(key)
ref = torchstore_sd.metadata_state_dict.get(key)
if isinstance(value, DTensor):
assert isinstance(ref, TensorReference)
assert isinstance(ref, TensorMetadata)
assert ref.tensor_slice is not None
assert ref.device_mesh is not None
assert ref.placements is not None
Expand Down Expand Up @@ -653,8 +653,6 @@ async def test_torchstore_state_dict_dtensor_distributed(
)

results = await actors.test_state_dict_with_dtensor.call()
for coord, (rank, status) in results:
assert status == "success", f"Actor rank {rank} failed"

finally:
if actors is not None:
Expand Down
45 changes: 29 additions & 16 deletions torchstore/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,18 @@
from typing import Any, Dict, List, Optional, Union

import torch

import torchstore.state_dict_utils
from monarch.actor import get_or_spawn_controller

from torchstore.client import LocalClient
from torchstore.controller import Controller

from torchstore.state_dict_utils import (
get_state_dict as get_state_dict_util,
get_state_dict_batch,
put_state_dict as put_state_dict_util,
put_state_dict_batch,
tssd_enabled,
)
from torchstore.storage_volume import StorageVolume
from torchstore.strategy import (
ControllerStorageVolumes,
Expand All @@ -25,7 +32,7 @@
DEFAULT_TORCHSTORE_NAME: str = "TorchStore"

# cache for local clients
_local_clent_map: Dict[str, LocalClient] = {}
_local_client_map: Dict[str, LocalClient] = {}


async def initialize(
Expand Down Expand Up @@ -93,14 +100,14 @@ async def shutdown(store_name: str = DEFAULT_TORCHSTORE_NAME) -> None:
"""
controller = await _controller(store_name)
await controller.teardown.call()
global _local_clent_map
_local_clent_map = {}
global _local_client_map
_local_client_map = {}


def reset_client(store_name: str = DEFAULT_TORCHSTORE_NAME) -> None:
"""Reset the local client for a given store. Useful for refreshing client state after shutdown."""
global _local_clent_map
_local_clent_map.pop(store_name, None)
global _local_client_map
_local_client_map.pop(store_name, None)


async def _controller(store_name: str = DEFAULT_TORCHSTORE_NAME) -> Controller:
Expand All @@ -123,8 +130,8 @@ async def client(store_name: str = DEFAULT_TORCHSTORE_NAME) -> LocalClient:
>>> store_client = await client()
>>> await store_client.put("my_key", tensor)
"""
if store_name in _local_clent_map:
return _local_clent_map[store_name]
if store_name in _local_client_map:
return _local_client_map[store_name]

controller = await _controller(store_name)
controller_strategy = await controller.get_controller_strategy.call_one()
Expand All @@ -133,7 +140,7 @@ async def client(store_name: str = DEFAULT_TORCHSTORE_NAME) -> LocalClient:
controller=controller,
strategy=controller_strategy,
)
_local_clent_map[store_name] = local_client
_local_client_map[store_name] = local_client

return local_client

Expand Down Expand Up @@ -280,9 +287,10 @@ async def put_state_dict(
>>> await put_state_dict(model.state_dict(), "model_checkpoint")
"""
cl = await client(store_name)
await torchstore.state_dict_utils.put_state_dict_batch(
store=cl, state_dict=state_dict, key=key
)
if tssd_enabled():
await put_state_dict_batch(store=cl, state_dict=state_dict, key=key)
else:
await put_state_dict_util(store=cl, state_dict=state_dict, key=key)


async def get_state_dict(
Expand All @@ -307,6 +315,11 @@ async def get_state_dict(
>>> model.load_state_dict(state_dict)
"""
cl = await client(store_name)
return await torchstore.state_dict_utils.get_state_dict(
cl, key, user_state_dict, strict
)
if tssd_enabled():
return await get_state_dict_batch(
store=cl, key=key, user_state_dict=user_state_dict, strict=strict
)
else:
return await get_state_dict_util(
store=cl, key=key, user_state_dict=user_state_dict, strict=strict
)
30 changes: 30 additions & 0 deletions torchstore/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from torchstore.controller import ObjectType
from torchstore.logging import LatencyTracker
from torchstore.state_dict_utils import unpack_metadata_state_dict
from torchstore.strategy import TorchStoreStrategy
from torchstore.transport import Pipe, Request, TensorSlice
from torchstore.transport.buffers import TransportContext
Expand Down Expand Up @@ -116,6 +117,35 @@ async def get(
latency_tracker.track_e2e()
return fetched_tensor

@torch.no_grad
async def get_batch(self, key_prefix: str, keys: list[str]) -> dict[str, Any]:
"""Fetch multiple tensors at once.

Args:
key_prefix: Prefix to prepend to each key (e.g., "v0/").
keys: List of keys to fetch.

Returns:
Dictionary mapping keys to their values (tensors or objects).
"""
logger.debug(f"Batch fetching {len(keys)} keys with prefix {key_prefix}")
latency_tracker = LatencyTracker(f"get_batch:{len(keys)}_keys")

# For now, we assume all keys are on the same storage volume
# (which is true when using SingletonStrategy or batch put)
storage_volume, _ = self.strategy.select_storage_volume()
pipe = Pipe(storage_volume)

tensor_blob, metadata_state_dict = await pipe.get_batch_from_storage_volume(
key_prefix, keys
)

# Unpack the metadata state dict to get actual tensors
result = unpack_metadata_state_dict(metadata_state_dict, tensor_blob)

latency_tracker.track_e2e()
return result

async def keys(self, prefix: str | None = None) -> list[str]:
"""
Get all keys that match the given prefix.
Expand Down
Loading
Loading