mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Compute type_equal() without reference to backend() (#53823)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/53823 Argument for correctness: type_equal previous compared if backends are equal. Backend is computed by translation from dispatch key. I verified that computeDispatchKey never computed a weird dispatch key (e.g., AutogradXLA), so that dispatchKeyToBackend was effectively injective. Then it is always valid to compare the arguments of an injective function for equality, rather than the output of the injective function. Signed-off-by: Edward Z. Yang <ezyang@fb.com> Test Plan: Imported from OSS Reviewed By: mruberry Differential Revision: D27036575 Pulled By: ezyang fbshipit-source-id: 6aeafc89f287da0bc0065bd21c1adb5e272dbb81
This commit is contained in:
parent
3c457043fb
commit
d47fd3df81
|
|
@ -339,7 +339,7 @@ struct C10_API TensorOptions {
|
|||
|
||||
// For compatibility with legacy tensor.type() comparisons
|
||||
bool type_equal(const TensorOptions& other) const {
|
||||
return backend() == other.backend() && typeMetaToScalarType(dtype_) == typeMetaToScalarType(other.dtype());
|
||||
return computeDispatchKey() == other.computeDispatchKey() && typeMetaToScalarType(dtype_) == typeMetaToScalarType(other.dtype());
|
||||
}
|
||||
|
||||
/// Returns the `pinned_memory` property of the `TensorOptions`, or
|
||||
|
|
@ -404,6 +404,10 @@ struct C10_API TensorOptions {
|
|||
return DispatchKeySet(computeDispatchKey());
|
||||
}
|
||||
|
||||
// INVARIANT: computeDispatchKey returns only the subset of dispatch keys for
|
||||
// which dispatchKeyToBackend is injective, if it is defined at all (for
|
||||
// the most part, this just means that this function never returns an
|
||||
// Autograd key)
|
||||
DispatchKey computeDispatchKey() const {
|
||||
return c10::computeDispatchKey(optTypeMetaToScalarType(dtype_opt()), layout_opt(), device_opt());
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user