Detailed changes
@@ -23,7 +23,7 @@ use semantic_index::{CloudEmbeddingProvider, ProjectIndex, SemanticIndex};
use serde::Deserialize;
use settings::Settings;
use std::sync::Arc;
-use ui::Composer;
+use ui::{Composer, ProjectIndexButton};
use util::{paths::EMBEDDINGS_DIR, ResultExt};
use workspace::{
dock::{DockPosition, Panel, PanelEvent},
@@ -228,6 +228,7 @@ pub struct AssistantChat {
list_state: ListState,
language_registry: Arc<LanguageRegistry>,
composer_editor: View<Editor>,
+ project_index_button: Option<View<ProjectIndexButton>>,
user_store: Model<UserStore>,
next_message_id: MessageId,
collapsed_messages: HashMap<MessageId, bool>,
@@ -263,6 +264,10 @@ impl AssistantChat {
},
);
+ let project_index_button = project_index.clone().map(|project_index| {
+ cx.new_view(|cx| ProjectIndexButton::new(project_index, tool_registry.clone(), cx))
+ });
+
Self {
model,
messages: Vec::new(),
@@ -275,6 +280,7 @@ impl AssistantChat {
list_state,
user_store,
language_registry,
+ project_index_button,
project_index,
next_message_id: MessageId(0),
editing_message: None,
@@ -397,7 +403,7 @@ impl AssistantChat {
{
this.tool_registry.definitions()
} else {
- &[]
+ Vec::new()
};
call_count += 1;
@@ -590,7 +596,7 @@ impl AssistantChat {
element.child(Composer::new(
body.clone(),
self.user_store.read(cx).current_user(),
- self.tool_registry.clone(),
+ self.project_index_button.clone(),
crate::ui::ModelSelector::new(
cx.view().downgrade(),
self.model.clone(),
@@ -768,7 +774,7 @@ impl Render for AssistantChat {
.child(Composer::new(
self.composer_editor.clone(),
self.user_store.read(cx).current_user(),
- self.tool_registry.clone(),
+ self.project_index_button.clone(),
crate::ui::ModelSelector::new(cx.view().downgrade(), self.model.clone())
.into_any_element(),
))
@@ -33,7 +33,7 @@ impl CompletionProvider {
messages: Vec<CompletionMessage>,
stop: Vec<String>,
temperature: f32,
- tools: &[ToolFunctionDefinition],
+ tools: Vec<ToolFunctionDefinition>,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<proto::LanguageModelResponseMessage>>>>
{
self.0.complete(model, messages, stop, temperature, tools)
@@ -51,7 +51,7 @@ pub trait CompletionProviderBackend: 'static {
messages: Vec<CompletionMessage>,
stop: Vec<String>,
temperature: f32,
- tools: &[ToolFunctionDefinition],
+ tools: Vec<ToolFunctionDefinition>,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<proto::LanguageModelResponseMessage>>>>;
}
@@ -80,7 +80,7 @@ impl CompletionProviderBackend for CloudCompletionProvider {
messages: Vec<CompletionMessage>,
stop: Vec<String>,
temperature: f32,
- tools: &[ToolFunctionDefinition],
+ tools: Vec<ToolFunctionDefinition>,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<proto::LanguageModelResponseMessage>>>>
{
let client = self.client.clone();
@@ -1,14 +1,17 @@
use anyhow::Result;
-use assistant_tooling::LanguageModelTool;
-use gpui::{percentage, prelude::*, Animation, AnimationExt, AnyView, Model, Task, Transformation};
+use assistant_tooling::{
+ // assistant_tool_button::{AssistantToolButton, ToolStatus},
+ LanguageModelTool,
+};
+use gpui::{prelude::*, Model, Task};
use project::Fs;
use schemars::JsonSchema;
use semantic_index::{ProjectIndex, Status};
use serde::Deserialize;
-use std::{sync::Arc, time::Duration};
+use std::sync::Arc;
use ui::{
- div, prelude::*, ButtonLike, CollapsibleContainer, Color, Icon, IconName, Indicator, Label,
- SharedString, Tooltip, WindowContext,
+ div, prelude::*, CollapsibleContainer, Color, Icon, IconName, Label, SharedString,
+ WindowContext,
};
use util::ResultExt as _;
@@ -199,13 +202,6 @@ impl LanguageModelTool for ProjectIndexTool {
cx.new_view(|_cx| ProjectIndexView { input, output })
}
- fn status_view(&self, cx: &mut WindowContext) -> Option<AnyView> {
- Some(
- cx.new_view(|cx| ProjectIndexStatusView::new(self.project_index.clone(), cx))
- .into(),
- )
- }
-
fn format(_input: &Self::Input, output: &Result<Self::Output>) -> String {
match &output {
Ok(output) => {
@@ -236,82 +232,3 @@ impl LanguageModelTool for ProjectIndexTool {
}
}
}
-
-struct ProjectIndexStatusView {
- project_index: Model<ProjectIndex>,
-}
-
-impl ProjectIndexStatusView {
- pub fn new(project_index: Model<ProjectIndex>, cx: &mut ViewContext<Self>) -> Self {
- cx.subscribe(&project_index, |_this, _, _status: &Status, cx| {
- cx.notify();
- })
- .detach();
- Self { project_index }
- }
-}
-
-impl Render for ProjectIndexStatusView {
- fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
- let status = self.project_index.read(cx).status();
-
- let is_enabled = match status {
- Status::Idle => true,
- _ => false,
- };
-
- let icon = match status {
- Status::Idle => Icon::new(IconName::Code)
- .size(IconSize::XSmall)
- .color(Color::Default),
- Status::Loading => Icon::new(IconName::Code)
- .size(IconSize::XSmall)
- .color(Color::Muted),
- Status::Scanning { .. } => Icon::new(IconName::Code)
- .size(IconSize::XSmall)
- .color(Color::Muted),
- };
-
- let indicator = match status {
- Status::Idle => Some(Indicator::dot().color(Color::Success)),
- Status::Scanning { .. } => Some(Indicator::dot().color(Color::Warning)),
- Status::Loading => Some(Indicator::icon(
- Icon::new(IconName::Spinner)
- .color(Color::Accent)
- .with_animation(
- "arrow-circle",
- Animation::new(Duration::from_secs(2)).repeat(),
- |icon, delta| icon.transform(Transformation::rotate(percentage(delta))),
- ),
- )),
- };
-
- ButtonLike::new("project-index")
- .disabled(!is_enabled)
- .child(
- ui::IconWithIndicator::new(icon, indicator)
- .indicator_border_color(Some(gpui::transparent_black())),
- )
- .tooltip({
- move |cx| {
- let (tooltip, meta) = match status {
- Status::Idle => (
- "Project index ready".to_string(),
- Some("Click to disable".to_string()),
- ),
- Status::Loading => ("Project index loading...".to_string(), None),
- Status::Scanning { remaining_count } => (
- "Project index scanning...".to_string(),
- Some(format!("{} remaining...", remaining_count)),
- ),
- };
-
- if let Some(meta) = meta {
- Tooltip::with_meta(tooltip, None, meta, cx)
- } else {
- Tooltip::text(tooltip, cx)
- }
- }
- })
- }
-}
@@ -1,6 +1,7 @@
mod chat_message;
mod chat_notice;
mod composer;
+mod project_index_button;
#[cfg(feature = "stories")]
mod stories;
@@ -8,6 +9,7 @@ mod stories;
pub use chat_message::*;
pub use chat_notice::*;
pub use composer::*;
+pub use project_index_button::*;
#[cfg(feature = "stories")]
pub use stories::*;
@@ -1,4 +1,4 @@
-use assistant_tooling::ToolRegistry;
+use crate::{ui::ProjectIndexButton, AssistantChat, CompletionProvider};
use client::User;
use editor::{Editor, EditorElement, EditorStyle};
use gpui::{AnyElement, FontStyle, FontWeight, TextStyle, View, WeakView, WhiteSpace};
@@ -7,13 +7,11 @@ use std::sync::Arc;
use theme::ThemeSettings;
use ui::{popover_menu, prelude::*, Avatar, ButtonLike, ContextMenu, Tooltip};
-use crate::{AssistantChat, CompletionProvider};
-
#[derive(IntoElement)]
pub struct Composer {
editor: View<Editor>,
player: Option<Arc<User>>,
- tool_registry: Arc<ToolRegistry>,
+ project_index_button: Option<View<ProjectIndexButton>>,
model_selector: AnyElement,
}
@@ -21,20 +19,28 @@ impl Composer {
pub fn new(
editor: View<Editor>,
player: Option<Arc<User>>,
- tool_registry: Arc<ToolRegistry>,
+ project_index_button: Option<View<ProjectIndexButton>>,
model_selector: AnyElement,
) -> Self {
Self {
editor,
player,
- tool_registry,
+ project_index_button,
model_selector,
}
}
+
+ fn render_tools(&mut self, _cx: &mut WindowContext) -> impl IntoElement {
+ h_flex().children(
+ self.project_index_button
+ .clone()
+ .map(|view| view.into_any_element()),
+ )
+ }
}
impl RenderOnce for Composer {
- fn render(self, cx: &mut WindowContext) -> impl IntoElement {
+ fn render(mut self, cx: &mut WindowContext) -> impl IntoElement {
let mut player_avatar = div().size(rems_from_px(20.)).into_any_element();
if let Some(player) = self.player.clone() {
player_avatar = Avatar::new(player.avatar_uri.clone())
@@ -95,9 +101,7 @@ impl RenderOnce for Composer {
.gap_2()
.justify_between()
.w_full()
- .child(h_flex().gap_1().children(
- self.tool_registry.status_views().iter().cloned(),
- ))
+ .child(h_flex().gap_1().child(self.render_tools(cx)))
.child(h_flex().gap_1().child(self.model_selector)),
),
),
@@ -0,0 +1,109 @@
+use assistant_tooling::ToolRegistry;
+use gpui::{percentage, prelude::*, Animation, AnimationExt, Model, Transformation};
+use semantic_index::{ProjectIndex, Status};
+use std::{sync::Arc, time::Duration};
+use ui::{prelude::*, ButtonLike, Color, Icon, IconName, Indicator, Tooltip};
+
+use crate::tools::ProjectIndexTool;
+
+pub struct ProjectIndexButton {
+ project_index: Model<ProjectIndex>,
+ tool_registry: Arc<ToolRegistry>,
+}
+
+impl ProjectIndexButton {
+ pub fn new(
+ project_index: Model<ProjectIndex>,
+ tool_registry: Arc<ToolRegistry>,
+ cx: &mut ViewContext<Self>,
+ ) -> Self {
+ cx.subscribe(&project_index, |_this, _, _status: &Status, cx| {
+ cx.notify();
+ })
+ .detach();
+ Self {
+ project_index,
+ tool_registry,
+ }
+ }
+
+ pub fn set_enabled(&mut self, enabled: bool) {
+ self.tool_registry
+ .set_tool_enabled::<ProjectIndexTool>(enabled);
+ }
+}
+
+impl Render for ProjectIndexButton {
+ // Expanded information on ToolView
+ fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
+ let status = self.project_index.read(cx).status();
+ let is_enabled = self.tool_registry.is_tool_enabled::<ProjectIndexTool>();
+
+ let icon = if is_enabled {
+ match status {
+ Status::Idle => Icon::new(IconName::Code)
+ .size(IconSize::XSmall)
+ .color(Color::Default),
+ Status::Loading => Icon::new(IconName::Code)
+ .size(IconSize::XSmall)
+ .color(Color::Muted),
+ Status::Scanning { .. } => Icon::new(IconName::Code)
+ .size(IconSize::XSmall)
+ .color(Color::Muted),
+ }
+ } else {
+ Icon::new(IconName::Code)
+ .size(IconSize::XSmall)
+ .color(Color::Disabled)
+ };
+
+ let indicator = if is_enabled {
+ match status {
+ Status::Idle => Some(Indicator::dot().color(Color::Success)),
+ Status::Scanning { .. } => Some(Indicator::dot().color(Color::Warning)),
+ Status::Loading => Some(Indicator::icon(
+ Icon::new(IconName::Spinner)
+ .color(Color::Accent)
+ .with_animation(
+ "arrow-circle",
+ Animation::new(Duration::from_secs(2)).repeat(),
+ |icon, delta| icon.transform(Transformation::rotate(percentage(delta))),
+ ),
+ )),
+ }
+ } else {
+ None
+ };
+
+ ButtonLike::new("project-index")
+ .child(
+ ui::IconWithIndicator::new(icon, indicator)
+ .indicator_border_color(Some(gpui::transparent_black())),
+ )
+ .tooltip({
+ move |cx| {
+ let (tooltip, meta) = match status {
+ Status::Idle => (
+ "Project index ready".to_string(),
+ Some("Click to disable".to_string()),
+ ),
+ Status::Loading => ("Project index loading...".to_string(), None),
+ Status::Scanning { remaining_count } => (
+ "Project index scanning...".to_string(),
+ Some(format!("{} remaining...", remaining_count)),
+ ),
+ };
+
+ if let Some(meta) = meta {
+ Tooltip::with_meta(tooltip, None, meta, cx)
+ } else {
+ Tooltip::text(tooltip, cx)
+ }
+ }
+ })
+ .on_click(cx.listener(move |this, _, cx| {
+ this.set_enabled(!is_enabled);
+ cx.notify();
+ }))
+ }
+}
@@ -1,48 +1,86 @@
use anyhow::{anyhow, Result};
-use gpui::{AnyView, Task, WindowContext};
-use std::collections::HashMap;
+use gpui::{Task, WindowContext};
+use std::{
+ any::TypeId,
+ collections::HashMap,
+ sync::atomic::{AtomicBool, Ordering::SeqCst},
+};
use crate::tool::{
LanguageModelTool, ToolFunctionCall, ToolFunctionCallResult, ToolFunctionDefinition,
};
+// Internal Tool representation for the registry
+pub struct Tool {
+ enabled: AtomicBool,
+ type_id: TypeId,
+ call: Box<dyn Fn(&ToolFunctionCall, &mut WindowContext) -> Task<Result<ToolFunctionCall>>>,
+ definition: ToolFunctionDefinition,
+}
+
+impl Tool {
+ fn new(
+ type_id: TypeId,
+ call: Box<dyn Fn(&ToolFunctionCall, &mut WindowContext) -> Task<Result<ToolFunctionCall>>>,
+ definition: ToolFunctionDefinition,
+ ) -> Self {
+ Self {
+ enabled: AtomicBool::new(true),
+ type_id,
+ call,
+ definition,
+ }
+ }
+}
+
pub struct ToolRegistry {
- tools: HashMap<
- String,
- Box<dyn Fn(&ToolFunctionCall, &mut WindowContext) -> Task<Result<ToolFunctionCall>>>,
- >,
- definitions: Vec<ToolFunctionDefinition>,
- status_views: Vec<AnyView>,
+ tools: HashMap<String, Tool>,
}
impl ToolRegistry {
pub fn new() -> Self {
Self {
tools: HashMap::new(),
- definitions: Vec::new(),
- status_views: Vec::new(),
}
}
- pub fn definitions(&self) -> &[ToolFunctionDefinition] {
- &self.definitions
+ pub fn set_tool_enabled<T: 'static + LanguageModelTool>(&self, is_enabled: bool) {
+ for tool in self.tools.values() {
+ if tool.type_id == TypeId::of::<T>() {
+ tool.enabled.store(is_enabled, SeqCst);
+ return;
+ }
+ }
+ }
+
+ pub fn is_tool_enabled<T: 'static + LanguageModelTool>(&self) -> bool {
+ for tool in self.tools.values() {
+ if tool.type_id == TypeId::of::<T>() {
+ return tool.enabled.load(SeqCst);
+ }
+ }
+ false
+ }
+
+ pub fn definitions(&self) -> Vec<ToolFunctionDefinition> {
+ self.tools
+ .values()
+ .filter(|tool| tool.enabled.load(SeqCst))
+ .map(|tool| tool.definition.clone())
+ .collect()
}
pub fn register<T: 'static + LanguageModelTool>(
&mut self,
tool: T,
- cx: &mut WindowContext,
+ _cx: &mut WindowContext,
) -> Result<()> {
- self.definitions.push(tool.definition());
-
- if let Some(tool_view) = tool.status_view(cx) {
- self.status_views.push(tool_view);
- }
+ let definition = tool.definition();
let name = tool.name();
- let previous = self.tools.insert(
- name.clone(),
- // registry.call(tool_call, cx)
+
+ let registered_tool = Tool::new(
+ TypeId::of::<T>(),
Box::new(
move |tool_call: &ToolFunctionCall, cx: &mut WindowContext| {
let name = tool_call.name.clone();
@@ -77,8 +115,11 @@ impl ToolRegistry {
})
},
),
+ definition,
);
+ let previous = self.tools.insert(name.clone(), registered_tool);
+
if previous.is_some() {
return Err(anyhow!("already registered a tool with name {}", name));
}
@@ -109,11 +150,7 @@ impl ToolRegistry {
}
};
- tool(tool_call, cx)
- }
-
- pub fn status_views(&self) -> &[AnyView] {
- &self.status_views
+ (tool.call)(tool_call, cx)
}
}
@@ -104,8 +104,4 @@ pub trait LanguageModelTool {
output: Result<Self::Output>,
cx: &mut WindowContext,
) -> View<Self::View>;
-
- fn status_view(&self, _cx: &mut WindowContext) -> Option<AnyView> {
- None
- }
}