/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.analysis.ml.hf;

import com.dataiku.dip.ApplicationConfigurator;
import com.dataiku.dip.DKUApp;
import com.dataiku.dip.ProxySettings;
import com.dataiku.dip.analysis.ml.hf.ModelCacheService;
import com.dataiku.dip.security.IPBlacklistVerifier;
import com.dataiku.dip.transactions.fs.ConcreteRelFileAttribute;
import com.dataiku.dip.transactions.fs.FileContent;
import com.dataiku.dip.transactions.fs.FileContentFactory;
import com.dataiku.dip.transactions.fs.ReadOnlyFSBase;
import com.dataiku.dip.transactions.fs.RelFile;
import com.dataiku.dip.transactions.fs.ifaces.ReadOnlyFS;
import com.dataiku.dip.transactions.fs.ifaces.RelFileAttribute;
import com.dataiku.dip.transactions.fs.ifaces.StreamSupplier;
import com.dataiku.dip.transactions.fs.utils.RelFileFilter;
import com.dataiku.dip.util.ProxyUtils;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.DKUtils;
import com.dataiku.dip.utils.JSON;
import com.dataiku.dip.utils.PerfUtils;
import com.dataiku.dss.shadelib.org.apache.commons.io.IOUtils;
import com.dataiku.dss.shadelib.org.apache.commons.io.input.ProxyInputStream;
import com.dataiku.dss.shadelib.org.apache.http.Header;
import com.dataiku.dss.shadelib.org.apache.http.HttpEntity;
import com.dataiku.dss.shadelib.org.apache.http.client.RedirectStrategy;
import com.dataiku.dss.shadelib.org.apache.http.client.methods.CloseableHttpResponse;
import com.dataiku.dss.shadelib.org.apache.http.client.methods.HttpGet;
import com.dataiku.dss.shadelib.org.apache.http.client.methods.HttpUriRequest;
import com.dataiku.dss.shadelib.org.apache.http.client.utils.URIBuilder;
import com.dataiku.dss.shadelib.org.apache.http.impl.client.CloseableHttpClient;
import com.dataiku.dss.shadelib.org.apache.http.impl.client.HttpClientBuilder;
import com.dataiku.dss.shadelib.org.apache.http.impl.client.LaxRedirectStrategy;
import com.google.common.escape.Escaper;
import com.google.common.io.ByteStreams;
import com.google.common.net.UrlEscapers;
import com.google.gson.reflect.TypeToken;
import java.io.IOException;
import java.io.InputStream;
import java.net.URI;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeUnit;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import javax.annotation.Nullable;
import org.apache.commons.lang3.function.FailableConsumer;
import org.apache.http.client.HttpResponseException;

