大模型之二十八-语音识别Whisper进阶

在上一篇博客大模型之二十七-语音识别Whisper实例浅析中遗留了几个问题,这里来看一下前两个问题。
1.如果不是Huggingface上可以下载的数据该怎么办?
2.上面的代码是可以训练了,但是训练的时候loss真的会和我们预期一致吗?比如如下怎么办?

进阶内容

在Whisper语音识别fine-tune的例子中,我们使用的是Huggingface封装好的数据加载以及Transformer工具,这将很多底层细节对开发人员屏蔽了,但是对于技术人员而言,这还远远不够,本篇通过一个要解决两个问题:
1.数据集是私有的,并不是Huggingface开源的数据集
2.不使用Huggingface封装好的Training pipeline,在Whisper开源的源代码基础之上fine-tune模型,并验证准确性。

整个框架代码使用pytorch-lightning来实现,目前很多优秀的比较大的开源都是实用pytorch-lightning来实现的。

安装一些python库

首先下载Whisper源代码,并且

! pip install git+https://github.com/openai/whisper.git
! pip install jiwer 
! pip install pytorch-lightning==2.4.0
! pip install -qqq evaluate==0.2.2

导入必要的python包

import os
import glob
import numpy as nptry:import tensorflow  # required in Colab to avoid protobuf compatibility issues
except ImportError:passimport torch
from torch import nn
import pandas as pd
import whisper
import torchaudio
import torchaudio.transforms as atfrom pytorch_lightning import LightningModule
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLoggerfrom tqdm.notebook import tqdm
import evaluatefrom transformers import (AdamW,get_linear_schedule_with_warmup
)

遗留的第一个问题–数据集

这里的数据集基于清华大学开源的30小时中文照着文本读而录的音频,原下载地址,
为了减小资源的开销,在有限的资源下,多迭代epoch,这里对数据集做了处理:

  • 将数据集缩到了10个小时/30小时,
  • 去掉了txt里音素的标注,只留文本,因为在这个数据集开源的时候,那时语音识别系统还是基于音素的。

可以关注私信我,联系索取处理之后的语料。

数据集处理

import globDATASET_DIR = "/kaggle/input/th30-all"
SAMPLE_RATE = 16000
BATCH_SIZE = 4
TRAIN_RATE = 0.85#whipser的输入是30s,16kHz采样率,最长480000 sample
AUDIO_MAX_LENGTH = 480000
TEXT_MAX_LENGTH = 120DEVICE = "gpu" if torch.cuda.is_available() else "cpu"###################### 读取数据信息并分离出train和val
dataset_dir = DATASET_DIR
transcripts_path_list = glob.glob(os.path.join(dataset_dir, "*.txt"))
print(len(transcripts_path_list))
13388

读取数据信息并分离出train和val

dataset_dir = DATASET_DIR
transcripts_path_list = glob.glob(os.path.join(dataset_dir, "*.txt"))
print(len(transcripts_path_list))def load_wave(wave_path, sample_rate:int=16000) -> torch.Tensor:waveform, sr = torchaudio.load(wave_path, normalize=True)if sample_rate != sr:waveform = at.Resample(sr, sample_rate)(waveform)return waveformdef get_audio_file_list(transcripts_path_list, text_max_length=120, audio_max_sample_length=480000, sample_rate=16000):audio_transcript_pair_list = []for transcripts_path in tqdm(transcripts_path_list):# audio文件目录确认audio_dir = os.path.dirname(transcripts_path)# 从翻译文本获取音频和文本with open(transcripts_path, "r") as f:text_list = f.readlines()for text in text_list:audio_id, text = text.replace("\n", "").split(":")#print(audio_id, text)audio_path = os.path.join(audio_dir, f"{audio_id}.wav")if os.path.exists(audio_path):# 检查数据audio = load_wave(audio_path, sample_rate=sample_rate)[0]if len(text) > text_max_length or len(audio) > audio_max_sample_length:print(len(text), len(audio))continueaudio_transcript_pair_list.append((audio_id, str(audio_path), text))return audio_transcript_pair_listtrain_num = int(len(transcripts_path_list) * TRAIN_RATE)
train_transcripts_path_list, eval_transcripts_path_list = transcripts_path_list[:train_num], transcripts_path_list[train_num:]
train_audio_transcript_pair_list = get_audio_file_list(train_transcripts_path_list, TEXT_MAX_LENGTH, AUDIO_MAX_LENGTH, SAMPLE_RATE)
eval_audio_transcript_pair_list = get_audio_file_list(eval_transcripts_path_list, TEXT_MAX_LENGTH, AUDIO_MAX_LENGTH, SAMPLE_RATE)
print("TRAIN AUDIO DATASET NUM: ", len(train_audio_transcript_pair_list))
print("EVAL AUDIO DATASET NUM: ", len(eval_audio_transcript_pair_list))
133880%|          | 0/11379 [00:00<?, ?it/s]  0%|          | 0/2009 [00:00<?, ?it/s]
TRAIN AUDIO DATASET NUM:  11379
EVAL AUDIO DATASET NUM:  2009

Data loader

