Source code for sisyphus.son_of_grid_engine

# Author: Jan-Thorsten Peter <peter@cs.rwth-aachen.de>

from typing import Any
import os
import subprocess

import time
import logging

import getpass  # used to get username
import math

import xml.etree.cElementTree
from collections import defaultdict, namedtuple

import sisyphus.global_settings as gs
from sisyphus.engine import EngineBase
from sisyphus.global_settings import STATE_RUNNING, STATE_UNKNOWN, STATE_QUEUE, STATE_QUEUE_ERROR

ENGINE_NAME = "sge"
TaskInfo = namedtuple("TaskInfo", ["job_id", "task_id", "state"])


[docs] def escape_name(name): """ :param str name: :rtype: str """ return name.replace("/", ".")
[docs] def try_to_multiply(y, x, backup_value=None): """ Tries to convert y to float multiply it by x and convert it back to a rounded string. return backup_value if it fails return y if backup_value == None :param str y: :param int|float x: :param str|None backup_value: :rtype: str """ try: return str(int(float(y) * x)) except ValueError: if backup_value is None: return y else: return backup_value
[docs] class SonOfGridEngine(EngineBase): def __init__(self, default_rqmt, gateway=None, auto_clean_eqw=True, ignore_jobs=None, pe_name="mpi"): """ :param dict default_rqmt: dictionary with the default rqmts :param str gateway: ssh to that node and run all sge commands there :param bool auto_clean_eqw: if True jobs in eqw will be set back to qw automatically :param list[str] ignore_jobs: list of job ids that will be ignored during status updates. Useful if a job is stuck inside of SGE and can not be deleted. Job should be listed as "job_number.task_id" e.g.: ['123.1', '123.2', '125.1'] :param str pe_name: used to select parallel environment (PE), when multi_node_slots is set in rqmt, as `-pe <pe_name> <multi_node_slots>`. The default "mpi" is somewhat arbitrarily chosen as we have it in our environment. """ self._task_info_cache_last_update = 0 self.gateway = gateway self.default_rqmt = default_rqmt self.auto_clean_eqw = auto_clean_eqw if ignore_jobs is None: ignore_jobs = [] self.ignore_jobs = ignore_jobs self.pe_name = pe_name def _system_call_timeout_warn_msg(self, command: Any) -> str: if self.gateway: return f"SSH command timeout: {command!s}" return f"Command timeout: {command!s}"
[docs] def system_call(self, command, send_to_stdin=None): """ :param list[str] command: qsub command :param str|None send_to_stdin: shell code, e.g. the command itself to execute :return: stdout, stderr, retval :rtype: list[bytes], list[bytes], int """ if self.gateway: system_command = ["ssh", "-x", self.gateway] + [" ".join(["cd", os.getcwd(), "&&"] + command)] else: # no gateway given, skip ssh local system_command = command logging.debug("shell_cmd: %s" % " ".join(system_command)) p = subprocess.Popen(system_command, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE) if send_to_stdin: send_to_stdin = send_to_stdin.encode() out, err = p.communicate(input=send_to_stdin, timeout=30) def fix_output(o): """ split output and drop last empty line :param bytes o: :rtype: list[bytes] """ o = o.split(b"\n") if o[-1] != b"": print(o[-1]) assert False return o[:-1] out = fix_output(out) err = fix_output(err) retval = p.wait(timeout=30) # Check for ssh error err_ = [] for raw_line in err: lstart = "ControlSocket" lend = "already exists, disabling multiplexing" line = raw_line.decode("utf8").strip() if line.startswith(lstart) and line.endswith(lend): # found ssh connection problem ssh_file = line[len(lstart) : len(lend)].strip() logging.warning("SSH Error %s" % line.strip()) try: os.unlink(ssh_file) logging.info("Delete file %s" % ssh_file) except OSError: logging.warning("Could not delete %s" % ssh_file) else: err_.append(raw_line) return out, err_, retval
def options(self, rqmt): out = [] try: mem = "%iG" % math.ceil(float(rqmt["mem"])) except ValueError: mem = rqmt["mem"] # mem = try_to_multiply(s['mem'], 1024*1024*1024) # convert to Gigabyte if possible out.append("-l") out.append("h_vmem=%s" % mem) out.append("-l") if "rss" in rqmt: try: rss = "%iG" % math.ceil(float(rqmt["rss"])) except ValueError: rss = rqmt["rss"] # rss = try_to_multiply(s['rss'], 1024*1024*1024) # convert to Gigabyte if possible out.append("h_rss=%s" % rss) else: out.append("h_rss=%s" % mem) try: file_size = "%iG" % math.ceil(float(rqmt["file_size"])) except (ValueError, KeyError): # If a different default value is wanted it can be overwritten by adding # 'file_size' to the default_rqmt of this engine. file_size = rqmt.get("file_size", "50G") out.append("-l") out.append("h_fsize=%s" % file_size) out.append("-l") out.append("gpu=%s" % rqmt.get("gpu", 0)) out.append("-l") out.append("num_proc=%s" % rqmt.get("cpu", 1)) # Try to convert time to float, calculate minutes from it # and convert it back to an rounded string # If it fails use string directly task_time = try_to_multiply(rqmt["time"], 60 * 60) # convert to seconds if possible out.append("-l") out.append("h_rt=%s" % task_time) if rqmt.get("multi_node_slots", None): out.extend(["-pe", self.pe_name, str(rqmt["multi_node_slots"])]) qsub_args = rqmt.get("qsub_args", []) if isinstance(qsub_args, str): qsub_args = qsub_args.split() out += qsub_args return out
[docs] def submit_call(self, call, logpath, rqmt, name, task_name, task_ids): """ :param list[str] call: :param str logpath: :param dict[str] rqmt: :param str name: :param str task_name: :param list[int] task_ids: :return: ENGINE_NAME, submitted (list of (list of task ids, job id)) :rtype: (str, list[(list[int],str)]) """ if not task_ids: # skip empty list return submitted = [] start_id, end_id, step_size = (None, None, None) for task_id in task_ids: if start_id is None: start_id = task_id elif end_id is None: end_id = task_id step_size = end_id - start_id elif task_id == end_id + step_size: end_id = task_id else: # this id doesn't fit pattern, this should only happen if only parts of the jobs are restarted job_id = self.submit_helper(call, logpath, rqmt, name, task_name, start_id, end_id, step_size) submitted.append((list(range(start_id, end_id, step_size)), job_id)) start_id, end_id, step_size = (task_id, None, None) assert start_id is not None if end_id is None: end_id = start_id step_size = 1 job_id = self.submit_helper(call, logpath, rqmt, name, task_name, start_id, end_id, step_size) submitted.append((list(range(start_id, end_id, step_size)), job_id)) return ENGINE_NAME, submitted
[docs] def submit_helper(self, call, logpath, rqmt, name, task_name, start_id, end_id, step_size): """ :param list[str] call: :param str logpath: :param dict[str] rqmt: :param str name: :param str task_name: :param int start_id: :param int end_id: :param int step_size: :rtype: str|None """ name = escape_name(name) qsub_call = ["qsub", "-cwd", "-N", name, "-j", "y", "-o", logpath, "-S", "/bin/bash", "-m", "n"] qsub_call += self.options(rqmt) qsub_call += ["-t", "%i-%i:%i" % (start_id, end_id, step_size)] command = " ".join(call) + "\n" while True: try: out, err, retval = self.system_call(qsub_call, command) except subprocess.TimeoutExpired: logging.warning(self._system_call_timeout_warn_msg(command)) time.sleep(gs.WAIT_PERIOD_SSH_TIMEOUT) continue break ref_output = ["Your", "job-array", '("%s")' % name, "has", "been", "submitted"] ref_output = [i.encode() for i in ref_output] job_id = None if len(out) == 1: sout = out[0].split() if len(sout) == 7 and sout[3].startswith(b'("') and sout[3].endswith(b'")'): if sout[3][2:-2] != name.encode() and name.encode().startswith(sout[3][2:-2]): # SGE can cutoff the job-name. Fix that. ref_output[2] = sout[3] if retval != 0 or len(err) > 0 or len(sout) != 7 or sout[0:2] + sout[3:] != ref_output: print(retval, len(err), len(sout), sout[0:2], sout[3:], ref_output) logging.error("Error to submit job") logging.error("QSUB command: %s" % " ".join(qsub_call)) for line in out: logging.error("Output: %s" % line.decode()) for line in err: logging.error("Error: %s" % line.decode()) # reset cache, after error self.reset_cache() else: sjob_id = sout[2].decode().split(".") assert len(sjob_id) == 2 assert sjob_id[1] == "%i-%i:%i" % (start_id, end_id, step_size) job_id = sjob_id[0] logging.info("Submitted with job_id: %s %s" % (job_id, name)) for task_id in range(start_id, end_id, step_size): self._task_info_cache[(name, task_id)].append((job_id, "qw")) if False: # for debugging logging.warning("Boost job!") subprocess.check_call(("qalter", "-p", "300", job_id)) else: logging.error("Error to submit job, return value: %i" % retval) logging.error("QSUB command: %s" % " ".join(qsub_call)) for line in out: logging.error("Output: %s" % line.decode()) for line in err: logging.error("Error: %s" % line.decode()) # reset cache, after error self.reset_cache() return job_id
def reset_cache(self): self._task_info_cache_last_update = -10
[docs] def queue_state(self): """Return s list with all currently running tasks in this queue""" if time.time() - self._task_info_cache_last_update < 30: # use cached value return self._task_info_cache # get qstat output system_command = ["qstat", "-xml", "-u", getpass.getuser()] while True: try: out, err, retval = self.system_call(system_command) except subprocess.TimeoutExpired: logging.warning(self._system_call_timeout_warn_msg(system_command)) time.sleep(gs.WAIT_PERIOD_SSH_TIMEOUT) continue break xml_data = "".join(i.decode("utf8") for i in out) # parse qstat output try: etree = xml.etree.cElementTree.fromstring(xml_data) except xml.etree.cElementTree.ParseError: logging.warning( "qstat -xml parsing error, retrying\n" "command: %s\n" "stdout: %s\n" "stderr: %s\n" "return value: %s" % (system_command, out, err, retval) ) time.sleep(gs.WAIT_PERIOD_QSTAT_PARSING) return self.queue_state() task_infos = defaultdict(list) for job in etree.iter("job_list"): job_info = {} for attr in job: text = attr.text if text is not None: text = text.strip() job_info[attr.tag] = text name = job_info["JB_name"].strip() state = job_info["state"].strip() task_ids = job_info.get("tasks", None) job_number = job_info["JB_job_number"].strip() def parse_task_ids(string): """ Return list with all task ids of this task :param str|None string: :rtype: list[int|None] """ if string is None: # No task id return [None] try: # just one task id return [int(string)] except ValueError: pass if "," in string: # multiple task ids tasks_list = [] for i in string.split(","): tasks_list += parse_task_ids(i) return tasks_list if ":" in string: # taks list start_end, step_size = string.split(":") start, end = start_end.split("-") return list(range(int(start), int(end) + 1, int(step_size))) logging.warning("Can not parse task: %s : %s" % (str(name), str(string))) return [] for task_id in parse_task_ids(task_ids): # Check if this task should be ignored, all sisyphus jobs have a task id if task_id is not None and "%s.%i" % (job_number, task_id) not in self.ignore_jobs: task_infos[(name, task_id)].append((job_number, state)) self._task_info_cache = task_infos self._task_info_cache_last_update = time.time() return task_infos
[docs] def task_state(self, task, task_id): """Return task state: 'r' == STATE_RUNNING 'qw' == STATE_QUEUE not found == STATE_UNKNOWN everything else == STATE_QUEUE_ERROR """ name = task.task_name() name = escape_name(name) task_name = (name, task_id) queue_state = self.queue_state() qs = queue_state[task_name] # task name should be uniq if len(qs) > 1: logging.warning( "More then one matching SGE task, use first match < %s > matches: %s" % (str(task_name), str(qs)) ) if qs == []: return STATE_UNKNOWN state = qs[0][1] if state in ["r", "t", "Rr", "Rt"]: return STATE_RUNNING elif state == "qw": return STATE_QUEUE elif state == "Eqw": if self.auto_clean_eqw: logging.info("Clean job in error state: %s, %s, %s" % (name, task_id, qs)) self.system_call(["qmod", "-cj", "%s.%s" % (qs[0][0], task_id)]) return STATE_QUEUE_ERROR else: return STATE_QUEUE_ERROR
[docs] def start_engine(self): """No starting action required with the current implementation""" pass
[docs] def stop_engine(self): """No stopping action required with the current implementation""" pass
[docs] @staticmethod def get_task_id(task_id): assert task_id is None, "SGE task should not be started with task id, it's given via $SGE_TASK_ID" task_id = os.getenv("SGE_TASK_ID") if task_id in ["undefined", None]: # SGE without an array job logging.critical("Job started without task_id, this should not happen! Continue with task_id=1") return 1 else: return int(task_id)
def get_default_rqmt(self, task): return self.default_rqmt
[docs] def init_worker(self, task): # setup log file by linking to engine logfile task_id = SonOfGridEngine.get_task_id(None) logpath = os.path.relpath(task.path(gs.JOB_LOG, task_id)) if os.path.isfile(logpath): os.unlink(logpath) engine_logpath = os.getenv("SGE_STDERR_PATH") try: if os.path.isfile(engine_logpath): os.link(engine_logpath, logpath) else: logging.warning("Could not find engine logfile: %s Create soft link anyway." % engine_logpath) os.symlink(os.path.relpath(engine_logpath, os.path.dirname(logpath)), logpath) except FileExistsError: pass
[docs] def get_logpath(self, logpath_base, task_name, task_id): """Returns log file for the currently running task""" return os.getenv("SGE_STDERR_PATH")