Turn off ddp_sharded during evaluation

Hi there,

I am using ddp_sharded with fairscale, and it works fine during training with the lightning Trainer. But I found that in evaluation for dev/test set, ddp_sharded is still turned on, i.e. the dataset is split into shards and evaluated separately, which is difficult to calculate evaluation metric (e.g. acc), or using early stopping. So is there anyway I can use ddp_sharded during training, but turn it off for evaluation only on a single GPU?

Here is code snippet of my trainer, and the model is a simple pytorch classifier using huggingface transformers.

class CLSTrainerINTCLS(pl.LightningModule):
def init(self, args, model):
super().init()
self.args = args
self.model = model

def training_step(self, batch, batch_idx):
    output_dicts = self.model(**batch)
    preds = torch.argmax(output_dicts["logits"], dim=-1)
    return output_dicts["loss"]

def validation_step(self, batch, batch_idx):
    return self.eval_step(batch, batch_idx, split="dev")

def test_step(self, batch, batch_idx):
    return self.eval_step(batch, batch_idx, split="test")

def eval_step(self, batch, batch_idx, split=None):
    output_dicts = self.model(**batch)
    preds = torch.argmax(output_dicts["logits"], dim=-1)
    return {
        f"{split}_loss": output_dicts["loss"].detach().cpu(),
        f"{split}_gold": batch["labels"].tolist(),
        f"{split}_pred": preds.tolist(),
    }

def validation_epoch_end(self, val_step_outputs):
    return self.eval_epoch_end(val_step_outputs, split="dev")

def test_epoch_end(self, test_step_outputs):
    return self.eval_epoch_end(test_step_outputs, split="test")

def eval_epoch_end(self, eval_step_outputs, split=None):
    loss = torch.mean(
        torch.stack([t[f"{split}_loss"] for t in eval_step_outputs])
    ).detach()
    golds = list(itertools.chain(*[t[f"{split}_gold"] for t in eval_step_outputs]))
    preds = list(itertools.chain(*[t[f"{split}_pred"] for t in eval_step_outputs]))
    acc = accuracy_score(golds, preds) * 100
    self.log(f"{split}_acc", acc, prog_bar=False)

pytorch :v1.8.1
pytorch lightning: v1.3.8
fairscale: v0.3.8
transformers: v4.6.1