public class HuggingFaceClient
implements AutoCloseable {
    private static final String HF_ENDPOINT;
    private static final String HF_API_ENDPOINT;
    private static final String HF_MODELS_API_ENDPOINT;
    private static final String HF_WHOAMI_API_ENDPOINT;
    private final CloseableHttpClient httpClient;
    @Nullable
    private final String apiKey;
    private final ProxySettings proxySettings = ApplicationConfigurator.getProxySettings();
    private static final DKULogger logger;

    public HuggingFaceClient(@Nullable String apiKey) {
        this.apiKey = apiKey;
        int concurrentConnections = DKUApp.getParams().getIntParam("dku.ml.hf.httpClientConcurrency", Integer.valueOf(8));
        int connectionsTTL = DKUApp.getParams().getIntParam("dku.ml.hf.connectionsTimeToLiveInSeconds", Integer.valueOf(30));
        HttpClientBuilder httpClientBuilder = HttpClientBuilder.create().setMaxConnPerRoute(concurrentConnections).setConnectionTimeToLive((long)connectionsTTL, TimeUnit.SECONDS).setRedirectStrategy((RedirectStrategy)LaxRedirectStrategy.INSTANCE).addInterceptorFirst(PerfUtils.MARK_HTTP_REQUEST_INTERCEPTOR);
        this.httpClient = ProxyUtils.applyProxySettings((ProxySettings)this.proxySettings, (HttpClientBuilder)httpClientBuilder).build();
    }

    public HuggingFaceClient() {
        this(null);
    }

    public ReadOnlyFS getModelRepo(String modelId, String revision) {
        return new HuggingFaceModelRepo(modelId, revision);
    }

    public List<HFResponseModelItem> listModels(@Nullable String filter) throws IOException {
        URIBuilder url = new URIBuilder(URI.create("https://huggingface.co/api/models"));
        if (filter != null) {
            url.setParameter("filter", filter);
        }
        ArrayList<HFResponseModelItem> out = new ArrayList<HFResponseModelItem>();
        this.paginatedCall(url.toString(), (FailableConsumer<CloseableHttpResponse, IOException>)((FailableConsumer)resp -> {
            try (InputStream is = resp.getEntity().getContent();){
                out.addAll((Collection)JSON.parse((InputStream)is, (TypeToken)new TypeToken<List<HFResponseModelItem>>(){}));
            }
        }));
        return out;
    }

    public boolean modelExists(ModelCacheService.ModelStorageDefinition modelDefinition) throws IOException {
        String url = modelDefinition.hasRevision() ? HF_MODELS_API_ENDPOINT + "/" + HuggingFaceClient.escapeUrlPathSegments(modelDefinition.getModelName()) + "/" + HuggingFaceClient.escapeUrlPathSegments(modelDefinition.getRevision()) : HF_MODELS_API_ENDPOINT + "/" + HuggingFaceClient.escapeUrlPathSegments(modelDefinition.getModelName());
        try (CloseableHttpResponse resp = this.rawCall(url);){
            boolean bl = resp.getStatusLine().getStatusCode() == 200;
            return bl;
        }
    }

    public String getModelUrl(String modelId, String commitHash) {
        return "https://huggingface.co/" + HuggingFaceClient.escapeUrlPathSegments(modelId) + "/tree/" + HuggingFaceClient.escapeUrlPathSegments(commitHash);
    }

    public HFCommitInfo getCommitInfos(String modelId, String commitHash) throws IOException {
        try (CloseableHttpResponse resp = this.rawCall(HF_MODELS_API_ENDPOINT + "/" + HuggingFaceClient.escapeUrlPathSegments(modelId) + "/commits/" + HuggingFaceClient.escapeUrlPathSegments(commitHash));){
            HFCommitInfo hFCommitInfo;
            block13: {
                InputStream inputStream = resp.getEntity().getContent();
                try {
                    List commits = (List)JSON.parse((InputStream)inputStream, (TypeToken)new TypeToken<List<HFCommitInfo>>(){});
                    if (commits.isEmpty()) {
                        throw new IOException("No commit found: " + modelId + " (" + commitHash + ")");
                    }
                    hFCommitInfo = (HFCommitInfo)commits.get(0);
                    if (inputStream == null) break block13;
                }
                catch (Throwable throwable) {
                    if (inputStream != null) {
                        try {
                            inputStream.close();
                        }
                        catch (Throwable throwable2) {
                            throwable.addSuppressed(throwable2);
                        }
                    }
                    throw throwable;
                }
                inputStream.close();
            }
            return hFCommitInfo;
        }
    }

    public HFResponseModelDetails getModelDetails(ModelCacheService.ModelStorageDefinition modelDefinition) throws IOException {
        String url = modelDefinition.hasRevision() ? HF_MODELS_API_ENDPOINT + "/" + HuggingFaceClient.escapeUrlPathSegments(modelDefinition.getModelName()) + "/" + HuggingFaceClient.escapeUrlPathSegments(modelDefinition.getRevision()) : HF_MODELS_API_ENDPOINT + "/" + HuggingFaceClient.escapeUrlPathSegments(modelDefinition.getModelName());
        try (CloseableHttpResponse resp = this.rawCall(url);){
            HFResponseModelDetails hFResponseModelDetails;
            block12: {
                InputStream inputStream = resp.getEntity().getContent();
                try {
                    hFResponseModelDetails = (HFResponseModelDetails)JSON.parse((InputStream)inputStream, HFResponseModelDetails.class);
                    if (inputStream == null) break block12;
                }
                catch (Throwable throwable) {
                    if (inputStream != null) {
                        try {
                            inputStream.close();
                        }
                        catch (Throwable throwable2) {
                            throwable.addSuppressed(throwable2);
                        }
                    }
                    throw throwable;
                }
                inputStream.close();
            }
            return hFResponseModelDetails;
        }
    }

    /*
     * Enabled aggressive exception aggregation
     */
    public String getBaseModelKey(String modelId, String revision) {
        try (CloseableHttpResponse resp = this.rawCall(HF_ENDPOINT + "/" + HuggingFaceClient.escapeUrlPathSegments(modelId) + "/resolve/" + HuggingFaceClient.escapeUrlPathSegments(revision) + "/" + HuggingFaceClient.escapeUrlPathSegments("adapter_config.json"));){
            String string;
            block14: {
                InputStream inputStream = resp.getEntity().getContent();
                try {
                    HFAdapterConfig adapterConfig = (HFAdapterConfig)JSON.parse((InputStream)inputStream, HFAdapterConfig.class);
                    string = adapterConfig.base_model_name_or_path;
                    if (inputStream == null) break block14;
                }
                catch (Throwable throwable) {
                    if (inputStream != null) {
                        try {
                            inputStream.close();
                        }
                        catch (Throwable throwable2) {
                            throwable.addSuppressed(throwable2);
                        }
                    }
                    throw throwable;
                }
                inputStream.close();
            }
            return string;
        }
        catch (Exception e) {
            return null;
        }
    }

    public void testAuthentification() throws IOException {
        this.rawCall(HF_WHOAMI_API_ENDPOINT);
    }

    @Override
    public void close() {
        try {
            this.httpClient.close();
        }
        catch (IOException e) {
            logger.error((Object)"Error while closing Hugging Face client", (Throwable)e);
        }
    }

    private CloseableHttpResponse rawCall(String url) throws IOException {
        return this.rawCall(url, false);
    }

    private CloseableHttpResponse rawCall(String url, boolean closeConnection) throws IOException {
        IPBlacklistVerifier.validateUriNotBlacklisted(url, this.proxySettings);
        logger.info((Object)("Fetching: " + url));
        HttpGet req = new HttpGet(url);
        if (closeConnection) {
            req.setHeader("Connection", "close");
        }
        if (this.apiKey != null) {
            req.setHeader("Authorization", "Bearer " + this.apiKey);
        }
        CloseableHttpResponse resp = this.httpClient.execute((HttpUriRequest)req);
        HttpEntity respEntity = resp.getEntity();
        long length = respEntity.getContentLength();
        boolean hasContentType = respEntity.getContentType() != null && respEntity.getContentType().getValue() != null;
        boolean isJson = hasContentType && respEntity.getContentType().getValue().startsWith("application/json");
        int code = resp.getStatusLine().getStatusCode();
        logger.info((Object)("HTTP " + code + " (" + length + " bytes)"));
        if (code != 200) {
            try {
                InputStream is = ByteStreams.limit((InputStream)respEntity.getContent(), (long)0x100000L);
                String data = IOUtils.toString((InputStream)is, (Charset)StandardCharsets.UTF_8);
                if (isJson) {
                    HFResponseError parsedError = (HFResponseError)JSON.parse((String)data, HFResponseError.class);
                    throw new HuggingFaceResponseException(code, "HTTP " + code + ": " + parsedError.error);
                }
                throw new HuggingFaceResponseException(code, "HTTP " + code + ": " + data);
            }
            catch (Throwable throwable) {
                resp.close();
                throw throwable;
            }
        }
        return resp;
    }

    private void paginatedCall(String url, FailableConsumer<CloseableHttpResponse, IOException> responseConsumer) throws IOException {
        while (true) {
            CloseableHttpResponse resp = this.rawCall(url);
            try {
                responseConsumer.accept((Object)resp);
                Header linkHeader = resp.getFirstHeader("Link");
                if (linkHeader != null && (url = HuggingFaceClient.parseLinkHeader(linkHeader.getValue()).get("next")) != null) continue;
            }
            finally {
                if (resp == null) continue;
                resp.close();
                continue;
            }
            break;
        }
    }

    private static Map<String, String> parseLinkHeader(String linkHeader) {
        String linkPattern = "<(.+)>;\\s*rel=\"(.+)\"";
        Pattern pattern = Pattern.compile(linkPattern);
        return Stream.of(linkHeader.split(",")).map(pattern::matcher).filter(Matcher::find).collect(HashMap::new, (map, matcher) -> map.put(matcher.group(2), matcher.group(1)), HashMap::putAll);
    }

    private static String escapeUrlPathSegments(String url) {
        Escaper escaper = UrlEscapers.urlPathSegmentEscaper();
        return Arrays.stream(url.split("/")).map(arg_0 -> ((Escaper)escaper).escape(arg_0)).collect(Collectors.joining("/"));
    }

    static {
        Map dssProcessEnvironment = DKUtils.getEnvironment();
        HF_ENDPOINT = dssProcessEnvironment.getOrDefault("HF_ENDPOINT", "https://huggingface.co");
        HF_API_ENDPOINT = HF_ENDPOINT + "/api";
        HF_MODELS_API_ENDPOINT = HF_API_ENDPOINT + "/models";
        HF_WHOAMI_API_ENDPOINT = HF_API_ENDPOINT + "/whoami-v2";
        logger = DKULogger.getLogger((String)"dku.huggingface");
    }

    public class HuggingFaceModelRepo
    extends ReadOnlyFSBase {
        private final Map<RelFile, Map<RelFile, HFResponseFileItem>> listings = new ConcurrentHashMap<RelFile, Map<RelFile, HFResponseFileItem>>();
        private final String modelId;
        private final String revision;

        public HuggingFaceModelRepo(String modelId, String revision) {
            this.modelId = modelId;
            this.revision = revision;
        }

        public List<RelFile> listFilesUnordered(RelFile directory) throws IOException {
            return new ArrayList<RelFile>(this.fetchListing(directory).keySet());
        }

        public FileContent readContentUnsafe(RelFile file) throws IOException {
            RelFileAttribute attrs = this.getAttributes(file);
            if (attrs == null) {
                throw new IOException("No such file: " + String.valueOf(file));
            }
            final String url = HF_ENDPOINT + "/" + HuggingFaceClient.escapeUrlPathSegments(this.modelId) + "/resolve/" + HuggingFaceClient.escapeUrlPathSegments(this.revision) + "/" + HuggingFaceClient.escapeUrlPathSegments(file.getFullPath());
            return FileContentFactory.DEFAULT.fromUncompressedStreamSupplier(new StreamSupplier(){
                int retryCount = 0;

                public InputStream openStream() throws IOException {
                    final CloseableHttpResponse resp = HuggingFaceClient.this.rawCall(url, this.retryCount > 0);
                    ++this.retryCount;
                    return new ProxyInputStream(resp.getEntity().getContent()){

                        public void close() throws IOException {
                            super.close();
                            resp.close();
                        }
                    };
                }
            }, attrs.getLength());
        }

        public RelFileAttribute getAttributes(RelFile file) throws IOException {
            if (file.isRoot()) {
                return new ConcreteRelFileAttribute(RelFile.root(), RelFileAttribute.FileType.DIRECTORY, -1L, -1L);
            }
            if (!this.isDirectory(file.getParent())) {
                return null;
            }
            HFResponseFileItem item = this.fetchListing(file.getParent()).get(file);
            if (item == null) {
                return null;
            }
            return new ConcreteRelFileAttribute(file, "file".equals(item.type) ? RelFileAttribute.FileType.FILE : RelFileAttribute.FileType.DIRECTORY, -1L, item.size);
        }

        private Map<RelFile, HFResponseFileItem> fetchListing(RelFile directory) throws IOException {
            Map<RelFile, HFResponseFileItem> listing = this.listings.get(directory);
            if (listing != null) {
                return listing;
            }
            HashMap<RelFile, HFResponseFileItem> finalListing = new HashMap<RelFile, HFResponseFileItem>();
            HuggingFaceClient.this.paginatedCall(HF_MODELS_API_ENDPOINT + "/" + HuggingFaceClient.escapeUrlPathSegments(this.modelId) + "/tree/" + HuggingFaceClient.escapeUrlPathSegments(this.revision) + "/" + HuggingFaceClient.escapeUrlPathSegments(directory.getFullPath()), (FailableConsumer<CloseableHttpResponse, IOException>)((FailableConsumer)resp -> {
                try (InputStream is = resp.getEntity().getContent();){
                    for (HFResponseFileItem item : (List)JSON.parse((InputStream)is, (TypeToken)new TypeToken<List<HFResponseFileItem>>(){})) {
                        if (!"file".equals(item.type) && !"directory".equals(item.type)) continue;
                        finalListing.put(RelFile.fromPath((String)item.path), item);
                    }
                }
            }));
            this.listings.put(directory, finalListing);
            return finalListing;
        }
    }

    public static class HFCommitInfo {
        public String id;
        public String title;
        public String date;
    }

    public static class HFResponseModelDetails
    extends HFResponseModelItem {
        public String library_name;
        public List<String> tags;
        public String author;
    }

    public static class HFAdapterConfig {
        public String base_model_name_or_path;
    }

    public static class HFResponseError {
        public String error;
    }

    public static class HuggingFaceResponseException
    extends HttpResponseException {
        public HuggingFaceResponseException(int statusCode, String reasonPhrase) {
            super(statusCode, reasonPhrase);
        }
    }

    public static class HuggingFaceWeightsFilter
    implements RelFileFilter {
        private ReadOnlyFS modelRepo;
        private boolean hasSafetensorsWeightsCache;

        private boolean hasSafetensorsWeights(ReadOnlyFS fs) throws IOException {
            if (this.modelRepo != fs) {
                logger.info((Object)"Scanning model repo for .safetensors weight files");
                List repoFiles = fs.listFilesUnordered(RelFile.root());
                this.hasSafetensorsWeightsCache = false;
                for (RelFile rf : repoFiles) {
                    if (!rf.getLeafName().equals("model.safetensors.index.json") && !rf.getLeafName().equals("model.safetensors")) continue;
                    logger.info((Object)"This model has pytorch weights in .safetensors format, .bin weight files will be ignored");
                    this.hasSafetensorsWeightsCache = true;
                    break;
                }
                this.modelRepo = fs;
                if (!this.hasSafetensorsWeightsCache) {
                    logger.info((Object)"This model doesn't have pytorch weights in .safetensors format, .bin weight files will not be ignored");
                }
            }
            return this.hasSafetensorsWeightsCache;
        }

        public boolean accept(ReadOnlyFS fs, RelFile rf) throws IOException {
            return !(rf.getLeafName().endsWith(".mlpackage") || !fs.isDirectory(rf) && (rf.getLeafName().endsWith(".msgpack") || rf.getLeafName().endsWith(".ot") || rf.getLeafName().endsWith(".h5") || rf.getLeafName().endsWith(".pt") || this.hasSafetensorsWeights(fs) && rf.getLeafName().endsWith(".bin") || this.hasSafetensorsWeights(fs) && rf.getLeafName().endsWith(".bin.index.json") || this.hasSafetensorsWeights(fs) && rf.getLeafName().endsWith(".pth") || this.hasSafetensorsWeights(fs) && rf.getLeafName().startsWith("consolidated") && rf.getLeafName().endsWith(".safetensors") || this.hasSafetensorsWeights(fs) && rf.getLeafName().equals("consolidated.safetensors.index.json")));
        }
    }

    public static class HFResponseModelItem {
        public String pipeline_tag;
        public String id;
        public String modelId;
    }

    private static class HFResponseFileItem {
        String type;
        String path;
        long size;

        private HFResponseFileItem() {
        }
    }
}

