Fix flake8 B028 warnings (#166224)

This PR fixes flake8 B028 warning by specifying stacklevel=2 in `warnings.warn`. The advantage is that users can know more contextual information about PyTorch warnings.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166224
Approved by: https://github.com/ezyang
This commit is contained in:
Yuanyuan Chen 2025-10-26 06:18:52 +00:00 committed by PyTorch MergeBot
parent f863550192
commit a60d9e1f6d
147 changed files with 598 additions and 311 deletions

View File

@ -2653,7 +2653,8 @@ def compile(
if torch.compiler.is_exporting(): if torch.compiler.is_exporting():
warnings.warn( warnings.warn(
"You are calling torch.compile inside torch.export region. " "You are calling torch.compile inside torch.export region. "
"To capture an useful graph, we will implicitly switch to torch.compile(backend=eager)" "To capture an useful graph, we will implicitly switch to torch.compile(backend=eager)",
stacklevel=2,
) )
from torch._higher_order_ops.utils import setup_compilation_env from torch._higher_order_ops.utils import setup_compilation_env

View File

@ -55,6 +55,7 @@ def warn_deprecated():
"torch._custom_op is deprecated and will be removed in PyTorch 2.6, please " "torch._custom_op is deprecated and will be removed in PyTorch 2.6, please "
"use the equivalent torch.library API instead.", "use the equivalent torch.library API instead.",
DeprecationWarning, DeprecationWarning,
stacklevel=2,
) )

View File

@ -704,7 +704,8 @@ class TS2FXGraphConverter:
# In a sense, the converter now becomes an stateful interpreter # In a sense, the converter now becomes an stateful interpreter
warnings.warn( warnings.warn(
"Converting aten::append.t, which is a inplace mutation of the list. " "Converting aten::append.t, which is a inplace mutation of the list. "
"This makes the converter non-functional: the result depends on the order of the append nodes being converter!" "This makes the converter non-functional: the result depends on the order of the append nodes being converter!",
stacklevel=2,
) )
args = tuple(self.get_fx_value_by_ir_value(inp) for inp in node.inputs()) args = tuple(self.get_fx_value_by_ir_value(inp) for inp in node.inputs())
@ -1471,7 +1472,8 @@ DEBUG: (TORCH_LOGS="+export" <cmd>), additionally
for k, tensor in self.ts_model.state_dict().items(): # type: ignore[union-attr] for k, tensor in self.ts_model.state_dict().items(): # type: ignore[union-attr]
if k not in ep.state_dict: if k not in ep.state_dict:
warnings.warn( warnings.warn(
f"Manually populate {k} into state_dict ExportedProgram, but it is never used by the ExportedProgram." f"Manually populate {k} into state_dict ExportedProgram, but it is never used by the ExportedProgram.",
stacklevel=2,
) )
ep.state_dict[k] = tensor ep.state_dict[k] = tensor

View File

@ -51,7 +51,8 @@ def _generate_inputs_for_submodules(
model(*args, **kwargs) model(*args, **kwargs)
except Exception as e: except Exception as e:
warnings.warn( warnings.warn(
f"Failed to generate submodule inputs because of the following error:\n{e}" f"Failed to generate submodule inputs because of the following error:\n{e}",
stacklevel=2,
) )
finally: finally:
for h in handles: for h in handles:

View File

@ -321,5 +321,6 @@ def _detect_attribute_assignment(mod: torch.nn.Module):
warnings.warn( warnings.warn(
f"The tensor {noun} {', '.join(assigned_tensor_attributes)} {verb} assigned during export. " f"The tensor {noun} {', '.join(assigned_tensor_attributes)} {verb} assigned during export. "
"Such attributes must be registered as buffers using the `register_buffer` API " "Such attributes must be registered as buffers using the `register_buffer` API "
"(https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_buffer)." "(https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_buffer).",
stacklevel=2,
) )

View File

@ -137,7 +137,8 @@ def call_func_at_runtime_with_args(
warnings.warn( warnings.warn(
"Your compiler for AOTAutograd is returning a function that doesn't take boxed arguments. " "Your compiler for AOTAutograd is returning a function that doesn't take boxed arguments. "
"Please wrap it with functorch.compile.make_boxed_func or handle the boxed arguments yourself. " "Please wrap it with functorch.compile.make_boxed_func or handle the boxed arguments yourself. "
"See https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670 for rationale." "See https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670 for rationale.",
stacklevel=2,
) )
out = normalize_as_list(f(*args)) out = normalize_as_list(f(*args))
return out return out

View File

@ -518,7 +518,8 @@ def do_auto_functionalize(
if "self" in unwrapped_kwargs or "self_" in unwrapped_kwargs: if "self" in unwrapped_kwargs or "self_" in unwrapped_kwargs:
warnings.warn( warnings.warn(
"Using `self` or `self_` as an argument in the definition of custom ops may lead to ambiguous parsing. " "Using `self` or `self_` as an argument in the definition of custom ops may lead to ambiguous parsing. "
"Please consider using a different name for this argument to avoid potential issues." "Please consider using a different name for this argument to avoid potential issues.",
stacklevel=2,
) )
with ctx.redispatch_to_next(): with ctx.redispatch_to_next():
unwrapped_outs = auto_functionalized( unwrapped_outs = auto_functionalized(
@ -691,7 +692,8 @@ def do_auto_functionalize_v2(
if "self" in unwrapped_kwargs or "self_" in unwrapped_kwargs: if "self" in unwrapped_kwargs or "self_" in unwrapped_kwargs:
warnings.warn( warnings.warn(
"Using `self` or `self_` as an argument in the definition of custom ops may lead to ambiguous parsing. " "Using `self` or `self_` as an argument in the definition of custom ops may lead to ambiguous parsing. "
"Please consider using a different name for this argument to avoid potential issues." "Please consider using a different name for this argument to avoid potential issues.",
stacklevel=2,
) )
all_basis_unwrapped = ctx.unwrap_tensors(all_bases) all_basis_unwrapped = ctx.unwrap_tensors(all_bases)

View File

@ -196,7 +196,8 @@ class BaseHOP(HigherOrderOperator, abc.ABC):
"Aliasing is not supported for HOP subgraph.\n" "Aliasing is not supported for HOP subgraph.\n"
f"{subgraph.print_readable(print_output=False)}\n" f"{subgraph.print_readable(print_output=False)}\n"
f"Alias info: inp-inp alias: {inp_inp_alias}, inp-out alias: {inp_out_alias}, out-out alias{out_out_alias}" f"Alias info: inp-inp alias: {inp_inp_alias}, inp-out alias: {inp_out_alias}, out-out alias{out_out_alias}"
f"This may lead to silent incorrectness." f"This may lead to silent incorrectness.",
stacklevel=2,
) )
schema_gen = HopSchemaGenerator(self) schema_gen = HopSchemaGenerator(self)

View File

@ -177,6 +177,7 @@ def cond(
"Pred is a Python constant. When used with torch.cond, it specializes on one of the branches." "Pred is a Python constant. When used with torch.cond, it specializes on one of the branches."
" If you want torch.cond to preserve two branches, please make the predicate a boolean tensor or a SymBool.", " If you want torch.cond to preserve two branches, please make the predicate a boolean tensor or a SymBool.",
UserWarning, UserWarning,
stacklevel=2,
) )
# This is the eager case. We can just run the true or false branch. # This is the eager case. We can just run the true or false branch.
if pred: if pred:

View File

@ -859,6 +859,7 @@ def ignore(drop=False, **kwargs):
warnings.warn( warnings.warn(
"ignore(drop_on_export=True) has been deprecated. TorchScript will now drop the function " "ignore(drop_on_export=True) has been deprecated. TorchScript will now drop the function "
"call on compilation. Use torch.jit.unused now. {}", "call on compilation. Use torch.jit.unused now. {}",
stacklevel=2,
category=FutureWarning, category=FutureWarning,
) )
@ -867,6 +868,7 @@ def ignore(drop=False, **kwargs):
warnings.warn( warnings.warn(
"ignore(True) has been deprecated. TorchScript will now drop the function " "ignore(True) has been deprecated. TorchScript will now drop the function "
"call on compilation. Use torch.jit.unused now. {}", "call on compilation. Use torch.jit.unused now. {}",
stacklevel=2,
category=FutureWarning, category=FutureWarning,
) )
@ -992,7 +994,8 @@ def _check_overload_body(func):
# Parsing the function definition can raise an OSError if source is unavailable. # Parsing the function definition can raise an OSError if source is unavailable.
# Since this is just an initial check, just raise a warning if this is the case. # Since this is just an initial check, just raise a warning if this is the case.
warnings.warn( warnings.warn(
f"Unable to retrieve source for @torch.jit._overload function: {func}." f"Unable to retrieve source for @torch.jit._overload function: {func}.",
stacklevel=2,
) )
return return
@ -1385,7 +1388,8 @@ def check_empty_containers(obj) -> None:
"calling torch.jit.isinstance in eager mode. For " "calling torch.jit.isinstance in eager mode. For "
"example, List[int] would become list and " "example, List[int] would become list and "
"therefore falsely return True for List[float] or" "therefore falsely return True for List[float] or"
" List[str]." " List[str].",
stacklevel=2,
) )

View File

@ -2137,7 +2137,8 @@ def alert_not_deterministic(caller: str):
f"{caller} does not have a deterministic implementation, but you set " f"{caller} does not have a deterministic implementation, but you set "
f"'torch.use_deterministic_algorithms(True, warn_only=True)'. " f"'torch.use_deterministic_algorithms(True, warn_only=True)'. "
f"You can file an issue at https://github.com/pytorch/pytorch/issues " f"You can file an issue at https://github.com/pytorch/pytorch/issues "
f"to help us prioritize adding deterministic support for this operation." f"to help us prioritize adding deterministic support for this operation.",
stacklevel=2,
) )
else: else:
torch._check( torch._check(

View File

@ -180,7 +180,7 @@ def _resize_output_check(out: TensorLikeType, shape: ShapeType):
"be resized unless they have zero elements. " "be resized unless they have zero elements. "
"You can explicitly reuse an out tensor t by resizing it, inplace, to zero elements with t.resize_(0)." "You can explicitly reuse an out tensor t by resizing it, inplace, to zero elements with t.resize_(0)."
) )
warnings.warn(msg) warnings.warn(msg, stacklevel=2)
return True return True

View File

