[hops] Support unspecialized nn module for export hops (#164082)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164082
Approved by: https://github.com/tugsbayasgalan
ghstack dependencies: #164079
This commit is contained in:
Animesh Jain 2025-09-28 15:37:32 -07:00 committed by PyTorch MergeBot
parent 1981ed4f60
commit dc54ce7554

View File

@ -1768,6 +1768,7 @@ class ScanHigherOrderVariable(TorchHigherOrderOperatorVariable):
combine_fn_var,
(
variables.nn_module.NNModuleVariable,
variables.nn_module.UnspecializedNNModuleVariable,
variables.FunctoolsPartialVariable,
),
):
@ -1776,7 +1777,13 @@ class ScanHigherOrderVariable(TorchHigherOrderOperatorVariable):
f"or a graph module if we're re-exporting but got "
f"{combine_fn.python_type()}. Please report an issue to PyTorch if you're seeing this."
)
return isinstance(combine_fn_var, variables.nn_module.NNModuleVariable)
return isinstance(
combine_fn_var,
(
variables.nn_module.NNModuleVariable,
variables.nn_module.UnspecializedNNModuleVariable,
),
)
def arg_extractor(combine_fn, init, xs, additional_inputs):
return combine_fn, init, xs, additional_inputs