Bohrium
robot
新建

空间站广场

论文
Notebooks
比赛
课程
Apps
我的主页
我的Notebooks
我的论文库
我的足迹

我的工作空间

任务
节点
文件
数据集
镜像
项目
数据库
公开
DP_GAN_Zoo_Branch (3)
中文
Deep Learning
PyTorch
CV
GAN
中文Deep LearningPyTorchCVGAN
suyanyi
发布于 2023-11-07
推荐镜像 :Basic Image:bohrium-notebook:2023-04-07
推荐机型 :c12_m92_1 * NVIDIA V100
1
1
summer2winter_yosemite(v1)

©️ Copyright 2023 @ Authors
作者: 苏沿溢
日期:2023-09-21
共享协议:本作品采用知识共享署名-非商业性使用-相同方式共享 4.0 国际许可协议进行许可。
快速开始:点击上方的 开始连接 按钮,选择 bohrium-notebook:2023-04-07镜像 和配置*V100*显卡即可开始。

代码
文本

CycleGAN

代码
文本
[ ]

代码
文本
[1]
import numpy as np
import datetime
import time
import sys
import torch
import os
代码
文本
[2]
import torch.nn as nn
import torch.nn.functional as F

class ResidualBlock(nn.Module):
def __init__(self, in_features):
super(ResidualBlock, self).__init__()

conv_block = [ nn.ReflectionPad2d(1),
nn.Conv2d(in_features, in_features, 3),
nn.InstanceNorm2d(in_features),
nn.ReLU(inplace=True),
nn.ReflectionPad2d(1),
nn.Conv2d(in_features, in_features, 3),
nn.InstanceNorm2d(in_features) ]

self.conv_block = nn.Sequential(*conv_block)

def forward(self, x):
return x + self.conv_block(x)

class GeneratorResNet(nn.Module):
def __init__(self, in_channels, out_channels, res_blocks ):
super(GeneratorResNet, self).__init__()
#in_channels = args.input_nc
#out_channels = args.output_nc
#res_blocks = args.n_residual_blocks
# Initial convolution block
model = [ nn.ReflectionPad2d(3),
nn.Conv2d(in_channels, 64, 7),
nn.InstanceNorm2d(64),
nn.ReLU(inplace=True) ]

# Downsampling
in_features = 64
out_features = in_features*2
for _ in range(2):
model += [ nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
nn.InstanceNorm2d(out_features),
nn.ReLU(inplace=True) ]
in_features = out_features
out_features = in_features*2

# Residual blocks
for _ in range(res_blocks):
model += [ResidualBlock(in_features)]

# Upsampling
out_features = in_features//2
for _ in range(2):
model += [ nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
nn.InstanceNorm2d(out_features),
nn.ReLU(inplace=True) ]
in_features = out_features
out_features = in_features//2

# Output layer
model += [ nn.ReflectionPad2d(3),
nn.Conv2d(64, out_channels, 7),
nn.Tanh() ]

self.model = nn.Sequential(*model)

def forward(self, x):
return self.model(x)


##############################
# Discriminator
##############################
class Discriminator_n_layers(nn.Module):
def __init__(self, n_D_layers, in_c):
super(Discriminator_n_layers, self).__init__()

n_layers = n_D_layers
in_channels = in_c
def discriminator_block(in_filters, out_filters, k=4, s=2, p=1, norm=True, sigmoid=False):
"""Returns downsampling layers of each discriminator block"""
layers = [nn.Conv2d(in_filters, out_filters, kernel_size=k, stride=s, padding=p)]
if norm:
layers.append(nn.BatchNorm2d(out_filters))
layers.append(nn.LeakyReLU(0.2, inplace=True))
if sigmoid:
layers.append(nn.Sigmoid())
print('use sigmoid')
return layers

sequence = [*discriminator_block(in_channels, 64, norm=False)] # (1,64,128,128)

assert n_layers<=5

if (n_layers == 1):
'when n_layers==1, the patch_size is (16x16)'
out_filters = 64* 2**(n_layers-1)

elif (1 < n_layers & n_layers<= 4):
'''
when n_layers==2, the patch_size is (34x34)
when n_layers==3, the patch_size is (70x70), this is the size used in the paper
when n_layers==4, the patch_size is (142x142)
'''
for k in range(1,n_layers): # k=1,2,3
sequence += [*discriminator_block(2**(5+k), 2**(6+k))]
out_filters = 64* 2**(n_layers-1)

