转载请注明出处

TL;DR

(0). 为什么需要状态容器:

因为TensorFlow 2.0以后默认打开动态图运算,因此tf.Varibable(以及其它状态类型)如果不放在一个容器中会被python解释器回收掉,无法求导以及更新,这种容器称之为状态容器(stateful container)。TensorFlow 2.0中状态容器的继承结构图为:

Trackable
|
|-- tf.Variable
|
|-- MutableHashTable
|
|-- AutoTrackable
|
|-- ListWrapper/DictWrapper
|
|-- tf.train.Checkpoint
|
|-- tf.Module
|
|-- tf.keras.layers.Layer
|
|-- tf.keras.Model
|
|-- tf.keras.Sequential

(1). Trackable:

Trackablefrom tensorflow.python.training.tracking.base import Trackable)是所有的状态容器的基类,其定义了一个self._track_trackable(self, name: Str, value: Trackable) -> None方法,来使得value成为self名为name的依赖。最终Trackable的实例之间的依赖关系构成一个有向无环图(DAG),只要根节点不被解释器回收,则所有的依赖都会存在。

(2). AutoTrackable:

AutoTrackablefrom tensorflow.python.training.tracking.tracking import AutoTrackable),在__setattr__中运行_track_trackable,执行a.b = c时自动的将c设成a名为b的依赖(这也是叫AutoTrackable的原因),避免显式调用_track_trackable带来的无聊和繁琐。

(3). Trackable vs AutoTrackable

只能成为别人的依赖而不依赖于别人的的对象,继承Trackable即可,而不用AutoTrackable,比如tf.VariableMutableHashTable。这些类型通常是可以被保存的对象。

(4). tf.train.Checkpoint

基于对象的储存(object-based save)的类型,地位等同于1.xtf.train.Saver。继承自AutoTrackable。该类型在储存时会遍历Trackable组成的DAG储存和载入变数(或stateful的组件)。

(5). tf.Module:

定位为一个轻量级的状态容器,继承自AutoTrackable,是所有的public api中的状态容器中最底层的一个,最底层的能够收集.variables.submodules属性的一个,因为可以收集变数,所以这个类型可以用来建模,配合tf.GradientTape使用。

(6). tf.keras.layers.Layer:

Layerkeras中最底层的类型,继承自tf.Module,相比于其父类,Layer开始有了建模的各种特性和功能,包括惰性构建机制(lazy build)统一的调用介面。大量已经被定义好的tf.keras.layers.Layer的子类可以用来快速的构建模型,如全连接层tf.keras.layers.Dense,卷积层tf.keras.layers.Conv2D和递归层tf.keras.layers.RNN等等。

(7). tf.keras.Model:

(间接)继承自tf.keras.layers.Layer,是keras中建模的核心类型,有如下特性:

  • 两种形式的构造函数(functional和subclassed)
  • 定义了compilefit方法,结合tf.keras.callbacks,提供一站式训练服务

(8). tf.keras.Sequential

对于输入的张量,挨个作用里面包含的tf.keras.layers.Layer,适合实现VGG16这样一条路走到黑的(子)模型。

(9). 这么多状态容器,选择继承哪个?

  • 仅在学习和深入研究状态容器(或基于对象的储存)时使用TrackableAutoTrackable
  • tf.Module: 适合自定义训练循环时使用
  • tf.keras.layers.Layer:适合实现一些中间层,比如Attention之类的,可以配合tf.keras.Sequential使用,极少看见大的模型继承自这个类型。
  • tf.keras.Model:适合一些固定套路的模型(使用compile + fit)。虽然也可以自定义训练循环,但是有一种杀鸡用牛刀的感觉。
  • tf.keras.Sequential:适合一条路走到黑的(子)模型。

写在最前

  • 所有内容基于个人的使用和源代码的阅读,属于个人总结,难免会有错误遗漏之处以及个人狭隘的观点,不足之处还请大家多多指教。
  • 完全理解本文的内容需要掌握很多python和tensorflow的知识点如python的垃圾回收机制和魔法方法,tensorflow的运算图原理和序列化格式SavedModel等等。
  • 运行代码段需要2.0.beta1版本,其余版本没有测试过(1.x肯定不行,2.x可能可以),直接运行pip install tensorflow==2.0.beta1即可(2.7或3.x版本都行)。除非显式地注明了代码段的运行版本,否则默认一概使用2.0.beta1(有无GPU均可)。 部分代码需要安装pytorch
  • 转载请注明出处

为什么需要状态容器

TensorFlow 2.0相对于TensorFlow 1.x的一个巨大的变化是默认打开动态图(查看此文档了解更详细的变更),这个变化导致的一个直接结果是tf.Variable(及其子类和其他的所有的状态类型如MutableHashTable)的生命周期将与其对应的python对象绑定,简单解释一下就是:

# 仅仅只能在1.x版本中运行!!!
import tensorflow as tf

v = tf.Variable(2.)
v = "foo"
default_graph = tf.get_default_graph()
vs = default_graph.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
print(vs)

一开始创建的tf.Variable所绑定的对象v被另外地赋值,理论上根据python的垃圾回收机制,这个tf.Variable应该被解释器回收掉(因为已经没有引用指向这个tf.Variable了),然而列印出的vs并不是空,这是因为1.x版本中创建的变数将依附在运算图tf.Graph中。2.0之后呢?

import tensorflow as tf

v = tf.Variable(2.)
v = "foo"

