WGAN 在 TF 2.6 中处理大型数据集时输出损失"NaN"



我正在研究将Wasserstein GAN应用于大型数据集。我曾经在我的研究所使用计算服务器,但是它很小,所以我开始使用更大的服务器。经过一番挖掘,我发现我所在机构的服务器使用的是NP 1.18.5和TF 2.3.1,而较大的服务器使用的是NP 1.19.5和TF 2.6.0,如果这两者相关的话。

我在较小的服务器上使用的模型工作得很好(它们遭受了模式崩溃,但我的意思是代码工作),但是当我开始在较大的服务器上运行东西时,我开始获得NaN,因为生成器和评论家都失去了。我调查了一下这个问题,它似乎与我的WGAN正在处理的数据量有关:一旦我开始使用~6.5G的数据,它返回NaNs,但低于该阈值,它似乎执行得很好。我的总数据集大约是9G。它在第一次训练迭代后输出NaN,所以我不认为问题是爆炸梯度,但我不是专家。

我已经尝试在激活批评家之前和之后实现批处理规范化,以及值裁剪,但似乎都没有影响大型数据集的输出。

完整的代码可以在这里找到:https://pastebin.com/2iJ0gQ5j(我知道它很乱)

这是一个最小可行的例子

#!/usr/bin/env python
# coding: utf-8
import csv
import sys
import pandas as pd
import numpy as np
import math
import os
import tensorflow as tf
import glob
import math
from pathlib import Path
from tensorflow.keras import backend
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.layers import Dense, Input, Dropout, BatchNormalization, Activation
from tensorflow.keras.optimizers import Adam, RMSprop
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.constraints import Constraint

# implementation of wasserstein loss
def wasserstein_loss(y_true, y_pred):
return backend.mean(y_true * y_pred)
# clip model weights to a given hypercube
class ClipConstraint(Constraint):
# set clip value when initialized
def __init__(self, clip_value):
self.clip_value = clip_value

# clip model weights to hypercube
def __call__(self, weights):
return backend.clip(weights, -self.clip_value, self.clip_value)

# get the config
def get_config(self):
return {'clip_value': self.clip_value}
# generate points in latent space as input for the generator
def generate_latent_points(latent_dim, n_samples):
# generate points in the latent space
x_input = np.random.randn(latent_dim * n_samples)
# reshape into a batch of inputs for the network
x_input = x_input.reshape(n_samples, latent_dim)
return x_input
def generate_real_samples(data, n) :
idx = np.random.choice(data.shape[0], n, replace=False)
return data[idx, :], -np.ones((n,1))
def generate_fake_samples(g_model, latent, n) :
return g_model(generate_latent_points(latent, n)), np.ones((n,1))

SNP_dir = "/my/work/dir/in/"
SNP_vcf_list = glob.glob(SNP_dir+"*.csv")
nstep = 2500
ncritstep = 5

### READ INPUT
sampleIDs = []
sampleData = {}
meaningful = []
linesperfile = []
## data is stored in multiple files, so it is read as a dictionary then converted into an array
for SNP_vcf in SNP_vcf_list :
n = 0
with open(SNP_vcf, newline='') as csvfile:
csvread = csv.reader(csvfile, delimiter=';',  quotechar='"')
for myline in csvread: #reads csvfile line by line, storing array of cells into myline
if(myline[0][0] == '#'):
myline[0] = myline[0][2:]
for mycell in myline : 
if mycell not in sampleIDs :
sampleIDs.append(mycell)
sampleData[mycell] = []
else :
n+=1
## some lines contain all 0s, so they are removed here.
mybool = np.std(np.asarray(myline).astype(np.float32)) != 0
meaningful.append(mybool)
i=0
for mycell in myline : 
sampleData[sampleIDs[i]].append(float(mycell))
i+=1
linesperfile.append(n)
input_size = sum(meaningful)
print("Read "+str(input_size)+" meaningful encoded variables for "+str(len(sampleIDs))+" samples")

real = []
for k, v in sampleData.items() :
real.append(np.array(v)[meaningful])
real = np.array(real)
print(real.shape)
del(sampleData)


