tensorboard的基本使用及案例

TensorBoard 是一个可视化工具,用于展示机器学习模型的训练过程和结果。以下是 TensorBoard 的基本使用方法及一些案例。

基本使用

  1. 安装

    • 安装 TensorBoard:

      pip install tensorboard
    • 如果使用 PyTorch,还需要安装 torchtorchvision

      pip install torch torchvision
  2. 启动 TensorBoard

    • 在命令行中运行以下命令启动 TensorBoard:

      tensorboard --logdir=runs

      其中 --logdir 指定日志文件所在的目录。

  3. 在代码中使用 TensorBoard

    • PyTorch 示例:以下是一个简单的线性回归模型,使用 TensorBoard 记录训练过程中的损失值。

      import torch
      import torch.nn as nn
      import torch.optim as optim
      from torch.utils.tensorboard import SummaryWriter
      import numpy as np
      import matplotlib.pyplot as plt# 生成数据
      x = np.random.rand(100, 1) * 20
      y = 3 * x + np.random.randn(100, 1) * 5# 转换为 Tensor
      x_tensor = torch.FloatTensor(x)
      y_tensor = torch.FloatTensor(y)# 定义线性模型
      model = nn.Linear(1, 1)# 损失函数和优化器
      criterion = nn.MSELoss()
      optimizer = optim.SGD(model.parameters(), lr=0.01)# 初始化 TensorBoard
      writer = SummaryWriter('runs/linear_regression')# 训练模型
      for epoch in range(100):model.train()optimizer.zero_grad()outputs = model(x_tensor)loss = criterion(outputs, y_tensor)loss.backward()optimizer.step()# 记录损失writer.add_scalar('Loss/train', loss.item(), epoch)# 关闭 TensorBoard
      writer.close()# 绘制真实与预测数据
      plt.figure(figsize=(10, 6))
      plt.scatter(x, y, label='Data', color='blue')
      plt.plot(x, model(x_tensor).detach().numpy(), label='Prediction', color='red')
      plt.xlabel('x')
      plt.ylabel('y')
      plt.legend()
      plt.title('Linear Regression')
      plt.grid()
      plt.show()

      在浏览器中访问 http://localhost:6006,可以看到实时的训练损失变化曲线。

  4. TensorBoard 的可视化功能

    • 标量可视化:使用 add_scalar 方法记录标量数据,如损失值或准确率。

    • 图像可视化:使用 add_image 方法记录图像数据。

    • 直方图可视化:使用 add_histogram 方法记录直方图数据。

    • 图形可视化:使用 add_graph 方法记录模型的计算图。

    • 嵌入可视化:使用 add_embedding 方法记录嵌入数据。

