mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Make assertions refine types (#23949)
Summary: Make assertions like `x is not None` refine the type of x. This is easy to do now that typing understands [exits](https://github.com/pytorch/pytorch/pull/23565). Pull Request resolved: https://github.com/pytorch/pytorch/pull/23949 Differential Revision: D16692772 Pulled By: eellison fbshipit-source-id: 540f28e65a784c72c7c555e0aed0765d5035bc37
This commit is contained in:
parent
0f5d071d52
commit
e90adf59a0
|
|
@ -5094,6 +5094,17 @@ a")
|
|||
if y is not None and x_none:
|
||||
print(x + y) # noqa: T484
|
||||
|
||||
def test_assertion_optional_refinement(self):
|
||||
@torch.jit.script
|
||||
def test(x, y):
|
||||
# type: (Optional[int], Optional[int]) -> int
|
||||
assert x is not None and y is not None
|
||||
return x + y
|
||||
|
||||
self.assertEqual(test(2, 2), 4)
|
||||
with self.assertRaisesRegex(Exception, ""):
|
||||
test(1, None)
|
||||
|
||||
def test_optional_tensor(self):
|
||||
@torch.jit.script
|
||||
def fn(x, y):
|
||||
|
|
|
|||
|
|
@ -1488,19 +1488,16 @@ struct to_ir {
|
|||
exit_blocks.insert(environment_stack->block());
|
||||
}
|
||||
|
||||
// emit assserions as an if branch so that assertions will reuse the
|
||||
// emitIfElseBlocks refining of types
|
||||
void emitAssert(const Assert& stmt) {
|
||||
Value* cond_value = emitCond(stmt.test());
|
||||
Node* n = graph->insertNode(create(prim::If, stmt.range(), 0));
|
||||
|
||||
n->addInput(cond_value);
|
||||
/* true_block =*/n->addBlock();
|
||||
auto* false_block = n->addBlock();
|
||||
|
||||
// if assert test is false throw exception
|
||||
pushFrame(false_block);
|
||||
WithInsertPoint guard(false_block);
|
||||
emitRaise(stmt.range());
|
||||
popFrame();
|
||||
List<Stmt> true_branch = List<Stmt>::create(stmt.range(), {});
|
||||
List<Stmt> false_branch =
|
||||
List<Stmt>::create(stmt.range(), {Raise::create(stmt.range())});
|
||||
auto if_stmt =
|
||||
If::create(stmt.range(), stmt.test(), true_branch, false_branch);
|
||||
emitIfElseBlocks(cond_value, if_stmt);
|
||||
}
|
||||
|
||||
// Validate that the `lhs` Expr's in an assignment statement are valid. That
|
||||
|
|
|
|||
|
|
@ -659,6 +659,10 @@ struct Raise : public Stmt {
|
|||
static Raise create(const SourceRange& range, const Maybe<Expr>& expr) {
|
||||
return Raise(Compound::create(TK_RAISE, range, {expr}));
|
||||
}
|
||||
static Raise create(const SourceRange& range) {
|
||||
return Raise(
|
||||
Compound::create(TK_RAISE, range, {Maybe<Expr>::create(range)}));
|
||||
}
|
||||
};
|
||||
|
||||
struct Assert : public Stmt {
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user