Bohrium
robot
新建

空间站广场

论文
Notebooks
比赛
课程
Apps
我的主页
我的Notebooks
我的论文库
我的足迹

我的工作空间

任务
节点
文件
数据集
镜像
项目
数据库
公开
基于CLIP模型的CIFAR-10提示词工程
notebook
notebook
许瑞晗
发布于 2023-09-25
推荐镜像 :Basic Image:ubuntu:22.04-py3.10-pytorch2.0
推荐机型 :c12_m46_1 * NVIDIA GPU B
CIFAR-10(v1)

安装和检查环境

由于CLIP是一个大模型,需要使用GPU加速推理。以下部分将CLIP repo及其依赖项,并检查是否安装了 PyTorch 1.7.1 或更高版本。

由于该公共镜像的Jupyter版本过低,无法使用tqdm,因此某些代码块的运行时间较长时不是很友好,请耐心等待~

代码
文本
[1]
! pip install ftfy regex tqdm
! pip install git+https://github.com/openai/CLIP.git
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
Requirement already satisfied: ftfy in /opt/mamba/lib/python3.10/site-packages (6.1.1)
Requirement already satisfied: regex in /opt/mamba/lib/python3.10/site-packages (2023.8.8)
Requirement already satisfied: tqdm in /opt/mamba/lib/python3.10/site-packages (4.64.1)
Requirement already satisfied: wcwidth>=0.2.5 in /opt/mamba/lib/python3.10/site-packages (from ftfy) (0.2.6)
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
Collecting git+https://github.com/openai/CLIP.git
  Cloning https://github.com/openai/CLIP.git to /tmp/pip-req-build-1vggs3p6
  Running command git clone --filter=blob:none --quiet https://github.com/openai/CLIP.git /tmp/pip-req-build-1vggs3p6
  Resolved https://github.com/openai/CLIP.git to commit a1d071733d7111c9c014f024669f959182114e33
  Preparing metadata (setup.py) ... done
