[NNC] Update Buf on mutation instead of creating new ones (#57513)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/57513

Test Plan: Imported from OSS

Reviewed By: VitalyFedyunin

Differential Revision: D28226917

Pulled By: navahgar

fbshipit-source-id: 4e74c56a85b7aadc285b872b8ef8f8e26f31c8ce
This commit is contained in:
Raghavan Raman 2021-05-06 01:05:49 -07:00 committed by Facebook GitHub Bot
parent 95fbc158d4
commit eef72f3f8a
3 changed files with 20 additions and 6 deletions

View File

@ -153,6 +153,10 @@ class TORCH_API Var : public ExprNode<Var> {
return name_hint_;
}
void set_name_hint(const std::string& name_hint) {
name_hint_ = name_hint;
}
Var(std::string name_hint, Dtype dtype)
: ExprNodeBase(dtype, kPrimitive), name_hint_(std::move(name_hint)) {}
@ -172,9 +176,16 @@ class TORCH_API Buf : public ExprNode<Buf> {
const Var* base_handle() const {
return base_handle_;
}
void set_base_handle(Var* base_handle) {
base_handle_ = base_handle;
}
const std::string& name_hint() const {
return base_handle_->name_hint();
}
void set_name_hint(const std::string& name_hint) {
base_handle_->set_name_hint(name_hint);
}
Buf(const std::string& name_hint,
const std::vector<const Expr*>& dims,
@ -182,7 +193,7 @@ class TORCH_API Buf : public ExprNode<Buf> {
const Expr* initializer = nullptr)
: Buf(new Var(name_hint, kHandle), dims, dtype, initializer) {}
Buf(const Var* var,
Buf(Var* var,
std::vector<const Expr*> dims,
Dtype dtype,
const Expr* initializer = nullptr)
@ -223,7 +234,7 @@ class TORCH_API Buf : public ExprNode<Buf> {
}
private:
const Var* base_handle_;
Var* base_handle_;
std::vector<const Expr*> dims_;
const Expr* initializer_;
};

View File

@ -184,9 +184,10 @@ const Expr* IRMutator::mutate(const Load* v) {
return new Load(dtype, buf_new, indices_new);
}
const Expr* IRMutator::mutate(const Buf* v) {
const Expr* IRMutator::mutate(Buf* v) {
const Var* var = v->base_handle();
const Var* var_new = dynamic_cast<const Var*>(var->accept_mutator(this));
Var* var_new =
dynamic_cast<Var*>(const_cast<Expr*>(var->accept_mutator(this)));
if (!var_new) {
return nullptr;
}
@ -203,7 +204,9 @@ const Expr* IRMutator::mutate(const Buf* v) {
return (Expr*)v;
}
return new Buf(var_new, dims_new, v->dtype());
v->set_base_handle(var_new);
v->set_dims(dims_new);
return v;
}
const Expr* IRMutator::mutate(const Broadcast* v) {

View File

@ -77,7 +77,7 @@ class TORCH_API IRMutator {
virtual const Expr* mutate(const Cast* v);
virtual const Expr* mutate(const BitCast* v);
virtual const Expr* mutate(const Var* v);
virtual const Expr* mutate(const Buf* v);
virtual const Expr* mutate(Buf* v);
virtual const Expr* mutate(const Ramp* v);
virtual const Expr* mutate(const Load* v);
virtual const Expr* mutate(const Broadcast* v);