Source code for ewoksmx.shell_utils.execute_slurm
import os
import pathlib
import re
import subprocess
import threading
import time
from typing import List
from typing import Optional
from typing import Union
from .execute_results import BashExecutionResult
[docs]
def execute_bash_commands(
shell_commands: List[str],
working_directory: Union[str, pathlib.Path],
stdout: bool = True,
stderr: bool = True,
stdmerge: bool = True,
name: str = "run",
parameters: Optional[dict] = None,
) -> BashExecutionResult:
"""Execute bash shell commands on Slurm and capture output in files."""
working_directory = pathlib.Path(working_directory)
working_directory.mkdir(parents=True, exist_ok=True)
script_path = working_directory / f"{name}.sh"
result = BashExecutionResult(
script_path, stdout=stdout, stderr=stderr, stdmerge=stdmerge
)
if not parameters:
parameters = {}
parameters.setdefault("job-name", name)
if result.log_path:
parameters["output"] = result.log_path
elif result.stdout_path:
parameters["output"] = result.stdout_path
elif result.stderr_path:
parameters["error"] = result.stderr_path
with open(script_path, "w") as script_file:
script_file.write("#!/bin/bash -l\n")
for k, v in parameters.items():
script_file.write(f"#SBATCH --{k}={v}\n")
script_file.write("\n")
script_file.write("set -e\n\n")
script_file.writelines(f"{cmd}\n" for cmd in shell_commands)
os.chmod(script_path, 0o755)
job_id = _submit_job(script_path)
result.slurm_id = job_id
result.return_code = _wait_job_finished(job_id)
return result
def _submit_job(script_path: pathlib.Path, timeout_sec=120) -> int:
working_dir = script_path.parent
list_command_line = [
"sbatch",
script_path,
]
proc = subprocess.Popen(
list_command_line,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
cwd=working_dir,
)
kill_proc = lambda p: p.kill() # noqa E731
timer = threading.Timer(timeout_sec, kill_proc, [proc])
try:
timer.start()
binary_stdout, binary_stderr = proc.communicate()
stdout = binary_stdout.decode("utf-8")
finally:
timer.cancel()
for line in stdout.split("\n"):
match = re.search(r"Submitted batch job (\d+)", line)
if match:
return int(match.group(1))
raise RuntimeError("Failed to retrieve job ID from sbatch output")
[docs]
def get_job_state(job_id: int) -> Optional[str]:
result = subprocess.run(
["scontrol", "show", "job", str(job_id)], capture_output=True, text=True
)
match = re.search(r"JobState=(\S+)", result.stdout)
return match.group(1) if match else None
def _wait_job_finished(job_id: int) -> int:
while True:
job_state = get_job_state(job_id)
if job_state in ("PENDING", "RUNNING", None):
time.sleep(1)
continue
if job_state in ("COMPLETED", "COMPLETING"):
return 0
return 1