在实时语音交互上超过GPT-4o,端到端语音模型Mini-Omni部署

Mini-Omni是清华大学开源的多模态大型语言模型,具备实时语音输入和流式音频输出的能力。

Mini-Omni模型能够一边听、一边说,一边思考,类似于ChatGPT的语言对话模式。

Mini-Omni模型的主要特点是能够直接通过音频模态进行推理,并生成流式输出,而不需要依赖额外的文本到语音(TTS)系统,从而减少了延迟。

Mini-Omni模型的架构在Qwen2-0.5B基础上进行了增强,使用了Whisper-small编码器来有效处理语音输入。

Mini-Omni模型采用了并行文本-音频生成方法,通过批量并行解码生成语音和文本,确保了模型在不同模态间的推理能力不受损害。

Mini-Omni模型还引入了VoiceAssistant-400K数据集,用于对优化语音输出的模型进行微调。

github项目地址:https://github.com/gpt-omni/mini-omni。

一、环境安装

1、python环境

建议安装python版本在3.10以上。

2、pip库安装

pip install torch==2.3.1+cu118 torchvision==0.18.1+cu118 torchaudio==2.3.1 --extra-index-url https://download.pytorch.org/whl/cu118

pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple

3、模型下载

git lfs install

git clone https://huggingface.co/gpt-omni/mini-omni

、功能测试

1、运行测试

(1)python代码调用测试

