mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
before the PR, for HF's ModelOutput class, we use dicts.py/DataClassVariable with our own implementation on __getItem__, __setAttr__, __setItem__. There is a risk that ModelOutput logic may change since it is a user code after the PR, we inline __getItem__, __setAttr__, __setItem__ using dicts.py/CustomizedDictVariable so the logic always keep AA unit test * python test/dynamo/test_model_output.py -k test_HF_bert_model_output test on HF benchmark * python benchmarks/dynamo/huggingface.py -d cuda --inference --accuracy --progress --inductor --print-dataframe-summary 2>&1 * all metric are the same before/after the PR, including pass rate, unique_graphs, graph_breaks, unique_graph_breaks * before the PR: P790393916 * after the PR: P790368991 Pull Request resolved: https://github.com/pytorch/pytorch/pull/105044 Approved by: https://github.com/jansel
250 lines
8.2 KiB
Python
250 lines
8.2 KiB
Python
# Owner(s): ["module: dynamo"]
|
|
import dataclasses
|
|
import unittest.mock
|
|
|
|
import torch
|
|
|
|
import torch._dynamo.test_case
|
|
import torch._dynamo.testing
|
|
from torch._dynamo.testing import same
|
|
|
|
try:
|
|
from transformers import modeling_outputs
|
|
from transformers.configuration_utils import PretrainedConfig
|
|
from transformers.file_utils import ModelOutput
|
|
from transformers.modeling_outputs import (
|
|
BaseModelOutput,
|
|
BaseModelOutputWithPastAndCrossAttentions,
|
|
BaseModelOutputWithPoolingAndCrossAttentions,
|
|
)
|
|
except ImportError:
|
|
modeling_outputs = None
|
|
|
|
|
|
def maybe_skip(fn):
|
|
if modeling_outputs is None:
|
|
return unittest.skip("requires HuggingFace")(fn)
|
|
return fn
|
|
|
|
|
|
class TestHFPretrained(torch._dynamo.test_case.TestCase):
|
|
@maybe_skip
|
|
def test_pretrained(self):
|
|
def fn(a, tmp):
|
|
if hasattr(tmp, "somekey"):
|
|
a = a + 1
|
|
if tmp.return_dict:
|
|
return a + torch.ones(2) * tmp.max_length
|
|
return a
|
|
|
|
x = torch.randn(2)
|
|
tmp = PretrainedConfig(return_dict=True, max_length=20)
|
|
ref = fn(x, tmp)
|
|
opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn)
|
|
res = opt_fn(x, tmp)
|
|
self.assertTrue(same(ref, res))
|
|
|
|
|
|
class TestModelOutput(torch._dynamo.test_case.TestCase):
|
|
@maybe_skip
|
|
def test_mo_create(self):
|
|
def fn(a, b):
|
|
tmp = BaseModelOutput(a + 1, attentions=b + 3)
|
|
return tmp
|
|
|
|
torch._dynamo.testing.standard_test(self, fn=fn, nargs=2, expected_ops=2)
|
|
|
|
@maybe_skip
|
|
def test_mo_assign(self):
|
|
def fn(a, b):
|
|
tmp = BaseModelOutput(last_hidden_state=b + 3)
|
|
tmp.hidden_states = a + 7
|
|
tmp["attentions"] = a + b + 6
|
|
return tmp
|
|
|
|
args = [torch.randn(10), torch.randn(10)]
|
|
obj1 = fn(*args)
|
|
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize_assert(cnts)(fn)
|
|
obj2 = opt_fn(*args)
|
|
self.assertTrue(same(obj1.last_hidden_state, obj2.last_hidden_state))
|
|
self.assertTrue(same(obj1.hidden_states, obj2.hidden_states))
|
|
self.assertTrue(same(obj1.attentions, obj2.attentions))
|
|
self.assertEqual(cnts.frame_count, 1)
|
|
self.assertEqual(cnts.op_count, 4)
|
|
|
|
def _common(self, fn, op_count):
|
|
args = [
|
|
BaseModelOutput(
|
|
last_hidden_state=torch.randn(10), attentions=torch.randn(10)
|
|
)
|
|
]
|
|
obj1 = fn(*args)
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize_assert(cnts)(fn)
|
|
obj2 = opt_fn(*args)
|
|
self.assertTrue(same(obj1, obj2))
|
|
self.assertEqual(cnts.frame_count, 1)
|
|
self.assertEqual(cnts.op_count, op_count)
|
|
|
|
@maybe_skip
|
|
def test_mo_getattr(self):
|
|
def fn(obj: BaseModelOutput):
|
|
x = obj.last_hidden_state * 10
|
|
if obj.hidden_states is not None:
|
|
x += obj.hidden_states
|
|
if obj.attentions is not None:
|
|
x += obj.attentions
|
|
return x
|
|
|
|
self._common(fn, 2)
|
|
|
|
@maybe_skip
|
|
def test_mo_getitem(self):
|
|
def fn(obj: BaseModelOutput):
|
|
x = obj["last_hidden_state"] * 10
|
|
if "hidden_stats" in obj:
|
|
x += obj["hidden_states"]
|
|
if "attentions" in obj:
|
|
x += obj["attentions"]
|
|
return x
|
|
|
|
self._common(fn, 2)
|
|
|
|
@maybe_skip
|
|
def test_mo_tuple(self):
|
|
def fn(obj: BaseModelOutput):
|
|
a, b = obj.to_tuple()
|
|
return a + b * 10
|
|
|
|
self._common(fn, 2)
|
|
|
|
@maybe_skip
|
|
def test_mo_index(self):
|
|
def fn(obj: BaseModelOutput):
|
|
return obj[0] * 10 + obj[1]
|
|
|
|
self._common(fn, 2)
|
|
|
|
@maybe_skip
|
|
def test_mo_init(self):
|
|
@dataclasses.dataclass
|
|
class MyDataClass(ModelOutput):
|
|
a: torch.Tensor
|
|
b: torch.Tensor = None
|
|
c: torch.Tensor = None
|
|
d: torch.Tensor = None
|
|
e: torch.Tensor = None
|
|
|
|
def fn(obj):
|
|
class_fields = dataclasses.fields(obj)
|
|
assert len(class_fields)
|
|
assert all(field.default is None for field in class_fields[1:])
|
|
other_fields_are_none = all(
|
|
getattr(obj, field.name) is None for field in class_fields[1:]
|
|
)
|
|
assert not other_fields_are_none
|
|
|
|
total = getattr(obj, class_fields[0].name)
|
|
for field in class_fields[1:]:
|
|
v = getattr(obj, field.name)
|
|
if v is not None:
|
|
total += v
|
|
|
|
return total
|
|
|
|
tensors = [torch.randn(10), torch.randn(10), torch.randn(10)]
|
|
obj1 = MyDataClass(*tensors)
|
|
correct1 = fn(obj1)
|
|
|
|
obj2 = MyDataClass(*tensors)
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize(cnts)(fn)
|
|
self.assertTrue(same(opt_fn(obj2), correct1))
|
|
self.assertEqual(cnts.frame_count, 1)
|
|
self.assertEqual(cnts.op_count, 2)
|
|
|
|
@maybe_skip
|
|
def test_HF_bert_model_output(self):
|
|
class BertPooler(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.dense = torch.nn.Linear(768, 768).to("cuda")
|
|
self.activation = torch.nn.Tanh()
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
# We "pool" the model by simply taking the hidden state corresponding
|
|
# to the first token.
|
|
first_token_tensor = hidden_states[:, 0]
|
|
pooled_output = self.dense(first_token_tensor)
|
|
pooled_output = self.activation(pooled_output)
|
|
return pooled_output
|
|
|
|
class BertEncoder(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
) -> BaseModelOutputWithPastAndCrossAttentions:
|
|
return BaseModelOutputWithPastAndCrossAttentions(
|
|
last_hidden_state=hidden_states,
|
|
past_key_values=None,
|
|
hidden_states=None,
|
|
attentions=None,
|
|
cross_attentions=None,
|
|
)
|
|
|
|
class BertModel(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.encoder = BertEncoder()
|
|
self.pooler = BertPooler()
|
|
|
|
def forward(
|
|
self,
|
|
sequence_output: torch.Tensor,
|
|
) -> BaseModelOutputWithPoolingAndCrossAttentions:
|
|
encoder_outputs = self.encoder(sequence_output)
|
|
# test __getitem__ and to_tuple
|
|
sequence_output = encoder_outputs[0]
|
|
pooled_output = (
|
|
self.pooler(sequence_output) if self.pooler is not None else None
|
|
)
|
|
# test CustomDictVariable.create
|
|
result = BaseModelOutputWithPoolingAndCrossAttentions(
|
|
last_hidden_state=sequence_output,
|
|
pooler_output=pooled_output,
|
|
past_key_values=encoder_outputs.past_key_values,
|
|
hidden_states=encoder_outputs.hidden_states,
|
|
attentions=encoder_outputs.attentions,
|
|
cross_attentions=encoder_outputs.cross_attentions,
|
|
)
|
|
# test __setattr__
|
|
result.pooler_output = pooled_output
|
|
# test __setitem__
|
|
result["pooler_output"] = pooled_output
|
|
return result
|
|
|
|
sequence_output = torch.rand(1, 12, 768).to("cuda")
|
|
model = BertModel()
|
|
orig_result = model(sequence_output)
|
|
compiled_model = torch.compile(model, backend="eager")
|
|
compiled_result = compiled_model(sequence_output)
|
|
self.assertTrue(
|
|
torch.allclose(
|
|
orig_result.last_hidden_state, compiled_result.last_hidden_state
|
|
)
|
|
)
|
|
self.assertTrue(
|
|
torch.allclose(orig_result.pooler_output, compiled_result.pooler_output)
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from torch._dynamo.test_case import run_tests
|
|
|
|
run_tests()
|