Skip to content

Commit 57f400d

Browse files
committed
[F] add is_model_builder into Registry to fix the fail of function build
1 parent 8f63a54 commit 57f400d

File tree

3 files changed

+60
-42
lines changed

3 files changed

+60
-42
lines changed

chameleon/modules/backbones/gpunet.py

Lines changed: 30 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from functools import partial
12
from typing import List, Optional
23

34
import torch
@@ -125,45 +126,45 @@ def _replace_padding(model):
125126
out_indices=out_indices,
126127
)
127128

128-
@classmethod
129-
def build_gpunet_0(cls, **kwargs):
130-
return cls.build_gpunet(name='gpunet_0', **kwargs)
129+
# @classmethod
130+
# def build_gpunet_0(cls, **kwargs):
131+
# return cls.build_gpunet(name='gpunet_0', **kwargs)
131132

132-
@classmethod
133-
def build_gpunet_1(cls, **kwargs):
134-
return cls.build_gpunet(name='gpunet_1', **kwargs)
133+
# @classmethod
134+
# def build_gpunet_1(cls, **kwargs):
135+
# return cls.build_gpunet(name='gpunet_1', **kwargs)
135136

136-
@classmethod
137-
def build_gpunet_2(cls, **kwargs):
138-
return cls.build_gpunet(name='gpunet_2', **kwargs)
137+
# @classmethod
138+
# def build_gpunet_2(cls, **kwargs):
139+
# return cls.build_gpunet(name='gpunet_2', **kwargs)
139140

140-
@classmethod
141-
def build_gpunet_p0(cls, **kwargs):
142-
return cls.build_gpunet(name='gpunet_p0', **kwargs)
141+
# @classmethod
142+
# def build_gpunet_p0(cls, **kwargs):
143+
# return cls.build_gpunet(name='gpunet_p0', **kwargs)
143144

144-
@classmethod
145-
def build_gpunet_p1(cls, **kwargs):
146-
return cls.build_gpunet(name='gpunet_p1', **kwargs)
145+
# @classmethod
146+
# def build_gpunet_p1(cls, **kwargs):
147+
# return cls.build_gpunet(name='gpunet_p1', **kwargs)
147148

148-
@classmethod
149-
def build_gpunet_d1(cls, **kwargs):
150-
return cls.build_gpunet(name='gpunet_d1', **kwargs)
149+
# @classmethod
150+
# def build_gpunet_d1(cls, **kwargs):
151+
# return cls.build_gpunet(name='gpunet_d1', **kwargs)
151152

152-
@classmethod
153-
def build_gpunet_d2(cls, **kwargs):
154-
return cls.build_gpunet(name='gpunet_d2', **kwargs)
153+
# @classmethod
154+
# def build_gpunet_d2(cls, **kwargs):
155+
# return cls.build_gpunet(name='gpunet_d2', **kwargs)
155156

156157

157158
GPUNETs = {
158-
'GPUNet_0': GPUNet.build_gpunet_0,
159-
'GPUNet_1': GPUNet.build_gpunet_1,
160-
'GPUNet_2': GPUNet.build_gpunet_2,
161-
'GPUNet_p0': GPUNet.build_gpunet_p0,
162-
'GPUNet_p1': GPUNet.build_gpunet_p1,
163-
'GPUNet_d1': GPUNet.build_gpunet_d1,
164-
'GPUNet_d2': GPUNet.build_gpunet_d2,
159+
'GPUNet_0': partial(GPUNet.build_gpunet, name='gpunet_0'),
160+
'GPUNet_1': partial(GPUNet.build_gpunet, name='gpunet_1'),
161+
'GPUNet_2': partial(GPUNet.build_gpunet, name='gpunet_2'),
162+
'GPUNet_p0': partial(GPUNet.build_gpunet, name='gpunet_p0'),
163+
'GPUNet_p1': partial(GPUNet.build_gpunet, name='gpunet_p1'),
164+
'GPUNet_d1': partial(GPUNet.build_gpunet, name='gpunet_d1'),
165+
'GPUNet_d2': partial(GPUNet.build_gpunet, name='gpunet_d2'),
165166
}
166167

