【深度学习】多目标融合算法(二):底部共享多任务模型(Shared-Bottom Multi-task Model)

 

目录

一、引言

1.1 往期回顾

1.2 本期概要

二、Shared-Bottom Multi-task Model(SBMM)

2.1 技术原理

2.2 技术优缺点

2.3 业务代码实践

三、总结


一、引言

在朴素的深度学习ctr预估模型中(如DNN),通常以一个行为为预估目标,比如通过ctr预估点击率。但实际推荐系统业务场景中,更多是多种目标融合的结果,比如视频推荐,会存在视频点击率、视频完整播放率、视频播放时长等多个目标,而多种目标如何更好的融合,在工业界与学术界均有较多内容产出,由于该环节对实际业务影响最为直接,特开此专栏对推荐系统深度学习多目标问题进行讲述。

1.1 往期回顾

上一篇文章主要介绍了推荐系统多目标算法中的“样本Loss加权”,该方法在训练时Loss乘以样本权重实现对多种目标的加权,通过引导Loss梯度的学习方向,让模型参数朝着你设定的权重方向去学习。

1.2 本期概要

今天进一步深化,主要介绍Shared-Bottom Multi-task Model算法,该算法中文可译为“底部共享多任务模型”,该算法设定多个任务,每个任务设定多个目标,通过“Loss计算时调整每个任务的权重”,亦或是“每个塔单元内,多目标Loss计算时调整每个目标的权重”进行多任务多目标的调整。

二、Shared-Bottom Multi-task Model(SBMM)

2.1 技术原理

Shared-Bottom Multi-task Model(SBMM)全称为底层共享多任务模型,主要由底层共享网络、多任务塔、多目标输出构成。核心原理:通过构造多任务多目标样本数据,在Loss计算环节,将各任务Loss求和(或加权求和),对Loss求导(求梯度)后,逐步后向传播迭代。

  1. 底部网络:Shared-Bottom 网络通常位于底部,可以为一个DNN网络,或者emb+pooling+mlp的方式对input输入的稀疏(sparse)特征进行稠密(dense)化。
  2. ​​​​​​​多个任务塔:底部网络上层接N个任务塔(Tower),每个塔根据需要可以定义为简单或复杂的多层感知器(mlp)网络。每个塔可以对应特定的场景,比如一二级页面场景。
  3. 多个目标:每个任务塔(Tower)可以输出多个学习目标,每个学习目标还可以像上一篇文章一样进行样本Loss加权。每个目标可以对应一种特定的指标行为,比如点击、时长、下单等。

2.2 技术优缺点

相比于上一篇文章提到的样本Loss加权融合法,以及后续文章将会介绍的MoE、MMoE方法,有如下优缺点:

优点:

  • 可以对多级场景任务进行建模,使得ctcvr等点击后转化问题可以被深度学习
  • 浅层参数共享,互相补充学习,任务相关性越高,模型的loss可以降低到更低

缺点: 

  • 跷跷板问题:任务没有好的相关性时,这种Hard parameter sharing会损害效果

2.3 业务代码实践

我们以小红书推荐场景为例,用户在一级发现页场景中停留并点击了“误杀3”中的一个视频笔记,在二级场景视频播放页中观看并点赞了视频。

跨场景多目标建模:我们定义一个SBMM算法结构,底层是一个3层的MLP(64,32,16),MLP出来后接一级场景Tower和二级场景Tower,一级场景任务中分别定义视频一级页“是否停留”、“停留时长”、“是否点击”,二级场景任务中分别定义“点击后播放时长”,“播放后是否点赞”

伪代码:

导入 pytorch 库
定义 SharedBottomMultiTaskModel 类 继承自 nn.Module:定义 __init__ 方法 参数 (self, 输入维度, 隐藏层1大小, 隐藏层2大小, 隐藏层3大小, 输出任务1维度, 输出任务2维度):初始化共享底部的三层全连接层初始化任务1的三层全连接层初始化任务2的三层全连接层定义 forward 方法 参数 (self, 输入数据):计算输入数据通过共享底部后的输出从共享底部输出分别计算任务1和任务2的结果返回任务1和任务2的结果生成虚拟样本数据:创建训练集和测试集实例化模型对象
定义损失函数和优化器
训练循环:前向传播: 获取预测值计算每个任务的损失反向传播和优化

PyTorch版本:

算法逻辑

  1. 导入必要的库。
  2. 定义一个类来表示共享底部和特定任务头部的模型结构。
  3. 在初始化方法中定义共享底部和两个独立的任务头部网络层。
  4. 实现前向传播函数,处理输入数据通过共享底部后分发到不同的任务头部。
  5. 生成虚拟样本数据。
  6. 定义损失函数和优化器。
  7. 编写训练循环。
  8. 进行模型预测。
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDatasetclass SharedBottomMultiTaskModel(nn.Module):def __init__(self, input_dim, hidden1_dim, hidden2_dim, hidden3_dim, output_task1_dim, output_task2_dim):super(SharedBottomMultiTaskModel, self).__init__()# 定义共享底部的三层全连接层self.shared_bottom = nn.Sequential(nn.Linear(input_dim, hidden1_dim),nn.ReLU(),nn.Linear(hidden1_dim, hidden2_dim),nn.ReLU(),nn.Linear(hidden2_dim, hidden3_dim),nn.ReLU())# 定义任务1的三层全连接层self.task1_head = nn.Sequential(nn.Linear(hidden3_dim, hidden2_dim),nn.ReLU(),nn.Linear(hidden2_dim, output_task1_dim))# 定义任务2的三层全连接层self.task2_head = nn.Sequential(nn.Linear(hidden3_dim, hidden2_dim),nn.ReLU(),nn.Linear(hidden2_dim, output_task2_dim))def forward(self, x):# 计算输入数据通过共享底部后的输出shared_output = self.shared_bottom(x)# 从共享底部输出分别计算任务1和任务2的结果task1_output = self.task1_head(shared_output)task2_output = self.task2_head(shared_output)return task1_output, task2_output# 构造虚拟样本数据
torch.manual_seed(42)  # 设置随机种子以保证结果可重复
input_dim = 10
task1_dim = 3
task2_dim = 2
num_samples = 1000
X_train = torch.randn(num_samples, input_dim)
y_train_task1 = torch.randn(num_samples, task1_dim)  # 假设任务1的输出维度为task1_dim
y_train_task2 = torch.randn(num_samples, task2_dim)  # 假设任务2的输出维度为task2_dim# 创建数据加载器
train_dataset = TensorDataset(X_train, y_train_task1, y_train_task2)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)# 实例化模型对象
model = SharedBottomMultiTaskModel(input_dim, 64, 32, 16, task1_dim, task2_dim)# 定义损失函数和优化器
criterion_task1 = nn.MSELoss()
criterion_task2 = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)# 训练循环
num_epochs = 10
for epoch in range(num_epochs):model.train()running_loss = 0.0for batch_idx, (X_batch, y_task1_batch, y_task2_batch) in enumerate(train_loader):# 前向传播: 获取预测值outputs_task1, outputs_task2 = model(X_batch)# 计算每个任务的损失loss_task1 = criterion_task1(outputs_task1, y_task1_batch)loss_task2 = criterion_task2(outputs_task2, y_task2_batch)#print(f'loss_task1:{loss_task1},loss_task2:{loss_task2}')total_loss = loss_task1 + loss_task2# 反向传播和优化optimizer.zero_grad()total_loss.backward()optimizer.step()running_loss += total_loss.item()print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}')# 模型预测
model.eval()
with torch.no_grad():test_input = torch.randn(1, input_dim)  # 构造一个测试样本pred_task1, pred_task2 = model(test_input)print(f'任务1预测结果: {pred_task1}')print(f'任务2预测结果: {pred_task2}')

三、总结

本文从技术原理、技术优缺点方面对推荐系统深度学习多任务多目标“Shared-Bottom Multi-task Model”算法进行讲解,该模型使用深度学习模型对多个任务场景多个目标的业务问题进行建模,使得用户在多个场景连续性行为可以被学习,在现实推荐系统业务中是比较基础的方法,后面本专栏还会陆续介绍MoE、MMoE等多任务多目标算法,期待您的关注和支持。

如果您还有时间,欢迎阅读本专栏的其他文章:

【深度学习】多目标融合算法(一):样本Loss加权(Sample Loss Reweight)

【深度学习】多目标融合算法(二):底部共享多任务模型(Shared-Bottom Multi-task Model) ​​​​​​​

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/562.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

分类模型为什么使用交叉熵作为损失函数

推导过程 让推理更有体感,进行下面假设: 假设要对猫、狗进行图片识别分类假设模型输出 y y y,是一个几率,表示是猫的概率 训练资料如下: x n x^n xn类别 y ^ n \widehat{y}^n y ​n x 1 x^1 x1猫1 x 2 x^2 x2猫1 x …

快速导入请求到postman

1.确定请求,右键复制为cURL(bash) 2.postman菜单栏Import-Raw text,粘贴复制的内容保存,请求添加成功

第432场周赛:跳过交替单元格的之字形遍历、机器人可以获得的最大金币数、图的最大边权的最小值、统计 K 次操作以内得到非递减子数组的数目

Q1、跳过交替单元格的之字形遍历 1、题目描述 给你一个 m x n 的二维数组 grid,数组由 正整数 组成。 你的任务是以 之字形 遍历 grid,同时跳过每个 交替 的单元格。 之字形遍历的定义如下: 从左上角的单元格 (0, 0) 开始。在当前行中向…

专题 - STM32

基础 基础知识 STM所有产品线(列举型号): STM产品的3内核架构(列举ARM芯片架构): STM32的3开发方式: STM32的5开发工具和套件: 若要在电脑上直接硬件级调试STM32设备,则…

基于Django的个性化餐饮管理系统

系统展示 用户前台界面 管理员后台界面 系统背景 该系统的研发对于餐饮行业具有重要意义。首先,通过个性化餐饮管理系统的应用,餐饮企业能够精准把握顾客需求,提供定制化服务,从而增强顾客粘性,提升顾客满意度。其次&a…

scala代码打包配置(maven)

目录 mavenpom.xml打包配置项&#xff08;非完整版&#xff0c;仅含打包的内容< build>&#xff09;pom.xml完整示例&#xff08;需要修改参数&#xff09;效果说明 maven 最主要的方式还是maven进行打包&#xff0c;也好进行配置项的管理 以下为pom文件&#xff08;不要…

plane开源的自托管项目

Plane 是一个开源的自托管项目规划解决方案&#xff0c;专注于问题管理、里程碑跟踪以及产品路线图的设计。作为一款开源软件&#xff0c;Plane 的代码托管在 GitHub 平台上&#xff0c;允许任何人查看和贡献代码。它为用户提供了便捷的项目创建与管理手段&#xff0c;并配备了…

wireshark排除私接小路由

1.wireshark打开&#xff0c;发现了可疑地址&#xff0c;合法的地址段DHCP是192.168.100.0段的&#xff0c;打开后查看发现可疑地址段&#xff0c;分别是&#xff0c;192.168.0.1 192.168.1.174 192.168.1.1。查找到它对应的MAC地址。 ip.src192.168.1.1 2.通过show fdb p…

Elasticsearch:使用 Playground 与你的 PDF 聊天

LLMs作者&#xff1a;来自 Elastic Toms Mura 了解如何将 PDF 文件上传到 Kibana 并使用 Elastic Playground 与它们交互。本博客展示了在 Playground 中与 PDF 聊天的实用示例。 Elasticsearch 8.16 具有一项新功能&#xff0c;可让你将 PDF 文件直接上传到 Kibana 并使用 Pla…

【C++】深入理解string相关函数:实现和分析

博客主页&#xff1a; [小ᶻ☡꙳ᵃⁱᵍᶜ꙳] 本文专栏: C 文章目录 &#x1f4af;前言&#x1f4af;1. 使用 stoi 和 stol 函数1.1 stoi 和 stol 的基本概述参数说明进制支持示例代码与解析运行结果解析 异常处理 &#x1f4af;2. 使用 stod 和 stof 函数2.1 stod 和 stof …

“AI智能服务平台系统,让生活更便捷、更智能