案例

  1. 记录训练过程中的标量数据

    • 在 PyTorch 中,可以使用 SummaryWriteradd_scalar 方法记录训练过程中的损失值和准确率。

      writer = SummaryWriter()
      for n_iter in range(100):writer.add_scalar('Loss/train', np.random.random(), n_iter)writer.add_scalar('Loss/test', np.random.random(), n_iter)writer.add_scalar('Accuracy/train', np.random.random(), n_iter)writer.add_scalar('Accuracy/test', np.random.random(), n_iter)
      writer.close()

      在 TensorBoard 中,这些标量数据会被分组显示,便于比较训练和测试阶段的性能。

  2. 记录模型的计算图

    • 在 PyTorch 中,可以使用 add_graph 方法记录模型的计算图。

      from torch.utils.tensorboard import SummaryWriter
      import torchvision
      from torchvision import datasets, transformswriter = SummaryWriter()
      transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
      trainset = datasets.MNIST('mnist_train', train=True, download=True, transform=transform)
      trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
      model = torchvision.models.resnet50(False)
      model.conv1 = torch.nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
      images, labels = next(iter(trainloader))
      writer.add_graph(model, images)
      writer.close()

      在 TensorBoard 中,可以查看模型的计算图,了解模型的结构。

  3. 记录嵌入数据

    • 在 PyTorch 中,可以使用 add_embedding 方法记录嵌入数据。

      from torch.utils.tensorboard import SummaryWriter
      import torchwriter = SummaryWriter()
      meta = []
      while len(meta) < 100:meta = meta + keyword.kwlist
      meta = meta[:100]
      for i, v in enumerate(meta):meta[i] = v + str(i)
      label_img = torch.rand(100, 3, 10, 32)
      for i in range(100):label_img[i] *= i / 100.0
      writer.add_embedding(torch.randn(100, 5), metadata=meta, label_img=label_img)
      writer.close()

      在 TensorBoard 中,嵌入数据(embedding)通常用于可视化高维数据的低维投影,例如通过 t-SNE 或 PCA 方法。以下是一个完整的案例,展示如何在 PyTorch 中记录嵌入数据并使用 TensorBoard 进行可视化。

      完整代码示例
      from torch.utils.tensorboard import SummaryWriter
      import torch
      import torch.nn as nn
      import torch.optim as optim
      import torchvision
      from torchvision import datasets, transforms
      import numpy as np# 设置 TensorBoard
      writer = SummaryWriter('runs/embedding')# 数据预处理
      transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))
      ])# 加载数据集
      trainset = datasets.MNIST('mnist_train', train=True, download=True, transform=transform)
      trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)# 定义一个简单的模型
      class SimpleNet(nn.Module):def __init__(self):super(SimpleNet, self).__init__()self.fc1 = nn.Linear(28 * 28, 128)  # 输入是 28x28 的图像self.fc2 = nn.Linear(128, 10)  # 输出是 10 类def forward(self, x):x = x.view(-1, 28 * 28)  # 将图像展平x = torch.relu(self.fc1(x))x = self.fc2(x)return xmodel = SimpleNet()# 损失函数和优化器
      criterion = nn.CrossEntropyLoss()
      optimizer = optim.SGD(model.parameters(), lr=0.01)# 训练模型
      for epoch in range(10):model.train()for batch_idx, (data, target) in enumerate(trainloader):optimizer.zero_grad()output = model(data)loss = criterion(output, target)loss.backward()optimizer.step()# 每隔一定步数记录一次嵌入数据if batch_idx % 100 == 0:# 获取嵌入层的输出embedding = model.fc1(data.view(-1, 28 * 28)).detach().numpy()# 获取标签labels = target.numpy()# 记录嵌入数据writer.add_embedding(mat=embedding,metadata=labels,label_img=data.unsqueeze(1),global_step=epoch * len(trainloader) + batch_idx)print(f"Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item()}")writer.close()
    • 模型定义

      • 定义了一个简单的全连接网络,包含两个全连接层。fc1 的输出可以作为嵌入数据。

    • 嵌入数据记录

      • 使用 add_embedding 方法记录嵌入数据。mat 参数是嵌入数据,metadata 是标签,label_img 是对应的图像数据。

      • 在训练过程中,每隔一定步数(例如每100步)记录一次嵌入数据。

    • TensorBoard 可视化

      • 启动 TensorBoard:

        bash复制

        tensorboard --logdir=runs
      • 在浏览器中访问 http://localhost:6006,进入 Embeddings 选项卡,可以看到嵌入数据的可视化结果。通过 t-SNE 或 PCA 方法,可以将高维嵌入数据投影到低维空间,观察数据的分布情况。

    • 总结

      TensorBoard 是一个强大的可视化工具,可以帮助我们更好地理解模型的训练过程和结果。通过记录标量、图像、直方图、嵌入数据等信息,我们可以在训练过程中实时观察模型的性能,调整训练策略,优化模型结构。

      其他案例

      1. 图像可视化

      在训练过程中,可以记录中间层的输出图像或输入图像,用于观察模型的特征提取效果。

      writer = SummaryWriter('runs/image_visualization')# 假设 data 是输入图像,output 是模型的某个中间层输出
      writer.add_image('input', data[0], global_step=epoch)
      writer.add_image('output', output[0], global_step=epoch)
      writer.close()
      2. 直方图可视化

      记录模型参数的分布情况,有助于分析模型的训练状态。

      writer = SummaryWriter('runs/histogram_visualization')# 假设 model 是模型对象
      for name, param in model.named_parameters():writer.add_histogram(name, param, global_step=epoch)
      writer.close()
      3. 自定义图表

      TensorBoard 支持自定义图表,可以通过 add_custom_scalars 方法定义多条曲线的组合视图。

      writer = SummaryWriter('runs/custom_scalars')# 定义自定义图表
      writer.add_custom_scalars_multilinechart(['Loss/train', 'Loss/test'])
      writer.add_custom_scalars_marginchart(['Accuracy/train', 'Accuracy/test'])# 记录数据
      for epoch in range(100):writer.add_scalar('Loss/train', np.random.random(), epoch)writer.add_scalar('Loss/test', np.random.random(), epoch)writer.add_scalar('Accuracy/train', np.random.random(), epoch)writer.add_scalar('Accuracy/test', np.random.random(), epoch)writer.close()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/11130.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【ArcGIS遇上Python】批量提取多波段影像至单个波段

