Make adding buffers more like adding parameters (#104069)

Add similar semantics for creating a buffer object similar to creating a parameter. This is done by introducing a new `Buffer` class that can be used for type disambiguation. The underlying functionality of registering a buffer remains the same as the `register_buffer` method has not been changed. The `persistent` parameter in the `Buffer` type is to indicate whether a buffer object should be persistent or not. Other non-test changes have to do with getting the new `Buffer` type recognized by inductor and dynamo. Remaining changes are test changes to make sure that the `Buffer` type can be used as a drop in replacement for `register_buffer` as it just leads to `register_buffer` being called. The addition of this new functionality still allows for normal tensors to be used as buffers so these changes are intended to be backwards compatible.

Fixes #35735

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104069
Approved by: https://github.com/mikaylagawarecki
This commit is contained in:
ekamiti 2023-07-17 17:59:02 +00:00 committed by PyTorch MergeBot
parent 4fc47b4156
commit 32d422f335
41 changed files with 268 additions and 149 deletions

View File

@ -22,6 +22,7 @@ These are the basic building blocks for graphs:
:nosignatures:
:template: classtemplate.rst
~parameter.Buffer
~parameter.Parameter
~parameter.UninitializedParameter
~parameter.UninitializedBuffer

View File

@ -58,7 +58,7 @@ class TestFlattenParams(FSDPTest):
dim_feedforward=128,
dropout=0.1,
)
module.register_buffer("dummy_buffer", torch.tensor(1.0))
module.dummy_buffer = nn.Buffer(torch.tensor(1.0))
def get_input(device, dtype):
torch.manual_seed(1) # keep everything deterministic

View File

