From PyTorch's custom Dataset to Lightning's custom DataModule

Hi folks,

I am attempting the UrbanSound8K Kaggle challenge.
I am training a CNN on spectral representations of audio excerpts.
For that, I use PyTorch Lightning.

Please find all my scripts for this challenge on my Github: UrbanSound8K

The ONLY thing that doesn’t follow Lightning’s philosophy is… my dataset.

At the moment, I am creating a PyTorch custom dataset class, that inherits Dataset class, instantiate it in the file, and then create my data loaders outside of it.

Below is part of the custom dataset class (contained in

class UrbanSound8KDataset(Dataset):
    def __init__(self, dataset_dir, transforms_params, device):
        self.device = device
        self.dataset_dir = dataset_dir
        self.metadata = pd.read_csv(os.path.join(dataset_dir, "UrbanSound8K.csv"))
        self.n_folds = max(self.metadata["fold"])
        self.n_classes = len(self.metadata["class"].unique())
        self.classes_map = pd.Series(self.metadata["class"].values,index=self.metadata["classID"]).sort_index().to_dict()
        self.target_sample_rate = transforms_params["target_sample_rate"]
        self.target_length = transforms_params["target_length"]
        self.n_samples = transforms_params["n_samples"]
        self.n_fft = transforms_params["n_fft"]
        self.n_mels = transforms_params["n_mels"]
    def __len__(self):
        return len(self.metadata)
    def __getitem__(self, index):
        audio_name = self._get_event_audio_name(index)
        class_id = torch.tensor(self._get_event_class_id(index), dtype=torch.long)
        signal, sr = self._get_event_signal(index)
        signal =
        signal = self._mix_down_if_necessary(signal)
        signal = self._resample_if_necessary(signal, sr)
        signal = self._cut_if_necessary(signal)
        signal = self._right_pad_if_necessary(signal)
        mel_spectrogram = self._mel_spectrogram_transform(signal)
        mel_spectrogram_db = self._db_transform(mel_spectrogram)
        return index, audio_name, class_id, mel_spectrogram_db
    def _get_event_class_id(self, index):
        return self.metadata.iloc[index]["classID"]
    def _get_event_audio_name(self, index):
        return self.metadata.iloc[index]["slice_file_name"]
    def _get_event_signal(self, index):
        event_fold = f"fold{self.metadata.iloc[index]['fold']}"
        event_filename = self.metadata.iloc[index]["slice_file_name"]
        audio_path = os.path.join(self.dataset_dir, event_fold, event_filename)
        signal, sr = torchaudio.load(audio_path, normalize=True)
        return signal, sr
    def _mix_down_if_necessary(self, signal):
        if signal.shape[0] > 1:
            signal = torch.mean(signal, dim=0, keepdim=True)
        return signal
    def _resample_if_necessary(self, signal, sr):
        if sr != self.target_sample_rate:
            resample_transform = torchaudio.transforms.Resample(sr, self.target_sample_rate)
            resample_transform =
            signal = resample_transform(signal)
        return signal
    def _cut_if_necessary(self, signal):
        if signal.shape[1] > self.n_samples:
            signal = signal[:, :self.n_samples]
        return signal
    def _right_pad_if_necessary(self, signal):
        signal_length = signal.shape[1]
        if signal_length < self.n_samples:
            num_missing_samples = self.n_samples - signal_length
            last_dim_padding = (0, num_missing_samples)
            signal = nn.functional.pad(signal, last_dim_padding)
        return signal
    def _spectrogram_transform(self, signal):
        spectrogram_transform = transforms.Spectrogram(
                                                        n_fft = self.n_fft,
                                                        win_length = self.n_fft,
                                                        hop_length = self.n_fft // 2,
                                                        pad = 0,
                                                        window_fn = torch.hann_window,
                                                        power = 2,
                                                        normalized = True,
                                                        wkwargs = None,
                                                        center = False,
                                                        pad_mode = "reflect",
                                                        onesided = True,
                                                        return_complex = False
        spectrogram_transform =
        spectrogram = spectrogram_transform(signal)
        return spectrogram
    def _mel_spectrogram_transform(self, signal):
        mel_spectrogram_transform = torchaudio.transforms.MelSpectrogram(
                                                        sample_rate = self.target_sample_rate,
                                                        n_fft = self.n_fft,
                                                        n_mels = self.n_mels,
                                                        window_fn = torch.hann_window,
                                                        power = 2,
                                                        normalized = True,
                                                        wkwargs = None,
                                                        center = True,
                                                        pad_mode = "reflect",
                                                        onesided = True,
                                                        norm = None,
                                                        mel_scale = "htk"
        mel_spectrogram_transform =
        mel_spectrogram = mel_spectrogram_transform(signal)
        return mel_spectrogram
    def _db_transform(self, mel_spectrogram):
        db_transform = torchaudio.transforms.AmplitudeToDB(stype="power")
        db_transform =
        mel_spectrogram_db = db_transform(mel_spectrogram)
        return mel_spectrogram_db

Then, in, I instantiate an object of this class like below (the class is in a file, hence the dataset.).

ds = dataset.UrbanSound8KDataset("dataset", transforms_params, args.device)

And finally, I create the data loaders.

        # Get the train and validation sets
        train_metadata = dataset.metadata.drop(dataset.metadata[dataset.metadata["fold"]==i].index)
        train_indices = train_metadata.index.to_list() 
        train_sampler = SubsetRandomSampler(train_indices)
        validation_indices = dataset.metadata[dataset.metadata["fold"]==i].index.to_list()
        # Create the train and validation dataloaders
        train_dataloader = DataLoader(
        validation_dataloader = DataLoader(

What I would like to do now is turn all this into a LightningDataModule.

I have used it already for CIFAR-100 but it was way less complex as the images were loaded all at once and few pre-processing was done. Here I have a lot of audios, and many pre-processing steps. And I don’t know where this should go in a LightningDataModule.

Can you please tell me if this is possible in this particular case?
If so, what would be this advantages of converting to a LightningDataModule?

The advantages I see to how I am doing it now: I can index my dataset object to return an pre-processed audio (a spectrogram image in my case). But that may be feasible with a LightningDataModule, I don’t know.

The downsides I see to how I am doing it now: I do it in several steps whereas a LightningDataModule seems to wrap up the data setup and dataloaders creation.

Please help me on this one.