mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[export] Remove torch._export.export (#119095)
XLA changes: https://github.com/pytorch/xla/pull/6486 Test Plan: CI Differential Revision: D53316196 Pull Request resolved: https://github.com/pytorch/pytorch/pull/119095 Approved by: https://github.com/ydwu4, https://github.com/zhxchen17, https://github.com/tugsbayasgalan, https://github.com/avikchaudhuri, https://github.com/jerryzh168
This commit is contained in:
parent
a7754b2b60
commit
0827510fd3
2
.github/ci_commit_pins/xla.txt
vendored
2
.github/ci_commit_pins/xla.txt
vendored
|
|
@ -1 +1 @@
|
||||||
6a0cb712f6335d6b5996e686ddec4a541e4b6ba5
|
fba464b199559f61faa720de8bf64cf955cfdce7
|
||||||
|
|
|
||||||
|
|
@ -110,7 +110,7 @@ Example::
|
||||||
def forward(self, x, y):
|
def forward(self, x, y):
|
||||||
return x + y
|
return x + y
|
||||||
|
|
||||||
mod = torch._export.export(MyModule())
|
mod = torch.export.export(MyModule())
|
||||||
print(mod.graph)
|
print(mod.graph)
|
||||||
|
|
||||||
The above is the textual representation of a Graph, with each line being a node.
|
The above is the textual representation of a Graph, with each line being a node.
|
||||||
|
|
|
||||||
|
|
@ -72,7 +72,7 @@ class TensorParallelTest(DTensorTestBase):
|
||||||
inputs = (torch.randn(7, 3, requires_grad=False).to(device=self.device_type),)
|
inputs = (torch.randn(7, 3, requires_grad=False).to(device=self.device_type),)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
res = model(*inputs)
|
res = model(*inputs)
|
||||||
exported_program = torch._export.export(
|
exported_program = torch.export.export(
|
||||||
model,
|
model,
|
||||||
inputs,
|
inputs,
|
||||||
constraints=None,
|
constraints=None,
|
||||||
|
|
@ -111,7 +111,7 @@ class TensorParallelTest(DTensorTestBase):
|
||||||
|
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
res = model(*inputs)
|
res = model(*inputs)
|
||||||
exported_program = torch._export.export(
|
exported_program = torch.export.export(
|
||||||
model,
|
model,
|
||||||
inputs,
|
inputs,
|
||||||
constraints=None,
|
constraints=None,
|
||||||
|
|
@ -148,7 +148,7 @@ class TensorParallelTest(DTensorTestBase):
|
||||||
|
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
res = model(*inputs)
|
res = model(*inputs)
|
||||||
exported_program = torch._export.export(
|
exported_program = torch.export.export(
|
||||||
model,
|
model,
|
||||||
inputs,
|
inputs,
|
||||||
constraints=None,
|
constraints=None,
|
||||||
|
|
|
||||||
|
|
@ -21,9 +21,9 @@ from functorch.experimental.control_flow import cond
|
||||||
from torch._dynamo import config
|
from torch._dynamo import config
|
||||||
from torch._dynamo.exc import UserError
|
from torch._dynamo.exc import UserError
|
||||||
from torch._dynamo.testing import normalize_gm
|
from torch._dynamo.testing import normalize_gm
|
||||||
from torch._export import dynamic_dim
|
|
||||||
from torch._higher_order_ops.out_dtype import out_dtype
|
from torch._higher_order_ops.out_dtype import out_dtype
|
||||||
from torch._subclasses import fake_tensor
|
from torch._subclasses import fake_tensor
|
||||||
|
from torch.export import dynamic_dim
|
||||||
from torch.fx.experimental.proxy_tensor import make_fx
|
from torch.fx.experimental.proxy_tensor import make_fx
|
||||||
from torch.fx.experimental.symbolic_shapes import (
|
from torch.fx.experimental.symbolic_shapes import (
|
||||||
ConstraintViolationError,
|
ConstraintViolationError,
|
||||||
|
|
|
||||||
|
|
@ -124,7 +124,7 @@ class TestPassInfra(TestCase):
|
||||||
input_tensor1 = torch.tensor(5.0)
|
input_tensor1 = torch.tensor(5.0)
|
||||||
input_tensor2 = torch.tensor(6.0)
|
input_tensor2 = torch.tensor(6.0)
|
||||||
|
|
||||||
ep_before = torch._export.export(my_module, (input_tensor1, input_tensor2))
|
ep_before = torch.export.export(my_module, (input_tensor1, input_tensor2))
|
||||||
from torch.fx.passes.infra.pass_base import PassResult
|
from torch.fx.passes.infra.pass_base import PassResult
|
||||||
|
|
||||||
def modify_input_output_pass(gm):
|
def modify_input_output_pass(gm):
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ import unittest
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch._export as export
|
from torch._export import capture_pre_autograd_graph
|
||||||
|
|
||||||
from torch.ao.quantization.observer import (
|
from torch.ao.quantization.observer import (
|
||||||
HistogramObserver,
|
HistogramObserver,
|
||||||
|
|
@ -102,7 +102,7 @@ class TestDuplicateDQPass(QuantizationTestCase):
|
||||||
|
|
||||||
# program capture
|
# program capture
|
||||||
m = copy.deepcopy(m_eager)
|
m = copy.deepcopy(m_eager)
|
||||||
m = export.capture_pre_autograd_graph(
|
m = capture_pre_autograd_graph(
|
||||||
m,
|
m,
|
||||||
example_inputs,
|
example_inputs,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ import unittest
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch._export as export
|
import torch._export
|
||||||
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
|
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
|
||||||
from torch.ao.quantization.quantizer import Quantizer
|
from torch.ao.quantization.quantizer import Quantizer
|
||||||
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
|
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
|
||||||
|
|
@ -64,7 +64,7 @@ class TestMetaDataPorting(QuantizationTestCase):
|
||||||
def _test_quant_tag_preservation_through_decomp(
|
def _test_quant_tag_preservation_through_decomp(
|
||||||
self, model, example_inputs, from_node_to_tags
|
self, model, example_inputs, from_node_to_tags
|
||||||
):
|
):
|
||||||
ep = export.export(model, example_inputs)
|
ep = torch.export.export(model, example_inputs)
|
||||||
found_tags = True
|
found_tags = True
|
||||||
not_found_nodes = ""
|
not_found_nodes = ""
|
||||||
for from_node, tag in from_node_to_tags.items():
|
for from_node, tag in from_node_to_tags.items():
|
||||||
|
|
@ -102,7 +102,7 @@ class TestMetaDataPorting(QuantizationTestCase):
|
||||||
|
|
||||||
# program capture
|
# program capture
|
||||||
m = copy.deepcopy(m_eager)
|
m = copy.deepcopy(m_eager)
|
||||||
m = export.capture_pre_autograd_graph(
|
m = torch._export.capture_pre_autograd_graph(
|
||||||
m,
|
m,
|
||||||
example_inputs,
|
example_inputs,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,6 @@ import torch
|
||||||
import torch._dynamo
|
import torch._dynamo
|
||||||
import torch._inductor
|
import torch._inductor
|
||||||
import torch._inductor.decomposition
|
import torch._inductor.decomposition
|
||||||
import torch._export
|
|
||||||
from torch._higher_order_ops.out_dtype import out_dtype
|
from torch._higher_order_ops.out_dtype import out_dtype
|
||||||
from torch.fx.experimental.proxy_tensor import make_fx
|
from torch.fx.experimental.proxy_tensor import make_fx
|
||||||
from torch.testing._internal.common_utils import (
|
from torch.testing._internal.common_utils import (
|
||||||
|
|
@ -62,7 +61,7 @@ class TestOutDtypeOp(TestCase):
|
||||||
weight = torch.randint(-128, 127, (5, 5), dtype=torch.int8)
|
weight = torch.randint(-128, 127, (5, 5), dtype=torch.int8)
|
||||||
m = M(weight)
|
m = M(weight)
|
||||||
x = torch.randint(-128, 127, (5, 5), dtype=torch.int8)
|
x = torch.randint(-128, 127, (5, 5), dtype=torch.int8)
|
||||||
ep = torch._export.export(
|
ep = torch.export.export(
|
||||||
m,
|
m,
|
||||||
(x,),
|
(x,),
|
||||||
)
|
)
|
||||||
|
|
@ -121,14 +120,15 @@ class TestOutDtypeOp(TestCase):
|
||||||
self.assertTrue(torch.allclose(numerical_res, gm(*inp)))
|
self.assertTrue(torch.allclose(numerical_res, gm(*inp)))
|
||||||
|
|
||||||
def test_out_dtype_non_functional(self):
|
def test_out_dtype_non_functional(self):
|
||||||
def f(x, y):
|
class M(torch.nn.Module):
|
||||||
return out_dtype(
|
def forward(self, x, y):
|
||||||
torch.ops.aten.add_.Tensor, torch.int32, x, y
|
return out_dtype(
|
||||||
)
|
torch.ops.aten.add_.Tensor, torch.int32, x, y
|
||||||
|
)
|
||||||
|
|
||||||
with self.assertRaisesRegex(ValueError, "out_dtype's first argument needs to be a functional operator"):
|
with self.assertRaisesRegex(ValueError, "out_dtype's first argument needs to be a functional operator"):
|
||||||
_ = torch._export.export(
|
_ = torch.export.export(
|
||||||
f, (torch.randint(-128, 127, (5, 5), dtype=torch.int8), torch.randint(-128, 127, (5, 5), dtype=torch.int8)),
|
M(), (torch.randint(-128, 127, (5, 5), dtype=torch.int8), torch.randint(-128, 127, (5, 5), dtype=torch.int8)),
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_out_dtype_non_op_overload(self):
|
def test_out_dtype_non_op_overload(self):
|
||||||
|
|
|
||||||
|
|
@ -220,36 +220,6 @@ def capture_pre_autograd_graph(
|
||||||
return module
|
return module
|
||||||
|
|
||||||
|
|
||||||
def export(
|
|
||||||
f: Callable,
|
|
||||||
args: Tuple[Any, ...],
|
|
||||||
kwargs: Optional[Dict[str, Any]] = None,
|
|
||||||
constraints: Optional[List[Constraint]] = None,
|
|
||||||
*,
|
|
||||||
strict: bool = True,
|
|
||||||
preserve_module_call_signature: Tuple[str, ...] = (),
|
|
||||||
) -> ExportedProgram:
|
|
||||||
from torch.export._trace import _export
|
|
||||||
warnings.warn("This function is deprecated. Please use torch.export.export instead.")
|
|
||||||
|
|
||||||
if constraints is not None:
|
|
||||||
warnings.warn(
|
|
||||||
"Using `constraints` to specify dynamic shapes for export is DEPRECATED "
|
|
||||||
"and will not be supported in the future. "
|
|
||||||
"Please use `dynamic_shapes` instead (see docs on `torch.export.export`).",
|
|
||||||
DeprecationWarning,
|
|
||||||
stacklevel=2,
|
|
||||||
)
|
|
||||||
return _export(
|
|
||||||
f,
|
|
||||||
args,
|
|
||||||
kwargs,
|
|
||||||
constraints,
|
|
||||||
strict=strict,
|
|
||||||
preserve_module_call_signature=preserve_module_call_signature,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def save(
|
def save(
|
||||||
ep: ExportedProgram,
|
ep: ExportedProgram,
|
||||||
f: Union[str, os.PathLike, io.BytesIO],
|
f: Union[str, os.PathLike, io.BytesIO],
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@ from collections import defaultdict
|
||||||
from typing import Tuple, Dict, Optional, List
|
from typing import Tuple, Dict, Optional, List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch._export import export
|
from torch.export import export
|
||||||
from torch._export.pass_base import _ExportPassBaseDeprecatedDoNotUse
|
from torch._export.pass_base import _ExportPassBaseDeprecatedDoNotUse
|
||||||
from torch._export.pass_infra.node_metadata import NodeMetadata
|
from torch._export.pass_infra.node_metadata import NodeMetadata
|
||||||
from torch._export.pass_infra.proxy_value import ProxyValue
|
from torch._export.pass_infra.proxy_value import ProxyValue
|
||||||
|
|
|
||||||
|
|
@ -3,9 +3,9 @@ import math
|
||||||
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type
|
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from torch._export import ExportedProgram
|
|
||||||
from torch._subclasses.fake_tensor import FakeTensor
|
from torch._subclasses.fake_tensor import FakeTensor
|
||||||
|
|
||||||
|
from torch.export import ExportedProgram
|
||||||
from torch.utils._pytree import (
|
from torch.utils._pytree import (
|
||||||
_register_pytree_node,
|
_register_pytree_node,
|
||||||
Context,
|
Context,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user