3D Gaussian Splatting代码详解(一):模型训练、数据加载

1 模型训练

def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoint_iterations, checkpoint, debug_from):first_iter = 0# 初始化高斯模型,用于表示场景中的每个点的3D高斯分布gaussians = GaussianModel(dataset.sh_degree)# 初始化场景对象,加载数据集和对应的相机参数scene = Scene(dataset, gaussians)# 为高斯模型参数设置优化器和学习率调度器gaussians.training_setup(opt)# 如果提供了checkpoint,则从checkpoint加载模型参数并恢复训练进度if checkpoint:(model_params, first_iter) = torch.load(checkpoint)gaussians.restore(model_params, opt)# 设置背景颜色,白色或黑色取决于数据集要求bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0]background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")# 创建CUDA事件用于计时iter_start = torch.cuda.Event(enable_timing=True)iter_end = torch.cuda.Event(enable_timing=True)viewpoint_stack = Noneema_loss_for_log = 0.0# 使用tqdm库创建进度条,追踪训练进度progress_bar = tqdm(range(first_iter, opt.iterations), desc="Training progress")first_iter += 1for iteration in range(first_iter, opt.iterations + 1):# 记录迭代开始时间iter_start.record()# 根据当前迭代次数更新学习率gaussians.update_learning_rate(iteration)# 每1000次迭代,提升球谐函数的次数以改进模型复杂度if iteration % 1000 == 0:gaussians.oneupSHdegree()# 随机选择一个训练用的相机视角if not viewpoint_stack:viewpoint_stack = scene.getTrainCameras().copy()viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack)-1))# 如果达到调试起始点,启用调试模式if (iteration - 1) == debug_from:pipe.debug = True# 根据设置决定是否使用随机背景颜色bg = torch.rand((3), device="cuda") if opt.random_background else background# 渲染当前视角的图像render_pkg = render(viewpoint_cam, gaussians, pipe, bg)image, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"]# 计算渲染图像与真实图像之间的损失gt_image = viewpoint_cam.original_image.cuda()Ll1 = l1_loss(image, gt_image)loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (1.0 - ssim(image, gt_image))loss.backward()# 记录迭代结束时间iter_end.record()with torch.no_grad():# 更新进度条和损失显示ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_logif iteration % 10 == 0:progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{7}f}"})progress_bar.update(10)if iteration == opt.iterations:progress_bar.close()# 定期记录训练数据并保存模型training_report(tb_writer, iteration, Ll1, loss, l1_loss, iter_start.elapsed_time(iter_end), testing_iterations, scene, render, (pipe, background))if iteration in saving_iterations:print("\n[ITER {}] Saving Gaussians".format(iteration))scene.save(iteration)# 在指定迭代区间内,对3D高斯模型进行增密和修剪if iteration < opt.densify_until_iter:gaussians.max_radii2D[visibility_filter] = torch.max(gaussians.max_radii2D[visibility_filter], radii[visibility_filter])gaussians.add_densification_stats(viewspace_point_tensor, visibility_filter)if iteration > opt.densify_from_iter and iteration % opt.densification_interval == 0:size_threshold = 20 if iteration > opt.opacity_reset_interval else Nonegaussians.densify_and_prune(opt.densify_grad_threshold, 0.005, scene.cameras_extent, size_threshold)if iteration % opt.opacity_reset_interval == 0 or (dataset.white_background and iteration == opt.densify_from_iter):gaussians.reset_opacity()# 执行优化器的一步,并准备下一次迭代if iteration < opt.iterations:gaussians.optimizer.step()gaussians.optimizer.zero_grad(set_to_none=True)# 定期保存checkpointif iteration in checkpoint_iterations:print("\n[ITER {}] Saving Checkpoint".format(iteration))torch.save((gaussians.capture(), iteration), scene.model_path + "/chkpnt" + str(iteration) + ".pth")

2 数据加载

