kateai/backend.cpp
2024-06-11 14:15:08 +02:00

99 lines
1.9 KiB
C++

#include "backend.h"
#include "exllama.h"
#include <qstringliteral.h>
AiBackend::Request::Request(const QString& textIn, type_t typeIn, void* userPtrIn):
text(textIn),
userPtr(userPtrIn),
type(typeIn)
{
id = idCounter++;
}
const QString& AiBackend::Request::getText() const
{
return text;
}
uint32_t AiBackend::Request::getId() const
{
return id;
}
void* AiBackend::Request::getUserPtr() const
{
return userPtr;
}
bool AiBackend::Request::operator==(const Request& in)
{
return id == in.id;
}
AiBackend::Request::type_t AiBackend::Request::getType() const
{
return type;
}
AiBackend::Response::Response(QString textIn, uint32_t idIn, bool finishedIn,
AiBackend::Request::type_t type, int64_t tokensIn, void* userPtrIn):
AiBackend::Request::Request(textIn, type, userPtrIn),
finished(finishedIn),
tokens(tokensIn)
{
id = idIn;
}
bool AiBackend::Response::isFinished() const
{
return finished;
}
void AiBackend::Response::setUserPtr(void* ptr)
{
userPtr = ptr;
}
std::vector<QString> AiBackend::getAvailableBackendNames()
{
return {QStringLiteral("KoboldAI"), ExLlama::backendNameStatic()};
}
std::shared_ptr<AiBackend> AiBackend::createBackend(const QString& name)
{
if(name == QStringLiteral("KoboldAI"))
return nullptr;
if(name == ExLlama::backendNameStatic())
return std::shared_ptr<ExLlama>(new ExLlama());
return nullptr;
}
bool AiBackend::generate(const Request& request)
{
if(request.getType() == Request::UNKOWN)
return false;
bool ret = generateImpl(request);
if(ret)
m_requests.insert(request.getId(), request);
return ret;
}
bool AiBackend::isValidId(uint32_t id)
{
return m_requests.find(id) != m_requests.end();
}
void AiBackend::feedResponse(Response response)
{
response.setUserPtr(m_requests[response.getId()].getUserPtr());
if(response.isFinished())
m_requests.remove(response.getId());
gotResponse(response);
}