# Author: Jan-Thorsten Peter <peter@cs.rwth-aachen.de>
import os
import subprocess
import time
import logging
import getpass # used to get username
import math
from xml.dom import minidom
import xml.etree.cElementTree
from collections import defaultdict, namedtuple
import sisyphus.global_settings as gs
from sisyphus.engine import EngineBase
from sisyphus.global_settings import STATE_RUNNING, STATE_UNKNOWN, STATE_QUEUE, STATE_QUEUE_ERROR
ENGINE_NAME = 'sge'
TaskInfo = namedtuple('TaskInfo', ["job_id", "task_id", "state"])
def escape_name(name):
return name.replace('/', '.')
[docs]def try_to_multiply(y, x, backup_value=None):
""" Tries to convert y to float multiply it by x and convert it back
to a rounded string.
return backup_value if it fails
return y if backup_value == None """
try:
return str(int(float(y) * x))
except ValueError:
if backup_value is None:
return y
else:
return backup_value
[docs]class SonOfGridEngine(EngineBase):
def __init__(self, default_rqmt, gateway=None, auto_clean_eqw=True, ignore_jobs=[]):
"""
:param dict default_rqmt: dictionary with the default rqmts
:param str gateway: ssh to that node and run all sge commands there
:param bool auto_clean_eqw: if True jobs in eqw will be set back to qw automatically
:param list[str] ignore_jobs: list of job ids that will be ignored during status updates.
Useful if a job is stuck inside of SGE and can not be deleted.
Job should be listed as "job_number.task_id" e.g.: ['123.1', '123.2', '125.1']
"""
self._task_info_cache_last_update = 0
self.gateway = gateway
self.default_rqmt = default_rqmt
self.auto_clean_eqw = auto_clean_eqw
self.ignore_jobs = ignore_jobs
[docs] def system_call(self, command, send_to_stdin=None):
"""
:param list[str] command: qsub command
:param str|None send_to_stdin: shell code, e.g. the command itself to execute
:return: stdout, stderr, retval
:rtype: list[bytes], list[bytes], int
"""
if self.gateway:
system_command = ['ssh', '-x', self.gateway] + [' '.join(['cd', os.getcwd(), '&&'] + command)]
else:
# no gateway given, skip ssh local
system_command = command
logging.debug('shell_cmd: %s' % ' '.join(system_command))
p = subprocess.Popen(system_command, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
if send_to_stdin:
send_to_stdin = send_to_stdin.encode()
out, err = p.communicate(input=send_to_stdin, timeout=30)
def fix_output(o):
"""
split output and drop last empty line
:param bytes o:
:rtype: list[bytes]
"""
o = o.split(b'\n')
if o[-1] != b'':
print(o[-1])
assert(False)
return o[:-1]
out = fix_output(out)
err = fix_output(err)
retval = p.wait(timeout=30)
# Check for ssh error
err_ = []
for raw_line in err:
lstart = 'ControlSocket'
lend = 'already exists, disabling multiplexing'
line = raw_line.decode('utf8').strip()
if line.startswith(lstart) and line.endswith(lend):
# found ssh connection problem
ssh_file = line[len(lstart):len(lend)].strip()
logging.warning('SSH Error %s' % line.strip())
try:
os.unlink(ssh_file)
logging.info('Delete file %s' % ssh_file)
except OSError:
logging.warning('Could not delete %s' % ssh_file)
else:
err_.append(raw_line)
return out, err_, retval
def options(self, rqmt):
out = []
try:
mem = "%iG" % math.ceil(float(rqmt['mem']))
except ValueError:
mem = rqmt['mem']
# mem = try_to_multiply(s['mem'], 1024*1024*1024) # convert to Gigabyte if possible
out.append('-l')
out.append('h_vmem=%s' % mem)
out.append('-l')
if 'rss' in rqmt:
try:
rss = "%iG" % math.ceil(float(rqmt['rss']))
except ValueError:
rss = rqmt['rss']
# rss = try_to_multiply(s['rss'], 1024*1024*1024) # convert to Gigabyte if possible
out.append('h_rss=%s' % rss)
else:
out.append('h_rss=%s' % mem)
out.append('-l')
out.append('gpu=%s' % rqmt.get('gpu', 0))
out.append('-l')
out.append('num_proc=%s' % rqmt.get('cpu', 1))
# Try to convert time to float, calculate minutes from it
# and convert it back to an rounded string
# If it fails use string directly
task_time = try_to_multiply(rqmt['time'], 60 * 60) # convert to seconds if possible
out.append('-l')
out.append('h_rt=%s' % task_time)
qsub_args = rqmt.get('qsub_args', [])
if isinstance(qsub_args, str):
qsub_args = qsub_args.split()
out += qsub_args
return out
[docs] def submit_call(self, call, logpath, rqmt, name, task_name, task_ids):
"""
:param list[str] call:
:param str logpath:
:param dict[str] rqmt:
:param str name:
:param str task_name:
:param list[int] task_ids:
:return: ENGINE_NAME, submitted (list of (list of task ids, job id))
:rtype: str, list[(list[int],int)]
"""
if not task_ids:
# skip empty list
return
submitted = []
start_id, end_id, step_size = (None, None, None)
for task_id in task_ids:
if start_id is None:
start_id = task_id
elif end_id is None:
end_id = task_id
step_size = end_id - start_id
elif task_id == end_id + step_size:
end_id = task_id
else:
# this id doesn't fit pattern, this should only happen if only parts of the jobs are restarted
job_id = self.submit_helper(call, logpath, rqmt, name, task_name, start_id, end_id, step_size)
submitted.append((list(range(start_id, end_id, step_size)), job_id))
start_id, end_id, step_size = (task_id, None, None)
assert start_id is not None
if end_id is None:
end_id = start_id
step_size = 1
job_id = self.submit_helper(call, logpath, rqmt, name, task_name, start_id, end_id, step_size)
submitted.append((list(range(start_id, end_id, step_size)), job_id))
return ENGINE_NAME, submitted
[docs] def submit_helper(self, call, logpath, rqmt, name, task_name, start_id, end_id, step_size):
"""
:param list[str] call:
:param str logpath:
:param dict[str] rqmt:
:param str name:
:param str task_name:
:param int start_id:
:param int end_id:
:param int step_size:
"""
name = escape_name(name)
qsub_call = [
'qsub',
'-cwd',
'-N',
name,
'-j',
'y',
'-o',
logpath,
'-l',
'h_fsize=50G',
'-S',
'/bin/bash',
'-m',
'n']
qsub_call += self.options(rqmt)
qsub_call += ['-t', '%i-%i:%i' % (start_id, end_id, step_size)]
command = ' '.join(call) + '\n'
try:
out, err, retval = self.system_call(qsub_call, command)
except subprocess.TimeoutExpired:
logging.warning('SSH command timeout %s' % str(command))
time.sleep(gs.WAIT_PERIOD_SSH_TIMEOUT)
return self.submit_helper(call, logpath, rqmt, name, task_name, start_id, end_id, step_size)
ref_output = ['Your', 'job-array', '("%s")' % name, 'has', 'been', 'submitted']
ref_output = [i.encode() for i in ref_output]
job_id = None
if len(out) == 1:
sout = out[0].split()
if retval != 0 or len(err) > 0 or len(sout) != 7 or sout[0:2] + sout[3:] != ref_output:
print(retval, len(err), len(sout), sout[0:2], sout[3:], ref_output)
logging.error("Error to submit job")
logging.error("QSUB command: %s" % ' '.join(qsub_call))
for line in out:
logging.error("Output: %s" % line.decode())
for line in err:
logging.error("Error: %s" % line.decode())
# reset cache, after error
self.reset_cache()
else:
sjob_id = sout[2].decode().split('.')
assert(len(sjob_id) == 2)
assert(sjob_id[1] == '%i-%i:%i' % (start_id, end_id, step_size))
job_id = sjob_id[0]
logging.info("Submitted with job_id: %s %s" % (job_id, name))
for task_id in range(start_id, end_id, step_size):
self._task_info_cache[(name, task_id)].append((job_id, 'qw'))
if False: # for debugging
logging.warning("Boost job!")
subprocess.check_call(('qalter', '-p', '300', job_id))
else:
logging.error("Error to submit job, return value: %i" % retval)
logging.error("QSUB command: %s" % ' '.join(qsub_call))
for line in out:
logging.error("Output: %s" % line.decode())
for line in err:
logging.error("Error: %s" % line.decode())
# reset cache, after error
self.reset_cache()
return job_id
def reset_cache(self):
self._task_info_cache_last_update = -10
[docs] def queue_state(self):
""" Return s list with all currently running tasks in this queue """
if time.time() - self._task_info_cache_last_update < 30:
# use cached value
return self._task_info_cache
# get qstat output
system_command = ['qstat', '-xml', '-u', getpass.getuser()]
try:
out, err, retval = self.system_call(system_command)
except subprocess.TimeoutExpired:
logging.warning('SSH command timeout %s' % str(system_command))
time.sleep(gs.WAIT_PERIOD_SSH_TIMEOUT)
return self.queue_state()
xml_data = ''.join(i.decode('utf8') for i in out)
# parse qstat output
try:
etree = xml.etree.cElementTree.fromstring(xml_data)
except xml.etree.cElementTree.ParseError:
logging.warning('qstat -xml parsing error, retrying\n'
'command: %s\n'
'stdout: %s\n'
'stderr: %s\n'
'return value: %s' % (system_command, out, err, retval))
time.sleep(gs.WAIT_PERIOD_QSTAT_PARSING)
return self.queue_state()
task_infos = defaultdict(list)
for job in etree.getiterator('job_list'):
job_info = {}
for attr in job:
text = attr.text
if text is not None:
text = text.strip()
job_info[attr.tag] = text
name = job_info['JB_name'].strip()
state = job_info['state'].strip()
task_ids = job_info.get('tasks', None)
job_number = job_info['JB_job_number'].strip()
def parse_task_ids(string):
""" Return list with all task ids of this task """
if string is None:
# No task id
return [None]
try:
# just one task id
return [int(string)]
except ValueError:
pass
if ',' in string:
# multiple task ids
tasks_list = []
for i in string.split(','):
tasks_list += parse_task_ids(i)
return tasks_list
if ':' in string:
# taks list
start_end, step_size = string.split(':')
start, end = start_end.split('-')
return list(range(int(start), int(end) + 1, int(step_size)))
logging.warning("Can not parse task: %s : %s" % (str(name), str(string)))
return []
for task_id in parse_task_ids(task_ids):
# Check if this task should be ignored, all sisyphus jobs have a task id
if task_id is not None and "%s.%i" % (job_number, task_id) not in self.ignore_jobs:
task_infos[(name, task_id)].append((job_number, state))
self._task_info_cache = task_infos
self._task_info_cache_last_update = time.time()
return task_infos
[docs] def task_state(self, task, task_id):
""" Return task state:
'r' == STATE_RUNNING
'qw' == STATE_QUEUE
not found == STATE_UNKNOWN
everything else == STATE_QUEUE_ERROR
"""
name = task.task_name()
name = escape_name(name)
task_name = (name, task_id)
queue_state = self.queue_state()
qs = queue_state[task_name]
# task name should be uniq
if len(qs) > 1:
logging.warning(
'More then one matching SGE task, use first match < %s > matches: %s' %
(str(task_name), str(qs)))
if qs == []:
return STATE_UNKNOWN
state = qs[0][1]
if state in ['r', 't']:
return STATE_RUNNING
elif state == 'qw':
return STATE_QUEUE
elif state == 'Eqw':
if self.auto_clean_eqw:
logging.info('Clean job in error state: %s, %s, %s' % (name, task_id, qs))
self.system_call(['qmod', '-cj', "%s.%s" % (qs[0][0], task_id)])
return STATE_QUEUE_ERROR
else:
return STATE_QUEUE_ERROR
[docs] def start_engine(self):
""" No starting action required with the current implementation """
pass
[docs] def stop_engine(self):
""" No stopping action required with the current implementation """
pass
[docs] @staticmethod
def get_task_id(task_id):
assert task_id is None, "SGE task should not be started with task id, it's given via $SGE_TASK_ID"
task_id = os.getenv('SGE_TASK_ID')
if task_id in ['undefined', None]:
# SGE without an array job
logging.critical("Job started without task_id, this should not happen! Continue with task_id=1")
return 1
else:
return int(task_id)
def get_default_rqmt(self, task):
return self.default_rqmt
[docs] def init_worker(self, task):
# setup log file by linking to engine logfile
task_id = SonOfGridEngine.get_task_id(None)
logpath = os.path.relpath(task.path(gs.JOB_LOG, task_id))
if os.path.isfile(logpath):
os.unlink(logpath)
engine_logpath = os.getenv('SGE_STDERR_PATH')
try:
if os.path.isfile(engine_logpath):
os.link(engine_logpath, logpath)
else:
logging.warning("Could not find engine logfile: %s Create soft link anyway." % engine_logpath)
os.symlink(os.path.relpath(engine_logpath, os.path.dirname(logpath)), logpath)
except FileExistsError:
pass
[docs] def get_logpath(self, logpath_base, task_name, task_id):
""" Returns log file for the currently running task """
return os.getenv('SGE_STDERR_PATH')