/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.llm.local;

import com.dataiku.dip.connections.HuggingFaceLocalConnection;
import com.dataiku.dip.llm.online.LLMClient;
import java.util.List;
import java.util.stream.Collectors;

public class HuggingFaceLocalClient {
    public static String getFormattedPromptContent(List<LLMClient.ChatMessage> chatMessages, HuggingFaceLocalConnection.HuggingFaceHandlingMode handlingMode) {
        StringBuilder sb = new StringBuilder();
        switch (handlingMode) {
            case TEXT_GENERATION_LLAMA_2: 
            case TEXT_GENERATION_LLAMA_GUARD: {
                sb.append("<s>[INST] ");
                for (int idx = 0; idx < chatMessages.size(); ++idx) {
                    LLMClient.ChatMessage msg = chatMessages.get(idx);
                    if ("system".equals(msg.role)) {
                        sb.append("<<SYS>>\n");
                        sb.append(msg.getText());
                        sb.append("\n<</SYS>>\n\n");
                    }
                    if ("user".equals(msg.role)) {
                        sb.append(msg.getText());
                        sb.append(" [/INST]");
                    }
                    if (!"assistant".equals(msg.role)) continue;
                    sb.append(" ");
                    sb.append(msg.getText());
                    sb.append("</s>");
                    if (idx >= chatMessages.size() - 1) continue;
                    sb.append("<s>[INST]");
                }
                break;
            }
            case TEXT_GENERATION_LLAMA_3: {
                sb.append("<|begin_of_text|>");
                for (LLMClient.ChatMessage msg : chatMessages) {
                    sb.append("<|start_header_id|>").append(msg.role).append("<|end_header_id|>");
                    sb.append("\n\n");
                    sb.append(msg.getText());
                    sb.append("<|eot_id|>");
                }
                sb.append("<|start_header_id|>assistant<|end_header_id|>\n");
                break;
            }
            case TEXT_GENERATION_PHI_3: {
                sb.append("<s>");
                for (LLMClient.ChatMessage msg : chatMessages) {
                    sb.append("<|").append(msg.role).append("|>");
                    sb.append(msg.getText()).append("<|end|>\n");
                }
                sb.append("<|assistant|>");
                break;
            }
            case TEXT_GENERATION_MPT: {
                for (LLMClient.ChatMessage msg : chatMessages) {
                    if ("system".equals(msg.role)) {
                        sb.append(msg.getText());
                    }
                    if ("user".equals(msg.role)) {
                        sb.append("### Instruction:\n");
                        sb.append(msg.getText());
                    }
                    if ("assistant".equals(msg.role)) {
                        sb.append("### Response:\n");
                        sb.append(msg.getText());
                    }
                    sb.append("\n");
                }
                sb.append("### Response:\n");
                break;
            }
            case TEXT_GENERATION_FALCON: {
                sb.append(">>TITLE<<\nFlawless answer\n");
                for (LLMClient.ChatMessage msg : chatMessages) {
                    if ("system".equals(msg.role)) {
                        sb.append(">>CONTEXT<<");
                        sb.append(msg.getText());
                    }
                    if ("user".equals(msg.role)) {
                        sb.append(">>QUESTION<<");
                        sb.append(msg.getText());
                    }
                    if ("assistant".equals(msg.role)) {
                        sb.append(">>ANSWER<<");
                        sb.append(msg.getText());
                    }
                    sb.append("\n");
                }
                sb.append(">>ANSWER<<\n");
                break;
            }
            case TEXT_GENERATION_ZEPHYR: {
                for (LLMClient.ChatMessage msg : chatMessages) {
                    sb.append("<|").append(msg.role).append("|>\n");
                    sb.append(msg.getText());
                    sb.append("</s>\n");
                }
                sb.append("<|assistant|>\n");
                break;
            }
            case TEXT_GENERATION_MISTRAL: {
                sb.append("<s>[INST] ");
                for (int idx = 0; idx < chatMessages.size(); ++idx) {
                    LLMClient.ChatMessage msg = chatMessages.get(idx);
                    if ("system".equals(msg.role)) {
                        sb.append("\n");
                        sb.append(msg.getText());
                        sb.append("\n");
                    }
                    if ("user".equals(msg.role)) {
                        sb.append(msg.getText());
                        sb.append(" [/INST]");
                    }
                    if (!"assistant".equals(msg.role)) continue;
                    sb.append(msg.getText());
                    sb.append("</s> ");
                    if (idx >= chatMessages.size() - 1) continue;
                    sb.append("[INST] ");
                }
                break;
            }
            case TEXT_GENERATION_GEMMA: {
                sb.append("<bos>");
                for (LLMClient.ChatMessage msg : chatMessages) {
                    if ("system".equals(msg.role) || "user".equals(msg.role)) {
                        sb.append("<start_of_turn>user\n");
                        sb.append(msg.getText());
                        sb.append("<end_of_turn>\n");
                    }
                    if (!"assistant".equals(msg.role)) continue;
                    sb.append("<start_of_turn>model\n");
                    sb.append(msg.getText());
                    sb.append("<end_of_turn>\n");
                }
                sb.append("<start_of_turn>model");
                break;
            }
            case T5: 
            case TEXT_GENERATION_DEEPSEEK: 
            case TEXT_GENERATION_GPT: 
            case TEXT_GENERATION_QWEN: 
            case TEXT_GENERATION_OPENBMB: 
            case TEXT_GENERATION_GENERIC: {
                sb.append(chatMessages.stream().map(LLMClient.ChatMessage::getText).collect(Collectors.joining("\n")));
                break;
            }
            case TEXT_GENERATION_AUTO: {
                for (LLMClient.ChatMessage msg : chatMessages) {
                    if ("system".equals(msg.role)) {
                        sb.append("System message: ");
                        sb.append(msg.getText());
                    }
                    if ("user".equals(msg.role)) {
                        sb.append("User message: ");
                        sb.append(msg.getText());
                    }
                    if ("assistant".equals(msg.role)) {
                        sb.append("Assistant message: ");
                        sb.append(msg.getText());
                    }
                    sb.append("\n");
                }
                break;
            }
            default: {
                throw new IllegalArgumentException("Not a handling mode that needs to format a prompt: " + String.valueOf((Object)handlingMode));
            }
        }
        return sb.toString();
    }
}