import os
import torch
import time
import lightning as L
import soundfile as sf
import whisper
from snac import SNAC
from litgpt import Tokenizer
from tqdm import tqdm
from huggingface_hub import snapshot_download
from lightning.fabric.utilities.load import _lazy_load as lazy_load
from utils.snac_utils import layershift, reconscruct_snac, reconstruct_tensors, get_time_str, get_snac, generate_audio_data
from litgpt.utils import num_parameters
from litgpt.generate.base import generate_AA, generate_ASR, generate_TA, generate_TT, generate_AT, generate_TA_BATCH, next_token_batch
from litgpt.model import GPT, Configtorch.set_printoptions(sci_mode=False)# Constants Definitions
text_vocabsize = 151936
text_specialtokens = 64
audio_vocabsize = 4096
audio_specialtokens = 64padded_text_vocabsize = text_vocabsize + text_specialtokens
padded_audio_vocabsize = audio_vocabsize + audio_specialtokens_eot = text_vocabsize
_pad_t = text_vocabsize + 1
_input_t = text_vocabsize + 2
_answer_t = text_vocabsize + 3
_asr = text_vocabsize + 4_eoa = audio_vocabsize
_pad_a = audio_vocabsize + 1
_input_a = audio_vocabsize + 2
_answer_a = audio_vocabsize + 3
_split = audio_vocabsize + 4# Utility Functions
def get_input_ids_TA(text, text_tokenizer):input_ids_item = [[] for _ in range(8)]text_tokens = text_tokenizer.encode(text)for i in range(7):input_ids_item[i] = [layershift(_pad_a, i)] * (len(text_tokens) + 2) + [layershift(_answer_a, i)]input_ids_item[i] = torch.tensor(input_ids_item[i]).unsqueeze(0)input_ids_item[-1] = [_input_t] + text_tokens.tolist() + [_eot] + [_answer_t]input_ids_item[-1] = torch.tensor(input_ids_item[-1]).unsqueeze(0)return input_ids_itemdef get_input_ids_TT(text, text_tokenizer):input_ids_item = [[] for i in range(8)]text_tokens = text_tokenizer.encode(text).tolist()for i in range(7):input_ids_item[i] = torch.tensor([layershift(_pad_a, i)] * (len(text_tokens) + 3)).unsqueeze(0)input_ids_item[-1] = [_input_t] + text_tokens + [_eot] + [_answer_t]input_ids_item[-1] = torch.tensor(input_ids_item[-1]).unsqueeze(0)return input_ids_itemdef get_input_ids_whisper(mel, leng, whispermodel, device, special_token_a=_answer_a, special_token_t=_answer_t):with torch.no_grad():mel = mel.unsqueeze(0).to(device)audio_feature = whispermodel.embed_audio(mel)[0][:leng]T = audio_feature.size(0)input_ids = []for i in range(7):input_ids_item = []input_ids_item.append(layershift(_input_a, i))input_ids_item += [layershift(_pad_a, i)] * Tinput_ids_item += [(layershift(_eoa, i)), layershift(special_token_a, i)]input_ids.append(torch.tensor(input_ids_item).unsqueeze(0))input_id_T = torch.tensor([_input_t] + [_pad_t] * T + [_eot, special_token_t])input_ids.append(input_id_T.unsqueeze(0))return audio_feature.unsqueeze(0), input_idsdef get_input_ids_whisper_ATBatch(mel, leng, whispermodel, device):with torch.no_grad():mel = mel.unsqueeze(0).to(device)audio_feature = whispermodel.embed_audio(mel)[0][:leng]T = audio_feature.size(0)input_ids_AA, input_ids_AT = [], []for i in range(7):lang_shift = layershift(_pad_a, i)input_ids_item_AA = [layershift(_input_a, i)] + [lang_shift] * T + [(layershift(_eoa, i)), layershift(_answer_a, i)]input_ids_item_AT = [layershift(_input_a, i)] + [lang_shift] * T + [(layershift(_eoa, i)), lang_shift]input_ids_AA.append(torch.tensor(input_ids_item_AA))input_ids_AT.append(torch.tensor(input_ids_item_AT))input_id_T = torch.tensor([_input_t] + [_pad_t] * T + [_eot, _answer_t])input_ids_AA.append(input_id_T)input_ids_AT.append(input_id_T)return torch.stack([audio_feature, audio_feature]), [input_ids_AA, input_ids_AT]def load_audio(path):audio = whisper.load_audio(path)duration_ms = (len(audio) / 16000) * 1000audio = whisper.pad_or_trim(audio)mel = whisper.log_mel_spectrogram(audio)return mel, int(duration_ms / 20) + 1def model_inference(fabric, model, snacmodel, tokenizer, func, step, *args, **kwargs):with fabric.init_tensor():model.set_kv_cache(batch_size=1)output = func(fabric, *args, **kwargs)model.clear_kv_cache()return outputdef load_model(ckpt_dir, device):snacmodel = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").eval().to(device)whispermodel = whisper.load_model("small").to(device)text_tokenizer = Tokenizer(ckpt_dir)fabric = L.Fabric(devices=1, strategy="auto")config = Config.from_file(ckpt_dir + "/model_config.yaml")config.post_adapter = Falsewith fabric.init_module(empty_init=False):model = GPT(config)model = fabric.setup(model)state_dict = lazy_load(ckpt_dir + "/lit_model.pth")model.load_state_dict(state_dict, strict=True)model.to(device).eval()return fabric, model, text_tokenizer, snacmodel, whispermodeldef download_model(ckpt_dir):repo_id = "gpt-omni/mini-omni"snapshot_download(repo_id, local_dir=ckpt_dir, revision="main")class OmniInference:def __init__(self, ckpt_dir='checkpoint', device='cuda'):self.device = deviceif not os.path.exists(ckpt_dir):print(f"Checkpoint directory {ckpt_dir} not found, downloading from huggingface")download_model(ckpt_dir)self.fabric, self.model, self.text_tokenizer, self.snacmodel, self.whispermodel = load_model(ckpt_dir, device)def warm_up(self, sample='./data/samples/output1.wav'):for _ in self.run_AT_batch_stream(sample):pass@torch.inference_mode()def run_AT_batch_stream(self, audio_path, stream_stride=4, max_returned_tokens=2048, temperature=0.9, top_k=1, top_p=1.0, eos_id_a=_eoa, eos_id_t=_eot):assert os.path.exists(audio_path), f"Audio file {audio_path} not found"with self.fabric.init_tensor():self.model.set_kv_cache(batch_size=2)mel, leng = load_audio(audio_path)audio_feature, input_ids = get_input_ids_whisper_ATBatch(mel, leng, self.whispermodel, self.device)T = input_ids[0].size(1)device = input_ids[0].deviceassert max_returned_tokens > T, f"Max returned tokens {max_returned_tokens} should be greater than audio length {T}"if self.model.max_seq_length < max_returned_tokens - 1:raise NotImplementedError(f"max_seq_length {self.model.max_seq_length} needs to be >= {max_returned_tokens - 1}")input_pos = torch.tensor([T], device=device)list_output = [[] for _ in range(8)]tokens_A, token_T = next_token_batch(self.model,audio_feature.to(torch.float32).to(self.model.device),input_ids,[T - 3, T - 3],["A1T2", "A1T2"],input_pos=torch.arange(0, T, device=device),temperature=temperature,top_k=top_k,top_p=top_p,)for i in range(7):list_output[i].append(tokens_A[i].tolist()[0])list_output[7].append(token_T.tolist()[0])model_input_ids = [[] for _ in range(8)]for i in range(7):tokens_A[i] = tokens_A[i].clone() + padded_text_vocabsize + i * padded_audio_vocabsizemodel_input_ids[i].append(tokens_A[i].clone().to(device).to(torch.int32))model_input_ids[i].append(torch.tensor([layershift(4097, i)], device=device))model_input_ids[i] = torch.stack(model_input_ids[i])model_input_ids[-1].append(token_T.clone().to(torch.int32))model_input_ids[-1].append(token_T.clone().to(torch.int32))model_input_ids[-1] = torch.stack(model_input_ids[-1])text_end = Falseindex = 1nums_generate = stream_stridebegin_generate = Falsecurrent_index = 0for _ in tqdm(range(2, max_returned_tokens - T + 1)):tokens_A, token_T = next_token_batch(self.model, None, model_input_ids, None, None, input_pos=input_pos, temperature=temperature, top_k=top_k, top_p=top_p)if text_end:token_T = torch.tensor([_pad_t], device=device)if tokens_A[-1] == eos_id_a:breakif token_T == eos_id_t:text_end = Truefor i in range(7):list_output[i].append(tokens_A[i].tolist()[0])list_output[7].append(token_T.tolist()[0])model_input_ids = [[] for _ in range(8)]for i in range(7):tokens_A[i] = tokens_A[i].clone() + padded_text_vocabsize + i * padded_audio_vocabsizemodel_input_ids[i].append(tokens_A[i].clone().to(device).to(torch.int32))model_input_ids[i].append(torch.tensor([layershift(4097, i)], device=device))model_input_ids[i] = torch.stack(model_input_ids[i])model_input_ids[-1].append(token_T.clone().to(torch.int32))model_input_ids[-1].append(token_T.clone().to(torch.int32))model_input_ids[-1] = torch.stack(model_input_ids[-1])if index == 7:begin_generate = Trueif begin_generate:current_index += 1if current_index == nums_generate:current_index = 0snac = get_snac(list_output, index, nums_generate)audio_stream = generate_audio_data(snac, self.snacmodel)yield audio_streaminput_pos = input_pos.add_(1)index += 1text = self.text_tokenizer.decode(torch.tensor(list_output[-1]))print(f"Text output: {text}")self.model.clear_kv_cache()return list_outputdef test_infer():device = "cuda:0"out_dir = f"./output/{get_time_str()}"ckpt_dir = f"./checkpoint"if not os.path.exists(ckpt_dir):print(f"Checkpoint directory {ckpt_dir} not found, downloading from huggingface")download_model(ckpt_dir)fabric, model, text_tokenizer, snacmodel, whispermodel = load_model(ckpt_dir, device)task = ['A1A2', 'asr', "T1A2", "AA-BATCH", 'T1T2', 'AT']# Prepare test datatest_audio_list = sorted(os.listdir('./data/samples'))test_audio_list = [os.path.join('./data/samples', path) for path in test_audio_list]test_audio_transcripts = ["What is your name?","What are your hobbies?","Do you like Beijing?","How are you feeling today?","What is the weather like today?"]test_text_list = ["What is your name?","How are you feeling today?","Can you describe your surroundings?","What did you do yesterday?","What is your favorite book and why?","How do you make a cup of tea?","What is the weather like today?","Can you explain the concept of time?","Can you tell me a joke?"]with torch.no_grad():for task_name in task:if "A1A2" in task_name:print("===============================================================")print("                       Testing A1A2")print("===============================================================")for i, path in enumerate(test_audio_list):try:mel, leng = load_audio(path)audio_feature, input_ids = get_input_ids_whisper(mel, leng, whispermodel, device)text = model_inference(fabric, model, snacmodel, text_tokenizer, A1_A2, i,fabric, audio_feature, input_ids, leng, model, text_tokenizer, i, snacmodel, out_dir)print(f"Input: {test_audio_transcripts[i]}")print(f"Output: {text}")print("+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")except Exception as e:print(f"[Error] Failed to process {path}: {e}")print("===============================================================")if 'asr' in task_name:print("===============================================================")print("                       Testing ASR")print("===============================================================")for i, path in enumerate(test_audio_list):mel, leng = load_audio(path)audio_feature, input_ids = get_input_ids_whisper(mel, leng, whispermodel, device, special_token_a=_pad_a, special_token_t=_asr)output = model_inference(fabric, model, snacmodel, text_tokenizer, A1_T1, i,fabric, audio_feature, input_ids, leng, model, text_tokenizer, i).lower().replace(',', '').replace('.', '').replace('?', '')print(f"Audio path: {path}")print(f"Audio transcript: {test_audio_transcripts[i]}")print(f"ASR output: {output}")print("+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")print("===============================================================")if "T1A2" in task_name:print("===============================================================")print("                       Testing T1A2")print("===============================================================")for i, text in enumerate(test_text_list):input_ids = get_input_ids_TA(text, text_tokenizer)text_output = model_inference(fabric, model, snacmodel, text_tokenizer, T1_A2, i,fabric, input_ids, model, text_tokenizer, i, snacmodel, out_dir)print(f"Input: {text}")print(f"Output: {text_output}")print("+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")print("===============================================================")if "T1T2" in task_name:print("===============================================================")print("                       Testing T1T2")print("===============================================================")for i, text in enumerate(test_text_list):input_ids = get_input_ids_TT(text, text_tokenizer)text_output = model_inference(fabric, model, snacmodel, text_tokenizer, T1_T2, i,fabric, input_ids, model, text_tokenizer, i)print(f"Input: {text}")print(f"Output: {text_output}")print("+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")print("===============================================================")if "AT" in task_name:print("===============================================================")print("                       Testing A1T2")print("===============================================================")for i, path in enumerate(test_audio_list):mel, leng = load_audio(path)audio_feature, input_ids = get_input_ids_whisper(mel, leng, whispermodel, device, special_token_a=_pad_a, special_token_t=_answer_t)text = model_inference(fabric, model, snacmodel, text_tokenizer, A1_T2, i,fabric, audio_feature, input_ids, leng, model, text_tokenizer, i)print(f"Input: {test_audio_transcripts[i]}")print(f"Output: {text}")print("+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")print("===============================================================")if "AA-BATCH" in task_name:print("===============================================================")print("                       Testing A1A2-BATCH")print("===============================================================")for i, path in enumerate(test_audio_list):mel, leng = load_audio(path)audio_feature, input_ids = get_input_ids_whisper_ATBatch(mel, leng, whispermodel, device)text = model_inference(fabric, model, snacmodel, text_tokenizer, A1_A2_batch, i,fabric, audio_feature, input_ids, leng, model, text_tokenizer, i, snacmodel, out_dir)print(f"Input: {test_audio_transcripts[i]}")print(f"Output: {text}")print("+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")print("===============================================================")print("========================= Test End ============================")if __name__ == "__main__":test_infer()

