77import sys
88import warnings
99from abc import ABC , abstractmethod
10+ from collections import OrderedDict
1011from contextvars import ContextVar
1112from copy import deepcopy
1213from dataclasses import dataclass
1718from typing import (
1819 TYPE_CHECKING ,
1920 Dict ,
21+ Mapping ,
2022 Optional ,
2123 Protocol ,
24+ Set ,
2225 Tuple ,
2326 Type ,
2427 Union ,
@@ -868,10 +871,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
868871 raise ValueError ("kernelize mode must contain Mode.INFERENCE or Mode.TRAINING." )
869872
870873 if device is None :
871- device_type = _find_device (model )
874+ device = _find_device (model )
875+ device_type = _find_device_type (model )
872876 elif isinstance (device , str ):
873877 _validate_device_type (device )
878+ import torch
879+
874880 device_type = Device (type = device )
881+ device = torch .device (device )
875882 else :
876883 device_type = Device (device .type )
877884
@@ -884,7 +891,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
884891 layer_name = module_class .kernel_layer_name
885892
886893 if _DISABLE_KERNEL_MAPPING :
887- _replace_forward (module , module_class )
894+ _replace_forward (device , module , module_class )
888895 continue
889896
890897 kernel = _KERNEL_MAPPING .get ().get (str (layer_name ))
@@ -898,7 +905,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
898905 )
899906 if not use_fallback :
900907 raise ValueError (f"No layer mapping for `{ layer_name } `" )
901- _replace_forward (module , module_class )
908+ _replace_forward (device , module , module_class )
902909 continue
903910
904911 # Get kernel options for the device
@@ -909,7 +916,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
909916 raise ValueError (
910917 f"No layer mapping for `{ layer_name } ` with device type `{ device_type } `"
911918 )
912- _replace_forward (module , module_class )
919+ _replace_forward (device , module , module_class )
913920 continue
914921
915922 repos = property_repos .repos
@@ -919,7 +926,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
919926 raise ValueError (
920927 f"No layer mapping for `{ layer_name } ` device `{ device_type } ` with the right properties"
921928 )
922- _replace_forward (module , module_class )
929+ _replace_forward (device , module , module_class )
923930 continue
924931
925932 repo_with_mode = _select_repository (
@@ -932,7 +939,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
932939 raise ValueError (
933940 f"No repository for `{ layer_name } ` for configuration mode={ mode } "
934941 )
935- _replace_forward (module , module_class )
942+ _replace_forward (device , module , module_class )
936943 continue
937944
938945 repo , repo_mode = repo_with_mode
@@ -951,6 +958,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
951958 )
952959
953960 _conditionally_replace_forward (
961+ device = device ,
954962 module = module ,
955963 layer = layer ,
956964 mode = mode ,
@@ -1037,19 +1045,31 @@ def _validate_layer(*, check_cls, cls, repo: LayerRepositoryProtocol):
10371045 raise TypeError (f"{ repo } must not override nn.Module constructor." )
10381046
10391047 # ... or predefined member variables.
1040- torch_module_members = {name for name , _ in inspect .getmembers (nn .Module )}
1041- cls_members = {name for name , _ in inspect .getmembers (cls )}
1042- difference = cls_members - torch_module_members
1048+ unique_members = _unique_layer_members (cls )
10431049 # verify if : difference ⊄ {"can_torch_compile", "has_backward"}
1044- if not difference <= {"can_torch_compile" , "has_backward" }:
1050+ if not unique_members <= {
1051+ "can_torch_compile" ,
1052+ "create_state" ,
1053+ "has_backward" ,
1054+ "forward_with_state" ,
1055+ }:
10451056 raise TypeError (
10461057 f"{ repo } must not contain additional members compared to `{ check_cls .__name__ } `."
10471058 )
10481059
10491060 # Check whether the forward signatures are similar.
1050- params = inspect .signature (cls .forward ).parameters
10511061 ref_params = inspect .signature (check_cls .forward ).parameters
10521062
1063+ params : Mapping [str , inspect .Parameter ]
1064+ if _is_stateful_layer (cls ):
1065+ params = inspect .signature (cls .forward_with_state ).parameters
1066+ # Get rid of the mappingproxy.
1067+ params = params .copy ()
1068+ # Remove the state to be able to compare with forward.
1069+ del params ["state" ]
1070+ else :
1071+ params = inspect .signature (cls .forward ).parameters
1072+
10531073 if len (params ) != len (ref_params ):
10541074 raise TypeError (
10551075 f"Forward signature of { repo } does not match `{ check_cls .__name__ } `: different number of arguments."
@@ -1074,15 +1094,21 @@ def _is_rocm_platform():
10741094 return torch .version .hip is not None
10751095
10761096
1077- def _find_device (model : "nn.Module" ) -> Device :
1097+ def _find_device (model : "nn.Module" ) -> torch . device :
10781098 try :
10791099 param = next (model .parameters ())
10801100 except StopIteration :
10811101 raise ValueError (
10821102 "Cannot determine model device, provide as `device` argument to `kernelize`."
10831103 )
10841104
1085- dev_type = param .device .type
1105+ return param .device
1106+
1107+
1108+ def _find_device_type (model : "nn.Module" ) -> Device :
1109+ device = _find_device (model )
1110+
1111+ dev_type = device .type
10861112 if dev_type == "cuda" :
10871113 # Refine based on actual platform
10881114 if _is_rocm_platform ():
@@ -1103,6 +1129,7 @@ def _find_capability() -> int:
11031129
11041130def _conditionally_replace_forward (
11051131 * ,
1132+ device : "torch.device" ,
11061133 module : "nn.Module" ,
11071134 layer : Type ["nn.Module" ],
11081135 mode : Mode ,
@@ -1128,15 +1155,25 @@ def _conditionally_replace_forward(
11281155 logging .info ("Layer does not support torch.compile, using fallback" )
11291156 if needs_fallback_for_backward :
11301157 logging .info ("Layer does not support backward, using fallback" )
1131- _replace_forward (module , module_class )
1158+ _replace_forward (device , module , module_class )
11321159 else :
11331160 raise ValueError (f"Available kernel does not support mode: { mode } " )
11341161 else :
1135- _replace_forward (module , layer )
1162+ _replace_forward (device , module , layer )
11361163
11371164
1138- def _replace_forward (module : "nn.Module" , layer : Type ["nn.Module" ]):
1139- module .forward = MethodType (layer .forward , module ) # type: ignore[method-assign]
1165+ def _replace_forward (
1166+ device : "torch.device" , module : "nn.Module" , layer : Type ["nn.Module" ]
1167+ ):
1168+ if _is_stateful_layer (layer ):
1169+ state = layer .create_state (device , module ) # type: ignore[attr-defined]
1170+
1171+ def forward (self , * args , ** kwargs ):
1172+ return layer .forward_with_state (self , state , * args , ** kwargs )
1173+
1174+ module .forward = MethodType (forward , module )
1175+ else :
1176+ module .forward = MethodType (layer .forward , module ) # type: ignore[method-assign]
11401177
11411178
11421179def _validate_layer_has_mode (
@@ -1179,3 +1216,21 @@ def _get_layer_memoize(
11791216 _CACHED_LAYER [repo ] = layer
11801217
11811218 return layer
1219+
1220+
1221+ def _unique_layer_members (layer : Type ["nn.Module" ]) -> Set [str ]:
1222+ import torch .nn as nn
1223+
1224+ torch_module_members = {name for name , _ in inspect .getmembers (nn .Module )}
1225+ cls_members = {name for name , _ in inspect .getmembers (layer )}
1226+ return cls_members - torch_module_members
1227+
1228+
1229+ def _is_stateful_layer (layer : Type [nn .Module ]) -> bool :
1230+ unique = _unique_layer_members (layer )
1231+ is_stateful = "forward_with_state" in unique
1232+ if is_stateful and len (unique & {"create_state" , "forward_with_state" }) != 2 :
1233+ raise TypeError (
1234+ f"Stateful layer `{ layer .__name__ } ` must implement both `create_state` and `forward_with_state` or neither."
1235+ )
1236+ return is_stateful
0 commit comments