Detailed changes
@@ -3216,6 +3216,7 @@ name = "cloud_llm_client"
version = "0.1.0"
dependencies = [
"anyhow",
+ "chrono",
"pretty_assertions",
"serde",
"serde_json",
@@ -5177,6 +5178,7 @@ dependencies = [
"anyhow",
"arrayvec",
"clap",
+ "cloud_llm_client",
"collections",
"futures 0.3.31",
"gpui",
@@ -21370,6 +21372,7 @@ dependencies = [
"zed_actions",
"zed_env_vars",
"zeta",
+ "zeta2",
"zlog",
"zlog_settings",
]
@@ -21647,6 +21650,32 @@ dependencies = [
"zlog",
]
+[[package]]
+name = "zeta2"
+version = "0.1.0"
+dependencies = [
+ "anyhow",
+ "arrayvec",
+ "client",
+ "cloud_llm_client",
+ "edit_prediction",
+ "edit_prediction_context",
+ "futures 0.3.31",
+ "gpui",
+ "language",
+ "language_model",
+ "log",
+ "project",
+ "release_channel",
+ "serde_json",
+ "thiserror 2.0.12",
+ "util",
+ "uuid",
+ "workspace",
+ "workspace-hack",
+ "worktree",
+]
+
[[package]]
name = "zeta_cli"
version = "0.1.0"
@@ -199,6 +199,7 @@ members = [
"crates/zed_actions",
"crates/zed_env_vars",
"crates/zeta",
+ "crates/zeta2",
"crates/zeta_cli",
"crates/zlog",
"crates/zlog_settings",
@@ -432,6 +433,7 @@ zed = { path = "crates/zed" }
zed_actions = { path = "crates/zed_actions" }
zed_env_vars = { path = "crates/zed_env_vars" }
zeta = { path = "crates/zeta" }
+zeta2 = { path = "crates/zeta2" }
zlog = { path = "crates/zlog" }
zlog_settings = { path = "crates/zlog_settings" }
@@ -13,6 +13,7 @@ path = "src/cloud_llm_client.rs"
[dependencies]
anyhow.workspace = true
+chrono.workspace = true
serde = { workspace = true, features = ["derive", "rc"] }
serde_json.workspace = true
strum = { workspace = true, features = ["derive"] }
@@ -1,3 +1,5 @@
+pub mod predict_edits_v3;
+
use std::str::FromStr;
use std::sync::Arc;
@@ -0,0 +1,118 @@
+use chrono::Duration;
+use serde::{Deserialize, Serialize};
+use std::{ops::Range, path::PathBuf};
+use uuid::Uuid;
+
+use crate::PredictEditsGitInfo;
+
+// TODO: snippet ordering within file / relative to excerpt
+
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct PredictEditsRequest {
+ pub excerpt: String,
+ pub excerpt_path: PathBuf,
+ /// Within file
+ pub excerpt_range: Range<usize>,
+ /// Within `excerpt`
+ pub cursor_offset: usize,
+ /// Within `signatures`
+ pub excerpt_parent: Option<usize>,
+ pub signatures: Vec<Signature>,
+ pub referenced_declarations: Vec<ReferencedDeclaration>,
+ pub events: Vec<Event>,
+ #[serde(default)]
+ pub can_collect_data: bool,
+ #[serde(skip_serializing_if = "Vec::is_empty", default)]
+ pub diagnostic_groups: Vec<DiagnosticGroup>,
+ /// Info about the git repository state, only present when can_collect_data is true.
+ #[serde(skip_serializing_if = "Option::is_none", default)]
+ pub git_info: Option<PredictEditsGitInfo>,
+ #[serde(default)]
+ pub debug_info: bool,
+}
+
+#[derive(Debug, Clone, Serialize, Deserialize)]
+#[serde(tag = "event")]
+pub enum Event {
+ BufferChange {
+ path: Option<PathBuf>,
+ old_path: Option<PathBuf>,
+ diff: String,
+ predicted: bool,
+ },
+}
+
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct Signature {
+ pub text: String,
+ pub text_is_truncated: bool,
+ #[serde(skip_serializing_if = "Option::is_none", default)]
+ pub parent_index: Option<usize>,
+}
+
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct ReferencedDeclaration {
+ pub path: PathBuf,
+ pub text: String,
+ pub text_is_truncated: bool,
+ /// Range of `text` within file, potentially truncated according to `text_is_truncated`
+ pub range: Range<usize>,
+ /// Range within `text`
+ pub signature_range: Range<usize>,
+ /// Index within `signatures`.
+ #[serde(skip_serializing_if = "Option::is_none", default)]
+ pub parent_index: Option<usize>,
+ pub score_components: ScoreComponents,
+ pub signature_score: f32,
+ pub declaration_score: f32,
+}
+
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct ScoreComponents {
+ pub is_same_file: bool,
+ pub is_referenced_nearby: bool,
+ pub is_referenced_in_breadcrumb: bool,
+ pub reference_count: usize,
+ pub same_file_declaration_count: usize,
+ pub declaration_count: usize,
+ pub reference_line_distance: u32,
+ pub declaration_line_distance: u32,
+ pub declaration_line_distance_rank: usize,
+ pub containing_range_vs_item_jaccard: f32,
+ pub containing_range_vs_signature_jaccard: f32,
+ pub adjacent_vs_item_jaccard: f32,
+ pub adjacent_vs_signature_jaccard: f32,
+ pub containing_range_vs_item_weighted_overlap: f32,
+ pub containing_range_vs_signature_weighted_overlap: f32,
+ pub adjacent_vs_item_weighted_overlap: f32,
+ pub adjacent_vs_signature_weighted_overlap: f32,
+}
+
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct DiagnosticGroup {
+ pub language_server: String,
+ pub diagnostic_group: serde_json::Value,
+}
+
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct PredictEditsResponse {
+ pub request_id: Uuid,
+ pub edits: Vec<Edit>,
+ pub debug_info: Option<DebugInfo>,
+}
+
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct DebugInfo {
+ pub prompt: String,
+ pub prompt_planning_time: Duration,
+ pub model_response: String,
+ pub inference_time: Duration,
+ pub parsing_time: Duration,
+}
+
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct Edit {
+ pub path: PathBuf,
+ pub range: Range<usize>,
+ pub content: String,
+}
@@ -14,6 +14,7 @@ path = "src/edit_prediction_context.rs"
[dependencies]
anyhow.workspace = true
arrayvec.workspace = true
+cloud_llm_client.workspace = true
collections.workspace = true
futures.workspace = true
gpui.workspace = true
@@ -41,6 +41,20 @@ impl Declaration {
}
}
+ pub fn parent(&self) -> Option<DeclarationId> {
+ match self {
+ Declaration::File { declaration, .. } => declaration.parent,
+ Declaration::Buffer { declaration, .. } => declaration.parent,
+ }
+ }
+
+ pub fn as_buffer(&self) -> Option<&BufferDeclaration> {
+ match self {
+ Declaration::File { .. } => None,
+ Declaration::Buffer { declaration, .. } => Some(declaration),
+ }
+ }
+
pub fn project_entry_id(&self) -> ProjectEntryId {
match self {
Declaration::File {
@@ -52,6 +66,13 @@ impl Declaration {
}
}
+ pub fn item_range(&self) -> Range<usize> {
+ match self {
+ Declaration::File { declaration, .. } => declaration.item_range_in_file.clone(),
+ Declaration::Buffer { declaration, .. } => declaration.item_range.clone(),
+ }
+ }
+
pub fn item_text(&self) -> (Cow<'_, str>, bool) {
match self {
Declaration::File { declaration, .. } => (
@@ -83,6 +104,16 @@ impl Declaration {
),
}
}
+
+ pub fn signature_range_in_item_text(&self) -> Range<usize> {
+ match self {
+ Declaration::File { declaration, .. } => declaration.signature_range_in_text.clone(),
+ Declaration::Buffer { declaration, .. } => {
+ declaration.signature_range.start - declaration.item_range.start
+ ..declaration.signature_range.end - declaration.item_range.start
+ }
+ }
+ }
}
fn expand_range_to_line_boundaries_and_truncate(
@@ -1,10 +1,11 @@
+use cloud_llm_client::predict_edits_v3::ScoreComponents;
use itertools::Itertools as _;
use language::BufferSnapshot;
use ordered_float::OrderedFloat;
use serde::Serialize;
use std::{collections::HashMap, ops::Range};
use strum::EnumIter;
-use text::{OffsetRangeExt, Point, ToPoint};
+use text::{Point, ToPoint};
use crate::{
Declaration, EditPredictionExcerpt, EditPredictionExcerptText, Identifier,
@@ -15,19 +16,14 @@ use crate::{
const MAX_IDENTIFIER_DECLARATION_COUNT: usize = 16;
-// TODO:
-//
-// * Consider adding declaration_file_count
-
#[derive(Clone, Debug)]
pub struct ScoredSnippet {
pub identifier: Identifier,
pub declaration: Declaration,
- pub score_components: ScoreInputs,
+ pub score_components: ScoreComponents,
pub scores: Scores,
}
-// TODO: Consider having "Concise" style corresponding to `concise_text`
#[derive(EnumIter, Clone, Copy, PartialEq, Eq, Hash, Debug)]
pub enum SnippetStyle {
Signature,
@@ -90,8 +86,8 @@ pub fn scored_snippets(
let declaration_count = declarations.len();
declarations
- .iter()
- .filter_map(|declaration| match declaration {
+ .into_iter()
+ .filter_map(|(declaration_id, declaration)| match declaration {
Declaration::Buffer {
buffer_id,
declaration: buffer_declaration,
@@ -100,24 +96,29 @@ pub fn scored_snippets(
let is_same_file = buffer_id == ¤t_buffer.remote_id();
if is_same_file {
- range_intersection(
- &buffer_declaration.item_range.to_offset(¤t_buffer),
- &excerpt.range,
- )
- .is_none()
- .then(|| {
+ let overlaps_excerpt =
+ range_intersection(&buffer_declaration.item_range, &excerpt.range)
+ .is_some();
+ if overlaps_excerpt
+ || excerpt
+ .parent_declarations
+ .iter()
+ .any(|(excerpt_parent, _)| excerpt_parent == &declaration_id)
+ {
+ None
+ } else {
let declaration_line = buffer_declaration
.item_range
.start
.to_point(current_buffer)
.row;
- (
+ Some((
true,
(cursor_point.row as i32 - declaration_line as i32)
.unsigned_abs(),
declaration,
- )
- })
+ ))
+ }
} else {
Some((false, u32::MAX, declaration))
}
@@ -238,7 +239,8 @@ fn score_snippet(
let adjacent_vs_signature_weighted_overlap =
weighted_overlap_coefficient(adjacent_identifier_occurrences, &item_signature_occurrences);
- let score_components = ScoreInputs {
+ // TODO: Consider adding declaration_file_count
+ let score_components = ScoreComponents {
is_same_file,
is_referenced_nearby,
is_referenced_in_breadcrumb,
@@ -261,51 +263,30 @@ fn score_snippet(
Some(ScoredSnippet {
identifier: identifier.clone(),
declaration: declaration,
- scores: score_components.score(),
+ scores: Scores::score(&score_components),
score_components,
})
}
-#[derive(Clone, Debug, Serialize)]
-pub struct ScoreInputs {
- pub is_same_file: bool,
- pub is_referenced_nearby: bool,
- pub is_referenced_in_breadcrumb: bool,
- pub reference_count: usize,
- pub same_file_declaration_count: usize,
- pub declaration_count: usize,
- pub reference_line_distance: u32,
- pub declaration_line_distance: u32,
- pub declaration_line_distance_rank: usize,
- pub containing_range_vs_item_jaccard: f32,
- pub containing_range_vs_signature_jaccard: f32,
- pub adjacent_vs_item_jaccard: f32,
- pub adjacent_vs_signature_jaccard: f32,
- pub containing_range_vs_item_weighted_overlap: f32,
- pub containing_range_vs_signature_weighted_overlap: f32,
- pub adjacent_vs_item_weighted_overlap: f32,
- pub adjacent_vs_signature_weighted_overlap: f32,
-}
-
#[derive(Clone, Debug, Serialize)]
pub struct Scores {
pub signature: f32,
pub declaration: f32,
}
-impl ScoreInputs {
- fn score(&self) -> Scores {
+impl Scores {
+ fn score(components: &ScoreComponents) -> Scores {
// Score related to how likely this is the correct declaration, range 0 to 1
- let accuracy_score = if self.is_same_file {
+ let accuracy_score = if components.is_same_file {
// TODO: use declaration_line_distance_rank
- 1.0 / self.same_file_declaration_count as f32
+ 1.0 / components.same_file_declaration_count as f32
} else {
- 1.0 / self.declaration_count as f32
+ 1.0 / components.declaration_count as f32
};
// Score related to the distance between the reference and cursor, range 0 to 1
- let distance_score = if self.is_referenced_nearby {
- 1.0 / (1.0 + self.reference_line_distance as f32 / 10.0).powf(2.0)
+ let distance_score = if components.is_referenced_nearby {
+ 1.0 / (1.0 + components.reference_line_distance as f32 / 10.0).powf(2.0)
} else {
// same score as ~14 lines away, rationale is to not overly penalize references from parent signatures
0.5
@@ -315,10 +296,12 @@ impl ScoreInputs {
let combined_score = 10.0 * accuracy_score * distance_score;
Scores {
- signature: combined_score * self.containing_range_vs_signature_weighted_overlap,
+ signature: combined_score * components.containing_range_vs_signature_weighted_overlap,
// declaration score gets boosted both by being multiplied by 2 and by there being more
// weighted overlap.
- declaration: 2.0 * combined_score * self.containing_range_vs_item_weighted_overlap,
+ declaration: 2.0
+ * combined_score
+ * components.containing_range_vs_item_weighted_overlap,
}
}
}
@@ -6,62 +6,82 @@ mod reference;
mod syntax_index;
mod text_similarity;
-use std::time::Instant;
-
-pub use declaration::{BufferDeclaration, Declaration, FileDeclaration, Identifier};
-pub use declaration_scoring::SnippetStyle;
-pub use excerpt::{EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionExcerptText};
-
use gpui::{App, AppContext as _, Entity, Task};
use language::BufferSnapshot;
-pub use reference::references_in_excerpt;
-pub use syntax_index::SyntaxIndex;
use text::{Point, ToOffset as _};
-use crate::declaration_scoring::{ScoredSnippet, scored_snippets};
+pub use declaration::*;
+pub use declaration_scoring::*;
+pub use excerpt::*;
+pub use reference::*;
+pub use syntax_index::*;
#[derive(Debug)]
pub struct EditPredictionContext {
pub excerpt: EditPredictionExcerpt,
pub excerpt_text: EditPredictionExcerptText,
+ pub cursor_offset_in_excerpt: usize,
pub snippets: Vec<ScoredSnippet>,
- pub retrieval_duration: std::time::Duration,
}
impl EditPredictionContext {
- pub fn gather(
+ pub fn gather_context_in_background(
cursor_point: Point,
buffer: BufferSnapshot,
excerpt_options: EditPredictionExcerptOptions,
- syntax_index: Entity<SyntaxIndex>,
+ syntax_index: Option<Entity<SyntaxIndex>>,
cx: &mut App,
) -> Task<Option<Self>> {
- let start = Instant::now();
- let index_state = syntax_index.read_with(cx, |index, _cx| index.state().clone());
- cx.background_spawn(async move {
- let index_state = index_state.lock().await;
-
- let excerpt =
- EditPredictionExcerpt::select_from_buffer(cursor_point, &buffer, &excerpt_options)?;
- let excerpt_text = excerpt.text(&buffer);
- let references = references_in_excerpt(&excerpt, &excerpt_text, &buffer);
- let cursor_offset = cursor_point.to_offset(&buffer);
-
- let snippets = scored_snippets(
+ if let Some(syntax_index) = syntax_index {
+ let index_state = syntax_index.read_with(cx, |index, _cx| index.state().clone());
+ cx.background_spawn(async move {
+ let index_state = index_state.lock().await;
+ Self::gather_context(cursor_point, &buffer, &excerpt_options, Some(&index_state))
+ })
+ } else {
+ cx.background_spawn(async move {
+ Self::gather_context(cursor_point, &buffer, &excerpt_options, None)
+ })
+ }
+ }
+
+ pub fn gather_context(
+ cursor_point: Point,
+ buffer: &BufferSnapshot,
+ excerpt_options: &EditPredictionExcerptOptions,
+ index_state: Option<&SyntaxIndexState>,
+ ) -> Option<Self> {
+ let excerpt = EditPredictionExcerpt::select_from_buffer(
+ cursor_point,
+ buffer,
+ excerpt_options,
+ index_state,
+ )?;
+ let excerpt_text = excerpt.text(buffer);
+ let cursor_offset_in_file = cursor_point.to_offset(buffer);
+ // TODO fix this to not need saturating_sub
+ let cursor_offset_in_excerpt = cursor_offset_in_file.saturating_sub(excerpt.range.start);
+
+ let snippets = if let Some(index_state) = index_state {
+ let references = references_in_excerpt(&excerpt, &excerpt_text, buffer);
+
+ scored_snippets(
&index_state,
&excerpt,
&excerpt_text,
references,
- cursor_offset,
- &buffer,
- );
-
- Some(Self {
- excerpt,
- excerpt_text,
- snippets,
- retrieval_duration: start.elapsed(),
- })
+ cursor_offset_in_file,
+ buffer,
+ )
+ } else {
+ vec![]
+ };
+
+ Some(Self {
+ excerpt,
+ excerpt_text,
+ cursor_offset_in_excerpt,
+ snippets,
})
}
}
@@ -101,24 +121,28 @@ mod tests {
let context = cx
.update(|cx| {
- EditPredictionContext::gather(
+ EditPredictionContext::gather_context_in_background(
cursor_point,
buffer_snapshot,
EditPredictionExcerptOptions {
- max_bytes: 40,
+ max_bytes: 60,
min_bytes: 10,
target_before_cursor_over_total_bytes: 0.5,
- include_parent_signatures: false,
},
- index,
+ Some(index),
cx,
)
})
.await
.unwrap();
- assert_eq!(context.snippets.len(), 1);
- assert_eq!(context.snippets[0].identifier.name.as_ref(), "process_data");
+ let mut snippet_identifiers = context
+ .snippets
+ .iter()
+ .map(|snippet| snippet.identifier.name.as_ref())
+ .collect::<Vec<_>>();
+ snippet_identifiers.sort();
+ assert_eq!(snippet_identifiers, vec!["main", "process_data"]);
drop(buffer);
}
@@ -1,9 +1,11 @@
use language::BufferSnapshot;
use std::ops::Range;
-use text::{OffsetRangeExt as _, Point, ToOffset as _, ToPoint as _};
+use text::{Point, ToOffset as _, ToPoint as _};
use tree_sitter::{Node, TreeCursor};
use util::RangeExt;
+use crate::{BufferDeclaration, declaration::DeclarationId, syntax_index::SyntaxIndexState};
+
// TODO:
//
// - Test parent signatures
@@ -27,14 +29,12 @@ pub struct EditPredictionExcerptOptions {
pub min_bytes: usize,
/// Target ratio of bytes before the cursor divided by total bytes in the window.
pub target_before_cursor_over_total_bytes: f32,
- /// Whether to include parent signatures
- pub include_parent_signatures: bool,
}
#[derive(Debug, Clone)]
pub struct EditPredictionExcerpt {
pub range: Range<usize>,
- pub parent_signature_ranges: Vec<Range<usize>>,
+ pub parent_declarations: Vec<(DeclarationId, Range<usize>)>,
pub size: usize,
}
@@ -50,9 +50,9 @@ impl EditPredictionExcerpt {
.text_for_range(self.range.clone())
.collect::<String>();
let parent_signatures = self
- .parent_signature_ranges
+ .parent_declarations
.iter()
- .map(|range| buffer.text_for_range(range.clone()).collect::<String>())
+ .map(|(_, range)| buffer.text_for_range(range.clone()).collect::<String>())
.collect();
EditPredictionExcerptText {
body,
@@ -62,8 +62,9 @@ impl EditPredictionExcerpt {
/// Selects an excerpt around a buffer position, attempting to choose logical boundaries based
/// on TreeSitter structure and approximately targeting a goal ratio of bytesbefore vs after the
- /// cursor. When `include_parent_signatures` is true, the excerpt also includes the signatures
- /// of parent outline items.
+ /// cursor.
+ ///
+ /// When `index` is provided, the excerpt will include the signatures of parent outline items.
///
/// First tries to use AST node boundaries to select the excerpt, and falls back on line-based
/// expansion.
@@ -73,6 +74,7 @@ impl EditPredictionExcerpt {
query_point: Point,
buffer: &BufferSnapshot,
options: &EditPredictionExcerptOptions,
+ syntax_index: Option<&SyntaxIndexState>,
) -> Option<Self> {
if buffer.len() <= options.max_bytes {
log::debug!(
@@ -90,17 +92,9 @@ impl EditPredictionExcerpt {
return None;
}
- // TODO: Don't compute text / annotation_range / skip converting to and from anchors.
- let outline_items = if options.include_parent_signatures {
- buffer
- .outline_items_containing(query_range.clone(), false, None)
- .into_iter()
- .flat_map(|item| {
- Some(ExcerptOutlineItem {
- item_range: item.range.to_offset(&buffer),
- signature_range: item.signature_range?.to_offset(&buffer),
- })
- })
+ let parent_declarations = if let Some(syntax_index) = syntax_index {
+ syntax_index
+ .buffer_declarations_containing_range(buffer.remote_id(), query_range.clone())
.collect()
} else {
Vec::new()
@@ -109,7 +103,7 @@ impl EditPredictionExcerpt {
let excerpt_selector = ExcerptSelector {
query_offset,
query_range,
- outline_items: &outline_items,
+ parent_declarations: &parent_declarations,
buffer,
options,
};
@@ -132,15 +126,15 @@ impl EditPredictionExcerpt {
excerpt_selector.select_lines()
}
- fn new(range: Range<usize>, parent_signature_ranges: Vec<Range<usize>>) -> Self {
+ fn new(range: Range<usize>, parent_declarations: Vec<(DeclarationId, Range<usize>)>) -> Self {
let size = range.len()
- + parent_signature_ranges
+ + parent_declarations
.iter()
- .map(|r| r.len())
+ .map(|(_, range)| range.len())
.sum::<usize>();
Self {
range,
- parent_signature_ranges,
+ parent_declarations,
size,
}
}
@@ -150,20 +144,14 @@ impl EditPredictionExcerpt {
// this is an issue because parent_signature_ranges may be incorrect
log::error!("bug: with_expanded_range called with disjoint range");
}
- let mut parent_signature_ranges = Vec::with_capacity(self.parent_signature_ranges.len());
- let mut size = new_range.len();
- for range in &self.parent_signature_ranges {
- if range.contains_inclusive(&new_range) {
+ let mut parent_declarations = Vec::with_capacity(self.parent_declarations.len());
+ for (declaration_id, range) in &self.parent_declarations {
+ if !range.contains_inclusive(&new_range) {
break;
}
- parent_signature_ranges.push(range.clone());
- size += range.len();
- }
- Self {
- range: new_range,
- parent_signature_ranges,
- size,
+ parent_declarations.push((*declaration_id, range.clone()));
}
+ Self::new(new_range, parent_declarations)
}
fn parent_signatures_size(&self) -> usize {
@@ -174,16 +162,11 @@ impl EditPredictionExcerpt {
struct ExcerptSelector<'a> {
query_offset: usize,
query_range: Range<usize>,
- outline_items: &'a [ExcerptOutlineItem],
+ parent_declarations: &'a [(DeclarationId, &'a BufferDeclaration)],
buffer: &'a BufferSnapshot,
options: &'a EditPredictionExcerptOptions,
}
-struct ExcerptOutlineItem {
- item_range: Range<usize>,
- signature_range: Range<usize>,
-}
-
impl<'a> ExcerptSelector<'a> {
/// Finds the largest node that is smaller than the window size and contains `query_range`.
fn select_tree_sitter_nodes(&self) -> Option<EditPredictionExcerpt> {
@@ -396,13 +379,13 @@ impl<'a> ExcerptSelector<'a> {
}
fn make_excerpt(&self, range: Range<usize>) -> EditPredictionExcerpt {
- let parent_signature_ranges = self
- .outline_items
+ let parent_declarations = self
+ .parent_declarations
.iter()
- .filter(|item| item.item_range.contains_inclusive(&range))
- .map(|item| item.signature_range.clone())
+ .filter(|(_, declaration)| declaration.item_range.contains_inclusive(&range))
+ .map(|(id, declaration)| (*id, declaration.signature_range.clone()))
.collect();
- EditPredictionExcerpt::new(range, parent_signature_ranges)
+ EditPredictionExcerpt::new(range, parent_declarations)
}
/// Returns `true` if the `forward` excerpt is a better choice than the `backward` excerpt.
@@ -493,8 +476,9 @@ mod tests {
let buffer = create_buffer(&text, cx);
let cursor_point = cursor.to_point(&buffer);
- let excerpt = EditPredictionExcerpt::select_from_buffer(cursor_point, &buffer, &options)
- .expect("Should select an excerpt");
+ let excerpt =
+ EditPredictionExcerpt::select_from_buffer(cursor_point, &buffer, &options, None)
+ .expect("Should select an excerpt");
pretty_assertions::assert_eq!(
generate_marked_text(&text, std::slice::from_ref(&excerpt.range), false),
generate_marked_text(&text, &[expected_excerpt], false)
@@ -517,7 +501,6 @@ fn main() {
max_bytes: 20,
min_bytes: 10,
target_before_cursor_over_total_bytes: 0.5,
- include_parent_signatures: false,
};
check_example(options, text, cx);
@@ -541,7 +524,6 @@ fn bar() {}"#;
max_bytes: 65,
min_bytes: 10,
target_before_cursor_over_total_bytes: 0.5,
- include_parent_signatures: false,
};
check_example(options, text, cx);
@@ -561,7 +543,6 @@ fn main() {
max_bytes: 50,
min_bytes: 10,
target_before_cursor_over_total_bytes: 0.5,
- include_parent_signatures: false,
};
check_example(options, text, cx);
@@ -583,7 +564,6 @@ fn main() {
max_bytes: 60,
min_bytes: 45,
target_before_cursor_over_total_bytes: 0.5,
- include_parent_signatures: false,
};
check_example(options, text, cx);
@@ -608,7 +588,6 @@ fn main() {
max_bytes: 120,
min_bytes: 10,
target_before_cursor_over_total_bytes: 0.6,
- include_parent_signatures: false,
};
check_example(options, text, cx);
@@ -33,8 +33,8 @@ pub fn references_in_excerpt(
snapshot,
);
- for (range, text) in excerpt
- .parent_signature_ranges
+ for ((_, range), text) in excerpt
+ .parent_declarations
.iter()
.zip(excerpt_text.parent_signatures.iter())
{
@@ -1,5 +1,3 @@
-use std::sync::Arc;
-
use collections::{HashMap, HashSet};
use futures::lock::Mutex;
use gpui::{App, AppContext as _, Context, Entity, Task, WeakEntity};
@@ -8,20 +6,17 @@ use project::buffer_store::{BufferStore, BufferStoreEvent};
use project::worktree_store::{WorktreeStore, WorktreeStoreEvent};
use project::{PathChange, Project, ProjectEntryId, ProjectPath};
use slotmap::SlotMap;
+use std::iter;
+use std::ops::Range;
+use std::sync::Arc;
use text::BufferId;
-use util::{debug_panic, some_or_debug_panic};
+use util::{RangeExt as _, debug_panic, some_or_debug_panic};
use crate::declaration::{
BufferDeclaration, Declaration, DeclarationId, FileDeclaration, Identifier,
};
use crate::outline::declarations_in_buffer;
-// TODO:
-//
-// * Skip for remote projects
-//
-// * Consider making SyntaxIndex not an Entity.
-
// Potential future improvements:
//
// * Send multiple selected excerpt ranges. Challenge is that excerpt ranges influence which
@@ -40,7 +35,6 @@ use crate::outline::declarations_in_buffer;
// * Concurrent slotmap
//
// * Use queue for parsing
-//
pub struct SyntaxIndex {
state: Arc<Mutex<SyntaxIndexState>>,
@@ -432,7 +426,7 @@ impl SyntaxIndexState {
pub fn declarations_for_identifier<const N: usize>(
&self,
identifier: &Identifier,
- ) -> Vec<Declaration> {
+ ) -> Vec<(DeclarationId, &Declaration)> {
// make sure to not have a large stack allocation
assert!(N < 32);
@@ -454,7 +448,7 @@ impl SyntaxIndexState {
project_entry_id, ..
} => {
included_buffer_entry_ids.push(*project_entry_id);
- result.push(declaration.clone());
+ result.push((*declaration_id, declaration));
if result.len() == N {
return Vec::new();
}
@@ -463,19 +457,19 @@ impl SyntaxIndexState {
project_entry_id, ..
} => {
if !included_buffer_entry_ids.contains(&project_entry_id) {
- file_declarations.push(declaration.clone());
+ file_declarations.push((*declaration_id, declaration));
}
}
}
}
- for declaration in file_declarations {
+ for (declaration_id, declaration) in file_declarations {
match declaration {
Declaration::File {
project_entry_id, ..
} => {
if !included_buffer_entry_ids.contains(&project_entry_id) {
- result.push(declaration);
+ result.push((declaration_id, declaration));
if result.len() == N {
return Vec::new();
@@ -489,6 +483,35 @@ impl SyntaxIndexState {
result
}
+ pub fn buffer_declarations_containing_range(
+ &self,
+ buffer_id: BufferId,
+ range: Range<usize>,
+ ) -> impl Iterator<Item = (DeclarationId, &BufferDeclaration)> {
+ let Some(buffer_state) = self.buffers.get(&buffer_id) else {
+ return itertools::Either::Left(iter::empty());
+ };
+
+ let iter = buffer_state
+ .declarations
+ .iter()
+ .filter_map(move |declaration_id| {
+ let Some(declaration) = self
+ .declarations
+ .get(*declaration_id)
+ .and_then(|d| d.as_buffer())
+ else {
+ log::error!("bug: missing buffer outline declaration");
+ return None;
+ };
+ if declaration.item_range.contains_inclusive(&range) {
+ return Some((*declaration_id, declaration));
+ }
+ return None;
+ });
+ itertools::Either::Right(iter)
+ }
+
pub fn file_declaration_count(&self, declaration: &Declaration) -> usize {
match declaration {
Declaration::File {
@@ -553,11 +576,11 @@ mod tests {
let decls = index_state.declarations_for_identifier::<8>(&main);
assert_eq!(decls.len(), 2);
- let decl = expect_file_decl("c.rs", &decls[0], &project, cx);
+ let decl = expect_file_decl("c.rs", &decls[0].1, &project, cx);
assert_eq!(decl.identifier, main.clone());
assert_eq!(decl.item_range_in_file, 32..280);
- let decl = expect_file_decl("a.rs", &decls[1], &project, cx);
+ let decl = expect_file_decl("a.rs", &decls[1].1, &project, cx);
assert_eq!(decl.identifier, main);
assert_eq!(decl.item_range_in_file, 0..98);
});
@@ -577,7 +600,7 @@ mod tests {
let decls = index_state.declarations_for_identifier::<8>(&test_process_data);
assert_eq!(decls.len(), 1);
- let decl = expect_file_decl("c.rs", &decls[0], &project, cx);
+ let decl = expect_file_decl("c.rs", &decls[0].1, &project, cx);
assert_eq!(decl.identifier, test_process_data);
let parent_id = decl.parent.unwrap();
@@ -618,7 +641,7 @@ mod tests {
let decls = index_state.declarations_for_identifier::<8>(&test_process_data);
assert_eq!(decls.len(), 1);
- let decl = expect_buffer_decl("c.rs", &decls[0], &project, cx);
+ let decl = expect_buffer_decl("c.rs", &decls[0].1, &project, cx);
assert_eq!(decl.identifier, test_process_data);
let parent_id = decl.parent.unwrap();
@@ -676,11 +699,11 @@ mod tests {
cx.update(|cx| {
let decls = index_state.declarations_for_identifier::<8>(&main);
assert_eq!(decls.len(), 2);
- let decl = expect_buffer_decl("c.rs", &decls[0], &project, cx);
+ let decl = expect_buffer_decl("c.rs", &decls[0].1, &project, cx);
assert_eq!(decl.identifier, main);
assert_eq!(decl.item_range.to_offset(&buffer.read(cx)), 32..280);
- expect_file_decl("a.rs", &decls[1], &project, cx);
+ expect_file_decl("a.rs", &decls[1].1, &project, cx);
});
}
@@ -695,8 +718,8 @@ mod tests {
cx.update(|cx| {
let decls = index_state.declarations_for_identifier::<8>(&main);
assert_eq!(decls.len(), 2);
- expect_file_decl("c.rs", &decls[0], &project, cx);
- expect_file_decl("a.rs", &decls[1], &project, cx);
+ expect_file_decl("c.rs", &decls[0].1, &project, cx);
+ expect_file_decl("a.rs", &decls[1].1, &project, cx);
});
}
@@ -9,8 +9,12 @@ use crate::reference::Reference;
// That implementation could actually be more efficient - no need to track words in the window that
// are not in the query.
+// TODO: Consider a flat sorted Vec<(String, usize)> representation. Intersection can just walk the
+// two in parallel.
+
static IDENTIFIER_REGEX: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"\b\w+\b").unwrap());
+// TODO: use &str or Cow<str> keys?
#[derive(Debug)]
pub struct IdentifierOccurrences {
identifier_to_count: HashMap<String, usize>,
@@ -4,7 +4,7 @@ use std::{
path::{Path, PathBuf},
str::FromStr,
sync::Arc,
- time::Duration,
+ time::{Duration, Instant},
};
use collections::HashMap;
@@ -195,6 +195,8 @@ impl EditPredictionTools {
.timer(Duration::from_millis(50))
.await;
+ let mut start_time = None;
+
let Ok(task) = this.update(cx, |this, cx| {
fn number_input_value<T: FromStr + Default>(
input: &Entity<SingleLineInput>,
@@ -216,15 +218,16 @@ impl EditPredictionTools {
&this.cursor_context_ratio_input,
cx,
),
- // TODO Display and add to options
- include_parent_signatures: false,
};
- EditPredictionContext::gather(
+ start_time = Some(Instant::now());
+
+ // TODO use global zeta instead
+ EditPredictionContext::gather_context_in_background(
cursor_position,
current_buffer_snapshot,
options,
- this.syntax_index.clone(),
+ Some(this.syntax_index.clone()),
cx,
)
}) else {
@@ -243,6 +246,7 @@ impl EditPredictionTools {
.ok();
return;
};
+ let retrieval_duration = start_time.unwrap().elapsed();
let mut languages = HashMap::default();
for snippet in context.snippets.iter() {
@@ -320,7 +324,7 @@ impl EditPredictionTools {
this.last_context = Some(ContextState {
context_editor,
- retrieval_duration: context.retrieval_duration,
+ retrieval_duration,
});
cx.notify();
})
@@ -84,6 +84,17 @@ pub enum EditPredictionProvider {
Zed,
}
+impl EditPredictionProvider {
+ pub fn is_zed(&self) -> bool {
+ match self {
+ EditPredictionProvider::Zed => true,
+ EditPredictionProvider::None
+ | EditPredictionProvider::Copilot
+ | EditPredictionProvider::Supermaven => false,
+ }
+ }
+}
+
/// The contents of the edit prediction settings.
#[skip_serializing_none]
#[derive(Clone, Debug, Default, Serialize, Deserialize, JsonSchema, MergeFrom, PartialEq)]
@@ -163,6 +163,7 @@ workspace.workspace = true
zed_actions.workspace = true
zed_env_vars.workspace = true
zeta.workspace = true
+zeta2.workspace = true
zlog.workspace = true
zlog_settings.workspace = true
@@ -203,21 +203,43 @@ fn assign_edit_prediction_provider(
}
}
- let zeta = zeta::Zeta::register(worktree, client.clone(), user_store, cx);
-
- if let Some(buffer) = &singleton_buffer
- && buffer.read(cx).file().is_some()
- && let Some(project) = editor.project()
- {
- zeta.update(cx, |zeta, cx| {
- zeta.register_buffer(buffer, project, cx);
+ if std::env::var("ZED_ZETA2").is_ok() {
+ let zeta = zeta2::Zeta::global(client, &user_store, cx);
+ let provider = cx.new(|cx| {
+ zeta2::ZetaEditPredictionProvider::new(
+ editor.project(),
+ &client,
+ &user_store,
+ cx,
+ )
});
- }
- let provider =
- cx.new(|_| zeta::ZetaEditPredictionProvider::new(zeta, singleton_buffer));
+ if let Some(buffer) = &singleton_buffer
+ && buffer.read(cx).file().is_some()
+ && let Some(project) = editor.project()
+ {
+ zeta.update(cx, |zeta, cx| {
+ zeta.register_buffer(buffer, project, cx);
+ });
+ }
- editor.set_edit_prediction_provider(Some(provider), window, cx);
+ editor.set_edit_prediction_provider(Some(provider), window, cx);
+ } else {
+ let zeta = zeta::Zeta::register(worktree, client.clone(), user_store, cx);
+
+ if let Some(buffer) = &singleton_buffer
+ && buffer.read(cx).file().is_some()
+ && let Some(project) = editor.project()
+ {
+ zeta.update(cx, |zeta, cx| {
+ zeta.register_buffer(buffer, project, cx);
+ });
+ }
+
+ let provider =
+ cx.new(|_| zeta::ZetaEditPredictionProvider::new(zeta, singleton_buffer));
+ editor.set_edit_prediction_provider(Some(provider), window, cx);
+ }
}
}
}
@@ -0,0 +1,37 @@
+[package]
+name = "zeta2"
+version = "0.1.0"
+edition.workspace = true
+publish.workspace = true
+license = "GPL-3.0-or-later"
+
+[lints]
+workspace = true
+
+[lib]
+path = "src/zeta2.rs"
+
+[dependencies]
+anyhow.workspace = true
+arrayvec.workspace = true
+client.workspace = true
+cloud_llm_client.workspace = true
+edit_prediction.workspace = true
+edit_prediction_context.workspace = true
+futures.workspace = true
+gpui.workspace = true
+language.workspace = true
+language_model.workspace = true
+log.workspace = true
+project.workspace = true
+release_channel.workspace = true
+serde_json.workspace = true
+thiserror.workspace = true
+util.workspace = true
+uuid.workspace = true
+workspace.workspace = true
+workspace-hack.workspace = true
+worktree.workspace = true
+
+[dev-dependencies]
+gpui = { workspace = true, features = ["test-support"] }
@@ -0,0 +1 @@
+../../LICENSE-GPL
@@ -0,0 +1,1130 @@
+use anyhow::{Context as _, Result, anyhow};
+use arrayvec::ArrayVec;
+use client::{Client, EditPredictionUsage, UserStore};
+use cloud_llm_client::predict_edits_v3::{self, Signature};
+use cloud_llm_client::{
+ EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME, ZED_VERSION_HEADER_NAME,
+};
+use edit_prediction::{DataCollectionState, Direction, EditPredictionProvider};
+use edit_prediction_context::{
+ DeclarationId, EditPredictionContext, EditPredictionExcerptOptions, SyntaxIndex,
+ SyntaxIndexState,
+};
+use futures::AsyncReadExt as _;
+use gpui::http_client::Method;
+use gpui::{
+ App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, http_client,
+ prelude::*,
+};
+use language::{Anchor, Buffer, OffsetRangeExt as _, ToPoint};
+use language::{BufferSnapshot, EditPreview};
+use language_model::{LlmApiToken, RefreshLlmTokenListener};
+use project::Project;
+use release_channel::AppVersion;
+use std::cmp;
+use std::collections::{HashMap, VecDeque, hash_map};
+use std::path::PathBuf;
+use std::str::FromStr as _;
+use std::time::{Duration, Instant};
+use std::{ops::Range, sync::Arc};
+use thiserror::Error;
+use util::ResultExt as _;
+use uuid::Uuid;
+use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
+
+const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
+
+/// Maximum number of events to track.
+const MAX_EVENT_COUNT: usize = 16;
+
+#[derive(Clone)]
+struct ZetaGlobal(Entity<Zeta>);
+
+impl Global for ZetaGlobal {}
+
+pub struct Zeta {
+ client: Arc<Client>,
+ user_store: Entity<UserStore>,
+ llm_token: LlmApiToken,
+ _llm_token_subscription: Subscription,
+ projects: HashMap<EntityId, ZetaProject>,
+ excerpt_options: EditPredictionExcerptOptions,
+ update_required: bool,
+}
+
+struct ZetaProject {
+ syntax_index: Entity<SyntaxIndex>,
+ events: VecDeque<Event>,
+ registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
+}
+
+struct RegisteredBuffer {
+ snapshot: BufferSnapshot,
+ _subscriptions: [gpui::Subscription; 2],
+}
+
+#[derive(Clone)]
+pub enum Event {
+ BufferChange {
+ old_snapshot: BufferSnapshot,
+ new_snapshot: BufferSnapshot,
+ timestamp: Instant,
+ },
+}
+
+impl Zeta {
+ pub fn global(
+ client: &Arc<Client>,
+ user_store: &Entity<UserStore>,
+ cx: &mut App,
+ ) -> Entity<Self> {
+ cx.try_global::<ZetaGlobal>()
+ .map(|global| global.0.clone())
+ .unwrap_or_else(|| {
+ let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
+ cx.set_global(ZetaGlobal(zeta.clone()));
+ zeta
+ })
+ }
+
+ fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
+ let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
+
+ Self {
+ projects: HashMap::new(),
+ client,
+ user_store,
+ excerpt_options: EditPredictionExcerptOptions {
+ max_bytes: 512,
+ min_bytes: 128,
+ target_before_cursor_over_total_bytes: 0.5,
+ },
+ llm_token: LlmApiToken::default(),
+ _llm_token_subscription: cx.subscribe(
+ &refresh_llm_token_listener,
+ |this, _listener, _event, cx| {
+ let client = this.client.clone();
+ let llm_token = this.llm_token.clone();
+ cx.spawn(async move |_this, _cx| {
+ llm_token.refresh(&client).await?;
+ anyhow::Ok(())
+ })
+ .detach_and_log_err(cx);
+ },
+ ),
+ update_required: false,
+ }
+ }
+
+ pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
+ self.user_store.read(cx).edit_prediction_usage()
+ }
+
+ pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut App) {
+ self.get_or_init_zeta_project(project, cx);
+ }
+
+ pub fn register_buffer(
+ &mut self,
+ buffer: &Entity<Buffer>,
+ project: &Entity<Project>,
+ cx: &mut Context<Self>,
+ ) {
+ let zeta_project = self.get_or_init_zeta_project(project, cx);
+ Self::register_buffer_impl(zeta_project, buffer, project, cx);
+ }
+
+ fn get_or_init_zeta_project(
+ &mut self,
+ project: &Entity<Project>,
+ cx: &mut App,
+ ) -> &mut ZetaProject {
+ self.projects
+ .entry(project.entity_id())
+ .or_insert_with(|| ZetaProject {
+ syntax_index: cx.new(|cx| SyntaxIndex::new(project, cx)),
+ events: VecDeque::new(),
+ registered_buffers: HashMap::new(),
+ })
+ }
+
+ fn register_buffer_impl<'a>(
+ zeta_project: &'a mut ZetaProject,
+ buffer: &Entity<Buffer>,
+ project: &Entity<Project>,
+ cx: &mut Context<Self>,
+ ) -> &'a mut RegisteredBuffer {
+ let buffer_id = buffer.entity_id();
+ match zeta_project.registered_buffers.entry(buffer_id) {
+ hash_map::Entry::Occupied(entry) => entry.into_mut(),
+ hash_map::Entry::Vacant(entry) => {
+ let snapshot = buffer.read(cx).snapshot();
+ let project_entity_id = project.entity_id();
+ entry.insert(RegisteredBuffer {
+ snapshot,
+ _subscriptions: [
+ cx.subscribe(buffer, {
+ let project = project.downgrade();
+ move |this, buffer, event, cx| {
+ if let language::BufferEvent::Edited = event
+ && let Some(project) = project.upgrade()
+ {
+ this.report_changes_for_buffer(&buffer, &project, cx);
+ }
+ }
+ }),
+ cx.observe_release(buffer, move |this, _buffer, _cx| {
+ let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
+ else {
+ return;
+ };
+ zeta_project.registered_buffers.remove(&buffer_id);
+ }),
+ ],
+ })
+ }
+ }
+ }
+
+ fn report_changes_for_buffer(
+ &mut self,
+ buffer: &Entity<Buffer>,
+ project: &Entity<Project>,
+ cx: &mut Context<Self>,
+ ) -> BufferSnapshot {
+ let zeta_project = self.get_or_init_zeta_project(project, cx);
+ let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
+
+ let new_snapshot = buffer.read(cx).snapshot();
+ if new_snapshot.version != registered_buffer.snapshot.version {
+ let old_snapshot =
+ std::mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
+ Self::push_event(
+ zeta_project,
+ Event::BufferChange {
+ old_snapshot,
+ new_snapshot: new_snapshot.clone(),
+ timestamp: Instant::now(),
+ },
+ );
+ }
+
+ new_snapshot
+ }
+
+ fn push_event(zeta_project: &mut ZetaProject, event: Event) {
+ let events = &mut zeta_project.events;
+
+ if let Some(Event::BufferChange {
+ new_snapshot: last_new_snapshot,
+ timestamp: last_timestamp,
+ ..
+ }) = events.back_mut()
+ {
+ // Coalesce edits for the same buffer when they happen one after the other.
+ let Event::BufferChange {
+ old_snapshot,
+ new_snapshot,
+ timestamp,
+ } = &event;
+
+ if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
+ && old_snapshot.remote_id() == last_new_snapshot.remote_id()
+ && old_snapshot.version == last_new_snapshot.version
+ {
+ *last_new_snapshot = new_snapshot.clone();
+ *last_timestamp = *timestamp;
+ return;
+ }
+ }
+
+ if events.len() >= MAX_EVENT_COUNT {
+ // These are halved instead of popping to improve prompt caching.
+ events.drain(..MAX_EVENT_COUNT / 2);
+ }
+
+ events.push_back(event);
+ }
+
+ pub fn request_prediction(
+ &mut self,
+ project: &Entity<Project>,
+ buffer: &Entity<Buffer>,
+ position: language::Anchor,
+ cx: &mut Context<Self>,
+ ) -> Task<Result<Option<EditPrediction>>> {
+ let project_state = self.projects.get(&project.entity_id());
+
+ let index_state = project_state.map(|state| {
+ state
+ .syntax_index
+ .read_with(cx, |index, _cx| index.state().clone())
+ });
+ let excerpt_options = self.excerpt_options.clone();
+ let snapshot = buffer.read(cx).snapshot();
+ let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
+ return Task::ready(Err(anyhow!("No file path for excerpt")));
+ };
+ let client = self.client.clone();
+ let llm_token = self.llm_token.clone();
+ let app_version = AppVersion::global(cx);
+ let worktree_snapshots = project
+ .read(cx)
+ .worktrees(cx)
+ .map(|worktree| worktree.read(cx).snapshot())
+ .collect::<Vec<_>>();
+
+ let request_task = cx.background_spawn({
+ let snapshot = snapshot.clone();
+ async move {
+ let index_state = if let Some(index_state) = index_state {
+ Some(index_state.lock_owned().await)
+ } else {
+ None
+ };
+
+ let cursor_point = position.to_point(&snapshot);
+
+ // TODO: make this only true if debug view is open
+ let debug_info = true;
+
+ let Some(request) = EditPredictionContext::gather_context(
+ cursor_point,
+ &snapshot,
+ &excerpt_options,
+ index_state.as_deref(),
+ )
+ .map(|context| {
+ make_cloud_request(
+ excerpt_path.clone(),
+ context,
+ // TODO pass everything
+ Vec::new(),
+ false,
+ Vec::new(),
+ None,
+ debug_info,
+ &worktree_snapshots,
+ index_state.as_deref(),
+ )
+ }) else {
+ return Ok(None);
+ };
+
+ anyhow::Ok(Some(
+ Self::perform_request(client, llm_token, app_version, request).await?,
+ ))
+ }
+ });
+
+ let buffer = buffer.clone();
+
+ cx.spawn(async move |this, cx| {
+ match request_task.await {
+ Ok(Some((response, usage))) => {
+ log::debug!("predicted edits: {:?}", &response.edits);
+
+ if let Some(usage) = usage {
+ this.update(cx, |this, cx| {
+ this.user_store.update(cx, |user_store, cx| {
+ user_store.update_edit_prediction_usage(usage, cx);
+ });
+ })
+ .ok();
+ }
+
+ // TODO telemetry: duration, etc
+
+ // TODO produce smaller edits by diffing against snapshot first
+ //
+ // Cloud returns entire snippets/excerpts ranges as they were included
+ // in the request, but we should display smaller edits to the user.
+ //
+ // We can do this by computing a diff of each one against the snapshot.
+ // Similar to zeta::Zeta::compute_edits, but per edit.
+ let edits = response
+ .edits
+ .into_iter()
+ .map(|edit| {
+ // TODO edits to different files
+ (
+ snapshot.anchor_before(edit.range.start)
+ ..snapshot.anchor_before(edit.range.end),
+ edit.content,
+ )
+ })
+ .collect::<Vec<_>>()
+ .into();
+
+ let Some((edits, snapshot, edit_preview_task)) =
+ buffer.read_with(cx, |buffer, cx| {
+ let new_snapshot = buffer.snapshot();
+ let edits: Arc<[_]> =
+ interpolate(&snapshot, &new_snapshot, edits)?.into();
+ Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
+ })?
+ else {
+ return Ok(None);
+ };
+
+ Ok(Some(EditPrediction {
+ id: EditPredictionId(response.request_id),
+ edits,
+ snapshot,
+ edit_preview: edit_preview_task.await,
+ }))
+ }
+ Ok(None) => Ok(None),
+ Err(err) => {
+ if err.is::<ZedUpdateRequiredError>() {
+ cx.update(|cx| {
+ this.update(cx, |this, _cx| {
+ this.update_required = true;
+ })
+ .ok();
+
+ let error_message: SharedString = err.to_string().into();
+ show_app_notification(
+ NotificationId::unique::<ZedUpdateRequiredError>(),
+ cx,
+ move |cx| {
+ cx.new(|cx| {
+ ErrorMessagePrompt::new(error_message.clone(), cx)
+ .with_link_button(
+ "Update Zed",
+ "https://zed.dev/releases",
+ )
+ })
+ },
+ );
+ })
+ .ok();
+ }
+
+ Err(err)
+ }
+ }
+ })
+ }
+
+ async fn perform_request(
+ client: Arc<Client>,
+ llm_token: LlmApiToken,
+ app_version: SemanticVersion,
+ request: predict_edits_v3::PredictEditsRequest,
+ ) -> Result<(
+ predict_edits_v3::PredictEditsResponse,
+ Option<EditPredictionUsage>,
+ )> {
+ let http_client = client.http_client();
+ let mut token = llm_token.acquire(&client).await?;
+ let mut did_retry = false;
+
+ loop {
+ let request_builder = http_client::Request::builder().method(Method::POST);
+ let request_builder =
+ if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
+ request_builder.uri(predict_edits_url)
+ } else {
+ request_builder.uri(
+ http_client
+ .build_zed_llm_url("/predict_edits/v3", &[])?
+ .as_ref(),
+ )
+ };
+ let request = request_builder
+ .header("Content-Type", "application/json")
+ .header("Authorization", format!("Bearer {}", token))
+ .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
+ .body(serde_json::to_string(&request)?.into())?;
+
+ let mut response = http_client.send(request).await?;
+
+ if let Some(minimum_required_version) = response
+ .headers()
+ .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
+ .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
+ {
+ anyhow::ensure!(
+ app_version >= minimum_required_version,
+ ZedUpdateRequiredError {
+ minimum_version: minimum_required_version
+ }
+ );
+ }
+
+ if response.status().is_success() {
+ let usage = EditPredictionUsage::from_headers(response.headers()).ok();
+
+ let mut body = Vec::new();
+ response.body_mut().read_to_end(&mut body).await?;
+ return Ok((serde_json::from_slice(&body)?, usage));
+ } else if !did_retry
+ && response
+ .headers()
+ .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
+ .is_some()
+ {
+ did_retry = true;
+ token = llm_token.refresh(&client).await?;
+ } else {
+ let mut body = String::new();
+ response.body_mut().read_to_string(&mut body).await?;
+ anyhow::bail!(
+ "error predicting edits.\nStatus: {:?}\nBody: {}",
+ response.status(),
+ body
+ );
+ }
+ }
+ }
+}
+
+#[derive(Error, Debug)]
+#[error(
+ "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
+)]
+pub struct ZedUpdateRequiredError {
+ minimum_version: SemanticVersion,
+}
+
+pub struct ZetaEditPredictionProvider {
+ zeta: Entity<Zeta>,
+ current_prediction: Option<CurrentEditPrediction>,
+ next_pending_prediction_id: usize,
+ pending_predictions: ArrayVec<PendingPrediction, 2>,
+ last_request_timestamp: Instant,
+}
+
+impl ZetaEditPredictionProvider {
+ pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
+
+ pub fn new(
+ project: Option<&Entity<Project>>,
+ client: &Arc<Client>,
+ user_store: &Entity<UserStore>,
+ cx: &mut App,
+ ) -> Self {
+ let zeta = Zeta::global(client, user_store, cx);
+ if let Some(project) = project {
+ zeta.update(cx, |zeta, cx| {
+ zeta.register_project(project, cx);
+ });
+ }
+
+ Self {
+ zeta,
+ current_prediction: None,
+ next_pending_prediction_id: 0,
+ pending_predictions: ArrayVec::new(),
+ last_request_timestamp: Instant::now(),
+ }
+ }
+}
+
+#[derive(Clone)]
+struct CurrentEditPrediction {
+ buffer_id: EntityId,
+ prediction: EditPrediction,
+}
+
+impl CurrentEditPrediction {
+ fn should_replace_prediction(&self, old_prediction: &Self, snapshot: &BufferSnapshot) -> bool {
+ if self.buffer_id != old_prediction.buffer_id {
+ return true;
+ }
+
+ let Some(old_edits) = old_prediction.prediction.interpolate(snapshot) else {
+ return true;
+ };
+ let Some(new_edits) = self.prediction.interpolate(snapshot) else {
+ return false;
+ };
+
+ if old_edits.len() == 1 && new_edits.len() == 1 {
+ let (old_range, old_text) = &old_edits[0];
+ let (new_range, new_text) = &new_edits[0];
+ new_range == old_range && new_text.starts_with(old_text)
+ } else {
+ true
+ }
+ }
+}
+
+#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
+pub struct EditPredictionId(Uuid);
+
+impl From<EditPredictionId> for gpui::ElementId {
+ fn from(value: EditPredictionId) -> Self {
+ gpui::ElementId::Uuid(value.0)
+ }
+}
+
+impl std::fmt::Display for EditPredictionId {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ write!(f, "{}", self.0)
+ }
+}
+
+#[derive(Clone)]
+pub struct EditPrediction {
+ id: EditPredictionId,
+ edits: Arc<[(Range<Anchor>, String)]>,
+ snapshot: BufferSnapshot,
+ edit_preview: EditPreview,
+}
+
+impl EditPrediction {
+ fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, String)>> {
+ interpolate(&self.snapshot, new_snapshot, self.edits.clone())
+ }
+}
+
+struct PendingPrediction {
+ id: usize,
+ _task: Task<()>,
+}
+
+impl EditPredictionProvider for ZetaEditPredictionProvider {
+ fn name() -> &'static str {
+ "zed-predict2"
+ }
+
+ fn display_name() -> &'static str {
+ "Zed's Edit Predictions 2"
+ }
+
+ fn show_completions_in_menu() -> bool {
+ true
+ }
+
+ fn show_tab_accept_marker() -> bool {
+ true
+ }
+
+ fn data_collection_state(&self, _cx: &App) -> DataCollectionState {
+ // TODO [zeta2]
+ DataCollectionState::Unsupported
+ }
+
+ fn toggle_data_collection(&mut self, _cx: &mut App) {
+ // TODO [zeta2]
+ }
+
+ fn usage(&self, cx: &App) -> Option<client::EditPredictionUsage> {
+ self.zeta.read(cx).usage(cx)
+ }
+
+ fn is_enabled(
+ &self,
+ _buffer: &Entity<language::Buffer>,
+ _cursor_position: language::Anchor,
+ _cx: &App,
+ ) -> bool {
+ true
+ }
+
+ fn is_refreshing(&self) -> bool {
+ !self.pending_predictions.is_empty()
+ }
+
+ fn refresh(
+ &mut self,
+ project: Option<Entity<project::Project>>,
+ buffer: Entity<language::Buffer>,
+ cursor_position: language::Anchor,
+ _debounce: bool,
+ cx: &mut Context<Self>,
+ ) {
+ let Some(project) = project else {
+ return;
+ };
+
+ if self
+ .zeta
+ .read(cx)
+ .user_store
+ .read_with(cx, |user_store, _cx| {
+ user_store.account_too_young() || user_store.has_overdue_invoices()
+ })
+ {
+ return;
+ }
+
+ if let Some(current_prediction) = self.current_prediction.as_ref() {
+ let snapshot = buffer.read(cx).snapshot();
+ if current_prediction
+ .prediction
+ .interpolate(&snapshot)
+ .is_some()
+ {
+ return;
+ }
+ }
+
+ let pending_prediction_id = self.next_pending_prediction_id;
+ self.next_pending_prediction_id += 1;
+ let last_request_timestamp = self.last_request_timestamp;
+
+ let task = cx.spawn(async move |this, cx| {
+ if let Some(timeout) = (last_request_timestamp + Self::THROTTLE_TIMEOUT)
+ .checked_duration_since(Instant::now())
+ {
+ cx.background_executor().timer(timeout).await;
+ }
+
+ let prediction_request = this.update(cx, |this, cx| {
+ this.last_request_timestamp = Instant::now();
+ this.zeta.update(cx, |zeta, cx| {
+ zeta.request_prediction(&project, &buffer, cursor_position, cx)
+ })
+ });
+
+ let prediction = match prediction_request {
+ Ok(prediction_request) => {
+ let prediction_request = prediction_request.await;
+ prediction_request.map(|c| {
+ c.map(|prediction| CurrentEditPrediction {
+ buffer_id: buffer.entity_id(),
+ prediction,
+ })
+ })
+ }
+ Err(error) => Err(error),
+ };
+
+ this.update(cx, |this, cx| {
+ if this.pending_predictions[0].id == pending_prediction_id {
+ this.pending_predictions.remove(0);
+ } else {
+ this.pending_predictions.clear();
+ }
+
+ let Some(new_prediction) = prediction
+ .context("edit prediction failed")
+ .log_err()
+ .flatten()
+ else {
+ cx.notify();
+ return;
+ };
+
+ if let Some(old_prediction) = this.current_prediction.as_ref() {
+ let snapshot = buffer.read(cx).snapshot();
+ if new_prediction.should_replace_prediction(old_prediction, &snapshot) {
+ this.current_prediction = Some(new_prediction);
+ }
+ } else {
+ this.current_prediction = Some(new_prediction);
+ }
+
+ cx.notify();
+ })
+ .ok();
+ });
+
+ // We always maintain at most two pending predictions. When we already
+ // have two, we replace the newest one.
+ if self.pending_predictions.len() <= 1 {
+ self.pending_predictions.push(PendingPrediction {
+ id: pending_prediction_id,
+ _task: task,
+ });
+ } else if self.pending_predictions.len() == 2 {
+ self.pending_predictions.pop();
+ self.pending_predictions.push(PendingPrediction {
+ id: pending_prediction_id,
+ _task: task,
+ });
+ }
+
+ cx.notify();
+ }
+
+ fn cycle(
+ &mut self,
+ _buffer: Entity<language::Buffer>,
+ _cursor_position: language::Anchor,
+ _direction: Direction,
+ _cx: &mut Context<Self>,
+ ) {
+ }
+
+ fn accept(&mut self, _cx: &mut Context<Self>) {
+ // TODO [zeta2] report accept
+ self.current_prediction.take();
+ self.pending_predictions.clear();
+ }
+
+ fn discard(&mut self, _cx: &mut Context<Self>) {
+ self.pending_predictions.clear();
+ self.current_prediction.take();
+ }
+
+ fn suggest(
+ &mut self,
+ buffer: &Entity<language::Buffer>,
+ cursor_position: language::Anchor,
+ cx: &mut Context<Self>,
+ ) -> Option<edit_prediction::EditPrediction> {
+ let CurrentEditPrediction {
+ buffer_id,
+ prediction,
+ ..
+ } = self.current_prediction.as_mut()?;
+
+ // Invalidate previous prediction if it was generated for a different buffer.
+ if *buffer_id != buffer.entity_id() {
+ self.current_prediction.take();
+ return None;
+ }
+
+ let buffer = buffer.read(cx);
+ let Some(edits) = prediction.interpolate(&buffer.snapshot()) else {
+ self.current_prediction.take();
+ return None;
+ };
+
+ let cursor_row = cursor_position.to_point(buffer).row;
+ let (closest_edit_ix, (closest_edit_range, _)) =
+ edits.iter().enumerate().min_by_key(|(_, (range, _))| {
+ let distance_from_start = cursor_row.abs_diff(range.start.to_point(buffer).row);
+ let distance_from_end = cursor_row.abs_diff(range.end.to_point(buffer).row);
+ cmp::min(distance_from_start, distance_from_end)
+ })?;
+
+ let mut edit_start_ix = closest_edit_ix;
+ for (range, _) in edits[..edit_start_ix].iter().rev() {
+ let distance_from_closest_edit =
+ closest_edit_range.start.to_point(buffer).row - range.end.to_point(buffer).row;
+ if distance_from_closest_edit <= 1 {
+ edit_start_ix -= 1;
+ } else {
+ break;
+ }
+ }
+
+ let mut edit_end_ix = closest_edit_ix + 1;
+ for (range, _) in &edits[edit_end_ix..] {
+ let distance_from_closest_edit =
+ range.start.to_point(buffer).row - closest_edit_range.end.to_point(buffer).row;
+ if distance_from_closest_edit <= 1 {
+ edit_end_ix += 1;
+ } else {
+ break;
+ }
+ }
+
+ Some(edit_prediction::EditPrediction {
+ id: Some(prediction.id.to_string().into()),
+ edits: edits[edit_start_ix..edit_end_ix].to_vec(),
+ edit_preview: Some(prediction.edit_preview.clone()),
+ })
+ }
+}
+
+fn make_cloud_request(
+ excerpt_path: PathBuf,
+ context: EditPredictionContext,
+ events: Vec<predict_edits_v3::Event>,
+ can_collect_data: bool,
+ diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
+ git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
+ debug_info: bool,
+ worktrees: &Vec<worktree::Snapshot>,
+ index_state: Option<&SyntaxIndexState>,
+) -> predict_edits_v3::PredictEditsRequest {
+ let mut signatures = Vec::new();
+ let mut declaration_to_signature_index = HashMap::default();
+ let mut referenced_declarations = Vec::new();
+
+ for snippet in context.snippets {
+ let project_entry_id = snippet.declaration.project_entry_id();
+ // TODO: Use full paths (worktree rooted) - need to move full_path method to the snapshot.
+ // Note that currently full_path is currently being used for excerpt_path.
+ let Some(path) = worktrees.iter().find_map(|worktree| {
+ let abs_path = worktree.abs_path();
+ worktree
+ .entry_for_id(project_entry_id)
+ .map(|e| abs_path.join(&e.path))
+ }) else {
+ continue;
+ };
+
+ let parent_index = index_state.and_then(|index_state| {
+ snippet.declaration.parent().and_then(|parent| {
+ add_signature(
+ parent,
+ &mut declaration_to_signature_index,
+ &mut signatures,
+ index_state,
+ )
+ })
+ });
+
+ let (text, text_is_truncated) = snippet.declaration.item_text();
+ referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
+ path,
+ text: text.into(),
+ range: snippet.declaration.item_range(),
+ text_is_truncated,
+ signature_range: snippet.declaration.signature_range_in_item_text(),
+ parent_index,
+ score_components: snippet.score_components,
+ signature_score: snippet.scores.signature,
+ declaration_score: snippet.scores.declaration,
+ });
+ }
+
+ let excerpt_parent = index_state.and_then(|index_state| {
+ context
+ .excerpt
+ .parent_declarations
+ .last()
+ .and_then(|(parent, _)| {
+ add_signature(
+ *parent,
+ &mut declaration_to_signature_index,
+ &mut signatures,
+ index_state,
+ )
+ })
+ });
+
+ predict_edits_v3::PredictEditsRequest {
+ excerpt_path,
+ excerpt: context.excerpt_text.body,
+ excerpt_range: context.excerpt.range,
+ cursor_offset: context.cursor_offset_in_excerpt,
+ referenced_declarations,
+ signatures,
+ excerpt_parent,
+ events,
+ can_collect_data,
+ diagnostic_groups,
+ git_info,
+ debug_info,
+ }
+}
+
+fn add_signature(
+ declaration_id: DeclarationId,
+ declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
+ signatures: &mut Vec<Signature>,
+ index: &SyntaxIndexState,
+) -> Option<usize> {
+ if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
+ return Some(*signature_index);
+ }
+ let Some(parent_declaration) = index.declaration(declaration_id) else {
+ log::error!("bug: missing parent declaration");
+ return None;
+ };
+ let parent_index = parent_declaration.parent().and_then(|parent| {
+ add_signature(parent, declaration_to_signature_index, signatures, index)
+ });
+ let (text, text_is_truncated) = parent_declaration.signature_text();
+ let signature_index = signatures.len();
+ signatures.push(Signature {
+ text: text.into(),
+ text_is_truncated,
+ parent_index,
+ });
+ declaration_to_signature_index.insert(declaration_id, signature_index);
+ Some(signature_index)
+}
+
+fn interpolate(
+ old_snapshot: &BufferSnapshot,
+ new_snapshot: &BufferSnapshot,
+ current_edits: Arc<[(Range<Anchor>, String)]>,
+) -> Option<Vec<(Range<Anchor>, String)>> {
+ let mut edits = Vec::new();
+
+ let mut model_edits = current_edits.iter().peekable();
+ for user_edit in new_snapshot.edits_since::<usize>(&old_snapshot.version) {
+ while let Some((model_old_range, _)) = model_edits.peek() {
+ let model_old_range = model_old_range.to_offset(old_snapshot);
+ if model_old_range.end < user_edit.old.start {
+ let (model_old_range, model_new_text) = model_edits.next().unwrap();
+ edits.push((model_old_range.clone(), model_new_text.clone()));
+ } else {
+ break;
+ }
+ }
+
+ if let Some((model_old_range, model_new_text)) = model_edits.peek() {
+ let model_old_offset_range = model_old_range.to_offset(old_snapshot);
+ if user_edit.old == model_old_offset_range {
+ let user_new_text = new_snapshot
+ .text_for_range(user_edit.new.clone())
+ .collect::<String>();
+
+ if let Some(model_suffix) = model_new_text.strip_prefix(&user_new_text) {
+ if !model_suffix.is_empty() {
+ let anchor = old_snapshot.anchor_after(user_edit.old.end);
+ edits.push((anchor..anchor, model_suffix.to_string()));
+ }
+
+ model_edits.next();
+ continue;
+ }
+ }
+ }
+
+ return None;
+ }
+
+ edits.extend(model_edits.cloned());
+
+ if edits.is_empty() { None } else { Some(edits) }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use gpui::TestAppContext;
+ use language::ToOffset as _;
+
+ #[gpui::test]
+ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
+ let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
+ let edits: Arc<[(Range<Anchor>, String)]> = cx.update(|cx| {
+ to_prediction_edits(
+ [(2..5, "REM".to_string()), (9..11, "".to_string())],
+ &buffer,
+ cx,
+ )
+ .into()
+ });
+
+ let edit_preview = cx
+ .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
+ .await;
+
+ let prediction = EditPrediction {
+ id: EditPredictionId(Uuid::new_v4()),
+ edits,
+ snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
+ edit_preview,
+ };
+
+ cx.update(|cx| {
+ assert_eq!(
+ from_prediction_edits(
+ &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
+ &buffer,
+ cx
+ ),
+ vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
+ );
+
+ buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
+ assert_eq!(
+ from_prediction_edits(
+ &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
+ &buffer,
+ cx
+ ),
+ vec![(2..2, "REM".to_string()), (6..8, "".to_string())]
+ );
+
+ buffer.update(cx, |buffer, cx| buffer.undo(cx));
+ assert_eq!(
+ from_prediction_edits(
+ &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
+ &buffer,
+ cx
+ ),
+ vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
+ );
+
+ buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
+ assert_eq!(
+ from_prediction_edits(
+ &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
+ &buffer,
+ cx
+ ),
+ vec![(3..3, "EM".to_string()), (7..9, "".to_string())]
+ );
+
+ buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
+ assert_eq!(
+ from_prediction_edits(
+ &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
+ &buffer,
+ cx
+ ),
+ vec![(4..4, "M".to_string()), (8..10, "".to_string())]
+ );
+
+ buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
+ assert_eq!(
+ from_prediction_edits(
+ &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
+ &buffer,
+ cx
+ ),
+ vec![(9..11, "".to_string())]
+ );
+
+ buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
+ assert_eq!(
+ from_prediction_edits(
+ &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
+ &buffer,
+ cx
+ ),
+ vec![(4..4, "M".to_string()), (8..10, "".to_string())]
+ );
+
+ buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
+ assert_eq!(
+ from_prediction_edits(
+ &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
+ &buffer,
+ cx
+ ),
+ vec![(4..4, "M".to_string())]
+ );
+
+ buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
+ assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None);
+ })
+ }
+
+ fn to_prediction_edits(
+ iterator: impl IntoIterator<Item = (Range<usize>, String)>,
+ buffer: &Entity<Buffer>,
+ cx: &App,
+ ) -> Vec<(Range<Anchor>, String)> {
+ let buffer = buffer.read(cx);
+ iterator
+ .into_iter()
+ .map(|(range, text)| {
+ (
+ buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
+ text,
+ )
+ })
+ .collect()
+ }
+
+ fn from_prediction_edits(
+ editor_edits: &[(Range<Anchor>, String)],
+ buffer: &Entity<Buffer>,
+ cx: &App,
+ ) -> Vec<(Range<usize>, String)> {
+ let buffer = buffer.read(cx);
+ editor_edits
+ .iter()
+ .map(|(range, text)| {
+ (
+ range.start.to_offset(buffer)..range.end.to_offset(buffer),
+ text.clone(),
+ )
+ })
+ .collect()
+ }
+}