Remove legacy constructor calls from _torch_ folder. (#53889)

Summary:
Fixes https://github.com/pytorch/pytorch/issues/53146
Related to https://github.com/pytorch/pytorch/issues/47112

As mentioned in https://github.com/pytorch/pytorch/issues/47112, the plan is to:

1. Verify that all `torch.Tensor()` scenarios are covered by other functions
2. Scrub internal `torch.Tensor()` uses
3. Update the docs and throw `TORCH_WARN_ONCE` if someone uses `torch.Tensor()`

In this PR, I replaced all occurrences of `torch.Tensor` present in the _torch_ folder.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/53889

Reviewed By: walterddr, zou3519

Differential Revision: D27190743

Pulled By: jbschlosser

fbshipit-source-id: 7ecc201d57935b8dbb98ae3718b60d95cb55a010
This commit is contained in:
Yukio Siraichi 2021-03-19 15:17:23 -07:00 committed by Facebook GitHub Bot
parent 6a4d2c61d5
commit 27048c1dfa
18 changed files with 56 additions and 56 deletions

View File

@ -6435,7 +6435,7 @@ Keyword args:
Example:: Example::
>>> eps = torch.finfo(torch.float32).eps >>> eps = torch.finfo(torch.float32).eps
>>> torch.nextafter(torch.Tensor([1, 2]), torch.Tensor([2, 1])) == torch.Tensor([eps + 1, 2 - eps]) >>> torch.nextafter(torch.tensor([1.0, 2.0]), torch.tensor([2.0, 1.0])) == torch.tensor([eps + 1, 2 - eps])
tensor([True, True]) tensor([True, True])
""".format(**common_args)) """.format(**common_args))

View File

@ -21,7 +21,7 @@ class Kumaraswamy(TransformedDistribution):
Example:: Example::
>>> m = Kumaraswamy(torch.Tensor([1.0]), torch.Tensor([1.0])) >>> m = Kumaraswamy(torch.tensor([1.0]), torch.tensor([1.0]))
>>> m.sample() # sample from a Kumaraswamy distribution with concentration alpha=1 and beta=1 >>> m.sample() # sample from a Kumaraswamy distribution with concentration alpha=1 and beta=1
tensor([ 0.1729]) tensor([ 0.1729])

View File

@ -840,7 +840,7 @@ Args:
compute_uv (bool, optional): whether to compute `U` and `V` or not. Defaults to True. compute_uv (bool, optional): whether to compute `U` and `V` or not. Defaults to True.
out (tuple, optional): a tuple of three tensors to use for the outputs. If compute_uv=False, out (tuple, optional): a tuple of three tensors to use for the outputs. If compute_uv=False,
the 1st and 3rd arguments must be tensors, but they are ignored. E.g. you can the 1st and 3rd arguments must be tensors, but they are ignored. E.g. you can
pass `(torch.Tensor(), out_S, torch.Tensor())` pass `(torch.tensor([]), out_S, torch.tensor([]))`
Example:: Example::

View File

@ -49,7 +49,7 @@ class _ConvBnNd(nn.modules.conv._ConvNd, nni._FusedModule):
self.bn = _BN_CLASS_MAP[dim](out_channels, eps, momentum, True, True) self.bn = _BN_CLASS_MAP[dim](out_channels, eps, momentum, True, True)
self.weight_fake_quant = self.qconfig.weight() self.weight_fake_quant = self.qconfig.weight()
if bias: if bias:
self.bias = Parameter(torch.Tensor(out_channels)) self.bias = Parameter(torch.empty(out_channels))
else: else:
self.register_parameter('bias', None) self.register_parameter('bias', None)
self.reset_bn_parameters() self.reset_bn_parameters()

View File

@ -875,9 +875,9 @@ class MultiheadAttention(Module):
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
if self._qkv_same_embed_dim is False: if self._qkv_same_embed_dim is False:
self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim)) self.q_proj_weight = Parameter(torch.empty(embed_dim, embed_dim))
self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim)) self.k_proj_weight = Parameter(torch.empty(embed_dim, self.kdim))
self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim)) self.v_proj_weight = Parameter(torch.empty(embed_dim, self.vdim))
self.register_parameter('in_proj_weight', None) self.register_parameter('in_proj_weight', None)
else: else:
self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim)) self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim))
@ -1043,7 +1043,7 @@ class PReLU(Module):
def __init__(self, num_parameters: int = 1, init: float = 0.25) -> None: def __init__(self, num_parameters: int = 1, init: float = 0.25) -> None:
self.num_parameters = num_parameters self.num_parameters = num_parameters
super(PReLU, self).__init__() super(PReLU, self).__init__()
self.weight = Parameter(torch.Tensor(num_parameters).fill_(init)) self.weight = Parameter(torch.empty(num_parameters).fill_(init))
def forward(self, input: Tensor) -> Tensor: def forward(self, input: Tensor) -> Tensor:
return F.prelu(input, self.weight) return F.prelu(input, self.weight)

