首页 > 其他 > 详细

有监督对比损失Tensorflow版本

时间:2020-12-05 22:52:15      阅读:134      评论:0      收藏:0      [点我收藏+]

 这里给出论文的SupContrast: Supervised Contrastive Learning的损失函数Tensorflow版本,代码改自:https://github.com/thecharm/boundary-aware-nested-ner

损失文件losses.py

"""
Author: Yonglong Tian (yonglong@mit.edu)
Date: May 07, 2020
"""
from __future__ import print_function


import tensorflow as tf

class SupConLoss(object):
    """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf.
    It also supports the unsupervised contrastive loss in SimCLR"""
    def __init__(self, temperature=0.07, contrast_mode=‘all‘,
                 base_temperature=0.07):
        super(SupConLoss, self).__init__()
        self.temperature = temperature
        self.contrast_mode = contrast_mode
        self.base_temperature = base_temperature

    def forward(self, features, labels=None, mask=None):
        """Compute loss for model. If both `labels` and `mask` are None,
        it degenerates to SimCLR unsupervised loss:
        https://arxiv.org/pdf/2002.05709.pdf
        Args:
            features: hidden vector of shape [bsz, n_views, ...].
            labels: ground truth of shape [bsz].
            mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j
                has the same class as sample i. Can be asymmetric.
        Returns:
            A loss scalar.
        """


        sizes = features.get_shape().as_list()
        if len(sizes) < 3:
            raise ValueError(‘`features` needs to be [bsz, n_views, ...],‘
                             ‘at least 3 dimensions are required‘)
        if len(sizes) > 3:
            features = tf.reshape(features, [tf.shape(features)[0], tf.shape(features)[1], -1])

        batch_size = tf.shape(features)[0]
        if labels is not None and mask is not None:
            raise ValueError(‘Cannot define both `labels` and `mask`‘)
        elif labels is None and mask is None:
            mask = tf.eye(batch_size, dtype=tf.float32)
        elif labels is not None:
            labels = tf.reshape(labels, [-1,1])
            mask = tf.cast(tf.equal(labels, tf.transpose(labels,[1,0])),dtype=tf.float32)
        else:
            mask = tf.cast(mask,dtype=tf.float32)

        # contrast_count = tf.shape(features)[1]
        contrast_count = features.get_shape().as_list()[1]
        contrast_feature = tf.concat(tf.unstack(features,axis=1),axis=0)
        if self.contrast_mode == ‘one‘:
            anchor_feature = features[:, 0]
            anchor_count = 1
        elif self.contrast_mode == ‘all‘:
            anchor_feature = contrast_feature
            anchor_count = contrast_count
        else:
            raise ValueError(‘Unknown mode: {}‘.format(self.contrast_mode))

        # compute logits
        anchor_dot_contrast = tf.matmul(anchor_feature, contrast_feature, transpose_b=True)/self.temperature
        # for numerical stability
        logits_max = tf.reduce_max(anchor_dot_contrast, axis=1, keep_dims=True)
        logits = anchor_dot_contrast - tf.stop_gradient(logits_max)

        # tile mask
        mask = tf.tile(mask,[anchor_count, contrast_count])

        # mask-out self-contrast cases
        logits_mask =  tf.ones_like(mask) -tf.one_hot(tf.reshape(tf.range(batch_size * anchor_count),[-1]), depth=batch_size * anchor_count)


        mask = mask * logits_mask

        # compute log_prob
        exp_logits = tf.exp(logits) * logits_mask
        log_prob = logits - tf.log(tf.reduce_sum(exp_logits,axis=1, keep_dims=True))

        # compute mean of log-likelihood over positive
        mean_log_prob_pos = tf.reduce_sum(mask * log_prob, axis=1) / tf.reduce_sum(mask, axis=1)

        # loss
        loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos
        loss = tf.reduce_mean(tf.reshape(loss, [anchor_count, batch_size]))
        # loss = tf.reduce_mean(loss)
        return loss

 

测试:

import tensorflow as tf
import losses
import os
os.environ["CUDA_VISIBLE_DEVICES"]=‘0‘

loss = losses.SupConLoss()

X = tf.random_uniform([10,2,5])

y = tf.random_uniform([10],minval=0, maxval=2, dtype=tf.int32)

sess = tf.Session()


print(sess.run(loss.forward(X,y)))

 输出:8.23587  

 

有监督对比损失Tensorflow版本

原文:https://www.cnblogs.com/huadongw/p/14091240.html

(0)
(0)
   
举报
评论 一句话评论(0
关于我们 - 联系我们 - 留言反馈 - 联系我们:wmxa8@hotmail.com
© 2014 bubuko.com 版权所有
打开技术之扣,分享程序人生!