@ -453,11 +453,11 @@ class TestFSDPMiscMultiProcess(FSDPTest):
# Check that `device_id` with `sync_module_states=True` works
nested_wrapped_module = init_nested_wrapped_module()
nested_wrapped_module.register_buffer(
"buf", torch.ones((2, 2), device="cpu") * self.rank
nested_wrapped_module.buf = nn.Buffer(
torch.ones((2, 2), device="cpu") * self.rank
)
nested_wrapped_module.module[0].register_buffer(
"buf", torch.ones((3, 2), device="cpu") * self.rank
nested_wrapped_module.module[0].buf = nn.Buffer(
torch.ones((3, 2), device="cpu") * self.rank
)
nested_wrapped_module = FSDP(
nested_wrapped_module,
@ -705,7 +705,7 @@ class TestFSDPMiscMultiThread(FSDPTestMultiThread):
torch.manual_seed(rank)
torch.cuda.manual_seed(rank)
self.lin = nn.Linear(10, 10, bias=False)
self.register_buffer("buffer", torch.ones(1) * rank)
self.buffer = nn.Buffer(torch.ones(1) * rank)
m = MyModel(self.rank).cuda()
_assert_module_states(

View File

@ -87,7 +87,7 @@ class Model(Module):
super().__init__()
self.inner = Linear(*INNER_SHAPE)
if register_buffers:
self.inner.register_buffer("buffer", torch.randn(BUFFER_SHAPE))
self.inner.buffer = nn.Buffer(torch.randn(BUFFER_SHAPE))
self.inner.register_buffer(
"non_persistent_buffer", torch.randn(BUFFER_SHAPE), persistent=False
)
@ -97,7 +97,7 @@ class Model(Module):
)
self.outer = Linear(*OUTER_SHAPE)
if register_buffers:
self.outer.register_buffer("buffer", torch.randn(BUFFER_SHAPE))
self.outer.buffer = nn.Buffer(torch.randn(BUFFER_SHAPE))
self.outer.register_buffer(
"non_persistent_buffer", torch.randn(BUFFER_SHAPE), persistent=False
)

View File

@ -423,7 +423,7 @@ class TestUnshardParams(TestUnshardParamsBase):
CUDAInitMode.CUDA_BEFORE,
deterministic=True,
)
model.register_buffer("buffer", torch.ones(1))
model.buffer = nn.Buffer(torch.ones(1))
# Wrap the top-level with FSDP since `named_parameters()` and
# `named_buffers` will contain FSDP prefixes if called on a non-FSDP
# root module
@ -436,7 +436,7 @@ class TestUnshardParams(TestUnshardParamsBase):
),
self.process_group,
)
fsdp_model.register_buffer("buffer", torch.ones(1))
fsdp_model.buffer = nn.Buffer(torch.ones(1))
with FSDP.summon_full_params(fsdp_model):
for call in ["named_parameters", "named_buffers"]:
for (n1, p1), (n2, p2) in itertools.zip_longest(

View File

@ -855,8 +855,7 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
torch.nn.Linear(HIDDEN_DIM, HIDDEN_DIM),
torch.nn.Linear(HIDDEN_DIM, OUTPUT_DIM),
).to(self.device)
model.register_buffer(
"test_buffer",
model.test_buffer = torch.nn.Buffer(
torch.ones((1), device=self.device) * self.rank,
)
# Define models/optimizers for DDP with ZeRO and DDP with local

View File

@ -34,8 +34,8 @@ class TestDataParallel(TestCase):
class TestModule(nn.Module):
def __init__(self, t):
super().__init__()
self.register_buffer('t_rg', t)
self.register_buffer('t_not_rg', t.clone().detach())
self.t_rg = nn.Buffer(t, t.requires_grad)
self.t_not_rg = nn.Buffer(t.clone().detach())
def forward(self, x):
return x * self.t_rg + self.t_not_rg

View File

@ -760,9 +760,8 @@ class TestSingleProc(DynamoDistributedSingleProcTestCase):
def __init__(self) -> None:
super().__init__()
self._param = torch.randn((3,), device="cuda")
self.register_buffer(
"_buf", torch.randn((3,), requires_grad=False, device="cuda")
)
self._buf = torch.nn.Buffer(
torch.randn((3,), requires_grad=False, device="cuda"))
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Use `_param` and `_buf` each twice in this compiled forward
@ -789,8 +788,8 @@ class TestSingleProc(DynamoDistributedSingleProcTestCase):
class BufModule(nn.Module):
def __init__(self) -> None:
super().__init__()
self.register_buffer(
"_buf", torch.randn((3,), requires_grad=False, device="cuda")
self._buf = nn.Buffer(
torch.randn((3,), requires_grad=False, device="cuda")
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
@ -802,7 +801,7 @@ class TestSingleProc(DynamoDistributedSingleProcTestCase):
self._param = nn.Parameter(torch.randn((1,), device="cuda"))
self._buf_module = BufModule()
# Share the buffer, meaning same tensor but different source
self.register_buffer("_buf", self._buf_module._buf)
self._buf = self._buf_module._buf
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Use the same buffer tensor twice in the compiled forward,

View File

@ -973,7 +973,7 @@ class ExportTests(torch._dynamo.test_case.TestCase):
def __init__(self):
super().__init__()
self.weight = torch.nn.Parameter(torch.ones(1, 1))
self.register_buffer("buffer", torch.ones(1, 1))
self.buffer = torch.nn.Buffer(torch.ones(1, 1))
def forward(self, x):
x = torch.nn.functional.linear(x, torch.randn(4, 4))
@ -2668,7 +2668,7 @@ def forward(self, x):
class Foo(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("buffer1", torch.ones(6, 2))
self.buffer1 = torch.nn.Buffer(torch.ones(6, 2))
def forward(self, x):
x.add_(2)

View File

@ -1451,7 +1451,7 @@ class MockModule(torch.nn.Module):
super().__init__()
self.relu = torch.nn.ReLU()
self.linear = torch.nn.Linear(10, 10)
self.register_buffer("buf0", torch.randn(10, 10))
self.buf0 = torch.nn.Buffer(torch.randn(10, 10))
def forward(self, x):
return self.relu(self.linear(x) + self.buf0)
@ -1500,7 +1500,7 @@ class OptimizedModuleTest(torch._dynamo.test_case.TestCase):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(10, 10)
self.register_buffer("buf0", torch.randn(10, 10))
self.buf0 = torch.nn.Buffer(torch.randn(10, 10))
def forward(self, x):
return self.r(torch.sin(x)) + self.buf0
@ -1527,7 +1527,7 @@ class OptimizedModuleTest(torch._dynamo.test_case.TestCase):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(10, 10)
self.register_buffer("buf0", torch.randn(10, 10))
self.register_buffer("buf0", torch.nn.Buffer(torch.randn(10, 10)))
self.register_parameter(
name="param0", param=torch.nn.Parameter(torch.randn(10, 10))
)

View File

@ -1859,8 +1859,8 @@ class ReproTests(torch._dynamo.test_case.TestCase):
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("sorted", torch.ones(4, 4))
self.register_buffer("indices", torch.ones(4, 4, dtype=torch.long))
self.sorted = torch.nn.Buffer(torch.ones(4, 4))
self.indices = torch.nn.Buffer(torch.ones(4, 4, dtype=torch.long))
def forward(self, x):
torch.sort(x, out=(self.sorted, self.indices))
@ -1891,7 +1891,7 @@ class ReproTests(torch._dynamo.test_case.TestCase):
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("base", torch.ones(4, 4))
self.base = torch.nn.Buffer(torch.ones(4, 4))
def forward(self, x):
torch.sigmoid(x, out=self.base)
@ -2174,8 +2174,8 @@ class ReproTests(torch._dynamo.test_case.TestCase):
class Foo(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("x", torch.ones(3))
self.register_buffer("y", torch.ones(3))
self.x = torch.nn.Buffer(torch.ones(3))
self.y = torch.nn.Buffer(torch.ones(3))
def forward(self, inp):
res = 0

View File

@ -50,7 +50,7 @@ class TestExperimentalExport(TestCase):
class Foo(torch.nn.Module):
def __init__(self, float_val):
super().__init__()
self.register_buffer("buffer1", torch.ones(6, 1))
self.buffer1 = torch.nn.Buffer(torch.ones(6, 1))
def forward(self, x):
self.buffer1.add_(2)

View File

@ -178,8 +178,7 @@ class VerifierTest(TestCase):
class TestModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer(
"a",
self.a = torch.nn.Buffer(
torch.randn(1, 3, 100, 100).to(memory_format=torch.channels_last),
)

View File

@ -1945,7 +1945,7 @@ def forward(self, tangents_1):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("buffer", torch.ones(4, 5))
self.buffer = torch.nn.Buffer(torch.ones(4, 5))
def forward(self, x):
y = self.buffer.add_(3)
@ -2140,7 +2140,7 @@ class <lambda>(torch.nn.Module):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("buffer1", torch.ones(6, 4))
self.buffer1 = torch.nn.Buffer(torch.ones(6, 4))
def forward(self, x):
x.add_(4)
@ -2153,7 +2153,7 @@ class <lambda>(torch.nn.Module):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("buffer1", torch.ones(6, 4))
self.buffer1 = torch.nn.Buffer(torch.ones(6, 4))
def forward(self, x, y):
y.add_(4)

View File

@ -3466,8 +3466,8 @@ class TestMakeFunctional(TestCase):
super().__init__()
self.bias = nn.Parameter(torch.randn(3))
self.linear = nn.Linear(3, 3)
self.register_buffer('buffer', torch.randn(3))
self.register_buffer('buffer_tied', self.buffer)
self.buffer = nn.Buffer(torch.randn(3))
self.buffer_tied = self.buffer
def forward(self, x):
x = self.linear(x)
@ -3497,7 +3497,7 @@ class TestMakeFunctional(TestCase):
def __init__(self):
super().__init__()
self.linear = nn.Linear(3, 3)
self.register_buffer('buffer', torch.randn(3))
self.buffer = nn.Buffer(torch.randn(3))
def forward(self, x):
x = self.linear(x)
@ -3517,7 +3517,7 @@ class TestMakeFunctional(TestCase):
def __init__(self):
super().__init__()
self.linear = nn.Linear(3, 3)
self.register_buffer('buffer', torch.randn(3))
self.buffer = nn.Buffer(torch.randn(3))
def forward(self, x):
x = self.linear(x)
@ -3573,8 +3573,8 @@ class TestMakeFunctional(TestCase):
self.linear = nn.Linear(3, 3)
self.weight = self.linear.weight
self.bias = self.linear.bias
self.register_buffer('buffer', torch.randn(3))
self.register_buffer('buffer_tied', self.buffer)
self.buffer = nn.Buffer(torch.randn(3))
self.buffer_tied = self.buffer
def forward(self, x):
x = self.linear(x)

View File

@ -565,8 +565,7 @@ class CudaReproTests(TestCase):
start = math.log2(0.5)
end = math.log2(1 / (2**8))
self.register_buffer(
"scales",
self.scales = nn.Buffer(
2
** torch.arange(
start,

View File

@ -6184,9 +6184,7 @@ class CommonTemplate:
class Repro(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer(
"_tensor_constant0", torch.randn([], dtype=torch.float32)
)
self._tensor_constant0 = nn.Buffer(torch.randn([], dtype=torch.float32))
def forward(self, arg0_1, arg1_1):
convert_element_type = torch.ops.prims.convert_element_type.default(

View File

@ -487,6 +487,7 @@ class TestSaveLoad(JitTestCase):
self.parameter_b = torch.nn.Parameter(torch.randn(4))
self.submodule_b = Submodule()
self.buffer_b = torch.nn.Buffer(torch.randn(4))
m = TestModule()
m_loaded = self.getExportImportCopy(torch.jit.script(m))
@ -526,7 +527,7 @@ class TestSaveLoad(JitTestCase):
super().__init__()
self.foo = torch.nn.Linear(2, 3, device="meta")
self.bar = torch.nn.Linear(3, 4)
self.register_buffer("buffer", torch.randn(4, device="meta"))
self.buffer = torch.nn.Buffer(torch.randn(4, device="meta"))
def forward(self, x):
x = self.foo(x)
@ -1150,6 +1151,7 @@ class TestSaveLoadFlatbuffer(JitTestCase):
self.parameter_b = torch.nn.Parameter(torch.randn(4))
self.submodule_b = Submodule()
self.buffer_b = torch.nn.Buffer(torch.randn(4))
m = TestModule()
m_loaded = self.getExportImportCopy(torch.jit.script(m))

View File

@ -5,7 +5,7 @@ import pickle
import torch
import torch.nn as nn
from torch.nn.parameter import UninitializedParameter, UninitializedBuffer
from torch.nn import Parameter
from torch.nn import Buffer, Parameter
from torch.testing._internal.common_utils import TestCase, run_tests, suppress_warnings
from torch.testing._internal.common_cuda import TEST_CUDA
@ -47,29 +47,29 @@ class TestLazyModules(TestCase):
@suppress_warnings
def test_lazy_module_buffer(self):
module = LazyModule()
module.register_buffer('test_buffer', UninitializedBuffer())
module.test_buffer = UninitializedBuffer()
self.assertTrue(module.has_uninitialized_params())
state_dict = module.state_dict()
self.assertIsInstance(state_dict['test_buffer'], UninitializedBuffer)
new_module = LazyModule()
# An error is raised when there is an attempt to replace an existing parameter
# with an uninitialized one
new_module.register_buffer('test_buffer', torch.ones(5, 5))
new_module.test_buffer = Buffer(torch.ones(5, 5))
with self.assertRaisesRegex(RuntimeError, 'shape of an uninitialized'):
new_module.load_state_dict(state_dict)
# Uninitialized parameters are overriden when the state dict to be loaded contains a valid one
new_module = LazyModule()
new_module.register_buffer('test_buffer', torch.ones(5, 5))
new_module.test_buffer = Buffer(torch.ones(5, 5))
module.load_state_dict(new_module.state_dict())
self.assertEqual(module.test_buffer, torch.ones((5, 5)))
# Uninitialized parameters are left unchanged
module = LazyModule()
module.register_buffer('test_buffer', UninitializedBuffer())
module.test_buffer = UninitializedBuffer()
self.assertTrue(module.has_uninitialized_params())
new_module = LazyModule()
new_module.register_buffer('test_buffer', UninitializedBuffer())
new_module.test_buffer = UninitializedBuffer()
module.load_state_dict(new_module.state_dict())
module.load_state_dict(new_module.state_dict())
self.assertTrue(module.has_uninitialized_params())
@ -85,7 +85,7 @@ class TestLazyModules(TestCase):
@suppress_warnings
def test_lazy_module_jit_buffer(self):
module = LazyModule()
module.register_buffer('test_buffer', UninitializedBuffer())
module.test_buffer = UninitializedBuffer()
self.assertTrue(module.has_uninitialized_params())
with self.assertRaisesRegex(RuntimeError, 'run a forward pass'):
torch.jit.script(module)
@ -101,7 +101,7 @@ class TestLazyModules(TestCase):
@suppress_warnings
def test_lazy_share_memory_buffer(self):
module = LazyModule()
module.register_buffer('test_buffer', UninitializedBuffer())
module.test_buffer = UninitializedBuffer()
self.assertTrue(module.has_uninitialized_params())
with self.assertRaisesRegex(RuntimeError, 'share memory on an uninitialized'):
module.share_memory()

View File

@ -9,7 +9,7 @@ import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import torch.nn.utils.parametrize as parametrize
from torch.nn import Parameter
from torch.nn import Buffer, Parameter
from torch.testing._internal.common_utils import run_tests, skipIfNoLapack, \
TemporaryFileName, instantiate_parametrized_tests, set_default_dtype
from torch.testing._internal.common_cuda import TEST_MULTIGPU
@ -305,7 +305,7 @@ class TestNNParametrization(NNTestCase):
# Instantiate parametrizations on buffers. It should work as expected
delattr(model, "bias")
model.register_buffer("bias", torch.ones(8))
model.bias = Buffer(torch.ones(8))
parametrize.register_parametrization(model, "bias", FirstZero())
parametrize.register_parametrization(model, "bias", LastZero())
self.assertTrue(parametrize.is_parametrized(model))
@ -333,8 +333,8 @@ class TestNNParametrization(NNTestCase):
class Orthogonal(nn.Module):
def __init__(self, n):
super().__init__()
self.register_buffer("id", torch.eye(n))
self.register_buffer("B", torch.empty(n, n))
self.id = Buffer(torch.eye(n))
self.B = Buffer(torch.empty(n, n))
init.orthogonal_(self.B)
def forward(self, X):
@ -396,7 +396,7 @@ class TestNNParametrization(NNTestCase):
class Orthogonal(nn.Module):
def __init__(self, n):
super().__init__()
self.register_buffer("B", torch.eye(n))
self.B = Buffer(torch.eye(n))
def forward(self, X):
Id = torch.eye(X.size(0))

View File

@ -297,7 +297,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
self,
):
super().__init__()
self.register_buffer("weight", torch.ones(5))
self.weight = torch.nn.Buffer(torch.ones(5))
def forward(self, x):
scale_1 = self.weight.reshape(1, -1, 1, 1)
@ -4214,7 +4214,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
class GatherModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("weight", torch.ones(5))
self.weight = torch.nn.Buffer(torch.ones(5))
# torch.nn.Embedding is converted to ONNX::Gather.
# Constant folding will be triggerred for constant inputs.
# This pattern is common for constant mask inputs in transformer models.
@ -4233,7 +4233,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
class GatherModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("weight", torch.ones(2))
self.weight = torch.nn.Buffer(torch.ones(2))
def forward(self, x):
# shape is of rank 0
@ -4248,7 +4248,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
class GatherModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("rb", torch.randn(1, 1, 3, 1, 1))
self.rb = torch.nn.Buffer(torch.randn(1, 1, 3, 1, 1))
def forward(self, x):
x += self.rb[0]
@ -9394,7 +9394,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
class ShapeModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("weight", torch.ones(5))
self.weight = torch.nn.Buffer(torch.ones(5))
def forward(self, x):
shape = self.weight.shape[0]
@ -10895,7 +10895,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
def __init__(self, embedding_dim):
super().__init__()
self.weights = InnerModule2.get_embedding(embedding_dim)
self.register_buffer("_float_tensor", torch.FloatTensor(1))
self._float_tensor = torch.nn.Buffer(torch.FloatTensor(1))
self.const = 2
@staticmethod
@ -10957,7 +10957,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
self.embedding_dim = embedding_dim
self.const = 2.5
self.weights = InnerModule.get_embedding(self.embedding_dim)
self.register_buffer("_float_tensor", torch.FloatTensor(1))
self._float_tensor = torch.nn.Buffer(torch.FloatTensor(1))
@staticmethod
def get_embedding(embedding_dim: int):

View File

@ -540,7 +540,7 @@ class TestUtilityFuns(_BaseTestCase):
self,
):
super().__init__()
self.register_buffer("weight", torch.ones(5))
self.weight = torch.nn.Buffer(torch.ones(5))
def forward(self, x):
b = self.weight.reshape(1, -1, 1, 1)
@ -563,7 +563,7 @@ class TestUtilityFuns(_BaseTestCase):
self,
):
super().__init__()
self.register_buffer("weight", torch.ones(5))
self.weight = torch.nn.Buffer(torch.ones(5))
def forward(self, x):
div = self.weight.div(torch.tensor([1, 2, 3, 4, 5]))
@ -586,7 +586,7 @@ class TestUtilityFuns(_BaseTestCase):
self,
):
super().__init__()
self.register_buffer("weight", torch.ones(5))
self.weight = torch.nn.Buffer(torch.ones(5))
def forward(self, x):
mul = self.weight.mul(torch.tensor([1, 2, 3, 4, 5]))
@ -609,7 +609,7 @@ class TestUtilityFuns(_BaseTestCase):
self,
):
super().__init__()
self.register_buffer("weight", torch.ones(5))
self.weight = torch.nn.Buffer(torch.ones(5))
def forward(self, x):
add = self.weight + torch.tensor([1, 2, 3, 4, 5])
@ -640,7 +640,7 @@ class TestUtilityFuns(_BaseTestCase):
self,
):
super().__init__()
self.register_buffer("weight", torch.ones(5))
self.weight = torch.nn.Buffer(torch.ones(5))
def forward(self, x):
sub = self.weight - torch.tensor([1, 2, 3, 4, 5])
@ -671,7 +671,7 @@ class TestUtilityFuns(_BaseTestCase):
self,
):
super().__init__()
self.register_buffer("weight", torch.ones(5))
self.weight = torch.nn.Buffer(torch.ones(5))
def forward(self, x):
sqrt = torch.sqrt(self.weight)
@ -691,7 +691,7 @@ class TestUtilityFuns(_BaseTestCase):
class ShapeModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("weight", torch.ones(5))
self.weight = torch.nn.Buffer(torch.ones(5))
def forward(self, x):
shape = self.weight.shape[0]

View File

@ -1440,7 +1440,7 @@ class TestQuantizedTensor(TestCase):
s = torch.rand(5, dtype=torch.float64) + 0.1
zp = torch.randint(5, 15, (5,))
x_q = torch.quantize_per_channel(x, s, zp, 1, torch.quint8)
self.register_buffer('x', x_q)
self.x = torch.nn.Buffer(x_q)
@torch.jit.script_method
def forward(self):

View File

@ -94,9 +94,9 @@ class _ReferenceConvBnNd(torch.nn.Conv2d, torch.nn.modules.conv._ConvNd):
self.beta = nn.Parameter(torch.empty(out_channels))
self.affine = True
self.track_running_stats = True
self.register_buffer('running_mean', torch.zeros(out_channels))
self.register_buffer('running_var', torch.ones(out_channels))
self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long))
self.running_mean = nn.Buffer(torch.zeros(out_channels))
self.running_var = nn.Buffer(torch.ones(out_channels))
self.num_batches_tracked = nn.Buffer(torch.tensor(0, dtype=torch.long))
self.activation_post_process = self.qconfig.activation()
self.weight_fake_quant = self.qconfig.weight()
if bias:

View File

@ -817,7 +817,7 @@ class TestFX(JitTestCase):
self.mm_param = torch.nn.Parameter(torch.randn(d_hid, d_hid))
self.mm_param2 = torch.nn.Parameter(torch.randn(d_hid, d_hid))
self.lin = torch.nn.Linear(d_hid, d_hid)
self.register_buffer('buffer', torch.randn(bs + 100, d_hid))
self.buffer = torch.nn.Buffer(torch.randn(bs + 100, d_hid))
def forward(self, x):
x = torch.mm(x, self.mm_param)
@ -2660,7 +2660,7 @@ class TestFX(JitTestCase):
class GetItemBase(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer('pe', torch.randn(8, 8))
self.pe = torch.nn.Buffer(torch.randn(8, 8))
class GetItem1(GetItemBase):
def forward(self, x):
@ -3026,7 +3026,7 @@ class TestFX(JitTestCase):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(100, 200)
self.register_buffer("buf", torch.randn(2, 3))
self.buf = torch.nn.Buffer(torch.randn(2, 3))
self.net_c = C()
def forward(self, x):
@ -3196,7 +3196,7 @@ class TestFX(JitTestCase):
def __init__(self):
super().__init__()
self.l1 = torch.nn.Linear(1, 1)
self.register_buffer('buffer', torch.ones(1))
self.buffer = torch.nn.Buffer(torch.ones(1))
def forward(self, x):
return self.l1(x) + self.buffer

View File

@ -1215,8 +1215,8 @@ class {test_classname}(torch.nn.Module):
self.seq = torch.nn.Sequential(torch.nn.BatchNorm1d(2, 2))
self.linear = torch.nn.Linear(2, 2)
self.attr = torch.randn(2)
self.register_buffer("attr2", torch.randn(2))
self.register_buffer("attr3", torch.ones(2, dtype=torch.int32))
self.attr2 = torch.nn.Buffer(torch.randn(2))
self.attr3 = torch.nn.Buffer(torch.ones(2, dtype=torch.int32))
def forward(self, x):
return self.linear(self.seq(self.W + self.attr + self.attr2 + self.attr3 + x))

View File

@ -482,7 +482,7 @@ class TestJit(JitTestCase):
class MyModule(torch.jit.ScriptModule):
def __init__(self):
super().__init__()
self.register_buffer('b0', torch.randn(1, 3))
self.b0 = nn.Buffer(torch.randn(1, 3))
self.p0 = nn.Parameter(torch.randn(2, 3))
@torch.jit.script_method
@ -538,7 +538,7 @@ class TestJit(JitTestCase):
super().__init__()
whole_tensor = torch.randn(4, 5, dtype=torch.float, device='cpu')
self.p0 = nn.Parameter(whole_tensor.narrow(0, 0, 1))
self.register_buffer('b0', whole_tensor.narrow(0, 3, 1))
self.b0 = nn.Buffer(whole_tensor.narrow(0, 3, 1))
m = Foo()
m2 = self.getExportImportCopy(m, map_location=torch.device('cuda:0'))
@ -3989,7 +3989,7 @@ def foo(x):
a.p = nn.Parameter(torch.rand(3, 4))
a.foo = nn.Module()
a.foo.name = 'foo'
a.foo.register_buffer('b', torch.rand(1, 1))
a.foo.b = nn.Buffer(torch.rand(1, 1))
a.foo.bar = nn.Module()
a.foo.bar.name = 'bar'
a.foo.bar.an_int = 4
@ -8957,7 +8957,7 @@ dedent """
class ModuleBufferMutate(torch.jit.ScriptModule):
def __init__(self):
super().__init__()
self.register_buffer('running_var', torch.tensor(0, dtype=torch.long))
self.running_var = nn.Buffer(torch.tensor(0, dtype=torch.long))
@torch.jit.script_method
def forward(self):
@ -9084,12 +9084,12 @@ dedent """
def __init__(self):
super(TestScript.DerivedStateModule, self).__init__()
self.param = torch.nn.Parameter(torch.ones(3, 4, dtype=torch.float))
self.register_buffer('derived', torch.neg(self.param).detach().clone())
self.derived = nn.Buffer(torch.neg(self.param).detach().clone())
# This is a flag so we can test that the pack method was called
self.register_buffer('pack_called', torch.zeros(1, dtype=torch.long))
self.pack_called = nn.Buffer(torch.zeros(1, dtype=torch.long))
# This is a flag so we can test that the unpack method was called
self.register_buffer('unpack_called', torch.zeros(1, dtype=torch.long))
self.unpack_called = nn.Buffer(torch.zeros(1, dtype=torch.long))
@torch.jit.script_method
def _pack(self):
@ -9269,7 +9269,7 @@ dedent """
class SubSubMod(torch.jit.ScriptModule):
def __init__(self):
super().__init__()
self.register_buffer('buf', torch.ones(3, 4) * 3)
self.buf = nn.Buffer(torch.ones(3, 4) * 3)
@torch.jit.script_method
def _pack(self):
@ -9286,7 +9286,7 @@ dedent """
class SubMod(torch.jit.ScriptModule):
def __init__(self):
super().__init__()
self.register_buffer('buf', torch.ones(3, 4) * 2)
self.buf = nn.Buffer(torch.ones(3, 4) * 2)
self.ssm = SubSubMod()
@torch.jit.script_method
@ -9305,7 +9305,7 @@ dedent """
def __init__(self):
super().__init__()
self.submod = SubMod()
self.register_buffer('buf', torch.ones(3, 4) * 1)
self.buf = nn.Buffer(torch.ones(3, 4) * 1)
@torch.jit.script_method
def _pack(self):
@ -13111,7 +13111,7 @@ dedent """
self.out_features = out_features
self.weight = torch.nn.Parameter(torch.empty(out_features, in_features))
self.bias = torch.nn.Parameter(torch.empty(out_features))
self.register_buffer('counter', torch.ones(out_features))
self.counter = nn.Buffer(torch.ones(out_features))
self.reset_parameters()
def reset_parameters(self):
@ -13164,7 +13164,7 @@ dedent """
super().__init__()
self.weight = torch.nn.Parameter(torch.ones(out_features, in_features))
self.bias = torch.nn.Parameter(torch.ones(out_features))
self.register_buffer("buffer", torch.ones(out_features))
self.buffer = nn.Buffer(torch.ones(out_features))
self.submodule = Submodule()
def forward(self, x):
@ -13619,8 +13619,8 @@ dedent """
def __init__(self, number):
super().__init__()
self.register_buffer('buffer1', torch.ones(2, 2))
self.register_buffer('buffer2', torch.ones(2, 2))
self.buffer1 = nn.Buffer(torch.ones(2, 2))
self.buffer2 = nn.Buffer(torch.ones(2, 2))
self.number = number
@torch.jit.script_method
@ -13638,8 +13638,8 @@ dedent """
def __init__(self, number, submodule):
super().__init__()
self.register_buffer('buffer1', torch.ones(2, 2))
self.register_buffer('buffer2', torch.ones(2, 2))
self.buffer1 = nn.Buffer(torch.ones(2, 2))
self.buffer2 = nn.Buffer(torch.ones(2, 2))
self.number = number
self.submodule = submodule
@ -13675,8 +13675,8 @@ dedent """
class NoArgState(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer('buffer1', torch.ones(2, 2))
self.register_buffer('buffer2', torch.ones(2, 2))
self.buffer1 = nn.Buffer(torch.ones(2, 2))
self.buffer2 = nn.Buffer(torch.ones(2, 2))
def forward(self):
pass
@ -15091,7 +15091,7 @@ dedent """
def __init__(self):
super().__init__()
tensor = torch.zeros(1, requires_grad=False)
self.register_buffer('some_state', torch.nn.Parameter(tensor))
self.some_state = nn.Buffer(torch.nn.Parameter(tensor))
@torch.jit.script_method
def forward(self, x):
@ -15484,8 +15484,8 @@ dedent """
self.mod = (torch.nn.ReLU())
self.mod2 = (torch.nn.ReLU())
self.mod3 = torch.nn.Sequential(torch.nn.Sequential(torch.nn.ReLU()))
self.register_buffer('x', torch.zeros(3))
self.register_buffer('y', torch.zeros(3))
self.x = nn.Buffer(torch.zeros(3))
self.y = nn.Buffer(torch.zeros(3))
self.z = torch.zeros(3)
def bleh(self):

View File

@ -19,7 +19,7 @@ import torch.nn.functional as F
import itertools
from collections import defaultdict
from torch import inf
from torch.nn import Parameter
from torch.nn import Buffer, Parameter
from torch.testing._internal import opinfo
from torch.testing._internal.common_utils import \
(gradcheck, gradgradcheck, run_tests, TestCase, download_file, IS_CI, NoTest,
@ -7643,14 +7643,14 @@ class TestNNMPS(NNTestCase):
def __init__(self):
super().__init__()
self.layer_dummy_param = Parameter(torch.empty(3, 5))
self.register_buffer('layer_dummy_buf', torch.zeros(1, 3, 3, 7))
self.layer_dummy_buf = Buffer(torch.zeros(1, 3, 3, 7))
class Net(nn.Module):
def __init__(self):
super().__init__()
self.l1 = Layer()
self.dummy_param = Parameter(torch.empty(3, 5))
self.register_buffer('dummy_buf', torch.zeros(7, 3, 3, 1))
self.dummy_buf = Buffer(torch.zeros(7, 3, 3, 1))
l = Layer()
n = Net()

View File

@ -31,7 +31,7 @@ from torch.nn.utils import clip_grad_norm_, clip_grad_value_
from torch.nn.utils import parameters_to_vector, vector_to_parameters
from torch.nn.utils.fusion import fuse_conv_bn_weights
from torch.nn.utils.fusion import fuse_linear_bn_weights
from torch.nn import Parameter
from torch.nn import Buffer, Parameter
from torch.nn.parallel._functions import Broadcast
from torch.testing._internal.common_dtype import integral_types, get_all_math_dtypes, floating_types
from torch.testing._internal.common_utils import freeze_rng_state, run_tests, TestCase, skipIfNoLapack, skipIfRocm, \
@ -365,8 +365,8 @@ class TestNN(NNTestCase):
class M(nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("buffer1", torch.empty(3, 5))
self.register_buffer("buffer2", self.buffer1)
self.buffer1 = Buffer(torch.empty(3, 5))
self.buffer2 = self.buffer1
m = M()
self.assertEqual(names(m.named_buffers()),
@ -425,7 +425,7 @@ class TestNN(NNTestCase):
linear = nn.Linear(2, 2)
linear._test_submodule = nn.Linear(2, 2)
linear._test_parameter = Parameter(torch.empty(2, 2))
linear.register_buffer('_test_buffer', torch.empty(2, 2))
linear._test_buffer = Buffer(torch.empty(2, 2))
keys = dir(linear)
self.assertIn('_test_submodule', keys)
self.assertIn('_test_parameter', keys)
@ -530,6 +530,9 @@ class TestNN(NNTestCase):
with self.assertRaises(KeyError):
m.register_buffer('attribute_name', torch.rand(5))
with self.assertRaises(KeyError):
m.attribute_name = Buffer(torch.rand(5))
del m.attribute_name
m.register_parameter('attribute_name', nn.Parameter())
with self.assertRaises(KeyError):
@ -556,12 +559,18 @@ class TestNN(NNTestCase):
self.assertEqual(m.buffer_name, buffer2)
m.register_buffer('buffer_name', buffer3)
self.assertEqual(m.buffer_name, buffer3)
m.buffer_name = Buffer(buffer1)
self.assertEqual(m.buffer_name, Buffer(buffer1))
m.buffer_name = Buffer(buffer2)
self.assertEqual(m.buffer_name, Buffer(buffer2))
m.buffer_name = Buffer(buffer3)
self.assertEqual(m.buffer_name, Buffer(buffer3))
def test_get_buffer(self):
m = nn.Module()
buffer1 = torch.randn(2, 3)
buffer2 = torch.randn(4, 5)
m.register_buffer('foo', buffer1)
m.foo = Buffer(buffer1)
m.register_buffer('bar', buffer2)
self.assertEqual(buffer1, m.get_buffer('foo'))
self.assertEqual(buffer2, m.get_buffer('bar'))
@ -575,13 +584,13 @@ class TestNN(NNTestCase):
class Sub(nn.Module):
def __init__(self, foo, bar):
super().__init__()
self.register_buffer('foo', foo)
self.foo = Buffer(foo)
self.subsub = SubSub(bar)
class SubSub(nn.Module):
def __init__(self, bar):
super().__init__()
self.register_buffer('bar', bar)
self.bar = Buffer(bar)
foo = torch.randn(2, 3)
bar = torch.randn(4, 5)
@ -591,33 +600,35 @@ class TestNN(NNTestCase):
def test_buffer_not_persistent(self):
m = nn.Module()
m.register_buffer('buf', torch.rand(5), persistent=False)
m.buf = nn.Buffer(torch.rand(5), persistent=False)
self.assertTrue(len(list(m.buffers())) == 1)
self.assertTrue(len(m.state_dict()) == 0)
def test_buffer_not_persistent_del(self):
m = nn.Module()
m.register_buffer('buf', torch.rand(5), persistent=False)
m.buf = nn.Buffer(torch.rand(5), persistent=False)
del m.buf
self.assertTrue(len(list(m.buffers())) == 0)
def test_buffer_not_persistent_overwrite(self):
m = nn.Module()
m.register_buffer('buf', torch.rand(5), persistent=False)
m.register_buffer('buf', torch.rand(5))
m.buf = nn.Buffer(torch.rand(5), persistent=False)
m.buf = nn.Buffer(torch.rand(5))
# can we overwrite a non-persistent buffer with a persistent one?
self.assertTrue(len(list(m.buffers())) == 1)
self.assertTrue(len(m.state_dict()) == 1)
# can we overwrite a persistent buffer with a non-persistent one?
m.register_buffer('buf', torch.rand(5), persistent=False)
m.buf = nn.Buffer(torch.rand(5), persistent=False)
self.assertTrue(len(list(m.buffers())) == 1)
self.assertTrue(len(m.state_dict()) == 0)
def test_buffer_not_persistent_assign(self):
m = nn.Module()
m.register_buffer('buf', torch.rand(5), persistent=False)
m.buf = nn.Buffer(torch.rand(5), persistent=False)
self.assertTrue(len(list(m.buffers())) == 1)
self.assertTrue(len(m.state_dict()) == 0)
# Assigning None removes the buffer but if we then assign a new Tensor
# to the same property, it should still be marked as a buffer.
@ -659,7 +670,7 @@ class TestNN(NNTestCase):
def test_buffer_not_persistent_load(self):
m = nn.Module()
m.register_buffer('buf', torch.rand(5), persistent=False)
m.buf = nn.Buffer(torch.rand(5), persistent=False)
m.load_state_dict({})
def test_register_parameter_raises_error_if_name_is_not_string(self):
@ -681,6 +692,11 @@ class TestNN(NNTestCase):
with self.assertRaises(KeyError):
m.register_parameter('attribute_name', nn.Parameter())
del m.attribute_name
m.attribute_name = Buffer(torch.rand(5))
with self.assertRaises(KeyError):
m.register_parameter('attribute_name', nn.Parameter())
del m.attribute_name
m.add_module('attribute_name', nn.Module())
with self.assertRaises(KeyError):
@ -1625,7 +1641,7 @@ tensor(..., device='meta', size=(1,), requires_grad=True)""")
net.l = l
net.l2 = l
net.add_module('empty', None)
net.register_buffer('indices', torch.LongTensor(1))
net.indices = Buffer(torch.LongTensor(1))
net.float()
self.assertIsInstance(l.weight.data, torch.FloatTensor)
self.assertIsInstance(l.bias.data, torch.FloatTensor)
@ -2811,8 +2827,8 @@ tensor(..., device='meta', size=(1,), requires_grad=True)""")
del l.a, l.b
self.assertEqual(list(l.children()), [])
buf = torch.randn(10)
l.register_buffer('buf', buf)
buf = Buffer(torch.randn(10))
l.buf = buf
self.assertIs(l.buf, buf)
l.buf = None
self.assertIs(l.buf, None)

View File

@ -18,7 +18,7 @@ class MockModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.l1 = torch.nn.Linear(1, 1)
self.register_buffer('buffer', torch.ones(1))
self.buffer = torch.nn.Buffer(torch.ones(1))
self.foo = 0.0
def forward(self, x):
@ -30,8 +30,8 @@ class MockTiedModule(torch.nn.Module):
super().__init__()
self.l1 = torch.nn.Linear(1, 1)
self.tied_bias = self.l1.bias
self.register_buffer('buffer', torch.ones(1))
self.register_buffer('tied_buffer', self.buffer)
self.buffer = torch.nn.Buffer(torch.ones(1))
self.tied_buffer = self.buffer
def forward(self, x):
return self.l1(x) + self.tied_bias + self.buffer + self.tied_buffer
@ -408,7 +408,7 @@ class TestStatelessFunctionalAPI(TestCase):
def test_tied_weights_warns(self, functional_call):
module = MockModule()
module.tied_bias = module.l1.bias
module.register_buffer("tied_buffer", module.buffer)
module.tied_buffer = torch.nn.Buffer(module.buffer)
@parametrize("functional_call", [
subtest(torch.func.functional_call, "torch_func"),
@ -613,7 +613,7 @@ class TestStatelessFunctionalAPI(TestCase):
class Foo(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer('foo', torch.tensor([0.0]))
self.foo = torch.nn.Buffer(torch.tensor([0.0]))
def forward(self, x):
self.foo = self.foo + 1
@ -637,7 +637,7 @@ class TestStatelessFunctionalAPI(TestCase):
class Foo(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer('foo', torch.tensor([0.0]))
self.foo = torch.nn.Buffer(torch.tensor([0.0]))
def forward(self, x):
self.foo.add_(1)
@ -759,7 +759,7 @@ class TestStatelessFunctionalAPI(TestCase):
def __init__(self):
super().__init__()
self.l1 = torch.nn.Linear(1, 1)
self.register_buffer('buffer', torch.ones(1))
self.buffer = torch.nn.Buffer(torch.ones(1))
def forward(self, x):
parameters = tuple(self.parameters())

View File

@ -448,6 +448,7 @@ def istensor(obj):
"""Check of obj is a tensor"""
tensor_list = (
torch.Tensor,
torch.nn.Buffer,
torch.nn.Parameter,
*config.traceable_tensor_subclasses,
)

View File

@ -292,7 +292,12 @@ class VariableBuilder:
# NB: Careful not to close over self to avoid ref cycle from lru_cache
entries = [
(
(torch.Tensor, torch.nn.Parameter, torch._subclasses.FakeTensor),
(
torch.Tensor,
torch.nn.Buffer,
torch.nn.Parameter,
torch._subclasses.FakeTensor,
),
cls.wrap_tensor,
),
((tuple, list, odict_values), cls.wrap_listlike),
@ -882,6 +887,7 @@ class VariableBuilder:
else:
assert type(value) in (
torch.Tensor,
torch.nn.Buffer,
torch.nn.Parameter,
torch._subclasses.fake_tensor.FakeTensor,
), type(value)
@ -1463,7 +1469,7 @@ def _automatic_dynamic(e, tx, name, static_shapes):
def wrap_to_fake_tensor_and_record(
e, tx, ignore_subclass=False, *, source: Optional[Source], is_tensor: bool
):
if type(e) in (torch.Tensor, torch.nn.Parameter) or (
if type(e) in (torch.Tensor, torch.nn.Buffer, torch.nn.Parameter) or (
ignore_subclass and isinstance(e, torch.Tensor)
):
assert source is not None

View File

@ -1426,6 +1426,7 @@ class FakeTensorMode(TorchDispatchMode):
not isinstance(x, FakeTensor)
and type(x) is not torch.Tensor
and type(x) is not torch.nn.Parameter
and type(x) is not torch.nn.Buffer
)
return [

View File

@ -496,6 +496,7 @@ class MetaConverter:
if (
type(t) is torch.Tensor
or type(t) is torch.nn.Buffer
or type(t) is torch.nn.Parameter
or (ignore_subclass and isinstance(t, torch.Tensor))
or isinstance(t, FakeTensor)
@ -544,6 +545,9 @@ class MetaConverter:
# NB: Cannot directly use Parameter constructor
# because that would force a detach, not desirable
r._is_param = True
elif type(t) is torch.nn.Buffer:
# similar to above
r._is_buffer = True
return r
elif torch.overrides.is_tensor_like(t):
# Blindly converting tensor subclasses to meta can cause

View File

@ -387,6 +387,17 @@ def _rebuild_qtensor(
return tensor
def _rebuild_buffer(data, requires_grad, persistent):
buffer = torch.nn.Buffer(data, requires_grad, persistent)
return buffer
def _rebuild_buffer_with_state(data, requires_grad, persistent, state):
buffer = torch.nn.Buffer(data, requires_grad, persistent)
buffer = _set_obj_state(buffer, state)
return buffer
def _rebuild_parameter(data, requires_grad, backward_hooks):
param = torch.nn.Parameter(data, requires_grad)
# NB: This line exists only for backwards compatibility; the

View File

@ -238,7 +238,7 @@ def fetch_sym_proxy(tracer):
def fetch_tensor_proxy(tracer):
return lambda t: get_proxy_slot(t, tracer, t)
HANDLED_TYPES = (torch.Tensor, torch.nn.Parameter)
HANDLED_TYPES = (torch.Tensor, torch.nn.Parameter, torch.nn.Buffer)
def proxy_call(proxy_mode, func, pre_dispatch, args, kwargs):
unrecognized_types = []

View File

@ -1,5 +1,6 @@
from .modules import * # noqa: F403
from .parameter import (
Buffer as Buffer,
Parameter as Parameter,
UninitializedParameter as UninitializedParameter,
UninitializedBuffer as UninitializedBuffer,

View File

@ -5,7 +5,7 @@ import functools
import weakref
import torch
from ..parameter import Parameter
from ..parameter import Parameter, Buffer
import torch.utils.hooks as hooks
from torch import Tensor, device, dtype
@ -1745,16 +1745,16 @@ class Module:
modules[name] = value
else:
buffers = self.__dict__.get('_buffers')
if buffers is not None and name in buffers:
if isinstance(value, Buffer) or buffers is not None and name in buffers:
if value is not None and not isinstance(value, torch.Tensor):
raise TypeError("cannot assign '{}' as buffer '{}' "
"(torch.Tensor or None expected)"
"(torch.nn.Buffer, torch.Tensor or None expected)"
.format(torch.typename(value), name))
for hook in _global_buffer_registration_hooks.values():
output = hook(self, name, value)
if output is not None:
value = output
buffers[name] = value
if isinstance(value, Buffer):
persistent = value.persistent
else:
persistent = name not in self._non_persistent_buffers_set
self.register_buffer(name, value, persistent)
else:
super().__setattr__(name, value)

View File

@ -196,6 +196,74 @@ class UninitializedParameter(UninitializedTensorMixin, Parameter):
memo[id(self)] = result
return result
# Metaclass to combine _TensorMeta and the instance check override for Buffer.
class _BufferMeta(torch._C._TensorMeta):
# Make `isinstance(t, Buffer)` return True for custom tensor instances that have the _is_buffer flag.
def __instancecheck__(self, instance):
return isinstance(instance, torch.Tensor) and getattr(instance, '_is_buffer', False)
class Buffer(torch.Tensor, metaclass=_BufferMeta):
r"""A kind of Tensor that should not be considered a model
parameter. For example, BatchNorm's ``running_mean`` is not a parameter, but is part of the module's state.
Buffers are :class:`~torch.Tensor` subclasses, that have a
very special property when used with :class:`Module` s - when they're
assigned as Module attributes they are automatically added to the list of
its buffers, and will appear e.g. in :meth:`~Module.buffers` iterator.
Assigning a Tensor doesn't have such effect. One can still assign a Tensor as explicitly by using
a the modules `~register_buffer` function.
Args:
data (Tensor): buffer tensor.
requires_grad (bool, optional): if the buffer requires gradient.
Default: `False`
persistent (bool, optional): whether the buffer is part of the module's
:attr:`state_dict`. Default: `True`
"""
def __new__(cls, data=None, requires_grad=False, persistent=True):
if data is None:
data = torch.empty(0)
# Path for custom tensors: set a flag on the instance to indicate buffer-ness.
t = data.detach().requires_grad_(requires_grad)
if type(t) is not type(data) and not isinstance(data, Parameter):
raise RuntimeError(f"Creating a Buffer from an instance of type {type(data).__name__} "
"requires that detach() returns an instance of the same type, but return "
f"type {type(t).__name__} was found instead. To use the type as a "
"Buffer, please correct the detach() semantics defined by "
"its __torch_dispatch__() implementation.")
t.persistent = persistent
t._is_buffer = True
return t
def __deepcopy__(self, memo):
if id(self) in memo:
return memo[id(self)]
else:
result = type(self)(self.data.clone(memory_format=torch.preserve_format), self.requires_grad, self.persistent)
memo[id(self)] = result
return result
def __repr__(self):
return 'Buffer containing:\n' + super().__repr__()
def __reduce_ex__(self, proto):
state = torch._utils._get_obj_state(self)
if not state:
return (
torch._utils._rebuild_buffer,
(self.data, self.requires_grad, self.persistent)
)
return (
torch._utils._rebuild_buffer_with_state,
(self.data, self.requires_grad, self.persistent, state)
)
__torch_function__ = _disabled_torch_function_impl
class UninitializedBuffer(UninitializedTensorMixin, torch.Tensor):
r"""A buffer that is not initialized.
@ -214,7 +282,10 @@ class UninitializedBuffer(UninitializedTensorMixin, torch.Tensor):
cls_to_become = torch.Tensor
def __new__(cls, requires_grad=False, device=None, dtype=None) -> None:
def __new__(cls, requires_grad=False, device=None, dtype=None, persistent=True) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
data = torch.empty(0, **factory_kwargs)
return torch.Tensor._make_subclass(cls, data, requires_grad)
ret = torch.Tensor._make_subclass(cls, data, requires_grad)
ret.persistent = persistent
ret._is_buffer = True
return ret

View File

@ -26,11 +26,22 @@ class UninitializedParameter(Tensor):
dtype: Optional[torch.dtype] = None,
): ...
class UninitializedBuffer(Tensor):
class Buffer(Tensor):
persistent: builtins.bool
def __init__(
self,
data: Tensor = ...,
requires_grad: builtins.bool = ...,
persistent: builtins.bool = ...,
): ...
class UninitializedBuffer(Tensor):
persistent: builtins.bool
def __init__(
self,
data: Tensor = ...,
requires_grad: builtins.bool = ...,
persistent: builtins.bool = ...,
): ...
def materialize(
self,

View File

@ -4736,14 +4736,14 @@ def _create_basic_net():
def __init__(self):
super().__init__()
self.layer_dummy_param = nn.Parameter(torch.empty(3, 5))
self.register_buffer('layer_dummy_buf', torch.zeros(1, 3, 3, 7))
self.layer_dummy_buf = nn.Buffer(torch.zeros(1, 3, 3, 7))
class Net(nn.Module):
def __init__(self):
super().__init__()
self.l1 = Layer()
self.dummy_param = nn.Parameter(torch.empty(3, 5))
self.register_buffer('dummy_buf', torch.zeros(7, 3, 3, 1))
self.dummy_buf = nn.Buffer(torch.zeros(7, 3, 3, 1))
l = Layer()
n = Net()