mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/40840 **Summary** This commit moves the TestBackend used for the JIT backend extension to the tests directory. It was temporarily placed in the source directory while figuring out some details of the user experience for this feature. **Test Plan** `python test/test_jit.py TestBackends` **Fixes** This commit fixes #40067. Test Plan: Imported from OSS Differential Revision: D22418682 Pulled By: SplitInfinity fbshipit-source-id: 9356af1341ec4d552a41c2a8929b327bc8b56057
77 lines
2.4 KiB
C++
77 lines
2.4 KiB
C++
#include <torch/csrc/jit/api/module.h>
|
|
#include <torch/csrc/jit/backends/backend.h>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
// This test JIT backend is intended to do the minimal amount of work
|
|
// necessary to test that the JIT backend registration endpoints and
|
|
// code generation are working correctly. It is not intended to
|
|
// produce numerically correct results.
|
|
class TestBackend : public PyTorchBackendInterface {
|
|
public:
|
|
// Constructor.
|
|
explicit TestBackend() {}
|
|
virtual ~TestBackend() = default;
|
|
|
|
c10::IValue preprocess(
|
|
c10::IValue mod,
|
|
c10::impl::GenericDict method_compile_spec) override {
|
|
return mod;
|
|
}
|
|
|
|
c10::impl::GenericDict compile(
|
|
c10::IValue processed,
|
|
c10::impl::GenericDict method_compile_spec) override {
|
|
auto spec =
|
|
c10::impl::toTypedDict<std::string, at::IValue>(method_compile_spec);
|
|
|
|
// Return the same string as a value for every key in method_compile_spec.
|
|
auto handles = c10::Dict<std::string, std::string>();
|
|
for (auto it = spec.begin(), end = spec.end(); it != end; ++it) {
|
|
handles.insert(it->key(), it->key());
|
|
}
|
|
return c10::impl::toGenericDict(handles);
|
|
}
|
|
c10::impl::GenericList execute(
|
|
c10::IValue handle,
|
|
c10::impl::GenericList inputs) override {
|
|
TORCH_INTERNAL_ASSERT(handle.isString());
|
|
TORCH_INTERNAL_ASSERT(inputs.size() > 0);
|
|
|
|
c10::List<at::Tensor> output_list;
|
|
|
|
// Implement simple accumulator and negative accumulator (?) ops. Return one
|
|
// or both of them depending on the handle to make sure multiple outputs are
|
|
// handled.
|
|
c10::IValue value = inputs[0];
|
|
at::Tensor accum = value.toTensor();
|
|
accum = accum.clone();
|
|
at::Tensor sub_accum = value.toTensor();
|
|
sub_accum = sub_accum.clone();
|
|
|
|
for (size_t i = 1, e = inputs.size(); i < e; ++i) {
|
|
value = inputs[i];
|
|
accum.add_(value.toTensor(), 1.0);
|
|
sub_accum.sub_(value.toTensor(), 1.0);
|
|
}
|
|
|
|
if (handle.toStringRef() == "accum") {
|
|
output_list.emplace_back(accum);
|
|
} else if (handle.toStringRef() == "sub_accum") {
|
|
output_list.emplace_back(sub_accum);
|
|
} else if (handle.toStringRef() == "forward") {
|
|
output_list.emplace_back(accum);
|
|
output_list.emplace_back(sub_accum);
|
|
}
|
|
|
|
return c10::impl::toList(output_list);
|
|
}
|
|
};
|
|
|
|
namespace {
|
|
static auto cls = torch::jit::backend<TestBackend>("test_backend");
|
|
}
|
|
|
|
} // namespace jit
|
|
} // namespace torch
|