方式一. 联网下载数据集

# pip install pyarrow
from datasets import load_dataset
dataset = load_dataset(path='glue', name='sst2')

正常情况:

DatasetDict({
    train: Dataset({
        features: ['sentence', 'label', 'idx'],
        num_rows: 67349
    })
    validation: Dataset({
        features: ['sentence', 'label', 'idx'],
        num_rows: 872
    })
    test: Dataset({
        features: ['sentence', 'label', 'idx'],
        num_rows: 1821
    })
})

异常报错:

ReadTimeout: (ReadTimeoutError("HTTPSConnectionPool(host='huggingface.co', port=443): Read timed out. (read timeout=10)"), '(Request ID: a16bb2d9-22f4-467e-9900-317e2368f09b)')

方式二. 本地读取数据集

由于网络原因,无法顺利下载,所以直接去官网手动下载数据文件,然后处理成DatasetDict格式,效果一样。

import pandas as pd

# 读取 Parquet 文件
train_df = pd.read_parquet('../data/sst2/train-00000-of-00001.parquet')
validation_df = pd.read_parquet('../data/sst2/validation-00000-of-00001.parquet')
test_df = pd.read_parquet('../data/sst2/test-00000-of-00001.parquet')

# 将 pandas DataFrame 转换为 DatasetDict 格式
from datasets import Dataset, DatasetDict

dataset = DatasetDict({
    'train': Dataset.from_pandas(train_df, preserve_index=False),
    'validation': Dataset.from_pandas(validation_df, preserve_index=False),
    'test': Dataset.from_pandas(test_df, preserve_index=False)
})

输出:

DatasetDict({
    train: Dataset({
        features: ['sentence', 'label', 'idx'],
        num_rows: 67349
    })
    validation: Dataset({
        features: ['sentence', 'label', 'idx'],
        num_rows: 872
    })
    test: Dataset({
        features: ['sentence', 'label', 'idx'],
        num_rows: 1821
    })
})
Logo

DAMO开发者矩阵,由阿里巴巴达摩院和中国互联网协会联合发起,致力于探讨最前沿的技术趋势与应用成果,搭建高质量的交流与分享平台,推动技术创新与产业应用链接,围绕“人工智能与新型计算”构建开放共享的开发者生态。

更多推荐