prepare for selectable backends
This commit is contained in:
@ -34,6 +34,8 @@ target_sources(
|
|||||||
PRIVATE
|
PRIVATE
|
||||||
kateai.cpp
|
kateai.cpp
|
||||||
kateaiconfigpage.cpp
|
kateaiconfigpage.cpp
|
||||||
|
backend.cpp
|
||||||
|
exllama.cpp
|
||||||
plugin.qrc
|
plugin.qrc
|
||||||
)
|
)
|
||||||
|
|
||||||
|
65
backend.cpp
Normal file
65
backend.cpp
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
#include "backend.h"
|
||||||
|
|
||||||
|
AiBackend::Request::Request(const QString& textIn, void* userPtrIn):
|
||||||
|
text(textIn),
|
||||||
|
userPtr(userPtrIn)
|
||||||
|
{
|
||||||
|
id = idCounter++;
|
||||||
|
}
|
||||||
|
|
||||||
|
const QString& AiBackend::Request::getText() const
|
||||||
|
{
|
||||||
|
return text;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint32_t AiBackend::Request::getId() const
|
||||||
|
{
|
||||||
|
return id;
|
||||||
|
}
|
||||||
|
|
||||||
|
void* AiBackend::Request::getUserPtr() const
|
||||||
|
{
|
||||||
|
return userPtr;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool AiBackend::Request::operator==(const Request& in)
|
||||||
|
{
|
||||||
|
return id == in.id;
|
||||||
|
}
|
||||||
|
|
||||||
|
AiBackend::Response::Response(QString textIn, uint32_t idIn, bool finishedIn, void* userPtrIn):
|
||||||
|
AiBackend::Request::Request(textIn, userPtrIn),
|
||||||
|
finished(finishedIn)
|
||||||
|
{
|
||||||
|
id = idIn;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool AiBackend::Response::isFinished() const
|
||||||
|
{
|
||||||
|
return finished;
|
||||||
|
}
|
||||||
|
|
||||||
|
void AiBackend::Response::setUserPtr(void* ptr)
|
||||||
|
{
|
||||||
|
userPtr = ptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
void AiBackend::generate(const Request& request)
|
||||||
|
{
|
||||||
|
m_requests.insert(request.getId(), request);
|
||||||
|
generateImpl(request);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool AiBackend::isValidId(uint32_t id)
|
||||||
|
{
|
||||||
|
return m_requests.find(id) != m_requests.end();
|
||||||
|
}
|
||||||
|
|
||||||
|
void AiBackend::feedResponse(Response response)
|
||||||
|
{
|
||||||
|
response.setUserPtr(m_requests[response.getId()].getUserPtr());
|
||||||
|
if(response.isFinished())
|
||||||
|
m_requests.remove(response.getId());
|
||||||
|
|
||||||
|
gotResponse(response);
|
||||||
|
}
|
70
backend.h
Normal file
70
backend.h
Normal file
@ -0,0 +1,70 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <QObject>
|
||||||
|
#include <QHash>
|
||||||
|
#include <QString>
|
||||||
|
#include <cstdint>
|
||||||
|
#include <stdexcept>
|
||||||
|
#include <string>
|
||||||
|
#include <QUrl>
|
||||||
|
|
||||||
|
class AiBackend: public QObject
|
||||||
|
{
|
||||||
|
Q_OBJECT
|
||||||
|
public:
|
||||||
|
|
||||||
|
class backend_error: public std::runtime_error
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
backend_error(std::string err): runtime_error(err){}
|
||||||
|
};
|
||||||
|
|
||||||
|
class Request
|
||||||
|
{
|
||||||
|
protected:
|
||||||
|
inline static uint32_t idCounter = 0;
|
||||||
|
QString text;
|
||||||
|
uint32_t id;
|
||||||
|
void* userPtr;
|
||||||
|
|
||||||
|
public:
|
||||||
|
Request() = default;
|
||||||
|
Request(const QString& text, void* userPtr = nullptr);
|
||||||
|
const QString& getText() const;
|
||||||
|
uint32_t getId() const;
|
||||||
|
bool operator==(const Request& in);
|
||||||
|
void* getUserPtr() const;
|
||||||
|
};
|
||||||
|
|
||||||
|
class Response: public Request
|
||||||
|
{
|
||||||
|
private:
|
||||||
|
bool finished;
|
||||||
|
|
||||||
|
public:
|
||||||
|
Response() = default;
|
||||||
|
Response(QString text, uint32_t id, bool finished, void* userPtr = nullptr);
|
||||||
|
bool isFinished() const;
|
||||||
|
void setUserPtr(void* ptr);
|
||||||
|
};
|
||||||
|
|
||||||
|
protected:
|
||||||
|
QHash<uint32_t, Request> m_requests;
|
||||||
|
|
||||||
|
void feedResponse(Response response);
|
||||||
|
bool isValidId(uint32_t id);
|
||||||
|
virtual void generateImpl(const Request& request) = 0;
|
||||||
|
|
||||||
|
public:
|
||||||
|
|
||||||
|
virtual bool ready() = 0;
|
||||||
|
void generate(const Request& request);
|
||||||
|
virtual void abort(uint64_t id){(void)id;}
|
||||||
|
virtual void open(const QUrl& url){(void)url;};
|
||||||
|
Q_SIGNAL void gotResponse(Response response);
|
||||||
|
|
||||||
|
virtual ~AiBackend() = default;
|
||||||
|
};
|
||||||
|
|
||||||
|
Q_DECLARE_METATYPE(AiBackend::Response);
|
||||||
|
Q_DECLARE_METATYPE(AiBackend::Request);
|
67
exllama.cpp
Normal file
67
exllama.cpp
Normal file
@ -0,0 +1,67 @@
|
|||||||
|
#include "exllama.h"
|
||||||
|
|
||||||
|
#include <QJsonObject>
|
||||||
|
#include <QJsonDocument>
|
||||||
|
|
||||||
|
ExLlama::ExLlama()
|
||||||
|
{
|
||||||
|
connect(&m_webSocket, &QWebSocket::textMessageReceived, this, &ExLlama::socketMessage);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool ExLlama::ready()
|
||||||
|
{
|
||||||
|
return m_webSocket.isValid();
|
||||||
|
}
|
||||||
|
|
||||||
|
void ExLlama::socketMessage(const QString& message)
|
||||||
|
{
|
||||||
|
QJsonDocument jsonDocument = QJsonDocument::fromJson(message.toUtf8());
|
||||||
|
QJsonValue idVal = jsonDocument[QStringLiteral("request_id")];
|
||||||
|
if(!idVal.isDouble())
|
||||||
|
{
|
||||||
|
qDebug()<<"Got invalid response on socket";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
int id = idVal.toInt();
|
||||||
|
|
||||||
|
if(!isValidId(id))
|
||||||
|
{
|
||||||
|
qDebug()<<"Got unkown response id on socket";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
QJsonValue responseValue = jsonDocument[QStringLiteral("response")];
|
||||||
|
if(!responseValue.isString())
|
||||||
|
{
|
||||||
|
qDebug()<<"Got invalid response on socket";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
feedResponse(Response(responseValue.toString(), id, true));
|
||||||
|
}
|
||||||
|
|
||||||
|
void ExLlama::generateImpl(const Request& request)
|
||||||
|
{
|
||||||
|
QJsonObject json;
|
||||||
|
json[QStringLiteral("action")] = QStringLiteral("infer");
|
||||||
|
json[QStringLiteral("request_id")] = static_cast<double>(request.getId());
|
||||||
|
json[QStringLiteral("text")] = request.getText();
|
||||||
|
json[QStringLiteral("max_new_tokens")] = 50;
|
||||||
|
json[QStringLiteral("stream")] = false;
|
||||||
|
|
||||||
|
QJsonDocument jsonDocument(json);
|
||||||
|
QString requestText = QString::fromUtf8(jsonDocument.toJson(QJsonDocument::JsonFormat::Compact));
|
||||||
|
qDebug()<<__func__<<' '<<requestText;
|
||||||
|
m_webSocket.sendTextMessage(requestText);
|
||||||
|
}
|
||||||
|
|
||||||
|
void ExLlama::open(const QUrl& url)
|
||||||
|
{
|
||||||
|
m_webSocket.close();
|
||||||
|
m_webSocket.open(url);
|
||||||
|
}
|
||||||
|
|
||||||
|
ExLlama::~ExLlama()
|
||||||
|
{
|
||||||
|
m_webSocket.disconnect();
|
||||||
|
}
|
25
exllama.h
Normal file
25
exllama.h
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "backend.h"
|
||||||
|
|
||||||
|
#include <QUrl>
|
||||||
|
#include <QWebSocket>
|
||||||
|
|
||||||
|
class ExLlama: public AiBackend
|
||||||
|
{
|
||||||
|
Q_OBJECT
|
||||||
|
private:
|
||||||
|
QWebSocket m_webSocket;
|
||||||
|
|
||||||
|
void socketMessage(const QString& message);
|
||||||
|
|
||||||
|
protected:
|
||||||
|
virtual void generateImpl(const Request& request) override;
|
||||||
|
|
||||||
|
public:
|
||||||
|
ExLlama();
|
||||||
|
virtual bool ready() override;
|
||||||
|
virtual void open(const QUrl& url) override;
|
||||||
|
|
||||||
|
virtual ~ExLlama();
|
||||||
|
};
|
67
kateai.cpp
67
kateai.cpp
@ -21,6 +21,8 @@
|
|||||||
#include <KLocalizedString>
|
#include <KLocalizedString>
|
||||||
#include <KSharedConfig>
|
#include <KSharedConfig>
|
||||||
|
|
||||||
|
#include "backend.h"
|
||||||
|
#include "exllama.h"
|
||||||
#include "kateaiconfigpage.h"
|
#include "kateaiconfigpage.h"
|
||||||
|
|
||||||
K_PLUGIN_FACTORY_WITH_JSON(KateAiPluginFactory, "kateai.json", registerPlugin<KateAiPlugin>();)
|
K_PLUGIN_FACTORY_WITH_JSON(KateAiPluginFactory, "kateai.json", registerPlugin<KateAiPlugin>();)
|
||||||
@ -28,26 +30,21 @@ K_PLUGIN_FACTORY_WITH_JSON(KateAiPluginFactory, "kateai.json", registerPlugin<Ka
|
|||||||
KateAiPlugin::KateAiPlugin(QObject *parent, const QList<QVariant> &)
|
KateAiPlugin::KateAiPlugin(QObject *parent, const QList<QVariant> &)
|
||||||
: KTextEditor::Plugin(parent), m_serverUrl(QStringLiteral("ws://localhost:8642"))
|
: KTextEditor::Plugin(parent), m_serverUrl(QStringLiteral("ws://localhost:8642"))
|
||||||
{
|
{
|
||||||
connect(&m_webSocket, &QWebSocket::connected, this, &KateAiPlugin::onConnected);
|
m_ai = new ExLlama();
|
||||||
|
KateAiPluginView::setAi(m_ai);
|
||||||
readConfig();
|
readConfig();
|
||||||
}
|
}
|
||||||
|
|
||||||
void KateAiPlugin::onConnected()
|
|
||||||
{
|
|
||||||
qDebug()<<__func__<<m_webSocket.isValid();
|
|
||||||
}
|
|
||||||
|
|
||||||
KateAiPlugin::~KateAiPlugin() = default;
|
KateAiPlugin::~KateAiPlugin() = default;
|
||||||
|
|
||||||
void KateAiPlugin::reconnect()
|
void KateAiPlugin::reconnect()
|
||||||
{
|
{
|
||||||
m_webSocket.close();
|
m_ai->open(m_serverUrl);
|
||||||
m_webSocket.open(m_serverUrl);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
QObject *KateAiPlugin::createView(KTextEditor::MainWindow *mainWindow)
|
QObject *KateAiPlugin::createView(KTextEditor::MainWindow *mainWindow)
|
||||||
{
|
{
|
||||||
auto view = new KateAiPluginView(this, mainWindow, &m_webSocket);
|
auto view = new KateAiPluginView(this, mainWindow);
|
||||||
connect(view, &KateAiPluginView::reconnect, this, &KateAiPlugin::reconnect);
|
connect(view, &KateAiPluginView::reconnect, this, &KateAiPlugin::reconnect);
|
||||||
return view;
|
return view;
|
||||||
}
|
}
|
||||||
@ -67,14 +64,12 @@ void KateAiPlugin::readConfig()
|
|||||||
reconnect();
|
reconnect();
|
||||||
}
|
}
|
||||||
|
|
||||||
KateAiPluginView::KateAiPluginView(KateAiPlugin *plugin, KTextEditor::MainWindow *mainwindow,
|
KateAiPluginView::KateAiPluginView(KateAiPlugin *plugin, KTextEditor::MainWindow *mainwindow, bool instruct)
|
||||||
QPointer<QWebSocket> webSocket, bool instruct)
|
|
||||||
: QObject(plugin)
|
: QObject(plugin)
|
||||||
, m_mainWindow(mainwindow)
|
, m_mainWindow(mainwindow)
|
||||||
, m_webSocket(webSocket)
|
|
||||||
, m_useInstruct(instruct)
|
, m_useInstruct(instruct)
|
||||||
{
|
{
|
||||||
KXMLGUIClient::setComponentName(QStringLiteral("kateaiplugin"), QStringLiteral("Git Blame"));
|
KXMLGUIClient::setComponentName(QStringLiteral("kateaiplugin"), QStringLiteral("Kate Ai"));
|
||||||
setXMLFile(QStringLiteral("ui.rc"));
|
setXMLFile(QStringLiteral("ui.rc"));
|
||||||
QAction *generateAction = actionCollection()->addAction(QStringLiteral("ai_generate"));
|
QAction *generateAction = actionCollection()->addAction(QStringLiteral("ai_generate"));
|
||||||
generateAction->setText(QStringLiteral("Generate text using AI"));
|
generateAction->setText(QStringLiteral("Generate text using AI"));
|
||||||
@ -82,7 +77,7 @@ KateAiPluginView::KateAiPluginView(KateAiPlugin *plugin, KTextEditor::MainWindow
|
|||||||
m_mainWindow->guiFactory()->addClient(this);
|
m_mainWindow->guiFactory()->addClient(this);
|
||||||
|
|
||||||
connect(generateAction, &QAction::triggered, this, &KateAiPluginView::generate);
|
connect(generateAction, &QAction::triggered, this, &KateAiPluginView::generate);
|
||||||
connect(m_webSocket, &QWebSocket::textMessageReceived, this, &KateAiPluginView::socketMessage);
|
connect(m_ai, &AiBackend::gotResponse, this, &KateAiPluginView::gotResponse);
|
||||||
}
|
}
|
||||||
|
|
||||||
QPointer<KTextEditor::Document> KateAiPluginView::activeDocument() const
|
QPointer<KTextEditor::Document> KateAiPluginView::activeDocument() const
|
||||||
@ -101,29 +96,17 @@ KTextEditor::Cursor KateAiPluginView::getCurrentCursor() const
|
|||||||
return KTextEditor::Cursor();
|
return KTextEditor::Cursor();
|
||||||
}
|
}
|
||||||
|
|
||||||
void KateAiPluginView::socketMessage(const QString& message)
|
void KateAiPluginView::gotResponse(AiBackend::Response response)
|
||||||
{
|
{
|
||||||
QJsonDocument jsonDocument = QJsonDocument::fromJson(message.toUtf8());
|
QHash<uint32_t, Request>::iterator it = m_requests.find(response.getId());
|
||||||
QJsonValue idVal = jsonDocument[QStringLiteral("request_id")];
|
|
||||||
if(!idVal.isDouble())
|
|
||||||
{
|
|
||||||
qDebug()<<"Got invalid response on socket";
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
int id = idVal.toInt();
|
|
||||||
|
|
||||||
QHash<int, Request>::iterator it = m_requests.find(id);
|
|
||||||
if(it != m_requests.end())
|
if(it != m_requests.end())
|
||||||
{
|
{
|
||||||
QJsonValue responseValue = jsonDocument[QStringLiteral("response")];
|
it.value().document->insertText(it.value().cursor, response.getText());
|
||||||
if(!responseValue.isString())
|
if(response.isFinished())
|
||||||
{
|
m_requests.erase(it);
|
||||||
qDebug()<<"Got invalid response on socket";
|
else
|
||||||
return;
|
it.value().cursor = getCurrentCursor();
|
||||||
}
|
|
||||||
if(it.value().document)
|
|
||||||
it.value().document->insertText(it.value().cursor, responseValue.toString());
|
|
||||||
m_requests.erase(it);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -198,25 +181,15 @@ QString KateAiPluginView::assembleContext(QPointer<KTextEditor::Document> docume
|
|||||||
void KateAiPluginView::generate()
|
void KateAiPluginView::generate()
|
||||||
{
|
{
|
||||||
qDebug()<<activeDocument()->mimeType();
|
qDebug()<<activeDocument()->mimeType();
|
||||||
if(m_webSocket && m_webSocket->isValid())
|
if(m_ai->ready())
|
||||||
{
|
{
|
||||||
KTextEditor::Cursor cursor = getCurrentCursor();
|
KTextEditor::Cursor cursor = getCurrentCursor();
|
||||||
QPointer<KTextEditor::Document> document = activeDocument();
|
QPointer<KTextEditor::Document> document = activeDocument();
|
||||||
QString text = assembleContext(document, cursor);
|
QString text = assembleContext(document, cursor);
|
||||||
int id = QRandomGenerator::global()->bounded(0, std::numeric_limits<int>::max());
|
|
||||||
|
|
||||||
QJsonObject json;
|
AiBackend::Request request(text);
|
||||||
json[QStringLiteral("action")] = QStringLiteral("infer");
|
m_ai->generate(request);
|
||||||
json[QStringLiteral("request_id")] = id;
|
m_requests.insert(request.getId(), {cursor, document});
|
||||||
json[QStringLiteral("text")] = text;
|
|
||||||
json[QStringLiteral("max_new_tokens")] = 50;
|
|
||||||
json[QStringLiteral("stream")] = false;
|
|
||||||
|
|
||||||
QJsonDocument jsonDocument(json);
|
|
||||||
QString requestText = QString::fromUtf8(jsonDocument.toJson(QJsonDocument::JsonFormat::Compact));
|
|
||||||
qDebug()<<__func__<<' '<<requestText;
|
|
||||||
m_webSocket->sendTextMessage(requestText);
|
|
||||||
m_requests.insert(id, {cursor, document});
|
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
|
19
kateai.h
19
kateai.h
@ -7,22 +7,21 @@
|
|||||||
#include <KTextEditor/Plugin>
|
#include <KTextEditor/Plugin>
|
||||||
#include <KXMLGUIClient>
|
#include <KXMLGUIClient>
|
||||||
|
|
||||||
#include <QList>
|
#include <QSet>
|
||||||
#include <QAction>
|
#include <QAction>
|
||||||
#include <QWebSocket>
|
|
||||||
#include <QPointer>
|
#include <QPointer>
|
||||||
#include <QtCore>
|
#include <QtCore>
|
||||||
|
|
||||||
|
#include "backend.h"
|
||||||
|
|
||||||
class KateAiPlugin : public KTextEditor::Plugin
|
class KateAiPlugin : public KTextEditor::Plugin
|
||||||
{
|
{
|
||||||
Q_OBJECT
|
Q_OBJECT
|
||||||
private:
|
private:
|
||||||
QWebSocket m_webSocket;
|
QPointer<AiBackend> m_ai;
|
||||||
QUrl m_serverUrl;
|
QUrl m_serverUrl;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void onConnected();
|
|
||||||
|
|
||||||
int configPages() const override
|
int configPages() const override
|
||||||
{
|
{
|
||||||
return 1;
|
return 1;
|
||||||
@ -52,24 +51,24 @@ private:
|
|||||||
};
|
};
|
||||||
|
|
||||||
KTextEditor::MainWindow *m_mainWindow;
|
KTextEditor::MainWindow *m_mainWindow;
|
||||||
QPointer<QWebSocket> m_webSocket;
|
inline static QPointer<AiBackend> m_ai;
|
||||||
QHash<int, Request> m_requests;
|
|
||||||
bool m_useInstruct = false;
|
bool m_useInstruct = false;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void generate();
|
void generate();
|
||||||
void socketMessage(const QString& message);
|
void gotResponse(AiBackend::Response respons);
|
||||||
static QStringList getIncludePaths(const QString& text);
|
static QStringList getIncludePaths(const QString& text);
|
||||||
QString assembleContext(QPointer<KTextEditor::Document> document, const KTextEditor::Cursor& cursor);
|
QString assembleContext(QPointer<KTextEditor::Document> document, const KTextEditor::Cursor& cursor);
|
||||||
|
|
||||||
QPointer<KTextEditor::Document> activeDocument() const;
|
QPointer<KTextEditor::Document> activeDocument() const;
|
||||||
KTextEditor::Cursor getCurrentCursor() const;
|
KTextEditor::Cursor getCurrentCursor() const;
|
||||||
|
QHash<uint32_t, Request> m_requests;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
KateAiPluginView(KateAiPlugin *plugin, KTextEditor::MainWindow *mainwindow, QPointer<QWebSocket> webSocket,
|
KateAiPluginView(KateAiPlugin *plugin, KTextEditor::MainWindow *mainwindow, bool instruct = false);
|
||||||
bool instruct = false);
|
|
||||||
~KateAiPluginView() override;
|
~KateAiPluginView() override;
|
||||||
void setInstruct(bool instruct);
|
void setInstruct(bool instruct);
|
||||||
|
static void setAi(QPointer<AiBackend> ai){m_ai = ai;}
|
||||||
|
|
||||||
Q_SIGNAL void reconnect();
|
Q_SIGNAL void reconnect();
|
||||||
};
|
};
|
||||||
|
Reference in New Issue
Block a user