MLX vs MPS vs CUDA:苹果新机器学习框架的基准测试

如果你是一个Mac用户和一个深度学习爱好者,你可能希望在某些时候Mac可以处理一些重型模型。苹果刚刚发布了MLX,一个在苹果芯片上高效运行机器学习模型的框架。

最近在PyTorch 1.12中引入MPS后端已经是一个大胆的步骤,但随着MLX的宣布,苹果还想在开源深度学习方面有更大的发展。

在本文中,我们将对这些新方法进行测试,在三种不同的Apple Silicon芯片和两个支持cuda的gpu上和传统CPU后端进行基准测试。

这里把基准测试集中在图卷积网络(GCN)模型上。这个模型主要由线性层组成,所以对于其他的模型也应该得到类似的结果。

创造环境

要为MLX构建环境,我们必须指定是使用i386还是arm架构。使用conda,可以使用:

 CONDA_SUBDIR=osx-arm64 conda create -n mlx python=3.10 numpy pytorch scipy requests -c conda-forgeconda activate mlx

如果检查你的env是否实际使用了arm,下面命令的输出应该是arm,而不是i386(因为我们用的Apple Silicon):

 python -c "import platform; print(platform.processor())"

然后就是使用pip安装MLX:

 pip install mlx

GCN模型

GCN模型是图神经网络(GNN)的一种,它使用邻接矩阵(表示图结构)和节点特征。它通过收集邻近节点的信息来计算节点嵌入。每个节点获得其邻居特征的平均值。这种平均是通过将节点特征与标准化邻接矩阵相乘来完成的,并根据节点度进行调整。为了学习这个过程,特征首先通过线性层投射到嵌入空间中。

我们将使用MLX实现一个GCN层和一个GCN模型:

 import mlx.nn as nnclass GCNLayer(nn.Module):def __init__(self, in_features, out_features, bias=True):super(GCNLayer, self).__init__()self.linear = nn.Linear(in_features, out_features, bias)def __call__(self, x, adj):x = self.linear(x)return adj @ xclass GCN(nn.Module):def __init__(self, x_dim, h_dim, out_dim, nb_layers=2, dropout=0.5, bias=True):super(GCN, self).__init__()layer_sizes = [x_dim] + [h_dim] * nb_layers + [out_dim]self.gcn_layers = [GCNLayer(in_dim, out_dim, bias)for in_dim, out_dim in zip(layer_sizes[:-1], layer_sizes[1:])]self.dropout = nn.Dropout(p=dropout)def __call__(self, x, adj):for layer in self.gcn_layers[:-1]:x = nn.relu(layer(x, adj))x = self.dropout(x)x = self.gcn_layers[-1](x, adj)return x

可以看到,mlx的模型开发方式与tf2基本一样,都是调用

__call__

进行前向传播,其实torch也一样,只不过它自定义了一个forward函数。

下面就是训练

 gcn = GCN(x_dim=x.shape[-1],h_dim=args.hidden_dim,out_dim=args.nb_classes,nb_layers=args.nb_layers,dropout=args.dropout,bias=args.bias,)mx.eval(gcn.parameters())optimizer = optim.Adam(learning_rate=args.lr)loss_and_grad_fn = nn.value_and_grad(gcn, forward_fn)# Training loopfor epoch in range(args.epochs):# Loss(loss, y_hat), grads = loss_and_grad_fn(gcn, x, adj, y, train_mask, args.weight_decay)optimizer.update(gcn, grads)mx.eval(gcn.parameters(), optimizer.state)# Validationval_loss = loss_fn(y_hat[val_mask], y[val_mask])val_acc = eval_fn(y_hat[val_mask], y[val_mask])

