mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Continuing the work from https://github.com/pytorch/pytorch/pull/146427 Adds the `torch.float8_e8m0fnu` dtype to PyTorch, as detailed in https://github.com/pytorch/pytorch/issues/146414 . Please see the issue for a detailed definition of the format. Example of basic functionality: ```python import torch # round trip x0 = torch.randn(4, 4, dtype=torch.float32) x1 = x0.to(torch.float8_e8m0fnu) # RNE rounding x2 = x1.to(torch.float32) # 2 ** exponent # creation with empty x0 = torch.empty(4, 4, dtype=torch.float8_e8m0fnu) # printing print(x0) ``` Done in this PR: * numerical correctness * op coverage (except for `torch._scaled_mm`): create tensor, cast to/from float32 * printing a tensor works For future PRs: * performance optimizations for casting * torch._scaled_mm * PT2 * various cleanups (detailed in comments with issue numbers) Test Plan: ``` pytest test/quantization/core/experimental/test_float8.py -s ``` Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/147466 Approved by: https://github.com/drisspg |
||
|---|---|---|
| .. | ||
| _autoheuristic | ||
| aoti | ||
| api | ||
| decompositions | ||
| dest | ||
| executorch | ||
| fuse | ||
| operator_versions | ||
| selective_build | ||
| shape_functions | ||
| static_runtime | ||
| __init__.py | ||
| BUCK.oss | ||
| BUILD.bazel | ||
| build.bzl | ||
| code_template.py | ||
| context.py | ||
| gen_aoti_c_shim.py | ||
| gen_backend_stubs.py | ||
| gen_executorch.py | ||
| gen_functionalization_type.py | ||
| gen_lazy_tensor.py | ||
| gen_schema_utils.py | ||
| gen_vmap_plumbing.py | ||
| gen.py | ||
| local.py | ||
| model.py | ||
| native_function_generation.py | ||
| utils.py | ||
| yaml_utils.py | ||