@ -3729,7 +3729,8 @@ def istft(
if end > expected_output_signal_len: if end > expected_output_signal_len:
warnings.warn( warnings.warn(
"The length of signal is shorter than the length parameter. Result is being " "The length of signal is shorter than the length parameter. Result is being "
+ "padded with zeros in the tail. Please check your center and hop_length settings" + "padded with zeros in the tail. Please check your center and hop_length settings",
stacklevel=2,
) )
y = aten.constant_pad_nd(y, (0, end - expected_output_signal_len), 0) y = aten.constant_pad_nd(y, (0, end - expected_output_signal_len), 0)
return y return y

View File

@ -405,7 +405,8 @@ class FunctionalTensorMode(TorchDispatchMode):
warnings.warn( warnings.warn(
f"At pre-dispatch tracing, we assume that any custom op marked with " f"At pre-dispatch tracing, we assume that any custom op marked with "
f"CompositeImplicitAutograd and have functional schema are safe to not decompose. " f"CompositeImplicitAutograd and have functional schema are safe to not decompose. "
f"Found {func} to be one such op." f"Found {func} to be one such op.",
stacklevel=2,
) )
return False return False
return True return True

View File

@ -350,7 +350,8 @@ class Tensor(torch._C.TensorBase):
# hypothesis is that no one cares for meta tensors. # hypothesis is that no one cares for meta tensors.
if skip_data: if skip_data:
warnings.warn( warnings.warn(
"Serializing tensors on the meta device under skip_data context manager is a no-op" "Serializing tensors on the meta device under skip_data context manager is a no-op",
stacklevel=2,
) )
arg_meta = ( arg_meta = (
self.dtype, self.dtype,
@ -1033,7 +1034,7 @@ class Tensor(torch._C.TensorBase):
def resize(self, *sizes): def resize(self, *sizes):
if has_torch_function_unary(self): if has_torch_function_unary(self):
return handle_torch_function(Tensor.resize, (self,), self, *sizes) return handle_torch_function(Tensor.resize, (self,), self, *sizes)
warnings.warn("non-inplace resize is deprecated") warnings.warn("non-inplace resize is deprecated", stacklevel=2)
from torch.autograd._functions import Resize from torch.autograd._functions import Resize
return Resize.apply(self, sizes) return Resize.apply(self, sizes)
@ -1041,7 +1042,7 @@ class Tensor(torch._C.TensorBase):
def resize_as(self, tensor): def resize_as(self, tensor):
if has_torch_function_variadic(self, tensor): if has_torch_function_variadic(self, tensor):
return handle_torch_function(Tensor.resize_as, (self, tensor), self, tensor) return handle_torch_function(Tensor.resize_as, (self, tensor), self, tensor)
warnings.warn("non-inplace resize_as is deprecated") warnings.warn("non-inplace resize_as is deprecated", stacklevel=2)
from torch.autograd._functions import Resize from torch.autograd._functions import Resize
return Resize.apply(self, tensor.size()) return Resize.apply(self, tensor.size())

View File

@ -118,7 +118,7 @@ def _get_async_or_non_blocking(function_name, non_blocking, kwargs):
message = "{}() got an unexpected keyword argument '{}'" message = "{}() got an unexpected keyword argument '{}'"
argument = list(kwargs.keys()).pop() argument = list(kwargs.keys()).pop()
raise TypeError(message.format(function_name, argument)) raise TypeError(message.format(function_name, argument))
warnings.warn("'async' is deprecated; use 'non_blocking'") warnings.warn("'async' is deprecated; use 'non_blocking'", stacklevel=2)
return kwargs["async"] return kwargs["async"]

View File

@ -555,7 +555,8 @@ class Unpickler:
f"Detected pickle protocol {self.proto} in the checkpoint, which was " f"Detected pickle protocol {self.proto} in the checkpoint, which was "
"not the default pickle protocol used by `torch.load` (2). The weights_only " "not the default pickle protocol used by `torch.load` (2). The weights_only "
"Unpickler might not support all instructions implemented by this protocol, " "Unpickler might not support all instructions implemented by this protocol, "
"please file an issue for adding support if you encounter this." "please file an issue for adding support if you encounter this.",
stacklevel=2,
) )
elif key[0] == STOP[0]: elif key[0] == STOP[0]:
rc = self.stack.pop() rc = self.stack.pop()

View File

@ -267,7 +267,8 @@ class autocast:
and torch.cuda.amp.common.amp_definitely_not_available() and torch.cuda.amp.common.amp_definitely_not_available()
): ):
warnings.warn( warnings.warn(
"User provided device_type of 'cuda', but CUDA is not available. Disabling" "User provided device_type of 'cuda', but CUDA is not available. Disabling",
stacklevel=2,
) )
enabled = False enabled = False
if cache_enabled is not None: if cache_enabled is not None:
@ -281,42 +282,42 @@ class autocast:
error_message += ( error_message += (
", ".join(str(dtype) for dtype in supported_dtype) + " currently." ", ".join(str(dtype) for dtype in supported_dtype) + " currently."
) )
warnings.warn(error_message) warnings.warn(error_message, stacklevel=2)
enabled = False enabled = False
elif self.device == "mtia": elif self.device == "mtia":
supported_dtype = [torch.bfloat16, torch.float16] supported_dtype = [torch.bfloat16, torch.float16]
if self.fast_dtype not in supported_dtype: if self.fast_dtype not in supported_dtype:
error_message = "In MTIA autocast, but the target dtype is not supported. Disabling autocast.\n" error_message = "In MTIA autocast, but the target dtype is not supported. Disabling autocast.\n"
error_message += "MTIA Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently." error_message += "MTIA Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently."
warnings.warn(error_message) warnings.warn(error_message, stacklevel=2)
enabled = False enabled = False
elif self.device == "maia": elif self.device == "maia":
supported_dtype = [torch.bfloat16, torch.float16] supported_dtype = [torch.bfloat16, torch.float16]
if self.fast_dtype not in supported_dtype: if self.fast_dtype not in supported_dtype:
error_message = "In MAIA autocast, but the target dtype is not supported. Disabling autocast.\n" error_message = "In MAIA autocast, but the target dtype is not supported. Disabling autocast.\n"
error_message += "MAIA Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently." error_message += "MAIA Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently."
warnings.warn(error_message) warnings.warn(error_message, stacklevel=2)
enabled = False enabled = False
elif self.device == "xpu": elif self.device == "xpu":
supported_dtype = [torch.bfloat16, torch.float16] supported_dtype = [torch.bfloat16, torch.float16]
if self.fast_dtype not in supported_dtype: if self.fast_dtype not in supported_dtype:
error_message = "In XPU autocast, but the target dtype is not supported. Disabling autocast.\n" error_message = "In XPU autocast, but the target dtype is not supported. Disabling autocast.\n"
error_message += "XPU Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently." error_message += "XPU Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently."
warnings.warn(error_message) warnings.warn(error_message, stacklevel=2)
enabled = False enabled = False
elif self.device == "ipu": elif self.device == "ipu":
supported_dtypes = [torch.bfloat16, torch.float16] supported_dtypes = [torch.bfloat16, torch.float16]
if self.fast_dtype not in supported_dtypes: if self.fast_dtype not in supported_dtypes:
error_message = "In IPU autocast, but the target dtype is not supported. Disabling autocast.\n" error_message = "In IPU autocast, but the target dtype is not supported. Disabling autocast.\n"
error_message += "IPU Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently." error_message += "IPU Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently."
warnings.warn(error_message) warnings.warn(error_message, stacklevel=2)
enabled = False enabled = False
elif self.device == "hpu": elif self.device == "hpu":
supported_dtype = [torch.bfloat16, torch.float16] supported_dtype = [torch.bfloat16, torch.float16]
if self.fast_dtype not in supported_dtype: if self.fast_dtype not in supported_dtype:
error_message = "In HPU autocast, but the target dtype is not supported. Disabling autocast.\n" error_message = "In HPU autocast, but the target dtype is not supported. Disabling autocast.\n"
error_message += "HPU Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently." error_message += "HPU Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently."
warnings.warn(error_message) warnings.warn(error_message, stacklevel=2)
enabled = False enabled = False
elif self.device == self.custom_backend_name: elif self.device == self.custom_backend_name:
supported_dtype = self.custom_device_mod.get_amp_supported_dtype() supported_dtype = self.custom_device_mod.get_amp_supported_dtype()
@ -326,7 +327,7 @@ class autocast:
error_message += ( error_message += (
", ".join(str(dtype) for dtype in supported_dtype) + " currently." ", ".join(str(dtype) for dtype in supported_dtype) + " currently."
) )
warnings.warn(error_message) warnings.warn(error_message, stacklevel=2)
enabled = False enabled = False
elif self.device == "cuda": elif self.device == "cuda":
if ( if (
@ -344,7 +345,7 @@ class autocast:
"In MPS autocast, but the target dtype is not supported. Disabling autocast.\n" "In MPS autocast, but the target dtype is not supported. Disabling autocast.\n"
"MPS Autocast only supports dtype of torch.bfloat16 and torch.float16 currently." "MPS Autocast only supports dtype of torch.bfloat16 and torch.float16 currently."
) )
warnings.warn(error_message) warnings.warn(error_message, stacklevel=2)
enabled = False enabled = False
elif self.fast_dtype == torch.bfloat16: elif self.fast_dtype == torch.bfloat16:
if not torch.backends.mps.is_macos_or_newer(14, 0): if not torch.backends.mps.is_macos_or_newer(14, 0):
@ -352,7 +353,7 @@ class autocast:
"In MPS autocast, but the target dtype torch.bfloat16 is not supported " "In MPS autocast, but the target dtype torch.bfloat16 is not supported "
"on macOS versions below 14. Disabling autocast." "on macOS versions below 14. Disabling autocast."
) )
warnings.warn(error_message) warnings.warn(error_message, stacklevel=2)
enabled = False enabled = False
elif self.device == "xla": elif self.device == "xla":
supported_dtype = [torch.float16, torch.bfloat16] supported_dtype = [torch.float16, torch.bfloat16]
@ -361,7 +362,7 @@ class autocast:
error_message += ( error_message += (
"XLA Autocast only supports dtype of torch.bfloat16 currently." "XLA Autocast only supports dtype of torch.bfloat16 currently."
) )
warnings.warn(error_message) warnings.warn(error_message, stacklevel=2)
enabled = False enabled = False
self._enabled = enabled self._enabled = enabled

View File

@ -422,6 +422,7 @@ class GradScaler:
"optimizer. In the near future GradScaler registers `grad_scale: Tensor` and " "optimizer. In the near future GradScaler registers `grad_scale: Tensor` and "
"`found_inf: Tensor` to the passed optimizer and let the optimizer use them directly.", "`found_inf: Tensor` to the passed optimizer and let the optimizer use them directly.",
FutureWarning, FutureWarning,
stacklevel=2,
) )
kwargs_.update({"grad_scaler": self}) kwargs_.update({"grad_scaler": self})
else: else:

View File

@ -469,14 +469,16 @@ class LSTM(torch.nn.Module):
warnings.warn( warnings.warn(
"dropout option for quantizable LSTM is ignored. " "dropout option for quantizable LSTM is ignored. "
"If you are training, please, use nn.LSTM version " "If you are training, please, use nn.LSTM version "
"followed by `prepare` step." "followed by `prepare` step.",
stacklevel=2,
) )
if num_layers == 1: if num_layers == 1:
warnings.warn( warnings.warn(
"dropout option adds dropout after all but last " "dropout option adds dropout after all but last "
"recurrent layer, so non-zero dropout expects " "recurrent layer, so non-zero dropout expects "
f"num_layers greater than 1, but got dropout={dropout} " f"num_layers greater than 1, but got dropout={dropout} "
f"and num_layers={num_layers}" f"and num_layers={num_layers}",
stacklevel=2,
) )
layers = [ layers = [

View File

@ -68,7 +68,8 @@ class Conv1d(nnq.Conv1d):
reduce_range=True, reduce_range=True,
): ):
warnings.warn( warnings.warn(
f"The current implementation of the {self._get_name()} module has poor numerical accuracy and its use is not recommended" # noqa: B950 f"The current implementation of the {self._get_name()} module has poor numerical accuracy and its use is not recommended", # noqa: B950
stacklevel=2,
) )
factory_kwargs = {"device": device, "dtype": dtype} factory_kwargs = {"device": device, "dtype": dtype}
kernel_size = _single(kernel_size) kernel_size = _single(kernel_size)
@ -155,7 +156,8 @@ class Conv2d(nnq.Conv2d):
): ):
warnings.warn( warnings.warn(
f"The current implementation of the {self._get_name()} module " f"The current implementation of the {self._get_name()} module "
"has poor numerical accuracy and its use is not recommended" "has poor numerical accuracy and its use is not recommended",
stacklevel=2,
) )
factory_kwargs = {"device": device, "dtype": dtype} factory_kwargs = {"device": device, "dtype": dtype}
kernel_size = _pair(kernel_size) kernel_size = _pair(kernel_size)
@ -239,7 +241,8 @@ class Conv3d(nnq.Conv3d):
dtype=None, dtype=None,
): ):
warnings.warn( warnings.warn(
f"The current implementation of the {self._get_name()} module has poor numerical accuracy and its use is not recommended" # noqa: B950 f"The current implementation of the {self._get_name()} module has poor numerical accuracy and its use is not recommended", # noqa: B950
stacklevel=2,
) )
assert padding_mode != "reflect", "Conv3d does not support reflection padding" assert padding_mode != "reflect", "Conv3d does not support reflection padding"
factory_kwargs = {"device": device, "dtype": dtype} factory_kwargs = {"device": device, "dtype": dtype}
@ -330,7 +333,8 @@ class ConvTranspose1d(nnq.ConvTranspose1d):
dtype=None, dtype=None,
): ):
warnings.warn( warnings.warn(
f"The current implementation of the {self._get_name()} module has poor numerical accuracy and its use is not recommended" # noqa: B950 f"The current implementation of the {self._get_name()} module has poor numerical accuracy and its use is not recommended", # noqa: B950
stacklevel=2,
) )
factory_kwargs = {"device": device, "dtype": dtype} factory_kwargs = {"device": device, "dtype": dtype}
super().__init__( super().__init__(
@ -412,7 +416,8 @@ class ConvTranspose2d(nnq.ConvTranspose2d):
dtype=None, dtype=None,
): ):
warnings.warn( warnings.warn(
f"The current implementation of the {self._get_name()} module has poor numerical accuracy and its use is not recommended" # noqa: B950 f"The current implementation of the {self._get_name()} module has poor numerical accuracy and its use is not recommended", # noqa: B950
stacklevel=2,
) )
factory_kwargs = {"device": device, "dtype": dtype} factory_kwargs = {"device": device, "dtype": dtype}
super().__init__( super().__init__(
@ -494,7 +499,8 @@ class ConvTranspose3d(nnq.ConvTranspose3d):
dtype=None, dtype=None,
): ):
warnings.warn( warnings.warn(
f"The current implementation of the {self._get_name()} module has poor numerical accuracy and its use is not recommended" # noqa: B950 f"The current implementation of the {self._get_name()} module has poor numerical accuracy and its use is not recommended", # noqa: B950
stacklevel=2,
) )
factory_kwargs = {"device": device, "dtype": dtype} factory_kwargs = {"device": device, "dtype": dtype}
super().__init__( super().__init__(

View File

@ -136,7 +136,8 @@ class RNNBase(torch.nn.Module):
"dropout option adds dropout after all but last " "dropout option adds dropout after all but last "
"recurrent layer, so non-zero dropout expects " "recurrent layer, so non-zero dropout expects "
f"num_layers greater than 1, but got dropout={dropout} and " f"num_layers greater than 1, but got dropout={dropout} and "
f"num_layers={num_layers}" f"num_layers={num_layers}",
stacklevel=2,
) )
if mode == "LSTM": if mode == "LSTM":

View File

@ -724,7 +724,8 @@ def upsample(input, size=None, scale_factor=None, mode="nearest", align_corners=
affects the outputs. affects the outputs.
""" """
warnings.warn( warnings.warn(
"nn.quantized.functional.upsample is deprecated. Use nn.quantized.functional.interpolate instead." "nn.quantized.functional.upsample is deprecated. Use nn.quantized.functional.interpolate instead.",
stacklevel=2,
) )
return interpolate(input, size, scale_factor, mode, align_corners) return interpolate(input, size, scale_factor, mode, align_corners)
@ -749,7 +750,8 @@ def upsample_bilinear(input, size=None, scale_factor=None):
""" """
# DeprecationWarning is ignored by default # DeprecationWarning is ignored by default
warnings.warn( warnings.warn(
"nn.quantized.functional.upsample_bilinear is deprecated. Use nn.quantized.functional.interpolate instead." "nn.quantized.functional.upsample_bilinear is deprecated. Use nn.quantized.functional.interpolate instead.",
stacklevel=2,
) )
return interpolate(input, size, scale_factor, mode="bilinear", align_corners=True) return interpolate(input, size, scale_factor, mode="bilinear", align_corners=True)
@ -774,6 +776,7 @@ def upsample_nearest(input, size=None, scale_factor=None):
""" """
# DeprecationWarning is ignored by default # DeprecationWarning is ignored by default
warnings.warn( warnings.warn(
"nn.quantized.functional.upsample_nearest is deprecated. Use nn.quantized.functional.interpolate instead." "nn.quantized.functional.upsample_nearest is deprecated. Use nn.quantized.functional.interpolate instead.",
stacklevel=2,
) )
return interpolate(input, size, scale_factor, mode="nearest") return interpolate(input, size, scale_factor, mode="nearest")

View File

@ -322,7 +322,8 @@ class PReLU(torch.nn.Module):
observer(float_wt) observer(float_wt)
if observer.dtype != torch.quint8: if observer.dtype != torch.quint8:
warn( warn(
f"PReLU's weight observer should have dtype quint8 but got {observer.dtype}" f"PReLU's weight observer should have dtype quint8 but got {observer.dtype}",
stacklevel=2,
) )
wt_scale, wt_zp = observer.calculate_qparams() wt_scale, wt_zp = observer.calculate_qparams()
qweight = torch.quantize_per_tensor( qweight = torch.quantize_per_tensor(
@ -339,7 +340,8 @@ class PReLU(torch.nn.Module):
observer(float_wt) observer(float_wt)
if observer.dtype != torch.quint8: if observer.dtype != torch.quint8:
warn( warn(
f"PReLU's weight observer should have dtype quint8 but got {observer.dtype}" f"PReLU's weight observer should have dtype quint8 but got {observer.dtype}",
stacklevel=2,
) )
wt_scale, wt_zp = observer.calculate_qparams() wt_scale, wt_zp = observer.calculate_qparams()
qweight = torch.quantize_per_tensor( qweight = torch.quantize_per_tensor(

View File

@ -213,7 +213,8 @@ class ActivationSparsifier:
if name in self.data_groups: # unregister layer if already present if name in self.data_groups: # unregister layer if already present
warnings.warn( warnings.warn(
"layer already attached to the sparsifier, deregistering the layer and registering with new config" "layer already attached to the sparsifier, deregistering the layer and registering with new config",
stacklevel=2,
) )
self.unregister_layer(name=name) self.unregister_layer(name=name)

View File

@ -158,6 +158,7 @@ class BaseDataScheduler:
"initialization. Please, make sure to call `data_sparsifier.step()` before " "initialization. Please, make sure to call `data_sparsifier.step()` before "
"`scheduler.step()`.", "`scheduler.step()`.",
UserWarning, UserWarning,
stacklevel=2,
) )
# Just check if there were two first scheduler.step() calls before sparsifier.step() # Just check if there were two first scheduler.step() calls before sparsifier.step()
@ -167,6 +168,7 @@ class BaseDataScheduler:
"You have to make sure you run the data_sparsifier.step() BEFORE any " "You have to make sure you run the data_sparsifier.step() BEFORE any "
"calls to the scheduler.step().", "calls to the scheduler.step().",
UserWarning, UserWarning,
stacklevel=2,
) )
self._step_count += 1 self._step_count += 1

View File

@ -105,7 +105,8 @@ class BaseDataSparsifier(base_sparsifier.BaseSparsifier):
if name in self.state: if name in self.state:
# If the named data already exists - replace # If the named data already exists - replace
warnings.warn( warnings.warn(
"Replacing existing data of the same name. - Did you mean a different name?" "Replacing existing data of the same name. - Did you mean a different name?",
stacklevel=2,
) )
# reuse old config # reuse old config

View File

@ -74,6 +74,7 @@ class StepSLScheduler(BaseDataScheduler):
"To get the last learning rate computed by the scheduler, " "To get the last learning rate computed by the scheduler, "
"please use `get_last_lr()`.", "please use `get_last_lr()`.",
UserWarning, UserWarning,
stacklevel=2,
) )
data_groups = self.data_sparsifier.data_groups data_groups = self.data_sparsifier.data_groups
if (self.last_epoch == 0) or (self.last_epoch % self.step_size != 0): if (self.last_epoch == 0) or (self.last_epoch % self.step_size != 0):

