Multi-Fidelity Hyperparameter Tuning

In our example above, a transformer language model is trained for 40 epochs before being validated. If a configuration performs poorly, we should find out earlier, and a lot of time could be saved by stopping poorly performing trials early. This is what multi-fidelity HPO methods are doing. There are different variants:

  • Early stopping (“stopping” type): Trials are not just validated after 40 epochs, but at the end of every epoch. If a trial is performing worse than many others trained for the same number of epochs, it is stopped early.

  • Pause and resume (“promotion” type): Trials are generally paused at the end of certain epochs, called rungs. A paused trial gets promoted (i.e., its training is resumed) if it does better than a majority of trials who reached the same rung.

Syne Tune provides a large number of multi-fidelity HPO methods, more details are given in this tutorial. In this section, you learn what needs to be done to support multi-fidelity hyperparameter tuning.

Annotating a Training Script for Multi-fidelity Tuning

Clearly, the training script training_script_report_end.py won’t do for multi-fidelity tuning. These methods need to know validation errors of models after each epoch of training, while the script above only validates the model at the end, after 40 epochs of training. A small modification of our training script, training_script_no_checkpoints.py, enables multi-fidelity tuning. The relevant part is this:

transformer_wikitext2/code/training_script_no_checkpoints.py – objective
def objective(config):
    torch.manual_seed(config["seed"])
    use_cuda = config["use_cuda"]
    if torch.cuda.is_available() and not use_cuda:
        print("WARNING: You have a CUDA device, so you should run with --use-cuda 1")
    device = torch.device("cuda" if use_cuda else "cpu")
    # Download data, setup data loaders
    corpus = download_dataset(config)
    ntokens = len(corpus.dictionary)
    train_data = batchify(corpus.train, bsz=config["batch_size"], device=device)
    valid_data = batchify(corpus.valid, bsz=10, device=device)
    # Used for reporting metrics to Syne Tune
    report = Reporter()
    # Create model and optimizer
    model, optimizer, criterion = create_training_objects(config, ntokens, device)

    for epoch in range(1, config[MAX_RESOURCE_ATTR] + 1):
        train(model, train_data, optimizer, criterion, config, ntokens, epoch)
        val_loss = evaluate(model, valid_data, criterion, config, ntokens)
        print("-" * 89)
        print(
            f"| end of epoch {epoch:3d} | valid loss {val_loss:5.2f} | "
            f"valid ppl {np.exp(val_loss):8.2f}"
        )
        print("-" * 89)
        # Report validation loss back to Syne Tune
        report(**{RESOURCE_ATTR: epoch, METRIC_NAME: val_loss})

Instead of calling report only once, at the end, we evaluate the model and report back at the end of each epoch. We also need to report the number of epochs done, using RESOURCE_ATTR as key. The execution backend receives these reports and relays them to the HPO method, which in turn makes a decision whether the trial may continue or should be stopped.

Checkpointing

Instead of stopping underperforming trials, some multi-fidelity methods rather pause trials. Any paused trial can be resumed in the future if there is evidence that it outperforms the majority of other trials. If training is very expensive, pause-and-resume scheduling can work better than early stopping, because any pause decision can be revisited in the future, while a stopping decision is final. Moreover, pause-and-resume scheduling does not require trials to be stopped, which can carry delays in some execution backends.

However, pause-and-resume scheduling needs checkpointing in order to work well. Once a trial is paused, its mutable state is stored in disk. When a trial gets resumed, this state is loaded from disk, and training can resume exactly from where it stopped.

Checkpointing needs to be implemented as part of the training script. Fortunately, Syne Tune provides some tooling to simplify this. Another modification of our training script, training_script.py, enables checkpointing. The relevant part is this:

transformer_wikitext2/code/training_script.py – objective
def objective(config):
    torch.manual_seed(config["seed"])
    use_cuda = config["use_cuda"]
    if torch.cuda.is_available() and not use_cuda:
        print("WARNING: You have a CUDA device, so you should run with --use-cuda 1")
    device = torch.device("cuda" if use_cuda else "cpu")
    # Download data, setup data loaders
    corpus = download_dataset(config)
    ntokens = len(corpus.dictionary)
    train_data = batchify(corpus.train, bsz=config["batch_size"], device=device)
    valid_data = batchify(corpus.valid, bsz=10, device=device)
    # Used for reporting metrics to Syne Tune
    report = Reporter()
    # Create model and optimizer
    model, optimizer, criterion = create_training_objects(config, ntokens, device)
    # [3]
    # Checkpointing
    state_dict_objects = {
        "model": model,
        "optimizer": optimizer,
    }
    if config["precision"] == "half":
        state_dict_objects["amp"] = amp
    load_model_fn, save_model_fn = pytorch_load_save_functions(
        state_dict_objects=state_dict_objects,
    )
    # [2]
    # Resume from checkpoint
    resume_from = resume_from_checkpointed_model(config, load_model_fn)

    for epoch in range(resume_from + 1, config[MAX_RESOURCE_ATTR] + 1):
        train(model, train_data, optimizer, criterion, config, ntokens, epoch)
        val_loss = evaluate(model, valid_data, criterion, config, ntokens)
        print("-" * 89)
        print(
            f"| end of epoch {epoch:3d} | valid loss {val_loss:5.2f} | "
            f"valid ppl {np.exp(val_loss):8.2f}"
        )
        print("-" * 89)
        # [1]
        # Write checkpoint
        checkpoint_model_at_rung_level(config, save_model_fn, epoch)
        # Report validation loss back to Syne Tune
        report(**{RESOURCE_ATTR: epoch, METRIC_NAME: val_loss})

Full details about supporting checkpointing are given in this tutorial. In a nutshell:

  • [1] Checkpoints have to be written at the end of each epoch, to a path passed as command line argument. A checkpoint needs to include the epoch number when it was written.

  • [2] Before the training loop starts, a checkpoint should be loaded from the same place. If one is found, the training loop skips all epochs already done. If not, it starts from scratch as usual.

  • [3] Syne Tune provides some checkpointing tooling for PyTorch models.

At this point, we have a final version, training_script.py, of our training script, which can be used with all HPO methods in Syne Tune. While earlier versions are simpler to implement, we recommend to include reporting and checkpointing after every epoch in any training script you care about. When checkpoints become very large, you may run into problems with disk space, which can be dealt with as described here.

Note

The pause-and-resume HPO methods in Syne Tune also work if checkpointing is not implemented. However, this means that training for a trial to be resumed in fact starts from scratch. The additional overhead makes running these methods less attractive. We strongly recommend to implement checkpointing.