一、基础设置与导入数据

import tensorflow as tf
from tensorflow.keras import layers, models, callbacks
import matplotlib.pyplot as plt
import pathlib
import os

# 1. 设置 GPU 
gpus = tf.config.list_physical_devices('GPU')
print("Found GPUs:", gpus)

# 2. 设置数据路径 
data_dir = pathlib.Path("T4_data") 

# 3. 加载数据
img_height = 224  # 这次图片比较大,保留细节
img_width = 224
batch_size = 32

# 训练集 (Training)
train_ds = tf.keras.utils.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="training",
    seed=123,
    image_size=(img_height, img_width),
    batch_size=batch_size)

# 验证集 (Validation)
val_ds = tf.keras.utils.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="validation",
    seed=123,
    image_size=(img_height, img_width),
    batch_size=batch_size)

class_names = train_ds.class_names
print(f"检测到的类别: {class_names}") # 应该是 ['Monkeypox', 'Others']

# 4. 性能优化 (AUTOTUNE)
AUTOTUNE = tf.data.AUTOTUNE
train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)

二、搭建CNN模型

num_classes = len(class_names)

model = models.Sequential([
    # 第一步必须是归一化 把 0-255 变成 0-1
    layers.Rescaling(1./255, input_shape=(img_height, img_width, 3)),
    
    # 卷积层 1
    layers.Conv2D(16, 3, padding='same', activation='relu'),
    layers.MaxPooling2D(),
    
    # 卷积层 2
    layers.Conv2D(32, 3, padding='same', activation='relu'),
    layers.MaxPooling2D(),
    
    # 卷积层 3
    layers.Conv2D(64, 3, padding='same', activation='relu'),
    layers.MaxPooling2D(),
    
    # 展平 + 全连接
    layers.Flatten(),
    layers.Dense(128, activation='relu'),
    layers.Dense(num_classes)
])

model.summary()

三、编译与训练

# 1. 编译
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

# 2. 设置回调函数 (ModelCheckpoint)
# 它的作用:监控 val_accuracy,只有当准确率提升时,才保存模型
# save_best_only=True: 只要最好的,不要烂的
checkpointer = callbacks.ModelCheckpoint(
    'best_model.h5', 
    monitor='val_accuracy', 
    verbose=1, 
    save_best_only=True
)

# 3. 训练 (把 callbacks 加进去)
# 这次跑 50 轮,因为我们有 checkpointer 兜底,不怕跑过头
epochs = 50
history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=epochs,
    callbacks=[checkpointer] # <--- 记得带上它
)

四、评估结果

acc = history.history['accuracy']
val_acc = history.history['val_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']

epochs_range = range(len(acc)) # 自动适应实际跑的轮数

plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

在这里插入图片描述

五、总结

1. 数据的“高清化”处理

图片尺寸变大:

之前的任务(MNIST/CIFAR)图片都很小(28x28 或 32x32)。

本期处理的是医疗影像,细节很重要,所以图片尺寸统一调整为 224x224。

注意点:在定义模型输入层 (input_shape) 和加载数据 (image_size) 时,都要写成 (224, 224, 3)。

2. 数据加载管道的优化 (Data Pipeline)

为了让训练跑得更快,代码里加了一串“加速咒语”:

cache():把读取过的数据缓存在内存里,下次读就不用去硬盘找了。

shuffle(1000):打乱数据,防止模型死记硬背。

prefetch(buffer_size=AUTOTUNE):预取数据。简单说就是GPU 在吃(训练)的时候,CPU 已经在帮忙把下一口饭喂到嘴边了,这样 GPU 就不会闲着。

3. 核心:回调函数 (ModelCheckpoint)

背景:模型训练很多轮(比如 50 轮),往往中间某几轮的效果最好,但训练到最后反而过拟合变差了。如果我们只拿最后训练完的模型,就亏了。

解决方案:使用 callbacks.ModelCheckpoint。

关键参数:

filepath=‘best_model.h5’:给保存的模型起个名字。

monitor=‘val_accuracy’:死死盯着“验证集准确率”这个指标。

save_best_only=True:只保存最好的。 如果这一轮的分数没有上一轮高,就不存;如果破纪录了,就覆盖保存。

用法:在 model.fit() 里加上 callbacks=[checkpointer]。

4. 显存按需分配 (GPU Setup)

代码:tf.config.experimental.set_memory_growth(gpu, True)。

作用:防止 TensorFlow 一上来就霸占你全部的显卡显存。这行代码告诉它:“用多少拿多少,别太贪心”。(防止显存溢出报错)。

5. 归一化写入模型

代码:layers.Rescaling(1./255)。

位置:放在 Sequential 模型的第一层。

好处:把“像素除以 255”这个预处理步骤直接焊死在模型里。以后拿模型去预测新图片时,不用手动再除以 255 了,模型自己会处理。
在这里插入图片描述

Logo

DAMO开发者矩阵,由阿里巴巴达摩院和中国互联网协会联合发起,致力于探讨最前沿的技术趋势与应用成果,搭建高质量的交流与分享平台,推动技术创新与产业应用链接,围绕“人工智能与新型计算”构建开放共享的开发者生态。

更多推荐