文章目录 model arch S1 Model S2 model
model arch
S1 model: AR model–ssl tokens S2 model: VITS,ssl 已经是mel 长度线性相关,MRTE(ssl_codes_embs, text, global_mel_emb)模块,将文本加强相关,学到一个参考结果
S1 Model
class Text2SemanticDecoder ( ) def forward_old ( self, x, x_lens, y, y_lens, bert_feature) : """x: phoneme_idsy: semantic_idsbert_feature: 已经根据word2phn 扩展成和x等长train : y+EOS,已知长度;infer : AR 预测,预测EOS 终止;如果没有,到预设最大长度,终止;""" x = self. ar_text_embedding( x) x = x + self. bert_proj( bert_feature. transpose( 1 , 2 ) ) x = self. ar_text_position( x) x_mask = make_pad_mask( x_lens) y_mask = make_pad_mask( y_lens) y_mask_int = y_mask. type ( torch. int64) codes = y. type ( torch. int64) * ( 1 - y_mask_int) y, targets = self. pad_y_eos( codes, y_mask_int, eos_id= self. EOS) x_len = x_lens. max ( ) y_len = y_lens. max ( ) y_emb = self. ar_audio_embedding( y) y_pos = self. ar_audio_position( y_emb) xy_padding_mask = torch. concat( [ x_mask, y_mask] , dim= 1 ) ar_xy_padding_mask = xy_padding_maskx_attn_mask = F. pad( torch. zeros( ( x_len, x_len) , dtype= torch. bool , device= x. device) , ( 0 , y_len) , value= True , ) y_attn_mask = F. pad( torch. triu( torch. ones( y_len, y_len, dtype= torch. bool , device= x. device) , diagonal= 1 , ) , ( x_len, 0 ) , value= False , ) xy_attn_mask = torch. concat( [ x_attn_mask, y_attn_mask] , dim= 0 ) bsz, src_len = x. shape[ 0 ] , x_len + y_len_xy_padding_mask = ( ar_xy_padding_mask. view( bsz, 1 , 1 , src_len) . expand( - 1 , self. num_head, - 1 , - 1 ) . reshape( bsz * self. num_head, 1 , src_len) ) xy_attn_mask = xy_attn_mask. logical_or( _xy_padding_mask) new_attn_mask = torch. zeros_like( xy_attn_mask, dtype= x. dtype) new_attn_mask. masked_fill_( xy_attn_mask, float ( "-inf" ) ) xy_attn_mask = new_attn_maskxy_pos = torch. concat( [ x, y_pos] , dim= 1 ) xy_dec, _ = self. h( ( xy_pos, None ) , mask= xy_attn_mask, ) logits = self. ar_predict_layer( xy_dec[ : , x_len: ] ) . permute( 0 , 2 , 1 ) loss = F. cross_entropy( logits, targets, reduction= "sum" ) acc = self. ar_accuracy_metric( logits. detach( ) , targets) . item( ) return loss, acc
S2 model
class Encoder ( ) def forward ( self, ssl, y_lengths, text, text_lengths, speed= 1 , test= None ) : '''y_lengths: mel_lengthge : ref_encoder_outputs''' ge = self. ref_enc( y * y_mask, y_mask) ssl = self. ssl_proj( ssl) quantized, codes, commit_loss, quantized_list = self. quantizer( ssl, layers= [ 0 ] ) if self. semantic_frame_rate == "25hz" : quantized = F. interpolate( quantized, size= int ( quantized. shape[ - 1 ] * 2 ) , mode= "nearest" ) y = self. encoder_ssl( y * y_mask, y_mask) text_mask = torch. unsqueeze( commons. sequence_mask( text_lengths, text. size( 1 ) ) , 1 ) . to( y. dtype) if test == 1 : text[ : , : ] = 0 text = self. text_embedding( text) . transpose( 1 , 2 ) text = self. encoder_text( text * text_mask, text_mask) y = self. mrte( y, y_mask, text, text_mask, ge) y = self. encoder2( y * y_mask, y_mask) if ( speed!= 1 ) : y = F. interpolate( y, size= int ( y. shape[ - 1 ] / speed) + 1 , mode= "linear" ) y_mask = F. interpolate( y_mask, size= y. shape[ - 1 ] , mode= "nearest" ) stats = self. proj( y) * y_maskm, logs = torch. split( stats, self. out_channels, dim= 1 ) return y, m, logs, y_mask