mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
So, this is a little awkward, so I don't mind more thoughts on how best to do this.
Let's suppose that you have a graph break inside of an inlined function call. We are not actually going to print this graph break yet; instead, we are going to restart analysis so that we can run up until the inlined function call. When this happens, the only log message we ever get is the log to `graph_break` (seen here) reporting that a graph break has occurred.
In the current code, we don't print the fully formatted exception if you are only using `graph_breaks` logging. So the exception that induced the graph break has its traceback lost forever. For some classes of errors, esp., guard on data-dependent SymInt, this is quite bad.
With this change, we do print the traceback. On this sample program:
```
import torch
import torch._dynamo.config
torch._dynamo.config.capture_scalar_outputs = True
def g(x, y):
y = x.item()
if y < 3:
return x + 2
else:
return x + 3
@torch.compile()
def f(x, y):
y = y * y
return g(x, y)
f(torch.tensor(4), torch.randn(4))
```
It looks like this:
```
[2024-02-06 10:32:24,334] [0/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG] Graph break: Traceback (most recent call last):
[2024-02-06 10:32:24,334] [0/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG] File "/data/users/ezyang/b/pytorch/torch/_dynamo/variables/tensor.py", line 878, in evaluate_expr
[2024-02-06 10:32:24,334] [0/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG] return guard_scalar(self.sym_num)
[2024-02-06 10:32:24,334] [0/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG] File "/data/users/ezyang/b/pytorch/torch/fx/experimental/symbolic_shapes.py", line 414, in guard_scalar
[2024-02-06 10:32:24,334] [0/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG] return guard_bool(a)
[2024-02-06 10:32:24,334] [0/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG] File "/data/users/ezyang/b/pytorch/torch/fx/experimental/symbolic_shapes.py", line 663, in guard_bool
[2024-02-06 10:32:24,334] [0/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG] return a.node.guard_bool("", 0) # NB: uses Python backtrace
[2024-02-06 10:32:24,334] [0/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG] File "/data/users/ezyang/b/pytorch/torch/fx/experimental/sym_node.py", line 366, in guard_bool
[2024-02-06 10:32:24,334] [0/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG] r = self.shape_env.evaluate_expr(self.expr, self.hint, fx_node=self.fx_node)
[2024-02-06 10:32:24,334] [0/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG] File "/data/users/ezyang/b/pytorch/torch/fx/experimental/recording.py", line 227, in wrapper
[2024-02-06 10:32:24,334] [0/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG] return fn(*args, **kwargs)
[2024-02-06 10:32:24,334] [0/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG] File "/data/users/ezyang/b/pytorch/torch/fx/experimental/symbolic_shapes.py", line 3670, in evaluate_expr
[2024-02-06 10:32:24,334] [0/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG] concrete_val = self.size_hint(orig_expr)
[2024-02-06 10:32:24,334] [0/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG] File "/data/users/ezyang/b/pytorch/torch/fx/experimental/symbolic_shapes.py", line 3403, in size_hint
[2024-02-06 10:32:24,334] [0/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG] raise self._make_data_dependent_error(result_expr, expr)
[2024-02-06 10:32:24,334] [0/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG] torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: It appears that you're trying to get a value out of symbolic int/float whose value is data-dependent (and thus we do not know the true value.) The expression we were trying to evaluate is u0 < 3 (unhinted: u0 < 3). For more information, run with TORCH_LOGS="+dynamic".
[2024-02-06 10:32:24,334] [0/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG]
[2024-02-06 10:32:24,334] [0/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG]
[2024-02-06 10:32:24,334] [0/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG] During handling of the above exception, another exception occurred:
[2024-02-06 10:32:24,334] [0/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG]
[2024-02-06 10:32:24,334] [0/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG] Traceback (most recent call last):
[2024-02-06 10:32:24,334] [0/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG] File "/data/users/ezyang/b/pytorch/torch/_dynamo/symbolic_convert.py", line 469, in wrapper
[2024-02-06 10:32:24,334] [0/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG] return inner_fn(self, inst)
[2024-02-06 10:32:24,334] [0/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG] File "/data/users/ezyang/b/pytorch/torch/_dynamo/symbolic_convert.py", line 1196, in CALL_FUNCTION
[2024-02-06 10:32:24,334] [0/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG] self.call_function(fn, args, {})
[2024-02-06 10:32:24,334] [0/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG] File "/data/users/ezyang/b/pytorch/torch/_dynamo/symbolic_convert.py", line 651, in call_function
[2024-02-06 10:32:24,334] [0/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG] self.push(fn.call_function(self, args, kwargs))
[2024-02-06 10:32:24,334] [0/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG] File "/data/users/ezyang/b/pytorch/torch/_dynamo/variables/functions.py", line 279, in call_function
[2024-02-06 10:32:24,334] [0/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG] return super().call_function(tx, args, kwargs)
[2024-02-06 10:32:24,334] [0/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG] File "/data/users/ezyang/b/pytorch/torch/_dynamo/variables/functions.py", line 87, in call_function
[2024-02-06 10:32:24,334] [0/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG] return tx.inline_user_function_return(
[2024-02-06 10:32:24,334] [0/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG] File "/data/users/ezyang/b/pytorch/torch/_dynamo/symbolic_convert.py", line 657, in inline_user_function_return
[2024-02-06 10:32:24,334] [0/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG] return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
[2024-02-06 10:32:24,334] [0/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG] File "/data/users/ezyang/b/pytorch/torch/_dynamo/symbolic_convert.py", line 2262, in inline_call
[2024-02-06 10:32:24,334] [0/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG] return cls.inline_call_(parent, func, args, kwargs)
[2024-02-06 10:32:24,334] [0/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG] File "/data/users/ezyang/b/pytorch/torch/_dynamo/symbolic_convert.py", line 2372, in inline_call_
[2024-02-06 10:32:24,334] [0/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG] tracer.run()
[2024-02-06 10:32:24,334] [0/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG] File "/data/users/ezyang/b/pytorch/torch/_dynamo/symbolic_convert.py", line 787, in run
[2024-02-06 10:32:24,334] [0/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG] and self.step()
[2024-02-06 10:32:24,334] [0/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG] File "/data/users/ezyang/b/pytorch/torch/_dynamo/symbolic_convert.py", line 750, in step
[2024-02-06 10:32:24,334] [0/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG] getattr(self, inst.opname)(inst)
[2024-02-06 10:32:24,334] [0/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG] File "/data/users/ezyang/b/pytorch/torch/_dynamo/symbolic_convert.py", line 431, in inner
[2024-02-06 10:32:24,334] [0/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG] eval_result = value.evaluate_expr(self.output)
[2024-02-06 10:32:24,334] [0/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG] File "/data/users/ezyang/b/pytorch/torch/_dynamo/variables/tensor.py", line 880, in evaluate_expr
[2024-02-06 10:32:24,334] [0/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG] raise UserError( # noqa: TRY200
[2024-02-06 10:32:24,334] [0/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG] torch._dynamo.exc.UserError: Consider annotating your code using torch._constrain_as_*(). It appears that you're trying to get a value out of symbolic int/float whose value is data-dependent (and thus we do not know the true value.) The expression we were trying to evaluate is u0 < 3 (unhinted: u0 < 3). For more information, run with TORCH_LOGS="+dynamic".
[2024-02-06 10:32:24,334] [0/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG]
[2024-02-06 10:32:24,334] [0/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG] For more information about this error, see: https://pytorch.org/docs/main/generated/exportdb/index.html#constrain-as-size-example
[2024-02-06 10:32:24,334] [0/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG]
[2024-02-06 10:32:24,334] [0/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG] From user code at:
[2024-02-06 10:32:24,334] [0/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG] File "/data/users/ezyang/b/pytorch/b.py", line 16, in f
[2024-02-06 10:32:24,334] [0/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG] return g(x, y)
[2024-02-06 10:32:24,334] [0/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG] File "/data/users/ezyang/b/pytorch/b.py", line 8, in g
[2024-02-06 10:32:24,334] [0/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG] if y < 3:
[2024-02-06 10:32:24,334] [0/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG]
```
The end of the log at restarted computation maybe can be improved too. Right now it looks like this:
```
[2024-02-06 10:32:24,338] [0/0_1] torch._dynamo.symbolic_convert: [DEBUG] TRACE CALL_FUNCTION 2 [UserFunctionVariable(), LazyVariableTracker(), TensorVariable()]
[2024-02-06 10:32:24,338] [0/0_1] torch._dynamo.output_graph: [DEBUG] COMPILING GRAPH due to GraphCompileReason(reason='Consider annotating your code using torch._constrain_as_*(). It appears that you\'re trying to get a value out of symbolic int/float whose value is data-dependent (and thus we do not know the true value.) The expression we were trying to evaluate is u0 < 3 (unhinted: u0 < 3). For more information, run with TORCH_LOGS="+dynamic".\n\nFor more information about this error, see: https://pytorch.org/docs/main/generated/exportdb/index.html#constrain-as-size-example', user_stack=[<FrameSummary file /data/users/ezyang/b/pytorch/b.py, line 16 in f>, <FrameSummary file /data/users/ezyang/b/pytorch/b.py, line 8 in g>], graph_break=True)
```
An alternative to doing it this way, is I can make symbolic shapes print a warning log when guard on unbacked SymInt itself, so we don't have to worry about Dynamo generating the backtrace well. If, for the most part, the backtrace for other graph breaks is irrelevant, then this would seem to be a more expedient solution.
PTAL and submit your opinions.
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119292
Approved by: https://github.com/yanboliang
336 lines
9.3 KiB
Python
336 lines
9.3 KiB
Python
# Owner(s): ["module: dynamo"]
|
|
|
|
import logging
|
|
import unittest
|
|
|
|
import torch
|
|
import torch._dynamo
|
|
import torch._dynamo.config
|
|
import torch._dynamo.test_case
|
|
from torch._dynamo.comptime import comptime
|
|
from torch._dynamo.exc import Unsupported
|
|
from torch.testing._internal.common_device_type import skipIf
|
|
from torch.testing._internal.common_utils import IS_FBCODE, munge_exc, TEST_Z3
|
|
from torch.testing._internal.logging_utils import LoggingTestCase, make_logging_test
|
|
|
|
|
|
class ExcTests(LoggingTestCase):
|
|
maxDiff = None
|
|
|
|
def test_unsupported_real_stack(self):
|
|
# exercise Unsupported constructor and augment_exc_message
|
|
def fn002(x):
|
|
torch._dynamo.graph_break()
|
|
|
|
def fn001(x):
|
|
x = x + 1
|
|
fn002(x)
|
|
|
|
self.assertExpectedInlineMunged(
|
|
Unsupported,
|
|
lambda: torch.compile(fn001, backend="eager", fullgraph=True)(
|
|
torch.randn(1)
|
|
),
|
|
"""\
|
|
'skip function graph_break in file _dynamo/decorators.py'
|
|
|
|
from user code:
|
|
File "test_exc.py", line N, in fn001
|
|
fn002(x)
|
|
File "test_exc.py", line N, in fn002
|
|
torch._dynamo.graph_break()""",
|
|
)
|
|
|
|
@torch._dynamo.config.patch(verbose=True, suppress_errors=True)
|
|
@make_logging_test()
|
|
@unittest.skipIf(IS_FBCODE, "stack trace slightly different in fbcode")
|
|
def test_internal_error_suppress_errors(self, records):
|
|
def fn001(x):
|
|
def f(ctx):
|
|
raise AssertionError()
|
|
|
|
comptime(f)
|
|
|
|
torch.compile(fn001, backend="eager")(torch.randn(1))
|
|
|
|
record = self.getRecord(records, "WON'T CONVERT")
|
|
|
|
self.assertExpectedInline(
|
|
munge_exc(record.getMessage()),
|
|
"""\
|
|
WON'T CONVERT fn001 test_exc.py line N
|
|
========== TorchDynamo Stack Trace ==========
|
|
Traceback (most recent call last):
|
|
File "test_exc.py", line N, in f
|
|
raise AssertionError()
|
|
AssertionError:
|
|
|
|
from user code:
|
|
File "test_exc.py", line N, in fn001
|
|
comptime(f)
|
|
|
|
|
|
========== The above exception occurred while processing the following code ==========
|
|
|
|
File "test_exc.py", line N, in test_internal_error_suppress_errors
|
|
torch.compile(fn001, backend="eager")(torch.randn(1))
|
|
File "test_exc.py", line N, in fn001
|
|
comptime(f)
|
|
|
|
==========""",
|
|
)
|
|
|
|
@make_logging_test()
|
|
def test_not_implemented_error(self, records):
|
|
def fn001(x):
|
|
def f(ctx):
|
|
raise NotImplementedError()
|
|
|
|
# Ensure graph break is not possible
|
|
for i in range(3):
|
|
comptime(f)
|
|
|
|
torch.compile(fn001, backend="eager")(torch.randn(1))
|
|
|
|
record = self.getRecord(records, "WON'T CONVERT")
|
|
|
|
self.assertExpectedInline(
|
|
munge_exc(record.getMessage()),
|
|
"""\
|
|
WON'T CONVERT fn001 test_exc.py line N
|
|
due to:
|
|
Traceback (most recent call last):
|
|
File "test_exc.py", line N, in f
|
|
raise NotImplementedError()
|
|
torch._dynamo.exc.InternalTorchDynamoError:
|
|
|
|
from user code:
|
|
File "test_exc.py", line N, in fn001
|
|
comptime(f)""",
|
|
)
|
|
|
|
@unittest.expectedFailure
|
|
@torch._dynamo.config.patch(inject_BUILD_SET_unimplemented_TESTING_ONLY=True)
|
|
@make_logging_test(dynamo=logging.DEBUG)
|
|
def test_unsupported_error(self, records):
|
|
def fn001(x):
|
|
return {1, 2}
|
|
|
|
torch.compile(fn001, backend="eager")(torch.randn(1))
|
|
|
|
# TODO: There is no graph break log! This is because the graph break
|
|
# logging is not in a centralized location; unsupported
|
|
# instruction bypasses it
|
|
self.getRecord(records, "Graph break:")
|
|
|
|
@torch._dynamo.config.patch(suppress_errors=False)
|
|
def test_internal_error_no_suppress(self):
|
|
def fn001(x):
|
|
# NB: avoid decorator, as 3.11 changed the line number attributed
|
|
# in this situation
|
|
def f(ctx):
|
|
raise AssertionError()
|
|
|
|
comptime(f)
|
|
|
|
# NB: OK for user code to be truncated here, because the regular
|
|
# exception backtrace has the rest of the crumbs
|
|
self.assertExpectedInlineMunged(
|
|
AssertionError,
|
|
lambda: torch.compile(fn001, backend="eager")(torch.randn(1)),
|
|
"""\
|
|
|
|
|
|
from user code:
|
|
File "test_exc.py", line N, in fn001
|
|
comptime(f)""",
|
|
)
|
|
|
|
@make_logging_test(graph_breaks=True)
|
|
def test_graph_break_log(self, records):
|
|
def fn002(x):
|
|
x = x + 1
|
|
torch._dynamo.graph_break()
|
|
x = x + 1
|
|
return x
|
|
|
|
def fn001(x):
|
|
return fn002(x)
|
|
|
|
torch.compile(fn001, backend="eager")(torch.randn(1))
|
|
|
|
record = self.getRecord(records, "Graph break:")
|
|
|
|
# TODO: This should also report the enclosing frames; need to plumb
|
|
# frame object to it
|
|
self.assertExpectedInline(
|
|
munge_exc(record.getMessage()),
|
|
"""\
|
|
Graph break: from user code at:
|
|
File "test_exc.py", line N, in fn001
|
|
return fn002(x)
|
|
File "test_exc.py", line N, in fn002
|
|
torch._dynamo.graph_break()
|
|
""", # noqa: B950
|
|
)
|
|
|
|
@torch._dynamo.config.patch(suppress_errors=False)
|
|
def test_backend_suppress_line(self):
|
|
def fn001(x):
|
|
x = torch.relu(x)
|
|
return x + 1
|
|
|
|
# Do NOT let this get attributed to x + 1
|
|
self.assertExpectedInlineMunged(
|
|
torch._dynamo.exc.BackendCompilerFailed,
|
|
lambda: torch.compile(fn001, backend="relu_compile_error_TESTING_ONLY")(
|
|
torch.randn(1)
|
|
),
|
|
"""\
|
|
backend='relu_compile_error_TESTING_ONLY' raised:
|
|
ReluCompileError:""",
|
|
)
|
|
|
|
@skipIf(not TEST_Z3, "z3 not installed")
|
|
@torch._dynamo.config.patch(
|
|
assume_static_by_default=False,
|
|
suppress_errors=False,
|
|
)
|
|
@torch.fx.experimental._config.patch(
|
|
inject_EVALUATE_EXPR_flip_equality_TESTING_ONLY=True,
|
|
translation_validation=True,
|
|
translation_validation_no_bisect=True,
|
|
)
|
|
def test_trigger_on_error(self):
|
|
from torch.fx.experimental.validator import ValidationException
|
|
|
|
@torch.compile
|
|
def fn(x, shape):
|
|
return x.split(shape)
|
|
|
|
self.assertExpectedInlineMunged(
|
|
ValidationException,
|
|
lambda: fn(torch.randn(20), (5, 10, 5)),
|
|
"""\
|
|
translation validation failed.
|
|
|
|
Model:
|
|
==> L['shape'][0]: 0
|
|
==> L['shape'][1]: 0
|
|
==> L['shape'][2]: 0
|
|
==> L['x'].size()[0]: 3
|
|
==> L['x'].storage_offset(): 0
|
|
==> L['x'].stride()[0]: 1
|
|
==> s0: 3
|
|
==> s1: 0
|
|
==> s2: 0
|
|
==> s3: 0
|
|
|
|
Assertions:
|
|
==> (== 0 L['x'].storage_offset())
|
|
==> (== 1 L['x'].stride()[0])
|
|
==> (== L['shape'][0] s1)
|
|
==> (== L['shape'][1] s2)
|
|
==> (== L['shape'][2] s3)
|
|
==> (== L['x'].size()[0] s0)
|
|
==> (> s0 1)
|
|
==> (True)
|
|
|
|
Target Expressions:
|
|
==> (<= 0 s1)
|
|
==> (<= 0 s2)
|
|
==> (<= 0 s3)
|
|
==> (<= 2 s0)
|
|
==> (== 0 L['shape'][0])
|
|
==> (== 0 L['shape'][1])
|
|
==> (== 0 L['shape'][2])
|
|
==> (== 0 L['x'].storage_offset())
|
|
==> (== 0 s1)
|
|
==> (== 0 s2)
|
|
==> (== 0 s3)
|
|
==> (== 1 L['x'].stride()[0])
|
|
==> (== L['x'].size()[0] s0)
|
|
==> (> s0 0)
|
|
==> (>= 9223372036854775806 s0)
|
|
==> (>= 9223372036854775806 s1)
|
|
==> (>= 9223372036854775806 s2)
|
|
==> (>= 9223372036854775806 s3)
|
|
|
|
Failed Source Expressions:
|
|
==> (== (+ L['shape'][0] L['shape'][1] L['shape'][2]) L['x'].size()[0])""",
|
|
)
|
|
|
|
@skipIf(not TEST_Z3, "z3 not installed")
|
|
@torch._dynamo.config.patch(
|
|
assume_static_by_default=False,
|
|
suppress_errors=False,
|
|
)
|
|
@torch.fx.experimental._config.patch(
|
|
inject_EVALUATE_EXPR_flip_equality_TESTING_ONLY=True,
|
|
translation_validation=True,
|
|
)
|
|
def test_trigger_bisect_on_error(self):
|
|
from torch.fx.experimental.validator import BisectValidationException
|
|
|
|
@torch.compile
|
|
def fn(x, shape):
|
|
return x.split(shape)
|
|
|
|
self.assertExpectedInlineMunged(
|
|
BisectValidationException,
|
|
lambda: fn(torch.randn(20), (5, 10, 5)),
|
|
"""\
|
|
translation validation failed when evaluating: Eq(s1 + s2 + s3, s0)
|
|
|
|
Failure occurred while running node:
|
|
%split : [num_users=3] = call_method[target=split](args = (%l_x_, (%l_shape_0_, %l_shape_1_, %l_shape_2_)), kwargs = {})
|
|
|
|
Model:
|
|
==> L['shape'][0]: -9223372036854775807
|
|
==> L['shape'][1]: -9223372036854775807
|
|
==> L['shape'][2]: -9223372036854775807
|
|
==> L['x'].size()[0]: 3
|
|
==> L['x'].storage_offset(): 0
|
|
==> L['x'].stride()[0]: 1
|
|
==> s0: 3
|
|
==> s1: -9223372036854775807
|
|
==> s2: -9223372036854775807
|
|
==> s3: -9223372036854775807
|
|
|
|
Assertions:
|
|
==> (== 0 L['x'].storage_offset())
|
|
==> (== 1 L['x'].stride()[0])
|
|
==> (== L['shape'][0] s1)
|
|
==> (== L['shape'][1] s2)
|
|
==> (== L['shape'][2] s3)
|
|
==> (== L['x'].size()[0] s0)
|
|
==> (> s0 1)
|
|
|
|
Target Expressions:
|
|
==> (!= (+ s1 s2 s3) s0)
|
|
==> (<= -9223372036854775808 s1)
|
|
==> (<= -9223372036854775808 s2)
|
|
==> (<= -9223372036854775808 s3)
|
|
==> (<= 2 s0)
|
|
==> (== 0 L['x'].storage_offset())
|
|
==> (== 1 L['x'].stride()[0])
|
|
==> (== L['shape'][0] s1)
|
|
==> (== L['shape'][1] s2)
|
|
==> (== L['shape'][2] s3)
|
|
==> (== L['x'].size()[0] s0)
|
|
==> (> s0 0)
|
|
==> (>= 9223372036854775806 s0)
|
|
==> (>= 9223372036854775807 s1)
|
|
==> (>= 9223372036854775807 s2)
|
|
==> (>= 9223372036854775807 s3)
|
|
|
|
Failed Source Expressions:
|
|
==> (== (+ L['shape'][0] L['shape'][1] L['shape'][2]) L['x'].size()[0])""",
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from torch._dynamo.test_case import run_tests
|
|
|
|
run_tests()
|