PaddleNLP使用Vicuna

LLaMA 模型

LLaMa 是一个大型语言模型,由 Meta 开源。它的全称是 Large Language Model Meta AI,参数量从 70 亿到 650 亿不等。例如,130 亿参数的 LLaMA 模型在大多数基准上可以胜过参数量达 1750 亿的 GPT-3,而且可以在单块 V100 GPU 上运行。而最大的 650 亿参数的 LLaMA 模型可以媲美谷歌的 Chinchilla-70B 和 PaLM-540B。

Vicuna 模型

Vicuna 是一个由 UC 伯克利、CMU、斯坦福等机构的学者联手发布的最新开源大模型。基于 Meta 开源的 LLaMA 大模型,使用 ShareGPT 平台上的用户共享对话数据微调而来。包含 7B 和 13B 两个型号的开源预训练模型。

在这里插入图片描述

下载模型

# 下载 Vicuna 7B
# !git lfs clone http://git.aistudio.baidu.com/180581/vicuna-7b-v1.1.git# 下载 Vicuna 13B
!git lfs clone http://git.aistudio.baidu.com/180581/vicuna-13b-v1.1.git

开发环境

!pip install --pre --upgrade paddlenlp -f https://www.paddlepaddle.org.cn/whl/paddlenlp.html --user
!pip install paddlepaddle-gpu==0.0.0.post112 -f https://www.paddlepaddle.org.cn/whl/linux/gpu/develop.html --user

代码

import os
import glob
import paddlefrom tqdm import tqdm
from paddlenlp.transformers import LlamaForCausalLM, LlamaConfig, LlamaTokenizerpattern = 'paddle-model-?????-of-?????.pdparams'# Vicuna 7B
# ckpt_dir = 'vicuna-7b-v1.1'
# config_dict =  {
#     "hidden_size": 4096,
#     "initializer_range": 0.02,
#     "intermediate_size": 11008,
#     "max_position_embeddings": 2048,
#     "model_type": "llama",
#     "num_attention_heads": 32,
#     "num_hidden_layers": 32,
#     "rms_norm_eps": 1e-06,
#     "vocab_size": 32000,
#     "bos_token_id": 1,
#     "eos_token_id": 2,
#     "pad_token_id": 0,
#     "use_cache": True,
#     "use_recompute": False,
#     "use_flash_attention": False,
# }# Vicuna 13B
ckpt_dir = 'vicuna-13b-v1.1'
config_dict =  {"hidden_size": 5120,"initializer_range": 0.02,"intermediate_size": 13824,"max_position_embeddings": 2048,"model_type": "llama","num_attention_heads": 40,"num_hidden_layers": 40,"rms_norm_eps": 1e-06,"vocab_size": 32000,"bos_token_id": 1,"eos_token_id": 2,"pad_token_id": 0,"use_cache": True,"use_recompute": False,"use_flash_attention": False,
}paddle.set_default_dtype('float16')tokenizer = LlamaTokenizer.from_pretrained(ckpt_dir)config = LlamaConfig(**config_dict)model = LlamaForCausalLM(config)
model.eval()for name, layer in model.named_sublayers():if 'rotary_emb' in name:layer.inv_freq = layer.inv_freq.cast(paddle.float32)paddle.device.cuda.empty_cache()for file_path in tqdm(glob.glob(os.path.join(ckpt_dir, pattern))):params = paddle.load(file_path)assert model.set_dict(params)[1] == [], 'Load error.'del paramspaddle.device.cuda.empty_cache()input_text = input('USER: ')
prompt = f'''USER: {input_text}\n\nASSISTANT: '''
with paddle.no_grad():with paddle.amp.auto_cast(False, level='O2', dtype='float16'):while True:if input_text == 'exit':breakinputs = tokenizer(prompt, return_tensors="pd", return_attention_mask=True,return_position_ids=True)outputs = model.generate(input_ids=inputs.input_ids, attention_mask=inputs.attention_mask, position_ids=inputs.position_ids, max_length=2048-inputs.input_ids.shape[1], min_length=0, decode_strategy="sampling",temperature=0.8, top_k=40, top_p=0.95, repetition_penalty=1.1,bos_token_id=tokenizer.bos_token_id,eos_token_id=tokenizer.eos_token_id,pad_token_id=tokenizer.pad_token_id,use_cache=True, use_fast=True, use_fp16_decoding=True)response = tokenizer.decode(outputs[0][0], skip_special_tokens=True)print('ASSISTANT: ' + response)input_text = input('USER: ')prompt += f'''{response}\n\nUSER: {input_text}\n\nASSISTANT: '''del inputsdel outputsdel responsepaddle.device.cuda.empty_cache()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/119317.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

