## @package dot_product # Module caffe2.python.layers.dot_product from caffe2.python import schema from caffe2.python.layers.layers import ( ModelLayer, ) class PairwiseSimilarity(ModelLayer): def __init__(self, model, input_record, output_dim, pairwise_similarity_func='dot', name='pairwise_similarity', **kwargs): super().__init__(model, name, input_record, **kwargs) assert isinstance(input_record, schema.Struct), ( "Incorrect input type. Expected Struct, but received: {0}". format(input_record)) assert ( ('all_embeddings' in input_record) ^ ('x_embeddings' in input_record and 'y_embeddings' in input_record) ), ( "either (all_embeddings) xor (x_embeddings and y_embeddings) " + "should be given." ) self.pairwise_similarity_func = pairwise_similarity_func if 'all_embeddings' in input_record: x_embeddings = input_record['all_embeddings'] y_embeddings = input_record['all_embeddings'] else: x_embeddings = input_record['x_embeddings'] y_embeddings = input_record['y_embeddings'] assert isinstance(x_embeddings, schema.Scalar), ( "Incorrect input type for x. Expected Scalar, " + "but received: {0}".format(x_embeddings)) assert isinstance(y_embeddings, schema.Scalar), ( "Incorrect input type for y. Expected Scalar, " + "but received: {0}".format(y_embeddings) ) if 'indices_to_gather' in input_record: indices_to_gather = input_record['indices_to_gather'] assert isinstance(indices_to_gather, schema.Scalar), ( "Incorrect type of indices_to_gather. " "Expected Scalar, but received: {0}".format(indices_to_gather) ) self.indices_to_gather = indices_to_gather else: self.indices_to_gather = None self.x_embeddings = x_embeddings self.y_embeddings = y_embeddings dtype = x_embeddings.field_types()[0].base self.output_schema = schema.Scalar( (dtype, (output_dim,)), self.get_next_blob_reference('output') ) def add_ops(self, net): if self.pairwise_similarity_func == "cosine_similarity": x_embeddings_norm = net.Normalize(self.x_embeddings(), axis=1) y_embeddings_norm = net.Normalize(self.y_embeddings(), axis=1) Y = net.BatchMatMul( [x_embeddings_norm, y_embeddings_norm], [self.get_next_blob_reference(x_embeddings_norm + '_matmul')], trans_b=1, ) elif self.pairwise_similarity_func == "dot": Y = net.BatchMatMul( [self.x_embeddings(), self.y_embeddings()], [self.get_next_blob_reference(self.x_embeddings() + '_matmul')], trans_b=1, ) else: raise NotImplementedError( "pairwise_similarity_func={} is not valid".format( self.pairwise_similarity_func ) ) if self.indices_to_gather: flattened = net.Flatten( Y, Y + '_flatten', ) net.BatchGather( [flattened, self.indices_to_gather()], self.output_schema(), ) else: net.Flatten(Y, self.output_schema())