class Scene:"""Scene 类用于管理场景的3D模型,包括相机参数、点云数据和高斯模型的初始化和加载"""def __init__(self, args: ModelParams, gaussians: GaussianModel, load_iteration=None, shuffle=True, resolution_scales=[1.0]):"""初始化场景对象:param args: 包含模型路径和源路径等模型参数:param gaussians: 高斯模型对象,用于场景点的3D表示:param load_iteration: 指定加载模型的迭代次数,如果为-1,则自动寻找最大迭代次数:param shuffle: 是否在训练前打乱相机列表:param resolution_scales: 分辨率比例列表,用于处理不同分辨率的相机"""self.model_path = args.model_path  # 模型文件保存路径self.loaded_iter = None  # 已加载的迭代次数self.gaussians = gaussians  # 高斯模型对象# 检查并加载已有的训练模型if load_iteration:if load_iteration == -1:self.loaded_iter = searchForMaxIteration(os.path.join(self.model_path, "point_cloud"))else:self.loaded_iter = load_iterationprint(f"Loading trained model at iteration {self.loaded_iter}")self.train_cameras = {}  # 用于训练的相机参数self.test_cameras = {}  # 用于测试的相机参数# 根据数据集类型(COLMAP或Blender)加载场景信息if os.path.exists(os.path.join(args.source_path, "sparse")):scene_info = sceneLoadTypeCallbacks["Colmap"](args.source_path, args.images, args.eval)elif os.path.exists(os.path.join(args.source_path, "transforms_train.json")):print("Found transforms_train.json file, assuming Blender data set!")scene_info = sceneLoadTypeCallbacks["Blender"](args.source_path, args.white_background, args.eval)else:assert False, "Could not recognize scene type!"# 如果是初次训练,初始化3D高斯模型;否则,加载已有模型if self.loaded_iter:self.gaussians.load_ply(os.path.join(self.model_path, "point_cloud", "iteration_" + str(self.loaded_iter), "point_cloud.ply"))else:self.gaussians.create_from_pcd(scene_info.point_cloud, self.cameras_extent)# 根据resolution_scales加载不同分辨率的训练和测试位姿for resolution_scale in resolution_scales:print("Loading Training Cameras")self.train_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.train_cameras, resolution_scale, args)print("Loading Test Cameras")self.test_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.test_cameras, resolution_scale, args)def save(self, iteration):"""保存当前迭代下的3D高斯模型点云。:param iteration: 当前的迭代次数。"""point_cloud_path = os.path.join(self.model_path, f"point_cloud/iteration_{iteration}")self.gaussians.save_ply(os.path.join(point_cloud_path, "point_cloud.ply"))def getTrainCameras(self, scale=1.0):"""获取指定分辨率比例的训练相机列表:param scale: 分辨率比例:return: 指定分辨率比例的训练相机列表"""return self.train_cameras[scale]sceneLoadTypeCallbacks = {"Colmap": readColmapSceneInfo,"Blender" : readNerfSyntheticInfo
}
def readColmapSceneInfo(path, images, eval, llffhold=8):# 尝试读取COLMAP处理结果中的二进制相机外参和内参文件try:cameras_extrinsic_file = os.path.join(path, "sparse/0", "images.bin")cameras_intrinsic_file = os.path.join(path, "sparse/0", "cameras.bin")cam_extrinsics = read_extrinsics_binary(cameras_extrinsic_file)cam_intrinsics = read_intrinsics_binary(cameras_intrinsic_file)except:# 如果二进制文件读取失败,尝试读取文本格式的相机外参和内参文件cameras_extrinsic_file = os.path.join(path, "sparse/0", "images.txt")cameras_intrinsic_file = os.path.join(path, "sparse/0", "cameras.txt")cam_extrinsics = read_extrinsics_text(cameras_extrinsic_file)cam_intrinsics = read_intrinsics_text(cameras_intrinsic_file)# 定义存放图片的目录,如果未指定则默认为"images"reading_dir = "images" if images is None else images# 读取并处理相机参数,转换为内部使用的格式cam_infos_unsorted = readColmapCameras(cam_extrinsics=cam_extrinsics, cam_intrinsics=cam_intrinsics, images_folder=os.path.join(path, reading_dir))# 根据图片名称对相机信息进行排序,以保证顺序一致性cam_infos = sorted(cam_infos_unsorted.copy(), key=lambda x: x.image_name)# 根据是否为评估模式(eval),将相机分为训练集和测试集# 如果为评估模式,根据llffhold参数(通常用于LLFF数据集)间隔选择测试相机if eval:train_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold != 0]test_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold == 0]else:# 如果不是评估模式,所有相机均为训练相机,测试相机列表为空train_cam_infos = cam_infostest_cam_infos = []# 计算场景归一化参数,这是为了处理不同尺寸和位置的场景,使模型训练更稳定nerf_normalization = getNerfppNorm(train_cam_infos)# 尝试读取点云数据,优先从PLY文件读取,如果不存在,则尝试从BIN或TXT文件转换并保存为PLY格式ply_path = os.path.join(path, "sparse/0/points3D.ply")bin_path = os.path.join(path, "sparse/0/points3D.bin")txt_path = os.path.join(path, "sparse/0/points3D.txt")if not os.path.exists(ply_path):print("Converting point3d.bin to .ply, will happen only the first time you open the scene.")try:xyz, rgb, _ = read_points3D_binary(bin_path)except:xyz, rgb, _ = read_points3D_text(txt_path)storePly(ply_path, xyz, rgb)try:pcd = fetchPly(ply_path)except:pcd = None# 组装场景信息,包括点云、训练用相机、测试用相机、场景归一化参数和点云文件路径scene_info = SceneInfo(point_cloud=pcd,train_cameras=train_cam_infos,test_cameras=test_cam_infos,nerf_normalization=nerf_normalization,ply_path=ply_path)return scene_infodef readColmapCameras(cam_extrinsics, cam_intrinsics, images_folder):cam_infos = []  # 初始化用于存储相机信息的列表# 遍历所有相机的外参for idx, key in enumerate(cam_extrinsics):# 动态显示读取相机信息的进度sys.stdout.write('\r')sys.stdout.write("Reading camera {}/{}".format(idx+1, len(cam_extrinsics)))sys.stdout.flush()# 获取当前相机的外参和内参extr = cam_extrinsics[key]  # 当前相机的外参intr = cam_intrinsics[extr.camera_id]  # 根据外参中的camera_id找到对应的内参height = intr.height  # 相机图片的高度width = intr.width  # 相机图片的宽度uid = intr.id  # 相机的唯一标识符# 将四元数表示的旋转转换为旋转矩阵RR = np.transpose(qvec2rotmat(extr.qvec))# 外参中的平移向量T = np.array(extr.tvec)# 根据相机内参模型计算视场角(FoV)if intr.model == "SIMPLE_PINHOLE":# 如果是简单针孔模型,只有一个焦距参数focal_length_x = intr.params[0]FovY = focal2fov(focal_length_x, height)  # 计算垂直方向的视场角FovX = focal2fov(focal_length_x, width)  # 计算水平方向的视场角elif intr.model == "PINHOLE":# 如果是针孔模型,有两个焦距参数focal_length_x = intr.params[0]focal_length_y = intr.params[1]FovY = focal2fov(focal_length_y, height)  # 使用y方向的焦距计算垂直视场角FovX = focal2fov(focal_length_x, width)  # 使用x方向的焦距计算水平视场角else:# 如果不是以上两种模型,抛出错误assert False, "Colmap camera model not handled: only undistorted datasets (PINHOLE or SIMPLE_PINHOLE cameras) supported!"# 构建图片的完整路径image_path = os.path.join(images_folder, os.path.basename(extr.name))image_name = os.path.basename(image_path).split(".")[0]  # 提取图片名称,不包含扩展名# 使用PIL库打开图片文件image = Image.open(image_path)# 创建并存储相机信息cam_info = CameraInfo(uid=uid, R=R, T=T, FovY=FovY, FovX=FovX, image=image,image_path=image_path, image_name=image_name, width=width, height=height)cam_infos.append(cam_info)# 在读取完所有相机信息后换行sys.stdout.write('\n')# 返回整理好的相机信息列表return cam_infos

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/459970.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

