背景意义

随着医疗影像技术的迅猛发展,X光胸片作为一种常见的影像学检查手段,广泛应用于肺部疾病的筛查与诊断。胸片能够提供重要的生理信息,帮助医生识别和评估多种病症,如肺炎、肺结核、肿瘤等。然而,传统的胸片分析往往依赖于医生的经验,存在主观性强、效率低下等问题。因此,基于计算机视觉和深度学习的自动化图像分析技术逐渐成为研究热点,尤其是在器官图像分割领域。

YOLO(You Only Look Once)系列模型因其高效的实时检测能力而受到广泛关注。YOLOv8作为该系列的最新版本,进一步提升了检测精度和速度,适用于复杂的医疗影像分析任务。针对X光胸片的图像分割,改进YOLOv8模型不仅能够提高分割的准确性,还能在多种病症的早期筛查中发挥重要作用。通过对心脏、左右肺、脊柱及气管等关键器官的精确分割,医生能够更快地获取病灶信息,从而做出更为准确的诊断。

本研究所使用的数据集包含2200幅无异常发现的X光胸片图像,涵盖了心脏、左肺、右肺、脊柱和气管五个类别。这一数据集的构建为模型的训练和验证提供了丰富的样本基础,确保了分割模型在实际应用中的可靠性与有效性。通过对这些器官的精确分割,研究将为临床医生提供更为直观的影像分析工具,提升胸片的解读效率。

此外,基于改进YOLOv8的图像分割系统不仅在技术上具有创新性,也在临床应用中具有重要的现实意义。随着人口老龄化和呼吸系统疾病发病率的上升,X光胸片的使用频率逐年增加,传统的人工解读方式已难以满足日益增长的医疗需求。通过引入深度学习技术,自动化的图像分割系统能够显著减轻医生的工作负担,提高诊断效率,进而改善患者的就医体验。

综上所述,基于改进YOLOv8的X光胸片器官图像分割系统的研究,不仅为医疗影像分析提供了一种新的解决方案,也为相关领域的研究提供了新的思路。随着技术的不断进步,未来该系统有望在更广泛的临床场景中得到应用,为实现智能化医疗贡献力量。通过对X光胸片的自动化分析,推动医学影像学的发展,提升医疗服务的质量与效率,最终实现更为精准的个性化医疗。

图片效果

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

数据集信息

在医学影像分析领域,尤其是针对X光胸片的研究,图像分割技术的应用日益受到重视。为了提升YOLOv8-seg模型在X光胸片器官图像分割任务中的表现,我们采用了名为“no-finding”的数据集。该数据集专门设计用于训练和评估图像分割算法,特别是针对胸部器官的精确分割。数据集包含五个主要类别,分别是心脏(Heart)、左肺(Left_Lung)、右肺(Right_Lung)、脊柱(Spine)和气管(Weasand)。这些类别涵盖了胸部X光片中最为关键的解剖结构,能够为后续的医学诊断和治疗提供重要的支持。

“no-finding”数据集的构建遵循了严格的医学标准,确保了数据的高质量和多样性。每个类别的图像均经过专业放射科医生的标注,确保了标注的准确性和一致性。这一过程不仅涉及到图像的选择,还包括对每个器官的边界进行精确的勾勒,以便于后续模型的训练。通过这种方式,数据集能够为YOLOv8-seg模型提供丰富的训练样本,使其在分割任务中具备更强的泛化能力。

在数据集的使用过程中,研究人员可以利用这些标注良好的图像进行监督学习。通过对模型进行反复训练和验证,研究者能够不断优化模型的参数,从而提升其在实际应用中的表现。此外,数据集中的图像样本涵盖了不同年龄、性别和病理状态的患者,使得模型在处理不同类型的X光胸片时,能够更好地适应各种临床场景。

为了进一步提高模型的鲁棒性,研究人员还可以通过数据增强技术对“no-finding”数据集进行扩展。这包括对图像进行旋转、缩放、翻转以及亮度和对比度的调整等操作。这些增强手段不仅能够增加训练样本的多样性,还能有效防止模型的过拟合现象,从而提升其在未知数据上的表现。

在模型训练完成后,研究人员将对其进行评估,主要通过计算分割精度、召回率和F1-score等指标来衡量模型的性能。这些评估指标将为模型的进一步优化提供重要依据,确保其在实际应用中能够达到预期的效果。

