轉載請註明出處

TL;DR

(0). 為什麼需要狀態容器:

因為TensorFlow 2.0以後默認打開動態圖運算,因此tf.Varibable(以及其它狀態類型)如果不放在一個容器中會被python解釋器回收掉,無法求導以及更新,這種容器稱之為狀態容器(stateful container)。TensorFlow 2.0中狀態容器的繼承結構圖為:

Trackable
|
|-- tf.Variable
|
|-- MutableHashTable
|
|-- AutoTrackable
|
|-- ListWrapper/DictWrapper
|
|-- tf.train.Checkpoint
|
|-- tf.Module
|
|-- tf.keras.layers.Layer
|
|-- tf.keras.Model
|
|-- tf.keras.Sequential

(1). Trackable:

Trackablefrom tensorflow.python.training.tracking.base import Trackable)是所有的狀態容器的基類,其定義了一個self._track_trackable(self, name: Str, value: Trackable) -> None方法,來使得value成為self名為name的依賴。最終Trackable的實例之間的依賴關係構成一個有向無環圖(DAG),只要根節點不被解釋器回收,則所有的依賴都會存在。

(2). AutoTrackable:

AutoTrackablefrom tensorflow.python.training.tracking.tracking import AutoTrackable),在__setattr__中運行_track_trackable,執行a.b = c時自動的將c設成a名為b的依賴(這也是叫AutoTrackable的原因),避免顯式調用_track_trackable帶來的無聊和繁瑣。

(3). Trackable vs AutoTrackable

只能成為別人的依賴而不依賴於別人的的對象,繼承Trackable即可,而不用AutoTrackable,比如tf.VariableMutableHashTable。這些類型通常是可以被保存的對象。

(4). tf.train.Checkpoint

基於對象的儲存(object-based save)的類型,地位等同於1.xtf.train.Saver。繼承自AutoTrackable。該類型在儲存時會遍歷Trackable組成的DAG儲存和載入變數(或stateful的組件)。

(5). tf.Module:

定位為一個輕量級的狀態容器,繼承自AutoTrackable,是所有的public api中的狀態容器中最底層的一個,最底層的能夠收集.variables.submodules屬性的一個,因為可以收集變數,所以這個類型可以用來建模,配合tf.GradientTape使用。

(6). tf.keras.layers.Layer:

Layerkeras中最底層的類型,繼承自tf.Module,相比於其父類,Layer開始有了建模的各種特性和功能,包括惰性構建機制(lazy build)統一的調用介面。大量已經被定義好的tf.keras.layers.Layer的子類可以用來快速的構建模型,如全連接層tf.keras.layers.Dense,卷積層tf.keras.layers.Conv2D和遞歸層tf.keras.layers.RNN等等。

(7). tf.keras.Model:

(間接)繼承自tf.keras.layers.Layer,是keras中建模的核心類型,有如下特性:

  • 兩種形式的構造函數(functional和subclassed)
  • 定義了compilefit方法,結合tf.keras.callbacks,提供一站式訓練服務

(8). tf.keras.Sequential

對於輸入的張量,挨個作用裡麪包含的tf.keras.layers.Layer,適合實現VGG16這樣一條路走到黑的(子)模型。

(9). 這麼多狀態容器,選擇繼承哪個?

  • 僅在學習和深入研究狀態容器(或基於對象的儲存)時使用TrackableAutoTrackable
  • tf.Module: 適合自定義訓練循環時使用
  • tf.keras.layers.Layer:適合實現一些中間層,比如Attention之類的,可以配合tf.keras.Sequential使用,極少看見大的模型繼承自這個類型。
  • tf.keras.Model:適合一些固定套路的模型(使用compile + fit)。雖然也可以自定義訓練循環,但是有一種殺雞用牛刀的感覺。
  • tf.keras.Sequential:適合一條路走到黑的(子)模型。

寫在最前

  • 所有內容基於個人的使用和源代碼的閱讀,屬於個人總結,難免會有錯誤遺漏之處以及個人狹隘的觀點,不足之處還請大家多多指教。
  • 完全理解本文的內容需要掌握很多python和tensorflow的知識點如python的垃圾回收機制和魔法方法,tensorflow的運算圖原理和序列化格式SavedModel等等。
  • 運行代碼段需要2.0.beta1版本,其餘版本沒有測試過(1.x肯定不行,2.x可能可以),直接運行pip install tensorflow==2.0.beta1即可(2.7或3.x版本都行)。除非顯式地註明瞭代碼段的運行版本,否則默認一概使用2.0.beta1(有無GPU均可)。 部分代碼需要安裝pytorch
  • 轉載請註明出處

為什麼需要狀態容器

TensorFlow 2.0相對於TensorFlow 1.x的一個巨大的變化是默認打開動態圖(查看此文檔瞭解更詳細的變更),這個變化導致的一個直接結果是tf.Variable(及其子類和其他的所有的狀態類型如MutableHashTable)的生命週期將與其對應的python對象綁定,簡單解釋一下就是:

# 僅僅只能在1.x版本中運行!!!
import tensorflow as tf

v = tf.Variable(2.)
v = "foo"
default_graph = tf.get_default_graph()
vs = default_graph.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
print(vs)

一開始創建的tf.Variable所綁定的對象v被另外地賦值,理論上根據python的垃圾回收機制,這個tf.Variable應該被解釋器回收掉(因為已經沒有引用指向這個tf.Variable了),然而列印出的vs並不是空,這是因為1.x版本中創建的變數將依附在運算圖tf.Graph中。2.0之後呢?

import tensorflow as tf

v = tf.Variable(2.)
v = "foo"