woptions = whisper.DecodingOptions(language="zh", without_timestamps=True)
wmodel = whisper.load_model(name="small",download_root="./whisper-small")
wtokenizer = whisper.tokenizer.get_tokenizer(True, language="zh", task=woptions.task)class Th30Dataset(torch.utils.data.Dataset):def __init__(self, audio_info_list, tokenizer, sample_rate) -> None:super().__init__()self.audio_info_list = audio_info_listself.sample_rate = sample_rateself.tokenizer = tokenizerdef __len__(self):return len(self.audio_info_list)def __getitem__(self, index):audio_id, audio_path, text = self.audio_info_list[index]#aduio monoaudio = load_wave(audio_path, sample_rate=self.sample_rate)audio = whisper.pad_or_trim(audio.flatten(), AUDIO_MAX_LENGTH)mel = whisper.log_mel_spectrogram(audio)#texttext = [*self.tokenizer.sot_sequence_including_notimestamps] + self.tokenizer.encode(text)labels = text[1:] + [self.tokenizer.eot]return {"input_ids": mel,"labels": labels,"dec_input_ids": text}
class WhisperDataCollatorWhithPadding:def __call__(self, features):input_ids, labels, dec_input_ids = [], [], []for f in features:input_ids.append(f["input_ids"])labels.append(f["labels"])dec_input_ids.append(f["dec_input_ids"])input_ids = torch.concat([input_id[None, :] for input_id in input_ids])label_lengths = [len(lab) for lab in labels]dec_input_ids_length = [len(e) for e in dec_input_ids]max_label_len = max(label_lengths + dec_input_ids_length)labels = [np.pad(lab, (0, max_label_len - lab_len), 'constant', constant_values=-100) for lab, lab_len in zip(labels, label_lengths)]dec_input_ids = [np.pad(e, (0, max_label_len - e_len), 'constant', constant_values=50257) for e, e_len in zip(dec_input_ids, dec_input_ids_length)] # 50257 is eot token idbatch = {"labels": labels,"dec_input_ids": dec_input_ids}batch = {k: torch.tensor(np.array(v), requires_grad=False) for k, v in batch.items()}batch["input_ids"] = input_idsreturn batchdataset = Th30Dataset(eval_audio_transcript_pair_list, wtokenizer, SAMPLE_RATE)
loader = torch.utils.data.DataLoader(dataset, batch_size=2, collate_fn=WhisperDataCollatorWhithPadding())

这是典型的Pytorch而不是前篇中Huggingface的数据加载方法,需要实现datasetDataLoader,详细参考Pytorch Lightning官方文档。至此,遗留的第一个问题解决。

验证数据集加载

DEVICE = "gpu" if torch.cuda.is_available() else "cpu"
for b in loader:print(b["labels"].shape)print(b["input_ids"].shape)print(b["dec_input_ids"].shape)for token, dec in zip(b["labels"], b["dec_input_ids"]):token[token == -100] = wtokenizer.eottext = wtokenizer.decode(token)print(text)dec[dec == -100] = wtokenizer.eottext = wtokenizer.decode(dec)print(text)break
torch.Size([2, 50])
torch.Size([2, 80, 3000])
torch.Size([2, 50])
<|zh|><|transcribe|><|notimestamps|>节目单上赫然印着特邀中央乐团百余位演奏演唱家微妙地避开了矛盾<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>
<|startoftranscript|><|zh|><|transcribe|><|notimestamps|>节目单上赫然印着特邀中央乐团百余位演奏演唱家微妙地避开了矛盾<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>
<|zh|><|transcribe|><|notimestamps|>放眼望去永定河两旁人声鼎沸彩旗飘扬推土机挖土机运土车正紧张地忙碌着<|endoftext|>
<|startoftranscript|><|zh|><|transcribe|><|notimestamps|>放眼望去永定河两旁人声鼎沸彩旗飘扬推土机挖土机运土车正紧张地忙碌着

验证解码器

with torch.no_grad():audio_features = wmodel.encoder(b["input_ids"].cuda())input_ids = b["input_ids"]labels = b["labels"].long()dec_input_ids = b["dec_input_ids"].long()audio_features = wmodel.encoder(input_ids.cuda())print(dec_input_ids)print(input_ids.shape, dec_input_ids.shape, audio_features.shape)print(audio_features.shape)print()# 计算解码器的输出
out = wmodel.decoder(dec_input_ids.cuda(), audio_features)print(out.shape)
print(out.view(-1, out.size(-1)).shape)
print(b["labels"].view(-1).shape)
tensor([[50258, 50260, 50359, 50363, 45161, 11386, 47446,  5708,  5266,   104,5823, 35825, 20708, 17682,  3023,   222,  5975,  1787,   106, 44365,3919,    95, 31906,  1593,   247, 11160, 31382,  1881,   237, 31382,39861,  5155, 39152,  5648,   247, 10928,  3330,   123, 18937,  2289,5881,   249,  5419,   122, 50257, 50257, 50257, 50257, 50257, 50257],[50258, 50260, 50359, 50363, 12744, 25281, 22694,  6734, 42503, 12088,3308,   111, 36257,  4479,   223,  4035, 32045, 41111,   236,  3308,116,  7391,   102,  4479,   245, 11808,   246,  3416,   105, 33597,45506, 37960,  8501,   244, 45506, 37960,  3316,   238, 45506, 17819,99, 15789,  7732,   100, 44059, 10928, 48839, 16337,   234, 20708]])
torch.Size([2, 80, 3000]) torch.Size([2, 50]) torch.Size([2, 1500, 768])
torch.Size([2, 1500, 768])torch.Size([2, 50, 51865])
torch.Size([100, 51865])
torch.Size([100])

token转文本输出

