Skip to content
209 changes: 206 additions & 3 deletions tests/test_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import pytest
import torch
import torch.distributed as dist
import torch.distributed.checkpoint as dcp
import torch.nn as nn
import torchstore as ts
Expand All @@ -22,9 +23,10 @@
get_model_state_dict,
get_optimizer_state_dict,
)
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
from torch.distributed.fsdp import fully_shard
from torch.distributed.tensor import DTensor
from torch.distributed.tensor import DTensor, Replicate, Shard
from torchstore.state_dict_utils import TensorMetadata, TorchStoreStateDict
from torchstore.utils import spawn_actors

from .utils import main, set_transport_type, transport_plus_strategy_params
Expand All @@ -35,6 +37,33 @@
MODEL_LINER_LENGTH = 10


def _setup_process_group():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if it's worth putting this function in a helper in tests/utils since it's used in multiple places?

https://github.com/meta-pytorch/torchstore/blob/main/tests/utils.py#L105

"""Set up minimal distributed environment for DTensor testing."""

if not dist.is_initialized():
import os
import socket

# Find an available port dynamically
def _find_free_port():
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", 0))
return s.getsockname()[1]

os.environ.setdefault("MASTER_ADDR", "127.0.0.1")
os.environ.setdefault("MASTER_PORT", str(_find_free_port()))
os.environ.setdefault("RANK", "0")
os.environ.setdefault("WORLD_SIZE", "1")

# Initialize single-process group
dist.init_process_group(
backend="gloo", # CPU backend
rank=0,
world_size=1,
)
return True


class UnitModule(nn.Module):
def __init__(self, device: torch.device):
super().__init__()
Expand Down Expand Up @@ -169,7 +198,7 @@ async def test_state_dict(strategy_params, transport_type):

class Trainer(Actor):
# Monarch RDMA does not work outside of an actor, so we need
# to wrapp this test first
# to wrap this test first
# TODO: assert this within rdma buffer
def __init__(self) -> None:
self.rank = current_rank().rank
Expand Down Expand Up @@ -293,5 +322,179 @@ def _assert_equal_state_dict(state_dict1, state_dict2):
), f"{key=} {flattened_state_dict_1[key]=} {flattened_state_dict_2[key]=}"


def test_torchstore_state_dict():
"""Test TorchStoreStateDict class with various tensor types and reconstruction."""
_setup_process_group()
device_mesh = DeviceMesh("cpu", [0])

# Create a state dict with various tensor types and shapes
original_state_dict = {
# Scalar tensor (0D)
"scalar": torch.tensor(42.5, dtype=torch.float32),
# 1D tensors with different dtypes
"vector_float": torch.randn(10, dtype=torch.float32),
"vector_int": torch.randint(0, 100, (5,), dtype=torch.int64),
"vector_half": torch.randn(8, dtype=torch.float16),
# 2D tensors with different dtypes
"matrix_float": torch.randn(3, 4, dtype=torch.float32),
"matrix_double": torch.randn(2, 3, dtype=torch.float64),
"matrix_int": torch.randint(-50, 50, (4, 2), dtype=torch.int32),
# DTensors
"dtensor_replicate": DTensor.from_local(
torch.randn(4, 6, dtype=torch.float32), device_mesh, [Replicate()]
),
"dtensor_shard": DTensor.from_local(
torch.randn(3, 5, dtype=torch.float32), device_mesh, [Shard(0)]
),
# Nested structure
"model": {
"layer1": {
"weight": torch.randn(5, 3, dtype=torch.float32),
"bias": torch.randn(5, dtype=torch.float32),
"dtensor_weight": DTensor.from_local(
torch.randn(5, 3, dtype=torch.float32), device_mesh, [Replicate()]
),
},
"layer2": {
"weight": torch.randn(2, 5, dtype=torch.float32),
"bias": torch.randn(2, dtype=torch.float32),
},
},
# Mixed with non-tensor data
"metadata": {
"epoch": 10,
"learning_rate": 0.001,
"optimizer_state": torch.randn(3, 3, dtype=torch.float32),
},
# List with tensors (note: flattened state dict doesn't preserve list structure)
"layer_weights": [
torch.randn(2, 2, dtype=torch.float32),
torch.tensor(123, dtype=torch.int32),
],
}

