Update parallel_apply.py for assertion error when len(modules) != len(inputs) (#94671)

Print the result why it is wrong.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94671
Approved by: https://github.com/ngimel, https://github.com/kit1980
This commit is contained in:
Kim,Won-Joong 2023-03-21 17:46:23 +00:00 committed by PyTorch MergeBot
parent a6bbeec2e1
commit c47cf9bc7f

View File

@ -35,7 +35,7 @@ def parallel_apply(modules, inputs, kwargs_tup=None, devices=None):
element of :attr:`inputs` can either be a single object as the only argument
to a module, or a collection of positional arguments.
"""
assert len(modules) == len(inputs)
assert len(modules) == len(inputs), f'The number of modules {len(modules)} is not equal to the number of inputs {len(inputs)}'
if kwargs_tup is not None:
assert len(modules) == len(kwargs_tup)
else: