diff --git a/Cargo.lock b/Cargo.lock index c7173bcc0533006cdca7c726a0399c088660dbab..dbb77742643d791ab1a39bdbf50901c6a3097280 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3213,6 +3213,7 @@ name = "cloud_llm_client" version = "0.1.0" dependencies = [ "anyhow", + "chrono", "pretty_assertions", "serde", "serde_json", @@ -21656,6 +21657,7 @@ dependencies = [ "util", "workspace", "workspace-hack", + "worktree", ] [[package]] diff --git a/crates/cloud_llm_client/Cargo.toml b/crates/cloud_llm_client/Cargo.toml index 6f090d3c6ea67d8bb189212fb9704b618554f671..700893dd4030e2eb7b9eab2286319ec08df2f522 100644 --- a/crates/cloud_llm_client/Cargo.toml +++ b/crates/cloud_llm_client/Cargo.toml @@ -13,6 +13,7 @@ path = "src/cloud_llm_client.rs" [dependencies] anyhow.workspace = true +chrono.workspace = true serde = { workspace = true, features = ["derive", "rc"] } serde_json.workspace = true strum = { workspace = true, features = ["derive"] } diff --git a/crates/cloud_llm_client/src/predict_edits_v3.rs b/crates/cloud_llm_client/src/predict_edits_v3.rs index 320f73c2b7a47a67199af2b868f43dfb8872e3e1..995409ca2fd33a8f8ac24e89cd45d0dcec7d2d5b 100644 --- a/crates/cloud_llm_client/src/predict_edits_v3.rs +++ b/crates/cloud_llm_client/src/predict_edits_v3.rs @@ -1,3 +1,4 @@ +use chrono::Duration; use serde::{Deserialize, Serialize}; use std::{ops::Range, path::PathBuf}; use uuid::Uuid; @@ -9,6 +10,11 @@ use crate::PredictEditsGitInfo; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct PredictEditsRequest { pub excerpt: String, + pub excerpt_path: PathBuf, + /// Within file + pub excerpt_range: Range, + /// Within `excerpt` + pub cursor_offset: usize, /// Within `signatures` pub excerpt_parent: Option, pub signatures: Vec, @@ -21,10 +27,20 @@ pub struct PredictEditsRequest { /// Info about the git repository state, only present when can_collect_data is true. #[serde(skip_serializing_if = "Option::is_none", default)] pub git_info: Option, + #[serde(default)] + pub debug_info: bool, } #[derive(Debug, Clone, Serialize, Deserialize)] -pub enum Event {} +#[serde(tag = "event")] +pub enum Event { + BufferChange { + path: Option, + old_path: Option, + diff: String, + predicted: bool, + }, +} #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Signature { @@ -36,8 +52,11 @@ pub struct Signature { #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ReferencedDeclaration { + pub path: PathBuf, pub text: String, pub text_is_truncated: bool, + /// Range of `text` within file, potentially truncated according to `text_is_truncated` + pub range: Range, /// Range within `text` pub signature_range: Range, /// Index within `signatures`. @@ -79,6 +98,16 @@ pub struct DiagnosticGroup { pub struct PredictEditsResponse { pub request_id: Uuid, pub edits: Vec, + pub debug_info: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DebugInfo { + pub prompt: String, + pub prompt_planning_time: Duration, + pub model_response: String, + pub inference_time: Duration, + pub parsing_time: Duration, } #[derive(Debug, Clone, Serialize, Deserialize)] diff --git a/crates/edit_prediction_context/src/declaration.rs b/crates/edit_prediction_context/src/declaration.rs index 2f3b6f437146ff3d91d45aff44b666dde07580b3..653f810d439395a8825c99f4b007e05d881540ab 100644 --- a/crates/edit_prediction_context/src/declaration.rs +++ b/crates/edit_prediction_context/src/declaration.rs @@ -66,6 +66,13 @@ impl Declaration { } } + pub fn item_range(&self) -> Range { + match self { + Declaration::File { declaration, .. } => declaration.item_range_in_file.clone(), + Declaration::Buffer { declaration, .. } => declaration.item_range.clone(), + } + } + pub fn item_text(&self) -> (Cow<'_, str>, bool) { match self { Declaration::File { declaration, .. } => ( diff --git a/crates/edit_prediction_context/src/edit_prediction_context.rs b/crates/edit_prediction_context/src/edit_prediction_context.rs index a42699c5ce2f745a01d17400aa1895b04ff64cbd..71bb486bdc8f70257d205fc25313c65669b712b8 100644 --- a/crates/edit_prediction_context/src/edit_prediction_context.rs +++ b/crates/edit_prediction_context/src/edit_prediction_context.rs @@ -20,6 +20,7 @@ pub use syntax_index::*; pub struct EditPredictionContext { pub excerpt: EditPredictionExcerpt, pub excerpt_text: EditPredictionExcerptText, + pub cursor_offset_in_excerpt: usize, pub snippets: Vec, } @@ -57,17 +58,18 @@ impl EditPredictionContext { index_state, )?; let excerpt_text = excerpt.text(buffer); + let cursor_offset_in_file = cursor_point.to_offset(buffer); + let cursor_offset_in_excerpt = cursor_offset_in_file - excerpt.range.start; let snippets = if let Some(index_state) = index_state { let references = references_in_excerpt(&excerpt, &excerpt_text, buffer); - let cursor_offset = cursor_point.to_offset(buffer); scored_snippets( &index_state, &excerpt, &excerpt_text, references, - cursor_offset, + cursor_offset_in_file, buffer, ) } else { @@ -77,6 +79,7 @@ impl EditPredictionContext { Some(Self { excerpt, excerpt_text, + cursor_offset_in_excerpt, snippets, }) } diff --git a/crates/zeta2/Cargo.toml b/crates/zeta2/Cargo.toml index e8ffda51ad7a2ff2c63d1241cdcf3b626460396c..001a3b24e3d62cacae9608307b36c3a458cb223f 100644 --- a/crates/zeta2/Cargo.toml +++ b/crates/zeta2/Cargo.toml @@ -30,3 +30,4 @@ thiserror.workspace = true util.workspace = true workspace.workspace = true workspace-hack.workspace = true +worktree.workspace = true diff --git a/crates/zeta2/src/zeta2.rs b/crates/zeta2/src/zeta2.rs index 3ec858115d11028d7c4ac94a01482be08f3f0e47..4ae9b36e596e37cf92b9403d3f04935e4258e2e1 100644 --- a/crates/zeta2/src/zeta2.rs +++ b/crates/zeta2/src/zeta2.rs @@ -1,4 +1,4 @@ -use anyhow::{Context as _, Result}; +use anyhow::{Context as _, Result, anyhow}; use arrayvec::ArrayVec; use client::{Client, EditPredictionUsage, UserStore}; use cloud_llm_client::predict_edits_v3::{self, Signature}; @@ -22,6 +22,7 @@ use language_model::{LlmApiToken, RefreshLlmTokenListener}; use project::Project; use release_channel::AppVersion; use std::collections::HashMap; +use std::path::PathBuf; use std::str::FromStr as _; use std::time::{Duration, Instant}; use std::{ops::Range, sync::Arc}; @@ -120,9 +121,17 @@ impl Zeta { }); let excerpt_options = self.excerpt_options.clone(); let snapshot = buffer.read(cx).snapshot(); + let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else { + return Task::ready(Err(anyhow!("No file path for excerpt"))); + }; let client = self.client.clone(); let llm_token = self.llm_token.clone(); let app_version = AppVersion::global(cx); + let worktree_snapshots = project + .read(cx) + .worktrees(cx) + .map(|worktree| worktree.read(cx).snapshot()) + .collect::>(); let request_task = cx.background_spawn({ let snapshot = snapshot.clone(); @@ -135,6 +144,9 @@ impl Zeta { let cursor_point = position.to_point(&snapshot); + // TODO: make this only true if debug view is open + let debug_info = true; + let Some(request) = EditPredictionContext::gather_context( cursor_point, &snapshot, @@ -143,12 +155,15 @@ impl Zeta { ) .map(|context| { make_cloud_request( + excerpt_path.clone(), context, // TODO pass everything Vec::new(), false, Vec::new(), None, + debug_info, + &worktree_snapshots, index_state.as_deref(), ) }) else { @@ -263,7 +278,7 @@ impl Zeta { } else { request_builder.uri( http_client - .build_zed_llm_url("/predict_edits/v2", &[])? + .build_zed_llm_url("/predict_edits/v3", &[])? .as_ref(), ) }; @@ -585,11 +600,14 @@ impl EditPredictionProvider for ZetaEditPredictionProvider { } fn make_cloud_request( + excerpt_path: PathBuf, context: EditPredictionContext, events: Vec, can_collect_data: bool, diagnostic_groups: Vec, git_info: Option, + debug_info: bool, + worktrees: &Vec, index_state: Option<&SyntaxIndexState>, ) -> predict_edits_v3::PredictEditsRequest { let mut signatures = Vec::new(); @@ -597,6 +615,18 @@ fn make_cloud_request( let mut referenced_declarations = Vec::new(); for snippet in context.snippets { + let project_entry_id = snippet.declaration.project_entry_id(); + // TODO: Use full paths (worktree rooted) - need to move full_path method to the snapshot. + // Note that currently full_path is currently being used for excerpt_path. + let Some(path) = worktrees.iter().find_map(|worktree| { + let abs_path = worktree.abs_path(); + worktree + .entry_for_id(project_entry_id) + .map(|e| abs_path.join(&e.path)) + }) else { + continue; + }; + let parent_index = index_state.and_then(|index_state| { snippet.declaration.parent().and_then(|parent| { add_signature( @@ -607,9 +637,12 @@ fn make_cloud_request( ) }) }); + let (text, text_is_truncated) = snippet.declaration.item_text(); referenced_declarations.push(predict_edits_v3::ReferencedDeclaration { + path, text: text.into(), + range: snippet.declaration.item_range(), text_is_truncated, signature_range: snippet.declaration.signature_range_in_item_text(), parent_index, @@ -635,7 +668,10 @@ fn make_cloud_request( }); predict_edits_v3::PredictEditsRequest { + excerpt_path, excerpt: context.excerpt_text.body, + excerpt_range: context.excerpt.range, + cursor_offset: context.cursor_offset_in_excerpt, referenced_declarations, signatures, excerpt_parent, @@ -644,6 +680,7 @@ fn make_cloud_request( can_collect_data, diagnostic_groups, git_info, + debug_info, } }