elif (n_layers == 5):
'''
when n_layers==5, the patch_size is (286x286), lis larger than the img_size(256),
so this is the whole img condition
'''
for k in range(1,4): # k=1,2,3
sequence += [*discriminator_block(2**(5+k), 2**(6+k))]
# k=4
sequence += [*discriminator_block(2**9, 2**9)] #
out_filters = 2**9

num_of_filter = min(2*out_filters, 2**9)

sequence += [*discriminator_block(out_filters, num_of_filter, k=4, s=1, p=1)]
sequence += [*discriminator_block(num_of_filter, 1, k=4, s=1, p=1, norm=False, sigmoid=False)]

self.model = nn.Sequential(*sequence)

def forward(self, img_input ):
return self.model(img_input)
代码
文本
[3]
def weights_init_normal(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
torch.nn.init.normal(m.weight.data, 0.0, 0.02)
elif classname.find('BatchNorm2d') != -1:
torch.nn.init.normal(m.weight.data, 1.0, 0.02)
torch.nn.init.constant(m.bias.data, 0.0)
代码
文本
[4]
input_nc_A=3
input_nc_B=3
n_residual_blocks=9
n_D_layers=4

G_AB = GeneratorResNet(input_nc_A,input_nc_B ,n_residual_blocks).cuda()
D_B = Discriminator_n_layers(n_D_layers, input_nc_B).cuda()
G_BA = GeneratorResNet(input_nc_B,input_nc_A ,n_residual_blocks).cuda()
D_A = Discriminator_n_layers(n_D_layers, input_nc_A).cuda()


G_AB.apply(weights_init_normal)
D_B.apply(weights_init_normal)
G_BA.apply(weights_init_normal)
D_A.apply(weights_init_normal)
/tmp/ipykernel_537/303252567.py:4: UserWarning: nn.init.normal is now deprecated in favor of nn.init.normal_.
  torch.nn.init.normal(m.weight.data, 0.0, 0.02)
/tmp/ipykernel_537/303252567.py:6: UserWarning: nn.init.normal is now deprecated in favor of nn.init.normal_.
  torch.nn.init.normal(m.weight.data, 1.0, 0.02)
/tmp/ipykernel_537/303252567.py:7: UserWarning: nn.init.constant is now deprecated in favor of nn.init.constant_.
  torch.nn.init.constant(m.bias.data, 0.0)
Discriminator_n_layers(
  (model): Sequential(
    (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
    (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (9): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): LeakyReLU(negative_slope=0.2, inplace=True)
    (11): Conv2d(512, 512, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1))
    (12): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (13): LeakyReLU(negative_slope=0.2, inplace=True)
    (14): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1))
    (15): LeakyReLU(negative_slope=0.2, inplace=True)
  )
)
代码
文本
[5]
import torch.utils.data as data
from PIL import Image
class GeneratorDataset(data.Dataset):
"""Load images first for generator. """

def __init__(self, root_dir, transform=None):
"""
Args:
root_dir (string): Directory with all the images.
transform (callable, optional): Optional transform to be applied
on a sample.
"""
self.root_dir = root_dir
self.filenames = os.listdir(root_dir)
self.transform = transform

def __len__(self):
return len(self.filenames)

def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()