大家好&#xff0c;我是资深产品经理老王&#xff0c;今天咱们来聊聊一个让生活变得越来越方便的高科技产品——AI智能服务平台系统。这个系统可是现代服务业的一颗璀璨明珠&#xff0c;它究竟有哪些魅力呢&#xff1f;下面我就跟大家伙儿闲聊一下。 一、什么是AI智能服务平台系…

回归预测 | MATLAB实MLR多元线性回归多输入单输出回归预测

回归预测 | MATLAB实MLR多元线性回归多输入单输出回归预测 目录 回归预测 | MATLAB实MLR多元线性回归多输入单输出回归预测预测效果基本介绍程序设计参考资料 预测效果 基本介绍 回归预测 | MATLAB实MLR多元线性回归多输入单输出回归预测。 程序设计 完整代码&#xff1a;回…

页面滚动下拉时,元素变为fixed浮动,上拉到顶部时恢复原状,js代码以视频示例

页面滚动下拉时,元素变为fixed浮动js代码 以视频示例 <style>video{width:100%;height:auto}.div2,#float1{position:fixed;_position:absolute;top:45px;right:0; z-index:250;}button{float:right;display:block;margin:5px} </style><section id"abou…

【Vim Masterclass 笔记09】S06L22:Vim 核心操作训练之 —— 文本的搜索、查找与替换操作(第一部分)

文章目录 S06L22 Search, Find, and Replace - Part One1 从光标位置起&#xff0c;正向定位到当前行的首个字符 b2 从光标位置起&#xff0c;反向查找某个字符3 重复上一次字符查找操作4 定位到目标字符的前一个字符5 单字符查找与 Vim 命令的组合6 跨行查找某字符串7 Vim 的增…

力扣 岛屿数量

从某个点找&#xff0c;不断找相邻位置。 题目 岛屿中被“0”隔开后 &#xff0c;是每一小块状的“1”&#xff0c;本题在问有多少块。可以用dfs进行搜索&#xff0c;遍历每一个点&#xff0c;把每一个点的上下左右做搜索检测&#xff0c;当检测到就标记为“0”表示已访问过&a…

Python学习(四)调用函数、定义函数、函数参数、递归函数

目录 一、调用函数1&#xff09;函数介绍2&#xff09;数据类型转换 二、定义函数1&#xff09;定义函数2&#xff09;空函数3&#xff09;参数检查4&#xff09;返回多个值 三、函数的参数1&#xff09;位置参数2&#xff09;默认参数3&#xff09;可变参数4&#xff09;关键字…

汽车基础软件AutoSAR自学攻略(三)-AutoSAR CP分层架构(2)

汽车基础软件AutoSAR自学攻略(三)-AutoSAR CP分层架构(2) 下面我们继续来介绍AutoSAR CP分层架构&#xff0c;下面的文字和图来自AutoSAR官网目前最新的标准R24-11的分层架构手册。该手册详细讲解了AutoSAR分层架构的设计&#xff0c;下面让我们来一起学习一下。 Introductio…

Mac——Docker desktop安装与使用教程

摘要 本文是一篇关于Mac系统下Docker Desktop安装与使用教程的博文。首先介绍连接WiFi网络&#xff0c;然后详细阐述了如何在Mac上安装Docker&#xff0c;包括下载地址以及不同芯片版本的选择。接着讲解了如何下载基础镜像和指定版本镜像&#xff0c;旨在帮助用户在Mac上高效使…

Jenkins内修改allure报告名称

背景&#xff1a; 最近使用Jenkins搭建自动化测试环境时&#xff0c;使用Jenkins的allure插件生成的报告&#xff0c;一直显示默认ALLURE REPORT&#xff0c;想自定义成与项目关联的名称&#xff0c;如图所示&#xff0c;很明显自定义名称显得高大上些&#xff0c;之前…

ROS Action接口

实现自主导航是使用Action接口的主要目的 在实际使用navigation导航系统的时候&#xff0c;机器人需要自主进行导航。不能每次都手动设置导航的目标点。所以需要编写程序代码来实现导航控制。这就需要使用到navigation的导航接口。Navigation的这个导航接口有好几个。Rose官方…