Requirement already satisfied: ftfy in /opt/mamba/lib/python3.10/site-packages (from clip==1.0) (6.1.1)
Requirement already satisfied: regex in /opt/mamba/lib/python3.10/site-packages (from clip==1.0) (2023.8.8)
Requirement already satisfied: tqdm in /opt/mamba/lib/python3.10/site-packages (from clip==1.0) (4.64.1)
Requirement already satisfied: torch in /opt/mamba/lib/python3.10/site-packages (from clip==1.0) (2.0.0+cu118)
Collecting torchvision
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/87/0f/88f023bf6176d9af0f85feedf4be129f9cf2748801c4d9c690739a10c100/torchvision-0.15.2-cp310-cp310-manylinux1_x86_64.whl (6.0 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 6.0/6.0 MB 32.7 MB/s eta 0:00:0000:0100:01
Requirement already satisfied: wcwidth>=0.2.5 in /opt/mamba/lib/python3.10/site-packages (from ftfy->clip==1.0) (0.2.6)
Requirement already satisfied: triton==2.0.0 in /opt/mamba/lib/python3.10/site-packages (from torch->clip==1.0) (2.0.0)
Requirement already satisfied: networkx in /opt/mamba/lib/python3.10/site-packages (from torch->clip==1.0) (3.0)
Requirement already satisfied: sympy in /opt/mamba/lib/python3.10/site-packages (from torch->clip==1.0) (1.11.1)
Requirement already satisfied: filelock in /opt/mamba/lib/python3.10/site-packages (from torch->clip==1.0) (3.10.0)
Requirement already satisfied: jinja2 in /opt/mamba/lib/python3.10/site-packages (from torch->clip==1.0) (3.1.2)
Requirement already satisfied: typing-extensions in /opt/mamba/lib/python3.10/site-packages (from torch->clip==1.0) (4.5.0)
Requirement already satisfied: lit in /opt/mamba/lib/python3.10/site-packages (from triton==2.0.0->torch->clip==1.0) (15.0.7)
Requirement already satisfied: cmake in /opt/mamba/lib/python3.10/site-packages (from triton==2.0.0->torch->clip==1.0) (3.26.0)
Collecting pillow!=8.3.*,>=5.3.0
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/7a/07/e896b096a77375e78e02ce222ae4fd6014928cd76c691d312060a1645dfa/Pillow-10.0.1-cp310-cp310-manylinux_2_28_x86_64.whl (3.6 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 3.6/3.6 MB 63.2 MB/s eta 0:00:00ta 0:00:01
Requirement already satisfied: requests in /opt/mamba/lib/python3.10/site-packages (from torchvision->clip==1.0) (2.28.1)
Collecting torch
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/8c/4d/17e07377c9c3d1a0c4eb3fde1c7c16b5a0ce6133ddbabc08ceef6b7f2645/torch-2.0.1-cp310-cp310-manylinux1_x86_64.whl (619.9 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 619.9/619.9 MB 1.7 MB/s eta 0:00:0000:0100:01
Requirement already satisfied: numpy in /opt/mamba/lib/python3.10/site-packages (from torchvision->clip==1.0) (1.24.2)
Collecting nvidia-cusolver-cu11==11.4.0.1
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/3e/77/66149e3153b19312fb782ea367f3f950123b93916a45538b573fe373570a/nvidia_cusolver_cu11-11.4.0.1-2-py3-none-manylinux1_x86_64.whl (102.6 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 102.6/102.6 MB 1.9 MB/s eta 0:00:0000:0100:01
Collecting nvidia-cudnn-cu11==8.5.0.96
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/dc/30/66d4347d6e864334da5bb1c7571305e501dcb11b9155971421bb7bb5315f/nvidia_cudnn_cu11-8.5.0.96-2-py3-none-manylinux1_x86_64.whl (557.1 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 557.1/557.1 MB 1.6 MB/s eta 0:00:0000:0100:01
Collecting nvidia-nccl-cu11==2.14.3
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/55/92/914cdb650b6a5d1478f83148597a25e90ea37d739bd563c5096b0e8a5f43/nvidia_nccl_cu11-2.14.3-py3-none-manylinux1_x86_64.whl (177.1 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 177.1/177.1 MB 3.4 MB/s eta 0:00:0000:0100:01
Collecting nvidia-cublas-cu11==11.10.3.66
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/ce/41/fdeb62b5437996e841d83d7d2714ca75b886547ee8017ee2fe6ea409d983/nvidia_cublas_cu11-11.10.3.66-py3-none-manylinux1_x86_64.whl (317.1 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 317.1/317.1 MB 3.5 MB/s eta 0:00:0000:0100:01
Collecting nvidia-cuda-nvrtc-cu11==11.7.99
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/ef/25/922c5996aada6611b79b53985af7999fc629aee1d5d001b6a22431e18fec/nvidia_cuda_nvrtc_cu11-11.7.99-2-py3-none-manylinux1_x86_64.whl (21.0 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 21.0/21.0 MB 44.4 MB/s eta 0:00:0000:0100:01
Collecting nvidia-curand-cu11==10.2.10.91
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/8f/11/af78d54b2420e64a4dd19e704f5bb69dcb5a6a3138b4465d6a48cdf59a21/nvidia_curand_cu11-10.2.10.91-py3-none-manylinux1_x86_64.whl (54.6 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 54.6/54.6 MB 15.2 MB/s eta 0:00:0000:0100:01
Collecting nvidia-cufft-cu11==10.9.0.58
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/74/79/b912a77e38e41f15a0581a59f5c3548d1ddfdda3225936fb67c342719e7a/nvidia_cufft_cu11-10.9.0.58-py3-none-manylinux1_x86_64.whl (168.4 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 168.4/168.4 MB 3.6 MB/s eta 0:00:0000:0100:01
Collecting nvidia-cuda-cupti-cu11==11.7.101
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/e6/9d/dd0cdcd800e642e3c82ee3b5987c751afd4f3fb9cc2752517f42c3bc6e49/nvidia_cuda_cupti_cu11-11.7.101-py3-none-manylinux1_x86_64.whl (11.8 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 11.8/11.8 MB 19.8 MB/s eta 0:00:0000:0100:01
Collecting nvidia-nvtx-cu11==11.7.91
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/23/d5/09493ff0e64fd77523afbbb075108f27a13790479efe86b9ffb4587671b5/nvidia_nvtx_cu11-11.7.91-py3-none-manylinux1_x86_64.whl (98 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 98.6/98.6 kB 2.2 MB/s eta 0:00:00ta 0:00:01
Collecting nvidia-cusparse-cu11==11.7.4.91
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/ea/6f/6d032cc1bb7db88a989ddce3f4968419a7edeafda362847f42f614b1f845/nvidia_cusparse_cu11-11.7.4.91-py3-none-manylinux1_x86_64.whl (173.2 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 173.2/173.2 MB 6.5 MB/s eta 0:00:0000:0100:01
Collecting nvidia-cuda-runtime-cu11==11.7.99
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/36/92/89cf558b514125d2ebd8344dd2f0533404b416486ff681d5434a5832a019/nvidia_cuda_runtime_cu11-11.7.99-py3-none-manylinux1_x86_64.whl (849 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 849.3/849.3 kB 11.5 MB/s eta 0:00:0000:0100:01
Requirement already satisfied: setuptools in /opt/mamba/lib/python3.10/site-packages (from nvidia-cublas-cu11==11.10.3.66->torch->clip==1.0) (65.5.0)
Requirement already satisfied: wheel in /opt/mamba/lib/python3.10/site-packages (from nvidia-cublas-cu11==11.10.3.66->torch->clip==1.0) (0.37.1)
Requirement already satisfied: MarkupSafe>=2.0 in /opt/mamba/lib/python3.10/site-packages (from jinja2->torch->clip==1.0) (2.1.2)
Requirement already satisfied: idna<4,>=2.5 in /opt/mamba/lib/python3.10/site-packages (from requests->torchvision->clip==1.0) (3.4)
Requirement already satisfied: charset-normalizer<3,>=2 in /opt/mamba/lib/python3.10/site-packages (from requests->torchvision->clip==1.0) (2.1.1)
Requirement already satisfied: urllib3<1.27,>=1.21.1 in /opt/mamba/lib/python3.10/site-packages (from requests->torchvision->clip==1.0) (1.26.11)
Requirement already satisfied: certifi>=2017.4.17 in /opt/mamba/lib/python3.10/site-packages (from requests->torchvision->clip==1.0) (2022.9.24)
Requirement already satisfied: mpmath>=0.19 in /opt/mamba/lib/python3.10/site-packages (from sympy->torch->clip==1.0) (1.3.0)
Building wheels for collected packages: clip
  Building wheel for clip (setup.py) ... done
  Created wheel for clip: filename=clip-1.0-py3-none-any.whl size=1369500 sha256=86b6673b75d2fa4e7ece895105752f79cea212fef95d3617cfd6f73eaecd12fd
  Stored in directory: /tmp/pip-ephem-wheel-cache-7mvxr5li/wheels/da/2b/4c/d6691fa9597aac8bb85d2ac13b112deb897d5b50f5ad9a37e4
Successfully built clip
Installing collected packages: pillow, nvidia-nvtx-cu11, nvidia-nccl-cu11, nvidia-cusparse-cu11, nvidia-curand-cu11, nvidia-cufft-cu11, nvidia-cuda-runtime-cu11, nvidia-cuda-nvrtc-cu11, nvidia-cuda-cupti-cu11, nvidia-cublas-cu11, nvidia-cusolver-cu11, nvidia-cudnn-cu11, torch, torchvision, clip
  Attempting uninstall: torch
    Found existing installation: torch 2.0.0+cu118
    Uninstalling torch-2.0.0+cu118:
      Successfully uninstalled torch-2.0.0+cu118
Successfully installed clip-1.0 nvidia-cublas-cu11-11.10.3.66 nvidia-cuda-cupti-cu11-11.7.101 nvidia-cuda-nvrtc-cu11-11.7.99 nvidia-cuda-runtime-cu11-11.7.99 nvidia-cudnn-cu11-8.5.0.96 nvidia-cufft-cu11-10.9.0.58 nvidia-curand-cu11-10.2.10.91 nvidia-cusolver-cu11-11.4.0.1 nvidia-cusparse-cu11-11.7.4.91 nvidia-nccl-cu11-2.14.3 nvidia-nvtx-cu11-11.7.91 pillow-10.0.1 torch-2.0.1 torchvision-0.15.2
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv
代码
文本
[2]
import numpy as np
import torch
import clip
from tqdm.notebook import tqdm
# from pkg_resources import packing

print("Torch version:", torch.__version__)

Torch version: 2.0.1+cu117
代码
文本

导入模型

以下部分实现了下载并导入CLIP模型。使用clip.available_models()查看可以使用的模型。

代码
文本
[3]
clip.available_models()
['RN50',
 'RN101',
 'RN50x4',
 'RN50x16',
 'RN50x64',
 'ViT-B/32',
 'ViT-B/16',
 'ViT-L/14',
 'ViT-L/14@336px']
代码
文本

这里我们使用最小的ViT模型 ViT-B/32,这个模型具有比resnet更好的性能,同时不会占用太多的计算资源。

代码
文本
[4]
model, preprocess = clip.load("ViT-B/32")
100%|███████████████████████████████████████| 338M/338M [00:28<00:00, 12.4MiB/s]
代码
文本
[5]
input_resolution = model.visual.input_resolution
context_length = model.context_length
vocab_size = model.vocab_size

print("Model parameters:", f"{np.sum([int(np.prod(p.shape)) for p in model.parameters()]):,}")
print("Input resolution:", input_resolution)
print("Context length:", context_length)
print("Vocab size:", vocab_size)
Model parameters: 151,277,313
Input resolution: 224
Context length: 77
Vocab size: 49408
代码
文本

准备CIFAR-10标签和提示词模板

在原来的CLIP模型的训练过程中,语言和图像是一对一的。因此在新的数据集上进行测试时,我们可以通过标签的名称,获得语言上的类别。

这里先定义cifar-10数据集的10个类(原论文在Imagenet数据集上进行的实验,但是imagenet太大了,我们使用小一点的cifar-10进行实验)

代码
文本
[6]
imagenet_classes = ["airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"]
代码
文本

CLIP定义了下面80种提示词模板。我们把类别标签分别放到这80个模板中,就可以获得一组提示词。提示词越多,我们获得的类别标签就会越准确。

代码
文本
[7]
imagenet_templates = [
'a bad photo of a {}.',
'a photo of many {}.',
'a sculpture of a {}.',
'a photo of the hard to see {}.',
'a low resolution photo of the {}.',
'a rendering of a {}.',
'graffiti of a {}.',
'a bad photo of the {}.',
'a cropped photo of the {}.',
'a tattoo of a {}.',
'the embroidered {}.',
'a photo of a hard to see {}.',
'a bright photo of a {}.',
'a photo of a clean {}.',
'a photo of a dirty {}.',
'a dark photo of the {}.',
'a drawing of a {}.',
'a photo of my {}.',
'the plastic {}.',
'a photo of the cool {}.',
'a close-up photo of a {}.',
'a black and white photo of the {}.',
'a painting of the {}.',
'a painting of a {}.',
'a pixelated photo of the {}.',
'a sculpture of the {}.',
'a bright photo of the {}.',
'a cropped photo of a {}.',
'a plastic {}.',
'a photo of the dirty {}.',
'a jpeg corrupted photo of a {}.',
'a blurry photo of the {}.',
'a photo of the {}.',
'a good photo of the {}.',
'a rendering of the {}.',
'a {} in a video game.',
'a photo of one {}.',
'a doodle of a {}.',
'a close-up photo of the {}.',
'a photo of a {}.',
'the origami {}.',
'the {} in a video game.',
'a sketch of a {}.',
'a doodle of the {}.',
'a origami {}.',
'a low resolution photo of a {}.',
'the toy {}.',
'a rendition of the {}.',
'a photo of the clean {}.',
'a photo of a large {}.',
'a rendition of a {}.',
'a photo of a nice {}.',
'a photo of a weird {}.',
'a blurry photo of a {}.',
'a cartoon {}.',
'art of a {}.',
'a sketch of the {}.',
'a embroidered {}.',
'a pixelated photo of a {}.',
'itap of the {}.',
'a jpeg corrupted photo of the {}.',
'a good photo of a {}.',
'a plushie {}.',
'a photo of the nice {}.',
'a photo of the small {}.',
'a photo of the weird {}.',
'the cartoon {}.',
'art of the {}.',
'a drawing of the {}.',
'a photo of the large {}.',
'a black and white photo of a {}.',
'the plushie {}.',
'a dark photo of a {}.',
'itap of a {}.',
'graffiti of the {}.',
'a toy {}.',
'itap of my {}.',
'a photo of a cool {}.',
'a photo of a small {}.',
'a tattoo of the {}.',
]

print(f"{len(imagenet_classes)} classes, {len(imagenet_templates)} templates")
10 classes, 80 templates
代码
文本

导入图像数据集

这里我们使用torchvision自带的CIFAR-10数据集即可。 (数据下载有点慢,如果超时可以提前下载到本地)

代码
文本
[9]
!pip install torchvision==0.15.2

from torchvision.datasets import CIFAR10

images = CIFAR10(root="/bohr/cifar-10-garl/v1/cifar-10-python", download=False, transform=preprocess)
loader = torch.utils.data.DataLoader(images, batch_size=32, num_workers=2)
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
Requirement already satisfied: torchvision==0.15.2 in /opt/mamba/lib/python3.10/site-packages (0.15.2)
Requirement already satisfied: torch==2.0.1 in /opt/mamba/lib/python3.10/site-packages (from torchvision==0.15.2) (2.0.1)
Requirement already satisfied: numpy in /opt/mamba/lib/python3.10/site-packages (from torchvision==0.15.2) (1.24.2)
Requirement already satisfied: requests in /opt/mamba/lib/python3.10/site-packages (from torchvision==0.15.2) (2.28.1)
Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /opt/mamba/lib/python3.10/site-packages (from torchvision==0.15.2) (10.0.1)
Requirement already satisfied: nvidia-cusolver-cu11==11.4.0.1 in /opt/mamba/lib/python3.10/site-packages (from torch==2.0.1->torchvision==0.15.2) (11.4.0.1)
Requirement already satisfied: sympy in /opt/mamba/lib/python3.10/site-packages (from torch==2.0.1->torchvision==0.15.2) (1.11.1)
Requirement already satisfied: networkx in /opt/mamba/lib/python3.10/site-packages (from torch==2.0.1->torchvision==0.15.2) (3.0)
Requirement already satisfied: nvidia-cusparse-cu11==11.7.4.91 in /opt/mamba/lib/python3.10/site-packages (from torch==2.0.1->torchvision==0.15.2) (11.7.4.91)
Requirement already satisfied: filelock in /opt/mamba/lib/python3.10/site-packages (from torch==2.0.1->torchvision==0.15.2) (3.10.0)
Requirement already satisfied: nvidia-cuda-nvrtc-cu11==11.7.99 in /opt/mamba/lib/python3.10/site-packages (from torch==2.0.1->torchvision==0.15.2) (11.7.99)
Requirement already satisfied: nvidia-cufft-cu11==10.9.0.58 in /opt/mamba/lib/python3.10/site-packages (from torch==2.0.1->torchvision==0.15.2) (10.9.0.58)
Requirement already satisfied: triton==2.0.0 in /opt/mamba/lib/python3.10/site-packages (from torch==2.0.1->torchvision==0.15.2) (2.0.0)
Requirement already satisfied: nvidia-cuda-runtime-cu11==11.7.99 in /opt/mamba/lib/python3.10/site-packages (from torch==2.0.1->torchvision==0.15.2) (11.7.99)
Requirement already satisfied: nvidia-cuda-cupti-cu11==11.7.101 in /opt/mamba/lib/python3.10/site-packages (from torch==2.0.1->torchvision==0.15.2) (11.7.101)
Requirement already satisfied: jinja2 in /opt/mamba/lib/python3.10/site-packages (from torch==2.0.1->torchvision==0.15.2) (3.1.2)
Requirement already satisfied: nvidia-cublas-cu11==11.10.3.66 in /opt/mamba/lib/python3.10/site-packages (from torch==2.0.1->torchvision==0.15.2) (11.10.3.66)
Requirement already satisfied: nvidia-nccl-cu11==2.14.3 in /opt/mamba/lib/python3.10/site-packages (from torch==2.0.1->torchvision==0.15.2) (2.14.3)
Requirement already satisfied: nvidia-cudnn-cu11==8.5.0.96 in /opt/mamba/lib/python3.10/site-packages (from torch==2.0.1->torchvision==0.15.2) (8.5.0.96)
Requirement already satisfied: nvidia-curand-cu11==10.2.10.91 in /opt/mamba/lib/python3.10/site-packages (from torch==2.0.1->torchvision==0.15.2) (10.2.10.91)
Requirement already satisfied: typing-extensions in /opt/mamba/lib/python3.10/site-packages (from torch==2.0.1->torchvision==0.15.2) (4.5.0)
Requirement already satisfied: nvidia-nvtx-cu11==11.7.91 in /opt/mamba/lib/python3.10/site-packages (from torch==2.0.1->torchvision==0.15.2) (11.7.91)
Requirement already satisfied: wheel in /opt/mamba/lib/python3.10/site-packages (from nvidia-cublas-cu11==11.10.3.66->torch==2.0.1->torchvision==0.15.2) (0.37.1)
Requirement already satisfied: setuptools in /opt/mamba/lib/python3.10/site-packages (from nvidia-cublas-cu11==11.10.3.66->torch==2.0.1->torchvision==0.15.2) (65.5.0)
Requirement already satisfied: cmake in /opt/mamba/lib/python3.10/site-packages (from triton==2.0.0->torch==2.0.1->torchvision==0.15.2) (3.26.0)
Requirement already satisfied: lit in /opt/mamba/lib/python3.10/site-packages (from triton==2.0.0->torch==2.0.1->torchvision==0.15.2) (15.0.7)
Requirement already satisfied: idna<4,>=2.5 in /opt/mamba/lib/python3.10/site-packages (from requests->torchvision==0.15.2) (3.4)
Requirement already satisfied: urllib3<1.27,>=1.21.1 in /opt/mamba/lib/python3.10/site-packages (from requests->torchvision==0.15.2) (1.26.11)
Requirement already satisfied: certifi>=2017.4.17 in /opt/mamba/lib/python3.10/site-packages (from requests->torchvision==0.15.2) (2022.9.24)
Requirement already satisfied: charset-normalizer<3,>=2 in /opt/mamba/lib/python3.10/site-packages (from requests->torchvision==0.15.2) (2.1.1)
Requirement already satisfied: MarkupSafe>=2.0 in /opt/mamba/lib/python3.10/site-packages (from jinja2->torch==2.0.1->torchvision==0.15.2) (2.1.2)
Requirement already satisfied: mpmath>=0.19 in /opt/mamba/lib/python3.10/site-packages (from sympy->torch==2.0.1->torchvision==0.15.2) (1.3.0)
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv
代码
文本

零样本预测

零样本预测(zero-shot prediction)是指当模型在足够大的训练集上训练之后,模型已经获得了对于其他数据集类别的推理能力。

只要通过适当的方式把这种推理能力释放出来,就可以获得不错的零样本预测性能。

和传统的迁移学习方法相比,零样本预测不需要在目标域做迁移训练,可以认为这个模型已经获得了“通用”计算能力。

代码
文本
[11]
def zeroshot_classifier(classnames, templates):
with torch.no_grad():
zeroshot_weights = []
for classname in classnames:
texts = [template.format(classname) for template in templates] #format with class
texts = clip.tokenize(texts).cuda() #tokenize
class_embeddings = model.encode_text(texts) #embed with text encoder
class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
class_embedding = class_embeddings.mean(dim=0)
class_embedding /= class_embedding.norm()
zeroshot_weights.append(class_embedding)
zeroshot_weights = torch.stack(zeroshot_weights, dim=1).cuda()
return zeroshot_weights


zeroshot_weights = zeroshot_classifier(imagenet_classes, imagenet_templates)
代码
文本

零样本预测

经过上述准备工作之后,我们就可以进行零样本预测了。

可以看到,零样本的CLIP模型可以在CIFAR-10数据集上获得89.56%的Top-1准确率。这说明CLIP已经具有非常强的泛化能力了。

代码
文本
[12]
def accuracy(output, target, topk=(1,)):
pred = output.topk(max(topk), 1, True, True)[1].t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
return [float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) for k in topk]
代码
文本
[14]
with torch.no_grad():
top1, top5, n = 0., 0., 0.
for i, (images, target) in enumerate(loader):
images = images.cuda()
target = target.cuda()
# predict
image_features = model.encode_image(images)
image_features /= image_features.norm(dim=-1, keepdim=True)
logits = 100. * image_features @ zeroshot_weights

# measure accuracy
acc1, acc5 = accuracy(logits, target, topk=(1, 5))
top1 += acc1
top5 += acc5
n += images.size(0)

top1 = (top1 / n) * 100
top5 = (top5 / n) * 100

print(f"Top-1 accuracy: {top1:.2f}")
print(f"Top-5 accuracy: {top5:.2f}")
Top-1 accuracy: 89.56
Top-5 accuracy: 99.43
代码
文本
notebook
notebook
点个赞吧
推荐阅读
公开
强化动力学(RiD)上手指南
RiDProtein DynamicsMD增强采样
RiDProtein DynamicsMD增强采样
fanjh@dp.tech
发布于 2023-09-12
4 赞4 转存文件
公开
PyTorch YOLO Tutorial
python中文
python中文
Letian
发布于 2024-05-14
1 赞2 转存文件