看Char-RNN的代碼時遇到這兩個函數,記錄一下備忘。

先來看一下tf.split()函數,作用是分割tensor,將分割後的tensor放入一個list。

split(
value, # 輸入的tensor
num_or_size_splits, # 如果是個整數n,就將輸入的tensor分為n個子tensor。如果是個list T,就將輸入的tensor分為len(T)個子tensor。
axis=0, # 默認為0,表示在哪個維度進行分割。
num=None,
name=split
)

舉個栗子。

import tensorflow as tf

a = np.array([[1,2,3],
[4,5,6]])
b = tf.split(a, 3, 1)
c = tf.split(a, [1, 2], 1)
with tf.Session() as sess:
print (sess.run(b))
print (sess.run(c))
輸出:
[array([[1],
[4]]), array([[2],
[5]]), array([[3],
[6]])]
[array([[1],
[4]]), array([[2, 3],
[5, 6]])]

我們來看一下上面的結果。可以看到輸出的 b 是將 a 在第二個維度也就是「列」上將數組平均分為3份,而 c 則是將 a 在「列」維度上將 a 分成兩份,每一份的長度對應list裏的數值,此處為[1, 2],注意如果num_or_size_splits為一個數,則要分割的那個維度的大小k一定要能被num_or_size_splits整除,上例k=3,num_or_size_splits=3,可以整除,如果num_or_size_splits換為2,則會報錯。同理如果num_or_size_splits是一個list,則list裏的所有值之和應該等於要分割的那個維度的大小k,上例中1 + 2 = k。

上例中數組a為二維,在高維時同理,這裡舉一個三維的栗子,深度學習中處理三維的情況比較多。

import tensorflow as tf

a = np.reshape(range(24),(4,2,3))
b = tf.split(a,2,1)
d = tf.split(a,3,2)

with tf.Session() as sess:
print (sess.run(b))
print (sess.run(d))
輸出:
[array([[[ 0, 1, 2]],

[[ 6, 7, 8]],

[[12, 13, 14]],

[[18, 19, 20]]]), array([[[ 3, 4, 5]],

[[ 9, 10, 11]],

[[15, 16, 17]],

[[21, 22, 23]]])]
[array([[[ 0],
[ 3]],

[[ 6],
[ 9]],

[[12],
[15]],

[[18],
[21]]]), array([[[ 1],
[ 4]],

[[ 7],
[10]],

[[13],
[16]],

[[19],
[22]]]), array([[[ 2],
[ 5]],

[[ 8],
[11]],

[[14],
[17]],

[[20],
[23]]])]

再來看一下 tf.squeeze()函數,作用是去掉維數為1的維度。

tf.squeeze
squeeze(
input, # 輸入的tensor
axis=None, # 默認為None,去掉維數為1的維度,也可以指定,則去掉指定維度
name=None,
squeeze_dims=None
)

繼續用上例舉慄。

import tensorflow as tf

a = np.reshape(range(24),(4,2,3))
b = tf.split(a,2,1)
c = tf.squeeze(b[0])
print (c.shape)
輸出:
(4, 3)

經過分割後的b是一個list,包含兩個數組元素,每個元素的shape都是4 * 1 * 3,將第一個數組元素經過tf.squeeze()函數處理,可以看到維度變為4 * 3,去掉了維數為1的維度。

參考:

【tensorflow 學習】tf.split()和tf.squeeze()

推薦閱讀:

相關文章