Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dockers/ubuntu-cuda/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -158,5 +158,5 @@ RUN \
pip list && \
python -c "import sys; ver = sys.version_info ; assert f'{ver.major}.{ver.minor}' == '$PYTHON_VERSION', ver" && \
python -c "import torch; print(f'PyTorch=={torch.__version__} with {torch.cuda.device_count()} GPUs')" && \
python -c "import nvfuser; print(f'nvFuser=={nvfuser.version()}')" && \
python -c "import nvfuser_direct as nvfuser; print(f'nvFuser=={nvfuser.version()}')" && \
python -c "import triton; print(f'Triton=={triton.__version__}')"
2 changes: 1 addition & 1 deletion docs/source/basic/inspecting_traces.rst
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ This will print the following::
# cuda version: 12.1
# nvfuser version: 0.2.8
import torch
from nvfuser import FusionDefinition, DataType
from nvfuser_direct import FusionDefinition, DataType

def nvfuser_fusion_id0(fd : FusionDefinition) -> None :
T0 = fd.define_tensor(shape=[-1, -1], contiguity=[True, True], dtype=DataType.Float, is_cpu=False, stride_order=[1, 0])
Expand Down
13 changes: 1 addition & 12 deletions thunder/executors/nvfuserex.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,23 +18,12 @@ def nvfuser_version() -> LooseVersion | None:
try:
import nvfuser_direct
except ImportError:
try:
import nvfuser
except ImportError:
pass
else:
if hasattr(nvfuser, "version"):
return LooseVersion(nvfuser.version())
else:
# NOTE: This import of nvFuser may or may not have version info
return LooseVersion("0.0.0")
return None
else:
if hasattr(nvfuser_direct, "version"):
return LooseVersion(nvfuser_direct.version())
else:
return LooseVersion("0.0.0")
# NOTE This occurs when nvFuser couldn't be imported
return None


def required_nvfuser_version() -> LooseVersion:
Expand Down
28 changes: 10 additions & 18 deletions thunder/executors/nvfuserex_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,24 +71,16 @@
# NOTE This impl file is here because nvFuser may not be available, so it's imported conditionally
# by nvfuserex.py when nvFuser is available.

DIRECT_BINDINGS_SUPPORTED_VERSION = LooseVersion("0.2.34")
DTENSOR_SUPPORTED_VERSION = LooseVersion("0.2.28")
if nvfuser_version() >= DIRECT_BINDINGS_SUPPORTED_VERSION:
import nvfuser_direct as nvfuser
from nvfuser_direct import (
DataType,
FusionDefinition,
multidevice,
ParallelType,
execute_with_dtensors,
compute_tensor_descriptor as nv_compute_td,
)
else:
if nvfuser_version() >= DTENSOR_SUPPORTED_VERSION:
from nvfuser_direct import FusionDefinition as DirectFusionDefinition
from nvfuser_direct import multidevice, ParallelType, execute_with_dtensors
import nvfuser
from nvfuser import DataType, FusionDefinition, compute_tensor_descriptor as nv_compute_td
import nvfuser_direct as nvfuser
from nvfuser_direct import (
DataType,
FusionDefinition,
multidevice,
ParallelType,
execute_with_dtensors,
compute_tensor_descriptor as nv_compute_td,
)

#
# Helper functions
Expand Down Expand Up @@ -394,7 +386,7 @@ def check_dtensor_tracing_and_runtime_metadata(inp):
lambda: "nvfuser: Expected runtime and tracing metadata to be the same for DTensor.",
)

fd = FusionDefinition() if nvfuser_version() >= DIRECT_BINDINGS_SUPPORTED_VERSION else DirectFusionDefinition()
fd = FusionDefinition()
# Device may be set in one of the "factory" methods like full, iota, or uniform
# NOTE: This should be called before defining because a factory method may look-up at `_selected_device` while being defined.
fd._selected_device = None
Expand Down
2 changes: 1 addition & 1 deletion thunder/tests/test_dynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -1403,7 +1403,7 @@ def foo(x):
@pytest.mark.skip(reason="https://github.com/Lightning-AI/lightning-thunder/issues/2546")
@requiresCUDA
def test_WallTime_KernelTime():
from nvfuser import FusionDefinition, DataType
from nvfuser_direct import FusionDefinition, DataType

def nvfuser_fusion_id2(fd: FusionDefinition) -> None:
T0 = fd.define_tensor(
Expand Down
Loading