mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Support C++ statically_known_true (#151346)
Differential Revision: [D73040543](https://our.internmc.facebook.com/intern/diff/D73040543/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/151346 Approved by: https://github.com/laithsakka
This commit is contained in:
parent
8895c290f4
commit
eb1f85a2a0
|
|
@ -222,8 +222,8 @@ inline Tensor applySlice(
|
||||||
? (*self_sizes)[dim]
|
? (*self_sizes)[dim]
|
||||||
: self.sym_size(dim);
|
: self.sym_size(dim);
|
||||||
if (!disable_slice_optimization &&
|
if (!disable_slice_optimization &&
|
||||||
TORCH_GUARD_SIZE_OBLIVIOUS(start.sym_eq(0)) &&
|
TORCH_STATICALLY_KNOWN_TRUE(start.sym_eq(0)) &&
|
||||||
TORCH_GUARD_SIZE_OBLIVIOUS(length.sym_eq(stop)) && step == 1) {
|
TORCH_STATICALLY_KNOWN_TRUE(length.sym_eq(stop)) && step == 1) {
|
||||||
return self;
|
return self;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -80,6 +80,14 @@ bool SymBool::guard_or_false(const char* file, int64_t line) const {
|
||||||
return a->guard_or_false(file, line);
|
return a->guard_or_false(file, line);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool SymBool::statically_known_true(const char* file, int64_t line) const {
|
||||||
|
if (auto ma = maybe_as_bool()) {
|
||||||
|
return *ma;
|
||||||
|
}
|
||||||
|
SymNode a = toSymNodeImpl();
|
||||||
|
return a->statically_known_true(file, line);
|
||||||
|
}
|
||||||
|
|
||||||
bool SymBool::guard_or_true(const char* file, int64_t line) const {
|
bool SymBool::guard_or_true(const char* file, int64_t line) const {
|
||||||
if (auto ma = maybe_as_bool()) {
|
if (auto ma = maybe_as_bool()) {
|
||||||
return *ma;
|
return *ma;
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <c10/core/SymNodeImpl.h>
|
#include <c10/core/SymNodeImpl.h>
|
||||||
|
|
@ -62,6 +63,7 @@ class C10_API SymBool {
|
||||||
bool guard_bool(const char* file, int64_t line) const;
|
bool guard_bool(const char* file, int64_t line) const;
|
||||||
bool expect_true(const char* file, int64_t line) const;
|
bool expect_true(const char* file, int64_t line) const;
|
||||||
bool guard_size_oblivious(const char* file, int64_t line) const;
|
bool guard_size_oblivious(const char* file, int64_t line) const;
|
||||||
|
bool statically_known_true(const char* file, int64_t line) const;
|
||||||
bool guard_or_false(const char* file, int64_t line) const;
|
bool guard_or_false(const char* file, int64_t line) const;
|
||||||
bool guard_or_true(const char* file, int64_t line) const;
|
bool guard_or_true(const char* file, int64_t line) const;
|
||||||
|
|
||||||
|
|
@ -129,6 +131,20 @@ inline bool guard_or_false(
|
||||||
return b.guard_or_false(file, line);
|
return b.guard_or_false(file, line);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline bool statically_known_true(
|
||||||
|
bool b,
|
||||||
|
const char* file [[maybe_unused]],
|
||||||
|
int64_t line [[maybe_unused]]) {
|
||||||
|
return b;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline bool statically_known_true(
|
||||||
|
const c10::SymBool& b,
|
||||||
|
const char* file,
|
||||||
|
int64_t line) {
|
||||||
|
return b.statically_known_true(file, line);
|
||||||
|
}
|
||||||
|
|
||||||
inline bool guard_or_true(
|
inline bool guard_or_true(
|
||||||
bool b,
|
bool b,
|
||||||
const char* file [[maybe_unused]],
|
const char* file [[maybe_unused]],
|
||||||
|
|
@ -146,6 +162,9 @@ inline bool guard_or_true(
|
||||||
#define TORCH_GUARD_SIZE_OBLIVIOUS(cond) \
|
#define TORCH_GUARD_SIZE_OBLIVIOUS(cond) \
|
||||||
c10::guard_size_oblivious((cond), __FILE__, __LINE__)
|
c10::guard_size_oblivious((cond), __FILE__, __LINE__)
|
||||||
|
|
||||||
|
#define TORCH_STATICALLY_KNOWN_TRUE(cond) \
|
||||||
|
c10::statically_known_true((cond), __FILE__, __LINE__)
|
||||||
|
|
||||||
#define TORCH_GUARD_OR_FALSE(cond) \
|
#define TORCH_GUARD_OR_FALSE(cond) \
|
||||||
c10::guard_or_false((cond), __FILE__, __LINE__)
|
c10::guard_or_false((cond), __FILE__, __LINE__)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <c10/macros/Export.h>
|
#include <c10/macros/Export.h>
|
||||||
|
|
@ -191,6 +192,11 @@ class C10_API SymNodeImpl : public c10::intrusive_ptr_target {
|
||||||
// with a better implementation!
|
// with a better implementation!
|
||||||
return guard_bool(file, line);
|
return guard_bool(file, line);
|
||||||
}
|
}
|
||||||
|
virtual bool statically_known_true(const char* file, int64_t line) {
|
||||||
|
// No improvement for unbacked SymBools by default, replace this
|
||||||
|
// with a better implementation!
|
||||||
|
return guard_bool(file, line);
|
||||||
|
}
|
||||||
virtual bool guard_or_true(const char* file, int64_t line) {
|
virtual bool guard_or_true(const char* file, int64_t line) {
|
||||||
// No improvement for unbacked SymBools by default, replace this
|
// No improvement for unbacked SymBools by default, replace this
|
||||||
// with a better implementation!
|
// with a better implementation!
|
||||||
|
|
|
||||||
|
|
@ -696,6 +696,30 @@ graph():
|
||||||
ep = export(f, args, strict=False)
|
ep = export(f, args, strict=False)
|
||||||
self.assertEqual(ep.module()(*args), f(*args))
|
self.assertEqual(ep.module()(*args), f(*args))
|
||||||
|
|
||||||
|
@testing.expectedFailureCppSerDes # Cpp serder seems to fail parsing complicated guards
|
||||||
|
def test_export_statically_known_true(self):
|
||||||
|
class Foo(torch.nn.Module):
|
||||||
|
def forward(self, x, y):
|
||||||
|
shape = y.shape[0] ** 2 - 3 * y.shape[0]
|
||||||
|
end = shape
|
||||||
|
return x[:, :end]
|
||||||
|
|
||||||
|
dynamic_shapes = (
|
||||||
|
(torch.export.Dim.DYNAMIC, torch.export.Dim.DYNAMIC),
|
||||||
|
(torch.export.Dim.DYNAMIC, torch.export.Dim.DYNAMIC),
|
||||||
|
)
|
||||||
|
|
||||||
|
ep = export(
|
||||||
|
Foo(),
|
||||||
|
(torch.randn(4, 4), torch.randn(4, 4)),
|
||||||
|
dynamic_shapes=dynamic_shapes,
|
||||||
|
strict=False,
|
||||||
|
)
|
||||||
|
FileCheck().check_count("torch.ops.aten.slice.Tensor", 2, exactly=True).run(
|
||||||
|
str(ep.graph)
|
||||||
|
)
|
||||||
|
FileCheck().check_count("operator.sub", 1, exactly=True).run(str(ep.graph))
|
||||||
|
|
||||||
def test_colon_parameter(self):
|
def test_colon_parameter(self):
|
||||||
class M(torch.nn.Module):
|
class M(torch.nn.Module):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
|
|
|
||||||
|
|
@ -1218,7 +1218,7 @@ def forward(self, x_1):
|
||||||
gm = make_fx(f, tracing_mode="symbolic")(src_tokens)
|
gm = make_fx(f, tracing_mode="symbolic")(src_tokens)
|
||||||
# Guards to rule out batch_size == sys.maxsize (wobbling between 2 and
|
# Guards to rule out batch_size == sys.maxsize (wobbling between 2 and
|
||||||
# 1 ok)
|
# 1 ok)
|
||||||
self.assertEqual(len(gm.shape_env.guards), 1)
|
self.assertEqual(len(gm.shape_env.guards), 0)
|
||||||
|
|
||||||
@unittest.skipIf(not HAS_CUDA, 'CUDA-only test')
|
@unittest.skipIf(not HAS_CUDA, 'CUDA-only test')
|
||||||
def test_cpu_scalar_cuda(self):
|
def test_cpu_scalar_cuda(self):
|
||||||
|
|
|
||||||
|
|
@ -140,6 +140,11 @@ class PythonSymNodeImpl : public c10::SymNodeImpl {
|
||||||
return getPyObj().attr("guard_or_false")(file, line).cast<bool>();
|
return getPyObj().attr("guard_or_false")(file, line).cast<bool>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool statically_known_true(const char* file, int64_t line) override {
|
||||||
|
py::gil_scoped_acquire acquire;
|
||||||
|
return getPyObj().attr("statically_known_true")(file, line).cast<bool>();
|
||||||
|
}
|
||||||
|
|
||||||
bool guard_or_true(const char* file, int64_t line) override {
|
bool guard_or_true(const char* file, int64_t line) override {
|
||||||
py::gil_scoped_acquire acquire;
|
py::gil_scoped_acquire acquire;
|
||||||
return getPyObj().attr("guard_or_true")(file, line).cast<bool>();
|
return getPyObj().attr("guard_or_true")(file, line).cast<bool>();
|
||||||
|
|
|
||||||
|
|
@ -572,6 +572,12 @@ class SymNode:
|
||||||
_advise_is_size(SymInt(self))
|
_advise_is_size(SymInt(self))
|
||||||
return r
|
return r
|
||||||
|
|
||||||
|
def statically_known_true(self, file, line):
|
||||||
|
from torch.fx.experimental.symbolic_shapes import statically_known_true
|
||||||
|
|
||||||
|
assert self.is_bool()
|
||||||
|
return statically_known_true(SymBool(self))
|
||||||
|
|
||||||
def guard_size_oblivious(self, file, line):
|
def guard_size_oblivious(self, file, line):
|
||||||
"""
|
"""
|
||||||
Like guard_bool, but if we encounter unbacked symbols, if those symbols
|
Like guard_bool, but if we encounter unbacked symbols, if those symbols
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user