You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
154 lines
4.1 KiB
154 lines
4.1 KiB
import * as vscode from 'vscode';
|
|
import './common';
|
|
import {
|
|
FetchErrorCause,
|
|
ResponseData,
|
|
showMessageWithTimeout,
|
|
showPendingStatusBar,
|
|
} from './common';
|
|
|
|
// llama.cpp server response format
|
|
export type LlamaData = {
|
|
content: string;
|
|
generation_settings: JSON;
|
|
model: string;
|
|
prompt: string;
|
|
stopped_eos: boolean;
|
|
stopped_limit: boolean;
|
|
stopped_word: boolean;
|
|
stopping_word: string;
|
|
timings: {
|
|
predicted_ms: number;
|
|
predicted_n: number;
|
|
predicted_per_second: number;
|
|
predicted_per_token_ms: number;
|
|
prompt_ms: number;
|
|
prompt_n: number;
|
|
prompt_per_second: number;
|
|
prompt_per_token_ms: number;
|
|
};
|
|
tokens_cached: number;
|
|
tokens_evaluated: number;
|
|
tokens_predicted: number;
|
|
truncated: boolean;
|
|
};
|
|
|
|
export type LlamaRequest = {
|
|
n_predict: number;
|
|
mirostat: number;
|
|
repeat_penalty: number;
|
|
frequency_penalty: number;
|
|
presence_penalty: number;
|
|
repeat_last_n: number;
|
|
temperature: number;
|
|
top_p: number;
|
|
top_k: number;
|
|
typical_p: number;
|
|
tfs_z: number;
|
|
seed: number;
|
|
stream: boolean;
|
|
cache_prompt: boolean;
|
|
prompt?: string;
|
|
input_prefix?: string;
|
|
input_suffix?: string;
|
|
};
|
|
|
|
export function createLlamacppRequest(
|
|
config: vscode.WorkspaceConfiguration,
|
|
doc_before: string,
|
|
doc_after: string
|
|
): LlamaRequest {
|
|
let request: LlamaRequest = {
|
|
n_predict: config.get('llamaMaxtokens') as number,
|
|
mirostat: config.get('llamaMirostat') as number,
|
|
repeat_penalty: config.get('llamaRepeatPenalty') as number,
|
|
frequency_penalty: config.get('llamaFrequencyPenalty,') as number,
|
|
presence_penalty: config.get('llamaPresencePenalty,') as number,
|
|
repeat_last_n: config.get('llamaRepeatCtx,') as number,
|
|
temperature: config.get('llamaTemperature') as number,
|
|
top_p: config.get('llamaTop_p') as number,
|
|
top_k: config.get('llamaTop_k') as number,
|
|
typical_p: config.get('llamaTypical_p') as number,
|
|
tfs_z: config.get('llamaTailfree_z,') as number,
|
|
seed: config.get('llamaSeed') as number,
|
|
stream: false,
|
|
cache_prompt: config.get('llamaCachePrompt') as boolean,
|
|
};
|
|
|
|
const fim = config.get('fimEnabled') as boolean;
|
|
const fimRequest = config.get('useFillInMiddleRequest') as boolean;
|
|
|
|
if (fim === true) {
|
|
if (fimRequest === true) {
|
|
request.input_prefix = doc_before;
|
|
request.input_suffix = doc_after;
|
|
} else {
|
|
const fim_beg = config.get('fimBeginString') as string;
|
|
const fim_hole = config.get('fimHoleString') as string;
|
|
const fim_end = config.get('fimEndString') as string;
|
|
request.prompt = fim_beg + doc_before + fim_hole + doc_after + fim_end;
|
|
}
|
|
} else {
|
|
request.prompt = doc_before;
|
|
}
|
|
|
|
return request;
|
|
}
|
|
|
|
export function llamacppRequestEndpoint(config: vscode.WorkspaceConfiguration): string {
|
|
const fim = config.get('fimEnabled') as boolean;
|
|
const fimRequest = config.get('useFillInMiddleRequest') as boolean;
|
|
let req_str: string = config.get('llamaHost') as string;
|
|
|
|
if (fim === true && fimRequest === true) {
|
|
req_str += '/infill';
|
|
} else {
|
|
req_str += '/completion';
|
|
}
|
|
|
|
return req_str;
|
|
}
|
|
|
|
export async function llamacppMakeRequest(
|
|
request: LlamaRequest,
|
|
endpoint: string
|
|
): Promise<ResponseData> {
|
|
let ret: ResponseData = {
|
|
content: '',
|
|
tokens: 0,
|
|
time: 0,
|
|
};
|
|
let data: LlamaData;
|
|
// try to send the request to the running server
|
|
try {
|
|
const response_promise = fetch(endpoint, {
|
|
method: 'POST',
|
|
headers: {
|
|
'content-type': 'application/json; charset=UTF-8',
|
|
},
|
|
body: JSON.stringify(request),
|
|
});
|
|
|
|
const response = await response_promise;
|
|
if (response.ok === false) {
|
|
throw new Error('llama server request is not ok??');
|
|
}
|
|
|
|
data = (await response.json()) as LlamaData;
|
|
const gen_tokens = data.timings.predicted_n;
|
|
const gen_time = (data.timings.predicted_ms / 1000).toFixed(2);
|
|
|
|
ret.content = data.content;
|
|
ret.tokens = data.tokens_predicted;
|
|
ret.time = data.timings.predicted_ms / 1000;
|
|
} catch (e: any) {
|
|
const err = e as TypeError;
|
|
const cause: FetchErrorCause = err.cause as FetchErrorCause;
|
|
const estr: string =
|
|
err.message + ' ' + cause.code + ' at ' + cause.address + ':' + cause.port;
|
|
// let the user know something went wrong
|
|
// TODO: maybe add a retry action or something
|
|
showMessageWithTimeout('dumbpilot error: ' + estr, 3000);
|
|
}
|
|
return ret;
|
|
}
|
|
|