因為沒有了默認的運算圖,執行完v = "foo"之後tf.Variable這個對象由於最後一個指向他的引用消失而被解釋器回收。

這個變化有什麼樣的影響呢,看一個線性回歸的例子就知道了

import tensorflow as tf

def linear_regression(x):
w = tf.Variable(tf.random.normal(3, 4))
b = tf.Variable(tf.random.normal((4,)))
return tf.matmul(x, w) + b

這是一段典型的1.x版本的代碼(這種函數配合tf.variable_scope使用並做變數重用(variable resue)甚至成為了面試中關於TensorFlow的一個經典問題)。因為運算圖的存在,函數執行完之後創建的變數wb就算不再有引用指向他們,還是會被儲存在tf.Graph中,因此不用擔心變數被解釋器回收的事情。然而到了2.0之後,這個函數等於廢了,因為運行完了之後wb已經不存在了…...變數都不在了,對變數求導、更新和保存就別想了吧。

現在的問題變成了怎麼做才能保證變數不被回收掉:只要保存至少一個指向他們的引用即可,簡單修改一下這個linear_regression函數變成類型就能做到:

import tensorflow as tf

class LinearRegression(object):
def __init__(self):
self.w = tf.Variable(tf.random.normal(3, 4))
self.b = tf.Variable(tf.random.normal((4,)))

def __call__(self, x):
return tf.matmul(x, self.w) + self.b

linear_regression = LinearRegression()

現在,就算調用了linear_regression(這個對象可以調用,因為__call__被顯式地定義了),變數也不再會被回收掉,因為總存在指向他們的引用。

對比一下就能夠發現,這個LinearRegression類型起了一個容器的作用,他用來儲存與自己的運算相關的tf.Variable(當然也要進行運算,在__call__中),防止他們被解釋器回收掉。

這個LinearRegression類型是我們自己定義的,繼承自object,而在TensorFlow 2.0中,為了能夠更好的收集、儲存和使用tf.Variable(或者更廣義一點,所有的stateful的組件,還比如tf.lookup.StaticHashTable等),新版本定義了一系列的複雜度與功能各不一樣的類型,因為這些類型的主要的作用是充當stateful的組件的容器,因為官方稱之為狀態容器(stateful container)

狀態容器類型

以上是新版本中一定需要狀態容器的原因(如果沒有,變數會被回收無法儲存和復用),現在來介紹一下TensorFlow 2.0中狀態容器的類型。他們整個繼承結構圖,見前面的樹狀圖,以下對這個樹狀繼承圖中重要的類型做一些解釋和說明。

Trackable

Trackable(引用方式為from tensorflow.python.training.tracking.base import Trackable)是新版本中所有的狀態容器的基類,繼承自objetc,其定義了self._track_trackable(self, name: Str, value: Trackable) -> None的方法,來使得value成為self名為name的依賴最終所有的Trackable的實例之間的依賴關係構成一個有向無環圖(Directed Acyclic Graph, 縮寫為DAG)。以下是一個通過調用_track_trackable來構造依賴關係DAG的例子:

import tensorflow as tf
from tensorflow.python.training.tracking.base import Trackable
from tensorflow.python.ops.lookup_ops import MutableHashTable
from tensorflow.python.training.tracking.graph_view import ObjectGraphView
import weakref

t1 = Trackable()
t2 = Trackable()
# 這裡暫時只要知道`tf.Variable`和`MutableHashTable`也都繼承自`Trackable`就行
v = tf.Variable(1.)
hash_table = MutableHashTable(key_dtype=tf.int32, value_dtype=tf.int32, default_value=0)
t1._track_trackable(t2, "child_trackable")
t1._track_trackable(v, "variable1")
t2._track_trackable(v, "variable2")
t2._track_trackable(hash_table, "dict")

# 使用`ObjectGraphView`這個類型可以獲得這個DAG的結構
object_graph_view = ObjectGraphView(weakref.ref(t1))
# saveables儲存了這個DAG中可以被保存的狀態類型,這裡有兩個,一個是v,另一個是hash_table,saveables會在
# tf.train.Checkpoint中被使用
saveables, dag, _ = object_graph_view.serialize_object_graph()
print(dag)

根據代碼可以看出構建的DAG大概如下圖所示:

這從列印得到的dag可以得到印證:

nodes {
children {
node_id: 1
local_name: "child_trackable"
}
children {
node_id: 2
local_name: "variable1"
}
}
nodes {
children {
node_id: 2
local_name: "variable2"
}
children {
node_id: 3
local_name: "dict"
}
}
nodes {
attributes {
name: "VARIABLE_VALUE"
full_name: "Variable"
checkpoint_key: "variable1/.ATTRIBUTES/VARIABLE_VALUE"
}
}
nodes {
attributes {
name: "table"
checkpoint_key: "child_trackable/dict/.ATTRIBUTES/table"
}
}

一共有四個節點,很容易可以看出這DAG中的四個節點分別代表了t1t2vhash_table

Trackable另一個重要的方法self._gather_saveables_for_checkpoint(self),會收集可以被儲存的對象,後文會在基於對象的儲存tf.train.Checkpoint中提到這個方法。

AutoTrackable

引用方式為from tensorflow.python.training.tracking.tracking import AutoTrackable。對於Trackable,每次都調用_track_trackable會顯得非常的繁瑣,能不能a.b = c的時候直接隱式地執行一下a._track_trackable(b, c)呢,AutoTrackable這個類型就做到了這一點。他通過覆蓋(override) __setattr____delattr__來達到自動收集(或者刪除)依賴的目的(這也是叫AutoTrackable的原因)。將上面一段關於Trackable的代碼改寫成等價的AutoTrackable的形式,會清爽的多。

