open_ai_compatible.rs

  1use anyhow::{Context as _, Result};
  2use cloud_llm_client::predict_edits_v3::{RawCompletionRequest, RawCompletionResponse};
  3use futures::AsyncReadExt as _;
  4use gpui::{App, AppContext as _, Entity, Global, SharedString, Task, http_client};
  5use language::language_settings::{OpenAiCompatibleEditPredictionSettings, all_language_settings};
  6use language_model::{ApiKeyState, EnvVar, env_var};
  7use std::sync::Arc;
  8use zed_credentials_provider::global as global_credentials_provider;
  9
 10pub fn open_ai_compatible_api_url(cx: &App) -> SharedString {
 11    all_language_settings(None, cx)
 12        .edit_predictions
 13        .open_ai_compatible_api
 14        .as_ref()
 15        .map(|settings| settings.api_url.clone())
 16        .unwrap_or_default()
 17        .into()
 18}
 19
 20pub const OPEN_AI_COMPATIBLE_CREDENTIALS_USERNAME: &str = "openai-compatible-api-token";
 21pub static OPEN_AI_COMPATIBLE_TOKEN_ENV_VAR: std::sync::LazyLock<EnvVar> =
 22    env_var!("ZED_OPEN_AI_COMPATIBLE_EDIT_PREDICTION_API_KEY");
 23
 24struct GlobalOpenAiCompatibleApiKey(Entity<ApiKeyState>);
 25
 26impl Global for GlobalOpenAiCompatibleApiKey {}
 27
 28pub fn open_ai_compatible_api_token(cx: &mut App) -> Entity<ApiKeyState> {
 29    if let Some(global) = cx.try_global::<GlobalOpenAiCompatibleApiKey>() {
 30        return global.0.clone();
 31    }
 32
 33    let entity = cx.new(|cx| {
 34        ApiKeyState::new(
 35            open_ai_compatible_api_url(cx),
 36            OPEN_AI_COMPATIBLE_TOKEN_ENV_VAR.clone(),
 37        )
 38    });
 39    cx.set_global(GlobalOpenAiCompatibleApiKey(entity.clone()));
 40    entity
 41}
 42
 43pub fn load_open_ai_compatible_api_token(
 44    cx: &mut App,
 45) -> Task<Result<(), language_model::AuthenticateError>> {
 46    let credentials_provider = global_credentials_provider(cx);
 47    let api_url = open_ai_compatible_api_url(cx);
 48    open_ai_compatible_api_token(cx).update(cx, |key_state, cx| {
 49        key_state.load_if_needed(api_url, |s| s, credentials_provider, cx)
 50    })
 51}
 52
 53pub fn load_open_ai_compatible_api_key_if_needed(
 54    provider: settings::EditPredictionProvider,
 55    cx: &mut App,
 56) -> Option<Arc<str>> {
 57    if provider != settings::EditPredictionProvider::OpenAiCompatibleApi {
 58        return None;
 59    }
 60    _ = load_open_ai_compatible_api_token(cx);
 61    let url = open_ai_compatible_api_url(cx);
 62    return open_ai_compatible_api_token(cx).read(cx).key(&url);
 63}
 64
 65pub(crate) async fn send_custom_server_request(
 66    provider: settings::EditPredictionProvider,
 67    settings: &OpenAiCompatibleEditPredictionSettings,
 68    prompt: String,
 69    max_tokens: u32,
 70    stop_tokens: Vec<String>,
 71    api_key: Option<Arc<str>>,
 72    http_client: &Arc<dyn http_client::HttpClient>,
 73) -> Result<(String, String)> {
 74    match provider {
 75        settings::EditPredictionProvider::Ollama => {
 76            let response = crate::ollama::make_request(
 77                settings.clone(),
 78                prompt,
 79                stop_tokens,
 80                http_client.clone(),
 81            )
 82            .await?;
 83            Ok((response.response, response.created_at))
 84        }
 85        _ => {
 86            let request = RawCompletionRequest {
 87                model: settings.model.clone(),
 88                prompt,
 89                max_tokens: Some(max_tokens),
 90                temperature: None,
 91                stop: stop_tokens
 92                    .into_iter()
 93                    .map(std::borrow::Cow::Owned)
 94                    .collect(),
 95                environment: None,
 96            };
 97
 98            let request_body = serde_json::to_string(&request)?;
 99            let mut http_request_builder = http_client::Request::builder()
100                .method(http_client::Method::POST)
101                .uri(settings.api_url.as_ref())
102                .header("Content-Type", "application/json");
103
104            if let Some(api_key) = api_key {
105                http_request_builder =
106                    http_request_builder.header("Authorization", format!("Bearer {}", api_key));
107            }
108
109            let http_request =
110                http_request_builder.body(http_client::AsyncBody::from(request_body))?;
111
112            let mut response = http_client.send(http_request).await?;
113            let status = response.status();
114
115            if !status.is_success() {
116                let mut body = String::new();
117                response.body_mut().read_to_string(&mut body).await?;
118                anyhow::bail!("custom server error: {} - {}", status, body);
119            }
120
121            let mut body = String::new();
122            response.body_mut().read_to_string(&mut body).await?;
123
124            let parsed: RawCompletionResponse =
125                serde_json::from_str(&body).context("Failed to parse completion response")?;
126            let text = parsed
127                .choices
128                .into_iter()
129                .next()
130                .map(|choice| choice.text)
131                .unwrap_or_default();
132            Ok((text, parsed.id))
133        }
134    }
135}