{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Tune XGBoost" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Install dependencies\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [ "ci-skip-cell" ] }, "outputs": [], "source": [ "%pip install 'syne-tune[extra]'\n", "%pip install xgboost" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from syne_tune import Tuner, StoppingCriterion\n", "from syne_tune.backend import PythonBackend\n", "from syne_tune.config_space import randint, uniform, loguniform\n", "from syne_tune.optimizer.baselines import CQR\n", "from syne_tune.experiments import load_experiment" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Define the training function\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def train(n_estimators: int, max_depth: int, gamma: float, reg_lambda: float):\n", " ''' Training function (the function to be tuned) with hyperparameters passed in as function arguments\n", "\n", " This example demonstrates training an XGBoost model on the UCI ML hand-written digits dataset.\n", " \n", " Note that the training function must be totally self-contained as it needs to be serialized. \n", " Everything (including variables and dependencies) must be defined or imported inside the function scope.\n", " \n", " For more information on XGBoost's hyperparameters, see https://xgboost.readthedocs.io/en/stable/parameter.html\n", " For more information about the dataset, see https://scikit-learn.org/stable/modules/generated/sklearn.datasets.load_digits.html\n", " '''\n", "\n", " from sklearn.datasets import load_digits\n", " from sklearn.model_selection import train_test_split\n", " from syne_tune import Reporter\n", " import xgboost\n", " import numpy as np\n", "\n", " X, y = load_digits(return_X_y=True)\n", "\n", " X_trainval, X_test, y_trainval, y_test = train_test_split(X, y, test_size=0.2, random_state=42)\n", " X_train, X_val, y_train, y_val = train_test_split(X_trainval, y_trainval, test_size=0.25, random_state=42)\n", "\n", " report = Reporter()\n", "\n", " clf = xgboost.XGBClassifier(\n", " n_estimators=n_estimators,\n", " reg_lambda=reg_lambda,\n", " gamma=gamma,\n", " max_depth=max_depth,\n", " )\n", " clf.fit(X_train, y_train)\n", "\n", " y_pred = clf.predict(X_val)\n", " accuracy = (np.equal(y_val, y_pred) * 1.0).mean()\n", "\n", " # report metrics back to syne tune\n", " report(accuracy = accuracy)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Define the tuning parameters\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Hyperparameter configuration space\n", "config_space = {\n", " \"max_depth\": randint(1,10),\n", " \"gamma\": uniform(1,10),\n", " \"reg_lambda\": loguniform(.0000001, 1),\n", " \"n_estimators\": randint(5, 15)\n", "}\n", "\n", "# Scheduler (i.e., HPO algorithm)\n", "scheduler = CQR(\n", " config_space,\n", " metric=\"accuracy\",\n", " do_minimize=False\n", ")\n", "\n", "tuner = Tuner(\n", " trial_backend=PythonBackend(tune_function=train, config_space=config_space),\n", " scheduler=scheduler,\n", " stop_criterion=StoppingCriterion(max_wallclock_time=30),\n", " n_workers=4, # how many trials are evaluated in parallel\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Run the tuning\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tuner.run()\n", "\n", "tuning_experiment = load_experiment(tuner.name)\n", "\n", "print(f\"best result found: {tuning_experiment.best_config()}\")\n", "\n", "tuning_experiment.plot()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.3" } }, "nbformat": 4, "nbformat_minor": 4 }