import functools import torch from .nested_tensor import NestedTensor from typing import * # noqa: F403 from torch.fx.operator_schemas import normalize_function __all__: List[Any] = [] JAGGED_OPS_TABLE: Dict[Any, Any] = {} def _wrap_jagged_dim(ndim, dim, op_name): from torch._prims_common import canonicalize_dims wrapped = canonicalize_dims(ndim, dim) if wrapped < 2: raise RuntimeError( f"{op_name}(): not supported for NestedTensor on dim=0 or dim=1" ) return wrapped - 1 def check_schema(schema_str: str, func, *args, **kwargs) -> None: named_arg_types = schema_str.split(", ") num_optional_args = sum([x.endswith("?") for x in named_arg_types]) min_args = len(named_arg_types) - num_optional_args if not (len(args) >= min_args and len(args) <= len(named_arg_types)): raise ValueError( f"NestedTensor {func.__name__}({schema_str}): expected at least {min_args} " f"arguments and at most {len(named_arg_types)} arguments, but got: " f"{len(args)} arguments" ) arg_type_check_fns = { "t": lambda x: isinstance(x, torch.Tensor) and not isinstance(x, NestedTensor), "jt": lambda x: isinstance(x, NestedTensor), "any": lambda x: True, } for i, named_arg_type in enumerate(named_arg_types): name, arg_type = named_arg_type.split(": ") is_optional = arg_type.endswith("?") normalized_arg_type = arg_type[:-1] if is_optional else arg_type if normalized_arg_type not in arg_type_check_fns.keys(): raise AssertionError(f"Unknown arg type: {normalized_arg_type}") if i >= len(args): if not is_optional: raise ValueError( f"NestedTensor {func.__name__}({schema_str}) " f"missing required argument: {name}" ) continue if not arg_type_check_fns[normalized_arg_type](args[i]): raise ValueError( f"NestedTensor {func.__name__}({schema_str}): {name} should be of " f"type {arg_type}, but got: {type(args[i])}" ) def check_ragged_dim_same( func, a: NestedTensor, a_name: str, b: NestedTensor, b_name: str ) -> None: # Calling into .shape here if a._size[a._ragged_idx] != b._size[b._ragged_idx]: raise RuntimeError( f"NestedTensor {func.__name__}: expected {a_name} and {b_name} to have the " "same exact offsets tensor." ) def register_func(tables, aten_ops, schema_str): if not isinstance(aten_ops, list): aten_ops = [aten_ops] if not isinstance(tables, list): tables = [tables] def wrapper(func): for aten_op in aten_ops: def get_inner(aten_op): def inner(*args, **kwargs): check_schema(schema_str, func, *args, **kwargs) return func(aten_op, *args, **kwargs) return inner for table in tables: table[aten_op] = get_inner(aten_op) return wrapper register_jagged_func = functools.partial(register_func, JAGGED_OPS_TABLE) def lookup_jagged(func, *args, **kwargs) -> Optional[Callable]: dispatch_func = JAGGED_OPS_TABLE.get(func, None) if dispatch_func is not None: return dispatch_func # Handle pointwise fallbacks if torch.Tag.pointwise in func.tags: # Assume there aren't additional tensors that aren't the "unary/binary" args num_tensor_args = sum([isinstance(x, torch.Tensor) for x in args]) if num_tensor_args == 1: return functools.partial(jagged_unary_pointwise, func) elif num_tensor_args == 2: check_schema("lhs: jt, rhs: any", func, *args, **kwargs) return functools.partial(jagged_binary_pointwise, func) return None def extract_kwargs(arg): kwargs = { "offsets": arg.offsets(), "ragged_size": arg._size[arg._ragged_idx], } return kwargs def jagged_unary_pointwise(func, *args, **kwargs): return NestedTensor( func(args[0]._values, *args[1:], **kwargs), **extract_kwargs(args[0]) ) def jagged_binary_pointwise(func, *args, **kwargs): a, b = args[0], args[1] assert isinstance(a, NestedTensor) if isinstance(b, NestedTensor): check_ragged_dim_same(func, args[0], "lhs", args[1], "rhs") b = args[1]._values else: # TODO: Verify this more and handle the a.dim() == b.dim() case specially if needed if a.dim() <= b.dim(): # need to use offsets to broadcast across batch dim properly # NB: inefficient fallback here; Triton codegen can help this assert a.shape[0] == b.shape[0] outputs = [] for a_comp, b_comp in zip(a.unbind(), b.unbind()): outputs.append(func(a_comp, b_comp, **kwargs)) new_values = torch.cat(outputs, dim=0) return NestedTensor(new_values, **extract_kwargs(a)) return NestedTensor(func(a._values, b, **kwargs), **extract_kwargs(a)) @register_jagged_func( [ torch.ops.aten.is_non_overlapping_and_dense.default, torch.ops.aten.sym_size.default, torch.ops.aten.dim.default, torch.ops.aten.sym_numel.default, torch.ops.aten.sym_stride.default, torch.ops.aten.sym_storage_offset.default, ], "self: jt", ) def tensor_attr_supported_getter(func, *args, **kwargs): if func == torch.ops.aten.is_non_overlapping_and_dense.default: return False if func == torch.ops.aten.sym_size.default: return args[0]._size if func == torch.ops.aten.dim.default: return len(args[0]._size) if func == torch.ops.aten.sym_numel.default: return args[0]._values.numel() if func == torch.ops.aten.sym_stride.default: return args[0]._strides if func == torch.ops.aten.sym_storage_offset.default: return 0 @register_jagged_func(torch.ops.prim.layout.default, "self: jt") def prim_layout_default(func, *args, **kwargs): return torch.jagged @register_jagged_func( [ torch.ops.aten.size.default, torch.ops.aten.is_contiguous.default, torch.ops.aten.is_contiguous.memory_format, ], "self: jt, memory_format: any?", ) def tensor_attr_unsupported_getter(func, *args, **kwargs): if func == torch.ops.aten.size.default: raise RuntimeError( "NestedTensors does not support directly calling torch.ops.aten.size " "please use `nested_tensor.size()` instead." ) raise RuntimeError( "NestedTensors do not support directly querying strides, " "storage_offset, or contiguity." ) @register_jagged_func(torch.ops.aten.linear.default, "input: jt, weight: t, bias: t?") def linear_default(func, *args, **kwargs): _, new_kwargs = normalize_function( func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True ) inp = new_kwargs.pop("input") weight = new_kwargs["weight"] bias = new_kwargs["bias"] new_values = torch.mm(inp._values, weight) if bias is not None: new_values += bias return NestedTensor(new_values, **extract_kwargs(inp)) @register_jagged_func( torch.ops.aten.linear_backward.default, "self: jt, grad_output: jt, weight: t, output_mask: any", ) def linear_backward_default(func, *args, **kwargs): _, new_kwargs = normalize_function( func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True ) inp = new_kwargs.pop("input") grad_output = new_kwargs.pop("grad_output") weight = new_kwargs.pop("weight") check_ragged_dim_same(func, inp, "self", grad_output, "grad_output") ds = NestedTensor( torch.mm(grad_output._values, weight.T), **extract_kwargs(grad_output) ) dw = torch.mm(inp._values.T, grad_output._values) db = None # NYI: gradient for bias, need to reduce over ragged dim return (ds, dw, db) @register_jagged_func(torch.ops.aten._to_copy.default, "self: jt") def to_copy_default(func, *args, **kwargs): _, new_kwargs = normalize_function( func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True ) inp = new_kwargs.pop("input") # don't change layout new_kwargs.pop("layout") new_values = func(inp._values, **new_kwargs) # NB: Purposefully keep offsets on the old device. return NestedTensor(new_values, **extract_kwargs(inp)) register_jagged_func( [ torch.ops.aten.ones_like.default, torch.ops.aten.zeros_like.default, torch.ops.aten.randn_like.default, ], "self: jt", )(jagged_unary_pointwise) @register_jagged_func(torch.ops.aten.prod.dim_int, "self: jt, dim: any, keepdim: any?") def prod_dim_int(func, *args, **kwargs): _, new_kwargs = normalize_function( func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True ) inp = new_kwargs.pop("input") # TODO: Figure out how to handle this better # keep_dim is required to keep it in jagged format if not new_kwargs["keepdim"]: raise RuntimeError("prod(): keepdim=True must be set for NestedTensor") dim = new_kwargs["dim"] new_kwargs["dim"] = _wrap_jagged_dim(len(inp.shape), dim, "prod") if new_kwargs["dim"] == 0: raise RuntimeError("prod(): not supported for NestedTensor on dim=0") return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(args[0])) @register_jagged_func(torch.ops.aten.unbind.int, "self: jt, dim: any?") def unbind_int(func, *args, **kwargs): _, new_kwargs = normalize_function( func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True ) dim = new_kwargs["dim"] if dim != 0: raise RuntimeError("unbind(): only supported for NestedTensor on dim=0") inp = new_kwargs.pop("input") values = inp._values offsets = inp.offsets() views = [] start = 0 for length in offsets.diff().cpu().tolist(): views.append(inp._values[start : start + length, ...]) start += length return tuple(views) @register_jagged_func(torch.ops.aten.unsqueeze.default, "self: jt, dim: any") def unsqueeze_default(func, *args, **kwargs): _, new_kwargs = normalize_function( func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True ) inp = new_kwargs.pop("input") values = inp._values offsets = inp.offsets # Account for collapsed jagged dim dim = new_kwargs["dim"] new_kwargs["dim"] = _wrap_jagged_dim(len(inp.shape) + 1, dim, "unsqueeze") return NestedTensor(func(values, **new_kwargs), **extract_kwargs(inp)) @register_jagged_func(torch.ops.aten.cat.default, "tensors: any, dim: any") def cat_default(func, *args, **kwargs): _, new_kwargs = normalize_function( func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True ) tensors = new_kwargs.pop("tensors") # Convert any non-nested to nested nested = [t for t in tensors if t.is_nested] assert len(nested) > 0 first = nested[0] tensors = [t if t.is_nested else t.expand_as(first) for t in tensors] # Account for collapsed jagged dim dim = new_kwargs["dim"] new_kwargs["dim"] = _wrap_jagged_dim(len(first.shape), dim, "cat") return NestedTensor( func([t._values for t in tensors], **new_kwargs), **extract_kwargs(tensors[0]) ) @register_jagged_func(torch.ops.aten.matmul.default, "self: jt, other: any") def matmul_default(func, *args, **kwargs): _, new_kwargs = normalize_function( func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True ) inp = new_kwargs.pop("input") other = new_kwargs.pop("other") if (not inp.is_nested) or other.is_nested: raise RuntimeError( "matmul(): only supported input pattern is (nested, non-nested)" ) return NestedTensor(func(inp._values, other, **new_kwargs), **extract_kwargs(inp)) @register_jagged_func( torch.ops.aten.expand.default, "self: jt, size: any, implicit: any?" ) def expand_default(func, *args, **kwargs): _, new_kwargs = normalize_function( func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True ) inp = new_kwargs.pop("input") size = new_kwargs["size"] assert ("implicit" not in new_kwargs) or (not new_kwargs.pop("implicit")) if list(size[:2]) != list(inp.shape[:2]): raise RuntimeError("expand(): cannot expand if ragged dims don't match") expand_arg = [-1, *size[2:]] return NestedTensor(func(inp._values, expand_arg), **extract_kwargs(inp)) @register_jagged_func(torch.ops.aten.expand_as.default, "self: t, other: jt") def expand_as_default(func, *args, **kwargs): _, new_kwargs = normalize_function( func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True ) inp = new_kwargs.pop("input") other = new_kwargs.pop("other") return NestedTensor(func(inp, other._values), **extract_kwargs(other)) @register_jagged_func(torch.ops.aten.where.self, "condition: jt, self: jt, other: jt") def where_self(func, *args, **kwargs): _, new_kwargs = normalize_function( func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True ) condition = new_kwargs.pop("condition") inp = new_kwargs.pop("input") other = new_kwargs.pop("other") assert condition.shape == other.shape == inp.shape return NestedTensor( func(condition._values, inp._values, other._values, **new_kwargs), **extract_kwargs(condition), ) @register_jagged_func(torch.ops.aten._pin_memory.default, "self: jt, device: any?") def _pin_memory_default(func, *args, **kwargs): _, new_kwargs = normalize_function( func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True ) inp = new_kwargs.pop("input") return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp)) @register_jagged_func(torch.ops.aten.is_pinned.default, "self: jt, device: any?") def is_pinned_default(func, *args, **kwargs): _, new_kwargs = normalize_function( func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True ) inp = new_kwargs.pop("input") return func(inp._values, **new_kwargs)