View File

@ -92,7 +92,8 @@ class BaseScheduler:
if not self._get_sl_called_within_step: if not self._get_sl_called_within_step:
warnings.warn( warnings.warn(
"To get the last sparsity level computed by the scheduler, " "To get the last sparsity level computed by the scheduler, "
"please use `get_last_sl()`." "please use `get_last_sl()`.",
stacklevel=2,
) )
raise NotImplementedError raise NotImplementedError
@ -124,6 +125,7 @@ class BaseScheduler:
"initialization. Please, make sure to call `sparsifier.step()` before " "initialization. Please, make sure to call `sparsifier.step()` before "
"`scheduler.step()`.", "`scheduler.step()`.",
UserWarning, UserWarning,
stacklevel=2,
) )
# Just check if there were two first scheduler.step() calls before sparsifier.step() # Just check if there were two first scheduler.step() calls before sparsifier.step()
@ -133,6 +135,7 @@ class BaseScheduler:
"You have to make sure you run the sparsifier.step() BEFORE any " "You have to make sure you run the sparsifier.step() BEFORE any "
"calls to the scheduler.step().", "calls to the scheduler.step().",
UserWarning, UserWarning,
stacklevel=2,
) )
self._step_count += 1 self._step_count += 1

View File

@ -90,7 +90,8 @@ class CubicSL(BaseScheduler):
if not self._get_sl_called_within_step: if not self._get_sl_called_within_step:
warnings.warn( warnings.warn(
"To get the last sparsity level computed by the scheduler, " "To get the last sparsity level computed by the scheduler, "
"please use `get_last_sl()`." "please use `get_last_sl()`.",
stacklevel=2,
) )
return [ return [
self.sparsity_compute_fn( self.sparsity_compute_fn(

View File

@ -56,7 +56,8 @@ class LambdaSL(BaseScheduler):
if not self._get_sl_called_within_step: if not self._get_sl_called_within_step:
warnings.warn( warnings.warn(
"To get the last sparsity level computed by the scheduler, " "To get the last sparsity level computed by the scheduler, "
"please use `get_last_sl()`." "please use `get_last_sl()`.",
stacklevel=2,
) )
return [ return [
base_sl * lmbda(self.last_epoch) base_sl * lmbda(self.last_epoch)

View File

@ -121,7 +121,8 @@ class _InputEqualizationObserver(nn.Module):
): ):
warnings.warn( warnings.warn(
"Must call calculate_equalization_scale before calling calculate_scaled_minmax. " "Must call calculate_equalization_scale before calling calculate_scaled_minmax. "
+ "Will not scale the next quantization observer." + "Will not scale the next quantization observer.",
stacklevel=2,
) )
return None, None return None, None
@ -226,7 +227,8 @@ def calculate_equalization_scale(
): ):
warnings.warn( warnings.warn(
"Must run observer before calling calculate_equalization_scale. " "Must run observer before calling calculate_equalization_scale. "
+ "Returning default equalization scale torch.tensor(1)." + "Returning default equalization scale torch.tensor(1).",
stacklevel=2,
) )
return torch.tensor(1) return torch.tensor(1)

View File

@ -597,7 +597,8 @@ def _maybe_recursive_remove_dequantize(arg: Any, node: Node, graph: Graph) -> No
_maybe_recursive_remove_dequantize(arg_element, node, graph) _maybe_recursive_remove_dequantize(arg_element, node, graph)
else: else:
warnings.warn( warnings.warn(
f"Unsupported node type in recursive remove dequantize: {type(arg)}" f"Unsupported node type in recursive remove dequantize: {type(arg)}",
stacklevel=2,
) )
@ -1197,7 +1198,8 @@ def convert(
_maybe_recursive_remove_dequantize(output, return_node, model.graph) _maybe_recursive_remove_dequantize(output, return_node, model.graph)
else: else:
warnings.warn( warnings.warn(
f"Unsupported node type for output_quantized_idxs: {type(output)}" f"Unsupported node type for output_quantized_idxs: {type(output)}",
stacklevel=2,
) )
elif node.op == "call_module": elif node.op == "call_module":
mod = _get_module(node, modules) mod = _get_module(node, modules)

View File

@ -1055,7 +1055,9 @@ def _maybe_insert_input_equalization_observers_for_node(
return return
if is_branch: if is_branch:
warnings.warn(f"Cannot equalize {node} because it is part of a branch.") warnings.warn(
f"Cannot equalize {node} because it is part of a branch.", stacklevel=2
)
return return
new_args = [] new_args = []

View File

@ -890,7 +890,8 @@ def _qconfig_satisfies_dtype_config_constraints(
if backend_quant_min is not None and backend_quant_max is not None: if backend_quant_min is not None and backend_quant_max is not None:
if app_quant_min is None or app_quant_max is None: if app_quant_min is None or app_quant_max is None:
warnings.warn( warnings.warn(
f"QConfig {debug_string} must specify 'quant_min' and 'quant_max', ignoring {qconfig}" f"QConfig {debug_string} must specify 'quant_min' and 'quant_max', ignoring {qconfig}",
stacklevel=2,
) )
return False return False
elif app_quant_min < backend_quant_min or app_quant_max > backend_quant_max: elif app_quant_min < backend_quant_min or app_quant_max > backend_quant_max:
@ -898,20 +899,23 @@ def _qconfig_satisfies_dtype_config_constraints(
f"QConfig {debug_string} quantization range must fall within the backend's:\n" f"QConfig {debug_string} quantization range must fall within the backend's:\n"
f"QConfig range = ({app_quant_min}, {app_quant_max}), " f"QConfig range = ({app_quant_min}, {app_quant_max}), "
f"BackendConfig range = ({backend_quant_min}, {backend_quant_max}), " f"BackendConfig range = ({backend_quant_min}, {backend_quant_max}), "
f"ignoring {qconfig}" f"ignoring {qconfig}",
stacklevel=2,
) )
return False return False
# check scale min # check scale min
if backend_scale_min is not None: if backend_scale_min is not None:
if app_scale_min is None: if app_scale_min is None:
warnings.warn( warnings.warn(
f"QConfig {debug_string} must specify 'eps', ignoring {qconfig}" f"QConfig {debug_string} must specify 'eps', ignoring {qconfig}",
stacklevel=2,
) )
return False return False
if app_scale_min < backend_scale_min: if app_scale_min < backend_scale_min:
warnings.warn( warnings.warn(
f"QConfig {debug_string} eps ({app_scale_min}) must be greater than or equal to " f"QConfig {debug_string} eps ({app_scale_min}) must be greater than or equal to "
f"the backend's min scale value ({backend_scale_min}), ignoring {qconfig}" f"the backend's min scale value ({backend_scale_min}), ignoring {qconfig}",
stacklevel=2,
) )
return False return False
# check fixed scale and zero point # check fixed scale and zero point
@ -935,7 +939,8 @@ def _qconfig_satisfies_dtype_config_constraints(
) and not isinstance(activation_post_process, FixedQParamsFakeQuantize): ) and not isinstance(activation_post_process, FixedQParamsFakeQuantize):
warnings.warn( warnings.warn(
f"QConfig must specify a FixedQParamsObserver or a FixedQParamsFakeQuantize " f"QConfig must specify a FixedQParamsObserver or a FixedQParamsFakeQuantize "
f"for fixed qparams ops, ignoring {qconfig}.\n{suggestion_str}" f"for fixed qparams ops, ignoring {qconfig}.\n{suggestion_str}",
stacklevel=2,
) )
return False return False
if ( if (
@ -945,7 +950,8 @@ def _qconfig_satisfies_dtype_config_constraints(
warnings.warn( warnings.warn(
f"QConfig fixed scale ({observer.scale}) and zero point ({observer.zero_point}) " f"QConfig fixed scale ({observer.scale}) and zero point ({observer.zero_point}) "
f"do not match the backend's ({backend_scale_exact_match} and {backend_zero_point_exact_match}), " f"do not match the backend's ({backend_scale_exact_match} and {backend_zero_point_exact_match}), "
f"ignoring {qconfig}.\n{suggestion_str}" f"ignoring {qconfig}.\n{suggestion_str}",
stacklevel=2,
) )
return False return False
return True return True

View File

@ -245,7 +245,8 @@ class UniformQuantizationObserverBase(ObserverBase):
if reduce_range: if reduce_range:
warnings.warn( warnings.warn(
"Please use quant_min and quant_max to specify the range for observers. \ "Please use quant_min and quant_max to specify the range for observers. \
reduce_range will be deprecated in a future release of PyTorch." reduce_range will be deprecated in a future release of PyTorch.",
stacklevel=2,
) )
self.reduce_range = reduce_range self.reduce_range = reduce_range
self.register_buffer("eps", torch.tensor([eps], **factory_kwargs)) self.register_buffer("eps", torch.tensor([eps], **factory_kwargs))
@ -829,7 +830,8 @@ class PerChannelMinMaxObserver(UniformQuantizationObserverBase):
self.max_val.resize_(val.shape) self.max_val.resize_(val.shape)
else: else:
warnings.warn( warnings.warn(
f"Observer load_from_state_dict got unexpected name {name}" f"Observer load_from_state_dict got unexpected name {name}",
stacklevel=2,
) )
# For torchscript module we need to update the attributes here since we do not # For torchscript module we need to update the attributes here since we do not
# call the `_load_from_state_dict` function defined module.py # call the `_load_from_state_dict` function defined module.py
@ -840,7 +842,8 @@ class PerChannelMinMaxObserver(UniformQuantizationObserverBase):
self.max_val.copy_(val) self.max_val.copy_(val)
else: else:
warnings.warn( warnings.warn(
f"Observer load_from_state_dict got unexpected name {name}" f"Observer load_from_state_dict got unexpected name {name}",
stacklevel=2,
) )
elif strict: elif strict:
missing_keys.append(key) missing_keys.append(key)
@ -1289,7 +1292,9 @@ class HistogramObserver(UniformQuantizationObserverBase):
# want to make our quantization range infinite # want to make our quantization range infinite
# and in practice those values will be clamped # and in practice those values will be clamped
if x_min == -torch.inf or x_max == torch.inf: if x_min == -torch.inf or x_max == torch.inf:
warnings.warn("torch.inf detected in input tensor, ignoring input") warnings.warn(
"torch.inf detected in input tensor, ignoring input", stacklevel=2
)
x = x[x.abs() != torch.inf] x = x[x.abs() != torch.inf]
if x.numel() == 0: if x.numel() == 0:
return x_orig return x_orig
@ -1345,7 +1350,8 @@ class HistogramObserver(UniformQuantizationObserverBase):
if is_uninitialized: if is_uninitialized:
warnings.warn( warnings.warn(
"must run observer before calling calculate_qparams.\ "must run observer before calling calculate_qparams.\
Returning default scale and zero point " Returning default scale and zero point ",
stacklevel=2,
) )
return torch.tensor([1.0], device=self.min_val.device.type), torch.tensor( return torch.tensor([1.0], device=self.min_val.device.type), torch.tensor(
[0], device=self.min_val.device.type [0], device=self.min_val.device.type
@ -1509,7 +1515,8 @@ class PlaceholderObserver(ObserverBase):
warnings.warn( warnings.warn(
"Please use `is_dynamic` instead of `compute_dtype`. \ "Please use `is_dynamic` instead of `compute_dtype`. \
`compute_dtype` will be deprecated in a future release \ `compute_dtype` will be deprecated in a future release \
of PyTorch." of PyTorch.",
stacklevel=2,
) )
def forward(self, x): def forward(self, x):

View File

@ -292,7 +292,8 @@ def get_default_qconfig(backend="x86", version=0):
if not torch.cpu._is_vnni_supported(): if not torch.cpu._is_vnni_supported():
warnings.warn( warnings.warn(
"Default qconfig of oneDNN backend with reduce_range of false may have accuracy issues " "Default qconfig of oneDNN backend with reduce_range of false may have accuracy issues "
"on CPU without Vector Neural Network Instruction support." "on CPU without Vector Neural Network Instruction support.",
stacklevel=2,
) )
qconfig = QConfig( qconfig = QConfig(
activation=HistogramObserver.with_args(reduce_range=False), activation=HistogramObserver.with_args(reduce_range=False),

View File

@ -392,7 +392,8 @@ def prepare(
warnings.warn( warnings.warn(
"None of the submodule got qconfig applied. Make sure you " "None of the submodule got qconfig applied. Make sure you "
"passed correct configuration through `qconfig_dict` or " "passed correct configuration through `qconfig_dict` or "
"by assigning the `.qconfig` attribute directly on submodules" "by assigning the `.qconfig` attribute directly on submodules",
stacklevel=2,
) )
_add_observer_( _add_observer_(

View File

@ -372,6 +372,7 @@ def _config_checker(method: Callable) -> Callable:
if quantizer._need_skip_config(quantization_config): if quantizer._need_skip_config(quantization_config):
warnings.warn( warnings.warn(
f"Skip the quantization config for {name}.", f"Skip the quantization config for {name}.",
stacklevel=2,
) )
return quantizer return quantizer
return method(quantizer, name, quantization_config) return method(quantizer, name, quantization_config)
@ -464,7 +465,10 @@ class X86InductorQuantizer(Quantizer):
current_mode.qat_state is not None current_mode.qat_state is not None
and current_mode.qat_state != quantization_config.is_qat and current_mode.qat_state != quantization_config.is_qat
): ):
warnings.warn("Mixed QAT and Non-QAT quantization config is not supported.") warnings.warn(
"Mixed QAT and Non-QAT quantization config is not supported.",
stacklevel=2,
)
need_skip = True need_skip = True
if current_mode.dynamic_state is not None: if current_mode.dynamic_state is not None:
input_activation_spec = quantization_config.input_activation input_activation_spec = quantization_config.input_activation
@ -473,14 +477,15 @@ class X86InductorQuantizer(Quantizer):
and current_mode.dynamic_state != input_activation_spec.is_dynamic and current_mode.dynamic_state != input_activation_spec.is_dynamic
): ):
warnings.warn( warnings.warn(
"Mixed dynamic and static quantization config is not supported." "Mixed dynamic and static quantization config is not supported.",
stacklevel=2,
) )
need_skip = True need_skip = True
return need_skip return need_skip
def set_global(self, quantization_config: QuantizationConfig): def set_global(self, quantization_config: QuantizationConfig):
if self._need_skip_config(quantization_config): if self._need_skip_config(quantization_config):
warnings.warn("Skip the global quantization config.") warnings.warn("Skip the global quantization config.", stacklevel=2)
return self return self
self.global_config = quantization_config self.global_config = quantization_config
return self return self
@ -489,7 +494,8 @@ class X86InductorQuantizer(Quantizer):
if not isinstance(self.global_config, QuantizationConfig): if not isinstance(self.global_config, QuantizationConfig):
warnings.warn( warnings.warn(
"The global_config for X86InductorQuantizer is currently invalid. \ "The global_config for X86InductorQuantizer is currently invalid. \
Please ensure that you use set_global to establish the global quantization configuration." Please ensure that you use set_global to establish the global quantization configuration.",
stacklevel=2,
) )
return self.global_config return self.global_config
@ -508,7 +514,8 @@ class X86InductorQuantizer(Quantizer):
) )
else: else:
warnings.warn( warnings.warn(
f"function: Unable to customize quantization config for {function_type} by X86InductorQuantizer." f"function: Unable to customize quantization config for {function_type} by X86InductorQuantizer.",
stacklevel=2,
) )
return self return self
@ -525,7 +532,8 @@ class X86InductorQuantizer(Quantizer):
) )
else: else:
warnings.warn( warnings.warn(
f"Module: Unable to customize quantization config for {module_type} by X86InductorQuantizer." f"Module: Unable to customize quantization config for {module_type} by X86InductorQuantizer.",
stacklevel=2,
) )
return self return self
@ -551,7 +559,8 @@ class X86InductorQuantizer(Quantizer):
self.operator_type_qconfig[operator_type] = quantization_config self.operator_type_qconfig[operator_type] = quantization_config
else: else:
warnings.warn( warnings.warn(
f"operator: Unable to quantize {operator} by X86InductorQuantizer." f"operator: Unable to quantize {operator} by X86InductorQuantizer.",
stacklevel=2,
) )
return self return self
@ -1317,7 +1326,8 @@ class X86InductorQuantizer(Quantizer):
if not is_all_inputs_connected_to_quantized_op(input_nodes_to_check): if not is_all_inputs_connected_to_quantized_op(input_nodes_to_check):
if quantization_config is not None: if quantization_config is not None:
warnings.warn( warnings.warn(
f"The input of maxpool2d is not quantized, skip annotate maxpool2d with config {quantization_config}." f"The input of maxpool2d is not quantized, skip annotate maxpool2d with config {quantization_config}.",
stacklevel=2,
) )
return return

View File

@ -427,7 +427,8 @@ def check_min_max_valid(min_val: torch.Tensor, max_val: torch.Tensor) -> bool:
if min_val.numel() == 0 or max_val.numel() == 0: if min_val.numel() == 0 or max_val.numel() == 0:
warnings.warn( warnings.warn(
"must run observer before calling calculate_qparams. " "must run observer before calling calculate_qparams. "
+ "Returning default values." + "Returning default values.",
stacklevel=2,
) )
return False return False
@ -435,7 +436,8 @@ def check_min_max_valid(min_val: torch.Tensor, max_val: torch.Tensor) -> bool:
if min_val == float("inf") and max_val == float("-inf"): if min_val == float("inf") and max_val == float("-inf"):
warnings.warn( warnings.warn(
"must run observer before calling calculate_qparams. " "must run observer before calling calculate_qparams. "
+ "Returning default values." + "Returning default values.",
stacklevel=2,
) )
return False return False
@ -806,7 +808,8 @@ def _assert_and_get_unique_device(module: torch.nn.Module) -> Any:
""" """
if {torch.device("cpu"), torch.device("meta")} == devices: if {torch.device("cpu"), torch.device("meta")} == devices:
warnings.warn( warnings.warn(
"Both 'meta' and 'cpu' are present in the list of devices. Module can have one device. We Select 'cpu'." "Both 'meta' and 'cpu' are present in the list of devices. Module can have one device. We Select 'cpu'.",
stacklevel=2,
) )
devices = {torch.device("cpu")} devices = {torch.device("cpu")}
"" ""

View File

@ -944,7 +944,8 @@ def _check_inputs(tupled_inputs) -> bool:
f"Input #{idx} requires gradient and " f"Input #{idx} requires gradient and "
"is not a double precision floating point or complex. " "is not a double precision floating point or complex. "
"This check will likely fail if all the inputs are " "This check will likely fail if all the inputs are "
"not of double precision floating point or complex. " "not of double precision floating point or complex. ",
stacklevel=2,
) )
if inp.is_sparse: if inp.is_sparse:
content = inp._values() content = inp._values()
@ -1325,7 +1326,8 @@ def _test_undefined_backward_mode(func, outputs, inputs) -> bool:
"Backwards compatibility: New undefined gradient support checking " "Backwards compatibility: New undefined gradient support checking "
"feature is enabled by default, but it may break existing callers " "feature is enabled by default, but it may break existing callers "
"of this function. If this is true for you, you can call this " "of this function. If this is true for you, you can call this "
'function with "check_undefined_grad=False" to disable the feature' 'function with "check_undefined_grad=False" to disable the feature',
stacklevel=2,
) )
def check_undefined_grad_support(output_to_check): def check_undefined_grad_support(output_to_check):

View File

@ -265,22 +265,24 @@ class profile:
if _get_privateuse1_backend_name() != "privateuseone": if _get_privateuse1_backend_name() != "privateuseone":
VALID_DEVICE_OPTIONS.append(_get_privateuse1_backend_name()) VALID_DEVICE_OPTIONS.append(_get_privateuse1_backend_name())
if self.use_device not in VALID_DEVICE_OPTIONS: if self.use_device not in VALID_DEVICE_OPTIONS:
warn(f"The {self.use_device} is not a valid device option.") warn(
f"The {self.use_device} is not a valid device option.", stacklevel=2
)
self.use_device = None self.use_device = None
if self.use_device == "cuda" and not torch.cuda.is_available(): if self.use_device == "cuda" and not torch.cuda.is_available():
warn("CUDA is not available, disabling CUDA profiling") warn("CUDA is not available, disabling CUDA profiling", stacklevel=2)
self.use_cuda = False self.use_cuda = False
self.use_device = None self.use_device = None
if self.use_device == "xpu" and not torch.xpu.is_available(): if self.use_device == "xpu" and not torch.xpu.is_available():
warn("XPU is not available, disabling XPU profiling") warn("XPU is not available, disabling XPU profiling", stacklevel=2)
self.use_device = None self.use_device = None
if self.use_device == "hpu" and not ( if self.use_device == "hpu" and not (
hasattr(torch, "hpu") and torch.hpu.is_available() hasattr(torch, "hpu") and torch.hpu.is_available()
): ):
warn("HPU is not available, disabling HPU profiling") warn("HPU is not available, disabling HPU profiling", stacklevel=2)
self.use_device = None self.use_device = None
self.kineto_activities = set() self.kineto_activities = set()
@ -1224,7 +1226,8 @@ class KinetoStepTracker:
if delta > 1: if delta > 1:
warn( warn(
"Profiler step count has increased more than 1 - " "Profiler step count has increased more than 1 - "
f"current_step = {cls._current_step} step dict = {cls._step_dict}" f"current_step = {cls._current_step} step dict = {cls._step_dict}",
stacklevel=2,
) )
for _ in range(delta): for _ in range(delta):
_kineto_step() _kineto_step()

View File

@ -118,7 +118,8 @@ def is_acceptable(tensor):
if not is_available(): if not is_available():
warnings.warn( warnings.warn(
"PyTorch was compiled without cuDNN/MIOpen support. To use cuDNN/MIOpen, rebuild " "PyTorch was compiled without cuDNN/MIOpen support. To use cuDNN/MIOpen, rebuild "
"PyTorch making sure the library is visible to the build system." "PyTorch making sure the library is visible to the build system.",
stacklevel=2,
) )
return False return False
if not _init(): if not _init():
@ -127,7 +128,8 @@ def is_acceptable(tensor):
libpath={"darwin": "DYLD_LIBRARY_PATH", "win32": "PATH"}.get( libpath={"darwin": "DYLD_LIBRARY_PATH", "win32": "PATH"}.get(
sys.platform, "LD_LIBRARY_PATH" sys.platform, "LD_LIBRARY_PATH"
) )
) ),
stacklevel=2,
) )
return False return False
return True return True

