首页 > 编程语言 > 详细

tensorflow对多维tensor按照指定索引重排序

时间:2020-01-16 21:43:28      阅读:375      评论:0      收藏:0      [点我收藏+]

背景是这样的,

比如我有一个张量data,shape是(batch_size,100,128)

我还有一个张量inc,shape是(batch_size,100)

我现在想根据这个张量地索引来对data重排序。

为什么会有这样地需求呢,是因为比如data是数据,100代表数据步长,128代表数据内units数目(维度),inc代表一个分数,这个分数表明了这100个步长当中每一步的重要性。现在我想要对data重排序一下,取top10,变成(batch_size,10,128),这样有利于后面的Attention。

操作例子见代码:

最主要的思想就是你有一个N维向量,那么就要指定一个N-1维的索引来对其重排序。例子中我们是一个(batch_size,100,128)的数据,

那么如果:

data是(batch_size,A,B,C,100,128)

inc是(batch_size,A,B,C,100,128)呢?

我的想法是先data reshape成(batch_size*A*B*C,100,128)

inc reshape成(batch_size*A*B*C,100)

后面的操作就一样了,先unstack,分别用gather取出相应切片(其实这里就已经做了个排序)

然后再stack回去

 

import tensorflow as tf
import numpy as np

data = tf.placeholder(tf.int64, [None, 5, 2])

choose = tf.placeholder(tf.int64,[None,5])
sortarg = tf.argsort(choose, direction="DESCENDING")
split_data = tf.unstack(data, num=3, axis=0)
split_choose = tf.unstack(sortarg, num=3, axis=0)
trans_data_list = list()
for i in range(3):
    trans_data_list.append(tf.gather(split_data[i], sortarg[i]))
trans_data = tf.stack(trans_data_list, axis=0)



with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    feed_dict = {
        choose:[[5,4,3,0,1],[2,3,0,4,2],[2,3,5,4,2]],
        data:[[[1,2],[3,4],[5,6],[7,8],[9,10]], [[11,12],[13,14],[15,16],[17,18],[19,20]], [[21,22],[23,24],[25,26],[27,28],[29,30]]]
    }
    print(sess.run(sortarg,feed_dict=feed_dict))
    print("-----------------------------------------------------")
    # print(sess.run(data_trans,feed_dict = feed_dict))
    print(sess.run(data,feed_dict=feed_dict))
    print("-----------------------------------------------------")
    print(sess.run(trans_data, feed_dict=feed_dict))

  

tensorflow对多维tensor按照指定索引重排序

原文:https://www.cnblogs.com/zhouxiaosong/p/12203119.html

(0)
(0)
   
举报
评论 一句话评论(0
关于我们 - 联系我们 - 留言反馈 - 联系我们:wmxa8@hotmail.com
© 2014 bubuko.com 版权所有
打开技术之扣,分享程序人生!