轉載請註明出處
TL;DR
(0). 為什麼需要狀態容器:
因為TensorFlow 2.0以後默認打開動態圖運算,因此tf.Varibable(以及其它狀態類型)如果不放在一個容器中會被python解釋器回收掉,無法求導以及更新,這種容器稱之為狀態容器(stateful container)。TensorFlow 2.0中狀態容器的繼承結構圖為:
TensorFlow 2.0
tf.Varibable
Trackable | |-- tf.Variable | |-- MutableHashTable | |-- AutoTrackable | |-- ListWrapper/DictWrapper | |-- tf.train.Checkpoint | |-- tf.Module | |-- tf.keras.layers.Layer | |-- tf.keras.Model | |-- tf.keras.Sequential
(1). Trackable:
Trackable
Trackable(from tensorflow.python.training.tracking.base import Trackable)是所有的狀態容器的基類,其定義了一個self._track_trackable(self, name: Str, value: Trackable) -> None方法,來使得value成為self名為name的依賴。最終Trackable的實例之間的依賴關係構成一個有向無環圖(DAG),只要根節點不被解釋器回收,則所有的依賴都會存在。
from tensorflow.python.training.tracking.base import Trackable
self._track_trackable(self, name: Str, value: Trackable) -> None
value
self
name
(2). AutoTrackable:
AutoTrackable
AutoTrackable(from tensorflow.python.training.tracking.tracking import AutoTrackable),在__setattr__中運行_track_trackable,執行a.b = c時自動的將c設成a名為b的依賴(這也是叫AutoTrackable的原因),避免顯式調用_track_trackable帶來的無聊和繁瑣。
from tensorflow.python.training.tracking.tracking import AutoTrackable
__setattr__
_track_trackable
a.b = c
c
a
b
(3). Trackable vs AutoTrackable:
只能成為別人的依賴而不依賴於別人的的對象,繼承Trackable即可,而不用AutoTrackable,比如tf.Variable和MutableHashTable。這些類型通常是可以被保存的對象。
tf.Variable
MutableHashTable
(4). tf.train.Checkpoint:
tf.train.Checkpoint
基於對象的儲存(object-based save)的類型,地位等同於1.x的tf.train.Saver。繼承自AutoTrackable。該類型在儲存時會遍歷Trackable組成的DAG儲存和載入變數(或stateful的組件)。
1.x
tf.train.Saver
(5). tf.Module:
tf.Module
定位為一個輕量級的狀態容器,繼承自AutoTrackable,是所有的public api中的狀態容器中最底層的一個,最底層的能夠收集.variables和.submodules屬性的一個,因為可以收集變數,所以這個類型可以用來建模,配合tf.GradientTape使用。
.variables
.submodules
tf.GradientTape
(6). tf.keras.layers.Layer:
tf.keras.layers.Layer
Layer是keras中最底層的類型,繼承自tf.Module,相比於其父類,Layer開始有了建模的各種特性和功能,包括惰性構建機制(lazy build)和統一的調用介面。大量已經被定義好的tf.keras.layers.Layer的子類可以用來快速的構建模型,如全連接層tf.keras.layers.Dense,卷積層tf.keras.layers.Conv2D和遞歸層tf.keras.layers.RNN等等。
Layer
keras
tf.keras.layers.Dense
tf.keras.layers.Conv2D
tf.keras.layers.RNN
(7). tf.keras.Model:
tf.keras.Model
(間接)繼承自tf.keras.layers.Layer,是keras中建模的核心類型,有如下特性:
compile
fit
tf.keras.callbacks
(8). tf.keras.Sequential
tf.keras.Sequential
對於輸入的張量,挨個作用裡麪包含的tf.keras.layers.Layer,適合實現VGG16這樣一條路走到黑的(子)模型。
(9). 這麼多狀態容器,選擇繼承哪個?
Attention
tensorflow
SavedModel
2.0.beta1
pip install tensorflow==2.0.beta1
pytorch
TensorFlow 2.0相對於TensorFlow 1.x的一個巨大的變化是默認打開動態圖(查看此文檔瞭解更詳細的變更),這個變化導致的一個直接結果是tf.Variable(及其子類和其他的所有的狀態類型如MutableHashTable)的生命週期將與其對應的python對象綁定,簡單解釋一下就是:
TensorFlow 1.x
# 僅僅只能在1.x版本中運行!!! import tensorflow as tf
v = tf.Variable(2.) v = "foo" default_graph = tf.get_default_graph() vs = default_graph.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) print(vs)
一開始創建的tf.Variable所綁定的對象v被另外地賦值,理論上根據python的垃圾回收機制,這個tf.Variable應該被解釋器回收掉(因為已經沒有引用指向這個tf.Variable了),然而列印出的vs並不是空,這是因為1.x版本中創建的變數將依附在運算圖tf.Graph中。2.0之後呢?
v
vs
tf.Graph
2.0
import tensorflow as tf
v = tf.Variable(2.) v = "foo"
因為沒有了默認的運算圖,執行完v = "foo"之後tf.Variable這個對象由於最後一個指向他的引用消失而被解釋器回收。
v = "foo"
這個變化有什麼樣的影響呢,看一個線性回歸的例子就知道了
def linear_regression(x): w = tf.Variable(tf.random.normal(3, 4)) b = tf.Variable(tf.random.normal((4,))) return tf.matmul(x, w) + b
這是一段典型的1.x版本的代碼(這種函數配合tf.variable_scope使用並做變數重用(variable resue)甚至成為了面試中關於TensorFlow的一個經典問題)。因為運算圖的存在,函數執行完之後創建的變數w和b就算不再有引用指向他們,還是會被儲存在tf.Graph中,因此不用擔心變數被解釋器回收的事情。然而到了2.0之後,這個函數等於廢了,因為運行完了之後w和b已經不存在了…...變數都不在了,對變數求導、更新和保存就別想了吧。
tf.variable_scope
TensorFlow
w
現在的問題變成了怎麼做才能保證變數不被回收掉:只要保存至少一個指向他們的引用即可,簡單修改一下這個linear_regression函數變成類型就能做到:
linear_regression
class LinearRegression(object): def __init__(self): self.w = tf.Variable(tf.random.normal(3, 4)) self.b = tf.Variable(tf.random.normal((4,)))
def __call__(self, x): return tf.matmul(x, self.w) + self.b
linear_regression = LinearRegression()
現在,就算調用了linear_regression(這個對象可以調用,因為__call__被顯式地定義了),變數也不再會被回收掉,因為總存在指向他們的引用。
__call__
對比一下就能夠發現,這個LinearRegression類型起了一個容器的作用,他用來儲存與自己的運算相關的tf.Variable(當然也要進行運算,在__call__中),防止他們被解釋器回收掉。
LinearRegression
這個LinearRegression類型是我們自己定義的,繼承自object,而在TensorFlow 2.0中,為了能夠更好的收集、儲存和使用tf.Variable(或者更廣義一點,所有的stateful的組件,還比如tf.lookup.StaticHashTable等),新版本定義了一系列的複雜度與功能各不一樣的類型,因為這些類型的主要的作用是充當stateful的組件的容器,因為官方稱之為狀態容器(stateful container)
object
tf.lookup.StaticHashTable
以上是新版本中一定需要狀態容器的原因(如果沒有,變數會被回收無法儲存和復用),現在來介紹一下TensorFlow 2.0中狀態容器的類型。他們整個繼承結構圖,見前面的樹狀圖,以下對這個樹狀繼承圖中重要的類型做一些解釋和說明。
Trackable(引用方式為from tensorflow.python.training.tracking.base import Trackable)是新版本中所有的狀態容器的基類,繼承自objetc,其定義了self._track_trackable(self, name: Str, value: Trackable) -> None的方法,來使得value成為self名為name的依賴,最終所有的Trackable的實例之間的依賴關係構成一個有向無環圖(Directed Acyclic Graph, 縮寫為DAG)。以下是一個通過調用_track_trackable來構造依賴關係DAG的例子:
objetc
import tensorflow as tf from tensorflow.python.training.tracking.base import Trackable from tensorflow.python.ops.lookup_ops import MutableHashTable from tensorflow.python.training.tracking.graph_view import ObjectGraphView import weakref
t1 = Trackable() t2 = Trackable() # 這裡暫時只要知道`tf.Variable`和`MutableHashTable`也都繼承自`Trackable`就行 v = tf.Variable(1.) hash_table = MutableHashTable(key_dtype=tf.int32, value_dtype=tf.int32, default_value=0) t1._track_trackable(t2, "child_trackable") t1._track_trackable(v, "variable1") t2._track_trackable(v, "variable2") t2._track_trackable(hash_table, "dict")
# 使用`ObjectGraphView`這個類型可以獲得這個DAG的結構 object_graph_view = ObjectGraphView(weakref.ref(t1)) # saveables儲存了這個DAG中可以被保存的狀態類型,這裡有兩個,一個是v,另一個是hash_table,saveables會在 # tf.train.Checkpoint中被使用 saveables, dag, _ = object_graph_view.serialize_object_graph() print(dag)
根據代碼可以看出構建的DAG大概如下圖所示:
這從列印得到的dag可以得到印證:
dag
nodes { children { node_id: 1 local_name: "child_trackable" } children { node_id: 2 local_name: "variable1" } } nodes { children { node_id: 2 local_name: "variable2" } children { node_id: 3 local_name: "dict" } } nodes { attributes { name: "VARIABLE_VALUE" full_name: "Variable" checkpoint_key: "variable1/.ATTRIBUTES/VARIABLE_VALUE" } } nodes { attributes { name: "table" checkpoint_key: "child_trackable/dict/.ATTRIBUTES/table" } }
一共有四個節點,很容易可以看出這DAG中的四個節點分別代表了t1、t2、v和hash_table。
t1
t2
hash_table
Trackable的另一個重要的方法self._gather_saveables_for_checkpoint(self),會收集可以被儲存的對象,後文會在基於對象的儲存tf.train.Checkpoint中提到這個方法。
self._gather_saveables_for_checkpoint(self)
引用方式為from tensorflow.python.training.tracking.tracking import AutoTrackable。對於Trackable,每次都調用_track_trackable會顯得非常的繁瑣,能不能a.b = c的時候直接隱式地執行一下a._track_trackable(b, c)呢,AutoTrackable這個類型就做到了這一點。他通過覆蓋(override) __setattr__和__delattr__來達到自動收集(或者刪除)依賴的目的(這也是叫AutoTrackable的原因)。將上面一段關於Trackable的代碼改寫成等價的AutoTrackable的形式,會清爽的多。
a._track_trackable(b, c)
__delattr__
import tensorflow as tf from tensorflow.python.training.tracking.tracking import AutoTrackable from tensorflow.python.ops.lookup_ops import MutableHashTable from tensorflow.python.training.tracking.graph_view import ObjectGraphView import weakref
t1 = AutoTrackable() t2 = AutoTrackable() v = tf.Variable(1.) hash_table = MutableHashTable(key_dtype=tf.int32, value_dtype=tf.int32, default_value=0) t1.child_trackable = t2 t1.variable1 = v t2.variable2 = v t2.dict = hash_table
對比一下AutoTrackable和Trackable的區別就能發現,x.y = z中,x肯定是AutoTrackable的,但是z不一定,如果z只會成為別人的依賴,而不依賴於別人(也就是說z是整個DAG中的葉節點),那z大可不必繼承AutoTrackable因為也不會有人執行z.xxx = xxxxx。在新版本中有很多這種只會成為他人依賴的類型,比如剛剛看到的tf.Variable和MutableHashTable等等,這些類型通常是可以被保存的對象,因此會覆蓋_gather_saveables_for_checkpoint,Trackable一節代碼中的saveables變數,正是因為tf.Variable和MutableHashTable覆蓋了這個方法,才使得自己出現在saveables中(並最終被tf.train.Checkpoint保存,具體看tf.train.Checkpoint一節)。
x.y = z
x
z
z.xxx = xxxxx
_gather_saveables_for_checkpoint
saveables
假如我想attach一個Trackable組成的list或者dict給AutoTrackable,會不會因為list或者dict不是Trackable而導致DAG構建的有問題?實際上並不會,為了應對這個問題,新版本特意設計了兩個類型(引用方式為from tensorflow.python.training.tracking.data_structures import ListWrapper, DictWrapper)。ListWrapper多重繼承自Trackable和collections.Sequence,_DictWrapper多重繼承自Trackable和collections.Mapping。一旦檢測到被attach的對象是一個Trackable的list或者dict,則將這個list或者dict轉變為ListWrapper和DictWrapper(__setattr__被覆蓋所具有的檢測功能),以下為一個例子:
list
dict
from tensorflow.python.training.tracking.data_structures import ListWrapper, DictWrapper
ListWrapper
collections.Sequence
_DictWrapper
collections.Mapping
DictWrapper
from tensorflow.python.training.tracking.tracking import AutoTrackable from tensorflow.python.training.tracking.base import Trackable
auto_trackable = AutoTrackable() auto_trackable.trackable_list = [Trackable() for _ in range(2)] auto_trackable.trackable_dict = {str(i): Trackable() for i in range(2)} print(type(auto_trackable.trackable_list))
對pytorch熟悉的用戶來講,ListWrapper(功能上)等價於torch.nn.ModuleList,DictWrapper(功能上)等價於torch.nn.ModuleDict。
torch.nn.ModuleList
torch.nn.ModuleDict
藉助於Trackable和AutoTrackable這兩個類型就足夠完成基於對象的保存(objec-based saving)了,以下為一個使用AutoTrackable做基於對象的儲存的例子:
import tensorflow as tf from tensorflow.python.training.tracking.tracking import AutoTrackable from tensorflow.python.ops.lookup_ops import MutableHashTable
t1 = AutoTrackable() t2 = AutoTrackable() v = tf.Variable(1.) hash_table = MutableHashTable(key_dtype=tf.int32, value_dtype=tf.int32, default_value=0) hash_table.insert(keys=[1, 2, 3], values=[2, 4, 6]) t1.child_trackable = t2 t1.variable1 = v t2.variable2 = v t2.dict = hash_table
checkpoint = tf.train.Checkpoint(root=t1) checkpoint.save("checkpoint/save")
執行完畢,會在checkpoint目錄下生成文件save-1(多出來的一表示保存次數)。接下來從文件中載入:
checkpoint
save-1
trackable1 = AutoTrackable() trackable2 = AutoTrackable() # 原始值從1.改成了2. variable = tf.Variable(2.) dictionary = MutableHashTable(key_dtype=tf.int32, value_dtype=tf.int32, default_value=0) # 字典的value從[2, 4, 6]變成了[10, 20, 30] dictionary.insert(keys=[1, 2, 3, 4], values=[10, 20, 30, 40]) trackable1.child_trackable = trackable2 trackable1.variable1 = variable trackable2.variable2 = variable trackable2.dict = dictionary print("before restore") print(trackable1.variable1.numpy()) print(trackable1.child_trackable.dict.lookup(tf.constant([1, 4])).numpy()) checkpoint = tf.train.Checkpoint(root=trackable1) checkpoint.restore("checkpoint/save-1") print("after restore") print(trackable1.variable1.numpy()) print(trackable1.child_trackable.dict.lookup(tf.constant([1, 4])).numpy())
輸出結果為
before restore 2.0 [10 30] after restore 1.0 [2 6]
應該有人注意到了我重命名了很多python變數,比如t1變成了trackable1,hash_table變成了dictionary,但是最終還是能夠成功地恢復變數,因為變數名一點都不重要,重要的是儲存時的鍵的名字(比如child_trackable)。
trackable1
dictionary
child_trackable
這個類型是如何做到儲存的呢?一個核心的類型是剛剛已經出現過了的ObjectGraphView(引用方式為from tensorflow.python.training.tracking.graph_view import ObjectGraphView),這個類型接收DAG的根節點之後開始遍歷整個DAG並通過調用節點的_gather_saveables_for_checkpoint來收集可以被保存的對象(如變數值)以及他們之間的依賴關係並儲存。
ObjectGraphView
from tensorflow.python.training.tracking.graph_view import ObjectGraphView
tf.train.Checkpoint還具有創建時載入的機制(restore on creation),來看一個例子:
import tensorflow as tf from tensorflow.python.training.tracking.tracking import AutoTrackable
class RestoreOnCreation(AutoTrackable): def __init__(self): super(RestoreOnCreation, self).__init__() self.v = None
def __call__(self, init_value): if self.v is None: self.v = tf.Variable(init_value) return self.v
restore_on_create = RestoreOnCreation() # 在這個調用之後`restore_on_create.v`才會被創建 restore_on_create(1.) tf.train.Checkpoint(root=restore_on_create).save("save/restore_on_create")
# 創建新的`RestoreOnCreation`對象`another_restore_on_create` another_restore_on_create = RestoreOnCreation() # 這裡的`another_restore_on_create`的`.v`屬性還是個`None`的時候就執行了載入 tf.train.Checkpoint(root=another_restore_on_create).restore("save/restore_on_create-1") # !!!傳進去的初始值雖然是2.,但是因為restore on creation的機制,還是會被修改成1,因此列印結果為1. print(another_restore_on_create(2.))
restore on creation的機制會在tf.keras.layers.Layer的惰性構建中使用(在tf.keras.layers.Layer一節會提到)。
作為新老兩代變數儲存的核心類型,這裡不得不提的是tf.train.Checkpoint和tf.train.Saver的區別:
實際上,2.0以後就沒有tf.train.Saver了(執行tf.train.Saver會報AttributeError),在1.x靜態運算圖時代,tf.train.Saver會構建一個變數名到變數的字典(隱式地從運算圖中獲得,或者顯式地由用戶傳入這個字典,參考其構造函數)以此來存儲(tf.train.Saver.save)變數或者載入(tf.train.Saver.restore)變數,載入變數時如果變數名變掉了則載入失敗,因此這是一種基於變數名的儲存方式(官方稱之為name-based save/restore)。
AttributeError
tf.train.Saver.save
tf.train.Saver.restore
在保存時(tf.train.Checkpoint.save)會記錄DAG中的依賴關係(這個依賴關係圖會被序列化成字元串儲存在了checkpoint文件中,對應的鍵為_CHECKPOINTABLE_OBJECT_GRAPH),在載入階段(tf.train.Checkpoint.restore)會讀取這個字元串並解析得到DAG依賴關係圖,並用DAG依賴關係圖中的節點匹配Trackable對象,匹配完成後載入變數值,因此這是一種基於對象的儲存方式(官方稱之為object-based save/restore)。
tf.train.Checkpoint.save
_CHECKPOINTABLE_OBJECT_GRAPH
tf.train.Checkpoint.restore
tf.Module是最底層的公開(public api)的狀態容器(Trackable和AutoTrackable都不是public的),關於這個類型可以參考官方的文檔,其核心在於實現了兩個功能:
tf.Module.variables
tf.Module.trainable_variables
tf.Module.submodules
這種收集某個類型的實例的具體的做法是對tf.Module的所有屬性進行遍歷(注意並沒有調用_gather_saveables_for_checkpoint),一旦發現是某個類型的實例,則將其收集。因為其收集變數的功能,可以配合tf.GradientTape(這個類型是新版本中計算導數的核心,tf.GradientTape.gradient與1.x版本中的tf.train.Optimizer.compute_gradients的作用一樣)來進行導數的運算和變數的更新了,以下為使用tf.Module做線性回歸的例子:
tf.GradientTape.gradient
tf.train.Optimizer.compute_gradients
class Model(tf.Module): def __init__(self): super(Model, self).__init__() self.v = tf.Variable([[20.], [-10.]])
def forward(self, x): return tf.matmul(x, self.v)
@staticmethod def mse(y_pred, y_true): return tf.reduce_mean(tf.square(y_pred - y_true))
x = tf.random.normal((1024, 2)) y = tf.matmul(x, tf.constant([[1.], [-1.]])) + tf.random.normal((1024,))
model = Model() optimizer = tf.keras.optimizers.Adam(1e-1) iterations = 1024
for i in range(iterations): with tf.GradientTape() as tape: # 此處的model.trainable_variables當然可以換成人工收集的變數[model.v]的形式,但是當變數數非常大的時候, # model.trainable_variables帶來的便利是巨大的 grads = tape.gradient(Model.mse(model.forward(x), y), model.trainable_variables) optimizer.apply_gradients(zip(grads, model.trainable_variables))
print(model.v.numpy())
# 由於model是tf.Module的子類的實例,因此也是Trackable的實例,從而就可以用`tf.train.Checkpoint`進行變數的儲存和載入,下面加以展示 checkpoint = tf.train.Checkpoint(model=model) # 務必保證`checkpoint/linear_regression`目錄存在 checkpoint.save("checkpoint/linear_regression/trained-model")
# 新建一個模型,並從訓練好的checkpoint載入變數值 another_model = Model() print("another model before restore") print(another_model.v) another_checkpoint = tf.train.Checkpoint(model=another_model) another_checkpoint.restore("checkpoint/linear_regression/trained-model-1") print("another model after restore") print(another_model.v)
正如在注釋中提到的,非常複雜的網路的變數是非常多的,在計算他們的導數(tf.GradientTape.gradient)中需要收集所有變數時,tf.Module.variables屬性一次性將他們收集好會帶來非常大的便利,這也是tf.Module適合用來建模的原因之一。
官方沒有給tf.Module設定任何的計算入口,因此在這裡我就定義了一個forward,實際上也可以是__call__,反正隨便是什麼名字的方法都可以。
forward
熟悉pytorch的應該都能理解tf.Module為啥叫Module,以及我為啥給這個模型定義的方法名叫forward。實際上torch.nn.Module.parameters()功能上等價於tf.Module.variables,torch.nn.Module.children()功能上等價於tf.Module.submodules,torch.save(torch.nn.Module.state_dict(), xxx)功能上等價於tf.train.Checkpoint(whatever=tf.Module).save(xxx)。但是不能認為這兩個類型的工作方式是完全一樣的,torch.nn.Module.state_dict會從torch.nn.Module.parameters()中來(當然還包含torch.nn.Module.buffers()),然而tf.train.Checkpoint收集的saveables並不是從tf.Module.variables來(而是從每一個葉節點的_gather_saveables_for_checkpoint而來),歸根到底,這還是因為TensorFlow中可以被保存的類型不止tf.Variable一個(還比如MutableHashTable,這種可以被保存的組件並不能算作是tf.Variable,因此也不該出現在tf.Module.variables中)。
Module
torch.nn.Module.parameters()
torch.nn.Module.children()
torch.save(torch.nn.Module.state_dict(), xxx)
tf.train.Checkpoint(whatever=tf.Module).save(xxx)
torch.nn.Module.state_dict
torch.nn.Module.buffers()
官方文檔在此。終於進入到keras的地盤了, Layer是keras中最底層的類型,keras中有大量已經被定義好的tf.keras.layers.Layer的(直接或間接的)子類可以用來快速的構建機器學習模型,如全連接層tf.keras.layers.Dense,卷積層tf.keras.layers.Conv2D和遞歸層tf.keras.layers.RNN等等。相比於輕量級的tf.Module,tf.keras.layers.Layer顯式地定義了用於建模的各種特性和功能,比如:
tf.keras.layers.Layer在調用過構造函數創建實例之後一般是不包含變數的,只有在第一次遇到輸入時,才會根據輸入張量的維度,調用build(self, input_shape)方法,構造變數,具體的看一個例子:
build(self, input_shape)
dense = tf.keras.layers.Dense(2) print(dense.variables) x = tf.random.normal((4, 3)) y = dense(x) print(dense.variables)
初次列印,dense包含的變數是個空list,而作用在一個形狀為[4, 3]的矩陣上後,dense包含的變數構建完畢,並且變數的尺寸恰好可以跟[4, 3]進行運算。在tf.keras.layers.Dense.build實現中可以直接發現這種邏輯(截取部分代碼段):
dense
[4, 3]
tf.keras.layers.Dense.build
def build(self, input_shape): # 取出輸入的特徵維度last_dim last_dim = tensor_shape.dimension_value(input_shape[-1]) # 根據輸入的特徵維度構建kernel和bias self.kernel = self.add_weight( kernel, # kernel的形狀為[last_dim, self.units] shape=[last_dim, self.units], initializer=self.kernel_initializer, regularizer=self.kernel_regularizer, constraint=self.kernel_constraint, dtype=self.dtype, trainable=True) if self.use_bias: self.bias = self.add_weight( bias, shape=[self.units,], initializer=self.bias_initializer, regularizer=self.bias_regularizer, constraint=self.bias_constraint, dtype=self.dtype, trainable=True) else: self.bias = None self.built = True
這種惰性構建的機制有什麼好處麼?看一個pytorch的VGG的實現就知道了:
# 截取自pytorch官方的VGG實現 from torch import nn
class VGG(nn.Module):
def __init__(self, features, num_classes=1000, init_weights=True): super(VGG, self).__init__() self.features = features self.avgpool = nn.AdaptiveAvgPool2d((7, 7)) self.classifier = nn.Sequential( # 512 * 7 * 7是怎麼來的? nn.Linear(512 * 7 * 7, 4096), nn.ReLU(True), nn.Dropout(), nn.Linear(4096, 4096), nn.ReLU(True), nn.Dropout(), nn.Linear(4096, num_classes), ) if init_weights: self._initialize_weights()
def forward(self, x): x = self.features(x) x = self.avgpool(x) x = x.view(x.size(0), -1) x = self.classifier(x) return x
nn.Linear(512 * 7 * 7, 4096)這裡的512 * 7 * 7顯然是根據輸入的圖像的尺寸以及經過的各種卷積和池化之後算出來的輸出的尺寸,用戶需要對這些運算非常熟悉才能算出這個值。如果使用惰性運算的話,則這個層在遇見輸入張量的時候才會構建與之相匹配的變數,減小用戶顯式計算這個值的負擔。具體一點,tf.keras.layers.Dense(4096)即可替代nn.Linear(512 * 7 * 7, 4096),當tf.keras.layers.Dense遇到這個張量並發現輸入的特徵維度為512 * 7 * 7時會在build方法中創建[512 * 7 * 7, 4096] 的權重用於運算。因此,以上pytorch實現中self.classifier可以被替換成tf.keras惰性構建的實現:
nn.Linear(512 * 7 * 7, 4096)
512 * 7 * 7
tf.keras.layers.Dense(4096)
build
[512 * 7 * 7, 4096]
self.classifier
tf.keras
self.classifier = tf.keras.Sequential( [tf.keras.layers.Dense(4096, activation="relu"), tf.keras.layers.Dropout(), tf.keras.layers.Dense(4096, activation="relu"), tf.keras.layers.Dropout(), tf.keras.layers.Dense(num_classes)] )
惰性構建機制的一個可能引起問題的地方在於權重的載入,如果變數還沒被創建就想載入怎麼辦?看一個例子就知道了:
dense = tf.keras.layers.Dense(2) # 注意看這裡,假如我曾經訓練好過這個dense,並且權重儲存在`my-checkpoint`中亦或者是我從別人那兒拿了一個checkpoint文件。則這裡的dense剛剛創建還沒有經過任何調用,因此dense.variables此時是個空,但是依然可以進行restore tf.train.Checkpoint(dense=dense).restore("my-checkpoint") print(dense.variables) x = tf.random.normal((4, 3)) # 在這裡被調用,變數被創建時,會立即賦予checkpoint中的權重,之後才會作用在x上 y = dense(x) print(dense.variables)
從上可以看出,這個因為tf.keras.layers.Layer的lazy build而導致的權重載入的問題被tf.train.Checkpoint的restore-on-create的機制給解決了。
tf.keras.layers.Layer的__call__被覆蓋過,用來對輸入張量進行運算,__call__會調用call,因此用戶需要在call方法中,定義通過輸入張量和該層所具有的變數進行實際的運算的過程。以下代碼摘取自官方tf.keras.layers.Dense.call的實現(截取了部分代碼段):
call
tf.keras.layers.Dense.call
def call(self, inputs): inputs = ops.convert_to_tensor(inputs) rank = common_shapes.rank(inputs) if rank > 2: outputs = standard_ops.tensordot(inputs, self.kernel, [[rank - 1], [0]]) else: outputs = gen_math_ops.mat_mul(inputs, self.kernel) if self.use_bias: outputs = nn.bias_add(outputs, self.bias) if self.activation is not None: return self.activation(outputs) # pylint: disable=not-callable return outputs
即根據輸入的inputs,判斷是否秩是否大於2,如果是則調用張量乘tensordot,如果不是調用矩陣乘mat_mul,最後處理bias和activation。
inputs
tensordot
mat_mul
bias
activation
顯然,藉助tf.Module和tf.keras.layers.Layer(及其子類),已經足夠實現非常複雜的模型了。
官方文檔在此,keras自己的文檔也可以做參考。tf.keras.Model(間接)繼承自tf.keras.layers.Layer,因此也具有惰性構建的機制,也可以被調用。如果說tf.keras.layers.Layer構成了一個個小的組件的話,tf.keras.Model就是要組合這些Layer構建模型的類型。這個類型太過複雜,值得說的東西也太多,其大體的特點如下:
看文檔可以發現,tf.keras.Model有兩種形式的構造函數,以下是第一種(官方稱之為functional api),從keras誕生至今一直存在的一種構造函數調用方式
這種方式首先通過tf.keras.layers.Input來生成輸入,經過各種層tf.keras.layers.Layer及其子類的運算之後得到輸出,tf.keras.Model接受輸入輸出而生成整個網路拓撲結構。開門見山地講,我認為這個形式的構造函數不適合使用,以下為理由:
tf.keras.layers.Input
import tensorflow as tf inputs = tf.keras.Input(shape=(3,)) outputs = inputs + 1 # 在此處下斷點,看不見張量的內容 model = tf.keras.Model(inputs=inputs, outputs=outputs)
y = model(tf.random.normal((6, 3)))
這個方式的巨大的缺點就是調試起來非常困難(這不就是2.0想要改正的 1.x的毛病麼?),假如運算完發現y裡面有nan,想檢查一下中間的運算哪裡出問題了,於是在outputs = inputs + 1這一行下了一個斷點並調試,然而,看到的是
y
nan
outputs = inputs + 1
In[2]: outputs Out[2]: <tf.Tensor add:0 shape=(None, 3) dtype=float32>
沒錯,是符號張量,沒有具體的數值。沒張量數值怎麼知道nan哪來的?新版本不是打開了動態圖執行了麼?
這一切都得從tf.keras.layers.Input開始說起,這個類型會創建一個本來在動態圖模式下並不能存在的Placeholderop(沒錯,新版本已經沒有tf.placeholder這個東西了),為了能創建這個op,只好隱式地創建一個運算圖tf.Graph(用來容納這個op,可以通過inputs.graph來獲得這個tf.Graph),於是,一切對這個inputs的運算,又回到了這個tf.Graph中。
Placeholder
tf.placeholder
inputs.graph
2. 與標準序列化函數tf.saved_model.save不兼容
tf.saved_model.save
姑且忽略這個看不見張量內容的事,忍一忍就算了(就像大家對1.x的態度),然而模型還不能被序列化真的是讓人非常難以接受,看如下的例子:
# 僅在TF 2.0.0-beta1版本中測試會報錯,其餘更新版本不保證 import tensorflow as tf inputs = tf.keras.Input(shape=(3,)) outputs = inputs + 1 model = tf.keras.Model(inputs=inputs, outputs=outputs)
y = model(tf.random.normal((6, 3))) # 報TypeError # 關於`tf.saved_model.save`目前只需要知道他接受一個`Trackable`並儲存成`SavedModel`即可。 tf.saved_model.save(model, "save/model")
當前主流的各種tensorflow模型的部署或者跨語言的使用都是基於SavedModel的,看如下一張圖就知道SavedModel的地位了:
tensorflow serving部署,比如使用java預測,如果不能儲存成SavedModel意味著以上這些功能全都不能使用,意味著tensorflow最強的部署的能力都廢了。(後續的文章中會有一文專門說到新版本下的模型序列化)
這個例子中inputs + 1是造成不能序列化的罪魁禍首(因為隱式地調用了tf.add),經過測試,tf.keras.layers.Layer的子類的調用才能使得tf.saved_model.save調用成功。其實這個問題並不難修復,但是從中可以看出的是,很多組件(比如tf.saved_model.save,tf.print等)在設計的時候並沒有考慮到functional api。
inputs + 1
tf.add
tf.print
如果將以上的代碼修改成eager模式的,也就是第二種方式的構造函數,將會有更多的便利。
class MyModel(tf.keras.Model):
def __init__(self): super(MyModel, self).__init__()
# 目前只需要知道`tf.function`會將動態圖轉變為靜態圖 @tf.function def call(self, inputs): # 如果發現這裡有nan,注釋掉`tf.function`用動態圖加斷點調試非常便利 y = inputs + 1 # 即使在靜態圖中,依然可以用tf.print列印出張量的內容 tf.print(y) return y
model = MyModel() # 序列化成`SavedModel`也不會有任何問題 tf.saved_model.save(model, "save/model")
如果你在此時發現y有nan,在y = inputs + 1處打斷點,就能看到運算過程中張量的具體值了。
y = inputs + 1
同時藉助tf.function修飾call方法可以隨時在動靜態圖之間切換,使用tf.print,即使在靜態圖下依然能夠隨時看到中間變數的內容(後續的文章中會有一文專門說到新版本下的完全動態圖化之後靜態圖轉換器tf.function的使用),也可以傳入額外的python值(如True或者False等等)作為call的參數,這些額外的功能functional api全都沒有。
tf.function
True
False
顯然subclassed方式纔是更適合TF 2.0的。
其實很多人對新版本完全keras化有很大的誤解,比如這個知乎問題直接問出了:
TF變成keras,早知道就直接用keras了,用TF幹嘛?
實際上新版本的tf.keras藉助於動態圖執行和subclassed的書寫代碼的方式,能夠非常好地平衡可調試性和可部署性(說白了就是動靜態圖之間的權衡),然而很多keras的用戶還在堅持keras官網推薦的functional api這種無法調試的方式來寫,丟失了大量tf.keras所擁有的功能如直接序列化成SavedModel用於部署等等。從這個角度上來講,TF並沒有變成keras,只是改造了keras的類作為自己的狀態容器而已。
github上有各種使用keras.Model構建模型並使用compile加fit進行訓練的例子(如transformer,han)(google搜索keras + 模型名稱保證能找到實現),因此這一小節就不再贅述使用方法,只說為什麼。
keras.Model
keras + 模型名稱
很多時候訓練需要額外的寫很多輔助的代碼,比如實現一個process bar看看每個epoch的訓練進度、實現一個early stop或者實現一個當驗證集損失變小時才存checkpoint的機制,這些實現都需要時間,當然如果代碼積累好了直接拿來用,不存在這種問題,如果沒有任何積累,compile + fit的機制可以幫你快速地做掉這些東西。keras經過重構之後,使用subclassed方式構建的模型進行compile + fit的方式跟functional api的方式幾乎沒有任何區別,並且可以配合tf.data.Dataset使用。官方有一些使用subclassed + compile + fit + tf.data.Dataset的例子,有興趣的可以看下,不再贅述。
tf.data.Dataset
經過測試,如果使用compile + fit的方式進行訓練,想進行調試查看運算過程中張量的內容的話,記得在構造函數中加上dynamic=True。
dynamic=True
當然,這種方式非常地死板,很多人不喜歡,而更想要定製化的訓練循環,這當然可以,藉助於tf.GradientTape就可以自己算導數,tf.optimizers.Optimizer做參數更新,總之也不一定非要用compile + fit
tf.optimizers.Optimizer
以前只有theano作為backend和functional api的時候,keras.Model定義了很多符合當時應用場景的屬性和方法,然而隨著tensorflow和動態圖的引入,這些屬性和方法大多已經(部分)失效,顯得這個類型非常的冗餘,當然也有一些仍然還有用的(比如compile和fit)。以下為一些例子:
theano
.save_weights
很久以前keras.Model.save_weights是用來儲存權重的,然而到了新版本中,千萬不要被save_weights這個名字矇蔽雙眼,很多人看到這個方法第一感覺可能是首先獲得keras.Model.weights(或者keras.Model.variables)然後儲存變數就完事?以前是這樣,新版本可就不是了,看個例子
keras.Model.save_weights
save_weights
keras.Model.weights
keras.Model.variables
import tensorflow as tf from tensorflow.python.ops.lookup_ops import MutableHashTable
class Model(tf.keras.Model): def __init__(self, dynamic=True): super(Model, self).__init__(dynamic=dynamic) self.hash_table = MutableHashTable(key_dtype=tf.int32, value_dtype=tf.int32, default_value=0) self.hash_table.insert([1, 2, 3], [2, 4, 6]) self.v = tf.Variable(1)
def call(self, x, training=True, mask=None): prob = self.hash_table.lookup(x) + self.v return prob
model = Model() y = model(tf.constant([1, 3]))
這個例子中除了self.v,self.hash_table也會被儲存下去,因此save_weights應該叫save_saveables才更合適。
self.v
self.hash_table
save_saveables
實際上,tf.keras.Model.save_weights與tf.train.Checkpoint的作用方式是一樣的,.weights和tf.keras.Model.save_weights保存的變數實際上不是相同途徑收集的,不要認為存的checkpoint跟.weights(或.variables)裡面的對象是一樣的。
tf.keras.Model.save_weights
.weights
用tf.keras.Model.save_weights還是用tf.train.Checkpoint存儲checkpoint?官方有一段解釋,大意是說
tf.keras.Model.load_weights
2. 被重構了的.save和tf.keras.models.load_model
.save
tf.keras.models.load_model
熟悉keras的用戶應該都知道,tf.keras.Model.save(或者keras.Model.save)在很久以前設計的時候是儲存整個模型的(儲存格式為.h5),儲存的單元為tf.keras.layers.Layer,弊端在於,對於自定義的層,如果丟了代碼就載入不回來了,而且.h5沒法直接像SavedModel那樣有很多的部署工具可以支持。在新版本中,.save將被修改成默認儲存成SavedModel格式,以後只要有SavedModel,就算丟失了所有模型的實現代碼,也能夠將模型載入回來(使用tf.keras.models.load_model)。截止到本文發稿時,這項工作仍在進行當中。
tf.keras.Model.save
keras.Model.save
.h5
3. 手動.build
.build
首先要明確一下,keras.Model.build是從keras.layers.Layer那兒(間接)繼承來的,在前面(tf.keras.layers.Layer的惰性構建一節)也提到了這個.build的作用是在第一次作用在輸入張量上的時候才構建與之匹配的權重減少人工計算維度的不必要性。而在只有functional api的年代,keras.Model.build方法是個冗餘的方法,因為functional api中的keras.Model中所有的層都有構建好的權重。
keras.Model.build
keras.layers.Layer
新版本中因為subclassed方式書寫的tf.keras.Model跟tf.keras.layers.Layer一樣也是有惰性構建機制的,因此tf.keras.Model.build被修改成可以直接根據輸入形狀構建權重,看個例子:
tf.keras.Model.build
class Model(tf.keras.Model): def __init__(self): super(Model, self).__init__() self.d = tf.keras.layers.Dense(3)
def call(self, x, training=True, mask=None): return self.d(x)
m = Model() # 列印為空 print(m.variables) # 這裡只要給進去輸入張量的形狀即可,不需要具體的張量內容 m.build((None, 4)) # 列印,有兩個權重 print(m.variables)
tf.keras.Model.build可以幫助用戶只給進去輸入的形狀而構建權重,並在此之後對變數進行查看內容、賦值和修改的等等。
4. 不兼容subclassed方式的.summary
.summary
.summary無法展示出subclassed方式構建的模型中的二級及更高子tf.keras.layers.Layer,只能展示一級的,看個例子:
class SuperModel(tf.keras.Model): def __init__(self): super(SuperModel, self).__init__() self.model = Model()
def call(self, x, training=True, mask=None): return self.model(x)
m = SuperModel() m.build((None, 4)) m.summary()
Model: "super_model" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= model (Model) multiple 15 ================================================================= Total params: 15 Trainable params: 15 Non-trainable params: 0 _________________________________________________________________
SuperModel的第一子級tf.keras.layers.Layer是可以展示,也就是這個Model,但是其第二子級,也就是Model的第一子級d就無法展示了,更別提希望能夠展示出完整的DAG的結構圖了…...
SuperModel
Model
d
5. 各種對subclassed方式不再有效的屬性
在只有functional api和靜態圖的時代,keras.Model還有很多有意義的屬性,比如keras.Model.inputs代表了輸入的符號張量,keras.Model.outputs代表了輸出的符號張量。對於動態圖和subclassed方式構建的模型,這些屬性顯然是沒意義的,因為沒有任何一個張量代表了subclassed的方式構建的模型的輸入,類似的屬性還有input、input_shape、input_mask、input_names、inbound_nodes、output、output_mask、output_names和output_shape等等,這些冗餘的屬性使得tf.keras.Model這個類型顯得非常笨重。
keras.Model.inputs
keras.Model.outputs
input
input_shape
input_mask
input_names
inbound_nodes
output
output_mask
output_names
output_shape
文檔在此,對於輸入的張量,tf.keras.Sequential挨個作用裡麪包含的tf.keras.layers.Layer,適合實現VGG16這樣一條路走到黑的模型。在一個大的模型中某些一條路走到黑的子模型用tf.keras.Sequential實現會極大地提升實現的效率。
因為繼承自tf.keras.Model,擁有tf.keras.Model的幾乎所有的功能,因此也可以使用compile + fit進行訓練。
六種種狀態容器,選擇繼承哪個?
.variable
tf.keras.callbacks.Callback
狀態容器是新版本完全擁抱動態圖執行的直接結果,如果沒有狀態容器,則局部創建的變數tf.Variable將會釋器回收掉。這難道就意味著1.x的靜態圖就完全拋棄了麼?顯然至少有兩個方面動態圖是完全比不過靜態圖的,一是因為靜態圖的各種優化(如constant folding、kernel fusion)帶來的運算速度的提升,二是靜態圖可以被序列化成SavedModel用於部署,用常見的靜態類型語言(C++和Java)都可以在工業場景載入調用。新版本的靜態圖去哪了?後面的《TensorFlow 2.0學習筆記之靜態圖轉換器》揭曉。
C++
Java