mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
format some aotautograd-related files in functorch with black (#83240)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83240 Approved by: https://github.com/ezyang
This commit is contained in:
parent
408fa38f33
commit
016fcca243
|
|
@ -720,6 +720,8 @@ include_patterns = [
|
|||
'torch/_subclasses/**/*.py',
|
||||
'torch/_*.py',
|
||||
'torchgen/**/*.py',
|
||||
'functorch/functorch/_src/aot_autograd.py',
|
||||
'functorch/functorch/_src/compilers.py',
|
||||
]
|
||||
command = [
|
||||
'python3',
|
||||
|
|
|
|||
|
|
@ -1,30 +1,34 @@
|
|||
import warnings
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from functools import wraps
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
from functorch import make_fx
|
||||
from torch.fx import immutable_collections, Interpreter
|
||||
import torch.fx.traceback as fx_traceback
|
||||
from torch._subclasses import FakeTensorMode
|
||||
import torch.nn as nn
|
||||
import torch.nn.utils.stateless as stateless
|
||||
import torch.utils._pytree as pytree
|
||||
import torch.utils.dlpack
|
||||
from torch.nn.utils import _stateless
|
||||
from torch import Tensor
|
||||
from torch._subclasses import FakeTensorMode
|
||||
from torch.fx import immutable_collections, Interpreter
|
||||
|
||||
from functorch import make_fx
|
||||
from functorch._C import CompileCache
|
||||
from functorch.experimental import functionalize
|
||||
from . import config
|
||||
from .decompositions import register_decomposition
|
||||
from .named_members_polyfill import _named_buffers, _named_parameters
|
||||
from .partitioners import default_partition
|
||||
from .named_members_polyfill import _named_parameters, _named_buffers
|
||||
from typing import Callable, List, Dict, Any, Tuple, Optional
|
||||
from functools import wraps
|
||||
import warnings
|
||||
|
||||
try:
|
||||
from torchdynamo import disable as disable_torchdynamo
|
||||
except ImportError:
|
||||
|
||||
def disable_torchdynamo(x):
|
||||
return x
|
||||
|
||||
|
||||
pytree._register_pytree_node(
|
||||
immutable_collections.immutable_list,
|
||||
lambda x: (list(x), None),
|
||||
|
|
@ -148,20 +152,25 @@ def track_graph_compiling(graph_name, increment_index=False):
|
|||
nth_graph += 1
|
||||
graph_being_compiled = None
|
||||
|
||||
|
||||
def make_boxed_func(f):
|
||||
def g(args):
|
||||
return f(*args)
|
||||
|
||||
g._boxed_call = True
|
||||
return g
|
||||
|
||||
|
||||
def make_boxed_compiler(compiler):
|
||||
@wraps(compiler)
|
||||
def f(fx_g, inps):
|
||||
out_f = compiler(fx_g, inps)
|
||||
fx_g = make_boxed_func(out_f)
|
||||
return fx_g
|
||||
|
||||
return f
|
||||
|
||||
|
||||
def call_func_with_args(f, args, steal_args=False):
|
||||
if not steal_args:
|
||||
args = list(args)
|
||||
|
|
@ -178,6 +187,7 @@ def call_func_with_args(f, args, steal_args=False):
|
|||
)
|
||||
return normalize_as_list(f(*args))
|
||||
|
||||
|
||||
def create_aot_autograd_function(
|
||||
flat_fn, fw_compiler, bw_compiler, partition_fn, decompositions, grad_state
|
||||
):
|
||||
|
|
@ -212,19 +222,30 @@ def create_aot_autograd_function(
|
|||
if compiled_fw is None:
|
||||
flat_tensor_args = pytree.tree_map(
|
||||
lambda x: x.detach().requires_grad_(x.requires_grad)
|
||||
if isinstance(x, Tensor) else x, flat_tensor_args
|
||||
if isinstance(x, Tensor)
|
||||
else x,
|
||||
flat_tensor_args,
|
||||
)
|
||||
fake_mode = (
|
||||
FakeTensorMode.push() if config.use_fake_tensor else nullcontext()
|
||||
)
|
||||
fake_mode = FakeTensorMode.push() if config.use_fake_tensor else nullcontext()
|
||||
with preserve_rng_state(), fake_mode as mode:
|
||||
# Set input tensors that require grad to leaves
|
||||
fake_flat_tensor_args = pytree.tree_map(
|
||||
lambda x: mode.from_tensor(x) if mode else x
|
||||
if isinstance(x, Tensor) else x, flat_tensor_args
|
||||
lambda x: mode.from_tensor(x)
|
||||
if mode
|
||||
else x
|
||||
if isinstance(x, Tensor)
|
||||
else x,
|
||||
flat_tensor_args,
|
||||
)
|
||||
with torch.set_grad_enabled(grad_state):
|
||||
out = flat_fn(*fake_flat_tensor_args)
|
||||
out = pytree.tree_map(
|
||||
lambda x: x.detach().contiguous() if isinstance(x, Tensor) else x, out
|
||||
lambda x: x.detach().contiguous()
|
||||
if isinstance(x, Tensor)
|
||||
else x,
|
||||
out,
|
||||
)
|
||||
|
||||
if isinstance(out, (list, tuple)):
|
||||
|
|
@ -233,7 +254,10 @@ def create_aot_autograd_function(
|
|||
num_outs = 1
|
||||
|
||||
joint_inputs = (fake_flat_tensor_args, out)
|
||||
aot_decompositions = {**aot_autograd_decompositions, **decompositions}
|
||||
aot_decompositions = {
|
||||
**aot_autograd_decompositions,
|
||||
**decompositions,
|
||||
}
|
||||
with torch.set_grad_enabled(grad_state):
|
||||
fx_g = make_fx(joint_forward_backward, aot_decompositions)(
|
||||
*joint_inputs
|
||||
|
|
@ -244,6 +268,7 @@ def create_aot_autograd_function(
|
|||
# fake fn to make functionalize happy
|
||||
def fake_fn(primals, tangents):
|
||||
return fx_g(primals, tangents)
|
||||
|
||||
fx_g = make_fx(functionalize(fake_fn))(*joint_inputs)
|
||||
|
||||
if config.debug_joint:
|
||||
|
|
@ -255,7 +280,6 @@ def create_aot_autograd_function(
|
|||
if config.debug_graphs:
|
||||
print(fw_module.code, bw_module.code)
|
||||
|
||||
|
||||
with track_graph_compiling("forward"):
|
||||
compiled_fw = fw_compiler(fw_module, flat_tensor_args)
|
||||
|
||||
|
|
@ -621,7 +645,7 @@ def aot_module(mod: nn.Module, *args, **kwargs) -> nn.Module:
|
|||
|
||||
def functional_call(named_params, named_buffers, *args, **kwargs):
|
||||
params_and_buffers = {**named_params, **named_buffers}
|
||||
return _stateless.functional_call(mod, params_and_buffers, args, kwargs)
|
||||
return stateless.functional_call(mod, params_and_buffers, args, kwargs)
|
||||
|
||||
compiled_f = aot_function(functional_call, *args, **kwargs)
|
||||
|
||||
|
|
@ -663,7 +687,7 @@ def aot_module_simplified(mod: nn.Module, *top_args, **top_kwargs) -> nn.Module:
|
|||
params_len = len(params_flat)
|
||||
|
||||
def functional_call(*args, **kwargs):
|
||||
with _stateless.reparametrize_module(
|
||||
with stateless._reparametrize_module(
|
||||
mod, pytree.tree_unflatten(args[:params_len], params_spec)
|
||||
):
|
||||
if isinstance(mod, torch.fx.GraphModule):
|
||||
|
|
@ -706,13 +730,16 @@ def aot_module_simplified(mod: nn.Module, *top_args, **top_kwargs) -> nn.Module:
|
|||
compiled_f = aot_function_simplified(functional_call, *top_args, **top_kwargs)
|
||||
|
||||
if top_kwargs:
|
||||
|
||||
def forward(*args, **kwargs):
|
||||
return compiled_f(
|
||||
*params_flat,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
else:
|
||||
|
||||
def forward(*args):
|
||||
return compiled_f(
|
||||
*params_flat,
|
||||
|
|
|
|||
|
|
@ -1,18 +1,23 @@
|
|||
import torch
|
||||
import torch.fx as fx
|
||||
import torch.nn as nn
|
||||
from functools import partial
|
||||
from typing import Callable, Optional, Tuple, Union
|
||||
|
||||
from .aot_autograd import aot_function, aot_module, make_boxed_compiler
|
||||
from .decompositions import get_decompositions
|
||||
from .partitioners import draw_graph, min_cut_rematerialization_partition, default_partition
|
||||
from .compile_utils import strip_overloads
|
||||
import copy
|
||||
import logging
|
||||
import os
|
||||
import pickle
|
||||
import random
|
||||
import copy
|
||||
import logging
|
||||
from functools import partial
|
||||
from typing import Callable, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.fx as fx
|
||||
import torch.nn as nn
|
||||
|
||||
from .aot_autograd import aot_function, aot_module, make_boxed_compiler
|
||||
from .compile_utils import strip_overloads
|
||||
from .decompositions import get_decompositions
|
||||
from .partitioners import (
|
||||
default_partition,
|
||||
draw_graph,
|
||||
min_cut_rematerialization_partition,
|
||||
)
|
||||
|
||||
|
||||
# These canonicalizations are needed here (and not decompositions), as the ops
|
||||
|
|
@ -43,8 +48,12 @@ def ts_compile(fx_g: fx.GraphModule, _) -> Callable:
|
|||
strip_overloads(fx_g)
|
||||
|
||||
for node in fx_g.graph.nodes:
|
||||
if (node.target == torch.ops.aten._to_copy and len(node.args) == 1
|
||||
and len(node.kwargs) == 1 and 'dtype' in node.kwargs):
|
||||
if (
|
||||
node.target == torch.ops.aten._to_copy
|
||||
and len(node.args) == 1
|
||||
and len(node.kwargs) == 1
|
||||
and "dtype" in node.kwargs
|
||||
):
|
||||
node.target = torch.ops.aten.to
|
||||
|
||||
for node in fx_g.graph.nodes:
|
||||
|
|
@ -55,7 +64,6 @@ def ts_compile(fx_g: fx.GraphModule, _) -> Callable:
|
|||
new_kwargs[k] = v
|
||||
node.kwargs = new_kwargs
|
||||
|
||||
|
||||
fx_g.graph.lint()
|
||||
|
||||
fx_g.recompile()
|
||||
|
|
@ -140,7 +148,9 @@ def print_compile(fx_g, _):
|
|||
|
||||
|
||||
def memory_efficient_fusion(
|
||||
fn: Union[Callable, nn.Module], static_argnums: Optional[Tuple[int]] = None, **kwargs
|
||||
fn: Union[Callable, nn.Module],
|
||||
static_argnums: Optional[Tuple[int]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Wrapper function over :func:`aot_function` and :func:`aot_module` to perform
|
||||
|
|
@ -218,7 +228,7 @@ def get_inputs(input_data_path):
|
|||
Return a random input for the given inputs meta generated from _save_fx_default.
|
||||
"""
|
||||
inputs = []
|
||||
with (open(input_data_path, 'rb')) as f:
|
||||
with (open(input_data_path, "rb")) as f:
|
||||
inputs_meta = pickle.load(f)
|
||||
inputs = []
|
||||
for meta in inputs_meta:
|
||||
|
|
@ -227,7 +237,16 @@ def get_inputs(input_data_path):
|
|||
input = type(random.rand())
|
||||
else:
|
||||
type, shape, stride, dtype, device = meta
|
||||
if dtype in {torch.int, torch.int32, torch.int64, torch.bool, torch.int, torch.uint8, int, float}:
|
||||
if dtype in {
|
||||
torch.int,
|
||||
torch.int32,
|
||||
torch.int64,
|
||||
torch.bool,
|
||||
torch.int,
|
||||
torch.uint8,
|
||||
int,
|
||||
float,
|
||||
}:
|
||||
input = torch.randint(0, 1, shape, dtype=dtype, device=device)
|
||||
else:
|
||||
input = torch.rand(shape, dtype=dtype, device=device)
|
||||
|
|
@ -260,16 +279,21 @@ def _save_fx_default(current_name, folder_name, dump_example_input, gm, example_
|
|||
input_meta += get_input_meta(args[1])
|
||||
return input_meta
|
||||
for arg in args:
|
||||
if(type(arg) == int or type(arg) == float):
|
||||
if type(arg) == int or type(arg) == float:
|
||||
input_meta.append((type(arg),))
|
||||
else:
|
||||
input_meta.append((type(arg), arg.shape, arg.stride(), arg.dtype, arg.device))
|
||||
input_meta.append(
|
||||
(type(arg), arg.shape, arg.stride(), arg.dtype, arg.device)
|
||||
)
|
||||
return input_meta
|
||||
|
||||
def graph_saver_helper(gm_to_save, args, type_name):
|
||||
global graph_index
|
||||
if len(gm_to_save.graph.nodes) == 0:
|
||||
logging.log(logging.WARNING, f"No nodes in graph {current_name}_{type_name}_{graph_index}.")
|
||||
logging.log(
|
||||
logging.WARNING,
|
||||
f"No nodes in graph {current_name}_{type_name}_{graph_index}.",
|
||||
)
|
||||
return
|
||||
|
||||
gm = copy.deepcopy(gm_to_save)
|
||||
|
|
@ -281,10 +305,21 @@ def _save_fx_default(current_name, folder_name, dump_example_input, gm, example_
|
|||
isExist = os.path.exists(f"{folder_name}/{current_name}")
|
||||
if not isExist:
|
||||
os.makedirs(f"{folder_name}/{current_name}")
|
||||
gm.to_folder(f"{folder_name}/{current_name}/{current_name}_{type_name}_{graph_index}")
|
||||
pickle.dump(input_meta, open(f"{folder_name}/{current_name}/{current_name}_{type_name}_{graph_index}/{current_name}_{type_name}_{graph_index}.input", "wb")) # noqa: E501
|
||||
gm.to_folder(
|
||||
f"{folder_name}/{current_name}/{current_name}_{type_name}_{graph_index}"
|
||||
)
|
||||
pickle.dump(
|
||||
input_meta,
|
||||
open(
|
||||
f"{folder_name}/{current_name}/{current_name}_{type_name}_{graph_index}/{current_name}_{type_name}_{graph_index}.input", # noqa: B950
|
||||
"wb",
|
||||
),
|
||||
) # noqa: E501
|
||||
if dump_example_input:
|
||||
torch.save(args, f"{folder_name}/{current_name}/{current_name}_{type_name}_{graph_index}/{current_name}_{type_name}_{graph_index}.pt") # noqa: E501
|
||||
torch.save(
|
||||
args,
|
||||
f"{folder_name}/{current_name}/{current_name}_{type_name}_{graph_index}/{current_name}_{type_name}_{graph_index}.pt", # noqa: B950
|
||||
) # noqa: E501
|
||||
|
||||
def graph_saver_forward(gm, fw_args):
|
||||
graph_saver_helper(gm, fw_args, "forward")
|
||||
|
|
@ -300,10 +335,13 @@ def _save_fx_default(current_name, folder_name, dump_example_input, gm, example_
|
|||
graph_saver_helper(gm, joint_args, "joint")
|
||||
return default_partition(gm, joint_args)
|
||||
|
||||
return aot_module_simplified(gm, fw_compiler=graph_saver_forward,
|
||||
bw_compiler=graph_saver_backward,
|
||||
partition_fn=graph_saver_joint,
|
||||
decompositions=default_decompositions)
|
||||
return aot_module_simplified(
|
||||
gm,
|
||||
fw_compiler=graph_saver_forward,
|
||||
bw_compiler=graph_saver_backward,
|
||||
partition_fn=graph_saver_joint,
|
||||
decompositions=default_decompositions,
|
||||
)
|
||||
|
||||
|
||||
def graph_dumper_aot(current_name, folder_name, dump_example_input=False):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user