mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/62336 This PR was generated by removing `const` for all types of nodes in NNC IR, and fixing compilation errors that were the result of this change. This is the first step in making all NNC mutations in-place. Test Plan: Imported from OSS Reviewed By: iramazanli Differential Revision: D30049829 Pulled By: navahgar fbshipit-source-id: ed14e2d2ca0559ffc0b92ac371f405579c85dd63
82 lines
2.2 KiB
C++
82 lines
2.2 KiB
C++
#include <torch/csrc/jit/tensorexpr/codegen.h>
|
|
|
|
#include <sstream>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
namespace tensorexpr {
|
|
|
|
RegisterCodeGenList::StmtFactoryMethod RegisterCodeGenList::
|
|
FindStmtFactoryMethod(const std::string& name) {
|
|
auto iter = stmt_factory_methods_.find(name);
|
|
if (iter == stmt_factory_methods_.end()) {
|
|
std::ostringstream oss;
|
|
oss << "Invalid stmt codegen name: " << name << ". ";
|
|
oss << "Existing codegen names: [";
|
|
int index = 0;
|
|
for (auto& entry : stmt_factory_methods_) {
|
|
if (index != 0) {
|
|
oss << ", ";
|
|
}
|
|
oss << entry.first;
|
|
index++;
|
|
}
|
|
oss << "]";
|
|
throw std::runtime_error(oss.str());
|
|
}
|
|
return iter->second;
|
|
}
|
|
|
|
void RegisterCodeGenList::AddStmtFactoryMethod(
|
|
const std::string& name,
|
|
const StmtFactoryMethod& stmt_factory_method) {
|
|
stmt_factory_methods_[name] = stmt_factory_method;
|
|
}
|
|
|
|
std::unique_ptr<CodeGen> CreateCodeGen(
|
|
const std::string& name,
|
|
Stmt* stmt,
|
|
const std::vector<CodeGen::BufferArg>& params,
|
|
at::Device device,
|
|
const std::string& kernel_func_name) {
|
|
RegisterCodeGenList::StmtFactoryMethod method =
|
|
RegisterCodeGenList::GetInstance().FindStmtFactoryMethod(name);
|
|
return method(stmt, params, device, kernel_func_name);
|
|
}
|
|
|
|
Expr* GenericIntrinsicsExpander::mutate(Intrinsics* v) {
|
|
if (v->op_type() == kSigmoid) {
|
|
auto x = v->param(0)->accept_mutator(this);
|
|
auto one = expr_to_vec(
|
|
ExprHandle(getImmediateByType(v->dtype(), 1.0)), v->dtype().lanes());
|
|
auto zero = expr_to_vec(
|
|
ExprHandle(getImmediateByType(v->dtype(), 0.0)), v->dtype().lanes());
|
|
ExprHandle y = one / (one + exp(zero - ExprHandle(x)));
|
|
return y.node();
|
|
}
|
|
return IRMutator::mutate(v);
|
|
}
|
|
|
|
void* CodeGen::argToPtr(const BufferArg& bufferArg, const CallArg& callArg) {
|
|
if (!bufferArg.isVar()) {
|
|
return callArg.data();
|
|
}
|
|
|
|
switch (bufferArg.dtype().scalar_type()) {
|
|
#define TYPE_CASE(_1, Name) \
|
|
case ScalarType::Name: \
|
|
return callArg.Name##Ptr();
|
|
|
|
AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, TYPE_CASE);
|
|
#undef TYPE_CASE
|
|
|
|
default:
|
|
throw unsupported_dtype();
|
|
}
|
|
return nullptr;
|
|
}
|
|
|
|
} // namespace tensorexpr
|
|
} // namespace jit
|
|
} // namespace torch
|