mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary:
This is the second reland attempt for https://github.com/pytorch/pytorch/pull/32140.
The first reland attempt https://github.com/pytorch/pytorch/pull/35011 failed due a [small incompatible change](https://github.com/pytorch/pytorch/pull/35011#issuecomment-601754216) in recent master (`skipIfRocm` was removed from `test_data_parallel.py`).
The present PR restores skipIfRocm.
Description from first reland attempt https://github.com/pytorch/pytorch/pull/35011:
> https://github.com/pytorch/pytorch/pull/32140 was approved and merged, but [reverted](d0577e19f0) because it broke builds with versions of Visual Studio older than 15.8 that were not represented in public CI. The build failures were caused by a [known VS bug](https://developercommunity.visualstudio.com/content/problem/27729/allow-function-with-internal-linkage-as-template-n.html), fixed in versions 15.8 and newer.
>
> The present PR reverts the revert (restoring https://github.com/pytorch/pytorch/pull/32140 's diffs) and adds a workaround to enable compilation with VS < 15.8. The workaround isn't pretty, but it's guarded by macros such that it's only used when compiling with VS < 15.8. All other builds compile with the same code/control flow as was merged in https://github.com/pytorch/pytorch/pull/32140.
>
> Original description of https://github.com/pytorch/pytorch/pull/32140:
> > Initial integration of eager autocasting, supporting out-of-place ops only for easier review.
> Relevant issue/RFC: https://github.com/pytorch/pytorch/issues/25081
>
> > In-place ops and ops with user-supplied out=... can certainly be supported as well (my initial WIP https://github.com/pytorch/pytorch/issues/29552 handled many) but require substantially more complex special casing in the autocasting backend and tests. Support for these ops (much of which has already been written) will be broken into later PRs.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/35102
Differential Revision: D20596918
Pulled By: ezyang
fbshipit-source-id: 60caa279bb0ce4a9bb0b28c1d585d42cf1cc7e50
126 lines
4.0 KiB
C++
126 lines
4.0 KiB
C++
#include <c10/core/impl/LocalDispatchKeySet.h>
|
|
|
|
#include <iostream>
|
|
|
|
namespace c10 {
|
|
namespace impl {
|
|
|
|
C10_DEFINE_bool(disable_variable_dispatch, false, "This flag forcibly disables the Variable code paths from executing, which currently breaks profiling in the process.");
|
|
|
|
namespace {
|
|
|
|
/// In the CAFFE2_FB_LIMITED_MOBILE_CAPABILITY build setting,
|
|
/// thread_local is not supported.
|
|
#ifndef CAFFE2_FB_LIMITED_MOBILE_CAPABILITY
|
|
|
|
// NB: POD, zero initialized!
|
|
thread_local PODLocalDispatchKeySet raw_local_dispatch_key_set;
|
|
|
|
#else // defined(CAFFE2_FB_LIMITED_MOBILE_CAPABILITY)
|
|
|
|
static PODLocalDispatchKeySet raw_local_dispatch_key_set;
|
|
|
|
#endif
|
|
|
|
} // anonymous namespace
|
|
|
|
LocalDispatchKeySet tls_local_dispatch_key_set() {
|
|
// Hack until variable performance is fixed
|
|
//
|
|
// ezyang: I'm pretty unhappy about this implementation, it looks wrong
|
|
// to me, as it seems to be performing a mutation on
|
|
// raw_local_dispatch_key_set. I can't conveniently test the correct
|
|
// version though...
|
|
if (FLAGS_disable_variable_dispatch) {
|
|
raw_local_dispatch_key_set.set_excluded(
|
|
raw_local_dispatch_key_set.excluded().add(
|
|
DispatchKey::VariableTensorId));
|
|
}
|
|
return raw_local_dispatch_key_set;
|
|
}
|
|
|
|
// An RAII guard could snapshot and restore the entire state (entire DispatchKeySet) as
|
|
// opposed to only snapshotting and restoring the state of its assigned DispatchKey.
|
|
// I'm not sure which is better. If only the RAII API is used, the two choices are
|
|
// not distinguishable.
|
|
//
|
|
// However, if the guard chooses to snapshot and restore the entire DispatchKeySet,
|
|
// the interaction with the non-RAII API changes. Consider this sequence of events:
|
|
// - An RAII guard is declared for a particular DispatchKey, but snapshots the entire
|
|
// current DispatchKeySet.
|
|
// - A call to the non-RAII API changes the state for a different DispatchKey.
|
|
// - The RAII guard goes out of scope, restoring the entire DispatchKeySet it snapshotted
|
|
// (which restores the state for its own assigned DispatchKey and wipes out the state
|
|
// for the other DispatchKey set by the non-RAII API).
|
|
|
|
// RAII API
|
|
|
|
IncludeDispatchKeyGuard::IncludeDispatchKeyGuard(DispatchKey x)
|
|
: tls_(&raw_local_dispatch_key_set)
|
|
, id_(x)
|
|
// NB: prev_state_ == true on Undefined makes the guard no-op
|
|
, prev_state_(x == DispatchKey::Undefined ? true : tls_->included().has(x)) {
|
|
if (!prev_state_) {
|
|
tls_->set_included(tls_->included().add(x));
|
|
}
|
|
}
|
|
|
|
IncludeDispatchKeyGuard::~IncludeDispatchKeyGuard() {
|
|
if (!prev_state_) {
|
|
tls_->set_included(tls_->included().remove(id_));
|
|
}
|
|
}
|
|
|
|
ExcludeDispatchKeyGuard::ExcludeDispatchKeyGuard(DispatchKey x)
|
|
: tls_(&raw_local_dispatch_key_set)
|
|
, id_(x)
|
|
// NB: prev_state_ == true on Undefined makes the guard no-op
|
|
, prev_state_(x == DispatchKey::Undefined ? true : tls_->excluded().has(x)) {
|
|
if (!prev_state_) {
|
|
tls_->set_excluded(tls_->excluded().add(x));
|
|
}
|
|
}
|
|
|
|
ExcludeDispatchKeyGuard::~ExcludeDispatchKeyGuard() {
|
|
if (!prev_state_) {
|
|
tls_->set_excluded(tls_->excluded().remove(id_));
|
|
}
|
|
}
|
|
|
|
// Non-RAII API
|
|
// Please prefer using the RAII API. See declarations in LocalDispatchKeySet.h for details.
|
|
|
|
bool tls_is_dispatch_key_excluded(DispatchKey x) {
|
|
return raw_local_dispatch_key_set.excluded().has(x);
|
|
}
|
|
|
|
void tls_set_dispatch_key_excluded(DispatchKey x, bool desired_state) {
|
|
auto* tls = &raw_local_dispatch_key_set;
|
|
bool current_state = tls->excluded().has(x);
|
|
if (desired_state != current_state) {
|
|
if (desired_state) {
|
|
tls->set_excluded(tls->excluded().add(x));
|
|
} else {
|
|
tls->set_excluded(tls->excluded().remove(x));
|
|
}
|
|
}
|
|
}
|
|
|
|
bool tls_is_dispatch_key_included(DispatchKey x) {
|
|
return raw_local_dispatch_key_set.included().has(x);
|
|
}
|
|
|
|
void tls_set_dispatch_key_included(DispatchKey x, bool desired_state) {
|
|
auto* tls = &raw_local_dispatch_key_set;
|
|
bool current_state = tls->included().has(x);
|
|
if (desired_state != current_state) {
|
|
if (desired_state) {
|
|
tls->set_included(tls->included().add(x));
|
|
} else {
|
|
tls->set_included(tls->included().remove(x));
|
|
}
|
|
}
|
|
}
|
|
|
|
}} // namespace c10::impl
|