未完......

更多详细的欢迎关注:杰哥新技术

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/432890.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Skywalking告警配置

背景 skywalking 9.7.0&#xff0c;地址&#xff1a;Backend setup | Apache SkyWalking helm&#xff1a;skywalking-helm:4.5.0&#xff0c;地址&#xff1a;skywalking-helm/chart/skywalking/values.yaml at v4.5.0 首先来说一下为什么使用skywalking告警&#xff1f; …

JS设计模式之组合模式:打造灵活高效的对象层次结构

引言 当我们构建复杂的应用程序时&#xff0c;经常会遇到处理对象层次结构的情况。这些层次结构通常是树形结构&#xff0c;由组合节点和叶子节点组成。在这样的情况下&#xff0c;JavaScript 设计模式之一的组合模式就能派上用场。 组合模式是一种结构型设计模式&#xff0c…

MySQL从入门到精通 - 基础篇

一、MySQL概述 1. 数据库相关概念 二、SQL &#xff08;1&#xff09;SQL通用语法 &#xff08;2&#xff09;SQL分类 &#xff08;3&#xff09;数据定义语言DDL 数据库操作 表操作 数据类型 1. 数值类型 2. 字符串类型 二进制数据&#xff1a;以二进制格式&#xff08;0和…

【JavaEE初阶】深入解析死锁的产生和避免以及内存不可见问题

