crates/edit_prediction/src/edit_prediction.rs 🔗
@@ -69,6 +69,7 @@ pub mod sweep_ai;
pub mod udiff;
mod capture_example;
+pub mod open_ai_compatible;
mod zed_edit_prediction_delegate;
pub mod zeta;
Ben Kunkle created
Closes #ISSUE
Before you mark this PR as ready for review, make sure that you have:
- [ ] Added a solid test coverage and/or screenshots from doing manual
testing
- [ ] Done a self-review taking into account security and performance
aspects
- [ ] Aligned any UI changes with the [UI
checklist](https://github.com/zed-industries/zed/blob/main/CONTRIBUTING.md#uiux-checklist)
Release Notes:
- Added support for providing an API key to OpenAI-compatible edit
prediction providers
crates/edit_prediction/src/edit_prediction.rs | 1
crates/edit_prediction/src/fim.rs | 11
crates/edit_prediction/src/open_ai_compatible.rs | 133 ++++
crates/edit_prediction/src/zeta.rs | 78 -
crates/edit_prediction_ui/src/edit_prediction_button.rs | 8
crates/settings_ui/src/pages/edit_prediction_provider_setup.rs | 115 +-
crates/zed/src/zed/edit_prediction_registry.rs | 5
7 files changed, 230 insertions(+), 121 deletions(-)
@@ -69,6 +69,7 @@ pub mod sweep_ai;
pub mod udiff;
mod capture_example;
+pub mod open_ai_compatible;
mod zed_edit_prediction_delegate;
pub mod zeta;
@@ -1,6 +1,7 @@
use crate::{
- EditPredictionId, EditPredictionModelInput, cursor_excerpt, prediction::EditPredictionResult,
- zeta,
+ EditPredictionId, EditPredictionModelInput, cursor_excerpt,
+ open_ai_compatible::{self, load_open_ai_compatible_api_key_if_needed},
+ prediction::EditPredictionResult,
};
use anyhow::{Context as _, Result, anyhow};
use gpui::{App, AppContext as _, Entity, Task};
@@ -58,6 +59,8 @@ pub fn request_prediction(
return Task::ready(Err(anyhow!("Unsupported edit prediction provider for FIM")));
};
+ let api_key = load_open_ai_compatible_api_key_if_needed(provider, cx);
+
let result = cx.background_spawn(async move {
let (excerpt_range, _) = cursor_excerpt::editable_and_context_ranges_for_cursor_position(
cursor_point,
@@ -90,12 +93,14 @@ pub fn request_prediction(
let stop_tokens = get_fim_stop_tokens();
let max_tokens = settings.max_output_tokens;
- let (response_text, request_id) = zeta::send_custom_server_request(
+
+ let (response_text, request_id) = open_ai_compatible::send_custom_server_request(
provider,
&settings,
prompt,
max_tokens,
stop_tokens,
+ api_key,
&http_client,
)
.await?;
@@ -0,0 +1,133 @@
+use anyhow::{Context as _, Result};
+use cloud_llm_client::predict_edits_v3::{RawCompletionRequest, RawCompletionResponse};
+use futures::AsyncReadExt as _;
+use gpui::{App, AppContext as _, Entity, Global, SharedString, Task, http_client};
+use language::language_settings::{OpenAiCompatibleEditPredictionSettings, all_language_settings};
+use language_model::{ApiKeyState, EnvVar, env_var};
+use std::sync::Arc;
+
+pub fn open_ai_compatible_api_url(cx: &App) -> SharedString {
+ all_language_settings(None, cx)
+ .edit_predictions
+ .open_ai_compatible_api
+ .as_ref()
+ .map(|settings| settings.api_url.clone())
+ .unwrap_or_default()
+ .into()
+}
+
+pub const OPEN_AI_COMPATIBLE_CREDENTIALS_USERNAME: &str = "openai-compatible-api-token";
+pub static OPEN_AI_COMPATIBLE_TOKEN_ENV_VAR: std::sync::LazyLock<EnvVar> =
+ env_var!("ZED_OPEN_AI_COMPATIBLE_EDIT_PREDICTION_API_KEY");
+
+struct GlobalOpenAiCompatibleApiKey(Entity<ApiKeyState>);
+
+impl Global for GlobalOpenAiCompatibleApiKey {}
+
+pub fn open_ai_compatible_api_token(cx: &mut App) -> Entity<ApiKeyState> {
+ if let Some(global) = cx.try_global::<GlobalOpenAiCompatibleApiKey>() {
+ return global.0.clone();
+ }
+
+ let entity = cx.new(|cx| {
+ ApiKeyState::new(
+ open_ai_compatible_api_url(cx),
+ OPEN_AI_COMPATIBLE_TOKEN_ENV_VAR.clone(),
+ )
+ });
+ cx.set_global(GlobalOpenAiCompatibleApiKey(entity.clone()));
+ entity
+}
+
+pub fn load_open_ai_compatible_api_token(
+ cx: &mut App,
+) -> Task<Result<(), language_model::AuthenticateError>> {
+ let api_url = open_ai_compatible_api_url(cx);
+ open_ai_compatible_api_token(cx).update(cx, |key_state, cx| {
+ key_state.load_if_needed(api_url, |s| s, cx)
+ })
+}
+
+pub fn load_open_ai_compatible_api_key_if_needed(
+ provider: settings::EditPredictionProvider,
+ cx: &mut App,
+) -> Option<Arc<str>> {
+ if provider != settings::EditPredictionProvider::OpenAiCompatibleApi {
+ return None;
+ }
+ _ = load_open_ai_compatible_api_token(cx);
+ let url = open_ai_compatible_api_url(cx);
+ return open_ai_compatible_api_token(cx).read(cx).key(&url);
+}
+
+pub(crate) async fn send_custom_server_request(
+ provider: settings::EditPredictionProvider,
+ settings: &OpenAiCompatibleEditPredictionSettings,
+ prompt: String,
+ max_tokens: u32,
+ stop_tokens: Vec<String>,
+ api_key: Option<Arc<str>>,
+ http_client: &Arc<dyn http_client::HttpClient>,
+) -> Result<(String, String)> {
+ match provider {
+ settings::EditPredictionProvider::Ollama => {
+ let response = crate::ollama::make_request(
+ settings.clone(),
+ prompt,
+ stop_tokens,
+ http_client.clone(),
+ )
+ .await?;
+ Ok((response.response, response.created_at))
+ }
+ _ => {
+ let request = RawCompletionRequest {
+ model: settings.model.clone(),
+ prompt,
+ max_tokens: Some(max_tokens),
+ temperature: None,
+ stop: stop_tokens
+ .into_iter()
+ .map(std::borrow::Cow::Owned)
+ .collect(),
+ environment: None,
+ };
+
+ let request_body = serde_json::to_string(&request)?;
+ let mut http_request_builder = http_client::Request::builder()
+ .method(http_client::Method::POST)
+ .uri(settings.api_url.as_ref())
+ .header("Content-Type", "application/json");
+
+ if let Some(api_key) = api_key {
+ http_request_builder =
+ http_request_builder.header("Authorization", format!("Bearer {}", api_key));
+ }
+
+ let http_request =
+ http_request_builder.body(http_client::AsyncBody::from(request_body))?;
+
+ let mut response = http_client.send(http_request).await?;
+ let status = response.status();
+
+ if !status.is_success() {
+ let mut body = String::new();
+ response.body_mut().read_to_string(&mut body).await?;
+ anyhow::bail!("custom server error: {} - {}", status, body);
+ }
+
+ let mut body = String::new();
+ response.body_mut().read_to_string(&mut body).await?;
+
+ let parsed: RawCompletionResponse =
+ serde_json::from_str(&body).context("Failed to parse completion response")?;
+ let text = parsed
+ .choices
+ .into_iter()
+ .next()
+ .map(|choice| choice.text)
+ .unwrap_or_default();
+ Ok((text, parsed.id))
+ }
+ }
+}
@@ -2,15 +2,14 @@ use crate::cursor_excerpt::compute_excerpt_ranges;
use crate::prediction::EditPredictionResult;
use crate::{
CurrentEditPrediction, DebugEvent, EditPredictionFinishedDebugEvent, EditPredictionId,
- EditPredictionModelInput, EditPredictionStartedDebugEvent, EditPredictionStore, ollama,
+ EditPredictionModelInput, EditPredictionStartedDebugEvent, EditPredictionStore,
};
-use anyhow::{Context as _, Result};
-use cloud_llm_client::predict_edits_v3::{RawCompletionRequest, RawCompletionResponse};
+use anyhow::Result;
+use cloud_llm_client::predict_edits_v3::RawCompletionRequest;
use cloud_llm_client::{AcceptEditPredictionBody, EditPredictionRejectReason};
use edit_prediction_types::PredictedCursorPosition;
-use futures::AsyncReadExt as _;
-use gpui::{App, AppContext as _, Task, http_client, prelude::*};
-use language::language_settings::{OpenAiCompatibleEditPredictionSettings, all_language_settings};
+use gpui::{App, AppContext as _, Task, prelude::*};
+use language::language_settings::all_language_settings;
use language::{BufferSnapshot, ToOffset as _, ToPoint, text_diff};
use release_channel::AppVersion;
use settings::EditPredictionPromptFormat;
@@ -25,6 +24,10 @@ use zeta_prompt::{
zeta1::{self, EDITABLE_REGION_END_MARKER},
};
+use crate::open_ai_compatible::{
+ load_open_ai_compatible_api_key_if_needed, send_custom_server_request,
+};
+
pub fn request_prediction_with_zeta(
store: &mut EditPredictionStore,
EditPredictionModelInput {
@@ -56,6 +59,7 @@ pub fn request_prediction_with_zeta(
let buffer_snapshotted_at = Instant::now();
let raw_config = store.zeta2_raw_config().cloned();
let preferred_experiment = store.preferred_experiment().map(|s| s.to_owned());
+ let open_ai_compatible_api_key = load_open_ai_compatible_api_key_if_needed(provider, cx);
let excerpt_path: Arc<Path> = snapshot
.file()
@@ -131,6 +135,7 @@ pub fn request_prediction_with_zeta(
prompt,
max_tokens,
stop_tokens,
+ open_ai_compatible_api_key.clone(),
&http_client,
)
.await?;
@@ -157,6 +162,7 @@ pub fn request_prediction_with_zeta(
prompt,
max_tokens,
vec![],
+ open_ai_compatible_api_key.clone(),
&http_client,
)
.await?;
@@ -400,66 +406,6 @@ pub fn zeta2_prompt_input(
(full_context_offset_range, prompt_input)
}
-pub(crate) async fn send_custom_server_request(
- provider: settings::EditPredictionProvider,
- settings: &OpenAiCompatibleEditPredictionSettings,
- prompt: String,
- max_tokens: u32,
- stop_tokens: Vec<String>,
- http_client: &Arc<dyn http_client::HttpClient>,
-) -> Result<(String, String)> {
- match provider {
- settings::EditPredictionProvider::Ollama => {
- let response =
- ollama::make_request(settings.clone(), prompt, stop_tokens, http_client.clone())
- .await?;
- Ok((response.response, response.created_at))
- }
- _ => {
- let request = RawCompletionRequest {
- model: settings.model.clone(),
- prompt,
- max_tokens: Some(max_tokens),
- temperature: None,
- stop: stop_tokens
- .into_iter()
- .map(std::borrow::Cow::Owned)
- .collect(),
- environment: None,
- };
-
- let request_body = serde_json::to_string(&request)?;
- let http_request = http_client::Request::builder()
- .method(http_client::Method::POST)
- .uri(settings.api_url.as_ref())
- .header("Content-Type", "application/json")
- .body(http_client::AsyncBody::from(request_body))?;
-
- let mut response = http_client.send(http_request).await?;
- let status = response.status();
-
- if !status.is_success() {
- let mut body = String::new();
- response.body_mut().read_to_string(&mut body).await?;
- anyhow::bail!("custom server error: {} - {}", status, body);
- }
-
- let mut body = String::new();
- response.body_mut().read_to_string(&mut body).await?;
-
- let parsed: RawCompletionResponse =
- serde_json::from_str(&body).context("Failed to parse completion response")?;
- let text = parsed
- .choices
- .into_iter()
- .next()
- .map(|choice| choice.text)
- .unwrap_or_default();
- Ok((text, parsed.id))
- }
- }
-}
-
pub(crate) fn edit_prediction_accepted(
store: &EditPredictionStore,
current_prediction: CurrentEditPrediction,
@@ -539,9 +539,15 @@ impl EditPredictionButton {
edit_prediction::ollama::ensure_authenticated(cx);
let sweep_api_token_task = edit_prediction::sweep_ai::load_sweep_api_token(cx);
let mercury_api_token_task = edit_prediction::mercury::load_mercury_api_token(cx);
+ let open_ai_compatible_api_token_task =
+ edit_prediction::open_ai_compatible::load_open_ai_compatible_api_token(cx);
cx.spawn(async move |this, cx| {
- _ = futures::join!(sweep_api_token_task, mercury_api_token_task);
+ _ = futures::join!(
+ sweep_api_token_task,
+ mercury_api_token_task,
+ open_ai_compatible_api_token_task
+ );
this.update(cx, |_, cx| {
cx.notify();
})
@@ -2,6 +2,7 @@ use codestral::{CODESTRAL_API_URL, codestral_api_key_state, codestral_api_url};
use edit_prediction::{
ApiKeyState,
mercury::{MERCURY_CREDENTIALS_URL, mercury_api_token},
+ open_ai_compatible::{open_ai_compatible_api_token, open_ai_compatible_api_url},
sweep_ai::{SWEEP_CREDENTIALS_URL, sweep_api_token},
};
use edit_prediction_ui::{get_available_providers, set_completion_provider};
@@ -33,7 +34,9 @@ pub(crate) fn render_edit_prediction_setup_page(
render_api_key_provider(
IconName::Inception,
"Mercury",
- "https://platform.inceptionlabs.ai/dashboard/api-keys".into(),
+ ApiKeyDocs::Link {
+ dashboard_url: "https://platform.inceptionlabs.ai/dashboard/api-keys".into(),
+ },
mercury_api_token(cx),
|_cx| MERCURY_CREDENTIALS_URL,
None,
@@ -46,7 +49,9 @@ pub(crate) fn render_edit_prediction_setup_page(
render_api_key_provider(
IconName::SweepAi,
"Sweep",
- "https://app.sweep.dev/".into(),
+ ApiKeyDocs::Link {
+ dashboard_url: "https://app.sweep.dev/".into(),
+ },
sweep_api_token(cx),
|_cx| SWEEP_CREDENTIALS_URL,
Some(
@@ -68,7 +73,9 @@ pub(crate) fn render_edit_prediction_setup_page(
render_api_key_provider(
IconName::AiMistral,
"Codestral",
- "https://console.mistral.ai/codestral".into(),
+ ApiKeyDocs::Link {
+ dashboard_url: "https://console.mistral.ai/codestral".into(),
+ },
codestral_api_key_state(cx),
|cx| codestral_api_url(cx),
Some(
@@ -87,7 +94,31 @@ pub(crate) fn render_edit_prediction_setup_page(
.into_any_element(),
),
Some(render_ollama_provider(settings_window, window, cx).into_any_element()),
- Some(render_open_ai_compatible_provider(settings_window, window, cx).into_any_element()),
+ Some(
+ render_api_key_provider(
+ IconName::AiOpenAiCompat,
+ "OpenAI Compatible API",
+ ApiKeyDocs::Custom {
+ message: "Set an API key here. It will be sent as Authorization: Bearer {key}."
+ .into(),
+ },
+ open_ai_compatible_api_token(cx),
+ |cx| open_ai_compatible_api_url(cx),
+ Some(
+ settings_window
+ .render_sub_page_items_section(
+ open_ai_compatible_settings().iter().enumerate(),
+ true,
+ window,
+ cx,
+ )
+ .into_any_element(),
+ ),
+ window,
+ cx,
+ )
+ .into_any_element(),
+ ),
];
div()
@@ -162,10 +193,15 @@ fn render_provider_dropdown(window: &mut Window, cx: &mut App) -> AnyElement {
.into_any_element()
}
+enum ApiKeyDocs {
+ Link { dashboard_url: SharedString },
+ Custom { message: SharedString },
+}
+
fn render_api_key_provider(
icon: IconName,
title: &'static str,
- link: SharedString,
+ docs: ApiKeyDocs,
api_key_state: Entity<ApiKeyState>,
current_url: fn(&mut App) -> SharedString,
additional_fields: Option<AnyElement>,
@@ -209,25 +245,32 @@ fn render_api_key_provider(
.icon(icon)
.no_padding(true);
let button_link_label = format!("{} dashboard", title);
- let description = h_flex()
- .min_w_0()
- .gap_0p5()
- .child(
- Label::new("Visit the")
+ let description = match docs {
+ ApiKeyDocs::Custom { message } => h_flex().min_w_0().gap_0p5().child(
+ Label::new(message)
.size(LabelSize::Small)
.color(Color::Muted),
- )
- .child(
- ButtonLink::new(button_link_label, link)
- .no_icon(true)
- .label_size(LabelSize::Small)
- .label_color(Color::Muted),
- )
- .child(
- Label::new("to generate an API key.")
- .size(LabelSize::Small)
- .color(Color::Muted),
- );
+ ),
+ ApiKeyDocs::Link { dashboard_url } => h_flex()
+ .min_w_0()
+ .gap_0p5()
+ .child(
+ Label::new("Visit the")
+ .size(LabelSize::Small)
+ .color(Color::Muted),
+ )
+ .child(
+ ButtonLink::new(button_link_label, dashboard_url)
+ .no_icon(true)
+ .label_size(LabelSize::Small)
+ .label_color(Color::Muted),
+ )
+ .child(
+ Label::new("to generate an API key.")
+ .size(LabelSize::Small)
+ .color(Color::Muted),
+ ),
+ };
let configured_card_label = if is_from_env_var {
"API Key Set in Environment Variable"
} else {
@@ -484,34 +527,6 @@ fn ollama_settings() -> Box<[SettingsPageItem]> {
])
}
-fn render_open_ai_compatible_provider(
- settings_window: &SettingsWindow,
- window: &mut Window,
- cx: &mut Context<SettingsWindow>,
-) -> impl IntoElement {
- let open_ai_compatible_settings = open_ai_compatible_settings();
- let additional_fields = settings_window
- .render_sub_page_items_section(
- open_ai_compatible_settings.iter().enumerate(),
- true,
- window,
- cx,
- )
- .into_any_element();
-
- v_flex()
- .id("open-ai-compatible")
- .min_w_0()
- .pt_8()
- .gap_1p5()
- .child(
- SettingsSectionHeader::new("OpenAI Compatible API")
- .icon(IconName::AiOpenAiCompat)
- .no_padding(true),
- )
- .child(div().px_neg_8().child(additional_fields))
-}
-
fn open_ai_compatible_settings() -> Box<[SettingsPageItem]> {
Box::new([
SettingsPageItem::SettingItem(SettingItem {
@@ -154,7 +154,10 @@ fn edit_prediction_provider_config_for_settings(cx: &App) -> Option<EditPredicti
}
}
- if format == EditPredictionPromptFormat::Zeta {
+ if matches!(
+ format,
+ EditPredictionPromptFormat::Zeta | EditPredictionPromptFormat::Zeta2
+ ) {
Some(EditPredictionProviderConfig::Zed(EditPredictionModel::Zeta))
} else {
Some(EditPredictionProviderConfig::Zed(