mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
# Summary This PR adds a new higher-order_op: `templated_attention`. This op is designed to extend the functionality of torch.nn.fucntional.scaled_dot_product_attention. PyTorch has efficient pre-written fused-attention kernels. However, users want to modify how scores are computed (a substep inside attention) -- this traditionally requires the user to write their own attention kernel. One such modification to attention scores that is not currently supported by the top level SDPA op is:[ Attention with Linear Biases (ALiBi](https://arxiv.org/abs/2108.12409)). This higher-order op will instead accept a callable( 'score_mod') function that is through torch.compile will be used to create an efficient attention kernel instantiation. ### Details This HOP utilizes the existing fx and HOP infra to capture and convert the User `score-mod` function and convert to an FX graph module. Inductor then consumes this HOP that has a `ir.Subgraph` input. It will inline this lowered subgraph into a triton kernel which performs fused attention with the modification to the scores matrix inlined. ### API The API for a score_mod function should be as follows: ```Python def score_mod(score: torch.Tensor, batch: torch.Tensor, head: torch.Tensor, token_1: torch.Tensor, token_kv: torch.Tensor) -> torch.Tensor ``` This function receives five parameters: - `score`: A scalar tensor representing the attention score, with the same data type and device as the query, key, and value tensors. - `batch`, `head`, `seq_len_q`, `seq_len_kv`: Scalar tensors indicating the batch index, head index, query index, and key/value index, respectively, with torch.int data type and located on the same device as the score tensor. Consider inputs query, key, value of shapes (2, 4, 16, 8), leading to an intermediate attention score matrix of shape (2, 4, 16, 16) The score_mod function will be vectorized over each element of this matrix. For instance, modifying the score at the position corresponding to the 0th batch, 2nd head, between the 8th query and the 9th key element, would be invoked as: ```Python score_mod(score[0,2,8,9], torch.tensor(0), torch.tensor(2), torch.tensor(8), torch.tensor(9)) ``` ### Examples ```Python import torch from torch.nn.attention.templated_attention import templated_attention torch.manual_seed(0) # Lets create some input tensors # The input tensor has shape (batch_size, num_heads, seq_len, head_dim) query = torch.randn(8, 8, 2048, 64, device="cuda", dtype=torch.float32) key = torch.randn(8, 8, 2048, 64, device="cuda", dtype=torch.float32) value = torch.randn(8, 8, 2048, 64, device="cuda", dtype=torch.float32) # Lets create a fun new score_modification! I will call this # Checkerboard. It will reduce the score for neighboring tokens (1 step apart) # in the sequence. And increase the score for tokens 2 steps apart. For everything # else, the score will remain the same. def checkerboard(score, batch, head, token_q, token_kv): score = torch.where(torch.abs(token_kv - token_q) == 1, score * 0.5, score) score = torch.where(torch.abs(token_kv - token_q) == 2, score * 2.0, score) return score # Lets call templated_attention with this new score modification output = templated_attention(query, key, value, score_mod=checkerboard) compiled_templated_attention = torch.compile(templated_attention) out_compiled = compiled_templated_attention(query, key, value, score_mod=checkerboard) torch.testing.assert_close(output, out_compiled, atol=2e-2, rtol=2e-2) ``` ### Future Work - This PR is currently only forward only. However the triton kernel for backwards where score_modifications to not rely on external buffers has been explored here: https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/flash/flash_attention.py - Kernel Improvements; There are has been some larger updates to the fused attention implementation that Triton uses in its tutorials. The implementation of this kernel is based on a prior version and should be updated. - We may want to unify this API under the top level SDPA API and leave that as a follow up once this is more stable - Should we error on CPU? - There are some issues with dynamic shapes - Capturing of free variables and lifting to inputs to the subgraph is not working correctly today ### Performance Comparisons generated by this benchmark: | Type | Speedup | batch_size | num_heads | q_seq_len | k_seq_len | head_dim | score_mod | dtype | |---------|-----------|--------------|-------------|-------------|-------------|------------|---------------|----------------| | Average | 5.412 | | | | | | | | | Max | 8.882 | 16 | 16 | 4096 | 4096 | 64 | relative_bias | torch.bfloat16 | | Min | 3.645 | 8 | 16 | 512 | 512 | 64 | causal_mask | torch.bfloat16 | | Min | 0.345 | 1 | 16 | 1024 | 1024 | 64 | pathological | torch.bfloat16 | For reference | Configuration | Forward Time (µ seconds) | Backend | Speedup | |-----------------------------------------------|--------------------------|------------------|---------| | Fastest Config in Sweep (`8 16 4096 4096 64 relative_bias torch.bfloat16`) | 3608 | Templated Attention | 1.0 | | Compiled SDPA (No Mask) | 9928 | Math | 2.75x | | Compiled SDPA (With Mask) | 11898 | Math | 3.29x | | Compiled SDPA (With Mask) | 8704 | Memory Efficient Attention | 2.42x | | Compiled SDPA (No Mask) | 2548 | FlashAttention2 | 0.706x | The speedups are measuring compiled templated attention speed versus different calls to torch.nn.functional.sdpa <details> <summary> FULL PERFORMANCE SWEEP NUMBERS </summary> | batch_size | num_heads | q_seq_len | k_seq_len | head_dim | score_mod | dtype | eager_time | compiled_time | speedup | |--------------|-------------|-------------|-------------|------------|---------------|----------------|--------------|-----------------|-----------| | 1 | 16 | 512 | 512 | 64 | causal_mask | torch.bfloat16 | 331.444 | 67.221 | 4.931 | | 1 | 16 | 512 | 512 | 64 | relative_bias | torch.bfloat16 | 335.300 | 64.187 | 5.224 | | 1 | 16 | 512 | 512 | 64 | head_bias | torch.bfloat16 | 352.039 | 63.806 | 5.517 | | 1 | 16 | 512 | 512 | 64 | pathological | torch.bfloat16 | 371.699 | 711.349 | 0.523 | | 1 | 16 | 1024 | 1024 | 64 | causal_mask | torch.bfloat16 | 333.488 | 86.455 | 3.857 | | 1 | 16 | 1024 | 1024 | 64 | relative_bias | torch.bfloat16 | 322.363 | 82.469 | 3.909 | | 1 | 16 | 1024 | 1024 | 64 | head_bias | torch.bfloat16 | 349.967 | 82.233 | 4.256 | | 1 | 16 | 1024 | 1024 | 64 | pathological | torch.bfloat16 | 486.359 | 1412.453 | 0.344 | | 1 | 16 | 4096 | 4096 | 64 | causal_mask | torch.bfloat16 | 2794.597 | 551.188 | 5.070 | | 1 | 16 | 4096 | 4096 | 64 | relative_bias | torch.bfloat16 | 3965.150 | 513.101 | 7.728 | | 1 | 16 | 4096 | 4096 | 64 | head_bias | torch.bfloat16 | 2408.013 | 504.759 | 4.771 | | 1 | 16 | 4096 | 4096 | 64 | pathological | torch.bfloat16 | 6850.531 | 16733.675 | 0.409 | | 8 | 16 | 512 | 512 | 64 | causal_mask | torch.bfloat16 | 441.939 | 123.576 | 3.576 | | 8 | 16 | 512 | 512 | 64 | relative_bias | torch.bfloat16 | 560.379 | 116.710 | 4.801 | | 8 | 16 | 512 | 512 | 64 | head_bias | torch.bfloat16 | 421.172 | 115.825 | 3.636 | | 8 | 16 | 512 | 512 | 64 | pathological | torch.bfloat16 | 994.492 | 2132.806 | 0.466 | | 8 | 16 | 1024 | 1024 | 64 | causal_mask | torch.bfloat16 | 1436.430 | 309.495 | 4.641 | | 8 | 16 | 1024 | 1024 | 64 | relative_bias | torch.bfloat16 | 1892.216 | 290.186 | 6.521 | | 8 | 16 | 1024 | 1024 | 64 | head_bias | torch.bfloat16 | 1360.665 | 282.956 | 4.809 | | 8 | 16 | 1024 | 1024 | 64 | pathological | torch.bfloat16 | 3525.532 | 8359.702 | 0.422 | | 8 | 16 | 4096 | 4096 | 64 | causal_mask | torch.bfloat16 | 22026.839 | 3864.604 | 5.700 | | 8 | 16 | 4096 | 4096 | 64 | relative_bias | torch.bfloat16 | 31262.746 | 3609.551 | 8.661 | | 8 | 16 | 4096 | 4096 | 64 | head_bias | torch.bfloat16 | 20219.079 | 3480.402 | 5.809 | | 8 | 16 | 4096 | 4096 | 64 | pathological | torch.bfloat16 | 54654.647 | 116652.357 | 0.469 | | 16 | 16 | 512 | 512 | 64 | causal_mask | torch.bfloat16 | 820.606 | 188.683 | 4.349 | | 16 | 16 | 512 | 512 | 64 | relative_bias | torch.bfloat16 | 1058.362 | 179.295 | 5.903 | | 16 | 16 | 512 | 512 | 64 | head_bias | torch.bfloat16 | 784.372 | 175.714 | 4.464 | | 16 | 16 | 512 | 512 | 64 | pathological | torch.bfloat16 | 1890.792 | 4212.877 | 0.449 | | 16 | 16 | 1024 | 1024 | 64 | causal_mask | torch.bfloat16 | 2781.830 | 557.017 | 4.994 | | 16 | 16 | 1024 | 1024 | 64 | relative_bias | torch.bfloat16 | 3694.050 | 525.249 | 7.033 | | 16 | 16 | 1024 | 1024 | 64 | head_bias | torch.bfloat16 | 2634.164 | 507.613 | 5.189 | | 16 | 16 | 1024 | 1024 | 64 | pathological | torch.bfloat16 | 6959.917 | 15331.116 | 0.454 | | 16 | 16 | 4096 | 4096 | 64 | causal_mask | torch.bfloat16 | 43889.096 | 7582.018 | 5.789 | | 16 | 16 | 4096 | 4096 | 64 | relative_bias | torch.bfloat16 | 62784.293 | 7075.846 | 8.873 | | 16 | 16 | 4096 | 4096 | 64 | head_bias | torch.bfloat16 | 40308.606 | 6829.587 | 5.902 | | 16 | 16 | 4096 | 4096 | 64 | pathological | torch.bfloat16 | 108892.137 | 233090.953 | 0.467 | </details> Pull Request resolved: https://github.com/pytorch/pytorch/pull/121845 Approved by: https://github.com/Chillee, https://github.com/zou3519
273 lines
11 KiB
Python
273 lines
11 KiB
Python
# mypy: ignore-errors
|
|
|
|
r"""This file is allowed to initialize CUDA context when imported."""
|
|
|
|
import functools
|
|
import torch
|
|
import torch.cuda
|
|
from torch.testing._internal.common_utils import LazyVal, TEST_NUMBA, TEST_WITH_ROCM, TEST_CUDA, IS_WINDOWS
|
|
import inspect
|
|
import contextlib
|
|
import os
|
|
|
|
|
|
CUDA_ALREADY_INITIALIZED_ON_IMPORT = torch.cuda.is_initialized()
|
|
|
|
|
|
TEST_MULTIGPU = TEST_CUDA and torch.cuda.device_count() >= 2
|
|
CUDA_DEVICE = torch.device("cuda:0") if TEST_CUDA else None
|
|
# note: if ROCm is targeted, TEST_CUDNN is code for TEST_MIOPEN
|
|
if TEST_WITH_ROCM:
|
|
TEST_CUDNN = LazyVal(lambda: TEST_CUDA)
|
|
else:
|
|
TEST_CUDNN = LazyVal(lambda: TEST_CUDA and torch.backends.cudnn.is_acceptable(torch.tensor(1., device=CUDA_DEVICE)))
|
|
|
|
TEST_CUDNN_VERSION = LazyVal(lambda: torch.backends.cudnn.version() if TEST_CUDNN else 0)
|
|
|
|
SM53OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (5, 3))
|
|
SM60OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (6, 0))
|
|
SM70OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (7, 0))
|
|
SM75OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (7, 5))
|
|
SM80OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 0))
|
|
SM90OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0))
|
|
|
|
def evaluate_gfx_arch_exact(matching_arch):
|
|
if not torch.cuda.is_available():
|
|
return False
|
|
gcn_arch_name = torch.cuda.get_device_properties('cuda').gcnArchName
|
|
arch = os.environ.get('PYTORCH_DEBUG_FLASH_ATTENTION_GCN_ARCH_OVERRIDE', gcn_arch_name)
|
|
return arch == matching_arch
|
|
|
|
GFX90A_Exact = LazyVal(lambda: evaluate_gfx_arch_exact('gfx90a:sramecc+:xnack-'))
|
|
GFX942_Exact = LazyVal(lambda: evaluate_gfx_arch_exact('gfx942:sramecc+:xnack-'))
|
|
|
|
def evaluate_platform_supports_flash_attention():
|
|
if TEST_WITH_ROCM:
|
|
return evaluate_gfx_arch_exact('gfx90a:sramecc+:xnack-') or evaluate_gfx_arch_exact('gfx942:sramecc+:xnack-')
|
|
if TEST_CUDA:
|
|
return not IS_WINDOWS and SM80OrLater
|
|
return False
|
|
|
|
PLATFORM_SUPPORTS_FLASH_ATTENTION: bool = LazyVal(lambda: evaluate_platform_supports_flash_attention())
|
|
PLATFORM_SUPPORTS_MEM_EFF_ATTENTION: bool = LazyVal(lambda: TEST_CUDA and not TEST_WITH_ROCM)
|
|
# TODO(eqy): gate this against a cuDNN version
|
|
PLATFORM_SUPPORTS_CUDNN_ATTENTION: bool = LazyVal(lambda: TEST_CUDA and not TEST_WITH_ROCM and
|
|
torch.backends.cuda.cudnn_sdp_enabled())
|
|
# This condition always evaluates to PLATFORM_SUPPORTS_MEM_EFF_ATTENTION but for logical clarity we keep it separate
|
|
PLATFORM_SUPPORTS_FUSED_ATTENTION: bool = LazyVal(lambda: PLATFORM_SUPPORTS_FLASH_ATTENTION or PLATFORM_SUPPORTS_MEM_EFF_ATTENTION)
|
|
|
|
PLATFORM_SUPPORTS_FUSED_SDPA: bool = TEST_CUDA and not TEST_WITH_ROCM
|
|
|
|
PLATFORM_SUPPORTS_BF16: bool = LazyVal(lambda: TEST_CUDA and SM80OrLater)
|
|
|
|
if TEST_NUMBA:
|
|
try:
|
|
import numba.cuda
|
|
TEST_NUMBA_CUDA = numba.cuda.is_available()
|
|
except Exception as e:
|
|
TEST_NUMBA_CUDA = False
|
|
TEST_NUMBA = False
|
|
else:
|
|
TEST_NUMBA_CUDA = False
|
|
|
|
# Used below in `initialize_cuda_context_rng` to ensure that CUDA context and
|
|
# RNG have been initialized.
|
|
__cuda_ctx_rng_initialized = False
|
|
|
|
|
|
# after this call, CUDA context and RNG must have been initialized on each GPU
|
|
def initialize_cuda_context_rng():
|
|
global __cuda_ctx_rng_initialized
|
|
assert TEST_CUDA, 'CUDA must be available when calling initialize_cuda_context_rng'
|
|
if not __cuda_ctx_rng_initialized:
|
|
# initialize cuda context and rng for memory tests
|
|
for i in range(torch.cuda.device_count()):
|
|
torch.randn(1, device=f"cuda:{i}")
|
|
__cuda_ctx_rng_initialized = True
|
|
|
|
|
|
# Test whether hardware TF32 math mode enabled. It is enabled only on:
|
|
# - CUDA >= 11
|
|
# - arch >= Ampere
|
|
def tf32_is_not_fp32():
|
|
if not torch.cuda.is_available() or torch.version.cuda is None:
|
|
return False
|
|
if torch.cuda.get_device_properties(torch.cuda.current_device()).major < 8:
|
|
return False
|
|
if int(torch.version.cuda.split('.')[0]) < 11:
|
|
return False
|
|
return True
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def tf32_off():
|
|
old_allow_tf32_matmul = torch.backends.cuda.matmul.allow_tf32
|
|
try:
|
|
torch.backends.cuda.matmul.allow_tf32 = False
|
|
with torch.backends.cudnn.flags(enabled=None, benchmark=None, deterministic=None, allow_tf32=False):
|
|
yield
|
|
finally:
|
|
torch.backends.cuda.matmul.allow_tf32 = old_allow_tf32_matmul
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def tf32_on(self, tf32_precision=1e-5):
|
|
old_allow_tf32_matmul = torch.backends.cuda.matmul.allow_tf32
|
|
old_precision = self.precision
|
|
try:
|
|
torch.backends.cuda.matmul.allow_tf32 = True
|
|
self.precision = tf32_precision
|
|
with torch.backends.cudnn.flags(enabled=None, benchmark=None, deterministic=None, allow_tf32=True):
|
|
yield
|
|
finally:
|
|
torch.backends.cuda.matmul.allow_tf32 = old_allow_tf32_matmul
|
|
self.precision = old_precision
|
|
|
|
|
|
# This is a wrapper that wraps a test to run this test twice, one with
|
|
# allow_tf32=True, another with allow_tf32=False. When running with
|
|
# allow_tf32=True, it will use reduced precision as specified by the
|
|
# argument. For example:
|
|
# @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
|
|
# @tf32_on_and_off(0.005)
|
|
# def test_matmul(self, device, dtype):
|
|
# a = ...; b = ...;
|
|
# c = torch.matmul(a, b)
|
|
# self.assertEqual(c, expected)
|
|
# In the above example, when testing torch.float32 and torch.complex64 on CUDA
|
|
# on a CUDA >= 11 build on an >=Ampere architecture, the matmul will be running at
|
|
# TF32 mode and TF32 mode off, and on TF32 mode, the assertEqual will use reduced
|
|
# precision to check values.
|
|
#
|
|
# This decorator can be used for function with or without device/dtype, such as
|
|
# @tf32_on_and_off(0.005)
|
|
# def test_my_op(self)
|
|
# @tf32_on_and_off(0.005)
|
|
# def test_my_op(self, device)
|
|
# @tf32_on_and_off(0.005)
|
|
# def test_my_op(self, device, dtype)
|
|
# @tf32_on_and_off(0.005)
|
|
# def test_my_op(self, dtype)
|
|
# if neither device nor dtype is specified, it will check if the system has ampere device
|
|
# if device is specified, it will check if device is cuda
|
|
# if dtype is specified, it will check if dtype is float32 or complex64
|
|
# tf32 and fp32 are different only when all the three checks pass
|
|
def tf32_on_and_off(tf32_precision=1e-5):
|
|
def with_tf32_disabled(self, function_call):
|
|
with tf32_off():
|
|
function_call()
|
|
|
|
def with_tf32_enabled(self, function_call):
|
|
with tf32_on(self, tf32_precision):
|
|
function_call()
|
|
|
|
def wrapper(f):
|
|
params = inspect.signature(f).parameters
|
|
arg_names = tuple(params.keys())
|
|
|
|
@functools.wraps(f)
|
|
def wrapped(*args, **kwargs):
|
|
for k, v in zip(arg_names, args):
|
|
kwargs[k] = v
|
|
cond = tf32_is_not_fp32()
|
|
if 'device' in kwargs:
|
|
cond = cond and (torch.device(kwargs['device']).type == 'cuda')
|
|
if 'dtype' in kwargs:
|
|
cond = cond and (kwargs['dtype'] in {torch.float32, torch.complex64})
|
|
if cond:
|
|
with_tf32_disabled(kwargs['self'], lambda: f(**kwargs))
|
|
with_tf32_enabled(kwargs['self'], lambda: f(**kwargs))
|
|
else:
|
|
f(**kwargs)
|
|
|
|
return wrapped
|
|
return wrapper
|
|
|
|
|
|
# This is a wrapper that wraps a test to run it with TF32 turned off.
|
|
# This wrapper is designed to be used when a test uses matmul or convolutions
|
|
# but the purpose of that test is not testing matmul or convolutions.
|
|
# Disabling TF32 will enforce torch.float tensors to be always computed
|
|
# at full precision.
|
|
def with_tf32_off(f):
|
|
@functools.wraps(f)
|
|
def wrapped(*args, **kwargs):
|
|
with tf32_off():
|
|
return f(*args, **kwargs)
|
|
|
|
return wrapped
|
|
|
|
def _get_magma_version():
|
|
if 'Magma' not in torch.__config__.show():
|
|
return (0, 0)
|
|
position = torch.__config__.show().find('Magma ')
|
|
version_str = torch.__config__.show()[position + len('Magma '):].split('\n')[0]
|
|
return tuple(int(x) for x in version_str.split("."))
|
|
|
|
def _get_torch_cuda_version():
|
|
if torch.version.cuda is None:
|
|
return (0, 0)
|
|
cuda_version = str(torch.version.cuda)
|
|
return tuple(int(x) for x in cuda_version.split("."))
|
|
|
|
def _get_torch_rocm_version():
|
|
if not TEST_WITH_ROCM:
|
|
return (0, 0)
|
|
rocm_version = str(torch.version.hip)
|
|
rocm_version = rocm_version.split("-")[0] # ignore git sha
|
|
return tuple(int(x) for x in rocm_version.split("."))
|
|
|
|
def _check_cusparse_generic_available():
|
|
return not TEST_WITH_ROCM
|
|
|
|
def _check_hipsparse_generic_available():
|
|
if not TEST_WITH_ROCM:
|
|
return False
|
|
|
|
rocm_version = str(torch.version.hip)
|
|
rocm_version = rocm_version.split("-")[0] # ignore git sha
|
|
rocm_version_tuple = tuple(int(x) for x in rocm_version.split("."))
|
|
return not (rocm_version_tuple is None or rocm_version_tuple < (5, 1))
|
|
|
|
|
|
TEST_CUSPARSE_GENERIC = _check_cusparse_generic_available()
|
|
TEST_HIPSPARSE_GENERIC = _check_hipsparse_generic_available()
|
|
|
|
# Shared by test_torch.py and test_multigpu.py
|
|
def _create_scaling_models_optimizers(device="cuda", optimizer_ctor=torch.optim.SGD, optimizer_kwargs=None):
|
|
# Create a module+optimizer that will use scaling, and a control module+optimizer
|
|
# that will not use scaling, against which the scaling-enabled module+optimizer can be compared.
|
|
mod_control = torch.nn.Sequential(torch.nn.Linear(8, 8), torch.nn.Linear(8, 8)).to(device=device)
|
|
mod_scaling = torch.nn.Sequential(torch.nn.Linear(8, 8), torch.nn.Linear(8, 8)).to(device=device)
|
|
with torch.no_grad():
|
|
for c, s in zip(mod_control.parameters(), mod_scaling.parameters()):
|
|
s.copy_(c)
|
|
|
|
kwargs = {"lr": 1.0}
|
|
if optimizer_kwargs is not None:
|
|
kwargs.update(optimizer_kwargs)
|
|
opt_control = optimizer_ctor(mod_control.parameters(), **kwargs)
|
|
opt_scaling = optimizer_ctor(mod_scaling.parameters(), **kwargs)
|
|
|
|
return mod_control, mod_scaling, opt_control, opt_scaling
|
|
|
|
# Shared by test_torch.py, test_cuda.py and test_multigpu.py
|
|
def _create_scaling_case(device="cuda", dtype=torch.float, optimizer_ctor=torch.optim.SGD, optimizer_kwargs=None):
|
|
data = [(torch.randn((8, 8), dtype=dtype, device=device), torch.randn((8, 8), dtype=dtype, device=device)),
|
|
(torch.randn((8, 8), dtype=dtype, device=device), torch.randn((8, 8), dtype=dtype, device=device)),
|
|
(torch.randn((8, 8), dtype=dtype, device=device), torch.randn((8, 8), dtype=dtype, device=device)),
|
|
(torch.randn((8, 8), dtype=dtype, device=device), torch.randn((8, 8), dtype=dtype, device=device))]
|
|
|
|
loss_fn = torch.nn.MSELoss().to(device)
|
|
|
|
skip_iter = 2
|
|
|
|
return _create_scaling_models_optimizers(
|
|
device=device, optimizer_ctor=optimizer_ctor, optimizer_kwargs=optimizer_kwargs,
|
|
) + (data, loss_fn, skip_iter)
|
|
|
|
|
|
# Importing this module should NOT eagerly initialize CUDA
|
|
if not CUDA_ALREADY_INITIALIZED_ON_IMPORT:
|
|
assert not torch.cuda.is_initialized()
|