Revert D25966661: Support needsOutputs for RecordFunction and ObserverUtil improvements

Test Plan: revert-hammer

Differential Revision:
D25966661 (0e43a73f76)

Original commit changeset: 707886e1f212

fbshipit-source-id: a4e4af29abf622c1e0aaaf7dfb019c045988b4bc
This commit is contained in:
Qi Zhao 2021-03-30 15:39:33 -07:00 committed by Facebook GitHub Bot
parent 23b15ef98a
commit 5b448cf21a
5 changed files with 29 additions and 237 deletions

View File

@ -377,10 +377,6 @@ namespace impl {
assert_is_valid_output_type<T, AllowDeprecatedTypes>();
return c10::ivalue::from(std::move(v));
}
static IValue copy(const T& v) {
assert_is_valid_output_type<T, AllowDeprecatedTypes>();
return IValue(v);
}
};
// Special case to allow kernels to return `Tensor&`.
@ -390,9 +386,6 @@ namespace impl {
static IValue call(at::Tensor& v) {
return c10::ivalue::from(v);
}
static IValue copy(at::Tensor& v) {
return IValue(v);
}
};
// wrap_kernel_functor_unboxed_
@ -484,35 +477,23 @@ namespace impl {
static void call(OutputType&& output, Stack* stack) {
torch::jit::push(*stack, return_to_ivalue<OutputType, AllowDeprecatedTypes>::call(std::forward<OutputType>(output)));
}
static void copy(const OutputType& output, Stack* stack) {
torch::jit::push(*stack, return_to_ivalue<OutputType, AllowDeprecatedTypes>::copy(output));
}
};
template<class... OutputTypes, bool AllowDeprecatedTypes>
struct push_outputs<std::tuple<OutputTypes...>, AllowDeprecatedTypes> final {
static void call(std::tuple<OutputTypes...>&& output, Stack* stack) {
call_(std::move(output), stack, std::make_index_sequence<sizeof...(OutputTypes)>());
}
static void copy(const std::tuple<OutputTypes...>& output, Stack* stack) {
copy_(output, stack, std::make_index_sequence<sizeof...(OutputTypes)>());
}
private:
template<size_t... indices>
static void call_(std::tuple<OutputTypes...>&& output, Stack* stack, std::index_sequence<indices...>) {
torch::jit::push(*stack, return_to_ivalue<OutputTypes, AllowDeprecatedTypes>::call(std::forward<OutputTypes>(std::get<indices>(output)))...);
}
template<size_t... indices>
static void copy_(const std::tuple<OutputTypes...>& output, Stack* stack, std::index_sequence<indices...>) {
torch::jit::push(*stack, return_to_ivalue<OutputTypes, AllowDeprecatedTypes>::copy(std::get<indices>(output))...);
}
};
template<bool AllowDeprecatedTypes>
struct push_outputs<void, AllowDeprecatedTypes> final {
static void call(int /*dummy*/, Stack* /*stack*/) {
}
static void copy(int /*dummy*/, Stack* /*stack*/) {
}
};
// make_boxed_from_unboxed_functor

View File

