kateai/exllama.cpp
2024-06-11 14:15:08 +02:00

136 lines
3.0 KiB
C++

#include "exllama.h"
#include "backend.h"
#include <QJsonObject>
#include <QJsonDocument>
#include <cstdint>
ExLlama::ExLlama()
{
connect(&m_webSocket, &QWebSocket::textMessageReceived, this, &ExLlama::socketMessage);
}
bool ExLlama::ready()
{
return m_webSocket.isValid();
}
void ExLlama::socketMessage(const QString& message)
{
QJsonDocument jsonDocument = QJsonDocument::fromJson(message.toUtf8());
QJsonValue idVal = jsonDocument[QStringLiteral("request_id")];
int id = idVal.toInt();
if(!isValidId(id))
{
qDebug()<<"Got unkown response id on socket";
return;
}
Response::type_t type = strToAction(jsonDocument[QStringLiteral("action")].toString());
QJsonValue responseValue;
int64_t tokens = -1;
if(type == Response::INFERENCE)
{
responseValue = jsonDocument[QStringLiteral("response")];
}
else if(type == Response::COUNT_TOKENS)
{
responseValue = QJsonValue(QStringLiteral(""));
QJsonValue tokensValue = jsonDocument[QStringLiteral("num_tokens")];
if(!tokensValue.isDouble())
{
qDebug()<<"Got invalid response on socket, num_tokens missing or not a number";
return;
}
tokens = tokensValue.toInt();
}
else if(type == Response::LEFT_TRIM)
{
responseValue = jsonDocument[QStringLiteral("trimmed_text")];
}
if(!responseValue.isString())
{
qDebug()<<"Got invalid response on socket";
return;
}
feedResponse(Response(responseValue.toString(), id, true, type, tokens));
}
const QString ExLlama::actionToStr(AiBackend::Request::type_t type)
{
switch(type)
{
case Request::INFERENCE:
return QStringLiteral("infer");
case Request::COUNT_TOKENS:
return QStringLiteral("estimate_token");
case Request::LEFT_TRIM:
return QStringLiteral("lefttrim_token");
default:
return QString();
}
}
AiBackend::Request::type_t ExLlama::strToAction(const QString& str)
{
if(str == QStringLiteral("infer"))
return Request::INFERENCE;
else if(str == QStringLiteral("estimate_token"))
return Request::COUNT_TOKENS;
else if(str == QStringLiteral("lefttrim_token"))
return Request::LEFT_TRIM;
else
return Request::UNKOWN;
}
bool ExLlama::generateImpl(const Request& request)
{
QString action = actionToStr(request.getType());
if(action.isEmpty())
return false;
QJsonObject json;
json[QStringLiteral("action")] = action;
json[QStringLiteral("request_id")] = static_cast<double>(request.getId());
json[QStringLiteral("text")] = request.getText();
if(request.getType() == Request::INFERENCE)
{
json[QStringLiteral("max_new_tokens")] = 50;
json[QStringLiteral("stream")] = false;
}
QJsonDocument jsonDocument(json);
QString requestText = QString::fromUtf8(jsonDocument.toJson(QJsonDocument::JsonFormat::Compact));
qDebug()<<__func__<<' '<<requestText;
m_webSocket.sendTextMessage(requestText);
return true;
}
void ExLlama::open(const QUrl& url)
{
m_webSocket.close();
m_webSocket.open(url);
}
QString ExLlama::backendName()
{
return backendNameStatic();
}
QString ExLlama::backendNameStatic()
{
return QStringLiteral("ExLlama");
}
ExLlama::~ExLlama()
{
m_webSocket.disconnect();
}