syne_tune.utils.checkpoint module
- syne_tune.utils.checkpoint.add_checkpointing_to_argparse(parser)[source]
To be called for the argument parser in the endpoint script. Arguments added here are optional. If checkpointing is not supported, they are simply not parsed.
- Parameters:
parser (
ArgumentParser
) – Parser to add extra arguments to
- syne_tune.utils.checkpoint.resume_from_checkpointed_model(config, load_model_fn)[source]
Checks whether there is a checkpoint to be resumed from. If so, the checkpoint is loaded by calling
load_model_fn
. This function takes a local pathname (to which it appends a filename). It returns resume_from, the resource value (e.g., epoch) the checkpoint was written at. If it fails to load the checkpoint, it may return 0. This skips resuming from a checkpoint. This resume_from value is returned.If checkpointing is not supported in
config
, or no checkpoint is found, resume_from = 0 is returned.- Parameters:
config (
Dict
[str
,Any
]) – Configuration the training script is called withload_model_fn (
Callable
[[str
],int
]) – See above, must returnresume_from
. Seepytorch_load_save_functions()
for an example
- Return type:
int
- Returns:
resume_from
(0 if no checkpoint has been loaded)
- syne_tune.utils.checkpoint.checkpoint_model_at_rung_level(config, save_model_fn, resource)[source]
If checkpointing is supported, checks whether a checkpoint is to be written. This is the case if the checkpoint dir is set in
config
. A checkpoint is written by callingsave_model_fn
, passing the local pathname and resource.Note: Why is
resource
passed here? In the future, we want to support writing checkpoints only for certain resource levels. This is useful if writing the checkpoint is expensive compared to the time needed to run one resource unit.- Parameters:
config (
Dict
[str
,Any
]) – Configuration the training script is called withsave_model_fn (
Callable
[[str
,int
],Any
]) – See above. Seepytorch_load_save_functions()
for an exampleresource (
int
) – Current resource level (e.g., number of epochs done)
- syne_tune.utils.checkpoint.pytorch_load_save_functions(state_dict_objects, mutable_state=None, fname='checkpoint.json')[source]
Provides default
load_model_fn
,save_model_fn
functions for standard PyTorch models (arguments toresume_from_checkpointed_model()
,checkpoint_model_at_rung_level()
).- Parameters:
state_dict_objects (
Dict
[str
,Any
]) – Dict of PyTorch objects implementingstate_dict
andload_state_dict
mutable_state (
Optional
[dict
]) – Optional. Additional dict with elementary value typesfname (
str
) – Name of local file (path is taken from config)
- Returns:
load_model_fn, save_model_fn