看了下slurm官方提供的rest-api,看懂比较复杂,所以尝试了下自己封装了个api,保留一些简单的参数(够用就行)

这是大概的接口列表,未来会持续迭代

思路的话主要是通过jsch连接远程服务器,通过服务器去执行slurm命令,并抓取服务器的返回

首先封装jsch工具类,这边因为考虑到直接在slurm的工作节点上去执行命令,所以用单例模式:

package com.easy.slurm.slurm.Jsch;

import com.alibaba.fastjson.JSONObject;
import com.jcraft.jsch.*;
import lombok.extern.slf4j.Slf4j;

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicBoolean;

/**
 * @author bing.bai
 * @create 2025/5/30
 */
@Slf4j
public class JschClient {

    private static final JschClient INSTANCE = new JschClient();
    private static final int DEFAULT_TIMEOUT = 30; // 默认超时时间(秒)
    private static final int THREAD_TIMEOUT = 10; // 线程等待超时时间(秒)

    private final AtomicBoolean inited = new AtomicBoolean(false);
    private Session session;
    private ExecutorService executorService;

    private JschClient() {
    }

    public static JschClient getInstance() {
        return INSTANCE;
    }

    public synchronized void init(String host, String user, String password) throws JSchException {
        if (inited.get()) {
            log.warn("JschClient already initialized");
            return;
        }

        JSch jsch = new JSch();
        session = jsch.getSession(user, host, 22);
        session.setPassword(password);
        session.setConfig("StrictHostKeyChecking", "no");
        session.setServerAliveInterval(30 * 1000);
        session.connect();

        // 创建专用线程池用于流读取
        executorService = Executors.newFixedThreadPool(2);
        inited.set(true);
        log.info("JschClient initialized successfully for host: {}", host);
    }

    public RunCommandResult runCommand(String command) throws JSchException, IOException {
        log.info("run command :{}",command);
        return runCommand(command, DEFAULT_TIMEOUT);
    }

    public RunCommandResult runCommand(String command, int timeoutSeconds) throws JSchException, IOException {
        if (!inited.get()) {
            throw new IllegalStateException("JschClient not initialized");
        }

        RunCommandResult result = new RunCommandResult();
        ChannelExec channel = null;

        try {
            channel = (ChannelExec) session.openChannel("exec");
            channel.setCommand(command);
            channel.setInputStream(null);

            // 准备读取流
            InputStream in = channel.getInputStream();
            InputStream err = channel.getErrStream();

            channel.connect();

            // 使用CompletableFuture并行读取两个流
            CompletableFuture<List<String>> stdoutFuture = readStreamAsync(in, "stdout");
            CompletableFuture<List<String>> stderrFuture = readStreamAsync(err, "stderr");

            // 等待流读取完成或超时
            CompletableFuture<Void> allFutures = CompletableFuture.allOf(stdoutFuture, stderrFuture);
            try {
                allFutures.get(timeoutSeconds, TimeUnit.SECONDS);
            } catch (InterruptedException e) {
                Thread.currentThread().interrupt();
                log.warn("Command execution interrupted", e);
            } catch (ExecutionException e) {
                log.error("Error reading command output", e.getCause());
            } catch (TimeoutException e) {
                log.warn("Command execution timed out after {} seconds", timeoutSeconds);
                // 取消未完成的读取任务
                stdoutFuture.cancel(true);
                stderrFuture.cancel(true);
            }

            // 获取结果(即使超时也获取已读取的部分)
            result.setStdout(stdoutFuture.getNow(new ArrayList<>()));
            result.setStderr(stderrFuture.getNow(new ArrayList<>()));
            result.setExitStatus(channel.getExitStatus());
            System.out.println(JSONObject.toJSONString(result));
            return result;
        } finally {
            if (channel != null) {
                channel.disconnect();
            }
        }
    }

