QZQ的小世界!

  • 首页
你好!
这里是QZQ的博客站!
  1. 首页
  2. 未分类
  3. 正文

Pytorch训练的过程中出现nan的情况处理

2025年4月4日 45点热度 0人点赞 0条评论

前言

搞定了LSTM理论之后,按照理解搭建了一个简易的模型,但是在一切都看起来没什么问题的情况下,报错了。

报错不太寻常,因为并不是跑一轮就报错,而是正常跑几轮才会报错。

用异常捕获机制强制停止出现异常的轮次,打印发现模型输出包含nan的张量。

漫长的debug就这样开始了。

神经网络传播过程中nan的处理

如果神经网络设计有缺陷,确实可能出现传播过程中出现nan的情况,而且这种情况在网络上非常常见。

比如:PyTorch训练过程中出现NaN的排查笔记 - 知乎

我觉得就排查的非常好,非常有条理。

神经网络传播过程中可能出现nan的地方很多,这里直接借用这位作者整理的:

  1. 学习率过大

学习率是控制模型参数更新的重要超参数。如果学习率设置得过大,模型参数更新的幅度可能会过大,从而导致损失值发散。这种情况下,测试损失的值可能会变为NaN。可以尝试减小学习率,以确保模型稳定地收敛。

  1. 模型设计问题

测试损失变为NaN还可能是模型设计存在问题。在一些情况下,模型的架构可能导致数值不稳定,从而出现NaN值。这可能是由于某些操作(例如relu、softmax等)在特定情况下产生了数值溢出或欠溢出。可以尝试改变模型的架构,例如使用不同的激活函数或正则化操作,来处理这个问题。

  1. 尺度不平衡的初始化

“尺度不平衡的初始化”是指权重初始化得过大或过小,造成了梯度更新时的不稳定性。使用适合你使用的激活函数的初始化方法(如He或Xavier初始化)可以有效地解决这一问题。

这位作者的排查方式比较专业,但是对于我来说实在是太难了,因为我并不十分清楚lstm的构造,说白了把lstm当个黑盒在使用,有没有更方便的方法?

有的,总算给我找到了:训练过程中出现nan(not a number)的原因及解决方案 - 知乎

可以在 python 文件头部使用如下函数打开 nan 检查:

torch``**.**``autograd``**.**``set_detect_anomaly(True)

加了这行,一旦传播过程中出现了nan,会直接在该处报错停止代码。由于是直接给出了问题发生的位置,这让debug变得非常容易!

另外,如果是反向传播过程中要使用这种检查,需要这样

loss = model(X)
with torch.autograd.detect_anomaly():
    loss.backward()

这样就能定位因为梯度爆炸之类的原因产生的nan了。

开启 nan 检查后,直接定位在了数据输入模型这步上,也就是说数据集有空值!

数据集有nan值的处理

一般来说,用以下语句定位数据集确定的有nan的位置

assert not torch.any(torch.isnan(T))

然后,要么变换,要么删掉就好了!

其他可能出现nan的情况

调研的时候看了不少文章,这里汇总一下

1、pytorch混合精度训练,使用半精度等能提升batch,但是有出现nan的风险

pytorch混合精度训练出现nan问题解决 - 知乎

2、升级你的numpy…

pytorch中第一轮训练loss就是nan是为什么啊? - 知乎

反正看到过很多奇怪的情况,结果都是升级numpy解决的…

[文章导入自 http://qzq-go.notion.site/12949a7b4e7580fc88fbe94d75bcf251 访问原文获取高清图片]

本作品采用 知识共享署名-非商业性使用 4.0 国际许可协议 进行许可
标签: IT技术 Pytorch和它学不完的AI =.=
最后更新:2025年4月3日

QZQ

一只涉猎广泛的技术爱好者,绝赞养猫中~

点赞
< 上一篇
下一篇 >

归档

  • 2025 年 4 月
  • 2025 年 3 月
  • 2025 年 2 月
  • 2025 年 1 月
  • 2024 年 12 月
  • 2024 年 11 月

分类

  • 技术
  • 未分类

COPYRIGHT © 2024 QZQ的小世界!. ALL RIGHTS RESERVED.

Theme Kratos Made By Seaton Jiang