mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Add integration test for pos_w
Summary: Title Reviewed By: azzolini Differential Revision: D5197307 fbshipit-source-id: 425bf8e7c5068ea544e5b2709b6bb27eef140bf3
This commit is contained in:
parent
df72826ead
commit
c291c97494
|
|
@ -1216,7 +1216,10 @@ class GatherOp : public Operator<Context> {
|
|||
auto idx = idxs[i];
|
||||
CAFFE_ENFORCE(
|
||||
0 <= idx && idx < data.dim(0),
|
||||
"INDICES element is out of DATA bounds");
|
||||
"INDICES element is out of DATA bounds, id=",
|
||||
idx,
|
||||
" data_dim=",
|
||||
data.dim(0));
|
||||
auto src = src_base + idx * block_bytesize;
|
||||
context_.template CopyItems<Context, Context>(
|
||||
data.meta(), block_size, src, out + block_bytesize * i);
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ class PositionWeighted(ModelLayer):
|
|||
|
||||
# TODO: Replace this with correct estimation after we compute
|
||||
# cardinality from run_meta
|
||||
self.shape = 1000
|
||||
self.shape = 2000
|
||||
|
||||
self.pos_w = model.net.NextScopedBlob(name + "_pos_w")
|
||||
self.params.append(
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user