展示整体代码

from sklearn import tree
import numpy as np


dataset = np.load('mnist.npz')

x_train = dataset['x_train']
y_train = dataset['y_train']
x_test = dataset['x_test']
y_test = dataset['y_test']

classifier = tree.DecisionTreeClassifier()
x_train = x_train.reshape(60000,784)
x_test = x_test.reshape(10000,784)

classifier.fit(x_train,y_train)
score = classifier.score(x_test,y_test)
print(score)

1.导入相应模块

from sklearn import tree
import numpy as np
  • 使用的是科学计算的库 numpy
  • 做机器学习的库 sklearn 中的 tree

2. 加载、提取数据集的数据

dataset = np.load('mnist.npz')

x_train = dataset['x_train']
y_train = dataset['y_train']
x_test = dataset['x_test']
y_test = dataset['y_test']

3. 对数据进行维度调整

x_train = x_train.reshape(60000,784)
x_test = x_test.reshape(10000,784)

数据的本来维度是:训练集(60000,28,28) 测试集(10000,28,28)现在调整为代码中所示

4. 创建决策树

classifier = tree.DecisionTreeClassifier()

在这里插入图片描述

当然你也可以根据自己的想法来决定你的决策树的深度,以及一些剪枝的策略

5. 投喂数据进行训练

classifier.fit(x_train,y_train)

6. 评估模型

score = classifier.score(x_test,y_test)
print(score)

在这里插入图片描述

疑问:

虽然用分类树可以将手写数字识别的精度达到一个比较高的水平,但是如何建立决策树的过程,通过哪些特征建立的决策树,怎么对这些特征进行可视化,我尚且还不知道,希望哪位大神可以帮忙可视化一下手写数字识别的决策树建立过程;感激不尽。

Logo

DAMO开发者矩阵,由阿里巴巴达摩院和中国互联网协会联合发起,致力于探讨最前沿的技术趋势与应用成果,搭建高质量的交流与分享平台,推动技术创新与产业应用链接,围绕“人工智能与新型计算”构建开放共享的开发者生态。

更多推荐