View File

@ -293,7 +293,8 @@ def _check_capability():
min_arch % 10, min_arch % 10,
max_arch // 10, max_arch // 10,
max_arch % 10, max_arch % 10,
) ),
stacklevel=2,
) )
matched_arches = "" matched_arches = ""
for arch, arch_info in CUDA_ARCHES_SUPPORTED.items(): for arch, arch_info in CUDA_ARCHES_SUPPORTED.items():
@ -303,7 +304,9 @@ def _check_capability():
): ):
matched_arches += f" {arch}" matched_arches += f" {arch}"
if matched_arches != "": if matched_arches != "":
warnings.warn(matched_cuda_warn.format(matched_arches)) warnings.warn(
matched_cuda_warn.format(matched_arches), stacklevel=2
)
def _check_cubins(): def _check_cubins():
@ -328,7 +331,8 @@ If you want to use the {} GPU with PyTorch, please check the instructions at htt
warnings.warn( warnings.warn(
incompatible_device_warn.format( incompatible_device_warn.format(
device_name, capability, " ".join(arch_list), device_name device_name, capability, " ".join(arch_list), device_name
) ),
stacklevel=2,
) )
@ -818,7 +822,9 @@ def _raw_device_count_amdsmi() -> int:
try: try:
amdsmi.amdsmi_init() amdsmi.amdsmi_init()
except amdsmi.AmdSmiException as e: except amdsmi.AmdSmiException as e:
warnings.warn(f"Can't initialize amdsmi - Error code: {e.err_code}") warnings.warn(
f"Can't initialize amdsmi - Error code: {e.err_code}", stacklevel=2
)
return -1 return -1
socket_handles = amdsmi.amdsmi_get_processor_handles() socket_handles = amdsmi.amdsmi_get_processor_handles()
return len(socket_handles) return len(socket_handles)
@ -831,12 +837,12 @@ def _raw_device_count_nvml() -> int:
nvml_h = CDLL("libnvidia-ml.so.1") nvml_h = CDLL("libnvidia-ml.so.1")
rc = nvml_h.nvmlInit() rc = nvml_h.nvmlInit()
if rc != 0: if rc != 0:
warnings.warn("Can't initialize NVML") warnings.warn("Can't initialize NVML", stacklevel=2)
return -1 return -1
dev_count = c_int(-1) dev_count = c_int(-1)
rc = nvml_h.nvmlDeviceGetCount_v2(byref(dev_count)) rc = nvml_h.nvmlDeviceGetCount_v2(byref(dev_count))
if rc != 0: if rc != 0:
warnings.warn("Can't get nvml device count") warnings.warn("Can't get nvml device count", stacklevel=2)
return -1 return -1
del nvml_h del nvml_h
return dev_count.value return dev_count.value
@ -850,27 +856,27 @@ def _raw_device_uuid_amdsmi() -> Optional[list[str]]:
try: try:
amdsmi.amdsmi_init() amdsmi.amdsmi_init()
except amdsmi.AmdSmiException: except amdsmi.AmdSmiException:
warnings.warn("Can't initialize amdsmi") warnings.warn("Can't initialize amdsmi", stacklevel=2)
return None return None
try: try:
socket_handles = amdsmi.amdsmi_get_processor_handles() socket_handles = amdsmi.amdsmi_get_processor_handles()
dev_count = len(socket_handles) dev_count = len(socket_handles)
except amdsmi.AmdSmiException: except amdsmi.AmdSmiException:
warnings.warn("Can't get amdsmi device count") warnings.warn("Can't get amdsmi device count", stacklevel=2)
return None return None
uuids: list[str] = [] uuids: list[str] = []
for idx in range(dev_count): for idx in range(dev_count):
try: try:
handler = amdsmi.amdsmi_get_processor_handles()[idx] handler = amdsmi.amdsmi_get_processor_handles()[idx]
except amdsmi.AmdSmiException: except amdsmi.AmdSmiException:
warnings.warn("Cannot get amd device handler") warnings.warn("Cannot get amd device handler", stacklevel=2)
return None return None
try: try:
uuid = amdsmi.amdsmi_get_gpu_asic_info(handler)["asic_serial"][ uuid = amdsmi.amdsmi_get_gpu_asic_info(handler)["asic_serial"][
2: 2:
] # Removes 0x prefix from serial ] # Removes 0x prefix from serial
except amdsmi.AmdSmiException: except amdsmi.AmdSmiException:
warnings.warn("Cannot get uuid for amd device") warnings.warn("Cannot get uuid for amd device", stacklevel=2)
return None return None
uuids.append( uuids.append(
str(uuid).lower() str(uuid).lower()
@ -885,25 +891,25 @@ def _raw_device_uuid_nvml() -> Optional[list[str]]:
nvml_h = CDLL("libnvidia-ml.so.1") nvml_h = CDLL("libnvidia-ml.so.1")
rc = nvml_h.nvmlInit() rc = nvml_h.nvmlInit()
if rc != 0: if rc != 0:
warnings.warn("Can't initialize NVML") warnings.warn("Can't initialize NVML", stacklevel=2)
return None return None
dev_count = c_int(-1) dev_count = c_int(-1)
rc = nvml_h.nvmlDeviceGetCount_v2(byref(dev_count)) rc = nvml_h.nvmlDeviceGetCount_v2(byref(dev_count))
if rc != 0: if rc != 0:
warnings.warn("Can't get nvml device count") warnings.warn("Can't get nvml device count", stacklevel=2)
return None return None
uuids: list[str] = [] uuids: list[str] = []
for idx in range(dev_count.value): for idx in range(dev_count.value):
dev_id = c_void_p() dev_id = c_void_p()
rc = nvml_h.nvmlDeviceGetHandleByIndex_v2(idx, byref(dev_id)) rc = nvml_h.nvmlDeviceGetHandleByIndex_v2(idx, byref(dev_id))
if rc != 0: if rc != 0:
warnings.warn("Can't get device handle") warnings.warn("Can't get device handle", stacklevel=2)
return None return None
buf_len = 96 buf_len = 96
buf = create_string_buffer(buf_len) buf = create_string_buffer(buf_len)
rc = nvml_h.nvmlDeviceGetUUID(dev_id, buf, buf_len) rc = nvml_h.nvmlDeviceGetUUID(dev_id, buf, buf_len)
if rc != 0: if rc != 0:
warnings.warn("Can't get device UUID") warnings.warn("Can't get device UUID", stacklevel=2)
return None return None
uuids.append(buf.raw.decode("ascii").strip("\0")) uuids.append(buf.raw.decode("ascii").strip("\0"))
del nvml_h del nvml_h

View File

@ -492,6 +492,7 @@ def reset_max_memory_allocated(device: "Device" = None) -> None:
"torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, " "torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, "
"which resets /all/ peak memory stats.", "which resets /all/ peak memory stats.",
FutureWarning, FutureWarning,
stacklevel=2,
) )
return reset_peak_memory_stats(device=device) return reset_peak_memory_stats(device=device)
@ -518,6 +519,7 @@ def reset_max_memory_cached(device: "Device" = None) -> None:
"torch.cuda.reset_max_memory_cached now calls torch.cuda.reset_peak_memory_stats, " "torch.cuda.reset_max_memory_cached now calls torch.cuda.reset_peak_memory_stats, "
"which resets /all/ peak memory stats.", "which resets /all/ peak memory stats.",
FutureWarning, FutureWarning,
stacklevel=2,
) )
return reset_peak_memory_stats(device=device) return reset_peak_memory_stats(device=device)

