/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.util;

import com.dataiku.dip.coremodel.InfoMessage;
import com.dataiku.dip.exceptions.CodedIOException;
import com.dataiku.dip.scheduler.runnables.StoppableWithTimeoutService;
import com.dataiku.dip.server.notifications.DSSEvent;
import com.dataiku.dip.server.notifications.DSSEventListener;
import com.dataiku.dip.server.notifications.frontend.FrontendEvent;
import com.dataiku.dip.server.notifications.frontend.TimeoutableTaskKeepAlive;
import com.dataiku.dip.server.services.PubSubService;
import com.dataiku.dip.util.GpuMonitoringNotificationRouter;
import com.dataiku.dip.util.SystemMonitoringCodes;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.DKUtils;
import com.dataiku.dip.utils.ExceptionUtils;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Joiner;
import java.io.IOException;
import java.io.UnsupportedEncodingException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import javax.annotation.PostConstruct;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;

@Service
public class GPUMonitoringService {
    private static final int GPU_POLL_TIMEOUT_MILLS = 800;
    private static final float GPU_MONITORING_TIMEOUT_MINS = 1.0f;
    private static final String[] NVSMI_FIELDS = new String[]{"uuid", "name", "index", "utilization.gpu", "memory.total", "memory.free", "memory.used", "utilization.memory"};
    private static final String GPU_MONITORING = "gpu-monitoring";
    private static final DKULogger logger = DKULogger.getLogger((String)"com.dataiku.dip.util.GPUMonitoringService");
    @Autowired
    private PubSubService pubSub;
    @Autowired
    private StoppableWithTimeoutService stoppableWithTimeoutService;
    @Autowired
    private GpuMonitoringNotificationRouter gpuMonitoringNotificationRouter;
    private final Map<String, Long> lastRefreshDateByUser = new HashMap<String, Long>();
    private GPUMonitoringThread runningThread;

    public GPUStatsResponse fetchGpuStats() throws IOException, InterruptedException {
        return this.parseGpuStats(this.runNvidiaSmi());
    }

    @PostConstruct
    public void postConstruct() {
        this.pubSub.subscribe("gpu-monitoring-start", (DSSEventListener)new DSSEventListener<GpuStatsStartEvent>(){

            public void on(GpuStatsStartEvent evt) {
                GPUMonitoringService.this.recordLastRefreshDate(evt.getUserLogin());
                GPUMonitoringService.this.registerOrRefresh();
            }
        });
        this.pubSub.subscribe("timeoutable-task-keepalive", (DSSEventListener)new DSSEventListener<TimeoutableTaskKeepAlive>(){

            public void on(TimeoutableTaskKeepAlive evt) {
                if (GPUMonitoringService.GPU_MONITORING.equals(evt.taskId)) {
                    GPUMonitoringService.this.recordLastRefreshDate(evt.getUserLogin());
                }
            }
        });
    }

    private synchronized void recordLastRefreshDate(String userLogin) {
        this.lastRefreshDateByUser.put(userLogin, System.currentTimeMillis());
    }

    private synchronized Set<String> getInterestedUsers() {
        return new HashSet<String>(this.lastRefreshDateByUser.keySet());
    }

    private synchronized void unsubscribeNotInterestedUsers() {
        HashSet<String> usersToUnsubscribe = new HashSet<String>();
        for (Map.Entry<String, Long> userRefreshDate : this.lastRefreshDateByUser.entrySet()) {
            if (!((float)userRefreshDate.getValue().longValue() < (float)System.currentTimeMillis() - 60000.0f)) continue;
            usersToUnsubscribe.add(userRefreshDate.getKey());
        }
        for (String login : usersToUnsubscribe) {
            this.lastRefreshDateByUser.remove(login);
        }
    }

    private synchronized void registerOrRefresh() {
        if (!this.stoppableWithTimeoutService.isRunning(GPU_MONITORING)) {
            this.runningThread = new GPUMonitoringThread();
            this.runningThread.start();
            this.stoppableWithTimeoutService.stopWithTimeout(GPU_MONITORING, new StoppableWithTimeoutService.Stoppable(){

                @Override
                public void stop() {
                    GPUMonitoringService.this.runningThread.interrupt();
                }
            }, 1.0f);
        } else {
            this.stoppableWithTimeoutService.refreshTask(GPU_MONITORING);
        }
    }