[MySQL#6] 表的CRUD (1) | Create | Retrieve(查) | where

目录 1. 插入 1.1 单行数据 - 全列插入 指定列插入 1.2 多行数据 - 全列插入 指定列插入 1.3 更新 1.4 替换 2. 查找 2.1 select 列 2.2 where 条件 具体案例 2.3 结果排序 总结关键字执行顺序 2.4 筛选分页结果 CRUD : Create(创建),Retrieve(读取)&…

[机器学习]集成学习

1 集成学习 强强联合、弱弱变强Bagging(平权投票):随机森林Boosting(加权投票):Adaboost、GBDT、XGBoost、LightGBM 2 随机森林 3 Adaboost 放大错误数据,缩小正确数据

第三十三篇:TCP协议如何避免/减少网络拥塞,TCP系列八

一、流量控制 一般来说,我们总是希望数据传输得更快一些,但是如果发送方把数据发送得太快,接收方可能来不及接收,造成数据的丢失,数据重发,造成网络资源的浪费甚至网络拥塞。所谓的流量控制(fl…

在Excel中如何快速筛选非特定颜色

Excel中的自动筛选是个非常强大的工具,不仅可以筛选内容,而且可以筛选颜色,例如筛选A列红色单元格。但是有时希望筛选除了红色之外的单元格(下图右侧所示),其他单元格的填充色不固定,有几种颜色…

数据结构---链表(一)【不带头单向非循环】

文章目录 链表概念链表的使用LinkedList 的几种遍历方式单链表的模拟实现(不带头)链表面试题 观察ArrayList 顺序表的源码发现,底层是使用数组实现的。由于其底层是一段连续空间,当在ArrayList任意位置插入或者删除元素时&#xf…

Pytorch(一)

一.PyTorch环境配置及安装 1.1 工具安装 1.1.1 Anaconda下载 清华大学镜像站下载,版本为Anaconda3-5.2.0-Windows-x86_64(对应python3.6.5) Index of /anaconda/archive/ | 清华大学开源软件镜像站 | Tsinghua Open Source Mirror 1.1.2…

关于我的数据结构与算法——初阶第二篇(排序)

(叠甲:如有侵权请联系,内容都是自己学习的总结,一定不全面,仅当互相交流(轻点骂)我也只是站在巨人肩膀上的一个小卡拉米,已老实,求放过)。 排序的概念及其运…

AI驱动的低代码未来:加速应用开发的智能解决方案

引言 随着数字化转型的浪潮席卷全球,企业对快速构建应用程序的需求愈发强烈。然而,传统的软件开发周期冗长、成本高昂,往往无法满足快速变化的市场需求。在此背景下,低代码平台逐渐成为开发者和企业的优选方案,以其“低…

三周精通FastAPI:21 子依赖项和路径操作装饰器依赖项

官方文档:https://fastapi.tiangolo.com/zh/tutorial/dependencies/sub-dependencies/#_6 子依赖项 FastAPI 支持创建含子依赖项的依赖项。 并且,可以按需声明任意深度的子依赖项嵌套层级。 FastAPI 负责处理解析不同深度的子依赖项。 第一层依赖项 …

模具生产管理系统软件:提升制造业效率的新利器

引言 我们都知道,企业面临着提高生产效率、降低成本和提升产品质量的压力。模具生产作为制造过程中至关重要的一环,如何有效管理和优化模具生产过程,成为企业关注的重点。模具生产管理系统应运而生,能够为企业提供实时监控、流程…

MySQL中,如何定位慢查询?定位到的慢SQL如何分析?

目录 1. 慢查询发生的场景? 2. MySQL中,如何定位慢查询? 2.1 详细解释 3. 定位到的慢SQL如何分析? 3.1 详细说明 1. 慢查询发生的场景? 2. MySQL中,如何定位慢查询? 介绍一下当时产生问题…

「C/C++」C++ 设计模式 之 单例模式(Singleton)

✨博客主页何曾参静谧的博客📌文章专栏「C/C」C/C程序设计📚全部专栏「VS」Visual Studio「C/C」C/C程序设计「UG/NX」BlockUI集合「Win」Windows程序设计「DSA」数据结构与算法「UG/NX」NX二次开发「QT」QT5程序设计「File」数据文件格式「PK」Parasoli…

华为云开源项目Sermant正式成为CNCF官方项目

近日,云原生计算基金会(CNCF)正式接纳由华为云发起的云原生无代理服务网格项目Sermant。Sermant的加入,极大地丰富了云原生微服务治理技术的探索、创新和发展,为CNCF社区注入了新的活力。 Sermant是华为云在微服务治理…

用sdcc给51单片机编译C程序

学习单片机大部分人用的是Keil uVision,虽然好用,可大部分人用的是盗版,其实单片机程序小的话,完全可以用文本编辑器(推荐notepad)编写,然后用免费的sdcc来编译,下面介绍一下大致的过程。 sdcc…

Ajax:表单 模板引擎

Ajax&#xff1a;表单 & 模板引擎 form 表单form 属性 Ajax操控表单事件监听阻止默认行为收集表单数据 模板引擎art-template{{}}语法原文输出条件输出循环输出过滤器 原理 form 表单 在HTML中&#xff0c;可以通过<form>创建一个表单&#xff0c;收集用户信息。而采…

【水下生物数据集】 水下生物识别 深度学习 目标检测 机器视觉 yolo(含数据集)

一、背景意义 随着全球海洋生态环境的日益变化&#xff0c;水下生物的监测和保护变得愈发重要。水下生物种类繁多&#xff0c;包括螃蟹、鱼类、水母、虾、小鱼和海星等&#xff0c;它们在海洋生态系统中扮演着关键角色。传统的水下生物监测方法通常依赖于人工观察&#xff0c;效…

[vulnhub]Kioptrix: Level 1.2 (#3)

https://www.vulnhub.com/entry/kioptrix-level-12-3,24/ 主机发现端口扫描 使用nmap扫描网段类存活主机 因为靶机是我最后添加的&#xff0c;所以靶机IP是169 nmap -sP 192.168.75.0/24 Starting Nmap 7.94SVN ( https://nmap.org ) at 2024-10-29 13:16 CST …

TVM前端研究--Relay

文章目录 深度学习IR梳理1. IR属性2. DL前端发展3. DL编译器4. DL编程语言Relay的主要内容一、Expression in Relay1. Dataflow and Control Fragments2. 变量3. 函数3.1 闭包3.2 多态和类型关系3.3. Call4. 算子5. ADT Constructors6. Moudle和Global Function7. 常量和元组8.…

Ubuntu UFW防火墙规则与命令示例大全

在服务器安全领域&#xff0c;防火墙是守护网络安全的坚实盾牌。UFW&#xff08;Uncomplicated Firewall&#xff09;&#xff0c;即“不复杂的防火墙”&#xff0c;是一个运行在iptables之上的防火墙配置工具&#xff0c;它为Ubuntu系统默认提供了一个简洁的命令行界面&#x…

Linux高阶——1026—验证内存映射mmap函数使用

1、验证共享映射后修改文件内容&#xff0c;是否能够同步 先创建一个映射文件&#xff0c;写入数据 分为四个步骤 1、打开映射文件 设文件描述符&#xff0c;使用open函数 int fd; if((fdopen("mapfile",O_RDWR))-1) { perror("open failed");exit…