mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Code Clean] Clean asserts in torch/autograd. (#165627)
Replaces 78 assert statements across 10 files in torch.autograd with explicit if-checks raising AssertionError to prevent assertions from being disabled with Python -O flag. This ensures error checking remains active in optimized builds. fix partially #164878 Pull Request resolved: https://github.com/pytorch/pytorch/pull/165627 Approved by: https://github.com/albanD
This commit is contained in:
parent
1bcd736f91
commit
850ba8c96d
|
|
@ -113,7 +113,8 @@ def _make_grads(
|
|||
# circular import
|
||||
from torch.nested._internal.nested_tensor import NestedTensor
|
||||
|
||||
assert isinstance(out, torch.Tensor)
|
||||
if not isinstance(out, torch.Tensor):
|
||||
raise AssertionError("Expected output to be a torch.Tensor")
|
||||
out_dtype = out.dtype
|
||||
out_is_nested = out.is_nested
|
||||
out_is_cpp_nested = out_is_nested and not isinstance(out, NestedTensor)
|
||||
|
|
@ -129,13 +130,15 @@ def _make_grads(
|
|||
# singleton int to represent jagged dimension, so that size() call
|
||||
# on nested tensor works.
|
||||
if out_is_cpp_nested:
|
||||
assert isinstance(out, torch.Tensor)
|
||||
if not isinstance(out, torch.Tensor):
|
||||
raise AssertionError("Expected output to be a torch.Tensor.")
|
||||
shape_matches = torch.is_same_size(out, first_grad)
|
||||
else:
|
||||
# We need to do a regular size check, without going through
|
||||
# the operator, to be able to handle unbacked symints
|
||||
# (expect_true ensures we can deal with unbacked)
|
||||
assert out_size is not None
|
||||
if out_size is None:
|
||||
raise AssertionError("Expected out_size to be set.")
|
||||
shape_matches = expect_true(sym_eq(out_size, first_grad.size()))
|
||||
|
||||
if not shape_matches:
|
||||
|
|
@ -191,10 +194,12 @@ def _make_grads(
|
|||
elif grad is None:
|
||||
if isinstance(out, graph.GradientEdge) or out.requires_grad: # type: ignore[attr-defined]
|
||||
if isinstance(out, graph.GradientEdge):
|
||||
assert out_size is not None
|
||||
if out_size is None:
|
||||
raise AssertionError("Expected out_size to be set.")
|
||||
out_numel_is_1 = all(o == 1 for o in out_size)
|
||||
else:
|
||||
assert isinstance(out, torch.Tensor)
|
||||
if not isinstance(out, torch.Tensor):
|
||||
raise AssertionError("Expected output to be a torch.Tensor")
|
||||
out_numel_is_1 = out.numel() == 1
|
||||
if not out_numel_is_1:
|
||||
raise RuntimeError(
|
||||
|
|
@ -207,8 +212,10 @@ def _make_grads(
|
|||
)
|
||||
raise RuntimeError(msg)
|
||||
if isinstance(out, graph.GradientEdge):
|
||||
assert out_size is not None
|
||||
assert out_device is not None
|
||||
if out_size is None:
|
||||
raise AssertionError("Expected out_size to be set.")
|
||||
if out_device is None:
|
||||
raise AssertionError("Expected out_device to be set.")
|
||||
new_grads.append(
|
||||
torch.ones(
|
||||
out_size,
|
||||
|
|
@ -217,7 +224,8 @@ def _make_grads(
|
|||
)
|
||||
)
|
||||
else:
|
||||
assert isinstance(out, torch.Tensor)
|
||||
if not isinstance(out, torch.Tensor):
|
||||
raise AssertionError("Expected output to be a torch.Tensor")
|
||||
new_grads.append(
|
||||
torch.ones_like(out, memory_format=torch.preserve_format)
|
||||
)
|
||||
|
|
|
|||
|
|
@ -65,5 +65,8 @@ class Resize(Function):
|
|||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
def backward(ctx, grad_output):
|
||||
assert grad_output.numel() == ctx.numel
|
||||
if grad_output.numel() != ctx.numel:
|
||||
raise AssertionError(
|
||||
f"Expected grad_output to have {ctx.numel} elements, but got {grad_output.numel()}"
|
||||
)
|
||||
return grad_output.contiguous().view(ctx.input_sizes), None
|
||||
|
|
|
|||
|
|
@ -146,7 +146,8 @@ class FunctionCtx:
|
|||
|
||||
"""
|
||||
for tensor in tensors:
|
||||
assert isinstance(tensor, torch.Tensor) or tensor is None, (
|
||||
if not (isinstance(tensor, torch.Tensor) or tensor is None):
|
||||
raise AssertionError(
|
||||
"save_for_forward expects all arguments to be tensors; you should "
|
||||
"save non-tensors as attributes on ctx."
|
||||
)
|
||||
|
|
|
|||
|
|
@ -54,7 +54,8 @@ def _tuple_postprocess(res, to_unpack):
|
|||
# - invert _as_tuple when res should match the inp given to _as_tuple
|
||||
# - optionally remove nesting of two tuples created by multiple calls to _as_tuple
|
||||
if isinstance(to_unpack, tuple):
|
||||
assert len(to_unpack) == 2
|
||||
if len(to_unpack) != 2:
|
||||
raise AssertionError("Expected to_unpack tuple to have exactly 2 elements")
|
||||
if not to_unpack[1]:
|
||||
res = tuple(el[0] for el in res)
|
||||
if not to_unpack[0]:
|
||||
|
|
@ -174,11 +175,17 @@ def _autograd_grad(
|
|||
):
|
||||
# Version of autograd.grad that accepts `None` in outputs and do not compute gradients for them.
|
||||
# This has the extra constraint that inputs has to be a tuple
|
||||
assert isinstance(outputs, tuple)
|
||||
if not isinstance(outputs, tuple):
|
||||
raise AssertionError("Expected outputs to be a tuple")
|
||||
if grad_outputs is None:
|
||||
grad_outputs = (None,) * len(outputs)
|
||||
assert isinstance(grad_outputs, tuple)
|
||||
assert len(outputs) == len(grad_outputs)
|
||||
if not isinstance(grad_outputs, tuple):
|
||||
raise AssertionError("Expected grad_outputs to be a tuple")
|
||||
if len(outputs) != len(grad_outputs):
|
||||
raise AssertionError(
|
||||
f"Expected outputs and grad_outputs to have the same length, "
|
||||
f"but got {len(outputs)} and {len(grad_outputs)}"
|
||||
)
|
||||
|
||||
new_outputs: tuple[torch.Tensor, ...] = ()
|
||||
new_grad_outputs: tuple[torch.Tensor, ...] = ()
|
||||
|
|
@ -489,8 +496,13 @@ def _construct_standard_basis_for(
|
|||
# See NOTE: [Computing jacobian with vmap and grad for multiple tensors]
|
||||
# for context behind this function. All the pre-conditions are guarded for
|
||||
# in torch.autograd.functional.jacobian.
|
||||
assert len(tensors) == len(tensor_numels)
|
||||
assert len(tensors) > 0
|
||||
if len(tensors) != len(tensor_numels):
|
||||
raise AssertionError(
|
||||
f"Expected tensors and tensor_numels to have the same length, "
|
||||
f"but got {len(tensors)} and {len(tensor_numels)}"
|
||||
)
|
||||
if len(tensors) == 0:
|
||||
raise AssertionError("Expected at least one tensor")
|
||||
total_numel = sum(tensor_numels)
|
||||
chunks = tuple(
|
||||
tensor.new_zeros(total_numel, tensor_numel)
|
||||
|
|
@ -664,7 +676,8 @@ def jacobian(
|
|||
>>> jac.shape
|
||||
torch.Size([4, 2, 4, 2])
|
||||
"""
|
||||
assert strategy in ("forward-mode", "reverse-mode"), (
|
||||
if strategy not in ("forward-mode", "reverse-mode"):
|
||||
raise AssertionError(
|
||||
'Expected strategy to be either "forward-mode" or "reverse-mode". Hint: If your '
|
||||
'function has more outputs than inputs, "forward-mode" tends to be more performant. '
|
||||
'Otherwise, prefer to use "reverse-mode".'
|
||||
|
|
@ -932,10 +945,13 @@ def hessian(
|
|||
[0., 6.]])))
|
||||
"""
|
||||
is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "hessian")
|
||||
assert outer_jacobian_strategy in (
|
||||
if outer_jacobian_strategy not in (
|
||||
"forward-mode",
|
||||
"reverse-mode",
|
||||
), 'Expected strategy to be either "forward-mode" or "reverse-mode".'
|
||||
):
|
||||
raise AssertionError(
|
||||
'Expected strategy to be either "forward-mode" or "reverse-mode".'
|
||||
)
|
||||
|
||||
def ensure_single_output_function(*inp):
|
||||
out = func(*inp)
|
||||
|
|
|
|||
|
|
@ -408,7 +408,8 @@ class _unsafe_preserve_version_counter(_DecoratorContextManager):
|
|||
|
||||
def __init__(self, tensors: Union[torch.Tensor, tuple[torch.Tensor, ...]]) -> None:
|
||||
self.tensors = (tensors,) if isinstance(tensors, torch.Tensor) else tensors
|
||||
assert isinstance(self.tensors, tuple)
|
||||
if not isinstance(self.tensors, tuple):
|
||||
raise AssertionError("Expected tensors to be a tuple")
|
||||
self.prev_versions = tuple(t._version for t in self.tensors)
|
||||
|
||||
def __enter__(self) -> None:
|
||||
|
|
|
|||
|
|
@ -363,8 +363,15 @@ def _compute_numerical_gradient(fn, entry, v, norm_v, nbhd_checks_fn):
|
|||
# sparse compressed tensors don't implement sub/add/copy_
|
||||
# yet. However, in non-masked semantics context entry and v
|
||||
# have the same sparse indices ...
|
||||
assert entry.layout == v.layout, (entry.layout, v.layout)
|
||||
assert entry._nnz() == v._nnz(), (entry._nnz(), v._nnz(), entry.shape)
|
||||
if entry.layout != v.layout:
|
||||
raise AssertionError(
|
||||
f"Expected entry and v to have the same layout, but got {entry.layout} and {v.layout}"
|
||||
)
|
||||
if entry._nnz() != v._nnz():
|
||||
raise AssertionError(
|
||||
f"Expected entry and v to have the same nnz, but got {entry._nnz()} and {v._nnz()} "
|
||||
f"with entry shape {entry.shape}"
|
||||
)
|
||||
# ... the finite differencing can be performed on values only:
|
||||
entry = entry.values()
|
||||
v = v.values()
|
||||
|
|
@ -403,13 +410,15 @@ def _compute_numerical_jvps_wrt_specific_input(
|
|||
jvp_fn(delta[1] * 1j) if isinstance(delta, tuple) else jvp_fn(delta * 1j)
|
||||
)
|
||||
for ds_dx, ds_dy in zip(ds_dx_tup, ds_dy_tup):
|
||||
assert not ds_dx.is_complex()
|
||||
if ds_dx.is_complex():
|
||||
raise AssertionError("Expected ds_dx to be real-valued, not complex")
|
||||
# conjugate wirtinger derivative
|
||||
conj_w_d = ds_dx + ds_dy * 1j
|
||||
jvps.append(conj_w_d)
|
||||
else:
|
||||
for ds_dx in ds_dx_tup: # R -> R or (R -> C for the forward AD case)
|
||||
assert is_forward_ad or not ds_dx.is_complex()
|
||||
if not is_forward_ad and ds_dx.is_complex():
|
||||
raise AssertionError("Expected ds_dx to be real-valued, not complex.")
|
||||
jvps.append(ds_dx)
|
||||
return jvps
|
||||
|
||||
|
|
@ -456,12 +465,14 @@ def _check_outputs_same_dtype_and_shape(output1, output2, eps, idx=None) -> None
|
|||
# Check that the returned outputs don't have different dtype or shape when you
|
||||
# perturb the input
|
||||
on_index = f"on index {idx} " if idx is not None else ""
|
||||
assert output1.shape == output2.shape, (
|
||||
if output1.shape != output2.shape:
|
||||
raise AssertionError(
|
||||
f"Expected `func` to return outputs with the same shape"
|
||||
f" when inputs are perturbed {on_index}by {eps}, but got:"
|
||||
f" shapes {output1.shape} and {output2.shape}."
|
||||
)
|
||||
assert output1.dtype == output2.dtype, (
|
||||
if output1.dtype != output2.dtype:
|
||||
raise AssertionError(
|
||||
f"Expected `func` to return outputs with the same dtype"
|
||||
f" when inputs are perturbed {on_index}by {eps}, but got:"
|
||||
f" dtypes {output1.dtype} and {output2.dtype}."
|
||||
|
|
@ -478,7 +489,8 @@ def get_numerical_jacobian_wrt_specific_input(
|
|||
# is equivalent to a single col of the Jacobian matrix of fn.
|
||||
jacobian_cols: dict[int, list[torch.Tensor]] = {}
|
||||
input = inputs[input_idx] if input is None else input
|
||||
assert input.requires_grad
|
||||
if not input.requires_grad:
|
||||
raise AssertionError("Expected input to have requires_grad=True")
|
||||
for x, idx, d_idx in _iter_tensor(input):
|
||||
wrapped_fn = _with_prepare_inputs(fn, inputs, input_idx, x)
|
||||
input_to_perturb = x[idx]
|
||||
|
|
@ -687,7 +699,11 @@ def _get_numerical_vJu(
|
|||
# Filter out the Ju for non floating point outputs
|
||||
filtered_Ju = []
|
||||
func_out = _as_tuple(func_out)
|
||||
assert len(all_Ju) == len(func_out)
|
||||
if len(all_Ju) != len(func_out):
|
||||
raise AssertionError(
|
||||
f"Expected all_Ju and func_out to have the same length, "
|
||||
f"but got {len(all_Ju)} and {len(func_out)}"
|
||||
)
|
||||
for Ju, output in zip(all_Ju, func_out):
|
||||
if _is_float_or_complex_tensor(output):
|
||||
filtered_Ju.append(Ju)
|
||||
|
|
@ -734,7 +750,11 @@ def _stack_and_check_tensors(
|
|||
out_jacobian[:, j].zero_()
|
||||
else:
|
||||
dense = tensor.to_dense() if tensor.layout != torch.strided else tensor
|
||||
assert out_jacobian[:, j].numel() == dense.numel()
|
||||
if out_jacobian[:, j].numel() != dense.numel():
|
||||
raise AssertionError(
|
||||
f"Expected out_jacobian column to have {dense.numel()} elements, "
|
||||
f"but got {out_jacobian[:, j].numel()}"
|
||||
)
|
||||
out_jacobian[:, j] = dense.reshape(-1)
|
||||
return out_jacobians, correct_grad_sizes, correct_grad_types
|
||||
|
||||
|
|
@ -1061,7 +1081,8 @@ Expected:
|
|||
|
||||
def _test_batched_grad_forward_ad(func, inputs) -> bool:
|
||||
fwAD = torch.autograd.forward_ad # To avoid early import issues (do we need this?)
|
||||
assert isinstance(inputs, tuple)
|
||||
if not isinstance(inputs, tuple):
|
||||
raise AssertionError("Expected inputs to be a tuple")
|
||||
|
||||
for input_idx, current_input in enumerate(inputs):
|
||||
if not (is_tensor_like(current_input) and current_input.requires_grad):
|
||||
|
|
@ -1641,7 +1662,10 @@ def _slow_gradcheck(
|
|||
|
||||
|
||||
def _dot_with_type_promotion(u, v):
|
||||
assert u.dim() == 1 and v.dim() == 1
|
||||
if u.dim() != 1 or v.dim() != 1:
|
||||
raise AssertionError(
|
||||
f"Expected u and v to be 1D tensors, but got dims {u.dim()} and {v.dim()}"
|
||||
)
|
||||
return (u * v).sum()
|
||||
|
||||
|
||||
|
|
@ -1908,7 +1932,8 @@ def _fast_gradcheck(
|
|||
)
|
||||
# TODO: replicate https://github.com/pytorch/pytorch/pull/77743 for fast gradcheck as well
|
||||
if use_forward_ad:
|
||||
assert all_v is None
|
||||
if all_v is not None:
|
||||
raise AssertionError("Expected all_v to be None.")
|
||||
analytical_vJu = _get_analytical_jacobian_forward_ad(
|
||||
func,
|
||||
inputs,
|
||||
|
|
@ -2036,13 +2061,16 @@ def gradcheck(
|
|||
``True`` if all differences satisfy allclose condition
|
||||
|
||||
"""
|
||||
assert check_forward_ad or check_backward_ad, (
|
||||
if not (check_forward_ad or check_backward_ad):
|
||||
raise AssertionError(
|
||||
"Expected at least one of check_forward_ad or check_backward_ad to be True"
|
||||
)
|
||||
assert not (check_batched_grad and not check_backward_ad), (
|
||||
if check_batched_grad and not check_backward_ad:
|
||||
raise AssertionError(
|
||||
"Setting check_batched_grad=True requires check_backward_ad to be True"
|
||||
)
|
||||
assert not (check_batched_forward_grad and not check_forward_ad), (
|
||||
if check_batched_forward_grad and not check_forward_ad:
|
||||
raise AssertionError(
|
||||
"Setting check_batched_forward_grad=True requires check_forward_ad to be True"
|
||||
)
|
||||
args = locals().copy()
|
||||
|
|
@ -2189,13 +2217,16 @@ def gradgradcheck(
|
|||
Returns:
|
||||
True if all differences satisfy allclose condition
|
||||
"""
|
||||
assert check_fwd_over_rev or check_rev_over_rev, (
|
||||
if not (check_fwd_over_rev or check_rev_over_rev):
|
||||
raise AssertionError(
|
||||
"Expected at least one of check_fwd_over_rev or check_rev_over_rev to be True"
|
||||
)
|
||||
assert not (check_undefined_grad and not check_rev_over_rev), (
|
||||
if check_undefined_grad and not check_rev_over_rev:
|
||||
raise AssertionError(
|
||||
"Setting check_undefined_grad=True requires check_rev_over_rev to be True"
|
||||
)
|
||||
assert not (check_batched_grad and not check_rev_over_rev), (
|
||||
if check_batched_grad and not check_rev_over_rev:
|
||||
raise AssertionError(
|
||||
"Setting check_batched_grad=True requires check_rev_over_rev to be True"
|
||||
)
|
||||
# TODO: do we want to test this too?
|
||||
|
|
|
|||
|
|
@ -187,7 +187,8 @@ def _get_grad_fn_or_grad_acc(t: Union[torch.Tensor, "GradientEdge"]) -> Node:
|
|||
node = t.view_as(t).grad_fn.next_functions[0][0] # type: ignore[union-attr]
|
||||
else:
|
||||
node = t.grad_fn
|
||||
assert node is not None
|
||||
if node is None:
|
||||
raise AssertionError("Expected gradient function to be set")
|
||||
return node
|
||||
|
||||
|
||||
|
|
@ -528,7 +529,8 @@ def register_multi_grad_hook(
|
|||
def inner_hook(grad: torch.Tensor) -> None:
|
||||
nonlocal count, nb_calls, buffer, fn
|
||||
id = torch._C._current_graph_task_id()
|
||||
assert id != -1, (
|
||||
if id == -1:
|
||||
raise AssertionError(
|
||||
"expected this hook to be called inside a backward call"
|
||||
)
|
||||
count[id] = count.get(id, 0)
|
||||
|
|
@ -546,7 +548,8 @@ def register_multi_grad_hook(
|
|||
|
||||
buffer[id][idx] = grad
|
||||
|
||||
assert nb_calls is not None
|
||||
if nb_calls is None:
|
||||
raise AssertionError("Expected nb_calls to be set")
|
||||
if curr_count == nb_calls - 1:
|
||||
fn = cast(Callable[[Sequence[Optional[torch.Tensor]]], None], fn)
|
||||
fn(buffer[id])
|
||||
|
|
@ -566,7 +569,10 @@ def register_multi_grad_hook(
|
|||
def wrapped_fn(grad: torch.Tensor) -> None:
|
||||
nonlocal ran_hook
|
||||
id = torch._C._current_graph_task_id()
|
||||
assert id != -1, "expected this hook to be called inside a backward call"
|
||||
if id == -1:
|
||||
raise AssertionError(
|
||||
"expected this hook to be called inside a backward call"
|
||||
)
|
||||
with lock:
|
||||
prev, ran_hook[id] = ran_hook[id], True
|
||||
if prev:
|
||||
|
|
@ -662,11 +668,13 @@ class _swap_with_cloned(saved_tensors_hooks):
|
|||
"Trying to backward outside of the 'allow_mutation_on_saved_tensors' context"
|
||||
"in which the graph was originally recorded."
|
||||
)
|
||||
assert _allow_mutation_on_saved_tensors_enabled, error_msg
|
||||
if not _allow_mutation_on_saved_tensors_enabled:
|
||||
raise AssertionError(error_msg)
|
||||
if handle in ctx.cloned:
|
||||
res = ctx.cloned[handle]
|
||||
else:
|
||||
assert handle in ctx.original, error_msg
|
||||
if handle not in ctx.original:
|
||||
raise AssertionError(error_msg)
|
||||
res = ctx.original[handle]
|
||||
return res
|
||||
|
||||
|
|
|
|||
|
|
@ -255,7 +255,8 @@ class profile:
|
|||
self.custom_trace_id_callback = custom_trace_id_callback
|
||||
self.trace_id = ""
|
||||
if not self.use_cpu:
|
||||
assert use_kineto, (
|
||||
if not use_kineto:
|
||||
raise AssertionError(
|
||||
"Device-only events supported only with Kineto (use_kineto=True)"
|
||||
)
|
||||
|
||||
|
|
@ -289,22 +290,26 @@ class profile:
|
|||
self.profiler_kind = ProfilerState.KINETO
|
||||
if self.use_device == "cuda":
|
||||
if not use_kineto or ProfilerActivity.CUDA not in _supported_activities():
|
||||
assert self.use_cpu, "Legacy CUDA profiling requires use_cpu=True"
|
||||
if not self.use_cpu:
|
||||
raise AssertionError("Legacy CUDA profiling requires use_cpu=True")
|
||||
self.profiler_kind = ProfilerState.KINETO_GPU_FALLBACK
|
||||
else:
|
||||
self.kineto_activities.add(ProfilerActivity.CUDA)
|
||||
elif self.use_device == "xpu":
|
||||
assert use_kineto and ProfilerActivity.XPU in _supported_activities(), (
|
||||
if not (use_kineto and ProfilerActivity.XPU in _supported_activities()):
|
||||
raise AssertionError(
|
||||
"Legacy XPU profiling is not supported. Requires use_kineto=True on XPU devices."
|
||||
)
|
||||
self.kineto_activities.add(ProfilerActivity.XPU)
|
||||
elif self.use_device == "mtia":
|
||||
assert use_kineto and ProfilerActivity.MTIA in _supported_activities(), (
|
||||
if not (use_kineto and ProfilerActivity.MTIA in _supported_activities()):
|
||||
raise AssertionError(
|
||||
"Legacy MTIA profiling is not supported. Requires use_kineto=True on MTIA devices."
|
||||
)
|
||||
self.kineto_activities.add(ProfilerActivity.MTIA)
|
||||
elif self.use_device == "hpu":
|
||||
assert use_kineto and ProfilerActivity.HPU in _supported_activities(), (
|
||||
if not (use_kineto and ProfilerActivity.HPU in _supported_activities()):
|
||||
raise AssertionError(
|
||||
"Legacy HPU profiling is not supported. Requires use_kineto=True on HPU devices."
|
||||
)
|
||||
self.kineto_activities.add(ProfilerActivity.HPU)
|
||||
|
|
@ -313,16 +318,16 @@ class profile:
|
|||
not use_kineto
|
||||
or ProfilerActivity.PrivateUse1 not in _supported_activities()
|
||||
):
|
||||
assert self.use_cpu, (
|
||||
if not self.use_cpu:
|
||||
raise AssertionError(
|
||||
"Legacy custombackend profiling requires use_cpu=True"
|
||||
)
|
||||
self.profiler_kind = ProfilerState.KINETO_PRIVATEUSE1_FALLBACK
|
||||
else:
|
||||
self.kineto_activities.add(ProfilerActivity.PrivateUse1)
|
||||
|
||||
assert len(self.kineto_activities) > 0, (
|
||||
"No activities specified for the profiler"
|
||||
)
|
||||
if len(self.kineto_activities) == 0:
|
||||
raise AssertionError("No activities specified for the profiler")
|
||||
|
||||
def default_trace_id(self):
|
||||
# Generate a UUID
|
||||
|
|
@ -472,7 +477,8 @@ class profile:
|
|||
top_level_events_only=False,
|
||||
):
|
||||
self._ensure_function_events()
|
||||
assert self._function_events is not None
|
||||
if self._function_events is None:
|
||||
raise AssertionError("Expected profiling results")
|
||||
return self._function_events.table(
|
||||
sort_by=sort_by,
|
||||
row_limit=row_limit,
|
||||
|
|
@ -500,8 +506,10 @@ class profile:
|
|||
|
||||
def export_stacks(self, path: str, metric: str = "self_cpu_time_total"):
|
||||
self._ensure_function_events()
|
||||
assert self._function_events is not None, "Expected profiling results"
|
||||
assert self.with_stack, "export_stacks() requires with_stack=True"
|
||||
if self._function_events is None:
|
||||
raise AssertionError("Expected profiling results")
|
||||
if not self.with_stack:
|
||||
raise AssertionError("export_stacks() requires with_stack=True")
|
||||
return self._function_events.export_stacks(path, metric)
|
||||
|
||||
def toggle_collection_dynamic(
|
||||
|
|
@ -519,7 +527,8 @@ class profile:
|
|||
group_by_overload_name=False,
|
||||
):
|
||||
self._ensure_function_events()
|
||||
assert self._function_events is not None, "Expected profiling results"
|
||||
if self._function_events is None:
|
||||
raise AssertionError("Expected profiling results")
|
||||
return self._function_events.key_averages(
|
||||
group_by_input_shape, group_by_stack_n, group_by_overload_name
|
||||
)
|
||||
|
|
@ -528,7 +537,8 @@ class profile:
|
|||
|
||||
def total_average(self):
|
||||
self._ensure_function_events()
|
||||
assert self._function_events is not None, "Expected profiling results"
|
||||
if self._function_events is None:
|
||||
raise AssertionError("Expected profiling results")
|
||||
return self._function_events.total_average()
|
||||
|
||||
total_average.__doc__ = EventList.total_average.__doc__
|
||||
|
|
@ -540,7 +550,8 @@ class profile:
|
|||
The total time is a sum of all self times across all the events.
|
||||
"""
|
||||
self._ensure_function_events()
|
||||
assert self._function_events is not None
|
||||
if self._function_events is None:
|
||||
raise AssertionError("Expected profiling results")
|
||||
return self._function_events.self_cpu_time_total
|
||||
|
||||
def _parse_kineto_results(self, result: _ProfilerResult):
|
||||
|
|
@ -796,7 +807,8 @@ class record_function(_ContextDecorator):
|
|||
|
||||
# Local variable is needed by TorchScript to refine Optional[T] to T
|
||||
record = self.record
|
||||
assert record is not None
|
||||
if record is None:
|
||||
raise AssertionError("Expected record to be set")
|
||||
|
||||
# TODO: Too slow with __torch_function__ handling enabled
|
||||
# See https://github.com/pytorch/pytorch/issues/76410
|
||||
|
|
@ -833,7 +845,8 @@ class record_function(_ContextDecorator):
|
|||
|
||||
# Local variable is needed by TorchScript to refine Optional[T] to T
|
||||
record = self.record
|
||||
assert record is not None
|
||||
if record is None:
|
||||
raise AssertionError("Expected record to be set")
|
||||
|
||||
# TODO: Too slow with __torch_function__ handling enabled
|
||||
# See https://github.com/pytorch/pytorch/issues/76410
|
||||
|
|
@ -1124,7 +1137,8 @@ def parse_nvprof_trace(path):
|
|||
for row in conn.execute(kernel_query):
|
||||
unique.see(row["marker_id"], row["runtime_id"])
|
||||
# 211 is cudaKernelLaunch for cuda >= 9.2
|
||||
assert row["cbid"] == 211
|
||||
if row["cbid"] != 211:
|
||||
raise AssertionError(f"Expected cbid to be 211, but got {row['cbid']}")
|
||||
evt = functions_map[row["marker_id"]]
|
||||
evt.append_kernel(
|
||||
row["kernel_name"], 0, row["kernel_end"] - row["kernel_start"]
|
||||
|
|
|
|||
|
|
@ -137,7 +137,8 @@ class profile:
|
|||
top_level_events_only=False,
|
||||
):
|
||||
self._check_finish()
|
||||
assert self.function_events is not None
|
||||
if self.function_events is None:
|
||||
raise AssertionError("Expected profiling results")
|
||||
return self.function_events.table(
|
||||
sort_by=sort_by,
|
||||
row_limit=row_limit,
|
||||
|
|
@ -152,27 +153,32 @@ class profile:
|
|||
|
||||
def export_chrome_trace(self, path):
|
||||
self._check_finish()
|
||||
assert self.function_events is not None
|
||||
if self.function_events is None:
|
||||
raise AssertionError("Expected profiling results")
|
||||
return self.function_events.export_chrome_trace(path)
|
||||
|
||||
export_chrome_trace.__doc__ = EventList.export_chrome_trace.__doc__
|
||||
|
||||
def export_stacks(self, path: str, metric: str = "self_cpu_time_total"):
|
||||
self._check_finish()
|
||||
assert self.function_events is not None, "Expected profiling results"
|
||||
assert self.with_stack, "export_stacks() requires with_stack=True"
|
||||
if self.function_events is None:
|
||||
raise AssertionError("Expected profiling results")
|
||||
if not self.with_stack:
|
||||
raise AssertionError("export_stacks() requires with_stack=True")
|
||||
return self.function_events.export_stacks(path, metric)
|
||||
|
||||
def key_averages(self, group_by_input_shape=False, group_by_stack_n=0):
|
||||
self._check_finish()
|
||||
assert self.function_events is not None, "Expected profiling results"
|
||||
if self.function_events is None:
|
||||
raise AssertionError("Expected profiling results")
|
||||
return self.function_events.key_averages(group_by_input_shape, group_by_stack_n)
|
||||
|
||||
key_averages.__doc__ = EventList.key_averages.__doc__
|
||||
|
||||
def total_average(self):
|
||||
self._check_finish()
|
||||
assert self.function_events is not None, "Expected profiling results"
|
||||
if self.function_events is None:
|
||||
raise AssertionError("Expected profiling results")
|
||||
return self.function_events.total_average()
|
||||
|
||||
total_average.__doc__ = EventList.total_average.__doc__
|
||||
|
|
@ -181,7 +187,8 @@ class profile:
|
|||
def self_cpu_time_total(self):
|
||||
"""Return CPU time as the sum of self times across all events."""
|
||||
self._check_finish()
|
||||
assert self.function_events is not None
|
||||
if self.function_events is None:
|
||||
raise AssertionError("Expected profiling results")
|
||||
return self.function_events.self_cpu_time_total
|
||||
|
||||
|
||||
|
|
@ -199,7 +206,8 @@ def _parse_legacy_records(thread_records):
|
|||
if start_record is None and name == "__start_profile":
|
||||
start_record = record
|
||||
|
||||
assert start_record is not None and not start_record.is_remote()
|
||||
if start_record is None or start_record.is_remote():
|
||||
raise AssertionError("Expected a valid local start_record")
|
||||
|
||||
for thread_record_list in thread_records:
|
||||
# accumulated memory allocations per handle
|
||||
|
|
@ -233,10 +241,11 @@ def _parse_legacy_records(thread_records):
|
|||
cpu_memory_allocs[record_key] = 0
|
||||
cuda_memory_allocs[record_key] = 0
|
||||
elif record.kind() == "pop":
|
||||
assert (
|
||||
record_key in range_starts
|
||||
), f"""Expected record with key {record_key} to exist in range_starts.
|
||||
This means that the pop event did not have a corresponding push."""
|
||||
if record_key not in range_starts:
|
||||
raise AssertionError(
|
||||
f"Expected record with key {record_key} to exist in range_starts. "
|
||||
"This means that the pop event did not have a corresponding push."
|
||||
)
|
||||
|
||||
start = range_starts[record_key]
|
||||
|
||||
|
|
@ -282,7 +291,11 @@ def _parse_legacy_records(thread_records):
|
|||
elif record.kind() == "memory_alloc":
|
||||
num_open_handles_cpu = len(cpu_memory_allocs)
|
||||
num_open_handles_cuda = len(cuda_memory_allocs)
|
||||
assert num_open_handles_cpu == num_open_handles_cuda
|
||||
if num_open_handles_cpu != num_open_handles_cuda:
|
||||
raise AssertionError(
|
||||
f"Expected CPU and CUDA memory allocation handles to match, "
|
||||
f"but got {num_open_handles_cpu} CPU and {num_open_handles_cuda} CUDA"
|
||||
)
|
||||
for handle in cpu_memory_allocs.keys():
|
||||
cpu_memory_allocs[handle] += record.cpu_memory_usage()
|
||||
for handle in cuda_memory_allocs.keys():
|
||||
|
|
|
|||
|
|
@ -130,7 +130,8 @@ class EventList(list):
|
|||
current_events.pop()
|
||||
else:
|
||||
parent.append_cpu_child(event)
|
||||
assert event.cpu_parent is None, (
|
||||
if event.cpu_parent is not None:
|
||||
raise AssertionError(
|
||||
f"There is already a CPU parent event for {event.key}"
|
||||
)
|
||||
event.set_cpu_parent(parent)
|
||||
|
|
@ -157,7 +158,10 @@ class EventList(list):
|
|||
for evt in self:
|
||||
p = bw_parent(evt)
|
||||
if p is not None:
|
||||
assert p.fwd_thread is not None
|
||||
if p.fwd_thread is None:
|
||||
raise AssertionError(
|
||||
"Expected fwd_thread to be set for backward parent"
|
||||
)
|
||||
t = (p.sequence_nr, p.fwd_thread)
|
||||
evt.stack = fwd_stacks.get(t, [])
|
||||
|
||||
|
|
@ -322,7 +326,10 @@ class EventList(list):
|
|||
Returns:
|
||||
An EventList containing FunctionEventAvg objects.
|
||||
"""
|
||||
assert self._tree_built
|
||||
if not self._tree_built:
|
||||
raise AssertionError(
|
||||
"Expected tree to be built before calling key_averages"
|
||||
)
|
||||
stats: dict[tuple[str, ...], FunctionEventAvg] = defaultdict(FunctionEventAvg)
|
||||
|
||||
def get_key(
|
||||
|
|
@ -392,7 +399,8 @@ def _format_time(time_us):
|
|||
def _format_time_share(time_us, total_time_us):
|
||||
"""Define how to format time in FunctionEvent."""
|
||||
if total_time_us == 0:
|
||||
assert time_us == 0, f"Expected time_us == 0 but got {time_us}"
|
||||
if time_us != 0:
|
||||
raise AssertionError(f"Expected time_us == 0 but got {time_us}")
|
||||
return "NaN"
|
||||
return f"{time_us * 100.0 / total_time_us:.2f}%"
|
||||
|
||||
|
|
@ -537,7 +545,8 @@ class FunctionEvent(FormattedTimesMixin):
|
|||
self.metadata_json = metadata_json
|
||||
|
||||
def append_kernel(self, name, device, duration):
|
||||
assert self.device_type == DeviceType.CPU
|
||||
if self.device_type != DeviceType.CPU:
|
||||
raise AssertionError("Expected device_type to be CPU")
|
||||
self.kernels.append(Kernel(name, device, duration))
|
||||
|
||||
def append_cpu_child(self, child):
|
||||
|
|
@ -546,9 +555,12 @@ class FunctionEvent(FormattedTimesMixin):
|
|||
One is supposed to append only direct children to the event to have
|
||||
correct self cpu time being reported.
|
||||
"""
|
||||
assert self.device_type == DeviceType.CPU
|
||||
assert isinstance(child, FunctionEvent)
|
||||
assert child.device_type == DeviceType.CPU
|
||||
if self.device_type != DeviceType.CPU:
|
||||
raise AssertionError("Expected device_type to be CPU")
|
||||
if not isinstance(child, FunctionEvent):
|
||||
raise AssertionError("Expected child to be a FunctionEvent")
|
||||
if child.device_type != DeviceType.CPU:
|
||||
raise AssertionError("Expected child device_type to be CPU")
|
||||
self.cpu_children.append(child)
|
||||
|
||||
def set_cpu_parent(self, parent):
|
||||
|
|
@ -558,9 +570,12 @@ class FunctionEvent(FormattedTimesMixin):
|
|||
the child's range interval is completely inside the parent's. We use
|
||||
this connection to determine the event is from top-level op or not.
|
||||
"""
|
||||
assert self.device_type == DeviceType.CPU
|
||||
assert isinstance(parent, FunctionEvent)
|
||||
assert parent.device_type == DeviceType.CPU
|
||||
if self.device_type != DeviceType.CPU:
|
||||
raise AssertionError("Expected device_type to be CPU")
|
||||
if not isinstance(parent, FunctionEvent):
|
||||
raise AssertionError("Expected parent to be a FunctionEvent")
|
||||
if parent.device_type != DeviceType.CPU:
|
||||
raise AssertionError("Expected parent device_type to be CPU")
|
||||
self.cpu_parent = parent
|
||||
|
||||
# Note: async events don't have children, are not used when computing 'self'
|
||||
|
|
@ -618,12 +633,15 @@ class FunctionEvent(FormattedTimesMixin):
|
|||
# each legacy cpu events has a single (fake) kernel
|
||||
return sum(kinfo.duration for kinfo in self.kernels)
|
||||
else:
|
||||
assert self.device_type in [
|
||||
if self.device_type not in [
|
||||
DeviceType.CUDA,
|
||||
DeviceType.PrivateUse1,
|
||||
DeviceType.MTIA,
|
||||
DeviceType.HPU,
|
||||
]
|
||||
]:
|
||||
raise AssertionError(
|
||||
f"Expected device_type to be CUDA, PrivateUse1, MTIA, or HPU, but got {self.device_type}"
|
||||
)
|
||||
return self.time_range.elapsed_us()
|
||||
|
||||
@property
|
||||
|
|
@ -643,12 +661,15 @@ class FunctionEvent(FormattedTimesMixin):
|
|||
child.device_time_total for child in self.cpu_children
|
||||
)
|
||||
else:
|
||||
assert self.device_type in [
|
||||
if self.device_type not in [
|
||||
DeviceType.CUDA,
|
||||
DeviceType.PrivateUse1,
|
||||
DeviceType.MTIA,
|
||||
DeviceType.HPU,
|
||||
]
|
||||
]:
|
||||
raise AssertionError(
|
||||
f"Expected device_type to be CUDA, PrivateUse1, MTIA, or HPU, but got {self.device_type}"
|
||||
)
|
||||
return self.device_time_total
|
||||
|
||||
@property
|
||||
|
|
@ -726,8 +747,14 @@ class FunctionEventAvg(FormattedTimesMixin):
|
|||
self.use_device = other.use_device
|
||||
self.is_user_annotation = other.is_user_annotation
|
||||
|
||||
assert isinstance(other, (FunctionEvent, FunctionEventAvg))
|
||||
assert other.key == self.key
|
||||
if not isinstance(other, (FunctionEvent, FunctionEventAvg)):
|
||||
raise AssertionError(
|
||||
"Expected other to be a FunctionEvent or FunctionEventAvg"
|
||||
)
|
||||
if other.key != self.key:
|
||||
raise AssertionError(
|
||||
f"Expected keys to match, but got {other.key} vs {self.key}"
|
||||
)
|
||||
|
||||
self.cpu_time_total += other.cpu_time_total
|
||||
self.device_time_total += other.device_time_total
|
||||
|
|
@ -974,10 +1001,14 @@ def _build_table(
|
|||
"TFLOPs",
|
||||
"PFLOPs",
|
||||
]
|
||||
assert flops > 0
|
||||
if flops <= 0:
|
||||
raise AssertionError(f"Expected flops to be positive, but got {flops}")
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
log_flops = max(0, min(math.log10(flops) / 3, float(len(flop_headers) - 1)))
|
||||
assert log_flops >= 0 and log_flops < len(flop_headers)
|
||||
if not (log_flops >= 0 and log_flops < len(flop_headers)):
|
||||
raise AssertionError(
|
||||
f"Expected log_flops to be in range [0, {len(flop_headers)}), but got {log_flops}"
|
||||
)
|
||||
return (pow(10, (math.floor(log_flops) * -3.0)), flop_headers[int(log_flops)])
|
||||
|
||||
add_column(name_column_width)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user