tokens = torch.argmax(out, dim=2)
for token in tokens:token[token == -100] = wtokenizer.eottext = wtokenizer.decode(token)print(text)
<|zh|><|translate|><|notimestamps|>节目单上赫然印着特邀中央乐团百余位演奏歌唱家微妙的避开了矛盾<|endoftext|><|endoftext|> <|endoftext|><|endoftext|><|endoftext|><|endoftext|>
<|zh|><|transcribe|><|notimestamps|>放眼望去,定河兩旁人身顎沸彩旗飘扬推土机挖土机运土车正紧张地忙碌着<|endoftext|>

构造trainer

class Config:learning_rate = 0.0001weight_decay = 0.01adam_epsilon = 1e-8warmup_steps = 2batch_size = 16num_worker = 2num_train_epochs = 1000gradient_accumulation_steps = 1sample_rate = SAMPLE_RATEclass WhisperModelModule(LightningModule):def __init__(self, cfg: Config, model_name="small", lang="zh", train_dataset=[], eval_dataset=[]) -> None:super().__init__()self.options = whisper.DecodingOptions(language=lang, without_timestamps=True)self.model = whisper.load_model(model_name)self.tokenizer = whisper.tokenizer.get_tokenizer(True, language="zh", task=self.options.task)# only decoder trainingfor p in self.model.encoder.parameters():p.requires_grad = Falseself.loss_fn = nn.CrossEntropyLoss(ignore_index=-100)self.metrics_wer = evaluate.load("wer")self.metrics_cer = evaluate.load("cer")self.cfg = cfgself.__train_dataset = train_datasetself.__eval_dataset = eval_datasetdef forward(self, x):return self.model(x)def training_step(self, batch, batch_id):input_ids = batch["input_ids"]labels = batch["labels"].long()dec_input_ids = batch["dec_input_ids"].long()with torch.no_grad():audio_features = self.model.encoder(input_ids)out = self.model.decoder(dec_input_ids, audio_features)loss = self.loss_fn(out.view(-1, out.size(-1)), labels.view(-1))self.log("train/loss", loss, on_step=False, on_epoch=True,  prog_bar=True, logger=True)return lossdef on_train_epoch_end(self):avg_loss = self.trainer.callback_metrics.get("train/loss")# 获取当前的 epoch 数量epoch = self.current_epochprint(f"Epoch: {epoch}, Training - Loss: {avg_loss:.4f}")def validation_step(self, batch, batch_id):input_ids = batch["input_ids"]labels = batch["labels"].long()dec_input_ids = batch["dec_input_ids"].long()audio_features = self.model.encoder(input_ids)out = self.model.decoder(dec_input_ids, audio_features)loss = self.loss_fn(out.view(-1, out.size(-1)), labels.view(-1))out[out == -100] = self.tokenizer.eotlabels[labels == -100] = self.tokenizer.eoto_list, l_list = [], []for o, l in zip(out, labels):o = torch.argmax(o, dim=1)o_list.append(self.tokenizer.decode(o))l_list.append(self.tokenizer.decode(l))wer = self.metrics_wer.compute(references=l_list, predictions=o_list)self.log("val/loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)self.log("val/wer", wer, on_step=False, on_epoch=True, prog_bar=True, logger=True)# 打印到终端#print(f"Validation - Loss: {loss:.4f}, WER: {wer:.4f}")return {"wer": wer,"loss": loss}def on_validation_epoch_end(self):avg_loss = self.trainer.callback_metrics.get("val/loss")avg_wer = self.trainer.callback_metrics.get("val/wer")# 获取当前的 epoch 数量epoch = self.current_epochprint(f"Epoch: {epoch}, Validation - Loss: {avg_loss:.4f}, WER: {avg_wer:.4f}")def configure_optimizers(self):"""创建优化程序和调度器 """model = self.modelno_decay = ["bias", "LayerNorm.weight"]optimizer_grouped_parameters = [{"params": [p for n, p in model.named_parameters()if not any(nd in n for nd in no_decay)],"weight_decay": self.cfg.weight_decay,},{"params": [p for n, p in model.named_parameters()if any(nd in n for nd in no_decay)],"weight_decay": 0.0,},]optimizer = AdamW(optimizer_grouped_parameters,lr=self.cfg.learning_rate,eps=self.cfg.adam_epsilon)self.optimizer = optimizerscheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=self.cfg.warmup_steps,num_training_steps=self.t_total)self.scheduler = schedulerreturn [optimizer], [{"scheduler": scheduler, "interval": "step", "frequency": 1}]def setup(self, stage=None):"""初始设置(读取数据集)"""if stage == 'fit' or stage is None:self.t_total = ((len(self.__train_dataset) // (self.cfg.batch_size))// self.cfg.gradient_accumulation_steps* float(self.cfg.num_train_epochs))def train_dataloader(self):""" 创建训练数据加载程序 """dataset = Th30Dataset(self.__train_dataset, self.tokenizer, self.cfg.sample_rate)return torch.utils.data.DataLoader(dataset,batch_size=self.cfg.batch_size,drop_last=True, shuffle=True, num_workers=self.cfg.num_worker,collate_fn=WhisperDataCollatorWhithPadding())def val_dataloader(self):""" 创建验证数据加载程序 """dataset = Th30Dataset(self.__eval_dataset, self.tokenizer, self.cfg.sample_rate)return torch.utils.data.DataLoader(dataset,batch_size=self.cfg.batch_size,num_workers=self.cfg.num_worker,collate_fn=WhisperDataCollatorWhithPadding())  

