Python中Tensorflow CNN层中的数据转换



我想使用CNN对音频文件进行分类。通常你可以把它们转换成光谱图,然后训练它们。

例如,Tensorflow的教程中描述了这一点。

然而,我现在正在寻找一种解决方案,它不在训练之前计算谱图,而是通过一层在模型本身中计算谱图。

为什么这样?

因为我想通过超参数优化来改进光谱图的设置。或者,理想情况下,网络本身应该学习最佳参数。

我建立了一个小例子。第一部分与Tensorflow教程中的内容完全相同。

import os
import pathlib
import tensorflow as tf
import numpy as np
from tensorflow.keras import layers
from tensorflow.keras import models
DATASET_PATH = 'data/mini_speech_commands'
data_dir = pathlib.Path(DATASET_PATH)
if not data_dir.exists():
  tf.keras.utils.get_file(
      'mini_speech_commands.zip',
      origin="http://storage.googleapis.com/download.tensorflow.org/data/mini_speech_commands.zip",
      extract=True,
      cache_dir='.', cache_subdir='data')

filenames = tf.io.gfile.glob(str(data_dir) + '/*/*')
filenames = tf.random.shuffle(filenames)
num_samples = len(filenames)
print('Number of total examples:', num_samples)
print('Example file tensor:', filenames[0])
commands = np.array(tf.io.gfile.listdir(str(data_dir)))
commands = commands[commands != 'README.md']
print('Commands:', commands)
train_files = filenames[:6400]
val_files = filenames[6400: 6400 + 800]
test_files = filenames[-800:]
print('Training set size', len(train_files))
print('Validation set size', len(val_files))
print('Test set size', len(test_files))

def decode_audio(audio_binary):
  audio, _ = tf.audio.decode_wav(contents=audio_binary)
  return tf.squeeze(audio, axis=-1)

def get_label(file_path):
  parts = tf.strings.split(
      input=file_path,
      sep=os.path.sep)
  return parts[-2]

def get_waveform_and_label(file_path):
  label = get_label(file_path)
  audio_binary = tf.io.read_file(file_path)
  waveform = decode_audio(audio_binary)
  return waveform, label
AUTOTUNE = tf.data.AUTOTUNE
files_ds = tf.data.Dataset.from_tensor_slices(train_files)
waveform_ds = files_ds.map(
    map_func=get_waveform_and_label,
    num_parallel_calls=AUTOTUNE)

我试着创建一个Layers的子类来做我想要的事情。

class SpectrogramTransform(layers.Layer):
    def __init__(self, fs=52000, nperseg=64, noverlap=32, nfft=16000):
        super(SpectrogramTransform, self).__init__()
        self.fs = fs
        self.nperseg = nperseg
        self.noverlap = noverlap
        self.nfft = nfft
    def calculate_spectrogram(self, inputs):
        _, _, Sxx = signal.spectrogram(
            x=inputs,
            fs=self.fs,
            nfft=self.nfft,
            nperseg=self.nperseg,
            noverlap=self.noverlap,
            mode="magnitude",
            )
        return Sxx
    
    def call(self, inputs):
        # convert tensor to numpy
        sess = tf.compat.v1.InteractiveSession()
        inputs = inputs.eval(session=sess)
        
        Sxx = self.calculate_spectrogram(inputs)
        # convert numpy back to tensor
        Sxx = tf.convert_to_tensor(Sxx, dtype=tf.float32)
        return Sxx 

test_file = tf.io.read_file(DATASET_PATH+'/down/0a9f9af7_nohash_0.wav')
test_audio, _ = tf.audio.decode_wav(contents=test_file)
print(test_audio.shape)
num_labels = len(commands)
    
model = models.Sequential([
    layers.Input(shape=test_audio.shape),
    SpectrogramTransform(),
    layers.Conv2D(32, 3, activation='relu'),
    layers.Conv2D(64, 3, activation='relu'),
    layers.MaxPooling2D(),
    layers.Dropout(0.25),
    layers.Flatten(),
    layers.Dense(128, activation='relu'),
    layers.Dropout(0.5),
    layers.Dense(num_labels),
])

尺寸似乎至少有问题。但显然还有更多的错误:(

但我也没有得到一个非常有意义的错误信息。

层接收的调用参数";spectrogram_transform";(类型SpectrogramTransform(:•inputs=tf.Tensor(shape=(None,13654,1(,dtype=float32(

你们中有人已经有了想法或类似的事情吗?事先非常感谢。

谢谢@djinn,这并没有解决我的问题,但这通常是一个很好的改进建议。

然而,我现在已经自己找到了解决方案。我想问题出在TF张量和Numpy数组之间的转换上。现在,我直接在张量上进行整个光谱图的计算。参见tfio.audio.spectrogram 的文档

这是我改课的结果。然而,我也对我的数据集做了一些小的更改,但它们在这里并不那么重要。但如果有人也对它感兴趣,请告诉我,我会添加它。

class SpectrogramTransform(layers.Layer):
    def __init__(self, window=512, stride=256, nfft=16000):
        super(SpectrogramTransform, self).__init__()
        self.window = window
        self.stride = stride
        self.nfft = nfft
    def call(self, waveform):
        return tfio.audio.spectrogram(input=waveform, nfft=self.nfft, window=self.window, stride=self.stride)

最新更新