mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[NNC] Update Buf on mutation instead of creating new ones (#57513)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/57513 Test Plan: Imported from OSS Reviewed By: VitalyFedyunin Differential Revision: D28226917 Pulled By: navahgar fbshipit-source-id: 4e74c56a85b7aadc285b872b8ef8f8e26f31c8ce
This commit is contained in:
parent
95fbc158d4
commit
eef72f3f8a
|
|
@ -153,6 +153,10 @@ class TORCH_API Var : public ExprNode<Var> {
|
|||
return name_hint_;
|
||||
}
|
||||
|
||||
void set_name_hint(const std::string& name_hint) {
|
||||
name_hint_ = name_hint;
|
||||
}
|
||||
|
||||
Var(std::string name_hint, Dtype dtype)
|
||||
: ExprNodeBase(dtype, kPrimitive), name_hint_(std::move(name_hint)) {}
|
||||
|
||||
|
|
@ -172,9 +176,16 @@ class TORCH_API Buf : public ExprNode<Buf> {
|
|||
const Var* base_handle() const {
|
||||
return base_handle_;
|
||||
}
|
||||
void set_base_handle(Var* base_handle) {
|
||||
base_handle_ = base_handle;
|
||||
}
|
||||
|
||||
const std::string& name_hint() const {
|
||||
return base_handle_->name_hint();
|
||||
}
|
||||
void set_name_hint(const std::string& name_hint) {
|
||||
base_handle_->set_name_hint(name_hint);
|
||||
}
|
||||
|
||||
Buf(const std::string& name_hint,
|
||||
const std::vector<const Expr*>& dims,
|
||||
|
|
@ -182,7 +193,7 @@ class TORCH_API Buf : public ExprNode<Buf> {
|
|||
const Expr* initializer = nullptr)
|
||||
: Buf(new Var(name_hint, kHandle), dims, dtype, initializer) {}
|
||||
|
||||
Buf(const Var* var,
|
||||
Buf(Var* var,
|
||||
std::vector<const Expr*> dims,
|
||||
Dtype dtype,
|
||||
const Expr* initializer = nullptr)
|
||||
|
|
@ -223,7 +234,7 @@ class TORCH_API Buf : public ExprNode<Buf> {
|
|||
}
|
||||
|
||||
private:
|
||||
const Var* base_handle_;
|
||||
Var* base_handle_;
|
||||
std::vector<const Expr*> dims_;
|
||||
const Expr* initializer_;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -184,9 +184,10 @@ const Expr* IRMutator::mutate(const Load* v) {
|
|||
return new Load(dtype, buf_new, indices_new);
|
||||
}
|
||||
|
||||
const Expr* IRMutator::mutate(const Buf* v) {
|
||||
const Expr* IRMutator::mutate(Buf* v) {
|
||||
const Var* var = v->base_handle();
|
||||
const Var* var_new = dynamic_cast<const Var*>(var->accept_mutator(this));
|
||||
Var* var_new =
|
||||
dynamic_cast<Var*>(const_cast<Expr*>(var->accept_mutator(this)));
|
||||
if (!var_new) {
|
||||
return nullptr;
|
||||
}
|
||||
|
|
@ -203,7 +204,9 @@ const Expr* IRMutator::mutate(const Buf* v) {
|
|||
return (Expr*)v;
|
||||
}
|
||||
|
||||
return new Buf(var_new, dims_new, v->dtype());
|
||||
v->set_base_handle(var_new);
|
||||
v->set_dims(dims_new);
|
||||
return v;
|
||||
}
|
||||
|
||||
const Expr* IRMutator::mutate(const Broadcast* v) {
|
||||
|
|
|
|||
|
|
@ -77,7 +77,7 @@ class TORCH_API IRMutator {
|
|||
virtual const Expr* mutate(const Cast* v);
|
||||
virtual const Expr* mutate(const BitCast* v);
|
||||
virtual const Expr* mutate(const Var* v);
|
||||
virtual const Expr* mutate(const Buf* v);
|
||||
virtual const Expr* mutate(Buf* v);
|
||||
virtual const Expr* mutate(const Ramp* v);
|
||||
virtual const Expr* mutate(const Load* v);
|
||||
virtual const Expr* mutate(const Broadcast* v);
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user