# Create TorchStoreStateDict
torchstore_state_dict = TorchStoreStateDict.from_state_dict(original_state_dict)

blob = torchstore_state_dict.tensor_blob

# 1. Flatten original state dict
original_flattened, _ = flatten_state_dict(original_state_dict)

# 2. Verify keys match between original flattened and torchstore flattened state dict
assert set(original_flattened.keys()) == set(
torchstore_state_dict.metadata_state_dict.keys()
), "Keys don't match between original and torchstore flattened state dicts"

# 3. Verify tensor references and calculate total size
_verify_metadata_state_dict(torchstore_state_dict, original_flattened)

# Calculate total size for blob verification
total_size = 0
for key, original_value in original_flattened.items():
if isinstance(original_value, torch.Tensor):
tensor_to_size = (
original_value._local_tensor
if hasattr(original_value, "_local_tensor")
else original_value
)
total_size += tensor_to_size.numel() * tensor_to_size.element_size()

# Verify tensor blob size matches total size
assert (
len(blob) == total_size
), f"Tensor blob size {len(blob)} doesn't match expected total size {total_size}"

# TODO: Add this code back
# # Reconstruct the state dict
# reconstructed_state_dict = torchstore_state_dict.to_state_dict()

# # Compare flattened versions - simpler than recursive comparison
# original_flattened, original_mapping = flatten_state_dict(original_state_dict)
# reconstructed_flattened, reconstructed_mapping = flatten_state_dict(
# reconstructed_state_dict
# )

# # Verify mappings are identical (structure preserved)
# assert (
# original_mapping == reconstructed_mapping
# ), "State dict structure mappings don't match"

# # Verify keys match
# assert set(original_flattened.keys()) == set(
# reconstructed_flattened.keys()
# ), "Flattened keys don't match"

# # Verify reconstruction using utility function
# _assert_equal_flattened_state_dict(original_flattened, reconstructed_flattened)

dist.destroy_process_group()


def _verify_metadata_state_dict(torchstore_state_dict, flattened_original):
"""Utility function to verify TensorMetadata objects in metadata_state_dict."""
for key, original_value in flattened_original.items():
torchstore_value = torchstore_state_dict.metadata_state_dict[key]

if isinstance(original_value, torch.Tensor):
if hasattr(original_value, "_local_tensor"): # DTensor check
# DTensor should be converted to TensorMetadata with tensor_slice
assert isinstance(torchstore_value, TensorMetadata)
assert (
torchstore_value.tensor_slice is not None
), f"DTensor at {key} should have tensor_slice"

# Verify local tensor metadata
local_tensor = original_value._local_tensor
assert torchstore_value.shape == tuple(local_tensor.shape)
assert torchstore_value.dtype == local_tensor.dtype
else:
# Regular tensor should not have tensor_slice
assert isinstance(torchstore_value, TensorMetadata)
assert (
torchstore_value.tensor_slice is None
), f"Regular tensor at {key} should not have tensor_slice"
assert torchstore_value.shape == tuple(original_value.shape)
assert torchstore_value.dtype == original_value.dtype


def _assert_equal_flattened_state_dict(flattened_original, flattened_reconstructed):
"""Utility function to verify reconstructed state dict matches original."""
for key, original_value in flattened_original.items():
reconstructed_value = flattened_reconstructed[key]

if hasattr(original_value, "_local_tensor"): # DTensor check
# Should be reconstructed as DTensor
assert hasattr(
reconstructed_value, "_local_tensor"
), f"Expected DTensor for {key}"

