[export] Remove torch._export.export (#119095)

XLA changes: https://github.com/pytorch/xla/pull/6486

Test Plan: CI

Differential Revision: D53316196

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119095
Approved by: https://github.com/ydwu4, https://github.com/zhxchen17, https://github.com/tugsbayasgalan, https://github.com/avikchaudhuri, https://github.com/jerryzh168
This commit is contained in:
Angela Yi 2024-02-08 21:22:04 +00:00 committed by PyTorch MergeBot
parent a7754b2b60
commit 0827510fd3
11 changed files with 23 additions and 53 deletions

View File

@ -1 +1 @@
6a0cb712f6335d6b5996e686ddec4a541e4b6ba5
fba464b199559f61faa720de8bf64cf955cfdce7

View File

@ -110,7 +110,7 @@ Example::
def forward(self, x, y):
return x + y
mod = torch._export.export(MyModule())
mod = torch.export.export(MyModule())
print(mod.graph)
The above is the textual representation of a Graph, with each line being a node.

View File

@ -72,7 +72,7 @@ class TensorParallelTest(DTensorTestBase):
inputs = (torch.randn(7, 3, requires_grad=False).to(device=self.device_type),)
with torch.no_grad():
res = model(*inputs)
exported_program = torch._export.export(
exported_program = torch.export.export(
model,
inputs,
constraints=None,
@ -111,7 +111,7 @@ class TensorParallelTest(DTensorTestBase):
with torch.inference_mode():
res = model(*inputs)
exported_program = torch._export.export(
exported_program = torch.export.export(
model,
inputs,
constraints=None,
@ -148,7 +148,7 @@ class TensorParallelTest(DTensorTestBase):
with torch.inference_mode():
res = model(*inputs)
exported_program = torch._export.export(
exported_program = torch.export.export(
model,
inputs,
constraints=None,

View File

@ -21,9 +21,9 @@ from functorch.experimental.control_flow import cond
from torch._dynamo import config
from torch._dynamo.exc import UserError
from torch._dynamo.testing import normalize_gm
from torch._export import dynamic_dim
from torch._higher_order_ops.out_dtype import out_dtype
from torch._subclasses import fake_tensor
from torch.export import dynamic_dim
from torch.fx.experimental.proxy_tensor import make_fx
from torch.fx.experimental.symbolic_shapes import (
ConstraintViolationError,

View File

@ -124,7 +124,7 @@ class TestPassInfra(TestCase):
input_tensor1 = torch.tensor(5.0)
input_tensor2 = torch.tensor(6.0)
ep_before = torch._export.export(my_module, (input_tensor1, input_tensor2))
ep_before = torch.export.export(my_module, (input_tensor1, input_tensor2))
from torch.fx.passes.infra.pass_base import PassResult
def modify_input_output_pass(gm):

View File

@ -4,7 +4,7 @@ import unittest
from typing import Any, Dict
import torch
import torch._export as export
from torch._export import capture_pre_autograd_graph
from torch.ao.quantization.observer import (
HistogramObserver,
@ -102,7 +102,7 @@ class TestDuplicateDQPass(QuantizationTestCase):
# program capture
m = copy.deepcopy(m_eager)
m = export.capture_pre_autograd_graph(
m = capture_pre_autograd_graph(
m,
example_inputs,
)

View File

@ -5,7 +5,7 @@ import unittest
from typing import List
import torch
import torch._export as export
import torch._export
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
from torch.ao.quantization.quantizer import Quantizer
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
@ -64,7 +64,7 @@ class TestMetaDataPorting(QuantizationTestCase):
def _test_quant_tag_preservation_through_decomp(
self, model, example_inputs, from_node_to_tags
):
ep = export.export(model, example_inputs)
ep = torch.export.export(model, example_inputs)
found_tags = True
not_found_nodes = ""
for from_node, tag in from_node_to_tags.items():
@ -102,7 +102,7 @@ class TestMetaDataPorting(QuantizationTestCase):
# program capture
m = copy.deepcopy(m_eager)
m = export.capture_pre_autograd_graph(
m = torch._export.capture_pre_autograd_graph(
m,
example_inputs,
)

View File

@ -5,7 +5,6 @@ import torch
import torch._dynamo
import torch._inductor
import torch._inductor.decomposition
import torch._export
from torch._higher_order_ops.out_dtype import out_dtype
from torch.fx.experimental.proxy_tensor import make_fx
from torch.testing._internal.common_utils import (
@ -62,7 +61,7 @@ class TestOutDtypeOp(TestCase):
weight = torch.randint(-128, 127, (5, 5), dtype=torch.int8)
m = M(weight)
x = torch.randint(-128, 127, (5, 5), dtype=torch.int8)
ep = torch._export.export(
ep = torch.export.export(
m,
(x,),
)
@ -121,14 +120,15 @@ class TestOutDtypeOp(TestCase):
self.assertTrue(torch.allclose(numerical_res, gm(*inp)))
def test_out_dtype_non_functional(self):
def f(x, y):
class M(torch.nn.Module):
def forward(self, x, y):
return out_dtype(
torch.ops.aten.add_.Tensor, torch.int32, x, y
)
with self.assertRaisesRegex(ValueError, "out_dtype's first argument needs to be a functional operator"):
_ = torch._export.export(
f, (torch.randint(-128, 127, (5, 5), dtype=torch.int8), torch.randint(-128, 127, (5, 5), dtype=torch.int8)),
_ = torch.export.export(
M(), (torch.randint(-128, 127, (5, 5), dtype=torch.int8), torch.randint(-128, 127, (5, 5), dtype=torch.int8)),
)
def test_out_dtype_non_op_overload(self):

View File

@ -220,36 +220,6 @@ def capture_pre_autograd_graph(
return module
def export(
f: Callable,
args: Tuple[Any, ...],
kwargs: Optional[Dict[str, Any]] = None,
constraints: Optional[List[Constraint]] = None,
*,
strict: bool = True,
preserve_module_call_signature: Tuple[str, ...] = (),
) -> ExportedProgram:
from torch.export._trace import _export
warnings.warn("This function is deprecated. Please use torch.export.export instead.")
if constraints is not None:
warnings.warn(
"Using `constraints` to specify dynamic shapes for export is DEPRECATED "
"and will not be supported in the future. "
"Please use `dynamic_shapes` instead (see docs on `torch.export.export`).",
DeprecationWarning,
stacklevel=2,
)
return _export(
f,
args,
kwargs,
constraints,
strict=strict,
preserve_module_call_signature=preserve_module_call_signature,
)
def save(
ep: ExportedProgram,
f: Union[str, os.PathLike, io.BytesIO],

View File

@ -3,7 +3,7 @@ from collections import defaultdict
from typing import Tuple, Dict, Optional, List
import torch
from torch._export import export
from torch.export import export
from torch._export.pass_base import _ExportPassBaseDeprecatedDoNotUse
from torch._export.pass_infra.node_metadata import NodeMetadata
from torch._export.pass_infra.proxy_value import ProxyValue

View File

@ -3,9 +3,9 @@ import math
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type
import torch
from torch._export import ExportedProgram
from torch._subclasses.fake_tensor import FakeTensor
from torch.export import ExportedProgram
from torch.utils._pytree import (
_register_pytree_node,
Context,