安装和检查环境
由于CLIP是一个大模型,需要使用GPU加速推理。以下部分将CLIP repo及其依赖项,并检查是否安装了 PyTorch 1.7.1 或更高版本。
由于该公共镜像的Jupyter版本过低,无法使用tqdm,因此某些代码块的运行时间较长时不是很友好,请耐心等待~
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple Requirement already satisfied: ftfy in /opt/mamba/lib/python3.10/site-packages (6.1.1) Requirement already satisfied: regex in /opt/mamba/lib/python3.10/site-packages (2023.8.8) Requirement already satisfied: tqdm in /opt/mamba/lib/python3.10/site-packages (4.64.1) Requirement already satisfied: wcwidth>=0.2.5 in /opt/mamba/lib/python3.10/site-packages (from ftfy) (0.2.6) WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple Collecting git+https://github.com/openai/CLIP.git Cloning https://github.com/openai/CLIP.git to /tmp/pip-req-build-1vggs3p6 Running command git clone --filter=blob:none --quiet https://github.com/openai/CLIP.git /tmp/pip-req-build-1vggs3p6 Resolved https://github.com/openai/CLIP.git to commit a1d071733d7111c9c014f024669f959182114e33 Preparing metadata (setup.py) ... done Requirement already satisfied: ftfy in /opt/mamba/lib/python3.10/site-packages (from clip==1.0) (6.1.1) Requirement already satisfied: regex in /opt/mamba/lib/python3.10/site-packages (from clip==1.0) (2023.8.8) Requirement already satisfied: tqdm in /opt/mamba/lib/python3.10/site-packages (from clip==1.0) (4.64.1) Requirement already satisfied: torch in /opt/mamba/lib/python3.10/site-packages (from clip==1.0) (2.0.0+cu118) Collecting torchvision Downloading https://pypi.tuna.tsinghua.edu.cn/packages/87/0f/88f023bf6176d9af0f85feedf4be129f9cf2748801c4d9c690739a10c100/torchvision-0.15.2-cp310-cp310-manylinux1_x86_64.whl (6.0 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 6.0/6.0 MB 32.7 MB/s eta 0:00:0000:0100:01 Requirement already satisfied: wcwidth>=0.2.5 in /opt/mamba/lib/python3.10/site-packages (from ftfy->clip==1.0) (0.2.6) Requirement already satisfied: triton==2.0.0 in /opt/mamba/lib/python3.10/site-packages (from torch->clip==1.0) (2.0.0) Requirement already satisfied: networkx in /opt/mamba/lib/python3.10/site-packages (from torch->clip==1.0) (3.0) Requirement already satisfied: sympy in /opt/mamba/lib/python3.10/site-packages (from torch->clip==1.0) (1.11.1) Requirement already satisfied: filelock in /opt/mamba/lib/python3.10/site-packages (from torch->clip==1.0) (3.10.0) Requirement already satisfied: jinja2 in /opt/mamba/lib/python3.10/site-packages (from torch->clip==1.0) (3.1.2) Requirement already satisfied: typing-extensions in /opt/mamba/lib/python3.10/site-packages (from torch->clip==1.0) (4.5.0) Requirement already satisfied: lit in /opt/mamba/lib/python3.10/site-packages (from triton==2.0.0->torch->clip==1.0) (15.0.7) Requirement already satisfied: cmake in /opt/mamba/lib/python3.10/site-packages (from triton==2.0.0->torch->clip==1.0) (3.26.0) Collecting pillow!=8.3.*,>=5.3.0 Downloading https://pypi.tuna.tsinghua.edu.cn/packages/7a/07/e896b096a77375e78e02ce222ae4fd6014928cd76c691d312060a1645dfa/Pillow-10.0.1-cp310-cp310-manylinux_2_28_x86_64.whl (3.6 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 3.6/3.6 MB 63.2 MB/s eta 0:00:00ta 0:00:01 Requirement already satisfied: requests in /opt/mamba/lib/python3.10/site-packages (from torchvision->clip==1.0) (2.28.1) Collecting torch Downloading https://pypi.tuna.tsinghua.edu.cn/packages/8c/4d/17e07377c9c3d1a0c4eb3fde1c7c16b5a0ce6133ddbabc08ceef6b7f2645/torch-2.0.1-cp310-cp310-manylinux1_x86_64.whl (619.9 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 619.9/619.9 MB 1.7 MB/s eta 0:00:0000:0100:01 Requirement already satisfied: numpy in /opt/mamba/lib/python3.10/site-packages (from torchvision->clip==1.0) (1.24.2) Collecting nvidia-cusolver-cu11==11.4.0.1 Downloading https://pypi.tuna.tsinghua.edu.cn/packages/3e/77/66149e3153b19312fb782ea367f3f950123b93916a45538b573fe373570a/nvidia_cusolver_cu11-11.4.0.1-2-py3-none-manylinux1_x86_64.whl (102.6 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 102.6/102.6 MB 1.9 MB/s eta 0:00:0000:0100:01 Collecting nvidia-cudnn-cu11==8.5.0.96 Downloading https://pypi.tuna.tsinghua.edu.cn/packages/dc/30/66d4347d6e864334da5bb1c7571305e501dcb11b9155971421bb7bb5315f/nvidia_cudnn_cu11-8.5.0.96-2-py3-none-manylinux1_x86_64.whl (557.1 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 557.1/557.1 MB 1.6 MB/s eta 0:00:0000:0100:01 Collecting nvidia-nccl-cu11==2.14.3 Downloading https://pypi.tuna.tsinghua.edu.cn/packages/55/92/914cdb650b6a5d1478f83148597a25e90ea37d739bd563c5096b0e8a5f43/nvidia_nccl_cu11-2.14.3-py3-none-manylinux1_x86_64.whl (177.1 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 177.1/177.1 MB 3.4 MB/s eta 0:00:0000:0100:01 Collecting nvidia-cublas-cu11==11.10.3.66 Downloading https://pypi.tuna.tsinghua.edu.cn/packages/ce/41/fdeb62b5437996e841d83d7d2714ca75b886547ee8017ee2fe6ea409d983/nvidia_cublas_cu11-11.10.3.66-py3-none-manylinux1_x86_64.whl (317.1 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 317.1/317.1 MB 3.5 MB/s eta 0:00:0000:0100:01 Collecting nvidia-cuda-nvrtc-cu11==11.7.99 Downloading https://pypi.tuna.tsinghua.edu.cn/packages/ef/25/922c5996aada6611b79b53985af7999fc629aee1d5d001b6a22431e18fec/nvidia_cuda_nvrtc_cu11-11.7.99-2-py3-none-manylinux1_x86_64.whl (21.0 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 21.0/21.0 MB 44.4 MB/s eta 0:00:0000:0100:01 Collecting nvidia-curand-cu11==10.2.10.91 Downloading https://pypi.tuna.tsinghua.edu.cn/packages/8f/11/af78d54b2420e64a4dd19e704f5bb69dcb5a6a3138b4465d6a48cdf59a21/nvidia_curand_cu11-10.2.10.91-py3-none-manylinux1_x86_64.whl (54.6 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 54.6/54.6 MB 15.2 MB/s eta 0:00:0000:0100:01 Collecting nvidia-cufft-cu11==10.9.0.58 Downloading https://pypi.tuna.tsinghua.edu.cn/packages/74/79/b912a77e38e41f15a0581a59f5c3548d1ddfdda3225936fb67c342719e7a/nvidia_cufft_cu11-10.9.0.58-py3-none-manylinux1_x86_64.whl (168.4 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 168.4/168.4 MB 3.6 MB/s eta 0:00:0000:0100:01 Collecting nvidia-cuda-cupti-cu11==11.7.101 Downloading https://pypi.tuna.tsinghua.edu.cn/packages/e6/9d/dd0cdcd800e642e3c82ee3b5987c751afd4f3fb9cc2752517f42c3bc6e49/nvidia_cuda_cupti_cu11-11.7.101-py3-none-manylinux1_x86_64.whl (11.8 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 11.8/11.8 MB 19.8 MB/s eta 0:00:0000:0100:01 Collecting nvidia-nvtx-cu11==11.7.91 Downloading https://pypi.tuna.tsinghua.edu.cn/packages/23/d5/09493ff0e64fd77523afbbb075108f27a13790479efe86b9ffb4587671b5/nvidia_nvtx_cu11-11.7.91-py3-none-manylinux1_x86_64.whl (98 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 98.6/98.6 kB 2.2 MB/s eta 0:00:00ta 0:00:01 Collecting nvidia-cusparse-cu11==11.7.4.91 Downloading https://pypi.tuna.tsinghua.edu.cn/packages/ea/6f/6d032cc1bb7db88a989ddce3f4968419a7edeafda362847f42f614b1f845/nvidia_cusparse_cu11-11.7.4.91-py3-none-manylinux1_x86_64.whl (173.2 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 173.2/173.2 MB 6.5 MB/s eta 0:00:0000:0100:01 Collecting nvidia-cuda-runtime-cu11==11.7.99 Downloading https://pypi.tuna.tsinghua.edu.cn/packages/36/92/89cf558b514125d2ebd8344dd2f0533404b416486ff681d5434a5832a019/nvidia_cuda_runtime_cu11-11.7.99-py3-none-manylinux1_x86_64.whl (849 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 849.3/849.3 kB 11.5 MB/s eta 0:00:0000:0100:01 Requirement already satisfied: setuptools in /opt/mamba/lib/python3.10/site-packages (from nvidia-cublas-cu11==11.10.3.66->torch->clip==1.0) (65.5.0) Requirement already satisfied: wheel in /opt/mamba/lib/python3.10/site-packages (from nvidia-cublas-cu11==11.10.3.66->torch->clip==1.0) (0.37.1) Requirement already satisfied: MarkupSafe>=2.0 in /opt/mamba/lib/python3.10/site-packages (from jinja2->torch->clip==1.0) (2.1.2) Requirement already satisfied: idna<4,>=2.5 in /opt/mamba/lib/python3.10/site-packages (from requests->torchvision->clip==1.0) (3.4) Requirement already satisfied: charset-normalizer<3,>=2 in /opt/mamba/lib/python3.10/site-packages (from requests->torchvision->clip==1.0) (2.1.1) Requirement already satisfied: urllib3<1.27,>=1.21.1 in /opt/mamba/lib/python3.10/site-packages (from requests->torchvision->clip==1.0) (1.26.11) Requirement already satisfied: certifi>=2017.4.17 in /opt/mamba/lib/python3.10/site-packages (from requests->torchvision->clip==1.0) (2022.9.24) Requirement already satisfied: mpmath>=0.19 in /opt/mamba/lib/python3.10/site-packages (from sympy->torch->clip==1.0) (1.3.0) Building wheels for collected packages: clip Building wheel for clip (setup.py) ... done Created wheel for clip: filename=clip-1.0-py3-none-any.whl size=1369500 sha256=86b6673b75d2fa4e7ece895105752f79cea212fef95d3617cfd6f73eaecd12fd Stored in directory: /tmp/pip-ephem-wheel-cache-7mvxr5li/wheels/da/2b/4c/d6691fa9597aac8bb85d2ac13b112deb897d5b50f5ad9a37e4 Successfully built clip Installing collected packages: pillow, nvidia-nvtx-cu11, nvidia-nccl-cu11, nvidia-cusparse-cu11, nvidia-curand-cu11, nvidia-cufft-cu11, nvidia-cuda-runtime-cu11, nvidia-cuda-nvrtc-cu11, nvidia-cuda-cupti-cu11, nvidia-cublas-cu11, nvidia-cusolver-cu11, nvidia-cudnn-cu11, torch, torchvision, clip Attempting uninstall: torch Found existing installation: torch 2.0.0+cu118 Uninstalling torch-2.0.0+cu118: Successfully uninstalled torch-2.0.0+cu118 Successfully installed clip-1.0 nvidia-cublas-cu11-11.10.3.66 nvidia-cuda-cupti-cu11-11.7.101 nvidia-cuda-nvrtc-cu11-11.7.99 nvidia-cuda-runtime-cu11-11.7.99 nvidia-cudnn-cu11-8.5.0.96 nvidia-cufft-cu11-10.9.0.58 nvidia-curand-cu11-10.2.10.91 nvidia-cusolver-cu11-11.4.0.1 nvidia-cusparse-cu11-11.7.4.91 nvidia-nccl-cu11-2.14.3 nvidia-nvtx-cu11-11.7.91 pillow-10.0.1 torch-2.0.1 torchvision-0.15.2 WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv
Torch version: 2.0.1+cu117
导入模型
以下部分实现了下载并导入CLIP模型。使用clip.available_models()查看可以使用的模型。
['RN50', 'RN101', 'RN50x4', 'RN50x16', 'RN50x64', 'ViT-B/32', 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px']
这里我们使用最小的ViT模型 ViT-B/32,这个模型具有比resnet更好的性能,同时不会占用太多的计算资源。
100%|███████████████████████████████████████| 338M/338M [00:28<00:00, 12.4MiB/s]
Model parameters: 151,277,313 Input resolution: 224 Context length: 77 Vocab size: 49408
准备CIFAR-10标签和提示词模板
在原来的CLIP模型的训练过程中,语言和图像是一对一的。因此在新的数据集上进行测试时,我们可以通过标签的名称,获得语言上的类别。
这里先定义cifar-10数据集的10个类(原论文在Imagenet数据集上进行的实验,但是imagenet太大了,我们使用小一点的cifar-10进行实验)
CLIP定义了下面80种提示词模板。我们把类别标签分别放到这80个模板中,就可以获得一组提示词。提示词越多,我们获得的类别标签就会越准确。
10 classes, 80 templates
导入图像数据集
这里我们使用torchvision自带的CIFAR-10数据集即可。 (数据下载有点慢,如果超时可以提前下载到本地)
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple Requirement already satisfied: torchvision==0.15.2 in /opt/mamba/lib/python3.10/site-packages (0.15.2) Requirement already satisfied: torch==2.0.1 in /opt/mamba/lib/python3.10/site-packages (from torchvision==0.15.2) (2.0.1) Requirement already satisfied: numpy in /opt/mamba/lib/python3.10/site-packages (from torchvision==0.15.2) (1.24.2) Requirement already satisfied: requests in /opt/mamba/lib/python3.10/site-packages (from torchvision==0.15.2) (2.28.1) Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /opt/mamba/lib/python3.10/site-packages (from torchvision==0.15.2) (10.0.1) Requirement already satisfied: nvidia-cusolver-cu11==11.4.0.1 in /opt/mamba/lib/python3.10/site-packages (from torch==2.0.1->torchvision==0.15.2) (11.4.0.1) Requirement already satisfied: sympy in /opt/mamba/lib/python3.10/site-packages (from torch==2.0.1->torchvision==0.15.2) (1.11.1) Requirement already satisfied: networkx in /opt/mamba/lib/python3.10/site-packages (from torch==2.0.1->torchvision==0.15.2) (3.0) Requirement already satisfied: nvidia-cusparse-cu11==11.7.4.91 in /opt/mamba/lib/python3.10/site-packages (from torch==2.0.1->torchvision==0.15.2) (11.7.4.91) Requirement already satisfied: filelock in /opt/mamba/lib/python3.10/site-packages (from torch==2.0.1->torchvision==0.15.2) (3.10.0) Requirement already satisfied: nvidia-cuda-nvrtc-cu11==11.7.99 in /opt/mamba/lib/python3.10/site-packages (from torch==2.0.1->torchvision==0.15.2) (11.7.99) Requirement already satisfied: nvidia-cufft-cu11==10.9.0.58 in /opt/mamba/lib/python3.10/site-packages (from torch==2.0.1->torchvision==0.15.2) (10.9.0.58) Requirement already satisfied: triton==2.0.0 in /opt/mamba/lib/python3.10/site-packages (from torch==2.0.1->torchvision==0.15.2) (2.0.0) Requirement already satisfied: nvidia-cuda-runtime-cu11==11.7.99 in /opt/mamba/lib/python3.10/site-packages (from torch==2.0.1->torchvision==0.15.2) (11.7.99) Requirement already satisfied: nvidia-cuda-cupti-cu11==11.7.101 in /opt/mamba/lib/python3.10/site-packages (from torch==2.0.1->torchvision==0.15.2) (11.7.101) Requirement already satisfied: jinja2 in /opt/mamba/lib/python3.10/site-packages (from torch==2.0.1->torchvision==0.15.2) (3.1.2) Requirement already satisfied: nvidia-cublas-cu11==11.10.3.66 in /opt/mamba/lib/python3.10/site-packages (from torch==2.0.1->torchvision==0.15.2) (11.10.3.66) Requirement already satisfied: nvidia-nccl-cu11==2.14.3 in /opt/mamba/lib/python3.10/site-packages (from torch==2.0.1->torchvision==0.15.2) (2.14.3) Requirement already satisfied: nvidia-cudnn-cu11==8.5.0.96 in /opt/mamba/lib/python3.10/site-packages (from torch==2.0.1->torchvision==0.15.2) (8.5.0.96) Requirement already satisfied: nvidia-curand-cu11==10.2.10.91 in /opt/mamba/lib/python3.10/site-packages (from torch==2.0.1->torchvision==0.15.2) (10.2.10.91) Requirement already satisfied: typing-extensions in /opt/mamba/lib/python3.10/site-packages (from torch==2.0.1->torchvision==0.15.2) (4.5.0) Requirement already satisfied: nvidia-nvtx-cu11==11.7.91 in /opt/mamba/lib/python3.10/site-packages (from torch==2.0.1->torchvision==0.15.2) (11.7.91) Requirement already satisfied: wheel in /opt/mamba/lib/python3.10/site-packages (from nvidia-cublas-cu11==11.10.3.66->torch==2.0.1->torchvision==0.15.2) (0.37.1) Requirement already satisfied: setuptools in /opt/mamba/lib/python3.10/site-packages (from nvidia-cublas-cu11==11.10.3.66->torch==2.0.1->torchvision==0.15.2) (65.5.0) Requirement already satisfied: cmake in /opt/mamba/lib/python3.10/site-packages (from triton==2.0.0->torch==2.0.1->torchvision==0.15.2) (3.26.0) Requirement already satisfied: lit in /opt/mamba/lib/python3.10/site-packages (from triton==2.0.0->torch==2.0.1->torchvision==0.15.2) (15.0.7) Requirement already satisfied: idna<4,>=2.5 in /opt/mamba/lib/python3.10/site-packages (from requests->torchvision==0.15.2) (3.4) Requirement already satisfied: urllib3<1.27,>=1.21.1 in /opt/mamba/lib/python3.10/site-packages (from requests->torchvision==0.15.2) (1.26.11) Requirement already satisfied: certifi>=2017.4.17 in /opt/mamba/lib/python3.10/site-packages (from requests->torchvision==0.15.2) (2022.9.24) Requirement already satisfied: charset-normalizer<3,>=2 in /opt/mamba/lib/python3.10/site-packages (from requests->torchvision==0.15.2) (2.1.1) Requirement already satisfied: MarkupSafe>=2.0 in /opt/mamba/lib/python3.10/site-packages (from jinja2->torch==2.0.1->torchvision==0.15.2) (2.1.2) Requirement already satisfied: mpmath>=0.19 in /opt/mamba/lib/python3.10/site-packages (from sympy->torch==2.0.1->torchvision==0.15.2) (1.3.0) WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv
零样本预测
零样本预测(zero-shot prediction)是指当模型在足够大的训练集上训练之后,模型已经获得了对于其他数据集类别的推理能力。
只要通过适当的方式把这种推理能力释放出来,就可以获得不错的零样本预测性能。
和传统的迁移学习方法相比,零样本预测不需要在目标域做迁移训练,可以认为这个模型已经获得了“通用”计算能力。
零样本预测
经过上述准备工作之后,我们就可以进行零样本预测了。
可以看到,零样本的CLIP模型可以在CIFAR-10数据集上获得89.56%的Top-1准确率。这说明CLIP已经具有非常强的泛化能力了。
Top-1 accuracy: 89.56 Top-5 accuracy: 99.43