# Verify local tensor data matches
assert torch.equal(
original_value._local_tensor, reconstructed_value._local_tensor
), f"Local tensor data mismatch for {key}"

# Verify global shape matches
assert (
original_value.shape == reconstructed_value.shape
), f"Global shape mismatch for {key}"

# Verify placements match
assert (
original_value.placements == reconstructed_value.placements
), f"Placements mismatch for {key}"

elif isinstance(original_value, torch.Tensor):
# Regular tensors should remain the same
assert torch.equal(
original_value, reconstructed_value
), f"Regular tensor mismatch for {key}"
else:
# Non-tensor values should be preserved
assert (
original_value == reconstructed_value
), f"Non-tensor value mismatch for {key}"


if __name__ == "__main__":
main(__file__)
35 changes: 21 additions & 14 deletions torchstore/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,15 @@
import torch
from monarch.actor import get_or_spawn_controller

import torchstore.state_dict_utils
from torchstore.client import LocalClient
from torchstore.controller import Controller

from torchstore.state_dict_utils import (
get_state_dict as get_state_dict_util,
put_state_dict as put_state_dict_util,
put_state_dict_batch,
tssd_enabled,
)
from torchstore.storage_volume import StorageVolume
from torchstore.strategy import (
ControllerStorageVolumes,
Expand All @@ -25,7 +31,7 @@
DEFAULT_TORCHSTORE_NAME: str = "TorchStore"

# cache for local clients
_local_clent_map: Dict[str, LocalClient] = {}
_local_client_map: Dict[str, LocalClient] = {}


async def initialize(
Expand Down Expand Up @@ -93,14 +99,14 @@ async def shutdown(store_name: str = DEFAULT_TORCHSTORE_NAME) -> None:
"""
controller = await _controller(store_name)
await controller.teardown.call()
global _local_clent_map
_local_clent_map = {}
global _local_client_map
_local_client_map = {}


def reset_client(store_name: str = DEFAULT_TORCHSTORE_NAME) -> None:
"""Reset the local client for a given store. Useful for refreshing client state after shutdown."""
global _local_clent_map
_local_clent_map.pop(store_name, None)
global _local_client_map
_local_client_map.pop(store_name, None)


async def _controller(store_name: str = DEFAULT_TORCHSTORE_NAME) -> Controller:
Expand All @@ -123,8 +129,8 @@ async def client(store_name: str = DEFAULT_TORCHSTORE_NAME) -> LocalClient:
>>> store_client = await client()
>>> await store_client.put("my_key", tensor)
"""
if store_name in _local_clent_map:
return _local_clent_map[store_name]
if store_name in _local_client_map:
return _local_client_map[store_name]

controller = await _controller(store_name)
controller_strategy = await controller.get_controller_strategy.call_one()
Expand All @@ -133,7 +139,7 @@ async def client(store_name: str = DEFAULT_TORCHSTORE_NAME) -> LocalClient:
controller=controller,
strategy=controller_strategy,
)
_local_clent_map[store_name] = local_client
_local_client_map[store_name] = local_client

return local_client

Expand Down Expand Up @@ -280,9 +286,10 @@ async def put_state_dict(
>>> await put_state_dict(model.state_dict(), "model_checkpoint")
"""
cl = await client(store_name)
await torchstore.state_dict_utils.put_state_dict(
store=cl, state_dict=state_dict, key=key
)
if tssd_enabled():
await put_state_dict_batch(store=cl, state_dict=state_dict, key=key)
else:
await put_state_dict_util(store=cl, state_dict=state_dict, key=key)


async def get_state_dict(
Expand All @@ -307,6 +314,6 @@ async def get_state_dict(
>>> model.load_state_dict(state_dict)
"""
cl = await client(store_name)
return await torchstore.state_dict_utils.get_state_dict(
cl, key, user_state_dict, strict
return await get_state_dict_util(
store=cl, key=key, user_state_dict=user_state_dict, strict=strict
)
Loading
Loading