Guidance on accessing DataModule Vocab or tokenisers during inference

Hi.

I’m currently implementing a toy Seq2Seq model for NLP translation German -> English (data: Multi30k). The code works, but I have issues defining an inference function that uses a loaded model checkpoint.

During training (at the end of an epoch) I use a function (translate()) that transforms a test sentence and logs to w&b. This function uses tokenizers and vocabularies that are generated/ stored in the DataModule. This works in this context since the trainer is still live and a pointer to the DataModule exists at this time.

This is the setup of my DataModule:

    def setup(self, stage=None):
        self.spacy_de = spacy.load("de")
        self.spacy_en = spacy.load("en")

        self.german = Field(
            tokenize=self._tokenize_de,
            lower=True,
            init_token="<sos>",
            eos_token="<eos>",
        )

        self.english = Field(
            tokenize=self._tokenize_en,
            lower=True,
            init_token="<sos>",
            eos_token="<eos>",
        )

        multi30k_train, multi30k_val, multi30k_test = Multi30k.splits(
            exts=(".de", ".en"), fields=(self.german, self.english)
        )

        self.german.build_vocab(multi30k_train, max_size=10000, min_freq=2)
        self.english.build_vocab(multi30k_train, max_size=10000, min_freq=2)

        self.train_it, self.val_it, self.test_it = BucketIterator.splits(
            (multi30k_train, multi30k_val, multi30k_test),
            batch_size=self.batch_size,
            sort_within_batch=True,
            sort_key=lambda x: len(x.src),
        )

This is a chunk of the translate function that utilises stuff from the DataModule (there are more of this):

# example 
tokens = [token.text.lower() for token in model.trainer.datamodule.spacy_de(sentence)]

I receive an error that a reference model.trainer.datamodule (that I use to query the tokenizers) is None. Since I suppose this is a weak link to the DataModules that’s not part of my checkpoints I’m wondering how people usually split this workload? I guess it’s a messy thing to utilise stuff from the DataModules like this?

Are you defining tokenizers and vocabularies outside of the data modules?

Cheers,
C

Maybe you can pass the tokenizers into the mode __init__? That way they will be saved in the model checkpoint:

class MyModule(pl.LightningModule):
    def __init__(self, german, english):
         super().__init__()
         self.save_hyperparameters()
         self.german = german
         self.english = english

dm = MyDataModule()
model = MyModule(dm.german, dm.english)