Update operator list for AutocastCPU (#68725)

Update operator list for AutocastCPU.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/68725
Approved by: https://github.com/frank-wei
This commit is contained in:
ecao 2022-05-11 17:28:33 +00:00 committed by PyTorch MergeBot
parent 9edee09ed6
commit 5993cc0b3d
4 changed files with 16 additions and 101 deletions

View File

@ -494,18 +494,15 @@ TORCH_LIBRARY_IMPL(aten, AutocastCPU, m) {
KERNEL_CPU(ADD_NS(addbmm), "addbmm", Tensor (const Tensor &, const Tensor &, const Tensor &, const Scalar&, const Scalar&), lower_precision_fp)
KERNEL_CPU(ADD_NS(linear), "linear", Tensor (const Tensor &, const Tensor &, const c10::optional<Tensor> &), lower_precision_fp)
KERNEL_CPU(ADD_NS(_convolution), "_convolution.deprecated", Tensor (const Tensor &, const Tensor &, const c10::optional<Tensor>&, IntArrayRef, IntArrayRef, IntArrayRef, bool, IntArrayRef, int64_t, bool, bool, bool), lower_precision_fp)
KERNEL_CPU(ADD_NS(_convolution), "_convolution", Tensor (const Tensor &, const Tensor &, const c10::optional<Tensor>&, IntArrayRef, IntArrayRef, IntArrayRef, bool, IntArrayRef, int64_t, bool, bool, bool, bool), lower_precision_fp)
KERNEL_CPU(ADD_NS(matmul), "matmul", Tensor (const Tensor &, const Tensor &), lower_precision_fp)
// fp32 cast policy
KERNEL_CPU(ADD_NS(conv_transpose1d), "conv_transpose1d", Tensor (const Tensor &, const Tensor &, const c10::optional<Tensor> &, IntArrayRef, IntArrayRef, IntArrayRef, int64_t, IntArrayRef), fp32)
KERNEL_CPU(ADD_NS(conv_transpose2d), "conv_transpose2d.input", Tensor (const Tensor &, const Tensor &, const c10::optional<Tensor> &, IntArrayRef, IntArrayRef, IntArrayRef, int64_t, IntArrayRef), fp32)
KERNEL_CPU(ADD_NS(conv_transpose3d), "conv_transpose3d.input", Tensor (const Tensor &, const Tensor &, const c10::optional<Tensor> &, IntArrayRef, IntArrayRef, IntArrayRef, int64_t, IntArrayRef), fp32)
KERNEL_CPU(ADD_NS(batch_norm), "batch_norm", Tensor (const Tensor &, const c10::optional<Tensor> &, const c10::optional<Tensor> &, const c10::optional<Tensor> &, const c10::optional<Tensor> &, bool, double, double, bool), fp32)
KERNEL_CPU(ADD_NS(dropout), "dropout", Tensor (const Tensor &, double, bool), fp32)
KERNEL_CPU(ADD_NS(avg_pool1d), "avg_pool1d", Tensor (const Tensor &, IntArrayRef, IntArrayRef, IntArrayRef, bool, bool), fp32)
KERNEL_CPU(ADD_NS(avg_pool2d), "avg_pool2d", Tensor (const Tensor &, IntArrayRef, IntArrayRef, IntArrayRef, bool, bool, c10::optional<int64_t>), fp32)
KERNEL_CPU(ADD_NS(avg_pool3d), "avg_pool3d", Tensor (const Tensor &, IntArrayRef, IntArrayRef, IntArrayRef, bool, bool, c10::optional<int64_t>), fp32)
KERNEL_CPU(ADD_NS(gelu), "gelu", Tensor (const Tensor &, c10::string_view), fp32)
KERNEL_CPU(ADD_NS(upsample_nearest1d), "upsample_nearest1d", Tensor (const Tensor &, IntArrayRef, c10::optional<double>), fp32)
KERNEL_CPU(ADD_NS(upsample_nearest1d), "upsample_nearest1d.vec", Tensor (const Tensor &, at::OptionalIntArrayRef, c10::optional<ArrayRef<double>>), fp32)
KERNEL_CPU(ADD_NS(_upsample_nearest_exact1d), "_upsample_nearest_exact1d", Tensor (const Tensor &, IntArrayRef, c10::optional<double>), fp32)
@ -524,14 +521,10 @@ TORCH_LIBRARY_IMPL(aten, AutocastCPU, m) {
KERNEL_CPU(ADD_NS(upsample_bilinear2d), "upsample_bilinear2d.vec", Tensor (const Tensor &, at::OptionalIntArrayRef, bool, c10::optional<ArrayRef<double>>), fp32)
KERNEL_CPU(ADD_NS(upsample_trilinear3d), "upsample_trilinear3d", Tensor (const Tensor &, IntArrayRef, bool, c10::optional<double>, c10::optional<double>, c10::optional<double>), fp32)
KERNEL_CPU(ADD_NS(upsample_trilinear3d), "upsample_trilinear3d.vec", Tensor (const Tensor &, at::OptionalIntArrayRef, bool, c10::optional<ArrayRef<double>>), fp32)
KERNEL_CPU(ADD_NS(binary_cross_entropy), "binary_cross_entropy", Tensor (const Tensor &, const Tensor &, const c10::optional<Tensor>&, int64_t), fp32)
KERNEL_CPU(ADD_NS(binary_cross_entropy_with_logits), "binary_cross_entropy_with_logits", Tensor (const Tensor &, const Tensor &, const c10::optional<Tensor>&, const c10::optional<Tensor>&, int64_t), fp32)
KERNEL_CPU(ADD_NS(instance_norm), "instance_norm", Tensor (const Tensor &, const c10::optional<Tensor>&, const c10::optional<Tensor>&, const c10::optional<Tensor>&, const c10::optional<Tensor>&, bool, double, double, bool), fp32)
KERNEL_CPU(ADD_NS(grid_sampler), "grid_sampler", Tensor(const Tensor &, const Tensor &, int64_t, int64_t, bool), fp32)
KERNEL_CPU(ADD_NS(polar), "polar", Tensor(const Tensor &, const Tensor &), fp32)
KERNEL_CPU(ADD_NS(multinomial), "multinomial", Tensor(const Tensor &, int64_t, bool, c10::optional<at::Generator>), fp32)
KERNEL_CPU(ADD_NS(poisson), "poisson", Tensor(const Tensor &, c10::optional<at::Generator>), fp32)
KERNEL_CPU(ADD_NS(fmod), "fmod.Tensor", Tensor(const Tensor &, const Tensor &), fp32)
KERNEL_CPU(ADD_NS(fmod), "fmod.Scalar", Tensor(const Tensor &, const Scalar &), fp32)
KERNEL_CPU(ADD_NS(prod), "prod", Tensor(const Tensor &, c10::optional<at::ScalarType>), fp32)
@ -551,28 +544,23 @@ TORCH_LIBRARY_IMPL(aten, AutocastCPU, m) {
KERNEL_CPU(ADD_NS(cumsum), "cumsum.dimname", Tensor(const Tensor &, at::Dimname, c10::optional<at::ScalarType>), fp32)
KERNEL_CPU(ADD_NS(diag), "diag", Tensor(const Tensor &, int64_t), fp32)
KERNEL_CPU(ADD_NS(diagflat), "diagflat", Tensor(const Tensor &, int64_t), fp32)
KERNEL_CPU(ADD_NS(grid_sampler_2d), "grid_sampler_2d", Tensor(const Tensor &, const Tensor &, int64_t, int64_t, bool), fp32)
KERNEL_CPU(ADD_NS(_grid_sampler_2d_cpu_fallback), "_grid_sampler_2d_cpu_fallback", Tensor(const Tensor &, const Tensor &, int64_t, int64_t, bool), fp32)
KERNEL_CPU(ADD_NS(grid_sampler_3d), "grid_sampler_3d", Tensor(const Tensor &, const Tensor &, int64_t, int64_t, bool), fp32)
KERNEL_CPU(ADD_NS(histc), "histc", Tensor(const Tensor &, int64_t, const at::Scalar &, const at::Scalar &), fp32)
KERNEL_CPU(ADD_NS(logcumsumexp), "logcumsumexp", Tensor(const Tensor &, int64_t), fp32)
KERNEL_CPU(ADD_NS(searchsorted), "searchsorted.Tensor", Tensor(const Tensor &, const Tensor &, bool, bool, const c10::optional<c10::string_view>, const c10::optional<Tensor> &), fp32)
KERNEL_CPU(ADD_NS(searchsorted), "searchsorted.Scalar", Tensor(const Tensor &, const at::Scalar &, bool, bool, const c10::optional<c10::string_view>, const c10::optional<Tensor> &), fp32)
KERNEL_CPU(ADD_NS(trace), "trace", Tensor(const Tensor &), fp32)
KERNEL_CPU(ADD_NS(tril), "tril", Tensor(const Tensor &, int64_t), fp32)
KERNEL_CPU(ADD_NS(triu), "triu", Tensor(const Tensor &, int64_t), fp32)
KERNEL_CPU(ADD_NS(vander), "vander", Tensor(const Tensor &, c10::optional<int64_t>, bool), fp32)
KERNEL_CPU(ADD_NS(view_as_complex), "view_as_complex", Tensor(const Tensor &), fp32)
KERNEL_CPU(ADD_NS(cholesky), "cholesky", Tensor(const Tensor &, bool), fp32)
KERNEL_CPU(ADD_NS(cholesky_inverse), "cholesky_inverse", Tensor(const Tensor &, bool), fp32)
KERNEL_CPU(ADD_NS(cholesky_solve), "cholesky_solve", Tensor(const Tensor &, const Tensor &, bool), fp32)
KERNEL_CPU(ADD_NS(dot), "dot", Tensor(const Tensor &, const Tensor &), fp32)
KERNEL_CPU(ADD_NS(inverse), "inverse", Tensor(const Tensor &), fp32)
KERNEL_CPU(ADD_NS(lu_solve), "lu_solve", Tensor(const Tensor &, const Tensor &, const Tensor &), fp32)
KERNEL_CPU(ADD_NS(matrix_rank), "matrix_rank", Tensor(const Tensor &, bool), fp32)
KERNEL_CPU(ADD_NS(orgqr), "orgqr", Tensor(const Tensor &, const Tensor &), fp32)
KERNEL_CPU(ADD_NS(ormqr), "ormqr", Tensor(const Tensor &, const Tensor &, const Tensor &, bool, bool), fp32)
KERNEL_CPU(ADD_NS(pinverse), "pinverse", Tensor(const Tensor &, double), fp32)
KERNEL_CPU(ADD_NS(vdot), "vdot", Tensor(const Tensor &, const Tensor &), fp32)
KERNEL_CPU(ADD_NS(im2col), "im2col", Tensor(const Tensor &, IntArrayRef, IntArrayRef, IntArrayRef, IntArrayRef), fp32)
KERNEL_CPU(ADD_NS(col2im), "col2im", Tensor(const Tensor &, IntArrayRef, IntArrayRef, IntArrayRef, IntArrayRef, IntArrayRef), fp32)
KERNEL_CPU(ADD_NS(max_pool3d), "max_pool3d", Tensor(const Tensor &, IntArrayRef, IntArrayRef, IntArrayRef, IntArrayRef, bool), fp32)
KERNEL_CPU(ADD_NS(max_unpool2d), "max_unpool2d", Tensor(const Tensor &, const Tensor &, IntArrayRef), fp32)
KERNEL_CPU(ADD_NS(max_unpool3d), "max_unpool3d", Tensor(const Tensor &, const Tensor &, IntArrayRef, IntArrayRef, IntArrayRef), fp32)
@ -583,17 +571,9 @@ TORCH_LIBRARY_IMPL(aten, AutocastCPU, m) {
KERNEL_CPU(ADD_NS(replication_pad2d), "replication_pad2d", Tensor(const Tensor &, IntArrayRef), fp32)
KERNEL_CPU(ADD_NS(replication_pad3d), "replication_pad3d", Tensor(const Tensor &, IntArrayRef), fp32)
KERNEL_CPU(ADD_NS(elu), "elu", Tensor(const Tensor &, const Scalar &, const Scalar &, const Scalar &), fp32)
KERNEL_CPU(ADD_NS(hardshrink), "hardshrink", Tensor(const Tensor &, const Scalar &), fp32)
KERNEL_CPU(ADD_NS(hardsigmoid), "hardsigmoid", Tensor(const Tensor &), fp32)
KERNEL_CPU(ADD_NS(hardswish), "hardswish", Tensor(const Tensor &), fp32)
KERNEL_CPU(ADD_NS(log_sigmoid), "log_sigmoid", Tensor(const Tensor &), fp32)
KERNEL_CPU(ADD_NS(prelu), "prelu", Tensor(const Tensor &, const Tensor &), fp32)
KERNEL_CPU(ADD_NS(selu), "selu", Tensor(const Tensor &), fp32)
KERNEL_CPU(ADD_NS(celu), "celu", Tensor(const Tensor &, const Scalar &), fp32)
KERNEL_CPU(ADD_NS(softplus), "softplus", Tensor(const Tensor &, const Scalar &, const Scalar &), fp32)
KERNEL_CPU(ADD_NS(softshrink), "softshrink", Tensor(const Tensor &, const Scalar &), fp32)
KERNEL_CPU(ADD_NS(group_norm), "group_norm", Tensor(const Tensor &, int64_t, const c10::optional<Tensor> &, const c10::optional<Tensor> &, double, bool), fp32)
KERNEL_CPU(ADD_NS(smooth_l1_loss), "smooth_l1_loss", Tensor (const Tensor &, const Tensor &, int64_t, double), fp32)
KERNEL_CPU(ADD_NS(mse_loss), "mse_loss", Tensor(const Tensor &, const Tensor &, int64_t), fp32)
KERNEL_CPU(ADD_NS(ctc_loss), "ctc_loss.IntList", Tensor(const Tensor &, const Tensor &, IntArrayRef, IntArrayRef, int64_t, int64_t, bool), fp32)
KERNEL_CPU(ADD_NS(ctc_loss), "ctc_loss.Tensor", Tensor(const Tensor &, const Tensor &, const Tensor &, const Tensor &, int64_t, int64_t, bool), fp32)
@ -620,6 +600,8 @@ TORCH_LIBRARY_IMPL(aten, AutocastCPU, m) {
KERNEL_CPU(ADD_NS(linalg_cond), "linalg_cond.p_str", Tensor(const Tensor &, c10::string_view), fp32)
KERNEL_CPU(ADD_NS(linalg_matrix_rank), "linalg_matrix_rank", Tensor(const Tensor &, double, bool), fp32)
KERNEL_CPU(ADD_NS(linalg_matrix_rank), "linalg_matrix_rank.tol_tensor", Tensor(const Tensor &, const Tensor &, bool), fp32)
KERNEL_CPU(ADD_NS(linalg_matrix_rank), "linalg_matrix_rank.atol_rtol_tensor", Tensor(const Tensor &, const c10::optional<at::Tensor> &, const c10::optional<at::Tensor> &, bool), fp32)
KERNEL_CPU(ADD_NS(linalg_matrix_rank), "linalg_matrix_rank.atol_rtol_float", Tensor(const Tensor &, c10::optional<double>, c10::optional<double>, bool), fp32)
KERNEL_CPU(ADD_NS(linalg_solve), "linalg_solve", Tensor(const Tensor &, const Tensor &), fp32)
KERNEL_CPU(ADD_NS(linalg_cholesky), "linalg_cholesky", Tensor(const Tensor &, bool), fp32)
KERNEL_CPU(ADD_NS(linalg_svdvals), "linalg_svdvals", Tensor(const Tensor &), fp32)
@ -632,30 +614,6 @@ TORCH_LIBRARY_IMPL(aten, AutocastCPU, m) {
KERNEL_CPU(ADD_NS(fake_quantize_per_tensor_affine), "fake_quantize_per_tensor_affine", Tensor (const Tensor &, double, int64_t, int64_t, int64_t), fp32)
KERNEL_CPU(ADD_NS(glu), "glu", Tensor (const Tensor &, int64_t), fp32)
m.impl(TORCH_SELECTIVE_NAME("aten::cummax"),
TORCH_FN((&WrapFunction<CastPolicy::fp32, DeviceType::CPU,
std::tuple<Tensor, Tensor> (const Tensor &, int64_t),
std::tuple<Tensor, Tensor> (const Tensor &, int64_t),
&ADD_NS(cummax)>::type::call)));
m.impl(TORCH_SELECTIVE_NAME("aten::cummax.dimname"),
TORCH_FN((&WrapFunction<CastPolicy::fp32, DeviceType::CPU,
std::tuple<Tensor, Tensor> (const Tensor &, at::Dimname),
std::tuple<Tensor, Tensor> (const Tensor &, at::Dimname),
&ADD_NS(cummax)>::type::call)));
m.impl(TORCH_SELECTIVE_NAME("aten::cummin"),
TORCH_FN((&WrapFunction<CastPolicy::fp32, DeviceType::CPU,
std::tuple<Tensor, Tensor> (const Tensor &, int64_t),
std::tuple<Tensor, Tensor> (const Tensor &, int64_t),
&ADD_NS(cummin)>::type::call)));
m.impl(TORCH_SELECTIVE_NAME("aten::cummin.dimname"),
TORCH_FN((&WrapFunction<CastPolicy::fp32, DeviceType::CPU,
std::tuple<Tensor, Tensor> (const Tensor &, at::Dimname),
std::tuple<Tensor, Tensor> (const Tensor &, at::Dimname),
&ADD_NS(cummin)>::type::call)));
m.impl(TORCH_SELECTIVE_NAME("aten::eig"),
TORCH_FN((&WrapFunction<CastPolicy::fp32, DeviceType::CPU,
std::tuple<Tensor, Tensor> (const Tensor &, bool),
@ -680,11 +638,6 @@ TORCH_LIBRARY_IMPL(aten, AutocastCPU, m) {
std::tuple<Tensor, Tensor, Tensor> (const Tensor &, bool, bool),
&ADD_NS(_lu_with_info)>::type::call)));
m.impl(TORCH_SELECTIVE_NAME("aten::lu_unpack"),
TORCH_FN((&WrapFunction<CastPolicy::fp32, DeviceType::CPU,
std::tuple<Tensor, Tensor, Tensor> (const Tensor &, const Tensor &, bool, bool),
std::tuple<Tensor, Tensor, Tensor> (const Tensor &, const Tensor &, bool, bool),
&ADD_NS(lu_unpack)>::type::call)));
m.impl(TORCH_SELECTIVE_NAME("aten::qr"),
TORCH_FN((&WrapFunction<CastPolicy::fp32, DeviceType::CPU,
@ -722,17 +675,6 @@ TORCH_LIBRARY_IMPL(aten, AutocastCPU, m) {
std::tuple<Tensor, Tensor> (const Tensor &, IntArrayRef, IntArrayRef, const Tensor &),
&ADD_NS(fractional_max_pool3d)>::type::call)));
m.impl(TORCH_SELECTIVE_NAME("aten::adaptive_max_pool1d"),
TORCH_FN((&WrapFunction<CastPolicy::fp32, DeviceType::CPU,
std::tuple<Tensor, Tensor> (const Tensor &, IntArrayRef),
std::tuple<Tensor, Tensor> (const Tensor &, IntArrayRef),
&ADD_NS(adaptive_max_pool1d)>::type::call)));
m.impl(TORCH_SELECTIVE_NAME("aten::adaptive_max_pool2d"),
TORCH_FN((&WrapFunction<CastPolicy::fp32, DeviceType::CPU,
std::tuple<Tensor, Tensor> (const Tensor &, IntArrayRef),
std::tuple<Tensor, Tensor> (const Tensor &, IntArrayRef),
&ADD_NS(adaptive_max_pool2d)>::type::call)));
m.impl(TORCH_SELECTIVE_NAME("aten::adaptive_max_pool3d"),
TORCH_FN((&WrapFunction<CastPolicy::fp32, DeviceType::CPU,

View File

@ -266,6 +266,7 @@ CPU Ops that can autocast to ``bfloat16``
``addmm``,
``addbmm``,
``linear``,
``matmul``,
``_convolution``
CPU Ops that can autocast to ``float32``
@ -275,11 +276,7 @@ CPU Ops that can autocast to ``float32``
``conv_transpose2d``,
``conv_transpose3d``,
``batch_norm``,
``dropout``,
``avg_pool1d``,
``avg_pool2d``,
``avg_pool3d``,
``gelu``,
``upsample_nearest1d``,
``_upsample_nearest_exact1d``,
``upsample_nearest2d``,
@ -290,12 +287,12 @@ CPU Ops that can autocast to ``float32``
``upsample_bilinear2d``,
``upsample_trilinear3d``,
``binary_cross_entropy``,
``binary_cross_entropy_with_logits``,
``instance_norm``,
``grid_sampler``,
``grid_sampler_2d``,
``_grid_sampler_2d_cpu_fallback``,
``grid_sampler_3d``,
``polar``,
``multinomial``,
``poisson``,
``fmod``,
``prod``,
``quantile``,
@ -309,16 +306,12 @@ CPU Ops that can autocast to ``float32``
``diagflat``,
``histc``,
``logcumsumexp``,
``searchsorted``,
``trace``,
``tril``,
``triu``,
``vander``,
``view_as_complex``,
``cholesky``,
``cholesky_inverse``,
``cholesky_solve``,
``dot``,
``inverse``,
``lu_solve``,
``matrix_rank``,
@ -326,9 +319,6 @@ CPU Ops that can autocast to ``float32``
``inverse``,
``ormqr``,
``pinverse``,
``vdot``,
``im2col``,
``col2im``,
``max_pool3d``,
``max_unpool2d``,
``max_unpool3d``,
@ -339,17 +329,9 @@ CPU Ops that can autocast to ``float32``
``replication_pad2d``,
``replication_pad3d``,
``elu``,
``hardshrink``,
``hardsigmoid``,
``hardswish``,
``log_sigmoid``,
``prelu``,
``selu``,
``celu``,
``softplus``,
``softshrink``,
``group_norm``,
``smooth_l1_loss``,
``mse_loss``,
``ctc_loss``,
``kl_div``,
@ -383,13 +365,10 @@ CPU Ops that can autocast to ``float32``
``linalg_tensorsolve``,
``fake_quantize_per_tensor_affine``,
``glu``,
``cummax``,
``cummin``,
``eig``,
``geqrf``,
``lstsq``,
``_lu_with_info``,
``lu_unpack``,
``qr``,
``solve``,
``svd``,
@ -397,8 +376,6 @@ CPU Ops that can autocast to ``float32``
``triangular_solve``,
``fractional_max_pool2d``,
``fractional_max_pool3d``,
``adaptive_max_pool1d``,
``adaptive_max_pool2d``,
``adaptive_max_pool3d``,
``multilabel_margin_loss_forward``,
``linalg_qr``,

View File

@ -104,8 +104,9 @@ class TestAutocastCPU(TestCase):
self._run_autocast_outofplace(op, args, torch.bfloat16, add_kwargs=maybe_kwargs)
def test_autocast_nn_bf16(self):
for op, args in self.autocast_lists.nn_bf16:
self._run_autocast_outofplace(op, args, torch.bfloat16, module=torch._C._nn)
for op_with_args in self.autocast_lists.nn_bf16:
op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args)
self._run_autocast_outofplace(op, args, torch.bfloat16, module=torch._C._nn, add_kwargs=maybe_kwargs)
def test_autocast_torch_fp32(self):
for op_with_args in self.autocast_lists.torch_fp32:

View File

@ -302,6 +302,7 @@ class AutocastCPUTestLists(object):
("bmm", (torch.randn((n, n, n), device=dev, dtype=torch.float32),
torch.randn((n, n, n), device=dev, dtype=torch.float32))),
("mm", mat0_fp32 + mat1_fp32),
("matmul", mat0_fp32 + mat1_fp32),
("baddbmm", (torch.randn((n, n, n), device=dev, dtype=torch.float32),
torch.randn((n, n, n), device=dev, dtype=torch.float32),
torch.randn((n, n, n), device=dev, dtype=torch.float32))),
@ -316,20 +317,15 @@ class AutocastCPUTestLists(object):
("batch_norm", dummy_bf16[2], {"weight": None, "bias": None, "running_mean": torch.rand((n), dtype=torch.float32),
"running_var": torch.rand((n), dtype=torch.float32), "training": False,
"momentum": 0.1, "eps": 1e-5, "cudnn_enabled": False}),
("dropout", dummy_bf16[2], {"p": 0.1, "train": False}),
("binary_cross_entropy_with_logits", mat0_bf16 + (torch.rand((n, n), device=dev, dtype=torch.bfloat16),)),
("instance_norm", dummy_bf16[1], {"weight": None, "bias": None, "running_mean": None,
"running_var": None, "use_input_stats": True,
"momentum": 0.1, "eps": 1e-5, "cudnn_enabled": False}),
]
self.nn_bf16 = [
("linear", mat0_fp32 + mat1_fp32),
("linear", mat0_fp32 + mat1_fp32, {}),
]
self.nn_fp32 = [
("avg_pool2d", dummy_bf16[2], {"kernel_size": (3, 2), "stride": (1, 1)}),
("avg_pool3d", dummy_bf16[3], {"kernel_size": (3, 3, 3), "stride": (1, 1, 1)}),
("gelu", dummy_bf16[3], {"approximate": 'none'}),
("gelu", dummy_bf16[3], {"approximate": 'tanh'}),
("upsample_nearest1d", dummy_bf16[2], {"output_size": (n)}),
("upsample_nearest2d", dummy_bf16[3], {"output_size": (n, n)}),
("upsample_nearest3d", dummy_bf16[4], {"output_size": (n, n, n)}),
@ -339,7 +335,6 @@ class AutocastCPUTestLists(object):
("binary_cross_entropy", (torch.rand((n, n), device=dev, dtype=torch.bfloat16),) +
(torch.rand((n, n), device=dev, dtype=torch.bfloat16),)),
("reflection_pad1d", dummy_bf16[2], {"padding": (3, 3)}),
("smooth_l1_loss", mat0_bf16 + mat1_bf16),
]
self.torch_need_autocast_promote = [
("cat", (pointwise0_bf16 + pointwise1_fp32,)),