ResNet-18 Trained on CIFAR-10

benchmarking/training_scripts/resnet_cifar10/resnet_cifar10.py
import os
import argparse
import logging
import time
from pathlib import Path

try:
    # Benchmark-specific imports are done here, in order to avoid import
    # errors if the dependencies are not installed (such errors should happen
    # only when the code is really called)
    from filelock import SoftFileLock, Timeout
    import numpy as np
    from tqdm import tqdm
    import torch
    import torch.nn.functional as F
    from torch.utils.data.sampler import SubsetRandomSampler
    from torchvision import datasets, transforms
    from torchvision.models import resnet18
except ImportError:
    logging.info(
        f"Please install benchmark-specific dependencies ({Path(__file__).parent / 'requirements.txt'})"
    )

from syne_tune import Reporter
from syne_tune.config_space import randint, uniform, loguniform, add_to_argparse
from syne_tune.utils import (
    resume_from_checkpointed_model,
    checkpoint_model_at_rung_level,
    add_checkpointing_to_argparse,
    pytorch_load_save_functions,
)


BATCH_SIZE_LOWER = 8

BATCH_SIZE_UPPER = 256

BATCH_SIZE_KEY = "batch_size"

METRIC_NAME = "objective"

RESOURCE_ATTR = "epoch"

MAX_RESOURCE_ATTR = "epochs"

ELAPSED_TIME_ATTR = "elapsed_time"


_config_space = {
    BATCH_SIZE_KEY: randint(BATCH_SIZE_LOWER, BATCH_SIZE_UPPER),
    "momentum": uniform(0, 0.99),
    "weight_decay": loguniform(1e-5, 1e-3),
    "lr": loguniform(1e-3, 0.1),
}


# ATTENTION: train_dataset, valid_dataset are both based on the CIFAR10
# training set, but train_dataset uses data augmentation. Make sure to
# only use disjoint parts for training and validation further down.
def get_CIFAR10(root):
    input_size = 32
    num_classes = 10
    normalize = [(0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)]
    train_transform = transforms.Compose(
        [
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(*normalize),
        ]
    )
    local_path = os.path.join(root, "CIFAR10")
    train_dataset = datasets.CIFAR10(
        local_path, train=True, transform=train_transform, download=True
    )

    valid_transform = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize(*normalize),
        ]
    )
    valid_dataset = datasets.CIFAR10(
        local_path, train=True, transform=valid_transform, download=True
    )

    return input_size, num_classes, train_dataset, valid_dataset


def train(model, train_loader, optimizer):
    model.train()
    total_loss = []
    for data, target in tqdm(train_loader):
        if torch.cuda.is_available():
            data = data.cuda()
            target = target.cuda()
        optimizer.zero_grad()
        prediction = model(data)
        loss = F.nll_loss(prediction, target)
        loss.backward()
        optimizer.step()
        total_loss.append(loss.item())
    avg_loss = sum(total_loss) / len(total_loss)


def valid(model, valid_loader):
    model.eval()
    loss = 0
    correct = 0
    for data, target in valid_loader:
        with torch.no_grad():
            if torch.cuda.is_available():
                data = data.cuda()
                target = target.cuda()
            prediction = model(data)
            loss += F.nll_loss(prediction, target, reduction="sum")
            prediction = prediction.max(1)[1]
            correct += prediction.eq(target.view_as(prediction)).sum().item()
    n_valid = len(valid_loader.sampler)
    loss /= n_valid
    valid_error = correct / n_valid
    return loss, valid_error


def _download_data(config):
    path = config["dataset_path"]
    os.makedirs(path, exist_ok=True)
    # Lock protection is needed for backends which run multiple worker
    # processes on the same instance
    lock_path = os.path.join(path, "lock")
    lock = SoftFileLock(lock_path)
    try:
        with lock.acquire(timeout=120, poll_intervall=1):
            input_size, num_classes, train_dataset, valid_dataset = get_CIFAR10(
                root=path
            )
    except Timeout:
        print(
            "WARNING: Could not obtain lock for dataset files. Trying anyway...",
            flush=True,
        )
        input_size, num_classes, train_dataset, valid_dataset = get_CIFAR10(root=path)
    return input_size, num_classes, train_dataset, valid_dataset


