# Author: Wilfried Michel <michel@cs.rwth-aachen.de>
from typing import Any
import os
import subprocess
import time
import logging
from collections import defaultdict, namedtuple
import sisyphus.global_settings as gs
from sisyphus.engine import EngineBase
from sisyphus.global_settings import STATE_RUNNING, STATE_UNKNOWN, STATE_QUEUE, STATE_QUEUE_ERROR
ENGINE_NAME = "lsf"
TaskInfo = namedtuple("TaskInfo", ["job_id", "task_id", "state"])
def escape_name(name):
return name.replace("/", ".")
[docs]
def try_to_multiply(y, x, backup_value=None):
"""Tries to convert y to float multiply it by x and convert it back
to a rounded string.
return backup_value if it fails
return y if backup_value == None"""
try:
return str(int(float(y) * x))
except ValueError:
if backup_value is None:
return y
else:
return backup_value
[docs]
class LoadSharingFacilityEngine(EngineBase):
def __init__(self, default_rqmt, gateway=None, auto_clean_eqw=True):
self._task_info_cache_last_update = 0
self.gateway = gateway
self.default_rqmt = default_rqmt
self.auto_clean_eqw = auto_clean_eqw
def _system_call_timeout_warn_msg(self, command: Any) -> str:
if self.gateway:
return f"SSH command timeout: {command!s}"
return f"Command timeout: {command!s}"
def system_call(self, command, send_to_stdin=None):
if self.gateway:
system_command = ["ssh", "-x", self.gateway] + [" ".join(["cd", os.getcwd(), "&&"] + command)]
else:
# no gateway given, skip ssh local
system_command = command
logging.debug("shell_cmd: %s" % " ".join(system_command))
p = subprocess.Popen(system_command, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
if send_to_stdin:
send_to_stdin = send_to_stdin.encode()
out, err = p.communicate(input=send_to_stdin, timeout=30)
def fix_output(o):
# split output and drop last empty line
o = o.split(b"\n")
if o[-1] != b"":
print(o[-1])
assert False
return o[:-1]
out = fix_output(out)
err = fix_output(err)
retval = p.wait(timeout=30)
# Check for ssh error
err_ = []
for raw_line in err:
lstart = "ControlSocket"
lend = "already exists, disabling multiplexing"
line = raw_line.decode("utf8").strip()
if line.startswith(lstart) and line.endswith(lend):
# found ssh connection problem
ssh_file = line[len(lstart) : len(lend)].strip()
logging.warning("SSH Error %s" % line.strip())
try:
os.unlink(ssh_file)
logging.info("Delete file %s" % ssh_file)
except OSError:
logging.warning("Could not delete %s" % ssh_file)
else:
err_.append(raw_line)
return (out, err_, retval)
def options(self, rqmt):
out = []
mem = try_to_multiply(rqmt["mem"], 1024) # convert to Megabyte if possible
out.append("-M %s" % mem)
if "rss" in rqmt:
rss = try_to_multiply(rqmt["rss"], 1024) # convert to Megabyte if possible
out.append("-v %s" % rss)
else:
out.append("-v %s" % mem)
if rqmt.get("gpu", 0) > 0:
out.append("-a gpu")
out.append("-n %s" % rqmt.get("cpu", 1))
# Try to convert time to float, calculate minutes from it
# and convert it back to an rounded string
# If it fails use string directly
task_time = try_to_multiply(rqmt["time"], 60) # convert to minutes if possible
out.append("-W %s" % task_time)
if rqmt.get("multi_node_slots", None):
raise NotImplementedError("Multi-node slots are not implemented for LSF")
bsub_args = rqmt.get("bsub_args", [])
if isinstance(bsub_args, str):
bsub_args = bsub_args.split()
out += bsub_args
return out
[docs]
def submit_call(self, call, logpath, rqmt, name, task_name, task_ids):
if not task_ids:
# skip empty list
return
submitted = []
start_id, end_id, entrycounter, submitstring, submitlist = (None, None, 0, "", [])
for task_id in task_ids:
if start_id is None:
start_id = task_id
elif (end_id is None and task_id == start_id + 1) or task_id == end_id + 1:
end_id = task_id
else:
# this id doesn't fit pattern, this should only happen if only parts of the jobs are restarted
if end_id is None:
submitstring += "%i," % (start_id)
submitlist += [start_id]
start_id = task_id
else:
submitstring += "%i-%i," % (start_id, end_id)
submitlist += list(range(start_id, end_id + 1))
start_id, end_id = (task_id, None)
entrycounter += 1
# The submitstring must not get longer than 255 chars. Assume job_id's are 4 digit numbers at max
if entrycounter == 20:
job_id = self.submit_helper(call, logpath, rqmt, name, task_name, submitstring[:-1])
submitted.append((submitlist, job_id))
entrycounter, submitstring, submitlist = (0, "", [])
assert start_id is not None
if end_id is None:
end_id = start_id
submitstring += "%i-%i," % (start_id, end_id)
submitlist += list(range(start_id, end_id + 1))
job_id = self.submit_helper(call, logpath, rqmt, name, task_name, submitstring[:-1])
submitted.append((submitlist, job_id))
return (ENGINE_NAME, submitted)
def submit_helper(self, call, logpath, rqmt, name, task_name, rangestring):
name = escape_name(name)
bsub_call = [
"bsub",
"-J",
"%s[%s]" % (name, rangestring),
"-o",
"%s/%s.%%J.%%I" % (logpath, name.split(".")[-1]),
]
bsub_call += self.options(rqmt)
# TODO these are commands very depended on the RWTH cluster, should be changed to be an option
command = (
". /usr/local_host/etc/bashrc; module unload intel; module load gcc/7; "
"module load python/3.6.0; module load cuda/80; module load intelmkl/2018; "
+ " ".join(call + ["--redirect_output"])
+ "\n"
)
while True:
try:
logging.info("bsub_call: %s" % bsub_call)
logging.info("command: %s" % command)
out, err, retval = self.system_call(bsub_call, command)
except subprocess.TimeoutExpired:
logging.warning(self._system_call_timeout_warn_msg(command))
time.sleep(gs.WAIT_PERIOD_SSH_TIMEOUT)
continue
break
ref_output = ["Job", "is", "submitted", "to", "queue"]
ref_output = [i.encode() for i in ref_output]
job_id = None
if len(out) == 1:
sout = out[0].split()
if retval != 0 or len(err) > 0 or len(sout) != 7 or sout[0:1] + sout[2:6] != ref_output:
print(retval, len(err), len(sout), sout[0:2], sout[3:], ref_output)
logging.error("Error to submit job")
logging.error("BSUB command: %s" % " ".join(bsub_call))
for line in out:
logging.error("Output: %s" % line.decode())
for line in err:
logging.error("Error: %s" % line.decode())
# reset cache, after error
self.reset_cache()
else:
job_id = sout[1].decode()[1:-1]
logging.info("Submitted with job_id: %s %s" % (job_id, name))
for entry in rangestring.split(","):
if "-" in entry:
start_id, end_id = entry.split("-")
for task_id in range(int(start_id), int(end_id) + 1):
self._task_info_cache[(name, task_id)].append((job_id, "PEND"))
else:
self._task_info_cache[(name, int(entry))].append((job_id, "PEND"))
else:
logging.error("Error to submit job, return value: %i" % retval)
logging.error("BSUB command: %s" % " ".join(bsub_call))
for line in out:
logging.error("Output: %s" % line.decode())
for line in err:
logging.error("Error: %s" % line.decode())
# reset cache, after error
self.reset_cache()
return job_id
def reset_cache(self):
self._task_info_cache_last_update = -10
[docs]
def queue_state(self):
"""Returns list with all currently running tasks in this queue"""
if time.time() - self._task_info_cache_last_update < 30:
# use cached value
return self._task_info_cache
# get bjobs output
system_command = ["bjobs", "-w"]
while True:
try:
out, err, retval = self.system_call(system_command)
except subprocess.TimeoutExpired:
logging.warning(self._system_call_timeout_warn_msg(system_command))
time.sleep(gs.WAIT_PERIOD_SSH_TIMEOUT)
continue
break
task_infos = defaultdict(list)
for line in out[1:]:
line = line.decode()
try:
field = line.split()
name = "[".join(field[6].split("[")[:-1])
state = field[2]
task = int(field[6].split("[")[-1].split("]")[0])
number = field[0]
task_infos[(name, task)].append((number, state))
except Exception:
logging.warning("Failed to parse bjobs -w output: %s" % line)
self._task_info_cache = task_infos
self._task_info_cache_last_update = time.time()
return task_infos
[docs]
def task_state(self, task, task_id):
"""Return task state:
'RUN', 'PROV' == STATE_RUNNING
'PEND', 'WAIT' == STATE_QUEUE
not found == STATE_UNKNOWN
everything else == STATE_QUEUE_ERROR
"""
name = task.task_name()
name = escape_name(name).encode()
task_name = (name, task_id)
queue_state = self.queue_state()
qs = queue_state[task_name]
# task name should be uniq
if len(qs) > 1:
logging.warning(
"More then one matching LSF task, use first match < %s > matches: %s" % (str(task_name), str(qs))
)
if qs == []:
return STATE_UNKNOWN
state = qs[0][1]
if state in ["RUN", "PROV"]:
return STATE_RUNNING
elif state in ["PEND", "WAIT"]:
return STATE_QUEUE
else:
return STATE_QUEUE_ERROR
[docs]
def start_engine(self):
"""No starting action required with the current implementation"""
pass
[docs]
def stop_engine(self):
"""No stopping action required with the current implementation"""
pass
[docs]
def get_task_id(self, task_id):
assert task_id is None, "LSB task should not be started with task id, it's given via $LSB_JOBINDEX"
task_id = os.getenv("LSB_JOBINDEX")
if task_id in ["undefined", None]:
# LSB without an array job
logging.critical("Job started without task_id, this should not happen! Continue with task_id=1")
return 1
else:
return int(task_id)
def get_default_rqmt(self, task):
return self.default_rqmt
[docs]
@staticmethod
def get_logpath(logpath_base, task_name, task_id, engine_selector=None):
"""Returns log file for the currently running task"""
return os.path.join(logpath_base, "%s.%s.%i" % (task_name, os.getenv("LSB_JOBID"), task_id))