本案例基于ArcGIS python,将landsat影像的7个波段影像数据,批量提取至单个波段。 相关阅读:【ArcGIS微课1000例】0141:提取多波段影像中的单个波段 文章目录 一、数据准备二、效果比对二、python批处理1. 编写python代码2. 运行代码一、数据准备 实验数据及完整的python位…

吴恩达深度学习——超参数调试

内容来自https://www.bilibili.com/video/BV1FT4y1E74V&#xff0c;仅为本人学习所用。 文章目录 超参数调试调试选择范围 Batch归一化公式整合 Softmax 超参数调试 调试 目前学习的一些超参数有学习率 α \alpha α&#xff08;最重要&#xff09;、动量梯度下降法 β \bet…

Alibaba开发规范_编程规约之命名风格

文章目录 命名风格的基本原则1. 命名不能以下划线或美元符号开始或结束2. 严禁使用拼音与英文混合或直接使用中文3. 类名使用 UpperCamelCase 风格&#xff0c;但以下情形例外&#xff1a;DO / BO / DTO / VO / AO / PO / UID 等4. 方法名、参数名、成员变量、局部变量使用 low…

从0开始,来看看怎么去linux排查Java程序故障

一&#xff0c;前提准备 最基本前提&#xff1a;你需要有liunx环境&#xff0c;如果没有请参考其它文献在自己得到local建立一个虚拟机去进行测试。 有了虚拟机之后&#xff0c;你还需要安装jdk和配置环境变量 1. 安装JDK&#xff08;以OpenJDK 17为例&#xff09; 下载JDK…

智能园区管理系统助力企业安全与效率双提升的成功案例分析

内容概要 在当今迅速发展的商业环境中&#xff0c;企业面临着资产管理、风险控制和运营效率提高等多重挑战。为了应对这些挑战&#xff0c;智能园区管理系统应运而生&#xff0c;为企业提供了全新的解决方案。例如&#xff0c;快鲸智慧园区&#xff08;楼宇&#xff09;管理系…

nacos 配置管理、 配置热更新、 动态路由

文章目录 配置管理引入jar包添加 bootstrap.yaml 文件配置在application.yaml 中添加自定义信息nacos 配置信息 配置热更新采用第一种配置根据服务名确定配置文件根据后缀确定配置文件 动态路由DynamicRouteLoaderNacosConfigManagerRouteDefinitionWriter 路由配置 配置管理 …

Linux-CentOS的yum源

