mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[inductor][subgraph] Plumbing to get ShapeAsConstantBuffer from subgraph to main graph output (#147559)
I am unable to create a test case that fails without the next PR. The idea is to have a symint which is returned by the inner subgraph and then returned by the forward graph after partitioning. Pull Request resolved: https://github.com/pytorch/pytorch/pull/147559 Approved by: https://github.com/eellison
This commit is contained in:
parent
c87097e74a
commit
fd16311e7f
|
|
@ -1269,6 +1269,7 @@ class GraphLowering(torch.fx.Interpreter):
|
||||||
sympy.logic.boolalg.Boolean,
|
sympy.logic.boolalg.Boolean,
|
||||||
int,
|
int,
|
||||||
ir.EffectfulKernel,
|
ir.EffectfulKernel,
|
||||||
|
ir.ShapeAsConstantBuffer,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
for x in result
|
for x in result
|
||||||
|
|
|
||||||
|
|
@ -208,6 +208,7 @@ def validate_ir(node_or_nodes: Optional[_NodeOrNodes]) -> None:
|
||||||
Expr,
|
Expr,
|
||||||
int,
|
int,
|
||||||
EffectfulKernel,
|
EffectfulKernel,
|
||||||
|
ShapeAsConstantBuffer,
|
||||||
),
|
),
|
||||||
), (
|
), (
|
||||||
f"Found {type(nodes)}, which is not a supported top level IR node. See [Note: Inductor IR]"
|
f"Found {type(nodes)}, which is not a supported top level IR node. See [Note: Inductor IR]"
|
||||||
|
|
@ -5166,7 +5167,7 @@ class ExternKernel(InputsKernel):
|
||||||
# TODO(jansel): impose layout preference on realized buffer
|
# TODO(jansel): impose layout preference on realized buffer
|
||||||
x.realize()
|
x.realize()
|
||||||
return x
|
return x
|
||||||
if isinstance(x, (NonTensorObj)):
|
if isinstance(x, (NonTensorObj, ShapeAsConstantBuffer)):
|
||||||
return x
|
return x
|
||||||
return cls.copy_input(x)
|
return cls.copy_input(x)
|
||||||
|
|
||||||
|
|
@ -7041,6 +7042,8 @@ class MutableBox(IRNode):
|
||||||
class TensorBox(MutableBox):
|
class TensorBox(MutableBox):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create(data): # type: ignore[no-untyped-def]
|
def create(data): # type: ignore[no-untyped-def]
|
||||||
|
if isinstance(data, ShapeAsConstantBuffer):
|
||||||
|
return data
|
||||||
return TensorBox(StorageBox(data))
|
return TensorBox(StorageBox(data))
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user