### CREATE CRITIC
hidden_c_1_size = math.ceil(input_size/10)
hidden_c_2_size = math.ceil(input_size/50)
hidden_c_3_size = math.ceil(input_size/100)
print("First hidden critic layer size: "+str(hidden_c_1_size))
print("Second hidden critic layer size: "+str(hidden_c_2_size))
print("Third hidden critic layer size: "+str(hidden_c_3_size))
hidden_g_1_size = math.ceil(input_size/10)
hidden_g_2_size = math.ceil(input_size/50)
hidden_g_3_size = math.ceil(input_size/100)
print("First hidden generator layer size: "+str(hidden_g_3_size))
print("Second hidden generator layer size: "+str(hidden_g_2_size))
print("Third hidden generator layer size: "+str(hidden_g_1_size))
const = ClipConstraint(0.01)
critic_input = Input(shape=(input_size, ))
c_hidden_1 = Dense(hidden_c_1_size, kernel_constraint=const)(critic_input)
c_bnorm_1 = BatchNormalization()(c_hidden_1)
c_act_1 = Activation("relu")(c_bnorm_1)
c_drop_1 = Dropout(0.4)(c_act_1)
c_hidden_2 = Dense(hidden_c_2_size, kernel_constraint=const)(c_drop_1)
c_bnorm_2 = BatchNormalization()(c_hidden_2)
c_act_2 = Activation("relu")(c_bnorm_2)
c_drop_2 = Dropout(0.4)(c_act_2)
c_hidden_3 = Dense(hidden_c_3_size, kernel_constraint=const)(c_drop_2)
c_bnorm_3 = BatchNormalization()(c_hidden_3)
c_act_3 = Activation("relu")(c_bnorm_3)
critic_output = Dense(1, activation='linear')(c_act_3)
critic = Model(critic_input, critic_output)
opt = RMSprop(learning_rate=0.00005, clipvalue=0.01)
critic.compile(optimizer=opt, loss=wasserstein_loss)
### CREATE GENERATOR
latent_dim_size = 1000
generator_input = Input(latent_dim_size)
g_hidden_1 = Dense(hidden_g_3_size, activation='relu')(generator_input)
g_drop_1 = Dropout(0.4)(g_hidden_1)
g_hidden_2 = Dense(hidden_g_2_size, activation='relu')(g_drop_1)
g_drop_2 = Dropout(0.4)(g_hidden_2)
g_hidden_3 = Dense(hidden_g_1_size, activation='relu')(g_drop_2)
generator_output = Dense(input_size, activation='relu')(g_hidden_3)
generator = Model(generator_input, generator_output)
critic.trainable = False
GAN = Sequential()
GAN.add(generator)
GAN.add(critic)
opt = RMSprop(learning_rate=0.00005, clipvalue=0.01)
GAN.compile(optimizer=opt, loss=wasserstein_loss)


c_history = []
g_history = []
for i in range(nstep) :
## Train critic more than generator
c_loss = 0
for j in range(ncritstep) :
critic.trainable = True
x_real, y_real = generate_real_samples(real, 1000)
x_fake, y_fake = generate_fake_samples(generator, latent_dim_size, 1000)
c_loss += critic.train_on_batch(x_real, y_real)
c_loss += critic.train_on_batch(x_fake, y_fake)
c_history.append(c_loss)
critic.trainable = False
x_gan = generate_latent_points(latent_dim_size, 1000)
y_gan = -np.ones((1000,1))
g_loss = GAN.train_on_batch(x_gan, y_gan)
g_history.append(g_loss)
if (i+1)%10 == 0 :
print(">"+str(i+1)+": c_loss= "+str(c_loss)+"; g_loss= "+str(g_loss))


print("Training finished! Moving on to output")
## output removed for simplicity's sake
至于我的数据,正如我提到的,它相当大。我能告诉你的是,大约2500个样本有近10万个维度,它们都是浮动的。在我看来,我想做的事情并不像我遇到的问题那么重要。如果你真的想要看到一些数据是什么样子的,我可以想办法把一些数据存到云端之类的地方。正如我所说,它曾经在较小的服务器上处理该数据集,并且在使用70K以下维度时工作,但是当使用TF 2.6在完整数据集上运行时,它会产生NaNs。为什么会发生这种情况?我似乎无法用现有的信息来诊断我的神经网络。

-根据gobrewers的建议,我添加了tf.debugging.enable_check_numerics(),这是输出:

Traceback (most recent call last):
File "encoded_wgan.py", line 272, in <module>
c_loss += critic.train_on_batch(x, y)
File "/tools/python/3.7.4/lib/python3.7/site-packages/keras/engine/training.py", line 1856, in train_on_batch
logs = self.train_function(iterator)
File "/tools/python/3.7.4/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 885, in __call__
result = self._call(*args, **kwds)
File "/tools/python/3.7.4/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 950, in _call
return self._stateless_fn(*args, **kwds)
File "/tools/python/3.7.4/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 3040, in __call__
filtered_flat_args, captured_inputs=graph_function.captured_inputs)  # pylint: disable=protected-access
File "/tools/python/3.7.4/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 1964, in _call_flat
ctx, args, cancellation_manager=cancellation_manager))
File "/tools/python/3.7.4/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 596, in call
ctx=ctx)
File "/tools/python/3.7.4/lib/python3.7/site-packages/tensorflow/python/eager/execute.py", line 138, in execute_with_callbacks
tensors = quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
File "/tools/python/3.7.4/lib/python3.7/site-packages/tensorflow/python/eager/execute.py", line 60, in quick_execute
inputs, attrs, num_outputs)
tensorflow.python.framework.errors_impl.InvalidArgumentError:  
!!! Detected Infinity or NaN in output 0 of graph op "IteratorGetNext" (# of outputs: 2) !!!
dtype: <dtype: 'float32'>
shape: (2000, 99557)
Input tensor: Tensor("iterator:0", shape=(), dtype=resource)
Graph name: "train_function"
Stack trace of op's creation ("->": inferred user code):
+ encoded_wgan.py (L272) <module>
-> |   c_loss += critic.train_on_batch(x, y)
+ ...b/python3.7/site-packages/keras/engine/training.py (L1856) train_on_batch
-> |   logs = self.train_function(iterator)
+ ...e-packages/tensorflow/python/eager/def_function.py (L885) __call__
|   result = self._call(*args, **kwds)
+ ...e-packages/tensorflow/python/eager/def_function.py (L950) _call
|   return self._stateless_fn(*args, **kwds)
+ .../site-packages/tensorflow/python/eager/function.py (L3038) __call__
|   filtered_flat_args) = self._maybe_define_function(args, kwargs)
+ .../site-packages/tensorflow/python/eager/function.py (L3463) _maybe_define_function
|   graph_function = self._create_graph_function(args, kwargs)
+ .../site-packages/tensorflow/python/eager/function.py (L3308) _create_graph_function
|   capture_by_value=self._capture_by_value),
+ ...packages/tensorflow/python/framework/func_graph.py (L1007) func_graph_from_py_func
|   func_outputs = python_func(*func_args, **func_kwargs)
+ ...e-packages/tensorflow/python/eager/def_function.py (L668) wrapped_fn
|   out = weak_wrapped_fn().__wrapped__(*args, **kwds)
+ ...packages/tensorflow/python/framework/func_graph.py (L990) wrapper
|   user_requested=True,
+ ...b/python3.7/site-packages/keras/engine/training.py (L853) train_function
-> |   return step_function(self, iterator)
+ ...b/python3.7/site-packages/keras/engine/training.py (L841) step_function
-> |   data = next(iterator)
+ ...ackages/tensorflow/python/data/ops/iterator_ops.py (L761) __next__
|   return self._next_internal()
+ ...ackages/tensorflow/python/data/ops/iterator_ops.py (L738) _next_internal
|   output_shapes=self._flat_output_shapes)
+ ...-packages/tensorflow/python/ops/gen_dataset_ops.py (L2750) iterator_get_next
|   output_shapes=output_shapes, name=name)
+ ...ages/tensorflow/python/framework/op_def_library.py (L750) _apply_op_helper
|   attrs=attr_protos, op_def=op_def)
+ ...packages/tensorflow/python/framework/func_graph.py (L601) _create_op_internal
|   compute_device)
+ ...7/site-packages/tensorflow/python/framework/ops.py (L3569) _create_op_internal
|   op_def=op_def)
+ ...7/site-packages/tensorflow/python/framework/ops.py (L2045) __init__
|   self._traceback = tf_stack.extract_stack_for_node(self._c_op)
: Tensor had NaN values
[[node IteratorGetNext/CheckNumericsV2 (defined at encoded_wgan.py:272) ]] [Op:__inference_train_function_2000]
Function call stack:
train_function

(注意:L272是我在完整脚本中调用的第一个train_on_batch())

我是盲人

我的数据中有NaNs,只是在一些非常特定的位置。因此,当我只加载数据的一部分时,我不包含它们,但我使用完整的数据集。

检查你的数据傻瓜

最新更新