import tensorflow as tf
from tensorflow.python.training.tracking.tracking import AutoTrackable
from tensorflow.python.ops.lookup_ops import MutableHashTable
from tensorflow.python.training.tracking.graph_view import ObjectGraphView
import weakref

t1 = AutoTrackable()
t2 = AutoTrackable()
v = tf.Variable(1.)
hash_table = MutableHashTable(key_dtype=tf.int32, value_dtype=tf.int32, default_value=0)
t1.child_trackable = t2
t1.variable1 = v
t2.variable2 = v
t2.dict = hash_table

只會成為他人依賴的類型

對比一下AutoTrackableTrackable的區別就能發現,x.y = z中,x肯定是AutoTrackable的,但是z不一定,如果z只會成為別人的依賴,而不依賴於別人(也就是說z是整個DAG中的葉節點),那z大可不必繼承AutoTrackable因為也不會有人執行z.xxx = xxxxx。在新版本中有很多這種只會成為他人依賴的類型,比如剛剛看到的tf.VariableMutableHashTable等等,這些類型通常是可以被保存的對象,因此會覆蓋_gather_saveables_for_checkpointTrackable一節代碼中的saveables變數,正是因為tf.VariableMutableHashTable覆蓋了這個方法,才使得自己出現在saveables中(並最終被tf.train.Checkpoint保存,具體看tf.train.Checkpoint一節)。

ListWrapper和DictWrapper

假如我想attach一個Trackable組成的list或者dictAutoTrackable,會不會因為list或者dict不是Trackable而導致DAG構建的有問題?實際上並不會,為了應對這個問題,新版本特意設計了兩個類型(引用方式為from tensorflow.python.training.tracking.data_structures import ListWrapper, DictWrapper)。ListWrapper多重繼承自Trackablecollections.Sequence_DictWrapper多重繼承自Trackablecollections.Mapping。一旦檢測到被attach的對象是一個Trackablelist或者dict,則將這個list或者dict轉變為ListWrapperDictWrapper__setattr__被覆蓋所具有的檢測功能),以下為一個例子:

from tensorflow.python.training.tracking.tracking import AutoTrackable
from tensorflow.python.training.tracking.base import Trackable

auto_trackable = AutoTrackable()
auto_trackable.trackable_list = [Trackable() for _ in range(2)]
auto_trackable.trackable_dict = {str(i): Trackable() for i in range(2)}
print(type(auto_trackable.trackable_list))

pytorch熟悉的用戶來講,ListWrapper(功能上)等價於torch.nn.ModuleListDictWrapper(功能上)等價於torch.nn.ModuleDict

tf.train.Checkpoint

藉助於TrackableAutoTrackable這兩個類型就足夠完成基於對象的保存(objec-based saving)了,以下為一個使用AutoTrackable做基於對象的儲存的例子:

import tensorflow as tf
from tensorflow.python.training.tracking.tracking import AutoTrackable
from tensorflow.python.ops.lookup_ops import MutableHashTable

t1 = AutoTrackable()
t2 = AutoTrackable()
v = tf.Variable(1.)
hash_table = MutableHashTable(key_dtype=tf.int32, value_dtype=tf.int32, default_value=0)
hash_table.insert(keys=[1, 2, 3], values=[2, 4, 6])
t1.child_trackable = t2
t1.variable1 = v
t2.variable2 = v
t2.dict = hash_table

checkpoint = tf.train.Checkpoint(root=t1)
checkpoint.save("checkpoint/save")

執行完畢,會在checkpoint目錄下生成文件save-1(多出來的一表示保存次數)。接下來從文件中載入:

import tensorflow as tf
from tensorflow.python.training.tracking.tracking import AutoTrackable
from tensorflow.python.ops.lookup_ops import MutableHashTable

trackable1 = AutoTrackable()
trackable2 = AutoTrackable()
# 原始值從1.改成了2.
variable = tf.Variable(2.)
dictionary = MutableHashTable(key_dtype=tf.int32, value_dtype=tf.int32, default_value=0)
# 字典的value從[2, 4, 6]變成了[10, 20, 30]
dictionary.insert(keys=[1, 2, 3, 4], values=[10, 20, 30, 40])
trackable1.child_trackable = trackable2
trackable1.variable1 = variable
trackable2.variable2 = variable
trackable2.dict = dictionary
print("before restore")
print(trackable1.variable1.numpy())
print(trackable1.child_trackable.dict.lookup(tf.constant([1, 4])).numpy())
checkpoint = tf.train.Checkpoint(root=trackable1)
checkpoint.restore("checkpoint/save-1")
print("after restore")
print(trackable1.variable1.numpy())
print(trackable1.child_trackable.dict.lookup(tf.constant([1, 4])).numpy())

輸出結果為

before restore
2.0
[10 30]
after restore
1.0
[2 6]

應該有人注意到了我重命名了很多python變數,比如t1變成了trackable1hash_table變成了dictionary,但是最終還是能夠成功地恢復變數,因為變數名一點都不重要,重要的是儲存時的鍵的名字(比如child_trackable)。

這個類型是如何做到儲存的呢?一個核心的類型是剛剛已經出現過了的ObjectGraphView(引用方式為from tensorflow.python.training.tracking.graph_view import ObjectGraphView),這個類型接收DAG的根節點之後開始遍歷整個DAG並通過調用節點的_gather_saveables_for_checkpoint來收集可以被保存的對象(如變數值)以及他們之間的依賴關係並儲存。

tf.train.Checkpoint還具有創建時載入的機制(restore on creation),來看一個例子:

