From f08fd732a7ecbfe191563e2498a61a7ae75d5b05 Mon Sep 17 00:00:00 2001 From: Agus Zubiaga Date: Sat, 6 Dec 2025 07:08:44 -0300 Subject: [PATCH] Add experimental mercury edit prediction provider (#44256) Release Notes: - N/A --------- Co-authored-by: Ben Kunkle Co-authored-by: Max Brunsfeld --- assets/icons/inception.svg | 11 + crates/edit_prediction/src/cursor_excerpt.rs | 78 ++++ crates/edit_prediction/src/edit_prediction.rs | 37 +- .../src/edit_prediction_tests.rs | 2 +- crates/edit_prediction/src/mercury.rs | 340 ++++++++++++++++++ .../edit_prediction/src/open_ai_response.rs | 31 ++ crates/edit_prediction/src/zeta1.rs | 178 ++++++++- .../src/zeta1/input_excerpt.rs | 231 ------------ crates/edit_prediction/src/zeta2.rs | 35 +- crates/edit_prediction_cli/src/predict.rs | 5 +- .../src/edit_prediction_button.rs | 112 +++++- .../src/edit_prediction_ui.rs | 4 +- ...s => external_provider_api_token_modal.rs} | 33 +- crates/icons/src/icons.rs | 15 +- .../language_models/src/provider/open_ai.rs | 2 +- crates/open_ai/src/open_ai.rs | 3 +- .../settings/src/settings_content/language.rs | 8 + .../zed/src/zed/edit_prediction_registry.rs | 5 + 18 files changed, 808 insertions(+), 322 deletions(-) create mode 100644 assets/icons/inception.svg create mode 100644 crates/edit_prediction/src/cursor_excerpt.rs create mode 100644 crates/edit_prediction/src/mercury.rs create mode 100644 crates/edit_prediction/src/open_ai_response.rs delete mode 100644 crates/edit_prediction/src/zeta1/input_excerpt.rs rename crates/edit_prediction_ui/src/{sweep_api_token_modal.rs => external_provider_api_token_modal.rs} (72%) diff --git a/assets/icons/inception.svg b/assets/icons/inception.svg new file mode 100644 index 0000000000000000000000000000000000000000..77a96c0b390ab9f2fe89143c2a89ba916000fabc --- /dev/null +++ b/assets/icons/inception.svg @@ -0,0 +1,11 @@ + + + + + + + + + + + diff --git a/crates/edit_prediction/src/cursor_excerpt.rs b/crates/edit_prediction/src/cursor_excerpt.rs new file mode 100644 index 0000000000000000000000000000000000000000..1f2f1d32ebcb2eaa151433bd49d275e0e2a3b817 --- /dev/null +++ b/crates/edit_prediction/src/cursor_excerpt.rs @@ -0,0 +1,78 @@ +use language::{BufferSnapshot, Point}; +use std::ops::Range; + +pub fn editable_and_context_ranges_for_cursor_position( + position: Point, + snapshot: &BufferSnapshot, + editable_region_token_limit: usize, + context_token_limit: usize, +) -> (Range, Range) { + let mut scope_range = position..position; + let mut remaining_edit_tokens = editable_region_token_limit; + + while let Some(parent) = snapshot.syntax_ancestor(scope_range.clone()) { + let parent_tokens = guess_token_count(parent.byte_range().len()); + let parent_point_range = Point::new( + parent.start_position().row as u32, + parent.start_position().column as u32, + ) + ..Point::new( + parent.end_position().row as u32, + parent.end_position().column as u32, + ); + if parent_point_range == scope_range { + break; + } else if parent_tokens <= editable_region_token_limit { + scope_range = parent_point_range; + remaining_edit_tokens = editable_region_token_limit - parent_tokens; + } else { + break; + } + } + + let editable_range = expand_range(snapshot, scope_range, remaining_edit_tokens); + let context_range = expand_range(snapshot, editable_range.clone(), context_token_limit); + (editable_range, context_range) +} + +fn expand_range( + snapshot: &BufferSnapshot, + range: Range, + mut remaining_tokens: usize, +) -> Range { + let mut expanded_range = range; + expanded_range.start.column = 0; + expanded_range.end.column = snapshot.line_len(expanded_range.end.row); + loop { + let mut expanded = false; + + if remaining_tokens > 0 && expanded_range.start.row > 0 { + expanded_range.start.row -= 1; + let line_tokens = + guess_token_count(snapshot.line_len(expanded_range.start.row) as usize); + remaining_tokens = remaining_tokens.saturating_sub(line_tokens); + expanded = true; + } + + if remaining_tokens > 0 && expanded_range.end.row < snapshot.max_point().row { + expanded_range.end.row += 1; + expanded_range.end.column = snapshot.line_len(expanded_range.end.row); + let line_tokens = guess_token_count(expanded_range.end.column as usize); + remaining_tokens = remaining_tokens.saturating_sub(line_tokens); + expanded = true; + } + + if !expanded { + break; + } + } + expanded_range +} + +/// Typical number of string bytes per token for the purposes of limiting model input. This is +/// intentionally low to err on the side of underestimating limits. +pub(crate) const BYTES_PER_TOKEN_GUESS: usize = 3; + +pub fn guess_token_count(bytes: usize) -> usize { + bytes / BYTES_PER_TOKEN_GUESS +} diff --git a/crates/edit_prediction/src/edit_prediction.rs b/crates/edit_prediction/src/edit_prediction.rs index ea8f0af7e16dedd30a86284af5386829053d7fab..141fff3063b83d7e0003fddd6b4eba2d213d5fd5 100644 --- a/crates/edit_prediction/src/edit_prediction.rs +++ b/crates/edit_prediction/src/edit_prediction.rs @@ -51,8 +51,11 @@ use thiserror::Error; use util::{RangeExt as _, ResultExt as _}; use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification}; +mod cursor_excerpt; mod license_detection; +pub mod mercury; mod onboarding_modal; +pub mod open_ai_response; mod prediction; pub mod sweep_ai; pub mod udiff; @@ -65,6 +68,7 @@ pub mod zeta2; mod edit_prediction_tests; use crate::license_detection::LicenseDetectionWatcher; +use crate::mercury::Mercury; use crate::onboarding_modal::ZedPredictModal; pub use crate::prediction::EditPrediction; pub use crate::prediction::EditPredictionId; @@ -96,6 +100,12 @@ impl FeatureFlag for SweepFeatureFlag { const NAME: &str = "sweep-ai"; } +pub struct MercuryFeatureFlag; + +impl FeatureFlag for MercuryFeatureFlag { + const NAME: &str = "mercury"; +} + pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions { context: EditPredictionExcerptOptions { max_bytes: 512, @@ -157,6 +167,7 @@ pub struct EditPredictionStore { eval_cache: Option>, edit_prediction_model: EditPredictionModel, pub sweep_ai: SweepAi, + pub mercury: Mercury, data_collection_choice: DataCollectionChoice, reject_predictions_tx: mpsc::UnboundedSender, shown_predictions: VecDeque, @@ -169,6 +180,7 @@ pub enum EditPredictionModel { Zeta1, Zeta2, Sweep, + Mercury, } #[derive(Debug, Clone, PartialEq)] @@ -474,6 +486,7 @@ impl EditPredictionStore { eval_cache: None, edit_prediction_model: EditPredictionModel::Zeta2, sweep_ai: SweepAi::new(cx), + mercury: Mercury::new(cx), data_collection_choice, reject_predictions_tx: reject_tx, rated_predictions: Default::default(), @@ -509,6 +522,15 @@ impl EditPredictionStore { .is_some() } + pub fn has_mercury_api_token(&self) -> bool { + self.mercury + .api_token + .clone() + .now_or_never() + .flatten() + .is_some() + } + #[cfg(feature = "eval-support")] pub fn with_eval_cache(&mut self, cache: Arc) { self.eval_cache = Some(cache); @@ -868,7 +890,7 @@ impl EditPredictionStore { fn accept_current_prediction(&mut self, project: &Entity, cx: &mut Context) { match self.edit_prediction_model { EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {} - EditPredictionModel::Sweep => return, + EditPredictionModel::Sweep | EditPredictionModel::Mercury => return, } let Some(project_state) = self.projects.get_mut(&project.entity_id()) else { @@ -1013,7 +1035,7 @@ impl EditPredictionStore { ) { match self.edit_prediction_model { EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {} - EditPredictionModel::Sweep => return, + EditPredictionModel::Sweep | EditPredictionModel::Mercury => return, } self.reject_predictions_tx @@ -1373,6 +1395,17 @@ impl EditPredictionStore { diagnostic_search_range.clone(), cx, ), + EditPredictionModel::Mercury => self.mercury.request_prediction( + &project, + &active_buffer, + snapshot.clone(), + position, + events, + &project_state.recent_paths, + related_files, + diagnostic_search_range.clone(), + cx, + ), }; cx.spawn(async move |this, cx| { diff --git a/crates/edit_prediction/src/edit_prediction_tests.rs b/crates/edit_prediction/src/edit_prediction_tests.rs index 8d5bad9ed8990769fd512ecfe523cf8d79aebca6..0b7e289bb32b5a10c32a4bd34f118d7cb6c7d43c 100644 --- a/crates/edit_prediction/src/edit_prediction_tests.rs +++ b/crates/edit_prediction/src/edit_prediction_tests.rs @@ -1620,7 +1620,7 @@ async fn test_no_data_collection_for_events_in_uncollectable_buffers(cx: &mut Te buffer.edit( [( 0..0, - " ".repeat(MAX_EVENT_TOKENS * zeta1::BYTES_PER_TOKEN_GUESS), + " ".repeat(MAX_EVENT_TOKENS * cursor_excerpt::BYTES_PER_TOKEN_GUESS), )], None, cx, diff --git a/crates/edit_prediction/src/mercury.rs b/crates/edit_prediction/src/mercury.rs new file mode 100644 index 0000000000000000000000000000000000000000..40c0fdfac021f937df5172fd423d3b6bfc5f8146 --- /dev/null +++ b/crates/edit_prediction/src/mercury.rs @@ -0,0 +1,340 @@ +use anyhow::{Context as _, Result}; +use cloud_llm_client::predict_edits_v3::Event; +use credentials_provider::CredentialsProvider; +use edit_prediction_context::RelatedFile; +use futures::{AsyncReadExt as _, FutureExt, future::Shared}; +use gpui::{ + App, AppContext as _, Entity, Task, + http_client::{self, AsyncBody, Method}, +}; +use language::{Buffer, BufferSnapshot, OffsetRangeExt as _, Point, ToPoint as _}; +use project::{Project, ProjectPath}; +use std::{ + collections::VecDeque, fmt::Write as _, mem, ops::Range, path::Path, sync::Arc, time::Instant, +}; + +use crate::{ + EditPredictionId, EditPredictionInputs, open_ai_response::text_from_response, + prediction::EditPredictionResult, +}; + +const MERCURY_API_URL: &str = "https://api.inceptionlabs.ai/v1/edit/completions"; +const MAX_CONTEXT_TOKENS: usize = 150; +const MAX_REWRITE_TOKENS: usize = 350; + +pub struct Mercury { + pub api_token: Shared>>, +} + +impl Mercury { + pub fn new(cx: &App) -> Self { + Mercury { + api_token: load_api_token(cx).shared(), + } + } + + pub fn set_api_token(&mut self, api_token: Option, cx: &mut App) -> Task> { + self.api_token = Task::ready(api_token.clone()).shared(); + store_api_token_in_keychain(api_token, cx) + } + + pub fn request_prediction( + &self, + _project: &Entity, + active_buffer: &Entity, + snapshot: BufferSnapshot, + position: language::Anchor, + events: Vec>, + _recent_paths: &VecDeque, + related_files: Vec, + _diagnostic_search_range: Range, + cx: &mut App, + ) -> Task>> { + let Some(api_token) = self.api_token.clone().now_or_never().flatten() else { + return Task::ready(Ok(None)); + }; + let full_path: Arc = snapshot + .file() + .map(|file| file.full_path(cx)) + .unwrap_or_else(|| "untitled".into()) + .into(); + + let http_client = cx.http_client(); + let cursor_point = position.to_point(&snapshot); + let buffer_snapshotted_at = Instant::now(); + + let result = cx.background_spawn(async move { + let (editable_range, context_range) = + crate::cursor_excerpt::editable_and_context_ranges_for_cursor_position( + cursor_point, + &snapshot, + MAX_CONTEXT_TOKENS, + MAX_REWRITE_TOKENS, + ); + + let offset_range = editable_range.to_offset(&snapshot); + let prompt = build_prompt( + &events, + &related_files, + &snapshot, + full_path.as_ref(), + cursor_point, + editable_range, + context_range.clone(), + ); + + let inputs = EditPredictionInputs { + events: events, + included_files: vec![cloud_llm_client::predict_edits_v3::RelatedFile { + path: full_path.clone(), + max_row: cloud_llm_client::predict_edits_v3::Line(snapshot.max_point().row), + excerpts: vec![cloud_llm_client::predict_edits_v3::Excerpt { + start_line: cloud_llm_client::predict_edits_v3::Line( + context_range.start.row, + ), + text: snapshot + .text_for_range(context_range.clone()) + .collect::() + .into(), + }], + }], + cursor_point: cloud_llm_client::predict_edits_v3::Point { + column: cursor_point.column, + line: cloud_llm_client::predict_edits_v3::Line(cursor_point.row), + }, + cursor_path: full_path.clone(), + }; + + let request_body = open_ai::Request { + model: "mercury-coder".into(), + messages: vec![open_ai::RequestMessage::User { + content: open_ai::MessageContent::Plain(prompt), + }], + stream: false, + max_completion_tokens: None, + stop: vec![], + temperature: None, + tool_choice: None, + parallel_tool_calls: None, + tools: vec![], + prompt_cache_key: None, + reasoning_effort: None, + }; + + let buf = serde_json::to_vec(&request_body)?; + let body: AsyncBody = buf.into(); + + let request = http_client::Request::builder() + .uri(MERCURY_API_URL) + .header("Content-Type", "application/json") + .header("Authorization", format!("Bearer {}", api_token)) + .header("Connection", "keep-alive") + .method(Method::POST) + .body(body) + .context("Failed to create request")?; + + let mut response = http_client + .send(request) + .await + .context("Failed to send request")?; + + let mut body: Vec = Vec::new(); + response + .body_mut() + .read_to_end(&mut body) + .await + .context("Failed to read response body")?; + + let response_received_at = Instant::now(); + if !response.status().is_success() { + anyhow::bail!( + "Request failed with status: {:?}\nBody: {}", + response.status(), + String::from_utf8_lossy(&body), + ); + }; + + let mut response: open_ai::Response = + serde_json::from_slice(&body).context("Failed to parse response")?; + + let id = mem::take(&mut response.id); + let response_str = text_from_response(response).unwrap_or_default(); + + let response_str = response_str.strip_prefix("```\n").unwrap_or(&response_str); + let response_str = response_str.strip_suffix("\n```").unwrap_or(&response_str); + + let mut edits = Vec::new(); + const NO_PREDICTION_OUTPUT: &str = "None"; + + if response_str != NO_PREDICTION_OUTPUT { + let old_text = snapshot + .text_for_range(offset_range.clone()) + .collect::(); + edits.extend( + language::text_diff(&old_text, &response_str) + .into_iter() + .map(|(range, text)| { + ( + snapshot.anchor_after(offset_range.start + range.start) + ..snapshot.anchor_before(offset_range.start + range.end), + text, + ) + }), + ); + } + + anyhow::Ok((id, edits, snapshot, response_received_at, inputs)) + }); + + let buffer = active_buffer.clone(); + + cx.spawn(async move |cx| { + let (id, edits, old_snapshot, response_received_at, inputs) = + result.await.context("Mercury edit prediction failed")?; + anyhow::Ok(Some( + EditPredictionResult::new( + EditPredictionId(id.into()), + &buffer, + &old_snapshot, + edits.into(), + buffer_snapshotted_at, + response_received_at, + inputs, + cx, + ) + .await, + )) + }) + } +} + +fn build_prompt( + events: &[Arc], + related_files: &[RelatedFile], + cursor_buffer: &BufferSnapshot, + cursor_buffer_path: &Path, + cursor_point: Point, + editable_range: Range, + context_range: Range, +) -> String { + const RECENTLY_VIEWED_SNIPPETS_START: &str = "<|recently_viewed_code_snippets|>\n"; + const RECENTLY_VIEWED_SNIPPETS_END: &str = "<|/recently_viewed_code_snippets|>\n"; + const RECENTLY_VIEWED_SNIPPET_START: &str = "<|recently_viewed_code_snippet|>\n"; + const RECENTLY_VIEWED_SNIPPET_END: &str = "<|/recently_viewed_code_snippet|>\n"; + const CURRENT_FILE_CONTENT_START: &str = "<|current_file_content|>\n"; + const CURRENT_FILE_CONTENT_END: &str = "<|/current_file_content|>\n"; + const CODE_TO_EDIT_START: &str = "<|code_to_edit|>\n"; + const CODE_TO_EDIT_END: &str = "<|/code_to_edit|>\n"; + const EDIT_DIFF_HISTORY_START: &str = "<|edit_diff_history|>\n"; + const EDIT_DIFF_HISTORY_END: &str = "<|/edit_diff_history|>\n"; + const CURSOR_TAG: &str = "<|cursor|>"; + const CODE_SNIPPET_FILE_PATH_PREFIX: &str = "code_snippet_file_path: "; + const CURRENT_FILE_PATH_PREFIX: &str = "current_file_path: "; + + let mut prompt = String::new(); + + push_delimited( + &mut prompt, + RECENTLY_VIEWED_SNIPPETS_START..RECENTLY_VIEWED_SNIPPETS_END, + |prompt| { + for related_file in related_files { + for related_excerpt in &related_file.excerpts { + push_delimited( + prompt, + RECENTLY_VIEWED_SNIPPET_START..RECENTLY_VIEWED_SNIPPET_END, + |prompt| { + prompt.push_str(CODE_SNIPPET_FILE_PATH_PREFIX); + prompt.push_str(related_file.path.path.as_unix_str()); + prompt.push('\n'); + prompt.push_str(&related_excerpt.text.to_string()); + }, + ); + } + } + }, + ); + + push_delimited( + &mut prompt, + CURRENT_FILE_CONTENT_START..CURRENT_FILE_CONTENT_END, + |prompt| { + prompt.push_str(CURRENT_FILE_PATH_PREFIX); + prompt.push_str(cursor_buffer_path.as_os_str().to_string_lossy().as_ref()); + prompt.push('\n'); + + let prefix_range = context_range.start..editable_range.start; + let suffix_range = editable_range.end..context_range.end; + + prompt.extend(cursor_buffer.text_for_range(prefix_range)); + push_delimited(prompt, CODE_TO_EDIT_START..CODE_TO_EDIT_END, |prompt| { + let range_before_cursor = editable_range.start..cursor_point; + let range_after_cursor = cursor_point..editable_range.end; + prompt.extend(cursor_buffer.text_for_range(range_before_cursor)); + prompt.push_str(CURSOR_TAG); + prompt.extend(cursor_buffer.text_for_range(range_after_cursor)); + }); + prompt.extend(cursor_buffer.text_for_range(suffix_range)); + }, + ); + + push_delimited( + &mut prompt, + EDIT_DIFF_HISTORY_START..EDIT_DIFF_HISTORY_END, + |prompt| { + for event in events { + writeln!(prompt, "{event}").unwrap(); + } + }, + ); + + prompt +} + +fn push_delimited(prompt: &mut String, delimiters: Range<&str>, cb: impl FnOnce(&mut String)) { + prompt.push_str(delimiters.start); + cb(prompt); + prompt.push_str(delimiters.end); +} + +pub const MERCURY_CREDENTIALS_URL: &str = "https://api.inceptionlabs.ai/v1/edit/completions"; +pub const MERCURY_CREDENTIALS_USERNAME: &str = "mercury-api-token"; + +pub fn load_api_token(cx: &App) -> Task> { + if let Some(api_token) = std::env::var("MERCURY_AI_TOKEN") + .ok() + .filter(|value| !value.is_empty()) + { + return Task::ready(Some(api_token)); + } + let credentials_provider = ::global(cx); + cx.spawn(async move |cx| { + let (_, credentials) = credentials_provider + .read_credentials(MERCURY_CREDENTIALS_URL, &cx) + .await + .ok()??; + String::from_utf8(credentials).ok() + }) +} + +fn store_api_token_in_keychain(api_token: Option, cx: &App) -> Task> { + let credentials_provider = ::global(cx); + + cx.spawn(async move |cx| { + if let Some(api_token) = api_token { + credentials_provider + .write_credentials( + MERCURY_CREDENTIALS_URL, + MERCURY_CREDENTIALS_USERNAME, + api_token.as_bytes(), + cx, + ) + .await + .context("Failed to save Mercury API token to system keychain") + } else { + credentials_provider + .delete_credentials(MERCURY_CREDENTIALS_URL, cx) + .await + .context("Failed to delete Mercury API token from system keychain") + } + }) +} diff --git a/crates/edit_prediction/src/open_ai_response.rs b/crates/edit_prediction/src/open_ai_response.rs new file mode 100644 index 0000000000000000000000000000000000000000..c7e3350936dd89c89849130ba279ad2914dd2bd8 --- /dev/null +++ b/crates/edit_prediction/src/open_ai_response.rs @@ -0,0 +1,31 @@ +pub fn text_from_response(mut res: open_ai::Response) -> Option { + let choice = res.choices.pop()?; + let output_text = match choice.message { + open_ai::RequestMessage::Assistant { + content: Some(open_ai::MessageContent::Plain(content)), + .. + } => content, + open_ai::RequestMessage::Assistant { + content: Some(open_ai::MessageContent::Multipart(mut content)), + .. + } => { + if content.is_empty() { + log::error!("No output from Baseten completion response"); + return None; + } + + match content.remove(0) { + open_ai::MessagePart::Text { text } => text, + open_ai::MessagePart::Image { .. } => { + log::error!("Expected text, got an image"); + return None; + } + } + } + _ => { + log::error!("Invalid response message: {:?}", choice.message); + return None; + } + }; + Some(output_text) +} diff --git a/crates/edit_prediction/src/zeta1.rs b/crates/edit_prediction/src/zeta1.rs index 06248603464563db12fa55a90c9c0bccf153c5f4..20f70421810c6d1678f844d1ec4c968b1ca96678 100644 --- a/crates/edit_prediction/src/zeta1.rs +++ b/crates/edit_prediction/src/zeta1.rs @@ -1,9 +1,8 @@ -mod input_excerpt; - use std::{fmt::Write, ops::Range, path::Path, sync::Arc, time::Instant}; use crate::{ EditPredictionId, EditPredictionStore, ZedUpdateRequiredError, + cursor_excerpt::{editable_and_context_ranges_for_cursor_position, guess_token_count}, prediction::{EditPredictionInputs, EditPredictionResult}, }; use anyhow::{Context as _, Result}; @@ -12,7 +11,6 @@ use cloud_llm_client::{ predict_edits_v3::Event, }; use gpui::{App, AppContext as _, AsyncApp, Context, Entity, SharedString, Task}; -use input_excerpt::excerpt_for_cursor_position; use language::{ Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, Point, ToPoint as _, text_diff, }; @@ -495,10 +493,174 @@ pub fn format_event(event: &Event) -> String { } } -/// Typical number of string bytes per token for the purposes of limiting model input. This is -/// intentionally low to err on the side of underestimating limits. -pub(crate) const BYTES_PER_TOKEN_GUESS: usize = 3; +#[derive(Debug)] +pub struct InputExcerpt { + pub context_range: Range, + pub editable_range: Range, + pub prompt: String, +} + +pub fn excerpt_for_cursor_position( + position: Point, + path: &str, + snapshot: &BufferSnapshot, + editable_region_token_limit: usize, + context_token_limit: usize, +) -> InputExcerpt { + let (editable_range, context_range) = editable_and_context_ranges_for_cursor_position( + position, + snapshot, + editable_region_token_limit, + context_token_limit, + ); + + let mut prompt = String::new(); + + writeln!(&mut prompt, "```{path}").unwrap(); + if context_range.start == Point::zero() { + writeln!(&mut prompt, "{START_OF_FILE_MARKER}").unwrap(); + } + + for chunk in snapshot.chunks(context_range.start..editable_range.start, false) { + prompt.push_str(chunk.text); + } + + push_editable_range(position, snapshot, editable_range.clone(), &mut prompt); + + for chunk in snapshot.chunks(editable_range.end..context_range.end, false) { + prompt.push_str(chunk.text); + } + write!(prompt, "\n```").unwrap(); + + InputExcerpt { + context_range, + editable_range, + prompt, + } +} + +fn push_editable_range( + cursor_position: Point, + snapshot: &BufferSnapshot, + editable_range: Range, + prompt: &mut String, +) { + writeln!(prompt, "{EDITABLE_REGION_START_MARKER}").unwrap(); + for chunk in snapshot.chunks(editable_range.start..cursor_position, false) { + prompt.push_str(chunk.text); + } + prompt.push_str(CURSOR_MARKER); + for chunk in snapshot.chunks(cursor_position..editable_range.end, false) { + prompt.push_str(chunk.text); + } + write!(prompt, "\n{EDITABLE_REGION_END_MARKER}").unwrap(); +} + +#[cfg(test)] +mod tests { + use super::*; + use gpui::{App, AppContext}; + use indoc::indoc; + use language::{Buffer, Language, LanguageConfig, LanguageMatcher, tree_sitter_rust}; + use std::sync::Arc; + + #[gpui::test] + fn test_excerpt_for_cursor_position(cx: &mut App) { + let text = indoc! {r#" + fn foo() { + let x = 42; + println!("Hello, world!"); + } + + fn bar() { + let x = 42; + let mut sum = 0; + for i in 0..x { + sum += i; + } + println!("Sum: {}", sum); + return sum; + } + + fn generate_random_numbers() -> Vec { + let mut rng = rand::thread_rng(); + let mut numbers = Vec::new(); + for _ in 0..5 { + numbers.push(rng.random_range(1..101)); + } + numbers + } + "#}; + let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx)); + let snapshot = buffer.read(cx).snapshot(); + + // Ensure we try to fit the largest possible syntax scope, resorting to line-based expansion + // when a larger scope doesn't fit the editable region. + let excerpt = excerpt_for_cursor_position(Point::new(12, 5), "main.rs", &snapshot, 50, 32); + assert_eq!( + excerpt.prompt, + indoc! {r#" + ```main.rs + let x = 42; + println!("Hello, world!"); + <|editable_region_start|> + } + + fn bar() { + let x = 42; + let mut sum = 0; + for i in 0..x { + sum += i; + } + println!("Sum: {}", sum); + r<|user_cursor_is_here|>eturn sum; + } -fn guess_token_count(bytes: usize) -> usize { - bytes / BYTES_PER_TOKEN_GUESS + fn generate_random_numbers() -> Vec { + <|editable_region_end|> + let mut rng = rand::thread_rng(); + let mut numbers = Vec::new(); + ```"#} + ); + + // The `bar` function won't fit within the editable region, so we resort to line-based expansion. + let excerpt = excerpt_for_cursor_position(Point::new(12, 5), "main.rs", &snapshot, 40, 32); + assert_eq!( + excerpt.prompt, + indoc! {r#" + ```main.rs + fn bar() { + let x = 42; + let mut sum = 0; + <|editable_region_start|> + for i in 0..x { + sum += i; + } + println!("Sum: {}", sum); + r<|user_cursor_is_here|>eturn sum; + } + + fn generate_random_numbers() -> Vec { + let mut rng = rand::thread_rng(); + <|editable_region_end|> + let mut numbers = Vec::new(); + for _ in 0..5 { + numbers.push(rng.random_range(1..101)); + ```"#} + ); + } + + fn rust_lang() -> Language { + Language::new( + LanguageConfig { + name: "Rust".into(), + matcher: LanguageMatcher { + path_suffixes: vec!["rs".to_string()], + ..Default::default() + }, + ..Default::default() + }, + Some(tree_sitter_rust::LANGUAGE.into()), + ) + } } diff --git a/crates/edit_prediction/src/zeta1/input_excerpt.rs b/crates/edit_prediction/src/zeta1/input_excerpt.rs deleted file mode 100644 index 853d74da463c19de4f1d3915cb703a53b6c43c61..0000000000000000000000000000000000000000 --- a/crates/edit_prediction/src/zeta1/input_excerpt.rs +++ /dev/null @@ -1,231 +0,0 @@ -use super::{ - CURSOR_MARKER, EDITABLE_REGION_END_MARKER, EDITABLE_REGION_START_MARKER, START_OF_FILE_MARKER, - guess_token_count, -}; -use language::{BufferSnapshot, Point}; -use std::{fmt::Write, ops::Range}; - -#[derive(Debug)] -pub struct InputExcerpt { - pub context_range: Range, - pub editable_range: Range, - pub prompt: String, -} - -pub fn excerpt_for_cursor_position( - position: Point, - path: &str, - snapshot: &BufferSnapshot, - editable_region_token_limit: usize, - context_token_limit: usize, -) -> InputExcerpt { - let mut scope_range = position..position; - let mut remaining_edit_tokens = editable_region_token_limit; - - while let Some(parent) = snapshot.syntax_ancestor(scope_range.clone()) { - let parent_tokens = guess_token_count(parent.byte_range().len()); - let parent_point_range = Point::new( - parent.start_position().row as u32, - parent.start_position().column as u32, - ) - ..Point::new( - parent.end_position().row as u32, - parent.end_position().column as u32, - ); - if parent_point_range == scope_range { - break; - } else if parent_tokens <= editable_region_token_limit { - scope_range = parent_point_range; - remaining_edit_tokens = editable_region_token_limit - parent_tokens; - } else { - break; - } - } - - let editable_range = expand_range(snapshot, scope_range, remaining_edit_tokens); - let context_range = expand_range(snapshot, editable_range.clone(), context_token_limit); - - let mut prompt = String::new(); - - writeln!(&mut prompt, "```{path}").unwrap(); - if context_range.start == Point::zero() { - writeln!(&mut prompt, "{START_OF_FILE_MARKER}").unwrap(); - } - - for chunk in snapshot.chunks(context_range.start..editable_range.start, false) { - prompt.push_str(chunk.text); - } - - push_editable_range(position, snapshot, editable_range.clone(), &mut prompt); - - for chunk in snapshot.chunks(editable_range.end..context_range.end, false) { - prompt.push_str(chunk.text); - } - write!(prompt, "\n```").unwrap(); - - InputExcerpt { - context_range, - editable_range, - prompt, - } -} - -fn push_editable_range( - cursor_position: Point, - snapshot: &BufferSnapshot, - editable_range: Range, - prompt: &mut String, -) { - writeln!(prompt, "{EDITABLE_REGION_START_MARKER}").unwrap(); - for chunk in snapshot.chunks(editable_range.start..cursor_position, false) { - prompt.push_str(chunk.text); - } - prompt.push_str(CURSOR_MARKER); - for chunk in snapshot.chunks(cursor_position..editable_range.end, false) { - prompt.push_str(chunk.text); - } - write!(prompt, "\n{EDITABLE_REGION_END_MARKER}").unwrap(); -} - -fn expand_range( - snapshot: &BufferSnapshot, - range: Range, - mut remaining_tokens: usize, -) -> Range { - let mut expanded_range = range; - expanded_range.start.column = 0; - expanded_range.end.column = snapshot.line_len(expanded_range.end.row); - loop { - let mut expanded = false; - - if remaining_tokens > 0 && expanded_range.start.row > 0 { - expanded_range.start.row -= 1; - let line_tokens = - guess_token_count(snapshot.line_len(expanded_range.start.row) as usize); - remaining_tokens = remaining_tokens.saturating_sub(line_tokens); - expanded = true; - } - - if remaining_tokens > 0 && expanded_range.end.row < snapshot.max_point().row { - expanded_range.end.row += 1; - expanded_range.end.column = snapshot.line_len(expanded_range.end.row); - let line_tokens = guess_token_count(expanded_range.end.column as usize); - remaining_tokens = remaining_tokens.saturating_sub(line_tokens); - expanded = true; - } - - if !expanded { - break; - } - } - expanded_range -} - -#[cfg(test)] -mod tests { - use super::*; - use gpui::{App, AppContext}; - use indoc::indoc; - use language::{Buffer, Language, LanguageConfig, LanguageMatcher, tree_sitter_rust}; - use std::sync::Arc; - - #[gpui::test] - fn test_excerpt_for_cursor_position(cx: &mut App) { - let text = indoc! {r#" - fn foo() { - let x = 42; - println!("Hello, world!"); - } - - fn bar() { - let x = 42; - let mut sum = 0; - for i in 0..x { - sum += i; - } - println!("Sum: {}", sum); - return sum; - } - - fn generate_random_numbers() -> Vec { - let mut rng = rand::thread_rng(); - let mut numbers = Vec::new(); - for _ in 0..5 { - numbers.push(rng.random_range(1..101)); - } - numbers - } - "#}; - let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx)); - let snapshot = buffer.read(cx).snapshot(); - - // Ensure we try to fit the largest possible syntax scope, resorting to line-based expansion - // when a larger scope doesn't fit the editable region. - let excerpt = excerpt_for_cursor_position(Point::new(12, 5), "main.rs", &snapshot, 50, 32); - assert_eq!( - excerpt.prompt, - indoc! {r#" - ```main.rs - let x = 42; - println!("Hello, world!"); - <|editable_region_start|> - } - - fn bar() { - let x = 42; - let mut sum = 0; - for i in 0..x { - sum += i; - } - println!("Sum: {}", sum); - r<|user_cursor_is_here|>eturn sum; - } - - fn generate_random_numbers() -> Vec { - <|editable_region_end|> - let mut rng = rand::thread_rng(); - let mut numbers = Vec::new(); - ```"#} - ); - - // The `bar` function won't fit within the editable region, so we resort to line-based expansion. - let excerpt = excerpt_for_cursor_position(Point::new(12, 5), "main.rs", &snapshot, 40, 32); - assert_eq!( - excerpt.prompt, - indoc! {r#" - ```main.rs - fn bar() { - let x = 42; - let mut sum = 0; - <|editable_region_start|> - for i in 0..x { - sum += i; - } - println!("Sum: {}", sum); - r<|user_cursor_is_here|>eturn sum; - } - - fn generate_random_numbers() -> Vec { - let mut rng = rand::thread_rng(); - <|editable_region_end|> - let mut numbers = Vec::new(); - for _ in 0..5 { - numbers.push(rng.random_range(1..101)); - ```"#} - ); - } - - fn rust_lang() -> Language { - Language::new( - LanguageConfig { - name: "Rust".into(), - matcher: LanguageMatcher { - path_suffixes: vec!["rs".to_string()], - ..Default::default() - }, - ..Default::default() - }, - Some(tree_sitter_rust::LANGUAGE.into()), - ) - } -} diff --git a/crates/edit_prediction/src/zeta2.rs b/crates/edit_prediction/src/zeta2.rs index 4808f38fc529b1c34212dd0198d15fb03a0baddf..e542bc7e86e6e381766bbedac6a15f431e0693f1 100644 --- a/crates/edit_prediction/src/zeta2.rs +++ b/crates/edit_prediction/src/zeta2.rs @@ -1,5 +1,6 @@ #[cfg(feature = "eval-support")] use crate::EvalCacheEntryKind; +use crate::open_ai_response::text_from_response; use crate::prediction::EditPredictionResult; use crate::{ DebugEvent, EDIT_PREDICTIONS_MODEL_ID, EditPredictionId, EditPredictionInputs, @@ -199,7 +200,7 @@ pub fn request_prediction_with_zeta2( stream: false, max_completion_tokens: None, stop: generation_params.stop.unwrap_or_default(), - temperature: generation_params.temperature.unwrap_or(0.7), + temperature: generation_params.temperature.or(Some(0.7)), tool_choice: None, parallel_tool_calls: None, tools: vec![], @@ -324,35 +325,3 @@ pub fn request_prediction_with_zeta2( )) }) } - -pub fn text_from_response(mut res: open_ai::Response) -> Option { - let choice = res.choices.pop()?; - let output_text = match choice.message { - open_ai::RequestMessage::Assistant { - content: Some(open_ai::MessageContent::Plain(content)), - .. - } => content, - open_ai::RequestMessage::Assistant { - content: Some(open_ai::MessageContent::Multipart(mut content)), - .. - } => { - if content.is_empty() { - log::error!("No output from Baseten completion response"); - return None; - } - - match content.remove(0) { - open_ai::MessagePart::Text { text } => text, - open_ai::MessagePart::Image { .. } => { - log::error!("Expected text, got an image"); - return None; - } - } - } - _ => { - log::error!("Invalid response message: {:?}", choice.message); - return None; - } - }; - Some(output_text) -} diff --git a/crates/edit_prediction_cli/src/predict.rs b/crates/edit_prediction_cli/src/predict.rs index db1fed70d82a1a19713dfe54dfd6cea2bfa03d5d..74e939b887ce15790993ec15f5973c7f5fd01866 100644 --- a/crates/edit_prediction_cli/src/predict.rs +++ b/crates/edit_prediction_cli/src/predict.rs @@ -198,8 +198,9 @@ pub async fn perform_predict( let response = request.response_rx.await?.0.map_err(|err| anyhow!(err))?; - let response = edit_prediction::zeta2::text_from_response(response) - .unwrap_or_default(); + let response = + edit_prediction::open_ai_response::text_from_response(response) + .unwrap_or_default(); let prediction_finished_at = Instant::now(); fs::write(example_run_dir.join("prediction_response.md"), &response)?; diff --git a/crates/edit_prediction_ui/src/edit_prediction_button.rs b/crates/edit_prediction_ui/src/edit_prediction_button.rs index dd3ebab42029f5adb7570b71ae0cd662aff3328e..04c7614689c5fdc076ab0aa9c4b4fe7d68e2f582 100644 --- a/crates/edit_prediction_ui/src/edit_prediction_button.rs +++ b/crates/edit_prediction_ui/src/edit_prediction_button.rs @@ -3,7 +3,7 @@ use client::{Client, UserStore, zed_urls}; use cloud_llm_client::UsageLimit; use codestral::CodestralEditPredictionDelegate; use copilot::{Copilot, Status}; -use edit_prediction::{SweepFeatureFlag, Zeta2FeatureFlag}; +use edit_prediction::{MercuryFeatureFlag, SweepFeatureFlag, Zeta2FeatureFlag}; use edit_prediction_types::EditPredictionDelegateHandle; use editor::{ Editor, MultiBufferOffset, SelectionEffects, actions::ShowEditPrediction, scroll::Autoscroll, @@ -23,6 +23,7 @@ use language::{ use project::DisableAiSettings; use regex::Regex; use settings::{ + EXPERIMENTAL_MERCURY_EDIT_PREDICTION_PROVIDER_NAME, EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME, EXPERIMENTAL_ZETA2_EDIT_PREDICTION_PROVIDER_NAME, Settings, SettingsStore, update_settings_file, @@ -44,7 +45,7 @@ use workspace::{ use zed_actions::OpenBrowser; use crate::{ - RatePredictions, SweepApiKeyModal, + ExternalProviderApiKeyModal, RatePredictions, rate_prediction_modal::PredictEditsRatePredictionsFeatureFlag, }; @@ -311,21 +312,31 @@ impl Render for EditPredictionButton { provider @ (EditPredictionProvider::Experimental(_) | EditPredictionProvider::Zed) => { let enabled = self.editor_enabled.unwrap_or(true); - let is_sweep = matches!( - provider, - EditPredictionProvider::Experimental( - EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME - ) - ); - - let sweep_missing_token = is_sweep - && !edit_prediction::EditPredictionStore::try_global(cx) - .map_or(false, |ep_store| ep_store.read(cx).has_sweep_api_token()); + let ep_icon; + let mut missing_token = false; - let ep_icon = match (is_sweep, enabled) { - (true, _) => IconName::SweepAi, - (false, true) => IconName::ZedPredict, - (false, false) => IconName::ZedPredictDisabled, + match provider { + EditPredictionProvider::Experimental( + EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME, + ) => { + ep_icon = IconName::SweepAi; + missing_token = edit_prediction::EditPredictionStore::try_global(cx) + .is_some_and(|ep_store| !ep_store.read(cx).has_sweep_api_token()); + } + EditPredictionProvider::Experimental( + EXPERIMENTAL_MERCURY_EDIT_PREDICTION_PROVIDER_NAME, + ) => { + ep_icon = IconName::Inception; + missing_token = edit_prediction::EditPredictionStore::try_global(cx) + .is_some_and(|ep_store| !ep_store.read(cx).has_mercury_api_token()); + } + _ => { + ep_icon = if enabled { + IconName::ZedPredict + } else { + IconName::ZedPredictDisabled + }; + } }; if edit_prediction::should_show_upsell_modal() { @@ -369,7 +380,7 @@ impl Render for EditPredictionButton { let show_editor_predictions = self.editor_show_predictions; let user = self.user_store.read(cx).current_user(); - let indicator_color = if sweep_missing_token { + let indicator_color = if missing_token { Some(Color::Error) } else if enabled && (!show_editor_predictions || over_limit) { Some(if over_limit { @@ -532,6 +543,12 @@ impl EditPredictionButton { )); } + if cx.has_flag::() { + providers.push(EditPredictionProvider::Experimental( + EXPERIMENTAL_MERCURY_EDIT_PREDICTION_PROVIDER_NAME, + )); + } + if cx.has_flag::() { providers.push(EditPredictionProvider::Experimental( EXPERIMENTAL_ZETA2_EDIT_PREDICTION_PROVIDER_NAME, @@ -628,7 +645,66 @@ impl EditPredictionButton { if let Some(workspace) = window.root::().flatten() { workspace.update(cx, |workspace, cx| { workspace.toggle_modal(window, cx, |window, cx| { - SweepApiKeyModal::new(window, cx) + ExternalProviderApiKeyModal::new( + window, + cx, + |api_key, store, cx| { + store + .sweep_ai + .set_api_token(api_key, cx) + .detach_and_log_err(cx); + }, + ) + }); + }); + }; + } else { + set_completion_provider(fs.clone(), cx, provider); + } + }); + + menu.item(entry) + } + EditPredictionProvider::Experimental( + EXPERIMENTAL_MERCURY_EDIT_PREDICTION_PROVIDER_NAME, + ) => { + let has_api_token = edit_prediction::EditPredictionStore::try_global(cx) + .map_or(false, |ep_store| ep_store.read(cx).has_mercury_api_token()); + + let should_open_modal = !has_api_token || is_current; + + let entry = if has_api_token { + ContextMenuEntry::new("Mercury") + .toggleable(IconPosition::Start, is_current) + } else { + ContextMenuEntry::new("Mercury") + .icon(IconName::XCircle) + .icon_color(Color::Error) + .documentation_aside( + DocumentationSide::Left, + DocumentationEdge::Bottom, + |_| { + Label::new("Click to configure your Mercury API token") + .into_any_element() + }, + ) + }; + + let entry = entry.handler(move |window, cx| { + if should_open_modal { + if let Some(workspace) = window.root::().flatten() { + workspace.update(cx, |workspace, cx| { + workspace.toggle_modal(window, cx, |window, cx| { + ExternalProviderApiKeyModal::new( + window, + cx, + |api_key, store, cx| { + store + .mercury + .set_api_token(api_key, cx) + .detach_and_log_err(cx); + }, + ) }); }); }; diff --git a/crates/edit_prediction_ui/src/edit_prediction_ui.rs b/crates/edit_prediction_ui/src/edit_prediction_ui.rs index 51b491c6b3512968bca4ce2e7ed73a505bd73a00..c177b5233c33feb4f5ff82f60bf3fb6981cf3ee8 100644 --- a/crates/edit_prediction_ui/src/edit_prediction_ui.rs +++ b/crates/edit_prediction_ui/src/edit_prediction_ui.rs @@ -1,7 +1,7 @@ mod edit_prediction_button; mod edit_prediction_context_view; +mod external_provider_api_token_modal; mod rate_prediction_modal; -mod sweep_api_token_modal; use std::any::{Any as _, TypeId}; @@ -17,7 +17,7 @@ use ui::{App, prelude::*}; use workspace::{SplitDirection, Workspace}; pub use edit_prediction_button::{EditPredictionButton, ToggleMenu}; -pub use sweep_api_token_modal::SweepApiKeyModal; +pub use external_provider_api_token_modal::ExternalProviderApiKeyModal; use crate::rate_prediction_modal::PredictEditsRatePredictionsFeatureFlag; diff --git a/crates/edit_prediction_ui/src/sweep_api_token_modal.rs b/crates/edit_prediction_ui/src/external_provider_api_token_modal.rs similarity index 72% rename from crates/edit_prediction_ui/src/sweep_api_token_modal.rs rename to crates/edit_prediction_ui/src/external_provider_api_token_modal.rs index 80366fc2ac691f165d44e1e6a29a633522146984..bc312836e9fdd30237156ac532a055d1e23a2589 100644 --- a/crates/edit_prediction_ui/src/sweep_api_token_modal.rs +++ b/crates/edit_prediction_ui/src/external_provider_api_token_modal.rs @@ -6,18 +6,24 @@ use ui::{Button, ButtonStyle, Clickable, Headline, HeadlineSize, prelude::*}; use ui_input::InputField; use workspace::ModalView; -pub struct SweepApiKeyModal { +pub struct ExternalProviderApiKeyModal { api_key_input: Entity, focus_handle: FocusHandle, + on_confirm: Box, &mut EditPredictionStore, &mut App)>, } -impl SweepApiKeyModal { - pub fn new(window: &mut Window, cx: &mut Context) -> Self { - let api_key_input = cx.new(|cx| InputField::new(window, cx, "Enter your Sweep API token")); +impl ExternalProviderApiKeyModal { + pub fn new( + window: &mut Window, + cx: &mut Context, + on_confirm: impl Fn(Option, &mut EditPredictionStore, &mut App) + 'static, + ) -> Self { + let api_key_input = cx.new(|cx| InputField::new(window, cx, "Enter your API key")); Self { api_key_input, focus_handle: cx.focus_handle(), + on_confirm: Box::new(on_confirm), } } @@ -30,39 +36,34 @@ impl SweepApiKeyModal { let api_key = (!api_key.trim().is_empty()).then_some(api_key); if let Some(ep_store) = EditPredictionStore::try_global(cx) { - ep_store.update(cx, |ep_store, cx| { - ep_store - .sweep_ai - .set_api_token(api_key, cx) - .detach_and_log_err(cx); - }); + ep_store.update(cx, |ep_store, cx| (self.on_confirm)(api_key, ep_store, cx)) } cx.emit(DismissEvent); } } -impl EventEmitter for SweepApiKeyModal {} +impl EventEmitter for ExternalProviderApiKeyModal {} -impl ModalView for SweepApiKeyModal {} +impl ModalView for ExternalProviderApiKeyModal {} -impl Focusable for SweepApiKeyModal { +impl Focusable for ExternalProviderApiKeyModal { fn focus_handle(&self, _cx: &App) -> FocusHandle { self.focus_handle.clone() } } -impl Render for SweepApiKeyModal { +impl Render for ExternalProviderApiKeyModal { fn render(&mut self, _window: &mut Window, cx: &mut Context) -> impl IntoElement { v_flex() - .key_context("SweepApiKeyModal") + .key_context("ExternalApiKeyModal") .on_action(cx.listener(Self::cancel)) .on_action(cx.listener(Self::confirm)) .elevation_2(cx) .w(px(400.)) .p_4() .gap_3() - .child(Headline::new("Sweep API Token").size(HeadlineSize::Small)) + .child(Headline::new("API Token").size(HeadlineSize::Small)) .child(self.api_key_input.clone()) .child( h_flex() diff --git a/crates/icons/src/icons.rs b/crates/icons/src/icons.rs index d28e2c1030c3c2378aa7997f4799c503cee97105..d79660356f04fd42425d9e549764a4c202d29d43 100644 --- a/crates/icons/src/icons.rs +++ b/crates/icons/src/icons.rs @@ -34,8 +34,8 @@ pub enum IconName { ArrowRightLeft, ArrowUp, ArrowUpRight, - Attach, AtSign, + Attach, AudioOff, AudioOn, Backspace, @@ -45,8 +45,8 @@ pub enum IconName { BellRing, Binary, Blocks, - BoltOutlined, BoltFilled, + BoltOutlined, Book, BookCopy, CaseSensitive, @@ -80,9 +80,9 @@ pub enum IconName { Debug, DebugBreakpoint, DebugContinue, + DebugDetach, DebugDisabledBreakpoint, DebugDisabledLogBreakpoint, - DebugDetach, DebugIgnoreBreakpoints, DebugLogBreakpoint, DebugPause, @@ -140,6 +140,7 @@ pub enum IconName { Hash, HistoryRerun, Image, + Inception, Indicator, Info, Json, @@ -147,6 +148,7 @@ pub enum IconName { Library, LineHeight, Link, + Linux, ListCollapse, ListFilter, ListTodo, @@ -172,8 +174,8 @@ pub enum IconName { PencilUnavailable, Person, Pin, - PlayOutlined, PlayFilled, + PlayOutlined, Plus, Power, Public, @@ -259,15 +261,14 @@ pub enum IconName { ZedAssistant, ZedBurnMode, ZedBurnModeOn, - ZedSrcCustom, - ZedSrcExtension, ZedPredict, ZedPredictDisabled, ZedPredictDown, ZedPredictError, ZedPredictUp, + ZedSrcCustom, + ZedSrcExtension, ZedXCopilot, - Linux, } impl IconName { diff --git a/crates/language_models/src/provider/open_ai.rs b/crates/language_models/src/provider/open_ai.rs index 46cea34e3e01cb0f8ad0f859827881f3ec74cad7..32ee95ce9bd423bf7f66efc1bc7440455380ab5c 100644 --- a/crates/language_models/src/provider/open_ai.rs +++ b/crates/language_models/src/provider/open_ai.rs @@ -438,7 +438,7 @@ pub fn into_open_ai( messages, stream, stop: request.stop, - temperature: request.temperature.unwrap_or(1.0), + temperature: request.temperature.or(Some(1.0)), max_completion_tokens: max_output_tokens, parallel_tool_calls: if supports_parallel_tool_calls && !request.tools.is_empty() { // Disable parallel tool calls, as the Agent currently expects a maximum of one per turn. diff --git a/crates/open_ai/src/open_ai.rs b/crates/open_ai/src/open_ai.rs index 6fdb393c9a13c7ff6a6981f949b4d0c865b9bff8..8ed70c9dd514cb59f5c7a160169031cbc28428e6 100644 --- a/crates/open_ai/src/open_ai.rs +++ b/crates/open_ai/src/open_ai.rs @@ -266,7 +266,8 @@ pub struct Request { pub max_completion_tokens: Option, #[serde(default, skip_serializing_if = "Vec::is_empty")] pub stop: Vec, - pub temperature: f32, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub temperature: Option, #[serde(default, skip_serializing_if = "Option::is_none")] pub tool_choice: Option, /// Whether to enable parallel function calling during tool use. diff --git a/crates/settings/src/settings_content/language.rs b/crates/settings/src/settings_content/language.rs index b466b4e0dd88bf41e0f77f67a38842305c11906f..25ff60e9f46cf797b815227222a3d27a6353c396 100644 --- a/crates/settings/src/settings_content/language.rs +++ b/crates/settings/src/settings_content/language.rs @@ -81,6 +81,7 @@ pub enum EditPredictionProvider { pub const EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME: &str = "sweep"; pub const EXPERIMENTAL_ZETA2_EDIT_PREDICTION_PROVIDER_NAME: &str = "zeta2"; +pub const EXPERIMENTAL_MERCURY_EDIT_PREDICTION_PROVIDER_NAME: &str = "mercury"; impl<'de> Deserialize<'de> for EditPredictionProvider { fn deserialize(deserializer: D) -> Result @@ -111,6 +112,13 @@ impl<'de> Deserialize<'de> for EditPredictionProvider { EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME, ) } + Content::Experimental(name) + if name == EXPERIMENTAL_MERCURY_EDIT_PREDICTION_PROVIDER_NAME => + { + EditPredictionProvider::Experimental( + EXPERIMENTAL_MERCURY_EDIT_PREDICTION_PROVIDER_NAME, + ) + } Content::Experimental(name) if name == EXPERIMENTAL_ZETA2_EDIT_PREDICTION_PROVIDER_NAME => { diff --git a/crates/zed/src/zed/edit_prediction_registry.rs b/crates/zed/src/zed/edit_prediction_registry.rs index 2d5746b87ab20de5d0aca47a4d5da60b9ec33d2a..77a1f71596f9cf1d2f4e32137580d0e3648359f5 100644 --- a/crates/zed/src/zed/edit_prediction_registry.rs +++ b/crates/zed/src/zed/edit_prediction_registry.rs @@ -9,6 +9,7 @@ use gpui::{AnyWindowHandle, App, AppContext as _, Context, Entity, WeakEntity}; use language::language_settings::{EditPredictionProvider, all_language_settings}; use language_models::MistralLanguageModelProvider; use settings::{ + EXPERIMENTAL_MERCURY_EDIT_PREDICTION_PROVIDER_NAME, EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME, EXPERIMENTAL_ZETA2_EDIT_PREDICTION_PROVIDER_NAME, SettingsStore, }; @@ -219,6 +220,10 @@ fn assign_edit_prediction_provider( && cx.has_flag::() { edit_prediction::EditPredictionModel::Zeta2 + } else if name == EXPERIMENTAL_MERCURY_EDIT_PREDICTION_PROVIDER_NAME + && cx.has_flag::() + { + edit_prediction::EditPredictionModel::Mercury } else { return false; }