Source code for sisyphus.graph

from sisyphus.tools import cache_result, extract_paths
import sisyphus.global_settings as gs
from sisyphus.job import Job
from sisyphus.job_path import AbstractPath
from sisyphus.block import Block
import sisyphus.tools as tools
import sisyphus.hash

import atexit
from inspect import isclass
import logging
import collections
import os
import time
import pprint
import threading
from typing import DefaultDict, Optional, List
from multiprocessing.pool import ThreadPool
from datetime import datetime


class Node(object):
    __slots__ = ["job", "sis_id", "inputs", "outputs"]

    def __init__(self, sis_id):
        self.job = None
        self.sis_id = sis_id
        self.inputs = set()
        self.outputs = set()


class OutputTarget:
    def __init__(self, name, inputs):
        """
        :param str name:
        :param inputs:
        """
        self._required = extract_paths(inputs)
        self.required_full_list = sorted(list(self._required))
        self.name = name

    def update_requirements(self, write_output=True, force=False):
        self._required = {out for out in self._required if not out.available()}

    @property
    def required(self):
        return self._required

    def is_done(self):
        if self._required is None:
            return False
        else:
            return len(self._required) == 0

    def run_when_done(self, write_output=None):
        pass

    def __fs_like__(self):
        if len(self.required_full_list) == 1:
            return self.required_full_list[0]
        else:
            return self.required_full_list

    def __eq__(self, other):
        return type(self) == type(other) and self.__dict__ == other.__dict__

    def __hash__(self):
        return sisyphus.hash.int_hash(self)


class OutputPath(OutputTarget):
    def __init__(self, output_path, sis_path):
        assert isinstance(sis_path, AbstractPath)
        self._output_path = output_path
        self._sis_path = sis_path
        super().__init__(output_path, sis_path)

    def run_when_done(self, write_output=None):
        """Checks if output is computed, if yes create output link"""
        assert self._sis_path.available()
        if write_output:
            creator = self._sis_path.creator
            if creator is None or len(creator._sis_alias_prefixes) == 0:
                prefixes = [""]
            else:
                prefixes = list(creator._sis_alias_prefixes)
            # Link output file
            for prefix in prefixes:
                outfile_name = os.path.join(gs.OUTPUT_DIR, prefix, self._output_path)
                outfile_dir = os.path.dirname(outfile_name)

                # Remove file if it exists, if not or if it is an directory an OSError is raised
                try:
                    os.unlink(outfile_dir)
                    logging.warning("Removed file from output directory: %s" % outfile_dir)
                except OSError:
                    pass

                # Create directory if it does not exist yet
                try:
                    try:
                        os.makedirs(outfile_dir)
                    except FileExistsError:
                        pass

                    # Check if current link is correct
                    if os.path.islink(outfile_name):
                        if os.path.realpath(outfile_name) != os.path.realpath(self._sis_path.get_path()):
                            os.unlink(outfile_name)
                        else:
                            return

                    # Set new link if needed
                    os.symlink(os.path.realpath(self._sis_path.get_path()), outfile_name)
                    logging.info("Finished output: %s" % outfile_name)

                    if gs.FINISHED_LOG:
                        if "/" in gs.FINISHED_LOG and not os.path.exists(os.path.dirname(gs.FINISHED_LOG)):
                            os.makedirs(os.path.dirname(gs.FINISHED_LOG))
                        with open(gs.FINISHED_LOG, "a") as f:
                            f.write(
                                datetime.now().strftime("%Y-%m-%d %H:%M:%S: ") + f"Finished output: {outfile_name}\n"
                            )

                except OSError as e:
                    logging.warning("Failed to updated output %s. Exception: %s" % (outfile_name, e))


class OutputCall(OutputTarget):
    def __init__(self, f, argv, kwargs, required=None):
        self._function_call = (f, argv, kwargs)
        if required is False:
            # This call finishes never
            required = None
        elif required is None:
            # Extract requirements automatically
            required = (argv, kwargs)
        else:
            # Use given requirements
            required = required

        name = "callback_%s_%i_%s_%s" % (f.__name__, id(f), gs.SIS_HASH(argv), gs.SIS_HASH(kwargs))
        self._already_called = False
        super().__init__(name, required)

    def run_when_done(self, write_output=None):
        """Runs given function if output is available"""
        assert all(out.available() for out in self._required)
        if not self._already_called:
            f, args, kwargs = self._function_call
            f(*args, **kwargs)
            self._already_called = True


class OutputReport(OutputTarget):
    def __init__(self, output_path, report_values, report_template=None, required=None, update_frequency=300):
        super().__init__(output_path, required if required is not None else report_values)
        self._report_template = report_template
        self._report_values = report_values
        self._output_path = output_path
        self._update_frequency = update_frequency
        self._last_update = -update_frequency
        self._last_values = None

    def update_values(self, report_values):
        if report_values is not None:
            self._report_values = report_values
        self._required = extract_paths(self._report_values)
        self.required_full_list = sorted(list(self._required))

    def update_requirements(self, write_output=True, force=False):
        """Update current report if enough time as passed since last update"""
        if not force and time.time() - self._last_update < self._update_frequency:
            return
        else:
            self._last_update = time.time()

        if write_output:
            self.write_report()

    def write_report(self):
        # Allow for anonymous reports
        if self._output_path is None:
            return
        outfile_name = os.path.join(gs.OUTPUT_DIR, self._output_path)
        outfile_dir = os.path.dirname(outfile_name)
        try:
            if not os.path.isdir(outfile_dir):
                os.makedirs(outfile_dir)

            # Remove link to avoid overwriting other files
            if os.path.islink(outfile_name):
                os.unlink(outfile_name)

            # Actually write report
            with open(outfile_name, "w") as f:
                if self._report_template:
                    f.write(self._report_template.format(**self._report_values))
                elif callable(self._report_values):
                    f.write(str(self._report_values()))
                else:
                    f.write(pprint.pformat(self._report_values, width=140) + "\n")
        except IOError as e:
            logging.warning("Error while updating %s:  %s" % (outfile_name, str(e)))

    def run_when_done(self, write_output=None):
        if write_output:
            self.write_report()

    def __fs_like__(self):
        if callable(self._report_values):
            return super().__fs_like__()
        else:
            return {
                "template": self._report_template,
                "values": self._report_values,
                "frequency": self._update_frequency,
            }


