mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Add shape_env guards to tracing context (#90876)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90876 Approved by: https://github.com/Chillee, https://github.com/ezyang
This commit is contained in:
parent
a01c1ee594
commit
53e71fad8f
|
|
@ -196,13 +196,24 @@ class OutputGraph(fx.Tracer, Checkpointable[OutputGraphState]):
|
||||||
super(OutputGraph, self).__init__()
|
super(OutputGraph, self).__init__()
|
||||||
self.graph = torch.fx.Graph()
|
self.graph = torch.fx.Graph()
|
||||||
self.graphargs: List[GraphArg] = []
|
self.graphargs: List[GraphArg] = []
|
||||||
self.tracing_context: TracingContext = TracingContext()
|
fake_mode = torch._subclasses.FakeTensorMode(
|
||||||
|
throw_on_data_dependent_ops=True,
|
||||||
|
shape_env=ShapeEnv() if config.dynamic_shapes else None,
|
||||||
|
)
|
||||||
|
self.tracing_context: TracingContext = TracingContext(fake_mode)
|
||||||
# tracked_fakes says where any tensor that was wrapped to fake came
|
# tracked_fakes says where any tensor that was wrapped to fake came
|
||||||
# from. It is similar to GraphArg, in that all GraphArgs will get
|
# from. It is similar to GraphArg, in that all GraphArgs will get
|
||||||
# will get added to TrackedFakes, but TrackedFakes also contains
|
# will get added to TrackedFakes, but TrackedFakes also contains
|
||||||
# GraphArgs that got pruned, and things like Tensor attributes which
|
# GraphArgs that got pruned, and things like Tensor attributes which
|
||||||
# aren't explicit graph inputs. Used by shape guard
|
# aren't explicit graph inputs. Used by shape guard
|
||||||
self.tracked_fakes: List[TrackedFake] = []
|
self.tracked_fakes: List[TrackedFake] = []
|
||||||
|
# Although we prune unused graphargs before sending graphs to
|
||||||
|
# compilers, we may have legitimately triggered shape guards
|
||||||
|
# on "unused" inputs that we must keep track of. So after
|
||||||
|
# remove_unused_graphargs is called, orig_graphargs and
|
||||||
|
# graphargs no longer alias; orig_graphargs is the original
|
||||||
|
# graphargs, and graphargs is the pruned list. Guard creation
|
||||||
|
# should use original graphargs.
|
||||||
self.orig_graphargs: List[GraphArg] = self.graphargs
|
self.orig_graphargs: List[GraphArg] = self.graphargs
|
||||||
self.nn_modules: Optional[Dict[str, torch.nn.Module]] = dict()
|
self.nn_modules: Optional[Dict[str, torch.nn.Module]] = dict()
|
||||||
self.side_effects = SideEffects()
|
self.side_effects = SideEffects()
|
||||||
|
|
@ -228,7 +239,6 @@ class OutputGraph(fx.Tracer, Checkpointable[OutputGraphState]):
|
||||||
self.unspec_variable_map: Dict[
|
self.unspec_variable_map: Dict[
|
||||||
str, Union[UnspecializedNumpyVariable, UnspecializedPythonVariable]
|
str, Union[UnspecializedNumpyVariable, UnspecializedPythonVariable]
|
||||||
] = {}
|
] = {}
|
||||||
self.shape_env = ShapeEnv() if config.dynamic_shapes else None
|
|
||||||
self.intermediary_symbols: Dict[sympy.Expr, None] = {}
|
self.intermediary_symbols: Dict[sympy.Expr, None] = {}
|
||||||
|
|
||||||
# Enables creating unique node names by tracking
|
# Enables creating unique node names by tracking
|
||||||
|
|
@ -245,6 +255,10 @@ class OutputGraph(fx.Tracer, Checkpointable[OutputGraphState]):
|
||||||
def fake_mode(self):
|
def fake_mode(self):
|
||||||
return self.root_tx.fake_mode
|
return self.root_tx.fake_mode
|
||||||
|
|
||||||
|
@property
|
||||||
|
def shape_env(self):
|
||||||
|
return self.tracing_context.fake_mode.shape_env
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def guards(self) -> Set[Guard]:
|
def guards(self) -> Set[Guard]:
|
||||||
return self.tracing_context.guards_context.dynamo_guards
|
return self.tracing_context.guards_context.dynamo_guards
|
||||||
|
|
|
||||||
|
|
@ -1577,10 +1577,7 @@ class InstructionTranslatorBase(Checkpointable[InstructionTranslatorGraphState])
|
||||||
# Flag to indicate whether tracing is used for export.
|
# Flag to indicate whether tracing is used for export.
|
||||||
self.export = export
|
self.export = export
|
||||||
|
|
||||||
self._fake_mode = torch._subclasses.FakeTensorMode(
|
self._fake_mode = output.tracing_context.fake_mode
|
||||||
throw_on_data_dependent_ops=True,
|
|
||||||
shape_env=output.shape_env,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.checkpoint = None
|
self.checkpoint = None
|
||||||
self.random_calls = []
|
self.random_calls = []
|
||||||
|
|
|
||||||
|
|
@ -285,8 +285,9 @@ class TracingContext:
|
||||||
def get() -> Optional["TracingContext"]:
|
def get() -> Optional["TracingContext"]:
|
||||||
return _CURRENT_TRACING_CONTEXT
|
return _CURRENT_TRACING_CONTEXT
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self, fake_mode):
|
||||||
self.guards_context = GuardsContext()
|
self.guards_context = GuardsContext()
|
||||||
|
self.fake_mode = fake_mode
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -753,7 +753,7 @@ class ShapeEnv(object):
|
||||||
try:
|
try:
|
||||||
exprs.append(ShapeGuardPrinter(symbol_to_source).doprint(g))
|
exprs.append(ShapeGuardPrinter(symbol_to_source).doprint(g))
|
||||||
except Exception:
|
except Exception:
|
||||||
log.warning(f"Failing guard allocated at:\n{tb}")
|
log.warning(f"Failing guard allocated at: \n{tb}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
# 3. Every symbol must not be equal to 0/1
|
# 3. Every symbol must not be equal to 0/1
|
||||||
|
|
@ -821,7 +821,7 @@ class ShapeEnv(object):
|
||||||
return bindings
|
return bindings
|
||||||
|
|
||||||
def get_nontrivial_guards(self):
|
def get_nontrivial_guards(self):
|
||||||
return [self.simplify(guard) for guard, _ in self.guards if self._maybe_evaluate_static(guard) is None]
|
return [self.simplify(guard.expr) for guard in self.guards if self._maybe_evaluate_static(guard.expr) is None]
|
||||||
|
|
||||||
def format_guards(self, verbose=False):
|
def format_guards(self, verbose=False):
|
||||||
def format_tb(tb):
|
def format_tb(tb):
|
||||||
|
|
@ -829,7 +829,7 @@ class ShapeEnv(object):
|
||||||
return ""
|
return ""
|
||||||
return f"\n Guarded at:\n{textwrap.indent(tb, ' ')}"
|
return f"\n Guarded at:\n{textwrap.indent(tb, ' ')}"
|
||||||
|
|
||||||
return '\n'.join(f" - {guard}{format_tb(tb)}" for guard, tb in self.guards)
|
return '\n'.join(f" - {guard.expr}{format_tb(guard.stack)}" for guard in self.guards)
|
||||||
|
|
||||||
def get_shape_groups(self):
|
def get_shape_groups(self):
|
||||||
shape_groups = collections.defaultdict(list)
|
shape_groups = collections.defaultdict(list)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user