Introduce Context Retrieval in Inline Assistant (#3097)

Kyle Caverly created

This PR introduces a new Inline Assistant feature "Retrieve Context", to
dynamically fill the content in your generation prompt based on relevant
results returned from the Semantic Search for the Prompt.

Release Notes:

- Introduce "Retrieve Context" button in Inline Assistant

Change summary

Cargo.lock                              |  22 
assets/icons/update.svg                 |   4 
crates/ai/src/embedding.rs              |  20 
crates/assistant/Cargo.toml             |   7 
crates/assistant/src/assistant_panel.rs | 555 ++++++++++++++++++++++++--
crates/assistant/src/prompts.rs         | 178 ++++++-
crates/theme/src/theme.rs               |   9 
styles/src/style_tree/assistant.ts      |  74 +++
8 files changed, 730 insertions(+), 139 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -103,7 +103,7 @@ dependencies = [
  "rusqlite",
  "serde",
  "serde_json",
- "tiktoken-rs 0.5.4",
+ "tiktoken-rs",
  "util",
 ]
 
@@ -316,12 +316,13 @@ dependencies = [
  "regex",
  "schemars",
  "search",
+ "semantic_index",
  "serde",
  "serde_json",
  "settings",
  "smol",
  "theme",
- "tiktoken-rs 0.4.5",
+ "tiktoken-rs",
  "util",
  "uuid 1.4.1",
  "workspace",
@@ -6942,7 +6943,7 @@ dependencies = [
  "smol",
  "tempdir",
  "theme",
- "tiktoken-rs 0.5.4",
+ "tiktoken-rs",
  "tree-sitter",
  "tree-sitter-cpp",
  "tree-sitter-elixir",
@@ -8117,21 +8118,6 @@ dependencies = [
  "weezl",
 ]
 
-[[package]]
-name = "tiktoken-rs"
-version = "0.4.5"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "52aacc1cff93ba9d5f198c62c49c77fa0355025c729eed3326beaf7f33bc8614"
-dependencies = [
- "anyhow",
- "base64 0.21.4",
- "bstr",
- "fancy-regex",
- "lazy_static",
- "parking_lot 0.12.1",
- "rustc-hash",
-]
-
 [[package]]
 name = "tiktoken-rs"
 version = "0.5.4"

assets/icons/update.svg 🔗

@@ -0,0 +1,8 @@
+<svg width="15" height="15" viewBox="0 0 15 15" fill="none" xmlns="http://www.w3.org/2000/svg">
+  <path
+    fill-rule="evenodd"
+    clip-rule="evenodd"

crates/ai/src/embedding.rs 🔗

@@ -85,25 +85,6 @@ impl Embedding {
     }
 }
 
-// impl FromSql for Embedding {
-//     fn column_result(value: ValueRef) -> FromSqlResult<Self> {
-//         let bytes = value.as_blob()?;
-//         let embedding: Result<Vec<f32>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes);
-//         if embedding.is_err() {
-//             return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err()));
-//         }
-//         Ok(Embedding(embedding.unwrap()))
-//     }
-// }
-
-// impl ToSql for Embedding {
-//     fn to_sql(&self) -> rusqlite::Result<ToSqlOutput> {
-//         let bytes = bincode::serialize(&self.0)
-//             .map_err(|err| rusqlite::Error::ToSqlConversionFailure(Box::new(err)))?;
-//         Ok(ToSqlOutput::Owned(rusqlite::types::Value::Blob(bytes)))
-//     }
-// }
-
 #[derive(Clone)]
 pub struct OpenAIEmbeddings {
     pub client: Arc<dyn HttpClient>,
@@ -300,6 +281,7 @@ impl EmbeddingProvider for OpenAIEmbeddings {
                     request_timeout,
                 )
                 .await?;
+
             request_number += 1;
 
             match response.status() {

crates/assistant/Cargo.toml 🔗

@@ -22,8 +22,11 @@ settings = { path = "../settings" }
 theme = { path = "../theme" }
 util = { path = "../util" }
 workspace = { path = "../workspace" }
-uuid.workspace = true
+semantic_index = { path = "../semantic_index" }
+project = { path = "../project" }
 
+uuid.workspace = true
+log.workspace = true
 anyhow.workspace = true
 chrono = { version = "0.4", features = ["serde"] }
 futures.workspace = true
@@ -36,7 +39,7 @@ schemars.workspace = true
 serde.workspace = true
 serde_json.workspace = true
 smol.workspace = true
-tiktoken-rs = "0.4"
+tiktoken-rs = "0.5"
 
 [dev-dependencies]
 editor = { path = "../editor", features = ["test-support"] }

crates/assistant/src/assistant_panel.rs 🔗

@@ -1,7 +1,7 @@
 use crate::{
     assistant_settings::{AssistantDockPosition, AssistantSettings, OpenAIModel},
     codegen::{self, Codegen, CodegenKind},
-    prompts::generate_content_prompt,
+    prompts::{generate_content_prompt, PromptCodeSnippet},
     MessageId, MessageMetadata, MessageStatus, Role, SavedConversation, SavedConversationMetadata,
     SavedMessage,
 };
@@ -29,13 +29,15 @@ use gpui::{
     },
     fonts::HighlightStyle,
     geometry::vector::{vec2f, Vector2F},
-    platform::{CursorStyle, MouseButton},
+    platform::{CursorStyle, MouseButton, PromptLevel},
     Action, AnyElement, AppContext, AsyncAppContext, ClipboardItem, Element, Entity, ModelContext,
-    ModelHandle, SizeConstraint, Subscription, Task, View, ViewContext, ViewHandle, WeakViewHandle,
-    WindowContext,
+    ModelHandle, SizeConstraint, Subscription, Task, View, ViewContext, ViewHandle,
+    WeakModelHandle, WeakViewHandle, WindowContext,
 };
 use language::{language_settings::SoftWrap, Buffer, LanguageRegistry, ToOffset as _};
+use project::Project;
 use search::BufferSearchBar;
+use semantic_index::{SemanticIndex, SemanticIndexStatus};
 use settings::SettingsStore;
 use std::{
     cell::{Cell, RefCell},
@@ -46,7 +48,7 @@ use std::{
     path::{Path, PathBuf},
     rc::Rc,
     sync::Arc,
-    time::Duration,
+    time::{Duration, Instant},
 };
 use theme::{
     components::{action_button::Button, ComponentExt},
@@ -72,6 +74,7 @@ actions!(
         ResetKey,
         InlineAssist,
         ToggleIncludeConversation,
+        ToggleRetrieveContext,
     ]
 );
 
@@ -108,6 +111,7 @@ pub fn init(cx: &mut AppContext) {
     cx.add_action(InlineAssistant::confirm);
     cx.add_action(InlineAssistant::cancel);
     cx.add_action(InlineAssistant::toggle_include_conversation);
+    cx.add_action(InlineAssistant::toggle_retrieve_context);
     cx.add_action(InlineAssistant::move_up);
     cx.add_action(InlineAssistant::move_down);
 }
@@ -145,6 +149,8 @@ pub struct AssistantPanel {
     include_conversation_in_next_inline_assist: bool,
     inline_prompt_history: VecDeque<String>,
     _watch_saved_conversations: Task<Result<()>>,
+    semantic_index: Option<ModelHandle<SemanticIndex>>,
+    retrieve_context_in_next_inline_assist: bool,
 }
 
 impl AssistantPanel {
@@ -191,6 +197,9 @@ impl AssistantPanel {
                         toolbar.add_item(cx.add_view(|cx| BufferSearchBar::new(cx)), cx);
                         toolbar
                     });
+
+                    let semantic_index = SemanticIndex::global(cx);
+
                     let mut this = Self {
                         workspace: workspace_handle,
                         active_editor_index: Default::default(),
@@ -215,6 +224,8 @@ impl AssistantPanel {
                         include_conversation_in_next_inline_assist: false,
                         inline_prompt_history: Default::default(),
                         _watch_saved_conversations,
+                        semantic_index,
+                        retrieve_context_in_next_inline_assist: false,
                     };
 
                     let mut old_dock_position = this.position(cx);
@@ -262,12 +273,19 @@ impl AssistantPanel {
             return;
         };
 
+        let project = workspace.project();
+
         this.update(cx, |assistant, cx| {
-            assistant.new_inline_assist(&active_editor, cx)
+            assistant.new_inline_assist(&active_editor, cx, project)
         });
     }
 
-    fn new_inline_assist(&mut self, editor: &ViewHandle<Editor>, cx: &mut ViewContext<Self>) {
+    fn new_inline_assist(
+        &mut self,
+        editor: &ViewHandle<Editor>,
+        cx: &mut ViewContext<Self>,
+        project: &ModelHandle<Project>,
+    ) {
         let api_key = if let Some(api_key) = self.api_key.borrow().clone() {
             api_key
         } else {
@@ -312,6 +330,27 @@ impl AssistantPanel {
             Codegen::new(editor.read(cx).buffer().clone(), codegen_kind, provider, cx)
         });
 
+        if let Some(semantic_index) = self.semantic_index.clone() {
+            let project = project.clone();
+            cx.spawn(|_, mut cx| async move {
+                let previously_indexed = semantic_index
+                    .update(&mut cx, |index, cx| {
+                        index.project_previously_indexed(&project, cx)
+                    })
+                    .await
+                    .unwrap_or(false);
+                if previously_indexed {
+                    let _ = semantic_index
+                        .update(&mut cx, |index, cx| {
+                            index.index_project(project.clone(), cx)
+                        })
+                        .await;
+                }
+                anyhow::Ok(())
+            })
+            .detach_and_log_err(cx);
+        }
+
         let measurements = Rc::new(Cell::new(BlockMeasurements::default()));
         let inline_assistant = cx.add_view(|cx| {
             let assistant = InlineAssistant::new(
@@ -322,6 +361,9 @@ impl AssistantPanel {
                 codegen.clone(),
                 self.workspace.clone(),
                 cx,
+                self.retrieve_context_in_next_inline_assist,
+                self.semantic_index.clone(),
+                project.clone(),
             );
             cx.focus_self();
             assistant
@@ -362,6 +404,7 @@ impl AssistantPanel {
                 editor: editor.downgrade(),
                 inline_assistant: Some((block_id, inline_assistant.clone())),
                 codegen: codegen.clone(),
+                project: project.downgrade(),
                 _subscriptions: vec![
                     cx.subscribe(&inline_assistant, Self::handle_inline_assistant_event),
                     cx.subscribe(editor, {
@@ -440,8 +483,15 @@ impl AssistantPanel {
             InlineAssistantEvent::Confirmed {
                 prompt,
                 include_conversation,
+                retrieve_context,
             } => {
-                self.confirm_inline_assist(assist_id, prompt, *include_conversation, cx);
+                self.confirm_inline_assist(
+                    assist_id,
+                    prompt,
+                    *include_conversation,
+                    cx,
+                    *retrieve_context,
+                );
             }
             InlineAssistantEvent::Canceled => {
                 self.finish_inline_assist(assist_id, true, cx);
@@ -454,6 +504,9 @@ impl AssistantPanel {
             } => {
                 self.include_conversation_in_next_inline_assist = *include_conversation;
             }
+            InlineAssistantEvent::RetrieveContextToggled { retrieve_context } => {
+                self.retrieve_context_in_next_inline_assist = *retrieve_context
+            }
         }
     }
 
@@ -532,6 +585,7 @@ impl AssistantPanel {
         user_prompt: &str,
         include_conversation: bool,
         cx: &mut ViewContext<Self>,
+        retrieve_context: bool,
     ) {
         let conversation = if include_conversation {
             self.active_editor()
@@ -553,6 +607,8 @@ impl AssistantPanel {
             return;
         };
 
+        let project = pending_assist.project.clone();
+
         self.inline_prompt_history
             .retain(|prompt| prompt != user_prompt);
         self.inline_prompt_history.push_back(user_prompt.into());
@@ -593,10 +649,62 @@ impl AssistantPanel {
         let codegen_kind = codegen.read(cx).kind().clone();
         let user_prompt = user_prompt.to_string();
 
-        let mut messages = Vec::new();
+        let snippets = if retrieve_context {
+            let Some(project) = project.upgrade(cx) else {
+                return;
+            };
+
+            let search_results = if let Some(semantic_index) = self.semantic_index.clone() {
+                let search_results = semantic_index.update(cx, |this, cx| {
+                    this.search_project(project, user_prompt.to_string(), 10, vec![], vec![], cx)
+                });
+
+                cx.background()
+                    .spawn(async move { search_results.await.unwrap_or_default() })
+            } else {
+                Task::ready(Vec::new())
+            };
+
+            let snippets = cx.spawn(|_, cx| async move {
+                let mut snippets = Vec::new();
+                for result in search_results.await {
+                    snippets.push(PromptCodeSnippet::new(result, &cx));
+
+                    // snippets.push(result.buffer.read_with(&cx, |buffer, _| {
+                    //     buffer
+                    //         .snapshot()
+                    //         .text_for_range(result.range)
+                    //         .collect::<String>()
+                    // }));
+                }
+                snippets
+            });
+            snippets
+        } else {
+            Task::ready(Vec::new())
+        };
+
         let mut model = settings::get::<AssistantSettings>(cx)
             .default_open_ai_model
             .clone();
+        let model_name = model.full_name();
+
+        let prompt = cx.background().spawn(async move {
+            let snippets = snippets.await;
+
+            let language_name = language_name.as_deref();
+            generate_content_prompt(
+                user_prompt,
+                language_name,
+                &buffer,
+                range,
+                codegen_kind,
+                snippets,
+                model_name,
+            )
+        });
+
+        let mut messages = Vec::new();
         if let Some(conversation) = conversation {
             let conversation = conversation.read(cx);
             let buffer = conversation.buffer.read(cx);
@@ -608,11 +716,6 @@ impl AssistantPanel {
             model = conversation.model.clone();
         }
 
-        let prompt = cx.background().spawn(async move {
-            let language_name = language_name.as_deref();
-            generate_content_prompt(user_prompt, language_name, &buffer, range, codegen_kind)
-        });
-
         cx.spawn(|_, mut cx| async move {
             let prompt = prompt.await;
 
@@ -1514,12 +1617,14 @@ impl Conversation {
                         Role::Assistant => "assistant".into(),
                         Role::System => "system".into(),
                     },
-                    content: self
-                        .buffer
-                        .read(cx)
-                        .text_for_range(message.offset_range)
-                        .collect(),
+                    content: Some(
+                        self.buffer
+                            .read(cx)
+                            .text_for_range(message.offset_range)
+                            .collect(),
+                    ),
                     name: None,
+                    function_call: None,
                 })
             })
             .collect::<Vec<_>>();
@@ -2638,12 +2743,16 @@ enum InlineAssistantEvent {
     Confirmed {
         prompt: String,
         include_conversation: bool,
+        retrieve_context: bool,
     },
     Canceled,
     Dismissed,
     IncludeConversationToggled {
         include_conversation: bool,
     },
+    RetrieveContextToggled {
+        retrieve_context: bool,
+    },
 }
 
 struct InlineAssistant {
@@ -2659,6 +2768,11 @@ struct InlineAssistant {
     pending_prompt: String,
     codegen: ModelHandle<Codegen>,
     _subscriptions: Vec<Subscription>,
+    retrieve_context: bool,
+    semantic_index: Option<ModelHandle<SemanticIndex>>,
+    semantic_permissioned: Option<bool>,
+    project: WeakModelHandle<Project>,
+    maintain_rate_limit: Option<Task<()>>,
 }
 
 impl Entity for InlineAssistant {
@@ -2675,51 +2789,65 @@ impl View for InlineAssistant {
         let theme = theme::current(cx);
 
         Flex::row()
-            .with_child(
-                Flex::row()
-                    .with_child(
-                        Button::action(ToggleIncludeConversation)
-                            .with_tooltip("Include Conversation", theme.tooltip.clone())
+            .with_children([Flex::row()
+                .with_child(
+                    Button::action(ToggleIncludeConversation)
+                        .with_tooltip("Include Conversation", theme.tooltip.clone())
+                        .with_id(self.id)
+                        .with_contents(theme::components::svg::Svg::new("icons/ai.svg"))
+                        .toggleable(self.include_conversation)
+                        .with_style(theme.assistant.inline.include_conversation.clone())
+                        .element()
+                        .aligned(),
+                )
+                .with_children(if SemanticIndex::enabled(cx) {
+                    Some(
+                        Button::action(ToggleRetrieveContext)
+                            .with_tooltip("Retrieve Context", theme.tooltip.clone())
                             .with_id(self.id)
-                            .with_contents(theme::components::svg::Svg::new("icons/ai.svg"))
-                            .toggleable(self.include_conversation)
-                            .with_style(theme.assistant.inline.include_conversation.clone())
+                            .with_contents(theme::components::svg::Svg::new(
+                                "icons/magnifying_glass.svg",
+                            ))
+                            .toggleable(self.retrieve_context)
+                            .with_style(theme.assistant.inline.retrieve_context.clone())
                             .element()
                             .aligned(),
                     )
-                    .with_children(if let Some(error) = self.codegen.read(cx).error() {
-                        Some(
-                            Svg::new("icons/error.svg")
-                                .with_color(theme.assistant.error_icon.color)
-                                .constrained()
-                                .with_width(theme.assistant.error_icon.width)
-                                .contained()
-                                .with_style(theme.assistant.error_icon.container)
-                                .with_tooltip::<ErrorIcon>(
-                                    self.id,
-                                    error.to_string(),
-                                    None,
-                                    theme.tooltip.clone(),
-                                    cx,
-                                )
-                                .aligned(),
-                        )
-                    } else {
-                        None
-                    })
-                    .aligned()
-                    .constrained()
-                    .dynamically({
-                        let measurements = self.measurements.clone();
-                        move |constraint, _, _| {
-                            let measurements = measurements.get();
-                            SizeConstraint {
-                                min: vec2f(measurements.gutter_width, constraint.min.y()),
-                                max: vec2f(measurements.gutter_width, constraint.max.y()),
-                            }
+                } else {
+                    None
+                })
+                .with_children(if let Some(error) = self.codegen.read(cx).error() {
+                    Some(
+                        Svg::new("icons/error.svg")
+                            .with_color(theme.assistant.error_icon.color)
+                            .constrained()
+                            .with_width(theme.assistant.error_icon.width)
+                            .contained()
+                            .with_style(theme.assistant.error_icon.container)
+                            .with_tooltip::<ErrorIcon>(
+                                self.id,
+                                error.to_string(),
+                                None,
+                                theme.tooltip.clone(),
+                                cx,
+                            )
+                            .aligned(),
+                    )
+                } else {
+                    None
+                })
+                .aligned()
+                .constrained()
+                .dynamically({
+                    let measurements = self.measurements.clone();
+                    move |constraint, _, _| {
+                        let measurements = measurements.get();
+                        SizeConstraint {
+                            min: vec2f(measurements.gutter_width, constraint.min.y()),
+                            max: vec2f(measurements.gutter_width, constraint.max.y()),
                         }
-                    }),
-            )
+                    }
+                })])
             .with_child(Empty::new().constrained().dynamically({
                 let measurements = self.measurements.clone();
                 move |constraint, _, _| {
@@ -2742,6 +2870,16 @@ impl View for InlineAssistant {
                     .left()
                     .flex(1., true),
             )
+            .with_children(if self.retrieve_context {
+                Some(
+                    Flex::row()
+                        .with_children(self.retrieve_context_status(cx))
+                        .flex(1., true)
+                        .aligned(),
+                )
+            } else {
+                None
+            })
             .contained()
             .with_style(theme.assistant.inline.container)
             .into_any()
@@ -2767,6 +2905,9 @@ impl InlineAssistant {
         codegen: ModelHandle<Codegen>,
         workspace: WeakViewHandle<Workspace>,
         cx: &mut ViewContext<Self>,
+        retrieve_context: bool,
+        semantic_index: Option<ModelHandle<SemanticIndex>>,
+        project: ModelHandle<Project>,
     ) -> Self {
         let prompt_editor = cx.add_view(|cx| {
             let mut editor = Editor::single_line(
@@ -2780,11 +2921,16 @@ impl InlineAssistant {
             editor.set_placeholder_text(placeholder, cx);
             editor
         });
-        let subscriptions = vec![
+        let mut subscriptions = vec![
             cx.observe(&codegen, Self::handle_codegen_changed),
             cx.subscribe(&prompt_editor, Self::handle_prompt_editor_events),
         ];
-        Self {
+
+        if let Some(semantic_index) = semantic_index.clone() {
+            subscriptions.push(cx.observe(&semantic_index, Self::semantic_index_changed));
+        }
+
+        let assistant = Self {
             id,
             prompt_editor,
             workspace,
@@ -2797,7 +2943,33 @@ impl InlineAssistant {
             pending_prompt: String::new(),
             codegen,
             _subscriptions: subscriptions,
+            retrieve_context,
+            semantic_permissioned: None,
+            semantic_index,
+            project: project.downgrade(),
+            maintain_rate_limit: None,
+        };
+
+        assistant.index_project(cx).log_err();
+
+        assistant
+    }
+
+    fn semantic_permissioned(&self, cx: &mut ViewContext<Self>) -> Task<Result<bool>> {
+        if let Some(value) = self.semantic_permissioned {
+            return Task::ready(Ok(value));
         }
+
+        let Some(project) = self.project.upgrade(cx) else {
+            return Task::ready(Err(anyhow!("project was dropped")));
+        };
+
+        self.semantic_index
+            .as_ref()
+            .map(|semantic| {
+                semantic.update(cx, |this, cx| this.project_previously_indexed(&project, cx))
+            })
+            .unwrap_or(Task::ready(Ok(false)))
     }
 
     fn handle_prompt_editor_events(
@@ -2812,6 +2984,37 @@ impl InlineAssistant {
         }
     }
 
+    fn semantic_index_changed(
+        &mut self,
+        semantic_index: ModelHandle<SemanticIndex>,
+        cx: &mut ViewContext<Self>,
+    ) {
+        let Some(project) = self.project.upgrade(cx) else {
+            return;
+        };
+
+        let status = semantic_index.read(cx).status(&project);
+        match status {
+            SemanticIndexStatus::Indexing {
+                rate_limit_expiry: Some(_),
+                ..
+            } => {
+                if self.maintain_rate_limit.is_none() {
+                    self.maintain_rate_limit = Some(cx.spawn(|this, mut cx| async move {
+                        loop {
+                            cx.background().timer(Duration::from_secs(1)).await;
+                            this.update(&mut cx, |_, cx| cx.notify()).log_err();
+                        }
+                    }));
+                }
+                return;
+            }
+            _ => {
+                self.maintain_rate_limit = None;
+            }
+        }
+    }
+
     fn handle_codegen_changed(&mut self, _: ModelHandle<Codegen>, cx: &mut ViewContext<Self>) {
         let is_read_only = !self.codegen.read(cx).idle();
         self.prompt_editor.update(cx, |editor, cx| {
@@ -2861,12 +3064,241 @@ impl InlineAssistant {
             cx.emit(InlineAssistantEvent::Confirmed {
                 prompt,
                 include_conversation: self.include_conversation,
+                retrieve_context: self.retrieve_context,
             });
             self.confirmed = true;
             cx.notify();
         }
     }
 
+    fn toggle_retrieve_context(&mut self, _: &ToggleRetrieveContext, cx: &mut ViewContext<Self>) {
+        let semantic_permissioned = self.semantic_permissioned(cx);
+
+        let Some(project) = self.project.upgrade(cx) else {
+            return;
+        };
+
+        let project_name = project
+            .read(cx)
+            .worktree_root_names(cx)
+            .collect::<Vec<&str>>()
+            .join("/");
+        let is_plural = project_name.chars().filter(|letter| *letter == '/').count() > 0;
+        let prompt_text = format!("Would you like to index the '{}' project{} for context retrieval? This requires sending code to the OpenAI API", project_name,
+            if is_plural {
+                "s"
+            } else {""});
+
+        cx.spawn(|this, mut cx| async move {
+            // If Necessary prompt user
+            if !semantic_permissioned.await.unwrap_or(false) {
+                let mut answer = this.update(&mut cx, |_, cx| {
+                    cx.prompt(
+                        PromptLevel::Info,
+                        prompt_text.as_str(),
+                        &["Continue", "Cancel"],
+                    )
+                })?;
+
+                if answer.next().await == Some(0) {
+                    this.update(&mut cx, |this, _| {
+                        this.semantic_permissioned = Some(true);
+                    })?;
+                } else {
+                    return anyhow::Ok(());
+                }
+            }
+
+            // If permissioned, update context appropriately
+            this.update(&mut cx, |this, cx| {
+                this.retrieve_context = !this.retrieve_context;
+
+                cx.emit(InlineAssistantEvent::RetrieveContextToggled {
+                    retrieve_context: this.retrieve_context,
+                });
+
+                if this.retrieve_context {
+                    this.index_project(cx).log_err();
+                }
+
+                cx.notify();
+            })?;
+
+            anyhow::Ok(())
+        })
+        .detach_and_log_err(cx);
+    }
+
+    fn index_project(&self, cx: &mut ViewContext<Self>) -> anyhow::Result<()> {
+        let Some(project) = self.project.upgrade(cx) else {
+            return Err(anyhow!("project was dropped!"));
+        };
+
+        let semantic_permissioned = self.semantic_permissioned(cx);
+        if let Some(semantic_index) = SemanticIndex::global(cx) {
+            cx.spawn(|_, mut cx| async move {
+                // This has to be updated to accomodate for semantic_permissions
+                if semantic_permissioned.await.unwrap_or(false) {
+                    semantic_index
+                        .update(&mut cx, |index, cx| index.index_project(project, cx))
+                        .await
+                } else {
+                    Err(anyhow!("project is not permissioned for semantic indexing"))
+                }
+            })
+            .detach_and_log_err(cx);
+        }
+
+        anyhow::Ok(())
+    }
+
+    fn retrieve_context_status(
+        &self,
+        cx: &mut ViewContext<Self>,
+    ) -> Option<AnyElement<InlineAssistant>> {
+        enum ContextStatusIcon {}
+
+        let Some(project) = self.project.upgrade(cx) else {
+            return None;
+        };
+
+        if let Some(semantic_index) = SemanticIndex::global(cx) {
+            let status = semantic_index.update(cx, |index, _| index.status(&project));
+            let theme = theme::current(cx);
+            match status {
+                SemanticIndexStatus::NotAuthenticated {} => Some(
+                    Svg::new("icons/error.svg")
+                        .with_color(theme.assistant.error_icon.color)
+                        .constrained()
+                        .with_width(theme.assistant.error_icon.width)
+                        .contained()
+                        .with_style(theme.assistant.error_icon.container)
+                        .with_tooltip::<ContextStatusIcon>(
+                            self.id,
+                            "Not Authenticated. Please ensure you have a valid 'OPENAI_API_KEY' in your environment variables.",
+                            None,
+                            theme.tooltip.clone(),
+                            cx,
+                        )
+                        .aligned()
+                        .into_any(),
+                ),
+                SemanticIndexStatus::NotIndexed {} => Some(
+                    Svg::new("icons/error.svg")
+                        .with_color(theme.assistant.inline.context_status.error_icon.color)
+                        .constrained()
+                        .with_width(theme.assistant.inline.context_status.error_icon.width)
+                        .contained()
+                        .with_style(theme.assistant.inline.context_status.error_icon.container)
+                        .with_tooltip::<ContextStatusIcon>(
+                            self.id,
+                            "Not Indexed",
+                            None,
+                            theme.tooltip.clone(),
+                            cx,
+                        )
+                        .aligned()
+                        .into_any(),
+                ),
+                SemanticIndexStatus::Indexing {
+                    remaining_files,
+                    rate_limit_expiry,
+                } => {
+
+                    let mut status_text = if remaining_files == 0 {
+                        "Indexing...".to_string()
+                    } else {
+                        format!("Remaining files to index: {remaining_files}")
+                    };
+
+                    if let Some(rate_limit_expiry) = rate_limit_expiry {
+                        let remaining_seconds = rate_limit_expiry.duration_since(Instant::now());
+                        if remaining_seconds > Duration::from_secs(0) && remaining_files > 0 {
+                            write!(
+                                status_text,
+                                " (rate limit expires in {}s)",
+                                remaining_seconds.as_secs()
+                            )
+                            .unwrap();
+                        }
+                    }
+                    Some(
+                        Svg::new("icons/update.svg")
+                            .with_color(theme.assistant.inline.context_status.in_progress_icon.color)
+                            .constrained()
+                            .with_width(theme.assistant.inline.context_status.in_progress_icon.width)
+                            .contained()
+                            .with_style(theme.assistant.inline.context_status.in_progress_icon.container)
+                            .with_tooltip::<ContextStatusIcon>(
+                                self.id,
+                                status_text,
+                                None,
+                                theme.tooltip.clone(),
+                                cx,
+                            )
+                            .aligned()
+                            .into_any(),
+                    )
+                }
+                SemanticIndexStatus::Indexed {} => Some(
+                    Svg::new("icons/check.svg")
+                        .with_color(theme.assistant.inline.context_status.complete_icon.color)
+                        .constrained()
+                        .with_width(theme.assistant.inline.context_status.complete_icon.width)
+                        .contained()
+                        .with_style(theme.assistant.inline.context_status.complete_icon.container)
+                        .with_tooltip::<ContextStatusIcon>(
+                            self.id,
+                            "Index up to date",
+                            None,
+                            theme.tooltip.clone(),
+                            cx,
+                        )
+                        .aligned()
+                        .into_any(),
+                ),
+            }
+        } else {
+            None
+        }
+    }
+
+    // fn retrieve_context_status(&self, cx: &mut ViewContext<Self>) -> String {
+    //     let project = self.project.clone();
+    //     if let Some(semantic_index) = self.semantic_index.clone() {
+    //         let status = semantic_index.update(cx, |index, cx| index.status(&project));
+    //         return match status {
+    //             // This theoretically shouldnt be a valid code path
+    //             // As the inline assistant cant be launched without an API key
+    //             // We keep it here for safety
+    //             semantic_index::SemanticIndexStatus::NotAuthenticated => {
+    //                 "Not Authenticated!\nPlease ensure you have an `OPENAI_API_KEY` in your environment variables.".to_string()
+    //             }
+    //             semantic_index::SemanticIndexStatus::Indexed => {
+    //                 "Indexing Complete!".to_string()
+    //             }
+    //             semantic_index::SemanticIndexStatus::Indexing { remaining_files, rate_limit_expiry } => {
+
+    //                 let mut status = format!("Remaining files to index for Context Retrieval: {remaining_files}");
+
+    //                 if let Some(rate_limit_expiry) = rate_limit_expiry {
+    //                     let remaining_seconds =
+    //                             rate_limit_expiry.duration_since(Instant::now());
+    //                     if remaining_seconds > Duration::from_secs(0) {
+    //                         write!(status, " (rate limit resets in {}s)", remaining_seconds.as_secs()).unwrap();
+    //                     }
+    //                 }
+    //                 status
+    //             }
+    //             semantic_index::SemanticIndexStatus::NotIndexed => {
+    //                 "Not Indexed for Context Retrieval".to_string()
+    //             }
+    //         };
+    //     }
+
+    //     "".to_string()
+    // }
+
     fn toggle_include_conversation(
         &mut self,
         _: &ToggleIncludeConversation,
@@ -2929,6 +3361,7 @@ struct PendingInlineAssist {
     inline_assistant: Option<(BlockId, ViewHandle<InlineAssistant>)>,
     codegen: ModelHandle<Codegen>,
     _subscriptions: Vec<Subscription>,
+    project: WeakModelHandle<Project>,
 }
 
 fn merge_ranges(ranges: &mut Vec<Range<Anchor>>, buffer: &MultiBufferSnapshot) {

crates/assistant/src/prompts.rs 🔗

@@ -1,8 +1,60 @@
 use crate::codegen::CodegenKind;
+use gpui::AsyncAppContext;
 use language::{BufferSnapshot, OffsetRangeExt, ToOffset};
+use semantic_index::SearchResult;
 use std::cmp::{self, Reverse};
 use std::fmt::Write;
 use std::ops::Range;
+use std::path::PathBuf;
+use tiktoken_rs::ChatCompletionRequestMessage;
+
+pub struct PromptCodeSnippet {
+    path: Option<PathBuf>,
+    language_name: Option<String>,
+    content: String,
+}
+
+impl PromptCodeSnippet {
+    pub fn new(search_result: SearchResult, cx: &AsyncAppContext) -> Self {
+        let (content, language_name, file_path) =
+            search_result.buffer.read_with(cx, |buffer, _| {
+                let snapshot = buffer.snapshot();
+                let content = snapshot
+                    .text_for_range(search_result.range.clone())
+                    .collect::<String>();
+
+                let language_name = buffer
+                    .language()
+                    .and_then(|language| Some(language.name().to_string()));
+
+                let file_path = buffer
+                    .file()
+                    .and_then(|file| Some(file.path().to_path_buf()));
+
+                (content, language_name, file_path)
+            });
+
+        PromptCodeSnippet {
+            path: file_path,
+            language_name,
+            content,
+        }
+    }
+}
+
+impl ToString for PromptCodeSnippet {
+    fn to_string(&self) -> String {
+        let path = self
+            .path
+            .as_ref()
+            .and_then(|path| Some(path.to_string_lossy().to_string()))
+            .unwrap_or("".to_string());
+        let language_name = self.language_name.clone().unwrap_or("".to_string());
+        let content = self.content.clone();
+
+        format!("The below code snippet may be relevant from file: {path}\n```{language_name}\n{content}\n```")
+    }
+}
 
 #[allow(dead_code)]
 fn summarize(buffer: &BufferSnapshot, selected_range: Range<impl ToOffset>) -> String {
@@ -121,17 +173,25 @@ pub fn generate_content_prompt(
     buffer: &BufferSnapshot,
     range: Range<impl ToOffset>,
     kind: CodegenKind,
+    search_results: Vec<PromptCodeSnippet>,
+    model: &str,
 ) -> String {
+    const MAXIMUM_SNIPPET_TOKEN_COUNT: usize = 500;
+    const RESERVED_TOKENS_FOR_GENERATION: usize = 1000;
+
+    let mut prompts = Vec::new();
     let range = range.to_offset(buffer);
-    let mut prompt = String::new();
 
     // General Preamble
     if let Some(language_name) = language_name {
-        writeln!(prompt, "You're an expert {language_name} engineer.\n").unwrap();
+        prompts.push(format!("You're an expert {language_name} engineer.\n"));
     } else {
-        writeln!(prompt, "You're an expert engineer.\n").unwrap();
+        prompts.push("You're an expert engineer.\n".to_string());
     }
 
+    // Snippets
+    let mut snippet_position = prompts.len() - 1;
+
     let mut content = String::new();
     content.extend(buffer.text_for_range(0..range.start));
     if range.start == range.end {
@@ -145,59 +205,99 @@ pub fn generate_content_prompt(
     }
     content.extend(buffer.text_for_range(range.end..buffer.len()));
 
-    writeln!(
-        prompt,
-        "The file you are currently working on has the following content:"
-    )
-    .unwrap();
+    prompts.push("The file you are currently working on has the following content:\n".to_string());
+
     if let Some(language_name) = language_name {
         let language_name = language_name.to_lowercase();
-        writeln!(prompt, "```{language_name}\n{content}\n```").unwrap();
+        prompts.push(format!("```{language_name}\n{content}\n```"));
     } else {
-        writeln!(prompt, "```\n{content}\n```").unwrap();
+        prompts.push(format!("```\n{content}\n```"));
     }
 
     match kind {
         CodegenKind::Generate { position: _ } => {
-            writeln!(prompt, "In particular, the user's cursor is current on the '<|START|>' span in the above outline, with no text selected.").unwrap();
-            writeln!(
-                prompt,
-                "Assume the cursor is located where the `<|START|` marker is."
-            )
-            .unwrap();
-            writeln!(
-                prompt,
+            prompts.push("In particular, the user's cursor is currently on the '<|START|>' span in the above outline, with no text selected.".to_string());
+            prompts
+                .push("Assume the cursor is located where the `<|START|` marker is.".to_string());
+            prompts.push(
                 "Text can't be replaced, so assume your answer will be inserted at the cursor."
-            )
-            .unwrap();
-            writeln!(
-                prompt,
+                    .to_string(),
+            );
+            prompts.push(format!(
                 "Generate text based on the users prompt: {user_prompt}"
-            )
-            .unwrap();
+            ));
         }
         CodegenKind::Transform { range: _ } => {
-            writeln!(prompt, "In particular, the user has selected a section of the text between the '<|START|' and '|END|>' spans.").unwrap();
-            writeln!(
-                prompt,
-                "Modify the users code selected text based upon the users prompt: {user_prompt}"
-            )
-            .unwrap();
-            writeln!(
-                prompt,
-                "You MUST reply with only the adjusted code (within the '<|START|' and '|END|>' spans), not the entire file."
-            )
-            .unwrap();
+            prompts.push("In particular, the user has selected a section of the text between the '<|START|' and '|END|>' spans.".to_string());
+            prompts.push(format!(
+                "Modify the users code selected text based upon the users prompt: '{user_prompt}'"
+            ));
+            prompts.push("You MUST reply with only the adjusted code (within the '<|START|' and '|END|>' spans), not the entire file.".to_string());
         }
     }
 
     if let Some(language_name) = language_name {
-        writeln!(prompt, "Your answer MUST always be valid {language_name}").unwrap();
+        prompts.push(format!(
+            "Your answer MUST always and only be valid {language_name}"
+        ));
+    }
+    prompts.push("Never make remarks about the output.".to_string());
+    prompts.push("Do not return any text, except the generated code.".to_string());
+    prompts.push("Do not wrap your text in a Markdown block".to_string());
+
+    let current_messages = [ChatCompletionRequestMessage {
+        role: "user".to_string(),
+        content: Some(prompts.join("\n")),
+        function_call: None,
+        name: None,
+    }];
+
+    let mut remaining_token_count = if let Ok(current_token_count) =
+        tiktoken_rs::num_tokens_from_messages(model, &current_messages)
+    {
+        let max_token_count = tiktoken_rs::model::get_context_size(model);
+        let intermediate_token_count = max_token_count - current_token_count;
+
+        if intermediate_token_count < RESERVED_TOKENS_FOR_GENERATION {
+            0
+        } else {
+            intermediate_token_count - RESERVED_TOKENS_FOR_GENERATION
+        }
+    } else {
+        // If tiktoken fails to count token count, assume we have no space remaining.
+        0
+    };
+
+    // TODO:
+    //   - add repository name to snippet
+    //   - add file path
+    //   - add language
+    if let Ok(encoding) = tiktoken_rs::get_bpe_from_model(model) {
+        let mut template = "You are working inside a large repository, here are a few code snippets that may be useful";
+
+        for search_result in search_results {
+            let mut snippet_prompt = template.to_string();
+            let snippet = search_result.to_string();
+            writeln!(snippet_prompt, "```\n{snippet}\n```").unwrap();
+
+            let token_count = encoding
+                .encode_with_special_tokens(snippet_prompt.as_str())
+                .len();
+            if token_count <= remaining_token_count {
+                if token_count < MAXIMUM_SNIPPET_TOKEN_COUNT {
+                    prompts.insert(snippet_position, snippet_prompt);
+                    snippet_position += 1;
+                    remaining_token_count -= token_count;
+                    // If you have already added the template to the prompt, remove the template.
+                    template = "";
+                }
+            } else {
+                break;
+            }
+        }
     }
-    writeln!(prompt, "Always wrap your response in a Markdown codeblock").unwrap();
-    writeln!(prompt, "Never make remarks about the output.").unwrap();
 
-    prompt
+    prompts.join("\n")
 }
 
 #[cfg(test)]

crates/theme/src/theme.rs 🔗

@@ -1199,6 +1199,15 @@ pub struct InlineAssistantStyle {
     pub disabled_editor: FieldEditor,
     pub pending_edit_background: Color,
     pub include_conversation: ToggleIconButtonStyle,
+    pub retrieve_context: ToggleIconButtonStyle,
+    pub context_status: ContextStatusStyle,
+}
+
+#[derive(Clone, Deserialize, Default, JsonSchema)]
+pub struct ContextStatusStyle {
+    pub error_icon: Icon,
+    pub in_progress_icon: Icon,
+    pub complete_icon: Icon,
 }
 
 #[derive(Clone, Deserialize, Default, JsonSchema)]

styles/src/style_tree/assistant.ts 🔗

@@ -79,6 +79,80 @@ export default function assistant(): any {
                 },
             },
             pending_edit_background: background(theme.highest, "positive"),
+            context_status: {
+                error_icon: {
+                    margin: { left: 8, right: 18 },
+                    color: foreground(theme.highest, "negative"),
+                    width: 12,
+                },
+                in_progress_icon: {
+                    margin: { left: 8, right: 18 },
+                    color: foreground(theme.highest, "positive"),
+                    width: 12,
+                },
+                complete_icon: {
+                    margin: { left: 8, right: 18 },
+                    color: foreground(theme.highest, "positive"),
+                    width: 12,
+                }
+            },
+            retrieve_context: toggleable({
+                base: interactive({
+                    base: {
+                        icon_size: 12,
+                        color: foreground(theme.highest, "variant"),
+
+                        button_width: 12,
+                        background: background(theme.highest, "on"),
+                        corner_radius: 2,
+                        border: {
+                            width: 1., color: background(theme.highest, "on")
+                        },
+                        margin: { left: 2 },
+                        padding: {
+                            left: 4,
+                            right: 4,
+                            top: 4,
+                            bottom: 4,
+                        },
+                    },
+                    state: {
+                        hovered: {
+                            ...text(theme.highest, "mono", "variant", "hovered"),
+                            background: background(theme.highest, "on", "hovered"),
+                            border: {
+                                width: 1., color: background(theme.highest, "on", "hovered")
+                            },
+                        },
+                        clicked: {
+                            ...text(theme.highest, "mono", "variant", "pressed"),
+                            background: background(theme.highest, "on", "pressed"),
+                            border: {
+                                width: 1., color: background(theme.highest, "on", "pressed")
+                            },
+                        },
+                    },
+                }),
+                state: {
+                    active: {
+                        default: {
+                            icon_size: 12,
+                            button_width: 12,
+                            color: foreground(theme.highest, "variant"),
+                            background: background(theme.highest, "accent"),
+                            border: border(theme.highest, "accent"),
+                        },
+                        hovered: {
+                            background: background(theme.highest, "accent", "hovered"),
+                            border: border(theme.highest, "accent", "hovered"),
+                        },
+                        clicked: {
+                            background: background(theme.highest, "accent", "pressed"),
+                            border: border(theme.highest, "accent", "pressed"),
+                        },
+                    },
+                },
+            }),
             include_conversation: toggleable({
                 base: interactive({
                     base: {