add retrieve context button to inline assistant

KCaverly created

Change summary

Cargo.lock                              |  21 ----
crates/assistant/Cargo.toml             |   2 
crates/assistant/src/assistant_panel.rs |  89 ++++++++++++--------
crates/assistant/src/prompts.rs         | 112 +++++++++++++++++---------
4 files changed, 131 insertions(+), 93 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -108,7 +108,7 @@ dependencies = [
  "rusqlite",
  "serde",
  "serde_json",
- "tiktoken-rs 0.5.4",
+ "tiktoken-rs",
  "util",
 ]
 
@@ -327,7 +327,7 @@ dependencies = [
  "settings",
  "smol",
  "theme",
- "tiktoken-rs 0.4.5",
+ "tiktoken-rs",
  "util",
  "uuid 1.4.1",
  "workspace",
@@ -6798,7 +6798,7 @@ dependencies = [
  "smol",
  "tempdir",
  "theme",
- "tiktoken-rs 0.5.4",
+ "tiktoken-rs",
  "tree-sitter",
  "tree-sitter-cpp",
  "tree-sitter-elixir",
@@ -7875,21 +7875,6 @@ dependencies = [
  "weezl",
 ]
 
-[[package]]
-name = "tiktoken-rs"
-version = "0.4.5"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "52aacc1cff93ba9d5f198c62c49c77fa0355025c729eed3326beaf7f33bc8614"
-dependencies = [
- "anyhow",
- "base64 0.21.4",
- "bstr",
- "fancy-regex",
- "lazy_static",
- "parking_lot 0.12.1",
- "rustc-hash",
-]
-
 [[package]]
 name = "tiktoken-rs"
 version = "0.5.4"

crates/assistant/Cargo.toml 🔗

@@ -38,7 +38,7 @@ schemars.workspace = true
 serde.workspace = true
 serde_json.workspace = true
 smol.workspace = true
-tiktoken-rs = "0.4"
+tiktoken-rs = "0.5"
 
 [dev-dependencies]
 editor = { path = "../editor", features = ["test-support"] }

crates/assistant/src/assistant_panel.rs 🔗

@@ -437,8 +437,15 @@ impl AssistantPanel {
             InlineAssistantEvent::Confirmed {
                 prompt,
                 include_conversation,
+                retrieve_context,
             } => {
-                self.confirm_inline_assist(assist_id, prompt, *include_conversation, cx);
+                self.confirm_inline_assist(
+                    assist_id,
+                    prompt,
+                    *include_conversation,
+                    cx,
+                    *retrieve_context,
+                );
             }
             InlineAssistantEvent::Canceled => {
                 self.finish_inline_assist(assist_id, true, cx);
@@ -532,6 +539,7 @@ impl AssistantPanel {
         user_prompt: &str,
         include_conversation: bool,
         cx: &mut ViewContext<Self>,
+        retrieve_context: bool,
     ) {
         let conversation = if include_conversation {
             self.active_editor()
@@ -593,42 +601,49 @@ impl AssistantPanel {
         let codegen_kind = codegen.read(cx).kind().clone();
         let user_prompt = user_prompt.to_string();
 
-        let project = if let Some(workspace) = self.workspace.upgrade(cx) {
-            workspace.read(cx).project()
-        } else {
-            return;
-        };
+        let snippets = if retrieve_context {
+            let project = if let Some(workspace) = self.workspace.upgrade(cx) {
+                workspace.read(cx).project()
+            } else {
+                return;
+            };
 
-        let project = project.to_owned();
-        let search_results = if let Some(semantic_index) = self.semantic_index.clone() {
-            let search_results = semantic_index.update(cx, |this, cx| {
-                this.search_project(project, user_prompt.to_string(), 10, vec![], vec![], cx)
-            });
+            let project = project.to_owned();
+            let search_results = if let Some(semantic_index) = self.semantic_index.clone() {
+                let search_results = semantic_index.update(cx, |this, cx| {
+                    this.search_project(project, user_prompt.to_string(), 10, vec![], vec![], cx)
+                });
 
-            cx.background()
-                .spawn(async move { search_results.await.unwrap_or_default() })
+                cx.background()
+                    .spawn(async move { search_results.await.unwrap_or_default() })
+            } else {
+                Task::ready(Vec::new())
+            };
+
+            let snippets = cx.spawn(|_, cx| async move {
+                let mut snippets = Vec::new();
+                for result in search_results.await {
+                    snippets.push(result.buffer.read_with(&cx, |buffer, _| {
+                        buffer
+                            .snapshot()
+                            .text_for_range(result.range)
+                            .collect::<String>()
+                    }));
+                }
+                snippets
+            });
+            snippets
         } else {
             Task::ready(Vec::new())
         };
 
-        let snippets = cx.spawn(|_, cx| async move {
-            let mut snippets = Vec::new();
-            for result in search_results.await {
-                snippets.push(result.buffer.read_with(&cx, |buffer, _| {
-                    buffer
-                        .snapshot()
-                        .text_for_range(result.range)
-                        .collect::<String>()
-                }));
-            }
-            snippets
-        });
+        let mut model = settings::get::<AssistantSettings>(cx)
+            .default_open_ai_model
+            .clone();
+        let model_name = model.full_name();
 
         let prompt = cx.background().spawn(async move {
             let snippets = snippets.await;
-            for snippet in &snippets {
-                println!("SNIPPET: \n{:?}", snippet);
-            }
 
             let language_name = language_name.as_deref();
             generate_content_prompt(
@@ -638,13 +653,11 @@ impl AssistantPanel {
                 range,
                 codegen_kind,
                 snippets,
+                model_name,
             )
         });
 
         let mut messages = Vec::new();
-        let mut model = settings::get::<AssistantSettings>(cx)
-            .default_open_ai_model
-            .clone();
         if let Some(conversation) = conversation {
             let conversation = conversation.read(cx);
             let buffer = conversation.buffer.read(cx);
@@ -1557,12 +1570,14 @@ impl Conversation {
                         Role::Assistant => "assistant".into(),
                         Role::System => "system".into(),
                     },
-                    content: self
-                        .buffer
-                        .read(cx)
-                        .text_for_range(message.offset_range)
-                        .collect(),
+                    content: Some(
+                        self.buffer
+                            .read(cx)
+                            .text_for_range(message.offset_range)
+                            .collect(),
+                    ),
                     name: None,
+                    function_call: None,
                 })
             })
             .collect::<Vec<_>>();
@@ -2681,6 +2696,7 @@ enum InlineAssistantEvent {
     Confirmed {
         prompt: String,
         include_conversation: bool,
+        retrieve_context: bool,
     },
     Canceled,
     Dismissed,
@@ -2922,6 +2938,7 @@ impl InlineAssistant {
             cx.emit(InlineAssistantEvent::Confirmed {
                 prompt,
                 include_conversation: self.include_conversation,
+                retrieve_context: self.retrieve_context,
             });
             self.confirmed = true;
             cx.notify();

crates/assistant/src/prompts.rs 🔗

@@ -1,8 +1,10 @@
 use crate::codegen::CodegenKind;
 use language::{BufferSnapshot, OffsetRangeExt, ToOffset};
 use std::cmp;
+use std::fmt::Write;
+use std::iter;
 use std::ops::Range;
-use std::{fmt::Write, iter};
+use tiktoken_rs::ChatCompletionRequestMessage;
 
 fn summarize(buffer: &BufferSnapshot, selected_range: Range<impl ToOffset>) -> String {
     #[derive(Debug)]
@@ -122,69 +124,103 @@ pub fn generate_content_prompt(
     range: Range<impl ToOffset>,
     kind: CodegenKind,
     search_results: Vec<String>,
+    model: &str,
 ) -> String {
-    let mut prompt = String::new();
+    const MAXIMUM_SNIPPET_TOKEN_COUNT: usize = 500;
+
+    let mut prompts = Vec::new();
 
     // General Preamble
     if let Some(language_name) = language_name {
-        writeln!(prompt, "You're an expert {language_name} engineer.\n").unwrap();
+        prompts.push(format!("You're an expert {language_name} engineer.\n"));
     } else {
-        writeln!(prompt, "You're an expert engineer.\n").unwrap();
+        prompts.push("You're an expert engineer.\n".to_string());
     }
 
+    // Snippets
+    let mut snippet_position = prompts.len() - 1;
+
     let outline = summarize(buffer, range);
-    writeln!(
-        prompt,
-        "The file you are currently working on has the following outline:"
-    )
-    .unwrap();
+    prompts.push("The file you are currently working on has the following outline:".to_string());
     if let Some(language_name) = language_name {
         let language_name = language_name.to_lowercase();
-        writeln!(prompt, "```{language_name}\n{outline}\n```").unwrap();
+        prompts.push(format!("```{language_name}\n{outline}\n```"));
     } else {
-        writeln!(prompt, "```\n{outline}\n```").unwrap();
+        prompts.push(format!("```\n{outline}\n```"));
     }
 
     match kind {
         CodegenKind::Generate { position: _ } => {
-            writeln!(prompt, "In particular, the user's cursor is current on the '<|START|>' span in the above outline, with no text selected.").unwrap();
-            writeln!(
-                prompt,
-                "Assume the cursor is located where the `<|START|` marker is."
-            )
-            .unwrap();
-            writeln!(
-                prompt,
+            prompts.push("In particular, the user's cursor is currently on the '<|START|>' span in the above outline, with no text selected.".to_string());
+            prompts
+                .push("Assume the cursor is located where the `<|START|` marker is.".to_string());
+            prompts.push(
                 "Text can't be replaced, so assume your answer will be inserted at the cursor."
-            )
-            .unwrap();
-            writeln!(
-                prompt,
+                    .to_string(),
+            );
+            prompts.push(format!(
                 "Generate text based on the users prompt: {user_prompt}"
-            )
-            .unwrap();
+            ));
         }
         CodegenKind::Transform { range: _ } => {
-            writeln!(prompt, "In particular, the user has selected a section of the text between the '<|START|' and '|END|>' spans.").unwrap();
-            writeln!(
-                prompt,
+            prompts.push("In particular, the user has selected a section of the text between the '<|START|' and '|END|>' spans.".to_string());
+            prompts.push(format!(
                 "Modify the users code selected text based upon the users prompt: {user_prompt}"
-            )
-            .unwrap();
-            writeln!(
-                prompt,
-                "You MUST reply with only the adjusted code (within the '<|START|' and '|END|>' spans), not the entire file."
-            )
-            .unwrap();
+            ));
+            prompts.push("You MUST reply with only the adjusted code (within the '<|START|' and '|END|>' spans), not the entire file.".to_string());
         }
     }
 
     if let Some(language_name) = language_name {
-        writeln!(prompt, "Your answer MUST always be valid {language_name}").unwrap();
+        prompts.push(format!("Your answer MUST always be valid {language_name}"));
+    }
+    prompts.push("Always wrap your response in a Markdown codeblock".to_string());
+    prompts.push("Never make remarks about the output.".to_string());
+
+    let current_messages = [ChatCompletionRequestMessage {
+        role: "user".to_string(),
+        content: Some(prompts.join("\n")),
+        function_call: None,
+        name: None,
+    }];
+
+    let remaining_token_count = if let Ok(current_token_count) =
+        tiktoken_rs::num_tokens_from_messages(model, &current_messages)
+    {
+        let max_token_count = tiktoken_rs::model::get_context_size(model);
+        max_token_count - current_token_count
+    } else {
+        // If tiktoken fails to count token count, assume we have no space remaining.
+        0
+    };
+
+    // TODO:
+    //   - add repository name to snippet
+    //   - add file path
+    //   - add language
+    if let Ok(encoding) = tiktoken_rs::get_bpe_from_model(model) {
+        let template = "You are working inside a large repository, here are a few code snippets that may be useful";
+
+        for search_result in search_results {
+            let mut snippet_prompt = template.to_string();
+            writeln!(snippet_prompt, "```\n{search_result}\n```").unwrap();
+
+            let token_count = encoding
+                .encode_with_special_tokens(snippet_prompt.as_str())
+                .len();
+            if token_count <= remaining_token_count {
+                if token_count < MAXIMUM_SNIPPET_TOKEN_COUNT {
+                    prompts.insert(snippet_position, snippet_prompt);
+                    snippet_position += 1;
+                }
+            } else {
+                break;
+            }
+        }
     }
-    writeln!(prompt, "Always wrap your response in a Markdown codeblock").unwrap();
-    writeln!(prompt, "Never make remarks about the output.").unwrap();
 
+    let prompt = prompts.join("\n");
+    println!("PROMPT: {:?}", prompt);
     prompt
 }