open_ai_compatible.rs

  1use anyhow::{Context as _, Result};
  2use cloud_llm_client::predict_edits_v3::{RawCompletionRequest, RawCompletionResponse};
  3use futures::AsyncReadExt as _;
  4use gpui::{App, AppContext as _, Entity, Global, SharedString, Task, http_client};
  5use language::language_settings::{OpenAiCompatibleEditPredictionSettings, all_language_settings};
  6use language_model::{ApiKeyState, EnvVar, env_var};
  7use std::sync::Arc;
  8
  9pub fn open_ai_compatible_api_url(cx: &App) -> SharedString {
 10    all_language_settings(None, cx)
 11        .edit_predictions
 12        .open_ai_compatible_api
 13        .as_ref()
 14        .map(|settings| settings.api_url.clone())
 15        .unwrap_or_default()
 16        .into()
 17}
 18
 19pub const OPEN_AI_COMPATIBLE_CREDENTIALS_USERNAME: &str = "openai-compatible-api-token";
 20pub static OPEN_AI_COMPATIBLE_TOKEN_ENV_VAR: std::sync::LazyLock<EnvVar> =
 21    env_var!("ZED_OPEN_AI_COMPATIBLE_EDIT_PREDICTION_API_KEY");
 22
 23struct GlobalOpenAiCompatibleApiKey(Entity<ApiKeyState>);
 24
 25impl Global for GlobalOpenAiCompatibleApiKey {}
 26
 27pub fn open_ai_compatible_api_token(cx: &mut App) -> Entity<ApiKeyState> {
 28    if let Some(global) = cx.try_global::<GlobalOpenAiCompatibleApiKey>() {
 29        return global.0.clone();
 30    }
 31
 32    let entity = cx.new(|cx| {
 33        ApiKeyState::new(
 34            open_ai_compatible_api_url(cx),
 35            OPEN_AI_COMPATIBLE_TOKEN_ENV_VAR.clone(),
 36        )
 37    });
 38    cx.set_global(GlobalOpenAiCompatibleApiKey(entity.clone()));
 39    entity
 40}
 41
 42pub fn load_open_ai_compatible_api_token(
 43    cx: &mut App,
 44) -> Task<Result<(), language_model::AuthenticateError>> {
 45    let api_url = open_ai_compatible_api_url(cx);
 46    open_ai_compatible_api_token(cx).update(cx, |key_state, cx| {
 47        key_state.load_if_needed(api_url, |s| s, cx)
 48    })
 49}
 50
 51pub fn load_open_ai_compatible_api_key_if_needed(
 52    provider: settings::EditPredictionProvider,
 53    cx: &mut App,
 54) -> Option<Arc<str>> {
 55    if provider != settings::EditPredictionProvider::OpenAiCompatibleApi {
 56        return None;
 57    }
 58    _ = load_open_ai_compatible_api_token(cx);
 59    let url = open_ai_compatible_api_url(cx);
 60    return open_ai_compatible_api_token(cx).read(cx).key(&url);
 61}
 62
 63pub(crate) async fn send_custom_server_request(
 64    provider: settings::EditPredictionProvider,
 65    settings: &OpenAiCompatibleEditPredictionSettings,
 66    prompt: String,
 67    max_tokens: u32,
 68    stop_tokens: Vec<String>,
 69    api_key: Option<Arc<str>>,
 70    http_client: &Arc<dyn http_client::HttpClient>,
 71) -> Result<(String, String)> {
 72    match provider {
 73        settings::EditPredictionProvider::Ollama => {
 74            let response = crate::ollama::make_request(
 75                settings.clone(),
 76                prompt,
 77                stop_tokens,
 78                http_client.clone(),
 79            )
 80            .await?;
 81            Ok((response.response, response.created_at))
 82        }
 83        _ => {
 84            let request = RawCompletionRequest {
 85                model: settings.model.clone(),
 86                prompt,
 87                max_tokens: Some(max_tokens),
 88                temperature: None,
 89                stop: stop_tokens
 90                    .into_iter()
 91                    .map(std::borrow::Cow::Owned)
 92                    .collect(),
 93                environment: None,
 94            };
 95
 96            let request_body = serde_json::to_string(&request)?;
 97            let mut http_request_builder = http_client::Request::builder()
 98                .method(http_client::Method::POST)
 99                .uri(settings.api_url.as_ref())
100                .header("Content-Type", "application/json");
101
102            if let Some(api_key) = api_key {
103                http_request_builder =
104                    http_request_builder.header("Authorization", format!("Bearer {}", api_key));
105            }
106
107            let http_request =
108                http_request_builder.body(http_client::AsyncBody::from(request_body))?;
109
110            let mut response = http_client.send(http_request).await?;
111            let status = response.status();
112
113            if !status.is_success() {
114                let mut body = String::new();
115                response.body_mut().read_to_string(&mut body).await?;
116                anyhow::bail!("custom server error: {} - {}", status, body);
117            }
118
119            let mut body = String::new();
120            response.body_mut().read_to_string(&mut body).await?;
121
122            let parsed: RawCompletionResponse =
123                serde_json::from_str(&body).context("Failed to parse completion response")?;
124            let text = parsed
125                .choices
126                .into_iter()
127                .next()
128                .map(|choice| choice.text)
129                .unwrap_or_default();
130            Ok((text, parsed.id))
131        }
132    }
133}