def _create_data_loaders(config, train_dataset, valid_dataset):
    indices = list(range(train_dataset.data.shape[0]))
    train_idx, valid_idx = indices[:40000], indices[40000:]
    train_sampler = SubsetRandomSampler(train_idx)
    valid_sampler = SubsetRandomSampler(valid_idx)
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=config[BATCH_SIZE_KEY],
        # shuffle=True,
        num_workers=0,
        sampler=train_sampler,
        pin_memory=True,
    )
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=128,
        # shuffle=False,
        num_workers=0,
        sampler=valid_sampler,
        pin_memory=True,
    )
    return train_loader, valid_loader


def _create_training_objects(config):
    model = Model()
    if torch.cuda.is_available():
        model = model.cuda()
        device = torch.device("cuda")
        model = torch.nn.DataParallel(
            model, device_ids=[i for i in range(config["num_gpus"])]
        ).to(device)
    milestones = [25, 40]
    optimizer = torch.optim.SGD(
        model.parameters(),
        lr=config["lr"],
        momentum=config["momentum"],
        weight_decay=config["weight_decay"],
    )
    scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, milestones=milestones, gamma=0.1
    )
    return model, optimizer, scheduler


def objective(config):
    torch.manual_seed(np.random.randint(10000))
    # Download data, setup data loaders
    input_size, num_classes, train_dataset, valid_dataset = _download_data(config)
    train_loader, valid_loader = _create_data_loaders(
        config, train_dataset, valid_dataset
    )
    # Do not want to count the time to download the dataset, which can be
    # substantial the first time
    ts_start = time.time()
    report = Reporter()
    # Create model, optimizer, LR scheduler
    model, optimizer, scheduler = _create_training_objects(config)
    # Checkpointing for PyTorch model
    load_model_fn, save_model_fn = pytorch_load_save_functions(
        {"model": model, "optimizer": optimizer, "lr_scheduler": scheduler}
    )
    # Resume from checkpoint (optional)
    resume_from = resume_from_checkpointed_model(config, load_model_fn)

    for epoch in range(resume_from + 1, config[MAX_RESOURCE_ATTR] + 1):
        train(model, train_loader, optimizer)
        scheduler.step()
        elapsed_time = time.time() - ts_start
        # Write checkpoint (optional)
        checkpoint_model_at_rung_level(config, save_model_fn, epoch)
        # Evaluate and send metrics back to Syne Tune
        _, valid_error = valid(model, valid_loader)
        report(
            **{
                RESOURCE_ATTR: epoch,
                METRIC_NAME: valid_error,
                ELAPSED_TIME_ATTR: elapsed_time,
            }
        )


if __name__ == "__main__":
    # Superclass reference torch.nn.Module requires torch to be defined
    class Model(torch.nn.Module):
        def __init__(self):
            super().__init__()
            self.resnet = resnet18(pretrained=False, num_classes=10)
            self.resnet.conv1 = torch.nn.Conv2d(
                3, 64, kernel_size=3, stride=1, padding=1, bias=False
            )
            self.resnet.maxpool = torch.nn.Identity()

        def forward(self, x):
            x = self.resnet(x)
            x = F.log_softmax(x, dim=1)
            return x

    root = logging.getLogger()
    root.setLevel(logging.INFO)

    parser = argparse.ArgumentParser()
    parser.add_argument(f"--{MAX_RESOURCE_ATTR}", type=int, required=True)
    parser.add_argument("--dataset_path", type=str, required=True)
    parser.add_argument("--num_gpus", type=int, default=1)
    add_to_argparse(parser, _config_space)
    add_checkpointing_to_argparse(parser)

    args, _ = parser.parse_known_args()

    objective(config=vars(args))

Transformer Trained on WikiText-2

