mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[pt2] add SymInt support for cdist (#98881)
Fixes #98853. Pull Request resolved: https://github.com/pytorch/pytorch/pull/98881 Approved by: https://github.com/ezyang
This commit is contained in:
parent
a2e809f29b
commit
8db04e080c
|
|
@ -46,6 +46,10 @@ std::vector<int64_t> infer_size(IntArrayRef a, IntArrayRef b) {
|
|||
return infer_size_impl<std::vector<int64_t>>(a, b);
|
||||
}
|
||||
|
||||
std::vector<SymInt> infer_size_symint(SymIntArrayRef a, SymIntArrayRef b) {
|
||||
return infer_size_impl<std::vector<SymInt>>(a, b);
|
||||
}
|
||||
|
||||
DimVector infer_size_dimvector(IntArrayRef a, IntArrayRef b) {
|
||||
return infer_size_impl<DimVector, IntArrayRef>(a, b);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -21,6 +21,9 @@
|
|||
namespace at {
|
||||
|
||||
TORCH_API std::vector<int64_t> infer_size(IntArrayRef a, IntArrayRef b);
|
||||
TORCH_API std::vector<SymInt> infer_size_symint(
|
||||
SymIntArrayRef a,
|
||||
SymIntArrayRef b);
|
||||
TORCH_API DimVector infer_size_dimvector(IntArrayRef a, IntArrayRef b);
|
||||
TORCH_API SymDimVector
|
||||
infer_size_symdimvector(SymIntArrayRef a, SymIntArrayRef b);
|
||||
|
|
|
|||
|
|
@ -85,8 +85,8 @@ static Tensor cdist_impl(const Tensor& x1, const Tensor& x2, const double p, c10
|
|||
TORCH_CHECK(device1 == device2, "X1 and X2 must have the same device type. X1: ", device1, " X2: ", device2);
|
||||
// TODO: This is bad; this test should apply universally
|
||||
TORCH_CHECK(!x1.is_cuda() || x1.get_device() == x2.get_device(), "device of X1 (", x1.get_device(), ") must match device of X2 (", x2.get_device(), ")");
|
||||
int64_t c1 = x1.size(-1);
|
||||
int64_t c2 = x2.size(-1);
|
||||
SymInt c1 = x1.sym_size(-1);
|
||||
SymInt c2 = x2.sym_size(-1);
|
||||
// 0 - default value. If p = 2 and r1 > 25 or r2 > 25 (these values are based on performance metrics),
|
||||
// it will try to compute distance using matrix multiplication approach
|
||||
// 1 - force to use matrix multiplication for p = 2
|
||||
|
|
@ -94,8 +94,8 @@ static Tensor cdist_impl(const Tensor& x1, const Tensor& x2, const double p, c10
|
|||
int64_t mode = compute_mode.value_or(0);
|
||||
TORCH_CHECK(mode >= 0 && mode <= 2, "possible modes: 0, 1, 2, but was: ", mode);
|
||||
|
||||
int64_t r1 = x1.size(-2);
|
||||
int64_t r2 = x2.size(-2);
|
||||
SymInt r1 = x1.sym_size(-2);
|
||||
SymInt r2 = x2.sym_size(-2);
|
||||
|
||||
// See Note [cdist relies on cdist_impl redispatching]
|
||||
// Keep this condition in sync with the condition at the Note
|
||||
|
|
@ -109,37 +109,37 @@ static Tensor cdist_impl(const Tensor& x1, const Tensor& x2, const double p, c10
|
|||
|
||||
//For batch calculation we expand all dimensions(except the last two) to one, with size that equals to product of them.
|
||||
//The last two dimensions will stay the same
|
||||
IntArrayRef batch_tensor1(x1.sizes().data(), dim1 - 2);
|
||||
IntArrayRef batch_tensor2(x2.sizes().data(), dim2 - 2);
|
||||
std::vector<int64_t> expand_batch_portion = infer_size(batch_tensor1, batch_tensor2);
|
||||
std::vector<int64_t> tensor1_expand_size(expand_batch_portion);
|
||||
SymIntArrayRef batch_tensor1(x1.sym_sizes().data(), dim1 - 2);
|
||||
SymIntArrayRef batch_tensor2(x2.sym_sizes().data(), dim2 - 2);
|
||||
std::vector<SymInt> expand_batch_portion = infer_size_symint(batch_tensor1, batch_tensor2);
|
||||
std::vector<SymInt> tensor1_expand_size(expand_batch_portion);
|
||||
tensor1_expand_size.insert(tensor1_expand_size.end(), {r1, c1});
|
||||
std::vector<int64_t> tensor2_expand_size(expand_batch_portion);
|
||||
std::vector<SymInt> tensor2_expand_size(expand_batch_portion);
|
||||
tensor2_expand_size.insert(tensor2_expand_size.end(), {r2, c2});
|
||||
|
||||
const int64_t expand_batch_product = c10::multiply_integers(expand_batch_portion);
|
||||
std::vector<int64_t> tensor1_view{expand_batch_product, r1, c1};
|
||||
std::vector<int64_t> tensor2_view{expand_batch_product, r2, c2};
|
||||
const SymInt expand_batch_product = c10::multiply_integers(expand_batch_portion);
|
||||
std::vector<SymInt> tensor1_view{expand_batch_product, r1, c1};
|
||||
std::vector<SymInt> tensor2_view{expand_batch_product, r2, c2};
|
||||
|
||||
Tensor tensor1_expanded = x1.expand(tensor1_expand_size).contiguous().view(tensor1_view);
|
||||
Tensor tensor2_expanded = x2.expand(tensor2_expand_size).contiguous().view(tensor2_view);
|
||||
Tensor tensor1_expanded = x1.expand_symint(tensor1_expand_size).contiguous().view_symint(tensor1_view);
|
||||
Tensor tensor2_expanded = x2.expand_symint(tensor2_expand_size).contiguous().view_symint(tensor2_view);
|
||||
|
||||
std::vector<int64_t> output_shape(std::move(expand_batch_portion));
|
||||
std::vector<SymInt> output_shape(std::move(expand_batch_portion));
|
||||
output_shape.insert(output_shape.end(), {r1, r2});
|
||||
|
||||
Tensor result;
|
||||
if (r1 == 0 || r2 == 0 || expand_batch_product == 0) {
|
||||
result = at::empty(output_shape, x1.options());
|
||||
result = at::empty_symint(output_shape, x1.options());
|
||||
} else if (c1 == 0) {
|
||||
result = at::zeros(output_shape, x1.options());
|
||||
result = at::zeros_symint(output_shape, x1.options());
|
||||
} else if (p == 2 && (mode == 1 || (mode == 0 && (r1 > 25 || r2 > 25)))) {
|
||||
// See Note [cdist relies on cdist_impl redispatching]
|
||||
// Keep the condition above in sync with the condition at the Note
|
||||
Tensor dist = (expand_batch_product == 1) ? at::_euclidean_dist(x1, x2) :
|
||||
at::_euclidean_dist(tensor1_expanded, tensor2_expanded);
|
||||
result = dist.view(output_shape);
|
||||
result = dist.view_symint(output_shape);
|
||||
} else {
|
||||
result = at::empty(output_shape, x1.options());
|
||||
result = at::empty_symint(output_shape, x1.options());
|
||||
cdist_stub(device1, result, tensor1_expanded, tensor2_expanded, p);
|
||||
}
|
||||
return result;
|
||||
|
|
@ -148,14 +148,14 @@ static Tensor cdist_impl(const Tensor& x1, const Tensor& x2, const double p, c10
|
|||
Tensor cdist(const Tensor& x1, const Tensor& x2, const double p, c10::optional<int64_t> compute_mode) {
|
||||
TORCH_CHECK(x1.dim() >= 2, "cdist only supports at least 2D tensors, X1 got: ", x1.dim(), "D");
|
||||
TORCH_CHECK(x2.dim() >= 2, "cdist only supports at least 2D tensors, X2 got: ", x2.dim(), "D");
|
||||
TORCH_CHECK(x1.size(-1) == x2.size(-1), "X1 and X2 must have the same number of columns. X1: ", x1.size(-1), " X2: ", x2.size(-1));
|
||||
TORCH_CHECK(x1.sym_size(-1) == x2.sym_size(-1), "X1 and X2 must have the same number of columns. X1: ", x1.sym_size(-1), " X2: ", x2.sym_size(-1));
|
||||
auto maybe_outnames = namedinference::compute_cdist_outnames(x1, x2);
|
||||
auto result = [&]() {
|
||||
NoNamesGuard guard;
|
||||
int64_t r1 = x1.size(-2);
|
||||
int64_t r2 = x2.size(-2);
|
||||
SymInt r1 = x1.sym_size(-2);
|
||||
SymInt r2 = x2.sym_size(-2);
|
||||
// Special case for empty input: always call the version with explicit autograd to ensure the graph is properly connected
|
||||
if (x1.numel() == 0 || x2.numel() == 0) {
|
||||
if (x1.sym_numel() == 0 || x2.sym_numel() == 0) {
|
||||
return at::_cdist_forward(x1, x2, p, compute_mode);
|
||||
}
|
||||
int64_t mode = compute_mode.value_or(0);
|
||||
|
|
|
|||
|
|
@ -681,7 +681,6 @@ meta_function_skips = {
|
|||
torch.aminmax : {i8, i64, u8, f64, b8, f32, i32, i16},
|
||||
torch.diff : {b8},
|
||||
torch.equal : {bf16, i8, c32, i64, u8, c128, b8, f64, i16, i32, f32, f16, c64},
|
||||
torch.functional.cdist : {f64, f32},
|
||||
torch.nanmean : {bf16, f64, f32, f16, c32, c64, c128},
|
||||
torch.nn.functional.cross_entropy : {bf16, f64, f32},
|
||||
torch.nn.functional.interpolate : {bf16, f64, f32, u8},
|
||||
|
|
|
|||
|
|
@ -1367,7 +1367,6 @@ symbolic_tensor_failures = {
|
|||
xfail('polar'),
|
||||
xfail('linalg.eig'),
|
||||
xfail('linalg.eigvals'),
|
||||
xfail('cdist', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
|
||||
xfail('cholesky_solve', ''), # Could not run 'aten::_cholesky_solve_helper' with arguments from the 'Meta' back...
|
||||
xfail('column_stack', ''), # Tensors of type TensorImpl do not have numel
|
||||
xfail('combinations', ''),
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user