View File

@ -38,8 +38,8 @@ class _NormBase(Module):
self.affine = affine self.affine = affine
self.track_running_stats = track_running_stats self.track_running_stats = track_running_stats
if self.affine: if self.affine:
self.weight = Parameter(torch.Tensor(num_features)) self.weight = Parameter(torch.empty(num_features))
self.bias = Parameter(torch.Tensor(num_features)) self.bias = Parameter(torch.empty(num_features))
else: else:
self.register_parameter('weight', None) self.register_parameter('weight', None)
self.register_parameter('bias', None) self.register_parameter('bias', None)

View File

@ -122,13 +122,13 @@ class _ConvNd(Module):
self._reversed_padding_repeated_twice = _reverse_repeat_tuple(self.padding, 2) self._reversed_padding_repeated_twice = _reverse_repeat_tuple(self.padding, 2)
if transposed: if transposed:
self.weight = Parameter(torch.Tensor( self.weight = Parameter(torch.empty(
in_channels, out_channels // groups, *kernel_size)) in_channels, out_channels // groups, *kernel_size))
else: else:
self.weight = Parameter(torch.Tensor( self.weight = Parameter(torch.empty(
out_channels, in_channels // groups, *kernel_size)) out_channels, in_channels // groups, *kernel_size))
if bias: if bias:
self.bias = Parameter(torch.Tensor(out_channels)) self.bias = Parameter(torch.empty(out_channels))
else: else:
self.register_parameter('bias', None) self.register_parameter('bias', None)
self.reset_parameters() self.reset_parameters()

View File

@ -76,9 +76,9 @@ class Linear(Module):
super(Linear, self).__init__() super(Linear, self).__init__()
self.in_features = in_features self.in_features = in_features
self.out_features = out_features self.out_features = out_features
self.weight = Parameter(torch.Tensor(out_features, in_features)) self.weight = Parameter(torch.empty(out_features, in_features))
if bias: if bias:
self.bias = Parameter(torch.Tensor(out_features)) self.bias = Parameter(torch.empty(out_features))
else: else:
self.register_parameter('bias', None) self.register_parameter('bias', None)
self.reset_parameters() self.reset_parameters()
@ -157,10 +157,10 @@ class Bilinear(Module):
self.in1_features = in1_features self.in1_features = in1_features
self.in2_features = in2_features self.in2_features = in2_features
self.out_features = out_features self.out_features = out_features
self.weight = Parameter(torch.Tensor(out_features, in1_features, in2_features)) self.weight = Parameter(torch.empty(out_features, in1_features, in2_features))
if bias: if bias:
self.bias = Parameter(torch.Tensor(out_features)) self.bias = Parameter(torch.empty(out_features))
else: else:
self.register_parameter('bias', None) self.register_parameter('bias', None)
self.reset_parameters() self.reset_parameters()

View File

@ -154,8 +154,8 @@ class LayerNorm(Module):
self.eps = eps self.eps = eps
self.elementwise_affine = elementwise_affine self.elementwise_affine = elementwise_affine
if self.elementwise_affine: if self.elementwise_affine:
self.weight = Parameter(torch.Tensor(*self.normalized_shape)) self.weight = Parameter(torch.empty(self.normalized_shape))
self.bias = Parameter(torch.Tensor(*self.normalized_shape)) self.bias = Parameter(torch.empty(self.normalized_shape))
else: else:
self.register_parameter('weight', None) self.register_parameter('weight', None)
self.register_parameter('bias', None) self.register_parameter('bias', None)
@ -230,8 +230,8 @@ class GroupNorm(Module):
self.eps = eps self.eps = eps
self.affine = affine self.affine = affine
if self.affine: if self.affine:
self.weight = Parameter(torch.Tensor(num_channels)) self.weight = Parameter(torch.empty(num_channels))
self.bias = Parameter(torch.Tensor(num_channels)) self.bias = Parameter(torch.empty(num_channels))
else: else:
self.register_parameter('weight', None) self.register_parameter('weight', None)
self.register_parameter('bias', None) self.register_parameter('bias', None)

View File

@ -84,12 +84,12 @@ class RNNBase(Module):
real_hidden_size = proj_size if proj_size > 0 else hidden_size real_hidden_size = proj_size if proj_size > 0 else hidden_size
layer_input_size = input_size if layer == 0 else real_hidden_size * num_directions layer_input_size = input_size if layer == 0 else real_hidden_size * num_directions
w_ih = Parameter(torch.Tensor(gate_size, layer_input_size)) w_ih = Parameter(torch.empty(gate_size, layer_input_size))
w_hh = Parameter(torch.Tensor(gate_size, real_hidden_size)) w_hh = Parameter(torch.empty(gate_size, real_hidden_size))
b_ih = Parameter(torch.Tensor(gate_size)) b_ih = Parameter(torch.empty(gate_size))
# Second bias vector included for CuDNN compatibility. Only one # Second bias vector included for CuDNN compatibility. Only one
# bias vector is needed in standard definition. # bias vector is needed in standard definition.
b_hh = Parameter(torch.Tensor(gate_size)) b_hh = Parameter(torch.empty(gate_size))
layer_params: Tuple[Tensor, ...] = () layer_params: Tuple[Tensor, ...] = ()
if self.proj_size == 0: if self.proj_size == 0:
if bias: if bias:
@ -97,7 +97,7 @@ class RNNBase(Module):
else: else:
layer_params = (w_ih, w_hh) layer_params = (w_ih, w_hh)
else: else:
w_hr = Parameter(torch.Tensor(proj_size, hidden_size)) w_hr = Parameter(torch.empty(proj_size, hidden_size))
if bias: if bias:
layer_params = (w_ih, w_hh, b_ih, b_hh, w_hr) layer_params = (w_ih, w_hh, b_ih, b_hh, w_hr)
else: else:
@ -850,11 +850,11 @@ class RNNCellBase(Module):
self.input_size = input_size self.input_size = input_size
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.bias = bias self.bias = bias
self.weight_ih = Parameter(torch.Tensor(num_chunks * hidden_size, input_size)) self.weight_ih = Parameter(torch.empty(num_chunks * hidden_size, input_size))
self.weight_hh = Parameter(torch.Tensor(num_chunks * hidden_size, hidden_size)) self.weight_hh = Parameter(torch.empty(num_chunks * hidden_size, hidden_size))
if bias: if bias:
self.bias_ih = Parameter(torch.Tensor(num_chunks * hidden_size)) self.bias_ih = Parameter(torch.empty(num_chunks * hidden_size))
self.bias_hh = Parameter(torch.Tensor(num_chunks * hidden_size)) self.bias_hh = Parameter(torch.empty(num_chunks * hidden_size))
else: else:
self.register_parameter('bias_ih', None) self.register_parameter('bias_ih', None)
self.register_parameter('bias_hh', None) self.register_parameter('bias_hh', None)

View File

@ -134,7 +134,7 @@ class Embedding(Module):
self.norm_type = norm_type self.norm_type = norm_type
self.scale_grad_by_freq = scale_grad_by_freq self.scale_grad_by_freq = scale_grad_by_freq
if _weight is None: if _weight is None:
self.weight = Parameter(torch.Tensor(num_embeddings, embedding_dim)) self.weight = Parameter(torch.empty(num_embeddings, embedding_dim))
self.reset_parameters() self.reset_parameters()
else: else:
assert list(_weight.shape) == [num_embeddings, embedding_dim], \ assert list(_weight.shape) == [num_embeddings, embedding_dim], \
@ -322,7 +322,7 @@ class EmbeddingBag(Module):
self.norm_type = norm_type self.norm_type = norm_type
self.scale_grad_by_freq = scale_grad_by_freq self.scale_grad_by_freq = scale_grad_by_freq
if _weight is None: if _weight is None:
self.weight = Parameter(torch.Tensor(num_embeddings, embedding_dim)) self.weight = Parameter(torch.empty(num_embeddings, embedding_dim))
self.reset_parameters() self.reset_parameters()
else: else:
assert list(_weight.shape) == [num_embeddings, embedding_dim], \ assert list(_weight.shape) == [num_embeddings, embedding_dim], \

View File

@ -22,7 +22,7 @@ class Parameter(torch.Tensor):
""" """
def __new__(cls, data=None, requires_grad=True): def __new__(cls, data=None, requires_grad=True):
if data is None: if data is None:
data = torch.Tensor() data = torch.tensor([])
return torch.Tensor._make_subclass(cls, data, requires_grad) return torch.Tensor._make_subclass(cls, data, requires_grad)
def __deepcopy__(self, memo): def __deepcopy__(self, memo):
@ -146,7 +146,7 @@ class UninitializedParameter(UninitializedTensorMixin, Parameter):
cls_to_become = Parameter cls_to_become = Parameter
def __new__(cls, requires_grad=True): def __new__(cls, requires_grad=True):
data = torch.Tensor() data = torch.tensor([])
return torch.Tensor._make_subclass(cls, data, requires_grad) return torch.Tensor._make_subclass(cls, data, requires_grad)
@ -166,5 +166,5 @@ class UninitializedBuffer(UninitializedTensorMixin, torch.Tensor):
cls_to_become = torch.Tensor cls_to_become = torch.Tensor
def __new__(cls, requires_grad=False): def __new__(cls, requires_grad=False):
data = torch.Tensor() data = torch.tensor([])
return torch.Tensor._make_subclass(cls, data, requires_grad) return torch.Tensor._make_subclass(cls, data, requires_grad)

View File

@ -1144,7 +1144,7 @@ def custom_from_mask(module, name, mask):
Examples: Examples:
>>> m = prune.custom_from_mask( >>> m = prune.custom_from_mask(
nn.Linear(5, 3), name='bias', mask=torch.Tensor([0, 1, 0]) nn.Linear(5, 3), name='bias', mask=torch.tensor([0, 1, 0])
) )
>>> print(m.bias_mask) >>> print(m.bias_mask)
tensor([0., 1., 0.]) tensor([0., 1., 0.])

View File

@ -2455,7 +2455,7 @@ def scatter_add(g, self, dim, index, src):
def log2(g, self): def log2(g, self):
_ln2 = 0.693147180559945309 _ln2 = 0.693147180559945309
return g.op('Div', log(g, self), g.op('Constant', value_t=torch.Tensor([_ln2]))) return g.op('Div', log(g, self), g.op('Constant', value_t=torch.tensor([_ln2])))
def prim_shape(g, self): def prim_shape(g, self):

View File

@ -588,8 +588,8 @@ class Quantizer:
# converting List[int] to Tensor since module attribute is # converting List[int] to Tensor since module attribute is
# Union[Tensor, Module] # Union[Tensor, Module]
model._standalone_module_input_quantized_idxs = \ model._standalone_module_input_quantized_idxs = \
torch.Tensor(input_quantized_idxs) torch.tensor(input_quantized_idxs)
model._standalone_module_output_quantized_idxs = torch.Tensor(output_quantized_idxs) model._standalone_module_output_quantized_idxs = torch.tensor(output_quantized_idxs)
return model return model
def save_state(self, observed: GraphModule) -> None: def save_state(self, observed: GraphModule) -> None:

View File

@ -101,7 +101,7 @@ results_type={self.results_type}, index_within_arg={self.index_within_arg})"""
# # one of NSSingleResultValuesType # # one of NSSingleResultValuesType
# 'type': 'weight', # 'type': 'weight',
# # the values of type specified above # # the values of type specified above
# 'values': [torch.Tensor(...), ...], # 'values': [torch.tensor(...), ...],
# # name of the node directly before the logger # # name of the node directly before the logger
# 'prev_node_name': 'linear1', # 'prev_node_name': 'linear1',
# # type of the underlying function or module # # type of the underlying function or module

View File

@ -652,7 +652,7 @@ def mseloss_no_reduce_scalar_test():
def nllloss_no_reduce_test(): def nllloss_no_reduce_test():
t = Variable(torch.Tensor(15).uniform_().mul(10).floor().long()) t = Variable(torch.empty(15).uniform_().mul(10).floor().long())
kwargs = {'reduction': 'none'} kwargs = {'reduction': 'none'}
return dict( return dict(
fullname='NLLLoss_no_reduce', fullname='NLLLoss_no_reduce',
@ -668,7 +668,7 @@ def nllloss_no_reduce_test():
def nllloss_no_reduce_ignore_index_test(): def nllloss_no_reduce_ignore_index_test():
t = Variable(torch.Tensor(15).uniform_().mul(10).floor().long()) t = Variable(torch.empty(15).uniform_().mul(10).floor().long())
kwargs: Dict[str, Union[int, str]] = {'ignore_index': 2, 'reduction': 'none'} kwargs: Dict[str, Union[int, str]] = {'ignore_index': 2, 'reduction': 'none'}
return dict( return dict(
fullname='NLLLoss_no_reduce_ignore_index', fullname='NLLLoss_no_reduce_ignore_index',
@ -685,7 +685,7 @@ def nllloss_no_reduce_ignore_index_test():
def nllloss_no_reduce_weights_test(): def nllloss_no_reduce_weights_test():
t = Variable(torch.Tensor(15).uniform_().mul(10).floor().long()) t = Variable(torch.empty(15).uniform_().mul(10).floor().long())
weight = torch.rand(10) weight = torch.rand(10)
def kwargs(i): def kwargs(i):
@ -706,7 +706,7 @@ def nllloss_no_reduce_weights_test():
def nllloss_no_reduce_weights_ignore_index_test(): def nllloss_no_reduce_weights_ignore_index_test():
t = Variable(torch.Tensor(15).uniform_().mul(10).floor().long()) t = Variable(torch.empty(15).uniform_().mul(10).floor().long())
weight = torch.rand(10) weight = torch.rand(10)
def kwargs(i): def kwargs(i):
@ -728,7 +728,7 @@ def nllloss_no_reduce_weights_ignore_index_test():
def nllloss_no_reduce_weights_ignore_index_neg_test(): def nllloss_no_reduce_weights_ignore_index_neg_test():
t = Variable(torch.Tensor(15).uniform_().mul(10).floor().long()) t = Variable(torch.empty(15).uniform_().mul(10).floor().long())
weight = torch.rand(10) weight = torch.rand(10)
def kwargs(i): def kwargs(i):
@ -4258,7 +4258,7 @@ criterion_tests = [
dict( dict(
module_name='NLLLoss', module_name='NLLLoss',
input_fn=lambda: torch.rand(15, 10).log(), input_fn=lambda: torch.rand(15, 10).log(),
target_fn=lambda: torch.Tensor(15).uniform_().mul(10).floor().long(), target_fn=lambda: torch.empty(15).uniform_().mul(10).floor().long(),
reference_fn=lambda i, t, m: reference_fn=lambda i, t, m:
nllloss_reference(i, t, reduction=get_reduction(m)), nllloss_reference(i, t, reduction=get_reduction(m)),
check_sum_reduction=True, check_sum_reduction=True,
@ -4269,7 +4269,7 @@ criterion_tests = [
constructor_args=(None, None, 2), constructor_args=(None, None, 2),
cpp_constructor_args='torch::nn::NLLLossOptions().weight({}).ignore_index(2)', cpp_constructor_args='torch::nn::NLLLossOptions().weight({}).ignore_index(2)',
input_fn=lambda: torch.rand(15, 10).log(), input_fn=lambda: torch.rand(15, 10).log(),
target_fn=lambda: torch.Tensor(15).uniform_().mul(10).floor().long(), target_fn=lambda: torch.empty(15).uniform_().mul(10).floor().long(),
reference_fn=lambda i, t, _: nllloss_reference(i, t, ignore_index=2), reference_fn=lambda i, t, _: nllloss_reference(i, t, ignore_index=2),
desc='ignore_index', desc='ignore_index',
check_bfloat16=True, check_bfloat16=True,
@ -4279,7 +4279,7 @@ criterion_tests = [
constructor_args_fn=lambda: (torch.rand(10),), constructor_args_fn=lambda: (torch.rand(10),),
cpp_constructor_args='torch::nn::NLLLossOptions().weight(torch::rand(10))', cpp_constructor_args='torch::nn::NLLLossOptions().weight(torch::rand(10))',
input_fn=lambda: torch.rand(15, 10).add(1e-2).log(), input_fn=lambda: torch.rand(15, 10).add(1e-2).log(),
target_fn=lambda: torch.Tensor(15).uniform_().mul(10).floor().long(), target_fn=lambda: torch.empty(15).uniform_().mul(10).floor().long(),
reference_fn=lambda i, t, m: reference_fn=lambda i, t, m:
nllloss_reference(i, t, weight=get_weight(m)), nllloss_reference(i, t, weight=get_weight(m)),
desc='weights', desc='weights',
@ -4290,7 +4290,7 @@ criterion_tests = [
constructor_args_fn=lambda: (torch.rand(10), None, 2), constructor_args_fn=lambda: (torch.rand(10), None, 2),
cpp_constructor_args='torch::nn::NLLLossOptions().weight(torch::rand(10)).ignore_index(2)', cpp_constructor_args='torch::nn::NLLLossOptions().weight(torch::rand(10)).ignore_index(2)',
input_fn=lambda: torch.rand(15, 10).add(1e-2).log(), input_fn=lambda: torch.rand(15, 10).add(1e-2).log(),
target_fn=lambda: torch.Tensor(15).uniform_().mul(10).floor().long(), target_fn=lambda: torch.empty(15).uniform_().mul(10).floor().long(),
reference_fn=lambda i, t, m: reference_fn=lambda i, t, m:
nllloss_reference(i, t, weight=get_weight(m), ignore_index=2), nllloss_reference(i, t, weight=get_weight(m), ignore_index=2),
desc='weights_ignore_index', desc='weights_ignore_index',
@ -4301,7 +4301,7 @@ criterion_tests = [
constructor_args_fn=lambda: (torch.rand(10), None, -1), constructor_args_fn=lambda: (torch.rand(10), None, -1),
cpp_constructor_args='torch::nn::NLLLossOptions().weight(torch::rand(10)).ignore_index(-1)', cpp_constructor_args='torch::nn::NLLLossOptions().weight(torch::rand(10)).ignore_index(-1)',
input_fn=lambda: torch.rand(15, 10).add(1e-2).log(), input_fn=lambda: torch.rand(15, 10).add(1e-2).log(),
target_fn=lambda: torch.Tensor(15).uniform_().mul(10 + 1).floor().long() - 1, target_fn=lambda: torch.empty(15).uniform_().mul(10 + 1).floor().long() - 1,
reference_fn=lambda i, t, m: reference_fn=lambda i, t, m:
nllloss_reference(i, t, weight=get_weight(m), ignore_index=-1), nllloss_reference(i, t, weight=get_weight(m), ignore_index=-1),
desc='weights_ignore_index_neg', desc='weights_ignore_index_neg',
@ -4354,14 +4354,14 @@ criterion_tests = [
dict( dict(
module_name='CrossEntropyLoss', module_name='CrossEntropyLoss',
input_size=(15, 10), input_size=(15, 10),
target_fn=lambda: torch.Tensor(15).uniform_().mul(10).floor().long(), target_fn=lambda: torch.empty(15).uniform_().mul(10).floor().long(),
), ),
dict( dict(
module_name='CrossEntropyLoss', module_name='CrossEntropyLoss',
constructor_args_fn=lambda: (torch.rand(10),), constructor_args_fn=lambda: (torch.rand(10),),
cpp_constructor_args='torch::nn::CrossEntropyLossOptions().weight(torch::rand(10))', cpp_constructor_args='torch::nn::CrossEntropyLossOptions().weight(torch::rand(10))',
input_size=(15, 10), input_size=(15, 10),
target_fn=lambda: torch.Tensor(15).uniform_().mul(10).floor().long(), target_fn=lambda: torch.empty(15).uniform_().mul(10).floor().long(),
desc='weights', desc='weights',
), ),
dict( dict(

View File

@ -97,7 +97,7 @@ class RemoteEM(nn.Module):
self.em = nn.EmbeddingBag( self.em = nn.EmbeddingBag(
num_embeddings, num_embeddings,
embedding_dim, embedding_dim,
_weight=torch.Tensor([init_em] * num_embeddings), _weight=torch.tensor([init_em] * num_embeddings),
) )
def forward(self, input: torch.Tensor): def forward(self, input: torch.Tensor):
@ -278,10 +278,10 @@ def get_training_examples():
# Every example has another one that has exactly the same features but an # Every example has another one that has exactly the same features but an
# opposite value. Therefore, their grads cancel each other in all-reduce. # opposite value. Therefore, their grads cancel each other in all-reduce.
for value in (-1, 1): for value in (-1, 1):
for x in (-1 * value, 1 * value): for x in (-1.0 * value, 1.0 * value):
for y in (1 * value, -1 * value): for y in (1.0 * value, -1.0 * value):
for z in (0, 1): for z in (0, 1):
training_examples.dense_features[idx, :] = torch.Tensor((x, y)) training_examples.dense_features[idx, :] = torch.tensor((x, y))
training_examples.sparse_features[idx] = z training_examples.sparse_features[idx] = z
training_examples.values[idx] = value training_examples.values[idx] = value
idx += 1 idx += 1