benchmarking/training_scripts/transformer_wikitext2/training_script.py
import argparse
import os
import logging
import math
from pathlib import Path
import time

try:
    # Benchmark-specific imports are done here, in order to avoid import
    # errors if the dependencies are not installed (such errors should happen
    # only when the code is really called)
    import numpy as np
    from filelock import SoftFileLock, Timeout
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
except ImportError:
    logging.info(
        f"Please install benchmark-specific dependencies ({Path(__file__).parent / 'requirements.txt'})"
    )
try:
    from apex import amp
except ImportError:
    print("Failed to import apex. You can still train with --precision {float|double}.")

from syne_tune.report import Reporter
from syne_tune.config_space import randint, uniform, loguniform, add_to_argparse
from syne_tune.utils import (
    resume_from_checkpointed_model,
    checkpoint_model_at_rung_level,
    add_checkpointing_to_argparse,
    pytorch_load_save_functions,
)


BATCH_SIZE_LOWER = 16

BATCH_SIZE_UPPER = 48

BATCH_SIZE_KEY = "batch_size"

METRIC_NAME = "val_loss"

RESOURCE_ATTR = "epoch"

MAX_RESOURCE_ATTR = "epochs"

ELAPSED_TIME_ATTR = "elapsed_time"


_config_space = {
    "lr": loguniform(1e-6, 1e-3),
    "dropout": uniform(0, 0.99),
    BATCH_SIZE_KEY: randint(BATCH_SIZE_LOWER, BATCH_SIZE_UPPER),
    "momentum": uniform(0, 0.99),
    "clip": uniform(0, 1),
}


DATASET_PATH = "https://raw.githubusercontent.com/pytorch/examples/master/word_language_model/data/wikitext-2/"


def download_wikitext2_dataset(root):
    import urllib

    path = os.path.join(root, "wikitext-2")
    for fname in ("train.txt", "valid.txt", "test.txt"):
        fh = os.path.join(path, fname)
        if not os.path.exists(fh):
            os.makedirs(path, exist_ok=True)
            urllib.request.urlretrieve(DATASET_PATH + fname, fh)


class Dictionary(object):
    def __init__(self):
        self.word2idx = {}
        self.idx2word = []

    def add_word(self, word):
        if word not in self.word2idx:
            self.idx2word.append(word)
            self.word2idx[word] = len(self.idx2word) - 1
        return self.word2idx[word]

    def __len__(self):
        return len(self.idx2word)


class Corpus(object):
    def __init__(self, path):
        self.dictionary = Dictionary()
        self.train = None
        self.valid = None
        self.test = None
        if not self.load_cache(path):
            self.train = self.tokenize(os.path.join(path, "train.txt"))
            self.valid = self.tokenize(os.path.join(path, "valid.txt"))
            self.test = self.tokenize(os.path.join(path, "test.txt"))
            self.save_cache(path)

    def load_cache(self, path):
        for cache in ["dict.pt", "train.pt", "valid.pt", "test.pt"]:
            cache_path = os.path.join(path, cache)
            if not os.path.exists(cache_path):
                return False
        self.dictionary = torch.load(os.path.join(path, "dict.pt"))
        self.train = torch.load(os.path.join(path, "train.pt"))
        self.valid = torch.load(os.path.join(path, "valid.pt"))
        self.test = torch.load(os.path.join(path, "test.pt"))
        return True

    def save_cache(self, path):
        torch.save(self.dictionary, os.path.join(path, "dict.pt"))
        torch.save(self.train, os.path.join(path, "train.pt"))
        torch.save(self.valid, os.path.join(path, "valid.pt"))
        torch.save(self.test, os.path.join(path, "test.pt"))

    def tokenize(self, path):
        """Tokenizes a text file."""
        assert os.path.exists(path)
        # Add words to the dictionary
        with open(path, "r", encoding="utf8") as f:
            for line in f:
                words = line.split() + ["<eos>"]
                for word in words:
                    self.dictionary.add_word(word)

        # Tokenize file content
        with open(path, "r", encoding="utf8") as f:
            idss = []
            for line in f:
                words = line.split() + ["<eos>"]
                ids = []
                for word in words:
                    ids.append(self.dictionary.word2idx[word])
                idss.append(torch.tensor(ids).type(torch.int64))
            ids = torch.cat(idss)

        return ids


