1
0
Fork 0
tsgames/src/games/ai/connection.ts

352 lines
12 KiB
TypeScript

import Lock from "@common/lock";
import SSE from "@common/sse";
import { delay, throttle } from "@common/utils";
interface IBaseConnection {
instruct: string;
}
interface IKoboldConnection extends IBaseConnection {
url: string;
}
interface IHordeConnection extends IBaseConnection {
apiKey?: string;
model: string;
}
export const isKoboldConnection = (obj: unknown): obj is IKoboldConnection => (
obj != null && typeof obj === 'object' && 'url' in obj && typeof obj.url === 'string'
);
export const isHordeConnection = (obj: unknown): obj is IHordeConnection => (
obj != null && typeof obj === 'object' && 'model' in obj && typeof obj.model === 'string'
);
export type IConnection = IKoboldConnection | IHordeConnection;
interface IHordeWorker {
id: string;
models: string[];
flagged: boolean;
online: boolean;
maintenance_mode: boolean;
max_context_length: number;
max_length: number;
performance: string;
}
export interface IHordeModel {
name: string;
hordeNames: string[];
maxLength: number;
maxContext: number;
workers: string[];
}
interface IHordeResult {
faulted: boolean;
done: boolean;
finished: number;
generations?: {
text: string;
}[];
}
const DEFAULT_GENERATION_SETTINGS = {
temperature: 0.8,
min_p: 0.1,
rep_pen: 1.08,
rep_pen_range: -1,
rep_pen_slope: 0.7,
top_k: 100,
top_p: 0.92,
banned_tokens: ['anticipat'],
max_length: 300,
trim_stop: true,
stop_sequence: ['[INST]', '[/INST]', '</s>', '<|'],
dry_allowed_length: 5,
dry_multiplier: 0.8,
dry_base: 1,
dry_sequence_breakers: ["\n", ":", "\"", "*"],
dry_penalty_last_n: 0
}
const MIN_PERFORMANCE = 5.0;
const MIN_WORKER_CONTEXT = 8192;
const MAX_HORDE_LENGTH = 512;
const MAX_HORDE_CONTEXT = 32000;
export const HORDE_ANON_KEY = '0000000000';
export const normalizeModel = (model: string) => {
let currentModel = model.split(/[\\\/]/).at(-1);
currentModel = currentModel.split('::').at(0);
let normalizedModel: string;
do {
normalizedModel = currentModel;
currentModel = currentModel
.replace(/[ ._-]\d+(k$|-context)/i, '') // remove context length, i.e. -32k
.replace(/[ ._-](gptq|awq|exl2?|imat|i\d)/i, '') // remove quant name
.replace(/([ ._-]?gg(uf|ml)[ ._-]?(v[ ._-]?\d)?)/i, '') // remove gguf-v3/ggml/etc
.replace(/[ ._-]i?q([ ._-]?\d[ ._-]?(k?[ ._-]?x*[ ._-]?[lms]?)?)+/i, '') // remove quant size
.replace(/[ ._-]\d+(\.\d+)?bpw/i, '') // remove bpw
.replace(/[ ._-]f(p|loat)?(8|16|32)/i, '')
.replace(/^(debug-?)+/i, '')
.trim();
} while (normalizedModel !== currentModel);
return normalizedModel
.replace(/[ _-]+/ig, '-')
.replace(/\.{2,}/, '-')
.replace(/[ ._-]+$/ig, '')
.trim();
}
export const approximateTokens = (prompt: string): number =>
Math.round(prompt.split(/\s+/).length * 0.75);
export type IGenerationSettings = Partial<typeof DEFAULT_GENERATION_SETTINGS>;
export namespace Connection {
const AIHORDE = 'https://aihorde.net';
async function* generateKobold(url: string, prompt: string, extraSettings: IGenerationSettings = {}): AsyncGenerator<string> {
const sse = new SSE(`${url}/api/extra/generate/stream`, {
payload: JSON.stringify({
...DEFAULT_GENERATION_SETTINGS,
...extraSettings,
prompt,
}),
});
const messages: string[] = [];
const messageLock = new Lock();
let end = false;
sse.addEventListener('message', (e) => {
if (e.data) {
{
const { token, finish_reason } = JSON.parse(e.data);
messages.push(token);
if (finish_reason && finish_reason !== 'null') {
end = true;
}
}
}
messageLock.release();
});
const handleEnd = () => {
end = true;
messageLock.release();
};
sse.addEventListener('error', handleEnd);
sse.addEventListener('abort', handleEnd);
sse.addEventListener('readystatechange', (e) => {
if (e.readyState === SSE.CLOSED) handleEnd();
});
while (!end || messages.length) {
while (messages.length > 0) {
const message = messages.shift();
if (message != null) {
try {
yield message;
} catch { }
}
}
if (!end) {
await messageLock.wait();
}
}
sse.close();
}
async function generateHorde(connection: Omit<IHordeConnection, keyof IBaseConnection>, prompt: string, extraSettings: IGenerationSettings = {}): Promise<string> {
const models = await getHordeModels();
const model = models.get(connection.model);
if (model) {
let maxLength = Math.min(model.maxLength, DEFAULT_GENERATION_SETTINGS.max_length);
if (extraSettings.max_length && extraSettings.max_length < maxLength) {
maxLength = extraSettings.max_length;
}
const requestData = {
prompt,
params: {
...DEFAULT_GENERATION_SETTINGS,
...extraSettings,
n: 1,
max_context_length: model.maxContext,
max_length: maxLength,
rep_pen_range: Math.min(model.maxContext, 4096),
},
models: model.hordeNames,
workers: model.workers,
};
const generateResponse = await fetch(`${AIHORDE}/api/v2/generate/text/async`, {
method: 'POST',
body: JSON.stringify(requestData),
headers: {
'Content-Type': 'application/json',
apikey: connection.apiKey || HORDE_ANON_KEY,
},
});
if (!generateResponse.ok || Math.floor(generateResponse.status / 100) !== 2) {
throw new Error(`Error starting generation: ${generateResponse.statusText}: ${await generateResponse.text()}`);
}
const { id } = await generateResponse.json() as { id: string };
const deleteRequest = () => fetch(`${AIHORDE}/api/v2/generate/text/status/${id}`, { method: 'DELETE' })
.catch(e => console.error('Error deleting request', e));
while (true) {
await delay(2500);
const retrieveResponse = await fetch(`${AIHORDE}/api/v2/generate/text/status/${id}`);
if (!retrieveResponse.ok || Math.floor(retrieveResponse.status / 100) !== 2) {
deleteRequest();
throw new Error(`Error retrieving generation: ${retrieveResponse.statusText}: ${await retrieveResponse.text()}`);
}
const result: IHordeResult = await retrieveResponse.json();
if (result.done && result.generations?.length === 1) {
const { text } = result.generations[0];
return text;
}
}
}
throw new Error(`Model ${connection.model} is offline`);
}
export async function* generate(connection: IConnection, prompt: string, extraSettings: IGenerationSettings = {}) {
if (isKoboldConnection(connection)) {
yield* generateKobold(connection.url, prompt, extraSettings);
} else if (isHordeConnection(connection)) {
yield await generateHorde(connection, prompt, extraSettings);
}
}
async function requestHordeModels(): Promise<Map<string, IHordeModel>> {
try {
const response = await fetch(`${AIHORDE}/api/v2/workers?type=text`);
if (response.ok) {
const workers: IHordeWorker[] = await response.json();
const goodWorkers = workers.filter(w =>
w.online
&& !w.maintenance_mode
&& !w.flagged
&& w.max_context_length >= MIN_WORKER_CONTEXT
&& parseFloat(w.performance) >= MIN_PERFORMANCE
);
const models = new Map<string, IHordeModel>();
for (const worker of goodWorkers) {
for (const modelName of worker.models) {
const normName = normalizeModel(modelName.toLowerCase());
let model = models.get(normName);
if (!model) {
model = {
hordeNames: [],
maxContext: MAX_HORDE_CONTEXT,
maxLength: MAX_HORDE_LENGTH,
name: normName,
workers: []
}
}
if (!model.hordeNames.includes(modelName)) {
model.hordeNames.push(modelName);
}
if (!model.workers.includes(worker.id)) {
model.workers.push(worker.id);
}
model.maxContext = Math.min(model.maxContext, worker.max_context_length);
model.maxLength = Math.min(model.maxLength, worker.max_length);
models.set(normName, model);
}
}
return models;
}
} catch (e) {
console.error(e);
}
return new Map();
};
export const getHordeModels = throttle(requestHordeModels, 10000);
export async function getModelName(connection: IConnection): Promise<string> {
if (isKoboldConnection(connection)) {
try {
const response = await fetch(`${connection.url}/api/v1/model`);
if (response.ok) {
const { result } = await response.json();
return result;
}
} catch (e) {
console.log('Error getting max tokens', e);
}
} else if (isHordeConnection(connection)) {
return connection.model;
}
return '';
}
export async function getContextLength(connection: IConnection): Promise<number> {
if (isKoboldConnection(connection)) {
try {
const response = await fetch(`${connection.url}/api/extra/true_max_context_length`);
if (response.ok) {
const { value } = await response.json();
return value;
}
} catch (e) {
console.log('Error getting max tokens', e);
}
} else if (isHordeConnection(connection)) {
const models = await getHordeModels();
const model = models.get(connection.model);
if (model) {
return model.maxContext;
}
}
return 0;
}
export async function countTokens(connection: IConnection, prompt: string) {
if (isKoboldConnection(connection)) {
try {
const response = await fetch(`${connection.url}/api/extra/tokencount`, {
body: JSON.stringify({ prompt }),
headers: { 'Content-Type': 'applicarion/json' },
method: 'POST',
});
if (response.ok) {
const { value } = await response.json();
return value;
}
} catch (e) {
console.log('Error counting tokens', e);
}
}
return approximateTokens(prompt);
}
}