mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Remove legacy constructor calls from _torch_ folder. (#53889)
Summary: Fixes https://github.com/pytorch/pytorch/issues/53146 Related to https://github.com/pytorch/pytorch/issues/47112 As mentioned in https://github.com/pytorch/pytorch/issues/47112, the plan is to: 1. Verify that all `torch.Tensor()` scenarios are covered by other functions 2. Scrub internal `torch.Tensor()` uses 3. Update the docs and throw `TORCH_WARN_ONCE` if someone uses `torch.Tensor()` In this PR, I replaced all occurrences of `torch.Tensor` present in the _torch_ folder. Pull Request resolved: https://github.com/pytorch/pytorch/pull/53889 Reviewed By: walterddr, zou3519 Differential Revision: D27190743 Pulled By: jbschlosser fbshipit-source-id: 7ecc201d57935b8dbb98ae3718b60d95cb55a010
This commit is contained in:
parent
6a4d2c61d5
commit
27048c1dfa
|
|
@ -6435,7 +6435,7 @@ Keyword args:
|
|||
Example::
|
||||
|
||||
>>> eps = torch.finfo(torch.float32).eps
|
||||
>>> torch.nextafter(torch.Tensor([1, 2]), torch.Tensor([2, 1])) == torch.Tensor([eps + 1, 2 - eps])
|
||||
>>> torch.nextafter(torch.tensor([1.0, 2.0]), torch.tensor([2.0, 1.0])) == torch.tensor([eps + 1, 2 - eps])
|
||||
tensor([True, True])
|
||||
|
||||
""".format(**common_args))
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ class Kumaraswamy(TransformedDistribution):
|
|||
|
||||
Example::
|
||||
|
||||
>>> m = Kumaraswamy(torch.Tensor([1.0]), torch.Tensor([1.0]))
|
||||
>>> m = Kumaraswamy(torch.tensor([1.0]), torch.tensor([1.0]))
|
||||
>>> m.sample() # sample from a Kumaraswamy distribution with concentration alpha=1 and beta=1
|
||||
tensor([ 0.1729])
|
||||
|
||||
|
|
|
|||
|
|
@ -840,7 +840,7 @@ Args:
|
|||
compute_uv (bool, optional): whether to compute `U` and `V` or not. Defaults to True.
|
||||
out (tuple, optional): a tuple of three tensors to use for the outputs. If compute_uv=False,
|
||||
the 1st and 3rd arguments must be tensors, but they are ignored. E.g. you can
|
||||
pass `(torch.Tensor(), out_S, torch.Tensor())`
|
||||
pass `(torch.tensor([]), out_S, torch.tensor([]))`
|
||||
|
||||
Example::
|
||||
|
||||
|
|
|
|||
|
|
@ -49,7 +49,7 @@ class _ConvBnNd(nn.modules.conv._ConvNd, nni._FusedModule):
|
|||
self.bn = _BN_CLASS_MAP[dim](out_channels, eps, momentum, True, True)
|
||||
self.weight_fake_quant = self.qconfig.weight()
|
||||
if bias:
|
||||
self.bias = Parameter(torch.Tensor(out_channels))
|
||||
self.bias = Parameter(torch.empty(out_channels))
|
||||
else:
|
||||
self.register_parameter('bias', None)
|
||||
self.reset_bn_parameters()
|
||||
|
|
|
|||
|
|
@ -875,9 +875,9 @@ class MultiheadAttention(Module):
|
|||
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
|
||||
|
||||
if self._qkv_same_embed_dim is False:
|
||||
self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
|
||||
self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim))
|
||||
self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim))
|
||||
self.q_proj_weight = Parameter(torch.empty(embed_dim, embed_dim))
|
||||
self.k_proj_weight = Parameter(torch.empty(embed_dim, self.kdim))
|
||||
self.v_proj_weight = Parameter(torch.empty(embed_dim, self.vdim))
|
||||
self.register_parameter('in_proj_weight', None)
|
||||
else:
|
||||
self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim))
|
||||
|
|
@ -1043,7 +1043,7 @@ class PReLU(Module):
|
|||
def __init__(self, num_parameters: int = 1, init: float = 0.25) -> None:
|
||||
self.num_parameters = num_parameters
|
||||
super(PReLU, self).__init__()
|
||||
self.weight = Parameter(torch.Tensor(num_parameters).fill_(init))
|
||||
self.weight = Parameter(torch.empty(num_parameters).fill_(init))
|
||||
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
return F.prelu(input, self.weight)
|
||||
|
|
|
|||
|
|
@ -38,8 +38,8 @@ class _NormBase(Module):
|
|||
self.affine = affine
|
||||
self.track_running_stats = track_running_stats
|
||||
if self.affine:
|
||||
self.weight = Parameter(torch.Tensor(num_features))
|
||||
self.bias = Parameter(torch.Tensor(num_features))
|
||||
self.weight = Parameter(torch.empty(num_features))
|
||||
self.bias = Parameter(torch.empty(num_features))
|
||||
else:
|
||||
self.register_parameter('weight', None)
|
||||
self.register_parameter('bias', None)
|
||||
|
|
|
|||
|
|
@ -122,13 +122,13 @@ class _ConvNd(Module):
|
|||
self._reversed_padding_repeated_twice = _reverse_repeat_tuple(self.padding, 2)
|
||||
|
||||
if transposed:
|
||||
self.weight = Parameter(torch.Tensor(
|
||||
self.weight = Parameter(torch.empty(
|
||||
in_channels, out_channels // groups, *kernel_size))
|
||||
else:
|
||||
self.weight = Parameter(torch.Tensor(
|
||||
self.weight = Parameter(torch.empty(
|
||||
out_channels, in_channels // groups, *kernel_size))
|
||||
if bias:
|
||||
self.bias = Parameter(torch.Tensor(out_channels))
|
||||
self.bias = Parameter(torch.empty(out_channels))
|
||||
else:
|
||||
self.register_parameter('bias', None)
|
||||
self.reset_parameters()
|
||||
|
|
|
|||
|
|
@ -76,9 +76,9 @@ class Linear(Module):
|
|||
super(Linear, self).__init__()
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self.weight = Parameter(torch.Tensor(out_features, in_features))
|
||||
self.weight = Parameter(torch.empty(out_features, in_features))
|
||||
if bias:
|
||||
self.bias = Parameter(torch.Tensor(out_features))
|
||||
self.bias = Parameter(torch.empty(out_features))
|
||||
else:
|
||||
self.register_parameter('bias', None)
|
||||
self.reset_parameters()
|
||||
|
|
@ -157,10 +157,10 @@ class Bilinear(Module):
|
|||
self.in1_features = in1_features
|
||||
self.in2_features = in2_features
|
||||
self.out_features = out_features
|
||||
self.weight = Parameter(torch.Tensor(out_features, in1_features, in2_features))
|
||||
self.weight = Parameter(torch.empty(out_features, in1_features, in2_features))
|
||||
|
||||
if bias:
|
||||
self.bias = Parameter(torch.Tensor(out_features))
|
||||
self.bias = Parameter(torch.empty(out_features))
|
||||
else:
|
||||
self.register_parameter('bias', None)
|
||||
self.reset_parameters()
|
||||
|
|
|
|||
|
|
@ -154,8 +154,8 @@ class LayerNorm(Module):
|
|||
self.eps = eps
|
||||
self.elementwise_affine = elementwise_affine
|
||||
if self.elementwise_affine:
|
||||
self.weight = Parameter(torch.Tensor(*self.normalized_shape))
|
||||
self.bias = Parameter(torch.Tensor(*self.normalized_shape))
|
||||
self.weight = Parameter(torch.empty(self.normalized_shape))
|
||||
self.bias = Parameter(torch.empty(self.normalized_shape))
|
||||
else:
|
||||
self.register_parameter('weight', None)
|
||||
self.register_parameter('bias', None)
|
||||
|
|
@ -230,8 +230,8 @@ class GroupNorm(Module):
|
|||
self.eps = eps
|
||||
self.affine = affine
|
||||
if self.affine:
|
||||
self.weight = Parameter(torch.Tensor(num_channels))
|
||||
self.bias = Parameter(torch.Tensor(num_channels))
|
||||
self.weight = Parameter(torch.empty(num_channels))
|
||||
self.bias = Parameter(torch.empty(num_channels))
|
||||
else:
|
||||
self.register_parameter('weight', None)
|
||||
self.register_parameter('bias', None)
|
||||
|
|
|
|||
|
|
@ -84,12 +84,12 @@ class RNNBase(Module):
|
|||
real_hidden_size = proj_size if proj_size > 0 else hidden_size
|
||||
layer_input_size = input_size if layer == 0 else real_hidden_size * num_directions
|
||||
|
||||
w_ih = Parameter(torch.Tensor(gate_size, layer_input_size))
|
||||
w_hh = Parameter(torch.Tensor(gate_size, real_hidden_size))
|
||||
b_ih = Parameter(torch.Tensor(gate_size))
|
||||
w_ih = Parameter(torch.empty(gate_size, layer_input_size))
|
||||
w_hh = Parameter(torch.empty(gate_size, real_hidden_size))
|
||||
b_ih = Parameter(torch.empty(gate_size))
|
||||
# Second bias vector included for CuDNN compatibility. Only one
|
||||
# bias vector is needed in standard definition.
|
||||
b_hh = Parameter(torch.Tensor(gate_size))
|
||||
b_hh = Parameter(torch.empty(gate_size))
|
||||
layer_params: Tuple[Tensor, ...] = ()
|
||||
if self.proj_size == 0:
|
||||
if bias:
|
||||
|
|
@ -97,7 +97,7 @@ class RNNBase(Module):
|
|||
else:
|
||||
layer_params = (w_ih, w_hh)
|
||||
else:
|
||||
w_hr = Parameter(torch.Tensor(proj_size, hidden_size))
|
||||
w_hr = Parameter(torch.empty(proj_size, hidden_size))
|
||||
if bias:
|
||||
layer_params = (w_ih, w_hh, b_ih, b_hh, w_hr)
|
||||
else:
|
||||
|
|
@ -850,11 +850,11 @@ class RNNCellBase(Module):
|
|||
self.input_size = input_size
|
||||
self.hidden_size = hidden_size
|
||||
self.bias = bias
|
||||
self.weight_ih = Parameter(torch.Tensor(num_chunks * hidden_size, input_size))
|
||||
self.weight_hh = Parameter(torch.Tensor(num_chunks * hidden_size, hidden_size))
|
||||
self.weight_ih = Parameter(torch.empty(num_chunks * hidden_size, input_size))
|
||||
self.weight_hh = Parameter(torch.empty(num_chunks * hidden_size, hidden_size))
|
||||
if bias:
|
||||
self.bias_ih = Parameter(torch.Tensor(num_chunks * hidden_size))
|
||||
self.bias_hh = Parameter(torch.Tensor(num_chunks * hidden_size))
|
||||
self.bias_ih = Parameter(torch.empty(num_chunks * hidden_size))
|
||||
self.bias_hh = Parameter(torch.empty(num_chunks * hidden_size))
|
||||
else:
|
||||
self.register_parameter('bias_ih', None)
|
||||
self.register_parameter('bias_hh', None)
|
||||
|
|
|
|||
|
|
@ -134,7 +134,7 @@ class Embedding(Module):
|
|||
self.norm_type = norm_type
|
||||
self.scale_grad_by_freq = scale_grad_by_freq
|
||||
if _weight is None:
|
||||
self.weight = Parameter(torch.Tensor(num_embeddings, embedding_dim))
|
||||
self.weight = Parameter(torch.empty(num_embeddings, embedding_dim))
|
||||
self.reset_parameters()
|
||||
else:
|
||||
assert list(_weight.shape) == [num_embeddings, embedding_dim], \
|
||||
|
|
@ -322,7 +322,7 @@ class EmbeddingBag(Module):
|
|||
self.norm_type = norm_type
|
||||
self.scale_grad_by_freq = scale_grad_by_freq
|
||||
if _weight is None:
|
||||
self.weight = Parameter(torch.Tensor(num_embeddings, embedding_dim))
|
||||
self.weight = Parameter(torch.empty(num_embeddings, embedding_dim))
|
||||
self.reset_parameters()
|
||||
else:
|
||||
assert list(_weight.shape) == [num_embeddings, embedding_dim], \
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ class Parameter(torch.Tensor):
|
|||
"""
|
||||
def __new__(cls, data=None, requires_grad=True):
|
||||
if data is None:
|
||||
data = torch.Tensor()
|
||||
data = torch.tensor([])
|
||||
return torch.Tensor._make_subclass(cls, data, requires_grad)
|
||||
|
||||
def __deepcopy__(self, memo):
|
||||
|
|
@ -146,7 +146,7 @@ class UninitializedParameter(UninitializedTensorMixin, Parameter):
|
|||
cls_to_become = Parameter
|
||||
|
||||
def __new__(cls, requires_grad=True):
|
||||
data = torch.Tensor()
|
||||
data = torch.tensor([])
|
||||
return torch.Tensor._make_subclass(cls, data, requires_grad)
|
||||
|
||||
|
||||
|
|
@ -166,5 +166,5 @@ class UninitializedBuffer(UninitializedTensorMixin, torch.Tensor):
|
|||
cls_to_become = torch.Tensor
|
||||
|
||||
def __new__(cls, requires_grad=False):
|
||||
data = torch.Tensor()
|
||||
data = torch.tensor([])
|
||||
return torch.Tensor._make_subclass(cls, data, requires_grad)
|
||||
|
|
|
|||
|
|
@ -1144,7 +1144,7 @@ def custom_from_mask(module, name, mask):
|
|||
|
||||
Examples:
|
||||
>>> m = prune.custom_from_mask(
|
||||
nn.Linear(5, 3), name='bias', mask=torch.Tensor([0, 1, 0])
|
||||
nn.Linear(5, 3), name='bias', mask=torch.tensor([0, 1, 0])
|
||||
)
|
||||
>>> print(m.bias_mask)
|
||||
tensor([0., 1., 0.])
|
||||
|
|
|
|||
|
|
@ -2455,7 +2455,7 @@ def scatter_add(g, self, dim, index, src):
|
|||
|
||||
def log2(g, self):
|
||||
_ln2 = 0.693147180559945309
|
||||
return g.op('Div', log(g, self), g.op('Constant', value_t=torch.Tensor([_ln2])))
|
||||
return g.op('Div', log(g, self), g.op('Constant', value_t=torch.tensor([_ln2])))
|
||||
|
||||
|
||||
def prim_shape(g, self):
|
||||
|
|
|
|||
|
|
@ -588,8 +588,8 @@ class Quantizer:
|
|||
# converting List[int] to Tensor since module attribute is
|
||||
# Union[Tensor, Module]
|
||||
model._standalone_module_input_quantized_idxs = \
|
||||
torch.Tensor(input_quantized_idxs)
|
||||
model._standalone_module_output_quantized_idxs = torch.Tensor(output_quantized_idxs)
|
||||
torch.tensor(input_quantized_idxs)
|
||||
model._standalone_module_output_quantized_idxs = torch.tensor(output_quantized_idxs)
|
||||
return model
|
||||
|
||||
def save_state(self, observed: GraphModule) -> None:
|
||||
|
|
|
|||
|
|
@ -101,7 +101,7 @@ results_type={self.results_type}, index_within_arg={self.index_within_arg})"""
|
|||
# # one of NSSingleResultValuesType
|
||||
# 'type': 'weight',
|
||||
# # the values of type specified above
|
||||
# 'values': [torch.Tensor(...), ...],
|
||||
# 'values': [torch.tensor(...), ...],
|
||||
# # name of the node directly before the logger
|
||||
# 'prev_node_name': 'linear1',
|
||||
# # type of the underlying function or module
|
||||
|
|
|
|||
|
|
@ -652,7 +652,7 @@ def mseloss_no_reduce_scalar_test():
|
|||
|
||||
|
||||
def nllloss_no_reduce_test():
|
||||
t = Variable(torch.Tensor(15).uniform_().mul(10).floor().long())
|
||||
t = Variable(torch.empty(15).uniform_().mul(10).floor().long())
|
||||
kwargs = {'reduction': 'none'}
|
||||
return dict(
|
||||
fullname='NLLLoss_no_reduce',
|
||||
|
|
@ -668,7 +668,7 @@ def nllloss_no_reduce_test():
|
|||
|
||||
|
||||
def nllloss_no_reduce_ignore_index_test():
|
||||
t = Variable(torch.Tensor(15).uniform_().mul(10).floor().long())
|
||||
t = Variable(torch.empty(15).uniform_().mul(10).floor().long())
|
||||
kwargs: Dict[str, Union[int, str]] = {'ignore_index': 2, 'reduction': 'none'}
|
||||
return dict(
|
||||
fullname='NLLLoss_no_reduce_ignore_index',
|
||||
|
|
@ -685,7 +685,7 @@ def nllloss_no_reduce_ignore_index_test():
|
|||
|
||||
|
||||
def nllloss_no_reduce_weights_test():
|
||||
t = Variable(torch.Tensor(15).uniform_().mul(10).floor().long())
|
||||
t = Variable(torch.empty(15).uniform_().mul(10).floor().long())
|
||||
weight = torch.rand(10)
|
||||
|
||||
def kwargs(i):
|
||||
|
|
@ -706,7 +706,7 @@ def nllloss_no_reduce_weights_test():
|
|||
|
||||
|
||||
def nllloss_no_reduce_weights_ignore_index_test():
|
||||
t = Variable(torch.Tensor(15).uniform_().mul(10).floor().long())
|
||||
t = Variable(torch.empty(15).uniform_().mul(10).floor().long())
|
||||
weight = torch.rand(10)
|
||||
|
||||
def kwargs(i):
|
||||
|
|
@ -728,7 +728,7 @@ def nllloss_no_reduce_weights_ignore_index_test():
|
|||
|
||||
|
||||
def nllloss_no_reduce_weights_ignore_index_neg_test():
|
||||
t = Variable(torch.Tensor(15).uniform_().mul(10).floor().long())
|
||||
t = Variable(torch.empty(15).uniform_().mul(10).floor().long())
|
||||
weight = torch.rand(10)
|
||||
|
||||
def kwargs(i):
|
||||
|
|
@ -4258,7 +4258,7 @@ criterion_tests = [
|
|||
dict(
|
||||
module_name='NLLLoss',
|
||||
input_fn=lambda: torch.rand(15, 10).log(),
|
||||
target_fn=lambda: torch.Tensor(15).uniform_().mul(10).floor().long(),
|
||||
target_fn=lambda: torch.empty(15).uniform_().mul(10).floor().long(),
|
||||
reference_fn=lambda i, t, m:
|
||||
nllloss_reference(i, t, reduction=get_reduction(m)),
|
||||
check_sum_reduction=True,
|
||||
|
|
@ -4269,7 +4269,7 @@ criterion_tests = [
|
|||
constructor_args=(None, None, 2),
|
||||
cpp_constructor_args='torch::nn::NLLLossOptions().weight({}).ignore_index(2)',
|
||||
input_fn=lambda: torch.rand(15, 10).log(),
|
||||
target_fn=lambda: torch.Tensor(15).uniform_().mul(10).floor().long(),
|
||||
target_fn=lambda: torch.empty(15).uniform_().mul(10).floor().long(),
|
||||
reference_fn=lambda i, t, _: nllloss_reference(i, t, ignore_index=2),
|
||||
desc='ignore_index',
|
||||
check_bfloat16=True,
|
||||
|
|
@ -4279,7 +4279,7 @@ criterion_tests = [
|
|||
constructor_args_fn=lambda: (torch.rand(10),),
|
||||
cpp_constructor_args='torch::nn::NLLLossOptions().weight(torch::rand(10))',
|
||||
input_fn=lambda: torch.rand(15, 10).add(1e-2).log(),
|
||||
target_fn=lambda: torch.Tensor(15).uniform_().mul(10).floor().long(),
|
||||
target_fn=lambda: torch.empty(15).uniform_().mul(10).floor().long(),
|
||||
reference_fn=lambda i, t, m:
|
||||
nllloss_reference(i, t, weight=get_weight(m)),
|
||||
desc='weights',
|
||||
|
|
@ -4290,7 +4290,7 @@ criterion_tests = [
|
|||
constructor_args_fn=lambda: (torch.rand(10), None, 2),
|
||||
cpp_constructor_args='torch::nn::NLLLossOptions().weight(torch::rand(10)).ignore_index(2)',
|
||||
input_fn=lambda: torch.rand(15, 10).add(1e-2).log(),
|
||||
target_fn=lambda: torch.Tensor(15).uniform_().mul(10).floor().long(),
|
||||
target_fn=lambda: torch.empty(15).uniform_().mul(10).floor().long(),
|
||||
reference_fn=lambda i, t, m:
|
||||
nllloss_reference(i, t, weight=get_weight(m), ignore_index=2),
|
||||
desc='weights_ignore_index',
|
||||
|
|
@ -4301,7 +4301,7 @@ criterion_tests = [
|
|||
constructor_args_fn=lambda: (torch.rand(10), None, -1),
|
||||
cpp_constructor_args='torch::nn::NLLLossOptions().weight(torch::rand(10)).ignore_index(-1)',
|
||||
input_fn=lambda: torch.rand(15, 10).add(1e-2).log(),
|
||||
target_fn=lambda: torch.Tensor(15).uniform_().mul(10 + 1).floor().long() - 1,
|
||||
target_fn=lambda: torch.empty(15).uniform_().mul(10 + 1).floor().long() - 1,
|
||||
reference_fn=lambda i, t, m:
|
||||
nllloss_reference(i, t, weight=get_weight(m), ignore_index=-1),
|
||||
desc='weights_ignore_index_neg',
|
||||
|
|
@ -4354,14 +4354,14 @@ criterion_tests = [
|
|||
dict(
|
||||
module_name='CrossEntropyLoss',
|
||||
input_size=(15, 10),
|
||||
target_fn=lambda: torch.Tensor(15).uniform_().mul(10).floor().long(),
|
||||
target_fn=lambda: torch.empty(15).uniform_().mul(10).floor().long(),
|
||||
),
|
||||
dict(
|
||||
module_name='CrossEntropyLoss',
|
||||
constructor_args_fn=lambda: (torch.rand(10),),
|
||||
cpp_constructor_args='torch::nn::CrossEntropyLossOptions().weight(torch::rand(10))',
|
||||
input_size=(15, 10),
|
||||
target_fn=lambda: torch.Tensor(15).uniform_().mul(10).floor().long(),
|
||||
target_fn=lambda: torch.empty(15).uniform_().mul(10).floor().long(),
|
||||
desc='weights',
|
||||
),
|
||||
dict(
|
||||
|
|
|
|||
|
|
@ -97,7 +97,7 @@ class RemoteEM(nn.Module):
|
|||
self.em = nn.EmbeddingBag(
|
||||
num_embeddings,
|
||||
embedding_dim,
|
||||
_weight=torch.Tensor([init_em] * num_embeddings),
|
||||
_weight=torch.tensor([init_em] * num_embeddings),
|
||||
)
|
||||
|
||||
def forward(self, input: torch.Tensor):
|
||||
|
|
@ -278,10 +278,10 @@ def get_training_examples():
|
|||
# Every example has another one that has exactly the same features but an
|
||||
# opposite value. Therefore, their grads cancel each other in all-reduce.
|
||||
for value in (-1, 1):
|
||||
for x in (-1 * value, 1 * value):
|
||||
for y in (1 * value, -1 * value):
|
||||
for x in (-1.0 * value, 1.0 * value):
|
||||
for y in (1.0 * value, -1.0 * value):
|
||||
for z in (0, 1):
|
||||
training_examples.dense_features[idx, :] = torch.Tensor((x, y))
|
||||
training_examples.dense_features[idx, :] = torch.tensor((x, y))
|
||||
training_examples.sparse_features[idx] = z
|
||||
training_examples.values[idx] = value
|
||||
idx += 1
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user