小白求问,我在运用vgg19进行3classes图像分类,经过了10epochs后valaccuracy保持在了80%左右,但是train_accuracy一直在上升,请问是什么原因导致的,是不是overfitting呢?


对于神经网路训练中,准确率不足的问题,系统整理一下:

第一部分,我们手里的锤子

  1. 搜集更多的数据:万能方法,但往往成本很高。data augmentation 是一种妥协方法

2. 更多的数据,并且让训练集尽量多种多样:同上,万能方法,成本很高

3. 更久的训练

4. 使用其它优化方法,而不是梯度下降,例如Adam

5. 试试更大的网路

6. 试试更小的网路

7. dropout

8. 正则化

9. 改变网路结构:激活函数,隐含层数量等

10. ........

第二部分,我们眼前的钉子

如何提高训练集上的表现?

  1. 更大的网路
  2. 使用其它优化演算法,例如Adam

如何提高 validation set 上的表现?

  1. 正则化
  2. 扩大训练集

如何提高在 test set 上的表现?

更大的validation set

如何提高真实世界上的表现?

改变 validation set/test set 或者 cost function

另外还有一些手段,例如 early stopping,它同时降低训练集上的精度和提高泛化能力。

一个行为同时显著影响两项得分,使我们在真实工程中非常困扰,能不用就不用。(这里引用吴恩达的一个比喻。如果一辆车有两套控制系统,分别控制左右、前后,这辆车就容易开。如果一辆车的控制系统是 0.7左右+0.3前后,以及0.5左右+0.5前后,理论上这车也能开,但很难很难)

另外还有一个比较笨,但非常有效的办法。把 validation set 中错误的那部分列出来单独分析,如果它们是很相似的一类,那你就要开心了。这种情况下,考虑如下处理:

  1. 数据或者模型需要改进。例如:
    1. 一个图像识别演算法,发现识别错误的基本都是老虎(识别成猫),那么可能是数据本身不足以支持这样的识别,或者模型过于简单/超参数有误。
    2. 一个房价预测模型,发现测试集中很多200平以上房价预测错误,而训练集上,200平以上样本只有2条
    3. 还是一个房价预测模型,发现测试集错误的,集中在某个小区。
  2. 如果怎样都不行的话,考虑外加规则。看起来比较丑陋,但生产环境往往需要借助这种方法。


如果只是10个epoch,可以尝试让它继续跑一会,如果还是不行的话简单说一些可以尝试的办法吧:

首先看数据有没有预处理,减去均值除以方差

再看看数据有没有增强,比如说随机裁剪,水平翻转,亮度饱和度对比度改变

网路最后fc层可以加dropout减轻过拟合

可以加正则项

模型参数的初始化,BN+ReLU的话用Xavier

尝试用label smothing

学习率调整,可以看著loss大概在多少epoch后不太降了就进行学习率衰减,或者可以尝试余弦退火

如果是用自己的数据集,看看能否收集到更多的训练数据


训练的metrics还在提高但是validation的metrics没有提高了,可以考虑是出现overfit。不过也跟你的学习策略相关,即你所选择的optimizor和学习率lr,有时候梯度下降卡在了一个local optimization的位置也可能会出现这种情况。可以使用自适应的学习策略还有改变你的学习率lr都可能帮助你的validation metrics进一步提升。

然后你说的提升模型的方法有很多可以考虑:

1)数据层面解决类间不平衡的问题,以及一个类内正负样本的不平衡问题。即想办法扩充数据,或者在训练时给不同部分的数据加上训练的权重;

2)模型层面我记得VGG19算是比较老一些的模型了,如果自己设计模型的话可以考虑参考ResNet的设计,加入一些Node的跳跃连接制造残差,这样可以加深网路的结构不容易出现过拟合,也不容易出现梯度爆炸和消失。


换一个思路,不要盲目的调

请思考下列问题

请问验证集中经常分类错误的图有哪些?

这类图片有什么特点?

是数据本身标注问题嘛?

是人也很难区分嘛?

开始训练的时候做过数据分析嘛?

数据类别数量平衡吗?

数据干净吗?

深度学习除了用各种技巧刷精度外,也要理解现在的模型的能力,可以从可视化的角度进行分析,如果发现分类错误的这些图片有共性,那么就可以对症下药了


是,不过你可以再训训,待train_acc到1了val可能会上升。我之前在pytorch forums上有个大佬回答我的是改进一下数据增强。当然也可以尝试看下你样本的分布,是不是有那么一点不均衡,然后重写一下sampler。不知道你是什么任务,可不可以加多任务的loss或者用一些类内类间loss来让训练难度加大。(我是菜 ,所以一下就只想到这些,等大佬们来回答吧


推荐阅读:
相关文章