分组数据的交叉验证方法
如果生成过程产生了依赖样本的组,那么独立同分布(i.i.d.)假设就会失效。
.
这种数据分组是特定于领域的。例如,医疗数据通常从多位患者中收集,每位患者可能包含多个样本,而这些样本很可能依赖于各自所属的患者组。在这个例子中,每个样本的患者 ID 就是它的组标识符。
.
在这种情况下,我们希望了解,基于某一组群体训练的模型是否能够很好地泛化到未见过的组。为了评估这一点,我们需要确保验证折中的所有样本均来自未在训练折中出现过的组。
.
以下的交叉验证分割方法可用于实现此目的。样本的分组标识符通过 groups 参数指定。
与普通的交叉验证方法不同,Group 分割方法(如 GroupKFold)在划分数据时不仅考虑样本本身,还根据指定的“组”(groups)进行分割,确保同一个组内的样本不会同时出现在训练集和测试集中。这种方式常用于具有相关样本的数据集(例如同一个用户的多条记录),避免组间信息泄露,更好地评估模型的泛化能力。
例如,如果数据来自不同的个体,每个个体包含多个样本,并且模型足够灵活,能够从特定个体的特征中学习,那么该模型可能无法泛化到新的个体上。GroupKFold 能够帮助检测这类过拟合的情况。
- GroupKFold:将数据分成 K 折,但在每一折中,来自同一组的所有样本要么全在训练集,要么全在测试集。
- StratifiedGroupKFold 结合了分层抽样(确保各个类别的样本比例与整个数据集一致)和分组抽样(确保同一组的样本不会同时出现在训练集和测试集中。)的特点,确保在划分数据时保持类别比例和组的独立性。
- GroupShuffleSplit 基于分组数据进行随机抽样,确保同一个组中的样本要么全部出现在训练集,要么全部出现在测试集中。它类似于 ShuffleSplit,但针对分组数据。
- 此外还有LeaveOneGroupOut和LeavePGroupsOut——在每次迭代中,将 一个或者P 个组作为测试集,其他组作为训练集,类似 LeavePOut,但基于组进行分割。
代码汇总
from sklearn.model_selection import GroupKFold
X = [0.1, 0.2, 2.2, 2.4, 2.3, 4.55, 5.8, 8.8, 9, 10]
y = ["a", "b", "b", "b", "c", "c", "c", "d", "d", "d"]
groups = [1, 1, 1, 2, 2, 2, 3, 3, 3, 3]
gkf = GroupKFold(n_splits=3)
for train, test in gkf.split(X, y, groups=groups):print("%s %s" % (train, test))from sklearn.model_selection import StratifiedGroupKFold
X = list(range(18))
y = [1] * 6 + [0] * 12
groups = [1, 2, 3, 3, 4, 4, 1, 1, 2, 2, 3, 4, 5, 5, 5, 6, 6, 6]
sgkf = StratifiedGroupKFold(n_splits=3)
for train, test in sgkf.split(X, y, groups=groups):print("%s %s" % (train, test))from sklearn.model_selection import LeaveOneGroupOut
X = [1, 5, 10, 50, 60, 70, 80]
y = [0, 1, 1, 2, 2, 2, 2]
groups = [1, 1, 2, 2, 3, 3, 3]
logo = LeaveOneGroupOut()
for train, test in logo.split(X, y, groups=groups):print("%s %s" % (train, test))from sklearn.model_selection import LeavePGroupsOut
X = np.arange(6)
y = [1, 1, 1, 2, 2, 2]
groups = [1, 1, 2, 2, 3, 3]
lpgo = LeavePGroupsOut(n_groups=2)
for train, test in lpgo.split(X, y, groups=groups):print("%s %s" % (train, test))from sklearn.model_selection import GroupShuffleSplit
X = [0.1, 0.2, 2.2, 2.4, 2.3, 4.55, 5.8, 0.001]
y = ["a", "b", "b", "b", "c", "c", "c", "a"]
groups = [1, 1, 2, 2, 3, 3, 4, 4]
gss = GroupShuffleSplit(n_splits=4, test_size=0.5, random_state=0)
for train, test in gss.split(X, y, groups=groups):print("%s %s" % (train, test))
参考原文:https://scikit-learn.org/stable/modules/cross_validation.html#computing-cross-validated-metrics