mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Next stage of breaking up https://github.com/pytorch/pytorch/pull/74710 IR builder class introduced to decouple the explicit usage of `TsNode` in core lazy tensors. Requires https://github.com/pytorch/pytorch/pull/75324 to be merged in first. **Background** - there are ~ 5 special ops used in lazy core but defined as :public {Backend}Node. (DeviceData, Expand, Scalar...) - we currently require all nodes derive from {Backend}Node, so that backends can make this assumption safely - it is hard to have shared 'IR classes' in core/ because they depend on 'Node' **Motivation** 1. avoid copy-paste of "special" node classes for each backend 2. in general decouple and remove all dependencies that LTC has on the TS backend **Summary of changes** - new 'IRBuilder' interface that knows how to make 5 special ops - move 'special' node classes to `ts_backend/` - implement TSIRBuilder that makes the special TS Nodes - new backend interface API to get the IRBuilder - update core code to call the builder CC: @wconstab @JackCaoG @henrytwo Partially Fixes #74628 Pull Request resolved: https://github.com/pytorch/pytorch/pull/75433 Approved by: https://github.com/wconstab
55 lines
1.5 KiB
C++
55 lines
1.5 KiB
C++
#include <torch/csrc/lazy/backend/backend_interface.h>
|
|
|
|
namespace torch {
|
|
namespace lazy {
|
|
|
|
namespace {
|
|
std::atomic<const BackendImplInterface*> backend_impl_registry;
|
|
} // namespace
|
|
|
|
bool hasBackend() {
|
|
return !!backend_impl_registry.load();
|
|
}
|
|
|
|
const BackendImplInterface* getBackend() {
|
|
auto* interface = backend_impl_registry.load();
|
|
TORCH_CHECK(interface, "Lazy tensor backend not registered.");
|
|
return interface;
|
|
}
|
|
|
|
BackendRegistrar::BackendRegistrar(
|
|
const BackendImplInterface* backend_impl_interface) {
|
|
backend_impl_registry.store(backend_impl_interface);
|
|
}
|
|
|
|
// Get IrBuilder from backend. Use TorchScriptIrBuilder by default
|
|
const IrBuilder* getIrBuilder() {
|
|
static const IrBuilder* builder = getBackend()->GetIrBuilder();
|
|
return builder;
|
|
}
|
|
|
|
|
|
at::Tensor MakeTensorFromComputationData(
|
|
const BackendDataPtr data,
|
|
c10::optional<at::ScalarType> logical_scalar_type) {
|
|
return getBackend()->MakeTensorFromComputationData(data, logical_scalar_type);
|
|
}
|
|
|
|
std::unique_ptr<LoweringContext> LoweringContext::Create(
|
|
const std::string& name,
|
|
BackendDevice device,
|
|
c10::ArrayRef<Node*> post_order,
|
|
Util::EmissionMap emit_status) {
|
|
return getBackend()->CreateLoweringContext(
|
|
name, device, post_order, emit_status);
|
|
}
|
|
|
|
std::unique_ptr<LoweringContext> LoweringContext::Create(
|
|
const std::string& name,
|
|
BackendDevice device) {
|
|
return getBackend()->CreateLoweringContext(name, device);
|
|
}
|
|
|
|
} // namespace lazy
|
|
} // namespace torch
|