[pt2] add SymInt support for cdist (#98881)

Fixes #98853.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98881
Approved by: https://github.com/ezyang
This commit is contained in:
Nikita Karetnikov 2023-04-12 01:44:55 +02:00 committed by PyTorch MergeBot
parent a2e809f29b
commit 8db04e080c
5 changed files with 30 additions and 25 deletions

View File

@ -46,6 +46,10 @@ std::vector<int64_t> infer_size(IntArrayRef a, IntArrayRef b) {
return infer_size_impl<std::vector<int64_t>>(a, b);
}
std::vector<SymInt> infer_size_symint(SymIntArrayRef a, SymIntArrayRef b) {
return infer_size_impl<std::vector<SymInt>>(a, b);
}
DimVector infer_size_dimvector(IntArrayRef a, IntArrayRef b) {
return infer_size_impl<DimVector, IntArrayRef>(a, b);
}

View File

@ -21,6 +21,9 @@
namespace at {
TORCH_API std::vector<int64_t> infer_size(IntArrayRef a, IntArrayRef b);
TORCH_API std::vector<SymInt> infer_size_symint(
SymIntArrayRef a,
SymIntArrayRef b);
TORCH_API DimVector infer_size_dimvector(IntArrayRef a, IntArrayRef b);
TORCH_API SymDimVector
infer_size_symdimvector(SymIntArrayRef a, SymIntArrayRef b);

View File

@ -85,8 +85,8 @@ static Tensor cdist_impl(const Tensor& x1, const Tensor& x2, const double p, c10
TORCH_CHECK(device1 == device2, "X1 and X2 must have the same device type. X1: ", device1, " X2: ", device2);
// TODO: This is bad; this test should apply universally
TORCH_CHECK(!x1.is_cuda() || x1.get_device() == x2.get_device(), "device of X1 (", x1.get_device(), ") must match device of X2 (", x2.get_device(), ")");
int64_t c1 = x1.size(-1);
int64_t c2 = x2.size(-1);
SymInt c1 = x1.sym_size(-1);
SymInt c2 = x2.sym_size(-1);
// 0 - default value. If p = 2 and r1 > 25 or r2 > 25 (these values are based on performance metrics),
// it will try to compute distance using matrix multiplication approach
// 1 - force to use matrix multiplication for p = 2
@ -94,8 +94,8 @@ static Tensor cdist_impl(const Tensor& x1, const Tensor& x2, const double p, c10
int64_t mode = compute_mode.value_or(0);
TORCH_CHECK(mode >= 0 && mode <= 2, "possible modes: 0, 1, 2, but was: ", mode);
int64_t r1 = x1.size(-2);
int64_t r2 = x2.size(-2);
SymInt r1 = x1.sym_size(-2);
SymInt r2 = x2.sym_size(-2);
// See Note [cdist relies on cdist_impl redispatching]
// Keep this condition in sync with the condition at the Note
@ -109,37 +109,37 @@ static Tensor cdist_impl(const Tensor& x1, const Tensor& x2, const double p, c10
//For batch calculation we expand all dimensions(except the last two) to one, with size that equals to product of them.
//The last two dimensions will stay the same
IntArrayRef batch_tensor1(x1.sizes().data(), dim1 - 2);
IntArrayRef batch_tensor2(x2.sizes().data(), dim2 - 2);
std::vector<int64_t> expand_batch_portion = infer_size(batch_tensor1, batch_tensor2);
std::vector<int64_t> tensor1_expand_size(expand_batch_portion);
SymIntArrayRef batch_tensor1(x1.sym_sizes().data(), dim1 - 2);
SymIntArrayRef batch_tensor2(x2.sym_sizes().data(), dim2 - 2);
std::vector<SymInt> expand_batch_portion = infer_size_symint(batch_tensor1, batch_tensor2);
std::vector<SymInt> tensor1_expand_size(expand_batch_portion);
tensor1_expand_size.insert(tensor1_expand_size.end(), {r1, c1});
std::vector<int64_t> tensor2_expand_size(expand_batch_portion);
std::vector<SymInt> tensor2_expand_size(expand_batch_portion);
tensor2_expand_size.insert(tensor2_expand_size.end(), {r2, c2});
const int64_t expand_batch_product = c10::multiply_integers(expand_batch_portion);
std::vector<int64_t> tensor1_view{expand_batch_product, r1, c1};
std::vector<int64_t> tensor2_view{expand_batch_product, r2, c2};
const SymInt expand_batch_product = c10::multiply_integers(expand_batch_portion);
std::vector<SymInt> tensor1_view{expand_batch_product, r1, c1};
std::vector<SymInt> tensor2_view{expand_batch_product, r2, c2};
Tensor tensor1_expanded = x1.expand(tensor1_expand_size).contiguous().view(tensor1_view);
Tensor tensor2_expanded = x2.expand(tensor2_expand_size).contiguous().view(tensor2_view);
Tensor tensor1_expanded = x1.expand_symint(tensor1_expand_size).contiguous().view_symint(tensor1_view);
Tensor tensor2_expanded = x2.expand_symint(tensor2_expand_size).contiguous().view_symint(tensor2_view);
std::vector<int64_t> output_shape(std::move(expand_batch_portion));
std::vector<SymInt> output_shape(std::move(expand_batch_portion));
output_shape.insert(output_shape.end(), {r1, r2});
Tensor result;
if (r1 == 0 || r2 == 0 || expand_batch_product == 0) {
result = at::empty(output_shape, x1.options());
result = at::empty_symint(output_shape, x1.options());
} else if (c1 == 0) {
result = at::zeros(output_shape, x1.options());
result = at::zeros_symint(output_shape, x1.options());
} else if (p == 2 && (mode == 1 || (mode == 0 && (r1 > 25 || r2 > 25)))) {
// See Note [cdist relies on cdist_impl redispatching]
// Keep the condition above in sync with the condition at the Note
Tensor dist = (expand_batch_product == 1) ? at::_euclidean_dist(x1, x2) :
at::_euclidean_dist(tensor1_expanded, tensor2_expanded);
result = dist.view(output_shape);
result = dist.view_symint(output_shape);
} else {
result = at::empty(output_shape, x1.options());
result = at::empty_symint(output_shape, x1.options());
cdist_stub(device1, result, tensor1_expanded, tensor2_expanded, p);
}
return result;
@ -148,14 +148,14 @@ static Tensor cdist_impl(const Tensor& x1, const Tensor& x2, const double p, c10
Tensor cdist(const Tensor& x1, const Tensor& x2, const double p, c10::optional<int64_t> compute_mode) {
TORCH_CHECK(x1.dim() >= 2, "cdist only supports at least 2D tensors, X1 got: ", x1.dim(), "D");
TORCH_CHECK(x2.dim() >= 2, "cdist only supports at least 2D tensors, X2 got: ", x2.dim(), "D");
TORCH_CHECK(x1.size(-1) == x2.size(-1), "X1 and X2 must have the same number of columns. X1: ", x1.size(-1), " X2: ", x2.size(-1));
TORCH_CHECK(x1.sym_size(-1) == x2.sym_size(-1), "X1 and X2 must have the same number of columns. X1: ", x1.sym_size(-1), " X2: ", x2.sym_size(-1));
auto maybe_outnames = namedinference::compute_cdist_outnames(x1, x2);
auto result = [&]() {
NoNamesGuard guard;
int64_t r1 = x1.size(-2);
int64_t r2 = x2.size(-2);
SymInt r1 = x1.sym_size(-2);
SymInt r2 = x2.sym_size(-2);
// Special case for empty input: always call the version with explicit autograd to ensure the graph is properly connected
if (x1.numel() == 0 || x2.numel() == 0) {
if (x1.sym_numel() == 0 || x2.sym_numel() == 0) {
return at::_cdist_forward(x1, x2, p, compute_mode);
}
int64_t mode = compute_mode.value_or(0);

View File

@ -681,7 +681,6 @@ meta_function_skips = {
torch.aminmax : {i8, i64, u8, f64, b8, f32, i32, i16},
torch.diff : {b8},
torch.equal : {bf16, i8, c32, i64, u8, c128, b8, f64, i16, i32, f32, f16, c64},
torch.functional.cdist : {f64, f32},
torch.nanmean : {bf16, f64, f32, f16, c32, c64, c128},
torch.nn.functional.cross_entropy : {bf16, f64, f32},
torch.nn.functional.interpolate : {bf16, f64, f32, u8},

View File

@ -1367,7 +1367,6 @@ symbolic_tensor_failures = {
xfail('polar'),
xfail('linalg.eig'),
xfail('linalg.eigvals'),
xfail('cdist', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('cholesky_solve', ''), # Could not run 'aten::_cholesky_solve_helper' with arguments from the 'Meta' back...
xfail('column_stack', ''), # Tensors of type TensorImpl do not have numel
xfail('combinations', ''),