@ -402,72 +402,9 @@ private:
};
namespace detail {
template <class... Args> inline void unused_arg_(const Args&...) {}
// CaptureKernelCall is intended to capture return values from Dispatcher
// unboxed kernel calls. A record function may request to get outputs from the
// kernel calls. For boxed kernels, it's straightforward, the returned values
// are in the stack object. The stack can be passed to record functions. For
// unboxed kernels, we need to handle different kinds of return values, cache
// them temporarily, then release the values for the actual function call
// return.
template <typename ReturnType>
struct CaptureKernelCall {
template <typename F, typename... Args>
CaptureKernelCall(
const F& kernel,
const TypedOperatorHandle<ReturnType(Args...)>& op,
const DispatchKeySet& dispatchKeySet,
Args&&... args)
// Calls the kernel and capture the result in output_.
: output_{kernel.template call<ReturnType, Args...>(
op,
dispatchKeySet,
std::forward<Args>(args)...)} {}
// Wraps the return values in a Stack.
Stack getOutputs() {
Stack stack;
impl::push_outputs<ReturnType, false>::copy(output_, &stack);
return stack;
}
// Since we are returning the output_, we don't expect the output_ to be used
// afterward. Copy elision and RVO do not apply to class data members. Using
// move semantic to avoid copies when possible.
ReturnType release() && {
return std::move(output_);
}
private:
ReturnType output_;
};
// Handle the lvalue reference differently since it should not be moved.
template <>
inline at::Tensor& CaptureKernelCall<at::Tensor&>::release() && {
return output_;
template<class... Args> inline void unused_arg_(const Args&...) {}
}
// Handle case where the kernel returns void.
template <>
struct CaptureKernelCall<void> {
template <typename F, typename... Args>
CaptureKernelCall(
const F& kernel,
const TypedOperatorHandle<void(Args...)>& op,
const DispatchKeySet& dispatchKeySet,
Args&&... args) {
// Calling the kernel and no need to capture void.
kernel.template call<void, Args...>(
op, dispatchKeySet, std::forward<Args>(args)...);
}
Stack getOutputs() {
return Stack();
}
void release() && {}
};
} // namespace detail
// See [Note: Argument forwarding in the dispatcher] for why Args doesn't use &&
template<class Return, class... Args>
inline Return Dispatcher::callWithDispatchKeySlowPath(const TypedOperatorHandle<Return(Args...)>& op, bool pre_sampled, DispatchKeySet dispatchKeySet, const KernelFunction& kernel, Args... args) {
@ -486,15 +423,6 @@ inline Return Dispatcher::callWithDispatchKeySlowPath(const TypedOperatorHandle<
} else {
runRecordFunction(guard, op, dispatchKey);
}
if (C10_UNLIKELY(guard.needsOutputs())) {
// Calls the kernel and capture the output temporarily to pass to
// RecordFunction.
detail::CaptureKernelCall<Return> captureKernelCall(
kernel, op, dispatchKeySet, std::forward<Args>(args)...);
guard.setOutputs(captureKernelCall.getOutputs());
// Releases the captured output to return to caller.
return std::move(captureKernelCall).release();
}
}
}
// keeping the guard alive while executing the kernel
@ -558,11 +486,6 @@ inline void Dispatcher::callBoxed(const OperatorHandle& op, Stack* stack) const
}
// keeping the guard alive while executing the kernel
kernel.callBoxed(op, dispatchKeySet, stack);
// track outputs
if (C10_UNLIKELY(
guard.isActive() && entry.isObserved() && guard.needsOutputs())) {
guard.setOutputs(*stack);
}
return;
}
#endif // PYTORCH_DISABLE_PER_OP_PROFILING

View File

