使用RoBERT进行fine tune来复现GLUE的效果

文章目录

  • 一. 参考博客or文献
  • 二. Proprocess GLUE task data
    • 2.1 下载GLUE的数据集
    • 2.2 预处理GLUE的数据集
      • 2.2.1 算法思路与整体代码以及运行结果图
      • 2.2.2 完整代码与处理结果
  • 三. 使用预处理好的数据集进行 finetune
    • 3.1 将RoBERTa的模型下载到本地
    • 3.2 微调任务之RTE(句子二分类任务)

一. 参考博客or文献

Finetuning RoBERTa on GLUE tasks

二. Proprocess GLUE task data

2.1 下载GLUE的数据集

GLUE数据集的下载链接: GLUE

import os
import sys
import shutil
import argparse
import tempfile
import urllib.request
import zipfileTASKS = ["CoLA", "SST", "MRPC", "QQP", "STS", "MNLI", "SNLI", "QNLI", "RTE", "WNLI", "diagnostic"]
TASK2PATH = {"CoLA":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FCoLA.zip?alt=media&token=46d5e637-3411-4188-bc44-5809b5bfb5f4',"SST":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSST-2.zip?alt=media&token=aabc5f6b-e466-44a2-b9b4-cf6337f84ac8',"MRPC":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2Fmrpc_dev_ids.tsv?alt=media&token=ec5c0836-31d5-48f4-b431-7480817f1adc',"QQP":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FQQP.zip?alt=media&token=700c6acf-160d-4d89-81d1-de4191d02cb5',"STS":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSTS-B.zip?alt=media&token=bddb94a7-8706-4e0d-a694-1109e12273b5',"MNLI":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FMNLI.zip?alt=media&token=50329ea1-e339-40e2-809c-10c40afff3ce',"SNLI":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSNLI.zip?alt=media&token=4afcfbb2-ff0c-4b2d-a09a-dbf07926f4df',"QNLI": 'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FQNLIv2.zip?alt=media&token=6fdcf570-0fc5-4631-8456-9505272d1601',"RTE":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FRTE.zip?alt=media&token=5efa7e85-a0bb-4f19-8ea2-9e1840f077fb',"WNLI":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FWNLI.zip?alt=media&token=068ad0a0-ded7-4bd7-99a5-5e00222e0faf',"diagnostic":'https://storage.googleapis.com/mtl-sentence-representations.appspot.com/tsvsWithoutLabels%2FAX.tsv?GoogleAccessId=firebase-adminsdk-0khhl@mtl-sentence-representations.iam.gserviceaccount.com&Expires=2498860800&Signature=DuQ2CSPt2Yfre0C%2BiISrVYrIFaZH1Lc7hBVZDD4ZyR7fZYOMNOUGpi8QxBmTNOrNPjR3z1cggo7WXFfrgECP6FBJSsURv8Ybrue8Ypt%2FTPxbuJ0Xc2FhDi%2BarnecCBFO77RSbfuz%2Bs95hRrYhTnByqu3U%2FYZPaj3tZt5QdfpH2IUROY8LiBXoXS46LE%2FgOQc%2FKN%2BA9SoscRDYsnxHfG0IjXGwHN%2Bf88q6hOmAxeNPx6moDulUF6XMUAaXCSFU%2BnRO2RDL9CapWxj%2BDl7syNyHhB7987hZ80B%2FwFkQ3MEs8auvt5XW1%2Bd4aCU7ytgM69r8JDCwibfhZxpaa4gd50QXQ%3D%3D'}MRPC_TRAIN = 'https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_train.txt'
MRPC_TEST = 'https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_test.txt'def download_and_extract(task, data_dir):print("Downloading and extracting %s..." % task)data_file = "%s.zip" % taskurllib.request.urlretrieve(TASK2PATH[task], data_file)with zipfile.ZipFile(data_file) as zip_ref:zip_ref.extractall(data_dir)os.remove(data_file)print("\tCompleted!")def format_mrpc(data_dir, path_to_data):print("Processing MRPC...")mrpc_dir = os.path.join(data_dir, "MRPC")if not os.path.isdir(mrpc_dir):os.mkdir(mrpc_dir)if path_to_data:mrpc_train_file = os.path.join(path_to_data, "msr_paraphrase_train.txt")mrpc_test_file = os.path.join(path_to_data, "msr_paraphrase_test.txt")else:print("Local MRPC data not specified, downloading data from %s" % MRPC_TRAIN)mrpc_train_file = os.path.join(mrpc_dir, "msr_paraphrase_train.txt")mrpc_test_file = os.path.join(mrpc_dir, "msr_paraphrase_test.txt")urllib.request.urlretrieve(MRPC_TRAIN, mrpc_train_file)urllib.request.urlretrieve(MRPC_TEST, mrpc_test_file)assert os.path.isfile(mrpc_train_file), "Train data not found at %s" % mrpc_train_fileassert os.path.isfile(mrpc_test_file), "Test data not found at %s" % mrpc_test_fileurllib.request.urlretrieve(TASK2PATH["MRPC"], os.path.join(mrpc_dir, "dev_ids.tsv"))dev_ids = []with open(os.path.join(mrpc_dir, "dev_ids.tsv"), encoding="utf8") as ids_fh:for row in ids_fh:dev_ids.append(row.strip().split('\t'))with open(mrpc_train_file, encoding="utf8") as data_fh, \open(os.path.join(mrpc_dir, "train.tsv"), 'w', encoding="utf8") as train_fh, \open(os.path.join(mrpc_dir, "dev.tsv"), 'w', encoding="utf8") as dev_fh:header = data_fh.readline()train_fh.write(header)dev_fh.write(header)for row in data_fh:label, id1, id2, s1, s2 = row.strip().split('\t')if [id1, id2] in dev_ids:dev_fh.write("%s\t%s\t%s\t%s\t%s\n" % (label, id1, id2, s1, s2))else:train_fh.write("%s\t%s\t%s\t%s\t%s\n" % (label, id1, id2, s1, s2))with open(mrpc_test_file, encoding="utf8") as data_fh, \open(os.path.join(mrpc_dir, "test.tsv"), 'w', encoding="utf8") as test_fh:header = data_fh.readline()test_fh.write("index\t#1 ID\t#2 ID\t#1 String\t#2 String\n")for idx, row in enumerate(data_fh):label, id1, id2, s1, s2 = row.strip().split('\t')test_fh.write("%d\t%s\t%s\t%s\t%s\n" % (idx, id1, id2, s1, s2))print("\tCompleted!")def download_diagnostic(data_dir):print("Downloading and extracting diagnostic...")if not os.path.isdir(os.path.join(data_dir, "diagnostic")):os.mkdir(os.path.join(data_dir, "diagnostic"))data_file = os.path.join(data_dir, "diagnostic", "diagnostic.tsv")urllib.request.urlretrieve(TASK2PATH["diagnostic"], data_file)print("\tCompleted!")returndef get_tasks(task_names):task_names = task_names.split(',')if "all" in task_names:tasks = TASKSelse:tasks = []for task_name in task_names:assert task_name in TASKS, "Task %s not found!" % task_nametasks.append(task_name)return tasksdef main(arguments):parser = argparse.ArgumentParser()parser.add_argument('--data_dir', help='directory to save data to', type=str, default='glue_data')parser.add_argument('--tasks', help='tasks to download data for as a comma separated string',type=str, default='all')parser.add_argument('--path_to_mrpc', help='path to directory containing extracted MRPC data, msr_paraphrase_train.txt and msr_paraphrase_text.txt',type=str, default='')args = parser.parse_args(arguments)if not os.path.isdir(args.data_dir):os.mkdir(args.data_dir)tasks = get_tasks(args.tasks)for task in tasks:if task == 'MRPC':format_mrpc(args.data_dir, args.path_to_mrpc)elif task == 'diagnostic':download_diagnostic(args.data_dir)else:download_and_extract(task, args.data_dir)if __name__ == '__main__':sys.exit(main(sys.argv[1:]))

起到的作用是, 每个数据集从网址上下载下来, 存储在文件夹中.

2.2 预处理GLUE的数据集

if [[ $# -ne 2 ]]; thenecho "Run as following:"echo "./examples/roberta/preprocess_GLUE_tasks.sh <glud_data_folder> <task_name>"exit 1
fi
  • ‘$#’: 是一个特殊的变量, 它表示命令行参数的数量.
  • ‘-ne’: 是一个比较运算符, 表示不等于.
  • ‘echo “Run as following” ’: epoch是一个命令, 用于在终端输出文本, 所以这段代码表的含义是: 在终端输出这句话.
  • ‘exit 1’: exit是一个内置命令, 用于退出当前的脚本或终端会话; 1是一个退出状态码, 用于指示脚本的非正常退出, 非零的退出状态码通常表示发生了错误.
TASKS=$2 # QQP
  • ‘TASKS=’: 这是一个变量赋值的语法形式, TASKS是变量名.
  • ‘$2’: 这是一个特殊的变量, 在Bash脚本中表示命令行参数的索引位置; '$2’表示命令行参数的第二个参数.
SPLITS="train dev test"INPUT_COUNT=2if [ "$TASK" = "QQP" ]     thenINPUT_COLUMNS=( 4 5 )     # train.tsv: id, qid1, qid2, question1, question2, is_duplicateTEST_INPUT_COLUMNS=( 2 3 )     # test.tsv: id, question1, question2LABEL_COLUMN=6elif [ "$TASK" = "MNLI" ]thenSPLITS="train dev_matched dev_mismatched test_matched test_mismatched"INPUT_COLUMNS=( 9 10 )   # train.tsv: index, promptID,	pairID,	genre,	sentence1_binary_parse,	sentence2_binary_parse,	sentence1_parse,	sentence2_parse,	sentence1,	sentence2,	label1,	gold_labelTEST_INPUT_COLUMNS=( 9 10 )   DEV_LABEL_COLUMN=16      # dev.tsv: index	promptID	pairID	genre	sentence1_binary_parse	sentence2_binary_parse	sentence1_parse	sentence2_parse	sentence1	sentence2	label1	label2	label3	label4	label5	gold_labelLABEL_COLUMN=12elif [ "$TASK" = "QNLI" ]thenINPUT_COLUMNS=( 2 3 )     # train.tsv: index, question, sentence, labelTEST_INPUT_COLUMNS=( 2 3 )     LABEL_COLUMN=4elif [ "$TASK" = "MRPC" ]thenINPUT_COLUMNS=( 4 5 )     # train.txt: Quality, #1 ID, #2 ID, #1 String, #2 StringTEST_INPUT_COLUMNS=( 4 5 )LABEL_COLUMN=1elif [ "$TASK" = "RTE" ]thenINPUT_COLUMNS=( 2 3 )     # train.tsv: index, sentence1, sentence2, labelTEST_INPUT_COLUMNS=( 2 3 )LABEL_COLUMN=4elif [ "$TASK" = "STS-B" ]thenINPUT_COLUMNS=( 8 9 )     # train.tsv: index, genre, filename, year, old_index, source1, source2, sentence1, sentence2, scoreTEST_INPUT_COLUMNS=( 8 9 )LABEL_COLUMN=10# Following are single sentence tasks.elif [ "$TASK" = "SST-2" ]thenINPUT_COLUMNS=( 1 )     # train.tsv: sentece, labelTEST_INPUT_COLUMNS=( 2 )     # test.tsv: index, sentenceLABEL_COLUMN=2INPUT_COUNT=1elif [ "$TASK" = "CoLA" ]thenINPUT_COLUMNS=( 4 )     # train.tsv: gj04, 1, *, 'The gardener watered the flowers.'TEST_INPUT_COLUMNS=( 2 )     # test.tsv: index, sentenceLABEL_COLUMN=2INPUT_COUNT=1fi
  • ‘INPUT_COLUMNS’: train数据集中的features. 举例: 如果是做两个句子的关联任务, 则features有两个, 分别代表两个句子.(第几列)
  • ‘TEST_INPUT_COLUMNS’: test数据集中的features.(第几列)
  • ‘LABEL_COLUMN’: train数据集中的flag/score.(第几列)
  • ‘DEV_LABEL_COLUMN’: dev数据集中的features.(第几列)
  • ‘INPUT_COUNT’: train数据集features的列数.
  rm -rf "$TASK_DATA_FOLDER/processed"mkdir -p "$TASK_DATA_FOLDER/processed"
  • 删除该数据集文件夹下的processed文件/文件夹, 并创建该数据集目录下文件夹processed, 猜测用于存放处理后的数据.
  for SPLIT in $SPLITS     # SPLITS: 'train, dev, test' or 'train, dev_matched, dev_mismatched, test_matched, test_mismatched'do# CoLA train and dev doesn't have header.if [[ ( "$TASK" = "CoLA") && ( "$SPLIT" != "test" ) ]]thencp "$TASK_DATA_FOLDER/$SPLIT.tsv" "$TASK_DATA_FOLDER/processed/$SPLIT.tsv.temp";elsetail -n +2 "$TASK_DATA_FOLDER/$SPLIT.tsv" > "$TASK_DATA_FOLDER/processed/$SPLIT.tsv.temp";fi# Remove unformatted lines from train and dev files for QQP dataset.if [[ ( "$TASK" = "QQP") && ( "$SPLIT" != "test" ) ]]thenawk -F '\t' -v NUM_FIELDS=6 'NF==NUM_FIELDS{print}{}' "$TASK_DATA_FOLDER/processed/$SPLIT.tsv.temp" > "$TASK_DATA_FOLDER/processed/$SPLIT.tsv";elsecp "$TASK_DATA_FOLDER/processed/$SPLIT.tsv.temp" "$TASK_DATA_FOLDER/processed/$SPLIT.tsv";firm "$TASK_DATA_FOLDER/processed/$SPLIT.tsv.temp";done
  • 如果是CoLA, 且为train与dev数据集, 则直接cp到/processed/xxx.tsv.temp; 否则将/xxx.temp 的第二行开始读取到/processed/xxx.tsv.temp中. 这其中tail -n +2, 指的是从数据集的第二行开始读取数据.
  • 标准的QQP数据集一共有6列, 这里想要通过awk命令来去除掉那些非标准的数据.
  # Split into input0, input1 and labelfor SPLIT in $SPLITSdofor INPUT_TYPE in $(seq 0 $((INPUT_COUNT-1)))doif [[ "$SPLIT" != test* ]]thenCOLUMN_NUMBER=${INPUT_COLUMNS[$INPUT_TYPE]}elseCOLUMN_NUMBER=${TEST_INPUT_COLUMNS[$INPUT_TYPE]}ficut -f"$COLUMN_NUMBER" "$TASK_DATA_FOLDER/processed/$SPLIT.tsv" > "$TASK_DATA_FOLDER/processed/$SPLIT.raw.input$INPUT_TYPE";doneif [[ "$SPLIT" != test* ]]thenif [ "$TASK" = "MNLI" ] && [ "$SPLIT" != "train" ]thencut -f"$DEV_LABEL_COLUMN" "$TASK_DATA_FOLDER/processed/$SPLIT.tsv"  > "$TASK_DATA_FOLDER/processed/$SPLIT.label";elsecut -f"$LABEL_COLUMN" "$TASK_DATA_FOLDER/processed/$SPLIT.tsv" > "$TASK_DATA_FOLDER/processed/$SPLIT.label";fifi

这段代码的作用是将input, label给提取出来.其中利用了shell命令中的cut命令能够指定原文件的列到输出文件中.
所有的文件被处理到目录: xxx/processed/train.label, xxx/processed/test.raw.input0

    # BPE encode.for INPUT_TYPE in $(seq 0 $((INPUT_COUNT-1)))doLANG="input$INPUT_TYPE"echo "BPE encoding $SPLIT/$LANG"python -m examples.roberta.multiprocessing_bpe_encoder \--encoder-json encoder.json \--vocab-bpe vocab.bpe \--inputs "$TASK_DATA_FOLDER/processed/$SPLIT.raw.$LANG" \--outputs "$TASK_DATA_FOLDER/processed/$SPLIT.$LANG" \--workers 60 \--keep-empty;done

使用BPE编码方式, 进行编码.

  # Remove output directory.rm -rf "$TASK-bin"DEVPREF="$TASK_DATA_FOLDER/processed/dev.LANG"TESTPREF="$TASK_DATA_FOLDER/processed/test.LANG"if [ "$TASK" = "MNLI" ]thenDEVPREF="$TASK_DATA_FOLDER/processed/dev_matched.LANG,$TASK_DATA_FOLDER/processed/dev_mismatched.LANG"TESTPREF="$TASK_DATA_FOLDER/processed/test_matched.LANG,$TASK_DATA_FOLDER/processed/test_mismatched.LANG"fi# Run fairseq preprocessing:for INPUT_TYPE in $(seq 0 $((INPUT_COUNT-1)))doLANG="input$INPUT_TYPE"fairseq-preprocess \--only-source \--trainpref "$TASK_DATA_FOLDER/processed/train.$LANG" \--validpref "${DEVPREF//LANG/$LANG}" \--testpref "${TESTPREF//LANG/$LANG}" \--destdir "$TASK-bin/$LANG" \--workers 60 \--srcdict dict.txt;doneif [[ "$TASK" !=  "STS-B" ]]thenfairseq-preprocess \--only-source \--trainpref "$TASK_DATA_FOLDER/processed/train.label" \--validpref "${DEVPREF//LANG/label}" \--destdir "$TASK-bin/label" \--workers 60;else# For STS-B output range is converted to be between: [0.0, 1.0]mkdir -p "$TASK-bin/label"awk '{print $1 / 5.0 }' "$TASK_DATA_FOLDER/processed/train.label" > "$TASK-bin/label/train.label"awk '{print $1 / 5.0 }' "$TASK_DATA_FOLDER/processed/dev.label" > "$TASK-bin/label/valid.label"fi

调用fairseq-preprocess 命令来对数据集进行最后一步的处理.

2.2.1 算法思路与整体代码以及运行结果图

  • 整体思路: 将九个glue的文件处理为vector类型的可训练文件.
  • 需要的文件:
    • glue的九个数据集: CoLA, MNLI, MRPC, QNLI, QQP, RTE, SST-2, STS-B, WNLI.(第一次测试只处处理前8个数据集)
    • 撰写 preprocess_GLUE_task.sh 脚本文件.
    • 准备好 BPE encode 的单词表以及使用BPE的python文件, 其中包括 以下文件: examples.roberta.multiprocessing_bpe_encoder, encoder.json, multiprocessing_bpe_encoder.py.
  • 细节思路:
      1. 获取各个下游任务的 input_features 以及 label 所在的列.
      1. 去除每个下游任务的head, 这里对于CoLA数据集需要特殊考虑, 因为它本身文件中就没有head.
      1. 去除QQP数据集中一些 unformatted 的 lines.
      1. 利用第一步得到的列以及2,3步得到的清洗后的数据集, 直接提取出features与label.
      1. 使用BPE文件, 对features进行encoding.
      1. 使用 fairseq-preprocess进行train, dev, test数据集的制作, 生成了bin, log与idx文件, 方便后续模型的训练.
  • 实际操作的例子:
      1. 文件路径设置, 以 all task 为例子, 假设为~/LLM/GLUE/MyProcess_Glue/preprocess_GLUE_tasks.sh data ALL.
      1. 去掉所有文件的header, 以CoLA为例子, 将 data/CoLA/train.tsv -> ‘data/CoLA/processed/train_temp.tsv’
      1. 去掉QQA数据集中的unformatted的数据line, 并它以及其他数据集存储到 'data/CoLA/processed/train.tsv’中, 将’data/CoLA/processed/train_temp.tsv’删除.
      1. 提取出文件中的 features 和 label 的columns, 这里需要注意MNLI数据集的dev与train中label所在的column是不一样的, 需要分开处理, 其他的都一样; features输出到 ‘data/CoLA/processed/train_raw_input0.tsv’, ‘data/CoLA/processed/train_raw_input1.tsv’; label输出到 ‘data/CoLA/processed/train_label.tsv’;
      1. 使用bpe进行encoding, 将encoding后的文件输出为 ‘CoLA/processed/train_input0.tsv’.
      1. 使用 fairseq-preprocess命令, 将encoding好的文件作为input, 输出对应的bin, log, idx文件, 并放到 CoLA-bin文件夹下.

2.2.2 完整代码与处理结果

完整代码

#!/bin/bash
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.# This is my program to address glue dataset.# judge the line
if [[ $# -ne 2 ]]; thenecho "Run as following:"echo "~/LLM/GLUE/MyProcess_Glue/preprocess_GLUE_tasks.sh <glue_data_folder> <tssk_name>"exit 1
fi# get the path of folder
GLUE_DATA_FOLDER=$1
# get the tasks of operating
TASKS=$2# download bpe encoder.json, vocabulary and fairseq dictionary
# wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json'
# wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe'
# wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/dict.txt'if [ "$TASKS" = "ALL" ]
thenTASKS="QQP MNLI QNLI MRPC RTE STS-B SST-2 CoLA"
fi# the starting of preprocessing tasks
for TASK in $TASKS
doecho "Precessing $TASK"# get current task's directoryTASK_DATA_FOLDER="$GLUE_DATA_FOLDER/$TASK"echo "Raw data as download from glue directory: $TASK_DATA_FOLDER"# We will get three datasetSPLITS="train dev test"INPUT_COUNT=2 # default the number of input senteces# get the columns of features and label, in respect to train, dev, test.if [ "$TASK" = "QQP" ] # train,dev,testthenINPUT_COLUMNS=( 4 5 ) # id,qid1,qid2,question1,question2,is_duplicataTEST_INPUT_COLUMNS=( 2 3 ) # id,question1,question2LABEL_COLUMN=6elif [ "$TASK" = "MNLI" ] # train,test_match,test_mismatch,dev_m,dev_mthenSPLITS="train dev_matched dev_mismatched test_matched test_mismatched"INPUT_COLUMNS=( 9 10 ) # index,proptID,pairID,genre,sentence1_binary_parse,sentence2_binary_parse,sentence1_parse,sentence2_parse,sentence1,sentence2,label1,gold_labelTEST_INPUT_COLUMNS=( 9 10 ) # index,promptID,pairID,genre,sentence1_binary_parse,sentence2_binary_parse,sentence1_parse,sentence2_parse,sentence1,sentence2DEV_LABEL_COLUMN=16 # index,promptID,pairID,genre,sentence1_binary_parse,sentence2_binary_parse,sentence1_parse,sentence2_parse,sentence1,sentence2,label1,label2,label3,label4,label5,gold_labelLABEL_COLUMN=12 elif [ "$TASK" = "QNLI" ] # train,test,devthenINPUT_COLUMNS=( 2 3 ) # index,question,sentence,labelTEST_INPUT_COLUMNS=( 2 3 ) # index, question, sentenceLABEL_COLUMN=4elif [ "$TASK" = "MRPC" ] # train(dev), testthenINPUT_COLUMNS=( 4 5 ) # Quality,1ID,2ID,1String,2StringTEST_INPUT_COLUMNS=( 4 5 ) # Quality,1ID,2ID,1String,2StringLABEL_COLUMN=1elif [ "$TASK" = "RTE" ] # train,dev,testthenINPUT_COLUMNS=( 2 3 ) # index,sentence1,sentence2,labelTEST_INPUT_COLUMNS=( 2 3 ) # index,sentence1,sentence2LABEL_COLUMN=4elif [ "$TASK" = "STS-B" ] # train,dev,testthenINPUT_COLUMNS=( 8 9 ) # index,genre,filename,year,old_index,source1,source2,sentence1,sentence2,scoreTEST_INPUT_COLUMNS=( 8 9 ) # index,genre,filename,year,old_index,source1,source2,sentence1,sentence2LABEL_COLUMN=10elif [ "$TASK" = "SST-2" ] # train,dev,testthenINPUT_COLUMNS=( 1 ) # sentence,labelTEST_INPUT_COLUMNS=( 2 ) # index,sentenceLABEL_COLUMN=2INPUT_COUNT=1elif [ "$TASK" = "CoLA" ] # train(there aren't heads),dev(too),testthenINPUT_COLUMNS=( 4 ) # xxx,1,*,sentenceTEST_INPUT_COLUMNS=( 2 )LABEL_COLUMN=2INPUT_COUNT=1fi # mkdir a folder to save our new dataset processedrm -rf "$TASK_DATA_FOLDER/processed"mkdir -p "$TASK_DATA_FOLDER/processed"get the pointed columns from $TASK_DATA_FOLDER=$GLUE_DATA_FOLDER/$TASKfor SPLIT in $SPLITS do# CoLA train and dev doesn't have hdeader.if [[ ( "$TASK" = "GoLA" ) && ( "$SPLIT" != "test" ) ]]then   # CoLA's train or devcp "$TASK_DATA_FOLDER/$SPLIT.tsv" "$TASK_DATA_FOLDER/processed/${SPLIT}_temp.tsv";elsetail -n +2 "$TASK_DATA_FOLDER/$SPLIT.tsv" > "$TASK_DATA_FOLDER/processed/${SPLIT}_temp.tsv";fi# Remove unformatted lines from train and dev files for QQP dataset.if [[ ( "$TASK" = "QQP" ) && ( "$SPLIT" != "test" ) ]]thenawk -F '\t' -v NUM_FILELDS=6 'NF==NUM_FILELDS{print}{}' "$TASK_DATA_FOLDER/processed/${SPLIT}_temp.tsv" > "$TASK_DATA_FOLDER/processed/$SPLIT.tsv";elsecp "$TASK_DATA_FOLDER/processed/${SPLIT}_temp.tsv" "$TASK_DATA_FOLDER/processed/$SPLIT.tsv";firm "$TASK_DATA_FOLDER/processed/${SPLIT}_temp.tsv";done# Get features and label columns, called them "input0, input1, label"for SPLIT in $SPLITSdo# Extract featuresfor INPUT_TYPE in $(seq 0 $(( INPUT_COUNT - 1 )))do   # process the train and dev dataset.if [[ "$SPLIT" != test* ]]then COLUMN_NUMBER=${INPUT_COLUMNS[$INPUT_TYPE]}elseCOLUMN_NUMBER=${TEST_INPUT_COLUMNS[$INPUT_TYPE]}ficut -f "$COLUMN_NUMBER" "$TASK_DATA_FOLDER/processed/$SPLIT.tsv" > "$TASK_DATA_FOLDER/processed/${SPLIT}_raw_input$INPUT_TYPE.tsv";done# Extract labelsif [[ "$SPLIT" != test* ]]thenif [ "$TASK" = "MNLI" ] && [ "$SPLIT" != "train" ]  # Only this dataset's dev have a different label column, in respect to train datasetthencut -f "$DEV_LABEL_COLUMN" "$TASK_DATA_FOLDER/processed/$SPLIT.tsv" > "$TASK_DATA_FOLDER/processed/${SPLIT}_label.tsv";            elsecut -f "$LABEL_COLUMN" "$TASK_DATA_FOLDER/processed/$SPLIT.tsv" > "$TASK_DATA_FOLDER/processed/${SPLIT}_label.tsv";    fifi# BPE encodefor INPUT_TYPE in $(seq 0 $(( INPUT_COUNT - 1 )))doLANG="input$INPUT_TYPE" echo "BPE encoding $SPLIT/$LANG"python -m multiprocessing_bpe_encoder \--encoder-json encoder.json \--vocab-bpe vocab.bpe \--inputs "$TASK_DATA_FOLDER/processed/${SPLIT}_raw_$LANG.tsv" \--outputs "$TASK_DATA_FOLDER/processed/${SPLIT}_$LANG.tsv" \--workers 60 \--keep-empty;donedone# Remove output directoryrm -rf "$TASK-bin"DEVPREF="$TASK_DATA_FOLDER/processed/dev_LANG"TESTPREF="$TASK_DATA_FOLDER/processed/test_LANG"if [ "$TASK" = "MNLI" ]thenDEVPREF="$TASK_DATA_FOLDER/processed/dev_matched_LANG,$TASK_DATA_FOLDER/processed/dev_mismatched_LANG"TESTPREF="$TASK_DATA_FOLDER/processed/test_matched_LANG,$TASK_DATA_FOLDER/processd/test_mismatched_LANG"fi# Run fairseq preprocessing:for INPUT_TYPE in $(seq 0 $(( INPUT_COUNT-1 )))doLANG="input$INPUT_TYPE.tsv"fairseq-preprocess \--only-source \--trainpref "$TASK_DATA_FOLDER/processed/train_$LANG" \--validpref "${DEVPREF//LANG/$LANG}" \--testpref "${TESTPREF//LANG/$LANG}" \--destdir "$TASK-bin/$LANG" \--workers 60 \--srcdict dict.txt;doneif [[ "$TASK" != "STS-B" ]]thenfairseq-preprocess \--only-source \--trainpref "$TASK_DATA_FOLDER/processed/train_label.tsv" \--validpref "${DEVPREF//LANG/label.tsv}" \--destdir "$TASK-bin/label" \--workers 60;else# For STS-B output range is converted to be between: [0.0, 1.0]awk '{print $1 / 5.0 }' "$TASK_DATA_FOLDER/processed/train_label.tsv" > "$TASK-bin/label/train_label.tsv"awk '{print $1 / 5.0 }' "$TASK_DATA_FOLDER/processed/dev_label.tsv" > "$TASK-bin/label/valid_label.tsv" fi
done

处理结果:
在这里插入图片描述

三. 使用预处理好的数据集进行 finetune

3.1 将RoBERTa的模型下载到本地

这里我使用base模型来做例子.
在这里插入图片描述

3.2 微调任务之RTE(句子二分类任务)

  • 需要涉及到的文件
      1. Pretrain的model(.pt)文件
      1. 指令: fairseq-hydra-train
      1. 对应下游任务的yaml文件, 这里是RTE.yaml
      1. 微调的数据集文件(文件夹), RTE-bin.
      1. 指定需要存储的checkpoint路径文件.(一般与model文件是一样的).
  • shell代码:
#!/bin/bashROBERTA_PATH=/home/phac123/LLM/RoBERTa/fine_tune_demo1_RTE/base/model.ptCUDA_VISIBLE_DEVICES=1 fairseq-hydra-train --config-dir /home/phac123/LLM/RoBERTa/fine_tune_demo1_RTE/ --config-name rte \
task.data=/home/phac123/LLM/RoBERTa/fine_tune_demo1_RTE/RTE-bin checkpoint.restore_file=$ROBERTA_PATH

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/92900.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

python编程需要的电脑配置,python编程对电脑的要求

大家好&#xff0c;给大家分享一下python编程用什么笔记本电脑&#xff0c;很多人还不知道这一点。下面详细解释一下。现在让我们来看看&#xff01; 不打游戏&#xff0c;只学编程。刚开始自学 Python小发猫伪原创,python下载需要花钱吗。 如果不搞机器学习的话&#xff0c;也…

TypeScript 语法

环境搭建 以javascript为基础构建的语言&#xff0c;一个js的超集&#xff0c;可以在任何支持js的平台中执行&#xff0c;ts扩展了js并且添加了类型&#xff0c;但是ts不能被js解析器直接执行&#xff0c;需要编译器编译为js文件&#xff0c;然后引入到 html 页面使用。 ts增…

Blender增强现实3D模型制作指南【AR】

推荐&#xff1a;用 NSDT编辑器 快速搭建可编程3D场景 将静态和动画 3D 内容集成到移动增强现实 (AR) 体验中是增强用户沉浸感和参与度的高效方法。 然而&#xff0c;为 AR 创建 3D 对象可能相当艰巨&#xff0c;尤其是对于那些缺乏 3D 建模经验的人来说。 与添加视频或照片 AR…

Mariadb高可用MHA (四十二)

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 目录 前言 一、概述 1.1 概念 1.2 组成 1.3 特点 1.4 工作原理 二、构建MHA 2.1 ssh免密登录 2.2 主从复制 2.3 MHA安装 2.3.1所有节点安装perl环境 2.3..2 node 2.3.…

零售行业供应链管理核心KPI指标(三)

完美订单满足率和退货率 完美订单满足率有三个方面的因素影响&#xff1a;订单按时、足量、无损交货。通常情况下零售企业追求线上订单履行周期慢慢达到行业平均水平&#xff0c;就是交付的速度变快了&#xff0c;这个肯定是一件好事情&#xff0c;趋势越来越好。 同时&#…

Vim的插件管理器之Vundle

1、安装Vundle插件管理器 Vim可以安装插件&#xff0c;但是需要手动安装比较麻烦&#xff0c;Vim本身没有提供插件管理器&#xff0c;所以会有很多的第三方的插件管理器&#xff0c;有一个vim的插件叫做 “vim-easymotion”&#xff0c;在它的github的安装说明里有列出对于不同…

log4j:WARN No appenders could be found for logger问题

本文将idea场景下的使用。 IDEA中&#xff0c;将配置文件命名为log4j.properties&#xff08;该命名才会被自动加载&#xff09;&#xff0c; 并放到某个目录下&#xff08;通常放到resources目录&#xff09;&#xff0c;并在resources上右键&#xff0c;找到Mark Directory a…

Nginx转发请求到后端服务报400 Bad Request

问题描述 系统部署好后&#xff0c;进行测试时发现有部分接口出错&#xff0c;项目采用Nginx作为后端代理服务器&#xff0c;有Nginx统一将请求转发到后端的网关服务&#xff0c;再由网关服务路由到具体的服务上&#xff0c;发布好后&#xff0c;大部分接口都是正常的&#xff…

POSTGRESQL 关于2023-08-14 数据库自动启动文章中使用KILL 来进行配置RELOAD的问题解释...

开头还是介绍一下群&#xff0c;如果感兴趣Polardb ,mongodb ,MySQL ,Postgresql ,redis &#xff0c;SQL SERVER ,ORACLE,Oceanbase 等有问题&#xff0c;有需求都可以加群群内有各大数据库行业大咖&#xff0c;CTO&#xff0c;可以解决你的问题。加群请加 liuaustin3微信号 &…

常见排序集锦-C语言实现数据结构

目录 排序的概念 常见排序集锦 1.直接插入排序 2.希尔排序 3.选择排序 4.堆排序 5.冒泡排序 6.快速排序 hoare 挖坑法 前后指针法 非递归 7.归并排序 非递归 排序实现接口 算法复杂度与稳定性分析 排序的概念 排序 &#xff1a;所谓排序&#xff0c;就是使一串记录&#…

【计算机网络】13、ARP 包:广播自己的 mac 地址和 ip

机器启动时&#xff0c;会向外广播自己的 mac 地址和 ip 地址&#xff0c;这个即称为 arp 协议。范围是未经过路由器的部分&#xff0c;如下图的蓝色部分&#xff0c;范围内的设备都会在本地记录 mac 和 ip 的绑定信息&#xff0c;若有重复则覆盖更新&#xff08;例如先收到 ma…

ESP32+VSCode开发环境搭建(全网最强最终解决方案)

文章目录 1 安装步骤2 开发机器环境准备3 安装ESP-IDF-tools离线包4 创建VSCode配置文件(纯净的开发环境)5 安装espressif IDF 插件6 程序测试7 常见问题7.1环境变量设置问题&#xff1f;问题1&#xff1a;到底是设置IDF_TOOLS_PATH和IDF_PATH还是只配置一个IDF_TOOLS_PATH? 7…

Spring的简介ioc容器及注入方式

一.Spring的简介 1.Spring的特性 Spring是一个开源框架&#xff0c;它由Rod Johnson创建。它是为了解决企业应用开发的复杂性而创建的。 Spring使用基本的JavaBean来完成以前只可能由EJB完成的事情。 然而&#xff0c;Spring的用途不仅限于服务器端的开发。从简单性、可测试性…

Python文件操作与输入输出:从基础到高级应用

文章目录 &#x1f340;引言&#x1f340;文件操作基础&#x1f340;上下文管理器与文件自动关闭&#x1f340;文件的迭代与逐行读取&#x1f340;文件的其他常见操作&#x1f340;输入输出基础&#x1f340; 文件输入输出&#x1f340;格式化输出&#x1f340;高级文件操作&am…

(二)掌握最基本的Linux服务器用法——Linux下简单的C/C++ 程序、项目编译

1、静态库与动态库 静态库(Static Library)&#xff1a;静态库是编译后的库文件&#xff0c;其中的代码在编译时被链接到程序中&#xff0c;因此它会与程序一起形成一个独立的可执行文件。每个使用静态库的程序都会有自己的库的副本&#xff0c;这可能会导致内存浪费。常用后缀…

AI 绘画Stable Diffusion 研究(九)sd图生图功能详解-老照片高清修复放大

大家好&#xff0c;我是风雨无阻。 通过前面几篇文章的介绍&#xff0c;相信各位小伙伴&#xff0c;对 Stable Diffusion 这款强大的AI 绘图系统有了全新的认知。我们见识到了借助 Stable Diffusion的文生图功能&#xff0c;利用简单的几个单词&#xff0c;就可以生成完美的图片…

7-3 求给定精度的简单交错序列部分和

分数 15 全屏浏览题目 切换布局 作者 C课程组 单位 浙江大学 本题要求编写程序&#xff0c;计算序列部分和 1 - 1/4 1/7 - 1/10 ... 直到最后一项的绝对值不大于给定精度eps。 输入格式: 输入在一行中给出一个正实数eps。 输出格式: 在一行中按照“sum S”的格式输出…

.net连接mysql,提示找不到请求的 .Net Framework Data Provider。可能没有安装

开发完成的.net程序需要连接mysql数据库&#xff0c;在个人电脑上运行没问题&#xff0c;别人运行时提示“提示找不到请求的 .Net Framework Data Provider。可能没有安装”。经过查询&#xff0c;安装Connector/NET 8.1.0&#xff0c;下载地址如下所示&#xff1a; https://d…

(分治) 剑指 Offer 16. 数值的整数次方 ——【Leetcode每日一题】

❓剑指 Offer 16. 数值的整数次方 难度&#xff1a;中等 实现 pow(x, n) &#xff0c;即计算 x 的 n 次幂函数&#xff08;即&#xff0c; x n x^n xn&#xff09;。不得使用库函数&#xff0c;同时不需要考虑大数问题。 示例 1&#xff1a; 输入&#xff1a;x 2.00000, n …

腾讯云轻量服务器测评:2核 2G 4M

腾讯云轻量2核2G4M服务器&#xff0c;4M带宽下载速度可达512KB/秒&#xff0c;系统盘为50GB SSD盘&#xff0c;300GB月流量&#xff0c;地域节点可选上海、广州和北京&#xff0c;腾讯云百科分享腾讯云2核2G4M轻量应用服务器配置性能表&#xff1a; 目录 腾讯云轻量2核2G4M服…