agent_profile.rs

  1use std::sync::Arc;
  2
  3use agent_settings::{AgentProfileId, AgentProfileSettings, AgentSettings};
  4use assistant_tool::{Tool, ToolSource, ToolWorkingSet};
  5use collections::IndexMap;
  6use convert_case::{Case, Casing};
  7use fs::Fs;
  8use gpui::{App, Entity, SharedString};
  9use settings::{Settings, update_settings_file};
 10use util::ResultExt;
 11
 12#[derive(Clone, Debug, Eq, PartialEq)]
 13pub struct AgentProfile {
 14    id: AgentProfileId,
 15    tool_set: Entity<ToolWorkingSet>,
 16}
 17
 18pub type AvailableProfiles = IndexMap<AgentProfileId, SharedString>;
 19
 20impl AgentProfile {
 21    pub fn new(id: AgentProfileId, tool_set: Entity<ToolWorkingSet>) -> Self {
 22        Self { id, tool_set }
 23    }
 24
 25    /// Saves a new profile to the settings.
 26    pub fn create(
 27        name: String,
 28        base_profile_id: Option<AgentProfileId>,
 29        fs: Arc<dyn Fs>,
 30        cx: &App,
 31    ) -> AgentProfileId {
 32        let id = AgentProfileId(name.to_case(Case::Kebab).into());
 33
 34        let base_profile =
 35            base_profile_id.and_then(|id| AgentSettings::get_global(cx).profiles.get(&id).cloned());
 36
 37        let profile_settings = AgentProfileSettings {
 38            name: name.into(),
 39            tools: base_profile
 40                .as_ref()
 41                .map(|profile| profile.tools.clone())
 42                .unwrap_or_default(),
 43            enable_all_context_servers: base_profile
 44                .as_ref()
 45                .map(|profile| profile.enable_all_context_servers)
 46                .unwrap_or_default(),
 47            context_servers: base_profile
 48                .map(|profile| profile.context_servers)
 49                .unwrap_or_default(),
 50        };
 51
 52        update_settings_file::<AgentSettings>(fs, cx, {
 53            let id = id.clone();
 54            move |settings, _cx| {
 55                settings.create_profile(id, profile_settings).log_err();
 56            }
 57        });
 58
 59        id
 60    }
 61
 62    /// Returns a map of AgentProfileIds to their names
 63    pub fn available_profiles(cx: &App) -> AvailableProfiles {
 64        let mut profiles = AvailableProfiles::default();
 65        for (id, profile) in AgentSettings::get_global(cx).profiles.iter() {
 66            profiles.insert(id.clone(), profile.name.clone());
 67        }
 68        profiles
 69    }
 70
 71    pub fn id(&self) -> &AgentProfileId {
 72        &self.id
 73    }
 74
 75    pub fn enabled_tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
 76        let Some(settings) = AgentSettings::get_global(cx).profiles.get(&self.id) else {
 77            return Vec::new();
 78        };
 79
 80        self.tool_set
 81            .read(cx)
 82            .tools(cx)
 83            .into_iter()
 84            .filter(|tool| Self::is_enabled(settings, tool.source(), tool.name()))
 85            .collect()
 86    }
 87
 88    fn is_enabled(settings: &AgentProfileSettings, source: ToolSource, name: String) -> bool {
 89        match source {
 90            ToolSource::Native => *settings.tools.get(name.as_str()).unwrap_or(&false),
 91            ToolSource::ContextServer { id } => {
 92                if settings.enable_all_context_servers {
 93                    return true;
 94                }
 95
 96                let Some(preset) = settings.context_servers.get(id.as_ref()) else {
 97                    return false;
 98                };
 99                *preset.tools.get(name.as_str()).unwrap_or(&false)
100            }
101        }
102    }
103}
104
105#[cfg(test)]
106mod tests {
107    use agent_settings::ContextServerPreset;
108    use assistant_tool::ToolRegistry;
109    use collections::IndexMap;
110    use gpui::SharedString;
111    use gpui::{AppContext, TestAppContext};
112    use http_client::FakeHttpClient;
113    use project::Project;
114    use settings::{Settings, SettingsStore};
115
116    use super::*;
117
118    #[gpui::test]
119    async fn test_enabled_built_in_tools_for_profile(cx: &mut TestAppContext) {
120        init_test_settings(cx);
121
122        let id = AgentProfileId::default();
123        let profile_settings = cx.read(|cx| {
124            AgentSettings::get_global(cx)
125                .profiles
126                .get(&id)
127                .unwrap()
128                .clone()
129        });
130        let tool_set = default_tool_set(cx);
131
132        let profile = AgentProfile::new(id.clone(), tool_set);
133
134        let mut enabled_tools = cx
135            .read(|cx| profile.enabled_tools(cx))
136            .into_iter()
137            .map(|tool| tool.name())
138            .collect::<Vec<_>>();
139        enabled_tools.sort();
140
141        let mut expected_tools = profile_settings
142            .tools
143            .into_iter()
144            .filter_map(|(tool, enabled)| enabled.then_some(tool.to_string()))
145            // Provider dependent
146            .filter(|tool| tool != "web_search")
147            .collect::<Vec<_>>();
148        // Plus all registered MCP tools
149        expected_tools.extend(["enabled_mcp_tool".into(), "disabled_mcp_tool".into()]);
150        expected_tools.sort();
151
152        assert_eq!(enabled_tools, expected_tools);
153    }
154
155    #[gpui::test]
156    async fn test_custom_mcp_settings(cx: &mut TestAppContext) {
157        init_test_settings(cx);
158
159        let id = AgentProfileId("custom_mcp".into());
160        let profile_settings = cx.read(|cx| {
161            AgentSettings::get_global(cx)
162                .profiles
163                .get(&id)
164                .unwrap()
165                .clone()
166        });
167        let tool_set = default_tool_set(cx);
168
169        let profile = AgentProfile::new(id.clone(), tool_set);
170
171        let mut enabled_tools = cx
172            .read(|cx| profile.enabled_tools(cx))
173            .into_iter()
174            .map(|tool| tool.name())
175            .collect::<Vec<_>>();
176        enabled_tools.sort();
177
178        let mut expected_tools = profile_settings.context_servers["mcp"]
179            .tools
180            .iter()
181            .filter_map(|(key, enabled)| enabled.then(|| key.to_string()))
182            .collect::<Vec<_>>();
183        expected_tools.sort();
184
185        assert_eq!(enabled_tools, expected_tools);
186    }
187
188    #[gpui::test]
189    async fn test_only_built_in(cx: &mut TestAppContext) {
190        init_test_settings(cx);
191
192        let id = AgentProfileId("write_minus_mcp".into());
193        let profile_settings = cx.read(|cx| {
194            AgentSettings::get_global(cx)
195                .profiles
196                .get(&id)
197                .unwrap()
198                .clone()
199        });
200        let tool_set = default_tool_set(cx);
201
202        let profile = AgentProfile::new(id.clone(), tool_set);
203
204        let mut enabled_tools = cx
205            .read(|cx| profile.enabled_tools(cx))
206            .into_iter()
207            .map(|tool| tool.name())
208            .collect::<Vec<_>>();
209        enabled_tools.sort();
210
211        let mut expected_tools = profile_settings
212            .tools
213            .into_iter()
214            .filter_map(|(tool, enabled)| enabled.then_some(tool.to_string()))
215            // Provider dependent
216            .filter(|tool| tool != "web_search")
217            .collect::<Vec<_>>();
218        expected_tools.sort();
219
220        assert_eq!(enabled_tools, expected_tools);
221    }
222
223    fn init_test_settings(cx: &mut TestAppContext) {
224        cx.update(|cx| {
225            let settings_store = SettingsStore::test(cx);
226            cx.set_global(settings_store);
227            Project::init_settings(cx);
228            AgentSettings::register(cx);
229            language_model::init_settings(cx);
230            ToolRegistry::default_global(cx);
231            assistant_tools::init(FakeHttpClient::with_404_response(), cx);
232        });
233
234        cx.update(|cx| {
235            let mut agent_settings = AgentSettings::get_global(cx).clone();
236            agent_settings.profiles.insert(
237                AgentProfileId("write_minus_mcp".into()),
238                AgentProfileSettings {
239                    name: "write_minus_mcp".into(),
240                    enable_all_context_servers: false,
241                    ..agent_settings.profiles[&AgentProfileId::default()].clone()
242                },
243            );
244            agent_settings.profiles.insert(
245                AgentProfileId("custom_mcp".into()),
246                AgentProfileSettings {
247                    name: "mcp".into(),
248                    tools: IndexMap::default(),
249                    enable_all_context_servers: false,
250                    context_servers: IndexMap::from_iter([("mcp".into(), context_server_preset())]),
251                },
252            );
253            AgentSettings::override_global(agent_settings, cx);
254        })
255    }
256
257    fn context_server_preset() -> ContextServerPreset {
258        ContextServerPreset {
259            tools: IndexMap::from_iter([
260                ("enabled_mcp_tool".into(), true),
261                ("disabled_mcp_tool".into(), false),
262            ]),
263        }
264    }
265
266    fn default_tool_set(cx: &mut TestAppContext) -> Entity<ToolWorkingSet> {
267        cx.new(|_| {
268            let mut tool_set = ToolWorkingSet::default();
269            tool_set.insert(Arc::new(FakeTool::new("enabled_mcp_tool", "mcp")));
270            tool_set.insert(Arc::new(FakeTool::new("disabled_mcp_tool", "mcp")));
271            tool_set
272        })
273    }
274
275    struct FakeTool {
276        name: String,
277        source: SharedString,
278    }
279
280    impl FakeTool {
281        fn new(name: impl Into<String>, source: impl Into<SharedString>) -> Self {
282            Self {
283                name: name.into(),
284                source: source.into(),
285            }
286        }
287    }
288
289    impl Tool for FakeTool {
290        fn name(&self) -> String {
291            self.name.clone()
292        }
293
294        fn source(&self) -> ToolSource {
295            ToolSource::ContextServer {
296                id: self.source.clone(),
297            }
298        }
299
300        fn description(&self) -> String {
301            unimplemented!()
302        }
303
304        fn icon(&self) -> icons::IconName {
305            unimplemented!()
306        }
307
308        fn needs_confirmation(&self, _input: &serde_json::Value, _cx: &App) -> bool {
309            unimplemented!()
310        }
311
312        fn ui_text(&self, _input: &serde_json::Value) -> String {
313            unimplemented!()
314        }
315
316        fn run(
317            self: Arc<Self>,
318            _input: serde_json::Value,
319            _request: Arc<language_model::LanguageModelRequest>,
320            _project: Entity<Project>,
321            _action_log: Entity<assistant_tool::ActionLog>,
322            _model: Arc<dyn language_model::LanguageModel>,
323            _window: Option<gpui::AnyWindowHandle>,
324            _cx: &mut App,
325        ) -> assistant_tool::ToolResult {
326            unimplemented!()
327        }
328
329        fn may_perform_edits(&self) -> bool {
330            unimplemented!()
331        }
332    }
333}