@ -189,7 +189,6 @@ class CallbackManager {
// to be executed and whether any of them need inputs
inline void init(RecordFunction& rec_fn, RecordScope scope, bool pre_sampled) {
bool found_needs_inputs = false;
bool found_needs_outputs = false;
bool found_needs_ids = false;
for (const auto& cb: rf_tls_.sorted_tls_callbacks_) {
@ -197,9 +196,6 @@ class CallbackManager {
if (cb.first.needsInputs()) {
found_needs_inputs = true;
}
if (cb.first.needsOutputs()) {
found_needs_outputs = true;
}
if (cb.first.needsIds()) {
found_needs_ids = true;
}
@ -215,9 +211,6 @@ class CallbackManager {
if (cb.first.needsInputs()) {
found_needs_inputs = true;
}
if (cb.first.needsOutputs()) {
found_needs_outputs = true;
}
if (cb.first.needsIds()) {
found_needs_ids = true;
}
@ -237,7 +230,6 @@ class CallbackManager {
rec_fn.state_->global_ctx_.resize(rec_fn.state_->sorted_active_global_handles_.size());
rec_fn.state_->needs_inputs = found_needs_inputs;
rec_fn.state_->needs_outputs = found_needs_outputs;
if (found_needs_ids) {
rec_fn.setHandle(next_unique_record_function_handle());
}
@ -460,8 +452,6 @@ void RecordFunction::before(
state_->sequence_nr_ = sequence_nr;
state_->thread_id_ = currentThreadId();
state_->operator_name_ = op.operator_name();
state_->op_input_size = op.schema().arguments().size();
state_->op_output_size = op.schema().returns().size();
state_->name_ = StringView(op.schema().name());
manager().runStartCallbacks(*this);

View File

@ -35,7 +35,7 @@ enum class C10_API_ENUM RecordScope : uint8_t {
namespace std {
template <>
struct hash<at::RecordScope> {
size_t operator()(
inline size_t operator()(
const at::RecordScope& sc) const {
return static_cast<std::size_t>(sc);
}
@ -52,7 +52,7 @@ struct TORCH_API StringView {
: owned_str_ptr_(std::make_shared<std::string>(std::move(str))),
str_ptr_(owned_str_ptr_->c_str()) {}
const char* str() const {
inline const char* str() const {
return str_ptr_;
}
@ -116,12 +116,12 @@ struct TORCH_API RecordFunction {
RecordFunction(const RecordFunction&) = delete;
RecordFunction& operator=(const RecordFunction&) = delete;
const StringView& name() const {
inline const StringView& name() const {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(state_, "Called name() on inactive RecordFunction");
return state_->name_;
}
int64_t seqNr() const {
inline int64_t seqNr() const {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(state_, "Called seqNr() on inactive RecordFunction");
return state_->sequence_nr_;
}
@ -131,35 +131,10 @@ struct TORCH_API RecordFunction {
return state_->inputs_;
}
const std::vector<c10::IValue>& outputs() const {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(state_, "Called outputs() on inactive RecordFunction");
return state_->outputs_;
}
void setOutputs(std::vector<c10::IValue>&& outputs) const {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(state_, "Called setOutputs() on inactive RecordFunction");
state_->outputs_ = std::move(outputs);
}
void setOutputs(c10::ArrayRef<c10::IValue> outputs) const {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(state_, "Called setOutputs() on inactive RecordFunction");
state_->outputs_ = outputs.vec();
}
size_t num_inputs() const {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(state_, "Called num_inputs() on inactive RecordFunction");
return state_->op_input_size;
}
size_t num_outputs() const {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(state_, "Called num_outputs() on inactive RecordFunction");
return state_->op_output_size;
}
// Retrieves the thread_id that this RecordFunction ran start callbacks with.
// Useful for writing thread safe end callbacks that may be potentially
// executed in a different thread (async ops)
uint64_t threadId() const {
inline uint64_t threadId() const {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(state_, "Called threadId() on inactive RecordFunction");
return state_->thread_id_;
}
@ -168,17 +143,17 @@ struct TORCH_API RecordFunction {
// or zero otherwise;
// used alongside with sequence number to correlate backward functions with
// the forward ones
uint64_t forwardThreadId() const {
inline uint64_t forwardThreadId() const {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(state_, "Called forwardThreadId() on inactive RecordFunction");
return state_->fwd_thread_id_;
}
void setForwardThreadId(uint64_t thread_id) {
inline void setForwardThreadId(uint64_t thread_id) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(state_, "Called setForwardThreadId() on inactive RecordFunction");
state_->fwd_thread_id_ = thread_id;
}
RecordScope scope() const {
inline RecordScope scope() const {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(state_, "Called scope() on inactive RecordFunction");
return state_->scope_;
}
@ -227,17 +202,17 @@ struct TORCH_API RecordFunction {
// Calls end callbacks. After end(), accessors will no longer provide useful results.
void end();
RecordFunctionHandle handle() const {
inline RecordFunctionHandle handle() const {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(state_, "Called handle() on inactive RecordFunction");
return state_->handle_;
}
c10::optional<OperatorName> operator_name() const {
inline c10::optional<OperatorName> operator_name() const {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(state_, "Called operator_name() on inactive RecordFunction");
return state_->operator_name_;
}
void setHandle(RecordFunctionHandle handle) {
inline void setHandle(RecordFunctionHandle handle) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(state_, "Called setHandle() on inactive RecordFunction");
state_->handle_ = handle;
}
@ -252,11 +227,6 @@ struct TORCH_API RecordFunction {
return state_->needs_inputs;
}
bool needsOutputs() const {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(state_, "Called needsOutputs() on inactive RecordFunction");
return state_->needs_outputs;
}
private:
// Allows the modification of some internal states for callbacks.
@ -268,9 +238,6 @@ struct TORCH_API RecordFunction {
// Whether any of the picked callbacks require inputs
bool needs_inputs = false;
// Whether any of the picked callbacks require outputs
bool needs_outputs = false;
// In cases when RecordFunction might be active but we chose not to
// use the observers (e.g. operator is not observed), this boolean
// flag is used to check whether the start callbacks were called
@ -295,11 +262,8 @@ struct TORCH_API RecordFunction {
StringView name_;
int64_t sequence_nr_ = -1;
std::vector<c10::IValue> inputs_;
std::vector<c10::IValue> outputs_;
c10::optional<c10::OperatorName> operator_name_;
size_t op_input_size{0};
size_t op_output_size{0};
// Kind of scope this RecordFunction is observing
const RecordScope scope_;
@ -359,11 +323,6 @@ class TORCH_API RecordFunctionCallback {
return *this;
}
RecordFunctionCallback& needsOutputs(bool needs_outputs) {
needs_outputs_ = needs_outputs;
return *this;
}
RecordFunctionCallback& needsIds(bool needs_ids) {
needs_ids_ = needs_ids;
return *this;
@ -395,31 +354,27 @@ class TORCH_API RecordFunctionCallback {
return *this;
}
bool needsInputs() const {
inline bool needsInputs() const {
return needs_inputs_;
}
bool needsOutputs() const {
return needs_outputs_;
}
bool needsIds() const {
inline bool needsIds() const {
return needs_ids_;
}
double samplingProb() const {
inline double samplingProb() const {
return sampling_prob_;
}
bool checkScope(RecordScope sc) const {
inline bool checkScope(RecordScope sc) const {
return scopes_[(size_t)sc];
}
StartCallback start() const {
inline StartCallback start() const {
return start_;
}
EndCallback end() const {
inline EndCallback end() const {
return end_;
}
@ -431,7 +386,6 @@ class TORCH_API RecordFunctionCallback {
double sampling_prob_ = 1.0;
std::array<bool, static_cast<size_t>(RecordScope::NUM_SCOPES)> scopes_ = {};
bool needs_inputs_ = false;
bool needs_outputs_ = false;
bool needs_ids_ = false;
};

View File

@ -691,10 +691,10 @@ at::Tensor invokeTestRecordFunctionJIT(at::Tensor& t) {
return module->forward({t}).toTensor();
}
using TracedTestValues =
using TracedTestInputs =
std::vector<std::tuple<std::string, std::vector<std::vector<int64_t>>>>;
void checkTracedInputs(const TracedTestValues& inputs) {
void checkTracedInputs(const TracedTestInputs& inputs) {
bool found_test = false;
bool found_pow = false;
bool found_mul = false;
@ -722,32 +722,6 @@ void checkTracedInputs(const TracedTestValues& inputs) {
TORCH_CHECK(found_mul);
}
void checkTracedOutputs(const TracedTestValues& outputs) {
bool found_test = false;
bool found_pow = false;
bool found_mul = false;
for (const auto& output : outputs) {
const auto& fn = std::get<0>(output);
const auto& sizes = std::get<1>(output);
if (fn == "test") {
found_test = true;
TORCH_CHECK(sizes.empty());
} else if (fn == "aten::pow") {
found_pow = true;
TORCH_CHECK(sizes.size() == 1);
TORCH_CHECK(sizes[0] == std::vector<int64_t>({1, 2, 3}));
} else if (fn == "aten::mul") {
found_mul = true;
TORCH_CHECK(sizes.size() == 1);
TORCH_CHECK(sizes[0] == std::vector<int64_t>({1, 2, 3}));
}
}
TORCH_CHECK(found_test);
TORCH_CHECK(found_pow);
TORCH_CHECK(found_mul);
}
static bool bad_scope = false;
template <RecordScope scope, size_t* cnt>
std::unique_ptr<at::ObserverContext> checkScopeCallback(
@ -829,10 +803,8 @@ static bool shouldRunCallback(const RecordFunctionCallback&) {
return should_run;
}
static TracedTestValues traced_inputs;
static TracedTestValues traced_outputs;
static std::unordered_set<std::string> ts_input_names;
static std::unordered_set<std::string> ts_output_names;
static TracedTestInputs traced_inputs;
static std::unordered_set<std::string> ts_names;
std::unique_ptr<at::ObserverContext> tracedInputsCallback(
const RecordFunction& fn) {
@ -848,71 +820,43 @@ std::unique_ptr<at::ObserverContext> tracedInputsCallback(
}
traced_inputs.push_back(std::make_tuple(fn.name().str(), sizes));
} else if (fn.scope() == RecordScope::TORCHSCRIPT_FUNCTION) {
ts_input_names.insert(fn.name().str());
ts_names.insert(fn.name().str());
}
return nullptr;
}
void tracedOutputsCallback(const RecordFunction& fn, ObserverContext* ctx_ptr) {
if (fn.scope() == RecordScope::FUNCTION) {
auto outputs = fn.outputs();
std::vector<std::vector<int64_t>> sizes;
for (const auto& output : outputs) {
if (output.isTensor()) {
sizes.push_back(output.toTensor().sizes().vec());
} else if (output.isScalar()) {
sizes.emplace_back();
}
}
traced_outputs.push_back(std::make_tuple(fn.name().str(), sizes));
} else if (fn.scope() == RecordScope::TORCHSCRIPT_FUNCTION) {
ts_output_names.insert(fn.name().str());
}
}
TEST(RecordFunctionTest, TracedTestInputsOutputs) {
TEST(RecordFunctionTest, TracedTestInputs) {
// disabling the inlining of method calls
GraphOptimizerEnabledGuard opt_guard(false);
// [(fn, [[sizes], [sizes], ...]), ...]
addGlobalCallback(
RecordFunctionCallback(tracedInputsCallback, tracedOutputsCallback)
.needsInputs(true)
.needsOutputs(true));
RecordFunctionCallback(tracedInputsCallback).needsInputs(true));
TracedTestValues eager_inputs, eager_outputs, jit_inputs, jit_outputs;
TracedTestInputs eager_inputs, jit_inputs;
{
auto t = torch::randn({1, 2, 3}, at::kCPU);
t.set_requires_grad(true);
auto t2 = invokeTestRecordFunction(t);
t2.backward(torch::ones_like(t2, at::MemoryFormat::Preserve));
eager_inputs = traced_inputs;
eager_outputs = traced_outputs;
traced_inputs.clear();
traced_outputs.clear();
TORCH_CHECK(ts_input_names.empty());
TORCH_CHECK(ts_output_names.empty());
TORCH_CHECK(ts_names.empty());
t = torch::randn({1, 2, 3}, at::kCPU);
t.set_requires_grad(true);
t2 = invokeTestRecordFunctionJIT(t);
t2.backward(torch::ones_like(t2, at::MemoryFormat::Preserve));
jit_inputs = traced_inputs;
jit_outputs = traced_outputs;
traced_inputs.clear();
traced_outputs.clear();
}
TORCH_CHECK(ts_input_names.find("forward") != ts_input_names.end());
TORCH_CHECK(ts_input_names.find("foo") != ts_input_names.end());
TORCH_CHECK(ts_output_names.find("forward") != ts_output_names.end());
TORCH_CHECK(ts_output_names.find("foo") != ts_output_names.end());
TORCH_CHECK(ts_names.find("forward") != ts_names.end());
TORCH_CHECK(ts_names.find("foo") != ts_names.end());
checkTracedInputs(eager_inputs);
checkTracedOutputs(eager_outputs);
checkTracedInputs(jit_inputs);
checkTracedOutputs(jit_outputs);
at::clearCallbacks();
}