package pt.ua.dicoogle.server.web.servlets.mlprovider;

import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.io.PrintWriter;
import java.nio.file.Files;
import java.nio.file.OpenOption;
import java.util.Iterator;
import java.util.List;
import java.util.UUID;
import java.util.concurrent.ExecutionException;
import javax.servlet.ServletException;
import javax.servlet.ServletOutputStream;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import net.kencochrane.raven.marshaller.json.JsonMarshaller;
import org.apache.commons.cli.HelpFormatter;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.IOUtils;
import org.apache.commons.lang3.StringUtils;
import org.apache.logging.log4j.core.lookup.StructuredDataLookup;
import org.dcm4che3.imageio.plugins.dcm.DicomMetaData;
import org.restlet.data.Status;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import pt.ua.dicoogle.plugins.PluginController;
import pt.ua.dicoogle.sdk.datastructs.dim.BulkAnnotation;
import pt.ua.dicoogle.sdk.datastructs.dim.DimLevel;
import pt.ua.dicoogle.sdk.datastructs.dim.Point2D;
import pt.ua.dicoogle.sdk.mlprovider.MLInference;
import pt.ua.dicoogle.sdk.mlprovider.MLInferenceRequest;
import pt.ua.dicoogle.sdk.task.Task;
import pt.ua.dicoogle.server.web.dicom.ROIExtractor;
import pt.ua.dicoogle.server.web.dicom.WSISopDescriptor;
import pt.ua.dicoogle.server.web.utils.ResponseUtil;
import pt.ua.dicoogle.server.web.utils.cache.WSICache;

/* loaded from: input_file:pt/ua/dicoogle/server/web/servlets/mlprovider/InferServlet.class */
public class InferServlet extends HttpServlet {
    private static final Logger log = LoggerFactory.getLogger((Class<?>) InferServlet.class);
    private final WSICache wsiCache = WSICache.getInstance();
    private final ROIExtractor roiExtractor = new ROIExtractor();

    @Override // javax.servlet.http.HttpServlet
    protected void doPost(HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse) throws ServletException, IOException {
        Task<MLInference> sendRequest;
        String iOUtils = IOUtils.toString(httpServletRequest.getReader());
        ObjectMapper objectMapper = new ObjectMapper();
        JsonNode readTree = objectMapper.readTree(iOUtils);
        if (!readTree.has(JsonMarshaller.LEVEL)) {
            ResponseUtil.sendError(httpServletResponse, Status.CLIENT_ERROR_BAD_REQUEST.getCode(), "DIM level provided was invalid");
            return;
        }
        if (!readTree.has("uid")) {
            ResponseUtil.sendError(httpServletResponse, Status.CLIENT_ERROR_BAD_REQUEST.getCode(), "DIM UID provided was invalid");
            return;
        }
        if (!readTree.has("provider")) {
            ResponseUtil.sendError(httpServletResponse, Status.CLIENT_ERROR_BAD_REQUEST.getCode(), "Provider provided was invalid");
            return;
        }
        if (!readTree.has("modelID")) {
            ResponseUtil.sendError(httpServletResponse, Status.CLIENT_ERROR_BAD_REQUEST.getCode(), "Model identifier was invalid");
            return;
        }
        String asText = readTree.get("provider").asText();
        String asText2 = readTree.get("modelID").asText();
        boolean z = readTree.has("wsi") && readTree.get("wsi").asBoolean();
        DimLevel valueOf = DimLevel.valueOf(readTree.get(JsonMarshaller.LEVEL).asText().toUpperCase());
        String asText3 = readTree.get("uid").asText();
        if (!z) {
            sendRequest = sendRequest(asText, asText2, valueOf, asText3, httpServletResponse);
        } else if (!readTree.has("points") || !readTree.has(StructuredDataLookup.TYPE_KEY) || !readTree.has("baseUID")) {
            ResponseUtil.sendError(httpServletResponse, Status.CLIENT_ERROR_BAD_REQUEST.getCode(), "Insufficient data to build request");
            return;
        } else {
            if (valueOf != DimLevel.INSTANCE) {
                ResponseUtil.sendError(httpServletResponse, Status.CLIENT_ERROR_BAD_REQUEST.getCode(), "Only Instance level is supported with WSI");
                return;
            }
            sendRequest = sendWSIRequest(asText, asText2, readTree.get("baseUID").asText(), asText3, BulkAnnotation.AnnotationType.valueOf(readTree.get(StructuredDataLookup.TYPE_KEY).asText()), (List) objectMapper.readValue(readTree.get("points").toString(), new TypeReference<List<Point2D>>() { // from class: pt.ua.dicoogle.server.web.servlets.mlprovider.InferServlet.1
            }), httpServletResponse);
        }
        if (sendRequest == null) {
            ResponseUtil.sendError(httpServletResponse, Status.SERVER_ERROR_INTERNAL.getCode(), "Could not build prediction request");
        } else {
            sendRequest.run();
        }
    }