img_name = self.filenames[idx]
img_path = os.path.join(self.root_dir, img_name)
sample = Image.open(img_path).convert('RGB')
if self.transform:
sample = self.transform(sample)
return sample
代码
文本
[6]
import torchvision.transforms as transforms
img_height = 256
img_width = 256
# transforms_ = [
# # transforms.Resize(int(args.img_height*random.uniform(0.8,1.2)), Image.BICUBIC),
# transforms.RandomCrop((img_height, img_width)),
# transforms.RandomHorizontalFlip(),
# transforms.ToTensor(),
# # transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
# transforms.Normalize((0.5), (0.5))
# ]
transforms_ = transforms.Compose([
# transforms.Resize(int(args.img_height*random.uniform(0.8,1.2)), Image.BICUBIC),
transforms.RandomCrop((img_height, img_width)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
# transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
transforms.Normalize((0.5), (0.5))
])

# transforms_ = transforms.Compose([
# transforms.CenterCrop(256),
# transforms.RandomHorizontalFlip(p=0.5),
# transforms.ToTensor(),
# transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
# ])
代码
文本
[7]
import os

zip_file = 'cyclegan_v1.zip'
zip_dir = 'summer2winter_yosemite'

if not os.path.exists(zip_file):
!wget --content-disposition https://bohrium-api.dp.tech/ds-dl/cyclegan-p5yd-v1.zip

script_dir = os.getcwd()
if not os.path.exists(os.path.join(script_dir, zip_dir)):
print("解压数据集(可能需要五六分钟)")
!unzip -q {zip_file} -d {script_dir}
代码
文本
[8]
!ls
'DP_GAN_Zoo_Branch (16).ipynb'	 Exp-summer2winter   summer2winter_yosemite
'DP_GAN_Zoo_Branch (17).ipynb'	 cyclegan_v1.zip
代码
文本
[9]
# DATASET_PATH = './horse2zebra/'
# DATASET_PATH = './summer2winter_yosemite/'
# DATASET_PATH = './monet2photo/'
DATASET_PATH = './summer2winter_yosemite'


train_data_X = GeneratorDataset(root_dir=os.path.join(DATASET_PATH, "trainA"),
transform=transforms_)

train_data_Y = GeneratorDataset(root_dir=os.path.join(DATASET_PATH, "trainB"),
transform=transforms_)

test_data_X = GeneratorDataset(root_dir=os.path.join(DATASET_PATH, "testA"),
transform=transforms_)

test_data_Y = GeneratorDataset(root_dir=os.path.join(DATASET_PATH, "testB"),
transform=transforms_)
代码
文本
[10]
print("Found {} images in {}".format(len(train_data_X), 'trainA'))
print("Found {} images in {}".format(len(train_data_Y), 'trainB'))
print("Found {} images in {}".format(len(test_data_X), 'testA'))
print("Found {} images in {}".format(len(test_data_Y), 'testB'))
Found 1231 images in trainA
Found 962 images in trainB
Found 309 images in testA
Found 238 images in testB
代码
文本
[11]
BATCH_SIZE = 12

train_image_loader_X = torch.utils.data.DataLoader(train_data_X, batch_size=BATCH_SIZE,
shuffle=True, num_workers=0)
train_image_loader_Y = torch.utils.data.DataLoader(train_data_Y, batch_size=BATCH_SIZE,
shuffle=True, num_workers=0)
test_image_loader_X = torch.utils.data.DataLoader(test_data_X, batch_size=BATCH_SIZE,
shuffle=False, num_workers=0)
test_image_loader_Y = torch.utils.data.DataLoader(test_data_Y, batch_size=BATCH_SIZE,
shuffle=False, num_workers=0)
代码
文本
[12]
import itertools
learning_rate = 0.0002
optimizer_G = torch.optim.Adam(
itertools.chain(G_AB.parameters(), G_BA.parameters()),
lr=learning_rate, betas=(0.5, 0.999))
optimizer_D_B = torch.optim.Adam(
D_B.parameters(),
lr=learning_rate/2, betas=(0.5, 0.999))
optimizer_D_A = torch.optim.Adam(
D_A.parameters(),
lr=learning_rate/2, betas=(0.5, 0.999))
代码
文本
[13]
criterion_GAN = torch.nn.MSELoss()
criterion_cycle = torch.nn.L1Loss()
criterion_identity = torch.nn.L1Loss()
if torch.cuda.is_available():
criterion_GAN.cuda()
criterion_cycle.cuda()
criterion_identity.cuda()
代码
文本
[14]
current_time = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
# exp_name="Exp_"+str(current_time)
exp_name="Exp"
dataset_name="summer2winter"
img_result_dir='result_images'
model_result_dir='saved_models'
代码
文本
[15]
def image_recover(image):
image = (image.permute(0,2,3,1) + 1.0) / 2.0
image = (image * 255.0).to(torch.uint8)
return image
代码
文本
[16]
from torch.autograd import Variable
from torchvision.utils import save_image
import random
class ReplayBuffer():
def __init__(self, max_size=50):
assert (max_size > 0), 'Empty buffer or trying to create a black hole. Be careful.'
self.max_size = max_size
self.data = []

def push_and_pop(self, data):
to_return = []
for element in data.data:
element = torch.unsqueeze(element, 0)
if len(self.data) < self.max_size:
self.data.append(element)
to_return.append(element)
else:
if random.uniform(0,1) > 0.5:
i = random.randint(0, self.max_size-1)
to_return.append(self.data[i].clone())
self.data[i] = element
else:
to_return.append(element)
return Variable(torch.cat(to_return))

class LambdaLR():
def __init__(self, epoch_num, epoch_start, decay_start_epoch):
assert ((epoch_num - decay_start_epoch) > 0), "Decay must start before the training session ends!"
self.epoch_num = epoch_num
self.epoch_start = epoch_start
self.decay_start_epoch = decay_start_epoch

def step(self, epoch):
return 1.0 - max(0, epoch + 1 + self.epoch_start - self.decay_start_epoch)/(self.epoch_num - self.decay_start_epoch)


def sample_images(G_AB,G_BA, iter_image_X,iter_image_Y, epoch, batches_done,file_name):
"""Saves a generated sample from the test set"""
real_X_A = next(iter_train_image_X).cuda()
real_Y_B = next(iter_train_image_Y).cuda()
###############################################################################
#### You can regard the A and B as two defferent styles;
#### X and Y as two defferent images which in two defferent styles respectively
#### So the G_AB change the style from A to B; G_BA change the style from B to A
################################################################################
fake_X_B = G_AB(real_X_A) # the real_X_A is in A style,so we change it into the B style
recov_X_A = G_BA(fake_X_B)# do reconstruction from fake B style
# idt_Y_B = G_AB(real_Y_B) # input the real_Y to make sure the G_AB has an identity mapping


fake_Y_A = G_BA(real_Y_B) # the real_Y is in B style,so we change it into the A style
recov_Y_B = G_AB(fake_Y_A)# do reconstruction from fake A style
# idt_X_A = G_BA(real_X_A)


# img_sample = torch.cat((real_X_A.data ,
# fake_X_B.data,
# recov_X_A.data,
# idt_Y_B.data,
# real_Y_B.data ,
# fake_Y_A.data,
# recov_Y_B.data,
# idt_X_A.data), 0)
img_sample = torch.cat((real_X_A.data ,
fake_X_B.data,
recov_X_A.data,
real_Y_B.data ,
fake_Y_A.data,
recov_Y_B.data), 0)

if file_name == 'train':
img_path = '%s-%s/%s/%s-%s.png' % (exp_name,
dataset_name,
img_result_dir+'/train',
batches_done,
epoch)
save_image(img_sample, img_path, nrow=BATCH_SIZE, normalize=True)
if file_name == 'test':
img_path = '%s-%s/%s/%s-%s.png' % (exp_name,
dataset_name,
img_result_dir+'/test',
batches_done,
epoch)
save_image(img_sample, img_path, nrow=BATCH_SIZE, normalize=True)
return img_path

代码
文本
[17]
#梯度惩罚
def gradient_penalty(D, xr, xf):
"""

:param D:
:param xr:
:param xf:
:return:
"""
LAMBDA = 0.3

# only constrait for Discriminator
xf = xf.detach()
xr = xr.detach()

# [b, 1] => [b, 2]
alpha = torch.rand(1).cuda()
alpha = alpha.expand_as(xr)

interpolates = alpha * xr + ((1 - alpha) * xf)
interpolates.requires_grad_()

disc_interpolates = D(interpolates)

gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolates,
grad_outputs=torch.ones_like(disc_interpolates),
create_graph=True, retain_graph=True, only_inputs=True)[0]