View File

@ -14,7 +14,7 @@ SUM = 0 # ncclRedOp_t
def is_available(tensors): def is_available(tensors):
if not hasattr(torch._C, "_nccl_all_reduce"): if not hasattr(torch._C, "_nccl_all_reduce"):
warnings.warn("PyTorch is not compiled with NCCL support") warnings.warn("PyTorch is not compiled with NCCL support", stacklevel=2)
return False return False
devices = set() devices = set()

View File

@ -626,7 +626,8 @@ def _process_single_offline_gemm(untuned_gemm_line: str, gpu_id: int) -> None:
else: else:
warnings.warn( warnings.warn(
"Offline tuning is not supported for this GEMM. Use online tuning instead. " "Offline tuning is not supported for this GEMM. Use online tuning instead. "
+ f"Skipped tuning for: {untuned_gemm[1]}" + f"Skipped tuning for: {untuned_gemm[1]}",
stacklevel=2,
) )
return return
@ -644,7 +645,8 @@ def _process_single_offline_gemm(untuned_gemm_line: str, gpu_id: int) -> None:
if m == 1 or n == 1 or k == 1: if m == 1 or n == 1 or k == 1:
warnings.warn( warnings.warn(
"Offline tuning is not support for this GEMM. Use online tuning instead. " "Offline tuning is not support for this GEMM. Use online tuning instead. "
+ f"Skipped tuning for: {untuned_gemm[1]}" + f"Skipped tuning for: {untuned_gemm[1]}",
stacklevel=2,
) )
return return
@ -747,7 +749,7 @@ def _process_single_offline_gemm(untuned_gemm_line: str, gpu_id: int) -> None:
matA = matA.t() matA = matA.t()
torch.nn.functional.linear(X, matA, bias) torch.nn.functional.linear(X, matA, bias)
else: else:
warnings.warn(f"error: unknown op {op_sig}") warnings.warn(f"error: unknown op {op_sig}", stacklevel=2)
def _check_tuning_assertions() -> None: def _check_tuning_assertions() -> None:
@ -756,7 +758,7 @@ def _check_tuning_assertions() -> None:
""" """
if is_enabled() is False: if is_enabled() is False:
warnings.warn("TunableOp was disabled. Trying to enable now.") warnings.warn("TunableOp was disabled. Trying to enable now.", stacklevel=2)
enable(True) enable(True)
assert is_enabled() is True assert is_enabled() is True
assert tuning_is_enabled() is True assert tuning_is_enabled() is True

View File

@ -23,7 +23,8 @@ try:
from torch.compiler import is_dynamo_compiling as is_torchdynamo_compiling from torch.compiler import is_dynamo_compiling as is_torchdynamo_compiling
except Exception: except Exception:
warnings.warn( warnings.warn(
"Unable to import torchdynamo util `is_torchdynamo_compiling`, so won't support torchdynamo correctly" "Unable to import torchdynamo util `is_torchdynamo_compiling`, so won't support torchdynamo correctly",
stacklevel=2,
) )
def is_torchdynamo_compiling(): # type: ignore[misc] def is_torchdynamo_compiling(): # type: ignore[misc]

View File

@ -470,7 +470,8 @@ class ShardedTensor(ShardedTensorBase):
src = shard.tensor.flatten() src = shard.tensor.flatten()
if src.nelement() == 0: if src.nelement() == 0:
warnings.warn( warnings.warn(
"Gathering a tensor with zero elements on rank " + str(rank) "Gathering a tensor with zero elements on rank " + str(rank),
stacklevel=2,
) )
continue continue
shard_offset = shard_placement[shard.metadata][1] shard_offset = shard_placement[shard.metadata][1]
@ -671,7 +672,8 @@ class ShardedTensor(ShardedTensorBase):
if device_to.index != current_idx: if device_to.index != current_idx:
warnings.warn( warnings.warn(
"ShardedTensor.to only move tensor to its current device" "ShardedTensor.to only move tensor to its current device"
"If you want to put to different device, use `reshard` instead." "If you want to put to different device, use `reshard` instead.",
stacklevel=2,
) )
device_to = torch.device(current_idx) device_to = torch.device(current_idx)

View File

@ -182,7 +182,8 @@ class ModTracker:
warnings.formatwarning = custom_formatwarning warnings.formatwarning = custom_formatwarning
warnings.warn( warnings.warn(
"The module hierarchy tracking maybe be messed up." "The module hierarchy tracking maybe be messed up."
" Please file a bug to PyTorch, if it is the case." " Please file a bug to PyTorch, if it is the case.",
stacklevel=2,
) )
if name not in self.parents: if name not in self.parents:
self._active_module_cnt[name] = 1 self._active_module_cnt[name] = 1

View File

@ -257,7 +257,8 @@ class Join:
f"{self._rank} has at least {WARN_THRESHOLD} " f"{self._rank} has at least {WARN_THRESHOLD} "
f"fewer inputs than other currently-active ranks. " f"fewer inputs than other currently-active ranks. "
"This level of skew could lead to performance " "This level of skew could lead to performance "
"degradation during training." "degradation during training.",
stacklevel=2,
) )
# Shadow the all-reduce in non-joined processes # Shadow the all-reduce in non-joined processes
num_nonjoined_procs = self._get_num_nonjoined_procs() num_nonjoined_procs = self._get_num_nonjoined_procs()

View File

@ -101,7 +101,8 @@ class PeriodicModelAverager(ModelAverager):
"When period is 1, no need to use model averaging because the communication cost " "When period is 1, no need to use model averaging because the communication cost "
"of all-reducing parameters will be no less than the cost of all-reducing gradients " "of all-reducing parameters will be no less than the cost of all-reducing gradients "
"by DistributedDataParallel in the backward pass. Therefore, only " "by DistributedDataParallel in the backward pass. Therefore, only "
"DistributedDataParallel should be used for this case." "DistributedDataParallel should be used for this case.",
stacklevel=2,
) )
self.period = period self.period = period

View File

@ -114,7 +114,8 @@ class HierarchicalModelAverager(averagers.ModelAverager):
"no need to use model averaging because the communication cost " "no need to use model averaging because the communication cost "
"of all-reducing parameters will be no less than the cost of all-reducing gradients " "of all-reducing parameters will be no less than the cost of all-reducing gradients "
"by DistributedDataParallel in the backward pass. Therefore, only " "by DistributedDataParallel in the backward pass. Therefore, only "
"DistributedDataParallel should be used for this case." "DistributedDataParallel should be used for this case.",
stacklevel=2,
) )
overall_group_size = dist.get_world_size(group=self.process_group) overall_group_size = dist.get_world_size(group=self.process_group)
if list(period_group_size_dict.values())[-1] != overall_group_size: if list(period_group_size_dict.values())[-1] != overall_group_size:

View File

@ -660,7 +660,8 @@ class _FileSystemWriter(StorageWriter):
warnings.warn( warnings.warn(
f"Detected an existing checkpoint in {self.path}, overwriting since {self.overwrite=}." f"Detected an existing checkpoint in {self.path}, overwriting since {self.overwrite=}."
" Past version 2.5 of PyTorch, `overwrite` will default to False. Set this variable to True to" " Past version 2.5 of PyTorch, `overwrite` will default to False. Set this variable to True to"
" maintain this functionality or False to raise when an existing checkpoint is found." " maintain this functionality or False to raise when an existing checkpoint is found.",
stacklevel=2,
) )
else: else:
raise RuntimeError(f"Checkpoint already exists and {self.overwrite=}.") raise RuntimeError(f"Checkpoint already exists and {self.overwrite=}.")

View File