前言&#xff1a; &#x1f308;上期博客&#xff1a;【后端开发】JavaEE初阶—线程安全问题与加锁原理&#xff08;超详解&#xff09;-CSDN博客 &#x1f525;感兴趣的小伙伴看一看小编主页&#xff1a;GGBondlctrl-CSDN博客 ⭐️小编会在后端开发的学习中不断更新~~~ &#…

C#图像处理学习笔记(屏幕截取,打开保存图像、旋转图像、黑白、马赛克、降低亮度、浮雕)

1、创建Form窗体应用程序 打开VS&#xff0c;创建新项目-语言选择C#-Window窗体应用&#xff08;.NET Framework) 如果找不到&#xff0c;检查一下有没有安装.NET 桌面开发模块&#xff0c;如果没有&#xff0c;需要下载&#xff0c;记得勾选相关开发工具 接上一步&#xff0c;…

【UE5】将2D切片图渲染为体积纹理,最终实现使用RT实时绘制体积纹理【第四篇-着色器投影-接收阴影部分】

上一章中实现了体积渲染的光照与自阴影&#xff0c;那我们这篇来实现投影 回顾 勘误 在开始本篇内容之前&#xff0c;我已经对上一章中的内容的错误进行了修改。为了确保不会错过这些更正&#xff0c;同时也避免大家重新阅读一遍&#xff0c;我将在这里为大家演示一下修改的…