在MLX中,计算是惰性的,这意味着eval()通常用于在更新后实际计算新的模型参数。而另一个关键函数是nn.value_and_grad(),它生成一个计算参数损失的函数。第一个参数是保存当前参数的模型,第二个参数是用于前向传递和损失计算的可调用函数。它返回的函数接受与forward函数相同的参数(在本例中为forward_fn)。我们可以这样定义这个函数:

 def forward_fn(gcn, x, adj, y, train_mask, weight_decay):y_hat = gcn(x, adj)loss = loss_fn(y_hat[train_mask], y[train_mask], weight_decay, gcn.parameters())return loss, y_hat

它仅仅包括计算前向传递和计算损失。Loss_fn()和eval_fn()定义如下:

 def loss_fn(y_hat, y, weight_decay=0.0, parameters=None):l = mx.mean(nn.losses.cross_entropy(y_hat, y))if weight_decay != 0.0:assert parameters != None, "Model parameters missing for L2 reg."l2_reg = sum(mx.sum(p[1] ** 2) for p in tree_flatten(parameters)).sqrt()return l + weight_decay * l2_regreturn ldef eval_fn(x, y):return mx.mean(mx.argmax(x, axis=1) == y)

损失函数是计算预测和标签之间的交叉熵,并包括L2正则化。由于L2正则化还不是内置特性,需要手动实现。

本文的完整代码:https://github.com/TristanBilot/mlx-GCN

可以看到除了一些细节函数调用的差别,基本的训练流程与pytorch和tf都很类似,但是这里的一个很好的事情是消除了显式地将对象分配给特定设备的需要,就像我们在PyTorch中经常使用.cuda()和.to(device)那样。这是因为苹果硅芯片的统一内存架构,所有变量共存于同一空间,也就是说消除了CPU和GPU之间缓慢的数据传输,这样也可以保证不会再出现与设备不匹配相关的烦人的运行时错误。

基准测试

我们将使用MLX与MPS, CPU和GPU设备进行比较。我们的测试平台是一个2层GCN模型,应用于Cora数据集,其中包括2708个节点和5429条边。

对于MLX, MPS和CPU测试,我们对M1 Pro, M2 Ultra和M3 Max进行基准测试。在两款NVIDIA V100 PCIe和V100 NVLINK上进行测试

MPS:比M1 Pro的CPU快2倍以上,在其他两个芯片上,与CPU相比有30-50%的改进。

MLX:比M1 Pro上的MPS快2.34倍。与MPS相比,M2 Ultra的性能提高了24%。在M3 Pro上MPS和MLX之间没有真正的改进。

CUDA V100 PCIe & NVLINK:只有23%和34%的速度比M3 Max与MLX,这里的原因可能是因为我们的模型比较小,所以发挥不出V100和NVLINK的优势(NVLINK主要GPU之间的数据传输大的情况下会有提高)。这也说明了苹果的统一内存架构的确可以消除CPU和GPU之间缓慢的数据传输。

总结

与CPU和MPS相比,MLX可以说是非常大的金币,在小数据量的情况下它甚至接近特斯拉V100的性能。也就是说我们可以使用MLX跑一些不是那么大的模型,比如一些表格数据。

从上面的基准测试也可以看到,现在可以利用苹果芯片的全部力量在本地运行深度学习模型(我一直认为MPS还没发挥苹果的优势,这回MPS已经证明了这一点)。

MLX刚刚发布就已经取得了惊人的影响力,并展示了巨大的潜力。相信未来几年开源社区的进一步增强,可以期待在不久的将来更强大的苹果芯片,将MLX的性能提升到一个全新的水平。

另外也说明了MPS(虽然也发布不久)还是有巨大的发展空间的,毕竟切换框架是一件很麻烦的事情,如果MPS能达到MLX 80%或者90%的速度,我想不会有人去换框架的。

最后说到框架,现在已经有了Pytorch,TF,JAX,现在又多了一个MLX。各种设备、各种后端包括:TPU(pytorch使用的XLA),CUDA,ROCM,现在又多了一个MPS。