gp = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * LAMBDA

return gp
代码
文本
[18]
patch = (1, img_height//(2**n_D_layers) - 2 , img_width//(2**n_D_layers) - 2)
代码
文本
[19]
epoch_start=0
epoch_num=800
decay_epoch=100

lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(optimizer_G, lr_lambda=LambdaLR(epoch_num, epoch_start, decay_epoch).step)
lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(optimizer_D_B, lr_lambda=LambdaLR(epoch_num, epoch_start, decay_epoch).step)
lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(optimizer_D_A, lr_lambda=LambdaLR(epoch_num, epoch_start, decay_epoch).step)
代码
文本
[20]
# Buffers of previously generated samples
fake_Y_A_buffer = ReplayBuffer()
fake_X_B_buffer = ReplayBuffer()
代码
文本
[21]
from itertools import cycle
iter_train_image_X = iter(cycle(train_image_loader_X))
iter_train_image_Y = iter(cycle(train_image_loader_Y))
iter_test_image_X = iter(cycle(test_image_loader_X))
iter_test_image_Y = iter(cycle(test_image_loader_Y))
代码
文本
[22]
file_name = '%s-%s' % (exp_name, dataset_name)

os.makedirs('%s-%s/%s' % (exp_name, dataset_name, img_result_dir), exist_ok=True)

os.makedirs('%s-%s/%s/train' % (exp_name, dataset_name, img_result_dir), exist_ok=True)
os.makedirs('%s-%s/%s/test' % (exp_name, dataset_name, img_result_dir), exist_ok=True)

os.makedirs('%s-%s/%s' % (exp_name, dataset_name, model_result_dir), exist_ok=True)
代码
文本
[23]
checkpoint_interval = 50
代码
文本
[ ]
import matplotlib.pyplot as plt
from PIL import Image
import os
from pathlib import Path
prev_time = time.time()
for epoch in range(epoch_start, epoch_num):
for i in range(len(train_image_loader_X)):


###############################################################################
#### You can regard the A and B as two defferent styles;
#### X and Y as two defferent images which in two defferent styles respectively
#### So the generator_AB change the style from A to B; generator_BA change the style from B to A
################################################################################

# Set model input
while True:
sample_X = next(iter_train_image_X)
sample_Y = next(iter_train_image_Y)
if sample_X.shape[0]==BATCH_SIZE and sample_Y.shape[0]==BATCH_SIZE:
break
real_X_A = Variable(sample_X.type(torch.FloatTensor).cuda())
real_Y_B = Variable(sample_Y.type(torch.FloatTensor).cuda())

# Adversarial ground truths
valid_A = Variable(torch.FloatTensor(np.ones((real_X_A.size(0), *patch))).cuda(), requires_grad=False)
fake_A = Variable(torch.FloatTensor(np.zeros((real_X_A.size(0), *patch))).cuda(), requires_grad=False)
valid_B = Variable(torch.FloatTensor(np.ones((real_Y_B.size(0), *patch))).cuda(), requires_grad=False)
fake_B = Variable(torch.FloatTensor(np.zeros((real_Y_B.size(0), *patch))).cuda(), requires_grad=False)

# ------------------
# Train Generators
# ------------------

optimizer_G.zero_grad()
# Identity loss
loss_id_A = criterion_identity(G_AB(real_X_A), real_X_A)
loss_id_B = criterion_identity(G_BA(real_Y_B), real_Y_B)

loss_identity = (loss_id_A + loss_id_B) / 2


# GAN loss
fake_X_B = G_AB(real_X_A)
pred_fake = D_B(fake_X_B)
#print(pred_fake.shape,valid.shape)
loss_GAN_AB = criterion_GAN(pred_fake, valid_A)

fake_Y_A = G_BA(real_Y_B)
pred_fake = D_A(fake_Y_A)
loss_GAN_BA = criterion_GAN(pred_fake, valid_B)

loss_GAN = (loss_GAN_AB + loss_GAN_BA) / 2

# Cycle loss
recov_X_A = G_BA(fake_X_B)
loss_cycle_A = criterion_cycle(recov_X_A, real_X_A)
recov_Y_B = G_AB(fake_Y_A)
loss_cycle_B = criterion_cycle(recov_Y_B, real_Y_B)

loss_cycle = (loss_cycle_A + loss_cycle_B) / 2

# Total loss
loss_G = loss_GAN + \
10 * loss_cycle + \
0.3 * loss_identity

loss_G.backward()
optimizer_G.step()

# -----------------------
# Train Discriminator A
# -----------------------
optimizer_D_A.zero_grad()

# Real loss
pred_real = D_A(real_X_A)
loss_real = criterion_GAN(pred_real, valid_A)
# loss_real = -(pred_real.mean())
# Fake loss (on batch of previously generated samples)
fake_Y_A_ = fake_Y_A_buffer.push_and_pop(fake_Y_A)
pred_fake = D_A(fake_Y_A_.detach())
loss_fake = criterion_GAN(pred_fake, fake_B)
# loss_fake = pred_fake.mean()

#gp
gp = gradient_penalty(D_A, real_X_A, fake_Y_A_.detach())

# Total loss
loss_D_A = (loss_real + loss_fake) / 2 + gp

loss_D_A.backward()
optimizer_D_A.step()

# -----------------------
# Train Discriminator B
# -----------------------
optimizer_D_B.zero_grad()

# Real loss
pred_real = D_B(real_Y_B)
loss_real = criterion_GAN(pred_real, valid_B)
# loss_real = -(pred_real.mean())
# Fake loss (on batch of previously generated samples)
fake_X_B_ = fake_X_B_buffer.push_and_pop(fake_X_B)
pred_fake = D_B(fake_X_B_.detach())
loss_fake = criterion_GAN(pred_fake, fake_A)
# loss_fake = pred_fake.mean()

#gp
gp = gradient_penalty(D_B, real_Y_B, fake_X_B_.detach())


# Total loss
loss_D_B = (loss_real + loss_fake) / 2 + gp

loss_D_B.backward()
optimizer_D_B.step()

loss_D = (loss_D_A + loss_D_B) / 2


# --------------
# Log Progress
# --------------

# Determine approximate time left
batches_done = epoch * len(train_image_loader_X) + i
batches_left = epoch_num * len(train_image_loader_X) - batches_done
time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time))
prev_time = time.time()

