Skip to content

Commit 438f643

Browse files
authored
feat: Fix perf issue in Map.get (#341)
Should fix #326
1 parent 639cc63 commit 438f643

File tree

3 files changed

+30
-11
lines changed

3 files changed

+30
-11
lines changed

python/tvm_ffi/_ffi_api.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from typing import Any, Callable, TYPE_CHECKING
2424
if TYPE_CHECKING:
2525
from collections.abc import Mapping, Sequence
26-
from tvm_ffi import Module
26+
from tvm_ffi import Module, Object
2727
from tvm_ffi.access_path import AccessPath
2828
# isort: on
2929
# fmt: on
@@ -50,6 +50,8 @@ def Map(*args: Any) -> Any: ...
5050
def MapCount(_0: Mapping[Any, Any], _1: Any, /) -> int: ...
5151
def MapForwardIterFunctor(_0: Mapping[Any, Any], /) -> Callable[..., Any]: ...
5252
def MapGetItem(_0: Mapping[Any, Any], _1: Any, /) -> Any: ...
53+
def MapGetItemOrMissing(_0: Mapping[Any, Any], _1: Any, /) -> Any: ...
54+
def MapGetMissingObject() -> Object: ...
5355
def MapSize(_0: Mapping[Any, Any], /) -> int: ...
5456
def ModuleClearImports(_0: Module, /) -> None: ...
5557
def ModuleGetFunction(_0: Module, _1: str, _2: bool, /) -> Callable[..., Any] | None: ...
@@ -95,6 +97,8 @@ def ToJSONGraphString(_0: Any, _1: Any, /) -> str: ...
9597
"MapCount",
9698
"MapForwardIterFunctor",
9799
"MapGetItem",
100+
"MapGetItemOrMissing",
101+
"MapGetMissingObject",
98102
"MapSize",
99103
"ModuleClearImports",
100104
"ModuleGetFunction",

python/tvm_ffi/container.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,8 @@
7777
V = TypeVar("V")
7878
_DefaultT = TypeVar("_DefaultT")
7979

80+
MISSING = _ffi_api.MapGetMissingObject()
81+
8082

8183
def getitem_helper(
8284
obj: Any,
@@ -254,12 +256,11 @@ def __contains__(self, item: object) -> bool:
254256
if not isinstance(item, tuple) or len(item) != 2:
255257
return False
256258
key, value = item
257-
try:
258-
existing_value = self._backend_map[key]
259-
except KeyError:
259+
actual_value = self._backend_map.get(key, MISSING)
260+
if actual_value is MISSING:
260261
return False
261-
else:
262-
return existing_value == value
262+
# TODO(@junrus): Is `__eq__` the right method to use here?
263+
return actual_value == value
263264

264265

265266
@register_object("ffi.Map")
@@ -349,10 +350,10 @@ def get(self, key: K, default: V | _DefaultT | None = None) -> V | _DefaultT | N
349350
The result value.
350351
351352
"""
352-
try:
353-
return self[key]
354-
except KeyError:
353+
ret = _ffi_api.MapGetItemOrMissing(self, key)
354+
if MISSING.same_as(ret):
355355
return default
356+
return ret
356357

357358
def __repr__(self) -> str:
358359
"""Return a string representation of the map."""

src/ffi/container.cc

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,11 @@ class MapForwardIterFunctor {
5555
ffi::MapObj::iterator end_;
5656
};
5757

58+
ObjectRef GetMissingObject() {
59+
static ObjectRef missing_obj(make_object<Object>());
60+
return missing_obj;
61+
}
62+
5863
TVM_FFI_STATIC_INIT_BLOCK() {
5964
namespace refl = tvm::ffi::reflection;
6065
refl::GlobalDef()
@@ -81,8 +86,17 @@ TVM_FFI_STATIC_INIT_BLOCK() {
8186
[](const ffi::MapObj* n, const Any& k) -> int64_t {
8287
return static_cast<int64_t>(n->count(k));
8388
})
84-
.def("ffi.MapForwardIterFunctor", [](const ffi::MapObj* n) -> ffi::Function {
85-
return ffi::Function::FromTyped(MapForwardIterFunctor(n->begin(), n->end()));
89+
.def("ffi.MapForwardIterFunctor",
90+
[](const ffi::MapObj* n) -> ffi::Function {
91+
return ffi::Function::FromTyped(MapForwardIterFunctor(n->begin(), n->end()));
92+
})
93+
.def("ffi.MapGetMissingObject", GetMissingObject)
94+
.def("ffi.MapGetItemOrMissing", [](const ffi::MapObj* n, const Any& k) -> Any {
95+
try {
96+
return n->at(k);
97+
} catch (const tvm::ffi::Error& e) {
98+
return GetMissingObject();
99+
}
86100
});
87101
}
88102
} // namespace ffi

0 commit comments

Comments
 (0)