pytorch/caffe2/operators/boolean_unmask_ops_test.cc
Nikita Shulga 4cb534f92e Make PyTorch code-base clang-tidy compliant (#56892)
Summary:
This is an automatic change generated by the following script:
```
#!/usr/bin/env python3
from subprocess import check_output, check_call
import os

def get_compiled_files_list():
    import json
    with open("build/compile_commands.json") as f:
        data = json.load(f)
    files = [os.path.relpath(node['file']) for node in data]
    for idx, fname in enumerate(files):
        if fname.startswith('build/') and fname.endswith('.DEFAULT.cpp'):
            files[idx] = fname[len('build/'):-len('.DEFAULT.cpp')]
    return files

def run_clang_tidy(fname):
    check_call(["python3", "tools/clang_tidy.py", "-c", "build", "-x", fname,"-s"])
    changes = check_output(["git", "ls-files", "-m"])
    if len(changes) == 0:
        return
    check_call(["git", "commit","--all", "-m", f"NOLINT stubs for {fname}"])

def main():
    git_files = check_output(["git", "ls-files"]).decode("ascii").split("\n")
    compiled_files = get_compiled_files_list()
    for idx, fname in enumerate(git_files):
        if fname not in compiled_files:
            continue
        if fname.startswith("caffe2/contrib/aten/"):
            continue
        print(f"[{idx}/{len(git_files)}] Processing {fname}")
        run_clang_tidy(fname)

if __name__ == "__main__":
    main()
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/56892

Reviewed By: H-Huang

Differential Revision: D27991944

Pulled By: malfet

fbshipit-source-id: 5415e1eb2c1b34319a4f03024bfaa087007d7179
2021-04-28 14:10:25 -07:00

73 lines
1.7 KiB
C++

#include <iostream>
#include <gtest/gtest.h>
#include "caffe2/core/context.h"
#include "caffe2/core/flags.h"
#include "caffe2/core/operator.h"
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
C10_DECLARE_string(caffe_test_root);
namespace caffe2 {
template <class DataT>
static void AddScalarInput(
const DataT& value,
const string& name,
Workspace* ws,
bool isEmpty = false) {
Blob* blob = ws->CreateBlob(name);
auto* tensor = BlobGetMutableTensor(blob, CPU);
if (!isEmpty) {
tensor->Resize(vector<int64_t>{1});
*(tensor->template mutable_data<DataT>()) = value;
} else {
tensor->Resize(vector<int64_t>{0});
tensor->template mutable_data<DataT>();
}
return;
}
// Test case for BooleanUnmask operator
// mask1: [ false ]
// values1: [ ]
// mask2: [ true ]
// values2: [ 1.0 ]
//
// Expected Output: [ 1.0 ]
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(BooleanUnmaskTest, Test) {
Workspace ws;
OperatorDef def;
def.set_name("test");
def.set_type("BooleanUnmask");
def.add_input("mask1");
def.add_input("values1");
def.add_input("mask2");
def.add_input("values2");
def.add_output("unmasked_data");
AddScalarInput(false, "mask1", &ws);
AddScalarInput(float(), "values1", &ws, true);
AddScalarInput(true, "mask2", &ws);
AddScalarInput(1.0f, "values2", &ws);
unique_ptr<OperatorBase> op(CreateOperator(def, &ws));
EXPECT_NE(nullptr, op.get());
EXPECT_TRUE(op->Run());
Blob* unmasked_data_blob = ws.GetBlob("unmasked_data");
EXPECT_NE(nullptr, unmasked_data_blob);
auto& unmasked_data = unmasked_data_blob->Get<TensorCPU>();
EXPECT_EQ(unmasked_data.numel(), 1);
CHECK_EQ(unmasked_data.data<float>()[0], 1.0f);
}
} // namespace caffe2