特征交叉-CAN学习笔记代码解读

一 核心模块coaction

  1. 对于每个特征对(feature_pairs)
  2. weight, bias 来自于P_induction
  3. P_fead是MLP的input

举个例子:如果是用户ID和产品ID的co-action,且产品ID是做induction,用户ID是做feed。

  • step1 用户ID/产品ID都先形成一个向量:对于产品ID,用parameter lookup获取一个可学习的P_induction(这个维度是(wi+bi) * L depth of mlp); 用户ID则直接形成一个向量P_fead
  • step2 P_induction 这个向量逐层(MLP层),reshape成MLP网络的weight 和bias;
  • step3 weight和bias作为MLP的参数,利用P_feed 作为input,进行MLP前向运算,得到特征交互结果
  1. 代码解读
#### CAN config #####
weight_emb_w = [[16, 8], [8,4]] # micro-mlp的参数dimension
weight_emb_b = [0, 0]           # bias参数
orders = 3  # 特征的阶数,文章提到了,要做高阶特征交叉,直接是P_feed^c, c就是阶数
order_indep = False # True
WEIGHT_EMB_DIM = (sum([w[0]*w[1] for w in weight_emb_w]) + sum(weight_emb_b)) # * orders 这个是供每一个micro-mlp拆解w&b需要的dimension总和
INDEP_NUM = 1
if order_indep:INDEP_NUM *= orders
###### 这一部分对应图中绿色和橙色部分,主要是把P_feed&P_induction的嵌入表示得到 ##########
if self.use_coaction:# batch_ph batch输入的数据;his_batch_ph历史批次数据; his_batch_embedded 历史嵌入表示ph_dict = {"item": [self.mid_batch_ph, self.mid_his_batch_ph, self.mid_his_batch_embedded],"cate": [self.cate_batch_ph, self.cate_his_batch_ph, self.cate_his_batch_embedded]}### p_induction ####self.mlp_batch_embedded = [] # induction embeddingwith tf.device(device):# 定义可训练的嵌入矩阵,在这里n_mid是item id的数量self.item_mlp_embeddings_var = tf.get_variable("item_mlp_embedding_var", [n_mid, INDEP_NUM * WEIGHT_EMB_DIM], trainable=True)self.cate_mlp_embeddings_var = tf.get_variable("cate_mlp_embedding_var", [n_cate, INDEP_NUM * WEIGHT_EMB_DIM], trainable=True)# 通过embedding_lookup在上一步初始化好的矩阵中找到对应的embedding表示self.mlp_batch_embedded.append(tf.nn.embedding_lookup(self.item_mlp_embeddings_var, ph_dict['item'][0]))self.mlp_batch_embedded.append(tf.nn.embedding_lookup(self.cate_mlp_embeddings_var, ph_dict['cate'][0]))#########P_feed input ########self.input_batch_embedded = []self.item_input_embeddings_var = tf.get_variable("item_input_embedding_var", [n_mid, weight_emb_w[0][0] * INDEP_NUM], trainable=True)self.cate_input_embeddings_var = tf.get_variable("cate_input_embedding_var", [n_cate, weight_emb_w[0][0] * INDEP_NUM], trainable=True)  self.input_batch_embedded.append(tf.nn.embedding_lookup(self.item_input_embeddings_var, ph_dict['item'][1]))self.input_batch_embedded.append(tf.nn.embedding_lookup(self.cate_input_embeddings_var, ph_dict['cate'][1]))
################这一部分是P_induction&P_feed在MLP的使用#######################
if self.use_coaction:# p_feed/inputinput_batch = self.input_batch_embeddedtmp_sum, tmp_seq = [], []if INDEP_NUM == 2:# 文章说明了是feature pairs,mlp_batch&input_batch都包含了两个部分,要分别组合for i, mlp_batch in enumerate(self.mlp_batch_embedded):for j, input_batch in enumerate(self.input_batch_embedded):coaction_sum, coaction_seq = gen_coaction(mlp_batch[:, WEIGHT_EMB_DIM * j:  WEIGHT_EMB_DIM * (j+1)], input_batch[:, :, weight_emb_w[0][0] * i: weight_emb_w[0][0] * (i+1)],  EMBEDDING_DIM, mode=CALC_MODE,mask=self.mask) tmp_sum.append(coaction_sum)tmp_seq.append(coaction_seq)else:for i, (mlp_batch, input_batch) in enumerate(zip(self.mlp_batch_embedded, self.input_batch_embedded)):coaction_sum, coaction_seq = gen_coaction(mlp_batch[:, :INDEP_NUM * WEIGHT_EMB_DIM], input_batch[:, :, :weight_emb_w[0][0]],  EMBEDDING_DIM, mode=CALC_MODE, mask=self.mask) tmp_sum.append(coaction_sum)tmp_seq.append(coaction_seq)self.coaction_sum = tf.concat(tmp_sum, axis=1) # sum poolingself.cross.append(self.coaction_sum)   # concat              
###### core interaction 核心运算 #########
def gen_coaction(ad, his_items, dim, mode="can", mask=None):"""ad: inducthis_items 待交互seq"""weight, bias = [], []idx = 0weight_orders = []bias_orders = []# 拆解得到weight&bias参数for i in range(orders):for w, b in zip(weight_emb_w, weight_emb_b):weight.append(tf.reshape(ad[:, idx:idx+w[0]*w[1]], [-1, w[0], w[1]]))idx += w[0] * w[1]if b == 0:bias.append(None)else:bias.append(tf.reshape(ad[:, idx:idx+b], [-1, 1, b]))idx += bweight_orders.append(weight)bias_orders.append(bias)if not order_indep:breakif mode == "can":out_seq = []hh = []# 高阶特征处理,explicit deal withfor i in range(orders):hh.append(his_items**(i+1))#hh = [sum(hh)]for i, h in enumerate(hh):if order_indep:weight, bias = weight_orders[i], bias_orders[i]else:weight, bias = weight_orders[0], bias_orders[0]# 模拟MLP forward calculationfor j, (w, b) in enumerate(zip(weight, bias)):h  = tf.matmul(h, w)if b is not None:h = h + bif j != len(weight)-1:h = tf.nn.tanh(h)out_seq.append(h)out_seq = tf.concat(out_seq, 2)if mask is not None:mask = tf.expand_dims(mask, axis=-1) out_seq = out_seq * mask# 序列交互结果做sum_poolingout = tf.reduce_sum(out_seq, 1)if keep_fake_carte_seq and mode=="emb":return out, out_seqreturn out, None