HTTP协议初识·中篇

加上目录,会出现导向不正确的情况,可能是bug,目录一长就容易出错? 本篇主要讲解了: 网页分离(网页代码和.c文件分离) html链接跳转 网页添加图片 确认并返回资源类型 填写正文长度属性 添加表单 临时重定向 补充知识&a…

Jdk8 动态编译 Java 源码为 Class 文件(三)

Jdk8 动态编译 Java 源码为 Class 文件 一.JDK版本二.工程介绍1.依赖2.启动类3.配置类(用于测试依赖注入)4.工具类1.Java 源码文件读取类2.SpringBoot 容器实例管理类 5.测试类1.抽象类2.接口类3.默认抽象实现4.默认接口实现 6.接口类1.测试接口2.类重载…

SpringWeb(SpringMVC)

目录 SpringWeb介绍 搭建 SpringWeb SpringWeb介绍 Spring Web是一个基于 Servlet API 构建的原始 web 框架,用于构建基于MVC模式的Web应用程序。在 web 层框架历经 Strust1,WebWork,Strust2 等诸多产品的历代更选 之后,目前业界普…

QT DAY4

一、使用鼠标时间完成组件的移动 #ifndef WIDGET_H #define WIDGET_H#include <QWidget> #include<QDebug> #include<QMouseEvent>QT_BEGIN_NAMESPACE namespace Ui { class Widget; } QT_END_NAMESPACEclass Widget : public QWidget {Q_OBJECTpublic:Widg…

LLMs:OpenAI官方重磅更新——新增GPT-3.5Turbo调和API更新功能

LLMs&#xff1a;OpenAI官方重磅更新——新增GPT-3.5Turbo调和API更新功能 导读&#xff1a;2023年8月22日&#xff0c;OpenAI官方发布&#xff0c;开发者现在可以使用自己的数据来定制适用于其用例的GPT-3.5 Turbo模型。GPT-3.5 Turbo的微调现在已经可用&#xff0c;GPT-4的微…

【算法与数据结构】106、LeetCode从中序与后序遍历序列构造二叉树

文章目录 一、题目二、解法三、完整代码 所有的LeetCode题解索引&#xff0c;可以看这篇文章——【算法和数据结构】LeetCode题解。 一、题目 二、解法 思路分析&#xff1a;首先我们要知道后序遍历数组的最后一个元素必然是根节点&#xff0c;然后根据根节点在中序遍历数组中的…

【LeetCode-中等题】994. 腐烂的橘子

文章目录 题目方法一&#xff1a;bfs层序遍历 题目 该题值推荐用bfs&#xff0c;因为是一层一层的感染&#xff0c;而不是一条线走到底的那种&#xff0c;所以深度优先搜索不适合 方法一&#xff1a;bfs层序遍历 广度优先搜索&#xff0c;就是从起点出发&#xff0c;每次都尝…

Android GB28181客户端开发(1):GB28181协议简介

Android GB28181客户端开发(1):GB28181协议简介 公共安全视频监控联网系统信息传输、交换、控制技术要求(2016版) 源码请翻到文章结尾 介绍GB28181协议 GB28181协议是一种基于IP网络的远程视频监控系统,它定义了设备之间的通信协议和数据格式。GB28181协议的主要特点是支…

【Rust】001-基础语法:变量声明及数据类型

【Rust】001-基础语法&#xff1a;变量声明及数据类型 文章目录 【Rust】001-基础语法&#xff1a;变量声明及数据类型一、概述1、学习起源2、依托课程 二、入门程序1、Hello World2、交互程序代码演示执行结果 3、继续上难度&#xff1a;访问链接并打印响应依赖代码执行命令 三…

Collections和CollectionUtils集合操作

