prepare for selectable backends

This commit is contained in:
2023-11-02 23:12:17 +01:00
parent fc0ca71b4d
commit 8a26f9e1e4
7 changed files with 258 additions and 57 deletions

View File

@ -21,6 +21,8 @@
#include <KLocalizedString>
#include <KSharedConfig>
#include "backend.h"
#include "exllama.h"
#include "kateaiconfigpage.h"
K_PLUGIN_FACTORY_WITH_JSON(KateAiPluginFactory, "kateai.json", registerPlugin<KateAiPlugin>();)
@ -28,26 +30,21 @@ K_PLUGIN_FACTORY_WITH_JSON(KateAiPluginFactory, "kateai.json", registerPlugin<Ka
KateAiPlugin::KateAiPlugin(QObject *parent, const QList<QVariant> &)
: KTextEditor::Plugin(parent), m_serverUrl(QStringLiteral("ws://localhost:8642"))
{
connect(&m_webSocket, &QWebSocket::connected, this, &KateAiPlugin::onConnected);
m_ai = new ExLlama();
KateAiPluginView::setAi(m_ai);
readConfig();
}
void KateAiPlugin::onConnected()
{
qDebug()<<__func__<<m_webSocket.isValid();
}
KateAiPlugin::~KateAiPlugin() = default;
void KateAiPlugin::reconnect()
{
m_webSocket.close();
m_webSocket.open(m_serverUrl);
m_ai->open(m_serverUrl);
}
QObject *KateAiPlugin::createView(KTextEditor::MainWindow *mainWindow)
{
auto view = new KateAiPluginView(this, mainWindow, &m_webSocket);
auto view = new KateAiPluginView(this, mainWindow);
connect(view, &KateAiPluginView::reconnect, this, &KateAiPlugin::reconnect);
return view;
}
@ -67,14 +64,12 @@ void KateAiPlugin::readConfig()
reconnect();
}
KateAiPluginView::KateAiPluginView(KateAiPlugin *plugin, KTextEditor::MainWindow *mainwindow,
QPointer<QWebSocket> webSocket, bool instruct)
KateAiPluginView::KateAiPluginView(KateAiPlugin *plugin, KTextEditor::MainWindow *mainwindow, bool instruct)
: QObject(plugin)
, m_mainWindow(mainwindow)
, m_webSocket(webSocket)
, m_useInstruct(instruct)
{
KXMLGUIClient::setComponentName(QStringLiteral("kateaiplugin"), QStringLiteral("Git Blame"));
KXMLGUIClient::setComponentName(QStringLiteral("kateaiplugin"), QStringLiteral("Kate Ai"));
setXMLFile(QStringLiteral("ui.rc"));
QAction *generateAction = actionCollection()->addAction(QStringLiteral("ai_generate"));
generateAction->setText(QStringLiteral("Generate text using AI"));
@ -82,7 +77,7 @@ KateAiPluginView::KateAiPluginView(KateAiPlugin *plugin, KTextEditor::MainWindow
m_mainWindow->guiFactory()->addClient(this);
connect(generateAction, &QAction::triggered, this, &KateAiPluginView::generate);
connect(m_webSocket, &QWebSocket::textMessageReceived, this, &KateAiPluginView::socketMessage);
connect(m_ai, &AiBackend::gotResponse, this, &KateAiPluginView::gotResponse);
}
QPointer<KTextEditor::Document> KateAiPluginView::activeDocument() const
@ -101,29 +96,17 @@ KTextEditor::Cursor KateAiPluginView::getCurrentCursor() const
return KTextEditor::Cursor();
}
void KateAiPluginView::socketMessage(const QString& message)
void KateAiPluginView::gotResponse(AiBackend::Response response)
{
QJsonDocument jsonDocument = QJsonDocument::fromJson(message.toUtf8());
QJsonValue idVal = jsonDocument[QStringLiteral("request_id")];
if(!idVal.isDouble())
{
qDebug()<<"Got invalid response on socket";
return;
}
int id = idVal.toInt();
QHash<uint32_t, Request>::iterator it = m_requests.find(response.getId());
QHash<int, Request>::iterator it = m_requests.find(id);
if(it != m_requests.end())
{
QJsonValue responseValue = jsonDocument[QStringLiteral("response")];
if(!responseValue.isString())
{
qDebug()<<"Got invalid response on socket";
return;
}
if(it.value().document)
it.value().document->insertText(it.value().cursor, responseValue.toString());
m_requests.erase(it);
it.value().document->insertText(it.value().cursor, response.getText());
if(response.isFinished())
m_requests.erase(it);
else
it.value().cursor = getCurrentCursor();
}
}
@ -198,25 +181,15 @@ QString KateAiPluginView::assembleContext(QPointer<KTextEditor::Document> docume
void KateAiPluginView::generate()
{
qDebug()<<activeDocument()->mimeType();
if(m_webSocket && m_webSocket->isValid())
if(m_ai->ready())
{
KTextEditor::Cursor cursor = getCurrentCursor();
QPointer<KTextEditor::Document> document = activeDocument();
QString text = assembleContext(document, cursor);
int id = QRandomGenerator::global()->bounded(0, std::numeric_limits<int>::max());
QJsonObject json;
json[QStringLiteral("action")] = QStringLiteral("infer");
json[QStringLiteral("request_id")] = id;
json[QStringLiteral("text")] = text;
json[QStringLiteral("max_new_tokens")] = 50;
json[QStringLiteral("stream")] = false;
QJsonDocument jsonDocument(json);
QString requestText = QString::fromUtf8(jsonDocument.toJson(QJsonDocument::JsonFormat::Compact));
qDebug()<<__func__<<' '<<requestText;
m_webSocket->sendTextMessage(requestText);
m_requests.insert(id, {cursor, document});
AiBackend::Request request(text);
m_ai->generate(request);
m_requests.insert(request.getId(), {cursor, document});
}
else
{