    @VisibleForTesting
    GPUStatsResponse parseGpuStats(String nvidiaSmi) {
        String[] lines = nvidiaSmi.split("\n");
        ArrayList<String> header = new ArrayList<String>();
        GPUStatsResponse gpuStatsResponse = new GPUStatsResponse();
        for (int i = 0; i < lines.length; ++i) {
            if (i == 0) {
                for (String k : lines[i].split(",")) {
                    if ((k = k.trim()).contains(" ")) {
                        k = k.split(" ")[0];
                    }
                    header.add(k);
                }
                continue;
            }
            HashMap<String, String> gpuStatsMap = new HashMap<String, String>();
            List<String> line = Arrays.asList(lines[i].split(","));
            for (int j = 0; j < header.size(); ++j) {
                String k;
                k = (String)header.get(j);
                String v = line.get(j).trim();
                if ("[Not Supported]".equals(v)) continue;
                gpuStatsMap.put(k, v);
            }
            GPUStats gpuStats = new GPUStats.GPUStatsBuilder().setUuid((String)gpuStatsMap.get("uuid")).setName((String)gpuStatsMap.get("name")).setIndex(GPUMonitoringService.toInteger((String)gpuStatsMap.get("index"))).setUtilizationGpu(GPUMonitoringService.toInteger((String)gpuStatsMap.get("utilization.gpu"))).setMemoryTotal(GPUMonitoringService.toInteger((String)gpuStatsMap.get("memory.total"))).setMemoryFree(GPUMonitoringService.toInteger((String)gpuStatsMap.get("memory.free"))).setMemoryUsed(GPUMonitoringService.toInteger((String)gpuStatsMap.get("memory.used"))).setUtilizationMemory(GPUMonitoringService.toInteger((String)gpuStatsMap.get("utilization.memory"))).build();
            gpuStatsResponse.stats.add(gpuStats);
        }
        return gpuStatsResponse;
    }

    private String runNvidiaSmi() throws InterruptedException, CodedIOException, UnsupportedEncodingException {
        byte[] bytes;
        block2: {
            ProcessBuilder pb = new ProcessBuilder(new String[0]);
            pb.command("nvidia-smi", "--format=csv,nounits", "--query-gpu=" + Joiner.on((String)",").join((Object[])NVSMI_FIELDS));
            bytes = new byte[]{};
            try {
                bytes = DKUtils.execAndGetOutput((ProcessBuilder)pb);
            }
            catch (IOException e) {
                if (!ExceptionUtils.hasCauseWithMessage((Throwable)e, (String)"No such file or directory")) break block2;
                throw new CodedIOException((InfoMessage.MessageCode)SystemMonitoringCodes.NVIDIA_SMI_NOT_FOUND, "Failed to collect gpu stats");
            }
        }
        return new String(bytes, "utf-8");
    }

    private static Integer toInteger(String val) {
        return val == null ? null : Integer.valueOf(Integer.parseInt(val));
    }

    public static class GPUStatsResponse {
        public GPUStatsResponseStatus status = GPUStatsResponseStatus.OK;
        public String error;
        public List<GPUStats> stats = new ArrayList<GPUStats>();

        public static enum GPUStatsResponseStatus {
            OK,
            ERROR;

        }
    }

    public static class GpuStatsStartEvent
    extends FrontendEvent {
        public static final String NAME = "gpu-monitoring-start";

        public GpuStatsStartEvent(String userLogin, String webSocketSessionId) {
            super(userLogin, webSocketSessionId);
        }

        public String getName() {
            return NAME;
        }
    }

