format some aotautograd-related files in functorch with black (#83240)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/83240
Approved by: https://github.com/ezyang
This commit is contained in:
Horace He 2022-08-13 11:34:10 +00:00 committed by PyTorch MergeBot
parent 408fa38f33
commit 016fcca243
3 changed files with 114 additions and 47 deletions

View File

@ -720,6 +720,8 @@ include_patterns = [
'torch/_subclasses/**/*.py',
'torch/_*.py',
'torchgen/**/*.py',
'functorch/functorch/_src/aot_autograd.py',
'functorch/functorch/_src/compilers.py',
]
command = [
'python3',

View File

@ -1,30 +1,34 @@
import warnings
from contextlib import contextmanager, nullcontext
from functools import wraps
from typing import Any, Callable, Dict, List, Optional, Tuple
import torch
import torch.nn as nn
from torch import Tensor
from functorch import make_fx
from torch.fx import immutable_collections, Interpreter
import torch.fx.traceback as fx_traceback
from torch._subclasses import FakeTensorMode
import torch.nn as nn
import torch.nn.utils.stateless as stateless
import torch.utils._pytree as pytree
import torch.utils.dlpack
from torch.nn.utils import _stateless
from torch import Tensor
from torch._subclasses import FakeTensorMode
from torch.fx import immutable_collections, Interpreter
from functorch import make_fx
from functorch._C import CompileCache
from functorch.experimental import functionalize
from . import config
from .decompositions import register_decomposition
from .named_members_polyfill import _named_buffers, _named_parameters
from .partitioners import default_partition
from .named_members_polyfill import _named_parameters, _named_buffers
from typing import Callable, List, Dict, Any, Tuple, Optional
from functools import wraps
import warnings
try:
from torchdynamo import disable as disable_torchdynamo
except ImportError:
def disable_torchdynamo(x):
return x
pytree._register_pytree_node(
immutable_collections.immutable_list,
lambda x: (list(x), None),
@ -148,20 +152,25 @@ def track_graph_compiling(graph_name, increment_index=False):
nth_graph += 1
graph_being_compiled = None
def make_boxed_func(f):
def g(args):
return f(*args)
g._boxed_call = True
return g
def make_boxed_compiler(compiler):
@wraps(compiler)
def f(fx_g, inps):
out_f = compiler(fx_g, inps)
fx_g = make_boxed_func(out_f)
return fx_g
return f
def call_func_with_args(f, args, steal_args=False):
if not steal_args:
args = list(args)
@ -178,6 +187,7 @@ def call_func_with_args(f, args, steal_args=False):
)
return normalize_as_list(f(*args))
def create_aot_autograd_function(
flat_fn, fw_compiler, bw_compiler, partition_fn, decompositions, grad_state
):
@ -212,19 +222,30 @@ def create_aot_autograd_function(
if compiled_fw is None:
flat_tensor_args = pytree.tree_map(
lambda x: x.detach().requires_grad_(x.requires_grad)
if isinstance(x, Tensor) else x, flat_tensor_args
if isinstance(x, Tensor)
else x,
flat_tensor_args,
)
fake_mode = (
FakeTensorMode.push() if config.use_fake_tensor else nullcontext()
)
fake_mode = FakeTensorMode.push() if config.use_fake_tensor else nullcontext()
with preserve_rng_state(), fake_mode as mode:
# Set input tensors that require grad to leaves
fake_flat_tensor_args = pytree.tree_map(
lambda x: mode.from_tensor(x) if mode else x
if isinstance(x, Tensor) else x, flat_tensor_args
lambda x: mode.from_tensor(x)
if mode
else x
if isinstance(x, Tensor)
else x,
flat_tensor_args,
)
with torch.set_grad_enabled(grad_state):
out = flat_fn(*fake_flat_tensor_args)
out = pytree.tree_map(
lambda x: x.detach().contiguous() if isinstance(x, Tensor) else x, out
lambda x: x.detach().contiguous()
if isinstance(x, Tensor)
else x,
out,
)
if isinstance(out, (list, tuple)):
@ -233,7 +254,10 @@ def create_aot_autograd_function(
num_outs = 1
joint_inputs = (fake_flat_tensor_args, out)
aot_decompositions = {**aot_autograd_decompositions, **decompositions}
aot_decompositions = {
**aot_autograd_decompositions,
**decompositions,
}
with torch.set_grad_enabled(grad_state):
fx_g = make_fx(joint_forward_backward, aot_decompositions)(
*joint_inputs
@ -244,6 +268,7 @@ def create_aot_autograd_function(
# fake fn to make functionalize happy
def fake_fn(primals, tangents):
return fx_g(primals, tangents)
fx_g = make_fx(functionalize(fake_fn))(*joint_inputs)
if config.debug_joint:
@ -255,7 +280,6 @@ def create_aot_autograd_function(
if config.debug_graphs:
print(fw_module.code, bw_module.code)
with track_graph_compiling("forward"):
compiled_fw = fw_compiler(fw_module, flat_tensor_args)
@ -621,7 +645,7 @@ def aot_module(mod: nn.Module, *args, **kwargs) -> nn.Module:
def functional_call(named_params, named_buffers, *args, **kwargs):
params_and_buffers = {**named_params, **named_buffers}
return _stateless.functional_call(mod, params_and_buffers, args, kwargs)
return stateless.functional_call(mod, params_and_buffers, args, kwargs)
compiled_f = aot_function(functional_call, *args, **kwargs)
@ -663,7 +687,7 @@ def aot_module_simplified(mod: nn.Module, *top_args, **top_kwargs) -> nn.Module:
params_len = len(params_flat)
def functional_call(*args, **kwargs):
with _stateless.reparametrize_module(
with stateless._reparametrize_module(
mod, pytree.tree_unflatten(args[:params_len], params_spec)
):
if isinstance(mod, torch.fx.GraphModule):
@ -706,13 +730,16 @@ def aot_module_simplified(mod: nn.Module, *top_args, **top_kwargs) -> nn.Module:
compiled_f = aot_function_simplified(functional_call, *top_args, **top_kwargs)
if top_kwargs:
def forward(*args, **kwargs):
return compiled_f(
*params_flat,
*args,
**kwargs,
)
else:
def forward(*args):
return compiled_f(
*params_flat,

View File

@ -1,18 +1,23 @@
import torch
import torch.fx as fx
import torch.nn as nn
from functools import partial
from typing import Callable, Optional, Tuple, Union
from .aot_autograd import aot_function, aot_module, make_boxed_compiler
from .decompositions import get_decompositions
from .partitioners import draw_graph, min_cut_rematerialization_partition, default_partition
from .compile_utils import strip_overloads
import copy
import logging
import os
import pickle
import random
import copy
import logging
from functools import partial
from typing import Callable, Optional, Tuple, Union
import torch
import torch.fx as fx
import torch.nn as nn
from .aot_autograd import aot_function, aot_module, make_boxed_compiler
from .compile_utils import strip_overloads
from .decompositions import get_decompositions
from .partitioners import (
default_partition,
draw_graph,
min_cut_rematerialization_partition,
)
# These canonicalizations are needed here (and not decompositions), as the ops
@ -43,8 +48,12 @@ def ts_compile(fx_g: fx.GraphModule, _) -> Callable:
strip_overloads(fx_g)
for node in fx_g.graph.nodes:
if (node.target == torch.ops.aten._to_copy and len(node.args) == 1
and len(node.kwargs) == 1 and 'dtype' in node.kwargs):
if (
node.target == torch.ops.aten._to_copy
and len(node.args) == 1
and len(node.kwargs) == 1
and "dtype" in node.kwargs
):
node.target = torch.ops.aten.to
for node in fx_g.graph.nodes:
@ -55,7 +64,6 @@ def ts_compile(fx_g: fx.GraphModule, _) -> Callable:
new_kwargs[k] = v
node.kwargs = new_kwargs
fx_g.graph.lint()
fx_g.recompile()
@ -140,7 +148,9 @@ def print_compile(fx_g, _):
def memory_efficient_fusion(
fn: Union[Callable, nn.Module], static_argnums: Optional[Tuple[int]] = None, **kwargs
fn: Union[Callable, nn.Module],
static_argnums: Optional[Tuple[int]] = None,
**kwargs,
):
"""
Wrapper function over :func:`aot_function` and :func:`aot_module` to perform
@ -218,7 +228,7 @@ def get_inputs(input_data_path):
Return a random input for the given inputs meta generated from _save_fx_default.
"""
inputs = []
with (open(input_data_path, 'rb')) as f:
with (open(input_data_path, "rb")) as f:
inputs_meta = pickle.load(f)
inputs = []
for meta in inputs_meta:
@ -227,7 +237,16 @@ def get_inputs(input_data_path):
input = type(random.rand())
else:
type, shape, stride, dtype, device = meta
if dtype in {torch.int, torch.int32, torch.int64, torch.bool, torch.int, torch.uint8, int, float}:
if dtype in {
torch.int,
torch.int32,
torch.int64,
torch.bool,
torch.int,
torch.uint8,
int,
float,
}:
input = torch.randint(0, 1, shape, dtype=dtype, device=device)
else:
input = torch.rand(shape, dtype=dtype, device=device)
@ -260,16 +279,21 @@ def _save_fx_default(current_name, folder_name, dump_example_input, gm, example_
input_meta += get_input_meta(args[1])
return input_meta
for arg in args:
if(type(arg) == int or type(arg) == float):
if type(arg) == int or type(arg) == float:
input_meta.append((type(arg),))
else:
input_meta.append((type(arg), arg.shape, arg.stride(), arg.dtype, arg.device))
input_meta.append(
(type(arg), arg.shape, arg.stride(), arg.dtype, arg.device)
)
return input_meta
def graph_saver_helper(gm_to_save, args, type_name):
global graph_index
if len(gm_to_save.graph.nodes) == 0:
logging.log(logging.WARNING, f"No nodes in graph {current_name}_{type_name}_{graph_index}.")
logging.log(
logging.WARNING,
f"No nodes in graph {current_name}_{type_name}_{graph_index}.",
)
return
gm = copy.deepcopy(gm_to_save)
@ -281,10 +305,21 @@ def _save_fx_default(current_name, folder_name, dump_example_input, gm, example_
isExist = os.path.exists(f"{folder_name}/{current_name}")
if not isExist:
os.makedirs(f"{folder_name}/{current_name}")
gm.to_folder(f"{folder_name}/{current_name}/{current_name}_{type_name}_{graph_index}")
pickle.dump(input_meta, open(f"{folder_name}/{current_name}/{current_name}_{type_name}_{graph_index}/{current_name}_{type_name}_{graph_index}.input", "wb")) # noqa: E501
gm.to_folder(
f"{folder_name}/{current_name}/{current_name}_{type_name}_{graph_index}"
)
pickle.dump(
input_meta,
open(
f"{folder_name}/{current_name}/{current_name}_{type_name}_{graph_index}/{current_name}_{type_name}_{graph_index}.input", # noqa: B950
"wb",
),
) # noqa: E501
if dump_example_input:
torch.save(args, f"{folder_name}/{current_name}/{current_name}_{type_name}_{graph_index}/{current_name}_{type_name}_{graph_index}.pt") # noqa: E501
torch.save(
args,
f"{folder_name}/{current_name}/{current_name}_{type_name}_{graph_index}/{current_name}_{type_name}_{graph_index}.pt", # noqa: B950
) # noqa: E501
def graph_saver_forward(gm, fw_args):
graph_saver_helper(gm, fw_args, "forward")
@ -300,10 +335,13 @@ def _save_fx_default(current_name, folder_name, dump_example_input, gm, example_
graph_saver_helper(gm, joint_args, "joint")
return default_partition(gm, joint_args)
return aot_module_simplified(gm, fw_compiler=graph_saver_forward,
bw_compiler=graph_saver_backward,
partition_fn=graph_saver_joint,
decompositions=default_decompositions)
return aot_module_simplified(
gm,
fw_compiler=graph_saver_forward,
bw_compiler=graph_saver_backward,
partition_fn=graph_saver_joint,
decompositions=default_decompositions,
)
def graph_dumper_aot(current_name, folder_name, dump_example_input=False):