pytorch/torch/testing/_internal
Joel Schlosser 8ba555ec8a Fix where() for NJT (#141500)
**Background:** It's common to use `scalar_tensor()` in the input to `where()` to convert any scalars present to compatible tensors with matching options, *including layout*. This shows up in various places, notably including derivative formulas ([example](78491d6afc/tools/autograd/derivatives.yaml (L432-L434))). It causes problems for NJTs because they have `layout=torch.jagged` and it never makes sense to create a scalar tensor with this layout. Some of the breakage only seems to happen in CI for reasons I don't fully understand (see the revert of #140736 due to softshrink's derivative formula).

**This PR:**
* Allows non-contiguous NJT inputs to `where()` + adds tests for this
* Handles scalar tensor / dense tensor inputs for `condition` / `other` + adds tests for this
    * Uses limited `broadcast_tensors()` / `broadcast_to()` support
    * Improves `expand()` to work on non-contig NJTs
* Changes `scalar_tensor()` to use `torch.strided` instead of `torch.jagged` in both eager and torch.compile (i.e. meta registration)
* Changes backward formulas for `sinc`, `pow`, `special.i1`, and `special.i1e` to uses `scalar_tensor()` instead of e.g. `zeros({})`

**Alternative approach:** Update all problematic usages of `scalar_tensor()` to avoid ever passing `layout=torch.jagged`. This is an extensive change and includes `torch.where()` logic, a bunch of derivative formulas, and likely other places not yet discovered.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141500
Approved by: https://github.com/malfet, https://github.com/cpuhrsch, https://github.com/soulitzer
2024-11-26 20:13:27 +00:00
..
codegen
data
distributed [DTensorTestbase] Fix TestFunc typing issue (#141513) 2024-11-26 19:48:34 +00:00
generated
opinfo Fix where() for NJT (#141500) 2024-11-26 20:13:27 +00:00
optests Revert "Deprecate torch._utils.is_compiling() and torch._dynamo.external_utils.is_compiling() (#127690)" 2024-11-05 23:10:38 +00:00
test_module
__init__.py
autocast_test_lists.py Remove unused Python variables in torch/[b-z]* (#136963) 2024-10-19 16:45:22 +00:00
autograd_function_db.py Remove unused Python variables in torch/[b-z]* (#136963) 2024-10-19 16:45:22 +00:00
check_kernel_launches.py
common_cuda.py Make sure all SDPA tests are ran with tensor cores enabled (#135592) 2024-10-29 20:53:10 +00:00
common_device_type.py General per-SampleInput xfail / skip system (#140443) 2024-11-19 23:09:38 +00:00
common_dist_composable.py Remove unused Python variables in torch/[b-z]* (#136963) 2024-10-19 16:45:22 +00:00
common_distributed.py Revert "[CI] Reduce distributed test timeout to 60s (#141168)" 2024-11-22 15:46:37 +00:00
common_dtype.py [redo] Fp8 support for item() with cuda, index_select, and fill_ cpu (#137341) 2024-10-07 00:58:51 +00:00
common_fsdp.py Fix type-safety of torch.nn.Module instances (#141240) 2024-11-22 00:05:05 +00:00
common_jit.py
common_methods_invocations.py [BE]: Apply PERF401 autofixes from ruff (#140980) 2024-11-20 17:52:07 +00:00
common_mkldnn.py
common_modules.py [BE] Use torch.log1p(x) instead of torch.log(1+x) (#141167) 2024-11-21 00:36:20 +00:00
common_nn.py
common_optimizers.py Support tensor betas in Adam and AdamW (#134171) 2024-11-15 21:55:55 +00:00
common_pruning.py Fix type-safety of torch.nn.Module instances (#141240) 2024-11-22 00:05:05 +00:00
common_quantization.py [Intel GPU] XPUInductorQuantizer for XPU int8 recipe customization (#139578) 2024-11-26 09:44:14 +00:00
common_quantized.py [BE]: Apply PERF401 autofixes from ruff (#140980) 2024-11-20 17:52:07 +00:00
common_subclass.py Fix wrapper subclass serialization with custom sizes / strides (#137030) 2024-10-02 18:55:03 +00:00
common_utils.py [Intel GPU] qconv at XPU backend (#133080) 2024-11-26 02:24:30 +00:00
composite_compliance.py Remove unused Python variables in torch/[b-z]* (#136963) 2024-10-19 16:45:22 +00:00
custom_op_db.py Remove unused Python variables in torch/[b-z]* (#136963) 2024-10-19 16:45:22 +00:00
custom_tensor.py
dist_utils.py
dynamo_test_failures.py
fake_config_module.py Add type annotations to Configs (#139833) 2024-11-07 03:49:09 +00:00
hop_db.py [invoke_subgraph] User facing API to support arbitrary args and kwargs (#139162) 2024-11-08 03:31:19 +00:00
hypothesis_utils.py [BE]: Apply PERF401 autofixes from ruff (#140980) 2024-11-20 17:52:07 +00:00
inductor_utils.py [Inductor UT] Generalize newly introduced inductor UTs for intel GPU (Part 3) (#136947) 2024-10-12 13:21:20 +00:00
jit_metaprogramming_utils.py Remove unused Python variables in torch/[b-z]* (#136963) 2024-10-19 16:45:22 +00:00
jit_utils.py Remove unused Python variables in torch/[b-z]* (#136963) 2024-10-19 16:45:22 +00:00
logging_tensor.py
logging_utils.py
quantization_torch_package_models.py
static_module.py
subclasses.py [aotd] coerce_same_metadata_as_tangent with expected_type for e.g.AsyncCollectiveTensor (#139095) 2024-11-07 16:24:48 +00:00
torchbind_impls.py Remove unused Python variables in torch/[b-z]* (#136963) 2024-10-19 16:45:22 +00:00
triton_utils.py [RFC] Implement caching for user defined triton kernels (#140326) 2024-11-16 02:37:16 +00:00
two_tensor.py Fix tensor subclass + dynamic shapes in torch.compile + aot autograd (#125941) 2024-10-28 21:58:59 +00:00