[fx2trt] Issue warnings instead of error if there's possible const folding opportunities (#71031)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71031

During the conversion stage, we might create some constants when size op is called and size is static. Raising error here causes problem for this case. Generally speaking it doesn't hurt to allow not const folding.

Test Plan:
Test with D33483843 on shufflenet.

Added unit tests.

Reviewed By: wushirong

Differential Revision: D33484183

fbshipit-source-id: 5b32c06297e56965befd7e83fe8ca273e3665cee
(cherry picked from commit e6b79bd3dd)
This commit is contained in:
Shiyan Deng 2022-01-19 15:09:54 -08:00 committed by PyTorch MergeBot
parent 61713acb07
commit 2dbbb1a921
3 changed files with 60 additions and 19 deletions

View File

@ -9,19 +9,21 @@ from torch.testing._internal.common_fx2trt import AccTestCase, InputTensorSpec
from parameterized import parameterized from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests from torch.testing._internal.common_utils import run_tests
NEED_TEST_BOTH_CONSTANTS_CASE = True
elementwise_ops = [ elementwise_ops = [
((lambda x, y: x + y), acc_ops.add), ((lambda x, y: x + y), acc_ops.add, NEED_TEST_BOTH_CONSTANTS_CASE),
((lambda x, y: x - y), acc_ops.sub), ((lambda x, y: x - y), acc_ops.sub, NEED_TEST_BOTH_CONSTANTS_CASE),
((lambda x, y: x / y), acc_ops.div), ((lambda x, y: x / y), acc_ops.div, NEED_TEST_BOTH_CONSTANTS_CASE),
((lambda x, y: x // y), acc_ops.floor_div), ((lambda x, y: x // y), acc_ops.floor_div, NEED_TEST_BOTH_CONSTANTS_CASE),
((lambda x, y: torch.div(x, y, rounding_mode="trunc")), acc_ops.trunc_div), ((lambda x, y: torch.div(x, y, rounding_mode="trunc")), acc_ops.trunc_div, not NEED_TEST_BOTH_CONSTANTS_CASE),
((lambda x, y: torch.div(x, y, rounding_mode="floor")), acc_ops.floor_div), ((lambda x, y: torch.div(x, y, rounding_mode="floor")), acc_ops.floor_div, NEED_TEST_BOTH_CONSTANTS_CASE),
((lambda x, y: torch.div(x, y)), acc_ops.div), ((lambda x, y: torch.div(x, y)), acc_ops.div, NEED_TEST_BOTH_CONSTANTS_CASE),
# torch.floor_divide rounds result toward zero, rather than -Inf. # torch.floor_divide rounds result toward zero, rather than -Inf.
# https://github.com/pytorch/pytorch/issues/43874 # https://github.com/pytorch/pytorch/issues/43874
((lambda x, y: torch.floor_divide(x, y)), acc_ops.trunc_div), ((lambda x, y: torch.floor_divide(x, y)), acc_ops.trunc_div, not NEED_TEST_BOTH_CONSTANTS_CASE),
((lambda x, y: x * y), acc_ops.mul), ((lambda x, y: x * y), acc_ops.mul, NEED_TEST_BOTH_CONSTANTS_CASE),
(torch.pow, acc_ops.pow), (torch.pow, acc_ops.pow, not NEED_TEST_BOTH_CONSTANTS_CASE),
] ]
@ -42,7 +44,7 @@ class TestBinaryOpConverters(AccTestCase):
self.run_test(m, inputs, expected_ops={expected_op}) self.run_test(m, inputs, expected_ops={expected_op})
@parameterized.expand([(op[1].__name__, op[0], op[1]) for op in elementwise_ops]) @parameterized.expand([(op[1].__name__, op[0], op[1]) for op in elementwise_ops])
def test_elementwise_ops_constant(self, name, orig_op: Callable, expected_op): def test_elementwise_ops_with_one_constant(self, name, orig_op: Callable, expected_op):
class TestModule(nn.Module): class TestModule(nn.Module):
def __init__(self, orig_op): def __init__(self, orig_op):
super().__init__() super().__init__()
@ -57,6 +59,24 @@ class TestBinaryOpConverters(AccTestCase):
inputs = [torch.randn(2, 2)] inputs = [torch.randn(2, 2)]
self.run_test(m, inputs, expected_ops={expected_op}) self.run_test(m, inputs, expected_ops={expected_op})
@parameterized.expand([(op[1].__name__, op[0], op[1]) for op in elementwise_ops if op[2]])
def test_elementwise_op_with_both_constants(self, name, orig_op: Callable, expected_op):
class TestModule(nn.Module):
def __init__(self, orig_op):
super().__init__()
self.constant0 = torch.nn.Parameter(torch.randn(1))
self.constant1 = torch.nn.Parameter(torch.randn(1))
self.orig_op = orig_op
def forward(self, x):
const = self.orig_op(self.constant0, self.constant1)
return self.orig_op(x, const)
m = TestModule(orig_op)
inputs = [torch.randn(2, 2)]
self.run_test(m, inputs, expected_ops={expected_op})
@parameterized.expand( @parameterized.expand(
[ [
( (

View File

@ -1,4 +1,6 @@
from typing import Any, Tuple, Sequence, Union, List, Optional, Dict import operator
import warnings
from typing import Any, Tuple, Sequence, Union, List, Optional, Dict, Callable
import numpy as np import numpy as np
import tensorrt as trt import tensorrt as trt
@ -324,11 +326,12 @@ def add_binary_elementwise_layer(
name: str name: str
) -> TRTTensor: ) -> TRTTensor:
""" """
This function adds a TensorRT elementwise layer. We only allow at most one This function adds a TensorRT elementwise layer. We allow both operands to be
operand to not be a trt tensor, otherwise, we should const fold it first. constant (not a trt tensor) because in implicit batch dimension mode, we could
If any operand is not a trt tensor, we make it a trt constant layer which introduce constant via .size() op. Other scenario should be const folded first.
has the same type as the other trt tensor. Then we broadcast these two inputs If any operand is not a trt tensor, we make it a trt constant layer which has
to have the same number of dimensions. the same type as the other trt tensor. Then we broadcast these two inputs to
have the same number of dimensions.
Limitation: Limitation:
If we are using implicit batch dim mode, the operand that is not a trt If we are using implicit batch dim mode, the operand that is not a trt
@ -357,8 +360,9 @@ def add_binary_elementwise_layer(
dtype = torch_dtype_from_trt(rhs_val.dtype) dtype = torch_dtype_from_trt(rhs_val.dtype)
is_rhs_trt_tensor = True is_rhs_trt_tensor = True
if not is_lhs_trt_tensor and not is_rhs_trt_tensor: if not is_lhs_trt_tensor and not is_rhs_trt_tensor:
raise RuntimeError(f"Both operands of the binary elementwise op {name}" warnings.warn(f"Both operands of the binary elementwise op {name} "
"are constant. In this case, please consider constant fold the model first.") "are constant. In this case, please consider constant fold the model first.")
return get_python_op_from_trt_elementwise_op(op_type)(lhs_val, rhs_val)
lhs_val = get_trt_tensor(network, lhs_val, f"{name}_lhs", dtype) lhs_val = get_trt_tensor(network, lhs_val, f"{name}_lhs", dtype)
rhs_val = get_trt_tensor(network, rhs_val, f"{name}_rhs", dtype) rhs_val = get_trt_tensor(network, rhs_val, f"{name}_rhs", dtype)
@ -614,3 +618,18 @@ def trunc_div(
trt.ElementWiseOperation.PROD, target, f"{name}_output") trt.ElementWiseOperation.PROD, target, f"{name}_output")
return output return output
def get_python_op_from_trt_elementwise_op(trt_op: TRTElementWiseOp) -> Callable[[Any, Any], Any]:
if trt_op == trt.ElementWiseOperation.SUM:
return operator.add
elif trt_op == trt.ElementWiseOperation.PROD:
return operator.mul
elif trt_op == trt.ElementWiseOperation.SUB:
return operator.sub
elif trt_op == trt.ElementWiseOperation.DIV:
return operator.truediv
elif trt_op == trt.ElementWiseOperation.FLOOR_DIV:
return operator.floordiv
else:
raise RuntimeError(f"{trt_op} is not supported yet!")

View File

@ -9,6 +9,7 @@ if hasattr(trt, "__version__"):
TRTPluginFieldCollection = trt.PluginFieldCollection TRTPluginFieldCollection = trt.PluginFieldCollection
TRTPlugin = trt.IPluginV2 TRTPlugin = trt.IPluginV2
TRTDataType = trt.DataType TRTDataType = trt.DataType
TRTElementWiseOp = trt.ElementWiseOperation
else: else:
TRTNetwork = "trt.INetworkDefinition" TRTNetwork = "trt.INetworkDefinition"
TRTTensor = "trt.tensorrt.ITensor" TRTTensor = "trt.tensorrt.ITensor"
@ -16,6 +17,7 @@ else:
TRTPluginFieldCollection = "trt.PluginFieldCollection" TRTPluginFieldCollection = "trt.PluginFieldCollection"
TRTPlugin = "trt.IPluginV2" TRTPlugin = "trt.IPluginV2"
TRTDataType = "trt.DataType" TRTDataType = "trt.DataType"
TRTElementWiseOp = "trt.ElementWiseOperation"
Shape = Sequence[int] Shape = Sequence[int]
ShapeRange = Tuple[Shape, Shape, Shape] ShapeRange = Tuple[Shape, Shape, Shape]