mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add meta support to tensor range factories (#67032)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/67032 This PR adds meta backend support to the `range`, `arange`, `linspace`, and `logspace` operators. Note that the original PR (#66630) was reverted due to two failing unit tests in the Bionic CI. This revision includes a fix for those tests; otherwise its content is identical to the previous PR. Original commit changeset: 2f9d8d1acbb0 ghstack-source-id: 142487306 Test Plan: Extended the existing tensor creation tests to assert meta backend support. Reviewed By: zhaojuanmao Differential Revision: D31834403 fbshipit-source-id: a489858a2a8a38a03234b14408e14d2b208a8d34
This commit is contained in:
parent
9e8016d8c4
commit
efdb17b984
|
|
@ -281,7 +281,7 @@ std::vector<Tensor>& histogramdd_bin_edges_out_cpu(const Tensor& self, IntArrayR
|
|||
auto outer_bin_edges = select_outer_bin_edges(reshaped_self, range);
|
||||
|
||||
for (int64_t dim = 0; dim < N; dim++) {
|
||||
linspace_cpu_out(outer_bin_edges.first[dim], outer_bin_edges.second[dim],
|
||||
linspace_out(outer_bin_edges.first[dim], outer_bin_edges.second[dim],
|
||||
bin_ct[dim] + 1, bin_edges_out[dim]);
|
||||
}
|
||||
|
||||
|
|
@ -362,7 +362,7 @@ histogram_out_cpu(const Tensor& self, int64_t bin_ct, c10::optional<c10::ArrayRe
|
|||
|
||||
histogramdd_prepare_out(reshaped_self, std::vector<int64_t>{bin_ct}, hist, bins_out);
|
||||
auto outer_bin_edges = select_outer_bin_edges(reshaped_self, range);
|
||||
linspace_cpu_out(outer_bin_edges.first[0], outer_bin_edges.second[0], bin_ct + 1, bin_edges);
|
||||
linspace_out(outer_bin_edges.first[0], outer_bin_edges.second[0], bin_ct + 1, bin_edges);
|
||||
|
||||
histogramdd_check_inputs(reshaped_self, bins_in, reshaped_weight);
|
||||
|
||||
|
|
@ -391,7 +391,7 @@ Tensor& histogram_histc_cpu_out(const Tensor& self, int64_t bin_ct,
|
|||
histogramdd_prepare_out(reshaped, std::vector<int64_t>{bin_ct}, hist, bins_out);
|
||||
|
||||
auto outer_bin_edges = histc_select_outer_bin_edges(self, min, max);
|
||||
linspace_cpu_out(outer_bin_edges.first, outer_bin_edges.second, bin_ct + 1, bin_edges);
|
||||
linspace_out(outer_bin_edges.first, outer_bin_edges.second, bin_ct + 1, bin_edges);
|
||||
|
||||
histogramdd_check_inputs(reshaped, bins_in, {});
|
||||
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ namespace at { namespace native {
|
|||
DECLARE_DISPATCH(void(*)(TensorIterator&, const Scalar&, const Scalar&, const Scalar&), arange_stub);
|
||||
DECLARE_DISPATCH(void(*)(TensorIterator&, const Scalar&, const Scalar&, int64_t), linspace_stub);
|
||||
|
||||
Tensor& linspace_cpu_out(const Scalar& start, const Scalar& end, c10::optional<int64_t> optional_steps, Tensor& result) {
|
||||
Tensor& linspace_out(const Scalar& start, const Scalar& end, c10::optional<int64_t> optional_steps, Tensor& result) {
|
||||
const auto steps = optional_steps.value_or(100);
|
||||
TORCH_CHECK(steps >= 0, "number of steps must be non-negative");
|
||||
|
||||
|
|
@ -28,6 +28,10 @@ Tensor& linspace_cpu_out(const Scalar& start, const Scalar& end, c10::optional<i
|
|||
result.resize_({steps});
|
||||
}
|
||||
|
||||
if (result.device() == kMeta) {
|
||||
return result;
|
||||
}
|
||||
|
||||
if (steps == 0) {
|
||||
// skip
|
||||
} else if (steps == 1) {
|
||||
|
|
@ -44,7 +48,7 @@ Tensor& linspace_cpu_out(const Scalar& start, const Scalar& end, c10::optional<i
|
|||
return result;
|
||||
}
|
||||
|
||||
Tensor& logspace_cpu_out(const Scalar& start, const Scalar& end, c10::optional<int64_t> optional_steps, double base, Tensor& result) {
|
||||
Tensor& logspace_out(const Scalar& start, const Scalar& end, c10::optional<int64_t> optional_steps, double base, Tensor& result) {
|
||||
const auto steps = optional_steps.value_or(100);
|
||||
TORCH_CHECK(steps >= 0, "number of steps must be non-negative");
|
||||
|
||||
|
|
@ -58,6 +62,11 @@ Tensor& logspace_cpu_out(const Scalar& start, const Scalar& end, c10::optional<i
|
|||
if (result.numel() != steps) {
|
||||
result.resize_({steps});
|
||||
}
|
||||
|
||||
if (result.device() == kMeta) {
|
||||
return result;
|
||||
}
|
||||
|
||||
Tensor r = result.is_contiguous() ? result : result.contiguous();
|
||||
|
||||
if (steps == 0) {
|
||||
|
|
@ -113,7 +122,7 @@ Tensor& logspace_cpu_out(const Scalar& start, const Scalar& end, c10::optional<i
|
|||
return result;
|
||||
}
|
||||
|
||||
Tensor& range_cpu_out(const Scalar& start, const Scalar& end, const Scalar& step, Tensor& result) {
|
||||
Tensor& range_out(const Scalar& start, const Scalar& end, const Scalar& step, Tensor& result) {
|
||||
AT_DISPATCH_ALL_TYPES_AND(kBFloat16, result.scalar_type(), "range_cpu", [&]() {
|
||||
using accscalar_t = at::acc_type<scalar_t, false>;
|
||||
auto xstart = start.to<accscalar_t>();
|
||||
|
|
@ -130,6 +139,11 @@ Tensor& range_cpu_out(const Scalar& start, const Scalar& end, const Scalar& step
|
|||
if (result.numel() != size) {
|
||||
result.resize_({size});
|
||||
}
|
||||
|
||||
if (result.device() == kMeta) {
|
||||
return;
|
||||
}
|
||||
|
||||
Tensor r = result.is_contiguous() ? result : result.contiguous();
|
||||
scalar_t *data_ptr = r.data_ptr<scalar_t>();
|
||||
|
||||
|
|
@ -147,7 +161,7 @@ Tensor& range_cpu_out(const Scalar& start, const Scalar& end, const Scalar& step
|
|||
return result;
|
||||
}
|
||||
|
||||
Tensor& arange_cpu_out(const Scalar& start, const Scalar& end, const Scalar& step, Tensor& result) {
|
||||
Tensor& arange_out(const Scalar& start, const Scalar& end, const Scalar& step, Tensor& result) {
|
||||
AT_DISPATCH_ALL_TYPES_AND(kBFloat16, result.scalar_type(), "arange_cpu", [&]() {
|
||||
using accscalar_t = at::acc_type<scalar_t, false>;
|
||||
auto xstart = start.to<accscalar_t>();
|
||||
|
|
@ -194,6 +208,10 @@ Tensor& arange_cpu_out(const Scalar& start, const Scalar& end, const Scalar& ste
|
|||
result.resize_({size});
|
||||
}
|
||||
|
||||
if (result.device() == kMeta) {
|
||||
return;
|
||||
}
|
||||
|
||||
Tensor r = result.is_contiguous() ? result : result.contiguous();
|
||||
auto iter = TensorIterator::borrowing_nullary_op(r);
|
||||
arange_stub(iter.device_type(), iter, start, size, step);
|
||||
|
|
|
|||
|
|
@ -533,7 +533,7 @@
|
|||
|
||||
- func: arange.start_out(Scalar start, Scalar end, Scalar step=1, *, Tensor(a!) out) -> Tensor(a!)
|
||||
dispatch:
|
||||
CPU: arange_cpu_out
|
||||
CPU, Meta: arange_out
|
||||
CUDA: arange_cuda_out
|
||||
|
||||
# This function is a temporary hack to allow tracing of arange like constructs with dynamic
|
||||
|
|
@ -2486,7 +2486,7 @@
|
|||
|
||||
- func: linspace.out(Scalar start, Scalar end, int? steps=None, *, Tensor(a!) out) -> Tensor(a!)
|
||||
dispatch:
|
||||
CPU: linspace_cpu_out
|
||||
CPU, Meta: linspace_out
|
||||
CUDA: linspace_cuda_out
|
||||
|
||||
- func: log(Tensor self) -> Tensor
|
||||
|
|
@ -2645,7 +2645,7 @@
|
|||
|
||||
- func: logspace.out(Scalar start, Scalar end, int? steps=None, float base=10.0, *, Tensor(a!) out) -> Tensor(a!)
|
||||
dispatch:
|
||||
CPU: logspace_cpu_out
|
||||
CPU, Meta: logspace_out
|
||||
CUDA: logspace_cuda_out
|
||||
|
||||
# log_softmax allows positional dtype, unlike most operators, because kwonly is BC-breaking when loading jit models.
|
||||
|
|
@ -3455,7 +3455,7 @@
|
|||
|
||||
- func: range.out(Scalar start, Scalar end, Scalar step=1, *, Tensor(a!) out) -> Tensor(a!)
|
||||
dispatch:
|
||||
CPU: range_cpu_out
|
||||
CPU, Meta: range_out
|
||||
CUDA: range_cuda_out
|
||||
|
||||
- func: ravel(Tensor(a) self) -> Tensor(a)
|
||||
|
|
|
|||
|
|
@ -1495,8 +1495,8 @@ class NumpyTests(TestCase):
|
|||
expected = b.float().unsqueeze(1).expand(100, 100)
|
||||
self.assertEqual(a, expected)
|
||||
|
||||
instantiate_device_type_tests(TestIndexing, globals())
|
||||
instantiate_device_type_tests(NumpyTests, globals())
|
||||
instantiate_device_type_tests(TestIndexing, globals(), except_for='meta')
|
||||
instantiate_device_type_tests(NumpyTests, globals(), except_for='meta')
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -1626,6 +1626,7 @@ class TestTensorCreation(TestCase):
|
|||
self.assertEqual(c1, expected)
|
||||
self.assertEqual(c2, expected)
|
||||
|
||||
@skipMeta
|
||||
def test_linlogspace_mem_overlap(self, device):
|
||||
x = torch.rand(1, device=device).expand(10)
|
||||
with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
|
||||
|
|
|
|||
|
|
@ -60,6 +60,7 @@ _compare_return_type = Tuple[bool, Optional[str]]
|
|||
#
|
||||
# The `equal_nan` can be True or False, which maps to the True or False
|
||||
# in `torch.allclose`.
|
||||
# TODO: Add support for comparing meta tensors. See https://github.com/pytorch/pytorch/pull/67032.
|
||||
def _compare_tensors_internal(a: torch.Tensor, b: torch.Tensor, *, rtol, atol, equal_nan) -> _compare_return_type:
|
||||
debug_msg : Optional[str]
|
||||
# Integer (including bool) comparisons are identity comparisons
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user