Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"python.testing.pytestArgs": ["tests"],
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true
}
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ scipy==1.14.0
setuptools==72.1.0
six==1.16.0
sympy==1.13.1
tabulate==0.9.0
threadpoolctl==3.5.0
tqdm==4.66.5
trimesh==4.4.4
Expand Down
52 changes: 52 additions & 0 deletions scripts/profile_data_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import argparse
import logging
import os
from neural_mesh_simplification.data.data_profiler import DataLoaderProfiler


def main():
parser = argparse.ArgumentParser(description="Profile the data loader.")
parser.add_argument(
"--data-dir",
type=str,
required=True,
help="The directory where the dataset is stored.",
)
parser.add_argument(
"--batch-size",
type=int,
required=False,
default=32,
help="The batch size to use.",
)
parser.add_argument(
"--num-workers",
type=int,
required=False,
default=os.cpu_count(),
help="The number of worker processes to use.",
)
parser.add_argument(
"--shuffle",
action="store_true",
help="Whether to shuffle the data.",
)
args = parser.parse_args()

# Configure logging
logging.basicConfig(level=logging.DEBUG)
# Set trimesh logger to INFO level to suppress debug messages
logging.getLogger("trimesh").setLevel(logging.INFO)

profiler = DataLoaderProfiler(
data_dir=args.data_dir,
batch_size=args.batch_size,
num_workers=args.num_workers,
shuffle=args.shuffle,
)
summary = profiler.profile()
profiler.log_summary(summary)


if __name__ == "__main__":
main()
234 changes: 234 additions & 0 deletions src/neural_mesh_simplification/data/data_profiler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,234 @@
import logging
import os
import time
import psutil
import numpy as np
import torch
from torch_geometric.loader import DataLoader
from neural_mesh_simplification.data.dataset import MeshSimplificationDataset
from tqdm import tqdm
from tabulate import tabulate


class DataLoaderProfiler:
def __init__(
self,
data_dir: str,
batch_size: int = 32,
num_workers: int = 8,
shuffle: bool = False,
follow_batch_fields=None,
log_interval: int = 10,
):
"""
Initialize the profiler with dataset parameters.

Args:
data_dir (str): Path to the dataset directory.
batch_size (int): Number of samples per batch.
num_workers (int): Number of worker processes.
shuffle (bool): Whether to shuffle the dataset.
follow_batch_fields (list, optional): List of batch attributes to follow.
log_interval (int): Frequency of logging progress (in batches).
"""
self.data_dir = data_dir
self.batch_size = batch_size
self.num_workers = num_workers
self.shuffle = shuffle
self.follow_batch_fields = follow_batch_fields or ["x", "pos"]
self.log_interval = log_interval

self.dataset = MeshSimplificationDataset(data_dir=data_dir)
self.dataloader = DataLoader(
self.dataset,
batch_size=self.batch_size,
shuffle=self.shuffle,
num_workers=self.num_workers,
follow_batch=self.follow_batch_fields,
)
self.logger = logging.getLogger(__name__)
self.logger.info(
f"Dataset loaded from {data_dir} with {len(self.dataset)} samples."
)

def compute_batch_memory_usage(self, batch) -> float:
"""
Compute the memory usage in MB for a given batch.

Args:
batch: A batch from the data loader.

Returns:
float: Memory usage in megabytes.
"""
memory_bytes = 0
for key in batch.keys():
value = getattr(batch, key)
if torch.is_tensor(value):
memory_bytes += value.element_size() * value.nelement()
return memory_bytes / (1024 * 1024)

def profile(self) -> dict:
"""
Profile the data loader over all batches and return summary statistics.

Returns:
dict: A dictionary containing batch statistics and memory usage metrics.
"""
batch_times = []
batch_memory = []
vertices_per_batch = []
faces_per_batch = []
edges_per_batch = []

process = psutil.Process(os.getpid())
initial_memory = process.memory_info().rss / (1024 * 1024)
start_time = time.perf_counter()

for batch_idx, batch in enumerate(
tqdm(self.dataloader, desc="Profiling batches", leave=False)
):
batch_start = time.perf_counter()

num_vertices = getattr(batch, "num_nodes", 0)
num_faces = (
batch.face.shape[1]
if hasattr(batch, "face")
and batch.face is not None
and batch.face.numel() > 0
else 0
)
num_edges = (
batch.edge_index.shape[1]
if hasattr(batch, "edge_index")
and batch.edge_index is not None
and batch.edge_index.numel() > 0
else 0
)

batch_memory.append(self.compute_batch_memory_usage(batch))
vertices_per_batch.append(num_vertices)
faces_per_batch.append(num_faces)
edges_per_batch.append(num_edges)

batch_times.append(time.perf_counter() - batch_start)

total_time = time.perf_counter() - start_time
final_memory = process.memory_info().rss / (1024 * 1024)

