ep: API keys for OpenAI compatible (#50615)

Ben Kunkle created

Closes #ISSUE

Before you mark this PR as ready for review, make sure that you have:
- [ ] Added a solid test coverage and/or screenshots from doing manual
testing
- [ ] Done a self-review taking into account security and performance
aspects
- [ ] Aligned any UI changes with the [UI
checklist](https://github.com/zed-industries/zed/blob/main/CONTRIBUTING.md#uiux-checklist)

Release Notes:

- Added support for providing an API key to OpenAI-compatible edit
prediction providers

Change summary

crates/edit_prediction/src/edit_prediction.rs                  |   1 
crates/edit_prediction/src/fim.rs                              |  11 
crates/edit_prediction/src/open_ai_compatible.rs               | 133 ++++
crates/edit_prediction/src/zeta.rs                             |  78 -
crates/edit_prediction_ui/src/edit_prediction_button.rs        |   8 
crates/settings_ui/src/pages/edit_prediction_provider_setup.rs | 115 +-
crates/zed/src/zed/edit_prediction_registry.rs                 |   5 
7 files changed, 230 insertions(+), 121 deletions(-)

Detailed changes

crates/edit_prediction/src/fim.rs 🔗

@@ -1,6 +1,7 @@
 use crate::{
-    EditPredictionId, EditPredictionModelInput, cursor_excerpt, prediction::EditPredictionResult,
-    zeta,
+    EditPredictionId, EditPredictionModelInput, cursor_excerpt,
+    open_ai_compatible::{self, load_open_ai_compatible_api_key_if_needed},
+    prediction::EditPredictionResult,
 };
 use anyhow::{Context as _, Result, anyhow};
 use gpui::{App, AppContext as _, Entity, Task};
@@ -58,6 +59,8 @@ pub fn request_prediction(
         return Task::ready(Err(anyhow!("Unsupported edit prediction provider for FIM")));
     };
 
+    let api_key = load_open_ai_compatible_api_key_if_needed(provider, cx);
+
     let result = cx.background_spawn(async move {
         let (excerpt_range, _) = cursor_excerpt::editable_and_context_ranges_for_cursor_position(
             cursor_point,
@@ -90,12 +93,14 @@ pub fn request_prediction(
         let stop_tokens = get_fim_stop_tokens();
 
         let max_tokens = settings.max_output_tokens;
-        let (response_text, request_id) = zeta::send_custom_server_request(
+
+        let (response_text, request_id) = open_ai_compatible::send_custom_server_request(
             provider,
             &settings,
             prompt,
             max_tokens,
             stop_tokens,
+            api_key,
             &http_client,
         )
         .await?;

crates/edit_prediction/src/open_ai_compatible.rs 🔗

@@ -0,0 +1,133 @@
+use anyhow::{Context as _, Result};
+use cloud_llm_client::predict_edits_v3::{RawCompletionRequest, RawCompletionResponse};
+use futures::AsyncReadExt as _;
+use gpui::{App, AppContext as _, Entity, Global, SharedString, Task, http_client};
+use language::language_settings::{OpenAiCompatibleEditPredictionSettings, all_language_settings};
+use language_model::{ApiKeyState, EnvVar, env_var};
+use std::sync::Arc;
+
+pub fn open_ai_compatible_api_url(cx: &App) -> SharedString {
+    all_language_settings(None, cx)
+        .edit_predictions
+        .open_ai_compatible_api
+        .as_ref()
+        .map(|settings| settings.api_url.clone())
+        .unwrap_or_default()
+        .into()
+}
+
+pub const OPEN_AI_COMPATIBLE_CREDENTIALS_USERNAME: &str = "openai-compatible-api-token";
+pub static OPEN_AI_COMPATIBLE_TOKEN_ENV_VAR: std::sync::LazyLock<EnvVar> =
+    env_var!("ZED_OPEN_AI_COMPATIBLE_EDIT_PREDICTION_API_KEY");
+
+struct GlobalOpenAiCompatibleApiKey(Entity<ApiKeyState>);
+
+impl Global for GlobalOpenAiCompatibleApiKey {}
+
+pub fn open_ai_compatible_api_token(cx: &mut App) -> Entity<ApiKeyState> {
+    if let Some(global) = cx.try_global::<GlobalOpenAiCompatibleApiKey>() {
+        return global.0.clone();
+    }
+
+    let entity = cx.new(|cx| {
+        ApiKeyState::new(
+            open_ai_compatible_api_url(cx),
+            OPEN_AI_COMPATIBLE_TOKEN_ENV_VAR.clone(),
+        )
+    });
+    cx.set_global(GlobalOpenAiCompatibleApiKey(entity.clone()));
+    entity
+}
+
+pub fn load_open_ai_compatible_api_token(
+    cx: &mut App,
+) -> Task<Result<(), language_model::AuthenticateError>> {
+    let api_url = open_ai_compatible_api_url(cx);
+    open_ai_compatible_api_token(cx).update(cx, |key_state, cx| {
+        key_state.load_if_needed(api_url, |s| s, cx)
+    })
+}
+
+pub fn load_open_ai_compatible_api_key_if_needed(
+    provider: settings::EditPredictionProvider,
+    cx: &mut App,
+) -> Option<Arc<str>> {
+    if provider != settings::EditPredictionProvider::OpenAiCompatibleApi {
+        return None;
+    }
+    _ = load_open_ai_compatible_api_token(cx);
+    let url = open_ai_compatible_api_url(cx);
+    return open_ai_compatible_api_token(cx).read(cx).key(&url);
+}
+
+pub(crate) async fn send_custom_server_request(
+    provider: settings::EditPredictionProvider,
+    settings: &OpenAiCompatibleEditPredictionSettings,
+    prompt: String,
+    max_tokens: u32,
+    stop_tokens: Vec<String>,
+    api_key: Option<Arc<str>>,
+    http_client: &Arc<dyn http_client::HttpClient>,
+) -> Result<(String, String)> {
+    match provider {
+        settings::EditPredictionProvider::Ollama => {
+            let response = crate::ollama::make_request(
+                settings.clone(),
+                prompt,
+                stop_tokens,
+                http_client.clone(),
+            )
+            .await?;
+            Ok((response.response, response.created_at))
+        }
+        _ => {
+            let request = RawCompletionRequest {
+                model: settings.model.clone(),
+                prompt,
+                max_tokens: Some(max_tokens),
+                temperature: None,
+                stop: stop_tokens
+                    .into_iter()
+                    .map(std::borrow::Cow::Owned)
+                    .collect(),
+                environment: None,
+            };
+
+            let request_body = serde_json::to_string(&request)?;
+            let mut http_request_builder = http_client::Request::builder()
+                .method(http_client::Method::POST)
+                .uri(settings.api_url.as_ref())
+                .header("Content-Type", "application/json");
+
+            if let Some(api_key) = api_key {
+                http_request_builder =
+                    http_request_builder.header("Authorization", format!("Bearer {}", api_key));
+            }
+
+            let http_request =
+                http_request_builder.body(http_client::AsyncBody::from(request_body))?;
+
+            let mut response = http_client.send(http_request).await?;
+            let status = response.status();
+
+            if !status.is_success() {
+                let mut body = String::new();
+                response.body_mut().read_to_string(&mut body).await?;
+                anyhow::bail!("custom server error: {} - {}", status, body);
+            }
+
+            let mut body = String::new();
+            response.body_mut().read_to_string(&mut body).await?;
+
+            let parsed: RawCompletionResponse =
+                serde_json::from_str(&body).context("Failed to parse completion response")?;
+            let text = parsed
+                .choices
+                .into_iter()
+                .next()
+                .map(|choice| choice.text)
+                .unwrap_or_default();
+            Ok((text, parsed.id))
+        }
+    }
+}

crates/edit_prediction/src/zeta.rs 🔗

@@ -2,15 +2,14 @@ use crate::cursor_excerpt::compute_excerpt_ranges;
 use crate::prediction::EditPredictionResult;
 use crate::{
     CurrentEditPrediction, DebugEvent, EditPredictionFinishedDebugEvent, EditPredictionId,
-    EditPredictionModelInput, EditPredictionStartedDebugEvent, EditPredictionStore, ollama,
+    EditPredictionModelInput, EditPredictionStartedDebugEvent, EditPredictionStore,
 };
-use anyhow::{Context as _, Result};
-use cloud_llm_client::predict_edits_v3::{RawCompletionRequest, RawCompletionResponse};
+use anyhow::Result;
+use cloud_llm_client::predict_edits_v3::RawCompletionRequest;
 use cloud_llm_client::{AcceptEditPredictionBody, EditPredictionRejectReason};
 use edit_prediction_types::PredictedCursorPosition;
-use futures::AsyncReadExt as _;
-use gpui::{App, AppContext as _, Task, http_client, prelude::*};
-use language::language_settings::{OpenAiCompatibleEditPredictionSettings, all_language_settings};
+use gpui::{App, AppContext as _, Task, prelude::*};
+use language::language_settings::all_language_settings;
 use language::{BufferSnapshot, ToOffset as _, ToPoint, text_diff};
 use release_channel::AppVersion;
 use settings::EditPredictionPromptFormat;
@@ -25,6 +24,10 @@ use zeta_prompt::{
     zeta1::{self, EDITABLE_REGION_END_MARKER},
 };
 
+use crate::open_ai_compatible::{
+    load_open_ai_compatible_api_key_if_needed, send_custom_server_request,
+};
+
 pub fn request_prediction_with_zeta(
     store: &mut EditPredictionStore,
     EditPredictionModelInput {
@@ -56,6 +59,7 @@ pub fn request_prediction_with_zeta(
     let buffer_snapshotted_at = Instant::now();
     let raw_config = store.zeta2_raw_config().cloned();
     let preferred_experiment = store.preferred_experiment().map(|s| s.to_owned());
+    let open_ai_compatible_api_key = load_open_ai_compatible_api_key_if_needed(provider, cx);
 
     let excerpt_path: Arc<Path> = snapshot
         .file()
@@ -131,6 +135,7 @@ pub fn request_prediction_with_zeta(
                                 prompt,
                                 max_tokens,
                                 stop_tokens,
+                                open_ai_compatible_api_key.clone(),
                                 &http_client,
                             )
                             .await?;
@@ -157,6 +162,7 @@ pub fn request_prediction_with_zeta(
                                 prompt,
                                 max_tokens,
                                 vec![],
+                                open_ai_compatible_api_key.clone(),
                                 &http_client,
                             )
                             .await?;
@@ -400,66 +406,6 @@ pub fn zeta2_prompt_input(
     (full_context_offset_range, prompt_input)
 }
 
-pub(crate) async fn send_custom_server_request(
-    provider: settings::EditPredictionProvider,
-    settings: &OpenAiCompatibleEditPredictionSettings,
-    prompt: String,
-    max_tokens: u32,
-    stop_tokens: Vec<String>,
-    http_client: &Arc<dyn http_client::HttpClient>,
-) -> Result<(String, String)> {
-    match provider {
-        settings::EditPredictionProvider::Ollama => {
-            let response =
-                ollama::make_request(settings.clone(), prompt, stop_tokens, http_client.clone())
-                    .await?;
-            Ok((response.response, response.created_at))
-        }
-        _ => {
-            let request = RawCompletionRequest {
-                model: settings.model.clone(),
-                prompt,
-                max_tokens: Some(max_tokens),
-                temperature: None,
-                stop: stop_tokens
-                    .into_iter()
-                    .map(std::borrow::Cow::Owned)
-                    .collect(),
-                environment: None,
-            };
-
-            let request_body = serde_json::to_string(&request)?;
-            let http_request = http_client::Request::builder()
-                .method(http_client::Method::POST)
-                .uri(settings.api_url.as_ref())
-                .header("Content-Type", "application/json")
-                .body(http_client::AsyncBody::from(request_body))?;
-
-            let mut response = http_client.send(http_request).await?;
-            let status = response.status();
-
-            if !status.is_success() {
-                let mut body = String::new();
-                response.body_mut().read_to_string(&mut body).await?;
-                anyhow::bail!("custom server error: {} - {}", status, body);
-            }
-
-            let mut body = String::new();
-            response.body_mut().read_to_string(&mut body).await?;
-
-            let parsed: RawCompletionResponse =
-                serde_json::from_str(&body).context("Failed to parse completion response")?;
-            let text = parsed
-                .choices
-                .into_iter()
-                .next()
-                .map(|choice| choice.text)
-                .unwrap_or_default();
-            Ok((text, parsed.id))
-        }
-    }
-}
-
 pub(crate) fn edit_prediction_accepted(
     store: &EditPredictionStore,
     current_prediction: CurrentEditPrediction,

crates/edit_prediction_ui/src/edit_prediction_button.rs 🔗

@@ -539,9 +539,15 @@ impl EditPredictionButton {
         edit_prediction::ollama::ensure_authenticated(cx);
         let sweep_api_token_task = edit_prediction::sweep_ai::load_sweep_api_token(cx);
         let mercury_api_token_task = edit_prediction::mercury::load_mercury_api_token(cx);
+        let open_ai_compatible_api_token_task =
+            edit_prediction::open_ai_compatible::load_open_ai_compatible_api_token(cx);
 
         cx.spawn(async move |this, cx| {
-            _ = futures::join!(sweep_api_token_task, mercury_api_token_task);
+            _ = futures::join!(
+                sweep_api_token_task,
+                mercury_api_token_task,
+                open_ai_compatible_api_token_task
+            );
             this.update(cx, |_, cx| {
                 cx.notify();
             })

crates/settings_ui/src/pages/edit_prediction_provider_setup.rs 🔗

@@ -2,6 +2,7 @@ use codestral::{CODESTRAL_API_URL, codestral_api_key_state, codestral_api_url};
 use edit_prediction::{
     ApiKeyState,
     mercury::{MERCURY_CREDENTIALS_URL, mercury_api_token},
+    open_ai_compatible::{open_ai_compatible_api_token, open_ai_compatible_api_url},
     sweep_ai::{SWEEP_CREDENTIALS_URL, sweep_api_token},
 };
 use edit_prediction_ui::{get_available_providers, set_completion_provider};
@@ -33,7 +34,9 @@ pub(crate) fn render_edit_prediction_setup_page(
             render_api_key_provider(
                 IconName::Inception,
                 "Mercury",
-                "https://platform.inceptionlabs.ai/dashboard/api-keys".into(),
+                ApiKeyDocs::Link {
+                    dashboard_url: "https://platform.inceptionlabs.ai/dashboard/api-keys".into(),
+                },
                 mercury_api_token(cx),
                 |_cx| MERCURY_CREDENTIALS_URL,
                 None,
@@ -46,7 +49,9 @@ pub(crate) fn render_edit_prediction_setup_page(
             render_api_key_provider(
                 IconName::SweepAi,
                 "Sweep",
-                "https://app.sweep.dev/".into(),
+                ApiKeyDocs::Link {
+                    dashboard_url: "https://app.sweep.dev/".into(),
+                },
                 sweep_api_token(cx),
                 |_cx| SWEEP_CREDENTIALS_URL,
                 Some(
@@ -68,7 +73,9 @@ pub(crate) fn render_edit_prediction_setup_page(
             render_api_key_provider(
                 IconName::AiMistral,
                 "Codestral",
-                "https://console.mistral.ai/codestral".into(),
+                ApiKeyDocs::Link {
+                    dashboard_url: "https://console.mistral.ai/codestral".into(),
+                },
                 codestral_api_key_state(cx),
                 |cx| codestral_api_url(cx),
                 Some(
@@ -87,7 +94,31 @@ pub(crate) fn render_edit_prediction_setup_page(
             .into_any_element(),
         ),
         Some(render_ollama_provider(settings_window, window, cx).into_any_element()),
-        Some(render_open_ai_compatible_provider(settings_window, window, cx).into_any_element()),
+        Some(
+            render_api_key_provider(
+                IconName::AiOpenAiCompat,
+                "OpenAI Compatible API",
+                ApiKeyDocs::Custom {
+                    message: "Set an API key here. It will be sent as Authorization: Bearer {key}."
+                        .into(),
+                },
+                open_ai_compatible_api_token(cx),
+                |cx| open_ai_compatible_api_url(cx),
+                Some(
+                    settings_window
+                        .render_sub_page_items_section(
+                            open_ai_compatible_settings().iter().enumerate(),
+                            true,
+                            window,
+                            cx,
+                        )
+                        .into_any_element(),
+                ),
+                window,
+                cx,
+            )
+            .into_any_element(),
+        ),
     ];
 
     div()
@@ -162,10 +193,15 @@ fn render_provider_dropdown(window: &mut Window, cx: &mut App) -> AnyElement {
         .into_any_element()
 }
 
+enum ApiKeyDocs {
+    Link { dashboard_url: SharedString },
+    Custom { message: SharedString },
+}
+
 fn render_api_key_provider(
     icon: IconName,
     title: &'static str,
-    link: SharedString,
+    docs: ApiKeyDocs,
     api_key_state: Entity<ApiKeyState>,
     current_url: fn(&mut App) -> SharedString,
     additional_fields: Option<AnyElement>,
@@ -209,25 +245,32 @@ fn render_api_key_provider(
         .icon(icon)
         .no_padding(true);
     let button_link_label = format!("{} dashboard", title);
-    let description = h_flex()
-        .min_w_0()
-        .gap_0p5()
-        .child(
-            Label::new("Visit the")
+    let description = match docs {
+        ApiKeyDocs::Custom { message } => h_flex().min_w_0().gap_0p5().child(
+            Label::new(message)
                 .size(LabelSize::Small)
                 .color(Color::Muted),
-        )
-        .child(
-            ButtonLink::new(button_link_label, link)
-                .no_icon(true)
-                .label_size(LabelSize::Small)
-                .label_color(Color::Muted),
-        )
-        .child(
-            Label::new("to generate an API key.")
-                .size(LabelSize::Small)
-                .color(Color::Muted),
-        );
+        ),
+        ApiKeyDocs::Link { dashboard_url } => h_flex()
+            .min_w_0()
+            .gap_0p5()
+            .child(
+                Label::new("Visit the")
+                    .size(LabelSize::Small)
+                    .color(Color::Muted),
+            )
+            .child(
+                ButtonLink::new(button_link_label, dashboard_url)
+                    .no_icon(true)
+                    .label_size(LabelSize::Small)
+                    .label_color(Color::Muted),
+            )
+            .child(
+                Label::new("to generate an API key.")
+                    .size(LabelSize::Small)
+                    .color(Color::Muted),
+            ),
+    };
     let configured_card_label = if is_from_env_var {
         "API Key Set in Environment Variable"
     } else {
@@ -484,34 +527,6 @@ fn ollama_settings() -> Box<[SettingsPageItem]> {
     ])
 }
 
-fn render_open_ai_compatible_provider(
-    settings_window: &SettingsWindow,
-    window: &mut Window,
-    cx: &mut Context<SettingsWindow>,
-) -> impl IntoElement {
-    let open_ai_compatible_settings = open_ai_compatible_settings();
-    let additional_fields = settings_window
-        .render_sub_page_items_section(
-            open_ai_compatible_settings.iter().enumerate(),
-            true,
-            window,
-            cx,
-        )
-        .into_any_element();
-
-    v_flex()
-        .id("open-ai-compatible")
-        .min_w_0()
-        .pt_8()
-        .gap_1p5()
-        .child(
-            SettingsSectionHeader::new("OpenAI Compatible API")
-                .icon(IconName::AiOpenAiCompat)
-                .no_padding(true),
-        )
-        .child(div().px_neg_8().child(additional_fields))
-}
-
 fn open_ai_compatible_settings() -> Box<[SettingsPageItem]> {
     Box::new([
         SettingsPageItem::SettingItem(SettingItem {

crates/zed/src/zed/edit_prediction_registry.rs 🔗

@@ -154,7 +154,10 @@ fn edit_prediction_provider_config_for_settings(cx: &App) -> Option<EditPredicti
                 }
             }
 
-            if format == EditPredictionPromptFormat::Zeta {
+            if matches!(
+                format,
+                EditPredictionPromptFormat::Zeta | EditPredictionPromptFormat::Zeta2
+            ) {
                 Some(EditPredictionProviderConfig::Zed(EditPredictionModel::Zeta))
             } else {
                 Some(EditPredictionProviderConfig::Zed(