general ui work

This commit is contained in:
uvos 2024-06-11 14:15:08 +02:00
parent 8a26f9e1e4
commit b111e15fd5
8 changed files with 272 additions and 47 deletions

View File

@ -1,8 +1,11 @@
#include "backend.h"
#include "exllama.h"
#include <qstringliteral.h>
AiBackend::Request::Request(const QString& textIn, void* userPtrIn):
AiBackend::Request::Request(const QString& textIn, type_t typeIn, void* userPtrIn):
text(textIn),
userPtr(userPtrIn)
userPtr(userPtrIn),
type(typeIn)
{
id = idCounter++;
}
@ -27,13 +30,21 @@ bool AiBackend::Request::operator==(const Request& in)
return id == in.id;
}
AiBackend::Response::Response(QString textIn, uint32_t idIn, bool finishedIn, void* userPtrIn):
AiBackend::Request::Request(textIn, userPtrIn),
finished(finishedIn)
AiBackend::Request::type_t AiBackend::Request::getType() const
{
return type;
}
AiBackend::Response::Response(QString textIn, uint32_t idIn, bool finishedIn,
AiBackend::Request::type_t type, int64_t tokensIn, void* userPtrIn):
AiBackend::Request::Request(textIn, type, userPtrIn),
finished(finishedIn),
tokens(tokensIn)
{
id = idIn;
}
bool AiBackend::Response::isFinished() const
{
return finished;
@ -44,10 +55,32 @@ void AiBackend::Response::setUserPtr(void* ptr)
userPtr = ptr;
}
void AiBackend::generate(const Request& request)
std::vector<QString> AiBackend::getAvailableBackendNames()
{
m_requests.insert(request.getId(), request);
generateImpl(request);
return {QStringLiteral("KoboldAI"), ExLlama::backendNameStatic()};
}
std::shared_ptr<AiBackend> AiBackend::createBackend(const QString& name)
{
if(name == QStringLiteral("KoboldAI"))
return nullptr;
if(name == ExLlama::backendNameStatic())
return std::shared_ptr<ExLlama>(new ExLlama());
return nullptr;
}
bool AiBackend::generate(const Request& request)
{
if(request.getType() == Request::UNKOWN)
return false;
bool ret = generateImpl(request);
if(ret)
m_requests.insert(request.getId(), request);
return ret;
}
bool AiBackend::isValidId(uint32_t id)

View File