主要是对LightningModule类相关方法的重载,定义了train、validate以及optimizer的行为,以及在训练过程中日志和相关信息、checkpoint的保存。

启动训练

log_output_dir = "./logs"
check_output_dir = "./artifacts"train_name = "whisper"
train_id = "00001"model_name = "small"
lang = "zh"cfg = Config()# os.mkdir(log_output_dir)
# os.mkdir(check_output_dir)tflogger = TensorBoardLogger(save_dir=log_output_dir,name=train_name,version=train_id
)from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpointcheckpoint_callback = ModelCheckpoint(dirpath=f"{check_output_dir}/checkpoint",filename="checkpoint-{epoch:04d}",save_top_k=2, # all model savesave_on_train_epoch_end=False,monitor='val/wer',  # 需要监控的验证损失mode='min',  # 最小化 val_lossverbose=True  # 打印更多的信息到控制台
)
callback_list = [checkpoint_callback, LearningRateMonitor(logging_interval="epoch")]
model = WhisperModelModule(cfg, model_name, lang, train_audio_transcript_pair_list, eval_audio_transcript_pair_list)trainer = Trainer(precision=16,accelerator="gpu",max_epochs=cfg.num_train_epochs,check_val_every_n_epoch=2,accumulate_grad_batches=cfg.gradient_accumulation_steps,logger=tflogger,callbacks=callback_list
)trainer.fit(model)
```shell
对于10小时数据集,你可能看到如下输出:

Epoch: 0, Training - Loss: 1.0872
Epoch: 1, Validation - Loss: 0.4207, WER: 0.9847
Epoch: 1, Training - Loss: 0.2955
Epoch: 2, Training - Loss: 0.5555
Epoch: 3, Validation - Loss: 0.2505, WER: 0.9006
Epoch: 3, Training - Loss: 0.0979
Epoch: 4, Training - Loss: 0.0602
Epoch: 5, Validation - Loss: 0.2889, WER: 0.8764
Epoch: 5, Training - Loss: 0.0721
Epoch: 6, Training - Loss: 0.0947
Epoch: 7, Validation - Loss: 0.3839, WER: 0.9809
Epoch: 7, Training - Loss: 0.1379

对于30小时数据集,即使用完整的th30数据,其中85%用于Training,而15%用于validation你可能看到如下输出:
```shell
3051.3s	121	Epoch: 1, Validation - Loss: 0.0588, WER: 0.3499
3061.1s	122	Epoch: 1, Training - Loss: 0.0340
4279.7s	123	Epoch: 2, Training - Loss: 0.0268
5691.4s	124	Epoch: 3, Validation - Loss: 0.0676, WER: 0.8318
5701.1s	125	Epoch: 3, Training - Loss: 0.0201
6919.2s	126	Epoch: 4, Training - Loss: 0.0257
8329.5s	127	Epoch: 5, Validation - Loss: 0.0484, WER: 0.8472
8329.5s	128	Epoch: 5, Training - Loss: 0.0144
9547.5s	129	Epoch: 6, Training - Loss: 0.1127
10959.3s	130	Epoch: 7, Validation - Loss: 0.0422, WER: 0.3982
10969.4s	131	Epoch: 7, Training - Loss: 0.0053
12188.2s	132	Epoch: 8, Training - Loss: 0.0076
13600.0s	133	Epoch: 9, Validation - Loss: 0.0482, WER: 0.8158
13600.0s	134	Epoch: 9, Training - Loss: 0.0126
14819.0s	135	Epoch: 10, Training - Loss: 0.0152
16230.8s	136	Epoch: 11, Validation - Loss: 0.0544, WER: 0.6829
16230.8s	137	Epoch: 11, Training - Loss: 0.0114
17450.1s	138	Epoch: 12, Training - Loss: 0.0174
18862.4s	139	Epoch: 13, Validation - Loss: 0.0523, WER: 0.3225
18872.1s	140	Epoch: 13, Training - Loss: 0.0117
20091.5s	141	Epoch: 14, Training - Loss: 0.0075
21503.2s	142	Epoch: 15, Validation - Loss: 0.0567, WER: 0.5187
21503.2s	143	Epoch: 15, Training - Loss: 0.0137
22722.5s	144	Epoch: 16, Training - Loss: 0.0150
24134.2s	145	Epoch: 17, Validation - Loss: 0.0631, WER: 0.4559
24134.2s	146	Epoch: 17, Training - Loss: 0.0122
25352.9s	147	Epoch: 18, Training - Loss: 0.0120
26765.0s	148	Epoch: 19, Validation - Loss: 0.0523, WER: 0.7387
26765.0s	149	Epoch: 19, Training - Loss: 0.0060
27983.9s	150	Epoch: 20, Training - Loss: 0.0154
29395.1s	151	Epoch: 21, Validation - Loss: 0.0520, WER: 0.4749
29395.1s	152	Epoch: 21, Training - Loss: 0.0073
30612.5s	153	Epoch: 22, Training - Loss: 0.6361
32022.4s	154	Epoch: 23, Validation - Loss: 0.0396, WER: 0.2912
32033.0s	155	Epoch: 23, Training - Loss: 0.0029
33250.8s	156	Epoch: 24, Training - Loss: 0.0036
34662.0s	157	Epoch: 25, Validation - Loss: 0.0461, WER: 0.6043
34662.0s	158	Epoch: 25, Training - Loss: 0.0094
35880.1s	159	Epoch: 26, Training - Loss: 0.0082
37291.0s	160	Epoch: 27, Validation - Loss: 0.0428, WER: 0.7481
37291.0s	161	Epoch: 27, Training - Loss: 0.0051
38509.7s	162	Epoch: 28, Training - Loss: 0.0075
39920.4s	163	Epoch: 29, Validation - Loss: 0.0447, WER: 0.8736
39920.4s	164	Epoch: 29, Training - Loss: 0.0091
41138.9s	165	Epoch: 30, Training - Loss: 0.0088
42549.9s	166	Epoch: 31, Validation - Loss: 0.0530, WER: 0.4500
42549.9s	167	Epoch: 31, Training - Loss: 0.0072