# Print log
print("\r[Epoch %d/%d] [Batch %d/%d] [D loss: %f, A: %f, B: %f ] [G loss: %f, G_AB: %f, cyc_A: %f, id_A: %f G_BA: %f, cyc_B: %f, id_B: %f] ETA: %s" %
(epoch+1, epoch_num,
i, len(train_image_loader_X),
loss_D.data.cpu(),loss_D_A.data.cpu(),loss_D_B.data.cpu(),
loss_G.data.cpu(),
loss_GAN_AB.data.cpu(),
loss_cycle_A.data.cpu(),
loss_id_A.data.cpu(),
loss_GAN_BA.data.cpu(),
loss_cycle_B.data.cpu(),
loss_id_B.data.cpu(),
time_left))

# If at sample interval save image
if batches_done % 50 == 0:
train_img_path = sample_images(G_AB,G_BA, iter_train_image_X,iter_train_image_Y, epoch, batches_done,'train')
test_img_path = sample_images(G_AB,G_BA, iter_test_image_X,iter_test_image_Y, epoch, batches_done,'test')
train_img = Image.open(train_img_path)
train_img = np.asarray(train_img)
test_img = Image.open(test_img_path)
test_img = np.asarray(test_img)
plt.figure(figsize=(10, 5))
plt.imshow(test_img)
plt.title('Test',fontsize=30)
plt.tick_params(axis='both', left=False, top=False, right=False, bottom=False, labelleft=False, labeltop=False, labelright=False, labelbottom=False)
plt.show()
plt.figure(figsize=(10, 5))
plt.imshow(train_img)
plt.title('Train',fontsize=30)
plt.tick_params(axis='both', left=False, top=False, right=False, bottom=False, labelleft=False, labeltop=False, labelright=False, labelbottom=False)
plt.show()
# plt.figure(figsize=(30, 15))
# plt.subplot(1, 2, 1)
# plt.imshow(test_img)
# plt.title('Test',fontsize=20)
# plt.tick_params(axis='both', left=False, top=False, right=False, bottom=False, labelleft=False, labeltop=False, labelright=False, labelbottom=False)