@ -290,6 +290,7 @@ def _verify_options(
"will be removed in 2.5. This feature can be achieved by manually " "will be removed in 2.5. This feature can be achieved by manually "
"filtering out the state_dict returned from get_state_dict.", "filtering out the state_dict returned from get_state_dict.",
FutureWarning, FutureWarning,
stacklevel=2,
) )
if optim_only and not optims: if optim_only and not optims:
raise RuntimeError( raise RuntimeError(
@ -1234,6 +1235,7 @@ def _unflatten_model_state_dict(
"feature, please preprocessing the model_state_dict to achieve the " "feature, please preprocessing the model_state_dict to achieve the "
"same functionality.", "same functionality.",
FutureWarning, FutureWarning,
stacklevel=2,
) )
cast_state_dict = cast(dict[nn.Module, dict[str, ValueType]], state_dict) cast_state_dict = cast(dict[nn.Module, dict[str, ValueType]], state_dict)
new_state_dict: dict[str, ValueType] = {} new_state_dict: dict[str, ValueType] = {}

View File

@ -158,7 +158,8 @@ def load(
no_dist = no_dist or (not dist.is_available()) or (not dist.is_initialized()) no_dist = no_dist or (not dist.is_available()) or (not dist.is_initialized())
if no_dist: if no_dist:
warnings.warn( warnings.warn(
"torch.distributed is disabled, unavailable or uninitialized, assuming the intent is to load in a single process." "torch.distributed is disabled, unavailable or uninitialized, assuming the intent is to load in a single process.",
stacklevel=2,
) )
with _profile(): with _profile():
@ -365,7 +366,8 @@ def _load_state_dict_from_keys(
no_dist = not (dist.is_available() and dist.is_initialized()) no_dist = not (dist.is_available() and dist.is_initialized())
if no_dist: if no_dist:
warnings.warn( warnings.warn(
"torch.distributed is unavailable or uninitialized, assuming the intent is to load in a single process." "torch.distributed is unavailable or uninitialized, assuming the intent is to load in a single process.",
stacklevel=2,
) )
storage_reader = cast( storage_reader = cast(

View File

@ -182,7 +182,8 @@ def save(
no_dist = no_dist or (not dist.is_available()) or (not dist.is_initialized()) no_dist = no_dist or (not dist.is_available()) or (not dist.is_initialized())
if no_dist: if no_dist:
warnings.warn( warnings.warn(
"torch.distributed is disabled, unavailable or uninitialized, assuming the intent is to save in a single process." "torch.distributed is disabled, unavailable or uninitialized, assuming the intent is to save in a single process.",
stacklevel=2,
) )
with _profile(): with _profile():
@ -414,7 +415,8 @@ def _save_state_dict(
warnings.warn( warnings.warn(
"The function definition for SavePlanner.set_up_planner has been updated" "The function definition for SavePlanner.set_up_planner has been updated"
" to include the storage_meta argument. Please update your implementation" " to include the storage_meta argument. Please update your implementation"
" to include this parameter." " to include this parameter.",
stacklevel=2,
) )
planner.set_up_planner(state_dict, distW.is_coordinator) # type: ignore[call-arg, arg-type] planner.set_up_planner(state_dict, distW.is_coordinator) # type: ignore[call-arg, arg-type]
else: else:

View File

@ -461,7 +461,8 @@ def _api_bc_check(func):
if len(args) == 2: if len(args) == 2:
warnings.warn( warnings.warn(
f"The argument order of {func.__name__} has been changed. " f"The argument order of {func.__name__} has been changed. "
"Please check the document to avoid future breakages." "Please check the document to avoid future breakages.",
stacklevel=2,
) )
sig = inspect.signature(func) sig = inspect.signature(func)
kwonlyargs = [ kwonlyargs = [

View File

@ -85,7 +85,8 @@ else:
# We keep this function for backward compatibility. # We keep this function for backward compatibility.
warnings.warn( warnings.warn(
"This get_root_mesh API will be deprecated soon." "This get_root_mesh API will be deprecated soon."
"Please use `get_root_mesh` inside DeviceMesh instead." "Please use `get_root_mesh` inside DeviceMesh instead.",
stacklevel=2,
) )
if not device_mesh: if not device_mesh:
return device_mesh return device_mesh
@ -108,7 +109,8 @@ else:
) -> list["DeviceMesh"]: ) -> list["DeviceMesh"]:
warnings.warn( warnings.warn(
"This _get_all_submeshes API will be deprecated soon." "This _get_all_submeshes API will be deprecated soon."
"Please use `_get_all_submeshes` inside DeviceMesh instead." "Please use `_get_all_submeshes` inside DeviceMesh instead.",
stacklevel=2,
) )
return device_mesh._get_all_submeshes(mesh_dim_name) return device_mesh._get_all_submeshes(mesh_dim_name)
@ -329,7 +331,8 @@ else:
"It is recommended to set the current device for the process BEFORE the DeviceMesh initialization so that " "It is recommended to set the current device for the process BEFORE the DeviceMesh initialization so that "
"the underlying communicator (i.e. NCCL) can be initialized properly. " "the underlying communicator (i.e. NCCL) can be initialized properly. "
"Given that the current process has no default device selected, DeviceMesh will use a heuristic to set the " "Given that the current process has no default device selected, DeviceMesh will use a heuristic to set the "
"device_id via `global_rank % num_devices_per_host`, assuming homogeneous hardware cluster. " "device_id via `global_rank % num_devices_per_host`, assuming homogeneous hardware cluster. ",
stacklevel=2,
) )
# heuristic to set the current cuda/cuda-like device base on num of gpu devices available in each host # heuristic to set the current cuda/cuda-like device base on num of gpu devices available in each host
# NOTE: This device selection would only work for homogeneous hardware. # NOTE: This device selection would only work for homogeneous hardware.
@ -766,7 +769,8 @@ else:
warnings.warn( warnings.warn(
"You are attempting to slice a submesh from another submesh. While we support this operation, " "You are attempting to slice a submesh from another submesh. While we support this operation, "
"it is users' responsibility to ensure that the submesh is consistently sliced across all ranks. " "it is users' responsibility to ensure that the submesh is consistently sliced across all ranks. "
"If not, this may result in some ranks receiving the submesh while others encounter errors." "If not, this may result in some ranks receiving the submesh while others encounter errors.",
stacklevel=2,
) )
slice_from_root = False slice_from_root = False
@ -803,7 +807,8 @@ else:
elif name in flatten_name_to_root_layout: elif name in flatten_name_to_root_layout:
warnings.warn( warnings.warn(
"Slicing a flattened dim from root mesh will be deprecated in PT 2.11. " "Slicing a flattened dim from root mesh will be deprecated in PT 2.11. "
"Users need to bookkeep the flattened mesh directly. " "Users need to bookkeep the flattened mesh directly. ",
stacklevel=2,
) )
layout_sliced.append(flatten_name_to_root_layout[name]) layout_sliced.append(flatten_name_to_root_layout[name])

View File

@ -352,7 +352,8 @@ class Backend(str): # noqa: SLOT000
warnings.warn( warnings.warn(
f"Device capability of {name} unspecified, assuming `cpu` and " f"Device capability of {name} unspecified, assuming `cpu` and "
"`cuda` or `xpu`. Please specify it via the `devices` argument of " "`cuda` or `xpu`. Please specify it via the `devices` argument of "
"`register_backend`." "`register_backend`.",
stacklevel=2,
) )
Backend.backend_capability[name.lower()] = ( Backend.backend_capability[name.lower()] = (
["cpu", "cuda", "xpu"] if torch.xpu.is_available() else ["cpu", "cuda"] ["cpu", "cuda", "xpu"] if torch.xpu.is_available() else ["cpu", "cuda"]
@ -427,7 +428,8 @@ class BackendConfig:
warnings.warn( warnings.warn(
f"Device capability of {backend} unknown, assuming `cpu` and " f"Device capability of {backend} unknown, assuming `cpu` and "
"`cuda`. You can specify it in `device:backend` format in " "`cuda`. You can specify it in `device:backend` format in "
"`init_process_group` call." "`init_process_group` call.",
stacklevel=2,
) )
backend_val = Backend(backend) backend_val = Backend(backend)
self.device_backend_map = { self.device_backend_map = {
@ -751,7 +753,8 @@ def _get_default_timeout(backend: Backend) -> timedelta:
# TODO moco benchmark on CPU initializes pgnccl backend today, triggered this assert in CI before it was # TODO moco benchmark on CPU initializes pgnccl backend today, triggered this assert in CI before it was
# changed to be a warning. We should fix the moco model. # changed to be a warning. We should fix the moco model.
warnings.warn( warnings.warn(
"Attempted to get default timeout for nccl backend, but NCCL support is not compiled" "Attempted to get default timeout for nccl backend, but NCCL support is not compiled",
stacklevel=2,
) )
return default_pg_timeout return default_pg_timeout
return default_pg_nccl_timeout return default_pg_nccl_timeout
@ -802,6 +805,7 @@ def _get_object_coll_device(group: Optional[ProcessGroup] = None) -> str:
f"You are using a Backend {type(group)} as a ProcessGroup. " f"You are using a Backend {type(group)} as a ProcessGroup. "
"This usage is deprecated since PyTorch 2.0. Please use a public API " "This usage is deprecated since PyTorch 2.0. Please use a public API "
"of PyTorch Distributed instead.", "of PyTorch Distributed instead.",
stacklevel=2,
) )
# Provide backward compatibility to cases where `group` passed in is # Provide backward compatibility to cases where `group` passed in is
# actually a Backend (like `ProcessGroupGloo`) rather than a # actually a Backend (like `ProcessGroupGloo`) rather than a
@ -868,7 +872,8 @@ def _get_pg_default_device(group: Optional[ProcessGroup] = None) -> torch.device
"backward-compatiblity reason. If you need to find a device for object " "backward-compatiblity reason. If you need to find a device for object "
"collectives, please use `_get_object_coll_device`. If you need to query " "collectives, please use `_get_object_coll_device`. If you need to query "
"the device types supported by group, please use " "the device types supported by group, please use "
"`_device_capability(group)`. " "`_device_capability(group)`. ",
stacklevel=2,
) )
group = group or _get_default_group() group = group or _get_default_group()
@ -910,7 +915,8 @@ def _get_pg_default_device(group: Optional[ProcessGroup] = None) -> torch.device
warnings.warn( warnings.warn(
"Multiple backends are registered with this ProcessGroup. We cannot " "Multiple backends are registered with this ProcessGroup. We cannot "
f"determine which one is the default. Returning {rv}. " f"determine which one is the default. Returning {rv}. "
"Please consider using other APIs." "Please consider using other APIs.",
stacklevel=2,
) )
return rv return rv
@ -1010,7 +1016,8 @@ def _warn_not_in_group(op_name) -> None:
global_rank = -1 if GroupMember.WORLD is None else GroupMember.WORLD.rank() global_rank = -1 if GroupMember.WORLD is None else GroupMember.WORLD.rank()
warnings.warn( warnings.warn(
f"Running {op_name} on global rank {global_rank} which does not " f"Running {op_name} on global rank {global_rank} which does not "
"belong to the given group." "belong to the given group.",
stacklevel=2,
) )
@ -1557,7 +1564,9 @@ def _set_pg_timeout(timeout: timedelta, group: Optional[ProcessGroup] = None) ->
elif is_gloo_available() and isinstance(backend, ProcessGroupGloo): elif is_gloo_available() and isinstance(backend, ProcessGroupGloo):
backends.add(backend) # type: ignore[arg-type] backends.add(backend) # type: ignore[arg-type]
if len(backends) == 0: if len(backends) == 0:
warnings.warn("Set timeout is now only supported for either nccl or gloo.") warnings.warn(
"Set timeout is now only supported for either nccl or gloo.", stacklevel=2
)
for backend in backends: for backend in backends:
backend._set_default_timeout(timeout) backend._set_default_timeout(timeout)
@ -1758,7 +1767,8 @@ def init_process_group(
warnings.warn( warnings.warn(
f"For MPI backend, world_size ({world_size}) and rank ({rank}) " f"For MPI backend, world_size ({world_size}) and rank ({rank}) "
"are ignored since they are assigned by the " "are ignored since they are assigned by the "
"MPI runtime." "MPI runtime.",
stacklevel=2,
) )
default_pg, _ = _new_process_group_helper( default_pg, _ = _new_process_group_helper(
@ -2038,7 +2048,8 @@ def _new_process_group_helper(
if backend_options._timeout != timeout: if backend_options._timeout != timeout:
warnings.warn( warnings.warn(
"backend_options._timeout was specified, " "backend_options._timeout was specified, "
"but timeout kwarg has a default value that will always override it. " "but timeout kwarg has a default value that will always override it. ",
stacklevel=2,
) )
else: else:
# default backend_options for NCCL # default backend_options for NCCL
@ -2259,7 +2270,8 @@ def destroy_process_group(group: Optional[ProcessGroup] = None):
if pg in _world.pg_coalesce_state.keys(): if pg in _world.pg_coalesce_state.keys():
warnings.warn( warnings.warn(
"Some coalesced collectives haven't been launched when " "Some coalesced collectives haven't been launched when "
"ProcessGroup is destroyed. They will be cleaned." "ProcessGroup is destroyed. They will be cleaned.",
stacklevel=2,
) )
del _world.pg_coalesce_state[pg] del _world.pg_coalesce_state[pg]
@ -2349,7 +2361,8 @@ def _abort_process_group(group: Optional[ProcessGroup] = None):
if pg in _world.pg_coalesce_state.keys(): if pg in _world.pg_coalesce_state.keys():
warnings.warn( warnings.warn(
"Some coalesced collectives haven't been launched when " "Some coalesced collectives haven't been launched when "
"ProcessGroup is aborted. They will be cleaned." "ProcessGroup is aborted. They will be cleaned.",
stacklevel=2,
) )
del _world.pg_coalesce_state[pg] del _world.pg_coalesce_state[pg]
@ -4919,7 +4932,8 @@ def barrier(
if group.rank() == 0: if group.rank() == 0:
warnings.warn( # warn only once warnings.warn( # warn only once
"barrier(): using the device under current context. " "barrier(): using the device under current context. "
"You can specify `device_id` in `init_process_group` to mute this warning." "You can specify `device_id` in `init_process_group` to mute this warning.",
stacklevel=2,
) )
work = group.barrier(opts=opts) work = group.barrier(opts=opts)
@ -5001,6 +5015,7 @@ def monitored_barrier(
warnings.warn( warnings.warn(
"Please specify timeout arg as a timedelta. " "Please specify timeout arg as a timedelta. "
f"Converting current value of {timeout} assuming it represents seconds", f"Converting current value of {timeout} assuming it represents seconds",
stacklevel=2,
) )
timeout = timedelta(seconds=timeout) timeout = timedelta(seconds=timeout)

View File

@ -106,6 +106,7 @@ class WorkerSpec:
warnings.warn( warnings.warn(
"WorkerSpec.fn will be deprecated," "WorkerSpec.fn will be deprecated,"
" please use WorkerSpec.entrypoint instead", " please use WorkerSpec.entrypoint instead",
stacklevel=2,
category=DeprecationWarning, category=DeprecationWarning,
) )
self.entrypoint = self.fn self.entrypoint = self.fn

View File

@ -52,7 +52,9 @@ class ErrorHandler:
try: try:
faulthandler.enable(all_threads=True) faulthandler.enable(all_threads=True)
except Exception as e: except Exception as e:
warnings.warn(f"Unable to enable fault handler. {type(e).__name__}: {e}") warnings.warn(
f"Unable to enable fault handler. {type(e).__name__}: {e}", stacklevel=2
)
def _write_error_file(self, file_path: str, error_msg: str) -> None: def _write_error_file(self, file_path: str, error_msg: str) -> None:
"""Write error message to the file.""" """Write error message to the file."""
@ -60,7 +62,9 @@ class ErrorHandler:
with open(file_path, "w") as fp: with open(file_path, "w") as fp:
fp.write(error_msg) fp.write(error_msg)
except Exception as e: except Exception as e:
warnings.warn(f"Unable to write error to file. {type(e).__name__}: {e}") warnings.warn(
f"Unable to write error to file. {type(e).__name__}: {e}", stacklevel=2
)
def record_exception(self, e: BaseException) -> None: def record_exception(self, e: BaseException) -> None:
""" """

View File

@ -65,5 +65,6 @@ def _derive_module_name(depth: int = 1) -> Optional[str]:
warnings.warn( warnings.warn(
f"Error deriving logger module name, using <None>. Exception: {e}", f"Error deriving logger module name, using <None>. Exception: {e}",
RuntimeWarning, RuntimeWarning,
stacklevel=2,
) )
return None return None

View File

@ -336,7 +336,8 @@ def _get_param_to_fqns(
warnings.warn( warnings.warn(
"FlatParameter is being traversed more than once. " "FlatParameter is being traversed more than once. "
"This case should only happen when using " "This case should only happen when using "
"DistributedModelParallel with FullyShardedDataParallel." "DistributedModelParallel with FullyShardedDataParallel.",
stacklevel=2,
) )
param_to_fqns[param] = global_fqns param_to_fqns[param] = global_fqns
elif not dedup_shared_params: elif not dedup_shared_params:

View File

@ -299,7 +299,8 @@ class _ExecOrderData:
warnings.warn( warnings.warn(
"Forward order differs from that of the first iteration " "Forward order differs from that of the first iteration "
f"on rank {self.rank}. Collectives are unchecked and may " f"on rank {self.rank}. Collectives are unchecked and may "
f"give incorrect results or hang.\n{msg_prefix}{msg_suffix}" f"give incorrect results or hang.\n{msg_prefix}{msg_suffix}",
stacklevel=2,
) )
self.warn_status = _ExecOrderWarnStatus.WARNING self.warn_status = _ExecOrderWarnStatus.WARNING
self.current_order_index += 1 self.current_order_index += 1

View File

@ -1585,7 +1585,8 @@ class FlatParamHandle:
warnings.warn( warnings.warn(
f"[Rank {self.rank}] Only some but not all ranks have a " f"[Rank {self.rank}] Only some but not all ranks have a "
"`None` `FlatParameter` gradient, so FSDP is using zeros to " "`None` `FlatParameter` gradient, so FSDP is using zeros to "
"approximate those ranks' sharded gradients being `None`" "approximate those ranks' sharded gradients being `None`",
stacklevel=2,
) )
flat_param._saved_grad_shard = None # type: ignore[assignment] flat_param._saved_grad_shard = None # type: ignore[assignment]
sharded_grad = torch.zeros(flat_param._sharded_size, device=self.device) # type: ignore[attr-defined] sharded_grad = torch.zeros(flat_param._sharded_size, device=self.device) # type: ignore[attr-defined]
@ -2434,7 +2435,8 @@ class FlatParamHandle:
f"[Rank {rank}] {'Parameter' if is_param else 'Gradient'} needs " f"[Rank {rank}] {'Parameter' if is_param else 'Gradient'} needs "
f"writeback in {self._training_state}\n" f"writeback in {self._training_state}\n"
f"expected shape={expected_shape} shape={src_shape} " f"expected shape={expected_shape} shape={src_shape} "
f"expected device={dst_tensor.device} device={src_device}" f"expected device={dst_tensor.device} device={src_device}",
stacklevel=2,
) )
if src_tensor is not None and src_tensor.shape != expected_shape: if src_tensor is not None and src_tensor.shape != expected_shape:
# NOTE: Gradient shape mismatch is not possible in practice since # NOTE: Gradient shape mismatch is not possible in practice since

View File

@ -431,7 +431,8 @@ def _init_core_state(
warnings.warn( warnings.warn(
"FSDP is switching to use `NO_SHARD` instead of " "FSDP is switching to use `NO_SHARD` instead of "
f"{sharding_strategy or ShardingStrategy.FULL_SHARD} since " f"{sharding_strategy or ShardingStrategy.FULL_SHARD} since "
"the world size is 1." "the world size is 1.",
stacklevel=2,
) )
sharding_strategy = ShardingStrategy.NO_SHARD sharding_strategy = ShardingStrategy.NO_SHARD
elif sharding_strategy == ShardingStrategy.NO_SHARD: elif sharding_strategy == ShardingStrategy.NO_SHARD:
@ -704,7 +705,8 @@ def _get_ignored_modules(
warnings.warn( warnings.warn(
"Trying to ignore the top-level module passed into the FSDP " "Trying to ignore the top-level module passed into the FSDP "
"constructor itself will result in all parameters being " "constructor itself will result in all parameters being "
f"ignored and is not well-supported: {module}" f"ignored and is not well-supported: {module}",
stacklevel=2,
) )
# Include nested FSDP modules' ignored modules # Include nested FSDP modules' ignored modules
for submodule in root_module.modules(): for submodule in root_module.modules():
@ -847,7 +849,8 @@ def _get_device_from_device_id(
f"FSDP will use the current device {device_handle.current_device()}. " f"FSDP will use the current device {device_handle.current_device()}. "
f"If this is incorrect, please explicitly call `torch.{device.type}.set_device()` " f"If this is incorrect, please explicitly call `torch.{device.type}.set_device()` "
"before FSDP initialization or pass in the explicit device " "before FSDP initialization or pass in the explicit device "
"index as the `device_id` argument." "index as the `device_id` argument.",
stacklevel=2,
) )
device = torch.device(device_handle.current_device()) device = torch.device(device_handle.current_device())
return device return device
@ -929,7 +932,8 @@ def _materialize_meta_module(
warnings.warn( warnings.warn(
"Unable to call `reset_parameters()` for module on meta " "Unable to call `reset_parameters()` for module on meta "
f"device with error {str(e)}. Please ensure that your module of" f"device with error {str(e)}. Please ensure that your module of"
f"type {type(module)} implements a `reset_parameters()` method." # type: ignore[possibly-undefined] f"type {type(module)} implements a `reset_parameters()` method.",
stacklevel=2, # type: ignore[possibly-undefined]
) )
raise e raise e
@ -1049,7 +1053,8 @@ def _warn_cpu_init():
"recommend passing in the `device_id` argument for FSDP to move " "recommend passing in the `device_id` argument for FSDP to move "
"`module` to GPU for the sharding initialization. `module` must also " "`module` to GPU for the sharding initialization. `module` must also "
"be on GPU device to work with the `sync_module_states=True` flag " "be on GPU device to work with the `sync_module_states=True` flag "
"since that requires GPU communication." "since that requires GPU communication.",
stacklevel=2,
) )

View File

@ -506,7 +506,8 @@ def _flatten_optim_state_dict(
flat_osd_state[key] = copy.deepcopy(state) flat_osd_state[key] = copy.deepcopy(state)
else: else:
warnings.warn( warnings.warn(
f"optim_state[{key}] is not on rank{fsdp_state.rank}." f"optim_state[{key}] is not on rank{fsdp_state.rank}.",
stacklevel=2,
) )
else: else:
@ -2051,7 +2052,8 @@ def _optim_state_dict(
"most cases, this is a user-defined state that is not " "most cases, this is a user-defined state that is not "
"associated with any particular parameter. Another possible " "associated with any particular parameter. Another possible "
"case is this state is managed by TorchRec. Otherwise, there may " "case is this state is managed by TorchRec. Otherwise, there may "
" be a mismatched assumption of optim_state_dict of this mode." " be a mismatched assumption of optim_state_dict of this mode.",
stacklevel=2,
) )
fsdp_osd_state[key] = value fsdp_osd_state[key] = value

View File

@ -337,7 +337,8 @@ def _full_post_state_dict_hook(
"This may mean that this state_dict entry could point to invalid " "This may mean that this state_dict entry could point to invalid "
"memory regions after returning from state_dict() call if this " "memory regions after returning from state_dict() call if this "
"parameter is managed by FSDP. Please check clone " "parameter is managed by FSDP. Please check clone "
f"implementation of {fqn}. Error: {str(e)}" f"implementation of {fqn}. Error: {str(e)}",
stacklevel=2,
) )
return _common_unshard_post_state_dict_hook( return _common_unshard_post_state_dict_hook(
@ -708,7 +709,8 @@ def _post_state_dict_hook(
context = _replace_with_full_state_dict_type(fsdp_state) context = _replace_with_full_state_dict_type(fsdp_state)
warnings.warn( warnings.warn(
"When using ``NO_SHARD`` for ``ShardingStrategy``, full_state_dict will " "When using ``NO_SHARD`` for ``ShardingStrategy``, full_state_dict will "
"be returned." "be returned.",
stacklevel=2,
) )
else: else:
context = contextlib.nullcontext() context = contextlib.nullcontext()
@ -770,7 +772,8 @@ def _pre_state_dict_hook(
context = _replace_with_full_state_dict_type(fsdp_state) context = _replace_with_full_state_dict_type(fsdp_state)
warnings.warn( warnings.warn(
"When using ``NO_SHARD`` for ``ShardingStrategy``, full_state_dict will " "When using ``NO_SHARD`` for ``ShardingStrategy``, full_state_dict will "
"be returned." "be returned.",
stacklevel=2,
) )
else: else:
_set_use_dtensor(fsdp_state) _set_use_dtensor(fsdp_state)
@ -824,7 +827,8 @@ def _pre_load_state_dict_hook(
context = _replace_with_full_state_dict_type(fsdp_state) context = _replace_with_full_state_dict_type(fsdp_state)
warnings.warn( warnings.warn(
"When using ``NO_SHARD`` for ``ShardingStrategy``, full_state_dict will" "When using ``NO_SHARD`` for ``ShardingStrategy``, full_state_dict will"
"be returned." "be returned.",
stacklevel=2,
) )
else: else:
_set_use_dtensor(fsdp_state) _set_use_dtensor(fsdp_state)
@ -861,7 +865,8 @@ def _post_load_state_dict_hook(
context = _replace_with_full_state_dict_type(fsdp_state) context = _replace_with_full_state_dict_type(fsdp_state)
warnings.warn( warnings.warn(
"When using ``NO_SHARD`` for ``ShardingStrategy``, full_state_dict will" "When using ``NO_SHARD`` for ``ShardingStrategy``, full_state_dict will"
"be returned." "be returned.",
stacklevel=2,
) )
else: else:
context = contextlib.nullcontext() context = contextlib.nullcontext()

View File

@ -153,7 +153,8 @@ def _validate_unshard_params_args(
"offload_to_cpu=True and rank0_only=False may result in the" "offload_to_cpu=True and rank0_only=False may result in the"
"unsharded parameters being redundantly copied to CPU memory for " "unsharded parameters being redundantly copied to CPU memory for "
"GPUs sharing the same CPU memory, which risks CPU OOM. We " "GPUs sharing the same CPU memory, which risks CPU OOM. We "
"recommend using offload_to_cpu=True with rank0_only=True." "recommend using offload_to_cpu=True with rank0_only=True.",
stacklevel=2,
) )

View File

@ -120,7 +120,8 @@ def _warn_on_overridden_mixed_precision(
"Both mixed precision and an auto_wrap_policy were specified to FSDP, " "Both mixed precision and an auto_wrap_policy were specified to FSDP, "
f"where the wrapped module has submodules of type:\n{overridden_module_classes}\n" f"where the wrapped module has submodules of type:\n{overridden_module_classes}\n"
"These modules will be wrapped as separate FSDP instacnes with mixed " "These modules will be wrapped as separate FSDP instacnes with mixed "
"precision disabled." "precision disabled.",
stacklevel=2,
) )
@ -172,7 +173,7 @@ def _validate_frozen_params(
f"The following parameters have requires_grad=False:\n{frozen_param_fqns}" f"The following parameters have requires_grad=False:\n{frozen_param_fqns}"
) )
if use_orig_params: if use_orig_params:
warnings.warn(msg) warnings.warn(msg, stacklevel=2)
else: else:
raise ValueError(msg) raise ValueError(msg)

View File

@ -680,6 +680,7 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
"#torch.distributed.checkpoint.state_dict.get_state_dict ." "#torch.distributed.checkpoint.state_dict.get_state_dict ."
"Tutorial: https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html .", "Tutorial: https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html .",
FutureWarning, FutureWarning,
stacklevel=2,
) )
_state_dict_type_to_config = { _state_dict_type_to_config = {
StateDictType.FULL_STATE_DICT: FullStateDictConfig, StateDictType.FULL_STATE_DICT: FullStateDictConfig,
@ -1208,7 +1209,8 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
warnings.warn( warnings.warn(
f"Called FSDP.clip_grad_norm_() on rank {self.rank} with no " f"Called FSDP.clip_grad_norm_() on rank {self.rank} with no "
"gradients -- returning the total norm in the default dtype " "gradients -- returning the total norm in the default dtype "
f"{total_norm.dtype}" f"{total_norm.dtype}",
stacklevel=2,
) # warn since this is generally unexpected ) # warn since this is generally unexpected
return total_norm return total_norm
total_norm_dtype = functools.reduce( total_norm_dtype = functools.reduce(

View File

@ -87,7 +87,8 @@ class _NamedOptimizer(optim.Optimizer):
else: else:
warnings.warn( warnings.warn(
"Since we pass in param_groups, we will use param_groups to " "Since we pass in param_groups, we will use param_groups to "
"initialize the optimizer, not all parameters of the module." "initialize the optimizer, not all parameters of the module.",
stacklevel=2,
) )
param_to_key = {param: key for key, param in self.named_parameters.items()} # type: ignore[misc, has-type] param_to_key = {param: key for key, param in self.named_parameters.items()} # type: ignore[misc, has-type]
ordered_param_keys = [] ordered_param_keys = []

View File

@ -92,7 +92,8 @@ class PostLocalSGDOptimizer(torch.optim.Optimizer):
else: else:
warnings.warn( warnings.warn(
"Loaded state dict does not contain a step counter for an averager. " "Loaded state dict does not contain a step counter for an averager. "
"Setting step counter to 0." "Setting step counter to 0.",
stacklevel=2,
) )
self.averager.step = 0 self.averager.step = 0

View File

@ -513,7 +513,8 @@ class OpDispatcher:
"Found a non-scalar tensor with numel=1 and ndim!=0, " "Found a non-scalar tensor with numel=1 and ndim!=0, "
"we are implicitly creating a replicated DTensor for it. " "we are implicitly creating a replicated DTensor for it. "
"However, please consider changing it to a scalar tensor " "However, please consider changing it to a scalar tensor "
"or explicitly create a DTensor under distributed environment." "or explicitly create a DTensor under distributed environment.",
stacklevel=2,
) )
if tensor_arg.numel() == 1 or self._allow_implicit_replication: if tensor_arg.numel() == 1 or self._allow_implicit_replication:

View File

@ -43,7 +43,8 @@ def is_rng_supported_mesh(device_mesh: DeviceMesh) -> bool:
else: else:
# TODO: Logs way too much # TODO: Logs way too much
warnings.warn( warnings.warn(
f"DTensor random operators may not have complete support on {device_mesh.device_type} device mesh" f"DTensor random operators may not have complete support on {device_mesh.device_type} device mesh",
stacklevel=2,
) )
return False return False
@ -72,7 +73,8 @@ def manual_seed(seed: int, device_mesh: DeviceMesh) -> None:
if not is_rng_supported_mesh(device_mesh): if not is_rng_supported_mesh(device_mesh):
warnings.warn( warnings.warn(
"DTensor manual_seed() may not have complete support " "DTensor manual_seed() may not have complete support "
f"on {device_mesh.device_type} device mesh" f"on {device_mesh.device_type} device mesh",
stacklevel=2,
) )
return return

View File

@ -74,7 +74,8 @@ def parallelize_module( # type: ignore[return]
if parallelize_plan is None: if parallelize_plan is None:
warnings.warn( warnings.warn(
"No parallelize_plan is provided and auto-parallel is not supported " "No parallelize_plan is provided and auto-parallel is not supported "
"at the moment, so this parallelize_module call will do nothing." "at the moment, so this parallelize_module call will do nothing.",
stacklevel=2,
) )
return module return module
@ -108,7 +109,8 @@ def parallelize_module( # type: ignore[return]
warnings.warn( warnings.warn(
f"Parallelize plan key '{module_path}' could not be resolved: " f"Parallelize plan key '{module_path}' could not be resolved: "
f"no submodule matching token '{token}' in module {module}, " f"no submodule matching token '{token}' in module {module}, "
f"skipping this plan entry." f"skipping this plan entry.",
stacklevel=2,
) )
continue continue

View File

@ -62,7 +62,8 @@ class Distribution:
warnings.warn( warnings.warn(
f"{self.__class__} does not define `arg_constraints`. " f"{self.__class__} does not define `arg_constraints`. "
+ "Please set `arg_constraints = {}` or initialize the distribution " + "Please set `arg_constraints = {}` or initialize the distribution "
+ "with `validate_args=False` to turn off validation." + "with `validate_args=False` to turn off validation.",
stacklevel=2,
) )
for param, constraint in arg_constraints.items(): for param, constraint in arg_constraints.items():
if constraints.is_dependent(constraint): if constraints.is_dependent(constraint):
@ -313,7 +314,8 @@ class Distribution:
warnings.warn( warnings.warn(
f"{self.__class__} does not define `support` to enable " f"{self.__class__} does not define `support` to enable "
+ "sample validation. Please initialize the distribution with " + "sample validation. Please initialize the distribution with "
+ "`validate_args=False` to turn off validation." + "`validate_args=False` to turn off validation.",
stacklevel=2,
) )
return return
assert support is not None assert support is not None

View File

@ -133,6 +133,7 @@ def _dispatch_kl(type_p, type_q):
f"Ambiguous kl_divergence({type_p.__name__}, {type_q.__name__}). " f"Ambiguous kl_divergence({type_p.__name__}, {type_q.__name__}). "
f"Please register_kl({left_p.__name__}, {right_q.__name__})", f"Please register_kl({left_p.__name__}, {right_q.__name__})",
RuntimeWarning, RuntimeWarning,
stacklevel=2,
) )
return left_fun return left_fun

View File

@ -127,7 +127,8 @@ class Wishart(ExponentialFamily):
if self.df.lt(event_shape[-1]).any(): if self.df.lt(event_shape[-1]).any():
warnings.warn( warnings.warn(
"Low df values detected. Singular samples are highly likely to occur for ndim - 1 < df < ndim." "Low df values detected. Singular samples are highly likely to occur for ndim - 1 < df < ndim.",
stacklevel=2,
) )
super().__init__(batch_shape, event_shape, validate_args=validate_args) super().__init__(batch_shape, event_shape, validate_args=validate_args)
@ -279,7 +280,7 @@ class Wishart(ExponentialFamily):
else: else:
# More optimized version with data-dependent control flow. # More optimized version with data-dependent control flow.
if is_singular.any(): if is_singular.any():
warnings.warn("Singular sample detected.") warnings.warn("Singular sample detected.", stacklevel=2)
for _ in range(max_try_correction): for _ in range(max_try_correction):
sample_new = self._bartlett_sampling(is_singular[is_singular].shape) sample_new = self._bartlett_sampling(is_singular[is_singular].shape)

View File

@ -500,10 +500,10 @@ def load(
if file_info.filename == "serialized_exported_program.json": if file_info.filename == "serialized_exported_program.json":
serialized_exported_program = file_content serialized_exported_program = file_content
elif file_info.filename == "serialized_state_dict.json": elif file_info.filename == "serialized_state_dict.json":
warnings.warn("This version of file is deprecated") warnings.warn("This version of file is deprecated", stacklevel=2)
serialized_state_dict = file_content serialized_state_dict = file_content
elif file_info.filename == "serialized_constants.json": elif file_info.filename == "serialized_constants.json":
warnings.warn("This version of file is deprecated") warnings.warn("This version of file is deprecated", stacklevel=2)
serialized_constants = file_content serialized_constants = file_content
elif file_info.filename == "serialized_state_dict.pt": elif file_info.filename == "serialized_state_dict.pt":
serialized_state_dict = file_content serialized_state_dict = file_content

View File

@ -2113,7 +2113,7 @@ def _export_for_training(
if torch._export.config.error_on_lifted_constant_tensors: if torch._export.config.error_on_lifted_constant_tensors:
raise RuntimeError(error_msg) raise RuntimeError(error_msg)
else: else:
warnings.warn(error_msg) warnings.warn(error_msg, stacklevel=2)
export_graph_signature = export_artifact.aten.sig export_graph_signature = export_artifact.aten.sig
@ -2189,7 +2189,8 @@ def _export_for_training(
f"This is likely result of torch.export.export not being able to track side effects " f"This is likely result of torch.export.export not being able to track side effects "
f"that is happening outside of model scope.\n\n" f"that is happening outside of model scope.\n\n"
f"Leaked tensors:\n {leak_details}\n\n" f"Leaked tensors:\n {leak_details}\n\n"
f"Alternatively, please file a bug report to PyTorch team for further debugging help." f"Alternatively, please file a bug report to PyTorch team for further debugging help.",
stacklevel=2,
) )
del legit_leak del legit_leak

View File

@ -530,7 +530,8 @@ def _create_stateful_graph_module(
f"A model attribute `{constant_fqn}` requires gradient. " f"A model attribute `{constant_fqn}` requires gradient. "
f"but it's not properly registered as a parameter. " f"but it's not properly registered as a parameter. "
f"torch.export will detach it and treat it as a constant tensor " f"torch.export will detach it and treat it as a constant tensor "
f"but please register it as parameter instead." f"but please register it as parameter instead.",
stacklevel=2,
) )
detached_buffer = buffer.detach() detached_buffer = buffer.detach()
original_tensor_to_detached_tensor[buffer] = detached_buffer original_tensor_to_detached_tensor[buffer] = detached_buffer
@ -549,7 +550,8 @@ def _create_stateful_graph_module(
f"A model attribute `{const_name}` requires gradient " f"A model attribute `{const_name}` requires gradient "
f"but it's not properly registered as a parameter. " f"but it's not properly registered as a parameter. "
f"torch.export will detach it and treat it as a constant tensor " f"torch.export will detach it and treat it as a constant tensor "
f"but please register it as parameter instead." f"but please register it as parameter instead.",
stacklevel=2,
) )
if value in original_tensor_to_detached_tensor: if value in original_tensor_to_detached_tensor:
value = original_tensor_to_detached_tensor[value] value = original_tensor_to_detached_tensor[value]

View File

@ -1684,7 +1684,8 @@ def _create_graph_module_for_export(root, graph):
"Unable to execute the generated python source code from " "Unable to execute the generated python source code from "
"the graph. The graph module will no longer be directly callable, " "the graph. The graph module will no longer be directly callable, "
"but you can still run the ExportedProgram, and if needed, you can " "but you can still run the ExportedProgram, and if needed, you can "
"run the graph module eagerly using torch.fx.Interpreter." "run the graph module eagerly using torch.fx.Interpreter.",
stacklevel=2,
) )
gm = torch.fx.GraphModule(root, torch.fx.Graph()) gm = torch.fx.GraphModule(root, torch.fx.Graph())
gm._graph = graph gm._graph = graph

View File

@ -108,7 +108,8 @@ def get_complete(
warnings.warn( warnings.warn(
"No complete tensor found in the group! Returning the first one. " "No complete tensor found in the group! Returning the first one. "
"This may cause issues when your weights are not on CPU." "This may cause issues when your weights are not on CPU.",
stacklevel=2,
) )
assert len(group) > 0 assert len(group) > 0
return next(iter(group)) return next(iter(group))

View File

@ -279,7 +279,8 @@ def _get_cache_or_reload(
f"The ref {ref} is ambiguous. Perhaps it is both a tag and a branch in the repo? " f"The ref {ref} is ambiguous. Perhaps it is both a tag and a branch in the repo? "
"Torchhub will now assume that it's a branch. " "Torchhub will now assume that it's a branch. "
"You can disambiguate tags and branches by explicitly passing refs/heads/branch_name or " "You can disambiguate tags and branches by explicitly passing refs/heads/branch_name or "
"refs/tags/tag_name as the ref. That might require using skip_validation=True." "refs/tags/tag_name as the ref. That might require using skip_validation=True.",
stacklevel=2,
) )
disambiguated_branch_ref = f"refs/heads/{ref}" disambiguated_branch_ref = f"refs/heads/{ref}"
url = _git_archive_link( url = _git_archive_link(
@ -338,7 +339,8 @@ def _check_repo_is_trusted(
"trust_repo=False) and a command prompt will appear asking for an explicit confirmation of trust, " "trust_repo=False) and a command prompt will appear asking for an explicit confirmation of trust, "
f"or {calling_fn}(..., trust_repo=True), which will assume that the prompt is to be answered with " f"or {calling_fn}(..., trust_repo=True), which will assume that the prompt is to be answered with "
f"'yes'. You can also use {calling_fn}(..., trust_repo='check') which will only prompt for " f"'yes'. You can also use {calling_fn}(..., trust_repo='check') which will only prompt for "
f"confirmation if the repo is not already trusted. This will eventually be the default behaviour" f"confirmation if the repo is not already trusted. This will eventually be the default behaviour",
stacklevel=2,
) )
return return
@ -406,7 +408,9 @@ def get_dir() -> str:
""" """
# Issue warning to move data if old env is set # Issue warning to move data if old env is set
if os.getenv("TORCH_HUB"): if os.getenv("TORCH_HUB"):
warnings.warn("TORCH_HUB is deprecated, please use env TORCH_HOME instead") warnings.warn(
"TORCH_HUB is deprecated, please use env TORCH_HOME instead", stacklevel=2
)
if _hub_dir is not None: if _hub_dir is not None:
return _hub_dir return _hub_dir
@ -853,7 +857,8 @@ def load_state_dict_from_url(
# Issue warning to move data if old env is set # Issue warning to move data if old env is set
if os.getenv("TORCH_MODEL_ZOO"): if os.getenv("TORCH_MODEL_ZOO"):
warnings.warn( warnings.warn(
"TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead" "TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead",
stacklevel=2,
) )
if model_dir is None: if model_dir is None:

View File

@ -257,7 +257,7 @@ class strict_fusion:
def __init__(self) -> None: def __init__(self) -> None:
if not torch._jit_internal.is_scripting(): if not torch._jit_internal.is_scripting():
warnings.warn("Only works in script mode") warnings.warn("Only works in script mode", stacklevel=2)
def __enter__(self): def __enter__(self):
pass pass

View File

@ -180,7 +180,8 @@ class AttributeTypeIsSupportedChecker(ast.NodeVisitor):
"instance-level annotations on empty non-base " "instance-level annotations on empty non-base "
"types in `__init__`. Instead, either 1) use a " "types in `__init__`. Instead, either 1) use a "
"type annotation in the class body, or 2) wrap " "type annotation in the class body, or 2) wrap "
"the type in `torch.jit.Attribute`." "the type in `torch.jit.Attribute`.",
stacklevel=2,
) )
def visit_Call(self, node): def visit_Call(self, node):
@ -245,5 +246,6 @@ class AttributeTypeIsSupportedChecker(ast.NodeVisitor):
"instance-level annotations on empty non-base " "instance-level annotations on empty non-base "
"types in `__init__`. Instead, either 1) use a " "types in `__init__`. Instead, either 1) use a "
"type annotation in the class body, or 2) wrap " "type annotation in the class body, or 2) wrap "
"the type in `torch.jit.Attribute`." "the type in `torch.jit.Attribute`.",
stacklevel=2,
) )

View File

@ -48,7 +48,9 @@ def signatures_match(decomposition_sig, torch_op_sig):
inspect_empty = inspect._empty # type: ignore[attr-defined] inspect_empty = inspect._empty # type: ignore[attr-defined]
for field in ["name", "annotation"]: for field in ["name", "annotation"]:
if field == "name" and decomp_param.name == "self": if field == "name" and decomp_param.name == "self":
warnings.warn("PyTorch uses 'input' instead of 'self' on public api") warnings.warn(
"PyTorch uses 'input' instead of 'self' on public api", stacklevel=2
)
if getattr(decomp_param, field) != getattr(op_param, field): if getattr(decomp_param, field) != getattr(op_param, field):
return False return False

View File

@ -309,7 +309,8 @@ def infer_concrete_type_builder(nn_module, share_types=True):
warnings.warn( warnings.warn(
f"'{name}' was found in ScriptModule constants, " f"'{name}' was found in ScriptModule constants, "
f" but it is a non-constant {hint}. Consider removing it." f" but it is a non-constant {hint}. Consider removing it.",
stacklevel=2,
) )
continue continue
if not hasattr(nn_module, name): if not hasattr(nn_module, name):
@ -318,7 +319,8 @@ def infer_concrete_type_builder(nn_module, share_types=True):
warnings.warn( warnings.warn(
f"'{name}' was found in ScriptModule constants, " f"'{name}' was found in ScriptModule constants, "
"but was not actually set in __init__. " "but was not actually set in __init__. "
"Consider removing it." "Consider removing it.",
stacklevel=2,
) )
continue continue
value = getattr(nn_module, name) value = getattr(nn_module, name)

View File

@ -775,6 +775,7 @@ if _enabled:
"Lite Interpreter is deprecated. Please consider switching to ExecuTorch. \ "Lite Interpreter is deprecated. Please consider switching to ExecuTorch. \
https://docs.pytorch.org/executorch/stable/getting-started.html", https://docs.pytorch.org/executorch/stable/getting-started.html",
DeprecationWarning, DeprecationWarning,
stacklevel=2,
) )
return self._c._save_for_mobile(*args, **kwargs) return self._c._save_for_mobile(*args, **kwargs)
@ -787,6 +788,7 @@ if _enabled:
"Lite Interpreter is deprecated. Please consider switching to ExecuTorch. \ "Lite Interpreter is deprecated. Please consider switching to ExecuTorch. \
https://docs.pytorch.org/executorch/stable/getting-started.html", https://docs.pytorch.org/executorch/stable/getting-started.html",
DeprecationWarning, DeprecationWarning,
stacklevel=2,
) )
return self._c._save_to_buffer_for_mobile(*args, **kwargs) return self._c._save_to_buffer_for_mobile(*args, **kwargs)
@ -1165,7 +1167,8 @@ def _script_impl(
warnings.warn( warnings.warn(
"Warning: monkeytype is not installed. Please install https://github.com/Instagram/MonkeyType " "Warning: monkeytype is not installed. Please install https://github.com/Instagram/MonkeyType "
"to enable Profile-Directed Typing in TorchScript. Refer to " "to enable Profile-Directed Typing in TorchScript. Refer to "
"https://github.com/Instagram/MonkeyType/blob/master/README.rst to install MonkeyType. " "https://github.com/Instagram/MonkeyType/blob/master/README.rst to install MonkeyType. ",
stacklevel=2,
) )
if isinstance(obj, torch.nn.Module): if isinstance(obj, torch.nn.Module):

View File

@ -686,7 +686,8 @@ def _trace_impl(
# it is hard to trace it because the forward method on ScriptModule is already defined, so it # it is hard to trace it because the forward method on ScriptModule is already defined, so it
# would result in an error. # would result in an error.
warnings.warn( warnings.warn(
"The input to trace is already a ScriptModule, tracing it is a no-op. Returning the object as is." "The input to trace is already a ScriptModule, tracing it is a no-op. Returning the object as is.",
stacklevel=2,
) )
return func return func

View File

@ -389,7 +389,8 @@ def is_tensor(ann):
warnings.warn( warnings.warn(
"TorchScript will treat type annotations of Tensor " "TorchScript will treat type annotations of Tensor "
"dtype-specific subtypes as if they are normal Tensors. " "dtype-specific subtypes as if they are normal Tensors. "
"dtype constraints are not enforced in compilation either." "dtype constraints are not enforced in compilation either.",
stacklevel=2,
) )
return True return True

View File

@ -44,7 +44,8 @@ def _apply_docstring_templates(func: Callable[_P, _T]) -> Callable[_P, _T]:
warnings.warn( warnings.warn(
f"No documentation string available for {func.__name__}." f"No documentation string available for {func.__name__}."
" PyTorch team should run `python tools/update_masked_docs.py`" " PyTorch team should run `python tools/update_masked_docs.py`"
" to generate the missing docstrings." " to generate the missing docstrings.",
stacklevel=2,
) )
else: else:
func.__doc__ = doc_string func.__doc__ = doc_string

View File

@ -322,7 +322,7 @@ class MaskedTensor(torch.Tensor):
"In the case that the semantics for the operator are not trivial, it would be appreciated " "In the case that the semantics for the operator are not trivial, it would be appreciated "
"to also include a proposal for the semantics." "to also include a proposal for the semantics."
) )
warnings.warn(msg) warnings.warn(msg, stacklevel=2)
return NotImplemented return NotImplemented
def __lt__(self, other): def __lt__(self, other):

View File

@ -90,7 +90,7 @@ def _torch_reduce_dim(fn):
"In the case that the semantics for the operator are not trivial, it would be appreciated " "In the case that the semantics for the operator are not trivial, it would be appreciated "
"to also include a proposal for the semantics." "to also include a proposal for the semantics."
) )
warnings.warn(msg) warnings.warn(msg, stacklevel=2)
return NotImplemented return NotImplemented
if not is_masked_tensor(self): if not is_masked_tensor(self):
raise TypeError("Input to reduce_dim must be a MaskedTensor") raise TypeError("Input to reduce_dim must be a MaskedTensor")

View File

@ -223,7 +223,9 @@ class ProcessContext:
class SpawnContext(ProcessContext): class SpawnContext(ProcessContext):
def __init__(self, processes, error_files): def __init__(self, processes, error_files):
warnings.warn("SpawnContext is renamed to ProcessContext since 1.4 release.") warnings.warn(
"SpawnContext is renamed to ProcessContext since 1.4 release.", stacklevel=2
)
super().__init__(processes, error_files) super().__init__(processes, error_files)

View File

@ -13,7 +13,8 @@ def get_enum(reduction: str) -> int:
elif reduction == "elementwise_mean": elif reduction == "elementwise_mean":
warnings.warn( warnings.warn(
"reduction='elementwise_mean' is deprecated. " "reduction='elementwise_mean' is deprecated. "
"Please use reduction='mean' instead." "Please use reduction='mean' instead.",
stacklevel=2,
) )
ret = 1 ret = 1
elif reduction == "sum": elif reduction == "sum":
@ -48,7 +49,7 @@ def legacy_get_string(
else: else:
ret = "none" ret = "none"
if emit_warning: if emit_warning:
warnings.warn(warning.format(ret)) warnings.warn(warning.format(ret), stacklevel=2)
return ret return ret

View File

@ -60,10 +60,10 @@ def _raise_kernel_warnings(params: SDPAParams) -> None:
""" """
if WARN_FOR_UNFUSED_KERNELS: if WARN_FOR_UNFUSED_KERNELS:
if not can_use_efficient_attention(params): if not can_use_efficient_attention(params):
warn("Efficient attention can't be used because:") warn("Efficient attention can't be used because:", stacklevel=2)
can_use_efficient_attention(params, True) can_use_efficient_attention(params, True)
if not can_use_flash_attention(params): if not can_use_flash_attention(params):
warn("Flash attention can't be used because:") warn("Flash attention can't be used because:", stacklevel=2)
can_use_flash_attention(params, True) can_use_flash_attention(params, True)

View File

@ -134,7 +134,8 @@ class CausalBias(torch.Tensor):
self.seq_len_kv = seq_len_kv self.seq_len_kv = seq_len_kv
if seq_len_q > seq_len_kv and variant == CausalVariant.LOWER_RIGHT: if seq_len_q > seq_len_kv and variant == CausalVariant.LOWER_RIGHT:
warn( warn(
"Lower right causal bias will produce NaNs in the output when seq_len_q > seq_len_kv!" "Lower right causal bias will produce NaNs in the output when seq_len_q > seq_len_kv!",
stacklevel=2,
) )
def _upper_left(self, device: torch.device) -> torch.Tensor: def _upper_left(self, device: torch.device) -> torch.Tensor:

Some files were not shown because too many files have changed in this diff Show More