    private CompletableFuture<List<String>> readStreamAsync(InputStream stream, String streamType) {
        return CompletableFuture.supplyAsync(() -> {
            List<String> lines = new ArrayList<>();
            try (BufferedReader reader = new BufferedReader(
                    new InputStreamReader(stream, StandardCharsets.UTF_8))) {
                String line;
                while ((line = reader.readLine()) != null) {
                    lines.add(line);
                    // 检查线程是否被中断(超时或取消时)
                    if (Thread.currentThread().isInterrupted()) {
                        log.debug("{} reading interrupted", streamType);
                        break;
                    }
                }
            } catch (IOException e) {
                log.error("Error reading {} stream", streamType, e);
            }
            return lines;
        }, executorService);
    }

    public synchronized void destroy() {
        if (!inited.get()) {
            log.warn("JschClient not initialized, nothing to destroy");
            return;
        }

        if (session != null && session.isConnected()) {
            session.disconnect();
        }

        if (executorService != null) {
            try {
                executorService.shutdown();
                if (!executorService.awaitTermination(THREAD_TIMEOUT, TimeUnit.SECONDS)) {
                    executorService.shutdownNow();
                }
            } catch (InterruptedException e) {
                executorService.shutdownNow();
                Thread.currentThread().interrupt();
            }
        }

        inited.set(false);
        log.info("JschClient destroyed successfully");
    }

    public static class RunCommandResult {
        private List<String> stdout = new ArrayList<>();
        private List<String> stderr = new ArrayList<>();
        private int exitStatus;

        // Getters and setters
        public List<String> getStdout() {
            return stdout;
        }

        public void setStdout(List<String> stdout) {
            this.stdout = stdout;
        }

        public List<String> getStderr() {
            return stderr;
        }

        public void setStderr(List<String> stderr) {
            this.stderr = stderr;
        }

        public int getExitStatus() {
            return exitStatus;
        }

        public void setExitStatus(int exitStatus) {
            this.exitStatus = exitStatus;
        }

        public boolean isSuccess() {
            return exitStatus == 0;
        }

        public String getErrorMsg() {
            String errorMsg = "";
            for (int i = 0; i < stderr.size(); i++) {
                errorMsg += stderr.get(i);
            }
            return errorMsg;
        }

        public String getSuccessMsg() {
            String successMsg = "";
            for (int i = 0; i < stdout.size(); i++) {
                successMsg += stdout.get(i);
            }
            return successMsg;
        }

    }
}

jsch的配置类:

package com.easy.slurm.slurm.Jsch;

import com.jcraft.jsch.JSchException;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Component;

import javax.annotation.PostConstruct;

/**
 * @author bing.bai
 * @create 2025/5/30
 */
@Component
public class JschConfig {

    @Value("${jsch.host}")
    private String host;

    @Value("${jsch.user}")
    private String user;

    @Value("${jsch.password}")
    private String password;

    @PostConstruct
    public void init() throws JSchException {
        if (StringUtils.isNotBlank(host)) {
            JschClient instance = JschClient.getInstance();
            instance.init(host,user,password);
        }
    }
}

这边主要讲一个作业提交,封装了slurm的api,resultvo为自定义的返回;:

