CIFAR-10 完整训练
在 CIFAR-10 数据集上训练 NCT,掌握超参数调优
🎯
CIFAR-10 完整训练
在 CIFAR-10 数据集上训练 NCT,掌握超参数调优
⏱️ 45 分钟📊 中级🧪 图像分类
实验概述
CIFAR-10 是一个更具挑战性的图像分类数据集,包含 10 个类别的 60,000 张 32x32 彩色图像。 本实验将带你掌握数据增强、学习率调度、模型保存和可视化等高级技巧, 帮助你在真实场景中获得最佳性能。
运行实验
方法 1:直接运行脚本
# 克隆仓库
git clone https://github.com/wyg5208/nct.git
cd nct
# 安装依赖(需要额外安装 torchvision)
pip install torch torchvision numpy scipy matplotlib
# 运行 CIFAR-10 完整训练
python experiments/run_cifar10_full.py方法 2:自定义配置训练
from nct_modules import NCTManager, NCTConfig
from datasets import load_cifar10
# 加载数据(带数据增强)
train_data, test_data = load_cifar10(
'./data/CIFAR-10',
augment=True, # 启用数据增强
normalize=True
)
# 创建标准配置(d_model=768, n_heads=8)
config = NCTConfig(
n_heads=8,
d_model=768,
n_layers=6,
dropout=0.1
)
# 初始化并训练(带学习率调度)
manager = NCTManager(config, lr=1e-4)
manager.train(
train_data,
epochs=50,
batch_size=64,
scheduler='cosine', # 余弦退火
save_best=True, # 保存最佳模型
visualize=True # 实时可视化
)关键技术点
🔄 数据增强
- • 随机水平翻转
- • 随机裁剪(padding=4)
- • 颜色抖动
- • CutMix/Mixup(可选)
📈 学习率调度
- • Warmup(前 5 epoch)
- • 余弦退火(5-50 epoch)
- • 学习率范围:1e-4 ~ 1e-6
- • 梯度裁剪(max_norm=1.0)
💾 模型保存
- • 每个 epoch 保存 checkpoint
- • 只保留 top-3 最佳模型
- • 自动删除旧文件节省空间
- • 支持断点续训
📊 可视化监控
- • 实时训练/验证曲线
- • 注意力权重热图
- • 混淆矩阵
- • 错误案例分析
预期结果
~92.3%
测试集准确率
~45 分钟
训练时间(GPU)
Top-1
CIFAR-10 SOTA
🎯 学习目标
- ✅ 掌握 CIFAR-10 数据加载和增强
- ✅ 学会使用学习率调度器
- ✅ 能够保存和加载模型 checkpoint
- ✅ 能够分析和可视化训练过程
- ✅ 理解超参数对性能的影响
💡 优化建议
✅ 推荐配置(平衡性能和速度)
d_model=768, n_heads=8, n_layers=6, batch_size=64, lr=1e-4
🚀 高性能配置(追求最佳准确率)
d_model=1024, n_heads=12, n_layers=8, batch_size=32, lr=5e-5
⚡ 快速原型配置(调试用)
d_model=256, n_heads=4, n_layers=4, batch_size=128, lr=3e-4