mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[BE][PYFMT] migrate PYFMT for torch/_[a-h]*/ to ruff format (#144551)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144551 Approved by: https://github.com/ezyang ghstack dependencies: #148186
This commit is contained in:
parent
9642c75689
commit
162ca185ff
|
|
@ -47,12 +47,9 @@ USE_BLACK_FILELIST = re.compile(
|
|||
"test/[p-z]*/**",
|
||||
# torch/**
|
||||
# torch/_[a-c]*/**
|
||||
"torch/_[a-c]*/**",
|
||||
# torch/_[e-h]*/**
|
||||
"torch/_[e-h]*/**",
|
||||
# torch/_i*/**
|
||||
# torch/_[j-z]*/**
|
||||
"torch/_[j-z]*/**",
|
||||
# torch/[a-c]*/**
|
||||
"torch/a[a-n]*/**",
|
||||
"torch/a[p-z]*/**",
|
||||
|
|
|
|||
|
|
@ -624,9 +624,9 @@ class TS2FXGraphConverter:
|
|||
self.fx_graph, name, self.is_top_level_graph()
|
||||
)
|
||||
elif name in self.name_to_constant:
|
||||
assert isinstance(
|
||||
self.name_to_constant[name], torch.ScriptObject
|
||||
), "Input conversion only handles ScriptObject"
|
||||
assert isinstance(self.name_to_constant[name], torch.ScriptObject), (
|
||||
"Input conversion only handles ScriptObject"
|
||||
)
|
||||
normalized_name = normalize_name(name)
|
||||
self.input_specs.append(
|
||||
InputSpec(
|
||||
|
|
@ -661,9 +661,7 @@ class TS2FXGraphConverter:
|
|||
def to_float_tensor(t):
|
||||
return t.to(dtype=torch.float).item()
|
||||
|
||||
inp_list = [
|
||||
self.get_fx_value_by_ir_value(inp) for inp in node.inputs()
|
||||
] # noqa: C416
|
||||
inp_list = [self.get_fx_value_by_ir_value(inp) for inp in node.inputs()] # noqa: C416
|
||||
fx_node = self.fx_graph.call_function(
|
||||
to_float_tensor,
|
||||
tuple(inp_list),
|
||||
|
|
@ -749,9 +747,7 @@ class TS2FXGraphConverter:
|
|||
self.name_to_constant[name] = value
|
||||
|
||||
def convert_prim_CallMethod(self, node: torch._C.Node):
|
||||
inp_list = [
|
||||
self.get_fx_value_by_ir_value(inp) for inp in node.inputs()
|
||||
] # noqa: C416
|
||||
inp_list = [self.get_fx_value_by_ir_value(inp) for inp in node.inputs()] # noqa: C416
|
||||
fx_node = self.fx_graph.call_method(
|
||||
node.s("name"),
|
||||
tuple(inp_list),
|
||||
|
|
@ -783,9 +779,9 @@ class TS2FXGraphConverter:
|
|||
self.name_to_node[output_name] = self.fx_graph.get_attr(attr_fqn)
|
||||
else:
|
||||
if attr_fqn not in self.name_to_non_tensor_attribute_node:
|
||||
self.name_to_non_tensor_attribute_node[
|
||||
attr_fqn
|
||||
] = self.name_to_non_tensor_attribute[attr_fqn]
|
||||
self.name_to_non_tensor_attribute_node[attr_fqn] = (
|
||||
self.name_to_non_tensor_attribute[attr_fqn]
|
||||
)
|
||||
self.name_to_node[output_name] = self.name_to_non_tensor_attribute_node[
|
||||
attr_fqn
|
||||
]
|
||||
|
|
@ -850,15 +846,15 @@ class TS2FXGraphConverter:
|
|||
k = self.get_fx_value_by_ir_value(inp)
|
||||
else:
|
||||
v = self.get_fx_value_by_ir_value(inp)
|
||||
assert (
|
||||
k is not None and v is not None
|
||||
), "DictConstruct has an empty key value pair."
|
||||
assert k is not None and v is not None, (
|
||||
"DictConstruct has an empty key value pair."
|
||||
)
|
||||
output_dict[k] = v
|
||||
k, v = None, None
|
||||
|
||||
assert (
|
||||
k is None and v is None
|
||||
), "DictConstruct has an odd number of elements (violating our assumption)."
|
||||
assert k is None and v is None, (
|
||||
"DictConstruct has an odd number of elements (violating our assumption)."
|
||||
)
|
||||
|
||||
output_name = node.output().debugName()
|
||||
self.name_to_node[output_name] = output_dict
|
||||
|
|
@ -1124,9 +1120,9 @@ class TS2FXGraphConverter:
|
|||
), # + 1 because the 0th element is the condition.
|
||||
)
|
||||
global_argument_index = global_arguments.index(name)
|
||||
fx_block_args[
|
||||
i + node.outputsSize() + global_argument_index
|
||||
] = self.name_to_node[name]
|
||||
fx_block_args[i + node.outputsSize() + global_argument_index] = (
|
||||
self.name_to_node[name]
|
||||
)
|
||||
|
||||
def _check_set_attr_in_if_block(self, if_node: torch._C.Node):
|
||||
for block in if_node.blocks():
|
||||
|
|
@ -1545,9 +1541,9 @@ DEBUG: (TORCH_LOGS="+export" <cmd>), additionally
|
|||
for spec in ep.graph_signature.input_specs:
|
||||
# Mark as constant tensors for erroneously traced buffers.
|
||||
if spec.kind == InputKind.BUFFER and spec.target in name_to_constant:
|
||||
assert isinstance(
|
||||
name_to_constant[spec.target], torch.Tensor
|
||||
), f"{type(name_to_constant[spec.target])} has been erroneously marked as buffer"
|
||||
assert isinstance(name_to_constant[spec.target], torch.Tensor), (
|
||||
f"{type(name_to_constant[spec.target])} has been erroneously marked as buffer"
|
||||
)
|
||||
spec.kind = InputKind.CONSTANT_TENSOR
|
||||
spec.persistent = None
|
||||
ep.verifier().check(ep)
|
||||
|
|
|
|||
|
|
@ -169,7 +169,10 @@ def fakify(
|
|||
return t
|
||||
|
||||
if isinstance(t, _IntWrapper):
|
||||
if t.dynamism is not None and t.dynamism.type in (_DimHintType.DYNAMIC, _DimHintType.AUTO): # type: ignore[union-attr]
|
||||
if t.dynamism is not None and t.dynamism.type in ( # type: ignore[union-attr]
|
||||
_DimHintType.DYNAMIC,
|
||||
_DimHintType.AUTO,
|
||||
):
|
||||
symint = mode.shape_env.create_unspecified_symint_and_symbol( # type: ignore[union-attr]
|
||||
t.val, source, DimDynamic.DYNAMIC
|
||||
)
|
||||
|
|
|
|||
|
|
@ -252,8 +252,11 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase):
|
|||
else:
|
||||
raise ExportPassBaseError(f"Unsupported target type: {target}")
|
||||
|
||||
def get_attr(
|
||||
self, target: str, args: tuple[Argument, ...], kwargs: dict[str, Argument] # type: ignore[override]
|
||||
def get_attr( # type: ignore[override]
|
||||
self,
|
||||
target: str,
|
||||
args: tuple[Argument, ...],
|
||||
kwargs: dict[str, Argument],
|
||||
) -> Argument:
|
||||
return super().get_attr(target, args, kwargs)
|
||||
|
||||
|
|
@ -265,8 +268,11 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase):
|
|||
) -> None:
|
||||
raise ExportPassBaseError("call_module is not supported.")
|
||||
|
||||
def call_method(
|
||||
self, target: str, args: tuple[Argument, ...], kwargs: dict[str, Argument] # type: ignore[override]
|
||||
def call_method( # type: ignore[override]
|
||||
self,
|
||||
target: str,
|
||||
args: tuple[Argument, ...],
|
||||
kwargs: dict[str, Argument],
|
||||
) -> None:
|
||||
raise ExportPassBaseError("call_method is not supported.")
|
||||
|
||||
|
|
@ -426,13 +432,17 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase):
|
|||
def call_submodule(
|
||||
self, graph_module: fx.GraphModule, inputs: tuple[Argument, ...]
|
||||
) -> PassResult:
|
||||
prev_tracer, self.tracer = self.tracer, self.ExportTracer(
|
||||
self, graph_module.graph._codegen
|
||||
prev_tracer, self.tracer = (
|
||||
self.tracer,
|
||||
self.ExportTracer(self, graph_module.graph._codegen),
|
||||
)
|
||||
self.tracer.fake_tensor_mode = prev_tracer.fake_tensor_mode
|
||||
interpreter = self.ExportInterpreter(self, graph_module)
|
||||
prev_interpreter, self.interpreter = self.interpreter, torch.fx.Interpreter( # type: ignore[assignment]
|
||||
torch.fx.GraphModule(torch.nn.Module(), torch.fx.Graph())
|
||||
prev_interpreter, self.interpreter = (
|
||||
self.interpreter,
|
||||
torch.fx.Interpreter( # type: ignore[assignment]
|
||||
torch.fx.GraphModule(torch.nn.Module(), torch.fx.Graph())
|
||||
),
|
||||
)
|
||||
inputs_data = pytree.tree_map_only(ProxyValue, lambda x: x.data, inputs)
|
||||
with fx_traceback.preserve_node_meta():
|
||||
|
|
@ -458,9 +468,9 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase):
|
|||
fake_tensor_mode = None
|
||||
for i in inputs:
|
||||
if isinstance(i, FakeTensor):
|
||||
assert (
|
||||
fake_tensor_mode is None or fake_tensor_mode is i.fake_mode
|
||||
), "Multiple fake tensor mode detected."
|
||||
assert fake_tensor_mode is None or fake_tensor_mode is i.fake_mode, (
|
||||
"Multiple fake tensor mode detected."
|
||||
)
|
||||
fake_tensor_mode = i.fake_mode
|
||||
if fake_tensor_mode is None:
|
||||
self.tracer.fake_tensor_mode = FakeTensorMode(allow_non_fake_inputs=True)
|
||||
|
|
|
|||
|
|
@ -16,12 +16,15 @@ def insert_custom_op_guards(gm: torch.fx.GraphModule, ops_to_guard: set[str]) ->
|
|||
"""
|
||||
for node in gm.graph.nodes:
|
||||
if node.op == "call_function" and str(node.target) in ops_to_guard:
|
||||
with _set_node_metadata_hook(
|
||||
gm,
|
||||
functools.partial(
|
||||
_node_metadata_hook, stack_trace=node.meta.get("stack_trace")
|
||||
with (
|
||||
_set_node_metadata_hook(
|
||||
gm,
|
||||
functools.partial(
|
||||
_node_metadata_hook, stack_trace=node.meta.get("stack_trace")
|
||||
),
|
||||
),
|
||||
), gm.graph.inserting_before(node):
|
||||
gm.graph.inserting_before(node),
|
||||
):
|
||||
for arg in (*node.args, *node.kwargs.values()):
|
||||
if isinstance(arg, torch.fx.Node) and isinstance(
|
||||
arg.meta.get("val"), torch.Tensor
|
||||
|
|
|
|||
|
|
@ -373,7 +373,7 @@ def lift_constants_pass(
|
|||
|
||||
def rewrite_script_object_meta(
|
||||
gm: torch.fx.GraphModule,
|
||||
) -> dict[str, _ConstantAttributeType,]:
|
||||
) -> dict[str, _ConstantAttributeType]:
|
||||
"""When tracing, we produce a graph with FakeScriptObject in the
|
||||
meta["val"].
|
||||
|
||||
|
|
|
|||
|
|
@ -107,20 +107,20 @@ def _dump_dynamic_shapes(
|
|||
would generate the following output:
|
||||
```
|
||||
{
|
||||
'dynamic_shapes': (
|
||||
"dynamic_shapes": (
|
||||
[
|
||||
['dx', 4],
|
||||
['dx + 1', 4],
|
||||
["dx", 4],
|
||||
["dx + 1", 4],
|
||||
],
|
||||
['_DimHint.STATIC'],
|
||||
['_DimHint.STATIC', '_DimHint.STATIC'],
|
||||
["_DimHint.STATIC"],
|
||||
["_DimHint.STATIC", "_DimHint.STATIC"],
|
||||
None,
|
||||
),
|
||||
'dims': {
|
||||
'dx': {
|
||||
'min': 4,
|
||||
'max': 16,
|
||||
'derived': ['dx + 1'],
|
||||
"dims": {
|
||||
"dx": {
|
||||
"min": 4,
|
||||
"max": 16,
|
||||
"derived": ["dx + 1"],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
|
@ -149,7 +149,7 @@ def _dump_dynamic_shapes(
|
|||
return out
|
||||
|
||||
def _track_dim_from_dims(
|
||||
val: Union[None, int, _DimHint, Dim]
|
||||
val: Union[None, int, _DimHint, Dim],
|
||||
) -> Union[None, int, str]:
|
||||
"""
|
||||
Tracks dims, ranges, derived dims from the standardized dynamic_shapes spec.
|
||||
|
|
@ -295,7 +295,7 @@ def _load_dynamic_shapes(
|
|||
dim_cache[_expr] = ddim # cache derived dims
|
||||
|
||||
def deserialize_shape(
|
||||
val: Union[None, int, str]
|
||||
val: Union[None, int, str],
|
||||
) -> Union[None, int, Dim, _DimHint]:
|
||||
if val is None or isinstance(val, int):
|
||||
return val
|
||||
|
|
|
|||
|
|
@ -129,13 +129,13 @@ def _staged_schema():
|
|||
t, cpp_type, thrift_type = dump_type(f.type, 0)
|
||||
ret = {"type": t}
|
||||
cpp_default: Optional[str] = None
|
||||
assert (
|
||||
typing.get_origin(f.type) == Annotated
|
||||
), f"Field {f.name} must be annotated with an integer id."
|
||||
assert typing.get_origin(f.type) == Annotated, (
|
||||
f"Field {f.name} must be annotated with an integer id."
|
||||
)
|
||||
thrift_id = f.type.__metadata__[0]
|
||||
assert (
|
||||
type(thrift_id) is int
|
||||
), f"Field {f.name} must be annotated with an integer id."
|
||||
assert type(thrift_id) is int, (
|
||||
f"Field {f.name} must be annotated with an integer id."
|
||||
)
|
||||
|
||||
value = dataclasses.MISSING
|
||||
if f.default is not dataclasses.MISSING:
|
||||
|
|
@ -173,9 +173,7 @@ def _staged_schema():
|
|||
|
||||
def _handle_int_enum(name, ty):
|
||||
yaml_ret[name] = {"kind": "enum", "fields": {x.name: x.value for x in ty}}
|
||||
cpp_enum_defs[
|
||||
name
|
||||
] = f"""
|
||||
cpp_enum_defs[name] = f"""
|
||||
enum class {name} {{
|
||||
{chr(10).join([f" {x.name} = {x.value}," for x in ty])}
|
||||
}};
|
||||
|
|
@ -240,14 +238,17 @@ enum {name} {{
|
|||
|
||||
from_json_def = f"""{{
|
||||
{name} nlohmann_json_default_obj;
|
||||
{chr(10).join(
|
||||
[f' nlohmann_json_t.{name} = nlohmann_json_j.value("{name}", nlohmann_json_default_obj.{name});'
|
||||
for name, f in cpp_fields.items()])}
|
||||
{
|
||||
chr(10).join(
|
||||
[
|
||||
f' nlohmann_json_t.{name} = nlohmann_json_j.value("{name}", nlohmann_json_default_obj.{name});'
|
||||
for name, f in cpp_fields.items()
|
||||
]
|
||||
)
|
||||
}
|
||||
}}
|
||||
"""
|
||||
cpp_class_defs[
|
||||
name
|
||||
] = f"""
|
||||
cpp_class_defs[name] = f"""
|
||||
class {name} {{
|
||||
private:
|
||||
{field_decls}
|
||||
|
|
@ -262,9 +263,7 @@ class {name} {{
|
|||
cpp_json_defs.append(f"inline {from_json_decl} {from_json_def}")
|
||||
cpp_type_decls.append(f"class {name};")
|
||||
|
||||
thrift_type_defs[
|
||||
name
|
||||
] = f"""
|
||||
thrift_type_defs[name] = f"""
|
||||
struct {name} {{
|
||||
{chr(10).join(f" {f['thrift_id']}: {f['thrift_type']} {n};" for n, f in thrift_fields.items())}
|
||||
}}"""
|
||||
|
|
@ -307,9 +306,7 @@ struct {name} {{
|
|||
]
|
||||
)
|
||||
|
||||
cpp_class_defs[
|
||||
name
|
||||
] = f"""
|
||||
cpp_class_defs[name] = f"""
|
||||
class {name} {{
|
||||
struct Void {{}};
|
||||
|
||||
|
|
@ -352,9 +349,7 @@ inline void parseEnum(std::string_view s, {name}::Tag& t) {{
|
|||
"""
|
||||
cpp_type_decls.append(f"class {name};")
|
||||
|
||||
thrift_type_defs[
|
||||
name
|
||||
] = f"""
|
||||
thrift_type_defs[name] = f"""
|
||||
union {name} {{
|
||||
{chr(10).join(f" {f['thrift_id']}: {f['thrift_type']} {n};" for n, f in thrift_fields.items())}
|
||||
}}"""
|
||||
|
|
|
|||
|
|
@ -322,9 +322,9 @@ def _reconstruct_fake_tensor(
|
|||
json_tensor_meta = json.loads(serialized_tensor_meta.decode("utf-8"))
|
||||
tensor_meta = _dict_to_dataclass(TensorMeta, json_tensor_meta)
|
||||
# Find the current fake mode
|
||||
assert (
|
||||
_CURRENT_DESERIALIZER is not None
|
||||
), "Need access to current deserializer state"
|
||||
assert _CURRENT_DESERIALIZER is not None, (
|
||||
"Need access to current deserializer state"
|
||||
)
|
||||
fake_tensor = _CURRENT_DESERIALIZER.deserialize_tensor_meta(tensor_meta)
|
||||
if is_parameter:
|
||||
fake_tensor = torch.nn.Parameter(fake_tensor) # type: ignore[assignment]
|
||||
|
|
@ -337,9 +337,9 @@ def serialize_torch_artifact(
|
|||
if artifact is None:
|
||||
return b""
|
||||
|
||||
assert (
|
||||
FakeTensor not in copyreg.dispatch_table
|
||||
), "Refusing to stomp on existing FakeTensor reducer"
|
||||
assert FakeTensor not in copyreg.dispatch_table, (
|
||||
"Refusing to stomp on existing FakeTensor reducer"
|
||||
)
|
||||
try:
|
||||
copyreg.pickle(FakeTensor, _reduce_fake_tensor)
|
||||
buffer = io.BytesIO()
|
||||
|
|
@ -356,7 +356,7 @@ def serialize_torch_artifact(
|
|||
|
||||
|
||||
def deserialize_torch_artifact(
|
||||
serialized: Union[dict[str, Any], tuple[Any, ...], bytes]
|
||||
serialized: Union[dict[str, Any], tuple[Any, ...], bytes],
|
||||
):
|
||||
if isinstance(serialized, (dict, tuple)):
|
||||
return serialized
|
||||
|
|
@ -415,7 +415,7 @@ def _symbol_index(sym: sympy.Symbol, sym_type: SymT):
|
|||
|
||||
|
||||
def serialize_range_constraints(
|
||||
range_constraints: dict[sympy.Symbol, ValueRanges]
|
||||
range_constraints: dict[sympy.Symbol, ValueRanges],
|
||||
) -> dict[str, RangeConstraint]:
|
||||
return {
|
||||
str(k): RangeConstraint(
|
||||
|
|
@ -499,9 +499,9 @@ class GraphModuleSerializer(metaclass=Final):
|
|||
graph_input = Argument.create(
|
||||
as_custom_obj=CustomObjArgument(name=node.name, class_fqn=class_fqn)
|
||||
)
|
||||
self.graph_state.custom_obj_values[
|
||||
node.name
|
||||
] = self.serialize_script_obj_meta(val)
|
||||
self.graph_state.custom_obj_values[node.name] = (
|
||||
self.serialize_script_obj_meta(val)
|
||||
)
|
||||
else:
|
||||
raise AssertionError(f"Unimplemented graph input type: {node.meta['val']}")
|
||||
self.graph_state.inputs.append(graph_input)
|
||||
|
|
@ -627,9 +627,9 @@ class GraphModuleSerializer(metaclass=Final):
|
|||
)
|
||||
elif type(node.target) in _serialization_registry:
|
||||
# Sanity check for unhandled serialization.
|
||||
assert (
|
||||
type(node.target) in _serialization_registry
|
||||
), f"{type(node.target)} is not supported in export serialization."
|
||||
assert type(node.target) in _serialization_registry, (
|
||||
f"{type(node.target)} is not supported in export serialization."
|
||||
)
|
||||
|
||||
handler = _serialization_registry[type(node.target)]
|
||||
namespace = handler.namespace()
|
||||
|
|
@ -1295,9 +1295,9 @@ class GraphModuleSerializer(metaclass=Final):
|
|||
f"but somehow previously was found to have field names {field_names}."
|
||||
)
|
||||
else:
|
||||
self.treespec_namedtuple_fields[
|
||||
serialized_type_name
|
||||
] = NamedTupleDef(field_names=ts.context._fields)
|
||||
self.treespec_namedtuple_fields[serialized_type_name] = (
|
||||
NamedTupleDef(field_names=ts.context._fields)
|
||||
)
|
||||
|
||||
for child in ts.children_specs:
|
||||
store_namedtuple_fields(child)
|
||||
|
|
@ -1516,9 +1516,9 @@ class GraphModuleSerializer(metaclass=Final):
|
|||
|
||||
idx_to_name = {}
|
||||
for user in node.users:
|
||||
assert (
|
||||
user.target is operator.getitem
|
||||
), f"User node {user} of {node} is incorrect"
|
||||
assert user.target is operator.getitem, (
|
||||
f"User node {user} of {node} is incorrect"
|
||||
)
|
||||
idx_to_name[user.args[1]] = user.name
|
||||
|
||||
for idx, _ in enumerate(meta_val):
|
||||
|
|
@ -3529,9 +3529,9 @@ def register_extension(
|
|||
extension_handler: type[ExtensionHandler],
|
||||
):
|
||||
"""Register custom de/serialization method for a node with non-standard type."""
|
||||
assert issubclass(
|
||||
extension_handler, ExtensionHandler
|
||||
), f"Expected ExtensionHandler, got {extension_handler}."
|
||||
assert issubclass(extension_handler, ExtensionHandler), (
|
||||
f"Expected ExtensionHandler, got {extension_handler}."
|
||||
)
|
||||
assert op_type not in _serialization_registry, f"{op_type} is already registered."
|
||||
assert isinstance(op_type, type) # Maybe a good idea to enforce this first.
|
||||
assert not (
|
||||
|
|
|
|||
|
|
@ -18,9 +18,9 @@ class _UnionTag(str):
|
|||
def __eq__(self, cmp) -> bool:
|
||||
assert isinstance(cmp, str)
|
||||
other = str(cmp)
|
||||
assert other in _get_field_names(
|
||||
self._cls
|
||||
), f"{other} is not a valid tag for {self._cls}. Available tags: {_get_field_names(self._cls)}"
|
||||
assert other in _get_field_names(self._cls), (
|
||||
f"{other} is not a valid tag for {self._cls}. Available tags: {_get_field_names(self._cls)}"
|
||||
)
|
||||
return str(self) == other
|
||||
|
||||
def __hash__(self):
|
||||
|
|
@ -43,7 +43,10 @@ class _Union:
|
|||
return obj
|
||||
|
||||
def __post_init__(self):
|
||||
assert not any(f.name in ("type", "_type", "create", "value") for f in fields(self)) # type: ignore[arg-type, misc]
|
||||
assert not any(
|
||||
f.name in ("type", "_type", "create", "value")
|
||||
for f in fields(self) # type: ignore[arg-type, misc]
|
||||
)
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
|
|
|
|||
|
|
@ -448,9 +448,9 @@ def register_dataclass_as_pytree_node(
|
|||
from_dumpable_context: Optional[FromDumpableContextFn] = None,
|
||||
return_none_fields: bool = False,
|
||||
) -> None:
|
||||
assert dataclasses.is_dataclass(
|
||||
cls
|
||||
), f"Only dataclasses can be registered with this function: {cls}"
|
||||
assert dataclasses.is_dataclass(cls), (
|
||||
f"Only dataclasses can be registered with this function: {cls}"
|
||||
)
|
||||
|
||||
def default_flatten_fn(obj: Any) -> tuple[list[Any], Context]:
|
||||
flattened = []
|
||||
|
|
@ -644,11 +644,14 @@ def _insert_aten_to_metadata_assert_pass(gm: torch.fx.GraphModule) -> None:
|
|||
continue
|
||||
|
||||
if (tensor_val := node.args[0].meta.get("val")) is not None:
|
||||
with gm.graph.inserting_before(node), _set_node_metadata_hook(
|
||||
gm,
|
||||
functools.partial(
|
||||
_node_metadata_hook,
|
||||
stack_trace=node.meta.get("stack_trace"),
|
||||
with (
|
||||
gm.graph.inserting_before(node),
|
||||
_set_node_metadata_hook(
|
||||
gm,
|
||||
functools.partial(
|
||||
_node_metadata_hook,
|
||||
stack_trace=node.meta.get("stack_trace"),
|
||||
),
|
||||
),
|
||||
):
|
||||
gm.graph.call_function(
|
||||
|
|
@ -1342,6 +1345,7 @@ def register_module_as_pytree_input_node(cls: type[torch.nn.Module]) -> None:
|
|||
|
||||
import torch
|
||||
|
||||
|
||||
class Module(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
|
@ -1350,12 +1354,15 @@ def register_module_as_pytree_input_node(cls: type[torch.nn.Module]) -> None:
|
|||
def forward(self, x):
|
||||
return self.linear(x)
|
||||
|
||||
|
||||
torch._export.utils.register_module_as_pytree_node(InputDataClass)
|
||||
|
||||
|
||||
class Mod(torch.nn.Module):
|
||||
def forward(self, x, m):
|
||||
return m(x) + x
|
||||
|
||||
|
||||
ep = torch.export.export(Mod(), (torch.randn(3), Module()))
|
||||
print(ep)
|
||||
|
||||
|
|
|
|||
|
|
@ -195,9 +195,9 @@ def mark_subclass_constructor_exportable_experimental(constructor_subclass):
|
|||
for mode in torch_function_mode_stack
|
||||
if isinstance(mode, PreDispatchTorchFunctionMode)
|
||||
]
|
||||
assert (
|
||||
len(pre_dispatch_tf_modes) <= 1
|
||||
), f"Expected only one PreDispatchTorchFunctionMode, found {len(pre_dispatch_tf_modes)}"
|
||||
assert len(pre_dispatch_tf_modes) <= 1, (
|
||||
f"Expected only one PreDispatchTorchFunctionMode, found {len(pre_dispatch_tf_modes)}"
|
||||
)
|
||||
|
||||
if len(pre_dispatch_tf_modes) == 0:
|
||||
return
|
||||
|
|
|
|||
|
|
@ -96,9 +96,9 @@ class GraphInfoProvider:
|
|||
@property
|
||||
def recomputable_node_only_graph(self) -> nx.DiGraph:
|
||||
if self._lazily_initialized_graphs[self.__RECOMPUTABLE_NODE_ONLY_GRAPH] is None:
|
||||
self._lazily_initialized_graphs[
|
||||
self.__RECOMPUTABLE_NODE_ONLY_GRAPH
|
||||
] = self._create_recomputable_node_only_graph()
|
||||
self._lazily_initialized_graphs[self.__RECOMPUTABLE_NODE_ONLY_GRAPH] = (
|
||||
self._create_recomputable_node_only_graph()
|
||||
)
|
||||
return self._lazily_initialized_graphs[self.__RECOMPUTABLE_NODE_ONLY_GRAPH]
|
||||
|
||||
@property
|
||||
|
|
@ -119,17 +119,17 @@ class GraphInfoProvider:
|
|||
@property
|
||||
def full_joint_nx_graph(self) -> nx.DiGraph:
|
||||
if self._lazily_initialized_graphs[self.__FULL_NX_JOINT_GRAPH] is None:
|
||||
self._lazily_initialized_graphs[
|
||||
self.__FULL_NX_JOINT_GRAPH
|
||||
] = self._create_full_joint_graph()
|
||||
self._lazily_initialized_graphs[self.__FULL_NX_JOINT_GRAPH] = (
|
||||
self._create_full_joint_graph()
|
||||
)
|
||||
return self._lazily_initialized_graphs[self.__FULL_NX_JOINT_GRAPH]
|
||||
|
||||
@property
|
||||
def simplified_fx_joint_graph(self) -> Graph:
|
||||
if self._lazily_initialized_graphs[self.__SIMPLIFIED_FX_JOINT_GRAPH] is None:
|
||||
self._lazily_initialized_graphs[
|
||||
self.__SIMPLIFIED_FX_JOINT_GRAPH
|
||||
] = self._recreate_psuedo_joint_graph()
|
||||
self._lazily_initialized_graphs[self.__SIMPLIFIED_FX_JOINT_GRAPH] = (
|
||||
self._recreate_psuedo_joint_graph()
|
||||
)
|
||||
return self._lazily_initialized_graphs[self.__SIMPLIFIED_FX_JOINT_GRAPH]
|
||||
|
||||
def get_non_ac_peak_memory(self) -> float:
|
||||
|
|
@ -285,9 +285,7 @@ class GraphInfoProvider:
|
|||
float(
|
||||
self.recomputable_node_only_graph_with_larger_graph_context.nodes[
|
||||
node
|
||||
][
|
||||
"memory"
|
||||
]
|
||||
]["memory"]
|
||||
)
|
||||
)
|
||||
)
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
"""
|
||||
Utils for caching the outputs of AOTAutograd
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
|
|
@ -425,16 +426,13 @@ class InductorOutput(Generic[TOut], ABC):
|
|||
"""
|
||||
|
||||
@abstractmethod
|
||||
def pre_save(self) -> None:
|
||||
...
|
||||
def pre_save(self) -> None: ...
|
||||
|
||||
@abstractmethod
|
||||
def load(self, example_inputs) -> TOut:
|
||||
...
|
||||
def load(self, example_inputs) -> TOut: ...
|
||||
|
||||
@abstractmethod
|
||||
def post_compile(self, result: TOut, fx_config: _CompileFxKwargs) -> TOut:
|
||||
...
|
||||
def post_compile(self, result: TOut, fx_config: _CompileFxKwargs) -> TOut: ...
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -598,7 +596,9 @@ class CompiledBackward(GenericCompiledBackward[CompiledFxGraph], FxGraphCacheLoa
|
|||
# See note [Wrapping bw_compiler in disable]
|
||||
# This is done by _wrapped_bw_compiler in torch/_dynamo/backends/common.py
|
||||
# But since on cache hit we do not call the bw_compiler, we need to reapply the disable
|
||||
return torch._dynamo.disable(compiled_bw, reason="do not trace generated backwards pass") # type: ignore[return-value]
|
||||
return torch._dynamo.disable( # type: ignore[return-value]
|
||||
compiled_bw, reason="do not trace generated backwards pass"
|
||||
)
|
||||
|
||||
|
||||
# Forward types don't have any extra parameters, so this is just a TypeAlias, in essence
|
||||
|
|
@ -617,7 +617,9 @@ class BundledCompiledBackward(
|
|||
# See note [Wrapping bw_compiler in disable]
|
||||
# This is done by _wrapped_bw_compiler in torch/_dynamo/backends/common.py
|
||||
# But since on cache hit we do not call the bw_compiler, we need to reapply the disable
|
||||
return torch._dynamo.disable(compiled_bw, reason="do not trace generated backwards pass") # type: ignore[return-value]
|
||||
return torch._dynamo.disable( # type: ignore[return-value]
|
||||
compiled_bw, reason="do not trace generated backwards pass"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -1053,10 +1055,10 @@ class AOTAutogradCache(GuardedCache[GenericAOTAutogradCacheEntry]):
|
|||
cache_key, debug_lines = autograd_cache_key(
|
||||
gm, args, aot_config, fx_config
|
||||
)
|
||||
entry: Optional[
|
||||
GenericAOTAutogradCacheEntry
|
||||
] = AOTAutogradCache._lookup(
|
||||
cache_key, local, remote, args, cache_info, aot_config
|
||||
entry: Optional[GenericAOTAutogradCacheEntry] = (
|
||||
AOTAutogradCache._lookup(
|
||||
cache_key, local, remote, args, cache_info, aot_config
|
||||
)
|
||||
)
|
||||
if entry is not None:
|
||||
compiled_fn = entry.wrap_post_compile(args, aot_config, fx_config)
|
||||
|
|
@ -1081,9 +1083,8 @@ class AOTAutogradCache(GuardedCache[GenericAOTAutogradCacheEntry]):
|
|||
# FXGraphCache and AOTAutogradCache?
|
||||
# get_metrics_context().increment(...)
|
||||
if (
|
||||
ephemeral_increase := add_ephemeral_timeout_increase_for_distributed(
|
||||
time_saved_ns
|
||||
)
|
||||
ephemeral_increase
|
||||
:= add_ephemeral_timeout_increase_for_distributed(time_saved_ns)
|
||||
) != 0:
|
||||
cache_info["ephemeral_timeout_increase"] = ephemeral_increase
|
||||
|
||||
|
|
@ -1311,9 +1312,9 @@ class AOTAutogradCache(GuardedCache[GenericAOTAutogradCacheEntry]):
|
|||
return None
|
||||
|
||||
if remote:
|
||||
remote_cache: Optional[
|
||||
RemoteCache[JsonDataTy]
|
||||
] = AOTAutogradCache.get_remote_cache()
|
||||
remote_cache: Optional[RemoteCache[JsonDataTy]] = (
|
||||
AOTAutogradCache.get_remote_cache()
|
||||
)
|
||||
if remote_cache is not None:
|
||||
time_taken_ms = int(
|
||||
(entry.forward_time_taken_ns + entry.backward_time_taken_ns) // 1e6
|
||||
|
|
|
|||
|
|
@ -571,9 +571,9 @@ from a multi-output view call"
|
|||
output_type = (
|
||||
OutputType.alias_of_intermediate_save_as_output
|
||||
)
|
||||
intermediate_base_tensor_id_to_output_idx[
|
||||
id(o._base)
|
||||
] = new_out_idx
|
||||
intermediate_base_tensor_id_to_output_idx[id(o._base)] = (
|
||||
new_out_idx
|
||||
)
|
||||
intermediate_bases.append(o._base)
|
||||
elif (
|
||||
# See https://github.com/pytorch/pytorch/issues/100348 for this case.
|
||||
|
|
|
|||
|
|
@ -46,11 +46,14 @@ aot_graphs_log = getArtifactLogger(__name__, "aot_graphs")
|
|||
def _create_graph(f, args, *, aot_config: AOTConfig) -> torch.fx.GraphModule:
|
||||
# FunctionalTensorMode must be enabled here.
|
||||
# See Note [Accessing .grad_fn on FunctionalTensor]
|
||||
with enable_python_dispatcher(), FunctionalTensorMode(
|
||||
pre_dispatch=aot_config.pre_dispatch,
|
||||
export=aot_config.is_export,
|
||||
# Allow token discovery for joint fn tracing as tokens can be used in backward.
|
||||
_allow_token_discovery=True,
|
||||
with (
|
||||
enable_python_dispatcher(),
|
||||
FunctionalTensorMode(
|
||||
pre_dispatch=aot_config.pre_dispatch,
|
||||
export=aot_config.is_export,
|
||||
# Allow token discovery for joint fn tracing as tokens can be used in backward.
|
||||
_allow_token_discovery=True,
|
||||
),
|
||||
):
|
||||
fx_g = make_fx(
|
||||
f,
|
||||
|
|
@ -238,9 +241,9 @@ def aot_dispatch_base_graph(
|
|||
|
||||
# TODO: should factor this into a separate function for export that always only returns just the graph.
|
||||
if aot_config.is_export:
|
||||
assert (
|
||||
maybe_subclass_meta is None
|
||||
), "aot_export_module does not support tensor subclass inputs for now."
|
||||
assert maybe_subclass_meta is None, (
|
||||
"aot_export_module does not support tensor subclass inputs for now."
|
||||
)
|
||||
return fw_module, saved_updated_flat_args_subclasses_desugared, maybe_subclass_meta
|
||||
|
||||
|
||||
|
|
@ -332,7 +335,7 @@ def aot_dispatch_autograd_graph(
|
|||
# when we need to manually detach() some inputs in the forward.
|
||||
# Higher order ops might eventually need to do the same.
|
||||
if aot_config.is_export:
|
||||
assert (
|
||||
maybe_subclass_meta is None
|
||||
), "aot_export_module does not support tensor subclass inputs for now."
|
||||
assert maybe_subclass_meta is None, (
|
||||
"aot_export_module does not support tensor subclass inputs for now."
|
||||
)
|
||||
return fx_g, saved_updated_joint_inputs, maybe_subclass_meta
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ This file contains utilities related to functionalization in AOTAutograd:
|
|||
3. regenerating/replaying views from their base
|
||||
4. checking if a graph is functional i.e. whether it contains any mutation ops
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
|
@ -452,14 +453,14 @@ def assert_functional_graph(fx_g: torch.fx.Graph) -> int:
|
|||
# this is mostly a hack to avoid failing XLA tests.
|
||||
# See https://github.com/pytorch/pytorch/pull/122434#issuecomment-2101012113
|
||||
if "set_buffer_donor_" not in str(n.args[0]):
|
||||
assert (
|
||||
n.args[0] in placeholders
|
||||
), f"n={str(n)}, n.args[0]={str(n.args[0])}, placeholders={str(placeholders)}, graph={str(fx_g)}"
|
||||
assert n.args[0] in placeholders, (
|
||||
f"n={str(n)}, n.args[0]={str(n.args[0])}, placeholders={str(placeholders)}, graph={str(fx_g)}"
|
||||
)
|
||||
mutation_count += 1
|
||||
else:
|
||||
assert (
|
||||
not n.target._schema.is_mutable
|
||||
), f"aot_autograd expected to have an entirely functional graph, but found {n.format_node()}"
|
||||
assert not n.target._schema.is_mutable, (
|
||||
f"aot_autograd expected to have an entirely functional graph, but found {n.format_node()}"
|
||||
)
|
||||
return mutation_count
|
||||
|
||||
|
||||
|
|
@ -472,9 +473,9 @@ def propagate_input_mutation_stacktraces(fx_g: torch.fx.Graph) -> None:
|
|||
if n.target is torch.ops.aten.copy_.default:
|
||||
# Can only copy_ into an input, and can only do so once
|
||||
if "set_buffer_donor_" not in str(n.args[0]):
|
||||
assert (
|
||||
n.args[0] in placeholders
|
||||
), f"n={str(n)}, n.args[0]={str(n.args[0])}, placeholders={str(placeholders)}, graph={str(fx_g)}"
|
||||
assert n.args[0] in placeholders, (
|
||||
f"n={str(n)}, n.args[0]={str(n.args[0])}, placeholders={str(placeholders)}, graph={str(fx_g)}"
|
||||
)
|
||||
placeholders.remove(n.args[0])
|
||||
copy_from_node = n.args[1]
|
||||
# Pre-condition: every node has a "stack_trace" field in its meta,
|
||||
|
|
|
|||
|
|
@ -823,9 +823,9 @@ def create_wrap_fn(fn, args):
|
|||
from .functional_utils import from_fun, has_data_mutation, to_fun
|
||||
|
||||
def assert_no_mutation(t):
|
||||
assert not has_data_mutation(
|
||||
t
|
||||
), "Saved tensors hooks with inputs mutations are not allowed"
|
||||
assert not has_data_mutation(t), (
|
||||
"Saved tensors hooks with inputs mutations are not allowed"
|
||||
)
|
||||
|
||||
@wraps(fn)
|
||||
def _wrapper(*args):
|
||||
|
|
@ -1110,9 +1110,11 @@ def maybe_inline_graph_saved_tensors_hooks(
|
|||
# Inserting packed sym scalars before first saved tensor input.
|
||||
# Inserting packed tensors before last saved tensor input.
|
||||
# Saved tensor inputs between them will be removed.
|
||||
with bw_g.inserting_before(
|
||||
bw_g_inputs[0]
|
||||
) if is_sym else bw_g.inserting_before(bw_g_input):
|
||||
with (
|
||||
bw_g.inserting_before(bw_g_inputs[0])
|
||||
if is_sym
|
||||
else bw_g.inserting_before(bw_g_input)
|
||||
):
|
||||
new_n = bw_g.placeholder(new_node_name)
|
||||
assert new_n.name == new_node_name
|
||||
new_n.meta = copy.copy(out_n.meta)
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ This module defines runtime wrappers, which, based on previous analysis attempts
|
|||
3. handle functionalized randomness
|
||||
4. deduplicate inputs and consolidate views into their bases (see input_output_analysis)
|
||||
"""
|
||||
|
||||
import builtins
|
||||
import collections
|
||||
import contextlib
|
||||
|
|
@ -318,9 +319,9 @@ def _create_runtime_wrapper(
|
|||
for info in runtime_metadata.output_info
|
||||
)
|
||||
|
||||
def record_runtime_wrapper_prologue_enter() -> (
|
||||
Optional[AbstractContextManager[None]]
|
||||
):
|
||||
def record_runtime_wrapper_prologue_enter() -> Optional[
|
||||
AbstractContextManager[None]
|
||||
]:
|
||||
if (
|
||||
torch.autograd.profiler._is_profiler_enabled
|
||||
and dynamo_config.record_runtime_overhead
|
||||
|
|
@ -950,9 +951,9 @@ class AOTDedupeWrapper(CompilerWrapper):
|
|||
keep_arg_mask.append(True)
|
||||
add_dupe_map.append(j)
|
||||
j += 1
|
||||
assert (
|
||||
len(add_dupe_map) == duped_arg_len
|
||||
), f"Expects add_dupe_map to have length {duped_arg_len} but got {len(add_dupe_map)}"
|
||||
assert len(add_dupe_map) == duped_arg_len, (
|
||||
f"Expects add_dupe_map to have length {duped_arg_len} but got {len(add_dupe_map)}"
|
||||
)
|
||||
|
||||
self.keep_arg_mask = keep_arg_mask
|
||||
self.add_dupe_map = add_dupe_map
|
||||
|
|
@ -996,9 +997,9 @@ class AOTDedupeWrapper(CompilerWrapper):
|
|||
keep_input_mutations=fw_metadata.keep_input_mutations,
|
||||
is_train=fw_metadata.is_train,
|
||||
)(*deduped_flat_args)
|
||||
assert (
|
||||
ref_fw_metadata == updated_fw_metadata
|
||||
), f"ref_metadata={str(ref_fw_metadata)}, actual_metadata={str(updated_fw_metadata)}"
|
||||
assert ref_fw_metadata == updated_fw_metadata, (
|
||||
f"ref_metadata={str(ref_fw_metadata)}, actual_metadata={str(updated_fw_metadata)}"
|
||||
)
|
||||
|
||||
return wrapped_flat_fn, deduped_flat_args, updated_fw_metadata
|
||||
|
||||
|
|
@ -1397,14 +1398,14 @@ def merge_view_inputs(
|
|||
# The "inputs that are aliased but have different differentiable bases" case
|
||||
# is more complicated and hopefully pretty rare. Not currently handled.
|
||||
if not is_inference:
|
||||
assert _are_differentiable_views(
|
||||
view1, view2
|
||||
), "aot_autograd() does not yet handle non-differentiable view input mutations."
|
||||
assert _are_differentiable_views(view1, view2), (
|
||||
"aot_autograd() does not yet handle non-differentiable view input mutations."
|
||||
)
|
||||
# Regenerating views when reinterpreting complex / real tensors seems non-trivial,
|
||||
# not handling for now
|
||||
assert _same_dtype_views(
|
||||
view1, view2
|
||||
), "aot_autograd() does not yet handle input mutations on views with different dtypes."
|
||||
assert _same_dtype_views(view1, view2), (
|
||||
"aot_autograd() does not yet handle input mutations on views with different dtypes."
|
||||
)
|
||||
non_none_bases = [
|
||||
fwd_inputs[i]._base
|
||||
for i in aliased_input_indices
|
||||
|
|
@ -1451,13 +1452,13 @@ def merge_view_inputs(
|
|||
# Case where all of the aliases require gradients, and have the same _base.
|
||||
synthetic_base = non_none_bases[0]
|
||||
for other_base in non_none_bases[1:]:
|
||||
assert (
|
||||
other_base is synthetic_base
|
||||
), "aot_autograd() does not yet handle non-differentiable view input mutations."
|
||||
assert other_base is synthetic_base, (
|
||||
"aot_autograd() does not yet handle non-differentiable view input mutations."
|
||||
)
|
||||
for alias in aliases_with_none_bases:
|
||||
assert (
|
||||
alias is synthetic_base
|
||||
), "aot_autograd() does not yet handle non-differentiable view input mutations."
|
||||
assert alias is synthetic_base, (
|
||||
"aot_autograd() does not yet handle non-differentiable view input mutations."
|
||||
)
|
||||
base_args.append(synthetic_base)
|
||||
for curr_view_idx in aliased_input_indices:
|
||||
curr_view = fwd_inputs[curr_view_idx]
|
||||
|
|
@ -2286,9 +2287,9 @@ To fix this, your tensor subclass must implement the dunder method __force_to_sa
|
|||
@staticmethod
|
||||
def _backward_impl(ctx, all_args):
|
||||
# compiled autograd reimplements this function at proxy_call_aot_backward
|
||||
assert (
|
||||
not backward_state_indices
|
||||
), "BackwardState requires CompiledAutograd"
|
||||
assert not backward_state_indices, (
|
||||
"BackwardState requires CompiledAutograd"
|
||||
)
|
||||
ctx.maybe_clear_saved_tensors()
|
||||
|
||||
saved_tensors_use_once = (
|
||||
|
|
|
|||
|
|
@ -689,7 +689,7 @@ class ViewAndMutationMeta:
|
|||
and len(self.traced_tangents) == len(other.traced_tangents)
|
||||
and all(
|
||||
x.shape == y.shape and x.dtype == y.dtype
|
||||
for x, y, in zip(self.traced_tangents, other.traced_tangents)
|
||||
for x, y in zip(self.traced_tangents, other.traced_tangents)
|
||||
)
|
||||
and self.num_backward_tokens == other.num_backward_tokens
|
||||
)
|
||||
|
|
@ -726,9 +726,9 @@ class SubclassMeta:
|
|||
# in case we made incorrect assumptions about the subclass-ness of our grad_outputs
|
||||
#
|
||||
# Optional field because we don't compute for inference graphs
|
||||
grad_input_metas: Optional[
|
||||
list[Union[PlainTensorMeta, SubclassCreationMeta]]
|
||||
] = None
|
||||
grad_input_metas: Optional[list[Union[PlainTensorMeta, SubclassCreationMeta]]] = (
|
||||
None
|
||||
)
|
||||
|
||||
def __init__(self) -> None:
|
||||
# The fields in this class get set after its construction.
|
||||
|
|
|
|||
|
|
@ -383,9 +383,9 @@ def wrap_tensor_subclasses(
|
|||
return wrapped_args + activations
|
||||
return tuple(list(wrapped_args) + list(activations))
|
||||
else:
|
||||
assert (
|
||||
len(unwrapped_args) == num_args_tallied
|
||||
), f"Expected {len(unwrapped_args)} == {num_args_tallied}"
|
||||
assert len(unwrapped_args) == num_args_tallied, (
|
||||
f"Expected {len(unwrapped_args)} == {num_args_tallied}"
|
||||
)
|
||||
return tuple(wrapped_args)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -320,15 +320,17 @@ def create_functionalized_rng_ops_wrapper(func, args, trace_joint=True) -> Any:
|
|||
def traced_joint(
|
||||
primals, tangents, fwd_seed, fwd_base_offset, bwd_seed, bwd_base_offset
|
||||
):
|
||||
with patch("torch.cuda.get_rng_state", override_get_rng_state), patch(
|
||||
"torch.cuda.set_rng_state", override_set_rng_state
|
||||
with (
|
||||
patch("torch.cuda.get_rng_state", override_get_rng_state),
|
||||
patch("torch.cuda.set_rng_state", override_set_rng_state),
|
||||
):
|
||||
return append_rng_offsets(func(primals, tangents))
|
||||
|
||||
def traced_forward(*primals_fwd_seed_fwd_base_offset):
|
||||
# The signature is (*primals, seed, offset)
|
||||
with patch("torch.cuda.get_rng_state", override_get_rng_state), patch(
|
||||
"torch.cuda.set_rng_state", override_set_rng_state
|
||||
with (
|
||||
patch("torch.cuda.get_rng_state", override_get_rng_state),
|
||||
patch("torch.cuda.set_rng_state", override_set_rng_state),
|
||||
):
|
||||
return append_rng_offsets(func(*primals_fwd_seed_fwd_base_offset[:-2]))
|
||||
|
||||
|
|
@ -450,15 +452,15 @@ def create_functionalized_fn(
|
|||
|
||||
# Ban metadata mutations on fw inputs during the bw
|
||||
if not inpt_info.mutates_metadata:
|
||||
assert (
|
||||
not joint_mutates_metadata
|
||||
), "Found a graph input that had its metadata mutated in the backward. This is not supported"
|
||||
assert not joint_mutates_metadata, (
|
||||
"Found a graph input that had its metadata mutated in the backward. This is not supported"
|
||||
)
|
||||
|
||||
# Ban storage resizing on fw inputs during the bw
|
||||
if not inpt_info.mutation_inductor_storage_resize:
|
||||
assert not was_inductor_storage_resized(
|
||||
f_inpt
|
||||
), "Found a graph input that had storage resizing in the backward. This is not supported"
|
||||
assert not was_inductor_storage_resized(f_inpt), (
|
||||
"Found a graph input that had storage resizing in the backward. This is not supported"
|
||||
)
|
||||
|
||||
# Allow data mutations on fw inputs during the bw, but only if they do not require grad
|
||||
# So we can guarantee that we can keep the mutations in the graph
|
||||
|
|
@ -470,7 +472,10 @@ def create_functionalized_fn(
|
|||
# Not banning here mutations on inpt_info.requires_grad -
|
||||
# we'll check at runtime and fail only when backward is under torch.is_grad_enabled (create_graph)
|
||||
# Add node meta for copy_ for partitioner that this node should be in backward graph.
|
||||
with torch.fx.traceback.preserve_node_meta(), set_partitioner_tag_must_be_in_backward():
|
||||
with (
|
||||
torch.fx.traceback.preserve_node_meta(),
|
||||
set_partitioner_tag_must_be_in_backward(),
|
||||
):
|
||||
before.copy_(after)
|
||||
meta.indices_of_inputs_that_requires_grad_with_mutations_in_bw.append(
|
||||
idx
|
||||
|
|
@ -485,7 +490,9 @@ def create_functionalized_fn(
|
|||
):
|
||||
assert not has_metadata_mutation(
|
||||
f_inpt, before, check_only_storage_mutation=False
|
||||
), "Found an input to the backward that had metadata mutated during the backward pass. This is not supported"
|
||||
), (
|
||||
"Found an input to the backward that had metadata mutated during the backward pass. This is not supported"
|
||||
)
|
||||
if has_data_mutation(f_inpt):
|
||||
can_be_in_graph = _check_if_mutation_can_be_in_graph(
|
||||
keep_input_mutations=True,
|
||||
|
|
@ -503,9 +510,9 @@ def create_functionalized_fn(
|
|||
),
|
||||
requires_grad=f_inpt.requires_grad,
|
||||
)
|
||||
assert (
|
||||
can_be_in_graph
|
||||
), "a backward input that had data mutated in an autograd-aware way. This is not supported"
|
||||
assert can_be_in_graph, (
|
||||
"a backward input that had data mutated in an autograd-aware way. This is not supported"
|
||||
)
|
||||
# Perform the input mutation
|
||||
with torch.fx.traceback.preserve_node_meta():
|
||||
before.copy_(after)
|
||||
|
|
@ -621,8 +628,10 @@ def create_functionalized_fn(
|
|||
if inpt_old.is_inference():
|
||||
maybe_preserve_vc = nullcontext()
|
||||
else:
|
||||
maybe_preserve_vc = torch.autograd._unsafe_preserve_version_counter(
|
||||
inpt_old # type: ignore[assignment]
|
||||
maybe_preserve_vc = (
|
||||
torch.autograd._unsafe_preserve_version_counter(
|
||||
inpt_old # type: ignore[assignment]
|
||||
)
|
||||
)
|
||||
with torch.no_grad(), maybe_preserve_vc:
|
||||
inpt_old.copy_(inpt_new)
|
||||
|
|
@ -889,9 +898,12 @@ def create_functional_call(mod, params_spec, params_len, store_orig_mod=False):
|
|||
# https://github.com/pytorch/pytorch/issues/103569
|
||||
|
||||
def functional_call(*args, **kwargs):
|
||||
with stateless._reparametrize_module(
|
||||
mod, pytree.tree_unflatten(args[:params_len], params_spec)
|
||||
), maybe_disable_thunkify():
|
||||
with (
|
||||
stateless._reparametrize_module(
|
||||
mod, pytree.tree_unflatten(args[:params_len], params_spec)
|
||||
),
|
||||
maybe_disable_thunkify(),
|
||||
):
|
||||
if isinstance(mod, torch.fx.GraphModule):
|
||||
with fx_traceback.preserve_node_meta(), warnings.catch_warnings():
|
||||
warnings.filterwarnings(
|
||||
|
|
|
|||
|
|
@ -140,9 +140,9 @@ def call_func_at_runtime_with_args(
|
|||
class PytreeThunk:
|
||||
spec: Optional[pytree.TreeSpec] = None
|
||||
# These are some kinda dumb microoptimizations that save about 3-4 us of overhead.
|
||||
is_simple: Optional[
|
||||
bool
|
||||
] = None # if the output spec is a tuple/list, we won't bother unflattening it.
|
||||
is_simple: Optional[bool] = (
|
||||
None # if the output spec is a tuple/list, we won't bother unflattening it.
|
||||
)
|
||||
is_really_simple: Optional[bool] = None # if the output spec is a LeafSpec
|
||||
|
||||
def set(self, spec: pytree.TreeSpec) -> None:
|
||||
|
|
@ -335,12 +335,12 @@ def unlift_tokens(fw_module, fw_metadata, aot_config, bw_module=None):
|
|||
|
||||
num_erased_inputs = len(input_token_nodes)
|
||||
|
||||
assert (
|
||||
num_erased_inputs == expected_num_erased
|
||||
), f"{subgraph} num_erased_inputs:{num_erased_inputs} {input_token_nodes}!=expected {expected_num_erased}"
|
||||
assert (
|
||||
num_erased_outs == expected_num_erased
|
||||
), f"{subgraph} num_erased_outs:{num_erased_outs} {output_token_nodes}!=expected {expected_num_erased}"
|
||||
assert num_erased_inputs == expected_num_erased, (
|
||||
f"{subgraph} num_erased_inputs:{num_erased_inputs} {input_token_nodes}!=expected {expected_num_erased}"
|
||||
)
|
||||
assert num_erased_outs == expected_num_erased, (
|
||||
f"{subgraph} num_erased_outs:{num_erased_outs} {output_token_nodes}!=expected {expected_num_erased}"
|
||||
)
|
||||
|
||||
module.recompile()
|
||||
|
||||
|
|
|
|||
|
|
@ -454,8 +454,7 @@ class AOTDispatchCompiler(Protocol):
|
|||
self,
|
||||
gm: torch.fx.GraphModule,
|
||||
example_inputs: Sequence[InputType],
|
||||
) -> Any:
|
||||
...
|
||||
) -> Any: ...
|
||||
|
||||
|
||||
# TODO: bikeshed on this name
|
||||
|
|
@ -637,13 +636,14 @@ def _create_aot_dispatcher_function(
|
|||
# If any saved tensor hooks are active, we **don't** want to trace them.
|
||||
# Instead, we'll let them run at runtime, around the custom autograd.Function
|
||||
# that we generate in torch.compile.
|
||||
with torch.autograd.set_multithreading_enabled(
|
||||
False
|
||||
), preserve_rng_state(), (
|
||||
fake_mode
|
||||
), (
|
||||
python_dispatcher_mode
|
||||
), PhiloxStateTracker(), torch._dynamo.utils._disable_saved_tensors_hooks_during_tracing():
|
||||
with (
|
||||
torch.autograd.set_multithreading_enabled(False),
|
||||
preserve_rng_state(),
|
||||
fake_mode,
|
||||
python_dispatcher_mode,
|
||||
PhiloxStateTracker(),
|
||||
torch._dynamo.utils._disable_saved_tensors_hooks_during_tracing(),
|
||||
):
|
||||
from torch._library.fake_class_registry import (
|
||||
FakeScriptObject,
|
||||
maybe_to_fake_obj,
|
||||
|
|
@ -756,7 +756,7 @@ def _create_aot_dispatcher_function(
|
|||
if fw_metadata.num_intermediate_bases > 0:
|
||||
assert not req_subclass_dispatch, f"""\
|
||||
torch.compile is currently being used with tensor subclass inputs:
|
||||
{','.join([str(type(x)) for x in fake_flat_args])}. We are attempting to a compile a graph with two graph outputs
|
||||
{",".join([str(type(x)) for x in fake_flat_args])}. We are attempting to a compile a graph with two graph outputs
|
||||
that alias one another, which is currently unsupported in the subclass use case. If you run into this,
|
||||
please file a github issue"""
|
||||
|
||||
|
|
@ -899,7 +899,7 @@ def aot_function(
|
|||
A simple example usage of :func:`aot_function` is as follows. This example
|
||||
will print the forward and backward graphs of the function ``fn``
|
||||
|
||||
>>> fn = lambda x : x.sin().cos()
|
||||
>>> fn = lambda x: x.sin().cos()
|
||||
>>> def print_compile_fn(fx_module, args):
|
||||
>>> print(fx_module)
|
||||
>>> return fx_module
|
||||
|
|
@ -1425,9 +1425,7 @@ We require the output marked as the loss (at index {output_loss_index}) to be a
|
|||
output_gradients = []
|
||||
for a, grad in zip(args, gradients):
|
||||
if isinstance(a, torch.Tensor) and a.requires_grad:
|
||||
assert (
|
||||
grad is not None
|
||||
), """\
|
||||
assert grad is not None, """\
|
||||
Found a parameter that did not receive a gradient.
|
||||
"This is most likely a bug, but if this needs to be supported please comment on this Github issue:
|
||||
https://github.com/pytorch/pytorch/issues/101192
|
||||
|
|
@ -1540,7 +1538,9 @@ def aot_export_joint_simple(
|
|||
if config.debug_assert:
|
||||
# Smoke test that after partitioning, we can run the forward without any calling convention changes.
|
||||
fw_module, _bw_module = aot_config.default_partition( # noqa: F821
|
||||
fx_g, args, num_fwd_outputs=len(fw_metadata.output_infos) # noqa: F821
|
||||
fx_g,
|
||||
args,
|
||||
num_fwd_outputs=len(fw_metadata.output_infos), # noqa: F821
|
||||
)
|
||||
# Attempt to run the fw_module with the original user inputs
|
||||
fake_mode = detect_fake_mode(args)
|
||||
|
|
|
|||
|
|
@ -92,7 +92,7 @@ def vmap(
|
|||
doesn't provide a batched ``torch.dot`` API; instead of unsuccessfully
|
||||
rummaging through docs, use :func:`vmap` to construct a new function.
|
||||
|
||||
>>> torch.dot # [D], [D] -> []
|
||||
>>> torch.dot # [D], [D] -> []
|
||||
>>> batched_dot = torch.func.vmap(torch.dot) # [N, D], [N, D] -> [N]
|
||||
>>> x, y = torch.randn(2, 5), torch.randn(2, 5)
|
||||
>>> batched_dot(x, y)
|
||||
|
|
@ -104,7 +104,7 @@ def vmap(
|
|||
>>> weights = torch.randn(feature_size, requires_grad=True)
|
||||
>>>
|
||||
>>> def model(feature_vec):
|
||||
>>> # Very simple linear model with activation
|
||||
>>> # Very simple linear model with activation
|
||||
>>> return feature_vec.dot(weights).relu()
|
||||
>>>
|
||||
>>> examples = torch.randn(batch_size, feature_size)
|
||||
|
|
@ -120,7 +120,7 @@ def vmap(
|
|||
|
||||
>>> # Setup
|
||||
>>> N = 5
|
||||
>>> f = lambda x: x ** 2
|
||||
>>> f = lambda x: x**2
|
||||
>>> x = torch.randn(N, requires_grad=True)
|
||||
>>> y = f(x)
|
||||
>>> I_N = torch.eye(N)
|
||||
|
|
@ -137,43 +137,49 @@ def vmap(
|
|||
|
||||
:func:`vmap` can also be nested, producing an output with multiple batched dimensions
|
||||
|
||||
>>> torch.dot # [D], [D] -> []
|
||||
>>> batched_dot = torch.vmap(torch.vmap(torch.dot)) # [N1, N0, D], [N1, N0, D] -> [N1, N0]
|
||||
>>> torch.dot # [D], [D] -> []
|
||||
>>> batched_dot = torch.vmap(
|
||||
... torch.vmap(torch.dot)
|
||||
... ) # [N1, N0, D], [N1, N0, D] -> [N1, N0]
|
||||
>>> x, y = torch.randn(2, 3, 5), torch.randn(2, 3, 5)
|
||||
>>> batched_dot(x, y) # tensor of size [2, 3]
|
||||
>>> batched_dot(x, y) # tensor of size [2, 3]
|
||||
|
||||
If the inputs are not batched along the first dimension, ``in_dims`` specifies
|
||||
the dimension that each inputs are batched along as
|
||||
|
||||
>>> torch.dot # [N], [N] -> []
|
||||
>>> torch.dot # [N], [N] -> []
|
||||
>>> batched_dot = torch.vmap(torch.dot, in_dims=1) # [N, D], [N, D] -> [D]
|
||||
>>> x, y = torch.randn(2, 5), torch.randn(2, 5)
|
||||
>>> batched_dot(x, y) # output is [5] instead of [2] if batched along the 0th dimension
|
||||
>>> batched_dot(
|
||||
... x, y
|
||||
... ) # output is [5] instead of [2] if batched along the 0th dimension
|
||||
|
||||
If there are multiple inputs each of which is batched along different dimensions,
|
||||
``in_dims`` must be a tuple with the batch dimension for each input as
|
||||
|
||||
>>> torch.dot # [D], [D] -> []
|
||||
>>> torch.dot # [D], [D] -> []
|
||||
>>> batched_dot = torch.vmap(torch.dot, in_dims=(0, None)) # [N, D], [D] -> [N]
|
||||
>>> x, y = torch.randn(2, 5), torch.randn(5)
|
||||
>>> batched_dot(x, y) # second arg doesn't have a batch dim because in_dim[1] was None
|
||||
>>> batched_dot(
|
||||
... x, y
|
||||
... ) # second arg doesn't have a batch dim because in_dim[1] was None
|
||||
|
||||
If the input is a Python struct, ``in_dims`` must be a tuple containing a struct
|
||||
matching the shape of the input:
|
||||
|
||||
>>> f = lambda dict: torch.dot(dict['x'], dict['y'])
|
||||
>>> f = lambda dict: torch.dot(dict["x"], dict["y"])
|
||||
>>> x, y = torch.randn(2, 5), torch.randn(5)
|
||||
>>> input = {'x': x, 'y': y}
|
||||
>>> batched_dot = torch.vmap(f, in_dims=({'x': 0, 'y': None},))
|
||||
>>> input = {"x": x, "y": y}
|
||||
>>> batched_dot = torch.vmap(f, in_dims=({"x": 0, "y": None},))
|
||||
>>> batched_dot(input)
|
||||
|
||||
By default, the output is batched along the first dimension. However, it can be batched
|
||||
along any dimension by using ``out_dims``
|
||||
|
||||
>>> f = lambda x: x ** 2
|
||||
>>> f = lambda x: x**2
|
||||
>>> x = torch.randn(2, 5)
|
||||
>>> batched_pow = torch.vmap(f, out_dims=1)
|
||||
>>> batched_pow(x) # [5, 2]
|
||||
>>> batched_pow(x) # [5, 2]
|
||||
|
||||
For any function that uses kwargs, the returned function will not batch the kwargs but will
|
||||
accept kwargs
|
||||
|
|
@ -184,7 +190,7 @@ def vmap(
|
|||
>>>
|
||||
>>> batched_pow = torch.vmap(fn)
|
||||
>>> assert torch.allclose(batched_pow(x), x * 4)
|
||||
>>> batched_pow(x, scale=x) # scale is not batched, output has shape [2, 2, 5]
|
||||
>>> batched_pow(x, scale=x) # scale is not batched, output has shape [2, 2, 5]
|
||||
|
||||
.. note::
|
||||
vmap does not provide general autobatching or handle variable-length
|
||||
|
|
@ -337,7 +343,7 @@ def grad(func: Callable, argnums: argnums_t = 0, has_aux: bool = False) -> Calla
|
|||
>>> batch_size, feature_size = 3, 5
|
||||
>>>
|
||||
>>> def model(weights, feature_vec):
|
||||
>>> # Very simple linear model with activation
|
||||
>>> # Very simple linear model with activation
|
||||
>>> assert feature_vec.dim() == 1
|
||||
>>> return feature_vec.dot(weights).relu()
|
||||
>>>
|
||||
|
|
@ -349,7 +355,9 @@ def grad(func: Callable, argnums: argnums_t = 0, has_aux: bool = False) -> Calla
|
|||
>>> examples = torch.randn(batch_size, feature_size)
|
||||
>>> targets = torch.randn(batch_size)
|
||||
>>> inputs = (weights, examples, targets)
|
||||
>>> grad_weight_per_example = vmap(grad(compute_loss), in_dims=(None, 0, 0))(*inputs)
|
||||
>>> grad_weight_per_example = vmap(grad(compute_loss), in_dims=(None, 0, 0))(
|
||||
... *inputs
|
||||
... )
|
||||
|
||||
Example of using ``grad`` with ``has_aux`` and ``argnums``:
|
||||
|
||||
|
|
|
|||
|
|
@ -185,8 +185,12 @@ def benchmark_utilization(
|
|||
```
|
||||
def f(a):
|
||||
return a.sum()
|
||||
|
||||
|
||||
a = torch.rand(2**20, device="cuda")
|
||||
utilization, mm_conv_utilization = benchmark_utilization(f, a, "tmp", trace_file_name = "tmp_chrome_trace")
|
||||
utilization, mm_conv_utilization = benchmark_utilization(
|
||||
f, a, "tmp", trace_file_name="tmp_chrome_trace"
|
||||
)
|
||||
```
|
||||
|
||||
Args:
|
||||
|
|
|
|||
|
|
@ -150,13 +150,13 @@ class DebugInterpreter(fx.Interpreter):
|
|||
def check(nv, rv, desc):
|
||||
assert callable(desc)
|
||||
assert nv.dtype == rv.dtype, f"{desc()}: {nv.dtype} != {rv.dtype}"
|
||||
assert (
|
||||
subst_symint_tuple(nv.size()) == rv.size()
|
||||
), f"{desc()}: {nv.size()} aka {subst_symint_tuple(nv.size())} != {rv.size()}"
|
||||
assert subst_symint_tuple(nv.size()) == rv.size(), (
|
||||
f"{desc()}: {nv.size()} aka {subst_symint_tuple(nv.size())} != {rv.size()}"
|
||||
)
|
||||
same_strides = check_significant_strides(nv, rv)
|
||||
assert (
|
||||
same_strides
|
||||
), f"{desc()}: {nv.stride()} aka {subst_symint_tuple(nv.stride())} != {rv.stride()}"
|
||||
assert same_strides, (
|
||||
f"{desc()}: {nv.stride()} aka {subst_symint_tuple(nv.stride())} != {rv.stride()}"
|
||||
)
|
||||
|
||||
r = super().run_node(n)
|
||||
if "val" in n.meta:
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@
|
|||
"""
|
||||
Global flags for aot autograd
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from typing import Literal, Optional, TYPE_CHECKING
|
||||
|
|
|
|||
|
|
@ -233,7 +233,7 @@ def vjp(func: Callable, *primals, has_aux: bool = False):
|
|||
>>> x = torch.randn([5])
|
||||
>>> f = lambda x: x.sin().sum()
|
||||
>>> (_, vjpfunc) = torch.func.vjp(f, x)
|
||||
>>> grad = vjpfunc(torch.tensor(1.))[0]
|
||||
>>> grad = vjpfunc(torch.tensor(1.0))[0]
|
||||
>>> assert torch.allclose(grad, torch.func.grad(f)(x))
|
||||
|
||||
However, :func:`vjp` can support functions with multiple outputs by
|
||||
|
|
@ -248,9 +248,9 @@ def vjp(func: Callable, *primals, has_aux: bool = False):
|
|||
:func:`vjp` can even support outputs being Python structs
|
||||
|
||||
>>> x = torch.randn([5])
|
||||
>>> f = lambda x: {'first': x.sin(), 'second': x.cos()}
|
||||
>>> f = lambda x: {"first": x.sin(), "second": x.cos()}
|
||||
>>> (_, vjpfunc) = torch.func.vjp(f, x)
|
||||
>>> cotangents = {'first': torch.ones([5]), 'second': torch.ones([5])}
|
||||
>>> cotangents = {"first": torch.ones([5]), "second": torch.ones([5])}
|
||||
>>> vjps = vjpfunc(cotangents)
|
||||
>>> assert torch.allclose(vjps[0], x.cos() + -x.sin())
|
||||
|
||||
|
|
@ -274,7 +274,7 @@ def vjp(func: Callable, *primals, has_aux: bool = False):
|
|||
>>>
|
||||
>>> (_, vjpfunc) = torch.func.vjp(f, x)
|
||||
>>> vjps = vjpfunc(torch.ones_like(x))
|
||||
>>> assert torch.allclose(vjps[0], torch.full(x.shape, 4.))
|
||||
>>> assert torch.allclose(vjps[0], torch.full(x.shape, 4.0))
|
||||
|
||||
.. note::
|
||||
Using PyTorch ``torch.no_grad`` together with ``vjp``.
|
||||
|
|
@ -930,8 +930,7 @@ def assert_output_is_tensor_or_tensors(output: Any, api: str) -> None:
|
|||
return
|
||||
if not isinstance(output, tuple):
|
||||
raise RuntimeError(
|
||||
f"{api}: Expected output of f to be a Tensor or Tensors, got "
|
||||
f"{type(output)}"
|
||||
f"{api}: Expected output of f to be a Tensor or Tensors, got {type(output)}"
|
||||
)
|
||||
if len(output) == 0:
|
||||
raise RuntimeError(
|
||||
|
|
@ -1023,10 +1022,10 @@ def jvp(
|
|||
|
||||
>>> from torch.func import jvp
|
||||
>>> x = torch.randn([])
|
||||
>>> f = lambda x: x * torch.tensor([1., 2., 3])
|
||||
>>> value, grad = jvp(f, (x,), (torch.tensor(1.),))
|
||||
>>> f = lambda x: x * torch.tensor([1.0, 2.0, 3])
|
||||
>>> value, grad = jvp(f, (x,), (torch.tensor(1.0),))
|
||||
>>> assert torch.allclose(value, f(x))
|
||||
>>> assert torch.allclose(grad, torch.tensor([1., 2, 3]))
|
||||
>>> assert torch.allclose(grad, torch.tensor([1.0, 2, 3]))
|
||||
|
||||
:func:`jvp` can support functions with multiple inputs by passing in the
|
||||
tangents for each of the inputs
|
||||
|
|
|
|||
|
|
@ -60,7 +60,10 @@ def functional_call(
|
|||
|
||||
.. code-block:: python
|
||||
|
||||
a = ({'weight': torch.ones(1, 1)}, {'buffer': torch.zeros(1)}) # two separate dictionaries
|
||||
a = (
|
||||
{"weight": torch.ones(1, 1)},
|
||||
{"buffer": torch.zeros(1)},
|
||||
) # two separate dictionaries
|
||||
mod = nn.Bar(1, 1) # return self.weight @ x + self.buffer
|
||||
print(mod.weight) # tensor(...)
|
||||
print(mod.buffer) # tensor(...)
|
||||
|
|
@ -83,10 +86,12 @@ def functional_call(
|
|||
t = torch.randn(4, 3)
|
||||
model = nn.Linear(3, 3)
|
||||
|
||||
|
||||
def compute_loss(params, x, t):
|
||||
y = functional_call(model, params, x)
|
||||
return nn.functional.mse_loss(y, t)
|
||||
|
||||
|
||||
grad_weights = grad(compute_loss)(dict(model.named_parameters()), x, t)
|
||||
|
||||
.. note:: If the user does not need grad tracking outside of grad transforms, they can detach all of the
|
||||
|
|
@ -179,9 +184,11 @@ def stack_module_state(
|
|||
models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)]
|
||||
data = torch.randn(batch_size, 3)
|
||||
|
||||
|
||||
def wrapper(params, buffers, data):
|
||||
return torch.func.functional_call(models[0], (params, buffers), data)
|
||||
|
||||
|
||||
params, buffers = stack_module_state(models)
|
||||
output = vmap(wrapper, (0, 0, None))(params, buffers, data)
|
||||
|
||||
|
|
@ -192,6 +199,8 @@ def stack_module_state(
|
|||
.. code-block:: python
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class Foo(nn.Module):
|
||||
def __init__(self, in_features, out_features):
|
||||
super().__init__()
|
||||
|
|
@ -202,6 +211,7 @@ def stack_module_state(
|
|||
def forward(self, x):
|
||||
return self.l2(self.l1(x))
|
||||
|
||||
|
||||
num_models = 5
|
||||
in_features, out_features = 3, 3
|
||||
models = [Foo(in_features, out_features) for i in range(num_models)]
|
||||
|
|
|
|||
|
|
@ -374,10 +374,12 @@ def make_functional(
|
|||
model = nn.Linear(3, 3)
|
||||
func, params = make_functional(model)
|
||||
|
||||
|
||||
def compute_loss(params, x, t):
|
||||
y = func(params, x)
|
||||
return nn.functional.mse_loss(y, t)
|
||||
|
||||
|
||||
grad_weights = grad(compute_loss)(params, x, t)
|
||||
|
||||
If the model has any buffers, please use :func:`make_functional_with_buffers` instead.
|
||||
|
|
@ -443,10 +445,12 @@ def make_functional_with_buffers(
|
|||
model = nn.Linear(3, 3)
|
||||
func, params, buffers = make_functional_with_buffers(model)
|
||||
|
||||
|
||||
def compute_loss(params, buffers, x, t):
|
||||
y = func(params, buffers, x)
|
||||
return nn.functional.mse_loss(y, t)
|
||||
|
||||
|
||||
grad_weights = grad(compute_loss)(params, buffers, x, t)
|
||||
|
||||
Args:
|
||||
|
|
@ -469,7 +473,7 @@ def make_functional_with_buffers(
|
|||
|
||||
|
||||
def transpose_stack(
|
||||
tuple_of_tuple_of_tensors: tuple[tuple[Tensor, ...], ...]
|
||||
tuple_of_tuple_of_tensors: tuple[tuple[Tensor, ...], ...],
|
||||
) -> tuple[Tensor, ...]:
|
||||
tuple_of_tuple_of_tensors = tuple(zip(*tuple_of_tuple_of_tensors))
|
||||
results = tuple(
|
||||
|
|
|
|||
|
|
@ -229,9 +229,9 @@ def _extract_graph_with_inputs_outputs(
|
|||
if isinstance(x, fx.Node):
|
||||
if x not in env:
|
||||
raise RuntimeError(f"Node {x} couldn't be found in env")
|
||||
assert not isinstance(
|
||||
env[x], InvalidNodeBase
|
||||
), f"Node {x} was invalid, but is output"
|
||||
assert not isinstance(env[x], InvalidNodeBase), (
|
||||
f"Node {x} was invalid, but is output"
|
||||
)
|
||||
output_values.append(env[x])
|
||||
else:
|
||||
output_values.append(x)
|
||||
|
|
@ -449,10 +449,10 @@ def perform_quantization(
|
|||
args=(clamp_max_scaled_node, quant_type),
|
||||
name="fp8_quant_" + str(node.name),
|
||||
)
|
||||
quant_activation_node.meta[
|
||||
"val"
|
||||
] = torch.ops.prims.convert_element_type.default(
|
||||
clamp_max_scaled_node.meta["val"], quant_type
|
||||
quant_activation_node.meta["val"] = (
|
||||
torch.ops.prims.convert_element_type.default(
|
||||
clamp_max_scaled_node.meta["val"], quant_type
|
||||
)
|
||||
)
|
||||
quant_activation_node.meta["tensor_meta"] = extract_tensor_metadata(
|
||||
quant_activation_node.meta["val"]
|
||||
|
|
@ -567,10 +567,10 @@ def quantize_activation_fw(graph: torch.fx.Graph) -> None:
|
|||
args=(node, quant_type),
|
||||
name="fp8_quant_" + str(node.name),
|
||||
)
|
||||
quant_node.meta[
|
||||
"val"
|
||||
] = torch.ops.prims.convert_element_type.default(
|
||||
node.meta["val"], quant_type
|
||||
quant_node.meta["val"] = (
|
||||
torch.ops.prims.convert_element_type.default(
|
||||
node.meta["val"], quant_type
|
||||
)
|
||||
)
|
||||
quant_node.meta["tensor_meta"] = extract_tensor_metadata(
|
||||
quant_node.meta["val"]
|
||||
|
|
@ -578,7 +578,7 @@ def quantize_activation_fw(graph: torch.fx.Graph) -> None:
|
|||
node_to_quant[node] = quant_node
|
||||
# only update the return node args, and remain all other users unchanged
|
||||
output_updated_args = [
|
||||
node_to_quant[node] if node in node_to_quant else node for node in fwd_outputs # type: ignore[union-attr]
|
||||
node_to_quant[node] if node in node_to_quant else node for node in fwd_outputs
|
||||
]
|
||||
# add the scale nodes to the ouput find the first sym_node in the output
|
||||
idx = find_first_sym_node(output_updated_args)
|
||||
|
|
@ -617,10 +617,10 @@ def quantize_activation_bw(graph: torch.fx.Graph) -> None:
|
|||
torch.ops.prims.convert_element_type.default,
|
||||
args=(node, dequant_type),
|
||||
)
|
||||
activation_node.meta[
|
||||
"val"
|
||||
] = torch.ops.prims.convert_element_type.default(
|
||||
node.meta["val"], dequant_type
|
||||
activation_node.meta["val"] = (
|
||||
torch.ops.prims.convert_element_type.default(
|
||||
node.meta["val"], dequant_type
|
||||
)
|
||||
)
|
||||
activation_node.meta["tensor_meta"] = extract_tensor_metadata(
|
||||
activation_node.meta["val"]
|
||||
|
|
@ -633,18 +633,18 @@ def quantize_activation_bw(graph: torch.fx.Graph) -> None:
|
|||
divided_target_node_32.meta["val"] = torch.ops.aten.div.Tensor(
|
||||
activation_node.meta["val"], scale_node.meta["val"]
|
||||
)
|
||||
divided_target_node_32.meta[
|
||||
"tensor_meta"
|
||||
] = extract_tensor_metadata(divided_target_node_32.meta["val"])
|
||||
divided_target_node_32.meta["tensor_meta"] = (
|
||||
extract_tensor_metadata(divided_target_node_32.meta["val"])
|
||||
)
|
||||
with graph.inserting_after(divided_target_node_32):
|
||||
dequant_node = graph.call_function(
|
||||
torch.ops.prims.convert_element_type.default,
|
||||
args=(divided_target_node_32, dequant_type),
|
||||
)
|
||||
dequant_node.meta[
|
||||
"val"
|
||||
] = torch.ops.prims.convert_element_type.default(
|
||||
divided_target_node_32.meta["val"], dequant_type
|
||||
dequant_node.meta["val"] = (
|
||||
torch.ops.prims.convert_element_type.default(
|
||||
divided_target_node_32.meta["val"], dequant_type
|
||||
)
|
||||
)
|
||||
dequant_node.meta["tensor_meta"] = extract_tensor_metadata(
|
||||
dequant_node.meta["val"]
|
||||
|
|
@ -656,10 +656,10 @@ def quantize_activation_bw(graph: torch.fx.Graph) -> None:
|
|||
args=(node, dequant_type),
|
||||
name="dequant_" + str(node.name),
|
||||
)
|
||||
dequant_node.meta[
|
||||
"val"
|
||||
] = torch.ops.prims.convert_element_type.default(
|
||||
node.meta["val"], dequant_type
|
||||
dequant_node.meta["val"] = (
|
||||
torch.ops.prims.convert_element_type.default(
|
||||
node.meta["val"], dequant_type
|
||||
)
|
||||
)
|
||||
dequant_node.meta["tensor_meta"] = extract_tensor_metadata(
|
||||
dequant_node.meta["val"]
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@
|
|||
From https://docs.google.com/spreadsheets/d/12R3nCOLskxPYjjiNkdqy4OdQ65eQp_htebXGODsjSeA/edit#gid=0
|
||||
Try to keep this list in sync with that.
|
||||
"""
|
||||
|
||||
import operator
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -156,6 +156,9 @@ def call_delegate_functionalize(
|
|||
)
|
||||
with ctx.redispatch_to_next():
|
||||
res = aoti_call_delegate(
|
||||
lowered_module, original_gm, unwrapped_weight_args, unwrapped_input_args # type: ignore[arg-type]
|
||||
lowered_module,
|
||||
original_gm,
|
||||
unwrapped_weight_args, # type: ignore[arg-type]
|
||||
unwrapped_input_args, # type: ignore[arg-type]
|
||||
)
|
||||
return ctx.wrap_tensors(res)
|
||||
|
|
|
|||
|
|
@ -31,9 +31,9 @@ aten = torch._ops.ops.aten
|
|||
|
||||
|
||||
def wrap_combine_fn_flat(*args, combine_fn, spec, num_leaves):
|
||||
assert (
|
||||
len(args) == 2 * num_leaves
|
||||
), f"Combin_fn received wrong number of arguments, expected {2 * num_leaves}, but got {len(args)}"
|
||||
assert len(args) == 2 * num_leaves, (
|
||||
f"Combin_fn received wrong number of arguments, expected {2 * num_leaves}, but got {len(args)}"
|
||||
)
|
||||
lhs = pytree.tree_unflatten(args[:num_leaves], spec)
|
||||
rhs = pytree.tree_unflatten(args[num_leaves:], spec)
|
||||
return combine_fn(lhs, rhs)
|
||||
|
|
@ -79,9 +79,9 @@ class AssociativeScanOp(HigherOrderOperator):
|
|||
# the additional_inputs being a list. See https://github.com/pytorch/pytorch/issues/145785
|
||||
# Once this issue is resolved, the assertion should only allow tuples
|
||||
# and the tuple cast should be removed
|
||||
assert isinstance(
|
||||
additional_inputs, (tuple, list)
|
||||
), "additional_inputs must be a tuple."
|
||||
assert isinstance(additional_inputs, (tuple, list)), (
|
||||
"additional_inputs must be a tuple."
|
||||
)
|
||||
additional_inputs = (
|
||||
tuple(additional_inputs)
|
||||
if isinstance(additional_inputs, list)
|
||||
|
|
@ -134,6 +134,7 @@ def associative_scan(
|
|||
def add(x: torch.Tensor, y: torch.Tensor):
|
||||
return x + y
|
||||
|
||||
|
||||
cumsum = associative_scan(add, x, dim)
|
||||
|
||||
"""
|
||||
|
|
@ -377,9 +378,9 @@ def trace_associative_scan(
|
|||
|
||||
assert outputs is not None
|
||||
outputs = pytree.tree_leaves(outputs)
|
||||
assert len(outputs) == len(
|
||||
xs
|
||||
), f"expected combine_fn to return {len(xs)} results but got {len(outputs)}"
|
||||
assert len(outputs) == len(xs), (
|
||||
f"expected combine_fn to return {len(xs)} results but got {len(outputs)}"
|
||||
)
|
||||
|
||||
xs_fake_tensors: list[torch.Tensor | torch.SymInt | int] = [
|
||||
first_slice_copy(x) for x in xs
|
||||
|
|
|
|||
|
|
@ -522,7 +522,8 @@ def do_auto_functionalize(
|
|||
)
|
||||
with ctx.redispatch_to_next():
|
||||
unwrapped_outs = auto_functionalized(
|
||||
op, **unwrapped_kwargs # type: ignore[arg-type]
|
||||
op,
|
||||
**unwrapped_kwargs, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
# List of the name of args that get mutated (according to the schema)
|
||||
|
|
@ -704,7 +705,8 @@ def do_auto_functionalize_v2(
|
|||
|
||||
with ctx.redispatch_to_next():
|
||||
unwrapped_outs = auto_functionalized_v2(
|
||||
op, **auto_func_kwargs # type: ignore[arg-type]
|
||||
op,
|
||||
**auto_func_kwargs, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
unwrapped_actual_out: Union[Any, tuple[Any]] = (
|
||||
|
|
@ -716,9 +718,9 @@ def do_auto_functionalize_v2(
|
|||
)
|
||||
|
||||
if isinstance(op, HigherOrderOperator):
|
||||
assert (
|
||||
len(schema.returns) > 0
|
||||
), f"hop is expected to return at least one output {schema}."
|
||||
assert len(schema.returns) > 0, (
|
||||
f"hop is expected to return at least one output {schema}."
|
||||
)
|
||||
assert len(unwrapped_actual_out) == len(schema.returns)
|
||||
else:
|
||||
if len(schema.returns) == 0:
|
||||
|
|
|
|||
|
|
@ -40,11 +40,14 @@ class BaseHOP(HigherOrderOperator, abc.ABC):
|
|||
def __init__(self):
|
||||
return super().__init__("invoke_quant")
|
||||
|
||||
|
||||
invoke_quant = InvokeQuant()
|
||||
|
||||
|
||||
def g(x):
|
||||
return x.sin().cos()
|
||||
|
||||
|
||||
@torch.compile(backend="aot_eager")
|
||||
def f(x):
|
||||
return invoke_quant(g, x, scheme="nf4")
|
||||
|
|
@ -113,7 +116,10 @@ class BaseHOP(HigherOrderOperator, abc.ABC):
|
|||
|
||||
out = self(subgraph, *operands, **kwargs)
|
||||
return track_tensor_tree(
|
||||
out, out_proxy, constant=None, tracer=proxy_mode.tracer # type: ignore[arg-type]
|
||||
out,
|
||||
out_proxy,
|
||||
constant=None,
|
||||
tracer=proxy_mode.tracer, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
def _call_FakeTensorMode(self, mode, subgraph, *operands, **kwargs):
|
||||
|
|
@ -230,7 +236,11 @@ class BaseHOPFunction(torch.autograd.Function):
|
|||
kwargs = ctx.kwargs
|
||||
|
||||
# TODO: Something special needs to happen with min cut partitioner
|
||||
with suspend_functionalization(), disable_functional_mode(), torch.enable_grad():
|
||||
with (
|
||||
suspend_functionalization(),
|
||||
disable_functional_mode(),
|
||||
torch.enable_grad(),
|
||||
):
|
||||
with disable_proxy_modes_tracing():
|
||||
from .invoke_subgraph import create_fw_bw_graph
|
||||
from .utils import _from_fun
|
||||
|
|
|
|||
|
|
@ -107,8 +107,12 @@ def cond(
|
|||
|
||||
def true_fn(x: torch.Tensor):
|
||||
return x.cos()
|
||||
|
||||
|
||||
def false_fn(x: torch.Tensor):
|
||||
return x.sin()
|
||||
|
||||
|
||||
return cond(x.shape[0] > 4, true_fn, false_fn, (x,))
|
||||
|
||||
Restrictions:
|
||||
|
|
@ -182,7 +186,11 @@ def cond(
|
|||
def _cond_op_wrapper(*args, **kwargs):
|
||||
return cond_op(*args, **kwargs)
|
||||
|
||||
with _set_compilation_env(), torch._dynamo.utils.disable_cache_limit(), _temp_remove_pre_dispatch_torch_function_mode():
|
||||
with (
|
||||
_set_compilation_env(),
|
||||
torch._dynamo.utils.disable_cache_limit(),
|
||||
_temp_remove_pre_dispatch_torch_function_mode(),
|
||||
):
|
||||
with _temp_remove_metadata_torch_function_mode() as metadata_mode:
|
||||
if metadata_mode:
|
||||
backend = make_eager_backend_with_torch_function_mode(metadata_mode)
|
||||
|
|
@ -248,9 +256,9 @@ def create_bw_fn(fn: Callable, args: tuple[Any]) -> Callable:
|
|||
|
||||
|
||||
def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands):
|
||||
assert isinstance(
|
||||
operands, (list, tuple)
|
||||
), f"Cond operands must be a list or tuple of tensors and SymInts {operands}"
|
||||
assert isinstance(operands, (list, tuple)), (
|
||||
f"Cond operands must be a list or tuple of tensors and SymInts {operands}"
|
||||
)
|
||||
|
||||
true_graph = reenter_make_fx(true_fn)(*operands)
|
||||
false_graph = reenter_make_fx(false_fn)(*operands)
|
||||
|
|
@ -297,9 +305,9 @@ def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands):
|
|||
|
||||
@cond_op.py_impl(DispatchKey.CompositeExplicitAutograd)
|
||||
def cond_op_dense(pred, true_fn, false_fn, operands):
|
||||
assert all(
|
||||
isinstance(o, (torch.Tensor, int)) for o in operands
|
||||
), f"Dense implementation operands must be a list of tensors and ints {operands}"
|
||||
assert all(isinstance(o, (torch.Tensor, int)) for o in operands), (
|
||||
f"Dense implementation operands must be a list of tensors and ints {operands}"
|
||||
)
|
||||
mode = _get_current_dispatch_mode()
|
||||
assert mode is None, "Mode should never be enabled for CPU/CUDA key"
|
||||
if pred:
|
||||
|
|
@ -627,9 +635,9 @@ def _merge_output(
|
|||
|
||||
if _maybe_expr(a_val) in a_stride_expr:
|
||||
a_expr = a_stride_expr[_maybe_expr(a_val)]
|
||||
assert (
|
||||
b_stride_expr[_maybe_expr(b_val)] == a_expr
|
||||
), f"a_stride_expr:{a_stride_expr}, b_stride_expr:{b_stride_expr}"
|
||||
assert b_stride_expr[_maybe_expr(b_val)] == a_expr, (
|
||||
f"a_stride_expr:{a_stride_expr}, b_stride_expr:{b_stride_expr}"
|
||||
)
|
||||
merged_strides[i] = a_expr
|
||||
else:
|
||||
if a_val == 1:
|
||||
|
|
@ -686,12 +694,12 @@ def cond_func(ctx, pred, true_fn, false_fn, inputs):
|
|||
|
||||
@cond_op.py_impl(torch._C._functorch.TransformType.Vmap)
|
||||
def cond_batch_rule(interpreter, pred, true_fn, false_fn, inputs):
|
||||
assert isinstance(
|
||||
inputs, (list, tuple)
|
||||
), "Cond inputs must be a list or tuple of tensors"
|
||||
assert all(
|
||||
isinstance(i, torch.Tensor) for i in inputs
|
||||
), "Cond inputs must be a list of tensors"
|
||||
assert isinstance(inputs, (list, tuple)), (
|
||||
"Cond inputs must be a list or tuple of tensors"
|
||||
)
|
||||
assert all(isinstance(i, torch.Tensor) for i in inputs), (
|
||||
"Cond inputs must be a list of tensors"
|
||||
)
|
||||
|
||||
pred_is_batched = isinstance(pred, torch.Tensor) and is_batchedtensor(pred)
|
||||
pred_ = get_unwrapped(pred) if pred_is_batched else pred
|
||||
|
|
|
|||
|
|
@ -240,9 +240,9 @@ def handle_effects(
|
|||
key = get_effect_key(op, args, kwargs)
|
||||
assert key is not None
|
||||
if key not in tokens:
|
||||
assert (
|
||||
allow_token_discovery
|
||||
), f"Could not find a token for effect {key} which came from the function {op}"
|
||||
assert allow_token_discovery, (
|
||||
f"Could not find a token for effect {key} which came from the function {op}"
|
||||
)
|
||||
proxy_tensor_mode = torch._C._get_dispatch_mode(
|
||||
torch._C._TorchDispatchModeKey.PROXY
|
||||
)
|
||||
|
|
|
|||
|
|
@ -49,7 +49,10 @@ def trace_call_delegate(proxy_mode, func_overload, lowered_module, *args):
|
|||
if not isinstance(e, (torch.Tensor, torch.SymInt, torch.SymFloat)):
|
||||
return e
|
||||
return get_proxy_slot(
|
||||
cast(torch.Tensor, e), proxy_mode.tracer, e, lambda e: e.proxy # type: ignore[attr-defined]
|
||||
cast(torch.Tensor, e),
|
||||
proxy_mode.tracer,
|
||||
e,
|
||||
lambda e: e.proxy, # type: ignore[attr-defined]
|
||||
)
|
||||
|
||||
if not is_lowered_module(lowered_module):
|
||||
|
|
|
|||
|
|
@ -38,9 +38,9 @@ def _construct_strides(
|
|||
) -> Sequence[int]:
|
||||
"""From a list of sizes and a fill order, construct the strides of the permuted tensor."""
|
||||
# Initialize strides
|
||||
assert len(sizes) == len(
|
||||
fill_order
|
||||
), "Length of sizes must match the length of the fill order"
|
||||
assert len(sizes) == len(fill_order), (
|
||||
"Length of sizes must match the length of the fill order"
|
||||
)
|
||||
strides = [0] * len(sizes)
|
||||
|
||||
# Start with stride 1 for the innermost dimension
|
||||
|
|
@ -594,7 +594,7 @@ def create_fw_bw_graph(
|
|||
*other_buffers: tuple[Tensor, ...],
|
||||
) -> tuple[Tensor, ...]:
|
||||
def fw_with_masks(
|
||||
*args: tuple[Tensor, ...]
|
||||
*args: tuple[Tensor, ...],
|
||||
) -> tuple[tuple[Tensor], tuple[bool]]:
|
||||
fw_out = score_mod(*args)
|
||||
out_requires_grad = fw_out.requires_grad
|
||||
|
|
@ -633,9 +633,9 @@ class FlexAttentionAutogradOp(torch.autograd.Function):
|
|||
for buffer in mask_mod_other_buffers
|
||||
if isinstance(buffer, torch.Tensor)
|
||||
)
|
||||
assert (
|
||||
not any_buffer_requires_grad
|
||||
), "Captured buffers from mask mod that require grad are not supported."
|
||||
assert not any_buffer_requires_grad, (
|
||||
"Captured buffers from mask mod that require grad are not supported."
|
||||
)
|
||||
ctx._fw_graph = fw_graph
|
||||
ctx._joint_graph = joint_graph
|
||||
ctx._mask_graph = block_mask[-1]
|
||||
|
|
@ -671,7 +671,11 @@ class FlexAttentionAutogradOp(torch.autograd.Function):
|
|||
return out, logsumexp
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx: Any, grad_out: Tensor, grad_logsumexp: Tensor) -> tuple[Optional[Tensor], ...]: # type: ignore[override]
|
||||
def backward( # type: ignore[override]
|
||||
ctx: Any,
|
||||
grad_out: Tensor,
|
||||
grad_logsumexp: Tensor,
|
||||
) -> tuple[Optional[Tensor], ...]:
|
||||
fw_args = saved_tensors_and_symints(ctx)
|
||||
(
|
||||
query,
|
||||
|
|
@ -939,9 +943,9 @@ def sdpa_dense_backward(
|
|||
actual_grad_value.copy_(grad_value)
|
||||
|
||||
if Bq != Bkv:
|
||||
assert (
|
||||
Bq > 1 and Bkv == 1
|
||||
), f"Bq and Bkv must broadcast. Got Bq={Bq} and Bkv={Bkv}"
|
||||
assert Bq > 1 and Bkv == 1, (
|
||||
f"Bq and Bkv must broadcast. Got Bq={Bq} and Bkv={Bkv}"
|
||||
)
|
||||
|
||||
actual_grad_key = torch.sum(actual_grad_key, 0, keepdim=True)
|
||||
actual_grad_value = torch.sum(actual_grad_value, 0, keepdim=True)
|
||||
|
|
|
|||
|
|
@ -38,8 +38,7 @@ class HintsWrapper(HigherOrderOperator):
|
|||
|
||||
if not all(isinstance(t, (torch.Tensor, int, float, bool)) for t in args):
|
||||
raise RuntimeError(
|
||||
"args must be a tuple of tensors, ints, floats, or bools, got "
|
||||
f"{args}"
|
||||
f"args must be a tuple of tensors, ints, floats, or bools, got {args}"
|
||||
)
|
||||
|
||||
if not isinstance(kwargs, dict):
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
# mypy: allow-untyped-defs
|
||||
|
||||
|
||||
import contextlib
|
||||
from contextlib import nullcontext
|
||||
from dataclasses import dataclass, field
|
||||
|
|
@ -70,9 +69,9 @@ class InvokeSubgraphHOP(HigherOrderOperator):
|
|||
identifier: Optional[str],
|
||||
*operands,
|
||||
):
|
||||
assert identifier is None or isinstance(
|
||||
identifier, str
|
||||
), "identifier must be a None or a string"
|
||||
assert identifier is None or isinstance(identifier, str), (
|
||||
"identifier must be a None or a string"
|
||||
)
|
||||
|
||||
assert all(
|
||||
isinstance(o, (torch.Tensor, int, torch.SymInt)) for o in operands
|
||||
|
|
@ -128,7 +127,11 @@ def invoke_subgraph_placeholder(func, *args, **kwargs):
|
|||
def _invoke_subgraph_placeholder_wrapper(func, args):
|
||||
return invoke_subgraph_placeholder(func, *args)
|
||||
|
||||
with _set_compilation_env(), torch._dynamo.utils.disable_cache_limit(), _temp_remove_pre_dispatch_torch_function_mode():
|
||||
with (
|
||||
_set_compilation_env(),
|
||||
torch._dynamo.utils.disable_cache_limit(),
|
||||
_temp_remove_pre_dispatch_torch_function_mode(),
|
||||
):
|
||||
with _temp_remove_metadata_torch_function_mode() as metadata_mode:
|
||||
if metadata_mode:
|
||||
backend = make_eager_backend_with_torch_function_mode(metadata_mode)
|
||||
|
|
|
|||
|
|
@ -36,9 +36,9 @@ aten = torch._ops.ops.aten
|
|||
def wrap_combine_fn_flat(
|
||||
*args, combine_fn, spec_init, spec_xs, num_init_leaves, num_inp_leaves
|
||||
):
|
||||
assert len(args) == (
|
||||
num_init_leaves + num_inp_leaves
|
||||
), f"Combin_fn received wrong number of arguments, expected {num_init_leaves + num_inp_leaves}, but got {len(args)}"
|
||||
assert len(args) == (num_init_leaves + num_inp_leaves), (
|
||||
f"Combin_fn received wrong number of arguments, expected {num_init_leaves + num_inp_leaves}, but got {len(args)}"
|
||||
)
|
||||
carry = pytree.tree_unflatten(args[:num_init_leaves], spec_init)
|
||||
xs = pytree.tree_unflatten(args[num_init_leaves:], spec_xs)
|
||||
return combine_fn(carry, xs)
|
||||
|
|
@ -73,13 +73,13 @@ def mask_list(
|
|||
# If other is None, then the elements of the `inp` list where the mask is False are removed
|
||||
# If other is not None, then the elements of the `inp` list where the mask is False are
|
||||
# replaced with the elements of the `other` list
|
||||
assert len(mask) == len(
|
||||
inp
|
||||
), "The length of the mask needs to be identical to the length of the input"
|
||||
assert len(mask) == len(inp), (
|
||||
"The length of the mask needs to be identical to the length of the input"
|
||||
)
|
||||
if other is not None:
|
||||
assert len(inp) == len(
|
||||
other
|
||||
), "If an input and an other list is provided, they need to have the same length"
|
||||
assert len(inp) == len(other), (
|
||||
"If an input and an other list is provided, they need to have the same length"
|
||||
)
|
||||
return [i if m else o for m, i, o in zip(mask, inp, other)]
|
||||
else:
|
||||
return [i for m, i in zip(mask, inp) if m]
|
||||
|
|
@ -97,9 +97,9 @@ def first_slice_copy_with_grad(li: list[Any]) -> list[Any]:
|
|||
|
||||
def split_into_chunks(iterable: Sequence[Any], chunk_sizes: list[int]) -> list[Any]:
|
||||
it = iter(iterable)
|
||||
assert sum(chunk_sizes) == len(
|
||||
iterable
|
||||
), "the sum of all chunks needs to match the length of the iterable."
|
||||
assert sum(chunk_sizes) == len(iterable), (
|
||||
"the sum of all chunks needs to match the length of the iterable."
|
||||
)
|
||||
return [list(itertools.islice(it, size)) for size in chunk_sizes]
|
||||
|
||||
|
||||
|
|
@ -166,6 +166,7 @@ def scan(
|
|||
# clone the output to avoid output-output aliasing
|
||||
return next_carry, y.clone()
|
||||
|
||||
|
||||
i0 = torch.zeros(1)
|
||||
xs = torch.arange(5)
|
||||
# returns torch.tensor([10.]), torch.tensor([[0], [1.], [3.], [6.], [10.]])
|
||||
|
|
@ -262,9 +263,9 @@ class ScanOp(HigherOrderOperator):
|
|||
# the additional_inputs being a list. See https://github.com/pytorch/pytorch/issues/145785
|
||||
# Once this issue is resolved, the assertion should only allow tuples
|
||||
# and the tuple cast should be removed
|
||||
assert isinstance(
|
||||
additional_inputs, (tuple, list)
|
||||
), "additional_inputs must be a tuple."
|
||||
assert isinstance(additional_inputs, (tuple, list)), (
|
||||
"additional_inputs must be a tuple."
|
||||
)
|
||||
additional_inputs = (
|
||||
tuple(additional_inputs)
|
||||
if isinstance(additional_inputs, list)
|
||||
|
|
|
|||
|
|
@ -35,9 +35,9 @@ class HopArgumentInfoGen:
|
|||
kw_only: bool = False,
|
||||
) -> HopArgumentInfo:
|
||||
if default_value is not None:
|
||||
assert type(example_value) == type(
|
||||
default_value
|
||||
), f"example_value type {type(example_value)} doesn't match default_value type: {type(default_value)}"
|
||||
assert type(example_value) == type(default_value), (
|
||||
f"example_value type {type(example_value)} doesn't match default_value type: {type(default_value)}"
|
||||
)
|
||||
|
||||
return HopArgumentInfo(
|
||||
name=name,
|
||||
|
|
@ -207,12 +207,12 @@ class CFunctionSchemaGen:
|
|||
args.append(CArgumentGen.from_hop_argument_info(i, arg_info))
|
||||
|
||||
# NOTE: we want the output to always be a single argument with torch._C.TupleType.
|
||||
assert isinstance(
|
||||
out_argument_info.example_value, tuple
|
||||
), f"expect out_argument_info's example_value to be a tuple but got {out_argument_info.example_value}"
|
||||
assert (
|
||||
not out_argument_info.is_mutated
|
||||
), "out_argument_info.is_mutated should always be set to False."
|
||||
assert isinstance(out_argument_info.example_value, tuple), (
|
||||
f"expect out_argument_info's example_value to be a tuple but got {out_argument_info.example_value}"
|
||||
)
|
||||
assert not out_argument_info.is_mutated, (
|
||||
"out_argument_info.is_mutated should always be set to False."
|
||||
)
|
||||
rets = None
|
||||
if len(out_argument_info.example_value) == 1:
|
||||
rets = [CArgumentGen.from_hop_argument_info(0, out_argument_info, True)]
|
||||
|
|
|
|||
|
|
@ -81,9 +81,9 @@ def enable_torchbind_tracing():
|
|||
torch.ScriptMethod.__call__ = torchbind_method_redispatch # type: ignore[method-assign]
|
||||
yield
|
||||
finally:
|
||||
assert (
|
||||
KNOWN_TYPES.pop() is torch.ScriptObject
|
||||
), "Someone else messed with KNOWN_TYPES during tracing, exploding."
|
||||
assert KNOWN_TYPES.pop() is torch.ScriptObject, (
|
||||
"Someone else messed with KNOWN_TYPES during tracing, exploding."
|
||||
)
|
||||
torch.ScriptMethod.__call__ = _orig_scriptmethod_call # type: ignore[method-assign]
|
||||
|
||||
|
||||
|
|
@ -127,9 +127,9 @@ def inner(mode, *args, **kwargs):
|
|||
|
||||
ret = track_tensor_tree(out, out_proxy, constant=None, tracer=mode.tracer)
|
||||
if "val" not in out_proxy.node.meta:
|
||||
assert out is None or isinstance(
|
||||
out, (int, float, bool)
|
||||
), "Currently, only these constant dtypes are supported to be returned from torchbind methods."
|
||||
assert out is None or isinstance(out, (int, float, bool)), (
|
||||
"Currently, only these constant dtypes are supported to be returned from torchbind methods."
|
||||
)
|
||||
out_proxy.node.meta["val"] = out
|
||||
return ret
|
||||
|
||||
|
|
|
|||
|
|
@ -93,7 +93,7 @@ def create_tma_experimental_metadata(
|
|||
|
||||
|
||||
def maybe_unpack_tma_experimental_metadata(
|
||||
tma_meta: Union[TMAExperimentalMetadata, TMAStableMetadata]
|
||||
tma_meta: Union[TMAExperimentalMetadata, TMAStableMetadata],
|
||||
) -> Optional[tuple[list[IntLikeType], list[IntLikeType], IntLikeType]]:
|
||||
if not tma_meta or len(tma_meta) != 2:
|
||||
return None
|
||||
|
|
@ -109,7 +109,7 @@ def create_tma_stable_metadata(
|
|||
|
||||
|
||||
def maybe_unpack_tma_stable_metadata(
|
||||
tma_meta: Union[TMAExperimentalMetadata, TMAStableMetadata]
|
||||
tma_meta: Union[TMAExperimentalMetadata, TMAStableMetadata],
|
||||
) -> Optional[tuple[list[IntLikeType]]]:
|
||||
if not tma_meta or len(tma_meta) != 2:
|
||||
return None
|
||||
|
|
@ -1122,7 +1122,8 @@ def trace_triton_kernel_wrapper(
|
|||
out = func_overload(**node_args)
|
||||
|
||||
proxy_args = pytree.tree_map(
|
||||
proxy_mode.tracer.unwrap_proxy, node_args # type: ignore[union-attr]
|
||||
proxy_mode.tracer.unwrap_proxy, # type: ignore[union-attr]
|
||||
node_args,
|
||||
)
|
||||
out_proxy = proxy_mode.tracer.create_proxy(
|
||||
"call_function",
|
||||
|
|
@ -1660,9 +1661,9 @@ class TritonHOPifier:
|
|||
|
||||
# Update the kwargs in each config
|
||||
# maybe_unpack_heuristic_result raises unsupported if the value is non-constant
|
||||
new_configs[config_idx].__dict__["kwargs"][
|
||||
kwarg_key
|
||||
] = self.maybe_unpack_heuristic_result(heuristic_result)
|
||||
new_configs[config_idx].__dict__["kwargs"][kwarg_key] = (
|
||||
self.maybe_unpack_heuristic_result(heuristic_result)
|
||||
)
|
||||
|
||||
iter_kernel = iter_kernel.fn
|
||||
assert isinstance(iter_kernel, JITFunction)
|
||||
|
|
@ -1742,9 +1743,9 @@ class TritonHOPifier:
|
|||
for config in new_configs:
|
||||
for name in special_param_names:
|
||||
if name not in config.__dict__["kwargs"]:
|
||||
assert (
|
||||
name in config.__dict__
|
||||
), f"{name} must be in autotuning configs to be used as a kernel parameter"
|
||||
assert name in config.__dict__, (
|
||||
f"{name} must be in autotuning configs to be used as a kernel parameter"
|
||||
)
|
||||
config.__dict__["kwargs"][name] = config.__dict__[name]
|
||||
updated = True
|
||||
|
||||
|
|
|
|||
|
|
@ -115,9 +115,9 @@ def reenter_make_fx(fn):
|
|||
|
||||
@functools.wraps(fn)
|
||||
def wrapped(*args):
|
||||
assert (
|
||||
_CURRENT_MAKE_FX_TRACER is not None
|
||||
), "Cannot reenter make_fx when we're not under a make_fx tracing session"
|
||||
assert _CURRENT_MAKE_FX_TRACER is not None, (
|
||||
"Cannot reenter make_fx when we're not under a make_fx tracing session"
|
||||
)
|
||||
return _CURRENT_MAKE_FX_TRACER.trace_subgraph(
|
||||
_maybe_run_with_interpreter(fn), *args
|
||||
)
|
||||
|
|
@ -323,20 +323,22 @@ def analyze_potential_input_alias_or_mutation(name, aliases, input_mutations):
|
|||
|
||||
def _has_potential_branch_input_mutation(gm, inputs, pre_dispatch=False):
|
||||
(
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
), inp_mutation = potential_input_alias_or_mutation(gm, inputs, pre_dispatch)
|
||||
(_, _, _),
|
||||
inp_mutation,
|
||||
) = potential_input_alias_or_mutation(gm, inputs, pre_dispatch)
|
||||
|
||||
return len(inp_mutation) > 0
|
||||
|
||||
|
||||
def has_potential_input_alias_or_mutation(gm, inputs, pre_dispatch=False):
|
||||
(
|
||||
inp_inp_alias_map,
|
||||
inp_out_alias_map,
|
||||
out_out_alias_map,
|
||||
), inp_mutation = potential_input_alias_or_mutation(gm, inputs, pre_dispatch)
|
||||
(
|
||||
inp_inp_alias_map,
|
||||
inp_out_alias_map,
|
||||
out_out_alias_map,
|
||||
),
|
||||
inp_mutation,
|
||||
) = potential_input_alias_or_mutation(gm, inputs, pre_dispatch)
|
||||
return (
|
||||
any(
|
||||
(
|
||||
|
|
@ -392,9 +394,7 @@ def _check_alias_and_mutation(graph_module, inputs_fake, name, pre_dispatch):
|
|||
graph_module, inputs_fake, pre_dispatch=pre_dispatch
|
||||
)
|
||||
if aliases:
|
||||
raise RuntimeError(
|
||||
f"{name} might be aliasing the input or the output!"
|
||||
) # noqa: F541
|
||||
raise RuntimeError(f"{name} might be aliasing the input or the output!") # noqa: F541
|
||||
if inp_mutation:
|
||||
raise RuntimeError(f"{name} might be modifying the input!") # noqa: F541
|
||||
|
||||
|
|
@ -504,9 +504,9 @@ def prepare_fw_with_masks_all_requires_grad(fn):
|
|||
# replaced with an all-zero tensor for better optimization
|
||||
def unmask_none_gradients(grads, operands):
|
||||
allowed_types = (torch.Tensor, int, torch.SymInt)
|
||||
assert all(
|
||||
isinstance(o, allowed_types) for o in operands
|
||||
), f"operands can only be of {allowed_types} but got {[type(o) for o in operands]}"
|
||||
assert all(isinstance(o, allowed_types) for o in operands), (
|
||||
f"operands can only be of {allowed_types} but got {[type(o) for o in operands]}"
|
||||
)
|
||||
|
||||
unmasked_grads = []
|
||||
for g, o in zip(grads, operands):
|
||||
|
|
@ -762,7 +762,9 @@ def validate_subgraph_args_types(lifted_args: Union[tuple[Any, ...], list[Any]])
|
|||
allowed_types = (torch.Tensor, int, torch.SymInt)
|
||||
assert all(
|
||||
isinstance(arg, (torch.Tensor, int, torch.SymInt)) for arg in lifted_args
|
||||
), f"{lifted_args} can only be of {allowed_types} but got {tuple(type(arg) for arg in lifted_args)}"
|
||||
), (
|
||||
f"{lifted_args} can only be of {allowed_types} but got {tuple(type(arg) for arg in lifted_args)}"
|
||||
)
|
||||
|
||||
|
||||
# TODO: Return a more detailed information as to which node
|
||||
|
|
@ -793,7 +795,11 @@ def check_input_alias_and_mutation_return_outputs(
|
|||
# This function can be called under autograd, functional, proxy and fake tensor mode.
|
||||
# We need to return either a fake tensor or a real tensor depending on the mode.
|
||||
# to detect the input mutation/aliasing.
|
||||
with disable_proxy_modes_tracing(), disable_functional_mode(), suspend_functionalization():
|
||||
with (
|
||||
disable_proxy_modes_tracing(),
|
||||
disable_functional_mode(),
|
||||
suspend_functionalization(),
|
||||
):
|
||||
|
||||
def _from_functional_tensor(t: torch.Tensor) -> torch.Tensor:
|
||||
if isinstance(t, FunctionalTensor) or torch._is_functional_tensor(t):
|
||||
|
|
@ -928,13 +934,11 @@ F = TypeVar("F", bound=Callable)
|
|||
|
||||
|
||||
@overload
|
||||
def register_fake(hop, fn: None = None) -> Callable[[F], F]:
|
||||
...
|
||||
def register_fake(hop, fn: None = None) -> Callable[[F], F]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def register_fake(hop, fn: F) -> F:
|
||||
...
|
||||
def register_fake(hop, fn: F) -> F: ...
|
||||
|
||||
|
||||
def register_fake(hop, fn=None):
|
||||
|
|
|
|||
|
|
@ -202,12 +202,12 @@ def while_loop_dense(cond_fn, body_fn, carried_inputs, additional_inputs):
|
|||
while pred := cond_fn(*carried_vals, *additional_inputs):
|
||||
_validate_cond_output(pred)
|
||||
out = body_fn(*carried_vals, *additional_inputs)
|
||||
assert isinstance(
|
||||
out, tuple
|
||||
), f"body_fn should return a tuple but got {type(out)}"
|
||||
assert len(out) == len(
|
||||
carried_inputs
|
||||
), "body_fn should return the same number of elements as carried_inputs"
|
||||
assert isinstance(out, tuple), (
|
||||
f"body_fn should return a tuple but got {type(out)}"
|
||||
)
|
||||
assert len(out) == len(carried_inputs), (
|
||||
"body_fn should return the same number of elements as carried_inputs"
|
||||
)
|
||||
carried_vals = out
|
||||
return carried_vals
|
||||
|
||||
|
|
@ -230,9 +230,9 @@ def _find_or_create_fake_mode() -> FakeTensorMode:
|
|||
def _create_unbacked_symint(
|
||||
fake_mode: FakeTensorMode, ignore_fresh_unbacked_symbols: bool
|
||||
) -> torch.SymInt:
|
||||
assert (
|
||||
fake_mode is not None and fake_mode.shape_env is not None
|
||||
), "Must provide a fake_mode with shape_env."
|
||||
assert fake_mode is not None and fake_mode.shape_env is not None, (
|
||||
"Must provide a fake_mode with shape_env."
|
||||
)
|
||||
ctx = (
|
||||
contextlib.nullcontext()
|
||||
if not ignore_fresh_unbacked_symbols
|
||||
|
|
|
|||
|
|
@ -28,8 +28,7 @@ def custom_op(
|
|||
mutates_args: Union[str, Iterable[str]],
|
||||
device_types: device_types_t = None,
|
||||
schema: Optional[str] = None,
|
||||
) -> Callable[[Callable[..., object]], "CustomOpDef"]:
|
||||
...
|
||||
) -> Callable[[Callable[..., object]], "CustomOpDef"]: ...
|
||||
|
||||
|
||||
@overload
|
||||
|
|
@ -41,8 +40,7 @@ def custom_op(
|
|||
mutates_args: Union[str, Iterable[str]],
|
||||
device_types: device_types_t = None,
|
||||
schema: Optional[str] = None,
|
||||
) -> "CustomOpDef":
|
||||
...
|
||||
) -> "CustomOpDef": ...
|
||||
|
||||
|
||||
@exposed_in("torch.library")
|
||||
|
|
@ -448,10 +446,10 @@ class CustomOpDef:
|
|||
>>>
|
||||
>>> @nonzero.register_fake
|
||||
>>> def _(x):
|
||||
>>> # Number of nonzero-elements is data-dependent.
|
||||
>>> # Since we cannot peek at the data in an abstract impl,
|
||||
>>> # we use the ctx object to construct a new symint that
|
||||
>>> # represents the data-dependent size.
|
||||
>>> # Number of nonzero-elements is data-dependent.
|
||||
>>> # Since we cannot peek at the data in an abstract impl,
|
||||
>>> # we use the ctx object to construct a new symint that
|
||||
>>> # represents the data-dependent size.
|
||||
>>> ctx = torch.library.get_ctx()
|
||||
>>> nnz = ctx.new_dynamic_size()
|
||||
>>> shape = [nnz, x.dim()]
|
||||
|
|
@ -561,7 +559,7 @@ class CustomOpDef:
|
|||
>>>
|
||||
>>> x = torch.randn(3, requires_grad=True)
|
||||
>>> y = numpy_sin(x)
|
||||
>>> grad_x, = torch.autograd.grad(y, x, torch.ones_like(y))
|
||||
>>> (grad_x,) = torch.autograd.grad(y, x, torch.ones_like(y))
|
||||
>>> assert torch.allclose(grad_x, x.cos())
|
||||
>>>
|
||||
>>> # Example with a keyword-only arg
|
||||
|
|
@ -581,7 +579,7 @@ class CustomOpDef:
|
|||
>>>
|
||||
>>> x = torch.randn(3, requires_grad=True)
|
||||
>>> y = numpy_mul(x, val=3.14)
|
||||
>>> grad_x, = torch.autograd.grad(y, x, torch.ones_like(y))
|
||||
>>> (grad_x,) = torch.autograd.grad(y, x, torch.ones_like(y))
|
||||
>>> assert torch.allclose(grad_x, torch.full_like(x, 3.14))
|
||||
|
||||
"""
|
||||
|
|
@ -919,7 +917,7 @@ def get_library_allowing_overwrite(
|
|||
|
||||
|
||||
def _maybe_get_opdef(
|
||||
op: Union[CustomOpDef, _ops.OpOverload, str]
|
||||
op: Union[CustomOpDef, _ops.OpOverload, str],
|
||||
) -> Optional[CustomOpDef]:
|
||||
if isinstance(op, CustomOpDef):
|
||||
return op
|
||||
|
|
|
|||
|
|
@ -150,13 +150,13 @@ def infer_schema(
|
|||
"the arguments that are mutated or the string 'unknown'. "
|
||||
)
|
||||
if schema_type.startswith("Tensor"):
|
||||
schema_type = f"Tensor(a{idx}!){schema_type[len('Tensor'):]}"
|
||||
schema_type = f"Tensor(a{idx}!){schema_type[len('Tensor') :]}"
|
||||
elif name in mutates_args:
|
||||
if not schema_type.startswith("Tensor"):
|
||||
error_fn(
|
||||
f"Parameter {name} is in mutable_args but only Tensors or collections of Tensors can be mutated"
|
||||
)
|
||||
schema_type = f"Tensor(a{idx}!){schema_type[len('Tensor'):]}"
|
||||
schema_type = f"Tensor(a{idx}!){schema_type[len('Tensor') :]}"
|
||||
seen_args.add(name)
|
||||
if param.default is inspect.Parameter.empty:
|
||||
params.append(f"{schema_type} {name}")
|
||||
|
|
|
|||
|
|
@ -1242,12 +1242,12 @@ def trace_structured(
|
|||
"frame_compile_id",
|
||||
"attempt",
|
||||
]
|
||||
assert callable(
|
||||
metadata_fn
|
||||
), f"metadata_fn should be callable, but got {type(metadata_fn)}"
|
||||
assert callable(
|
||||
payload_fn
|
||||
), f"payload_fn should be callable, but got {type(payload_fn)}"
|
||||
assert callable(metadata_fn), (
|
||||
f"metadata_fn should be callable, but got {type(metadata_fn)}"
|
||||
)
|
||||
assert callable(payload_fn), (
|
||||
f"payload_fn should be callable, but got {type(payload_fn)}"
|
||||
)
|
||||
# trace_log never propagates and is ALWAYS DEBUG, so also check that there
|
||||
# are handlers instead of checking the log level
|
||||
if trace_log.handlers:
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
"""
|
||||
Utilities for converting data types into structured JSON for dumping.
|
||||
"""
|
||||
|
||||
import inspect
|
||||
import os
|
||||
import traceback
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
# mypy: ignore-errors
|
||||
|
||||
""" Define analogs of numpy dtypes supported by pytorch.
|
||||
"""Define analogs of numpy dtypes supported by pytorch.
|
||||
Define the scalar types and supported dtypes and numpy <--> torch dtype mappings.
|
||||
"""
|
||||
|
||||
import builtins
|
||||
|
||||
import torch
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@
|
|||
Here `dtype` is always a torch.dtype, this module knows nothing about
|
||||
scalar types, wrapper dtypes or anything like that. PyTorch only.
|
||||
"""
|
||||
|
||||
from collections import namedtuple
|
||||
|
||||
import torch
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@
|
|||
Things imported from here have numpy-compatible signatures but operate on
|
||||
pytorch tensors.
|
||||
"""
|
||||
|
||||
# Contents of this module ends up in the main namespace via _funcs.py
|
||||
# where type annotations are used in conjunction with the @normalizer decorator.
|
||||
from __future__ import annotations
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
# mypy: ignore-errors
|
||||
|
||||
""" "Normalize" arguments: convert array_likes to tensors, dtypes to torch dtypes and so on.
|
||||
"""
|
||||
""" "Normalize" arguments: convert array_likes to tensors, dtypes to torch dtypes and so on."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
|
|
|
|||
|
|
@ -1,10 +1,11 @@
|
|||
# mypy: ignore-errors
|
||||
|
||||
""" Implementation of reduction operations, to be wrapped into arrays, dtypes etc
|
||||
"""Implementation of reduction operations, to be wrapped into arrays, dtypes etc
|
||||
in the 'public' layer.
|
||||
|
||||
Anything here only deals with torch objects, e.g. "dtype" is a torch.dtype instance etc
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
# mypy: ignore-errors
|
||||
|
||||
"""Assorted utilities, which do not need anything other then torch and stdlib.
|
||||
"""
|
||||
"""Assorted utilities, which do not need anything other then torch and stdlib."""
|
||||
|
||||
import operator
|
||||
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ NumPy has strict guarantees on reproducibility etc; here we don't give any.
|
|||
Q: default dtype is float64 in numpy
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@
|
|||
Utility function to facilitate testing.
|
||||
|
||||
"""
|
||||
|
||||
import contextlib
|
||||
import gc
|
||||
import operator
|
||||
|
|
@ -167,7 +168,7 @@ def assert_equal(actual, desired, err_msg="", verbose=True):
|
|||
|
||||
Examples
|
||||
--------
|
||||
>>> np.testing.assert_equal([4,5], [4,6])
|
||||
>>> np.testing.assert_equal([4, 5], [4, 6])
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
AssertionError:
|
||||
|
|
@ -298,8 +299,12 @@ def print_assert_equal(test_string, actual, desired):
|
|||
|
||||
Examples
|
||||
--------
|
||||
>>> np.testing.print_assert_equal('Test XYZ of func xyz', [0, 1], [0, 1]) # doctest: +SKIP
|
||||
>>> np.testing.print_assert_equal('Test XYZ of func xyz', [0, 1], [0, 2]) # doctest: +SKIP
|
||||
>>> np.testing.print_assert_equal(
|
||||
... "Test XYZ of func xyz", [0, 1], [0, 1]
|
||||
... ) # doctest: +SKIP
|
||||
>>> np.testing.print_assert_equal(
|
||||
... "Test XYZ of func xyz", [0, 1], [0, 2]
|
||||
... ) # doctest: +SKIP
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
AssertionError: Test XYZ of func xyz failed
|
||||
|
|
@ -377,8 +382,9 @@ def assert_almost_equal(actual, desired, decimal=7, err_msg="", verbose=True):
|
|||
ACTUAL: 2.3333333333333
|
||||
DESIRED: 2.33333334
|
||||
|
||||
>>> assert_almost_equal(np.array([1.0,2.3333333333333]),
|
||||
... np.array([1.0,2.33333334]), decimal=9)
|
||||
>>> assert_almost_equal(
|
||||
... np.array([1.0, 2.3333333333333]), np.array([1.0, 2.33333334]), decimal=9
|
||||
... )
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
AssertionError:
|
||||
|
|
@ -487,11 +493,19 @@ def assert_approx_equal(actual, desired, significant=7, err_msg="", verbose=True
|
|||
|
||||
Examples
|
||||
--------
|
||||
>>> np.testing.assert_approx_equal(0.12345677777777e-20, 0.1234567e-20) # doctest: +SKIP
|
||||
>>> np.testing.assert_approx_equal(0.12345670e-20, 0.12345671e-20, # doctest: +SKIP
|
||||
... significant=8)
|
||||
>>> np.testing.assert_approx_equal(0.12345670e-20, 0.12345672e-20, # doctest: +SKIP
|
||||
... significant=8)
|
||||
>>> np.testing.assert_approx_equal(
|
||||
... 0.12345677777777e-20, 0.1234567e-20
|
||||
... ) # doctest: +SKIP
|
||||
>>> np.testing.assert_approx_equal(
|
||||
... 0.12345670e-20,
|
||||
... 0.12345671e-20, # doctest: +SKIP
|
||||
... significant=8,
|
||||
... )
|
||||
>>> np.testing.assert_approx_equal(
|
||||
... 0.12345670e-20,
|
||||
... 0.12345672e-20, # doctest: +SKIP
|
||||
... significant=8,
|
||||
... )
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
AssertionError:
|
||||
|
|
@ -501,7 +515,7 @@ def assert_approx_equal(actual, desired, significant=7, err_msg="", verbose=True
|
|||
|
||||
the evaluated condition that raises the exception is
|
||||
|
||||
>>> abs(0.12345670e-20/1e-21 - 0.12345672e-20/1e-21) >= 10**-(8-1)
|
||||
>>> abs(0.12345670e-20 / 1e-21 - 0.12345672e-20 / 1e-21) >= 10 ** -(8 - 1)
|
||||
True
|
||||
|
||||
"""
|
||||
|
|
@ -776,15 +790,16 @@ def assert_array_equal(x, y, err_msg="", verbose=True, *, strict=False):
|
|||
--------
|
||||
The first assert does not raise an exception:
|
||||
|
||||
>>> np.testing.assert_array_equal([1.0,2.33333,np.nan],
|
||||
... [np.exp(0),2.33333, np.nan])
|
||||
>>> np.testing.assert_array_equal(
|
||||
... [1.0, 2.33333, np.nan], [np.exp(0), 2.33333, np.nan]
|
||||
... )
|
||||
|
||||
Use `assert_allclose` or one of the nulp (number of floating point values)
|
||||
functions for these cases instead:
|
||||
|
||||
>>> np.testing.assert_allclose([1.0,np.pi,np.nan],
|
||||
... [1, np.sqrt(np.pi)**2, np.nan],
|
||||
... rtol=1e-10, atol=0)
|
||||
>>> np.testing.assert_allclose(
|
||||
... [1.0, np.pi, np.nan], [1, np.sqrt(np.pi) ** 2, np.nan], rtol=1e-10, atol=0
|
||||
... )
|
||||
|
||||
As mentioned in the Notes section, `assert_array_equal` has special
|
||||
handling for scalars. Here the test checks that each value in `x` is 3:
|
||||
|
|
@ -809,7 +824,7 @@ def assert_array_equal(x, y, err_msg="", verbose=True, *, strict=False):
|
|||
The `strict` parameter also ensures that the array data types match:
|
||||
|
||||
>>> x = np.array([2, 2, 2])
|
||||
>>> y = np.array([2., 2., 2.], dtype=np.float32)
|
||||
>>> y = np.array([2.0, 2.0, 2.0], dtype=np.float32)
|
||||
>>> np.testing.assert_array_equal(x, y, strict=True)
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
|
|
@ -881,11 +896,11 @@ def assert_array_almost_equal(x, y, decimal=6, err_msg="", verbose=True):
|
|||
--------
|
||||
the first assert does not raise an exception
|
||||
|
||||
>>> np.testing.assert_array_almost_equal([1.0,2.333,np.nan],
|
||||
... [1.0,2.333,np.nan])
|
||||
>>> np.testing.assert_array_almost_equal([1.0, 2.333, np.nan], [1.0, 2.333, np.nan])
|
||||
|
||||
>>> np.testing.assert_array_almost_equal([1.0,2.33333,np.nan],
|
||||
... [1.0,2.33339,np.nan], decimal=5)
|
||||
>>> np.testing.assert_array_almost_equal(
|
||||
... [1.0, 2.33333, np.nan], [1.0, 2.33339, np.nan], decimal=5
|
||||
... )
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
AssertionError:
|
||||
|
|
@ -897,8 +912,9 @@ def assert_array_almost_equal(x, y, decimal=6, err_msg="", verbose=True):
|
|||
x: torch.ndarray([1.0000, 2.3333, nan], dtype=float64)
|
||||
y: torch.ndarray([1.0000, 2.3334, nan], dtype=float64)
|
||||
|
||||
>>> np.testing.assert_array_almost_equal([1.0,2.33333,np.nan],
|
||||
... [1.0,2.33333, 5], decimal=5)
|
||||
>>> np.testing.assert_array_almost_equal(
|
||||
... [1.0, 2.33333, np.nan], [1.0, 2.33333, 5], decimal=5
|
||||
... )
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
AssertionError:
|
||||
|
|
@ -1054,8 +1070,8 @@ def assert_string_equal(actual, desired):
|
|||
|
||||
Examples
|
||||
--------
|
||||
>>> np.testing.assert_string_equal('abc', 'abc') # doctest: +SKIP
|
||||
>>> np.testing.assert_string_equal('abc', 'abcd') # doctest: +SKIP
|
||||
>>> np.testing.assert_string_equal("abc", "abc") # doctest: +SKIP
|
||||
>>> np.testing.assert_string_equal("abc", "abcd") # doctest: +SKIP
|
||||
Traceback (most recent call last):
|
||||
File "<stdin>", line 1, in <module>
|
||||
...
|
||||
|
|
@ -1341,11 +1357,11 @@ def assert_array_almost_equal_nulp(x, y, nulp=1):
|
|||
|
||||
Examples
|
||||
--------
|
||||
>>> x = np.array([1., 1e-10, 1e-20])
|
||||
>>> x = np.array([1.0, 1e-10, 1e-20])
|
||||
>>> eps = np.finfo(x.dtype).eps
|
||||
>>> np.testing.assert_array_almost_equal_nulp(x, x*eps/2 + x) # doctest: +SKIP
|
||||
>>> np.testing.assert_array_almost_equal_nulp(x, x * eps / 2 + x) # doctest: +SKIP
|
||||
|
||||
>>> np.testing.assert_array_almost_equal_nulp(x, x*eps + x) # doctest: +SKIP
|
||||
>>> np.testing.assert_array_almost_equal_nulp(x, x * eps + x) # doctest: +SKIP
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
AssertionError: X and Y are not equal to 1 ULP (max is 2)
|
||||
|
|
@ -1404,7 +1420,7 @@ def assert_array_max_ulp(a, b, maxulp=1, dtype=None):
|
|||
|
||||
Examples
|
||||
--------
|
||||
>>> a = np.linspace(0., 1., 100)
|
||||
>>> a = np.linspace(0.0, 1.0, 100)
|
||||
>>> res = np.testing.assert_array_max_ulp(a, np.arcsin(np.sin(a))) # doctest: +SKIP
|
||||
|
||||
"""
|
||||
|
|
@ -1562,7 +1578,7 @@ def assert_warns(warning_class, *args, **kwargs):
|
|||
>>> import warnings
|
||||
>>> def deprecated_func(num):
|
||||
... warnings.warn("Please upgrade", DeprecationWarning)
|
||||
... return num*num
|
||||
... return num * num
|
||||
>>> with np.testing.assert_warns(DeprecationWarning):
|
||||
... assert deprecated_func(4) == 16
|
||||
>>> # or passing a func
|
||||
|
|
@ -1663,19 +1679,29 @@ def _gen_alignment_data(dtype=float32, type="binary", max_size=24):
|
|||
yield out, inp(), ufmt % (o, o, s, dtype, "out of place")
|
||||
d = inp()
|
||||
yield d, d, ufmt % (o, o, s, dtype, "in place")
|
||||
yield out[1:], inp()[:-1], ufmt % (
|
||||
o + 1,
|
||||
o,
|
||||
s - 1,
|
||||
dtype,
|
||||
"out of place",
|
||||
yield (
|
||||
out[1:],
|
||||
inp()[:-1],
|
||||
ufmt
|
||||
% (
|
||||
o + 1,
|
||||
o,
|
||||
s - 1,
|
||||
dtype,
|
||||
"out of place",
|
||||
),
|
||||
)
|
||||
yield out[:-1], inp()[1:], ufmt % (
|
||||
o,
|
||||
o + 1,
|
||||
s - 1,
|
||||
dtype,
|
||||
"out of place",
|
||||
yield (
|
||||
out[:-1],
|
||||
inp()[1:],
|
||||
ufmt
|
||||
% (
|
||||
o,
|
||||
o + 1,
|
||||
s - 1,
|
||||
dtype,
|
||||
"out of place",
|
||||
),
|
||||
)
|
||||
yield inp()[:-1], inp()[1:], ufmt % (o, o + 1, s - 1, dtype, "aliased")
|
||||
yield inp()[1:], inp()[:-1], ufmt % (o + 1, o, s - 1, dtype, "aliased")
|
||||
|
|
@ -1691,53 +1717,89 @@ def _gen_alignment_data(dtype=float32, type="binary", max_size=24):
|
|||
yield d, d, inp2(), bfmt % (o, o, o, s, dtype, "in place1")
|
||||
d = inp2()
|
||||
yield d, inp1(), d, bfmt % (o, o, o, s, dtype, "in place2")
|
||||
yield out[1:], inp1()[:-1], inp2()[:-1], bfmt % (
|
||||
o + 1,
|
||||
o,
|
||||
o,
|
||||
s - 1,
|
||||
dtype,
|
||||
"out of place",
|
||||
yield (
|
||||
out[1:],
|
||||
inp1()[:-1],
|
||||
inp2()[:-1],
|
||||
bfmt
|
||||
% (
|
||||
o + 1,
|
||||
o,
|
||||
o,
|
||||
s - 1,
|
||||
dtype,
|
||||
"out of place",
|
||||
),
|
||||
)
|
||||
yield out[:-1], inp1()[1:], inp2()[:-1], bfmt % (
|
||||
o,
|
||||
o + 1,
|
||||
o,
|
||||
s - 1,
|
||||
dtype,
|
||||
"out of place",
|
||||
yield (
|
||||
out[:-1],
|
||||
inp1()[1:],
|
||||
inp2()[:-1],
|
||||
bfmt
|
||||
% (
|
||||
o,
|
||||
o + 1,
|
||||
o,
|
||||
s - 1,
|
||||
dtype,
|
||||
"out of place",
|
||||
),
|
||||
)
|
||||
yield out[:-1], inp1()[:-1], inp2()[1:], bfmt % (
|
||||
o,
|
||||
o,
|
||||
o + 1,
|
||||
s - 1,
|
||||
dtype,
|
||||
"out of place",
|
||||
yield (
|
||||
out[:-1],
|
||||
inp1()[:-1],
|
||||
inp2()[1:],
|
||||
bfmt
|
||||
% (
|
||||
o,
|
||||
o,
|
||||
o + 1,
|
||||
s - 1,
|
||||
dtype,
|
||||
"out of place",
|
||||
),
|
||||
)
|
||||
yield inp1()[1:], inp1()[:-1], inp2()[:-1], bfmt % (
|
||||
o + 1,
|
||||
o,
|
||||
o,
|
||||
s - 1,
|
||||
dtype,
|
||||
"aliased",
|
||||
yield (
|
||||
inp1()[1:],
|
||||
inp1()[:-1],
|
||||
inp2()[:-1],
|
||||
bfmt
|
||||
% (
|
||||
o + 1,
|
||||
o,
|
||||
o,
|
||||
s - 1,
|
||||
dtype,
|
||||
"aliased",
|
||||
),
|
||||
)
|
||||
yield inp1()[:-1], inp1()[1:], inp2()[:-1], bfmt % (
|
||||
o,
|
||||
o + 1,
|
||||
o,
|
||||
s - 1,
|
||||
dtype,
|
||||
"aliased",
|
||||
yield (
|
||||
inp1()[:-1],
|
||||
inp1()[1:],
|
||||
inp2()[:-1],
|
||||
bfmt
|
||||
% (
|
||||
o,
|
||||
o + 1,
|
||||
o,
|
||||
s - 1,
|
||||
dtype,
|
||||
"aliased",
|
||||
),
|
||||
)
|
||||
yield inp1()[:-1], inp1()[:-1], inp2()[1:], bfmt % (
|
||||
o,
|
||||
o,
|
||||
o + 1,
|
||||
s - 1,
|
||||
dtype,
|
||||
"aliased",
|
||||
yield (
|
||||
inp1()[:-1],
|
||||
inp1()[:-1],
|
||||
inp2()[1:],
|
||||
bfmt
|
||||
% (
|
||||
o,
|
||||
o,
|
||||
o + 1,
|
||||
s - 1,
|
||||
dtype,
|
||||
"aliased",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -1818,9 +1880,10 @@ class clear_and_catch_warnings(warnings.catch_warnings):
|
|||
--------
|
||||
>>> import warnings
|
||||
>>> with np.testing.clear_and_catch_warnings( # doctest: +SKIP
|
||||
... modules=[np.core.fromnumeric]):
|
||||
... warnings.simplefilter('always')
|
||||
... warnings.filterwarnings('ignore', module='np.core.fromnumeric')
|
||||
... modules=[np.core.fromnumeric]
|
||||
... ):
|
||||
... warnings.simplefilter("always")
|
||||
... warnings.filterwarnings("ignore", module="np.core.fromnumeric")
|
||||
... # do something that raises a warning but ignore those in
|
||||
... # np.core.fromnumeric
|
||||
"""
|
||||
|
|
@ -1918,6 +1981,8 @@ class suppress_warnings:
|
|||
|
||||
sup = np.testing.suppress_warnings()
|
||||
sup.filter(module=np.ma.core) # module must match exactly
|
||||
|
||||
|
||||
@sup
|
||||
def some_function():
|
||||
# do something which causes a warning in np.ma.core
|
||||
|
|
|
|||
|
|
@ -2513,7 +2513,11 @@ def _full_aten(
|
|||
) -> Tensor:
|
||||
# Note that Mypy thinks torch.full can't accept a complex fill_value
|
||||
return torch.full(
|
||||
shape, fill_value, dtype=dtype, device=device, requires_grad=requires_grad # type: ignore[arg-type]
|
||||
shape,
|
||||
fill_value,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
requires_grad=requires_grad, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -2556,7 +2560,11 @@ def _full_like_aten(
|
|||
) -> Tensor:
|
||||
# Note that Mypy thinks torch.full can't accept a complex fill_value
|
||||
return torch.full_like(
|
||||
a, fill_value, dtype=dtype, device=device, requires_grad=requires_grad # type: ignore[arg-type]
|
||||
a,
|
||||
fill_value,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
requires_grad=requires_grad, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -340,9 +340,9 @@ def register_graphsafe_run_with_rng_state_op():
|
|||
@graphsafe_run_with_rng_state.py_impl(DispatchKey.BackendSelect)
|
||||
def impl_backend_select(op, *args, rng_state=None, **kwargs):
|
||||
device = get_device(args, kwargs)
|
||||
assert (
|
||||
device == "cuda"
|
||||
), f"GraphSafe RNG operations only supported for CUDA, got {device}"
|
||||
assert device == "cuda", (
|
||||
f"GraphSafe RNG operations only supported for CUDA, got {device}"
|
||||
)
|
||||
return impl_cuda(op, *args, rng_state=rng_state, **kwargs)
|
||||
|
||||
@graphsafe_run_with_rng_state.py_impl(FakeTensorMode)
|
||||
|
|
|
|||
|
|
@ -33,17 +33,13 @@ if TYPE_CHECKING:
|
|||
import sympy
|
||||
|
||||
class _WorksWithInt(typing.Protocol):
|
||||
def __add__(self, other: Any) -> typing.Self:
|
||||
...
|
||||
def __add__(self, other: Any) -> typing.Self: ...
|
||||
|
||||
def __radd__(self, other: Any) -> typing.Self:
|
||||
...
|
||||
def __radd__(self, other: Any) -> typing.Self: ...
|
||||
|
||||
def __mul__(self, other: Any) -> typing.Self:
|
||||
...
|
||||
def __mul__(self, other: Any) -> typing.Self: ...
|
||||
|
||||
def __rmul__(self, other: Any) -> typing.Self:
|
||||
...
|
||||
def __rmul__(self, other: Any) -> typing.Self: ...
|
||||
|
||||
_IntLikeT = TypeVar("_IntLikeT", bound=_WorksWithInt)
|
||||
|
||||
|
|
@ -292,9 +288,7 @@ def is_contiguous(a: TensorLikeType, false_if_dde=False) -> bool:
|
|||
# can assume x is not 0 in expected_stride equation. This make the check consistent with
|
||||
# make_contiguous_strides_for. If we make a tensor and used strides from make_contiguous_strides_for
|
||||
# and then called definitely_contiguous we should get True.
|
||||
expected_stride *= (
|
||||
x if is_nested_int(x) else sym_max(x, 1)
|
||||
) # type:ignore[assignment]
|
||||
expected_stride *= x if is_nested_int(x) else sym_max(x, 1) # type:ignore[assignment]
|
||||
|
||||
return True
|
||||
|
||||
|
|
@ -912,7 +906,7 @@ def extract_shape(*args, allow_cpu_scalar_tensors: bool) -> Optional[ShapeType]:
|
|||
# Extracts dimensions that might be passed either as a list/tuple or as varargs.
|
||||
# A typical case is Tensor.permute .
|
||||
def extract_dims_from_varargs(
|
||||
dims: Union[DimsSequenceType, tuple[DimsSequenceType, ...]]
|
||||
dims: Union[DimsSequenceType, tuple[DimsSequenceType, ...]],
|
||||
) -> DimsSequenceType:
|
||||
if dims and isinstance(dims[0], Sequence):
|
||||
assert len(dims) == 1
|
||||
|
|
@ -1234,7 +1228,7 @@ def get_higher_dtype(
|
|||
assert b is None or isinstance(b, (torch.dtype, TensorLike, Number))
|
||||
|
||||
def _extract_dtype(
|
||||
x: Optional[Union[torch.dtype, TensorLikeType, NumberType]]
|
||||
x: Optional[Union[torch.dtype, TensorLikeType, NumberType]],
|
||||
) -> Optional[torch.dtype]:
|
||||
if x is None:
|
||||
return None
|
||||
|
|
@ -1452,7 +1446,7 @@ class RETURN_TYPE(Enum):
|
|||
|
||||
# TODO: when NumberType contains the sym types, can simplify this
|
||||
def number_type(
|
||||
x: Union[NumberType, torch.SymInt, torch.SymFloat, torch.SymBool]
|
||||
x: Union[NumberType, torch.SymInt, torch.SymFloat, torch.SymBool],
|
||||
) -> type:
|
||||
if isinstance(x, torch.SymInt):
|
||||
return int
|
||||
|
|
@ -1708,9 +1702,7 @@ def make_contiguous_strides_for(
|
|||
strides = []
|
||||
for l in reversed(shape):
|
||||
strides.append(multiplier)
|
||||
multiplier *= (
|
||||
l if is_nested_int(l) else sym_max(l, 1)
|
||||
) # type:ignore[assignment]
|
||||
multiplier *= l if is_nested_int(l) else sym_max(l, 1) # type:ignore[assignment]
|
||||
|
||||
result = tuple(reversed(strides))
|
||||
|
||||
|
|
@ -1860,7 +1852,9 @@ def compute_required_storage_length(
|
|||
|
||||
>>> # xdoctest: +SKIP(failing)
|
||||
>>> t2 = torch.empty_strided((1, 2, 3), (5, 7, 11))
|
||||
>>> size = compute_required_storage_length(t2.shape, t2.stride(), t2.storage_offset())
|
||||
>>> size = compute_required_storage_length(
|
||||
... t2.shape, t2.stride(), t2.storage_offset()
|
||||
... )
|
||||
>>> size == t.storage().size()
|
||||
True
|
||||
|
||||
|
|
@ -1870,7 +1864,9 @@ def compute_required_storage_length(
|
|||
>>> slice.storage().size()
|
||||
100
|
||||
|
||||
>>> compute_required_storage_length(slice.shape, slice.stride(), slice.storage_offset())
|
||||
>>> compute_required_storage_length(
|
||||
... slice.shape, slice.stride(), slice.storage_offset()
|
||||
... )
|
||||
40
|
||||
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -316,8 +316,7 @@ def out_wrapper(
|
|||
and len(result) == len(out_names) # type: ignore[arg-type]
|
||||
)
|
||||
or (
|
||||
fn.__name__ == "unbind"
|
||||
and isinstance(result, (list, tuple)) # type: ignore[arg-type]
|
||||
fn.__name__ == "unbind" and isinstance(result, (list, tuple)) # type: ignore[arg-type]
|
||||
)
|
||||
)
|
||||
# unbind_copy is a special case: see https://github.com/pytorch/pytorch/issues/130829
|
||||
|
|
@ -342,9 +341,15 @@ def out_wrapper(
|
|||
assert isinstance(out, TensorLike)
|
||||
# These two operations are done in-place
|
||||
_maybe_resize_out(
|
||||
out, result.shape, maybe_compute_memory_format(result) # type: ignore[union-attr]
|
||||
out,
|
||||
result.shape, # type: ignore[union-attr]
|
||||
maybe_compute_memory_format(result),
|
||||
)
|
||||
_safe_copy_out(
|
||||
copy_from=result, # type: ignore[arg-type]
|
||||
copy_to=out,
|
||||
exact_dtype=exact_dtype,
|
||||
)
|
||||
_safe_copy_out(copy_from=result, copy_to=out, exact_dtype=exact_dtype) # type: ignore[arg-type]
|
||||
else:
|
||||
if fn.__name__ != "unbind":
|
||||
assert isinstance(out, tuple) # type: ignore[arg-type]
|
||||
|
|
@ -385,7 +390,8 @@ def out_wrapper(
|
|||
params = sorted(params, key=lambda p: p.kind)
|
||||
|
||||
_fn.__signature__ = inspect.Signature( # type: ignore[attr-defined]
|
||||
parameters=params, return_annotation=return_type # type: ignore[arg-type]
|
||||
parameters=params,
|
||||
return_annotation=return_type, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
_fn.__annotations__ = dict(getattr(fn, "__annotations__", {}))
|
||||
|
|
@ -400,7 +406,9 @@ def out_wrapper(
|
|||
# Add an indicator attribute that can be used in special cases
|
||||
# where having a function wrapped by `out_wrapper` is not desirable e.g.
|
||||
# jit
|
||||
_fn._torch_decompositions_out_wrapper = f"This function is wrapped by {out_wrapper.__module__}.out_wrapper" # type: ignore[attr-defined]
|
||||
_fn._torch_decompositions_out_wrapper = ( # type: ignore[attr-defined]
|
||||
f"This function is wrapped by {out_wrapper.__module__}.out_wrapper"
|
||||
)
|
||||
|
||||
return _fn
|
||||
|
||||
|
|
|
|||
|
|
@ -313,7 +313,8 @@ def _canonicalize_fft_shape_and_dim_args(
|
|||
|
||||
# Translate any -1 values in shape to the default length
|
||||
ret_shape = tuple(
|
||||
s if s != -1 else input_sizes[d] for (s, d) in zip(shape, ret_dims) # type: ignore[possibly-undefined]
|
||||
s if s != -1 else input_sizes[d]
|
||||
for (s, d) in zip(shape, ret_dims) # type: ignore[possibly-undefined]
|
||||
)
|
||||
elif dim is None:
|
||||
# No shape, no dim
|
||||
|
|
|
|||
|
|
@ -310,7 +310,7 @@ def strobelight(
|
|||
profiler = StrobelightCLIFunctionProfiler(**kwargs)
|
||||
|
||||
def strobelight_inner(
|
||||
work_function: Callable[_P, _R]
|
||||
work_function: Callable[_P, _R],
|
||||
) -> Callable[_P, Optional[_R]]:
|
||||
@functools.wraps(work_function)
|
||||
def wrapper_function(*args: _P.args, **kwargs: _P.kwargs) -> Optional[_R]:
|
||||
|
|
|
|||
|
|
@ -129,9 +129,9 @@ def _is_tensor_constructor(func: OpOverload):
|
|||
def register_op_impl(run_impl_check: Union[Callable[[OpOverload], bool], OpOverload]):
|
||||
def impl_decorator(op_impl):
|
||||
if isinstance(run_impl_check, OpOverload):
|
||||
assert (
|
||||
run_impl_check not in op_implementations_dict
|
||||
), f"duplicate registration: {run_impl_check}"
|
||||
assert run_impl_check not in op_implementations_dict, (
|
||||
f"duplicate registration: {run_impl_check}"
|
||||
)
|
||||
op_implementations_dict[run_impl_check] = op_impl
|
||||
elif isinstance(run_impl_check, (list, tuple)):
|
||||
for op in run_impl_check:
|
||||
|
|
@ -575,25 +575,25 @@ def assert_tensor_metadata(
|
|||
layout=None,
|
||||
) -> None:
|
||||
if sizes is not None:
|
||||
assert (
|
||||
t.size() == sizes
|
||||
), f"Tensor sizes mismatch! Expected: {sizes}, Got: {t.size()}"
|
||||
assert t.size() == sizes, (
|
||||
f"Tensor sizes mismatch! Expected: {sizes}, Got: {t.size()}"
|
||||
)
|
||||
if strides is not None:
|
||||
assert (
|
||||
t.stride() == strides
|
||||
), f"Tensor strides mismatch! Expected: {strides}, Got: {t.stride()}"
|
||||
assert t.stride() == strides, (
|
||||
f"Tensor strides mismatch! Expected: {strides}, Got: {t.stride()}"
|
||||
)
|
||||
if dtype is not None:
|
||||
assert (
|
||||
t.dtype == dtype
|
||||
), f"Tensor dtype mismatch! Expected: {dtype}, Got: {t.dtype}"
|
||||
assert t.dtype == dtype, (
|
||||
f"Tensor dtype mismatch! Expected: {dtype}, Got: {t.dtype}"
|
||||
)
|
||||
if layout is not None:
|
||||
assert (
|
||||
t.layout == layout
|
||||
), f"Tensor layout mismatch! Expected: {layout}, Got: {t.layout()}"
|
||||
assert t.layout == layout, (
|
||||
f"Tensor layout mismatch! Expected: {layout}, Got: {t.layout()}"
|
||||
)
|
||||
if device is not None:
|
||||
assert (
|
||||
t.device == device
|
||||
), f"Tensor device mismatch! Expected: {device}, Got: {t.device}"
|
||||
assert t.device == device, (
|
||||
f"Tensor device mismatch! Expected: {device}, Got: {t.device}"
|
||||
)
|
||||
|
||||
|
||||
# NB: this must be ordered after local_scalar_dense
|
||||
|
|
@ -1091,7 +1091,9 @@ def get_fast_op_impls():
|
|||
register_fast_op_impl(torch.ops.aten.sub.Tensor)(
|
||||
make_fast_binary_impl(torch._refs.sub)
|
||||
)
|
||||
register_fast_op_impl(torch.ops.aten.mul.Tensor)(make_fast_binary_impl(torch._refs.mul)) # type: ignore[has-type]
|
||||
register_fast_op_impl(torch.ops.aten.mul.Tensor)(
|
||||
make_fast_binary_impl(torch._refs.mul)
|
||||
) # type: ignore[has-type]
|
||||
register_fast_op_impl(torch.ops.aten.div.Tensor)(
|
||||
make_fast_binary_impl(
|
||||
torch._refs.div,
|
||||
|
|
|
|||
|
|
@ -496,9 +496,9 @@ class FakeTensorConverter:
|
|||
pytype: Optional[type[torch.Tensor]] = None,
|
||||
dispatch_keys: Optional[torch.DispatchKeySet] = None,
|
||||
) -> FakeTensor:
|
||||
assert (
|
||||
t.device.type == "meta"
|
||||
), f"tensor's device must be `meta`, got {t.device.type} instead"
|
||||
assert t.device.type == "meta", (
|
||||
f"tensor's device must be `meta`, got {t.device.type} instead"
|
||||
)
|
||||
# This is a bit abusive (this is not the "real" tensor) but whatever,
|
||||
# the meta tensor should be fresh so there's no way to get it wrong
|
||||
maybe_memo = self._get_memo(t)
|
||||
|
|
@ -1594,7 +1594,10 @@ class FakeTensorMode(TorchDispatchMode):
|
|||
if torch.Tag.dynamic_output_shape in func.tags:
|
||||
if func is aten.index.Tensor:
|
||||
_, new_kwargs = normalize_function( # type: ignore[misc]
|
||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True # type: ignore[arg-type]
|
||||
func,
|
||||
args=args, # type: ignore[arg-type]
|
||||
kwargs=kwargs, # type: ignore[arg-type]
|
||||
normalize_to_only_use_kwargs=True,
|
||||
)
|
||||
for index in new_kwargs["indices"]:
|
||||
# index calls nonzero for bool or int8 tensors, and
|
||||
|
|
@ -2136,9 +2139,7 @@ class FakeTensorMode(TorchDispatchMode):
|
|||
try:
|
||||
_check_fake_real_vals(s_fake, s_real)
|
||||
except MetadataMismatchError as exc:
|
||||
if (
|
||||
torch._functorch.config.generate_fake_kernels_from_real_mismatches
|
||||
):
|
||||
if torch._functorch.config.generate_fake_kernels_from_real_mismatches:
|
||||
dtrace_structured(
|
||||
"mismatched_fake_kernel",
|
||||
metadata_fn=lambda: {
|
||||
|
|
@ -2311,9 +2312,9 @@ class FakeTensorMode(TorchDispatchMode):
|
|||
and not flat_arg_fake_tensors
|
||||
and not device_conversion_skip_const_prop
|
||||
):
|
||||
assert all(
|
||||
t.constant is not None for t in flat_arg_fake_tensors
|
||||
), f"{func} should not have fake inputs without constants"
|
||||
assert all(t.constant is not None for t in flat_arg_fake_tensors), (
|
||||
f"{func} should not have fake inputs without constants"
|
||||
)
|
||||
const_flat_args = [
|
||||
a.constant if self.is_our_fake(a) else a for a in flat_args
|
||||
]
|
||||
|
|
@ -2538,9 +2539,7 @@ class FakeTensorMode(TorchDispatchMode):
|
|||
|
||||
if real_out is not nil:
|
||||
# cross check fake/real outputs, and optionally override fake kernel mismatches
|
||||
if (
|
||||
not torch._functorch.config.generate_fake_kernels_from_real_mismatches
|
||||
):
|
||||
if not torch._functorch.config.generate_fake_kernels_from_real_mismatches:
|
||||
self._maybe_infer_fake_kernel_from_pytree_out(
|
||||
func,
|
||||
(args, kwargs),
|
||||
|
|
@ -2924,7 +2923,10 @@ class FakeTensorMode(TorchDispatchMode):
|
|||
schema_info = get_schema_info(func)
|
||||
if any_constant and schema_info.is_mutable():
|
||||
_, new_kwargs = normalize_function( # type: ignore[misc]
|
||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True # type: ignore[arg-type]
|
||||
func,
|
||||
args=args, # type: ignore[arg-type]
|
||||
kwargs=kwargs, # type: ignore[arg-type]
|
||||
normalize_to_only_use_kwargs=True,
|
||||
)
|
||||
for k, v in new_kwargs.items():
|
||||
k = k if (k != "input" or schema_info.has_argument(k)) else "self"
|
||||
|
|
@ -2948,9 +2950,9 @@ class FakeTensorMode(TorchDispatchMode):
|
|||
if static_shapes is None:
|
||||
static_shapes = self.static_shapes
|
||||
if static_shapes:
|
||||
assert (
|
||||
symbolic_context is None
|
||||
), "cannot set both static_shapes and symbolic_context"
|
||||
assert symbolic_context is None, (
|
||||
"cannot set both static_shapes and symbolic_context"
|
||||
)
|
||||
shape_env = None
|
||||
return self.fake_tensor_converter.from_real_tensor(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -102,7 +102,7 @@ def is_sdpa_error(func, idx, e):
|
|||
|
||||
|
||||
def try_convert_fake_to_real(
|
||||
ten_list: list[Union[FakeTensor, Any]]
|
||||
ten_list: list[Union[FakeTensor, Any]],
|
||||
) -> list[Union[FakeTensor, torch.Tensor, Any]]:
|
||||
"""
|
||||
Attempt to convert fake tensors to a corresponding real tensor with the correct underlying storage by looking up
|
||||
|
|
@ -266,9 +266,9 @@ class CrossRefFakeMode(TorchDispatchMode):
|
|||
if fake_r is not None:
|
||||
r_flat = pytree.tree_leaves(r)
|
||||
f_flat = pytree.tree_leaves(fake_r)
|
||||
assert len(f_flat) == len(
|
||||
r_flat
|
||||
), f"{context} mismatch in number of returns {len(f_flat)} != {len(r_flat)}"
|
||||
assert len(f_flat) == len(r_flat), (
|
||||
f"{context} mismatch in number of returns {len(f_flat)} != {len(r_flat)}"
|
||||
)
|
||||
|
||||
if self.check_aliasing:
|
||||
_check_alias_info(
|
||||
|
|
@ -279,9 +279,9 @@ class CrossRefFakeMode(TorchDispatchMode):
|
|||
zip(pytree.tree_leaves(r), pytree.tree_leaves(fake_r))
|
||||
):
|
||||
r_is_ten = isinstance(r_out, torch.Tensor)
|
||||
assert r_is_ten == isinstance(
|
||||
f_out, torch.Tensor
|
||||
), f"{context} mismatched number of tensor outputs"
|
||||
assert r_is_ten == isinstance(f_out, torch.Tensor), (
|
||||
f"{context} mismatched number of tensor outputs"
|
||||
)
|
||||
if r_is_ten:
|
||||
try:
|
||||
_check_fake_real_tensors(
|
||||
|
|
|
|||
|
|
@ -357,7 +357,9 @@ class MetaTensorDescriber:
|
|||
|
||||
maybe_functorch_stack = None
|
||||
if is_functorch_wrapped:
|
||||
with torch._functorch.pyfunctorch.temporarily_clear_interpreter_stack() as maybe_functorch_stack:
|
||||
with (
|
||||
torch._functorch.pyfunctorch.temporarily_clear_interpreter_stack()
|
||||
) as maybe_functorch_stack:
|
||||
pass
|
||||
|
||||
attrs = None
|
||||
|
|
@ -517,8 +519,7 @@ class ViewFunc(Generic[_TensorT]):
|
|||
new_base: _TensorT,
|
||||
symint_visitor_fn: Optional[Callable[[int], int]] = None,
|
||||
tensor_visitor_fn: Optional[Callable[[torch.Tensor], _TensorT]] = None,
|
||||
) -> _TensorT:
|
||||
...
|
||||
) -> _TensorT: ...
|
||||
|
||||
@staticmethod
|
||||
def from_tensor(t: torch.Tensor) -> ViewFunc:
|
||||
|
|
@ -574,8 +575,7 @@ class _CustomViewFunc(ViewFunc[_TensorT], Generic[_TensorT]):
|
|||
class _MetaTensorCallback(Protocol, Generic[_TensorT_cov]):
|
||||
def __call__(
|
||||
self, arg: Callable[[], torch.Tensor], /, *, device: Union[torch.device, str]
|
||||
) -> _TensorT_cov:
|
||||
...
|
||||
) -> _TensorT_cov: ...
|
||||
|
||||
|
||||
class _MetaTensorCallbackKwargs(TypedDict, total=False):
|
||||
|
|
@ -592,8 +592,7 @@ class _MetaTensorCallbackOptDevice(Protocol, Generic[_TensorT_cov]):
|
|||
arg: Callable[[], torch.Tensor],
|
||||
/,
|
||||
**kwargs: Unpack[_MetaTensorCallbackKwargs],
|
||||
) -> _TensorT_cov:
|
||||
...
|
||||
) -> _TensorT_cov: ...
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
|
|
@ -785,9 +784,9 @@ class MetaConverter(Generic[_TensorT]):
|
|||
] = weakref.WeakValueDictionary()
|
||||
# Maps MetaTensorId to torch.Tensor (typically a meta tensor or
|
||||
# FakeTensor)
|
||||
self.tensor_memo: weakref.WeakValueDictionary[
|
||||
MetaTensorId, _TensorT
|
||||
] = weakref.WeakValueDictionary()
|
||||
self.tensor_memo: weakref.WeakValueDictionary[MetaTensorId, _TensorT] = (
|
||||
weakref.WeakValueDictionary()
|
||||
)
|
||||
self.hit = 0
|
||||
self.miss = 0
|
||||
self.del_hook = None
|
||||
|
|
@ -1772,9 +1771,9 @@ class MetaConverter(Generic[_TensorT]):
|
|||
# subclasses. Relevant test is
|
||||
# DynamicShapesFunctionTests::test_add_dynamic_shapes in
|
||||
# test/dynamo/test_dynamic_shapes.py
|
||||
maybe_fake_mgr: AbstractContextManager[
|
||||
None
|
||||
] = contextlib.nullcontext()
|
||||
maybe_fake_mgr: AbstractContextManager[None] = (
|
||||
contextlib.nullcontext()
|
||||
)
|
||||
from torch._subclasses.fake_tensor import (
|
||||
in_kernel_invocation_manager,
|
||||
maybe_get_fake_mode,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user