定制化训练DeepSeek模型:LoAR、COT推理与SFT技术应用

大模型机器学习算法

picture.image

向AI转型的程序员都关注公众号 机器学习AI算法工程

一. 前言介绍

本文内容:

  1. 模型加载与预处理

:详细讲解如何加载预训练模型、分词器,并处理输入数据集。 2. LoRA配置

:介绍如何使用LoRA技术配置模型,并高效进行微调,节省计算资源。 3. 训练过程

:展示了如何配置训练参数,使用SFTTrainer进行训练,并通过WandB记录训练日志。 4. 模型保存与评估

:如何保存微调后的模型,以及如何通过合适的评估集对模型进行验证。 5. 模型合并

:展示了如何通过加权平均的方式合并多个模型权重,得到一个更强大的模型。

1.1 项目背景

本文档描述了如何在MAC笔记本上对

DeepSeek-R1-Distill-Llama-1.5B

Qwen架构

进行高效微调,使用** transformers

进行数据处理,并结合

LoRA

技术进行模型微调,使用

WandB

监控训练过程,

ModelScope

下载模型。(训练数据量大约2w条左右)

  • 由于为MAC笔记本本地训练 无显卡支持 故而放弃(DeepSeek-R1-Distill-Qwen-7B Q wen)

下载的服务信息如下:

|

安装服务

|

版本名称

|

作用

| | --- | --- | --- | | Unsloth |

|

用于数据处理和模型微调。

| | Transformers |

|

Hugging Face 提供的模型库,用于加载和微调 DeepSeek-R1。

| | WandB |

|

用于训练过程的实时监控和可视化。

| | LoRA |

|

用于微调的低秩适应技术。

| | ModelScope |

|

用于下载 DeepSeek-R1-8b 模型。

| | python3.11 |

Python 3.11

|

用于执行 Python 脚本和训练任务。

|

1.2 LoRA和 QLoRA 简介

以下是 LoRA 和 QLoRA 的区别表格:

|

特性

|

LoRA (Low-Rank Adaptation)

|

QLoRA (Quantized LoRA)

| | --- | --- | --- | | 核心原理 |

通过低秩矩阵分解减少需要调整的参数量

|

在 LoRA 的基础上结合量化技术,进一步减少存储和计算需求

| | 主要优点 |

降低训练时需要调整的参数数量,提高微调效率

|

除了低秩矩阵,还通过量化减少内存占用,适用于资源有限的环境

| | 存储需求 |

较低,但不如 QLoRA 节省内存

|

显著减少内存使用,适合在内存受限的设备上使用

| | 计算效率 |

提高训练效率,减少计算资源消耗

|

量化后的低精度计算进一步提高了计算效率,降低了开销

| | 适用场景 |

计算资源有限但不需要极限压缩的场景

|

内存和计算资源极其有限的环境,特别是在边缘设备上使用

| | 适用硬件 |

适用于大多数硬件设备,尤其是高性能计算环境

|

特别适合内存有限的硬件,如边缘设备、低内存服务器等

|

1.3 LLaMA 架构和

Qwen 架构

|

特性

|

LLaMA 架构

|

Qwen 架构

| | --- | --- | --- | | 开发者 |

Meta(Facebook)

|

深度求索(DeepSeek)

| | 设计目标 |

高效、轻量化

|

中文优化、多语言支持

| | 参数量 |

7B、13B、33B、65B 等

|

7B、14B 等

| | 开源情况 |

开源

|

部分开源或未完全公开

| | 适用场景 |

资源有限的环境

|

中文任务、多语言任务

|

LLaMA 架构

  • 全称

:Large Language Model Meta AI(LLaMA)

  • 开发者

:由 Meta(原 Facebook)开发。

  • 特点

  • 高效性

:LLaMA 旨在以较少的参数量实现高性能,专注于优化计算效率。

  • 轻量化

:模型参数量相对较小(如 7B、13B、33B、65B),但通过高质量数据和训练方法,性能接近甚至超越更大的模型。

  • 开源

:Meta 发布了 LLaMA 的权重和代码,供研究社区使用。

  • 应用场景

  • 适合资源有限的环境,如本地部署或移动设备。
  • 适用于各种 NLP 任务,尤其是在生成、问答、文本分类等任务中,具有较好的性能和效率。

Qwen 架构

  • 开发者

:由中国的深度求索(DeepSeek)团队开发。

  • 特点

  • 定制化设计

:Qwen 可能是针对中文或特定任务优化的架构,具体细节未完全公开。

  • 多语言支持

:Qwen 系列模型通常对中文有较好的支持,同时在英文和多语言任务上也有不错的表现。

  • 参数量灵活

:Qwen 系列包括不同规模的模型(如 7B、14B 等),适合不同场景。

  • 应用场景

  • Qwen 适用于文本生成、自动化内容创作、对话系统、语音合成等任务。

二. 环境准备

2.1 Unsloth 安装(显卡版本-暂时不用)

  • Unsloth

是一个用于数据处理和模型微调的工具。您可以通过以下命令安装:

  • MAC不试用,需要显卡

      
          

        
   

 
        
 ##官网:https://github.com/unslothai/unsloth
 
        
   

 
        
   

 
        
 #01 创建项目,并设置python虚拟环境,python3.11版本
 
        
   

 
        
   

 
        
 #02 安装 unsloth(cpu版本)
 
        
   

 
        brew 
        
 install
 
         llvm(Homebrew clang version 
        
 19.1
 
        .
        
 7
 
        )
        
   

 
        echo '
        
 export
 
         PATH=
        
 "/opt/homebrew/opt/llvm/bin:$PATH"
 
        ' >> ~/.zshrc
        
   

 
        source ~/.zshrc
        
   

 
        
   

 
        pip 
        
 install
 
         torch
        
   

 
        pip 
        
 install
 
         numpy
        
   

 
        pip 
        
 install
 
        
 "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"
 
        
   

 
        
   

 
        
   

 
        
   

 
        
 #03 版本检查
 
        
   

 
        python -c 
        
 "import torch; print(torch.\_\_version\_\_)"
 
        
   

 
        
 2.6
 
        .
        
 0
 
        
   

 
        
   

 
        
 #04 引用
 
        
   

 
        from unsloth import FastLanguageModel
        
   

 
        
   

 
        
   

 
        
   

 
        
   

 
        
   

 
      
    

安装完成后,您可以使用 Unsloth

进行数据的预处理、加载和微调模型。

  • 暂时不使用

      
          

        
   

 
        
 #01 linux 服务建议使用docker
 
        
   

 
        
   

 
        
   

 
        
 #02 拉取镜像
 
        
   

 
        docker pull modelscope-registry.us-west-1.cr.aliyuncs.com/modelscope-repo/modelscope:ubuntu22.04-py310-torch2.3.1-1.22.2
        
   

 
        
   

 
        
 #03 启动
 
        
   

 
        
   

 
      
    

2.2 创建Python项目


      
          

        
   

 
        
 #01 环境是python3.11
 
        
   

 
        
   

 
        
 #02 项目目录
 
        
   

 
        Unsloth-DeepSeek-R1-8b/
        
   

 
        ├── data/                    
        
 # 存放训练数据、验证数据等
 
        
   

 
        │   ├── raw/                 
        
 # 原始数据
 
        
   

 
        │   └── processed/           
        
 # 预处理后的数据
 
        
   

 

 
        ├── models/                  
        
 # 存放模型文件
 
        
   

 
        │   ├── checkpoints/         
        
 # 存储训练过程中的模型检查点
 
        
   

 
        │   └── final\_model/         
        
 # 存储最终微调后的模型
 
        
   

 

 
        ├── scripts/                 
        
 # 存放训练脚本、数据处理脚本等
 
        
   

 
        │   ├── train.py             
        
 # 训练脚本
 
        
   

 
        │   ├── data\_preprocessing.py
        
 # 数据预处理脚本
 
        
   

 
        │   └── evaluate.py          
        
 # 评估脚本
 
        
   

 

 
        ├── logs/                    
        
 # 存放训练日志文件
 
        
   

 
        │   └── training\_logs.txt    
        
 # 训练过程中的日志
 
        
   

 

 
        ├── wandb/                   
        
 # 存放 wandb 相关的配置和记录
 
        
   

 
        │   └── wandb\_config.py      
        
 # wandb 配置文件
 
        
   

 

 
        ├── environment/             
        
 # 环境配置文件
 
        
   

 
        │   ├── requirements.txt     
        
 # 项目的 Python 依赖
 
        
   

 
        │   └── environment.yml      
        
 # 如果使用 Conda,可以创建一个环境配置文件
 
        
   

 

 
        ├── main.py                  
        
 # 主运行文件,启动训练或其他任务
 
        
   

 
        └── README.md                
        
 # 项目的描述文件,包含如何使用和运行的说明
 
        
   

 
        
   

 
        
   

 
        
 #03 创建目录
 
        
   

 
        
 # 创建子目录
 
        
   

 
        
 mkdir
 
         -p data/raw
        
   

 
        
 mkdir
 
         -p data/processed
        
   

 
        
 mkdir
 
         -p models/checkpoints
        
   

 
        
 mkdir
 
         -p models/final\_model
        
   

 
        
 mkdir
 
         -p scripts
        
   

 
        
 mkdir
 
         -p logs
        
   

 
        
 mkdir
 
         -p wandb
        
   

 
        
 mkdir
 
         -p environment
        
   

 
        
   

 
        
 # 创建文件
 
        
   

 
        
 touch
 
         scripts/train.py
        
   

 
        
 touch
 
         scripts/data\_preprocessing.py
        
   

 
        
 touch
 
         scripts/evaluate.py
        
   

 
        
 touch
 
         logs/training\_logs.txt
        
   

 
        
 touch
 
         wandb/wandb\_config.py
        
   

 
        
 touch
 
         environment/requirements.txt
        
   

 
        
 touch
 
         environment/environment.yml
        
   

 
        
 touch
 
         main.py
        
   

 
        
 touch
 
         README.md
        
   

 
      
    

