mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Enable UFMT on test/scripts and some files (#124137)
Part of: #123062 Ran lintrunner on: - `test/scripts` - `test/simulate_nccl_errors.py` - `test/test_ao_sparsity.py` - `test/test_autocast.py` - `test/test_binary_ufuncs.py` - `test/test_bundled_images.py` - `test/test_bundled_inputs.py` - `test/test_comparison_utils.py` - `test/test_compile_benchmark_util.py` - `test/test_complex.py` - `test/test_cpp_api_parity.py` - `test/test_cpp_extensions_aot.py` - `test/test_cpp_extensions_jit.py` - `test/test_cpp_extensions_open_device_registration.py` Detail: ```bash $ lintrunner -a --take UFMT --all-files ok No lint issues. Successfully applied all patches. ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/124137 Approved by: https://github.com/soulitzer
This commit is contained in:
parent
f0560f7b3b
commit
b3504af56e
|
|
@ -1051,21 +1051,6 @@ exclude_patterns = [
|
|||
'test/quantization/fx/test_numeric_suite_fx.py',
|
||||
'test/quantization/fx/test_quantize_fx.py',
|
||||
'test/quantization/fx/test_subgraph_rewriter.py',
|
||||
'test/scripts/cuda_memcheck_common.py',
|
||||
'test/scripts/run_cuda_memcheck.py',
|
||||
'test/simulate_nccl_errors.py',
|
||||
'test/test_ao_sparsity.py',
|
||||
'test/test_autocast.py',
|
||||
'test/test_binary_ufuncs.py',
|
||||
'test/test_bundled_images.py',
|
||||
'test/test_bundled_inputs.py',
|
||||
'test/test_comparison_utils.py',
|
||||
'test/test_compile_benchmark_util.py',
|
||||
'test/test_complex.py',
|
||||
'test/test_cpp_api_parity.py',
|
||||
'test/test_cpp_extensions_aot.py',
|
||||
'test/test_cpp_extensions_jit.py',
|
||||
'test/test_cpp_extensions_open_device_registration.py',
|
||||
'test/test_cuda.py',
|
||||
'test/test_cuda_expandable_segments.py',
|
||||
'test/test_cuda_multigpu.py',
|
||||
|
|
|
|||
|
|
@ -1,8 +1,10 @@
|
|||
# this file contains a simple parser that parses report
|
||||
# from cuda-memcheck
|
||||
|
||||
|
||||
class ParseError(Exception):
|
||||
"""Whenever the simple parser is unable to parse the report, this exception will be raised"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
|
|
@ -77,25 +79,25 @@ def parse(message):
|
|||
========= ERROR SUMMARY: 4 errors
|
||||
"""
|
||||
errors = []
|
||||
HEAD = '========='
|
||||
HEAD = "========="
|
||||
headlen = len(HEAD)
|
||||
started = False
|
||||
in_message = False
|
||||
message_lines = []
|
||||
lines = message.splitlines()
|
||||
for l in lines:
|
||||
if l == HEAD + ' CUDA-MEMCHECK':
|
||||
if l == HEAD + " CUDA-MEMCHECK":
|
||||
started = True
|
||||
continue
|
||||
if not started or not l.startswith(HEAD):
|
||||
continue
|
||||
l = l[headlen + 1:]
|
||||
if l.startswith('ERROR SUMMARY:'):
|
||||
l = l[headlen + 1 :]
|
||||
if l.startswith("ERROR SUMMARY:"):
|
||||
return Report(l, errors)
|
||||
if not in_message:
|
||||
in_message = True
|
||||
message_lines = [l]
|
||||
elif l == '':
|
||||
elif l == "":
|
||||
errors.append(Error(message_lines))
|
||||
in_message = False
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -12,39 +12,62 @@ Example usage:
|
|||
Note that running cuda-memcheck could be very slow.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import torch
|
||||
import multiprocessing
|
||||
import argparse
|
||||
import subprocess
|
||||
import tqdm
|
||||
import asyncio
|
||||
import multiprocessing
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
import cuda_memcheck_common as cmc
|
||||
import torch
|
||||
import tqdm
|
||||
|
||||
ALL_TESTS = []
|
||||
GPUS = torch.cuda.device_count()
|
||||
|
||||
# parse arguments
|
||||
parser = argparse.ArgumentParser(description="Run isolated cuda-memcheck on unit tests")
|
||||
parser.add_argument('filename', help="the python file for a test, such as test_torch.py")
|
||||
parser.add_argument('timeout', type=int, help='kill the test if it does not terminate in a certain amount of seconds')
|
||||
parser.add_argument('--strict', action='store_true',
|
||||
help='Whether to show cublas/cudnn errors. These errors are ignored by default because'
|
||||
'cublas/cudnn does not run error-free under cuda-memcheck, and ignoring these errors')
|
||||
parser.add_argument('--nproc', type=int, default=multiprocessing.cpu_count(),
|
||||
help='Number of processes running tests, default to number of cores in the system')
|
||||
parser.add_argument('--gpus', default='all',
|
||||
help='GPU assignments for each process, it could be "all", or : separated list like "1,2:3,4:5,6"')
|
||||
parser.add_argument('--ci', action='store_true',
|
||||
help='Whether this script is executed in CI. When executed inside a CI, this script fails when '
|
||||
'an error is detected. Also, it will not show tqdm progress bar, but directly print the error'
|
||||
'to stdout instead.')
|
||||
parser.add_argument('--nohang', action='store_true', help='Treat timeout as success')
|
||||
parser.add_argument('--split', type=int, default=1, help='Split the job into pieces')
|
||||
parser.add_argument('--rank', type=int, default=0, help='Which piece this process should pick')
|
||||
parser.add_argument(
|
||||
"filename", help="the python file for a test, such as test_torch.py"
|
||||
)
|
||||
parser.add_argument(
|
||||
"timeout",
|
||||
type=int,
|
||||
help="kill the test if it does not terminate in a certain amount of seconds",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--strict",
|
||||
action="store_true",
|
||||
help="Whether to show cublas/cudnn errors. These errors are ignored by default because"
|
||||
"cublas/cudnn does not run error-free under cuda-memcheck, and ignoring these errors",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--nproc",
|
||||
type=int,
|
||||
default=multiprocessing.cpu_count(),
|
||||
help="Number of processes running tests, default to number of cores in the system",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gpus",
|
||||
default="all",
|
||||
help='GPU assignments for each process, it could be "all", or : separated list like "1,2:3,4:5,6"',
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ci",
|
||||
action="store_true",
|
||||
help="Whether this script is executed in CI. When executed inside a CI, this script fails when "
|
||||
"an error is detected. Also, it will not show tqdm progress bar, but directly print the error"
|
||||
"to stdout instead.",
|
||||
)
|
||||
parser.add_argument("--nohang", action="store_true", help="Treat timeout as success")
|
||||
parser.add_argument("--split", type=int, default=1, help="Split the job into pieces")
|
||||
parser.add_argument(
|
||||
"--rank", type=int, default=0, help="Which piece this process should pick"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
# Filters that ignores cublas/cudnn errors
|
||||
# TODO (@zasdfgbnm): When can we remove this? Will cublas/cudnn run error-free under cuda-memcheck?
|
||||
def is_ignored_only(output):
|
||||
|
|
@ -56,32 +79,43 @@ def is_ignored_only(output):
|
|||
return False
|
||||
count_ignored_errors = 0
|
||||
for e in report.errors:
|
||||
if 'libcublas' in ''.join(e.stack) or 'libcudnn' in ''.join(e.stack) or 'libcufft' in ''.join(e.stack):
|
||||
if (
|
||||
"libcublas" in "".join(e.stack)
|
||||
or "libcudnn" in "".join(e.stack)
|
||||
or "libcufft" in "".join(e.stack)
|
||||
):
|
||||
count_ignored_errors += 1
|
||||
return count_ignored_errors == report.num_errors
|
||||
|
||||
|
||||
# Set environment PYTORCH_CUDA_MEMCHECK=1 to allow skipping some tests
|
||||
os.environ['PYTORCH_CUDA_MEMCHECK'] = '1'
|
||||
os.environ["PYTORCH_CUDA_MEMCHECK"] = "1"
|
||||
|
||||
# Discover tests:
|
||||
# To get a list of tests, run:
|
||||
# pytest --setup-only test/test_torch.py
|
||||
# and then parse the output
|
||||
proc = subprocess.Popen(['pytest', '--setup-only', args.filename], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
proc = subprocess.Popen(
|
||||
["pytest", "--setup-only", args.filename],
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
)
|
||||
stdout, stderr = proc.communicate()
|
||||
lines = stdout.decode().strip().splitlines()
|
||||
for line in lines:
|
||||
if '(fixtures used:' in line:
|
||||
if "(fixtures used:" in line:
|
||||
line = line.strip().split()[0]
|
||||
line = line[line.find('::') + 2:]
|
||||
line = line.replace('::', '.')
|
||||
line = line[line.find("::") + 2 :]
|
||||
line = line.replace("::", ".")
|
||||
ALL_TESTS.append(line)
|
||||
|
||||
|
||||
# Do a simple filtering:
|
||||
# if 'cpu' or 'CPU' is in the name and 'cuda' or 'CUDA' is not in the name, then skip it
|
||||
def is_cpu_only(name):
|
||||
name = name.lower()
|
||||
return ('cpu' in name) and "cuda" not in name
|
||||
return ("cpu" in name) and "cuda" not in name
|
||||
|
||||
|
||||
ALL_TESTS = [x for x in ALL_TESTS if not is_cpu_only(x)]
|
||||
|
||||
|
|
@ -101,7 +135,7 @@ ALL_TESTS = ALL_TESTS[start:end]
|
|||
# or as specified by the user
|
||||
progress = 0
|
||||
if not args.ci:
|
||||
logfile = open('result.log', 'w')
|
||||
logfile = open("result.log", "w")
|
||||
progressbar = tqdm.tqdm(total=len(ALL_TESTS))
|
||||
else:
|
||||
logfile = sys.stdout
|
||||
|
|
@ -110,53 +144,61 @@ else:
|
|||
class ProgressbarStub:
|
||||
def update(self, *args):
|
||||
return
|
||||
|
||||
progressbar = ProgressbarStub()
|
||||
|
||||
|
||||
async def run1(coroutine_id):
|
||||
global progress
|
||||
|
||||
if args.gpus == 'all':
|
||||
if args.gpus == "all":
|
||||
gpuid = coroutine_id % GPUS
|
||||
else:
|
||||
gpu_assignments = args.gpus.split(':')
|
||||
assert args.nproc == len(gpu_assignments), 'Please specify GPU assignment for each process, separated by :'
|
||||
gpu_assignments = args.gpus.split(":")
|
||||
assert args.nproc == len(
|
||||
gpu_assignments
|
||||
), "Please specify GPU assignment for each process, separated by :"
|
||||
gpuid = gpu_assignments[coroutine_id]
|
||||
|
||||
while progress < len(ALL_TESTS):
|
||||
test = ALL_TESTS[progress]
|
||||
progress += 1
|
||||
cmd = f'CUDA_VISIBLE_DEVICES={gpuid} cuda-memcheck --error-exitcode 1 python {args.filename} {test}'
|
||||
proc = await asyncio.create_subprocess_shell(cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE)
|
||||
cmd = f"CUDA_VISIBLE_DEVICES={gpuid} cuda-memcheck --error-exitcode 1 python {args.filename} {test}"
|
||||
proc = await asyncio.create_subprocess_shell(
|
||||
cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
|
||||
)
|
||||
try:
|
||||
stdout, stderr = await asyncio.wait_for(proc.communicate(), args.timeout)
|
||||
except asyncio.TimeoutError:
|
||||
print('Timeout:', test, file=logfile)
|
||||
print("Timeout:", test, file=logfile)
|
||||
proc.kill()
|
||||
if args.ci and not args.nohang:
|
||||
sys.exit("Hang detected on cuda-memcheck")
|
||||
else:
|
||||
if proc.returncode == 0:
|
||||
print('Success:', test, file=logfile)
|
||||
print("Success:", test, file=logfile)
|
||||
else:
|
||||
stdout = stdout.decode()
|
||||
stderr = stderr.decode()
|
||||
should_display = args.strict or not is_ignored_only(stdout)
|
||||
if should_display:
|
||||
print('Fail:', test, file=logfile)
|
||||
print("Fail:", test, file=logfile)
|
||||
print(stdout, file=logfile)
|
||||
print(stderr, file=logfile)
|
||||
if args.ci:
|
||||
sys.exit("Failure detected on cuda-memcheck")
|
||||
else:
|
||||
print('Ignored:', test, file=logfile)
|
||||
print("Ignored:", test, file=logfile)
|
||||
del proc
|
||||
progressbar.update(1)
|
||||
|
||||
|
||||
async def main():
|
||||
tasks = [asyncio.ensure_future(run1(i)) for i in range(args.nproc)]
|
||||
for t in tasks:
|
||||
await t
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
if __name__ == "__main__":
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.run_until_complete(main())
|
||||
|
|
|
|||
|
|
@ -1,22 +1,26 @@
|
|||
|
||||
import torch.distributed as c10d
|
||||
import torch
|
||||
import argparse
|
||||
import os
|
||||
import logging
|
||||
logging.basicConfig(format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', level=logging.INFO)
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch.distributed as c10d
|
||||
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Simple script to simulate NCCL errors. The script is '
|
||||
'supposed to be run on multiple different nodes simultaneously with '
|
||||
'appropriate rank and world_size. The script run an allreduce() on '
|
||||
'the rank 0 node and aborts all the other nodes to simulate an error '
|
||||
'in NCCL')
|
||||
parser.add_argument('addr', help='address of the master node to connect to.')
|
||||
parser.add_argument('port', help='port of the master node to connect to.')
|
||||
parser.add_argument('rank', help='rank of this node')
|
||||
parser.add_argument('world_size', help='number of nodes in process group')
|
||||
description="Simple script to simulate NCCL errors. The script is "
|
||||
"supposed to be run on multiple different nodes simultaneously with "
|
||||
"appropriate rank and world_size. The script run an allreduce() on "
|
||||
"the rank 0 node and aborts all the other nodes to simulate an error "
|
||||
"in NCCL"
|
||||
)
|
||||
parser.add_argument("addr", help="address of the master node to connect to.")
|
||||
parser.add_argument("port", help="port of the master node to connect to.")
|
||||
parser.add_argument("rank", help="rank of this node")
|
||||
parser.add_argument("world_size", help="number of nodes in process group")
|
||||
args = parser.parse_args()
|
||||
rank = int(args.rank)
|
||||
world_size = int(args.world_size)
|
||||
|
|
@ -24,14 +28,14 @@ if __name__ == "__main__":
|
|||
|
||||
store = c10d.TCPStore(args.addr, port, world_size, rank == 0)
|
||||
process_group = c10d.ProcessGroupNCCL(store, rank, world_size)
|
||||
logging.info('Running first allreduce')
|
||||
logging.info("Running first allreduce")
|
||||
process_group.allreduce(torch.rand(10).cuda(rank)).wait()
|
||||
if rank == 0:
|
||||
logging.info('Running second allreduce only on rank 0')
|
||||
logging.info("Running second allreduce only on rank 0")
|
||||
work = process_group.allreduce(torch.rand(10).cuda(rank))
|
||||
logging.info('Waiting for allreduce to complete...')
|
||||
logging.info("Waiting for allreduce to complete...")
|
||||
work.wait()
|
||||
logging.info('Second allreduce successful: %s', work.is_success())
|
||||
logging.info("Second allreduce successful: %s", work.is_success())
|
||||
else:
|
||||
logging.info('Aborting all other ranks.')
|
||||
logging.info("Aborting all other ranks.")
|
||||
os.abort()
|
||||
|
|
|
|||
|
|
@ -1,46 +1,59 @@
|
|||
# Owner(s): ["module: unknown"]
|
||||
|
||||
from torch.testing._internal.common_utils import run_tests, IS_ARM64
|
||||
|
||||
# Kernels
|
||||
from ao.sparsity.test_kernels import TestQuantizedSparseKernels # noqa: F401
|
||||
from ao.sparsity.test_kernels import TestQuantizedSparseLayers # noqa: F401
|
||||
from ao.sparsity.test_kernels import ( # noqa: F401 # noqa: F401
|
||||
TestQuantizedSparseKernels,
|
||||
TestQuantizedSparseLayers,
|
||||
)
|
||||
|
||||
# Parametrizations
|
||||
from ao.sparsity.test_parametrization import TestFakeSparsity # noqa: F401
|
||||
|
||||
# Scheduler
|
||||
from ao.sparsity.test_scheduler import ( # noqa: F401 # noqa: F401
|
||||
TestCubicScheduler,
|
||||
TestScheduler,
|
||||
)
|
||||
|
||||
# Sparsifier
|
||||
from ao.sparsity.test_sparsifier import TestBaseSparsifier # noqa: F401
|
||||
from ao.sparsity.test_sparsifier import TestWeightNormSparsifier # noqa: F401
|
||||
from ao.sparsity.test_sparsifier import TestNearlyDiagonalSparsifier # noqa: F401
|
||||
from ao.sparsity.test_sparsifier import ( # noqa: F401 # noqa: F401 # noqa: F401
|
||||
TestBaseSparsifier,
|
||||
TestNearlyDiagonalSparsifier,
|
||||
TestWeightNormSparsifier,
|
||||
)
|
||||
|
||||
# Structured Pruning
|
||||
from ao.sparsity.test_structured_sparsifier import TestBaseStructuredSparsifier # noqa: F401
|
||||
from ao.sparsity.test_structured_sparsifier import TestSaliencyPruner # noqa: F401
|
||||
from ao.sparsity.test_structured_sparsifier import TestFPGMPruner # noqa: F401
|
||||
|
||||
# Scheduler
|
||||
from ao.sparsity.test_scheduler import TestScheduler # noqa: F401
|
||||
from ao.sparsity.test_scheduler import TestCubicScheduler # noqa: F401
|
||||
from ao.sparsity.test_structured_sparsifier import ( # noqa: F401 # noqa: F401 # noqa: F401
|
||||
TestBaseStructuredSparsifier,
|
||||
TestFPGMPruner,
|
||||
TestSaliencyPruner,
|
||||
)
|
||||
from torch.testing._internal.common_utils import IS_ARM64, run_tests
|
||||
|
||||
# Composability
|
||||
if not IS_ARM64:
|
||||
from ao.sparsity.test_composability import TestComposability # noqa: F401
|
||||
from ao.sparsity.test_composability import TestFxComposability # noqa: F401
|
||||
from ao.sparsity.test_composability import ( # noqa: F401 # noqa: F401
|
||||
TestComposability,
|
||||
TestFxComposability,
|
||||
)
|
||||
|
||||
# Utilities
|
||||
from ao.sparsity.test_sparsity_utils import TestSparsityUtilFunctions # noqa: F401
|
||||
|
||||
# Data Sparsifier
|
||||
from ao.sparsity.test_data_sparsifier import TestBaseDataSparsifier # noqa: F401
|
||||
from ao.sparsity.test_data_sparsifier import TestNormDataSparsifiers # noqa: F401
|
||||
from ao.sparsity.test_data_sparsifier import TestQuantizationUtils # noqa: F401
|
||||
# Activation Sparsifier
|
||||
from ao.sparsity.test_activation_sparsifier import ( # noqa: F401
|
||||
TestActivationSparsifier,
|
||||
)
|
||||
|
||||
# Data Scheduler
|
||||
from ao.sparsity.test_data_scheduler import TestBaseDataScheduler # noqa: F401
|
||||
|
||||
# Activation Sparsifier
|
||||
from ao.sparsity.test_activation_sparsifier import TestActivationSparsifier # noqa: F401
|
||||
# Data Sparsifier
|
||||
from ao.sparsity.test_data_sparsifier import ( # noqa: F401 # noqa: F401 # noqa: F401
|
||||
TestBaseDataSparsifier,
|
||||
TestNormDataSparsifiers,
|
||||
TestQuantizationUtils,
|
||||
)
|
||||
|
||||
# Utilities
|
||||
from ao.sparsity.test_sparsity_utils import TestSparsityUtilFunctions # noqa: F401
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -4,14 +4,20 @@ import collections
|
|||
import unittest
|
||||
|
||||
import torch
|
||||
from torch.testing._internal.common_utils import TestCase, run_tests, IS_WINDOWS, skipIfTorchDynamo
|
||||
from torch.testing._internal.autocast_test_lists import AutocastCPUTestLists
|
||||
from torch.testing._internal.common_utils import (
|
||||
IS_WINDOWS,
|
||||
run_tests,
|
||||
skipIfTorchDynamo,
|
||||
TestCase,
|
||||
)
|
||||
from torch.utils._python_dispatch import TorchDispatchMode
|
||||
|
||||
|
||||
class TestAutocastCPU(TestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.autocast_lists = AutocastCPUTestLists(torch.device('cpu'))
|
||||
self.autocast_lists = AutocastCPUTestLists(torch.device("cpu"))
|
||||
|
||||
def tearDown(self):
|
||||
del self.autocast_lists
|
||||
|
|
@ -49,18 +55,23 @@ class TestAutocastCPU(TestCase):
|
|||
if module is not None and hasattr(module, op):
|
||||
output = getattr(module, op)(*args, **add_kwargs)
|
||||
if isinstance(output, torch.Tensor):
|
||||
self.assertTrue(out_type == output.dtype,
|
||||
f"autocast for torch.{op} produced {output.dtype}, should produce {out_type}")
|
||||
self.assertTrue(
|
||||
out_type == output.dtype,
|
||||
f"autocast for torch.{op} produced {output.dtype}, should produce {out_type}",
|
||||
)
|
||||
# Try Tensor.* variant:
|
||||
if hasattr(torch.Tensor, op):
|
||||
output_method = getattr(args[0], op)(*args[1:], **add_kwargs)
|
||||
if isinstance(output_method, torch.Tensor):
|
||||
self.assertTrue(out_type == output_method.dtype,
|
||||
"autocast for torch.{} produced {}, should produce torch.{}"
|
||||
.format(op, output_method.dtype, out_type))
|
||||
self.assertTrue(
|
||||
out_type == output_method.dtype,
|
||||
f"autocast for torch.{op} produced {output_method.dtype}, should produce torch.{out_type}",
|
||||
)
|
||||
|
||||
self.assertTrue((output is not None) or (output_method is not None),
|
||||
f"{op} not found as an attribute on either Tensor or the requested module {module}")
|
||||
self.assertTrue(
|
||||
(output is not None) or (output_method is not None),
|
||||
f"{op} not found as an attribute on either Tensor or the requested module {module}",
|
||||
)
|
||||
|
||||
# Accounts for ops that return Tensors, iterables, and other non-Tensors.
|
||||
# For example, lstm_cell returns a tuple and equal returns bool.
|
||||
|
|
@ -76,7 +87,9 @@ class TestAutocastCPU(TestCase):
|
|||
if (output is not None) and (output_method is not None):
|
||||
self.assertTrue(type(output) == type(output_method))
|
||||
comparison = compare(output, output_method)
|
||||
self.assertTrue(comparison, f"torch.{op} result did not match Tensor.{op} result")
|
||||
self.assertTrue(
|
||||
comparison, f"torch.{op} result did not match Tensor.{op} result"
|
||||
)
|
||||
|
||||
# Compare numerics to Python-side "autocasting" that (we expect) does the same thing
|
||||
# as the C++-side autocasting, and should be bitwise accurate.
|
||||
|
|
@ -85,9 +98,13 @@ class TestAutocastCPU(TestCase):
|
|||
self.assertFalse(torch.is_autocast_cpu_enabled())
|
||||
|
||||
if module is not None and hasattr(module, op):
|
||||
control = getattr(module, op)(*cast(args, run_as_type), **add_kwargs)
|
||||
control = getattr(module, op)(
|
||||
*cast(args, run_as_type), **add_kwargs
|
||||
)
|
||||
else:
|
||||
control = getattr(args[0].to(run_as_type), op)(*cast(args[1:], run_as_type), **add_kwargs)
|
||||
control = getattr(args[0].to(run_as_type), op)(
|
||||
*cast(args[1:], run_as_type), **add_kwargs
|
||||
)
|
||||
self.assertTrue(type(output_to_compare) == type(control))
|
||||
comparison = compare(output_to_compare, control)
|
||||
self.assertTrue(comparison, f"torch.{op} result did not match control")
|
||||
|
|
@ -102,22 +119,51 @@ class TestAutocastCPU(TestCase):
|
|||
|
||||
@skipIfTorchDynamo()
|
||||
def test_autocast_torch_expect_builtin_promote(self):
|
||||
for op, args1, args2, out_type in self.autocast_lists.torch_expect_builtin_promote:
|
||||
for (
|
||||
op,
|
||||
args1,
|
||||
args2,
|
||||
out_type,
|
||||
) in self.autocast_lists.torch_expect_builtin_promote:
|
||||
self._run_autocast_outofplace(op, args1, torch.float32, out_type=out_type)
|
||||
self._run_autocast_outofplace(op, args2, torch.float32, out_type=out_type, amp_dtype=torch.float16)
|
||||
self._run_autocast_outofplace(
|
||||
op, args2, torch.float32, out_type=out_type, amp_dtype=torch.float16
|
||||
)
|
||||
|
||||
@skipIfTorchDynamo()
|
||||
def test_autocast_methods_expect_builtin_promote(self):
|
||||
for op, args1, args2, out_type in self.autocast_lists.methods_expect_builtin_promote:
|
||||
self._run_autocast_outofplace(op, args1, torch.float32, module=None, out_type=out_type)
|
||||
self._run_autocast_outofplace(op, args2, torch.float32, module=None, out_type=out_type, amp_dtype=torch.float16)
|
||||
for (
|
||||
op,
|
||||
args1,
|
||||
args2,
|
||||
out_type,
|
||||
) in self.autocast_lists.methods_expect_builtin_promote:
|
||||
self._run_autocast_outofplace(
|
||||
op, args1, torch.float32, module=None, out_type=out_type
|
||||
)
|
||||
self._run_autocast_outofplace(
|
||||
op,
|
||||
args2,
|
||||
torch.float32,
|
||||
module=None,
|
||||
out_type=out_type,
|
||||
amp_dtype=torch.float16,
|
||||
)
|
||||
|
||||
@skipIfTorchDynamo()
|
||||
def test_autocast_torch_16(self):
|
||||
for op_with_args in self.autocast_lists.torch_16:
|
||||
op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args)
|
||||
self._run_autocast_outofplace(op, args, torch.bfloat16, add_kwargs=maybe_kwargs)
|
||||
self._run_autocast_outofplace(op, args, torch.float16, add_kwargs=maybe_kwargs, amp_dtype=torch.float16)
|
||||
self._run_autocast_outofplace(
|
||||
op, args, torch.bfloat16, add_kwargs=maybe_kwargs
|
||||
)
|
||||
self._run_autocast_outofplace(
|
||||
op,
|
||||
args,
|
||||
torch.float16,
|
||||
add_kwargs=maybe_kwargs,
|
||||
amp_dtype=torch.float16,
|
||||
)
|
||||
|
||||
@skipIfTorchDynamo()
|
||||
def test_autocast_nn_16(self):
|
||||
|
|
@ -139,8 +185,16 @@ class TestAutocastCPU(TestCase):
|
|||
def test_autocast_torch_fp32(self):
|
||||
for op_with_args in self.autocast_lists.torch_fp32:
|
||||
op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args)
|
||||
self._run_autocast_outofplace(op, args, torch.float32, add_kwargs=maybe_kwargs)
|
||||
self._run_autocast_outofplace(op, args, torch.float32, add_kwargs=maybe_kwargs, amp_dtype=torch.float16)
|
||||
self._run_autocast_outofplace(
|
||||
op, args, torch.float32, add_kwargs=maybe_kwargs
|
||||
)
|
||||
self._run_autocast_outofplace(
|
||||
op,
|
||||
args,
|
||||
torch.float32,
|
||||
add_kwargs=maybe_kwargs,
|
||||
amp_dtype=torch.float16,
|
||||
)
|
||||
|
||||
@skipIfTorchDynamo()
|
||||
def test_autocast_nn_fp32(self):
|
||||
|
|
@ -162,11 +216,16 @@ class TestAutocastCPU(TestCase):
|
|||
def test_autocast_torch_need_autocast_promote(self):
|
||||
for op, args1, args2 in self.autocast_lists.torch_need_autocast_promote:
|
||||
self._run_autocast_outofplace(op, args1, torch.float32)
|
||||
self._run_autocast_outofplace(op, args2, torch.float32, amp_dtype=torch.float16)
|
||||
self._run_autocast_outofplace(
|
||||
op, args2, torch.float32, amp_dtype=torch.float16
|
||||
)
|
||||
|
||||
@unittest.skipIf(IS_WINDOWS, "Limit support for bf16 path")
|
||||
def test_autocast_rnn(self):
|
||||
if torch.backends.mkldnn.is_available() and torch.ops.mkldnn._is_mkldnn_bf16_supported():
|
||||
if (
|
||||
torch.backends.mkldnn.is_available()
|
||||
and torch.ops.mkldnn._is_mkldnn_bf16_supported()
|
||||
):
|
||||
x = torch.randn(1, 2, 1)
|
||||
hx = torch.randn(2, 2, 1)
|
||||
cx = torch.randn(2, 2, 1)
|
||||
|
|
@ -182,9 +241,10 @@ class TestAutocastCPU(TestCase):
|
|||
m(x, (hx, cx))
|
||||
|
||||
def test_autocast_disabled_with_fp32_dtype(self):
|
||||
with torch.autocast(device_type='cpu', dtype=torch.float32, enabled=False):
|
||||
with torch.autocast(device_type="cpu", dtype=torch.float32, enabled=False):
|
||||
_ = torch.ones(10)
|
||||
|
||||
|
||||
class CustomLinear(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, w_t):
|
||||
|
|
@ -194,13 +254,13 @@ class CustomLinear(torch.autograd.Function):
|
|||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
x, w_t = ctx.saved_tensors
|
||||
with torch.autocast(device_type='cuda'):
|
||||
with torch.autocast(device_type="cuda"):
|
||||
dL_dX = torch.matmul(grad_output, w_t)
|
||||
dL_dW = torch.matmul(x.transpose(0, 1), grad_output).transpose(0, 1)
|
||||
return dL_dX, dL_dW
|
||||
|
||||
class WeightDTypeCastCounterMode(TorchDispatchMode):
|
||||
|
||||
class WeightDTypeCastCounterMode(TorchDispatchMode):
|
||||
def __init__(self, weight):
|
||||
super().__init__()
|
||||
self.dtype_cast_counter = 0
|
||||
|
|
@ -208,9 +268,9 @@ class WeightDTypeCastCounterMode(TorchDispatchMode):
|
|||
|
||||
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
||||
if (
|
||||
func is torch.ops.aten._to_copy.default and
|
||||
args[0] is self.weight and
|
||||
kwargs['dtype'] is torch.float16
|
||||
func is torch.ops.aten._to_copy.default
|
||||
and args[0] is self.weight
|
||||
and kwargs["dtype"] is torch.float16
|
||||
):
|
||||
self.dtype_cast_counter += 1
|
||||
return func(*args, **kwargs)
|
||||
|
|
@ -224,6 +284,7 @@ class WeightDTypeCastCounterMode(TorchDispatchMode):
|
|||
torch.clear_autocast_cache = self.old_clear_cache
|
||||
return super().__exit__(exc_type, exc_val, exc_tb)
|
||||
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
|
||||
class TestAutocastGPU(TestCase):
|
||||
def test_cast_cache_is_global(self):
|
||||
|
|
@ -238,7 +299,7 @@ class TestAutocastGPU(TestCase):
|
|||
weight = torch.nn.Parameter(torch.randn(4, 3).cuda())
|
||||
|
||||
with WeightDTypeCastCounterMode(weight) as mode:
|
||||
with torch.autocast(device_type='cuda'):
|
||||
with torch.autocast(device_type="cuda"):
|
||||
output = CustomLinear.apply(data, weight)
|
||||
s = output.sum()
|
||||
s.backward()
|
||||
|
|
@ -246,7 +307,6 @@ class TestAutocastGPU(TestCase):
|
|||
self.assertEqual(mode.dtype_cast_counter, 1)
|
||||
|
||||
def test_cache_disabled(self):
|
||||
|
||||
data = torch.randn(2, 3).cuda()
|
||||
weight = torch.nn.Parameter(torch.randn(4, 3).cuda())
|
||||
|
||||
|
|
@ -255,7 +315,7 @@ class TestAutocastGPU(TestCase):
|
|||
torch._C._add_cached_tensor(weight)
|
||||
|
||||
with WeightDTypeCastCounterMode(weight) as mode:
|
||||
with torch.autocast(device_type='cuda'):
|
||||
with torch.autocast(device_type="cuda"):
|
||||
output = CustomLinear.apply(data, weight)
|
||||
s = output.sum()
|
||||
s.backward()
|
||||
|
|
@ -275,12 +335,12 @@ class TestTorchAutocast(TestCase):
|
|||
self.assertEqual(cpu_fast_dtype, torch.bfloat16)
|
||||
|
||||
def test_invalid_device(self):
|
||||
dev = 'not a real device'
|
||||
msg = f'unsupported autocast device_type \'{dev}\''
|
||||
dev = "not a real device"
|
||||
msg = f"unsupported autocast device_type '{dev}'"
|
||||
with self.assertRaisesRegex(RuntimeError, msg):
|
||||
with torch.autocast(device_type=dev):
|
||||
_ = torch.tensor(1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -1,75 +1,75 @@
|
|||
# Owner(s): ["module: tests"]
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
import itertools
|
||||
from itertools import chain
|
||||
from itertools import product
|
||||
import math
|
||||
import random
|
||||
from numbers import Number
|
||||
import warnings
|
||||
import operator
|
||||
import random
|
||||
import warnings
|
||||
from functools import partial
|
||||
from itertools import chain, product
|
||||
from numbers import Number
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import torch.autograd.forward_ad as fwAD
|
||||
from torch import inf, nan
|
||||
from torch.testing._internal.common_utils import (
|
||||
TestCase,
|
||||
slowTest,
|
||||
iter_indices,
|
||||
run_tests,
|
||||
gradcheck,
|
||||
torch_to_numpy_dtype_dict,
|
||||
numpy_to_torch_dtype_dict,
|
||||
TEST_SCIPY,
|
||||
set_default_dtype,
|
||||
skipIfTorchDynamo,
|
||||
)
|
||||
from torch.testing import make_tensor
|
||||
from torch.testing._internal.common_device_type import (
|
||||
deviceCountAtLeast,
|
||||
dtypes,
|
||||
dtypesIfCPU,
|
||||
dtypesIfCUDA,
|
||||
expectedFailureMeta,
|
||||
instantiate_device_type_tests,
|
||||
onlyCUDA,
|
||||
onlyCPU,
|
||||
dtypes,
|
||||
dtypesIfCUDA,
|
||||
dtypesIfCPU,
|
||||
deviceCountAtLeast,
|
||||
precisionOverride,
|
||||
onlyCUDA,
|
||||
onlyNativeDeviceTypes,
|
||||
skipIf,
|
||||
ops,
|
||||
OpDTypes,
|
||||
ops,
|
||||
precisionOverride,
|
||||
skipIf,
|
||||
skipMeta,
|
||||
)
|
||||
from torch.testing import make_tensor
|
||||
from torch.testing._internal.common_dtype import (
|
||||
all_types_and_complex_and,
|
||||
all_types_and,
|
||||
integral_types,
|
||||
all_types_and_complex_and,
|
||||
complex_types,
|
||||
integral_types_and,
|
||||
floating_types_and,
|
||||
floating_and_complex_types,
|
||||
get_all_math_dtypes,
|
||||
floating_types_and,
|
||||
get_all_int_dtypes,
|
||||
get_all_math_dtypes,
|
||||
integral_types,
|
||||
integral_types_and,
|
||||
)
|
||||
from torch.testing._internal.common_methods_invocations import (
|
||||
binary_ufuncs,
|
||||
binary_ufuncs_and_refs,
|
||||
generate_elementwise_binary_tensors,
|
||||
generate_elementwise_binary_small_value_tensors,
|
||||
generate_elementwise_binary_large_value_tensors,
|
||||
generate_elementwise_binary_extremal_value_tensors,
|
||||
generate_elementwise_binary_broadcasting_tensors,
|
||||
generate_elementwise_binary_with_scalar_samples,
|
||||
generate_elementwise_binary_extremal_value_tensors,
|
||||
generate_elementwise_binary_large_value_tensors,
|
||||
generate_elementwise_binary_small_value_tensors,
|
||||
generate_elementwise_binary_tensors,
|
||||
generate_elementwise_binary_with_scalar_and_type_promotion_samples,
|
||||
generate_elementwise_binary_with_scalar_samples,
|
||||
)
|
||||
from torch.testing._internal.common_utils import (
|
||||
gradcheck,
|
||||
iter_indices,
|
||||
numpy_to_torch_dtype_dict,
|
||||
run_tests,
|
||||
set_default_dtype,
|
||||
skipIfTorchDynamo,
|
||||
slowTest,
|
||||
TEST_SCIPY,
|
||||
TestCase,
|
||||
torch_to_numpy_dtype_dict,
|
||||
)
|
||||
|
||||
if TEST_SCIPY:
|
||||
import scipy.special
|
||||
import scipy.integrate
|
||||
import scipy.special
|
||||
|
||||
|
||||
# TODO: update to use opinfos consistently
|
||||
class TestBinaryUfuncs(TestCase):
|
||||
|
|
@ -269,7 +269,6 @@ class TestBinaryUfuncs(TestCase):
|
|||
)
|
||||
self._test_reference_numerics(dtype, op, gen, equal_nan=True)
|
||||
|
||||
|
||||
@ops(binary_ufuncs)
|
||||
def test_contig_vs_every_other(self, device, dtype, op):
|
||||
lhs = make_tensor(
|
||||
|
|
@ -487,7 +486,7 @@ class TestBinaryUfuncs(TestCase):
|
|||
)
|
||||
|
||||
make_rhs_scalar_tensor = partial(
|
||||
make_tensor, (), device='cpu', **op.rhs_make_tensor_kwargs
|
||||
make_tensor, (), device="cpu", **op.rhs_make_tensor_kwargs
|
||||
)
|
||||
|
||||
def _supported(dtypes):
|
||||
|
|
@ -777,10 +776,14 @@ class TestBinaryUfuncs(TestCase):
|
|||
# scalar x scalar
|
||||
# Note: result dtype is default float type
|
||||
if op.supports_two_python_scalars and _supported((torch.long, torch.float32)):
|
||||
rhs_f_scalar = 2.
|
||||
for lhs in (1, 1.):
|
||||
rhs_f_scalar = 2.0
|
||||
for lhs in (1, 1.0):
|
||||
result = op(lhs, rhs_f_scalar)
|
||||
expected_dtype = torch.get_default_dtype() if not op.always_returns_bool else torch.bool
|
||||
expected_dtype = (
|
||||
torch.get_default_dtype()
|
||||
if not op.always_returns_bool
|
||||
else torch.bool
|
||||
)
|
||||
self.assertEqual(result.dtype, expected_dtype)
|
||||
|
||||
# TODO: move to error input test
|
||||
|
|
@ -966,7 +969,6 @@ class TestBinaryUfuncs(TestCase):
|
|||
|
||||
@dtypes(torch.bfloat16, torch.half, torch.float32, torch.float64)
|
||||
def test_div_rounding_nonfinite(self, device, dtype):
|
||||
|
||||
# Compare division of special floating point values against NumPy
|
||||
num = torch.tensor(
|
||||
[1.0, -1.0, 0, 0.1, -0.1, np.pi, -np.pi, np.inf, -np.inf, np.nan],
|
||||
|
|
@ -1088,21 +1090,27 @@ class TestBinaryUfuncs(TestCase):
|
|||
# NOTE: the calculation still produces an error if the number is greater than
|
||||
# finfo.max / 2, but hopefully people realized that it's a dangerous region to work with
|
||||
finfo = torch.finfo(dtype)
|
||||
nom_lst = [complex(finfo.min / 2, finfo.min / 2),
|
||||
complex(finfo.max / 2, finfo.max / 2),
|
||||
complex(finfo.tiny, finfo.tiny),
|
||||
complex(finfo.tiny, 0.0),
|
||||
complex(0.0, 0.0)]
|
||||
denom_lst = [complex(finfo.min / 2, finfo.min / 2),
|
||||
complex(finfo.max / 2, finfo.max / 2),
|
||||
complex(finfo.tiny, finfo.tiny),
|
||||
complex(0.0, finfo.tiny),
|
||||
complex(finfo.tiny, finfo.tiny)]
|
||||
expected_lst = [complex(1.0, 0.0),
|
||||
complex(1.0, 0.0),
|
||||
complex(1.0, 0.0),
|
||||
complex(0.0, -1.0),
|
||||
complex(0.0, 0.0)]
|
||||
nom_lst = [
|
||||
complex(finfo.min / 2, finfo.min / 2),
|
||||
complex(finfo.max / 2, finfo.max / 2),
|
||||
complex(finfo.tiny, finfo.tiny),
|
||||
complex(finfo.tiny, 0.0),
|
||||
complex(0.0, 0.0),
|
||||
]
|
||||
denom_lst = [
|
||||
complex(finfo.min / 2, finfo.min / 2),
|
||||
complex(finfo.max / 2, finfo.max / 2),
|
||||
complex(finfo.tiny, finfo.tiny),
|
||||
complex(0.0, finfo.tiny),
|
||||
complex(finfo.tiny, finfo.tiny),
|
||||
]
|
||||
expected_lst = [
|
||||
complex(1.0, 0.0),
|
||||
complex(1.0, 0.0),
|
||||
complex(1.0, 0.0),
|
||||
complex(0.0, -1.0),
|
||||
complex(0.0, 0.0),
|
||||
]
|
||||
nom = torch.tensor(nom_lst, dtype=dtype, device=device)
|
||||
denom = torch.tensor(denom_lst, dtype=dtype, device=device)
|
||||
expected = torch.tensor(expected_lst, dtype=dtype, device=device)
|
||||
|
|
@ -1146,7 +1154,10 @@ class TestBinaryUfuncs(TestCase):
|
|||
# test that multi-d out doesn't trigger segfault
|
||||
arg1 = (torch.ones(2, 1, device=device), torch.ones(1, device=device))
|
||||
arg2 = (torch.ones(2, device=device), torch.ones(1, 1, device=device))
|
||||
outs = (torch.ones(2, 1, 1, 1, device=device), torch.ones(2, 2, 2, 2, device=device))
|
||||
outs = (
|
||||
torch.ones(2, 1, 1, 1, device=device),
|
||||
torch.ones(2, 2, 2, 2, device=device),
|
||||
)
|
||||
|
||||
for a1, a2, o in zip(arg1, arg2, outs):
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
|
|
@ -1360,12 +1371,16 @@ class TestBinaryUfuncs(TestCase):
|
|||
self._do_pow_for_exponents(m1, exponents + complex_exponents, pow, 10e-4)
|
||||
else:
|
||||
self._do_pow_for_exponents(m1, exponents, math.pow, None)
|
||||
will_raise_error = dtype is torch.half and torch.device(device).type == 'cpu'
|
||||
will_raise_error = (
|
||||
dtype is torch.half and torch.device(device).type == "cpu"
|
||||
)
|
||||
if will_raise_error:
|
||||
# On CPU,
|
||||
# Half Tensor with complex exponents leads to computation dtype
|
||||
# of ComplexHalf for which this ops is not supported yet
|
||||
with self.assertRaisesRegex(RuntimeError, "not implemented for 'ComplexHalf'"):
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "not implemented for 'ComplexHalf'"
|
||||
):
|
||||
self._do_pow_for_exponents(m1, complex_exponents, pow, 10e-4)
|
||||
else:
|
||||
self._do_pow_for_exponents(m1, complex_exponents, pow, 10e-4)
|
||||
|
|
@ -1663,10 +1678,15 @@ class TestBinaryUfuncs(TestCase):
|
|||
# of ComplexHalf for which this ops is not supported yet
|
||||
# NOTE: pow has fast-path when base is 1 which supports
|
||||
# ComplexHalf
|
||||
will_raise_error = torch.device(device).type == 'cpu' and \
|
||||
dtype is torch.half and base != (1 + 0j)
|
||||
will_raise_error = (
|
||||
torch.device(device).type == "cpu"
|
||||
and dtype is torch.half
|
||||
and base != (1 + 0j)
|
||||
)
|
||||
if will_raise_error:
|
||||
with self.assertRaisesRegex(RuntimeError, "not implemented for 'ComplexHalf'"):
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "not implemented for 'ComplexHalf'"
|
||||
):
|
||||
self._test_pow(base, first_exp)
|
||||
self._test_pow(base, second_exp)
|
||||
else:
|
||||
|
|
@ -2028,9 +2048,7 @@ class TestBinaryUfuncs(TestCase):
|
|||
tmp //= b_t
|
||||
self.assertEqual(tmp.item(), expected_ifloordiv)
|
||||
|
||||
self.assertEqual(
|
||||
scripted_floor_divide__scalar(a_t), math.floor(a / 5)
|
||||
)
|
||||
self.assertEqual(scripted_floor_divide__scalar(a_t), math.floor(a / 5))
|
||||
|
||||
# Tests binary op equivalence with Python builtin ops
|
||||
# Also tests that reverse operations are equivalent to forward ops
|
||||
|
|
@ -2042,7 +2060,6 @@ class TestBinaryUfuncs(TestCase):
|
|||
(operator.mul, torch.mul),
|
||||
(operator.truediv, torch.div),
|
||||
):
|
||||
|
||||
for a, b in product(range(-10, 10), range(-10, 10)):
|
||||
for op in (lambda x: x * 0.5, lambda x: math.floor(x)):
|
||||
a = op(a)
|
||||
|
|
@ -3143,11 +3160,19 @@ class TestBinaryUfuncs(TestCase):
|
|||
bits = iinfo.bits
|
||||
low = iinfo.min
|
||||
high = iinfo.max
|
||||
exact_dtype = dtype != torch.uint8 # numpy changes dtype from uint8 to int16 for some out-of-limits shift values
|
||||
exact_dtype = (
|
||||
dtype != torch.uint8
|
||||
) # numpy changes dtype from uint8 to int16 for some out-of-limits shift values
|
||||
for input in (
|
||||
torch.tensor([-1, 0, 1], device=device, dtype=dtype), # small for non-vectorized operation
|
||||
torch.tensor([low, high], device=device, dtype=dtype), # small for non-vectorized operation
|
||||
make_tensor((64, 64, 64), low=low, high=high, device=device, dtype=dtype), # large for vectorized operation
|
||||
torch.tensor(
|
||||
[-1, 0, 1], device=device, dtype=dtype
|
||||
), # small for non-vectorized operation
|
||||
torch.tensor(
|
||||
[low, high], device=device, dtype=dtype
|
||||
), # small for non-vectorized operation
|
||||
make_tensor(
|
||||
(64, 64, 64), low=low, high=high, device=device, dtype=dtype
|
||||
), # large for vectorized operation
|
||||
):
|
||||
shift_left_expected = torch.zeros_like(input)
|
||||
shift_right_expected = torch.clamp(input, -1, 0)
|
||||
|
|
@ -3158,7 +3183,8 @@ class TestBinaryUfuncs(TestCase):
|
|||
lambda x: x << shift,
|
||||
lambda x: np.left_shift(x, shift),
|
||||
input,
|
||||
exact_dtype=exact_dtype, msg=f"<< {shift}"
|
||||
exact_dtype=exact_dtype,
|
||||
msg=f"<< {shift}",
|
||||
)
|
||||
shift_right = input >> shift
|
||||
self.assertEqual(shift_right, shift_right_expected, msg=f">> {shift}")
|
||||
|
|
@ -3166,7 +3192,8 @@ class TestBinaryUfuncs(TestCase):
|
|||
lambda x: x >> shift,
|
||||
lambda x: np.right_shift(x, shift),
|
||||
input,
|
||||
exact_dtype=exact_dtype, msg=f">> {shift}"
|
||||
exact_dtype=exact_dtype,
|
||||
msg=f">> {shift}",
|
||||
)
|
||||
|
||||
@onlyNativeDeviceTypes
|
||||
|
|
@ -3448,6 +3475,7 @@ class TestBinaryUfuncs(TestCase):
|
|||
# numpy has not implemented logaddexp for complex
|
||||
def _ref_func(x, y):
|
||||
return scipy.special.logsumexp(np.stack((x, y), axis=0), axis=0)
|
||||
|
||||
ref_func = _ref_func
|
||||
our_func = torch.logaddexp
|
||||
else:
|
||||
|
|
@ -3488,9 +3516,11 @@ class TestBinaryUfuncs(TestCase):
|
|||
)
|
||||
_test_helper(a, b)
|
||||
|
||||
@skipIfTorchDynamo() # complex infs/nans differ under Dynamo/Inductor
|
||||
@skipIfTorchDynamo() # complex infs/nans differ under Dynamo/Inductor
|
||||
@dtypesIfCUDA(torch.float32, torch.float64, torch.bfloat16)
|
||||
@dtypes(torch.float32, torch.float64, torch.bfloat16, torch.complex64, torch.complex128)
|
||||
@dtypes(
|
||||
torch.float32, torch.float64, torch.bfloat16, torch.complex64, torch.complex128
|
||||
)
|
||||
def test_logaddexp(self, device, dtype):
|
||||
self._test_logaddexp(device, dtype, base2=False)
|
||||
|
||||
|
|
@ -3818,7 +3848,13 @@ class TestBinaryUfuncs(TestCase):
|
|||
b_16 = b.to(dtype=lowp_dtype)
|
||||
actual_16 = a_16.atan2(b_16)
|
||||
self.assertEqual(actual_16, actual.to(dtype=lowp_dtype))
|
||||
self.assertEqual(expected, actual_16.view(-1), exact_dtype=False, rtol=rtol, atol=atol)
|
||||
self.assertEqual(
|
||||
expected,
|
||||
actual_16.view(-1),
|
||||
exact_dtype=False,
|
||||
rtol=rtol,
|
||||
atol=atol,
|
||||
)
|
||||
|
||||
_test_atan2_with_size((2, 2), device)
|
||||
_test_atan2_with_size((3, 3), device)
|
||||
|
|
@ -3886,7 +3922,6 @@ class TestBinaryUfuncs(TestCase):
|
|||
|
||||
@skipIf(not TEST_SCIPY, "Scipy required for the test.")
|
||||
def test_cumulative_trapezoid(self, device):
|
||||
|
||||
import scipy.integrate
|
||||
|
||||
if hasattr(scipy.integrate, "cumulative_trapezoid"):
|
||||
|
|
@ -4034,7 +4069,6 @@ class TestBinaryUfuncs(TestCase):
|
|||
torch.Tensor.float_power,
|
||||
torch.Tensor.float_power_,
|
||||
):
|
||||
|
||||
# Case of Tensor x Tensor
|
||||
if op is torch.Tensor.float_power_ and base_dtype != out_dtype:
|
||||
with self.assertRaisesRegex(
|
||||
|
|
@ -4431,6 +4465,7 @@ tensor_binary_ops = [
|
|||
# '__divmod__', '__rdivmod__', '__idivmod__',
|
||||
]
|
||||
|
||||
|
||||
# Test that binary math operations return NotImplemented for unknown types.
|
||||
def generate_not_implemented_tests(cls):
|
||||
class UnknownType:
|
||||
|
|
|
|||
|
|
@ -1,25 +1,29 @@
|
|||
#!/usr/bin/env python3
|
||||
# Owner(s): ["oncall: mobile"]
|
||||
|
||||
import io
|
||||
|
||||
import cv2
|
||||
import torch
|
||||
import torch.utils.bundled_inputs
|
||||
import io
|
||||
import cv2
|
||||
from torch.testing._internal.common_utils import TestCase
|
||||
|
||||
torch.ops.load_library("//caffe2/torch/fb/operators:decode_bundled_image")
|
||||
|
||||
|
||||
def model_size(sm):
|
||||
buffer = io.BytesIO()
|
||||
torch.jit.save(sm, buffer)
|
||||
return len(buffer.getvalue())
|
||||
|
||||
|
||||
def save_and_load(sm):
|
||||
buffer = io.BytesIO()
|
||||
torch.jit.save(sm, buffer)
|
||||
buffer.seek(0)
|
||||
return torch.jit.load(buffer)
|
||||
|
||||
|
||||
"""Return an InflatableArg that contains a tensor of the compressed image and the way to decode it
|
||||
|
||||
keyword arguments:
|
||||
|
|
@ -27,6 +31,8 @@ def save_and_load(sm):
|
|||
if in NCHW format, N should be 1
|
||||
quality -- the quality needed to compress the image
|
||||
"""
|
||||
|
||||
|
||||
def bundle_jpeg_image(img_tensor, quality):
|
||||
# turn NCHW to HWC
|
||||
if img_tensor.dim() == 4:
|
||||
|
|
@ -37,9 +43,12 @@ def bundle_jpeg_image(img_tensor, quality):
|
|||
_, enc_img = cv2.imencode(".JPEG", pixels, encode_param)
|
||||
enc_img_tensor = torch.from_numpy(enc_img)
|
||||
enc_img_tensor = torch.flatten(enc_img_tensor).byte()
|
||||
obj = torch.utils.bundled_inputs.InflatableArg(enc_img_tensor, "torch.ops.fb.decode_bundled_image({})")
|
||||
obj = torch.utils.bundled_inputs.InflatableArg(
|
||||
enc_img_tensor, "torch.ops.fb.decode_bundled_image({})"
|
||||
)
|
||||
return obj
|
||||
|
||||
|
||||
def get_tensor_from_raw_BGR(im) -> torch.Tensor:
|
||||
raw_data = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
|
||||
raw_data = torch.from_numpy(raw_data).float()
|
||||
|
|
@ -53,6 +62,7 @@ class TestBundledImages(TestCase):
|
|||
class SingleTensorModel(torch.nn.Module):
|
||||
def forward(self, arg):
|
||||
return arg
|
||||
|
||||
im = cv2.imread("caffe2/test/test_img/p1.jpg")
|
||||
tensor = torch.from_numpy(im)
|
||||
inflatable_arg = bundle_jpeg_image(tensor, 90)
|
||||
|
|
|
|||
|
|
@ -3,11 +3,11 @@
|
|||
|
||||
import io
|
||||
import textwrap
|
||||
from typing import List, Optional, Dict
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.utils.bundled_inputs
|
||||
from torch.testing._internal.common_utils import TestCase, run_tests
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
|
||||
|
||||
def model_size(sm):
|
||||
|
|
@ -24,7 +24,6 @@ def save_and_load(sm):
|
|||
|
||||
|
||||
class TestBundledInputs(TestCase):
|
||||
|
||||
def test_single_tensors(self):
|
||||
class SingleTensorModel(torch.nn.Module):
|
||||
def forward(self, arg):
|
||||
|
|
@ -32,7 +31,7 @@ class TestBundledInputs(TestCase):
|
|||
|
||||
sm = torch.jit.script(SingleTensorModel())
|
||||
original_size = model_size(sm)
|
||||
get_expr : List[str] = []
|
||||
get_expr: List[str] = []
|
||||
samples = [
|
||||
# Tensor with small numel and small storage.
|
||||
(torch.tensor([1]),),
|
||||
|
|
@ -50,7 +49,8 @@ class TestBundledInputs(TestCase):
|
|||
(torch.quantize_per_tensor(torch.zeros(4, 8, 32, 32), 1, 0, torch.qint8),),
|
||||
]
|
||||
torch.utils.bundled_inputs.augment_model_with_bundled_inputs(
|
||||
sm, samples, get_expr)
|
||||
sm, samples, get_expr
|
||||
)
|
||||
# print(get_expr[0])
|
||||
# print(sm._generate_bundled_inputs.code)
|
||||
|
||||
|
|
@ -80,18 +80,17 @@ class TestBundledInputs(TestCase):
|
|||
self.assertEqual(inflated[5][0].mean().item(), 0, atol=0.025, rtol=0)
|
||||
self.assertEqual(inflated[5][0].std().item(), 1, atol=0.02, rtol=0)
|
||||
|
||||
|
||||
def test_large_tensor_with_inflation(self):
|
||||
class SingleTensorModel(torch.nn.Module):
|
||||
def forward(self, arg):
|
||||
return arg
|
||||
|
||||
sm = torch.jit.script(SingleTensorModel())
|
||||
sample_tensor = torch.randn(1 << 16)
|
||||
# We can store tensors with custom inflation functions regardless
|
||||
# of size, even if inflation is just the identity.
|
||||
sample = torch.utils.bundled_inputs.bundle_large_tensor(sample_tensor)
|
||||
torch.utils.bundled_inputs.augment_model_with_bundled_inputs(
|
||||
sm, [(sample,)])
|
||||
torch.utils.bundled_inputs.augment_model_with_bundled_inputs(sm, [(sample,)])
|
||||
|
||||
loaded = save_and_load(sm)
|
||||
inflated = loaded.get_all_bundled_inputs()
|
||||
|
|
@ -99,17 +98,18 @@ class TestBundledInputs(TestCase):
|
|||
|
||||
self.assertEqual(inflated[0][0], sample_tensor)
|
||||
|
||||
|
||||
def test_rejected_tensors(self):
|
||||
def check_tensor(sample):
|
||||
# Need to define the class in this scope to get a fresh type for each run.
|
||||
class SingleTensorModel(torch.nn.Module):
|
||||
def forward(self, arg):
|
||||
return arg
|
||||
|
||||
sm = torch.jit.script(SingleTensorModel())
|
||||
with self.assertRaisesRegex(Exception, "Bundled input argument"):
|
||||
torch.utils.bundled_inputs.augment_model_with_bundled_inputs(
|
||||
sm, [(sample,)])
|
||||
sm, [(sample,)]
|
||||
)
|
||||
|
||||
# Plain old big tensor.
|
||||
check_tensor(torch.randn(1 << 16))
|
||||
|
|
@ -120,7 +120,6 @@ class TestBundledInputs(TestCase):
|
|||
self.assertEqual(small_sparse.numel(), 2)
|
||||
check_tensor(small_sparse)
|
||||
|
||||
|
||||
def test_non_tensors(self):
|
||||
class StringAndIntModel(torch.nn.Module):
|
||||
def forward(self, fmt: str, num: int):
|
||||
|
|
@ -131,8 +130,7 @@ class TestBundledInputs(TestCase):
|
|||
("first {}", 1),
|
||||
("second {}", 2),
|
||||
]
|
||||
torch.utils.bundled_inputs.augment_model_with_bundled_inputs(
|
||||
sm, samples)
|
||||
torch.utils.bundled_inputs.augment_model_with_bundled_inputs(sm, samples)
|
||||
|
||||
loaded = save_and_load(sm)
|
||||
inflated = loaded.get_all_bundled_inputs()
|
||||
|
|
@ -162,23 +160,17 @@ class TestBundledInputs(TestCase):
|
|||
(torch.ones(4, 8, 32, 32).contiguous(memory_format=torch.channels_last),),
|
||||
]
|
||||
info = [
|
||||
'Tensor with small numel and small storage.',
|
||||
'Tensor with large numel and small storage.',
|
||||
'Tensor with small numel and large storage.',
|
||||
'Large zero tensor.',
|
||||
'Large channels-last ones tensor.',
|
||||
'Special encoding of random tensor.',
|
||||
"Tensor with small numel and small storage.",
|
||||
"Tensor with large numel and small storage.",
|
||||
"Tensor with small numel and large storage.",
|
||||
"Large zero tensor.",
|
||||
"Large channels-last ones tensor.",
|
||||
"Special encoding of random tensor.",
|
||||
]
|
||||
torch.utils.bundled_inputs.augment_many_model_functions_with_bundled_inputs(
|
||||
mm,
|
||||
inputs={
|
||||
mm.forward : samples,
|
||||
mm.foo : samples
|
||||
},
|
||||
info={
|
||||
mm.forward : info,
|
||||
mm.foo : info
|
||||
}
|
||||
inputs={mm.forward: samples, mm.foo: samples},
|
||||
info={mm.forward: info, mm.foo: info},
|
||||
)
|
||||
loaded = save_and_load(mm)
|
||||
inflated = loaded.get_all_bundled_inputs()
|
||||
|
|
@ -194,15 +186,21 @@ class TestBundledInputs(TestCase):
|
|||
|
||||
# Check helper that work on all functions
|
||||
all_info = loaded.get_bundled_inputs_functions_and_info()
|
||||
self.assertEqual(set(all_info.keys()), {'forward', 'foo'})
|
||||
self.assertEqual(all_info['forward']['get_inputs_function_name'], ['get_all_bundled_inputs_for_forward'])
|
||||
self.assertEqual(all_info['foo']['get_inputs_function_name'], ['get_all_bundled_inputs_for_foo'])
|
||||
self.assertEqual(all_info['forward']['info'], info)
|
||||
self.assertEqual(all_info['foo']['info'], info)
|
||||
self.assertEqual(set(all_info.keys()), {"forward", "foo"})
|
||||
self.assertEqual(
|
||||
all_info["forward"]["get_inputs_function_name"],
|
||||
["get_all_bundled_inputs_for_forward"],
|
||||
)
|
||||
self.assertEqual(
|
||||
all_info["foo"]["get_inputs_function_name"],
|
||||
["get_all_bundled_inputs_for_foo"],
|
||||
)
|
||||
self.assertEqual(all_info["forward"]["info"], info)
|
||||
self.assertEqual(all_info["foo"]["info"], info)
|
||||
|
||||
# example of how to turn the 'get_inputs_function_name' into the actual list of bundled inputs
|
||||
for func_name in all_info.keys():
|
||||
input_func_name = all_info[func_name]['get_inputs_function_name'][0]
|
||||
input_func_name = all_info[func_name]["get_inputs_function_name"][0]
|
||||
func_to_run = getattr(loaded, input_func_name)
|
||||
self.assertEqual(func_to_run(), samples)
|
||||
|
||||
|
|
@ -220,16 +218,18 @@ class TestBundledInputs(TestCase):
|
|||
# inputs defined 2 ways so should fail
|
||||
with self.assertRaises(Exception):
|
||||
mm = torch.jit.script(MultipleMethodModel())
|
||||
definition = textwrap.dedent("""
|
||||
definition = textwrap.dedent(
|
||||
"""
|
||||
def _generate_bundled_inputs_for_forward(self):
|
||||
return []
|
||||
""")
|
||||
"""
|
||||
)
|
||||
mm.define(definition)
|
||||
torch.utils.bundled_inputs.augment_many_model_functions_with_bundled_inputs(
|
||||
mm,
|
||||
inputs={
|
||||
mm.forward : samples,
|
||||
mm.foo : samples,
|
||||
mm.forward: samples,
|
||||
mm.foo: samples,
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -251,8 +251,8 @@ class TestBundledInputs(TestCase):
|
|||
torch.utils.bundled_inputs.augment_many_model_functions_with_bundled_inputs(
|
||||
mm,
|
||||
inputs={
|
||||
mm.forward : None,
|
||||
mm.foo : samples,
|
||||
mm.forward: None,
|
||||
mm.foo: samples,
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -265,8 +265,7 @@ class TestBundledInputs(TestCase):
|
|||
with self.assertRaises(TypeError):
|
||||
m = torch.jit.script(SingleTensorModel())
|
||||
torch.utils.bundled_inputs.augment_model_with_bundled_inputs(
|
||||
m,
|
||||
inputs="foo" # type: ignore[arg-type]
|
||||
m, inputs="foo" # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
# List of non tuples. Most common error using the api.
|
||||
|
|
@ -274,7 +273,9 @@ class TestBundledInputs(TestCase):
|
|||
m = torch.jit.script(SingleTensorModel())
|
||||
torch.utils.bundled_inputs.augment_model_with_bundled_inputs(
|
||||
m,
|
||||
inputs=[torch.ones(1, 2), ] # type: ignore[list-item]
|
||||
inputs=[
|
||||
torch.ones(1, 2), # type: ignore[list-item]
|
||||
],
|
||||
)
|
||||
|
||||
def test_double_augment_fail(self):
|
||||
|
|
@ -284,13 +285,13 @@ class TestBundledInputs(TestCase):
|
|||
|
||||
m = torch.jit.script(SingleTensorModel())
|
||||
torch.utils.bundled_inputs.augment_model_with_bundled_inputs(
|
||||
m,
|
||||
inputs=[(torch.ones(1),)]
|
||||
m, inputs=[(torch.ones(1),)]
|
||||
)
|
||||
with self.assertRaisesRegex(Exception, "Models can only be augmented with bundled inputs once."):
|
||||
with self.assertRaisesRegex(
|
||||
Exception, "Models can only be augmented with bundled inputs once."
|
||||
):
|
||||
torch.utils.bundled_inputs.augment_model_with_bundled_inputs(
|
||||
m,
|
||||
inputs=[(torch.ones(1),)]
|
||||
m, inputs=[(torch.ones(1),)]
|
||||
)
|
||||
|
||||
def test_double_augment_non_mutator(self):
|
||||
|
|
@ -300,8 +301,7 @@ class TestBundledInputs(TestCase):
|
|||
|
||||
m = torch.jit.script(SingleTensorModel())
|
||||
bundled_model = torch.utils.bundled_inputs.bundle_inputs(
|
||||
m,
|
||||
inputs=[(torch.ones(1),)]
|
||||
m, inputs=[(torch.ones(1),)]
|
||||
)
|
||||
with self.assertRaises(AttributeError):
|
||||
m.get_all_bundled_inputs()
|
||||
|
|
@ -315,18 +315,15 @@ class TestBundledInputs(TestCase):
|
|||
|
||||
m = torch.jit.script(SingleTensorModel())
|
||||
bundled_model = torch.utils.bundled_inputs.bundle_inputs(
|
||||
m,
|
||||
inputs={m.forward : [(torch.ones(1),)]}
|
||||
m, inputs={m.forward: [(torch.ones(1),)]}
|
||||
)
|
||||
self.assertEqual(bundled_model.get_all_bundled_inputs(), [(torch.ones(1),)])
|
||||
|
||||
bundled_model2 = torch.utils.bundled_inputs.bundle_inputs(
|
||||
bundled_model,
|
||||
inputs=[(torch.ones(2),)]
|
||||
bundled_model, inputs=[(torch.ones(2),)]
|
||||
)
|
||||
self.assertEqual(bundled_model2.get_all_bundled_inputs(), [(torch.ones(2),)])
|
||||
|
||||
|
||||
def test_dict_args(self):
|
||||
class MyModel(torch.nn.Module):
|
||||
def forward(
|
||||
|
|
@ -396,7 +393,7 @@ class TestBundledInputs(TestCase):
|
|||
""",
|
||||
)
|
||||
|
||||
out : List[str] = []
|
||||
out: List[str] = []
|
||||
sm = torch.jit.script(MyModel())
|
||||
original_size = model_size(sm)
|
||||
small_inputs = (
|
||||
|
|
@ -426,7 +423,10 @@ class TestBundledInputs(TestCase):
|
|||
inflated = loaded.get_all_bundled_inputs()
|
||||
self.assertEqual(len(inflated[0]), len(small_inputs))
|
||||
|
||||
methods, _ = torch.utils.bundled_inputs._get_bundled_inputs_attributes_and_methods(
|
||||
(
|
||||
methods,
|
||||
_,
|
||||
) = torch.utils.bundled_inputs._get_bundled_inputs_attributes_and_methods(
|
||||
loaded
|
||||
)
|
||||
|
||||
|
|
@ -439,5 +439,5 @@ class TestBundledInputs(TestCase):
|
|||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -2,7 +2,8 @@
|
|||
# Owner(s): ["module: internals"]
|
||||
|
||||
import torch
|
||||
from torch.testing._internal.common_utils import TestCase, run_tests
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
|
||||
|
||||
class TestComparisonUtils(TestCase):
|
||||
def test_all_equal_no_assert(self):
|
||||
|
|
@ -32,5 +33,5 @@ class TestComparisonUtils(TestCase):
|
|||
torch._assert_tensor_metadata(t, [3], [1], torch.float)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -1,17 +1,20 @@
|
|||
# Owner(s): ["module: dynamo"]
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
import torch._dynamo as torchdynamo
|
||||
from torch.testing._internal.common_utils import TestCase, run_tests, TEST_CUDA
|
||||
import unittest
|
||||
from torch.testing._internal.common_utils import run_tests, TEST_CUDA, TestCase
|
||||
|
||||
try:
|
||||
import tabulate # noqa: F401 # type: ignore[import]
|
||||
from torch.utils.benchmark.utils.compile import bench_all
|
||||
|
||||
HAS_TABULATE = True
|
||||
except ImportError:
|
||||
HAS_TABULATE = False
|
||||
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
|
||||
@unittest.skipIf(not HAS_TABULATE, "tabulate not available")
|
||||
class TestCompileBenchmarkUtil(TestCase):
|
||||
|
|
@ -28,10 +31,24 @@ class TestCompileBenchmarkUtil(TestCase):
|
|||
model = ToyModel().cuda()
|
||||
|
||||
inference_table = bench_all(model, torch.ones(1024, 2, 2).cuda(), 5)
|
||||
self.assertTrue("Inference" in inference_table and "Eager" in inference_table and "-" in inference_table)
|
||||
self.assertTrue(
|
||||
"Inference" in inference_table
|
||||
and "Eager" in inference_table
|
||||
and "-" in inference_table
|
||||
)
|
||||
|
||||
training_table = bench_all(model, torch.ones(1024, 2, 2).cuda(), 5, optimizer=torch.optim.SGD(model.parameters(), lr=0.01))
|
||||
self.assertTrue("Train" in training_table and "Eager" in training_table and "-" in training_table)
|
||||
training_table = bench_all(
|
||||
model,
|
||||
torch.ones(1024, 2, 2).cuda(),
|
||||
5,
|
||||
optimizer=torch.optim.SGD(model.parameters(), lr=0.01),
|
||||
)
|
||||
self.assertTrue(
|
||||
"Train" in training_table
|
||||
and "Eager" in training_table
|
||||
and "-" in training_table
|
||||
)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -2,27 +2,31 @@
|
|||
|
||||
import torch
|
||||
from torch.testing._internal.common_device_type import (
|
||||
instantiate_device_type_tests,
|
||||
dtypes,
|
||||
instantiate_device_type_tests,
|
||||
onlyCPU,
|
||||
)
|
||||
from torch.testing._internal.common_utils import TestCase, run_tests, set_default_dtype
|
||||
from torch.testing._internal.common_dtype import complex_types
|
||||
from torch.testing._internal.common_utils import run_tests, set_default_dtype, TestCase
|
||||
|
||||
devices = (torch.device("cpu"), torch.device("cuda:0"))
|
||||
|
||||
devices = (torch.device('cpu'), torch.device('cuda:0'))
|
||||
|
||||
class TestComplexTensor(TestCase):
|
||||
@dtypes(*complex_types())
|
||||
def test_to_list(self, device, dtype):
|
||||
# test that the complex float tensor has expected values and
|
||||
# there's no garbage value in the resultant list
|
||||
self.assertEqual(torch.zeros((2, 2), device=device, dtype=dtype).tolist(), [[0j, 0j], [0j, 0j]])
|
||||
self.assertEqual(
|
||||
torch.zeros((2, 2), device=device, dtype=dtype).tolist(),
|
||||
[[0j, 0j], [0j, 0j]],
|
||||
)
|
||||
|
||||
@dtypes(torch.float32, torch.float64, torch.float16)
|
||||
def test_dtype_inference(self, device, dtype):
|
||||
# issue: https://github.com/pytorch/pytorch/issues/36834
|
||||
with set_default_dtype(dtype):
|
||||
x = torch.tensor([3., 3. + 5.j], device=device)
|
||||
x = torch.tensor([3.0, 3.0 + 5.0j], device=device)
|
||||
if dtype == torch.float16:
|
||||
self.assertEqual(x.dtype, torch.chalf)
|
||||
elif dtype == torch.float32:
|
||||
|
|
@ -47,7 +51,9 @@ class TestComplexTensor(TestCase):
|
|||
@dtypes(*complex_types())
|
||||
def test_any(self, device, dtype):
|
||||
# issue: https://github.com/pytorch/pytorch/issues/120875
|
||||
x = torch.tensor([0, 0j, -0 + 0j, -0 - 0j, 0 + 0j, 0 - 0j], device=device, dtype=dtype)
|
||||
x = torch.tensor(
|
||||
[0, 0j, -0 + 0j, -0 - 0j, 0 + 0j, 0 - 0j], device=device, dtype=dtype
|
||||
)
|
||||
self.assertFalse(torch.any(x))
|
||||
|
||||
@onlyCPU
|
||||
|
|
@ -57,67 +63,179 @@ class TestComplexTensor(TestCase):
|
|||
nan = float("nan")
|
||||
# Non-vectorized operations
|
||||
for a, b in (
|
||||
(torch.tensor([-0.0610 - 2.1172j], device=device, dtype=dtype),
|
||||
torch.tensor([-6.1278 - 8.5019j], device=device, dtype=dtype)),
|
||||
(torch.tensor([-0.0610 - 2.1172j], device=device, dtype=dtype),
|
||||
torch.tensor([-6.1278 - 2.1172j], device=device, dtype=dtype)),
|
||||
(torch.tensor([-0.0610 - 2.1172j], device=device, dtype=dtype),
|
||||
torch.tensor([-0.0610 - 8.5019j], device=device, dtype=dtype)),
|
||||
(
|
||||
torch.tensor([-0.0610 - 2.1172j], device=device, dtype=dtype),
|
||||
torch.tensor([-6.1278 - 8.5019j], device=device, dtype=dtype),
|
||||
),
|
||||
(
|
||||
torch.tensor([-0.0610 - 2.1172j], device=device, dtype=dtype),
|
||||
torch.tensor([-6.1278 - 2.1172j], device=device, dtype=dtype),
|
||||
),
|
||||
(
|
||||
torch.tensor([-0.0610 - 2.1172j], device=device, dtype=dtype),
|
||||
torch.tensor([-0.0610 - 8.5019j], device=device, dtype=dtype),
|
||||
),
|
||||
):
|
||||
actual = torch.eq(a, b)
|
||||
expected = torch.tensor([False], device=device, dtype=torch.bool)
|
||||
self.assertEqual(actual, expected, msg=f"\neq\nactual {actual}\nexpected {expected}")
|
||||
self.assertEqual(
|
||||
actual, expected, msg=f"\neq\nactual {actual}\nexpected {expected}"
|
||||
)
|
||||
|
||||
actual = torch.eq(a, a)
|
||||
expected = torch.tensor([True], device=device, dtype=torch.bool)
|
||||
self.assertEqual(actual, expected, msg=f"\neq\nactual {actual}\nexpected {expected}")
|
||||
self.assertEqual(
|
||||
actual, expected, msg=f"\neq\nactual {actual}\nexpected {expected}"
|
||||
)
|
||||
|
||||
actual = torch.full_like(b, complex(2, 2))
|
||||
torch.eq(a, b, out=actual)
|
||||
expected = torch.tensor([complex(0)], device=device, dtype=dtype)
|
||||
self.assertEqual(actual, expected, msg=f"\neq(out)\nactual {actual}\nexpected {expected}")
|
||||
self.assertEqual(
|
||||
actual, expected, msg=f"\neq(out)\nactual {actual}\nexpected {expected}"
|
||||
)
|
||||
|
||||
actual = torch.full_like(b, complex(2, 2))
|
||||
torch.eq(a, a, out=actual)
|
||||
expected = torch.tensor([complex(1)], device=device, dtype=dtype)
|
||||
self.assertEqual(actual, expected, msg=f"\neq(out)\nactual {actual}\nexpected {expected}")
|
||||
self.assertEqual(
|
||||
actual, expected, msg=f"\neq(out)\nactual {actual}\nexpected {expected}"
|
||||
)
|
||||
|
||||
# Vectorized operations
|
||||
for a, b in (
|
||||
(torch.tensor([
|
||||
-0.0610 - 2.1172j, 5.1576 + 5.4775j, complex(2.8871, nan), -6.6545 - 3.7655j, -2.7036 - 1.4470j, 0.3712 + 7.989j,
|
||||
-0.0610 - 2.1172j, 5.1576 + 5.4775j, complex(nan, -3.2650), -6.6545 - 3.7655j, -2.7036 - 1.4470j, 0.3712 + 7.989j],
|
||||
device=device, dtype=dtype),
|
||||
torch.tensor([
|
||||
-6.1278 - 8.5019j, 0.5886 + 8.8816j, complex(2.8871, nan), 6.3505 + 2.2683j, 0.3712 + 7.9659j, 0.3712 + 7.989j,
|
||||
-6.1278 - 2.1172j, 5.1576 + 8.8816j, complex(nan, -3.2650), 6.3505 + 2.2683j, 0.3712 + 7.9659j, 0.3712 + 7.989j],
|
||||
device=device, dtype=dtype)),
|
||||
(
|
||||
torch.tensor(
|
||||
[
|
||||
-0.0610 - 2.1172j,
|
||||
5.1576 + 5.4775j,
|
||||
complex(2.8871, nan),
|
||||
-6.6545 - 3.7655j,
|
||||
-2.7036 - 1.4470j,
|
||||
0.3712 + 7.989j,
|
||||
-0.0610 - 2.1172j,
|
||||
5.1576 + 5.4775j,
|
||||
complex(nan, -3.2650),
|
||||
-6.6545 - 3.7655j,
|
||||
-2.7036 - 1.4470j,
|
||||
0.3712 + 7.989j,
|
||||
],
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
),
|
||||
torch.tensor(
|
||||
[
|
||||
-6.1278 - 8.5019j,
|
||||
0.5886 + 8.8816j,
|
||||
complex(2.8871, nan),
|
||||
6.3505 + 2.2683j,
|
||||
0.3712 + 7.9659j,
|
||||
0.3712 + 7.989j,
|
||||
-6.1278 - 2.1172j,
|
||||
5.1576 + 8.8816j,
|
||||
complex(nan, -3.2650),
|
||||
6.3505 + 2.2683j,
|
||||
0.3712 + 7.9659j,
|
||||
0.3712 + 7.989j,
|
||||
],
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
),
|
||||
),
|
||||
):
|
||||
actual = torch.eq(a, b)
|
||||
expected = torch.tensor([False, False, False, False, False, True,
|
||||
False, False, False, False, False, True],
|
||||
device=device, dtype=torch.bool)
|
||||
self.assertEqual(actual, expected, msg=f"\neq\nactual {actual}\nexpected {expected}")
|
||||
expected = torch.tensor(
|
||||
[
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
True,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
True,
|
||||
],
|
||||
device=device,
|
||||
dtype=torch.bool,
|
||||
)
|
||||
self.assertEqual(
|
||||
actual, expected, msg=f"\neq\nactual {actual}\nexpected {expected}"
|
||||
)
|
||||
|
||||
actual = torch.eq(a, a)
|
||||
expected = torch.tensor([True, True, False, True, True, True,
|
||||
True, True, False, True, True, True],
|
||||
device=device, dtype=torch.bool)
|
||||
self.assertEqual(actual, expected, msg=f"\neq\nactual {actual}\nexpected {expected}")
|
||||
expected = torch.tensor(
|
||||
[
|
||||
True,
|
||||
True,
|
||||
False,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
False,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
],
|
||||
device=device,
|
||||
dtype=torch.bool,
|
||||
)
|
||||
self.assertEqual(
|
||||
actual, expected, msg=f"\neq\nactual {actual}\nexpected {expected}"
|
||||
)
|
||||
|
||||
actual = torch.full_like(b, complex(2, 2))
|
||||
torch.eq(a, b, out=actual)
|
||||
expected = torch.tensor([complex(0), complex(0), complex(0), complex(0), complex(0), complex(1),
|
||||
complex(0), complex(0), complex(0), complex(0), complex(0), complex(1)],
|
||||
device=device, dtype=dtype)
|
||||
self.assertEqual(actual, expected, msg=f"\neq(out)\nactual {actual}\nexpected {expected}")
|
||||
expected = torch.tensor(
|
||||
[
|
||||
complex(0),
|
||||
complex(0),
|
||||
complex(0),
|
||||
complex(0),
|
||||
complex(0),
|
||||
complex(1),
|
||||
complex(0),
|
||||
complex(0),
|
||||
complex(0),
|
||||
complex(0),
|
||||
complex(0),
|
||||
complex(1),
|
||||
],
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
self.assertEqual(
|
||||
actual, expected, msg=f"\neq(out)\nactual {actual}\nexpected {expected}"
|
||||
)
|
||||
|
||||
actual = torch.full_like(b, complex(2, 2))
|
||||
torch.eq(a, a, out=actual)
|
||||
expected = torch.tensor([complex(1), complex(1), complex(0), complex(1), complex(1), complex(1),
|
||||
complex(1), complex(1), complex(0), complex(1), complex(1), complex(1)],
|
||||
device=device, dtype=dtype)
|
||||
self.assertEqual(actual, expected, msg=f"\neq(out)\nactual {actual}\nexpected {expected}")
|
||||
expected = torch.tensor(
|
||||
[
|
||||
complex(1),
|
||||
complex(1),
|
||||
complex(0),
|
||||
complex(1),
|
||||
complex(1),
|
||||
complex(1),
|
||||
complex(1),
|
||||
complex(1),
|
||||
complex(0),
|
||||
complex(1),
|
||||
complex(1),
|
||||
complex(1),
|
||||
],
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
self.assertEqual(
|
||||
actual, expected, msg=f"\neq(out)\nactual {actual}\nexpected {expected}"
|
||||
)
|
||||
|
||||
@onlyCPU
|
||||
@dtypes(*complex_types())
|
||||
|
|
@ -126,70 +244,183 @@ class TestComplexTensor(TestCase):
|
|||
nan = float("nan")
|
||||
# Non-vectorized operations
|
||||
for a, b in (
|
||||
(torch.tensor([-0.0610 - 2.1172j], device=device, dtype=dtype),
|
||||
torch.tensor([-6.1278 - 8.5019j], device=device, dtype=dtype)),
|
||||
(torch.tensor([-0.0610 - 2.1172j], device=device, dtype=dtype),
|
||||
torch.tensor([-6.1278 - 2.1172j], device=device, dtype=dtype)),
|
||||
(torch.tensor([-0.0610 - 2.1172j], device=device, dtype=dtype),
|
||||
torch.tensor([-0.0610 - 8.5019j], device=device, dtype=dtype)),
|
||||
(
|
||||
torch.tensor([-0.0610 - 2.1172j], device=device, dtype=dtype),
|
||||
torch.tensor([-6.1278 - 8.5019j], device=device, dtype=dtype),
|
||||
),
|
||||
(
|
||||
torch.tensor([-0.0610 - 2.1172j], device=device, dtype=dtype),
|
||||
torch.tensor([-6.1278 - 2.1172j], device=device, dtype=dtype),
|
||||
),
|
||||
(
|
||||
torch.tensor([-0.0610 - 2.1172j], device=device, dtype=dtype),
|
||||
torch.tensor([-0.0610 - 8.5019j], device=device, dtype=dtype),
|
||||
),
|
||||
):
|
||||
actual = torch.ne(a, b)
|
||||
expected = torch.tensor([True], device=device, dtype=torch.bool)
|
||||
self.assertEqual(actual, expected, msg=f"\nne\nactual {actual}\nexpected {expected}")
|
||||
self.assertEqual(
|
||||
actual, expected, msg=f"\nne\nactual {actual}\nexpected {expected}"
|
||||
)
|
||||
|
||||
actual = torch.ne(a, a)
|
||||
expected = torch.tensor([False], device=device, dtype=torch.bool)
|
||||
self.assertEqual(actual, expected, msg=f"\nne\nactual {actual}\nexpected {expected}")
|
||||
self.assertEqual(
|
||||
actual, expected, msg=f"\nne\nactual {actual}\nexpected {expected}"
|
||||
)
|
||||
|
||||
actual = torch.full_like(b, complex(2, 2))
|
||||
torch.ne(a, b, out=actual)
|
||||
expected = torch.tensor([complex(1)], device=device, dtype=dtype)
|
||||
self.assertEqual(actual, expected, msg=f"\nne(out)\nactual {actual}\nexpected {expected}")
|
||||
self.assertEqual(
|
||||
actual, expected, msg=f"\nne(out)\nactual {actual}\nexpected {expected}"
|
||||
)
|
||||
|
||||
actual = torch.full_like(b, complex(2, 2))
|
||||
torch.ne(a, a, out=actual)
|
||||
expected = torch.tensor([complex(0)], device=device, dtype=dtype)
|
||||
self.assertEqual(actual, expected, msg=f"\nne(out)\nactual {actual}\nexpected {expected}")
|
||||
self.assertEqual(
|
||||
actual, expected, msg=f"\nne(out)\nactual {actual}\nexpected {expected}"
|
||||
)
|
||||
|
||||
# Vectorized operations
|
||||
for a, b in (
|
||||
(torch.tensor([
|
||||
-0.0610 - 2.1172j, 5.1576 + 5.4775j, complex(2.8871, nan), -6.6545 - 3.7655j, -2.7036 - 1.4470j, 0.3712 + 7.989j,
|
||||
-0.0610 - 2.1172j, 5.1576 + 5.4775j, complex(nan, -3.2650), -6.6545 - 3.7655j, -2.7036 - 1.4470j, 0.3712 + 7.989j],
|
||||
device=device, dtype=dtype),
|
||||
torch.tensor([
|
||||
-6.1278 - 8.5019j, 0.5886 + 8.8816j, complex(2.8871, nan), 6.3505 + 2.2683j, 0.3712 + 7.9659j, 0.3712 + 7.989j,
|
||||
-6.1278 - 2.1172j, 5.1576 + 8.8816j, complex(nan, -3.2650), 6.3505 + 2.2683j, 0.3712 + 7.9659j, 0.3712 + 7.989j],
|
||||
device=device, dtype=dtype)),
|
||||
(
|
||||
torch.tensor(
|
||||
[
|
||||
-0.0610 - 2.1172j,
|
||||
5.1576 + 5.4775j,
|
||||
complex(2.8871, nan),
|
||||
-6.6545 - 3.7655j,
|
||||
-2.7036 - 1.4470j,
|
||||
0.3712 + 7.989j,
|
||||
-0.0610 - 2.1172j,
|
||||
5.1576 + 5.4775j,
|
||||
complex(nan, -3.2650),
|
||||
-6.6545 - 3.7655j,
|
||||
-2.7036 - 1.4470j,
|
||||
0.3712 + 7.989j,
|
||||
],
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
),
|
||||
torch.tensor(
|
||||
[
|
||||
-6.1278 - 8.5019j,
|
||||
0.5886 + 8.8816j,
|
||||
complex(2.8871, nan),
|
||||
6.3505 + 2.2683j,
|
||||
0.3712 + 7.9659j,
|
||||
0.3712 + 7.989j,
|
||||
-6.1278 - 2.1172j,
|
||||
5.1576 + 8.8816j,
|
||||
complex(nan, -3.2650),
|
||||
6.3505 + 2.2683j,
|
||||
0.3712 + 7.9659j,
|
||||
0.3712 + 7.989j,
|
||||
],
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
),
|
||||
),
|
||||
):
|
||||
actual = torch.ne(a, b)
|
||||
expected = torch.tensor([True, True, True, True, True, False,
|
||||
True, True, True, True, True, False],
|
||||
device=device, dtype=torch.bool)
|
||||
self.assertEqual(actual, expected, msg=f"\nne\nactual {actual}\nexpected {expected}")
|
||||
expected = torch.tensor(
|
||||
[
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
False,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
False,
|
||||
],
|
||||
device=device,
|
||||
dtype=torch.bool,
|
||||
)
|
||||
self.assertEqual(
|
||||
actual, expected, msg=f"\nne\nactual {actual}\nexpected {expected}"
|
||||
)
|
||||
|
||||
actual = torch.ne(a, a)
|
||||
expected = torch.tensor([False, False, True, False, False, False,
|
||||
False, False, True, False, False, False],
|
||||
device=device, dtype=torch.bool)
|
||||
self.assertEqual(actual, expected, msg=f"\nne\nactual {actual}\nexpected {expected}")
|
||||
expected = torch.tensor(
|
||||
[
|
||||
False,
|
||||
False,
|
||||
True,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
True,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
],
|
||||
device=device,
|
||||
dtype=torch.bool,
|
||||
)
|
||||
self.assertEqual(
|
||||
actual, expected, msg=f"\nne\nactual {actual}\nexpected {expected}"
|
||||
)
|
||||
|
||||
actual = torch.full_like(b, complex(2, 2))
|
||||
torch.ne(a, b, out=actual)
|
||||
expected = torch.tensor([complex(1), complex(1), complex(1), complex(1), complex(1), complex(0),
|
||||
complex(1), complex(1), complex(1), complex(1), complex(1), complex(0)],
|
||||
device=device, dtype=dtype)
|
||||
self.assertEqual(actual, expected, msg=f"\nne(out)\nactual {actual}\nexpected {expected}")
|
||||
expected = torch.tensor(
|
||||
[
|
||||
complex(1),
|
||||
complex(1),
|
||||
complex(1),
|
||||
complex(1),
|
||||
complex(1),
|
||||
complex(0),
|
||||
complex(1),
|
||||
complex(1),
|
||||
complex(1),
|
||||
complex(1),
|
||||
complex(1),
|
||||
complex(0),
|
||||
],
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
self.assertEqual(
|
||||
actual, expected, msg=f"\nne(out)\nactual {actual}\nexpected {expected}"
|
||||
)
|
||||
|
||||
actual = torch.full_like(b, complex(2, 2))
|
||||
torch.ne(a, a, out=actual)
|
||||
expected = torch.tensor([complex(0), complex(0), complex(1), complex(0), complex(0), complex(0),
|
||||
complex(0), complex(0), complex(1), complex(0), complex(0), complex(0)],
|
||||
device=device, dtype=dtype)
|
||||
self.assertEqual(actual, expected, msg=f"\nne(out)\nactual {actual}\nexpected {expected}")
|
||||
expected = torch.tensor(
|
||||
[
|
||||
complex(0),
|
||||
complex(0),
|
||||
complex(1),
|
||||
complex(0),
|
||||
complex(0),
|
||||
complex(0),
|
||||
complex(0),
|
||||
complex(0),
|
||||
complex(1),
|
||||
complex(0),
|
||||
complex(0),
|
||||
complex(0),
|
||||
],
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
self.assertEqual(
|
||||
actual, expected, msg=f"\nne(out)\nactual {actual}\nexpected {expected}"
|
||||
)
|
||||
|
||||
|
||||
instantiate_device_type_tests(TestComplexTensor, globals())
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
TestCase._default_dtype_check_enabled = True
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -4,26 +4,35 @@
|
|||
import os
|
||||
|
||||
import torch
|
||||
import torch.testing._internal.common_utils as common
|
||||
import torch.testing._internal.common_nn as common_nn
|
||||
import torch.testing._internal.common_utils as common
|
||||
from cpp_api_parity import (
|
||||
functional_impl_check,
|
||||
module_impl_check,
|
||||
sample_functional,
|
||||
sample_module,
|
||||
)
|
||||
from cpp_api_parity.parity_table_parser import parse_parity_tracker_table
|
||||
from cpp_api_parity.utils import is_torch_nn_functional_test
|
||||
from cpp_api_parity import module_impl_check, functional_impl_check, sample_module, sample_functional
|
||||
|
||||
# NOTE: turn this on if you want to print source code of all C++ tests (e.g. for debugging purpose)
|
||||
PRINT_CPP_SOURCE = False
|
||||
|
||||
devices = ['cpu', 'cuda']
|
||||
devices = ["cpu", "cuda"]
|
||||
|
||||
PARITY_TABLE_PATH = os.path.join(os.path.dirname(__file__), 'cpp_api_parity', 'parity-tracker.md')
|
||||
PARITY_TABLE_PATH = os.path.join(
|
||||
os.path.dirname(__file__), "cpp_api_parity", "parity-tracker.md"
|
||||
)
|
||||
|
||||
parity_table = parse_parity_tracker_table(PARITY_TABLE_PATH)
|
||||
|
||||
|
||||
@torch.testing._internal.common_utils.markDynamoStrictTest
|
||||
class TestCppApiParity(common.TestCase):
|
||||
module_test_params_map = {}
|
||||
functional_test_params_map = {}
|
||||
|
||||
|
||||
expected_test_params_dicts = []
|
||||
|
||||
if not common.IS_ARM64:
|
||||
|
|
@ -35,27 +44,47 @@ if not common.IS_ARM64:
|
|||
(common_nn.criterion_tests, common_nn.CriterionTest),
|
||||
]:
|
||||
for test_params_dict in test_params_dicts:
|
||||
if test_params_dict.get('test_cpp_api_parity', True):
|
||||
if test_params_dict.get("test_cpp_api_parity", True):
|
||||
if is_torch_nn_functional_test(test_params_dict):
|
||||
functional_impl_check.write_test_to_test_class(
|
||||
TestCppApiParity, test_params_dict, test_instance_class, parity_table, devices)
|
||||
TestCppApiParity,
|
||||
test_params_dict,
|
||||
test_instance_class,
|
||||
parity_table,
|
||||
devices,
|
||||
)
|
||||
else:
|
||||
module_impl_check.write_test_to_test_class(
|
||||
TestCppApiParity, test_params_dict, test_instance_class, parity_table, devices)
|
||||
TestCppApiParity,
|
||||
test_params_dict,
|
||||
test_instance_class,
|
||||
parity_table,
|
||||
devices,
|
||||
)
|
||||
expected_test_params_dicts.append(test_params_dict)
|
||||
|
||||
# Assert that all NN module/functional test dicts appear in the parity test
|
||||
assert len([name for name in TestCppApiParity.__dict__ if 'test_torch_nn_' in name]) == \
|
||||
len(expected_test_params_dicts) * len(devices)
|
||||
assert len(
|
||||
[name for name in TestCppApiParity.__dict__ if "test_torch_nn_" in name]
|
||||
) == len(expected_test_params_dicts) * len(devices)
|
||||
|
||||
# Assert that there exists auto-generated tests for `SampleModule` and `sample_functional`.
|
||||
# 4 == 2 (number of test dicts that are not skipped) * 2 (number of devices)
|
||||
assert len([name for name in TestCppApiParity.__dict__ if 'SampleModule' in name]) == 4
|
||||
assert (
|
||||
len([name for name in TestCppApiParity.__dict__ if "SampleModule" in name]) == 4
|
||||
)
|
||||
# 4 == 2 (number of test dicts that are not skipped) * 2 (number of devices)
|
||||
assert len([name for name in TestCppApiParity.__dict__ if 'sample_functional' in name]) == 4
|
||||
assert (
|
||||
len([name for name in TestCppApiParity.__dict__ if "sample_functional" in name])
|
||||
== 4
|
||||
)
|
||||
|
||||
module_impl_check.build_cpp_tests(TestCppApiParity, print_cpp_source=PRINT_CPP_SOURCE)
|
||||
functional_impl_check.build_cpp_tests(TestCppApiParity, print_cpp_source=PRINT_CPP_SOURCE)
|
||||
module_impl_check.build_cpp_tests(
|
||||
TestCppApiParity, print_cpp_source=PRINT_CPP_SOURCE
|
||||
)
|
||||
functional_impl_check.build_cpp_tests(
|
||||
TestCppApiParity, print_cpp_source=PRINT_CPP_SOURCE
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
common.TestCase._default_dtype_check_enabled = True
|
||||
|
|
|
|||
|
|
@ -1,20 +1,22 @@
|
|||
# Owner(s): ["module: cpp-extensions"]
|
||||
|
||||
from itertools import repeat
|
||||
import os
|
||||
import re
|
||||
from typing import Union, get_args, get_origin
|
||||
import unittest
|
||||
from itertools import repeat
|
||||
from typing import get_args, get_origin, Union
|
||||
|
||||
import torch.testing._internal.common_utils as common
|
||||
from torch.testing._internal.common_utils import IS_WINDOWS, skipIfTorchDynamo
|
||||
from torch.testing._internal.common_cuda import TEST_CUDA
|
||||
import torch
|
||||
import torch.backends.cudnn
|
||||
|
||||
import torch.testing._internal.common_utils as common
|
||||
import torch.utils.cpp_extension
|
||||
from torch.testing._internal.common_cuda import TEST_CUDA
|
||||
from torch.testing._internal.common_utils import IS_WINDOWS, skipIfTorchDynamo
|
||||
|
||||
try:
|
||||
import pytest
|
||||
|
||||
HAS_PYTEST = True
|
||||
except ImportError as e:
|
||||
HAS_PYTEST = False
|
||||
|
|
@ -141,11 +143,15 @@ class TestCppExtensionAOT(common.TestCase):
|
|||
@common.skipIfRocm
|
||||
@unittest.skipIf(common.IS_WINDOWS, "Windows not supported")
|
||||
@unittest.skipIf(not TEST_CUDA, "CUDA not found")
|
||||
@unittest.skipIf(os.getenv('USE_NINJA', '0') == '0', "cuda extension with dlink requires ninja to build")
|
||||
@unittest.skipIf(
|
||||
os.getenv("USE_NINJA", "0") == "0",
|
||||
"cuda extension with dlink requires ninja to build",
|
||||
)
|
||||
def test_cuda_dlink_libs(self):
|
||||
from torch_test_cpp_extension import cuda_dlink
|
||||
a = torch.randn(8, dtype=torch.float, device='cuda')
|
||||
b = torch.randn(8, dtype=torch.float, device='cuda')
|
||||
|
||||
a = torch.randn(8, dtype=torch.float, device="cuda")
|
||||
b = torch.randn(8, dtype=torch.float, device="cuda")
|
||||
ref = a + b
|
||||
test = cuda_dlink.add(a, b)
|
||||
self.assertEqual(test, ref)
|
||||
|
|
@ -164,6 +170,7 @@ class TestPybindTypeCasters(common.TestCase):
|
|||
second argument to `PYBIND11_TYPE_CASTER` should be the type we expect to
|
||||
receive in python, in these tests we verify this at run-time.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def expected_return_type(func):
|
||||
"""
|
||||
|
|
@ -220,7 +227,9 @@ class TestPybindTypeCasters(common.TestCase):
|
|||
break
|
||||
else:
|
||||
raise AssertionError(f"{val} is not an instance of {expected_types}")
|
||||
self.assertFalse(expected_types, f"Missing functions for types {expected_types}")
|
||||
self.assertFalse(
|
||||
expected_types, f"Missing functions for types {expected_types}"
|
||||
)
|
||||
|
||||
def test_pybind_return_types(self):
|
||||
functions = [
|
||||
|
|
@ -248,29 +257,29 @@ class TestPybindTypeCasters(common.TestCase):
|
|||
@torch.testing._internal.common_utils.markDynamoStrictTest
|
||||
class TestORTTensor(common.TestCase):
|
||||
def test_unregistered(self):
|
||||
a = torch.arange(0, 10, device='cpu')
|
||||
a = torch.arange(0, 10, device="cpu")
|
||||
with self.assertRaisesRegex(RuntimeError, "Could not run"):
|
||||
b = torch.arange(0, 10, device='ort')
|
||||
b = torch.arange(0, 10, device="ort")
|
||||
|
||||
@skipIfTorchDynamo("dynamo cannot model ort device")
|
||||
def test_zeros(self):
|
||||
a = torch.empty(5, 5, device='cpu')
|
||||
self.assertEqual(a.device, torch.device('cpu'))
|
||||
a = torch.empty(5, 5, device="cpu")
|
||||
self.assertEqual(a.device, torch.device("cpu"))
|
||||
|
||||
b = torch.empty(5, 5, device='ort')
|
||||
self.assertEqual(b.device, torch.device('ort', 0))
|
||||
b = torch.empty(5, 5, device="ort")
|
||||
self.assertEqual(b.device, torch.device("ort", 0))
|
||||
self.assertEqual(ort_extension.get_test_int(), 0)
|
||||
self.assertEqual(torch.get_default_dtype(), b.dtype)
|
||||
|
||||
c = torch.empty((5, 5), dtype=torch.int64, device='ort')
|
||||
c = torch.empty((5, 5), dtype=torch.int64, device="ort")
|
||||
self.assertEqual(ort_extension.get_test_int(), 0)
|
||||
self.assertEqual(torch.int64, c.dtype)
|
||||
|
||||
def test_add(self):
|
||||
a = torch.empty(5, 5, device='ort', requires_grad=True)
|
||||
a = torch.empty(5, 5, device="ort", requires_grad=True)
|
||||
self.assertEqual(ort_extension.get_test_int(), 0)
|
||||
|
||||
b = torch.empty(5, 5, device='ort')
|
||||
b = torch.empty(5, 5, device="ort")
|
||||
self.assertEqual(ort_extension.get_test_int(), 0)
|
||||
|
||||
c = a + b
|
||||
|
|
@ -279,9 +288,9 @@ class TestORTTensor(common.TestCase):
|
|||
def test_conv_backend_override(self):
|
||||
# To simplify tests, we use 4d input here to avoid doing view4d( which
|
||||
# needs more overrides) in _convolution.
|
||||
input = torch.empty(2, 4, 10, 2, device='ort', requires_grad=True)
|
||||
weight = torch.empty(6, 4, 2, 2, device='ort', requires_grad=True)
|
||||
bias = torch.empty(6, device='ort')
|
||||
input = torch.empty(2, 4, 10, 2, device="ort", requires_grad=True)
|
||||
weight = torch.empty(6, 4, 2, 2, device="ort", requires_grad=True)
|
||||
bias = torch.empty(6, device="ort")
|
||||
|
||||
# Make sure forward is overriden
|
||||
out = torch.nn.functional.conv2d(input, weight, bias, 2, 0, 1, 1)
|
||||
|
|
@ -299,7 +308,6 @@ class TestORTTensor(common.TestCase):
|
|||
|
||||
@torch.testing._internal.common_utils.markDynamoStrictTest
|
||||
class TestRNGExtension(common.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
|
||||
|
|
@ -310,7 +318,7 @@ class TestRNGExtension(common.TestCase):
|
|||
t = torch.empty(10, dtype=torch.int64).random_()
|
||||
self.assertNotEqual(t, fourty_two)
|
||||
|
||||
gen = torch.Generator(device='cpu')
|
||||
gen = torch.Generator(device="cpu")
|
||||
t = torch.empty(10, dtype=torch.int64).random_(generator=gen)
|
||||
self.assertNotEqual(t, fourty_two)
|
||||
|
||||
|
|
@ -337,7 +345,6 @@ class TestRNGExtension(common.TestCase):
|
|||
@torch.testing._internal.common_utils.markDynamoStrictTest
|
||||
@unittest.skipIf(not TEST_CUDA, "CUDA not found")
|
||||
class TestTorchLibrary(common.TestCase):
|
||||
|
||||
def test_torch_library(self):
|
||||
import torch_test_cpp_extension.torch_library # noqa: F401
|
||||
|
||||
|
|
@ -353,7 +360,7 @@ class TestTorchLibrary(common.TestCase):
|
|||
self.assertFalse(s(True, False))
|
||||
self.assertFalse(s(False, True))
|
||||
self.assertFalse(s(False, False))
|
||||
self.assertIn('torch_library::logical_and', str(s.graph))
|
||||
self.assertIn("torch_library::logical_and", str(s.graph))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -1,31 +1,38 @@
|
|||
# Owner(s): ["module: cpp-extensions"]
|
||||
|
||||
import glob
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
import warnings
|
||||
import re
|
||||
import tempfile
|
||||
import subprocess
|
||||
import glob
|
||||
|
||||
import torch.testing._internal.common_utils as common
|
||||
from torch.testing._internal.common_cuda import TEST_CUDNN, TEST_CUDA
|
||||
import torch
|
||||
import torch.backends.cudnn
|
||||
import torch.utils.cpp_extension
|
||||
from torch.utils.cpp_extension import CUDA_HOME, ROCM_HOME
|
||||
from torch.testing._internal.common_utils import gradcheck
|
||||
import torch.multiprocessing as mp
|
||||
from torch.utils.cpp_extension import _TORCH_PATH, remove_extension_h_precompiler_headers, get_cxx_compiler, check_compiler_is_gcc
|
||||
|
||||
import torch.testing._internal.common_utils as common
|
||||
import torch.utils.cpp_extension
|
||||
from torch.testing._internal.common_cuda import TEST_CUDA, TEST_CUDNN
|
||||
from torch.testing._internal.common_utils import gradcheck
|
||||
from torch.utils.cpp_extension import (
|
||||
_TORCH_PATH,
|
||||
check_compiler_is_gcc,
|
||||
CUDA_HOME,
|
||||
get_cxx_compiler,
|
||||
remove_extension_h_precompiler_headers,
|
||||
ROCM_HOME,
|
||||
)
|
||||
|
||||
# define TEST_ROCM before changing TEST_CUDA
|
||||
TEST_ROCM = TEST_CUDA and torch.version.hip is not None and ROCM_HOME is not None
|
||||
TEST_CUDA = TEST_CUDA and CUDA_HOME is not None
|
||||
TEST_MPS = torch.backends.mps.is_available()
|
||||
IS_WINDOWS = sys.platform == "win32"
|
||||
IS_LINUX = sys.platform.startswith('linux')
|
||||
IS_LINUX = sys.platform.startswith("linux")
|
||||
|
||||
|
||||
def remove_build_path():
|
||||
|
|
@ -73,9 +80,11 @@ class TestCppExtensionJIT(common.TestCase):
|
|||
"cpp_extensions/jit_extension.cpp",
|
||||
"cpp_extensions/jit_extension2.cpp",
|
||||
],
|
||||
extra_include_paths=["cpp_extensions",
|
||||
"path / with spaces in it",
|
||||
"path with quote'"],
|
||||
extra_include_paths=[
|
||||
"cpp_extensions",
|
||||
"path / with spaces in it",
|
||||
"path with quote'",
|
||||
],
|
||||
extra_cflags=["-g"],
|
||||
verbose=True,
|
||||
)
|
||||
|
|
@ -140,33 +149,39 @@ class TestCppExtensionJIT(common.TestCase):
|
|||
def _run_jit_cuda_archflags(self, flags, expected):
|
||||
# Compile an extension with given `flags`
|
||||
def _check_cuobjdump_output(expected_values, is_ptx=False):
|
||||
elf_or_ptx = '--list-ptx' if is_ptx else '--list-elf'
|
||||
lib_ext = '.pyd' if IS_WINDOWS else '.so'
|
||||
elf_or_ptx = "--list-ptx" if is_ptx else "--list-elf"
|
||||
lib_ext = ".pyd" if IS_WINDOWS else ".so"
|
||||
# Note, .extension name may include _v1, _v2, so first find exact name
|
||||
ext_filename = glob.glob(os.path.join(temp_dir,
|
||||
'cudaext_archflag*' + lib_ext))[0]
|
||||
command = ['cuobjdump', elf_or_ptx, ext_filename]
|
||||
p = subprocess.Popen(command,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE)
|
||||
ext_filename = glob.glob(
|
||||
os.path.join(temp_dir, "cudaext_archflag*" + lib_ext)
|
||||
)[0]
|
||||
command = ["cuobjdump", elf_or_ptx, ext_filename]
|
||||
p = subprocess.Popen(
|
||||
command, stdout=subprocess.PIPE, stderr=subprocess.PIPE
|
||||
)
|
||||
output, err = p.communicate()
|
||||
output = output.decode("ascii")
|
||||
err = err.decode("ascii")
|
||||
|
||||
if not p.returncode == 0 or not err == '':
|
||||
raise AssertionError(f"Flags: {flags}\nReturncode: {p.returncode}\nStderr: {err}\n"
|
||||
f"Output: {output} ")
|
||||
if not p.returncode == 0 or not err == "":
|
||||
raise AssertionError(
|
||||
f"Flags: {flags}\nReturncode: {p.returncode}\nStderr: {err}\n"
|
||||
f"Output: {output} "
|
||||
)
|
||||
|
||||
actual_arches = sorted(re.findall(r'sm_\d\d', output))
|
||||
expected_arches = sorted(['sm_' + xx for xx in expected_values])
|
||||
self.assertEqual(actual_arches, expected_arches,
|
||||
msg=f"Flags: {flags}, Actual: {actual_arches}, Expected: {expected_arches}\n"
|
||||
f"Stderr: {err}\nOutput: {output}")
|
||||
actual_arches = sorted(re.findall(r"sm_\d\d", output))
|
||||
expected_arches = sorted(["sm_" + xx for xx in expected_values])
|
||||
self.assertEqual(
|
||||
actual_arches,
|
||||
expected_arches,
|
||||
msg=f"Flags: {flags}, Actual: {actual_arches}, Expected: {expected_arches}\n"
|
||||
f"Stderr: {err}\nOutput: {output}",
|
||||
)
|
||||
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
old_envvar = os.environ.get('TORCH_CUDA_ARCH_LIST', None)
|
||||
old_envvar = os.environ.get("TORCH_CUDA_ARCH_LIST", None)
|
||||
try:
|
||||
os.environ['TORCH_CUDA_ARCH_LIST'] = flags
|
||||
os.environ["TORCH_CUDA_ARCH_LIST"] = flags
|
||||
|
||||
params = {
|
||||
"name": "cudaext_archflags",
|
||||
|
|
@ -209,9 +224,9 @@ class TestCppExtensionJIT(common.TestCase):
|
|||
shutil.rmtree(temp_dir)
|
||||
|
||||
if old_envvar is None:
|
||||
os.environ.pop('TORCH_CUDA_ARCH_LIST')
|
||||
os.environ.pop("TORCH_CUDA_ARCH_LIST")
|
||||
else:
|
||||
os.environ['TORCH_CUDA_ARCH_LIST'] = old_envvar
|
||||
os.environ["TORCH_CUDA_ARCH_LIST"] = old_envvar
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "CUDA not found")
|
||||
@unittest.skipIf(TEST_ROCM, "disabled on rocm")
|
||||
|
|
@ -227,15 +242,18 @@ class TestCppExtensionJIT(common.TestCase):
|
|||
# expected values is length-2 tuple: (list of ELF, list of PTX)
|
||||
# note: there should not be more than one PTX value
|
||||
archflags = {
|
||||
'': ([f'{capability[0]}{capability[1]}' for capability in capabilities], None),
|
||||
"Maxwell+Tegra;6.1": (['53', '61'], None),
|
||||
"Volta": (['70'], ['70']),
|
||||
"": (
|
||||
[f"{capability[0]}{capability[1]}" for capability in capabilities],
|
||||
None,
|
||||
),
|
||||
"Maxwell+Tegra;6.1": (["53", "61"], None),
|
||||
"Volta": (["70"], ["70"]),
|
||||
}
|
||||
archflags["7.5+PTX"] = (['75'], ['75'])
|
||||
archflags["5.0;6.0+PTX;7.0;7.5"] = (['50', '60', '70', '75'], ['60'])
|
||||
if int(torch.version.cuda.split('.')[0]) < 12:
|
||||
archflags["7.5+PTX"] = (["75"], ["75"])
|
||||
archflags["5.0;6.0+PTX;7.0;7.5"] = (["50", "60", "70", "75"], ["60"])
|
||||
if int(torch.version.cuda.split(".")[0]) < 12:
|
||||
# CUDA 12 drops compute capability < 5.0
|
||||
archflags["Pascal 3.5"] = (['35', '60', '61'], None)
|
||||
archflags["Pascal 3.5"] = (["35", "60", "61"], None)
|
||||
|
||||
for flags, expected in archflags.items():
|
||||
try:
|
||||
|
|
@ -594,8 +612,13 @@ class TestCppExtensionJIT(common.TestCase):
|
|||
self.assertEqual(sequential[2].parameters()[0].dtype, old_dtype)
|
||||
|
||||
# Make sure we can access these methods recursively.
|
||||
self.assertEqual(len(list(sequential.parameters())), len(net.parameters()) * 2 + 1)
|
||||
self.assertEqual(len(list(sequential.named_parameters())), len(net.named_parameters()) * 2 + 1)
|
||||
self.assertEqual(
|
||||
len(list(sequential.parameters())), len(net.parameters()) * 2 + 1
|
||||
)
|
||||
self.assertEqual(
|
||||
len(list(sequential.named_parameters())),
|
||||
len(net.named_parameters()) * 2 + 1,
|
||||
)
|
||||
self.assertEqual(len(list(sequential.buffers())), len(net.buffers()) * 2)
|
||||
self.assertEqual(len(list(sequential.modules())), 8)
|
||||
|
||||
|
|
@ -751,8 +774,9 @@ class TestCppExtensionJIT(common.TestCase):
|
|||
with self.assertRaises(RuntimeError) as e:
|
||||
torch.utils.cpp_extension.load_inline(
|
||||
name="test_compilation_error_formatting",
|
||||
cpp_sources="int main() { return 0 }")
|
||||
pattern = r'.*(\\n|\\r).*'
|
||||
cpp_sources="int main() { return 0 }",
|
||||
)
|
||||
pattern = r".*(\\n|\\r).*"
|
||||
self.assertNotRegex(str(e), pattern)
|
||||
|
||||
def test_warning(self):
|
||||
|
|
@ -760,7 +784,7 @@ class TestCppExtensionJIT(common.TestCase):
|
|||
# symbol. But because of visibility and the fact that it lives in a
|
||||
# different compilation unit than pybind, this trips up ubsan even though
|
||||
# it is fine. "ubsan.supp" thus needs to contain "vptr:warn_mod.so".
|
||||
source = '''
|
||||
source = """
|
||||
// error_type:
|
||||
// 0: no error
|
||||
// 1: torch::TypeError
|
||||
|
|
@ -788,17 +812,19 @@ class TestCppExtensionJIT(common.TestCase):
|
|||
}
|
||||
return x.cos();
|
||||
}
|
||||
'''
|
||||
"""
|
||||
|
||||
# Ensure double type for hard-coded c name below
|
||||
t = torch.rand(2).double()
|
||||
cpp_tensor_name = r"CPUDoubleType"
|
||||
|
||||
# Without error handling, the warnings cannot be catched
|
||||
warn_mod = torch.utils.cpp_extension.load_inline(name='warn_mod',
|
||||
cpp_sources=[source],
|
||||
functions=['foo'],
|
||||
with_pytorch_error_handling=False)
|
||||
warn_mod = torch.utils.cpp_extension.load_inline(
|
||||
name="warn_mod",
|
||||
cpp_sources=[source],
|
||||
functions=["foo"],
|
||||
with_pytorch_error_handling=False,
|
||||
)
|
||||
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warn_mod.foo(t, 0)
|
||||
|
|
@ -808,7 +834,9 @@ class TestCppExtensionJIT(common.TestCase):
|
|||
warn_mod.foo(t, 1)
|
||||
self.assertEqual(len(w), 0)
|
||||
|
||||
with self.assertRaisesRegex(SystemError, "bad argument to internal function"):
|
||||
with self.assertRaisesRegex(
|
||||
SystemError, "bad argument to internal function"
|
||||
):
|
||||
warn_mod.foo(t, 2)
|
||||
self.assertEqual(len(w), 0)
|
||||
|
||||
|
|
@ -816,12 +844,12 @@ class TestCppExtensionJIT(common.TestCase):
|
|||
warn_mod.foo(t, 3)
|
||||
self.assertEqual(len(w), 0)
|
||||
|
||||
|
||||
warn_mod = torch.utils.cpp_extension.load_inline(name='warn_mod',
|
||||
cpp_sources=[source],
|
||||
functions=['foo'],
|
||||
with_pytorch_error_handling=True)
|
||||
|
||||
warn_mod = torch.utils.cpp_extension.load_inline(
|
||||
name="warn_mod",
|
||||
cpp_sources=[source],
|
||||
functions=["foo"],
|
||||
with_pytorch_error_handling=True,
|
||||
)
|
||||
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
# Catched with no error should be detected
|
||||
|
|
@ -834,7 +862,9 @@ class TestCppExtensionJIT(common.TestCase):
|
|||
self.assertEqual(len(w), 2)
|
||||
|
||||
# Catched with python error should also be detected
|
||||
with self.assertRaisesRegex(SystemError, "bad argument to internal function"):
|
||||
with self.assertRaisesRegex(
|
||||
SystemError, "bad argument to internal function"
|
||||
):
|
||||
warn_mod.foo(t, 2)
|
||||
self.assertEqual(len(w), 3)
|
||||
|
||||
|
|
@ -859,7 +889,7 @@ class TestCppExtensionJIT(common.TestCase):
|
|||
self.assertEqual(len(w), 0)
|
||||
|
||||
def test_autograd_from_cpp(self):
|
||||
source = '''
|
||||
source = """
|
||||
void run_back(at::Tensor x) {
|
||||
x.backward({});
|
||||
}
|
||||
|
|
@ -868,7 +898,7 @@ class TestCppExtensionJIT(common.TestCase):
|
|||
pybind11::gil_scoped_release no_gil;
|
||||
x.backward({});
|
||||
}
|
||||
'''
|
||||
"""
|
||||
|
||||
class MyFn(torch.autograd.Function):
|
||||
@staticmethod
|
||||
|
|
@ -879,14 +909,18 @@ class TestCppExtensionJIT(common.TestCase):
|
|||
def backward(ctx, gx):
|
||||
return gx
|
||||
|
||||
test_backward_deadlock = torch.utils.cpp_extension.load_inline(name='test_backward_deadlock',
|
||||
cpp_sources=[source],
|
||||
functions=['run_back', 'run_back_no_gil'],)
|
||||
test_backward_deadlock = torch.utils.cpp_extension.load_inline(
|
||||
name="test_backward_deadlock",
|
||||
cpp_sources=[source],
|
||||
functions=["run_back", "run_back_no_gil"],
|
||||
)
|
||||
|
||||
# This used to deadlock
|
||||
inp = torch.rand(20, requires_grad=True)
|
||||
loss = MyFn.apply(inp).sum()
|
||||
with self.assertRaisesRegex(RuntimeError, "The autograd engine was called while holding the GIL."):
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "The autograd engine was called while holding the GIL."
|
||||
):
|
||||
test_backward_deadlock.run_back(loss)
|
||||
|
||||
inp = torch.rand(20, requires_grad=True)
|
||||
|
|
@ -936,7 +970,6 @@ class TestCppExtensionJIT(common.TestCase):
|
|||
with self.assertRaisesRegex(RuntimeError, msg):
|
||||
torch.func.grad(identity_m.identity)(t)
|
||||
|
||||
|
||||
def test_gen_extension_h_pch(self):
|
||||
if not IS_LINUX:
|
||||
return
|
||||
|
|
@ -973,5 +1006,6 @@ class TestCppExtensionJIT(common.TestCase):
|
|||
self.assertEqual(pch_exist, True)
|
||||
self.assertEqual(signature_exist, True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
common.run_tests()
|
||||
|
|
|
|||
|
|
@ -3,14 +3,15 @@
|
|||
import os
|
||||
import shutil
|
||||
import sys
|
||||
from typing import Union
|
||||
import tempfile
|
||||
import unittest
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
|
||||
import torch.testing._internal.common_utils as common
|
||||
from torch.testing._internal.common_utils import IS_ARM64, TEST_CUDA
|
||||
import torch
|
||||
import torch.utils.cpp_extension
|
||||
from torch.testing._internal.common_utils import IS_ARM64, TEST_CUDA
|
||||
from torch.utils.cpp_extension import CUDA_HOME, ROCM_HOME
|
||||
|
||||
|
||||
|
|
@ -28,18 +29,19 @@ def remove_build_path():
|
|||
|
||||
|
||||
class DummyModule:
|
||||
|
||||
@staticmethod
|
||||
def device_count() -> int:
|
||||
return 1
|
||||
|
||||
@staticmethod
|
||||
def get_rng_state(device: Union[int, str, torch.device] = 'foo') -> torch.Tensor:
|
||||
def get_rng_state(device: Union[int, str, torch.device] = "foo") -> torch.Tensor:
|
||||
# create a tensor using our custom device object.
|
||||
return torch.empty(4, 4, device="foo")
|
||||
|
||||
@staticmethod
|
||||
def set_rng_state(new_state: torch.Tensor, device: Union[int, str, torch.device] = 'foo') -> None:
|
||||
def set_rng_state(
|
||||
new_state: torch.Tensor, device: Union[int, str, torch.device] = "foo"
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -50,11 +52,12 @@ class DummyModule:
|
|||
def current_device():
|
||||
return 0
|
||||
|
||||
|
||||
@unittest.skipIf(IS_ARM64, "Does not work on arm")
|
||||
@torch.testing._internal.common_utils.markDynamoStrictTest
|
||||
class TestCppExtensionOpenRgistration(common.TestCase):
|
||||
"""Tests Open Device Registration with C++ extensions.
|
||||
"""
|
||||
"""Tests Open Device Registration with C++ extensions."""
|
||||
|
||||
module = None
|
||||
|
||||
def setUp(self):
|
||||
|
|
@ -89,7 +92,7 @@ class TestCppExtensionOpenRgistration(common.TestCase):
|
|||
|
||||
def test_open_device_registration(self):
|
||||
def test_base_device_registration():
|
||||
torch.utils.rename_privateuse1_backend('foo')
|
||||
torch.utils.rename_privateuse1_backend("foo")
|
||||
self.assertFalse(self.module.custom_add_called())
|
||||
# create a tensor using our custom device object
|
||||
device = self.module.custom_device()
|
||||
|
|
@ -103,7 +106,7 @@ class TestCppExtensionOpenRgistration(common.TestCase):
|
|||
z = x + y
|
||||
# check that it was called
|
||||
self.assertTrue(self.module.custom_add_called())
|
||||
z_cpu = z.to(device='cpu')
|
||||
z_cpu = z.to(device="cpu")
|
||||
# Check that our cross-device copy correctly copied the data to cpu
|
||||
self.assertTrue(z_cpu.is_cpu)
|
||||
self.assertFalse(z.is_cpu)
|
||||
|
|
@ -115,40 +118,45 @@ class TestCppExtensionOpenRgistration(common.TestCase):
|
|||
def test_before_common_registration():
|
||||
# check that register module name should be the same as custom backend
|
||||
with self.assertRaisesRegex(RuntimeError, "Expected one of cpu"):
|
||||
torch._register_device_module('xxx', DummyModule)
|
||||
torch._register_device_module("xxx", DummyModule)
|
||||
# check generator registered before using
|
||||
torch.utils.rename_privateuse1_backend('foo')
|
||||
torch.utils.rename_privateuse1_backend("foo")
|
||||
with self.assertRaisesRegex(RuntimeError, "torch has no module of"):
|
||||
with torch.random.fork_rng(device_type="foo"):
|
||||
pass
|
||||
# check attributes before registered
|
||||
self.assertFalse(hasattr(torch.Tensor, 'is_foo'))
|
||||
self.assertFalse(hasattr(torch.Tensor, 'foo'))
|
||||
self.assertFalse(hasattr(torch.TypedStorage, 'is_foo'))
|
||||
self.assertFalse(hasattr(torch.TypedStorage, 'foo'))
|
||||
self.assertFalse(hasattr(torch.UntypedStorage, 'is_foo'))
|
||||
self.assertFalse(hasattr(torch.UntypedStorage, 'foo'))
|
||||
self.assertFalse(hasattr(torch.nn.Module, 'foo'))
|
||||
self.assertFalse(hasattr(torch.Tensor, "is_foo"))
|
||||
self.assertFalse(hasattr(torch.Tensor, "foo"))
|
||||
self.assertFalse(hasattr(torch.TypedStorage, "is_foo"))
|
||||
self.assertFalse(hasattr(torch.TypedStorage, "foo"))
|
||||
self.assertFalse(hasattr(torch.UntypedStorage, "is_foo"))
|
||||
self.assertFalse(hasattr(torch.UntypedStorage, "foo"))
|
||||
self.assertFalse(hasattr(torch.nn.Module, "foo"))
|
||||
|
||||
def test_after_common_registration():
|
||||
# check attributes after registered
|
||||
self.assertTrue(hasattr(torch.Tensor, 'is_foo'))
|
||||
self.assertTrue(hasattr(torch.Tensor, 'foo'))
|
||||
self.assertTrue(hasattr(torch.TypedStorage, 'is_foo'))
|
||||
self.assertTrue(hasattr(torch.TypedStorage, 'foo'))
|
||||
self.assertTrue(hasattr(torch.UntypedStorage, 'is_foo'))
|
||||
self.assertTrue(hasattr(torch.UntypedStorage, 'foo'))
|
||||
self.assertTrue(hasattr(torch.nn.Module, 'foo'))
|
||||
self.assertTrue(hasattr(torch.Tensor, "is_foo"))
|
||||
self.assertTrue(hasattr(torch.Tensor, "foo"))
|
||||
self.assertTrue(hasattr(torch.TypedStorage, "is_foo"))
|
||||
self.assertTrue(hasattr(torch.TypedStorage, "foo"))
|
||||
self.assertTrue(hasattr(torch.UntypedStorage, "is_foo"))
|
||||
self.assertTrue(hasattr(torch.UntypedStorage, "foo"))
|
||||
self.assertTrue(hasattr(torch.nn.Module, "foo"))
|
||||
|
||||
def test_common_registration():
|
||||
# first rename custom backend
|
||||
torch.utils.rename_privateuse1_backend('foo')
|
||||
torch.utils.rename_privateuse1_backend("foo")
|
||||
# backend name can only rename once
|
||||
with self.assertRaisesRegex(RuntimeError, "torch.register_privateuse1_backend()"):
|
||||
torch.utils.rename_privateuse1_backend('xxx')
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "torch.register_privateuse1_backend()"
|
||||
):
|
||||
torch.utils.rename_privateuse1_backend("xxx")
|
||||
# register foo module, torch.foo
|
||||
torch._register_device_module('foo', DummyModule)
|
||||
self.assertTrue(torch.utils.backend_registration._get_custom_mod_func("device_count")() == 1)
|
||||
torch._register_device_module("foo", DummyModule)
|
||||
self.assertTrue(
|
||||
torch.utils.backend_registration._get_custom_mod_func("device_count")()
|
||||
== 1
|
||||
)
|
||||
with self.assertRaisesRegex(RuntimeError, "Try to call torch.foo"):
|
||||
torch.utils.backend_registration._get_custom_mod_func("func_name_")
|
||||
# default set for_tensor and for_module are True, so only set for_storage is True
|
||||
|
|
@ -162,23 +170,29 @@ class TestCppExtensionOpenRgistration(common.TestCase):
|
|||
# None of our CPU operations should call the custom add function.
|
||||
self.assertFalse(self.module.custom_add_called())
|
||||
# check generator registered before using
|
||||
with self.assertRaisesRegex(RuntimeError,
|
||||
"Please register a generator to the PrivateUse1 dispatch key"):
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"Please register a generator to the PrivateUse1 dispatch key",
|
||||
):
|
||||
gen_ = torch.Generator(device=device)
|
||||
self.module.register_generator_first()
|
||||
gen = torch.Generator(device=device)
|
||||
self.assertTrue(gen.device == device)
|
||||
# generator can be registered only once
|
||||
with self.assertRaisesRegex(RuntimeError,
|
||||
"Only can register a generator to the PrivateUse1 dispatch key once"):
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"Only can register a generator to the PrivateUse1 dispatch key once",
|
||||
):
|
||||
self.module.register_generator_second()
|
||||
self.module.register_hook()
|
||||
default_gen = self.module.default_generator(0)
|
||||
self.assertTrue(default_gen.device.type == torch._C._get_privateuse1_backend_name())
|
||||
self.assertTrue(
|
||||
default_gen.device.type == torch._C._get_privateuse1_backend_name()
|
||||
)
|
||||
|
||||
def test_open_device_dispatchstub():
|
||||
# test kernels could be reused by privateuse1 backend through dispatchstub
|
||||
torch.utils.rename_privateuse1_backend('foo')
|
||||
torch.utils.rename_privateuse1_backend("foo")
|
||||
input_data = torch.randn(2, 2, 3, dtype=torch.float32, device="cpu")
|
||||
foo_input_data = input_data.to("foo")
|
||||
output_data = torch.abs(input_data)
|
||||
|
|
@ -202,10 +216,14 @@ class TestCppExtensionOpenRgistration(common.TestCase):
|
|||
self.assertEqual(output_data, foo_output_data.cpu())
|
||||
|
||||
def test_open_device_quantized():
|
||||
torch.utils.rename_privateuse1_backend('foo')
|
||||
input_data = torch.randn(3, 4, 5, dtype=torch.float32, device="cpu").to("foo")
|
||||
quantized_tensor = torch.quantize_per_tensor(input_data, 0.1, 10, torch.qint8)
|
||||
self.assertEqual(quantized_tensor.device, torch.device('foo:0'))
|
||||
torch.utils.rename_privateuse1_backend("foo")
|
||||
input_data = torch.randn(3, 4, 5, dtype=torch.float32, device="cpu").to(
|
||||
"foo"
|
||||
)
|
||||
quantized_tensor = torch.quantize_per_tensor(
|
||||
input_data, 0.1, 10, torch.qint8
|
||||
)
|
||||
self.assertEqual(quantized_tensor.device, torch.device("foo:0"))
|
||||
self.assertEqual(quantized_tensor.dtype, torch.qint8)
|
||||
|
||||
def test_open_device_random():
|
||||
|
|
@ -216,15 +234,15 @@ class TestCppExtensionOpenRgistration(common.TestCase):
|
|||
device = self.module.custom_device()
|
||||
# check whether print tensor.type() meets the expectation
|
||||
dtypes = {
|
||||
torch.bool: 'torch.foo.BoolTensor',
|
||||
torch.double: 'torch.foo.DoubleTensor',
|
||||
torch.float32: 'torch.foo.FloatTensor',
|
||||
torch.half: 'torch.foo.HalfTensor',
|
||||
torch.int32: 'torch.foo.IntTensor',
|
||||
torch.int64: 'torch.foo.LongTensor',
|
||||
torch.int8: 'torch.foo.CharTensor',
|
||||
torch.short: 'torch.foo.ShortTensor',
|
||||
torch.uint8: 'torch.foo.ByteTensor',
|
||||
torch.bool: "torch.foo.BoolTensor",
|
||||
torch.double: "torch.foo.DoubleTensor",
|
||||
torch.float32: "torch.foo.FloatTensor",
|
||||
torch.half: "torch.foo.HalfTensor",
|
||||
torch.int32: "torch.foo.IntTensor",
|
||||
torch.int64: "torch.foo.LongTensor",
|
||||
torch.int8: "torch.foo.CharTensor",
|
||||
torch.short: "torch.foo.ShortTensor",
|
||||
torch.uint8: "torch.foo.ByteTensor",
|
||||
}
|
||||
for tt, dt in dtypes.items():
|
||||
test_tensor = torch.empty(4, 4, dtype=tt, device=device)
|
||||
|
|
@ -284,9 +302,11 @@ class TestCppExtensionOpenRgistration(common.TestCase):
|
|||
self.assertTrue(self.module.custom_storageImpl_called())
|
||||
|
||||
def test_open_device_storage_pin_memory():
|
||||
torch.utils.rename_privateuse1_backend('foo')
|
||||
torch.utils.rename_privateuse1_backend("foo")
|
||||
with self.assertRaisesRegex(RuntimeError, "The custom device module of"):
|
||||
torch.utils.generate_methods_for_privateuse1_backend(for_tensor=False, for_module=False, for_storage=True)
|
||||
torch.utils.generate_methods_for_privateuse1_backend(
|
||||
for_tensor=False, for_module=False, for_storage=True
|
||||
)
|
||||
# Check if the pin_memory is functioning properly on custom device
|
||||
cpu_tensor = torch.empty(3)
|
||||
self.assertFalse(cpu_tensor.is_foo)
|
||||
|
|
@ -333,32 +353,42 @@ class TestCppExtensionOpenRgistration(common.TestCase):
|
|||
self.assertFalse(cpu_untyped_storage_pinned.is_pinned())
|
||||
self.assertTrue(cpu_untyped_storage_pinned.is_pinned("foo"))
|
||||
self.assertTrue(cpu_untyped_storage_pinned.is_pinned(foo_device))
|
||||
with self.assertRaisesRegex(TypeError, "positional arguments but 3 were given"):
|
||||
with self.assertRaisesRegex(
|
||||
TypeError, "positional arguments but 3 were given"
|
||||
):
|
||||
cpu_untyped_storage_pinned.is_pinned("foo1", "foo2")
|
||||
|
||||
# Test storage pin_memory on error device
|
||||
self.assertFalse(cpu_storage_pinned.is_pinned("hpu"))
|
||||
with self.assertRaisesRegex(NotImplementedError, "with arguments from the 'HPU' backend"):
|
||||
with self.assertRaisesRegex(
|
||||
NotImplementedError, "with arguments from the 'HPU' backend"
|
||||
):
|
||||
cpu_storage.pin_memory("hpu")
|
||||
self.assertFalse(cpu_untyped_storage_pinned.is_pinned("hpu"))
|
||||
with self.assertRaisesRegex(NotImplementedError, "with arguments from the 'HPU' backend"):
|
||||
with self.assertRaisesRegex(
|
||||
NotImplementedError, "with arguments from the 'HPU' backend"
|
||||
):
|
||||
cpu_untyped_storage.pin_memory("hpu")
|
||||
invalid_device = torch.device("hpu")
|
||||
self.assertFalse(cpu_untyped_storage_pinned.is_pinned(invalid_device))
|
||||
with self.assertRaisesRegex(NotImplementedError, "with arguments from the 'HPU' backend"):
|
||||
with self.assertRaisesRegex(
|
||||
NotImplementedError, "with arguments from the 'HPU' backend"
|
||||
):
|
||||
cpu_untyped_storage.pin_memory(invalid_device)
|
||||
|
||||
def test_open_device_serialization():
|
||||
self.module.set_custom_device_index(-1)
|
||||
storage = torch.UntypedStorage(4, device=torch.device('foo'))
|
||||
self.assertEqual(torch.serialization.location_tag(storage), 'foo')
|
||||
storage = torch.UntypedStorage(4, device=torch.device("foo"))
|
||||
self.assertEqual(torch.serialization.location_tag(storage), "foo")
|
||||
|
||||
self.module.set_custom_device_index(0)
|
||||
storage = torch.UntypedStorage(4, device=torch.device('foo'))
|
||||
self.assertEqual(torch.serialization.location_tag(storage), 'foo:0')
|
||||
storage = torch.UntypedStorage(4, device=torch.device("foo"))
|
||||
self.assertEqual(torch.serialization.location_tag(storage), "foo:0")
|
||||
|
||||
cpu_storage = torch.empty(4, 4).storage()
|
||||
foo_storage = torch.serialization.default_restore_location(cpu_storage, 'foo:0')
|
||||
foo_storage = torch.serialization.default_restore_location(
|
||||
cpu_storage, "foo:0"
|
||||
)
|
||||
self.assertTrue(foo_storage.is_foo)
|
||||
# test tensor MetaData serialization
|
||||
x = torch.empty(4, 4).long()
|
||||
|
|
@ -369,7 +399,7 @@ class TestCppExtensionOpenRgistration(common.TestCase):
|
|||
|
||||
self.module.custom_serialization_registry()
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
path = os.path.join(tmpdir, 'data.pt')
|
||||
path = os.path.join(tmpdir, "data.pt")
|
||||
torch.save(y, path)
|
||||
z1 = torch.load(path)
|
||||
# loads correctly onto the foo backend device
|
||||
|
|
@ -377,14 +407,14 @@ class TestCppExtensionOpenRgistration(common.TestCase):
|
|||
# loads BackendMeta data correctly
|
||||
self.assertTrue(self.module.check_backend_meta(z1))
|
||||
# cross-backend
|
||||
z2 = torch.load(path, map_location='cpu')
|
||||
z2 = torch.load(path, map_location="cpu")
|
||||
# loads correctly onto the cpu backend device
|
||||
self.assertFalse(z2.is_foo)
|
||||
# loads BackendMeta data correctly
|
||||
self.assertFalse(self.module.check_backend_meta(z2))
|
||||
|
||||
def test_open_device_storage_resize():
|
||||
torch.utils.rename_privateuse1_backend('foo')
|
||||
torch.utils.rename_privateuse1_backend("foo")
|
||||
cpu_tensor = torch.randn([8])
|
||||
foo_tensor = cpu_tensor.foo()
|
||||
foo_storage = foo_tensor.storage()
|
||||
|
|
@ -392,11 +422,11 @@ class TestCppExtensionOpenRgistration(common.TestCase):
|
|||
# Only register tensor resize_ function.
|
||||
foo_tensor.resize_(8)
|
||||
self.assertTrue(foo_storage.size() == 8)
|
||||
with self.assertRaisesRegex(TypeError, 'Overflow'):
|
||||
with self.assertRaisesRegex(TypeError, "Overflow"):
|
||||
foo_tensor.resize_(8**29)
|
||||
|
||||
def test_open_device_storage_type():
|
||||
torch.utils.rename_privateuse1_backend('foo')
|
||||
torch.utils.rename_privateuse1_backend("foo")
|
||||
# test cpu float storage
|
||||
cpu_tensor = torch.randn([8]).float()
|
||||
cpu_storage = cpu_tensor.storage()
|
||||
|
|
@ -429,14 +459,14 @@ class TestCppExtensionOpenRgistration(common.TestCase):
|
|||
torch.foo.FloatStorage = None
|
||||
|
||||
def test_open_device_faketensor():
|
||||
torch.utils.rename_privateuse1_backend('foo')
|
||||
torch.utils.rename_privateuse1_backend("foo")
|
||||
with torch._subclasses.fake_tensor.FakeTensorMode.push():
|
||||
a = torch.empty(1, device="foo")
|
||||
b = torch.empty(1, device="foo:0")
|
||||
result = a + b
|
||||
|
||||
def test_open_device_named_tensor():
|
||||
torch.utils.rename_privateuse1_backend('foo')
|
||||
torch.utils.rename_privateuse1_backend("foo")
|
||||
a = torch.empty([2, 3, 4, 5], device="foo", names=["N", "C", "H", "W"])
|
||||
|
||||
# Not an open registration test - this file is just very convenient
|
||||
|
|
@ -462,7 +492,9 @@ class TestCppExtensionOpenRgistration(common.TestCase):
|
|||
out_ref.sum().backward()
|
||||
|
||||
x_test = x_ref.clone().detach().requires_grad_(True)
|
||||
f_compiled = torch.compile(torch.ops._test_funcs.custom_autograd_fn_aliasing)
|
||||
f_compiled = torch.compile(
|
||||
torch.ops._test_funcs.custom_autograd_fn_aliasing
|
||||
)
|
||||
out_test = f_compiled(x_test)
|
||||
out_test.sum().backward()
|
||||
|
||||
|
|
@ -470,16 +502,18 @@ class TestCppExtensionOpenRgistration(common.TestCase):
|
|||
self.assertEqual(x_ref.grad, x_test.grad)
|
||||
|
||||
def test_open_device_scalar_type_fallback():
|
||||
torch.utils.rename_privateuse1_backend('foo')
|
||||
z_cpu = torch.Tensor([[0, 0, 0, 1, 1, 2], [0, 1, 2, 1, 2, 2]]).to(torch.int64)
|
||||
z = torch.triu_indices(3, 3, device='foo')
|
||||
torch.utils.rename_privateuse1_backend("foo")
|
||||
z_cpu = torch.Tensor([[0, 0, 0, 1, 1, 2], [0, 1, 2, 1, 2, 2]]).to(
|
||||
torch.int64
|
||||
)
|
||||
z = torch.triu_indices(3, 3, device="foo")
|
||||
self.assertEqual(z_cpu, z)
|
||||
|
||||
def test_open_device_tensor_type_fallback():
|
||||
torch.utils.rename_privateuse1_backend('foo')
|
||||
torch.utils.rename_privateuse1_backend("foo")
|
||||
# create tensors located in custom device
|
||||
x = torch.Tensor([[1, 2, 3], [2, 3, 4]]).to('foo')
|
||||
y = torch.Tensor([1, 0, 2]).to('foo')
|
||||
x = torch.Tensor([[1, 2, 3], [2, 3, 4]]).to("foo")
|
||||
y = torch.Tensor([1, 0, 2]).to("foo")
|
||||
# create result tensor located in cpu
|
||||
z_cpu = torch.Tensor([[0, 2, 1], [1, 3, 2]])
|
||||
# Check that our device is correct.
|
||||
|
|
@ -491,14 +525,14 @@ class TestCppExtensionOpenRgistration(common.TestCase):
|
|||
self.assertEqual(z_cpu, z)
|
||||
# call index op, which will fallback to cpu
|
||||
z_cpu = torch.Tensor([3, 1])
|
||||
y = torch.Tensor([1, 0]).long().to('foo')
|
||||
y = torch.Tensor([1, 0]).long().to("foo")
|
||||
z = x[y, y]
|
||||
self.assertEqual(z_cpu, z)
|
||||
|
||||
def test_open_device_tensorlist_type_fallback():
|
||||
torch.utils.rename_privateuse1_backend('foo')
|
||||
torch.utils.rename_privateuse1_backend("foo")
|
||||
# create tensors located in custom device
|
||||
v_foo = torch.Tensor([1, 2, 3]).to('foo')
|
||||
v_foo = torch.Tensor([1, 2, 3]).to("foo")
|
||||
# create result tensor located in cpu
|
||||
z_cpu = torch.Tensor([2, 4, 6])
|
||||
# create tensorlist for foreach_add op
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user