Source code for syne_tune.report

# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
import os
import re
import sys
import json
import logging
from ast import literal_eval
from typing import List, Dict, Any
from time import time, perf_counter
from dataclasses import dataclass

from syne_tune.constants import (
    ST_INSTANCE_TYPE,
    ST_INSTANCE_COUNT,
    ST_WORKER_TIME,
    ST_WORKER_COST,
    ST_WORKER_TIMESTAMP,
    ST_WORKER_ITER,
    ST_SAGEMAKER_METRIC_TAG,
)
from syne_tune.util import dump_json_with_numpy

# this is required so that metrics are written
from syne_tune.backend.sagemaker_backend.instance_info import InstanceInfos

logging.basicConfig()
logger = logging.getLogger(__name__)


[docs] @dataclass class Reporter: """ Callback for reporting metric values from a training script back to Syne Tune. Example: .. code-block:: python from syne_tune import Reporter report = Reporter() for epoch in range(1, epochs + 1): # ... report(epoch=epoch, accuracy=accuracy) :param add_time: If True (default), the time (in secs) since creation of the :class:`Reporter` object is reported automatically as :const:`~syne_tune.constants.ST_WORKER_TIME` :param add_cost: If True (default), estimated dollar cost since creation of :class:`Reporter` object is reported automatically as :const:`~syne_tune.constants.ST_WORKER_COST`. This is available for SageMaker backend only. Requires ``add_time=True``. """ add_time: bool = True add_cost: bool = True def __post_init__(self): if self.add_time: self.start = perf_counter() self.iter = 0 # TODO dollar-cost computation is not available for file-based backends, what would be # needed to add support for those backends will be to add a way to access instance-type # information. if self.add_cost: # add instance_type and instance count so that cost can be computed easily self.instance_type = os.getenv( f"SM_HP_{ST_INSTANCE_TYPE.upper()}", None ) self.instance_count = literal_eval( os.getenv(f"SM_HP_{ST_INSTANCE_COUNT.upper()}", "1") ) if self.instance_type is not None: logger.info( f"detected instance-type/instance-count to {self.instance_type}/{self.instance_count}" ) instance_infos = InstanceInfos() if self.instance_type in instance_infos.instances: cost_per_hour = instance_infos( instance_type=self.instance_type ).cost_per_hour self.dollar_cost = cost_per_hour * self.instance_count / 3600 def __call__(self, **kwargs) -> None: """Report metric values from training function back to Syne Tune A time stamp :const:`~syne_tune.constants.ST_WORKER_TIMESTAMP` is added. See :attr:`add_time`, :attr:`add_cost` comments for other automatically added metrics. :param kwargs: Keyword arguments for metrics to be reported, for instance :code:`report(epoch=1, loss=1.2)`. Values must be serializable with json, keys should not start with ``st_`` which is a reserved namespace for Syne Tune internals. """ self._check_reported_values(kwargs) assert not any(key.startswith("st_") for key in kwargs), ( "The metric prefix 'st_' is used by Syne Tune internals, " "please use a metric name that does not start with 'st_'." ) kwargs[ST_WORKER_TIMESTAMP] = time() if self.add_time: seconds_spent = perf_counter() - self.start kwargs[ST_WORKER_TIME] = seconds_spent # second cost will only be there if we were able to properly detect the instance-type and instance-count # from the environment if hasattr(self, "dollar_cost"): kwargs[ST_WORKER_COST] = seconds_spent * self.dollar_cost kwargs[ST_WORKER_ITER] = self.iter self.iter += 1 _report_logger(**kwargs) @staticmethod def _check_reported_values(kwargs: Dict[str, Any]): assert all( v is not None for v in kwargs.values() ), f"Invalid value in report: kwargs = {kwargs}"
def _report_logger(**kwargs): print(f"[{ST_SAGEMAKER_METRIC_TAG}]: {_serialize_report_dict(kwargs)}") sys.stdout.flush() def _serialize_report_dict(report_dict: Dict[str, Any]) -> str: """ :param report_dict: a dictionary of metrics to be serialized :return: serialized string of the reported metrics, an exception is raised if the size is too large or if the dictionary values are not JSON-serializable """ try: report_str = dump_json_with_numpy(report_dict) assert sys.getsizeof(report_str) < 50_000 return report_str except TypeError as e: print("The dictionary set to be reported does not seem to be serializable.") raise e except AssertionError as e: print("The dictionary set to be reported is too large.") raise e except Exception as e: raise e
[docs] def retrieve(log_lines: List[str]) -> List[Dict[str, float]]: """Retrieves metrics reported with :func:`_report_logger` given log lines. :param log_lines: Lines in log file to be scanned for metric reports :return: list of metrics retrieved from the log lines. """ metrics = [] regex = r"\[" + ST_SAGEMAKER_METRIC_TAG + r"\]: (\{.*\})" for metric_values in re.findall(regex, "\n".join(log_lines)): metrics.append(json.loads(metric_values)) return metrics