遗留的第二个问题

首先是数据集的问题,因为可以看到随着时长的增加,看到模型训练过程在符合预期方向走,

  1. 最低数据量:起步来说,至少需要几个小时的音频数据来进行有效的fine-tuning。例如,从10小时开始,这是一个相对较小的数据集,可以用来调试模型和流程。

  2. 中等数据量:为了获得更佳的效果,推荐使用20至50小时的音频数据。这可以帮助模型更好地学习到特定语言的特性。

  3. 理想数据量:如果资源允许,使用超过100小时的音频数据将更有助于模型性能的提升。更多的数据可以显著提高模型的泛化能力和准确性。

当然对于大模型,数据质量越高越好,数据多样性越多越好。
进一步通过tensorboard图可以看到:
请添加图片描述
在运行12个小时之后可以看到WER比一开始的确实下降了不少,但是还没有达到20%左右,最低的WER在0.2912,但是这里可以观察到一个非常有趣的现象:

在观察到训练损失(Training Loss)持续下降而验证损失(Validation Loss)和字错误率 (WER, Word Error Rate) 没有持续改善或波动较大的情况时,这通常是过拟合的一个迹象。在这种情况下,模型在训练数据上表现得越来越好,但在未见过的验证数据上的表现却没有相对应的提升,甚至出现恶化。

由于callback回调中会保持前两个在验证集上WER最小的两个checkpoint,接下来有几个思路:

1.分析模型在验证集上的错误,看是否存在特定模式或类型的错误,这可能帮助诊断问题并指导进一步模型调整,因为我们是在whipser开源的基础上fine-tune的,所以不可能简化模型结构的本身,如减少层数或神经元数目以改善过拟合。

2.可以考虑正则化技术(L2正则化、Dropout)等以有助于缓解过拟现象,增强模型的泛化能力

3.调整训练策略,调整学习率或者使用不同的优化器,以评估模型在验证集上的表现;

4.增加更多数据,帮助模型学习到更多特征,从而提高模型泛化能力

观察验证集识别效果

由于输出缩略或视觉上的相似性,一些小的差异(如标点、空白或特殊字符)可能不容易觉察。这些微小的差异在计算WER时会被考虑进去,但在人眼检查时可能会被忽略。
请添加图片描述
可以看到基本上是一致的,但是个别词是有出入的,这是因为th30是人工读的,准确性比较高,并不意味着通话、会议、游戏场景的识别率也能如此。
这里再留几个尾巴给读者自己实现:

CER

1.中文是基于字符的语言,通常我们会使用CER(Character Error Rate,字符错误率)来进行更精确的评估。然而,如果你使用的是WER来评估中文语音识别的质量,这里有几点可能需要注意:

  • 在处理中文时,如果WER是基于词的,就必须先进行准确的分词。中文没有明显的词与词之间的分隔,因此分词的准确性对于WER的计算非常关键。错误的分词可能导致高WER,即使识别的字符完全正确。
  • 中文中的一些微小差异,如同音字错误、词序变化或者是语气词的使用,都可以在视觉上看起来非常相似,但在WER的计算中会被视为错误。
  • 你查看的样本可能并不代表整体数据集的平均表现。此外,中文语音识别可能特别擅长处理某些特定的语句或者在某些领域表现更好。

数据质量

除了数据量之外,数据的质量也非常重要:

  • 多样性:数据应该涵盖多种口音、语速和语调,以及不同的背景噪声环境,这将帮助模型在各种输入条件下都能保持稳定的表现。
  • 标注准确性:确保你的数据标注尽可能准确,错误的标注会直接影响模型学习的结果。

预处理和增强

  • 预处理:对音频进行预处理,如采样率转换(确保和模型训练时使用的采样率一致),音量标准化等。
  • 数据增强:可以考虑使用音频数据增强技术,如添加背景噪声、改变语速和音高等,以增加模型的鲁棒性。

资源和迭代

  • 计算资源:fine-tuning一个语音识别模型可能需要大量的计算资源,特别是当使用大量数据时。确保你有足够的GPU资源进行训练。
  • 迭代和评估:在fine-tuning过程中需要多次迭代和评估,以找到最优的模型参数和设置。

总结来说,训练的过程如炼丹,有些训练的经验是不能从小模型直接用到大模型上的。比如small和对于large-v3两种。
在模型相对较小的时候,learning rate的设置可以比较激进,但是对非常大模型的时候,较大的lr可能导致模型一开始loss就无法收敛,是发散的,但如果设置的lr比较小,那可能使得训练的时长成倍增加,怎么办呢?针对很大的模型,warm-up策略是很多时候会使用的。

load weight and inference