总之,“no-finding”数据集为改进YOLOv8-seg的X光胸片器官图像分割系统提供了坚实的基础。通过对该数据集的有效利用,研究人员不仅能够提升模型的分割精度,还能为医学影像分析领域的进一步研究奠定基础。随着技术的不断进步和数据集的不断完善,未来的X光胸片图像分割系统将能够更好地服务于临床实践,为患者的健康管理提供更加精准的支持。
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

核心代码


```python
# Ultralytics YOLO 🚀, AGPL-3.0 license

from typing import List
from urllib.parse import urlsplit
import numpy as np

class TritonRemoteModel:
    """
    与远程Triton推理服务器模型交互的客户端。

    属性:
        endpoint (str): Triton服务器上模型的名称。
        url (str): Triton服务器的URL。
        triton_client: Triton客户端(HTTP或gRPC)。
        InferInput: Triton客户端的输入类。
        InferRequestedOutput: Triton客户端的输出请求类。
        input_formats (List[str]): 模型输入的数据类型。
        np_input_formats (List[type]): 模型输入的numpy数据类型。
        input_names (List[str]): 模型输入的名称。
        output_names (List[str]): 模型输出的名称。
    """

    def __init__(self, url: str, endpoint: str = '', scheme: str = ''):
        """
        初始化TritonRemoteModel。

        参数可以单独提供或从形式为
            <scheme>://<netloc>/<endpoint>/<task_name>的集合'url'参数中解析。

        参数:
            url (str): Triton服务器的URL。
            endpoint (str): Triton服务器上模型的名称。
            scheme (str): 通信方案('http'或'gRPC')。
        """
        # 如果没有提供endpoint和scheme,则从URL字符串解析所有参数
        if not endpoint and not scheme:
            splits = urlsplit(url)
            endpoint = splits.path.strip('/').split('/')[0]  # 获取模型名称
            scheme = splits.scheme  # 获取通信方案
            url = splits.netloc  # 获取服务器地址

        self.endpoint = endpoint  # 设置模型名称
        self.url = url  # 设置服务器URL

        # 根据通信方案选择Triton客户端
        if scheme == 'http':
            import tritonclient.http as client  # 导入HTTP客户端
            self.triton_client = client.InferenceServerClient(url=self.url, verbose=False, ssl=False)
            config = self.triton_client.get_model_config(endpoint)  # 获取模型配置
        else:
            import tritonclient.grpc as client  # 导入gRPC客户端
            self.triton_client = client.InferenceServerClient(url=self.url, verbose=False, ssl=False)
            config = self.triton_client.get_model_config(endpoint, as_json=True)['config']  # 获取模型配置

        # 按字母顺序排序输出名称
        config['output'] = sorted(config['output'], key=lambda x: x.get('name'))

        # 定义模型属性
        type_map = {'TYPE_FP32': np.float32, 'TYPE_FP16': np.float16, 'TYPE_UINT8': np.uint8}
        self.InferRequestedOutput = client.InferRequestedOutput  # 设置输出请求类
        self.InferInput = client.InferInput  # 设置输入类
        self.input_formats = [x['data_type'] for x in config['input']]  # 获取输入数据类型
        self.np_input_formats = [type_map[x] for x in self.input_formats]  # 转换为numpy数据类型
        self.input_names = [x['name'] for x in config['input']]  # 获取输入名称
        self.output_names = [x['name'] for x in config['output']]  # 获取输出名称

    def __call__(self, *inputs: np.ndarray) -> List[np.ndarray]:
        """
        使用给定的输入调用模型。

        参数:
            *inputs (List[np.ndarray]): 输入数据。

        返回:
            List[np.ndarray]: 模型输出。
        """
        infer_inputs = []  # 存储推理输入
        input_format = inputs[0].dtype  # 获取输入数据类型
        for i, x in enumerate(inputs):
            # 如果输入数据类型与预期不符,则转换数据类型
            if x.dtype != self.np_input_formats[i]:
                x = x.astype(self.np_input_formats[i])
            # 创建InferInput对象并设置数据
            infer_input = self.InferInput(self.input_names[i], [*x.shape], self.input_formats[i].replace('TYPE_', ''))
            infer_input.set_data_from_numpy(x)  # 从numpy数组设置数据
            infer_inputs.append(infer_input)  # 添加到推理输入列表

        # 创建InferRequestedOutput对象以请求输出
        infer_outputs = [self.InferRequestedOutput(output_name) for output_name in self.output_names]
        # 调用Triton客户端进行推理
        outputs = self.triton_client.infer(model_name=self.endpoint, inputs=infer_inputs, outputs=infer_outputs)

        # 返回输出结果,转换为原始输入数据类型
        return [outputs.as_numpy(output_name).astype(input_format) for output_name in self.output_names]

代码分析

  1. 类的定义TritonRemoteModel类用于与Triton推理服务器进行交互,主要包含模型的基本信息和推理功能。
  2. 初始化方法:在__init__方法中,解析URL并初始化Triton客户端,获取模型配置并设置输入输出的相关属性。
  3. 调用方法__call__方法实现了模型的推理过程,接收输入数据,处理数据类型,并返回模型的输出结果。```
    这个文件定义了一个名为 TritonRemoteModel 的类,用于与远程的 Triton 推理服务器模型进行交互。Triton 是一个用于高效部署和推理机器学习模型的服务器。该类的主要功能是封装与 Triton 服务器的通信,简化模型的推理过程。