2.3 python 依赖库


      
          

        
   

 
        
 #03 安装即可
 
        
   

 
        pip 
        
 install
 
         torch==
        
 2.6
 
        .
        
 0
 
         transformers datasets
        
   

 
        
   

 
        
 #03 更新证书(后续如果有pip网站使用https 会验证该证书)
 
        
   

 
        /Applications/Python\ 
        
 3.11
 
        /
        
 Install
 
        \ Certificates.
        
 command
 
        
   

 
        
   

 
      
    

2.2 LoRA peft 安装

LoRA 和 PEFT 的安装

  • LoRA

PEFT

是用于高效微调的技术。如果你想在 Mac 上使用这些技术来微调 DeepSeek 模型,你需要安装相关的依赖项。

  • PEFT 包含了 LoRA 的实现,并且它使得你能够通过修改模型的一部分参数来进行高效微调,从而不需要调整整个模型的权重。

      
          

        
   

 
        
 #01 安装 peft
 
        
   

 
        pip 
        
 install 
 
        peft
        
   

 
      
    

2.3 WandB 设置

WandB

是一个用于训练过程实时监控和可视化的工具。您可以通过以下步骤设置 WandB

  1. 注册并登录

WandB官网

。 2. 获取您的 API 密钥并配置环境变量:


      
          

        
   

 
        
 #01 aipkey (本人谷歌邮箱)
 
        
   

 
        
   

 
        
   

 
        
 #02 命令
 
        
   

 
        pip install wandb
        
   

 
        wandb login
        
   

 
        
   

 
        
 #02  运行文件
 
        
   

 
        
 import
 
         wandb  
        
 # 导入 wandb 库,用于跟踪和可视化实验
 
        
   

 
        
 import
 
         random  
        
 # 导入 random 库,用于生成随机数
 
        
   

 
        
   

 
        
 # 开始一个新的 wandb 运行来跟踪当前脚本
 
        
   

 
        wandb.init(
        
   

 
            
        
 # 设置 wandb 项目,所有与该运行相关的数据将被记录到这个项目中
 
        
   

 
            project=
        
 "my-awesome-project"
 
        ,  
        
 # 项目名称,你可以在 wandb 仪表盘中看到这个项目
 
        
   

 
            
        
   

 
            
        
 # 追踪超参数和运行的元数据
 
        
   

 
            config={
        
   

 
                
        
 "learning\_rate"
 
        : 
        
 0.02
 
        ,  
        
 # 设置学习率
 
        
   

 
                
        
 "architecture"
 
        : 
        
 "CNN"
 
        ,  
        
 # 模型架构(这里是卷积神经网络)
 
        
   

 
                
        
 "dataset"
 
        : 
        
 "CIFAR-100"
 
        ,  
        
 # 使用的数据集(这里是 CIFAR-100 数据集)
 
        
   

 
                
        
 "epochs"
 
        : 
        
 10
 
        ,  
        
 # 训练的轮数
 
        
   

 
            }
        
   

 
        )
        
   

 
        
   

 
        
 # 模拟训练过程
 
        
   

 
        epochs = 
        
 10
 
        
 # 总训练轮数
 
        
   

 
        offset = random.random() / 
        
 5
 
        
 # 生成一个小的随机偏移量,用于模拟训练过程中一些不确定性
 
        
   

 
        
   

 
        
 # 开始训练循环,模拟 2 到 10 轮的训练过程
 
        
   

 
        
 for
 
         epoch 
        
 in
 
        
 range
 
        (
        
 2
 
        , epochs):  
        
 # 从第二轮开始,到第 10 轮结束
 
        
   

 
            
        
 # 模拟准确率的变化,随着 epoch 的增加,准确率逐渐提升
 
        
   

 
            acc = 
        
 1
 
         - 
        
 2
 
         ** -epoch - random.random() / epoch - offset
        
   

 
            
        
   

 
            
        
 # 模拟损失的变化,随着 epoch 的增加,损失逐渐减少
 
        
   

 
            loss = 
        
 2
 
         ** -epoch + random.random() / epoch + offset
        
   

 
        
   

 
            
        
 # 使用 wandb 记录每一轮的准确率(acc)和损失值(loss)
 
        
   

 
            wandb.log({
        
 "acc"
 
        : acc, 
        
 "loss"
 
        : loss})
        
   

 
        
   

 
        
 # [可选] 结束 wandb 运行,确保数据被正确上传并完成记录
 
        
   

 
        wandb.finish()
        
   

 
        
   

 
        
   

 
        
   

 
      
    

2.4 modelscope pull 模型


      
          

        
   

 
        
 #01 安装modelscope 
 
        
   

 
        pip install modelscope
        
   

 
        
   

 
        
 #02 下载模型文件
 
        
   

 
        
 mkdir
 
         -p ./models/DeepSeek-R1-Distill-Llama-8B
        
   

 
        
 mkdir
 
         -p ./models/DeepSeek-R1-Distill-Qwen-1.5B
        
   

 
        
 mkdir
 
         -p ./models/DeepSeek-R1-Distill-Qwen-7B
        
   

 
        
   

 
        modelscope download --model deepseek-ai/DeepSeek-R1-Distill-Llama-8B --local\_dir ./models/DeepSeek-R1-Distill-Llama-8B
        
   

 
        
   

 
        modelscope download --model deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B --local\_dir ./models/DeepSeek-R1-Distill-Qwen-1.5B
        
   

 
        
   

 
        modelscope download --model deepseek-ai/DeepSeek-R1-Distill-Qwen-7B --local\_dir ./models/DeepSeek-R1-Distill-Qwen-7B
        
   

 
        
   

 
        
   

 
        
   

 
        modelscope download --model deepseek-ai/DeepSeek-R1-Distill-Llama-8B --local\_dir ./DeepSeek-R1-Distill-Llama-8B
      
    

2.5 测试模型使用


      
          

        
   

 
        
 """
 
   

 
 
   

 
 
   

 
 训练前询问问题:
 
   

 
   皮质醇增多症患者在血浆ACTH明显升高且大剂量地塞米松抑制试验阳性的情况下,应考虑哪种疾病?
 
   

 
   
 
   

 
 训练后再次询问:
 
   

 
 
   

 
 
   

 
 scripts/test\_inference.py
 
   

 
 
   

 
 """
 
        
   

 
        
   

 
        
   

 
        
 import
 
         os
        
   

 
        
 from
 
         transformers 
        
 import
 
         AutoModelForCausalLM, AutoTokenizer
        
   

 
        
 import
 
         torch
        
   

 
        
   

 
        
 # 获取当前脚本的路径
 
        
   

 
        current\_dir = os.path.dirname(\_\_file\_\_)
        
   

 
        
   

 
        
 # 拼接模型和分词器路径
 
        
   

 
        model\_dir = os.path.join(current\_dir, 
        
 '..'
 
        , 
        
 'models'
 
        , 
        
 'DeepSeek-R1-Distill-Qwen-1.5B'
 
        )
        
   

 
        
   

 
        
 # 打印路径确认
 
        
   

 
        
 print
 
        (
        
 f"Model path: 
 
 {model\_dir}
 
 "
 
        )
        
   

 
        
   

 
        
 # 确保模型和分词器的路径存在
 
        
   

 
        
 if
 
        
 not
 
         os.path.exists(model\_dir):
        
   

 
            
        
 raise
 
         ValueError(
        
 f"Model directory does not exist at 
 
 {model\_dir}
 
 "
 
        )
        
   

 
        
 else
 
        :
        
   

 
            
        
 print
 
        (
        
 "Model directory exists, proceeding with loading."
 
        )
        
   

 
        
   

 
        
 # 加载模型和分词器
 
        
   

 
        
 print
 
        (
        
 "Loading model and tokenizer..."
 
        )
        
   

 
        model = AutoModelForCausalLM.from\_pretrained(model\_dir)
        
   

 
        tokenizer = AutoTokenizer.from\_pretrained(model\_dir)
        
   

 
        
   

 
        
 # 打印模型和分词器的配置信息
 
        
   

 
        
 print
 
        (
        
 f"Model config: 
 
 {model.config}
 
 "
 
        )
        
   

 
        
 print
 
        (
        
 f"Tokenizer config: 
 
 {tokenizer}
 
 "
 
        )
        
   

 
        
   

 
        
 # 输入中文文本
 
        
   

 
        input\_text = 
        
 "皮质醇增多症患者在血浆ACTH明显升高且大剂量地塞米松抑制试验阳性的情况下,应考虑哪种疾病?"
 
        
   

 
        
 print
 
        (
        
 f"User input: 
 
 {input\_text}
 
 "
 
        )
        
   

 
        
   

 
        
 # 结构化的 prompt
 
        
   

 
        prompt\_style\_chat = 
        
 f"""请写出一个恰当的回答来完成当前对话任务。
 
   

 
 
   

 
 ### Instruction:
 
   

 
 你是一名助人为乐的助手。
 
   

 
 
   

 
 ### Question:
 
   

 
 
 {input\_text}
 
 
   

 
 
   

 
 ### Response:
 
   

 
 <think>"""
 
        
   

 
        
   

 
        
 # 使用分词器处理输入文本
 
        
   

 
        inputs = tokenizer(prompt\_style\_chat, return\_tensors=
        
 "pt"
 
        , padding=
        
 True
 
        , truncation=
        
 True
 
        , max\_length=
        
 512
 
        )
        
   

 
        
   

 
        
 # 打印 tokenized 输入
 
        
   

 
        
 print
 
        (
        
 f"Tokenized input: 
 
 {inputs}
 
 "
 
        )
        
   

 
        
   

 
        
 # 打印输入形状
 
        
   

 
        
 print
 
        (
        
 f"Input shape: 
 
 {inputs[
 
 'input\_ids'
 
 ].shape}
 
 "
 
        )
        
   

 
        
   

 
        
 # 打印模型的最大长度
 
        
   

 
        
 print
 
        (
        
 f"Model max length: 
 
 {model.config.max\_position\_embeddings}
 
 "
 
        )
        
   

 
        
   

 
        
 # 将模型移至正确的设备(使用 GPU 如果可用)
 
        
   

 
        device = 
        
 "cuda"
 
        
 if
 
         torch.cuda.is\_available() 
        
 else
 
        
 "cpu"
 
        
   

 
        model.to(device)
        
   

 
        
   

 
        
 # 打印设备信息
 
        
   

 
        
 print
 
        (
        
 f"Using device: 
 
 {device}
 
 "
 
        )
        
   

 
        
   

 
        
 # 打印分词器的 pad\_token\_id
 
        
   

 
        pad\_token\_id = tokenizer.pad\_token\_id 
        
 if
 
         tokenizer.pad\_token\_id 
        
 is
 
        
 not
 
        
 None
 
        
 else
 
         model.config.pad\_token\_id
        
   

 
        
 print
 
        (
        
 f"Using pad\_token\_id: 
 
 {pad\_token\_id}
 
 "
 
        )
        
   

 
        
   

 
        
 # 生成模型输出
 
        
   

 
        
 print
 
        (
        
 "Generating response..."
 
        )
        
   

 
        
 # 使用 max\_new\_tokens 来控制生成长度
 
        
   

 
        
 with
 
         torch.no\_grad():  
        
 # 禁用梯度计算,节省内存
 
        
   

 
            
        
 try
 
        :
        
   

 
                
        
 print
 
        (
        
 "Calling model.generate()..."
 
        )
        
   

 
                outputs = model.generate(
        
   

 
                    inputs[
        
 'input\_ids'
 
        ].to(device),
        
   

 
                    attention\_mask=inputs[
        
 'attention\_mask'
 
        ].to(device),
        
   

 
                    max\_new\_tokens=
        
 1200
 
        ,  
        
 # 设置最大生成的 token 数量
 
        
   

 
                    temperature=
        
 1.0
 
        ,
        
   

 
                    top\_p=
        
 0.9
 
        ,
        
   

 
                    pad\_token\_id=pad\_token\_id
        
   

 
                )
        
   

 
        
   

 
                
        
 print
 
        (
        
 "Model.generate() completed."
 
        )
        
   

 
            
        
 except
 
         Exception 
        
 as
 
         e:
        
   

 
                
        
 print
 
        (
        
 f"Error generating response: 
 
 {e}
 
 "
 
        )
        
   

 
                
        
 raise
 
        
   

 
        
   

 
        
 # 打印生成的输出 ID 和它们的形状
 
        
   

 
        
 print
 
        (
        
 f"Generated output IDs: 
 
 {outputs}
 
 "
 
        )
        
   

 
        
 print
 
        (
        
 f"Shape of generated output: 
 
 {outputs.shape}
 
 "
 
        )
        
   

 
        
   

 
        
 # 解码生成的输出文本
 
        
   

 
        
 try
 
        :
        
   

 
            response = tokenizer.decode(outputs[
        
 0
 
        ], skip\_special\_tokens=
        
 True
 
        )
        
   

 
            
        
 print
 
        (
        
 f"Generated response: 
 
 {response}
 
 "
 
        )
        
   

 
        
 except
 
         Exception 
        
 as
 
         e:
        
   

 
            
        
 print
 
        (
        
 f"Error decoding output: 
 
 {e}
 
 "
 
        )
        
   

 
        
   

 
        
   

 
        
   

 
      
    
  • 问题回答

      
          

        
 User input:
 
         
        
 皮质醇增多症患者在血浆ACTH明显升高且大剂量地塞米松抑制试验阳性的情况下,应考虑哪种疾病?
 
        
   

 
        
 Tokenized input:
 
         {
        
 'input\_ids':
 
        
 tensor(
 
        [[
        
 151646
 
        ,  
        
 14880
 
        , 
        
 112672
 
        ,  
        
 46944
 
        , 
        
 112449
 
        , 
        
 111423
 
        ,  
        
 36407
 
        ,  
        
 60548
 
        ,  
        
 67949
 
        ,
        
   

 
                 
        
 105051
 
        ,  
        
 88802
 
        ,   
        
 3407
 
        ,  
        
 14374
 
        ,  
        
 29051
 
        ,    
        
 510
 
        ,  
        
 56568
 
        , 
        
 110124
 
        ,  
        
 99262
 
        ,
        
   

 
                 
        
 103247
 
        ,  
        
 99350
 
        ,   
        
 9370
 
        , 
        
 110498
 
        ,   
        
 3407
 
        ,  
        
 14374
 
        ,  
        
 15846
 
        ,    
        
 510
 
        ,  
        
 99888
 
        ,
        
   

 
                  
        
 99178
 
        , 
        
 103032
 
        , 
        
 107284
 
        ,  
        
 99769
 
        , 
        
 101924
 
        ,  
        
 18493
 
        ,  
        
 99389
 
        , 
        
 101498
 
        ,   
        
 6823
 
        ,
        
   

 
                     
        
 39
 
        , 
        
 100687
 
        , 
        
 109061
 
        , 
        
 100136
 
        ,  
        
 26288
 
        , 
        
 114786
 
        ,  
        
 29490
 
        , 
        
 101202
 
        ,  
        
 72261
 
        ,
        
   

 
                 
        
 100180
 
        , 
        
 106555
 
        , 
        
 102360
 
        , 
        
 112758
 
        , 
        
 104248
 
        ,   
        
 3837
 
        ,  
        
 50511
 
        , 
        
 101118
 
        , 
        
 113195
 
        ,
        
   

 
                 
        
 101160
 
        ,  
        
 26850
 
        ,  
        
 14374
 
        ,   
        
 5949
 
        ,    
        
 510
 
        , 
        
 151648
 
        ]]
        
 )
 
        , 
        
 'attention\_mask':
 
        
 tensor(
 
        [[
        
 1
 
        , 
        
 1
 
        , 
        
 1
 
        , 
        
 1
 
        , 
        
 1
 
        , 
        
 1
 
        , 
        
 1
 
        , 
        
 1
 
        , 
        
 1
 
        , 
        
 1
 
        , 
        
 1
 
        , 
        
 1
 
        , 
        
 1
 
        , 
        
 1
 
        , 
        
 1
 
        , 
        
 1
 
        , 
        
 1
 
        , 
        
 1
 
        , 
        
 1
 
        , 
        
 1
 
        , 
        
 1
 
        , 
        
 1
 
        , 
        
 1
 
        , 
        
 1
 
        ,
        
   

 
                 
        
 1
 
        , 
        
 1
 
        , 
        
 1
 
        , 
        
 1
 
        , 
        
 1
 
        , 
        
 1
 
        , 
        
 1
 
        , 
        
 1
 
        , 
        
 1
 
        , 
        
 1
 
        , 
        
 1
 
        , 
        
 1
 
        , 
        
 1
 
        , 
        
 1
 
        , 
        
 1
 
        , 
        
 1
 
        , 
        
 1
 
        , 
        
 1
 
        , 
        
 1
 
        , 
        
 1
 
        , 
        
 1
 
        , 
        
 1
 
        , 
        
 1
 
        , 
        
 1
 
        ,
        
   

 
                 
        
 1
 
        , 
        
 1
 
        , 
        
 1
 
        , 
        
 1
 
        , 
        
 1
 
        , 
        
 1
 
        , 
        
 1
 
        , 
        
 1
 
        , 
        
 1
 
        , 
        
 1
 
        , 
        
 1
 
        , 
        
 1
 
        ]]
        
 )
 
        }
        
   

 
        
 Input shape:
 
        
 torch.Size([1,
 
        
 60
 
        
 ])
 
        
   

 
        
 Model max length:
 
        
 131072
 
        
   

 
        
 Using device:
 
        
 cpu
 
        
   

 
        
 Using pad\_token\_id:
 
        
 151643
 
        
   

 
        
 Generating
 
        
 response...
 
        
   

 
        
 Calling
 
        
 model.generate()...
 
        
   

 
        
 Model.generate()
 
        
 completed.
 
        
   

 
        
   

 
        
 Generated response:
 
        
 请写出一个恰当的回答来完成当前对话任务。
 
        
   

 
        
   

 
        
 ### Instruction:
 
        
   

 
        
 你是一名助人为乐的助手。
 
        
   

 
        
   

 
        
 ### Question:
 
        
   

 
        
 皮质醇增多症患者在血浆ACTH明显升高且大剂量地塞米松抑制试验阳性的情况下,应考虑哪种疾病?
 
        
   

 
        
   

 
        
 ### Response:
 
        
   

 
        
 <think>
 
        
   

 
        
 好的,我现在需要仔细分析这个问题并给出一个合适的回答。首先,问题描述的是皮质醇增多症(PHT)患者在血浆ACTH明显升高且大剂量地塞米松抑制试验(SSDS)显示阳性的情况下,应考虑哪种疾病。
 
        
   

 
        
   

 
        
 首先,我记得皮质醇增多症是由于皮质醇分泌异常导致,通常由代谢紊乱或神经退行性疾病引起,比如皮质醇过激释放症、皮质醇过激释放性代谢综合征等。通常,患者可能表现出皮质醇水平升高,血浆ACTH显著升高,这符合题意的第一个条件。
 
        
   

 
        
   

 
        
 接下来,第二个条件是SSDS试验阳性。SSDS试验主要用于检测皮质醇释放的细胞因子,比如PD-L1,这些因子在疾病早期有显著的表观变化。皮质醇增多症患者的皮质醇释放确实受阻,导致细胞因子释放减少,这在SSDS中会被检测出来,所以这种情况属于皮质醇增多症。
 
        
   

 
        
   

 
        
 综合这两个条件,患者的血浆ACTH升高和SSDS阳性,符合皮质醇增多症的特征。因此,这种情况下应考虑的是皮质醇增多症。
 
        
   

 
        
   

 
        
 我需要确保我没有遗漏其他可能导致SSDS试验阳性的情况。比如,是否有一些其他类型的疾病,比如胰岛素素合成障碍或胰岛素缺乏,也会影响皮质醇释放?不过,这些更可能是胰岛素素合成障碍,而不是直接由皮质醇释放受阻引起的。皮质醇增多症通常是由于皮质醇释放异常,因此SSDS阳性更直接与皮质醇释放受阻相关。
 
        
   

 
        
   

 
        
 此外,ACTH升高可能与皮质醇增多症不同,而更可能是由于激素分泌过量或其他激素调节问题。因此,ACTH升高的信号应该更多指向皮质醇增多症。
 
        
   

 
        
   

 
        
 综上所述,这种情况下应该考虑的疾病是皮质醇增多症。
 
        
   

 
        
 </think>
 
        
   

 
        
   

 
        
 应考虑皮质醇增多症(Pantoprazolidone
 
        
 Phenomenon)。
 
        
   

 
        
   

 
        
 因为:
 
        
   

 
        
   

 
        
 1
 
        
 .
 
        
 血浆ACTH显著升高,符合皮质醇增多症的特征。
 
        
   

 
        
 2
 
        
 .
 
        
 SSDS试验阳性,表明皮质醇释放受阻,属于皮质醇增多症的表现。
 
        
   

 
        
   

 
        
   

 
        
      
    
三. 训练数据数据

3.1 准备数据集


      
          

        
   

 
        
 #01 我们使用COT格式 医学领域 medical-o1-reasoning-SFT 数据集
 
        
   

 
        https://huggingface.co/datasets/FreedomIntelligence/medical-o1-reasoning-SFT
        
   

 
        
   

 
        
 #02 b本地导入方式()
 
        
   

 
        from datasets import load\_dataset
        
   

 
        ds = load\_dataset(
        
 "FreedomIntelligence/medical-o1-reasoning-SFT"
 
        , 
        
 "zh"
 
        )
        
   

 
        
   

 
      
    
  • Hugging face 数据集

  • modelscope


      
          

        
   

 
        
 #01 使用modelscope 数据集 官网地址
 
        
   

 
        https://www.modelscope.cn/datasets/YIRONGCHEN/PsyDTCorpus/files
        
   

 
        
   

 
        
 #02 下载完整数据集repo
 
        
   

 
        modelscope download --dataset YIRONGCHEN/PsyDTCorpus --local\_dir ./dir
        
   

 
        
   

 
        
   

 
        
 #03 下载单个文件到指定本地文件夹(以下载README.md到当前路径下“dir”目录为例)
 
        
   

 
        modelscope download --dataset YIRONGCHEN/PsyDTCorpus README.md --local\_dir ./dir
        
   

 
        
   

 
      
    

3.2 数据清洗


      
          

        
   

 
        
 #01 用于对medical-o1-reasoning-SFT数据集进行修改,Complex\_CoT列和Response列进行拼接,并加上文本结束标记:
 
        
   

 
        
 def
 
        
 formatting\_prompts\_func
 
        (
        
 examples, EOS\_TOKEN
 
        ):
        
   

 
            
        
 """
 
   

 
     格式化数据集中的每个示例,使其符合训练的要求。
 
   

 
 
   

 
     Args:
 
   

 
         examples (dict): 数据集中的输入示例
 
   

 
         EOS\_TOKEN (str): 结束符
 
   

 
 
   

 
     Returns:
 
   

 
         dict: 格式化后的文本数据
 
   

 
     """
 
        
   

 
            train\_prompt\_style = 
        
 """Below is an instruction that describes a task, paired with an input that provides further context. 
 
   

 
     Write a response that appropriately completes the request. 
 
   

 
     Before answering, think carefully about the question and create a step-by-step chain of thoughts to ensure a logical and accurate response.
 
   

 
 
   

 
     ### Instruction:
 
   

 
     You are a medical expert with advanced knowledge in clinical reasoning, diagnostics, and treatment planning. 
 
   

 
     Please answer the following medical question. 
 
   

 
 
   

 
     ### Question:
 
   

 
     {}
 
   

 
 
   

 
     ### Response:
 
   

 
     <think>
 
   

 
     {}
 
   

 
     </think>
 
   

 
     {}"""
 
        
   

 
        
   

 
            inputs = examples[
        
 "Question"
 
        ]
        
   

 
            cots = examples[
        
 "Complex\_CoT"
 
        ]
        
   

 
            outputs = examples[
        
 "Response"
 
        ]
        
   

 
            texts = []
        
   

 
            
        
 for
 
        
 input
 
        , cot, output 
        
 in
 
        
 zip
 
        (inputs, cots, outputs):
        
   

 
                text = train\_prompt\_style.
        
 format
 
        (
        
 input
 
        , cot, output) + EOS\_TOKEN
        
   

 
                texts.append(text)
        
   

 
            
        
 return
 
         {
        
 "text"
 
        : texts}
        
   

 
        
   

 
        
   

 
        
   

 
        
 """
 
   

 
 
   

 
 问题({}) 被嵌套到 ### Question: 下面,替换掉 {}。
 
   

 
 推理过程({}) 被嵌套到 <think></think> 标签内,替换掉第二个 {}。
 
   

 
 答案({}) 被嵌套到模板的最后,替换掉第三个 {}。
 
   

 
 具体替换流程:
 
   

 
 {} 第一个位置将会被每个样本中的问题(examples["Question"])替换。
 
   

 
 {} 第二个位置将会被每个样本中的推理过程(examples["Complex\_CoT"])替换。
 
   

 
 {} 第三个位置将会被每个样本中的答案(examples["Response"])替换。
 
   

 
 例如,如果输入数据如下:
 
   

 
 
   

 
 问题(Question): "What is the cause of fever?"
 
   

 
 推理过程(Complex\_CoT): "Fever is usually caused by an infection or inflammation. We need to identify the source."
 
   

 
 答案(Response): "The most common causes of fever are bacterial or viral infections."
 
   

 
 
   

 
 """
 
        
   

 
        
   

 
          
      
    
  • 原数据格式

      
          

        {
        
   

 
            
        
 "Question"
 
        : [
        
   

 
                
        
 "What is the cause of headache?"
 
        ,
        
   

 
                
        
 "How do you treat a cold?"
 
        
   

 
            ],
        
   

 
            
        
 "Complex\_CoT"
 
        : [
        
   

 
                
        
 "The causes of headaches are numerous, including tension, dehydration, or sinus issues."
 
        ,
        
   

 
                
        
 "Treating a cold typically involves rest, fluids, and over-the-counter medications for symptoms."
 
        
   

 
            ],
        
   

 
            
        
 "Response"
 
        : [
        
   

 
                
        
 "A headache can be caused by stress, lack of sleep, or a sinus infection."
 
        ,
        
   

 
                
        
 "For a cold, hydration and rest are key. Medications like ibuprofen can help with symptoms."
 
        
   

 
            ]
        
   

 
        }
        
   

 
      
    
  • 格式化后数据

      
          

        {
        
   

 
            
        
 "text"
 
        : [
        
   

 
                
        
 ""
 
        
 "Below is an instruction that describes a task, paired with an input that provides further context. 
 
   

 
         Write a response that appropriately completes the request. 
 
   

 
         Before answering, think carefully about the question and create a step-by-step chain of thoughts to ensure a logical and accurate response.
 
   

 
 
   

 
         ### Instruction:
 
   

 
         You are a medical expert with advanced knowledge in clinical reasoning, diagnostics, and treatment planning. 
 
   

 
         Please answer the following medical question. 
 
   

 
 
   

 
         ### Question:
 
   

 
         What is the cause of headache?
 
   

 
 
   

 
         ### Response:
 
   

 
         <think>
 
   

 
         The causes of headaches are numerous, including tension, dehydration, or sinus issues.
 
   

 
         </think>
 
   

 
         A headache can be caused by stress, lack of sleep, or a sinus infection. <|endoftext|>
 
   

 
         "
 
        
 ""
 
        ,
        
   

 
                
        
 ""
 
        
 "Below is an instruction that describes a task, paired with an input that provides further context. 
 
   

 
         Write a response that appropriately completes the request. 
 
   

 
         Before answering, think carefully about the question and create a step-by-step chain of thoughts to ensure a logical and accurate response.
 
   

 
 
   

 
         ### Instruction:
 
   

 
         You are a medical expert with advanced knowledge in clinical reasoning, diagnostics, and treatment planning. 
 
   

 
         Please answer the following medical question. 
 
   

 
 
   

 
         ### Question:
 
   

 
         How do you treat a cold?
 
   

 
 
   

 
         ### Response:
 
   

 
         <think>
 
   

 
         Treating a cold typically involves rest, fluids, and over-the-counter medications for symptoms.
 
   

 
         </think>
 
   

 
         For a cold, hydration and rest are key. Medications like ibuprofen can help with symptoms. <|endoftext|>
 
   

 
         "
 
        
 ""
 
        
   

 
            ]
        
   

 
        }
        
   

 
      
    

3.3 训练数据

  1. setup_wandb

: 配置并登录到 wandb

进行实验跟踪和日志记录。 2. set_paths

: 设置根目录、模型路径、数据集路径和保存微调模型的路径。 3. load_model_and_tokenizer

: 加载预训练模型和分词器,获取结束符。 4. formatting_prompts_func

: 格式化数据集中的问题和回答,以便训练。 5. setup_lora

: 配置并应用LoRA(低秩适配器)到模型。 6. load_dataset_func

: 加载数据集并进行切分,返回训练集和评估集。 7. setup_training_args

: 设置训练参数,包括学习率、批处理大小、训练周期等。 8. train_model

: 使用 SFTTrainer

进行模型训练。 9. save_model

: 保存训练好的模型到指定路径。


      
          

        
 import
 
         os
        
   

 
        
 import
 
         torch
        
   

 
        
 from
 
         transformers 
        
 import
 
         AutoModelForCausalLM, AutoTokenizer, TrainingArguments
        
   

 
        
 from
 
         datasets 
        
 import
 
         load\_dataset
        
   

 
        
 from
 
         peft 
        
 import
 
         get\_peft\_model, LoraConfig
        
   

 
        
 from
 
         trl 
        
 import
 
         SFTTrainer  
        
 # 使用 SFTTrainer
 
        
   

 
        
 import
 
         wandb
        
   

 
        
 from
 
         config 
        
 import
 
         setting
        
   

 
        
   

 
        
 # 设置环境变量,禁用tokenizer的并行化
 
        
   

 
        os.environ[
        
 "TOKENIZERS\_PARALLELISM"
 
        ] = 
        
 "false"
 
        
   

 
        
   

 
        
   

 
        
 # 登录wandb
 
        
   

 
        
 def
 
        
 setup\_wandb
 
        ():
        
   

 
            
        
 """
 
   

 
     登录到wandb以便记录训练过程中的日志和指标。
 
   

 
     """
 
        
   

 
            wandb.login()
        
   

 
        
   

 
        
   

 
        
 # 设置路径
 
        
   

 
        
 def
 
        
 set\_paths
 
        ():
        
   

 
            
        
 """
 
   

 
     设置项目根目录、模型路径、数据集路径和最终模型保存路径。
 
   

 
 
   

 
     Returns:
 
   

 
         model\_dir (str): 模型文件路径
 
   

 
         dataset\_path (str): 数据集路径
 
   

 
         final\_model\_dir (str): 微调后模型的保存路径
 
   

 
     """
 
        
   

 
            root\_dir = setting.root\_dir  
        
 # 项目根路径
 
        
   

 
            model\_dir = os.path.join(root\_dir, 
        
 'models'
 
        , 
        
 'DeepSeek-R1-Distill-Qwen-1.5B'
 
        )  
        
 # 模型文件路径
 
        
   

 
            dataset\_path = os.path.join(root\_dir, 
        
 'data'
 
        , 
        
 'medical-o1-reasoning-SFT'
 
        )  
        
 # 数据集路径
 
        
   

 
            final\_model\_dir = os.path.join(root\_dir, 
        
 'models'
 
        , 
        
 'final\_model'
 
        )  
        
 # 高效微调后模型保存路径
 
        
   

 
            
        
 print
 
        (
        
 f'设置模型路径:
 
 {model\_dir}
 
  | 数据集位置:
 
 {dataset\_path}
 
 '
 
        )
        
   

 
            
        
 return
 
         model\_dir, dataset\_path, final\_model\_dir
        
   

 
        
   

 
        
   

 
        
 # 加载模型和分词器
 
        
   

 
        
 def
 
        
 load\_model\_and\_tokenizer
 
        (
        
 model\_dir
 
        ):
        
   

 
            
        
 """
 
   

 
     加载预训练模型和对应的分词器,并获取结束符(EOS\_TOKEN)。
 
   

 
 
   

 
     Args:
 
   

 
         model\_dir (str): 模型的文件路径
 
   

 
 
   

 
     Returns:
 
   

 
         model (AutoModelForCausalLM): 加载的模型
 
   

 
         tokenizer (AutoTokenizer): 加载的分词器
 
   

 
         EOS\_TOKEN (str): 模型的结束符(如果没有,使用默认值)
 
   

 
     """
 
        
   

 
            
        
 print
 
        (
        
 "加载分词器:Loading model and tokenizer..."
 
        )
        
   

 
            model = AutoModelForCausalLM.from\_pretrained(model\_dir)
        
   

 
            tokenizer = AutoTokenizer.from\_pretrained(model\_dir)
        
   

 
        
   

 
            EOS\_TOKEN = tokenizer.eos\_token
        
   

 
            
        
 if
 
         EOS\_TOKEN 
        
 is
 
        
 None
 
        :
        
   

 
                EOS\_TOKEN = 
        
 "<|endoftext|>"
 
        
   

 
        
   

 
            
        
 print
 
        (
        
 f'结束符:
 
 {EOS\_TOKEN}
 
 '
 
        )
        
   

 
            
        
 return
 
         model, tokenizer, EOS\_TOKEN
        
   

 
        
   

 
        
   

 
        
 # 格式化训练数据
 
        
   

 
        
 def
 
        
 formatting\_prompts\_func
 
        (
        
 examples, EOS\_TOKEN
 
        ):
        
   

 
            
        
 """
 
   

 
     格式化数据集中的每个示例,使其符合训练的要求。
 
   

 
 
   

 
     Args:
 
   

 
         examples (dict): 数据集中的输入示例
 
   

 
         EOS\_TOKEN (str): 结束符
 
   

 
 
   

 
     Returns:
 
   

 
         dict: 格式化后的文本数据
 
   

 
     """
 
        
   

 
            train\_prompt\_style = 
        
 """Below is an instruction that describes a task, paired with an input that provides further context. 
 
   

 
     Write a response that appropriately completes the request. 
 
   

 
     Before answering, think carefully about the question and create a step-by-step chain of thoughts to ensure a logical and accurate response.
 
   

 
 
   

 
     ### Instruction:
 
   

 
     You are a medical expert with advanced knowledge in clinical reasoning, diagnostics, and treatment planning. 
 
   

 
     Please answer the following medical question. 
 
   

 
 
   

 
     ### Question:
 
   

 
     {}
 
   

 
 
   

 
     ### Response:
 
   

 
     <think>
 
   

 
     {}
 
   

 
     </think>
 
   

 
     {}"""
 
        
   

 
        
   

 
            inputs = examples[
        
 "Question"
 
        ]
        
   

 
            cots = examples[
        
 "Complex\_CoT"
 
        ]
        
   

 
            outputs = examples[
        
 "Response"
 
        ]
        
   

 
            texts = []
        
   

 
            
        
 for
 
        
 input
 
        , cot, output 
        
 in
 
        
 zip
 
        (inputs, cots, outputs):
        
   

 
                text = train\_prompt\_style.
        
 format
 
        (
        
 input
 
        , cot, output) + EOS\_TOKEN
        
   

 
                texts.append(text)
        
   

 
            
        
 return
 
         {
        
 "text"
 
        : texts}
        
   

 
        
   

 
        
   

 
        
 # 设置LoRA配置
 
        
   

 
        
 def
 
        
 setup\_lora
 
        (
        
 model
 
        ):
        
   

 
            
        
 """
 
   

 
     设置LoRA(低秩适配器)配置,并将其应用到模型。
 
   

 
 
   

 
     Args:
 
   

 
         model (AutoModelForCausalLM): 加载的模型
 
   

 
 
   

 
     Returns:
 
   

 
         model (AutoModelForCausalLM): 应用LoRA后的模型
 
   

 
     """
 
        
   

 
            
        
 print
 
        (
        
 "设置LoRA: Setting up LoRA configuration..."
 
        )
        
   

 
            lora\_config = LoraConfig(
        
   

 
                r=
        
 8
 
        ,  
        
 # adapter的秩
 
        
   

 
                lora\_alpha=
        
 32
 
        ,  
        
 # 缩放因子
 
        
   

 
                lora\_dropout=
        
 0.1
 
        ,  
        
 # LoRA层的dropout
 
        
   

 
                bias=
        
 "none"
 
        ,  
        
 # LoRA的偏置项
 
        
   

 
            )
        
   

 
            
        
 return
 
         get\_peft\_model(model, lora\_config)
        
   

 
        
   

 
        
   

 
        
 # 加载数据集
 
        
   

 
        
 def
 
        
 load\_dataset\_func
 
        (
        
 dataset\_path, train\_size=
 
 100
 
 
        ):
        
   

 
            
        
 """
 
   

 
     从指定路径加载数据集,训练集大小为 train\_size,评估集为训练集的10%,但至少为1。
 
   

 
     """
 
        
   

 
            
        
 print
 
        (
        
 f"从 
 
 {dataset\_path}
 
  加载数据集..."
 
        )
        
   

 
            
        
 # 加载数据集
 
        
   

 
            dataset = load\_dataset(dataset\_path, 
        
 "en"
 
        , split=
        
 "train"
 
        , trust\_remote\_code=
        
 True
 
        )
        
   

 
        
   

 
            
        
 # 计算评估集大小
 
        
   

 
            eval\_size = 
        
 max
 
        (
        
 1
 
        , 
        
 int
 
        (train\_size * 
        
 0.1
 
        ))  
        
 # 评估集大小是训练集的10%,但至少为1
 
        
   

 
        
   

 
            
        
 # 切分数据集
 
        
   

 
            train\_dataset = dataset.select(
        
 range
 
        (train\_size))  
        
 # 使用前 train\_size 条作为训练集
 
        
   

 
            eval\_dataset = dataset.select(
        
 range
 
        (train\_size, train\_size + eval\_size))  
        
 # 剩余部分作为评估集
 
        
   

 
        
   

 
            
        
 print
 
        (
        
 f"训练集大小: 
 
 {
 
 len
 
 (train\_dataset)}
 
 , 评估集大小: 
 
 {
 
 len
 
 (eval\_dataset)}
 
 "
 
        )
        
   

 
            
        
 return
 
         train\_dataset, eval\_dataset
        
   

 
        
   

 
        
   

 
        
 # 配置训练参数
 
        
   

 
        
 def
 
        
 setup\_training\_args
 
        (
        
 final\_model\_dir, enable\_evaluation=
 
 True
 
 
        ):
        
   

 
            
        
 """
 
   

 
     设置训练参数,包括输出目录、学习率、批处理大小等,并根据参数控制是否启用评估。
 
   

 
 
   

 
     Args:
 
   

 
         final\_model\_dir (str): 微调后模型保存的路径
 
   

 
         enable\_evaluation (bool): 是否启用评估。默认为True,启用评估;为False时禁用评估。
 
   

 
 
   

 
     Returns:
 
   

 
         training\_args (TrainingArguments): 训练参数
 
   

 
     """
 
        
   

 
            
        
 # 根据是否启用评估设置 evaluation\_strategy
 
        
   

 
            evaluation\_strategy = 
        
 "epoch"
 
        
 if
 
         enable\_evaluation 
        
 else
 
        
 "no"
 
        
   

 
        
   

 
            training\_args = TrainingArguments(
        
   

 
                output\_dir=final\_model\_dir,
        
   

 
                evaluation\_strategy=evaluation\_strategy,  
        
 # 控制评估策略
 
        
   

 
                learning\_rate=
        
 5e-5
 
        ,
        
   

 
                per\_device\_train\_batch\_size=
        
 2
 
        ,  
        
 # 适当减少批处理大小(根据M3 Pro的内存限制)
 
        
   

 
                gradient\_accumulation\_steps=
        
 4
 
        ,  
        
 # 使用梯度累积,模拟更大的批量
 
        
   

 
                num\_train\_epochs=
        
 3
 
        ,  
        
 # 训练3个周期
 
        
   

 
                report\_to=
        
 "wandb"
 
        ,  
        
 # 使用wandb进行训练日志记录
 
        
   

 
                weight\_decay=
        
 0.01
 
        ,
        
   

 
                logging\_dir=os.path.join(setting.root\_dir, 
        
 'logs'
 
        ),
        
   

 
                logging\_steps=
        
 50
 
        ,  
        
 # 减少日志记录频率
 
        
   

 
                save\_steps=
        
 500
 
        ,  
        
 # 增加模型保存的步数频率,减少频繁保存
 
        
   

 
                save\_total\_limit=
        
 2
 
        ,  
        
 # 保存最多2个模型
 
        
   

 
                dataloader\_num\_workers=
        
 4
 
        ,  
        
 # 设置数据加载器的并行数(根据需要调整)
 
        
   

 
            )
        
   

 
            
        
 return
 
         training\_args
        
   

 
        
   

 
        
   

 
        
   

 
        
 # 训练模型
 
        
   

 
        
 def
 
        
 train\_model
 
        (
        
 model, training\_args, dataset, eval\_dataset, tokenizer, enable\_evaluation=
 
 True
 
 
        ):
        
   

 
            
        
 """
 
   

 
     使用SFTTrainer进行模型训练。
 
   

 
 
   

 
     Args:
 
   

 
         model (AutoModelForCausalLM): 需要训练的模型
 
   

 
         training\_args (TrainingArguments): 训练参数
 
   

 
         dataset (Dataset): 用于训练的数据集
 
   

 
         eval\_dataset (Dataset): 用于评估的数据集
 
   

 
         tokenizer (AutoTokenizer): 分词器
 
   

 
         enable\_evaluation (bool): 是否进行评估
 
   

 
 
   

 
     Returns:
 
   

 
         trainer (SFTTrainer): 训练器实例
 
   

 
     """
 
        
   

 
            
        
 # 如果启用了评估,传递评估集
 
        
   

 
            trainer = SFTTrainer(
        
   

 
                model=model,
        
   

 
                args=training\_args,
        
   

 
                train\_dataset=dataset,
        
   

 
                eval\_dataset=eval\_dataset 
        
 if
 
         enable\_evaluation 
        
 else
 
        
 None
 
        ,  
        
 # 根据参数决定是否传递评估集
 
        
   

 
                tokenizer=tokenizer,
        
   

 
                data\_collator=
        
 None
 
        ,  
        
 # 可以选择合适的data collator
 
        
   

 
            )
        
   

 
            trainer.train()
        
   

 
            
        
 return
 
         trainer
        
   

 
        
   

 
        
   

 
        
 # 保存模型
 
        
   

 
        
 def
 
        
 save\_model
 
        (
        
 trainer, final\_model\_dir
 
        ):
        
   

 
            
        
 """
 
   

 
     保存训练后的模型到指定目录。
 
   

 
 
   

 
     Args:
 
   

 
         trainer (SFTTrainer): 训练器实例
 
   

 
         final\_model\_dir (str): 模型保存路径
 
   

 
     """
 
        
   

 
            
        
 print
 
        (
        
 "Saving model..."
 
        )
        
   

 
            trainer.save\_model(final\_model\_dir)
        
   

 
        
   

 
        
   

 
        
   

 
        
 def
 
        
 merge\_models
 
        (
        
 models, weights, device=
 
 "cpu"
 
 
        ):
        
   

 
            
        
 """
 
   

 
     合并多个模型的权重(加权平均)。
 
   

 
 
   

 
     Args:
 
   

 
         models (list): 模型列表
 
   

 
         weights (list): 权重列表,权重数量与模型数量一致
 
   

 
         device (str): 设备,可以是 "cuda" 或 "cpu"
 
   

 
 
   

 
     Returns:
 
   

 
         merged\_model (nn.Module): 合并后的模型
 
   

 
     """
 
        
   

 
            
        
 # 确保模型数量与权重数量一致
 
        
   

 
            
        
 assert
 
        
 len
 
        (models) == 
        
 len
 
        (weights), 
        
 "模型数量与权重数量不一致"
 
        
   

 
        
   

 
            
        
 # 将所有模型加载到相同的设备
 
        
   

 
            
        
 for
 
         i 
        
 in
 
        
 range
 
        (
        
 len
 
        (models)):
        
   

 
                models[i] = models[i].to(device)
        
   

 
        
   

 
            
        
 # 获取第一个模型的状态字典
 
        
   

 
            merged\_state\_dict = models[
        
 0
 
        ].state\_dict()
        
   

 
        
   

 
            
        
 # 对每一层的权重进行加权平均
 
        
   

 
            
        
 for
 
         key 
        
 in
 
         merged\_state\_dict.keys():
        
   

 
                merged\_state\_dict[key] = torch.zeros\_like(merged\_state\_dict[key])
        
   

 
                
        
 for
 
         model, weight 
        
 in
 
        
 zip
 
        (models, weights):
        
   

 
                    merged\_state\_dict[key] += model.state\_dict()[key] * weight
        
   

 
        
   

 
            
        
 # 创建一个新的模型并加载合并后的权重
 
        
   

 
            merged\_model = models[
        
 0
 
        ].\_\_class\_\_.from\_pretrained(models[
        
 0
 
        ].config)
        
   

 
            merged\_model.load\_state\_dict(merged\_state\_dict)
        
   

 
            
        
 return
 
         merged\_model
        
   

 
        
   

 
        
   

 
        
 # 主函数
 
        
   

 
        
 def
 
        
 main
 
        ():
        
   

 
            
        
 """
 
   

 
     主函数,执行整个训练流程:设置路径、加载模型、训练并保存模型。
 
   

 
 
   

 
     参数设置:
 
   

 
             enable\_evaluation = False  # 设置为False以禁用评估 如果性能慢可以设置 False
 
   

 
 
   

 
     加载数据集:
 
   

 
         train\_size=10 设置数据集大小,评估集是数据集百分之10(如果小于1 则等于1 )
 
   

 
         train\_dataset, eval\_dataset = load\_dataset\_func(dataset\_path, train\_size=10)
 
   

 
 
   

 
 
   

 
     """
 
        
   

 
            setup\_wandb()  
        
 # 登录wandb
 
        
   

 
            model\_dir, dataset\_path, final\_model\_dir = set\_paths()  
        
 # 设置路径
 
        
   

 
        
   

 
            model, tokenizer, EOS\_TOKEN = load\_model\_and\_tokenizer(model\_dir)  
        
 # 加载模型和分词器
 
        
   

 
        
   

 
            train\_dataset, eval\_dataset = load\_dataset\_func(dataset\_path, train\_size=
        
 5
 
        )  
        
 # 加载数据集
 
        
   

 
            train\_dataset = train\_dataset.
        
 map
 
        (
        
 lambda
 
         examples: formatting\_prompts\_func(examples, EOS\_TOKEN), batched=
        
 True
 
        )  
        
 # 格式化数据集
 
        
   

 
            eval\_dataset = eval\_dataset.
        
 map
 
        (
        
 lambda
 
         examples: formatting\_prompts\_func(examples, EOS\_TOKEN), batched=
        
 True
 
        )  
        
 # 格式化评估集
 
        
   

 
            
        
 print
 
        (train\_dataset[
        
 "text"
 
        ][
        
 0
 
        ])  
        
 # 打印格式化后的数据
 
        
   

 
        
   

 
            model = setup\_lora(model)  
        
 # 配置LoRA
 
        
   

 
            
        
 # 设置是否开启评估
 
        
   

 
            enable\_evaluation = 
        
 True
 
        
 # 设置为False以禁用评估
 
        
   

 
            training\_args = setup\_training\_args(final\_model\_dir,enable\_evaluation)  
        
 # 配置训练参数
 
        
   

 
            trainer = train\_model(model, training\_args, train\_dataset, eval\_dataset, tokenizer, enable\_evaluation)  
        
 # 开始训练
 
        
   

 
        
   

 
            save\_model(trainer, final\_model\_dir)  
        
 # 保存模型
 
        
   

 
            wandb.finish()  
        
 # 完成wandb记录
 
        
   

 
        
   

 
        
   

 
        
   

 
        
   

 
        
 # 执行主函数
 
        
   

 
        
 if
 
         \_\_name\_\_ == 
        
 "\_\_main\_\_"
 
        :
        
   

 
            main()
        
   

 
      
    

3.3 训练模型并保存


      
          

        
   

 
        
 """
 
   

 
 保存在本地 models/final\_model 路径下
 
   

 
 
   

 
 """
 
        
   

 
        
   

 
        
 def
 
         
        
 save\_model
 
        (
        
 trainer, final\_model\_dir
 
        ):
        
   

 
            
        
 """
 
   

 
     保存训练后的模型到指定目录。
 
   

 
 
   

 
     Args:
 
   

 
         trainer (SFTTrainer): 训练器实例
 
   

 
         final\_model\_dir (str): 模型保存路径
 
   

 
     """
 
        
   

 
            
        
 print
 
        (
        
 "Saving model..."
 
        )
        
   

 
            trainer.save\_model(final\_model\_dir)
        
   

 
        
   

 
            
        
   

 
        
   

 
      
    

3.4 合并模型文件


      
          

        
   

 
        
 #01 执行即可
 
        
   

 
        new\_model\_local = 
        
 "DeepSeek-R1-Medical-COT-Tiny"
 
        
   

 
        model.save\_pretrained(new\_model\_local) 
        
   

 
        tokenizer.save\_pretrained(new\_model\_local)
        
   

 
        model.save\_pretrained\_merged(new\_model\_local, tokenizer, save\_method = 
        
 "merged\_16bit"
 
        ,)
        
   

 
      
    

3.4 评估和监控训练过程

评估( eval/

)相关信息:

  • eval/runtime 18.3908

: 评估过程总共耗时18.39秒。

  • eval/samples\_per\_second 0.054

: 每秒处理的样本数为0.054,表示评估的速度较慢。

  • eval/steps\_per\_second 0.054

: 每秒进行评估步数为0.054,说明每个评估步骤的时间消耗较大。

训练( train/

)相关信息:

  • train/epoch 0

: 当前训练轮次是第0轮。

  • train/global\_step 0

: 当前全局步骤为0,表示尚未进行任何训练步骤。

  • train\_loss 14435.36663

: 当前训练的损失为14435.37,表明模型的表现尚不理想,通常需要更多的训练来降低损失。

  • train/runtime 251.7582

: 训练总时间为251.76秒。

  • train/samples\_per\_second 0.06

: 每秒处理的训练样本数为0.06,训练的速度较慢。

  • train/steps\_per\_second 0.012

: 每秒进行的训练步数为0.012,表示每个训练步骤消耗的时间较长。


      
          

        
   

 
        
   

 
        
 #02 详细日志
 
        
   

 
        wandb: ⭐️ View project at https://wandb.ai/z15119911990-beijing/huggingface
        
   

 
        wandb: 🚀 View run at https://wandb.ai/z15119911990-beijing/huggingface/runs/mgrko2jv
        
   

 
          0%|          | 0/3 [00:00<?, ?it/s]
        
   

 
        {
        
 'eval\_runtime'
 
        : 14.8693, 
        
 'eval\_samples\_per\_second'
 
        : 0.067, 
        
 'eval\_steps\_per\_second'
 
        : 0.067, 
        
 'epoch'
 
        : 0}
        
   

 
                                             
        
   

 
          0%|          | 0/3 [00:30<?, ?it/s]
        
   

 
        100%|██████████| 1/1 [00:00<00:00, 1461.94it/s]
        
   

 
                                                       
        
   

 
                                             
        
   

 
        {
        
 'eval\_runtime'
 
        : 21.2073, 
        
 'eval\_samples\_per\_second'
 
        : 0.047, 
        
 'eval\_steps\_per\_second'
 
        : 0.047, 
        
 'epoch'
 
        : 0}
        
   

 
          0%|          | 0/3 [02:11<?, ?it/s]
        
   

 
        100%|██████████| 1/1 [00:00<00:00, 33.69it/s]
        
   

 
                                                     
        
   

 
                                             
        
   

 
          0%|          | 0/3 [04:02<?, ?it/s]
        
   

 
        100%|██████████| 1/1 [00:00<00:00, 334.66it/s]
        
   

 
                                                      {
        
 'eval\_runtime'
 
        : 18.3908, 
        
 'eval\_samples\_per\_second'
 
        : 0.054, 
        
 'eval\_steps\_per\_second'
 
        : 0.054, 
        
 'epoch'
 
        : 0}
        
   

 
        {
        
 'train\_runtime'
 
        : 251.7582, 
        
 'train\_samples\_per\_second'
 
        : 0.06, 
        
 'train\_steps\_per\_second'
 
        : 0.012, 
        
 'train\_loss'
 
        : 14435.3666305542, 
        
 'epoch'
 
        : 0}
        
   

 
          0%|          | 0/3 [04:10<?, ?it/s]
        
   

 
        wandb:                                                                                
        
   

 
        wandb: 
        
   

 
        wandb: Run 
        
 history
 
        :
        
   

 
        wandb:            
        
 eval
 
        /runtime ▁█▅
        
   

 
        wandb: 
        
 eval
 
        /samples\_per\_second █▁▃
        
   

 
        wandb:   
        
 eval
 
        /steps\_per\_second █▁▃
        
   

 
        wandb:             train/epoch ▁▁▁▁
        
   

 
        wandb:       train/global\_step ▁▁▁▁
        
   

 
        wandb: 
        
   

 
        wandb: Run summary:
        
   

 
        wandb:             
        
 eval
 
        /runtime 18.3908
        
   

 
        wandb:  
        
 eval
 
        /samples\_per\_second 0.054
        
   

 
        wandb:    
        
 eval
 
        /steps\_per\_second 0.054
        
   

 
        wandb:               total\_flos 43804457687040.0
        
   

 
        wandb:              train/epoch 0
        
   

 
        wandb:        train/global\_step 0
        
   

 
        wandb:               train\_loss 14435.36663
        
   

 
        wandb:            train\_runtime 251.7582
        
   

 
        wandb: train\_samples\_per\_second 0.06
        
   

 
        wandb:   train\_steps\_per\_second 0.012
        
   

 
        wandb: 
        
   

 
        wandb: 🚀 View run /Users/ningcaichen/Documents/02-python相关文档/01-AI系列/LoRA-DeepSeek-R1/models/final\_model at: https://wandb.ai/z15119911990-beijing/huggingface/runs/mgrko2jv
        
   

 
        wandb: ⭐️ View project at: https://wandb.ai/z15119911990-beijing/huggingface
        
   

 
        wandb: Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)
        
   

 
        wandb: Find logs at: ./wandb/run-20250212\_133457-mgrko2jv/logs
        
   

 
        
   

 
        
   

 
      
    

机器学习算法AI大数据技术

搜索公众号添加: datanlp

picture.image

长按图片,识别二维码

阅读过本文的人还看了以下文章:

实时语义分割ENet算法,提取书本/票据边缘

整理开源的中文大语言模型,以规模较小、可私有化部署、训练成本较低的模型为主

《大语言模型》PDF下载

动手学深度学习-(李沐)PyTorch版本

YOLOv9电动车头盔佩戴检测,详细讲解模型训练

TensorFlow 2.0深度学习案例实战

基于40万表格数据集TableBank,用MaskRCNN做表格检测

《基于深度学习的自然语言处理》中/英PDF

Deep Learning 中文版初版-周志华团队

【全套视频课】最全的目标检测算法系列讲解,通俗易懂!

《美团机器学习实践》_美团算法团队.pdf

《深度学习入门:基于Python的理论与实现》高清中文PDF+源码

《深度学习:基于Keras的Python实践》PDF和代码

特征提取与图像处理(第二版).pdf

python就业班学习视频,从入门到实战项目

2019最新《PyTorch自然语言处理》英、中文版PDF+源码

《21个项目玩转深度学习:基于TensorFlow的实践详解》完整版PDF+附书代码

《深度学习之pytorch》pdf+附书源码

PyTorch深度学习快速实战入门《pytorch-handbook》

【下载】豆瓣评分8.1,《机器学习实战:基于Scikit-Learn和TensorFlow》

《Python数据分析与挖掘实战》PDF+完整源码

汽车行业完整知识图谱项目实战视频(全23课)

李沐大神开源《动手学深度学习》,加州伯克利深度学习(2019春)教材

笔记、代码清晰易懂!李航《统计学习方法》最新资源全套!

《神经网络与深度学习》最新2018版中英PDF+源码

将机器学习模型部署为REST API

FashionAI服装属性标签图像识别Top1-5方案分享

重要开源!CNN-RNN-CTC 实现手写汉字识别

yolo3 检测出图像中的不规则汉字

同样是机器学习算法工程师,你的面试为什么过不了?

前海征信大数据算法:风险概率预测

【Keras】完整实现‘交通标志’分类、‘票据’分类两个项目,让你掌握深度学习图像分类

VGG16迁移学习,实现医学图像识别分类工程项目

特征工程(一)

特征工程(二) :文本数据的展开、过滤和分块

特征工程(三):特征缩放,从词袋到 TF-IDF

特征工程(四): 类别特征

特征工程(五): PCA 降维

特征工程(六): 非线性特征提取和模型堆叠

特征工程(七):图像特征提取和深度学习

如何利用全新的决策树集成级联结构gcForest做特征工程并打分?

Machine Learning Yearning 中文翻译稿

蚂蚁金服2018秋招-算法工程师(共四面)通过

全球AI挑战-场景分类的比赛源码(多模型融合)

斯坦福CS230官方指南:CNN、RNN及使用技巧速查(打印收藏)

python+flask搭建CNN在线识别手写中文网站

中科院Kaggle全球文本匹配竞赛华人第1名团队-深度学习与特征工程

不断更新资源

深度学习、机器学习、数据分析、python

搜索公众号添加: datayx

picture.image

0
0
0
0
关于作者

文章

0

获赞

0

收藏

0

相关资源
字节跳动 XR 技术的探索与实践
火山引擎开发者社区技术大讲堂第二期邀请到了火山引擎 XR 技术负责人和火山引擎创作 CV 技术负责人,为大家分享字节跳动积累的前沿视觉技术及内外部的应用实践,揭秘现代炫酷的视觉效果背后的技术实现。
相关产品
评论
未登录
看完啦,登录分享一下感受吧~
暂无评论