diff --git a/crates/cloud_llm_client/src/cloud_llm_client.rs b/crates/cloud_llm_client/src/cloud_llm_client.rs index 6cb53b026cd6029a82dd1f7179f1e717b04f1fe1..791f9488f160aa1d1f41862fcdca87f937629452 100644 --- a/crates/cloud_llm_client/src/cloud_llm_client.rs +++ b/crates/cloud_llm_client/src/cloud_llm_client.rs @@ -138,30 +138,44 @@ pub enum LanguageModelProvider { #[derive(Debug, Clone, Serialize, Deserialize)] pub struct PredictEditsBody { - #[serde(skip_serializing_if = "Option::is_none", default)] - pub outline: Option, pub input_events: String, pub input_excerpt: String, - #[serde(skip_serializing_if = "Option::is_none", default)] - pub speculated_output: Option, /// Whether the user provided consent for sampling this interaction. #[serde(default, alias = "data_collection_permission")] pub can_collect_data: bool, + /// Note that this is no longer sent, in favor of `PredictEditsAdditionalContext`. #[serde(skip_serializing_if = "Option::is_none", default)] pub diagnostic_groups: Option>, - /// Info about the git repository state, only present when can_collect_data is true. + /// Info about the git repository state, only present when can_collect_data is true. Note that + /// this is no longer sent, in favor of `PredictEditsAdditionalContext`. #[serde(skip_serializing_if = "Option::is_none", default)] pub git_info: Option, } +/// Additional context only stored when can_collect_data is true for the corresponding edit +/// predictions request. #[derive(Debug, Clone, Serialize, Deserialize)] -pub struct PredictEditsGitInfo { +pub struct PredictEditsAdditionalContext { /// Path to the file in the repository that contains the input excerpt. - #[serde(skip_serializing_if = "Option::is_none", default)] - pub input_path: Option, + pub input_path: String, /// Cursor position within the file that contains the input excerpt. - #[serde(skip_serializing_if = "Option::is_none", default)] - pub cursor_point: Option, + pub cursor_point: Point, + /// Cursor offset in bytes within the file that contains the input excerpt. + pub cursor_offset: usize, + #[serde(flatten)] + pub git_info: PredictEditsGitInfo, + /// Diagnostic near the cursor position. + #[serde(skip_serializing_if = "Vec::is_empty", default)] + pub diagnostic_groups: Vec<(String, Box)>, + /// True if the diagnostics were truncated. + pub diagnostic_groups_truncated: bool, + /// Recently active files that may be within this repository. + #[serde(skip_serializing_if = "Vec::is_empty", default)] + pub recent_files: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PredictEditsGitInfo { /// SHA of git HEAD commit at time of prediction. #[serde(skip_serializing_if = "Option::is_none", default)] pub head_sha: Option, @@ -171,9 +185,6 @@ pub struct PredictEditsGitInfo { /// URL of the remote called `upstream`. #[serde(skip_serializing_if = "Option::is_none", default)] pub remote_upstream_url: Option, - /// Recently active files that may be within this repository. - #[serde(skip_serializing_if = "Option::is_none", default)] - pub recent_files: Option>, } /// A zero-indexed point in a text buffer consisting of a row and column. diff --git a/crates/language/src/buffer.rs b/crates/language/src/buffer.rs index 4ddc2b3018614f592beeb55aaa2cc9ed46b5522c..fdc51b0b1a77af7fcc7d1414719b587629800154 100644 --- a/crates/language/src/buffer.rs +++ b/crates/language/src/buffer.rs @@ -146,7 +146,7 @@ pub struct BufferSnapshot { pub text: text::BufferSnapshot, pub(crate) syntax: SyntaxSnapshot, file: Option>, - diagnostics: SmallVec<[(LanguageServerId, DiagnosticSet); 2]>, + pub diagnostics: SmallVec<[(LanguageServerId, DiagnosticSet); 2]>, remote_selections: TreeMap, language: Option>, non_text_state_update_count: usize, diff --git a/crates/zeta/src/zeta.rs b/crates/zeta/src/zeta.rs index 79864fd21e7ed3030e9bd09722f851fb1b6244f8..e7d1efd3d40396c998fa7c1ecc94a1cb29c0d36f 100644 --- a/crates/zeta/src/zeta.rs +++ b/crates/zeta/src/zeta.rs @@ -20,8 +20,8 @@ use anyhow::{Context as _, Result, anyhow}; use client::{Client, EditPredictionUsage, UserStore}; use cloud_llm_client::{ AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME, - PredictEditsBody, PredictEditsGitInfo, PredictEditsRecentFile, PredictEditsResponse, - ZED_VERSION_HEADER_NAME, + PredictEditsAdditionalContext, PredictEditsBody, PredictEditsGitInfo, PredictEditsRecentFile, + PredictEditsResponse, ZED_VERSION_HEADER_NAME, }; use collections::{HashMap, HashSet, VecDeque}; use futures::AsyncReadExt; @@ -68,7 +68,6 @@ const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_ch const MAX_CONTEXT_TOKENS: usize = 150; const MAX_REWRITE_TOKENS: usize = 350; const MAX_EVENT_TOKENS: usize = 500; -const MAX_DIAGNOSTIC_GROUPS: usize = 10; /// Maximum number of events to track. const MAX_EVENT_COUNT: usize = 16; @@ -76,12 +75,15 @@ const MAX_EVENT_COUNT: usize = 16; /// Maximum number of recent files to track. const MAX_RECENT_PROJECT_ENTRIES_COUNT: usize = 16; -/// Minimum number of milliseconds between recent project entries to keep them +/// Minimum number of milliseconds between recent file entries. const MIN_TIME_BETWEEN_RECENT_PROJECT_ENTRIES: Duration = Duration::from_millis(100); /// Maximum file path length to include in recent files list. const MAX_RECENT_FILE_PATH_LENGTH: usize = 512; +/// Maximum number of JSON bytes for diagnostics in additional context. +const MAX_DIAGNOSTICS_BYTES: usize = 4096; + /// Maximum number of edit predictions to store for feedback. const MAX_SHOWN_COMPLETION_COUNT: usize = 50; @@ -151,7 +153,6 @@ pub struct EditPrediction { edits: Arc<[(Range, String)]>, snapshot: BufferSnapshot, edit_preview: EditPreview, - input_outline: Arc, input_events: Arc, input_excerpt: Arc, output_excerpt: Arc, @@ -407,7 +408,7 @@ impl Zeta { fn request_completion_impl( &mut self, workspace: Option>, - project: Option<&Entity>, + project: Option>, buffer: &Entity, cursor: language::Anchor, can_collect_data: CanCollectData, @@ -429,30 +430,21 @@ impl Zeta { let llm_token = self.llm_token.clone(); let app_version = AppVersion::global(cx); - let cursor_point = cursor.to_point(&snapshot); - let cursor_offset = cursor_point.to_offset(&snapshot); - let git_info = if matches!(can_collect_data, CanCollectData(true)) { - self.gather_git_info(cursor_point, &buffer_snapshotted_at, &snapshot, project, cx) - } else { - None - }; - let full_path: Arc = snapshot .file() .map(|f| Arc::from(f.full_path(cx).as_path())) .unwrap_or_else(|| Arc::from(Path::new("untitled"))); let full_path_str = full_path.to_string_lossy().to_string(); + let cursor_point = cursor.to_point(&snapshot); + let cursor_offset = cursor_point.to_offset(&snapshot); let make_events_prompt = move || prompt_for_events(&events, MAX_EVENT_TOKENS); - let gather_task = gather_context( - project, + let gather_task = cx.background_spawn(gather_context( full_path_str, - &snapshot, + snapshot.clone(), cursor_point, make_events_prompt, can_collect_data, - git_info, - cx, - ); + )); cx.spawn(async move |this, cx| { let GatherContextOutput { @@ -461,13 +453,41 @@ impl Zeta { } = gather_task.await?; let done_gathering_context_at = Instant::now(); + let additional_context_task = if matches!(can_collect_data, CanCollectData(true)) + && let Some(file) = snapshot.file() + && let Ok(project_path) = cx.update(|cx| ProjectPath::from_file(file.as_ref(), cx)) + { + // This is async to reduce latency of the edit predictions request. The downside is + // that it will see a slightly later state than was used when gathering context. + let snapshot = snapshot.clone(); + let this = this.clone(); + Some(cx.spawn(async move |cx| { + if let Ok(Some(task)) = this.update(cx, |this, cx| { + this.gather_additional_context( + cursor_point, + cursor_offset, + snapshot, + &buffer_snapshotted_at, + project_path, + project.as_ref(), + cx, + ) + }) { + Some(task.await) + } else { + None + } + })) + } else { + None + }; + log::debug!( "Events:\n{}\nExcerpt:\n{:?}", body.input_events, body.input_excerpt ); - let input_outline = body.outline.clone().unwrap_or_default(); let input_events = body.input_events.clone(); let input_excerpt = body.input_excerpt.clone(); @@ -524,6 +544,7 @@ impl Zeta { .ok(); } + let request_id = response.request_id.clone(); let edit_prediction = Self::process_completion_response( response, buffer, @@ -531,7 +552,6 @@ impl Zeta { editable_range, cursor_offset, full_path, - input_outline, input_events, input_excerpt, buffer_snapshotted_at, @@ -555,6 +575,19 @@ impl Zeta { ); } + if let Some(additional_context_task) = additional_context_task { + cx.background_spawn(async move { + if let Some(additional_context) = additional_context_task.await { + telemetry::event!( + "Edit Prediction Additional Context", + request_id = request_id, + additional_context = additional_context + ); + } + }) + .detach(); + } + edit_prediction }) } @@ -572,13 +605,12 @@ impl Zeta { and then another "#}; - let project = None; let buffer = cx.new(|cx| Buffer::local(test_buffer_text, cx)); let position = buffer.read(cx).anchor_before(Point::new(1, 0)); let completion_tasks = vec![ self.fake_completion( - project, + None, &buffer, position, PredictEditsResponse { @@ -590,12 +622,12 @@ And maybe a short line Then a few lines and then another {EDITABLE_REGION_END_MARKER} - ", ), + "), }, cx, ), self.fake_completion( - project, + None, &buffer, position, PredictEditsResponse { @@ -612,7 +644,7 @@ and then another cx, ), self.fake_completion( - project, + None, &buffer, position, PredictEditsResponse { @@ -630,7 +662,7 @@ and then another cx, ), self.fake_completion( - project, + None, &buffer, position, PredictEditsResponse { @@ -648,7 +680,7 @@ and then another cx, ), self.fake_completion( - project, + None, &buffer, position, PredictEditsResponse { @@ -665,7 +697,7 @@ and then another cx, ), self.fake_completion( - project, + None, &buffer, position, PredictEditsResponse { @@ -681,7 +713,7 @@ and then another cx, ), self.fake_completion( - project, + None, &buffer, position, PredictEditsResponse { @@ -715,7 +747,7 @@ and then another #[cfg(any(test, feature = "test-support"))] pub fn fake_completion( &mut self, - project: Option<&Entity>, + project: Option>, buffer: &Entity, position: language::Anchor, response: PredictEditsResponse, @@ -736,7 +768,7 @@ and then another pub fn request_completion( &mut self, - project: Option<&Entity>, + project: Option>, buffer: &Entity, position: language::Anchor, can_collect_data: CanCollectData, @@ -904,7 +936,6 @@ and then another editable_range: Range, cursor_offset: usize, path: Arc, - input_outline: String, input_events: String, input_excerpt: String, buffer_snapshotted_at: Instant, @@ -949,7 +980,6 @@ and then another edits, edit_preview, snapshot, - input_outline: input_outline.into(), input_events: input_events.into(), input_excerpt: input_excerpt.into(), output_excerpt, @@ -1078,7 +1108,6 @@ and then another rating, input_events = completion.input_events, input_excerpt = completion.input_excerpt, - input_outline = completion.input_outline, output_excerpt = completion.output_excerpt, feedback ); @@ -1136,17 +1165,17 @@ and then another } } - fn gather_git_info( + fn gather_additional_context( &mut self, cursor_point: language::Point, + cursor_offset: usize, + snapshot: BufferSnapshot, buffer_snapshotted_at: &Instant, - snapshot: &BufferSnapshot, + project_path: ProjectPath, project: Option<&Entity>, cx: &mut Context, - ) -> Option { + ) -> Option> { let project = project?.read(cx); - let file = snapshot.file()?; - let project_path = ProjectPath::from_file(file.as_ref(), cx); let entry = project.entry_for_path(&project_path, cx)?; if !worktree_entry_eligible_for_collection(&entry) { return None; @@ -1155,7 +1184,25 @@ and then another let git_store = project.git_store().read(cx); let (repository, repo_path) = git_store.repository_and_path_for_project_path(&project_path, cx)?; - let repo_path_str = repo_path.to_str()?; + let repo_path_string = repo_path.to_str()?.to_string(); + + let diagnostics = if let Some(local_lsp_store) = project.lsp_store().read(cx).as_local() { + snapshot + .diagnostics + .iter() + .filter_map(|(language_server_id, diagnostics)| { + let language_server = + local_lsp_store.running_language_server_for_id(*language_server_id)?; + Some(( + *language_server_id, + language_server.name(), + diagnostics.clone(), + )) + }) + .collect() + } else { + Vec::new() + }; repository.update(cx, |repository, cx| { let head_sha = repository.head_commit.as_ref()?.sha.to_string(); @@ -1163,14 +1210,61 @@ and then another let remote_upstream_url = repository.remote_upstream_url.clone(); let recent_files = self.recent_files(&buffer_snapshotted_at, repository, cx); - Some(PredictEditsGitInfo { - input_path: Some(repo_path_str.to_string()), - cursor_point: Some(to_cloud_llm_client_point(cursor_point)), - head_sha: Some(head_sha), - remote_origin_url, - remote_upstream_url, - recent_files: Some(recent_files), - }) + // group, resolve, and select diagnostics on a background thread + Some(cx.background_spawn(async move { + let mut diagnostic_groups_with_name = Vec::new(); + for (language_server_id, language_server_name, diagnostics) in + diagnostics.into_iter() + { + let mut groups = Vec::new(); + diagnostics.groups(language_server_id, &mut groups, &snapshot); + diagnostic_groups_with_name.extend(groups.into_iter().map(|(_, group)| { + ( + language_server_name.clone(), + group.resolve::(&snapshot), + ) + })); + } + + // sort by proximity to cursor + diagnostic_groups_with_name.sort_by_key(|(_, group)| { + let range = &group.entries[group.primary_ix].range; + if range.start >= cursor_offset { + range.start - cursor_offset + } else if cursor_offset >= range.end { + cursor_offset - range.end + } else { + (cursor_offset - range.start).min(range.end - cursor_offset) + } + }); + + let mut diagnostic_groups = Vec::new(); + let mut diagnostic_groups_truncated = false; + let mut diagnostics_byte_count = 0; + for (name, group) in diagnostic_groups_with_name { + let raw_value = serde_json::value::to_raw_value(&group).unwrap(); + diagnostics_byte_count += name.0.len() + raw_value.get().len(); + if diagnostics_byte_count > MAX_DIAGNOSTICS_BYTES { + diagnostic_groups_truncated = true; + break; + } + diagnostic_groups.push((name.to_string(), raw_value)); + } + + PredictEditsAdditionalContext { + input_path: repo_path_string, + cursor_point: to_cloud_llm_client_point(cursor_point), + cursor_offset: cursor_offset, + git_info: PredictEditsGitInfo { + head_sha: Some(head_sha), + remote_origin_url, + remote_upstream_url, + }, + diagnostic_groups, + diagnostic_groups_truncated, + recent_files, + } + })) }) } @@ -1319,74 +1413,34 @@ pub struct GatherContextOutput { pub editable_range: Range, } -pub fn gather_context( - project: Option<&Entity>, +pub async fn gather_context( full_path_str: String, - snapshot: &BufferSnapshot, + snapshot: BufferSnapshot, cursor_point: language::Point, make_events_prompt: impl FnOnce() -> String + Send + 'static, can_collect_data: CanCollectData, - git_info: Option, - cx: &App, -) -> Task> { - let local_lsp_store = - project.and_then(|project| project.read(cx).lsp_store().read(cx).as_local()); - let diagnostic_groups: Vec<(String, serde_json::Value)> = - if matches!(can_collect_data, CanCollectData(true)) - && let Some(local_lsp_store) = local_lsp_store - { - snapshot - .diagnostic_groups(None) - .into_iter() - .filter_map(|(language_server_id, diagnostic_group)| { - let language_server = - local_lsp_store.running_language_server_for_id(language_server_id)?; - let diagnostic_group = diagnostic_group.resolve::(snapshot); - let language_server_name = language_server.name().to_string(); - let serialized = serde_json::to_value(diagnostic_group).unwrap(); - Some((language_server_name, serialized)) - }) - .collect::>() - } else { - Vec::new() - }; - - cx.background_spawn({ - let snapshot = snapshot.clone(); - async move { - let diagnostic_groups = if diagnostic_groups.is_empty() - || diagnostic_groups.len() >= MAX_DIAGNOSTIC_GROUPS - { - None - } else { - Some(diagnostic_groups) - }; - - let input_excerpt = excerpt_for_cursor_position( - cursor_point, - &full_path_str, - &snapshot, - MAX_REWRITE_TOKENS, - MAX_CONTEXT_TOKENS, - ); - let input_events = make_events_prompt(); - let editable_range = input_excerpt.editable_range.to_offset(&snapshot); - - let body = PredictEditsBody { - input_events, - input_excerpt: input_excerpt.prompt, - can_collect_data: can_collect_data.0, - diagnostic_groups, - git_info, - outline: None, - speculated_output: None, - }; - - Ok(GatherContextOutput { - body, - editable_range, - }) - } +) -> Result { + let input_excerpt = excerpt_for_cursor_position( + cursor_point, + &full_path_str, + &snapshot, + MAX_REWRITE_TOKENS, + MAX_CONTEXT_TOKENS, + ); + let input_events = make_events_prompt(); + let editable_range = input_excerpt.editable_range.to_offset(&snapshot); + + let body = PredictEditsBody { + input_events, + input_excerpt: input_excerpt.prompt, + can_collect_data: can_collect_data.0, + diagnostic_groups: None, + git_info: None, + }; + + Ok(GatherContextOutput { + body, + editable_range, }) } @@ -1763,13 +1817,7 @@ impl edit_prediction::EditPredictionProvider for ZetaEditPredictionProvider { let completion_request = this.update(cx, |this, cx| { this.last_request_timestamp = Instant::now(); this.zeta.update(cx, |zeta, cx| { - zeta.request_completion( - project.as_ref(), - &buffer, - position, - can_collect_data, - cx, - ) + zeta.request_completion(project, &buffer, position, can_collect_data, cx) }) }); @@ -1983,7 +2031,6 @@ mod tests { id: EditPredictionId(Uuid::new_v4()), excerpt_range: 0..0, cursor_offset: 0, - input_outline: "".into(), input_events: "".into(), input_excerpt: "".into(), output_excerpt: "".into(), diff --git a/crates/zeta_cli/src/main.rs b/crates/zeta_cli/src/main.rs index 7ffbd688985bfd11bae61d7b4582b267ff3e7903..68aead90707fb77c9bbb5c1e1c22adf8cc80dc50 100644 --- a/crates/zeta_cli/src/main.rs +++ b/crates/zeta_cli/src/main.rs @@ -129,15 +129,15 @@ async fn get_context( return Err(anyhow!("Absolute paths are not supported in --cursor")); } - let (project, _lsp_open_handle, buffer) = if use_language_server { - let (project, lsp_open_handle, buffer) = + let (_lsp_open_handle, buffer) = if use_language_server { + let (_project, lsp_open_handle, buffer) = open_buffer_with_language_server(&worktree_path, &cursor.path, app_state, cx).await?; - (Some(project), Some(lsp_open_handle), buffer) + (Some(lsp_open_handle), buffer) } else { let abs_path = worktree_path.join(&cursor.path); let content = smol::fs::read_to_string(&abs_path).await?; let buffer = cx.new(|cx| Buffer::local(content, cx))?; - (None, None, buffer) + (None, buffer) }; let worktree_name = worktree_path @@ -172,21 +172,14 @@ async fn get_context( None => String::new(), }; // Enable gathering extra data not currently needed for edit predictions - let git_info = None; - let mut gather_context_output = cx - .update(|cx| { - gather_context( - project.as_ref(), - full_path_str, - &snapshot, - clipped_cursor, - move || events, - CanCollectData(true), - git_info, - cx, - ) - })? - .await; + let mut gather_context_output = gather_context( + full_path_str, + snapshot, + clipped_cursor, + move || events, + CanCollectData(true), + ) + .await; // Disable data collection for these requests, as this is currently just used for evals if let Ok(gather_context_output) = gather_context_output.as_mut() {