李宏毅HW03(CNN)记录
基于PyTorch的食物分类项目:从Baseline到FixMatch与TTA的实践之路
引言
这是让AI帮我整理的李宏毅hw03CNN,记录了相对baseline的改进,并让AI帮我把混乱的代码稍微做了重构。这次没有做模型集成,没怎么调参,最终在kaggle排行榜10%左右,过了private strong baseline。
在深度学习领域,图像分类是一个基础且重要的任务。本次项目旨在对 food-11 数据集进行分类,但更重要的目标是探索如何从一个简单的基线模型(Baseline)出发,通过应用一系列先进的机器学习技术和工程实践,系统性地提升模型的性能和训练效率。
本文将分为两个部分:
- 第一部分(核心):详细记录了项目代码(
hw03.py)相比于原始基线(HW03.ipynb)在算法和训练策略上的关键改进。这部分将深入探讨每一项技术的原理和实践价值,是我工作的核心成果。 - 第二部分(次要):介绍在此基础上进行的工程化重构,旨在将实验性脚本转化为一个灵活、可复用的命令行工具。
Part 1: 我的模型改进与实践 (核心)
相比于 HW03.ipynb 中提供的简单CNN和基础训练流程,我的 hw03.py 脚本集成了一系列现代深度学习技术,实现了从监督学习到半监督学习,从基础训练到高级混合策略训练的跨越。
1.1 模型架构:从零开始的ResNet-18
- 基线 (
HW03.ipynb): 提供了一个简单的自定义CNN,并建议可以使用resnet18(pretrained=False)。 - 我的实现 (
hw03.py):- 果断采用了ResNet-18 (
torchvision.models.resnet18(weights=None)) 作为骨干网络。相比基线的浅层CNN,ResNet的残差连接结构能有效解决深度网络中的梯度消失问题,允许网络构建得更深,从而学习到更丰富的特征。 - 严格遵守了作业要求,通过设置
weights=None(等同于pretrained=False) 确保模型是从零开始训练,保证了实验的公平性。 - 精确地替换了模型的最后一层全连接层 (
model.fc = nn.Linear(model.fc.in_features, 11)),以匹配food-11数据集的11个类别。
- 果断采用了ResNet-18 (
1.2 高级数据增强策略 (Advanced Data Augmentation)
数据增强是防止过拟合、提升模型泛化能力的关键。我采用了远比基线复杂的增强策略。
- 基线 (
HW03.ipynb): 仅使用了基础的Resize,RandomHorizontalFlip和ColorJitter。这些操作虽然有效,但组合相对简单。 - 我的实现 (
hw03.py):RandomResizedCrop(224): 代替了简单的Resize(128)。它在训练时随机裁剪图像的不同部分并缩放到224x224,这让模型对物体的位置和大小变化更具鲁棒性,是ImageNet预训练的标准做法。RandAugment: 这是我引入的一项核心增强技术。它是一种自动化的数据增强策略,能从一系列(如旋转、色彩、对比度等)变换中自动学习并选择最佳的组合和强度。相比基线手动、固定的ColorJitter,RandAugment提供了更丰富、更强大的增强效果,极大地提升了模型的泛化能力。RandomErasing: 训练中随机擦除图像的一块矩形区域。这个过程模拟了真实世界中的物体遮挡情况,强迫模型去关注物体的全局特征而非仅仅是局部细节,是提升模型鲁棒性的又一利器。
1.3 半监督学习:FixMatch的应用
这是本次实践中最为重要的升级之一,旨在利用项目中提供的大量未标注数据。
- 基线 (
HW03.ipynb): 提供了一个get_pseudo_labels函数的框架,该方法通过简单的置信度阈值来生成伪标签,是一种较为基础的半监督方法(Self-Training),且实现较为复杂,容易出错。 - 我的实现 (
hw03.py):- 完整地实现并应用了 FixMatch 算法,这是一种业界领先的半监督学习框架。
- 核心思想: FixMatch的核心是一致性正则化。它假设模型对于同一张图片的不同增强版本,应该给出一致的预测。
- 实现流程:
- 弱增强与强增强: 对每一张未标注图片,同时生成一个“弱增强”版本(仅翻转和裁剪)和一个“强增强”版本(应用
RandAugment)。 - 生成伪标签: 将“弱增强”图片输入模型,得到预测。只有当模型对这个预测的置信度高于一个阈值(如0.95)时,我们才认为这个伪标签是可靠的。
- 计算一致性损失: 对于那些拥有可靠伪标签的图片,我们要求模型在看到其“强增强”版本时,其预测结果应与伪标签保持一致。这个差异就构成了无监督损失 (
loss_u)。 - 总损失: 最终的训练损失是有监督损失 (
loss_s,来自标注数据)和无监督损失 (loss_u)的加权和。
- 弱增强与强增强: 对每一张未标注图片,同时生成一个“弱增强”版本(仅翻转和裁剪)和一个“强增强”版本(应用
- 优势: 相比基线的简单伪标签方法,FixMatch通过“弱增强生成标签、强增强进行学习”的非对称设计,能生成更高质量的伪标签,并利用一致性正则化让模型在无标签数据上学到更鲁棒的特征表示。
1.4 训练过程优化 (Training Process Optimization)
为了最大化训练效果和效率,我引入了多种优化技术。
AdamW与CosineLRScheduler:- 优化器: 使用
AdamW代替了基线的Adam。AdamW通过解耦权重衰减和梯度更新,通常能获得比Adam更好的泛化性能。 - 学习率调度: 实现了带有预热(Warmup)的余弦退火学习率 (
CosineLRScheduler)。在训练初期使用一个较小的学习率(预热)让模型参数稳定下来,然后学习率按余弦曲线逐渐下降。这种策略在大型模型训练中被证明非常有效,能帮助模型跳出局部最优,达到更好的收敛点。
- 优化器: 使用
- 混合精度训练 (AMP - Automatic Mixed Precision):
- 通过
torch.cuda.amp.autocast和GradScaler开启了自动混合精度训练。 - 原理: 在训练中,部分计算(如卷积)使用16位浮点数(FP16),而另一部分(如损失计算)保持32位浮点数(FP32)。
- 优势: 显著减少GPU显存占用并加快训练速度,同时通过
GradScaler动态缩放损失,避免了FP16下数值下溢的问题,保证了训练的稳定性。
- 通过
Mixup与CutMix:- 通过
timm库引入了这两种强大的正则化技术。Mixup将两张图片按比例混合,标签也做相应混合;CutMix则是将一张图的一部分剪切并粘贴到另一张图上。 - 效果: 它们通过在训练样本之间创造虚拟的训练数据,极大地丰富了数据分布,有效防止了模型过拟合,并提升了模型的泛化能力和鲁棒性。
- 通过
- 梯度裁剪 (Gradient Clipping):
- 在优化器更新前,使用
torch.nn.utils.clip_grad_norm_对梯度进行裁剪。这可以防止在训练中可能出现的梯度爆炸问题,使训练过程更加稳定。
- 在优化器更新前,使用
1.5 提升预测准确率:测试时增强 (TTA)
- 基线 (
HW03.ipynb): 在测试时,仅对每张图片进行一次中心裁剪预测。 - 我的实现 (
hw03.py):- 引入了**测试时增强 (Test-Time Augmentation, TTA)**。在预测阶段,对每一张测试图片生成多个增强版本(例如,原始图、水平翻转图、不同角度旋转图等),分别进行预测,最后将所有预测结果进行平均(或投票)。
- 效果: TTA能有效减少模型对特定视角或裁剪方式的依赖,通过集成多个视角的预测结果,通常能带来显著且“免费”的准确率提升。
1.6 实验监控与管理
TensorBoard: 使用SummaryWriter将训练和验证过程中的关键指标(如损失loss、准确率acc、学习率lr)实时记录到日志中。通过TensorBoard可视化这些指标,可以直观地监控训练状态、诊断问题、比较不同实验的效果。EarlyStopping(早停机制): 我编写并集成了一个EarlyStopping类。它会监控验证集上的损失(或准确率),如果在连续多个epoch内没有改善,就会自动停止训练。这不仅可以防止模型过拟合,还能节省大量的训练时间,并自动保存性能最佳的模型。
Part 2: 代码重构:从实验脚本到命令行工具
在您完成上述所有算法和策略的探索后,代码已经变得非常强大,但同时也成了一个冗长的脚本。为了便于复用、分享和进行更多的自动化实验,我对其进行了工程化重构。
重构目标
将一个从头运行到尾的脚本,改造成一个可以通过命令行控制其行为的、模块化的工具。
主要改进
参数化 (
argparse):问题: 原始脚本中的所有超参数(如学习率、批次大小、epoch数、文件路径)都以硬编码的形式散落在代码中。每次调整都需要手动修改代码,非常低效且容易出错。
解决方案: 引入Python内置的
argparse库,将所有可变参数都定义为命令行参数。示例: 现在,你可以像下面这样启动训练,轻松改变配置:
1
2
3
4
5# 使用不同的学习率和批次大小进行训练
python hw03_refactored.py --lr 0.001 --batch_size 32
# 关闭半监督学习,只训练100个epoch
python hw03_refactored.py --use_semi_supervised False --epochs 100
模块化(函数封装):
- 问题: 原始脚本的所有逻辑(数据加载、模型定义、训练、验证、测试)都混合在主执行流程中,代码可读性和可维护性差。
- 解决方案: 将代码按功能拆分为独立的函数,例如:
get_data_loaders(): 负责所有数据加载器的创建。build_model(): 负责模型的构建。train_one_epoch(): 包含一个epoch的完整训练逻辑。validate(): 包含验证逻辑。main(): 作为程序的入口,负责解析命令行参数,并按顺序调用其他函数。
- 优势: 这种结构使得代码逻辑清晰,每个函数职责单一,极大地提高了代码的可读性、可维护性和可测试性。
结论
从一个基础的CNN模型,到最终集成了FixMatch、RandAugment、Mixup、AMP、TTA等一系列先进技术的复杂训练流程,这次项目实践完整地展示了如何在深度学习任务中系统性地提升模型性能。第一部分详述的各项技术是性能飞跃的核心,而第二部分的工程化重构则为未来的进一步实验和部署打下了坚实的基础。