因为没有了默认的运算图,执行完v = "foo"之后tf.Variable这个对象由于最后一个指向他的引用消失而被解释器回收。

这个变化有什么样的影响呢,看一个线性回归的例子就知道了

import tensorflow as tf

def linear_regression(x):
w = tf.Variable(tf.random.normal(3, 4))
b = tf.Variable(tf.random.normal((4,)))
return tf.matmul(x, w) + b

这是一段典型的1.x版本的代码(这种函数配合tf.variable_scope使用并做变数重用(variable resue)甚至成为了面试中关于TensorFlow的一个经典问题)。因为运算图的存在,函数执行完之后创建的变数wb就算不再有引用指向他们,还是会被储存在tf.Graph中,因此不用担心变数被解释器回收的事情。然而到了2.0之后,这个函数等于废了,因为运行完了之后wb已经不存在了…...变数都不在了,对变数求导、更新和保存就别想了吧。

现在的问题变成了怎么做才能保证变数不被回收掉:只要保存至少一个指向他们的引用即可,简单修改一下这个linear_regression函数变成类型就能做到:

import tensorflow as tf

class LinearRegression(object):
def __init__(self):
self.w = tf.Variable(tf.random.normal(3, 4))
self.b = tf.Variable(tf.random.normal((4,)))

def __call__(self, x):
return tf.matmul(x, self.w) + self.b

linear_regression = LinearRegression()

现在,就算调用了linear_regression(这个对象可以调用,因为__call__被显式地定义了),变数也不再会被回收掉,因为总存在指向他们的引用。

对比一下就能够发现,这个LinearRegression类型起了一个容器的作用,他用来储存与自己的运算相关的tf.Variable(当然也要进行运算,在__call__中),防止他们被解释器回收掉。

这个LinearRegression类型是我们自己定义的,继承自object,而在TensorFlow 2.0中,为了能够更好的收集、储存和使用tf.Variable(或者更广义一点,所有的stateful的组件,还比如tf.lookup.StaticHashTable等),新版本定义了一系列的复杂度与功能各不一样的类型,因为这些类型的主要的作用是充当stateful的组件的容器,因为官方称之为状态容器(stateful container)

状态容器类型

以上是新版本中一定需要状态容器的原因(如果没有,变数会被回收无法储存和复用),现在来介绍一下TensorFlow 2.0中状态容器的类型。他们整个继承结构图,见前面的树状图,以下对这个树状继承图中重要的类型做一些解释和说明。

Trackable

Trackable(引用方式为from tensorflow.python.training.tracking.base import Trackable)是新版本中所有的状态容器的基类,继承自objetc,其定义了self._track_trackable(self, name: Str, value: Trackable) -> None的方法,来使得value成为self名为name的依赖最终所有的Trackable的实例之间的依赖关系构成一个有向无环图(Directed Acyclic Graph, 缩写为DAG)。以下是一个通过调用_track_trackable来构造依赖关系DAG的例子:

import tensorflow as tf
from tensorflow.python.training.tracking.base import Trackable
from tensorflow.python.ops.lookup_ops import MutableHashTable
from tensorflow.python.training.tracking.graph_view import ObjectGraphView
import weakref

t1 = Trackable()
t2 = Trackable()
# 这里暂时只要知道`tf.Variable`和`MutableHashTable`也都继承自`Trackable`就行
v = tf.Variable(1.)
hash_table = MutableHashTable(key_dtype=tf.int32, value_dtype=tf.int32, default_value=0)
t1._track_trackable(t2, "child_trackable")
t1._track_trackable(v, "variable1")
t2._track_trackable(v, "variable2")
t2._track_trackable(hash_table, "dict")

# 使用`ObjectGraphView`这个类型可以获得这个DAG的结构
object_graph_view = ObjectGraphView(weakref.ref(t1))
# saveables储存了这个DAG中可以被保存的状态类型,这里有两个,一个是v,另一个是hash_table,saveables会在
# tf.train.Checkpoint中被使用
saveables, dag, _ = object_graph_view.serialize_object_graph()
print(dag)

根据代码可以看出构建的DAG大概如下图所示:

这从列印得到的dag可以得到印证:

nodes {
children {
node_id: 1
local_name: "child_trackable"
}
children {
node_id: 2
local_name: "variable1"
}
}
nodes {
children {
node_id: 2
local_name: "variable2"
}
children {
node_id: 3
local_name: "dict"
}
}
nodes {
attributes {
name: "VARIABLE_VALUE"
full_name: "Variable"
checkpoint_key: "variable1/.ATTRIBUTES/VARIABLE_VALUE"
}
}
nodes {
attributes {
name: "table"
checkpoint_key: "child_trackable/dict/.ATTRIBUTES/table"
}
}

一共有四个节点,很容易可以看出这DAG中的四个节点分别代表了t1t2vhash_table

Trackable另一个重要的方法self._gather_saveables_for_checkpoint(self),会收集可以被储存的对象,后文会在基于对象的储存tf.train.Checkpoint中提到这个方法。

AutoTrackable

引用方式为from tensorflow.python.training.tracking.tracking import AutoTrackable。对于Trackable,每次都调用_track_trackable会显得非常的繁琐,能不能a.b = c的时候直接隐式地执行一下a._track_trackable(b, c)呢,AutoTrackable这个类型就做到了这一点。他通过覆盖(override) __setattr____delattr__来达到自动收集(或者删除)依赖的目的(这也是叫AutoTrackable的原因)。将上面一段关于Trackable的代码改写成等价的AutoTrackable的形式,会清爽的多。

