352 lines
12 KiB
TypeScript
352 lines
12 KiB
TypeScript
import Lock from "@common/lock";
|
|
import SSE from "@common/sse";
|
|
import { delay, throttle } from "@common/utils";
|
|
|
|
interface IBaseConnection {
|
|
instruct: string;
|
|
}
|
|
|
|
interface IKoboldConnection extends IBaseConnection {
|
|
url: string;
|
|
}
|
|
|
|
interface IHordeConnection extends IBaseConnection {
|
|
apiKey?: string;
|
|
model: string;
|
|
}
|
|
|
|
export const isKoboldConnection = (obj: unknown): obj is IKoboldConnection => (
|
|
obj != null && typeof obj === 'object' && 'url' in obj && typeof obj.url === 'string'
|
|
);
|
|
|
|
export const isHordeConnection = (obj: unknown): obj is IHordeConnection => (
|
|
obj != null && typeof obj === 'object' && 'model' in obj && typeof obj.model === 'string'
|
|
);
|
|
|
|
export type IConnection = IKoboldConnection | IHordeConnection;
|
|
|
|
interface IHordeWorker {
|
|
id: string;
|
|
models: string[];
|
|
flagged: boolean;
|
|
online: boolean;
|
|
maintenance_mode: boolean;
|
|
max_context_length: number;
|
|
max_length: number;
|
|
performance: string;
|
|
}
|
|
|
|
export interface IHordeModel {
|
|
name: string;
|
|
hordeNames: string[];
|
|
maxLength: number;
|
|
maxContext: number;
|
|
workers: string[];
|
|
}
|
|
|
|
interface IHordeResult {
|
|
faulted: boolean;
|
|
done: boolean;
|
|
finished: number;
|
|
generations?: {
|
|
text: string;
|
|
}[];
|
|
}
|
|
|
|
const DEFAULT_GENERATION_SETTINGS = {
|
|
temperature: 0.8,
|
|
min_p: 0.1,
|
|
rep_pen: 1.08,
|
|
rep_pen_range: -1,
|
|
rep_pen_slope: 0.7,
|
|
top_k: 100,
|
|
top_p: 0.92,
|
|
banned_tokens: ['anticipat'],
|
|
max_length: 300,
|
|
trim_stop: true,
|
|
stop_sequence: ['[INST]', '[/INST]', '</s>', '<|'],
|
|
dry_allowed_length: 5,
|
|
dry_multiplier: 0.8,
|
|
dry_base: 1,
|
|
dry_sequence_breakers: ["\n", ":", "\"", "*"],
|
|
dry_penalty_last_n: 0
|
|
}
|
|
|
|
const MIN_PERFORMANCE = 5.0;
|
|
const MIN_WORKER_CONTEXT = 8192;
|
|
const MAX_HORDE_LENGTH = 512;
|
|
const MAX_HORDE_CONTEXT = 32000;
|
|
export const HORDE_ANON_KEY = '0000000000';
|
|
|
|
export const normalizeModel = (model: string) => {
|
|
let currentModel = model.split(/[\\\/]/).at(-1);
|
|
currentModel = currentModel.split('::').at(0);
|
|
let normalizedModel: string;
|
|
|
|
do {
|
|
normalizedModel = currentModel;
|
|
|
|
currentModel = currentModel
|
|
.replace(/[ ._-]\d+(k$|-context)/i, '') // remove context length, i.e. -32k
|
|
.replace(/[ ._-](gptq|awq|exl2?|imat|i\d)/i, '') // remove quant name
|
|
.replace(/([ ._-]?gg(uf|ml)[ ._-]?(v[ ._-]?\d)?)/i, '') // remove gguf-v3/ggml/etc
|
|
.replace(/[ ._-]i?q([ ._-]?\d[ ._-]?(k?[ ._-]?x*[ ._-]?[lms]?)?)+/i, '') // remove quant size
|
|
.replace(/[ ._-]\d+(\.\d+)?bpw/i, '') // remove bpw
|
|
.replace(/[ ._-]f(p|loat)?(8|16|32)/i, '')
|
|
.replace(/^(debug-?)+/i, '')
|
|
.trim();
|
|
} while (normalizedModel !== currentModel);
|
|
|
|
return normalizedModel
|
|
.replace(/[ _-]+/ig, '-')
|
|
.replace(/\.{2,}/, '-')
|
|
.replace(/[ ._-]+$/ig, '')
|
|
.trim();
|
|
}
|
|
|
|
export const approximateTokens = (prompt: string): number =>
|
|
Math.round(prompt.split(/\s+/).length * 0.75);
|
|
|
|
export type IGenerationSettings = Partial<typeof DEFAULT_GENERATION_SETTINGS>;
|
|
|
|
export namespace Connection {
|
|
const AIHORDE = 'https://aihorde.net';
|
|
|
|
async function* generateKobold(url: string, prompt: string, extraSettings: IGenerationSettings = {}): AsyncGenerator<string> {
|
|
const sse = new SSE(`${url}/api/extra/generate/stream`, {
|
|
payload: JSON.stringify({
|
|
...DEFAULT_GENERATION_SETTINGS,
|
|
...extraSettings,
|
|
prompt,
|
|
}),
|
|
});
|
|
|
|
const messages: string[] = [];
|
|
const messageLock = new Lock();
|
|
let end = false;
|
|
|
|
sse.addEventListener('message', (e) => {
|
|
if (e.data) {
|
|
{
|
|
const { token, finish_reason } = JSON.parse(e.data);
|
|
messages.push(token);
|
|
|
|
if (finish_reason && finish_reason !== 'null') {
|
|
end = true;
|
|
}
|
|
}
|
|
}
|
|
messageLock.release();
|
|
});
|
|
|
|
const handleEnd = () => {
|
|
end = true;
|
|
messageLock.release();
|
|
};
|
|
|
|
sse.addEventListener('error', handleEnd);
|
|
sse.addEventListener('abort', handleEnd);
|
|
sse.addEventListener('readystatechange', (e) => {
|
|
if (e.readyState === SSE.CLOSED) handleEnd();
|
|
});
|
|
|
|
while (!end || messages.length) {
|
|
while (messages.length > 0) {
|
|
const message = messages.shift();
|
|
if (message != null) {
|
|
try {
|
|
yield message;
|
|
} catch { }
|
|
}
|
|
}
|
|
if (!end) {
|
|
await messageLock.wait();
|
|
}
|
|
}
|
|
|
|
sse.close();
|
|
}
|
|
|
|
async function generateHorde(connection: Omit<IHordeConnection, keyof IBaseConnection>, prompt: string, extraSettings: IGenerationSettings = {}): Promise<string> {
|
|
const models = await getHordeModels();
|
|
const model = models.get(connection.model);
|
|
if (model) {
|
|
let maxLength = Math.min(model.maxLength, DEFAULT_GENERATION_SETTINGS.max_length);
|
|
if (extraSettings.max_length && extraSettings.max_length < maxLength) {
|
|
maxLength = extraSettings.max_length;
|
|
}
|
|
const requestData = {
|
|
prompt,
|
|
params: {
|
|
...DEFAULT_GENERATION_SETTINGS,
|
|
...extraSettings,
|
|
n: 1,
|
|
max_context_length: model.maxContext,
|
|
max_length: maxLength,
|
|
rep_pen_range: Math.min(model.maxContext, 4096),
|
|
},
|
|
models: model.hordeNames,
|
|
workers: model.workers,
|
|
};
|
|
|
|
const generateResponse = await fetch(`${AIHORDE}/api/v2/generate/text/async`, {
|
|
method: 'POST',
|
|
body: JSON.stringify(requestData),
|
|
headers: {
|
|
'Content-Type': 'application/json',
|
|
apikey: connection.apiKey || HORDE_ANON_KEY,
|
|
},
|
|
});
|
|
|
|
if (!generateResponse.ok || Math.floor(generateResponse.status / 100) !== 2) {
|
|
throw new Error(`Error starting generation: ${generateResponse.statusText}: ${await generateResponse.text()}`);
|
|
}
|
|
|
|
const { id } = await generateResponse.json() as { id: string };
|
|
const deleteRequest = () => fetch(`${AIHORDE}/api/v2/generate/text/status/${id}`, { method: 'DELETE' })
|
|
.catch(e => console.error('Error deleting request', e));
|
|
|
|
while (true) {
|
|
await delay(2500);
|
|
|
|
const retrieveResponse = await fetch(`${AIHORDE}/api/v2/generate/text/status/${id}`);
|
|
if (!retrieveResponse.ok || Math.floor(retrieveResponse.status / 100) !== 2) {
|
|
deleteRequest();
|
|
throw new Error(`Error retrieving generation: ${retrieveResponse.statusText}: ${await retrieveResponse.text()}`);
|
|
}
|
|
|
|
const result: IHordeResult = await retrieveResponse.json();
|
|
|
|
if (result.done && result.generations?.length === 1) {
|
|
const { text } = result.generations[0];
|
|
|
|
return text;
|
|
}
|
|
}
|
|
}
|
|
|
|
throw new Error(`Model ${connection.model} is offline`);
|
|
}
|
|
|
|
export async function* generate(connection: IConnection, prompt: string, extraSettings: IGenerationSettings = {}) {
|
|
if (isKoboldConnection(connection)) {
|
|
yield* generateKobold(connection.url, prompt, extraSettings);
|
|
} else if (isHordeConnection(connection)) {
|
|
yield await generateHorde(connection, prompt, extraSettings);
|
|
}
|
|
}
|
|
|
|
async function requestHordeModels(): Promise<Map<string, IHordeModel>> {
|
|
try {
|
|
const response = await fetch(`${AIHORDE}/api/v2/workers?type=text`);
|
|
if (response.ok) {
|
|
const workers: IHordeWorker[] = await response.json();
|
|
const goodWorkers = workers.filter(w =>
|
|
w.online
|
|
&& !w.maintenance_mode
|
|
&& !w.flagged
|
|
&& w.max_context_length >= MIN_WORKER_CONTEXT
|
|
&& parseFloat(w.performance) >= MIN_PERFORMANCE
|
|
);
|
|
|
|
const models = new Map<string, IHordeModel>();
|
|
|
|
for (const worker of goodWorkers) {
|
|
for (const modelName of worker.models) {
|
|
const normName = normalizeModel(modelName.toLowerCase());
|
|
let model = models.get(normName);
|
|
if (!model) {
|
|
model = {
|
|
hordeNames: [],
|
|
maxContext: MAX_HORDE_CONTEXT,
|
|
maxLength: MAX_HORDE_LENGTH,
|
|
name: normName,
|
|
workers: []
|
|
}
|
|
}
|
|
|
|
if (!model.hordeNames.includes(modelName)) {
|
|
model.hordeNames.push(modelName);
|
|
}
|
|
if (!model.workers.includes(worker.id)) {
|
|
model.workers.push(worker.id);
|
|
}
|
|
|
|
model.maxContext = Math.min(model.maxContext, worker.max_context_length);
|
|
model.maxLength = Math.min(model.maxLength, worker.max_length);
|
|
|
|
models.set(normName, model);
|
|
}
|
|
}
|
|
|
|
return models;
|
|
}
|
|
} catch (e) {
|
|
console.error(e);
|
|
}
|
|
|
|
return new Map();
|
|
};
|
|
|
|
export const getHordeModels = throttle(requestHordeModels, 10000);
|
|
|
|
export async function getModelName(connection: IConnection): Promise<string> {
|
|
if (isKoboldConnection(connection)) {
|
|
try {
|
|
const response = await fetch(`${connection.url}/api/v1/model`);
|
|
if (response.ok) {
|
|
const { result } = await response.json();
|
|
return result;
|
|
}
|
|
} catch (e) {
|
|
console.log('Error getting max tokens', e);
|
|
}
|
|
} else if (isHordeConnection(connection)) {
|
|
return connection.model;
|
|
}
|
|
|
|
return '';
|
|
}
|
|
|
|
export async function getContextLength(connection: IConnection): Promise<number> {
|
|
if (isKoboldConnection(connection)) {
|
|
try {
|
|
const response = await fetch(`${connection.url}/api/extra/true_max_context_length`);
|
|
if (response.ok) {
|
|
const { value } = await response.json();
|
|
return value;
|
|
}
|
|
} catch (e) {
|
|
console.log('Error getting max tokens', e);
|
|
}
|
|
} else if (isHordeConnection(connection)) {
|
|
const models = await getHordeModels();
|
|
const model = models.get(connection.model);
|
|
if (model) {
|
|
return model.maxContext;
|
|
}
|
|
}
|
|
|
|
return 0;
|
|
}
|
|
|
|
export async function countTokens(connection: IConnection, prompt: string) {
|
|
if (isKoboldConnection(connection)) {
|
|
try {
|
|
const response = await fetch(`${connection.url}/api/extra/tokencount`, {
|
|
body: JSON.stringify({ prompt }),
|
|
headers: { 'Content-Type': 'applicarion/json' },
|
|
method: 'POST',
|
|
});
|
|
if (response.ok) {
|
|
const { value } = await response.json();
|
|
return value;
|
|
}
|
|
} catch (e) {
|
|
console.log('Error counting tokens', e);
|
|
}
|
|
}
|
|
|
|
return approximateTokens(prompt);
|
|
}
|
|
} |