def get_batch(source, i, bptt):
    seq_len = min(bptt, len(source) - 1 - i)
    data = source[i : i + seq_len]
    target = source[i + 1 : i + 1 + seq_len].view(-1)
    return data, target


def batchloader(train_data, bptt):
    for batch, i in enumerate(range(0, train_data.size(0) - 1, bptt)):
        yield get_batch(train_data, i, bptt)


def batchify(data, bsz, device):
    # Work out how cleanly we can divide the dataset into bsz parts.
    nbatch = data.size(0) // bsz
    # Trim off any extra elements that wouldn't cleanly fit (remainders).
    data = data.narrow(0, 0, nbatch * bsz)
    # Evenly divide the data across the bsz batches.
    data = data.view(bsz, -1).t().contiguous()
    return data.to(device)


def setprec(t, precision):
    if precision == "half":
        # do nothing since this is handled by AMP
        return t
    elif precision == "float":
        return t.float()
    elif precision == "double":
        return t.double()
    else:
        raise ValueError(f"invalid precision string {precision}")


def download_dataset(config):
    path = config["input_data_dir"]
    os.makedirs(path, exist_ok=True)
    # Lock protection is needed for backends which run multiple worker
    # processes on the same instance
    lock_path = os.path.join(path, "lock")
    lock = SoftFileLock(lock_path)
    try:
        with lock.acquire(timeout=120, poll_intervall=1):
            # Make sure files are present locally
            download_wikitext2_dataset(path)
            corpus = Corpus(os.path.join(path, "wikitext-2"))
    except Timeout:
        print(
            "WARNING: Could not obtain lock for dataset files. Trying anyway...",
            flush=True,
        )
        # Make sure files are present locally
        download_wikitext2_dataset(path)
        corpus = Corpus(os.path.join(path, "wikitext-2"))
    return corpus


def evaluate(model, valid_data, criterion, config, ntokens):
    # Turn on evaluation mode which disables dropout
    model.eval()
    bptt = config["bptt"]
    total_loss = 0.0
    with torch.no_grad():
        for i in range(0, valid_data.size(0) - 1, bptt):
            data, targets = get_batch(valid_data, i, bptt)
            output = model(data)
            output = output.view(-1, ntokens)
            total_loss += len(data) * criterion(output, targets).item()
    return total_loss / (len(valid_data) - 1)


def train(model, train_data, optimizer, criterion, config, ntokens, epoch):
    # Turn on training mode which enables dropout
    model.train()
    bptt = config["bptt"]
    precision = config["precision"]
    log_interval = config["log_interval"]
    total_loss = 0.0
    epoch_loss = 0.0
    start_time = time.time()
    first_loss = None
    for batch, i in enumerate(range(0, train_data.size(0) - 1, bptt)):
        data, targets = get_batch(train_data, i, bptt)
        # Starting each batch, we detach the hidden state from how it was previously produced.
        # If we didn't, the model would try backpropagating all the way to start of the dataset.
        optimizer.zero_grad()
        output = model(data)
        output = output.view(-1, ntokens)
        loss = criterion(output, targets)
        if torch.isnan(loss):
            exit(0)
        if precision == "half":
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()
        clip = config["clip"]
        if clip > 0:
            # ``clip_grad_norm`` helps prevent the exploding gradient problem in RNNs / LSTMs.
            if precision == "half":
                params = amp.master_params(optimizer)
            else:
                params = model.parameters()
            torch.nn.utils.clip_grad_norm_(params, clip)
        optimizer.step()
        total_loss += loss.item()
        epoch_loss += len(data) * loss.item()
        if batch % log_interval == 0 and batch > 0:
            cur_loss = total_loss / log_interval
            elapsed = time.time() - start_time
            print(
                "| epoch {:3d} | {:5d}/{:5d} batches | lr {:02.5f} | ms/batch {:5.2f} | "
                "loss {:5.2f} | ppl {:8.2f}".format(
                    epoch,
                    batch,
                    len(train_data) // bptt,
                    config["lr"],
                    elapsed * 1000 / log_interval,
                    cur_loss,
                    np.exp(cur_loss),
                )
            )
            total_loss = 0
            start_time = time.time()
            if first_loss is None:
                first_loss = cur_loss
    return epoch_loss / (len(train_data) - 1), first_loss


