Bohrium
robot
新建

空间站广场

论文
Notebooks
比赛
课程
Apps
我的主页
我的Notebooks
我的论文库
我的足迹

我的工作空间

任务
节点
文件
数据集
镜像
项目
数据库
公开
智能滴定算法改进的训练
Machine Learning
python
Machine Learningpython
zzh
更新于 2024-11-02
推荐镜像 :DeePMD-kit:3.0.0b4-cuda12.1
推荐机型 :c32_m64_cpu
优化训练算法
简化数据增强策略
引入早停策略
实时记录与可视化
原初算法模型(大连理工大学)
第一版修改,效果不好,舍去
第二版修改,效修改了代码中的早停机制,效果不好,舍去
第三版修改,得到很好的效果
训练日志的输出方法

智能滴定是一种结合自动化控制与智能算法的现代化化学分析技术,旨在实现样品滴定过程的高精度和高效率。本次智能滴定技术主要用于变色滴定,通过图像采集和颜色识别自动判断滴定终点。通过先进的计算技术、传感器系统和数据处理算法,智能滴定系统能够实时监控滴定过程中颜色的变化,自动调整滴定速度,确保实验的精确性和稳定性。

在本次智能滴定系统中,针对图像识别算法的改进集中在以下几个方面:

优化训练算法

将优化器从 Adam 改为 SGD,并添加动量项和学习率调度器(StepLR)。这种组合在分类任务中通常表现出更高的准确性和稳定性,特别是在小数据集上。学习率调度器可以在训练后期降低学习率,帮助模型细化结果,提高泛化能力。

简化数据增强策略

本次修改精简了数据增强操作,保留了核心特征提取的增强方法,如随机裁剪和水平翻转。这种调整减少了不必要的复杂变换,聚焦于图像的主要特征,有助于模型在变色滴定过程中对颜色变化的识别更准确。

引入早停策略

为了防止模型过拟合,加入了早停策略。当验证准确率不再提升时,系统会自动停止训练,从而节省训练时间并提高模型的鲁棒性。

实时记录与可视化

使用 TensorBoard 实时记录训练过程中的损失、准确率和训练时长,并进行可视化分析。这有助于优化模型训练过程,并确保模型以最优的精度和效率完成训练。

通过这些改进,智能滴定系统可以更加准确和快速地识别滴定终点的颜色变化,适应复杂的实验环境,实现更高的实验精度和稳定性。

代码
文本
双击即可修改
代码
文本

安装torch,resnet,tensorflow,torchvision库