二 文章中的应用
整体的模型结构两部分构成:

  • co-action作为核心形成的一部分,对于用户的序列特征,一一作用后做sum-pooling,对于非序列特征,作用后直接输出
  • DIEN作为核心形成的一部分

两部分concat以后加一个DNN常规操作,看起来就像是用co-action做显式的特征交叉,然后DIEN做之前的序列建模。
在这里插入图片描述
三 一些其他细节补充

  1. can 部分高阶特征处理: 直接把待交叉特征p_fead 做c阶运算后,再与p_induction进行作用
  2. 在文章场景,p_induction是target_item,也就是产品

四 用tf2/torch重构

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as Fclass CAN_Model(nn.Module):def __init__(self, n_uid, n_mid, n_cate, n_carte, EMBEDDING_DIM, HIDDEN_SIZE, ATTENTION_SIZE, use_negsampling=False, use_softmax=True, use_coaction=False, use_cartes=False):super(CAN_Model, self).__init__()self.n_uid = n_uidself.n_mid = n_midself.n_cate = n_cateself.n_carte = n_carteself.EMBEDDING_DIM = EMBEDDING_DIMself.HIDDEN_SIZE = HIDDEN_SIZEself.ATTENTION_SIZE = ATTENTION_SIZEself.use_negsampling = use_negsamplingself.use_softmax = use_softmaxself.use_coaction = use_coactionself.use_cartes = use_cartesself.uid_embeddings = nn.Embedding(n_uid, EMBEDDING_DIM)self.mid_embeddings = nn.Embedding(n_mid, EMBEDDING_DIM)self.cate_embeddings = nn.Embedding(n_cate, EMBEDDING_DIM)if use_cartes:self.carte_embeddings = nn.ModuleList([nn.Embedding(num, EMBEDDING_DIM) for num in n_carte])if self.use_coaction:self.item_mlp_embeddings = nn.Parameter(torch.randn(n_mid, INDEP_NUM * WEIGHT_EMB_DIM))self.cate_mlp_embeddings = nn.Parameter(torch.randn(n_cate, INDEP_NUM * WEIGHT_EMB_DIM))self.input_batch_embeddings = nn.ModuleList([nn.Embedding(n_mid, weight_emb_w[0][0] * INDEP_NUM), nn.Embedding(n_cate, weight_emb_w[0][0] * INDEP_NUM)])self.fc1 = nn.Linear(200, 80)self.fc2 = nn.Linear(80, 2 if use_softmax else 1)def forward(self, uid, mid, cate, mid_his, cate_his, mask, target, seq_len, lr, carte=None):# Embedding lookupsuid_emb = self.uid_embeddings(uid)mid_emb = self.mid_embeddings(mid)cate_emb = self.cate_embeddings(cate)mid_his_emb = self.mid_embeddings(mid_his)cate_his_emb = self.cate_embeddings(cate_his)if self.use_cartes:carte_emb = [emb(carte[:, i, :]) for i, emb in enumerate(self.carte_embeddings)]# Co-action logic (if enabled)if self.use_coaction:# This is a simplified version of the co-action implementation from the original TensorFlow codemlp_embedded_item = self.item_mlp_embeddings[mid]mlp_embedded_cate = self.cate_mlp_embeddings[cate]input_embedded_item = self.input_batch_embeddings[0](mid_his)input_embedded_cate = self.input_batch_embeddings[1](cate_his)# Further coaction operations can be added based on your logic# Concatenate item and category embeddingsitem_eb = torch.cat([mid_emb, cate_emb], dim=1)item_his_eb = torch.cat([mid_his_emb, cate_his_emb], dim=2)item_his_eb_sum = item_his_eb.sum(dim=1)if self.use_negsampling:# Assuming the negative sampling implementation would need its own logic.pass# FC layersx = self.fc1(item_eb)x = F.relu(x)x = self.fc2(x)# Loss computationif self.use_softmax:y_hat = F.softmax(x, dim=-1)loss = F.cross_entropy(y_hat, target)else:y_hat = torch.sigmoid(x)loss = F.binary_cross_entropy_with_logits(x, target)return loss, y_hatdef auxiliary_loss(self, h_states, click_seq, noclick_seq, mask):mask = mask.float()click_input = torch.cat([h_states, click_seq], dim=-1)noclick_input = torch.cat([h_states, noclick_seq], dim=-1)click_prop = self.auxiliary_net(click_input)[:, :, 0]noclick_prop = self.auxiliary_net(noclick_input)[:, :, 0]click_loss = -torch.log(click_prop) * masknoclick_loss = -torch.log(1.0 - noclick_prop) * maskloss = (click_loss + noclick_loss).mean()return lossdef auxiliary_net(self, in_):x = F.relu(self.fc1(in_))x = F.relu(self.fc2(x))return xdef train_step(self, data, optimizer):optimizer.zero_grad()loss, y_hat = self(data)loss.backward()optimizer.step()return loss.item()def evaluate(self, data):with torch.no_grad():loss, y_hat = self(data)return loss.item(), y_hat# Example of using the model
n_uid = 1000
n_mid = 1000
n_cate = 500
n_carte = [10, 20]  # Example carte sizes
EMBEDDING_DIM = 128
HIDDEN_SIZE = 256
ATTENTION_SIZE = 128model = CAN_Model(n_uid, n_mid, n_cate, n_carte, EMBEDDING_DIM, HIDDEN_SIZE, ATTENTION_SIZE)# Optimizer
optimizer = optim.Adam(model.parameters(), lr=0.001)# Example data
uid = torch.randint(0, n_uid, (32,))
mid = torch.randint(0, n_mid, (32,))
cate = torch.randint(0, n_cate, (32,))
mid_his = torch.randint(0, n_mid, (32, 5))
cate_his = torch.randint(0, n_cate, (32, 5))
mask = torch.ones(32, 5)
target = torch.randint(0, 2, (32,))
seq_len = torch.randint(1, 5, (32,))
lr = 0.001# Training step
loss = model.train_step((uid, mid, cate, mid_his, cate_his, mask, target, seq_len, lr), optimizer)
print(f"Loss: {loss}")