def create_training_objects(config, ntokens, device):
    precision = config["precision"]
    d_model = config["d_model"]
    model = TransformerModel(
        ntokens,
        ninp=d_model,
        nhead=config["nhead"],
        nhid=d_model * config["ffn_ratio"],
        nlayers=config["nlayers"],
        dropout=config["dropout"],
    )
    model = model.to(device)
    model = setprec(model, precision)
    criterion = nn.NLLLoss()
    if config["optimizer_name"] == "sgd":
        optimizer = torch.optim.SGD(
            model.parameters(),
            lr=config["lr"],
            momentum=config["momentum"],
        )
    elif config["optimizer_name"] == "adam":
        optimizer = torch.optim.Adam(
            model.parameters(),
            lr=config["lr"],
            betas=(config["momentum"], 0.999),
        )
    else:
        raise ValueError(f"optimizer_name = {config['optimizer_name']} not supported")
    # half-precision black magic
    if precision == "half":
        model, optimizer = amp.initialize(
            model, optimizer, opt_level="O1", min_loss_scale=0.0001, verbosity=0
        )
    return model, optimizer, criterion


def objective(config):
    torch.manual_seed(config["seed"])
    use_cuda = config["use_cuda"]
    if torch.cuda.is_available() and not use_cuda:
        print("WARNING: You have a CUDA device, so you should run with --use-cuda 1")
    device = torch.device("cuda" if use_cuda else "cpu")
    # Download data, setup data loaders
    corpus = download_dataset(config)
    ntokens = len(corpus.dictionary)
    train_data = batchify(corpus.train, bsz=config["batch_size"], device=device)
    valid_data = batchify(corpus.valid, bsz=10, device=device)
    # Do not want to count the time to download the dataset, which can be
    # substantial the first time
    ts_start = time.time()
    # Used for reporting metrics to Syne Tune
    report = Reporter()
    # Create model and optimizer
    model, optimizer, criterion = create_training_objects(config, ntokens, device)
    # Checkpointing
    state_dict_objects = {
        "model": model,
        "optimizer": optimizer,
    }
    if config["precision"] == "half":
        state_dict_objects["amp"] = amp
    load_model_fn, save_model_fn = pytorch_load_save_functions(
        state_dict_objects=state_dict_objects,
    )
    # Resume from checkpoint (optional)
    resume_from = resume_from_checkpointed_model(config, load_model_fn)

    # At any point you can hit Ctrl + C to break out of training early.
    try:
        for epoch in range(resume_from + 1, config[MAX_RESOURCE_ATTR] + 1):
            epoch_start_time = time.time()
            train(model, train_data, optimizer, criterion, config, ntokens, epoch)
            val_loss = evaluate(model, valid_data, criterion, config, ntokens)
            curr_ts = time.time()
            elapsed_time = curr_ts - ts_start
            duration = curr_ts - epoch_start_time
            print("-" * 89)
            print(
                "| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | "
                "valid ppl {:8.2f}".format(epoch, duration, val_loss, np.exp(val_loss))
            )
            print("-" * 89)
            # Write checkpoint (optional)
            checkpoint_model_at_rung_level(config, save_model_fn, epoch)
            # Report metrics back to Syne Tune
            report_kwargs = {
                RESOURCE_ATTR: epoch,
                METRIC_NAME: val_loss,
                ELAPSED_TIME_ATTR: elapsed_time,
            }
            report(**report_kwargs)
    except KeyboardInterrupt:
        print("-" * 89)
        print("Exiting from training early")


