# billiards_models **Repository Path**: oftenlin/billiards_models ## Basic Information - **Project Name**: billiards_models - **Description**: No description available - **Primary Language**: Unknown - **License**: Not specified - **Default Branch**: master - **Homepage**: None - **GVP Project**: No ## Statistics - **Stars**: 0 - **Forks**: 0 - **Created**: 2025-04-25 - **Last Updated**: 2025-12-11 ## Categories & Tags **Categories**: Uncategorized **Tags**: None ## README # 台球球杆检测系统 基于PyTorch的台球球杆位置检测模型,能够检测图片中最多3个球杆的位置和端点坐标。 ## 功能特点 - 🎯 **多球杆检测**: 支持检测最多3个球杆 - 📍 **端点定位**: 精确定位每个球杆的两个端点坐标 - 🔍 **置信度评估**: 提供每个检测结果的置信度分数 - 🎨 **可视化工具**: 丰富的可视化功能,包括预测结果、训练历史等 - 🚀 **推理脚本**: 支持单张图片、批量图片和视频检测 ## 项目结构 ``` billiards_models/ ├── cue_detection/ │ ├── config/ │ │ └── config.py # 配置文件 │ ├── utils/ │ │ ├── data_analysis.py # 数据分析工具 │ │ ├── data_processing.py # 数据预处理 │ │ └── visualization.py # 可视化工具 │ ├── dataset/ │ │ └── cue_dataset.py # 数据集类 │ ├── models/ │ │ └── cue_model.py # 模型定义 │ ├── train/ │ │ └── trainer.py # 训练器 │ ├── inference/ │ │ └── inference.py # 推理脚本 │ ├── data/ │ │ └── cue_annotations.json # 标注数据 │ ├── train_model.py # 训练脚本 │ └── outputs/ # 输出目录 ├── requirements.txt └── README.md ``` ## 安装依赖 ```bash pip install -r requirements.txt ``` ## 数据分析 首先分析数据集的基本信息: ```bash python cue_detection/utils/data_analysis.py ``` 这将生成数据分析报告和可视化图表。 ## 模型训练 ### 基础训练 ```bash python cue_detection/train_model.py ``` ### 自定义参数训练 ```bash python cue_detection/train_model.py \ --backbone resnet50 \ --batch_size 16 \ --learning_rate 0.001 \ --num_epochs 100 ``` ### 恢复训练 ```bash python cue_detection/train_model.py --resume ``` ### 仅评估模型 ```bash python cue_detection/train_model.py --eval_only ``` ## 推理预测 ### 单张图片推理 ```bash python cue_detection/inference/inference.py \ --model cue_detection/outputs/models/best_model.pth \ --image path/to/image.jpg \ --confidence 0.5 \ --visualize ``` ### 批量图片推理 ```bash python cue_detection/inference/inference.py \ --model cue_detection/outputs/models/best_model.pth \ --directory path/to/images/ \ --output results.json \ --save_vis output_visualizations/ ``` ### 视频检测 ```bash python -c " from cue_detection.utils.visualization import create_detection_video create_detection_video('input_video.mp4', 'best_model.pth', 'output_video.mp4') " ``` ## 配置说明 主要配置参数在 `cue_detection/config/config.py` 中: ```python # 数据配置 annotations_file = "cue_detection/data/cue_annotations.json" images_dir = "cue_detection/data/images" # 模型配置 backbone = "resnet50" # 主干网络 num_cues = 3 # 最大球杆数量 img_size = (512, 512) # 输入图片尺寸 # 训练配置 batch_size = 16 learning_rate = 1e-3 num_epochs = 100 ``` ## 模型架构 模型采用两阶段设计: 1. **特征提取**: 使用ResNet作为骨干网络提取图像特征 2. **多任务头**: - **分类头**: 判断每个位置是否有球杆(3个二分类) - **回归头**: 预测球杆端点坐标(12个坐标值) ## 损失函数 ```python 总损失 = 分类权重 × 分类损失 + 回归权重 × 回归损失 ``` - **分类损失**: BCEWithLogitsLoss - **回归损失**: SmoothL1Loss (只对存在球杆的位置计算) ## 评估指标 - **分类准确率**: 球杆存在性判断的准确率 - **端点误差**: 预测端点与真实端点的像素距离误差 - **平均精度**: 不同置信度阈值下的检测性能 ## 可视化功能 ### 训练历史可视化 ```python from cue_detection.utils.visualization import visualize_training_history visualize_training_history('cue_detection/outputs/logs/train_history.json') ``` ### 数据样本可视化 ```python from cue_detection.utils.visualization import visualize_data_samples from cue_detection.utils.data_processing import CueDataProcessor processor = CueDataProcessor(config) train_data, _, _ = processor.split_dataset() processed_data = processor.process_dataset(train_data[:6]) visualize_data_samples(processed_data) ``` ## 输出文件说明 训练完成后,会在 `cue_detection/outputs/` 目录下生成: - `models/best_model.pth`: 最佳模型权重 - `models/latest_checkpoint.pth`: 最新检查点 - `logs/training.log`: 训练日志 - `logs/train_history.json`: 训练历史数据 - `visualizations/`: 可视化图表 - `evaluation_results.json`: 评估结果 ## 注意事项 1. **图片路径**: 确保图片文件存在于 `cue_detection/data/images/` 目录中 2. **内存使用**: 如果出现内存不足,可以减小 `batch_size` 或使用 `resnet18` 3. **数据增强**: 如果没有安装 `albumentations`,系统会自动使用简化版数据加载器 4. **GPU支持**: 系统会自动检测并使用GPU加速训练 ## 故障排除 ### 常见问题 1. **ImportError**: 确保所有依赖包都已安装 2. **CUDA out of memory**: 减小batch_size或使用更小的模型 3. **图片加载失败**: 检查图片路径是否正确 ### 性能优化 1. **使用GPU**: 确保PyTorch支持CUDA 2. **数据加载**: 增加 `num_workers` 参数 3. **混合精度**: 可以使用 `torch.cuda.amp` 进行混合精度训练 ## 扩展功能 系统具有良好的扩展性,可以轻松添加: - 更多的backbone网络 - 不同的损失函数 - 新的数据增强策略 - 其他评估指标 ## 许可证 本项目仅用于学习和研究目的。 ## 联系方式 如有问题或建议,请通过以下方式联系: - 创建Issue - 提交Pull Request