checkpoint_path = "whisper-checkpoint/checkpoint-epoch0023.ckpt"
state_dict = torch.load(checkpoint_path)
print(state_dict.keys())
state_dict = state_dict['state_dict']
/tmp/ipykernel_36/4099222220.py:2: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.state_dict = torch.load(checkpoint_path)
dict_keys(['epoch', 'global_step', 'pytorch-lightning_version', 'state_dict', 'loops', 'callbacks', 'optimizer_states', 'lr_schedulers', 'MixedPrecision'])

加载模型参数

cfg = Config()
whisper_model = WhisperModelModule(cfg)
whisper_model.load_state_dict(state_dict)
100%|███████████████████████████████████████| 461M/461M [00:05<00:00, 87.6MiB/s]
Downloading builder script:   0%|          | 0.00/4.49k [00:00<?, ?B/s]
Downloading builder script:   0%|          | 0.00/5.60k [00:00<?, ?B/s]
<All keys matched successfully>

前向推理

woptions = whisper.DecodingOptions(language="zh", without_timestamps=True)
dataset = Th30Dataset(eval_audio_transcript_pair_list, wtokenizer, SAMPLE_RATE)
loader = torch.utils.data.DataLoader(dataset, batch_size=2, collate_fn=WhisperDataCollatorWhithPadding())refs = []
res = []
for b in tqdm(loader):input_ids = b["input_ids"].half().cuda()labels = b["labels"].long().cuda()with torch.no_grad():#audio_features = whisper_model.model.encoder(input_ids)#out = whisper_model.model.decoder(enc_input_ids, audio_features)results = whisper_model.model.decode(input_ids, woptions)for r in results:res.append(r.text)for l in labels:l[l == -100] = wtokenizer.eotref = wtokenizer.decode(l)refs.append(ref)```打印推理结果```for k, v in zip(refs, res):print("-"*10)print(k)print(v)

部分输出结果

  ----------