在类的文档字符串中,列出了该类的主要属性,包括模型的名称、服务器的 URL、Triton 客户端、输入输出的格式和名称等。这些属性为后续的推理操作提供了必要的信息。

类的构造函数 __init__ 接受三个参数:urlendpointscheme。如果没有提供 endpointscheme,则会从 url 中解析出这些信息。解析过程使用了 urlsplit 函数,将 URL 分解为不同的部分。接着,根据提供的通信方案(HTTP 或 gRPC),导入相应的 Triton 客户端,并获取模型的配置。

模型的输出名称会按照字母顺序进行排序,以便后续处理。构造函数还定义了输入输出的数据类型映射,确保输入数据的格式与模型的要求相匹配。输入和输出的名称及格式被存储在类的属性中,以便后续使用。

类的 __call__ 方法允许用户以函数的方式调用模型。该方法接受一个或多个 NumPy 数组作为输入,并返回模型的输出。首先,方法会检查输入数据的类型是否与模型要求的类型一致,如果不一致,则进行类型转换。然后,创建 InferInput 对象,将输入数据设置到 Triton 客户端中。接着,构建输出请求,并通过 Triton 客户端进行推理。最后,返回的输出会被转换为原始输入的类型,并以列表的形式返回。

总的来说,这个文件提供了一个简洁的接口,使得用户能够方便地与 Triton 推理服务器进行交互,进行模型推理操作。
```以下是经过简化和注释的核心代码部分,主要包含了 DCNv3Function 类及其前向和反向传播的实现,以及一些辅助函数。注释详细解释了每个部分的功能和作用。

import torch
import torch.nn.functional as F
from torch.autograd import Function
from torch.cuda.amp import custom_bwd, custom_fwd

class DCNv3Function(Function):
    @staticmethod
    @custom_fwd
    def forward(ctx, input, offset, mask, kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w, dilation_h, dilation_w, group, group_channels, offset_scale, im2col_step, remove_center):
        # 保存前向传播所需的参数到上下文中
        ctx.kernel_h = kernel_h
        ctx.kernel_w = kernel_w
        ctx.stride_h = stride_h
        ctx.stride_w = stride_w
        ctx.pad_h = pad_h
        ctx.pad_w = pad_w
        ctx.dilation_h = dilation_h
        ctx.dilation_w = dilation_w
        ctx.group = group
        ctx.group_channels = group_channels
        ctx.offset_scale = offset_scale
        ctx.im2col_step = im2col_step
        ctx.remove_center = remove_center

        # 准备前向传播的参数
        args = [input, offset, mask, kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w, dilation_h, dilation_w, group, group_channels, offset_scale, ctx.im2col_step]
        if remove_center:
            args.append(remove_center)

        # 调用 DCNv3 的前向函数
        output = DCNv3.dcnv3_forward(*args)
        ctx.save_for_backward(input, offset, mask)  # 保存输入以备反向传播使用

        return output

    @staticmethod
    @once_differentiable
    @custom_bwd
    def backward(ctx, grad_output):
        # 从上下文中获取保存的张量
        input, offset, mask = ctx.saved_tensors

        # 准备反向传播的参数
        args = [input, offset, mask, ctx.kernel_h, ctx.kernel_w, ctx.stride_h, ctx.stride_w, ctx.pad_h, ctx.pad_w, ctx.dilation_h, ctx.dilation_w, ctx.group, ctx.group_channels, ctx.offset_scale, grad_output.contiguous(), ctx.im2col_step]
        if ctx.remove_center:
            args.append(ctx.remove_center)

        # 调用 DCNv3 的反向函数
        grad_input, grad_offset, grad_mask = DCNv3.dcnv3_backward(*args)

        return grad_input, grad_offset, grad_mask, None, None, None, None, None, None, None, None, None, None, None, None

def _get_reference_points(spatial_shapes, device, kernel_h, kernel_w, dilation_h, dilation_w, pad_h=0, pad_w=0, stride_h=1, stride_w=1):
    # 计算参考点的函数
    _, H_, W_, _ = spatial_shapes
    H_out = (H_ - (dilation_h * (kernel_h - 1) + 1)) // stride_h + 1
    W_out = (W_ - (dilation_w * (kernel_w - 1) + 1)) // stride_w + 1

    # 生成参考点的网格
    ref_y, ref_x = torch.meshgrid(
        torch.linspace((dilation_h * (kernel_h - 1)) // 2 + 0.5, (dilation_h * (kernel_h - 1)) // 2 + 0.5 + (H_out - 1) * stride_h, H_out, dtype=torch.float32, device=device),
        torch.linspace((dilation_w * (kernel_w - 1)) // 2 + 0.5, (dilation_w * (kernel_w - 1)) // 2 + 0.5 + (W_out - 1) * stride_w, W_out, dtype=torch.float32, device=device))
    
    # 归一化参考点
    ref_y = ref_y.reshape(-1)[None] / H_
    ref_x = ref_x.reshape(-1)[None] / W_
    ref = torch.stack((ref_x, ref_y), -1).reshape(1, H_out, W_out, 1, 2)

    return ref

def dcnv3_core_pytorch(input, offset, mask, kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w, dilation_h, dilation_w, group, group_channels, offset_scale, remove_center):
    # DCNv3 的核心实现,执行卷积操作
    if remove_center and (kernel_h % 2 == 0 or kernel_w % 2 == 0 or kernel_w != kernel_h):
        raise ValueError('remove_center 仅与奇数的方形卷积核兼容。')

    # 对输入进行填充
    input = F.pad(input, [0, 0, pad_h, pad_h, pad_w, pad_w])
    N_, H_in, W_in, _ = input.shape
    _, H_out, W_out, _ = offset.shape

    # 获取参考点和生成膨胀网格
    ref = _get_reference_points(input.shape, input.device, kernel_h, kernel_w, dilation_h, dilation_w, pad_h, pad_w, stride_h, stride_w)
    # 省略了生成膨胀网格的实现细节

    # 计算采样位置
    sampling_locations = (ref + grid * offset_scale).repeat(N_, 1, 1, 1, 1)
    if remove_center:
        sampling_locations = remove_center_sampling_locations(sampling_locations, kernel_w=kernel_w, kernel_h=kernel_h)
    
    # 进行采样并计算输出
    output = (sampling_input_ * mask).sum(-1).view(N_, group * group_channels, H_out * W_out)
    return output.transpose(1, 2).reshape(N_, H_out, W_out, -1).contiguous()

主要功能说明:

  1. DCNv3Function: 该类实现了 DCNv3 的前向和反向传播,利用了 PyTorch 的自定义函数功能。
  2. _get_reference_points: 计算卷积操作中所需的参考点,用于生成采样位置。
  3. dcnv3_core_pytorch: 实现了 DCNv3 的核心卷积操作,包括输入填充、参考点计算和最终输出的生成。

注意事项:

  • 代码中涉及到的 DCNv3 的具体实现细节(如 DCNv3.dcnv3_forwardDCNv3.dcnv3_backward)未在此处提供,假设这些函数在其他地方定义并实现。
  • 该实现依赖于 PyTorch 的自动求导机制,适用于深度学习模型的训练和推理。```
    这个程序文件 dcnv3_func.py 实现了 DCNv3(Deformable Convolutional Networks v3)中的一些核心功能,主要用于在深度学习模型中进行可变形卷积操作。该文件使用 PyTorch 框架,并包含前向和反向传播的实现。

