zeta2: Test prediction request (#38794)

Agus Zubiaga and Bennet Bo Fenner created

Release Notes:

- N/A

---------

Co-authored-by: Bennet Bo Fenner <bennetbo@gmx.de>

Change summary

Cargo.lock                                      |   4 
crates/cloud_llm_client/Cargo.toml              |   3 
crates/cloud_llm_client/src/predict_edits_v3.rs |   1 
crates/zeta2/Cargo.toml                         |   8 
crates/zeta2/src/zeta2.rs                       | 326 ++++++++++++++++++
5 files changed, 335 insertions(+), 7 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -20609,6 +20609,7 @@ dependencies = [
  "arrayvec",
  "chrono",
  "client",
+ "clock",
  "cloud_llm_client",
  "cloud_zeta2_prompt",
  "edit_prediction",
@@ -20619,9 +20620,12 @@ dependencies = [
  "language",
  "language_model",
  "log",
+ "lsp",
+ "pretty_assertions",
  "project",
  "release_channel",
  "serde_json",
+ "settings",
  "thiserror 2.0.12",
  "util",
  "uuid",

crates/cloud_llm_client/src/predict_edits_v3.rs 🔗

@@ -50,6 +50,7 @@ pub enum PromptFormat {
 }
 
 #[derive(Debug, Clone, Serialize, Deserialize)]
+#[cfg_attr(any(test, feature = "test-support"), derive(PartialEq))]
 #[serde(tag = "event")]
 pub enum Event {
     BufferChange {

crates/zeta2/Cargo.toml 🔗

@@ -37,4 +37,12 @@ workspace.workspace = true
 worktree.workspace = true
 
 [dev-dependencies]
+clock = { workspace = true, features = ["test-support"] }
+cloud_llm_client = { workspace = true, features = ["test-support"] }
 gpui = { workspace = true, features = ["test-support"] }
+lsp.workspace = true
+indoc.workspace = true
+language_model = { workspace = true, features = ["test-support"] }
+pretty_assertions.workspace = true
+project = { workspace = true, features = ["test-support"] }
+settings = { workspace = true, features = ["test-support"] }

crates/zeta2/src/zeta2.rs 🔗

@@ -345,21 +345,20 @@ impl Zeta {
                             new_snapshot,
                             ..
                         } => {
-                            let path = new_snapshot.file().map(|f| f.path().clone());
+                            let path = new_snapshot.file().map(|f| f.full_path(cx));
 
                             let old_path = old_snapshot.file().and_then(|f| {
-                                let old_path = f.path();
-                                if Some(old_path) != path.as_ref() {
-                                    Some(old_path.clone())
+                                let old_path = f.full_path(cx);
+                                if Some(&old_path) != path.as_ref() {
+                                    Some(old_path)
                                 } else {
                                     None
                                 }
                             });
 
                             predict_edits_v3::Event::BufferChange {
-                                old_path: old_path
-                                    .map(|old_path| old_path.as_std_path().to_path_buf()),
-                                path: path.map(|path| path.as_std_path().to_path_buf()),
+                                old_path,
+                                path,
                                 diff: language::unified_diff(
                                     &old_snapshot.text(),
                                     &new_snapshot.text(),
@@ -833,3 +832,316 @@ fn add_signature(
     declaration_to_signature_index.insert(declaration_id, signature_index);
     Some(signature_index)
 }
+
+#[cfg(test)]
+mod tests {
+    use std::{
+        path::{Path, PathBuf},
+        sync::Arc,
+    };
+
+    use client::UserStore;
+    use clock::FakeSystemClock;
+    use cloud_llm_client::predict_edits_v3;
+    use futures::{
+        AsyncReadExt, StreamExt,
+        channel::{mpsc, oneshot},
+    };
+    use gpui::{
+        Entity, TestAppContext,
+        http_client::{FakeHttpClient, Response},
+        prelude::*,
+    };
+    use indoc::indoc;
+    use language::{LanguageServerId, OffsetRangeExt as _};
+    use project::{FakeFs, Project};
+    use serde_json::json;
+    use settings::SettingsStore;
+    use util::path;
+    use uuid::Uuid;
+
+    use crate::Zeta;
+
+    #[gpui::test]
+    async fn test_simple_request(cx: &mut TestAppContext) {
+        let (zeta, mut req_rx) = init_test(cx);
+        let fs = FakeFs::new(cx.executor());
+        fs.insert_tree(
+            "/root",
+            json!({
+                "foo.md":  "Hello!\nHow\nBye"
+            }),
+        )
+        .await;
+        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
+
+        let buffer = project
+            .update(cx, |project, cx| {
+                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
+                project.open_buffer(path, cx)
+            })
+            .await
+            .unwrap();
+        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
+        let position = snapshot.anchor_before(language::Point::new(1, 3));
+
+        let prediction_task = zeta.update(cx, |zeta, cx| {
+            zeta.request_prediction(&project, &buffer, position, cx)
+        });
+
+        let (request, respond_tx) = req_rx.next().await.unwrap();
+        assert_eq!(
+            request.excerpt_path.as_ref(),
+            Path::new(path!("root/foo.md"))
+        );
+        assert_eq!(request.cursor_offset, 10);
+
+        respond_tx
+            .send(predict_edits_v3::PredictEditsResponse {
+                request_id: Uuid::new_v4(),
+                edits: vec![predict_edits_v3::Edit {
+                    path: Path::new(path!("root/foo.md")).into(),
+                    range: 0..snapshot.len(),
+                    content: "Hello!\nHow are you?\nBye".into(),
+                }],
+                debug_info: None,
+            })
+            .unwrap();
+
+        let prediction = prediction_task.await.unwrap().unwrap();
+
+        assert_eq!(prediction.edits.len(), 1);
+        assert_eq!(
+            prediction.edits[0].0.to_point(&snapshot).start,
+            language::Point::new(1, 3)
+        );
+        assert_eq!(prediction.edits[0].1, " are you?");
+    }
+
+    #[gpui::test]
+    async fn test_request_events(cx: &mut TestAppContext) {
+        let (zeta, mut req_rx) = init_test(cx);
+        let fs = FakeFs::new(cx.executor());
+        fs.insert_tree(
+            "/root",
+            json!({
+                "foo.md": "Hello!\n\nBye"
+            }),
+        )
+        .await;
+        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
+
+        let buffer = project
+            .update(cx, |project, cx| {
+                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
+                project.open_buffer(path, cx)
+            })
+            .await
+            .unwrap();
+
+        zeta.update(cx, |zeta, cx| {
+            zeta.register_buffer(&buffer, &project, cx);
+        });
+
+        buffer.update(cx, |buffer, cx| {
+            buffer.edit(vec![(7..7, "How")], None, cx);
+        });
+
+        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
+        let position = snapshot.anchor_before(language::Point::new(1, 3));
+
+        let prediction_task = zeta.update(cx, |zeta, cx| {
+            zeta.request_prediction(&project, &buffer, position, cx)
+        });
+
+        let (request, respond_tx) = req_rx.next().await.unwrap();
+
+        assert_eq!(request.events.len(), 1);
+        assert_eq!(
+            request.events[0],
+            predict_edits_v3::Event::BufferChange {
+                path: Some(PathBuf::from(path!("root/foo.md"))),
+                old_path: None,
+                diff: indoc! {"
+                        @@ -1,3 +1,3 @@
+                         Hello!
+                        -
+                        +How
+                         Bye
+                    "}
+                .to_string(),
+                predicted: false
+            }
+        );
+
+        respond_tx
+            .send(predict_edits_v3::PredictEditsResponse {
+                request_id: Uuid::new_v4(),
+                edits: vec![predict_edits_v3::Edit {
+                    path: Path::new(path!("root/foo.md")).into(),
+                    range: 0..snapshot.len(),
+                    content: "Hello!\nHow are you?\nBye".into(),
+                }],
+                debug_info: None,
+            })
+            .unwrap();
+
+        let prediction = prediction_task.await.unwrap().unwrap();
+
+        assert_eq!(prediction.edits.len(), 1);
+        assert_eq!(
+            prediction.edits[0].0.to_point(&snapshot).start,
+            language::Point::new(1, 3)
+        );
+        assert_eq!(prediction.edits[0].1, " are you?");
+    }
+
+    #[gpui::test]
+    async fn test_request_diagnostics(cx: &mut TestAppContext) {
+        let (zeta, mut req_rx) = init_test(cx);
+        let fs = FakeFs::new(cx.executor());
+        fs.insert_tree(
+            "/root",
+            json!({
+                "foo.md": "Hello!\nBye"
+            }),
+        )
+        .await;
+        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
+
+        let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
+        let diagnostic = lsp::Diagnostic {
+            range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
+            severity: Some(lsp::DiagnosticSeverity::ERROR),
+            message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
+            ..Default::default()
+        };
+
+        project.update(cx, |project, cx| {
+            project.lsp_store().update(cx, |lsp_store, cx| {
+                // Create some diagnostics
+                lsp_store
+                    .update_diagnostics(
+                        LanguageServerId(0),
+                        lsp::PublishDiagnosticsParams {
+                            uri: path_to_buffer_uri.clone(),
+                            diagnostics: vec![diagnostic],
+                            version: None,
+                        },
+                        None,
+                        language::DiagnosticSourceKind::Pushed,
+                        &[],
+                        cx,
+                    )
+                    .unwrap();
+            });
+        });
+
+        let buffer = project
+            .update(cx, |project, cx| {
+                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
+                project.open_buffer(path, cx)
+            })
+            .await
+            .unwrap();
+
+        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
+        let position = snapshot.anchor_before(language::Point::new(0, 0));
+
+        let _prediction_task = zeta.update(cx, |zeta, cx| {
+            zeta.request_prediction(&project, &buffer, position, cx)
+        });
+
+        let (request, _respond_tx) = req_rx.next().await.unwrap();
+
+        assert_eq!(request.diagnostic_groups.len(), 1);
+        let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
+            .unwrap();
+        // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
+        assert_eq!(
+            value,
+            json!({
+                "entries": [{
+                    "range": {
+                        "start": 8,
+                        "end": 10
+                    },
+                    "diagnostic": {
+                        "source": null,
+                        "code": null,
+                        "code_description": null,
+                        "severity": 1,
+                        "message": "\"Hello\" deprecated. Use \"Hi\" instead",
+                        "markdown": null,
+                        "group_id": 0,
+                        "is_primary": true,
+                        "is_disk_based": false,
+                        "is_unnecessary": false,
+                        "source_kind": "Pushed",
+                        "data": null,
+                        "underline": true
+                    }
+                }],
+                "primary_ix": 0
+            })
+        );
+    }
+
+    fn init_test(
+        cx: &mut TestAppContext,
+    ) -> (
+        Entity<Zeta>,
+        mpsc::UnboundedReceiver<(
+            predict_edits_v3::PredictEditsRequest,
+            oneshot::Sender<predict_edits_v3::PredictEditsResponse>,
+        )>,
+    ) {
+        cx.update(move |cx| {
+            let settings_store = SettingsStore::test(cx);
+            cx.set_global(settings_store);
+            language::init(cx);
+            Project::init_settings(cx);
+
+            let (req_tx, req_rx) = mpsc::unbounded();
+
+            let http_client = FakeHttpClient::create({
+                move |req| {
+                    let uri = req.uri().path().to_string();
+                    let mut body = req.into_body();
+                    let req_tx = req_tx.clone();
+                    async move {
+                        let resp = match uri.as_str() {
+                            "/client/llm_tokens" => serde_json::to_string(&json!({
+                                "token": "test"
+                            }))
+                            .unwrap(),
+                            "/predict_edits/v3" => {
+                                let mut buf = Vec::new();
+                                body.read_to_end(&mut buf).await.ok();
+                                let req = serde_json::from_slice(&buf).unwrap();
+
+                                let (res_tx, res_rx) = oneshot::channel();
+                                req_tx.unbounded_send((req, res_tx)).unwrap();
+                                serde_json::to_string(&res_rx.await.unwrap()).unwrap()
+                            }
+                            _ => {
+                                panic!("Unexpected path: {}", uri)
+                            }
+                        };
+
+                        Ok(Response::builder().body(resp.into()).unwrap())
+                    }
+                }
+            });
+
+            let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
+            client.cloud_client().set_credentials(1, "test".into());
+
+            language_model::init(client.clone(), cx);
+
+            let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
+            let zeta = Zeta::global(&client, &user_store, cx);
+            (zeta, req_rx)
+        })
+    }
+}