使用torch模型BMM int8计算。
模拟:BMM->softmax->BMM 计算流程
import torch
import numpy as np
torch.manual_seed(777)
def int8_quantize_per_token(x: torch.Tensor, axis: int = -1, attns=False):if x.dtype != torch.float32:x = x.type(torch.float32)xmax = torch.abs(x)xmax = torch.max(xmax, dim=axis, keepdim=True)[0]scale = xmax / 127.0if not attns:# scale = torch.clamp(scale, 1e-5, np.finfo(np.float32).max)passelse:# scale = torch.tensor(1 / 127.0, dtype=torch.float32)passout = x / scaleout = torch.round(out)out = torch.clamp(out, -128, 127)quantized_out = out.type(torch.int8)return quantized_out, scaledef int8_quantize_per_tensor(x, axis=0, attns=False):if x.dtype != torch.float32:x = x.type(torch.float32)xmax = torch.abs(x)xmax = torch.max(xmax, dim=-1, keepdim=True)[0]xmax = torch.max(xmax, dim=-2, keepdim=True)[0]scale = xmax / 127.0if not attns:# scale = torch.clamp(scale, 1e-5, np.finfo(np.float32).max)passelse:# scale = torch.tensor(1 / 127.0, dtype=torch.float32)passout = x / scaleout = torch.round(out)out = torch.clamp(out, -128, 127)quantized_out = out.type(torch.int8)return quantized_out, scaledef matmul_int8(key, query, value):key = key.permute([0, 1, 3, 2])query, q_s = int8_quantize_per_token(query)key, k_s = int8_quantize_per_token(key, -2)attention_scores = torch.matmul(query.type(torch.float32),key.type(torch.float32))scale = q_s * k_sattention_1 = torch.mul(attention_scores, scale)attention_scores = attention_1 / torch.sqrt(torch.tensor(32, dtype=torch.float32))attention_scores = torch.softmax(attention_scores, dim=-1)attention_scores_int8, attn_p_s = int8_quantize_per_token(attention_scores, attns=True)value, v_s = int8_quantize_per_token(value, -2)context = torch.matmul(attention_scores_int8.type(torch.float32),value.type(torch.float32))scale = attn_p_s * v_scontext = torch.mul(context, scale)return attention_1, contextdef matmul_fp(key, query, value):key = key.permute([0, 1, 3, 2])attention_1 = torch.matmul(query.type(torch.float32),key.type(torch.float32))attention_scores = attention_1 / torch.sqrt(torch.tensor(32, dtype=torch.float32))attention_scores = torch.softmax(attention_scores, dim=-1)context = torch.matmul(attention_scores.type(torch.float32),value.type(torch.float32))return attention_1, contextdef mtx_similar1(arr1:np.ndarray, arr2:np.ndarray) ->float:'''计算矩阵相似度的一种方法。将矩阵展平成向量,计算向量的乘积除以模长。注意有展平操作。:param arr1:矩阵1:param arr2:矩阵2:return:实际是夹角的余弦值,ret = (cos+1)/2'''farr1 = arr1.ravel()farr2 = arr2.ravel()len1 = len(farr1)len2 = len(farr2)if len1 > len2:farr1 = farr1[:len2]else:farr2 = farr2[:len1]numer = np.sum(farr1 * farr2)denom = np.sqrt(np.sum(farr1**2) * np.sum(farr2**2))similar = numer / denom # 这实际是夹角的余弦值return (similar+1) / 2 # 姑且把余弦函数当线性if __name__ == "__main__":key = torch.randn((2, 6, 10, 32))value = torch.randn((2, 6, 10, 32))query = torch.randn((2, 6, 1, 32))i_key = key.clone().detach()i_value = value.clone().detach()i_query = query.clone().detach()fp_score, fp_context = matmul_fp(key, query, value)int8_score, int8_context = matmul_int8(i_key, i_query, i_value)similar1 = mtx_similar1(int8_score.cpu().detach().numpy(),fp_score.cpu().detach().numpy())similar2 = mtx_similar1(int8_context.cpu().detach().numpy(),fp_context.cpu().detach().numpy())print(similar1, similar2)np.testing.assert_allclose(fp_score.detach().cpu().numpy(),int8_score.detach().cpu().numpy(),rtol=1e-02, atol=1e-03)np.testing.assert_allclose(fp_context.detach().cpu().numpy(),int8_context.detach().cpu().numpy(),rtol=1e-02, atol=1e-03)
结论:
Per-token 精度优于per-tensor
BMM1 和 BMM2定点计算之后,输出误差较大