diff --git a/Cargo.lock b/Cargo.lock index 93108f0b70128aebab68a1feb9bdddcbff8442cd..5441a1704061e11d0819997ee0ccd7b04b16fac9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -20609,6 +20609,7 @@ dependencies = [ "arrayvec", "chrono", "client", + "clock", "cloud_llm_client", "cloud_zeta2_prompt", "edit_prediction", @@ -20619,9 +20620,12 @@ dependencies = [ "language", "language_model", "log", + "lsp", + "pretty_assertions", "project", "release_channel", "serde_json", + "settings", "thiserror 2.0.12", "util", "uuid", diff --git a/crates/cloud_llm_client/Cargo.toml b/crates/cloud_llm_client/Cargo.toml index 700893dd4030e2eb7b9eab2286319ec08df2f522..1ef978f0a7d112f4239215d43b2306631bafa64b 100644 --- a/crates/cloud_llm_client/Cargo.toml +++ b/crates/cloud_llm_client/Cargo.toml @@ -5,6 +5,9 @@ publish.workspace = true edition.workspace = true license = "Apache-2.0" +[features] +test-support = [] + [lints] workspace = true diff --git a/crates/cloud_llm_client/src/predict_edits_v3.rs b/crates/cloud_llm_client/src/predict_edits_v3.rs index 9c5123fdb8e7aaddbda3bd7cd5d36b112de7538d..21f32d674443282b1793e27609f6b29221d7e966 100644 --- a/crates/cloud_llm_client/src/predict_edits_v3.rs +++ b/crates/cloud_llm_client/src/predict_edits_v3.rs @@ -50,6 +50,7 @@ pub enum PromptFormat { } #[derive(Debug, Clone, Serialize, Deserialize)] +#[cfg_attr(any(test, feature = "test-support"), derive(PartialEq))] #[serde(tag = "event")] pub enum Event { BufferChange { diff --git a/crates/zeta2/Cargo.toml b/crates/zeta2/Cargo.toml index 8ef1c5a64f7b64a6af6a0ce984e2f76e14eb5e77..bce7e5987ccec635b335110a3a38298040c68e72 100644 --- a/crates/zeta2/Cargo.toml +++ b/crates/zeta2/Cargo.toml @@ -37,4 +37,12 @@ workspace.workspace = true worktree.workspace = true [dev-dependencies] +clock = { workspace = true, features = ["test-support"] } +cloud_llm_client = { workspace = true, features = ["test-support"] } gpui = { workspace = true, features = ["test-support"] } +lsp.workspace = true +indoc.workspace = true +language_model = { workspace = true, features = ["test-support"] } +pretty_assertions.workspace = true +project = { workspace = true, features = ["test-support"] } +settings = { workspace = true, features = ["test-support"] } diff --git a/crates/zeta2/src/zeta2.rs b/crates/zeta2/src/zeta2.rs index 5f621d6acf11b1f42e5c2334b8cf03f8e1176d0a..b15496a7558ce21de775cc7666382098431d2c21 100644 --- a/crates/zeta2/src/zeta2.rs +++ b/crates/zeta2/src/zeta2.rs @@ -345,21 +345,20 @@ impl Zeta { new_snapshot, .. } => { - let path = new_snapshot.file().map(|f| f.path().clone()); + let path = new_snapshot.file().map(|f| f.full_path(cx)); let old_path = old_snapshot.file().and_then(|f| { - let old_path = f.path(); - if Some(old_path) != path.as_ref() { - Some(old_path.clone()) + let old_path = f.full_path(cx); + if Some(&old_path) != path.as_ref() { + Some(old_path) } else { None } }); predict_edits_v3::Event::BufferChange { - old_path: old_path - .map(|old_path| old_path.as_std_path().to_path_buf()), - path: path.map(|path| path.as_std_path().to_path_buf()), + old_path, + path, diff: language::unified_diff( &old_snapshot.text(), &new_snapshot.text(), @@ -833,3 +832,316 @@ fn add_signature( declaration_to_signature_index.insert(declaration_id, signature_index); Some(signature_index) } + +#[cfg(test)] +mod tests { + use std::{ + path::{Path, PathBuf}, + sync::Arc, + }; + + use client::UserStore; + use clock::FakeSystemClock; + use cloud_llm_client::predict_edits_v3; + use futures::{ + AsyncReadExt, StreamExt, + channel::{mpsc, oneshot}, + }; + use gpui::{ + Entity, TestAppContext, + http_client::{FakeHttpClient, Response}, + prelude::*, + }; + use indoc::indoc; + use language::{LanguageServerId, OffsetRangeExt as _}; + use project::{FakeFs, Project}; + use serde_json::json; + use settings::SettingsStore; + use util::path; + use uuid::Uuid; + + use crate::Zeta; + + #[gpui::test] + async fn test_simple_request(cx: &mut TestAppContext) { + let (zeta, mut req_rx) = init_test(cx); + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + "/root", + json!({ + "foo.md": "Hello!\nHow\nBye" + }), + ) + .await; + let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await; + + let buffer = project + .update(cx, |project, cx| { + let path = project.find_project_path(path!("root/foo.md"), cx).unwrap(); + project.open_buffer(path, cx) + }) + .await + .unwrap(); + let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot()); + let position = snapshot.anchor_before(language::Point::new(1, 3)); + + let prediction_task = zeta.update(cx, |zeta, cx| { + zeta.request_prediction(&project, &buffer, position, cx) + }); + + let (request, respond_tx) = req_rx.next().await.unwrap(); + assert_eq!( + request.excerpt_path.as_ref(), + Path::new(path!("root/foo.md")) + ); + assert_eq!(request.cursor_offset, 10); + + respond_tx + .send(predict_edits_v3::PredictEditsResponse { + request_id: Uuid::new_v4(), + edits: vec![predict_edits_v3::Edit { + path: Path::new(path!("root/foo.md")).into(), + range: 0..snapshot.len(), + content: "Hello!\nHow are you?\nBye".into(), + }], + debug_info: None, + }) + .unwrap(); + + let prediction = prediction_task.await.unwrap().unwrap(); + + assert_eq!(prediction.edits.len(), 1); + assert_eq!( + prediction.edits[0].0.to_point(&snapshot).start, + language::Point::new(1, 3) + ); + assert_eq!(prediction.edits[0].1, " are you?"); + } + + #[gpui::test] + async fn test_request_events(cx: &mut TestAppContext) { + let (zeta, mut req_rx) = init_test(cx); + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + "/root", + json!({ + "foo.md": "Hello!\n\nBye" + }), + ) + .await; + let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await; + + let buffer = project + .update(cx, |project, cx| { + let path = project.find_project_path(path!("root/foo.md"), cx).unwrap(); + project.open_buffer(path, cx) + }) + .await + .unwrap(); + + zeta.update(cx, |zeta, cx| { + zeta.register_buffer(&buffer, &project, cx); + }); + + buffer.update(cx, |buffer, cx| { + buffer.edit(vec![(7..7, "How")], None, cx); + }); + + let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot()); + let position = snapshot.anchor_before(language::Point::new(1, 3)); + + let prediction_task = zeta.update(cx, |zeta, cx| { + zeta.request_prediction(&project, &buffer, position, cx) + }); + + let (request, respond_tx) = req_rx.next().await.unwrap(); + + assert_eq!(request.events.len(), 1); + assert_eq!( + request.events[0], + predict_edits_v3::Event::BufferChange { + path: Some(PathBuf::from(path!("root/foo.md"))), + old_path: None, + diff: indoc! {" + @@ -1,3 +1,3 @@ + Hello! + - + +How + Bye + "} + .to_string(), + predicted: false + } + ); + + respond_tx + .send(predict_edits_v3::PredictEditsResponse { + request_id: Uuid::new_v4(), + edits: vec![predict_edits_v3::Edit { + path: Path::new(path!("root/foo.md")).into(), + range: 0..snapshot.len(), + content: "Hello!\nHow are you?\nBye".into(), + }], + debug_info: None, + }) + .unwrap(); + + let prediction = prediction_task.await.unwrap().unwrap(); + + assert_eq!(prediction.edits.len(), 1); + assert_eq!( + prediction.edits[0].0.to_point(&snapshot).start, + language::Point::new(1, 3) + ); + assert_eq!(prediction.edits[0].1, " are you?"); + } + + #[gpui::test] + async fn test_request_diagnostics(cx: &mut TestAppContext) { + let (zeta, mut req_rx) = init_test(cx); + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + "/root", + json!({ + "foo.md": "Hello!\nBye" + }), + ) + .await; + let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await; + + let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap(); + let diagnostic = lsp::Diagnostic { + range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)), + severity: Some(lsp::DiagnosticSeverity::ERROR), + message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(), + ..Default::default() + }; + + project.update(cx, |project, cx| { + project.lsp_store().update(cx, |lsp_store, cx| { + // Create some diagnostics + lsp_store + .update_diagnostics( + LanguageServerId(0), + lsp::PublishDiagnosticsParams { + uri: path_to_buffer_uri.clone(), + diagnostics: vec![diagnostic], + version: None, + }, + None, + language::DiagnosticSourceKind::Pushed, + &[], + cx, + ) + .unwrap(); + }); + }); + + let buffer = project + .update(cx, |project, cx| { + let path = project.find_project_path(path!("root/foo.md"), cx).unwrap(); + project.open_buffer(path, cx) + }) + .await + .unwrap(); + + let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot()); + let position = snapshot.anchor_before(language::Point::new(0, 0)); + + let _prediction_task = zeta.update(cx, |zeta, cx| { + zeta.request_prediction(&project, &buffer, position, cx) + }); + + let (request, _respond_tx) = req_rx.next().await.unwrap(); + + assert_eq!(request.diagnostic_groups.len(), 1); + let value = serde_json::from_str::(request.diagnostic_groups[0].0.get()) + .unwrap(); + // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3 + assert_eq!( + value, + json!({ + "entries": [{ + "range": { + "start": 8, + "end": 10 + }, + "diagnostic": { + "source": null, + "code": null, + "code_description": null, + "severity": 1, + "message": "\"Hello\" deprecated. Use \"Hi\" instead", + "markdown": null, + "group_id": 0, + "is_primary": true, + "is_disk_based": false, + "is_unnecessary": false, + "source_kind": "Pushed", + "data": null, + "underline": true + } + }], + "primary_ix": 0 + }) + ); + } + + fn init_test( + cx: &mut TestAppContext, + ) -> ( + Entity, + mpsc::UnboundedReceiver<( + predict_edits_v3::PredictEditsRequest, + oneshot::Sender, + )>, + ) { + cx.update(move |cx| { + let settings_store = SettingsStore::test(cx); + cx.set_global(settings_store); + language::init(cx); + Project::init_settings(cx); + + let (req_tx, req_rx) = mpsc::unbounded(); + + let http_client = FakeHttpClient::create({ + move |req| { + let uri = req.uri().path().to_string(); + let mut body = req.into_body(); + let req_tx = req_tx.clone(); + async move { + let resp = match uri.as_str() { + "/client/llm_tokens" => serde_json::to_string(&json!({ + "token": "test" + })) + .unwrap(), + "/predict_edits/v3" => { + let mut buf = Vec::new(); + body.read_to_end(&mut buf).await.ok(); + let req = serde_json::from_slice(&buf).unwrap(); + + let (res_tx, res_rx) = oneshot::channel(); + req_tx.unbounded_send((req, res_tx)).unwrap(); + serde_json::to_string(&res_rx.await.unwrap()).unwrap() + } + _ => { + panic!("Unexpected path: {}", uri) + } + }; + + Ok(Response::builder().body(resp.into()).unwrap()) + } + } + }); + + let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx); + client.cloud_client().set_credentials(1, "test".into()); + + language_model::init(client.clone(), cx); + + let user_store = cx.new(|cx| UserStore::new(client.clone(), cx)); + let zeta = Zeta::global(&client, &user_store, cx); + (zeta, req_rx) + }) + } +}