mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[hops] Support unspecialized nn module for export hops (#164082)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164082 Approved by: https://github.com/tugsbayasgalan ghstack dependencies: #164079
This commit is contained in:
parent
1981ed4f60
commit
dc54ce7554
|
|
@ -1768,6 +1768,7 @@ class ScanHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
|||
combine_fn_var,
|
||||
(
|
||||
variables.nn_module.NNModuleVariable,
|
||||
variables.nn_module.UnspecializedNNModuleVariable,
|
||||
variables.FunctoolsPartialVariable,
|
||||
),
|
||||
):
|
||||
|
|
@ -1776,7 +1777,13 @@ class ScanHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
|||
f"or a graph module if we're re-exporting but got "
|
||||
f"{combine_fn.python_type()}. Please report an issue to PyTorch if you're seeing this."
|
||||
)
|
||||
return isinstance(combine_fn_var, variables.nn_module.NNModuleVariable)
|
||||
return isinstance(
|
||||
combine_fn_var,
|
||||
(
|
||||
variables.nn_module.NNModuleVariable,
|
||||
variables.nn_module.UnspecializedNNModuleVariable,
|
||||
),
|
||||
)
|
||||
|
||||
def arg_extractor(combine_fn, init, xs, additional_inputs):
|
||||
return combine_fn, init, xs, additional_inputs
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user