1、什么是yum yum是CentOS的软件仓库管理工具。 2、yum的仓库 2.1、yum的远程仓库源 2.1.1、国内仓库 国内较知名的网络源(aliyun源&#xff0c;163源&#xff0c;sohu源&#xff0c;知名大学开源镜像等) 阿里源:https://opsx.alibaba.com/mirror 网易源:http://mirrors.1…

16.[前端开发]Day16-HTML+CSS阶段练习(网易云音乐五)

完整代码 网易云-main-left-rank&#xff08;排行榜&#xff09; <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta http-equiv"X-UA-Compatible" content"IEedge"><meta name&q…

【ts + java】古玩系统开发总结

src别名的配置 开发中文件和文件的关系会比较复杂&#xff0c;我们需要给src文件夹一个别名吧 vite.config.js import { defineConfig } from vite import vue from vitejs/plugin-vue import path from path// https://vitejs.dev/config/ export default defineConfig({pl…

使用Pygame制作“俄罗斯方块”游戏

1. 前言 俄罗斯方块&#xff08;Tetris&#xff09; 是一款由方块下落、行消除等核心规则构成的经典益智游戏&#xff1a; 每次从屏幕顶部出现一个随机的方块&#xff08;由若干小方格组成&#xff09;&#xff0c;玩家可以左右移动或旋转该方块&#xff0c;让它合适地堆叠在…

小程序设计和开发:什么是竞品分析,如何进行竞品分析

一、竞品分析的定义 竞品分析是指对竞争对手的产品进行深入研究和比较&#xff0c;以了解市场动态、发现自身产品的优势和不足&#xff0c;并为产品的设计、开发和营销策略提供参考依据。在小程序设计和开发中&#xff0c;竞品分析可以帮助开发者了解同类型小程序的功能、用户体…

Vue简介

目录 Vue是什么&#xff1f;为什么要使用Vue&#xff1f;Vue的三种加载方式拓展&#xff1a;什么是渐进式框架&#xff1f; Vue是什么&#xff1f; Vue是一套用于构建用户界面的渐进式 JavaScript (主张最少)框架 &#xff0c;开发者只需关注视图层。另一方面&#xff0c;当与…

Linux多路转接poll

Linux多路转接poll 1. poll() poll() 结构包含了要监视的 event 和发生的 event &#xff0c;接口使用比 select() 更方便。且 poll 并没有最大数量限制&#xff08;但是数量过大后性能也是会下降&#xff09;。 2. poll() 的工作原理 poll() 不再需要像 select() 那样自行…

C++【深入底层,手撕vector】

vector是向量的意思&#xff0c;看了vector的底层实现之后&#xff0c;能够很明确的认识到它其实就是我们经常使用的顺序表。在我们的认知中&#xff0c;顺序表会有一个数组、数据的size以及容量的大小。vector作为一个向量容器&#xff0c;它可以存放任意类型的数据。所以在实…

基于FPGA的BT656编解码

概述 BT656全称为“ITU-R BT.656-4”或简称“BT656”,是一种用于数字视频传输的接口标准。它规定了数字视频信号的编码方式、传输格式以及接口电气特性。在物理层面上,BT656接口通常包含10根线(在某些应用中可能略有不同,但标准配置为10根)。这些线分别用于传输视频数据、…

关于系统重构实践的一些思考与总结

文章目录 一、前言二、系统重构的范式1.明确目标和背景2.兼容屏蔽对上层的影响3.设计灰度迁移方案3.1 灰度策略3.2 灰度过程设计3.2.1 case1 业务逻辑变更3.2.2 case2 底层数据变更&#xff08;数据平滑迁移&#xff09;3.2.3 case3 在途新旧流程兼容3.2.4 case4 接口变更3.2.5…

Microsoft Power BI:融合 AI 的文本分析

Microsoft Power BI 是微软推出的一款功能强大的商业智能工具&#xff0c;旨在帮助用户从各种数据源中提取、分析和可视化数据&#xff0c;以支持业务决策和洞察。以下是关于 Power BI 的深度介绍&#xff1a; 1. 核心功能与特点 Power BI 提供了全面的数据分析和可视化功能&…

【机器学习】自定义数据集 ,使用朴素贝叶斯对其进行分类

一、贝叶斯原理 贝叶斯算法是基于贝叶斯公式的&#xff0c;其公式为&#xff1a; 其中叫做先验概率&#xff0c;叫做条件概率&#xff0c;叫做观察概率&#xff0c;叫做后验概率&#xff0c;也是我们求解的结果&#xff0c;通过比较后验概率的大小&#xff0c;将后验概率最大的…

AMS仿真方法

1. 准备好verilog文件。并且准备一份.vc文件&#xff0c;将所有的verilog file的路径全部写在里面。 2. 将verilog顶层导入到virtuoso中&#xff1a; 注意.v只要引入顶层即可。不需要全部引入。实际上顶层里面只要包含端口即可&#xff0c;即便是空的也没事。 引入时会报warni…

OpenAI o3-mini全面解析:最新免费推理模型重磅发布

引言 2025年1月31日&#xff0c;OpenAI重磅发布全新推理模型o3-mini。这款模型作为OpenAI推理系列的最新突破&#xff0c;不仅在性能和性价比方面实现跨越式提升&#xff0c;更是首次全面开放免费使用。这一重大举措彰显了OpenAI在人工智能技术普及和成本优化领域的创新决心。…