首先,文件中引入了一些必要的库,包括 PyTorch 和一些功能性模块。DCNv3Function 类继承自 torch.autograd.Function,它定义了可变形卷积的前向和反向传播方法。通过使用 @custom_fwd@custom_bwd 装饰器,前向和反向传播的操作可以被自定义,以便更好地利用 GPU 加速。

forward 方法中,输入参数包括输入张量、偏移量、掩码以及卷积核的各种参数(如大小、步幅、填充等)。该方法首先将这些参数存储在上下文中,以便在反向传播时使用。然后,它调用 DCNv3.dcnv3_forward 函数执行前向计算,并返回输出结果。

backward 方法实现了反向传播,计算梯度。它从上下文中恢复保存的张量,并调用 DCNv3.dcnv3_backward 函数来计算输入、偏移量和掩码的梯度。

此外,文件中还定义了一些辅助函数,例如 _get_reference_points_generate_dilation_grids,用于生成参考点和膨胀网格。这些函数通过计算空间形状和卷积参数,生成在卷积操作中需要的采样位置。

remove_center_sampling_locations 函数用于移除中心采样位置,确保在某些情况下的采样位置符合要求。dcnv3_core_pytorch 函数则是实现可变形卷积的核心逻辑,包括输入的填充、参考点的计算、采样位置的生成以及最终的输出计算。

总的来说,这个文件实现了 DCNv3 中的可变形卷积操作,提供了前向和反向传播的功能,能够在深度学习模型中灵活应用,以提高模型对形状变化的适应能力。


```python
# 导入必要的库
from ultralytics.utils import LOGGER, RUNS_DIR, SETTINGS, TESTS_RUNNING, colorstr

try:
    import os
    import mlflow  # 导入 MLflow 库

    # 确保在测试环境中不记录日志
    assert not TESTS_RUNNING or 'test_mlflow' in os.environ.get('PYTEST_CURRENT_TEST', '')
    # 确保 MLflow 集成已启用
    assert SETTINGS['mlflow'] is True  
    assert hasattr(mlflow, '__version__')  # 确保 mlflow 包已正确导入

    PREFIX = colorstr('MLflow: ')  # 设置日志前缀

except (ImportError, AssertionError):
    mlflow = None  # 如果导入失败,则将 mlflow 设置为 None


def on_pretrain_routine_end(trainer):
    """
    在预训练例程结束时记录训练参数到 MLflow。

    参数:
        trainer (ultralytics.engine.trainer.BaseTrainer): 包含要记录的参数的训练对象。

    全局变量:
        mlflow: 用于记录的 MLflow 模块。

    环境变量:
        MLFLOW_TRACKING_URI: MLflow 跟踪的 URI,默认为 'runs/mlflow'。
        MLFLOW_EXPERIMENT_NAME: MLflow 实验的名称,默认为 trainer.args.project。
        MLFLOW_RUN: MLflow 运行的名称,默认为 trainer.args.name。
    """
    global mlflow

    # 获取跟踪 URI
    uri = os.environ.get('MLFLOW_TRACKING_URI') or str(RUNS_DIR / 'mlflow')
    LOGGER.debug(f'{PREFIX} tracking uri: {uri}')
    mlflow.set_tracking_uri(uri)  # 设置 MLflow 跟踪 URI

    # 设置实验和运行名称
    experiment_name = os.environ.get('MLFLOW_EXPERIMENT_NAME') or trainer.args.project or '/Shared/YOLOv8'
    run_name = os.environ.get('MLFLOW_RUN') or trainer.args.name
    mlflow.set_experiment(experiment_name)  # 设置实验名称

    mlflow.autolog()  # 启用自动日志记录
    try:
        # 开始一个新的运行
        active_run = mlflow.active_run() or mlflow.start_run(run_name=run_name)
        LOGGER.info(f'{PREFIX}logging run_id({active_run.info.run_id}) to {uri}')
        # 提示用户查看日志
        LOGGER.info(f"{PREFIX}view at http://127.0.0.1:5000 with 'mlflow server --backend-store-uri {uri}'")
        LOGGER.info(f"{PREFIX}disable with 'yolo settings mlflow=False'")
        mlflow.log_params(dict(trainer.args))  # 记录训练参数
    except Exception as e:
        LOGGER.warning(f'{PREFIX}WARNING ⚠️ Failed to initialize: {e}\n'
                       f'{PREFIX}WARNING ⚠️ Not tracking this run')


def on_fit_epoch_end(trainer):
    """在每个训练周期结束时记录训练指标到 MLflow。"""
    if mlflow:
        # 清理指标名称并记录
        sanitized_metrics = {k.replace('(', '').replace(')', ''): float(v) for k, v in trainer.metrics.items()}
        mlflow.log_metrics(metrics=sanitized_metrics, step=trainer.epoch)


def on_train_end(trainer):
    """在训练结束时记录模型工件。"""
    if mlflow:
        # 记录最佳模型和其他文件
        mlflow.log_artifact(str(trainer.best.parent))  # 记录最佳模型的目录
        for f in trainer.save_dir.glob('*'):  # 记录保存目录中的所有文件
            if f.suffix in {'.png', '.jpg', '.csv', '.pt', '.yaml'}:
                mlflow.log_artifact(str(f))

        mlflow.end_run()  # 结束当前运行
        LOGGER.info(f'{PREFIX}results logged to {mlflow.get_tracking_uri()}\n'
                    f"{PREFIX}disable with 'yolo settings mlflow=False'")


# 定义回调函数字典
callbacks = {
    'on_pretrain_routine_end': on_pretrain_routine_end,
    'on_fit_epoch_end': on_fit_epoch_end,
    'on_train_end': on_train_end} if mlflow else {}

代码说明:

  1. 导入库:导入必要的库和模块,包括 mlflow,用于记录实验的相关信息。
  2. 环境变量和设置:通过环境变量设置 MLflow 的跟踪 URI、实验名称和运行名称。
  3. 记录训练参数:在预训练结束时,记录训练参数到 MLflow。
  4. 记录训练指标:在每个训练周期结束时,记录当前的训练指标。
  5. 记录模型工件:在训练结束时,记录模型的相关文件和最佳模型。
  6. 回调函数:定义了在不同训练阶段调用的回调函数,用于自动记录相关信息。```
    这个程序文件是用于Ultralytics YOLO模型的MLflow日志记录功能。MLflow是一个开源平台,用于管理机器学习生命周期,包括实验跟踪、模型管理和部署等。该模块的主要作用是记录训练过程中的各种参数、指标和模型工件,以便后续分析和可视化。

文件开头包含了一些基本信息和使用说明,包括如何设置项目名称、运行名称、启动本地MLflow服务器以及如何终止正在运行的MLflow服务器实例。这些命令通过环境变量或参数传递给程序,以便用户可以灵活配置。

在代码中,首先导入了一些必要的模块和常量,包括日志记录器、运行目录、设置和颜色字符串。接着,尝试导入MLflow模块,并进行了一些基本的验证,如检查是否在测试环境中运行以及MLflow集成是否启用。如果导入失败或验证不通过,则将mlflow设置为None。

接下来的三个函数分别对应训练过程中的不同阶段:

  1. on_pretrain_routine_end:在预训练例程结束时记录训练参数。该函数根据环境变量和训练器的参数设置MLflow的跟踪URI、实验名称和运行名称。如果没有活动的运行,则启动一个新的MLflow运行。最后,它会记录训练器的参数。如果在初始化过程中出现异常,则会记录警告信息。

  2. on_fit_epoch_end:在每个训练周期结束时记录训练指标。它会清理指标的名称,去掉括号,并将指标值转换为浮点数,然后将这些指标记录到MLflow中。

  3. on_train_end:在训练结束时记录模型工件。它会记录最佳模型的目录及其他保存的文件(如图像、CSV、模型权重等),并结束当前的MLflow运行。

最后,代码定义了一个回调字典,包含了上述三个函数,只有在成功导入MLflow的情况下才会被定义。这样,程序在训练过程中可以自动调用这些回调函数,确保所有重要的信息都被记录到MLflow中,以便后续的分析和可视化。

源码文件

在这里插入图片描述

源码获取

欢迎大家点赞、收藏、关注、评论啦 、查看👇🏻获取联系方式👇🏻

Logo

DAMO开发者矩阵,由阿里巴巴达摩院和中国互联网协会联合发起,致力于探讨最前沿的技术趋势与应用成果,搭建高质量的交流与分享平台,推动技术创新与产业应用链接,围绕“人工智能与新型计算”构建开放共享的开发者生态。

更多推荐