@ -7,6 +7,7 @@
#include <stdexcept>
#include <string>
#include <QUrl>
#include <memory>
class AiBackend: public QObject
{
@ -21,17 +22,28 @@ public:
class Request
{
public:
typedef enum {
INFERENCE,
COUNT_TOKENS,
LEFT_TRIM,
UNKOWN
} type_t;
protected:
inline static uint32_t idCounter = 0;
QString text;
uint32_t id;
void* userPtr;
type_t type;
public:
Request() = default;
Request(const QString& text, void* userPtr = nullptr);
Request(const QString& text, type_t type = UNKOWN, void* userPtr = nullptr);
const QString& getText() const;
uint32_t getId() const;
type_t getType() const;
bool operator==(const Request& in);
void* getUserPtr() const;
};
@ -40,11 +52,14 @@ public:
{
private:
bool finished;
int64_t tokens = -1;
public:
Response() = default;
Response(QString text, uint32_t id, bool finished, void* userPtr = nullptr);
Response(QString text, uint32_t id, bool finished, type_t type = UNKOWN, int64_t tokens= -1, void* userPtr = nullptr);
bool isFinished() const;
int64_t getTokens() const;
void setUserPtr(void* ptr);
};
@ -53,12 +68,15 @@ protected:
void feedResponse(Response response);
bool isValidId(uint32_t id);
virtual void generateImpl(const Request& request) = 0;
virtual bool generateImpl(const Request& request) = 0;
public:
static std::vector<QString> getAvailableBackendNames();
static std::shared_ptr<AiBackend> createBackend(const QString& name);
virtual QString backendName() = 0;
virtual bool ready() = 0;
void generate(const Request& request);
bool generate(const Request& request);
virtual void abort(uint64_t id){(void)id;}
virtual void open(const QUrl& url){(void)url;};
Q_SIGNAL void gotResponse(Response response);
@ -68,3 +86,4 @@ public:
Q_DECLARE_METATYPE(AiBackend::Response);
Q_DECLARE_METATYPE(AiBackend::Request);

View File

@ -1,7 +1,9 @@
#include "exllama.h"
#include "backend.h"
#include <QJsonObject>
#include <QJsonDocument>
#include <cstdint>
ExLlama::ExLlama()
{
@ -17,11 +19,6 @@ void ExLlama::socketMessage(const QString& message)
{
QJsonDocument jsonDocument = QJsonDocument::fromJson(message.toUtf8());
QJsonValue idVal = jsonDocument[QStringLiteral("request_id")];
if(!idVal.isDouble())
{
qDebug()<<"Got invalid response on socket";
return;
}
int id = idVal.toInt();
if(!isValidId(id))
@ -30,29 +27,90 @@ void ExLlama::socketMessage(const QString& message)
return;
}
QJsonValue responseValue = jsonDocument[QStringLiteral("response")];
Response::type_t type = strToAction(jsonDocument[QStringLiteral("action")].toString());
QJsonValue responseValue;
int64_t tokens = -1;
if(type == Response::INFERENCE)
{
responseValue = jsonDocument[QStringLiteral("response")];
}
else if(type == Response::COUNT_TOKENS)
{
responseValue = QJsonValue(QStringLiteral(""));
QJsonValue tokensValue = jsonDocument[QStringLiteral("num_tokens")];
if(!tokensValue.isDouble())
{
qDebug()<<"Got invalid response on socket, num_tokens missing or not a number";
return;
}
tokens = tokensValue.toInt();
}
else if(type == Response::LEFT_TRIM)
{
responseValue = jsonDocument[QStringLiteral("trimmed_text")];
}
if(!responseValue.isString())
{
qDebug()<<"Got invalid response on socket";
return;
}
feedResponse(Response(responseValue.toString(), id, true));
feedResponse(Response(responseValue.toString(), id, true, type, tokens));
}
void ExLlama::generateImpl(const Request& request)
const QString ExLlama::actionToStr(AiBackend::Request::type_t type)
{
switch(type)
{
case Request::INFERENCE:
return QStringLiteral("infer");
case Request::COUNT_TOKENS:
return QStringLiteral("estimate_token");
case Request::LEFT_TRIM:
return QStringLiteral("lefttrim_token");
default:
return QString();
}
}
AiBackend::Request::type_t ExLlama::strToAction(const QString& str)
{
if(str == QStringLiteral("infer"))
return Request::INFERENCE;
else if(str == QStringLiteral("estimate_token"))
return Request::COUNT_TOKENS;
else if(str == QStringLiteral("lefttrim_token"))
return Request::LEFT_TRIM;
else
return Request::UNKOWN;
}
bool ExLlama::generateImpl(const Request& request)
{
QString action = actionToStr(request.getType());
if(action.isEmpty())
return false;
QJsonObject json;
json[QStringLiteral("action")] = QStringLiteral("infer");
json[QStringLiteral("action")] = action;
json[QStringLiteral("request_id")] = static_cast<double>(request.getId());
json[QStringLiteral("text")] = request.getText();
json[QStringLiteral("max_new_tokens")] = 50;
json[QStringLiteral("stream")] = false;
if(request.getType() == Request::INFERENCE)
{
json[QStringLiteral("max_new_tokens")] = 50;
json[QStringLiteral("stream")] = false;
}
QJsonDocument jsonDocument(json);
QString requestText = QString::fromUtf8(jsonDocument.toJson(QJsonDocument::JsonFormat::Compact));
qDebug()<<__func__<<' '<<requestText;
m_webSocket.sendTextMessage(requestText);
return true;
}
void ExLlama::open(const QUrl& url)
@ -61,6 +119,16 @@ void ExLlama::open(const QUrl& url)
m_webSocket.open(url);
}
QString ExLlama::backendName()
{
return backendNameStatic();
}
QString ExLlama::backendNameStatic()
{
return QStringLiteral("ExLlama");
}
ExLlama::~ExLlama()
{
m_webSocket.disconnect();

View File

@ -13,13 +13,19 @@ private:
void socketMessage(const QString& message);
static const QString actionToStr(Request::type_t type);
static Request::type_t strToAction(const QString& str);
protected:
virtual void generateImpl(const Request& request) override;
virtual bool generateImpl(const Request& request) override;
public:
ExLlama();
virtual bool ready() override;
virtual void open(const QUrl& url) override;
virtual QString backendName() override;
static QString backendNameStatic();
virtual ~ExLlama();
};

