sklearn.datasets.fetch_openml 使用详解:轻松获取开放机器学习数据集

1. fetch_openml 简介

fetch_openml 是 scikit-learn 的 datasets 模块中一个非常重要的函数,它提供了从 OpenML 平台 获取数据集的便捷接口。OpenML 是一个开放的机器学习平台,网址为 https://www.openml.org,它收集了大量真实世界的数据集,供机器学习研究者和实践者使用。

在 scikit-learn 的 0.20 版本后,fetch_mldata 函数已被 fetch_openml 取代,成为获取外部数据集的主要方式。通过这个函数,用户可以轻松访问数千个真实数据集,无需手动下载和处理文件,极大提高了机器学习实验的效率。

2. 函数参数详解

fetch_openml 函数提供了多个参数,让用户能够精确控制要获取的数据集及其返回格式。下面是各个参数的详细说明:

参数 类型 默认值 说明
name str None 数据集的名称标识符(如 “mnist_784”)
data_id int None 数据集的唯一 OpenML ID(如 853)
version int/‘active’ ‘active’ 数据集版本,推荐指定具体版本而非使用 ‘active’
as_frame boolean False 是否返回 Pandas DataFrame 格式
return_X_y boolean False 是否返回 (data, target) 元组而非 Bunch 对象
cache boolean True 是否缓存下载的数据集
parser str ‘auto’ 指定 ARFF 文件解析器(‘liac-arff’、‘pandas’ 等)
target_column str/list/None ‘default-target’ 指定作为目标变量的列
data_home str None 指定数据下载和缓存的目录

重要参数说明

name 与 data_id
这两个参数都用于指定要获取的数据集,但只需提供其中一个。name 是数据集的名称标识符,如 “mnist_784” 表示 MNIST 手写数字数据集;而 data_id 是 OpenML 平台分配给每个数据集的唯一数字 ID。如果同时指定了这两个参数,函数会优先使用 data_id

version 参数的重要性
OpenML 上同一个数据集可能有多个版本。当 version='active' 时,会获取仍处于活动状态的最旧版本。但由于可能存在多个活动版本且它们之间可能有显著差异,强烈建议明确指定版本号以确保实验的可重现性。

as_frame 与 return_X_y 的组合使用

  • as_frame=False, return_X_y=False:返回 Bunch 对象,包含数据、目标值和元数据
  • as_frame=True, return_X_y=False:返回 Bunch 对象,但其中的数据为 Pandas DataFrame
  • as_frame=True/False, return_X_y=True:直接返回 (data, target) 元组

parser 参数的选择
当数据集包含混合数据类型或分类变量时,parser 参数尤为重要。设置 parser="pandas" 可以避免一些数据类型相关的问题,特别是在处理包含字符串特征的数据集时。

3. 返回值结构

fetch_openml 函数的返回值取决于 return_X_y 参数的设置:

return_X_y=False(默认)时

返回一个 Bunch 对象,这是一个类似字典的数据结构,可以通过属性或键访问。主要包含以下属性:

  • data:特征矩阵,形状为 (n_samples, n_features)
  • target:目标变量(标签或回归值)
  • feature_names:特征名称列表
  • target_names:目标变量名称(如果有)
  • DESCR:数据集的详细描述
  • details:来自 OpenML 的更多元数据
  • url:数据集在 OpenML 上的 URL
  • frame:仅当 as_frame=True 时存在,包含完整的 DataFrame

return_X_y=True

直接返回一个元组 (data, target),其中 data 是特征矩阵,target 是目标变量。

4. 实用示例

示例 1:加载 MNIST 手写数字数据集

from sklearn.datasets import fetch_openml
import matplotlib.pyplot as plt
import numpy as np

# 加载 MNIST 数据集
mnist = fetch_openml("mnist_784", version=1, as_frame=False, cache=True)
X, y = mnist.data, mnist.target

print("特征矩阵形状:", X.shape)  # 输出: (70000, 784)
print("目标变量形状:", y.shape)  # 输出: (70000,)
print("数据类型:", type(X))     # 输出: <class 'numpy.ndarray'>

# 可视化一个数字样本
index = np.random.randint(0, 70000)
image = X[index].reshape(28, 28)  # 将 784 维向量重塑为 28x28 图像
label = y[index]

plt.imshow(image, cmap="gray")
plt.title(f"标签: {label}")
plt.axis('off')
plt.show()

说明:MNIST 数据集包含 70,000 张 28×28 像素的手写数字图像,每个图像被展平为 784 维的向量。as_frame=False 确保返回 NumPy 数组,适合大多数机器学习算法。

示例 2:加载泰坦尼克数据集(返回 DataFrame)

from sklearn.datasets import fetch_openml

# 加载泰坦尼克数据集,使用 DataFrame 格式
titanic = fetch_openml("titanic", version=1, as_frame=True, parser="pandas")
df = titanic.frame

print("数据形状:", df.shape)
print("\n前几行数据:")
print(df.head())
print("\n数据集描述:")
print(titanic.DESCR[:500])  # 显示前500个字符的描述

说明:当 as_frame=True 时,可以直接获取 Pandas DataFrame,便于数据探索和分析。parser="pandas" 确保使用 Pandas 解析器处理数据,可以避免一些数据类型问题。

示例 3:替代被移除的波士顿房价数据集

自 scikit-learn 1.2 版本起,由于伦理考虑,波士顿房价数据集已被移除。可以使用 fetch_openml 获取替代数据集:

from sklearn import datasets

# 方法1:直接返回特征和目标
data_x, data_y = datasets.fetch_openml(name="boston", version=1, 
                                      as_frame=True, return_X_y=True, 
                                      parser="pandas")

# 方法2:返回 Bunch 对象
boston = datasets.fetch_openml(name="boston", version=1, 
                              as_frame=True, parser="pandas")
X, y = boston.data, boston.target
feature_names = boston.feature_names

5. 使用技巧与注意事项

  1. 数据集排序问题:使用 fetch_openml 获取的 MNIST 数据集是未排序的,而之前的 fetch_mldata 返回的是按目标排序的数据集。如果需要排序,可能需要手动处理。

  2. 数据缓存:默认情况下,cache=True 会缓存下载的数据集,存储位置通常为 ~/scikit_learn_data。可以通过 data_home 参数指定其他位置。

  3. 错误处理:如果遇到数据类型错误,可以尝试以下解决方案:

    • 设置 parser="pandas"
    • 明确转换数据类型,如 X.astype(np.float64)
    • 检查目标变量是否需要类型转换
  4. 大数据集处理:对于大型数据集(如 MNIST),可以考虑只加载部分数据进行初步实验:

    # 随机选择部分样本
    shuffle_index = np.random.permutation(60000)
    X_small, y_small = X[shuffle_index[:10000]], y[shuffle_index[:10000]]
    

6. 总结

fetch_openml 是 scikit-learn 中一个功能强大的工具,它提供了直接访问 OpenML 平台上大量数据集的便捷途径。通过合理设置参数,用户可以灵活控制要获取的数据集版本、返回格式和数据表示方式,满足不同场景下的需求。

无论是进行机器学习教学、算法测试还是实际项目开发,fetch_openml 都能大大简化数据获取和预处理的流程,让研究者更专注于模型本身而非数据准备工作。

Logo

DAMO开发者矩阵,由阿里巴巴达摩院和中国互联网协会联合发起,致力于探讨最前沿的技术趋势与应用成果,搭建高质量的交流与分享平台,推动技术创新与产业应用链接,围绕“人工智能与新型计算”构建开放共享的开发者生态。

更多推荐