diff --git a/crates/edit_prediction/src/edit_prediction.rs b/crates/edit_prediction/src/edit_prediction.rs index e6e3a9abdf83deb785cd56d358b065973682b8cc..74988d65933b3bbbc2507077a74dfeb94089ab63 100644 --- a/crates/edit_prediction/src/edit_prediction.rs +++ b/crates/edit_prediction/src/edit_prediction.rs @@ -69,6 +69,7 @@ pub mod sweep_ai; pub mod udiff; mod capture_example; +pub mod open_ai_compatible; mod zed_edit_prediction_delegate; pub mod zeta; diff --git a/crates/edit_prediction/src/fim.rs b/crates/edit_prediction/src/fim.rs index 66f2e58a3b01b4fbf49b11864db4daec6b4dc1c2..d3e18f73acc665eec28d725530d11297cf4d69ea 100644 --- a/crates/edit_prediction/src/fim.rs +++ b/crates/edit_prediction/src/fim.rs @@ -1,6 +1,7 @@ use crate::{ - EditPredictionId, EditPredictionModelInput, cursor_excerpt, prediction::EditPredictionResult, - zeta, + EditPredictionId, EditPredictionModelInput, cursor_excerpt, + open_ai_compatible::{self, load_open_ai_compatible_api_key_if_needed}, + prediction::EditPredictionResult, }; use anyhow::{Context as _, Result, anyhow}; use gpui::{App, AppContext as _, Entity, Task}; @@ -58,6 +59,8 @@ pub fn request_prediction( return Task::ready(Err(anyhow!("Unsupported edit prediction provider for FIM"))); }; + let api_key = load_open_ai_compatible_api_key_if_needed(provider, cx); + let result = cx.background_spawn(async move { let (excerpt_range, _) = cursor_excerpt::editable_and_context_ranges_for_cursor_position( cursor_point, @@ -90,12 +93,14 @@ pub fn request_prediction( let stop_tokens = get_fim_stop_tokens(); let max_tokens = settings.max_output_tokens; - let (response_text, request_id) = zeta::send_custom_server_request( + + let (response_text, request_id) = open_ai_compatible::send_custom_server_request( provider, &settings, prompt, max_tokens, stop_tokens, + api_key, &http_client, ) .await?; diff --git a/crates/edit_prediction/src/open_ai_compatible.rs b/crates/edit_prediction/src/open_ai_compatible.rs new file mode 100644 index 0000000000000000000000000000000000000000..ca378ba1fd0bc9bdbb3e85c7610e1b94c1be388f --- /dev/null +++ b/crates/edit_prediction/src/open_ai_compatible.rs @@ -0,0 +1,133 @@ +use anyhow::{Context as _, Result}; +use cloud_llm_client::predict_edits_v3::{RawCompletionRequest, RawCompletionResponse}; +use futures::AsyncReadExt as _; +use gpui::{App, AppContext as _, Entity, Global, SharedString, Task, http_client}; +use language::language_settings::{OpenAiCompatibleEditPredictionSettings, all_language_settings}; +use language_model::{ApiKeyState, EnvVar, env_var}; +use std::sync::Arc; + +pub fn open_ai_compatible_api_url(cx: &App) -> SharedString { + all_language_settings(None, cx) + .edit_predictions + .open_ai_compatible_api + .as_ref() + .map(|settings| settings.api_url.clone()) + .unwrap_or_default() + .into() +} + +pub const OPEN_AI_COMPATIBLE_CREDENTIALS_USERNAME: &str = "openai-compatible-api-token"; +pub static OPEN_AI_COMPATIBLE_TOKEN_ENV_VAR: std::sync::LazyLock = + env_var!("ZED_OPEN_AI_COMPATIBLE_EDIT_PREDICTION_API_KEY"); + +struct GlobalOpenAiCompatibleApiKey(Entity); + +impl Global for GlobalOpenAiCompatibleApiKey {} + +pub fn open_ai_compatible_api_token(cx: &mut App) -> Entity { + if let Some(global) = cx.try_global::() { + return global.0.clone(); + } + + let entity = cx.new(|cx| { + ApiKeyState::new( + open_ai_compatible_api_url(cx), + OPEN_AI_COMPATIBLE_TOKEN_ENV_VAR.clone(), + ) + }); + cx.set_global(GlobalOpenAiCompatibleApiKey(entity.clone())); + entity +} + +pub fn load_open_ai_compatible_api_token( + cx: &mut App, +) -> Task> { + let api_url = open_ai_compatible_api_url(cx); + open_ai_compatible_api_token(cx).update(cx, |key_state, cx| { + key_state.load_if_needed(api_url, |s| s, cx) + }) +} + +pub fn load_open_ai_compatible_api_key_if_needed( + provider: settings::EditPredictionProvider, + cx: &mut App, +) -> Option> { + if provider != settings::EditPredictionProvider::OpenAiCompatibleApi { + return None; + } + _ = load_open_ai_compatible_api_token(cx); + let url = open_ai_compatible_api_url(cx); + return open_ai_compatible_api_token(cx).read(cx).key(&url); +} + +pub(crate) async fn send_custom_server_request( + provider: settings::EditPredictionProvider, + settings: &OpenAiCompatibleEditPredictionSettings, + prompt: String, + max_tokens: u32, + stop_tokens: Vec, + api_key: Option>, + http_client: &Arc, +) -> Result<(String, String)> { + match provider { + settings::EditPredictionProvider::Ollama => { + let response = crate::ollama::make_request( + settings.clone(), + prompt, + stop_tokens, + http_client.clone(), + ) + .await?; + Ok((response.response, response.created_at)) + } + _ => { + let request = RawCompletionRequest { + model: settings.model.clone(), + prompt, + max_tokens: Some(max_tokens), + temperature: None, + stop: stop_tokens + .into_iter() + .map(std::borrow::Cow::Owned) + .collect(), + environment: None, + }; + + let request_body = serde_json::to_string(&request)?; + let mut http_request_builder = http_client::Request::builder() + .method(http_client::Method::POST) + .uri(settings.api_url.as_ref()) + .header("Content-Type", "application/json"); + + if let Some(api_key) = api_key { + http_request_builder = + http_request_builder.header("Authorization", format!("Bearer {}", api_key)); + } + + let http_request = + http_request_builder.body(http_client::AsyncBody::from(request_body))?; + + let mut response = http_client.send(http_request).await?; + let status = response.status(); + + if !status.is_success() { + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + anyhow::bail!("custom server error: {} - {}", status, body); + } + + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + + let parsed: RawCompletionResponse = + serde_json::from_str(&body).context("Failed to parse completion response")?; + let text = parsed + .choices + .into_iter() + .next() + .map(|choice| choice.text) + .unwrap_or_default(); + Ok((text, parsed.id)) + } + } +} diff --git a/crates/edit_prediction/src/zeta.rs b/crates/edit_prediction/src/zeta.rs index f6a786572736908556535b9131c1cf7814a6126f..789ff6c0d7fcc269baf30b5e0fb0e849bc865859 100644 --- a/crates/edit_prediction/src/zeta.rs +++ b/crates/edit_prediction/src/zeta.rs @@ -2,15 +2,14 @@ use crate::cursor_excerpt::compute_excerpt_ranges; use crate::prediction::EditPredictionResult; use crate::{ CurrentEditPrediction, DebugEvent, EditPredictionFinishedDebugEvent, EditPredictionId, - EditPredictionModelInput, EditPredictionStartedDebugEvent, EditPredictionStore, ollama, + EditPredictionModelInput, EditPredictionStartedDebugEvent, EditPredictionStore, }; -use anyhow::{Context as _, Result}; -use cloud_llm_client::predict_edits_v3::{RawCompletionRequest, RawCompletionResponse}; +use anyhow::Result; +use cloud_llm_client::predict_edits_v3::RawCompletionRequest; use cloud_llm_client::{AcceptEditPredictionBody, EditPredictionRejectReason}; use edit_prediction_types::PredictedCursorPosition; -use futures::AsyncReadExt as _; -use gpui::{App, AppContext as _, Task, http_client, prelude::*}; -use language::language_settings::{OpenAiCompatibleEditPredictionSettings, all_language_settings}; +use gpui::{App, AppContext as _, Task, prelude::*}; +use language::language_settings::all_language_settings; use language::{BufferSnapshot, ToOffset as _, ToPoint, text_diff}; use release_channel::AppVersion; use settings::EditPredictionPromptFormat; @@ -25,6 +24,10 @@ use zeta_prompt::{ zeta1::{self, EDITABLE_REGION_END_MARKER}, }; +use crate::open_ai_compatible::{ + load_open_ai_compatible_api_key_if_needed, send_custom_server_request, +}; + pub fn request_prediction_with_zeta( store: &mut EditPredictionStore, EditPredictionModelInput { @@ -56,6 +59,7 @@ pub fn request_prediction_with_zeta( let buffer_snapshotted_at = Instant::now(); let raw_config = store.zeta2_raw_config().cloned(); let preferred_experiment = store.preferred_experiment().map(|s| s.to_owned()); + let open_ai_compatible_api_key = load_open_ai_compatible_api_key_if_needed(provider, cx); let excerpt_path: Arc = snapshot .file() @@ -131,6 +135,7 @@ pub fn request_prediction_with_zeta( prompt, max_tokens, stop_tokens, + open_ai_compatible_api_key.clone(), &http_client, ) .await?; @@ -157,6 +162,7 @@ pub fn request_prediction_with_zeta( prompt, max_tokens, vec![], + open_ai_compatible_api_key.clone(), &http_client, ) .await?; @@ -400,66 +406,6 @@ pub fn zeta2_prompt_input( (full_context_offset_range, prompt_input) } -pub(crate) async fn send_custom_server_request( - provider: settings::EditPredictionProvider, - settings: &OpenAiCompatibleEditPredictionSettings, - prompt: String, - max_tokens: u32, - stop_tokens: Vec, - http_client: &Arc, -) -> Result<(String, String)> { - match provider { - settings::EditPredictionProvider::Ollama => { - let response = - ollama::make_request(settings.clone(), prompt, stop_tokens, http_client.clone()) - .await?; - Ok((response.response, response.created_at)) - } - _ => { - let request = RawCompletionRequest { - model: settings.model.clone(), - prompt, - max_tokens: Some(max_tokens), - temperature: None, - stop: stop_tokens - .into_iter() - .map(std::borrow::Cow::Owned) - .collect(), - environment: None, - }; - - let request_body = serde_json::to_string(&request)?; - let http_request = http_client::Request::builder() - .method(http_client::Method::POST) - .uri(settings.api_url.as_ref()) - .header("Content-Type", "application/json") - .body(http_client::AsyncBody::from(request_body))?; - - let mut response = http_client.send(http_request).await?; - let status = response.status(); - - if !status.is_success() { - let mut body = String::new(); - response.body_mut().read_to_string(&mut body).await?; - anyhow::bail!("custom server error: {} - {}", status, body); - } - - let mut body = String::new(); - response.body_mut().read_to_string(&mut body).await?; - - let parsed: RawCompletionResponse = - serde_json::from_str(&body).context("Failed to parse completion response")?; - let text = parsed - .choices - .into_iter() - .next() - .map(|choice| choice.text) - .unwrap_or_default(); - Ok((text, parsed.id)) - } - } -} - pub(crate) fn edit_prediction_accepted( store: &EditPredictionStore, current_prediction: CurrentEditPrediction, diff --git a/crates/edit_prediction_ui/src/edit_prediction_button.rs b/crates/edit_prediction_ui/src/edit_prediction_button.rs index 6339c7d6cd9fa1cc40101cc1bf14650a6904b3c7..743256970f486b474405e7f034f18501505cb825 100644 --- a/crates/edit_prediction_ui/src/edit_prediction_button.rs +++ b/crates/edit_prediction_ui/src/edit_prediction_button.rs @@ -539,9 +539,15 @@ impl EditPredictionButton { edit_prediction::ollama::ensure_authenticated(cx); let sweep_api_token_task = edit_prediction::sweep_ai::load_sweep_api_token(cx); let mercury_api_token_task = edit_prediction::mercury::load_mercury_api_token(cx); + let open_ai_compatible_api_token_task = + edit_prediction::open_ai_compatible::load_open_ai_compatible_api_token(cx); cx.spawn(async move |this, cx| { - _ = futures::join!(sweep_api_token_task, mercury_api_token_task); + _ = futures::join!( + sweep_api_token_task, + mercury_api_token_task, + open_ai_compatible_api_token_task + ); this.update(cx, |_, cx| { cx.notify(); }) diff --git a/crates/settings_ui/src/pages/edit_prediction_provider_setup.rs b/crates/settings_ui/src/pages/edit_prediction_provider_setup.rs index 338fe4de14f1f7e9060fafe865253f09f0bdc481..32c4bee84bd1f72263ed28bcd44d7e6349c4b24c 100644 --- a/crates/settings_ui/src/pages/edit_prediction_provider_setup.rs +++ b/crates/settings_ui/src/pages/edit_prediction_provider_setup.rs @@ -2,6 +2,7 @@ use codestral::{CODESTRAL_API_URL, codestral_api_key_state, codestral_api_url}; use edit_prediction::{ ApiKeyState, mercury::{MERCURY_CREDENTIALS_URL, mercury_api_token}, + open_ai_compatible::{open_ai_compatible_api_token, open_ai_compatible_api_url}, sweep_ai::{SWEEP_CREDENTIALS_URL, sweep_api_token}, }; use edit_prediction_ui::{get_available_providers, set_completion_provider}; @@ -33,7 +34,9 @@ pub(crate) fn render_edit_prediction_setup_page( render_api_key_provider( IconName::Inception, "Mercury", - "https://platform.inceptionlabs.ai/dashboard/api-keys".into(), + ApiKeyDocs::Link { + dashboard_url: "https://platform.inceptionlabs.ai/dashboard/api-keys".into(), + }, mercury_api_token(cx), |_cx| MERCURY_CREDENTIALS_URL, None, @@ -46,7 +49,9 @@ pub(crate) fn render_edit_prediction_setup_page( render_api_key_provider( IconName::SweepAi, "Sweep", - "https://app.sweep.dev/".into(), + ApiKeyDocs::Link { + dashboard_url: "https://app.sweep.dev/".into(), + }, sweep_api_token(cx), |_cx| SWEEP_CREDENTIALS_URL, Some( @@ -68,7 +73,9 @@ pub(crate) fn render_edit_prediction_setup_page( render_api_key_provider( IconName::AiMistral, "Codestral", - "https://console.mistral.ai/codestral".into(), + ApiKeyDocs::Link { + dashboard_url: "https://console.mistral.ai/codestral".into(), + }, codestral_api_key_state(cx), |cx| codestral_api_url(cx), Some( @@ -87,7 +94,31 @@ pub(crate) fn render_edit_prediction_setup_page( .into_any_element(), ), Some(render_ollama_provider(settings_window, window, cx).into_any_element()), - Some(render_open_ai_compatible_provider(settings_window, window, cx).into_any_element()), + Some( + render_api_key_provider( + IconName::AiOpenAiCompat, + "OpenAI Compatible API", + ApiKeyDocs::Custom { + message: "Set an API key here. It will be sent as Authorization: Bearer {key}." + .into(), + }, + open_ai_compatible_api_token(cx), + |cx| open_ai_compatible_api_url(cx), + Some( + settings_window + .render_sub_page_items_section( + open_ai_compatible_settings().iter().enumerate(), + true, + window, + cx, + ) + .into_any_element(), + ), + window, + cx, + ) + .into_any_element(), + ), ]; div() @@ -162,10 +193,15 @@ fn render_provider_dropdown(window: &mut Window, cx: &mut App) -> AnyElement { .into_any_element() } +enum ApiKeyDocs { + Link { dashboard_url: SharedString }, + Custom { message: SharedString }, +} + fn render_api_key_provider( icon: IconName, title: &'static str, - link: SharedString, + docs: ApiKeyDocs, api_key_state: Entity, current_url: fn(&mut App) -> SharedString, additional_fields: Option, @@ -209,25 +245,32 @@ fn render_api_key_provider( .icon(icon) .no_padding(true); let button_link_label = format!("{} dashboard", title); - let description = h_flex() - .min_w_0() - .gap_0p5() - .child( - Label::new("Visit the") + let description = match docs { + ApiKeyDocs::Custom { message } => h_flex().min_w_0().gap_0p5().child( + Label::new(message) .size(LabelSize::Small) .color(Color::Muted), - ) - .child( - ButtonLink::new(button_link_label, link) - .no_icon(true) - .label_size(LabelSize::Small) - .label_color(Color::Muted), - ) - .child( - Label::new("to generate an API key.") - .size(LabelSize::Small) - .color(Color::Muted), - ); + ), + ApiKeyDocs::Link { dashboard_url } => h_flex() + .min_w_0() + .gap_0p5() + .child( + Label::new("Visit the") + .size(LabelSize::Small) + .color(Color::Muted), + ) + .child( + ButtonLink::new(button_link_label, dashboard_url) + .no_icon(true) + .label_size(LabelSize::Small) + .label_color(Color::Muted), + ) + .child( + Label::new("to generate an API key.") + .size(LabelSize::Small) + .color(Color::Muted), + ), + }; let configured_card_label = if is_from_env_var { "API Key Set in Environment Variable" } else { @@ -484,34 +527,6 @@ fn ollama_settings() -> Box<[SettingsPageItem]> { ]) } -fn render_open_ai_compatible_provider( - settings_window: &SettingsWindow, - window: &mut Window, - cx: &mut Context, -) -> impl IntoElement { - let open_ai_compatible_settings = open_ai_compatible_settings(); - let additional_fields = settings_window - .render_sub_page_items_section( - open_ai_compatible_settings.iter().enumerate(), - true, - window, - cx, - ) - .into_any_element(); - - v_flex() - .id("open-ai-compatible") - .min_w_0() - .pt_8() - .gap_1p5() - .child( - SettingsSectionHeader::new("OpenAI Compatible API") - .icon(IconName::AiOpenAiCompat) - .no_padding(true), - ) - .child(div().px_neg_8().child(additional_fields)) -} - fn open_ai_compatible_settings() -> Box<[SettingsPageItem]> { Box::new([ SettingsPageItem::SettingItem(SettingItem { diff --git a/crates/zed/src/zed/edit_prediction_registry.rs b/crates/zed/src/zed/edit_prediction_registry.rs index 67b0d26c88cf0bd254a776834de09fb89d6ea195..39eee233e02a782e2379849247448c8f8c1ea71a 100644 --- a/crates/zed/src/zed/edit_prediction_registry.rs +++ b/crates/zed/src/zed/edit_prediction_registry.rs @@ -154,7 +154,10 @@ fn edit_prediction_provider_config_for_settings(cx: &App) -> Option