mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Continuing the work from https://github.com/pytorch/pytorch/pull/146427 Adds the `torch.float8_e8m0fnu` dtype to PyTorch, as detailed in https://github.com/pytorch/pytorch/issues/146414 . Please see the issue for a detailed definition of the format. Example of basic functionality: ```python import torch # round trip x0 = torch.randn(4, 4, dtype=torch.float32) x1 = x0.to(torch.float8_e8m0fnu) # RNE rounding x2 = x1.to(torch.float32) # 2 ** exponent # creation with empty x0 = torch.empty(4, 4, dtype=torch.float8_e8m0fnu) # printing print(x0) ``` Done in this PR: * numerical correctness * op coverage (except for `torch._scaled_mm`): create tensor, cast to/from float32 * printing a tensor works For future PRs: * performance optimizations for casting * torch._scaled_mm * PT2 * various cleanups (detailed in comments with issue numbers) Test Plan: ``` pytest test/quantization/core/experimental/test_float8.py -s ``` Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/147466 Approved by: https://github.com/drisspg |
||
|---|---|---|
| .. | ||
| byte_order.cpp | ||
| byte_order.h | ||
| cpp_stacktraces.cpp | ||
| cpp_stacktraces.h | ||
| cuda_enabled.h | ||
| device_lazy_init.cpp | ||
| device_lazy_init.h | ||
| disable_torch_function.cpp | ||
| disable_torch_function.h | ||
| generated_serialization_types.h | ||
| init.cpp | ||
| init.h | ||
| invalid_arguments.cpp | ||
| invalid_arguments.h | ||
| nested.cpp | ||
| nested.h | ||
| numpy_stub.h | ||
| object_ptr.cpp | ||
| object_ptr.h | ||
| out_types.cpp | ||
| out_types.h | ||
| pybind.cpp | ||
| pybind.h | ||
| pycfunction_helpers.h | ||
| pyobject_preservation.cpp | ||
| pyobject_preservation.h | ||
| python_arg_parser.cpp | ||
| python_arg_parser.h | ||
| python_compat.h | ||
| python_dispatch.cpp | ||
| python_dispatch.h | ||
| python_numbers.h | ||
| python_raii.h | ||
| python_scalars.h | ||
| python_strings.h | ||
| python_stub.h | ||
| python_symnode.cpp | ||
| python_symnode.h | ||
| python_torch_function_mode.h | ||
| python_tuples.h | ||
| pythoncapi_compat.h | ||
| schema_info.cpp | ||
| schema_info.h | ||
| six.h | ||
| structseq.cpp | ||
| structseq.h | ||
| tensor_apply.cpp | ||
| tensor_apply.h | ||
| tensor_dtypes.cpp | ||
| tensor_dtypes.h | ||
| tensor_flatten.cpp | ||
| tensor_flatten.h | ||
| tensor_layouts.cpp | ||
| tensor_layouts.h | ||
| tensor_list.cpp | ||
| tensor_list.h | ||
| tensor_memoryformats.cpp | ||
| tensor_memoryformats.h | ||
| tensor_new.cpp | ||
| tensor_new.h | ||
| tensor_numpy.cpp | ||
| tensor_numpy.h | ||
| tensor_qschemes.cpp | ||
| tensor_qschemes.h | ||
| tensor_types.cpp | ||
| tensor_types.h | ||
| throughput_benchmark-inl.h | ||
| throughput_benchmark.cpp | ||
| throughput_benchmark.h | ||
| torch_dispatch_mode.h | ||
| variadic.cpp | ||
| variadic.h | ||
| verbose.cpp | ||
| verbose.h | ||