if __name__ == "__main__":
    # Temporarily leave PositionalEncoding module here. Will be moved somewhere else.
    class PositionalEncoding(nn.Module):
        r"""Inject some information about the relative or absolute position of the tokens
            in the sequence. The positional encodings have the same dimension as
            the embeddings, so that the two can be summed. Here, we use sine and cosine
            functions of different frequencies.
        .. math::
            \text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model))
            \text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model))
            \text{where pos is the word position and i is the embed idx)
        Args:
            d_model: the embed dim (required).
            dropout: the dropout value (default=0.1).
            max_len: the max. length of the incoming sequence (default=5000).
        Examples:
            >>> pos_encoder = PositionalEncoding(d_model)
        """

        def __init__(self, d_model, dropout=0.1, max_len=5000):
            super(PositionalEncoding, self).__init__()
            self.dropout = nn.Dropout(p=dropout)

            pe = torch.zeros(max_len, d_model)
            position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
            div_term = torch.exp(
                torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
            )
            pe[:, 0::2] = torch.sin(position * div_term)
            pe[:, 1::2] = torch.cos(position * div_term)
            pe = pe.unsqueeze(0).transpose(0, 1)
            self.register_buffer("pe", pe)

        def forward(self, x):
            r"""Inputs of forward function
            Args:
                x: the sequence fed to the positional encoder model (required).
            Shape:
                x: [sequence length, batch size, embed dim]
                output: [sequence length, batch size, embed dim]
            Examples:
                >>> output = pos_encoder(x)
            """

            x = x + self.pe[: x.size(0), :]
            return self.dropout(x)

    class TransformerModel(nn.Module):
        """Container module with an encoder, a recurrent or transformer module, and a decoder."""

        def __init__(self, ntoken, ninp, nhead, nhid, nlayers, dropout=0.5):
            super(TransformerModel, self).__init__()
            try:
                from torch.nn import TransformerEncoder, TransformerEncoderLayer
            except ImportError:
                raise ImportError(
                    "TransformerEncoder module does not exist in PyTorch 1.1 or lower."
                )
            self.model_type = "Transformer"
            self.src_mask = None
            self.pos_encoder = PositionalEncoding(ninp, dropout)
            encoder_layers = TransformerEncoderLayer(ninp, nhead, nhid, dropout)
            self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
            self.encoder = nn.Embedding(ntoken, ninp)
            self.ninp = ninp
            self.decoder = nn.Linear(ninp, ntoken)

            self.init_weights()

        @staticmethod
        def _generate_square_subsequent_mask(sz):
            mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
            mask = (
                mask.float()
                .masked_fill(mask == 0, float("-inf"))
                .masked_fill(mask == 1, float(0.0))
            )
            return mask

        def init_weights(self):
            initrange = 0.1
            nn.init.uniform_(self.encoder.weight, -initrange, initrange)
            nn.init.zeros_(self.decoder.bias)
            nn.init.uniform_(self.decoder.weight, -initrange, initrange)

        def forward(self, src, has_mask=True):
            if has_mask:
                device = src.device
                if self.src_mask is None or self.src_mask.size(0) != len(src):
                    mask = self._generate_square_subsequent_mask(len(src)).to(device)
                    self.src_mask = mask
            else:
                self.src_mask = None

            src = self.encoder(src) * math.sqrt(self.ninp)
            src = self.pos_encoder(src)
            output = self.transformer_encoder(src, self.src_mask)
            output = self.decoder(output)
            return F.log_softmax(output, dim=-1)

    root = logging.getLogger()
    root.setLevel(logging.INFO)

    parser = argparse.ArgumentParser(
        description="PyTorch Wikitext-2 Transformer Language Model",
        formatter_class=argparse.RawTextHelpFormatter,
    )
    parser.add_argument(
        "--" + MAX_RESOURCE_ATTR, type=int, default=40, help="upper epoch limit"
    )
    parser.add_argument("--use_cuda", type=int, default=1)
    parser.add_argument(
        "--input_data_dir",
        type=str,
        default="./",
        help="location of the data corpus",
    )
    parser.add_argument(
        "--optimizer_name", type=str, default="sgd", choices=["sgd", "adam"]
    )
    parser.add_argument("--bptt", type=int, default=35, help="sequence length")
    parser.add_argument("--seed", type=int, default=1111, help="random seed")
    parser.add_argument(
        "--precision", type=str, default="float", help="float | double | half"
    )
    parser.add_argument(
        "--log_interval",
        type=int,
        default=200,
        help="report interval",
    )
    # These could become hyperparameters as well (more like NAS)
    parser.add_argument("--d_model", type=int, default=256, help="width of the model")
    parser.add_argument(
        "--ffn_ratio", type=int, default=1, help="the ratio of d_ffn to d_model"
    )
    parser.add_argument("--nlayers", type=int, default=2, help="number of layers")
    parser.add_argument(
        "--nhead",
        type=int,
        default=2,
        help="the number of heads in the encoder/decoder of the transformer model",
    )
    add_to_argparse(parser, _config_space)
    add_checkpointing_to_argparse(parser)

    args, _ = parser.parse_known_args()
    args.use_cuda = bool(args.use_cuda)

    objective(config=vars(args))

