训练GAN发生器梯度始终为0,Discriminator表现良好



当我尝试训练gan时,梯度总是给我一个烂摊子我用这段代码来查看渐变渐变但是它是0

print(self.ganout.ganout1[0].weight.grad)

梯度img

这都是代码我做错了什么

import os
import time
from collections import OrderedDict
# from pl_bolts.models.gans import DCGAN
import numpy
import numpy as np
import torch
import torch.nn.functional as F
import torch.nn as nn
import pytorch_lightning as pyl
from torch.autograd import Variable
from torch.autograd._functions import tensor
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import pandas as pd
import cv2
from PIL import Image
class cc_block(nn.Module):
def __init__(self,in_channels, out_channels, kernel_size, stride):
super().__init__()
self.inn=0
self.convv =nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride)
def forward(self, x):
self.inn=x
aa =self.convv(x)
cc = torch.add(self.inn, aa)
# b = self.ganj2(a.view(3*22*17*4))
return cc
class ResBlock(nn.Module):
def __init__(self, n_chans,n_chans1):
super(ResBlock, self).__init__()
self.conv = nn.Conv2d(n_chans, n_chans1,
kernel_size=(8, 8), stride=(2, 2),
bias=False)
self.batch_norm = nn.BatchNorm2d(num_features=n_chans)  # <5>
torch.nn.init.kaiming_normal_(self.conv.weight,
nonlinearity='relu')  # <6>
torch.nn.init.constant_(self.batch_norm.weight, 0.5)  # <7>
torch.nn.init.zeros_(self.batch_norm.bias)
def forward(self, x):
out = self.conv(x)
out = self.batch_norm(out)
out = torch.relu(out)
return out + x
class jian(nn.Module):
def __init__(self, ):
super().__init__()
self.ganj = nn.Sequential(nn.Conv2d(in_channels=3, out_channels=300, kernel_size=(8, 8), stride=(2, 2)),
nn.BatchNorm2d(300), nn.LeakyReLU(0.2),
nn.Conv2d(in_channels=300, out_channels=300, kernel_size=(8, 8), stride=(2, 2)),
# ResBlock(300, 300,),
nn.BatchNorm2d(300), nn.LeakyReLU(0.2),
nn.Conv2d(in_channels=300, out_channels=3, kernel_size=(8, 8), stride=(2, 2)),
nn.LeakyReLU(0.2))  # 22 17
self.ganj2 = nn.Sequential(nn.Linear(3 * 22 * 17, 1), nn.Sigmoid())
self.optimizer = torch.optim.Adam(self.parameters(),lr=0.00001)
def forward(self, x):
a = self.ganj(x)
b = self.ganj2(a.view(3 * 22 * 17))
# b = self.ganj2(a.view(3*22*17*4))
return b
class grent(nn.Module):
def __init__(self,):
super().__init__()
self.ganout1 = nn.Sequential(nn.Linear(100, 6 * 22 * 17), nn.SELU())
# self.ganout1 = nn.Sequential(nn.Linear(1, 3 * 22 * 17), nn.LeakyReLU(0.2),nn.Linear(3 * 22 * 17, 3 * 50 * 30), nn.LeakyReLU(0.2),nn.Linear( 3 * 50 * 30,  3 * 150 * 50), nn.LeakyReLU(0.2),nn.Linear( 3 * 150 * 100,  3 * 200 * 150), nn.SELU(),nn.Linear( 3 * 200 * 150,  3 * 218 * 178), nn.Sigmoid())
# self.ganout1 = nn.Sequential(nn.Linear(10, 3 * 218 ), nn.LeakyReLU(0.2),nn.Linear(3 * 218, 3000 ), nn.LeakyReLU(0.2),nn.Linear(3000 , 3 * 218 * 178),nn.Sigmoid())
self.optimizer = torch.optim.Adam(self.parameters(), lr=0.00001)

# self.ganout = nn.Sequential(
#     nn.ConvTranspose2d(in_channels=6, out_channels=400, kernel_size=(8, 8), stride=(2, 2)), nn.BatchNorm2d(400),
#     nn.SELU(), nn.ConvTranspose2d(in_channels=400, out_channels=400, kernel_size=(8, 8), stride=(2, 2)),
#     nn.BatchNorm2d(400), nn.SELU(),
#     nn.ConvTranspose2d(in_channels=400, out_channels=3, kernel_size=(8, 8), stride=(2, 2)),
#     nn.Sigmoid())
self.ganout = nn.Sequential(
nn.ConvTranspose2d(in_channels=6, out_channels=400, kernel_size=(8, 8), stride=(2, 2)),
nn.SELU(), nn.ConvTranspose2d(in_channels=400, out_channels=400, kernel_size=(8, 8), stride=(2, 2)),
nn.SELU(),
nn.ConvTranspose2d(in_channels=400, out_channels=3, kernel_size=(8, 8), stride=(2, 2)),
nn.Sigmoid())
def forward(self, x):
a = self.ganout1(x)
# a =a.view((1,3 , 218, 178))
b = self.ganout(a.view((1,6 , 22, 17)))
# b = self.ganj2(a.view(3*22*17*4))
return b

class main_modle(pyl.LightningModule):
def __init__(self, ):
super().__init__()
self.ganout = grent()
self.ganj = jian()


def forward(self, inputs):
a =self.ganout(inputs)
# b = self.ganj2(a.view(3*22*17*4))
return a

