ResNet-18 Trained on CIFAR-10
import os
import argparse
import logging
import time
from pathlib import Path
try:
# Benchmark-specific imports are done here, in order to avoid import
# errors if the dependencies are not installed (such errors should happen
# only when the code is really called)
from filelock import SoftFileLock, Timeout
import numpy as np
from tqdm import tqdm
import torch
import torch.nn.functional as F
from torch.utils.data.sampler import SubsetRandomSampler
from torchvision import datasets, transforms
from torchvision.models import resnet18
except ImportError:
logging.info(
f"Please install benchmark-specific dependencies ({Path(__file__).parent / 'requirements.txt'})"
)
from syne_tune import Reporter
from syne_tune.config_space import randint, uniform, loguniform, add_to_argparse
from syne_tune.utils import (
resume_from_checkpointed_model,
checkpoint_model_at_rung_level,
add_checkpointing_to_argparse,
pytorch_load_save_functions,
)
BATCH_SIZE_LOWER = 8
BATCH_SIZE_UPPER = 256
BATCH_SIZE_KEY = "batch_size"
METRIC_NAME = "objective"
RESOURCE_ATTR = "epoch"
MAX_RESOURCE_ATTR = "epochs"
ELAPSED_TIME_ATTR = "elapsed_time"
_config_space = {
BATCH_SIZE_KEY: randint(BATCH_SIZE_LOWER, BATCH_SIZE_UPPER),
"momentum": uniform(0, 0.99),
"weight_decay": loguniform(1e-5, 1e-3),
"lr": loguniform(1e-3, 0.1),
}
# ATTENTION: train_dataset, valid_dataset are both based on the CIFAR10
# training set, but train_dataset uses data augmentation. Make sure to
# only use disjoint parts for training and validation further down.
def get_CIFAR10(root):
input_size = 32
num_classes = 10
normalize = [(0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)]
train_transform = transforms.Compose(
[
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(*normalize),
]
)
local_path = os.path.join(root, "CIFAR10")
train_dataset = datasets.CIFAR10(
local_path, train=True, transform=train_transform, download=True
)
valid_transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize(*normalize),
]
)
valid_dataset = datasets.CIFAR10(
local_path, train=True, transform=valid_transform, download=True
)
return input_size, num_classes, train_dataset, valid_dataset
def train(model, train_loader, optimizer):
model.train()
total_loss = []
for data, target in tqdm(train_loader):
if torch.cuda.is_available():
data = data.cuda()
target = target.cuda()
optimizer.zero_grad()
prediction = model(data)
loss = F.nll_loss(prediction, target)
loss.backward()
optimizer.step()
total_loss.append(loss.item())
avg_loss = sum(total_loss) / len(total_loss)
def valid(model, valid_loader):
model.eval()
loss = 0
correct = 0
for data, target in valid_loader:
with torch.no_grad():
if torch.cuda.is_available():
data = data.cuda()
target = target.cuda()
prediction = model(data)
loss += F.nll_loss(prediction, target, reduction="sum")
prediction = prediction.max(1)[1]
correct += prediction.eq(target.view_as(prediction)).sum().item()
n_valid = len(valid_loader.sampler)
loss /= n_valid
valid_error = correct / n_valid
return loss, valid_error
def _download_data(config):
path = config["dataset_path"]
os.makedirs(path, exist_ok=True)
# Lock protection is needed for backends which run multiple worker
# processes on the same instance
lock_path = os.path.join(path, "lock")
lock = SoftFileLock(lock_path)
try:
with lock.acquire(timeout=120, poll_intervall=1):
input_size, num_classes, train_dataset, valid_dataset = get_CIFAR10(
root=path
)
except Timeout:
print(
"WARNING: Could not obtain lock for dataset files. Trying anyway...",
flush=True,
)
input_size, num_classes, train_dataset, valid_dataset = get_CIFAR10(root=path)
return input_size, num_classes, train_dataset, valid_dataset
def _create_data_loaders(config, train_dataset, valid_dataset):
indices = list(range(train_dataset.data.shape[0]))
train_idx, valid_idx = indices[:40000], indices[40000:]
train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=config[BATCH_SIZE_KEY],
# shuffle=True,
num_workers=0,
sampler=train_sampler,
pin_memory=True,
)
valid_loader = torch.utils.data.DataLoader(
valid_dataset,
batch_size=128,
# shuffle=False,
num_workers=0,
sampler=valid_sampler,
pin_memory=True,
)
return train_loader, valid_loader
def _create_training_objects(config):
model = Model()
if torch.cuda.is_available():
model = model.cuda()
device = torch.device("cuda")
model = torch.nn.DataParallel(
model, device_ids=[i for i in range(config["num_gpus"])]
).to(device)
milestones = [25, 40]
optimizer = torch.optim.SGD(
model.parameters(),
lr=config["lr"],
momentum=config["momentum"],
weight_decay=config["weight_decay"],
)
scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer, milestones=milestones, gamma=0.1
)
return model, optimizer, scheduler
def objective(config):
torch.manual_seed(np.random.randint(10000))
# Download data, setup data loaders
input_size, num_classes, train_dataset, valid_dataset = _download_data(config)
train_loader, valid_loader = _create_data_loaders(
config, train_dataset, valid_dataset
)
# Do not want to count the time to download the dataset, which can be
# substantial the first time
ts_start = time.time()
report = Reporter()
# Create model, optimizer, LR scheduler
model, optimizer, scheduler = _create_training_objects(config)
# Checkpointing for PyTorch model
load_model_fn, save_model_fn = pytorch_load_save_functions(
{"model": model, "optimizer": optimizer, "lr_scheduler": scheduler}
)
# Resume from checkpoint (optional)
resume_from = resume_from_checkpointed_model(config, load_model_fn)
for epoch in range(resume_from + 1, config[MAX_RESOURCE_ATTR] + 1):
train(model, train_loader, optimizer)
scheduler.step()
elapsed_time = time.time() - ts_start
# Write checkpoint (optional)
checkpoint_model_at_rung_level(config, save_model_fn, epoch)
# Evaluate and send metrics back to Syne Tune
_, valid_error = valid(model, valid_loader)
report(
**{
RESOURCE_ATTR: epoch,
METRIC_NAME: valid_error,
ELAPSED_TIME_ATTR: elapsed_time,
}
)
if __name__ == "__main__":
# Superclass reference torch.nn.Module requires torch to be defined
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.resnet = resnet18(pretrained=False, num_classes=10)
self.resnet.conv1 = torch.nn.Conv2d(
3, 64, kernel_size=3, stride=1, padding=1, bias=False
)
self.resnet.maxpool = torch.nn.Identity()
def forward(self, x):
x = self.resnet(x)
x = F.log_softmax(x, dim=1)
return x
root = logging.getLogger()
root.setLevel(logging.INFO)
parser = argparse.ArgumentParser()
parser.add_argument(f"--{MAX_RESOURCE_ATTR}", type=int, required=True)
parser.add_argument("--dataset_path", type=str, required=True)
parser.add_argument("--num_gpus", type=int, default=1)
add_to_argparse(parser, _config_space)
add_checkpointing_to_argparse(parser)
args, _ = parser.parse_known_args()
objective(config=vars(args))
Transformer Trained on WikiText-2
import argparse
import os
import logging
import math
from pathlib import Path
import time
try:
# Benchmark-specific imports are done here, in order to avoid import
# errors if the dependencies are not installed (such errors should happen
# only when the code is really called)
import numpy as np
from filelock import SoftFileLock, Timeout
import torch
import torch.nn as nn
import torch.nn.functional as F
except ImportError:
logging.info(
f"Please install benchmark-specific dependencies ({Path(__file__).parent / 'requirements.txt'})"
)
try:
from apex import amp
except ImportError:
print("Failed to import apex. You can still train with --precision {float|double}.")
from syne_tune.report import Reporter
from syne_tune.config_space import randint, uniform, loguniform, add_to_argparse
from syne_tune.utils import (
resume_from_checkpointed_model,
checkpoint_model_at_rung_level,
add_checkpointing_to_argparse,
pytorch_load_save_functions,
)
BATCH_SIZE_LOWER = 16
BATCH_SIZE_UPPER = 48
BATCH_SIZE_KEY = "batch_size"
METRIC_NAME = "val_loss"
RESOURCE_ATTR = "epoch"
MAX_RESOURCE_ATTR = "epochs"
ELAPSED_TIME_ATTR = "elapsed_time"
_config_space = {
"lr": loguniform(1e-6, 1e-3),
"dropout": uniform(0, 0.99),
BATCH_SIZE_KEY: randint(BATCH_SIZE_LOWER, BATCH_SIZE_UPPER),
"momentum": uniform(0, 0.99),
"clip": uniform(0, 1),
}
DATASET_PATH = "https://raw.githubusercontent.com/pytorch/examples/master/word_language_model/data/wikitext-2/"
def download_wikitext2_dataset(root):
import urllib
path = os.path.join(root, "wikitext-2")
for fname in ("train.txt", "valid.txt", "test.txt"):
fh = os.path.join(path, fname)
if not os.path.exists(fh):
os.makedirs(path, exist_ok=True)
urllib.request.urlretrieve(DATASET_PATH + fname, fh)
class Dictionary(object):
def __init__(self):
self.word2idx = {}
self.idx2word = []
def add_word(self, word):
if word not in self.word2idx:
self.idx2word.append(word)
self.word2idx[word] = len(self.idx2word) - 1
return self.word2idx[word]
def __len__(self):
return len(self.idx2word)
class Corpus(object):
def __init__(self, path):
self.dictionary = Dictionary()
self.train = None
self.valid = None
self.test = None
if not self.load_cache(path):
self.train = self.tokenize(os.path.join(path, "train.txt"))
self.valid = self.tokenize(os.path.join(path, "valid.txt"))
self.test = self.tokenize(os.path.join(path, "test.txt"))
self.save_cache(path)
def load_cache(self, path):
for cache in ["dict.pt", "train.pt", "valid.pt", "test.pt"]:
cache_path = os.path.join(path, cache)
if not os.path.exists(cache_path):
return False
self.dictionary = torch.load(os.path.join(path, "dict.pt"))
self.train = torch.load(os.path.join(path, "train.pt"))
self.valid = torch.load(os.path.join(path, "valid.pt"))
self.test = torch.load(os.path.join(path, "test.pt"))
return True
def save_cache(self, path):
torch.save(self.dictionary, os.path.join(path, "dict.pt"))
torch.save(self.train, os.path.join(path, "train.pt"))
torch.save(self.valid, os.path.join(path, "valid.pt"))
torch.save(self.test, os.path.join(path, "test.pt"))
def tokenize(self, path):
"""Tokenizes a text file."""
assert os.path.exists(path)
# Add words to the dictionary
with open(path, "r", encoding="utf8") as f:
for line in f:
words = line.split() + ["<eos>"]
for word in words:
self.dictionary.add_word(word)
# Tokenize file content
with open(path, "r", encoding="utf8") as f:
idss = []
for line in f:
words = line.split() + ["<eos>"]
ids = []
for word in words:
ids.append(self.dictionary.word2idx[word])
idss.append(torch.tensor(ids).type(torch.int64))
ids = torch.cat(idss)
return ids
def get_batch(source, i, bptt):
seq_len = min(bptt, len(source) - 1 - i)
data = source[i : i + seq_len]
target = source[i + 1 : i + 1 + seq_len].view(-1)
return data, target
def batchloader(train_data, bptt):
for batch, i in enumerate(range(0, train_data.size(0) - 1, bptt)):
yield get_batch(train_data, i, bptt)
def batchify(data, bsz, device):
# Work out how cleanly we can divide the dataset into bsz parts.
nbatch = data.size(0) // bsz
# Trim off any extra elements that wouldn't cleanly fit (remainders).
data = data.narrow(0, 0, nbatch * bsz)
# Evenly divide the data across the bsz batches.
data = data.view(bsz, -1).t().contiguous()
return data.to(device)
def setprec(t, precision):
if precision == "half":
# do nothing since this is handled by AMP
return t
elif precision == "float":
return t.float()
elif precision == "double":
return t.double()
else:
raise ValueError(f"invalid precision string {precision}")
def download_dataset(config):
path = config["input_data_dir"]
os.makedirs(path, exist_ok=True)
# Lock protection is needed for backends which run multiple worker
# processes on the same instance
lock_path = os.path.join(path, "lock")
lock = SoftFileLock(lock_path)
try:
with lock.acquire(timeout=120, poll_intervall=1):
# Make sure files are present locally
download_wikitext2_dataset(path)
corpus = Corpus(os.path.join(path, "wikitext-2"))
except Timeout:
print(
"WARNING: Could not obtain lock for dataset files. Trying anyway...",
flush=True,
)
# Make sure files are present locally
download_wikitext2_dataset(path)
corpus = Corpus(os.path.join(path, "wikitext-2"))
return corpus
def evaluate(model, valid_data, criterion, config, ntokens):
# Turn on evaluation mode which disables dropout
model.eval()
bptt = config["bptt"]
total_loss = 0.0
with torch.no_grad():
for i in range(0, valid_data.size(0) - 1, bptt):
data, targets = get_batch(valid_data, i, bptt)
output = model(data)
output = output.view(-1, ntokens)
total_loss += len(data) * criterion(output, targets).item()
return total_loss / (len(valid_data) - 1)
def train(model, train_data, optimizer, criterion, config, ntokens, epoch):
# Turn on training mode which enables dropout
model.train()
bptt = config["bptt"]
precision = config["precision"]
log_interval = config["log_interval"]
total_loss = 0.0
epoch_loss = 0.0
start_time = time.time()
first_loss = None
for batch, i in enumerate(range(0, train_data.size(0) - 1, bptt)):
data, targets = get_batch(train_data, i, bptt)
# Starting each batch, we detach the hidden state from how it was previously produced.
# If we didn't, the model would try backpropagating all the way to start of the dataset.
optimizer.zero_grad()
output = model(data)
output = output.view(-1, ntokens)
loss = criterion(output, targets)
if torch.isnan(loss):
exit(0)
if precision == "half":
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
clip = config["clip"]
if clip > 0:
# ``clip_grad_norm`` helps prevent the exploding gradient problem in RNNs / LSTMs.
if precision == "half":
params = amp.master_params(optimizer)
else:
params = model.parameters()
torch.nn.utils.clip_grad_norm_(params, clip)
optimizer.step()
total_loss += loss.item()
epoch_loss += len(data) * loss.item()
if batch % log_interval == 0 and batch > 0:
cur_loss = total_loss / log_interval
elapsed = time.time() - start_time
print(
"| epoch {:3d} | {:5d}/{:5d} batches | lr {:02.5f} | ms/batch {:5.2f} | "
"loss {:5.2f} | ppl {:8.2f}".format(
epoch,
batch,
len(train_data) // bptt,
config["lr"],
elapsed * 1000 / log_interval,
cur_loss,
np.exp(cur_loss),
)
)
total_loss = 0
start_time = time.time()
if first_loss is None:
first_loss = cur_loss
return epoch_loss / (len(train_data) - 1), first_loss
def create_training_objects(config, ntokens, device):
precision = config["precision"]
d_model = config["d_model"]
model = TransformerModel(
ntokens,
ninp=d_model,
nhead=config["nhead"],
nhid=d_model * config["ffn_ratio"],
nlayers=config["nlayers"],
dropout=config["dropout"],
)
model = model.to(device)
model = setprec(model, precision)
criterion = nn.NLLLoss()
if config["optimizer_name"] == "sgd":
optimizer = torch.optim.SGD(
model.parameters(),
lr=config["lr"],
momentum=config["momentum"],
)
elif config["optimizer_name"] == "adam":
optimizer = torch.optim.Adam(
model.parameters(),
lr=config["lr"],
betas=(config["momentum"], 0.999),
)
else:
raise ValueError(f"optimizer_name = {config['optimizer_name']} not supported")
# half-precision black magic
if precision == "half":
model, optimizer = amp.initialize(
model, optimizer, opt_level="O1", min_loss_scale=0.0001, verbosity=0
)
return model, optimizer, criterion
def objective(config):
torch.manual_seed(config["seed"])
use_cuda = config["use_cuda"]
if torch.cuda.is_available() and not use_cuda:
print("WARNING: You have a CUDA device, so you should run with --use-cuda 1")
device = torch.device("cuda" if use_cuda else "cpu")
# Download data, setup data loaders
corpus = download_dataset(config)
ntokens = len(corpus.dictionary)
train_data = batchify(corpus.train, bsz=config["batch_size"], device=device)
valid_data = batchify(corpus.valid, bsz=10, device=device)
# Do not want to count the time to download the dataset, which can be
# substantial the first time
ts_start = time.time()
# Used for reporting metrics to Syne Tune
report = Reporter()
# Create model and optimizer
model, optimizer, criterion = create_training_objects(config, ntokens, device)
# Checkpointing
state_dict_objects = {
"model": model,
"optimizer": optimizer,
}
if config["precision"] == "half":
state_dict_objects["amp"] = amp
load_model_fn, save_model_fn = pytorch_load_save_functions(
state_dict_objects=state_dict_objects,
)
# Resume from checkpoint (optional)
resume_from = resume_from_checkpointed_model(config, load_model_fn)
# At any point you can hit Ctrl + C to break out of training early.
try:
for epoch in range(resume_from + 1, config[MAX_RESOURCE_ATTR] + 1):
epoch_start_time = time.time()
train(model, train_data, optimizer, criterion, config, ntokens, epoch)
val_loss = evaluate(model, valid_data, criterion, config, ntokens)
curr_ts = time.time()
elapsed_time = curr_ts - ts_start
duration = curr_ts - epoch_start_time
print("-" * 89)
print(
"| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | "
"valid ppl {:8.2f}".format(epoch, duration, val_loss, np.exp(val_loss))
)
print("-" * 89)
# Write checkpoint (optional)
checkpoint_model_at_rung_level(config, save_model_fn, epoch)
# Report metrics back to Syne Tune
report_kwargs = {
RESOURCE_ATTR: epoch,
METRIC_NAME: val_loss,
ELAPSED_TIME_ATTR: elapsed_time,
}
report(**report_kwargs)
except KeyboardInterrupt:
print("-" * 89)
print("Exiting from training early")
if __name__ == "__main__":
# Temporarily leave PositionalEncoding module here. Will be moved somewhere else.
class PositionalEncoding(nn.Module):
r"""Inject some information about the relative or absolute position of the tokens
in the sequence. The positional encodings have the same dimension as
the embeddings, so that the two can be summed. Here, we use sine and cosine
functions of different frequencies.
.. math::
\text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model))
\text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model))
\text{where pos is the word position and i is the embed idx)
Args:
d_model: the embed dim (required).
dropout: the dropout value (default=0.1).
max_len: the max. length of the incoming sequence (default=5000).
Examples:
>>> pos_encoder = PositionalEncoding(d_model)
"""
def __init__(self, d_model, dropout=0.1, max_len=5000):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer("pe", pe)
def forward(self, x):
r"""Inputs of forward function
Args:
x: the sequence fed to the positional encoder model (required).
Shape:
x: [sequence length, batch size, embed dim]
output: [sequence length, batch size, embed dim]
Examples:
>>> output = pos_encoder(x)
"""
x = x + self.pe[: x.size(0), :]
return self.dropout(x)
class TransformerModel(nn.Module):
"""Container module with an encoder, a recurrent or transformer module, and a decoder."""
def __init__(self, ntoken, ninp, nhead, nhid, nlayers, dropout=0.5):
super(TransformerModel, self).__init__()
try:
from torch.nn import TransformerEncoder, TransformerEncoderLayer
except ImportError:
raise ImportError(
"TransformerEncoder module does not exist in PyTorch 1.1 or lower."
)
self.model_type = "Transformer"
self.src_mask = None
self.pos_encoder = PositionalEncoding(ninp, dropout)
encoder_layers = TransformerEncoderLayer(ninp, nhead, nhid, dropout)
self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
self.encoder = nn.Embedding(ntoken, ninp)
self.ninp = ninp
self.decoder = nn.Linear(ninp, ntoken)
self.init_weights()
@staticmethod
def _generate_square_subsequent_mask(sz):
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
mask = (
mask.float()
.masked_fill(mask == 0, float("-inf"))
.masked_fill(mask == 1, float(0.0))
)
return mask
def init_weights(self):
initrange = 0.1
nn.init.uniform_(self.encoder.weight, -initrange, initrange)
nn.init.zeros_(self.decoder.bias)
nn.init.uniform_(self.decoder.weight, -initrange, initrange)
def forward(self, src, has_mask=True):
if has_mask:
device = src.device
if self.src_mask is None or self.src_mask.size(0) != len(src):
mask = self._generate_square_subsequent_mask(len(src)).to(device)
self.src_mask = mask
else:
self.src_mask = None
src = self.encoder(src) * math.sqrt(self.ninp)
src = self.pos_encoder(src)
output = self.transformer_encoder(src, self.src_mask)
output = self.decoder(output)
return F.log_softmax(output, dim=-1)
root = logging.getLogger()
root.setLevel(logging.INFO)
parser = argparse.ArgumentParser(
description="PyTorch Wikitext-2 Transformer Language Model",
formatter_class=argparse.RawTextHelpFormatter,
)
parser.add_argument(
"--" + MAX_RESOURCE_ATTR, type=int, default=40, help="upper epoch limit"
)
parser.add_argument("--use_cuda", type=int, default=1)
parser.add_argument(
"--input_data_dir",
type=str,
default="./",
help="location of the data corpus",
)
parser.add_argument(
"--optimizer_name", type=str, default="sgd", choices=["sgd", "adam"]
)
parser.add_argument("--bptt", type=int, default=35, help="sequence length")
parser.add_argument("--seed", type=int, default=1111, help="random seed")
parser.add_argument(
"--precision", type=str, default="float", help="float | double | half"
)
parser.add_argument(
"--log_interval",
type=int,
default=200,
help="report interval",
)
# These could become hyperparameters as well (more like NAS)
parser.add_argument("--d_model", type=int, default=256, help="width of the model")
parser.add_argument(
"--ffn_ratio", type=int, default=1, help="the ratio of d_ffn to d_model"
)
parser.add_argument("--nlayers", type=int, default=2, help="number of layers")
parser.add_argument(
"--nhead",
type=int,
default=2,
help="the number of heads in the encoder/decoder of the transformer model",
)
add_to_argparse(parser, _config_space)
add_checkpointing_to_argparse(parser)
args, _ = parser.parse_known_args()
args.use_cuda = bool(args.use_cuda)
objective(config=vars(args))
Multi-layer Perceptron Trained on Fashion MNIST
"""
Two-layer MLP trained on Fashion MNIST
"""
import os
import argparse
import logging
import time
from pathlib import Path
try:
# Benchmark-specific imports are done here, in order to avoid import
# errors if the dependencies are not installed (such errors should happen
# only when the code is really called)
from filelock import SoftFileLock, Timeout
import torch
import torch.nn as nn
from torch.utils.data.sampler import SubsetRandomSampler
from torchvision import datasets
from torchvision import transforms
except ImportError:
logging.info(
f"Please install benchmark-specific dependencies ({Path(__file__).parent / 'requirements.txt'})"
)
from syne_tune import Reporter
from syne_tune.config_space import (
randint,
lograndint,
uniform,
loguniform,
add_to_argparse,
)
from syne_tune.utils import (
resume_from_checkpointed_model,
checkpoint_model_at_rung_level,
add_checkpointing_to_argparse,
pytorch_load_save_functions,
parse_bool,
)
NUM_UNITS_1 = "n_units_1"
NUM_UNITS_2 = "n_units_2"
METRIC_NAME = "accuracy"
RESOURCE_ATTR = "epoch"
ELAPSED_TIME_ATTR = "elapsed_time"
_config_space = {
NUM_UNITS_1: lograndint(4, 1024),
NUM_UNITS_2: lograndint(4, 1024),
"batch_size": randint(8, 128),
"dropout_1": uniform(0, 0.99),
"dropout_2": uniform(0, 0.99),
"learning_rate": loguniform(1e-6, 1),
"weight_decay": loguniform(1e-8, 1),
}
# Boilerplate for objective
def download_data(config):
path = os.path.join(config["dataset_path"], "FashionMNIST")
os.makedirs(path, exist_ok=True)
# Lock protection is needed for backends which run multiple worker
# processes on the same instance
lock_path = os.path.join(path, "lock")
lock = SoftFileLock(lock_path)
try:
with lock.acquire(timeout=120, poll_intervall=1):
data_train = datasets.FashionMNIST(
root=path, train=True, download=True, transform=transforms.ToTensor()
)
except Timeout:
print(
"WARNING: Could not obtain lock for dataset files. Trying anyway...",
flush=True,
)
data_train = datasets.FashionMNIST(
root=path, train=True, download=True, transform=transforms.ToTensor()
)
return data_train
def split_data(config, data_train):
# We use 50000 samples for training and 10000 samples for validation
indices = list(range(data_train.data.shape[0]))
train_idx, valid_idx = indices[:50000], indices[50000:]
train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)
batch_size = config["batch_size"]
train_loader = torch.utils.data.DataLoader(
data_train, batch_size=batch_size, sampler=train_sampler, drop_last=True
)
valid_loader = torch.utils.data.DataLoader(
data_train, batch_size=batch_size, sampler=valid_sampler, drop_last=True
)
return train_loader, valid_loader
def model_and_optimizer(config):
n_units_1 = config["n_units_1"]
n_units_2 = config["n_units_2"]
dropout_1 = config["dropout_1"]
dropout_2 = config["dropout_2"]
learning_rate = config["learning_rate"]
weight_decay = config["weight_decay"]
# Define the network architecture
comp_list = [
nn.Linear(28 * 28, n_units_1),
nn.Dropout(p=dropout_1),
nn.ReLU(),
nn.Linear(n_units_1, n_units_2),
nn.Dropout(p=dropout_2),
nn.ReLU(),
nn.Linear(n_units_2, 10),
]
model = nn.Sequential(*comp_list)
optimizer = torch.optim.Adam(
model.parameters(), lr=learning_rate, weight_decay=weight_decay
)
criterion = nn.CrossEntropyLoss()
return {"model": model, "optimizer": optimizer, "criterion": criterion}
def train_model(config, state, train_loader):
model = state["model"]
optimizer = state["optimizer"]
criterion = state["criterion"]
batch_size = config["batch_size"]
model.train()
for data, target in train_loader:
optimizer.zero_grad()
output = model(data.view(batch_size, -1))
loss = criterion(output, target)
loss.backward()
optimizer.step()
def validate_model(config, state, valid_loader):
batch_size = config["batch_size"]
model = state["model"]
model.eval()
correct = 0
total = 0
for data, target in valid_loader:
output = model(data.view(batch_size, -1))
_, predicted = torch.max(output.data, 1)
total += target.size(0)
correct += (predicted == target).sum().item()
return correct / total # Validation accuracy
def objective(config):
report_current_best = parse_bool(config["report_current_best"])
data_train = download_data(config)
# Do not want to count the time to download the dataset, which can be
# substantial the first time
ts_start = time.time()
report = Reporter()
train_loader, valid_loader = split_data(config, data_train)
state = model_and_optimizer(config)
# Checkpointing
load_model_fn, save_model_fn = pytorch_load_save_functions(
{"model": state["model"], "optimizer": state["optimizer"]}
)
# Resume from checkpoint (optional)
resume_from = resume_from_checkpointed_model(config, load_model_fn)
current_best = None
for epoch in range(resume_from + 1, config["epochs"] + 1):
train_model(config, state, train_loader)
accuracy = validate_model(config, state, valid_loader)
elapsed_time = time.time() - ts_start
if current_best is None or accuracy > current_best:
current_best = accuracy
# Write checkpoint (optional)
checkpoint_model_at_rung_level(config, save_model_fn, epoch)
# Feed the score back to Tune.
objective = current_best if report_current_best else accuracy
report(
**{
RESOURCE_ATTR: epoch,
METRIC_NAME: objective,
ELAPSED_TIME_ATTR: elapsed_time,
}
)
if __name__ == "__main__":
root = logging.getLogger()
root.setLevel(logging.INFO)
parser = argparse.ArgumentParser()
parser.add_argument("--epochs", type=int, required=True)
parser.add_argument("--dataset_path", type=str, required=True)
parser.add_argument("--report_current_best", type=str, default="False")
add_to_argparse(parser, _config_space)
add_checkpointing_to_argparse(parser)
args, _ = parser.parse_known_args()
objective(config=vars(args))