import tensorflow as tf
from tensorflow.python.training.tracking.tracking import AutoTrackable
from tensorflow.python.ops.lookup_ops import MutableHashTable
from tensorflow.python.training.tracking.graph_view import ObjectGraphView
import weakref

t1 = AutoTrackable()
t2 = AutoTrackable()
v = tf.Variable(1.)
hash_table = MutableHashTable(key_dtype=tf.int32, value_dtype=tf.int32, default_value=0)
t1.child_trackable = t2
t1.variable1 = v
t2.variable2 = v
t2.dict = hash_table

只会成为他人依赖的类型

对比一下AutoTrackableTrackable的区别就能发现,x.y = z中,x肯定是AutoTrackable的,但是z不一定,如果z只会成为别人的依赖,而不依赖于别人(也就是说z是整个DAG中的叶节点),那z大可不必继承AutoTrackable因为也不会有人执行z.xxx = xxxxx。在新版本中有很多这种只会成为他人依赖的类型,比如刚刚看到的tf.VariableMutableHashTable等等,这些类型通常是可以被保存的对象,因此会覆盖_gather_saveables_for_checkpointTrackable一节代码中的saveables变数,正是因为tf.VariableMutableHashTable覆盖了这个方法,才使得自己出现在saveables中(并最终被tf.train.Checkpoint保存,具体看tf.train.Checkpoint一节)。

ListWrapper和DictWrapper

假如我想attach一个Trackable组成的list或者dictAutoTrackable,会不会因为list或者dict不是Trackable而导致DAG构建的有问题?实际上并不会,为了应对这个问题,新版本特意设计了两个类型(引用方式为from tensorflow.python.training.tracking.data_structures import ListWrapper, DictWrapper)。ListWrapper多重继承自Trackablecollections.Sequence_DictWrapper多重继承自Trackablecollections.Mapping。一旦检测到被attach的对象是一个Trackablelist或者dict,则将这个list或者dict转变为ListWrapperDictWrapper__setattr__被覆盖所具有的检测功能),以下为一个例子:

from tensorflow.python.training.tracking.tracking import AutoTrackable
from tensorflow.python.training.tracking.base import Trackable

auto_trackable = AutoTrackable()
auto_trackable.trackable_list = [Trackable() for _ in range(2)]
auto_trackable.trackable_dict = {str(i): Trackable() for i in range(2)}
print(type(auto_trackable.trackable_list))

pytorch熟悉的用户来讲,ListWrapper(功能上)等价于torch.nn.ModuleListDictWrapper(功能上)等价于torch.nn.ModuleDict

tf.train.Checkpoint

借助于TrackableAutoTrackable这两个类型就足够完成基于对象的保存(objec-based saving)了,以下为一个使用AutoTrackable做基于对象的储存的例子:

import tensorflow as tf
from tensorflow.python.training.tracking.tracking import AutoTrackable
from tensorflow.python.ops.lookup_ops import MutableHashTable

t1 = AutoTrackable()
t2 = AutoTrackable()
v = tf.Variable(1.)
hash_table = MutableHashTable(key_dtype=tf.int32, value_dtype=tf.int32, default_value=0)
hash_table.insert(keys=[1, 2, 3], values=[2, 4, 6])
t1.child_trackable = t2
t1.variable1 = v
t2.variable2 = v
t2.dict = hash_table

checkpoint = tf.train.Checkpoint(root=t1)
checkpoint.save("checkpoint/save")

执行完毕,会在checkpoint目录下生成文件save-1(多出来的一表示保存次数)。接下来从文件中载入:

import tensorflow as tf
from tensorflow.python.training.tracking.tracking import AutoTrackable
from tensorflow.python.ops.lookup_ops import MutableHashTable

trackable1 = AutoTrackable()
trackable2 = AutoTrackable()
# 原始值从1.改成了2.
variable = tf.Variable(2.)
dictionary = MutableHashTable(key_dtype=tf.int32, value_dtype=tf.int32, default_value=0)
# 字典的value从[2, 4, 6]变成了[10, 20, 30]
dictionary.insert(keys=[1, 2, 3, 4], values=[10, 20, 30, 40])
trackable1.child_trackable = trackable2
trackable1.variable1 = variable
trackable2.variable2 = variable
trackable2.dict = dictionary
print("before restore")
print(trackable1.variable1.numpy())
print(trackable1.child_trackable.dict.lookup(tf.constant([1, 4])).numpy())
checkpoint = tf.train.Checkpoint(root=trackable1)
checkpoint.restore("checkpoint/save-1")
print("after restore")
print(trackable1.variable1.numpy())
print(trackable1.child_trackable.dict.lookup(tf.constant([1, 4])).numpy())

输出结果为

before restore
2.0
[10 30]
after restore
1.0
[2 6]

应该有人注意到了我重命名了很多python变数,比如t1变成了trackable1hash_table变成了dictionary,但是最终还是能够成功地恢复变数,因为变数名一点都不重要,重要的是储存时的键的名字(比如child_trackable)。

这个类型是如何做到储存的呢?一个核心的类型是刚刚已经出现过了的ObjectGraphView(引用方式为from tensorflow.python.training.tracking.graph_view import ObjectGraphView),这个类型接收DAG的根节点之后开始遍历整个DAG并通过调用节点的_gather_saveables_for_checkpoint来收集可以被保存的对象(如变数值)以及他们之间的依赖关系并储存。