Reference:

  1. 文章形成思路历程
  2. CAN: Feature Co-Action for Click-Through Rate Prediction-21年,阿里
  3. Implementation

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/488971.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Java从入门到工作3 - 框架/工具

3.1、SpringBoot框架结构 在 Spring Boot 或微服务架构中,每个服务的文件目录结构通常遵循一定的约定。以下是一个常见的 Spring Boot 服务目录结构示例,以及各个文件和目录的简要说明: my-service │ ├── src │ ├── main │ │…

基于SpringBoot的青少年心理健康教育网站

作者:计算机学姐 开发技术:SpringBoot、SSM、Vue、MySQL、JSP、ElementUI、Python、小程序等,“文末源码”。 专栏推荐:前后端分离项目源码、SpringBoot项目源码、Vue项目源码、SSM项目源码、微信小程序源码 精品专栏:…

基于事件驱动的websocket简单实现

websocket的实现 什么是websocket? WebSocket 是一种网络通信协议,旨在为客户端和服务器之间提供全双工、实时的通信通道。它是在 HTML5 规范中引入的,可以让浏览器与服务器进行持久化连接,以便实现低延迟的数据交换。 WebSock…

JavaEE 【知识改变命运】04 多线程(3)

文章目录 多线程带来的风险-线程安全线程不安全的举例分析产出线程安全的原因:1.线程是抢占式的2. 多线程修改同一个变量(程序的要求)3. 原子性4. 内存可见性5. 指令重排序 总结线程安全问题产生的原因解决线程安全问题1. synchronized关键字…