import tensorflow as tf
from tensorflow.python.training.tracking.tracking import AutoTrackable

class RestoreOnCreation(AutoTrackable):
def __init__(self):
super(RestoreOnCreation, self).__init__()
self.v = None

def __call__(self, init_value):
if self.v is None:
self.v = tf.Variable(init_value)
return self.v

restore_on_create = RestoreOnCreation()
# 在這個調用之後`restore_on_create.v`才會被創建
restore_on_create(1.)
tf.train.Checkpoint(root=restore_on_create).save("save/restore_on_create")

# 創建新的`RestoreOnCreation`對象`another_restore_on_create`
another_restore_on_create = RestoreOnCreation()
# 這裡的`another_restore_on_create`的`.v`屬性還是個`None`的時候就執行了載入
tf.train.Checkpoint(root=another_restore_on_create).restore("save/restore_on_create-1")
# !!!傳進去的初始值雖然是2.,但是因為restore on creation的機制,還是會被修改成1,因此列印結果為1.
print(another_restore_on_create(2.))

restore on creation的機制會在tf.keras.layers.Layer的惰性構建中使用(在tf.keras.layers.Layer一節會提到)。

作為新老兩代變數儲存的核心類型,這裡不得不提的是tf.train.Checkpointtf.train.Saver的區別:

  • tf.train.Saver:

實際上,2.0以後就沒有tf.train.Saver了(執行tf.train.Saver會報AttributeError),在1.x靜態運算圖時代,tf.train.Saver會構建一個變數名到變數的字典(隱式地從運算圖中獲得,或者顯式地由用戶傳入這個字典,參考其構造函數)以此來存儲(tf.train.Saver.save)變數或者載入(tf.train.Saver.restore)變數,載入變數時如果變數名變掉了則載入失敗,因此這是一種基於變數名的儲存方式(官方稱之為name-based save/restore)。

  • tf.train.Checkpoint

在保存時(tf.train.Checkpoint.save)會記錄DAG中的依賴關係(這個依賴關係圖會被序列化成字元串儲存在了checkpoint文件中,對應的鍵為_CHECKPOINTABLE_OBJECT_GRAPH),在載入階段(tf.train.Checkpoint.restore)會讀取這個字元串並解析得到DAG依賴關係圖,並用DAG依賴關係圖中的節點匹配Trackable對象,匹配完成後載入變數值,因此這是一種基於對象的儲存方式(官方稱之為object-based save/restore)。

tf.Module

