Add support for parsing torch.Generator in JIT (#140489)

Fixes #140420

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140489
Approved by: https://github.com/davidberard98
This commit is contained in:
Antonio Kim 2024-11-13 23:06:54 +00:00 committed by PyTorch MergeBot
parent 70060b0927
commit b34bb1f562
2 changed files with 100 additions and 0 deletions

View File

@ -14184,6 +14184,43 @@ dedent """
FileCheck().check_not("prim::PythonOp").run(cu.test.graph)
def test_parse_generator(self):
def _test_parse_generator(seed):
jit_graph = parse_ir(
f"""
graph():
%0 : float = prim::Constant[value=-0.31622776601683789]()
%1 : float = prim::Constant[value=0.31622776601683789]()
%2 : Generator = prim::Constant[value=torch.Generator(device="cpu", seed={seed})]()
%3 : NoneType = prim::Constant()
%4 : int[] = prim::Constant[value=[]]()
%5 : int = prim::Constant[value=6]()
%6 : Device = prim::Constant[value="cpu"]()
%7 : Tensor = aten::empty(%4, %5, %3, %6, %3, %3)
%8 : Float() = aten::uniform(%7, %0, %1, %2)
return (%8)
""",
)
node = next(
n
for n in jit_graph.nodes()
if isinstance(n.output().type(), torch._C._GeneratorType)
)
assert isinstance(node.output().type(), torch._C._GeneratorType)
g = node.ival("value")
assert isinstance(g, torch.Generator)
self.assertEqual(g.initial_seed(), seed)
_test_parse_generator(2024)
_test_parse_generator(2**63 - 1)
with self.assertRaisesRegex(RuntimeError, "Seed must be a non-negative integer"):
_test_parse_generator(-2024)
with self.assertRaisesRegex(RuntimeError, "Number is too big"):
_test_parse_generator(2**63)
def test_early_return_rewrite(self):
def test_foo(x: bool):
if x:

View File

@ -427,6 +427,69 @@ void IRParser::parseAttr(Node* n) {
}
L.expect(')');
deferred_empty_container_initializations_.push_back(n);
} else if (L.cur().text() == "torch") {
L.next();
L.expect('.');
auto function = L.cur().text();
if (function == "Generator") {
L.next();
L.expect('(');
std::optional<uint64_t> seed;
std::string device = "cpu";
while (!L.nextIf(')')) {
auto arg = L.expect(TK_IDENT).text();
L.expect('=');
if (arg == "device") {
ParsedLiteral r = parseScalarLiteral(n);
if (r.k != AttributeKind::s) {
throw(
ErrorReport(L.cur().range)
<< "Expected string literal for device argument");
}
if (r.s != "cpu") {
throw(
ErrorReport(L.cur().range)
<< "Only cpu device is supported for Generator at this time.");
}
device = r.s;
} else if (arg == "seed") {
ParsedLiteral r = parseScalarLiteral(n);
if (r.k != AttributeKind::i) {
throw(
ErrorReport(L.cur().range)
<< "Expected int literal for seed argument");
}
if (r.i < 0) {
throw(
ErrorReport(L.cur().range)
<< "Seed must be a non-negative integer");
}
seed = r.i;
} else {
throw(
ErrorReport(L.cur().range)
<< "Generator only supports the following arguments:\n"
<< "- device\n"
<< "- seed\n"
<< "Got: " << arg);
}
L.nextIf(',');
}
if (device == "cpu") {
if (seed.has_value()) {
n->ival_(
Symbol::attr(attrname), at::detail::createCPUGenerator(*seed));
} else {
n->ival_(Symbol::attr(attrname), at::detail::createCPUGenerator());
}
}
} else {
throw(
ErrorReport(L.cur().range)
<< "Expected one of the following torch functions:\n"
<< "- Generator\n"
<< "Got: " << function);
}
} else {
// scalar
ParsedLiteral r = parseScalarLiteral(n);