【高性能计算】java连接slurm提交作业,展示作业队列等
看了下slurm官方提供的rest-api,看懂比较复杂,所以尝试了下自己封装了个api,保留一些简单的参数(够用就行)思路的话主要是通过jsch连接远程服务器,通过服务器去执行slurm命令,并抓取服务器的返回。这边主要讲一个作业提交,封装了slurm的api,resultvo为自定义的返回;主要是将用户的输入拼接成slurm的命令,在服务器上执行。这是大概的接口列表,未来会持续迭代。
·
看了下slurm官方提供的rest-api,看懂比较复杂,所以尝试了下自己封装了个api,保留一些简单的参数(够用就行)

这是大概的接口列表,未来会持续迭代
思路的话主要是通过jsch连接远程服务器,通过服务器去执行slurm命令,并抓取服务器的返回
首先封装jsch工具类,这边因为考虑到直接在slurm的工作节点上去执行命令,所以用单例模式:
package com.easy.slurm.slurm.Jsch;
import com.alibaba.fastjson.JSONObject;
import com.jcraft.jsch.*;
import lombok.extern.slf4j.Slf4j;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicBoolean;
/**
* @author bing.bai
* @create 2025/5/30
*/
@Slf4j
public class JschClient {
private static final JschClient INSTANCE = new JschClient();
private static final int DEFAULT_TIMEOUT = 30; // 默认超时时间(秒)
private static final int THREAD_TIMEOUT = 10; // 线程等待超时时间(秒)
private final AtomicBoolean inited = new AtomicBoolean(false);
private Session session;
private ExecutorService executorService;
private JschClient() {
}
public static JschClient getInstance() {
return INSTANCE;
}
public synchronized void init(String host, String user, String password) throws JSchException {
if (inited.get()) {
log.warn("JschClient already initialized");
return;
}
JSch jsch = new JSch();
session = jsch.getSession(user, host, 22);
session.setPassword(password);
session.setConfig("StrictHostKeyChecking", "no");
session.setServerAliveInterval(30 * 1000);
session.connect();
// 创建专用线程池用于流读取
executorService = Executors.newFixedThreadPool(2);
inited.set(true);
log.info("JschClient initialized successfully for host: {}", host);
}
public RunCommandResult runCommand(String command) throws JSchException, IOException {
log.info("run command :{}",command);
return runCommand(command, DEFAULT_TIMEOUT);
}
public RunCommandResult runCommand(String command, int timeoutSeconds) throws JSchException, IOException {
if (!inited.get()) {
throw new IllegalStateException("JschClient not initialized");
}
RunCommandResult result = new RunCommandResult();
ChannelExec channel = null;
try {
channel = (ChannelExec) session.openChannel("exec");
channel.setCommand(command);
channel.setInputStream(null);
// 准备读取流
InputStream in = channel.getInputStream();
InputStream err = channel.getErrStream();
channel.connect();
// 使用CompletableFuture并行读取两个流
CompletableFuture<List<String>> stdoutFuture = readStreamAsync(in, "stdout");
CompletableFuture<List<String>> stderrFuture = readStreamAsync(err, "stderr");
// 等待流读取完成或超时
CompletableFuture<Void> allFutures = CompletableFuture.allOf(stdoutFuture, stderrFuture);
try {
allFutures.get(timeoutSeconds, TimeUnit.SECONDS);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
log.warn("Command execution interrupted", e);
} catch (ExecutionException e) {
log.error("Error reading command output", e.getCause());
} catch (TimeoutException e) {
log.warn("Command execution timed out after {} seconds", timeoutSeconds);
// 取消未完成的读取任务
stdoutFuture.cancel(true);
stderrFuture.cancel(true);
}
// 获取结果(即使超时也获取已读取的部分)
result.setStdout(stdoutFuture.getNow(new ArrayList<>()));
result.setStderr(stderrFuture.getNow(new ArrayList<>()));
result.setExitStatus(channel.getExitStatus());
System.out.println(JSONObject.toJSONString(result));
return result;
} finally {
if (channel != null) {
channel.disconnect();
}
}
}
private CompletableFuture<List<String>> readStreamAsync(InputStream stream, String streamType) {
return CompletableFuture.supplyAsync(() -> {
List<String> lines = new ArrayList<>();
try (BufferedReader reader = new BufferedReader(
new InputStreamReader(stream, StandardCharsets.UTF_8))) {
String line;
while ((line = reader.readLine()) != null) {
lines.add(line);
// 检查线程是否被中断(超时或取消时)
if (Thread.currentThread().isInterrupted()) {
log.debug("{} reading interrupted", streamType);
break;
}
}
} catch (IOException e) {
log.error("Error reading {} stream", streamType, e);
}
return lines;
}, executorService);
}
public synchronized void destroy() {
if (!inited.get()) {
log.warn("JschClient not initialized, nothing to destroy");
return;
}
if (session != null && session.isConnected()) {
session.disconnect();
}
if (executorService != null) {
try {
executorService.shutdown();
if (!executorService.awaitTermination(THREAD_TIMEOUT, TimeUnit.SECONDS)) {
executorService.shutdownNow();
}
} catch (InterruptedException e) {
executorService.shutdownNow();
Thread.currentThread().interrupt();
}
}
inited.set(false);
log.info("JschClient destroyed successfully");
}
public static class RunCommandResult {
private List<String> stdout = new ArrayList<>();
private List<String> stderr = new ArrayList<>();
private int exitStatus;
// Getters and setters
public List<String> getStdout() {
return stdout;
}
public void setStdout(List<String> stdout) {
this.stdout = stdout;
}
public List<String> getStderr() {
return stderr;
}
public void setStderr(List<String> stderr) {
this.stderr = stderr;
}
public int getExitStatus() {
return exitStatus;
}
public void setExitStatus(int exitStatus) {
this.exitStatus = exitStatus;
}
public boolean isSuccess() {
return exitStatus == 0;
}
public String getErrorMsg() {
String errorMsg = "";
for (int i = 0; i < stderr.size(); i++) {
errorMsg += stderr.get(i);
}
return errorMsg;
}
public String getSuccessMsg() {
String successMsg = "";
for (int i = 0; i < stdout.size(); i++) {
successMsg += stdout.get(i);
}
return successMsg;
}
}
}
jsch的配置类:
package com.easy.slurm.slurm.Jsch;
import com.jcraft.jsch.JSchException;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Component;
import javax.annotation.PostConstruct;
/**
* @author bing.bai
* @create 2025/5/30
*/
@Component
public class JschConfig {
@Value("${jsch.host}")
private String host;
@Value("${jsch.user}")
private String user;
@Value("${jsch.password}")
private String password;
@PostConstruct
public void init() throws JSchException {
if (StringUtils.isNotBlank(host)) {
JschClient instance = JschClient.getInstance();
instance.init(host,user,password);
}
}
}
这边主要讲一个作业提交,封装了slurm的api,resultvo为自定义的返回;:

