general ui work

This commit is contained in:
2024-06-11 14:15:08 +02:00
parent 8a26f9e1e4
commit b111e15fd5
8 changed files with 272 additions and 47 deletions

View File

@ -1,7 +1,9 @@
#include "exllama.h"
#include "backend.h"
#include <QJsonObject>
#include <QJsonDocument>
#include <cstdint>
ExLlama::ExLlama()
{
@ -17,11 +19,6 @@ void ExLlama::socketMessage(const QString& message)
{
QJsonDocument jsonDocument = QJsonDocument::fromJson(message.toUtf8());
QJsonValue idVal = jsonDocument[QStringLiteral("request_id")];
if(!idVal.isDouble())
{
qDebug()<<"Got invalid response on socket";
return;
}
int id = idVal.toInt();
if(!isValidId(id))
@ -30,29 +27,90 @@ void ExLlama::socketMessage(const QString& message)
return;
}
QJsonValue responseValue = jsonDocument[QStringLiteral("response")];
Response::type_t type = strToAction(jsonDocument[QStringLiteral("action")].toString());
QJsonValue responseValue;
int64_t tokens = -1;
if(type == Response::INFERENCE)
{
responseValue = jsonDocument[QStringLiteral("response")];
}
else if(type == Response::COUNT_TOKENS)
{
responseValue = QJsonValue(QStringLiteral(""));
QJsonValue tokensValue = jsonDocument[QStringLiteral("num_tokens")];
if(!tokensValue.isDouble())
{
qDebug()<<"Got invalid response on socket, num_tokens missing or not a number";
return;
}
tokens = tokensValue.toInt();
}
else if(type == Response::LEFT_TRIM)
{
responseValue = jsonDocument[QStringLiteral("trimmed_text")];
}
if(!responseValue.isString())
{
qDebug()<<"Got invalid response on socket";
return;
}
feedResponse(Response(responseValue.toString(), id, true));
feedResponse(Response(responseValue.toString(), id, true, type, tokens));
}
void ExLlama::generateImpl(const Request& request)
const QString ExLlama::actionToStr(AiBackend::Request::type_t type)
{
switch(type)
{
case Request::INFERENCE:
return QStringLiteral("infer");
case Request::COUNT_TOKENS:
return QStringLiteral("estimate_token");
case Request::LEFT_TRIM:
return QStringLiteral("lefttrim_token");
default:
return QString();
}
}
AiBackend::Request::type_t ExLlama::strToAction(const QString& str)
{
if(str == QStringLiteral("infer"))
return Request::INFERENCE;
else if(str == QStringLiteral("estimate_token"))
return Request::COUNT_TOKENS;
else if(str == QStringLiteral("lefttrim_token"))
return Request::LEFT_TRIM;
else
return Request::UNKOWN;
}
bool ExLlama::generateImpl(const Request& request)
{
QString action = actionToStr(request.getType());
if(action.isEmpty())
return false;
QJsonObject json;
json[QStringLiteral("action")] = QStringLiteral("infer");
json[QStringLiteral("action")] = action;
json[QStringLiteral("request_id")] = static_cast<double>(request.getId());
json[QStringLiteral("text")] = request.getText();
json[QStringLiteral("max_new_tokens")] = 50;
json[QStringLiteral("stream")] = false;
if(request.getType() == Request::INFERENCE)
{
json[QStringLiteral("max_new_tokens")] = 50;
json[QStringLiteral("stream")] = false;
}
QJsonDocument jsonDocument(json);
QString requestText = QString::fromUtf8(jsonDocument.toJson(QJsonDocument::JsonFormat::Compact));
qDebug()<<__func__<<' '<<requestText;
m_webSocket.sendTextMessage(requestText);
return true;
}
void ExLlama::open(const QUrl& url)
@ -61,6 +119,16 @@ void ExLlama::open(const QUrl& url)
m_webSocket.open(url);
}
QString ExLlama::backendName()
{
return backendNameStatic();
}
QString ExLlama::backendNameStatic()
{
return QStringLiteral("ExLlama");
}
ExLlama::~ExLlama()
{
m_webSocket.disconnect();