From 63f38f77ba206a9b6deb441255c2d396995edb46 Mon Sep 17 00:00:00 2001 From: Alessandro Mauri Date: Sat, 16 Dec 2023 19:19:05 +0100 Subject: [PATCH] still some todos but this shit works boiii --- src/extension.ts | 24 ++++++++++--- src/openai-api.ts | 85 ++++++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 100 insertions(+), 9 deletions(-) diff --git a/src/extension.ts b/src/extension.ts index 975b523..1cc5355 100644 --- a/src/extension.ts +++ b/src/extension.ts @@ -2,12 +2,17 @@ import { ok } from 'assert'; import * as vscode from 'vscode'; import commentPrefix from './comments.json'; import { - LlamaData, LlamaRequest, createLlamacppRequest, llamacppRequestEndpoint, llamacppMakeRequest, } from './llamacpp-api'; +import { + OpenAICompletionRequest, + createOpenAIAPIRequest, + openAIAPIRequestEndpoint, + openAIMakeRequest, +} from './openai-api'; import { FetchErrorCause, ResponseData, @@ -99,9 +104,20 @@ export function activate(context: vscode.ExtensionContext) { doc_before = pfx + ' ' + fname + sfx + '\n' + doc_before; // actially make the request - const request: LlamaRequest = createLlamacppRequest(config, doc_before, doc_after); - const endpoint: string = llamacppRequestEndpoint(config); - let data: ResponseData = await llamacppMakeRequest(request, endpoint); + let data: ResponseData = { content: '', tokens: 0, time: 0 }; + if (config.get('llamaUseOpenAIAPI') === true) { + const request: OpenAICompletionRequest = createOpenAIAPIRequest( + config, + doc_before, + doc_after + ); + const endpoint: string = openAIAPIRequestEndpoint(config); + data = await openAIMakeRequest(request, endpoint); + } else { + const request: LlamaRequest = createLlamacppRequest(config, doc_before, doc_after); + const endpoint: string = llamacppRequestEndpoint(config); + data = await llamacppMakeRequest(request, endpoint); + } result.items.push({ insertText: data.content, diff --git a/src/openai-api.ts b/src/openai-api.ts index fdaad9b..7722ad7 100644 --- a/src/openai-api.ts +++ b/src/openai-api.ts @@ -1,9 +1,15 @@ import * as vscode from 'vscode'; +import { + FetchErrorCause, + ResponseData, + showMessageWithTimeout, + showPendingStatusBar, +} from './common'; // oogabooga/text-generation-webui OpenAI compatible API // https://github.com/oobabooga/text-generation-webui/wiki/12-%E2%80%90-OpenAI-API -type OpenAICompletionRequest = { +export type OpenAICompletionRequest = { model?: string; // automatic prompt: string; best_of?: number; // 1 @@ -57,11 +63,20 @@ type OpenAICompletionRequest = { type OpenAICompletionSuccessResponse = { id: string; - choices: object[]; + choices: { + finish_reason: string; + index: number; + logprobs: object | null; + text: string; + }[]; created?: number; model: string; object?: string; - usage: object; + usage: { + completion_tokens: number; + prompt_tokens: number; + total_tokens: number; + }; }; type OpenAICompletionFailureResponse = { @@ -111,6 +126,66 @@ export function createOpenAIAPIRequest( } // for now only completions is implemented -export function OpenAIAPIRequestEndpoint(config: vscode.WorkspaceConfiguration): string { - return '/v1/completions'; +export function openAIAPIRequestEndpoint(config: vscode.WorkspaceConfiguration): string { + return (config.get('llamaHost') as string) + '/v1/completions'; +} + +export async function openAIMakeRequest( + request: OpenAICompletionRequest, + endpoint: string +): Promise { + let ret: ResponseData = { + content: '', + tokens: 0, + time: 0, + }; + let data: OpenAICompletionResponse; + // try to send the request to the running server + try { + const response_promise = fetch(endpoint, { + method: 'POST', + headers: { + 'content-type': 'application/json; charset=UTF-8', + }, + body: JSON.stringify(request), + }); + + showPendingStatusBar('dumbpilot waiting', response_promise); + + // TODO: measure the time it takes the server to respond + let resp_time: number = 0; + const response = await response_promise; + + if (response.ok === false) { + throw new Error('llama server request is not ok??'); + } + + data = (await response.json()) as OpenAICompletionResponse; + + // check wether the remote gave back an error + if (Object.hasOwn(data, 'detail') === true) { + data = data as OpenAICompletionFailureResponse; + // TODO: why did it error? + throw new Error('OpenAI Endpoint Error'); + } + + // unpack the data + data = data as OpenAICompletionSuccessResponse; + // FIXME: why the choices may be multiple? + // TODO: display the multiple choices + ret.content = data.choices[0].text; + ret.tokens = data.usage.completion_tokens; + ret.time = resp_time; + + showMessageWithTimeout(`predicted ${ret.tokens} tokens in ${ret.time} seconds`, 1500); + } catch (e: any) { + const err = e as TypeError; + const cause: FetchErrorCause = err.cause as FetchErrorCause; + const estr: string = + err.message + ' ' + cause.code + ' at ' + cause.address + ':' + cause.port; + // let the user know something went wrong + // TODO: maybe add a retry action or something + showMessageWithTimeout('dumbpilot error: ' + estr, 3000); + } + return ret; }