0.引入依赖 <dependency><groupId>org.apache.commons</groupId><artifactId>commons-collections4</artifactId><version>4.4</version> </dependency> 一.Collections用法&#xff1a; 01、排序操作 reverse(List list)…

【摆烂之小左】Maven配置IDEA教程

Maven是什么 Maven项目对象模型(POM)&#xff0c;可以通过一小段描述信息来管理项目的构建&#xff0c;报告和文档的项目管理工具软件。 Maven 除了以程序构建能力为特色之外&#xff0c;还提供高级项目管理工具。由于 Maven 的缺省构建规则有较高的可重用性&#xff0c;所以常…

数学建模:模糊综合评价分析

&#x1f506; 文章首发于我的个人博客&#xff1a;欢迎大佬们来逛逛 数学建模&#xff1a;模糊综合评价分析 文章目录 数学建模&#xff1a;模糊综合评价分析综合评价分析常用评价方法一级模糊综合评价综合代码 多级模糊综合评价总结 综合评价分析 构成综合评价类问题的五个…

go语言基础操作--二

a : 10str : "mike"//匿名函数&#xff0c;没有函数名字 形成一个闭包,函数定义&#xff0c;还没有调用f1 : func() { //:自动推到类型fmt.Println("a ", a)fmt.Println("str ", str)}f1()//给一个函数类型起别名 这个写法不推荐type FuncType …

Flutter状态管理 — 探索Flutter中的状态

前言 随着响应式编程的理念&Flutter被大众所了解以来&#xff0c;状态管理一直是一个引人深思的话题。如果想要学习好Flutter这样的响应式的编程框架就一定是离不开状态管理的。我遇到过很多没有了解过响应式编程框架的&#xff0c;或者从事后端开发&#xff0c;自己想用F…

docker笔记4:高级复杂安装-mysql主从复制

1.主从搭建步骤 1.1新建主服务器容器实例3307 docker run -p 3307:3306 --name mysql-master \ -v /mydata/mysql-master/log:/var/log/mysql \ -v /mydata/mysql-master/data:/var/lib/mysql \ -v /mydata/mysql-master/conf:/etc/mysql \ -e MYSQL_ROOT_PASSWORDroot \ -d…

Linux的目录结构特点

Linux的目录结构特点 1、使用树形目录结构来组织和管理文件。 2、整个系统只有一个根目录&#xff08;树根&#xff09;&#xff0c;Linux的根目录用“/”表示。 3、其他所有分区以及外部设备&#xff08;如硬盘&#xff0c;光驱等&#xff09;都是以根目录为起点&#xff0…

【已解决】激活虚拟环境报错:此时不应有Anaconda3\envs\[envs]\Library\ssl\cacert.pem。

新建虚拟环境后&#xff0c;进入虚拟环境的时候出现这样的报错&#xff1a; 此时不应有Anaconda3 envs yolov5 Library ssl cacert.pem。 但是之前装的虚拟环境也还能再次激活&#xff0c;base环境也无任何问题&#xff0c;仅新装的虚拟环境无法激活。 查遍了百度谷歌&#xff…

ThreadLocal源码剖析(简单理解)

Thread部分源码 public class Thread implements Runnable {ThreadLocal.ThreadLocalMap threadLocals null; }ThreadLocal源码,其中ThreadLocal有一个静态内部类ThreadLocalMap,这个Map不是类似二叉树类型的,只是一个普通数组,其中具体使用什么算法其实我也不太理解. 然后对…

通过ref 操作dom , 点击按钮后跳转到页面指定图片位置

滚动图片到视图 定义了一个名为 scrollToIndex 的函数&#xff0c;它接受一个参数 index。当按钮被点击时&#xff0c;这个函数会被调用&#xff0c;并根据传入的 index 值来滚动到对应的图片。 以 alt 来标记图片位置 alt“Tom” import { useRef } from "react";c…

1.(python数模)单函数读取常用文件

Python单函数读取常用文件 代码如下&#xff1a; import pandas as pd# 读取数据文件 def readDataFile(readPath): # readPath: 数据文件的地址和文件名try:if (readPath[-4:] ".csv"):dfFile pd.read_csv(readPath, header0, sep",") # 间隔符为逗…