import functools import torch from .nested_tensor import NestedTensor from typing import * # noqa: F403 __all__: List[Any] = [] JAGGED_OPS_TABLE: Dict[Any, Any] = {} def check_schema(schema_str: str, func, *args, **kwargs) -> None: named_arg_types = schema_str.split(", ") num_optional_args = sum([x.endswith("?") for x in named_arg_types]) min_args = len(named_arg_types) - num_optional_args if not (len(args) >= min_args and len(args) <= len(named_arg_types)): raise ValueError( f"NestedTensor {func.__name__}({schema_str}): expected at least {min_args} " f"arguments and at most {len(named_arg_types)} arguments, but got: " f"{len(args)} arguments" ) arg_type_check_fns = { "t": lambda x: isinstance(x, torch.Tensor) and not isinstance(x, NestedTensor), "jt": lambda x: isinstance(x, NestedTensor), "any": lambda x: True, } for i, named_arg_type in enumerate(named_arg_types): name, arg_type = named_arg_type.split(": ") is_optional = arg_type.endswith("?") normalized_arg_type = arg_type[:-1] if is_optional else arg_type if normalized_arg_type not in arg_type_check_fns.keys(): raise AssertionError(f"Unknown arg type: {normalized_arg_type}") if i >= len(args): if not is_optional: raise ValueError( f"NestedTensor {func.__name__}({schema_str}) " f"missing required argument: {name}" ) continue if not arg_type_check_fns[normalized_arg_type](args[i]): raise ValueError( f"NestedTensor {func.__name__}({schema_str}): {name} should be of " f"type {arg_type}, but got: {type(args[i])}" ) def check_ragged_dim_same( func, a: NestedTensor, a_name: str, b: NestedTensor, b_name: str ) -> None: # Calling into .shape here assert len(a._size) == 3, "NestedTensor must be [B, *, D]" if a._size[1] != b._size[1]: raise RuntimeError( f"NestedTensor {func.__name__}: expected {a_name} and {b_name} to have the " "same exact offsets tensor." ) def register_func(tables, aten_ops, schema_str): if not isinstance(aten_ops, list): aten_ops = [aten_ops] if not isinstance(tables, list): tables = [tables] def wrapper(func): for aten_op in aten_ops: def get_inner(aten_op): def inner(*args, **kwargs): check_schema(schema_str, func, *args, **kwargs) return func(aten_op, *args, **kwargs) return inner for table in tables: table[aten_op] = get_inner(aten_op) return wrapper register_jagged_func = functools.partial(register_func, JAGGED_OPS_TABLE) def lookup_jagged(func, *args, **kwargs) -> Optional[Callable]: if torch.Tag.pointwise in func.tags: # Assume there aren't additional tensors that aren't the "unary/binary" args num_tensor_args = sum([isinstance(x, torch.Tensor) for x in args]) if num_tensor_args == 1: return functools.partial(jagged_unary_pointwise, func) elif num_tensor_args == 2: check_schema("lhs: jt, rhs: jt", func, *args, **kwargs) return functools.partial(jagged_binary_pointwise, func) else: return None return JAGGED_OPS_TABLE.get(func, None) def extract_kwargs(arg): kwargs = { "offsets": arg.offsets(), } return kwargs def jagged_unary_pointwise(func, *args, **kwargs): return NestedTensor(func(args[0].values(), **kwargs), **extract_kwargs(args[0])) def jagged_binary_pointwise(func, *args, **kwargs): check_ragged_dim_same(func, args[0], "lhs", args[1], "rhs") return NestedTensor( func(args[0].values(), args[1].values(), **kwargs), **extract_kwargs(args[0]) ) @register_jagged_func( [ torch.ops.aten.is_non_overlapping_and_dense.default, torch.ops.aten.sym_size.default, torch.ops.aten.dim.default, torch.ops.aten.sym_numel.default, ], "self: jt", ) def tensor_attr_supported_getter(func, *args, **kwargs): if func == torch.ops.aten.is_non_overlapping_and_dense.default: return False if func == torch.ops.aten.sym_size.default: return args[0]._size if func == torch.ops.aten.dim.default: return 3 if func == torch.ops.aten.sym_numel.default: return args[0].values().numel() @register_jagged_func( [ torch.ops.aten.size.default, torch.ops.aten.sym_stride.default, torch.ops.aten.is_contiguous.default, torch.ops.aten.is_contiguous.memory_format, torch.ops.aten.sym_storage_offset.default, ], "self: jt, memory_format: any?", ) def tensor_attr_unsupported_getter(func, *args, **kwargs): if func == torch.ops.aten.size.default: raise RuntimeError( "NestedTensors does not support directly calling torch.ops.aten.size " "please use `nested_tensor.size()` instead." ) raise RuntimeError( "NestedTensors do not support directly querying strides, " "storage_offset, or contiguity." ) @register_jagged_func(torch.ops.aten.linear.default, "input: jt, weight: t, bias: t?") def linear_default(func, *args, **kwargs): values = torch.mm(args[0].values(), args[1]) if len(args) == 3: values += args[2] return NestedTensor(values, **extract_kwargs(args[0])) @register_jagged_func( torch.ops.aten.linear_backward.default, "self: jt, grad_output: jt, weight: t, output_mask: any", ) def linear_backward_default(func, *args, **kwargs): check_ragged_dim_same(func, args[0], "self", args[1], "grad_output") ds = NestedTensor(torch.mm(args[1].values(), args[2].T), **extract_kwargs(args[1])) dw = torch.mm(args[0].values().T, args[1].values()) db = None # NYI: gradient for bias, need to reduce over ragged dim return (ds, dw, db)