You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
222 lines
6.4 KiB
222 lines
6.4 KiB
import * as vscode from 'vscode';
|
|
import {
|
|
FetchErrorCause,
|
|
ResponseData,
|
|
showMessageWithTimeout,
|
|
updateStatusBarMessage,
|
|
} from './common';
|
|
|
|
// oogabooga/text-generation-webui OpenAI compatible API
|
|
// https://github.com/oobabooga/text-generation-webui/wiki/12-%E2%80%90-OpenAI-API
|
|
|
|
export type OpenAICompletionRequest = {
|
|
model?: string; // automatic
|
|
prompt: string;
|
|
best_of?: number; // 1
|
|
echo?: boolean; // false
|
|
frequency_penalty?: number; // null
|
|
logit_bias?: object; // null
|
|
logprobs?: number; // 0
|
|
max_tokens?: number; // 16
|
|
n?: number; // 1
|
|
presence_penalty?: number; // 0
|
|
stop?: string;
|
|
stream?: boolean; // false
|
|
suffix?: string;
|
|
temperature?: number; // 1
|
|
top_p?: number; // 1
|
|
user?: string;
|
|
preset?: string;
|
|
min_p?: number; // 1
|
|
top_k?: number; // 1
|
|
repetition_penalty?: number; // 1
|
|
repetition_penalty_range?: number; // 1024
|
|
typical_p?: number; // 1
|
|
tfs?: number; // 1
|
|
top_a?: number; // 0
|
|
epsilon_cutoff?: number; // 0
|
|
eta_cutoff?: number; // 0
|
|
guidance_scale?: number; // 1
|
|
negative_prompt?: string; // ""
|
|
penalty_alpha?: number; // 0
|
|
mirostat_mode?: number; // 0
|
|
mirostat_tau?: number; // 5
|
|
mirostat_eta?: number; // 0.1
|
|
temperature_last?: boolean; // false
|
|
do_sample?: boolean; // true
|
|
seed?: number; // -1
|
|
encoder_repetition_penalty?: number; // 1
|
|
no_repeat_ngram_size?: number; // 0
|
|
min_length?: number; // 0
|
|
num_beams?: number; // 1
|
|
length_penalty?: number; // 1
|
|
early_stopping?: boolean; // false
|
|
truncation_length?: number; // 0
|
|
max_tokens_second?: number; // 0
|
|
custom_token_bans?: string; // ""
|
|
auto_max_new_tokens?: boolean; // false
|
|
ban_eos_token?: boolean; // false
|
|
add_bos_token?: boolean; // true
|
|
skip_special_tokens?: boolean; // true
|
|
grammar_string?: string; // ''
|
|
};
|
|
|
|
type OpenAICompletionSuccessResponse = {
|
|
id: string;
|
|
choices: {
|
|
finish_reason: string;
|
|
index: number;
|
|
logprobs: object | null;
|
|
text: string;
|
|
}[];
|
|
created?: number;
|
|
model: string;
|
|
object?: string;
|
|
usage: {
|
|
completion_tokens?: number;
|
|
prompt_tokens: number;
|
|
total_tokens: number;
|
|
};
|
|
};
|
|
|
|
type OpenAICompletionFailureResponse = {
|
|
detail: {
|
|
loc: (string | number)[];
|
|
msg: string;
|
|
type: string;
|
|
}[];
|
|
};
|
|
|
|
type OpenAICompletionResponse = OpenAICompletionSuccessResponse | OpenAICompletionFailureResponse;
|
|
|
|
export function createOpenAIAPIRequest(
|
|
config: vscode.WorkspaceConfiguration,
|
|
doc_before: string,
|
|
doc_after: string
|
|
): OpenAICompletionRequest {
|
|
let request: OpenAICompletionRequest = {
|
|
prompt: '',
|
|
max_tokens: config.get('parameters.MaxTokens') as number,
|
|
mirostat_mode: config.get('parameters.Mirostat') as number,
|
|
repetition_penalty: config.get('parameters.RepeatPenalty') as number,
|
|
frequency_penalty: config.get('parameters.FrequencyPenalty,') as number,
|
|
presence_penalty: config.get('parameters.PresencePenalty,') as number,
|
|
repetition_penalty_range: config.get('parameters.RepeatCtx,') as number,
|
|
temperature: config.get('parameters.Temperature') as number,
|
|
top_p: config.get('parameters.Top_p') as number,
|
|
top_k: config.get('parameters.Top_k') as number,
|
|
typical_p: config.get('parameters.Typical_p') as number,
|
|
tfs: config.get('parameters.Tailfree_z,') as number,
|
|
seed: config.get('parameters.Seed') as number,
|
|
stream: config.get('parameters.stream') as boolean,
|
|
};
|
|
|
|
const fim = config.get('fimEnabled') as boolean;
|
|
|
|
if (fim === true) {
|
|
const fim_beg = config.get('fimBeginString') as string;
|
|
const fim_hole = config.get('fimHoleString') as string;
|
|
const fim_end = config.get('fimEndString') as string;
|
|
request.prompt = fim_beg + doc_before + fim_hole + doc_after + fim_end;
|
|
} else {
|
|
request.prompt = doc_before;
|
|
}
|
|
|
|
return request;
|
|
}
|
|
|
|
// for now only vv1/completions is implemented
|
|
// TODO: implement chat
|
|
export function openAIAPIRequestEndpoint(config: vscode.WorkspaceConfiguration): string {
|
|
return (config.get('endpoint') as string) + '/v1/completions';
|
|
}
|
|
|
|
// make a request and parse the incoming data
|
|
export async function openAIMakeRequest(
|
|
request_body: OpenAICompletionRequest,
|
|
endpoint: string
|
|
): Promise<ResponseData> {
|
|
let ret: ResponseData = {
|
|
content: '',
|
|
tokens: 0,
|
|
time: 0,
|
|
};
|
|
let data: OpenAICompletionResponse;
|
|
const is_stream: boolean = request_body.stream === true ? true : false;
|
|
|
|
// format the request
|
|
const request: RequestInit = {
|
|
method: 'POST',
|
|
headers: {
|
|
'content-type': 'application/json; charset=UTF-8',
|
|
},
|
|
body: JSON.stringify(request_body),
|
|
};
|
|
|
|
// try to send the request to the running server
|
|
try {
|
|
const response_promise = fetch(endpoint, request);
|
|
|
|
// if doing a stream request we have to attach a reader and join
|
|
// the individual responses
|
|
const response = await response_promise;
|
|
|
|
// read the data chunk by chunk using asynchronous iteration
|
|
if (response.body === null) {
|
|
throw new Error('null response body');
|
|
}
|
|
|
|
// start a timer
|
|
const timer_start = performance.now();
|
|
|
|
let chunk_number: number = 1;
|
|
for await (const chunk of response.body) {
|
|
// each chunk of data is a complete response in the form of a uint8 array
|
|
const data_text = Buffer.from(chunk as Uint8Array).toString();
|
|
|
|
// each response chunk contains one or more data chunks, which in turn are just json data
|
|
const data_chunks = data_text.split('data: ');
|
|
let data: OpenAICompletionResponse;
|
|
for (const data_string of data_chunks) {
|
|
data_string.trim();
|
|
if (data_string.length < 2) {
|
|
continue;
|
|
}
|
|
|
|
data = JSON.parse(data_string);
|
|
//console.log(JSON.stringify(data));
|
|
|
|
if (Object.hasOwn(data, 'detail') === true) {
|
|
data = data as OpenAICompletionFailureResponse;
|
|
// TODO: why did it error?
|
|
throw new Error('OpenAI Endpoint Error');
|
|
}
|
|
// unpack the data
|
|
data = data as OpenAICompletionSuccessResponse;
|
|
|
|
for (const choice of data.choices) {
|
|
ret.content += choice.text;
|
|
updateStatusBarMessage(chunk_number, choice.text);
|
|
chunk_number++;
|
|
}
|
|
ret.tokens += data.usage?.completion_tokens || 0;
|
|
}
|
|
}
|
|
|
|
// stop the timer
|
|
const timer_end = performance.now();
|
|
ret.time = (timer_end - timer_start) / 1000.0;
|
|
// clear the status bar item
|
|
updateStatusBarMessage(0, '');
|
|
} catch (e: any) {
|
|
console.error(e);
|
|
const err = e as TypeError;
|
|
const cause: FetchErrorCause = err.cause as FetchErrorCause;
|
|
const estr: string =
|
|
err.message + ' ' + cause.code + ' at ' + cause.address + ':' + cause.port;
|
|
// let the user know something went wrong
|
|
// TODO: maybe add a retry action or something
|
|
showMessageWithTimeout('dumbpilot error: ' + estr, 3000);
|
|
}
|
|
return ret;
|
|
}
|
|
|