prepare for selectable backends

This commit is contained in:
uvos 2023-11-02 23:12:17 +01:00
parent fc0ca71b4d
commit 8a26f9e1e4
7 changed files with 258 additions and 57 deletions

View File

@ -34,6 +34,8 @@ target_sources(
PRIVATE
kateai.cpp
kateaiconfigpage.cpp
backend.cpp
exllama.cpp
plugin.qrc
)

65
backend.cpp Normal file
View File

@ -0,0 +1,65 @@
#include "backend.h"
AiBackend::Request::Request(const QString& textIn, void* userPtrIn):
text(textIn),
userPtr(userPtrIn)
{
id = idCounter++;
}
const QString& AiBackend::Request::getText() const
{
return text;
}
uint32_t AiBackend::Request::getId() const
{
return id;
}
void* AiBackend::Request::getUserPtr() const
{
return userPtr;
}
bool AiBackend::Request::operator==(const Request& in)
{
return id == in.id;
}
AiBackend::Response::Response(QString textIn, uint32_t idIn, bool finishedIn, void* userPtrIn):
AiBackend::Request::Request(textIn, userPtrIn),
finished(finishedIn)
{
id = idIn;
}
bool AiBackend::Response::isFinished() const
{
return finished;
}
void AiBackend::Response::setUserPtr(void* ptr)
{
userPtr = ptr;
}
void AiBackend::generate(const Request& request)
{
m_requests.insert(request.getId(), request);
generateImpl(request);
}
bool AiBackend::isValidId(uint32_t id)
{
return m_requests.find(id) != m_requests.end();
}
void AiBackend::feedResponse(Response response)
{
response.setUserPtr(m_requests[response.getId()].getUserPtr());
if(response.isFinished())
m_requests.remove(response.getId());
gotResponse(response);
}

70
backend.h Normal file
View File

@ -0,0 +1,70 @@
#pragma once
#include <QObject>
#include <QHash>
#include <QString>
#include <cstdint>
#include <stdexcept>
#include <string>
#include <QUrl>
class AiBackend: public QObject
{
Q_OBJECT
public:
class backend_error: public std::runtime_error
{
public:
backend_error(std::string err): runtime_error(err){}
};
class Request
{
protected:
inline static uint32_t idCounter = 0;
QString text;
uint32_t id;
void* userPtr;
public:
Request() = default;
Request(const QString& text, void* userPtr = nullptr);
const QString& getText() const;
uint32_t getId() const;
bool operator==(const Request& in);
void* getUserPtr() const;
};
class Response: public Request
{
private:
bool finished;
public:
Response() = default;
Response(QString text, uint32_t id, bool finished, void* userPtr = nullptr);
bool isFinished() const;
void setUserPtr(void* ptr);
};
protected:
QHash<uint32_t, Request> m_requests;
void feedResponse(Response response);
bool isValidId(uint32_t id);
virtual void generateImpl(const Request& request) = 0;
public:
virtual bool ready() = 0;
void generate(const Request& request);
virtual void abort(uint64_t id){(void)id;}
virtual void open(const QUrl& url){(void)url;};
Q_SIGNAL void gotResponse(Response response);
virtual ~AiBackend() = default;
};
Q_DECLARE_METATYPE(AiBackend::Response);
Q_DECLARE_METATYPE(AiBackend::Request);

67
exllama.cpp Normal file
View File

@ -0,0 +1,67 @@
#include "exllama.h"
#include <QJsonObject>
#include <QJsonDocument>
ExLlama::ExLlama()
{
connect(&m_webSocket, &QWebSocket::textMessageReceived, this, &ExLlama::socketMessage);
}
bool ExLlama::ready()
{
return m_webSocket.isValid();
}
void ExLlama::socketMessage(const QString& message)
{
QJsonDocument jsonDocument = QJsonDocument::fromJson(message.toUtf8());
QJsonValue idVal = jsonDocument[QStringLiteral("request_id")];
if(!idVal.isDouble())
{
qDebug()<<"Got invalid response on socket";
return;
}
int id = idVal.toInt();
if(!isValidId(id))
{
qDebug()<<"Got unkown response id on socket";
return;
}
QJsonValue responseValue = jsonDocument[QStringLiteral("response")];
if(!responseValue.isString())
{
qDebug()<<"Got invalid response on socket";
return;
}
feedResponse(Response(responseValue.toString(), id, true));
}
void ExLlama::generateImpl(const Request& request)
{
QJsonObject json;
json[QStringLiteral("action")] = QStringLiteral("infer");
json[QStringLiteral("request_id")] = static_cast<double>(request.getId());
json[QStringLiteral("text")] = request.getText();
json[QStringLiteral("max_new_tokens")] = 50;
json[QStringLiteral("stream")] = false;
QJsonDocument jsonDocument(json);
QString requestText = QString::fromUtf8(jsonDocument.toJson(QJsonDocument::JsonFormat::Compact));
qDebug()<<__func__<<' '<<requestText;
m_webSocket.sendTextMessage(requestText);
}
void ExLlama::open(const QUrl& url)
{
m_webSocket.close();
m_webSocket.open(url);
}
ExLlama::~ExLlama()
{
m_webSocket.disconnect();
}

