Source code for sisyphus.task

import math
import os
import logging
import sys
import time
from typing import Optional, Union, Any, Sequence, Dict, List
import subprocess as sp
from ast import literal_eval

import sisyphus.tools as tools
import sisyphus.global_settings as gs


[docs] class Task: """ Object to hold information what function should be run with which requirements. """ def __init__( self, start: str, resume: Optional[str] = None, rqmt: Optional[Dict[str, Any]] = None, args: Optional[Sequence[Union[List[Any], Any]]] = None, mini_task: bool = False, update_rqmt=None, parallel: int = 0, tries: int = 1, continuable: bool = False, ): """ :param start: name of the function which will be executed on start :param resume: name of the function which will be executed on resume, often set equal to start :param rqmt: job requirements Might contain: "cpu": number of cpus "gpu": number of gpus "mem": amount of memory, in GB "time": amount of time, in hours "multi_node_slots": amount of slots, distributed potentially over multiple nodes. E.g. maps to `--ntasks <multi_node_slots>` parameter for Slurm, and to `-pe <pe_name> <multi_node_slots>` for SGE (SGE parallel environment (PE)). :param args: job arguments :param mini_task: will be run on engine for short jobs if True :param (dict[str],dict[str])->dict[str] update_rqmt: function to update job requirements for interrupted jobs :param parallel: if set to > 0, groups jobs for individual arguments together into the number of batches specified here. Will then submit at max `parallel` jobs into the engine at a time. :param tries: how often this task is resubmitted after failure :param continuable: If set to True this task will not set a finished marker, useful for tasks that can be continued for arbitrarily long, e.g. adding more epochs to neural network training """ if rqmt is None: rqmt = {} if args is None: args = [[]] self._start = start self._resume = resume self._rqmt = rqmt.copy() if mini_task: self._rqmt["engine"] = "short" self._update_rqmt = update_rqmt if update_rqmt else gs.update_engine_rqmt self._args = list(args) self._parallel = len(self._args) if parallel == 0 else parallel self.mini_task = mini_task self.reset_cache() self.last_state = None self.tries = tries self.continuable = continuable def __repr__(self): return "<Task %r job=%r>" % (self._start, getattr(self, "_job", None)) def reset_cache(self): self._state_cache = {} self._state_cache_time = {} def set_job(self, job): """ :param sisyphus.job.Job job: """ self._job = job for name in self._start, self._resume: try: if name is not None: getattr(self._job, name) except AttributeError: logging.critical("Trying to create a task with an invalid function name") logging.critical("Job name: %s" % str(job)) logging.critical("Function name: %s" % str(name)) raise def get_f(self, name): return getattr(self._job, name) def task_ids(self): """ :return: list with all valid task ids :rtype: list[int] """ return list(range(1, self._parallel + 1)) def rqmt(self): if callable(self._rqmt): rqmt = self._rqmt() else: rqmt = self._rqmt # Ensure that the requested memory is a float representing GB if "mem" in rqmt: rqmt["mem"] = tools.str_to_GB(rqmt["mem"]) if "time" in rqmt: rqmt["time"] = tools.str_to_hours(rqmt["time"]) return rqmt def name(self): return self._start def resumeable(self): return self._resume is not None def run(self, task_id, resume_job=False, logging_thread=None): """ This function is executed to run this job. :param int task_id: :param bool resume_job: :param sisyphus.worker.LoggingThread logging_thread: """ logging.debug("Task name: %s id: %s" % (self.name(), task_id)) job = self._job logging.info("Start Job: %s Task: %s" % (job, self.name())) logging.info("Inputs:") for i in sorted(self._job._sis_inputs): if i.path_type == "Path": logging.info(i.get_path()) else: logging.info("%s (Variable: %s, %s)" % (i.get_path(), str(i), type(i.get()))) if gs.WAIT_PERIOD_FOR_INPUTS_AVAILABLE: for _ in range(math.ceil(gs.WAIT_PERIOD_FOR_INPUTS_AVAILABLE)): if os.path.exists(i.get_path()): break logging.warning("Input path does not exist, waiting: %s" % i.get_path()) time.sleep(1) # each input must be at least X seconds old # if an input file is too young it's may not synced in a network filesystem yet try: input_age = time.time() - os.stat(i.get_path()).st_mtime time.sleep(max(0, gs.WAIT_PERIOD_MTIME_OF_INPUTS - input_age)) except FileNotFoundError: (logging.error if gs.TASK_INPUTS_MUST_BE_AVAILABLE else logging.warning)( "Input path does not exist: %s" % i.get_path() ) if gs.TASK_INPUTS_MUST_BE_AVAILABLE: raise tools.get_system_informations(sys.stdout) sys.stdout.flush() try: if resume_job: if self._resume is not None: task = self._resume else: task = self._start logging.warning( "No resume function set (changed tasks after job was initialized?) " "Fallback to normal start function: %s" % task ) else: task = self._start assert task is not None, "Error loading task" # save current directory and change into work directory with tools.execute_in_dir(self.path(gs.JOB_WORK_DIR)): f = getattr(self._job, task) # get job arguments for arg_id in self._get_arg_idx_for_task_id(task_id): args = self._args[arg_id] if not isinstance(args, (list, tuple)): args = [args] logging.info("-" * 60) logging.info("Starting subtask for arg id: %d args: %s" % (arg_id, str(args))) logging.info("-" * 60) f(*args) except sp.CalledProcessError as e: if e.returncode == 137: # TODO move this into engine class logging.error("Command got killed by SGE (probably out of memory):") logging.error("Cmd: %s" % e.cmd) logging.error("Args: %s" % str(e.args)) logging.error("Return-Code: %s" % e.returncode) logging_thread.out_of_memory = True logging_thread.stop() else: logging.error("Executed command failed:") logging.error("Cmd: %s" % e.cmd) logging.error("Args: %s" % str(e.args)) logging.error("Return-Code: %s" % e.returncode) logging_thread.stop() self.error(task_id, True) except Exception: # Job failed logging.error("Job failed, traceback:") sys.excepthook(*sys.exc_info()) logging_thread.stop() self.error(task_id, True) # TODO handle failed job else: # Job finished normally logging_thread.stop() if not self.continuable: self.finished(task_id, True) sys.stdout.flush() sys.stderr.flush() logging.info("Job finished successfully") def task_name(self): return "%s.%s" % (self._job._sis_id(), self.name()) def path(self, path_type=None, task_id=None): if path_type not in (None, gs.JOB_WORK_DIR, gs.JOB_SAVE, gs.JOB_LOG_ENGINE): path_type = "%s.%s" % (path_type, self.name()) return self._job._sis_path(path_type, task_id) def check_state(self, state, task_id=None, update=None, combine=all, minimal_time_since_change=0): """ :param state: name of state :param int|list[int]|None task_id: :param bool|None update: if not None change state to this value :param combine: how states are combines, e.g. only finished if all jobs are finished => all, error state is true if only one or more is has the error flag => any :param minimal_time_since_change: only true if state change is at least that old :return: if this state is currently set or not :rtype: bool """ if task_id is None: task_id = self.task_ids() current_state = self._job._sis_file_logging( state + "." + self.name(), task_id, update=update, combine=combine, minimal_file_age=minimal_time_since_change, ) return current_state def finished(self, task_id=None, update=None) -> bool: minimal_time_since_change = 0 if not gs.SKIP_IS_FINISHED_TIMEOUT: minimal_time_since_change = gs.WAIT_PERIOD_JOB_FS_SYNC + gs.WAIT_PERIOD_JOB_CLEANUP if self.check_state( gs.STATE_FINISHED, task_id, update=update, combine=all, minimal_time_since_change=minimal_time_since_change ): return True else: return False def error(self, task_id=None, update=None): """ :param int|list[int]|None task_id: :param bool|None update: :return: true if job or task is in error state. :rtype: bool """ if update: # set error flag self.check_state(gs.STATE_ERROR, task_id, update=update, combine=any) return True if isinstance(task_id, int): task_ids = [task_id] elif task_id is None: task_ids = self.task_ids() elif isinstance(task_id, list): task_ids = task_id else: raise Exception("unexpected task_id %r" % (task_id,)) assert isinstance(task_ids, list) for task_id in task_ids: error_file = self._job._sis_path(gs.STATE_ERROR + "." + self.name(), task_id) error_file = os.path.realpath(error_file) if os.path.isfile(error_file): # task is in error state # move log file and remove error file if a usued try is left for i in range(1, self.tries): log_file = self._job._sis_path(gs.JOB_LOG + "." + self.name(), task_id) new_name = "%s.error.%02i" % (log_file, i) if not os.path.isfile(new_name): if os.path.isfile(log_file): os.rename(log_file, new_name) os.remove(error_file) break if os.path.isfile(error_file): # task is still in error state return True return False def started(self, task_id=None): """True if job execution has started""" path = self.path(gs.JOB_LOG, task_id) state = os.path.isfile(path) return state def print_error(self, lines=0): for task_id in self.task_ids(): if self.error(task_id): logging.error("Job: %s Task: %s %s" % (self._job._sis_id(), self.name(), task_id)) logpath = self.path(gs.JOB_LOG, task_id) if os.path.exists(logpath): with open(logpath) as log: logging.error("Logfile:") print() if lines > 0: print("".join(log.readlines()[-lines:]), end="") else: print(log.read()) def state(self, engine, task_id=None, force=False): if force or time.time() - self._state_cache_time.get(task_id, -20) >= 10: state = self._get_state(engine, task_id) self._state_cache[task_id] = state self._state_cache_time[task_id] = time.time() return self._state_cache[task_id] def _get_state(self, engine, task_id=None): """Store return of helper as value as last state""" self.last_state = self._get_state_helper(engine, task_id) return self.last_state def _get_state_helper(self, engine, task_id=None): """Return state of this task given by external code""" # Handling external states if self.finished(task_id): return gs.STATE_FINISHED elif self.error(task_id): return gs.STATE_ERROR else: # Task is not finished and not in error state, time to check the engine if task_id is None: # Check all task_id of this task, return the 'worst' state engine_states = [self.state(engine, i) for i in self.task_ids()] for engine_state in ( gs.STATE_ERROR, gs.STATE_QUEUE_ERROR, gs.STATE_INTERRUPTED_RESUMABLE, gs.STATE_INTERRUPTED_NOT_RESUMABLE, gs.STATE_RUNNABLE, gs.STATE_QUEUE, gs.STATE_RUNNING, gs.STATE_RETRY_ERROR, gs.STATE_FINISHED, ): if engine_state in engine_states: return engine_state logging.critical("Could not determine state of task: %s" % str(engine_states)) assert False # This code point should be unreachable else: # check state for the given task id if engine is None: engine_state = gs.STATE_UNKNOWN else: engine_state = engine.task_state(self, task_id) assert engine_state in (gs.STATE_QUEUE, gs.STATE_QUEUE_ERROR, gs.STATE_RUNNING, gs.STATE_UNKNOWN) # force cache update to avoid caching problems if last state was not also UNKNOWN if ( engine_state == gs.STATE_UNKNOWN and self.last_state and self.last_state != gs.STATE_UNKNOWN and self.started(task_id) ): engine.reset_cache() engine_state = engine.task_state(self, task_id) assert engine_state in ( gs.STATE_QUEUE, gs.STATE_QUEUE_ERROR, gs.STATE_RUNNING, gs.STATE_UNKNOWN, ) if engine_state == gs.STATE_UNKNOWN: if self.started(task_id): # check again if it finished or crashed while retrieving the state if self.finished(task_id): return gs.STATE_FINISHED elif self.error(task_id): return gs.STATE_ERROR # job logging file got updated recently, assume job is still running. # used to avoid wrongly marking jobs as interrupted do to slow filesystem updates elif self.running(task_id): return gs.STATE_RUNNING history = [] if engine is None else engine.get_submit_history(self)[task_id] # Filter entries with less progress than the most recent entry. # https://github.com/rwth-i6/sisyphus/issues/295 completed_fraction = (history[-1].get("completed_fraction") or 0) if history else 0 history = [d for d in history if (d.get("completed_fraction") or 0) >= completed_fraction] if len(history) > gs.MAX_SUBMIT_RETRIES: # More then three tries to run this task, something is wrong return gs.STATE_RETRY_ERROR else: # Task was started, but isn't running anymore => interrupted if self._resume is None: return gs.STATE_INTERRUPTED_NOT_RESUMABLE else: return gs.STATE_INTERRUPTED_RESUMABLE else: return gs.STATE_RUNNABLE else: if engine_state == gs.STATE_RUNNING and self.running(task_id) is False: # Warn if job is running but doesn't update logging file anymore logging.warning( "Job marked as running but logging file has not been updated: " "%s assume it is running" % str(self._job) ) return engine_state def running(self, task_id): """ :return: True if usage file changed recently, None if usage file doesn't exist False otherwise """ usage_file = self._job._sis_path(gs.PLOGGING_FILE + "." + self.name(), task_id, abspath=True) maximal_file_age = gs.WAIT_PERIOD_JOB_FS_SYNC + gs.PLOGGING_UPDATE_FILE_PERIOD + gs.WAIT_PERIOD_JOB_CLEANUP if not os.path.isfile(usage_file): return None if maximal_file_age > time.time() - os.path.getmtime(usage_file): return True log_file = self._job._sis_path(gs.JOB_LOG + "." + self.name(), task_id) return os.path.isfile(log_file) and maximal_file_age > time.time() - os.path.getmtime(log_file) def _get_arg_idx_for_task_id(self, task_id): """ :param int task_id: :rtype: list[int] """ assert task_id > 0, "this function assumes task_ids start at 1" nargs = len(self._args) chunk_size = nargs // self._parallel overflow = nargs % self._parallel if task_id - 1 < overflow: start = (chunk_size + 1) * (task_id - 1) return range(start, start + chunk_size + 1) else: start = (chunk_size + 1) * overflow + chunk_size * (task_id - 1 - overflow) return range(start, start + chunk_size) def update_rqmt(self, last_rqmt, task_id): """Update task requirements of interrupted job""" last_rqmt = last_rqmt.copy() # Make sure mem and time are numbers and not str last_rqmt["mem"] = tools.str_to_GB(last_rqmt["mem"]) last_rqmt["time"] = tools.str_to_hours(last_rqmt["time"]) usage_file = self._job._sis_path(gs.PLOGGING_FILE + "." + self.name(), task_id, abspath=True) try: last_usage = literal_eval(open(usage_file).read()) except (SyntaxError, IOError): # we don't know anything if no usage file is writen or is invalid, just reuse last rqmts return last_rqmt return self._update_rqmt(last_rqmt=last_rqmt, last_usage=last_usage) def get_process_logging_path(self, task_id): return self._job._sis_path(gs.PLOGGING_FILE + "." + self.name(), task_id, abspath=True) def __str__(self): return "Task < workdir(%s) name(%s) ids(%s) >" % ( self.path(), self.name(), ",".join(str(i) for i in self.task_ids()), ) def get_worker_call(self, task_id=None): if isinstance(gs.SIS_COMMAND, list): call = gs.SIS_COMMAND[:] else: call = gs.SIS_COMMAND.split() call += [gs.CMD_WORKER, os.path.relpath(self.path()), self.name()] if task_id is not None: call.append(str(task_id)) if hasattr(self, "_job"): call = self._job._sis_worker_wrapper(self._job, self.name(), call) else: logging.warning(f"Task {self.name()} run without an associated Job. Using global worker_wrapper.") call = gs.worker_wrapper(None, self.name(), call) return call