Source code for sisyphus.task

import math
import os
import logging
import sys
import time
import typing
import subprocess as sp
from ast import literal_eval

import sisyphus.tools as tools
import sisyphus.global_settings as gs


[docs] class Task(object): """ Object to hold information what function should be run with which requirements. """ def __init__( self, start, resume=None, rqmt=None, args=None, mini_task=False, update_rqmt=None, parallel=0, tries=1, continuable=False, ): """ :param str start: name of the function which will be executed on start :param str resume: name of the function which will be executed on resume, often set equal to start :param dict[str] rqmt: job requirements Might contain: "cpu": number of cpus "gpu": number of gpus "mem": amount of memory, in GB "time": amount of time, in hours "multi_node_slots": amount of slots, distributed potentially over multiple nodes. E.g. maps to `--ntasks <multi_node_slots>` parameter for Slurm, and to `-pe <pe_name> <multi_node_slots>` for SGE (SGE parallel environment (PE)). :param typing.Sequence[typing.Union[typing.List[object],object]] args: job arguments :param bool mini_task: will be run on engine for short jobs if True :param (dict[str],dict[str])->dict[str] update_rqmt: function to update job requirements for interrupted jobs :param int parallel: if set to > 0, groups jobs for individual arguments together into the number of batches specified here. Will then submit at max `parallel` jobs into the engine at a time. :param int tries: how often this task is resubmitted after failure :param bool continuable: If set to True this task will not set a finished marker, useful for tasks that can be continued for arbitrarily long, e.g. adding more epochs to neural network training """ if rqmt is None: rqmt = {} if args is None: args = [[]] self._start = start self._resume = resume self._rqmt = rqmt.copy() if mini_task: self._rqmt["engine"] = "short" self._update_rqmt = update_rqmt if update_rqmt else gs.update_engine_rqmt self._args = list(args) self._parallel = len(self._args) if parallel == 0 else parallel self.mini_task = mini_task self.reset_cache() self.last_state = None self.tries = tries self.continuable = continuable def __repr__(self): return "<Task %r job=%r>" % (self._start, getattr(self, "_job", None)) def reset_cache(self): self._state_cache = {} self._state_cache_time = {} def set_job(self, job): """ :param sisyphus.job.Job job: """ self._job = job for name in self._start, self._resume: try: if name is not None: getattr(self._job, name) except AttributeError: logging.critical("Trying to create a task with an invalid function name") logging.critical("Job name: %s" % str(job)) logging.critical("Function name: %s" % str(name)) raise def get_f(self, name): return getattr(self._job, name) def task_ids(self): """ :return: list with all valid task ids :rtype: list[int] """ return list(range(1, self._parallel + 1)) def rqmt(self): if callable(self._rqmt): rqmt = self._rqmt() else: rqmt = self._rqmt # Ensure that the requested memory is a float representing GB if "mem" in rqmt: rqmt["mem"] = tools.str_to_GB(rqmt["mem"]) if "time" in rqmt: rqmt["time"] = tools.str_to_hours(rqmt["time"]) return rqmt def name(self): return self._start def resumeable(self): return self._resume is not None def run(self, task_id, resume_job=False, logging_thread=None): """ This function is executed to run this job. :param int task_id: :param bool resume_job: :param sisyphus.worker.LoggingThread logging_thread: """ logging.debug("Task name: %s id: %s" % (self.name(), task_id)) job = self._job logging.info("Start Job: %s Task: %s" % (job, self.name())) logging.info("Inputs:") for i in sorted(self._job._sis_inputs): if i.path_type == "Path": logging.info(i.get_path()) else: logging.info("%s (Variable: %s, %s)" % (i.get_path(), str(i), type(i.get()))) if gs.WAIT_PERIOD_FOR_INPUTS_AVAILABLE: for _ in range(math.ceil(gs.WAIT_PERIOD_FOR_INPUTS_AVAILABLE)): if os.path.exists(i.get_path()): break logging.warning("Input path does not exist, waiting: %s" % i.get_path()) time.sleep(1) # each input must be at least X seconds old # if an input file is too young it's may not synced in a network filesystem yet try: input_age = time.time() - os.stat(i.get_path()).st_mtime time.sleep(max(0, gs.WAIT_PERIOD_MTIME_OF_INPUTS - input_age)) except FileNotFoundError: (logging.error if gs.TASK_INPUTS_MUST_BE_AVAILABLE else logging.warning)( "Input path does not exist: %s" % i.get_path() ) if gs.TASK_INPUTS_MUST_BE_AVAILABLE: raise tools.get_system_informations(sys.stdout) sys.stdout.flush() try: if resume_job: if self._resume is not None: task = self._resume else: task = self._start logging.warning( "No resume function set (changed tasks after job was initialized?) " "Fallback to normal start function: %s" % task ) else: task = self._start assert task is not None, "Error loading task" # save current directory and change into work directory with tools.execute_in_dir(self.path(gs.JOB_WORK_DIR)): f = getattr(self._job, task) # get job arguments for arg_id in self._get_arg_idx_for_task_id(task_id): args = self._args[arg_id] if not isinstance(args, (list, tuple)): args = [args] logging.info("-" * 60) logging.info("Starting subtask for arg id: %d args: %s" % (arg_id, str(args))) logging.info("-" * 60) f(*args) except sp.CalledProcessError as e: if e.returncode == 137: # TODO move this into engine class logging.error("Command got killed by SGE (probably out of memory):") logging.error("Cmd: %s" % e.cmd) logging.error("Args: %s" % str(e.args)) logging.error("Return-Code: %s" % e.returncode) logging_thread.out_of_memory = True logging_thread.stop() else: logging.error("Executed command failed:") logging.error("Cmd: %s" % e.cmd) logging.error("Args: %s" % str(e.args)) logging.error("Return-Code: %s" % e.returncode) logging_thread.stop() self.error(task_id, True) except Exception: # Job failed logging.error("Job failed, traceback:") sys.excepthook(*sys.exc_info()) logging_thread.stop() self.error(task_id, True) # TODO handle failed job else: # Job finished normally logging_thread.stop() if not self.continuable: self.finished(task_id, True) sys.stdout.flush() sys.stderr.flush() logging.info("Job finished successfully") def task_name(self): return "%s.%s" % (self._job._sis_id(), self.name()) def path(self, path_type=None, task_id=None): if path_type not in (None, gs.JOB_WORK_DIR, gs.JOB_SAVE, gs.JOB_LOG_ENGINE): path_type = "%s.%s" % (path_type, self.name()) return self._job._sis_path(path_type, task_id) def check_state(self, state, task_id=None, update=None, combine=all, minimal_time_since_change=0): """ :param state: name of state :param int|list[int]|None task_id: :param bool|None update: if not None change state to this value :param combine: how states are combines, e.g. only finished if all jobs are finished => all, error state is true if only one or more is has the error flag => any :param minimal_time_since_change: only true if state change is at least that old :return: if this state is currently set or not :rtype: bool """ if task_id is None: task_id = self.task_ids() current_state = self._job._sis_file_logging( state + "." + self.name(), task_id, update=update, combine=combine, minimal_file_age=minimal_time_since_change, ) return current_state def finished(self, task_id=None, update=None) -> bool: minimal_time_since_change = 0 if not gs.SKIP_IS_FINISHED_TIMEOUT: minimal_time_since_change = gs.WAIT_PERIOD_JOB_FS_SYNC + gs.WAIT_PERIOD_JOB_CLEANUP if self.check_state( gs.STATE_FINISHED, task_id, update=update, combine=all, minimal_time_since_change=minimal_time_since_change ): return True else: return False def error(self, task_id=None, update=None): """ :param int|list[int]|None task_id: :param bool|None update: :return: true if job or task is in error state. :rtype: bool """ if update: # set error flag self.check_state(gs.STATE_ERROR, task_id, update=update, combine=any) return True if isinstance(task_id, int): task_ids = [task_id] elif task_id is None: task_ids = self.task_ids() elif isinstance(task_id, list): task_ids = task_id else: raise Exception("unexpected task_id %r" % (task_id,)) assert isinstance(task_ids, list) for task_id in task_ids: error_file = self._job._sis_path(gs.STATE_ERROR + "." + self.name(), task_id) error_file = os.path.realpath(error_file) if os.path.isfile(error_file): # task is in error state # move log file and remove error file if a usued try is left for i in range(1, self.tries): log_file = self._job._sis_path(gs.JOB_LOG + "." + self.name(), task_id) new_name = "%s.error.%02i" % (log_file, i) if not os.path.isfile(new_name): if os.path.isfile(log_file): os.rename(log_file, new_name) os.remove(error_file) break if os.path.isfile(error_file): # task is still in error state return True return False def started(self, task_id=None): """True if job execution has started""" path = self.path(gs.JOB_LOG, task_id) state = os.path.isfile(path) return state def print_error(self, lines=0): for task_id in self.task_ids(): if self.error(task_id): logging.error("Job: %s Task: %s %s" % (self._job._sis_id(), self.name(), task_id)) logpath = self.path(gs.JOB_LOG, task_id) if os.path.exists(logpath): with open(logpath) as log: logging.error("Logfile:") print() if lines > 0: print("".join(log.readlines()[-lines:]), end="") else: print(log.read()) def state(self, engine, task_id=None, force=False): if force or time.time() - self._state_cache_time.get(task_id, -20) >= 10: state = self._get_state(engine, task_id) self._state_cache[task_id] = state self._state_cache_time[task_id] = time.time() return self._state_cache[task_id] def _get_state(self, engine, task_id=None): """Store return of helper as value as last state""" self.last_state = self._get_state_helper(engine, task_id) return self.last_state def _get_state_helper(self, engine, task_id=None): """Return state of this task given by external code""" # Handling external states if self.finished(task_id): return gs.STATE_FINISHED elif self.error(task_id): return gs.STATE_ERROR else: # Task is not finished and not in error state, time to check the engine if task_id is None: # Check all task_id of this task, return the 'worst' state engine_states = [self.state(engine, i) for i in self.task_ids()] for engine_state in ( gs.STATE_ERROR, gs.STATE_QUEUE_ERROR, gs.STATE_INTERRUPTED_RESUMABLE, gs.STATE_INTERRUPTED_NOT_RESUMABLE, gs.STATE_RUNNABLE, gs.STATE_QUEUE, gs.STATE_RUNNING, gs.STATE_RETRY_ERROR, gs.STATE_FINISHED, ): if engine_state in engine_states: return engine_state logging.critical("Could not determine state of task: %s" % str(engine_states)) assert False # This code point should be unreachable else: # check state for the given task id if engine is None: engine_state = gs.STATE_UNKNOWN else: engine_state = engine.task_state(self, task_id) assert engine_state in (gs.STATE_QUEUE, gs.STATE_QUEUE_ERROR, gs.STATE_RUNNING, gs.STATE_UNKNOWN) # force cache update to avoid caching problems if last state was not also UNKNOWN if ( engine_state == gs.STATE_UNKNOWN and self.last_state and self.last_state != gs.STATE_UNKNOWN and self.started(task_id) ): engine.reset_cache() engine_state = engine.task_state(self, task_id) assert engine_state in ( gs.STATE_QUEUE, gs.STATE_QUEUE_ERROR, gs.STATE_RUNNING, gs.STATE_UNKNOWN, ) if engine_state == gs.STATE_UNKNOWN: if self.started(task_id): # check again if it finished or crashed while retrieving the state if self.finished(task_id): return gs.STATE_FINISHED elif self.error(task_id): return gs.STATE_ERROR # job logging file got updated recently, assume job is still running. # used to avoid wrongly marking jobs as interrupted do to slow filesystem updates elif self.running(task_id): return gs.STATE_RUNNING history = [] if engine is None else engine.get_submit_history(self) if history and len(history[task_id]) > gs.MAX_SUBMIT_RETRIES: # More then three tries to run this task, something is wrong return gs.STATE_RETRY_ERROR else: # Task was started, but isn't running anymore => interrupted if self._resume is None: return gs.STATE_INTERRUPTED_NOT_RESUMABLE else: return gs.STATE_INTERRUPTED_RESUMABLE else: return gs.STATE_RUNNABLE else: if engine_state == gs.STATE_RUNNING and self.running(task_id) is False: # Warn if job is running but doesn't update logging file anymore logging.warning( "Job marked as running but logging file has not been updated: " "%s assume it is running" % str(self._job) ) return engine_state def running(self, task_id): """ :return: True if usage file changed recently, None if usage file doesn't exist False otherwise """ usage_file = self._job._sis_path(gs.PLOGGING_FILE + "." + self.name(), task_id, abspath=True) maximal_file_age = gs.WAIT_PERIOD_JOB_FS_SYNC + gs.PLOGGING_UPDATE_FILE_PERIOD + gs.WAIT_PERIOD_JOB_CLEANUP if not os.path.isfile(usage_file): return None if maximal_file_age > time.time() - os.path.getmtime(usage_file): return True log_file = self._job._sis_path(gs.JOB_LOG + "." + self.name(), task_id) return os.path.isfile(log_file) and maximal_file_age > time.time() - os.path.getmtime(log_file) def _get_arg_idx_for_task_id(self, task_id): """ :param int task_id: :rtype: list[int] """ assert task_id > 0, "this function assumes task_ids start at 1" nargs = len(self._args) chunk_size = nargs // self._parallel overflow = nargs % self._parallel if task_id - 1 < overflow: start = (chunk_size + 1) * (task_id - 1) return range(start, start + chunk_size + 1) else: start = (chunk_size + 1) * overflow + chunk_size * (task_id - 1 - overflow) return range(start, start + chunk_size) def update_rqmt(self, last_rqmt, task_id): """Update task requirements of interrupted job""" last_rqmt = last_rqmt.copy() # Make sure mem and time are numbers and not str last_rqmt["mem"] = tools.str_to_GB(last_rqmt["mem"]) last_rqmt["time"] = tools.str_to_hours(last_rqmt["time"]) usage_file = self._job._sis_path(gs.PLOGGING_FILE + "." + self.name(), task_id, abspath=True) try: last_usage = literal_eval(open(usage_file).read()) except (SyntaxError, IOError): # we don't know anything if no usage file is writen or is invalid, just reuse last rqmts return last_rqmt return self._update_rqmt(last_rqmt=last_rqmt, last_usage=last_usage) def get_process_logging_path(self, task_id): return self._job._sis_path(gs.PLOGGING_FILE + "." + self.name(), task_id, abspath=True) def __str__(self): return "Task < workdir(%s) name(%s) ids(%s) >" % ( self.path(), self.name(), ",".join(str(i) for i in self.task_ids()), ) def get_worker_call(self, task_id=None): if isinstance(gs.SIS_COMMAND, list): call = gs.SIS_COMMAND[:] else: call = gs.SIS_COMMAND.split() call += [gs.CMD_WORKER, os.path.relpath(self.path()), self.name()] if task_id is not None: call.append(str(task_id)) if hasattr(self, "_job"): call = self._job._sis_worker_wrapper(self._job, self.name(), call) else: logging.warning(f"Task {self.name()} run without an associated Job. Using global worker_wrapper.") call = gs.worker_wrapper(None, self.name(), call) return call