import torch# 定义获取标签掩码的函数
def _get_gt_mask(logits, target):print("原始 logits:\n", logits)print("目标 target:\n", target)# 将 target 拉平为一维张量target = target.reshape(-1)print("拉平后的 target:\n", target)# 创建一个和 logits 大小相同的全零张量,然后根据 target 将对应的类别位置设置为1mask = torch.zeros_like(logits).scatter_(1, target.unsqueeze(1), 1).bool()print("生成的标签掩码 mask:\n", mask)# 返回根据 target 设置的标签掩码return mask# 定义组合掩码的函数
def cat_mask(t, mask1, mask2):print("输入张量 t:\n", t)print("标签掩码 mask1:\n", mask1)print("非标签掩码 mask2:\n", mask2)# 计算 mask1 对应的 t 值,sum(dim=1) 表示在类别维度上进行求和t1 = (t * mask1).sum(dim=1, keepdims=True)print("标签类别的加权和 t1:\n", t1)# 计算 mask2 对应的 t 值t2 = (t * mask2).sum(1, keepdims=True)print("非标签类别的加权和 t2:\n", t2)# 将两个值拼接成新的张量rt = torch.cat([t1, t2], dim=1)print("拼接后的结果 rt:\n", rt)return rt# 示例:假设有3个样本和5个类别的logits
logits = torch.tensor([[2.0, 1.0, 0.1, 3.0, 0.5],[1.0, 3.0, 2.5, 0.5, 0.3],[0.5, 2.2, 1.1, 4.0, 1.5]])# 对应的标签 target
target = torch.tensor([3, 1, 4]) # 每个样本的正确类别是3, 1, 4# 获取标签掩码
gt_mask = _get_gt_mask(logits, target)# 获取非标签掩码
def _get_other_mask(logits, target):print("原始 logits:\n", logits)print("目标 target:\n", target)# 将 target 拉平为一维张量target = target.reshape(-1)print("拉平后的 target:\n", target)# 创建一个和 logits 大小相同的全1张量,然后根据 target 将对应的类别位置设置为0mask = torch.ones_like(logits).scatter_(1, target.unsqueeze(1), 0).bool()print("生成的非标签掩码 mask:\n", mask)return maskother_mask = _get_other_mask(logits, target)# 假设有某些 softmax 结果
t = torch.softmax(logits, dim=1)
print("Softmax 后的 logits (概率值):\n", t)# 使用标签掩码和非标签掩码进行组合
combined = cat_mask(t, gt_mask, other_mask)
print("最终组合后的结果:\n", combined)