叉车司机信息权限采集系统,保障与优化叉车运输网络的安全

叉车司机信息权限采集系统可以通过监控司机的行车行为和车辆状况&#xff0c;实时掌握车辆位置和行驶路线&#xff0c;从而提高运输安全性&#xff0c;优化运输网络&#xff0c;降低事故风险。同时&#xff0c;该系统还可以通过对叉车司机信息和行车数据的分析&#xff0c;优化…

Flutter屏幕适配

我们可以根据下面有适配属性的Widget来进行屏幕适配 1.MediaQuery 通过它可以直接获得屏幕的大小&#xff08;宽度 / 高度&#xff09;和方向&#xff08;纵向 / 横向&#xff09; Size screenSize MediaQuery.of(context).size; double width screenSize.width; double h…

springboot异常(三):异常处理原理

&#x1f345;一、BasicErrorController ☘️1.1 描述 BasicErrorController是Springboot中默认的异常处理方法&#xff0c;无需额外的操作&#xff0c;当程序发生了异常之后&#xff0c;Springboot自动捕获异常&#xff0c;重新请求到BasicErrorController中&#xff0c;在B…

网络安全 DVWA通关指南 DVWA Stored Cross Site Scripting (存储型 XSS)

DVWA Stored Cross Site Scripting (存储型 XSS) 文章目录 DVWA Stored Cross Site Scripting (存储型 XSS)XSS跨站原理存储型 LowMediumHighImpossible 参考文献 WEB 安全靶场通关指南 相关阅读 Brute Force (爆破) Command Injection&#xff08;命令注入&#xff09; Cro…

