机器学习 - RankNet - 通过神经网络学习排序偏好
RankNet 简介
RankNet 是一种基于神经网络的排序学习算法,由微软在 2007 年提出。其核心思想是通过文档对的相对比较来训练模型,预测两个文档的排序关系(例如某个文档是否应排在另一个文档前面)。RankNet 不直接预测文档的绝对得分,而是学习文档之间的排序偏好。
核心原理
输入与输出:
- 输入:文档对(Document Pair)的特征表示(如关键词匹配度、点击率、文档长度等)。
- 输出:两个文档的排序概率(例如,文档 A 排在文档 B 前面的概率)。
概率建模:
- 对两个文档 (i) 和 (j),模型计算出它们的得分 (si) 和 (s_j),用两者的得分差计算排序概率:
[
P{ij} = \frac{1}{1 + e^{-\sigma(s_i - s_j)}}
]
其中 (\sigma) 是调节参数,控制概率的陡峭程度。
损失函数:
- 使用交叉熵损失衡量预测概率 (P{ij}) 和真实标签 (\bar{P}{ij}) 的差异:
[
\text{Loss} = -\bar{P}{ij} \log P{ij} - (1 - \bar{P}{ij}) \log (1 - P{ij})
] - 真实标签 (\bar{P}_{ij}) 通常是 1(文档 (i) 更相关)或 0(文档 (j) 更相关)。
训练方式:
- 通过反向传播和梯度下降更新神经网络的权重,最小化损失函数。
示例:搜索引擎结果排序
假设我们需要对搜索词“机器学习教程”返回的文档进行排序,构造以下训练数据:
查询-文档对 |
特征(关键词匹配度, 点击率) |
真实相关性标签 |
文档 A |
[0.95, 0.8] |
高相关(应排第一) |
文档 B |
[0.80, 0.6] |
中相关(应排第二) |
文档 C |
[0.70, 0.3] |
低相关(应排第三) |
训练步骤
生成文档对:
- 从标签中构造对比对:((A, B))、((A, C))、((B, C))。
模型预测:
- 用神经网络计算每个文档的得分:
- (s_A = 2.0)(假设初始模型输出)
- (s_B = 1.5)
- (s_C = 1.0)
计算排序概率(例如 (\sigma=1)):
- 对于对 ((A, B)):
[
P_{AB} = \frac{1}{1 + e^{-(2.0 - 1.5)}} = \frac{1}{1 + e^{-0.5}} \approx 0.62
]
- 真实标签 (\bar{P}_{AB} = 1)(因为 A 应排在 B 前面)。
计算损失与更新参数:
- 损失:(-\log(0.62) \approx 0.48)
- 反向传播调整神经网络权重,使得 (sA - s_B) 增大,从而提高 (P{AB})。
重复迭代:
- 对每个文档对计算损失并更新参数,最终模型能正确排序。
RankNet 的优缺点
优点:
- 适合处理隐含的排序偏好(如点击数据)。
- 对噪声标签有一定的鲁棒性。
缺点:
- 仅考虑文档对,未全局优化整个列表(Listwise 方法如 LambdaMART 更优)。
- 计算复杂度随文档对数量增加而上升。
应用场景
- 搜索引擎结果排序
- 推荐系统中的商品排序
- 广告竞价排名
RankNet 是排序学习(Learning to Rank)的重要基础,后续改进算法(如 LambdaRank、ListNet)在其思路上进一步发展。