Code in benchmarking/examples/benchmark_dehb

Comparison of DEHB against a number of baselines.

benchmarking/examples/benchmark_dehb/baselines.py
from typing import Dict, Any
from syne_tune.experiments.baselines import (
    convert_categorical_to_ordinal,
    convert_categorical_to_ordinal_numeric,
)
from syne_tune.experiments.default_baselines import (
    ASHA,
    SyncHyperband,
    SyncBOHB,
    DEHB,
    SyncMOBSTER,
)


class Methods:
    ASHA = "ASHA"
    SYNCHB = "SYNCHB"
    DEHB = "DEHB"
    BOHB = "BOHB"
    ASHA_ORD = "ASHA-ORD"
    SYNCHB_ORD = "SYNCHB-ORD"
    DEHB_ORD = "DEHB-ORD"
    BOHB_ORD = "BOHB-ORD"
    ASHA_STOP = "ASHA-STOP"
    SYNCMOBSTER = "SYNCMOBSTER"


def conv_numeric_then_rest(margs) -> Dict[str, Any]:
    return convert_categorical_to_ordinal(
        convert_categorical_to_ordinal_numeric(
            margs.config_space, kind=margs.fcnet_ordinal
        )
    )


methods = {
    Methods.ASHA: lambda method_arguments: ASHA(
        method_arguments,
        type="promotion",
    ),
    Methods.SYNCHB: lambda method_arguments: SyncHyperband(method_arguments),
    Methods.DEHB: lambda method_arguments: DEHB(method_arguments),
    Methods.BOHB: lambda method_arguments: SyncBOHB(method_arguments),
    Methods.ASHA_ORD: lambda method_arguments: ASHA(
        method_arguments,
        config_space=conv_numeric_then_rest(method_arguments),
        type="promotion",
    ),
    Methods.SYNCHB_ORD: lambda method_arguments: SyncHyperband(
        method_arguments,
        config_space=conv_numeric_then_rest(method_arguments),
    ),
    Methods.DEHB_ORD: lambda method_arguments: DEHB(
        method_arguments,
        config_space=conv_numeric_then_rest(method_arguments),
    ),
    Methods.BOHB_ORD: lambda method_arguments: SyncBOHB(
        method_arguments,
        config_space=conv_numeric_then_rest(method_arguments),
    ),
    Methods.ASHA_STOP: lambda method_arguments: ASHA(
        method_arguments,
        type="stopping",
    ),
    Methods.SYNCMOBSTER: lambda method_arguments: SyncMOBSTER(method_arguments),
}
benchmarking/examples/benchmark_dehb/benchmark_definitions.py
from syne_tune.experiments.benchmark_definitions import (
    nas201_benchmark_definitions,
    fcnet_benchmark_definitions,
    lcbench_selected_benchmark_definitions,
    yahpo_lcbench_selected_benchmark_definitions,
)


benchmark_definitions = {
    **nas201_benchmark_definitions,
    **fcnet_benchmark_definitions,
    **lcbench_selected_benchmark_definitions,
    **yahpo_lcbench_selected_benchmark_definitions,
}
benchmarking/examples/benchmark_dehb/hpo_main.py
from typing import Dict, Any

from baselines import methods
from benchmark_definitions import benchmark_definitions
from syne_tune.experiments.launchers.hpo_main_simulator import main
from syne_tune.util import recursive_merge


extra_args = [
    dict(
        name="num_brackets",
        type=int,
        help="Number of brackets",
    ),
]


def map_method_args(args, method: str, method_kwargs: Dict[str, Any]) -> Dict[str, Any]:
    if args.num_brackets is not None:
        new_dict = {
            "scheduler_kwargs": {"brackets": args.num_brackets},
        }
        method_kwargs = recursive_merge(method_kwargs, new_dict)
    return method_kwargs


if __name__ == "__main__":
    main(methods, benchmark_definitions, extra_args, map_method_args)
benchmarking/examples/benchmark_dehb/launch_remote.py
from pathlib import Path

from benchmark_definitions import benchmark_definitions
from baselines import methods, Methods
from hpo_main import extra_args
from syne_tune.experiments.launchers.launch_remote_simulator import launch_remote


if __name__ == "__main__":

    def _is_expensive_method(method: str) -> bool:
        return method == Methods.SYNCMOBSTER

    entry_point = Path(__file__).parent / "hpo_main.py"
    launch_remote(
        entry_point=entry_point,
        methods=methods,
        benchmark_definitions=benchmark_definitions,
        extra_args=extra_args,
        is_expensive_method=_is_expensive_method,
    )
benchmarking/examples/benchmark_dehb/requirements.txt
syne-tune[gpsearchers,kde,blackbox-repository,yahpo,aws]
tqdm