mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
Fix typo in test folder. _typos.toml ```bash [default.extend-words] nd = "nd" arange = "arange" Nd = "Nd" GLOBALs = "GLOBALs" hte = "hte" iy = "iy" PN = "PN" Dout = "Dout" optin = "optin" gam = "gam" PTD = "PTD" ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/166374 Approved by: https://github.com/cyyever, https://github.com/ezyang
38 lines
1.4 KiB
Python
38 lines
1.4 KiB
Python
from pathlib import Path
|
|
|
|
import torch
|
|
from torch.fx import symbolic_trace
|
|
from torch.package import PackageExporter
|
|
from torch.testing._internal.common_utils import IS_FBCODE, IS_SANDCASTLE
|
|
|
|
|
|
packaging_directory = f"{Path(__file__).parent}/package_bc"
|
|
torch.package.package_exporter._gate_torchscript_serialization = False
|
|
|
|
|
|
def generate_bc_packages():
|
|
"""Function to create packages for testing backwards compatibility"""
|
|
if not IS_FBCODE or IS_SANDCASTLE:
|
|
from package_a.test_nn_module import TestNnModule
|
|
|
|
test_nn_module = TestNnModule()
|
|
test_torchscript_module = torch.jit.script(TestNnModule())
|
|
test_fx_module: torch.fx.GraphModule = symbolic_trace(TestNnModule())
|
|
with PackageExporter(f"{packaging_directory}/test_nn_module.pt") as pe1:
|
|
pe1.intern("**")
|
|
pe1.save_pickle("nn_module", "nn_module.pkl", test_nn_module)
|
|
with PackageExporter(
|
|
f"{packaging_directory}/test_torchscript_module.pt"
|
|
) as pe2:
|
|
pe2.intern("**")
|
|
pe2.save_pickle(
|
|
"torchscript_module", "torchscript_module.pkl", test_torchscript_module
|
|
)
|
|
with PackageExporter(f"{packaging_directory}/test_fx_module.pt") as pe3:
|
|
pe3.intern("**")
|
|
pe3.save_pickle("fx_module", "fx_module.pkl", test_fx_module)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
generate_bc_packages()
|