    private Task<MLInference> sendWSIRequest(String str, String str2, String str3, String str4, BulkAnnotation.AnnotationType annotationType, List<Point2D> list, HttpServletResponse httpServletResponse) {
        ObjectMapper objectMapper = new ObjectMapper();
        MLInferenceRequest mLInferenceRequest = new MLInferenceRequest(true, DimLevel.INSTANCE, str4, str2);
        BulkAnnotation bulkAnnotation = new BulkAnnotation();
        bulkAnnotation.setPoints(list);
        bulkAnnotation.setAnnotationType(annotationType);
        try {
            DicomMetaData dicomMetadata = getDicomMetadata(str4);
            mLInferenceRequest.setRoi(this.roiExtractor.extractROI(dicomMetadata, bulkAnnotation));
            Task<MLInference> infer = PluginController.getInstance().infer(str, mLInferenceRequest);
            if (infer != null) {
                infer.onCompletion(() -> {
                    try {
                        try {
                            MLInference mLInference = (MLInference) infer.get();
                            if (mLInference == null) {
                                log.error("Provider returned null prediction");
                                ResponseUtil.sendError(httpServletResponse, Status.SERVER_ERROR_INTERNAL.getCode(), "Could not make prediction");
                                return;
                            }
                            if (!mLInference.getAnnotations().isEmpty()) {
                                new WSISopDescriptor().extractData(dicomMetadata.getAttributes());
                                DicomMetaData dicomMetadata2 = getDicomMetadata(str3);
                                new WSISopDescriptor().extractData(dicomMetadata2.getAttributes());
                                convertCoordinates(mLInference, bulkAnnotation.getBoundingBox().get(0), (r0.getTotalPixelMatrixRows() * 1.0d) / r0.getTotalPixelMatrixRows());
                            }
                            httpServletResponse.setContentType("application/json");
                            PrintWriter writer = httpServletResponse.getWriter();
                            objectMapper.writeValue(writer, mLInference);
                            writer.close();
                            writer.flush();
                        } catch (InterruptedException | ExecutionException e) {
                            log.error("Could not make prediction", e);
                            try {
                                ResponseUtil.sendError(httpServletResponse, Status.SERVER_ERROR_INTERNAL.getCode(), "Could not make prediction");
                            } catch (IOException e2) {
                                throw new RuntimeException(e2);
                            }
                        }
                    } catch (IOException e3) {
                        throw new RuntimeException(e3);
                    }
                });
            }
            return infer;
        } catch (IOException e) {
            return null;
        }
    }

