Compute type_equal() without reference to backend() (#53823)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/53823

Argument for correctness: type_equal previous compared if backends
are equal.  Backend is computed by translation from dispatch key.
I verified that computeDispatchKey never computed a weird
dispatch key (e.g., AutogradXLA), so that dispatchKeyToBackend
was effectively injective.  Then it is always valid to compare
the arguments of an injective function for equality, rather than
the output of the injective function.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

Test Plan: Imported from OSS

Reviewed By: mruberry

Differential Revision: D27036575

Pulled By: ezyang

fbshipit-source-id: 6aeafc89f287da0bc0065bd21c1adb5e272dbb81
This commit is contained in:
Edward Yang 2021-03-16 15:15:52 -07:00 committed by Facebook GitHub Bot
parent 3c457043fb
commit d47fd3df81

View File

@ -339,7 +339,7 @@ struct C10_API TensorOptions {
// For compatibility with legacy tensor.type() comparisons
bool type_equal(const TensorOptions& other) const {
return backend() == other.backend() && typeMetaToScalarType(dtype_) == typeMetaToScalarType(other.dtype());
return computeDispatchKey() == other.computeDispatchKey() && typeMetaToScalarType(dtype_) == typeMetaToScalarType(other.dtype());
}
/// Returns the `pinned_memory` property of the `TensorOptions`, or
@ -404,6 +404,10 @@ struct C10_API TensorOptions {
return DispatchKeySet(computeDispatchKey());
}
// INVARIANT: computeDispatchKey returns only the subset of dispatch keys for
// which dispatchKeyToBackend is injective, if it is defined at all (for
// the most part, this just means that this function never returns an
// Autograd key)
DispatchKey computeDispatchKey() const {
return c10::computeDispatchKey(optTypeMetaToScalarType(dtype_opt()), layout_opt(), device_opt());
}