Hi folks,
I am attempting the UrbanSound8K Kaggle challenge.
I am training a CNN on spectral representations of audio excerpts.
For that, I use PyTorch Lightning.
Please find all my scripts for this challenge on my Github: UrbanSound8K
The ONLY thing that doesn’t follow Lightning’s philosophy is… my dataset.
At the moment, I am creating a PyTorch custom dataset class, that inherits Dataset class, instantiate it in the main.py file, and then create my data loaders outside of it.
Below is part of the custom dataset class (contained in dataset.py).
class UrbanSound8KDataset(Dataset):
def __init__(self, dataset_dir, transforms_params, device):
self.device = device
self.dataset_dir = dataset_dir
self.metadata = pd.read_csv(os.path.join(dataset_dir, "UrbanSound8K.csv"))
self.n_folds = max(self.metadata["fold"])
self.n_classes = len(self.metadata["class"].unique())
self.classes_map = pd.Series(self.metadata["class"].values,index=self.metadata["classID"]).sort_index().to_dict()
self.target_sample_rate = transforms_params["target_sample_rate"]
self.target_length = transforms_params["target_length"]
self.n_samples = transforms_params["n_samples"]
self.n_fft = transforms_params["n_fft"]
self.n_mels = transforms_params["n_mels"]
def __len__(self):
return len(self.metadata)
def __getitem__(self, index):
audio_name = self._get_event_audio_name(index)
class_id = torch.tensor(self._get_event_class_id(index), dtype=torch.long)
signal, sr = self._get_event_signal(index)
signal = signal.to(self.device)
signal = self._mix_down_if_necessary(signal)
signal = self._resample_if_necessary(signal, sr)
signal = self._cut_if_necessary(signal)
signal = self._right_pad_if_necessary(signal)
mel_spectrogram = self._mel_spectrogram_transform(signal)
mel_spectrogram_db = self._db_transform(mel_spectrogram)
return index, audio_name, class_id, mel_spectrogram_db
def _get_event_class_id(self, index):
return self.metadata.iloc[index]["classID"]
def _get_event_audio_name(self, index):
return self.metadata.iloc[index]["slice_file_name"]
def _get_event_signal(self, index):
event_fold = f"fold{self.metadata.iloc[index]['fold']}"
event_filename = self.metadata.iloc[index]["slice_file_name"]
audio_path = os.path.join(self.dataset_dir, event_fold, event_filename)
signal, sr = torchaudio.load(audio_path, normalize=True)
return signal, sr
def _mix_down_if_necessary(self, signal):
if signal.shape[0] > 1:
signal = torch.mean(signal, dim=0, keepdim=True)
return signal
def _resample_if_necessary(self, signal, sr):
if sr != self.target_sample_rate:
resample_transform = torchaudio.transforms.Resample(sr, self.target_sample_rate)
resample_transform = resample_transform.to(self.device)
signal = resample_transform(signal)
return signal
def _cut_if_necessary(self, signal):
if signal.shape[1] > self.n_samples:
signal = signal[:, :self.n_samples]
return signal
def _right_pad_if_necessary(self, signal):
signal_length = signal.shape[1]
if signal_length < self.n_samples:
num_missing_samples = self.n_samples - signal_length
last_dim_padding = (0, num_missing_samples)
signal = nn.functional.pad(signal, last_dim_padding)
return signal
def _spectrogram_transform(self, signal):
spectrogram_transform = transforms.Spectrogram(
n_fft = self.n_fft,
win_length = self.n_fft,
hop_length = self.n_fft // 2,
pad = 0,
window_fn = torch.hann_window,
power = 2,
normalized = True,
wkwargs = None,
center = False,
pad_mode = "reflect",
onesided = True,
return_complex = False
)
spectrogram_transform = spectrogram_transform.to(self.device)
spectrogram = spectrogram_transform(signal)
return spectrogram
def _mel_spectrogram_transform(self, signal):
mel_spectrogram_transform = torchaudio.transforms.MelSpectrogram(
sample_rate = self.target_sample_rate,
n_fft = self.n_fft,
n_mels = self.n_mels,
window_fn = torch.hann_window,
power = 2,
normalized = True,
wkwargs = None,
center = True,
pad_mode = "reflect",
onesided = True,
norm = None,
mel_scale = "htk"
)
mel_spectrogram_transform = mel_spectrogram_transform.to(self.device)
mel_spectrogram = mel_spectrogram_transform(signal)
return mel_spectrogram
def _db_transform(self, mel_spectrogram):
db_transform = torchaudio.transforms.AmplitudeToDB(stype="power")
db_transform = db_transform.to(self.device)
mel_spectrogram_db = db_transform(mel_spectrogram)
return mel_spectrogram_db
Then, in main.py, I instantiate an object of this class like below (the class is in a dataset.py file, hence the dataset.).
ds = dataset.UrbanSound8KDataset("dataset", transforms_params, args.device)
And finally, I create the data loaders.
# Get the train and validation sets
train_metadata = dataset.metadata.drop(dataset.metadata[dataset.metadata["fold"]==i].index)
train_indices = train_metadata.index.to_list()
train_sampler = SubsetRandomSampler(train_indices)
validation_indices = dataset.metadata[dataset.metadata["fold"]==i].index.to_list()
# Create the train and validation dataloaders
train_dataloader = DataLoader(
ds,
batch_size=args.bs,
sampler=train_sampler,
num_workers=args.workers
)
validation_dataloader = DataLoader(
ds,
sampler=validation_indices,
batch_size=args.bs,
num_workers=args.workers
)
What I would like to do now is turn all this into a LightningDataModule.
I have used it already for CIFAR-100 but it was way less complex as the images were loaded all at once and few pre-processing was done. Here I have a lot of audios, and many pre-processing steps. And I don’t know where this should go in a LightningDataModule.
Can you please tell me if this is possible in this particular case?
If so, what would be this advantages of converting to a LightningDataModule?
The advantages I see to how I am doing it now: I can index my dataset object to return an pre-processed audio (a spectrogram image in my case). But that may be feasible with a LightningDataModule, I don’t know.
The downsides I see to how I am doing it now: I do it in several steps whereas a LightningDataModule seems to wrap up the data setup and dataloaders creation.
Please help me on this one.
Cheers
Antoine