mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[RELAND] Add metadata coverage for unsafe_split and unsafe_split_with_sizes (#92802)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92802 Approved by: https://github.com/soumith
This commit is contained in:
parent
53ef803705
commit
8f3600b966
|
|
@ -2416,7 +2416,6 @@ symbolic_aot_autograd_failures = {
|
|||
xfail('var_mean', 'unbiased'), # Cannot call numel() on tensor with symbolic sizes/strides
|
||||
xfail('view_as', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
xfail('vsplit', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
xfail('unsafe_split', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
}
|
||||
|
||||
def _test_aot_autograd_forwards_backwards_helper(self, f, compiled_f, args):
|
||||
|
|
@ -2574,7 +2573,6 @@ aot_autograd_module_failures = set({
|
|||
})
|
||||
|
||||
symbolic_aot_autograd_module_failures = {
|
||||
torch.nn.GRU, # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
torch.nn.Transformer, # DataDependentOutputException: aten.equal compares a mask input to a mask producing a bool
|
||||
torch.nn.TransformerEncoder, # DataDependentOutputException: aten.equal compares a mask input to a mask producing a bool
|
||||
torch.nn.TransformerEncoderLayer, # RuntimeError: tried to get Double out of SymFloat
|
||||
|
|
|
|||
|
|
@ -1351,7 +1351,6 @@ symbolic_tensor_failures = {
|
|||
xfail('vsplit', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
|
||||
xfail('unique_consecutive', ''), # aten.unique_consecutive.default - couldn't find symbolic meta function/decomposition
|
||||
xfail('unique', ''), # aten._unique2.default - couldn't find symbolic meta function/decomposition
|
||||
xfail('unsafe_split', ''), # cannot call sizes() on tensor with symbolic sizes/strides
|
||||
}
|
||||
symbolic_tensor_segfaults = {
|
||||
skip('nn.functional.batch_norm') # Segfault??
|
||||
|
|
|
|||
|
|
@ -1085,7 +1085,7 @@ def prod(x: List[int]):
|
|||
return r
|
||||
|
||||
|
||||
@register_decomposition(aten.split_with_sizes)
|
||||
@register_decomposition([aten.split_with_sizes, aten.unsafe_split_with_sizes])
|
||||
def split_with_sizes(
|
||||
self: Tensor, split_sizes: List[int], dim: int = 0
|
||||
) -> List[Tensor]:
|
||||
|
|
@ -1099,7 +1099,7 @@ def split_with_sizes(
|
|||
return splits
|
||||
|
||||
|
||||
@register_decomposition(aten.split.Tensor)
|
||||
@register_decomposition([aten.split.Tensor, aten.unsafe_split.Tensor])
|
||||
def split(self: Tensor, split_size: int, dim: int = 0) -> List[Tensor]:
|
||||
input_sizes = self.shape
|
||||
dim_size = input_sizes[dim]
|
||||
|
|
@ -1462,6 +1462,18 @@ def native_batch_norm_decomposition(
|
|||
)
|
||||
|
||||
|
||||
@aten.unsafe_chunk.default.py_impl(DispatchKey.CompositeImplicitAutograd)
|
||||
def unsafe_chunk_py_impl(tensor, chunks, dim=0) -> List[Tensor]:
|
||||
dim_size = tensor.size(dim)
|
||||
split_size = (dim_size + chunks - 1) // chunks
|
||||
|
||||
if split_size == 0 and dim_size == 0:
|
||||
split_sizes = [split_size for _ in chunks]
|
||||
split_sizes[chunks - 1] = split_size - (split_size * chunks - dim_size)
|
||||
return torch.ops.aten.unsafe_split_with_sizes.default(tensor, split_sizes, dim)
|
||||
return torch.ops.aten.unsafe_split.Tensor(tensor, split_size, dim)
|
||||
|
||||
|
||||
@register_decomposition(aten._native_batch_norm_legit.default)
|
||||
def _native_batch_norm_legit(
|
||||
input: Tensor,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user