View File

@ -8,6 +8,7 @@
#include <qdebug.h>
#include <qhash.h>
#include <qjsonobject.h>
#include <qmessagebox.h>
#include <qnamespace.h>
#include <QString>
#include <KActionCollection>
@ -18,6 +19,7 @@
#include <QTextCodec>
#include <limits>
#include <QMessageBox>
#include <QInputDialog>
#include <KLocalizedString>
#include <KSharedConfig>
@ -62,12 +64,15 @@ void KateAiPlugin::readConfig()
KConfigGroup config(KSharedConfig::openConfig(), "Ai");
m_serverUrl = QUrl(config.readEntry("Url", "ws://localhost:8642"));
reconnect();
KateAiPluginView::setInstruct(config.readEntry("Instruct", false));
KateAiPluginView::setSystemPrompt(config.readEntry("SystemPrompt", "You are an intelligent programming assistant."));
KateAiPluginView::setMaxContext(config.readEntry("Context", 1024));
}
KateAiPluginView::KateAiPluginView(KateAiPlugin *plugin, KTextEditor::MainWindow *mainwindow, bool instruct)
KateAiPluginView::KateAiPluginView(KateAiPlugin *plugin, KTextEditor::MainWindow *mainwindow)
: QObject(plugin)
, m_mainWindow(mainwindow)
, m_useInstruct(instruct)
{
KXMLGUIClient::setComponentName(QStringLiteral("kateaiplugin"), QStringLiteral("Kate Ai"));
setXMLFile(QStringLiteral("ui.rc"));
@ -143,21 +148,28 @@ QStringList KateAiPluginView::getIncludePaths(const QString& text)
return paths;
}
QString KateAiPluginView::assembleContext(QPointer<KTextEditor::Document> document, const KTextEditor::Cursor& cursor)
QString KateAiPluginView::assembleContext(QPointer<KTextEditor::Document> document,
const KTextEditor::Cursor& cursor, const QString& instruction)
{
QString mime = document->mimeType();
QString context;
QString baseText;
QString documentText = document->text(KTextEditor::Range(KTextEditor::Cursor(0, 0), cursor));
if(m_useInstruct)
{
context.append(QStringLiteral("### System Prompt\n"));
context.append(m_systemPrompt);
context.push_back(u'\n');
context.append(QStringLiteral("### User Message\n"));
context.append(instruction);
context.push_back(QStringLiteral("\n### Assistant\n"));
}
if(!m_useInstruct)
baseText = document->text(KTextEditor::Range(KTextEditor::Cursor(0, 0), cursor));
else
baseText = document->text();
if(mime == QStringLiteral("text/x-c++src") || mime == QStringLiteral("text/x-csrc"))
{
QFileInfo documentFileInfo(document->url().path());
QString directory = documentFileInfo.absolutePath();
QStringList paths = getIncludePaths(baseText);
QStringList paths = getIncludePaths(documentText);
qDebug()<<__func__<<"Directory:"<<directory<<"Paths:"<<paths;
for(QString& path : paths)
@ -173,7 +185,7 @@ QString KateAiPluginView::assembleContext(QPointer<KTextEditor::Document> docume
}
}
context.append(baseText);
context.append(documentText);
return context;
}
@ -185,15 +197,30 @@ void KateAiPluginView::generate()
{
KTextEditor::Cursor cursor = getCurrentCursor();
QPointer<KTextEditor::Document> document = activeDocument();
QString text = assembleContext(document, cursor);
QString instruction;
AiBackend::Request request(text);
m_ai->generate(request);
m_requests.insert(request.getId(), {cursor, document});
if(m_useInstruct)
{
bool ok;
instruction = QInputDialog::getMultiLineText(m_mainWindow->activeView(), i18n("Input instruction"),
i18n("Instruction"), m_lastInstruct, &ok);
if(!ok)
return;
}
QString text = assembleContext(document, cursor, instruction);
AiBackend::Request request(text, AiBackend::Request::INFERENCE);
bool ret = m_ai->generate(request);
if(!ret)
QMessageBox::warning(m_mainWindow->activeView(), i18n("Failure"),
i18n("The Ai backend was unable to process this request"));
else
m_requests.insert(request.getId(), {cursor, document});
}
else
{
QMessageBox box;
QMessageBox box(m_mainWindow->activeView());
box.setText(i18n("The AI server is not connected."));
box.setInformativeText(i18n("would you like to try and reconnect?"));
box.setStandardButtons(QMessageBox::Yes | QMessageBox::No);
@ -214,5 +241,20 @@ void KateAiPluginView::setInstruct(bool instruct)
m_useInstruct = instruct;
}
void KateAiPluginView::setSystemPrompt(const QString& in)
{
m_systemPrompt = in;
}
void KateAiPluginView::setAi(QPointer<AiBackend> ai)
{
m_ai = ai;
}
void KateAiPluginView::setMaxContext(int maxContext)
{
m_maxContext = maxContext;
}
#include "kateai.moc"
#include "moc_kateai.cpp"

