Skip to content

[FEA] Slow unbucketize permute operation in SequenceEmbeddingsAllToAll for row-wise sharding #296

@z52527

Description

@z52527

Background

Currently, DynamicEmb has custom input_dist implementation (RwSparseFeaturesDist in input_dist.py) but still relies on TorchRec's original output_dist implementation. This causes:

  1. Performance issue: The unbucketize_permute operation in TorchRec's output distribution is slow, especially for non-contiguous distribution patterns (e.g., round-robin)
  2. Limited customization: Cannot optimize the output distribution without modifying TorchRec source code

Objective

Port TorchRec's output distribution classes to DynamicEmb library, enabling future performance optimizations.

Tasks

PR 1: Port output distribution classes to DynamicEmb

  • Create dynamicemb/output_dist.py with:
    • RwSequenceEmbeddingDist
    • RwPooledEmbeddingDist
  • Update dynamicemb/planner/rw_sharding.py to override create_output_dist() methods
  • Verify with existing tests (test_sequence_embedding_fw.py, test_pooled_embedding_fw.py)

PR 2: Optimize unbucketize permute with custom kernel

  • Design optimized data format for permute tensor
  • Implement CUDA kernel for efficient unbucketize operation
  • Integrate with output_dist.py
  • Benchmark and validate performance improvement

Metadata

Metadata

Assignees

Labels

dynamicembRelated with dynamicembenhancementImprovement for existing feature

Type

No type

Projects

No projects

Milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions