# train_class **Repository Path**: mingmingxr/train_class ## Basic Information - **Project Name**: train_class - **Description**: No description available - **Primary Language**: Unknown - **License**: Not specified - **Default Branch**: master - **Homepage**: None - **GVP Project**: No ## Statistics - **Stars**: 0 - **Forks**: 2 - **Created**: 2025-09-24 - **Last Updated**: 2025-09-24 ## Categories & Tags **Categories**: Uncategorized **Tags**: None ## README # 医疗命名实体识别模型训练项目 本项目是基于ERNIE模型的医疗命名实体识别(NER)任务的微调训练项目。通过监督微调(SFT)方法,训练模型从医疗文本中抽取疾病(dis)和医疗程序(pro)类型的实体。 ## 目录结构 ``` train_class/ ├── config/ # 配置文件目录 ├── data/ # 数据目录 │ ├── trian_data/ # 训练数据 │ ├── test_data/ # 测试数据 │ └── test_model/ # 模型测试结果 ├── evaluates/ # 评估相关代码 ├── model/ # 模型相关代码 ├── src/ # 核心源代码 │ ├── sftdata.py # 数据处理模块 │ └── training.py # 训练流程模块 ├── test/ # 测试代码 ├── requirements.txt # 项目依赖 ├── sft_train.py # SFT训练主程序 ├── test_model.py # 模型测试程序 └── README.md # 项目说明文档 ``` [//]: # (```mermaid) [//]: # (graph TD) [//]: # ( %% 标题: 核电施工安全多模态大模型 (HedianAn-VLM) 训练流程) [//]: # ( %% 全局样式定义) [//]: # ( classDef stage fill:#f9f9f9,stroke:#ddd,stroke-width:2px,rx:10,ry:10;) [//]: # ( classDef model fill:#e6f7ff,stroke:#91d5ff,stroke-width:2px,rx:5,ry:5;) [//]: # ( classDef data fill:#f6ffed,stroke:#b7eb8f,stroke-width:2px,rx:5,ry:5;) [//]: # ( classDef process fill:#fff0f6,stroke:#ffadd2,stroke-width:2px,rx:20,ry:20;) [//]: # ( classDef finalModel fill:#fffbe6,stroke:#ffe58f,stroke-width:2px,rx:5,ry:5;) [//]: # () [//]: # ( %% 阶段一: 领域多模态持续预训练) [//]: # ( subgraph S1 [阶段一: 领域持续预训练]) [//]: # ( direction TB) [//]: # ( subgraph MultiModalData [多模态领域数据]) [//]: # ( direction LR) [//]: # ( A["文本数据:核电安全规范、施工方案与日志、事故/未遂案例"]:::data) [//]: # ( B["视觉数据:现场施工图像/视频、工程图纸(CAD/P&ID)、设备状态照片"]:::data) [//]: # ( end) [//]: # ( D["通用多模态数据"]:::data) [//]: # ( BaseVLM["开源视觉语言大模型(如 Qwen-VL-Max)"]:::model ) [//]: # ( %% 修复类名拼写错误(m→model)) [//]: # ( HedianAnBase[HedianAn-VLM_Base]:::model) [//]: # ( ) [//]: # ( MultiModalData --> BaseVLM) [//]: # ( D --> BaseVLM) [//]: # ( BaseVLM -->|领域特定持续预训练| HedianAnBase) [//]: # ( end) [//]: # () [//]: # ( %% 阶段二: 多模态指令微调 (SFT)) [//]: # ( subgraph S2 [阶段二: 多模态指令微调]) [//]: # ( direction TB) [//]: # ( F1["视觉隐患识别 LoRA\n输入: 图像/视频\n输出: 隐患类型与位置"]:::process) [//]: # ( F2["规范遵从性问答 LoRA\n输入: 图像+文本规范\n输出: 是否合规及原因"]:::process) [//]: # ( F3["安全巡检报告生成 LoRA\n输入: 现场图像集+文本描述\n输出: 结构化报告"]:::process) [//]: # ( HedianAnSFT[HedianAn-VLM_SFT]:::model) [//]: # ( ) [//]: # ( HedianAnBase --> F1) [//]: # ( HedianAnBase --> F2) [//]: # ( HedianAnBase --> F3) [//]: # ( ) [//]: # ( subgraph MergedSFT [能力融合]) [//]: # ( direction LR) [//]: # ( F1 --> HedianAnSFT) [//]: # ( F2 --> HedianAnSFT) [//]: # ( F3 --> HedianAnSFT) [//]: # ( end) [//]: # ( end) [//]: # () [//]: # ( %% 阶段三: 对齐与强化) [//]: # ( subgraph S3 [阶段三: 对齐与强化]) [//]: # ( direction TB) [//]: # ( G1["风险评估准确性\nDPO强化对隐患等级与影响的判断精度"]:::process) [//]: # ( G2["纠正措施合理性\nPPO强化生成措施的可行性、有效性"]:::process) [//]: # ( G3["解释与溯源能力\nDPO强化模型决策的透明度(引用规范、说明推理)"]:::process) [//]: # ( HedianAnFinal[HedianAn-VLM]:::finalModel) [//]: # ( ) [//]: # ( HedianAnSFT --> G1) [//]: # ( HedianAnSFT --> G2 ) [//]: # ( HedianAnSFT --> G3) [//]: # ( ) [//]: # ( subgraph FinalModelIntegration [最终模型]) [//]: # ( direction LR) [//]: # ( G1 --> HedianAnFinal) [//]: # ( G2 --> HedianAnFinal ) [//]: # ( G3 --> HedianAnFinal) [//]: # ( end) [//]: # ( end) [//]: # ( ) [//]: # ( %% 阶段间连接) [//]: # ( HedianAnBase --> HedianAnSFT) [//]: # ( HedianAnSFT --> HedianAnFinal) [//]: # (```) ## 环境依赖 ### 硬件要求 - **镜像**: - PyTorch 2.7.0 - Python 3.12 (Ubuntu 22.04) - CUDA 12.8 - **GPU**: - RTX 5090 (32GB) × 1 - **CPU**: - 25 vCPU Intel(R) Xeon(R) Platinum 8470Q - **内存**: - 120GB - **硬盘**: - 系统盘:30 GB - 数据盘: - 免费额度:50GB SSD - 付费额度:0GB ### 软件依赖 - Python 3.8+ - CUDA 11.8+ (如果使用GPU) - PyTorch 2.0+ - 其他依赖见[requirements.txt] ## 运行步骤 1. **AutoDL环境启动** - 登录 [AutoDL快速入门](https://www.autodl.com/docs/quick_start/) 页面。 - 创建一个GPU实例: ![创建GPU实例](images/img.png) - 在租用实例页面进行以下配置: - 选择计费方式、地区、GPU型号和GPU数量。 - 选择合适的空闲主机和镜像(平台提供了多种内置深度学习框架的基础镜像和社区镜像)。 ![创建GPU实例2](images/img5.png) - 确认配置后,点击创建实例,创建实例时选择这个镜像会自动完成初步环境配置。 ![img_2.png](images/img4.png) - 启动实例后,等待实例启动完成。 - 创建完成后等待自动开机,今后主要用到的操作入口见截图中。 ![img_1.png](images/img3.png) - 登录AutoDL控制台,进入实例详情页面。 - 使用ssh工具链接(如xshell,xftfp等工具) 2. 克隆项目代码: ```bash mkdir /data cd /data git clone https://gitee.com/yang-zhibo-01/train_class.git cd train_class ``` 3. 创建Conda环境(推荐): ```bash conda create -n med_ner python=3.12 conda activate med_ner ``` 4. 安装依赖: ```bash pip install -r requirements.txt ``` ## 数据格式 ### 训练数据格式 训练数据采用JSONL格式,每一行是一个JSON对象,包含以下字段: ```json { "instruction": "任务指令,描述模型需要完成的任务", "input": "输入文本,包含需要识别实体的医疗文本", "output": "期望输出,以JSON数组格式表示的实体列表" } ``` 示例: ```json { "instruction": "你是一个医疗命名实体识别的专家。请从以下文本中抽取出实体并对实体进行分类,并以JSON格式返回。", "input": "慢性Q热的其他表现还有血管受累、动脉瘤、骨髓炎、长期发热、肺炎、肝炎或紫癜样皮疹。", "output": "[{\"entity\": \"慢性Q热\", \"label\": \"dis\"}, {\"entity\": \"血管受累\", \"label\": \"sym\"}, {\"entity\": \"动脉瘤\", \"label\": \"sym\"}]" } ``` ### 输出格式 模型输出为JSON数组格式,包含识别出的实体: ```json [ {"entity": "实体文本", "label": "实体类型"}, {"entity": "慢性Q热", "label": "dis"}, {"entity": "血管受累", "label": "sym"} ] ``` 实体类型说明: - `dis`: 疾病(Disease) - `pro`: 医疗程序(Procedure) - `sym`: 症状(Symptom) - `bod`: 身体部位(Body) ## 使用方法 ### 1. 模型训练 下载模型: ``` modelscope download --model PaddlePaddle/ERNIE-4.5-0.3B-PT --local_dir /data/baidu/ERNIE-4.5-0.3B-PT # 使用国内源下载模型 ``` ### 推荐超参数配置表 | 学习率 (learning_rate) | 批次大小 (batch_size) | 训练样本数 (num_samples) | 训练轮数 (epochs) | 适用场景 | |---------------------|-------------------|-------------------------|------------------|----------| | 5e-5 | 8 | 1000 | 3 | 默认推荐配置 | | 4e-4 | 8 | 2000 | 5 | 高精度要求场景 | | 5e-5 | 4 | 500 | 2 | 快速训练验证 | 运行SFT训练脚本: ``` python sft_train.py # 采取默认超参数 python sft_train.py --learning_rate 6e-6 --batch_size 4 --num_samples 2000 --epochs 1 #自定义超参数 ``` 训练配置在[sft_train.py](train_class\sft_train.py)文件中定义,包括: - 模型名称: `baidu/ERNIE-4.5-0.3B-PT` - 超参数网格: 学习率、批次大小、训练样本数、训练轮数等 - 数据路径: 训练数据和测试数据路径 训练过程中会自动保存模型到指定路径,格式为: ``` ./ernie_sft_lora_tuning_results/lr_{学习率}_bs_{批次大小}_samples_{样本数}/final_model/ ``` ### 2. 模型测试 训练完成后,可以使用测试脚本对模型进行推理: ``` export MODEL_API_KEY="YOUR_API_KEY" # 填写智普API密钥 python test_model.py --lr 6e-6 --bs 4 --samples 2000 ``` 测试脚本会: - 加载训练好的模型 - 读取测试数据 - 对每条测试数据进行推理 - 保存结果到[data/test_model/]目录 ### 3. 自定义训练 如果需要自定义训练参数,可以修改[sft_train.py]中的以下配置: ```python # 超参数网格定义 hyperparameter_grid = [ {'learning_rate': 5e-5, 'batch_size': 4, 'num_samples': 1000, 'epochs': 3}, # 可以添加更多参数组合 ] # 数据路径 MODEL_NAME = "/data/baidu/ERNIE-4.5-0.3B-PT" TRAIN_DATASET_PATH = "/data/train_class/data/trian_data/training_data.jsonl" TEST_DATASET_PATH = "/data/train_class/data/test_data/test_data.jsonl" ``` ## 技术特点 ### 1. LoRA微调 项目使用LoRA(Low-Rank Adaptation)参数高效微调技术: - 微调更多层: 包括注意力机制和前馈网络的所有投影层 - 高秩值: 使用r=64提高模型表达能力 - 显存优化: 相比全参数微调大幅减少显存使用 ### 2. QLoRA量化 使用4位量化技术进一步减少显存占用: - 4位量化配置 - NF4量化类型 - bfloat16计算精度 ### 3. 数据处理 - 自动格式化输入提示(prompt) - Tokenization处理 - 序列截断和填充 ### 4. 模型保存 - 按超参数自动组织模型保存路径 - 保存完整的模型权重和tokenizer ## 项目模块说明 ### [sft_train.py](./sft_train.py) 超参数搜索和训练主程序,支持多组超参数组合训练。 ### [src/training.py](./src/training.py) 核心训练流程,包括模型加载、数据处理、训练和模型保存。 ### [src/sftdata.py](./src/sftdata.py) 数据处理模块,负责数据加载、格式化和tokenization。 ### [model/ERNIE.py](./model/ERNIE.py) ERNIE模型微调器封装,基于HuggingFace Transformers Trainer。 ### [evaluates/Metrics.py](evaluates/metrics.py) 评估指标计算,包括精确率、召回率和F1分数(目前暂时不可用)。 ### [test_model.py](./test_model.py) 模型测试脚本,用于加载训练好的模型并进行推理。 ## 常见问题 ### 1. 显存不足 如果遇到CUDA Out of Memory错误,可以尝试: - 减小[batch_size]参数 - 减少[max_seq_length] - 启用更激进的量化设置 ### 2. 模型加载失败 确保模型路径正确,并且包含以下文件: - config.json - pytorch_model.bin 或 adapter_model.bin - tokenizer_config.json - special_tokens_map.json ### 3. 数据格式错误 确保训练数据符合JSONL格式,每行一个JSON对象。 ## 结果分析 训练结果保存在以下位置: - 模型文件: `ernie_sft_lora_tuning_results/`目录下对应参数的子目录 - 训练日志: 控制台输出 - 测试结果: [data/test_model/]目录下的JSON文件 ## 未来规划 ### 1. 准确率和召回率计算方式(正在完成) - 评估指标计算,目前暂时不可用。 - 计划使用大模型来判断NER提取是否完全,以及NER的分类是否正确。 ### 2. 实时输出准确率 - 将训练和推理代码结合,实现实时准确率输出功能。 ### 3. 强化学习更新(待验证) - 探索使用强化学习方法进一步优化模型性能(不一定成功,模型参数太小了)。