TensorFlow学习笔记:猴痘病识别
1. 数据的“高清化”处理图片尺寸变大:之前的任务(MNIST/CIFAR)图片都很小(28x28 或 32x32)。本期处理的是医疗影像,细节很重要,所以图片尺寸统一调整为 224x224。注意点:在定义模型输入层 (input_shape) 和加载数据 (image_size) 时,都要写成 (224, 224, 3)。2. 数据加载管道的优化 (Data Pipeline)为了让训练跑得更快
- 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
- 🍖 原作者:K同学啊
一、基础设置与导入数据
import tensorflow as tf
from tensorflow.keras import layers, models, callbacks
import matplotlib.pyplot as plt
import pathlib
import os
# 1. 设置 GPU
gpus = tf.config.list_physical_devices('GPU')
print("Found GPUs:", gpus)
# 2. 设置数据路径
data_dir = pathlib.Path("T4_data")
# 3. 加载数据
img_height = 224 # 这次图片比较大,保留细节
img_width = 224
batch_size = 32
# 训练集 (Training)
train_ds = tf.keras.utils.image_dataset_from_directory(
data_dir,
validation_split=0.2,
subset="training",
seed=123,
image_size=(img_height, img_width),
batch_size=batch_size)
# 验证集 (Validation)
val_ds = tf.keras.utils.image_dataset_from_directory(
data_dir,
validation_split=0.2,
subset="validation",
seed=123,
image_size=(img_height, img_width),
batch_size=batch_size)
class_names = train_ds.class_names
print(f"检测到的类别: {class_names}") # 应该是 ['Monkeypox', 'Others']
# 4. 性能优化 (AUTOTUNE)
AUTOTUNE = tf.data.AUTOTUNE
train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)
二、搭建CNN模型
num_classes = len(class_names)
model = models.Sequential([
# 第一步必须是归一化 把 0-255 变成 0-1
layers.Rescaling(1./255, input_shape=(img_height, img_width, 3)),
# 卷积层 1
layers.Conv2D(16, 3, padding='same', activation='relu'),
layers.MaxPooling2D(),
# 卷积层 2
layers.Conv2D(32, 3, padding='same', activation='relu'),
layers.MaxPooling2D(),
# 卷积层 3
layers.Conv2D(64, 3, padding='same', activation='relu'),
layers.MaxPooling2D(),
# 展平 + 全连接
layers.Flatten(),
layers.Dense(128, activation='relu'),
layers.Dense(num_classes)
])
model.summary()
三、编译与训练
# 1. 编译
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
# 2. 设置回调函数 (ModelCheckpoint)
# 它的作用:监控 val_accuracy,只有当准确率提升时,才保存模型
# save_best_only=True: 只要最好的,不要烂的
checkpointer = callbacks.ModelCheckpoint(
'best_model.h5',
monitor='val_accuracy',
verbose=1,
save_best_only=True
)
# 3. 训练 (把 callbacks 加进去)
# 这次跑 50 轮,因为我们有 checkpointer 兜底,不怕跑过头
epochs = 50
history = model.fit(
train_ds,
validation_data=val_ds,
epochs=epochs,
callbacks=[checkpointer] # <--- 记得带上它
)
四、评估结果
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']
epochs_range = range(len(acc)) # 自动适应实际跑的轮数
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

五、总结
1. 数据的“高清化”处理
图片尺寸变大:
之前的任务(MNIST/CIFAR)图片都很小(28x28 或 32x32)。
本期处理的是医疗影像,细节很重要,所以图片尺寸统一调整为 224x224。
注意点:在定义模型输入层 (input_shape) 和加载数据 (image_size) 时,都要写成 (224, 224, 3)。
2. 数据加载管道的优化 (Data Pipeline)
为了让训练跑得更快,代码里加了一串“加速咒语”:
cache():把读取过的数据缓存在内存里,下次读就不用去硬盘找了。
shuffle(1000):打乱数据,防止模型死记硬背。
prefetch(buffer_size=AUTOTUNE):预取数据。简单说就是GPU 在吃(训练)的时候,CPU 已经在帮忙把下一口饭喂到嘴边了,这样 GPU 就不会闲着。
3. 核心:回调函数 (ModelCheckpoint)
背景:模型训练很多轮(比如 50 轮),往往中间某几轮的效果最好,但训练到最后反而过拟合变差了。如果我们只拿最后训练完的模型,就亏了。
解决方案:使用 callbacks.ModelCheckpoint。
关键参数:
filepath=‘best_model.h5’:给保存的模型起个名字。
monitor=‘val_accuracy’:死死盯着“验证集准确率”这个指标。
save_best_only=True:只保存最好的。 如果这一轮的分数没有上一轮高,就不存;如果破纪录了,就覆盖保存。
用法:在 model.fit() 里加上 callbacks=[checkpointer]。
4. 显存按需分配 (GPU Setup)
代码:tf.config.experimental.set_memory_growth(gpu, True)。
作用:防止 TensorFlow 一上来就霸占你全部的显卡显存。这行代码告诉它:“用多少拿多少,别太贪心”。(防止显存溢出报错)。
5. 归一化写入模型
代码:layers.Rescaling(1./255)。
位置:放在 Sequential 模型的第一层。
好处:把“像素除以 255”这个预处理步骤直接焊死在模型里。以后拿模型去预测新图片时,不用手动再除以 255 了,模型自己会处理。
DAMO开发者矩阵,由阿里巴巴达摩院和中国互联网协会联合发起,致力于探讨最前沿的技术趋势与应用成果,搭建高质量的交流与分享平台,推动技术创新与产业应用链接,围绕“人工智能与新型计算”构建开放共享的开发者生态。
更多推荐



所有评论(0)