tf.Module是最底層的公開(public api)的狀態容器(TrackableAutoTrackable都不是public的),關於這個類型可以參考官方的文檔,其核心在於實現了兩個功能:

  • 收集所有的變數,儲存在tf.Module.variables(或者僅僅是可以訓練的變數tf.Module.trainable_variables
  • 收集所有的子tf.Module,儲存在tf.Module.submodules

這種收集某個類型的實例的具體的做法是tf.Module的所有屬性進行遍歷(注意並沒有調用_gather_saveables_for_checkpoint),一旦發現是某個類型的實例,則將其收集。因為其收集變數的功能,可以配合tf.GradientTape(這個類型是新版本中計算導數的核心,tf.GradientTape.gradient與1.x版本中的tf.train.Optimizer.compute_gradients的作用一樣)來進行導數的運算和變數的更新了,以下為使用tf.Module做線性回歸的例子:

import tensorflow as tf

class Model(tf.Module):
def __init__(self):
super(Model, self).__init__()
self.v = tf.Variable([[20.],
[-10.]])

def forward(self, x):
return tf.matmul(x, self.v)

@staticmethod
def mse(y_pred, y_true):
return tf.reduce_mean(tf.square(y_pred - y_true))

x = tf.random.normal((1024, 2))
y = tf.matmul(x, tf.constant([[1.],
[-1.]])) + tf.random.normal((1024,))

model = Model()
optimizer = tf.keras.optimizers.Adam(1e-1)
iterations = 1024

for i in range(iterations):
with tf.GradientTape() as tape:
# 此處的model.trainable_variables當然可以換成人工收集的變數[model.v]的形式,但是當變數數非常大的時候,
# model.trainable_variables帶來的便利是巨大的
grads = tape.gradient(Model.mse(model.forward(x), y), model.trainable_variables)
optimizer.apply_gradients(zip(grads, model.trainable_variables))

print(model.v.numpy())

# 由於model是tf.Module的子類的實例,因此也是Trackable的實例,從而就可以用`tf.train.Checkpoint`進行變數的儲存和載入,下面加以展示
checkpoint = tf.train.Checkpoint(model=model)
# 務必保證`checkpoint/linear_regression`目錄存在
checkpoint.save("checkpoint/linear_regression/trained-model")

# 新建一個模型,並從訓練好的checkpoint載入變數值
another_model = Model()
print("another model before restore")
print(another_model.v)
another_checkpoint = tf.train.Checkpoint(model=another_model)
another_checkpoint.restore("checkpoint/linear_regression/trained-model-1")
print("another model after restore")
print(another_model.v)

正如在注釋中提到的,非常複雜的網路的變數是非常多的,在計算他們的導數(tf.GradientTape.gradient)中需要收集所有變數時,tf.Module.variables屬性一次性將他們收集好會帶來非常大的便利,這也是tf.Module適合用來建模的原因之一。

官方沒有給tf.Module設定任何的計算入口,因此在這裡我就定義了一個forward,實際上也可以是__call__,反正隨便是什麼名字的方法都可以。

熟悉pytorch的應該都能理解tf.Module為啥叫Module,以及我為啥給這個模型定義的方法名叫forward。實際上torch.nn.Module.parameters()功能上等價於tf.Module.variablestorch.nn.Module.children()功能上等價於tf.Module.submodulestorch.save(torch.nn.Module.state_dict(), xxx)功能上等價於tf.train.Checkpoint(whatever=tf.Module).save(xxx)。但是不能認為這兩個類型的工作方式是完全一樣的torch.nn.Module.state_dict會從torch.nn.Module.parameters()中來(當然還包含torch.nn.Module.buffers()),然而tf.train.Checkpoint收集的saveables並不是從tf.Module.variables來(而是從每一個葉節點的_gather_saveables_for_checkpoint而來),歸根到底,這還是因為TensorFlow中可以被保存的類型不止tf.Variable一個(還比如MutableHashTable,這種可以被保存的組件並不能算作是tf.Variable,因此也不該出現在tf.Module.variables中)。

tf.keras.layers.Layer

官方文檔在此。終於進入到keras的地盤了, Layerkeras中最底層的類型,keras中有大量已經被定義好的tf.keras.layers.Layer的(直接或間接的)子類可以用來快速的構建機器學習模型,如全連接層tf.keras.layers.Dense,卷積層tf.keras.layers.Conv2D和遞歸層tf.keras.layers.RNN等等。相比於輕量級的tf.Moduletf.keras.layers.Layer顯式地定義了用於建模的各種特性和功能,比如:

惰性構建的機制

tf.keras.layers.Layer在調用過構造函數創建實例之後一般是不包含變數的,只有在第一次遇到輸入時,才會根據輸入張量的維度,調用build(self, input_shape)方法,構造變數,具體的看一個例子:

import tensorflow as tf

dense = tf.keras.layers.Dense(2)
print(dense.variables)
x = tf.random.normal((4, 3))
y = dense(x)
print(dense.variables)

初次列印,dense包含的變數是個空list,而作用在一個形狀為[4, 3]的矩陣上後,dense包含的變數構建完畢,並且變數的尺寸恰好可以跟[4, 3]進行運算。在tf.keras.layers.Dense.build實現中可以直接發現這種邏輯(截取部分代碼段):

def build(self, input_shape):
# 取出輸入的特徵維度last_dim
last_dim = tensor_shape.dimension_value(input_shape[-1])
# 根據輸入的特徵維度構建kernel和bias
self.kernel = self.add_weight(
kernel,
# kernel的形狀為[last_dim, self.units]
shape=[last_dim, self.units],
initializer=self.kernel_initializer,
regularizer=self.kernel_regularizer,
constraint=self.kernel_constraint,
dtype=self.dtype,
trainable=True)
if self.use_bias:
self.bias = self.add_weight(
bias,
shape=[self.units,],
initializer=self.bias_initializer,
regularizer=self.bias_regularizer,
constraint=self.bias_constraint,
dtype=self.dtype,
trainable=True)
else:
self.bias = None
self.built = True

這種惰性構建的機制有什麼好處麼?看一個pytorch的VGG的實現就知道了:

# 截取自pytorch官方的VGG實現
from torch import nn

class VGG(nn.Module):

def __init__(self, features, num_classes=1000, init_weights=True):
super(VGG, self).__init__()
self.features = features
self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
self.classifier = nn.Sequential(
# 512 * 7 * 7是怎麼來的?
nn.Linear(512 * 7 * 7, 4096),
nn.ReLU(True),
nn.Dropout(),
nn.Linear(4096, 4096),
nn.ReLU(True),
nn.Dropout(),
nn.Linear(4096, num_classes),
)
if init_weights:
self._initialize_weights()

def forward(self, x):
x = self.features(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
x = self.classifier(x)
return x

nn.Linear(512 * 7 * 7, 4096)這裡的512 * 7 * 7顯然是根據輸入的圖像的尺寸以及經過的各種卷積和池化之後算出來的輸出的尺寸,用戶需要對這些運算非常熟悉才能算出這個值。如果使用惰性運算的話,則這個層在遇見輸入張量的時候才會構建與之相匹配的變數,減小用戶顯式計算這個值的負擔。具體一點,tf.keras.layers.Dense(4096)即可替代nn.Linear(512 * 7 * 7, 4096),當tf.keras.layers.Dense遇到這個張量並發現輸入的特徵維度為512 * 7 * 7時會在build方法中創建[512 * 7 * 7, 4096] 的權重用於運算。因此,以上pytorch實現中self.classifier可以被替換成tf.keras惰性構建的實現:

self.classifier = tf.keras.Sequential(
[tf.keras.layers.Dense(4096, activation="relu"),
tf.keras.layers.Dropout(),
tf.keras.layers.Dense(4096, activation="relu"),
tf.keras.layers.Dropout(),
tf.keras.layers.Dense(num_classes)]
)

惰性構建機制的一個可能引起問題的地方在於權重的載入,如果變數還沒被創建就想載入怎麼辦?看一個例子就知道了:

import tensorflow as tf

dense = tf.keras.layers.Dense(2)
# 注意看這裡,假如我曾經訓練好過這個dense,並且權重儲存在`my-checkpoint`中亦或者是我從別人那兒拿了一個checkpoint文件。則這裡的dense剛剛創建還沒有經過任何調用,因此dense.variables此時是個空,但是依然可以進行restore
tf.train.Checkpoint(dense=dense).restore("my-checkpoint")
print(dense.variables)
x = tf.random.normal((4, 3))
# 在這裡被調用,變數被創建時,會立即賦予checkpoint中的權重,之後才會作用在x上
y = dense(x)
print(dense.variables)

從上可以看出,這個因為tf.keras.layers.Layer的lazy build而導致的權重載入的問題被tf.train.Checkpoint的restore-on-create的機制給解決了。

統一的調用介面

tf.keras.layers.Layer__call__被覆蓋過,用來對輸入張量進行運算,__call__會調用call,因此用戶需要在call方法中,定義通過輸入張量和該層所具有的變數進行實際的運算的過程。以下代碼摘取自官方tf.keras.layers.Dense.call的實現(截取了部分代碼段):

def call(self, inputs):
inputs = ops.convert_to_tensor(inputs)
rank = common_shapes.rank(inputs)
if rank > 2:
outputs = standard_ops.tensordot(inputs, self.kernel, [[rank - 1], [0]])
else:
outputs = gen_math_ops.mat_mul(inputs, self.kernel)
if self.use_bias:
outputs = nn.bias_add(outputs, self.bias)
if self.activation is not None:
return self.activation(outputs) # pylint: disable=not-callable
return outputs

即根據輸入的inputs,判斷是否秩是否大於2,如果是則調用張量乘tensordot,如果不是調用矩陣乘mat_mul,最後處理biasactivation

顯然,藉助tf.Moduletf.keras.layers.Layer(及其子類),已經足夠實現非常複雜的模型了。

tf.keras.Model

官方文檔在此,keras自己的文檔也可以做參考。tf.keras.Model(間接)繼承自tf.keras.layers.Layer,因此也具有惰性構建的機制,也可以被調用。如果說tf.keras.layers.Layer構成了一個個小的組件的話,tf.keras.Model就是要組合這些Layer構建模型的類型。這個類型太過複雜,值得說的東西也太多,其大體的特點如下:

兩種形式的構造函數

看文檔可以發現,tf.keras.Model有兩種形式的構造函數,以下是第一種(官方稱之為functional api),從keras誕生至今一直存在的一種構造函數調用方式

  • functional api

這種方式首先通過tf.keras.layers.Input來生成輸入,經過各種層tf.keras.layers.Layer及其子類的運算之後得到輸出,tf.keras.Model接受輸入輸出而生成整個網路拓撲結構。開門見山地講,我認為這個形式的構造函數不適合使用,以下為理由:

  1. 無法調試,丟失了動態圖執行的本意

import tensorflow as tf
inputs = tf.keras.Input(shape=(3,))
outputs = inputs + 1 # 在此處下斷點,看不見張量的內容
model = tf.keras.Model(inputs=inputs, outputs=outputs)

y = model(tf.random.normal((6, 3)))

這個方式的巨大的缺點就是調試起來非常困難(這不就是2.0想要改正的 1.x的毛病麼?),假如運算完發現y裡面有nan,想檢查一下中間的運算哪裡出問題了,於是在outputs = inputs + 1這一行下了一個斷點並調試,然而,看到的是

In[2]: outputs
Out[2]: <tf.Tensor add:0 shape=(None, 3) dtype=float32>

沒錯,是符號張量,沒有具體的數值。沒張量數值怎麼知道nan哪來的?新版本不是打開了動態圖執行了麼?

這一切都得從tf.keras.layers.Input開始說起,這個類型會創建一個本來在動態圖模式下並不能存在的Placeholderop(沒錯,新版本已經沒有tf.placeholder這個東西了),為了能創建這個op,只好隱式地創建一個運算圖tf.Graph(用來容納這個op,可以通過inputs.graph來獲得這個tf.Graph),於是,一切對這個inputs的運算,又回到了這個tf.Graph中。

2. 與標準序列化函數tf.saved_model.save不兼容

姑且忽略這個看不見張量內容的事,忍一忍就算了(就像大家對1.x的態度),然而模型還不能被序列化真的是讓人非常難以接受,看如下的例子:

# 僅在TF 2.0.0-beta1版本中測試會報錯,其餘更新版本不保證
import tensorflow as tf
inputs = tf.keras.Input(shape=(3,))
outputs = inputs + 1
model = tf.keras.Model(inputs=inputs, outputs=outputs)

y = model(tf.random.normal((6, 3)))
# 報TypeError
# 關於`tf.saved_model.save`目前只需要知道他接受一個`Trackable`並儲存成`SavedModel`即可。
tf.saved_model.save(model, "save/model")

當前主流的各種tensorflow模型的部署或者跨語言的使用都是基於SavedModel的,看如下一張圖就知道SavedModel的地位了:

tensorflow serving部署,比如使用java預測,如果不能儲存成SavedModel意味著以上這些功能全都不能使用,意味著tensorflow最強的部署的能力都廢了。(後續的文章中會有一文專門說到新版本下的模型序列化)

這個例子中inputs + 1是造成不能序列化的罪魁禍首(因為隱式地調用了tf.add),經過測試,tf.keras.layers.Layer的子類的調用才能使得tf.saved_model.save調用成功。其實這個問題並不難修復,但是從中可以看出的是,很多組件(比如tf.saved_model.savetf.print等)在設計的時候並沒有考慮到functional api。

如果將以上的代碼修改成eager模式的,也就是第二種方式的構造函數,將會有更多的便利。

  • subclassed

import tensorflow as tf

class MyModel(tf.keras.Model):

def __init__(self):
super(MyModel, self).__init__()

# 目前只需要知道`tf.function`會將動態圖轉變為靜態圖
@tf.function
def call(self, inputs):
# 如果發現這裡有nan,注釋掉`tf.function`用動態圖加斷點調試非常便利
y = inputs + 1
# 即使在靜態圖中,依然可以用tf.print列印出張量的內容
tf.print(y)
return y

model = MyModel()
# 序列化成`SavedModel`也不會有任何問題
tf.saved_model.save(model, "save/model")

如果你在此時發現ynan,在y = inputs + 1處打斷點,就能看到運算過程中張量的具體值了。

同時藉助tf.function修飾call方法可以隨時在動靜態圖之間切換,使用tf.print,即使在靜態圖下依然能夠隨時看到中間變數的內容(後續的文章中會有一文專門說到新版本下的完全動態圖化之後靜態圖轉換器tf.function的使用),也可以傳入額外的python值(如True或者False等等)作為call的參數,這些額外的功能functional api全都沒有。

顯然subclassed方式纔是更適合TF 2.0的。

其實很多人對新版本完全keras化有很大的誤解,比如這個知乎問題直接問出了:

TF變成keras,早知道就直接用keras了,用TF幹嘛?

實際上新版本的tf.keras藉助於動態圖執行和subclassed的書寫代碼的方式,能夠非常好地平衡可調試性和可部署性(說白了就是動靜態圖之間的權衡),然而很多keras的用戶還在堅持keras官網推薦的functional api這種無法調試的方式來寫,丟失了大量tf.keras所擁有的功能如直接序列化成SavedModel用於部署等等。從這個角度上來講,TF並沒有變成keras,只是改造了keras的類作為自己的狀態容器而已

compile和fit的一站式訓練

github上有各種使用keras.Model構建模型並使用compilefit進行訓練的例子(如transformer,han)(google搜索keras + 模型名稱保證能找到實現),因此這一小節就不再贅述使用方法,只說為什麼。

很多時候訓練需要額外的寫很多輔助的代碼,比如實現一個process bar看看每個epoch的訓練進度、實現一個early stop或者實現一個當驗證集損失變小時才存checkpoint的機制,這些實現都需要時間,當然如果代碼積累好了直接拿來用,不存在這種問題,如果沒有任何積累,compile + fit的機制可以幫你快速地做掉這些東西。keras經過重構之後,使用subclassed方式構建的模型進行compile + fit的方式跟functional api的方式幾乎沒有任何區別,並且可以配合tf.data.Dataset使用。官方有一些使用subclassed + compile + fit + tf.data.Dataset的例子,有興趣的可以看下,不再贅述。

經過測試,如果使用compile + fit的方式進行訓練,想進行調試查看運算過程中張量的內容的話,記得在構造函數中加上dynamic=True

當然,這種方式非常地死板,很多人不喜歡,而更想要定製化的訓練循環,這當然可以,藉助於tf.GradientTape就可以自己算導數,tf.optimizers.Optimizer做參數更新,總之也不一定非要用compile + fit

廢棄或被更改的各種方法與屬性

以前只有theano作為backend和functional api的時候,keras.Model定義了很多符合當時應用場景的屬性和方法,然而隨著tensorflow和動態圖的引入,這些屬性和方法大多已經(部分)失效,顯得這個類型非常的冗餘,當然也有一些仍然還有用的(比如compilefit)。以下為一些例子:

  1. 命名失當的.save_weights

很久以前keras.Model.save_weights是用來儲存權重的,然而到了新版本中,千萬不要被save_weights這個名字矇蔽雙眼,很多人看到這個方法第一感覺可能是首先獲得keras.Model.weights(或者keras.Model.variables)然後儲存變數就完事?以前是這樣,新版本可就不是了,看個例子

import tensorflow as tf
from tensorflow.python.ops.lookup_ops import MutableHashTable

class Model(tf.keras.Model):
def __init__(self, dynamic=True):
super(Model, self).__init__(dynamic=dynamic)
self.hash_table = MutableHashTable(key_dtype=tf.int32, value_dtype=tf.int32, default_value=0)
self.hash_table.insert([1, 2, 3], [2, 4, 6])
self.v = tf.Variable(1)

def call(self, x, training=True, mask=None):
prob = self.hash_table.lookup(x) + self.v
return prob

model = Model()
y = model(tf.constant([1, 3]))

這個例子中除了self.vself.hash_table也會被儲存下去,因此save_weights應該叫save_saveables才更合適。

實際上,tf.keras.Model.save_weightstf.train.Checkpoint的作用方式是一樣的,.weightstf.keras.Model.save_weights保存的變數實際上不是相同途徑收集的,不要認為存的checkpoint跟.weights(或.variables)裡面的對象是一樣的

tf.keras.Model.save_weights還是用tf.train.Checkpoint存儲checkpoint?官方有一段解釋,大意是說

  • tf.keras.Model.save_weights存的checkpoint要用tf.keras.Model.load_weights載入
  • tf.train.Checkpoint.save存的checkpoint要用tf.train.Checkpoint.restore載入
  • 最好使用tf.train.Checkpoint而不是tf.keras.Model.save_weights(不知道原因,我猜應該是save_weights在新版本真的不是個能夠反應其功能的名字

2. 被重構了的.savetf.keras.models.load_model

熟悉keras的用戶應該都知道,tf.keras.Model.save(或者keras.Model.save)在很久以前設計的時候是儲存整個模型的(儲存格式為.h5),儲存的單元為tf.keras.layers.Layer,弊端在於,對於自定義的層,如果丟了代碼就載入不回來了而且.h5沒法直接像SavedModel那樣有很多的部署工具可以支持。在新版本中,.save將被修改成默認儲存成SavedModel格式,以後只要有SavedModel,就算丟失了所有模型的實現代碼,也能夠將模型載入回來(使用tf.keras.models.load_model)。截止到本文發稿時,這項工作仍在進行當中。

3. 手動.build

首先要明確一下,keras.Model.build是從keras.layers.Layer那兒(間接)繼承來的,在前面(tf.keras.layers.Layer的惰性構建一節)也提到了這個.build的作用是在第一次作用在輸入張量上的時候才構建與之匹配的權重減少人工計算維度的不必要性。而在只有functional api的年代,keras.Model.build方法是個冗餘的方法,因為functional api中的keras.Model中所有的層都有構建好的權重。

新版本中因為subclassed方式書寫的tf.keras.Modeltf.keras.layers.Layer一樣也是有惰性構建機制的,因此tf.keras.Model.build被修改成可以直接根據輸入形狀構建權重,看個例子:

import tensorflow as tf

class Model(tf.keras.Model):
def __init__(self):
super(Model, self).__init__()
self.d = tf.keras.layers.Dense(3)

def call(self, x, training=True, mask=None):
return self.d(x)

m = Model()
# 列印為空
print(m.variables)
# 這裡只要給進去輸入張量的形狀即可,不需要具體的張量內容
m.build((None, 4))
# 列印,有兩個權重
print(m.variables)

tf.keras.Model.build可以幫助用戶只給進去輸入的形狀而構建權重,並在此之後對變數進行查看內容、賦值和修改的等等。

4. 不兼容subclassed方式的.summary

.summary無法展示出subclassed方式構建的模型中的二級及更高子tf.keras.layers.Layer,只能展示一級的,看個例子:

import tensorflow as tf

class Model(tf.keras.Model):
def __init__(self):
super(Model, self).__init__()
self.d = tf.keras.layers.Dense(3)

def call(self, x, training=True, mask=None):
return self.d(x)

class SuperModel(tf.keras.Model):
def __init__(self):
super(SuperModel, self).__init__()
self.model = Model()

def call(self, x, training=True, mask=None):
return self.model(x)

m = SuperModel()
m.build((None, 4))
m.summary()

輸出結果為

Model: "super_model"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
model (Model) multiple 15
=================================================================
Total params: 15
Trainable params: 15
Non-trainable params: 0
_________________________________________________________________

SuperModel的第一子級tf.keras.layers.Layer是可以展示,也就是這個Model,但是其第二子級,也就是Model的第一子級d就無法展示了,更別提希望能夠展示出完整的DAG的結構圖了…...

5. 各種對subclassed方式不再有效的屬性

在只有functional api和靜態圖的時代,keras.Model還有很多有意義的屬性,比如keras.Model.inputs代表了輸入的符號張量keras.Model.outputs代表了輸出的符號張量。對於動態圖和subclassed方式構建的模型,這些屬性顯然是沒意義的,因為沒有任何一個張量代表了subclassed的方式構建的模型的輸入,類似的屬性還有inputinput_shapeinput_maskinput_namesinbound_nodesoutputoutput_maskoutput_namesoutput_shape等等,這些冗餘的屬性使得tf.keras.Model這個類型顯得非常笨重

tf.keras.Sequential

文檔在此,對於輸入的張量,tf.keras.Sequential挨個作用裡麪包含的tf.keras.layers.Layer,適合實現VGG16這樣一條路走到黑的模型。在一個大的模型中某些一條路走到黑的子模型用tf.keras.Sequential實現會極大地提升實現的效率。

因為繼承自tf.keras.Model,擁有tf.keras.Model的幾乎所有的功能,因此也可以使用compile + fit進行訓練。

狀態容器選擇

六種種狀態容器,選擇繼承哪個?

  • 僅在學習和深入研究狀態容器(或基於對象的儲存)時使用TrackableAutoTrackable建模中不要使用這兩個類型,因為沒有顯式地變數收集(.variable屬性)的機制,使用tf.GradientTape.gradient會費力。
  • tf.Module: 優點是輕量級,介面設計非常乾淨,可以配合tf.keras.layers.Layer使用,缺點是無法配合tf.keras.Sequential等等使用(雖然自己實現一個也不難),如果希望有自定義的訓練循環(custom training loop)tf.Module是一個好的選擇。
  • tf.keras.layers.Layer:已經有大量定義好的tf.keras.layers.Layer的子類如tf.keras.layers.Conv2D等等可以在建模中使用。這個類型適合實現一些中間層,比如Attention之類的,還可以配合tf.keras.Sequential使用,極少看見大的模型繼承自這個類型(沒看見不代表不可以)。
  • tf.keras.Model:曾經為靜態圖定製的類型,此後經過大量的改造以此適應新版本的動態圖和subclassed的寫法,代價是大量的屬性和方法(部分)都失效了,顯得非常笨重。但至少不同版本的兼容性還是在的,依然可以使用固定套路compile + fit跑模型,配合tf.keras.callbacks.Callback能夠極快速地實現各種訓練功能(如定期Checkpoint等)。雖然也可以自定義訓練循環,但是有一種殺雞用牛刀的感覺。
  • tf.keras.Sequential:適合實現一條路走到黑的(子)模型,比如VGG,可能不到三十行就能實現一個。

狀態容器是新版本完全擁抱動態圖執行的直接結果,如果沒有狀態容器,則局部創建的變數tf.Variable將會釋器回收掉。這難道就意味著1.x的靜態圖就完全拋棄了麼?顯然至少有兩個方面動態圖是完全比不過靜態圖的,一是因為靜態圖的各種優化(如constant folding、kernel fusion)帶來的運算速度的提升,二是靜態圖可以被序列化成SavedModel用於部署,用常見的靜態類型語言(C++Java)都可以在工業場景載入調用。新版本的靜態圖去哪了?後面的《TensorFlow 2.0學習筆記之靜態圖轉換器》揭曉。


推薦閱讀:
相關文章