|
@@ -7,6 +7,7 @@ import com.spotify.docker.client.exceptions.DockerException;
|
|
|
import com.spotify.docker.client.messages.Image;
|
|
|
import io.minio.errors.*;
|
|
|
import io.renren.common.annotation.SysLog;
|
|
|
+import io.renren.common.exception.RRException;
|
|
|
import io.renren.common.utils.*;
|
|
|
import io.renren.common.validator.ValidatorUtils;
|
|
|
import io.renren.common.validator.group.AddGroup;
|
|
@@ -21,6 +22,7 @@ import io.renren.modules.sys.service.impl.AlgTrainServiceImpl;
|
|
|
import io.renren.modules.sys.service.impl.AlgsModelsServiceImpl;
|
|
|
import io.renren.modules.sys.service.impl.AlgsServiceImpl;
|
|
|
import io.swagger.models.auth.In;
|
|
|
+import lombok.extern.slf4j.Slf4j;
|
|
|
import org.apache.commons.lang3.StringUtils;
|
|
|
import org.apache.shiro.authz.annotation.RequiresPermissions;
|
|
|
import org.springframework.beans.factory.annotation.Autowired;
|
|
@@ -39,6 +41,9 @@ import java.security.NoSuchAlgorithmException;
|
|
|
import java.text.SimpleDateFormat;
|
|
|
import java.util.*;
|
|
|
import java.util.concurrent.ConcurrentHashMap;
|
|
|
+import java.util.concurrent.CountDownLatch;
|
|
|
+import java.util.concurrent.ExecutorService;
|
|
|
+import java.util.concurrent.Executors;
|
|
|
import java.util.stream.Collectors;
|
|
|
|
|
|
import static io.renren.common.utils.ShiroUtils.getUserId;
|
|
@@ -47,6 +52,7 @@ import static io.renren.common.utils.ShiroUtils.getUserId;
|
|
|
* @Author: Ivan Q
|
|
|
* @Date: 2021/6/11 16:28
|
|
|
*/
|
|
|
+@Slf4j
|
|
|
@RestController
|
|
|
@RequestMapping("/algstrain")
|
|
|
public class algTrainController {
|
|
@@ -77,6 +83,8 @@ public class algTrainController {
|
|
|
|
|
|
public final TestSubscriber subscriber = new TestSubscriber();
|
|
|
public final TestPublisher publisher = new TestPublisher();
|
|
|
+ ExecutorService executorService = Executors.newSingleThreadExecutor(); // 创建单线程执行器
|
|
|
+
|
|
|
|
|
|
private static String returnFileName;
|
|
|
|
|
@@ -192,7 +200,8 @@ public class algTrainController {
|
|
|
//以下部分是创建容器部分
|
|
|
|
|
|
//首先获取内存大小限制,并转换为以字节为单位
|
|
|
- Long memoryMB=Long.parseLong(map.get("memory"));
|
|
|
+// Long memoryMB=Long.parseLong(map.get("memory"));
|
|
|
+ Long memoryMB = 512L;
|
|
|
Long memoryByte=memoryMB*1024*1024;
|
|
|
|
|
|
//选择一个未使用的端口进行映射,并标记已使用
|
|
@@ -208,7 +217,7 @@ public class algTrainController {
|
|
|
}*/
|
|
|
//获取该算法在被创建时所选的算法框架
|
|
|
String baseImageName=baseImageService.getById(alg.getFrameId()).getBaseImageName();
|
|
|
- String containerId=DockerClientUtils.createContainer("algTrain"+algTrain.getAlgorithmTrainingId(),String.valueOf(portInt),memoryByte,baseImageName,map.get("setCpus"));
|
|
|
+ String containerId=DockerClientUtils.createContainer("algTrain"+algTrain.getAlgorithmTrainingId(),String.valueOf(portInt),memoryByte,baseImageName,"1");
|
|
|
|
|
|
//保存该训练任务所在容器id
|
|
|
algTrain.setContainerId(containerId);
|
|
@@ -246,29 +255,39 @@ public class algTrainController {
|
|
|
* @return
|
|
|
*/
|
|
|
@GetMapping("/startTraining")
|
|
|
- public R startTraining(String algorithmTrainingId) throws DockerException, InterruptedException, IOException, InvalidResponseException, InvalidKeyException, NoSuchAlgorithmException, ErrorResponseException, XmlParserException, InvalidBucketNameException, InsufficientDataException, InternalException {
|
|
|
+ public R startTraining(String algorithmTrainingId) throws DockerException, InterruptedException, RegionConflictException, InvalidBucketNameException, InsufficientDataException, ErrorResponseException, IOException, NoSuchAlgorithmException, InvalidKeyException, InvalidResponseException, XmlParserException, InternalException {
|
|
|
AlgTrain algTrain=algTrainService.selectByPrimaryKey(Long.parseLong(algorithmTrainingId));
|
|
|
- returnFileName = DockerClientUtils.execPython(algTrain.getContainerId(),"algTrain"+algTrain.getAlgorithmTrainingId(),algTrain.getRunfileName());
|
|
|
-
|
|
|
- //完成训练后,将任务状态改为已结束,并保存结束时间
|
|
|
- algTrain.setMissStatus((byte) 3);
|
|
|
- //将对应版本训练状态改为已训练
|
|
|
- algTrain.setHasRun(1);
|
|
|
+ try {
|
|
|
+// if(!DockerClientUtils.isContainerRunning(algTrain.getContainerId())){
|
|
|
+// DockerClientUtils.startContainer(algTrain.getContainerId());
|
|
|
+// Thread.sleep(7000);
|
|
|
+// }
|
|
|
+ String minioPath = DockerClientUtils.execPython(algTrain.getContainerId(), "algTrain" + algTrain.getAlgorithmTrainingId(), algTrain.getRunfileName(), algorithmTrainingId);
|
|
|
+
|
|
|
+ //完成训练后,将任务状态改为已结束,并保存结束时间
|
|
|
+ algTrain.setMissStatus((byte) 3);
|
|
|
+ //将对应版本训练状态改为已训练
|
|
|
+ algTrain.setHasRun(1);
|
|
|
/*Version version=versionService.getById(algTrain.getVersionId());
|
|
|
version.setVersionStatus((byte) 1);
|
|
|
versionService.update(version);*/
|
|
|
|
|
|
- Date date = new Date();
|
|
|
- //SimpleDateFormat sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss");
|
|
|
- algTrain.setMissStopTime(date);
|
|
|
- algTrainService.update(algTrain);
|
|
|
-
|
|
|
- //保存日志到数据库
|
|
|
- String algorithmTrainingLogContent= (String) getOutput(algorithmTrainingId).get("output");
|
|
|
- AlgTrainLog algTrainLog=new AlgTrainLog();
|
|
|
- algTrainLog.setAlgorithmTrainingId(Long.parseLong(algorithmTrainingId));
|
|
|
- algTrainLog.setAlgorithmTrainingLogContent(algorithmTrainingLogContent);
|
|
|
- algTrainLogService.save(algTrainLog);
|
|
|
+ Date date = new Date();
|
|
|
+ //SimpleDateFormat sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss");
|
|
|
+ algTrain.setMissStopTime(date);
|
|
|
+ algTrainService.update(algTrain);
|
|
|
+
|
|
|
+ //保存日志到数据库
|
|
|
+ String algorithmTrainingLogContent = (String) getOutput(algorithmTrainingId, minioPath).get("output");
|
|
|
+ AlgTrainLog algTrainLog = new AlgTrainLog();
|
|
|
+ algTrainLog.setAlgorithmTrainingId(Long.parseLong(algorithmTrainingId));
|
|
|
+ algTrainLog.setAlgorithmTrainingLogContent(algorithmTrainingLogContent);
|
|
|
+ algTrainLog.setAlgorithmTrainingLogMinioPath(minioPath);
|
|
|
+ algTrainLogService.save(algTrainLog);
|
|
|
+ //运行结束之后停掉docker
|
|
|
+ } finally {
|
|
|
+// DockerClientUtils.stopContainer(algTrain.getContainerId());
|
|
|
+ }
|
|
|
|
|
|
// 产生数据
|
|
|
// publisher.publishMessage(
|
|
@@ -296,41 +315,74 @@ public class algTrainController {
|
|
|
|
|
|
@Async
|
|
|
@Scheduled(fixedRate = 6000)//每6秒执行一次,获取消息
|
|
|
- public void checkAlgRequest() throws DockerException, InvalidBucketNameException, InsufficientDataException, ErrorResponseException, IOException, NoSuchAlgorithmException, InvalidKeyException, InterruptedException, InvalidResponseException, XmlParserException, InternalException {
|
|
|
+ public void checkAlgRequest() throws DockerException, InvalidBucketNameException, InsufficientDataException, ErrorResponseException, IOException, NoSuchAlgorithmException, InvalidKeyException, InterruptedException, InvalidResponseException, XmlParserException, InternalException, RegionConflictException {
|
|
|
if (subscriber.listener != null && subscriber.listener.isRequest == 0) {
|
|
|
publisher.publishMessage(
|
|
|
"4",
|
|
|
"111",
|
|
|
"算法文件使用请求"
|
|
|
);
|
|
|
+
|
|
|
}
|
|
|
+ // 创建 CountDownLatch,初始化为1
|
|
|
+ CountDownLatch latch = new CountDownLatch(1);
|
|
|
+ // 启动一个线程检查 publisher 初始化状态
|
|
|
+ new Thread(() -> {
|
|
|
+ while (!publisher.isInitialized) {
|
|
|
+ try {
|
|
|
+ // 等待一段时间后再次检查
|
|
|
+ Thread.sleep(100);
|
|
|
+ } catch (InterruptedException e) {
|
|
|
+ Thread.currentThread().interrupt(); // 恢复中断状态
|
|
|
+ }
|
|
|
+ }
|
|
|
+ latch.countDown(); // 初始化成功,释放锁
|
|
|
+ }).start();
|
|
|
+ // 等待 publisher 初始化完成
|
|
|
+ System.out.println("Waiting for publisher to initialize...");
|
|
|
+ latch.await();
|
|
|
publisher.publishMessage(
|
|
|
"4",
|
|
|
"8",
|
|
|
"这是心跳信息"
|
|
|
);
|
|
|
|
|
|
- if(subscriber.listener != null && subscriber.listener.isRequest == 1) {
|
|
|
+ if(subscriber.listener != null && subscriber.listener.isRequest == 1 && !subscriber.listener.algMap.isEmpty()) {
|
|
|
for(Map.Entry<String, String> map : subscriber.listener.algMap.entrySet()){
|
|
|
String key = map.getKey();
|
|
|
String missName = map.getValue();
|
|
|
AlgTrain algTrain = algTrainService.selectByMissName(missName);
|
|
|
+ if (algTrain == null) {
|
|
|
+ continue;
|
|
|
+ }
|
|
|
Long algorithmTrainingId = algTrain.getAlgorithmTrainingId();
|
|
|
- startTraining(String.valueOf(algorithmTrainingId));
|
|
|
+ executorService.submit(() -> {
|
|
|
+ try {
|
|
|
+ startTraining(String.valueOf(algorithmTrainingId));
|
|
|
+ } catch (Exception e) {
|
|
|
+ throw new RRException(e.getMessage());
|
|
|
+ }
|
|
|
+ });
|
|
|
+
|
|
|
}
|
|
|
+ subscriber.listener.algMap.clear();
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- @PostConstruct
|
|
|
- public void init() throws DockerException, InvalidBucketNameException, InsufficientDataException, ErrorResponseException, IOException, NoSuchAlgorithmException, InvalidKeyException, InterruptedException, InvalidResponseException, XmlParserException, InternalException {
|
|
|
- // 初始化时运行
|
|
|
- if (subscriber.listener.isRequest == 1) {
|
|
|
- algRun(subscriber.listener.algMap);
|
|
|
- }
|
|
|
- subscriber.listener.isRequest = 0;
|
|
|
- }
|
|
|
+// @PostConstruct
|
|
|
+// public void init() throws DockerException, InvalidBucketNameException, InsufficientDataException, ErrorResponseException, IOException, NoSuchAlgorithmException, InvalidKeyException, InterruptedException, InvalidResponseException, XmlParserException, InternalException, RegionConflictException {
|
|
|
+// // 初始化时运行
|
|
|
+// int flag = 0; //标志位,判断是否初始化成功
|
|
|
+// if (subscriber.listener != null && subscriber.listener.isRequest == 1) {
|
|
|
+//// algRun(subscriber.listener.algMap);
|
|
|
+// flag = 1;
|
|
|
+// }
|
|
|
+// if (flag == 1) {
|
|
|
+// subscriber.listener.isRequest = 0;
|
|
|
+// }
|
|
|
+// }
|
|
|
|
|
|
- private void algRun(Map<String, String> algMap) throws DockerException, InvalidBucketNameException, InsufficientDataException, ErrorResponseException, IOException, NoSuchAlgorithmException, InvalidKeyException, InterruptedException, InvalidResponseException, XmlParserException, InternalException {
|
|
|
+ private void algRun(Map<String, String> algMap) throws DockerException, InvalidBucketNameException, InsufficientDataException, ErrorResponseException, IOException, NoSuchAlgorithmException, InvalidKeyException, InterruptedException, InvalidResponseException, XmlParserException, InternalException, RegionConflictException {
|
|
|
for(Map.Entry<String, String> map : algMap.entrySet()){
|
|
|
String key = map.getKey();
|
|
|
String missName = map.getValue();
|
|
@@ -414,19 +466,39 @@ public class algTrainController {
|
|
|
* @return
|
|
|
*/
|
|
|
@GetMapping("/getOutput")
|
|
|
- public R getOutput(String algorithmTrainingId) throws DockerException, InterruptedException, IOException, InvalidResponseException, InvalidKeyException, NoSuchAlgorithmException, ErrorResponseException, XmlParserException, InvalidBucketNameException, InsufficientDataException, InternalException {
|
|
|
- AlgTrain algTrain = algTrainService.selectByPrimaryKey(Long.parseLong(algorithmTrainingId));
|
|
|
- DockerClientUtils.copyFile(algTrain.getContainerId(),"/" + returnFileName,"/opt/algTrain" + algorithmTrainingId);
|
|
|
- InputStream inputStream = FTPUtils.downloadFile("/opt/uploadFile/algTrain" + algorithmTrainingId + "/" + returnFileName);
|
|
|
- if (inputStream == null) {
|
|
|
- return R.error("文件不存在");
|
|
|
+ public R getOutput(String algorithmTrainingId, String minioPath) throws DockerException, InterruptedException, IOException, InvalidResponseException, InvalidKeyException, NoSuchAlgorithmException, ErrorResponseException, XmlParserException, InvalidBucketNameException, InsufficientDataException, InternalException {
|
|
|
+// AlgTrain algTrain = algTrainService.selectByPrimaryKey(Long.parseLong(algorithmTrainingId));
|
|
|
+// DockerClientUtils.copyFile(algTrain.getContainerId(),"/" + returnFileName,"/opt/algTrain" + algorithmTrainingId);
|
|
|
+// InputStream inputStream = FTPUtils.downloadFile("/opt/uploadFile/algTrain" + algorithmTrainingId + "/" + returnFileName);
|
|
|
+// if (inputStream == null) {
|
|
|
+// return R.error("结果不存在");
|
|
|
+// }
|
|
|
+// String result = new BufferedReader(new InputStreamReader(inputStream))
|
|
|
+// .lines().collect(Collectors.joining("\n"));
|
|
|
+//
|
|
|
+// inputStream.close();
|
|
|
+//// System.out.println(result);
|
|
|
+//
|
|
|
+// return R.ok().put("output",result);
|
|
|
+ String objectName = "";
|
|
|
+ if (minioPath == null) {
|
|
|
+ AlgTrainLog algTrainLog = algTrainLogService.selectByAlgTrainId(Long.parseLong(algorithmTrainingId));
|
|
|
+ objectName = algTrainLog.getAlgorithmTrainingLogMinioPath();
|
|
|
+ } else {
|
|
|
+ objectName = minioPath;
|
|
|
+ }
|
|
|
+ log.info("获取文件从minio的objectName {}", objectName);
|
|
|
+ try (InputStream fileInputStream = MinIoUtils.getFileInputStream("algorithm-train-task", objectName)) {
|
|
|
+ // 继续处理文件流
|
|
|
+ String result = new BufferedReader(new InputStreamReader(fileInputStream))
|
|
|
+ .lines().collect(Collectors.joining("\n"));
|
|
|
+ log.info("result {}", result);
|
|
|
+ return R.ok().put("output", result);
|
|
|
+ } catch (ErrorResponseException e){
|
|
|
+ if (e.getMessage().equals("The specified key does not exist."))
|
|
|
+ return R.error("文件不存在,请确认文件位置或者重新上传");
|
|
|
+ else return R.error(e.getMessage());
|
|
|
}
|
|
|
- String result = new BufferedReader(new InputStreamReader(inputStream))
|
|
|
- .lines().collect(Collectors.joining("\n"));
|
|
|
- inputStream.close();
|
|
|
- System.out.println(result);
|
|
|
-
|
|
|
- return R.ok().put("output",result);
|
|
|
}
|
|
|
|
|
|
/**
|