主要是将用户的输入拼接成slurm的命令,在服务器上执行
@Override
public ResultVo submitSbatchJob(SbatchJobSubmit sbatchJobSubmit) throws JSchException, IOException {
if (StringUtils.isBlank(sbatchJobSubmit.getJobFileUrl())) {
throw new BussinessException(ResultCode.FILEURLNOTEXIST);
}
String sbatchCommand = combineSbatchJobCommand(sbatchJobSubmit);
JschClient instance = JschClient.getInstance();
JschClient.RunCommandResult runCommandResult = instance.runCommand(sbatchCommand);
if (runCommandResult.isSuccess()) {
return ResultVo.success(runCommandResult.getSuccessMsg());
} else {
return ResultVo.fail(runCommandResult.getErrorMsg());
}
}
private String combineSbatchJobCommand(SbatchJobSubmit sbatchJobSubmit) {
StringBuilder command = new StringBuilder(SBATCH);
command.append(" -D " + SlurmConstants.WORKDIR + "/" + sbatchJobSubmit.getTenantId());
if (StringUtils.isNotBlank(sbatchJobSubmit.getJobName())) {
command.append(" -J " + sbatchJobSubmit.getJobName());
}
if (StringUtils.isNotBlank(sbatchJobSubmit.getPartition())) {
command.append(" -p " + sbatchJobSubmit.getPartition());
}
if (StringUtils.isNotBlank(sbatchJobSubmit.getTime())) {
command.append(" -t " + sbatchJobSubmit.getTime());
}
if (sbatchJobSubmit.getNodes() != null) {
command.append(" --nodes=" + sbatchJobSubmit.getNodes()) ;
}
if (sbatchJobSubmit.getNtasks() != null) {
command.append(" --ntasks=" + sbatchJobSubmit.getNtasks());
}
if (sbatchJobSubmit.getCpusPerTas() != null) {
command.append(" --cpus-per-task=" + sbatchJobSubmit.getCpusPerTas());
}
if (StringUtils.isNotBlank(sbatchJobSubmit.getMem())) {
command.append(" --mem= " + sbatchJobSubmit.getMem()) ;
}
if (StringUtils.isNotBlank(sbatchJobSubmit.getMemPerCpu())) {
command.append(" --mem-per-cpu= " + sbatchJobSubmit.getMemPerCpu()) ;
}
if (StringUtils.isNotBlank(sbatchJobSubmit.getGres())) {
command.append(" --gres= " + sbatchJobSubmit.getGres());
}
if (sbatchJobSubmit.getExclusive() != null && sbatchJobSubmit.getExclusive()) {
command.append(" --exclusive= ");
}
if (StringUtils.isNotBlank(sbatchJobSubmit.getNodelist())) {
command.append(" --nodelist= ");
}
if (StringUtils.isNotBlank(sbatchJobSubmit.getExclude())) {
command.append(" --exclude= " + sbatchJobSubmit.getExclude());
}
command.append(" " + sbatchJobSubmit.getJobFileUrl()) ;
return command.toString();
}
用户输入类:
package com.easy.slurm.slurm.bean;
import lombok.Data;
/**
* @author bing.bai
* @create 2025/6/5
*/
@Data
public class SbatchJobSubmit {
private String tenantId;
//sbatch脚本地址
private String jobFileUrl;
//设置作业名称 sbatch -J my_job
private String jobName;
// 指定分区(队列) -p gpu
private String partition;
//作业时间限制(格式:D-HH:MM:SS) -t 2-12:00:00
private String time;
//节点数量 --nodes=4
private Integer nodes;
//总任务数(MPI进程数) --ntasks=128
private Integer ntasks;
//每个任务的CPU核心数 --cpus-per-task=4
private Integer cpusPerTas;
//每个节点内存(单位:M/G) --mem=8G
private String mem;
// 每个CPU核心内存 --mem-per-cpu=1G
private String memPerCpu;
//通用资源(如GPU) --gres=gpu:2
private String gres;
//独占节点(即使资源未用完) sbatch --exclusive
private Boolean exclusive;
//指定节点列表 --nodelist=node[1-5,7]
private String nodelist;
//排除节点 --exclude=node6
private String exclude;
}
提交作业展示:
附上gitee代码地址:
easy-slurm: java封装slurm的api,方便用户通过http接口操作slurmsbatch 提交作业squeue 作业列表cancel 取消作业sinfo 节点信息
DAMO开发者矩阵,由阿里巴巴达摩院和中国互联网协会联合发起,致力于探讨最前沿的技术趋势与应用成果,搭建高质量的交流与分享平台,推动技术创新与产业应用链接,围绕“人工智能与新型计算”构建开放共享的开发者生态。
更多推荐

所有评论(0)