pytorch/torch/csrc/Generator.h
Frank Lin 249e65b92d Graph-Safe RNG State Exchange for Tensor Parallelism (#114068)
See #113541

The PR allows for registering and controlling multiple RNG states using indices, ensuring cudagraph-safe operations, and includes both C++ and Python API changes to support this functionality.

cc  @eellison @anijain2305 @jansel @ezyang @ptrblck @csarofeen @mcarilli
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114068
Approved by: https://github.com/ezyang, https://github.com/eqy, https://github.com/xuzhao9
2024-03-27 01:14:38 +00:00

31 lines
1.0 KiB
C

#pragma once
#include <ATen/core/Generator.h>
#include <torch/csrc/Export.h>
#include <torch/csrc/python_headers.h>
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
struct THPGenerator {
PyObject_HEAD at::Generator cdata;
};
// Creates a new Python object wrapping the default at::Generator. The reference
// is borrowed. The caller should ensure that the at::Generator object lifetime
// last at least as long as the Python wrapper.
TORCH_PYTHON_API PyObject* THPGenerator_initDefaultGenerator(
at::Generator cdata);
#define THPGenerator_Check(obj) PyObject_IsInstance(obj, THPGeneratorClass)
TORCH_PYTHON_API extern PyObject* THPGeneratorClass;
bool THPGenerator_init(PyObject* module);
TORCH_PYTHON_API PyObject* THPGenerator_Wrap(at::Generator gen);
TORCH_PYTHON_API at::Generator THPGenerator_Unwrap(PyObject* state);
// Creates a new Python object for a Generator. The Generator must not already
// have a PyObject* associated with it.
PyObject* THPGenerator_NewWithVar(PyTypeObject* type, at::Generator gen);