ElasticSearch如何做性能优化?

大家好,我是锋哥。今天分享关于【ElasticSearch如何做性能优化?】面试题。希望对大家有帮助; ElasticSearch如何做性能优化? 1000道 互联网大厂Java工程师 精选面试题-Java资源分享网 在 Elasticsearch 中,性能优化是…

Chrome浏览器调用ActiveX控件--allWebOffice控件

背景 allWebOffice控件能够实现在浏览器窗口中在线操作文档的应用(阅读、编辑、保存等),支持编辑文档时保留修改痕迹,支持书签位置内容动态填充,支持公文套红,支持文档保护控制等诸多办公功能,…

贪心算法(一)

目录 一、贪心算法 二、柠檬水找零 三、将数组和减半的最少操作次数 四、最大数 五、摆动序列 一、贪心算法 贪心算法的本质是选择每一阶段的局部最优,从而达到全局最优。 贪心策略:1、把解决问题的过程分为若干步。2、解决每一步的时候&#xff…

Scratch节日作品 | 圣诞节礼物——体验节日的温馨与编程的乐趣! ❄️

今天为大家推荐一款充满节日氛围的Scratch作品——《圣诞礼物》!这款程序不仅带来了雪花飘落、圣诞老人和麋鹿的经典场景,还通过编程的形式让小朋友们体验到收到礼物的喜悦。通过这款游戏,小朋友们能学习编程知识、了解圣诞文化,同…

基于Qwen2-VL模型针对 ImageToText 任务进行微调训练 - 数据处理

基于Qwen2-VL模型针对 ImageToText 任务进行微调训练 - 数据处理 flyfish 给定的图像生成一段自然语言描述。它的目标是生成一个或多个句子,能够准确地描述图像中的主要内容、物体、动作、场景等信息。例如,对于一张包含一只狗在草地上奔跑的图像&…

