still some todos but this shit works boiii
This commit is contained in:
parent
960b2190bf
commit
63f38f77ba
@ -2,12 +2,17 @@ import { ok } from 'assert';
|
|||||||
import * as vscode from 'vscode';
|
import * as vscode from 'vscode';
|
||||||
import commentPrefix from './comments.json';
|
import commentPrefix from './comments.json';
|
||||||
import {
|
import {
|
||||||
LlamaData,
|
|
||||||
LlamaRequest,
|
LlamaRequest,
|
||||||
createLlamacppRequest,
|
createLlamacppRequest,
|
||||||
llamacppRequestEndpoint,
|
llamacppRequestEndpoint,
|
||||||
llamacppMakeRequest,
|
llamacppMakeRequest,
|
||||||
} from './llamacpp-api';
|
} from './llamacpp-api';
|
||||||
|
import {
|
||||||
|
OpenAICompletionRequest,
|
||||||
|
createOpenAIAPIRequest,
|
||||||
|
openAIAPIRequestEndpoint,
|
||||||
|
openAIMakeRequest,
|
||||||
|
} from './openai-api';
|
||||||
import {
|
import {
|
||||||
FetchErrorCause,
|
FetchErrorCause,
|
||||||
ResponseData,
|
ResponseData,
|
||||||
@ -99,9 +104,20 @@ export function activate(context: vscode.ExtensionContext) {
|
|||||||
doc_before = pfx + ' ' + fname + sfx + '\n' + doc_before;
|
doc_before = pfx + ' ' + fname + sfx + '\n' + doc_before;
|
||||||
|
|
||||||
// actially make the request
|
// actially make the request
|
||||||
const request: LlamaRequest = createLlamacppRequest(config, doc_before, doc_after);
|
let data: ResponseData = { content: '', tokens: 0, time: 0 };
|
||||||
const endpoint: string = llamacppRequestEndpoint(config);
|
if (config.get('llamaUseOpenAIAPI') === true) {
|
||||||
let data: ResponseData = await llamacppMakeRequest(request, endpoint);
|
const request: OpenAICompletionRequest = createOpenAIAPIRequest(
|
||||||
|
config,
|
||||||
|
doc_before,
|
||||||
|
doc_after
|
||||||
|
);
|
||||||
|
const endpoint: string = openAIAPIRequestEndpoint(config);
|
||||||
|
data = await openAIMakeRequest(request, endpoint);
|
||||||
|
} else {
|
||||||
|
const request: LlamaRequest = createLlamacppRequest(config, doc_before, doc_after);
|
||||||
|
const endpoint: string = llamacppRequestEndpoint(config);
|
||||||
|
data = await llamacppMakeRequest(request, endpoint);
|
||||||
|
}
|
||||||
|
|
||||||
result.items.push({
|
result.items.push({
|
||||||
insertText: data.content,
|
insertText: data.content,
|
||||||
|
@ -1,9 +1,15 @@
|
|||||||
import * as vscode from 'vscode';
|
import * as vscode from 'vscode';
|
||||||
|
import {
|
||||||
|
FetchErrorCause,
|
||||||
|
ResponseData,
|
||||||
|
showMessageWithTimeout,
|
||||||
|
showPendingStatusBar,
|
||||||
|
} from './common';
|
||||||
|
|
||||||
// oogabooga/text-generation-webui OpenAI compatible API
|
// oogabooga/text-generation-webui OpenAI compatible API
|
||||||
// https://github.com/oobabooga/text-generation-webui/wiki/12-%E2%80%90-OpenAI-API
|
// https://github.com/oobabooga/text-generation-webui/wiki/12-%E2%80%90-OpenAI-API
|
||||||
|
|
||||||
type OpenAICompletionRequest = {
|
export type OpenAICompletionRequest = {
|
||||||
model?: string; // automatic
|
model?: string; // automatic
|
||||||
prompt: string;
|
prompt: string;
|
||||||
best_of?: number; // 1
|
best_of?: number; // 1
|
||||||
@ -57,11 +63,20 @@ type OpenAICompletionRequest = {
|
|||||||
|
|
||||||
type OpenAICompletionSuccessResponse = {
|
type OpenAICompletionSuccessResponse = {
|
||||||
id: string;
|
id: string;
|
||||||
choices: object[];
|
choices: {
|
||||||
|
finish_reason: string;
|
||||||
|
index: number;
|
||||||
|
logprobs: object | null;
|
||||||
|
text: string;
|
||||||
|
}[];
|
||||||
created?: number;
|
created?: number;
|
||||||
model: string;
|
model: string;
|
||||||
object?: string;
|
object?: string;
|
||||||
usage: object;
|
usage: {
|
||||||
|
completion_tokens: number;
|
||||||
|
prompt_tokens: number;
|
||||||
|
total_tokens: number;
|
||||||
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
type OpenAICompletionFailureResponse = {
|
type OpenAICompletionFailureResponse = {
|
||||||
@ -111,6 +126,66 @@ export function createOpenAIAPIRequest(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// for now only completions is implemented
|
// for now only completions is implemented
|
||||||
export function OpenAIAPIRequestEndpoint(config: vscode.WorkspaceConfiguration): string {
|
export function openAIAPIRequestEndpoint(config: vscode.WorkspaceConfiguration): string {
|
||||||
return '/v1/completions';
|
return (config.get('llamaHost') as string) + '/v1/completions';
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function openAIMakeRequest(
|
||||||
|
request: OpenAICompletionRequest,
|
||||||
|
endpoint: string
|
||||||
|
): Promise<ResponseData> {
|
||||||
|
let ret: ResponseData = {
|
||||||
|
content: '',
|
||||||
|
tokens: 0,
|
||||||
|
time: 0,
|
||||||
|
};
|
||||||
|
let data: OpenAICompletionResponse;
|
||||||
|
// try to send the request to the running server
|
||||||
|
try {
|
||||||
|
const response_promise = fetch(endpoint, {
|
||||||
|
method: 'POST',
|
||||||
|
headers: {
|
||||||
|
'content-type': 'application/json; charset=UTF-8',
|
||||||
|
},
|
||||||
|
body: JSON.stringify(request),
|
||||||
|
});
|
||||||
|
|
||||||
|
showPendingStatusBar('dumbpilot waiting', response_promise);
|
||||||
|
|
||||||
|
// TODO: measure the time it takes the server to respond
|
||||||
|
let resp_time: number = 0;
|
||||||
|
const response = await response_promise;
|
||||||
|
|
||||||
|
if (response.ok === false) {
|
||||||
|
throw new Error('llama server request is not ok??');
|
||||||
|
}
|
||||||
|
|
||||||
|
data = (await response.json()) as OpenAICompletionResponse;
|
||||||
|
|
||||||
|
// check wether the remote gave back an error
|
||||||
|
if (Object.hasOwn(data, 'detail') === true) {
|
||||||
|
data = data as OpenAICompletionFailureResponse;
|
||||||
|
// TODO: why did it error?
|
||||||
|
throw new Error('OpenAI Endpoint Error');
|
||||||
|
}
|
||||||
|
|
||||||
|
// unpack the data
|
||||||
|
data = data as OpenAICompletionSuccessResponse;
|
||||||
|
// FIXME: why the choices may be multiple?
|
||||||
|
// TODO: display the multiple choices
|
||||||
|
ret.content = data.choices[0].text;
|
||||||
|
ret.tokens = data.usage.completion_tokens;
|
||||||
|
ret.time = resp_time;
|
||||||
|
|
||||||
|
showMessageWithTimeout(`predicted ${ret.tokens} tokens in ${ret.time} seconds`, 1500);
|
||||||
|
} catch (e: any) {
|
||||||
|
const err = e as TypeError;
|
||||||
|
const cause: FetchErrorCause = err.cause as FetchErrorCause;
|
||||||
|
const estr: string =
|
||||||
|
err.message + ' ' + cause.code + ' at ' + cause.address + ':' + cause.port;
|
||||||
|
// let the user know something went wrong
|
||||||
|
// TODO: maybe add a retry action or something
|
||||||
|
showMessageWithTimeout('dumbpilot error: ' + estr, 3000);
|
||||||
|
}
|
||||||
|
return ret;
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user