# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
import logging
from pathlib import Path
from typing import List, Union, Dict, Optional
from syne_tune.try_import import try_import_aws_message, try_import_yahpo_message
try:
import s3fs as s3fs
from botocore.exceptions import NoCredentialsError
except ImportError:
print(try_import_aws_message())
from syne_tune.blackbox_repository.blackbox import Blackbox
from syne_tune.blackbox_repository.blackbox_offline import (
deserialize as deserialize_offline,
)
from syne_tune.blackbox_repository.blackbox_tabular import (
deserialize as deserialize_tabular,
)
from syne_tune.blackbox_repository.conversion_scripts.scripts.pd1_import import (
deserialize as deserialize_pd1,
)
try:
from syne_tune.blackbox_repository.conversion_scripts.scripts.yahpo_import import (
instantiate_yahpo,
)
except ImportError:
print(try_import_yahpo_message())
# where the blackbox repository is stored on s3
from syne_tune.blackbox_repository.conversion_scripts.recipes import (
generate_blackbox_recipes,
)
from syne_tune.blackbox_repository.conversion_scripts.utils import (
validate_hash,
blackbox_local_path,
blackbox_s3_path,
)
logger = logging.getLogger(__name__)
[docs]
def blackbox_list() -> List[str]:
"""
:return: list of blackboxes available
"""
return list(generate_blackbox_recipes.keys())
[docs]
def load_blackbox(
name: str,
skip_if_present: bool = True,
s3_root: Optional[str] = None,
generate_if_not_found: bool = True,
yahpo_kwargs: Optional[dict] = None,
ignore_hash: bool = True, # TODO: Switch back to ``False`` once hash computation fixed
) -> Union[Dict[str, Blackbox], Blackbox]:
"""
:param name: name of a blackbox present in the repository, see
:func:`blackbox_list` to get list of available blackboxes. Syne Tune
currently provides the following blackboxes evaluations:
* "nasbench201": 15625 multi-fidelity configurations of computer vision
architectures evaluated on 3 datasets.
NAS-Bench-201: Extending the scope of reproducible neural architecture search.
Dong, X. and Yang, Y. 2020.
* "fcnet": 62208 multi-fidelity configurations of MLP evaluated on 4 datasets.
Tabular benchmarks for joint architecture and hyperparameter optimization.
Klein, A. and Hutter, F. 2019.
* "lcbench": 2000 multi-fidelity Pytorch model configurations evaluated on many datasets.
Reference: Auto-PyTorch: Multi-Fidelity MetaLearning for Efficient and Robust AutoDL.
Lucas Zimmer, Marius Lindauer, Frank Hutter. 2020.
* "icml-deepar": 2420 single-fidelity configurations of DeepAR forecasting algorithm evaluated on 10 datasets.
A quantile-based approach for hyperparameter transfer learning.
Salinas, D., Shen, H., and Perrone, V. 2021.
* "icml-xgboost": 5O00 single-fidelity configurations of XGBoost evaluated on 9 datasets.
A quantile-based approach for hyperparameter transfer learning.
Salinas, D., Shen, H., and Perrone, V. 2021.
* "yahpo-*": Number of different benchmarks from YAHPO Gym. Note that these
blackboxes come with surrogates already, so no need to wrap them into
:class:`SurrogateBlackbox`
:param skip_if_present: skip the download if the file locally exists
:param s3_root: S3 root directory for blackbox repository. Defaults to
S3 bucket name of SageMaker session
:param generate_if_not_found: If the blackbox file is not present locally
or on S3, should it be generated using its conversion script?
:param yahpo_kwargs: For a YAHPO blackbox (``name == "yahpo-*"``), these are
additional arguments to ``instantiate_yahpo``
:param ignore_hash: do not check if hash of currently stored files matches the
pre-computed hash. Be careful with this option. If hashes do not match, results
might not be reproducible.
:return: blackbox with the given name, download it if not present.
"""
tgt_folder = blackbox_local_path(name)
expected_hash = generate_blackbox_recipes[name].hash
if check_blackbox_local_files(tgt_folder) and skip_if_present:
if (
not ignore_hash
and expected_hash is not None
and not validate_hash(tgt_folder, expected_hash)
):
logger.warning(
f"Files seem to be corrupted (hash mismatch), regenerating it locally and persisting it on S3."
)
generate_blackbox_recipes[name].generate(s3_root=s3_root)
if not validate_hash(tgt_folder, expected_hash):
Exception(
f"The hash of the files do not match the stored hash after regenerations. "
f"Consider updating the hash and sending a pull-request to change it or set the option ``ignore_hash`` to True."
)
logger.info(
f"Skipping download of {name} as {tgt_folder} already exists, change skip_if_present to redownload"
)
else:
logger.info("Local files not found, attempting to copy from S3.")
tgt_folder.mkdir(exist_ok=True, parents=True)
try:
s3_folder = blackbox_s3_path(name=name, s3_root=s3_root)
fs = s3fs.S3FileSystem()
data_on_s3 = fs.exists(f"{s3_folder}/metadata.json")
except NoCredentialsError:
data_on_s3 = False
if data_on_s3:
logger.info("found blackbox on S3, copying it locally")
# download files from s3 to repository_path
for src in fs.glob(f"{s3_folder}/*"):
tgt = tgt_folder / Path(src).name
logger.info(f"copying {src} to {tgt}")
fs.get(src, str(tgt))
if (
not ignore_hash
and expected_hash is not None
and not validate_hash(tgt_folder, expected_hash)
):
logger.warning(
f"Files seem to be corrupted (hash mismatch), regenerating it locally and overwrite files on S3."
)
generate_blackbox_recipes[name].generate(s3_root=s3_root)
else:
assert generate_if_not_found, (
"Blackbox files do not exist locally or on S3. If you have "
+ f"write permissions to {s3_folder}, you can set "
+ "generate_if_not_found=True in order to generate and persist them"
)
logger.info(
"Did not find blackbox files locally nor on S3, regenerating it locally and persisting it on S3."
)
generate_blackbox_recipes[name].generate(s3_root=s3_root)
if name.startswith("yahpo"):
if yahpo_kwargs is None:
yahpo_kwargs = dict()
return instantiate_yahpo(name, **yahpo_kwargs)
elif name.startswith("pd1"):
return deserialize_pd1(tgt_folder)
elif (tgt_folder / "hyperparameters.parquet").exists():
return deserialize_tabular(tgt_folder)
else:
return deserialize_offline(tgt_folder)
[docs]
def check_blackbox_local_files(tgt_folder) -> bool:
"""checks whether the file of the blackbox ``name`` are present in ``repository_path``"""
return tgt_folder.exists() and (tgt_folder / "metadata.json").exists()
if __name__ == "__main__":
# list all blackboxes available
blackboxes = blackbox_list()
print(blackboxes)
for bb in blackboxes:
print(bb)
# download an existing blackbox
blackbox = load_blackbox(bb)
print(blackbox)