    private class GPUMonitoringThread
    extends Thread {
        private GPUMonitoringThread() {
        }

        @Override
        public void run() {
            Exception fetchException = null;
            while (!Thread.interrupted() && fetchException == null) {
                block8: {
                    List<StoppableWithTimeoutService.Stoppable> stoppables;
                    try {
                        GPUMonitoringService.this.unsubscribeNotInterestedUsers();
                        GPUStatsResponse gpuStatsResponse = GPUMonitoringService.this.fetchGpuStats();
                        GpuStatsResponseEvent evt = new GpuStatsResponseEvent(gpuStatsResponse);
                        GPUMonitoringService.this.gpuMonitoringNotificationRouter.sendStats(evt, GPUMonitoringService.this.getInterestedUsers());
                        stoppables = GPUMonitoringService.this.stoppableWithTimeoutService.killIfExpired(GPUMonitoringService.GPU_MONITORING);
                    }
                    catch (Exception e) {
                        String message = "Failed to fetch GPU stats, maybe nvidia-smi utility is not found";
                        logger.error((Object)message, (Throwable)e);
                        GPUStatsResponse response = new GPUStatsResponse();
                        response.status = GPUStatsResponse.GPUStatsResponseStatus.ERROR;
                        response.error = message;
                        GPUMonitoringService.this.gpuMonitoringNotificationRouter.sendStats(new GpuStatsResponseEvent(response), GPUMonitoringService.this.getInterestedUsers());
                        stoppables = GPUMonitoringService.this.stoppableWithTimeoutService.kill(GPUMonitoringService.GPU_MONITORING);
                        fetchException = e;
                    }
                    try {
                        for (StoppableWithTimeoutService.Stoppable stoppable : stoppables) {
                            stoppable.stop();
                        }
                    }
                    catch (Exception e) {
                        logger.error((Object)"Failed to stop GPU monitoring", (Throwable)e);
                        if (fetchException != null) break block8;
                        fetchException = e;
                    }
                }
                try {
                    Thread.sleep(800L);
                }
                catch (InterruptedException e) {
                    Thread.currentThread().interrupt();
                }
            }
        }
    }

    public static class GPUStats {
        public String uuid;
        public String name;
        public Integer index;
        public Integer utilizationGpu;
        public Integer memoryTotal;
        public Integer memoryFree;
        public Integer memoryUsed;
        public Integer utilizationMemory;

        private GPUStats(GPUStatsBuilder builder) {
            this.uuid = builder.uuid;
            this.name = builder.name;
            this.index = builder.index;
            this.utilizationGpu = builder.utilizationGpu;
            this.memoryTotal = builder.memoryTotal;
            this.memoryFree = builder.memoryFree;
            this.memoryUsed = builder.memoryUsed;
            this.utilizationMemory = builder.utilizationMemory;
        }

        public static class GPUStatsBuilder {
            private String uuid;
            private String name;
            private Integer index;
            private Integer utilizationGpu;
            private Integer memoryTotal;
            private Integer memoryFree;
            private Integer memoryUsed;
            private Integer utilizationMemory;

            public GPUStatsBuilder setUuid(String uuid) {
                this.uuid = uuid;
                return this;
            }

            public GPUStatsBuilder setName(String name) {
                this.name = name;
                return this;
            }

            public GPUStatsBuilder setIndex(Integer index) {
                this.index = index;
                return this;
            }

            public GPUStatsBuilder setUtilizationGpu(Integer utilizationGpu) {
                this.utilizationGpu = utilizationGpu;
                return this;
            }

            public GPUStatsBuilder setMemoryTotal(Integer memoryTotal) {
                this.memoryTotal = memoryTotal;
                return this;
            }

            public GPUStatsBuilder setMemoryFree(Integer memoryFree) {
                this.memoryFree = memoryFree;
                return this;
            }

            public GPUStatsBuilder setMemoryUsed(Integer memoryUsed) {
                this.memoryUsed = memoryUsed;
                return this;
            }

            public GPUStatsBuilder setUtilizationMemory(Integer utilizationMemory) {
                this.utilizationMemory = utilizationMemory;
                return this;
            }

            public GPUStats build() {
                return new GPUStats(this);
            }
        }
    }

    public static class GpuStatsResponseEvent
    implements DSSEvent {
        public static final String NAME = "gpu-stats-response";
        public GPUStatsResponse response;

        public GpuStatsResponseEvent(GPUStatsResponse response) {
            this.response = response;
        }

        public String getName() {
            return NAME;
        }
    }
}

