pytorch/torch/_dynamo/testing.py
Adnan Akhundov 809ff3b274 Add host-side Triton TMA support to Dynamo (#137677)
This adds Dynamo tracing support for the host-side Triton TMA API (see `create_2d_tma_descriptor` calls on the host in the [Triton tutorial](https://triton-lang.org/main/getting-started/tutorials/09-persistent-matmul.html#sphx-glr-getting-started-tutorials-09-persistent-matmul-py)). A few notes:

- Here we assume the availability of the host-side TMA API added to upstream Triton in https://github.com/triton-lang/triton/pull/4498. As of time of writing, this is not a part of the PT2 OSS Triton pin (although back-ported internally). OSS Triton pin update should be done in December 2024.
- To capture the chain of calls `t.data_ptr() --> create_{1d,2d}_tma_descriptor(ptr, ...) --> kernel[grid](tma_desc, ...)`, we add three new variable trackers: `DataPtrVariable`, `CreateTMADescriptorVariable` (for the function), `TMADescriptorVariable` (for TMA descriptor object). This is to maintain the path back from the Triton kernel to the Tensor from which the TMA descriptor has been created.
- The newly introduced variables have `reconstruct` methods used in case of graph breaks.
- The `tma_descriptor_metadata` extracted from the captured `create_{1d,2d}_tma_descriptor` calls is propagated through the HOPs in Dynamo and AOTAutograd to be used by the downstream compiler (e.g., Inductor). See the unit tests for how the captured HOP arguments look like.
- In the Dynamo-captured fx graph, we replace the TMA descriptor arguments of the Triton kernel by the underlying Tensors, to be able to track the input/output relationships in terms of Tensors.
- In the Triton kernel mutation analysis pass (in AOTAutograd), we use the `tt.experimental_descriptor_store` TTIR op to detect mutations of the underlying tensors via TMA descriptors. So that downstream AOTAutograd can perform functionalizations as required.
- JIT Inductor and AOT Inductor support will be implemented in follow-up PRs.

Differential Revision: [D64404928](https://our.internmc.facebook.com/intern/diff/D64404928)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137677
Approved by: https://github.com/zou3519
2024-10-16 02:18:48 +00:00

477 lines
14 KiB
Python

import contextlib
import dis
import functools
import logging
import os.path
import random
import re
import sys
import types
import unittest
from typing import (
Any,
Callable,
Dict,
List,
Optional,
overload,
Sequence,
Tuple,
TypeVar,
Union,
)
from unittest.mock import patch
import torch
from torch import fx
from torch._dynamo.backends.debugging import aot_eager
from torch._dynamo.output_graph import OutputGraph
from . import config, eval_frame, optimize_assert, reset
from .bytecode_transformation import (
create_instruction,
debug_checks,
is_generator,
transform_code_object,
)
from .guards import CheckFunctionManager, CompileId, GuardedCode
from .utils import same
np: Optional[types.ModuleType] = None
try:
import numpy as np
except ModuleNotFoundError:
np = None
unsupported = eval_frame.unsupported
three = 3
log = logging.getLogger(__name__)
def clone_me(x: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
if x is None:
return None
return x.detach().clone().requires_grad_(x.requires_grad)
def remove_optimized_module_prefix(name: str) -> str:
return re.sub(r"^_orig_mod[.]", "", name)
def collect_results(
model: torch.nn.Module, prediction: Any, loss: Any, example_inputs: Any
) -> List[Any]:
results = []
results.append(prediction)
results.append(loss)
# if isinstance(loss, torch.Tensor) and loss.item() > 1:
# log.warning(
# f"High loss value alert - {loss:.2f}. Can result in unstable gradients."
# )
grads = {}
params = {}
for name, param in model.named_parameters():
if isinstance(model, eval_frame.OptimizedModule):
name = remove_optimized_module_prefix(name)
param_copy = param
grad = param.grad
# Treat None and zero grad as same
if param.grad is None:
grad = torch.zeros_like(param)
grads[name + ".grad"] = grad
params[name] = param_copy
results.append(grads)
results.append(params)
buffers = {}
for name, buffer in model.named_buffers():
if isinstance(model, eval_frame.OptimizedModule):
name = remove_optimized_module_prefix(name)
buffers[name] = buffer
results.append(buffers)
for example in example_inputs:
if isinstance(example, (tuple, list)):
for inp in example:
if isinstance(inp, torch.Tensor):
results.append(inp.grad)
else:
if isinstance(example, torch.Tensor):
results.append(example.grad)
return results
def requires_bwd_pass(out: Any) -> bool:
if isinstance(out, torch.Tensor):
return out.requires_grad
elif isinstance(out, (list, tuple)):
return any(requires_bwd_pass(x) for x in out)
elif out is None:
return False
elif isinstance(out, int):
return False
raise NotImplementedError("Don't know how to reduce", type(out))
@overload
def reduce_to_scalar_loss(out: torch.Tensor) -> torch.Tensor:
...
@overload
def reduce_to_scalar_loss(
out: Union[List[Any], Tuple[Any, ...], Dict[Any, Any]]
) -> float:
...
def reduce_to_scalar_loss(out: Any) -> Union[torch.Tensor, float]:
"""Reduce the output of a model to get scalar loss"""
if isinstance(out, torch.Tensor):
# Mean does not work on integer tensors
return out.sum() / out.numel()
elif isinstance(out, (list, tuple)):
return sum(reduce_to_scalar_loss(x) for x in out) / len(out)
elif type(out).__name__ in (
"MaskedLMOutput",
"Seq2SeqLMOutput",
"CausalLMOutputWithCrossAttentions",
):
return reduce_to_scalar_loss(out.logits)
elif type(out).__name__ == "SquashedNormal":
return out.mean.sum()
elif isinstance(out, dict):
return sum(reduce_to_scalar_loss(value) for value in out.values()) / len(
out.keys()
)
raise NotImplementedError("Don't know how to reduce", type(out))
def debug_dir() -> str:
path = os.path.join(os.path.dirname(__file__), "../debug")
if not os.path.exists(path):
os.mkdir(path)
return path
def debug_dump(name: str, code: types.CodeType, extra: str = "") -> None:
with open(os.path.join(debug_dir(), name), "w") as fd:
fd.write(
f"{dis.Bytecode(code).info()}\n\n{dis.Bytecode(code).dis()}\n\n{extra}\n"
)
def debug_insert_nops(
frame: types.FrameType, cache_size: int, hooks: Any, _: Any, *, skip: int = 0
) -> Optional[GuardedCode]:
"""used to debug jump updates"""
def insert_nops(instructions: List[Any], code_options: Any) -> None:
instructions.insert(0, create_instruction("NOP"))
instructions.insert(0, create_instruction("NOP"))
if is_generator(frame.f_code):
return None
debug_checks(frame.f_code)
code = transform_code_object(frame.f_code, insert_nops)
graph = OutputGraph(
code_options={},
compiler_fn=None,
root_tx=None,
export=False,
export_constraints=None,
frame_state={"_id": 0},
# TODO: shouldn't this be f_locals/f_globals from frame?
local_scope=locals(),
global_scope=globals(),
f_code=frame.f_code,
torch_function_mode_stack=[],
)
return GuardedCode(code, CheckFunctionManager(graph).check_fn, CompileId(0, 0))
class CompileCounter:
def __init__(self) -> None:
self.frame_count = 0
self.op_count = 0
def __call__(
self, gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]
) -> Callable[..., Any]:
self.frame_count += 1
for node in gm.graph.nodes:
if "call" in node.op:
self.op_count += 1
return gm.forward
def clear(self) -> None:
self.frame_count = 0
self.op_count = 0
class CompileCounterWithBackend:
def __init__(self, backend: str) -> None:
self.frame_count = 0
self.op_count = 0
self.backend = backend
self.graphs: List[torch.fx.GraphModule] = []
def __call__(
self, gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]
) -> Callable[..., Any]:
from .backends.registry import lookup_backend
self.frame_count += 1
for node in gm.graph.nodes:
if "call" in node.op:
self.op_count += 1
self.graphs.append(gm)
return lookup_backend(self.backend)(gm, example_inputs)
# Equivalent to backend="eager", but also records graphs that
# we can assert on
class EagerAndRecordGraphs:
def __init__(self) -> None:
self.graphs: List[torch.fx.GraphModule] = []
def __call__(
self, gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]
) -> Callable[..., Any]:
self.graphs.append(gm)
return gm.forward
# Equivalent to backend="aot_eager", but also records graphs that
# we can assert on
class AOTEagerAndRecordGraphs:
def __init__(self) -> None:
self.graphs: List[torch.fx.GraphModule] = []
def __call__(
self, gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]
) -> Callable[..., Any]:
def save_graph(gm: torch.fx.GraphModule, *args: Any, **kwargs: Any) -> Any:
self.graphs.append(gm)
return gm.forward
return aot_eager(
gm,
example_inputs,
fw_compiler=save_graph,
bw_compiler=save_graph,
)
def strip_comment(code: str) -> str:
return re.sub(r"(?m)^ *#.*\n?", "", code)
def remove_trailing_space(code: str) -> str:
return "\n".join([line.rstrip() for line in code.split("\n")])
def normalize_gm(gm_str: str) -> str:
# strip comments as comments have path to files which may differ from
# system to system.
return remove_trailing_space(strip_comment(gm_str))
def empty_line_normalizer(code: str) -> str:
"""
Normalize code: remove empty lines.
"""
normal_code = re.sub(r"[\r\n]+", "\n", code)
return normal_code
def standard_test(
self: Any,
fn: Callable[..., Any],
nargs: int,
expected_ops: Optional[int] = None,
expected_ops_dynamic: Optional[int] = None,
expected_frame_count: int = 1,
) -> None:
if not config.assume_static_by_default and expected_ops_dynamic is not None:
expected_ops = expected_ops_dynamic
actual = CompileCounter()
args1 = [torch.randn(10, 10) for _ in range(nargs)]
args2 = [torch.randn(10, 10) for _ in range(nargs)]
correct1 = fn(*args1)
correct2 = fn(*args2)
reset()
opt_fn = optimize_assert(actual)(fn)
val1a = opt_fn(*args1)
val2a = opt_fn(*args2)
val1b = opt_fn(*args1)
val2b = opt_fn(*args2)
reset()
self.assertTrue(same(val1a, correct1))
self.assertTrue(same(val1b, correct1))
self.assertTrue(same(val2a, correct2))
self.assertTrue(same(val2b, correct2))
self.assertEqual(actual.frame_count, expected_frame_count)
if expected_ops is not None:
self.assertEqual(actual.op_count, expected_ops)
def dummy_fx_compile(
gm: fx.GraphModule, example_inputs: List[torch.Tensor]
) -> Callable[..., Any]:
return gm.forward
def format_speedup(
speedup: float,
pvalue: float,
is_correct: bool = True,
pvalue_threshold: float = 0.1,
) -> str:
if not is_correct:
return "ERROR"
if pvalue > pvalue_threshold:
return f"{speedup:.3f}x SAME"
return f"{speedup:.3f}x p={pvalue:.2f}"
def rand_strided(
size: Sequence[int],
stride: Sequence[int],
dtype: torch.dtype = torch.float32,
device: Union[str, torch.device] = "cpu",
extra_size: int = 0,
) -> torch.Tensor:
needed_size = (
sum((shape - 1) * stride for shape, stride in zip(size, stride))
+ 1
+ extra_size
)
if dtype.is_floating_point:
if dtype.itemsize == 1:
"""
normal distribution kernel is not implemented for fp8..
Workaround that by creating a fp16 tensor and then cast.
"""
buffer = torch.randn(needed_size, dtype=torch.float16, device=device).to(
dtype=dtype
)
else:
buffer = torch.randn(needed_size, dtype=dtype, device=device)
else:
buffer = torch.zeros(size=[needed_size], dtype=dtype, device=device)
return torch.as_strided(buffer, size, stride)
_T = TypeVar("_T")
def _make_fn_with_patches(fn: Callable[..., _T], *patches: Any) -> Callable[..., _T]:
@functools.wraps(fn)
def _fn(*args: Any, **kwargs: Any) -> _T:
with contextlib.ExitStack() as stack:
for module, attr, val in patches:
stack.enter_context(patch.object(module, attr, val))
return fn(*args, **kwargs)
return _fn
def make_test_cls_with_patches(
cls: type,
cls_prefix: str,
fn_suffix: str,
*patches: Any,
xfail_prop: Optional[str] = None,
decorator: Callable[[Callable[..., Any]], Callable[..., Any]] = lambda x: x,
) -> type:
DummyTestClass = type(f"{cls_prefix}{cls.__name__}", cls.__bases__, {})
DummyTestClass.__qualname__ = DummyTestClass.__name__
for name in dir(cls):
if name.startswith("test_"):
fn = getattr(cls, name)
if not callable(fn):
setattr(DummyTestClass, name, getattr(cls, name))
continue
new_name = f"{name}{fn_suffix}"
new_fn = _make_fn_with_patches(fn, *patches)
new_fn.__name__ = new_name
if xfail_prop is not None and hasattr(fn, xfail_prop):
new_fn = unittest.expectedFailure(new_fn)
setattr(DummyTestClass, new_name, decorator(new_fn))
# NB: Doesn't handle slots correctly, but whatever
elif not hasattr(DummyTestClass, name):
setattr(DummyTestClass, name, getattr(cls, name))
return DummyTestClass
# test Python 3.11+ specific features
def skipIfNotPy311(fn: Callable[..., Any]) -> Callable[..., Any]:
if sys.version_info >= (3, 11):
return fn
return unittest.skip(fn)
def skipIfNotPy312(fn: Callable[..., Any]) -> Callable[..., Any]:
if sys.version_info >= (3, 12):
return fn
return unittest.skip("Requires Python 3.12+")(fn)
def xfailIfPy312(fn: Callable[..., Any]) -> Callable[..., Any]:
if sys.version_info >= (3, 12):
return unittest.expectedFailure(fn)
return fn
def skipIfPy312(fn: Callable[..., Any]) -> Callable[..., Any]:
if sys.version_info >= (3, 12):
return unittest.skip("Not supported in Python 3.12+")(fn)
return fn
def requiresPy310(fn: Callable[..., Any]) -> Callable[..., Any]:
if sys.version_info >= (3, 10):
return fn
else:
return unittest.skip("Requires Python 3.10+")(fn)
# Controls tests generated in test/inductor/test_torchinductor_dynamic_shapes.py
# and test/dynamo/test_dynamic_shapes.py
def expectedFailureDynamic(fn: Callable[..., Any]) -> Callable[..., Any]:
fn._expected_failure_dynamic = True # type: ignore[attr-defined]
return fn
# Controls tests generated in test/inductor/test_torchinductor_codegen_dynamic_shapes.py
def expectedFailureCodegenDynamic(fn: Callable[..., Any]) -> Callable[..., Any]:
fn._expected_failure_codegen_dynamic = True # type: ignore[attr-defined]
return fn
# Controls test generated in test/inductor/test_cpp_wrapper.py
def expectedFailureDynamicWrapper(fn: Callable[..., Any]) -> Callable[..., Any]:
fn._expected_failure_dynamic_wrapper = True # type: ignore[attr-defined]
return fn
def reset_rng_state(use_xla: bool = False) -> None:
torch.manual_seed(1337)
random.seed(1337)
if np:
np.random.seed(1337)
if use_xla:
import torch_xla.core.xla_model as xm
xm.set_rng_state(1337, str(xm.xla_device()))