https://avoid.overfit.cn/post/eb87d12f29eb4665adb43ad59fd3d64f

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/223722.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【EasyExcel实践】万能导出,一个接口导出多张表以及任意字段(可指定字段顺序)-简化升级版

文章目录 前言正文一、项目简介二、核心代码2.1 pom.xml 依赖配置2.2 ExcelHeadMapFactory2.3 ExcelDataLinkedHashMap2.4 自定义注解 ExcelExportBean2.5 自定义注解 ExcelColumnTitle2.6 建造器接口 Builder2.7 表格工具类 ExcelUtils2.8 GsonUtil2.9 模版类 ExportDynamicCo…

Opencv中的滤波器

一副图像通过滤波器得到另一张图像,其中滤波器又称为卷积核,滤波的过程称之为卷积。 这就是一个卷积的过程,通过一个卷积核得到另一张图片,明显发现新的到的图片边缘部分更加清晰了(锐化)。 上图就是一个卷…

java8实战 lambda表达式、函数式接口、方法引用双冒号(中)

前言 书接上文,上一篇博客讲到了lambda表达式的应用场景,本篇接着将java8实战第三章的总结。建议读者先看第一篇博客 其他函数式接口例子 上一篇有讲到Java API也有其他的函数式接口,书里也举了2个例子,一个是java.util.functi…

Ubuntu系统如何安装和卸载CUDA和CUDNN

背景 最近在学习PaddlePaddle在各个显卡驱动版本的安装和使用,所以同时也学习如何在Ubuntu安装和卸载CUDA和CUDNN,在学习过程中,顺便记录学习过程。在供大家学习的同时,也在加强自己的记忆。本文章以卸载CUDA 8.0 和 CUDNN 7.05 …

Docker 编译OpenHarmony 4.0 release

一、背景介绍 1.1、环境配置 编译环境:Ubuntu 20.04OpenHarmony版本:4.0 release平台设备:RK3568 OpenHarmony 3.2更新至OpenHarmony 4.0后,公司服务器无法编译通过,总是在最后几十个文件时报错,错误码4000&#xf…

Linux一行命令配置jdk环境

使用方法: 压缩包上传 到/opt, 更换命令中对应的jdk包名即可。 注意点:jdk-8u151-linux-x64.tar.gz 解压后名字是jdk1.8.0_151 sudo tar -zxvf jdk-8u151-linux-x64.tar.gz -C /opt && echo export JAVA_HOME/opt/jdk1.8.0_151 | sudo tee -a …

Diffusion扩散模型学习:图片高斯加噪

高斯分布即正态分布;图片高斯加噪即把图片矩阵每个值和一个高斯分布的矩阵上的对应值相加 1、高斯分布 np.random.normal 一维: import numpy as np import matplotlib.pyplot as pltdef generate_gaussian_noise(mean, std_dev, size):noise np.ran…

【CentOS 7.9 分区】挂载硬盘为LVM操作实例

LVM与标准分区有何区别,如何选择 目录 1 小系统使用LVM的益处:2 大系统使用LVM的益处:3 优点:CentOS 7.9 挂载硬盘为LVM操作实例查看硬盘情况格式化硬盘创建PV创建VG创建LV创建文件系统并挂载自动挂载添加:注意用空格间…

VSCode SSH 连接提示: spawn UNKNOWN

随笔记录 目录 1. 背景介绍 2. 确认问题 : ssh -V 3. 解决问题 3.1 确认本地 ssh.exe 路径 3.2 修改vscode Remote.ssh:Path 3.2.1 设置 Reomte.ssh:Path - 方法一 3.2.2 设置 Reomte.ssh:Path - 方法二 1. 背景介绍 windows 系统vscode ssh remote CentOS7&#xff…

【零基础入门Docker】什么是Dockerfile Syntax

