mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[4/N][Easy] fix typo for usort config in pyproject.toml (kown -> known): sort functorch (#127125)
The `usort` config in `pyproject.toml` has no effect due to a typo. Fixing the typo make `usort` do more and generate the changes in the PR. Except `pyproject.toml`, all changes are generated by `lintrunner -a --take UFMT --all-files`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127125 Approved by: https://github.com/Skylion007 ghstack dependencies: #127122, #127123, #127124
This commit is contained in:
parent
35ea5c6b22
commit
a28bfb5ed5
|
|
@ -1,10 +1,11 @@
|
|||
import torch
|
||||
import torch.fx as fx
|
||||
from functorch import make_fx
|
||||
|
||||
from torch._functorch.compile_utils import fx_graph_cse
|
||||
from torch.profiler import profile, ProfilerActivity
|
||||
|
||||
from functorch import make_fx
|
||||
|
||||
|
||||
def profile_it(f, inp):
|
||||
for _ in range(5):
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ from functools import partial
|
|||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
|
||||
from functorch.compile import pointwise_operator
|
||||
|
||||
WRITE_CSV = False
|
||||
|
|
|
|||
|
|
@ -3,11 +3,11 @@ import time
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision.models as models
|
||||
|
||||
from functorch import grad, make_functional, vmap
|
||||
from opacus import PrivacyEngine
|
||||
from opacus.utils.module_modification import convert_batchnorm_modules
|
||||
|
||||
from functorch import grad, make_functional, vmap
|
||||
|
||||
device = "cuda"
|
||||
batch_size = 128
|
||||
torch.manual_seed(0)
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ import time
|
|||
|
||||
import torch
|
||||
import torch.utils
|
||||
|
||||
from functorch.compile import aot_function, tvm_compile
|
||||
|
||||
a = torch.randn(2000, 1, 4, requires_grad=True)
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ import timeit
|
|||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from functorch.compile import compiled_module, tvm_compile
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ import time
|
|||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from functorch import make_functional
|
||||
from functorch.compile import nnc_jit
|
||||
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@
|
|||
import time
|
||||
|
||||
import torch
|
||||
|
||||
from functorch import grad, make_fx
|
||||
from functorch.compile import nnc_jit
|
||||
|
||||
|
|
|
|||
|
|
@ -38,10 +38,11 @@ import pandas as pd
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
from functorch import make_functional_with_buffers
|
||||
from support.omniglot_loaders import OmniglotNShot
|
||||
from torch import nn
|
||||
|
||||
from functorch import make_functional_with_buffers
|
||||
|
||||
mpl.use("Agg")
|
||||
plt.style.use("bmh")
|
||||
|
||||
|
|
|
|||
|
|
@ -8,10 +8,11 @@ import matplotlib as mpl
|
|||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
from functorch import grad, make_functional, vmap
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from functorch import grad, make_functional, vmap
|
||||
|
||||
mpl.use("Agg")
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -34,11 +34,11 @@ include_trailing_comma = true
|
|||
|
||||
|
||||
[tool.usort.known]
|
||||
first_party = ["caffe2", "torchgen", "test"]
|
||||
first_party = ["caffe2", "torchgen", "functorch", "test"]
|
||||
standard_library = ["typing_extensions"]
|
||||
|
||||
[tool.usort.kown]
|
||||
first_party = ["torch", "functorch"]
|
||||
first_party = ["torch"]
|
||||
|
||||
|
||||
[tool.ruff]
|
||||
|
|
|
|||
|
|
@ -70,9 +70,10 @@ bw_graph_cell = [None]
|
|||
fw_compiler = functools.partial(extract_graph, graph_cell=fw_graph_cell)
|
||||
bw_compiler = functools.partial(extract_graph, graph_cell=bw_graph_cell)
|
||||
|
||||
from functorch.compile import min_cut_rematerialization_partition
|
||||
from torch._dynamo.backends.common import aot_autograd
|
||||
|
||||
from functorch.compile import min_cut_rematerialization_partition
|
||||
|
||||
aot_eager_graph = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
|
|
|
|||
|
|
@ -10,13 +10,13 @@ import torch.distributed as dist
|
|||
import torch.distributed._functional_collectives as ft_c
|
||||
import torch.distributed._tensor as dt
|
||||
import torch.distributed.distributed_c10d as c10d
|
||||
|
||||
from functorch import make_fx
|
||||
from torch._inductor.utils import run_and_get_code
|
||||
from torch.testing import FileCheck
|
||||
from torch.testing._internal.distributed.fake_pg import FakeStore
|
||||
from torch.utils._triton import has_triton
|
||||
|
||||
from functorch import make_fx
|
||||
|
||||
if not dist.is_available():
|
||||
print("Distributed not available, skipping tests", file=sys.stderr)
|
||||
sys.exit(0)
|
||||
|
|
|
|||
|
|
@ -11,7 +11,6 @@ import torch._dynamo.test_case
|
|||
import torch._functorch.config
|
||||
import torch.distributed as dist
|
||||
import torch.utils.checkpoint
|
||||
from functorch.compile import min_cut_rematerialization_partition
|
||||
from torch._dynamo.backends.common import aot_autograd
|
||||
from torch._dynamo.testing import CompileCounterWithBackend
|
||||
from torch._higher_order_ops.wrap import tag_activation_checkpoint
|
||||
|
|
@ -20,6 +19,8 @@ from torch.testing._internal.inductor_utils import HAS_CUDA
|
|||
from torch.testing._internal.two_tensor import TwoTensor
|
||||
from torch.utils.checkpoint import _pt2_selective_checkpoint_context_fn_gen, checkpoint
|
||||
|
||||
from functorch.compile import min_cut_rematerialization_partition
|
||||
|
||||
requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda")
|
||||
requires_distributed = functools.partial(
|
||||
unittest.skipIf, not dist.is_available(), "requires distributed"
|
||||
|
|
|
|||
|
|
@ -722,9 +722,10 @@ class AotAutogradFallbackTests(torch._dynamo.test_case.TestCase):
|
|||
return (out,)
|
||||
|
||||
def compile_submod(input_mod, args):
|
||||
from functorch.compile import nop
|
||||
from torch._functorch.aot_autograd import aot_module_simplified
|
||||
|
||||
from functorch.compile import nop
|
||||
|
||||
class WrapperModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
|
|
|||
|
|
@ -240,9 +240,10 @@ class TestCustomBackendAPI(torch._dynamo.test_case.TestCase):
|
|||
self.assertTrue(backend_run)
|
||||
|
||||
def test_aot_autograd_api(self):
|
||||
from functorch.compile import make_boxed_func
|
||||
from torch._dynamo.backends.common import aot_autograd
|
||||
|
||||
from functorch.compile import make_boxed_func
|
||||
|
||||
backend_run = False
|
||||
|
||||
def my_compiler(gm, example_inputs):
|
||||
|
|
|
|||
|
|
@ -3,12 +3,13 @@
|
|||
import unittest
|
||||
|
||||
import torch
|
||||
from functorch import make_fx
|
||||
from torch._dynamo import debug_utils
|
||||
from torch._dynamo.debug_utils import aot_graph_input_parser
|
||||
from torch._dynamo.test_case import TestCase
|
||||
from torch.testing._internal.inductor_utils import HAS_CUDA
|
||||
|
||||
from functorch import make_fx
|
||||
|
||||
requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda")
|
||||
|
||||
f32 = torch.float32
|
||||
|
|
|
|||
|
|
@ -17,7 +17,6 @@ import torch
|
|||
import torch._dynamo
|
||||
import torch._dynamo.test_case
|
||||
import torch._dynamo.testing
|
||||
from functorch.experimental.control_flow import cond
|
||||
from torch._dynamo import config
|
||||
from torch._dynamo.exc import UserError
|
||||
from torch._dynamo.testing import normalize_gm
|
||||
|
|
@ -34,6 +33,8 @@ from torch.fx.experimental.symbolic_shapes import (
|
|||
from torch.testing._internal import common_utils
|
||||
from torch.testing._internal.common_cuda import TEST_CUDA
|
||||
|
||||
from functorch.experimental.control_flow import cond
|
||||
|
||||
|
||||
class ExportTests(torch._dynamo.test_case.TestCase):
|
||||
# TODO(voz): Refactor to a shared test function.
|
||||
|
|
|
|||
|
|
@ -7,8 +7,6 @@ import sys
|
|||
import unittest
|
||||
import warnings
|
||||
|
||||
import functorch.experimental.control_flow as control_flow
|
||||
|
||||
import torch
|
||||
import torch._dynamo.config as config
|
||||
|
||||
|
|
@ -34,6 +32,8 @@ from torch.testing._internal.common_utils import (
|
|||
from torch.testing._internal.inductor_utils import HAS_CUDA
|
||||
from torch.testing._internal.logging_utils import LoggingTestCase, make_logging_test
|
||||
|
||||
import functorch.experimental.control_flow as control_flow
|
||||
|
||||
|
||||
requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda")
|
||||
|
||||
|
|
|
|||
|
|
@ -8,11 +8,12 @@ import torch
|
|||
import torch._dynamo
|
||||
import torch._dynamo.test_case
|
||||
import torch._dynamo.testing
|
||||
from functorch.compile import nop
|
||||
from torch._dynamo import compiled_autograd
|
||||
from torch._functorch.aot_autograd import aot_module_simplified
|
||||
from torch.utils.hooks import RemovableHandle
|
||||
|
||||
from functorch.compile import nop
|
||||
|
||||
|
||||
def compiler_fn(gm):
|
||||
return torch._dynamo.optimize("inductor", nopython=True, dynamic=True)(gm)
|
||||
|
|
|
|||
|
|
@ -166,9 +166,10 @@ bw_graph = [None]
|
|||
|
||||
|
||||
def aot_graph_capture_backend(gm, args):
|
||||
from functorch.compile import min_cut_rematerialization_partition
|
||||
from torch._functorch.aot_autograd import aot_module_simplified
|
||||
|
||||
from functorch.compile import min_cut_rematerialization_partition
|
||||
|
||||
def fw_compiler(gm, _):
|
||||
fw_graph[0] = gm
|
||||
return gm
|
||||
|
|
|
|||
|
|
@ -15,7 +15,6 @@ from typing import Dict, List
|
|||
import torch
|
||||
import torch._dynamo as torchdynamo
|
||||
import torch.nn.functional as F
|
||||
from functorch.experimental.control_flow import cond, map
|
||||
from torch import Tensor
|
||||
from torch._dynamo.test_case import TestCase
|
||||
from torch._export.pass_base import _ExportPassBaseDeprecatedDoNotUse
|
||||
|
|
@ -59,6 +58,8 @@ from torch.utils._pytree import (
|
|||
treespec_loads,
|
||||
)
|
||||
|
||||
from functorch.experimental.control_flow import cond, map
|
||||
|
||||
try:
|
||||
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
|
||||
|
||||
|
|
|
|||
|
|
@ -3,13 +3,14 @@ import copy
|
|||
import unittest
|
||||
|
||||
import torch
|
||||
from functorch.experimental import control_flow
|
||||
from torch._dynamo.eval_frame import is_dynamo_supported
|
||||
from torch._export.pass_base import _ExportPassBaseDeprecatedDoNotUse
|
||||
from torch.export import export
|
||||
from torch.fx.passes.infra.pass_base import PassResult
|
||||
from torch.testing._internal.common_utils import IS_WINDOWS, run_tests, TestCase
|
||||
|
||||
from functorch.experimental import control_flow
|
||||
|
||||
|
||||
@unittest.skipIf(not is_dynamo_supported(), "Dynamo not supported")
|
||||
class TestPassInfra(TestCase):
|
||||
|
|
|
|||
|
|
@ -11,7 +11,6 @@ from re import escape
|
|||
from typing import List, Set
|
||||
|
||||
import torch
|
||||
from functorch.experimental.control_flow import cond
|
||||
from torch._dynamo.eval_frame import is_dynamo_supported
|
||||
from torch._export.non_strict_utils import (
|
||||
_fakify_script_objects,
|
||||
|
|
@ -54,6 +53,8 @@ from torch.testing._internal.common_utils import (
|
|||
from torch.testing._internal.torchbind_impls import init_torchbind_implementations
|
||||
from torch.utils import _pytree as pytree
|
||||
|
||||
from functorch.experimental.control_flow import cond
|
||||
|
||||
|
||||
def count_call_function(graph: torch.fx.Graph, target: torch.ops.OpOverload) -> int:
|
||||
count = 0
|
||||
|
|
|
|||
|
|
@ -10,7 +10,6 @@ from typing import Any, List
|
|||
|
||||
import torch
|
||||
import torch._dynamo as torchdynamo
|
||||
from functorch.experimental.control_flow import cond, map
|
||||
from torch import Tensor
|
||||
from torch._export.utils import (
|
||||
get_buffer,
|
||||
|
|
@ -52,6 +51,8 @@ from torch.utils._pytree import (
|
|||
treespec_loads,
|
||||
)
|
||||
|
||||
from functorch.experimental.control_flow import cond, map
|
||||
|
||||
|
||||
@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo isn't support")
|
||||
class TestUnflatten(TestCase):
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@
|
|||
import unittest
|
||||
|
||||
import torch
|
||||
from functorch.experimental import control_flow
|
||||
from torch import Tensor
|
||||
from torch._dynamo.eval_frame import is_dynamo_supported
|
||||
|
||||
|
|
@ -11,6 +10,8 @@ from torch.export import export
|
|||
from torch.export.exported_program import InputKind, InputSpec, TensorArgument
|
||||
from torch.testing._internal.common_utils import IS_WINDOWS, run_tests, TestCase
|
||||
|
||||
from functorch.experimental import control_flow
|
||||
|
||||
|
||||
@unittest.skipIf(not is_dynamo_supported(), "dynamo isn't supported")
|
||||
class TestVerifier(TestCase):
|
||||
|
|
|
|||
|
|
@ -6,9 +6,10 @@
|
|||
import math
|
||||
|
||||
import torch
|
||||
from functorch.dim import cat, dimlists, dims, softmax
|
||||
from torch import nn
|
||||
|
||||
from functorch.dim import cat, dimlists, dims, softmax
|
||||
|
||||
|
||||
class Linear(nn.Linear):
|
||||
def forward(self, input):
|
||||
|
|
|
|||
|
|
@ -11,13 +11,14 @@ from collections import namedtuple
|
|||
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
from functorch import vmap
|
||||
from functorch_additional_op_db import additional_op_db
|
||||
from torch.testing._internal.autograd_function_db import autograd_function_db
|
||||
from torch.testing._internal.common_device_type import toleranceOverride
|
||||
from torch.testing._internal.common_methods_invocations import DecorateInfo, op_db
|
||||
from torch.testing._internal.common_modules import module_db
|
||||
|
||||
from functorch import vmap
|
||||
|
||||
IS_FBCODE = os.getenv("FUNCTORCH_TEST_FBCODE") == "1"
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -20,22 +20,6 @@ import torch._dynamo as torchdynamo
|
|||
import torch.nn as nn
|
||||
import torch.utils._pytree as pytree
|
||||
from common_utils import decorate, decorateForModules, skip, skipOps, xfail
|
||||
from functorch import grad, jacrev, make_fx, vjp, vmap
|
||||
from functorch.compile import (
|
||||
aot_function,
|
||||
aot_module,
|
||||
compiled_function,
|
||||
compiled_module,
|
||||
default_decompositions,
|
||||
default_partition,
|
||||
get_aot_compilation_context,
|
||||
make_boxed_compiler,
|
||||
memory_efficient_fusion,
|
||||
min_cut_rematerialization_partition,
|
||||
nnc_jit,
|
||||
nop,
|
||||
)
|
||||
from functorch.experimental import control_flow
|
||||
from torch._decomp import decomposition_table
|
||||
from torch._functorch.aot_autograd import (
|
||||
aot_export_joint_simple,
|
||||
|
|
@ -77,6 +61,23 @@ from torch.testing._internal.optests import (
|
|||
)
|
||||
from torch.testing._internal.two_tensor import TwoTensor, TwoTensorMode
|
||||
|
||||
from functorch import grad, jacrev, make_fx, vjp, vmap
|
||||
from functorch.compile import (
|
||||
aot_function,
|
||||
aot_module,
|
||||
compiled_function,
|
||||
compiled_module,
|
||||
default_decompositions,
|
||||
default_partition,
|
||||
get_aot_compilation_context,
|
||||
make_boxed_compiler,
|
||||
memory_efficient_fusion,
|
||||
min_cut_rematerialization_partition,
|
||||
nnc_jit,
|
||||
nop,
|
||||
)
|
||||
from functorch.experimental import control_flow
|
||||
|
||||
USE_TORCHVISION = False
|
||||
try:
|
||||
import torchvision
|
||||
|
|
|
|||
|
|
@ -5,8 +5,6 @@ import unittest
|
|||
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
from functorch.experimental import control_flow
|
||||
from functorch.experimental.control_flow import cond, UnsupportedAliasMutationException
|
||||
from torch._higher_order_ops.while_loop import while_loop
|
||||
from torch._subclasses.functional_tensor import (
|
||||
CppFunctionalizeAPI,
|
||||
|
|
@ -27,6 +25,9 @@ from torch.testing._internal.common_utils import (
|
|||
TestCase,
|
||||
)
|
||||
|
||||
from functorch.experimental import control_flow
|
||||
from functorch.experimental.control_flow import cond, UnsupportedAliasMutationException
|
||||
|
||||
|
||||
# TODO: pull these helpers from AOTAutograd later
|
||||
def to_fun(t):
|
||||
|
|
|
|||
|
|
@ -14,6 +14,13 @@ import torch
|
|||
from attn_ft import BertSelfAttention as BertSelfAttentionA, Linear
|
||||
from attn_positional import BertSelfAttention as BertSelfAttentionB
|
||||
|
||||
from torch.testing._internal.common_utils import (
|
||||
run_tests,
|
||||
skipIfTorchDynamo,
|
||||
TEST_CUDA,
|
||||
TestCase,
|
||||
)
|
||||
|
||||
from functorch._C import dim as _C
|
||||
from functorch.dim import (
|
||||
Dim,
|
||||
|
|
@ -25,13 +32,6 @@ from functorch.dim import (
|
|||
Tensor,
|
||||
)
|
||||
|
||||
from torch.testing._internal.common_utils import (
|
||||
run_tests,
|
||||
skipIfTorchDynamo,
|
||||
TEST_CUDA,
|
||||
TestCase,
|
||||
)
|
||||
|
||||
try:
|
||||
from torchvision.models import resnet18
|
||||
except ImportError:
|
||||
|
|
|
|||
|
|
@ -15,8 +15,6 @@ import unittest
|
|||
import warnings
|
||||
from functools import partial, wraps
|
||||
|
||||
import functorch
|
||||
|
||||
# NB: numpy is a testing dependency!
|
||||
import numpy as np
|
||||
import torch
|
||||
|
|
@ -24,21 +22,6 @@ import torch.autograd.forward_ad as fwAD
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from common_utils import expectedFailureIf
|
||||
from functorch import (
|
||||
combine_state_for_ensemble,
|
||||
grad,
|
||||
grad_and_value,
|
||||
hessian,
|
||||
jacfwd,
|
||||
jacrev,
|
||||
jvp,
|
||||
make_functional,
|
||||
make_functional_with_buffers,
|
||||
make_fx,
|
||||
vjp,
|
||||
vmap,
|
||||
)
|
||||
from functorch.experimental import functionalize, replace_all_batch_norm_modules_
|
||||
from torch._C import _ExcludeDispatchKeyGuard, DispatchKey, DispatchKeySet
|
||||
from torch._dynamo import allow_in_graph
|
||||
from torch._functorch.eager_transforms import _slice_argnums
|
||||
|
|
@ -81,6 +64,23 @@ from torch.testing._internal.common_utils import (
|
|||
|
||||
from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten
|
||||
|
||||
import functorch
|
||||
from functorch import (
|
||||
combine_state_for_ensemble,
|
||||
grad,
|
||||
grad_and_value,
|
||||
hessian,
|
||||
jacfwd,
|
||||
jacrev,
|
||||
jvp,
|
||||
make_functional,
|
||||
make_functional_with_buffers,
|
||||
make_fx,
|
||||
vjp,
|
||||
vmap,
|
||||
)
|
||||
from functorch.experimental import functionalize, replace_all_batch_norm_modules_
|
||||
|
||||
USE_TORCHVISION = False
|
||||
try:
|
||||
import torchvision # noqa: F401
|
||||
|
|
|
|||
|
|
@ -8,12 +8,13 @@ from typing import Callable
|
|||
import torch
|
||||
import torch.fx as fx
|
||||
import torch.nn as nn
|
||||
from functorch import make_fx
|
||||
from functorch.compile import memory_efficient_fusion
|
||||
from torch._functorch.compile_utils import fx_graph_cse
|
||||
from torch.nn import functional as F
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
|
||||
from functorch import make_fx
|
||||
from functorch.compile import memory_efficient_fusion
|
||||
|
||||
HAS_CUDA = torch.cuda.is_available()
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,11 +1,12 @@
|
|||
# Owner(s): ["module: functorch"]
|
||||
|
||||
import torch
|
||||
from functorch import make_fx
|
||||
from functorch.compile import minifier
|
||||
from torch._functorch.compile_utils import get_outputs, get_placeholders
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
|
||||
from functorch import make_fx
|
||||
from functorch.compile import minifier
|
||||
|
||||
|
||||
class TestMinifier(TestCase):
|
||||
def test_has_mul_minifier(self):
|
||||
|
|
|
|||
|
|
@ -29,7 +29,6 @@ from common_utils import (
|
|||
tol2,
|
||||
xfail,
|
||||
)
|
||||
from functorch import grad, jacfwd, jacrev, vjp, vmap
|
||||
from functorch_additional_op_db import additional_op_db
|
||||
from torch import Tensor
|
||||
from torch._functorch.eager_transforms import _as_tuple, jvp
|
||||
|
|
@ -62,6 +61,8 @@ from torch.testing._internal.opinfo.core import SampleInput
|
|||
from torch.utils import _pytree as pytree
|
||||
from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten
|
||||
|
||||
from functorch import grad, jacfwd, jacrev, vjp, vmap
|
||||
|
||||
aten = torch.ops.aten
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -27,6 +27,8 @@ SOFTWARE.
|
|||
from typing import Any, Callable, Dict
|
||||
from unittest import mock
|
||||
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
|
||||
from functorch.einops._parsing import (
|
||||
_ellipsis,
|
||||
AnonymousAxis,
|
||||
|
|
@ -34,7 +36,6 @@ from functorch.einops._parsing import (
|
|||
ParsedExpression,
|
||||
validate_rearrange_expressions,
|
||||
)
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
|
||||
mock_anonymous_axis_eq: Callable[[AnonymousAxis, object], bool] = (
|
||||
lambda self, other: isinstance(other, AnonymousAxis) and self.value == other.value
|
||||
|
|
|
|||
|
|
@ -29,9 +29,10 @@ from typing import List, Tuple
|
|||
|
||||
import numpy as np
|
||||
import torch
|
||||
from functorch.einops import rearrange
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
|
||||
from functorch.einops import rearrange
|
||||
|
||||
identity_patterns: List[str] = [
|
||||
"...->...",
|
||||
"a b c d e-> a b c d e",
|
||||
|
|
|
|||
|
|
@ -18,7 +18,6 @@ from collections import namedtuple
|
|||
from typing import OrderedDict
|
||||
from unittest.case import skipIf
|
||||
|
||||
import functorch
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from common_utils import (
|
||||
|
|
@ -36,8 +35,6 @@ from common_utils import (
|
|||
tol1,
|
||||
xfail,
|
||||
)
|
||||
from functorch import grad, grad_and_value, jacfwd, jvp, vjp, vmap
|
||||
from functorch.experimental import chunk_vmap
|
||||
from functorch_additional_op_db import additional_op_db
|
||||
from torch import Tensor
|
||||
from torch._C._functorch import reshape_dim_into, reshape_dim_outof
|
||||
|
|
@ -68,6 +65,10 @@ from torch.testing._internal.common_utils import (
|
|||
)
|
||||
from torch.utils import _pytree as pytree
|
||||
|
||||
import functorch
|
||||
from functorch import grad, grad_and_value, jacfwd, jvp, vjp, vmap
|
||||
from functorch.experimental import chunk_vmap
|
||||
|
||||
FALLBACK_REGEX = "There is a performance drop"
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -106,9 +106,10 @@ class EfficientConvBNEvalTemplate(TestCase):
|
|||
def test_conv_bn_eval(
|
||||
test_class, use_bias, module, sync_bn, decompose_nn_module
|
||||
):
|
||||
from functorch import make_fx
|
||||
from torch._dispatch.python import enable_python_dispatcher
|
||||
|
||||
from functorch import make_fx
|
||||
|
||||
kwargs = {"kernel_size": 3, "stride": 2} if module[0] != nn.Linear else {}
|
||||
mod_eager = test_class(
|
||||
module[0],
|
||||
|
|
|
|||
|
|
@ -2,8 +2,6 @@
|
|||
import contextlib
|
||||
from unittest.mock import patch
|
||||
|
||||
import functorch
|
||||
|
||||
import torch
|
||||
import torch._inductor.config as config
|
||||
import torch.autograd
|
||||
|
|
@ -30,6 +28,8 @@ from torch.testing._internal.common_utils import skipIfRocm
|
|||
# Defines all the kernels for tests
|
||||
from torch.testing._internal.triton_utils import HAS_CUDA, requires_cuda
|
||||
|
||||
import functorch
|
||||
|
||||
if HAS_CUDA:
|
||||
from torch.testing._internal.triton_utils import add_kernel
|
||||
|
||||
|
|
|
|||
|
|
@ -107,7 +107,6 @@ class KernelTests(torch._inductor.test_case.TestCase):
|
|||
@requires_cuda
|
||||
@skipIfRocm
|
||||
def test_triton_kernel_functionalize(self):
|
||||
from functorch import make_fx
|
||||
from torch._higher_order_ops.triton_kernel_wrap import kernel_side_table
|
||||
from torch._subclasses.functional_tensor import (
|
||||
CppFunctionalizeAPI,
|
||||
|
|
@ -115,6 +114,8 @@ class KernelTests(torch._inductor.test_case.TestCase):
|
|||
PythonFunctionalizeAPI,
|
||||
)
|
||||
|
||||
from functorch import make_fx
|
||||
|
||||
kernel_side_table.reset_table()
|
||||
|
||||
def f(x, output):
|
||||
|
|
|
|||
|
|
@ -13,13 +13,14 @@ import torch._custom_ops as custom_ops
|
|||
|
||||
import torch.testing._internal.optests as optests
|
||||
import torch.utils.cpp_extension
|
||||
from functorch import make_fx
|
||||
from torch import Tensor
|
||||
from torch._custom_op.impl import custom_op, CustomOp, infer_schema
|
||||
from torch._utils_internal import get_file_path_2
|
||||
from torch.testing._internal import custom_op_db
|
||||
from torch.testing._internal.common_cuda import TEST_CUDA
|
||||
from torch.testing._internal.custom_op_db import numpy_nonzero
|
||||
|
||||
from functorch import make_fx
|
||||
from typing import * # noqa: F403
|
||||
import numpy as np
|
||||
|
||||
|
|
|
|||
|
|
@ -5,9 +5,9 @@ import functools
|
|||
from importlib import import_module
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from functorch.compile import min_cut_rematerialization_partition
|
||||
|
||||
import torch
|
||||
|
||||
from functorch.compile import min_cut_rematerialization_partition
|
||||
from torch import _guards
|
||||
from torch._functorch import config as functorch_config
|
||||
from torch._functorch.compilers import ts_compile
|
||||
|
|
|
|||
|
|
@ -11,10 +11,10 @@ from itertools import count
|
|||
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
|
||||
from unittest import mock
|
||||
|
||||
from functorch.compile import min_cut_rematerialization_partition
|
||||
|
||||
import torch.fx
|
||||
import torch.utils._pytree as pytree
|
||||
|
||||
from functorch.compile import min_cut_rematerialization_partition
|
||||
from torch._dynamo import (
|
||||
compiled_autograd,
|
||||
config as dynamo_config,
|
||||
|
|
|
|||
|
|
@ -14,9 +14,9 @@ import subprocess
|
|||
from typing import Any, Dict, List, Optional
|
||||
from unittest.mock import patch
|
||||
|
||||
from functorch.compile import draw_graph, get_aot_graph_name, get_graph_being_compiled
|
||||
|
||||
import torch
|
||||
|
||||
from functorch.compile import draw_graph, get_aot_graph_name, get_graph_being_compiled
|
||||
from torch import fx as fx
|
||||
|
||||
from torch._dynamo.repro.after_aot import save_graph_repro, wrap_compiler_debug
|
||||
|
|
|
|||
|
|
@ -5,8 +5,6 @@ from dataclasses import dataclass
|
|||
from functools import partial, wraps
|
||||
from typing import Any, Callable, cast, Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
from functorch import make_fx
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
|
|
@ -15,6 +13,8 @@ import torch.distributed._functional_collectives
|
|||
import torch.nn as nn
|
||||
import torch.utils._pytree as pytree
|
||||
|
||||
from functorch import make_fx
|
||||
|
||||
from torch import fx
|
||||
from torch._decomp.decompositions import native_layer_norm_backward
|
||||
|
||||
|
|
|
|||
|
|
@ -6,10 +6,10 @@ import logging
|
|||
import operator
|
||||
from typing import Callable, List, Optional, Set, Tuple
|
||||
|
||||
from functorch import make_fx
|
||||
|
||||
import torch
|
||||
|
||||
from functorch import make_fx
|
||||
|
||||
from torch._inductor.compile_fx import compile_fx_inner
|
||||
from torch._inductor.decomposition import select_decomp_table
|
||||
|
||||
|
|
|
|||
|
|
@ -1,11 +1,11 @@
|
|||
from operator import itemgetter
|
||||
from typing import List
|
||||
|
||||
from functorch.compile import make_boxed_func
|
||||
|
||||
import torch
|
||||
import torch.fx
|
||||
import torch.nn as nn
|
||||
|
||||
from functorch.compile import make_boxed_func
|
||||
from torch._functorch.compilers import aot_module
|
||||
from torch._inductor.decomposition import select_decomp_table
|
||||
from torch.distributed._tensor import DTensor
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user