Add a failing test for codegen autoindent

Antonio Scandurra created

Change summary

Cargo.lock               |   1 
crates/ai/Cargo.toml     |   1 
crates/ai/src/ai.rs      |   2 
crates/ai/src/codegen.rs | 113 +++++++++++++++++++++++++++++++++++++++++
4 files changed, 115 insertions(+), 2 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -114,6 +114,7 @@ dependencies = [
  "log",
  "menu",
  "ordered-float",
+ "parking_lot 0.11.2",
  "project",
  "rand 0.8.5",
  "regex",

crates/ai/Cargo.toml 🔗

@@ -27,6 +27,7 @@ futures.workspace = true
 indoc.workspace = true
 isahc.workspace = true
 ordered-float.workspace = true
+parking_lot.workspace = true
 regex.workspace = true
 schemars.workspace = true
 serde.workspace = true

crates/ai/src/ai.rs 🔗

@@ -27,7 +27,7 @@ use util::paths::CONVERSATIONS_DIR;
 const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";
 
 // Data types for chat completion requests
-#[derive(Debug, Serialize)]
+#[derive(Debug, Default, Serialize)]
 pub struct OpenAIRequest {
     model: String,
     messages: Vec<RequestMessage>,

crates/ai/src/codegen.rs 🔗

@@ -406,9 +406,68 @@ fn strip_markdown_codeblock(
 
 #[cfg(test)]
 mod tests {
+    use super::*;
     use futures::stream;
+    use gpui::{executor::Deterministic, TestAppContext};
+    use indoc::indoc;
+    use language::{tree_sitter_rust, Buffer, Language, LanguageConfig};
+    use parking_lot::Mutex;
+    use rand::prelude::*;
+
+    #[gpui::test(iterations = 10)]
+    async fn test_autoindent(
+        cx: &mut TestAppContext,
+        mut rng: StdRng,
+        deterministic: Arc<Deterministic>,
+    ) {
+        let text = indoc! {"
+            fn main() {
+                let x = 0;
+                for _ in 0..10 {
+                    x += 1;
+                }
+            }
+        "};
+        let buffer =
+            cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
+        let buffer = cx.add_model(|cx| MultiBuffer::singleton(buffer, cx));
+        let range = buffer.read_with(cx, |buffer, cx| {
+            let snapshot = buffer.snapshot(cx);
+            snapshot.anchor_before(Point::new(1, 4))..snapshot.anchor_after(Point::new(4, 4))
+        });
+        let provider = Arc::new(TestCompletionProvider::new());
+        let codegen = cx.add_model(|cx| Codegen::new(buffer.clone(), range, provider.clone(), cx));
+        codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx));
+
+        let mut new_text = indoc! {"
+                   let mut x = 0;
+            while x < 10 {
+                           x += 1;
+               }
+        "};
+        while !new_text.is_empty() {
+            let max_len = cmp::min(new_text.len(), 10);
+            let len = rng.gen_range(1..=max_len);
+            let (chunk, suffix) = new_text.split_at(len);
+            provider.send_completion(chunk);
+            new_text = suffix;
+            deterministic.run_until_parked();
+        }
+        provider.finish_completion();
+        deterministic.run_until_parked();
 
-    use super::*;
+        assert_eq!(
+            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
+            indoc! {"
+                fn main() {
+                    let mut x = 0;
+                    while x < 10 {
+                        x += 1;
+                    }
+                }
+            "}
+        );
+    }
 
     #[gpui::test]
     async fn test_strip_markdown_codeblock() {
@@ -465,4 +524,56 @@ mod tests {
             )
         }
     }
+
+    struct TestCompletionProvider {
+        last_completion_tx: Mutex<Option<mpsc::Sender<String>>>,
+    }
+
+    impl TestCompletionProvider {
+        fn new() -> Self {
+            Self {
+                last_completion_tx: Mutex::new(None),
+            }
+        }
+
+        fn send_completion(&self, completion: impl Into<String>) {
+            let mut tx = self.last_completion_tx.lock();
+            tx.as_mut().unwrap().try_send(completion.into()).unwrap();
+        }
+
+        fn finish_completion(&self) {
+            self.last_completion_tx.lock().take().unwrap();
+        }
+    }
+
+    impl CompletionProvider for TestCompletionProvider {
+        fn complete(
+            &self,
+            _prompt: OpenAIRequest,
+        ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
+            let (tx, rx) = mpsc::channel(1);
+            *self.last_completion_tx.lock() = Some(tx);
+            async move { Ok(rx.map(|rx| Ok(rx)).boxed()) }.boxed()
+        }
+    }
+
+    fn rust_lang() -> Language {
+        Language::new(
+            LanguageConfig {
+                name: "Rust".into(),
+                path_suffixes: vec!["rs".to_string()],
+                ..Default::default()
+            },
+            Some(tree_sitter_rust::language()),
+        )
+        .with_indents_query(
+            r#"
+            (call_expression) @indent
+            (field_expression) @indent
+            (_ "(" ")" @end) @indent
+            (_ "{" "}" @end) @indent
+            "#,
+        )
+        .unwrap()
+    }
 }