pytorch/torch/csrc/lazy/backend/backend_interface.cpp
Antonio Kim f3f327e103 Decouple LTC from TS Backend using Lazy IR Builder
Next stage of breaking up https://github.com/pytorch/pytorch/pull/74710

IR builder class introduced to decouple the explicit usage of `TsNode` in core lazy tensors.

Requires https://github.com/pytorch/pytorch/pull/75324 to be merged in first.

**Background**
- there are ~ 5 special ops used in lazy core but defined as :public {Backend}Node.  (DeviceData, Expand, Scalar...)
- we currently require all nodes derive from {Backend}Node, so that backends can make this assumption safely
- it is hard to have shared 'IR classes' in core/ because they depend on 'Node'

**Motivation**

1. avoid copy-paste of "special" node classes for each backend
2. in general decouple and remove all dependencies that LTC has on the TS backend

**Summary of changes**
- new 'IRBuilder' interface that knows how to make 5 special ops
- move 'special' node classes to `ts_backend/`
- implement TSIRBuilder that makes the special TS Nodes
- new backend interface API to get the IRBuilder
- update core code to call the builder

CC: @wconstab @JackCaoG @henrytwo

Partially Fixes #74628

Pull Request resolved: https://github.com/pytorch/pytorch/pull/75433
Approved by: https://github.com/wconstab
2022-04-28 02:07:02 +00:00

55 lines
1.5 KiB
C++

#include <torch/csrc/lazy/backend/backend_interface.h>
namespace torch {
namespace lazy {
namespace {
std::atomic<const BackendImplInterface*> backend_impl_registry;
} // namespace
bool hasBackend() {
return !!backend_impl_registry.load();
}
const BackendImplInterface* getBackend() {
auto* interface = backend_impl_registry.load();
TORCH_CHECK(interface, "Lazy tensor backend not registered.");
return interface;
}
BackendRegistrar::BackendRegistrar(
const BackendImplInterface* backend_impl_interface) {
backend_impl_registry.store(backend_impl_interface);
}
// Get IrBuilder from backend. Use TorchScriptIrBuilder by default
const IrBuilder* getIrBuilder() {
static const IrBuilder* builder = getBackend()->GetIrBuilder();
return builder;
}
at::Tensor MakeTensorFromComputationData(
const BackendDataPtr data,
c10::optional<at::ScalarType> logical_scalar_type) {
return getBackend()->MakeTensorFromComputationData(data, logical_scalar_type);
}
std::unique_ptr<LoweringContext> LoweringContext::Create(
const std::string& name,
BackendDevice device,
c10::ArrayRef<Node*> post_order,
Util::EmissionMap emit_status) {
return getBackend()->CreateLoweringContext(
name, device, post_order, emit_status);
}
std::unique_ptr<LoweringContext> LoweringContext::Create(
const std::string& name,
BackendDevice device) {
return getBackend()->CreateLoweringContext(name, device);
}
} // namespace lazy
} // namespace torch