代码
文本
[2]
!pip install torch
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
Collecting torch
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/2a/ef/834af4a885b31a0b32fff2d80e1e40f771e1566ea8ded55347502440786a/torch-2.5.1-cp310-cp310-manylinux1_x86_64.whl (906.4 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 906.4/906.4 MB 1.1 MB/s eta 0:00:0000:0100:02
Collecting nvidia-cuda-nvrtc-cu12==12.4.127
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/2c/14/91ae57cd4db3f9ef7aa99f4019cfa8d54cb4caa7e00975df6467e9725a9f/nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (24.6 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 24.6/24.6 MB 21.6 MB/s eta 0:00:0000:0100:01
Collecting networkx
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/b9/54/dd730b32ea14ea797530a4479b2ed46a6fb250f682a9cfb997e968bf0261/networkx-3.4.2-py3-none-any.whl (1.7 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.7/1.7 MB 11.9 MB/s eta 0:00:0000:0100:01
Collecting nvidia-cusolver-cu12==11.6.1.9
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/3a/e1/5b9089a4b2a4790dfdea8b3a006052cfecff58139d5a4e34cb1a51df8d6f/nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl (127.9 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 127.9/127.9 MB 5.7 MB/s eta 0:00:0000:0100:01
Collecting nvidia-cuda-cupti-cu12==12.4.127
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/67/42/f4f60238e8194a3106d06a058d494b18e006c10bb2b915655bd9f6ea4cb1/nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (13.8 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 13.8/13.8 MB 36.8 MB/s eta 0:00:0000:0100:01
Collecting nvidia-nccl-cu12==2.21.5
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/df/99/12cd266d6233f47d00daf3a72739872bdc10267d0383508b0b9c84a18bb6/nvidia_nccl_cu12-2.21.5-py3-none-manylinux2014_x86_64.whl (188.7 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 188.7/188.7 MB 2.6 MB/s eta 0:00:0000:0100:01
Collecting fsspec
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/c6/b2/454d6e7f0158951d8a78c2e1eb4f69ae81beb8dca5fee9809c6c99e9d0d0/fsspec-2024.10.0-py3-none-any.whl (179 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 179.6/179.6 kB 21.5 MB/s eta 0:00:00
Collecting nvidia-cublas-cu12==12.4.5.8
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/ae/71/1c91302526c45ab494c23f61c7a84aa568b8c1f9d196efa5993957faf906/nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl (363.4 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 363.4/363.4 MB 2.1 MB/s eta 0:00:0000:0100:01
Collecting nvidia-cusparse-cu12==12.3.1.170
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/db/f7/97a9ea26ed4bbbfc2d470994b8b4f338ef663be97b8f677519ac195e113d/nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl (207.5 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 207.5/207.5 MB 2.8 MB/s eta 0:00:0000:0100:01
Collecting triton==3.1.0
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/98/29/69aa56dc0b2eb2602b553881e34243475ea2afd9699be042316842788ff5/triton-3.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (209.5 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 209.5/209.5 MB 2.5 MB/s eta 0:00:0000:0100:01
Collecting nvidia-nvjitlink-cu12==12.4.127
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/ff/ff/847841bacfbefc97a00036e0fce5a0f086b640756dc38caea5e1bb002655/nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (21.1 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 21.1/21.1 MB 35.5 MB/s eta 0:00:0000:0100:01
Collecting filelock
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/b9/f8/feced7779d755758a52d1f6635d990b8d98dc0a29fa568bbe0625f18fdf3/filelock-3.16.1-py3-none-any.whl (16 kB)
Collecting nvidia-curand-cu12==10.3.5.147
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/8a/6d/44ad094874c6f1b9c654f8ed939590bdc408349f137f9b98a3a23ccec411/nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl (56.3 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 56.3/56.3 MB 9.8 MB/s eta 0:00:00:00:0100:01
Requirement already satisfied: jinja2 in /opt/mamba/lib/python3.10/site-packages (from torch) (3.1.2)
Collecting nvidia-nvtx-cu12==12.4.127
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/87/20/199b8713428322a2f22b722c62b8cc278cc53dffa9705d744484b5035ee9/nvidia_nvtx_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (99 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 99.1/99.1 kB 19.3 MB/s eta 0:00:00
Requirement already satisfied: typing-extensions>=4.8.0 in /opt/mamba/lib/python3.10/site-packages (from torch) (4.12.2)
Collecting nvidia-cudnn-cu12==9.1.0.70
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/9f/fd/713452cd72343f682b1c7b9321e23829f00b842ceaedcda96e742ea0b0b3/nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl (664.8 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 664.8/664.8 MB 1.2 MB/s eta 0:00:0000:0100:02
Collecting sympy==1.13.1
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/b2/fe/81695a1aa331a842b582453b605175f419fe8540355886031328089d840a/sympy-1.13.1-py3-none-any.whl (6.2 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 6.2/6.2 MB 20.9 MB/s eta 0:00:0000:0100:01
Collecting nvidia-cuda-runtime-cu12==12.4.127
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/ea/27/1795d86fe88ef397885f2e580ac37628ed058a92ed2c39dc8eac3adf0619/nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (883 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 883.7/883.7 kB 33.1 MB/s eta 0:00:00
Collecting nvidia-cufft-cu12==11.2.1.3
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/27/94/3266821f65b92b3138631e9c8e7fe1fb513804ac934485a8d05776e1dd43/nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl (211.5 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 211.5/211.5 MB 2.4 MB/s eta 0:00:0000:0100:01
Collecting mpmath<1.4,>=1.1.0
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/43/e3/7d92a15f894aa0c9c4b49b8ee9ac9850d6e63b03c9c32c0367a13ae62209/mpmath-1.3.0-py3-none-any.whl (536 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 536.2/536.2 kB 12.1 MB/s eta 0:00:00a 0:00:01
Requirement already satisfied: MarkupSafe>=2.0 in /opt/mamba/lib/python3.10/site-packages (from jinja2->torch) (2.1.2)
Installing collected packages: mpmath, sympy, nvidia-nvtx-cu12, nvidia-nvjitlink-cu12, nvidia-nccl-cu12, nvidia-curand-cu12, nvidia-cufft-cu12, nvidia-cuda-runtime-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-cupti-cu12, nvidia-cublas-cu12, networkx, fsspec, filelock, triton, nvidia-cusparse-cu12, nvidia-cudnn-cu12, nvidia-cusolver-cu12, torch
Successfully installed filelock-3.16.1 fsspec-2024.10.0 mpmath-1.3.0 networkx-3.4.2 nvidia-cublas-cu12-12.4.5.8 nvidia-cuda-cupti-cu12-12.4.127 nvidia-cuda-nvrtc-cu12-12.4.127 nvidia-cuda-runtime-cu12-12.4.127 nvidia-cudnn-cu12-9.1.0.70 nvidia-cufft-cu12-11.2.1.3 nvidia-curand-cu12-10.3.5.147 nvidia-cusolver-cu12-11.6.1.9 nvidia-cusparse-cu12-12.3.1.170 nvidia-nccl-cu12-2.21.5 nvidia-nvjitlink-cu12-12.4.127 nvidia-nvtx-cu12-12.4.127 sympy-1.13.1 torch-2.5.1 triton-3.1.0
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv
代码
文本
[4]
!pip install torchvision
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
Collecting torchvision
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/a2/f6/7ff89a9f8703f623f5664afd66c8600e3f09fe188e1e0b7e6f9a8617f865/torchvision-0.20.1-cp310-cp310-manylinux1_x86_64.whl (7.2 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 7.2/7.2 MB 40.6 MB/s eta 0:00:0000:0100:01
Requirement already satisfied: torch==2.5.1 in /opt/mamba/lib/python3.10/site-packages (from torchvision) (2.5.1)
Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /opt/mamba/lib/python3.10/site-packages (from torchvision) (10.4.0)
Requirement already satisfied: numpy in /opt/mamba/lib/python3.10/site-packages (from torchvision) (1.24.2)
Requirement already satisfied: nvidia-curand-cu12==10.3.5.147 in /opt/mamba/lib/python3.10/site-packages (from torch==2.5.1->torchvision) (10.3.5.147)
Requirement already satisfied: nvidia-cufft-cu12==11.2.1.3 in /opt/mamba/lib/python3.10/site-packages (from torch==2.5.1->torchvision) (11.2.1.3)
Requirement already satisfied: typing-extensions>=4.8.0 in /opt/mamba/lib/python3.10/site-packages (from torch==2.5.1->torchvision) (4.12.2)
Requirement already satisfied: nvidia-cusolver-cu12==11.6.1.9 in /opt/mamba/lib/python3.10/site-packages (from torch==2.5.1->torchvision) (11.6.1.9)
Requirement already satisfied: nvidia-cusparse-cu12==12.3.1.170 in /opt/mamba/lib/python3.10/site-packages (from torch==2.5.1->torchvision) (12.3.1.170)
Requirement already satisfied: nvidia-nvtx-cu12==12.4.127 in /opt/mamba/lib/python3.10/site-packages (from torch==2.5.1->torchvision) (12.4.127)
Requirement already satisfied: sympy==1.13.1 in /opt/mamba/lib/python3.10/site-packages (from torch==2.5.1->torchvision) (1.13.1)
Requirement already satisfied: jinja2 in /opt/mamba/lib/python3.10/site-packages (from torch==2.5.1->torchvision) (3.1.2)
Requirement already satisfied: nvidia-cuda-cupti-cu12==12.4.127 in /opt/mamba/lib/python3.10/site-packages (from torch==2.5.1->torchvision) (12.4.127)
Requirement already satisfied: nvidia-cublas-cu12==12.4.5.8 in /opt/mamba/lib/python3.10/site-packages (from torch==2.5.1->torchvision) (12.4.5.8)
Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.4.127 in /opt/mamba/lib/python3.10/site-packages (from torch==2.5.1->torchvision) (12.4.127)
Requirement already satisfied: fsspec in /opt/mamba/lib/python3.10/site-packages (from torch==2.5.1->torchvision) (2024.10.0)
Requirement already satisfied: triton==3.1.0 in /opt/mamba/lib/python3.10/site-packages (from torch==2.5.1->torchvision) (3.1.0)
Requirement already satisfied: networkx in /opt/mamba/lib/python3.10/site-packages (from torch==2.5.1->torchvision) (3.4.2)
Requirement already satisfied: nvidia-cudnn-cu12==9.1.0.70 in /opt/mamba/lib/python3.10/site-packages (from torch==2.5.1->torchvision) (9.1.0.70)
Requirement already satisfied: nvidia-cuda-runtime-cu12==12.4.127 in /opt/mamba/lib/python3.10/site-packages (from torch==2.5.1->torchvision) (12.4.127)
Requirement already satisfied: nvidia-nvjitlink-cu12==12.4.127 in /opt/mamba/lib/python3.10/site-packages (from torch==2.5.1->torchvision) (12.4.127)
Requirement already satisfied: filelock in /opt/mamba/lib/python3.10/site-packages (from torch==2.5.1->torchvision) (3.16.1)
Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in /opt/mamba/lib/python3.10/site-packages (from torch==2.5.1->torchvision) (2.21.5)
Requirement already satisfied: mpmath<1.4,>=1.1.0 in /opt/mamba/lib/python3.10/site-packages (from sympy==1.13.1->torch==2.5.1->torchvision) (1.3.0)
Requirement already satisfied: MarkupSafe>=2.0 in /opt/mamba/lib/python3.10/site-packages (from jinja2->torch==2.5.1->torchvision) (2.1.2)
Installing collected packages: torchvision
Successfully installed torchvision-0.20.1
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv
代码
文本
双击即可修改
代码
文本
[7]
!pip install resnet
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
Collecting resnet
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/c2/5b/1a89d31126c50cea7b29db3772a00862fa72b54f0970032766c914091ee0/resnet-0.1.tar.gz (5.8 kB)
  Preparing metadata (setup.py) ... done
Collecting keras>=2.0
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/c2/88/eef50051a772dcb4433d1f3e4c1d6576ba450fe83e89d028d7e8b85a2122/keras-3.6.0-py3-none-any.whl (1.2 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.2/1.2 MB 15.2 MB/s eta 0:00:0000:0100:01
Collecting optree
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/ab/ce/91d9f095fb2bd0a22490ede4580aed861472682b47bd460bc06369fce502/optree-0.13.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (358 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 358.9/358.9 kB 28.0 MB/s eta 0:00:00
Requirement already satisfied: packaging in /opt/mamba/lib/python3.10/site-packages (from keras>=2.0->resnet) (23.0)
Requirement already satisfied: numpy in /opt/mamba/lib/python3.10/site-packages (from keras>=2.0->resnet) (1.24.2)
Collecting absl-py
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/a2/ad/e0d3c824784ff121c03cc031f944bc7e139a8f1870ffd2845cc2dd76f6c4/absl_py-2.1.0-py3-none-any.whl (133 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 133.7/133.7 kB 27.8 MB/s eta 0:00:00
Collecting rich
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/9a/e2/10e9819cf4a20bd8ea2f5dabafc2e6bf4a78d6a0965daeb60a4b34d1c11f/rich-13.9.3-py3-none-any.whl (242 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 242.2/242.2 kB 28.1 MB/s eta 0:00:00
Collecting ml-dtypes
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/9a/5b/d47361f882ff2ae27d764f314d18706c69859da60a6c78e6c9e81714c792/ml_dtypes-0.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (4.5 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 4.5/4.5 MB 49.7 MB/s eta 0:00:00a 0:00:01
Requirement already satisfied: h5py in /opt/mamba/lib/python3.10/site-packages (from keras>=2.0->resnet) (3.8.0)
Collecting namex
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/73/59/7854fbfb59f8ae35483ce93493708be5942ebb6328cd85b3a609df629736/namex-0.0.8-py3-none-any.whl (5.8 kB)
Requirement already satisfied: typing-extensions>=4.5.0 in /opt/mamba/lib/python3.10/site-packages (from optree->keras>=2.0->resnet) (4.12.2)
Collecting markdown-it-py>=2.2.0
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/42/d7/1ec15b46af6af88f19b8e5ffea08fa375d433c998b8a7639e76935c14f1f/markdown_it_py-3.0.0-py3-none-any.whl (87 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 87.5/87.5 kB 31.2 MB/s eta 0:00:00
Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /opt/mamba/lib/python3.10/site-packages (from rich->keras>=2.0->resnet) (2.14.0)
Collecting mdurl~=0.1
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/b3/38/89ba8ad64ae25be8de66a6d463314cf1eb366222074cfda9ee839c56a4b4/mdurl-0.1.2-py3-none-any.whl (10.0 kB)
Building wheels for collected packages: resnet
  Building wheel for resnet (setup.py) ... done
  Created wheel for resnet: filename=resnet-0.1-py3-none-any.whl size=10027 sha256=4ae376a10f10a2f46ba5c707f154202beb920a7d38480cfc0429aa4ad2b345c2
  Stored in directory: /root/.cache/pip/wheels/fa/38/08/daf71e4f8411d19e510748916e29e17be43a33d5b0a1b4896b
Successfully built resnet
Installing collected packages: namex, optree, ml-dtypes, mdurl, absl-py, markdown-it-py, rich, keras, resnet
Successfully installed absl-py-2.1.0 keras-3.6.0 markdown-it-py-3.0.0 mdurl-0.1.2 ml-dtypes-0.5.0 namex-0.0.8 optree-0.13.0 resnet-0.1 rich-13.9.3
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv
代码
文本
[9]
!pip install tensorflow
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
Collecting tensorflow
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/d4/80/1567ccc375ccda4d28af28c960cca7f709f7c259463ac1436554697e8868/tensorflow-2.18.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (615.3 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 615.3/615.3 MB 1.2 MB/s eta 0:00:0000:0100:02
Requirement already satisfied: keras>=3.5.0 in /opt/mamba/lib/python3.10/site-packages (from tensorflow) (3.6.0)
Collecting numpy<2.1.0,>=1.26.0
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/fa/66/f7177ab331876200ac7563a580140643d1179c8b4b6a6b0fc9838de2a9b8/numpy-2.0.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (19.5 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 19.5/19.5 MB 22.1 MB/s eta 0:00:0000:0100:01
Collecting tensorflow-io-gcs-filesystem>=0.23.1
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/f3/48/47b7d25572961a48b1de3729b7a11e835b888e41e0203cca82df95d23b91/tensorflow_io_gcs_filesystem-0.37.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (5.1 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 5.1/5.1 MB 11.9 MB/s eta 0:00:0000:0100:01
Collecting opt-einsum>=2.3.2
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/23/cd/066e86230ae37ed0be70aae89aabf03ca8d9f39c8aea0dec8029455b5540/opt_einsum-3.4.0-py3-none-any.whl (71 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 71.9/71.9 kB 17.0 MB/s eta 0:00:00
Collecting astunparse>=1.6.0
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/2b/03/13dde6512ad7b4557eb792fbcf0c653af6076b81e5941d36ec61f7ce6028/astunparse-1.6.3-py2.py3-none-any.whl (12 kB)
Collecting google-pasta>=0.1.1
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/a3/de/c648ef6835192e6e2cc03f40b19eeda4382c49b5bafb43d88b931c4c74ac/google_pasta-0.2.0-py3-none-any.whl (57 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 57.5/57.5 kB 17.3 MB/s eta 0:00:00
Collecting libclang>=13.0.0
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/1d/fc/716c1e62e512ef1c160e7984a73a5fc7df45166f2ff3f254e71c58076f7c/libclang-18.1.1-py2.py3-none-manylinux2010_x86_64.whl (24.5 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 24.5/24.5 MB 11.1 MB/s eta 0:00:0000:0100:01
Requirement already satisfied: typing-extensions>=3.6.6 in /opt/mamba/lib/python3.10/site-packages (from tensorflow) (4.12.2)
Collecting flatbuffers>=24.3.25
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/41/f0/7e988a019bc54b2dbd0ad4182ef2d53488bb02e58694cd79d61369e85900/flatbuffers-24.3.25-py2.py3-none-any.whl (26 kB)
Requirement already satisfied: requests<3,>=2.21.0 in /opt/mamba/lib/python3.10/site-packages (from tensorflow) (2.28.1)
Collecting tensorboard<2.19,>=2.18
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/b1/de/021c1d407befb505791764ad2cbd56ceaaa53a746baed01d2e2143f05f18/tensorboard-2.18.0-py3-none-any.whl (5.5 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 5.5/5.5 MB 14.9 MB/s eta 0:00:0000:0100:01
Collecting gast!=0.5.0,!=0.5.1,!=0.5.2,>=0.2.1
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/a3/61/8001b38461d751cd1a0c3a6ae84346796a5758123f3ed97a1b121dfbf4f3/gast-0.6.0-py3-none-any.whl (21 kB)
Requirement already satisfied: packaging in /opt/mamba/lib/python3.10/site-packages (from tensorflow) (23.0)
Requirement already satisfied: setuptools in /opt/mamba/lib/python3.10/site-packages (from tensorflow) (65.5.0)
Collecting h5py>=3.11.0
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/85/bc/e76f4b2096e0859225f5441d1b7f5e2041fffa19fc2c16756c67078417aa/h5py-3.12.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (5.3 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 5.3/5.3 MB 35.5 MB/s eta 0:00:0000:0100:01
Collecting protobuf!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<6.0.0dev,>=3.20.3
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/5d/ae/3257b09328c0b4e59535e497b0c7537d4954038bdd53a2f0d2f49d15a7c4/protobuf-5.28.3-cp38-abi3-manylinux2014_x86_64.whl (316 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 316.6/316.6 kB 41.9 MB/s eta 0:00:00
Collecting termcolor>=1.1.0
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/7f/be/df630c387a0a054815d60be6a97eb4e8f17385d5d6fe660e1c02750062b4/termcolor-2.5.0-py3-none-any.whl (7.8 kB)
Collecting wrapt>=1.11.0
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/49/83/b40bc1ad04a868b5b5bcec86349f06c1ee1ea7afe51dc3e46131e4f39308/wrapt-1.16.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (80 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 80.3/80.3 kB 22.0 MB/s eta 0:00:00
Requirement already satisfied: absl-py>=1.0.0 in /opt/mamba/lib/python3.10/site-packages (from tensorflow) (2.1.0)
Requirement already satisfied: six>=1.12.0 in /opt/mamba/lib/python3.10/site-packages (from tensorflow) (1.16.0)
Collecting grpcio<2.0,>=1.24.3
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/64/19/a16762a70eeb8ddfe43283ce434d1499c1c409ceec0c646f783883084478/grpcio-1.67.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (5.9 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 5.9/5.9 MB 74.2 MB/s eta 0:00:0000:0100:01
Collecting ml-dtypes<0.5.0,>=0.4.0
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/16/86/a9f7569e7e4f5395f927de38a13b92efa73f809285d04f2923b291783dd2/ml_dtypes-0.4.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.2 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.2/2.2 MB 81.0 MB/s eta 0:00:00
Requirement already satisfied: wheel<1.0,>=0.23.0 in /opt/mamba/lib/python3.10/site-packages (from astunparse>=1.6.0->tensorflow) (0.37.1)
Requirement already satisfied: namex in /opt/mamba/lib/python3.10/site-packages (from keras>=3.5.0->tensorflow) (0.0.8)
Requirement already satisfied: optree in /opt/mamba/lib/python3.10/site-packages (from keras>=3.5.0->tensorflow) (0.13.0)
Requirement already satisfied: rich in /opt/mamba/lib/python3.10/site-packages (from keras>=3.5.0->tensorflow) (13.9.3)
Requirement already satisfied: idna<4,>=2.5 in /opt/mamba/lib/python3.10/site-packages (from requests<3,>=2.21.0->tensorflow) (3.4)
Requirement already satisfied: certifi>=2017.4.17 in /opt/mamba/lib/python3.10/site-packages (from requests<3,>=2.21.0->tensorflow) (2022.9.24)
Requirement already satisfied: charset-normalizer<3,>=2 in /opt/mamba/lib/python3.10/site-packages (from requests<3,>=2.21.0->tensorflow) (2.1.1)
Requirement already satisfied: urllib3<1.27,>=1.21.1 in /opt/mamba/lib/python3.10/site-packages (from requests<3,>=2.21.0->tensorflow) (1.26.11)
Collecting werkzeug>=1.0.1
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/6c/69/05837f91dfe42109203ffa3e488214ff86a6d68b2ed6c167da6cdc42349b/werkzeug-3.0.6-py3-none-any.whl (227 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 228.0/228.0 kB 65.7 MB/s eta 0:00:00
Collecting markdown>=2.6.8
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/3f/08/83871f3c50fc983b88547c196d11cf8c3340e37c32d2e9d6152abe2c61f7/Markdown-3.7-py3-none-any.whl (106 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 106.3/106.3 kB 37.1 MB/s eta 0:00:00
Collecting tensorboard-data-server<0.8.0,>=0.7.0
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/73/c6/825dab04195756cf8ff2e12698f22513b3db2f64925bdd41671bfb33aaa5/tensorboard_data_server-0.7.2-py3-none-manylinux_2_31_x86_64.whl (6.6 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 6.6/6.6 MB 86.4 MB/s eta 0:00:00:00:0100:01
Requirement already satisfied: MarkupSafe>=2.1.1 in /opt/mamba/lib/python3.10/site-packages (from werkzeug>=1.0.1->tensorboard<2.19,>=2.18->tensorflow) (2.1.2)
Requirement already satisfied: markdown-it-py>=2.2.0 in /opt/mamba/lib/python3.10/site-packages (from rich->keras>=3.5.0->tensorflow) (3.0.0)
Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /opt/mamba/lib/python3.10/site-packages (from rich->keras>=3.5.0->tensorflow) (2.14.0)
Requirement already satisfied: mdurl~=0.1 in /opt/mamba/lib/python3.10/site-packages (from markdown-it-py>=2.2.0->rich->keras>=3.5.0->tensorflow) (0.1.2)
Installing collected packages: libclang, flatbuffers, wrapt, werkzeug, termcolor, tensorflow-io-gcs-filesystem, tensorboard-data-server, protobuf, opt-einsum, numpy, markdown, grpcio, google-pasta, gast, astunparse, tensorboard, ml-dtypes, h5py, tensorflow
  Attempting uninstall: numpy
    Found existing installation: numpy 1.24.2
    Uninstalling numpy-1.24.2:
      Successfully uninstalled numpy-1.24.2
  Attempting uninstall: ml-dtypes
    Found existing installation: ml_dtypes 0.5.0
    Uninstalling ml_dtypes-0.5.0:
      Successfully uninstalled ml_dtypes-0.5.0
  Attempting uninstall: h5py
    Found existing installation: h5py 3.8.0
    Uninstalling h5py-3.8.0:
      Successfully uninstalled h5py-3.8.0
ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
scipy 1.10.1 requires numpy<1.27.0,>=1.19.5, but you have numpy 2.0.2 which is incompatible.
Successfully installed astunparse-1.6.3 flatbuffers-24.3.25 gast-0.6.0 google-pasta-0.2.0 grpcio-1.67.1 h5py-3.12.1 libclang-18.1.1 markdown-3.7 ml-dtypes-0.4.1 numpy-2.0.2 opt-einsum-3.4.0 protobuf-5.28.3 tensorboard-2.18.0 tensorboard-data-server-0.7.2 tensorflow-2.18.0 tensorflow-io-gcs-filesystem-0.37.1 termcolor-2.5.0 werkzeug-3.0.6 wrapt-1.16.0
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv
代码
文本

原初算法模型(大连理工大学)

代码
文本
[12]
import os
import sys
import json

import torch
import torch.nn as nn
from torchvision import transforms, datasets
import torch.optim as optim
from tqdm import tqdm

from torchvision.models import resnet34


# 主函数
def main():
# 判断是否有可用的GPU,如果有则使用GPU,否则使用CPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("using {} device.".format(device))

# 定义训练和验证的数据变换
data_transform = {
"train": transforms.Compose([
# 随机裁剪并缩放图片到224x224大小
transforms.RandomResizedCrop(224),
# 随机水平翻转图片
transforms.RandomHorizontalFlip(),
# 将图片转换为Tensor
transforms.ToTensor(),
# 对图片进行归一化
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
"val": transforms.Compose([
# 将图片缩放到224x224大小
transforms.Resize((224, 224)),
# 将图片转换为Tensor
transforms.ToTensor(),
# 对图片进行归一化
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}

# 获取数据集的根路径
data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))
# 拼接出图片数据集的路径
image_path = os.path.join(data_root, "/personal/Auto_Titration/Picture_Train/data/")
# 断言图片数据集路径存在
assert os.path.exists(image_path), "{} path does not exist.".format(image_path)

train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
transform=data_transform["train"])
# 获取训练数据集中的样本数量
train_num = len(train_dataset)

# 获取类别到索引的映射
flower_list = train_dataset.class_to_idx
# 反转映射,得到索引到类别的映射
cla_dict = dict((val, key) for key, val in flower_list.items())

# 将索引到类别的映射写入json文件
json_str = json.dumps(cla_dict, indent=4)
with open('class_indices.json', 'w') as json_file:
json_file.write(json_str)

# 设置batch大小
batch_size = 32
# 计算每个进程使用的dataloader工作线程数
nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])
print('Using {} dataloader workers every process'.format(nw))

# 创建训练数据加载器,使用指定的批次大小、是否打乱数据以及工作线程数
train_loader = torch.utils.data.DataLoader(train_dataset,
batch_size=batch_size, shuffle=True,
num_workers=0)
# 创建验证数据集,使用ImageFolder加载指定目录下的图片,并应用相应的数据变换
validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
transform=data_transform["val"])
# 获取验证数据集的样本数量
val_num = len(validate_dataset)
# 创建验证数据加载器
validate_loader = torch.utils.data.DataLoader(validate_dataset,
batch_size=batch_size, shuffle=False,
num_workers=0)
# 打印用于训练和验证的图片数量
print("using {} images for training, {} images for validation.".format(train_num,
val_num))
# 创建一个验证数据加载器的迭代器,并获取一张图片和对应的标签(这里并未使用)
# test_data_iter = iter(validate_loader)
# test_image, test_label = test_data_iter.next()

# 定义模型名称
model_name = "resnet34-3"
# 实例化ResNet34模型,并设置输出类别数为2
net = resnet34(num_classes=2)
# 将模型移动到指定的设备上(CPU或GPU)
net.to(device)
# 定义损失函数为交叉熵损失
loss_function = nn.CrossEntropyLoss()
# 定义优化器为Adam,并设置学习率为0.0001
optimizer = optim.Adam(net.parameters(), lr=0.0001)
# 设置训练轮数
epochs = 100
# 初始化最佳准确率为0
best_acc = 0.0
# 定义模型保存路径
save_path = './{}Net.pth'.format(model_name)
# 获取训练数据加载器的长度,即训练步数
train_steps = len(train_loader)
# 开始训练循环
for epoch in range(epochs):
# 设置模型为训练模式
net.train()
# 初始化训练损失为0
running_loss = 0.0
# 使用tqdm库创建一个进度条,用于显示训练进度
train_bar = tqdm(train_loader, file=sys.stdout)
# 开始训练步骤的循环
for step, data in enumerate(train_bar):
# 从数据加载器中获取图片和标签
images, labels = data
# 清空梯度
optimizer.zero_grad()
# 将图片和标签移动到指定的设备上
outputs = net(images.to(device))
# 计算损失
loss = loss_function(outputs, labels.to(device))
# 反向传播计算梯度
loss.backward()
# 使用优化器更新模型参数
optimizer.step()
# 累加训练损失
running_loss += loss.item()
# 更新进度条的描述,显示当前训练轮数、总轮数和损失值
train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
epochs,
loss)
# 进入验证阶段
net.eval() # 将模型设置为评估模式,关闭dropout和batch normalization的某些行为
acc = 0.0 # 初始化累计正确的数量,用于计算准确率
# 不计算梯度,因为验证阶段不需要反向传播
with torch.no_grad():
val_bar = tqdm(validate_loader, file=sys.stdout) # 创建验证数据加载器的进度条
for val_data in val_bar: # 遍历验证数据
val_images, val_labels = val_data # 获取验证图片和标签
outputs = net(val_images.to(device)) # 前向传播,得到预测输出
predict_y = torch.max(outputs, dim=1)[1] # 获取预测的最大概率对应的类别索引
acc += torch.eq(predict_y, val_labels.to(device)).sum().item() # 计算预测正确的数量,并累加
# 计算验证准确率
val_accurate = acc / val_num
# 打印当前轮数的训练损失和验证准确率
print('[epoch %d] train_loss: %.3f val_accuracy: %.3f' %
(epoch + 1, running_loss / train_steps, val_accurate))
# 如果当前验证准确率比之前保存的最高准确率还要高
if val_accurate > best_acc:
best_acc = val_accurate # 更新最高准确率
torch.save(net.state_dict(), save_path) # 保存当前模型状态字典到指定路径
# 训练完成后打印结束信息
print('Finished Training')


if __name__ == '__main__':
main()

using cpu device.
Using 2 dataloader workers every process
using 604 images for training, 58 images for validation.
train epoch[1/100] loss:0.170: 100%|██████████| 19/19 [02:31<00:00,  7.96s/it]
100%|██████████| 2/2 [00:05<00:00,  2.83s/it]
[epoch 1] train_loss: 0.400  val_accuracy: 0.552
Finished Training
train epoch[2/100] loss:0.116: 100%|██████████| 19/19 [02:12<00:00,  6.98s/it]
100%|██████████| 2/2 [00:04<00:00,  2.12s/it]
[epoch 2] train_loss: 0.159  val_accuracy: 0.966
Finished Training
train epoch[3/100] loss:0.029: 100%|██████████| 19/19 [02:12<00:00,  6.96s/it]
100%|██████████| 2/2 [00:04<00:00,  2.11s/it]
[epoch 3] train_loss: 0.156  val_accuracy: 0.983
Finished Training
train epoch[4/100] loss:0.060: 100%|██████████| 19/19 [02:12<00:00,  6.96s/it]
100%|██████████| 2/2 [00:04<00:00,  2.07s/it]
[epoch 4] train_loss: 0.113  val_accuracy: 0.966
Finished Training
train epoch[5/100] loss:0.240: 100%|██████████| 19/19 [02:13<00:00,  7.04s/it]
100%|██████████| 2/2 [00:04<00:00,  2.14s/it]
[epoch 5] train_loss: 0.178  val_accuracy: 0.862
Finished Training
train epoch[6/100] loss:0.187: 100%|██████████| 19/19 [02:15<00:00,  7.15s/it]
100%|██████████| 2/2 [00:04<00:00,  2.10s/it]
[epoch 6] train_loss: 0.132  val_accuracy: 0.966
Finished Training
train epoch[7/100] loss:0.173: 100%|██████████| 19/19 [02:13<00:00,  7.04s/it]
100%|██████████| 2/2 [00:04<00:00,  2.17s/it]
[epoch 7] train_loss: 0.138  val_accuracy: 1.000
Finished Training
train epoch[8/100] loss:0.076: 100%|██████████| 19/19 [02:13<00:00,  7.03s/it]
100%|██████████| 2/2 [00:04<00:00,  2.19s/it]
[epoch 8] train_loss: 0.109  val_accuracy: 1.000
Finished Training
train epoch[9/100] loss:0.041: 100%|██████████| 19/19 [02:15<00:00,  7.11s/it]
100%|██████████| 2/2 [00:04<00:00,  2.16s/it]
[epoch 9] train_loss: 0.107  val_accuracy: 1.000
Finished Training
train epoch[10/100] loss:0.093: 100%|██████████| 19/19 [02:13<00:00,  7.04s/it]
100%|██████████| 2/2 [00:04<00:00,  2.13s/it]
[epoch 10] train_loss: 0.091  val_accuracy: 1.000
Finished Training
train epoch[11/100] loss:0.019: 100%|██████████| 19/19 [02:12<00:00,  6.97s/it]
100%|██████████| 2/2 [00:04<00:00,  2.09s/it]
[epoch 11] train_loss: 0.097  val_accuracy: 0.983
Finished Training
train epoch[12/100] loss:0.296: 100%|██████████| 19/19 [02:13<00:00,  7.03s/it]
100%|██████████| 2/2 [00:04<00:00,  2.09s/it]
[epoch 12] train_loss: 0.145  val_accuracy: 1.000
Finished Training
train epoch[13/100] loss:0.042: 100%|██████████| 19/19 [02:14<00:00,  7.08s/it]
100%|██████████| 2/2 [00:04<00:00,  2.18s/it]
[epoch 13] train_loss: 0.106  val_accuracy: 1.000
Finished Training
train epoch[14/100] loss:0.012: 100%|██████████| 19/19 [02:12<00:00,  6.99s/it]
100%|██████████| 2/2 [00:04<00:00,  2.14s/it]
[epoch 14] train_loss: 0.073  val_accuracy: 1.000
Finished Training
train epoch[15/100] loss:0.049: 100%|██████████| 19/19 [02:13<00:00,  7.01s/it]
100%|██████████| 2/2 [00:04<00:00,  2.07s/it]
[epoch 15] train_loss: 0.099  val_accuracy: 1.000
Finished Training
train epoch[16/100] loss:0.028: 100%|██████████| 19/19 [02:13<00:00,  7.05s/it]
100%|██████████| 2/2 [00:04<00:00,  2.14s/it]
[epoch 16] train_loss: 0.091  val_accuracy: 1.000
Finished Training
train epoch[17/100] loss:0.040: 100%|██████████| 19/19 [02:14<00:00,  7.08s/it]
100%|██████████| 2/2 [00:04<00:00,  2.14s/it]
[epoch 17] train_loss: 0.091  val_accuracy: 0.948
Finished Training
train epoch[18/100] loss:0.141: 100%|██████████| 19/19 [02:15<00:00,  7.13s/it]
100%|██████████| 2/2 [00:04<00:00,  2.14s/it]
[epoch 18] train_loss: 0.090  val_accuracy: 1.000
Finished Training
train epoch[19/100] loss:0.148:   5%|▌         | 1/19 [00:09<02:57,  9.85s/it]
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
Cell In[12], line 158
    154         print('Finished Training')
    157 if __name__ == '__main__':
--> 158     main()

Cell In[12], line 120, in main()
    118 optimizer.zero_grad()
    119 # 将图片和标签移动到指定的设备上
--> 120 outputs = net(images.to(device))
    121 # 计算损失
    122 loss = loss_function(outputs, labels.to(device))

File /opt/mamba/lib/python3.10/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File /opt/mamba/lib/python3.10/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File /opt/mamba/lib/python3.10/site-packages/torchvision/models/resnet.py:285, in ResNet.forward(self, x)
    284 def forward(self, x: Tensor) -> Tensor:
--> 285     return self._forward_impl(x)

File /opt/mamba/lib/python3.10/site-packages/torchvision/models/resnet.py:276, in ResNet._forward_impl(self, x)
    274 x = self.layer2(x)
    275 x = self.layer3(x)
--> 276 x = self.layer4(x)
    278 x = self.avgpool(x)
    279 x = torch.flatten(x, 1)

File /opt/mamba/lib/python3.10/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File /opt/mamba/lib/python3.10/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File /opt/mamba/lib/python3.10/site-packages/torch/nn/modules/container.py:250, in Sequential.forward(self, input)
    248 def forward(self, input):
    249     for module in self:
--> 250         input = module(input)
    251     return input

File /opt/mamba/lib/python3.10/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File /opt/mamba/lib/python3.10/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File /opt/mamba/lib/python3.10/site-packages/torchvision/models/resnet.py:92, in BasicBlock.forward(self, x)
     89 def forward(self, x: Tensor) -> Tensor:
     90     identity = x
---> 92     out = self.conv1(x)
     93     out = self.bn1(out)
     94     out = self.relu(out)

File /opt/mamba/lib/python3.10/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File /opt/mamba/lib/python3.10/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File /opt/mamba/lib/python3.10/site-packages/torch/nn/modules/conv.py:554, in Conv2d.forward(self, input)
    553 def forward(self, input: Tensor) -> Tensor:
--> 554     return self._conv_forward(input, self.weight, self.bias)

File /opt/mamba/lib/python3.10/site-packages/torch/nn/modules/conv.py:549, in Conv2d._conv_forward(self, input, weight, bias)
    537 if self.padding_mode != "zeros":
    538     return F.conv2d(
    539         F.pad(
    540             input, self._reversed_padding_repeated_twice, mode=self.padding_mode
   (...)
    547         self.groups,
    548     )
--> 549 return F.conv2d(
    550     input, weight, bias, self.stride, self.padding, self.dilation, self.groups
    551 )

KeyboardInterrupt: 
代码
文本

以上代码,采用的强制停止训练的报错。与代码运行无关。

代码
文本
[2]
!pip install torchvision

Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
Collecting torchvision
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/a2/f6/7ff89a9f8703f623f5664afd66c8600e3f09fe188e1e0b7e6f9a8617f865/torchvision-0.20.1-cp310-cp310-manylinux1_x86_64.whl (7.2 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 7.2/7.2 MB 52.0 MB/s eta 0:00:00
Requirement already satisfied: numpy in /opt/deepmd-kit-3.0.0b4/lib/python3.10/site-packages (from torchvision) (1.26.4)
Collecting torch==2.5.1 (from torchvision)
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/2a/ef/834af4a885b31a0b32fff2d80e1e40f771e1566ea8ded55347502440786a/torch-2.5.1-cp310-cp310-manylinux1_x86_64.whl (906.4 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 906.4/906.4 MB 9.1 MB/s eta 0:00:0000:0100:03
Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /opt/deepmd-kit-3.0.0b4/lib/python3.10/site-packages (from torchvision) (10.4.0)
Requirement already satisfied: filelock in /opt/deepmd-kit-3.0.0b4/lib/python3.10/site-packages (from torch==2.5.1->torchvision) (3.16.1)
Requirement already satisfied: typing-extensions>=4.8.0 in /opt/deepmd-kit-3.0.0b4/lib/python3.10/site-packages (from torch==2.5.1->torchvision) (4.12.2)
Requirement already satisfied: networkx in /opt/deepmd-kit-3.0.0b4/lib/python3.10/site-packages (from torch==2.5.1->torchvision) (3.3)
Requirement already satisfied: jinja2 in /opt/deepmd-kit-3.0.0b4/lib/python3.10/site-packages (from torch==2.5.1->torchvision) (3.1.4)
Collecting fsspec (from torch==2.5.1->torchvision)
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/c6/b2/454d6e7f0158951d8a78c2e1eb4f69ae81beb8dca5fee9809c6c99e9d0d0/fsspec-2024.10.0-py3-none-any.whl (179 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch==2.5.1->torchvision)
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/2c/14/91ae57cd4db3f9ef7aa99f4019cfa8d54cb4caa7e00975df6467e9725a9f/nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (24.6 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 24.6/24.6 MB 24.4 MB/s eta 0:00:0000:0100:01
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch==2.5.1->torchvision)
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/ea/27/1795d86fe88ef397885f2e580ac37628ed058a92ed2c39dc8eac3adf0619/nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (883 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 883.7/883.7 kB 9.1 MB/s eta 0:00:00
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch==2.5.1->torchvision)
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/67/42/f4f60238e8194a3106d06a058d494b18e006c10bb2b915655bd9f6ea4cb1/nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (13.8 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 13.8/13.8 MB 14.8 MB/s eta 0:00:00a 0:00:01
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch==2.5.1->torchvision)
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/9f/fd/713452cd72343f682b1c7b9321e23829f00b842ceaedcda96e742ea0b0b3/nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl (664.8 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 664.8/664.8 MB 10.1 MB/s eta 0:00:0000:0100:02
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch==2.5.1->torchvision)
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/ae/71/1c91302526c45ab494c23f61c7a84aa568b8c1f9d196efa5993957faf906/nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl (363.4 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 363.4/363.4 MB 11.3 MB/s eta 0:00:0000:0100:01
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch==2.5.1->torchvision)
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/27/94/3266821f65b92b3138631e9c8e7fe1fb513804ac934485a8d05776e1dd43/nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl (211.5 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 211.5/211.5 MB 12.6 MB/s eta 0:00:0000:0100:01
Collecting nvidia-curand-cu12==10.3.5.147 (from torch==2.5.1->torchvision)
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/8a/6d/44ad094874c6f1b9c654f8ed939590bdc408349f137f9b98a3a23ccec411/nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl (56.3 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 56.3/56.3 MB 15.6 MB/s eta 0:00:0000:0100:01
Collecting nvidia-cusolver-cu12==11.6.1.9 (from torch==2.5.1->torchvision)
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/3a/e1/5b9089a4b2a4790dfdea8b3a006052cfecff58139d5a4e34cb1a51df8d6f/nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl (127.9 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 127.9/127.9 MB 12.6 MB/s eta 0:00:0000:0100:01
Collecting nvidia-cusparse-cu12==12.3.1.170 (from torch==2.5.1->torchvision)
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/db/f7/97a9ea26ed4bbbfc2d470994b8b4f338ef663be97b8f677519ac195e113d/nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl (207.5 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 207.5/207.5 MB 12.3 MB/s eta 0:00:0000:0100:01
Collecting nvidia-nccl-cu12==2.21.5 (from torch==2.5.1->torchvision)
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/df/99/12cd266d6233f47d00daf3a72739872bdc10267d0383508b0b9c84a18bb6/nvidia_nccl_cu12-2.21.5-py3-none-manylinux2014_x86_64.whl (188.7 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 188.7/188.7 MB 12.8 MB/s eta 0:00:0000:0100:01
Collecting nvidia-nvtx-cu12==12.4.127 (from torch==2.5.1->torchvision)
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/87/20/199b8713428322a2f22b722c62b8cc278cc53dffa9705d744484b5035ee9/nvidia_nvtx_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (99 kB)
Collecting nvidia-nvjitlink-cu12==12.4.127 (from torch==2.5.1->torchvision)
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/ff/ff/847841bacfbefc97a00036e0fce5a0f086b640756dc38caea5e1bb002655/nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (21.1 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 21.1/21.1 MB 28.6 MB/s eta 0:00:00a 0:00:01
Collecting triton==3.1.0 (from torch==2.5.1->torchvision)
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/98/29/69aa56dc0b2eb2602b553881e34243475ea2afd9699be042316842788ff5/triton-3.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (209.5 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 209.5/209.5 MB 12.0 MB/s eta 0:00:0000:0100:01
Collecting sympy==1.13.1 (from torch==2.5.1->torchvision)
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/b2/fe/81695a1aa331a842b582453b605175f419fe8540355886031328089d840a/sympy-1.13.1-py3-none-any.whl (6.2 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 6.2/6.2 MB 15.4 MB/s eta 0:00:00a 0:00:01
Requirement already satisfied: mpmath<1.4,>=1.1.0 in /opt/deepmd-kit-3.0.0b4/lib/python3.10/site-packages (from sympy==1.13.1->torch==2.5.1->torchvision) (1.3.0)
Requirement already satisfied: MarkupSafe>=2.0 in /opt/deepmd-kit-3.0.0b4/lib/python3.10/site-packages (from jinja2->torch==2.5.1->torchvision) (2.1.5)
Installing collected packages: triton, sympy, nvidia-nvtx-cu12, nvidia-nvjitlink-cu12, nvidia-nccl-cu12, nvidia-curand-cu12, nvidia-cufft-cu12, nvidia-cuda-runtime-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-cupti-cu12, nvidia-cublas-cu12, fsspec, nvidia-cusparse-cu12, nvidia-cudnn-cu12, nvidia-cusolver-cu12, torch, torchvision
  Attempting uninstall: sympy
    Found existing installation: sympy 1.13.2
    Uninstalling sympy-1.13.2:
      Successfully uninstalled sympy-1.13.2
  Attempting uninstall: torch
    Found existing installation: torch 2.0.0.post200
    Uninstalling torch-2.0.0.post200:
      Successfully uninstalled torch-2.0.0.post200
Successfully installed fsspec-2024.10.0 nvidia-cublas-cu12-12.4.5.8 nvidia-cuda-cupti-cu12-12.4.127 nvidia-cuda-nvrtc-cu12-12.4.127 nvidia-cuda-runtime-cu12-12.4.127 nvidia-cudnn-cu12-9.1.0.70 nvidia-cufft-cu12-11.2.1.3 nvidia-curand-cu12-10.3.5.147 nvidia-cusolver-cu12-11.6.1.9 nvidia-cusparse-cu12-12.3.1.170 nvidia-nccl-cu12-2.21.5 nvidia-nvjitlink-cu12-12.4.127 nvidia-nvtx-cu12-12.4.127 sympy-1.13.1 torch-2.5.1 torchvision-0.20.1 triton-3.1.0
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager, possibly rendering your system unusable.It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv. Use the --root-user-action option if you know what you are doing and want to suppress this warning.
代码
文本
[4]
!pip uninstall torch torchvision
!pip install torch==2.0.1 torchvision==0.15.2
Found existing installation: torch 2.5.1
Uninstalling torch-2.5.1:
  Would remove:
    /opt/deepmd-kit-3.0.0b4/bin/convert-caffe2-to-onnx
    /opt/deepmd-kit-3.0.0b4/bin/convert-onnx-to-caffe2
    /opt/deepmd-kit-3.0.0b4/bin/torchfrtrace
    /opt/deepmd-kit-3.0.0b4/bin/torchrun
    /opt/deepmd-kit-3.0.0b4/lib/python3.10/site-packages/functorch/*
    /opt/deepmd-kit-3.0.0b4/lib/python3.10/site-packages/torch-2.5.1.dist-info/*
    /opt/deepmd-kit-3.0.0b4/lib/python3.10/site-packages/torch/*
    /opt/deepmd-kit-3.0.0b4/lib/python3.10/site-packages/torchgen/*
Proceed (Y/n)? ^C
ERROR: Operation cancelled by user
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
Collecting torch==2.0.1
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/8c/4d/17e07377c9c3d1a0c4eb3fde1c7c16b5a0ce6133ddbabc08ceef6b7f2645/torch-2.0.1-cp310-cp310-manylinux1_x86_64.whl (619.9 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 619.9/619.9 MB 10.2 MB/s eta 0:00:0000:0100:02
Collecting torchvision==0.15.2
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/87/0f/88f023bf6176d9af0f85feedf4be129f9cf2748801c4d9c690739a10c100/torchvision-0.15.2-cp310-cp310-manylinux1_x86_64.whl (6.0 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 6.0/6.0 MB 27.0 MB/s eta 0:00:00a 0:00:01
Requirement already satisfied: filelock in /opt/deepmd-kit-3.0.0b4/lib/python3.10/site-packages (from torch==2.0.1) (3.16.1)
Requirement already satisfied: typing-extensions in /opt/deepmd-kit-3.0.0b4/lib/python3.10/site-packages (from torch==2.0.1) (4.12.2)
Requirement already satisfied: sympy in /opt/deepmd-kit-3.0.0b4/lib/python3.10/site-packages (from torch==2.0.1) (1.13.1)
Requirement already satisfied: networkx in /opt/deepmd-kit-3.0.0b4/lib/python3.10/site-packages (from torch==2.0.1) (3.3)
Requirement already satisfied: jinja2 in /opt/deepmd-kit-3.0.0b4/lib/python3.10/site-packages (from torch==2.0.1) (3.1.4)
Collecting nvidia-cuda-nvrtc-cu11==11.7.99 (from torch==2.0.1)
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/ef/25/922c5996aada6611b79b53985af7999fc629aee1d5d001b6a22431e18fec/nvidia_cuda_nvrtc_cu11-11.7.99-2-py3-none-manylinux1_x86_64.whl (21.0 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 21.0/21.0 MB 22.8 MB/s eta 0:00:00a 0:00:01
Collecting nvidia-cuda-runtime-cu11==11.7.99 (from torch==2.0.1)
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/36/92/89cf558b514125d2ebd8344dd2f0533404b416486ff681d5434a5832a019/nvidia_cuda_runtime_cu11-11.7.99-py3-none-manylinux1_x86_64.whl (849 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 849.3/849.3 kB 16.1 MB/s eta 0:00:00
Collecting nvidia-cuda-cupti-cu11==11.7.101 (from torch==2.0.1)
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/e6/9d/dd0cdcd800e642e3c82ee3b5987c751afd4f3fb9cc2752517f42c3bc6e49/nvidia_cuda_cupti_cu11-11.7.101-py3-none-manylinux1_x86_64.whl (11.8 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 11.8/11.8 MB 11.5 MB/s eta 0:00:0000:010:01
Collecting nvidia-cudnn-cu11==8.5.0.96 (from torch==2.0.1)
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/dc/30/66d4347d6e864334da5bb1c7571305e501dcb11b9155971421bb7bb5315f/nvidia_cudnn_cu11-8.5.0.96-2-py3-none-manylinux1_x86_64.whl (557.1 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 557.1/557.1 MB 10.5 MB/s eta 0:00:0000:0100:02
Collecting nvidia-cublas-cu11==11.10.3.66 (from torch==2.0.1)
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/ce/41/fdeb62b5437996e841d83d7d2714ca75b886547ee8017ee2fe6ea409d983/nvidia_cublas_cu11-11.10.3.66-py3-none-manylinux1_x86_64.whl (317.1 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 317.1/317.1 MB 11.5 MB/s eta 0:00:0000:0100:01
Collecting nvidia-cufft-cu11==10.9.0.58 (from torch==2.0.1)
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/64/c8/133717b43182ba063803e983e7680a94826a9f4ff5734af0ca315803f1b3/nvidia_cufft_cu11-10.9.0.58-py3-none-manylinux2014_x86_64.whl (168.4 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 168.4/168.4 MB 12.8 MB/s eta 0:00:0000:0100:01
Collecting nvidia-curand-cu11==10.2.10.91 (from torch==2.0.1)
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/8f/11/af78d54b2420e64a4dd19e704f5bb69dcb5a6a3138b4465d6a48cdf59a21/nvidia_curand_cu11-10.2.10.91-py3-none-manylinux1_x86_64.whl (54.6 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 54.6/54.6 MB 15.7 MB/s eta 0:00:0000:0100:01
Collecting nvidia-cusolver-cu11==11.4.0.1 (from torch==2.0.1)
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/3e/77/66149e3153b19312fb782ea367f3f950123b93916a45538b573fe373570a/nvidia_cusolver_cu11-11.4.0.1-2-py3-none-manylinux1_x86_64.whl (102.6 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 102.6/102.6 MB 12.7 MB/s eta 0:00:0000:0100:01
Collecting nvidia-cusparse-cu11==11.7.4.91 (from torch==2.0.1)
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/ea/6f/6d032cc1bb7db88a989ddce3f4968419a7edeafda362847f42f614b1f845/nvidia_cusparse_cu11-11.7.4.91-py3-none-manylinux1_x86_64.whl (173.2 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 173.2/173.2 MB 12.3 MB/s eta 0:00:0000:0100:01
Collecting nvidia-nccl-cu11==2.14.3 (from torch==2.0.1)
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/55/92/914cdb650b6a5d1478f83148597a25e90ea37d739bd563c5096b0e8a5f43/nvidia_nccl_cu11-2.14.3-py3-none-manylinux1_x86_64.whl (177.1 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 177.1/177.1 MB 12.5 MB/s eta 0:00:0000:0100:01
Collecting nvidia-nvtx-cu11==11.7.91 (from torch==2.0.1)
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/23/d5/09493ff0e64fd77523afbbb075108f27a13790479efe86b9ffb4587671b5/nvidia_nvtx_cu11-11.7.91-py3-none-manylinux1_x86_64.whl (98 kB)
Collecting triton==2.0.0 (from torch==2.0.1)
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/ca/31/ff6be541195daf77aa5c72303b2354661a69e717967d44d91eb4f3fdce32/triton-2.0.0-1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (63.3 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 63.3/63.3 MB 13.0 MB/s eta 0:00:0000:0100:01
Requirement already satisfied: numpy in /opt/deepmd-kit-3.0.0b4/lib/python3.10/site-packages (from torchvision==0.15.2) (1.26.4)
Requirement already satisfied: requests in /opt/deepmd-kit-3.0.0b4/lib/python3.10/site-packages (from torchvision==0.15.2) (2.32.3)
Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /opt/deepmd-kit-3.0.0b4/lib/python3.10/site-packages (from torchvision==0.15.2) (10.4.0)
Requirement already satisfied: setuptools in /opt/deepmd-kit-3.0.0b4/lib/python3.10/site-packages (from nvidia-cublas-cu11==11.10.3.66->torch==2.0.1) (75.1.0)
Requirement already satisfied: wheel in /opt/deepmd-kit-3.0.0b4/lib/python3.10/site-packages (from nvidia-cublas-cu11==11.10.3.66->torch==2.0.1) (0.44.0)
Collecting cmake (from triton==2.0.0->torch==2.0.1)
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/4b/87/3b01bc707c7a021f3c2c2d133af405618b1580613e22a9ead4a9d9d49903/cmake-3.30.5-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (26.9 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 26.9/26.9 MB 10.3 MB/s eta 0:00:00a 0:00:01
Collecting lit (from triton==2.0.0->torch==2.0.1)
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/96/06/b36f150fa7c5bcc96a31a4d19a20fddbd1d965b6f02510b57a3bb8d4b930/lit-18.1.8-py3-none-any.whl (96 kB)
Requirement already satisfied: MarkupSafe>=2.0 in /opt/deepmd-kit-3.0.0b4/lib/python3.10/site-packages (from jinja2->torch==2.0.1) (2.1.5)
Requirement already satisfied: charset-normalizer<4,>=2 in /opt/deepmd-kit-3.0.0b4/lib/python3.10/site-packages (from requests->torchvision==0.15.2) (3.3.2)
Requirement already satisfied: idna<4,>=2.5 in /opt/deepmd-kit-3.0.0b4/lib/python3.10/site-packages (from requests->torchvision==0.15.2) (3.10)
Requirement already satisfied: urllib3<3,>=1.21.1 in /opt/deepmd-kit-3.0.0b4/lib/python3.10/site-packages (from requests->torchvision==0.15.2) (2.2.3)
Requirement already satisfied: certifi>=2017.4.17 in /opt/deepmd-kit-3.0.0b4/lib/python3.10/site-packages (from requests->torchvision==0.15.2) (2024.8.30)
Requirement already satisfied: mpmath<1.4,>=1.1.0 in /opt/deepmd-kit-3.0.0b4/lib/python3.10/site-packages (from sympy->torch==2.0.1) (1.3.0)
Installing collected packages: lit, nvidia-nvtx-cu11, nvidia-nccl-cu11, nvidia-cusparse-cu11, nvidia-curand-cu11, nvidia-cufft-cu11, nvidia-cuda-runtime-cu11, nvidia-cuda-nvrtc-cu11, nvidia-cuda-cupti-cu11, nvidia-cublas-cu11, cmake, nvidia-cusolver-cu11, nvidia-cudnn-cu11, triton, torch, torchvision
  Attempting uninstall: triton
    Found existing installation: triton 3.1.0
    Uninstalling triton-3.1.0:
      Successfully uninstalled triton-3.1.0
  Attempting uninstall: torch
    Found existing installation: torch 2.5.1
    Uninstalling torch-2.5.1:
      Successfully uninstalled torch-2.5.1
  Attempting uninstall: torchvision
    Found existing installation: torchvision 0.20.1
    Uninstalling torchvision-0.20.1:
      Successfully uninstalled torchvision-0.20.1
Successfully installed cmake-3.30.5 lit-18.1.8 nvidia-cublas-cu11-11.10.3.66 nvidia-cuda-cupti-cu11-11.7.101 nvidia-cuda-nvrtc-cu11-11.7.99 nvidia-cuda-runtime-cu11-11.7.99 nvidia-cudnn-cu11-8.5.0.96 nvidia-cufft-cu11-10.9.0.58 nvidia-curand-cu11-10.2.10.91 nvidia-cusolver-cu11-11.4.0.1 nvidia-cusparse-cu11-11.7.4.91 nvidia-nccl-cu11-2.14.3 nvidia-nvtx-cu11-11.7.91 torch-2.0.1 torchvision-0.15.2 triton-2.0.0
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager, possibly rendering your system unusable.It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv. Use the --root-user-action option if you know what you are doing and want to suppress this warning.
代码
文本

第一版修改,效果不好,舍去

代码
文本
[6]
import os
import sys
import json

import torch
import torch.nn as nn
from torchvision import transforms, datasets
import torch.optim as optim
from tqdm import tqdm

from torchvision.models import resnet34

def main():
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("using {} device.".format(device))

data_transform = {
"train": transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(15),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
"val": transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
}

data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))
image_path = os.path.join(data_root, "/personal/Auto_Titration/Picture_Train/data/")
assert os.path.exists(image_path), "{} path does not exist.".format(image_path)

train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
transform=data_transform["train"])
train_num = len(train_dataset)
flower_list = train_dataset.class_to_idx
cla_dict = dict((val, key) for key, val in flower_list.items())
json_str = json.dumps(cla_dict, indent=4)
with open('class_indices.json', 'w') as json_file:
json_file.write(json_str)

batch_size = 32
nw = min([os.cpu_count() // 2, batch_size if batch_size > 1 else 0, 8])
print('Using {} dataloader workers every process'.format(nw))

train_loader = torch.utils.data.DataLoader(train_dataset,
batch_size=batch_size, shuffle=True,
num_workers=nw)
validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
transform=data_transform["val"])
val_num = len(validate_dataset)
validate_loader = torch.utils.data.DataLoader(validate_dataset,
batch_size=batch_size, shuffle=False,
num_workers=nw)
print("using {} images for training, {} images for validation.".format(train_num, val_num))

model_name = "resnet34-3"
net = resnet34(num_classes=2)
net.to(device)
loss_function = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.0001)
epochs = 100
best_acc = 0.0
patience_counter = 0
patience_limit = 10
save_path = './{}Net.pth'.format(model_name)
train_steps = len(train_loader)

for epoch in range(epochs):
net.train()
running_loss = 0.0
train_bar = tqdm(train_loader, file=sys.stdout)
for step, data in enumerate(train_bar):
images, labels = data
optimizer.zero_grad()
outputs = net(images.to(device))
loss = loss_function(outputs, labels.to(device))
loss.backward()
optimizer.step()
running_loss += loss.item()
train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1, epochs, loss)

net.eval()
acc = 0.0
with torch.no_grad():
val_bar = tqdm(validate_loader, file=sys.stdout)
for val_data in val_bar:
val_images, val_labels = val_data
outputs = net(val_images.to(device))
predict_y = torch.max(outputs, dim=1)[1]
acc += torch.eq(predict_y, val_labels.to(device)).sum().item()

val_accurate = acc / val_num
print('[epoch %d] train_loss: %.3f val_accuracy: %.3f' %
(epoch + 1, running_loss / train_steps, val_accurate))

if val_accurate > best_acc:
best_acc = val_accurate
torch.save(net.state_dict(), save_path)
patience_counter = 0
else:
patience_counter += 1
if patience_counter > patience_limit:
print("Early stopping triggered.")
break

print('Finished Training')

if __name__ == '__main__':
main()

/opt/deepmd-kit-3.0.0b4/lib/python3.10/site-packages/torchvision/io/image.py:13: UserWarning: Failed to load image Python extension: '/opt/deepmd-kit-3.0.0b4/lib/python3.10/site-packages/torchvision/image.so: undefined symbol: _ZN3c106detail23torchInternalAssertFailEPKcS2_jS2_RKSs'If you don't plan on using image functionality from `torchvision.io`, you can ignore this warning. Otherwise, there might be something wrong with your environment. Did you have `libjpeg` or `libpng` installed before building `torchvision` from source?
  warn(
using cpu device.
Using 8 dataloader workers every process
using 604 images for training, 58 images for validation.
train epoch[1/100] loss:0.609: 100%|██████████| 19/19 [01:03<00:00,  3.35s/it]
100%|██████████| 2/2 [00:08<00:00,  4.31s/it]
[epoch 1] train_loss: 0.697  val_accuracy: 0.741
train epoch[2/100] loss:0.711: 100%|██████████| 19/19 [01:06<00:00,  3.52s/it]
100%|██████████| 2/2 [00:02<00:00,  1.37s/it]
[epoch 2] train_loss: 0.595  val_accuracy: 0.379
train epoch[3/100] loss:0.575: 100%|██████████| 19/19 [01:02<00:00,  3.31s/it]
100%|██████████| 2/2 [00:02<00:00,  1.44s/it]
[epoch 3] train_loss: 0.588  val_accuracy: 0.741
train epoch[4/100] loss:0.775: 100%|██████████| 19/19 [01:12<00:00,  3.81s/it]
100%|██████████| 2/2 [00:03<00:00,  1.53s/it]
[epoch 4] train_loss: 0.574  val_accuracy: 0.724
train epoch[5/100] loss:0.667: 100%|██████████| 19/19 [01:17<00:00,  4.09s/it]
100%|██████████| 2/2 [00:03<00:00,  1.59s/it]
[epoch 5] train_loss: 0.565  val_accuracy: 0.741
train epoch[6/100] loss:0.601: 100%|██████████| 19/19 [01:08<00:00,  3.61s/it]
100%|██████████| 2/2 [00:03<00:00,  1.53s/it]
[epoch 6] train_loss: 0.554  val_accuracy: 0.879
train epoch[7/100] loss:0.642: 100%|██████████| 19/19 [01:15<00:00,  3.97s/it]
100%|██████████| 2/2 [00:02<00:00,  1.41s/it]
[epoch 7] train_loss: 0.572  val_accuracy: 0.741
train epoch[8/100] loss:0.465: 100%|██████████| 19/19 [01:12<00:00,  3.79s/it]
100%|██████████| 2/2 [00:05<00:00,  2.89s/it]
[epoch 8] train_loss: 0.564  val_accuracy: 0.741
train epoch[9/100] loss:0.473: 100%|██████████| 19/19 [01:08<00:00,  3.60s/it]
100%|██████████| 2/2 [00:02<00:00,  1.36s/it]
[epoch 9] train_loss: 0.540  val_accuracy: 0.759
train epoch[10/100] loss:0.583: 100%|██████████| 19/19 [01:08<00:00,  3.60s/it]
100%|██████████| 2/2 [00:04<00:00,  2.31s/it]
[epoch 10] train_loss: 0.532  val_accuracy: 0.810
train epoch[11/100] loss:0.754: 100%|██████████| 19/19 [01:01<00:00,  3.23s/it]
100%|██████████| 2/2 [00:02<00:00,  1.43s/it]
[epoch 11] train_loss: 0.539  val_accuracy: 0.741
train epoch[12/100] loss:0.441: 100%|██████████| 19/19 [01:07<00:00,  3.57s/it]
100%|██████████| 2/2 [00:02<00:00,  1.38s/it]
[epoch 12] train_loss: 0.532  val_accuracy: 0.741
train epoch[13/100] loss:0.516: 100%|██████████| 19/19 [01:04<00:00,  3.41s/it]
100%|██████████| 2/2 [00:02<00:00,  1.35s/it]
[epoch 13] train_loss: 0.498  val_accuracy: 0.707
train epoch[14/100] loss:0.511: 100%|██████████| 19/19 [01:07<00:00,  3.56s/it]
100%|██████████| 2/2 [00:03<00:00,  1.52s/it]
[epoch 14] train_loss: 0.553  val_accuracy: 0.741
train epoch[15/100] loss:0.543: 100%|██████████| 19/19 [01:09<00:00,  3.63s/it]
100%|██████████| 2/2 [00:03<00:00,  1.59s/it]
[epoch 15] train_loss: 0.510  val_accuracy: 0.759
train epoch[16/100] loss:0.286: 100%|██████████| 19/19 [01:12<00:00,  3.82s/it]
100%|██████████| 2/2 [00:02<00:00,  1.45s/it]
[epoch 16] train_loss: 0.466  val_accuracy: 0.759
train epoch[17/100] loss:0.498: 100%|██████████| 19/19 [01:04<00:00,  3.38s/it]
100%|██████████| 2/2 [00:02<00:00,  1.45s/it]
[epoch 17] train_loss: 0.524  val_accuracy: 0.638
Early stopping triggered.
Finished Training
代码
文本

第二版修改,效修改了代码中的早停机制,效果不好,舍去

代码
文本
[7]
import os
import sys
import json
import time # 新增
import torch
import torch.nn as nn
from torchvision import transforms, datasets
import torch.optim as optim
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter # 新增
from torchvision.models import resnet34
def main():
# 初始化 TensorBoard
writer = SummaryWriter('runs/experiment') # 设置日志文件保存路径

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("using {} device.".format(device))

# 数据增强变换
data_transform = {
"train": transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(15),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
"val": transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
}

# 获取数据集路径
data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))
image_path = os.path.join(data_root, "/personal/Auto_Titration/Picture_Train/data/")
assert os.path.exists(image_path), "{} path does not exist.".format(image_path)

# 加载训练数据集
train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
transform=data_transform["train"])
train_num = len(train_dataset)
flower_list = train_dataset.class_to_idx
cla_dict = dict((val, key) for key, val in flower_list.items())
json_str = json.dumps(cla_dict, indent=4)
with open('class_indices.json', 'w') as json_file:
json_file.write(json_str)

batch_size = 32
nw = min([os.cpu_count() // 2, batch_size if batch_size > 1 else 0, 8])
print('Using {} dataloader workers every process'.format(nw))

train_loader = torch.utils.data.DataLoader(train_dataset,
batch_size=batch_size, shuffle=True,
num_workers=nw)

# 加载验证数据集
validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
transform=data_transform["val"])
val_num = len(validate_dataset)
validate_loader = torch.utils.data.DataLoader(validate_dataset,
batch_size=batch_size, shuffle=False,
num_workers=nw)

print("using {} images for training, {} images for validation.".format(train_num, val_num))

model_name = "resnet34-3"
net = resnet34(num_classes=2)
net.to(device)
loss_function = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.0001)
epochs = 100
best_acc = 0.0
patience_counter = 0
patience_limit = 10
save_path = './{}Net.pth'.format(model_name)
train_steps = len(train_loader)

# 开始计时
start_time = time.time()

for epoch in range(epochs):
# 记录每轮epoch的开始时间
epoch_start_time = time.time()

net.train()
running_loss = 0.0
train_bar = tqdm(train_loader, file=sys.stdout)
for step, data in enumerate(train_bar):
images, labels = data
optimizer.zero_grad()
outputs = net(images.to(device))
loss = loss_function(outputs, labels.to(device))
loss.backward()
optimizer.step()
running_loss += loss.item()
train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1, epochs, loss)

# 验证阶段
net.eval()
acc = 0.0
with torch.no_grad():
val_bar = tqdm(validate_loader, file=sys.stdout)
for val_data in val_bar:
val_images, val_labels = val_data
outputs = net(val_images.to(device))
predict_y = torch.max(outputs, dim=1)[1]
acc += torch.eq(predict_y, val_labels.to(device)).sum().item()

val_accurate = acc / val_num
epoch_time = time.time() - epoch_start_time # 计算当前epoch时长

# 打印当前epoch损失、验证准确率和时长
print('[epoch %d] train_loss: %.3f val_accuracy: %.3f epoch_time: %.2f sec' %
(epoch + 1, running_loss / train_steps, val_accurate, epoch_time))

# 将数据记录到TensorBoard
writer.add_scalar('Loss/train', running_loss / train_steps, epoch)
writer.add_scalar('Accuracy/val', val_accurate, epoch)
writer.add_scalar('Time/epoch', epoch_time, epoch)

# Early Stopping 机制
if val_accurate > best_acc:
best_acc = val_accurate
torch.save(net.state_dict(), save_path)
patience_counter = 0
else:
patience_counter += 1
if patience_counter > patience_limit:
print("Early stopping triggered.")
break

# 训练完成后记录总时长
total_time = time.time() - start_time
print(f"Training completed in {total_time:.2f} seconds")
writer.add_text('Training Time', f'Total training time: {total_time:.2f} seconds')

# 关闭 TensorBoard 记录
writer.close()
print('Finished Training')

if __name__ == '__main__':
main()

2024-10-30 22:57:14.344638: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-10-30 22:57:14.344680: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-10-30 22:57:14.344687: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-10-30 22:57:14.478038: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE4.1 SSE4.2 AVX AVX2 AVX512F FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
using cpu device.
Using 8 dataloader workers every process
using 604 images for training, 58 images for validation.
train epoch[1/100] loss:0.648: 100%|██████████| 19/19 [01:07<00:00,  3.57s/it]
100%|██████████| 2/2 [00:03<00:00,  1.58s/it]
[epoch 1] train_loss: 0.690  val_accuracy: 0.741  epoch_time: 70.99 sec
train epoch[2/100] loss:0.603: 100%|██████████| 19/19 [01:20<00:00,  4.23s/it]
100%|██████████| 2/2 [00:03<00:00,  1.56s/it]
[epoch 2] train_loss: 0.601  val_accuracy: 0.741  epoch_time: 83.52 sec
train epoch[3/100] loss:0.615: 100%|██████████| 19/19 [01:03<00:00,  3.35s/it]
100%|██████████| 2/2 [00:02<00:00,  1.49s/it]
[epoch 3] train_loss: 0.611  val_accuracy: 0.741  epoch_time: 66.54 sec
train epoch[4/100] loss:0.750: 100%|██████████| 19/19 [01:11<00:00,  3.74s/it]
100%|██████████| 2/2 [00:03<00:00,  1.54s/it]
[epoch 4] train_loss: 0.596  val_accuracy: 0.741  epoch_time: 74.12 sec
train epoch[5/100] loss:0.694: 100%|██████████| 19/19 [01:08<00:00,  3.63s/it]
100%|██████████| 2/2 [00:03<00:00,  1.57s/it]
[epoch 5] train_loss: 0.570  val_accuracy: 0.776  epoch_time: 72.13 sec
train epoch[6/100] loss:0.706: 100%|██████████| 19/19 [01:18<00:00,  4.11s/it]
100%|██████████| 2/2 [00:03<00:00,  1.52s/it]
[epoch 6] train_loss: 0.554  val_accuracy: 0.741  epoch_time: 81.20 sec
train epoch[7/100] loss:0.569: 100%|██████████| 19/19 [01:11<00:00,  3.76s/it]
100%|██████████| 2/2 [00:03<00:00,  1.65s/it]
[epoch 7] train_loss: 0.569  val_accuracy: 0.741  epoch_time: 74.66 sec
train epoch[8/100] loss:0.612: 100%|██████████| 19/19 [01:11<00:00,  3.77s/it]
100%|██████████| 2/2 [00:02<00:00,  1.34s/it]
[epoch 8] train_loss: 0.559  val_accuracy: 0.741  epoch_time: 74.28 sec
train epoch[9/100] loss:0.706: 100%|██████████| 19/19 [01:20<00:00,  4.22s/it]
100%|██████████| 2/2 [00:02<00:00,  1.40s/it]
[epoch 9] train_loss: 0.547  val_accuracy: 0.741  epoch_time: 82.92 sec
train epoch[10/100] loss:0.775: 100%|██████████| 19/19 [01:12<00:00,  3.81s/it]
100%|██████████| 2/2 [00:02<00:00,  1.37s/it]
[epoch 10] train_loss: 0.603  val_accuracy: 0.741  epoch_time: 75.11 sec
train epoch[11/100] loss:0.585: 100%|██████████| 19/19 [01:31<00:00,  4.81s/it]
100%|██████████| 2/2 [00:02<00:00,  1.38s/it]
[epoch 11] train_loss: 0.546  val_accuracy: 0.655  epoch_time: 94.20 sec
train epoch[12/100] loss:0.430: 100%|██████████| 19/19 [01:06<00:00,  3.51s/it]
100%|██████████| 2/2 [00:02<00:00,  1.43s/it]
[epoch 12] train_loss: 0.545  val_accuracy: 0.741  epoch_time: 69.57 sec
train epoch[13/100] loss:0.548: 100%|██████████| 19/19 [01:11<00:00,  3.76s/it]
100%|██████████| 2/2 [00:02<00:00,  1.39s/it]
[epoch 13] train_loss: 0.540  val_accuracy: 0.569  epoch_time: 74.31 sec
train epoch[14/100] loss:0.499: 100%|██████████| 19/19 [01:04<00:00,  3.39s/it]
100%|██████████| 2/2 [00:02<00:00,  1.31s/it]
[epoch 14] train_loss: 0.536  val_accuracy: 0.552  epoch_time: 66.94 sec
train epoch[15/100] loss:0.536: 100%|██████████| 19/19 [01:05<00:00,  3.44s/it]
100%|██████████| 2/2 [00:03<00:00,  1.55s/it]
[epoch 15] train_loss: 0.520  val_accuracy: 0.741  epoch_time: 68.49 sec
train epoch[16/100] loss:0.633: 100%|██████████| 19/19 [01:26<00:00,  4.57s/it]
100%|██████████| 2/2 [00:02<00:00,  1.36s/it]
[epoch 16] train_loss: 0.553  val_accuracy: 0.741  epoch_time: 89.62 sec
Early stopping triggered.
Training completed in 1219.07 seconds
Finished Training
代码
文本

第三版修改,得到很好的效果

代码
文本
[18]
import os
import sys
import json
import time
import torch
import torch.nn as nn
from torchvision import transforms, datasets
import torch.optim as optim
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
from torchvision.models import resnet34

def main():
writer = SummaryWriter('/personal/Auto_Titration/Picture_Train/experiment_optimized')

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("using {} device.".format(device))

# 精简后的数据增强
data_transform = {
"train": transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
"val": transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
}

data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))
image_path = os.path.join(data_root, "/personal/Auto_Titration/Picture_Train/data/")
assert os.path.exists(image_path), "{} path does not exist.".format(image_path)

train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
transform=data_transform["train"])
train_num = len(train_dataset)
flower_list = train_dataset.class_to_idx
cla_dict = dict((val, key) for key, val in flower_list.items())
json_str = json.dumps(cla_dict, indent=4)
with open('class_indices.json', 'w') as json_file:
json_file.write(json_str)

batch_size = 32
nw = min([os.cpu_count() // 2, batch_size if batch_size > 1 else 0, 8])
print('Using {} dataloader workers every process'.format(nw))

train_loader = torch.utils.data.DataLoader(train_dataset,
batch_size=batch_size, shuffle=True,
num_workers=nw)

validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
transform=data_transform["val"])
val_num = len(validate_dataset)
validate_loader = torch.utils.data.DataLoader(validate_dataset,
batch_size=batch_size, shuffle=False,
num_workers=nw)

print("using {} images for training, {} images for validation.".format(train_num, val_num))

model_name = "resnet34-optimized"
net = resnet34(num_classes=2)
net.to(device)
loss_function = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-4) # 使用SGD优化器
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5) # 学习率调度器

epochs = 100
best_acc = 0.0
patience_counter = 0
patience_limit = 10 # 调低早停耐心次数

save_path = './{}Net.pth'.format(model_name)
train_steps = len(train_loader)

start_time = time.time()

for epoch in range(epochs):
epoch_start_time = time.time()

net.train()
running_loss = 0.0
train_bar = tqdm(train_loader, file=sys.stdout)
for step, data in enumerate(train_bar):
images, labels = data
optimizer.zero_grad()
outputs = net(images.to(device))
loss = loss_function(outputs, labels.to(device))
loss.backward()
optimizer.step()
running_loss += loss.item()
train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1, epochs, loss)

scheduler.step() # 更新学习率

net.eval()
acc = 0.0
with torch.no_grad():
val_bar = tqdm(validate_loader, file=sys.stdout)
for val_data in val_bar:
val_images, val_labels = val_data
outputs = net(val_images.to(device))
predict_y = torch.max(outputs, dim=1)[1]
acc += torch.eq(predict_y, val_labels.to(device)).sum().item()

val_accurate = acc / val_num
epoch_time = time.time() - epoch_start_time

print('[epoch %d] train_loss: %.3f val_accuracy: %.3f epoch_time: %.2f sec' %
(epoch + 1, running_loss / train_steps, val_accurate, epoch_time))

writer.add_scalar('Loss/train', running_loss / train_steps, epoch)
writer.add_scalar('Accuracy/val', val_accurate, epoch)
writer.add_scalar('Time/epoch', epoch_time, epoch)

if val_accurate > best_acc:
best_acc = val_accurate
torch.save(net.state_dict(), save_path)
patience_counter = 0
else:
patience_counter += 1
if patience_counter > patience_limit:
print("Early stopping triggered.")
break

total_time = time.time() - start_time
print(f"Training completed in {total_time:.2f} seconds")
writer.add_text('Training Time', f'Total training time: {total_time:.2f} seconds')

writer.close()
print('Finished Training')

if __name__ == '__main__':
main()

using cpu device.
Using 8 dataloader workers every process
using 604 images for training, 58 images for validation.
train epoch[1/100] loss:0.635: 100%|██████████| 19/19 [00:59<00:00,  3.12s/it]
100%|██████████| 2/2 [00:03<00:00,  1.59s/it]
[epoch 1] train_loss: 0.648  val_accuracy: 0.293  epoch_time: 62.48 sec
train epoch[2/100] loss:0.578: 100%|██████████| 19/19 [01:12<00:00,  3.80s/it]
100%|██████████| 2/2 [00:03<00:00,  1.68s/it]
[epoch 2] train_loss: 0.580  val_accuracy: 0.741  epoch_time: 75.62 sec
train epoch[3/100] loss:0.194: 100%|██████████| 19/19 [01:01<00:00,  3.26s/it]
100%|██████████| 2/2 [00:03<00:00,  1.69s/it]
[epoch 3] train_loss: 0.379  val_accuracy: 0.776  epoch_time: 65.35 sec
train epoch[4/100] loss:0.452: 100%|██████████| 19/19 [01:05<00:00,  3.47s/it]
100%|██████████| 2/2 [00:03<00:00,  1.60s/it]
[epoch 4] train_loss: 0.196  val_accuracy: 1.000  epoch_time: 69.10 sec
train epoch[5/100] loss:0.036: 100%|██████████| 19/19 [01:13<00:00,  3.89s/it]
100%|██████████| 2/2 [00:03<00:00,  1.55s/it]
[epoch 5] train_loss: 0.147  val_accuracy: 0.948  epoch_time: 76.95 sec
train epoch[6/100] loss:0.046: 100%|██████████| 19/19 [01:07<00:00,  3.53s/it]
100%|██████████| 2/2 [00:06<00:00,  3.43s/it]
[epoch 6] train_loss: 0.120  val_accuracy: 0.810  epoch_time: 73.97 sec
train epoch[7/100] loss:0.261: 100%|██████████| 19/19 [01:17<00:00,  4.07s/it]
100%|██████████| 2/2 [00:02<00:00,  1.50s/it]
[epoch 7] train_loss: 0.154  val_accuracy: 0.966  epoch_time: 80.25 sec
train epoch[8/100] loss:0.116: 100%|██████████| 19/19 [01:02<00:00,  3.29s/it]
100%|██████████| 2/2 [00:03<00:00,  1.67s/it]
[epoch 8] train_loss: 0.137  val_accuracy: 0.983  epoch_time: 65.79 sec
train epoch[9/100] loss:0.112: 100%|██████████| 19/19 [01:21<00:00,  4.28s/it]
100%|██████████| 2/2 [00:03<00:00,  1.58s/it]
[epoch 9] train_loss: 0.126  val_accuracy: 0.966  epoch_time: 84.52 sec
train epoch[10/100] loss:0.218: 100%|██████████| 19/19 [01:21<00:00,  4.28s/it]
100%|██████████| 2/2 [00:03<00:00,  1.64s/it]
[epoch 10] train_loss: 0.157  val_accuracy: 0.552  epoch_time: 84.60 sec
train epoch[11/100] loss:0.246: 100%|██████████| 19/19 [01:10<00:00,  3.73s/it]
100%|██████████| 2/2 [00:03<00:00,  1.69s/it]
[epoch 11] train_loss: 0.120  val_accuracy: 0.983  epoch_time: 74.34 sec
train epoch[12/100] loss:0.111: 100%|██████████| 19/19 [01:24<00:00,  4.44s/it]
100%|██████████| 2/2 [00:02<00:00,  1.44s/it]
[epoch 12] train_loss: 0.096  val_accuracy: 0.948  epoch_time: 87.29 sec
train epoch[13/100] loss:0.142: 100%|██████████| 19/19 [01:13<00:00,  3.88s/it]
100%|██████████| 2/2 [00:03<00:00,  1.64s/it]
[epoch 13] train_loss: 0.126  val_accuracy: 1.000  epoch_time: 77.02 sec
train epoch[14/100] loss:0.079: 100%|██████████| 19/19 [01:16<00:00,  4.02s/it]
100%|██████████| 2/2 [00:03<00:00,  1.63s/it]
[epoch 14] train_loss: 0.095  val_accuracy: 1.000  epoch_time: 79.73 sec
train epoch[15/100] loss:0.009: 100%|██████████| 19/19 [01:15<00:00,  3.96s/it]
100%|██████████| 2/2 [00:03<00:00,  1.66s/it]
[epoch 15] train_loss: 0.077  val_accuracy: 1.000  epoch_time: 78.54 sec
Early stopping triggered.
Training completed in 1136.46 seconds
Finished Training
代码
文本
[14]
!pip install tensorboard
已隐藏输出
代码
文本

训练日志的输出方法

代码
文本
[16]
!%tensorboard --logdir /personal/Auto_Titration/Picture_Train/experiment_optimized
已隐藏输出
代码
文本
[ ]
%load_ext tensorboard
%tensorboard --logdir /personal/Auto_Titration/Picture_Train/experiment_optimized
代码
文本
Machine Learning
python
Machine Learningpython
点个赞吧
推荐阅读
公开
7-4,DeepFM网络
Machine Learning
Machine Learning
xuxh@dp.tech
更新于 2024-10-11
公开
优化器optimizers
TensorFlow
TensorFlow
xuxh@dp.tech
更新于 2024-08-22
{/**/}