Multi-layer Perceptron Trained on Fashion MNIST

benchmarking/training_scripts/mlp_on_fashion_mnist/mlp_on_fashion_mnist.py
"""
Two-layer MLP trained on Fashion MNIST
"""
import os
import argparse
import logging
import time
from pathlib import Path

try:
    # Benchmark-specific imports are done here, in order to avoid import
    # errors if the dependencies are not installed (such errors should happen
    # only when the code is really called)
    from filelock import SoftFileLock, Timeout
    import torch
    import torch.nn as nn
    from torch.utils.data.sampler import SubsetRandomSampler
    from torchvision import datasets
    from torchvision import transforms
except ImportError:
    logging.info(
        f"Please install benchmark-specific dependencies ({Path(__file__).parent / 'requirements.txt'})"
    )

from syne_tune import Reporter
from syne_tune.config_space import (
    randint,
    lograndint,
    uniform,
    loguniform,
    add_to_argparse,
)
from syne_tune.utils import (
    resume_from_checkpointed_model,
    checkpoint_model_at_rung_level,
    add_checkpointing_to_argparse,
    pytorch_load_save_functions,
    parse_bool,
)


NUM_UNITS_1 = "n_units_1"

NUM_UNITS_2 = "n_units_2"

METRIC_NAME = "accuracy"

RESOURCE_ATTR = "epoch"

ELAPSED_TIME_ATTR = "elapsed_time"


_config_space = {
    NUM_UNITS_1: lograndint(4, 1024),
    NUM_UNITS_2: lograndint(4, 1024),
    "batch_size": randint(8, 128),
    "dropout_1": uniform(0, 0.99),
    "dropout_2": uniform(0, 0.99),
    "learning_rate": loguniform(1e-6, 1),
    "weight_decay": loguniform(1e-8, 1),
}


# Boilerplate for objective


def download_data(config):
    path = os.path.join(config["dataset_path"], "FashionMNIST")
    os.makedirs(path, exist_ok=True)
    # Lock protection is needed for backends which run multiple worker
    # processes on the same instance
    lock_path = os.path.join(path, "lock")
    lock = SoftFileLock(lock_path)
    try:
        with lock.acquire(timeout=120, poll_intervall=1):
            data_train = datasets.FashionMNIST(
                root=path, train=True, download=True, transform=transforms.ToTensor()
            )
    except Timeout:
        print(
            "WARNING: Could not obtain lock for dataset files. Trying anyway...",
            flush=True,
        )
        data_train = datasets.FashionMNIST(
            root=path, train=True, download=True, transform=transforms.ToTensor()
        )
    return data_train


