diff --git a/src/games/ai-story/tools/connection.ts b/src/games/ai-story/tools/connection.ts index 996a33a..2c31b20 100644 --- a/src/games/ai-story/tools/connection.ts +++ b/src/games/ai-story/tools/connection.ts @@ -161,6 +161,12 @@ export namespace Connection { if (extraSettings.max_length && extraSettings.max_length < maxLength) { maxLength = extraSettings.max_length; } + const baseTemperature = extraSettings.temperature ?? DEFAULT_GENERATION_SETTINGS.temperature; + let currentTemperature = baseTemperature; + const MAX_TEMPERATURE = 2.0; + const TEMP_INCREMENT = 0.15; + const RECOVERY_LENGTH = 16; + const requestData = { prompt, params: { @@ -170,11 +176,13 @@ export namespace Connection { max_context_length: model.maxContext, max_length: maxLength, rep_pen_range: Math.min(model.maxContext, 4096), + temperature: currentTemperature, }, models: model.hordeNames, workers: model.workers, }; const bannedTokens = requestData.params.banned_tokens ?? []; + let recoveryMode = false; const { signal } = abortController; @@ -221,34 +229,54 @@ export namespace Connection { if (response?.text) { text = response.text; + let minStopIdx = text.length; for (const sequence of requestData.params.stop_sequence) { const stopIdx = text.indexOf(sequence); - if (stopIdx >= 0) { - text = text.slice(0, stopIdx); + if (stopIdx >= 0 && stopIdx < minStopIdx) { + minStopIdx = stopIdx; } } + if (minStopIdx < text.length) { + text = text.slice(0, minStopIdx); + } const locaseText = text.toLowerCase(); let unsloppedText = text; + let slopDetected = false; + let minSlopIdx = text.length; + let detectedBan = ''; for (const ban of bannedTokens) { const slopIdx = locaseText.indexOf(ban.toLowerCase()); - if (slopIdx >= 0) { - console.log(`[horde] slop '${ban}' detected at ${slopIdx}`); - unsloppedText = unsloppedText.slice(0, slopIdx).trimEnd(); + if (slopIdx >= 0 && slopIdx < minSlopIdx) { + minSlopIdx = slopIdx; + detectedBan = ban; + slopDetected = true; } } + if (slopDetected) { + console.log(`[horde] slop '${detectedBan}' detected at ${minSlopIdx}`); + unsloppedText = unsloppedText.slice(0, minSlopIdx).trimEnd(); + } yield { text: unsloppedText, cost: response.cost }; requestData.prompt += unsloppedText; - if (unsloppedText === text) { + if (slopDetected) { + recoveryMode = true; + requestData.params.max_length = RECOVERY_LENGTH; + currentTemperature = Math.min(MAX_TEMPERATURE, currentTemperature + TEMP_INCREMENT); + requestData.params.temperature = currentTemperature; + requestData.params.top_p = Math.min(0.98, 0.92 + (currentTemperature - baseTemperature) * 0.02); + } else if (recoveryMode) { + recoveryMode = false; + requestData.params.max_length = maxLength; + requestData.params.temperature = baseTemperature; + requestData.params.top_p = 0.92; + currentTemperature = baseTemperature; + } else { return; // we are finished } - - if (unsloppedText.length === 0) { - requestData.params.temperature += 0.05; - } } } catch (e) { if (!signal.aborted) {