✍面向读者:所有人 ✍所属专栏:零基础入门Docker专栏https://blog.csdn.net/arthas777/category_12455882.html 目录 编写Dockerfile和Format的语法 2. MAINTAINER 3. RUN 4. ADD 6. ENTRYPOINT 7. CMD 8. EXPOSE 9. VOLUME 11. USER 12. ARG …

.NET core 自定义过滤器 Filter 实现webapi RestFul 统一接口数据返回格式

之前写过使用自定义返回类的方式来统一接口数据返回格式,.Net Core webapi RestFul 统一接口数据返回格式-CSDN博客 但是这存在一个问题,不是所有接口会按照定义的数据格式返回,除非每个接口都返回我们自定义的类,这种实现起来不…

Adobe InDesign各版本安装指南

下载链接 https://pan.baidu.com/s/1VWGKDUijTTETU9sVWFjCtg?pwd0531 #2024版本 1.鼠标右击【InCopy2024(64bit)】压缩包(win11及以上系统需先点击“显示更多选项”)【解压到 InCopy2024(64bit)】。 2.打开解压后的文件夹,鼠标右击【Setup…

DevOps系列文章 : 使用dpkg命令打deb包

创建一个打包的目录,类似rpmbuild,这里创建了目录deb_build mkdir deb_build目标 我有一个hello的二进制文件hello和源码hello.c, 准备安装到/opt/helloworld目录中 步骤 在deb_build目录创建一个文件夹用于存放我的安装文件 mkdir helloworld在he…

深入探讨多模态模型和计算机视觉

近年来,机器学习领域在从图像识别到自然语言处理的不同问题类型上取得了显着进展。然而,这些模型中的大多数都对来自单一模态的数据进行操作,例如图像、文本或语音。相比之下,现实世界的数据通常来自多种模态,例如图像…

【Linux】Linux常见指令解析上

目录 1. 前言2. ls指令3. pwd指令4. cd指令3.1 cd常见快捷指令 4. touch指令5. mkdir指令6. rmdir指令 && rm指令 (重要)6.1 rmdir指令6.2 rm指令 7. man指令 1. 前言 这篇文章我们将详细介绍一下Linux下常见的基本指令。 2. ls指令 语法: ls [选…

系列一、GitHub搜索技巧

一、GitHub搜索技巧 1.1、概述 作为程序员,GitHub大家应该都再熟悉不过了,很多时候当我们需要使用某一项技能而又无从下手时,通常会在百度(面向百度编程)或者在GitHub上通过关键字寻找相关案例,比如我想学…

Go自定义PriorityQueue优先队列使用Heap堆

题目 分析 每次找最大的,pop出来 然后折半,再丢进去 go写法 go如果想用heap,要实现less\len\swap\push\pop 但可以偷懒,用sort.IntSlice,已经实现了less\len\swap 但由于目前是大根堆,要重写一下less 因此&#xff…

PWM/PFM 自动切换升压型转换器系统(一)

通过对芯片整体设计要求的考虑,搭建全负载高效率升压型 DC-DC 转换器的整体系 统框架,对系统的工作过程和模块电路的功能进行简要阐述,对外围电路的选取进行准确计 算,分析系统的损耗来源,实现高效率的设计目标。 芯片…

机场信息集成系统系列介绍(8):基于视频分析的航班保障核心数据自动采集系统

目录 一、背景 二、相关功能规划 1、功能设计 2、其他设计要求 三、具体保障数据采集的覆盖点 四、相关性能指标要求 1、性能指标要求 2、算法指标要求 一、背景 基于视频分析的航班保障核心数据自动化采集系统,是ACDM系统建设的延伸,此类系统并…

Uniapp + Vue3 + Pinia + Vant3 框架搭建

现在越来越多项目都偏向于Vue3开发&#xff0c;想着uniapp搭配Vue3试试效果怎么样&#xff0c;接下来就是详细操作步骤。 初始化Uniapp Vue3项目 App.vue setup语法 <script setup>import {onLaunch,onShow,onHide} from dcloudio/uni-apponLaunch(() > {console.l…