预测所有的数字,现在我们将mnist.py文件里面从加载标签以后的代码做以修改。

def encode_digit(Y, digit):
    encoded_Y = np.zeros_like(Y)
    n_labels = Y.shape[0]
    for i in range(n_labels):
        if Y[i] == digit:
            encoded_Y[i][0] = 1
    return encoded_Y


TRAINING_LABELS = load_labels("../data/mnist/train-labels-idx1-ubyte.gz")
TEST_LABELS = load_labels("../data/mnist/t10k-labels-idx1-ubyte.gz")


Y_train = []
Y_test = []

for digit in range(10):
    Y_train.append(encode_digit(TRAINING_LABELS, digit))
    Y_test.append(encode_digit(TEST_LABELS, digit))

        修改分类算法里的代码,如下: 

# iterations为迭代次数
def train(X, Y, iterations, lr):
    w = np.zeros((X.shape[1], 1))  # 初始化权重w
    for i in range(iterations):
        # print("Iteration %4d => Loss: %.20f" % (i, loss(X, Y, w)))
        w -= gradient(X, Y, w) * lr
    return w


def test(X, Y, w, digit):
    total_examples = X.shape[0]
    correct_results = np.sum(classify(X, w) == Y)
    success_percent = correct_results * 100 / total_examples
    print("Correct classifications for digit %d: %d/%d (%.2f%%)" %
          (digit, correct_results, total_examples, success_percent))


for digit in range(10):
    # 训练模型
    w = train(mi.X_train, mi.Y_train[digit], iterations=100, lr=1e-5)
    # 模型测试
    test(mi.X_test, mi.Y_test[digit], w, digit)

        运行后,得到如下结果: 

        我们发现预测数字8的效果显然没有其它数字的效果好,进一步地说明我们的训练算法还有很大地改进空间,不过这样的结果还能接受!

参考文献:

Programming Machine Learning: Form Coding to Deep Learning.[M],Paolo Perrotta,2021.6. 

Logo

DAMO开发者矩阵,由阿里巴巴达摩院和中国互联网协会联合发起,致力于探讨最前沿的技术趋势与应用成果,搭建高质量的交流与分享平台,推动技术创新与产业应用链接,围绕“人工智能与新型计算”构建开放共享的开发者生态。

更多推荐