mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
log_softmax: fix meta function output argument dtype check. (#140289)
Tracking issue: #138399 Pull Request resolved: https://github.com/pytorch/pytorch/pull/140289 Approved by: https://github.com/ezyang ghstack dependencies: #140186, #140286, #140288
This commit is contained in:
parent
435286e985
commit
48a276c5a0
|
|
@ -168,7 +168,6 @@ meta_consistency_out_dtype_mismatch_xfails = {
|
||||||
xfail("linalg.solve"),
|
xfail("linalg.solve"),
|
||||||
xfail("linalg.solve_ex"),
|
xfail("linalg.solve_ex"),
|
||||||
xfail("linalg.solve_triangular"),
|
xfail("linalg.solve_triangular"),
|
||||||
xfail("log_softmax"),
|
|
||||||
xfail("logcumsumexp"),
|
xfail("logcumsumexp"),
|
||||||
xfail("lu_solve"),
|
xfail("lu_solve"),
|
||||||
xfail("lu_unpack"),
|
xfail("lu_unpack"),
|
||||||
|
|
|
||||||
|
|
@ -1220,7 +1220,7 @@ def _softmax(x: Tensor, dim: int, half_to_float: bool):
|
||||||
|
|
||||||
|
|
||||||
@register_decomposition(aten._log_softmax)
|
@register_decomposition(aten._log_softmax)
|
||||||
@out_wrapper()
|
@out_wrapper(exact_dtype=True)
|
||||||
def _log_softmax(x: Tensor, dim: int, half_to_float: bool):
|
def _log_softmax(x: Tensor, dim: int, half_to_float: bool):
|
||||||
# eager log_softmax returns a contiguous tensor. Ensure that decomp also
|
# eager log_softmax returns a contiguous tensor. Ensure that decomp also
|
||||||
# returns a contiguous tensor.
|
# returns a contiguous tensor.
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user