主要是将用户的输入拼接成slurm的命令,在服务器上执行

  @Override
    public ResultVo submitSbatchJob(SbatchJobSubmit sbatchJobSubmit) throws JSchException, IOException {
        if (StringUtils.isBlank(sbatchJobSubmit.getJobFileUrl())) {
            throw new BussinessException(ResultCode.FILEURLNOTEXIST);
        }
        String sbatchCommand = combineSbatchJobCommand(sbatchJobSubmit);
        JschClient instance = JschClient.getInstance();
        JschClient.RunCommandResult runCommandResult = instance.runCommand(sbatchCommand);
        if (runCommandResult.isSuccess()) {
            return ResultVo.success(runCommandResult.getSuccessMsg());
        } else {
            return ResultVo.fail(runCommandResult.getErrorMsg());
        }
    }

  private String combineSbatchJobCommand(SbatchJobSubmit sbatchJobSubmit) {
        StringBuilder command = new StringBuilder(SBATCH);
        command.append(" -D " + SlurmConstants.WORKDIR + "/" + sbatchJobSubmit.getTenantId());
        if (StringUtils.isNotBlank(sbatchJobSubmit.getJobName())) {
            command.append(" -J " + sbatchJobSubmit.getJobName());
        }
        if (StringUtils.isNotBlank(sbatchJobSubmit.getPartition())) {
            command.append(" -p " + sbatchJobSubmit.getPartition());
        }
        if (StringUtils.isNotBlank(sbatchJobSubmit.getTime())) {
            command.append(" -t " + sbatchJobSubmit.getTime());
        }
        if (sbatchJobSubmit.getNodes() != null) {
            command.append(" --nodes=" + sbatchJobSubmit.getNodes()) ;
        }
        if (sbatchJobSubmit.getNtasks() != null) {
            command.append(" --ntasks=" + sbatchJobSubmit.getNtasks());
        }
        if (sbatchJobSubmit.getCpusPerTas() != null) {
            command.append(" --cpus-per-task=" + sbatchJobSubmit.getCpusPerTas());
        }
        if (StringUtils.isNotBlank(sbatchJobSubmit.getMem())) {
            command.append(" --mem= " + sbatchJobSubmit.getMem()) ;
        }
        if (StringUtils.isNotBlank(sbatchJobSubmit.getMemPerCpu())) {
            command.append(" --mem-per-cpu= " + sbatchJobSubmit.getMemPerCpu()) ;
        }
        if (StringUtils.isNotBlank(sbatchJobSubmit.getGres())) {
            command.append(" --gres= " + sbatchJobSubmit.getGres());
        }
        if (sbatchJobSubmit.getExclusive() != null && sbatchJobSubmit.getExclusive()) {
            command.append(" --exclusive= ");
        }
        if (StringUtils.isNotBlank(sbatchJobSubmit.getNodelist())) {
            command.append(" --nodelist= ");
        }
        if (StringUtils.isNotBlank(sbatchJobSubmit.getExclude())) {
            command.append(" --exclude= " + sbatchJobSubmit.getExclude());
        }
        command.append(" " + sbatchJobSubmit.getJobFileUrl()) ;
        return command.toString();
    }

用户输入类:

package com.easy.slurm.slurm.bean;

import lombok.Data;

/**
 * @author bing.bai
 * @create 2025/6/5
 */
@Data
public class SbatchJobSubmit {

    private String tenantId;

    //sbatch脚本地址
    private String jobFileUrl;

    //设置作业名称 sbatch -J my_job
    private String jobName;

    //	指定分区(队列) -p gpu
    private String partition;

    //作业时间限制(格式:D-HH:MM:SS) -t 2-12:00:00
    private String time;

    //节点数量	--nodes=4
    private Integer nodes;

    //总任务数(MPI进程数)	--ntasks=128
    private Integer ntasks;

    //每个任务的CPU核心数	--cpus-per-task=4
    private Integer cpusPerTas;

    //每个节点内存(单位:M/G)	--mem=8G
    private String mem;

    //	每个CPU核心内存	--mem-per-cpu=1G
    private String memPerCpu;

    //通用资源(如GPU)	--gres=gpu:2
    private String gres;

    //独占节点(即使资源未用完)	sbatch --exclusive
    private Boolean exclusive;

    //指定节点列表	--nodelist=node[1-5,7]
    private String nodelist;

    //排除节点	--exclude=node6
    private String exclude;


}

提交作业展示:

附上gitee代码地址:

easy-slurm: java封装slurm的api,方便用户通过http接口操作slurmsbatch 提交作业squeue 作业列表cancel 取消作业sinfo 节点信息

Logo

DAMO开发者矩阵,由阿里巴巴达摩院和中国互联网协会联合发起,致力于探讨最前沿的技术趋势与应用成果,搭建高质量的交流与分享平台,推动技术创新与产业应用链接,围绕“人工智能与新型计算”构建开放共享的开发者生态。

更多推荐