在上一篇博客大模型之二十七-语音识别Whisper实例浅析中遗留了几个问题,这里来看一下前两个问题。
1.如果不是Huggingface上可以下载的数据该怎么办?
2.上面的代码是可以训练了,但是训练的时候loss真的会和我们预期一致吗?比如如下怎么办?
进阶内容
在Whisper语音识别fine-tune的例子中,我们使用的是Huggingface封装好的数据加载以及Transformer工具,这将很多底层细节对开发人员屏蔽了,但是对于技术人员而言,这还远远不够,本篇通过一个要解决两个问题:
1.数据集是私有的,并不是Huggingface开源的数据集
2.不使用Huggingface封装好的Training pipeline,在Whisper开源的源代码基础之上fine-tune模型,并验证准确性。
整个框架代码使用pytorch-lightning来实现,目前很多优秀的比较大的开源都是实用pytorch-lightning来实现的。
安装一些python库
首先下载Whisper源代码,并且
! pip install git+https://github.com/openai/whisper.git
! pip install jiwer
! pip install pytorch-lightning==2.4.0
! pip install -qqq evaluate==0.2.2
导入必要的python包
import os
import glob
import numpy as nptry:import tensorflow # required in Colab to avoid protobuf compatibility issues
except ImportError:passimport torch
from torch import nn
import pandas as pd
import whisper
import torchaudio
import torchaudio.transforms as atfrom pytorch_lightning import LightningModule
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLoggerfrom tqdm.notebook import tqdm
import evaluatefrom transformers import (AdamW,get_linear_schedule_with_warmup
)
遗留的第一个问题–数据集
这里的数据集基于清华大学开源的30小时中文照着文本读而录的音频,原下载地址,
为了减小资源的开销,在有限的资源下,多迭代epoch,这里对数据集做了处理:
- 将数据集缩到了10个小时/30小时,
- 去掉了txt里音素的标注,只留文本,因为在这个数据集开源的时候,那时语音识别系统还是基于音素的。
可以关注私信我,联系索取处理之后的语料。
数据集处理
import globDATASET_DIR = "/kaggle/input/th30-all"
SAMPLE_RATE = 16000
BATCH_SIZE = 4
TRAIN_RATE = 0.85#whipser的输入是30s,16kHz采样率,最长480000 sample
AUDIO_MAX_LENGTH = 480000
TEXT_MAX_LENGTH = 120DEVICE = "gpu" if torch.cuda.is_available() else "cpu"###################### 读取数据信息并分离出train和val
dataset_dir = DATASET_DIR
transcripts_path_list = glob.glob(os.path.join(dataset_dir, "*.txt"))
print(len(transcripts_path_list))
13388
读取数据信息并分离出train和val
dataset_dir = DATASET_DIR
transcripts_path_list = glob.glob(os.path.join(dataset_dir, "*.txt"))
print(len(transcripts_path_list))def load_wave(wave_path, sample_rate:int=16000) -> torch.Tensor:waveform, sr = torchaudio.load(wave_path, normalize=True)if sample_rate != sr:waveform = at.Resample(sr, sample_rate)(waveform)return waveformdef get_audio_file_list(transcripts_path_list, text_max_length=120, audio_max_sample_length=480000, sample_rate=16000):audio_transcript_pair_list = []for transcripts_path in tqdm(transcripts_path_list):# audio文件目录确认audio_dir = os.path.dirname(transcripts_path)# 从翻译文本获取音频和文本with open(transcripts_path, "r") as f:text_list = f.readlines()for text in text_list:audio_id, text = text.replace("\n", "").split(":")#print(audio_id, text)audio_path = os.path.join(audio_dir, f"{audio_id}.wav")if os.path.exists(audio_path):# 检查数据audio = load_wave(audio_path, sample_rate=sample_rate)[0]if len(text) > text_max_length or len(audio) > audio_max_sample_length:print(len(text), len(audio))continueaudio_transcript_pair_list.append((audio_id, str(audio_path), text))return audio_transcript_pair_listtrain_num = int(len(transcripts_path_list) * TRAIN_RATE)
train_transcripts_path_list, eval_transcripts_path_list = transcripts_path_list[:train_num], transcripts_path_list[train_num:]
train_audio_transcript_pair_list = get_audio_file_list(train_transcripts_path_list, TEXT_MAX_LENGTH, AUDIO_MAX_LENGTH, SAMPLE_RATE)
eval_audio_transcript_pair_list = get_audio_file_list(eval_transcripts_path_list, TEXT_MAX_LENGTH, AUDIO_MAX_LENGTH, SAMPLE_RATE)
print("TRAIN AUDIO DATASET NUM: ", len(train_audio_transcript_pair_list))
print("EVAL AUDIO DATASET NUM: ", len(eval_audio_transcript_pair_list))
133880%| | 0/11379 [00:00<?, ?it/s] 0%| | 0/2009 [00:00<?, ?it/s]
TRAIN AUDIO DATASET NUM: 11379
EVAL AUDIO DATASET NUM: 2009
Data loader
woptions = whisper.DecodingOptions(language="zh", without_timestamps=True)
wmodel = whisper.load_model(name="small",download_root="./whisper-small")
wtokenizer = whisper.tokenizer.get_tokenizer(True, language="zh", task=woptions.task)class Th30Dataset(torch.utils.data.Dataset):def __init__(self, audio_info_list, tokenizer, sample_rate) -> None:super().__init__()self.audio_info_list = audio_info_listself.sample_rate = sample_rateself.tokenizer = tokenizerdef __len__(self):return len(self.audio_info_list)def __getitem__(self, index):audio_id, audio_path, text = self.audio_info_list[index]#aduio monoaudio = load_wave(audio_path, sample_rate=self.sample_rate)audio = whisper.pad_or_trim(audio.flatten(), AUDIO_MAX_LENGTH)mel = whisper.log_mel_spectrogram(audio)#texttext = [*self.tokenizer.sot_sequence_including_notimestamps] + self.tokenizer.encode(text)labels = text[1:] + [self.tokenizer.eot]return {"input_ids": mel,"labels": labels,"dec_input_ids": text}
class WhisperDataCollatorWhithPadding:def __call__(self, features):input_ids, labels, dec_input_ids = [], [], []for f in features:input_ids.append(f["input_ids"])labels.append(f["labels"])dec_input_ids.append(f["dec_input_ids"])input_ids = torch.concat([input_id[None, :] for input_id in input_ids])label_lengths = [len(lab) for lab in labels]dec_input_ids_length = [len(e) for e in dec_input_ids]max_label_len = max(label_lengths + dec_input_ids_length)labels = [np.pad(lab, (0, max_label_len - lab_len), 'constant', constant_values=-100) for lab, lab_len in zip(labels, label_lengths)]dec_input_ids = [np.pad(e, (0, max_label_len - e_len), 'constant', constant_values=50257) for e, e_len in zip(dec_input_ids, dec_input_ids_length)] # 50257 is eot token idbatch = {"labels": labels,"dec_input_ids": dec_input_ids}batch = {k: torch.tensor(np.array(v), requires_grad=False) for k, v in batch.items()}batch["input_ids"] = input_idsreturn batchdataset = Th30Dataset(eval_audio_transcript_pair_list, wtokenizer, SAMPLE_RATE)
loader = torch.utils.data.DataLoader(dataset, batch_size=2, collate_fn=WhisperDataCollatorWhithPadding())
这是典型的Pytorch而不是前篇中Huggingface的数据加载方法,需要实现dataset
和DataLoader
,详细参考Pytorch Lightning官方文档。至此,遗留的第一个问题解决。
验证数据集加载
DEVICE = "gpu" if torch.cuda.is_available() else "cpu"
for b in loader:print(b["labels"].shape)print(b["input_ids"].shape)print(b["dec_input_ids"].shape)for token, dec in zip(b["labels"], b["dec_input_ids"]):token[token == -100] = wtokenizer.eottext = wtokenizer.decode(token)print(text)dec[dec == -100] = wtokenizer.eottext = wtokenizer.decode(dec)print(text)break
torch.Size([2, 50])
torch.Size([2, 80, 3000])
torch.Size([2, 50])
<|zh|><|transcribe|><|notimestamps|>节目单上赫然印着特邀中央乐团百余位演奏演唱家微妙地避开了矛盾<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>
<|startoftranscript|><|zh|><|transcribe|><|notimestamps|>节目单上赫然印着特邀中央乐团百余位演奏演唱家微妙地避开了矛盾<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>
<|zh|><|transcribe|><|notimestamps|>放眼望去永定河两旁人声鼎沸彩旗飘扬推土机挖土机运土车正紧张地忙碌着<|endoftext|>
<|startoftranscript|><|zh|><|transcribe|><|notimestamps|>放眼望去永定河两旁人声鼎沸彩旗飘扬推土机挖土机运土车正紧张地忙碌着
验证解码器
with torch.no_grad():audio_features = wmodel.encoder(b["input_ids"].cuda())input_ids = b["input_ids"]labels = b["labels"].long()dec_input_ids = b["dec_input_ids"].long()audio_features = wmodel.encoder(input_ids.cuda())print(dec_input_ids)print(input_ids.shape, dec_input_ids.shape, audio_features.shape)print(audio_features.shape)print()# 计算解码器的输出
out = wmodel.decoder(dec_input_ids.cuda(), audio_features)print(out.shape)
print(out.view(-1, out.size(-1)).shape)
print(b["labels"].view(-1).shape)
tensor([[50258, 50260, 50359, 50363, 45161, 11386, 47446, 5708, 5266, 104,5823, 35825, 20708, 17682, 3023, 222, 5975, 1787, 106, 44365,3919, 95, 31906, 1593, 247, 11160, 31382, 1881, 237, 31382,39861, 5155, 39152, 5648, 247, 10928, 3330, 123, 18937, 2289,5881, 249, 5419, 122, 50257, 50257, 50257, 50257, 50257, 50257],[50258, 50260, 50359, 50363, 12744, 25281, 22694, 6734, 42503, 12088,3308, 111, 36257, 4479, 223, 4035, 32045, 41111, 236, 3308,116, 7391, 102, 4479, 245, 11808, 246, 3416, 105, 33597,45506, 37960, 8501, 244, 45506, 37960, 3316, 238, 45506, 17819,99, 15789, 7732, 100, 44059, 10928, 48839, 16337, 234, 20708]])
torch.Size([2, 80, 3000]) torch.Size([2, 50]) torch.Size([2, 1500, 768])
torch.Size([2, 1500, 768])torch.Size([2, 50, 51865])
torch.Size([100, 51865])
torch.Size([100])
token转文本输出
tokens = torch.argmax(out, dim=2)
for token in tokens:token[token == -100] = wtokenizer.eottext = wtokenizer.decode(token)print(text)
<|zh|><|translate|><|notimestamps|>节目单上赫然印着特邀中央乐团百余位演奏歌唱家微妙的避开了矛盾<|endoftext|><|endoftext|> <|endoftext|><|endoftext|><|endoftext|><|endoftext|>
<|zh|><|transcribe|><|notimestamps|>放眼望去,定河兩旁人身顎沸彩旗飘扬推土机挖土机运土车正紧张地忙碌着<|endoftext|>
构造trainer
class Config:learning_rate = 0.0001weight_decay = 0.01adam_epsilon = 1e-8warmup_steps = 2batch_size = 16num_worker = 2num_train_epochs = 1000gradient_accumulation_steps = 1sample_rate = SAMPLE_RATEclass WhisperModelModule(LightningModule):def __init__(self, cfg: Config, model_name="small", lang="zh", train_dataset=[], eval_dataset=[]) -> None:super().__init__()self.options = whisper.DecodingOptions(language=lang, without_timestamps=True)self.model = whisper.load_model(model_name)self.tokenizer = whisper.tokenizer.get_tokenizer(True, language="zh", task=self.options.task)# only decoder trainingfor p in self.model.encoder.parameters():p.requires_grad = Falseself.loss_fn = nn.CrossEntropyLoss(ignore_index=-100)self.metrics_wer = evaluate.load("wer")self.metrics_cer = evaluate.load("cer")self.cfg = cfgself.__train_dataset = train_datasetself.__eval_dataset = eval_datasetdef forward(self, x):return self.model(x)def training_step(self, batch, batch_id):input_ids = batch["input_ids"]labels = batch["labels"].long()dec_input_ids = batch["dec_input_ids"].long()with torch.no_grad():audio_features = self.model.encoder(input_ids)out = self.model.decoder(dec_input_ids, audio_features)loss = self.loss_fn(out.view(-1, out.size(-1)), labels.view(-1))self.log("train/loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)return lossdef on_train_epoch_end(self):avg_loss = self.trainer.callback_metrics.get("train/loss")# 获取当前的 epoch 数量epoch = self.current_epochprint(f"Epoch: {epoch}, Training - Loss: {avg_loss:.4f}")def validation_step(self, batch, batch_id):input_ids = batch["input_ids"]labels = batch["labels"].long()dec_input_ids = batch["dec_input_ids"].long()audio_features = self.model.encoder(input_ids)out = self.model.decoder(dec_input_ids, audio_features)loss = self.loss_fn(out.view(-1, out.size(-1)), labels.view(-1))out[out == -100] = self.tokenizer.eotlabels[labels == -100] = self.tokenizer.eoto_list, l_list = [], []for o, l in zip(out, labels):o = torch.argmax(o, dim=1)o_list.append(self.tokenizer.decode(o))l_list.append(self.tokenizer.decode(l))wer = self.metrics_wer.compute(references=l_list, predictions=o_list)self.log("val/loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)self.log("val/wer", wer, on_step=False, on_epoch=True, prog_bar=True, logger=True)# 打印到终端#print(f"Validation - Loss: {loss:.4f}, WER: {wer:.4f}")return {"wer": wer,"loss": loss}def on_validation_epoch_end(self):avg_loss = self.trainer.callback_metrics.get("val/loss")avg_wer = self.trainer.callback_metrics.get("val/wer")# 获取当前的 epoch 数量epoch = self.current_epochprint(f"Epoch: {epoch}, Validation - Loss: {avg_loss:.4f}, WER: {avg_wer:.4f}")def configure_optimizers(self):"""创建优化程序和调度器 """model = self.modelno_decay = ["bias", "LayerNorm.weight"]optimizer_grouped_parameters = [{"params": [p for n, p in model.named_parameters()if not any(nd in n for nd in no_decay)],"weight_decay": self.cfg.weight_decay,},{"params": [p for n, p in model.named_parameters()if any(nd in n for nd in no_decay)],"weight_decay": 0.0,},]optimizer = AdamW(optimizer_grouped_parameters,lr=self.cfg.learning_rate,eps=self.cfg.adam_epsilon)self.optimizer = optimizerscheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=self.cfg.warmup_steps,num_training_steps=self.t_total)self.scheduler = schedulerreturn [optimizer], [{"scheduler": scheduler, "interval": "step", "frequency": 1}]def setup(self, stage=None):"""初始设置(读取数据集)"""if stage == 'fit' or stage is None:self.t_total = ((len(self.__train_dataset) // (self.cfg.batch_size))// self.cfg.gradient_accumulation_steps* float(self.cfg.num_train_epochs))def train_dataloader(self):""" 创建训练数据加载程序 """dataset = Th30Dataset(self.__train_dataset, self.tokenizer, self.cfg.sample_rate)return torch.utils.data.DataLoader(dataset,batch_size=self.cfg.batch_size,drop_last=True, shuffle=True, num_workers=self.cfg.num_worker,collate_fn=WhisperDataCollatorWhithPadding())def val_dataloader(self):""" 创建验证数据加载程序 """dataset = Th30Dataset(self.__eval_dataset, self.tokenizer, self.cfg.sample_rate)return torch.utils.data.DataLoader(dataset,batch_size=self.cfg.batch_size,num_workers=self.cfg.num_worker,collate_fn=WhisperDataCollatorWhithPadding())
主要是对LightningModule
类相关方法的重载,定义了train、validate以及optimizer的行为,以及在训练过程中日志和相关信息、checkpoint的保存。
启动训练
log_output_dir = "./logs"
check_output_dir = "./artifacts"train_name = "whisper"
train_id = "00001"model_name = "small"
lang = "zh"cfg = Config()# os.mkdir(log_output_dir)
# os.mkdir(check_output_dir)tflogger = TensorBoardLogger(save_dir=log_output_dir,name=train_name,version=train_id
)from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpointcheckpoint_callback = ModelCheckpoint(dirpath=f"{check_output_dir}/checkpoint",filename="checkpoint-{epoch:04d}",save_top_k=2, # all model savesave_on_train_epoch_end=False,monitor='val/wer', # 需要监控的验证损失mode='min', # 最小化 val_lossverbose=True # 打印更多的信息到控制台
)
callback_list = [checkpoint_callback, LearningRateMonitor(logging_interval="epoch")]
model = WhisperModelModule(cfg, model_name, lang, train_audio_transcript_pair_list, eval_audio_transcript_pair_list)trainer = Trainer(precision=16,accelerator="gpu",max_epochs=cfg.num_train_epochs,check_val_every_n_epoch=2,accumulate_grad_batches=cfg.gradient_accumulation_steps,logger=tflogger,callbacks=callback_list
)trainer.fit(model)
```shell
对于10小时数据集,你可能看到如下输出:
Epoch: 0, Training - Loss: 1.0872
Epoch: 1, Validation - Loss: 0.4207, WER: 0.9847
Epoch: 1, Training - Loss: 0.2955
Epoch: 2, Training - Loss: 0.5555
Epoch: 3, Validation - Loss: 0.2505, WER: 0.9006
Epoch: 3, Training - Loss: 0.0979
Epoch: 4, Training - Loss: 0.0602
Epoch: 5, Validation - Loss: 0.2889, WER: 0.8764
Epoch: 5, Training - Loss: 0.0721
Epoch: 6, Training - Loss: 0.0947
Epoch: 7, Validation - Loss: 0.3839, WER: 0.9809
Epoch: 7, Training - Loss: 0.1379
对于30小时数据集,即使用完整的th30数据,其中85%用于Training,而15%用于validation你可能看到如下输出:
```shell
3051.3s 121 Epoch: 1, Validation - Loss: 0.0588, WER: 0.3499
3061.1s 122 Epoch: 1, Training - Loss: 0.0340
4279.7s 123 Epoch: 2, Training - Loss: 0.0268
5691.4s 124 Epoch: 3, Validation - Loss: 0.0676, WER: 0.8318
5701.1s 125 Epoch: 3, Training - Loss: 0.0201
6919.2s 126 Epoch: 4, Training - Loss: 0.0257
8329.5s 127 Epoch: 5, Validation - Loss: 0.0484, WER: 0.8472
8329.5s 128 Epoch: 5, Training - Loss: 0.0144
9547.5s 129 Epoch: 6, Training - Loss: 0.1127
10959.3s 130 Epoch: 7, Validation - Loss: 0.0422, WER: 0.3982
10969.4s 131 Epoch: 7, Training - Loss: 0.0053
12188.2s 132 Epoch: 8, Training - Loss: 0.0076
13600.0s 133 Epoch: 9, Validation - Loss: 0.0482, WER: 0.8158
13600.0s 134 Epoch: 9, Training - Loss: 0.0126
14819.0s 135 Epoch: 10, Training - Loss: 0.0152
16230.8s 136 Epoch: 11, Validation - Loss: 0.0544, WER: 0.6829
16230.8s 137 Epoch: 11, Training - Loss: 0.0114
17450.1s 138 Epoch: 12, Training - Loss: 0.0174
18862.4s 139 Epoch: 13, Validation - Loss: 0.0523, WER: 0.3225
18872.1s 140 Epoch: 13, Training - Loss: 0.0117
20091.5s 141 Epoch: 14, Training - Loss: 0.0075
21503.2s 142 Epoch: 15, Validation - Loss: 0.0567, WER: 0.5187
21503.2s 143 Epoch: 15, Training - Loss: 0.0137
22722.5s 144 Epoch: 16, Training - Loss: 0.0150
24134.2s 145 Epoch: 17, Validation - Loss: 0.0631, WER: 0.4559
24134.2s 146 Epoch: 17, Training - Loss: 0.0122
25352.9s 147 Epoch: 18, Training - Loss: 0.0120
26765.0s 148 Epoch: 19, Validation - Loss: 0.0523, WER: 0.7387
26765.0s 149 Epoch: 19, Training - Loss: 0.0060
27983.9s 150 Epoch: 20, Training - Loss: 0.0154
29395.1s 151 Epoch: 21, Validation - Loss: 0.0520, WER: 0.4749
29395.1s 152 Epoch: 21, Training - Loss: 0.0073
30612.5s 153 Epoch: 22, Training - Loss: 0.6361
32022.4s 154 Epoch: 23, Validation - Loss: 0.0396, WER: 0.2912
32033.0s 155 Epoch: 23, Training - Loss: 0.0029
33250.8s 156 Epoch: 24, Training - Loss: 0.0036
34662.0s 157 Epoch: 25, Validation - Loss: 0.0461, WER: 0.6043
34662.0s 158 Epoch: 25, Training - Loss: 0.0094
35880.1s 159 Epoch: 26, Training - Loss: 0.0082
37291.0s 160 Epoch: 27, Validation - Loss: 0.0428, WER: 0.7481
37291.0s 161 Epoch: 27, Training - Loss: 0.0051
38509.7s 162 Epoch: 28, Training - Loss: 0.0075
39920.4s 163 Epoch: 29, Validation - Loss: 0.0447, WER: 0.8736
39920.4s 164 Epoch: 29, Training - Loss: 0.0091
41138.9s 165 Epoch: 30, Training - Loss: 0.0088
42549.9s 166 Epoch: 31, Validation - Loss: 0.0530, WER: 0.4500
42549.9s 167 Epoch: 31, Training - Loss: 0.0072
遗留的第二个问题
首先是数据集的问题,因为可以看到随着时长的增加,看到模型训练过程在符合预期方向走,
-
最低数据量:起步来说,至少需要几个小时的音频数据来进行有效的fine-tuning。例如,从10小时开始,这是一个相对较小的数据集,可以用来调试模型和流程。
-
中等数据量:为了获得更佳的效果,推荐使用20至50小时的音频数据。这可以帮助模型更好地学习到特定语言的特性。
-
理想数据量:如果资源允许,使用超过100小时的音频数据将更有助于模型性能的提升。更多的数据可以显著提高模型的泛化能力和准确性。
当然对于大模型,数据质量越高越好,数据多样性越多越好。
进一步通过tensorboard图可以看到:
在运行12个小时之后可以看到WER比一开始的确实下降了不少,但是还没有达到20%左右,最低的WER在0.2912,但是这里可以观察到一个非常有趣的现象:
在观察到训练损失(Training Loss)持续下降而验证损失(Validation Loss)和字错误率 (WER, Word Error Rate) 没有持续改善或波动较大的情况时,这通常是过拟合的一个迹象。在这种情况下,模型在训练数据上表现得越来越好,但在未见过的验证数据上的表现却没有相对应的提升,甚至出现恶化。
由于callback回调中会保持前两个在验证集上WER最小的两个checkpoint,接下来有几个思路:
1.分析模型在验证集上的错误,看是否存在特定模式或类型的错误,这可能帮助诊断问题并指导进一步模型调整,因为我们是在whipser开源的基础上fine-tune的,所以不可能简化模型结构的本身,如减少层数或神经元数目以改善过拟合。
2.可以考虑正则化技术(L2正则化、Dropout)等以有助于缓解过拟现象,增强模型的泛化能力
3.调整训练策略,调整学习率或者使用不同的优化器,以评估模型在验证集上的表现;
4.增加更多数据,帮助模型学习到更多特征,从而提高模型泛化能力
观察验证集识别效果
由于输出缩略或视觉上的相似性,一些小的差异(如标点、空白或特殊字符)可能不容易觉察。这些微小的差异在计算WER时会被考虑进去,但在人眼检查时可能会被忽略。
可以看到基本上是一致的,但是个别词是有出入的,这是因为th30是人工读的,准确性比较高,并不意味着通话、会议、游戏场景的识别率也能如此。
这里再留几个尾巴给读者自己实现:
CER
1.中文是基于字符的语言,通常我们会使用CER(Character Error Rate,字符错误率)来进行更精确的评估。然而,如果你使用的是WER来评估中文语音识别的质量,这里有几点可能需要注意:
- 在处理中文时,如果WER是基于词的,就必须先进行准确的分词。中文没有明显的词与词之间的分隔,因此分词的准确性对于WER的计算非常关键。错误的分词可能导致高WER,即使识别的字符完全正确。
- 中文中的一些微小差异,如同音字错误、词序变化或者是语气词的使用,都可以在视觉上看起来非常相似,但在WER的计算中会被视为错误。
- 你查看的样本可能并不代表整体数据集的平均表现。此外,中文语音识别可能特别擅长处理某些特定的语句或者在某些领域表现更好。
数据质量
除了数据量之外,数据的质量也非常重要:
- 多样性:数据应该涵盖多种口音、语速和语调,以及不同的背景噪声环境,这将帮助模型在各种输入条件下都能保持稳定的表现。
- 标注准确性:确保你的数据标注尽可能准确,错误的标注会直接影响模型学习的结果。
预处理和增强
- 预处理:对音频进行预处理,如采样率转换(确保和模型训练时使用的采样率一致),音量标准化等。
- 数据增强:可以考虑使用音频数据增强技术,如添加背景噪声、改变语速和音高等,以增加模型的鲁棒性。
资源和迭代
- 计算资源:fine-tuning一个语音识别模型可能需要大量的计算资源,特别是当使用大量数据时。确保你有足够的GPU资源进行训练。
- 迭代和评估:在fine-tuning过程中需要多次迭代和评估,以找到最优的模型参数和设置。
总结来说,训练的过程如炼丹,有些训练的经验是不能从小模型直接用到大模型上的。比如small和对于large-v3两种。
在模型相对较小的时候,learning rate的设置可以比较激进,但是对非常大模型的时候,较大的lr可能导致模型一开始loss就无法收敛,是发散的,但如果设置的lr比较小,那可能使得训练的时长成倍增加,怎么办呢?针对很大的模型,warm-up策略是很多时候会使用的。
load weight and inference
checkpoint_path = "whisper-checkpoint/checkpoint-epoch0023.ckpt"
state_dict = torch.load(checkpoint_path)
print(state_dict.keys())
state_dict = state_dict['state_dict']
/tmp/ipykernel_36/4099222220.py:2: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.state_dict = torch.load(checkpoint_path)
dict_keys(['epoch', 'global_step', 'pytorch-lightning_version', 'state_dict', 'loops', 'callbacks', 'optimizer_states', 'lr_schedulers', 'MixedPrecision'])
加载模型参数
cfg = Config()
whisper_model = WhisperModelModule(cfg)
whisper_model.load_state_dict(state_dict)
100%|███████████████████████████████████████| 461M/461M [00:05<00:00, 87.6MiB/s]
Downloading builder script: 0%| | 0.00/4.49k [00:00<?, ?B/s]
Downloading builder script: 0%| | 0.00/5.60k [00:00<?, ?B/s]
<All keys matched successfully>
前向推理
woptions = whisper.DecodingOptions(language="zh", without_timestamps=True)
dataset = Th30Dataset(eval_audio_transcript_pair_list, wtokenizer, SAMPLE_RATE)
loader = torch.utils.data.DataLoader(dataset, batch_size=2, collate_fn=WhisperDataCollatorWhithPadding())refs = []
res = []
for b in tqdm(loader):input_ids = b["input_ids"].half().cuda()labels = b["labels"].long().cuda()with torch.no_grad():#audio_features = whisper_model.model.encoder(input_ids)#out = whisper_model.model.decoder(enc_input_ids, audio_features)results = whisper_model.model.decode(input_ids, woptions)for r in results:res.append(r.text)for l in labels:l[l == -100] = wtokenizer.eotref = wtokenizer.decode(l)refs.append(ref)```打印推理结果```for k, v in zip(refs, res):print("-"*10)print(k)print(v)
部分输出结果
----------
<|zh|><|transcribe|><|notimestamps|>节目单上赫然印着特邀中央乐团百余位演奏演唱家微妙地避开了矛盾<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>
节目单上赫然印着特邀中央乐团百余位演奏歌唱家微妙地避开了矛盾
----------
<|zh|><|transcribe|><|notimestamps|>放眼望去永定河两旁人声鼎沸彩旗飘扬推土机挖土机运土车正紧张地忙碌着<|endoftext|>
放眼望去永定河两旁人声鼎沸彩旗飘扬推土机挖土机运土车正紧张地忙碌着
----------
<|zh|><|transcribe|><|notimestamps|>旅与游的时间比往往旅长游短与游客的愿望相悖<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>
旅与游的时间比往往旅长游短与游客的愿望相悖
----------
<|zh|><|transcribe|><|notimestamps|>中学毕业后他考入尼恩罗德商业学院毕业后曾服兵役并在一家贸易公司任职<|endoftext|>
中学毕业后他考入尼恩罗德商业学院毕业后曾服兵役并在一家贸易公司任职
----------
<|zh|><|transcribe|><|notimestamps|>该片导演为虞石束杨军主要演员有韩夫一李进慕爱秋周桂云金鑫冯云魁等<|endoftext|>
该片导演为虞石束杨军主要演员有韩夫一李进慕爱秋周桂云金鑫冯云魁等
----------
<|zh|><|transcribe|><|notimestamps|>何勰二话没说立即交了一千五百元的押金又为洛桑卓玛买来了全套新衣服和住院用品<|endoftext|><|endoftext|>
何勰二话没说立即交了一千五百元的押金又为洛桑卓玛买来了全套新衣服和住院用品
----------
<|zh|><|transcribe|><|notimestamps|>印加人所创造的文明与玛雅文明阿兹特克文明一起被誉为美洲印第安三大文明<|endoftext|>
印加人所创造的文明与玛雅文明阿兹特克文明一起被誉为美洲印第安三大文明
----------
<|zh|><|transcribe|><|notimestamps|>今天陪以萌找冯邦找得又累又饿但看见以萌那副着急样我一点也吃不下<|endoftext|><|endoftext|><|endoftext|><|endoftext|>
今天陪以萌找冯邦找得又累又饿但看见以萌那副着急样我一点也吃不下
----------
<|zh|><|transcribe|><|notimestamps|>亲英的北爱尔兰新教派武装十二日晚发表声明威胁要报复爱尔兰共和军<|endoftext|>
亲英的北爱尔兰新教派武装十二日晚发表声明威胁要报复爱尔兰共和军
----------
<|zh|><|transcribe|><|notimestamps|>小仲不顾闲言碎语一天几趟往我家跑为我洗衣做饭熬药煎汤<|endoftext|><|endoftext|><|endoftext|>
小仲不顾闲言碎语一天几趟往我家跑为我洗衣做饭熬药煎汤
----------
<|zh|><|transcribe|><|notimestamps|>这位病人因贲门下胃底大弯静脉曲张伴血管瘤破裂胃内大量喷血<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>
这位病人因贲门下胃底大弯静脉曲张伴血管瘤破裂胃内大量喷血
----------
<|zh|><|transcribe|><|notimestamps|>驻藏边防某部二连战士赵金站岗时隐隐约约听见营区外的不远处有哭泣声<|endoftext|>
驻藏边防某部二连战士赵金站岗时隐隐约约听见营区外的不远处有哭泣声
----------
<|zh|><|transcribe|><|notimestamps|>其种植的红富士苹果以色泽艳丽果质细脆汁多味美和极耐贮运而享誉海内外<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>
其种植的红富士苹果以色泽艳丽果质细脆汁多味美和极耐贮存而享誉海内外
----------
<|zh|><|transcribe|><|notimestamps|>一九四一年十一月陕甘宁边区根据三三制原则举行参议会议员竞选一位名叫森健的学员被推为候选人<|endoftext|>
一九四一年十一月陕甘宁边区根据三三制原则举行参议会议员竞选一位名叫森健的学员被推为候选人
----------
<|zh|><|transcribe|><|notimestamps|>当船往下漂时白唇鹿扬起四蹄在岸边追随好像是送行一直跑了十几里太亲切了<|endoftext|>
当船往下漂时白唇鹿扬起四蹄在岸边追赶好像是送行一直跑了十几里太亲切了
----------
<|zh|><|transcribe|><|notimestamps|>有的单位按年人均月收入减去费用八百元后的余额为应纳税所得额<|endoftext|><|endoftext|><|endoftext|><|endoftext|>
有的单位按年人均月收入减去费用八百元后的余额为应纳税所得额
----------
<|zh|><|transcribe|><|notimestamps|>女性腰部以上特别肥胖者易患乳腺癌腰围与臀围差别不大者患癌率比一般妇女高六倍<|endoftext|>
女性腰部以上特别肥胖者易患乳腺癌腰围与臀围差别不大者患癌率比一般妇女高六倍
----------
<|zh|><|transcribe|><|notimestamps|>日本队在男子团体赛中获银牌队员岩井哲贤在个人全能赛也夺得一枚银牌<|endoftext|><|endoftext|><|endoftext|>
日本队在男子团体赛中获银牌队员岩井哲贤在个人全能赛也夺得一枚银牌
----------
<|zh|><|transcribe|><|notimestamps|>如此举措源于杭州娃哈哈食品集团公司总经理宗庆后对市场特质的洞悉<|endoftext|><|endoftext|><|endoftext|>
如此举措源于杭州娃哈哈食品集团公司总经理宗庆后对市场特质的洞悉```
接下来还有三个问题对于应用更需要细致考虑:
1.Whisper除了识别,还有直接翻译功能,在以前要先识别成中文,再汉译英等,这个好处是显而易见的,首先只要一个模型,节约部分人力、机器以及服务端GPU,业务场景上可以是会议的实时翻译、看英文视频实时翻译成中文,这会减少latency,用户体验也更好;
2.如何在实时的流式场景中使用?
3.kv-caching是个什么技术?12倍是如何做到的?这在工程部署商用价值非常大。欢迎点赞、收藏、关注,以便及时收到下一篇推送。