Python和PyTorch库实现基于生成对抗网络(GAN)将小纹理合成大纹理的详细步骤及代码示例

以下是使用Python和PyTorch库实现基于生成对抗网络(GAN)将小纹理合成大纹理的详细步骤及代码示例。

思路概述

我们将使用生成对抗网络(GAN)来完成小纹理到大纹理的合成任务。GAN由生成器(Generator)和判别器(Discriminator)组成。生成器的目标是生成逼真的大纹理图像,而判别器的任务是区分生成的图像和真实的大纹理图像。通过两者的对抗训练,最终生成器能够学习到如何合成高质量的大纹理图像。

代码实现

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import os# 定义数据集类
class TextureDataset(Dataset):def __init__(self, root_dir, transform=None):self.root_dir = root_dirself.transform = transformself.image_files = [os.path.join(root_dir, f) for f in os.listdir(root_dir) if f.endswith(('.png', '.jpg', '.jpeg'))]def __len__(self):return len(self.image_files)def __getitem__(self, idx):image_path = self.image_files[idx]image = Image.open(image_path).convert('RGB')if self.transform:image = self.transform(image)return image# 定义生成器
class Generator(nn.Module):def __init__(self, z_dim, img_channels):super(Generator, self).__init__()self.gen = nn.Sequential(self._block(z_dim, 1024, 4, 1, 0),self._block(1024, 512, 4, 2, 1),self._block(512, 256, 4, 2, 1),self._block(256, 128, 4, 2, 1),nn.ConvTranspose2d(128, img_channels, kernel_size=4, stride=2, padding=1),nn.Tanh())def _block(self, in_channels, out_channels, kernel_size, stride, padding):return nn.Sequential(nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),nn.BatchNorm2d(out_channels),nn.ReLU(True))def forward(self, x):return self.gen(x)# 定义判别器
class Discriminator(nn.Module):def __init__(self, img_channels):super(Discriminator, self).__init__()self.disc = nn.Sequential(nn.Conv2d(img_channels, 64, kernel_size=4, stride=2, padding=1),nn.LeakyReLU(0.2),self._block(64, 128, 4, 2, 1),self._block(128, 256, 4, 2, 1),self._block(256, 512, 4, 2, 1),nn.Conv2d(512, 1, kernel_size=4, stride=2, padding=0),nn.Sigmoid())def _block(self, in_channels, out_channels, kernel_size, stride, padding):return nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),nn.BatchNorm2d(out_channels),nn.LeakyReLU(0.2))def forward(self, x):return self.disc(x)# 超参数设置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
lr = 0.0002
batch_size = 32
image_size = 64
z_dim = 100
img_channels = 3
num_epochs = 50# 数据预处理
transform = transforms.Compose([transforms.Resize((image_size, image_size)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])# 加载数据集
dataset = TextureDataset(root_dir='path/to/your/texture/images', transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)# 初始化生成器和判别器
gen = Generator(z_dim, img_channels).to(device)
disc = Discriminator(img_channels).to(device)# 定义优化器和损失函数
opt_gen = optim.Adam(gen.parameters(), lr=lr, betas=(0.5, 0.999))
opt_disc = optim.Adam(disc.parameters(), lr=lr, betas=(0.5, 0.999))
criterion = nn.BCELoss()# 训练循环
for epoch in range(num_epochs):for i, real_images in enumerate(dataloader):real_images = real_images.to(device)### 训练判别器opt_disc.zero_grad()noise = torch.randn(batch_size, z_dim, 1, 1).to(device)fake_images = gen(noise)disc_real = disc(real_images).reshape(-1)lossD_real = criterion(disc_real, torch.ones_like(disc_real))disc_fake = disc(fake_images.detach()).reshape(-1)lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake))lossD = (lossD_real + lossD_fake) / 2lossD.backward()opt_disc.step()### 训练生成器opt_gen.zero_grad()output = disc(fake_images).reshape(-1)lossG = criterion(output, torch.ones_like(output))lossG.backward()opt_gen.step()print(f"Epoch [{epoch+1}/{num_epochs}] Loss D: {lossD.item():.4f}, Loss G: {lossG.item():.4f}")# 生成大纹理图像
num_samples = 1
noise = torch.randn(num_samples, z_dim, 1, 1).to(device)
generated_images = gen(noise)
generated_images = (generated_images + 1) / 2  # 反归一化
generated_images = generated_images.cpu().detach().permute(0, 2, 3, 1).numpy()# 显示生成的图像
plt.imshow(generated_images[0])
plt.axis('off')
plt.show()

代码说明

  1. 数据集类 TextureDataset:用于加载纹理图像数据集,并进行必要的预处理。
  2. 生成器 Generator:通过一系列反卷积层将随机噪声向量转换为大纹理图像。
  3. 判别器 Discriminator:使用卷积层来区分真实的大纹理图像和生成的图像。
  4. 训练循环:交替训练判别器和生成器,通过对抗训练不断提高生成器的性能。
  5. 生成大纹理图像:训练完成后,使用生成器生成大纹理图像并显示。

使用方法

  1. 将代码中的 'path/to/your/texture/images' 替换为你实际的小纹理图像文件夹路径。
  2. 确保你已经安装了PyTorch和相关的依赖库。
  3. 运行代码,等待训练完成,最后会显示生成的大纹理图像。

通过以上步骤,你就可以使用生成对抗网络将小纹理合成成大纹理图像。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/34462.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

正点原子[第三期]Arm(iMX6U)Linux移植学习笔记-4 uboot目录分析

前言: 本文是根据哔哩哔哩网站上“Arm(iMX6U)Linux系统移植和根文件系统构键篇”视频的学习笔记,在这里会记录下正点原子 I.MX6ULL 开发板的配套视频教程所作的实验和学习笔记内容。本文大量引用了正点原子教学视频和链接中的内容。 引用: …

视频AI方案:数据+算力+算法,人工智能的三大基石

背景分析 随着信息技术的迅猛发展,人工智能(AI)已经逐渐渗透到我们生活的各个领域,从智能家居到自动驾驶,从医疗诊断到金融风控,AI的应用正在改变着我们的生活方式。而数据、算法和算力,正是构…

MySQL -- 表的约束

概念引入:真正的约束表字段的是数据类型,但是数据类型的约束方式比较单一的,所以需要一些额外的一些约束,用于表示数据的合法性,在只有数据类型一种约束的情况下,我们比较难保证数据是百分百合法。通过添加…

嵌入式Zephyr RTOS面试题及参考答案

目录 Zephyr RTOS 的主要设计目标是什么?适用于哪些领域? Zephyr 支持哪些内核对象类型?举例说明其应用场景。 Zephyr 支持哪些线程同步机制?举例说明其适用场景。 Zephyr 内核支持哪些任务状态?状态转换的条件是什么? Zephyr 如何实现低延迟中断处理?(如直接中断服…

《TCP/IP网络编程》学习笔记 | Chapter 18:多线程服务器端的实现

《TCP/IP网络编程》学习笔记 | Chapter 18:多线程服务器端的实现 《TCP/IP网络编程》学习笔记 | Chapter 18:多线程服务器端的实现线程的概念引入线程的背景线程与进程的区别 线程创建与运行pthread_createpthread_join可在临界区内调用的函数工作&#…

C++相关基础概念之入门讲解(上)

1. 命名空间 C中的命名空间(namespace)是用来避免命名冲突问题的一种机制。通过将类、函数、变量等封装在命名空间中,可以避免不同部分的代码中出现相同名称的冲突。在C中,可以使用namespace关键字来定义命名空间。 然后我们在调…

创新技术引领软件供应链安全,助力数字中国建设

编者按 随着数字化转型的加速,针对软件供应链的攻击事件呈快速增长态势,目前已成为网络空间安全的焦点。如何将安全嵌入到软件开发到运营的全流程,实现防护技术的自动化、一体化、智能化,成为技术领域追逐的热点。 悬镜安全作为…

PyTorch 系列教程:使用CNN实现图像分类

图像分类是计算机视觉领域的一项基本任务,也是深度学习技术的一个常见应用。近年来,卷积神经网络(cnn)和PyTorch库的结合由于其易用性和鲁棒性已经成为执行图像分类的流行选择。 理解卷积神经网络(cnn) 卷…

【2025】基于python+django的驾校招生培训管理系统(源码、万字文档、图文修改、调试答疑)

课题功能结构图如下: 驾校招生培训管理系统设计 一、课题背景 随着机动车保有量的不断增加,人们对驾驶技能的需求也日益增长。驾校作为驾驶培训的主要机构,面临着激烈的市场竞争和学员需求多样化等挑战。传统的驾校管理模式往往依赖于人工操作…

【JavaWeb】快速入门——HTMLCSS

文章目录 一、 HTML简介1、HTML概念2、HTML文件结构3、可视化网页结构 二、 HTML标签语法1、标题标签2、段落标签3、超链接4、换行5、无序列表6、路径7、图片8、块1 盒子模型2 布局标签 三、 使用HTML表格展示数据1、定义表格2、合并单元格横向合并纵向合并 四、 使用HTML表单收…

MySQL 优化方案

一、MySQL 查询过程 MySQL 查询过程是指从客户端发送 SQL 语句到 MySQL 服务器,再到服务器返回结果集的整个过程。这个过程涉及多个组件的协作,包括连接管理、查询解析、优化、执行和结果返回等。 1.1 查询过程的关键组件 连接管理器:管理…

服务性能防腐体系:基于自动化压测的熔断机制

01# 背景 在系统架构的演进过程中,项目初始阶段都会通过压力测试构建安全护城河,此时的服务性能与资源水位保持着黄金比例关系。然而在业务高速发展时期,每个冲刺周期都被切割成以业务需求为单位的开发单元,压力测试逐渐从必选项…

六十天前端强化训练之第二十天React Router 基础详解

欢迎来到编程星辰海的博客讲解 看完可以给一个免费的三连吗,谢谢大佬! 目录 一、核心概念 1.1 核心组件 1.2 路由模式对比 二、核心代码示例 2.1 基础路由配置 2.2 动态路由示例 2.3 嵌套路由实现 2.4 完整示例代码 三、关键功能实现效果 四、…

grad_traj_optimization 开源项目

开源项目 grad_traj_optimization 使用教程-CSDN博客 ubuntu如何切换到root用户_ubuntu切换到root用户-CSDN博客 catkin_make: command not found 解决办法_catkin-make not found-CSDN博客 这就说明需要编译的package虽然存在,但不在指定的目录下。catkin_make命…

深圳南柯电子|净水器EMC测试整改:水质安全与电磁兼容性的双赢

在当今注重健康生活的时代,净水器作为家庭用水安全的第一道防线,其性能与安全性备受关注。其中,电磁兼容性(EMC)测试是净水器产品上市前不可或缺的一环,它直接关系到产品在复杂电磁环境中的稳定运行及不对其…

要登录的设备ip未知时的处理方法

目录 1 应用场景... 1 2 解决方法:... 1 2.1 wireshark设置... 1 2.2 获取网口mac地址,wireshark抓包前预过滤掉自身mac地址的影响。... 2 2.3 pc网口和设备对接... 3 2.3.1 情况1:... 3 2.3.2 情…

GHCTF web方向题解

upload?SSTI! import os import refrom flask import Flask, request, jsonify,render_template_string,send_from_directory, abort,redirect from werkzeug.utils import secure_filename import os from werkzeug.utils import secure_filenameapp Flask(__name__)# 配置…

Vision Transformer (ViT):将Transformer带入计算机视觉的革命性尝试(代码实现)

Vision Transformer (ViT):将Transformer带入计算机视觉的革命性尝试 作为一名深度学习研究者,如果你对自然语言处理(NLP)领域的Transformer架构了如指掌,那么你一定不会对它在序列建模中的强大能力感到陌生。然而&am…

蓝耘携手通义万象 2.1 图生视频:开启创意无限的共享新时代

在科技飞速发展的今天,各种新奇的技术不断涌现,改变着我们的生活和工作方式。蓝耘和通义万象 2.1 图生视频就是其中两项非常厉害的技术。蓝耘就像是一个超级大管家,能把各种资源管理得井井有条;而通义万象 2.1 图生视频则像是一个…

IEC61850标准下MMS 缓存报告控制块 ResvTms详细解析

IEC61850标准是电力系统自动化领域唯一的全球通用标准。IEC61850通过标准的实现,使得智能变电站的工程实施变得规范、统一和透明,这大大提高了变电站自动化系统的技术水平和安全稳定运行水平。 在 IEC61850 标准体系中,ResvTms(r…