文章目录
- 模型训练验证
- 损失函数和优化器
- 模型优化
- 训练函数
- 验证函数
- 模型保存
模型训练验证
损失函数和优化器
loss_function = nn.CrossEntropyLoss() # 损失函数
optimizer = Adam(model.parameters()) # 优化器,优化参数
模型优化
获得模型所有的可训练参数(比如每一层的权重、偏置),设置优化器类型,自动调整学习步长(自适应学习率),后续训练更新参数。
# 雇佣Adam教练,让他管理模型参数
optimizer = Adam(model.parameters(), lr=0.001) # lr是初始学习率
# 1. optimizer.zero_grad() # 清空上一轮的成绩单
# 2. loss.backward() # 计算每个参数要改进的方向(梯度)
# 3. optimizer.step() # 参数调整
训练函数
def train():loss = 0accuracy = 0model.train()for x, y in train_loader: # 获得每个batch数据x, y = x.to(device), y.to(device)output = model(x) # 得到预测labeloptimizer.zero_grad() # 梯度清零batch_loss = loss_function(output, y) # 计算batch误差batch_loss.backward() # 计算误差梯度optimizer.step() # 调整模型参数loss += batch_loss.item()accuracy += get_batch_accuracy(output, y, train_N)print('Train - Loss: {:.4f} Accuracy: {:.4f}'.format(loss, accuracy))
验证函数
def validate():loss = 0accuracy = 0model.eval() # 评估模式,关闭随机性等增加稳定性with torch.no_grad(): # 禁用梯度,提高效率for x, y in valid_loader:x, y = x.to(device), y.to(device)output = model(x) # 不用进行梯度计算、参数调整loss += loss_function(output, y).item()accuracy += get_batch_accuracy(output, y, valid_N)print('Valid - Loss: {:.4f} Accuracy: {:.4f}'.format(loss, accuracy))
模型保存
.pth 文件是PyTorch模型的“存档文件”,保存了所有必要信息。加载后,模型即可直接运行,无需重新训练!
# 保存整个模型(结构 + 参数)
torch.save(model, 'model.pth')
.pth 文件可以用https://netron.app/查看