Spring Boot整合 RabbitMQ

文章目录 一. 引入依赖二. 添加配置三. Work Queue(工作队列模式)声明队列生产者消费者 四. Publish/Subscribe(发布订阅模式)声明队列和交换机生产者消费者 五. Routing(路由模式)声明队列和交换机生产者消费者 六. Topics(通配符模式)声明队列和交换机生产者消费者 一. 引入依…

谷粒商城—分布式基础

1. 整体介绍 1)安装vagrant 2)安装Centos7 $ vagrant init centos/7 A `Vagrantfile` has been placed in this directory. You are now ready to `vagrant up` your first virtual environment! Please read the comments in the Vagrantfile as well as documentation on…

【考前预习】1.计算机网络概述

往期推荐 子网掩码、网络地址、广播地址、子网划分及计算-CSDN博客 一文搞懂大数据流式计算引擎Flink【万字详解,史上最全】-CSDN博客 浅学React和JSX-CSDN博客 浅谈云原生--微服务、CICD、Serverless、服务网格_云原生 serverless-CSDN博客 浅谈维度建模、数据分析…

计算机视觉与医学的结合:推动医学领域研究的新机遇

目录 引言医学领域面临的发文难题计算机视觉与医学的结合:发展趋势计算机视觉结合医学的研究方向高区位参考文章结语 引言 计算机视觉(Computer Vision, CV)技术作为人工智能的重要分支,已经在多个领域取得了显著的应用成果&…

AI智算-k8s部署大语言模型管理工具Ollama

文章目录 简介k8s部署OllamaOpen WebUI访问Open-WebUI 简介 Github:https://github.com/ollama/ollama 官网:https://ollama.com/ API:https://github.com/ollama/ollama/blob/main/docs/api.md Ollama 是一个基于 Go 语言开发的可以本地运…

PyQt事件机制练习

一、思维导图 二、代码 import sysfrom PyQt6.QtTextToSpeech import QTextToSpeech from PyQt6.QtWidgets import QApplication, QWidget, QLabel, QPushButton, QLineEdit from PyQt6 import uic from PyQt6.QtCore import Qt, QTimerEvent, QTimeclass MyWidget(QWidget):d…

硬件设计 | Altium Designer软件PCB规则设置

基于Altium Designer(24.9.1)版本 嘉立创PCB工艺加工能力范围说明-嘉立创PCB打样专业工厂-线路板打样 规则参考-嘉立创 注意事项 1.每次设置完规则参数都要点击应用保存 2.每次创建PCB,都要设置好参数 3.可以设置默认规则,将…

【计算机学习笔记】GB2312、GBK、Unicode等字符编码的理解

之前编写win32程序时没怎么关注过宽字符到底是个啥东西,最近在编写网络框架又遇到字符相关的问题,所以写一篇文章记录一下(有些部分属于个人理解,如果有错误欢迎指出) 目录 几个常见的编码方式Unicode和UTF-8、UTF-16、…

深入理解 CSS 文本换行: overflow-wrap 和 word-break

前言 正常情况下,在固定宽度的盒子中的中文会自动换行。但是,当遇到非常长的英文单词或者很长的 URL 时,文本可能就不会自动换行,而会溢出所在容器。幸运的是,CSS 为我们提供了一些和文本换行相关的属性;今…

centos9升级OpenSSH

需求 Centos9系统升级OpenSSH和OpenSSL OpenSSH升级为openssh-9.8p1 OpenSSL默认为OpenSSL-3.2.2(根据需求进行升级) 将源码包编译为rpm包 查看OpenSSH和OpenSSL版本 ssh -V下载源码包并上传到服务器 openssh最新版本下载地址 wget https://cdn.openb…

Pull requests 和Merge Request其实是一个意思

Pull requests的定义 在Git中,PR(Pull Request)是一种协作开发的常用方式。它允许开发者将自己的代码变更(通常是一个分支)提交到项目的仓库中,然后请求负责代码审查的人员将这些变更合并到主分支中。通过…