Detailed changes
@@ -59,6 +59,7 @@ dependencies = [
"assistant_slash_command",
"assistant_slash_commands",
"assistant_tool",
+ "assistant_tools",
"async-watch",
"audio",
"buffer_diff",
@@ -147,7 +148,6 @@ dependencies = [
"deepseek",
"fs",
"gpui",
- "indexmap",
"language_model",
"lmstudio",
"log",
@@ -2987,7 +2987,6 @@ dependencies = [
"anyhow",
"assistant_context_editor",
"assistant_slash_command",
- "assistant_tool",
"async-stripe",
"async-trait",
"async-tungstenite",
@@ -771,7 +771,6 @@
"tools": {
"copy_path": true,
"create_directory": true,
- "create_file": true,
"delete_path": true,
"diagnostics": true,
"edit_file": true,
@@ -102,6 +102,7 @@ zed_llm_client.workspace = true
zstd.workspace = true
[dev-dependencies]
+assistant_tools.workspace = true
buffer_diff = { workspace = true, features = ["test-support"] }
editor = { workspace = true, features = ["test-support"] }
gpui = { workspace = true, "features" = ["test-support"] }
@@ -1144,6 +1144,10 @@ impl ActiveThread {
cx,
);
}
+ ThreadEvent::ProfileChanged => {
+ self.save_thread(cx);
+ cx.notify();
+ }
}
}
@@ -3,6 +3,7 @@ mod agent_configuration;
mod agent_diff;
mod agent_model_selector;
mod agent_panel;
+mod agent_profile;
mod buffer_codegen;
mod context;
mod context_picker;
@@ -2,25 +2,21 @@ mod profile_modal_header;
use std::sync::Arc;
-use agent_settings::{AgentProfile, AgentProfileId, AgentSettings, builtin_profiles};
+use agent_settings::{AgentProfileId, AgentSettings, builtin_profiles};
use assistant_tool::ToolWorkingSet;
-use convert_case::{Case, Casing as _};
use editor::Editor;
use fs::Fs;
-use gpui::{
- DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, Subscription, WeakEntity,
- prelude::*,
-};
-use settings::{Settings as _, update_settings_file};
+use gpui::{DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, Subscription, prelude::*};
+use settings::Settings as _;
use ui::{
KeyBinding, ListItem, ListItemSpacing, ListSeparator, Navigable, NavigableEntry, prelude::*,
};
-use util::ResultExt as _;
use workspace::{ModalView, Workspace};
use crate::agent_configuration::manage_profiles_modal::profile_modal_header::ProfileModalHeader;
use crate::agent_configuration::tool_picker::{ToolPicker, ToolPickerDelegate};
-use crate::{AgentPanel, ManageProfiles, ThreadStore};
+use crate::agent_profile::AgentProfile;
+use crate::{AgentPanel, ManageProfiles};
use super::tool_picker::ToolPickerMode;
@@ -103,7 +99,6 @@ pub struct NewProfileMode {
pub struct ManageProfilesModal {
fs: Arc<dyn Fs>,
tools: Entity<ToolWorkingSet>,
- thread_store: WeakEntity<ThreadStore>,
focus_handle: FocusHandle,
mode: Mode,
}
@@ -119,9 +114,8 @@ impl ManageProfilesModal {
let fs = workspace.app_state().fs.clone();
let thread_store = panel.read(cx).thread_store();
let tools = thread_store.read(cx).tools();
- let thread_store = thread_store.downgrade();
workspace.toggle_modal(window, cx, |window, cx| {
- let mut this = Self::new(fs, tools, thread_store, window, cx);
+ let mut this = Self::new(fs, tools, window, cx);
if let Some(profile_id) = action.customize_tools.clone() {
this.configure_builtin_tools(profile_id, window, cx);
@@ -136,7 +130,6 @@ impl ManageProfilesModal {
pub fn new(
fs: Arc<dyn Fs>,
tools: Entity<ToolWorkingSet>,
- thread_store: WeakEntity<ThreadStore>,
window: &mut Window,
cx: &mut Context<Self>,
) -> Self {
@@ -145,7 +138,6 @@ impl ManageProfilesModal {
Self {
fs,
tools,
- thread_store,
focus_handle,
mode: Mode::choose_profile(window, cx),
}
@@ -206,7 +198,6 @@ impl ManageProfilesModal {
ToolPickerMode::McpTools,
self.fs.clone(),
self.tools.clone(),
- self.thread_store.clone(),
profile_id.clone(),
profile,
cx,
@@ -244,7 +235,6 @@ impl ManageProfilesModal {
ToolPickerMode::BuiltinTools,
self.fs.clone(),
self.tools.clone(),
- self.thread_store.clone(),
profile_id.clone(),
profile,
cx,
@@ -270,32 +260,10 @@ impl ManageProfilesModal {
match &self.mode {
Mode::ChooseProfile { .. } => {}
Mode::NewProfile(mode) => {
- let settings = AgentSettings::get_global(cx);
-
- let base_profile = mode
- .base_profile_id
- .as_ref()
- .and_then(|profile_id| settings.profiles.get(profile_id).cloned());
-
let name = mode.name_editor.read(cx).text(cx);
- let profile_id = AgentProfileId(name.to_case(Case::Kebab).into());
-
- let profile = AgentProfile {
- name: name.into(),
- tools: base_profile
- .as_ref()
- .map(|profile| profile.tools.clone())
- .unwrap_or_default(),
- enable_all_context_servers: base_profile
- .as_ref()
- .map(|profile| profile.enable_all_context_servers)
- .unwrap_or_default(),
- context_servers: base_profile
- .map(|profile| profile.context_servers)
- .unwrap_or_default(),
- };
-
- self.create_profile(profile_id.clone(), profile, cx);
+
+ let profile_id =
+ AgentProfile::create(name, mode.base_profile_id.clone(), self.fs.clone(), cx);
self.view_profile(profile_id, window, cx);
}
Mode::ViewProfile(_) => {}
@@ -325,19 +293,6 @@ impl ManageProfilesModal {
}
}
}
-
- fn create_profile(
- &self,
- profile_id: AgentProfileId,
- profile: AgentProfile,
- cx: &mut Context<Self>,
- ) {
- update_settings_file::<AgentSettings>(self.fs.clone(), cx, {
- move |settings, _cx| {
- settings.create_profile(profile_id, profile).log_err();
- }
- });
- }
}
impl ModalView for ManageProfilesModal {}
@@ -520,14 +475,13 @@ impl ManageProfilesModal {
) -> impl IntoElement {
let settings = AgentSettings::get_global(cx);
- let profile_id = &settings.default_profile;
let profile_name = settings
.profiles
.get(&mode.profile_id)
.map(|profile| profile.name.clone())
.unwrap_or_else(|| "Unknown".into());
- let icon = match profile_id.as_str() {
+ let icon = match mode.profile_id.as_str() {
"write" => IconName::Pencil,
"ask" => IconName::MessageBubbles,
_ => IconName::UserRoundPen,
@@ -1,19 +1,17 @@
use std::{collections::BTreeMap, sync::Arc};
use agent_settings::{
- AgentProfile, AgentProfileContent, AgentProfileId, AgentSettings, AgentSettingsContent,
+ AgentProfileContent, AgentProfileId, AgentProfileSettings, AgentSettings, AgentSettingsContent,
ContextServerPresetContent,
};
use assistant_tool::{ToolSource, ToolWorkingSet};
use fs::Fs;
use gpui::{App, Context, DismissEvent, Entity, EventEmitter, Focusable, Task, WeakEntity, Window};
use picker::{Picker, PickerDelegate};
-use settings::{Settings as _, update_settings_file};
+use settings::update_settings_file;
use ui::{ListItem, ListItemSpacing, prelude::*};
use util::ResultExt as _;
-use crate::ThreadStore;
-
pub struct ToolPicker {
picker: Entity<Picker<ToolPickerDelegate>>,
}
@@ -71,11 +69,10 @@ pub enum PickerItem {
pub struct ToolPickerDelegate {
tool_picker: WeakEntity<ToolPicker>,
- thread_store: WeakEntity<ThreadStore>,
fs: Arc<dyn Fs>,
items: Arc<Vec<PickerItem>>,
profile_id: AgentProfileId,
- profile: AgentProfile,
+ profile_settings: AgentProfileSettings,
filtered_items: Vec<PickerItem>,
selected_index: usize,
mode: ToolPickerMode,
@@ -86,20 +83,18 @@ impl ToolPickerDelegate {
mode: ToolPickerMode,
fs: Arc<dyn Fs>,
tool_set: Entity<ToolWorkingSet>,
- thread_store: WeakEntity<ThreadStore>,
profile_id: AgentProfileId,
- profile: AgentProfile,
+ profile_settings: AgentProfileSettings,
cx: &mut Context<ToolPicker>,
) -> Self {
let items = Arc::new(Self::resolve_items(mode, &tool_set, cx));
Self {
tool_picker: cx.entity().downgrade(),
- thread_store,
fs,
items,
profile_id,
- profile,
+ profile_settings,
filtered_items: Vec::new(),
selected_index: 0,
mode,
@@ -249,28 +244,31 @@ impl PickerDelegate for ToolPickerDelegate {
};
let is_currently_enabled = if let Some(server_id) = server_id.clone() {
- let preset = self.profile.context_servers.entry(server_id).or_default();
+ let preset = self
+ .profile_settings
+ .context_servers
+ .entry(server_id)
+ .or_default();
let is_enabled = *preset.tools.entry(tool_name.clone()).or_default();
*preset.tools.entry(tool_name.clone()).or_default() = !is_enabled;
is_enabled
} else {
- let is_enabled = *self.profile.tools.entry(tool_name.clone()).or_default();
- *self.profile.tools.entry(tool_name.clone()).or_default() = !is_enabled;
+ let is_enabled = *self
+ .profile_settings
+ .tools
+ .entry(tool_name.clone())
+ .or_default();
+ *self
+ .profile_settings
+ .tools
+ .entry(tool_name.clone())
+ .or_default() = !is_enabled;
is_enabled
};
- let active_profile_id = &AgentSettings::get_global(cx).default_profile;
- if active_profile_id == &self.profile_id {
- self.thread_store
- .update(cx, |this, cx| {
- this.load_profile(self.profile.clone(), cx);
- })
- .log_err();
- }
-
update_settings_file::<AgentSettings>(self.fs.clone(), cx, {
let profile_id = self.profile_id.clone();
- let default_profile = self.profile.clone();
+ let default_profile = self.profile_settings.clone();
let server_id = server_id.clone();
let tool_name = tool_name.clone();
move |settings: &mut AgentSettingsContent, _cx| {
@@ -348,14 +346,18 @@ impl PickerDelegate for ToolPickerDelegate {
),
PickerItem::Tool { name, server_id } => {
let is_enabled = if let Some(server_id) = server_id {
- self.profile
+ self.profile_settings
.context_servers
.get(server_id.as_ref())
.and_then(|preset| preset.tools.get(name))
.copied()
- .unwrap_or(self.profile.enable_all_context_servers)
+ .unwrap_or(self.profile_settings.enable_all_context_servers)
} else {
- self.profile.tools.get(name).copied().unwrap_or(false)
+ self.profile_settings
+ .tools
+ .get(name)
+ .copied()
+ .unwrap_or(false)
};
Some(
@@ -1378,7 +1378,8 @@ impl AgentDiff {
| ThreadEvent::CheckpointChanged
| ThreadEvent::ToolConfirmationNeeded
| ThreadEvent::ToolUseLimitReached
- | ThreadEvent::CancelEditing => {}
+ | ThreadEvent::CancelEditing
+ | ThreadEvent::ProfileChanged => {}
}
}
@@ -0,0 +1,334 @@
+use std::sync::Arc;
+
+use agent_settings::{AgentProfileId, AgentProfileSettings, AgentSettings};
+use assistant_tool::{Tool, ToolSource, ToolWorkingSet};
+use collections::IndexMap;
+use convert_case::{Case, Casing};
+use fs::Fs;
+use gpui::{App, Entity};
+use settings::{Settings, update_settings_file};
+use ui::SharedString;
+use util::ResultExt;
+
+#[derive(Clone, Debug, Eq, PartialEq)]
+pub struct AgentProfile {
+ id: AgentProfileId,
+ tool_set: Entity<ToolWorkingSet>,
+}
+
+pub type AvailableProfiles = IndexMap<AgentProfileId, SharedString>;
+
+impl AgentProfile {
+ pub fn new(id: AgentProfileId, tool_set: Entity<ToolWorkingSet>) -> Self {
+ Self { id, tool_set }
+ }
+
+ /// Saves a new profile to the settings.
+ pub fn create(
+ name: String,
+ base_profile_id: Option<AgentProfileId>,
+ fs: Arc<dyn Fs>,
+ cx: &App,
+ ) -> AgentProfileId {
+ let id = AgentProfileId(name.to_case(Case::Kebab).into());
+
+ let base_profile =
+ base_profile_id.and_then(|id| AgentSettings::get_global(cx).profiles.get(&id).cloned());
+
+ let profile_settings = AgentProfileSettings {
+ name: name.into(),
+ tools: base_profile
+ .as_ref()
+ .map(|profile| profile.tools.clone())
+ .unwrap_or_default(),
+ enable_all_context_servers: base_profile
+ .as_ref()
+ .map(|profile| profile.enable_all_context_servers)
+ .unwrap_or_default(),
+ context_servers: base_profile
+ .map(|profile| profile.context_servers)
+ .unwrap_or_default(),
+ };
+
+ update_settings_file::<AgentSettings>(fs, cx, {
+ let id = id.clone();
+ move |settings, _cx| {
+ settings.create_profile(id, profile_settings).log_err();
+ }
+ });
+
+ id
+ }
+
+ /// Returns a map of AgentProfileIds to their names
+ pub fn available_profiles(cx: &App) -> AvailableProfiles {
+ let mut profiles = AvailableProfiles::default();
+ for (id, profile) in AgentSettings::get_global(cx).profiles.iter() {
+ profiles.insert(id.clone(), profile.name.clone());
+ }
+ profiles
+ }
+
+ pub fn id(&self) -> &AgentProfileId {
+ &self.id
+ }
+
+ pub fn enabled_tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
+ let Some(settings) = AgentSettings::get_global(cx).profiles.get(&self.id) else {
+ return Vec::new();
+ };
+
+ self.tool_set
+ .read(cx)
+ .tools(cx)
+ .into_iter()
+ .filter(|tool| Self::is_enabled(settings, tool.source(), tool.name()))
+ .collect()
+ }
+
+ fn is_enabled(settings: &AgentProfileSettings, source: ToolSource, name: String) -> bool {
+ match source {
+ ToolSource::Native => *settings.tools.get(name.as_str()).unwrap_or(&false),
+ ToolSource::ContextServer { id } => {
+ if settings.enable_all_context_servers {
+ return true;
+ }
+
+ let Some(preset) = settings.context_servers.get(id.as_ref()) else {
+ return false;
+ };
+ *preset.tools.get(name.as_str()).unwrap_or(&false)
+ }
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use agent_settings::ContextServerPreset;
+ use assistant_tool::ToolRegistry;
+ use collections::IndexMap;
+ use gpui::{AppContext, TestAppContext};
+ use http_client::FakeHttpClient;
+ use project::Project;
+ use settings::{Settings, SettingsStore};
+ use ui::SharedString;
+
+ use super::*;
+
+ #[gpui::test]
+ async fn test_enabled_built_in_tools_for_profile(cx: &mut TestAppContext) {
+ init_test_settings(cx);
+
+ let id = AgentProfileId::default();
+ let profile_settings = cx.read(|cx| {
+ AgentSettings::get_global(cx)
+ .profiles
+ .get(&id)
+ .unwrap()
+ .clone()
+ });
+ let tool_set = default_tool_set(cx);
+
+ let profile = AgentProfile::new(id.clone(), tool_set);
+
+ let mut enabled_tools = cx
+ .read(|cx| profile.enabled_tools(cx))
+ .into_iter()
+ .map(|tool| tool.name())
+ .collect::<Vec<_>>();
+ enabled_tools.sort();
+
+ let mut expected_tools = profile_settings
+ .tools
+ .into_iter()
+ .filter_map(|(tool, enabled)| enabled.then_some(tool.to_string()))
+ // Provider dependent
+ .filter(|tool| tool != "web_search")
+ .collect::<Vec<_>>();
+ // Plus all registered MCP tools
+ expected_tools.extend(["enabled_mcp_tool".into(), "disabled_mcp_tool".into()]);
+ expected_tools.sort();
+
+ assert_eq!(enabled_tools, expected_tools);
+ }
+
+ #[gpui::test]
+ async fn test_custom_mcp_settings(cx: &mut TestAppContext) {
+ init_test_settings(cx);
+
+ let id = AgentProfileId("custom_mcp".into());
+ let profile_settings = cx.read(|cx| {
+ AgentSettings::get_global(cx)
+ .profiles
+ .get(&id)
+ .unwrap()
+ .clone()
+ });
+ let tool_set = default_tool_set(cx);
+
+ let profile = AgentProfile::new(id.clone(), tool_set);
+
+ let mut enabled_tools = cx
+ .read(|cx| profile.enabled_tools(cx))
+ .into_iter()
+ .map(|tool| tool.name())
+ .collect::<Vec<_>>();
+ enabled_tools.sort();
+
+ let mut expected_tools = profile_settings.context_servers["mcp"]
+ .tools
+ .iter()
+ .filter_map(|(key, enabled)| enabled.then(|| key.to_string()))
+ .collect::<Vec<_>>();
+ expected_tools.sort();
+
+ assert_eq!(enabled_tools, expected_tools);
+ }
+
+ #[gpui::test]
+ async fn test_only_built_in(cx: &mut TestAppContext) {
+ init_test_settings(cx);
+
+ let id = AgentProfileId("write_minus_mcp".into());
+ let profile_settings = cx.read(|cx| {
+ AgentSettings::get_global(cx)
+ .profiles
+ .get(&id)
+ .unwrap()
+ .clone()
+ });
+ let tool_set = default_tool_set(cx);
+
+ let profile = AgentProfile::new(id.clone(), tool_set);
+
+ let mut enabled_tools = cx
+ .read(|cx| profile.enabled_tools(cx))
+ .into_iter()
+ .map(|tool| tool.name())
+ .collect::<Vec<_>>();
+ enabled_tools.sort();
+
+ let mut expected_tools = profile_settings
+ .tools
+ .into_iter()
+ .filter_map(|(tool, enabled)| enabled.then_some(tool.to_string()))
+ // Provider dependent
+ .filter(|tool| tool != "web_search")
+ .collect::<Vec<_>>();
+ expected_tools.sort();
+
+ assert_eq!(enabled_tools, expected_tools);
+ }
+
+ fn init_test_settings(cx: &mut TestAppContext) {
+ cx.update(|cx| {
+ let settings_store = SettingsStore::test(cx);
+ cx.set_global(settings_store);
+ Project::init_settings(cx);
+ AgentSettings::register(cx);
+ language_model::init_settings(cx);
+ ToolRegistry::default_global(cx);
+ assistant_tools::init(FakeHttpClient::with_404_response(), cx);
+ });
+
+ cx.update(|cx| {
+ let mut agent_settings = AgentSettings::get_global(cx).clone();
+ agent_settings.profiles.insert(
+ AgentProfileId("write_minus_mcp".into()),
+ AgentProfileSettings {
+ name: "write_minus_mcp".into(),
+ enable_all_context_servers: false,
+ ..agent_settings.profiles[&AgentProfileId::default()].clone()
+ },
+ );
+ agent_settings.profiles.insert(
+ AgentProfileId("custom_mcp".into()),
+ AgentProfileSettings {
+ name: "mcp".into(),
+ tools: IndexMap::default(),
+ enable_all_context_servers: false,
+ context_servers: IndexMap::from_iter([("mcp".into(), context_server_preset())]),
+ },
+ );
+ AgentSettings::override_global(agent_settings, cx);
+ })
+ }
+
+ fn context_server_preset() -> ContextServerPreset {
+ ContextServerPreset {
+ tools: IndexMap::from_iter([
+ ("enabled_mcp_tool".into(), true),
+ ("disabled_mcp_tool".into(), false),
+ ]),
+ }
+ }
+
+ fn default_tool_set(cx: &mut TestAppContext) -> Entity<ToolWorkingSet> {
+ cx.new(|_| {
+ let mut tool_set = ToolWorkingSet::default();
+ tool_set.insert(Arc::new(FakeTool::new("enabled_mcp_tool", "mcp")));
+ tool_set.insert(Arc::new(FakeTool::new("disabled_mcp_tool", "mcp")));
+ tool_set
+ })
+ }
+
+ struct FakeTool {
+ name: String,
+ source: SharedString,
+ }
+
+ impl FakeTool {
+ fn new(name: impl Into<String>, source: impl Into<SharedString>) -> Self {
+ Self {
+ name: name.into(),
+ source: source.into(),
+ }
+ }
+ }
+
+ impl Tool for FakeTool {
+ fn name(&self) -> String {
+ self.name.clone()
+ }
+
+ fn source(&self) -> ToolSource {
+ ToolSource::ContextServer {
+ id: self.source.clone(),
+ }
+ }
+
+ fn description(&self) -> String {
+ unimplemented!()
+ }
+
+ fn icon(&self) -> ui::IconName {
+ unimplemented!()
+ }
+
+ fn needs_confirmation(&self, _input: &serde_json::Value, _cx: &App) -> bool {
+ unimplemented!()
+ }
+
+ fn ui_text(&self, _input: &serde_json::Value) -> String {
+ unimplemented!()
+ }
+
+ fn run(
+ self: Arc<Self>,
+ _input: serde_json::Value,
+ _request: Arc<language_model::LanguageModelRequest>,
+ _project: Entity<Project>,
+ _action_log: Entity<assistant_tool::ActionLog>,
+ _model: Arc<dyn language_model::LanguageModel>,
+ _window: Option<gpui::AnyWindowHandle>,
+ _cx: &mut App,
+ ) -> assistant_tool::ToolResult {
+ unimplemented!()
+ }
+
+ fn may_perform_edits(&self) -> bool {
+ unimplemented!()
+ }
+ }
+}
@@ -175,8 +175,7 @@ impl MessageEditor {
)
});
- let incompatible_tools =
- cx.new(|cx| IncompatibleToolsState::new(thread.read(cx).tools().clone(), cx));
+ let incompatible_tools = cx.new(|cx| IncompatibleToolsState::new(thread.clone(), cx));
let subscriptions = vec![
cx.subscribe_in(&context_strip, window, Self::handle_context_strip_event),
@@ -204,15 +203,8 @@ impl MessageEditor {
)
});
- let profile_selector = cx.new(|cx| {
- ProfileSelector::new(
- fs,
- thread.clone(),
- thread_store,
- editor.focus_handle(cx),
- cx,
- )
- });
+ let profile_selector =
+ cx.new(|cx| ProfileSelector::new(fs, thread.clone(), editor.focus_handle(cx), cx));
Self {
editor: editor.clone(),
@@ -1,26 +1,24 @@
use std::sync::Arc;
-use agent_settings::{
- AgentDockPosition, AgentProfile, AgentProfileId, AgentSettings, GroupedAgentProfiles,
- builtin_profiles,
-};
+use agent_settings::{AgentDockPosition, AgentProfileId, AgentSettings, builtin_profiles};
use fs::Fs;
-use gpui::{Action, Empty, Entity, FocusHandle, Subscription, WeakEntity, prelude::*};
+use gpui::{Action, Empty, Entity, FocusHandle, Subscription, prelude::*};
use language_model::LanguageModelRegistry;
use settings::{Settings as _, SettingsStore, update_settings_file};
use ui::{
ContextMenu, ContextMenuEntry, DocumentationSide, PopoverMenu, PopoverMenuHandle, Tooltip,
prelude::*,
};
-use util::ResultExt as _;
-use crate::{ManageProfiles, Thread, ThreadStore, ToggleProfileSelector};
+use crate::{
+ ManageProfiles, Thread, ToggleProfileSelector,
+ agent_profile::{AgentProfile, AvailableProfiles},
+};
pub struct ProfileSelector {
- profiles: GroupedAgentProfiles,
+ profiles: AvailableProfiles,
fs: Arc<dyn Fs>,
thread: Entity<Thread>,
- thread_store: WeakEntity<ThreadStore>,
menu_handle: PopoverMenuHandle<ContextMenu>,
focus_handle: FocusHandle,
_subscriptions: Vec<Subscription>,
@@ -30,7 +28,6 @@ impl ProfileSelector {
pub fn new(
fs: Arc<dyn Fs>,
thread: Entity<Thread>,
- thread_store: WeakEntity<ThreadStore>,
focus_handle: FocusHandle,
cx: &mut Context<Self>,
) -> Self {
@@ -39,10 +36,9 @@ impl ProfileSelector {
});
Self {
- profiles: GroupedAgentProfiles::from_settings(AgentSettings::get_global(cx)),
+ profiles: AgentProfile::available_profiles(cx),
fs,
thread,
- thread_store,
menu_handle: PopoverMenuHandle::default(),
focus_handle,
_subscriptions: vec![settings_subscription],
@@ -54,7 +50,7 @@ impl ProfileSelector {
}
fn refresh_profiles(&mut self, cx: &mut Context<Self>) {
- self.profiles = GroupedAgentProfiles::from_settings(AgentSettings::get_global(cx));
+ self.profiles = AgentProfile::available_profiles(cx);
}
fn build_context_menu(
@@ -64,21 +60,30 @@ impl ProfileSelector {
) -> Entity<ContextMenu> {
ContextMenu::build(window, cx, |mut menu, _window, cx| {
let settings = AgentSettings::get_global(cx);
- for (profile_id, profile) in self.profiles.builtin.iter() {
+
+ let mut found_non_builtin = false;
+ for (profile_id, profile_name) in self.profiles.iter() {
+ if !builtin_profiles::is_builtin(profile_id) {
+ found_non_builtin = true;
+ continue;
+ }
menu = menu.item(self.menu_entry_for_profile(
profile_id.clone(),
- profile,
+ profile_name,
settings,
cx,
));
}
- if !self.profiles.custom.is_empty() {
+ if found_non_builtin {
menu = menu.separator().header("Custom Profiles");
- for (profile_id, profile) in self.profiles.custom.iter() {
+ for (profile_id, profile_name) in self.profiles.iter() {
+ if builtin_profiles::is_builtin(profile_id) {
+ continue;
+ }
menu = menu.item(self.menu_entry_for_profile(
profile_id.clone(),
- profile,
+ profile_name,
settings,
cx,
));
@@ -99,19 +104,20 @@ impl ProfileSelector {
fn menu_entry_for_profile(
&self,
profile_id: AgentProfileId,
- profile: &AgentProfile,
+ profile_name: &SharedString,
settings: &AgentSettings,
- _cx: &App,
+ cx: &App,
) -> ContextMenuEntry {
- let documentation = match profile.name.to_lowercase().as_str() {
+ let documentation = match profile_name.to_lowercase().as_str() {
builtin_profiles::WRITE => Some("Get help to write anything."),
builtin_profiles::ASK => Some("Chat about your codebase."),
builtin_profiles::MINIMAL => Some("Chat about anything with no tools."),
_ => None,
};
+ let thread_profile_id = self.thread.read(cx).profile().id();
- let entry = ContextMenuEntry::new(profile.name.clone())
- .toggleable(IconPosition::End, profile_id == settings.default_profile);
+ let entry = ContextMenuEntry::new(profile_name.clone())
+ .toggleable(IconPosition::End, &profile_id == thread_profile_id);
let entry = if let Some(doc_text) = documentation {
entry.documentation_aside(documentation_side(settings.dock), move |_| {
@@ -123,7 +129,7 @@ impl ProfileSelector {
entry.handler({
let fs = self.fs.clone();
- let thread_store = self.thread_store.clone();
+ let thread = self.thread.clone();
let profile_id = profile_id.clone();
move |_window, cx| {
update_settings_file::<AgentSettings>(fs.clone(), cx, {
@@ -133,11 +139,9 @@ impl ProfileSelector {
}
});
- thread_store
- .update(cx, |this, cx| {
- this.load_profile_by_id(profile_id.clone(), cx);
- })
- .log_err();
+ thread.update(cx, |this, cx| {
+ this.set_profile(profile_id.clone(), cx);
+ });
}
})
}
@@ -146,7 +150,7 @@ impl ProfileSelector {
impl Render for ProfileSelector {
fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let settings = AgentSettings::get_global(cx);
- let profile_id = &settings.default_profile;
+ let profile_id = self.thread.read(cx).profile().id();
let profile = settings.profiles.get(profile_id);
let selected_profile = profile
@@ -4,7 +4,7 @@ use std::ops::Range;
use std::sync::Arc;
use std::time::Instant;
-use agent_settings::{AgentSettings, CompletionMode};
+use agent_settings::{AgentProfileId, AgentSettings, CompletionMode};
use anyhow::{Result, anyhow};
use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
use chrono::{DateTime, Utc};
@@ -41,6 +41,7 @@ use uuid::Uuid;
use zed_llm_client::{CompletionIntent, CompletionRequestStatus};
use crate::ThreadStore;
+use crate::agent_profile::AgentProfile;
use crate::context::{AgentContext, AgentContextHandle, ContextLoadResult, LoadedContext};
use crate::thread_store::{
SerializedCrease, SerializedLanguageModel, SerializedMessage, SerializedMessageSegment,
@@ -360,6 +361,7 @@ pub struct Thread {
>,
remaining_turns: u32,
configured_model: Option<ConfiguredModel>,
+ profile: AgentProfile,
}
#[derive(Clone, Debug, PartialEq, Eq)]
@@ -407,6 +409,7 @@ impl Thread {
) -> Self {
let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel();
let configured_model = LanguageModelRegistry::read_global(cx).default_model();
+ let profile_id = AgentSettings::get_global(cx).default_profile.clone();
Self {
id: ThreadId::new(),
@@ -449,6 +452,7 @@ impl Thread {
request_callback: None,
remaining_turns: u32::MAX,
configured_model,
+ profile: AgentProfile::new(profile_id, tools),
}
}
@@ -495,6 +499,9 @@ impl Thread {
let completion_mode = serialized
.completion_mode
.unwrap_or_else(|| AgentSettings::get_global(cx).preferred_completion_mode);
+ let profile_id = serialized
+ .profile
+ .unwrap_or_else(|| AgentSettings::get_global(cx).default_profile.clone());
Self {
id,
@@ -554,7 +561,7 @@ impl Thread {
pending_checkpoint: None,
project: project.clone(),
prompt_builder,
- tools,
+ tools: tools.clone(),
tool_use,
action_log: cx.new(|_| ActionLog::new(project)),
initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
@@ -570,6 +577,7 @@ impl Thread {
request_callback: None,
remaining_turns: u32::MAX,
configured_model,
+ profile: AgentProfile::new(profile_id, tools),
}
}
@@ -585,6 +593,17 @@ impl Thread {
&self.id
}
+ pub fn profile(&self) -> &AgentProfile {
+ &self.profile
+ }
+
+ pub fn set_profile(&mut self, id: AgentProfileId, cx: &mut Context<Self>) {
+ if &id != self.profile.id() {
+ self.profile = AgentProfile::new(id, self.tools.clone());
+ cx.emit(ThreadEvent::ProfileChanged);
+ }
+ }
+
pub fn is_empty(&self) -> bool {
self.messages.is_empty()
}
@@ -919,8 +938,7 @@ impl Thread {
model: Arc<dyn LanguageModel>,
) -> Vec<LanguageModelRequestTool> {
if model.supports_tools() {
- self.tools()
- .read(cx)
+ self.profile
.enabled_tools(cx)
.into_iter()
.filter_map(|tool| {
@@ -1180,6 +1198,7 @@ impl Thread {
}),
completion_mode: Some(this.completion_mode),
tool_use_limit_reached: this.tool_use_limit_reached,
+ profile: Some(this.profile.id().clone()),
})
})
}
@@ -2121,7 +2140,7 @@ impl Thread {
window: Option<AnyWindowHandle>,
cx: &mut Context<Thread>,
) {
- let available_tools = self.tools.read(cx).enabled_tools(cx);
+ let available_tools = self.profile.enabled_tools(cx);
let tool_list = available_tools
.iter()
@@ -2213,19 +2232,15 @@ impl Thread {
) -> Task<()> {
let tool_name: Arc<str> = tool.name().into();
- let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
- Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))).into()
- } else {
- tool.run(
- input,
- request,
- self.project.clone(),
- self.action_log.clone(),
- model,
- window,
- cx,
- )
- };
+ let tool_result = tool.run(
+ input,
+ request,
+ self.project.clone(),
+ self.action_log.clone(),
+ model,
+ window,
+ cx,
+ );
// Store the card separately if it exists
if let Some(card) = tool_result.card.clone() {
@@ -2344,8 +2359,7 @@ impl Thread {
let client = self.project.read(cx).client();
let enabled_tool_names: Vec<String> = self
- .tools()
- .read(cx)
+ .profile
.enabled_tools(cx)
.iter()
.map(|tool| tool.name())
@@ -2858,6 +2872,7 @@ pub enum ThreadEvent {
ToolUseLimitReached,
CancelEditing,
CompletionCanceled,
+ ProfileChanged,
}
impl EventEmitter<ThreadEvent> for Thread {}
@@ -2872,7 +2887,7 @@ struct PendingCompletion {
mod tests {
use super::*;
use crate::{ThreadStore, context::load_context, context_store::ContextStore, thread_store};
- use agent_settings::{AgentSettings, LanguageModelParameters};
+ use agent_settings::{AgentProfileId, AgentSettings, LanguageModelParameters};
use assistant_tool::ToolRegistry;
use editor::EditorSettings;
use gpui::TestAppContext;
@@ -3285,6 +3300,71 @@ fn main() {{
);
}
+ #[gpui::test]
+ async fn test_storing_profile_setting_per_thread(cx: &mut TestAppContext) {
+ init_test_settings(cx);
+
+ let project = create_test_project(
+ cx,
+ json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
+ )
+ .await;
+
+ let (_workspace, thread_store, thread, _context_store, _model) =
+ setup_test_environment(cx, project.clone()).await;
+
+ // Check that we are starting with the default profile
+ let profile = cx.read(|cx| thread.read(cx).profile.clone());
+ let tool_set = cx.read(|cx| thread_store.read(cx).tools());
+ assert_eq!(
+ profile,
+ AgentProfile::new(AgentProfileId::default(), tool_set)
+ );
+ }
+
+ #[gpui::test]
+ async fn test_serializing_thread_profile(cx: &mut TestAppContext) {
+ init_test_settings(cx);
+
+ let project = create_test_project(
+ cx,
+ json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
+ )
+ .await;
+
+ let (_workspace, thread_store, thread, _context_store, _model) =
+ setup_test_environment(cx, project.clone()).await;
+
+ // Profile gets serialized with default values
+ let serialized = thread
+ .update(cx, |thread, cx| thread.serialize(cx))
+ .await
+ .unwrap();
+
+ assert_eq!(serialized.profile, Some(AgentProfileId::default()));
+
+ let deserialized = cx.update(|cx| {
+ thread.update(cx, |thread, cx| {
+ Thread::deserialize(
+ thread.id.clone(),
+ serialized,
+ thread.project.clone(),
+ thread.tools.clone(),
+ thread.prompt_builder.clone(),
+ thread.project_context.clone(),
+ None,
+ cx,
+ )
+ })
+ });
+ let tool_set = cx.read(|cx| thread_store.read(cx).tools());
+
+ assert_eq!(
+ deserialized.profile,
+ AgentProfile::new(AgentProfileId::default(), tool_set)
+ );
+ }
+
#[gpui::test]
async fn test_temperature_setting(cx: &mut TestAppContext) {
init_test_settings(cx);
@@ -3,9 +3,9 @@ use std::path::{Path, PathBuf};
use std::rc::Rc;
use std::sync::{Arc, Mutex};
-use agent_settings::{AgentProfile, AgentProfileId, AgentSettings, CompletionMode};
+use agent_settings::{AgentProfileId, CompletionMode};
use anyhow::{Context as _, Result, anyhow};
-use assistant_tool::{ToolId, ToolSource, ToolWorkingSet};
+use assistant_tool::{ToolId, ToolWorkingSet};
use chrono::{DateTime, Utc};
use collections::HashMap;
use context_server::ContextServerId;
@@ -25,7 +25,6 @@ use prompt_store::{
UserRulesContext, WorktreeContext,
};
use serde::{Deserialize, Serialize};
-use settings::{Settings as _, SettingsStore};
use ui::Window;
use util::ResultExt as _;
@@ -147,12 +146,7 @@ impl ThreadStore {
prompt_store: Option<Entity<PromptStore>>,
cx: &mut Context<Self>,
) -> (Self, oneshot::Receiver<()>) {
- let mut subscriptions = vec![
- cx.observe_global::<SettingsStore>(move |this: &mut Self, cx| {
- this.load_default_profile(cx);
- }),
- cx.subscribe(&project, Self::handle_project_event),
- ];
+ let mut subscriptions = vec![cx.subscribe(&project, Self::handle_project_event)];
if let Some(prompt_store) = prompt_store.as_ref() {
subscriptions.push(cx.subscribe(
@@ -200,7 +194,6 @@ impl ThreadStore {
_reload_system_prompt_task: reload_system_prompt_task,
_subscriptions: subscriptions,
};
- this.load_default_profile(cx);
this.register_context_server_handlers(cx);
this.reload(cx).detach_and_log_err(cx);
(this, ready_rx)
@@ -520,86 +513,6 @@ impl ThreadStore {
})
}
- fn load_default_profile(&self, cx: &mut Context<Self>) {
- let assistant_settings = AgentSettings::get_global(cx);
-
- self.load_profile_by_id(assistant_settings.default_profile.clone(), cx);
- }
-
- pub fn load_profile_by_id(&self, profile_id: AgentProfileId, cx: &mut Context<Self>) {
- let assistant_settings = AgentSettings::get_global(cx);
-
- if let Some(profile) = assistant_settings.profiles.get(&profile_id) {
- self.load_profile(profile.clone(), cx);
- }
- }
-
- pub fn load_profile(&self, profile: AgentProfile, cx: &mut Context<Self>) {
- self.tools.update(cx, |tools, cx| {
- tools.disable_all_tools(cx);
- tools.enable(
- ToolSource::Native,
- &profile
- .tools
- .into_iter()
- .filter_map(|(tool, enabled)| enabled.then(|| tool))
- .collect::<Vec<_>>(),
- cx,
- );
- });
-
- if profile.enable_all_context_servers {
- for context_server_id in self
- .project
- .read(cx)
- .context_server_store()
- .read(cx)
- .all_server_ids()
- {
- self.tools.update(cx, |tools, cx| {
- tools.enable_source(
- ToolSource::ContextServer {
- id: context_server_id.0.into(),
- },
- cx,
- );
- });
- }
- // Enable all the tools from all context servers, but disable the ones that are explicitly disabled
- for (context_server_id, preset) in profile.context_servers {
- self.tools.update(cx, |tools, cx| {
- tools.disable(
- ToolSource::ContextServer {
- id: context_server_id.into(),
- },
- &preset
- .tools
- .into_iter()
- .filter_map(|(tool, enabled)| (!enabled).then(|| tool))
- .collect::<Vec<_>>(),
- cx,
- )
- })
- }
- } else {
- for (context_server_id, preset) in profile.context_servers {
- self.tools.update(cx, |tools, cx| {
- tools.enable(
- ToolSource::ContextServer {
- id: context_server_id.into(),
- },
- &preset
- .tools
- .into_iter()
- .filter_map(|(tool, enabled)| enabled.then(|| tool))
- .collect::<Vec<_>>(),
- cx,
- )
- })
- }
- }
- }
-
fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
cx.subscribe(
&self.project.read(cx).context_server_store(),
@@ -618,6 +531,7 @@ impl ThreadStore {
match event {
project::context_server_store::Event::ServerStatusChanged { server_id, status } => {
match status {
+ ContextServerStatus::Starting => {}
ContextServerStatus::Running => {
if let Some(server) =
context_server_store.read(cx).get_running_server(server_id)
@@ -656,10 +570,9 @@ impl ThreadStore {
.log_err();
if let Some(tool_ids) = tool_ids {
- this.update(cx, |this, cx| {
+ this.update(cx, |this, _| {
this.context_server_tool_ids
.insert(server_id, tool_ids);
- this.load_default_profile(cx);
})
.log_err();
}
@@ -675,10 +588,8 @@ impl ThreadStore {
tool_working_set.update(cx, |tool_working_set, _| {
tool_working_set.remove(&tool_ids);
});
- self.load_default_profile(cx);
}
}
- _ => {}
}
}
}
@@ -714,6 +625,8 @@ pub struct SerializedThread {
pub completion_mode: Option<CompletionMode>,
#[serde(default)]
pub tool_use_limit_reached: bool,
+ #[serde(default)]
+ pub profile: Option<AgentProfileId>,
}
#[derive(Serialize, Deserialize, Debug)]
@@ -856,6 +769,7 @@ impl LegacySerializedThread {
model: None,
completion_mode: None,
tool_use_limit_reached: false,
+ profile: None,
}
}
}
@@ -1,30 +1,33 @@
use std::sync::Arc;
-use assistant_tool::{Tool, ToolSource, ToolWorkingSet, ToolWorkingSetEvent};
+use assistant_tool::{Tool, ToolSource};
use collections::HashMap;
use gpui::{App, Context, Entity, IntoElement, Render, Subscription, Window};
use language_model::{LanguageModel, LanguageModelToolSchemaFormat};
use ui::prelude::*;
+use crate::{Thread, ThreadEvent};
+
pub struct IncompatibleToolsState {
cache: HashMap<LanguageModelToolSchemaFormat, Vec<Arc<dyn Tool>>>,
- tool_working_set: Entity<ToolWorkingSet>,
- _tool_working_set_subscription: Subscription,
+ thread: Entity<Thread>,
+ _thread_subscription: Subscription,
}
impl IncompatibleToolsState {
- pub fn new(tool_working_set: Entity<ToolWorkingSet>, cx: &mut Context<Self>) -> Self {
+ pub fn new(thread: Entity<Thread>, cx: &mut Context<Self>) -> Self {
let _tool_working_set_subscription =
- cx.subscribe(&tool_working_set, |this, _, event, _| match event {
- ToolWorkingSetEvent::EnabledToolsChanged => {
+ cx.subscribe(&thread, |this, _, event, _| match event {
+ ThreadEvent::ProfileChanged => {
this.cache.clear();
}
+ _ => {}
});
Self {
cache: HashMap::default(),
- tool_working_set,
- _tool_working_set_subscription,
+ thread,
+ _thread_subscription: _tool_working_set_subscription,
}
}
@@ -36,8 +39,9 @@ impl IncompatibleToolsState {
self.cache
.entry(model.tool_input_format())
.or_insert_with(|| {
- self.tool_working_set
+ self.thread
.read(cx)
+ .profile()
.enabled_tools(cx)
.iter()
.filter(|tool| tool.input_schema(model.tool_input_format()).is_err())
@@ -16,7 +16,6 @@ anthropic = { workspace = true, features = ["schemars"] }
anyhow.workspace = true
collections.workspace = true
gpui.workspace = true
-indexmap.workspace = true
language_model.workspace = true
lmstudio = { workspace = true, features = ["schemars"] }
log.workspace = true
@@ -17,29 +17,6 @@ pub mod builtin_profiles {
}
}
-#[derive(Default)]
-pub struct GroupedAgentProfiles {
- pub builtin: IndexMap<AgentProfileId, AgentProfile>,
- pub custom: IndexMap<AgentProfileId, AgentProfile>,
-}
-
-impl GroupedAgentProfiles {
- pub fn from_settings(settings: &crate::AgentSettings) -> Self {
- let mut builtin = IndexMap::default();
- let mut custom = IndexMap::default();
-
- for (profile_id, profile) in settings.profiles.clone() {
- if builtin_profiles::is_builtin(&profile_id) {
- builtin.insert(profile_id, profile);
- } else {
- custom.insert(profile_id, profile);
- }
- }
-
- Self { builtin, custom }
- }
-}
-
#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize, JsonSchema)]
pub struct AgentProfileId(pub Arc<str>);
@@ -63,7 +40,7 @@ impl Default for AgentProfileId {
/// A profile for the Zed Agent that controls its behavior.
#[derive(Debug, Clone)]
-pub struct AgentProfile {
+pub struct AgentProfileSettings {
/// The name of the profile.
pub name: SharedString,
pub tools: IndexMap<Arc<str>, bool>,
@@ -102,7 +102,7 @@ pub struct AgentSettings {
pub using_outdated_settings_version: bool,
pub default_profile: AgentProfileId,
pub default_view: DefaultView,
- pub profiles: IndexMap<AgentProfileId, AgentProfile>,
+ pub profiles: IndexMap<AgentProfileId, AgentProfileSettings>,
pub always_allow_tool_actions: bool,
pub notify_when_agent_waiting: NotifyWhenAgentWaiting,
pub play_sound_when_agent_done: bool,
@@ -531,7 +531,7 @@ impl AgentSettingsContent {
pub fn create_profile(
&mut self,
profile_id: AgentProfileId,
- profile: AgentProfile,
+ profile_settings: AgentProfileSettings,
) -> Result<()> {
self.v2_setting(|settings| {
let profiles = settings.profiles.get_or_insert_default();
@@ -542,10 +542,10 @@ impl AgentSettingsContent {
profiles.insert(
profile_id,
AgentProfileContent {
- name: profile.name.into(),
- tools: profile.tools,
- enable_all_context_servers: Some(profile.enable_all_context_servers),
- context_servers: profile
+ name: profile_settings.name.into(),
+ tools: profile_settings.tools,
+ enable_all_context_servers: Some(profile_settings.enable_all_context_servers),
+ context_servers: profile_settings
.context_servers
.into_iter()
.map(|(server_id, preset)| {
@@ -910,7 +910,7 @@ impl Settings for AgentSettings {
.extend(profiles.into_iter().map(|(id, profile)| {
(
id,
- AgentProfile {
+ AgentProfileSettings {
name: profile.name.into(),
tools: profile.tools,
enable_all_context_servers: profile
@@ -1,7 +1,7 @@
use std::sync::Arc;
-use collections::{HashMap, HashSet, IndexMap};
-use gpui::{App, Context, EventEmitter};
+use collections::{HashMap, IndexMap};
+use gpui::App;
use crate::{Tool, ToolRegistry, ToolSource};
@@ -13,17 +13,9 @@ pub struct ToolId(usize);
pub struct ToolWorkingSet {
context_server_tools_by_id: HashMap<ToolId, Arc<dyn Tool>>,
context_server_tools_by_name: HashMap<String, Arc<dyn Tool>>,
- enabled_sources: HashSet<ToolSource>,
- enabled_tools_by_source: HashMap<ToolSource, HashSet<Arc<str>>>,
next_tool_id: ToolId,
}
-pub enum ToolWorkingSetEvent {
- EnabledToolsChanged,
-}
-
-impl EventEmitter<ToolWorkingSetEvent> for ToolWorkingSet {}
-
impl ToolWorkingSet {
pub fn tool(&self, name: &str, cx: &App) -> Option<Arc<dyn Tool>> {
self.context_server_tools_by_name
@@ -57,42 +49,6 @@ impl ToolWorkingSet {
tools_by_source
}
- pub fn enabled_tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
- let all_tools = self.tools(cx);
-
- all_tools
- .into_iter()
- .filter(|tool| self.is_enabled(&tool.source(), &tool.name().into()))
- .collect()
- }
-
- pub fn disable_all_tools(&mut self, cx: &mut Context<Self>) {
- self.enabled_tools_by_source.clear();
- cx.emit(ToolWorkingSetEvent::EnabledToolsChanged);
- }
-
- pub fn enable_source(&mut self, source: ToolSource, cx: &mut Context<Self>) {
- self.enabled_sources.insert(source.clone());
-
- let tools_by_source = self.tools_by_source(cx);
- if let Some(tools) = tools_by_source.get(&source) {
- self.enabled_tools_by_source.insert(
- source,
- tools
- .into_iter()
- .map(|tool| tool.name().into())
- .collect::<HashSet<_>>(),
- );
- }
- cx.emit(ToolWorkingSetEvent::EnabledToolsChanged);
- }
-
- pub fn disable_source(&mut self, source: &ToolSource, cx: &mut Context<Self>) {
- self.enabled_sources.remove(source);
- self.enabled_tools_by_source.remove(source);
- cx.emit(ToolWorkingSetEvent::EnabledToolsChanged);
- }
-
pub fn insert(&mut self, tool: Arc<dyn Tool>) -> ToolId {
let tool_id = self.next_tool_id;
self.next_tool_id.0 += 1;
@@ -102,42 +58,6 @@ impl ToolWorkingSet {
tool_id
}
- pub fn is_enabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
- self.enabled_tools_by_source
- .get(source)
- .map_or(false, |enabled_tools| enabled_tools.contains(name))
- }
-
- pub fn is_disabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
- !self.is_enabled(source, name)
- }
-
- pub fn enable(
- &mut self,
- source: ToolSource,
- tools_to_enable: &[Arc<str>],
- cx: &mut Context<Self>,
- ) {
- self.enabled_tools_by_source
- .entry(source)
- .or_default()
- .extend(tools_to_enable.into_iter().cloned());
- cx.emit(ToolWorkingSetEvent::EnabledToolsChanged);
- }
-
- pub fn disable(
- &mut self,
- source: ToolSource,
- tools_to_disable: &[Arc<str>],
- cx: &mut Context<Self>,
- ) {
- self.enabled_tools_by_source
- .entry(source)
- .or_default()
- .retain(|name| !tools_to_disable.contains(name));
- cx.emit(ToolWorkingSetEvent::EnabledToolsChanged);
- }
-
pub fn remove(&mut self, tool_ids_to_remove: &[ToolId]) {
self.context_server_tools_by_id
.retain(|id, _| !tool_ids_to_remove.contains(id));
@@ -80,7 +80,6 @@ zed_llm_client.workspace = true
agent_settings.workspace = true
assistant_context_editor.workspace = true
assistant_slash_command.workspace = true
-assistant_tool.workspace = true
async-trait.workspace = true
audio.workspace = true
buffer_diff.workspace = true
@@ -294,6 +294,7 @@ impl ExampleContext {
| ThreadEvent::MessageDeleted(_)
| ThreadEvent::SummaryChanged
| ThreadEvent::SummaryGenerated
+ | ThreadEvent::ProfileChanged
| ThreadEvent::ReceivedTextChunk
| ThreadEvent::StreamedToolUse { .. }
| ThreadEvent::CheckpointChanged
@@ -306,17 +306,19 @@ impl ExampleInstance {
let thread_store = thread_store.await?;
- let profile_id = meta.profile_id.clone();
- thread_store.update(cx, |thread_store, cx| thread_store.load_profile_by_id(profile_id, cx)).expect("Failed to load profile");
let thread =
thread_store.update(cx, |thread_store, cx| {
- if let Some(json) = &meta.existing_thread_json {
+ let thread = if let Some(json) = &meta.existing_thread_json {
let serialized = SerializedThread::from_json(json.as_bytes()).expect("Can't read serialized thread");
thread_store.create_thread_from_serialized(serialized, cx)
} else {
thread_store.create_thread(cx)
- }
+ };
+ thread.update(cx, |thread, cx| {
+ thread.set_profile(meta.profile_id.clone(), cx);
+ });
+ thread
})?;