agent_profile.rs

  1use std::sync::Arc;
  2
  3use agent_settings::{AgentProfileId, AgentProfileSettings, AgentSettings};
  4use assistant_tool::{Tool, ToolSource, ToolWorkingSet};
  5use collections::IndexMap;
  6use convert_case::{Case, Casing};
  7use fs::Fs;
  8use gpui::{App, Entity, SharedString};
  9use settings::{Settings, update_settings_file};
 10use util::ResultExt;
 11
 12#[derive(Clone, Debug, Eq, PartialEq)]
 13pub struct AgentProfile {
 14    id: AgentProfileId,
 15    tool_set: Entity<ToolWorkingSet>,
 16}
 17
 18pub type AvailableProfiles = IndexMap<AgentProfileId, SharedString>;
 19
 20impl AgentProfile {
 21    pub fn new(id: AgentProfileId, tool_set: Entity<ToolWorkingSet>) -> Self {
 22        Self { id, tool_set }
 23    }
 24
 25    /// Saves a new profile to the settings.
 26    pub fn create(
 27        name: String,
 28        base_profile_id: Option<AgentProfileId>,
 29        fs: Arc<dyn Fs>,
 30        cx: &App,
 31    ) -> AgentProfileId {
 32        let id = AgentProfileId(name.to_case(Case::Kebab).into());
 33
 34        let base_profile =
 35            base_profile_id.and_then(|id| AgentSettings::get_global(cx).profiles.get(&id).cloned());
 36
 37        let profile_settings = AgentProfileSettings {
 38            name: name.into(),
 39            tools: base_profile
 40                .as_ref()
 41                .map(|profile| profile.tools.clone())
 42                .unwrap_or_default(),
 43            enable_all_context_servers: base_profile
 44                .as_ref()
 45                .map(|profile| profile.enable_all_context_servers)
 46                .unwrap_or_default(),
 47            context_servers: base_profile
 48                .map(|profile| profile.context_servers)
 49                .unwrap_or_default(),
 50        };
 51
 52        update_settings_file::<AgentSettings>(fs, cx, {
 53            let id = id.clone();
 54            move |settings, _cx| {
 55                settings.create_profile(id, profile_settings).log_err();
 56            }
 57        });
 58
 59        id
 60    }
 61
 62    /// Returns a map of AgentProfileIds to their names
 63    pub fn available_profiles(cx: &App) -> AvailableProfiles {
 64        let mut profiles = AvailableProfiles::default();
 65        for (id, profile) in AgentSettings::get_global(cx).profiles.iter() {
 66            profiles.insert(id.clone(), profile.name.clone());
 67        }
 68        profiles
 69    }
 70
 71    pub fn id(&self) -> &AgentProfileId {
 72        &self.id
 73    }
 74
 75    pub fn enabled_tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
 76        let Some(settings) = AgentSettings::get_global(cx).profiles.get(&self.id) else {
 77            return Vec::new();
 78        };
 79
 80        self.tool_set
 81            .read(cx)
 82            .tools(cx)
 83            .into_iter()
 84            .filter(|tool| Self::is_enabled(settings, tool.source(), tool.name()))
 85            .collect()
 86    }
 87
 88    pub fn is_tool_enabled(&self, source: ToolSource, tool_name: String, cx: &App) -> bool {
 89        let Some(settings) = AgentSettings::get_global(cx).profiles.get(&self.id) else {
 90            return false;
 91        };
 92
 93        return Self::is_enabled(settings, source, tool_name);
 94    }
 95
 96    fn is_enabled(settings: &AgentProfileSettings, source: ToolSource, name: String) -> bool {
 97        match source {
 98            ToolSource::Native => *settings.tools.get(name.as_str()).unwrap_or(&false),
 99            ToolSource::ContextServer { id } => {
100                if settings.enable_all_context_servers {
101                    return true;
102                }
103
104                let Some(preset) = settings.context_servers.get(id.as_ref()) else {
105                    return false;
106                };
107                *preset.tools.get(name.as_str()).unwrap_or(&false)
108            }
109        }
110    }
111}
112
113#[cfg(test)]
114mod tests {
115    use agent_settings::ContextServerPreset;
116    use assistant_tool::ToolRegistry;
117    use collections::IndexMap;
118    use gpui::SharedString;
119    use gpui::{AppContext, TestAppContext};
120    use http_client::FakeHttpClient;
121    use project::Project;
122    use settings::{Settings, SettingsStore};
123
124    use super::*;
125
126    #[gpui::test]
127    async fn test_enabled_built_in_tools_for_profile(cx: &mut TestAppContext) {
128        init_test_settings(cx);
129
130        let id = AgentProfileId::default();
131        let profile_settings = cx.read(|cx| {
132            AgentSettings::get_global(cx)
133                .profiles
134                .get(&id)
135                .unwrap()
136                .clone()
137        });
138        let tool_set = default_tool_set(cx);
139
140        let profile = AgentProfile::new(id.clone(), tool_set);
141
142        let mut enabled_tools = cx
143            .read(|cx| profile.enabled_tools(cx))
144            .into_iter()
145            .map(|tool| tool.name())
146            .collect::<Vec<_>>();
147        enabled_tools.sort();
148
149        let mut expected_tools = profile_settings
150            .tools
151            .into_iter()
152            .filter_map(|(tool, enabled)| enabled.then_some(tool.to_string()))
153            // Provider dependent
154            .filter(|tool| tool != "web_search")
155            .collect::<Vec<_>>();
156        // Plus all registered MCP tools
157        expected_tools.extend(["enabled_mcp_tool".into(), "disabled_mcp_tool".into()]);
158        expected_tools.sort();
159
160        assert_eq!(enabled_tools, expected_tools);
161    }
162
163    #[gpui::test]
164    async fn test_custom_mcp_settings(cx: &mut TestAppContext) {
165        init_test_settings(cx);
166
167        let id = AgentProfileId("custom_mcp".into());
168        let profile_settings = cx.read(|cx| {
169            AgentSettings::get_global(cx)
170                .profiles
171                .get(&id)
172                .unwrap()
173                .clone()
174        });
175        let tool_set = default_tool_set(cx);
176
177        let profile = AgentProfile::new(id.clone(), tool_set);
178
179        let mut enabled_tools = cx
180            .read(|cx| profile.enabled_tools(cx))
181            .into_iter()
182            .map(|tool| tool.name())
183            .collect::<Vec<_>>();
184        enabled_tools.sort();
185
186        let mut expected_tools = profile_settings.context_servers["mcp"]
187            .tools
188            .iter()
189            .filter_map(|(key, enabled)| enabled.then(|| key.to_string()))
190            .collect::<Vec<_>>();
191        expected_tools.sort();
192
193        assert_eq!(enabled_tools, expected_tools);
194    }
195
196    #[gpui::test]
197    async fn test_only_built_in(cx: &mut TestAppContext) {
198        init_test_settings(cx);
199
200        let id = AgentProfileId("write_minus_mcp".into());
201        let profile_settings = cx.read(|cx| {
202            AgentSettings::get_global(cx)
203                .profiles
204                .get(&id)
205                .unwrap()
206                .clone()
207        });
208        let tool_set = default_tool_set(cx);
209
210        let profile = AgentProfile::new(id.clone(), tool_set);
211
212        let mut enabled_tools = cx
213            .read(|cx| profile.enabled_tools(cx))
214            .into_iter()
215            .map(|tool| tool.name())
216            .collect::<Vec<_>>();
217        enabled_tools.sort();
218
219        let mut expected_tools = profile_settings
220            .tools
221            .into_iter()
222            .filter_map(|(tool, enabled)| enabled.then_some(tool.to_string()))
223            // Provider dependent
224            .filter(|tool| tool != "web_search")
225            .collect::<Vec<_>>();
226        expected_tools.sort();
227
228        assert_eq!(enabled_tools, expected_tools);
229    }
230
231    fn init_test_settings(cx: &mut TestAppContext) {
232        cx.update(|cx| {
233            let settings_store = SettingsStore::test(cx);
234            cx.set_global(settings_store);
235            Project::init_settings(cx);
236            AgentSettings::register(cx);
237            language_model::init_settings(cx);
238            ToolRegistry::default_global(cx);
239            assistant_tools::init(FakeHttpClient::with_404_response(), cx);
240        });
241
242        cx.update(|cx| {
243            let mut agent_settings = AgentSettings::get_global(cx).clone();
244            agent_settings.profiles.insert(
245                AgentProfileId("write_minus_mcp".into()),
246                AgentProfileSettings {
247                    name: "write_minus_mcp".into(),
248                    enable_all_context_servers: false,
249                    ..agent_settings.profiles[&AgentProfileId::default()].clone()
250                },
251            );
252            agent_settings.profiles.insert(
253                AgentProfileId("custom_mcp".into()),
254                AgentProfileSettings {
255                    name: "mcp".into(),
256                    tools: IndexMap::default(),
257                    enable_all_context_servers: false,
258                    context_servers: IndexMap::from_iter([("mcp".into(), context_server_preset())]),
259                },
260            );
261            AgentSettings::override_global(agent_settings, cx);
262        })
263    }
264
265    fn context_server_preset() -> ContextServerPreset {
266        ContextServerPreset {
267            tools: IndexMap::from_iter([
268                ("enabled_mcp_tool".into(), true),
269                ("disabled_mcp_tool".into(), false),
270            ]),
271        }
272    }
273
274    fn default_tool_set(cx: &mut TestAppContext) -> Entity<ToolWorkingSet> {
275        cx.new(|_| {
276            let mut tool_set = ToolWorkingSet::default();
277            tool_set.insert(Arc::new(FakeTool::new("enabled_mcp_tool", "mcp")));
278            tool_set.insert(Arc::new(FakeTool::new("disabled_mcp_tool", "mcp")));
279            tool_set
280        })
281    }
282
283    struct FakeTool {
284        name: String,
285        source: SharedString,
286    }
287
288    impl FakeTool {
289        fn new(name: impl Into<String>, source: impl Into<SharedString>) -> Self {
290            Self {
291                name: name.into(),
292                source: source.into(),
293            }
294        }
295    }
296
297    impl Tool for FakeTool {
298        fn name(&self) -> String {
299            self.name.clone()
300        }
301
302        fn source(&self) -> ToolSource {
303            ToolSource::ContextServer {
304                id: self.source.clone(),
305            }
306        }
307
308        fn description(&self) -> String {
309            unimplemented!()
310        }
311
312        fn icon(&self) -> icons::IconName {
313            unimplemented!()
314        }
315
316        fn needs_confirmation(&self, _input: &serde_json::Value, _cx: &App) -> bool {
317            unimplemented!()
318        }
319
320        fn ui_text(&self, _input: &serde_json::Value) -> String {
321            unimplemented!()
322        }
323
324        fn run(
325            self: Arc<Self>,
326            _input: serde_json::Value,
327            _request: Arc<language_model::LanguageModelRequest>,
328            _project: Entity<Project>,
329            _action_log: Entity<assistant_tool::ActionLog>,
330            _model: Arc<dyn language_model::LanguageModel>,
331            _window: Option<gpui::AnyWindowHandle>,
332            _cx: &mut App,
333        ) -> assistant_tool::ToolResult {
334            unimplemented!()
335        }
336
337        fn may_perform_edits(&self) -> bool {
338            unimplemented!()
339        }
340    }
341}