大概可以參考pytorch裡面torch.nn.functional.gumbel_softmax的實現:

def gumbel_softmax(logits, tau=1, hard=False, eps=1e-10, dim=-1):

if eps != 1e-10:
warnings.warn("`eps` parameter is deprecated and has no effect.")

gumbels = -torch.empty_like(logits).exponential_().log() # ~Gumbel(0,1)
gumbels = (logits + gumbels) / tau # ~Gumbel(logits,tau)
y_soft = gumbels.softmax(dim)

if hard:
# Straight through.
index = y_soft.max(dim, keepdim=True)[1]
y_hard = torch.zeros_like(logits).scatter_(dim, index, 1.0)
ret = y_hard - y_soft.detach() + y_soft
else:
# Reparametrization trick.
ret = y_soft
return ret

我把注釋刪掉了…… 其實可以看到 hard=True 的時候返回的就是one-hot向量。其中y_soft 是採樣出來的概率分佈,y_hard是根據這個概率分佈得到求出來的one-hot向量,detach()這個方法實際上是把一個張量移除計算圖變成常量,這樣反向傳播的時候就不會計算它的梯度。

所以這個東西:

ret = y_hard - y_soft.detach() + y_soft

就是構造了一個數值上等於one-hot向量的張量,但實際上反向傳播的時候梯度是順著y_soft傳回去的。


使用gumbel-max(注意是max不是softmax)能夠等價於對softmax進行採樣,但是還有一個問題就是argmax不可導,因此將argmax替換為softmax,調節係數使得softmax分佈無限趨近於one-hot,以達到近似argmax的效果,選取softmax最大輸出作為action就可以了


可以用如下trick,假設s是gumbel softmax中採樣,stop_gradient(one_hot(argmax (s))-s)+s, 此時前向傳播的時候是one hot (因為上述操作的值等於one_hot(argmax(s))),後向傳播的時候有gradient (gradient 從 s 裏傳回),但注意這個時候gradient是biased,bias多大可以通過gumbel softmax裏的temperature調節(但不可能變為0),實際效果還行。


這圖有些誤導吧

作者自己也說了

we approximate the argmax by a (low-temperature) softmax … We have differentiable sampling operator (albeit with a soft one-hot output instead of a scalar).

就是把softmax的T調低

得到的只是接近one-hot的soft one-hot

不是真的one-hot


我本來是想按照圖上的流程走可以讓Actor網路輸出一個動作值(標量),並且從該動作值出發可以bp求導。後來問了原作者之後,他解釋道:With Gumbel-Softmax sampling, you actually dont get a scalar result. The output of the sampling process (after adding Gumbel noise to the logits then applying a softmax) is a vector, with most of the weight concentrated in the action that would have been chosen by a normal sampling process. 也就是說加完雜訊並且通過softmax之後得到一個概率向量,只能再通過採樣才能得到標量的動作值。所以這樣的話,是從動作值開始bp全程可導是行不通的。


推薦閱讀:
相關文章