diff --git a/Cargo.lock b/Cargo.lock index 64d683621530291c114f78737b45a32be8d60f14..0ce0621443ce998245474bded4f2d2296591cd1c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3133,13 +3133,13 @@ name = "cloud_llm_client" version = "0.1.0" dependencies = [ "anyhow", - "chrono", "indoc", "pretty_assertions", "serde", "serde_json", "strum 0.27.2", "uuid", + "zeta_prompt", ] [[package]] @@ -3247,7 +3247,7 @@ name = "codestral" version = "0.1.0" dependencies = [ "anyhow", - "edit_prediction_context", + "edit_prediction", "edit_prediction_types", "futures 0.3.31", "gpui", @@ -5336,7 +5336,6 @@ version = "0.1.0" dependencies = [ "anyhow", "clock", - "cloud_llm_client", "collections", "env_logger 0.11.8", "futures 0.3.31", diff --git a/crates/cloud_llm_client/Cargo.toml b/crates/cloud_llm_client/Cargo.toml index c6a551a1fbd8a83e50f68fbcf47f26a6e96a1d24..0f0f2e77360dab0793f5740a24965711f4d80fda 100644 --- a/crates/cloud_llm_client/Cargo.toml +++ b/crates/cloud_llm_client/Cargo.toml @@ -16,11 +16,11 @@ path = "src/cloud_llm_client.rs" [dependencies] anyhow.workspace = true -chrono.workspace = true serde = { workspace = true, features = ["derive", "rc"] } serde_json.workspace = true strum = { workspace = true, features = ["derive"] } uuid = { workspace = true, features = ["serde"] } +zeta_prompt.workspace = true [dev-dependencies] pretty_assertions.workspace = true diff --git a/crates/cloud_llm_client/src/predict_edits_v3.rs b/crates/cloud_llm_client/src/predict_edits_v3.rs index 1c7e4d79577475edb929724efc2f4d56945d7a4f..4b64813590d0dc00350d3fc5856540d378864eb9 100644 --- a/crates/cloud_llm_client/src/predict_edits_v3.rs +++ b/crates/cloud_llm_client/src/predict_edits_v3.rs @@ -1,219 +1,5 @@ -use chrono::Duration; use serde::{Deserialize, Serialize}; -use std::{ - borrow::Cow, - fmt::{Display, Write as _}, - ops::{Add, Range, Sub}, - path::Path, - sync::Arc, -}; -use strum::EnumIter; -use uuid::Uuid; - -use crate::{PredictEditsGitInfo, PredictEditsRequestTrigger}; - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct PlanContextRetrievalRequest { - pub excerpt: String, - pub excerpt_path: Arc, - pub excerpt_line_range: Range, - pub cursor_file_max_row: Line, - pub events: Vec>, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct PredictEditsRequest { - pub excerpt: String, - pub excerpt_path: Arc, - /// Within file - pub excerpt_range: Range, - pub excerpt_line_range: Range, - pub cursor_point: Point, - /// Within `signatures` - pub excerpt_parent: Option, - #[serde(skip_serializing_if = "Vec::is_empty", default)] - pub related_files: Vec, - pub events: Vec>, - #[serde(default)] - pub can_collect_data: bool, - /// Info about the git repository state, only present when can_collect_data is true. - #[serde(skip_serializing_if = "Option::is_none", default)] - pub git_info: Option, - // Only available to staff - #[serde(default)] - pub debug_info: bool, - #[serde(skip_serializing_if = "Option::is_none", default)] - pub prompt_max_bytes: Option, - #[serde(default)] - pub prompt_format: PromptFormat, - #[serde(default)] - pub trigger: PredictEditsRequestTrigger, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct RelatedFile { - pub path: Arc, - pub max_row: Line, - pub excerpts: Vec, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct Excerpt { - pub start_line: Line, - pub text: Arc, -} - -#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, EnumIter)] -pub enum PromptFormat { - /// XML old_tex/new_text - OldTextNewText, - /// Prompt format intended for use via edit_prediction_cli - OnlySnippets, - /// One-sentence instructions used in fine-tuned models - Minimal, - /// One-sentence instructions + FIM-like template - MinimalQwen, - /// No instructions, Qwen chat + Seed-Coder 1120 FIM-like template - SeedCoder1120, -} - -impl PromptFormat { - pub const DEFAULT: PromptFormat = PromptFormat::Minimal; -} - -impl Default for PromptFormat { - fn default() -> Self { - Self::DEFAULT - } -} - -impl PromptFormat { - pub fn iter() -> impl Iterator { - ::iter() - } -} - -impl std::fmt::Display for PromptFormat { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - PromptFormat::OnlySnippets => write!(f, "Only Snippets"), - PromptFormat::OldTextNewText => write!(f, "Old Text / New Text"), - PromptFormat::Minimal => write!(f, "Minimal"), - PromptFormat::MinimalQwen => write!(f, "Minimal + Qwen FIM"), - PromptFormat::SeedCoder1120 => write!(f, "Seed-Coder 1120"), - } - } -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[cfg_attr(any(test, feature = "test-support"), derive(PartialEq))] -#[serde(tag = "event")] -pub enum Event { - BufferChange { - path: Arc, - old_path: Arc, - diff: String, - predicted: bool, - in_open_source_repo: bool, - }, -} - -impl Display for Event { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Event::BufferChange { - path, - old_path, - diff, - predicted, - .. - } => { - if *predicted { - write!( - f, - "// User accepted prediction:\n--- a/{}\n+++ b/{}\n{diff}", - DiffPathFmt(old_path), - DiffPathFmt(path) - ) - } else { - write!( - f, - "--- a/{}\n+++ b/{}\n{diff}", - DiffPathFmt(old_path), - DiffPathFmt(path) - ) - } - } - } - } -} - -/// always format the Path as a unix path with `/` as the path sep in Diffs -pub struct DiffPathFmt<'a>(pub &'a Path); - -impl<'a> std::fmt::Display for DiffPathFmt<'a> { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let mut is_first = true; - for component in self.0.components() { - if !is_first { - f.write_char('/')?; - } else { - is_first = false; - } - write!(f, "{}", component.as_os_str().display())?; - } - Ok(()) - } -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct PredictEditsResponse { - pub request_id: Uuid, - pub edits: Vec, - pub debug_info: Option, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct DebugInfo { - pub prompt: String, - pub prompt_planning_time: Duration, - pub model_response: String, - pub inference_time: Duration, - pub parsing_time: Duration, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct Edit { - pub path: Arc, - pub range: Range, - pub content: String, -} - -#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, PartialOrd, Eq, Ord)] -pub struct Point { - pub line: Line, - pub column: u32, -} - -#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, PartialOrd, Eq, Ord)] -#[serde(transparent)] -pub struct Line(pub u32); - -impl Add for Line { - type Output = Self; - - fn add(self, rhs: Self) -> Self::Output { - Self(self.0 + rhs.0) - } -} - -impl Sub for Line { - type Output = Self; - - fn sub(self, rhs: Self) -> Self::Output { - Self(self.0 - rhs.0) - } -} +use std::borrow::Cow; #[derive(Debug, Deserialize, Serialize)] pub struct RawCompletionRequest { @@ -226,6 +12,22 @@ pub struct RawCompletionRequest { pub stop: Vec>, } +#[derive(Debug, Serialize, Deserialize)] +pub struct PredictEditsV3Request { + #[serde(flatten)] + pub input: zeta_prompt::ZetaPromptInput, + #[serde(skip_serializing_if = "Option::is_none")] + pub model: Option, + #[serde(default)] + pub prompt_version: zeta_prompt::ZetaVersion, +} + +#[derive(Debug, Deserialize, Serialize)] +pub struct PredictEditsV3Response { + pub request_id: String, + pub output: String, +} + #[derive(Debug, Deserialize, Serialize)] pub struct RawCompletionResponse { pub id: String, @@ -248,86 +50,3 @@ pub struct RawCompletionUsage { pub completion_tokens: u32, pub total_tokens: u32, } - -#[cfg(test)] -mod tests { - use super::*; - use indoc::indoc; - use pretty_assertions::assert_eq; - - #[test] - fn test_event_display() { - let ev = Event::BufferChange { - path: Path::new("untitled").into(), - old_path: Path::new("untitled").into(), - diff: "@@ -1,2 +1,2 @@\n-a\n-b\n".into(), - predicted: false, - in_open_source_repo: true, - }; - assert_eq!( - ev.to_string(), - indoc! {" - --- a/untitled - +++ b/untitled - @@ -1,2 +1,2 @@ - -a - -b - "} - ); - - let ev = Event::BufferChange { - path: Path::new("foo/bar.txt").into(), - old_path: Path::new("foo/bar.txt").into(), - diff: "@@ -1,2 +1,2 @@\n-a\n-b\n".into(), - predicted: false, - in_open_source_repo: true, - }; - assert_eq!( - ev.to_string(), - indoc! {" - --- a/foo/bar.txt - +++ b/foo/bar.txt - @@ -1,2 +1,2 @@ - -a - -b - "} - ); - - let ev = Event::BufferChange { - path: Path::new("abc.txt").into(), - old_path: Path::new("123.txt").into(), - diff: "@@ -1,2 +1,2 @@\n-a\n-b\n".into(), - predicted: false, - in_open_source_repo: true, - }; - assert_eq!( - ev.to_string(), - indoc! {" - --- a/123.txt - +++ b/abc.txt - @@ -1,2 +1,2 @@ - -a - -b - "} - ); - - let ev = Event::BufferChange { - path: Path::new("abc.txt").into(), - old_path: Path::new("123.txt").into(), - diff: "@@ -1,2 +1,2 @@\n-a\n-b\n".into(), - predicted: true, - in_open_source_repo: true, - }; - assert_eq!( - ev.to_string(), - indoc! {" - // User accepted prediction: - --- a/123.txt - +++ b/abc.txt - @@ -1,2 +1,2 @@ - -a - -b - "} - ); - } -} diff --git a/crates/codestral/Cargo.toml b/crates/codestral/Cargo.toml index c5686795ea6c45cb3bc4d01341c7b99b77f68bfc..7da3bed75a9f175b49b3b11f00fd1d1583743f5e 100644 --- a/crates/codestral/Cargo.toml +++ b/crates/codestral/Cargo.toml @@ -11,7 +11,7 @@ path = "src/codestral.rs" [dependencies] anyhow.workspace = true edit_prediction_types.workspace = true -edit_prediction_context.workspace = true +edit_prediction.workspace = true futures.workspace = true gpui.workspace = true http_client.workspace = true diff --git a/crates/codestral/src/codestral.rs b/crates/codestral/src/codestral.rs index 2fe2cdf63716be6341af1fb908d2f45feb9eca6a..afec79bef7f6d5f523b1ad2d110982e0a1dd467a 100644 --- a/crates/codestral/src/codestral.rs +++ b/crates/codestral/src/codestral.rs @@ -1,5 +1,5 @@ -use anyhow::{Context as _, Result}; -use edit_prediction_context::{EditPredictionExcerpt, EditPredictionExcerptOptions}; +use anyhow::Result; +use edit_prediction::cursor_excerpt; use edit_prediction_types::{EditPrediction, EditPredictionDelegate}; use futures::AsyncReadExt; use gpui::{App, Context, Entity, Task}; @@ -15,16 +15,10 @@ use std::{ sync::Arc, time::{Duration, Instant}, }; -use text::ToOffset; +use text::{OffsetRangeExt as _, ToOffset}; pub const DEBOUNCE_TIMEOUT: Duration = Duration::from_millis(150); -const EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions { - max_bytes: 1050, - min_bytes: 525, - target_before_cursor_over_total_bytes: 0.66, -}; - /// Represents a completion that has been received and processed from Codestral. /// This struct maintains the state needed to interpolate the completion as the user types. #[derive(Clone)] @@ -235,19 +229,27 @@ impl EditPredictionDelegate for CodestralEditPredictionDelegate { let cursor_offset = cursor_position.to_offset(&snapshot); let cursor_point = cursor_offset.to_point(&snapshot); - let excerpt = EditPredictionExcerpt::select_from_buffer( - cursor_point, - &snapshot, - &EXCERPT_OPTIONS, - ) - .context("Line containing cursor doesn't fit in excerpt max bytes")?; - let excerpt_text = excerpt.text(&snapshot); + const MAX_CONTEXT_TOKENS: usize = 150; + const MAX_REWRITE_TOKENS: usize = 350; + + let (_, context_range) = + cursor_excerpt::editable_and_context_ranges_for_cursor_position( + cursor_point, + &snapshot, + MAX_REWRITE_TOKENS, + MAX_CONTEXT_TOKENS, + ); + + let context_range = context_range.to_offset(&snapshot); + let excerpt_text = snapshot + .text_for_range(context_range.clone()) + .collect::(); let cursor_within_excerpt = cursor_offset - .saturating_sub(excerpt.range.start) - .min(excerpt_text.body.len()); - let prompt = excerpt_text.body[..cursor_within_excerpt].to_string(); - let suffix = excerpt_text.body[cursor_within_excerpt..].to_string(); + .saturating_sub(context_range.start) + .min(excerpt_text.len()); + let prompt = excerpt_text[..cursor_within_excerpt].to_string(); + let suffix = excerpt_text[cursor_within_excerpt..].to_string(); let completion_text = match Self::fetch_completion( http_client, diff --git a/crates/edit_prediction/src/edit_prediction.rs b/crates/edit_prediction/src/edit_prediction.rs index 880ea9b2320194d227c28a67a07f879abbee74e4..1b1834b39f79ee5f71bcd240a8df54d249406fe4 100644 --- a/crates/edit_prediction/src/edit_prediction.rs +++ b/crates/edit_prediction/src/edit_prediction.rs @@ -2,7 +2,7 @@ use anyhow::Result; use arrayvec::ArrayVec; use client::{Client, EditPredictionUsage, UserStore}; use cloud_llm_client::predict_edits_v3::{ - self, PromptFormat, RawCompletionRequest, RawCompletionResponse, + PredictEditsV3Request, PredictEditsV3Response, RawCompletionRequest, RawCompletionResponse, }; use cloud_llm_client::{ EXPIRED_LLM_TOKEN_HEADER_NAME, EditPredictionRejectReason, EditPredictionRejection, @@ -12,7 +12,6 @@ use cloud_llm_client::{ use collections::{HashMap, HashSet}; use copilot::Copilot; use db::kvp::{Dismissable, KEY_VALUE_STORE}; -use edit_prediction_context::EditPredictionExcerptOptions; use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile}; use feature_flags::{FeatureFlag, FeatureFlagAppExt as _}; use futures::{ @@ -39,6 +38,7 @@ use settings::{EditPredictionProvider, SettingsStore, update_settings_file}; use std::collections::{VecDeque, hash_map}; use text::Edit; use workspace::Workspace; +use zeta_prompt::ZetaPromptInput; use zeta_prompt::ZetaVersion; use std::ops::Range; @@ -113,27 +113,8 @@ impl FeatureFlag for MercuryFeatureFlag { const NAME: &str = "mercury"; } -pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions { - context: EditPredictionExcerptOptions { - max_bytes: 512, - min_bytes: 128, - target_before_cursor_over_total_bytes: 0.5, - }, - prompt_format: PromptFormat::DEFAULT, -}; - -static USE_OLLAMA: LazyLock = - LazyLock::new(|| env::var("ZED_ZETA2_OLLAMA").is_ok_and(|var| !var.is_empty())); - -static EDIT_PREDICTIONS_MODEL_ID: LazyLock = LazyLock::new(|| { - match env::var("ZED_ZETA2_MODEL").as_deref() { - Ok("zeta2-exp") => "4w5n28vw", // Fine-tuned model @ Baseten - Ok(model) => model, - Err(_) if *USE_OLLAMA => "qwen3-coder:30b", - Err(_) => "yqvev8r3", // Vanilla qwen3-coder @ Baseten - } - .to_string() -}); +static EDIT_PREDICTIONS_MODEL_ID: LazyLock> = + LazyLock::new(|| env::var("ZED_ZETA_MODEL").ok()); pub struct Zeta2FeatureFlag; @@ -167,10 +148,7 @@ pub struct EditPredictionStore { _llm_token_subscription: Subscription, projects: HashMap, use_context: bool, - options: ZetaOptions, update_required: bool, - #[cfg(feature = "cli-support")] - eval_cache: Option>, edit_prediction_model: EditPredictionModel, pub sweep_ai: SweepAi, pub mercury: Mercury, @@ -206,12 +184,6 @@ pub struct EditPredictionModelInput { pub user_actions: Vec, } -#[derive(Debug, Clone, PartialEq)] -pub struct ZetaOptions { - pub context: EditPredictionExcerptOptions, - pub prompt_format: predict_edits_v3::PromptFormat, -} - #[derive(Debug)] pub enum DebugEvent { ContextRetrievalStarted(ContextRetrievalStartedDebugEvent), @@ -248,8 +220,6 @@ pub struct EditPredictionFinishedDebugEvent { pub model_output: Option, } -pub type RequestDebugInfo = predict_edits_v3::DebugInfo; - const USER_ACTION_HISTORY_SIZE: usize = 16; #[derive(Clone, Debug)] @@ -641,7 +611,6 @@ impl EditPredictionStore { projects: HashMap::default(), client, user_store, - options: DEFAULT_OPTIONS, use_context: false, llm_token, _llm_token_subscription: cx.subscribe( @@ -657,8 +626,6 @@ impl EditPredictionStore { }, ), update_required: false, - #[cfg(feature = "cli-support")] - eval_cache: None, edit_prediction_model: EditPredictionModel::Zeta2 { version: Default::default(), }, @@ -671,17 +638,7 @@ impl EditPredictionStore { shown_predictions: Default::default(), custom_predict_edits_url: match env::var("ZED_PREDICT_EDITS_URL") { Ok(custom_url) => Url::parse(&custom_url).log_err().map(Into::into), - Err(_) => { - if *USE_OLLAMA { - Some( - Url::parse("http://localhost:11434/v1/chat/completions") - .unwrap() - .into(), - ) - } else { - None - } - } + Err(_) => None, }, }; @@ -718,19 +675,6 @@ impl EditPredictionStore { self.mercury.api_token.read(cx).has_key() } - #[cfg(feature = "cli-support")] - pub fn with_eval_cache(&mut self, cache: Arc) { - self.eval_cache = Some(cache); - } - - pub fn options(&self) -> &ZetaOptions { - &self.options - } - - pub fn set_options(&mut self, options: ZetaOptions) { - self.options = options; - } - pub fn set_use_context(&mut self, use_context: bool) { self.use_context = use_context; } @@ -1946,8 +1890,6 @@ impl EditPredictionStore { custom_url: Option>, llm_token: LlmApiToken, app_version: Version, - #[cfg(feature = "cli-support")] eval_cache: Option>, - #[cfg(feature = "cli-support")] eval_cache_kind: EvalCacheEntryKind, ) -> Result<(RawCompletionResponse, Option)> { let url = if let Some(custom_url) = custom_url { custom_url.as_ref().clone() @@ -1957,28 +1899,39 @@ impl EditPredictionStore { .build_zed_llm_url("/predict_edits/raw", &[])? }; - #[cfg(feature = "cli-support")] - let cache_key = if let Some(cache) = eval_cache { - use collections::FxHasher; - use std::hash::{Hash, Hasher}; - - let mut hasher = FxHasher::default(); - url.hash(&mut hasher); - let request_str = serde_json::to_string_pretty(&request)?; - request_str.hash(&mut hasher); - let hash = hasher.finish(); - - let key = (eval_cache_kind, hash); - if let Some(response_str) = cache.read(key) { - return Ok((serde_json::from_str(&response_str)?, None)); - } + Self::send_api_request( + |builder| { + let req = builder + .uri(url.as_ref()) + .body(serde_json::to_string(&request)?.into()); + Ok(req?) + }, + client, + llm_token, + app_version, + true, + ) + .await + } - Some((cache, request_str, key)) - } else { - None + pub(crate) async fn send_v3_request( + input: ZetaPromptInput, + prompt_version: ZetaVersion, + client: Arc, + llm_token: LlmApiToken, + app_version: Version, + ) -> Result<(PredictEditsV3Response, Option)> { + let url = client + .http_client() + .build_zed_llm_url("/predict_edits/v3", &[])?; + + let request = PredictEditsV3Request { + input, + model: EDIT_PREDICTIONS_MODEL_ID.clone(), + prompt_version, }; - let (response, usage) = Self::send_api_request( + Self::send_api_request( |builder| { let req = builder .uri(url.as_ref()) @@ -1990,14 +1943,7 @@ impl EditPredictionStore { app_version, true, ) - .await?; - - #[cfg(feature = "cli-support")] - if let Some((cache, request, key)) = cache_key { - cache.write(key, &request, &serde_json::to_string_pretty(&response)?); - } - - Ok((response, usage)) + .await } fn handle_api_response( @@ -2282,34 +2228,6 @@ pub struct ZedUpdateRequiredError { minimum_version: Version, } -#[cfg(feature = "cli-support")] -pub type EvalCacheKey = (EvalCacheEntryKind, u64); - -#[cfg(feature = "cli-support")] -#[derive(Debug, Clone, Copy, PartialEq)] -pub enum EvalCacheEntryKind { - Context, - Search, - Prediction, -} - -#[cfg(feature = "cli-support")] -impl std::fmt::Display for EvalCacheEntryKind { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - EvalCacheEntryKind::Search => write!(f, "search"), - EvalCacheEntryKind::Context => write!(f, "context"), - EvalCacheEntryKind::Prediction => write!(f, "prediction"), - } - } -} - -#[cfg(feature = "cli-support")] -pub trait EvalCache: Send + Sync { - fn read(&self, key: EvalCacheKey) -> Option; - fn write(&self, key: EvalCacheKey, input: &str, value: &str); -} - #[derive(Debug, Clone, Copy)] pub enum DataCollectionChoice { NotAnswered, diff --git a/crates/edit_prediction/src/edit_prediction_tests.rs b/crates/edit_prediction/src/edit_prediction_tests.rs index 7432ebb888a2ca8648388d55d0b6bf52b40fb153..1291d23a80896e53f2a4d2ceaa595fd26b39b949 100644 --- a/crates/edit_prediction/src/edit_prediction_tests.rs +++ b/crates/edit_prediction/src/edit_prediction_tests.rs @@ -6,9 +6,7 @@ use cloud_api_types::{CreateLlmTokenResponse, LlmToken}; use cloud_llm_client::{ EditPredictionRejectReason, EditPredictionRejection, PredictEditsBody, PredictEditsResponse, RejectEditPredictionsBody, - predict_edits_v3::{ - RawCompletionChoice, RawCompletionRequest, RawCompletionResponse, RawCompletionUsage, - }, + predict_edits_v3::{PredictEditsV3Request, PredictEditsV3Response}, }; use futures::{ AsyncReadExt, StreamExt, @@ -72,7 +70,7 @@ async fn test_current_state(cx: &mut TestAppContext) { respond_tx .send(model_response( - request, + &request, indoc! {r" --- a/root/1.txt +++ b/root/1.txt @@ -129,7 +127,7 @@ async fn test_current_state(cx: &mut TestAppContext) { let (request, respond_tx) = requests.predict.next().await.unwrap(); respond_tx .send(model_response( - request, + &request, indoc! {r#" --- a/root/2.txt +++ b/root/2.txt @@ -213,7 +211,7 @@ async fn test_simple_request(cx: &mut TestAppContext) { respond_tx .send(model_response( - request, + &request, indoc! { r" --- a/root/foo.md +++ b/root/foo.md @@ -290,7 +288,7 @@ async fn test_request_events(cx: &mut TestAppContext) { respond_tx .send(model_response( - request, + &request, indoc! {r#" --- a/root/foo.md +++ b/root/foo.md @@ -622,8 +620,8 @@ async fn test_empty_prediction(cx: &mut TestAppContext) { }); let (request, respond_tx) = requests.predict.next().await.unwrap(); - let response = model_response(request, ""); - let id = response.id.clone(); + let response = model_response(&request, ""); + let id = response.request_id.clone(); respond_tx.send(response).unwrap(); cx.run_until_parked(); @@ -682,8 +680,8 @@ async fn test_interpolated_empty(cx: &mut TestAppContext) { buffer.set_text("Hello!\nHow are you?\nBye", cx); }); - let response = model_response(request, SIMPLE_DIFF); - let id = response.id.clone(); + let response = model_response(&request, SIMPLE_DIFF); + let id = response.request_id.clone(); respond_tx.send(response).unwrap(); cx.run_until_parked(); @@ -747,8 +745,8 @@ async fn test_replace_current(cx: &mut TestAppContext) { }); let (request, respond_tx) = requests.predict.next().await.unwrap(); - let first_response = model_response(request, SIMPLE_DIFF); - let first_id = first_response.id.clone(); + let first_response = model_response(&request, SIMPLE_DIFF); + let first_id = first_response.request_id.clone(); respond_tx.send(first_response).unwrap(); cx.run_until_parked(); @@ -770,8 +768,8 @@ async fn test_replace_current(cx: &mut TestAppContext) { }); let (request, respond_tx) = requests.predict.next().await.unwrap(); - let second_response = model_response(request, SIMPLE_DIFF); - let second_id = second_response.id.clone(); + let second_response = model_response(&request, SIMPLE_DIFF); + let second_id = second_response.request_id.clone(); respond_tx.send(second_response).unwrap(); cx.run_until_parked(); @@ -829,8 +827,8 @@ async fn test_current_preferred(cx: &mut TestAppContext) { }); let (request, respond_tx) = requests.predict.next().await.unwrap(); - let first_response = model_response(request, SIMPLE_DIFF); - let first_id = first_response.id.clone(); + let first_response = model_response(&request, SIMPLE_DIFF); + let first_id = first_response.request_id.clone(); respond_tx.send(first_response).unwrap(); cx.run_until_parked(); @@ -854,7 +852,7 @@ async fn test_current_preferred(cx: &mut TestAppContext) { let (request, respond_tx) = requests.predict.next().await.unwrap(); // worse than current prediction let second_response = model_response( - request, + &request, indoc! { r" --- a/root/foo.md +++ b/root/foo.md @@ -865,7 +863,7 @@ async fn test_current_preferred(cx: &mut TestAppContext) { Bye "}, ); - let second_id = second_response.id.clone(); + let second_id = second_response.request_id.clone(); respond_tx.send(second_response).unwrap(); cx.run_until_parked(); @@ -935,8 +933,8 @@ async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) { cx.run_until_parked(); // second responds first - let second_response = model_response(request, SIMPLE_DIFF); - let second_id = second_response.id.clone(); + let second_response = model_response(&request, SIMPLE_DIFF); + let second_id = second_response.request_id.clone(); respond_second.send(second_response).unwrap(); cx.run_until_parked(); @@ -953,8 +951,8 @@ async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) { ); }); - let first_response = model_response(request1, SIMPLE_DIFF); - let first_id = first_response.id.clone(); + let first_response = model_response(&request1, SIMPLE_DIFF); + let first_id = first_response.request_id.clone(); respond_first.send(first_response).unwrap(); cx.run_until_parked(); @@ -1046,8 +1044,8 @@ async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) { let (request3, respond_third) = requests.predict.next().await.unwrap(); - let first_response = model_response(request1, SIMPLE_DIFF); - let first_id = first_response.id.clone(); + let first_response = model_response(&request1, SIMPLE_DIFF); + let first_id = first_response.request_id.clone(); respond_first.send(first_response).unwrap(); cx.run_until_parked(); @@ -1064,8 +1062,8 @@ async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) { ); }); - let cancelled_response = model_response(request2, SIMPLE_DIFF); - let cancelled_id = cancelled_response.id.clone(); + let cancelled_response = model_response(&request2, SIMPLE_DIFF); + let cancelled_id = cancelled_response.request_id.clone(); respond_second.send(cancelled_response).unwrap(); cx.run_until_parked(); @@ -1082,8 +1080,8 @@ async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) { ); }); - let third_response = model_response(request3, SIMPLE_DIFF); - let third_response_id = third_response.id.clone(); + let third_response = model_response(&request3, SIMPLE_DIFF); + let third_response_id = third_response.request_id.clone(); respond_third.send(third_response).unwrap(); cx.run_until_parked(); @@ -1327,50 +1325,26 @@ async fn test_rejections_flushing(cx: &mut TestAppContext) { // } // Generate a model response that would apply the given diff to the active file. -fn model_response(request: RawCompletionRequest, diff_to_apply: &str) -> RawCompletionResponse { - let prompt = &request.prompt; - - let current_marker = "<|fim_middle|>current\n"; - let updated_marker = "<|fim_middle|>updated\n"; - let suffix_marker = "<|fim_suffix|>\n"; - let cursor = "<|user_cursor|>"; - - let start_ix = current_marker.len() + prompt.find(current_marker).unwrap(); - let end_ix = start_ix + &prompt[start_ix..].find(updated_marker).unwrap(); - let excerpt = prompt[start_ix..end_ix].replace(cursor, ""); - // In v0113_ordered format, the excerpt contains <|fim_suffix|> and suffix content. - // Strip that out to get just the editable region. - let excerpt = if let Some(suffix_pos) = excerpt.find(suffix_marker) { - &excerpt[..suffix_pos] - } else { - &excerpt - }; - let new_excerpt = apply_diff_to_string(diff_to_apply, excerpt).unwrap(); - - RawCompletionResponse { - id: Uuid::new_v4().to_string(), - object: "text_completion".into(), - created: 0, - model: "model".into(), - choices: vec![RawCompletionChoice { - text: new_excerpt, - finish_reason: None, - }], - usage: RawCompletionUsage { - prompt_tokens: 0, - completion_tokens: 0, - total_tokens: 0, - }, +fn model_response(request: &PredictEditsV3Request, diff_to_apply: &str) -> PredictEditsV3Response { + let excerpt = + request.input.cursor_excerpt[request.input.editable_range_in_excerpt.clone()].to_string(); + let new_excerpt = apply_diff_to_string(diff_to_apply, &excerpt).unwrap(); + + PredictEditsV3Response { + request_id: Uuid::new_v4().to_string(), + output: new_excerpt, } } -fn prompt_from_request(request: &RawCompletionRequest) -> &str { - &request.prompt +fn prompt_from_request(request: &PredictEditsV3Request) -> String { + zeta_prompt::format_zeta_prompt(&request.input, request.prompt_version) } struct RequestChannels { - predict: - mpsc::UnboundedReceiver<(RawCompletionRequest, oneshot::Sender)>, + predict: mpsc::UnboundedReceiver<( + PredictEditsV3Request, + oneshot::Sender, + )>, reject: mpsc::UnboundedReceiver<(RejectEditPredictionsBody, oneshot::Sender<()>)>, } @@ -1397,7 +1371,7 @@ fn init_test_with_fake_client( "token": "test" })) .unwrap(), - "/predict_edits/raw" => { + "/predict_edits/v3" => { let mut buf = Vec::new(); body.read_to_end(&mut buf).await.ok(); let req = serde_json::from_slice(&buf).unwrap(); @@ -1677,20 +1651,9 @@ async fn test_edit_prediction_no_spurious_trailing_newline(cx: &mut TestAppConte // Model returns output WITH a trailing newline, even though the buffer doesn't have one. // Zeta2 should normalize both sides before diffing, so no spurious newline is inserted. - let response = RawCompletionResponse { - id: Uuid::new_v4().to_string(), - object: "text_completion".into(), - created: 0, - model: "model".into(), - choices: vec![RawCompletionChoice { - text: "hello world\n".to_string(), - finish_reason: None, - }], - usage: RawCompletionUsage { - prompt_tokens: 0, - completion_tokens: 0, - total_tokens: 0, - }, + let response = PredictEditsV3Response { + request_id: Uuid::new_v4().to_string(), + output: "hello world\n".to_string(), }; respond_tx.send(response).unwrap(); diff --git a/crates/edit_prediction/src/mercury.rs b/crates/edit_prediction/src/mercury.rs index 41b832c75a7478489103c3c9a99fbcd78ab8e0c0..8a6c4f92d49c663b8baeb17c5f2ebd99230f2c51 100644 --- a/crates/edit_prediction/src/mercury.rs +++ b/crates/edit_prediction/src/mercury.rs @@ -15,8 +15,8 @@ use std::{mem, ops::Range, path::Path, sync::Arc, time::Instant}; use zeta_prompt::ZetaPromptInput; const MERCURY_API_URL: &str = "https://api.inceptionlabs.ai/v1/edit/completions"; -const MAX_CONTEXT_TOKENS: usize = 150; -const MAX_REWRITE_TOKENS: usize = 350; +const MAX_REWRITE_TOKENS: usize = 150; +const MAX_CONTEXT_TOKENS: usize = 350; pub struct Mercury { pub api_token: Entity, diff --git a/crates/edit_prediction/src/zeta2.rs b/crates/edit_prediction/src/zeta2.rs index 9a8b9767ceda0c311ce0779fe1c0ac948b9485ce..52eb18daff92bb852212028cd19fadcf4e0a9289 100644 --- a/crates/edit_prediction/src/zeta2.rs +++ b/crates/edit_prediction/src/zeta2.rs @@ -1,5 +1,3 @@ -#[cfg(feature = "cli-support")] -use crate::EvalCacheEntryKind; use crate::prediction::EditPredictionResult; use crate::{ CurrentEditPrediction, DebugEvent, EDIT_PREDICTIONS_MODEL_ID, EditPredictionFinishedDebugEvent, @@ -22,8 +20,8 @@ pub const MAX_CONTEXT_TOKENS: usize = 350; pub fn max_editable_tokens(version: ZetaVersion) -> usize { match version { - ZetaVersion::V0112_MiddleAtEnd | ZetaVersion::V0113_Ordered => 150, - ZetaVersion::V0114_180EditableRegion => 180, + ZetaVersion::V0112MiddleAtEnd | ZetaVersion::V0113Ordered => 150, + ZetaVersion::V0114180EditableRegion => 180, } } @@ -42,7 +40,7 @@ pub fn request_prediction_with_zeta2( cx: &mut Context, ) -> Task>> { let buffer_snapshotted_at = Instant::now(); - let url = store.custom_predict_edits_url.clone(); + let custom_url = store.custom_predict_edits_url.clone(); let Some(excerpt_path) = snapshot .file() @@ -55,9 +53,6 @@ pub fn request_prediction_with_zeta2( let llm_token = store.llm_token.clone(); let app_version = AppVersion::global(cx); - #[cfg(feature = "cli-support")] - let eval_cache = store.eval_cache.clone(); - let request_task = cx.background_spawn({ async move { let cursor_offset = position.to_offset(&snapshot); @@ -70,49 +65,68 @@ pub fn request_prediction_with_zeta2( zeta_version, ); - let prompt = format_zeta_prompt(&prompt_input, zeta_version); - if let Some(debug_tx) = &debug_tx { + let prompt = format_zeta_prompt(&prompt_input, zeta_version); debug_tx .unbounded_send(DebugEvent::EditPredictionStarted( EditPredictionStartedDebugEvent { buffer: buffer.downgrade(), - prompt: Some(prompt.clone()), + prompt: Some(prompt), position, }, )) .ok(); } - let request = RawCompletionRequest { - model: EDIT_PREDICTIONS_MODEL_ID.clone(), - prompt, - temperature: None, - stop: vec![], - max_tokens: Some(2048), - }; - log::trace!("Sending edit prediction request"); - let response = EditPredictionStore::send_raw_llm_request( - request, - client, - url, - llm_token, - app_version, - #[cfg(feature = "cli-support")] - eval_cache, - #[cfg(feature = "cli-support")] - EvalCacheEntryKind::Prediction, - ) - .await; + let (request_id, output_text, usage) = if let Some(custom_url) = custom_url { + // Use raw endpoint with custom URL + let prompt = format_zeta_prompt(&prompt_input, zeta_version); + let request = RawCompletionRequest { + model: EDIT_PREDICTIONS_MODEL_ID.clone().unwrap_or_default(), + prompt, + temperature: None, + stop: vec![], + max_tokens: Some(2048), + }; + + let (mut response, usage) = EditPredictionStore::send_raw_llm_request( + request, + client, + Some(custom_url), + llm_token, + app_version, + ) + .await?; + + let request_id = EditPredictionId(response.id.clone().into()); + let output_text = response.choices.pop().map(|choice| choice.text); + (request_id, output_text, usage) + } else { + let (response, usage) = EditPredictionStore::send_v3_request( + prompt_input.clone(), + zeta_version, + client, + llm_token, + app_version, + ) + .await?; + + let request_id = EditPredictionId(response.request_id.into()); + let output_text = if response.output.is_empty() { + None + } else { + Some(response.output) + }; + (request_id, output_text, usage) + }; + let received_response_at = Instant::now(); log::trace!("Got edit prediction response"); - let (mut res, usage) = response?; - let request_id = EditPredictionId(res.id.clone().into()); - let Some(mut output_text) = res.choices.pop().map(|choice| choice.text) else { + let Some(mut output_text) = output_text else { return Ok((Some((request_id, None)), usage)); }; diff --git a/crates/edit_prediction_context/Cargo.toml b/crates/edit_prediction_context/Cargo.toml index 38bb74733653257ee944dd11e037b2f63837b5e5..e1c1aed4e35f518258edcec8acd59dd9fcac7338 100644 --- a/crates/edit_prediction_context/Cargo.toml +++ b/crates/edit_prediction_context/Cargo.toml @@ -14,7 +14,6 @@ path = "src/edit_prediction_context.rs" [dependencies] anyhow.workspace = true clock.workspace = true -cloud_llm_client.workspace = true collections.workspace = true futures.workspace = true gpui.workspace = true diff --git a/crates/edit_prediction_context/src/edit_prediction_context.rs b/crates/edit_prediction_context/src/edit_prediction_context.rs index f333a91533948c8b91b480996ac20a6ae2abda46..79bfdfa192a7d52d7f1189b93e164290380c71ea 100644 --- a/crates/edit_prediction_context/src/edit_prediction_context.rs +++ b/crates/edit_prediction_context/src/edit_prediction_context.rs @@ -20,12 +20,9 @@ use util::{RangeExt as _, ResultExt}; mod assemble_excerpts; #[cfg(test)] mod edit_prediction_context_tests; -mod excerpt; #[cfg(test)] mod fake_definition_lsp; -pub use cloud_llm_client::predict_edits_v3::Line; -pub use excerpt::{EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionExcerptText}; pub use zeta_prompt::{RelatedExcerpt, RelatedFile}; const IDENTIFIER_LINE_COUNT: u32 = 3; diff --git a/crates/edit_prediction_context/src/excerpt.rs b/crates/edit_prediction_context/src/excerpt.rs deleted file mode 100644 index 3fc7eed4ace5a83992bf496aef3e364aea96e215..0000000000000000000000000000000000000000 --- a/crates/edit_prediction_context/src/excerpt.rs +++ /dev/null @@ -1,556 +0,0 @@ -use cloud_llm_client::predict_edits_v3::Line; -use language::{BufferSnapshot, LanguageId, Point, ToOffset as _, ToPoint as _}; -use std::ops::Range; -use tree_sitter::{Node, TreeCursor}; -use util::RangeExt; - -// TODO: -// -// - Test parent signatures -// -// - Decide whether to count signatures against the excerpt size. Could instead defer this to prompt -// planning. -// -// - Still return an excerpt even if the line around the cursor doesn't fit (e.g. for a markdown -// paragraph). -// -// - Truncation of long lines. -// -// - Filter outer syntax layers that don't support edit prediction. - -#[derive(Debug, Clone, PartialEq)] -pub struct EditPredictionExcerptOptions { - /// Limit for the number of bytes in the window around the cursor. - pub max_bytes: usize, - /// Minimum number of bytes in the window around the cursor. When syntax tree selection results - /// in an excerpt smaller than this, it will fall back on line-based selection. - pub min_bytes: usize, - /// Target ratio of bytes before the cursor divided by total bytes in the window. - pub target_before_cursor_over_total_bytes: f32, -} - -#[derive(Debug, Clone)] -pub struct EditPredictionExcerpt { - pub range: Range, - pub line_range: Range, - pub size: usize, -} - -#[derive(Debug, Clone)] -pub struct EditPredictionExcerptText { - pub body: String, - pub language_id: Option, -} - -impl EditPredictionExcerpt { - pub fn text(&self, buffer: &BufferSnapshot) -> EditPredictionExcerptText { - let body = buffer - .text_for_range(self.range.clone()) - .collect::(); - let language_id = buffer.language().map(|l| l.id()); - EditPredictionExcerptText { body, language_id } - } - - /// Selects an excerpt around a buffer position, attempting to choose logical boundaries based - /// on TreeSitter structure and approximately targeting a goal ratio of bytesbefore vs after the - /// cursor. - /// - /// When `index` is provided, the excerpt will include the signatures of parent outline items. - /// - /// First tries to use AST node boundaries to select the excerpt, and falls back on line-based - /// expansion. - /// - /// Returns `None` if the line around the cursor doesn't fit. - pub fn select_from_buffer( - query_point: Point, - buffer: &BufferSnapshot, - options: &EditPredictionExcerptOptions, - ) -> Option { - if buffer.len() <= options.max_bytes { - log::debug!( - "using entire file for excerpt since source length ({}) <= window max bytes ({})", - buffer.len(), - options.max_bytes - ); - let offset_range = 0..buffer.len(); - let line_range = Line(0)..Line(buffer.max_point().row); - return Some(EditPredictionExcerpt::new(offset_range, line_range)); - } - - let query_offset = query_point.to_offset(buffer); - let query_line_range = query_point.row..query_point.row + 1; - let query_range = Point::new(query_line_range.start, 0).to_offset(buffer) - ..Point::new(query_line_range.end, 0).to_offset(buffer); - if query_range.len() >= options.max_bytes { - return None; - } - - let excerpt_selector = ExcerptSelector { - query_offset, - query_range, - query_line_range: Line(query_line_range.start)..Line(query_line_range.end), - buffer, - options, - }; - - if let Some(excerpt) = excerpt_selector.select_tree_sitter_nodes() { - if excerpt.size >= options.min_bytes { - return Some(excerpt); - } - log::debug!( - "tree-sitter excerpt was {} bytes, smaller than min of {}, falling back on line-based selection", - excerpt.size, - options.min_bytes - ); - } else { - log::debug!( - "couldn't find excerpt via tree-sitter, falling back on line-based selection" - ); - } - - excerpt_selector.select_lines() - } - - fn new(range: Range, line_range: Range) -> Self { - Self { - size: range.len(), - range, - line_range, - } - } - - fn with_expanded_range(&self, new_range: Range, new_line_range: Range) -> Self { - if !new_range.contains_inclusive(&self.range) { - // this is an issue because parent_signature_ranges may be incorrect - log::error!("bug: with_expanded_range called with disjoint range"); - } - Self::new(new_range, new_line_range) - } - - fn parent_signatures_size(&self) -> usize { - self.size - self.range.len() - } -} - -struct ExcerptSelector<'a> { - query_offset: usize, - query_range: Range, - query_line_range: Range, - buffer: &'a BufferSnapshot, - options: &'a EditPredictionExcerptOptions, -} - -impl<'a> ExcerptSelector<'a> { - /// Finds the largest node that is smaller than the window size and contains `query_range`. - fn select_tree_sitter_nodes(&self) -> Option { - let selected_layer_root = self.select_syntax_layer()?; - let mut cursor = selected_layer_root.walk(); - - loop { - let line_start = node_line_start(cursor.node()); - let line_end = node_line_end(cursor.node()); - let line_range = Line(line_start.row)..Line(line_end.row); - let excerpt_range = - line_start.to_offset(&self.buffer)..line_end.to_offset(&self.buffer); - if excerpt_range.contains_inclusive(&self.query_range) { - let excerpt = self.make_excerpt(excerpt_range, line_range); - if excerpt.size <= self.options.max_bytes { - return Some(self.expand_to_siblings(&mut cursor, excerpt)); - } - } else { - // TODO: Should still be able to handle this case via AST nodes. For example, this - // can happen if the cursor is between two methods in a large class file. - return None; - } - - if cursor - .goto_first_child_for_byte(self.query_range.start) - .is_none() - { - return None; - } - } - } - - /// Select the smallest syntax layer that exceeds max_len, or the largest if none exceed max_len. - fn select_syntax_layer(&self) -> Option> { - let mut smallest_exceeding_max_len: Option> = None; - let mut largest: Option> = None; - for layer in self - .buffer - .syntax_layers_for_range(self.query_range.start..self.query_range.start, true) - { - let layer_range = layer.node().byte_range(); - if !layer_range.contains_inclusive(&self.query_range) { - continue; - } - - if layer_range.len() > self.options.max_bytes { - match &smallest_exceeding_max_len { - None => smallest_exceeding_max_len = Some(layer.node()), - Some(existing) => { - if layer_range.len() < existing.byte_range().len() { - smallest_exceeding_max_len = Some(layer.node()); - } - } - } - } else { - match &largest { - None => largest = Some(layer.node()), - Some(existing) if layer_range.len() > existing.byte_range().len() => { - largest = Some(layer.node()) - } - _ => {} - } - } - } - - smallest_exceeding_max_len.or(largest) - } - - // motivation for this and `goto_previous_named_sibling` is to avoid including things like - // trailing unnamed "}" in body nodes - fn goto_next_named_sibling(cursor: &mut TreeCursor) -> bool { - while cursor.goto_next_sibling() { - if cursor.node().is_named() { - return true; - } - } - false - } - - fn goto_previous_named_sibling(cursor: &mut TreeCursor) -> bool { - while cursor.goto_previous_sibling() { - if cursor.node().is_named() { - return true; - } - } - false - } - - fn expand_to_siblings( - &self, - cursor: &mut TreeCursor, - mut excerpt: EditPredictionExcerpt, - ) -> EditPredictionExcerpt { - let mut forward_cursor = cursor.clone(); - let backward_cursor = cursor; - let mut forward_done = !Self::goto_next_named_sibling(&mut forward_cursor); - let mut backward_done = !Self::goto_previous_named_sibling(backward_cursor); - loop { - if backward_done && forward_done { - break; - } - - let mut forward = None; - while !forward_done { - let new_end_point = node_line_end(forward_cursor.node()); - let new_end = new_end_point.to_offset(&self.buffer); - if new_end > excerpt.range.end { - let new_excerpt = excerpt.with_expanded_range( - excerpt.range.start..new_end, - excerpt.line_range.start..Line(new_end_point.row), - ); - if new_excerpt.size <= self.options.max_bytes { - forward = Some(new_excerpt); - break; - } else { - log::debug!("halting forward expansion, as it doesn't fit"); - forward_done = true; - break; - } - } - forward_done = !Self::goto_next_named_sibling(&mut forward_cursor); - } - - let mut backward = None; - while !backward_done { - let new_start_point = node_line_start(backward_cursor.node()); - let new_start = new_start_point.to_offset(&self.buffer); - if new_start < excerpt.range.start { - let new_excerpt = excerpt.with_expanded_range( - new_start..excerpt.range.end, - Line(new_start_point.row)..excerpt.line_range.end, - ); - if new_excerpt.size <= self.options.max_bytes { - backward = Some(new_excerpt); - break; - } else { - log::debug!("halting backward expansion, as it doesn't fit"); - backward_done = true; - break; - } - } - backward_done = !Self::goto_previous_named_sibling(backward_cursor); - } - - let go_forward = match (forward, backward) { - (Some(forward), Some(backward)) => { - let go_forward = self.is_better_excerpt(&forward, &backward); - if go_forward { - excerpt = forward; - } else { - excerpt = backward; - } - go_forward - } - (Some(forward), None) => { - log::debug!("expanding forward, since backward expansion has halted"); - excerpt = forward; - true - } - (None, Some(backward)) => { - log::debug!("expanding backward, since forward expansion has halted"); - excerpt = backward; - false - } - (None, None) => break, - }; - - if go_forward { - forward_done = !Self::goto_next_named_sibling(&mut forward_cursor); - } else { - backward_done = !Self::goto_previous_named_sibling(backward_cursor); - } - } - - excerpt - } - - fn select_lines(&self) -> Option { - // early return if line containing query_offset is already too large - let excerpt = self.make_excerpt(self.query_range.clone(), self.query_line_range.clone()); - if excerpt.size > self.options.max_bytes { - log::debug!( - "excerpt for cursor line is {} bytes, which exceeds the window", - excerpt.size - ); - return None; - } - let signatures_size = excerpt.parent_signatures_size(); - let bytes_remaining = self.options.max_bytes.saturating_sub(signatures_size); - - let before_bytes = - (self.options.target_before_cursor_over_total_bytes * bytes_remaining as f32) as usize; - - let start_line = { - let offset = self.query_offset.saturating_sub(before_bytes); - let point = offset.to_point(self.buffer); - Line(point.row + 1) - }; - let start_offset = Point::new(start_line.0, 0).to_offset(&self.buffer); - let end_line = { - let offset = start_offset + bytes_remaining; - let point = offset.to_point(self.buffer); - Line(point.row) - }; - let end_offset = Point::new(end_line.0, 0).to_offset(&self.buffer); - - // this could be expanded further since recalculated `signature_size` may be smaller, but - // skipping that for now for simplicity - // - // TODO: could also consider checking if lines immediately before / after fit. - let excerpt = self.make_excerpt(start_offset..end_offset, start_line..end_line); - if excerpt.size > self.options.max_bytes { - log::error!( - "bug: line-based excerpt selection has size {}, \ - which is {} bytes larger than the max size", - excerpt.size, - excerpt.size - self.options.max_bytes - ); - } - return Some(excerpt); - } - - fn make_excerpt(&self, range: Range, line_range: Range) -> EditPredictionExcerpt { - EditPredictionExcerpt::new(range, line_range) - } - - /// Returns `true` if the `forward` excerpt is a better choice than the `backward` excerpt. - fn is_better_excerpt( - &self, - forward: &EditPredictionExcerpt, - backward: &EditPredictionExcerpt, - ) -> bool { - let forward_ratio = self.excerpt_range_ratio(forward); - let backward_ratio = self.excerpt_range_ratio(backward); - let forward_delta = - (forward_ratio - self.options.target_before_cursor_over_total_bytes).abs(); - let backward_delta = - (backward_ratio - self.options.target_before_cursor_over_total_bytes).abs(); - let forward_is_better = forward_delta <= backward_delta; - if forward_is_better { - log::debug!( - "expanding forward since {} is closer than {} to {}", - forward_ratio, - backward_ratio, - self.options.target_before_cursor_over_total_bytes - ); - } else { - log::debug!( - "expanding backward since {} is closer than {} to {}", - backward_ratio, - forward_ratio, - self.options.target_before_cursor_over_total_bytes - ); - } - forward_is_better - } - - /// Returns the ratio of bytes before the cursor over bytes within the range. - fn excerpt_range_ratio(&self, excerpt: &EditPredictionExcerpt) -> f32 { - let Some(bytes_before_cursor) = self.query_offset.checked_sub(excerpt.range.start) else { - log::error!("bug: edit prediction cursor offset is not outside the excerpt"); - return 0.0; - }; - bytes_before_cursor as f32 / excerpt.range.len() as f32 - } -} - -fn node_line_start(node: Node) -> Point { - Point::new(node.start_position().row as u32, 0) -} - -fn node_line_end(node: Node) -> Point { - Point::new(node.end_position().row as u32 + 1, 0) -} - -#[cfg(test)] -mod tests { - use super::*; - use gpui::{AppContext, TestAppContext}; - use language::Buffer; - use util::test::{generate_marked_text, marked_text_offsets_by}; - - fn create_buffer(text: &str, cx: &mut TestAppContext) -> BufferSnapshot { - let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(language::rust_lang(), cx)); - buffer.read_with(cx, |buffer, _| buffer.snapshot()) - } - - fn cursor_and_excerpt_range(text: &str) -> (String, usize, Range) { - let (text, offsets) = marked_text_offsets_by(text, vec!['ˇ', '«', '»']); - (text, offsets[&'ˇ'][0], offsets[&'«'][0]..offsets[&'»'][0]) - } - - fn check_example(options: EditPredictionExcerptOptions, text: &str, cx: &mut TestAppContext) { - let (text, cursor, expected_excerpt) = cursor_and_excerpt_range(text); - - let buffer = create_buffer(&text, cx); - let cursor_point = cursor.to_point(&buffer); - - let excerpt = EditPredictionExcerpt::select_from_buffer(cursor_point, &buffer, &options) - .expect("Should select an excerpt"); - pretty_assertions::assert_eq!( - generate_marked_text(&text, std::slice::from_ref(&excerpt.range), false), - generate_marked_text(&text, &[expected_excerpt], false) - ); - assert!(excerpt.size <= options.max_bytes); - assert!(excerpt.range.contains(&cursor)); - } - - #[gpui::test] - fn test_ast_based_selection_current_node(cx: &mut TestAppContext) { - zlog::init_test(); - let text = r#" -fn main() { - let x = 1; -« let ˇy = 2; -» let z = 3; -}"#; - - let options = EditPredictionExcerptOptions { - max_bytes: 20, - min_bytes: 10, - target_before_cursor_over_total_bytes: 0.5, - }; - - check_example(options, text, cx); - } - - #[gpui::test] - fn test_ast_based_selection_parent_node(cx: &mut TestAppContext) { - zlog::init_test(); - let text = r#" -fn foo() {} - -«fn main() { - let x = 1; - let ˇy = 2; - let z = 3; -} -» -fn bar() {}"#; - - let options = EditPredictionExcerptOptions { - max_bytes: 65, - min_bytes: 10, - target_before_cursor_over_total_bytes: 0.5, - }; - - check_example(options, text, cx); - } - - #[gpui::test] - fn test_ast_based_selection_expands_to_siblings(cx: &mut TestAppContext) { - zlog::init_test(); - let text = r#" -fn main() { -« let x = 1; - let ˇy = 2; - let z = 3; -»}"#; - - let options = EditPredictionExcerptOptions { - max_bytes: 50, - min_bytes: 10, - target_before_cursor_over_total_bytes: 0.5, - }; - - check_example(options, text, cx); - } - - #[gpui::test] - fn test_line_based_selection(cx: &mut TestAppContext) { - zlog::init_test(); - let text = r#" -fn main() { - let x = 1; -« if true { - let ˇy = 2; - } - let z = 3; -»}"#; - - let options = EditPredictionExcerptOptions { - max_bytes: 60, - min_bytes: 45, - target_before_cursor_over_total_bytes: 0.5, - }; - - check_example(options, text, cx); - } - - #[gpui::test] - fn test_line_based_selection_with_before_cursor_ratio(cx: &mut TestAppContext) { - zlog::init_test(); - let text = r#" - fn main() { -« let a = 1; - let b = 2; - let c = 3; - let ˇd = 4; - let e = 5; - let f = 6; -» - let g = 7; - }"#; - - let options = EditPredictionExcerptOptions { - max_bytes: 120, - min_bytes: 10, - target_before_cursor_over_total_bytes: 0.6, - }; - - check_example(options, text, cx); - } -} diff --git a/crates/zeta_prompt/src/zeta_prompt.rs b/crates/zeta_prompt/src/zeta_prompt.rs index 9fa672d85a814d5d089f0e2147d72c69af6a1da1..df5d36e64a4b72c002d1486c6f5345f5482d3981 100644 --- a/crates/zeta_prompt/src/zeta_prompt.rs +++ b/crates/zeta_prompt/src/zeta_prompt.rs @@ -33,10 +33,10 @@ pub struct ZetaPromptInput { )] #[allow(non_camel_case_types)] pub enum ZetaVersion { - V0112_MiddleAtEnd, - V0113_Ordered, + V0112MiddleAtEnd, + V0113Ordered, #[default] - V0114_180EditableRegion, + V0114180EditableRegion, } impl std::fmt::Display for ZetaVersion { @@ -134,10 +134,10 @@ pub fn format_zeta_prompt(input: &ZetaPromptInput, version: ZetaVersion) -> Stri write_edit_history_section(&mut prompt, input); match version { - ZetaVersion::V0112_MiddleAtEnd => { + ZetaVersion::V0112MiddleAtEnd => { v0112_middle_at_end::write_cursor_excerpt_section(&mut prompt, input); } - ZetaVersion::V0113_Ordered | ZetaVersion::V0114_180EditableRegion => { + ZetaVersion::V0113Ordered | ZetaVersion::V0114180EditableRegion => { v0113_ordered::write_cursor_excerpt_section(&mut prompt, input) } }