[docs] class SISGraph(object): """This graph contains all targets that needs to be calculated and through there dependencies all required jobs. These jobs can be searched and modified using the provided functions. Most interesting functions are:: # Lists all jobs jobs() # Find jobs by matching substring find(pattern) # Execute function for all nodes for_all_nodes(f) # Dictionaries with jobs sorted by current status: get_jobs_by_status() """ def __init__(self): self._targets = set() # type: set[OutputTarget] self._active_targets = [] # type: list[OutputTarget] self._pool = None self.used_output_path = set() @property def pool(self): if self._pool is None: self._pool = ThreadPool(gs.GRAPH_WORKER) atexit.register(self._pool.close) return self._pool @property def targets(self): return self._targets @property def active_targets(self): return self._active_targets def remove_from_active_targets(self, target): self._active_targets = [out for out in self._active_targets if out != target] @property def targets_dict(self): """ :return: dict name -> target :rtype: dict[str,OutputTarget] """ return {t.name: t for t in self._targets} @property def output(self): """Deprecated: used for backwards comparability, only supports path outputs""" out = {} for t in self._targets: if len(t.required_full_list) == 1: out[t.name] = t.required_full_list[0] else: for pos, path in enumerate(t.required_full_list): out["%s_%02i" % (t.name, pos)] = path return out
[docs] def add_target(self, target): """ :param OutputTarget target: """ # Avoid adding the same target multiple times if target in self._targets: return self._targets.add(target) # check if output path is already used try: path = target._output_path creator = self._sis_path.creator if creator is None or len(creator._sis_alias_prefixes) == 0: prefixes = [""] else: prefixes = list(creator._sis_alias_prefixes) for prefix in prefixes: prefixed_path = prefix + path if prefixed_path in self.used_output_path: logging.warning("Output path is used more than once: %s" % path) self.used_output_path.add(prefixed_path) except AttributeError: pass if not target.is_done(): self._active_targets.append(target)
[docs] def update_nodes(self): """Update all nodes to get the most current dependency graph""" start = time.time() def update_nodes(job): job._sis_runnable() return True self.for_all_nodes(update_nodes) logging.debug("All graph nodes updated (time needed: %.2f)" % (time.time() - start))
@cache_result(gs.FILESYSTEM_CACHE_TIME) def id_to_job_dict(self): return {job._sis_id(): job for job in self.jobs()} def __contains__(self, item): assert isinstance(item, Job) return item._sis_id() in self.id_to_job_dict() @cache_result(gs.FILESYSTEM_CACHE_TIME) def job_directory_structure(self): d = {} for job in self.jobs(): current = d path = job._sis_id().split("/") for step in path[:-1]: if step not in current: current[step] = {} current = current[step] current[path[-1]] = job return d def job_by_id(self, sis_id): return self.id_to_job_dict().get(sis_id)
[docs] def jobs(self): """ :return ([Job, ...]): List with all jobs in grpah """ job_list = [] def f(job): job_list.append(job) return True self.for_all_nodes(f) return job_list
[docs] def find(self, pattern, mode="all"): """Returns a list with all jobs and paths that partly match the pattern :param pattern(str): Pattern to match :param mode(str): Select if jobs, paths or both should be returned. Possible values: all, path, job :return ([Job/Path, ...]): List with all matching jobs/paths """ out = set() for j in self.jobs(): if mode in ("all", "job"): vis_name = j.get_vis_name() aliases = j._sis_aliases if j._sis_aliases is not None else set() if ( pattern in j._sis_path() or (vis_name is not None and pattern in vis_name) or any(pattern in a for a in aliases) ): out.add(j) if mode in ("all", "path"): for p in j._sis_inputs: if pattern in str(p): out.add(p) return list(out)
[docs] def jobs_sorted(self): """Yields jobs in a order so that for each jop all jobs it depends on are already finished :return (generator Node): jobs sorted by dependency """ id_to_job = {} def get_job(sis_id): if sis_id not in id_to_job: id_to_job[sis_id] = Node(sis_id) return id_to_job[sis_id] stack = [] for job in self.jobs(): node = get_job(job._sis_id()) node.job = job for i in job._sis_inputs: if i.creator: node.inputs.add(i.creator._sis_id()) get_job(i.creator._sis_id()).outputs.add(job._sis_id()) id_to_job[node.sis_id] = node if not node.inputs: stack.append(node) stack.sort(key=lambda n: n.sis_id) def recursive_depth(node): yield node.job for sis_id in sorted(list(node.outputs)): next_node = id_to_job[sis_id] next_node.inputs.remove(node.sis_id) if not next_node.inputs: for i in recursive_depth(next_node): yield i for node in stack: for i in recursive_depth(node): yield i
[docs] def get_jobs_by_status( self, nodes: Optional[List] = None, engine: Optional = None, skip_finished: bool = False ) -> DefaultDict[str, List[Job]]: """Return all jobs needed to finish output in dictionary with current status as key :param nodes: all nodes that will be checked, defaults to all output nodes in graph :param sisyphus.engine.EngineBase engine: Use status job status of engine, ignore engine status if set to None (default: None) :param bool skip_finished: Stop checking subtrees of finished nodes to save time :return ({status1\\: [Job, ...], status2\\: ...}): Dictionary with all jobs sorted by current state :rtype: dict[str,list[Job]] """ states = collections.defaultdict(set) lock = threading.Lock() def get_unfinished_jobs(job): """ Returns a list with all non finished jobs. :param Job job: :rtype: bool """ # job not visited in this run, need to calculate dependencies if skip_finished and job._sis_finished(): # stop on finished return False new_state = None if job._sis_runnable(): if job._sis_setup(): if job._sis_is_set_to_hold(): new_state = gs.STATE_HOLD elif job._sis_finished(): new_state = gs.STATE_FINISHED else: # check state of tasks for task in job._sis_tasks(): if not task.finished(): new_state = task.state(engine) break # Job finished since previous check if new_state is None: # Stop here if skip_finished: return False else: new_state = gs.STATE_FINISHED else: if job._sis_is_set_to_hold(): new_state = gs.STATE_HOLD else: new_state = gs.STATE_RUNNABLE else: new_state = gs.STATE_WAITING # List input paths for i in job._sis_inputs: if i.creator is None: path_state = gs.STATE_INPUT_PATH if i.available() else gs.STATE_INPUT_MISSING with lock: states[path_state].add(i.get_path()) assert new_state is not None with lock: states[new_state].add(job) return True self.for_all_nodes(get_unfinished_jobs, nodes=nodes) return states
[docs] def for_all_nodes(self, f, nodes=None, bottom_up=False): """ Run function f for each node and ancestor for `nodes` from top down, stop expanding tree branch if functions returns False. Does not stop on None to allow functions with no return value to run for every node. :param (Job)->bool f: function will be executed for all nodes :param nodes: all nodes that will be checked, defaults to all output nodes in graph :param bool bottom_up: start with deepest nodes first, ignore return value of f :return: set with all visited nodes """ # fill nodes with all output nodes if none are given if nodes is None: nodes = [] for target in self._targets: for path in target.required_full_list: if path.creator: nodes.append(path.creator) if gs.GRAPH_WORKER == 1: visited_set = set() visited_list = [] queue = list(reversed(nodes)) while queue: job = queue.pop(-1) if id(job) in visited_set: continue visited_set.add(id(job)) job._sis_runnable() if bottom_up: # execute in reverse order at the end visited_list.append(job) else: res = f(job) # Stop if function has a not None but false return value if res is not None and not res: continue for path in job._sis_inputs: if path.creator: if id(path.creator) not in visited_set: queue.append(path.creator) if bottom_up: for job in reversed(visited_list): f(job) return visited_set visited = {} finished = 0 pool_lock = threading.Lock() finished_lock = threading.Lock() pool = self.pool # recursive function to run through tree def runner(job): """ :param Job job: """ sis_id = job._sis_id() with pool_lock: if sis_id not in visited: visited[sis_id] = pool.apply_async( tools.default_handle_exception_interrupt_main_thread(runner_helper), (job,) ) def runner_helper(job): """ :param Job job: """ # make sure all inputs are updated job._sis_runnable() nonlocal finished if bottom_up: for path in job._sis_inputs: if path.creator: runner(path.creator) f(job) else: res = f(job) # Stop if function has a not None but false return value if res is None or res: for path in job._sis_inputs: if path.creator: runner(path.creator) with finished_lock: finished += 1 for node in nodes: runner(node) # Check if all jobs are finished while len(visited) != finished: time.sleep(0.1) # Check again and create output set out = set() for k, v in visited.items(): v.get() out.add(k) return out
def path_to_all_nodes(self): visited = {} check_later = {} # recursive function to run through tree def runner(obj, path, only_check): if id(obj) in visited: return else: visited[id(obj)] = obj if not isclass(obj): try: sis_id = obj._sis_id() if sis_id in visited: return else: visited[sis_id] = obj if only_check: logging.warning( "Could not export %s since it's only reachable " "via sets. %s" % (obj, only_check) ) else: yield path, obj except AttributeError: pass if isinstance(obj, set): if len(obj) == 1: for name, value in enumerate(obj): yield from runner(value, path + [name], only_check=only_check) elif only_check: # check all values in the given set for name, value in enumerate(obj): yield from runner(value, path + [name], only_check=only_check) else: # we can not handle this case since a set can be sorted different every time # check later if we have any jobs we could not map in the end if id(obj) in check_later: check_later[id(obj)][1].append(path) else: check_later[id(obj)] = (obj, [path]) return elif isinstance(obj, list): for name, value in enumerate(obj): yield from runner(value, path + [name], only_check=only_check) elif isinstance(obj, dict): for name, value in obj.items(): assert is_literal(name), "Can not export %s (type: %s) as directory key" % (name, type(name)) yield from runner(value, path + [name], only_check=only_check) elif isinstance(obj, AbstractPath): yield from runner(obj.creator, path + ["creator"], only_check=only_check) else: try: for name, value in obj.__dict__.items(): if not name.startswith("_sis") and not isinstance(obj, Block): yield from runner(value, path + [name], only_check=only_check) except AttributeError: pass # check all outputs for target in self._targets: if isinstance(target, OutputPath): path = target._output_path obj = target._sis_path if obj.creator: yield from runner(obj.creator, [path, "creator"], only_check=False) # check if there are any jobs that could not be reached due to sets for _, (obj, possible_paths) in check_later.items(): del visited[id(obj)] # remove to avoid early aborting for i in runner(obj, [], only_check=possible_paths): pass
[docs] def get_job_from_path(self, path): """The reverse function for get_path_to_all_nodes""" # extract dict from targets current = {} for t in self._targets: if len(t.required_full_list) == 1: current[t.name] = t.required_full_list[0] else: for pos, required_path in enumerate(t.required_full_list): current["%s_%02i" % (t.name, pos)] = required_path for step in path: if isinstance(current, dict): current = current.get(step) elif isinstance(current, (list, tuple)): if 0 <= step < len(current): current = current[step] else: return None elif hasattr(current, "__dict__"): current = current.__dict__.get(step) else: return None return current
[docs] def set_job_targets(self, engine=None): """Add a target to all jobs (if possible) to have a more informative output""" # Reset all caches def f(job): try: job._sis_needed_for_which_targets = set() except AttributeError: pass self.for_all_nodes(f) for target in self.targets: if isinstance(target, OutputPath): name = target._output_path out = target._sis_path if out.creator is not None: logging.info( "Add target %s to jobs (used for more informativ output, " "disable with SHOW_JOB_TARGETS=False)" % name ) def f(job): if gs.SHOW_JOB_TARGETS is True or len(job._sis_needed_for_which_targets) < gs.SHOW_JOB_TARGETS: job._sis_needed_for_which_targets.add(name) return True return False self.for_all_nodes(f=f, nodes=[out.creator])
def is_literal(obj, visited=None): # The most likely checks at the beginning if isinstance(obj, (str, bytes, int, float, type(None))): return True # Avoid being stuck in a loop if visited is None: visited = {id(obj)} elif id(obj) in visited: return False else: visited.add(id(obj)) if isinstance(obj, (list, tuple, set)): return all(is_literal(i, visited) for i in obj) if isinstance(obj, dict): for k, v in obj.items(): if not is_literal(k, visited) or not is_literal(v, visited): return False return False graph = SISGraph()