    private Task<MLInference> sendRequest(String str, String str2, DimLevel dimLevel, String str3, HttpServletResponse httpServletResponse) {
        ObjectMapper objectMapper = new ObjectMapper();
        Task<MLInference> infer = PluginController.getInstance().infer(str, new MLInferenceRequest(false, dimLevel, str3, str2));
        if (infer != null) {
            infer.onCompletion(() -> {
                InputStream newInputStream;
                try {
                    try {
                        MLInference mLInference = (MLInference) infer.get();
                        if (mLInference == null) {
                            log.error("Provider returned null prediction");
                            ResponseUtil.sendError(httpServletResponse, Status.SERVER_ERROR_INTERNAL.getCode(), "Could not make prediction");
                            return;
                        }
                        if (mLInference.getDicomSEG() != null && !mLInference.hasResults()) {
                            httpServletResponse.setContentType("application/dicom");
                            ServletOutputStream outputStream = httpServletResponse.getOutputStream();
                            newInputStream = Files.newInputStream(mLInference.getDicomSEG(), new OpenOption[0]);
                            try {
                                IOUtils.copy(newInputStream, outputStream);
                                outputStream.flush();
                                if (newInputStream != null) {
                                    newInputStream.close();
                                }
                            } finally {
                            }
                        } else if (mLInference.getDicomSEG() == null || !mLInference.hasResults()) {
                            httpServletResponse.setContentType("application/json");
                            PrintWriter writer = httpServletResponse.getWriter();
                            objectMapper.writeValue(writer, mLInference);
                            writer.close();
                            writer.flush();
                        } else {
                            String uuid = UUID.randomUUID().toString();
                            httpServletResponse.setContentType("multipart/form-data; boundary=" + uuid);
                            ServletOutputStream outputStream2 = httpServletResponse.getOutputStream();
                            outputStream2.print(HelpFormatter.DEFAULT_LONG_OPT_PREFIX + uuid);
                            outputStream2.println();
                            outputStream2.print("Content-Disposition: form-data; name=\"params\"");
                            outputStream2.println();
                            outputStream2.print("Content-Type: application/json");
                            outputStream2.println();
                            outputStream2.println();
                            outputStream2.print(objectMapper.writeValueAsString(mLInference));
                            outputStream2.println();
                            outputStream2.print(HelpFormatter.DEFAULT_LONG_OPT_PREFIX + uuid);
                            outputStream2.println();
                            outputStream2.print("Content-Disposition: form-data; name=\"dicomseg\"; filename=\"dicomseg.dcm\"");
                            outputStream2.println();
                            outputStream2.print("Content-Type: application/dicom");
                            outputStream2.println();
                            outputStream2.println();
                            newInputStream = Files.newInputStream(mLInference.getDicomSEG(), new OpenOption[0]);
                            try {
                                IOUtils.copy(newInputStream, outputStream2);
                                outputStream2.flush();
                                if (newInputStream != null) {
                                    newInputStream.close();
                                }
                                outputStream2.println();
                                outputStream2.print(HelpFormatter.DEFAULT_LONG_OPT_PREFIX + uuid + HelpFormatter.DEFAULT_LONG_OPT_PREFIX);
                                outputStream2.flush();
                                outputStream2.close();
                            } finally {
                            }
                        }
                        try {
                            if (!StringUtils.isBlank(mLInference.getResourcesFolder())) {
                                FileUtils.deleteDirectory(new File(mLInference.getResourcesFolder()));
                            }
                        } catch (IOException e) {
                            log.warn("Could not delete temporary file", (Throwable) e);
                        }
                    } catch (InterruptedException | ExecutionException e2) {
                        log.error("Could not make prediction", e2);
                        try {
                            ResponseUtil.sendError(httpServletResponse, Status.SERVER_ERROR_INTERNAL.getCode(), "Could not make prediction");
                        } catch (IOException e3) {
                            throw new RuntimeException(e3);
                        }
                    }
                } catch (IOException e4) {
                    throw new RuntimeException(e4);
                }
            });
        }
        return infer;
    }

    private DicomMetaData getDicomMetadata(String str) throws IOException {
        return this.wsiCache.get(str);
    }

    private void convertCoordinates(MLInference mLInference, Point2D point2D, double d) {
        Iterator<BulkAnnotation> it = mLInference.getAnnotations().iterator();
        while (it.hasNext()) {
            for (Point2D point2D2 : it.next().getPoints()) {
                point2D2.setX((point2D2.getX() + point2D.getX()) / d);
                point2D2.setY((point2D2.getY() + point2D.getY()) / d);
            }
        }
    }
}