25
exllama.h Normal file
View File

@ -0,0 +1,25 @@
#pragma once
#include "backend.h"
#include <QUrl>
#include <QWebSocket>
class ExLlama: public AiBackend
{
Q_OBJECT
private:
QWebSocket m_webSocket;
void socketMessage(const QString& message);
protected:
virtual void generateImpl(const Request& request) override;
public:
ExLlama();
virtual bool ready() override;
virtual void open(const QUrl& url) override;
virtual ~ExLlama();
};

View File

@ -21,6 +21,8 @@
#include <KLocalizedString>
#include <KSharedConfig>
#include "backend.h"
#include "exllama.h"
#include "kateaiconfigpage.h"
K_PLUGIN_FACTORY_WITH_JSON(KateAiPluginFactory, "kateai.json", registerPlugin<KateAiPlugin>();)
@ -28,26 +30,21 @@ K_PLUGIN_FACTORY_WITH_JSON(KateAiPluginFactory, "kateai.json", registerPlugin<Ka
KateAiPlugin::KateAiPlugin(QObject *parent, const QList<QVariant> &)
: KTextEditor::Plugin(parent), m_serverUrl(QStringLiteral("ws://localhost:8642"))
{
connect(&m_webSocket, &QWebSocket::connected, this, &KateAiPlugin::onConnected);
m_ai = new ExLlama();
KateAiPluginView::setAi(m_ai);
readConfig();
}
void KateAiPlugin::onConnected()
{
qDebug()<<__func__<<m_webSocket.isValid();
}
KateAiPlugin::~KateAiPlugin() = default;
void KateAiPlugin::reconnect()
{
m_webSocket.close();
m_webSocket.open(m_serverUrl);
m_ai->open(m_serverUrl);
}
QObject *KateAiPlugin::createView(KTextEditor::MainWindow *mainWindow)
{
auto view = new KateAiPluginView(this, mainWindow, &m_webSocket);
auto view = new KateAiPluginView(this, mainWindow);
connect(view, &KateAiPluginView::reconnect, this, &KateAiPlugin::reconnect);
return view;
}
@ -67,14 +64,12 @@ void KateAiPlugin::readConfig()
reconnect();
}
KateAiPluginView::KateAiPluginView(KateAiPlugin *plugin, KTextEditor::MainWindow *mainwindow,
QPointer<QWebSocket> webSocket, bool instruct)
KateAiPluginView::KateAiPluginView(KateAiPlugin *plugin, KTextEditor::MainWindow *mainwindow, bool instruct)
: QObject(plugin)
, m_mainWindow(mainwindow)
, m_webSocket(webSocket)
, m_useInstruct(instruct)
{
KXMLGUIClient::setComponentName(QStringLiteral("kateaiplugin"), QStringLiteral("Git Blame"));
KXMLGUIClient::setComponentName(QStringLiteral("kateaiplugin"), QStringLiteral("Kate Ai"));
setXMLFile(QStringLiteral("ui.rc"));
QAction *generateAction = actionCollection()->addAction(QStringLiteral("ai_generate"));
generateAction->setText(QStringLiteral("Generate text using AI"));
@ -82,7 +77,7 @@ KateAiPluginView::KateAiPluginView(KateAiPlugin *plugin, KTextEditor::MainWindow
m_mainWindow->guiFactory()->addClient(this);
connect(generateAction, &QAction::triggered, this, &KateAiPluginView::generate);
connect(m_webSocket, &QWebSocket::textMessageReceived, this, &KateAiPluginView::socketMessage);
connect(m_ai, &AiBackend::gotResponse, this, &KateAiPluginView::gotResponse);
}
QPointer<KTextEditor::Document> KateAiPluginView::activeDocument() const
@ -101,29 +96,17 @@ KTextEditor::Cursor KateAiPluginView::getCurrentCursor() const
return KTextEditor::Cursor();
}
void KateAiPluginView::socketMessage(const QString& message)
void KateAiPluginView::gotResponse(AiBackend::Response response)
{
QJsonDocument jsonDocument = QJsonDocument::fromJson(message.toUtf8());
QJsonValue idVal = jsonDocument[QStringLiteral("request_id")];
if(!idVal.isDouble())
{
qDebug()<<"Got invalid response on socket";
return;
}
int id = idVal.toInt();
QHash<uint32_t, Request>::iterator it = m_requests.find(response.getId());
QHash<int, Request>::iterator it = m_requests.find(id);
if(it != m_requests.end())
{
QJsonValue responseValue = jsonDocument[QStringLiteral("response")];
if(!responseValue.isString())
{
qDebug()<<"Got invalid response on socket";
return;
}
if(it.value().document)
it.value().document->insertText(it.value().cursor, responseValue.toString());
m_requests.erase(it);
it.value().document->insertText(it.value().cursor, response.getText());
if(response.isFinished())
m_requests.erase(it);
else
it.value().cursor = getCurrentCursor();
}
}
@ -198,25 +181,15 @@ QString KateAiPluginView::assembleContext(QPointer<KTextEditor::Document> docume
void KateAiPluginView::generate()
{
qDebug()<<activeDocument()->mimeType();
if(m_webSocket && m_webSocket->isValid())
if(m_ai->ready())
{
KTextEditor::Cursor cursor = getCurrentCursor();
QPointer<KTextEditor::Document> document = activeDocument();
QString text = assembleContext(document, cursor);
int id = QRandomGenerator::global()->bounded(0, std::numeric_limits<int>::max());
QJsonObject json;
json[QStringLiteral("action")] = QStringLiteral("infer");
json[QStringLiteral("request_id")] = id;
json[QStringLiteral("text")] = text;
json[QStringLiteral("max_new_tokens")] = 50;
json[QStringLiteral("stream")] = false;
QJsonDocument jsonDocument(json);
QString requestText = QString::fromUtf8(jsonDocument.toJson(QJsonDocument::JsonFormat::Compact));
qDebug()<<__func__<<' '<<requestText;
m_webSocket->sendTextMessage(requestText);
m_requests.insert(id, {cursor, document});
AiBackend::Request request(text);
m_ai->generate(request);
m_requests.insert(request.getId(), {cursor, document});
}
else
{

View File

@ -7,22 +7,21 @@
#include <KTextEditor/Plugin>
#include <KXMLGUIClient>
#include <QList>
#include <QSet>
#include <QAction>
#include <QWebSocket>
#include <QPointer>
#include <QtCore>
#include "backend.h"
class KateAiPlugin : public KTextEditor::Plugin
{
Q_OBJECT
private:
QWebSocket m_webSocket;
QPointer<AiBackend> m_ai;
QUrl m_serverUrl;
private:
void onConnected();
int configPages() const override
{
return 1;
@ -52,24 +51,24 @@ private:
};
KTextEditor::MainWindow *m_mainWindow;
QPointer<QWebSocket> m_webSocket;
QHash<int, Request> m_requests;
inline static QPointer<AiBackend> m_ai;
bool m_useInstruct = false;
private:
void generate();
void socketMessage(const QString& message);
void gotResponse(AiBackend::Response respons);
static QStringList getIncludePaths(const QString& text);
QString assembleContext(QPointer<KTextEditor::Document> document, const KTextEditor::Cursor& cursor);
QPointer<KTextEditor::Document> activeDocument() const;
KTextEditor::Cursor getCurrentCursor() const;
QHash<uint32_t, Request> m_requests;
public:
KateAiPluginView(KateAiPlugin *plugin, KTextEditor::MainWindow *mainwindow, QPointer<QWebSocket> webSocket,
bool instruct = false);
KateAiPluginView(KateAiPlugin *plugin, KTextEditor::MainWindow *mainwindow, bool instruct = false);
~KateAiPluginView() override;
void setInstruct(bool instruct);
static void setAi(QPointer<AiBackend> ai){m_ai = ai;}
Q_SIGNAL void reconnect();
};