mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Make counts datatype int. Used as index.
Summary: To avoid Numpy warning: using a non-integer number instead of an integer will result in an error in the future Closes https://github.com/caffe2/caffe2/pull/64 Differential Revision: D4658348 Pulled By: Yangqing fbshipit-source-id: 3a1b33cbb27849bc167b08147d078e8d487567f4
This commit is contained in:
parent
9ef35f4a0b
commit
c61a7ca777
|
|
@ -119,7 +119,7 @@ class SegmentsTester(TesterBase):
|
|||
dtype=data.dtype
|
||||
) for seg_id in range(0, K)
|
||||
]
|
||||
counts = np.zeros(K)
|
||||
counts = np.zeros(K, dtype=int)
|
||||
for i, seg_id in enumerate(segment_ids):
|
||||
data_idx = i if indices is None else indices[i]
|
||||
outputs[seg_id][counts[seg_id]] = data[data_idx]
|
||||
|
|
@ -132,7 +132,7 @@ class SegmentsTester(TesterBase):
|
|||
if len(segment_ids) == 0:
|
||||
return output
|
||||
K = max(segment_ids) + 1
|
||||
counts = np.zeros(K)
|
||||
counts = np.zeros(K, dtype=int)
|
||||
for i, seg_id in enumerate(segment_ids):
|
||||
output[i] = inputs[seg_id][counts[seg_id]]
|
||||
counts[seg_id] += 1
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user