mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add support for parsing torch.Generator in JIT (#140489)
Fixes #140420 Pull Request resolved: https://github.com/pytorch/pytorch/pull/140489 Approved by: https://github.com/davidberard98
This commit is contained in:
parent
70060b0927
commit
b34bb1f562
|
|
@ -14184,6 +14184,43 @@ dedent """
|
|||
|
||||
FileCheck().check_not("prim::PythonOp").run(cu.test.graph)
|
||||
|
||||
def test_parse_generator(self):
|
||||
def _test_parse_generator(seed):
|
||||
jit_graph = parse_ir(
|
||||
f"""
|
||||
graph():
|
||||
%0 : float = prim::Constant[value=-0.31622776601683789]()
|
||||
%1 : float = prim::Constant[value=0.31622776601683789]()
|
||||
%2 : Generator = prim::Constant[value=torch.Generator(device="cpu", seed={seed})]()
|
||||
%3 : NoneType = prim::Constant()
|
||||
%4 : int[] = prim::Constant[value=[]]()
|
||||
%5 : int = prim::Constant[value=6]()
|
||||
%6 : Device = prim::Constant[value="cpu"]()
|
||||
%7 : Tensor = aten::empty(%4, %5, %3, %6, %3, %3)
|
||||
%8 : Float() = aten::uniform(%7, %0, %1, %2)
|
||||
return (%8)
|
||||
""",
|
||||
)
|
||||
|
||||
node = next(
|
||||
n
|
||||
for n in jit_graph.nodes()
|
||||
if isinstance(n.output().type(), torch._C._GeneratorType)
|
||||
)
|
||||
assert isinstance(node.output().type(), torch._C._GeneratorType)
|
||||
g = node.ival("value")
|
||||
assert isinstance(g, torch.Generator)
|
||||
self.assertEqual(g.initial_seed(), seed)
|
||||
|
||||
_test_parse_generator(2024)
|
||||
_test_parse_generator(2**63 - 1)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Seed must be a non-negative integer"):
|
||||
_test_parse_generator(-2024)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Number is too big"):
|
||||
_test_parse_generator(2**63)
|
||||
|
||||
def test_early_return_rewrite(self):
|
||||
def test_foo(x: bool):
|
||||
if x:
|
||||
|
|
|
|||
|
|
@ -427,6 +427,69 @@ void IRParser::parseAttr(Node* n) {
|
|||
}
|
||||
L.expect(')');
|
||||
deferred_empty_container_initializations_.push_back(n);
|
||||
} else if (L.cur().text() == "torch") {
|
||||
L.next();
|
||||
L.expect('.');
|
||||
auto function = L.cur().text();
|
||||
if (function == "Generator") {
|
||||
L.next();
|
||||
L.expect('(');
|
||||
std::optional<uint64_t> seed;
|
||||
std::string device = "cpu";
|
||||
while (!L.nextIf(')')) {
|
||||
auto arg = L.expect(TK_IDENT).text();
|
||||
L.expect('=');
|
||||
if (arg == "device") {
|
||||
ParsedLiteral r = parseScalarLiteral(n);
|
||||
if (r.k != AttributeKind::s) {
|
||||
throw(
|
||||
ErrorReport(L.cur().range)
|
||||
<< "Expected string literal for device argument");
|
||||
}
|
||||
if (r.s != "cpu") {
|
||||
throw(
|
||||
ErrorReport(L.cur().range)
|
||||
<< "Only cpu device is supported for Generator at this time.");
|
||||
}
|
||||
device = r.s;
|
||||
} else if (arg == "seed") {
|
||||
ParsedLiteral r = parseScalarLiteral(n);
|
||||
if (r.k != AttributeKind::i) {
|
||||
throw(
|
||||
ErrorReport(L.cur().range)
|
||||
<< "Expected int literal for seed argument");
|
||||
}
|
||||
if (r.i < 0) {
|
||||
throw(
|
||||
ErrorReport(L.cur().range)
|
||||
<< "Seed must be a non-negative integer");
|
||||
}
|
||||
seed = r.i;
|
||||
} else {
|
||||
throw(
|
||||
ErrorReport(L.cur().range)
|
||||
<< "Generator only supports the following arguments:\n"
|
||||
<< "- device\n"
|
||||
<< "- seed\n"
|
||||
<< "Got: " << arg);
|
||||
}
|
||||
L.nextIf(',');
|
||||
}
|
||||
if (device == "cpu") {
|
||||
if (seed.has_value()) {
|
||||
n->ival_(
|
||||
Symbol::attr(attrname), at::detail::createCPUGenerator(*seed));
|
||||
} else {
|
||||
n->ival_(Symbol::attr(attrname), at::detail::createCPUGenerator());
|
||||
}
|
||||
}
|
||||
} else {
|
||||
throw(
|
||||
ErrorReport(L.cur().range)
|
||||
<< "Expected one of the following torch functions:\n"
|
||||
<< "- Generator\n"
|
||||
<< "Got: " << function);
|
||||
}
|
||||
} else {
|
||||
// scalar
|
||||
ParsedLiteral r = parseScalarLiteral(n);
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user