summary = {
"total_time": total_time,
"avg_batch_time": np.mean(batch_times) if batch_times else 0,
"memory_increase": final_memory - initial_memory,
"avg_vertices": np.mean(vertices_per_batch) if vertices_per_batch else 0,
"avg_faces": np.mean(faces_per_batch) if faces_per_batch else 0,
"avg_edges": np.mean(edges_per_batch) if edges_per_batch else 0,
"avg_memory": np.mean(batch_memory) if batch_memory else 0,
"min_batch_memory": np.min(batch_memory) if batch_memory else 0,
"max_batch_memory": np.max(batch_memory) if batch_memory else 0,
"vertices_range": (
np.min(vertices_per_batch) if vertices_per_batch else 0,
np.max(vertices_per_batch) if vertices_per_batch else 0,
),
"faces_range": (
np.min(faces_per_batch) if faces_per_batch else 0,
np.max(faces_per_batch) if faces_per_batch else 0,
),
"edges_range": (
np.min(edges_per_batch) if edges_per_batch else 0,
np.max(edges_per_batch) if edges_per_batch else 0,
),
}
return summary

def log_summary(self, summary: dict) -> None:
"""
Log the summary statistics in a tabulated format.

Args:
summary (dict): Dictionary containing the profiling results.
"""
self.logger.info("\n\n========= Dataset Profiling Results =========\n")

# Timing and Memory Overview
timing_memory_data = [
["Total Processing Time", f"{summary['total_time']:.2f} s"],
["Average Batch Time", f"{summary['avg_batch_time']:.4f} s"],
["Memory Usage Increase", f"{summary['memory_increase']:.2f} MB"],
]
self.logger.info("\nTiming and Memory Overview:")
self.logger.info(
"\n"
+ tabulate(
timing_memory_data,
tablefmt="psql",
colalign=("left", "right"),
maxcolwidths=[25, 15],
)
+ "\n"
)

# Batch Statistics
batch_stats_data = [
["Metric", "Average", "Min", "Max"],
[
"Vertices",
f"{summary['avg_vertices']:,.1f}",
f"{summary['vertices_range'][0]:,d}",
f"{summary['vertices_range'][1]:,d}",
],
[
"Faces",
f"{summary['avg_faces']:,.1f}",
f"{summary['faces_range'][0]:,d}",
f"{summary['faces_range'][1]:,d}",
],
[
"Edges",
f"{summary['avg_edges']:,.1f}",
f"{summary['edges_range'][0]:,d}",
f"{summary['edges_range'][1]:,d}",
],
[
"Memory (MB)",
f"{summary['avg_memory']:.2f}",
f"{summary['min_batch_memory']:.2f}",
f"{summary['max_batch_memory']:.2f}",
],
]
self.logger.info("\nBatch Statistics:")
self.logger.info(
"\n"
+ tabulate(
batch_stats_data,
headers="firstrow",
tablefmt="psql",
colalign=("left", "right", "right", "right"),
maxcolwidths=[15, 12, 12, 12],
)
+ "\n"
)

# Add dataset info with proper string formatting
dataset_info = [
["Configuration", "Value"], # Add header row
["Total Samples", f"{len(self.dataset):,d}"],
["Batch Size", f"{self.batch_size}"],
["Number of Workers", f"{self.num_workers}"],
[
"Shuffle Enabled",
"Yes" if self.shuffle else "No",
], # Convert boolean to Yes/No
]
self.logger.info("\nDataset Configuration:")
self.logger.info(
"\n"
+ tabulate(
dataset_info,
headers="firstrow",
tablefmt="psql",
colalign=("left", "right"),
maxcolwidths=[20, 10],
)
+ "\n"
)
37 changes: 16 additions & 21 deletions src/neural_mesh_simplification/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,29 +75,24 @@ def augment_mesh(mesh: trimesh.Trimesh) -> Trimesh | None:

return mesh


def mesh_to_tensor(mesh: trimesh.Trimesh) -> Data:
"""Convert a mesh to tensor representation including graph structure."""
if mesh is None:
return None
"""
Converts a trimesh.Trimesh to a torch_geometric Data object.

Vertices are converted to float coordinates. Faces are converted
to long tensors in COO format. A valid edge_index is computed from
the mesh's unique edges.
"""

# Convert vertices and faces to tensors
vertices_tensor = torch.tensor(mesh.vertices, dtype=torch.float32)
faces_tensor = torch.tensor(mesh.faces, dtype=torch.long)

# Build graph structure
G = build_graph_from_mesh(mesh)

# Create edge index tensor
edge_index = torch.tensor(list(G.edges), dtype=torch.long).t().contiguous()

# Create Data object
data = Data(
x=vertices_tensor,
pos=vertices_tensor,
edge_index=edge_index,
face=faces_tensor.t(),
num_nodes=len(mesh.vertices),
)
vertices = torch.tensor(mesh.vertices, dtype=torch.float32)
faces = torch.tensor(mesh.faces, dtype=torch.long)

# Extract unique edges from the mesh and convert to numpy array first
edges = np.array(mesh.edges_unique)
# Convert to tensor and transpose to shape: [2, num_edges]
edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()

data = Data(x=vertices, pos=vertices.clone(), edge_index=edge_index, face=faces.t(), num_nodes=len(mesh.vertices))

return data