看Char-RNN的代码时遇到这两个函数,记录一下备忘。

先来看一下tf.split()函数,作用是分割tensor,将分割后的tensor放入一个list。

split(
value, # 输入的tensor
num_or_size_splits, # 如果是个整数n,就将输入的tensor分为n个子tensor。如果是个list T,就将输入的tensor分为len(T)个子tensor。
axis=0, # 默认为0,表示在哪个维度进行分割。
num=None,
name=split
)

举个栗子。

import tensorflow as tf

a = np.array([[1,2,3],
[4,5,6]])
b = tf.split(a, 3, 1)
c = tf.split(a, [1, 2], 1)
with tf.Session() as sess:
print (sess.run(b))
print (sess.run(c))
输出:
[array([[1],
[4]]), array([[2],
[5]]), array([[3],
[6]])]
[array([[1],
[4]]), array([[2, 3],
[5, 6]])]

我们来看一下上面的结果。可以看到输出的 b 是将 a 在第二个维度也就是「列」上将数组平均分为3份,而 c 则是将 a 在「列」维度上将 a 分成两份,每一份的长度对应list里的数值,此处为[1, 2],注意如果num_or_size_splits为一个数,则要分割的那个维度的大小k一定要能被num_or_size_splits整除,上例k=3,num_or_size_splits=3,可以整除,如果num_or_size_splits换为2,则会报错。同理如果num_or_size_splits是一个list,则list里的所有值之和应该等于要分割的那个维度的大小k,上例中1 + 2 = k。

上例中数组a为二维,在高维时同理,这里举一个三维的栗子,深度学习中处理三维的情况比较多。

import tensorflow as tf

a = np.reshape(range(24),(4,2,3))
b = tf.split(a,2,1)
d = tf.split(a,3,2)

with tf.Session() as sess:
print (sess.run(b))
print (sess.run(d))
输出:
[array([[[ 0, 1, 2]],

[[ 6, 7, 8]],

[[12, 13, 14]],

[[18, 19, 20]]]), array([[[ 3, 4, 5]],

[[ 9, 10, 11]],

[[15, 16, 17]],

[[21, 22, 23]]])]
[array([[[ 0],
[ 3]],

[[ 6],
[ 9]],

[[12],
[15]],

[[18],
[21]]]), array([[[ 1],
[ 4]],

[[ 7],
[10]],

[[13],
[16]],

[[19],
[22]]]), array([[[ 2],
[ 5]],

[[ 8],
[11]],

[[14],
[17]],

[[20],
[23]]])]

再来看一下 tf.squeeze()函数,作用是去掉维数为1的维度。

tf.squeeze
squeeze(
input, # 输入的tensor
axis=None, # 默认为None,去掉维数为1的维度,也可以指定,则去掉指定维度
name=None,
squeeze_dims=None
)

继续用上例举栗。

import tensorflow as tf

a = np.reshape(range(24),(4,2,3))
b = tf.split(a,2,1)
c = tf.squeeze(b[0])
print (c.shape)
输出:
(4, 3)

经过分割后的b是一个list,包含两个数组元素,每个元素的shape都是4 * 1 * 3,将第一个数组元素经过tf.squeeze()函数处理,可以看到维度变为4 * 3,去掉了维数为1的维度。

参考:

【tensorflow 学习】tf.split()和tf.squeeze()

推荐阅读:

相关文章