#pragma once #include #include #include struct DeepAndWide : torch::nn::Module { DeepAndWide(int num_features = 50) { mu_ = register_parameter("mu_", torch::randn({1, num_features})); sigma_ = register_parameter("sigma_", torch::randn({1, num_features})); fc_w_ = register_parameter("fc_w_", torch::randn({1, num_features + 1})); fc_b_ = register_parameter("fc_b_", torch::randn({1})); } torch::Tensor forward( torch::Tensor ad_emb_packed, torch::Tensor user_emb, torch::Tensor wide) { auto wide_offset = wide + mu_; auto wide_normalized = wide_offset * sigma_; auto wide_noNaN = wide_normalized; // Placeholder for ReplaceNaN auto wide_preproc = torch::clamp(wide_noNaN, -10.0, 10.0); auto user_emb_t = torch::transpose(user_emb, 1, 2); auto dp_unflatten = torch::bmm(ad_emb_packed, user_emb_t); auto dp = torch::flatten(dp_unflatten, 1); auto input = torch::cat({dp, wide_preproc}, 1); auto fc1 = torch::nn::functional::linear(input, fc_w_, fc_b_); auto pred = torch::sigmoid(fc1); return pred; } torch::Tensor mu_, sigma_, fc_w_, fc_b_; }; // Implementation using native functions and pre-allocated tensors. // It could be used as a "speed of light" for static runtime. struct DeepAndWideFast : torch::nn::Module { DeepAndWideFast(int num_features = 50) { mu_ = register_parameter("mu_", torch::randn({1, num_features})); sigma_ = register_parameter("sigma_", torch::randn({1, num_features})); fc_w_ = register_parameter("fc_w_", torch::randn({1, num_features + 1})); fc_b_ = register_parameter("fc_b_", torch::randn({1})); allocated = false; prealloc_tensors = {}; } torch::Tensor forward( torch::Tensor ad_emb_packed, torch::Tensor user_emb, torch::Tensor wide) { torch::NoGradGuard no_grad; if (!allocated) { auto wide_offset = at::add(wide, mu_); auto wide_normalized = at::mul(wide_offset, sigma_); // Placeholder for ReplaceNaN auto wide_preproc = at::cpu::clamp(wide_normalized, -10.0, 10.0); auto user_emb_t = at::native::transpose(user_emb, 1, 2); auto dp_unflatten = at::cpu::bmm(ad_emb_packed, user_emb_t); // auto dp = at::native::flatten(dp_unflatten, 1); auto dp = dp_unflatten.view({dp_unflatten.size(0), 1}); auto input = at::cpu::cat({dp, wide_preproc}, 1); // fc1 = torch::nn::functional::linear(input, fc_w_, fc_b_); fc_w_t_ = torch::t(fc_w_); auto fc1 = torch::addmm(fc_b_, input, fc_w_t_); auto pred = at::cpu::sigmoid(fc1); prealloc_tensors = { wide_offset, wide_normalized, wide_preproc, user_emb_t, dp_unflatten, dp, input, fc1, pred}; allocated = true; return pred; } else { // Potential optimization: add and mul could be fused together (e.g. with // Eigen). at::add_out(prealloc_tensors[0], wide, mu_); at::mul_out(prealloc_tensors[1], prealloc_tensors[0], sigma_); at::native::clip_out( prealloc_tensors[1], -10.0, 10.0, prealloc_tensors[2]); // Potential optimization: original tensor could be pre-transposed. // prealloc_tensors[3] = at::native::transpose(user_emb, 1, 2); if (prealloc_tensors[3].data_ptr() != user_emb.data_ptr()) { auto sizes = user_emb.sizes(); auto strides = user_emb.strides(); prealloc_tensors[3].set_( user_emb.storage(), 0, {sizes[0], sizes[2], sizes[1]}, {strides[0], strides[2], strides[1]}); } // Potential optimization: call MKLDNN directly. at::cpu::bmm_out(ad_emb_packed, prealloc_tensors[3], prealloc_tensors[4]); if (prealloc_tensors[5].data_ptr() != prealloc_tensors[4].data_ptr()) { // in unlikely case that the input tensor changed we need to // reinitialize the view prealloc_tensors[5] = prealloc_tensors[4].view({prealloc_tensors[4].size(0), 1}); } // Potential optimization: we can replace cat with carefully constructed // tensor views on the output that are passed to the _out ops above. at::cpu::cat_outf( {prealloc_tensors[5], prealloc_tensors[2]}, 1, prealloc_tensors[6]); at::cpu::addmm_out( prealloc_tensors[7], fc_b_, prealloc_tensors[6], fc_w_t_, 1, 1); at::cpu::sigmoid_out(prealloc_tensors[7], prealloc_tensors[8]); return prealloc_tensors[8]; } } torch::Tensor mu_, sigma_, fc_w_, fc_b_, fc_w_t_; std::vector prealloc_tensors; bool allocated = false; }; torch::jit::Module getDeepAndWideSciptModel(int num_features = 50); torch::jit::Module getTrivialScriptModel(); torch::jit::Module getLeakyReLUScriptModel(); torch::jit::Module getLeakyReLUConstScriptModel(); torch::jit::Module getLongScriptModel(); torch::jit::Module getSignedLog1pModel();