167168

168169
for k, v in GPUNETs.items():
169-
BACKBONES.register_module(name=k, module=v)
170+
BACKBONES.register_module(name=k, module=v, is_model_builder=True)
Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
1+
from functools import partial
2+
13
import timm
24
import torch.nn as nn
35

46
from ...registry import BACKBONES
57

8+
models = timm.list_models()
69

7-
class Timm:
8-
@staticmethod
9-
def build_model(*args, **kwargs) -> nn.Module:
10-
return timm.create_model(*args, **kwargs)
11-
12-
13-
timm_models = timm.list_models()
14-
for name in timm_models:
15-
BACKBONES.register_module(f'timm_{name}', module=Timm.build_model)
10+
for name in models:
11+
BACKBONES.register_module(
12+
f'timm_{name}',
13+
module=partial(timm.create_model, model_name=name),
14+
is_model_builder=True,
15+
)

chameleon/registry/registry.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@ def build_from_cfg(cfg: dict, registry: "Registry") -> Any:
2121
kwargs = cfg.copy()
2222
name = kwargs.pop('name')
2323
obj_cls = registry.get(name)
24-
if inspect.isclass(obj_cls) or inspect.ismethod(obj_cls):
24+
is_model_builder = registry.is_model_builder(name)
25+
26+
if inspect.isclass(obj_cls) or is_model_builder:
2527
obj = obj_cls(**kwargs)
2628
else:
2729
obj = obj_cls
@@ -33,6 +35,7 @@ class Registry:
3335
def __init__(self, name: str):
3436
self._name = name
3537
self._module_dict: Dict[str, Type] = dict()
38+
self._type_dict: Dict[str, Type] = dict()
3639

3740
def __len__(self):
3841
return len(self._module_dict)
@@ -73,14 +76,26 @@ def get(self, key: str) -> Optional[Type]:
7376

7477
return obj_cls
7578

79+
def is_model_builder(self, key: str) -> bool:
80+
if not isinstance(key, str):
81+
raise TypeError(f'key must be a str, but got {type(key)}')
82+
83+
is_model_builder = self._type_dict.get(key, None)
84+
85+
if is_model_builder is None:
86+
raise KeyError(f'{key} is not in the {self.name} registry')
87+
88+
return is_model_builder
89+
7690
def build(self, cfg: dict) -> Any:
7791
return build_from_cfg(cfg, registry=self)
7892

7993
def _register_module(
8094
self,
8195
module: Type,
8296
module_name: Optional[Union[str, List[str]]] = None,
83-
force: bool = False
97+
force: bool = False,
98+
is_model_builder: bool = False,
8499
) -> None:
85100
if not callable(module):
86101
raise TypeError(f'module must be a callable, but got {type(module)}')
@@ -94,12 +109,14 @@ def _register_module(
94109
existed_module = self.module_dict[name]
95110
raise KeyError(f'{name} is already registered in {self.name} at {existed_module.__module__}')
96111
self._module_dict[name] = module
112+
self._type_dict[name] = is_model_builder
97113

98114
def register_module(
99115
self,
100116
name: str = None,
101117
force: bool = False,
102118
module: Optional[Type] = None,
119+
is_model_builder: bool = False,
103120
) -> Union[type, Callable]:
104121

105122
if not (name is None or isinstance(name, str)):
@@ -110,12 +127,12 @@ def register_module(
110127

111128
# use it as a normal method: x.register_module(module=SomeClass)
112129
if module is not None:
113-
self._register_module(module=module, module_name=name, force=force)
130+
self._register_module(module=module, module_name=name, force=force, is_model_builder=is_model_builder)
114131
return module
115132

116133
# use it as a decorator: @x.register_module()
117134
def _register(module):
118-
self._register_module(module=module, module_name=name, force=force)
135+
self._register_module(module=module, module_name=name, force=force, is_model_builder=is_model_builder)
119136
return module
120137

121138
return _register

0 commit comments

Comments
 (0)