pytorch/torch/testing/_internal/common_fx2trt.py
Wei Wei f323a8ab3f [fx2trt] support ops for hf_T5 (#62)
Summary:
X-link: https://github.com/pytorch/fx2trt/pull/62

1. support type_as()
2. improve power with method
3. improve to_dtype(torch.Tensor)
4. support float(),half(),int()
5. support isinf(). Not possible to support isnan()
6. support torch.any()
7. support rsqrt()

(Note: this ignores all push blocking failures!)

Test Plan: Test passed

Reviewed By: 842974287

Differential Revision: D35516820

fbshipit-source-id: 7f4b7c1b833034121f7643565bd4d1b4133319bf
(cherry picked from commit 67b454f91dcad0678b57d78dae14cce21d32b84a)
2022-04-27 00:38:11 +00:00

264 lines
9.2 KiB
Python

import unittest
from typing import Callable, List, Tuple
import fx2trt_oss.tracer.acc_tracer.acc_tracer as acc_tracer
import torch
import torch.fx
from fx2trt_oss.fx import (
TRTInterpreter,
InputTensorSpec,
TRTModule,
)
from fx2trt_oss.fx.passes.pass_utils import chain_passes
from fx2trt_oss.fx.utils import LowerPrecision
from torch.fx.experimental.normalize import NormalizeArgs
from torch.fx.passes import shape_prop
from torch.testing._internal.common_utils import TestCase
def fetch_attr(mod, target):
"""
Fetch an attribute from the ``Module`` hierarchy of ``mod.module``.
Args:
target (str): The fully-qualfiied name of the attribute to fetch
Return:
Any: The value of the attribute.
"""
target_atoms = target.split(".")
attr_itr = mod
for i, atom in enumerate(target_atoms):
if not hasattr(attr_itr, atom):
raise RuntimeError(
f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}"
)
attr_itr = getattr(attr_itr, atom)
return attr_itr
@unittest.skipIf(not torch.cuda.is_available(), "Skip because CUDA is not available")
class TRTTestCase(TestCase):
def setUp(self):
super().setUp()
torch.manual_seed(3)
def run_test(self, mod, inputs, expected_ops, unexpected_ops, interpreter, rtol, atol, precision=LowerPrecision.FP32):
with torch.no_grad():
cuda_inputs = []
for i in inputs:
cuda_inputs.append(i.cuda())
mod.eval()
if len(expected_ops):
self.assert_has_op(mod, expected_ops)
if unexpected_ops:
self.assert_unexpected_op(mod, unexpected_ops)
interpreter_result = interpreter.run(lower_precision=precision)
trt_mod = TRTModule(
interpreter_result.engine,
interpreter_result.input_names,
interpreter_result.output_names,
)
ref_outputs = mod(*inputs)
outputs = trt_mod(*cuda_inputs)
if isinstance(outputs, torch.Tensor):
ref_outputs = [ref_outputs]
outputs = [outputs]
for out, ref in zip(outputs, ref_outputs):
if not isinstance(ref, torch.Tensor):
ref = torch.tensor([ref])
ref = ref.cpu() # to_dtype test has cases with gpu output
torch.testing.assert_allclose(out.cpu(), ref, rtol=rtol, atol=atol)
def run_test_custom_compare_results(
self,
mod,
inputs,
expected_ops,
interpreter,
comparators: List[Tuple[Callable, List]],
fp16_mode=False,
):
"""
Runs the test and compares the result using the provided comparators.
The size of comparators must be equal to the number of outputs from 'mod'.
mod - a model to run.
inputs - a list of the model inputs.
expected ops - a list of ops that should be verified.
interpreter - used for converting the model to TRT.
comparators - a list of (func, args) pairs corresponding to each of
the module outputs. usage: func(x, y, *args)
"""
with torch.no_grad():
cuda_inputs = []
for i in inputs:
cuda_inputs.append(i.cuda())
mod.eval()
if len(expected_ops):
self.assert_has_op(mod, expected_ops)
interpreter_result = interpreter.run(lower_precision=LowerPrecision.FP16 if fp16_mode else LowerPrecision.FP32)
trt_mod = TRTModule(
interpreter_result.engine,
interpreter_result.input_names,
interpreter_result.output_names,
)
res_trt = trt_mod(*cuda_inputs).cpu()
res_cpu = mod(*inputs)
assert len(res_trt) == len(res_cpu)
assert len(res_cpu) == len(comparators)
for output_trt, output_cpu, comparator in zip(
res_trt, res_cpu, comparators
):
comp_func = comparator[0]
args = comparator[1]
self.assertTrue(comp_func(output_trt, output_cpu, *args))
def run_test_with_error(self, mod, inputs, interpreter, expect_error):
with self.assertRaises(expect_error):
with torch.no_grad():
cuda_inputs = []
for i in inputs:
cuda_inputs.append(i.cuda())
mod.eval()
interpreter.run(lower_precision=LowerPrecision.FP32)
def assert_has_op(self, mod, ops):
ops_in_mod = set()
for node in mod.graph.nodes:
if node.op == "call_module":
ops_in_mod.add(type(fetch_attr(mod, node.target)))
elif node.op in {"call_function", "call_method"}:
ops_in_mod.add(node.target)
self.assertTrue(
ops_in_mod >= ops, f"expected ops {ops}, actuall ops {ops_in_mod}"
)
def assert_unexpected_op(self, mod, ops):
for node in mod.graph.nodes:
if (node.op == "call_module"):
if type(fetch_attr(mod, node.target)) in ops:
return False
elif node.op in {"call_function", "call_method"}:
if node.target in ops:
return False
return True
class VanillaTestCase(TRTTestCase):
def run_test(self, mod, inputs, expected_ops, rtol=1e-05, atol=1e-06):
mod = torch.fx.symbolic_trace(mod)
shape_prop.ShapeProp(mod).propagate(*inputs)
mod = NormalizeArgs(mod).transform()
interp = TRTInterpreter(mod, InputTensorSpec.from_tensors(inputs))
super().run_test(mod, inputs, expected_ops, None, interp, rtol, atol)
def run_test_custom_compare_results(
self,
mod,
inputs,
expected_ops,
interpreter,
comparators: List[Tuple[Callable, List]],
fp16_mode=False,
):
# interpreter is ignored, we do not need this for Vanilla tests
# Note this is different from internal version, we need to fix the test case
# after we refactor the internal callsites to use this file
mod = torch.fx.symbolic_trace(mod)
shape_prop.ShapeProp(mod).propagate(*inputs)
mod = NormalizeArgs(mod).transform()
interp = TRTInterpreter(mod, InputTensorSpec.from_tensors(inputs))
super().run_test_custom_compare_results(
mod, inputs, expected_ops, interp, comparators, fp16_mode=fp16_mode
)
class AccTestCase(TRTTestCase):
def run_test(
self,
mod,
inputs,
expected_ops,
unexpected_ops=None,
apply_passes=None,
test_explicit_batch_dim=True,
test_implicit_batch_dim=True,
test_explicit_precision=False,
rtol=1e-03,
atol=1e-03,
precision=LowerPrecision.FP32,
):
mod.eval()
mod = acc_tracer.trace(mod, inputs)
if apply_passes is not None:
pass_tracer = chain_passes(*apply_passes)
mod = pass_tracer(mod, inputs)
if test_implicit_batch_dim:
interp = TRTInterpreter(mod, InputTensorSpec.from_tensors(inputs))
super().run_test(mod, inputs, expected_ops, unexpected_ops, interp, rtol, atol, precision)
if test_explicit_batch_dim:
interp = TRTInterpreter(
mod, InputTensorSpec.from_tensors(inputs), explicit_batch_dimension=True
)
super().run_test(mod, inputs, expected_ops, unexpected_ops, interp, rtol, atol, precision)
if test_explicit_precision:
interp = TRTInterpreter(mod, InputTensorSpec.from_tensors(inputs), explicit_precision=test_explicit_precision)
super().run_test(mod, inputs, expected_ops, unexpected_ops, interp, rtol, atol)
interp = TRTInterpreter(
mod, InputTensorSpec.from_tensors(inputs), explicit_batch_dimension=True, explicit_precision=test_explicit_precision
)
super().run_test(mod, inputs, expected_ops, unexpected_ops, interp, rtol, atol, precision)
def run_test_with_assert_error(
self,
mod,
inputs,
expect_error,
test_explicit_batch_dim=True,
test_implicit_batch_dim=True,
):
mod.eval()
mod = acc_tracer.trace(mod, inputs)
if test_implicit_batch_dim:
interp = TRTInterpreter(mod, InputTensorSpec.from_tensors(inputs))
super().run_test_with_error(mod, inputs, interp, expect_error)
if test_explicit_batch_dim:
interp = TRTInterpreter(
mod, InputTensorSpec.from_tensors(inputs), explicit_batch_dimension=True
)
super().run_test_with_error(mod, inputs, interp, expect_error)
def run_test_with_dynamic_shape(
self,
mod,
input_specs,
expected_ops,
unexpected_ops=None,
rtol=1e-03,
atol=1e-03,
):
mod.eval()
inputs = InputTensorSpec.create_inputs_from_specs(input_specs)
mod = acc_tracer.trace(mod, inputs)
interp = TRTInterpreter(mod, input_specs, explicit_batch_dimension=True)
super().run_test(mod, inputs, expected_ops, unexpected_ops, interp, rtol, atol)