mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[export] Remove torch._export.export (#119095)
XLA changes: https://github.com/pytorch/xla/pull/6486 Test Plan: CI Differential Revision: D53316196 Pull Request resolved: https://github.com/pytorch/pytorch/pull/119095 Approved by: https://github.com/ydwu4, https://github.com/zhxchen17, https://github.com/tugsbayasgalan, https://github.com/avikchaudhuri, https://github.com/jerryzh168
This commit is contained in:
parent
a7754b2b60
commit
0827510fd3
2
.github/ci_commit_pins/xla.txt
vendored
2
.github/ci_commit_pins/xla.txt
vendored
|
|
@ -1 +1 @@
|
|||
6a0cb712f6335d6b5996e686ddec4a541e4b6ba5
|
||||
fba464b199559f61faa720de8bf64cf955cfdce7
|
||||
|
|
|
|||
|
|
@ -110,7 +110,7 @@ Example::
|
|||
def forward(self, x, y):
|
||||
return x + y
|
||||
|
||||
mod = torch._export.export(MyModule())
|
||||
mod = torch.export.export(MyModule())
|
||||
print(mod.graph)
|
||||
|
||||
The above is the textual representation of a Graph, with each line being a node.
|
||||
|
|
|
|||
|
|
@ -72,7 +72,7 @@ class TensorParallelTest(DTensorTestBase):
|
|||
inputs = (torch.randn(7, 3, requires_grad=False).to(device=self.device_type),)
|
||||
with torch.no_grad():
|
||||
res = model(*inputs)
|
||||
exported_program = torch._export.export(
|
||||
exported_program = torch.export.export(
|
||||
model,
|
||||
inputs,
|
||||
constraints=None,
|
||||
|
|
@ -111,7 +111,7 @@ class TensorParallelTest(DTensorTestBase):
|
|||
|
||||
with torch.inference_mode():
|
||||
res = model(*inputs)
|
||||
exported_program = torch._export.export(
|
||||
exported_program = torch.export.export(
|
||||
model,
|
||||
inputs,
|
||||
constraints=None,
|
||||
|
|
@ -148,7 +148,7 @@ class TensorParallelTest(DTensorTestBase):
|
|||
|
||||
with torch.inference_mode():
|
||||
res = model(*inputs)
|
||||
exported_program = torch._export.export(
|
||||
exported_program = torch.export.export(
|
||||
model,
|
||||
inputs,
|
||||
constraints=None,
|
||||
|
|
|
|||
|
|
@ -21,9 +21,9 @@ from functorch.experimental.control_flow import cond
|
|||
from torch._dynamo import config
|
||||
from torch._dynamo.exc import UserError
|
||||
from torch._dynamo.testing import normalize_gm
|
||||
from torch._export import dynamic_dim
|
||||
from torch._higher_order_ops.out_dtype import out_dtype
|
||||
from torch._subclasses import fake_tensor
|
||||
from torch.export import dynamic_dim
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch.fx.experimental.symbolic_shapes import (
|
||||
ConstraintViolationError,
|
||||
|
|
|
|||
|
|
@ -124,7 +124,7 @@ class TestPassInfra(TestCase):
|
|||
input_tensor1 = torch.tensor(5.0)
|
||||
input_tensor2 = torch.tensor(6.0)
|
||||
|
||||
ep_before = torch._export.export(my_module, (input_tensor1, input_tensor2))
|
||||
ep_before = torch.export.export(my_module, (input_tensor1, input_tensor2))
|
||||
from torch.fx.passes.infra.pass_base import PassResult
|
||||
|
||||
def modify_input_output_pass(gm):
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import unittest
|
|||
from typing import Any, Dict
|
||||
|
||||
import torch
|
||||
import torch._export as export
|
||||
from torch._export import capture_pre_autograd_graph
|
||||
|
||||
from torch.ao.quantization.observer import (
|
||||
HistogramObserver,
|
||||
|
|
@ -102,7 +102,7 @@ class TestDuplicateDQPass(QuantizationTestCase):
|
|||
|
||||
# program capture
|
||||
m = copy.deepcopy(m_eager)
|
||||
m = export.capture_pre_autograd_graph(
|
||||
m = capture_pre_autograd_graph(
|
||||
m,
|
||||
example_inputs,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ import unittest
|
|||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch._export as export
|
||||
import torch._export
|
||||
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
|
||||
from torch.ao.quantization.quantizer import Quantizer
|
||||
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
|
||||
|
|
@ -64,7 +64,7 @@ class TestMetaDataPorting(QuantizationTestCase):
|
|||
def _test_quant_tag_preservation_through_decomp(
|
||||
self, model, example_inputs, from_node_to_tags
|
||||
):
|
||||
ep = export.export(model, example_inputs)
|
||||
ep = torch.export.export(model, example_inputs)
|
||||
found_tags = True
|
||||
not_found_nodes = ""
|
||||
for from_node, tag in from_node_to_tags.items():
|
||||
|
|
@ -102,7 +102,7 @@ class TestMetaDataPorting(QuantizationTestCase):
|
|||
|
||||
# program capture
|
||||
m = copy.deepcopy(m_eager)
|
||||
m = export.capture_pre_autograd_graph(
|
||||
m = torch._export.capture_pre_autograd_graph(
|
||||
m,
|
||||
example_inputs,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@ import torch
|
|||
import torch._dynamo
|
||||
import torch._inductor
|
||||
import torch._inductor.decomposition
|
||||
import torch._export
|
||||
from torch._higher_order_ops.out_dtype import out_dtype
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch.testing._internal.common_utils import (
|
||||
|
|
@ -62,7 +61,7 @@ class TestOutDtypeOp(TestCase):
|
|||
weight = torch.randint(-128, 127, (5, 5), dtype=torch.int8)
|
||||
m = M(weight)
|
||||
x = torch.randint(-128, 127, (5, 5), dtype=torch.int8)
|
||||
ep = torch._export.export(
|
||||
ep = torch.export.export(
|
||||
m,
|
||||
(x,),
|
||||
)
|
||||
|
|
@ -121,14 +120,15 @@ class TestOutDtypeOp(TestCase):
|
|||
self.assertTrue(torch.allclose(numerical_res, gm(*inp)))
|
||||
|
||||
def test_out_dtype_non_functional(self):
|
||||
def f(x, y):
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, x, y):
|
||||
return out_dtype(
|
||||
torch.ops.aten.add_.Tensor, torch.int32, x, y
|
||||
)
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "out_dtype's first argument needs to be a functional operator"):
|
||||
_ = torch._export.export(
|
||||
f, (torch.randint(-128, 127, (5, 5), dtype=torch.int8), torch.randint(-128, 127, (5, 5), dtype=torch.int8)),
|
||||
_ = torch.export.export(
|
||||
M(), (torch.randint(-128, 127, (5, 5), dtype=torch.int8), torch.randint(-128, 127, (5, 5), dtype=torch.int8)),
|
||||
)
|
||||
|
||||
def test_out_dtype_non_op_overload(self):
|
||||
|
|
|
|||
|
|
@ -220,36 +220,6 @@ def capture_pre_autograd_graph(
|
|||
return module
|
||||
|
||||
|
||||
def export(
|
||||
f: Callable,
|
||||
args: Tuple[Any, ...],
|
||||
kwargs: Optional[Dict[str, Any]] = None,
|
||||
constraints: Optional[List[Constraint]] = None,
|
||||
*,
|
||||
strict: bool = True,
|
||||
preserve_module_call_signature: Tuple[str, ...] = (),
|
||||
) -> ExportedProgram:
|
||||
from torch.export._trace import _export
|
||||
warnings.warn("This function is deprecated. Please use torch.export.export instead.")
|
||||
|
||||
if constraints is not None:
|
||||
warnings.warn(
|
||||
"Using `constraints` to specify dynamic shapes for export is DEPRECATED "
|
||||
"and will not be supported in the future. "
|
||||
"Please use `dynamic_shapes` instead (see docs on `torch.export.export`).",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
return _export(
|
||||
f,
|
||||
args,
|
||||
kwargs,
|
||||
constraints,
|
||||
strict=strict,
|
||||
preserve_module_call_signature=preserve_module_call_signature,
|
||||
)
|
||||
|
||||
|
||||
def save(
|
||||
ep: ExportedProgram,
|
||||
f: Union[str, os.PathLike, io.BytesIO],
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from collections import defaultdict
|
|||
from typing import Tuple, Dict, Optional, List
|
||||
|
||||
import torch
|
||||
from torch._export import export
|
||||
from torch.export import export
|
||||
from torch._export.pass_base import _ExportPassBaseDeprecatedDoNotUse
|
||||
from torch._export.pass_infra.node_metadata import NodeMetadata
|
||||
from torch._export.pass_infra.proxy_value import ProxyValue
|
||||
|
|
|
|||
|
|
@ -3,9 +3,9 @@ import math
|
|||
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
|
||||
from torch._export import ExportedProgram
|
||||
from torch._subclasses.fake_tensor import FakeTensor
|
||||
|
||||
from torch.export import ExportedProgram
|
||||
from torch.utils._pytree import (
|
||||
_register_pytree_node,
|
||||
Context,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user