From c9e3b32366df2de3bf24dc5d36b7739eb24bf417 Mon Sep 17 00:00:00 2001 From: Agus Zubiaga Date: Mon, 22 Sep 2025 19:18:38 -0300 Subject: [PATCH] zeta2: Provider setup (#38676) Creates a new `EditPredictionProvider` for zeta2, that requests completions from a new cloud endpoint including context from the new `edit_prediction_context` crate. This is not ready for use, but it allows us to iterate. Release Notes: - N/A --------- Co-authored-by: Michael Sloan Co-authored-by: Bennet Co-authored-by: Bennet Bo Fenner --- Cargo.lock | 29 + Cargo.toml | 2 + crates/cloud_llm_client/Cargo.toml | 1 + .../cloud_llm_client/src/cloud_llm_client.rs | 2 + .../cloud_llm_client/src/predict_edits_v3.rs | 118 ++ crates/edit_prediction_context/Cargo.toml | 1 + .../src/declaration.rs | 31 + .../src/declaration_scoring.rs | 83 +- .../src/edit_prediction_context.rs | 104 +- crates/edit_prediction_context/src/excerpt.rs | 85 +- .../edit_prediction_context/src/reference.rs | 4 +- .../src/syntax_index.rs | 69 +- .../src/text_similarity.rs | 4 + .../src/edit_prediction_tools.rs | 16 +- .../settings/src/settings_content/language.rs | 11 + crates/zed/Cargo.toml | 1 + .../zed/src/zed/edit_prediction_registry.rs | 46 +- crates/zeta2/Cargo.toml | 37 + crates/zeta2/LICENSE-GPL | 1 + crates/zeta2/src/zeta2.rs | 1130 +++++++++++++++++ 20 files changed, 1589 insertions(+), 186 deletions(-) create mode 100644 crates/cloud_llm_client/src/predict_edits_v3.rs create mode 100644 crates/zeta2/Cargo.toml create mode 120000 crates/zeta2/LICENSE-GPL create mode 100644 crates/zeta2/src/zeta2.rs diff --git a/Cargo.lock b/Cargo.lock index 450d59e73b6ce48f86abe17d6bd27df98f1e7df4..5e704d56697b1460c9a3a705f852e3287185bdf2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3216,6 +3216,7 @@ name = "cloud_llm_client" version = "0.1.0" dependencies = [ "anyhow", + "chrono", "pretty_assertions", "serde", "serde_json", @@ -5177,6 +5178,7 @@ dependencies = [ "anyhow", "arrayvec", "clap", + "cloud_llm_client", "collections", "futures 0.3.31", "gpui", @@ -21370,6 +21372,7 @@ dependencies = [ "zed_actions", "zed_env_vars", "zeta", + "zeta2", "zlog", "zlog_settings", ] @@ -21647,6 +21650,32 @@ dependencies = [ "zlog", ] +[[package]] +name = "zeta2" +version = "0.1.0" +dependencies = [ + "anyhow", + "arrayvec", + "client", + "cloud_llm_client", + "edit_prediction", + "edit_prediction_context", + "futures 0.3.31", + "gpui", + "language", + "language_model", + "log", + "project", + "release_channel", + "serde_json", + "thiserror 2.0.12", + "util", + "uuid", + "workspace", + "workspace-hack", + "worktree", +] + [[package]] name = "zeta_cli" version = "0.1.0" diff --git a/Cargo.toml b/Cargo.toml index fd552c6e9d117bd03b251f231dee8294b02ba928..6e1950aaeea715dd85c98a443b6116a619b0e3f7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -199,6 +199,7 @@ members = [ "crates/zed_actions", "crates/zed_env_vars", "crates/zeta", + "crates/zeta2", "crates/zeta_cli", "crates/zlog", "crates/zlog_settings", @@ -432,6 +433,7 @@ zed = { path = "crates/zed" } zed_actions = { path = "crates/zed_actions" } zed_env_vars = { path = "crates/zed_env_vars" } zeta = { path = "crates/zeta" } +zeta2 = { path = "crates/zeta2" } zlog = { path = "crates/zlog" } zlog_settings = { path = "crates/zlog_settings" } diff --git a/crates/cloud_llm_client/Cargo.toml b/crates/cloud_llm_client/Cargo.toml index 6f090d3c6ea67d8bb189212fb9704b618554f671..700893dd4030e2eb7b9eab2286319ec08df2f522 100644 --- a/crates/cloud_llm_client/Cargo.toml +++ b/crates/cloud_llm_client/Cargo.toml @@ -13,6 +13,7 @@ path = "src/cloud_llm_client.rs" [dependencies] anyhow.workspace = true +chrono.workspace = true serde = { workspace = true, features = ["derive", "rc"] } serde_json.workspace = true strum = { workspace = true, features = ["derive"] } diff --git a/crates/cloud_llm_client/src/cloud_llm_client.rs b/crates/cloud_llm_client/src/cloud_llm_client.rs index 24923a318441afeaa2521064b4f433ab9ee1e55f..e0cc42af76156466c31ead17d6421f3634d3ad7c 100644 --- a/crates/cloud_llm_client/src/cloud_llm_client.rs +++ b/crates/cloud_llm_client/src/cloud_llm_client.rs @@ -1,3 +1,5 @@ +pub mod predict_edits_v3; + use std::str::FromStr; use std::sync::Arc; diff --git a/crates/cloud_llm_client/src/predict_edits_v3.rs b/crates/cloud_llm_client/src/predict_edits_v3.rs new file mode 100644 index 0000000000000000000000000000000000000000..60621b1f14714439b0527078c07e2865799172f3 --- /dev/null +++ b/crates/cloud_llm_client/src/predict_edits_v3.rs @@ -0,0 +1,118 @@ +use chrono::Duration; +use serde::{Deserialize, Serialize}; +use std::{ops::Range, path::PathBuf}; +use uuid::Uuid; + +use crate::PredictEditsGitInfo; + +// TODO: snippet ordering within file / relative to excerpt + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PredictEditsRequest { + pub excerpt: String, + pub excerpt_path: PathBuf, + /// Within file + pub excerpt_range: Range, + /// Within `excerpt` + pub cursor_offset: usize, + /// Within `signatures` + pub excerpt_parent: Option, + pub signatures: Vec, + pub referenced_declarations: Vec, + pub events: Vec, + #[serde(default)] + pub can_collect_data: bool, + #[serde(skip_serializing_if = "Vec::is_empty", default)] + pub diagnostic_groups: Vec, + /// Info about the git repository state, only present when can_collect_data is true. + #[serde(skip_serializing_if = "Option::is_none", default)] + pub git_info: Option, + #[serde(default)] + pub debug_info: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "event")] +pub enum Event { + BufferChange { + path: Option, + old_path: Option, + diff: String, + predicted: bool, + }, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Signature { + pub text: String, + pub text_is_truncated: bool, + #[serde(skip_serializing_if = "Option::is_none", default)] + pub parent_index: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ReferencedDeclaration { + pub path: PathBuf, + pub text: String, + pub text_is_truncated: bool, + /// Range of `text` within file, potentially truncated according to `text_is_truncated` + pub range: Range, + /// Range within `text` + pub signature_range: Range, + /// Index within `signatures`. + #[serde(skip_serializing_if = "Option::is_none", default)] + pub parent_index: Option, + pub score_components: ScoreComponents, + pub signature_score: f32, + pub declaration_score: f32, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ScoreComponents { + pub is_same_file: bool, + pub is_referenced_nearby: bool, + pub is_referenced_in_breadcrumb: bool, + pub reference_count: usize, + pub same_file_declaration_count: usize, + pub declaration_count: usize, + pub reference_line_distance: u32, + pub declaration_line_distance: u32, + pub declaration_line_distance_rank: usize, + pub containing_range_vs_item_jaccard: f32, + pub containing_range_vs_signature_jaccard: f32, + pub adjacent_vs_item_jaccard: f32, + pub adjacent_vs_signature_jaccard: f32, + pub containing_range_vs_item_weighted_overlap: f32, + pub containing_range_vs_signature_weighted_overlap: f32, + pub adjacent_vs_item_weighted_overlap: f32, + pub adjacent_vs_signature_weighted_overlap: f32, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DiagnosticGroup { + pub language_server: String, + pub diagnostic_group: serde_json::Value, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PredictEditsResponse { + pub request_id: Uuid, + pub edits: Vec, + pub debug_info: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DebugInfo { + pub prompt: String, + pub prompt_planning_time: Duration, + pub model_response: String, + pub inference_time: Duration, + pub parsing_time: Duration, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Edit { + pub path: PathBuf, + pub range: Range, + pub content: String, +} diff --git a/crates/edit_prediction_context/Cargo.toml b/crates/edit_prediction_context/Cargo.toml index 48f51da3912ea5bca589e7b559d5b665b9b762d6..75880cad5f3e2807e525908656931853efa19a92 100644 --- a/crates/edit_prediction_context/Cargo.toml +++ b/crates/edit_prediction_context/Cargo.toml @@ -14,6 +14,7 @@ path = "src/edit_prediction_context.rs" [dependencies] anyhow.workspace = true arrayvec.workspace = true +cloud_llm_client.workspace = true collections.workspace = true futures.workspace = true gpui.workspace = true diff --git a/crates/edit_prediction_context/src/declaration.rs b/crates/edit_prediction_context/src/declaration.rs index 8fba85367c70e2cb1211e343bba7a6675a5a5360..653f810d439395a8825c99f4b007e05d881540ab 100644 --- a/crates/edit_prediction_context/src/declaration.rs +++ b/crates/edit_prediction_context/src/declaration.rs @@ -41,6 +41,20 @@ impl Declaration { } } + pub fn parent(&self) -> Option { + match self { + Declaration::File { declaration, .. } => declaration.parent, + Declaration::Buffer { declaration, .. } => declaration.parent, + } + } + + pub fn as_buffer(&self) -> Option<&BufferDeclaration> { + match self { + Declaration::File { .. } => None, + Declaration::Buffer { declaration, .. } => Some(declaration), + } + } + pub fn project_entry_id(&self) -> ProjectEntryId { match self { Declaration::File { @@ -52,6 +66,13 @@ impl Declaration { } } + pub fn item_range(&self) -> Range { + match self { + Declaration::File { declaration, .. } => declaration.item_range_in_file.clone(), + Declaration::Buffer { declaration, .. } => declaration.item_range.clone(), + } + } + pub fn item_text(&self) -> (Cow<'_, str>, bool) { match self { Declaration::File { declaration, .. } => ( @@ -83,6 +104,16 @@ impl Declaration { ), } } + + pub fn signature_range_in_item_text(&self) -> Range { + match self { + Declaration::File { declaration, .. } => declaration.signature_range_in_text.clone(), + Declaration::Buffer { declaration, .. } => { + declaration.signature_range.start - declaration.item_range.start + ..declaration.signature_range.end - declaration.item_range.start + } + } + } } fn expand_range_to_line_boundaries_and_truncate( diff --git a/crates/edit_prediction_context/src/declaration_scoring.rs b/crates/edit_prediction_context/src/declaration_scoring.rs index 4cbc4e83c02e0cae912813a261d99f8fe8c41b55..1638723857d225d05efce512bb9025fa89fb38f3 100644 --- a/crates/edit_prediction_context/src/declaration_scoring.rs +++ b/crates/edit_prediction_context/src/declaration_scoring.rs @@ -1,10 +1,11 @@ +use cloud_llm_client::predict_edits_v3::ScoreComponents; use itertools::Itertools as _; use language::BufferSnapshot; use ordered_float::OrderedFloat; use serde::Serialize; use std::{collections::HashMap, ops::Range}; use strum::EnumIter; -use text::{OffsetRangeExt, Point, ToPoint}; +use text::{Point, ToPoint}; use crate::{ Declaration, EditPredictionExcerpt, EditPredictionExcerptText, Identifier, @@ -15,19 +16,14 @@ use crate::{ const MAX_IDENTIFIER_DECLARATION_COUNT: usize = 16; -// TODO: -// -// * Consider adding declaration_file_count - #[derive(Clone, Debug)] pub struct ScoredSnippet { pub identifier: Identifier, pub declaration: Declaration, - pub score_components: ScoreInputs, + pub score_components: ScoreComponents, pub scores: Scores, } -// TODO: Consider having "Concise" style corresponding to `concise_text` #[derive(EnumIter, Clone, Copy, PartialEq, Eq, Hash, Debug)] pub enum SnippetStyle { Signature, @@ -90,8 +86,8 @@ pub fn scored_snippets( let declaration_count = declarations.len(); declarations - .iter() - .filter_map(|declaration| match declaration { + .into_iter() + .filter_map(|(declaration_id, declaration)| match declaration { Declaration::Buffer { buffer_id, declaration: buffer_declaration, @@ -100,24 +96,29 @@ pub fn scored_snippets( let is_same_file = buffer_id == ¤t_buffer.remote_id(); if is_same_file { - range_intersection( - &buffer_declaration.item_range.to_offset(¤t_buffer), - &excerpt.range, - ) - .is_none() - .then(|| { + let overlaps_excerpt = + range_intersection(&buffer_declaration.item_range, &excerpt.range) + .is_some(); + if overlaps_excerpt + || excerpt + .parent_declarations + .iter() + .any(|(excerpt_parent, _)| excerpt_parent == &declaration_id) + { + None + } else { let declaration_line = buffer_declaration .item_range .start .to_point(current_buffer) .row; - ( + Some(( true, (cursor_point.row as i32 - declaration_line as i32) .unsigned_abs(), declaration, - ) - }) + )) + } } else { Some((false, u32::MAX, declaration)) } @@ -238,7 +239,8 @@ fn score_snippet( let adjacent_vs_signature_weighted_overlap = weighted_overlap_coefficient(adjacent_identifier_occurrences, &item_signature_occurrences); - let score_components = ScoreInputs { + // TODO: Consider adding declaration_file_count + let score_components = ScoreComponents { is_same_file, is_referenced_nearby, is_referenced_in_breadcrumb, @@ -261,51 +263,30 @@ fn score_snippet( Some(ScoredSnippet { identifier: identifier.clone(), declaration: declaration, - scores: score_components.score(), + scores: Scores::score(&score_components), score_components, }) } -#[derive(Clone, Debug, Serialize)] -pub struct ScoreInputs { - pub is_same_file: bool, - pub is_referenced_nearby: bool, - pub is_referenced_in_breadcrumb: bool, - pub reference_count: usize, - pub same_file_declaration_count: usize, - pub declaration_count: usize, - pub reference_line_distance: u32, - pub declaration_line_distance: u32, - pub declaration_line_distance_rank: usize, - pub containing_range_vs_item_jaccard: f32, - pub containing_range_vs_signature_jaccard: f32, - pub adjacent_vs_item_jaccard: f32, - pub adjacent_vs_signature_jaccard: f32, - pub containing_range_vs_item_weighted_overlap: f32, - pub containing_range_vs_signature_weighted_overlap: f32, - pub adjacent_vs_item_weighted_overlap: f32, - pub adjacent_vs_signature_weighted_overlap: f32, -} - #[derive(Clone, Debug, Serialize)] pub struct Scores { pub signature: f32, pub declaration: f32, } -impl ScoreInputs { - fn score(&self) -> Scores { +impl Scores { + fn score(components: &ScoreComponents) -> Scores { // Score related to how likely this is the correct declaration, range 0 to 1 - let accuracy_score = if self.is_same_file { + let accuracy_score = if components.is_same_file { // TODO: use declaration_line_distance_rank - 1.0 / self.same_file_declaration_count as f32 + 1.0 / components.same_file_declaration_count as f32 } else { - 1.0 / self.declaration_count as f32 + 1.0 / components.declaration_count as f32 }; // Score related to the distance between the reference and cursor, range 0 to 1 - let distance_score = if self.is_referenced_nearby { - 1.0 / (1.0 + self.reference_line_distance as f32 / 10.0).powf(2.0) + let distance_score = if components.is_referenced_nearby { + 1.0 / (1.0 + components.reference_line_distance as f32 / 10.0).powf(2.0) } else { // same score as ~14 lines away, rationale is to not overly penalize references from parent signatures 0.5 @@ -315,10 +296,12 @@ impl ScoreInputs { let combined_score = 10.0 * accuracy_score * distance_score; Scores { - signature: combined_score * self.containing_range_vs_signature_weighted_overlap, + signature: combined_score * components.containing_range_vs_signature_weighted_overlap, // declaration score gets boosted both by being multiplied by 2 and by there being more // weighted overlap. - declaration: 2.0 * combined_score * self.containing_range_vs_item_weighted_overlap, + declaration: 2.0 + * combined_score + * components.containing_range_vs_item_weighted_overlap, } } } diff --git a/crates/edit_prediction_context/src/edit_prediction_context.rs b/crates/edit_prediction_context/src/edit_prediction_context.rs index aed2953777d82d65b7e9cb42229d78634d5e4a3d..aeda74811296b70fc48198b1c3f72a50cfd7c31e 100644 --- a/crates/edit_prediction_context/src/edit_prediction_context.rs +++ b/crates/edit_prediction_context/src/edit_prediction_context.rs @@ -6,62 +6,82 @@ mod reference; mod syntax_index; mod text_similarity; -use std::time::Instant; - -pub use declaration::{BufferDeclaration, Declaration, FileDeclaration, Identifier}; -pub use declaration_scoring::SnippetStyle; -pub use excerpt::{EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionExcerptText}; - use gpui::{App, AppContext as _, Entity, Task}; use language::BufferSnapshot; -pub use reference::references_in_excerpt; -pub use syntax_index::SyntaxIndex; use text::{Point, ToOffset as _}; -use crate::declaration_scoring::{ScoredSnippet, scored_snippets}; +pub use declaration::*; +pub use declaration_scoring::*; +pub use excerpt::*; +pub use reference::*; +pub use syntax_index::*; #[derive(Debug)] pub struct EditPredictionContext { pub excerpt: EditPredictionExcerpt, pub excerpt_text: EditPredictionExcerptText, + pub cursor_offset_in_excerpt: usize, pub snippets: Vec, - pub retrieval_duration: std::time::Duration, } impl EditPredictionContext { - pub fn gather( + pub fn gather_context_in_background( cursor_point: Point, buffer: BufferSnapshot, excerpt_options: EditPredictionExcerptOptions, - syntax_index: Entity, + syntax_index: Option>, cx: &mut App, ) -> Task> { - let start = Instant::now(); - let index_state = syntax_index.read_with(cx, |index, _cx| index.state().clone()); - cx.background_spawn(async move { - let index_state = index_state.lock().await; - - let excerpt = - EditPredictionExcerpt::select_from_buffer(cursor_point, &buffer, &excerpt_options)?; - let excerpt_text = excerpt.text(&buffer); - let references = references_in_excerpt(&excerpt, &excerpt_text, &buffer); - let cursor_offset = cursor_point.to_offset(&buffer); - - let snippets = scored_snippets( + if let Some(syntax_index) = syntax_index { + let index_state = syntax_index.read_with(cx, |index, _cx| index.state().clone()); + cx.background_spawn(async move { + let index_state = index_state.lock().await; + Self::gather_context(cursor_point, &buffer, &excerpt_options, Some(&index_state)) + }) + } else { + cx.background_spawn(async move { + Self::gather_context(cursor_point, &buffer, &excerpt_options, None) + }) + } + } + + pub fn gather_context( + cursor_point: Point, + buffer: &BufferSnapshot, + excerpt_options: &EditPredictionExcerptOptions, + index_state: Option<&SyntaxIndexState>, + ) -> Option { + let excerpt = EditPredictionExcerpt::select_from_buffer( + cursor_point, + buffer, + excerpt_options, + index_state, + )?; + let excerpt_text = excerpt.text(buffer); + let cursor_offset_in_file = cursor_point.to_offset(buffer); + // TODO fix this to not need saturating_sub + let cursor_offset_in_excerpt = cursor_offset_in_file.saturating_sub(excerpt.range.start); + + let snippets = if let Some(index_state) = index_state { + let references = references_in_excerpt(&excerpt, &excerpt_text, buffer); + + scored_snippets( &index_state, &excerpt, &excerpt_text, references, - cursor_offset, - &buffer, - ); - - Some(Self { - excerpt, - excerpt_text, - snippets, - retrieval_duration: start.elapsed(), - }) + cursor_offset_in_file, + buffer, + ) + } else { + vec![] + }; + + Some(Self { + excerpt, + excerpt_text, + cursor_offset_in_excerpt, + snippets, }) } } @@ -101,24 +121,28 @@ mod tests { let context = cx .update(|cx| { - EditPredictionContext::gather( + EditPredictionContext::gather_context_in_background( cursor_point, buffer_snapshot, EditPredictionExcerptOptions { - max_bytes: 40, + max_bytes: 60, min_bytes: 10, target_before_cursor_over_total_bytes: 0.5, - include_parent_signatures: false, }, - index, + Some(index), cx, ) }) .await .unwrap(); - assert_eq!(context.snippets.len(), 1); - assert_eq!(context.snippets[0].identifier.name.as_ref(), "process_data"); + let mut snippet_identifiers = context + .snippets + .iter() + .map(|snippet| snippet.identifier.name.as_ref()) + .collect::>(); + snippet_identifiers.sort(); + assert_eq!(snippet_identifiers, vec!["main", "process_data"]); drop(buffer); } diff --git a/crates/edit_prediction_context/src/excerpt.rs b/crates/edit_prediction_context/src/excerpt.rs index d27e75bd129147986d2359585d5e048704d8828f..764c9040f247561ba6058dfd2954f42085297a4c 100644 --- a/crates/edit_prediction_context/src/excerpt.rs +++ b/crates/edit_prediction_context/src/excerpt.rs @@ -1,9 +1,11 @@ use language::BufferSnapshot; use std::ops::Range; -use text::{OffsetRangeExt as _, Point, ToOffset as _, ToPoint as _}; +use text::{Point, ToOffset as _, ToPoint as _}; use tree_sitter::{Node, TreeCursor}; use util::RangeExt; +use crate::{BufferDeclaration, declaration::DeclarationId, syntax_index::SyntaxIndexState}; + // TODO: // // - Test parent signatures @@ -27,14 +29,12 @@ pub struct EditPredictionExcerptOptions { pub min_bytes: usize, /// Target ratio of bytes before the cursor divided by total bytes in the window. pub target_before_cursor_over_total_bytes: f32, - /// Whether to include parent signatures - pub include_parent_signatures: bool, } #[derive(Debug, Clone)] pub struct EditPredictionExcerpt { pub range: Range, - pub parent_signature_ranges: Vec>, + pub parent_declarations: Vec<(DeclarationId, Range)>, pub size: usize, } @@ -50,9 +50,9 @@ impl EditPredictionExcerpt { .text_for_range(self.range.clone()) .collect::(); let parent_signatures = self - .parent_signature_ranges + .parent_declarations .iter() - .map(|range| buffer.text_for_range(range.clone()).collect::()) + .map(|(_, range)| buffer.text_for_range(range.clone()).collect::()) .collect(); EditPredictionExcerptText { body, @@ -62,8 +62,9 @@ impl EditPredictionExcerpt { /// Selects an excerpt around a buffer position, attempting to choose logical boundaries based /// on TreeSitter structure and approximately targeting a goal ratio of bytesbefore vs after the - /// cursor. When `include_parent_signatures` is true, the excerpt also includes the signatures - /// of parent outline items. + /// cursor. + /// + /// When `index` is provided, the excerpt will include the signatures of parent outline items. /// /// First tries to use AST node boundaries to select the excerpt, and falls back on line-based /// expansion. @@ -73,6 +74,7 @@ impl EditPredictionExcerpt { query_point: Point, buffer: &BufferSnapshot, options: &EditPredictionExcerptOptions, + syntax_index: Option<&SyntaxIndexState>, ) -> Option { if buffer.len() <= options.max_bytes { log::debug!( @@ -90,17 +92,9 @@ impl EditPredictionExcerpt { return None; } - // TODO: Don't compute text / annotation_range / skip converting to and from anchors. - let outline_items = if options.include_parent_signatures { - buffer - .outline_items_containing(query_range.clone(), false, None) - .into_iter() - .flat_map(|item| { - Some(ExcerptOutlineItem { - item_range: item.range.to_offset(&buffer), - signature_range: item.signature_range?.to_offset(&buffer), - }) - }) + let parent_declarations = if let Some(syntax_index) = syntax_index { + syntax_index + .buffer_declarations_containing_range(buffer.remote_id(), query_range.clone()) .collect() } else { Vec::new() @@ -109,7 +103,7 @@ impl EditPredictionExcerpt { let excerpt_selector = ExcerptSelector { query_offset, query_range, - outline_items: &outline_items, + parent_declarations: &parent_declarations, buffer, options, }; @@ -132,15 +126,15 @@ impl EditPredictionExcerpt { excerpt_selector.select_lines() } - fn new(range: Range, parent_signature_ranges: Vec>) -> Self { + fn new(range: Range, parent_declarations: Vec<(DeclarationId, Range)>) -> Self { let size = range.len() - + parent_signature_ranges + + parent_declarations .iter() - .map(|r| r.len()) + .map(|(_, range)| range.len()) .sum::(); Self { range, - parent_signature_ranges, + parent_declarations, size, } } @@ -150,20 +144,14 @@ impl EditPredictionExcerpt { // this is an issue because parent_signature_ranges may be incorrect log::error!("bug: with_expanded_range called with disjoint range"); } - let mut parent_signature_ranges = Vec::with_capacity(self.parent_signature_ranges.len()); - let mut size = new_range.len(); - for range in &self.parent_signature_ranges { - if range.contains_inclusive(&new_range) { + let mut parent_declarations = Vec::with_capacity(self.parent_declarations.len()); + for (declaration_id, range) in &self.parent_declarations { + if !range.contains_inclusive(&new_range) { break; } - parent_signature_ranges.push(range.clone()); - size += range.len(); - } - Self { - range: new_range, - parent_signature_ranges, - size, + parent_declarations.push((*declaration_id, range.clone())); } + Self::new(new_range, parent_declarations) } fn parent_signatures_size(&self) -> usize { @@ -174,16 +162,11 @@ impl EditPredictionExcerpt { struct ExcerptSelector<'a> { query_offset: usize, query_range: Range, - outline_items: &'a [ExcerptOutlineItem], + parent_declarations: &'a [(DeclarationId, &'a BufferDeclaration)], buffer: &'a BufferSnapshot, options: &'a EditPredictionExcerptOptions, } -struct ExcerptOutlineItem { - item_range: Range, - signature_range: Range, -} - impl<'a> ExcerptSelector<'a> { /// Finds the largest node that is smaller than the window size and contains `query_range`. fn select_tree_sitter_nodes(&self) -> Option { @@ -396,13 +379,13 @@ impl<'a> ExcerptSelector<'a> { } fn make_excerpt(&self, range: Range) -> EditPredictionExcerpt { - let parent_signature_ranges = self - .outline_items + let parent_declarations = self + .parent_declarations .iter() - .filter(|item| item.item_range.contains_inclusive(&range)) - .map(|item| item.signature_range.clone()) + .filter(|(_, declaration)| declaration.item_range.contains_inclusive(&range)) + .map(|(id, declaration)| (*id, declaration.signature_range.clone())) .collect(); - EditPredictionExcerpt::new(range, parent_signature_ranges) + EditPredictionExcerpt::new(range, parent_declarations) } /// Returns `true` if the `forward` excerpt is a better choice than the `backward` excerpt. @@ -493,8 +476,9 @@ mod tests { let buffer = create_buffer(&text, cx); let cursor_point = cursor.to_point(&buffer); - let excerpt = EditPredictionExcerpt::select_from_buffer(cursor_point, &buffer, &options) - .expect("Should select an excerpt"); + let excerpt = + EditPredictionExcerpt::select_from_buffer(cursor_point, &buffer, &options, None) + .expect("Should select an excerpt"); pretty_assertions::assert_eq!( generate_marked_text(&text, std::slice::from_ref(&excerpt.range), false), generate_marked_text(&text, &[expected_excerpt], false) @@ -517,7 +501,6 @@ fn main() { max_bytes: 20, min_bytes: 10, target_before_cursor_over_total_bytes: 0.5, - include_parent_signatures: false, }; check_example(options, text, cx); @@ -541,7 +524,6 @@ fn bar() {}"#; max_bytes: 65, min_bytes: 10, target_before_cursor_over_total_bytes: 0.5, - include_parent_signatures: false, }; check_example(options, text, cx); @@ -561,7 +543,6 @@ fn main() { max_bytes: 50, min_bytes: 10, target_before_cursor_over_total_bytes: 0.5, - include_parent_signatures: false, }; check_example(options, text, cx); @@ -583,7 +564,6 @@ fn main() { max_bytes: 60, min_bytes: 45, target_before_cursor_over_total_bytes: 0.5, - include_parent_signatures: false, }; check_example(options, text, cx); @@ -608,7 +588,6 @@ fn main() { max_bytes: 120, min_bytes: 10, target_before_cursor_over_total_bytes: 0.6, - include_parent_signatures: false, }; check_example(options, text, cx); diff --git a/crates/edit_prediction_context/src/reference.rs b/crates/edit_prediction_context/src/reference.rs index abb7dd75dd569f7a14bea807bdc64441d0e64871..268f8c39ef84ba29593f502aff7e818e931cc873 100644 --- a/crates/edit_prediction_context/src/reference.rs +++ b/crates/edit_prediction_context/src/reference.rs @@ -33,8 +33,8 @@ pub fn references_in_excerpt( snapshot, ); - for (range, text) in excerpt - .parent_signature_ranges + for ((_, range), text) in excerpt + .parent_declarations .iter() .zip(excerpt_text.parent_signatures.iter()) { diff --git a/crates/edit_prediction_context/src/syntax_index.rs b/crates/edit_prediction_context/src/syntax_index.rs index 64982f5805f08a3ba791578e28778f0c8399fde8..d234e975d504c145d7bc2fc0680569c388ba0d1c 100644 --- a/crates/edit_prediction_context/src/syntax_index.rs +++ b/crates/edit_prediction_context/src/syntax_index.rs @@ -1,5 +1,3 @@ -use std::sync::Arc; - use collections::{HashMap, HashSet}; use futures::lock::Mutex; use gpui::{App, AppContext as _, Context, Entity, Task, WeakEntity}; @@ -8,20 +6,17 @@ use project::buffer_store::{BufferStore, BufferStoreEvent}; use project::worktree_store::{WorktreeStore, WorktreeStoreEvent}; use project::{PathChange, Project, ProjectEntryId, ProjectPath}; use slotmap::SlotMap; +use std::iter; +use std::ops::Range; +use std::sync::Arc; use text::BufferId; -use util::{debug_panic, some_or_debug_panic}; +use util::{RangeExt as _, debug_panic, some_or_debug_panic}; use crate::declaration::{ BufferDeclaration, Declaration, DeclarationId, FileDeclaration, Identifier, }; use crate::outline::declarations_in_buffer; -// TODO: -// -// * Skip for remote projects -// -// * Consider making SyntaxIndex not an Entity. - // Potential future improvements: // // * Send multiple selected excerpt ranges. Challenge is that excerpt ranges influence which @@ -40,7 +35,6 @@ use crate::outline::declarations_in_buffer; // * Concurrent slotmap // // * Use queue for parsing -// pub struct SyntaxIndex { state: Arc>, @@ -432,7 +426,7 @@ impl SyntaxIndexState { pub fn declarations_for_identifier( &self, identifier: &Identifier, - ) -> Vec { + ) -> Vec<(DeclarationId, &Declaration)> { // make sure to not have a large stack allocation assert!(N < 32); @@ -454,7 +448,7 @@ impl SyntaxIndexState { project_entry_id, .. } => { included_buffer_entry_ids.push(*project_entry_id); - result.push(declaration.clone()); + result.push((*declaration_id, declaration)); if result.len() == N { return Vec::new(); } @@ -463,19 +457,19 @@ impl SyntaxIndexState { project_entry_id, .. } => { if !included_buffer_entry_ids.contains(&project_entry_id) { - file_declarations.push(declaration.clone()); + file_declarations.push((*declaration_id, declaration)); } } } } - for declaration in file_declarations { + for (declaration_id, declaration) in file_declarations { match declaration { Declaration::File { project_entry_id, .. } => { if !included_buffer_entry_ids.contains(&project_entry_id) { - result.push(declaration); + result.push((declaration_id, declaration)); if result.len() == N { return Vec::new(); @@ -489,6 +483,35 @@ impl SyntaxIndexState { result } + pub fn buffer_declarations_containing_range( + &self, + buffer_id: BufferId, + range: Range, + ) -> impl Iterator { + let Some(buffer_state) = self.buffers.get(&buffer_id) else { + return itertools::Either::Left(iter::empty()); + }; + + let iter = buffer_state + .declarations + .iter() + .filter_map(move |declaration_id| { + let Some(declaration) = self + .declarations + .get(*declaration_id) + .and_then(|d| d.as_buffer()) + else { + log::error!("bug: missing buffer outline declaration"); + return None; + }; + if declaration.item_range.contains_inclusive(&range) { + return Some((*declaration_id, declaration)); + } + return None; + }); + itertools::Either::Right(iter) + } + pub fn file_declaration_count(&self, declaration: &Declaration) -> usize { match declaration { Declaration::File { @@ -553,11 +576,11 @@ mod tests { let decls = index_state.declarations_for_identifier::<8>(&main); assert_eq!(decls.len(), 2); - let decl = expect_file_decl("c.rs", &decls[0], &project, cx); + let decl = expect_file_decl("c.rs", &decls[0].1, &project, cx); assert_eq!(decl.identifier, main.clone()); assert_eq!(decl.item_range_in_file, 32..280); - let decl = expect_file_decl("a.rs", &decls[1], &project, cx); + let decl = expect_file_decl("a.rs", &decls[1].1, &project, cx); assert_eq!(decl.identifier, main); assert_eq!(decl.item_range_in_file, 0..98); }); @@ -577,7 +600,7 @@ mod tests { let decls = index_state.declarations_for_identifier::<8>(&test_process_data); assert_eq!(decls.len(), 1); - let decl = expect_file_decl("c.rs", &decls[0], &project, cx); + let decl = expect_file_decl("c.rs", &decls[0].1, &project, cx); assert_eq!(decl.identifier, test_process_data); let parent_id = decl.parent.unwrap(); @@ -618,7 +641,7 @@ mod tests { let decls = index_state.declarations_for_identifier::<8>(&test_process_data); assert_eq!(decls.len(), 1); - let decl = expect_buffer_decl("c.rs", &decls[0], &project, cx); + let decl = expect_buffer_decl("c.rs", &decls[0].1, &project, cx); assert_eq!(decl.identifier, test_process_data); let parent_id = decl.parent.unwrap(); @@ -676,11 +699,11 @@ mod tests { cx.update(|cx| { let decls = index_state.declarations_for_identifier::<8>(&main); assert_eq!(decls.len(), 2); - let decl = expect_buffer_decl("c.rs", &decls[0], &project, cx); + let decl = expect_buffer_decl("c.rs", &decls[0].1, &project, cx); assert_eq!(decl.identifier, main); assert_eq!(decl.item_range.to_offset(&buffer.read(cx)), 32..280); - expect_file_decl("a.rs", &decls[1], &project, cx); + expect_file_decl("a.rs", &decls[1].1, &project, cx); }); } @@ -695,8 +718,8 @@ mod tests { cx.update(|cx| { let decls = index_state.declarations_for_identifier::<8>(&main); assert_eq!(decls.len(), 2); - expect_file_decl("c.rs", &decls[0], &project, cx); - expect_file_decl("a.rs", &decls[1], &project, cx); + expect_file_decl("c.rs", &decls[0].1, &project, cx); + expect_file_decl("a.rs", &decls[1].1, &project, cx); }); } diff --git a/crates/edit_prediction_context/src/text_similarity.rs b/crates/edit_prediction_context/src/text_similarity.rs index f7a9822ecca01f1b6ff1dc04bdc12fbcddc5159b..2ace7bf10cc6fd13b8a5636212211a3274d3c259 100644 --- a/crates/edit_prediction_context/src/text_similarity.rs +++ b/crates/edit_prediction_context/src/text_similarity.rs @@ -9,8 +9,12 @@ use crate::reference::Reference; // That implementation could actually be more efficient - no need to track words in the window that // are not in the query. +// TODO: Consider a flat sorted Vec<(String, usize)> representation. Intersection can just walk the +// two in parallel. + static IDENTIFIER_REGEX: LazyLock = LazyLock::new(|| Regex::new(r"\b\w+\b").unwrap()); +// TODO: use &str or Cow keys? #[derive(Debug)] pub struct IdentifierOccurrences { identifier_to_count: HashMap, diff --git a/crates/edit_prediction_tools/src/edit_prediction_tools.rs b/crates/edit_prediction_tools/src/edit_prediction_tools.rs index f00a16e026704f1d1da318956f41128a9783a54c..7e18b468bc6aac5ab53a9ce55f195644e19e9367 100644 --- a/crates/edit_prediction_tools/src/edit_prediction_tools.rs +++ b/crates/edit_prediction_tools/src/edit_prediction_tools.rs @@ -4,7 +4,7 @@ use std::{ path::{Path, PathBuf}, str::FromStr, sync::Arc, - time::Duration, + time::{Duration, Instant}, }; use collections::HashMap; @@ -195,6 +195,8 @@ impl EditPredictionTools { .timer(Duration::from_millis(50)) .await; + let mut start_time = None; + let Ok(task) = this.update(cx, |this, cx| { fn number_input_value( input: &Entity, @@ -216,15 +218,16 @@ impl EditPredictionTools { &this.cursor_context_ratio_input, cx, ), - // TODO Display and add to options - include_parent_signatures: false, }; - EditPredictionContext::gather( + start_time = Some(Instant::now()); + + // TODO use global zeta instead + EditPredictionContext::gather_context_in_background( cursor_position, current_buffer_snapshot, options, - this.syntax_index.clone(), + Some(this.syntax_index.clone()), cx, ) }) else { @@ -243,6 +246,7 @@ impl EditPredictionTools { .ok(); return; }; + let retrieval_duration = start_time.unwrap().elapsed(); let mut languages = HashMap::default(); for snippet in context.snippets.iter() { @@ -320,7 +324,7 @@ impl EditPredictionTools { this.last_context = Some(ContextState { context_editor, - retrieval_duration: context.retrieval_duration, + retrieval_duration, }); cx.notify(); }) diff --git a/crates/settings/src/settings_content/language.rs b/crates/settings/src/settings_content/language.rs index 6052afee671edba49e05b56ddef147a01866e364..24a3de0c3b918e86488c42d9d12fedb8e16081c6 100644 --- a/crates/settings/src/settings_content/language.rs +++ b/crates/settings/src/settings_content/language.rs @@ -84,6 +84,17 @@ pub enum EditPredictionProvider { Zed, } +impl EditPredictionProvider { + pub fn is_zed(&self) -> bool { + match self { + EditPredictionProvider::Zed => true, + EditPredictionProvider::None + | EditPredictionProvider::Copilot + | EditPredictionProvider::Supermaven => false, + } + } +} + /// The contents of the edit prediction settings. #[skip_serializing_none] #[derive(Clone, Debug, Default, Serialize, Deserialize, JsonSchema, MergeFrom, PartialEq)] diff --git a/crates/zed/Cargo.toml b/crates/zed/Cargo.toml index 5b6cb3924610b89406a37230497fee8ffc511e34..3d7bbe8642c64d7cfff9bfb9169c3df17e18d854 100644 --- a/crates/zed/Cargo.toml +++ b/crates/zed/Cargo.toml @@ -163,6 +163,7 @@ workspace.workspace = true zed_actions.workspace = true zed_env_vars.workspace = true zeta.workspace = true +zeta2.workspace = true zlog.workspace = true zlog_settings.workspace = true diff --git a/crates/zed/src/zed/edit_prediction_registry.rs b/crates/zed/src/zed/edit_prediction_registry.rs index ae26427fc6547079b163235f5d1c3df26a489795..d0e8e26074296e5b54bccaa73de7e06e4aacf205 100644 --- a/crates/zed/src/zed/edit_prediction_registry.rs +++ b/crates/zed/src/zed/edit_prediction_registry.rs @@ -203,21 +203,43 @@ fn assign_edit_prediction_provider( } } - let zeta = zeta::Zeta::register(worktree, client.clone(), user_store, cx); - - if let Some(buffer) = &singleton_buffer - && buffer.read(cx).file().is_some() - && let Some(project) = editor.project() - { - zeta.update(cx, |zeta, cx| { - zeta.register_buffer(buffer, project, cx); + if std::env::var("ZED_ZETA2").is_ok() { + let zeta = zeta2::Zeta::global(client, &user_store, cx); + let provider = cx.new(|cx| { + zeta2::ZetaEditPredictionProvider::new( + editor.project(), + &client, + &user_store, + cx, + ) }); - } - let provider = - cx.new(|_| zeta::ZetaEditPredictionProvider::new(zeta, singleton_buffer)); + if let Some(buffer) = &singleton_buffer + && buffer.read(cx).file().is_some() + && let Some(project) = editor.project() + { + zeta.update(cx, |zeta, cx| { + zeta.register_buffer(buffer, project, cx); + }); + } - editor.set_edit_prediction_provider(Some(provider), window, cx); + editor.set_edit_prediction_provider(Some(provider), window, cx); + } else { + let zeta = zeta::Zeta::register(worktree, client.clone(), user_store, cx); + + if let Some(buffer) = &singleton_buffer + && buffer.read(cx).file().is_some() + && let Some(project) = editor.project() + { + zeta.update(cx, |zeta, cx| { + zeta.register_buffer(buffer, project, cx); + }); + } + + let provider = + cx.new(|_| zeta::ZetaEditPredictionProvider::new(zeta, singleton_buffer)); + editor.set_edit_prediction_provider(Some(provider), window, cx); + } } } } diff --git a/crates/zeta2/Cargo.toml b/crates/zeta2/Cargo.toml new file mode 100644 index 0000000000000000000000000000000000000000..61c560baab543baf9d57b034807fe60cb566b24f --- /dev/null +++ b/crates/zeta2/Cargo.toml @@ -0,0 +1,37 @@ +[package] +name = "zeta2" +version = "0.1.0" +edition.workspace = true +publish.workspace = true +license = "GPL-3.0-or-later" + +[lints] +workspace = true + +[lib] +path = "src/zeta2.rs" + +[dependencies] +anyhow.workspace = true +arrayvec.workspace = true +client.workspace = true +cloud_llm_client.workspace = true +edit_prediction.workspace = true +edit_prediction_context.workspace = true +futures.workspace = true +gpui.workspace = true +language.workspace = true +language_model.workspace = true +log.workspace = true +project.workspace = true +release_channel.workspace = true +serde_json.workspace = true +thiserror.workspace = true +util.workspace = true +uuid.workspace = true +workspace.workspace = true +workspace-hack.workspace = true +worktree.workspace = true + +[dev-dependencies] +gpui = { workspace = true, features = ["test-support"] } diff --git a/crates/zeta2/LICENSE-GPL b/crates/zeta2/LICENSE-GPL new file mode 120000 index 0000000000000000000000000000000000000000..89e542f750cd3860a0598eff0dc34b56d7336dc4 --- /dev/null +++ b/crates/zeta2/LICENSE-GPL @@ -0,0 +1 @@ +../../LICENSE-GPL \ No newline at end of file diff --git a/crates/zeta2/src/zeta2.rs b/crates/zeta2/src/zeta2.rs new file mode 100644 index 0000000000000000000000000000000000000000..791273a9242dd7aa50588fcfe90e9258ff3724ea --- /dev/null +++ b/crates/zeta2/src/zeta2.rs @@ -0,0 +1,1130 @@ +use anyhow::{Context as _, Result, anyhow}; +use arrayvec::ArrayVec; +use client::{Client, EditPredictionUsage, UserStore}; +use cloud_llm_client::predict_edits_v3::{self, Signature}; +use cloud_llm_client::{ + EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME, ZED_VERSION_HEADER_NAME, +}; +use edit_prediction::{DataCollectionState, Direction, EditPredictionProvider}; +use edit_prediction_context::{ + DeclarationId, EditPredictionContext, EditPredictionExcerptOptions, SyntaxIndex, + SyntaxIndexState, +}; +use futures::AsyncReadExt as _; +use gpui::http_client::Method; +use gpui::{ + App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, http_client, + prelude::*, +}; +use language::{Anchor, Buffer, OffsetRangeExt as _, ToPoint}; +use language::{BufferSnapshot, EditPreview}; +use language_model::{LlmApiToken, RefreshLlmTokenListener}; +use project::Project; +use release_channel::AppVersion; +use std::cmp; +use std::collections::{HashMap, VecDeque, hash_map}; +use std::path::PathBuf; +use std::str::FromStr as _; +use std::time::{Duration, Instant}; +use std::{ops::Range, sync::Arc}; +use thiserror::Error; +use util::ResultExt as _; +use uuid::Uuid; +use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification}; + +const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1); + +/// Maximum number of events to track. +const MAX_EVENT_COUNT: usize = 16; + +#[derive(Clone)] +struct ZetaGlobal(Entity); + +impl Global for ZetaGlobal {} + +pub struct Zeta { + client: Arc, + user_store: Entity, + llm_token: LlmApiToken, + _llm_token_subscription: Subscription, + projects: HashMap, + excerpt_options: EditPredictionExcerptOptions, + update_required: bool, +} + +struct ZetaProject { + syntax_index: Entity, + events: VecDeque, + registered_buffers: HashMap, +} + +struct RegisteredBuffer { + snapshot: BufferSnapshot, + _subscriptions: [gpui::Subscription; 2], +} + +#[derive(Clone)] +pub enum Event { + BufferChange { + old_snapshot: BufferSnapshot, + new_snapshot: BufferSnapshot, + timestamp: Instant, + }, +} + +impl Zeta { + pub fn global( + client: &Arc, + user_store: &Entity, + cx: &mut App, + ) -> Entity { + cx.try_global::() + .map(|global| global.0.clone()) + .unwrap_or_else(|| { + let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx)); + cx.set_global(ZetaGlobal(zeta.clone())); + zeta + }) + } + + fn new(client: Arc, user_store: Entity, cx: &mut Context) -> Self { + let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx); + + Self { + projects: HashMap::new(), + client, + user_store, + excerpt_options: EditPredictionExcerptOptions { + max_bytes: 512, + min_bytes: 128, + target_before_cursor_over_total_bytes: 0.5, + }, + llm_token: LlmApiToken::default(), + _llm_token_subscription: cx.subscribe( + &refresh_llm_token_listener, + |this, _listener, _event, cx| { + let client = this.client.clone(); + let llm_token = this.llm_token.clone(); + cx.spawn(async move |_this, _cx| { + llm_token.refresh(&client).await?; + anyhow::Ok(()) + }) + .detach_and_log_err(cx); + }, + ), + update_required: false, + } + } + + pub fn usage(&self, cx: &App) -> Option { + self.user_store.read(cx).edit_prediction_usage() + } + + pub fn register_project(&mut self, project: &Entity, cx: &mut App) { + self.get_or_init_zeta_project(project, cx); + } + + pub fn register_buffer( + &mut self, + buffer: &Entity, + project: &Entity, + cx: &mut Context, + ) { + let zeta_project = self.get_or_init_zeta_project(project, cx); + Self::register_buffer_impl(zeta_project, buffer, project, cx); + } + + fn get_or_init_zeta_project( + &mut self, + project: &Entity, + cx: &mut App, + ) -> &mut ZetaProject { + self.projects + .entry(project.entity_id()) + .or_insert_with(|| ZetaProject { + syntax_index: cx.new(|cx| SyntaxIndex::new(project, cx)), + events: VecDeque::new(), + registered_buffers: HashMap::new(), + }) + } + + fn register_buffer_impl<'a>( + zeta_project: &'a mut ZetaProject, + buffer: &Entity, + project: &Entity, + cx: &mut Context, + ) -> &'a mut RegisteredBuffer { + let buffer_id = buffer.entity_id(); + match zeta_project.registered_buffers.entry(buffer_id) { + hash_map::Entry::Occupied(entry) => entry.into_mut(), + hash_map::Entry::Vacant(entry) => { + let snapshot = buffer.read(cx).snapshot(); + let project_entity_id = project.entity_id(); + entry.insert(RegisteredBuffer { + snapshot, + _subscriptions: [ + cx.subscribe(buffer, { + let project = project.downgrade(); + move |this, buffer, event, cx| { + if let language::BufferEvent::Edited = event + && let Some(project) = project.upgrade() + { + this.report_changes_for_buffer(&buffer, &project, cx); + } + } + }), + cx.observe_release(buffer, move |this, _buffer, _cx| { + let Some(zeta_project) = this.projects.get_mut(&project_entity_id) + else { + return; + }; + zeta_project.registered_buffers.remove(&buffer_id); + }), + ], + }) + } + } + } + + fn report_changes_for_buffer( + &mut self, + buffer: &Entity, + project: &Entity, + cx: &mut Context, + ) -> BufferSnapshot { + let zeta_project = self.get_or_init_zeta_project(project, cx); + let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx); + + let new_snapshot = buffer.read(cx).snapshot(); + if new_snapshot.version != registered_buffer.snapshot.version { + let old_snapshot = + std::mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone()); + Self::push_event( + zeta_project, + Event::BufferChange { + old_snapshot, + new_snapshot: new_snapshot.clone(), + timestamp: Instant::now(), + }, + ); + } + + new_snapshot + } + + fn push_event(zeta_project: &mut ZetaProject, event: Event) { + let events = &mut zeta_project.events; + + if let Some(Event::BufferChange { + new_snapshot: last_new_snapshot, + timestamp: last_timestamp, + .. + }) = events.back_mut() + { + // Coalesce edits for the same buffer when they happen one after the other. + let Event::BufferChange { + old_snapshot, + new_snapshot, + timestamp, + } = &event; + + if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL + && old_snapshot.remote_id() == last_new_snapshot.remote_id() + && old_snapshot.version == last_new_snapshot.version + { + *last_new_snapshot = new_snapshot.clone(); + *last_timestamp = *timestamp; + return; + } + } + + if events.len() >= MAX_EVENT_COUNT { + // These are halved instead of popping to improve prompt caching. + events.drain(..MAX_EVENT_COUNT / 2); + } + + events.push_back(event); + } + + pub fn request_prediction( + &mut self, + project: &Entity, + buffer: &Entity, + position: language::Anchor, + cx: &mut Context, + ) -> Task>> { + let project_state = self.projects.get(&project.entity_id()); + + let index_state = project_state.map(|state| { + state + .syntax_index + .read_with(cx, |index, _cx| index.state().clone()) + }); + let excerpt_options = self.excerpt_options.clone(); + let snapshot = buffer.read(cx).snapshot(); + let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else { + return Task::ready(Err(anyhow!("No file path for excerpt"))); + }; + let client = self.client.clone(); + let llm_token = self.llm_token.clone(); + let app_version = AppVersion::global(cx); + let worktree_snapshots = project + .read(cx) + .worktrees(cx) + .map(|worktree| worktree.read(cx).snapshot()) + .collect::>(); + + let request_task = cx.background_spawn({ + let snapshot = snapshot.clone(); + async move { + let index_state = if let Some(index_state) = index_state { + Some(index_state.lock_owned().await) + } else { + None + }; + + let cursor_point = position.to_point(&snapshot); + + // TODO: make this only true if debug view is open + let debug_info = true; + + let Some(request) = EditPredictionContext::gather_context( + cursor_point, + &snapshot, + &excerpt_options, + index_state.as_deref(), + ) + .map(|context| { + make_cloud_request( + excerpt_path.clone(), + context, + // TODO pass everything + Vec::new(), + false, + Vec::new(), + None, + debug_info, + &worktree_snapshots, + index_state.as_deref(), + ) + }) else { + return Ok(None); + }; + + anyhow::Ok(Some( + Self::perform_request(client, llm_token, app_version, request).await?, + )) + } + }); + + let buffer = buffer.clone(); + + cx.spawn(async move |this, cx| { + match request_task.await { + Ok(Some((response, usage))) => { + log::debug!("predicted edits: {:?}", &response.edits); + + if let Some(usage) = usage { + this.update(cx, |this, cx| { + this.user_store.update(cx, |user_store, cx| { + user_store.update_edit_prediction_usage(usage, cx); + }); + }) + .ok(); + } + + // TODO telemetry: duration, etc + + // TODO produce smaller edits by diffing against snapshot first + // + // Cloud returns entire snippets/excerpts ranges as they were included + // in the request, but we should display smaller edits to the user. + // + // We can do this by computing a diff of each one against the snapshot. + // Similar to zeta::Zeta::compute_edits, but per edit. + let edits = response + .edits + .into_iter() + .map(|edit| { + // TODO edits to different files + ( + snapshot.anchor_before(edit.range.start) + ..snapshot.anchor_before(edit.range.end), + edit.content, + ) + }) + .collect::>() + .into(); + + let Some((edits, snapshot, edit_preview_task)) = + buffer.read_with(cx, |buffer, cx| { + let new_snapshot = buffer.snapshot(); + let edits: Arc<[_]> = + interpolate(&snapshot, &new_snapshot, edits)?.into(); + Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx))) + })? + else { + return Ok(None); + }; + + Ok(Some(EditPrediction { + id: EditPredictionId(response.request_id), + edits, + snapshot, + edit_preview: edit_preview_task.await, + })) + } + Ok(None) => Ok(None), + Err(err) => { + if err.is::() { + cx.update(|cx| { + this.update(cx, |this, _cx| { + this.update_required = true; + }) + .ok(); + + let error_message: SharedString = err.to_string().into(); + show_app_notification( + NotificationId::unique::(), + cx, + move |cx| { + cx.new(|cx| { + ErrorMessagePrompt::new(error_message.clone(), cx) + .with_link_button( + "Update Zed", + "https://zed.dev/releases", + ) + }) + }, + ); + }) + .ok(); + } + + Err(err) + } + } + }) + } + + async fn perform_request( + client: Arc, + llm_token: LlmApiToken, + app_version: SemanticVersion, + request: predict_edits_v3::PredictEditsRequest, + ) -> Result<( + predict_edits_v3::PredictEditsResponse, + Option, + )> { + let http_client = client.http_client(); + let mut token = llm_token.acquire(&client).await?; + let mut did_retry = false; + + loop { + let request_builder = http_client::Request::builder().method(Method::POST); + let request_builder = + if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") { + request_builder.uri(predict_edits_url) + } else { + request_builder.uri( + http_client + .build_zed_llm_url("/predict_edits/v3", &[])? + .as_ref(), + ) + }; + let request = request_builder + .header("Content-Type", "application/json") + .header("Authorization", format!("Bearer {}", token)) + .header(ZED_VERSION_HEADER_NAME, app_version.to_string()) + .body(serde_json::to_string(&request)?.into())?; + + let mut response = http_client.send(request).await?; + + if let Some(minimum_required_version) = response + .headers() + .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME) + .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok()) + { + anyhow::ensure!( + app_version >= minimum_required_version, + ZedUpdateRequiredError { + minimum_version: minimum_required_version + } + ); + } + + if response.status().is_success() { + let usage = EditPredictionUsage::from_headers(response.headers()).ok(); + + let mut body = Vec::new(); + response.body_mut().read_to_end(&mut body).await?; + return Ok((serde_json::from_slice(&body)?, usage)); + } else if !did_retry + && response + .headers() + .get(EXPIRED_LLM_TOKEN_HEADER_NAME) + .is_some() + { + did_retry = true; + token = llm_token.refresh(&client).await?; + } else { + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + anyhow::bail!( + "error predicting edits.\nStatus: {:?}\nBody: {}", + response.status(), + body + ); + } + } + } +} + +#[derive(Error, Debug)] +#[error( + "You must update to Zed version {minimum_version} or higher to continue using edit predictions." +)] +pub struct ZedUpdateRequiredError { + minimum_version: SemanticVersion, +} + +pub struct ZetaEditPredictionProvider { + zeta: Entity, + current_prediction: Option, + next_pending_prediction_id: usize, + pending_predictions: ArrayVec, + last_request_timestamp: Instant, +} + +impl ZetaEditPredictionProvider { + pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300); + + pub fn new( + project: Option<&Entity>, + client: &Arc, + user_store: &Entity, + cx: &mut App, + ) -> Self { + let zeta = Zeta::global(client, user_store, cx); + if let Some(project) = project { + zeta.update(cx, |zeta, cx| { + zeta.register_project(project, cx); + }); + } + + Self { + zeta, + current_prediction: None, + next_pending_prediction_id: 0, + pending_predictions: ArrayVec::new(), + last_request_timestamp: Instant::now(), + } + } +} + +#[derive(Clone)] +struct CurrentEditPrediction { + buffer_id: EntityId, + prediction: EditPrediction, +} + +impl CurrentEditPrediction { + fn should_replace_prediction(&self, old_prediction: &Self, snapshot: &BufferSnapshot) -> bool { + if self.buffer_id != old_prediction.buffer_id { + return true; + } + + let Some(old_edits) = old_prediction.prediction.interpolate(snapshot) else { + return true; + }; + let Some(new_edits) = self.prediction.interpolate(snapshot) else { + return false; + }; + + if old_edits.len() == 1 && new_edits.len() == 1 { + let (old_range, old_text) = &old_edits[0]; + let (new_range, new_text) = &new_edits[0]; + new_range == old_range && new_text.starts_with(old_text) + } else { + true + } + } +} + +#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)] +pub struct EditPredictionId(Uuid); + +impl From for gpui::ElementId { + fn from(value: EditPredictionId) -> Self { + gpui::ElementId::Uuid(value.0) + } +} + +impl std::fmt::Display for EditPredictionId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} + +#[derive(Clone)] +pub struct EditPrediction { + id: EditPredictionId, + edits: Arc<[(Range, String)]>, + snapshot: BufferSnapshot, + edit_preview: EditPreview, +} + +impl EditPrediction { + fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option, String)>> { + interpolate(&self.snapshot, new_snapshot, self.edits.clone()) + } +} + +struct PendingPrediction { + id: usize, + _task: Task<()>, +} + +impl EditPredictionProvider for ZetaEditPredictionProvider { + fn name() -> &'static str { + "zed-predict2" + } + + fn display_name() -> &'static str { + "Zed's Edit Predictions 2" + } + + fn show_completions_in_menu() -> bool { + true + } + + fn show_tab_accept_marker() -> bool { + true + } + + fn data_collection_state(&self, _cx: &App) -> DataCollectionState { + // TODO [zeta2] + DataCollectionState::Unsupported + } + + fn toggle_data_collection(&mut self, _cx: &mut App) { + // TODO [zeta2] + } + + fn usage(&self, cx: &App) -> Option { + self.zeta.read(cx).usage(cx) + } + + fn is_enabled( + &self, + _buffer: &Entity, + _cursor_position: language::Anchor, + _cx: &App, + ) -> bool { + true + } + + fn is_refreshing(&self) -> bool { + !self.pending_predictions.is_empty() + } + + fn refresh( + &mut self, + project: Option>, + buffer: Entity, + cursor_position: language::Anchor, + _debounce: bool, + cx: &mut Context, + ) { + let Some(project) = project else { + return; + }; + + if self + .zeta + .read(cx) + .user_store + .read_with(cx, |user_store, _cx| { + user_store.account_too_young() || user_store.has_overdue_invoices() + }) + { + return; + } + + if let Some(current_prediction) = self.current_prediction.as_ref() { + let snapshot = buffer.read(cx).snapshot(); + if current_prediction + .prediction + .interpolate(&snapshot) + .is_some() + { + return; + } + } + + let pending_prediction_id = self.next_pending_prediction_id; + self.next_pending_prediction_id += 1; + let last_request_timestamp = self.last_request_timestamp; + + let task = cx.spawn(async move |this, cx| { + if let Some(timeout) = (last_request_timestamp + Self::THROTTLE_TIMEOUT) + .checked_duration_since(Instant::now()) + { + cx.background_executor().timer(timeout).await; + } + + let prediction_request = this.update(cx, |this, cx| { + this.last_request_timestamp = Instant::now(); + this.zeta.update(cx, |zeta, cx| { + zeta.request_prediction(&project, &buffer, cursor_position, cx) + }) + }); + + let prediction = match prediction_request { + Ok(prediction_request) => { + let prediction_request = prediction_request.await; + prediction_request.map(|c| { + c.map(|prediction| CurrentEditPrediction { + buffer_id: buffer.entity_id(), + prediction, + }) + }) + } + Err(error) => Err(error), + }; + + this.update(cx, |this, cx| { + if this.pending_predictions[0].id == pending_prediction_id { + this.pending_predictions.remove(0); + } else { + this.pending_predictions.clear(); + } + + let Some(new_prediction) = prediction + .context("edit prediction failed") + .log_err() + .flatten() + else { + cx.notify(); + return; + }; + + if let Some(old_prediction) = this.current_prediction.as_ref() { + let snapshot = buffer.read(cx).snapshot(); + if new_prediction.should_replace_prediction(old_prediction, &snapshot) { + this.current_prediction = Some(new_prediction); + } + } else { + this.current_prediction = Some(new_prediction); + } + + cx.notify(); + }) + .ok(); + }); + + // We always maintain at most two pending predictions. When we already + // have two, we replace the newest one. + if self.pending_predictions.len() <= 1 { + self.pending_predictions.push(PendingPrediction { + id: pending_prediction_id, + _task: task, + }); + } else if self.pending_predictions.len() == 2 { + self.pending_predictions.pop(); + self.pending_predictions.push(PendingPrediction { + id: pending_prediction_id, + _task: task, + }); + } + + cx.notify(); + } + + fn cycle( + &mut self, + _buffer: Entity, + _cursor_position: language::Anchor, + _direction: Direction, + _cx: &mut Context, + ) { + } + + fn accept(&mut self, _cx: &mut Context) { + // TODO [zeta2] report accept + self.current_prediction.take(); + self.pending_predictions.clear(); + } + + fn discard(&mut self, _cx: &mut Context) { + self.pending_predictions.clear(); + self.current_prediction.take(); + } + + fn suggest( + &mut self, + buffer: &Entity, + cursor_position: language::Anchor, + cx: &mut Context, + ) -> Option { + let CurrentEditPrediction { + buffer_id, + prediction, + .. + } = self.current_prediction.as_mut()?; + + // Invalidate previous prediction if it was generated for a different buffer. + if *buffer_id != buffer.entity_id() { + self.current_prediction.take(); + return None; + } + + let buffer = buffer.read(cx); + let Some(edits) = prediction.interpolate(&buffer.snapshot()) else { + self.current_prediction.take(); + return None; + }; + + let cursor_row = cursor_position.to_point(buffer).row; + let (closest_edit_ix, (closest_edit_range, _)) = + edits.iter().enumerate().min_by_key(|(_, (range, _))| { + let distance_from_start = cursor_row.abs_diff(range.start.to_point(buffer).row); + let distance_from_end = cursor_row.abs_diff(range.end.to_point(buffer).row); + cmp::min(distance_from_start, distance_from_end) + })?; + + let mut edit_start_ix = closest_edit_ix; + for (range, _) in edits[..edit_start_ix].iter().rev() { + let distance_from_closest_edit = + closest_edit_range.start.to_point(buffer).row - range.end.to_point(buffer).row; + if distance_from_closest_edit <= 1 { + edit_start_ix -= 1; + } else { + break; + } + } + + let mut edit_end_ix = closest_edit_ix + 1; + for (range, _) in &edits[edit_end_ix..] { + let distance_from_closest_edit = + range.start.to_point(buffer).row - closest_edit_range.end.to_point(buffer).row; + if distance_from_closest_edit <= 1 { + edit_end_ix += 1; + } else { + break; + } + } + + Some(edit_prediction::EditPrediction { + id: Some(prediction.id.to_string().into()), + edits: edits[edit_start_ix..edit_end_ix].to_vec(), + edit_preview: Some(prediction.edit_preview.clone()), + }) + } +} + +fn make_cloud_request( + excerpt_path: PathBuf, + context: EditPredictionContext, + events: Vec, + can_collect_data: bool, + diagnostic_groups: Vec, + git_info: Option, + debug_info: bool, + worktrees: &Vec, + index_state: Option<&SyntaxIndexState>, +) -> predict_edits_v3::PredictEditsRequest { + let mut signatures = Vec::new(); + let mut declaration_to_signature_index = HashMap::default(); + let mut referenced_declarations = Vec::new(); + + for snippet in context.snippets { + let project_entry_id = snippet.declaration.project_entry_id(); + // TODO: Use full paths (worktree rooted) - need to move full_path method to the snapshot. + // Note that currently full_path is currently being used for excerpt_path. + let Some(path) = worktrees.iter().find_map(|worktree| { + let abs_path = worktree.abs_path(); + worktree + .entry_for_id(project_entry_id) + .map(|e| abs_path.join(&e.path)) + }) else { + continue; + }; + + let parent_index = index_state.and_then(|index_state| { + snippet.declaration.parent().and_then(|parent| { + add_signature( + parent, + &mut declaration_to_signature_index, + &mut signatures, + index_state, + ) + }) + }); + + let (text, text_is_truncated) = snippet.declaration.item_text(); + referenced_declarations.push(predict_edits_v3::ReferencedDeclaration { + path, + text: text.into(), + range: snippet.declaration.item_range(), + text_is_truncated, + signature_range: snippet.declaration.signature_range_in_item_text(), + parent_index, + score_components: snippet.score_components, + signature_score: snippet.scores.signature, + declaration_score: snippet.scores.declaration, + }); + } + + let excerpt_parent = index_state.and_then(|index_state| { + context + .excerpt + .parent_declarations + .last() + .and_then(|(parent, _)| { + add_signature( + *parent, + &mut declaration_to_signature_index, + &mut signatures, + index_state, + ) + }) + }); + + predict_edits_v3::PredictEditsRequest { + excerpt_path, + excerpt: context.excerpt_text.body, + excerpt_range: context.excerpt.range, + cursor_offset: context.cursor_offset_in_excerpt, + referenced_declarations, + signatures, + excerpt_parent, + events, + can_collect_data, + diagnostic_groups, + git_info, + debug_info, + } +} + +fn add_signature( + declaration_id: DeclarationId, + declaration_to_signature_index: &mut HashMap, + signatures: &mut Vec, + index: &SyntaxIndexState, +) -> Option { + if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) { + return Some(*signature_index); + } + let Some(parent_declaration) = index.declaration(declaration_id) else { + log::error!("bug: missing parent declaration"); + return None; + }; + let parent_index = parent_declaration.parent().and_then(|parent| { + add_signature(parent, declaration_to_signature_index, signatures, index) + }); + let (text, text_is_truncated) = parent_declaration.signature_text(); + let signature_index = signatures.len(); + signatures.push(Signature { + text: text.into(), + text_is_truncated, + parent_index, + }); + declaration_to_signature_index.insert(declaration_id, signature_index); + Some(signature_index) +} + +fn interpolate( + old_snapshot: &BufferSnapshot, + new_snapshot: &BufferSnapshot, + current_edits: Arc<[(Range, String)]>, +) -> Option, String)>> { + let mut edits = Vec::new(); + + let mut model_edits = current_edits.iter().peekable(); + for user_edit in new_snapshot.edits_since::(&old_snapshot.version) { + while let Some((model_old_range, _)) = model_edits.peek() { + let model_old_range = model_old_range.to_offset(old_snapshot); + if model_old_range.end < user_edit.old.start { + let (model_old_range, model_new_text) = model_edits.next().unwrap(); + edits.push((model_old_range.clone(), model_new_text.clone())); + } else { + break; + } + } + + if let Some((model_old_range, model_new_text)) = model_edits.peek() { + let model_old_offset_range = model_old_range.to_offset(old_snapshot); + if user_edit.old == model_old_offset_range { + let user_new_text = new_snapshot + .text_for_range(user_edit.new.clone()) + .collect::(); + + if let Some(model_suffix) = model_new_text.strip_prefix(&user_new_text) { + if !model_suffix.is_empty() { + let anchor = old_snapshot.anchor_after(user_edit.old.end); + edits.push((anchor..anchor, model_suffix.to_string())); + } + + model_edits.next(); + continue; + } + } + } + + return None; + } + + edits.extend(model_edits.cloned()); + + if edits.is_empty() { None } else { Some(edits) } +} + +#[cfg(test)] +mod tests { + use super::*; + use gpui::TestAppContext; + use language::ToOffset as _; + + #[gpui::test] + async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) { + let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx)); + let edits: Arc<[(Range, String)]> = cx.update(|cx| { + to_prediction_edits( + [(2..5, "REM".to_string()), (9..11, "".to_string())], + &buffer, + cx, + ) + .into() + }); + + let edit_preview = cx + .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx)) + .await; + + let prediction = EditPrediction { + id: EditPredictionId(Uuid::new_v4()), + edits, + snapshot: cx.read(|cx| buffer.read(cx).snapshot()), + edit_preview, + }; + + cx.update(|cx| { + assert_eq!( + from_prediction_edits( + &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(), + &buffer, + cx + ), + vec![(2..5, "REM".to_string()), (9..11, "".to_string())] + ); + + buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx)); + assert_eq!( + from_prediction_edits( + &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(), + &buffer, + cx + ), + vec![(2..2, "REM".to_string()), (6..8, "".to_string())] + ); + + buffer.update(cx, |buffer, cx| buffer.undo(cx)); + assert_eq!( + from_prediction_edits( + &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(), + &buffer, + cx + ), + vec![(2..5, "REM".to_string()), (9..11, "".to_string())] + ); + + buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx)); + assert_eq!( + from_prediction_edits( + &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(), + &buffer, + cx + ), + vec![(3..3, "EM".to_string()), (7..9, "".to_string())] + ); + + buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx)); + assert_eq!( + from_prediction_edits( + &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(), + &buffer, + cx + ), + vec![(4..4, "M".to_string()), (8..10, "".to_string())] + ); + + buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx)); + assert_eq!( + from_prediction_edits( + &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(), + &buffer, + cx + ), + vec![(9..11, "".to_string())] + ); + + buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx)); + assert_eq!( + from_prediction_edits( + &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(), + &buffer, + cx + ), + vec![(4..4, "M".to_string()), (8..10, "".to_string())] + ); + + buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx)); + assert_eq!( + from_prediction_edits( + &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(), + &buffer, + cx + ), + vec![(4..4, "M".to_string())] + ); + + buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx)); + assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None); + }) + } + + fn to_prediction_edits( + iterator: impl IntoIterator, String)>, + buffer: &Entity, + cx: &App, + ) -> Vec<(Range, String)> { + let buffer = buffer.read(cx); + iterator + .into_iter() + .map(|(range, text)| { + ( + buffer.anchor_after(range.start)..buffer.anchor_before(range.end), + text, + ) + }) + .collect() + } + + fn from_prediction_edits( + editor_edits: &[(Range, String)], + buffer: &Entity, + cx: &App, + ) -> Vec<(Range, String)> { + let buffer = buffer.read(cx); + editor_edits + .iter() + .map(|(range, text)| { + ( + range.start.to_offset(buffer)..range.end.to_offset(buffer), + text.clone(), + ) + }) + .collect() + } +}