mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
This was added in https://github.com/pytorch/pytorch/pull/119562 the idea in this loop seems to be the following. ``` if (TORCH_GUARD_SIZE_OBLIVIOUS(size.sym_eq(1))) { // NB: we could short circuit this once needs_reduce is true but there's // no point since the reduction function will guard on this anyway if (!c10::guard_or_false(size.sym_eq(target), __FILE__, __LINE__)) { needs_reduce = true; } } else { if (!size.sym_eq(target).expect_true(__FILE__, __LINE__)) { fail(); } } ``` 1. if we know size ==1 1.1 : if we know for sure size == target --> no reduce needed. 1.2 : we know for sure that size != target --> we do reduction. 1.3: we could not tell if size == target or not --> we do reduction. 2. if we do now know if size ==1 or not we add a runtime assertions that size ==target and we fail at runtime if size is not equal to target. We could have simplified 1.1 and always do reduction under 1.1, since doing 1.3 without runtime checks implies that it is safe, but i feel the reason could be perf here? idk. anyway using TORCH_GUARD_OR_FALSE instead of TORCH_GUARD_SIZE_OBLIVIOUS here is appropriate. there is really no clear reason for size oblivious reasoning. or for this logic not to apply when size is not size like size is always >=0 anyway. but bad reasoning can make us not able to infer that although we know its true here. python test/dynamo/test_misc.py -k test_validate_outputs_unbacked Pull Request resolved: https://github.com/pytorch/pytorch/pull/154172 Approved by: https://github.com/bobrenjc93 ghstack dependencies: #154154, #154164, #154167
205 lines
6.3 KiB
C++
205 lines
6.3 KiB
C++
#include <torch/csrc/autograd/input_metadata.h>
|
|
|
|
// TODO: we may be able to move some imports from input_metadata.h to here, but
|
|
// it seems that function.h transitively depends on some of them.
|
|
|
|
namespace torch::autograd {
|
|
|
|
namespace {
|
|
|
|
MetadataShape compute_variant_shape(const at::Tensor& input) {
|
|
if (input.is_nested() && !input.unsafeGetTensorImpl()->is_python_dispatch()) {
|
|
auto nested_size = input._nested_tensor_size();
|
|
return MetadataShape{std::in_place_type<at::Tensor>, nested_size};
|
|
}
|
|
return MetadataShape{std::in_place_type<SymIntSmallVec>, input.sym_sizes()};
|
|
}
|
|
|
|
bool is_python_dispatch(const at::Tensor& tensor) {
|
|
return tensor.unsafeGetTensorImpl()->is_python_dispatch();
|
|
}
|
|
|
|
bool is_cpp_nested_tensor(const at::Tensor& tensor) {
|
|
return tensor.is_nested() && !is_python_dispatch(tensor);
|
|
}
|
|
|
|
} // namespace
|
|
|
|
InputMetadata::InputMetadata(
|
|
const at::TensorOptions& options,
|
|
MetadataShape input_shape,
|
|
bool is_tensor_subclass,
|
|
bool is_nested)
|
|
: options_{options},
|
|
shape_{std::move(input_shape)},
|
|
is_tensor_subclass_{is_tensor_subclass},
|
|
is_nested_{is_nested},
|
|
was_default_constructed_{false} {
|
|
auto device_ = options.device();
|
|
stream_ = c10::impl::getDeviceGuardImpl(device_.type())->getStream(device_);
|
|
}
|
|
|
|
InputMetadata::InputMetadata(const at::Tensor& t)
|
|
: InputMetadata(
|
|
t.options(),
|
|
compute_variant_shape(t),
|
|
is_python_dispatch(t),
|
|
t.is_nested()) {}
|
|
|
|
at::Tensor InputMetadata::zeros_like() const {
|
|
TORCH_CHECK(
|
|
!is_nested_, "Zeros is not currently supported for nested tensors.")
|
|
return at::zeros_symint(shape_as_dim_vector(), options_);
|
|
}
|
|
|
|
at::Tensor InputMetadata::maybe_reduce(
|
|
const size_t i,
|
|
at::Tensor grad,
|
|
const std::function<std::string(const std::string&)>& format_error) const {
|
|
auto fail = [&]() {
|
|
const auto message = incompatible_shape_error_message(i, grad);
|
|
TORCH_CHECK(false, format_error(message.str()));
|
|
};
|
|
|
|
// Nested tensor makes my brain explode, so I've just hard-coded the logic
|
|
// for this case, at risk of code duplication. This logic does NOT do the
|
|
// careful oblivious logic as seen below
|
|
if (is_nested_ || is_cpp_nested_tensor() || grad.is_nested() ||
|
|
::torch::autograd::is_cpp_nested_tensor(grad)) {
|
|
if (!is_same_shape(grad)) {
|
|
if (is_expandable_to_shape(grad)) {
|
|
return reduce_grad(grad);
|
|
} else {
|
|
fail();
|
|
}
|
|
} else {
|
|
return grad;
|
|
}
|
|
}
|
|
|
|
auto shape = shape_as_dim_vector();
|
|
auto desired = grad.sym_sizes();
|
|
|
|
size_t ndim = shape.size();
|
|
size_t target_dim = desired.size();
|
|
if (ndim > target_dim) {
|
|
fail();
|
|
}
|
|
bool needs_reduce = false;
|
|
for (const auto i : c10::irange(ndim)) {
|
|
const auto& size = shape[ndim - i - 1];
|
|
const auto& target = desired[target_dim - i - 1];
|
|
// The conditions here are written carefully so that we are able to
|
|
// infer deferred runtime asserts
|
|
if (TORCH_GUARD_OR_FALSE(size.sym_eq(1))) {
|
|
// NB: we could short circuit this once needs_reduce is true but there's
|
|
// no point since the reduction function will guard on this anyway
|
|
if (!c10::guard_or_false(size.sym_eq(target), __FILE__, __LINE__)) {
|
|
needs_reduce = true;
|
|
}
|
|
} else {
|
|
if (!size.sym_eq(target).expect_true(__FILE__, __LINE__)) {
|
|
fail();
|
|
}
|
|
}
|
|
}
|
|
if (ndim != target_dim) {
|
|
needs_reduce = true;
|
|
}
|
|
|
|
if (needs_reduce) {
|
|
return reduce_grad(grad);
|
|
} else {
|
|
return grad;
|
|
}
|
|
}
|
|
|
|
bool InputMetadata::is_same_shape(const at::Tensor& grad) const {
|
|
if (!is_nestedness_same(grad)) {
|
|
return false;
|
|
}
|
|
if (is_cpp_nested_tensor()) {
|
|
return grad._nested_tensor_size().is_same_size(shape_as_tensor());
|
|
}
|
|
return grad.sym_sizes().equals(shape_as_dim_vector());
|
|
}
|
|
|
|
bool InputMetadata::is_expandable_to_shape(const at::Tensor& grad) const {
|
|
if (!maybe_expandable_to(grad)) {
|
|
return false;
|
|
}
|
|
return at::is_expandable_to(shape_as_dim_vector(), grad.sym_sizes());
|
|
}
|
|
|
|
at::Tensor InputMetadata::reduce_grad(at::Tensor& grad) const {
|
|
// reduce_grad should only be called if is_expandable_to_shape returns true.
|
|
TORCH_INTERNAL_ASSERT(maybe_expandable_to(grad));
|
|
return at::sum_to(std::move(grad), shape_as_dim_vector());
|
|
}
|
|
|
|
std::stringstream InputMetadata::incompatible_shape_error_message(
|
|
const size_t index,
|
|
const at::Tensor& grad) const {
|
|
std::stringstream ss{};
|
|
ss << "invalid gradient at index " << index << " - got ";
|
|
if (::torch::autograd::is_cpp_nested_tensor(grad)) {
|
|
ss << grad._nested_tensor_size();
|
|
} else {
|
|
ss << grad.sym_sizes();
|
|
}
|
|
ss << " but expected shape compatible with ";
|
|
if (is_cpp_nested_tensor()) {
|
|
ss << shape_as_tensor();
|
|
} else {
|
|
ss << shape_as_dim_vector();
|
|
}
|
|
return ss;
|
|
}
|
|
|
|
bool InputMetadata::is_cpp_nested_tensor() const {
|
|
bool ret = std::holds_alternative<at::Tensor>(shape_);
|
|
TORCH_INTERNAL_ASSERT(ret == (is_nested_ && !is_tensor_subclass_))
|
|
return ret;
|
|
}
|
|
|
|
c10::SymIntArrayRef InputMetadata::shape_as_dim_vector() const {
|
|
const auto& dim_shape = std::get<SymIntSmallVec>(shape_);
|
|
return c10::SymIntArrayRef(dim_shape.data(), dim_shape.size());
|
|
}
|
|
|
|
// Danger: not thread safe, caller must protect with lock
|
|
SymIntSmallVec& InputMetadata::mutable_shape_as_dim_vector() {
|
|
return std::get<SymIntSmallVec>(shape_);
|
|
}
|
|
|
|
bool InputMetadata::is_nestedness_same(const at::Tensor& grad) const {
|
|
return (
|
|
grad.is_nested() == is_nested_ &&
|
|
::torch::autograd::is_cpp_nested_tensor(grad) == is_cpp_nested_tensor());
|
|
}
|
|
|
|
at::Tensor InputMetadata::shape_as_tensor() const {
|
|
return std::get<at::Tensor>(shape_);
|
|
}
|
|
|
|
bool InputMetadata::maybe_expandable_to(const at::Tensor& grad) const {
|
|
// This is the initial step to determine whether or not the tensor represented
|
|
// by input_metadata is expandable to grad based on is-nestedness information
|
|
// alone. If this function returns true, then is_expandable_to_shape will be
|
|
// called. We support the following 3 types of expansion:
|
|
bool grad_is_nested = grad.is_nested();
|
|
if (!is_nested_ && !grad_is_nested) {
|
|
// Normal case (no NestedTensors are involved)
|
|
// (1) plain Tensor -> plain Tensor
|
|
return true;
|
|
} else {
|
|
// (2) python NT -> python NT
|
|
// (3) plain Tensor -> python NT
|
|
return (
|
|
grad_is_nested && is_python_dispatch(grad) &&
|
|
(!is_nested_ || is_tensor_subclass_));
|
|
}
|
|
}
|
|
|
|
} // namespace torch::autograd
|