# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
from typing import Callable, Any, Optional, Dict
import argparse
import os
from syne_tune.constants import ST_CHECKPOINT_DIR
[docs]
def add_checkpointing_to_argparse(parser: argparse.ArgumentParser):
"""
To be called for the argument parser in the endpoint script.
Arguments added here are optional. If checkpointing is not supported,
they are simply not parsed.
:param parser: Parser to add extra arguments to
"""
parser.add_argument(f"--{ST_CHECKPOINT_DIR}", type=str)
[docs]
def resume_from_checkpointed_model(
config: Dict[str, Any], load_model_fn: Callable[[str], int]
) -> int:
"""
Checks whether there is a checkpoint to be resumed from. If so, the
checkpoint is loaded by calling ``load_model_fn``. This function takes
a local pathname (to which it appends a filename). It returns
resume_from, the resource value (e.g., epoch) the checkpoint was written
at. If it fails to load the checkpoint, it may return 0. This skips
resuming from a checkpoint. This resume_from value is returned.
If checkpointing is not supported in ``config``, or no checkpoint is
found, resume_from = 0 is returned.
:param config: Configuration the training script is called with
:param load_model_fn: See above, must return ``resume_from``. See
:func:`pytorch_load_save_functions` for an example
:return: ``resume_from`` (0 if no checkpoint has been loaded)
"""
resume_from = 0
local_path = config.get(ST_CHECKPOINT_DIR)
if local_path is not None and os.path.exists(local_path):
resume_from = load_model_fn(local_path)
trial_id = config.get("trial_id")
if trial_id is not None:
print(
f"Trial {trial_id}: Loading checkpoint [resume_from = "
f"{resume_from}, local_path = {local_path}]"
)
return resume_from
[docs]
def checkpoint_model_at_rung_level(
config: Dict[str, Any], save_model_fn: Callable[[str, int], Any], resource: int
):
"""
If checkpointing is supported, checks whether a checkpoint is to be
written. This is the case if the checkpoint dir is set in ``config``.
A checkpoint is written by calling ``save_model_fn``, passing the
local pathname and resource.
Note: Why is ``resource`` passed here? In the future, we want to support
writing checkpoints only for certain resource levels. This is useful if
writing the checkpoint is expensive compared to the time needed to
run one resource unit.
:param config: Configuration the training script is called with
:param save_model_fn: See above. See :func:`pytorch_load_save_functions` for
an example
:param resource: Current resource level (e.g., number of epochs done)
"""
local_path = config.get(ST_CHECKPOINT_DIR)
if local_path is not None:
save_model_fn(local_path, resource)
trial_id = config.get("trial_id")
if trial_id is not None:
print(
f"Trial {trial_id}: Saving checkpoint [resource = "
f"{resource}, local_path = {local_path}]"
)
RESOURCE_NAME = "st_resource"
STATE_DICT_PREFIX = "st_state_dict_"
MUTABLE_STATE_PREFIX = "st_mutable_"
[docs]
def pytorch_load_save_functions(
state_dict_objects: Dict[str, Any],
mutable_state: Optional[dict] = None,
fname: str = "checkpoint.json",
):
"""
Provides default ``load_model_fn``, ``save_model_fn`` functions for standard
PyTorch models (arguments to :func:`resume_from_checkpointed_model`,
:func:`checkpoint_model_at_rung_level`).
:param state_dict_objects: Dict of PyTorch objects implementing ``state_dict``
and ``load_state_dict``
:param mutable_state: Optional. Additional dict with elementary value
types
:param fname: Name of local file (path is taken from config)
:return: ``load_model_fn, save_model_fn``
"""
import torch
def load_model_fn(local_path: str) -> int:
_mutable_state, local_filename = _common_init(local_path)
try:
checkpoint = torch.load(local_filename)
resume_from = int(checkpoint[RESOURCE_NAME])
for k, v in state_dict_objects.items():
v.load_state_dict(checkpoint[STATE_DICT_PREFIX + k])
for k in _mutable_state:
v = checkpoint[MUTABLE_STATE_PREFIX + k]
v_old = _mutable_state.get(k)
if v_old is not None:
v = type(v_old)(v)
_mutable_state[k] = v
except Exception:
resume_from = 0
return resume_from
def save_model_fn(local_path: str, resource: int):
os.makedirs(local_path, exist_ok=True)
_mutable_state, local_filename = _common_init(local_path)
local_filename = os.path.join(local_path, fname)
checkpoint = {
STATE_DICT_PREFIX + k: v.state_dict() for k, v in state_dict_objects.items()
}
checkpoint[RESOURCE_NAME] = resource
for k, v in _mutable_state.items():
checkpoint[MUTABLE_STATE_PREFIX + k] = v
torch.save(checkpoint, local_filename)
def _common_init(local_path: str) -> (dict, str):
if mutable_state is None:
_mutable_state = dict()
else:
_mutable_state = mutable_state
local_filename = os.path.join(local_path, fname)
return _mutable_state, local_filename
return load_model_fn, save_model_fn