136 lines
3.0 KiB
C++
136 lines
3.0 KiB
C++
#include "exllama.h"
|
|
#include "backend.h"
|
|
|
|
#include <QJsonObject>
|
|
#include <QJsonDocument>
|
|
#include <cstdint>
|
|
|
|
ExLlama::ExLlama()
|
|
{
|
|
connect(&m_webSocket, &QWebSocket::textMessageReceived, this, &ExLlama::socketMessage);
|
|
}
|
|
|
|
bool ExLlama::ready()
|
|
{
|
|
return m_webSocket.isValid();
|
|
}
|
|
|
|
void ExLlama::socketMessage(const QString& message)
|
|
{
|
|
QJsonDocument jsonDocument = QJsonDocument::fromJson(message.toUtf8());
|
|
QJsonValue idVal = jsonDocument[QStringLiteral("request_id")];
|
|
int id = idVal.toInt();
|
|
|
|
if(!isValidId(id))
|
|
{
|
|
qDebug()<<"Got unkown response id on socket";
|
|
return;
|
|
}
|
|
|
|
Response::type_t type = strToAction(jsonDocument[QStringLiteral("action")].toString());
|
|
QJsonValue responseValue;
|
|
int64_t tokens = -1;
|
|
|
|
if(type == Response::INFERENCE)
|
|
{
|
|
responseValue = jsonDocument[QStringLiteral("response")];
|
|
}
|
|
else if(type == Response::COUNT_TOKENS)
|
|
{
|
|
responseValue = QJsonValue(QStringLiteral(""));
|
|
QJsonValue tokensValue = jsonDocument[QStringLiteral("num_tokens")];
|
|
if(!tokensValue.isDouble())
|
|
{
|
|
qDebug()<<"Got invalid response on socket, num_tokens missing or not a number";
|
|
return;
|
|
}
|
|
tokens = tokensValue.toInt();
|
|
}
|
|
else if(type == Response::LEFT_TRIM)
|
|
{
|
|
responseValue = jsonDocument[QStringLiteral("trimmed_text")];
|
|
}
|
|
|
|
if(!responseValue.isString())
|
|
{
|
|
qDebug()<<"Got invalid response on socket";
|
|
return;
|
|
}
|
|
|
|
feedResponse(Response(responseValue.toString(), id, true, type, tokens));
|
|
}
|
|
|
|
const QString ExLlama::actionToStr(AiBackend::Request::type_t type)
|
|
{
|
|
switch(type)
|
|
{
|
|
case Request::INFERENCE:
|
|
return QStringLiteral("infer");
|
|
case Request::COUNT_TOKENS:
|
|
return QStringLiteral("estimate_token");
|
|
case Request::LEFT_TRIM:
|
|
return QStringLiteral("lefttrim_token");
|
|
default:
|
|
return QString();
|
|
}
|
|
}
|
|
|
|
AiBackend::Request::type_t ExLlama::strToAction(const QString& str)
|
|
{
|
|
if(str == QStringLiteral("infer"))
|
|
return Request::INFERENCE;
|
|
else if(str == QStringLiteral("estimate_token"))
|
|
return Request::COUNT_TOKENS;
|
|
else if(str == QStringLiteral("lefttrim_token"))
|
|
return Request::LEFT_TRIM;
|
|
else
|
|
return Request::UNKOWN;
|
|
}
|
|
|
|
bool ExLlama::generateImpl(const Request& request)
|
|
{
|
|
QString action = actionToStr(request.getType());
|
|
|
|
if(action.isEmpty())
|
|
return false;
|
|
|
|
QJsonObject json;
|
|
json[QStringLiteral("action")] = action;
|
|
json[QStringLiteral("request_id")] = static_cast<double>(request.getId());
|
|
json[QStringLiteral("text")] = request.getText();
|
|
|
|
if(request.getType() == Request::INFERENCE)
|
|
{
|
|
json[QStringLiteral("max_new_tokens")] = 50;
|
|
json[QStringLiteral("stream")] = false;
|
|
}
|
|
|
|
QJsonDocument jsonDocument(json);
|
|
QString requestText = QString::fromUtf8(jsonDocument.toJson(QJsonDocument::JsonFormat::Compact));
|
|
qDebug()<<__func__<<' '<<requestText;
|
|
m_webSocket.sendTextMessage(requestText);
|
|
|
|
return true;
|
|
}
|
|
|
|
void ExLlama::open(const QUrl& url)
|
|
{
|
|
m_webSocket.close();
|
|
m_webSocket.open(url);
|
|
}
|
|
|
|
QString ExLlama::backendName()
|
|
{
|
|
return backendNameStatic();
|
|
}
|
|
|
|
QString ExLlama::backendNameStatic()
|
|
{
|
|
return QStringLiteral("ExLlama");
|
|
}
|
|
|
|
ExLlama::~ExLlama()
|
|
{
|
|
m_webSocket.disconnect();
|
|
}
|