Allow codebase search to be turned on or off within the composer for assistant2 (#11315)

Kyle Kelley created

![image](https://github.com/zed-industries/zed/assets/836375/e03d2357-e2e4-49f1-86d6-7593bce13618)


![image](https://github.com/zed-industries/zed/assets/836375/3d769622-82e1-4e6f-bdec-4dce81e43423)


![image](https://github.com/zed-industries/zed/assets/836375/bf79a4ec-1660-47b1-8525-e741575dc5d4)

Release Notes:

- N/A

Change summary

crates/assistant2/src/assistant2.rs              |  14 +
crates/assistant2/src/completion_provider.rs     |   6 
crates/assistant2/src/tools/project_index.rs     |  99 +---------------
crates/assistant2/src/ui.rs                      |   2 
crates/assistant2/src/ui/composer.rs             |  24 ++-
crates/assistant2/src/ui/project_index_button.rs | 109 ++++++++++++++++++
crates/assistant_tooling/src/registry.rs         |  89 ++++++++++----
crates/assistant_tooling/src/tool.rs             |   4 
8 files changed, 209 insertions(+), 138 deletions(-)

Detailed changes

crates/assistant2/src/assistant2.rs 🔗

@@ -23,7 +23,7 @@ use semantic_index::{CloudEmbeddingProvider, ProjectIndex, SemanticIndex};
 use serde::Deserialize;
 use settings::Settings;
 use std::sync::Arc;
-use ui::Composer;
+use ui::{Composer, ProjectIndexButton};
 use util::{paths::EMBEDDINGS_DIR, ResultExt};
 use workspace::{
     dock::{DockPosition, Panel, PanelEvent},
@@ -228,6 +228,7 @@ pub struct AssistantChat {
     list_state: ListState,
     language_registry: Arc<LanguageRegistry>,
     composer_editor: View<Editor>,
+    project_index_button: Option<View<ProjectIndexButton>>,
     user_store: Model<UserStore>,
     next_message_id: MessageId,
     collapsed_messages: HashMap<MessageId, bool>,
@@ -263,6 +264,10 @@ impl AssistantChat {
             },
         );
 
+        let project_index_button = project_index.clone().map(|project_index| {
+            cx.new_view(|cx| ProjectIndexButton::new(project_index, tool_registry.clone(), cx))
+        });
+
         Self {
             model,
             messages: Vec::new(),
@@ -275,6 +280,7 @@ impl AssistantChat {
             list_state,
             user_store,
             language_registry,
+            project_index_button,
             project_index,
             next_message_id: MessageId(0),
             editing_message: None,
@@ -397,7 +403,7 @@ impl AssistantChat {
                     {
                         this.tool_registry.definitions()
                     } else {
-                        &[]
+                        Vec::new()
                     };
                     call_count += 1;
 
@@ -590,7 +596,7 @@ impl AssistantChat {
                         element.child(Composer::new(
                             body.clone(),
                             self.user_store.read(cx).current_user(),
-                            self.tool_registry.clone(),
+                            self.project_index_button.clone(),
                             crate::ui::ModelSelector::new(
                                 cx.view().downgrade(),
                                 self.model.clone(),
@@ -768,7 +774,7 @@ impl Render for AssistantChat {
             .child(Composer::new(
                 self.composer_editor.clone(),
                 self.user_store.read(cx).current_user(),
-                self.tool_registry.clone(),
+                self.project_index_button.clone(),
                 crate::ui::ModelSelector::new(cx.view().downgrade(), self.model.clone())
                     .into_any_element(),
             ))

crates/assistant2/src/completion_provider.rs 🔗

@@ -33,7 +33,7 @@ impl CompletionProvider {
         messages: Vec<CompletionMessage>,
         stop: Vec<String>,
         temperature: f32,
-        tools: &[ToolFunctionDefinition],
+        tools: Vec<ToolFunctionDefinition>,
     ) -> BoxFuture<'static, Result<BoxStream<'static, Result<proto::LanguageModelResponseMessage>>>>
     {
         self.0.complete(model, messages, stop, temperature, tools)
@@ -51,7 +51,7 @@ pub trait CompletionProviderBackend: 'static {
         messages: Vec<CompletionMessage>,
         stop: Vec<String>,
         temperature: f32,
-        tools: &[ToolFunctionDefinition],
+        tools: Vec<ToolFunctionDefinition>,
     ) -> BoxFuture<'static, Result<BoxStream<'static, Result<proto::LanguageModelResponseMessage>>>>;
 }
 
@@ -80,7 +80,7 @@ impl CompletionProviderBackend for CloudCompletionProvider {
         messages: Vec<CompletionMessage>,
         stop: Vec<String>,
         temperature: f32,
-        tools: &[ToolFunctionDefinition],
+        tools: Vec<ToolFunctionDefinition>,
     ) -> BoxFuture<'static, Result<BoxStream<'static, Result<proto::LanguageModelResponseMessage>>>>
     {
         let client = self.client.clone();

crates/assistant2/src/tools/project_index.rs 🔗

@@ -1,14 +1,17 @@
 use anyhow::Result;
-use assistant_tooling::LanguageModelTool;
-use gpui::{percentage, prelude::*, Animation, AnimationExt, AnyView, Model, Task, Transformation};
+use assistant_tooling::{
+    // assistant_tool_button::{AssistantToolButton, ToolStatus},
+    LanguageModelTool,
+};
+use gpui::{prelude::*, Model, Task};
 use project::Fs;
 use schemars::JsonSchema;
 use semantic_index::{ProjectIndex, Status};
 use serde::Deserialize;
-use std::{sync::Arc, time::Duration};
+use std::sync::Arc;
 use ui::{
-    div, prelude::*, ButtonLike, CollapsibleContainer, Color, Icon, IconName, Indicator, Label,
-    SharedString, Tooltip, WindowContext,
+    div, prelude::*, CollapsibleContainer, Color, Icon, IconName, Label, SharedString,
+    WindowContext,
 };
 use util::ResultExt as _;
 
@@ -199,13 +202,6 @@ impl LanguageModelTool for ProjectIndexTool {
         cx.new_view(|_cx| ProjectIndexView { input, output })
     }
 
-    fn status_view(&self, cx: &mut WindowContext) -> Option<AnyView> {
-        Some(
-            cx.new_view(|cx| ProjectIndexStatusView::new(self.project_index.clone(), cx))
-                .into(),
-        )
-    }
-
     fn format(_input: &Self::Input, output: &Result<Self::Output>) -> String {
         match &output {
             Ok(output) => {
@@ -236,82 +232,3 @@ impl LanguageModelTool for ProjectIndexTool {
         }
     }
 }
-
-struct ProjectIndexStatusView {
-    project_index: Model<ProjectIndex>,
-}
-
-impl ProjectIndexStatusView {
-    pub fn new(project_index: Model<ProjectIndex>, cx: &mut ViewContext<Self>) -> Self {
-        cx.subscribe(&project_index, |_this, _, _status: &Status, cx| {
-            cx.notify();
-        })
-        .detach();
-        Self { project_index }
-    }
-}
-
-impl Render for ProjectIndexStatusView {
-    fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
-        let status = self.project_index.read(cx).status();
-
-        let is_enabled = match status {
-            Status::Idle => true,
-            _ => false,
-        };
-
-        let icon = match status {
-            Status::Idle => Icon::new(IconName::Code)
-                .size(IconSize::XSmall)
-                .color(Color::Default),
-            Status::Loading => Icon::new(IconName::Code)
-                .size(IconSize::XSmall)
-                .color(Color::Muted),
-            Status::Scanning { .. } => Icon::new(IconName::Code)
-                .size(IconSize::XSmall)
-                .color(Color::Muted),
-        };
-
-        let indicator = match status {
-            Status::Idle => Some(Indicator::dot().color(Color::Success)),
-            Status::Scanning { .. } => Some(Indicator::dot().color(Color::Warning)),
-            Status::Loading => Some(Indicator::icon(
-                Icon::new(IconName::Spinner)
-                    .color(Color::Accent)
-                    .with_animation(
-                        "arrow-circle",
-                        Animation::new(Duration::from_secs(2)).repeat(),
-                        |icon, delta| icon.transform(Transformation::rotate(percentage(delta))),
-                    ),
-            )),
-        };
-
-        ButtonLike::new("project-index")
-            .disabled(!is_enabled)
-            .child(
-                ui::IconWithIndicator::new(icon, indicator)
-                    .indicator_border_color(Some(gpui::transparent_black())),
-            )
-            .tooltip({
-                move |cx| {
-                    let (tooltip, meta) = match status {
-                        Status::Idle => (
-                            "Project index ready".to_string(),
-                            Some("Click to disable".to_string()),
-                        ),
-                        Status::Loading => ("Project index loading...".to_string(), None),
-                        Status::Scanning { remaining_count } => (
-                            "Project index scanning...".to_string(),
-                            Some(format!("{} remaining...", remaining_count)),
-                        ),
-                    };
-
-                    if let Some(meta) = meta {
-                        Tooltip::with_meta(tooltip, None, meta, cx)
-                    } else {
-                        Tooltip::text(tooltip, cx)
-                    }
-                }
-            })
-    }
-}

crates/assistant2/src/ui.rs 🔗

@@ -1,6 +1,7 @@
 mod chat_message;
 mod chat_notice;
 mod composer;
+mod project_index_button;
 
 #[cfg(feature = "stories")]
 mod stories;
@@ -8,6 +9,7 @@ mod stories;
 pub use chat_message::*;
 pub use chat_notice::*;
 pub use composer::*;
+pub use project_index_button::*;
 
 #[cfg(feature = "stories")]
 pub use stories::*;

crates/assistant2/src/ui/composer.rs 🔗

@@ -1,4 +1,4 @@
-use assistant_tooling::ToolRegistry;
+use crate::{ui::ProjectIndexButton, AssistantChat, CompletionProvider};
 use client::User;
 use editor::{Editor, EditorElement, EditorStyle};
 use gpui::{AnyElement, FontStyle, FontWeight, TextStyle, View, WeakView, WhiteSpace};
@@ -7,13 +7,11 @@ use std::sync::Arc;
 use theme::ThemeSettings;
 use ui::{popover_menu, prelude::*, Avatar, ButtonLike, ContextMenu, Tooltip};
 
-use crate::{AssistantChat, CompletionProvider};
-
 #[derive(IntoElement)]
 pub struct Composer {
     editor: View<Editor>,
     player: Option<Arc<User>>,
-    tool_registry: Arc<ToolRegistry>,
+    project_index_button: Option<View<ProjectIndexButton>>,
     model_selector: AnyElement,
 }
 
@@ -21,20 +19,28 @@ impl Composer {
     pub fn new(
         editor: View<Editor>,
         player: Option<Arc<User>>,
-        tool_registry: Arc<ToolRegistry>,
+        project_index_button: Option<View<ProjectIndexButton>>,
         model_selector: AnyElement,
     ) -> Self {
         Self {
             editor,
             player,
-            tool_registry,
+            project_index_button,
             model_selector,
         }
     }
+
+    fn render_tools(&mut self, _cx: &mut WindowContext) -> impl IntoElement {
+        h_flex().children(
+            self.project_index_button
+                .clone()
+                .map(|view| view.into_any_element()),
+        )
+    }
 }
 
 impl RenderOnce for Composer {
-    fn render(self, cx: &mut WindowContext) -> impl IntoElement {
+    fn render(mut self, cx: &mut WindowContext) -> impl IntoElement {
         let mut player_avatar = div().size(rems_from_px(20.)).into_any_element();
         if let Some(player) = self.player.clone() {
             player_avatar = Avatar::new(player.avatar_uri.clone())
@@ -95,9 +101,7 @@ impl RenderOnce for Composer {
                                         .gap_2()
                                         .justify_between()
                                         .w_full()
-                                        .child(h_flex().gap_1().children(
-                                            self.tool_registry.status_views().iter().cloned(),
-                                        ))
+                                        .child(h_flex().gap_1().child(self.render_tools(cx)))
                                         .child(h_flex().gap_1().child(self.model_selector)),
                                 ),
                         ),

crates/assistant2/src/ui/project_index_button.rs 🔗

@@ -0,0 +1,109 @@
+use assistant_tooling::ToolRegistry;
+use gpui::{percentage, prelude::*, Animation, AnimationExt, Model, Transformation};
+use semantic_index::{ProjectIndex, Status};
+use std::{sync::Arc, time::Duration};
+use ui::{prelude::*, ButtonLike, Color, Icon, IconName, Indicator, Tooltip};
+
+use crate::tools::ProjectIndexTool;
+
+pub struct ProjectIndexButton {
+    project_index: Model<ProjectIndex>,
+    tool_registry: Arc<ToolRegistry>,
+}
+
+impl ProjectIndexButton {
+    pub fn new(
+        project_index: Model<ProjectIndex>,
+        tool_registry: Arc<ToolRegistry>,
+        cx: &mut ViewContext<Self>,
+    ) -> Self {
+        cx.subscribe(&project_index, |_this, _, _status: &Status, cx| {
+            cx.notify();
+        })
+        .detach();
+        Self {
+            project_index,
+            tool_registry,
+        }
+    }
+
+    pub fn set_enabled(&mut self, enabled: bool) {
+        self.tool_registry
+            .set_tool_enabled::<ProjectIndexTool>(enabled);
+    }
+}
+
+impl Render for ProjectIndexButton {
+    // Expanded information on ToolView
+    fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
+        let status = self.project_index.read(cx).status();
+        let is_enabled = self.tool_registry.is_tool_enabled::<ProjectIndexTool>();
+
+        let icon = if is_enabled {
+            match status {
+                Status::Idle => Icon::new(IconName::Code)
+                    .size(IconSize::XSmall)
+                    .color(Color::Default),
+                Status::Loading => Icon::new(IconName::Code)
+                    .size(IconSize::XSmall)
+                    .color(Color::Muted),
+                Status::Scanning { .. } => Icon::new(IconName::Code)
+                    .size(IconSize::XSmall)
+                    .color(Color::Muted),
+            }
+        } else {
+            Icon::new(IconName::Code)
+                .size(IconSize::XSmall)
+                .color(Color::Disabled)
+        };
+
+        let indicator = if is_enabled {
+            match status {
+                Status::Idle => Some(Indicator::dot().color(Color::Success)),
+                Status::Scanning { .. } => Some(Indicator::dot().color(Color::Warning)),
+                Status::Loading => Some(Indicator::icon(
+                    Icon::new(IconName::Spinner)
+                        .color(Color::Accent)
+                        .with_animation(
+                            "arrow-circle",
+                            Animation::new(Duration::from_secs(2)).repeat(),
+                            |icon, delta| icon.transform(Transformation::rotate(percentage(delta))),
+                        ),
+                )),
+            }
+        } else {
+            None
+        };
+
+        ButtonLike::new("project-index")
+            .child(
+                ui::IconWithIndicator::new(icon, indicator)
+                    .indicator_border_color(Some(gpui::transparent_black())),
+            )
+            .tooltip({
+                move |cx| {
+                    let (tooltip, meta) = match status {
+                        Status::Idle => (
+                            "Project index ready".to_string(),
+                            Some("Click to disable".to_string()),
+                        ),
+                        Status::Loading => ("Project index loading...".to_string(), None),
+                        Status::Scanning { remaining_count } => (
+                            "Project index scanning...".to_string(),
+                            Some(format!("{} remaining...", remaining_count)),
+                        ),
+                    };
+
+                    if let Some(meta) = meta {
+                        Tooltip::with_meta(tooltip, None, meta, cx)
+                    } else {
+                        Tooltip::text(tooltip, cx)
+                    }
+                }
+            })
+            .on_click(cx.listener(move |this, _, cx| {
+                this.set_enabled(!is_enabled);
+                cx.notify();
+            }))
+    }
+}

crates/assistant_tooling/src/registry.rs 🔗

@@ -1,48 +1,86 @@
 use anyhow::{anyhow, Result};
-use gpui::{AnyView, Task, WindowContext};
-use std::collections::HashMap;
+use gpui::{Task, WindowContext};
+use std::{
+    any::TypeId,
+    collections::HashMap,
+    sync::atomic::{AtomicBool, Ordering::SeqCst},
+};
 
 use crate::tool::{
     LanguageModelTool, ToolFunctionCall, ToolFunctionCallResult, ToolFunctionDefinition,
 };
 
+// Internal Tool representation for the registry
+pub struct Tool {
+    enabled: AtomicBool,
+    type_id: TypeId,
+    call: Box<dyn Fn(&ToolFunctionCall, &mut WindowContext) -> Task<Result<ToolFunctionCall>>>,
+    definition: ToolFunctionDefinition,
+}
+
+impl Tool {
+    fn new(
+        type_id: TypeId,
+        call: Box<dyn Fn(&ToolFunctionCall, &mut WindowContext) -> Task<Result<ToolFunctionCall>>>,
+        definition: ToolFunctionDefinition,
+    ) -> Self {
+        Self {
+            enabled: AtomicBool::new(true),
+            type_id,
+            call,
+            definition,
+        }
+    }
+}
+
 pub struct ToolRegistry {
-    tools: HashMap<
-        String,
-        Box<dyn Fn(&ToolFunctionCall, &mut WindowContext) -> Task<Result<ToolFunctionCall>>>,
-    >,
-    definitions: Vec<ToolFunctionDefinition>,
-    status_views: Vec<AnyView>,
+    tools: HashMap<String, Tool>,
 }
 
 impl ToolRegistry {
     pub fn new() -> Self {
         Self {
             tools: HashMap::new(),
-            definitions: Vec::new(),
-            status_views: Vec::new(),
         }
     }
 
-    pub fn definitions(&self) -> &[ToolFunctionDefinition] {
-        &self.definitions
+    pub fn set_tool_enabled<T: 'static + LanguageModelTool>(&self, is_enabled: bool) {
+        for tool in self.tools.values() {
+            if tool.type_id == TypeId::of::<T>() {
+                tool.enabled.store(is_enabled, SeqCst);
+                return;
+            }
+        }
+    }
+
+    pub fn is_tool_enabled<T: 'static + LanguageModelTool>(&self) -> bool {
+        for tool in self.tools.values() {
+            if tool.type_id == TypeId::of::<T>() {
+                return tool.enabled.load(SeqCst);
+            }
+        }
+        false
+    }
+
+    pub fn definitions(&self) -> Vec<ToolFunctionDefinition> {
+        self.tools
+            .values()
+            .filter(|tool| tool.enabled.load(SeqCst))
+            .map(|tool| tool.definition.clone())
+            .collect()
     }
 
     pub fn register<T: 'static + LanguageModelTool>(
         &mut self,
         tool: T,
-        cx: &mut WindowContext,
+        _cx: &mut WindowContext,
     ) -> Result<()> {
-        self.definitions.push(tool.definition());
-
-        if let Some(tool_view) = tool.status_view(cx) {
-            self.status_views.push(tool_view);
-        }
+        let definition = tool.definition();
 
         let name = tool.name();
-        let previous = self.tools.insert(
-            name.clone(),
-            // registry.call(tool_call, cx)
+
+        let registered_tool = Tool::new(
+            TypeId::of::<T>(),
             Box::new(
                 move |tool_call: &ToolFunctionCall, cx: &mut WindowContext| {
                     let name = tool_call.name.clone();
@@ -77,8 +115,11 @@ impl ToolRegistry {
                     })
                 },
             ),
+            definition,
         );
 
+        let previous = self.tools.insert(name.clone(), registered_tool);
+
         if previous.is_some() {
             return Err(anyhow!("already registered a tool with name {}", name));
         }
@@ -109,11 +150,7 @@ impl ToolRegistry {
             }
         };
 
-        tool(tool_call, cx)
-    }
-
-    pub fn status_views(&self) -> &[AnyView] {
-        &self.status_views
+        (tool.call)(tool_call, cx)
     }
 }
 

crates/assistant_tooling/src/tool.rs 🔗

@@ -104,8 +104,4 @@ pub trait LanguageModelTool {
         output: Result<Self::Output>,
         cx: &mut WindowContext,
     ) -> View<Self::View>;
-
-    fn status_view(&self, _cx: &mut WindowContext) -> Option<AnyView> {
-        None
-    }
 }