View File

@ -52,23 +52,28 @@ private:
KTextEditor::MainWindow *m_mainWindow;
inline static QPointer<AiBackend> m_ai;
bool m_useInstruct = false;
inline static bool m_useInstruct = false;
inline static QString m_systemPrompt;
inline static int m_maxContext;
inline static QString m_lastInstruct;
private:
void generate();
void gotResponse(AiBackend::Response respons);
static QStringList getIncludePaths(const QString& text);
QString assembleContext(QPointer<KTextEditor::Document> document, const KTextEditor::Cursor& cursor);
QString assembleContext(QPointer<KTextEditor::Document> document, const KTextEditor::Cursor& cursor, const QString& instruction = QString());
QPointer<KTextEditor::Document> activeDocument() const;
KTextEditor::Cursor getCurrentCursor() const;
QHash<uint32_t, Request> m_requests;
public:
KateAiPluginView(KateAiPlugin *plugin, KTextEditor::MainWindow *mainwindow, bool instruct = false);
KateAiPluginView(KateAiPlugin *plugin, KTextEditor::MainWindow *mainwindow);
~KateAiPluginView() override;
void setInstruct(bool instruct);
static void setAi(QPointer<AiBackend> ai){m_ai = ai;}
static void setInstruct(bool instruct);
static void setSystemPrompt(const QString& in);
static void setMaxContext(int maxContext);
static void setAi(QPointer<AiBackend> ai);
Q_SIGNAL void reconnect();
};

View File

