mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: X-link: https://github.com/pytorch/fx2trt/pull/62 1. support type_as() 2. improve power with method 3. improve to_dtype(torch.Tensor) 4. support float(),half(),int() 5. support isinf(). Not possible to support isnan() 6. support torch.any() 7. support rsqrt() (Note: this ignores all push blocking failures!) Test Plan: Test passed Reviewed By: 842974287 Differential Revision: D35516820 fbshipit-source-id: 7f4b7c1b833034121f7643565bd4d1b4133319bf (cherry picked from commit 67b454f91dcad0678b57d78dae14cce21d32b84a)
264 lines
9.2 KiB
Python
264 lines
9.2 KiB
Python
import unittest
|
|
from typing import Callable, List, Tuple
|
|
|
|
import fx2trt_oss.tracer.acc_tracer.acc_tracer as acc_tracer
|
|
import torch
|
|
import torch.fx
|
|
from fx2trt_oss.fx import (
|
|
TRTInterpreter,
|
|
InputTensorSpec,
|
|
TRTModule,
|
|
)
|
|
from fx2trt_oss.fx.passes.pass_utils import chain_passes
|
|
from fx2trt_oss.fx.utils import LowerPrecision
|
|
from torch.fx.experimental.normalize import NormalizeArgs
|
|
from torch.fx.passes import shape_prop
|
|
from torch.testing._internal.common_utils import TestCase
|
|
|
|
|
|
def fetch_attr(mod, target):
|
|
"""
|
|
Fetch an attribute from the ``Module`` hierarchy of ``mod.module``.
|
|
|
|
Args:
|
|
target (str): The fully-qualfiied name of the attribute to fetch
|
|
|
|
Return:
|
|
Any: The value of the attribute.
|
|
"""
|
|
target_atoms = target.split(".")
|
|
attr_itr = mod
|
|
for i, atom in enumerate(target_atoms):
|
|
if not hasattr(attr_itr, atom):
|
|
raise RuntimeError(
|
|
f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}"
|
|
)
|
|
attr_itr = getattr(attr_itr, atom)
|
|
return attr_itr
|
|
|
|
|
|
@unittest.skipIf(not torch.cuda.is_available(), "Skip because CUDA is not available")
|
|
class TRTTestCase(TestCase):
|
|
def setUp(self):
|
|
super().setUp()
|
|
torch.manual_seed(3)
|
|
|
|
def run_test(self, mod, inputs, expected_ops, unexpected_ops, interpreter, rtol, atol, precision=LowerPrecision.FP32):
|
|
with torch.no_grad():
|
|
cuda_inputs = []
|
|
for i in inputs:
|
|
cuda_inputs.append(i.cuda())
|
|
|
|
mod.eval()
|
|
if len(expected_ops):
|
|
self.assert_has_op(mod, expected_ops)
|
|
if unexpected_ops:
|
|
self.assert_unexpected_op(mod, unexpected_ops)
|
|
|
|
interpreter_result = interpreter.run(lower_precision=precision)
|
|
trt_mod = TRTModule(
|
|
interpreter_result.engine,
|
|
interpreter_result.input_names,
|
|
interpreter_result.output_names,
|
|
)
|
|
|
|
ref_outputs = mod(*inputs)
|
|
outputs = trt_mod(*cuda_inputs)
|
|
|
|
if isinstance(outputs, torch.Tensor):
|
|
ref_outputs = [ref_outputs]
|
|
outputs = [outputs]
|
|
for out, ref in zip(outputs, ref_outputs):
|
|
if not isinstance(ref, torch.Tensor):
|
|
ref = torch.tensor([ref])
|
|
ref = ref.cpu() # to_dtype test has cases with gpu output
|
|
torch.testing.assert_allclose(out.cpu(), ref, rtol=rtol, atol=atol)
|
|
|
|
def run_test_custom_compare_results(
|
|
self,
|
|
mod,
|
|
inputs,
|
|
expected_ops,
|
|
interpreter,
|
|
comparators: List[Tuple[Callable, List]],
|
|
fp16_mode=False,
|
|
):
|
|
"""
|
|
Runs the test and compares the result using the provided comparators.
|
|
The size of comparators must be equal to the number of outputs from 'mod'.
|
|
|
|
mod - a model to run.
|
|
inputs - a list of the model inputs.
|
|
expected ops - a list of ops that should be verified.
|
|
interpreter - used for converting the model to TRT.
|
|
comparators - a list of (func, args) pairs corresponding to each of
|
|
the module outputs. usage: func(x, y, *args)
|
|
|
|
"""
|
|
with torch.no_grad():
|
|
cuda_inputs = []
|
|
for i in inputs:
|
|
cuda_inputs.append(i.cuda())
|
|
|
|
mod.eval()
|
|
if len(expected_ops):
|
|
self.assert_has_op(mod, expected_ops)
|
|
|
|
interpreter_result = interpreter.run(lower_precision=LowerPrecision.FP16 if fp16_mode else LowerPrecision.FP32)
|
|
trt_mod = TRTModule(
|
|
interpreter_result.engine,
|
|
interpreter_result.input_names,
|
|
interpreter_result.output_names,
|
|
)
|
|
res_trt = trt_mod(*cuda_inputs).cpu()
|
|
res_cpu = mod(*inputs)
|
|
assert len(res_trt) == len(res_cpu)
|
|
assert len(res_cpu) == len(comparators)
|
|
for output_trt, output_cpu, comparator in zip(
|
|
res_trt, res_cpu, comparators
|
|
):
|
|
comp_func = comparator[0]
|
|
args = comparator[1]
|
|
self.assertTrue(comp_func(output_trt, output_cpu, *args))
|
|
|
|
def run_test_with_error(self, mod, inputs, interpreter, expect_error):
|
|
with self.assertRaises(expect_error):
|
|
with torch.no_grad():
|
|
cuda_inputs = []
|
|
for i in inputs:
|
|
cuda_inputs.append(i.cuda())
|
|
|
|
mod.eval()
|
|
interpreter.run(lower_precision=LowerPrecision.FP32)
|
|
|
|
def assert_has_op(self, mod, ops):
|
|
ops_in_mod = set()
|
|
|
|
for node in mod.graph.nodes:
|
|
if node.op == "call_module":
|
|
ops_in_mod.add(type(fetch_attr(mod, node.target)))
|
|
elif node.op in {"call_function", "call_method"}:
|
|
ops_in_mod.add(node.target)
|
|
|
|
self.assertTrue(
|
|
ops_in_mod >= ops, f"expected ops {ops}, actuall ops {ops_in_mod}"
|
|
)
|
|
|
|
def assert_unexpected_op(self, mod, ops):
|
|
for node in mod.graph.nodes:
|
|
if (node.op == "call_module"):
|
|
if type(fetch_attr(mod, node.target)) in ops:
|
|
return False
|
|
elif node.op in {"call_function", "call_method"}:
|
|
if node.target in ops:
|
|
return False
|
|
return True
|
|
|
|
|
|
class VanillaTestCase(TRTTestCase):
|
|
def run_test(self, mod, inputs, expected_ops, rtol=1e-05, atol=1e-06):
|
|
mod = torch.fx.symbolic_trace(mod)
|
|
shape_prop.ShapeProp(mod).propagate(*inputs)
|
|
mod = NormalizeArgs(mod).transform()
|
|
interp = TRTInterpreter(mod, InputTensorSpec.from_tensors(inputs))
|
|
super().run_test(mod, inputs, expected_ops, None, interp, rtol, atol)
|
|
|
|
def run_test_custom_compare_results(
|
|
self,
|
|
mod,
|
|
inputs,
|
|
expected_ops,
|
|
interpreter,
|
|
comparators: List[Tuple[Callable, List]],
|
|
fp16_mode=False,
|
|
):
|
|
# interpreter is ignored, we do not need this for Vanilla tests
|
|
# Note this is different from internal version, we need to fix the test case
|
|
# after we refactor the internal callsites to use this file
|
|
mod = torch.fx.symbolic_trace(mod)
|
|
shape_prop.ShapeProp(mod).propagate(*inputs)
|
|
mod = NormalizeArgs(mod).transform()
|
|
interp = TRTInterpreter(mod, InputTensorSpec.from_tensors(inputs))
|
|
super().run_test_custom_compare_results(
|
|
mod, inputs, expected_ops, interp, comparators, fp16_mode=fp16_mode
|
|
)
|
|
|
|
|
|
class AccTestCase(TRTTestCase):
|
|
def run_test(
|
|
self,
|
|
mod,
|
|
inputs,
|
|
expected_ops,
|
|
unexpected_ops=None,
|
|
apply_passes=None,
|
|
test_explicit_batch_dim=True,
|
|
test_implicit_batch_dim=True,
|
|
test_explicit_precision=False,
|
|
rtol=1e-03,
|
|
atol=1e-03,
|
|
precision=LowerPrecision.FP32,
|
|
):
|
|
mod.eval()
|
|
mod = acc_tracer.trace(mod, inputs)
|
|
|
|
if apply_passes is not None:
|
|
pass_tracer = chain_passes(*apply_passes)
|
|
mod = pass_tracer(mod, inputs)
|
|
|
|
if test_implicit_batch_dim:
|
|
interp = TRTInterpreter(mod, InputTensorSpec.from_tensors(inputs))
|
|
super().run_test(mod, inputs, expected_ops, unexpected_ops, interp, rtol, atol, precision)
|
|
|
|
if test_explicit_batch_dim:
|
|
interp = TRTInterpreter(
|
|
mod, InputTensorSpec.from_tensors(inputs), explicit_batch_dimension=True
|
|
)
|
|
super().run_test(mod, inputs, expected_ops, unexpected_ops, interp, rtol, atol, precision)
|
|
|
|
if test_explicit_precision:
|
|
interp = TRTInterpreter(mod, InputTensorSpec.from_tensors(inputs), explicit_precision=test_explicit_precision)
|
|
super().run_test(mod, inputs, expected_ops, unexpected_ops, interp, rtol, atol)
|
|
|
|
interp = TRTInterpreter(
|
|
mod, InputTensorSpec.from_tensors(inputs), explicit_batch_dimension=True, explicit_precision=test_explicit_precision
|
|
)
|
|
super().run_test(mod, inputs, expected_ops, unexpected_ops, interp, rtol, atol, precision)
|
|
|
|
|
|
def run_test_with_assert_error(
|
|
self,
|
|
mod,
|
|
inputs,
|
|
expect_error,
|
|
test_explicit_batch_dim=True,
|
|
test_implicit_batch_dim=True,
|
|
):
|
|
mod.eval()
|
|
mod = acc_tracer.trace(mod, inputs)
|
|
|
|
if test_implicit_batch_dim:
|
|
interp = TRTInterpreter(mod, InputTensorSpec.from_tensors(inputs))
|
|
super().run_test_with_error(mod, inputs, interp, expect_error)
|
|
|
|
if test_explicit_batch_dim:
|
|
interp = TRTInterpreter(
|
|
mod, InputTensorSpec.from_tensors(inputs), explicit_batch_dimension=True
|
|
)
|
|
super().run_test_with_error(mod, inputs, interp, expect_error)
|
|
|
|
def run_test_with_dynamic_shape(
|
|
self,
|
|
mod,
|
|
input_specs,
|
|
expected_ops,
|
|
unexpected_ops=None,
|
|
rtol=1e-03,
|
|
atol=1e-03,
|
|
):
|
|
mod.eval()
|
|
inputs = InputTensorSpec.create_inputs_from_specs(input_specs)
|
|
mod = acc_tracer.trace(mod, inputs)
|
|
interp = TRTInterpreter(mod, input_specs, explicit_batch_dimension=True)
|
|
super().run_test(mod, inputs, expected_ops, unexpected_ops, interp, rtol, atol)
|