一 核心模块coaction
- 对于每个特征对(feature_pairs)
- weight, bias 来自于P_induction
- P_fead是MLP的input
举个例子:如果是用户ID和产品ID的co-action,且产品ID是做induction,用户ID是做feed。
- step1 用户ID/产品ID都先形成一个向量:对于产品ID,用parameter lookup获取一个可学习的P_induction(这个维度是(wi+bi) * L depth of mlp); 用户ID则直接形成一个向量P_fead
- step2 P_induction 这个向量逐层(MLP层),reshape成MLP网络的weight 和bias;
- step3 weight和bias作为MLP的参数,利用P_feed 作为input,进行MLP前向运算,得到特征交互结果
- 代码解读
#### CAN config #####
weight_emb_w = [[16, 8], [8,4]] # micro-mlp的参数dimension
weight_emb_b = [0, 0] # bias参数
orders = 3 # 特征的阶数,文章提到了,要做高阶特征交叉,直接是P_feed^c, c就是阶数
order_indep = False # True
WEIGHT_EMB_DIM = (sum([w[0]*w[1] for w in weight_emb_w]) + sum(weight_emb_b)) # * orders 这个是供每一个micro-mlp拆解w&b需要的dimension总和
INDEP_NUM = 1
if order_indep:INDEP_NUM *= orders
###### 这一部分对应图中绿色和橙色部分,主要是把P_feed&P_induction的嵌入表示得到 ##########
if self.use_coaction:# batch_ph batch输入的数据;his_batch_ph历史批次数据; his_batch_embedded 历史嵌入表示ph_dict = {"item": [self.mid_batch_ph, self.mid_his_batch_ph, self.mid_his_batch_embedded],"cate": [self.cate_batch_ph, self.cate_his_batch_ph, self.cate_his_batch_embedded]}### p_induction ####self.mlp_batch_embedded = [] # induction embeddingwith tf.device(device):# 定义可训练的嵌入矩阵,在这里n_mid是item id的数量self.item_mlp_embeddings_var = tf.get_variable("item_mlp_embedding_var", [n_mid, INDEP_NUM * WEIGHT_EMB_DIM], trainable=True)self.cate_mlp_embeddings_var = tf.get_variable("cate_mlp_embedding_var", [n_cate, INDEP_NUM * WEIGHT_EMB_DIM], trainable=True)# 通过embedding_lookup在上一步初始化好的矩阵中找到对应的embedding表示self.mlp_batch_embedded.append(tf.nn.embedding_lookup(self.item_mlp_embeddings_var, ph_dict['item'][0]))self.mlp_batch_embedded.append(tf.nn.embedding_lookup(self.cate_mlp_embeddings_var, ph_dict['cate'][0]))#########P_feed input ########self.input_batch_embedded = []self.item_input_embeddings_var = tf.get_variable("item_input_embedding_var", [n_mid, weight_emb_w[0][0] * INDEP_NUM], trainable=True)self.cate_input_embeddings_var = tf.get_variable("cate_input_embedding_var", [n_cate, weight_emb_w[0][0] * INDEP_NUM], trainable=True) self.input_batch_embedded.append(tf.nn.embedding_lookup(self.item_input_embeddings_var, ph_dict['item'][1]))self.input_batch_embedded.append(tf.nn.embedding_lookup(self.cate_input_embeddings_var, ph_dict['cate'][1]))
################这一部分是P_induction&P_feed在MLP的使用#######################
if self.use_coaction:# p_feed/inputinput_batch = self.input_batch_embeddedtmp_sum, tmp_seq = [], []if INDEP_NUM == 2:# 文章说明了是feature pairs,mlp_batch&input_batch都包含了两个部分,要分别组合for i, mlp_batch in enumerate(self.mlp_batch_embedded):for j, input_batch in enumerate(self.input_batch_embedded):coaction_sum, coaction_seq = gen_coaction(mlp_batch[:, WEIGHT_EMB_DIM * j: WEIGHT_EMB_DIM * (j+1)], input_batch[:, :, weight_emb_w[0][0] * i: weight_emb_w[0][0] * (i+1)], EMBEDDING_DIM, mode=CALC_MODE,mask=self.mask) tmp_sum.append(coaction_sum)tmp_seq.append(coaction_seq)else:for i, (mlp_batch, input_batch) in enumerate(zip(self.mlp_batch_embedded, self.input_batch_embedded)):coaction_sum, coaction_seq = gen_coaction(mlp_batch[:, :INDEP_NUM * WEIGHT_EMB_DIM], input_batch[:, :, :weight_emb_w[0][0]], EMBEDDING_DIM, mode=CALC_MODE, mask=self.mask) tmp_sum.append(coaction_sum)tmp_seq.append(coaction_seq)self.coaction_sum = tf.concat(tmp_sum, axis=1) # sum poolingself.cross.append(self.coaction_sum) # concat
###### core interaction 核心运算 #########
def gen_coaction(ad, his_items, dim, mode="can", mask=None):"""ad: inducthis_items 待交互seq"""weight, bias = [], []idx = 0weight_orders = []bias_orders = []# 拆解得到weight&bias参数for i in range(orders):for w, b in zip(weight_emb_w, weight_emb_b):weight.append(tf.reshape(ad[:, idx:idx+w[0]*w[1]], [-1, w[0], w[1]]))idx += w[0] * w[1]if b == 0:bias.append(None)else:bias.append(tf.reshape(ad[:, idx:idx+b], [-1, 1, b]))idx += bweight_orders.append(weight)bias_orders.append(bias)if not order_indep:breakif mode == "can":out_seq = []hh = []# 高阶特征处理,explicit deal withfor i in range(orders):hh.append(his_items**(i+1))#hh = [sum(hh)]for i, h in enumerate(hh):if order_indep:weight, bias = weight_orders[i], bias_orders[i]else:weight, bias = weight_orders[0], bias_orders[0]# 模拟MLP forward calculationfor j, (w, b) in enumerate(zip(weight, bias)):h = tf.matmul(h, w)if b is not None:h = h + bif j != len(weight)-1:h = tf.nn.tanh(h)out_seq.append(h)out_seq = tf.concat(out_seq, 2)if mask is not None:mask = tf.expand_dims(mask, axis=-1) out_seq = out_seq * mask# 序列交互结果做sum_poolingout = tf.reduce_sum(out_seq, 1)if keep_fake_carte_seq and mode=="emb":return out, out_seqreturn out, None
二 文章中的应用
整体的模型结构两部分构成:
- co-action作为核心形成的一部分,对于用户的序列特征,一一作用后做sum-pooling,对于非序列特征,作用后直接输出
- DIEN作为核心形成的一部分
两部分concat以后加一个DNN常规操作,看起来就像是用co-action做显式的特征交叉,然后DIEN做之前的序列建模。
三 一些其他细节补充
- can 部分高阶特征处理: 直接把待交叉特征p_fead 做c阶运算后,再与p_induction进行作用
- 在文章场景,p_induction是target_item,也就是产品
四 用tf2/torch重构
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as Fclass CAN_Model(nn.Module):def __init__(self, n_uid, n_mid, n_cate, n_carte, EMBEDDING_DIM, HIDDEN_SIZE, ATTENTION_SIZE, use_negsampling=False, use_softmax=True, use_coaction=False, use_cartes=False):super(CAN_Model, self).__init__()self.n_uid = n_uidself.n_mid = n_midself.n_cate = n_cateself.n_carte = n_carteself.EMBEDDING_DIM = EMBEDDING_DIMself.HIDDEN_SIZE = HIDDEN_SIZEself.ATTENTION_SIZE = ATTENTION_SIZEself.use_negsampling = use_negsamplingself.use_softmax = use_softmaxself.use_coaction = use_coactionself.use_cartes = use_cartesself.uid_embeddings = nn.Embedding(n_uid, EMBEDDING_DIM)self.mid_embeddings = nn.Embedding(n_mid, EMBEDDING_DIM)self.cate_embeddings = nn.Embedding(n_cate, EMBEDDING_DIM)if use_cartes:self.carte_embeddings = nn.ModuleList([nn.Embedding(num, EMBEDDING_DIM) for num in n_carte])if self.use_coaction:self.item_mlp_embeddings = nn.Parameter(torch.randn(n_mid, INDEP_NUM * WEIGHT_EMB_DIM))self.cate_mlp_embeddings = nn.Parameter(torch.randn(n_cate, INDEP_NUM * WEIGHT_EMB_DIM))self.input_batch_embeddings = nn.ModuleList([nn.Embedding(n_mid, weight_emb_w[0][0] * INDEP_NUM), nn.Embedding(n_cate, weight_emb_w[0][0] * INDEP_NUM)])self.fc1 = nn.Linear(200, 80)self.fc2 = nn.Linear(80, 2 if use_softmax else 1)def forward(self, uid, mid, cate, mid_his, cate_his, mask, target, seq_len, lr, carte=None):# Embedding lookupsuid_emb = self.uid_embeddings(uid)mid_emb = self.mid_embeddings(mid)cate_emb = self.cate_embeddings(cate)mid_his_emb = self.mid_embeddings(mid_his)cate_his_emb = self.cate_embeddings(cate_his)if self.use_cartes:carte_emb = [emb(carte[:, i, :]) for i, emb in enumerate(self.carte_embeddings)]# Co-action logic (if enabled)if self.use_coaction:# This is a simplified version of the co-action implementation from the original TensorFlow codemlp_embedded_item = self.item_mlp_embeddings[mid]mlp_embedded_cate = self.cate_mlp_embeddings[cate]input_embedded_item = self.input_batch_embeddings[0](mid_his)input_embedded_cate = self.input_batch_embeddings[1](cate_his)# Further coaction operations can be added based on your logic# Concatenate item and category embeddingsitem_eb = torch.cat([mid_emb, cate_emb], dim=1)item_his_eb = torch.cat([mid_his_emb, cate_his_emb], dim=2)item_his_eb_sum = item_his_eb.sum(dim=1)if self.use_negsampling:# Assuming the negative sampling implementation would need its own logic.pass# FC layersx = self.fc1(item_eb)x = F.relu(x)x = self.fc2(x)# Loss computationif self.use_softmax:y_hat = F.softmax(x, dim=-1)loss = F.cross_entropy(y_hat, target)else:y_hat = torch.sigmoid(x)loss = F.binary_cross_entropy_with_logits(x, target)return loss, y_hatdef auxiliary_loss(self, h_states, click_seq, noclick_seq, mask):mask = mask.float()click_input = torch.cat([h_states, click_seq], dim=-1)noclick_input = torch.cat([h_states, noclick_seq], dim=-1)click_prop = self.auxiliary_net(click_input)[:, :, 0]noclick_prop = self.auxiliary_net(noclick_input)[:, :, 0]click_loss = -torch.log(click_prop) * masknoclick_loss = -torch.log(1.0 - noclick_prop) * maskloss = (click_loss + noclick_loss).mean()return lossdef auxiliary_net(self, in_):x = F.relu(self.fc1(in_))x = F.relu(self.fc2(x))return xdef train_step(self, data, optimizer):optimizer.zero_grad()loss, y_hat = self(data)loss.backward()optimizer.step()return loss.item()def evaluate(self, data):with torch.no_grad():loss, y_hat = self(data)return loss.item(), y_hat# Example of using the model
n_uid = 1000
n_mid = 1000
n_cate = 500
n_carte = [10, 20] # Example carte sizes
EMBEDDING_DIM = 128
HIDDEN_SIZE = 256
ATTENTION_SIZE = 128model = CAN_Model(n_uid, n_mid, n_cate, n_carte, EMBEDDING_DIM, HIDDEN_SIZE, ATTENTION_SIZE)# Optimizer
optimizer = optim.Adam(model.parameters(), lr=0.001)# Example data
uid = torch.randint(0, n_uid, (32,))
mid = torch.randint(0, n_mid, (32,))
cate = torch.randint(0, n_cate, (32,))
mid_his = torch.randint(0, n_mid, (32, 5))
cate_his = torch.randint(0, n_cate, (32, 5))
mask = torch.ones(32, 5)
target = torch.randint(0, 2, (32,))
seq_len = torch.randint(1, 5, (32,))
lr = 0.001# Training step
loss = model.train_step((uid, mid, cate, mid_his, cate_his, mask, target, seq_len, lr), optimizer)
print(f"Loss: {loss}")
Reference:
- 文章形成思路历程
- CAN: Feature Co-Action for Click-Through Rate Prediction-21年,阿里
- Implementation