// SPDX-FileCopyrightText: 2024 UnionTech Software Technology Co., Ltd.
//
// SPDX-License-Identifier: GPL-3.0-or-later

#include "llamallmproxy.h"

#include "llama.h"
#include "common/common.h"
#include <QDebug>

GLOBAL_USE_NAMESPACE

LlamaLLMProxy::LlamaLLMProxy(const std::string &name)
    : LLMProxy()
    , LlamaModelWrapper()
    , modelName(name)
{

}

std::string LlamaLLMProxy::name() const
{
    return modelName;
}

std::vector<int32_t> LlamaLLMProxy::tokenize(const std::string &prompt, const std::map<std::string, std::string> &params)
{
    return llama_tokenize(gModel, prompt, true, true);
}

std::string LlamaLLMProxy::detokenize(const std::vector<int32_t> &tokens, const std::map<std::string, std::string> &params)
{
    std::string ret;
    std::string piece;
    for (size_t i = 0; i < tokens.size(); ++i) {
        piece = llama_token_to_piece(gCtx, tokens[i], false);
        ret += piece;
    }

    return ret;
}

std::vector<int32_t> LlamaLLMProxy::generate(const std::vector<int32_t> &input, const std::map<std::string, std::string> &params, generateStream stream, void *user) const
{
    std::vector<int32_t> output;
    if (input.empty())
        return output;

    const int n_ctx = llama_n_ctx(gCtx);

    if (input.size() > n_ctx - 4) {
        std::cerr << QString("prompt is too long (%0 tokens, max %0)").arg(input.size()).arg(n_ctx - 4).toStdString() << std::endl;
        return output;
    }

    int n_consumed = 0;
    llama_sampling_context * ctx_sampling = llama_sampling_init(gParams->sparams);

    std::vector<llama_token> embd;
    while (input.size() > n_consumed) {
        embd.push_back(input[n_consumed]);
        llama_sampling_accept(ctx_sampling, gCtx, input[n_consumed], false);
        ++n_consumed;
    }

    const int n_batch = llama_n_batch(gCtx);
    const int defaultSeqID = 0;

    // only a single seq_id per token is needed
    llama_batch batch = llama_batch_init(n_batch, 0, 1);

    int n_past = 0;
    LLMGenerateContext slot;

    while (true) {
        int n_eval = embd.size();
        bool decodeError = false;
        for (int32_t i = 0; i < n_eval; i += n_batch) {
            llama_batch_clear(batch);

            const int32_t n_tokens = std::min(n_batch, n_eval - i);
            int j = 0;
            for (; j < n_tokens; ++j) {
                llama_batch_add(batch, embd[i + j], n_past, {defaultSeqID}, false);
                n_past++;
            }

            if (i + j == n_eval)
                batch.logits[batch.n_tokens - 1] = true;

            if (llama_decode(gCtx, batch)) {
                std::cerr << "llama_decode: failed to eval" << std::endl;
                decodeError = true;
                break;
            }
        }

        if (decodeError)
            break;

        embd.clear();

        const llama_token id = llama_sampling_sample(ctx_sampling, gCtx, nullptr);
        llama_sampling_accept(ctx_sampling, gCtx, id, true);

        output.push_back(id);
        embd.push_back(id);

        if (stream) {
            auto send = processToken(id, slot);
            if (!send.empty()) {
                if (!stream(send, user))
                    break;
            }
        }

        if (embd.back() == llama_token_eos(gModel))
           break;

        if (n_past >= n_ctx - 4)
            break;
    }

    llama_sampling_free(ctx_sampling);
    llama_kv_cache_clear(gCtx);
    llama_batch_free(batch);

    return output;
}

// see llama.cpp/examples/server/server.cppp:process_token
std::string LlamaLLMProxy::processToken(int32_t token, LLMGenerateContext &slot) const
{
    std::string push;

    const std::string token_str = llama_token_to_piece(gCtx, token, false);
    slot.generatedText += token_str;

    // check if there is incomplete UTF-8 character at the end
    bool incomplete = false;
    for (unsigned i = 1; i < 5 && i <= slot.generatedText.size(); ++i) {
        unsigned char c = slot.generatedText[slot.generatedText.size() - i];
        if ((c & 0xC0) == 0x80) {
            // continuation byte: 10xxxxxx
            continue;
        }
        if ((c & 0xE0) == 0xC0) {
            // 2-byte character: 110xxxxx ...
            incomplete = i < 2;
        } else if ((c & 0xF0) == 0xE0) {
            // 3-byte character: 1110xxxx ...
            incomplete = i < 3;
        } else if ((c & 0xF8) == 0xF0) {
            // 4-byte character: 11110xxx ...
            incomplete = i < 4;
        }
        // else 1-byte character or invalid byte
        break;
    }

    if (!incomplete) {
          size_t pos = std::min(slot.pushedPos, slot.generatedText.size());
          push = slot.generatedText.substr(pos, std::string::npos);
          slot.pushedPos += push.size();
      }

    return push;
}

