mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert D25966661: Support needsOutputs for RecordFunction and ObserverUtil improvements
Test Plan: revert-hammer
Differential Revision:
D25966661 (0e43a73f76)
Original commit changeset: 707886e1f212
fbshipit-source-id: a4e4af29abf622c1e0aaaf7dfb019c045988b4bc
This commit is contained in:
parent
23b15ef98a
commit
5b448cf21a
|
|
@ -377,10 +377,6 @@ namespace impl {
|
|||
assert_is_valid_output_type<T, AllowDeprecatedTypes>();
|
||||
return c10::ivalue::from(std::move(v));
|
||||
}
|
||||
static IValue copy(const T& v) {
|
||||
assert_is_valid_output_type<T, AllowDeprecatedTypes>();
|
||||
return IValue(v);
|
||||
}
|
||||
};
|
||||
|
||||
// Special case to allow kernels to return `Tensor&`.
|
||||
|
|
@ -390,9 +386,6 @@ namespace impl {
|
|||
static IValue call(at::Tensor& v) {
|
||||
return c10::ivalue::from(v);
|
||||
}
|
||||
static IValue copy(at::Tensor& v) {
|
||||
return IValue(v);
|
||||
}
|
||||
};
|
||||
|
||||
// wrap_kernel_functor_unboxed_
|
||||
|
|
@ -484,35 +477,23 @@ namespace impl {
|
|||
static void call(OutputType&& output, Stack* stack) {
|
||||
torch::jit::push(*stack, return_to_ivalue<OutputType, AllowDeprecatedTypes>::call(std::forward<OutputType>(output)));
|
||||
}
|
||||
static void copy(const OutputType& output, Stack* stack) {
|
||||
torch::jit::push(*stack, return_to_ivalue<OutputType, AllowDeprecatedTypes>::copy(output));
|
||||
}
|
||||
};
|
||||
template<class... OutputTypes, bool AllowDeprecatedTypes>
|
||||
struct push_outputs<std::tuple<OutputTypes...>, AllowDeprecatedTypes> final {
|
||||
static void call(std::tuple<OutputTypes...>&& output, Stack* stack) {
|
||||
call_(std::move(output), stack, std::make_index_sequence<sizeof...(OutputTypes)>());
|
||||
}
|
||||
static void copy(const std::tuple<OutputTypes...>& output, Stack* stack) {
|
||||
copy_(output, stack, std::make_index_sequence<sizeof...(OutputTypes)>());
|
||||
}
|
||||
|
||||
private:
|
||||
template<size_t... indices>
|
||||
static void call_(std::tuple<OutputTypes...>&& output, Stack* stack, std::index_sequence<indices...>) {
|
||||
torch::jit::push(*stack, return_to_ivalue<OutputTypes, AllowDeprecatedTypes>::call(std::forward<OutputTypes>(std::get<indices>(output)))...);
|
||||
}
|
||||
template<size_t... indices>
|
||||
static void copy_(const std::tuple<OutputTypes...>& output, Stack* stack, std::index_sequence<indices...>) {
|
||||
torch::jit::push(*stack, return_to_ivalue<OutputTypes, AllowDeprecatedTypes>::copy(std::get<indices>(output))...);
|
||||
}
|
||||
};
|
||||
template<bool AllowDeprecatedTypes>
|
||||
struct push_outputs<void, AllowDeprecatedTypes> final {
|
||||
static void call(int /*dummy*/, Stack* /*stack*/) {
|
||||
}
|
||||
static void copy(int /*dummy*/, Stack* /*stack*/) {
|
||||
}
|
||||
};
|
||||
|
||||
// make_boxed_from_unboxed_functor
|
||||
|
|
|
|||
|
|
@ -402,72 +402,9 @@ private:
|
|||
};
|
||||
|
||||
namespace detail {
|
||||
template <class... Args> inline void unused_arg_(const Args&...) {}
|
||||
|
||||
// CaptureKernelCall is intended to capture return values from Dispatcher
|
||||
// unboxed kernel calls. A record function may request to get outputs from the
|
||||
// kernel calls. For boxed kernels, it's straightforward, the returned values
|
||||
// are in the stack object. The stack can be passed to record functions. For
|
||||
// unboxed kernels, we need to handle different kinds of return values, cache
|
||||
// them temporarily, then release the values for the actual function call
|
||||
// return.
|
||||
template <typename ReturnType>
|
||||
struct CaptureKernelCall {
|
||||
template <typename F, typename... Args>
|
||||
CaptureKernelCall(
|
||||
const F& kernel,
|
||||
const TypedOperatorHandle<ReturnType(Args...)>& op,
|
||||
const DispatchKeySet& dispatchKeySet,
|
||||
Args&&... args)
|
||||
// Calls the kernel and capture the result in output_.
|
||||
: output_{kernel.template call<ReturnType, Args...>(
|
||||
op,
|
||||
dispatchKeySet,
|
||||
std::forward<Args>(args)...)} {}
|
||||
// Wraps the return values in a Stack.
|
||||
Stack getOutputs() {
|
||||
Stack stack;
|
||||
impl::push_outputs<ReturnType, false>::copy(output_, &stack);
|
||||
return stack;
|
||||
}
|
||||
// Since we are returning the output_, we don't expect the output_ to be used
|
||||
// afterward. Copy elision and RVO do not apply to class data members. Using
|
||||
// move semantic to avoid copies when possible.
|
||||
ReturnType release() && {
|
||||
return std::move(output_);
|
||||
}
|
||||
|
||||
private:
|
||||
ReturnType output_;
|
||||
};
|
||||
|
||||
// Handle the lvalue reference differently since it should not be moved.
|
||||
template <>
|
||||
inline at::Tensor& CaptureKernelCall<at::Tensor&>::release() && {
|
||||
return output_;
|
||||
template<class... Args> inline void unused_arg_(const Args&...) {}
|
||||
}
|
||||
|
||||
// Handle case where the kernel returns void.
|
||||
template <>
|
||||
struct CaptureKernelCall<void> {
|
||||
template <typename F, typename... Args>
|
||||
CaptureKernelCall(
|
||||
const F& kernel,
|
||||
const TypedOperatorHandle<void(Args...)>& op,
|
||||
const DispatchKeySet& dispatchKeySet,
|
||||
Args&&... args) {
|
||||
// Calling the kernel and no need to capture void.
|
||||
kernel.template call<void, Args...>(
|
||||
op, dispatchKeySet, std::forward<Args>(args)...);
|
||||
}
|
||||
Stack getOutputs() {
|
||||
return Stack();
|
||||
}
|
||||
void release() && {}
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
// See [Note: Argument forwarding in the dispatcher] for why Args doesn't use &&
|
||||
template<class Return, class... Args>
|
||||
inline Return Dispatcher::callWithDispatchKeySlowPath(const TypedOperatorHandle<Return(Args...)>& op, bool pre_sampled, DispatchKeySet dispatchKeySet, const KernelFunction& kernel, Args... args) {
|
||||
|
|
@ -486,15 +423,6 @@ inline Return Dispatcher::callWithDispatchKeySlowPath(const TypedOperatorHandle<
|
|||
} else {
|
||||
runRecordFunction(guard, op, dispatchKey);
|
||||
}
|
||||
if (C10_UNLIKELY(guard.needsOutputs())) {
|
||||
// Calls the kernel and capture the output temporarily to pass to
|
||||
// RecordFunction.
|
||||
detail::CaptureKernelCall<Return> captureKernelCall(
|
||||
kernel, op, dispatchKeySet, std::forward<Args>(args)...);
|
||||
guard.setOutputs(captureKernelCall.getOutputs());
|
||||
// Releases the captured output to return to caller.
|
||||
return std::move(captureKernelCall).release();
|
||||
}
|
||||
}
|
||||
}
|
||||
// keeping the guard alive while executing the kernel
|
||||
|
|
@ -558,11 +486,6 @@ inline void Dispatcher::callBoxed(const OperatorHandle& op, Stack* stack) const
|
|||
}
|
||||
// keeping the guard alive while executing the kernel
|
||||
kernel.callBoxed(op, dispatchKeySet, stack);
|
||||
// track outputs
|
||||
if (C10_UNLIKELY(
|
||||
guard.isActive() && entry.isObserved() && guard.needsOutputs())) {
|
||||
guard.setOutputs(*stack);
|
||||
}
|
||||
return;
|
||||
}
|
||||
#endif // PYTORCH_DISABLE_PER_OP_PROFILING
|
||||
|
|
|
|||
|
|
@ -189,7 +189,6 @@ class CallbackManager {
|
|||
// to be executed and whether any of them need inputs
|
||||
inline void init(RecordFunction& rec_fn, RecordScope scope, bool pre_sampled) {
|
||||
bool found_needs_inputs = false;
|
||||
bool found_needs_outputs = false;
|
||||
bool found_needs_ids = false;
|
||||
|
||||
for (const auto& cb: rf_tls_.sorted_tls_callbacks_) {
|
||||
|
|
@ -197,9 +196,6 @@ class CallbackManager {
|
|||
if (cb.first.needsInputs()) {
|
||||
found_needs_inputs = true;
|
||||
}
|
||||
if (cb.first.needsOutputs()) {
|
||||
found_needs_outputs = true;
|
||||
}
|
||||
if (cb.first.needsIds()) {
|
||||
found_needs_ids = true;
|
||||
}
|
||||
|
|
@ -215,9 +211,6 @@ class CallbackManager {
|
|||
if (cb.first.needsInputs()) {
|
||||
found_needs_inputs = true;
|
||||
}
|
||||
if (cb.first.needsOutputs()) {
|
||||
found_needs_outputs = true;
|
||||
}
|
||||
if (cb.first.needsIds()) {
|
||||
found_needs_ids = true;
|
||||
}
|
||||
|
|
@ -237,7 +230,6 @@ class CallbackManager {
|
|||
rec_fn.state_->global_ctx_.resize(rec_fn.state_->sorted_active_global_handles_.size());
|
||||
|
||||
rec_fn.state_->needs_inputs = found_needs_inputs;
|
||||
rec_fn.state_->needs_outputs = found_needs_outputs;
|
||||
if (found_needs_ids) {
|
||||
rec_fn.setHandle(next_unique_record_function_handle());
|
||||
}
|
||||
|
|
@ -460,8 +452,6 @@ void RecordFunction::before(
|
|||
state_->sequence_nr_ = sequence_nr;
|
||||
state_->thread_id_ = currentThreadId();
|
||||
state_->operator_name_ = op.operator_name();
|
||||
state_->op_input_size = op.schema().arguments().size();
|
||||
state_->op_output_size = op.schema().returns().size();
|
||||
state_->name_ = StringView(op.schema().name());
|
||||
|
||||
manager().runStartCallbacks(*this);
|
||||
|
|
|
|||
|
|
@ -35,7 +35,7 @@ enum class C10_API_ENUM RecordScope : uint8_t {
|
|||
namespace std {
|
||||
template <>
|
||||
struct hash<at::RecordScope> {
|
||||
size_t operator()(
|
||||
inline size_t operator()(
|
||||
const at::RecordScope& sc) const {
|
||||
return static_cast<std::size_t>(sc);
|
||||
}
|
||||
|
|
@ -52,7 +52,7 @@ struct TORCH_API StringView {
|
|||
: owned_str_ptr_(std::make_shared<std::string>(std::move(str))),
|
||||
str_ptr_(owned_str_ptr_->c_str()) {}
|
||||
|
||||
const char* str() const {
|
||||
inline const char* str() const {
|
||||
return str_ptr_;
|
||||
}
|
||||
|
||||
|
|
@ -116,12 +116,12 @@ struct TORCH_API RecordFunction {
|
|||
RecordFunction(const RecordFunction&) = delete;
|
||||
RecordFunction& operator=(const RecordFunction&) = delete;
|
||||
|
||||
const StringView& name() const {
|
||||
inline const StringView& name() const {
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(state_, "Called name() on inactive RecordFunction");
|
||||
return state_->name_;
|
||||
}
|
||||
|
||||
int64_t seqNr() const {
|
||||
inline int64_t seqNr() const {
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(state_, "Called seqNr() on inactive RecordFunction");
|
||||
return state_->sequence_nr_;
|
||||
}
|
||||
|
|
@ -131,35 +131,10 @@ struct TORCH_API RecordFunction {
|
|||
return state_->inputs_;
|
||||
}
|
||||
|
||||
const std::vector<c10::IValue>& outputs() const {
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(state_, "Called outputs() on inactive RecordFunction");
|
||||
return state_->outputs_;
|
||||
}
|
||||
|
||||
void setOutputs(std::vector<c10::IValue>&& outputs) const {
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(state_, "Called setOutputs() on inactive RecordFunction");
|
||||
state_->outputs_ = std::move(outputs);
|
||||
}
|
||||
|
||||
void setOutputs(c10::ArrayRef<c10::IValue> outputs) const {
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(state_, "Called setOutputs() on inactive RecordFunction");
|
||||
state_->outputs_ = outputs.vec();
|
||||
}
|
||||
|
||||
size_t num_inputs() const {
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(state_, "Called num_inputs() on inactive RecordFunction");
|
||||
return state_->op_input_size;
|
||||
}
|
||||
|
||||
size_t num_outputs() const {
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(state_, "Called num_outputs() on inactive RecordFunction");
|
||||
return state_->op_output_size;
|
||||
}
|
||||
|
||||
// Retrieves the thread_id that this RecordFunction ran start callbacks with.
|
||||
// Useful for writing thread safe end callbacks that may be potentially
|
||||
// executed in a different thread (async ops)
|
||||
uint64_t threadId() const {
|
||||
inline uint64_t threadId() const {
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(state_, "Called threadId() on inactive RecordFunction");
|
||||
return state_->thread_id_;
|
||||
}
|
||||
|
|
@ -168,17 +143,17 @@ struct TORCH_API RecordFunction {
|
|||
// or zero otherwise;
|
||||
// used alongside with sequence number to correlate backward functions with
|
||||
// the forward ones
|
||||
uint64_t forwardThreadId() const {
|
||||
inline uint64_t forwardThreadId() const {
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(state_, "Called forwardThreadId() on inactive RecordFunction");
|
||||
return state_->fwd_thread_id_;
|
||||
}
|
||||
|
||||
void setForwardThreadId(uint64_t thread_id) {
|
||||
inline void setForwardThreadId(uint64_t thread_id) {
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(state_, "Called setForwardThreadId() on inactive RecordFunction");
|
||||
state_->fwd_thread_id_ = thread_id;
|
||||
}
|
||||
|
||||
RecordScope scope() const {
|
||||
inline RecordScope scope() const {
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(state_, "Called scope() on inactive RecordFunction");
|
||||
return state_->scope_;
|
||||
}
|
||||
|
|
@ -227,17 +202,17 @@ struct TORCH_API RecordFunction {
|
|||
// Calls end callbacks. After end(), accessors will no longer provide useful results.
|
||||
void end();
|
||||
|
||||
RecordFunctionHandle handle() const {
|
||||
inline RecordFunctionHandle handle() const {
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(state_, "Called handle() on inactive RecordFunction");
|
||||
return state_->handle_;
|
||||
}
|
||||
|
||||
c10::optional<OperatorName> operator_name() const {
|
||||
inline c10::optional<OperatorName> operator_name() const {
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(state_, "Called operator_name() on inactive RecordFunction");
|
||||
return state_->operator_name_;
|
||||
}
|
||||
|
||||
void setHandle(RecordFunctionHandle handle) {
|
||||
inline void setHandle(RecordFunctionHandle handle) {
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(state_, "Called setHandle() on inactive RecordFunction");
|
||||
state_->handle_ = handle;
|
||||
}
|
||||
|
|
@ -252,11 +227,6 @@ struct TORCH_API RecordFunction {
|
|||
return state_->needs_inputs;
|
||||
}
|
||||
|
||||
bool needsOutputs() const {
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(state_, "Called needsOutputs() on inactive RecordFunction");
|
||||
return state_->needs_outputs;
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
// Allows the modification of some internal states for callbacks.
|
||||
|
|
@ -268,9 +238,6 @@ struct TORCH_API RecordFunction {
|
|||
// Whether any of the picked callbacks require inputs
|
||||
bool needs_inputs = false;
|
||||
|
||||
// Whether any of the picked callbacks require outputs
|
||||
bool needs_outputs = false;
|
||||
|
||||
// In cases when RecordFunction might be active but we chose not to
|
||||
// use the observers (e.g. operator is not observed), this boolean
|
||||
// flag is used to check whether the start callbacks were called
|
||||
|
|
@ -295,11 +262,8 @@ struct TORCH_API RecordFunction {
|
|||
StringView name_;
|
||||
int64_t sequence_nr_ = -1;
|
||||
std::vector<c10::IValue> inputs_;
|
||||
std::vector<c10::IValue> outputs_;
|
||||
|
||||
c10::optional<c10::OperatorName> operator_name_;
|
||||
size_t op_input_size{0};
|
||||
size_t op_output_size{0};
|
||||
|
||||
// Kind of scope this RecordFunction is observing
|
||||
const RecordScope scope_;
|
||||
|
|
@ -359,11 +323,6 @@ class TORCH_API RecordFunctionCallback {
|
|||
return *this;
|
||||
}
|
||||
|
||||
RecordFunctionCallback& needsOutputs(bool needs_outputs) {
|
||||
needs_outputs_ = needs_outputs;
|
||||
return *this;
|
||||
}
|
||||
|
||||
RecordFunctionCallback& needsIds(bool needs_ids) {
|
||||
needs_ids_ = needs_ids;
|
||||
return *this;
|
||||
|
|
@ -395,31 +354,27 @@ class TORCH_API RecordFunctionCallback {
|
|||
return *this;
|
||||
}
|
||||
|
||||
bool needsInputs() const {
|
||||
inline bool needsInputs() const {
|
||||
return needs_inputs_;
|
||||
}
|
||||
|
||||
bool needsOutputs() const {
|
||||
return needs_outputs_;
|
||||
}
|
||||
|
||||
bool needsIds() const {
|
||||
inline bool needsIds() const {
|
||||
return needs_ids_;
|
||||
}
|
||||
|
||||
double samplingProb() const {
|
||||
inline double samplingProb() const {
|
||||
return sampling_prob_;
|
||||
}
|
||||
|
||||
bool checkScope(RecordScope sc) const {
|
||||
inline bool checkScope(RecordScope sc) const {
|
||||
return scopes_[(size_t)sc];
|
||||
}
|
||||
|
||||
StartCallback start() const {
|
||||
inline StartCallback start() const {
|
||||
return start_;
|
||||
}
|
||||
|
||||
EndCallback end() const {
|
||||
inline EndCallback end() const {
|
||||
return end_;
|
||||
}
|
||||
|
||||
|
|
@ -431,7 +386,6 @@ class TORCH_API RecordFunctionCallback {
|
|||
double sampling_prob_ = 1.0;
|
||||
std::array<bool, static_cast<size_t>(RecordScope::NUM_SCOPES)> scopes_ = {};
|
||||
bool needs_inputs_ = false;
|
||||
bool needs_outputs_ = false;
|
||||
bool needs_ids_ = false;
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -691,10 +691,10 @@ at::Tensor invokeTestRecordFunctionJIT(at::Tensor& t) {
|
|||
return module->forward({t}).toTensor();
|
||||
}
|
||||
|
||||
using TracedTestValues =
|
||||
using TracedTestInputs =
|
||||
std::vector<std::tuple<std::string, std::vector<std::vector<int64_t>>>>;
|
||||
|
||||
void checkTracedInputs(const TracedTestValues& inputs) {
|
||||
void checkTracedInputs(const TracedTestInputs& inputs) {
|
||||
bool found_test = false;
|
||||
bool found_pow = false;
|
||||
bool found_mul = false;
|
||||
|
|
@ -722,32 +722,6 @@ void checkTracedInputs(const TracedTestValues& inputs) {
|
|||
TORCH_CHECK(found_mul);
|
||||
}
|
||||
|
||||
void checkTracedOutputs(const TracedTestValues& outputs) {
|
||||
bool found_test = false;
|
||||
bool found_pow = false;
|
||||
bool found_mul = false;
|
||||
for (const auto& output : outputs) {
|
||||
const auto& fn = std::get<0>(output);
|
||||
const auto& sizes = std::get<1>(output);
|
||||
|
||||
if (fn == "test") {
|
||||
found_test = true;
|
||||
TORCH_CHECK(sizes.empty());
|
||||
} else if (fn == "aten::pow") {
|
||||
found_pow = true;
|
||||
TORCH_CHECK(sizes.size() == 1);
|
||||
TORCH_CHECK(sizes[0] == std::vector<int64_t>({1, 2, 3}));
|
||||
} else if (fn == "aten::mul") {
|
||||
found_mul = true;
|
||||
TORCH_CHECK(sizes.size() == 1);
|
||||
TORCH_CHECK(sizes[0] == std::vector<int64_t>({1, 2, 3}));
|
||||
}
|
||||
}
|
||||
TORCH_CHECK(found_test);
|
||||
TORCH_CHECK(found_pow);
|
||||
TORCH_CHECK(found_mul);
|
||||
}
|
||||
|
||||
static bool bad_scope = false;
|
||||
template <RecordScope scope, size_t* cnt>
|
||||
std::unique_ptr<at::ObserverContext> checkScopeCallback(
|
||||
|
|
@ -829,10 +803,8 @@ static bool shouldRunCallback(const RecordFunctionCallback&) {
|
|||
return should_run;
|
||||
}
|
||||
|
||||
static TracedTestValues traced_inputs;
|
||||
static TracedTestValues traced_outputs;
|
||||
static std::unordered_set<std::string> ts_input_names;
|
||||
static std::unordered_set<std::string> ts_output_names;
|
||||
static TracedTestInputs traced_inputs;
|
||||
static std::unordered_set<std::string> ts_names;
|
||||
|
||||
std::unique_ptr<at::ObserverContext> tracedInputsCallback(
|
||||
const RecordFunction& fn) {
|
||||
|
|
@ -848,71 +820,43 @@ std::unique_ptr<at::ObserverContext> tracedInputsCallback(
|
|||
}
|
||||
traced_inputs.push_back(std::make_tuple(fn.name().str(), sizes));
|
||||
} else if (fn.scope() == RecordScope::TORCHSCRIPT_FUNCTION) {
|
||||
ts_input_names.insert(fn.name().str());
|
||||
ts_names.insert(fn.name().str());
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void tracedOutputsCallback(const RecordFunction& fn, ObserverContext* ctx_ptr) {
|
||||
if (fn.scope() == RecordScope::FUNCTION) {
|
||||
auto outputs = fn.outputs();
|
||||
std::vector<std::vector<int64_t>> sizes;
|
||||
for (const auto& output : outputs) {
|
||||
if (output.isTensor()) {
|
||||
sizes.push_back(output.toTensor().sizes().vec());
|
||||
} else if (output.isScalar()) {
|
||||
sizes.emplace_back();
|
||||
}
|
||||
}
|
||||
traced_outputs.push_back(std::make_tuple(fn.name().str(), sizes));
|
||||
} else if (fn.scope() == RecordScope::TORCHSCRIPT_FUNCTION) {
|
||||
ts_output_names.insert(fn.name().str());
|
||||
}
|
||||
}
|
||||
|
||||
TEST(RecordFunctionTest, TracedTestInputsOutputs) {
|
||||
TEST(RecordFunctionTest, TracedTestInputs) {
|
||||
// disabling the inlining of method calls
|
||||
GraphOptimizerEnabledGuard opt_guard(false);
|
||||
|
||||
// [(fn, [[sizes], [sizes], ...]), ...]
|
||||
addGlobalCallback(
|
||||
RecordFunctionCallback(tracedInputsCallback, tracedOutputsCallback)
|
||||
.needsInputs(true)
|
||||
.needsOutputs(true));
|
||||
RecordFunctionCallback(tracedInputsCallback).needsInputs(true));
|
||||
|
||||
TracedTestValues eager_inputs, eager_outputs, jit_inputs, jit_outputs;
|
||||
TracedTestInputs eager_inputs, jit_inputs;
|
||||
{
|
||||
auto t = torch::randn({1, 2, 3}, at::kCPU);
|
||||
t.set_requires_grad(true);
|
||||
auto t2 = invokeTestRecordFunction(t);
|
||||
t2.backward(torch::ones_like(t2, at::MemoryFormat::Preserve));
|
||||
eager_inputs = traced_inputs;
|
||||
eager_outputs = traced_outputs;
|
||||
traced_inputs.clear();
|
||||
traced_outputs.clear();
|
||||
|
||||
TORCH_CHECK(ts_input_names.empty());
|
||||
TORCH_CHECK(ts_output_names.empty());
|
||||
TORCH_CHECK(ts_names.empty());
|
||||
|
||||
t = torch::randn({1, 2, 3}, at::kCPU);
|
||||
t.set_requires_grad(true);
|
||||
t2 = invokeTestRecordFunctionJIT(t);
|
||||
t2.backward(torch::ones_like(t2, at::MemoryFormat::Preserve));
|
||||
jit_inputs = traced_inputs;
|
||||
jit_outputs = traced_outputs;
|
||||
traced_inputs.clear();
|
||||
traced_outputs.clear();
|
||||
}
|
||||
|
||||
TORCH_CHECK(ts_input_names.find("forward") != ts_input_names.end());
|
||||
TORCH_CHECK(ts_input_names.find("foo") != ts_input_names.end());
|
||||
TORCH_CHECK(ts_output_names.find("forward") != ts_output_names.end());
|
||||
TORCH_CHECK(ts_output_names.find("foo") != ts_output_names.end());
|
||||
TORCH_CHECK(ts_names.find("forward") != ts_names.end());
|
||||
TORCH_CHECK(ts_names.find("foo") != ts_names.end());
|
||||
|
||||
checkTracedInputs(eager_inputs);
|
||||
checkTracedOutputs(eager_outputs);
|
||||
checkTracedInputs(jit_inputs);
|
||||
checkTracedOutputs(jit_outputs);
|
||||
at::clearCallbacks();
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user