VGG应用:猫狗大战——基于VGG16的猫狗数据分类

一、数据集的处理与加载

class CatDogDataset(Dataset):def __init__(self, data_dir, mode="train", split_n=0.9, rng_seed=620, transform=None):self.mode = modeself.data_dir = data_dirself.rng_seed = rng_seedself.split_n = split_nself.data_info = self._get_img_info()  # data_info存储所有图片路径和标签,在DataLoader中通过index读取样本self.transform = transformdef __getitem__(self, index):path_img, label = self.data_info[index]img = Image.open(path_img).convert('RGB')     # 0~255if self.transform is not None:img = self.transform(img)   # 在这里做transform,转为tensor等等return img, labeldef __len__(self):if len(self.data_info) == 0:raise Exception("\ndata_dir:{} is a empty dir! Please checkout your path to images!".format(self.data_dir))return len(self.data_info)def _get_img_info(self):img_names = os.listdir(self.data_dir)img_names = list(filter(lambda x: x.endswith('.jpg'), img_names))random.seed(self.rng_seed)random.shuffle(img_names)img_labels = [0 if n.startswith('cat') else 1 for n in img_names]split_idx = int(len(img_labels) * self.split_n) # 25000* 0.9 = 22500# split_idx = int(100 * self.split_n)if self.mode == "train":img_set = img_names[:split_idx]     # img_set = img_names[:22500]     label_set = img_labels[:split_idx]elif self.mode == "valid":img_set = img_names[split_idx:]label_set = img_labels[split_idx:]else:raise Exception("self.mode 无法识别,仅支持(train, valid)")path_img_set = [os.path.join(self.data_dir, n) for n in img_set]data_info = [(n, l) for n, l in zip(path_img_set, label_set)]return data_info
    norm_mean = [0.485, 0.456, 0.406]norm_std = [0.229, 0.224, 0.225]train_transform = transforms.Compose([transforms.Resize((256)),transforms.CenterCrop(256),transforms.RandomCrop(224),transforms.RandomHorizontalFlip(p=0.5),transforms.ToTensor(),transforms.Normalize(norm_mean, norm_std),])normalizes = transforms.Normalize(norm_mean, norm_std)valid_transform = transforms.Compose([transforms.Resize((256, 256)),transforms.TenCrop(224, vertical_flip=False),transforms.Lambda(lambda crops: torch.stack([normalizes(transforms.ToTensor()(crop)) for crop in crops])),])# 构建MyDataset实例train_data = CatDogDataset(data_dir=data_dir, mode="train", transform=train_transform)valid_data = CatDogDataset(data_dir=data_dir, mode="valid", transform=valid_transform)# 构建DataLodertrain_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)valid_loader = DataLoader(dataset=valid_data, batch_size=4)

二、调用模型,定义损失函数以及优化器

 vgg16_model = get_vgg16(path_state_dict, device, False)num_ftrs = vgg16_model.classifier._modules["6"].in_featuresvgg16_model.classifier._modules["6"] = nn.Linear(num_ftrs, num_classes)vgg16_model.to(device)criterion = nn.CrossEntropyLoss()flag = 0# flag = 1if flag:fc_params_id = list(map(id, vgg16_model.classifier.parameters()))  # 返回的是parameters的 内存地址base_params = filter(lambda p: id(p) not in fc_params_id, vgg16_model.parameters())optimizer = optim.SGD([{'params': base_params, 'lr': LR * 0.1},  # 0{'params': vgg16_model.classifier.parameters(), 'lr': LR}], momentum=0.9)else:optimizer = optim.SGD(vgg16_model.parameters(), lr=LR, momentum=0.9)  # 选择优化器scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=lr_decay_step, gamma=0.1) 

