mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
Use 'is' in callable comparisons (#166624)
Just like we use `is/is not` for class comparisons, it is generally advised to use `is/is not` for comparisons against torch functions. Pull Request resolved: https://github.com/pytorch/pytorch/pull/166624 Approved by: https://github.com/Lucaskabela, https://github.com/Skylion007
This commit is contained in:
parent
639a0b1239
commit
694db5f549
|
|
@ -369,7 +369,7 @@ def relu_compile_error_TESTING_ONLY(
|
|||
gm: torch.fx.GraphModule, example_inputs: list[torch.Tensor]
|
||||
) -> torch.fx.GraphModule:
|
||||
for node in gm.graph.nodes:
|
||||
if node.target == torch.relu:
|
||||
if node.target is torch.relu:
|
||||
raise ReluCompileError
|
||||
return gm
|
||||
|
||||
|
|
@ -379,7 +379,7 @@ def relu_runtime_error_TESTING_ONLY(
|
|||
gm: torch.fx.GraphModule, example_inputs: list[torch.Tensor]
|
||||
) -> torch.fx.GraphModule:
|
||||
for node in gm.graph.nodes:
|
||||
if node.target == torch.relu:
|
||||
if node.target is torch.relu:
|
||||
node.target = torch._assert
|
||||
node.args = (False, "ReluRuntimeError")
|
||||
gm.recompile()
|
||||
|
|
@ -391,7 +391,7 @@ def relu_accuracy_error_TESTING_ONLY(
|
|||
gm: torch.fx.GraphModule, example_inputs: list[torch.Tensor]
|
||||
) -> torch.fx.GraphModule:
|
||||
for node in gm.graph.nodes:
|
||||
if node.target == torch.relu:
|
||||
if node.target is torch.relu:
|
||||
node.target = torch.add
|
||||
node.args = (node.args[0], 1)
|
||||
gm.recompile()
|
||||
|
|
|
|||
|
|
@ -573,7 +573,7 @@ class AutogradCompilerInstance:
|
|||
result.name = make_unique(node.name)
|
||||
value_remap[node] = result
|
||||
elif node.op == "call_function":
|
||||
if node.target == torch.ops.aten.view.default:
|
||||
if node.target is torch.ops.aten.view.default:
|
||||
# this aot bwd graph is being lazily compiled
|
||||
# we must manually apply the view_to_reshape post grad pass
|
||||
# since it was already applied to the aot fwd, and baked into the gradients
|
||||
|
|
@ -1241,7 +1241,7 @@ class AutogradCompilerInstance:
|
|||
to_append = []
|
||||
hook_block = [node] # contain the hook and hook args getitem
|
||||
for n in input_nodes:
|
||||
if n.op == "call_function" and n.target == operator.getitem:
|
||||
if n.op == "call_function" and n.target is operator.getitem:
|
||||
to_append.append(n.args[0])
|
||||
to_remove.append(n)
|
||||
hook_block.append(n)
|
||||
|
|
@ -1314,7 +1314,7 @@ class AutogradCompilerInstance:
|
|||
# find the corresponding acc_grad node
|
||||
acc_grad_node = None
|
||||
for n in list(param_node.users.keys()):
|
||||
if n.op == "call_function" and n.target == call_accumulate_grad:
|
||||
if n.op == "call_function" and n.target is call_accumulate_grad:
|
||||
acc_grad_node = n
|
||||
break
|
||||
|
||||
|
|
@ -1369,7 +1369,7 @@ class AutogradCompilerInstance:
|
|||
for n in list(param_node.users.keys()):
|
||||
if (
|
||||
n.op == "call_function"
|
||||
and n.target == call_hook
|
||||
and n.target is call_hook
|
||||
and n.kwargs.get("hook_type", None) == "post_acc_grad_hook"
|
||||
):
|
||||
post_acc_grad_hook_node = n
|
||||
|
|
|
|||
|
|
@ -455,7 +455,7 @@ def cast_dtype_args_to_fp64(model: torch.fx.GraphModule) -> torch.fx.GraphModule
|
|||
for node in model.graph.nodes:
|
||||
if (
|
||||
node.op == "call_function"
|
||||
and node.target == torch.ops.prims.convert_element_type.default
|
||||
and node.target is torch.ops.prims.convert_element_type.default
|
||||
):
|
||||
assert len(node.args) == 2
|
||||
if is_float_dtype(node.args[1]) and node.args[1] != torch.float64:
|
||||
|
|
|
|||
|
|
@ -2930,12 +2930,12 @@ def handle_traced_output(example_value, tx, proxy, options, subclass_type, targe
|
|||
hasattr(proxy.node.target, "__name__")
|
||||
and proxy.node.target.__name__ == "set_state"
|
||||
and isinstance(proxy.node.target.__self__, torch._C.Generator)
|
||||
or proxy.node.target == torch.random.set_rng_state
|
||||
or proxy.node.target is torch.random.set_rng_state
|
||||
):
|
||||
return TorchInGraphFunctionVariable(proxy.node.target)
|
||||
elif (
|
||||
proxy.node.target == torch._C._DisableFuncTorch
|
||||
or proxy.node.target == torch.cuda._is_in_bad_fork
|
||||
proxy.node.target is torch._C._DisableFuncTorch
|
||||
or proxy.node.target is torch.cuda._is_in_bad_fork
|
||||
):
|
||||
return UserDefinedObjectVariable(example_value)
|
||||
elif istype(example_value, torch.Size) and all(
|
||||
|
|
|
|||
|
|
@ -2595,7 +2595,7 @@ class CheckpointHigherOrderVariable(WrapHigherOrderVariable):
|
|||
from torch.utils.checkpoint import noop_context_fn
|
||||
|
||||
context_fn = None
|
||||
if "context_fn" in kwargs and kwargs["context_fn"] != noop_context_fn:
|
||||
if "context_fn" in kwargs and kwargs["context_fn"] is not noop_context_fn:
|
||||
ctx = kwargs.pop("context_fn")
|
||||
if isinstance(ctx, torch._dynamo.variables.UserFunctionVariable):
|
||||
context_fn = ctx.fn
|
||||
|
|
|
|||
|
|
@ -48,7 +48,7 @@ class CollectTracepointsPass(PassBase):
|
|||
for node in module.graph.nodes:
|
||||
if node.op != "call_function":
|
||||
continue
|
||||
if node.target == torch.ops.higher_order._export_tracepoint:
|
||||
if node.target is torch.ops.higher_order._export_tracepoint:
|
||||
kind = node.kwargs["kind"]
|
||||
if kind == "module_call_outputs":
|
||||
nn_module_stack = node.meta["nn_module_stack"]
|
||||
|
|
@ -64,7 +64,7 @@ class CollectTracepointsPass(PassBase):
|
|||
for node in reversed(module.graph.nodes):
|
||||
if node.op != "call_function":
|
||||
continue
|
||||
if node.target == torch.ops.higher_order._export_tracepoint:
|
||||
if node.target is torch.ops.higher_order._export_tracepoint:
|
||||
kind = node.kwargs["kind"]
|
||||
if kind == "module_call_inputs":
|
||||
nn_module_stack = node.meta["nn_module_stack"]
|
||||
|
|
@ -94,7 +94,7 @@ class CollectTracepointsPass(PassBase):
|
|||
for node in module.graph.nodes:
|
||||
if node.op != "call_function":
|
||||
continue
|
||||
if node.target == torch.ops.higher_order._export_tracepoint:
|
||||
if node.target is torch.ops.higher_order._export_tracepoint:
|
||||
# There's some subtlety worth noting. Here fqn corresponds to
|
||||
# the call name, whereas path corresponds to the module name.
|
||||
# They are not necessarily the same! When a submodule is shared
|
||||
|
|
|
|||
|
|
@ -65,7 +65,7 @@ class ConstantFolder(torch.fx.Interpreter):
|
|||
|
||||
def is_impure(self, node: torch.fx.Node) -> bool:
|
||||
if (
|
||||
node.target == torch.ops.prims.convert_element_type.default
|
||||
node.target is torch.ops.prims.convert_element_type.default
|
||||
and node.args[0].op == "get_attr" # type: ignore[union-attr]
|
||||
and node.args[0].meta["val"].dtype == torch.int8 # type: ignore[union-attr]
|
||||
and node.args[1] == torch.bfloat16
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ def _is_enter_autocast_node(node: torch.fx.Node) -> Union[torch.fx.Node, bool]:
|
|||
return (
|
||||
node
|
||||
and node.op == "call_function"
|
||||
and node.target == torch.amp.autocast_mode._enter_autocast
|
||||
and node.target is torch.amp.autocast_mode._enter_autocast
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -42,7 +42,7 @@ def _is_exit_autocast_node(node: torch.fx.Node) -> Union[torch.fx.Node, bool]:
|
|||
return (
|
||||
node
|
||||
and node.op == "call_function"
|
||||
and node.target == torch.amp.autocast_mode._exit_autocast
|
||||
and node.target is torch.amp.autocast_mode._exit_autocast
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ def _is_set_grad_enabled_node(node: torch.fx.Node) -> Union[torch.fx.Node, bool]
|
|||
return (
|
||||
node
|
||||
and node.op == "call_function"
|
||||
and node.target == torch._C._set_grad_enabled
|
||||
and node.target is torch._C._set_grad_enabled
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -88,7 +88,7 @@ def _staged_schema():
|
|||
f"std::optional<{cpp_type}>",
|
||||
f"optional {thrift_type}",
|
||||
)
|
||||
elif o == Annotated:
|
||||
elif o is Annotated:
|
||||
return dump_type(t.__origin__, level)
|
||||
else:
|
||||
raise AssertionError(f"Type {t} is not supported in export schema.")
|
||||
|
|
|
|||
|
|
@ -249,7 +249,7 @@ def maybe_to_fresh_input(idx, t, meta):
|
|||
def is_with_effects(node):
|
||||
return (
|
||||
node.op == "call_function"
|
||||
and node.target == torch.ops.higher_order.with_effects
|
||||
and node.target is torch.ops.higher_order.with_effects
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1256,7 +1256,7 @@ def reordering_to_mimic_autograd_engine(gm: fx.GraphModule) -> fx.GraphModule:
|
|||
# Build the graph op-by-op by starting from the node all the way to the end
|
||||
# copy_ can be not using tangents at all, we must copy it.
|
||||
for node in list(gm.graph.nodes)[: order[first_node_in_bwd]]:
|
||||
if node.op == "call_function" and node.target == torch.ops.aten.copy_.default:
|
||||
if node.op == "call_function" and node.target is torch.ops.aten.copy_.default:
|
||||
insert_node_in_graph(node)
|
||||
|
||||
for node in list(gm.graph.nodes)[order[first_node_in_bwd] :]:
|
||||
|
|
@ -1596,7 +1596,7 @@ def force_save_bw_mutation_src(joint_module: fx.GraphModule) -> None:
|
|||
if node.op == "output":
|
||||
continue
|
||||
|
||||
is_copy_ = node.target == torch.ops.aten.copy_.default
|
||||
is_copy_ = node.target is torch.ops.aten.copy_.default
|
||||
if is_copy_:
|
||||
if _has_tag_must_be_in_backward(node):
|
||||
has_mutation_in_bw.add(node.args[0])
|
||||
|
|
|
|||
|
|
@ -1393,7 +1393,7 @@ def remove_fsdp2_unsharded_param_graph_input_usage(graph: torch.fx.Graph):
|
|||
for idx, node in enumerate(node_list):
|
||||
if (
|
||||
node.op == "call_function"
|
||||
and node.target == torch.ops.inductor.resize_storage_bytes_.default
|
||||
and node.target is torch.ops.inductor.resize_storage_bytes_.default
|
||||
):
|
||||
assert node.args[0].op == "placeholder", f"""\
|
||||
Resize can only operate on graph inputs, but got {node} which is resizing non-graph-input {node.args[0]}
|
||||
|
|
@ -1444,7 +1444,7 @@ Skipping `remove_fsdp2_unsharded_param_graph_input_usage` FX graph pass for that
|
|||
# Find all eligible unsharded params and their corresponding graph intermediates.
|
||||
unsharded_param_to_fsdp_copy_node_idxes = defaultdict(list)
|
||||
for idx, node in enumerate(node_list):
|
||||
if node.op == "call_function" and node.target == torch.ops.fsdp.copy_.default:
|
||||
if node.op == "call_function" and node.target is torch.ops.fsdp.copy_.default:
|
||||
fsdp_copy_node = node
|
||||
unsharded_param = node.args[0]
|
||||
assert unsharded_param.op == "placeholder", f"""
|
||||
|
|
@ -1456,8 +1456,8 @@ Offending node: {unsharded_param}. Graph: {graph}
|
|||
|
||||
def is_allowed_mutation(node):
|
||||
return (
|
||||
node.target == torch.ops.fsdp.copy_.default
|
||||
or node.target == torch.ops.inductor.resize_storage_bytes_.default
|
||||
node.target is torch.ops.fsdp.copy_.default
|
||||
or node.target is torch.ops.inductor.resize_storage_bytes_.default
|
||||
)
|
||||
|
||||
def is_node_mutating_unsharded_param_or_its_alias(node, unsharded_params):
|
||||
|
|
@ -1558,7 +1558,7 @@ Graph: {graph}
|
|||
for node in node_list:
|
||||
if (
|
||||
node.op == "call_function"
|
||||
and node.target == torch.ops.inductor.resize_storage_bytes_.default
|
||||
and node.target is torch.ops.inductor.resize_storage_bytes_.default
|
||||
and node.args[0] in unsharded_param_to_fsdp_copy_node_idxes
|
||||
):
|
||||
graph.erase_node(node)
|
||||
|
|
|
|||
|
|
@ -122,7 +122,7 @@ class ConstantFolder(torch.fx.Interpreter):
|
|||
def is_impure(self, node: torch.fx.node.Node) -> bool:
|
||||
def is_woq_int8_pattern(node: torch.fx.node.Node) -> bool:
|
||||
return (
|
||||
node.target == torch.ops.prims.convert_element_type.default # type: ignore[return-value]
|
||||
node.target is torch.ops.prims.convert_element_type.default # type: ignore[return-value]
|
||||
and isinstance(node.args[0], torch.fx.Node)
|
||||
and "val" in node.args[0].meta
|
||||
and node.args[0].meta["val"].dtype == torch.int8 # type: ignore[union-attr]
|
||||
|
|
@ -132,7 +132,7 @@ class ConstantFolder(torch.fx.Interpreter):
|
|||
if (
|
||||
is_woq_int8_pattern(node)
|
||||
or (
|
||||
node.target == torch.ops.aten.permute.default
|
||||
node.target is torch.ops.aten.permute.default
|
||||
and len(node.users) == 1
|
||||
and is_woq_int8_pattern(next(iter(node.users)))
|
||||
)
|
||||
|
|
|
|||
|
|
@ -591,7 +591,7 @@ def b2b_gemm_handler(match: Match, mat1: torch.fx.Node, mat2: torch.fx.Node) ->
|
|||
)
|
||||
|
||||
def is_mm(node: torch.fx.Node) -> bool:
|
||||
return node.target == torch.ops.aten.mm.default
|
||||
return node.target is torch.ops.aten.mm.default
|
||||
|
||||
# the inner MM
|
||||
inner_mm = match.nodes[-1]
|
||||
|
|
|
|||
|
|
@ -122,28 +122,28 @@ def bucket_reduce_scatter(
|
|||
def is_all_gather_into_tensor(node: torch.fx.Node) -> bool: # type: ignore[arg-type]
|
||||
return (
|
||||
node.op == "call_function"
|
||||
and node.target == torch.ops._c10d_functional.all_gather_into_tensor.default
|
||||
and node.target is torch.ops._c10d_functional.all_gather_into_tensor.default
|
||||
)
|
||||
|
||||
|
||||
def is_reduce_scatter_tensor(node: torch.fx.Node) -> bool:
|
||||
return (
|
||||
node.op == "call_function"
|
||||
and node.target == torch.ops._c10d_functional.reduce_scatter_tensor.default
|
||||
and node.target is torch.ops._c10d_functional.reduce_scatter_tensor.default
|
||||
)
|
||||
|
||||
|
||||
def is_wait_tensor(node: torch.fx.Node) -> bool:
|
||||
return (
|
||||
node.op == "call_function"
|
||||
and node.target == torch.ops._c10d_functional.wait_tensor.default
|
||||
and node.target is torch.ops._c10d_functional.wait_tensor.default
|
||||
)
|
||||
|
||||
|
||||
def is_all_reduce_tensor(node: torch.fx.Node) -> bool:
|
||||
return (
|
||||
node.op == "call_function"
|
||||
and node.target == torch.ops._c10d_functional.all_reduce.default
|
||||
and node.target is torch.ops._c10d_functional.all_reduce.default
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -315,7 +315,7 @@ class UniformValueConstantFolder(ConstantFolder):
|
|||
# single-elem attrs
|
||||
if node.op == "get_attr" or (
|
||||
node.op == "call_function"
|
||||
and node.target == torch.ops.aten.lift_fresh_copy.default
|
||||
and node.target is torch.ops.aten.lift_fresh_copy.default
|
||||
):
|
||||
out = super(ConstantFolder, self).run_node(node)
|
||||
if isinstance(out, torch.Tensor) and out.numel() == 1:
|
||||
|
|
|
|||
|
|
@ -737,7 +737,7 @@ def linear_permute_fusion(module: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
|||
input_node = node.kwargs["input"]
|
||||
if (
|
||||
input_node.op == "call_function"
|
||||
and input_node.target == torch.nn.functional.linear
|
||||
and input_node.target is torch.nn.functional.linear
|
||||
):
|
||||
normalized = NormalizedLinearNode(input_node)
|
||||
input = normalized.get_input()
|
||||
|
|
|
|||
|
|
@ -679,7 +679,7 @@ def reinplace_inplaceable_ops_core(graph: torch.fx.Graph) -> None:
|
|||
if copy_node is not None:
|
||||
replace_dict[copy_node] = copy_node.args[0]
|
||||
node.target = inplaceable_op.inplace_op
|
||||
elif node.target == torch.ops.higher_order.auto_functionalized_v2:
|
||||
elif node.target is torch.ops.higher_order.auto_functionalized_v2:
|
||||
_mutable_op = node.args[0]
|
||||
kwargs = node.kwargs
|
||||
|
||||
|
|
@ -696,7 +696,7 @@ def reinplace_inplaceable_ops_core(graph: torch.fx.Graph) -> None:
|
|||
# auto_functionalized into clones + a mutable op; this metadata
|
||||
# tells the decomp to only clone the following inputs
|
||||
node.meta["only_clone_these_tensors"] = new_bases_to_clone
|
||||
elif node.target == torch.ops.higher_order.auto_functionalized:
|
||||
elif node.target is torch.ops.higher_order.auto_functionalized:
|
||||
_mutable_op = node.args[0]
|
||||
from torch._higher_order_ops.auto_functionalize import get_mutable_args
|
||||
|
||||
|
|
|
|||
|
|
@ -127,12 +127,12 @@ def _get_dim(node: Any):
|
|||
if "dim" in node.kwargs:
|
||||
assert isinstance(node.kwargs["dim"], int)
|
||||
return node.kwargs["dim"]
|
||||
if node.target == torch.unbind:
|
||||
if node.target is torch.unbind:
|
||||
if len(node.args) == 2:
|
||||
assert isinstance(node.args[-1], int)
|
||||
return node.args[-1]
|
||||
return 0 # defaults to dim=0
|
||||
if node.target == torch.split:
|
||||
if node.target is torch.split:
|
||||
if len(node.args) == 3:
|
||||
assert isinstance(node.args[-1], int)
|
||||
return node.args[-1]
|
||||
|
|
@ -351,7 +351,7 @@ def normalize_cat_default(match: Match, *args, **kwargs):
|
|||
cat_node.args == new_args
|
||||
and cat_node.kwargs == new_kwargs
|
||||
and cat_node.op == "call_function"
|
||||
and cat_node.target == torch.cat
|
||||
and cat_node.target is torch.cat
|
||||
):
|
||||
return
|
||||
|
||||
|
|
@ -866,7 +866,7 @@ class SplitCatSimplifier:
|
|||
cat_dim = get_arg_value(user_node, 1, "dim")
|
||||
transform_params: list[_TransformParam] = []
|
||||
for user_input in user_inputs:
|
||||
if split_dim == cat_dim and user_node.target == torch.cat:
|
||||
if split_dim == cat_dim and user_node.target is torch.cat:
|
||||
# No transform needed
|
||||
transform_params.append((None, None, None, None))
|
||||
elif isinstance(user_input, tuple): # Split being simplified
|
||||
|
|
@ -888,7 +888,7 @@ class SplitCatSimplifier:
|
|||
(unflatten_params, movedim_params, None, None)
|
||||
)
|
||||
elif (
|
||||
user_node.target == torch.stack or split_dim != cat_dim
|
||||
user_node.target is torch.stack or split_dim != cat_dim
|
||||
): # We need to unsqueeze inputs not coming through split
|
||||
transform_params.append((None, None, (cat_dim,), None))
|
||||
else: # Non-split inputs
|
||||
|
|
@ -1107,9 +1107,9 @@ class SplitCatSimplifier:
|
|||
)
|
||||
|
||||
if (
|
||||
user_node.target == torch.cat
|
||||
user_node.target is torch.cat
|
||||
and split_dim != cat_dim
|
||||
and split_node.target == torch.split
|
||||
and split_node.target is torch.split
|
||||
):
|
||||
with graph.inserting_after(new_cat_node):
|
||||
new_cat_node_meta = new_cat_node.meta["example_value"]
|
||||
|
|
@ -1225,13 +1225,13 @@ class UnbindCatRemover(SplitCatSimplifier):
|
|||
(split_dim, cat_dim) if split_dim != cat_dim else None
|
||||
)
|
||||
flatten_params = None
|
||||
if user_node.target == torch.cat:
|
||||
if user_node.target is torch.cat:
|
||||
flatten_params = (cat_dim, cat_dim + 1)
|
||||
transform_params.append(
|
||||
(None, movedim_params, None, flatten_params)
|
||||
)
|
||||
elif (
|
||||
user_node.target == torch.stack
|
||||
user_node.target is torch.stack
|
||||
): # We need to unsqueeze inputs not coming through unbind into cat
|
||||
transform_params.append((None, None, (cat_dim,), None))
|
||||
else: # Non-unbind inputs
|
||||
|
|
@ -1298,13 +1298,13 @@ def merge_split_squeeze(
|
|||
match: Match, split_input: torch.fx.Node, split_sizes: list[int], dim: int
|
||||
):
|
||||
graph = match.graph
|
||||
split = next(node for node in match.nodes if node.target == torch.split)
|
||||
split = next(node for node in match.nodes if node.target is torch.split)
|
||||
if not all(s == 1 for s in split_sizes):
|
||||
return
|
||||
if isinstance(dim, Sequence):
|
||||
return
|
||||
next_users = find_next_users(split)
|
||||
if not all(node.target == torch.squeeze for node in next_users):
|
||||
if not all(node.target is torch.squeeze for node in next_users):
|
||||
return
|
||||
with graph.inserting_before(match.output_node()):
|
||||
unbind = graph.call_function(
|
||||
|
|
@ -1364,7 +1364,7 @@ getitem_unbind = ListOf(
|
|||
pass_dict=construct_pattern_matcher_pass("unbind_stack_pass"),
|
||||
)
|
||||
def merge_unbind_stack(match: Match, unbind_input: torch.fx.Node, dim: int):
|
||||
unbind_node = next(node for node in match.nodes if node.target == torch.unbind)
|
||||
unbind_node = next(node for node in match.nodes if node.target is torch.unbind)
|
||||
UnbindCatRemover().remove_unbind(match.graph, unbind_node)
|
||||
|
||||
|
||||
|
|
@ -1431,7 +1431,7 @@ reshape_getitem_split = ListOf(
|
|||
def simplify_split_cat(match: Match, split_sections: list[int], dim: int):
|
||||
if not isinstance(split_sections, (list, tuple)): # Unnormalized split
|
||||
return
|
||||
split_node = next(node for node in match.nodes if node.target == torch.split)
|
||||
split_node = next(node for node in match.nodes if node.target is torch.split)
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
SplitCatSimplifier().simplify(match.graph, split_node, split_sections)
|
||||
|
||||
|
|
@ -1518,7 +1518,7 @@ def merge_getitem_cat(match: Match, split_sections: list[int], dim: int):
|
|||
if not isinstance(split_sections, (list, tuple)): # Unnormalized split
|
||||
return
|
||||
graph = match.graph
|
||||
split_node = next(node for node in match.nodes if node.target == torch.split)
|
||||
split_node = next(node for node in match.nodes if node.target is torch.split)
|
||||
split_input, _split_size, split_dim = _get_split_args_default(split_node)
|
||||
# if the cat and split have different dims, return
|
||||
# Find the next users (i.e. users after the getitem)
|
||||
|
|
@ -1625,7 +1625,7 @@ def mutate_cat_node(match: Match, split_sections: list[int], dim: int):
|
|||
if not isinstance(split_sections, (list, tuple)): # Unnormalized split
|
||||
return
|
||||
graph = match.graph
|
||||
split_node = next(node for node in match.nodes if node.target == torch.split)
|
||||
split_node = next(node for node in match.nodes if node.target is torch.split)
|
||||
_split_input, _split_size, split_dim = _get_split_args_default(split_node)
|
||||
# if the cat and split have different dims, return
|
||||
# Find the next users (i.e. users after the getitem)
|
||||
|
|
@ -1904,7 +1904,7 @@ def merge_select_cat_aten(match: Match, *args, **kwargs):
|
|||
# get the select nodes from the node
|
||||
select_nodes = list(node_input.users.keys())
|
||||
for cat_node in list(node.users.keys()):
|
||||
if cat_node.target == torch.ops.aten.cat.default:
|
||||
if cat_node.target is torch.ops.aten.cat.default:
|
||||
cat_dim = get_arg_value(cat_node, 1, "dim")
|
||||
cat_inputs = get_arg_value(cat_node, 0, "tensors")
|
||||
# check all select nodes has same slice dim
|
||||
|
|
@ -2020,7 +2020,7 @@ def merge_unbind_stack_aten(match: Match, *args, **kwargs):
|
|||
parent_of_select_node = get_arg_value(select_nodes[0], 0, "input")
|
||||
# check the target of select_nodes are the same
|
||||
if not all(
|
||||
select_node.target == torch.ops.aten.select for select_node in select_nodes
|
||||
select_node.target is torch.ops.aten.select for select_node in select_nodes
|
||||
):
|
||||
return
|
||||
# check the select nodes come from the same parent node
|
||||
|
|
@ -2357,7 +2357,7 @@ def remove_split_unbind_children(graph: torch.fx.Graph, inputs: list[torch.fx.No
|
|||
def split_cat_to_slices(match: Match, split_sections: list[int], dim: int):
|
||||
if not isinstance(split_sections, (list, tuple)): # Unnormalized split
|
||||
return
|
||||
split_nodes = [node for node in match.nodes if node.target == torch.split]
|
||||
split_nodes = [node for node in match.nodes if node.target is torch.split]
|
||||
if split_nodes:
|
||||
split_node = next(node for node in split_nodes)
|
||||
else:
|
||||
|
|
@ -2438,7 +2438,7 @@ def split_cat_to_slices(match: Match, split_sections: list[int], dim: int):
|
|||
pass_dict=construct_pattern_matcher_pass("unbind_cat_to_view_pass"),
|
||||
)
|
||||
def unbind_cat_to_view(match: Match, unbind_input: torch.fx.Node, dim: int):
|
||||
unbind_node = next(node for node in match.nodes if node.target == torch.unbind)
|
||||
unbind_node = next(node for node in match.nodes if node.target is torch.unbind)
|
||||
graph = match.graph
|
||||
# get the cat_node and check its inputs and meta data
|
||||
next_users = find_next_users(unbind_node)
|
||||
|
|
@ -2614,7 +2614,7 @@ def convert_reshape_cat_arg_to_stack(
|
|||
def split_stack_to_cats(match: Match, split_sections: list[int], dim: int):
|
||||
if not isinstance(split_sections, (list, tuple)): # Unnormalized split
|
||||
return
|
||||
split_node = next(node for node in match.nodes if node.target == torch.split)
|
||||
split_node = next(node for node in match.nodes if node.target is torch.split)
|
||||
split_dim = get_arg_value(split_node, 2, "dim") or 0
|
||||
graph = match.graph
|
||||
threshold_to_cat = torch._inductor.config.pre_grad_fusion_options[
|
||||
|
|
@ -2685,7 +2685,7 @@ def split_stack_to_cats(match: Match, split_sections: list[int], dim: int):
|
|||
pass_dict=construct_pattern_matcher_pass("unbind_stack_to_slices_pass"),
|
||||
)
|
||||
def unbind_stack_to_slices(match: Match, unbind_input: torch.fx.Node, dim: int):
|
||||
unbind_node = next(node for node in match.nodes if node.target == torch.unbind)
|
||||
unbind_node = next(node for node in match.nodes if node.target is torch.unbind)
|
||||
graph = match.graph
|
||||
# get the cat_node and check its inputs and meta data
|
||||
next_users = find_next_users(unbind_node)
|
||||
|
|
@ -2785,10 +2785,10 @@ def get_view_shape_list(cat_arg: torch.fx.Node, stack_dim: int) -> list[int]:
|
|||
pass_dict=construct_pattern_matcher_pass("move_reshape_out_of_split_stack_pass"),
|
||||
)
|
||||
def move_reshape_out_of_split_stack(match: Match, *args, **kwargs):
|
||||
split_node = next(node for node in match.nodes if node.target == torch.split)
|
||||
split_node = next(node for node in match.nodes if node.target is torch.split)
|
||||
split_dim = _get_dim(split_node)
|
||||
split_users = list(split_node.users.keys())
|
||||
stack_nodes = [node for node in match.nodes if node.target == torch.stack]
|
||||
stack_nodes = [node for node in match.nodes if node.target is torch.stack]
|
||||
graph = match.graph
|
||||
threshold_to_cat = torch._inductor.config.pre_grad_fusion_options[
|
||||
"move_reshape_out_of_split_stack_pass"
|
||||
|
|
@ -2926,7 +2926,7 @@ def move_view_after_cat(match: Match, *args, **kwargs):
|
|||
split_node = next(
|
||||
node
|
||||
for node in match.nodes
|
||||
if node.target == torch.ops.aten.split_with_sizes.default
|
||||
if node.target is torch.ops.aten.split_with_sizes.default
|
||||
)
|
||||
split_input, split_section, split_dim = _get_split_args_default(split_node)
|
||||
split_users = list(split_node.users.keys())
|
||||
|
|
@ -2936,7 +2936,7 @@ def move_view_after_cat(match: Match, *args, **kwargs):
|
|||
if not is_sorted_and_consecutive(getitem_indices): # type: ignore[arg-type]
|
||||
return
|
||||
cat_nodes = [
|
||||
node for node in match.nodes if node.target == torch.ops.aten.cat.default
|
||||
node for node in match.nodes if node.target is torch.ops.aten.cat.default
|
||||
]
|
||||
graph = match.graph
|
||||
for cat_node in cat_nodes:
|
||||
|
|
@ -2950,7 +2950,7 @@ def move_view_after_cat(match: Match, *args, **kwargs):
|
|||
continue
|
||||
# check if the cat inputs are all the view nodes
|
||||
if not all(
|
||||
view_node.target == torch.ops.aten.reshape.default
|
||||
view_node.target is torch.ops.aten.reshape.default
|
||||
for view_node in cat_inputs
|
||||
):
|
||||
continue
|
||||
|
|
|
|||
|
|
@ -5849,7 +5849,7 @@ class ExternKernel(InputsKernel):
|
|||
if shape_env := V.fake_mode.shape_env:
|
||||
node_meta_val = V.current_node.meta.get("val")
|
||||
ctx: AbstractContextManager[None] = nullcontext()
|
||||
if V.current_node.target == torch._higher_order_ops.effects.with_effects:
|
||||
if V.current_node.target is torch._higher_order_ops.effects.with_effects:
|
||||
# remove the first effect token in meta["val"] and meta["unbacked_bindings"]
|
||||
node_meta_val = node_meta_val[1]
|
||||
ctx = _remove_effect_token_unbacked_bindings(V.current_node)
|
||||
|
|
|
|||
|
|
@ -677,7 +677,7 @@ def _convolution(
|
|||
|
||||
|
||||
def constrain_conv_to_fx_strides(fx_node, *args, **kwargs):
|
||||
assert fx_node.target == torch.ops.aten.convolution.default
|
||||
assert fx_node.target is torch.ops.aten.convolution.default
|
||||
if V.graph.layout_opt:
|
||||
return args, kwargs
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -478,7 +478,7 @@ def validate_joint_graph(joint_graph: torch.fx.Graph):
|
|||
for node in joint_graph.nodes:
|
||||
if (
|
||||
node.op == "call_function"
|
||||
and node.target == torch.ops.flex_lib.zeros_and_scatter.default
|
||||
and node.target is torch.ops.flex_lib.zeros_and_scatter.default
|
||||
):
|
||||
for user in node.users:
|
||||
if user.op != "output":
|
||||
|
|
|
|||
|
|
@ -3922,7 +3922,7 @@ def index_put_impl_(self, indices, values, accumulate, check, may_realize=False)
|
|||
isinstance(indice, ir.StorageBox)
|
||||
and isinstance(indice.data, ir.ExternKernel)
|
||||
and getattr(indice.data, "fx_node", None)
|
||||
and indice.data.fx_node.target == torch.ops.aten.randperm.default
|
||||
and indice.data.fx_node.target is torch.ops.aten.randperm.default
|
||||
)
|
||||
return False
|
||||
|
||||
|
|
|
|||
|
|
@ -437,7 +437,7 @@ class HigherOrderOperator(OperatorBase, abc.ABC):
|
|||
subclass_type = type(arg)
|
||||
if (
|
||||
subclass_type.__torch_dispatch__
|
||||
== torch._C._disabled_torch_dispatch_impl
|
||||
is torch._C._disabled_torch_dispatch_impl
|
||||
):
|
||||
continue
|
||||
|
||||
|
|
|
|||
|
|
@ -1050,7 +1050,7 @@ def _get_weight_info_from_shadow_wrapper(shadow_wrapper: torch.nn.Module):
|
|||
raise AssertionError(f"Expected exactly 1, got {len(shadow_n.users)}")
|
||||
quant_node = next(iter(shadow_n.users.keys()))
|
||||
new_args: Any = None
|
||||
if quant_node.target == torch.quantize_per_channel:
|
||||
if quant_node.target is torch.quantize_per_channel:
|
||||
_weight, scale_node, zp_node, axis, dtype = quant_node.args
|
||||
scale_val = getattr_from_fqn(shadow_wrapper, scale_node.target)
|
||||
zp_val = getattr_from_fqn(shadow_wrapper, zp_node.target)
|
||||
|
|
|
|||
|
|
@ -204,7 +204,7 @@ def get_node_input_qparams(
|
|||
|
||||
if prev_node.op == "call_function":
|
||||
# quantize - read the args directly
|
||||
if prev_node.target == torch.quantize_per_tensor:
|
||||
if prev_node.target is torch.quantize_per_tensor:
|
||||
return _get_scale_zp_from_function_args(prev_node, gm, 1, 2)
|
||||
elif prev_node.target in (toq.add, toq.add_relu, toq.mul, toq.mul_relu):
|
||||
return _get_scale_zp_from_function_args(prev_node, gm, 2, 3)
|
||||
|
|
|
|||
|
|
@ -1181,7 +1181,7 @@ def special_pattern_replacement(model: GraphModule):
|
|||
modules = dict(model.named_modules(remove_duplicate=False))
|
||||
for n in model.graph.nodes:
|
||||
q_node = n
|
||||
is_quantize = q_node.target == torch.quantize_per_tensor
|
||||
is_quantize = q_node.target is torch.quantize_per_tensor
|
||||
is_to_fp16 = (
|
||||
q_node.op == "call_method"
|
||||
and q_node.target == "to"
|
||||
|
|
|
|||
|
|
@ -139,10 +139,10 @@ def _get_lstm_with_individually_observed_parts(
|
|||
mul_count = 0
|
||||
for node in cell.graph.nodes:
|
||||
op_index: Optional[tuple[Callable, int]] = None # e.g. (torch.add, 1)
|
||||
if node.target == torch.add:
|
||||
if node.target is torch.add:
|
||||
op_index = (torch.add, add_count)
|
||||
add_count += 1
|
||||
elif node.target == torch.mul:
|
||||
elif node.target is torch.mul:
|
||||
op_index = (torch.mul, mul_count)
|
||||
mul_count += 1
|
||||
else:
|
||||
|
|
@ -205,7 +205,7 @@ def _get_reference_quantized_lstm_module(
|
|||
# on custom module input/output dtypes, and (2) expand support for complex
|
||||
# input/output structures.
|
||||
for node in cell.graph.nodes:
|
||||
if node.target == torch.quantize_per_tensor:
|
||||
if node.target is torch.quantize_per_tensor:
|
||||
arg = node.args[0]
|
||||
# Remove quantize(x), quantize(hidden[0]), and quantize(hidden[1])
|
||||
if arg.target == "x" or (
|
||||
|
|
|
|||
|
|
@ -455,9 +455,9 @@ def _maybe_insert_input_observers_for_node(
|
|||
# gelu has a has an approximate kwarg that persist in exported graph.
|
||||
# This is just a work around for these.
|
||||
if not (
|
||||
node.target == torch.ops.aten.clone.default
|
||||
or node.target == torch.ops.aten.zeros_like.default
|
||||
or node.target == torch.ops.aten.gelu.default
|
||||
node.target is torch.ops.aten.clone.default
|
||||
or node.target is torch.ops.aten.zeros_like.default
|
||||
or node.target is torch.ops.aten.gelu.default
|
||||
or len(node.kwargs) == 0
|
||||
):
|
||||
raise AssertionError(" expecting kwargs for aten op IR to be empty")
|
||||
|
|
|
|||
|
|
@ -939,7 +939,7 @@ def _fold_conv_bn_qat(m: GraphModule) -> GraphModule:
|
|||
# remove in place add from batchnorm tracking training stats
|
||||
for node in m.graph.nodes:
|
||||
if (
|
||||
node.target == torch.ops.aten.add_.Tensor
|
||||
node.target is torch.ops.aten.add_.Tensor
|
||||
and node.args[0].op == "get_attr"
|
||||
and node.args[1] == 1
|
||||
and (
|
||||
|
|
|
|||
|
|
@ -97,10 +97,10 @@ def _find_q_dq_node_for_user(
|
|||
def _is_sym_size_node(node: Node):
|
||||
return (
|
||||
node.op == "call_function"
|
||||
and node.target == torch.ops.aten.sym_size.default
|
||||
or node.target == torch.ops.aten.sym_numel.default
|
||||
or node.target == torch.ops.aten.sym_numel
|
||||
or node.target == torch.ops.aten.sym_size
|
||||
and node.target is torch.ops.aten.sym_size.default
|
||||
or node.target is torch.ops.aten.sym_numel.default
|
||||
or node.target is torch.ops.aten.sym_numel
|
||||
or node.target is torch.ops.aten.sym_size
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -228,7 +228,7 @@ def fold_bn_weights_into_conv_node(
|
|||
bn_b = _get_tensor_constant_from_node(bn_args[2], m)
|
||||
bn_rm = _get_tensor_constant_from_node(bn_args[3], m)
|
||||
bn_rv = _get_tensor_constant_from_node(bn_args[4], m)
|
||||
if bn_node.target == torch.ops.aten._native_batch_norm_legit_no_training.default:
|
||||
if bn_node.target is torch.ops.aten._native_batch_norm_legit_no_training.default:
|
||||
eps_arg_index = 6
|
||||
elif _is_supported_batch_norm_for_training(bn_node):
|
||||
eps_arg_index = 7
|
||||
|
|
@ -268,7 +268,7 @@ def fold_bn_weights_into_conv_node(
|
|||
# native_batch_norm has 3 outputs, we expect getitem calls on the output
|
||||
# and we want to replace the uses of getitem 0 with the output of conv
|
||||
#
|
||||
if bn_node.target == torch.ops.aten.batch_norm.default:
|
||||
if bn_node.target is torch.ops.aten.batch_norm.default:
|
||||
# With the new training ir, instead of batch_norm + getitem,
|
||||
# we only have the batch_norm node.
|
||||
#
|
||||
|
|
@ -377,7 +377,7 @@ def _get_aten_graph_module_for_pattern(
|
|||
for node in aten_pattern.graph.nodes: # type: ignore[union-attr]
|
||||
if (
|
||||
node.op == "call_function"
|
||||
and node.target == torch.ops.aten.copy_.default
|
||||
and node.target is torch.ops.aten.copy_.default
|
||||
and len(node.users) == 0
|
||||
):
|
||||
aten_pattern.graph.erase_node(node) # type: ignore[operator, union-attr]
|
||||
|
|
|
|||
|
|
@ -77,7 +77,7 @@ class EmbeddingQuantizer(Quantizer):
|
|||
# just as an example of alternate ways of annotating
|
||||
if (
|
||||
node.op == "call_function"
|
||||
and node.target == torch.ops.aten.embedding.default
|
||||
and node.target is torch.ops.aten.embedding.default
|
||||
):
|
||||
if embedding_config.config.weight is None:
|
||||
raise ValueError(
|
||||
|
|
|
|||
|
|
@ -607,7 +607,7 @@ class X86InductorQuantizer(Quantizer):
|
|||
_annotate_nodes_not_quantize(linear_node)
|
||||
return
|
||||
input_qspec_map = {}
|
||||
assert linear_node.target == torch.ops.aten.linear.default
|
||||
assert linear_node.target is torch.ops.aten.linear.default
|
||||
has_bias = len(linear_node.args) == 3
|
||||
input_index = 0
|
||||
weight_index = 1
|
||||
|
|
@ -1396,7 +1396,7 @@ class X86InductorQuantizer(Quantizer):
|
|||
""" # noqa: B950
|
||||
edge_or_node: tuple[Node, Node]
|
||||
if (node.target in int8_in_int8_out_ops) and (_is_any_annotated([node])):
|
||||
if node.target == torch.ops.aten.max_pool2d.default:
|
||||
if node.target is torch.ops.aten.max_pool2d.default:
|
||||
maxpool_node = node
|
||||
if not _is_all_annotated(
|
||||
[
|
||||
|
|
|
|||
|
|
@ -112,7 +112,7 @@ class XPUInductorQuantizer(X86InductorQuantizer):
|
|||
node: Node,
|
||||
) -> None:
|
||||
if (node.target in int8_in_int8_out_ops) and (_is_any_annotated([node])):
|
||||
if node.target == torch.ops.aten.max_pool2d.default:
|
||||
if node.target is torch.ops.aten.max_pool2d.default:
|
||||
return
|
||||
else:
|
||||
input_node = node.all_input_nodes[0]
|
||||
|
|
|
|||
|
|
@ -111,7 +111,7 @@ def auto_quantize(func, qtype, quant_loss=None):
|
|||
async_op = kwargs.get("async_op", False)
|
||||
if async_op is True:
|
||||
raise RuntimeError("The async_op=True mode is not supported yet.")
|
||||
if func == dist.all_gather:
|
||||
if func is dist.all_gather:
|
||||
tensors = args[0]
|
||||
input_tensors = _quantize_tensor(args[1], qtype)
|
||||
out_tensors = _quantize_tensor_list(tensors, qtype)
|
||||
|
|
@ -121,7 +121,7 @@ def auto_quantize(func, qtype, quant_loss=None):
|
|||
):
|
||||
tensors[i] = t
|
||||
|
||||
elif func == dist.all_to_all:
|
||||
elif func is dist.all_to_all:
|
||||
tensors = args[0]
|
||||
input_tensors = _quantize_tensor_list(args[1], qtype)
|
||||
out_tensors = _quantize_tensor_list(tensors, qtype)
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ from torch.fx import Graph
|
|||
|
||||
def remove_self_clone(graph: Graph) -> None:
|
||||
for node in graph.nodes:
|
||||
if node.target == torch.ops.aten.copy_.default and node.args[0] == node.args[1]:
|
||||
if node.target is torch.ops.aten.copy_.default and node.args[0] == node.args[1]:
|
||||
node.replace_all_uses_with(node.args[0])
|
||||
graph.erase_node(node)
|
||||
|
||||
|
|
|
|||
|
|
@ -69,7 +69,7 @@ def _try_remove_connecting_pytrees(curr_module_node: torch.fx.Node) -> None:
|
|||
flatten_node = curr_module_users[0]
|
||||
assert (
|
||||
flatten_node.op == "call_function"
|
||||
and flatten_node.target == fx_pytree.tree_flatten_spec
|
||||
and flatten_node.target is fx_pytree.tree_flatten_spec
|
||||
)
|
||||
|
||||
flatten_getitem_users = _get_getitem_users(flatten_node)
|
||||
|
|
@ -85,7 +85,7 @@ def _try_remove_connecting_pytrees(curr_module_node: torch.fx.Node) -> None:
|
|||
unflatten_node = next(iter(flatten_getitem_users))
|
||||
if not (
|
||||
unflatten_node.op == "call_function"
|
||||
and unflatten_node.target == pytree.tree_unflatten
|
||||
and unflatten_node.target is pytree.tree_unflatten
|
||||
):
|
||||
log.debug(
|
||||
"Flatten node %s's user is not a pytree.tree_unflatten. "
|
||||
|
|
|
|||
|
|
@ -222,7 +222,7 @@ def _rewrite_tracepoint_node(gm: torch.fx.GraphModule):
|
|||
that has the same target and args, but with the _export_root stripped from path.
|
||||
"""
|
||||
for node in gm.graph.nodes:
|
||||
if node.target == torch.ops.higher_order._export_tracepoint:
|
||||
if node.target is torch.ops.higher_order._export_tracepoint:
|
||||
if "path" in node.kwargs:
|
||||
path = _strip_root(node.kwargs["path"])
|
||||
with gm.graph.inserting_before(node):
|
||||
|
|
|
|||
|
|
@ -50,7 +50,7 @@ def _remove_detach_pass(
|
|||
if node.op != "call_function":
|
||||
continue
|
||||
if (
|
||||
node.target == torch.ops.aten.detach.default
|
||||
node.target is torch.ops.aten.detach.default
|
||||
and len(node.users) == 1
|
||||
and next(iter(node.users)).target == torch.ops.aten.detach.default
|
||||
):
|
||||
|
|
|
|||
|
|
@ -78,7 +78,7 @@ def move_to_device_pass(
|
|||
|
||||
if (
|
||||
node.op == "call_function"
|
||||
and node.target == torch.ops.aten.to.device
|
||||
and node.target is torch.ops.aten.to.device
|
||||
):
|
||||
args = list(node.args)
|
||||
# pyrefly: ignore [unsupported-operation]
|
||||
|
|
|
|||
|
|
@ -814,7 +814,7 @@ def bisect(shape_env):
|
|||
# Bisection happens on the assertion nodes of the recorded FX graph for
|
||||
# dynamic shapes.
|
||||
assert_nodes = [
|
||||
node for node in shape_env.graph.nodes if node.target == torch._assert
|
||||
node for node in shape_env.graph.nodes if node.target is torch._assert
|
||||
]
|
||||
|
||||
# Preparing the indices for binary search.
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ class CudaGraphsSupport(OperatorSupport):
|
|||
if node.op not in CALLABLE_NODE_OPS:
|
||||
return False
|
||||
|
||||
if node.target == torch.ops.aten.embedding_dense_backward.default:
|
||||
if node.target is torch.ops.aten.embedding_dense_backward.default:
|
||||
return False
|
||||
|
||||
if node.target == operator.getitem:
|
||||
|
|
|
|||
|
|
@ -326,18 +326,18 @@ def split_module(
|
|||
instantiate_node_partition_mapping(node)
|
||||
|
||||
if node.op == "call_function" and node.target in GLOBAL_STATE_NODES:
|
||||
if node.target == torch._C._set_grad_enabled:
|
||||
if node.target is torch._C._set_grad_enabled:
|
||||
assert len(node.args) == 1
|
||||
assert isinstance(node.args[0], bool)
|
||||
active_grad = node
|
||||
grad_regions[active_grad] = set({split_callback(node)})
|
||||
elif node.target == torch.amp._enter_autocast:
|
||||
elif node.target is torch.amp._enter_autocast:
|
||||
# Should all be python constants
|
||||
assert all(not isinstance(arg, Node) for arg in node.args)
|
||||
active_autocasts.add(node)
|
||||
autocast_regions[node] = set({split_callback(node)})
|
||||
autocast_exits[node] = None
|
||||
elif node.target == torch.amp._exit_autocast:
|
||||
elif node.target is torch.amp._exit_autocast:
|
||||
assert len(node.args) == 1
|
||||
autocast_regions[node.args[0]].add(split_callback(node))
|
||||
active_autocasts.remove(node.args[0])
|
||||
|
|
|
|||
|
|
@ -1501,7 +1501,7 @@ def _checkpoint_without_reentrant_generator(
|
|||
unpack_error_cb = None
|
||||
|
||||
if _checkpoint_debug_enabled if _checkpoint_debug_enabled is not None else debug:
|
||||
if context_fn != noop_context_fn:
|
||||
if context_fn is not noop_context_fn:
|
||||
raise ValueError(
|
||||
"debug=True is incompatible with non-default context_fn"
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user