mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[forward fix] add support for MemoryFormat after type tightening (#154658)
Summary:
fixes error:
```
raise AssertionError(f"Unexpected type in c_type_for_prim_type: {type_=}")
AssertionError: Unexpected type in c_type_for_prim_type: type_=MemoryFormat
```
after https://github.com/pytorch/pytorch/pull/154371 | D75568111
Test Plan:
```
buck test 'fbcode//mode/opt' fbcode//deeplearning/aot_inductor/test:test_custom_ops -- --exact 'deeplearning/aot_inductor/test:test_custom_ops - test_export_extern_fallback_nodes (deeplearning.aot_inductor.test.test_custom_ops.TestAOTInductorProxyExecutor)'
```
Differential Revision: D75617432
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154658
Approved by: https://github.com/Camyll, https://github.com/atalman, https://github.com/malfet
This commit is contained in:
parent
a4b0023f3b
commit
0fdd568b78
|
|
@ -2388,7 +2388,7 @@ if (!custom_op_wrapper) {
|
|||
return "int64_t"
|
||||
elif isinstance(
|
||||
type_, (torch.BoolType, torch.SymBoolType, torch.EnumType)
|
||||
) or repr(type_) in ("ScalarType", "Layout"):
|
||||
) or repr(type_) in ("ScalarType", "Layout", "MemoryFormat"):
|
||||
return "int32_t"
|
||||
elif isinstance(type_, torch.FloatType):
|
||||
return "double"
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user