Spring:项目中的统一异常处理和自定义异常

介绍异常的处理方式。在项目中&#xff0c;都会进行自定义异常&#xff0c;并且都是需要配合统一结果返回进行使用。 1.背景引入 &#xff08;1&#xff09;背景介绍 为什么要处理异常&#xff1f;如果不处理项目中的异常信息&#xff0c;前端访问我们后端就是显示访问失败的…

eslint-plugin-react的使用中,所出现的react版本警告

记一次使用eslint-plugin-react的警告 Warning: React version not specified in eslint-plugin-react settings. See https://github.com/jsx-eslint/eslint-plugin-react#configuration . 背景 我们在工程化项目中&#xff0c;常常会通过eslint来约束我们代码的一些统一格…

基于RPA+BERT的文档辅助“悦读”系统 | OPENAIGC开发者大赛高校组AI创作力奖

在第二届拯救者杯OPENAIGC开发者大赛中&#xff0c;涌现出一批技术突出、创意卓越的作品。为了让这些优秀项目被更多人看到&#xff0c;我们特意开设了优秀作品报道专栏&#xff0c;旨在展示其独特之处和开发者的精彩故事。 无论您是技术专家还是爱好者&#xff0c;希望能带给…

关于寻址方式的讨论

### 对话内容 **学生B&#xff08;ESFP&#xff09;**&#xff1a;老师&#xff0c;寻址方式听起来很复杂&#xff0c;能详细讲解一下吗&#xff1f;而且最好能举些具体例子&#xff01;&#x1f60a; **老师&#xff08;ENTP&#xff09;**&#xff1a;当然可以&#xff01;…

JVM(HotSpot):方法区(Method Area)

文章目录 一、内存结构图二、方法区定义三、内存溢出问题四、常量池与运行时常量池 一、内存结构图 1.6 方法区详细结构图 1.8方法区详细结构图 1.8后&#xff0c;方法区是JVM内存的一个逻辑结构&#xff0c;真实内存用的本地物理内存。 且字符串常量池从常量池中移入堆中。 …

蓝队技能-应急响应篇Web内存马查杀Spring框架型中间件型JVM分析Class提取

知识点&#xff1a; 1、应急响应-Web框架内存马-分析&清除 2、应急响应-Web中间件内存马-分析&清除 注&#xff1a;框架型内存马与中间件内存马只要网站重启后就清除了。 目前Java内存马具体分类&#xff1a; 1、传统Web应用型内存马 Servlet型内存马&#xff1a;…

vivado中除法器ip核的使用

看了很多博客&#xff0c;都没写清楚&#xff0c;害 我要实现 reg [9:0] a; 被除数 reg [16:0] b; 除数 wire [39:0] res; 结果 wire [15:0] real_shan; 要实现a/b 则如下这么配置 选择经过几个周期出结果 wire [39:0] res; // dly5 div_gen_0 div_gen_0_inst (.aclk(clk), …

精密制造的革新:光谱共焦传感器与工业视觉相机的融合

在现代精密制造领域&#xff0c;对微小尺寸、高精度产品的检测需求日益迫切。光谱共焦传感器凭借其非接触、高精度测量特性脱颖而出&#xff0c;而工业视觉相机则以其高分辨率、实时成像能力著称。两者的融合&#xff0c;不仅解决了传统检测方式在微米级别测量上的局限&#xf…

通过 LabVIEW 正则表达式读取数值(整数或小数)

在LabVIEW开发中&#xff0c;字符串处理是一个非常常见的需求&#xff0c;尤其是在处理包含复杂格式的数字时。本文通过一个具体的例子来说明如何利用 Match Regular Expression Function 和 Match Pattern Function 读取并解析字符串中的数字&#xff0c;并重点探讨这两个函数…

MyBatis<foreach>标签的用法与实践

foreach标签简介 实践 demo1 简单的一个批量更新&#xff0c;这里传入了一个List类型的集合作为参数&#xff0c;拼接到 in 的后面 &#xff0c;来实现一个简单的批量更新 <update id"updateVislxble" parameterType"java.util.List">update model…