[BE][PYFMT] migrate PYFMT for torch/_[a-h]*/ to ruff format (#144551)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144551
Approved by: https://github.com/ezyang
ghstack dependencies: #148186
This commit is contained in:
Xuehai Pan 2025-06-24 17:24:16 +08:00 committed by PyTorch MergeBot
parent 9642c75689
commit 162ca185ff
71 changed files with 824 additions and 642 deletions

View File

@ -47,12 +47,9 @@ USE_BLACK_FILELIST = re.compile(
"test/[p-z]*/**",
# torch/**
# torch/_[a-c]*/**
"torch/_[a-c]*/**",
# torch/_[e-h]*/**
"torch/_[e-h]*/**",
# torch/_i*/**
# torch/_[j-z]*/**
"torch/_[j-z]*/**",
# torch/[a-c]*/**
"torch/a[a-n]*/**",
"torch/a[p-z]*/**",

View File

@ -624,9 +624,9 @@ class TS2FXGraphConverter:
self.fx_graph, name, self.is_top_level_graph()
)
elif name in self.name_to_constant:
assert isinstance(
self.name_to_constant[name], torch.ScriptObject
), "Input conversion only handles ScriptObject"
assert isinstance(self.name_to_constant[name], torch.ScriptObject), (
"Input conversion only handles ScriptObject"
)
normalized_name = normalize_name(name)
self.input_specs.append(
InputSpec(
@ -661,9 +661,7 @@ class TS2FXGraphConverter:
def to_float_tensor(t):
return t.to(dtype=torch.float).item()
inp_list = [
self.get_fx_value_by_ir_value(inp) for inp in node.inputs()
] # noqa: C416
inp_list = [self.get_fx_value_by_ir_value(inp) for inp in node.inputs()] # noqa: C416
fx_node = self.fx_graph.call_function(
to_float_tensor,
tuple(inp_list),
@ -749,9 +747,7 @@ class TS2FXGraphConverter:
self.name_to_constant[name] = value
def convert_prim_CallMethod(self, node: torch._C.Node):
inp_list = [
self.get_fx_value_by_ir_value(inp) for inp in node.inputs()
] # noqa: C416
inp_list = [self.get_fx_value_by_ir_value(inp) for inp in node.inputs()] # noqa: C416
fx_node = self.fx_graph.call_method(
node.s("name"),
tuple(inp_list),
@ -783,9 +779,9 @@ class TS2FXGraphConverter:
self.name_to_node[output_name] = self.fx_graph.get_attr(attr_fqn)
else:
if attr_fqn not in self.name_to_non_tensor_attribute_node:
self.name_to_non_tensor_attribute_node[
attr_fqn
] = self.name_to_non_tensor_attribute[attr_fqn]
self.name_to_non_tensor_attribute_node[attr_fqn] = (
self.name_to_non_tensor_attribute[attr_fqn]
)
self.name_to_node[output_name] = self.name_to_non_tensor_attribute_node[
attr_fqn
]
@ -850,15 +846,15 @@ class TS2FXGraphConverter:
k = self.get_fx_value_by_ir_value(inp)
else:
v = self.get_fx_value_by_ir_value(inp)
assert (
k is not None and v is not None
), "DictConstruct has an empty key value pair."
assert k is not None and v is not None, (
"DictConstruct has an empty key value pair."
)
output_dict[k] = v
k, v = None, None
assert (
k is None and v is None
), "DictConstruct has an odd number of elements (violating our assumption)."
assert k is None and v is None, (
"DictConstruct has an odd number of elements (violating our assumption)."
)
output_name = node.output().debugName()
self.name_to_node[output_name] = output_dict
@ -1124,9 +1120,9 @@ class TS2FXGraphConverter:
), # + 1 because the 0th element is the condition.
)
global_argument_index = global_arguments.index(name)
fx_block_args[
i + node.outputsSize() + global_argument_index
] = self.name_to_node[name]
fx_block_args[i + node.outputsSize() + global_argument_index] = (
self.name_to_node[name]
)
def _check_set_attr_in_if_block(self, if_node: torch._C.Node):
for block in if_node.blocks():
@ -1545,9 +1541,9 @@ DEBUG: (TORCH_LOGS="+export" <cmd>), additionally
for spec in ep.graph_signature.input_specs:
# Mark as constant tensors for erroneously traced buffers.
if spec.kind == InputKind.BUFFER and spec.target in name_to_constant:
assert isinstance(
name_to_constant[spec.target], torch.Tensor
), f"{type(name_to_constant[spec.target])} has been erroneously marked as buffer"
assert isinstance(name_to_constant[spec.target], torch.Tensor), (
f"{type(name_to_constant[spec.target])} has been erroneously marked as buffer"
)
spec.kind = InputKind.CONSTANT_TENSOR
spec.persistent = None
ep.verifier().check(ep)

View File

@ -169,7 +169,10 @@ def fakify(
return t
if isinstance(t, _IntWrapper):
if t.dynamism is not None and t.dynamism.type in (_DimHintType.DYNAMIC, _DimHintType.AUTO): # type: ignore[union-attr]
if t.dynamism is not None and t.dynamism.type in ( # type: ignore[union-attr]
_DimHintType.DYNAMIC,
_DimHintType.AUTO,
):
symint = mode.shape_env.create_unspecified_symint_and_symbol( # type: ignore[union-attr]
t.val, source, DimDynamic.DYNAMIC
)

View File

@ -252,8 +252,11 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase):
else:
raise ExportPassBaseError(f"Unsupported target type: {target}")
def get_attr(
self, target: str, args: tuple[Argument, ...], kwargs: dict[str, Argument] # type: ignore[override]
def get_attr( # type: ignore[override]
self,
target: str,
args: tuple[Argument, ...],
kwargs: dict[str, Argument],
) -> Argument:
return super().get_attr(target, args, kwargs)
@ -265,8 +268,11 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase):
) -> None:
raise ExportPassBaseError("call_module is not supported.")
def call_method(
self, target: str, args: tuple[Argument, ...], kwargs: dict[str, Argument] # type: ignore[override]
def call_method( # type: ignore[override]
self,
target: str,
args: tuple[Argument, ...],
kwargs: dict[str, Argument],
) -> None:
raise ExportPassBaseError("call_method is not supported.")
@ -426,13 +432,17 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase):
def call_submodule(
self, graph_module: fx.GraphModule, inputs: tuple[Argument, ...]
) -> PassResult:
prev_tracer, self.tracer = self.tracer, self.ExportTracer(
self, graph_module.graph._codegen
prev_tracer, self.tracer = (
self.tracer,
self.ExportTracer(self, graph_module.graph._codegen),
)
self.tracer.fake_tensor_mode = prev_tracer.fake_tensor_mode
interpreter = self.ExportInterpreter(self, graph_module)
prev_interpreter, self.interpreter = self.interpreter, torch.fx.Interpreter( # type: ignore[assignment]
torch.fx.GraphModule(torch.nn.Module(), torch.fx.Graph())
prev_interpreter, self.interpreter = (
self.interpreter,
torch.fx.Interpreter( # type: ignore[assignment]
torch.fx.GraphModule(torch.nn.Module(), torch.fx.Graph())
),
)
inputs_data = pytree.tree_map_only(ProxyValue, lambda x: x.data, inputs)
with fx_traceback.preserve_node_meta():
@ -458,9 +468,9 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase):
fake_tensor_mode = None
for i in inputs:
if isinstance(i, FakeTensor):
assert (
fake_tensor_mode is None or fake_tensor_mode is i.fake_mode
), "Multiple fake tensor mode detected."
assert fake_tensor_mode is None or fake_tensor_mode is i.fake_mode, (
"Multiple fake tensor mode detected."
)
fake_tensor_mode = i.fake_mode
if fake_tensor_mode is None:
self.tracer.fake_tensor_mode = FakeTensorMode(allow_non_fake_inputs=True)

View File

@ -16,12 +16,15 @@ def insert_custom_op_guards(gm: torch.fx.GraphModule, ops_to_guard: set[str]) ->
"""
for node in gm.graph.nodes:
if node.op == "call_function" and str(node.target) in ops_to_guard:
with _set_node_metadata_hook(
gm,
functools.partial(
_node_metadata_hook, stack_trace=node.meta.get("stack_trace")
with (
_set_node_metadata_hook(
gm,
functools.partial(
_node_metadata_hook, stack_trace=node.meta.get("stack_trace")
),
),
), gm.graph.inserting_before(node):
gm.graph.inserting_before(node),
):
for arg in (*node.args, *node.kwargs.values()):
if isinstance(arg, torch.fx.Node) and isinstance(
arg.meta.get("val"), torch.Tensor

View File

@ -373,7 +373,7 @@ def lift_constants_pass(
def rewrite_script_object_meta(
gm: torch.fx.GraphModule,
) -> dict[str, _ConstantAttributeType,]:
) -> dict[str, _ConstantAttributeType]:
"""When tracing, we produce a graph with FakeScriptObject in the
meta["val"].

View File

@ -107,20 +107,20 @@ def _dump_dynamic_shapes(
would generate the following output:
```
{
'dynamic_shapes': (
"dynamic_shapes": (
[
['dx', 4],
['dx + 1', 4],
["dx", 4],
["dx + 1", 4],
],
['_DimHint.STATIC'],
['_DimHint.STATIC', '_DimHint.STATIC'],
["_DimHint.STATIC"],
["_DimHint.STATIC", "_DimHint.STATIC"],
None,
),
'dims': {
'dx': {
'min': 4,
'max': 16,
'derived': ['dx + 1'],
"dims": {
"dx": {
"min": 4,
"max": 16,
"derived": ["dx + 1"],
},
},
}
@ -149,7 +149,7 @@ def _dump_dynamic_shapes(
return out
def _track_dim_from_dims(
val: Union[None, int, _DimHint, Dim]
val: Union[None, int, _DimHint, Dim],
) -> Union[None, int, str]:
"""
Tracks dims, ranges, derived dims from the standardized dynamic_shapes spec.
@ -295,7 +295,7 @@ def _load_dynamic_shapes(
dim_cache[_expr] = ddim # cache derived dims
def deserialize_shape(
val: Union[None, int, str]
val: Union[None, int, str],
) -> Union[None, int, Dim, _DimHint]:
if val is None or isinstance(val, int):
return val

View File

@ -129,13 +129,13 @@ def _staged_schema():
t, cpp_type, thrift_type = dump_type(f.type, 0)
ret = {"type": t}
cpp_default: Optional[str] = None
assert (
typing.get_origin(f.type) == Annotated
), f"Field {f.name} must be annotated with an integer id."
assert typing.get_origin(f.type) == Annotated, (
f"Field {f.name} must be annotated with an integer id."
)
thrift_id = f.type.__metadata__[0]
assert (
type(thrift_id) is int
), f"Field {f.name} must be annotated with an integer id."
assert type(thrift_id) is int, (
f"Field {f.name} must be annotated with an integer id."
)
value = dataclasses.MISSING
if f.default is not dataclasses.MISSING:
@ -173,9 +173,7 @@ def _staged_schema():
def _handle_int_enum(name, ty):
yaml_ret[name] = {"kind": "enum", "fields": {x.name: x.value for x in ty}}
cpp_enum_defs[
name
] = f"""
cpp_enum_defs[name] = f"""
enum class {name} {{
{chr(10).join([f" {x.name} = {x.value}," for x in ty])}
}};
@ -240,14 +238,17 @@ enum {name} {{
from_json_def = f"""{{
{name} nlohmann_json_default_obj;
{chr(10).join(
[f' nlohmann_json_t.{name} = nlohmann_json_j.value("{name}", nlohmann_json_default_obj.{name});'
for name, f in cpp_fields.items()])}
{
chr(10).join(
[
f' nlohmann_json_t.{name} = nlohmann_json_j.value("{name}", nlohmann_json_default_obj.{name});'
for name, f in cpp_fields.items()
]
)
}
}}
"""
cpp_class_defs[
name
] = f"""
cpp_class_defs[name] = f"""
class {name} {{
private:
{field_decls}
@ -262,9 +263,7 @@ class {name} {{
cpp_json_defs.append(f"inline {from_json_decl} {from_json_def}")
cpp_type_decls.append(f"class {name};")
thrift_type_defs[
name
] = f"""
thrift_type_defs[name] = f"""
struct {name} {{
{chr(10).join(f" {f['thrift_id']}: {f['thrift_type']} {n};" for n, f in thrift_fields.items())}
}}"""
@ -307,9 +306,7 @@ struct {name} {{
]
)
cpp_class_defs[
name
] = f"""
cpp_class_defs[name] = f"""
class {name} {{
struct Void {{}};
@ -352,9 +349,7 @@ inline void parseEnum(std::string_view s, {name}::Tag& t) {{
"""
cpp_type_decls.append(f"class {name};")
thrift_type_defs[
name
] = f"""
thrift_type_defs[name] = f"""
union {name} {{
{chr(10).join(f" {f['thrift_id']}: {f['thrift_type']} {n};" for n, f in thrift_fields.items())}
}}"""

View File

@ -322,9 +322,9 @@ def _reconstruct_fake_tensor(
json_tensor_meta = json.loads(serialized_tensor_meta.decode("utf-8"))
tensor_meta = _dict_to_dataclass(TensorMeta, json_tensor_meta)
# Find the current fake mode
assert (
_CURRENT_DESERIALIZER is not None
), "Need access to current deserializer state"
assert _CURRENT_DESERIALIZER is not None, (
"Need access to current deserializer state"
)
fake_tensor = _CURRENT_DESERIALIZER.deserialize_tensor_meta(tensor_meta)
if is_parameter:
fake_tensor = torch.nn.Parameter(fake_tensor) # type: ignore[assignment]
@ -337,9 +337,9 @@ def serialize_torch_artifact(
if artifact is None:
return b""
assert (
FakeTensor not in copyreg.dispatch_table
), "Refusing to stomp on existing FakeTensor reducer"
assert FakeTensor not in copyreg.dispatch_table, (
"Refusing to stomp on existing FakeTensor reducer"
)
try:
copyreg.pickle(FakeTensor, _reduce_fake_tensor)
buffer = io.BytesIO()
@ -356,7 +356,7 @@ def serialize_torch_artifact(
def deserialize_torch_artifact(
serialized: Union[dict[str, Any], tuple[Any, ...], bytes]
serialized: Union[dict[str, Any], tuple[Any, ...], bytes],
):
if isinstance(serialized, (dict, tuple)):
return serialized
@ -415,7 +415,7 @@ def _symbol_index(sym: sympy.Symbol, sym_type: SymT):
def serialize_range_constraints(
range_constraints: dict[sympy.Symbol, ValueRanges]
range_constraints: dict[sympy.Symbol, ValueRanges],
) -> dict[str, RangeConstraint]:
return {
str(k): RangeConstraint(
@ -499,9 +499,9 @@ class GraphModuleSerializer(metaclass=Final):
graph_input = Argument.create(
as_custom_obj=CustomObjArgument(name=node.name, class_fqn=class_fqn)
)
self.graph_state.custom_obj_values[
node.name
] = self.serialize_script_obj_meta(val)
self.graph_state.custom_obj_values[node.name] = (
self.serialize_script_obj_meta(val)
)
else:
raise AssertionError(f"Unimplemented graph input type: {node.meta['val']}")
self.graph_state.inputs.append(graph_input)
@ -627,9 +627,9 @@ class GraphModuleSerializer(metaclass=Final):
)
elif type(node.target) in _serialization_registry:
# Sanity check for unhandled serialization.
assert (
type(node.target) in _serialization_registry
), f"{type(node.target)} is not supported in export serialization."
assert type(node.target) in _serialization_registry, (
f"{type(node.target)} is not supported in export serialization."
)
handler = _serialization_registry[type(node.target)]
namespace = handler.namespace()
@ -1295,9 +1295,9 @@ class GraphModuleSerializer(metaclass=Final):
f"but somehow previously was found to have field names {field_names}."
)
else:
self.treespec_namedtuple_fields[
serialized_type_name
] = NamedTupleDef(field_names=ts.context._fields)
self.treespec_namedtuple_fields[serialized_type_name] = (
NamedTupleDef(field_names=ts.context._fields)
)
for child in ts.children_specs:
store_namedtuple_fields(child)
@ -1516,9 +1516,9 @@ class GraphModuleSerializer(metaclass=Final):
idx_to_name = {}
for user in node.users:
assert (
user.target is operator.getitem
), f"User node {user} of {node} is incorrect"
assert user.target is operator.getitem, (
f"User node {user} of {node} is incorrect"
)
idx_to_name[user.args[1]] = user.name
for idx, _ in enumerate(meta_val):
@ -3529,9 +3529,9 @@ def register_extension(
extension_handler: type[ExtensionHandler],
):
"""Register custom de/serialization method for a node with non-standard type."""
assert issubclass(
extension_handler, ExtensionHandler
), f"Expected ExtensionHandler, got {extension_handler}."
assert issubclass(extension_handler, ExtensionHandler), (
f"Expected ExtensionHandler, got {extension_handler}."
)
assert op_type not in _serialization_registry, f"{op_type} is already registered."
assert isinstance(op_type, type) # Maybe a good idea to enforce this first.
assert not (

View File

@ -18,9 +18,9 @@ class _UnionTag(str):
def __eq__(self, cmp) -> bool:
assert isinstance(cmp, str)
other = str(cmp)
assert other in _get_field_names(
self._cls
), f"{other} is not a valid tag for {self._cls}. Available tags: {_get_field_names(self._cls)}"
assert other in _get_field_names(self._cls), (
f"{other} is not a valid tag for {self._cls}. Available tags: {_get_field_names(self._cls)}"
)
return str(self) == other
def __hash__(self):
@ -43,7 +43,10 @@ class _Union:
return obj
def __post_init__(self):
assert not any(f.name in ("type", "_type", "create", "value") for f in fields(self)) # type: ignore[arg-type, misc]
assert not any(
f.name in ("type", "_type", "create", "value")
for f in fields(self) # type: ignore[arg-type, misc]
)
@property
def type(self) -> str:

View File

@ -448,9 +448,9 @@ def register_dataclass_as_pytree_node(
from_dumpable_context: Optional[FromDumpableContextFn] = None,
return_none_fields: bool = False,
) -> None:
assert dataclasses.is_dataclass(
cls
), f"Only dataclasses can be registered with this function: {cls}"
assert dataclasses.is_dataclass(cls), (
f"Only dataclasses can be registered with this function: {cls}"
)
def default_flatten_fn(obj: Any) -> tuple[list[Any], Context]:
flattened = []
@ -644,11 +644,14 @@ def _insert_aten_to_metadata_assert_pass(gm: torch.fx.GraphModule) -> None:
continue
if (tensor_val := node.args[0].meta.get("val")) is not None:
with gm.graph.inserting_before(node), _set_node_metadata_hook(
gm,
functools.partial(
_node_metadata_hook,
stack_trace=node.meta.get("stack_trace"),
with (
gm.graph.inserting_before(node),
_set_node_metadata_hook(
gm,
functools.partial(
_node_metadata_hook,
stack_trace=node.meta.get("stack_trace"),
),
),
):
gm.graph.call_function(
@ -1342,6 +1345,7 @@ def register_module_as_pytree_input_node(cls: type[torch.nn.Module]) -> None:
import torch
class Module(torch.nn.Module):
def __init__(self):
super().__init__()
@ -1350,12 +1354,15 @@ def register_module_as_pytree_input_node(cls: type[torch.nn.Module]) -> None:
def forward(self, x):
return self.linear(x)
torch._export.utils.register_module_as_pytree_node(InputDataClass)
class Mod(torch.nn.Module):
def forward(self, x, m):
return m(x) + x
ep = torch.export.export(Mod(), (torch.randn(3), Module()))
print(ep)

View File

@ -195,9 +195,9 @@ def mark_subclass_constructor_exportable_experimental(constructor_subclass):
for mode in torch_function_mode_stack
if isinstance(mode, PreDispatchTorchFunctionMode)
]
assert (
len(pre_dispatch_tf_modes) <= 1
), f"Expected only one PreDispatchTorchFunctionMode, found {len(pre_dispatch_tf_modes)}"
assert len(pre_dispatch_tf_modes) <= 1, (
f"Expected only one PreDispatchTorchFunctionMode, found {len(pre_dispatch_tf_modes)}"
)
if len(pre_dispatch_tf_modes) == 0:
return

View File

@ -96,9 +96,9 @@ class GraphInfoProvider:
@property
def recomputable_node_only_graph(self) -> nx.DiGraph:
if self._lazily_initialized_graphs[self.__RECOMPUTABLE_NODE_ONLY_GRAPH] is None:
self._lazily_initialized_graphs[
self.__RECOMPUTABLE_NODE_ONLY_GRAPH
] = self._create_recomputable_node_only_graph()
self._lazily_initialized_graphs[self.__RECOMPUTABLE_NODE_ONLY_GRAPH] = (
self._create_recomputable_node_only_graph()
)
return self._lazily_initialized_graphs[self.__RECOMPUTABLE_NODE_ONLY_GRAPH]
@property
@ -119,17 +119,17 @@ class GraphInfoProvider:
@property
def full_joint_nx_graph(self) -> nx.DiGraph:
if self._lazily_initialized_graphs[self.__FULL_NX_JOINT_GRAPH] is None:
self._lazily_initialized_graphs[
self.__FULL_NX_JOINT_GRAPH
] = self._create_full_joint_graph()
self._lazily_initialized_graphs[self.__FULL_NX_JOINT_GRAPH] = (
self._create_full_joint_graph()
)
return self._lazily_initialized_graphs[self.__FULL_NX_JOINT_GRAPH]
@property
def simplified_fx_joint_graph(self) -> Graph:
if self._lazily_initialized_graphs[self.__SIMPLIFIED_FX_JOINT_GRAPH] is None:
self._lazily_initialized_graphs[
self.__SIMPLIFIED_FX_JOINT_GRAPH
] = self._recreate_psuedo_joint_graph()
self._lazily_initialized_graphs[self.__SIMPLIFIED_FX_JOINT_GRAPH] = (
self._recreate_psuedo_joint_graph()
)
return self._lazily_initialized_graphs[self.__SIMPLIFIED_FX_JOINT_GRAPH]
def get_non_ac_peak_memory(self) -> float:
@ -285,9 +285,7 @@ class GraphInfoProvider:
float(
self.recomputable_node_only_graph_with_larger_graph_context.nodes[
node
][
"memory"
]
]["memory"]
)
)
)

View File

@ -2,6 +2,7 @@
"""
Utils for caching the outputs of AOTAutograd
"""
from __future__ import annotations
import base64
@ -425,16 +426,13 @@ class InductorOutput(Generic[TOut], ABC):
"""
@abstractmethod
def pre_save(self) -> None:
...
def pre_save(self) -> None: ...
@abstractmethod
def load(self, example_inputs) -> TOut:
...
def load(self, example_inputs) -> TOut: ...
@abstractmethod
def post_compile(self, result: TOut, fx_config: _CompileFxKwargs) -> TOut:
...
def post_compile(self, result: TOut, fx_config: _CompileFxKwargs) -> TOut: ...
@dataclass
@ -598,7 +596,9 @@ class CompiledBackward(GenericCompiledBackward[CompiledFxGraph], FxGraphCacheLoa
# See note [Wrapping bw_compiler in disable]
# This is done by _wrapped_bw_compiler in torch/_dynamo/backends/common.py
# But since on cache hit we do not call the bw_compiler, we need to reapply the disable
return torch._dynamo.disable(compiled_bw, reason="do not trace generated backwards pass") # type: ignore[return-value]
return torch._dynamo.disable( # type: ignore[return-value]
compiled_bw, reason="do not trace generated backwards pass"
)
# Forward types don't have any extra parameters, so this is just a TypeAlias, in essence
@ -617,7 +617,9 @@ class BundledCompiledBackward(
# See note [Wrapping bw_compiler in disable]
# This is done by _wrapped_bw_compiler in torch/_dynamo/backends/common.py
# But since on cache hit we do not call the bw_compiler, we need to reapply the disable
return torch._dynamo.disable(compiled_bw, reason="do not trace generated backwards pass") # type: ignore[return-value]
return torch._dynamo.disable( # type: ignore[return-value]
compiled_bw, reason="do not trace generated backwards pass"
)
@dataclass
@ -1053,10 +1055,10 @@ class AOTAutogradCache(GuardedCache[GenericAOTAutogradCacheEntry]):
cache_key, debug_lines = autograd_cache_key(
gm, args, aot_config, fx_config
)
entry: Optional[
GenericAOTAutogradCacheEntry
] = AOTAutogradCache._lookup(
cache_key, local, remote, args, cache_info, aot_config
entry: Optional[GenericAOTAutogradCacheEntry] = (
AOTAutogradCache._lookup(
cache_key, local, remote, args, cache_info, aot_config
)
)
if entry is not None:
compiled_fn = entry.wrap_post_compile(args, aot_config, fx_config)
@ -1081,9 +1083,8 @@ class AOTAutogradCache(GuardedCache[GenericAOTAutogradCacheEntry]):
# FXGraphCache and AOTAutogradCache?
# get_metrics_context().increment(...)
if (
ephemeral_increase := add_ephemeral_timeout_increase_for_distributed(
time_saved_ns
)
ephemeral_increase
:= add_ephemeral_timeout_increase_for_distributed(time_saved_ns)
) != 0:
cache_info["ephemeral_timeout_increase"] = ephemeral_increase
@ -1311,9 +1312,9 @@ class AOTAutogradCache(GuardedCache[GenericAOTAutogradCacheEntry]):
return None
if remote:
remote_cache: Optional[
RemoteCache[JsonDataTy]
] = AOTAutogradCache.get_remote_cache()
remote_cache: Optional[RemoteCache[JsonDataTy]] = (
AOTAutogradCache.get_remote_cache()
)
if remote_cache is not None:
time_taken_ms = int(
(entry.forward_time_taken_ns + entry.backward_time_taken_ns) // 1e6

View File

@ -571,9 +571,9 @@ from a multi-output view call"
output_type = (
OutputType.alias_of_intermediate_save_as_output
)
intermediate_base_tensor_id_to_output_idx[
id(o._base)
] = new_out_idx
intermediate_base_tensor_id_to_output_idx[id(o._base)] = (
new_out_idx
)
intermediate_bases.append(o._base)
elif (
# See https://github.com/pytorch/pytorch/issues/100348 for this case.

View File

@ -46,11 +46,14 @@ aot_graphs_log = getArtifactLogger(__name__, "aot_graphs")
def _create_graph(f, args, *, aot_config: AOTConfig) -> torch.fx.GraphModule:
# FunctionalTensorMode must be enabled here.
# See Note [Accessing .grad_fn on FunctionalTensor]
with enable_python_dispatcher(), FunctionalTensorMode(
pre_dispatch=aot_config.pre_dispatch,
export=aot_config.is_export,
# Allow token discovery for joint fn tracing as tokens can be used in backward.
_allow_token_discovery=True,
with (
enable_python_dispatcher(),
FunctionalTensorMode(
pre_dispatch=aot_config.pre_dispatch,
export=aot_config.is_export,
# Allow token discovery for joint fn tracing as tokens can be used in backward.
_allow_token_discovery=True,
),
):
fx_g = make_fx(
f,
@ -238,9 +241,9 @@ def aot_dispatch_base_graph(
# TODO: should factor this into a separate function for export that always only returns just the graph.
if aot_config.is_export:
assert (
maybe_subclass_meta is None
), "aot_export_module does not support tensor subclass inputs for now."
assert maybe_subclass_meta is None, (
"aot_export_module does not support tensor subclass inputs for now."
)
return fw_module, saved_updated_flat_args_subclasses_desugared, maybe_subclass_meta
@ -332,7 +335,7 @@ def aot_dispatch_autograd_graph(
# when we need to manually detach() some inputs in the forward.
# Higher order ops might eventually need to do the same.
if aot_config.is_export:
assert (
maybe_subclass_meta is None
), "aot_export_module does not support tensor subclass inputs for now."
assert maybe_subclass_meta is None, (
"aot_export_module does not support tensor subclass inputs for now."
)
return fx_g, saved_updated_joint_inputs, maybe_subclass_meta

View File

@ -6,6 +6,7 @@ This file contains utilities related to functionalization in AOTAutograd:
3. regenerating/replaying views from their base
4. checking if a graph is functional i.e. whether it contains any mutation ops
"""
from __future__ import annotations
from dataclasses import dataclass
@ -452,14 +453,14 @@ def assert_functional_graph(fx_g: torch.fx.Graph) -> int:
# this is mostly a hack to avoid failing XLA tests.
# See https://github.com/pytorch/pytorch/pull/122434#issuecomment-2101012113
if "set_buffer_donor_" not in str(n.args[0]):
assert (
n.args[0] in placeholders
), f"n={str(n)}, n.args[0]={str(n.args[0])}, placeholders={str(placeholders)}, graph={str(fx_g)}"
assert n.args[0] in placeholders, (
f"n={str(n)}, n.args[0]={str(n.args[0])}, placeholders={str(placeholders)}, graph={str(fx_g)}"
)
mutation_count += 1
else:
assert (
not n.target._schema.is_mutable
), f"aot_autograd expected to have an entirely functional graph, but found {n.format_node()}"
assert not n.target._schema.is_mutable, (
f"aot_autograd expected to have an entirely functional graph, but found {n.format_node()}"
)
return mutation_count
@ -472,9 +473,9 @@ def propagate_input_mutation_stacktraces(fx_g: torch.fx.Graph) -> None:
if n.target is torch.ops.aten.copy_.default:
# Can only copy_ into an input, and can only do so once
if "set_buffer_donor_" not in str(n.args[0]):
assert (
n.args[0] in placeholders
), f"n={str(n)}, n.args[0]={str(n.args[0])}, placeholders={str(placeholders)}, graph={str(fx_g)}"
assert n.args[0] in placeholders, (
f"n={str(n)}, n.args[0]={str(n.args[0])}, placeholders={str(placeholders)}, graph={str(fx_g)}"
)
placeholders.remove(n.args[0])
copy_from_node = n.args[1]
# Pre-condition: every node has a "stack_trace" field in its meta,

View File

@ -823,9 +823,9 @@ def create_wrap_fn(fn, args):
from .functional_utils import from_fun, has_data_mutation, to_fun
def assert_no_mutation(t):
assert not has_data_mutation(
t
), "Saved tensors hooks with inputs mutations are not allowed"
assert not has_data_mutation(t), (
"Saved tensors hooks with inputs mutations are not allowed"
)
@wraps(fn)
def _wrapper(*args):
@ -1110,9 +1110,11 @@ def maybe_inline_graph_saved_tensors_hooks(
# Inserting packed sym scalars before first saved tensor input.
# Inserting packed tensors before last saved tensor input.
# Saved tensor inputs between them will be removed.
with bw_g.inserting_before(
bw_g_inputs[0]
) if is_sym else bw_g.inserting_before(bw_g_input):
with (
bw_g.inserting_before(bw_g_inputs[0])
if is_sym
else bw_g.inserting_before(bw_g_input)
):
new_n = bw_g.placeholder(new_node_name)
assert new_n.name == new_node_name
new_n.meta = copy.copy(out_n.meta)

View File

@ -6,6 +6,7 @@ This module defines runtime wrappers, which, based on previous analysis attempts
3. handle functionalized randomness
4. deduplicate inputs and consolidate views into their bases (see input_output_analysis)
"""
import builtins
import collections
import contextlib
@ -318,9 +319,9 @@ def _create_runtime_wrapper(
for info in runtime_metadata.output_info
)
def record_runtime_wrapper_prologue_enter() -> (
Optional[AbstractContextManager[None]]
):
def record_runtime_wrapper_prologue_enter() -> Optional[
AbstractContextManager[None]
]:
if (
torch.autograd.profiler._is_profiler_enabled
and dynamo_config.record_runtime_overhead
@ -950,9 +951,9 @@ class AOTDedupeWrapper(CompilerWrapper):
keep_arg_mask.append(True)
add_dupe_map.append(j)
j += 1
assert (
len(add_dupe_map) == duped_arg_len
), f"Expects add_dupe_map to have length {duped_arg_len} but got {len(add_dupe_map)}"
assert len(add_dupe_map) == duped_arg_len, (
f"Expects add_dupe_map to have length {duped_arg_len} but got {len(add_dupe_map)}"
)
self.keep_arg_mask = keep_arg_mask
self.add_dupe_map = add_dupe_map
@ -996,9 +997,9 @@ class AOTDedupeWrapper(CompilerWrapper):
keep_input_mutations=fw_metadata.keep_input_mutations,
is_train=fw_metadata.is_train,
)(*deduped_flat_args)
assert (
ref_fw_metadata == updated_fw_metadata
), f"ref_metadata={str(ref_fw_metadata)}, actual_metadata={str(updated_fw_metadata)}"
assert ref_fw_metadata == updated_fw_metadata, (
f"ref_metadata={str(ref_fw_metadata)}, actual_metadata={str(updated_fw_metadata)}"
)
return wrapped_flat_fn, deduped_flat_args, updated_fw_metadata
@ -1397,14 +1398,14 @@ def merge_view_inputs(
# The "inputs that are aliased but have different differentiable bases" case
# is more complicated and hopefully pretty rare. Not currently handled.
if not is_inference:
assert _are_differentiable_views(
view1, view2
), "aot_autograd() does not yet handle non-differentiable view input mutations."
assert _are_differentiable_views(view1, view2), (
"aot_autograd() does not yet handle non-differentiable view input mutations."
)
# Regenerating views when reinterpreting complex / real tensors seems non-trivial,
# not handling for now
assert _same_dtype_views(
view1, view2
), "aot_autograd() does not yet handle input mutations on views with different dtypes."
assert _same_dtype_views(view1, view2), (
"aot_autograd() does not yet handle input mutations on views with different dtypes."
)
non_none_bases = [
fwd_inputs[i]._base
for i in aliased_input_indices
@ -1451,13 +1452,13 @@ def merge_view_inputs(
# Case where all of the aliases require gradients, and have the same _base.
synthetic_base = non_none_bases[0]
for other_base in non_none_bases[1:]:
assert (
other_base is synthetic_base
), "aot_autograd() does not yet handle non-differentiable view input mutations."
assert other_base is synthetic_base, (
"aot_autograd() does not yet handle non-differentiable view input mutations."
)
for alias in aliases_with_none_bases:
assert (
alias is synthetic_base
), "aot_autograd() does not yet handle non-differentiable view input mutations."
assert alias is synthetic_base, (
"aot_autograd() does not yet handle non-differentiable view input mutations."
)
base_args.append(synthetic_base)
for curr_view_idx in aliased_input_indices:
curr_view = fwd_inputs[curr_view_idx]
@ -2286,9 +2287,9 @@ To fix this, your tensor subclass must implement the dunder method __force_to_sa
@staticmethod
def _backward_impl(ctx, all_args):
# compiled autograd reimplements this function at proxy_call_aot_backward
assert (
not backward_state_indices
), "BackwardState requires CompiledAutograd"
assert not backward_state_indices, (
"BackwardState requires CompiledAutograd"
)
ctx.maybe_clear_saved_tensors()
saved_tensors_use_once = (

View File

@ -689,7 +689,7 @@ class ViewAndMutationMeta:
and len(self.traced_tangents) == len(other.traced_tangents)
and all(
x.shape == y.shape and x.dtype == y.dtype
for x, y, in zip(self.traced_tangents, other.traced_tangents)
for x, y in zip(self.traced_tangents, other.traced_tangents)
)
and self.num_backward_tokens == other.num_backward_tokens
)
@ -726,9 +726,9 @@ class SubclassMeta:
# in case we made incorrect assumptions about the subclass-ness of our grad_outputs
#
# Optional field because we don't compute for inference graphs
grad_input_metas: Optional[
list[Union[PlainTensorMeta, SubclassCreationMeta]]
] = None
grad_input_metas: Optional[list[Union[PlainTensorMeta, SubclassCreationMeta]]] = (
None
)
def __init__(self) -> None:
# The fields in this class get set after its construction.

View File

@ -383,9 +383,9 @@ def wrap_tensor_subclasses(
return wrapped_args + activations
return tuple(list(wrapped_args) + list(activations))
else:
assert (
len(unwrapped_args) == num_args_tallied
), f"Expected {len(unwrapped_args)} == {num_args_tallied}"
assert len(unwrapped_args) == num_args_tallied, (
f"Expected {len(unwrapped_args)} == {num_args_tallied}"
)
return tuple(wrapped_args)

View File

@ -320,15 +320,17 @@ def create_functionalized_rng_ops_wrapper(func, args, trace_joint=True) -> Any:
def traced_joint(
primals, tangents, fwd_seed, fwd_base_offset, bwd_seed, bwd_base_offset
):
with patch("torch.cuda.get_rng_state", override_get_rng_state), patch(
"torch.cuda.set_rng_state", override_set_rng_state
with (
patch("torch.cuda.get_rng_state", override_get_rng_state),
patch("torch.cuda.set_rng_state", override_set_rng_state),
):
return append_rng_offsets(func(primals, tangents))
def traced_forward(*primals_fwd_seed_fwd_base_offset):
# The signature is (*primals, seed, offset)
with patch("torch.cuda.get_rng_state", override_get_rng_state), patch(
"torch.cuda.set_rng_state", override_set_rng_state
with (
patch("torch.cuda.get_rng_state", override_get_rng_state),
patch("torch.cuda.set_rng_state", override_set_rng_state),
):
return append_rng_offsets(func(*primals_fwd_seed_fwd_base_offset[:-2]))
@ -450,15 +452,15 @@ def create_functionalized_fn(
# Ban metadata mutations on fw inputs during the bw
if not inpt_info.mutates_metadata:
assert (
not joint_mutates_metadata
), "Found a graph input that had its metadata mutated in the backward. This is not supported"
assert not joint_mutates_metadata, (
"Found a graph input that had its metadata mutated in the backward. This is not supported"
)
# Ban storage resizing on fw inputs during the bw
if not inpt_info.mutation_inductor_storage_resize:
assert not was_inductor_storage_resized(
f_inpt
), "Found a graph input that had storage resizing in the backward. This is not supported"
assert not was_inductor_storage_resized(f_inpt), (
"Found a graph input that had storage resizing in the backward. This is not supported"
)
# Allow data mutations on fw inputs during the bw, but only if they do not require grad
# So we can guarantee that we can keep the mutations in the graph
@ -470,7 +472,10 @@ def create_functionalized_fn(
# Not banning here mutations on inpt_info.requires_grad -
# we'll check at runtime and fail only when backward is under torch.is_grad_enabled (create_graph)
# Add node meta for copy_ for partitioner that this node should be in backward graph.
with torch.fx.traceback.preserve_node_meta(), set_partitioner_tag_must_be_in_backward():
with (
torch.fx.traceback.preserve_node_meta(),
set_partitioner_tag_must_be_in_backward(),
):
before.copy_(after)
meta.indices_of_inputs_that_requires_grad_with_mutations_in_bw.append(
idx
@ -485,7 +490,9 @@ def create_functionalized_fn(
):
assert not has_metadata_mutation(
f_inpt, before, check_only_storage_mutation=False
), "Found an input to the backward that had metadata mutated during the backward pass. This is not supported"
), (
"Found an input to the backward that had metadata mutated during the backward pass. This is not supported"
)
if has_data_mutation(f_inpt):
can_be_in_graph = _check_if_mutation_can_be_in_graph(
keep_input_mutations=True,
@ -503,9 +510,9 @@ def create_functionalized_fn(
),
requires_grad=f_inpt.requires_grad,
)
assert (
can_be_in_graph
), "a backward input that had data mutated in an autograd-aware way. This is not supported"
assert can_be_in_graph, (
"a backward input that had data mutated in an autograd-aware way. This is not supported"
)
# Perform the input mutation
with torch.fx.traceback.preserve_node_meta():
before.copy_(after)
@ -621,8 +628,10 @@ def create_functionalized_fn(
if inpt_old.is_inference():
maybe_preserve_vc = nullcontext()
else:
maybe_preserve_vc = torch.autograd._unsafe_preserve_version_counter(
inpt_old # type: ignore[assignment]
maybe_preserve_vc = (
torch.autograd._unsafe_preserve_version_counter(
inpt_old # type: ignore[assignment]
)
)
with torch.no_grad(), maybe_preserve_vc:
inpt_old.copy_(inpt_new)
@ -889,9 +898,12 @@ def create_functional_call(mod, params_spec, params_len, store_orig_mod=False):
# https://github.com/pytorch/pytorch/issues/103569
def functional_call(*args, **kwargs):
with stateless._reparametrize_module(
mod, pytree.tree_unflatten(args[:params_len], params_spec)
), maybe_disable_thunkify():
with (
stateless._reparametrize_module(
mod, pytree.tree_unflatten(args[:params_len], params_spec)
),
maybe_disable_thunkify(),
):
if isinstance(mod, torch.fx.GraphModule):
with fx_traceback.preserve_node_meta(), warnings.catch_warnings():
warnings.filterwarnings(

View File

@ -140,9 +140,9 @@ def call_func_at_runtime_with_args(
class PytreeThunk:
spec: Optional[pytree.TreeSpec] = None
# These are some kinda dumb microoptimizations that save about 3-4 us of overhead.
is_simple: Optional[
bool
] = None # if the output spec is a tuple/list, we won't bother unflattening it.
is_simple: Optional[bool] = (
None # if the output spec is a tuple/list, we won't bother unflattening it.
)
is_really_simple: Optional[bool] = None # if the output spec is a LeafSpec
def set(self, spec: pytree.TreeSpec) -> None:
@ -335,12 +335,12 @@ def unlift_tokens(fw_module, fw_metadata, aot_config, bw_module=None):
num_erased_inputs = len(input_token_nodes)
assert (
num_erased_inputs == expected_num_erased
), f"{subgraph} num_erased_inputs:{num_erased_inputs} {input_token_nodes}!=expected {expected_num_erased}"
assert (
num_erased_outs == expected_num_erased
), f"{subgraph} num_erased_outs:{num_erased_outs} {output_token_nodes}!=expected {expected_num_erased}"
assert num_erased_inputs == expected_num_erased, (
f"{subgraph} num_erased_inputs:{num_erased_inputs} {input_token_nodes}!=expected {expected_num_erased}"
)
assert num_erased_outs == expected_num_erased, (
f"{subgraph} num_erased_outs:{num_erased_outs} {output_token_nodes}!=expected {expected_num_erased}"
)
module.recompile()

View File

@ -454,8 +454,7 @@ class AOTDispatchCompiler(Protocol):
self,
gm: torch.fx.GraphModule,
example_inputs: Sequence[InputType],
) -> Any:
...
) -> Any: ...
# TODO: bikeshed on this name
@ -637,13 +636,14 @@ def _create_aot_dispatcher_function(
# If any saved tensor hooks are active, we **don't** want to trace them.
# Instead, we'll let them run at runtime, around the custom autograd.Function
# that we generate in torch.compile.
with torch.autograd.set_multithreading_enabled(
False
), preserve_rng_state(), (
fake_mode
), (
python_dispatcher_mode
), PhiloxStateTracker(), torch._dynamo.utils._disable_saved_tensors_hooks_during_tracing():
with (
torch.autograd.set_multithreading_enabled(False),
preserve_rng_state(),
fake_mode,
python_dispatcher_mode,
PhiloxStateTracker(),
torch._dynamo.utils._disable_saved_tensors_hooks_during_tracing(),
):
from torch._library.fake_class_registry import (
FakeScriptObject,
maybe_to_fake_obj,
@ -756,7 +756,7 @@ def _create_aot_dispatcher_function(
if fw_metadata.num_intermediate_bases > 0:
assert not req_subclass_dispatch, f"""\
torch.compile is currently being used with tensor subclass inputs:
{','.join([str(type(x)) for x in fake_flat_args])}. We are attempting to a compile a graph with two graph outputs
{",".join([str(type(x)) for x in fake_flat_args])}. We are attempting to a compile a graph with two graph outputs
that alias one another, which is currently unsupported in the subclass use case. If you run into this,
please file a github issue"""
@ -899,7 +899,7 @@ def aot_function(
A simple example usage of :func:`aot_function` is as follows. This example
will print the forward and backward graphs of the function ``fn``
>>> fn = lambda x : x.sin().cos()
>>> fn = lambda x: x.sin().cos()
>>> def print_compile_fn(fx_module, args):
>>> print(fx_module)
>>> return fx_module
@ -1425,9 +1425,7 @@ We require the output marked as the loss (at index {output_loss_index}) to be a
output_gradients = []
for a, grad in zip(args, gradients):
if isinstance(a, torch.Tensor) and a.requires_grad:
assert (
grad is not None
), """\
assert grad is not None, """\
Found a parameter that did not receive a gradient.
"This is most likely a bug, but if this needs to be supported please comment on this Github issue:
https://github.com/pytorch/pytorch/issues/101192
@ -1540,7 +1538,9 @@ def aot_export_joint_simple(
if config.debug_assert:
# Smoke test that after partitioning, we can run the forward without any calling convention changes.
fw_module, _bw_module = aot_config.default_partition( # noqa: F821
fx_g, args, num_fwd_outputs=len(fw_metadata.output_infos) # noqa: F821
fx_g,
args,
num_fwd_outputs=len(fw_metadata.output_infos), # noqa: F821
)
# Attempt to run the fw_module with the original user inputs
fake_mode = detect_fake_mode(args)

View File

@ -92,7 +92,7 @@ def vmap(
doesn't provide a batched ``torch.dot`` API; instead of unsuccessfully
rummaging through docs, use :func:`vmap` to construct a new function.
>>> torch.dot # [D], [D] -> []
>>> torch.dot # [D], [D] -> []
>>> batched_dot = torch.func.vmap(torch.dot) # [N, D], [N, D] -> [N]
>>> x, y = torch.randn(2, 5), torch.randn(2, 5)
>>> batched_dot(x, y)
@ -104,7 +104,7 @@ def vmap(
>>> weights = torch.randn(feature_size, requires_grad=True)
>>>
>>> def model(feature_vec):
>>> # Very simple linear model with activation
>>> # Very simple linear model with activation
>>> return feature_vec.dot(weights).relu()
>>>
>>> examples = torch.randn(batch_size, feature_size)
@ -120,7 +120,7 @@ def vmap(
>>> # Setup
>>> N = 5
>>> f = lambda x: x ** 2
>>> f = lambda x: x**2
>>> x = torch.randn(N, requires_grad=True)
>>> y = f(x)
>>> I_N = torch.eye(N)
@ -137,43 +137,49 @@ def vmap(
:func:`vmap` can also be nested, producing an output with multiple batched dimensions
>>> torch.dot # [D], [D] -> []
>>> batched_dot = torch.vmap(torch.vmap(torch.dot)) # [N1, N0, D], [N1, N0, D] -> [N1, N0]
>>> torch.dot # [D], [D] -> []
>>> batched_dot = torch.vmap(
... torch.vmap(torch.dot)
... ) # [N1, N0, D], [N1, N0, D] -> [N1, N0]
>>> x, y = torch.randn(2, 3, 5), torch.randn(2, 3, 5)
>>> batched_dot(x, y) # tensor of size [2, 3]
>>> batched_dot(x, y) # tensor of size [2, 3]
If the inputs are not batched along the first dimension, ``in_dims`` specifies
the dimension that each inputs are batched along as
>>> torch.dot # [N], [N] -> []
>>> torch.dot # [N], [N] -> []
>>> batched_dot = torch.vmap(torch.dot, in_dims=1) # [N, D], [N, D] -> [D]
>>> x, y = torch.randn(2, 5), torch.randn(2, 5)
>>> batched_dot(x, y) # output is [5] instead of [2] if batched along the 0th dimension
>>> batched_dot(
... x, y
... ) # output is [5] instead of [2] if batched along the 0th dimension
If there are multiple inputs each of which is batched along different dimensions,
``in_dims`` must be a tuple with the batch dimension for each input as
>>> torch.dot # [D], [D] -> []
>>> torch.dot # [D], [D] -> []
>>> batched_dot = torch.vmap(torch.dot, in_dims=(0, None)) # [N, D], [D] -> [N]
>>> x, y = torch.randn(2, 5), torch.randn(5)
>>> batched_dot(x, y) # second arg doesn't have a batch dim because in_dim[1] was None
>>> batched_dot(
... x, y
... ) # second arg doesn't have a batch dim because in_dim[1] was None
If the input is a Python struct, ``in_dims`` must be a tuple containing a struct
matching the shape of the input:
>>> f = lambda dict: torch.dot(dict['x'], dict['y'])
>>> f = lambda dict: torch.dot(dict["x"], dict["y"])
>>> x, y = torch.randn(2, 5), torch.randn(5)
>>> input = {'x': x, 'y': y}
>>> batched_dot = torch.vmap(f, in_dims=({'x': 0, 'y': None},))
>>> input = {"x": x, "y": y}
>>> batched_dot = torch.vmap(f, in_dims=({"x": 0, "y": None},))
>>> batched_dot(input)
By default, the output is batched along the first dimension. However, it can be batched
along any dimension by using ``out_dims``
>>> f = lambda x: x ** 2
>>> f = lambda x: x**2
>>> x = torch.randn(2, 5)
>>> batched_pow = torch.vmap(f, out_dims=1)
>>> batched_pow(x) # [5, 2]
>>> batched_pow(x) # [5, 2]
For any function that uses kwargs, the returned function will not batch the kwargs but will
accept kwargs
@ -184,7 +190,7 @@ def vmap(
>>>
>>> batched_pow = torch.vmap(fn)
>>> assert torch.allclose(batched_pow(x), x * 4)
>>> batched_pow(x, scale=x) # scale is not batched, output has shape [2, 2, 5]
>>> batched_pow(x, scale=x) # scale is not batched, output has shape [2, 2, 5]
.. note::
vmap does not provide general autobatching or handle variable-length
@ -337,7 +343,7 @@ def grad(func: Callable, argnums: argnums_t = 0, has_aux: bool = False) -> Calla
>>> batch_size, feature_size = 3, 5
>>>
>>> def model(weights, feature_vec):
>>> # Very simple linear model with activation
>>> # Very simple linear model with activation
>>> assert feature_vec.dim() == 1
>>> return feature_vec.dot(weights).relu()
>>>
@ -349,7 +355,9 @@ def grad(func: Callable, argnums: argnums_t = 0, has_aux: bool = False) -> Calla
>>> examples = torch.randn(batch_size, feature_size)
>>> targets = torch.randn(batch_size)
>>> inputs = (weights, examples, targets)
>>> grad_weight_per_example = vmap(grad(compute_loss), in_dims=(None, 0, 0))(*inputs)
>>> grad_weight_per_example = vmap(grad(compute_loss), in_dims=(None, 0, 0))(
... *inputs
... )
Example of using ``grad`` with ``has_aux`` and ``argnums``:

View File

@ -185,8 +185,12 @@ def benchmark_utilization(
```
def f(a):
return a.sum()
a = torch.rand(2**20, device="cuda")
utilization, mm_conv_utilization = benchmark_utilization(f, a, "tmp", trace_file_name = "tmp_chrome_trace")
utilization, mm_conv_utilization = benchmark_utilization(
f, a, "tmp", trace_file_name="tmp_chrome_trace"
)
```
Args:

View File

@ -150,13 +150,13 @@ class DebugInterpreter(fx.Interpreter):
def check(nv, rv, desc):
assert callable(desc)
assert nv.dtype == rv.dtype, f"{desc()}: {nv.dtype} != {rv.dtype}"
assert (
subst_symint_tuple(nv.size()) == rv.size()
), f"{desc()}: {nv.size()} aka {subst_symint_tuple(nv.size())} != {rv.size()}"
assert subst_symint_tuple(nv.size()) == rv.size(), (
f"{desc()}: {nv.size()} aka {subst_symint_tuple(nv.size())} != {rv.size()}"
)
same_strides = check_significant_strides(nv, rv)
assert (
same_strides
), f"{desc()}: {nv.stride()} aka {subst_symint_tuple(nv.stride())} != {rv.stride()}"
assert same_strides, (
f"{desc()}: {nv.stride()} aka {subst_symint_tuple(nv.stride())} != {rv.stride()}"
)
r = super().run_node(n)
if "val" in n.meta:

View File

@ -7,6 +7,7 @@
"""
Global flags for aot autograd
"""
import os
import sys
from typing import Literal, Optional, TYPE_CHECKING

View File

@ -233,7 +233,7 @@ def vjp(func: Callable, *primals, has_aux: bool = False):
>>> x = torch.randn([5])
>>> f = lambda x: x.sin().sum()
>>> (_, vjpfunc) = torch.func.vjp(f, x)
>>> grad = vjpfunc(torch.tensor(1.))[0]
>>> grad = vjpfunc(torch.tensor(1.0))[0]
>>> assert torch.allclose(grad, torch.func.grad(f)(x))
However, :func:`vjp` can support functions with multiple outputs by
@ -248,9 +248,9 @@ def vjp(func: Callable, *primals, has_aux: bool = False):
:func:`vjp` can even support outputs being Python structs
>>> x = torch.randn([5])
>>> f = lambda x: {'first': x.sin(), 'second': x.cos()}
>>> f = lambda x: {"first": x.sin(), "second": x.cos()}
>>> (_, vjpfunc) = torch.func.vjp(f, x)
>>> cotangents = {'first': torch.ones([5]), 'second': torch.ones([5])}
>>> cotangents = {"first": torch.ones([5]), "second": torch.ones([5])}
>>> vjps = vjpfunc(cotangents)
>>> assert torch.allclose(vjps[0], x.cos() + -x.sin())
@ -274,7 +274,7 @@ def vjp(func: Callable, *primals, has_aux: bool = False):
>>>
>>> (_, vjpfunc) = torch.func.vjp(f, x)
>>> vjps = vjpfunc(torch.ones_like(x))
>>> assert torch.allclose(vjps[0], torch.full(x.shape, 4.))
>>> assert torch.allclose(vjps[0], torch.full(x.shape, 4.0))
.. note::
Using PyTorch ``torch.no_grad`` together with ``vjp``.
@ -930,8 +930,7 @@ def assert_output_is_tensor_or_tensors(output: Any, api: str) -> None:
return
if not isinstance(output, tuple):
raise RuntimeError(
f"{api}: Expected output of f to be a Tensor or Tensors, got "
f"{type(output)}"
f"{api}: Expected output of f to be a Tensor or Tensors, got {type(output)}"
)
if len(output) == 0:
raise RuntimeError(
@ -1023,10 +1022,10 @@ def jvp(
>>> from torch.func import jvp
>>> x = torch.randn([])
>>> f = lambda x: x * torch.tensor([1., 2., 3])
>>> value, grad = jvp(f, (x,), (torch.tensor(1.),))
>>> f = lambda x: x * torch.tensor([1.0, 2.0, 3])
>>> value, grad = jvp(f, (x,), (torch.tensor(1.0),))
>>> assert torch.allclose(value, f(x))
>>> assert torch.allclose(grad, torch.tensor([1., 2, 3]))
>>> assert torch.allclose(grad, torch.tensor([1.0, 2, 3]))
:func:`jvp` can support functions with multiple inputs by passing in the
tangents for each of the inputs

View File

@ -60,7 +60,10 @@ def functional_call(
.. code-block:: python
a = ({'weight': torch.ones(1, 1)}, {'buffer': torch.zeros(1)}) # two separate dictionaries
a = (
{"weight": torch.ones(1, 1)},
{"buffer": torch.zeros(1)},
) # two separate dictionaries
mod = nn.Bar(1, 1) # return self.weight @ x + self.buffer
print(mod.weight) # tensor(...)
print(mod.buffer) # tensor(...)
@ -83,10 +86,12 @@ def functional_call(
t = torch.randn(4, 3)
model = nn.Linear(3, 3)
def compute_loss(params, x, t):
y = functional_call(model, params, x)
return nn.functional.mse_loss(y, t)
grad_weights = grad(compute_loss)(dict(model.named_parameters()), x, t)
.. note:: If the user does not need grad tracking outside of grad transforms, they can detach all of the
@ -179,9 +184,11 @@ def stack_module_state(
models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)]
data = torch.randn(batch_size, 3)
def wrapper(params, buffers, data):
return torch.func.functional_call(models[0], (params, buffers), data)
params, buffers = stack_module_state(models)
output = vmap(wrapper, (0, 0, None))(params, buffers, data)
@ -192,6 +199,8 @@ def stack_module_state(
.. code-block:: python
import torch.nn as nn
class Foo(nn.Module):
def __init__(self, in_features, out_features):
super().__init__()
@ -202,6 +211,7 @@ def stack_module_state(
def forward(self, x):
return self.l2(self.l1(x))
num_models = 5
in_features, out_features = 3, 3
models = [Foo(in_features, out_features) for i in range(num_models)]

View File

@ -374,10 +374,12 @@ def make_functional(
model = nn.Linear(3, 3)
func, params = make_functional(model)
def compute_loss(params, x, t):
y = func(params, x)
return nn.functional.mse_loss(y, t)
grad_weights = grad(compute_loss)(params, x, t)
If the model has any buffers, please use :func:`make_functional_with_buffers` instead.
@ -443,10 +445,12 @@ def make_functional_with_buffers(
model = nn.Linear(3, 3)
func, params, buffers = make_functional_with_buffers(model)
def compute_loss(params, buffers, x, t):
y = func(params, buffers, x)
return nn.functional.mse_loss(y, t)
grad_weights = grad(compute_loss)(params, buffers, x, t)
Args:
@ -469,7 +473,7 @@ def make_functional_with_buffers(
def transpose_stack(
tuple_of_tuple_of_tensors: tuple[tuple[Tensor, ...], ...]
tuple_of_tuple_of_tensors: tuple[tuple[Tensor, ...], ...],
) -> tuple[Tensor, ...]:
tuple_of_tuple_of_tensors = tuple(zip(*tuple_of_tuple_of_tensors))
results = tuple(

View File

@ -229,9 +229,9 @@ def _extract_graph_with_inputs_outputs(
if isinstance(x, fx.Node):
if x not in env:
raise RuntimeError(f"Node {x} couldn't be found in env")
assert not isinstance(
env[x], InvalidNodeBase
), f"Node {x} was invalid, but is output"
assert not isinstance(env[x], InvalidNodeBase), (
f"Node {x} was invalid, but is output"
)
output_values.append(env[x])
else:
output_values.append(x)
@ -449,10 +449,10 @@ def perform_quantization(
args=(clamp_max_scaled_node, quant_type),
name="fp8_quant_" + str(node.name),
)
quant_activation_node.meta[
"val"
] = torch.ops.prims.convert_element_type.default(
clamp_max_scaled_node.meta["val"], quant_type
quant_activation_node.meta["val"] = (
torch.ops.prims.convert_element_type.default(
clamp_max_scaled_node.meta["val"], quant_type
)
)
quant_activation_node.meta["tensor_meta"] = extract_tensor_metadata(
quant_activation_node.meta["val"]
@ -567,10 +567,10 @@ def quantize_activation_fw(graph: torch.fx.Graph) -> None:
args=(node, quant_type),
name="fp8_quant_" + str(node.name),
)
quant_node.meta[
"val"
] = torch.ops.prims.convert_element_type.default(
node.meta["val"], quant_type
quant_node.meta["val"] = (
torch.ops.prims.convert_element_type.default(
node.meta["val"], quant_type
)
)
quant_node.meta["tensor_meta"] = extract_tensor_metadata(
quant_node.meta["val"]
@ -578,7 +578,7 @@ def quantize_activation_fw(graph: torch.fx.Graph) -> None:
node_to_quant[node] = quant_node
# only update the return node args, and remain all other users unchanged
output_updated_args = [
node_to_quant[node] if node in node_to_quant else node for node in fwd_outputs # type: ignore[union-attr]
node_to_quant[node] if node in node_to_quant else node for node in fwd_outputs
]
# add the scale nodes to the ouput find the first sym_node in the output
idx = find_first_sym_node(output_updated_args)
@ -617,10 +617,10 @@ def quantize_activation_bw(graph: torch.fx.Graph) -> None:
torch.ops.prims.convert_element_type.default,
args=(node, dequant_type),
)
activation_node.meta[
"val"
] = torch.ops.prims.convert_element_type.default(
node.meta["val"], dequant_type
activation_node.meta["val"] = (
torch.ops.prims.convert_element_type.default(
node.meta["val"], dequant_type
)
)
activation_node.meta["tensor_meta"] = extract_tensor_metadata(
activation_node.meta["val"]
@ -633,18 +633,18 @@ def quantize_activation_bw(graph: torch.fx.Graph) -> None:
divided_target_node_32.meta["val"] = torch.ops.aten.div.Tensor(
activation_node.meta["val"], scale_node.meta["val"]
)
divided_target_node_32.meta[
"tensor_meta"
] = extract_tensor_metadata(divided_target_node_32.meta["val"])
divided_target_node_32.meta["tensor_meta"] = (
extract_tensor_metadata(divided_target_node_32.meta["val"])
)
with graph.inserting_after(divided_target_node_32):
dequant_node = graph.call_function(
torch.ops.prims.convert_element_type.default,
args=(divided_target_node_32, dequant_type),
)
dequant_node.meta[
"val"
] = torch.ops.prims.convert_element_type.default(
divided_target_node_32.meta["val"], dequant_type
dequant_node.meta["val"] = (
torch.ops.prims.convert_element_type.default(
divided_target_node_32.meta["val"], dequant_type
)
)
dequant_node.meta["tensor_meta"] = extract_tensor_metadata(
dequant_node.meta["val"]
@ -656,10 +656,10 @@ def quantize_activation_bw(graph: torch.fx.Graph) -> None:
args=(node, dequant_type),
name="dequant_" + str(node.name),
)
dequant_node.meta[
"val"
] = torch.ops.prims.convert_element_type.default(
node.meta["val"], dequant_type
dequant_node.meta["val"] = (
torch.ops.prims.convert_element_type.default(
node.meta["val"], dequant_type
)
)
dequant_node.meta["tensor_meta"] = extract_tensor_metadata(
dequant_node.meta["val"]

View File

@ -4,6 +4,7 @@
From https://docs.google.com/spreadsheets/d/12R3nCOLskxPYjjiNkdqy4OdQ65eQp_htebXGODsjSeA/edit#gid=0
Try to keep this list in sync with that.
"""
import operator

View File

@ -156,6 +156,9 @@ def call_delegate_functionalize(
)
with ctx.redispatch_to_next():
res = aoti_call_delegate(
lowered_module, original_gm, unwrapped_weight_args, unwrapped_input_args # type: ignore[arg-type]
lowered_module,
original_gm,
unwrapped_weight_args, # type: ignore[arg-type]
unwrapped_input_args, # type: ignore[arg-type]
)
return ctx.wrap_tensors(res)

View File

@ -31,9 +31,9 @@ aten = torch._ops.ops.aten
def wrap_combine_fn_flat(*args, combine_fn, spec, num_leaves):
assert (
len(args) == 2 * num_leaves
), f"Combin_fn received wrong number of arguments, expected {2 * num_leaves}, but got {len(args)}"
assert len(args) == 2 * num_leaves, (
f"Combin_fn received wrong number of arguments, expected {2 * num_leaves}, but got {len(args)}"
)
lhs = pytree.tree_unflatten(args[:num_leaves], spec)
rhs = pytree.tree_unflatten(args[num_leaves:], spec)
return combine_fn(lhs, rhs)
@ -79,9 +79,9 @@ class AssociativeScanOp(HigherOrderOperator):
# the additional_inputs being a list. See https://github.com/pytorch/pytorch/issues/145785
# Once this issue is resolved, the assertion should only allow tuples
# and the tuple cast should be removed
assert isinstance(
additional_inputs, (tuple, list)
), "additional_inputs must be a tuple."
assert isinstance(additional_inputs, (tuple, list)), (
"additional_inputs must be a tuple."
)
additional_inputs = (
tuple(additional_inputs)
if isinstance(additional_inputs, list)
@ -134,6 +134,7 @@ def associative_scan(
def add(x: torch.Tensor, y: torch.Tensor):
return x + y
cumsum = associative_scan(add, x, dim)
"""
@ -377,9 +378,9 @@ def trace_associative_scan(
assert outputs is not None
outputs = pytree.tree_leaves(outputs)
assert len(outputs) == len(
xs
), f"expected combine_fn to return {len(xs)} results but got {len(outputs)}"
assert len(outputs) == len(xs), (
f"expected combine_fn to return {len(xs)} results but got {len(outputs)}"
)
xs_fake_tensors: list[torch.Tensor | torch.SymInt | int] = [
first_slice_copy(x) for x in xs

View File

@ -522,7 +522,8 @@ def do_auto_functionalize(
)
with ctx.redispatch_to_next():
unwrapped_outs = auto_functionalized(
op, **unwrapped_kwargs # type: ignore[arg-type]
op,
**unwrapped_kwargs, # type: ignore[arg-type]
)
# List of the name of args that get mutated (according to the schema)
@ -704,7 +705,8 @@ def do_auto_functionalize_v2(
with ctx.redispatch_to_next():
unwrapped_outs = auto_functionalized_v2(
op, **auto_func_kwargs # type: ignore[arg-type]
op,
**auto_func_kwargs, # type: ignore[arg-type]
)
unwrapped_actual_out: Union[Any, tuple[Any]] = (
@ -716,9 +718,9 @@ def do_auto_functionalize_v2(
)
if isinstance(op, HigherOrderOperator):
assert (
len(schema.returns) > 0
), f"hop is expected to return at least one output {schema}."
assert len(schema.returns) > 0, (
f"hop is expected to return at least one output {schema}."
)
assert len(unwrapped_actual_out) == len(schema.returns)
else:
if len(schema.returns) == 0:

View File

@ -40,11 +40,14 @@ class BaseHOP(HigherOrderOperator, abc.ABC):
def __init__(self):
return super().__init__("invoke_quant")
invoke_quant = InvokeQuant()
def g(x):
return x.sin().cos()
@torch.compile(backend="aot_eager")
def f(x):
return invoke_quant(g, x, scheme="nf4")
@ -113,7 +116,10 @@ class BaseHOP(HigherOrderOperator, abc.ABC):
out = self(subgraph, *operands, **kwargs)
return track_tensor_tree(
out, out_proxy, constant=None, tracer=proxy_mode.tracer # type: ignore[arg-type]
out,
out_proxy,
constant=None,
tracer=proxy_mode.tracer, # type: ignore[arg-type]
)
def _call_FakeTensorMode(self, mode, subgraph, *operands, **kwargs):
@ -230,7 +236,11 @@ class BaseHOPFunction(torch.autograd.Function):
kwargs = ctx.kwargs
# TODO: Something special needs to happen with min cut partitioner
with suspend_functionalization(), disable_functional_mode(), torch.enable_grad():
with (
suspend_functionalization(),
disable_functional_mode(),
torch.enable_grad(),
):
with disable_proxy_modes_tracing():
from .invoke_subgraph import create_fw_bw_graph
from .utils import _from_fun

View File

@ -107,8 +107,12 @@ def cond(
def true_fn(x: torch.Tensor):
return x.cos()
def false_fn(x: torch.Tensor):
return x.sin()
return cond(x.shape[0] > 4, true_fn, false_fn, (x,))
Restrictions:
@ -182,7 +186,11 @@ def cond(
def _cond_op_wrapper(*args, **kwargs):
return cond_op(*args, **kwargs)
with _set_compilation_env(), torch._dynamo.utils.disable_cache_limit(), _temp_remove_pre_dispatch_torch_function_mode():
with (
_set_compilation_env(),
torch._dynamo.utils.disable_cache_limit(),
_temp_remove_pre_dispatch_torch_function_mode(),
):
with _temp_remove_metadata_torch_function_mode() as metadata_mode:
if metadata_mode:
backend = make_eager_backend_with_torch_function_mode(metadata_mode)
@ -248,9 +256,9 @@ def create_bw_fn(fn: Callable, args: tuple[Any]) -> Callable:
def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands):
assert isinstance(
operands, (list, tuple)
), f"Cond operands must be a list or tuple of tensors and SymInts {operands}"
assert isinstance(operands, (list, tuple)), (
f"Cond operands must be a list or tuple of tensors and SymInts {operands}"
)
true_graph = reenter_make_fx(true_fn)(*operands)
false_graph = reenter_make_fx(false_fn)(*operands)
@ -297,9 +305,9 @@ def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands):
@cond_op.py_impl(DispatchKey.CompositeExplicitAutograd)
def cond_op_dense(pred, true_fn, false_fn, operands):
assert all(
isinstance(o, (torch.Tensor, int)) for o in operands
), f"Dense implementation operands must be a list of tensors and ints {operands}"
assert all(isinstance(o, (torch.Tensor, int)) for o in operands), (
f"Dense implementation operands must be a list of tensors and ints {operands}"
)
mode = _get_current_dispatch_mode()
assert mode is None, "Mode should never be enabled for CPU/CUDA key"
if pred:
@ -627,9 +635,9 @@ def _merge_output(
if _maybe_expr(a_val) in a_stride_expr:
a_expr = a_stride_expr[_maybe_expr(a_val)]
assert (
b_stride_expr[_maybe_expr(b_val)] == a_expr
), f"a_stride_expr:{a_stride_expr}, b_stride_expr:{b_stride_expr}"
assert b_stride_expr[_maybe_expr(b_val)] == a_expr, (
f"a_stride_expr:{a_stride_expr}, b_stride_expr:{b_stride_expr}"
)
merged_strides[i] = a_expr
else:
if a_val == 1:
@ -686,12 +694,12 @@ def cond_func(ctx, pred, true_fn, false_fn, inputs):
@cond_op.py_impl(torch._C._functorch.TransformType.Vmap)
def cond_batch_rule(interpreter, pred, true_fn, false_fn, inputs):
assert isinstance(
inputs, (list, tuple)
), "Cond inputs must be a list or tuple of tensors"
assert all(
isinstance(i, torch.Tensor) for i in inputs
), "Cond inputs must be a list of tensors"
assert isinstance(inputs, (list, tuple)), (
"Cond inputs must be a list or tuple of tensors"
)
assert all(isinstance(i, torch.Tensor) for i in inputs), (
"Cond inputs must be a list of tensors"
)
pred_is_batched = isinstance(pred, torch.Tensor) and is_batchedtensor(pred)
pred_ = get_unwrapped(pred) if pred_is_batched else pred

View File

@ -240,9 +240,9 @@ def handle_effects(
key = get_effect_key(op, args, kwargs)
assert key is not None
if key not in tokens:
assert (
allow_token_discovery
), f"Could not find a token for effect {key} which came from the function {op}"
assert allow_token_discovery, (
f"Could not find a token for effect {key} which came from the function {op}"
)
proxy_tensor_mode = torch._C._get_dispatch_mode(
torch._C._TorchDispatchModeKey.PROXY
)

View File

@ -49,7 +49,10 @@ def trace_call_delegate(proxy_mode, func_overload, lowered_module, *args):
if not isinstance(e, (torch.Tensor, torch.SymInt, torch.SymFloat)):
return e
return get_proxy_slot(
cast(torch.Tensor, e), proxy_mode.tracer, e, lambda e: e.proxy # type: ignore[attr-defined]
cast(torch.Tensor, e),
proxy_mode.tracer,
e,
lambda e: e.proxy, # type: ignore[attr-defined]
)
if not is_lowered_module(lowered_module):

View File

@ -38,9 +38,9 @@ def _construct_strides(
) -> Sequence[int]:
"""From a list of sizes and a fill order, construct the strides of the permuted tensor."""
# Initialize strides
assert len(sizes) == len(
fill_order
), "Length of sizes must match the length of the fill order"
assert len(sizes) == len(fill_order), (
"Length of sizes must match the length of the fill order"
)
strides = [0] * len(sizes)
# Start with stride 1 for the innermost dimension
@ -594,7 +594,7 @@ def create_fw_bw_graph(
*other_buffers: tuple[Tensor, ...],
) -> tuple[Tensor, ...]:
def fw_with_masks(
*args: tuple[Tensor, ...]
*args: tuple[Tensor, ...],
) -> tuple[tuple[Tensor], tuple[bool]]:
fw_out = score_mod(*args)
out_requires_grad = fw_out.requires_grad
@ -633,9 +633,9 @@ class FlexAttentionAutogradOp(torch.autograd.Function):
for buffer in mask_mod_other_buffers
if isinstance(buffer, torch.Tensor)
)
assert (
not any_buffer_requires_grad
), "Captured buffers from mask mod that require grad are not supported."
assert not any_buffer_requires_grad, (
"Captured buffers from mask mod that require grad are not supported."
)
ctx._fw_graph = fw_graph
ctx._joint_graph = joint_graph
ctx._mask_graph = block_mask[-1]
@ -671,7 +671,11 @@ class FlexAttentionAutogradOp(torch.autograd.Function):
return out, logsumexp
@staticmethod
def backward(ctx: Any, grad_out: Tensor, grad_logsumexp: Tensor) -> tuple[Optional[Tensor], ...]: # type: ignore[override]
def backward( # type: ignore[override]
ctx: Any,
grad_out: Tensor,
grad_logsumexp: Tensor,
) -> tuple[Optional[Tensor], ...]:
fw_args = saved_tensors_and_symints(ctx)
(
query,
@ -939,9 +943,9 @@ def sdpa_dense_backward(
actual_grad_value.copy_(grad_value)
if Bq != Bkv:
assert (
Bq > 1 and Bkv == 1
), f"Bq and Bkv must broadcast. Got Bq={Bq} and Bkv={Bkv}"
assert Bq > 1 and Bkv == 1, (
f"Bq and Bkv must broadcast. Got Bq={Bq} and Bkv={Bkv}"
)
actual_grad_key = torch.sum(actual_grad_key, 0, keepdim=True)
actual_grad_value = torch.sum(actual_grad_value, 0, keepdim=True)

View File

@ -38,8 +38,7 @@ class HintsWrapper(HigherOrderOperator):
if not all(isinstance(t, (torch.Tensor, int, float, bool)) for t in args):
raise RuntimeError(
"args must be a tuple of tensors, ints, floats, or bools, got "
f"{args}"
f"args must be a tuple of tensors, ints, floats, or bools, got {args}"
)
if not isinstance(kwargs, dict):

View File

@ -1,6 +1,5 @@
# mypy: allow-untyped-defs
import contextlib
from contextlib import nullcontext
from dataclasses import dataclass, field
@ -70,9 +69,9 @@ class InvokeSubgraphHOP(HigherOrderOperator):
identifier: Optional[str],
*operands,
):
assert identifier is None or isinstance(
identifier, str
), "identifier must be a None or a string"
assert identifier is None or isinstance(identifier, str), (
"identifier must be a None or a string"
)
assert all(
isinstance(o, (torch.Tensor, int, torch.SymInt)) for o in operands
@ -128,7 +127,11 @@ def invoke_subgraph_placeholder(func, *args, **kwargs):
def _invoke_subgraph_placeholder_wrapper(func, args):
return invoke_subgraph_placeholder(func, *args)
with _set_compilation_env(), torch._dynamo.utils.disable_cache_limit(), _temp_remove_pre_dispatch_torch_function_mode():
with (
_set_compilation_env(),
torch._dynamo.utils.disable_cache_limit(),
_temp_remove_pre_dispatch_torch_function_mode(),
):
with _temp_remove_metadata_torch_function_mode() as metadata_mode:
if metadata_mode:
backend = make_eager_backend_with_torch_function_mode(metadata_mode)

View File

@ -36,9 +36,9 @@ aten = torch._ops.ops.aten
def wrap_combine_fn_flat(
*args, combine_fn, spec_init, spec_xs, num_init_leaves, num_inp_leaves
):
assert len(args) == (
num_init_leaves + num_inp_leaves
), f"Combin_fn received wrong number of arguments, expected {num_init_leaves + num_inp_leaves}, but got {len(args)}"
assert len(args) == (num_init_leaves + num_inp_leaves), (
f"Combin_fn received wrong number of arguments, expected {num_init_leaves + num_inp_leaves}, but got {len(args)}"
)
carry = pytree.tree_unflatten(args[:num_init_leaves], spec_init)
xs = pytree.tree_unflatten(args[num_init_leaves:], spec_xs)
return combine_fn(carry, xs)
@ -73,13 +73,13 @@ def mask_list(
# If other is None, then the elements of the `inp` list where the mask is False are removed
# If other is not None, then the elements of the `inp` list where the mask is False are
# replaced with the elements of the `other` list
assert len(mask) == len(
inp
), "The length of the mask needs to be identical to the length of the input"
assert len(mask) == len(inp), (
"The length of the mask needs to be identical to the length of the input"
)
if other is not None:
assert len(inp) == len(
other
), "If an input and an other list is provided, they need to have the same length"
assert len(inp) == len(other), (
"If an input and an other list is provided, they need to have the same length"
)
return [i if m else o for m, i, o in zip(mask, inp, other)]
else:
return [i for m, i in zip(mask, inp) if m]
@ -97,9 +97,9 @@ def first_slice_copy_with_grad(li: list[Any]) -> list[Any]:
def split_into_chunks(iterable: Sequence[Any], chunk_sizes: list[int]) -> list[Any]:
it = iter(iterable)
assert sum(chunk_sizes) == len(
iterable
), "the sum of all chunks needs to match the length of the iterable."
assert sum(chunk_sizes) == len(iterable), (
"the sum of all chunks needs to match the length of the iterable."
)
return [list(itertools.islice(it, size)) for size in chunk_sizes]
@ -166,6 +166,7 @@ def scan(
# clone the output to avoid output-output aliasing
return next_carry, y.clone()
i0 = torch.zeros(1)
xs = torch.arange(5)
# returns torch.tensor([10.]), torch.tensor([[0], [1.], [3.], [6.], [10.]])
@ -262,9 +263,9 @@ class ScanOp(HigherOrderOperator):
# the additional_inputs being a list. See https://github.com/pytorch/pytorch/issues/145785
# Once this issue is resolved, the assertion should only allow tuples
# and the tuple cast should be removed
assert isinstance(
additional_inputs, (tuple, list)
), "additional_inputs must be a tuple."
assert isinstance(additional_inputs, (tuple, list)), (
"additional_inputs must be a tuple."
)
additional_inputs = (
tuple(additional_inputs)
if isinstance(additional_inputs, list)

View File

@ -35,9 +35,9 @@ class HopArgumentInfoGen:
kw_only: bool = False,
) -> HopArgumentInfo:
if default_value is not None:
assert type(example_value) == type(
default_value
), f"example_value type {type(example_value)} doesn't match default_value type: {type(default_value)}"
assert type(example_value) == type(default_value), (
f"example_value type {type(example_value)} doesn't match default_value type: {type(default_value)}"
)
return HopArgumentInfo(
name=name,
@ -207,12 +207,12 @@ class CFunctionSchemaGen:
args.append(CArgumentGen.from_hop_argument_info(i, arg_info))
# NOTE: we want the output to always be a single argument with torch._C.TupleType.
assert isinstance(
out_argument_info.example_value, tuple
), f"expect out_argument_info's example_value to be a tuple but got {out_argument_info.example_value}"
assert (
not out_argument_info.is_mutated
), "out_argument_info.is_mutated should always be set to False."
assert isinstance(out_argument_info.example_value, tuple), (
f"expect out_argument_info's example_value to be a tuple but got {out_argument_info.example_value}"
)
assert not out_argument_info.is_mutated, (
"out_argument_info.is_mutated should always be set to False."
)
rets = None
if len(out_argument_info.example_value) == 1:
rets = [CArgumentGen.from_hop_argument_info(0, out_argument_info, True)]

View File

@ -81,9 +81,9 @@ def enable_torchbind_tracing():
torch.ScriptMethod.__call__ = torchbind_method_redispatch # type: ignore[method-assign]
yield
finally:
assert (
KNOWN_TYPES.pop() is torch.ScriptObject
), "Someone else messed with KNOWN_TYPES during tracing, exploding."
assert KNOWN_TYPES.pop() is torch.ScriptObject, (
"Someone else messed with KNOWN_TYPES during tracing, exploding."
)
torch.ScriptMethod.__call__ = _orig_scriptmethod_call # type: ignore[method-assign]
@ -127,9 +127,9 @@ def inner(mode, *args, **kwargs):
ret = track_tensor_tree(out, out_proxy, constant=None, tracer=mode.tracer)
if "val" not in out_proxy.node.meta:
assert out is None or isinstance(
out, (int, float, bool)
), "Currently, only these constant dtypes are supported to be returned from torchbind methods."
assert out is None or isinstance(out, (int, float, bool)), (
"Currently, only these constant dtypes are supported to be returned from torchbind methods."
)
out_proxy.node.meta["val"] = out
return ret

View File

@ -93,7 +93,7 @@ def create_tma_experimental_metadata(
def maybe_unpack_tma_experimental_metadata(
tma_meta: Union[TMAExperimentalMetadata, TMAStableMetadata]
tma_meta: Union[TMAExperimentalMetadata, TMAStableMetadata],
) -> Optional[tuple[list[IntLikeType], list[IntLikeType], IntLikeType]]:
if not tma_meta or len(tma_meta) != 2:
return None
@ -109,7 +109,7 @@ def create_tma_stable_metadata(
def maybe_unpack_tma_stable_metadata(
tma_meta: Union[TMAExperimentalMetadata, TMAStableMetadata]
tma_meta: Union[TMAExperimentalMetadata, TMAStableMetadata],
) -> Optional[tuple[list[IntLikeType]]]:
if not tma_meta or len(tma_meta) != 2:
return None
@ -1122,7 +1122,8 @@ def trace_triton_kernel_wrapper(
out = func_overload(**node_args)
proxy_args = pytree.tree_map(
proxy_mode.tracer.unwrap_proxy, node_args # type: ignore[union-attr]
proxy_mode.tracer.unwrap_proxy, # type: ignore[union-attr]
node_args,
)
out_proxy = proxy_mode.tracer.create_proxy(
"call_function",
@ -1660,9 +1661,9 @@ class TritonHOPifier:
# Update the kwargs in each config
# maybe_unpack_heuristic_result raises unsupported if the value is non-constant
new_configs[config_idx].__dict__["kwargs"][
kwarg_key
] = self.maybe_unpack_heuristic_result(heuristic_result)
new_configs[config_idx].__dict__["kwargs"][kwarg_key] = (
self.maybe_unpack_heuristic_result(heuristic_result)
)
iter_kernel = iter_kernel.fn
assert isinstance(iter_kernel, JITFunction)
@ -1742,9 +1743,9 @@ class TritonHOPifier:
for config in new_configs:
for name in special_param_names:
if name not in config.__dict__["kwargs"]:
assert (
name in config.__dict__
), f"{name} must be in autotuning configs to be used as a kernel parameter"
assert name in config.__dict__, (
f"{name} must be in autotuning configs to be used as a kernel parameter"
)
config.__dict__["kwargs"][name] = config.__dict__[name]
updated = True

View File

@ -115,9 +115,9 @@ def reenter_make_fx(fn):
@functools.wraps(fn)
def wrapped(*args):
assert (
_CURRENT_MAKE_FX_TRACER is not None
), "Cannot reenter make_fx when we're not under a make_fx tracing session"
assert _CURRENT_MAKE_FX_TRACER is not None, (
"Cannot reenter make_fx when we're not under a make_fx tracing session"
)
return _CURRENT_MAKE_FX_TRACER.trace_subgraph(
_maybe_run_with_interpreter(fn), *args
)
@ -323,20 +323,22 @@ def analyze_potential_input_alias_or_mutation(name, aliases, input_mutations):
def _has_potential_branch_input_mutation(gm, inputs, pre_dispatch=False):
(
_,
_,
_,
), inp_mutation = potential_input_alias_or_mutation(gm, inputs, pre_dispatch)
(_, _, _),
inp_mutation,
) = potential_input_alias_or_mutation(gm, inputs, pre_dispatch)
return len(inp_mutation) > 0
def has_potential_input_alias_or_mutation(gm, inputs, pre_dispatch=False):
(
inp_inp_alias_map,
inp_out_alias_map,
out_out_alias_map,
), inp_mutation = potential_input_alias_or_mutation(gm, inputs, pre_dispatch)
(
inp_inp_alias_map,
inp_out_alias_map,
out_out_alias_map,
),
inp_mutation,
) = potential_input_alias_or_mutation(gm, inputs, pre_dispatch)
return (
any(
(
@ -392,9 +394,7 @@ def _check_alias_and_mutation(graph_module, inputs_fake, name, pre_dispatch):
graph_module, inputs_fake, pre_dispatch=pre_dispatch
)
if aliases:
raise RuntimeError(
f"{name} might be aliasing the input or the output!"
) # noqa: F541
raise RuntimeError(f"{name} might be aliasing the input or the output!") # noqa: F541
if inp_mutation:
raise RuntimeError(f"{name} might be modifying the input!") # noqa: F541
@ -504,9 +504,9 @@ def prepare_fw_with_masks_all_requires_grad(fn):
# replaced with an all-zero tensor for better optimization
def unmask_none_gradients(grads, operands):
allowed_types = (torch.Tensor, int, torch.SymInt)
assert all(
isinstance(o, allowed_types) for o in operands
), f"operands can only be of {allowed_types} but got {[type(o) for o in operands]}"
assert all(isinstance(o, allowed_types) for o in operands), (
f"operands can only be of {allowed_types} but got {[type(o) for o in operands]}"
)
unmasked_grads = []
for g, o in zip(grads, operands):
@ -762,7 +762,9 @@ def validate_subgraph_args_types(lifted_args: Union[tuple[Any, ...], list[Any]])
allowed_types = (torch.Tensor, int, torch.SymInt)
assert all(
isinstance(arg, (torch.Tensor, int, torch.SymInt)) for arg in lifted_args
), f"{lifted_args} can only be of {allowed_types} but got {tuple(type(arg) for arg in lifted_args)}"
), (
f"{lifted_args} can only be of {allowed_types} but got {tuple(type(arg) for arg in lifted_args)}"
)
# TODO: Return a more detailed information as to which node
@ -793,7 +795,11 @@ def check_input_alias_and_mutation_return_outputs(
# This function can be called under autograd, functional, proxy and fake tensor mode.
# We need to return either a fake tensor or a real tensor depending on the mode.
# to detect the input mutation/aliasing.
with disable_proxy_modes_tracing(), disable_functional_mode(), suspend_functionalization():
with (
disable_proxy_modes_tracing(),
disable_functional_mode(),
suspend_functionalization(),
):
def _from_functional_tensor(t: torch.Tensor) -> torch.Tensor:
if isinstance(t, FunctionalTensor) or torch._is_functional_tensor(t):
@ -928,13 +934,11 @@ F = TypeVar("F", bound=Callable)
@overload
def register_fake(hop, fn: None = None) -> Callable[[F], F]:
...
def register_fake(hop, fn: None = None) -> Callable[[F], F]: ...
@overload
def register_fake(hop, fn: F) -> F:
...
def register_fake(hop, fn: F) -> F: ...
def register_fake(hop, fn=None):

View File

@ -202,12 +202,12 @@ def while_loop_dense(cond_fn, body_fn, carried_inputs, additional_inputs):
while pred := cond_fn(*carried_vals, *additional_inputs):
_validate_cond_output(pred)
out = body_fn(*carried_vals, *additional_inputs)
assert isinstance(
out, tuple
), f"body_fn should return a tuple but got {type(out)}"
assert len(out) == len(
carried_inputs
), "body_fn should return the same number of elements as carried_inputs"
assert isinstance(out, tuple), (
f"body_fn should return a tuple but got {type(out)}"
)
assert len(out) == len(carried_inputs), (
"body_fn should return the same number of elements as carried_inputs"
)
carried_vals = out
return carried_vals
@ -230,9 +230,9 @@ def _find_or_create_fake_mode() -> FakeTensorMode:
def _create_unbacked_symint(
fake_mode: FakeTensorMode, ignore_fresh_unbacked_symbols: bool
) -> torch.SymInt:
assert (
fake_mode is not None and fake_mode.shape_env is not None
), "Must provide a fake_mode with shape_env."
assert fake_mode is not None and fake_mode.shape_env is not None, (
"Must provide a fake_mode with shape_env."
)
ctx = (
contextlib.nullcontext()
if not ignore_fresh_unbacked_symbols

View File

@ -28,8 +28,7 @@ def custom_op(
mutates_args: Union[str, Iterable[str]],
device_types: device_types_t = None,
schema: Optional[str] = None,
) -> Callable[[Callable[..., object]], "CustomOpDef"]:
...
) -> Callable[[Callable[..., object]], "CustomOpDef"]: ...
@overload
@ -41,8 +40,7 @@ def custom_op(
mutates_args: Union[str, Iterable[str]],
device_types: device_types_t = None,
schema: Optional[str] = None,
) -> "CustomOpDef":
...
) -> "CustomOpDef": ...
@exposed_in("torch.library")
@ -448,10 +446,10 @@ class CustomOpDef:
>>>
>>> @nonzero.register_fake
>>> def _(x):
>>> # Number of nonzero-elements is data-dependent.
>>> # Since we cannot peek at the data in an abstract impl,
>>> # we use the ctx object to construct a new symint that
>>> # represents the data-dependent size.
>>> # Number of nonzero-elements is data-dependent.
>>> # Since we cannot peek at the data in an abstract impl,
>>> # we use the ctx object to construct a new symint that
>>> # represents the data-dependent size.
>>> ctx = torch.library.get_ctx()
>>> nnz = ctx.new_dynamic_size()
>>> shape = [nnz, x.dim()]
@ -561,7 +559,7 @@ class CustomOpDef:
>>>
>>> x = torch.randn(3, requires_grad=True)
>>> y = numpy_sin(x)
>>> grad_x, = torch.autograd.grad(y, x, torch.ones_like(y))
>>> (grad_x,) = torch.autograd.grad(y, x, torch.ones_like(y))
>>> assert torch.allclose(grad_x, x.cos())
>>>
>>> # Example with a keyword-only arg
@ -581,7 +579,7 @@ class CustomOpDef:
>>>
>>> x = torch.randn(3, requires_grad=True)
>>> y = numpy_mul(x, val=3.14)
>>> grad_x, = torch.autograd.grad(y, x, torch.ones_like(y))
>>> (grad_x,) = torch.autograd.grad(y, x, torch.ones_like(y))
>>> assert torch.allclose(grad_x, torch.full_like(x, 3.14))
"""
@ -919,7 +917,7 @@ def get_library_allowing_overwrite(
def _maybe_get_opdef(
op: Union[CustomOpDef, _ops.OpOverload, str]
op: Union[CustomOpDef, _ops.OpOverload, str],
) -> Optional[CustomOpDef]:
if isinstance(op, CustomOpDef):
return op

View File

@ -150,13 +150,13 @@ def infer_schema(
"the arguments that are mutated or the string 'unknown'. "
)
if schema_type.startswith("Tensor"):
schema_type = f"Tensor(a{idx}!){schema_type[len('Tensor'):]}"
schema_type = f"Tensor(a{idx}!){schema_type[len('Tensor') :]}"
elif name in mutates_args:
if not schema_type.startswith("Tensor"):
error_fn(
f"Parameter {name} is in mutable_args but only Tensors or collections of Tensors can be mutated"
)
schema_type = f"Tensor(a{idx}!){schema_type[len('Tensor'):]}"
schema_type = f"Tensor(a{idx}!){schema_type[len('Tensor') :]}"
seen_args.add(name)
if param.default is inspect.Parameter.empty:
params.append(f"{schema_type} {name}")

View File

@ -1242,12 +1242,12 @@ def trace_structured(
"frame_compile_id",
"attempt",
]
assert callable(
metadata_fn
), f"metadata_fn should be callable, but got {type(metadata_fn)}"
assert callable(
payload_fn
), f"payload_fn should be callable, but got {type(payload_fn)}"
assert callable(metadata_fn), (
f"metadata_fn should be callable, but got {type(metadata_fn)}"
)
assert callable(payload_fn), (
f"payload_fn should be callable, but got {type(payload_fn)}"
)
# trace_log never propagates and is ALWAYS DEBUG, so also check that there
# are handlers instead of checking the log level
if trace_log.handlers:

View File

@ -1,6 +1,7 @@
"""
Utilities for converting data types into structured JSON for dumping.
"""
import inspect
import os
import traceback

View File

@ -1,8 +1,9 @@
# mypy: ignore-errors
""" Define analogs of numpy dtypes supported by pytorch.
"""Define analogs of numpy dtypes supported by pytorch.
Define the scalar types and supported dtypes and numpy <--> torch dtype mappings.
"""
import builtins
import torch

View File

@ -5,6 +5,7 @@
Here `dtype` is always a torch.dtype, this module knows nothing about
scalar types, wrapper dtypes or anything like that. PyTorch only.
"""
from collections import namedtuple
import torch

View File

@ -5,6 +5,7 @@
Things imported from here have numpy-compatible signatures but operate on
pytorch tensors.
"""
# Contents of this module ends up in the main namespace via _funcs.py
# where type annotations are used in conjunction with the @normalizer decorator.
from __future__ import annotations

View File

@ -1,7 +1,7 @@
# mypy: ignore-errors
""" "Normalize" arguments: convert array_likes to tensors, dtypes to torch dtypes and so on.
"""
""" "Normalize" arguments: convert array_likes to tensors, dtypes to torch dtypes and so on."""
from __future__ import annotations
import functools

View File

@ -1,10 +1,11 @@
# mypy: ignore-errors
""" Implementation of reduction operations, to be wrapped into arrays, dtypes etc
"""Implementation of reduction operations, to be wrapped into arrays, dtypes etc
in the 'public' layer.
Anything here only deals with torch objects, e.g. "dtype" is a torch.dtype instance etc
"""
from __future__ import annotations
import functools

View File

@ -1,7 +1,6 @@
# mypy: ignore-errors
"""Assorted utilities, which do not need anything other then torch and stdlib.
"""
"""Assorted utilities, which do not need anything other then torch and stdlib."""
import operator

View File

@ -7,6 +7,7 @@ NumPy has strict guarantees on reproducibility etc; here we don't give any.
Q: default dtype is float64 in numpy
"""
from __future__ import annotations
import functools

View File

@ -4,6 +4,7 @@
Utility function to facilitate testing.
"""
import contextlib
import gc
import operator
@ -167,7 +168,7 @@ def assert_equal(actual, desired, err_msg="", verbose=True):
Examples
--------
>>> np.testing.assert_equal([4,5], [4,6])
>>> np.testing.assert_equal([4, 5], [4, 6])
Traceback (most recent call last):
...
AssertionError:
@ -298,8 +299,12 @@ def print_assert_equal(test_string, actual, desired):
Examples
--------
>>> np.testing.print_assert_equal('Test XYZ of func xyz', [0, 1], [0, 1]) # doctest: +SKIP
>>> np.testing.print_assert_equal('Test XYZ of func xyz', [0, 1], [0, 2]) # doctest: +SKIP
>>> np.testing.print_assert_equal(
... "Test XYZ of func xyz", [0, 1], [0, 1]
... ) # doctest: +SKIP
>>> np.testing.print_assert_equal(
... "Test XYZ of func xyz", [0, 1], [0, 2]
... ) # doctest: +SKIP
Traceback (most recent call last):
...
AssertionError: Test XYZ of func xyz failed
@ -377,8 +382,9 @@ def assert_almost_equal(actual, desired, decimal=7, err_msg="", verbose=True):
ACTUAL: 2.3333333333333
DESIRED: 2.33333334
>>> assert_almost_equal(np.array([1.0,2.3333333333333]),
... np.array([1.0,2.33333334]), decimal=9)
>>> assert_almost_equal(
... np.array([1.0, 2.3333333333333]), np.array([1.0, 2.33333334]), decimal=9
... )
Traceback (most recent call last):
...
AssertionError:
@ -487,11 +493,19 @@ def assert_approx_equal(actual, desired, significant=7, err_msg="", verbose=True
Examples
--------
>>> np.testing.assert_approx_equal(0.12345677777777e-20, 0.1234567e-20) # doctest: +SKIP
>>> np.testing.assert_approx_equal(0.12345670e-20, 0.12345671e-20, # doctest: +SKIP
... significant=8)
>>> np.testing.assert_approx_equal(0.12345670e-20, 0.12345672e-20, # doctest: +SKIP
... significant=8)
>>> np.testing.assert_approx_equal(
... 0.12345677777777e-20, 0.1234567e-20
... ) # doctest: +SKIP
>>> np.testing.assert_approx_equal(
... 0.12345670e-20,
... 0.12345671e-20, # doctest: +SKIP
... significant=8,
... )
>>> np.testing.assert_approx_equal(
... 0.12345670e-20,
... 0.12345672e-20, # doctest: +SKIP
... significant=8,
... )
Traceback (most recent call last):
...
AssertionError:
@ -501,7 +515,7 @@ def assert_approx_equal(actual, desired, significant=7, err_msg="", verbose=True
the evaluated condition that raises the exception is
>>> abs(0.12345670e-20/1e-21 - 0.12345672e-20/1e-21) >= 10**-(8-1)
>>> abs(0.12345670e-20 / 1e-21 - 0.12345672e-20 / 1e-21) >= 10 ** -(8 - 1)
True
"""
@ -776,15 +790,16 @@ def assert_array_equal(x, y, err_msg="", verbose=True, *, strict=False):
--------
The first assert does not raise an exception:
>>> np.testing.assert_array_equal([1.0,2.33333,np.nan],
... [np.exp(0),2.33333, np.nan])
>>> np.testing.assert_array_equal(
... [1.0, 2.33333, np.nan], [np.exp(0), 2.33333, np.nan]
... )
Use `assert_allclose` or one of the nulp (number of floating point values)
functions for these cases instead:
>>> np.testing.assert_allclose([1.0,np.pi,np.nan],
... [1, np.sqrt(np.pi)**2, np.nan],
... rtol=1e-10, atol=0)
>>> np.testing.assert_allclose(
... [1.0, np.pi, np.nan], [1, np.sqrt(np.pi) ** 2, np.nan], rtol=1e-10, atol=0
... )
As mentioned in the Notes section, `assert_array_equal` has special
handling for scalars. Here the test checks that each value in `x` is 3:
@ -809,7 +824,7 @@ def assert_array_equal(x, y, err_msg="", verbose=True, *, strict=False):
The `strict` parameter also ensures that the array data types match:
>>> x = np.array([2, 2, 2])
>>> y = np.array([2., 2., 2.], dtype=np.float32)
>>> y = np.array([2.0, 2.0, 2.0], dtype=np.float32)
>>> np.testing.assert_array_equal(x, y, strict=True)
Traceback (most recent call last):
...
@ -881,11 +896,11 @@ def assert_array_almost_equal(x, y, decimal=6, err_msg="", verbose=True):
--------
the first assert does not raise an exception
>>> np.testing.assert_array_almost_equal([1.0,2.333,np.nan],
... [1.0,2.333,np.nan])
>>> np.testing.assert_array_almost_equal([1.0, 2.333, np.nan], [1.0, 2.333, np.nan])
>>> np.testing.assert_array_almost_equal([1.0,2.33333,np.nan],
... [1.0,2.33339,np.nan], decimal=5)
>>> np.testing.assert_array_almost_equal(
... [1.0, 2.33333, np.nan], [1.0, 2.33339, np.nan], decimal=5
... )
Traceback (most recent call last):
...
AssertionError:
@ -897,8 +912,9 @@ def assert_array_almost_equal(x, y, decimal=6, err_msg="", verbose=True):
x: torch.ndarray([1.0000, 2.3333, nan], dtype=float64)
y: torch.ndarray([1.0000, 2.3334, nan], dtype=float64)
>>> np.testing.assert_array_almost_equal([1.0,2.33333,np.nan],
... [1.0,2.33333, 5], decimal=5)
>>> np.testing.assert_array_almost_equal(
... [1.0, 2.33333, np.nan], [1.0, 2.33333, 5], decimal=5
... )
Traceback (most recent call last):
...
AssertionError:
@ -1054,8 +1070,8 @@ def assert_string_equal(actual, desired):
Examples
--------
>>> np.testing.assert_string_equal('abc', 'abc') # doctest: +SKIP
>>> np.testing.assert_string_equal('abc', 'abcd') # doctest: +SKIP
>>> np.testing.assert_string_equal("abc", "abc") # doctest: +SKIP
>>> np.testing.assert_string_equal("abc", "abcd") # doctest: +SKIP
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
...
@ -1341,11 +1357,11 @@ def assert_array_almost_equal_nulp(x, y, nulp=1):
Examples
--------
>>> x = np.array([1., 1e-10, 1e-20])
>>> x = np.array([1.0, 1e-10, 1e-20])
>>> eps = np.finfo(x.dtype).eps
>>> np.testing.assert_array_almost_equal_nulp(x, x*eps/2 + x) # doctest: +SKIP
>>> np.testing.assert_array_almost_equal_nulp(x, x * eps / 2 + x) # doctest: +SKIP
>>> np.testing.assert_array_almost_equal_nulp(x, x*eps + x) # doctest: +SKIP
>>> np.testing.assert_array_almost_equal_nulp(x, x * eps + x) # doctest: +SKIP
Traceback (most recent call last):
...
AssertionError: X and Y are not equal to 1 ULP (max is 2)
@ -1404,7 +1420,7 @@ def assert_array_max_ulp(a, b, maxulp=1, dtype=None):
Examples
--------
>>> a = np.linspace(0., 1., 100)
>>> a = np.linspace(0.0, 1.0, 100)
>>> res = np.testing.assert_array_max_ulp(a, np.arcsin(np.sin(a))) # doctest: +SKIP
"""
@ -1562,7 +1578,7 @@ def assert_warns(warning_class, *args, **kwargs):
>>> import warnings
>>> def deprecated_func(num):
... warnings.warn("Please upgrade", DeprecationWarning)
... return num*num
... return num * num
>>> with np.testing.assert_warns(DeprecationWarning):
... assert deprecated_func(4) == 16
>>> # or passing a func
@ -1663,19 +1679,29 @@ def _gen_alignment_data(dtype=float32, type="binary", max_size=24):
yield out, inp(), ufmt % (o, o, s, dtype, "out of place")
d = inp()
yield d, d, ufmt % (o, o, s, dtype, "in place")
yield out[1:], inp()[:-1], ufmt % (
o + 1,
o,
s - 1,
dtype,
"out of place",
yield (
out[1:],
inp()[:-1],
ufmt
% (
o + 1,
o,
s - 1,
dtype,
"out of place",
),
)
yield out[:-1], inp()[1:], ufmt % (
o,
o + 1,
s - 1,
dtype,
"out of place",
yield (
out[:-1],
inp()[1:],
ufmt
% (
o,
o + 1,
s - 1,
dtype,
"out of place",
),
)
yield inp()[:-1], inp()[1:], ufmt % (o, o + 1, s - 1, dtype, "aliased")
yield inp()[1:], inp()[:-1], ufmt % (o + 1, o, s - 1, dtype, "aliased")
@ -1691,53 +1717,89 @@ def _gen_alignment_data(dtype=float32, type="binary", max_size=24):
yield d, d, inp2(), bfmt % (o, o, o, s, dtype, "in place1")
d = inp2()
yield d, inp1(), d, bfmt % (o, o, o, s, dtype, "in place2")
yield out[1:], inp1()[:-1], inp2()[:-1], bfmt % (
o + 1,
o,
o,
s - 1,
dtype,
"out of place",
yield (
out[1:],
inp1()[:-1],
inp2()[:-1],
bfmt
% (
o + 1,
o,
o,
s - 1,
dtype,
"out of place",
),
)
yield out[:-1], inp1()[1:], inp2()[:-1], bfmt % (
o,
o + 1,
o,
s - 1,
dtype,
"out of place",
yield (
out[:-1],
inp1()[1:],
inp2()[:-1],
bfmt
% (
o,
o + 1,
o,
s - 1,
dtype,
"out of place",
),
)
yield out[:-1], inp1()[:-1], inp2()[1:], bfmt % (
o,
o,
o + 1,
s - 1,
dtype,
"out of place",
yield (
out[:-1],
inp1()[:-1],
inp2()[1:],
bfmt
% (
o,
o,
o + 1,
s - 1,
dtype,
"out of place",
),
)
yield inp1()[1:], inp1()[:-1], inp2()[:-1], bfmt % (
o + 1,
o,
o,
s - 1,
dtype,
"aliased",
yield (
inp1()[1:],
inp1()[:-1],
inp2()[:-1],
bfmt
% (
o + 1,
o,
o,
s - 1,
dtype,
"aliased",
),
)
yield inp1()[:-1], inp1()[1:], inp2()[:-1], bfmt % (
o,
o + 1,
o,
s - 1,
dtype,
"aliased",
yield (
inp1()[:-1],
inp1()[1:],
inp2()[:-1],
bfmt
% (
o,
o + 1,
o,
s - 1,
dtype,
"aliased",
),
)
yield inp1()[:-1], inp1()[:-1], inp2()[1:], bfmt % (
o,
o,
o + 1,
s - 1,
dtype,
"aliased",
yield (
inp1()[:-1],
inp1()[:-1],
inp2()[1:],
bfmt
% (
o,
o,
o + 1,
s - 1,
dtype,
"aliased",
),
)
@ -1818,9 +1880,10 @@ class clear_and_catch_warnings(warnings.catch_warnings):
--------
>>> import warnings
>>> with np.testing.clear_and_catch_warnings( # doctest: +SKIP
... modules=[np.core.fromnumeric]):
... warnings.simplefilter('always')
... warnings.filterwarnings('ignore', module='np.core.fromnumeric')
... modules=[np.core.fromnumeric]
... ):
... warnings.simplefilter("always")
... warnings.filterwarnings("ignore", module="np.core.fromnumeric")
... # do something that raises a warning but ignore those in
... # np.core.fromnumeric
"""
@ -1918,6 +1981,8 @@ class suppress_warnings:
sup = np.testing.suppress_warnings()
sup.filter(module=np.ma.core) # module must match exactly
@sup
def some_function():
# do something which causes a warning in np.ma.core

View File

@ -2513,7 +2513,11 @@ def _full_aten(
) -> Tensor:
# Note that Mypy thinks torch.full can't accept a complex fill_value
return torch.full(
shape, fill_value, dtype=dtype, device=device, requires_grad=requires_grad # type: ignore[arg-type]
shape,
fill_value,
dtype=dtype,
device=device,
requires_grad=requires_grad, # type: ignore[arg-type]
)
@ -2556,7 +2560,11 @@ def _full_like_aten(
) -> Tensor:
# Note that Mypy thinks torch.full can't accept a complex fill_value
return torch.full_like(
a, fill_value, dtype=dtype, device=device, requires_grad=requires_grad # type: ignore[arg-type]
a,
fill_value,
dtype=dtype,
device=device,
requires_grad=requires_grad, # type: ignore[arg-type]
)

View File

@ -340,9 +340,9 @@ def register_graphsafe_run_with_rng_state_op():
@graphsafe_run_with_rng_state.py_impl(DispatchKey.BackendSelect)
def impl_backend_select(op, *args, rng_state=None, **kwargs):
device = get_device(args, kwargs)
assert (
device == "cuda"
), f"GraphSafe RNG operations only supported for CUDA, got {device}"
assert device == "cuda", (
f"GraphSafe RNG operations only supported for CUDA, got {device}"
)
return impl_cuda(op, *args, rng_state=rng_state, **kwargs)
@graphsafe_run_with_rng_state.py_impl(FakeTensorMode)

View File

@ -33,17 +33,13 @@ if TYPE_CHECKING:
import sympy
class _WorksWithInt(typing.Protocol):
def __add__(self, other: Any) -> typing.Self:
...
def __add__(self, other: Any) -> typing.Self: ...
def __radd__(self, other: Any) -> typing.Self:
...
def __radd__(self, other: Any) -> typing.Self: ...
def __mul__(self, other: Any) -> typing.Self:
...
def __mul__(self, other: Any) -> typing.Self: ...
def __rmul__(self, other: Any) -> typing.Self:
...
def __rmul__(self, other: Any) -> typing.Self: ...
_IntLikeT = TypeVar("_IntLikeT", bound=_WorksWithInt)
@ -292,9 +288,7 @@ def is_contiguous(a: TensorLikeType, false_if_dde=False) -> bool:
# can assume x is not 0 in expected_stride equation. This make the check consistent with
# make_contiguous_strides_for. If we make a tensor and used strides from make_contiguous_strides_for
# and then called definitely_contiguous we should get True.
expected_stride *= (
x if is_nested_int(x) else sym_max(x, 1)
) # type:ignore[assignment]
expected_stride *= x if is_nested_int(x) else sym_max(x, 1) # type:ignore[assignment]
return True
@ -912,7 +906,7 @@ def extract_shape(*args, allow_cpu_scalar_tensors: bool) -> Optional[ShapeType]:
# Extracts dimensions that might be passed either as a list/tuple or as varargs.
# A typical case is Tensor.permute .
def extract_dims_from_varargs(
dims: Union[DimsSequenceType, tuple[DimsSequenceType, ...]]
dims: Union[DimsSequenceType, tuple[DimsSequenceType, ...]],
) -> DimsSequenceType:
if dims and isinstance(dims[0], Sequence):
assert len(dims) == 1
@ -1234,7 +1228,7 @@ def get_higher_dtype(
assert b is None or isinstance(b, (torch.dtype, TensorLike, Number))
def _extract_dtype(
x: Optional[Union[torch.dtype, TensorLikeType, NumberType]]
x: Optional[Union[torch.dtype, TensorLikeType, NumberType]],
) -> Optional[torch.dtype]:
if x is None:
return None
@ -1452,7 +1446,7 @@ class RETURN_TYPE(Enum):
# TODO: when NumberType contains the sym types, can simplify this
def number_type(
x: Union[NumberType, torch.SymInt, torch.SymFloat, torch.SymBool]
x: Union[NumberType, torch.SymInt, torch.SymFloat, torch.SymBool],
) -> type:
if isinstance(x, torch.SymInt):
return int
@ -1708,9 +1702,7 @@ def make_contiguous_strides_for(
strides = []
for l in reversed(shape):
strides.append(multiplier)
multiplier *= (
l if is_nested_int(l) else sym_max(l, 1)
) # type:ignore[assignment]
multiplier *= l if is_nested_int(l) else sym_max(l, 1) # type:ignore[assignment]
result = tuple(reversed(strides))
@ -1860,7 +1852,9 @@ def compute_required_storage_length(
>>> # xdoctest: +SKIP(failing)
>>> t2 = torch.empty_strided((1, 2, 3), (5, 7, 11))
>>> size = compute_required_storage_length(t2.shape, t2.stride(), t2.storage_offset())
>>> size = compute_required_storage_length(
... t2.shape, t2.stride(), t2.storage_offset()
... )
>>> size == t.storage().size()
True
@ -1870,7 +1864,9 @@ def compute_required_storage_length(
>>> slice.storage().size()
100
>>> compute_required_storage_length(slice.shape, slice.stride(), slice.storage_offset())
>>> compute_required_storage_length(
... slice.shape, slice.stride(), slice.storage_offset()
... )
40
"""

View File

@ -316,8 +316,7 @@ def out_wrapper(
and len(result) == len(out_names) # type: ignore[arg-type]
)
or (
fn.__name__ == "unbind"
and isinstance(result, (list, tuple)) # type: ignore[arg-type]
fn.__name__ == "unbind" and isinstance(result, (list, tuple)) # type: ignore[arg-type]
)
)
# unbind_copy is a special case: see https://github.com/pytorch/pytorch/issues/130829
@ -342,9 +341,15 @@ def out_wrapper(
assert isinstance(out, TensorLike)
# These two operations are done in-place
_maybe_resize_out(
out, result.shape, maybe_compute_memory_format(result) # type: ignore[union-attr]
out,
result.shape, # type: ignore[union-attr]
maybe_compute_memory_format(result),
)
_safe_copy_out(
copy_from=result, # type: ignore[arg-type]
copy_to=out,
exact_dtype=exact_dtype,
)
_safe_copy_out(copy_from=result, copy_to=out, exact_dtype=exact_dtype) # type: ignore[arg-type]
else:
if fn.__name__ != "unbind":
assert isinstance(out, tuple) # type: ignore[arg-type]
@ -385,7 +390,8 @@ def out_wrapper(
params = sorted(params, key=lambda p: p.kind)
_fn.__signature__ = inspect.Signature( # type: ignore[attr-defined]
parameters=params, return_annotation=return_type # type: ignore[arg-type]
parameters=params,
return_annotation=return_type, # type: ignore[arg-type]
)
_fn.__annotations__ = dict(getattr(fn, "__annotations__", {}))
@ -400,7 +406,9 @@ def out_wrapper(
# Add an indicator attribute that can be used in special cases
# where having a function wrapped by `out_wrapper` is not desirable e.g.
# jit
_fn._torch_decompositions_out_wrapper = f"This function is wrapped by {out_wrapper.__module__}.out_wrapper" # type: ignore[attr-defined]
_fn._torch_decompositions_out_wrapper = ( # type: ignore[attr-defined]
f"This function is wrapped by {out_wrapper.__module__}.out_wrapper"
)
return _fn

View File

@ -313,7 +313,8 @@ def _canonicalize_fft_shape_and_dim_args(
# Translate any -1 values in shape to the default length
ret_shape = tuple(
s if s != -1 else input_sizes[d] for (s, d) in zip(shape, ret_dims) # type: ignore[possibly-undefined]
s if s != -1 else input_sizes[d]
for (s, d) in zip(shape, ret_dims) # type: ignore[possibly-undefined]
)
elif dim is None:
# No shape, no dim

View File

@ -310,7 +310,7 @@ def strobelight(
profiler = StrobelightCLIFunctionProfiler(**kwargs)
def strobelight_inner(
work_function: Callable[_P, _R]
work_function: Callable[_P, _R],
) -> Callable[_P, Optional[_R]]:
@functools.wraps(work_function)
def wrapper_function(*args: _P.args, **kwargs: _P.kwargs) -> Optional[_R]:

View File

@ -129,9 +129,9 @@ def _is_tensor_constructor(func: OpOverload):
def register_op_impl(run_impl_check: Union[Callable[[OpOverload], bool], OpOverload]):
def impl_decorator(op_impl):
if isinstance(run_impl_check, OpOverload):
assert (
run_impl_check not in op_implementations_dict
), f"duplicate registration: {run_impl_check}"
assert run_impl_check not in op_implementations_dict, (
f"duplicate registration: {run_impl_check}"
)
op_implementations_dict[run_impl_check] = op_impl
elif isinstance(run_impl_check, (list, tuple)):
for op in run_impl_check:
@ -575,25 +575,25 @@ def assert_tensor_metadata(
layout=None,
) -> None:
if sizes is not None:
assert (
t.size() == sizes
), f"Tensor sizes mismatch! Expected: {sizes}, Got: {t.size()}"
assert t.size() == sizes, (
f"Tensor sizes mismatch! Expected: {sizes}, Got: {t.size()}"
)
if strides is not None:
assert (
t.stride() == strides
), f"Tensor strides mismatch! Expected: {strides}, Got: {t.stride()}"
assert t.stride() == strides, (
f"Tensor strides mismatch! Expected: {strides}, Got: {t.stride()}"
)
if dtype is not None:
assert (
t.dtype == dtype
), f"Tensor dtype mismatch! Expected: {dtype}, Got: {t.dtype}"
assert t.dtype == dtype, (
f"Tensor dtype mismatch! Expected: {dtype}, Got: {t.dtype}"
)
if layout is not None:
assert (
t.layout == layout
), f"Tensor layout mismatch! Expected: {layout}, Got: {t.layout()}"
assert t.layout == layout, (
f"Tensor layout mismatch! Expected: {layout}, Got: {t.layout()}"
)
if device is not None:
assert (
t.device == device
), f"Tensor device mismatch! Expected: {device}, Got: {t.device}"
assert t.device == device, (
f"Tensor device mismatch! Expected: {device}, Got: {t.device}"
)
# NB: this must be ordered after local_scalar_dense
@ -1091,7 +1091,9 @@ def get_fast_op_impls():
register_fast_op_impl(torch.ops.aten.sub.Tensor)(
make_fast_binary_impl(torch._refs.sub)
)
register_fast_op_impl(torch.ops.aten.mul.Tensor)(make_fast_binary_impl(torch._refs.mul)) # type: ignore[has-type]
register_fast_op_impl(torch.ops.aten.mul.Tensor)(
make_fast_binary_impl(torch._refs.mul)
) # type: ignore[has-type]
register_fast_op_impl(torch.ops.aten.div.Tensor)(
make_fast_binary_impl(
torch._refs.div,

View File

@ -496,9 +496,9 @@ class FakeTensorConverter:
pytype: Optional[type[torch.Tensor]] = None,
dispatch_keys: Optional[torch.DispatchKeySet] = None,
) -> FakeTensor:
assert (
t.device.type == "meta"
), f"tensor's device must be `meta`, got {t.device.type} instead"
assert t.device.type == "meta", (
f"tensor's device must be `meta`, got {t.device.type} instead"
)
# This is a bit abusive (this is not the "real" tensor) but whatever,
# the meta tensor should be fresh so there's no way to get it wrong
maybe_memo = self._get_memo(t)
@ -1594,7 +1594,10 @@ class FakeTensorMode(TorchDispatchMode):
if torch.Tag.dynamic_output_shape in func.tags:
if func is aten.index.Tensor:
_, new_kwargs = normalize_function( # type: ignore[misc]
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True # type: ignore[arg-type]
func,
args=args, # type: ignore[arg-type]
kwargs=kwargs, # type: ignore[arg-type]
normalize_to_only_use_kwargs=True,
)
for index in new_kwargs["indices"]:
# index calls nonzero for bool or int8 tensors, and
@ -2136,9 +2139,7 @@ class FakeTensorMode(TorchDispatchMode):
try:
_check_fake_real_vals(s_fake, s_real)
except MetadataMismatchError as exc:
if (
torch._functorch.config.generate_fake_kernels_from_real_mismatches
):
if torch._functorch.config.generate_fake_kernels_from_real_mismatches:
dtrace_structured(
"mismatched_fake_kernel",
metadata_fn=lambda: {
@ -2311,9 +2312,9 @@ class FakeTensorMode(TorchDispatchMode):
and not flat_arg_fake_tensors
and not device_conversion_skip_const_prop
):
assert all(
t.constant is not None for t in flat_arg_fake_tensors
), f"{func} should not have fake inputs without constants"
assert all(t.constant is not None for t in flat_arg_fake_tensors), (
f"{func} should not have fake inputs without constants"
)
const_flat_args = [
a.constant if self.is_our_fake(a) else a for a in flat_args
]
@ -2538,9 +2539,7 @@ class FakeTensorMode(TorchDispatchMode):
if real_out is not nil:
# cross check fake/real outputs, and optionally override fake kernel mismatches
if (
not torch._functorch.config.generate_fake_kernels_from_real_mismatches
):
if not torch._functorch.config.generate_fake_kernels_from_real_mismatches:
self._maybe_infer_fake_kernel_from_pytree_out(
func,
(args, kwargs),
@ -2924,7 +2923,10 @@ class FakeTensorMode(TorchDispatchMode):
schema_info = get_schema_info(func)
if any_constant and schema_info.is_mutable():
_, new_kwargs = normalize_function( # type: ignore[misc]
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True # type: ignore[arg-type]
func,
args=args, # type: ignore[arg-type]
kwargs=kwargs, # type: ignore[arg-type]
normalize_to_only_use_kwargs=True,
)
for k, v in new_kwargs.items():
k = k if (k != "input" or schema_info.has_argument(k)) else "self"
@ -2948,9 +2950,9 @@ class FakeTensorMode(TorchDispatchMode):
if static_shapes is None:
static_shapes = self.static_shapes
if static_shapes:
assert (
symbolic_context is None
), "cannot set both static_shapes and symbolic_context"
assert symbolic_context is None, (
"cannot set both static_shapes and symbolic_context"
)
shape_env = None
return self.fake_tensor_converter.from_real_tensor(
self,

View File

@ -102,7 +102,7 @@ def is_sdpa_error(func, idx, e):
def try_convert_fake_to_real(
ten_list: list[Union[FakeTensor, Any]]
ten_list: list[Union[FakeTensor, Any]],
) -> list[Union[FakeTensor, torch.Tensor, Any]]:
"""
Attempt to convert fake tensors to a corresponding real tensor with the correct underlying storage by looking up
@ -266,9 +266,9 @@ class CrossRefFakeMode(TorchDispatchMode):
if fake_r is not None:
r_flat = pytree.tree_leaves(r)
f_flat = pytree.tree_leaves(fake_r)
assert len(f_flat) == len(
r_flat
), f"{context} mismatch in number of returns {len(f_flat)} != {len(r_flat)}"
assert len(f_flat) == len(r_flat), (
f"{context} mismatch in number of returns {len(f_flat)} != {len(r_flat)}"
)
if self.check_aliasing:
_check_alias_info(
@ -279,9 +279,9 @@ class CrossRefFakeMode(TorchDispatchMode):
zip(pytree.tree_leaves(r), pytree.tree_leaves(fake_r))
):
r_is_ten = isinstance(r_out, torch.Tensor)
assert r_is_ten == isinstance(
f_out, torch.Tensor
), f"{context} mismatched number of tensor outputs"
assert r_is_ten == isinstance(f_out, torch.Tensor), (
f"{context} mismatched number of tensor outputs"
)
if r_is_ten:
try:
_check_fake_real_tensors(

View File

@ -357,7 +357,9 @@ class MetaTensorDescriber:
maybe_functorch_stack = None
if is_functorch_wrapped:
with torch._functorch.pyfunctorch.temporarily_clear_interpreter_stack() as maybe_functorch_stack:
with (
torch._functorch.pyfunctorch.temporarily_clear_interpreter_stack()
) as maybe_functorch_stack:
pass
attrs = None
@ -517,8 +519,7 @@ class ViewFunc(Generic[_TensorT]):
new_base: _TensorT,
symint_visitor_fn: Optional[Callable[[int], int]] = None,
tensor_visitor_fn: Optional[Callable[[torch.Tensor], _TensorT]] = None,
) -> _TensorT:
...
) -> _TensorT: ...
@staticmethod
def from_tensor(t: torch.Tensor) -> ViewFunc:
@ -574,8 +575,7 @@ class _CustomViewFunc(ViewFunc[_TensorT], Generic[_TensorT]):
class _MetaTensorCallback(Protocol, Generic[_TensorT_cov]):
def __call__(
self, arg: Callable[[], torch.Tensor], /, *, device: Union[torch.device, str]
) -> _TensorT_cov:
...
) -> _TensorT_cov: ...
class _MetaTensorCallbackKwargs(TypedDict, total=False):
@ -592,8 +592,7 @@ class _MetaTensorCallbackOptDevice(Protocol, Generic[_TensorT_cov]):
arg: Callable[[], torch.Tensor],
/,
**kwargs: Unpack[_MetaTensorCallbackKwargs],
) -> _TensorT_cov:
...
) -> _TensorT_cov: ...
@dataclass(frozen=True)
@ -785,9 +784,9 @@ class MetaConverter(Generic[_TensorT]):
] = weakref.WeakValueDictionary()
# Maps MetaTensorId to torch.Tensor (typically a meta tensor or
# FakeTensor)
self.tensor_memo: weakref.WeakValueDictionary[
MetaTensorId, _TensorT
] = weakref.WeakValueDictionary()
self.tensor_memo: weakref.WeakValueDictionary[MetaTensorId, _TensorT] = (
weakref.WeakValueDictionary()
)
self.hit = 0
self.miss = 0
self.del_hook = None
@ -1772,9 +1771,9 @@ class MetaConverter(Generic[_TensorT]):
# subclasses. Relevant test is
# DynamicShapesFunctionTests::test_add_dynamic_shapes in
# test/dynamo/test_dynamic_shapes.py
maybe_fake_mgr: AbstractContextManager[
None
] = contextlib.nullcontext()
maybe_fake_mgr: AbstractContextManager[None] = (
contextlib.nullcontext()
)
from torch._subclasses.fake_tensor import (
in_kernel_invocation_manager,
maybe_get_fake_mode,