三、训练过程

 train_curve = list()valid_curve = list()for epoch in range(start_epoch + 1, MAX_EPOCH):loss_mean = 0.correct = 0.total = 0.vgg16_model.train()for i, data in enumerate(train_loader):# forwardinputs, labels = datainputs, labels = inputs.to(device), labels.to(device)outputs = vgg16_model(inputs)# backwardoptimizer.zero_grad()loss = criterion(outputs, labels)loss.backward()# update weightsoptimizer.step()# 统计分类情况_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).squeeze().cpu().sum().numpy()# 打印训练信息loss_mean += loss.item()train_curve.append(loss.item())if (i+1) % log_interval == 0:loss_mean = loss_mean / log_intervalprint("Training:Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%} lr:{}".format(epoch, MAX_EPOCH, i+1, len(train_loader), loss_mean, correct / total, scheduler.get_last_lr()))loss_mean = 0.scheduler.step()  # 更新学习率# validate the modelif (epoch+1) % val_interval == 0:correct_val = 0.total_val = 0.loss_val = 0.vgg16_model.eval()with torch.no_grad():for j, data in enumerate(valid_loader):inputs, labels = datainputs, labels = inputs.to(device), labels.to(device)bs, ncrops, c, h, w = inputs.size()outputs = vgg16_model(inputs.view(-1, c, h, w))outputs_avg = outputs.view(bs, ncrops, -1).mean(1)loss = criterion(outputs_avg, labels)_, predicted = torch.max(outputs_avg.data, 1)total_val += labels.size(0)correct_val += (predicted == labels).squeeze().cpu().sum().numpy()loss_val += loss.item()loss_val_mean = loss_val/len(valid_loader)valid_curve.append(loss_val_mean)print("Valid:\t Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(epoch, MAX_EPOCH, j+1, len(valid_loader), loss_val_mean, correct_val / total_val))vgg16_model.train()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/33378.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Windows 7 专业版如何安装英文、中文语言包

下载相应的语言包,以管理员身份运行下载的exe文件,会在该exe文件的旁边生成一个lp.cab文件,赶快将此lp.cab复制一个副本文把lp.cab放到C:根目录接下来的步骤: 1. 在所有程序附件中,以管理员运行命令行:在命令窗口打入如…

chat中文国内版软件开发

如果要开发中文国内版的Chat软件,可能会包括以下一些功能: 中文自然语言处理:对于中文文本,需要进行中文自然语言处理,包括分词、词性标注、命名实体识别、情感分析等。 智能问答和对话系统:开发智…

Win32:C++其实早已支持中文编程

我们以前学习C/C的时候,对于变量和标识符的命名都有如下规则: 变量名必须由字母、数字、下划线构成只能以字母、下划线开头 似乎对中文不太友善啊,于是后来出现了一些中文编程的呼声,甚至还真的出现了一些中文编程语言。 其实在…

硅基MEMS制造技术分析

MEMS(微电子机械系统)技术是一种使产品集成化、微型化、智能化的微型机电系统。在半导体集成电路技术之上发展起来的硅基MEMS制造技术目前使用十分广泛。   国外技术发展日趋成熟 上世纪80年代,在美国政府的高度重视下MEMS技术研发开始起步。1992年“美国国家关…

硅基罗丹明铜离子荧光探针/烷氧基羰基取代硅基罗丹明衍生物

硅基罗丹明类荧光染料分子由于其良好的光谱学和化学性质,被应用于荧光探针的合成中。通过将罗丹明分子中的氧原子用硅原子进行取代,使其广谱范围红移,可以满足近红外荧光检测需求,同时保留了罗丹明染料诸多性质,如荧光…

硅基罗丹明近红外荧光染料/硅基罗丹明近红外发射双光子

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文 前言 荧光探针分子通常是由荧光团(Fluorophore),识别基团(Receptor)和起传递作用的链接基团(Spacer)组成,荧光团是将识别基团选择性的与被分析物结合或…

硅基光电子集成

摩尔定律:摩尔定律是由英特尔(Intel)创始人之一戈登摩尔(Gordon Moore)提出来的。其内容为:当价格不变时,集成电路上可容纳的元器件的数目,约每隔18-24个月便会增加一倍,…

硅基芯片与光纤耦合及封装

前言 随着硅基光子集成的设计与工艺条件逐步完善,芯片上各类有源、无器以 及它们组合而成的光模块,目前已经能够很好地实现小尺寸下信号处理。对于一 个完整的光通信链路,常由“发射端 ——传输介质 ——接收端”三部分构成,而集…

[项目管理-2]:软硬件项目管理 - 干系人管理、实践活动、常见工具

