You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
dumbpilot/src/openai-api.ts

214 lines
6.0 KiB

import * as vscode from 'vscode';
import {
FetchErrorCause,
ResponseData,
showMessageWithTimeout,
showPendingStatusBar,
} from './common';
import { config } from 'process';
// oogabooga/text-generation-webui OpenAI compatible API
// https://github.com/oobabooga/text-generation-webui/wiki/12-%E2%80%90-OpenAI-API
export type OpenAICompletionRequest = {
model?: string; // automatic
prompt: string;
best_of?: number; // 1
echo?: boolean; // false
frequency_penalty?: number; // null
logit_bias?: object; // null
logprobs?: number; // 0
max_tokens?: number; // 16
n?: number; // 1
presence_penalty?: number; // 0
stop?: string;
stream?: boolean; // false
suffix?: string;
temperature?: number; // 1
top_p?: number; // 1
user?: string;
preset?: string;
min_p?: number; // 1
top_k?: number; // 1
repetition_penalty?: number; // 1
repetition_penalty_range?: number; // 1024
typical_p?: number; // 1
tfs?: number; // 1
top_a?: number; // 0
epsilon_cutoff?: number; // 0
eta_cutoff?: number; // 0
guidance_scale?: number; // 1
negative_prompt?: string; // ""
penalty_alpha?: number; // 0
mirostat_mode?: number; // 0
mirostat_tau?: number; // 5
mirostat_eta?: number; // 0.1
temperature_last?: boolean; // false
do_sample?: boolean; // true
seed?: number; // -1
encoder_repetition_penalty?: number; // 1
no_repeat_ngram_size?: number; // 0
min_length?: number; // 0
num_beams?: number; // 1
length_penalty?: number; // 1
early_stopping?: boolean; // false
truncation_length?: number; // 0
max_tokens_second?: number; // 0
custom_token_bans?: string; // ""
auto_max_new_tokens?: boolean; // false
ban_eos_token?: boolean; // false
add_bos_token?: boolean; // true
skip_special_tokens?: boolean; // true
grammar_string?: string; // ''
};
type OpenAICompletionSuccessResponse = {
id: string;
choices: {
finish_reason: string;
index: number;
logprobs: object | null;
text: string;
}[];
created?: number;
model: string;
object?: string;
usage: {
completion_tokens?: number;
prompt_tokens: number;
total_tokens: number;
};
};
type OpenAICompletionFailureResponse = {
detail: {
loc: (string | number)[];
msg: string;
type: string;
}[];
};
type OpenAICompletionResponse = OpenAICompletionSuccessResponse | OpenAICompletionFailureResponse;
export function createOpenAIAPIRequest(
config: vscode.WorkspaceConfiguration,
doc_before: string,
doc_after: string
): OpenAICompletionRequest {
let request: OpenAICompletionRequest = {
prompt: '',
max_tokens: config.get('llamaMaxtokens') as number,
mirostat_mode: config.get('llamaMirostat') as number,
repetition_penalty: config.get('llamaRepeatPenalty') as number,
frequency_penalty: config.get('llamaFrequencyPenalty,') as number,
presence_penalty: config.get('llamaPresencePenalty,') as number,
repetition_penalty_range: config.get('llamaRepeatCtx,') as number,
temperature: config.get('llamaTemperature') as number,
top_p: config.get('llamaTop_p') as number,
top_k: config.get('llamaTop_k') as number,
typical_p: config.get('llamaTypical_p') as number,
tfs: config.get('llamaTailfree_z,') as number,
seed: config.get('llamaSeed') as number,
stream: config.get('llamaAPIStream'),
};
const fim = config.get('fimEnabled') as boolean;
if (fim === true) {
const fim_beg = config.get('fimBeginString') as string;
const fim_hole = config.get('fimHoleString') as string;
const fim_end = config.get('fimEndString') as string;
request.prompt = fim_beg + doc_before + fim_hole + doc_after + fim_end;
} else {
request.prompt = doc_before;
}
return request;
}
// for now only completions is implemented
export function openAIAPIRequestEndpoint(config: vscode.WorkspaceConfiguration): string {
return (config.get('llamaHost') as string) + '/v1/completions';
}
export async function openAIMakeRequest(
request_body: OpenAICompletionRequest,
endpoint: string
): Promise<ResponseData> {
let ret: ResponseData = {
content: '',
tokens: 0,
time: 0,
};
let data: OpenAICompletionResponse;
const is_stream: boolean = request_body.stream === true ? true : false;
// format the request
const request: RequestInit = {
method: 'POST',
headers: {
'content-type': 'application/json; charset=UTF-8',
},
body: JSON.stringify(request_body),
};
// try to send the request to the running server
try {
const response_promise = fetch(endpoint, request);
// if doing a stream request we have to attach a reader and join
// the individual responses
const response = await response_promise;
// read the data chunk by chunk using asynchronous iteration
if (response.body === null) {
throw new Error('null response body');
}
// start a timer
const timer_start = performance.now();
for await (const chunk of response.body) {
// FIXME: why the fuck do I have to do this shite
let data_text = new TextDecoder().decode(chunk);
data_text = data_text.substring(data_text.indexOf('{'));
let data: OpenAICompletionResponse;
try {
data = JSON.parse(data_text);
} catch (e: any) {
console.error(e);
return ret;
}
//console.log(JSON.stringify(data));
if (Object.hasOwn(data, 'detail') === true) {
data = data as OpenAICompletionFailureResponse;
// TODO: why did it error?
throw new Error('OpenAI Endpoint Error');
}
// unpack the data
data = data as OpenAICompletionSuccessResponse;
// FIXME: why the choices may be multiple?
// TODO: display the multiple choices
//console.log(data.choices[0].text);
ret.content += data.choices[0].text;
ret.tokens += data.usage?.completion_tokens || 0;
}
// stop the timer
const timer_end = performance.now();
ret.time = (timer_end - timer_start) / 1000.0;
} catch (e: any) {
console.error(e);
const err = e as TypeError;
const cause: FetchErrorCause = err.cause as FetchErrorCause;
const estr: string =
err.message + ' ' + cause.code + ' at ' + cause.address + ':' + cause.port;
// let the user know something went wrong
// TODO: maybe add a retry action or something
showMessageWithTimeout('dumbpilot error: ' + estr, 3000);
}
return ret;
}