@@ -345,21 +345,20 @@ impl Zeta {
new_snapshot,
..
} => {
- let path = new_snapshot.file().map(|f| f.path().clone());
+ let path = new_snapshot.file().map(|f| f.full_path(cx));
let old_path = old_snapshot.file().and_then(|f| {
- let old_path = f.path();
- if Some(old_path) != path.as_ref() {
- Some(old_path.clone())
+ let old_path = f.full_path(cx);
+ if Some(&old_path) != path.as_ref() {
+ Some(old_path)
} else {
None
}
});
predict_edits_v3::Event::BufferChange {
- old_path: old_path
- .map(|old_path| old_path.as_std_path().to_path_buf()),
- path: path.map(|path| path.as_std_path().to_path_buf()),
+ old_path,
+ path,
diff: language::unified_diff(
&old_snapshot.text(),
&new_snapshot.text(),
@@ -833,3 +832,316 @@ fn add_signature(
declaration_to_signature_index.insert(declaration_id, signature_index);
Some(signature_index)
}
+
+#[cfg(test)]
+mod tests {
+ use std::{
+ path::{Path, PathBuf},
+ sync::Arc,
+ };
+
+ use client::UserStore;
+ use clock::FakeSystemClock;
+ use cloud_llm_client::predict_edits_v3;
+ use futures::{
+ AsyncReadExt, StreamExt,
+ channel::{mpsc, oneshot},
+ };
+ use gpui::{
+ Entity, TestAppContext,
+ http_client::{FakeHttpClient, Response},
+ prelude::*,
+ };
+ use indoc::indoc;
+ use language::{LanguageServerId, OffsetRangeExt as _};
+ use project::{FakeFs, Project};
+ use serde_json::json;
+ use settings::SettingsStore;
+ use util::path;
+ use uuid::Uuid;
+
+ use crate::Zeta;
+
+ #[gpui::test]
+ async fn test_simple_request(cx: &mut TestAppContext) {
+ let (zeta, mut req_rx) = init_test(cx);
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(
+ "/root",
+ json!({
+ "foo.md": "Hello!\nHow\nBye"
+ }),
+ )
+ .await;
+ let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
+
+ let buffer = project
+ .update(cx, |project, cx| {
+ let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
+ project.open_buffer(path, cx)
+ })
+ .await
+ .unwrap();
+ let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
+ let position = snapshot.anchor_before(language::Point::new(1, 3));
+
+ let prediction_task = zeta.update(cx, |zeta, cx| {
+ zeta.request_prediction(&project, &buffer, position, cx)
+ });
+
+ let (request, respond_tx) = req_rx.next().await.unwrap();
+ assert_eq!(
+ request.excerpt_path.as_ref(),
+ Path::new(path!("root/foo.md"))
+ );
+ assert_eq!(request.cursor_offset, 10);
+
+ respond_tx
+ .send(predict_edits_v3::PredictEditsResponse {
+ request_id: Uuid::new_v4(),
+ edits: vec![predict_edits_v3::Edit {
+ path: Path::new(path!("root/foo.md")).into(),
+ range: 0..snapshot.len(),
+ content: "Hello!\nHow are you?\nBye".into(),
+ }],
+ debug_info: None,
+ })
+ .unwrap();
+
+ let prediction = prediction_task.await.unwrap().unwrap();
+
+ assert_eq!(prediction.edits.len(), 1);
+ assert_eq!(
+ prediction.edits[0].0.to_point(&snapshot).start,
+ language::Point::new(1, 3)
+ );
+ assert_eq!(prediction.edits[0].1, " are you?");
+ }
+
+ #[gpui::test]
+ async fn test_request_events(cx: &mut TestAppContext) {
+ let (zeta, mut req_rx) = init_test(cx);
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(
+ "/root",
+ json!({
+ "foo.md": "Hello!\n\nBye"
+ }),
+ )
+ .await;
+ let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
+
+ let buffer = project
+ .update(cx, |project, cx| {
+ let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
+ project.open_buffer(path, cx)
+ })
+ .await
+ .unwrap();
+
+ zeta.update(cx, |zeta, cx| {
+ zeta.register_buffer(&buffer, &project, cx);
+ });
+
+ buffer.update(cx, |buffer, cx| {
+ buffer.edit(vec![(7..7, "How")], None, cx);
+ });
+
+ let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
+ let position = snapshot.anchor_before(language::Point::new(1, 3));
+
+ let prediction_task = zeta.update(cx, |zeta, cx| {
+ zeta.request_prediction(&project, &buffer, position, cx)
+ });
+
+ let (request, respond_tx) = req_rx.next().await.unwrap();
+
+ assert_eq!(request.events.len(), 1);
+ assert_eq!(
+ request.events[0],
+ predict_edits_v3::Event::BufferChange {
+ path: Some(PathBuf::from(path!("root/foo.md"))),
+ old_path: None,
+ diff: indoc! {"
+ @@ -1,3 +1,3 @@
+ Hello!
+ -
+ +How
+ Bye
+ "}
+ .to_string(),
+ predicted: false
+ }
+ );
+
+ respond_tx
+ .send(predict_edits_v3::PredictEditsResponse {
+ request_id: Uuid::new_v4(),
+ edits: vec![predict_edits_v3::Edit {
+ path: Path::new(path!("root/foo.md")).into(),
+ range: 0..snapshot.len(),
+ content: "Hello!\nHow are you?\nBye".into(),
+ }],
+ debug_info: None,
+ })
+ .unwrap();
+
+ let prediction = prediction_task.await.unwrap().unwrap();
+
+ assert_eq!(prediction.edits.len(), 1);
+ assert_eq!(
+ prediction.edits[0].0.to_point(&snapshot).start,
+ language::Point::new(1, 3)
+ );
+ assert_eq!(prediction.edits[0].1, " are you?");
+ }
+
+ #[gpui::test]
+ async fn test_request_diagnostics(cx: &mut TestAppContext) {
+ let (zeta, mut req_rx) = init_test(cx);
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(
+ "/root",
+ json!({
+ "foo.md": "Hello!\nBye"
+ }),
+ )
+ .await;
+ let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
+
+ let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
+ let diagnostic = lsp::Diagnostic {
+ range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
+ severity: Some(lsp::DiagnosticSeverity::ERROR),
+ message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
+ ..Default::default()
+ };
+
+ project.update(cx, |project, cx| {
+ project.lsp_store().update(cx, |lsp_store, cx| {
+ // Create some diagnostics
+ lsp_store
+ .update_diagnostics(
+ LanguageServerId(0),
+ lsp::PublishDiagnosticsParams {
+ uri: path_to_buffer_uri.clone(),
+ diagnostics: vec![diagnostic],
+ version: None,
+ },
+ None,
+ language::DiagnosticSourceKind::Pushed,
+ &[],
+ cx,
+ )
+ .unwrap();
+ });
+ });
+
+ let buffer = project
+ .update(cx, |project, cx| {
+ let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
+ project.open_buffer(path, cx)
+ })
+ .await
+ .unwrap();
+
+ let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
+ let position = snapshot.anchor_before(language::Point::new(0, 0));
+
+ let _prediction_task = zeta.update(cx, |zeta, cx| {
+ zeta.request_prediction(&project, &buffer, position, cx)
+ });
+
+ let (request, _respond_tx) = req_rx.next().await.unwrap();
+
+ assert_eq!(request.diagnostic_groups.len(), 1);
+ let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
+ .unwrap();
+ // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
+ assert_eq!(
+ value,
+ json!({
+ "entries": [{
+ "range": {
+ "start": 8,
+ "end": 10
+ },
+ "diagnostic": {
+ "source": null,
+ "code": null,
+ "code_description": null,
+ "severity": 1,
+ "message": "\"Hello\" deprecated. Use \"Hi\" instead",
+ "markdown": null,
+ "group_id": 0,
+ "is_primary": true,
+ "is_disk_based": false,
+ "is_unnecessary": false,
+ "source_kind": "Pushed",
+ "data": null,
+ "underline": true
+ }
+ }],
+ "primary_ix": 0
+ })
+ );
+ }
+
+ fn init_test(
+ cx: &mut TestAppContext,
+ ) -> (
+ Entity<Zeta>,
+ mpsc::UnboundedReceiver<(
+ predict_edits_v3::PredictEditsRequest,
+ oneshot::Sender<predict_edits_v3::PredictEditsResponse>,
+ )>,
+ ) {
+ cx.update(move |cx| {
+ let settings_store = SettingsStore::test(cx);
+ cx.set_global(settings_store);
+ language::init(cx);
+ Project::init_settings(cx);
+
+ let (req_tx, req_rx) = mpsc::unbounded();
+
+ let http_client = FakeHttpClient::create({
+ move |req| {
+ let uri = req.uri().path().to_string();
+ let mut body = req.into_body();
+ let req_tx = req_tx.clone();
+ async move {
+ let resp = match uri.as_str() {
+ "/client/llm_tokens" => serde_json::to_string(&json!({
+ "token": "test"
+ }))
+ .unwrap(),
+ "/predict_edits/v3" => {
+ let mut buf = Vec::new();
+ body.read_to_end(&mut buf).await.ok();
+ let req = serde_json::from_slice(&buf).unwrap();
+
+ let (res_tx, res_rx) = oneshot::channel();
+ req_tx.unbounded_send((req, res_tx)).unwrap();
+ serde_json::to_string(&res_rx.await.unwrap()).unwrap()
+ }
+ _ => {
+ panic!("Unexpected path: {}", uri)
+ }
+ };
+
+ Ok(Response::builder().body(resp.into()).unwrap())
+ }
+ }
+ });
+
+ let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
+ client.cloud_client().set_credentials(1, "test".into());
+
+ language_model::init(client.clone(), cx);
+
+ let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
+ let zeta = Zeta::global(&client, &user_store, cx);
+ (zeta, req_rx)
+ })
+ }
+}