mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[fx2trt] Issue warnings instead of error if there's possible const folding opportunities (#71031)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71031
During the conversion stage, we might create some constants when size op is called and size is static. Raising error here causes problem for this case. Generally speaking it doesn't hurt to allow not const folding.
Test Plan:
Test with D33483843 on shufflenet.
Added unit tests.
Reviewed By: wushirong
Differential Revision: D33484183
fbshipit-source-id: 5b32c06297e56965befd7e83fe8ca273e3665cee
(cherry picked from commit e6b79bd3dd)
This commit is contained in:
parent
61713acb07
commit
2dbbb1a921
|
|
@ -9,19 +9,21 @@ from torch.testing._internal.common_fx2trt import AccTestCase, InputTensorSpec
|
|||
from parameterized import parameterized
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
|
||||
NEED_TEST_BOTH_CONSTANTS_CASE = True
|
||||
|
||||
elementwise_ops = [
|
||||
((lambda x, y: x + y), acc_ops.add),
|
||||
((lambda x, y: x - y), acc_ops.sub),
|
||||
((lambda x, y: x / y), acc_ops.div),
|
||||
((lambda x, y: x // y), acc_ops.floor_div),
|
||||
((lambda x, y: torch.div(x, y, rounding_mode="trunc")), acc_ops.trunc_div),
|
||||
((lambda x, y: torch.div(x, y, rounding_mode="floor")), acc_ops.floor_div),
|
||||
((lambda x, y: torch.div(x, y)), acc_ops.div),
|
||||
((lambda x, y: x + y), acc_ops.add, NEED_TEST_BOTH_CONSTANTS_CASE),
|
||||
((lambda x, y: x - y), acc_ops.sub, NEED_TEST_BOTH_CONSTANTS_CASE),
|
||||
((lambda x, y: x / y), acc_ops.div, NEED_TEST_BOTH_CONSTANTS_CASE),
|
||||
((lambda x, y: x // y), acc_ops.floor_div, NEED_TEST_BOTH_CONSTANTS_CASE),
|
||||
((lambda x, y: torch.div(x, y, rounding_mode="trunc")), acc_ops.trunc_div, not NEED_TEST_BOTH_CONSTANTS_CASE),
|
||||
((lambda x, y: torch.div(x, y, rounding_mode="floor")), acc_ops.floor_div, NEED_TEST_BOTH_CONSTANTS_CASE),
|
||||
((lambda x, y: torch.div(x, y)), acc_ops.div, NEED_TEST_BOTH_CONSTANTS_CASE),
|
||||
# torch.floor_divide rounds result toward zero, rather than -Inf.
|
||||
# https://github.com/pytorch/pytorch/issues/43874
|
||||
((lambda x, y: torch.floor_divide(x, y)), acc_ops.trunc_div),
|
||||
((lambda x, y: x * y), acc_ops.mul),
|
||||
(torch.pow, acc_ops.pow),
|
||||
((lambda x, y: torch.floor_divide(x, y)), acc_ops.trunc_div, not NEED_TEST_BOTH_CONSTANTS_CASE),
|
||||
((lambda x, y: x * y), acc_ops.mul, NEED_TEST_BOTH_CONSTANTS_CASE),
|
||||
(torch.pow, acc_ops.pow, not NEED_TEST_BOTH_CONSTANTS_CASE),
|
||||
]
|
||||
|
||||
|
||||
|
|
@ -42,7 +44,7 @@ class TestBinaryOpConverters(AccTestCase):
|
|||
self.run_test(m, inputs, expected_ops={expected_op})
|
||||
|
||||
@parameterized.expand([(op[1].__name__, op[0], op[1]) for op in elementwise_ops])
|
||||
def test_elementwise_ops_constant(self, name, orig_op: Callable, expected_op):
|
||||
def test_elementwise_ops_with_one_constant(self, name, orig_op: Callable, expected_op):
|
||||
class TestModule(nn.Module):
|
||||
def __init__(self, orig_op):
|
||||
super().__init__()
|
||||
|
|
@ -57,6 +59,24 @@ class TestBinaryOpConverters(AccTestCase):
|
|||
inputs = [torch.randn(2, 2)]
|
||||
self.run_test(m, inputs, expected_ops={expected_op})
|
||||
|
||||
@parameterized.expand([(op[1].__name__, op[0], op[1]) for op in elementwise_ops if op[2]])
|
||||
def test_elementwise_op_with_both_constants(self, name, orig_op: Callable, expected_op):
|
||||
class TestModule(nn.Module):
|
||||
def __init__(self, orig_op):
|
||||
super().__init__()
|
||||
self.constant0 = torch.nn.Parameter(torch.randn(1))
|
||||
self.constant1 = torch.nn.Parameter(torch.randn(1))
|
||||
self.orig_op = orig_op
|
||||
|
||||
def forward(self, x):
|
||||
const = self.orig_op(self.constant0, self.constant1)
|
||||
return self.orig_op(x, const)
|
||||
|
||||
m = TestModule(orig_op)
|
||||
inputs = [torch.randn(2, 2)]
|
||||
self.run_test(m, inputs, expected_ops={expected_op})
|
||||
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
(
|
||||
|
|
|
|||
|
|
@ -1,4 +1,6 @@
|
|||
from typing import Any, Tuple, Sequence, Union, List, Optional, Dict
|
||||
import operator
|
||||
import warnings
|
||||
from typing import Any, Tuple, Sequence, Union, List, Optional, Dict, Callable
|
||||
|
||||
import numpy as np
|
||||
import tensorrt as trt
|
||||
|
|
@ -324,11 +326,12 @@ def add_binary_elementwise_layer(
|
|||
name: str
|
||||
) -> TRTTensor:
|
||||
"""
|
||||
This function adds a TensorRT elementwise layer. We only allow at most one
|
||||
operand to not be a trt tensor, otherwise, we should const fold it first.
|
||||
If any operand is not a trt tensor, we make it a trt constant layer which
|
||||
has the same type as the other trt tensor. Then we broadcast these two inputs
|
||||
to have the same number of dimensions.
|
||||
This function adds a TensorRT elementwise layer. We allow both operands to be
|
||||
constant (not a trt tensor) because in implicit batch dimension mode, we could
|
||||
introduce constant via .size() op. Other scenario should be const folded first.
|
||||
If any operand is not a trt tensor, we make it a trt constant layer which has
|
||||
the same type as the other trt tensor. Then we broadcast these two inputs to
|
||||
have the same number of dimensions.
|
||||
|
||||
Limitation:
|
||||
If we are using implicit batch dim mode, the operand that is not a trt
|
||||
|
|
@ -357,8 +360,9 @@ def add_binary_elementwise_layer(
|
|||
dtype = torch_dtype_from_trt(rhs_val.dtype)
|
||||
is_rhs_trt_tensor = True
|
||||
if not is_lhs_trt_tensor and not is_rhs_trt_tensor:
|
||||
raise RuntimeError(f"Both operands of the binary elementwise op {name}"
|
||||
"are constant. In this case, please consider constant fold the model first.")
|
||||
warnings.warn(f"Both operands of the binary elementwise op {name} "
|
||||
"are constant. In this case, please consider constant fold the model first.")
|
||||
return get_python_op_from_trt_elementwise_op(op_type)(lhs_val, rhs_val)
|
||||
|
||||
lhs_val = get_trt_tensor(network, lhs_val, f"{name}_lhs", dtype)
|
||||
rhs_val = get_trt_tensor(network, rhs_val, f"{name}_rhs", dtype)
|
||||
|
|
@ -614,3 +618,18 @@ def trunc_div(
|
|||
trt.ElementWiseOperation.PROD, target, f"{name}_output")
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def get_python_op_from_trt_elementwise_op(trt_op: TRTElementWiseOp) -> Callable[[Any, Any], Any]:
|
||||
if trt_op == trt.ElementWiseOperation.SUM:
|
||||
return operator.add
|
||||
elif trt_op == trt.ElementWiseOperation.PROD:
|
||||
return operator.mul
|
||||
elif trt_op == trt.ElementWiseOperation.SUB:
|
||||
return operator.sub
|
||||
elif trt_op == trt.ElementWiseOperation.DIV:
|
||||
return operator.truediv
|
||||
elif trt_op == trt.ElementWiseOperation.FLOOR_DIV:
|
||||
return operator.floordiv
|
||||
else:
|
||||
raise RuntimeError(f"{trt_op} is not supported yet!")
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ if hasattr(trt, "__version__"):
|
|||
TRTPluginFieldCollection = trt.PluginFieldCollection
|
||||
TRTPlugin = trt.IPluginV2
|
||||
TRTDataType = trt.DataType
|
||||
TRTElementWiseOp = trt.ElementWiseOperation
|
||||
else:
|
||||
TRTNetwork = "trt.INetworkDefinition"
|
||||
TRTTensor = "trt.tensorrt.ITensor"
|
||||
|
|
@ -16,6 +17,7 @@ else:
|
|||
TRTPluginFieldCollection = "trt.PluginFieldCollection"
|
||||
TRTPlugin = "trt.IPluginV2"
|
||||
TRTDataType = "trt.DataType"
|
||||
TRTElementWiseOp = "trt.ElementWiseOperation"
|
||||
|
||||
Shape = Sequence[int]
|
||||
ShapeRange = Tuple[Shape, Shape, Shape]
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user