tf.train.Checkpoint还具有创建时载入的机制(restore on creation),来看一个例子:

import tensorflow as tf
from tensorflow.python.training.tracking.tracking import AutoTrackable

class RestoreOnCreation(AutoTrackable):
def __init__(self):
super(RestoreOnCreation, self).__init__()
self.v = None

def __call__(self, init_value):
if self.v is None:
self.v = tf.Variable(init_value)
return self.v

restore_on_create = RestoreOnCreation()
# 在这个调用之后`restore_on_create.v`才会被创建
restore_on_create(1.)
tf.train.Checkpoint(root=restore_on_create).save("save/restore_on_create")

# 创建新的`RestoreOnCreation`对象`another_restore_on_create`
another_restore_on_create = RestoreOnCreation()
# 这里的`another_restore_on_create`的`.v`属性还是个`None`的时候就执行了载入
tf.train.Checkpoint(root=another_restore_on_create).restore("save/restore_on_create-1")
# !!!传进去的初始值虽然是2.,但是因为restore on creation的机制,还是会被修改成1,因此列印结果为1.
print(another_restore_on_create(2.))

restore on creation的机制会在tf.keras.layers.Layer的惰性构建中使用(在tf.keras.layers.Layer一节会提到)。

作为新老两代变数储存的核心类型,这里不得不提的是tf.train.Checkpointtf.train.Saver的区别:

  • tf.train.Saver:

实际上,2.0以后就没有tf.train.Saver了(执行tf.train.Saver会报AttributeError),在1.x静态运算图时代,tf.train.Saver会构建一个变数名到变数的字典(隐式地从运算图中获得,或者显式地由用户传入这个字典,参考其构造函数)以此来存储(tf.train.Saver.save)变数或者载入(tf.train.Saver.restore)变数,载入变数时如果变数名变掉了则载入失败,因此这是一种基于变数名的储存方式(官方称之为name-based save/restore)。

  • tf.train.Checkpoint

在保存时(tf.train.Checkpoint.save)会记录DAG中的依赖关系(这个依赖关系图会被序列化成字元串储存在了checkpoint文件中,对应的键为_CHECKPOINTABLE_OBJECT_GRAPH),在载入阶段(tf.train.Checkpoint.restore)会读取这个字元串并解析得到DAG依赖关系图,并用DAG依赖关系图中的节点匹配Trackable对象,匹配完成后载入变数值,因此这是一种基于对象的储存方式(官方称之为object-based save/restore)。

tf.Module