def configure_optimizers(self):
optimizer = jian().optimizer
# optimizer = torch.optim.Adam(self.ganj.parameters(),lr=0.00001)
# optimizer2 = torch.optim.Adam(self.ganout.parameters(),lr=0.0003 )
# optimizer2 = torch.optim.Adam(self.ganout.parameters(), lr=0.0003)
optimizer2 = grent().optimizer
return optimizer2,optimizer
# return optimizer,
def adversarial_loss(self, y_hat, y):
return F.binary_cross_entropy(y_hat, y)
def training_step(self, batch, batch_idx,optimizer_idx):
# def training_step(self, batch, batch_idx,):
#     x , y =batch
self.loss_function = nn.BCELoss()
self.zero_grad()

x = batch[0]
if optimizer_idx == 1:
if batch_idx%1==0:
outt=self.ganj(x)
if torch.isnan(outt).any():
outt =Variable(torch.tensor([1.0]),requires_grad=True).type_as(x)
aa = torch.as_tensor(torch.Tensor([1.0])).type_as(x)
# realloss =self.loss_function(outt,aa)
realloss=self.adversarial_loss(outt,aa)
aaas = torch.round(self.forward(torch.randn([100]).type_as(x) ).detach() *255)
outt2 = self.ganj(aaas)
if torch.isnan(outt2).any():
outt2 =Variable(torch.tensor([0.0]),requires_grad=True).type_as(x)
aa2 = torch.as_tensor(torch.Tensor([0.0])).type_as(x)
# fuckloss = self.loss_function(outt2, aa2)
fuckloss = self.adversarial_loss(outt2, aa2)
mainloss = (realloss +fuckloss) /2
csdfsd =0
tqdm_dict = {"d_loss": mainloss}
output = OrderedDict({"loss": mainloss, "progress_bar": tqdm_dict, "log": tqdm_dict})
# return output
self.log_dict(tqdm_dict)
return  mainloss
# else:
#
#     mainloss = fuckloss
#     tqdm_dict = {"d_loss": mainloss}
#     csdfsd = 0
#     output = OrderedDict({"loss": mainloss, "progress_bar": tqdm_dict, "log": tqdm_dict})
#     return output
if optimizer_idx == 0:
self.loss_function2 = nn.MSELoss()
outt2 = self.ganj(torch.round(self.ganout(torch.randn([100]).type_as(x))*255))
if torch.isnan(outt2).any():
outt2 =Variable(torch.tensor([0.0]),requires_grad=True).type_as(x)
asds =torch.Tensor([1.0]).type_as(x)
# losss = self.loss_function2(outt2,asds)
losss = self.adversarial_loss(outt2, torch.Tensor([1.0]).type_as(x))
csdfsd = 0
if batch_idx%100 ==0:
aaas = torch.round(self.ganout(torch.randn([100]).type_as(x)) * 255)
print(self.ganout.state_dict().keys())
# print(self.ganout.ganout1[0].weight)
print(self.ganout.ganout1[0].weight.grad)
img_1 = aaas[0].cpu().detach().numpy()
img_1 = img_1.astype('uint8')
img_1 = np.transpose(img_1, (1, 2, 0))  # 将通道数移到最后
cv2.imwrite('./114/'+str(batch_idx)+'.png',img_1)
# torch.autograd.set_detect_anomaly(True)
# losss.backward(retain_graph=True)
# a =self.configure_optimizers()
# a[0].step()
tqdm_dict = {"g_loss": losss}
output = OrderedDict({"loss": losss, "progress_bar": tqdm_dict, "log": tqdm_dict})
# return output
self.log_dict(tqdm_dict)
return losss
#
# return losss




# a =x.shape
# c =self.forward(x)
# # print(c.shape)
# # print (a)
aaaaasss =0

# return self.loss_function(c,aa)
def test_step(self, batch, batch_idx):
pass

class dataset2(Dataset):
def __init__(self, csc_file):
# self.data_df= pd.read_csv(csc_file,header=None)
self.filll = os.listdir(csc_file)
# print( self.filll)
self.num =len(self.filll)
if csc_file[-1] == '/':
self.path = csc_file
else:
self.path = csc_file + '/'
def __len__(self):
return self.num
# return self.num + 212000
def __getitem__(self, index):
if index <= self.num:
# if 0:
filllsd = self.filll[index]
img = cv2.imread(self.path + filllsd)
image1 = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
imgg = torch.FloatTensor(img)
img = numpy.transpose(imgg, (2, 0, 1))
# input = torch.from_numpy(img)
input1 = torch.FloatTensor(img)
ttttfff = torch.Tensor([1])
# print(0000000000000000)
# print(input1.shape)
# imgg =imgg.resize((1,3,218,178))
# image = Variable(torch.unsqueeze(image, dim=0).float(), requires_grad=False)
# imggf = Variable(torch.unsqueeze(th.FloatTensor(r), dim=0).float(), requires_grad=False)
# cv2.imshow("windows_name", np.uint8(imggf.cpu().numpy()))
# cv2.waitKey()
# print(imggf)
# image = image
# return label ,image ,target
# print(img.shape)
# tesyt = np.uint8(imggf.cpu().numpy())
else:
acc = torch.rand([ 3, 218, 178]) * 255
acc = torch.round(acc)
input1 = acc
ttttfff = torch.Tensor([0])
# print(1111111111111111)
# return input1,ttttfff
return input1,
# return  image, target

aaa =dataset2(r'K:aaaaaImgimg_align_celeba')
from pytorch_lightning.loggers import TensorBoardLogger
logger = TensorBoardLogger('tb_logs', name='my_model')
trainer = pyl.Trainer(gpus=1,logger =logger)

trainer.fit(model=main_modle(),train_dataloader=DataLoader(aaa,shuffle=True,batch_size=1))

尝试用tanh代替sigmoid激活函数。当D非常自信时,s型曲线趋于0

为了稳定起见,应该轮流训练Discriminator和Generator。我建议尝试先训练生成器,而不更新鉴别器参数,以观察它是否能够训练或实现有问题。此外,我可以建议添加正则化项。

最新更新