diff --git a/CMakeLists.txt b/CMakeLists.txt index d43067c..f5c66ab 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -34,6 +34,8 @@ target_sources( PRIVATE kateai.cpp kateaiconfigpage.cpp + backend.cpp + exllama.cpp plugin.qrc ) diff --git a/backend.cpp b/backend.cpp new file mode 100644 index 0000000..eb91cc2 --- /dev/null +++ b/backend.cpp @@ -0,0 +1,65 @@ +#include "backend.h" + +AiBackend::Request::Request(const QString& textIn, void* userPtrIn): +text(textIn), +userPtr(userPtrIn) +{ + id = idCounter++; +} + +const QString& AiBackend::Request::getText() const +{ + return text; +} + +uint32_t AiBackend::Request::getId() const +{ + return id; +} + +void* AiBackend::Request::getUserPtr() const +{ + return userPtr; +} + +bool AiBackend::Request::operator==(const Request& in) +{ + return id == in.id; +} + +AiBackend::Response::Response(QString textIn, uint32_t idIn, bool finishedIn, void* userPtrIn): +AiBackend::Request::Request(textIn, userPtrIn), +finished(finishedIn) +{ + id = idIn; +} + +bool AiBackend::Response::isFinished() const +{ + return finished; +} + +void AiBackend::Response::setUserPtr(void* ptr) +{ + userPtr = ptr; +} + +void AiBackend::generate(const Request& request) +{ + m_requests.insert(request.getId(), request); + generateImpl(request); +} + +bool AiBackend::isValidId(uint32_t id) +{ + return m_requests.find(id) != m_requests.end(); +} + +void AiBackend::feedResponse(Response response) +{ + response.setUserPtr(m_requests[response.getId()].getUserPtr()); + if(response.isFinished()) + m_requests.remove(response.getId()); + + gotResponse(response); +} diff --git a/backend.h b/backend.h new file mode 100644 index 0000000..e0ab609 --- /dev/null +++ b/backend.h @@ -0,0 +1,70 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +class AiBackend: public QObject +{ + Q_OBJECT +public: + + class backend_error: public std::runtime_error + { + public: + backend_error(std::string err): runtime_error(err){} + }; + + class Request + { + protected: + inline static uint32_t idCounter = 0; + QString text; + uint32_t id; + void* userPtr; + + public: + Request() = default; + Request(const QString& text, void* userPtr = nullptr); + const QString& getText() const; + uint32_t getId() const; + bool operator==(const Request& in); + void* getUserPtr() const; + }; + + class Response: public Request + { + private: + bool finished; + + public: + Response() = default; + Response(QString text, uint32_t id, bool finished, void* userPtr = nullptr); + bool isFinished() const; + void setUserPtr(void* ptr); + }; + +protected: + QHash m_requests; + + void feedResponse(Response response); + bool isValidId(uint32_t id); + virtual void generateImpl(const Request& request) = 0; + +public: + + virtual bool ready() = 0; + void generate(const Request& request); + virtual void abort(uint64_t id){(void)id;} + virtual void open(const QUrl& url){(void)url;}; + Q_SIGNAL void gotResponse(Response response); + + virtual ~AiBackend() = default; +}; + +Q_DECLARE_METATYPE(AiBackend::Response); +Q_DECLARE_METATYPE(AiBackend::Request); diff --git a/exllama.cpp b/exllama.cpp new file mode 100644 index 0000000..6d8abab --- /dev/null +++ b/exllama.cpp @@ -0,0 +1,67 @@ +#include "exllama.h" + +#include +#include + +ExLlama::ExLlama() +{ + connect(&m_webSocket, &QWebSocket::textMessageReceived, this, &ExLlama::socketMessage); +} + +bool ExLlama::ready() +{ + return m_webSocket.isValid(); +} + +void ExLlama::socketMessage(const QString& message) +{ + QJsonDocument jsonDocument = QJsonDocument::fromJson(message.toUtf8()); + QJsonValue idVal = jsonDocument[QStringLiteral("request_id")]; + if(!idVal.isDouble()) + { + qDebug()<<"Got invalid response on socket"; + return; + } + int id = idVal.toInt(); + + if(!isValidId(id)) + { + qDebug()<<"Got unkown response id on socket"; + return; + } + + QJsonValue responseValue = jsonDocument[QStringLiteral("response")]; + if(!responseValue.isString()) + { + qDebug()<<"Got invalid response on socket"; + return; + } + + feedResponse(Response(responseValue.toString(), id, true)); +} + +void ExLlama::generateImpl(const Request& request) +{ + QJsonObject json; + json[QStringLiteral("action")] = QStringLiteral("infer"); + json[QStringLiteral("request_id")] = static_cast(request.getId()); + json[QStringLiteral("text")] = request.getText(); + json[QStringLiteral("max_new_tokens")] = 50; + json[QStringLiteral("stream")] = false; + + QJsonDocument jsonDocument(json); + QString requestText = QString::fromUtf8(jsonDocument.toJson(QJsonDocument::JsonFormat::Compact)); + qDebug()<<__func__<<' '< +#include + +class ExLlama: public AiBackend +{ + Q_OBJECT +private: + QWebSocket m_webSocket; + + void socketMessage(const QString& message); + +protected: + virtual void generateImpl(const Request& request) override; + +public: + ExLlama(); + virtual bool ready() override; + virtual void open(const QUrl& url) override; + + virtual ~ExLlama(); +}; diff --git a/kateai.cpp b/kateai.cpp index 58c7cfe..b6a8621 100644 --- a/kateai.cpp +++ b/kateai.cpp @@ -21,6 +21,8 @@ #include #include +#include "backend.h" +#include "exllama.h" #include "kateaiconfigpage.h" K_PLUGIN_FACTORY_WITH_JSON(KateAiPluginFactory, "kateai.json", registerPlugin();) @@ -28,26 +30,21 @@ K_PLUGIN_FACTORY_WITH_JSON(KateAiPluginFactory, "kateai.json", registerPlugin &) : KTextEditor::Plugin(parent), m_serverUrl(QStringLiteral("ws://localhost:8642")) { - connect(&m_webSocket, &QWebSocket::connected, this, &KateAiPlugin::onConnected); + m_ai = new ExLlama(); + KateAiPluginView::setAi(m_ai); readConfig(); } -void KateAiPlugin::onConnected() -{ - qDebug()<<__func__<open(m_serverUrl); } QObject *KateAiPlugin::createView(KTextEditor::MainWindow *mainWindow) { - auto view = new KateAiPluginView(this, mainWindow, &m_webSocket); + auto view = new KateAiPluginView(this, mainWindow); connect(view, &KateAiPluginView::reconnect, this, &KateAiPlugin::reconnect); return view; } @@ -67,14 +64,12 @@ void KateAiPlugin::readConfig() reconnect(); } -KateAiPluginView::KateAiPluginView(KateAiPlugin *plugin, KTextEditor::MainWindow *mainwindow, - QPointer webSocket, bool instruct) +KateAiPluginView::KateAiPluginView(KateAiPlugin *plugin, KTextEditor::MainWindow *mainwindow, bool instruct) : QObject(plugin) , m_mainWindow(mainwindow) - , m_webSocket(webSocket) , m_useInstruct(instruct) { - KXMLGUIClient::setComponentName(QStringLiteral("kateaiplugin"), QStringLiteral("Git Blame")); + KXMLGUIClient::setComponentName(QStringLiteral("kateaiplugin"), QStringLiteral("Kate Ai")); setXMLFile(QStringLiteral("ui.rc")); QAction *generateAction = actionCollection()->addAction(QStringLiteral("ai_generate")); generateAction->setText(QStringLiteral("Generate text using AI")); @@ -82,7 +77,7 @@ KateAiPluginView::KateAiPluginView(KateAiPlugin *plugin, KTextEditor::MainWindow m_mainWindow->guiFactory()->addClient(this); connect(generateAction, &QAction::triggered, this, &KateAiPluginView::generate); - connect(m_webSocket, &QWebSocket::textMessageReceived, this, &KateAiPluginView::socketMessage); + connect(m_ai, &AiBackend::gotResponse, this, &KateAiPluginView::gotResponse); } QPointer KateAiPluginView::activeDocument() const @@ -101,29 +96,17 @@ KTextEditor::Cursor KateAiPluginView::getCurrentCursor() const return KTextEditor::Cursor(); } -void KateAiPluginView::socketMessage(const QString& message) +void KateAiPluginView::gotResponse(AiBackend::Response response) { - QJsonDocument jsonDocument = QJsonDocument::fromJson(message.toUtf8()); - QJsonValue idVal = jsonDocument[QStringLiteral("request_id")]; - if(!idVal.isDouble()) - { - qDebug()<<"Got invalid response on socket"; - return; - } - int id = idVal.toInt(); + QHash::iterator it = m_requests.find(response.getId()); - QHash::iterator it = m_requests.find(id); if(it != m_requests.end()) { - QJsonValue responseValue = jsonDocument[QStringLiteral("response")]; - if(!responseValue.isString()) - { - qDebug()<<"Got invalid response on socket"; - return; - } - if(it.value().document) - it.value().document->insertText(it.value().cursor, responseValue.toString()); - m_requests.erase(it); + it.value().document->insertText(it.value().cursor, response.getText()); + if(response.isFinished()) + m_requests.erase(it); + else + it.value().cursor = getCurrentCursor(); } } @@ -198,25 +181,15 @@ QString KateAiPluginView::assembleContext(QPointer docume void KateAiPluginView::generate() { qDebug()<mimeType(); - if(m_webSocket && m_webSocket->isValid()) + if(m_ai->ready()) { KTextEditor::Cursor cursor = getCurrentCursor(); QPointer document = activeDocument(); QString text = assembleContext(document, cursor); - int id = QRandomGenerator::global()->bounded(0, std::numeric_limits::max()); - QJsonObject json; - json[QStringLiteral("action")] = QStringLiteral("infer"); - json[QStringLiteral("request_id")] = id; - json[QStringLiteral("text")] = text; - json[QStringLiteral("max_new_tokens")] = 50; - json[QStringLiteral("stream")] = false; - - QJsonDocument jsonDocument(json); - QString requestText = QString::fromUtf8(jsonDocument.toJson(QJsonDocument::JsonFormat::Compact)); - qDebug()<<__func__<<' '<sendTextMessage(requestText); - m_requests.insert(id, {cursor, document}); + AiBackend::Request request(text); + m_ai->generate(request); + m_requests.insert(request.getId(), {cursor, document}); } else { diff --git a/kateai.h b/kateai.h index bbc6f94..100ce92 100644 --- a/kateai.h +++ b/kateai.h @@ -7,22 +7,21 @@ #include #include -#include +#include #include -#include #include #include +#include "backend.h" + class KateAiPlugin : public KTextEditor::Plugin { Q_OBJECT private: - QWebSocket m_webSocket; + QPointer m_ai; QUrl m_serverUrl; private: - void onConnected(); - int configPages() const override { return 1; @@ -52,24 +51,24 @@ private: }; KTextEditor::MainWindow *m_mainWindow; - QPointer m_webSocket; - QHash m_requests; + inline static QPointer m_ai; bool m_useInstruct = false; private: void generate(); - void socketMessage(const QString& message); + void gotResponse(AiBackend::Response respons); static QStringList getIncludePaths(const QString& text); QString assembleContext(QPointer document, const KTextEditor::Cursor& cursor); QPointer activeDocument() const; KTextEditor::Cursor getCurrentCursor() const; + QHash m_requests; public: - KateAiPluginView(KateAiPlugin *plugin, KTextEditor::MainWindow *mainwindow, QPointer webSocket, - bool instruct = false); + KateAiPluginView(KateAiPlugin *plugin, KTextEditor::MainWindow *mainwindow, bool instruct = false); ~KateAiPluginView() override; void setInstruct(bool instruct); + static void setAi(QPointer ai){m_ai = ai;} Q_SIGNAL void reconnect(); };