简单答案:会!而且效果会变的很差

详细答案:

这其实是一个很有趣的研究领域:连续学习(continual learning)或者叫 终身学习(lifelong learning)

为了更清楚的 理解,以及解决 这个问题,我们可以来看一下 这个领域 里面,比较重要的 两篇 文章

论文: Overcoming catastrophic forgetting in neural networks

这是 Deepmind 的一篇文章,发表在了 Proceedings of the National Academy of Sciences of the United States of America (PNAS) 上面。这个期刊上很多都是自然科学的文章,可以说也很有趣了,不过按照这个期刊的描述:It is the official journal of the National Academy of Sciences, published since 1915, and publishes original research, scientific reviews, commentaries, and letters.,倒也算契合。


问题定义

先说要解决什么问题。

现在有 dataset A,我们搭了一个模型,然后在 A 上也 train 出了满意的结果。然后 dataset B 来了,我们那刚刚 train 好的模型在 B 上继续 finetune,train 之后,发现模型在 B 上的结果也很满意。

但是,这个时候你如果拿最终这个模型去 A 上重新测一遍,你会发现准确率已经惨不忍睹了。所以这个问题就是

「机器学习的很快,但是遗忘的也快」,英文叫做 「catastrophic forgetting」

为了更好定义这个问题,必须要指出的是,在上述例子中,一旦我们开始在 B 上训练是,A 就永远的不可见了。

方法

High-level idea:

find a solution to a new task in the neighbourhood of an older one

就是说,我们在 B 上进行 finetune 的时候,不要离之前的 model 太远了,这就是基本思路

上图里面,两个 set 的重合处,就是我们要找的,既对 B 效果好,又离之前的 model 的比较近。

理论依据:

Many configurations of θ will result in the same performance [1, 2]。

[1] Robert H. Nielsen. Theory of the backpropagation neural network. In Proceedings ofthe International Joint Conference on Neural Networks, volume I, pages 593–605. Piscataway, NJ: IEEE, 1989.

[2] Héctor J. Sussmann. Uniqueness of the weights for minimal feedforward nets with a given input- output map. Neural Networks, 5:589–593, 1992.

就是说,对于同一个 task,是有可能找到很多符合条件(也就是效果不错)的 model 的,那么我们就可以从中找到一个离之前 model 最近的一个 model。试想,如果没有这个结论,那么「既在 B 上效果好,又不至于离之前的 model 太远」 可能就是一个不可能的任务。

具体做法

基本的出发点是,当我在 B 上面 finetune 的时候,并没有必要调整所有的参数,我们尽可能只动那些对 A 影响比较小的参数,就足够了。

对于 train 一个 model 来说,我们在更新参数的过程可以用一个贝叶斯公式来表达,

上面公式中,

[公式]

就是我们常用的 loss 的相反数。后两项分别是 weights 和 dataset 的先验分布。其中 dataset 的先验分布由于本身在 dataset 确定之后就固定了,所以一般也不考虑。而 weights 的先验分布则常常是一些 regularization,比如 weights decay 等等。

那么,当我们从 A 转到 B 的时候,上面公式发生了如下变化,

最关键的就是理解等号右边的中间这一项。

在本文的问题中,我再不是对 weight 加一个 regularization 那么简单了,而是我希望这个 weights 能够对 A dataset 也有比较好的准确率。所以这里替换成了 A dataset 的优化目标。

上面这个公式,此时不只是 B 的优化目标,而是 A 和 B 整体的优化目标。

直接求,

[公式]

是没办法求的,因为当我们在 train B dataset 的时候,A dataset 已经看不到了。那么这里我们做一个假设,认为每个 theta 是符合一个高斯分布的,

  1. 这个高斯分布的均值就是 现在 A 上 train 后得到的那个值。这个很好理解,因为那个值就是我们想保持的值,所以把他设为分布的中心(也就是均值)是合理的。
  2. 这个高斯分布的 precision(就是方差的倒数)

[公式]

可以用 Fisher information matrix, F 来估计。我们暂时不讲为什么 precision 可以用 Fisher information matrix 估计,先来讲一讲为什么:

方差的倒数代表这个参数对 A 的重要性。倒数值越低,改变他对 A 的影响越小。

上面的绿色的线是一个方差比较大的分布,我们稍微改变一下参数的值,概率的变化并不大。

上面的蓝色的线是一个方差比较小的分布,我们同样改变相同的量,他的概率变化要大得多。

