diff --git a/backend.cpp b/backend.cpp index eb91cc2..d86d436 100644 --- a/backend.cpp +++ b/backend.cpp @@ -1,8 +1,11 @@ #include "backend.h" +#include "exllama.h" +#include -AiBackend::Request::Request(const QString& textIn, void* userPtrIn): +AiBackend::Request::Request(const QString& textIn, type_t typeIn, void* userPtrIn): text(textIn), -userPtr(userPtrIn) +userPtr(userPtrIn), +type(typeIn) { id = idCounter++; } @@ -27,13 +30,21 @@ bool AiBackend::Request::operator==(const Request& in) return id == in.id; } -AiBackend::Response::Response(QString textIn, uint32_t idIn, bool finishedIn, void* userPtrIn): -AiBackend::Request::Request(textIn, userPtrIn), -finished(finishedIn) +AiBackend::Request::type_t AiBackend::Request::getType() const +{ + return type; +} + +AiBackend::Response::Response(QString textIn, uint32_t idIn, bool finishedIn, + AiBackend::Request::type_t type, int64_t tokensIn, void* userPtrIn): +AiBackend::Request::Request(textIn, type, userPtrIn), +finished(finishedIn), +tokens(tokensIn) { id = idIn; } + bool AiBackend::Response::isFinished() const { return finished; @@ -44,10 +55,32 @@ void AiBackend::Response::setUserPtr(void* ptr) userPtr = ptr; } -void AiBackend::generate(const Request& request) +std::vector AiBackend::getAvailableBackendNames() { - m_requests.insert(request.getId(), request); - generateImpl(request); + return {QStringLiteral("KoboldAI"), ExLlama::backendNameStatic()}; +} + +std::shared_ptr AiBackend::createBackend(const QString& name) +{ + if(name == QStringLiteral("KoboldAI")) + return nullptr; + if(name == ExLlama::backendNameStatic()) + return std::shared_ptr(new ExLlama()); + + return nullptr; +} + +bool AiBackend::generate(const Request& request) +{ + if(request.getType() == Request::UNKOWN) + return false; + + bool ret = generateImpl(request); + + if(ret) + m_requests.insert(request.getId(), request); + + return ret; } bool AiBackend::isValidId(uint32_t id) diff --git a/backend.h b/backend.h index e0ab609..88b9d83 100644 --- a/backend.h +++ b/backend.h @@ -7,6 +7,7 @@ #include #include #include +#include class AiBackend: public QObject { @@ -21,17 +22,28 @@ public: class Request { + public: + + typedef enum { + INFERENCE, + COUNT_TOKENS, + LEFT_TRIM, + UNKOWN + } type_t; + protected: inline static uint32_t idCounter = 0; QString text; uint32_t id; void* userPtr; + type_t type; public: Request() = default; - Request(const QString& text, void* userPtr = nullptr); + Request(const QString& text, type_t type = UNKOWN, void* userPtr = nullptr); const QString& getText() const; uint32_t getId() const; + type_t getType() const; bool operator==(const Request& in); void* getUserPtr() const; }; @@ -40,11 +52,14 @@ public: { private: bool finished; + int64_t tokens = -1; public: Response() = default; - Response(QString text, uint32_t id, bool finished, void* userPtr = nullptr); + Response(QString text, uint32_t id, bool finished, type_t type = UNKOWN, int64_t tokens= -1, void* userPtr = nullptr); bool isFinished() const; + int64_t getTokens() const; + void setUserPtr(void* ptr); }; @@ -53,12 +68,15 @@ protected: void feedResponse(Response response); bool isValidId(uint32_t id); - virtual void generateImpl(const Request& request) = 0; + virtual bool generateImpl(const Request& request) = 0; public: + static std::vector getAvailableBackendNames(); + static std::shared_ptr createBackend(const QString& name); + virtual QString backendName() = 0; virtual bool ready() = 0; - void generate(const Request& request); + bool generate(const Request& request); virtual void abort(uint64_t id){(void)id;} virtual void open(const QUrl& url){(void)url;}; Q_SIGNAL void gotResponse(Response response); @@ -68,3 +86,4 @@ public: Q_DECLARE_METATYPE(AiBackend::Response); Q_DECLARE_METATYPE(AiBackend::Request); + diff --git a/exllama.cpp b/exllama.cpp index 6d8abab..4afc27c 100644 --- a/exllama.cpp +++ b/exllama.cpp @@ -1,7 +1,9 @@ #include "exllama.h" +#include "backend.h" #include #include +#include ExLlama::ExLlama() { @@ -17,11 +19,6 @@ void ExLlama::socketMessage(const QString& message) { QJsonDocument jsonDocument = QJsonDocument::fromJson(message.toUtf8()); QJsonValue idVal = jsonDocument[QStringLiteral("request_id")]; - if(!idVal.isDouble()) - { - qDebug()<<"Got invalid response on socket"; - return; - } int id = idVal.toInt(); if(!isValidId(id)) @@ -30,29 +27,90 @@ void ExLlama::socketMessage(const QString& message) return; } - QJsonValue responseValue = jsonDocument[QStringLiteral("response")]; + Response::type_t type = strToAction(jsonDocument[QStringLiteral("action")].toString()); + QJsonValue responseValue; + int64_t tokens = -1; + + if(type == Response::INFERENCE) + { + responseValue = jsonDocument[QStringLiteral("response")]; + } + else if(type == Response::COUNT_TOKENS) + { + responseValue = QJsonValue(QStringLiteral("")); + QJsonValue tokensValue = jsonDocument[QStringLiteral("num_tokens")]; + if(!tokensValue.isDouble()) + { + qDebug()<<"Got invalid response on socket, num_tokens missing or not a number"; + return; + } + tokens = tokensValue.toInt(); + } + else if(type == Response::LEFT_TRIM) + { + responseValue = jsonDocument[QStringLiteral("trimmed_text")]; + } + if(!responseValue.isString()) { qDebug()<<"Got invalid response on socket"; return; } - feedResponse(Response(responseValue.toString(), id, true)); + feedResponse(Response(responseValue.toString(), id, true, type, tokens)); } -void ExLlama::generateImpl(const Request& request) +const QString ExLlama::actionToStr(AiBackend::Request::type_t type) { + switch(type) + { + case Request::INFERENCE: + return QStringLiteral("infer"); + case Request::COUNT_TOKENS: + return QStringLiteral("estimate_token"); + case Request::LEFT_TRIM: + return QStringLiteral("lefttrim_token"); + default: + return QString(); + } +} + +AiBackend::Request::type_t ExLlama::strToAction(const QString& str) +{ + if(str == QStringLiteral("infer")) + return Request::INFERENCE; + else if(str == QStringLiteral("estimate_token")) + return Request::COUNT_TOKENS; + else if(str == QStringLiteral("lefttrim_token")) + return Request::LEFT_TRIM; + else + return Request::UNKOWN; +} + +bool ExLlama::generateImpl(const Request& request) +{ + QString action = actionToStr(request.getType()); + + if(action.isEmpty()) + return false; + QJsonObject json; - json[QStringLiteral("action")] = QStringLiteral("infer"); + json[QStringLiteral("action")] = action; json[QStringLiteral("request_id")] = static_cast(request.getId()); json[QStringLiteral("text")] = request.getText(); - json[QStringLiteral("max_new_tokens")] = 50; - json[QStringLiteral("stream")] = false; + + if(request.getType() == Request::INFERENCE) + { + json[QStringLiteral("max_new_tokens")] = 50; + json[QStringLiteral("stream")] = false; + } QJsonDocument jsonDocument(json); QString requestText = QString::fromUtf8(jsonDocument.toJson(QJsonDocument::JsonFormat::Compact)); qDebug()<<__func__<<' '< #include #include +#include #include #include #include @@ -18,6 +19,7 @@ #include #include #include +#include #include #include @@ -62,12 +64,15 @@ void KateAiPlugin::readConfig() KConfigGroup config(KSharedConfig::openConfig(), "Ai"); m_serverUrl = QUrl(config.readEntry("Url", "ws://localhost:8642")); reconnect(); + + KateAiPluginView::setInstruct(config.readEntry("Instruct", false)); + KateAiPluginView::setSystemPrompt(config.readEntry("SystemPrompt", "You are an intelligent programming assistant.")); + KateAiPluginView::setMaxContext(config.readEntry("Context", 1024)); } -KateAiPluginView::KateAiPluginView(KateAiPlugin *plugin, KTextEditor::MainWindow *mainwindow, bool instruct) +KateAiPluginView::KateAiPluginView(KateAiPlugin *plugin, KTextEditor::MainWindow *mainwindow) : QObject(plugin) , m_mainWindow(mainwindow) - , m_useInstruct(instruct) { KXMLGUIClient::setComponentName(QStringLiteral("kateaiplugin"), QStringLiteral("Kate Ai")); setXMLFile(QStringLiteral("ui.rc")); @@ -143,21 +148,28 @@ QStringList KateAiPluginView::getIncludePaths(const QString& text) return paths; } -QString KateAiPluginView::assembleContext(QPointer document, const KTextEditor::Cursor& cursor) +QString KateAiPluginView::assembleContext(QPointer document, + const KTextEditor::Cursor& cursor, const QString& instruction) { QString mime = document->mimeType(); QString context; - QString baseText; + QString documentText = document->text(KTextEditor::Range(KTextEditor::Cursor(0, 0), cursor)); + + if(m_useInstruct) + { + context.append(QStringLiteral("### System Prompt\n")); + context.append(m_systemPrompt); + context.push_back(u'\n'); + context.append(QStringLiteral("### User Message\n")); + context.append(instruction); + context.push_back(QStringLiteral("\n### Assistant\n")); + } - if(!m_useInstruct) - baseText = document->text(KTextEditor::Range(KTextEditor::Cursor(0, 0), cursor)); - else - baseText = document->text(); if(mime == QStringLiteral("text/x-c++src") || mime == QStringLiteral("text/x-csrc")) { QFileInfo documentFileInfo(document->url().path()); QString directory = documentFileInfo.absolutePath(); - QStringList paths = getIncludePaths(baseText); + QStringList paths = getIncludePaths(documentText); qDebug()<<__func__<<"Directory:"< docume } } - context.append(baseText); + context.append(documentText); return context; } @@ -185,15 +197,30 @@ void KateAiPluginView::generate() { KTextEditor::Cursor cursor = getCurrentCursor(); QPointer document = activeDocument(); - QString text = assembleContext(document, cursor); + QString instruction; - AiBackend::Request request(text); - m_ai->generate(request); - m_requests.insert(request.getId(), {cursor, document}); + if(m_useInstruct) + { + bool ok; + instruction = QInputDialog::getMultiLineText(m_mainWindow->activeView(), i18n("Input instruction"), + i18n("Instruction"), m_lastInstruct, &ok); + if(!ok) + return; + } + + QString text = assembleContext(document, cursor, instruction); + + AiBackend::Request request(text, AiBackend::Request::INFERENCE); + bool ret = m_ai->generate(request); + if(!ret) + QMessageBox::warning(m_mainWindow->activeView(), i18n("Failure"), + i18n("The Ai backend was unable to process this request")); + else + m_requests.insert(request.getId(), {cursor, document}); } else { - QMessageBox box; + QMessageBox box(m_mainWindow->activeView()); box.setText(i18n("The AI server is not connected.")); box.setInformativeText(i18n("would you like to try and reconnect?")); box.setStandardButtons(QMessageBox::Yes | QMessageBox::No); @@ -214,5 +241,20 @@ void KateAiPluginView::setInstruct(bool instruct) m_useInstruct = instruct; } +void KateAiPluginView::setSystemPrompt(const QString& in) +{ + m_systemPrompt = in; +} + +void KateAiPluginView::setAi(QPointer ai) +{ + m_ai = ai; +} + +void KateAiPluginView::setMaxContext(int maxContext) +{ + m_maxContext = maxContext; +} + #include "kateai.moc" #include "moc_kateai.cpp" diff --git a/kateai.h b/kateai.h index 100ce92..29d417f 100644 --- a/kateai.h +++ b/kateai.h @@ -52,23 +52,28 @@ private: KTextEditor::MainWindow *m_mainWindow; inline static QPointer m_ai; - bool m_useInstruct = false; + inline static bool m_useInstruct = false; + inline static QString m_systemPrompt; + inline static int m_maxContext; + inline static QString m_lastInstruct; private: void generate(); void gotResponse(AiBackend::Response respons); static QStringList getIncludePaths(const QString& text); - QString assembleContext(QPointer document, const KTextEditor::Cursor& cursor); + QString assembleContext(QPointer document, const KTextEditor::Cursor& cursor, const QString& instruction = QString()); QPointer activeDocument() const; KTextEditor::Cursor getCurrentCursor() const; QHash m_requests; public: - KateAiPluginView(KateAiPlugin *plugin, KTextEditor::MainWindow *mainwindow, bool instruct = false); + KateAiPluginView(KateAiPlugin *plugin, KTextEditor::MainWindow *mainwindow); ~KateAiPluginView() override; - void setInstruct(bool instruct); - static void setAi(QPointer ai){m_ai = ai;} + static void setInstruct(bool instruct); + static void setSystemPrompt(const QString& in); + static void setMaxContext(int maxContext); + static void setAi(QPointer ai); Q_SIGNAL void reconnect(); }; diff --git a/kateaiconfigpage.cpp b/kateaiconfigpage.cpp index ed1d92a..df8b149 100644 --- a/kateaiconfigpage.cpp +++ b/kateaiconfigpage.cpp @@ -7,6 +7,9 @@ #include #include #include +#include + +#include "backend.h" KateAiConfigPage::KateAiConfigPage(QWidget *parent, KateAiPlugin *plugin) : KTextEditor::ConfigPage(parent) @@ -15,19 +18,42 @@ KateAiConfigPage::KateAiConfigPage(QWidget *parent, KateAiPlugin *plugin) QVBoxLayout* layout = new QVBoxLayout(this); layout->setContentsMargins(0, 0, 0, 0); + std::vector backends = AiBackend::getAvailableBackendNames(); + for(const QString& name : backends) + cmbxServerType.addItem(name); + + layout->addWidget(&cmbxServerType); + QHBoxLayout* lineLayout = new QHBoxLayout(this); - QLabel* lineEditLabel = new QLabel(i18n("Url for the WebSockets ExLlama Ai server:"), this); + QLabel* lineEditLabel = new QLabel(i18n("Url for the Ai server:"), this); lineEditLabel->setSizePolicy(QSizePolicy(QSizePolicy::Expanding, QSizePolicy::Fixed)); lineLayout->addWidget(lineEditLabel); lineLayout->addWidget(&lineUrl); layout->addLayout(lineLayout); + QHBoxLayout* contextSpinLayout = new QHBoxLayout(this); + QLabel* contextSpinLabel = new QLabel(i18n("Maximum context:"), this); + contextSpinLayout->addWidget(contextSpinLabel); + contextSpinBox.setMinimum(100); + contextSpinBox.setMaximum(32768); + contextSpinLayout->addWidget(&contextSpinBox); + layout->addLayout(contextSpinLayout); + btnCompletion.setText(i18n("Use the Ai to generate a completion")); - btnInstruct.setText(i18n("Use the Ai to insert a response to a instruction")); + btnInstruct.setText(i18n("Use the Ai to insert a response to an instruction")); layout->addWidget(&btnCompletion); layout->addWidget(&btnInstruct); + + QHBoxLayout* systemPromptLayout = new QHBoxLayout(this); + systemPromptLabel.setText(i18n("System Prompt:")); + systemPromptLabel.setSizePolicy(QSizePolicy(QSizePolicy::Expanding, QSizePolicy::Fixed)); + systemPromptLayout->addWidget(&systemPromptLabel); + systemPromptLayout->addWidget(&lineSystemPrompt); + layout->addLayout(systemPromptLayout); layout->addStretch(); + connect(&btnInstruct, &QRadioButton::toggled, this, &KateAiConfigPage::instructBtnToggeled); + reset(); } @@ -51,6 +77,9 @@ void KateAiConfigPage::apply() KConfigGroup config(KSharedConfig::openConfig(), "Ai"); config.writeEntry("Url", lineUrl.text()); config.writeEntry("Instruct", btnInstruct.isChecked()); + config.writeEntry("SystemPrompt", lineSystemPrompt.text()); + config.writeEntry("Context", contextSpinBox.value()); + config.writeEntry("Backend", cmbxServerType.currentText()); config.sync(); m_plugin->readConfig(); @@ -59,7 +88,19 @@ void KateAiConfigPage::apply() void KateAiConfigPage::reset() { KConfigGroup config(KSharedConfig::openConfig(), "Ai"); - lineUrl.setText(config.readEntry("Url", "ws://localhost:8642")); + lineSystemPrompt.setText(config.readEntry("SystemPrompt", "ws://localhost:8642")); + lineUrl.setText(config.readEntry("Url", "You are an intelligent programming assistant.")); btnInstruct.setChecked(config.readEntry("Instruct", false)); btnCompletion.setChecked(!btnInstruct.isChecked()); + contextSpinBox.setValue(config.readEntry("Context", 1024)); + cmbxServerType.setCurrentText(config.readEntry("Backend", "KoboldAI")); + + lineSystemPrompt.setEnabled(btnInstruct.isChecked()); + systemPromptLabel.setEnabled(btnInstruct.isChecked()); +} + +void KateAiConfigPage::instructBtnToggeled(bool checked) +{ + lineSystemPrompt.setEnabled(checked); + systemPromptLabel.setEnabled(checked); } diff --git a/kateaiconfigpage.h b/kateaiconfigpage.h index ab179f4..1ce6e65 100644 --- a/kateaiconfigpage.h +++ b/kateaiconfigpage.h @@ -5,16 +5,27 @@ #include #include +#include +#include +#include class KateAiConfigPage : public KTextEditor::ConfigPage { Q_OBJECT private: QLineEdit lineUrl; + QLabel systemPromptLabel; + QLineEdit lineSystemPrompt; QRadioButton btnCompletion; QRadioButton btnInstruct; + QComboBox cmbxServerType; + QSpinBox contextSpinBox; KateAiPlugin* m_plugin; +private: + + void instructBtnToggeled(bool checked); + public: explicit KateAiConfigPage(QWidget *parent = nullptr, KateAiPlugin *plugin = nullptr); ~KateAiConfigPage() override