目录 第2章 干系人管理(谁来做?) 2.1 概述 2.2 常见的干系人 2.3 干系人的信息内容 2.4 干系人登记册存放 2.5 干系人表格的作用(登记册) 2.6 干系人管理策略 2.7 干系人管理与其他模块的关系 第2章 干系人管理…

虚拟人都能导购了,还要实体导购干什么?

作者 | 曾响铃 文 | 响铃说 请珍惜还在和你内卷的同事吧,毕竟他还是个“真人”。 2021年年底,当某地产商“总部优秀新人奖”被颁给首位虚拟员工“崔筱盼”时,还在“反内卷”的打工人瞬间“悟了”——原来,“我的同事不是人”这…

[管理与领导-7]:新任管理第1课:管理转身--从技术业务走向管理 - 管理常识1

目录 第1章 管理基本概念 1.1 什么是管理? 1.2 管理的要素与职能 第2章 管理是什么? 2.1 以终为始 2.2 资源的优化配置 2.3 分而治之:分目标、分任务、分权力、分利益 2.4 目标明确 2.5 优先级 2.6 知人善用,人尽其才 …

硅基生命

硅基生命 编辑 讨论2 硅基生命是相对于碳基生命而言的。所谓碳基生命,根源于有机物的原始概念:只能由生物产生的物质(有机物现在指的是除了碳氧化物,碳硫化物,碳酸盐,氰化物,碳化物&#xff0c…

硅基压力传感器—MEMS

背景介绍 压力传感器作为触觉传感的核心部件,要求对外界机械力进行精确、稳定的探测和反馈,是人机交互系统发展的关键,在工业机器人、电子皮肤等领域具有广泛的应用。调研可知,压力测量的方法主要有:1. 电阻应变式压力…

硅基生命之漫谈-1:天马行空

1. 身(生理)》硬件 1.1 分解与组合 原子-》分子-》有机分子-》基因-》器官-》组织-》人体 1.2 五官 眼》摄像头 耳》拾音器 鼻》各种气体床传感器 口》发声器 舌》味道传感器 1.3 人体八大系统 运动系统(手,足,…

[人工智能-综述-3]:人工智能与硅基生命,人类终将成为造物主

作者主页(文火冰糖的硅基工坊):https://blog.csdn.net/HiWangWenBing 本文网址:https://blog.csdn.net/HiWangWenBing/article/details/119061112 目录 引言 第1 部分 人工智能的过去 - 人类的智能 1.1 宇宙的诞生与生命的出现 1.2 人类的出现与人…

npm包管理,这一篇就够了

文章目录 人工智能福利文章npm是什么?npm install 安装模块npm uninstall 卸载模块npm update 更新模块npm outdated 检查模块是否已经过时npm ls 查看安装的模块npm init 在项目中引导创建一个package.json文件npm help 查看某条命令的详细帮助npm root 查看包的安…

飞书即时消息无需代码连接PaLM Google AI的方法

飞书即时消息用户使用场景: 许多企业使用飞书系统办公,现在有了PaLM Google AI技术,能够根据用户的提问来自动产生回答,而且不需要人为干预。企业负责人常常在想,如果可以将PaLM Google AI技术融入到飞书机器人中打造一…

飞书即时消息无需代码连接文本翻译(免费版)的方法

飞书即时消息用户使用场景: 在一个跨国企业中,飞书即时消息应用机器人被用于不同部门的沟通协作。当一个部门的成员收到来自其他部门或者国外的消息时,需要企业人员同步到翻译软件上进行翻译。但这个过程会存在一定的问题: 翻译一…

AI图像生成无需代码连接集简云数据表的方法

1 场景描述 人工智能的出现,各个领域都开始尝试将AI作为提高工作效率的必备工具。除了AI对话等,越来越多的AI图像生成工具也出现在市场上。这些AI图像生成工具可以自动创建惊人的图像、艺术作品和设计,从而帮助设计师和创意人员更快速地实现其…

OpenAI DALL·E无需代码连接集简云小程序的方法

使用场景 随着chatgpt的大火,带来了一波人工客服智能机器人的热潮,除自动聊天外,又增加了ai生成图像的功能,也有越来越多企业关注到了AI绘画的热度,并选择在这一领域加速布局。 在设计领域中,它可以帮助设计…