def split_data(config, data_train):
    # We use 50000 samples for training and 10000 samples for validation
    indices = list(range(data_train.data.shape[0]))
    train_idx, valid_idx = indices[:50000], indices[50000:]
    train_sampler = SubsetRandomSampler(train_idx)
    valid_sampler = SubsetRandomSampler(valid_idx)
    batch_size = config["batch_size"]
    train_loader = torch.utils.data.DataLoader(
        data_train, batch_size=batch_size, sampler=train_sampler, drop_last=True
    )
    valid_loader = torch.utils.data.DataLoader(
        data_train, batch_size=batch_size, sampler=valid_sampler, drop_last=True
    )
    return train_loader, valid_loader


def model_and_optimizer(config):
    n_units_1 = config["n_units_1"]
    n_units_2 = config["n_units_2"]
    dropout_1 = config["dropout_1"]
    dropout_2 = config["dropout_2"]
    learning_rate = config["learning_rate"]
    weight_decay = config["weight_decay"]
    # Define the network architecture
    comp_list = [
        nn.Linear(28 * 28, n_units_1),
        nn.Dropout(p=dropout_1),
        nn.ReLU(),
        nn.Linear(n_units_1, n_units_2),
        nn.Dropout(p=dropout_2),
        nn.ReLU(),
        nn.Linear(n_units_2, 10),
    ]
    model = nn.Sequential(*comp_list)
    optimizer = torch.optim.Adam(
        model.parameters(), lr=learning_rate, weight_decay=weight_decay
    )
    criterion = nn.CrossEntropyLoss()
    return {"model": model, "optimizer": optimizer, "criterion": criterion}


def train_model(config, state, train_loader):
    model = state["model"]
    optimizer = state["optimizer"]
    criterion = state["criterion"]
    batch_size = config["batch_size"]
    model.train()
    for data, target in train_loader:
        optimizer.zero_grad()
        output = model(data.view(batch_size, -1))
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()


def validate_model(config, state, valid_loader):
    batch_size = config["batch_size"]
    model = state["model"]
    model.eval()
    correct = 0
    total = 0
    for data, target in valid_loader:
        output = model(data.view(batch_size, -1))
        _, predicted = torch.max(output.data, 1)
        total += target.size(0)
        correct += (predicted == target).sum().item()
    return correct / total  # Validation accuracy


def objective(config):
    report_current_best = parse_bool(config["report_current_best"])

    data_train = download_data(config)

    # Do not want to count the time to download the dataset, which can be
    # substantial the first time
    ts_start = time.time()
    report = Reporter()

    train_loader, valid_loader = split_data(config, data_train)

    state = model_and_optimizer(config)

    # Checkpointing
    load_model_fn, save_model_fn = pytorch_load_save_functions(
        {"model": state["model"], "optimizer": state["optimizer"]}
    )
    # Resume from checkpoint (optional)
    resume_from = resume_from_checkpointed_model(config, load_model_fn)

    current_best = None
    for epoch in range(resume_from + 1, config["epochs"] + 1):
        train_model(config, state, train_loader)
        accuracy = validate_model(config, state, valid_loader)
        elapsed_time = time.time() - ts_start
        if current_best is None or accuracy > current_best:
            current_best = accuracy
        # Write checkpoint (optional)
        checkpoint_model_at_rung_level(config, save_model_fn, epoch)
        # Feed the score back to Tune.
        objective = current_best if report_current_best else accuracy
        report(
            **{
                RESOURCE_ATTR: epoch,
                METRIC_NAME: objective,
                ELAPSED_TIME_ATTR: elapsed_time,
            }
        )


if __name__ == "__main__":
    root = logging.getLogger()
    root.setLevel(logging.INFO)

    parser = argparse.ArgumentParser()
    parser.add_argument("--epochs", type=int, required=True)
    parser.add_argument("--dataset_path", type=str, required=True)
    parser.add_argument("--report_current_best", type=str, default="False")
    add_to_argparse(parser, _config_space)
    add_checkpointing_to_argparse(parser)

    args, _ = parser.parse_known_args()

    objective(config=vars(args))