@ -7,6 +7,9 @@
#include <QVBoxLayout>
#include <QLabel>
#include <qlabel.h>
#include <qradiobutton.h>
#include "backend.h"
KateAiConfigPage::KateAiConfigPage(QWidget *parent, KateAiPlugin *plugin)
: KTextEditor::ConfigPage(parent)
@ -15,19 +18,42 @@ KateAiConfigPage::KateAiConfigPage(QWidget *parent, KateAiPlugin *plugin)
QVBoxLayout* layout = new QVBoxLayout(this);
layout->setContentsMargins(0, 0, 0, 0);
std::vector<QString> backends = AiBackend::getAvailableBackendNames();
for(const QString& name : backends)
cmbxServerType.addItem(name);
layout->addWidget(&cmbxServerType);
QHBoxLayout* lineLayout = new QHBoxLayout(this);
QLabel* lineEditLabel = new QLabel(i18n("Url for the WebSockets ExLlama Ai server:"), this);
QLabel* lineEditLabel = new QLabel(i18n("Url for the Ai server:"), this);
lineEditLabel->setSizePolicy(QSizePolicy(QSizePolicy::Expanding, QSizePolicy::Fixed));
lineLayout->addWidget(lineEditLabel);
lineLayout->addWidget(&lineUrl);
layout->addLayout(lineLayout);
QHBoxLayout* contextSpinLayout = new QHBoxLayout(this);
QLabel* contextSpinLabel = new QLabel(i18n("Maximum context:"), this);
contextSpinLayout->addWidget(contextSpinLabel);
contextSpinBox.setMinimum(100);
contextSpinBox.setMaximum(32768);
contextSpinLayout->addWidget(&contextSpinBox);
layout->addLayout(contextSpinLayout);
btnCompletion.setText(i18n("Use the Ai to generate a completion"));
btnInstruct.setText(i18n("Use the Ai to insert a response to a instruction"));
btnInstruct.setText(i18n("Use the Ai to insert a response to an instruction"));
layout->addWidget(&btnCompletion);
layout->addWidget(&btnInstruct);
QHBoxLayout* systemPromptLayout = new QHBoxLayout(this);
systemPromptLabel.setText(i18n("System Prompt:"));
systemPromptLabel.setSizePolicy(QSizePolicy(QSizePolicy::Expanding, QSizePolicy::Fixed));
systemPromptLayout->addWidget(&systemPromptLabel);
systemPromptLayout->addWidget(&lineSystemPrompt);
layout->addLayout(systemPromptLayout);
layout->addStretch();
connect(&btnInstruct, &QRadioButton::toggled, this, &KateAiConfigPage::instructBtnToggeled);
reset();
}
@ -51,6 +77,9 @@ void KateAiConfigPage::apply()
KConfigGroup config(KSharedConfig::openConfig(), "Ai");
config.writeEntry("Url", lineUrl.text());
config.writeEntry("Instruct", btnInstruct.isChecked());
config.writeEntry("SystemPrompt", lineSystemPrompt.text());
config.writeEntry("Context", contextSpinBox.value());
config.writeEntry("Backend", cmbxServerType.currentText());
config.sync();
m_plugin->readConfig();
@ -59,7 +88,19 @@ void KateAiConfigPage::apply()
void KateAiConfigPage::reset()
{
KConfigGroup config(KSharedConfig::openConfig(), "Ai");
lineUrl.setText(config.readEntry("Url", "ws://localhost:8642"));
lineSystemPrompt.setText(config.readEntry("SystemPrompt", "ws://localhost:8642"));
lineUrl.setText(config.readEntry("Url", "You are an intelligent programming assistant."));
btnInstruct.setChecked(config.readEntry("Instruct", false));
btnCompletion.setChecked(!btnInstruct.isChecked());
contextSpinBox.setValue(config.readEntry("Context", 1024));
cmbxServerType.setCurrentText(config.readEntry("Backend", "KoboldAI"));
lineSystemPrompt.setEnabled(btnInstruct.isChecked());
systemPromptLabel.setEnabled(btnInstruct.isChecked());
}
void KateAiConfigPage::instructBtnToggeled(bool checked)
{
lineSystemPrompt.setEnabled(checked);
systemPromptLabel.setEnabled(checked);
}

View File

@ -5,16 +5,27 @@
#include <QLineEdit>
#include <QRadioButton>
#include <QComboBox>
#include <QSpinBox>
#include <QLabel>
class KateAiConfigPage : public KTextEditor::ConfigPage
{
Q_OBJECT
private:
QLineEdit lineUrl;
QLabel systemPromptLabel;
QLineEdit lineSystemPrompt;
QRadioButton btnCompletion;
QRadioButton btnInstruct;
QComboBox cmbxServerType;
QSpinBox contextSpinBox;
KateAiPlugin* m_plugin;
private:
void instructBtnToggeled(bool checked);
public:
explicit KateAiConfigPage(QWidget *parent = nullptr, KateAiPlugin *plugin = nullptr);
~KateAiConfigPage() override