TabPFN 开源项目使用教程

【免费下载链接】TabPFN Official implementation of the TabPFN paper (https://arxiv.org/abs/2207.01848) and the tabpfn package. 【免费下载链接】TabPFN 项目地址: https://gitcode.com/gh_mirrors/ta/TabPFN

1. 项目介绍

TabPFN 是一个基于 Transformer 架构的开源表格数据预测工具,专门为小型表格分类问题设计。该项目由 Prior Labs 团队开发,提供了官方的 PyTorch 实现,支持 CUDA 加速计算。

主要特点

  • 极速预测:能够在不到一秒的时间内完成小型表格数据的分类任务
  • 无需超参数调优:内置预训练模型已经过优化,用户无需手动调整参数
  • GPU 加速支持:提供完整的 CUDA 支持,显著提升计算性能
  • 端到端处理:内置数据预处理功能,无需额外数据清洗步骤

2. 项目快速启动

环境要求

  • Python 3.9 或更高版本
  • PyTorch 2.1+
  • 建议使用 GPU(8GB VRAM 以上)

安装方式

方式一:通过 pip 安装

pip install tabpfn

方式二:从源码安装

git clone https://gitcode.com/gh_mirrors/ta/TabPFN
cd TabPFN
pip install -e .

方式三:开发环境安装(包含测试工具)

pip install -e ".[dev]"

快速使用示例

分类任务示例
from sklearn.datasets import load_breast_cancer
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
from tabpfn import TabPFNClassifier

# 加载乳腺癌数据集
X, y = load_breast_cancer(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5, random_state=42)

# 初始化分类器
classifier = TabPFNClassifier()
classifier.fit(X_train, y_train)

# 进行预测
predictions = classifier.predict(X_test)
print("准确率:", accuracy_score(y_test, predictions))
回归任务示例
from sklearn.datasets import fetch_openml
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split
from tabpfn import TabPFNRegressor

# 加载波士顿房价数据集
df = fetch_openml(data_id=531, as_frame=True)
X = df.data
y = df.target.astype(float)

# 划分训练测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5, random_state=42)

# 初始化回归器
regressor = TabPFNRegressor()
regressor.fit(X_train, y_train)

# 预测并评估
predictions = regressor.predict(X_test)
mse = mean_squared_error(y_test, predictions)
print("均方误差:", mse)

3. 核心功能特性

3.1 内置预处理

TabPFN 内置了完整的数据预处理流水线,包括:

  • 缺失值处理
  • 特征缩放
  • 分类特征编码
  • 异常值检测

3.2 模型版本管理

支持多个版本的预训练模型:

from tabpfn.constants import ModelVersion

# 使用 TabPFN v2 版本
classifier_v2 = TabPFNClassifier.create_default_for_version(ModelVersion.V2)

# 使用最新的 TabPFN 2.5 版本(默认)
classifier_latest = TabPFNClassifier()

3.3 GPU 加速配置

# 使用 GPU 加速
classifier = TabPFNClassifier(device='cuda')

# 强制在 CPU 上运行大数据集(不推荐)
import os
os.environ['TABPFN_ALLOW_CPU_LARGE_DATASET'] = 'true'

4. 高级用法

4.1 模型保存与加载

from tabpfn.model_loading import save_fitted_tabpfn_model, load_fitted_tabpfn_model

# 保存训练好的模型
classifier.fit(X_train, y_train)
save_fitted_tabpfn_model(classifier, "my_model.tabpfn_fit")

# 加载模型
loaded_classifier = load_fitted_tabpfn_model("my_model.tabpfn_fit", device="cpu")

4.2 缓存优化

# 使用 KV 缓存加速预测
classifier = TabPFNClassifier(fit_mode='fit_with_cache')
classifier.fit(X_train, y_train)  # 训练时构建缓存
predictions = classifier.predict(X_test)  # 预测时使用缓存加速

4.3 集成配置调整

# 调整集成配置数量
classifier = TabPFNClassifier(N_ensemble_configurations=64)

5. 最佳实践

5.1 数据准备建议

  • 避免过度预处理:TabPFN 内置了完整的预处理流程
  • 分类数据处理:使用 OrdinalEncoder 而非 One-Hot Encoding
  • 缺失值处理:保留 NaN 值,模型会自动处理

5.2 性能优化技巧

  • 使用 GPU:确保启用 CUDA 支持
  • 合理设置批次大小:根据显存大小调整
  • 启用缓存:对于重复预测任务使用 KV 缓存

5.3 错误处理

# 处理模型加载错误
try:
    classifier = TabPFNClassifier()
    classifier.fit(X_train, y_train)
except Exception as e:
    print(f"模型加载失败: {e}")
    # 检查网络连接和模型文件完整性

6. 项目结构说明

TabPFN 项目采用模块化设计:

TabPFN/
├── src/tabpfn/           # 核心源代码
│   ├── classifier.py     # 分类器实现
│   ├── regressor.py      # 回归器实现
│   ├── model/           # 模型架构
│   └── preprocessors/   # 预处理模块
├── examples/            # 使用示例
├── tests/              # 测试代码
└── scripts/            # 辅助脚本

7. 常见问题解答

7.1 安装问题

Q: 安装时出现依赖冲突怎么办? A: 建议使用虚拟环境,并确保 Python 版本为 3.9+

Q: 模型下载失败怎么办? A: 检查网络连接,或手动下载模型文件到缓存目录

7.2 使用问题

Q: 数据集大小有限制吗? A: TabPFN 优化用于 5 万行以下的数据集,更大数据集需要特殊处理

Q: 支持多分类问题吗? A: 支持,但类别数量建议不超过 10 个

7.3 性能问题

Q: 为什么 CPU 上运行很慢? A: TabPFN 设计为 GPU 加速,CPU 仅适合小数据集

Q: 如何优化内存使用? A: 调整批次大小和集成配置数量

8. 开发与贡献

8.1 开发环境搭建

# 创建虚拟环境
python -m venv venv
source venv/bin/activate

# 安装开发依赖
pip install -e ".[dev]"
pre-commit install

8.2 运行测试

# 运行所有测试
pytest tests/

# 运行特定测试模块
pytest tests/test_classifier_interface.py

8.3 代码规范

项目使用 ruff 进行代码格式化:

# 自动格式化代码
ruff check --fix
ruff format

TabPFN 作为一个高效的表格数据预测工具,为机器学习从业者提供了快速解决小样本表格分类问题的强大工具。其简单的 API 设计和优秀的性能使其成为表格数据领域的理想选择。

【免费下载链接】TabPFN Official implementation of the TabPFN paper (https://arxiv.org/abs/2207.01848) and the tabpfn package. 【免费下载链接】TabPFN 项目地址: https://gitcode.com/gh_mirrors/ta/TabPFN

Logo

DAMO开发者矩阵,由阿里巴巴达摩院和中国互联网协会联合发起,致力于探讨最前沿的技术趋势与应用成果,搭建高质量的交流与分享平台,推动技术创新与产业应用链接,围绕“人工智能与新型计算”构建开放共享的开发者生态。

更多推荐