《Python OpenCV从菜鸟到高手》带你进入图像处理与计算机视觉的大门!
解锁Python编程的无限可能:《奇妙的Python》带你漫游代码世界
目录
- 迁移学习概述
- 环境准备与数据预处理
- 使用Keras实现迁移学习
- 使用PyTorch实现迁移学习
- 模型评估与结果分析
- 迁移学习技巧与最佳实践
- 应用场景与总结
1. 迁移学习概述
迁移学习(Transfer Learning)是机器学习中的一种技术,通过将在一个任务上训练好的模型参数迁移到另一个相关任务中,从而加速模型训练过程并提升模型性能。在计算机视觉领域,常用的预训练模型(如VGG16、ResNet、Inception等)已经在ImageNet数据集上经过充分训练,可以直接用于特征提取或微调(Fine-tuning)。
迁移学习的优势:
- 节省训练时间:预训练模型已学习通用特征
- 降低数据需求:适合小样本场景
- 提升模型性能:利用已有知识提升新任务表现
典型应用场景:
- 医学影像分类
- 卫星图像识别
- 工业缺陷检测
- 自然场景物体识别
2. 环境准备与数据预处理
2.1 环境配置
# 安装必要库(Keras版本)
!pip install tensorflow keras numpy pandas matplotlib scikit-learn
2.2 数据准备
假设我们使用Kaggle的猫狗分类数据集(包含25000张训练图像)
import os
import numpy as np
from keras.preprocessing.image import ImageDataGenerator# 数据集路径配置
train_dir = '/path/to/train'
validation_dir = '/path/to/validation'
test_dir = '/path/to/test'# 图像预处理参数
img_width, img_height = 224, 224 # 匹配预训练模型输入尺寸
batch_size = 32
num_classes = 2 # 猫和狗分类# 数据增强配置
train_datagen = ImageDataGenerator(rescale=1./255,rotation_range=20,width_shift_range=0.2,height_shift_range=0.2,shear_range=0.2,zoom_range=0.2,horizontal_flip=True,fill_mode='nearest')validation_datagen = ImageDataGenerator(rescale=1./255)# 创建数据生成器
train_generator = train_datagen.flow_from_directory(train_dir,target_size=(img_width, img_height),batch_size=batch_size,class_mode='categorical')validation_generator = validation_datagen.flow_from_directory(validation_dir,target_size=(img_width, img_height),batch_size=batch_size,class_mode='categorical')
3. 使用Keras实现迁移学习
3.1 加载预训练模型
from keras.applications import VGG16
from keras.models import Model
from keras.layers import Dense, GlobalAveragePooling2D, Dropout# 加载VGG16模型(不包括顶层)
base_model = VGG16(weights='imagenet', include_top=False, input_shape=(img_width, img_height, 3))# 冻结卷积基
for layer in base_model.layers:layer.trainable = False# 添加自定义顶层
x = base_model.output
x = GlobalAveragePooling2D()(x) # 全局平均池化
x = Dense(512, activation='relu')(x)
x = Dropout(0.5)(x) # 防止过拟合
predictions = Dense(num_classes, activation='softmax')(x)# 构建完整模型
model = Model(inputs=base_model.input, outputs=predictions)