agent_profile.rs

  1use std::sync::Arc;
  2
  3use agent_settings::{AgentProfileId, AgentProfileSettings, AgentSettings};
  4use assistant_tool::{Tool, ToolSource, ToolWorkingSet, UniqueToolName};
  5use collections::IndexMap;
  6use convert_case::{Case, Casing};
  7use fs::Fs;
  8use gpui::{App, Entity, SharedString};
  9use settings::{Settings, update_settings_file};
 10use util::ResultExt;
 11
 12#[derive(Clone, Debug, Eq, PartialEq)]
 13pub struct AgentProfile {
 14    id: AgentProfileId,
 15    tool_set: Entity<ToolWorkingSet>,
 16}
 17
 18pub type AvailableProfiles = IndexMap<AgentProfileId, SharedString>;
 19
 20impl AgentProfile {
 21    pub fn new(id: AgentProfileId, tool_set: Entity<ToolWorkingSet>) -> Self {
 22        Self { id, tool_set }
 23    }
 24
 25    /// Saves a new profile to the settings.
 26    pub fn create(
 27        name: String,
 28        base_profile_id: Option<AgentProfileId>,
 29        fs: Arc<dyn Fs>,
 30        cx: &App,
 31    ) -> AgentProfileId {
 32        let id = AgentProfileId(name.to_case(Case::Kebab).into());
 33
 34        let base_profile =
 35            base_profile_id.and_then(|id| AgentSettings::get_global(cx).profiles.get(&id).cloned());
 36
 37        let profile_settings = AgentProfileSettings {
 38            name: name.into(),
 39            tools: base_profile
 40                .as_ref()
 41                .map(|profile| profile.tools.clone())
 42                .unwrap_or_default(),
 43            enable_all_context_servers: base_profile
 44                .as_ref()
 45                .map(|profile| profile.enable_all_context_servers)
 46                .unwrap_or_default(),
 47            context_servers: base_profile
 48                .map(|profile| profile.context_servers)
 49                .unwrap_or_default(),
 50        };
 51
 52        update_settings_file(fs, cx, {
 53            let id = id.clone();
 54            move |settings, _cx| {
 55                profile_settings.save_to_settings(id, settings).log_err();
 56            }
 57        });
 58
 59        id
 60    }
 61
 62    /// Returns a map of AgentProfileIds to their names
 63    pub fn available_profiles(cx: &App) -> AvailableProfiles {
 64        let mut profiles = AvailableProfiles::default();
 65        for (id, profile) in AgentSettings::get_global(cx).profiles.iter() {
 66            profiles.insert(id.clone(), profile.name.clone());
 67        }
 68        profiles
 69    }
 70
 71    pub fn id(&self) -> &AgentProfileId {
 72        &self.id
 73    }
 74
 75    pub fn enabled_tools(&self, cx: &App) -> Vec<(UniqueToolName, Arc<dyn Tool>)> {
 76        let Some(settings) = AgentSettings::get_global(cx).profiles.get(&self.id) else {
 77            return Vec::new();
 78        };
 79
 80        self.tool_set
 81            .read(cx)
 82            .tools(cx)
 83            .into_iter()
 84            .filter(|(_, tool)| Self::is_enabled(settings, tool.source(), tool.name()))
 85            .collect()
 86    }
 87
 88    pub fn is_tool_enabled(&self, source: ToolSource, tool_name: String, cx: &App) -> bool {
 89        let Some(settings) = AgentSettings::get_global(cx).profiles.get(&self.id) else {
 90            return false;
 91        };
 92
 93        Self::is_enabled(settings, source, tool_name)
 94    }
 95
 96    fn is_enabled(settings: &AgentProfileSettings, source: ToolSource, name: String) -> bool {
 97        match source {
 98            ToolSource::Native => *settings.tools.get(name.as_str()).unwrap_or(&false),
 99            ToolSource::ContextServer { id } => settings
100                .context_servers
101                .get(id.as_ref())
102                .and_then(|preset| preset.tools.get(name.as_str()).copied())
103                .unwrap_or(settings.enable_all_context_servers),
104        }
105    }
106}
107
108#[cfg(test)]
109mod tests {
110    use agent_settings::ContextServerPreset;
111    use assistant_tool::ToolRegistry;
112    use collections::IndexMap;
113    use gpui::SharedString;
114    use gpui::{AppContext, TestAppContext};
115    use http_client::FakeHttpClient;
116    use project::Project;
117    use settings::{Settings, SettingsStore};
118
119    use super::*;
120
121    #[gpui::test]
122    async fn test_enabled_built_in_tools_for_profile(cx: &mut TestAppContext) {
123        init_test_settings(cx);
124
125        let id = AgentProfileId::default();
126        let profile_settings = cx.read(|cx| {
127            AgentSettings::get_global(cx)
128                .profiles
129                .get(&id)
130                .unwrap()
131                .clone()
132        });
133        let tool_set = default_tool_set(cx);
134
135        let profile = AgentProfile::new(id, tool_set);
136
137        let mut enabled_tools = cx
138            .read(|cx| profile.enabled_tools(cx))
139            .into_iter()
140            .map(|(_, tool)| tool.name())
141            .collect::<Vec<_>>();
142        enabled_tools.sort();
143
144        let mut expected_tools = profile_settings
145            .tools
146            .into_iter()
147            .filter_map(|(tool, enabled)| enabled.then_some(tool.to_string()))
148            // Provider dependent
149            .filter(|tool| tool != "web_search")
150            .collect::<Vec<_>>();
151        // Plus all registered MCP tools
152        expected_tools.extend(["enabled_mcp_tool".into(), "disabled_mcp_tool".into()]);
153        expected_tools.sort();
154
155        assert_eq!(enabled_tools, expected_tools);
156    }
157
158    #[gpui::test]
159    async fn test_custom_mcp_settings(cx: &mut TestAppContext) {
160        init_test_settings(cx);
161
162        let id = AgentProfileId("custom_mcp".into());
163        let profile_settings = cx.read(|cx| {
164            AgentSettings::get_global(cx)
165                .profiles
166                .get(&id)
167                .unwrap()
168                .clone()
169        });
170        let tool_set = default_tool_set(cx);
171
172        let profile = AgentProfile::new(id, tool_set);
173
174        let mut enabled_tools = cx
175            .read(|cx| profile.enabled_tools(cx))
176            .into_iter()
177            .map(|(_, tool)| tool.name())
178            .collect::<Vec<_>>();
179        enabled_tools.sort();
180
181        let mut expected_tools = profile_settings.context_servers["mcp"]
182            .tools
183            .iter()
184            .filter_map(|(key, enabled)| enabled.then(|| key.to_string()))
185            .collect::<Vec<_>>();
186        expected_tools.sort();
187
188        assert_eq!(enabled_tools, expected_tools);
189    }
190
191    #[gpui::test]
192    async fn test_only_built_in(cx: &mut TestAppContext) {
193        init_test_settings(cx);
194
195        let id = AgentProfileId("write_minus_mcp".into());
196        let profile_settings = cx.read(|cx| {
197            AgentSettings::get_global(cx)
198                .profiles
199                .get(&id)
200                .unwrap()
201                .clone()
202        });
203        let tool_set = default_tool_set(cx);
204
205        let profile = AgentProfile::new(id, tool_set);
206
207        let mut enabled_tools = cx
208            .read(|cx| profile.enabled_tools(cx))
209            .into_iter()
210            .map(|(_, tool)| tool.name())
211            .collect::<Vec<_>>();
212        enabled_tools.sort();
213
214        let mut expected_tools = profile_settings
215            .tools
216            .into_iter()
217            .filter_map(|(tool, enabled)| enabled.then_some(tool.to_string()))
218            // Provider dependent
219            .filter(|tool| tool != "web_search")
220            .collect::<Vec<_>>();
221        expected_tools.sort();
222
223        assert_eq!(enabled_tools, expected_tools);
224    }
225
226    fn init_test_settings(cx: &mut TestAppContext) {
227        cx.update(|cx| {
228            let settings_store = SettingsStore::test(cx);
229            cx.set_global(settings_store);
230            Project::init_settings(cx);
231            AgentSettings::register(cx);
232            language_model::init_settings(cx);
233            ToolRegistry::default_global(cx);
234            assistant_tools::init(FakeHttpClient::with_404_response(), cx);
235        });
236
237        cx.update(|cx| {
238            let mut agent_settings = AgentSettings::get_global(cx).clone();
239            agent_settings.profiles.insert(
240                AgentProfileId("write_minus_mcp".into()),
241                AgentProfileSettings {
242                    name: "write_minus_mcp".into(),
243                    enable_all_context_servers: false,
244                    ..agent_settings.profiles[&AgentProfileId::default()].clone()
245                },
246            );
247            agent_settings.profiles.insert(
248                AgentProfileId("custom_mcp".into()),
249                AgentProfileSettings {
250                    name: "mcp".into(),
251                    tools: IndexMap::default(),
252                    enable_all_context_servers: false,
253                    context_servers: IndexMap::from_iter([("mcp".into(), context_server_preset())]),
254                },
255            );
256            AgentSettings::override_global(agent_settings, cx);
257        })
258    }
259
260    fn context_server_preset() -> ContextServerPreset {
261        ContextServerPreset {
262            tools: IndexMap::from_iter([
263                ("enabled_mcp_tool".into(), true),
264                ("disabled_mcp_tool".into(), false),
265            ]),
266        }
267    }
268
269    fn default_tool_set(cx: &mut TestAppContext) -> Entity<ToolWorkingSet> {
270        cx.new(|cx| {
271            let mut tool_set = ToolWorkingSet::default();
272            tool_set.insert(Arc::new(FakeTool::new("enabled_mcp_tool", "mcp")), cx);
273            tool_set.insert(Arc::new(FakeTool::new("disabled_mcp_tool", "mcp")), cx);
274            tool_set
275        })
276    }
277
278    struct FakeTool {
279        name: String,
280        source: SharedString,
281    }
282
283    impl FakeTool {
284        fn new(name: impl Into<String>, source: impl Into<SharedString>) -> Self {
285            Self {
286                name: name.into(),
287                source: source.into(),
288            }
289        }
290    }
291
292    impl Tool for FakeTool {
293        fn name(&self) -> String {
294            self.name.clone()
295        }
296
297        fn source(&self) -> ToolSource {
298            ToolSource::ContextServer {
299                id: self.source.clone(),
300            }
301        }
302
303        fn description(&self) -> String {
304            unimplemented!()
305        }
306
307        fn icon(&self) -> icons::IconName {
308            unimplemented!()
309        }
310
311        fn needs_confirmation(
312            &self,
313            _input: &serde_json::Value,
314            _project: &Entity<Project>,
315            _cx: &App,
316        ) -> bool {
317            unimplemented!()
318        }
319
320        fn ui_text(&self, _input: &serde_json::Value) -> String {
321            unimplemented!()
322        }
323
324        fn run(
325            self: Arc<Self>,
326            _input: serde_json::Value,
327            _request: Arc<language_model::LanguageModelRequest>,
328            _project: Entity<Project>,
329            _action_log: Entity<action_log::ActionLog>,
330            _model: Arc<dyn language_model::LanguageModel>,
331            _window: Option<gpui::AnyWindowHandle>,
332            _cx: &mut App,
333        ) -> assistant_tool::ToolResult {
334            unimplemented!()
335        }
336
337        fn may_perform_edits(&self) -> bool {
338            unimplemented!()
339        }
340    }
341}