新建
智能滴定算法改进的训练
zzh
推荐镜像 :DeePMD-kit:3.0.0b4-cuda12.1
推荐机型 :c32_m64_cpu
赞
目录
智能滴定是一种结合自动化控制与智能算法的现代化化学分析技术,旨在实现样品滴定过程的高精度和高效率。本次智能滴定技术主要用于变色滴定,通过图像采集和颜色识别自动判断滴定终点。通过先进的计算技术、传感器系统和数据处理算法,智能滴定系统能够实时监控滴定过程中颜色的变化,自动调整滴定速度,确保实验的精确性和稳定性。
在本次智能滴定系统中,针对图像识别算法的改进集中在以下几个方面:
优化训练算法
将优化器从 Adam 改为 SGD,并添加动量项和学习率调度器(StepLR)。这种组合在分类任务中通常表现出更高的准确性和稳定性,特别是在小数据集上。学习率调度器可以在训练后期降低学习率,帮助模型细化结果,提高泛化能力。
简化数据增强策略
本次修改精简了数据增强操作,保留了核心特征提取的增强方法,如随机裁剪和水平翻转。这种调整减少了不必要的复杂变换,聚焦于图像的主要特征,有助于模型在变色滴定过程中对颜色变化的识别更准确。
引入早停策略
为了防止模型过拟合,加入了早停策略。当验证准确率不再提升时,系统会自动停止训练,从而节省训练时间并提高模型的鲁棒性。
实时记录与可视化
使用 TensorBoard 实时记录训练过程中的损失、准确率和训练时长,并进行可视化分析。这有助于优化模型训练过程,并确保模型以最优的精度和效率完成训练。
通过这些改进,智能滴定系统可以更加准确和快速地识别滴定终点的颜色变化,适应复杂的实验环境,实现更高的实验精度和稳定性。
代码
文本
双击即可修改
代码
文本
安装torch,resnet,tensorflow,torchvision库
代码
文本
[2]
!pip install torch
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple Collecting torch Downloading https://pypi.tuna.tsinghua.edu.cn/packages/2a/ef/834af4a885b31a0b32fff2d80e1e40f771e1566ea8ded55347502440786a/torch-2.5.1-cp310-cp310-manylinux1_x86_64.whl (906.4 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 906.4/906.4 MB 1.1 MB/s eta 0:00:0000:0100:02 Collecting nvidia-cuda-nvrtc-cu12==12.4.127 Downloading https://pypi.tuna.tsinghua.edu.cn/packages/2c/14/91ae57cd4db3f9ef7aa99f4019cfa8d54cb4caa7e00975df6467e9725a9f/nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (24.6 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 24.6/24.6 MB 21.6 MB/s eta 0:00:0000:0100:01 Collecting networkx Downloading https://pypi.tuna.tsinghua.edu.cn/packages/b9/54/dd730b32ea14ea797530a4479b2ed46a6fb250f682a9cfb997e968bf0261/networkx-3.4.2-py3-none-any.whl (1.7 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.7/1.7 MB 11.9 MB/s eta 0:00:0000:0100:01 Collecting nvidia-cusolver-cu12==11.6.1.9 Downloading https://pypi.tuna.tsinghua.edu.cn/packages/3a/e1/5b9089a4b2a4790dfdea8b3a006052cfecff58139d5a4e34cb1a51df8d6f/nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl (127.9 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 127.9/127.9 MB 5.7 MB/s eta 0:00:0000:0100:01 Collecting nvidia-cuda-cupti-cu12==12.4.127 Downloading https://pypi.tuna.tsinghua.edu.cn/packages/67/42/f4f60238e8194a3106d06a058d494b18e006c10bb2b915655bd9f6ea4cb1/nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (13.8 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 13.8/13.8 MB 36.8 MB/s eta 0:00:0000:0100:01 Collecting nvidia-nccl-cu12==2.21.5 Downloading https://pypi.tuna.tsinghua.edu.cn/packages/df/99/12cd266d6233f47d00daf3a72739872bdc10267d0383508b0b9c84a18bb6/nvidia_nccl_cu12-2.21.5-py3-none-manylinux2014_x86_64.whl (188.7 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 188.7/188.7 MB 2.6 MB/s eta 0:00:0000:0100:01 Collecting fsspec Downloading https://pypi.tuna.tsinghua.edu.cn/packages/c6/b2/454d6e7f0158951d8a78c2e1eb4f69ae81beb8dca5fee9809c6c99e9d0d0/fsspec-2024.10.0-py3-none-any.whl (179 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 179.6/179.6 kB 21.5 MB/s eta 0:00:00 Collecting nvidia-cublas-cu12==12.4.5.8 Downloading https://pypi.tuna.tsinghua.edu.cn/packages/ae/71/1c91302526c45ab494c23f61c7a84aa568b8c1f9d196efa5993957faf906/nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl (363.4 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 363.4/363.4 MB 2.1 MB/s eta 0:00:0000:0100:01 Collecting nvidia-cusparse-cu12==12.3.1.170 Downloading https://pypi.tuna.tsinghua.edu.cn/packages/db/f7/97a9ea26ed4bbbfc2d470994b8b4f338ef663be97b8f677519ac195e113d/nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl (207.5 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 207.5/207.5 MB 2.8 MB/s eta 0:00:0000:0100:01 Collecting triton==3.1.0 Downloading https://pypi.tuna.tsinghua.edu.cn/packages/98/29/69aa56dc0b2eb2602b553881e34243475ea2afd9699be042316842788ff5/triton-3.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (209.5 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 209.5/209.5 MB 2.5 MB/s eta 0:00:0000:0100:01 Collecting nvidia-nvjitlink-cu12==12.4.127 Downloading https://pypi.tuna.tsinghua.edu.cn/packages/ff/ff/847841bacfbefc97a00036e0fce5a0f086b640756dc38caea5e1bb002655/nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (21.1 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 21.1/21.1 MB 35.5 MB/s eta 0:00:0000:0100:01 Collecting filelock Downloading https://pypi.tuna.tsinghua.edu.cn/packages/b9/f8/feced7779d755758a52d1f6635d990b8d98dc0a29fa568bbe0625f18fdf3/filelock-3.16.1-py3-none-any.whl (16 kB) Collecting nvidia-curand-cu12==10.3.5.147 Downloading https://pypi.tuna.tsinghua.edu.cn/packages/8a/6d/44ad094874c6f1b9c654f8ed939590bdc408349f137f9b98a3a23ccec411/nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl (56.3 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 56.3/56.3 MB 9.8 MB/s eta 0:00:00:00:0100:01 Requirement already satisfied: jinja2 in /opt/mamba/lib/python3.10/site-packages (from torch) (3.1.2) Collecting nvidia-nvtx-cu12==12.4.127 Downloading https://pypi.tuna.tsinghua.edu.cn/packages/87/20/199b8713428322a2f22b722c62b8cc278cc53dffa9705d744484b5035ee9/nvidia_nvtx_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (99 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 99.1/99.1 kB 19.3 MB/s eta 0:00:00 Requirement already satisfied: typing-extensions>=4.8.0 in /opt/mamba/lib/python3.10/site-packages (from torch) (4.12.2) Collecting nvidia-cudnn-cu12==9.1.0.70 Downloading https://pypi.tuna.tsinghua.edu.cn/packages/9f/fd/713452cd72343f682b1c7b9321e23829f00b842ceaedcda96e742ea0b0b3/nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl (664.8 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 664.8/664.8 MB 1.2 MB/s eta 0:00:0000:0100:02 Collecting sympy==1.13.1 Downloading https://pypi.tuna.tsinghua.edu.cn/packages/b2/fe/81695a1aa331a842b582453b605175f419fe8540355886031328089d840a/sympy-1.13.1-py3-none-any.whl (6.2 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 6.2/6.2 MB 20.9 MB/s eta 0:00:0000:0100:01 Collecting nvidia-cuda-runtime-cu12==12.4.127 Downloading https://pypi.tuna.tsinghua.edu.cn/packages/ea/27/1795d86fe88ef397885f2e580ac37628ed058a92ed2c39dc8eac3adf0619/nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (883 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 883.7/883.7 kB 33.1 MB/s eta 0:00:00 Collecting nvidia-cufft-cu12==11.2.1.3 Downloading https://pypi.tuna.tsinghua.edu.cn/packages/27/94/3266821f65b92b3138631e9c8e7fe1fb513804ac934485a8d05776e1dd43/nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl (211.5 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 211.5/211.5 MB 2.4 MB/s eta 0:00:0000:0100:01 Collecting mpmath<1.4,>=1.1.0 Downloading https://pypi.tuna.tsinghua.edu.cn/packages/43/e3/7d92a15f894aa0c9c4b49b8ee9ac9850d6e63b03c9c32c0367a13ae62209/mpmath-1.3.0-py3-none-any.whl (536 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 536.2/536.2 kB 12.1 MB/s eta 0:00:00a 0:00:01 Requirement already satisfied: MarkupSafe>=2.0 in /opt/mamba/lib/python3.10/site-packages (from jinja2->torch) (2.1.2) Installing collected packages: mpmath, sympy, nvidia-nvtx-cu12, nvidia-nvjitlink-cu12, nvidia-nccl-cu12, nvidia-curand-cu12, nvidia-cufft-cu12, nvidia-cuda-runtime-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-cupti-cu12, nvidia-cublas-cu12, networkx, fsspec, filelock, triton, nvidia-cusparse-cu12, nvidia-cudnn-cu12, nvidia-cusolver-cu12, torch Successfully installed filelock-3.16.1 fsspec-2024.10.0 mpmath-1.3.0 networkx-3.4.2 nvidia-cublas-cu12-12.4.5.8 nvidia-cuda-cupti-cu12-12.4.127 nvidia-cuda-nvrtc-cu12-12.4.127 nvidia-cuda-runtime-cu12-12.4.127 nvidia-cudnn-cu12-9.1.0.70 nvidia-cufft-cu12-11.2.1.3 nvidia-curand-cu12-10.3.5.147 nvidia-cusolver-cu12-11.6.1.9 nvidia-cusparse-cu12-12.3.1.170 nvidia-nccl-cu12-2.21.5 nvidia-nvjitlink-cu12-12.4.127 nvidia-nvtx-cu12-12.4.127 sympy-1.13.1 torch-2.5.1 triton-3.1.0 WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv
代码
文本
[4]
!pip install torchvision
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple Collecting torchvision Downloading https://pypi.tuna.tsinghua.edu.cn/packages/a2/f6/7ff89a9f8703f623f5664afd66c8600e3f09fe188e1e0b7e6f9a8617f865/torchvision-0.20.1-cp310-cp310-manylinux1_x86_64.whl (7.2 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 7.2/7.2 MB 40.6 MB/s eta 0:00:0000:0100:01 Requirement already satisfied: torch==2.5.1 in /opt/mamba/lib/python3.10/site-packages (from torchvision) (2.5.1) Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /opt/mamba/lib/python3.10/site-packages (from torchvision) (10.4.0) Requirement already satisfied: numpy in /opt/mamba/lib/python3.10/site-packages (from torchvision) (1.24.2) Requirement already satisfied: nvidia-curand-cu12==10.3.5.147 in /opt/mamba/lib/python3.10/site-packages (from torch==2.5.1->torchvision) (10.3.5.147) Requirement already satisfied: nvidia-cufft-cu12==11.2.1.3 in /opt/mamba/lib/python3.10/site-packages (from torch==2.5.1->torchvision) (11.2.1.3) Requirement already satisfied: typing-extensions>=4.8.0 in /opt/mamba/lib/python3.10/site-packages (from torch==2.5.1->torchvision) (4.12.2) Requirement already satisfied: nvidia-cusolver-cu12==11.6.1.9 in /opt/mamba/lib/python3.10/site-packages (from torch==2.5.1->torchvision) (11.6.1.9) Requirement already satisfied: nvidia-cusparse-cu12==12.3.1.170 in /opt/mamba/lib/python3.10/site-packages (from torch==2.5.1->torchvision) (12.3.1.170) Requirement already satisfied: nvidia-nvtx-cu12==12.4.127 in /opt/mamba/lib/python3.10/site-packages (from torch==2.5.1->torchvision) (12.4.127) Requirement already satisfied: sympy==1.13.1 in /opt/mamba/lib/python3.10/site-packages (from torch==2.5.1->torchvision) (1.13.1) Requirement already satisfied: jinja2 in /opt/mamba/lib/python3.10/site-packages (from torch==2.5.1->torchvision) (3.1.2) Requirement already satisfied: nvidia-cuda-cupti-cu12==12.4.127 in /opt/mamba/lib/python3.10/site-packages (from torch==2.5.1->torchvision) (12.4.127) Requirement already satisfied: nvidia-cublas-cu12==12.4.5.8 in /opt/mamba/lib/python3.10/site-packages (from torch==2.5.1->torchvision) (12.4.5.8) Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.4.127 in /opt/mamba/lib/python3.10/site-packages (from torch==2.5.1->torchvision) (12.4.127) Requirement already satisfied: fsspec in /opt/mamba/lib/python3.10/site-packages (from torch==2.5.1->torchvision) (2024.10.0) Requirement already satisfied: triton==3.1.0 in /opt/mamba/lib/python3.10/site-packages (from torch==2.5.1->torchvision) (3.1.0) Requirement already satisfied: networkx in /opt/mamba/lib/python3.10/site-packages (from torch==2.5.1->torchvision) (3.4.2) Requirement already satisfied: nvidia-cudnn-cu12==9.1.0.70 in /opt/mamba/lib/python3.10/site-packages (from torch==2.5.1->torchvision) (9.1.0.70) Requirement already satisfied: nvidia-cuda-runtime-cu12==12.4.127 in /opt/mamba/lib/python3.10/site-packages (from torch==2.5.1->torchvision) (12.4.127) Requirement already satisfied: nvidia-nvjitlink-cu12==12.4.127 in /opt/mamba/lib/python3.10/site-packages (from torch==2.5.1->torchvision) (12.4.127) Requirement already satisfied: filelock in /opt/mamba/lib/python3.10/site-packages (from torch==2.5.1->torchvision) (3.16.1) Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in /opt/mamba/lib/python3.10/site-packages (from torch==2.5.1->torchvision) (2.21.5) Requirement already satisfied: mpmath<1.4,>=1.1.0 in /opt/mamba/lib/python3.10/site-packages (from sympy==1.13.1->torch==2.5.1->torchvision) (1.3.0) Requirement already satisfied: MarkupSafe>=2.0 in /opt/mamba/lib/python3.10/site-packages (from jinja2->torch==2.5.1->torchvision) (2.1.2) Installing collected packages: torchvision Successfully installed torchvision-0.20.1 WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv
代码
文本
双击即可修改
代码
文本
[7]
!pip install resnet
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple Collecting resnet Downloading https://pypi.tuna.tsinghua.edu.cn/packages/c2/5b/1a89d31126c50cea7b29db3772a00862fa72b54f0970032766c914091ee0/resnet-0.1.tar.gz (5.8 kB) Preparing metadata (setup.py) ... done Collecting keras>=2.0 Downloading https://pypi.tuna.tsinghua.edu.cn/packages/c2/88/eef50051a772dcb4433d1f3e4c1d6576ba450fe83e89d028d7e8b85a2122/keras-3.6.0-py3-none-any.whl (1.2 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.2/1.2 MB 15.2 MB/s eta 0:00:0000:0100:01 Collecting optree Downloading https://pypi.tuna.tsinghua.edu.cn/packages/ab/ce/91d9f095fb2bd0a22490ede4580aed861472682b47bd460bc06369fce502/optree-0.13.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (358 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 358.9/358.9 kB 28.0 MB/s eta 0:00:00 Requirement already satisfied: packaging in /opt/mamba/lib/python3.10/site-packages (from keras>=2.0->resnet) (23.0) Requirement already satisfied: numpy in /opt/mamba/lib/python3.10/site-packages (from keras>=2.0->resnet) (1.24.2) Collecting absl-py Downloading https://pypi.tuna.tsinghua.edu.cn/packages/a2/ad/e0d3c824784ff121c03cc031f944bc7e139a8f1870ffd2845cc2dd76f6c4/absl_py-2.1.0-py3-none-any.whl (133 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 133.7/133.7 kB 27.8 MB/s eta 0:00:00 Collecting rich Downloading https://pypi.tuna.tsinghua.edu.cn/packages/9a/e2/10e9819cf4a20bd8ea2f5dabafc2e6bf4a78d6a0965daeb60a4b34d1c11f/rich-13.9.3-py3-none-any.whl (242 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 242.2/242.2 kB 28.1 MB/s eta 0:00:00 Collecting ml-dtypes Downloading https://pypi.tuna.tsinghua.edu.cn/packages/9a/5b/d47361f882ff2ae27d764f314d18706c69859da60a6c78e6c9e81714c792/ml_dtypes-0.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (4.5 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 4.5/4.5 MB 49.7 MB/s eta 0:00:00a 0:00:01 Requirement already satisfied: h5py in /opt/mamba/lib/python3.10/site-packages (from keras>=2.0->resnet) (3.8.0) Collecting namex Downloading https://pypi.tuna.tsinghua.edu.cn/packages/73/59/7854fbfb59f8ae35483ce93493708be5942ebb6328cd85b3a609df629736/namex-0.0.8-py3-none-any.whl (5.8 kB) Requirement already satisfied: typing-extensions>=4.5.0 in /opt/mamba/lib/python3.10/site-packages (from optree->keras>=2.0->resnet) (4.12.2) Collecting markdown-it-py>=2.2.0 Downloading https://pypi.tuna.tsinghua.edu.cn/packages/42/d7/1ec15b46af6af88f19b8e5ffea08fa375d433c998b8a7639e76935c14f1f/markdown_it_py-3.0.0-py3-none-any.whl (87 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 87.5/87.5 kB 31.2 MB/s eta 0:00:00 Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /opt/mamba/lib/python3.10/site-packages (from rich->keras>=2.0->resnet) (2.14.0) Collecting mdurl~=0.1 Downloading https://pypi.tuna.tsinghua.edu.cn/packages/b3/38/89ba8ad64ae25be8de66a6d463314cf1eb366222074cfda9ee839c56a4b4/mdurl-0.1.2-py3-none-any.whl (10.0 kB) Building wheels for collected packages: resnet Building wheel for resnet (setup.py) ... done Created wheel for resnet: filename=resnet-0.1-py3-none-any.whl size=10027 sha256=4ae376a10f10a2f46ba5c707f154202beb920a7d38480cfc0429aa4ad2b345c2 Stored in directory: /root/.cache/pip/wheels/fa/38/08/daf71e4f8411d19e510748916e29e17be43a33d5b0a1b4896b Successfully built resnet Installing collected packages: namex, optree, ml-dtypes, mdurl, absl-py, markdown-it-py, rich, keras, resnet Successfully installed absl-py-2.1.0 keras-3.6.0 markdown-it-py-3.0.0 mdurl-0.1.2 ml-dtypes-0.5.0 namex-0.0.8 optree-0.13.0 resnet-0.1 rich-13.9.3 WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv
代码
文本
[9]
!pip install tensorflow
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple Collecting tensorflow Downloading https://pypi.tuna.tsinghua.edu.cn/packages/d4/80/1567ccc375ccda4d28af28c960cca7f709f7c259463ac1436554697e8868/tensorflow-2.18.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (615.3 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 615.3/615.3 MB 1.2 MB/s eta 0:00:0000:0100:02 Requirement already satisfied: keras>=3.5.0 in /opt/mamba/lib/python3.10/site-packages (from tensorflow) (3.6.0) Collecting numpy<2.1.0,>=1.26.0 Downloading https://pypi.tuna.tsinghua.edu.cn/packages/fa/66/f7177ab331876200ac7563a580140643d1179c8b4b6a6b0fc9838de2a9b8/numpy-2.0.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (19.5 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 19.5/19.5 MB 22.1 MB/s eta 0:00:0000:0100:01 Collecting tensorflow-io-gcs-filesystem>=0.23.1 Downloading https://pypi.tuna.tsinghua.edu.cn/packages/f3/48/47b7d25572961a48b1de3729b7a11e835b888e41e0203cca82df95d23b91/tensorflow_io_gcs_filesystem-0.37.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (5.1 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 5.1/5.1 MB 11.9 MB/s eta 0:00:0000:0100:01 Collecting opt-einsum>=2.3.2 Downloading https://pypi.tuna.tsinghua.edu.cn/packages/23/cd/066e86230ae37ed0be70aae89aabf03ca8d9f39c8aea0dec8029455b5540/opt_einsum-3.4.0-py3-none-any.whl (71 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 71.9/71.9 kB 17.0 MB/s eta 0:00:00 Collecting astunparse>=1.6.0 Downloading https://pypi.tuna.tsinghua.edu.cn/packages/2b/03/13dde6512ad7b4557eb792fbcf0c653af6076b81e5941d36ec61f7ce6028/astunparse-1.6.3-py2.py3-none-any.whl (12 kB) Collecting google-pasta>=0.1.1 Downloading https://pypi.tuna.tsinghua.edu.cn/packages/a3/de/c648ef6835192e6e2cc03f40b19eeda4382c49b5bafb43d88b931c4c74ac/google_pasta-0.2.0-py3-none-any.whl (57 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 57.5/57.5 kB 17.3 MB/s eta 0:00:00 Collecting libclang>=13.0.0 Downloading https://pypi.tuna.tsinghua.edu.cn/packages/1d/fc/716c1e62e512ef1c160e7984a73a5fc7df45166f2ff3f254e71c58076f7c/libclang-18.1.1-py2.py3-none-manylinux2010_x86_64.whl (24.5 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 24.5/24.5 MB 11.1 MB/s eta 0:00:0000:0100:01 Requirement already satisfied: typing-extensions>=3.6.6 in /opt/mamba/lib/python3.10/site-packages (from tensorflow) (4.12.2) Collecting flatbuffers>=24.3.25 Downloading https://pypi.tuna.tsinghua.edu.cn/packages/41/f0/7e988a019bc54b2dbd0ad4182ef2d53488bb02e58694cd79d61369e85900/flatbuffers-24.3.25-py2.py3-none-any.whl (26 kB) Requirement already satisfied: requests<3,>=2.21.0 in /opt/mamba/lib/python3.10/site-packages (from tensorflow) (2.28.1) Collecting tensorboard<2.19,>=2.18 Downloading https://pypi.tuna.tsinghua.edu.cn/packages/b1/de/021c1d407befb505791764ad2cbd56ceaaa53a746baed01d2e2143f05f18/tensorboard-2.18.0-py3-none-any.whl (5.5 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 5.5/5.5 MB 14.9 MB/s eta 0:00:0000:0100:01 Collecting gast!=0.5.0,!=0.5.1,!=0.5.2,>=0.2.1 Downloading https://pypi.tuna.tsinghua.edu.cn/packages/a3/61/8001b38461d751cd1a0c3a6ae84346796a5758123f3ed97a1b121dfbf4f3/gast-0.6.0-py3-none-any.whl (21 kB) Requirement already satisfied: packaging in /opt/mamba/lib/python3.10/site-packages (from tensorflow) (23.0) Requirement already satisfied: setuptools in /opt/mamba/lib/python3.10/site-packages (from tensorflow) (65.5.0) Collecting h5py>=3.11.0 Downloading https://pypi.tuna.tsinghua.edu.cn/packages/85/bc/e76f4b2096e0859225f5441d1b7f5e2041fffa19fc2c16756c67078417aa/h5py-3.12.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (5.3 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 5.3/5.3 MB 35.5 MB/s eta 0:00:0000:0100:01 Collecting protobuf!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<6.0.0dev,>=3.20.3 Downloading https://pypi.tuna.tsinghua.edu.cn/packages/5d/ae/3257b09328c0b4e59535e497b0c7537d4954038bdd53a2f0d2f49d15a7c4/protobuf-5.28.3-cp38-abi3-manylinux2014_x86_64.whl (316 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 316.6/316.6 kB 41.9 MB/s eta 0:00:00 Collecting termcolor>=1.1.0 Downloading https://pypi.tuna.tsinghua.edu.cn/packages/7f/be/df630c387a0a054815d60be6a97eb4e8f17385d5d6fe660e1c02750062b4/termcolor-2.5.0-py3-none-any.whl (7.8 kB) Collecting wrapt>=1.11.0 Downloading https://pypi.tuna.tsinghua.edu.cn/packages/49/83/b40bc1ad04a868b5b5bcec86349f06c1ee1ea7afe51dc3e46131e4f39308/wrapt-1.16.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (80 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 80.3/80.3 kB 22.0 MB/s eta 0:00:00 Requirement already satisfied: absl-py>=1.0.0 in /opt/mamba/lib/python3.10/site-packages (from tensorflow) (2.1.0) Requirement already satisfied: six>=1.12.0 in /opt/mamba/lib/python3.10/site-packages (from tensorflow) (1.16.0) Collecting grpcio<2.0,>=1.24.3 Downloading https://pypi.tuna.tsinghua.edu.cn/packages/64/19/a16762a70eeb8ddfe43283ce434d1499c1c409ceec0c646f783883084478/grpcio-1.67.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (5.9 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 5.9/5.9 MB 74.2 MB/s eta 0:00:0000:0100:01 Collecting ml-dtypes<0.5.0,>=0.4.0 Downloading https://pypi.tuna.tsinghua.edu.cn/packages/16/86/a9f7569e7e4f5395f927de38a13b92efa73f809285d04f2923b291783dd2/ml_dtypes-0.4.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.2 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.2/2.2 MB 81.0 MB/s eta 0:00:00 Requirement already satisfied: wheel<1.0,>=0.23.0 in /opt/mamba/lib/python3.10/site-packages (from astunparse>=1.6.0->tensorflow) (0.37.1) Requirement already satisfied: namex in /opt/mamba/lib/python3.10/site-packages (from keras>=3.5.0->tensorflow) (0.0.8) Requirement already satisfied: optree in /opt/mamba/lib/python3.10/site-packages (from keras>=3.5.0->tensorflow) (0.13.0) Requirement already satisfied: rich in /opt/mamba/lib/python3.10/site-packages (from keras>=3.5.0->tensorflow) (13.9.3) Requirement already satisfied: idna<4,>=2.5 in /opt/mamba/lib/python3.10/site-packages (from requests<3,>=2.21.0->tensorflow) (3.4) Requirement already satisfied: certifi>=2017.4.17 in /opt/mamba/lib/python3.10/site-packages (from requests<3,>=2.21.0->tensorflow) (2022.9.24) Requirement already satisfied: charset-normalizer<3,>=2 in /opt/mamba/lib/python3.10/site-packages (from requests<3,>=2.21.0->tensorflow) (2.1.1) Requirement already satisfied: urllib3<1.27,>=1.21.1 in /opt/mamba/lib/python3.10/site-packages (from requests<3,>=2.21.0->tensorflow) (1.26.11) Collecting werkzeug>=1.0.1 Downloading https://pypi.tuna.tsinghua.edu.cn/packages/6c/69/05837f91dfe42109203ffa3e488214ff86a6d68b2ed6c167da6cdc42349b/werkzeug-3.0.6-py3-none-any.whl (227 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 228.0/228.0 kB 65.7 MB/s eta 0:00:00 Collecting markdown>=2.6.8 Downloading https://pypi.tuna.tsinghua.edu.cn/packages/3f/08/83871f3c50fc983b88547c196d11cf8c3340e37c32d2e9d6152abe2c61f7/Markdown-3.7-py3-none-any.whl (106 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 106.3/106.3 kB 37.1 MB/s eta 0:00:00 Collecting tensorboard-data-server<0.8.0,>=0.7.0 Downloading https://pypi.tuna.tsinghua.edu.cn/packages/73/c6/825dab04195756cf8ff2e12698f22513b3db2f64925bdd41671bfb33aaa5/tensorboard_data_server-0.7.2-py3-none-manylinux_2_31_x86_64.whl (6.6 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 6.6/6.6 MB 86.4 MB/s eta 0:00:00:00:0100:01 Requirement already satisfied: MarkupSafe>=2.1.1 in /opt/mamba/lib/python3.10/site-packages (from werkzeug>=1.0.1->tensorboard<2.19,>=2.18->tensorflow) (2.1.2) Requirement already satisfied: markdown-it-py>=2.2.0 in /opt/mamba/lib/python3.10/site-packages (from rich->keras>=3.5.0->tensorflow) (3.0.0) Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /opt/mamba/lib/python3.10/site-packages (from rich->keras>=3.5.0->tensorflow) (2.14.0) Requirement already satisfied: mdurl~=0.1 in /opt/mamba/lib/python3.10/site-packages (from markdown-it-py>=2.2.0->rich->keras>=3.5.0->tensorflow) (0.1.2) Installing collected packages: libclang, flatbuffers, wrapt, werkzeug, termcolor, tensorflow-io-gcs-filesystem, tensorboard-data-server, protobuf, opt-einsum, numpy, markdown, grpcio, google-pasta, gast, astunparse, tensorboard, ml-dtypes, h5py, tensorflow Attempting uninstall: numpy Found existing installation: numpy 1.24.2 Uninstalling numpy-1.24.2: Successfully uninstalled numpy-1.24.2 Attempting uninstall: ml-dtypes Found existing installation: ml_dtypes 0.5.0 Uninstalling ml_dtypes-0.5.0: Successfully uninstalled ml_dtypes-0.5.0 Attempting uninstall: h5py Found existing installation: h5py 3.8.0 Uninstalling h5py-3.8.0: Successfully uninstalled h5py-3.8.0 ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts. scipy 1.10.1 requires numpy<1.27.0,>=1.19.5, but you have numpy 2.0.2 which is incompatible. Successfully installed astunparse-1.6.3 flatbuffers-24.3.25 gast-0.6.0 google-pasta-0.2.0 grpcio-1.67.1 h5py-3.12.1 libclang-18.1.1 markdown-3.7 ml-dtypes-0.4.1 numpy-2.0.2 opt-einsum-3.4.0 protobuf-5.28.3 tensorboard-2.18.0 tensorboard-data-server-0.7.2 tensorflow-2.18.0 tensorflow-io-gcs-filesystem-0.37.1 termcolor-2.5.0 werkzeug-3.0.6 wrapt-1.16.0 WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv
代码
文本
原初算法模型(大连理工大学)
代码
文本
[12]
import os
import sys
import json
import torch
import torch.nn as nn
from torchvision import transforms, datasets
import torch.optim as optim
from tqdm import tqdm
from torchvision.models import resnet34
# 主函数
def main():
# 判断是否有可用的GPU,如果有则使用GPU,否则使用CPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("using {} device.".format(device))
# 定义训练和验证的数据变换
data_transform = {
"train": transforms.Compose([
# 随机裁剪并缩放图片到224x224大小
transforms.RandomResizedCrop(224),
# 随机水平翻转图片
transforms.RandomHorizontalFlip(),
# 将图片转换为Tensor
transforms.ToTensor(),
# 对图片进行归一化
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
"val": transforms.Compose([
# 将图片缩放到224x224大小
transforms.Resize((224, 224)),
# 将图片转换为Tensor
transforms.ToTensor(),
# 对图片进行归一化
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}
# 获取数据集的根路径
data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))
# 拼接出图片数据集的路径
image_path = os.path.join(data_root, "/personal/Auto_Titration/Picture_Train/data/")
# 断言图片数据集路径存在
assert os.path.exists(image_path), "{} path does not exist.".format(image_path)
train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
transform=data_transform["train"])
# 获取训练数据集中的样本数量
train_num = len(train_dataset)
# 获取类别到索引的映射
flower_list = train_dataset.class_to_idx
# 反转映射,得到索引到类别的映射
cla_dict = dict((val, key) for key, val in flower_list.items())
# 将索引到类别的映射写入json文件
json_str = json.dumps(cla_dict, indent=4)
with open('class_indices.json', 'w') as json_file:
json_file.write(json_str)
# 设置batch大小
batch_size = 32
# 计算每个进程使用的dataloader工作线程数
nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])
print('Using {} dataloader workers every process'.format(nw))
# 创建训练数据加载器,使用指定的批次大小、是否打乱数据以及工作线程数
train_loader = torch.utils.data.DataLoader(train_dataset,
batch_size=batch_size, shuffle=True,
num_workers=0)
# 创建验证数据集,使用ImageFolder加载指定目录下的图片,并应用相应的数据变换
validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
transform=data_transform["val"])
# 获取验证数据集的样本数量
val_num = len(validate_dataset)
# 创建验证数据加载器
validate_loader = torch.utils.data.DataLoader(validate_dataset,
batch_size=batch_size, shuffle=False,
num_workers=0)
# 打印用于训练和验证的图片数量
print("using {} images for training, {} images for validation.".format(train_num,
val_num))
# 创建一个验证数据加载器的迭代器,并获取一张图片和对应的标签(这里并未使用)
# test_data_iter = iter(validate_loader)
# test_image, test_label = test_data_iter.next()
# 定义模型名称
model_name = "resnet34-3"
# 实例化ResNet34模型,并设置输出类别数为2
net = resnet34(num_classes=2)
# 将模型移动到指定的设备上(CPU或GPU)
net.to(device)
# 定义损失函数为交叉熵损失
loss_function = nn.CrossEntropyLoss()
# 定义优化器为Adam,并设置学习率为0.0001
optimizer = optim.Adam(net.parameters(), lr=0.0001)
# 设置训练轮数
epochs = 100
# 初始化最佳准确率为0
best_acc = 0.0
# 定义模型保存路径
save_path = './{}Net.pth'.format(model_name)
# 获取训练数据加载器的长度,即训练步数
train_steps = len(train_loader)
# 开始训练循环
for epoch in range(epochs):
# 设置模型为训练模式
net.train()
# 初始化训练损失为0
running_loss = 0.0
# 使用tqdm库创建一个进度条,用于显示训练进度
train_bar = tqdm(train_loader, file=sys.stdout)
# 开始训练步骤的循环
for step, data in enumerate(train_bar):
# 从数据加载器中获取图片和标签
images, labels = data
# 清空梯度
optimizer.zero_grad()
# 将图片和标签移动到指定的设备上
outputs = net(images.to(device))
# 计算损失
loss = loss_function(outputs, labels.to(device))
# 反向传播计算梯度
loss.backward()
# 使用优化器更新模型参数
optimizer.step()
# 累加训练损失
running_loss += loss.item()
# 更新进度条的描述,显示当前训练轮数、总轮数和损失值
train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
epochs,
loss)
# 进入验证阶段
net.eval() # 将模型设置为评估模式,关闭dropout和batch normalization的某些行为
acc = 0.0 # 初始化累计正确的数量,用于计算准确率
# 不计算梯度,因为验证阶段不需要反向传播
with torch.no_grad():
val_bar = tqdm(validate_loader, file=sys.stdout) # 创建验证数据加载器的进度条
for val_data in val_bar: # 遍历验证数据
val_images, val_labels = val_data # 获取验证图片和标签
outputs = net(val_images.to(device)) # 前向传播,得到预测输出
predict_y = torch.max(outputs, dim=1)[1] # 获取预测的最大概率对应的类别索引
acc += torch.eq(predict_y, val_labels.to(device)).sum().item() # 计算预测正确的数量,并累加
# 计算验证准确率
val_accurate = acc / val_num
# 打印当前轮数的训练损失和验证准确率
print('[epoch %d] train_loss: %.3f val_accuracy: %.3f' %
(epoch + 1, running_loss / train_steps, val_accurate))
# 如果当前验证准确率比之前保存的最高准确率还要高
if val_accurate > best_acc:
best_acc = val_accurate # 更新最高准确率
torch.save(net.state_dict(), save_path) # 保存当前模型状态字典到指定路径
# 训练完成后打印结束信息
print('Finished Training')
if __name__ == '__main__':
main()
using cpu device. Using 2 dataloader workers every process using 604 images for training, 58 images for validation. train epoch[1/100] loss:0.170: 100%|██████████| 19/19 [02:31<00:00, 7.96s/it] 100%|██████████| 2/2 [00:05<00:00, 2.83s/it] [epoch 1] train_loss: 0.400 val_accuracy: 0.552 Finished Training train epoch[2/100] loss:0.116: 100%|██████████| 19/19 [02:12<00:00, 6.98s/it] 100%|██████████| 2/2 [00:04<00:00, 2.12s/it] [epoch 2] train_loss: 0.159 val_accuracy: 0.966 Finished Training train epoch[3/100] loss:0.029: 100%|██████████| 19/19 [02:12<00:00, 6.96s/it] 100%|██████████| 2/2 [00:04<00:00, 2.11s/it] [epoch 3] train_loss: 0.156 val_accuracy: 0.983 Finished Training train epoch[4/100] loss:0.060: 100%|██████████| 19/19 [02:12<00:00, 6.96s/it] 100%|██████████| 2/2 [00:04<00:00, 2.07s/it] [epoch 4] train_loss: 0.113 val_accuracy: 0.966 Finished Training train epoch[5/100] loss:0.240: 100%|██████████| 19/19 [02:13<00:00, 7.04s/it] 100%|██████████| 2/2 [00:04<00:00, 2.14s/it] [epoch 5] train_loss: 0.178 val_accuracy: 0.862 Finished Training train epoch[6/100] loss:0.187: 100%|██████████| 19/19 [02:15<00:00, 7.15s/it] 100%|██████████| 2/2 [00:04<00:00, 2.10s/it] [epoch 6] train_loss: 0.132 val_accuracy: 0.966 Finished Training train epoch[7/100] loss:0.173: 100%|██████████| 19/19 [02:13<00:00, 7.04s/it] 100%|██████████| 2/2 [00:04<00:00, 2.17s/it] [epoch 7] train_loss: 0.138 val_accuracy: 1.000 Finished Training train epoch[8/100] loss:0.076: 100%|██████████| 19/19 [02:13<00:00, 7.03s/it] 100%|██████████| 2/2 [00:04<00:00, 2.19s/it] [epoch 8] train_loss: 0.109 val_accuracy: 1.000 Finished Training train epoch[9/100] loss:0.041: 100%|██████████| 19/19 [02:15<00:00, 7.11s/it] 100%|██████████| 2/2 [00:04<00:00, 2.16s/it] [epoch 9] train_loss: 0.107 val_accuracy: 1.000 Finished Training train epoch[10/100] loss:0.093: 100%|██████████| 19/19 [02:13<00:00, 7.04s/it] 100%|██████████| 2/2 [00:04<00:00, 2.13s/it] [epoch 10] train_loss: 0.091 val_accuracy: 1.000 Finished Training train epoch[11/100] loss:0.019: 100%|██████████| 19/19 [02:12<00:00, 6.97s/it] 100%|██████████| 2/2 [00:04<00:00, 2.09s/it] [epoch 11] train_loss: 0.097 val_accuracy: 0.983 Finished Training train epoch[12/100] loss:0.296: 100%|██████████| 19/19 [02:13<00:00, 7.03s/it] 100%|██████████| 2/2 [00:04<00:00, 2.09s/it] [epoch 12] train_loss: 0.145 val_accuracy: 1.000 Finished Training train epoch[13/100] loss:0.042: 100%|██████████| 19/19 [02:14<00:00, 7.08s/it] 100%|██████████| 2/2 [00:04<00:00, 2.18s/it] [epoch 13] train_loss: 0.106 val_accuracy: 1.000 Finished Training train epoch[14/100] loss:0.012: 100%|██████████| 19/19 [02:12<00:00, 6.99s/it] 100%|██████████| 2/2 [00:04<00:00, 2.14s/it] [epoch 14] train_loss: 0.073 val_accuracy: 1.000 Finished Training train epoch[15/100] loss:0.049: 100%|██████████| 19/19 [02:13<00:00, 7.01s/it] 100%|██████████| 2/2 [00:04<00:00, 2.07s/it] [epoch 15] train_loss: 0.099 val_accuracy: 1.000 Finished Training train epoch[16/100] loss:0.028: 100%|██████████| 19/19 [02:13<00:00, 7.05s/it] 100%|██████████| 2/2 [00:04<00:00, 2.14s/it] [epoch 16] train_loss: 0.091 val_accuracy: 1.000 Finished Training train epoch[17/100] loss:0.040: 100%|██████████| 19/19 [02:14<00:00, 7.08s/it] 100%|██████████| 2/2 [00:04<00:00, 2.14s/it] [epoch 17] train_loss: 0.091 val_accuracy: 0.948 Finished Training train epoch[18/100] loss:0.141: 100%|██████████| 19/19 [02:15<00:00, 7.13s/it] 100%|██████████| 2/2 [00:04<00:00, 2.14s/it] [epoch 18] train_loss: 0.090 val_accuracy: 1.000 Finished Training train epoch[19/100] loss:0.148: 5%|▌ | 1/19 [00:09<02:57, 9.85s/it]
--------------------------------------------------------------------------- KeyboardInterrupt Traceback (most recent call last) Cell In[12], line 158 154 print('Finished Training') 157 if __name__ == '__main__': --> 158 main() Cell In[12], line 120, in main() 118 optimizer.zero_grad() 119 # 将图片和标签移动到指定的设备上 --> 120 outputs = net(images.to(device)) 121 # 计算损失 122 loss = loss_function(outputs, labels.to(device)) File /opt/mamba/lib/python3.10/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs) 1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] 1735 else: -> 1736 return self._call_impl(*args, **kwargs) File /opt/mamba/lib/python3.10/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs) 1742 # If we don't have any hooks, we want to skip the rest of the logic in 1743 # this function, and just call forward. 1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1745 or _global_backward_pre_hooks or _global_backward_hooks 1746 or _global_forward_hooks or _global_forward_pre_hooks): -> 1747 return forward_call(*args, **kwargs) 1749 result = None 1750 called_always_called_hooks = set() File /opt/mamba/lib/python3.10/site-packages/torchvision/models/resnet.py:285, in ResNet.forward(self, x) 284 def forward(self, x: Tensor) -> Tensor: --> 285 return self._forward_impl(x) File /opt/mamba/lib/python3.10/site-packages/torchvision/models/resnet.py:276, in ResNet._forward_impl(self, x) 274 x = self.layer2(x) 275 x = self.layer3(x) --> 276 x = self.layer4(x) 278 x = self.avgpool(x) 279 x = torch.flatten(x, 1) File /opt/mamba/lib/python3.10/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs) 1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] 1735 else: -> 1736 return self._call_impl(*args, **kwargs) File /opt/mamba/lib/python3.10/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs) 1742 # If we don't have any hooks, we want to skip the rest of the logic in 1743 # this function, and just call forward. 1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1745 or _global_backward_pre_hooks or _global_backward_hooks 1746 or _global_forward_hooks or _global_forward_pre_hooks): -> 1747 return forward_call(*args, **kwargs) 1749 result = None 1750 called_always_called_hooks = set() File /opt/mamba/lib/python3.10/site-packages/torch/nn/modules/container.py:250, in Sequential.forward(self, input) 248 def forward(self, input): 249 for module in self: --> 250 input = module(input) 251 return input File /opt/mamba/lib/python3.10/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs) 1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] 1735 else: -> 1736 return self._call_impl(*args, **kwargs) File /opt/mamba/lib/python3.10/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs) 1742 # If we don't have any hooks, we want to skip the rest of the logic in 1743 # this function, and just call forward. 1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1745 or _global_backward_pre_hooks or _global_backward_hooks 1746 or _global_forward_hooks or _global_forward_pre_hooks): -> 1747 return forward_call(*args, **kwargs) 1749 result = None 1750 called_always_called_hooks = set() File /opt/mamba/lib/python3.10/site-packages/torchvision/models/resnet.py:92, in BasicBlock.forward(self, x) 89 def forward(self, x: Tensor) -> Tensor: 90 identity = x ---> 92 out = self.conv1(x) 93 out = self.bn1(out) 94 out = self.relu(out) File /opt/mamba/lib/python3.10/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs) 1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] 1735 else: -> 1736 return self._call_impl(*args, **kwargs) File /opt/mamba/lib/python3.10/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs) 1742 # If we don't have any hooks, we want to skip the rest of the logic in 1743 # this function, and just call forward. 1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1745 or _global_backward_pre_hooks or _global_backward_hooks 1746 or _global_forward_hooks or _global_forward_pre_hooks): -> 1747 return forward_call(*args, **kwargs) 1749 result = None 1750 called_always_called_hooks = set() File /opt/mamba/lib/python3.10/site-packages/torch/nn/modules/conv.py:554, in Conv2d.forward(self, input) 553 def forward(self, input: Tensor) -> Tensor: --> 554 return self._conv_forward(input, self.weight, self.bias) File /opt/mamba/lib/python3.10/site-packages/torch/nn/modules/conv.py:549, in Conv2d._conv_forward(self, input, weight, bias) 537 if self.padding_mode != "zeros": 538 return F.conv2d( 539 F.pad( 540 input, self._reversed_padding_repeated_twice, mode=self.padding_mode (...) 547 self.groups, 548 ) --> 549 return F.conv2d( 550 input, weight, bias, self.stride, self.padding, self.dilation, self.groups 551 ) KeyboardInterrupt:
代码
文本
以上代码,采用的强制停止训练的报错。与代码运行无关。
代码
文本
[2]
!pip install torchvision
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple Collecting torchvision Downloading https://pypi.tuna.tsinghua.edu.cn/packages/a2/f6/7ff89a9f8703f623f5664afd66c8600e3f09fe188e1e0b7e6f9a8617f865/torchvision-0.20.1-cp310-cp310-manylinux1_x86_64.whl (7.2 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 7.2/7.2 MB 52.0 MB/s eta 0:00:00 Requirement already satisfied: numpy in /opt/deepmd-kit-3.0.0b4/lib/python3.10/site-packages (from torchvision) (1.26.4) Collecting torch==2.5.1 (from torchvision) Downloading https://pypi.tuna.tsinghua.edu.cn/packages/2a/ef/834af4a885b31a0b32fff2d80e1e40f771e1566ea8ded55347502440786a/torch-2.5.1-cp310-cp310-manylinux1_x86_64.whl (906.4 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 906.4/906.4 MB 9.1 MB/s eta 0:00:0000:0100:03 Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /opt/deepmd-kit-3.0.0b4/lib/python3.10/site-packages (from torchvision) (10.4.0) Requirement already satisfied: filelock in /opt/deepmd-kit-3.0.0b4/lib/python3.10/site-packages (from torch==2.5.1->torchvision) (3.16.1) Requirement already satisfied: typing-extensions>=4.8.0 in /opt/deepmd-kit-3.0.0b4/lib/python3.10/site-packages (from torch==2.5.1->torchvision) (4.12.2) Requirement already satisfied: networkx in /opt/deepmd-kit-3.0.0b4/lib/python3.10/site-packages (from torch==2.5.1->torchvision) (3.3) Requirement already satisfied: jinja2 in /opt/deepmd-kit-3.0.0b4/lib/python3.10/site-packages (from torch==2.5.1->torchvision) (3.1.4) Collecting fsspec (from torch==2.5.1->torchvision) Downloading https://pypi.tuna.tsinghua.edu.cn/packages/c6/b2/454d6e7f0158951d8a78c2e1eb4f69ae81beb8dca5fee9809c6c99e9d0d0/fsspec-2024.10.0-py3-none-any.whl (179 kB) Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch==2.5.1->torchvision) Downloading https://pypi.tuna.tsinghua.edu.cn/packages/2c/14/91ae57cd4db3f9ef7aa99f4019cfa8d54cb4caa7e00975df6467e9725a9f/nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (24.6 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 24.6/24.6 MB 24.4 MB/s eta 0:00:0000:0100:01 Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch==2.5.1->torchvision) Downloading https://pypi.tuna.tsinghua.edu.cn/packages/ea/27/1795d86fe88ef397885f2e580ac37628ed058a92ed2c39dc8eac3adf0619/nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (883 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 883.7/883.7 kB 9.1 MB/s eta 0:00:00 Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch==2.5.1->torchvision) Downloading https://pypi.tuna.tsinghua.edu.cn/packages/67/42/f4f60238e8194a3106d06a058d494b18e006c10bb2b915655bd9f6ea4cb1/nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (13.8 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 13.8/13.8 MB 14.8 MB/s eta 0:00:00a 0:00:01 Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch==2.5.1->torchvision) Downloading https://pypi.tuna.tsinghua.edu.cn/packages/9f/fd/713452cd72343f682b1c7b9321e23829f00b842ceaedcda96e742ea0b0b3/nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl (664.8 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 664.8/664.8 MB 10.1 MB/s eta 0:00:0000:0100:02 Collecting nvidia-cublas-cu12==12.4.5.8 (from torch==2.5.1->torchvision) Downloading https://pypi.tuna.tsinghua.edu.cn/packages/ae/71/1c91302526c45ab494c23f61c7a84aa568b8c1f9d196efa5993957faf906/nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl (363.4 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 363.4/363.4 MB 11.3 MB/s eta 0:00:0000:0100:01 Collecting nvidia-cufft-cu12==11.2.1.3 (from torch==2.5.1->torchvision) Downloading https://pypi.tuna.tsinghua.edu.cn/packages/27/94/3266821f65b92b3138631e9c8e7fe1fb513804ac934485a8d05776e1dd43/nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl (211.5 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 211.5/211.5 MB 12.6 MB/s eta 0:00:0000:0100:01 Collecting nvidia-curand-cu12==10.3.5.147 (from torch==2.5.1->torchvision) Downloading https://pypi.tuna.tsinghua.edu.cn/packages/8a/6d/44ad094874c6f1b9c654f8ed939590bdc408349f137f9b98a3a23ccec411/nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl (56.3 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 56.3/56.3 MB 15.6 MB/s eta 0:00:0000:0100:01 Collecting nvidia-cusolver-cu12==11.6.1.9 (from torch==2.5.1->torchvision) Downloading https://pypi.tuna.tsinghua.edu.cn/packages/3a/e1/5b9089a4b2a4790dfdea8b3a006052cfecff58139d5a4e34cb1a51df8d6f/nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl (127.9 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 127.9/127.9 MB 12.6 MB/s eta 0:00:0000:0100:01 Collecting nvidia-cusparse-cu12==12.3.1.170 (from torch==2.5.1->torchvision) Downloading https://pypi.tuna.tsinghua.edu.cn/packages/db/f7/97a9ea26ed4bbbfc2d470994b8b4f338ef663be97b8f677519ac195e113d/nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl (207.5 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 207.5/207.5 MB 12.3 MB/s eta 0:00:0000:0100:01 Collecting nvidia-nccl-cu12==2.21.5 (from torch==2.5.1->torchvision) Downloading https://pypi.tuna.tsinghua.edu.cn/packages/df/99/12cd266d6233f47d00daf3a72739872bdc10267d0383508b0b9c84a18bb6/nvidia_nccl_cu12-2.21.5-py3-none-manylinux2014_x86_64.whl (188.7 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 188.7/188.7 MB 12.8 MB/s eta 0:00:0000:0100:01 Collecting nvidia-nvtx-cu12==12.4.127 (from torch==2.5.1->torchvision) Downloading https://pypi.tuna.tsinghua.edu.cn/packages/87/20/199b8713428322a2f22b722c62b8cc278cc53dffa9705d744484b5035ee9/nvidia_nvtx_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (99 kB) Collecting nvidia-nvjitlink-cu12==12.4.127 (from torch==2.5.1->torchvision) Downloading https://pypi.tuna.tsinghua.edu.cn/packages/ff/ff/847841bacfbefc97a00036e0fce5a0f086b640756dc38caea5e1bb002655/nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (21.1 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 21.1/21.1 MB 28.6 MB/s eta 0:00:00a 0:00:01 Collecting triton==3.1.0 (from torch==2.5.1->torchvision) Downloading https://pypi.tuna.tsinghua.edu.cn/packages/98/29/69aa56dc0b2eb2602b553881e34243475ea2afd9699be042316842788ff5/triton-3.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (209.5 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 209.5/209.5 MB 12.0 MB/s eta 0:00:0000:0100:01 Collecting sympy==1.13.1 (from torch==2.5.1->torchvision) Downloading https://pypi.tuna.tsinghua.edu.cn/packages/b2/fe/81695a1aa331a842b582453b605175f419fe8540355886031328089d840a/sympy-1.13.1-py3-none-any.whl (6.2 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 6.2/6.2 MB 15.4 MB/s eta 0:00:00a 0:00:01 Requirement already satisfied: mpmath<1.4,>=1.1.0 in /opt/deepmd-kit-3.0.0b4/lib/python3.10/site-packages (from sympy==1.13.1->torch==2.5.1->torchvision) (1.3.0) Requirement already satisfied: MarkupSafe>=2.0 in /opt/deepmd-kit-3.0.0b4/lib/python3.10/site-packages (from jinja2->torch==2.5.1->torchvision) (2.1.5) Installing collected packages: triton, sympy, nvidia-nvtx-cu12, nvidia-nvjitlink-cu12, nvidia-nccl-cu12, nvidia-curand-cu12, nvidia-cufft-cu12, nvidia-cuda-runtime-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-cupti-cu12, nvidia-cublas-cu12, fsspec, nvidia-cusparse-cu12, nvidia-cudnn-cu12, nvidia-cusolver-cu12, torch, torchvision Attempting uninstall: sympy Found existing installation: sympy 1.13.2 Uninstalling sympy-1.13.2: Successfully uninstalled sympy-1.13.2 Attempting uninstall: torch Found existing installation: torch 2.0.0.post200 Uninstalling torch-2.0.0.post200: Successfully uninstalled torch-2.0.0.post200 Successfully installed fsspec-2024.10.0 nvidia-cublas-cu12-12.4.5.8 nvidia-cuda-cupti-cu12-12.4.127 nvidia-cuda-nvrtc-cu12-12.4.127 nvidia-cuda-runtime-cu12-12.4.127 nvidia-cudnn-cu12-9.1.0.70 nvidia-cufft-cu12-11.2.1.3 nvidia-curand-cu12-10.3.5.147 nvidia-cusolver-cu12-11.6.1.9 nvidia-cusparse-cu12-12.3.1.170 nvidia-nccl-cu12-2.21.5 nvidia-nvjitlink-cu12-12.4.127 nvidia-nvtx-cu12-12.4.127 sympy-1.13.1 torch-2.5.1 torchvision-0.20.1 triton-3.1.0 WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager, possibly rendering your system unusable.It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv. Use the --root-user-action option if you know what you are doing and want to suppress this warning.
代码
文本
[4]
!pip uninstall torch torchvision
!pip install torch==2.0.1 torchvision==0.15.2
Found existing installation: torch 2.5.1 Uninstalling torch-2.5.1: Would remove: /opt/deepmd-kit-3.0.0b4/bin/convert-caffe2-to-onnx /opt/deepmd-kit-3.0.0b4/bin/convert-onnx-to-caffe2 /opt/deepmd-kit-3.0.0b4/bin/torchfrtrace /opt/deepmd-kit-3.0.0b4/bin/torchrun /opt/deepmd-kit-3.0.0b4/lib/python3.10/site-packages/functorch/* /opt/deepmd-kit-3.0.0b4/lib/python3.10/site-packages/torch-2.5.1.dist-info/* /opt/deepmd-kit-3.0.0b4/lib/python3.10/site-packages/torch/* /opt/deepmd-kit-3.0.0b4/lib/python3.10/site-packages/torchgen/* Proceed (Y/n)? ^C ERROR: Operation cancelled by user Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple Collecting torch==2.0.1 Downloading https://pypi.tuna.tsinghua.edu.cn/packages/8c/4d/17e07377c9c3d1a0c4eb3fde1c7c16b5a0ce6133ddbabc08ceef6b7f2645/torch-2.0.1-cp310-cp310-manylinux1_x86_64.whl (619.9 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 619.9/619.9 MB 10.2 MB/s eta 0:00:0000:0100:02 Collecting torchvision==0.15.2 Downloading https://pypi.tuna.tsinghua.edu.cn/packages/87/0f/88f023bf6176d9af0f85feedf4be129f9cf2748801c4d9c690739a10c100/torchvision-0.15.2-cp310-cp310-manylinux1_x86_64.whl (6.0 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 6.0/6.0 MB 27.0 MB/s eta 0:00:00a 0:00:01 Requirement already satisfied: filelock in /opt/deepmd-kit-3.0.0b4/lib/python3.10/site-packages (from torch==2.0.1) (3.16.1) Requirement already satisfied: typing-extensions in /opt/deepmd-kit-3.0.0b4/lib/python3.10/site-packages (from torch==2.0.1) (4.12.2) Requirement already satisfied: sympy in /opt/deepmd-kit-3.0.0b4/lib/python3.10/site-packages (from torch==2.0.1) (1.13.1) Requirement already satisfied: networkx in /opt/deepmd-kit-3.0.0b4/lib/python3.10/site-packages (from torch==2.0.1) (3.3) Requirement already satisfied: jinja2 in /opt/deepmd-kit-3.0.0b4/lib/python3.10/site-packages (from torch==2.0.1) (3.1.4) Collecting nvidia-cuda-nvrtc-cu11==11.7.99 (from torch==2.0.1) Downloading https://pypi.tuna.tsinghua.edu.cn/packages/ef/25/922c5996aada6611b79b53985af7999fc629aee1d5d001b6a22431e18fec/nvidia_cuda_nvrtc_cu11-11.7.99-2-py3-none-manylinux1_x86_64.whl (21.0 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 21.0/21.0 MB 22.8 MB/s eta 0:00:00a 0:00:01 Collecting nvidia-cuda-runtime-cu11==11.7.99 (from torch==2.0.1) Downloading https://pypi.tuna.tsinghua.edu.cn/packages/36/92/89cf558b514125d2ebd8344dd2f0533404b416486ff681d5434a5832a019/nvidia_cuda_runtime_cu11-11.7.99-py3-none-manylinux1_x86_64.whl (849 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 849.3/849.3 kB 16.1 MB/s eta 0:00:00 Collecting nvidia-cuda-cupti-cu11==11.7.101 (from torch==2.0.1) Downloading https://pypi.tuna.tsinghua.edu.cn/packages/e6/9d/dd0cdcd800e642e3c82ee3b5987c751afd4f3fb9cc2752517f42c3bc6e49/nvidia_cuda_cupti_cu11-11.7.101-py3-none-manylinux1_x86_64.whl (11.8 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 11.8/11.8 MB 11.5 MB/s eta 0:00:0000:010:01 Collecting nvidia-cudnn-cu11==8.5.0.96 (from torch==2.0.1) Downloading https://pypi.tuna.tsinghua.edu.cn/packages/dc/30/66d4347d6e864334da5bb1c7571305e501dcb11b9155971421bb7bb5315f/nvidia_cudnn_cu11-8.5.0.96-2-py3-none-manylinux1_x86_64.whl (557.1 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 557.1/557.1 MB 10.5 MB/s eta 0:00:0000:0100:02 Collecting nvidia-cublas-cu11==11.10.3.66 (from torch==2.0.1) Downloading https://pypi.tuna.tsinghua.edu.cn/packages/ce/41/fdeb62b5437996e841d83d7d2714ca75b886547ee8017ee2fe6ea409d983/nvidia_cublas_cu11-11.10.3.66-py3-none-manylinux1_x86_64.whl (317.1 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 317.1/317.1 MB 11.5 MB/s eta 0:00:0000:0100:01 Collecting nvidia-cufft-cu11==10.9.0.58 (from torch==2.0.1) Downloading https://pypi.tuna.tsinghua.edu.cn/packages/64/c8/133717b43182ba063803e983e7680a94826a9f4ff5734af0ca315803f1b3/nvidia_cufft_cu11-10.9.0.58-py3-none-manylinux2014_x86_64.whl (168.4 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 168.4/168.4 MB 12.8 MB/s eta 0:00:0000:0100:01 Collecting nvidia-curand-cu11==10.2.10.91 (from torch==2.0.1) Downloading https://pypi.tuna.tsinghua.edu.cn/packages/8f/11/af78d54b2420e64a4dd19e704f5bb69dcb5a6a3138b4465d6a48cdf59a21/nvidia_curand_cu11-10.2.10.91-py3-none-manylinux1_x86_64.whl (54.6 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 54.6/54.6 MB 15.7 MB/s eta 0:00:0000:0100:01 Collecting nvidia-cusolver-cu11==11.4.0.1 (from torch==2.0.1) Downloading https://pypi.tuna.tsinghua.edu.cn/packages/3e/77/66149e3153b19312fb782ea367f3f950123b93916a45538b573fe373570a/nvidia_cusolver_cu11-11.4.0.1-2-py3-none-manylinux1_x86_64.whl (102.6 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 102.6/102.6 MB 12.7 MB/s eta 0:00:0000:0100:01 Collecting nvidia-cusparse-cu11==11.7.4.91 (from torch==2.0.1) Downloading https://pypi.tuna.tsinghua.edu.cn/packages/ea/6f/6d032cc1bb7db88a989ddce3f4968419a7edeafda362847f42f614b1f845/nvidia_cusparse_cu11-11.7.4.91-py3-none-manylinux1_x86_64.whl (173.2 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 173.2/173.2 MB 12.3 MB/s eta 0:00:0000:0100:01 Collecting nvidia-nccl-cu11==2.14.3 (from torch==2.0.1) Downloading https://pypi.tuna.tsinghua.edu.cn/packages/55/92/914cdb650b6a5d1478f83148597a25e90ea37d739bd563c5096b0e8a5f43/nvidia_nccl_cu11-2.14.3-py3-none-manylinux1_x86_64.whl (177.1 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 177.1/177.1 MB 12.5 MB/s eta 0:00:0000:0100:01 Collecting nvidia-nvtx-cu11==11.7.91 (from torch==2.0.1) Downloading https://pypi.tuna.tsinghua.edu.cn/packages/23/d5/09493ff0e64fd77523afbbb075108f27a13790479efe86b9ffb4587671b5/nvidia_nvtx_cu11-11.7.91-py3-none-manylinux1_x86_64.whl (98 kB) Collecting triton==2.0.0 (from torch==2.0.1) Downloading https://pypi.tuna.tsinghua.edu.cn/packages/ca/31/ff6be541195daf77aa5c72303b2354661a69e717967d44d91eb4f3fdce32/triton-2.0.0-1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (63.3 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 63.3/63.3 MB 13.0 MB/s eta 0:00:0000:0100:01 Requirement already satisfied: numpy in /opt/deepmd-kit-3.0.0b4/lib/python3.10/site-packages (from torchvision==0.15.2) (1.26.4) Requirement already satisfied: requests in /opt/deepmd-kit-3.0.0b4/lib/python3.10/site-packages (from torchvision==0.15.2) (2.32.3) Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /opt/deepmd-kit-3.0.0b4/lib/python3.10/site-packages (from torchvision==0.15.2) (10.4.0) Requirement already satisfied: setuptools in /opt/deepmd-kit-3.0.0b4/lib/python3.10/site-packages (from nvidia-cublas-cu11==11.10.3.66->torch==2.0.1) (75.1.0) Requirement already satisfied: wheel in /opt/deepmd-kit-3.0.0b4/lib/python3.10/site-packages (from nvidia-cublas-cu11==11.10.3.66->torch==2.0.1) (0.44.0) Collecting cmake (from triton==2.0.0->torch==2.0.1) Downloading https://pypi.tuna.tsinghua.edu.cn/packages/4b/87/3b01bc707c7a021f3c2c2d133af405618b1580613e22a9ead4a9d9d49903/cmake-3.30.5-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (26.9 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 26.9/26.9 MB 10.3 MB/s eta 0:00:00a 0:00:01 Collecting lit (from triton==2.0.0->torch==2.0.1) Downloading https://pypi.tuna.tsinghua.edu.cn/packages/96/06/b36f150fa7c5bcc96a31a4d19a20fddbd1d965b6f02510b57a3bb8d4b930/lit-18.1.8-py3-none-any.whl (96 kB) Requirement already satisfied: MarkupSafe>=2.0 in /opt/deepmd-kit-3.0.0b4/lib/python3.10/site-packages (from jinja2->torch==2.0.1) (2.1.5) Requirement already satisfied: charset-normalizer<4,>=2 in /opt/deepmd-kit-3.0.0b4/lib/python3.10/site-packages (from requests->torchvision==0.15.2) (3.3.2) Requirement already satisfied: idna<4,>=2.5 in /opt/deepmd-kit-3.0.0b4/lib/python3.10/site-packages (from requests->torchvision==0.15.2) (3.10) Requirement already satisfied: urllib3<3,>=1.21.1 in /opt/deepmd-kit-3.0.0b4/lib/python3.10/site-packages (from requests->torchvision==0.15.2) (2.2.3) Requirement already satisfied: certifi>=2017.4.17 in /opt/deepmd-kit-3.0.0b4/lib/python3.10/site-packages (from requests->torchvision==0.15.2) (2024.8.30) Requirement already satisfied: mpmath<1.4,>=1.1.0 in /opt/deepmd-kit-3.0.0b4/lib/python3.10/site-packages (from sympy->torch==2.0.1) (1.3.0) Installing collected packages: lit, nvidia-nvtx-cu11, nvidia-nccl-cu11, nvidia-cusparse-cu11, nvidia-curand-cu11, nvidia-cufft-cu11, nvidia-cuda-runtime-cu11, nvidia-cuda-nvrtc-cu11, nvidia-cuda-cupti-cu11, nvidia-cublas-cu11, cmake, nvidia-cusolver-cu11, nvidia-cudnn-cu11, triton, torch, torchvision Attempting uninstall: triton Found existing installation: triton 3.1.0 Uninstalling triton-3.1.0: Successfully uninstalled triton-3.1.0 Attempting uninstall: torch Found existing installation: torch 2.5.1 Uninstalling torch-2.5.1: Successfully uninstalled torch-2.5.1 Attempting uninstall: torchvision Found existing installation: torchvision 0.20.1 Uninstalling torchvision-0.20.1: Successfully uninstalled torchvision-0.20.1 Successfully installed cmake-3.30.5 lit-18.1.8 nvidia-cublas-cu11-11.10.3.66 nvidia-cuda-cupti-cu11-11.7.101 nvidia-cuda-nvrtc-cu11-11.7.99 nvidia-cuda-runtime-cu11-11.7.99 nvidia-cudnn-cu11-8.5.0.96 nvidia-cufft-cu11-10.9.0.58 nvidia-curand-cu11-10.2.10.91 nvidia-cusolver-cu11-11.4.0.1 nvidia-cusparse-cu11-11.7.4.91 nvidia-nccl-cu11-2.14.3 nvidia-nvtx-cu11-11.7.91 torch-2.0.1 torchvision-0.15.2 triton-2.0.0 WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager, possibly rendering your system unusable.It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv. Use the --root-user-action option if you know what you are doing and want to suppress this warning.
代码
文本
第一版修改,效果不好,舍去
代码
文本
[6]
import os
import sys
import json
import torch
import torch.nn as nn
from torchvision import transforms, datasets
import torch.optim as optim
from tqdm import tqdm
from torchvision.models import resnet34
def main():
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("using {} device.".format(device))
data_transform = {
"train": transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(15),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
"val": transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
}
data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))
image_path = os.path.join(data_root, "/personal/Auto_Titration/Picture_Train/data/")
assert os.path.exists(image_path), "{} path does not exist.".format(image_path)
train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
transform=data_transform["train"])
train_num = len(train_dataset)
flower_list = train_dataset.class_to_idx
cla_dict = dict((val, key) for key, val in flower_list.items())
json_str = json.dumps(cla_dict, indent=4)
with open('class_indices.json', 'w') as json_file:
json_file.write(json_str)
batch_size = 32
nw = min([os.cpu_count() // 2, batch_size if batch_size > 1 else 0, 8])
print('Using {} dataloader workers every process'.format(nw))
train_loader = torch.utils.data.DataLoader(train_dataset,
batch_size=batch_size, shuffle=True,
num_workers=nw)
validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
transform=data_transform["val"])
val_num = len(validate_dataset)
validate_loader = torch.utils.data.DataLoader(validate_dataset,
batch_size=batch_size, shuffle=False,
num_workers=nw)
print("using {} images for training, {} images for validation.".format(train_num, val_num))
model_name = "resnet34-3"
net = resnet34(num_classes=2)
net.to(device)
loss_function = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.0001)
epochs = 100
best_acc = 0.0
patience_counter = 0
patience_limit = 10
save_path = './{}Net.pth'.format(model_name)
train_steps = len(train_loader)
for epoch in range(epochs):
net.train()
running_loss = 0.0
train_bar = tqdm(train_loader, file=sys.stdout)
for step, data in enumerate(train_bar):
images, labels = data
optimizer.zero_grad()
outputs = net(images.to(device))
loss = loss_function(outputs, labels.to(device))
loss.backward()
optimizer.step()
running_loss += loss.item()
train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1, epochs, loss)
net.eval()
acc = 0.0
with torch.no_grad():
val_bar = tqdm(validate_loader, file=sys.stdout)
for val_data in val_bar:
val_images, val_labels = val_data
outputs = net(val_images.to(device))
predict_y = torch.max(outputs, dim=1)[1]
acc += torch.eq(predict_y, val_labels.to(device)).sum().item()
val_accurate = acc / val_num
print('[epoch %d] train_loss: %.3f val_accuracy: %.3f' %
(epoch + 1, running_loss / train_steps, val_accurate))
if val_accurate > best_acc:
best_acc = val_accurate
torch.save(net.state_dict(), save_path)
patience_counter = 0
else:
patience_counter += 1
if patience_counter > patience_limit:
print("Early stopping triggered.")
break
print('Finished Training')
if __name__ == '__main__':
main()
/opt/deepmd-kit-3.0.0b4/lib/python3.10/site-packages/torchvision/io/image.py:13: UserWarning: Failed to load image Python extension: '/opt/deepmd-kit-3.0.0b4/lib/python3.10/site-packages/torchvision/image.so: undefined symbol: _ZN3c106detail23torchInternalAssertFailEPKcS2_jS2_RKSs'If you don't plan on using image functionality from `torchvision.io`, you can ignore this warning. Otherwise, there might be something wrong with your environment. Did you have `libjpeg` or `libpng` installed before building `torchvision` from source? warn( using cpu device. Using 8 dataloader workers every process using 604 images for training, 58 images for validation. train epoch[1/100] loss:0.609: 100%|██████████| 19/19 [01:03<00:00, 3.35s/it] 100%|██████████| 2/2 [00:08<00:00, 4.31s/it] [epoch 1] train_loss: 0.697 val_accuracy: 0.741 train epoch[2/100] loss:0.711: 100%|██████████| 19/19 [01:06<00:00, 3.52s/it] 100%|██████████| 2/2 [00:02<00:00, 1.37s/it] [epoch 2] train_loss: 0.595 val_accuracy: 0.379 train epoch[3/100] loss:0.575: 100%|██████████| 19/19 [01:02<00:00, 3.31s/it] 100%|██████████| 2/2 [00:02<00:00, 1.44s/it] [epoch 3] train_loss: 0.588 val_accuracy: 0.741 train epoch[4/100] loss:0.775: 100%|██████████| 19/19 [01:12<00:00, 3.81s/it] 100%|██████████| 2/2 [00:03<00:00, 1.53s/it] [epoch 4] train_loss: 0.574 val_accuracy: 0.724 train epoch[5/100] loss:0.667: 100%|██████████| 19/19 [01:17<00:00, 4.09s/it] 100%|██████████| 2/2 [00:03<00:00, 1.59s/it] [epoch 5] train_loss: 0.565 val_accuracy: 0.741 train epoch[6/100] loss:0.601: 100%|██████████| 19/19 [01:08<00:00, 3.61s/it] 100%|██████████| 2/2 [00:03<00:00, 1.53s/it] [epoch 6] train_loss: 0.554 val_accuracy: 0.879 train epoch[7/100] loss:0.642: 100%|██████████| 19/19 [01:15<00:00, 3.97s/it] 100%|██████████| 2/2 [00:02<00:00, 1.41s/it] [epoch 7] train_loss: 0.572 val_accuracy: 0.741 train epoch[8/100] loss:0.465: 100%|██████████| 19/19 [01:12<00:00, 3.79s/it] 100%|██████████| 2/2 [00:05<00:00, 2.89s/it] [epoch 8] train_loss: 0.564 val_accuracy: 0.741 train epoch[9/100] loss:0.473: 100%|██████████| 19/19 [01:08<00:00, 3.60s/it] 100%|██████████| 2/2 [00:02<00:00, 1.36s/it] [epoch 9] train_loss: 0.540 val_accuracy: 0.759 train epoch[10/100] loss:0.583: 100%|██████████| 19/19 [01:08<00:00, 3.60s/it] 100%|██████████| 2/2 [00:04<00:00, 2.31s/it] [epoch 10] train_loss: 0.532 val_accuracy: 0.810 train epoch[11/100] loss:0.754: 100%|██████████| 19/19 [01:01<00:00, 3.23s/it] 100%|██████████| 2/2 [00:02<00:00, 1.43s/it] [epoch 11] train_loss: 0.539 val_accuracy: 0.741 train epoch[12/100] loss:0.441: 100%|██████████| 19/19 [01:07<00:00, 3.57s/it] 100%|██████████| 2/2 [00:02<00:00, 1.38s/it] [epoch 12] train_loss: 0.532 val_accuracy: 0.741 train epoch[13/100] loss:0.516: 100%|██████████| 19/19 [01:04<00:00, 3.41s/it] 100%|██████████| 2/2 [00:02<00:00, 1.35s/it] [epoch 13] train_loss: 0.498 val_accuracy: 0.707 train epoch[14/100] loss:0.511: 100%|██████████| 19/19 [01:07<00:00, 3.56s/it] 100%|██████████| 2/2 [00:03<00:00, 1.52s/it] [epoch 14] train_loss: 0.553 val_accuracy: 0.741 train epoch[15/100] loss:0.543: 100%|██████████| 19/19 [01:09<00:00, 3.63s/it] 100%|██████████| 2/2 [00:03<00:00, 1.59s/it] [epoch 15] train_loss: 0.510 val_accuracy: 0.759 train epoch[16/100] loss:0.286: 100%|██████████| 19/19 [01:12<00:00, 3.82s/it] 100%|██████████| 2/2 [00:02<00:00, 1.45s/it] [epoch 16] train_loss: 0.466 val_accuracy: 0.759 train epoch[17/100] loss:0.498: 100%|██████████| 19/19 [01:04<00:00, 3.38s/it] 100%|██████████| 2/2 [00:02<00:00, 1.45s/it] [epoch 17] train_loss: 0.524 val_accuracy: 0.638 Early stopping triggered. Finished Training
代码
文本
第二版修改,效修改了代码中的早停机制,效果不好,舍去
代码
文本
[7]
import os
import sys
import json
import time # 新增
import torch
import torch.nn as nn
from torchvision import transforms, datasets
import torch.optim as optim
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter # 新增
from torchvision.models import resnet34
def main():
# 初始化 TensorBoard
writer = SummaryWriter('runs/experiment') # 设置日志文件保存路径
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("using {} device.".format(device))
# 数据增强变换
data_transform = {
"train": transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(15),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
"val": transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
}
# 获取数据集路径
data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))
image_path = os.path.join(data_root, "/personal/Auto_Titration/Picture_Train/data/")
assert os.path.exists(image_path), "{} path does not exist.".format(image_path)
# 加载训练数据集
train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
transform=data_transform["train"])
train_num = len(train_dataset)
flower_list = train_dataset.class_to_idx
cla_dict = dict((val, key) for key, val in flower_list.items())
json_str = json.dumps(cla_dict, indent=4)
with open('class_indices.json', 'w') as json_file:
json_file.write(json_str)
batch_size = 32
nw = min([os.cpu_count() // 2, batch_size if batch_size > 1 else 0, 8])
print('Using {} dataloader workers every process'.format(nw))
train_loader = torch.utils.data.DataLoader(train_dataset,
batch_size=batch_size, shuffle=True,
num_workers=nw)
# 加载验证数据集
validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
transform=data_transform["val"])
val_num = len(validate_dataset)
validate_loader = torch.utils.data.DataLoader(validate_dataset,
batch_size=batch_size, shuffle=False,
num_workers=nw)
print("using {} images for training, {} images for validation.".format(train_num, val_num))
model_name = "resnet34-3"
net = resnet34(num_classes=2)
net.to(device)
loss_function = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.0001)
epochs = 100
best_acc = 0.0
patience_counter = 0
patience_limit = 10
save_path = './{}Net.pth'.format(model_name)
train_steps = len(train_loader)
# 开始计时
start_time = time.time()
for epoch in range(epochs):
# 记录每轮epoch的开始时间
epoch_start_time = time.time()
net.train()
running_loss = 0.0
train_bar = tqdm(train_loader, file=sys.stdout)
for step, data in enumerate(train_bar):
images, labels = data
optimizer.zero_grad()
outputs = net(images.to(device))
loss = loss_function(outputs, labels.to(device))
loss.backward()
optimizer.step()
running_loss += loss.item()
train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1, epochs, loss)
# 验证阶段
net.eval()
acc = 0.0
with torch.no_grad():
val_bar = tqdm(validate_loader, file=sys.stdout)
for val_data in val_bar:
val_images, val_labels = val_data
outputs = net(val_images.to(device))
predict_y = torch.max(outputs, dim=1)[1]
acc += torch.eq(predict_y, val_labels.to(device)).sum().item()
val_accurate = acc / val_num
epoch_time = time.time() - epoch_start_time # 计算当前epoch时长
# 打印当前epoch损失、验证准确率和时长
print('[epoch %d] train_loss: %.3f val_accuracy: %.3f epoch_time: %.2f sec' %
(epoch + 1, running_loss / train_steps, val_accurate, epoch_time))
# 将数据记录到TensorBoard
writer.add_scalar('Loss/train', running_loss / train_steps, epoch)
writer.add_scalar('Accuracy/val', val_accurate, epoch)
writer.add_scalar('Time/epoch', epoch_time, epoch)
# Early Stopping 机制
if val_accurate > best_acc:
best_acc = val_accurate
torch.save(net.state_dict(), save_path)
patience_counter = 0
else:
patience_counter += 1
if patience_counter > patience_limit:
print("Early stopping triggered.")
break
# 训练完成后记录总时长
total_time = time.time() - start_time
print(f"Training completed in {total_time:.2f} seconds")
writer.add_text('Training Time', f'Total training time: {total_time:.2f} seconds')
# 关闭 TensorBoard 记录
writer.close()
print('Finished Training')
if __name__ == '__main__':
main()
2024-10-30 22:57:14.344638: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered 2024-10-30 22:57:14.344680: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered 2024-10-30 22:57:14.344687: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered 2024-10-30 22:57:14.478038: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations. To enable the following instructions: SSE4.1 SSE4.2 AVX AVX2 AVX512F FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags. using cpu device. Using 8 dataloader workers every process using 604 images for training, 58 images for validation. train epoch[1/100] loss:0.648: 100%|██████████| 19/19 [01:07<00:00, 3.57s/it] 100%|██████████| 2/2 [00:03<00:00, 1.58s/it] [epoch 1] train_loss: 0.690 val_accuracy: 0.741 epoch_time: 70.99 sec train epoch[2/100] loss:0.603: 100%|██████████| 19/19 [01:20<00:00, 4.23s/it] 100%|██████████| 2/2 [00:03<00:00, 1.56s/it] [epoch 2] train_loss: 0.601 val_accuracy: 0.741 epoch_time: 83.52 sec train epoch[3/100] loss:0.615: 100%|██████████| 19/19 [01:03<00:00, 3.35s/it] 100%|██████████| 2/2 [00:02<00:00, 1.49s/it] [epoch 3] train_loss: 0.611 val_accuracy: 0.741 epoch_time: 66.54 sec train epoch[4/100] loss:0.750: 100%|██████████| 19/19 [01:11<00:00, 3.74s/it] 100%|██████████| 2/2 [00:03<00:00, 1.54s/it] [epoch 4] train_loss: 0.596 val_accuracy: 0.741 epoch_time: 74.12 sec train epoch[5/100] loss:0.694: 100%|██████████| 19/19 [01:08<00:00, 3.63s/it] 100%|██████████| 2/2 [00:03<00:00, 1.57s/it] [epoch 5] train_loss: 0.570 val_accuracy: 0.776 epoch_time: 72.13 sec train epoch[6/100] loss:0.706: 100%|██████████| 19/19 [01:18<00:00, 4.11s/it] 100%|██████████| 2/2 [00:03<00:00, 1.52s/it] [epoch 6] train_loss: 0.554 val_accuracy: 0.741 epoch_time: 81.20 sec train epoch[7/100] loss:0.569: 100%|██████████| 19/19 [01:11<00:00, 3.76s/it] 100%|██████████| 2/2 [00:03<00:00, 1.65s/it] [epoch 7] train_loss: 0.569 val_accuracy: 0.741 epoch_time: 74.66 sec train epoch[8/100] loss:0.612: 100%|██████████| 19/19 [01:11<00:00, 3.77s/it] 100%|██████████| 2/2 [00:02<00:00, 1.34s/it] [epoch 8] train_loss: 0.559 val_accuracy: 0.741 epoch_time: 74.28 sec train epoch[9/100] loss:0.706: 100%|██████████| 19/19 [01:20<00:00, 4.22s/it] 100%|██████████| 2/2 [00:02<00:00, 1.40s/it] [epoch 9] train_loss: 0.547 val_accuracy: 0.741 epoch_time: 82.92 sec train epoch[10/100] loss:0.775: 100%|██████████| 19/19 [01:12<00:00, 3.81s/it] 100%|██████████| 2/2 [00:02<00:00, 1.37s/it] [epoch 10] train_loss: 0.603 val_accuracy: 0.741 epoch_time: 75.11 sec train epoch[11/100] loss:0.585: 100%|██████████| 19/19 [01:31<00:00, 4.81s/it] 100%|██████████| 2/2 [00:02<00:00, 1.38s/it] [epoch 11] train_loss: 0.546 val_accuracy: 0.655 epoch_time: 94.20 sec train epoch[12/100] loss:0.430: 100%|██████████| 19/19 [01:06<00:00, 3.51s/it] 100%|██████████| 2/2 [00:02<00:00, 1.43s/it] [epoch 12] train_loss: 0.545 val_accuracy: 0.741 epoch_time: 69.57 sec train epoch[13/100] loss:0.548: 100%|██████████| 19/19 [01:11<00:00, 3.76s/it] 100%|██████████| 2/2 [00:02<00:00, 1.39s/it] [epoch 13] train_loss: 0.540 val_accuracy: 0.569 epoch_time: 74.31 sec train epoch[14/100] loss:0.499: 100%|██████████| 19/19 [01:04<00:00, 3.39s/it] 100%|██████████| 2/2 [00:02<00:00, 1.31s/it] [epoch 14] train_loss: 0.536 val_accuracy: 0.552 epoch_time: 66.94 sec train epoch[15/100] loss:0.536: 100%|██████████| 19/19 [01:05<00:00, 3.44s/it] 100%|██████████| 2/2 [00:03<00:00, 1.55s/it] [epoch 15] train_loss: 0.520 val_accuracy: 0.741 epoch_time: 68.49 sec train epoch[16/100] loss:0.633: 100%|██████████| 19/19 [01:26<00:00, 4.57s/it] 100%|██████████| 2/2 [00:02<00:00, 1.36s/it] [epoch 16] train_loss: 0.553 val_accuracy: 0.741 epoch_time: 89.62 sec Early stopping triggered. Training completed in 1219.07 seconds Finished Training
代码
文本
第三版修改,得到很好的效果
代码
文本
[18]
import os
import sys
import json
import time
import torch
import torch.nn as nn
from torchvision import transforms, datasets
import torch.optim as optim
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
from torchvision.models import resnet34
def main():
writer = SummaryWriter('/personal/Auto_Titration/Picture_Train/experiment_optimized')
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("using {} device.".format(device))
# 精简后的数据增强
data_transform = {
"train": transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
"val": transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
}
data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))
image_path = os.path.join(data_root, "/personal/Auto_Titration/Picture_Train/data/")
assert os.path.exists(image_path), "{} path does not exist.".format(image_path)
train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
transform=data_transform["train"])
train_num = len(train_dataset)
flower_list = train_dataset.class_to_idx
cla_dict = dict((val, key) for key, val in flower_list.items())
json_str = json.dumps(cla_dict, indent=4)
with open('class_indices.json', 'w') as json_file:
json_file.write(json_str)
batch_size = 32
nw = min([os.cpu_count() // 2, batch_size if batch_size > 1 else 0, 8])
print('Using {} dataloader workers every process'.format(nw))
train_loader = torch.utils.data.DataLoader(train_dataset,
batch_size=batch_size, shuffle=True,
num_workers=nw)
validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
transform=data_transform["val"])
val_num = len(validate_dataset)
validate_loader = torch.utils.data.DataLoader(validate_dataset,
batch_size=batch_size, shuffle=False,
num_workers=nw)
print("using {} images for training, {} images for validation.".format(train_num, val_num))
model_name = "resnet34-optimized"
net = resnet34(num_classes=2)
net.to(device)
loss_function = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-4) # 使用SGD优化器
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5) # 学习率调度器
epochs = 100
best_acc = 0.0
patience_counter = 0
patience_limit = 10 # 调低早停耐心次数
save_path = './{}Net.pth'.format(model_name)
train_steps = len(train_loader)
start_time = time.time()
for epoch in range(epochs):
epoch_start_time = time.time()
net.train()
running_loss = 0.0
train_bar = tqdm(train_loader, file=sys.stdout)
for step, data in enumerate(train_bar):
images, labels = data
optimizer.zero_grad()
outputs = net(images.to(device))
loss = loss_function(outputs, labels.to(device))
loss.backward()
optimizer.step()
running_loss += loss.item()
train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1, epochs, loss)
scheduler.step() # 更新学习率
net.eval()
acc = 0.0
with torch.no_grad():
val_bar = tqdm(validate_loader, file=sys.stdout)
for val_data in val_bar:
val_images, val_labels = val_data
outputs = net(val_images.to(device))
predict_y = torch.max(outputs, dim=1)[1]
acc += torch.eq(predict_y, val_labels.to(device)).sum().item()
val_accurate = acc / val_num
epoch_time = time.time() - epoch_start_time
print('[epoch %d] train_loss: %.3f val_accuracy: %.3f epoch_time: %.2f sec' %
(epoch + 1, running_loss / train_steps, val_accurate, epoch_time))
writer.add_scalar('Loss/train', running_loss / train_steps, epoch)
writer.add_scalar('Accuracy/val', val_accurate, epoch)
writer.add_scalar('Time/epoch', epoch_time, epoch)
if val_accurate > best_acc:
best_acc = val_accurate
torch.save(net.state_dict(), save_path)
patience_counter = 0
else:
patience_counter += 1
if patience_counter > patience_limit:
print("Early stopping triggered.")
break
total_time = time.time() - start_time
print(f"Training completed in {total_time:.2f} seconds")
writer.add_text('Training Time', f'Total training time: {total_time:.2f} seconds')
writer.close()
print('Finished Training')
if __name__ == '__main__':
main()
using cpu device. Using 8 dataloader workers every process using 604 images for training, 58 images for validation. train epoch[1/100] loss:0.635: 100%|██████████| 19/19 [00:59<00:00, 3.12s/it] 100%|██████████| 2/2 [00:03<00:00, 1.59s/it] [epoch 1] train_loss: 0.648 val_accuracy: 0.293 epoch_time: 62.48 sec train epoch[2/100] loss:0.578: 100%|██████████| 19/19 [01:12<00:00, 3.80s/it] 100%|██████████| 2/2 [00:03<00:00, 1.68s/it] [epoch 2] train_loss: 0.580 val_accuracy: 0.741 epoch_time: 75.62 sec train epoch[3/100] loss:0.194: 100%|██████████| 19/19 [01:01<00:00, 3.26s/it] 100%|██████████| 2/2 [00:03<00:00, 1.69s/it] [epoch 3] train_loss: 0.379 val_accuracy: 0.776 epoch_time: 65.35 sec train epoch[4/100] loss:0.452: 100%|██████████| 19/19 [01:05<00:00, 3.47s/it] 100%|██████████| 2/2 [00:03<00:00, 1.60s/it] [epoch 4] train_loss: 0.196 val_accuracy: 1.000 epoch_time: 69.10 sec train epoch[5/100] loss:0.036: 100%|██████████| 19/19 [01:13<00:00, 3.89s/it] 100%|██████████| 2/2 [00:03<00:00, 1.55s/it] [epoch 5] train_loss: 0.147 val_accuracy: 0.948 epoch_time: 76.95 sec train epoch[6/100] loss:0.046: 100%|██████████| 19/19 [01:07<00:00, 3.53s/it] 100%|██████████| 2/2 [00:06<00:00, 3.43s/it] [epoch 6] train_loss: 0.120 val_accuracy: 0.810 epoch_time: 73.97 sec train epoch[7/100] loss:0.261: 100%|██████████| 19/19 [01:17<00:00, 4.07s/it] 100%|██████████| 2/2 [00:02<00:00, 1.50s/it] [epoch 7] train_loss: 0.154 val_accuracy: 0.966 epoch_time: 80.25 sec train epoch[8/100] loss:0.116: 100%|██████████| 19/19 [01:02<00:00, 3.29s/it] 100%|██████████| 2/2 [00:03<00:00, 1.67s/it] [epoch 8] train_loss: 0.137 val_accuracy: 0.983 epoch_time: 65.79 sec train epoch[9/100] loss:0.112: 100%|██████████| 19/19 [01:21<00:00, 4.28s/it] 100%|██████████| 2/2 [00:03<00:00, 1.58s/it] [epoch 9] train_loss: 0.126 val_accuracy: 0.966 epoch_time: 84.52 sec train epoch[10/100] loss:0.218: 100%|██████████| 19/19 [01:21<00:00, 4.28s/it] 100%|██████████| 2/2 [00:03<00:00, 1.64s/it] [epoch 10] train_loss: 0.157 val_accuracy: 0.552 epoch_time: 84.60 sec train epoch[11/100] loss:0.246: 100%|██████████| 19/19 [01:10<00:00, 3.73s/it] 100%|██████████| 2/2 [00:03<00:00, 1.69s/it] [epoch 11] train_loss: 0.120 val_accuracy: 0.983 epoch_time: 74.34 sec train epoch[12/100] loss:0.111: 100%|██████████| 19/19 [01:24<00:00, 4.44s/it] 100%|██████████| 2/2 [00:02<00:00, 1.44s/it] [epoch 12] train_loss: 0.096 val_accuracy: 0.948 epoch_time: 87.29 sec train epoch[13/100] loss:0.142: 100%|██████████| 19/19 [01:13<00:00, 3.88s/it] 100%|██████████| 2/2 [00:03<00:00, 1.64s/it] [epoch 13] train_loss: 0.126 val_accuracy: 1.000 epoch_time: 77.02 sec train epoch[14/100] loss:0.079: 100%|██████████| 19/19 [01:16<00:00, 4.02s/it] 100%|██████████| 2/2 [00:03<00:00, 1.63s/it] [epoch 14] train_loss: 0.095 val_accuracy: 1.000 epoch_time: 79.73 sec train epoch[15/100] loss:0.009: 100%|██████████| 19/19 [01:15<00:00, 3.96s/it] 100%|██████████| 2/2 [00:03<00:00, 1.66s/it] [epoch 15] train_loss: 0.077 val_accuracy: 1.000 epoch_time: 78.54 sec Early stopping triggered. Training completed in 1136.46 seconds Finished Training
代码
文本
[14]
!pip install tensorboard
已隐藏输出
代码
文本
训练日志的输出方法
代码
文本
[16]
!%tensorboard --logdir /personal/Auto_Titration/Picture_Train/experiment_optimized
已隐藏输出
代码
文本
[ ]
%load_ext tensorboard
%tensorboard --logdir /personal/Auto_Titration/Picture_Train/experiment_optimized
代码
文本
点个赞吧
推荐阅读
公开
7-4,DeepFM网络xuxh@dp.tech
更新于 2024-10-11
公开
优化器optimizersxuxh@dp.tech
更新于 2024-08-22