tf.Module是最底层的公开(public api)的状态容器(TrackableAutoTrackable都不是public的),关于这个类型可以参考官方的文档,其核心在于实现了两个功能:

  • 收集所有的变数,储存在tf.Module.variables(或者仅仅是可以训练的变数tf.Module.trainable_variables
  • 收集所有的子tf.Module,储存在tf.Module.submodules

这种收集某个类型的实例的具体的做法是tf.Module的所有属性进行遍历(注意并没有调用_gather_saveables_for_checkpoint),一旦发现是某个类型的实例,则将其收集。因为其收集变数的功能,可以配合tf.GradientTape(这个类型是新版本中计算导数的核心,tf.GradientTape.gradient与1.x版本中的tf.train.Optimizer.compute_gradients的作用一样)来进行导数的运算和变数的更新了,以下为使用tf.Module做线性回归的例子:

import tensorflow as tf

class Model(tf.Module):
def __init__(self):
super(Model, self).__init__()
self.v = tf.Variable([[20.],
[-10.]])

def forward(self, x):
return tf.matmul(x, self.v)

@staticmethod
def mse(y_pred, y_true):
return tf.reduce_mean(tf.square(y_pred - y_true))

x = tf.random.normal((1024, 2))
y = tf.matmul(x, tf.constant([[1.],
[-1.]])) + tf.random.normal((1024,))

model = Model()
optimizer = tf.keras.optimizers.Adam(1e-1)
iterations = 1024

for i in range(iterations):
with tf.GradientTape() as tape:
# 此处的model.trainable_variables当然可以换成人工收集的变数[model.v]的形式,但是当变数数非常大的时候,
# model.trainable_variables带来的便利是巨大的
grads = tape.gradient(Model.mse(model.forward(x), y), model.trainable_variables)
optimizer.apply_gradients(zip(grads, model.trainable_variables))

print(model.v.numpy())

# 由于model是tf.Module的子类的实例,因此也是Trackable的实例,从而就可以用`tf.train.Checkpoint`进行变数的储存和载入,下面加以展示
checkpoint = tf.train.Checkpoint(model=model)
# 务必保证`checkpoint/linear_regression`目录存在
checkpoint.save("checkpoint/linear_regression/trained-model")

# 新建一个模型,并从训练好的checkpoint载入变数值
another_model = Model()
print("another model before restore")
print(another_model.v)
another_checkpoint = tf.train.Checkpoint(model=another_model)
another_checkpoint.restore("checkpoint/linear_regression/trained-model-1")
print("another model after restore")
print(another_model.v)

正如在注释中提到的,非常复杂的网路的变数是非常多的,在计算他们的导数(tf.GradientTape.gradient)中需要收集所有变数时,tf.Module.variables属性一次性将他们收集好会带来非常大的便利,这也是tf.Module适合用来建模的原因之一。

官方没有给tf.Module设定任何的计算入口,因此在这里我就定义了一个forward,实际上也可以是__call__,反正随便是什么名字的方法都可以。

熟悉pytorch的应该都能理解tf.Module为啥叫Module,以及我为啥给这个模型定义的方法名叫forward。实际上torch.nn.Module.parameters()功能上等价于tf.Module.variablestorch.nn.Module.children()功能上等价于tf.Module.submodulestorch.save(torch.nn.Module.state_dict(), xxx)功能上等价于tf.train.Checkpoint(whatever=tf.Module).save(xxx)。但是不能认为这两个类型的工作方式是完全一样的torch.nn.Module.state_dict会从torch.nn.Module.parameters()中来(当然还包含torch.nn.Module.buffers()),然而tf.train.Checkpoint收集的saveables并不是从tf.Module.variables来(而是从每一个叶节点的_gather_saveables_for_checkpoint而来),归根到底,这还是因为TensorFlow中可以被保存的类型不止tf.Variable一个(还比如MutableHashTable,这种可以被保存的组件并不能算作是tf.Variable,因此也不该出现在tf.Module.variables中)。

tf.keras.layers.Layer

官方文档在此。终于进入到keras的地盘了, Layerkeras中最底层的类型,keras中有大量已经被定义好的tf.keras.layers.Layer的(直接或间接的)子类可以用来快速的构建机器学习模型,如全连接层tf.keras.layers.Dense,卷积层tf.keras.layers.Conv2D和递归层tf.keras.layers.RNN等等。相比于轻量级的tf.Moduletf.keras.layers.Layer显式地定义了用于建模的各种特性和功能,比如:

惰性构建的机制

tf.keras.layers.Layer在调用过构造函数创建实例之后一般是不包含变数的,只有在第一次遇到输入时,才会根据输入张量的维度,调用build(self, input_shape)方法,构造变数,具体的看一个例子:

import tensorflow as tf

dense = tf.keras.layers.Dense(2)
print(dense.variables)
x = tf.random.normal((4, 3))
y = dense(x)
print(dense.variables)

初次列印,dense包含的变数是个空list,而作用在一个形状为[4, 3]的矩阵上后,dense包含的变数构建完毕,并且变数的尺寸恰好可以跟[4, 3]进行运算。在tf.keras.layers.Dense.build实现中可以直接发现这种逻辑(截取部分代码段):

def build(self, input_shape):
# 取出输入的特征维度last_dim
last_dim = tensor_shape.dimension_value(input_shape[-1])
# 根据输入的特征维度构建kernel和bias
self.kernel = self.add_weight(
kernel,
# kernel的形状为[last_dim, self.units]
shape=[last_dim, self.units],
initializer=self.kernel_initializer,
regularizer=self.kernel_regularizer,
constraint=self.kernel_constraint,
dtype=self.dtype,
trainable=True)
if self.use_bias:
self.bias = self.add_weight(
bias,
shape=[self.units,],
initializer=self.bias_initializer,
regularizer=self.bias_regularizer,
constraint=self.bias_constraint,
dtype=self.dtype,
trainable=True)
else:
self.bias = None
self.built = True

这种惰性构建的机制有什么好处么?看一个pytorch的VGG的实现就知道了:

# 截取自pytorch官方的VGG实现
from torch import nn

class VGG(nn.Module):

def __init__(self, features, num_classes=1000, init_weights=True):
super(VGG, self).__init__()
self.features = features
self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
self.classifier = nn.Sequential(
# 512 * 7 * 7是怎么来的?
nn.Linear(512 * 7 * 7, 4096),
nn.ReLU(True),
nn.Dropout(),
nn.Linear(4096, 4096),
nn.ReLU(True),
nn.Dropout(),
nn.Linear(4096, num_classes),
)
if init_weights:
self._initialize_weights()

def forward(self, x):
x = self.features(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
x = self.classifier(x)
return x

nn.Linear(512 * 7 * 7, 4096)这里的512 * 7 * 7显然是根据输入的图像的尺寸以及经过的各种卷积和池化之后算出来的输出的尺寸,用户需要对这些运算非常熟悉才能算出这个值。如果使用惰性运算的话,则这个层在遇见输入张量的时候才会构建与之相匹配的变数,减小用户显式计算这个值的负担。具体一点,tf.keras.layers.Dense(4096)即可替代nn.Linear(512 * 7 * 7, 4096),当tf.keras.layers.Dense遇到这个张量并发现输入的特征维度为512 * 7 * 7时会在build方法中创建[512 * 7 * 7, 4096] 的权重用于运算。因此,以上pytorch实现中self.classifier可以被替换成tf.keras惰性构建的实现:

self.classifier = tf.keras.Sequential(
[tf.keras.layers.Dense(4096, activation="relu"),
tf.keras.layers.Dropout(),
tf.keras.layers.Dense(4096, activation="relu"),
tf.keras.layers.Dropout(),
tf.keras.layers.Dense(num_classes)]
)

惰性构建机制的一个可能引起问题的地方在于权重的载入,如果变数还没被创建就想载入怎么办?看一个例子就知道了:

import tensorflow as tf

dense = tf.keras.layers.Dense(2)
# 注意看这里,假如我曾经训练好过这个dense,并且权重储存在`my-checkpoint`中亦或者是我从别人那儿拿了一个checkpoint文件。则这里的dense刚刚创建还没有经过任何调用,因此dense.variables此时是个空,但是依然可以进行restore
tf.train.Checkpoint(dense=dense).restore("my-checkpoint")
print(dense.variables)
x = tf.random.normal((4, 3))
# 在这里被调用,变数被创建时,会立即赋予checkpoint中的权重,之后才会作用在x上
y = dense(x)
print(dense.variables)

从上可以看出,这个因为tf.keras.layers.Layer的lazy build而导致的权重载入的问题被tf.train.Checkpoint的restore-on-create的机制给解决了。

统一的调用介面

tf.keras.layers.Layer__call__被覆盖过,用来对输入张量进行运算,__call__会调用call,因此用户需要在call方法中,定义通过输入张量和该层所具有的变数进行实际的运算的过程。以下代码摘取自官方tf.keras.layers.Dense.call的实现(截取了部分代码段):

def call(self, inputs):
inputs = ops.convert_to_tensor(inputs)
rank = common_shapes.rank(inputs)
if rank > 2:
outputs = standard_ops.tensordot(inputs, self.kernel, [[rank - 1], [0]])
else:
outputs = gen_math_ops.mat_mul(inputs, self.kernel)
if self.use_bias:
outputs = nn.bias_add(outputs, self.bias)
if self.activation is not None:
return self.activation(outputs) # pylint: disable=not-callable
return outputs

即根据输入的inputs,判断是否秩是否大于2,如果是则调用张量乘tensordot,如果不是调用矩阵乘mat_mul,最后处理biasactivation

显然,借助tf.Moduletf.keras.layers.Layer(及其子类),已经足够实现非常复杂的模型了。

tf.keras.Model

官方文档在此,keras自己的文档也可以做参考。tf.keras.Model(间接)继承自tf.keras.layers.Layer,因此也具有惰性构建的机制,也可以被调用。如果说tf.keras.layers.Layer构成了一个个小的组件的话,tf.keras.Model就是要组合这些Layer构建模型的类型。这个类型太过复杂,值得说的东西也太多,其大体的特点如下:

两种形式的构造函数

看文档可以发现,tf.keras.Model有两种形式的构造函数,以下是第一种(官方称之为functional api),从keras诞生至今一直存在的一种构造函数调用方式

  • functional api

这种方式首先通过tf.keras.layers.Input来生成输入,经过各种层tf.keras.layers.Layer及其子类的运算之后得到输出,tf.keras.Model接受输入输出而生成整个网路拓扑结构。开门见山地讲,我认为这个形式的构造函数不适合使用,以下为理由:

  1. 无法调试,丢失了动态图执行的本意

import tensorflow as tf
inputs = tf.keras.Input(shape=(3,))
outputs = inputs + 1 # 在此处下断点,看不见张量的内容
model = tf.keras.Model(inputs=inputs, outputs=outputs)

y = model(tf.random.normal((6, 3)))

这个方式的巨大的缺点就是调试起来非常困难(这不就是2.0想要改正的 1.x的毛病么?),假如运算完发现y里面有nan,想检查一下中间的运算哪里出问题了,于是在outputs = inputs + 1这一行下了一个断点并调试,然而,看到的是

In[2]: outputs
Out[2]: <tf.Tensor add:0 shape=(None, 3) dtype=float32>

没错,是符号张量,没有具体的数值。没张量数值怎么知道nan哪来的?新版本不是打开了动态图执行了么?

这一切都得从tf.keras.layers.Input开始说起,这个类型会创建一个本来在动态图模式下并不能存在的Placeholderop(没错,新版本已经没有tf.placeholder这个东西了),为了能创建这个op,只好隐式地创建一个运算图tf.Graph(用来容纳这个op,可以通过inputs.graph来获得这个tf.Graph),于是,一切对这个inputs的运算,又回到了这个tf.Graph中。

2. 与标准序列化函数tf.saved_model.save不兼容

姑且忽略这个看不见张量内容的事,忍一忍就算了(就像大家对1.x的态度),然而模型还不能被序列化真的是让人非常难以接受,看如下的例子:

# 仅在TF 2.0.0-beta1版本中测试会报错,其余更新版本不保证
import tensorflow as tf
inputs = tf.keras.Input(shape=(3,))
outputs = inputs + 1
model = tf.keras.Model(inputs=inputs, outputs=outputs)

y = model(tf.random.normal((6, 3)))
# 报TypeError
# 关于`tf.saved_model.save`目前只需要知道他接受一个`Trackable`并储存成`SavedModel`即可。
tf.saved_model.save(model, "save/model")

当前主流的各种tensorflow模型的部署或者跨语言的使用都是基于SavedModel的,看如下一张图就知道SavedModel的地位了:

tensorflow serving部署,比如使用java预测,如果不能储存成SavedModel意味著以上这些功能全都不能使用,意味著tensorflow最强的部署的能力都废了。(后续的文章中会有一文专门说到新版本下的模型序列化)

这个例子中inputs + 1是造成不能序列化的罪魁祸首(因为隐式地调用了tf.add),经过测试,tf.keras.layers.Layer的子类的调用才能使得tf.saved_model.save调用成功。其实这个问题并不难修复,但是从中可以看出的是,很多组件(比如tf.saved_model.savetf.print等)在设计的时候并没有考虑到functional api。

如果将以上的代码修改成eager模式的,也就是第二种方式的构造函数,将会有更多的便利。

  • subclassed

import tensorflow as tf

class MyModel(tf.keras.Model):

def __init__(self):
super(MyModel, self).__init__()

# 目前只需要知道`tf.function`会将动态图转变为静态图
@tf.function
def call(self, inputs):
# 如果发现这里有nan,注释掉`tf.function`用动态图加断点调试非常便利
y = inputs + 1
# 即使在静态图中,依然可以用tf.print列印出张量的内容
tf.print(y)
return y

model = MyModel()
# 序列化成`SavedModel`也不会有任何问题
tf.saved_model.save(model, "save/model")

如果你在此时发现ynan,在y = inputs + 1处打断点,就能看到运算过程中张量的具体值了。

同时借助tf.function修饰call方法可以随时在动静态图之间切换,使用tf.print,即使在静态图下依然能够随时看到中间变数的内容(后续的文章中会有一文专门说到新版本下的完全动态图化之后静态图转换器tf.function的使用),也可以传入额外的python值(如True或者False等等)作为call的参数,这些额外的功能functional api全都没有。

显然subclassed方式才是更适合TF 2.0的。

其实很多人对新版本完全keras化有很大的误解,比如这个知乎问题直接问出了:

TF变成keras,早知道就直接用keras了,用TF干嘛?

实际上新版本的tf.keras借助于动态图执行和subclassed的书写代码的方式,能够非常好地平衡可调试性和可部署性(说白了就是动静态图之间的权衡),然而很多keras的用户还在坚持keras官网推荐的functional api这种无法调试的方式来写,丢失了大量tf.keras所拥有的功能如直接序列化成SavedModel用于部署等等。从这个角度上来讲,TF并没有变成keras,只是改造了keras的类作为自己的状态容器而已

compile和fit的一站式训练

github上有各种使用keras.Model构建模型并使用compilefit进行训练的例子(如transformer,han)(google搜索keras + 模型名称保证能找到实现),因此这一小节就不再赘述使用方法,只说为什么。

很多时候训练需要额外的写很多辅助的代码,比如实现一个process bar看看每个epoch的训练进度、实现一个early stop或者实现一个当验证集损失变小时才存checkpoint的机制,这些实现都需要时间,当然如果代码积累好了直接拿来用,不存在这种问题,如果没有任何积累,compile + fit的机制可以帮你快速地做掉这些东西。keras经过重构之后,使用subclassed方式构建的模型进行compile + fit的方式跟functional api的方式几乎没有任何区别,并且可以配合tf.data.Dataset使用。官方有一些使用subclassed + compile + fit + tf.data.Dataset的例子,有兴趣的可以看下,不再赘述。

经过测试,如果使用compile + fit的方式进行训练,想进行调试查看运算过程中张量的内容的话,记得在构造函数中加上dynamic=True

当然,这种方式非常地死板,很多人不喜欢,而更想要定制化的训练循环,这当然可以,借助于tf.GradientTape就可以自己算导数,tf.optimizers.Optimizer做参数更新,总之也不一定非要用compile + fit

废弃或被更改的各种方法与属性

以前只有theano作为backend和functional api的时候,keras.Model定义了很多符合当时应用场景的属性和方法,然而随著tensorflow和动态图的引入,这些属性和方法大多已经(部分)失效,显得这个类型非常的冗余,当然也有一些仍然还有用的(比如compilefit)。以下为一些例子:

  1. 命名失当的.save_weights

很久以前keras.Model.save_weights是用来储存权重的,然而到了新版本中,千万不要被save_weights这个名字蒙蔽双眼,很多人看到这个方法第一感觉可能是首先获得keras.Model.weights(或者keras.Model.variables)然后储存变数就完事?以前是这样,新版本可就不是了,看个例子

import tensorflow as tf
from tensorflow.python.ops.lookup_ops import MutableHashTable

class Model(tf.keras.Model):
def __init__(self, dynamic=True):
super(Model, self).__init__(dynamic=dynamic)
self.hash_table = MutableHashTable(key_dtype=tf.int32, value_dtype=tf.int32, default_value=0)
self.hash_table.insert([1, 2, 3], [2, 4, 6])
self.v = tf.Variable(1)

def call(self, x, training=True, mask=None):
prob = self.hash_table.lookup(x) + self.v
return prob

model = Model()
y = model(tf.constant([1, 3]))

这个例子中除了self.vself.hash_table也会被储存下去,因此save_weights应该叫save_saveables才更合适。

实际上,tf.keras.Model.save_weightstf.train.Checkpoint的作用方式是一样的,.weightstf.keras.Model.save_weights保存的变数实际上不是相同途径收集的,不要认为存的checkpoint跟.weights(或.variables)里面的对象是一样的

tf.keras.Model.save_weights还是用tf.train.Checkpoint存储checkpoint?官方有一段解释,大意是说

  • tf.keras.Model.save_weights存的checkpoint要用tf.keras.Model.load_weights载入
  • tf.train.Checkpoint.save存的checkpoint要用tf.train.Checkpoint.restore载入
  • 最好使用tf.train.Checkpoint而不是tf.keras.Model.save_weights(不知道原因,我猜应该是save_weights在新版本真的不是个能够反应其功能的名字

2. 被重构了的.savetf.keras.models.load_model

熟悉keras的用户应该都知道,tf.keras.Model.save(或者keras.Model.save)在很久以前设计的时候是储存整个模型的(储存格式为.h5),储存的单元为tf.keras.layers.Layer,弊端在于,对于自定义的层,如果丢了代码就载入不回来了而且.h5没法直接像SavedModel那样有很多的部署工具可以支持。在新版本中,.save将被修改成默认储存成SavedModel格式,以后只要有SavedModel,就算丢失了所有模型的实现代码,也能够将模型载入回来(使用tf.keras.models.load_model)。截止到本文发稿时,这项工作仍在进行当中。

3. 手动.build

首先要明确一下,keras.Model.build是从keras.layers.Layer那儿(间接)继承来的,在前面(tf.keras.layers.Layer的惰性构建一节)也提到了这个.build的作用是在第一次作用在输入张量上的时候才构建与之匹配的权重减少人工计算维度的不必要性。而在只有functional api的年代,keras.Model.build方法是个冗余的方法,因为functional api中的keras.Model中所有的层都有构建好的权重。

新版本中因为subclassed方式书写的tf.keras.Modeltf.keras.layers.Layer一样也是有惰性构建机制的,因此tf.keras.Model.build被修改成可以直接根据输入形状构建权重,看个例子:

import tensorflow as tf

class Model(tf.keras.Model):
def __init__(self):
super(Model, self).__init__()
self.d = tf.keras.layers.Dense(3)

def call(self, x, training=True, mask=None):
return self.d(x)

m = Model()
# 列印为空
print(m.variables)
# 这里只要给进去输入张量的形状即可,不需要具体的张量内容
m.build((None, 4))
# 列印,有两个权重
print(m.variables)

tf.keras.Model.build可以帮助用户只给进去输入的形状而构建权重,并在此之后对变数进行查看内容、赋值和修改的等等。

4. 不兼容subclassed方式的.summary

.summary无法展示出subclassed方式构建的模型中的二级及更高子tf.keras.layers.Layer,只能展示一级的,看个例子:

import tensorflow as tf

class Model(tf.keras.Model):
def __init__(self):
super(Model, self).__init__()
self.d = tf.keras.layers.Dense(3)

def call(self, x, training=True, mask=None):
return self.d(x)

class SuperModel(tf.keras.Model):
def __init__(self):
super(SuperModel, self).__init__()
self.model = Model()

def call(self, x, training=True, mask=None):
return self.model(x)

m = SuperModel()
m.build((None, 4))
m.summary()

输出结果为

Model: "super_model"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
model (Model) multiple 15
=================================================================
Total params: 15
Trainable params: 15
Non-trainable params: 0
_________________________________________________________________

SuperModel的第一子级tf.keras.layers.Layer是可以展示,也就是这个Model,但是其第二子级,也就是Model的第一子级d就无法展示了,更别提希望能够展示出完整的DAG的结构图了…...

5. 各种对subclassed方式不再有效的属性

在只有functional api和静态图的时代,keras.Model还有很多有意义的属性,比如keras.Model.inputs代表了输入的符号张量keras.Model.outputs代表了输出的符号张量。对于动态图和subclassed方式构建的模型,这些属性显然是没意义的,因为没有任何一个张量代表了subclassed的方式构建的模型的输入,类似的属性还有inputinput_shapeinput_maskinput_namesinbound_nodesoutputoutput_maskoutput_namesoutput_shape等等,这些冗余的属性使得tf.keras.Model这个类型显得非常笨重

tf.keras.Sequential

文档在此,对于输入的张量,tf.keras.Sequential挨个作用里面包含的tf.keras.layers.Layer,适合实现VGG16这样一条路走到黑的模型。在一个大的模型中某些一条路走到黑的子模型用tf.keras.Sequential实现会极大地提升实现的效率。

因为继承自tf.keras.Model,拥有tf.keras.Model的几乎所有的功能,因此也可以使用compile + fit进行训练。

状态容器选择

六种种状态容器,选择继承哪个?

  • 仅在学习和深入研究状态容器(或基于对象的储存)时使用TrackableAutoTrackable建模中不要使用这两个类型,因为没有显式地变数收集(.variable属性)的机制,使用tf.GradientTape.gradient会费力。
  • tf.Module: 优点是轻量级,介面设计非常干净,可以配合tf.keras.layers.Layer使用,缺点是无法配合tf.keras.Sequential等等使用(虽然自己实现一个也不难),如果希望有自定义的训练循环(custom training loop)tf.Module是一个好的选择。
  • tf.keras.layers.Layer:已经有大量定义好的tf.keras.layers.Layer的子类如tf.keras.layers.Conv2D等等可以在建模中使用。这个类型适合实现一些中间层,比如Attention之类的,还可以配合tf.keras.Sequential使用,极少看见大的模型继承自这个类型(没看见不代表不可以)。
  • tf.keras.Model:曾经为静态图定制的类型,此后经过大量的改造以此适应新版本的动态图和subclassed的写法,代价是大量的属性和方法(部分)都失效了,显得非常笨重。但至少不同版本的兼容性还是在的,依然可以使用固定套路compile + fit跑模型,配合tf.keras.callbacks.Callback能够极快速地实现各种训练功能(如定期Checkpoint等)。虽然也可以自定义训练循环,但是有一种杀鸡用牛刀的感觉。
  • tf.keras.Sequential:适合实现一条路走到黑的(子)模型,比如VGG,可能不到三十行就能实现一个。

状态容器是新版本完全拥抱动态图执行的直接结果,如果没有状态容器,则局部创建的变数tf.Variable将会释器回收掉。这难道就意味著1.x的静态图就完全抛弃了么?显然至少有两个方面动态图是完全比不过静态图的,一是因为静态图的各种优化(如constant folding、kernel fusion)带来的运算速度的提升,二是静态图可以被序列化成SavedModel用于部署,用常见的静态类型语言(C++Java)都可以在工业场景载入调用。新版本的静态图去哪了?后面的《TensorFlow 2.0学习笔记之静态图转换器》揭晓。


推荐阅读:
相关文章