/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.lambda.controllers;

import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.ErrorContext;
import com.dataiku.dip.utils.JSON;
import com.dataiku.lambda.LambdaContext;
import com.dataiku.lambda.auth.AuthVerificationService;
import com.dataiku.lambda.controllers.LambdaControllerUtils;
import com.dataiku.lambda.controllers.RequestMetadata;
import com.dataiku.lambda.endpoints.predictcommon.PredictionEndpointHandlerBase;
import com.dataiku.lambda.model.api.ForecastQuery;
import com.dataiku.lambda.model.api.MultiplePredictionQuery;
import com.dataiku.lambda.model.api.PredictionResponse;
import com.dataiku.lambda.model.api.SinglePredictionQuery;
import com.dataiku.lambda.model.api.SinglePredictionResponse;
import com.dataiku.lambda.model.serverconfig.QueryAPIKey;
import com.dataiku.lambda.server.LambdaAPIControllerBase;
import com.dataiku.lambda.services.ServiceManager;
import com.dataiku.lambda.services.ServicesService;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import org.apache.log4j.MDC;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Controller;
import org.springframework.web.bind.annotation.PathVariable;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestMethod;

@Controller
public class PredictionEndpointsController
extends LambdaAPIControllerBase {
    @Autowired
    private LambdaContext lambdaContext;
    @Autowired
    private ServicesService serviceService;
    @Autowired
    private AuthVerificationService authService;
    private static final DKULogger logger = DKULogger.getLogger((String)"dku.lambda.prediction.controller");

    @RequestMapping(value={"/public/api/v1/{serviceId}/{endpointId}/predict"}, method={RequestMethod.POST})
    public void execPrediction(HttpServletRequest req, HttpServletResponse resp, @PathVariable String serviceId, @PathVariable String endpointId) throws Exception {
        try (ErrorContext.ACNDC c1 = ErrorContext.pushWithNDC((String)("svc:" + serviceId));
             ErrorContext.ACNDC c2 = ErrorContext.pushWithNDC((String)("ep:" + endpointId));){
            SinglePredictionQuery spq = (SinglePredictionQuery)this.getRequestBodyAs(req, SinglePredictionQuery.class);
            this.replySimple(req, resp, serviceId, endpointId, spq);
        }
    }

    @RequestMapping(value={"/public/api/v1/{serviceId}/{endpointId}/predict-multi"}, method={RequestMethod.POST})
    public void execPredictionMulti(HttpServletRequest req, HttpServletResponse resp, @PathVariable String serviceId, @PathVariable String endpointId) throws Exception {
        try (ErrorContext.ACNDC c1 = ErrorContext.pushWithNDC((String)("svc:" + serviceId));
             ErrorContext.ACNDC c2 = ErrorContext.pushWithNDC((String)("ep:" + endpointId));){
            MultiplePredictionQuery mpq = (MultiplePredictionQuery)this.getRequestBodyAs(req, MultiplePredictionQuery.class);
            this.replyMultiple(req, resp, serviceId, endpointId, mpq);
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @RequestMapping(value={"/public/api/v1/{serviceId}/{endpointId}/forecast"}, method={RequestMethod.POST})
    public void execForecast(HttpServletRequest req, HttpServletResponse resp, @PathVariable String serviceId, @PathVariable String endpointId) throws Exception {
        try (ErrorContext.ACNDC c1 = ErrorContext.pushWithNDC((String)("svc:" + serviceId));
             ErrorContext.ACNDC c2 = ErrorContext.pushWithNDC((String)("ep:" + endpointId));){
            ForecastQuery forecastQuery = (ForecastQuery)this.getRequestBodyAs(req, ForecastQuery.class);
            this.require(forecastQuery.items != null, "'items' is required");
            long startTimeN = System.nanoTime();
            ServiceManager manager = this.serviceService.getServiceManagerCheck(serviceId);
            this.authService.checkAuth(serviceId, manager.getConfig(), req);
            QueryAPIKey apiKey = this.authService.getApiKeyForInternalCalls(serviceId, manager.getConfig(), req);
            ServiceManager.RefcountedEndpoint re = manager.acquireEndpoint(endpointId, forecastQuery.dispatch);
            try {
                PredictionEndpointHandlerBase peh = (PredictionEndpointHandlerBase)re.getHandler();
                PredictionResponse presp = peh.forecast(startTimeN, re, forecastQuery, apiKey);
                this.writeJSON2(resp, presp);
            }
            finally {
                manager.releaseEndpoint(re);
            }
        }
    }

    @RequestMapping(value={"/public/api/v1/{serviceId}/{endpointId}/predict-effect"}, method={RequestMethod.POST})
    public void execCausalPrediction(HttpServletRequest req, HttpServletResponse resp, @PathVariable String serviceId, @PathVariable String endpointId) throws Exception {
        try (ErrorContext.ACNDC c1 = ErrorContext.pushWithNDC((String)("svc:" + serviceId));
             ErrorContext.ACNDC c2 = ErrorContext.pushWithNDC((String)("ep:" + endpointId));){
            SinglePredictionQuery spq = (SinglePredictionQuery)this.getRequestBodyAs(req, SinglePredictionQuery.class);
            this.replySimple(req, resp, serviceId, endpointId, spq);
        }
    }

    @RequestMapping(value={"/public/api/v1/{serviceId}/{endpointId}/predict-effect-multi"}, method={RequestMethod.POST})
    public void execCausalPredictionMulti(HttpServletRequest req, HttpServletResponse resp, @PathVariable String serviceId, @PathVariable String endpointId) throws Exception {
        try (ErrorContext.ACNDC c1 = ErrorContext.pushWithNDC((String)("svc:" + serviceId));
             ErrorContext.ACNDC c2 = ErrorContext.pushWithNDC((String)("ep:" + endpointId));){
            MultiplePredictionQuery mpq = (MultiplePredictionQuery)this.getRequestBodyAs(req, MultiplePredictionQuery.class);
            this.replyMultiple(req, resp, serviceId, endpointId, mpq);
        }
    }

    @RequestMapping(value={"/public/api/v1/{serviceId}/{endpointId}/predict-simple"}, method={RequestMethod.GET})
    public void execPredictionSimple(HttpServletRequest req, HttpServletResponse resp, @PathVariable String serviceId, @PathVariable String endpointId) throws Exception {
        try (ErrorContext.ACNDC c1 = ErrorContext.pushWithNDC((String)("svc:" + serviceId));
             ErrorContext.ACNDC c2 = ErrorContext.pushWithNDC((String)("ep:" + endpointId));){
            SinglePredictionQuery spq = LambdaControllerUtils.httpGETReqToPredictionQuery(req);
            this.replySimple(req, resp, serviceId, endpointId, spq);
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void replySimple(HttpServletRequest req, HttpServletResponse resp, String serviceId, String endpointId, SinglePredictionQuery spq) throws Exception {
        long startTimeN = System.nanoTime();
        this.require(spq.features != null, "'features' is required");
        MDC.put((String)"serviceId", (Object)serviceId);
        MDC.put((String)"endpointId", (Object)endpointId);
        logger.trace(() -> "Got predict request: " + JSON.json((Object)spq));
        ServiceManager manager = this.serviceService.getServiceManagerCheck(serviceId);
        this.authService.checkAuth(serviceId, manager.getConfig(), req);
        QueryAPIKey apiKey = this.authService.getApiKeyForInternalCalls(manager.getServiceId(), manager.getConfig(), req);
        ServiceManager.RefcountedEndpoint re = manager.acquireEndpoint(endpointId, spq.dispatch);
        try {
            PredictionEndpointHandlerBase peh = (PredictionEndpointHandlerBase)re.getHandler();
            boolean isRequestMetadataEnabled = this.lambdaContext.getMandatoryConfig().isRequestMetadataEnabled;
            this.writeJSON2(resp, new SinglePredictionResponse(peh.predict(startTimeN, re, spq, apiKey, RequestMetadata.extractFrom(req, isRequestMetadataEnabled))));
        }
        finally {
            manager.releaseEndpoint(re);
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void replyMultiple(HttpServletRequest req, HttpServletResponse resp, String serviceId, String endpointId, MultiplePredictionQuery mpq) throws Exception {
        long startTimeN = System.nanoTime();
        this.require(mpq.items != null, "'items' is required");
        ServiceManager manager = this.serviceService.getServiceManagerCheck(serviceId);
        this.authService.checkAuth(serviceId, manager.getConfig(), req);
        QueryAPIKey apiKey = this.authService.getApiKeyForInternalCalls(serviceId, manager.getConfig(), req);
        ServiceManager.RefcountedEndpoint re = manager.acquireEndpoint(endpointId, mpq.dispatch);
        try {
            PredictionEndpointHandlerBase peh = (PredictionEndpointHandlerBase)re.getHandler();
            boolean isRequestMetadataEnabled = this.lambdaContext.getMandatoryConfig().isRequestMetadataEnabled;
            PredictionResponse presp = peh.predict(startTimeN, re, mpq, apiKey, RequestMetadata.extractFrom(req, isRequestMetadataEnabled));
            this.writeJSON2(resp, presp);
        }
        finally {
            manager.releaseEndpoint(re);
        }
    }
}

