mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Make adding buffers more like adding parameters (#104069)
Add similar semantics for creating a buffer object similar to creating a parameter. This is done by introducing a new `Buffer` class that can be used for type disambiguation. The underlying functionality of registering a buffer remains the same as the `register_buffer` method has not been changed. The `persistent` parameter in the `Buffer` type is to indicate whether a buffer object should be persistent or not. Other non-test changes have to do with getting the new `Buffer` type recognized by inductor and dynamo. Remaining changes are test changes to make sure that the `Buffer` type can be used as a drop in replacement for `register_buffer` as it just leads to `register_buffer` being called. The addition of this new functionality still allows for normal tensors to be used as buffers so these changes are intended to be backwards compatible. Fixes #35735 Pull Request resolved: https://github.com/pytorch/pytorch/pull/104069 Approved by: https://github.com/mikaylagawarecki
This commit is contained in:
parent
4fc47b4156
commit
32d422f335
|
|
@ -22,6 +22,7 @@ These are the basic building blocks for graphs:
|
|||
:nosignatures:
|
||||
:template: classtemplate.rst
|
||||
|
||||
~parameter.Buffer
|
||||
~parameter.Parameter
|
||||
~parameter.UninitializedParameter
|
||||
~parameter.UninitializedBuffer
|
||||
|
|
|
|||
|
|
@ -58,7 +58,7 @@ class TestFlattenParams(FSDPTest):
|
|||
dim_feedforward=128,
|
||||
dropout=0.1,
|
||||
)
|
||||
module.register_buffer("dummy_buffer", torch.tensor(1.0))
|
||||
module.dummy_buffer = nn.Buffer(torch.tensor(1.0))
|
||||
|
||||
def get_input(device, dtype):
|
||||
torch.manual_seed(1) # keep everything deterministic
|
||||
|
|
|
|||
|
|
@ -453,11 +453,11 @@ class TestFSDPMiscMultiProcess(FSDPTest):
|
|||
|
||||
# Check that `device_id` with `sync_module_states=True` works
|
||||
nested_wrapped_module = init_nested_wrapped_module()
|
||||
nested_wrapped_module.register_buffer(
|
||||
"buf", torch.ones((2, 2), device="cpu") * self.rank
|
||||
nested_wrapped_module.buf = nn.Buffer(
|
||||
torch.ones((2, 2), device="cpu") * self.rank
|
||||
)
|
||||
nested_wrapped_module.module[0].register_buffer(
|
||||
"buf", torch.ones((3, 2), device="cpu") * self.rank
|
||||
nested_wrapped_module.module[0].buf = nn.Buffer(
|
||||
torch.ones((3, 2), device="cpu") * self.rank
|
||||
)
|
||||
nested_wrapped_module = FSDP(
|
||||
nested_wrapped_module,
|
||||
|
|
@ -705,7 +705,7 @@ class TestFSDPMiscMultiThread(FSDPTestMultiThread):
|
|||
torch.manual_seed(rank)
|
||||
torch.cuda.manual_seed(rank)
|
||||
self.lin = nn.Linear(10, 10, bias=False)
|
||||
self.register_buffer("buffer", torch.ones(1) * rank)
|
||||
self.buffer = nn.Buffer(torch.ones(1) * rank)
|
||||
|
||||
m = MyModel(self.rank).cuda()
|
||||
_assert_module_states(
|
||||
|
|
|
|||
|
|
@ -87,7 +87,7 @@ class Model(Module):
|
|||
super().__init__()
|
||||
self.inner = Linear(*INNER_SHAPE)
|
||||
if register_buffers:
|
||||
self.inner.register_buffer("buffer", torch.randn(BUFFER_SHAPE))
|
||||
self.inner.buffer = nn.Buffer(torch.randn(BUFFER_SHAPE))
|
||||
self.inner.register_buffer(
|
||||
"non_persistent_buffer", torch.randn(BUFFER_SHAPE), persistent=False
|
||||
)
|
||||
|
|
@ -97,7 +97,7 @@ class Model(Module):
|
|||
)
|
||||
self.outer = Linear(*OUTER_SHAPE)
|
||||
if register_buffers:
|
||||
self.outer.register_buffer("buffer", torch.randn(BUFFER_SHAPE))
|
||||
self.outer.buffer = nn.Buffer(torch.randn(BUFFER_SHAPE))
|
||||
self.outer.register_buffer(
|
||||
"non_persistent_buffer", torch.randn(BUFFER_SHAPE), persistent=False
|
||||
)
|
||||
|
|
|
|||
|
|
@ -423,7 +423,7 @@ class TestUnshardParams(TestUnshardParamsBase):
|
|||
CUDAInitMode.CUDA_BEFORE,
|
||||
deterministic=True,
|
||||
)
|
||||
model.register_buffer("buffer", torch.ones(1))
|
||||
model.buffer = nn.Buffer(torch.ones(1))
|
||||
# Wrap the top-level with FSDP since `named_parameters()` and
|
||||
# `named_buffers` will contain FSDP prefixes if called on a non-FSDP
|
||||
# root module
|
||||
|
|
@ -436,7 +436,7 @@ class TestUnshardParams(TestUnshardParamsBase):
|
|||
),
|
||||
self.process_group,
|
||||
)
|
||||
fsdp_model.register_buffer("buffer", torch.ones(1))
|
||||
fsdp_model.buffer = nn.Buffer(torch.ones(1))
|
||||
with FSDP.summon_full_params(fsdp_model):
|
||||
for call in ["named_parameters", "named_buffers"]:
|
||||
for (n1, p1), (n2, p2) in itertools.zip_longest(
|
||||
|
|
|
|||
|
|
@ -855,8 +855,7 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
|
|||
torch.nn.Linear(HIDDEN_DIM, HIDDEN_DIM),
|
||||
torch.nn.Linear(HIDDEN_DIM, OUTPUT_DIM),
|
||||
).to(self.device)
|
||||
model.register_buffer(
|
||||
"test_buffer",
|
||||
model.test_buffer = torch.nn.Buffer(
|
||||
torch.ones((1), device=self.device) * self.rank,
|
||||
)
|
||||
# Define models/optimizers for DDP with ZeRO and DDP with local
|
||||
|
|
|
|||
|
|
@ -34,8 +34,8 @@ class TestDataParallel(TestCase):
|
|||
class TestModule(nn.Module):
|
||||
def __init__(self, t):
|
||||
super().__init__()
|
||||
self.register_buffer('t_rg', t)
|
||||
self.register_buffer('t_not_rg', t.clone().detach())
|
||||
self.t_rg = nn.Buffer(t, t.requires_grad)
|
||||
self.t_not_rg = nn.Buffer(t.clone().detach())
|
||||
|
||||
def forward(self, x):
|
||||
return x * self.t_rg + self.t_not_rg
|
||||
|
|
|
|||
|
|
@ -760,9 +760,8 @@ class TestSingleProc(DynamoDistributedSingleProcTestCase):
|
|||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self._param = torch.randn((3,), device="cuda")
|
||||
self.register_buffer(
|
||||
"_buf", torch.randn((3,), requires_grad=False, device="cuda")
|
||||
)
|
||||
self._buf = torch.nn.Buffer(
|
||||
torch.randn((3,), requires_grad=False, device="cuda"))
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# Use `_param` and `_buf` each twice in this compiled forward
|
||||
|
|
@ -789,8 +788,8 @@ class TestSingleProc(DynamoDistributedSingleProcTestCase):
|
|||
class BufModule(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.register_buffer(
|
||||
"_buf", torch.randn((3,), requires_grad=False, device="cuda")
|
||||
self._buf = nn.Buffer(
|
||||
torch.randn((3,), requires_grad=False, device="cuda")
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
|
|
@ -802,7 +801,7 @@ class TestSingleProc(DynamoDistributedSingleProcTestCase):
|
|||
self._param = nn.Parameter(torch.randn((1,), device="cuda"))
|
||||
self._buf_module = BufModule()
|
||||
# Share the buffer, meaning same tensor but different source
|
||||
self.register_buffer("_buf", self._buf_module._buf)
|
||||
self._buf = self._buf_module._buf
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# Use the same buffer tensor twice in the compiled forward,
|
||||
|
|
|
|||
|
|
@ -973,7 +973,7 @@ class ExportTests(torch._dynamo.test_case.TestCase):
|
|||
def __init__(self):
|
||||
super().__init__()
|
||||
self.weight = torch.nn.Parameter(torch.ones(1, 1))
|
||||
self.register_buffer("buffer", torch.ones(1, 1))
|
||||
self.buffer = torch.nn.Buffer(torch.ones(1, 1))
|
||||
|
||||
def forward(self, x):
|
||||
x = torch.nn.functional.linear(x, torch.randn(4, 4))
|
||||
|
|
@ -2668,7 +2668,7 @@ def forward(self, x):
|
|||
class Foo(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.register_buffer("buffer1", torch.ones(6, 2))
|
||||
self.buffer1 = torch.nn.Buffer(torch.ones(6, 2))
|
||||
|
||||
def forward(self, x):
|
||||
x.add_(2)
|
||||
|
|
|
|||
|
|
@ -1451,7 +1451,7 @@ class MockModule(torch.nn.Module):
|
|||
super().__init__()
|
||||
self.relu = torch.nn.ReLU()
|
||||
self.linear = torch.nn.Linear(10, 10)
|
||||
self.register_buffer("buf0", torch.randn(10, 10))
|
||||
self.buf0 = torch.nn.Buffer(torch.randn(10, 10))
|
||||
|
||||
def forward(self, x):
|
||||
return self.relu(self.linear(x) + self.buf0)
|
||||
|
|
@ -1500,7 +1500,7 @@ class OptimizedModuleTest(torch._dynamo.test_case.TestCase):
|
|||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(10, 10)
|
||||
self.register_buffer("buf0", torch.randn(10, 10))
|
||||
self.buf0 = torch.nn.Buffer(torch.randn(10, 10))
|
||||
|
||||
def forward(self, x):
|
||||
return self.r(torch.sin(x)) + self.buf0
|
||||
|
|
@ -1527,7 +1527,7 @@ class OptimizedModuleTest(torch._dynamo.test_case.TestCase):
|
|||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(10, 10)
|
||||
self.register_buffer("buf0", torch.randn(10, 10))
|
||||
self.register_buffer("buf0", torch.nn.Buffer(torch.randn(10, 10)))
|
||||
self.register_parameter(
|
||||
name="param0", param=torch.nn.Parameter(torch.randn(10, 10))
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1859,8 +1859,8 @@ class ReproTests(torch._dynamo.test_case.TestCase):
|
|||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.register_buffer("sorted", torch.ones(4, 4))
|
||||
self.register_buffer("indices", torch.ones(4, 4, dtype=torch.long))
|
||||
self.sorted = torch.nn.Buffer(torch.ones(4, 4))
|
||||
self.indices = torch.nn.Buffer(torch.ones(4, 4, dtype=torch.long))
|
||||
|
||||
def forward(self, x):
|
||||
torch.sort(x, out=(self.sorted, self.indices))
|
||||
|
|
@ -1891,7 +1891,7 @@ class ReproTests(torch._dynamo.test_case.TestCase):
|
|||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.register_buffer("base", torch.ones(4, 4))
|
||||
self.base = torch.nn.Buffer(torch.ones(4, 4))
|
||||
|
||||
def forward(self, x):
|
||||
torch.sigmoid(x, out=self.base)
|
||||
|
|
@ -2174,8 +2174,8 @@ class ReproTests(torch._dynamo.test_case.TestCase):
|
|||
class Foo(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.register_buffer("x", torch.ones(3))
|
||||
self.register_buffer("y", torch.ones(3))
|
||||
self.x = torch.nn.Buffer(torch.ones(3))
|
||||
self.y = torch.nn.Buffer(torch.ones(3))
|
||||
|
||||
def forward(self, inp):
|
||||
res = 0
|
||||
|
|
|
|||
|
|
@ -50,7 +50,7 @@ class TestExperimentalExport(TestCase):
|
|||
class Foo(torch.nn.Module):
|
||||
def __init__(self, float_val):
|
||||
super().__init__()
|
||||
self.register_buffer("buffer1", torch.ones(6, 1))
|
||||
self.buffer1 = torch.nn.Buffer(torch.ones(6, 1))
|
||||
|
||||
def forward(self, x):
|
||||
self.buffer1.add_(2)
|
||||
|
|
|
|||
|
|
@ -178,8 +178,7 @@ class VerifierTest(TestCase):
|
|||
class TestModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.register_buffer(
|
||||
"a",
|
||||
self.a = torch.nn.Buffer(
|
||||
torch.randn(1, 3, 100, 100).to(memory_format=torch.channels_last),
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -1945,7 +1945,7 @@ def forward(self, tangents_1):
|
|||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.register_buffer("buffer", torch.ones(4, 5))
|
||||
self.buffer = torch.nn.Buffer(torch.ones(4, 5))
|
||||
|
||||
def forward(self, x):
|
||||
y = self.buffer.add_(3)
|
||||
|
|
@ -2140,7 +2140,7 @@ class <lambda>(torch.nn.Module):
|
|||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.register_buffer("buffer1", torch.ones(6, 4))
|
||||
self.buffer1 = torch.nn.Buffer(torch.ones(6, 4))
|
||||
|
||||
def forward(self, x):
|
||||
x.add_(4)
|
||||
|
|
@ -2153,7 +2153,7 @@ class <lambda>(torch.nn.Module):
|
|||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.register_buffer("buffer1", torch.ones(6, 4))
|
||||
self.buffer1 = torch.nn.Buffer(torch.ones(6, 4))
|
||||
|
||||
def forward(self, x, y):
|
||||
y.add_(4)
|
||||
|
|
|
|||
|
|
@ -3466,8 +3466,8 @@ class TestMakeFunctional(TestCase):
|
|||
super().__init__()
|
||||
self.bias = nn.Parameter(torch.randn(3))
|
||||
self.linear = nn.Linear(3, 3)
|
||||
self.register_buffer('buffer', torch.randn(3))
|
||||
self.register_buffer('buffer_tied', self.buffer)
|
||||
self.buffer = nn.Buffer(torch.randn(3))
|
||||
self.buffer_tied = self.buffer
|
||||
|
||||
def forward(self, x):
|
||||
x = self.linear(x)
|
||||
|
|
@ -3497,7 +3497,7 @@ class TestMakeFunctional(TestCase):
|
|||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear = nn.Linear(3, 3)
|
||||
self.register_buffer('buffer', torch.randn(3))
|
||||
self.buffer = nn.Buffer(torch.randn(3))
|
||||
|
||||
def forward(self, x):
|
||||
x = self.linear(x)
|
||||
|
|
@ -3517,7 +3517,7 @@ class TestMakeFunctional(TestCase):
|
|||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear = nn.Linear(3, 3)
|
||||
self.register_buffer('buffer', torch.randn(3))
|
||||
self.buffer = nn.Buffer(torch.randn(3))
|
||||
|
||||
def forward(self, x):
|
||||
x = self.linear(x)
|
||||
|
|
@ -3573,8 +3573,8 @@ class TestMakeFunctional(TestCase):
|
|||
self.linear = nn.Linear(3, 3)
|
||||
self.weight = self.linear.weight
|
||||
self.bias = self.linear.bias
|
||||
self.register_buffer('buffer', torch.randn(3))
|
||||
self.register_buffer('buffer_tied', self.buffer)
|
||||
self.buffer = nn.Buffer(torch.randn(3))
|
||||
self.buffer_tied = self.buffer
|
||||
|
||||
def forward(self, x):
|
||||
x = self.linear(x)
|
||||
|
|
|
|||
|
|
@ -565,8 +565,7 @@ class CudaReproTests(TestCase):
|
|||
start = math.log2(0.5)
|
||||
end = math.log2(1 / (2**8))
|
||||
|
||||
self.register_buffer(
|
||||
"scales",
|
||||
self.scales = nn.Buffer(
|
||||
2
|
||||
** torch.arange(
|
||||
start,
|
||||
|
|
|
|||
|
|
@ -6184,9 +6184,7 @@ class CommonTemplate:
|
|||
class Repro(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.register_buffer(
|
||||
"_tensor_constant0", torch.randn([], dtype=torch.float32)
|
||||
)
|
||||
self._tensor_constant0 = nn.Buffer(torch.randn([], dtype=torch.float32))
|
||||
|
||||
def forward(self, arg0_1, arg1_1):
|
||||
convert_element_type = torch.ops.prims.convert_element_type.default(
|
||||
|
|
|
|||
|
|
@ -487,6 +487,7 @@ class TestSaveLoad(JitTestCase):
|
|||
|
||||
self.parameter_b = torch.nn.Parameter(torch.randn(4))
|
||||
self.submodule_b = Submodule()
|
||||
self.buffer_b = torch.nn.Buffer(torch.randn(4))
|
||||
|
||||
m = TestModule()
|
||||
m_loaded = self.getExportImportCopy(torch.jit.script(m))
|
||||
|
|
@ -526,7 +527,7 @@ class TestSaveLoad(JitTestCase):
|
|||
super().__init__()
|
||||
self.foo = torch.nn.Linear(2, 3, device="meta")
|
||||
self.bar = torch.nn.Linear(3, 4)
|
||||
self.register_buffer("buffer", torch.randn(4, device="meta"))
|
||||
self.buffer = torch.nn.Buffer(torch.randn(4, device="meta"))
|
||||
|
||||
def forward(self, x):
|
||||
x = self.foo(x)
|
||||
|
|
@ -1150,6 +1151,7 @@ class TestSaveLoadFlatbuffer(JitTestCase):
|
|||
|
||||
self.parameter_b = torch.nn.Parameter(torch.randn(4))
|
||||
self.submodule_b = Submodule()
|
||||
self.buffer_b = torch.nn.Buffer(torch.randn(4))
|
||||
|
||||
m = TestModule()
|
||||
m_loaded = self.getExportImportCopy(torch.jit.script(m))
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ import pickle
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn.parameter import UninitializedParameter, UninitializedBuffer
|
||||
from torch.nn import Parameter
|
||||
from torch.nn import Buffer, Parameter
|
||||
from torch.testing._internal.common_utils import TestCase, run_tests, suppress_warnings
|
||||
from torch.testing._internal.common_cuda import TEST_CUDA
|
||||
|
||||
|
|
@ -47,29 +47,29 @@ class TestLazyModules(TestCase):
|
|||
@suppress_warnings
|
||||
def test_lazy_module_buffer(self):
|
||||
module = LazyModule()
|
||||
module.register_buffer('test_buffer', UninitializedBuffer())
|
||||
module.test_buffer = UninitializedBuffer()
|
||||
self.assertTrue(module.has_uninitialized_params())
|
||||
state_dict = module.state_dict()
|
||||
self.assertIsInstance(state_dict['test_buffer'], UninitializedBuffer)
|
||||
new_module = LazyModule()
|
||||
# An error is raised when there is an attempt to replace an existing parameter
|
||||
# with an uninitialized one
|
||||
new_module.register_buffer('test_buffer', torch.ones(5, 5))
|
||||
new_module.test_buffer = Buffer(torch.ones(5, 5))
|
||||
with self.assertRaisesRegex(RuntimeError, 'shape of an uninitialized'):
|
||||
new_module.load_state_dict(state_dict)
|
||||
# Uninitialized parameters are overriden when the state dict to be loaded contains a valid one
|
||||
new_module = LazyModule()
|
||||
new_module.register_buffer('test_buffer', torch.ones(5, 5))
|
||||
new_module.test_buffer = Buffer(torch.ones(5, 5))
|
||||
module.load_state_dict(new_module.state_dict())
|
||||
self.assertEqual(module.test_buffer, torch.ones((5, 5)))
|
||||
|
||||
# Uninitialized parameters are left unchanged
|
||||
module = LazyModule()
|
||||
module.register_buffer('test_buffer', UninitializedBuffer())
|
||||
module.test_buffer = UninitializedBuffer()
|
||||
self.assertTrue(module.has_uninitialized_params())
|
||||
|
||||
new_module = LazyModule()
|
||||
new_module.register_buffer('test_buffer', UninitializedBuffer())
|
||||
new_module.test_buffer = UninitializedBuffer()
|
||||
module.load_state_dict(new_module.state_dict())
|
||||
module.load_state_dict(new_module.state_dict())
|
||||
self.assertTrue(module.has_uninitialized_params())
|
||||
|
|
@ -85,7 +85,7 @@ class TestLazyModules(TestCase):
|
|||
@suppress_warnings
|
||||
def test_lazy_module_jit_buffer(self):
|
||||
module = LazyModule()
|
||||
module.register_buffer('test_buffer', UninitializedBuffer())
|
||||
module.test_buffer = UninitializedBuffer()
|
||||
self.assertTrue(module.has_uninitialized_params())
|
||||
with self.assertRaisesRegex(RuntimeError, 'run a forward pass'):
|
||||
torch.jit.script(module)
|
||||
|
|
@ -101,7 +101,7 @@ class TestLazyModules(TestCase):
|
|||
@suppress_warnings
|
||||
def test_lazy_share_memory_buffer(self):
|
||||
module = LazyModule()
|
||||
module.register_buffer('test_buffer', UninitializedBuffer())
|
||||
module.test_buffer = UninitializedBuffer()
|
||||
self.assertTrue(module.has_uninitialized_params())
|
||||
with self.assertRaisesRegex(RuntimeError, 'share memory on an uninitialized'):
|
||||
module.share_memory()
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ import torch.nn as nn
|
|||
import torch.nn.functional as F
|
||||
import torch.nn.init as init
|
||||
import torch.nn.utils.parametrize as parametrize
|
||||
from torch.nn import Parameter
|
||||
from torch.nn import Buffer, Parameter
|
||||
from torch.testing._internal.common_utils import run_tests, skipIfNoLapack, \
|
||||
TemporaryFileName, instantiate_parametrized_tests, set_default_dtype
|
||||
from torch.testing._internal.common_cuda import TEST_MULTIGPU
|
||||
|
|
@ -305,7 +305,7 @@ class TestNNParametrization(NNTestCase):
|
|||
|
||||
# Instantiate parametrizations on buffers. It should work as expected
|
||||
delattr(model, "bias")
|
||||
model.register_buffer("bias", torch.ones(8))
|
||||
model.bias = Buffer(torch.ones(8))
|
||||
parametrize.register_parametrization(model, "bias", FirstZero())
|
||||
parametrize.register_parametrization(model, "bias", LastZero())
|
||||
self.assertTrue(parametrize.is_parametrized(model))
|
||||
|
|
@ -333,8 +333,8 @@ class TestNNParametrization(NNTestCase):
|
|||
class Orthogonal(nn.Module):
|
||||
def __init__(self, n):
|
||||
super().__init__()
|
||||
self.register_buffer("id", torch.eye(n))
|
||||
self.register_buffer("B", torch.empty(n, n))
|
||||
self.id = Buffer(torch.eye(n))
|
||||
self.B = Buffer(torch.empty(n, n))
|
||||
init.orthogonal_(self.B)
|
||||
|
||||
def forward(self, X):
|
||||
|
|
@ -396,7 +396,7 @@ class TestNNParametrization(NNTestCase):
|
|||
class Orthogonal(nn.Module):
|
||||
def __init__(self, n):
|
||||
super().__init__()
|
||||
self.register_buffer("B", torch.eye(n))
|
||||
self.B = Buffer(torch.eye(n))
|
||||
|
||||
def forward(self, X):
|
||||
Id = torch.eye(X.size(0))
|
||||
|
|
|
|||
|
|
@ -297,7 +297,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
self,
|
||||
):
|
||||
super().__init__()
|
||||
self.register_buffer("weight", torch.ones(5))
|
||||
self.weight = torch.nn.Buffer(torch.ones(5))
|
||||
|
||||
def forward(self, x):
|
||||
scale_1 = self.weight.reshape(1, -1, 1, 1)
|
||||
|
|
@ -4214,7 +4214,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
class GatherModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.register_buffer("weight", torch.ones(5))
|
||||
self.weight = torch.nn.Buffer(torch.ones(5))
|
||||
# torch.nn.Embedding is converted to ONNX::Gather.
|
||||
# Constant folding will be triggerred for constant inputs.
|
||||
# This pattern is common for constant mask inputs in transformer models.
|
||||
|
|
@ -4233,7 +4233,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
class GatherModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.register_buffer("weight", torch.ones(2))
|
||||
self.weight = torch.nn.Buffer(torch.ones(2))
|
||||
|
||||
def forward(self, x):
|
||||
# shape is of rank 0
|
||||
|
|
@ -4248,7 +4248,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
class GatherModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.register_buffer("rb", torch.randn(1, 1, 3, 1, 1))
|
||||
self.rb = torch.nn.Buffer(torch.randn(1, 1, 3, 1, 1))
|
||||
|
||||
def forward(self, x):
|
||||
x += self.rb[0]
|
||||
|
|
@ -9394,7 +9394,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
class ShapeModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.register_buffer("weight", torch.ones(5))
|
||||
self.weight = torch.nn.Buffer(torch.ones(5))
|
||||
|
||||
def forward(self, x):
|
||||
shape = self.weight.shape[0]
|
||||
|
|
@ -10895,7 +10895,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
def __init__(self, embedding_dim):
|
||||
super().__init__()
|
||||
self.weights = InnerModule2.get_embedding(embedding_dim)
|
||||
self.register_buffer("_float_tensor", torch.FloatTensor(1))
|
||||
self._float_tensor = torch.nn.Buffer(torch.FloatTensor(1))
|
||||
self.const = 2
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -10957,7 +10957,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
self.embedding_dim = embedding_dim
|
||||
self.const = 2.5
|
||||
self.weights = InnerModule.get_embedding(self.embedding_dim)
|
||||
self.register_buffer("_float_tensor", torch.FloatTensor(1))
|
||||
self._float_tensor = torch.nn.Buffer(torch.FloatTensor(1))
|
||||
|
||||
@staticmethod
|
||||
def get_embedding(embedding_dim: int):
|
||||
|
|
|
|||
|
|
@ -540,7 +540,7 @@ class TestUtilityFuns(_BaseTestCase):
|
|||
self,
|
||||
):
|
||||
super().__init__()
|
||||
self.register_buffer("weight", torch.ones(5))
|
||||
self.weight = torch.nn.Buffer(torch.ones(5))
|
||||
|
||||
def forward(self, x):
|
||||
b = self.weight.reshape(1, -1, 1, 1)
|
||||
|
|
@ -563,7 +563,7 @@ class TestUtilityFuns(_BaseTestCase):
|
|||
self,
|
||||
):
|
||||
super().__init__()
|
||||
self.register_buffer("weight", torch.ones(5))
|
||||
self.weight = torch.nn.Buffer(torch.ones(5))
|
||||
|
||||
def forward(self, x):
|
||||
div = self.weight.div(torch.tensor([1, 2, 3, 4, 5]))
|
||||
|
|
@ -586,7 +586,7 @@ class TestUtilityFuns(_BaseTestCase):
|
|||
self,
|
||||
):
|
||||
super().__init__()
|
||||
self.register_buffer("weight", torch.ones(5))
|
||||
self.weight = torch.nn.Buffer(torch.ones(5))
|
||||
|
||||
def forward(self, x):
|
||||
mul = self.weight.mul(torch.tensor([1, 2, 3, 4, 5]))
|
||||
|
|
@ -609,7 +609,7 @@ class TestUtilityFuns(_BaseTestCase):
|
|||
self,
|
||||
):
|
||||
super().__init__()
|
||||
self.register_buffer("weight", torch.ones(5))
|
||||
self.weight = torch.nn.Buffer(torch.ones(5))
|
||||
|
||||
def forward(self, x):
|
||||
add = self.weight + torch.tensor([1, 2, 3, 4, 5])
|
||||
|
|
@ -640,7 +640,7 @@ class TestUtilityFuns(_BaseTestCase):
|
|||
self,
|
||||
):
|
||||
super().__init__()
|
||||
self.register_buffer("weight", torch.ones(5))
|
||||
self.weight = torch.nn.Buffer(torch.ones(5))
|
||||
|
||||
def forward(self, x):
|
||||
sub = self.weight - torch.tensor([1, 2, 3, 4, 5])
|
||||
|
|
@ -671,7 +671,7 @@ class TestUtilityFuns(_BaseTestCase):
|
|||
self,
|
||||
):
|
||||
super().__init__()
|
||||
self.register_buffer("weight", torch.ones(5))
|
||||
self.weight = torch.nn.Buffer(torch.ones(5))
|
||||
|
||||
def forward(self, x):
|
||||
sqrt = torch.sqrt(self.weight)
|
||||
|
|
@ -691,7 +691,7 @@ class TestUtilityFuns(_BaseTestCase):
|
|||
class ShapeModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.register_buffer("weight", torch.ones(5))
|
||||
self.weight = torch.nn.Buffer(torch.ones(5))
|
||||
|
||||
def forward(self, x):
|
||||
shape = self.weight.shape[0]
|
||||
|
|
|
|||
|
|
@ -1440,7 +1440,7 @@ class TestQuantizedTensor(TestCase):
|
|||
s = torch.rand(5, dtype=torch.float64) + 0.1
|
||||
zp = torch.randint(5, 15, (5,))
|
||||
x_q = torch.quantize_per_channel(x, s, zp, 1, torch.quint8)
|
||||
self.register_buffer('x', x_q)
|
||||
self.x = torch.nn.Buffer(x_q)
|
||||
|
||||
@torch.jit.script_method
|
||||
def forward(self):
|
||||
|
|
|
|||
|
|
@ -94,9 +94,9 @@ class _ReferenceConvBnNd(torch.nn.Conv2d, torch.nn.modules.conv._ConvNd):
|
|||
self.beta = nn.Parameter(torch.empty(out_channels))
|
||||
self.affine = True
|
||||
self.track_running_stats = True
|
||||
self.register_buffer('running_mean', torch.zeros(out_channels))
|
||||
self.register_buffer('running_var', torch.ones(out_channels))
|
||||
self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long))
|
||||
self.running_mean = nn.Buffer(torch.zeros(out_channels))
|
||||
self.running_var = nn.Buffer(torch.ones(out_channels))
|
||||
self.num_batches_tracked = nn.Buffer(torch.tensor(0, dtype=torch.long))
|
||||
self.activation_post_process = self.qconfig.activation()
|
||||
self.weight_fake_quant = self.qconfig.weight()
|
||||
if bias:
|
||||
|
|
|
|||
|
|
@ -817,7 +817,7 @@ class TestFX(JitTestCase):
|
|||
self.mm_param = torch.nn.Parameter(torch.randn(d_hid, d_hid))
|
||||
self.mm_param2 = torch.nn.Parameter(torch.randn(d_hid, d_hid))
|
||||
self.lin = torch.nn.Linear(d_hid, d_hid)
|
||||
self.register_buffer('buffer', torch.randn(bs + 100, d_hid))
|
||||
self.buffer = torch.nn.Buffer(torch.randn(bs + 100, d_hid))
|
||||
|
||||
def forward(self, x):
|
||||
x = torch.mm(x, self.mm_param)
|
||||
|
|
@ -2660,7 +2660,7 @@ class TestFX(JitTestCase):
|
|||
class GetItemBase(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.register_buffer('pe', torch.randn(8, 8))
|
||||
self.pe = torch.nn.Buffer(torch.randn(8, 8))
|
||||
|
||||
class GetItem1(GetItemBase):
|
||||
def forward(self, x):
|
||||
|
|
@ -3026,7 +3026,7 @@ class TestFX(JitTestCase):
|
|||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(100, 200)
|
||||
self.register_buffer("buf", torch.randn(2, 3))
|
||||
self.buf = torch.nn.Buffer(torch.randn(2, 3))
|
||||
self.net_c = C()
|
||||
|
||||
def forward(self, x):
|
||||
|
|
@ -3196,7 +3196,7 @@ class TestFX(JitTestCase):
|
|||
def __init__(self):
|
||||
super().__init__()
|
||||
self.l1 = torch.nn.Linear(1, 1)
|
||||
self.register_buffer('buffer', torch.ones(1))
|
||||
self.buffer = torch.nn.Buffer(torch.ones(1))
|
||||
|
||||
def forward(self, x):
|
||||
return self.l1(x) + self.buffer
|
||||
|
|
|
|||
|
|
@ -1215,8 +1215,8 @@ class {test_classname}(torch.nn.Module):
|
|||
self.seq = torch.nn.Sequential(torch.nn.BatchNorm1d(2, 2))
|
||||
self.linear = torch.nn.Linear(2, 2)
|
||||
self.attr = torch.randn(2)
|
||||
self.register_buffer("attr2", torch.randn(2))
|
||||
self.register_buffer("attr3", torch.ones(2, dtype=torch.int32))
|
||||
self.attr2 = torch.nn.Buffer(torch.randn(2))
|
||||
self.attr3 = torch.nn.Buffer(torch.ones(2, dtype=torch.int32))
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear(self.seq(self.W + self.attr + self.attr2 + self.attr3 + x))
|
||||
|
|
|
|||
|
|
@ -482,7 +482,7 @@ class TestJit(JitTestCase):
|
|||
class MyModule(torch.jit.ScriptModule):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.register_buffer('b0', torch.randn(1, 3))
|
||||
self.b0 = nn.Buffer(torch.randn(1, 3))
|
||||
self.p0 = nn.Parameter(torch.randn(2, 3))
|
||||
|
||||
@torch.jit.script_method
|
||||
|
|
@ -538,7 +538,7 @@ class TestJit(JitTestCase):
|
|||
super().__init__()
|
||||
whole_tensor = torch.randn(4, 5, dtype=torch.float, device='cpu')
|
||||
self.p0 = nn.Parameter(whole_tensor.narrow(0, 0, 1))
|
||||
self.register_buffer('b0', whole_tensor.narrow(0, 3, 1))
|
||||
self.b0 = nn.Buffer(whole_tensor.narrow(0, 3, 1))
|
||||
|
||||
m = Foo()
|
||||
m2 = self.getExportImportCopy(m, map_location=torch.device('cuda:0'))
|
||||
|
|
@ -3989,7 +3989,7 @@ def foo(x):
|
|||
a.p = nn.Parameter(torch.rand(3, 4))
|
||||
a.foo = nn.Module()
|
||||
a.foo.name = 'foo'
|
||||
a.foo.register_buffer('b', torch.rand(1, 1))
|
||||
a.foo.b = nn.Buffer(torch.rand(1, 1))
|
||||
a.foo.bar = nn.Module()
|
||||
a.foo.bar.name = 'bar'
|
||||
a.foo.bar.an_int = 4
|
||||
|
|
@ -8957,7 +8957,7 @@ dedent """
|
|||
class ModuleBufferMutate(torch.jit.ScriptModule):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.register_buffer('running_var', torch.tensor(0, dtype=torch.long))
|
||||
self.running_var = nn.Buffer(torch.tensor(0, dtype=torch.long))
|
||||
|
||||
@torch.jit.script_method
|
||||
def forward(self):
|
||||
|
|
@ -9084,12 +9084,12 @@ dedent """
|
|||
def __init__(self):
|
||||
super(TestScript.DerivedStateModule, self).__init__()
|
||||
self.param = torch.nn.Parameter(torch.ones(3, 4, dtype=torch.float))
|
||||
self.register_buffer('derived', torch.neg(self.param).detach().clone())
|
||||
self.derived = nn.Buffer(torch.neg(self.param).detach().clone())
|
||||
|
||||
# This is a flag so we can test that the pack method was called
|
||||
self.register_buffer('pack_called', torch.zeros(1, dtype=torch.long))
|
||||
self.pack_called = nn.Buffer(torch.zeros(1, dtype=torch.long))
|
||||
# This is a flag so we can test that the unpack method was called
|
||||
self.register_buffer('unpack_called', torch.zeros(1, dtype=torch.long))
|
||||
self.unpack_called = nn.Buffer(torch.zeros(1, dtype=torch.long))
|
||||
|
||||
@torch.jit.script_method
|
||||
def _pack(self):
|
||||
|
|
@ -9269,7 +9269,7 @@ dedent """
|
|||
class SubSubMod(torch.jit.ScriptModule):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.register_buffer('buf', torch.ones(3, 4) * 3)
|
||||
self.buf = nn.Buffer(torch.ones(3, 4) * 3)
|
||||
|
||||
@torch.jit.script_method
|
||||
def _pack(self):
|
||||
|
|
@ -9286,7 +9286,7 @@ dedent """
|
|||
class SubMod(torch.jit.ScriptModule):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.register_buffer('buf', torch.ones(3, 4) * 2)
|
||||
self.buf = nn.Buffer(torch.ones(3, 4) * 2)
|
||||
self.ssm = SubSubMod()
|
||||
|
||||
@torch.jit.script_method
|
||||
|
|
@ -9305,7 +9305,7 @@ dedent """
|
|||
def __init__(self):
|
||||
super().__init__()
|
||||
self.submod = SubMod()
|
||||
self.register_buffer('buf', torch.ones(3, 4) * 1)
|
||||
self.buf = nn.Buffer(torch.ones(3, 4) * 1)
|
||||
|
||||
@torch.jit.script_method
|
||||
def _pack(self):
|
||||
|
|
@ -13111,7 +13111,7 @@ dedent """
|
|||
self.out_features = out_features
|
||||
self.weight = torch.nn.Parameter(torch.empty(out_features, in_features))
|
||||
self.bias = torch.nn.Parameter(torch.empty(out_features))
|
||||
self.register_buffer('counter', torch.ones(out_features))
|
||||
self.counter = nn.Buffer(torch.ones(out_features))
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
|
|
@ -13164,7 +13164,7 @@ dedent """
|
|||
super().__init__()
|
||||
self.weight = torch.nn.Parameter(torch.ones(out_features, in_features))
|
||||
self.bias = torch.nn.Parameter(torch.ones(out_features))
|
||||
self.register_buffer("buffer", torch.ones(out_features))
|
||||
self.buffer = nn.Buffer(torch.ones(out_features))
|
||||
self.submodule = Submodule()
|
||||
|
||||
def forward(self, x):
|
||||
|
|
@ -13619,8 +13619,8 @@ dedent """
|
|||
|
||||
def __init__(self, number):
|
||||
super().__init__()
|
||||
self.register_buffer('buffer1', torch.ones(2, 2))
|
||||
self.register_buffer('buffer2', torch.ones(2, 2))
|
||||
self.buffer1 = nn.Buffer(torch.ones(2, 2))
|
||||
self.buffer2 = nn.Buffer(torch.ones(2, 2))
|
||||
self.number = number
|
||||
|
||||
@torch.jit.script_method
|
||||
|
|
@ -13638,8 +13638,8 @@ dedent """
|
|||
|
||||
def __init__(self, number, submodule):
|
||||
super().__init__()
|
||||
self.register_buffer('buffer1', torch.ones(2, 2))
|
||||
self.register_buffer('buffer2', torch.ones(2, 2))
|
||||
self.buffer1 = nn.Buffer(torch.ones(2, 2))
|
||||
self.buffer2 = nn.Buffer(torch.ones(2, 2))
|
||||
self.number = number
|
||||
self.submodule = submodule
|
||||
|
||||
|
|
@ -13675,8 +13675,8 @@ dedent """
|
|||
class NoArgState(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.register_buffer('buffer1', torch.ones(2, 2))
|
||||
self.register_buffer('buffer2', torch.ones(2, 2))
|
||||
self.buffer1 = nn.Buffer(torch.ones(2, 2))
|
||||
self.buffer2 = nn.Buffer(torch.ones(2, 2))
|
||||
|
||||
def forward(self):
|
||||
pass
|
||||
|
|
@ -15091,7 +15091,7 @@ dedent """
|
|||
def __init__(self):
|
||||
super().__init__()
|
||||
tensor = torch.zeros(1, requires_grad=False)
|
||||
self.register_buffer('some_state', torch.nn.Parameter(tensor))
|
||||
self.some_state = nn.Buffer(torch.nn.Parameter(tensor))
|
||||
|
||||
@torch.jit.script_method
|
||||
def forward(self, x):
|
||||
|
|
@ -15484,8 +15484,8 @@ dedent """
|
|||
self.mod = (torch.nn.ReLU())
|
||||
self.mod2 = (torch.nn.ReLU())
|
||||
self.mod3 = torch.nn.Sequential(torch.nn.Sequential(torch.nn.ReLU()))
|
||||
self.register_buffer('x', torch.zeros(3))
|
||||
self.register_buffer('y', torch.zeros(3))
|
||||
self.x = nn.Buffer(torch.zeros(3))
|
||||
self.y = nn.Buffer(torch.zeros(3))
|
||||
self.z = torch.zeros(3)
|
||||
|
||||
def bleh(self):
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ import torch.nn.functional as F
|
|||
import itertools
|
||||
from collections import defaultdict
|
||||
from torch import inf
|
||||
from torch.nn import Parameter
|
||||
from torch.nn import Buffer, Parameter
|
||||
from torch.testing._internal import opinfo
|
||||
from torch.testing._internal.common_utils import \
|
||||
(gradcheck, gradgradcheck, run_tests, TestCase, download_file, IS_CI, NoTest,
|
||||
|
|
@ -7643,14 +7643,14 @@ class TestNNMPS(NNTestCase):
|
|||
def __init__(self):
|
||||
super().__init__()
|
||||
self.layer_dummy_param = Parameter(torch.empty(3, 5))
|
||||
self.register_buffer('layer_dummy_buf', torch.zeros(1, 3, 3, 7))
|
||||
self.layer_dummy_buf = Buffer(torch.zeros(1, 3, 3, 7))
|
||||
|
||||
class Net(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.l1 = Layer()
|
||||
self.dummy_param = Parameter(torch.empty(3, 5))
|
||||
self.register_buffer('dummy_buf', torch.zeros(7, 3, 3, 1))
|
||||
self.dummy_buf = Buffer(torch.zeros(7, 3, 3, 1))
|
||||
|
||||
l = Layer()
|
||||
n = Net()
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ from torch.nn.utils import clip_grad_norm_, clip_grad_value_
|
|||
from torch.nn.utils import parameters_to_vector, vector_to_parameters
|
||||
from torch.nn.utils.fusion import fuse_conv_bn_weights
|
||||
from torch.nn.utils.fusion import fuse_linear_bn_weights
|
||||
from torch.nn import Parameter
|
||||
from torch.nn import Buffer, Parameter
|
||||
from torch.nn.parallel._functions import Broadcast
|
||||
from torch.testing._internal.common_dtype import integral_types, get_all_math_dtypes, floating_types
|
||||
from torch.testing._internal.common_utils import freeze_rng_state, run_tests, TestCase, skipIfNoLapack, skipIfRocm, \
|
||||
|
|
@ -365,8 +365,8 @@ class TestNN(NNTestCase):
|
|||
class M(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.register_buffer("buffer1", torch.empty(3, 5))
|
||||
self.register_buffer("buffer2", self.buffer1)
|
||||
self.buffer1 = Buffer(torch.empty(3, 5))
|
||||
self.buffer2 = self.buffer1
|
||||
|
||||
m = M()
|
||||
self.assertEqual(names(m.named_buffers()),
|
||||
|
|
@ -425,7 +425,7 @@ class TestNN(NNTestCase):
|
|||
linear = nn.Linear(2, 2)
|
||||
linear._test_submodule = nn.Linear(2, 2)
|
||||
linear._test_parameter = Parameter(torch.empty(2, 2))
|
||||
linear.register_buffer('_test_buffer', torch.empty(2, 2))
|
||||
linear._test_buffer = Buffer(torch.empty(2, 2))
|
||||
keys = dir(linear)
|
||||
self.assertIn('_test_submodule', keys)
|
||||
self.assertIn('_test_parameter', keys)
|
||||
|
|
@ -530,6 +530,9 @@ class TestNN(NNTestCase):
|
|||
with self.assertRaises(KeyError):
|
||||
m.register_buffer('attribute_name', torch.rand(5))
|
||||
|
||||
with self.assertRaises(KeyError):
|
||||
m.attribute_name = Buffer(torch.rand(5))
|
||||
|
||||
del m.attribute_name
|
||||
m.register_parameter('attribute_name', nn.Parameter())
|
||||
with self.assertRaises(KeyError):
|
||||
|
|
@ -556,12 +559,18 @@ class TestNN(NNTestCase):
|
|||
self.assertEqual(m.buffer_name, buffer2)
|
||||
m.register_buffer('buffer_name', buffer3)
|
||||
self.assertEqual(m.buffer_name, buffer3)
|
||||
m.buffer_name = Buffer(buffer1)
|
||||
self.assertEqual(m.buffer_name, Buffer(buffer1))
|
||||
m.buffer_name = Buffer(buffer2)
|
||||
self.assertEqual(m.buffer_name, Buffer(buffer2))
|
||||
m.buffer_name = Buffer(buffer3)
|
||||
self.assertEqual(m.buffer_name, Buffer(buffer3))
|
||||
|
||||
def test_get_buffer(self):
|
||||
m = nn.Module()
|
||||
buffer1 = torch.randn(2, 3)
|
||||
buffer2 = torch.randn(4, 5)
|
||||
m.register_buffer('foo', buffer1)
|
||||
m.foo = Buffer(buffer1)
|
||||
m.register_buffer('bar', buffer2)
|
||||
self.assertEqual(buffer1, m.get_buffer('foo'))
|
||||
self.assertEqual(buffer2, m.get_buffer('bar'))
|
||||
|
|
@ -575,13 +584,13 @@ class TestNN(NNTestCase):
|
|||
class Sub(nn.Module):
|
||||
def __init__(self, foo, bar):
|
||||
super().__init__()
|
||||
self.register_buffer('foo', foo)
|
||||
self.foo = Buffer(foo)
|
||||
self.subsub = SubSub(bar)
|
||||
|
||||
class SubSub(nn.Module):
|
||||
def __init__(self, bar):
|
||||
super().__init__()
|
||||
self.register_buffer('bar', bar)
|
||||
self.bar = Buffer(bar)
|
||||
|
||||
foo = torch.randn(2, 3)
|
||||
bar = torch.randn(4, 5)
|
||||
|
|
@ -591,33 +600,35 @@ class TestNN(NNTestCase):
|
|||
|
||||
def test_buffer_not_persistent(self):
|
||||
m = nn.Module()
|
||||
m.register_buffer('buf', torch.rand(5), persistent=False)
|
||||
m.buf = nn.Buffer(torch.rand(5), persistent=False)
|
||||
self.assertTrue(len(list(m.buffers())) == 1)
|
||||
self.assertTrue(len(m.state_dict()) == 0)
|
||||
|
||||
def test_buffer_not_persistent_del(self):
|
||||
m = nn.Module()
|
||||
m.register_buffer('buf', torch.rand(5), persistent=False)
|
||||
m.buf = nn.Buffer(torch.rand(5), persistent=False)
|
||||
del m.buf
|
||||
self.assertTrue(len(list(m.buffers())) == 0)
|
||||
|
||||
def test_buffer_not_persistent_overwrite(self):
|
||||
m = nn.Module()
|
||||
m.register_buffer('buf', torch.rand(5), persistent=False)
|
||||
m.register_buffer('buf', torch.rand(5))
|
||||
m.buf = nn.Buffer(torch.rand(5), persistent=False)
|
||||
m.buf = nn.Buffer(torch.rand(5))
|
||||
|
||||
# can we overwrite a non-persistent buffer with a persistent one?
|
||||
self.assertTrue(len(list(m.buffers())) == 1)
|
||||
self.assertTrue(len(m.state_dict()) == 1)
|
||||
|
||||
# can we overwrite a persistent buffer with a non-persistent one?
|
||||
m.register_buffer('buf', torch.rand(5), persistent=False)
|
||||
m.buf = nn.Buffer(torch.rand(5), persistent=False)
|
||||
self.assertTrue(len(list(m.buffers())) == 1)
|
||||
self.assertTrue(len(m.state_dict()) == 0)
|
||||
|
||||
def test_buffer_not_persistent_assign(self):
|
||||
m = nn.Module()
|
||||
m.register_buffer('buf', torch.rand(5), persistent=False)
|
||||
m.buf = nn.Buffer(torch.rand(5), persistent=False)
|
||||
self.assertTrue(len(list(m.buffers())) == 1)
|
||||
self.assertTrue(len(m.state_dict()) == 0)
|
||||
|
||||
# Assigning None removes the buffer but if we then assign a new Tensor
|
||||
# to the same property, it should still be marked as a buffer.
|
||||
|
|
@ -659,7 +670,7 @@ class TestNN(NNTestCase):
|
|||
|
||||
def test_buffer_not_persistent_load(self):
|
||||
m = nn.Module()
|
||||
m.register_buffer('buf', torch.rand(5), persistent=False)
|
||||
m.buf = nn.Buffer(torch.rand(5), persistent=False)
|
||||
m.load_state_dict({})
|
||||
|
||||
def test_register_parameter_raises_error_if_name_is_not_string(self):
|
||||
|
|
@ -681,6 +692,11 @@ class TestNN(NNTestCase):
|
|||
with self.assertRaises(KeyError):
|
||||
m.register_parameter('attribute_name', nn.Parameter())
|
||||
|
||||
del m.attribute_name
|
||||
m.attribute_name = Buffer(torch.rand(5))
|
||||
with self.assertRaises(KeyError):
|
||||
m.register_parameter('attribute_name', nn.Parameter())
|
||||
|
||||
del m.attribute_name
|
||||
m.add_module('attribute_name', nn.Module())
|
||||
with self.assertRaises(KeyError):
|
||||
|
|
@ -1625,7 +1641,7 @@ tensor(..., device='meta', size=(1,), requires_grad=True)""")
|
|||
net.l = l
|
||||
net.l2 = l
|
||||
net.add_module('empty', None)
|
||||
net.register_buffer('indices', torch.LongTensor(1))
|
||||
net.indices = Buffer(torch.LongTensor(1))
|
||||
net.float()
|
||||
self.assertIsInstance(l.weight.data, torch.FloatTensor)
|
||||
self.assertIsInstance(l.bias.data, torch.FloatTensor)
|
||||
|
|
@ -2811,8 +2827,8 @@ tensor(..., device='meta', size=(1,), requires_grad=True)""")
|
|||
del l.a, l.b
|
||||
self.assertEqual(list(l.children()), [])
|
||||
|
||||
buf = torch.randn(10)
|
||||
l.register_buffer('buf', buf)
|
||||
buf = Buffer(torch.randn(10))
|
||||
l.buf = buf
|
||||
self.assertIs(l.buf, buf)
|
||||
l.buf = None
|
||||
self.assertIs(l.buf, None)
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ class MockModule(torch.nn.Module):
|
|||
def __init__(self):
|
||||
super().__init__()
|
||||
self.l1 = torch.nn.Linear(1, 1)
|
||||
self.register_buffer('buffer', torch.ones(1))
|
||||
self.buffer = torch.nn.Buffer(torch.ones(1))
|
||||
self.foo = 0.0
|
||||
|
||||
def forward(self, x):
|
||||
|
|
@ -30,8 +30,8 @@ class MockTiedModule(torch.nn.Module):
|
|||
super().__init__()
|
||||
self.l1 = torch.nn.Linear(1, 1)
|
||||
self.tied_bias = self.l1.bias
|
||||
self.register_buffer('buffer', torch.ones(1))
|
||||
self.register_buffer('tied_buffer', self.buffer)
|
||||
self.buffer = torch.nn.Buffer(torch.ones(1))
|
||||
self.tied_buffer = self.buffer
|
||||
|
||||
def forward(self, x):
|
||||
return self.l1(x) + self.tied_bias + self.buffer + self.tied_buffer
|
||||
|
|
@ -408,7 +408,7 @@ class TestStatelessFunctionalAPI(TestCase):
|
|||
def test_tied_weights_warns(self, functional_call):
|
||||
module = MockModule()
|
||||
module.tied_bias = module.l1.bias
|
||||
module.register_buffer("tied_buffer", module.buffer)
|
||||
module.tied_buffer = torch.nn.Buffer(module.buffer)
|
||||
|
||||
@parametrize("functional_call", [
|
||||
subtest(torch.func.functional_call, "torch_func"),
|
||||
|
|
@ -613,7 +613,7 @@ class TestStatelessFunctionalAPI(TestCase):
|
|||
class Foo(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.register_buffer('foo', torch.tensor([0.0]))
|
||||
self.foo = torch.nn.Buffer(torch.tensor([0.0]))
|
||||
|
||||
def forward(self, x):
|
||||
self.foo = self.foo + 1
|
||||
|
|
@ -637,7 +637,7 @@ class TestStatelessFunctionalAPI(TestCase):
|
|||
class Foo(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.register_buffer('foo', torch.tensor([0.0]))
|
||||
self.foo = torch.nn.Buffer(torch.tensor([0.0]))
|
||||
|
||||
def forward(self, x):
|
||||
self.foo.add_(1)
|
||||
|
|
@ -759,7 +759,7 @@ class TestStatelessFunctionalAPI(TestCase):
|
|||
def __init__(self):
|
||||
super().__init__()
|
||||
self.l1 = torch.nn.Linear(1, 1)
|
||||
self.register_buffer('buffer', torch.ones(1))
|
||||
self.buffer = torch.nn.Buffer(torch.ones(1))
|
||||
|
||||
def forward(self, x):
|
||||
parameters = tuple(self.parameters())
|
||||
|
|
|
|||
|
|
@ -448,6 +448,7 @@ def istensor(obj):
|
|||
"""Check of obj is a tensor"""
|
||||
tensor_list = (
|
||||
torch.Tensor,
|
||||
torch.nn.Buffer,
|
||||
torch.nn.Parameter,
|
||||
*config.traceable_tensor_subclasses,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -292,7 +292,12 @@ class VariableBuilder:
|
|||
# NB: Careful not to close over self to avoid ref cycle from lru_cache
|
||||
entries = [
|
||||
(
|
||||
(torch.Tensor, torch.nn.Parameter, torch._subclasses.FakeTensor),
|
||||
(
|
||||
torch.Tensor,
|
||||
torch.nn.Buffer,
|
||||
torch.nn.Parameter,
|
||||
torch._subclasses.FakeTensor,
|
||||
),
|
||||
cls.wrap_tensor,
|
||||
),
|
||||
((tuple, list, odict_values), cls.wrap_listlike),
|
||||
|
|
@ -882,6 +887,7 @@ class VariableBuilder:
|
|||
else:
|
||||
assert type(value) in (
|
||||
torch.Tensor,
|
||||
torch.nn.Buffer,
|
||||
torch.nn.Parameter,
|
||||
torch._subclasses.fake_tensor.FakeTensor,
|
||||
), type(value)
|
||||
|
|
@ -1463,7 +1469,7 @@ def _automatic_dynamic(e, tx, name, static_shapes):
|
|||
def wrap_to_fake_tensor_and_record(
|
||||
e, tx, ignore_subclass=False, *, source: Optional[Source], is_tensor: bool
|
||||
):
|
||||
if type(e) in (torch.Tensor, torch.nn.Parameter) or (
|
||||
if type(e) in (torch.Tensor, torch.nn.Buffer, torch.nn.Parameter) or (
|
||||
ignore_subclass and isinstance(e, torch.Tensor)
|
||||
):
|
||||
assert source is not None
|
||||
|
|
|
|||
|
|
@ -1426,6 +1426,7 @@ class FakeTensorMode(TorchDispatchMode):
|
|||
not isinstance(x, FakeTensor)
|
||||
and type(x) is not torch.Tensor
|
||||
and type(x) is not torch.nn.Parameter
|
||||
and type(x) is not torch.nn.Buffer
|
||||
)
|
||||
|
||||
return [
|
||||
|
|
|
|||
|
|
@ -496,6 +496,7 @@ class MetaConverter:
|
|||
|
||||
if (
|
||||
type(t) is torch.Tensor
|
||||
or type(t) is torch.nn.Buffer
|
||||
or type(t) is torch.nn.Parameter
|
||||
or (ignore_subclass and isinstance(t, torch.Tensor))
|
||||
or isinstance(t, FakeTensor)
|
||||
|
|
@ -544,6 +545,9 @@ class MetaConverter:
|
|||
# NB: Cannot directly use Parameter constructor
|
||||
# because that would force a detach, not desirable
|
||||
r._is_param = True
|
||||
elif type(t) is torch.nn.Buffer:
|
||||
# similar to above
|
||||
r._is_buffer = True
|
||||
return r
|
||||
elif torch.overrides.is_tensor_like(t):
|
||||
# Blindly converting tensor subclasses to meta can cause
|
||||
|
|
|
|||
|
|
@ -387,6 +387,17 @@ def _rebuild_qtensor(
|
|||
return tensor
|
||||
|
||||
|
||||
def _rebuild_buffer(data, requires_grad, persistent):
|
||||
buffer = torch.nn.Buffer(data, requires_grad, persistent)
|
||||
return buffer
|
||||
|
||||
|
||||
def _rebuild_buffer_with_state(data, requires_grad, persistent, state):
|
||||
buffer = torch.nn.Buffer(data, requires_grad, persistent)
|
||||
buffer = _set_obj_state(buffer, state)
|
||||
return buffer
|
||||
|
||||
|
||||
def _rebuild_parameter(data, requires_grad, backward_hooks):
|
||||
param = torch.nn.Parameter(data, requires_grad)
|
||||
# NB: This line exists only for backwards compatibility; the
|
||||
|
|
|
|||
|
|
@ -238,7 +238,7 @@ def fetch_sym_proxy(tracer):
|
|||
def fetch_tensor_proxy(tracer):
|
||||
return lambda t: get_proxy_slot(t, tracer, t)
|
||||
|
||||
HANDLED_TYPES = (torch.Tensor, torch.nn.Parameter)
|
||||
HANDLED_TYPES = (torch.Tensor, torch.nn.Parameter, torch.nn.Buffer)
|
||||
|
||||
def proxy_call(proxy_mode, func, pre_dispatch, args, kwargs):
|
||||
unrecognized_types = []
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
from .modules import * # noqa: F403
|
||||
from .parameter import (
|
||||
Buffer as Buffer,
|
||||
Parameter as Parameter,
|
||||
UninitializedParameter as UninitializedParameter,
|
||||
UninitializedBuffer as UninitializedBuffer,
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ import functools
|
|||
import weakref
|
||||
|
||||
import torch
|
||||
from ..parameter import Parameter
|
||||
from ..parameter import Parameter, Buffer
|
||||
import torch.utils.hooks as hooks
|
||||
|
||||
from torch import Tensor, device, dtype
|
||||
|
|
@ -1745,16 +1745,16 @@ class Module:
|
|||
modules[name] = value
|
||||
else:
|
||||
buffers = self.__dict__.get('_buffers')
|
||||
if buffers is not None and name in buffers:
|
||||
if isinstance(value, Buffer) or buffers is not None and name in buffers:
|
||||
if value is not None and not isinstance(value, torch.Tensor):
|
||||
raise TypeError("cannot assign '{}' as buffer '{}' "
|
||||
"(torch.Tensor or None expected)"
|
||||
"(torch.nn.Buffer, torch.Tensor or None expected)"
|
||||
.format(torch.typename(value), name))
|
||||
for hook in _global_buffer_registration_hooks.values():
|
||||
output = hook(self, name, value)
|
||||
if output is not None:
|
||||
value = output
|
||||
buffers[name] = value
|
||||
if isinstance(value, Buffer):
|
||||
persistent = value.persistent
|
||||
else:
|
||||
persistent = name not in self._non_persistent_buffers_set
|
||||
self.register_buffer(name, value, persistent)
|
||||
else:
|
||||
super().__setattr__(name, value)
|
||||
|
||||
|
|
|
|||
|
|
@ -196,6 +196,74 @@ class UninitializedParameter(UninitializedTensorMixin, Parameter):
|
|||
memo[id(self)] = result
|
||||
return result
|
||||
|
||||
# Metaclass to combine _TensorMeta and the instance check override for Buffer.
|
||||
class _BufferMeta(torch._C._TensorMeta):
|
||||
# Make `isinstance(t, Buffer)` return True for custom tensor instances that have the _is_buffer flag.
|
||||
def __instancecheck__(self, instance):
|
||||
return isinstance(instance, torch.Tensor) and getattr(instance, '_is_buffer', False)
|
||||
|
||||
|
||||
class Buffer(torch.Tensor, metaclass=_BufferMeta):
|
||||
r"""A kind of Tensor that should not be considered a model
|
||||
parameter. For example, BatchNorm's ``running_mean`` is not a parameter, but is part of the module's state.
|
||||
|
||||
Buffers are :class:`~torch.Tensor` subclasses, that have a
|
||||
very special property when used with :class:`Module` s - when they're
|
||||
assigned as Module attributes they are automatically added to the list of
|
||||
its buffers, and will appear e.g. in :meth:`~Module.buffers` iterator.
|
||||
Assigning a Tensor doesn't have such effect. One can still assign a Tensor as explicitly by using
|
||||
a the modules `~register_buffer` function.
|
||||
|
||||
Args:
|
||||
data (Tensor): buffer tensor.
|
||||
requires_grad (bool, optional): if the buffer requires gradient.
|
||||
Default: `False`
|
||||
persistent (bool, optional): whether the buffer is part of the module's
|
||||
:attr:`state_dict`. Default: `True`
|
||||
"""
|
||||
def __new__(cls, data=None, requires_grad=False, persistent=True):
|
||||
if data is None:
|
||||
data = torch.empty(0)
|
||||
|
||||
# Path for custom tensors: set a flag on the instance to indicate buffer-ness.
|
||||
t = data.detach().requires_grad_(requires_grad)
|
||||
if type(t) is not type(data) and not isinstance(data, Parameter):
|
||||
raise RuntimeError(f"Creating a Buffer from an instance of type {type(data).__name__} "
|
||||
"requires that detach() returns an instance of the same type, but return "
|
||||
f"type {type(t).__name__} was found instead. To use the type as a "
|
||||
"Buffer, please correct the detach() semantics defined by "
|
||||
"its __torch_dispatch__() implementation.")
|
||||
t.persistent = persistent
|
||||
t._is_buffer = True
|
||||
return t
|
||||
|
||||
def __deepcopy__(self, memo):
|
||||
if id(self) in memo:
|
||||
return memo[id(self)]
|
||||
else:
|
||||
result = type(self)(self.data.clone(memory_format=torch.preserve_format), self.requires_grad, self.persistent)
|
||||
memo[id(self)] = result
|
||||
return result
|
||||
|
||||
def __repr__(self):
|
||||
return 'Buffer containing:\n' + super().__repr__()
|
||||
|
||||
def __reduce_ex__(self, proto):
|
||||
state = torch._utils._get_obj_state(self)
|
||||
|
||||
if not state:
|
||||
return (
|
||||
torch._utils._rebuild_buffer,
|
||||
(self.data, self.requires_grad, self.persistent)
|
||||
)
|
||||
|
||||
return (
|
||||
torch._utils._rebuild_buffer_with_state,
|
||||
(self.data, self.requires_grad, self.persistent, state)
|
||||
)
|
||||
|
||||
__torch_function__ = _disabled_torch_function_impl
|
||||
|
||||
class UninitializedBuffer(UninitializedTensorMixin, torch.Tensor):
|
||||
r"""A buffer that is not initialized.
|
||||
|
||||
|
|
@ -214,7 +282,10 @@ class UninitializedBuffer(UninitializedTensorMixin, torch.Tensor):
|
|||
|
||||
cls_to_become = torch.Tensor
|
||||
|
||||
def __new__(cls, requires_grad=False, device=None, dtype=None) -> None:
|
||||
def __new__(cls, requires_grad=False, device=None, dtype=None, persistent=True) -> None:
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
data = torch.empty(0, **factory_kwargs)
|
||||
return torch.Tensor._make_subclass(cls, data, requires_grad)
|
||||
ret = torch.Tensor._make_subclass(cls, data, requires_grad)
|
||||
ret.persistent = persistent
|
||||
ret._is_buffer = True
|
||||
return ret
|
||||
|
|
|
|||
|
|
@ -26,11 +26,22 @@ class UninitializedParameter(Tensor):
|
|||
dtype: Optional[torch.dtype] = None,
|
||||
): ...
|
||||
|
||||
class UninitializedBuffer(Tensor):
|
||||
class Buffer(Tensor):
|
||||
persistent: builtins.bool
|
||||
def __init__(
|
||||
self,
|
||||
data: Tensor = ...,
|
||||
requires_grad: builtins.bool = ...,
|
||||
persistent: builtins.bool = ...,
|
||||
): ...
|
||||
|
||||
class UninitializedBuffer(Tensor):
|
||||
persistent: builtins.bool
|
||||
def __init__(
|
||||
self,
|
||||
data: Tensor = ...,
|
||||
requires_grad: builtins.bool = ...,
|
||||
persistent: builtins.bool = ...,
|
||||
): ...
|
||||
def materialize(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -4736,14 +4736,14 @@ def _create_basic_net():
|
|||
def __init__(self):
|
||||
super().__init__()
|
||||
self.layer_dummy_param = nn.Parameter(torch.empty(3, 5))
|
||||
self.register_buffer('layer_dummy_buf', torch.zeros(1, 3, 3, 7))
|
||||
self.layer_dummy_buf = nn.Buffer(torch.zeros(1, 3, 3, 7))
|
||||
|
||||
class Net(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.l1 = Layer()
|
||||
self.dummy_param = nn.Parameter(torch.empty(3, 5))
|
||||
self.register_buffer('dummy_buf', torch.zeros(7, 3, 3, 1))
|
||||
self.dummy_buf = nn.Buffer(torch.zeros(7, 3, 3, 1))
|
||||
|
||||
l = Layer()
|
||||
n = Net()
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user