Detailed changes
@@ -0,0 +1,11 @@
+<svg width="28" height="28" viewBox="0 0 28 28" fill="none" id="svg1378540956_510">
+<g clip-path="url(#svg1378540956_510_clip0_1_1506)" transform="translate(4, 4) scale(0.857)">
+<path d="M17.0547 0.372066H8.52652L-0.00165176 8.90024V17.4284H8.52652V8.90024H17.0547V0.372066Z" fill="#1A1C20"></path>
+<path d="M10.1992 27.6279H18.7274L27.2556 19.0998V10.5716H18.7274V19.0998H10.1992V27.6279Z" fill="#1A1C20"></path>
+</g>
+<defs>
+<clipPath id="svg1378540956_510_clip0_1_1506">
+<rect width="27.2559" height="27.2559" fill="white" transform="translate(0 0.37207)"></rect>
+</clipPath>
+</defs>
+</svg>
@@ -0,0 +1,78 @@
+use language::{BufferSnapshot, Point};
+use std::ops::Range;
+
+pub fn editable_and_context_ranges_for_cursor_position(
+ position: Point,
+ snapshot: &BufferSnapshot,
+ editable_region_token_limit: usize,
+ context_token_limit: usize,
+) -> (Range<Point>, Range<Point>) {
+ let mut scope_range = position..position;
+ let mut remaining_edit_tokens = editable_region_token_limit;
+
+ while let Some(parent) = snapshot.syntax_ancestor(scope_range.clone()) {
+ let parent_tokens = guess_token_count(parent.byte_range().len());
+ let parent_point_range = Point::new(
+ parent.start_position().row as u32,
+ parent.start_position().column as u32,
+ )
+ ..Point::new(
+ parent.end_position().row as u32,
+ parent.end_position().column as u32,
+ );
+ if parent_point_range == scope_range {
+ break;
+ } else if parent_tokens <= editable_region_token_limit {
+ scope_range = parent_point_range;
+ remaining_edit_tokens = editable_region_token_limit - parent_tokens;
+ } else {
+ break;
+ }
+ }
+
+ let editable_range = expand_range(snapshot, scope_range, remaining_edit_tokens);
+ let context_range = expand_range(snapshot, editable_range.clone(), context_token_limit);
+ (editable_range, context_range)
+}
+
+fn expand_range(
+ snapshot: &BufferSnapshot,
+ range: Range<Point>,
+ mut remaining_tokens: usize,
+) -> Range<Point> {
+ let mut expanded_range = range;
+ expanded_range.start.column = 0;
+ expanded_range.end.column = snapshot.line_len(expanded_range.end.row);
+ loop {
+ let mut expanded = false;
+
+ if remaining_tokens > 0 && expanded_range.start.row > 0 {
+ expanded_range.start.row -= 1;
+ let line_tokens =
+ guess_token_count(snapshot.line_len(expanded_range.start.row) as usize);
+ remaining_tokens = remaining_tokens.saturating_sub(line_tokens);
+ expanded = true;
+ }
+
+ if remaining_tokens > 0 && expanded_range.end.row < snapshot.max_point().row {
+ expanded_range.end.row += 1;
+ expanded_range.end.column = snapshot.line_len(expanded_range.end.row);
+ let line_tokens = guess_token_count(expanded_range.end.column as usize);
+ remaining_tokens = remaining_tokens.saturating_sub(line_tokens);
+ expanded = true;
+ }
+
+ if !expanded {
+ break;
+ }
+ }
+ expanded_range
+}
+
+/// Typical number of string bytes per token for the purposes of limiting model input. This is
+/// intentionally low to err on the side of underestimating limits.
+pub(crate) const BYTES_PER_TOKEN_GUESS: usize = 3;
+
+pub fn guess_token_count(bytes: usize) -> usize {
+ bytes / BYTES_PER_TOKEN_GUESS
+}
@@ -51,8 +51,11 @@ use thiserror::Error;
use util::{RangeExt as _, ResultExt as _};
use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
+mod cursor_excerpt;
mod license_detection;
+pub mod mercury;
mod onboarding_modal;
+pub mod open_ai_response;
mod prediction;
pub mod sweep_ai;
pub mod udiff;
@@ -65,6 +68,7 @@ pub mod zeta2;
mod edit_prediction_tests;
use crate::license_detection::LicenseDetectionWatcher;
+use crate::mercury::Mercury;
use crate::onboarding_modal::ZedPredictModal;
pub use crate::prediction::EditPrediction;
pub use crate::prediction::EditPredictionId;
@@ -96,6 +100,12 @@ impl FeatureFlag for SweepFeatureFlag {
const NAME: &str = "sweep-ai";
}
+pub struct MercuryFeatureFlag;
+
+impl FeatureFlag for MercuryFeatureFlag {
+ const NAME: &str = "mercury";
+}
+
pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
context: EditPredictionExcerptOptions {
max_bytes: 512,
@@ -157,6 +167,7 @@ pub struct EditPredictionStore {
eval_cache: Option<Arc<dyn EvalCache>>,
edit_prediction_model: EditPredictionModel,
pub sweep_ai: SweepAi,
+ pub mercury: Mercury,
data_collection_choice: DataCollectionChoice,
reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejection>,
shown_predictions: VecDeque<EditPrediction>,
@@ -169,6 +180,7 @@ pub enum EditPredictionModel {
Zeta1,
Zeta2,
Sweep,
+ Mercury,
}
#[derive(Debug, Clone, PartialEq)]
@@ -474,6 +486,7 @@ impl EditPredictionStore {
eval_cache: None,
edit_prediction_model: EditPredictionModel::Zeta2,
sweep_ai: SweepAi::new(cx),
+ mercury: Mercury::new(cx),
data_collection_choice,
reject_predictions_tx: reject_tx,
rated_predictions: Default::default(),
@@ -509,6 +522,15 @@ impl EditPredictionStore {
.is_some()
}
+ pub fn has_mercury_api_token(&self) -> bool {
+ self.mercury
+ .api_token
+ .clone()
+ .now_or_never()
+ .flatten()
+ .is_some()
+ }
+
#[cfg(feature = "eval-support")]
pub fn with_eval_cache(&mut self, cache: Arc<dyn EvalCache>) {
self.eval_cache = Some(cache);
@@ -868,7 +890,7 @@ impl EditPredictionStore {
fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
match self.edit_prediction_model {
EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {}
- EditPredictionModel::Sweep => return,
+ EditPredictionModel::Sweep | EditPredictionModel::Mercury => return,
}
let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
@@ -1013,7 +1035,7 @@ impl EditPredictionStore {
) {
match self.edit_prediction_model {
EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {}
- EditPredictionModel::Sweep => return,
+ EditPredictionModel::Sweep | EditPredictionModel::Mercury => return,
}
self.reject_predictions_tx
@@ -1373,6 +1395,17 @@ impl EditPredictionStore {
diagnostic_search_range.clone(),
cx,
),
+ EditPredictionModel::Mercury => self.mercury.request_prediction(
+ &project,
+ &active_buffer,
+ snapshot.clone(),
+ position,
+ events,
+ &project_state.recent_paths,
+ related_files,
+ diagnostic_search_range.clone(),
+ cx,
+ ),
};
cx.spawn(async move |this, cx| {
@@ -1620,7 +1620,7 @@ async fn test_no_data_collection_for_events_in_uncollectable_buffers(cx: &mut Te
buffer.edit(
[(
0..0,
- " ".repeat(MAX_EVENT_TOKENS * zeta1::BYTES_PER_TOKEN_GUESS),
+ " ".repeat(MAX_EVENT_TOKENS * cursor_excerpt::BYTES_PER_TOKEN_GUESS),
)],
None,
cx,
@@ -0,0 +1,340 @@
+use anyhow::{Context as _, Result};
+use cloud_llm_client::predict_edits_v3::Event;
+use credentials_provider::CredentialsProvider;
+use edit_prediction_context::RelatedFile;
+use futures::{AsyncReadExt as _, FutureExt, future::Shared};
+use gpui::{
+ App, AppContext as _, Entity, Task,
+ http_client::{self, AsyncBody, Method},
+};
+use language::{Buffer, BufferSnapshot, OffsetRangeExt as _, Point, ToPoint as _};
+use project::{Project, ProjectPath};
+use std::{
+ collections::VecDeque, fmt::Write as _, mem, ops::Range, path::Path, sync::Arc, time::Instant,
+};
+
+use crate::{
+ EditPredictionId, EditPredictionInputs, open_ai_response::text_from_response,
+ prediction::EditPredictionResult,
+};
+
+const MERCURY_API_URL: &str = "https://api.inceptionlabs.ai/v1/edit/completions";
+const MAX_CONTEXT_TOKENS: usize = 150;
+const MAX_REWRITE_TOKENS: usize = 350;
+
+pub struct Mercury {
+ pub api_token: Shared<Task<Option<String>>>,
+}
+
+impl Mercury {
+ pub fn new(cx: &App) -> Self {
+ Mercury {
+ api_token: load_api_token(cx).shared(),
+ }
+ }
+
+ pub fn set_api_token(&mut self, api_token: Option<String>, cx: &mut App) -> Task<Result<()>> {
+ self.api_token = Task::ready(api_token.clone()).shared();
+ store_api_token_in_keychain(api_token, cx)
+ }
+
+ pub fn request_prediction(
+ &self,
+ _project: &Entity<Project>,
+ active_buffer: &Entity<Buffer>,
+ snapshot: BufferSnapshot,
+ position: language::Anchor,
+ events: Vec<Arc<Event>>,
+ _recent_paths: &VecDeque<ProjectPath>,
+ related_files: Vec<RelatedFile>,
+ _diagnostic_search_range: Range<Point>,
+ cx: &mut App,
+ ) -> Task<Result<Option<EditPredictionResult>>> {
+ let Some(api_token) = self.api_token.clone().now_or_never().flatten() else {
+ return Task::ready(Ok(None));
+ };
+ let full_path: Arc<Path> = snapshot
+ .file()
+ .map(|file| file.full_path(cx))
+ .unwrap_or_else(|| "untitled".into())
+ .into();
+
+ let http_client = cx.http_client();
+ let cursor_point = position.to_point(&snapshot);
+ let buffer_snapshotted_at = Instant::now();
+
+ let result = cx.background_spawn(async move {
+ let (editable_range, context_range) =
+ crate::cursor_excerpt::editable_and_context_ranges_for_cursor_position(
+ cursor_point,
+ &snapshot,
+ MAX_CONTEXT_TOKENS,
+ MAX_REWRITE_TOKENS,
+ );
+
+ let offset_range = editable_range.to_offset(&snapshot);
+ let prompt = build_prompt(
+ &events,
+ &related_files,
+ &snapshot,
+ full_path.as_ref(),
+ cursor_point,
+ editable_range,
+ context_range.clone(),
+ );
+
+ let inputs = EditPredictionInputs {
+ events: events,
+ included_files: vec![cloud_llm_client::predict_edits_v3::RelatedFile {
+ path: full_path.clone(),
+ max_row: cloud_llm_client::predict_edits_v3::Line(snapshot.max_point().row),
+ excerpts: vec![cloud_llm_client::predict_edits_v3::Excerpt {
+ start_line: cloud_llm_client::predict_edits_v3::Line(
+ context_range.start.row,
+ ),
+ text: snapshot
+ .text_for_range(context_range.clone())
+ .collect::<String>()
+ .into(),
+ }],
+ }],
+ cursor_point: cloud_llm_client::predict_edits_v3::Point {
+ column: cursor_point.column,
+ line: cloud_llm_client::predict_edits_v3::Line(cursor_point.row),
+ },
+ cursor_path: full_path.clone(),
+ };
+
+ let request_body = open_ai::Request {
+ model: "mercury-coder".into(),
+ messages: vec![open_ai::RequestMessage::User {
+ content: open_ai::MessageContent::Plain(prompt),
+ }],
+ stream: false,
+ max_completion_tokens: None,
+ stop: vec![],
+ temperature: None,
+ tool_choice: None,
+ parallel_tool_calls: None,
+ tools: vec![],
+ prompt_cache_key: None,
+ reasoning_effort: None,
+ };
+
+ let buf = serde_json::to_vec(&request_body)?;
+ let body: AsyncBody = buf.into();
+
+ let request = http_client::Request::builder()
+ .uri(MERCURY_API_URL)
+ .header("Content-Type", "application/json")
+ .header("Authorization", format!("Bearer {}", api_token))
+ .header("Connection", "keep-alive")
+ .method(Method::POST)
+ .body(body)
+ .context("Failed to create request")?;
+
+ let mut response = http_client
+ .send(request)
+ .await
+ .context("Failed to send request")?;
+
+ let mut body: Vec<u8> = Vec::new();
+ response
+ .body_mut()
+ .read_to_end(&mut body)
+ .await
+ .context("Failed to read response body")?;
+
+ let response_received_at = Instant::now();
+ if !response.status().is_success() {
+ anyhow::bail!(
+ "Request failed with status: {:?}\nBody: {}",
+ response.status(),
+ String::from_utf8_lossy(&body),
+ );
+ };
+
+ let mut response: open_ai::Response =
+ serde_json::from_slice(&body).context("Failed to parse response")?;
+
+ let id = mem::take(&mut response.id);
+ let response_str = text_from_response(response).unwrap_or_default();
+
+ let response_str = response_str.strip_prefix("```\n").unwrap_or(&response_str);
+ let response_str = response_str.strip_suffix("\n```").unwrap_or(&response_str);
+
+ let mut edits = Vec::new();
+ const NO_PREDICTION_OUTPUT: &str = "None";
+
+ if response_str != NO_PREDICTION_OUTPUT {
+ let old_text = snapshot
+ .text_for_range(offset_range.clone())
+ .collect::<String>();
+ edits.extend(
+ language::text_diff(&old_text, &response_str)
+ .into_iter()
+ .map(|(range, text)| {
+ (
+ snapshot.anchor_after(offset_range.start + range.start)
+ ..snapshot.anchor_before(offset_range.start + range.end),
+ text,
+ )
+ }),
+ );
+ }
+
+ anyhow::Ok((id, edits, snapshot, response_received_at, inputs))
+ });
+
+ let buffer = active_buffer.clone();
+
+ cx.spawn(async move |cx| {
+ let (id, edits, old_snapshot, response_received_at, inputs) =
+ result.await.context("Mercury edit prediction failed")?;
+ anyhow::Ok(Some(
+ EditPredictionResult::new(
+ EditPredictionId(id.into()),
+ &buffer,
+ &old_snapshot,
+ edits.into(),
+ buffer_snapshotted_at,
+ response_received_at,
+ inputs,
+ cx,
+ )
+ .await,
+ ))
+ })
+ }
+}
+
+fn build_prompt(
+ events: &[Arc<Event>],
+ related_files: &[RelatedFile],
+ cursor_buffer: &BufferSnapshot,
+ cursor_buffer_path: &Path,
+ cursor_point: Point,
+ editable_range: Range<Point>,
+ context_range: Range<Point>,
+) -> String {
+ const RECENTLY_VIEWED_SNIPPETS_START: &str = "<|recently_viewed_code_snippets|>\n";
+ const RECENTLY_VIEWED_SNIPPETS_END: &str = "<|/recently_viewed_code_snippets|>\n";
+ const RECENTLY_VIEWED_SNIPPET_START: &str = "<|recently_viewed_code_snippet|>\n";
+ const RECENTLY_VIEWED_SNIPPET_END: &str = "<|/recently_viewed_code_snippet|>\n";
+ const CURRENT_FILE_CONTENT_START: &str = "<|current_file_content|>\n";
+ const CURRENT_FILE_CONTENT_END: &str = "<|/current_file_content|>\n";
+ const CODE_TO_EDIT_START: &str = "<|code_to_edit|>\n";
+ const CODE_TO_EDIT_END: &str = "<|/code_to_edit|>\n";
+ const EDIT_DIFF_HISTORY_START: &str = "<|edit_diff_history|>\n";
+ const EDIT_DIFF_HISTORY_END: &str = "<|/edit_diff_history|>\n";
+ const CURSOR_TAG: &str = "<|cursor|>";
+ const CODE_SNIPPET_FILE_PATH_PREFIX: &str = "code_snippet_file_path: ";
+ const CURRENT_FILE_PATH_PREFIX: &str = "current_file_path: ";
+
+ let mut prompt = String::new();
+
+ push_delimited(
+ &mut prompt,
+ RECENTLY_VIEWED_SNIPPETS_START..RECENTLY_VIEWED_SNIPPETS_END,
+ |prompt| {
+ for related_file in related_files {
+ for related_excerpt in &related_file.excerpts {
+ push_delimited(
+ prompt,
+ RECENTLY_VIEWED_SNIPPET_START..RECENTLY_VIEWED_SNIPPET_END,
+ |prompt| {
+ prompt.push_str(CODE_SNIPPET_FILE_PATH_PREFIX);
+ prompt.push_str(related_file.path.path.as_unix_str());
+ prompt.push('\n');
+ prompt.push_str(&related_excerpt.text.to_string());
+ },
+ );
+ }
+ }
+ },
+ );
+
+ push_delimited(
+ &mut prompt,
+ CURRENT_FILE_CONTENT_START..CURRENT_FILE_CONTENT_END,
+ |prompt| {
+ prompt.push_str(CURRENT_FILE_PATH_PREFIX);
+ prompt.push_str(cursor_buffer_path.as_os_str().to_string_lossy().as_ref());
+ prompt.push('\n');
+
+ let prefix_range = context_range.start..editable_range.start;
+ let suffix_range = editable_range.end..context_range.end;
+
+ prompt.extend(cursor_buffer.text_for_range(prefix_range));
+ push_delimited(prompt, CODE_TO_EDIT_START..CODE_TO_EDIT_END, |prompt| {
+ let range_before_cursor = editable_range.start..cursor_point;
+ let range_after_cursor = cursor_point..editable_range.end;
+ prompt.extend(cursor_buffer.text_for_range(range_before_cursor));
+ prompt.push_str(CURSOR_TAG);
+ prompt.extend(cursor_buffer.text_for_range(range_after_cursor));
+ });
+ prompt.extend(cursor_buffer.text_for_range(suffix_range));
+ },
+ );
+
+ push_delimited(
+ &mut prompt,
+ EDIT_DIFF_HISTORY_START..EDIT_DIFF_HISTORY_END,
+ |prompt| {
+ for event in events {
+ writeln!(prompt, "{event}").unwrap();
+ }
+ },
+ );
+
+ prompt
+}
+
+fn push_delimited(prompt: &mut String, delimiters: Range<&str>, cb: impl FnOnce(&mut String)) {
+ prompt.push_str(delimiters.start);
+ cb(prompt);
+ prompt.push_str(delimiters.end);
+}
+
+pub const MERCURY_CREDENTIALS_URL: &str = "https://api.inceptionlabs.ai/v1/edit/completions";
+pub const MERCURY_CREDENTIALS_USERNAME: &str = "mercury-api-token";
+
+pub fn load_api_token(cx: &App) -> Task<Option<String>> {
+ if let Some(api_token) = std::env::var("MERCURY_AI_TOKEN")
+ .ok()
+ .filter(|value| !value.is_empty())
+ {
+ return Task::ready(Some(api_token));
+ }
+ let credentials_provider = <dyn CredentialsProvider>::global(cx);
+ cx.spawn(async move |cx| {
+ let (_, credentials) = credentials_provider
+ .read_credentials(MERCURY_CREDENTIALS_URL, &cx)
+ .await
+ .ok()??;
+ String::from_utf8(credentials).ok()
+ })
+}
+
+fn store_api_token_in_keychain(api_token: Option<String>, cx: &App) -> Task<Result<()>> {
+ let credentials_provider = <dyn CredentialsProvider>::global(cx);
+
+ cx.spawn(async move |cx| {
+ if let Some(api_token) = api_token {
+ credentials_provider
+ .write_credentials(
+ MERCURY_CREDENTIALS_URL,
+ MERCURY_CREDENTIALS_USERNAME,
+ api_token.as_bytes(),
+ cx,
+ )
+ .await
+ .context("Failed to save Mercury API token to system keychain")
+ } else {
+ credentials_provider
+ .delete_credentials(MERCURY_CREDENTIALS_URL, cx)
+ .await
+ .context("Failed to delete Mercury API token from system keychain")
+ }
+ })
+}
@@ -0,0 +1,31 @@
+pub fn text_from_response(mut res: open_ai::Response) -> Option<String> {
+ let choice = res.choices.pop()?;
+ let output_text = match choice.message {
+ open_ai::RequestMessage::Assistant {
+ content: Some(open_ai::MessageContent::Plain(content)),
+ ..
+ } => content,
+ open_ai::RequestMessage::Assistant {
+ content: Some(open_ai::MessageContent::Multipart(mut content)),
+ ..
+ } => {
+ if content.is_empty() {
+ log::error!("No output from Baseten completion response");
+ return None;
+ }
+
+ match content.remove(0) {
+ open_ai::MessagePart::Text { text } => text,
+ open_ai::MessagePart::Image { .. } => {
+ log::error!("Expected text, got an image");
+ return None;
+ }
+ }
+ }
+ _ => {
+ log::error!("Invalid response message: {:?}", choice.message);
+ return None;
+ }
+ };
+ Some(output_text)
+}
@@ -1,9 +1,8 @@
-mod input_excerpt;
-
use std::{fmt::Write, ops::Range, path::Path, sync::Arc, time::Instant};
use crate::{
EditPredictionId, EditPredictionStore, ZedUpdateRequiredError,
+ cursor_excerpt::{editable_and_context_ranges_for_cursor_position, guess_token_count},
prediction::{EditPredictionInputs, EditPredictionResult},
};
use anyhow::{Context as _, Result};
@@ -12,7 +11,6 @@ use cloud_llm_client::{
predict_edits_v3::Event,
};
use gpui::{App, AppContext as _, AsyncApp, Context, Entity, SharedString, Task};
-use input_excerpt::excerpt_for_cursor_position;
use language::{
Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, Point, ToPoint as _, text_diff,
};
@@ -495,10 +493,174 @@ pub fn format_event(event: &Event) -> String {
}
}
-/// Typical number of string bytes per token for the purposes of limiting model input. This is
-/// intentionally low to err on the side of underestimating limits.
-pub(crate) const BYTES_PER_TOKEN_GUESS: usize = 3;
+#[derive(Debug)]
+pub struct InputExcerpt {
+ pub context_range: Range<Point>,
+ pub editable_range: Range<Point>,
+ pub prompt: String,
+}
+
+pub fn excerpt_for_cursor_position(
+ position: Point,
+ path: &str,
+ snapshot: &BufferSnapshot,
+ editable_region_token_limit: usize,
+ context_token_limit: usize,
+) -> InputExcerpt {
+ let (editable_range, context_range) = editable_and_context_ranges_for_cursor_position(
+ position,
+ snapshot,
+ editable_region_token_limit,
+ context_token_limit,
+ );
+
+ let mut prompt = String::new();
+
+ writeln!(&mut prompt, "```{path}").unwrap();
+ if context_range.start == Point::zero() {
+ writeln!(&mut prompt, "{START_OF_FILE_MARKER}").unwrap();
+ }
+
+ for chunk in snapshot.chunks(context_range.start..editable_range.start, false) {
+ prompt.push_str(chunk.text);
+ }
+
+ push_editable_range(position, snapshot, editable_range.clone(), &mut prompt);
+
+ for chunk in snapshot.chunks(editable_range.end..context_range.end, false) {
+ prompt.push_str(chunk.text);
+ }
+ write!(prompt, "\n```").unwrap();
+
+ InputExcerpt {
+ context_range,
+ editable_range,
+ prompt,
+ }
+}
+
+fn push_editable_range(
+ cursor_position: Point,
+ snapshot: &BufferSnapshot,
+ editable_range: Range<Point>,
+ prompt: &mut String,
+) {
+ writeln!(prompt, "{EDITABLE_REGION_START_MARKER}").unwrap();
+ for chunk in snapshot.chunks(editable_range.start..cursor_position, false) {
+ prompt.push_str(chunk.text);
+ }
+ prompt.push_str(CURSOR_MARKER);
+ for chunk in snapshot.chunks(cursor_position..editable_range.end, false) {
+ prompt.push_str(chunk.text);
+ }
+ write!(prompt, "\n{EDITABLE_REGION_END_MARKER}").unwrap();
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use gpui::{App, AppContext};
+ use indoc::indoc;
+ use language::{Buffer, Language, LanguageConfig, LanguageMatcher, tree_sitter_rust};
+ use std::sync::Arc;
+
+ #[gpui::test]
+ fn test_excerpt_for_cursor_position(cx: &mut App) {
+ let text = indoc! {r#"
+ fn foo() {
+ let x = 42;
+ println!("Hello, world!");
+ }
+
+ fn bar() {
+ let x = 42;
+ let mut sum = 0;
+ for i in 0..x {
+ sum += i;
+ }
+ println!("Sum: {}", sum);
+ return sum;
+ }
+
+ fn generate_random_numbers() -> Vec<i32> {
+ let mut rng = rand::thread_rng();
+ let mut numbers = Vec::new();
+ for _ in 0..5 {
+ numbers.push(rng.random_range(1..101));
+ }
+ numbers
+ }
+ "#};
+ let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
+ let snapshot = buffer.read(cx).snapshot();
+
+ // Ensure we try to fit the largest possible syntax scope, resorting to line-based expansion
+ // when a larger scope doesn't fit the editable region.
+ let excerpt = excerpt_for_cursor_position(Point::new(12, 5), "main.rs", &snapshot, 50, 32);
+ assert_eq!(
+ excerpt.prompt,
+ indoc! {r#"
+ ```main.rs
+ let x = 42;
+ println!("Hello, world!");
+ <|editable_region_start|>
+ }
+
+ fn bar() {
+ let x = 42;
+ let mut sum = 0;
+ for i in 0..x {
+ sum += i;
+ }
+ println!("Sum: {}", sum);
+ r<|user_cursor_is_here|>eturn sum;
+ }
-fn guess_token_count(bytes: usize) -> usize {
- bytes / BYTES_PER_TOKEN_GUESS
+ fn generate_random_numbers() -> Vec<i32> {
+ <|editable_region_end|>
+ let mut rng = rand::thread_rng();
+ let mut numbers = Vec::new();
+ ```"#}
+ );
+
+ // The `bar` function won't fit within the editable region, so we resort to line-based expansion.
+ let excerpt = excerpt_for_cursor_position(Point::new(12, 5), "main.rs", &snapshot, 40, 32);
+ assert_eq!(
+ excerpt.prompt,
+ indoc! {r#"
+ ```main.rs
+ fn bar() {
+ let x = 42;
+ let mut sum = 0;
+ <|editable_region_start|>
+ for i in 0..x {
+ sum += i;
+ }
+ println!("Sum: {}", sum);
+ r<|user_cursor_is_here|>eturn sum;
+ }
+
+ fn generate_random_numbers() -> Vec<i32> {
+ let mut rng = rand::thread_rng();
+ <|editable_region_end|>
+ let mut numbers = Vec::new();
+ for _ in 0..5 {
+ numbers.push(rng.random_range(1..101));
+ ```"#}
+ );
+ }
+
+ fn rust_lang() -> Language {
+ Language::new(
+ LanguageConfig {
+ name: "Rust".into(),
+ matcher: LanguageMatcher {
+ path_suffixes: vec!["rs".to_string()],
+ ..Default::default()
+ },
+ ..Default::default()
+ },
+ Some(tree_sitter_rust::LANGUAGE.into()),
+ )
+ }
}
@@ -1,231 +0,0 @@
-use super::{
- CURSOR_MARKER, EDITABLE_REGION_END_MARKER, EDITABLE_REGION_START_MARKER, START_OF_FILE_MARKER,
- guess_token_count,
-};
-use language::{BufferSnapshot, Point};
-use std::{fmt::Write, ops::Range};
-
-#[derive(Debug)]
-pub struct InputExcerpt {
- pub context_range: Range<Point>,
- pub editable_range: Range<Point>,
- pub prompt: String,
-}
-
-pub fn excerpt_for_cursor_position(
- position: Point,
- path: &str,
- snapshot: &BufferSnapshot,
- editable_region_token_limit: usize,
- context_token_limit: usize,
-) -> InputExcerpt {
- let mut scope_range = position..position;
- let mut remaining_edit_tokens = editable_region_token_limit;
-
- while let Some(parent) = snapshot.syntax_ancestor(scope_range.clone()) {
- let parent_tokens = guess_token_count(parent.byte_range().len());
- let parent_point_range = Point::new(
- parent.start_position().row as u32,
- parent.start_position().column as u32,
- )
- ..Point::new(
- parent.end_position().row as u32,
- parent.end_position().column as u32,
- );
- if parent_point_range == scope_range {
- break;
- } else if parent_tokens <= editable_region_token_limit {
- scope_range = parent_point_range;
- remaining_edit_tokens = editable_region_token_limit - parent_tokens;
- } else {
- break;
- }
- }
-
- let editable_range = expand_range(snapshot, scope_range, remaining_edit_tokens);
- let context_range = expand_range(snapshot, editable_range.clone(), context_token_limit);
-
- let mut prompt = String::new();
-
- writeln!(&mut prompt, "```{path}").unwrap();
- if context_range.start == Point::zero() {
- writeln!(&mut prompt, "{START_OF_FILE_MARKER}").unwrap();
- }
-
- for chunk in snapshot.chunks(context_range.start..editable_range.start, false) {
- prompt.push_str(chunk.text);
- }
-
- push_editable_range(position, snapshot, editable_range.clone(), &mut prompt);
-
- for chunk in snapshot.chunks(editable_range.end..context_range.end, false) {
- prompt.push_str(chunk.text);
- }
- write!(prompt, "\n```").unwrap();
-
- InputExcerpt {
- context_range,
- editable_range,
- prompt,
- }
-}
-
-fn push_editable_range(
- cursor_position: Point,
- snapshot: &BufferSnapshot,
- editable_range: Range<Point>,
- prompt: &mut String,
-) {
- writeln!(prompt, "{EDITABLE_REGION_START_MARKER}").unwrap();
- for chunk in snapshot.chunks(editable_range.start..cursor_position, false) {
- prompt.push_str(chunk.text);
- }
- prompt.push_str(CURSOR_MARKER);
- for chunk in snapshot.chunks(cursor_position..editable_range.end, false) {
- prompt.push_str(chunk.text);
- }
- write!(prompt, "\n{EDITABLE_REGION_END_MARKER}").unwrap();
-}
-
-fn expand_range(
- snapshot: &BufferSnapshot,
- range: Range<Point>,
- mut remaining_tokens: usize,
-) -> Range<Point> {
- let mut expanded_range = range;
- expanded_range.start.column = 0;
- expanded_range.end.column = snapshot.line_len(expanded_range.end.row);
- loop {
- let mut expanded = false;
-
- if remaining_tokens > 0 && expanded_range.start.row > 0 {
- expanded_range.start.row -= 1;
- let line_tokens =
- guess_token_count(snapshot.line_len(expanded_range.start.row) as usize);
- remaining_tokens = remaining_tokens.saturating_sub(line_tokens);
- expanded = true;
- }
-
- if remaining_tokens > 0 && expanded_range.end.row < snapshot.max_point().row {
- expanded_range.end.row += 1;
- expanded_range.end.column = snapshot.line_len(expanded_range.end.row);
- let line_tokens = guess_token_count(expanded_range.end.column as usize);
- remaining_tokens = remaining_tokens.saturating_sub(line_tokens);
- expanded = true;
- }
-
- if !expanded {
- break;
- }
- }
- expanded_range
-}
-
-#[cfg(test)]
-mod tests {
- use super::*;
- use gpui::{App, AppContext};
- use indoc::indoc;
- use language::{Buffer, Language, LanguageConfig, LanguageMatcher, tree_sitter_rust};
- use std::sync::Arc;
-
- #[gpui::test]
- fn test_excerpt_for_cursor_position(cx: &mut App) {
- let text = indoc! {r#"
- fn foo() {
- let x = 42;
- println!("Hello, world!");
- }
-
- fn bar() {
- let x = 42;
- let mut sum = 0;
- for i in 0..x {
- sum += i;
- }
- println!("Sum: {}", sum);
- return sum;
- }
-
- fn generate_random_numbers() -> Vec<i32> {
- let mut rng = rand::thread_rng();
- let mut numbers = Vec::new();
- for _ in 0..5 {
- numbers.push(rng.random_range(1..101));
- }
- numbers
- }
- "#};
- let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
- let snapshot = buffer.read(cx).snapshot();
-
- // Ensure we try to fit the largest possible syntax scope, resorting to line-based expansion
- // when a larger scope doesn't fit the editable region.
- let excerpt = excerpt_for_cursor_position(Point::new(12, 5), "main.rs", &snapshot, 50, 32);
- assert_eq!(
- excerpt.prompt,
- indoc! {r#"
- ```main.rs
- let x = 42;
- println!("Hello, world!");
- <|editable_region_start|>
- }
-
- fn bar() {
- let x = 42;
- let mut sum = 0;
- for i in 0..x {
- sum += i;
- }
- println!("Sum: {}", sum);
- r<|user_cursor_is_here|>eturn sum;
- }
-
- fn generate_random_numbers() -> Vec<i32> {
- <|editable_region_end|>
- let mut rng = rand::thread_rng();
- let mut numbers = Vec::new();
- ```"#}
- );
-
- // The `bar` function won't fit within the editable region, so we resort to line-based expansion.
- let excerpt = excerpt_for_cursor_position(Point::new(12, 5), "main.rs", &snapshot, 40, 32);
- assert_eq!(
- excerpt.prompt,
- indoc! {r#"
- ```main.rs
- fn bar() {
- let x = 42;
- let mut sum = 0;
- <|editable_region_start|>
- for i in 0..x {
- sum += i;
- }
- println!("Sum: {}", sum);
- r<|user_cursor_is_here|>eturn sum;
- }
-
- fn generate_random_numbers() -> Vec<i32> {
- let mut rng = rand::thread_rng();
- <|editable_region_end|>
- let mut numbers = Vec::new();
- for _ in 0..5 {
- numbers.push(rng.random_range(1..101));
- ```"#}
- );
- }
-
- fn rust_lang() -> Language {
- Language::new(
- LanguageConfig {
- name: "Rust".into(),
- matcher: LanguageMatcher {
- path_suffixes: vec!["rs".to_string()],
- ..Default::default()
- },
- ..Default::default()
- },
- Some(tree_sitter_rust::LANGUAGE.into()),
- )
- }
-}
@@ -1,5 +1,6 @@
#[cfg(feature = "eval-support")]
use crate::EvalCacheEntryKind;
+use crate::open_ai_response::text_from_response;
use crate::prediction::EditPredictionResult;
use crate::{
DebugEvent, EDIT_PREDICTIONS_MODEL_ID, EditPredictionId, EditPredictionInputs,
@@ -199,7 +200,7 @@ pub fn request_prediction_with_zeta2(
stream: false,
max_completion_tokens: None,
stop: generation_params.stop.unwrap_or_default(),
- temperature: generation_params.temperature.unwrap_or(0.7),
+ temperature: generation_params.temperature.or(Some(0.7)),
tool_choice: None,
parallel_tool_calls: None,
tools: vec![],
@@ -324,35 +325,3 @@ pub fn request_prediction_with_zeta2(
))
})
}
-
-pub fn text_from_response(mut res: open_ai::Response) -> Option<String> {
- let choice = res.choices.pop()?;
- let output_text = match choice.message {
- open_ai::RequestMessage::Assistant {
- content: Some(open_ai::MessageContent::Plain(content)),
- ..
- } => content,
- open_ai::RequestMessage::Assistant {
- content: Some(open_ai::MessageContent::Multipart(mut content)),
- ..
- } => {
- if content.is_empty() {
- log::error!("No output from Baseten completion response");
- return None;
- }
-
- match content.remove(0) {
- open_ai::MessagePart::Text { text } => text,
- open_ai::MessagePart::Image { .. } => {
- log::error!("Expected text, got an image");
- return None;
- }
- }
- }
- _ => {
- log::error!("Invalid response message: {:?}", choice.message);
- return None;
- }
- };
- Some(output_text)
-}
@@ -198,8 +198,9 @@ pub async fn perform_predict(
let response =
request.response_rx.await?.0.map_err(|err| anyhow!(err))?;
- let response = edit_prediction::zeta2::text_from_response(response)
- .unwrap_or_default();
+ let response =
+ edit_prediction::open_ai_response::text_from_response(response)
+ .unwrap_or_default();
let prediction_finished_at = Instant::now();
fs::write(example_run_dir.join("prediction_response.md"), &response)?;
@@ -3,7 +3,7 @@ use client::{Client, UserStore, zed_urls};
use cloud_llm_client::UsageLimit;
use codestral::CodestralEditPredictionDelegate;
use copilot::{Copilot, Status};
-use edit_prediction::{SweepFeatureFlag, Zeta2FeatureFlag};
+use edit_prediction::{MercuryFeatureFlag, SweepFeatureFlag, Zeta2FeatureFlag};
use edit_prediction_types::EditPredictionDelegateHandle;
use editor::{
Editor, MultiBufferOffset, SelectionEffects, actions::ShowEditPrediction, scroll::Autoscroll,
@@ -23,6 +23,7 @@ use language::{
use project::DisableAiSettings;
use regex::Regex;
use settings::{
+ EXPERIMENTAL_MERCURY_EDIT_PREDICTION_PROVIDER_NAME,
EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME,
EXPERIMENTAL_ZETA2_EDIT_PREDICTION_PROVIDER_NAME, Settings, SettingsStore,
update_settings_file,
@@ -44,7 +45,7 @@ use workspace::{
use zed_actions::OpenBrowser;
use crate::{
- RatePredictions, SweepApiKeyModal,
+ ExternalProviderApiKeyModal, RatePredictions,
rate_prediction_modal::PredictEditsRatePredictionsFeatureFlag,
};
@@ -311,21 +312,31 @@ impl Render for EditPredictionButton {
provider @ (EditPredictionProvider::Experimental(_) | EditPredictionProvider::Zed) => {
let enabled = self.editor_enabled.unwrap_or(true);
- let is_sweep = matches!(
- provider,
- EditPredictionProvider::Experimental(
- EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME
- )
- );
-
- let sweep_missing_token = is_sweep
- && !edit_prediction::EditPredictionStore::try_global(cx)
- .map_or(false, |ep_store| ep_store.read(cx).has_sweep_api_token());
+ let ep_icon;
+ let mut missing_token = false;
- let ep_icon = match (is_sweep, enabled) {
- (true, _) => IconName::SweepAi,
- (false, true) => IconName::ZedPredict,
- (false, false) => IconName::ZedPredictDisabled,
+ match provider {
+ EditPredictionProvider::Experimental(
+ EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME,
+ ) => {
+ ep_icon = IconName::SweepAi;
+ missing_token = edit_prediction::EditPredictionStore::try_global(cx)
+ .is_some_and(|ep_store| !ep_store.read(cx).has_sweep_api_token());
+ }
+ EditPredictionProvider::Experimental(
+ EXPERIMENTAL_MERCURY_EDIT_PREDICTION_PROVIDER_NAME,
+ ) => {
+ ep_icon = IconName::Inception;
+ missing_token = edit_prediction::EditPredictionStore::try_global(cx)
+ .is_some_and(|ep_store| !ep_store.read(cx).has_mercury_api_token());
+ }
+ _ => {
+ ep_icon = if enabled {
+ IconName::ZedPredict
+ } else {
+ IconName::ZedPredictDisabled
+ };
+ }
};
if edit_prediction::should_show_upsell_modal() {
@@ -369,7 +380,7 @@ impl Render for EditPredictionButton {
let show_editor_predictions = self.editor_show_predictions;
let user = self.user_store.read(cx).current_user();
- let indicator_color = if sweep_missing_token {
+ let indicator_color = if missing_token {
Some(Color::Error)
} else if enabled && (!show_editor_predictions || over_limit) {
Some(if over_limit {
@@ -532,6 +543,12 @@ impl EditPredictionButton {
));
}
+ if cx.has_flag::<MercuryFeatureFlag>() {
+ providers.push(EditPredictionProvider::Experimental(
+ EXPERIMENTAL_MERCURY_EDIT_PREDICTION_PROVIDER_NAME,
+ ));
+ }
+
if cx.has_flag::<Zeta2FeatureFlag>() {
providers.push(EditPredictionProvider::Experimental(
EXPERIMENTAL_ZETA2_EDIT_PREDICTION_PROVIDER_NAME,
@@ -628,7 +645,66 @@ impl EditPredictionButton {
if let Some(workspace) = window.root::<Workspace>().flatten() {
workspace.update(cx, |workspace, cx| {
workspace.toggle_modal(window, cx, |window, cx| {
- SweepApiKeyModal::new(window, cx)
+ ExternalProviderApiKeyModal::new(
+ window,
+ cx,
+ |api_key, store, cx| {
+ store
+ .sweep_ai
+ .set_api_token(api_key, cx)
+ .detach_and_log_err(cx);
+ },
+ )
+ });
+ });
+ };
+ } else {
+ set_completion_provider(fs.clone(), cx, provider);
+ }
+ });
+
+ menu.item(entry)
+ }
+ EditPredictionProvider::Experimental(
+ EXPERIMENTAL_MERCURY_EDIT_PREDICTION_PROVIDER_NAME,
+ ) => {
+ let has_api_token = edit_prediction::EditPredictionStore::try_global(cx)
+ .map_or(false, |ep_store| ep_store.read(cx).has_mercury_api_token());
+
+ let should_open_modal = !has_api_token || is_current;
+
+ let entry = if has_api_token {
+ ContextMenuEntry::new("Mercury")
+ .toggleable(IconPosition::Start, is_current)
+ } else {
+ ContextMenuEntry::new("Mercury")
+ .icon(IconName::XCircle)
+ .icon_color(Color::Error)
+ .documentation_aside(
+ DocumentationSide::Left,
+ DocumentationEdge::Bottom,
+ |_| {
+ Label::new("Click to configure your Mercury API token")
+ .into_any_element()
+ },
+ )
+ };
+
+ let entry = entry.handler(move |window, cx| {
+ if should_open_modal {
+ if let Some(workspace) = window.root::<Workspace>().flatten() {
+ workspace.update(cx, |workspace, cx| {
+ workspace.toggle_modal(window, cx, |window, cx| {
+ ExternalProviderApiKeyModal::new(
+ window,
+ cx,
+ |api_key, store, cx| {
+ store
+ .mercury
+ .set_api_token(api_key, cx)
+ .detach_and_log_err(cx);
+ },
+ )
});
});
};
@@ -1,7 +1,7 @@
mod edit_prediction_button;
mod edit_prediction_context_view;
+mod external_provider_api_token_modal;
mod rate_prediction_modal;
-mod sweep_api_token_modal;
use std::any::{Any as _, TypeId};
@@ -17,7 +17,7 @@ use ui::{App, prelude::*};
use workspace::{SplitDirection, Workspace};
pub use edit_prediction_button::{EditPredictionButton, ToggleMenu};
-pub use sweep_api_token_modal::SweepApiKeyModal;
+pub use external_provider_api_token_modal::ExternalProviderApiKeyModal;
use crate::rate_prediction_modal::PredictEditsRatePredictionsFeatureFlag;
@@ -6,18 +6,24 @@ use ui::{Button, ButtonStyle, Clickable, Headline, HeadlineSize, prelude::*};
use ui_input::InputField;
use workspace::ModalView;
-pub struct SweepApiKeyModal {
+pub struct ExternalProviderApiKeyModal {
api_key_input: Entity<InputField>,
focus_handle: FocusHandle,
+ on_confirm: Box<dyn Fn(Option<String>, &mut EditPredictionStore, &mut App)>,
}
-impl SweepApiKeyModal {
- pub fn new(window: &mut Window, cx: &mut Context<Self>) -> Self {
- let api_key_input = cx.new(|cx| InputField::new(window, cx, "Enter your Sweep API token"));
+impl ExternalProviderApiKeyModal {
+ pub fn new(
+ window: &mut Window,
+ cx: &mut Context<Self>,
+ on_confirm: impl Fn(Option<String>, &mut EditPredictionStore, &mut App) + 'static,
+ ) -> Self {
+ let api_key_input = cx.new(|cx| InputField::new(window, cx, "Enter your API key"));
Self {
api_key_input,
focus_handle: cx.focus_handle(),
+ on_confirm: Box::new(on_confirm),
}
}
@@ -30,39 +36,34 @@ impl SweepApiKeyModal {
let api_key = (!api_key.trim().is_empty()).then_some(api_key);
if let Some(ep_store) = EditPredictionStore::try_global(cx) {
- ep_store.update(cx, |ep_store, cx| {
- ep_store
- .sweep_ai
- .set_api_token(api_key, cx)
- .detach_and_log_err(cx);
- });
+ ep_store.update(cx, |ep_store, cx| (self.on_confirm)(api_key, ep_store, cx))
}
cx.emit(DismissEvent);
}
}
-impl EventEmitter<DismissEvent> for SweepApiKeyModal {}
+impl EventEmitter<DismissEvent> for ExternalProviderApiKeyModal {}
-impl ModalView for SweepApiKeyModal {}
+impl ModalView for ExternalProviderApiKeyModal {}
-impl Focusable for SweepApiKeyModal {
+impl Focusable for ExternalProviderApiKeyModal {
fn focus_handle(&self, _cx: &App) -> FocusHandle {
self.focus_handle.clone()
}
}
-impl Render for SweepApiKeyModal {
+impl Render for ExternalProviderApiKeyModal {
fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
v_flex()
- .key_context("SweepApiKeyModal")
+ .key_context("ExternalApiKeyModal")
.on_action(cx.listener(Self::cancel))
.on_action(cx.listener(Self::confirm))
.elevation_2(cx)
.w(px(400.))
.p_4()
.gap_3()
- .child(Headline::new("Sweep API Token").size(HeadlineSize::Small))
+ .child(Headline::new("API Token").size(HeadlineSize::Small))
.child(self.api_key_input.clone())
.child(
h_flex()
@@ -34,8 +34,8 @@ pub enum IconName {
ArrowRightLeft,
ArrowUp,
ArrowUpRight,
- Attach,
AtSign,
+ Attach,
AudioOff,
AudioOn,
Backspace,
@@ -45,8 +45,8 @@ pub enum IconName {
BellRing,
Binary,
Blocks,
- BoltOutlined,
BoltFilled,
+ BoltOutlined,
Book,
BookCopy,
CaseSensitive,
@@ -80,9 +80,9 @@ pub enum IconName {
Debug,
DebugBreakpoint,
DebugContinue,
+ DebugDetach,
DebugDisabledBreakpoint,
DebugDisabledLogBreakpoint,
- DebugDetach,
DebugIgnoreBreakpoints,
DebugLogBreakpoint,
DebugPause,
@@ -140,6 +140,7 @@ pub enum IconName {
Hash,
HistoryRerun,
Image,
+ Inception,
Indicator,
Info,
Json,
@@ -147,6 +148,7 @@ pub enum IconName {
Library,
LineHeight,
Link,
+ Linux,
ListCollapse,
ListFilter,
ListTodo,
@@ -172,8 +174,8 @@ pub enum IconName {
PencilUnavailable,
Person,
Pin,
- PlayOutlined,
PlayFilled,
+ PlayOutlined,
Plus,
Power,
Public,
@@ -259,15 +261,14 @@ pub enum IconName {
ZedAssistant,
ZedBurnMode,
ZedBurnModeOn,
- ZedSrcCustom,
- ZedSrcExtension,
ZedPredict,
ZedPredictDisabled,
ZedPredictDown,
ZedPredictError,
ZedPredictUp,
+ ZedSrcCustom,
+ ZedSrcExtension,
ZedXCopilot,
- Linux,
}
impl IconName {
@@ -438,7 +438,7 @@ pub fn into_open_ai(
messages,
stream,
stop: request.stop,
- temperature: request.temperature.unwrap_or(1.0),
+ temperature: request.temperature.or(Some(1.0)),
max_completion_tokens: max_output_tokens,
parallel_tool_calls: if supports_parallel_tool_calls && !request.tools.is_empty() {
// Disable parallel tool calls, as the Agent currently expects a maximum of one per turn.
@@ -266,7 +266,8 @@ pub struct Request {
pub max_completion_tokens: Option<u64>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub stop: Vec<String>,
- pub temperature: f32,
+ #[serde(default, skip_serializing_if = "Option::is_none")]
+ pub temperature: Option<f32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tool_choice: Option<ToolChoice>,
/// Whether to enable parallel function calling during tool use.
@@ -81,6 +81,7 @@ pub enum EditPredictionProvider {
pub const EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME: &str = "sweep";
pub const EXPERIMENTAL_ZETA2_EDIT_PREDICTION_PROVIDER_NAME: &str = "zeta2";
+pub const EXPERIMENTAL_MERCURY_EDIT_PREDICTION_PROVIDER_NAME: &str = "mercury";
impl<'de> Deserialize<'de> for EditPredictionProvider {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
@@ -111,6 +112,13 @@ impl<'de> Deserialize<'de> for EditPredictionProvider {
EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME,
)
}
+ Content::Experimental(name)
+ if name == EXPERIMENTAL_MERCURY_EDIT_PREDICTION_PROVIDER_NAME =>
+ {
+ EditPredictionProvider::Experimental(
+ EXPERIMENTAL_MERCURY_EDIT_PREDICTION_PROVIDER_NAME,
+ )
+ }
Content::Experimental(name)
if name == EXPERIMENTAL_ZETA2_EDIT_PREDICTION_PROVIDER_NAME =>
{
@@ -9,6 +9,7 @@ use gpui::{AnyWindowHandle, App, AppContext as _, Context, Entity, WeakEntity};
use language::language_settings::{EditPredictionProvider, all_language_settings};
use language_models::MistralLanguageModelProvider;
use settings::{
+ EXPERIMENTAL_MERCURY_EDIT_PREDICTION_PROVIDER_NAME,
EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME,
EXPERIMENTAL_ZETA2_EDIT_PREDICTION_PROVIDER_NAME, SettingsStore,
};
@@ -219,6 +220,10 @@ fn assign_edit_prediction_provider(
&& cx.has_flag::<Zeta2FeatureFlag>()
{
edit_prediction::EditPredictionModel::Zeta2
+ } else if name == EXPERIMENTAL_MERCURY_EDIT_PREDICTION_PROVIDER_NAME
+ && cx.has_flag::<Zeta2FeatureFlag>()
+ {
+ edit_prediction::EditPredictionModel::Mercury
} else {
return false;
}