# plt.subplot(1, 2, 2)
# plt.imshow(train_img)
# plt.title('Train',fontsize=20)
# plt.tick_params(axis='both', left=False, top=False, right=False, bottom=False, labelleft=False, labeltop=False, labelright=False, labelbottom=False)

# plt.show()

# Update learning rates
lr_scheduler_G.step(epoch)
lr_scheduler_D_B.step(epoch)
lr_scheduler_D_A.step(epoch)


if checkpoint_interval != -1 and epoch % checkpoint_interval == 0:
# Save model checkpoints
torch.save(G_AB.state_dict(), '%s-%s/%s/G__AB_%d.pth' % (exp_name, dataset_name, model_result_dir, epoch))
torch.save(G_BA.state_dict(), '%s-%s/%s/G__BA_%d.pth' % (exp_name, dataset_name, model_result_dir, epoch))
torch.save(D_A.state_dict(), '%s-%s/%s/D__A_%d.pth' % (exp_name, dataset_name, model_result_dir, epoch))
torch.save(D_B.state_dict(), '%s-%s/%s/D__B_%d.pth' % (exp_name, dataset_name, model_result_dir, epoch))
代码
文本
[ ]

代码
文本
[ ]

代码
文本
[ ]

代码
文本
[ ]

代码
文本
[ ]

代码
文本
[ ]

代码
文本
[ ]

代码
文本
[ ]

代码
文本
[ ]

代码
文本
[ ]

代码
文本
中文
Deep Learning
PyTorch
CV
GAN
中文Deep LearningPyTorchCVGAN
点个赞吧
本文被以下合集收录
DP GAN Zoo Branch
suyanyi
更新于 2023-09-21
3 篇0 人关注
推荐阅读
公开
Demo.ipynb
PyTorch
PyTorch
xuxh@dp.tech
更新于 2024-08-21
公开
dynamo.ipynb
AI
AI
xuxh@dp.tech
更新于 2024-09-03