<|zh|><|transcribe|><|notimestamps|>节目单上赫然印着特邀中央乐团百余位演奏演唱家微妙地避开了矛盾<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>
节目单上赫然印着特邀中央乐团百余位演奏歌唱家微妙地避开了矛盾
----------
<|zh|><|transcribe|><|notimestamps|>放眼望去永定河两旁人声鼎沸彩旗飘扬推土机挖土机运土车正紧张地忙碌着<|endoftext|>
放眼望去永定河两旁人声鼎沸彩旗飘扬推土机挖土机运土车正紧张地忙碌着
----------
<|zh|><|transcribe|><|notimestamps|>旅与游的时间比往往旅长游短与游客的愿望相悖<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>
旅与游的时间比往往旅长游短与游客的愿望相悖
----------
<|zh|><|transcribe|><|notimestamps|>中学毕业后他考入尼恩罗德商业学院毕业后曾服兵役并在一家贸易公司任职<|endoftext|>
中学毕业后他考入尼恩罗德商业学院毕业后曾服兵役并在一家贸易公司任职
----------
<|zh|><|transcribe|><|notimestamps|>该片导演为虞石束杨军主要演员有韩夫一李进慕爱秋周桂云金鑫冯云魁等<|endoftext|>
该片导演为虞石束杨军主要演员有韩夫一李进慕爱秋周桂云金鑫冯云魁等
----------
<|zh|><|transcribe|><|notimestamps|>何勰二话没说立即交了一千五百元的押金又为洛桑卓玛买来了全套新衣服和住院用品<|endoftext|><|endoftext|>
何勰二话没说立即交了一千五百元的押金又为洛桑卓玛买来了全套新衣服和住院用品
----------
<|zh|><|transcribe|><|notimestamps|>印加人所创造的文明与玛雅文明阿兹特克文明一起被誉为美洲印第安三大文明<|endoftext|>
印加人所创造的文明与玛雅文明阿兹特克文明一起被誉为美洲印第安三大文明
----------
<|zh|><|transcribe|><|notimestamps|>今天陪以萌找冯邦找得又累又饿但看见以萌那副着急样我一点也吃不下<|endoftext|><|endoftext|><|endoftext|><|endoftext|>
今天陪以萌找冯邦找得又累又饿但看见以萌那副着急样我一点也吃不下
----------
<|zh|><|transcribe|><|notimestamps|>亲英的北爱尔兰新教派武装十二日晚发表声明威胁要报复爱尔兰共和军<|endoftext|>
亲英的北爱尔兰新教派武装十二日晚发表声明威胁要报复爱尔兰共和军
----------
<|zh|><|transcribe|><|notimestamps|>小仲不顾闲言碎语一天几趟往我家跑为我洗衣做饭熬药煎汤<|endoftext|><|endoftext|><|endoftext|>
小仲不顾闲言碎语一天几趟往我家跑为我洗衣做饭熬药煎汤
----------
<|zh|><|transcribe|><|notimestamps|>这位病人因贲门下胃底大弯静脉曲张伴血管瘤破裂胃内大量喷血<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>
这位病人因贲门下胃底大弯静脉曲张伴血管瘤破裂胃内大量喷血
----------
<|zh|><|transcribe|><|notimestamps|>驻藏边防某部二连战士赵金站岗时隐隐约约听见营区外的不远处有哭泣声<|endoftext|>
驻藏边防某部二连战士赵金站岗时隐隐约约听见营区外的不远处有哭泣声
----------
<|zh|><|transcribe|><|notimestamps|>其种植的红富士苹果以色泽艳丽果质细脆汁多味美和极耐贮运而享誉海内外<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>
其种植的红富士苹果以色泽艳丽果质细脆汁多味美和极耐贮存而享誉海内外
----------
<|zh|><|transcribe|><|notimestamps|>一九四一年十一月陕甘宁边区根据三三制原则举行参议会议员竞选一位名叫森健的学员被推为候选人<|endoftext|>
一九四一年十一月陕甘宁边区根据三三制原则举行参议会议员竞选一位名叫森健的学员被推为候选人
----------
<|zh|><|transcribe|><|notimestamps|>当船往下漂时白唇鹿扬起四蹄在岸边追随好像是送行一直跑了十几里太亲切了<|endoftext|>
当船往下漂时白唇鹿扬起四蹄在岸边追赶好像是送行一直跑了十几里太亲切了
----------
<|zh|><|transcribe|><|notimestamps|>有的单位按年人均月收入减去费用八百元后的余额为应纳税所得额<|endoftext|><|endoftext|><|endoftext|><|endoftext|>
有的单位按年人均月收入减去费用八百元后的余额为应纳税所得额
----------
<|zh|><|transcribe|><|notimestamps|>女性腰部以上特别肥胖者易患乳腺癌腰围与臀围差别不大者患癌率比一般妇女高六倍<|endoftext|>
女性腰部以上特别肥胖者易患乳腺癌腰围与臀围差别不大者患癌率比一般妇女高六倍
----------
<|zh|><|transcribe|><|notimestamps|>日本队在男子团体赛中获银牌队员岩井哲贤在个人全能赛也夺得一枚银牌<|endoftext|><|endoftext|><|endoftext|>
日本队在男子团体赛中获银牌队员岩井哲贤在个人全能赛也夺得一枚银牌
----------
<|zh|><|transcribe|><|notimestamps|>如此举措源于杭州娃哈哈食品集团公司总经理宗庆后对市场特质的洞悉<|endoftext|><|endoftext|><|endoftext|>
如此举措源于杭州娃哈哈食品集团公司总经理宗庆后对市场特质的洞悉```
接下来还有三个问题对于应用更需要细致考虑:
1.Whisper除了识别,还有直接翻译功能,在以前要先识别成中文,再汉译英等,这个好处是显而易见的,首先只要一个模型,节约部分人力、机器以及服务端GPU,业务场景上可以是会议的实时翻译、看英文视频实时翻译成中文,这会减少latency,用户体验也更好;
2.如何在实时的流式场景中使用?
3.kv-caching是个什么技术?12倍是如何做到的?这在工程部署商用价值非常大。欢迎点赞、收藏、关注,以便及时收到下一篇推送。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/412171.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

最新视频合成后调优技术ExVideo模型部署

ExVideo是一种新型的视频合成模型后调优技术&#xff0c;由华东师范大学和阿里巴巴的研究人员共同开发。 ExVideo提出了一种新的后调优策略&#xff0c;无需对整个模型进行大规模重训&#xff0c;仅通过对模型中时序相关组件的微调&#xff0c;就能够显著增强其生成更长视频片…

Linux 安装Mysql保姆级教程

一、检查环境 我们登录服务器&#xff0c;查看之前是否安装过mysql rpm -qa | grep mysql 由于我之前安装过&#xff0c;所以这里是有数据的 如果需要删除重新下载&#xff0c;可以使用 rpm -e mysql57-community-release-el7-10.noarch.rpm 二、安装 1、下载 接下来下载安装…

群晖(Docker Compose)配置 frp 服务

为了方便远程电脑&#xff0c;访问自己电脑上的ComfyUI等服务&#xff0c;配置了 frp 服务。 配置 frp 服务后&#xff0c;发现群晖中的一些服务也可以 stcp 安全的暴露出来。 直接在群晖通过 Docker Compose 方式部署 frps 和 frpc&#xff0c;访问者通过 frpc 安全访问暴露…

CentOS 7安装和配置 NFS

前言 NFS 是 Network File System 的缩写&#xff0c;即网络文件系统。功能是让客户端通过网络访问不同主机上磁盘里的数据&#xff0c;主要用在类 Unix 系统上实现文件共享的一种方法。本例演示 CentOS 7 下安装和配置 NFS 的基本步骤。 环境说明 CentOS 7&#xff08;Mini…

光学涡旋Talbot阵列照明器的matlab模拟与仿真

目录 1.程序功能描述 2.测试软件版本以及运行结果展示 3.核心程序 4.本算法原理 5.完整程序 1.程序功能描述 光学涡旋 Talbot 阵列照明器是一种利用光学涡旋&#xff08;Optical Vortex&#xff09;和 Talbot 效应&#xff08;Talbot Effect&#xff09;相结合的技术&…

LVS部署——DR集群

目录 一、LVS—DR工作原理 二、LVS-DR数据流向 三、LVS-DR模式特点和优缺点 3.1、特点 3.2、优缺点 四、LVS-DR中的ARP问题 4.1、IP地址冲突 4.2、第二次访问请求失败 五、部署LVS-DR集群 5.1、实验准备 5.2、配置负载调度器&#xff08;192.168.20.15&#xff09; …

SpringBoot2:学SpringBoot前的知识准备-用IDEA创建传统的webapp工程,并整合SpringMVC

1、IDEA创建工程 基于Maven模板创建的SpringMVC工程 工程创建好后&#xff0c;只有webapp目录 这里&#xff0c;我们需要手动创建java目录和resources配置文件目录 创建好后&#xff0c;配置下目录属性 最终结构 至此&#xff0c;工程就创建好了 2、配置Tomcat 参考&am…

【Tesla FSD V12的前世今生】从模块化设计到端到端自动驾驶技术的跃迁

自动驾驶技术的发展一直是全球汽车行业的焦点&#xff0c;Tesla的Full-Self Driving&#xff08;FSD&#xff09;系统凭借其持续的技术革新和强大的数据支持&#xff0c;在这个领域独占鳌头。本文将深入介绍Tesla FSD V12的演进历史&#xff0c;从自动驾驶的基础概念入手&#…

机器学习 之 决策树与随机森林的实现

引言 随着互联网技术的发展&#xff0c;垃圾邮件过滤已成为一项重要的任务。机器学习技术&#xff0c;尤其是决策树和随机森林&#xff0c;在解决这类问题时表现出色。本文将介绍随机森林的基本概念&#xff0c;并通过一个具体的案例——筛选垃圾电子邮件——来展示随机森林的…

【OpenGL】xcode+glfw画三角形

环境搭建 1. 执行brew install glfw 2. 项目中Build Settings中header Search Paths中添加glfw的include路径 3. 项目中Build Phases中的Link Binary With Libraries中添加glfw的lib文件&#xff08;路径/opt/homebrew/Cellar/glfw/3.4/lib/libglfw.3.4.dylib&#xff09;及…

数据结构之内核链表,栈,队列

今天主要学习了内核链表&#xff0c;顺序栈&#xff0c;链式栈&#xff0c;顺序队列&#xff0c;链式队列的相关内容。 一.内核链表 内核链表和之前的单向&#xff0c;双向链表有所不同的是内核链表的结构是数据包含节点&#xff0c;特点如下&#xff1a; 1.一种链表结构能够操…

谷歌的 GameNGen:无需游戏引擎,人工智能模拟 “毁灭战士“,开辟新天地

谷歌公司的研究人员创建了一个神经网络&#xff0c;可以在不使用传统游戏引擎的情况下生成经典射击游戏《毁灭战士》的实时游戏&#xff0c;从而实现了人工智能领域的一个重要里程碑。这个名为 GameNGen 的系统标志着人工智能向前迈出了重要一步&#xff0c;它能在单芯片上以每…

c语言(二叉树)

第4章 二叉树和BST 树与二叉树 基本概念 树是一种非线性结构&#xff0c;其严格的数学定义是&#xff1a;如果一组数据中除了第一个节点&#xff08;第一个节点称为根节点&#xff0c;没有直接前驱节点&#xff09;之外&#xff0c;其余任意节点有且仅有一个直接前驱&#xff…

Python相关系数导图

&#x1f3af;要点 量化变量和特征关联绘图对比皮尔逊相关系数、斯皮尔曼氏秩和肯德尔秩汽车性价比相关性矩阵热图大流行病与资产波动城镇化模型预测交通量宝可梦类别特征非线性依赖性捕捉向量加权皮尔逊相关系数量化图像相似性 Python皮尔逊-斯皮尔曼-肯德尔 皮尔逊相关系…

Node.js原生开发脚手架工具(下)

前言 在现代软件开发中&#xff0c;脚手架工具成为提高开发效率和一致性的关键利器。使用Node.js原生开发自己的脚手架工具不仅能帮助自动化常见任务&#xff0c;还能根据具体需求进行高度定制。Node.js的异步非阻塞特性和丰富的模块系统使其成为构建这种工具的理想选择。本篇文…

★ 算法OJ题 ★ 力扣202 - 快乐数

Ciallo&#xff5e;(∠・ω< )⌒☆ ~ 今天&#xff0c;我将和大家一起做一道双指针算法题--快乐数~ 目录 一 题目 二 算法解析 三 编写算法 一 题目 202. 快乐数 - 力扣&#xff08;LeetCode&#xff09; 二 算法解析 题⽬告诉我们&#xff0c;当我们不断重复操作…

Java设计模式之外观模式详细讲解和案例示范

1. 引言 在软件开发过程中&#xff0c;复杂的系统往往包含许多子系统和模块&#xff0c;随着系统功能的增加&#xff0c;模块之间的交互也变得更加复杂。这种复杂性可能会导致系统的可维护性和扩展性降低。外观模式&#xff08;Facade Pattern&#xff09;是一种结构型设计模式…

java同步概念

同步&#xff08;Synchronization&#xff09;在Java多线程编程中是一个既重要又复杂的概念。它涉及到如何确保多个线程在访问共享资源时能够保持数据的一致性和完整性&#xff0c;避免出现竞态条件&#xff08;Race Condition&#xff09;等问题。 同步的基本概念 同步的主要目…

深入解析体育馆蓝牙导航系统的技术实现与应用

技术爱好者与开发者们&#xff0c;您是否在大型体育馆内常常为找不到洗手间、休息区或观赛区而烦恼&#xff1f;随着科技的进步&#xff0c;我们团队倾力打造了体育馆蓝牙导航系统&#xff0c;专为解决这一痛点而生。本系统利用先进的蓝牙信标技术和精准的室内定位算法&#xf…

YOLO | YOLO目标检测算法(YOLO-V1)

github&#xff1a;https://github.com/MichaelBeechan CSDN&#xff1a;https://blog.csdn.net/u011344545 YOLO目标检测算法 YOLO V1概述&#xff08;2016&#xff09; YOLO V1概述&#xff08;2016&#xff09; 经典的One-stage方法 YOLO&#xff1a;You Only Look Once 把…