Make assertions refine types (#23949)

Summary:
Make assertions like `x is not None` refine the type of x. This is easy to do now that typing understands [exits](https://github.com/pytorch/pytorch/pull/23565).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/23949

Differential Revision: D16692772

Pulled By: eellison

fbshipit-source-id: 540f28e65a784c72c7c555e0aed0765d5035bc37
This commit is contained in:
Elias Ellison 2019-08-07 12:58:59 -07:00 committed by Facebook Github Bot
parent 0f5d071d52
commit e90adf59a0
3 changed files with 23 additions and 11 deletions

View File

@ -5094,6 +5094,17 @@ a")
if y is not None and x_none:
print(x + y) # noqa: T484
def test_assertion_optional_refinement(self):
@torch.jit.script
def test(x, y):
# type: (Optional[int], Optional[int]) -> int
assert x is not None and y is not None
return x + y
self.assertEqual(test(2, 2), 4)
with self.assertRaisesRegex(Exception, ""):
test(1, None)
def test_optional_tensor(self):
@torch.jit.script
def fn(x, y):

View File

@ -1488,19 +1488,16 @@ struct to_ir {
exit_blocks.insert(environment_stack->block());
}
// emit assserions as an if branch so that assertions will reuse the
// emitIfElseBlocks refining of types
void emitAssert(const Assert& stmt) {
Value* cond_value = emitCond(stmt.test());
Node* n = graph->insertNode(create(prim::If, stmt.range(), 0));
n->addInput(cond_value);
/* true_block =*/n->addBlock();
auto* false_block = n->addBlock();
// if assert test is false throw exception
pushFrame(false_block);
WithInsertPoint guard(false_block);
emitRaise(stmt.range());
popFrame();
List<Stmt> true_branch = List<Stmt>::create(stmt.range(), {});
List<Stmt> false_branch =
List<Stmt>::create(stmt.range(), {Raise::create(stmt.range())});
auto if_stmt =
If::create(stmt.range(), stmt.test(), true_branch, false_branch);
emitIfElseBlocks(cond_value, if_stmt);
}
// Validate that the `lhs` Expr's in an assignment statement are valid. That

View File

@ -659,6 +659,10 @@ struct Raise : public Stmt {
static Raise create(const SourceRange& range, const Maybe<Expr>& expr) {
return Raise(Compound::create(TK_RAISE, range, {expr}));
}
static Raise create(const SourceRange& range) {
return Raise(
Compound::create(TK_RAISE, range, {Maybe<Expr>::create(range)}));
}
};
struct Assert : public Stmt {