我们当然希望是条件概率越大越好,所以我们要找到那些方差比较大的参数,因为改变他们对 A 的概率项影响较小。

Fisher Information Matrix

总结自:

https://www.inference.vc/on-empirical-fisher-information/

https://wiseodd.github.io/techblog/2018/03/14/natural-gradient/

对于一个监督学习的问题,

[公式]

Fisher Information Matrix (FIM) 的定义是,

[公式]

一个比较直观的解释是,FIM是 两个分布 KL Divergence 的 Hessian 矩阵,而且这两个分布之间只有非常细微的差别,下面来推导一下。

首先来看两个细微差别的分布的 KL Divergence 的定义,

[公式]

然后我们先来对

[公式]

求一阶导数,

然后再求二阶导数,

[公式]

然后就可以得到 Hessian 矩阵了,

论文:Memory Aware Synapses: Learning What (not) to Forget

这里的 [公式] 是 neural network,

值得注意的是,这里的 [公式] 并不是损失函数对参数的导数,而是网路对应的函数对参数的导数。一般来说,后者要复杂不少。因为损失函数是一个标量,而网路的输出则是是一个向量。加入这个向量的长度为10(总共有10类),那么算网路导数的计算量就是算损失函数导数的10倍,所以这个问题必须解决。

本文的解决办法也比较直接,那就是我算输出向量 L2 norm 关于参数的导数。关于这个妥协,作者后面也做了实验,发现并不会让效果打折。

这里的 [公式]是可以作为一个重要性程度衡量标准的。因为自然 [公式]越大,变动参数带来的函数的变化也就越大。所以综合考虑所有的 data points,我们可以得到:

局部版本

之前的 [公式] 是一个函数,这个函数代表整个网路。同样的,我们认为网路中的每一层也是一个函数

[公式],这里的下标 [公式] 表示第几层。背后的 idea 是,只要我们能够保证每一层变化足够小,那么整个网路的变化也会比较小。

这里红框这个公式其实是有问题的。

[公式] 实际上应该是 [公式] 是这一层第 [公式]-th 输入,也就是 [公式] 的输出;

[公式] 实际上应该是 [公式] 是这一层第 [公式]-th 输出

同样的,最终每个参数的重要程度要综合考虑所有的 data points,

实验

前面讲了讲个版本:全局的,局部的。实验部分,作者用的都是全局的这个方法(主要也是因为全局方法效果更好一些)。

「这个方法真正牛的地方在于,他没有对 regularization 这一项前面乘的系数进行调参。使用默认值 1,也取得了比较好的效果」

而且,这篇文章只用到了一阶导数信息,相对来说比较容易计算,而且不要「网路已经train到收敛」的假设条件。

总的来说,这篇文章可以算作是一篇 「simple but work」 的工作。


这就是传说中的「灾难性遗忘的问题」,你可以看看life long learning或continual learning就是专门研究这个问题。

这问题说难也难,说简单也简单。只要你保留全部的老数据,微调新数据的时候,加一个multi task的分支给老数据一起训练,基本就是这个问题的upper bound了。

这个问题难在某些NLP任务中, 需要不断扩展类别。每次加入新类别,需要重新训练模型。尤其当体系大海量数据时,重新训练模型代价很高。此外还要保证已有类型的稳定,避免对线上业务的影响。

如何高效地处理这个问题?

有一篇文章Learning without Forgetting,提供了很好的思路:

简单说就是先用老模型过一遍新数据,保留老模型的输出,然后训练新模型,同时对新模型的输出计算distillation loss对冲掉新数据对老模型参数的影响。这篇文章的idea真是棒棒哒!


ICLR2020上有一篇 Editable Neural Networks 就是解决你说的问题,针对error 样本进行调整但是不影响其他数据。

另外最近一众Life long learning/ incremental learning 还有关于catastrophic forgetting都是研究这个问题的。


当然会,所以有个现象叫做灾难遗忘嘛。

怎么解决?Continual learning可以看看。

Transfer就解决不了了,因为迁移的目的就是要求在目标领域的性能最好,不管之前的领域了。

多任务学习倒是也可以参考。


答案是肯定的,这叫做灾难性遗忘现象(catastrophic forgetting),已经有许多相关的研究了。比如说Overcoming catastrophic forgetting in neural networks这篇文章已经达到700+的引用量。而针对这一问题提出的连续学习(continual learning)近期是研究的大热门,连续学习就是要解决在连续学习多个任务的情况下,保持前部任务的高准确度,当前仍然是一个充满挑战的领域。


推荐阅读:
相关文章