
Graph-Enhanced Multi-Task Learning of Multi-Level Transition Dynamics for Session-based Recommendation
https://ojs.aaai.org/index.php/AAAI/article/view/16534
1. 背景
基于会话的推荐常用于在线应用,从电子商务到广告业务等。然而现有的工作没有很好地设计方法来捕获复杂动态转移中的时序信息和多层次的相互依赖的关系结构。因此本文提出 Multi-level Transition Dynamics (MTD) 方法。能够以自动和分层的方式联合学习会话内和会话间项目转换动态。
- 开发位置感知的注意力机制学习单个会话中的商品转换规律。
- 提出了一种图结构的层次关系编码器,通过使用全局图上下文执行embedding传播,以高阶连接性的形式显式捕获会话间的项目转换。
2. 方法
令
S={v_1,...,v_m,...,v_M}表示商品的集合,M为集合大小。会话s可以表示为
s=[v_{s,1},...,v_{s,I}],模型的输出为
Y=[y_1,...,y_M],即每个商品被点击的概率。
2.1 Intra-Session Item Relation Learning
为了捕获会话内的转换关系,作者设计了两个模块了学习内部转换模式:位置感知的自注意力网络和会话特定知识表示的注意力聚合。
2.1.1 Self-Attentive Item Embedding Layer
作者利用自注意力网络将会话的原始表征
E_s in mathbb{R}^{Itimes d}映射为潜在表征。公式如下:
left[begin{array}{l}
mathbf{Q} \
mathbf{K} \
mathbf{V}
end{array}right]=mathbf{E}_{s}left[begin{array}{l}
mathbf{W}_{Q} \
mathbf{W}_{K} \
mathbf{W}_{V}
end{array}right] ; operatorname{Att}(mathbf{Q}, mathbf{K}, mathbf{V})=deltaleft(frac{mathbf{Q K}^{T}}{sqrt{d}}right) mathbf{V}令
mathbf{X}_s in mathbb{R}^{I times d}=Att(mathbf{Q,K,V})表示经过attention后的表征,
delta(cdot)表示softmax函数,紧接着再经过FFN增强非线性表征,公式如下,其中
phi表示ReLU激活函数。
widetilde{mathbf{X}}_{s}=operatorname{FFN}left(mathbf{X}_{s}right)=varphileft(mathbf{X}_{s} cdot mathbf{W}_{1} mathbf{b}_{1}right) cdot mathbf{W}_{2} mathbf{b}_{2}
2.1.2 Position-aware Item-wise Aggregation Module
作者进一步设计了位置感知的注意力聚合组件以捕获会话内的商品之间的关系。对于和用户未来会感兴趣的商品更相关的会话内的商品会给予更大的权重,令需要学习的权重为
{alpha_1,...,alpha_I},该权重对应于会话中的每一个embedding
tilde{X}={X_{s,1},...,x_{s,i},...,x_{s,I}},权重个计算公式如下,其中g,w为可学习参数,维度有所不同,g的作用是将其映射为标量。
sigma,delta分别为sigmoid和softmax函数。
alpha_{i}=deltaleft(mathbf{g}^{T} cdot sigmaleft(mathbf{W}_{3} cdot mathbf{x}_{s, I} mathbf{W}_{4} cdot mathbf{x}_{s, i}right)right)
得到权重后,对其进行加权得到聚合后的表征
mathbf{x}_{s}^{*}=sum_{i=1}^{I} alpha_{i} cdot mathbf{x}_{s, i}另一方面,通过注入位置信息进一步增强了会话内 item-wise 融合模块,以捕获项目的特定会话时间顺序信号。位置信号的embedding维度同样为d,也就是和x的维度一致。构建包含位置相对关系的表征为下式,通过相对位置构建权重然后进行聚合得到
p_smathbf{p}_{s}=sum_{i=1}^{I} omega_{i} cdot mathbf{x}_{s, i} ; quad omega_{i}=propto exp (|i-I| 1)
最终的表征为拼接后的表征,具体为
q_s=W_c[x_{s,I},x_s^*,p_s],包含了最后一个商品的embedding,加权聚合后的embedding和包含位置关系的embedding。经过和目标商品
v_m做内积后在经过sigmoid得到最终的分数
tilde{y}_n=sigma(q_s^Tv_n),会话内的损失函数可以构建为:
mathcal{L}_{i n}=-sum_{n}^{N} mathbf{y}_{n} log left(tilde{mathbf{y}}_{n}right) left(1-mathbf{y}_{n}right) log left(1-tilde{mathbf{y}}_{n}right)
2.2 Global Transition Dynamics Modeling
为了捕获会话间的商品转换动态,本文设计了图神经网络架构(如图所示),以将不同会话的高阶相关信号注入会话表示中。

令
mathcal{G}=(mathcal{V},mathcal{E})表示图,V为节点,E为边。每一个会话s看成一条从
v_{s,1}到
v_{s_I}的边,首先通过传统的GCN在图上进行信息传播,公示如下,这里不再赘述其含义。
mathbf{H}^{(l 1)}=varphileft(mathbf{A}, mathbf{H}^{l} mathbf{W}^{l}right)=varphileft(hat{mathbf{D}}^{-frac{1}{2}} hat{mathbf{A}} hat{mathbf{D}}^{-frac{1}{2}} mathbf{H}^{l} mathbf{W}^{l}right)
2.2.1 Global Dependency Representation
在得到
H={h_1,...,h_m,...,h_M}后,捕获来自不同会话的相关项之间的高阶全局依赖关系。首先对H embedding集中的embedding进行聚合
mathbf{b}=tau(mathbf{H}),其中
tau为平均池化。本文同样局部级表征H和图级表征 z 之间的互信息关系来增强跨会话商品编码。正负样本分别为
(h_m,z),(tilde{h}_m,z),负样本的采样方式可以参考mim。然后将正负样本分别送入编码函数,如下式,结果表示给定h和z,节点属于图G的概率。
xileft(mathbf{h}_{m}, mathbf{z}right)=sigmaleft(mathbf{h}_{m}^{T} cdot mathbf{W}_{g} cdot mathbf{z}right) ; mathbb{R}^{d} times mathbb{R}^{d} rightarrow mathbb{R}
最终图级别的损失函数为:
begin{aligned}
mathcal{L}_{c o} &=-frac{1}{N_{p o s} N_{n e g}}left(sum_{i=1}^{N_{p o s}} rholeft(mathbf{h}_{m}, mathbf{z}right) cdot log xileft(mathbf{h}_{m}, mathbf{z}right)right.\
&left. sum_{i=1}^{N_{n e g}} rholeft(widetilde{mathbf{h}}_{m}, mathbf{z}right) cdot log left[1-xileft(widetilde{mathbf{h}}_{m}, mathbf{z}right)right]right)
end{aligned}总损失函数为:
L=L_{cr} lambda_1 L_{in} lambda_2 ||Theta||_2
3. 实验结果

image.png
4. 总结
本文针对会话推荐方面的推荐算法,提出了新的方案。该方法一方面,在会话内部编码时加入了位置信息;另一方面,利用互信息结合局部表征和全局表征来构建损失函数,以获取不同会话之间商品的转换关系。本文和之前的互信息增强图学习的文章类似,可以结合起来看。