mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[export] Add device and dtype fields to assert_tensor_metadata (#141071)
Differential Revision: [D66321128](https://our.internmc.facebook.com/intern/diff/D66321128) Pull Request resolved: https://github.com/pytorch/pytorch/pull/141071 Approved by: https://github.com/yushangdi, https://github.com/zou3519
This commit is contained in:
parent
45d62d6fc5
commit
0fbc0830ba
|
|
@ -19,15 +19,27 @@ void _assert_match(const O& original, const C& compared, const std::string& name
|
||||||
if (!equal) {
|
if (!equal) {
|
||||||
std::stringstream msg;
|
std::stringstream msg;
|
||||||
msg << "Tensor " << name << " mismatch!";
|
msg << "Tensor " << name << " mismatch!";
|
||||||
AT_ASSERT(equal, msg.str());
|
if (!equal) {
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void _assert_tensor_metadata(at::Tensor const& tensor, at::OptionalIntArrayRef sizes, at::OptionalIntArrayRef strides, std::optional<c10::ScalarType> dtype) {
|
void _assert_tensor_metadata_meta_symint(at::Tensor const& tensor, at::OptionalSymIntArrayRef sizes, at::OptionalSymIntArrayRef strides, std::optional<c10::ScalarType> dtype, std::optional<c10::Device> device, std::optional<c10::Layout> layout) {
|
||||||
|
_assert_match(tensor.sym_sizes(), sizes, "sizes");
|
||||||
|
_assert_match(tensor.sym_strides(), strides, "strides");
|
||||||
|
_assert_match(tensor.dtype(), dtype, "dtype");
|
||||||
|
_assert_match(tensor.device(), device, "device");
|
||||||
|
_assert_match(tensor.layout(), layout, "layout");
|
||||||
|
}
|
||||||
|
|
||||||
|
void _assert_tensor_metadata(at::Tensor const& tensor, at::OptionalIntArrayRef sizes, at::OptionalIntArrayRef strides, std::optional<c10::ScalarType> dtype, std::optional<c10::Device> device, std::optional<c10::Layout> layout) {
|
||||||
_assert_match(tensor.sizes(), sizes, "sizes");
|
_assert_match(tensor.sizes(), sizes, "sizes");
|
||||||
_assert_match(tensor.strides(), strides, "strides");
|
_assert_match(tensor.strides(), strides, "strides");
|
||||||
_assert_match(tensor.dtype(), dtype, "dtype");
|
_assert_match(tensor.dtype(), dtype, "dtype");
|
||||||
|
_assert_match(tensor.device(), device, "device");
|
||||||
|
_assert_match(tensor.layout(), layout, "layout");
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -187,7 +187,10 @@
|
||||||
dispatch:
|
dispatch:
|
||||||
CPU: _functional_assert_async_msg_cpu
|
CPU: _functional_assert_async_msg_cpu
|
||||||
|
|
||||||
- func: _assert_tensor_metadata(Tensor a, SymInt[]? size=None, SymInt[]? stride=None, ScalarType? dtype=None) -> ()
|
- func: _assert_tensor_metadata(Tensor a, SymInt[]? size=None, SymInt[]? stride=None, ScalarType? dtype=None, *, Device? device=None, Layout? layout=None) -> ()
|
||||||
|
dispatch:
|
||||||
|
CompositeExplicitAutograd: _assert_tensor_metadata
|
||||||
|
Meta: _assert_tensor_metadata_meta_symint
|
||||||
|
|
||||||
- func: _print(str s) -> ()
|
- func: _print(str s) -> ()
|
||||||
dispatch:
|
dispatch:
|
||||||
|
|
|
||||||
|
|
@ -30,6 +30,7 @@ aten::_amp_update_scale.out
|
||||||
aten::_amp_update_scale_
|
aten::_amp_update_scale_
|
||||||
aten::_assert_async
|
aten::_assert_async
|
||||||
aten::_assert_async.msg
|
aten::_assert_async.msg
|
||||||
|
aten::_assert_tensor_metadata
|
||||||
aten::_batch_norm_no_update.out
|
aten::_batch_norm_no_update.out
|
||||||
aten::_batch_norm_with_update.out
|
aten::_batch_norm_with_update.out
|
||||||
aten::_cdist_backward
|
aten::_cdist_backward
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,8 @@
|
||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
# Owner(s): ["module: internals"]
|
# Owner(s): ["module: internals"]
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||||
|
|
||||||
|
|
@ -32,6 +34,19 @@ class TestComparisonUtils(TestCase):
|
||||||
with self.assertRaises(RuntimeError):
|
with self.assertRaises(RuntimeError):
|
||||||
torch._assert_tensor_metadata(t, [3], [1], torch.float)
|
torch._assert_tensor_metadata(t, [3], [1], torch.float)
|
||||||
|
|
||||||
|
@unittest.skipIf(not torch.cuda.is_available(), "Requires cuda")
|
||||||
|
def test_assert_device(self):
|
||||||
|
t = torch.tensor([0.5], device="cpu")
|
||||||
|
|
||||||
|
with self.assertRaises(RuntimeError):
|
||||||
|
torch._assert_tensor_metadata(t, device="cuda")
|
||||||
|
|
||||||
|
def test_assert_layout(self):
|
||||||
|
t = torch.tensor([0.5])
|
||||||
|
|
||||||
|
with self.assertRaises(RuntimeError):
|
||||||
|
torch._assert_tensor_metadata(t, layout=torch.sparse_coo)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
run_tests()
|
run_tests()
|
||||||
|
|
|
||||||
|
|
@ -80,6 +80,7 @@ _side_effectful_functions: Set[Callable] = {
|
||||||
torch._assert_async,
|
torch._assert_async,
|
||||||
_ops.aten._assert_async.msg,
|
_ops.aten._assert_async.msg,
|
||||||
_ops.aten._assert_scalar.default,
|
_ops.aten._assert_scalar.default,
|
||||||
|
_ops.aten._assert_tensor_metadata.default,
|
||||||
_ops.aten.sym_constrain_range.default,
|
_ops.aten.sym_constrain_range.default,
|
||||||
_ops.aten.sym_constrain_range_for_size.default,
|
_ops.aten.sym_constrain_range_for_size.default,
|
||||||
_ops.profiler._record_function_enter,
|
_ops.profiler._record_function_enter,
|
||||||
|
|
|
||||||
|
|
@ -53,6 +53,7 @@ MUTABLE_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT = [
|
||||||
FUNCTIONAL_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT = [
|
FUNCTIONAL_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT = [
|
||||||
"_assert_async", # no return
|
"_assert_async", # no return
|
||||||
"_assert_async.msg", # no return
|
"_assert_async.msg", # no return
|
||||||
|
"_assert_tensor_metadata", # no return
|
||||||
"_cslt_sparse_mm_search", # returns an int
|
"_cslt_sparse_mm_search", # returns an int
|
||||||
"_assert_scalar", # no return
|
"_assert_scalar", # no return
|
||||||
"_dimI", # returns an int
|
"_dimI", # returns an int
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user