agent_profile.rs

  1use std::sync::Arc;
  2
  3use agent_settings::{AgentProfileId, AgentProfileSettings, AgentSettings};
  4use assistant_tool::{Tool, ToolSource, ToolWorkingSet};
  5use collections::IndexMap;
  6use convert_case::{Case, Casing};
  7use fs::Fs;
  8use gpui::{App, Entity};
  9use settings::{Settings, update_settings_file};
 10use ui::SharedString;
 11use util::ResultExt;
 12
 13#[derive(Clone, Debug, Eq, PartialEq)]
 14pub struct AgentProfile {
 15    id: AgentProfileId,
 16    tool_set: Entity<ToolWorkingSet>,
 17}
 18
 19pub type AvailableProfiles = IndexMap<AgentProfileId, SharedString>;
 20
 21impl AgentProfile {
 22    pub fn new(id: AgentProfileId, tool_set: Entity<ToolWorkingSet>) -> Self {
 23        Self { id, tool_set }
 24    }
 25
 26    /// Saves a new profile to the settings.
 27    pub fn create(
 28        name: String,
 29        base_profile_id: Option<AgentProfileId>,
 30        fs: Arc<dyn Fs>,
 31        cx: &App,
 32    ) -> AgentProfileId {
 33        let id = AgentProfileId(name.to_case(Case::Kebab).into());
 34
 35        let base_profile =
 36            base_profile_id.and_then(|id| AgentSettings::get_global(cx).profiles.get(&id).cloned());
 37
 38        let profile_settings = AgentProfileSettings {
 39            name: name.into(),
 40            tools: base_profile
 41                .as_ref()
 42                .map(|profile| profile.tools.clone())
 43                .unwrap_or_default(),
 44            enable_all_context_servers: base_profile
 45                .as_ref()
 46                .map(|profile| profile.enable_all_context_servers)
 47                .unwrap_or_default(),
 48            context_servers: base_profile
 49                .map(|profile| profile.context_servers)
 50                .unwrap_or_default(),
 51        };
 52
 53        update_settings_file::<AgentSettings>(fs, cx, {
 54            let id = id.clone();
 55            move |settings, _cx| {
 56                settings.create_profile(id, profile_settings).log_err();
 57            }
 58        });
 59
 60        id
 61    }
 62
 63    /// Returns a map of AgentProfileIds to their names
 64    pub fn available_profiles(cx: &App) -> AvailableProfiles {
 65        let mut profiles = AvailableProfiles::default();
 66        for (id, profile) in AgentSettings::get_global(cx).profiles.iter() {
 67            profiles.insert(id.clone(), profile.name.clone());
 68        }
 69        profiles
 70    }
 71
 72    pub fn id(&self) -> &AgentProfileId {
 73        &self.id
 74    }
 75
 76    pub fn enabled_tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
 77        let Some(settings) = AgentSettings::get_global(cx).profiles.get(&self.id) else {
 78            return Vec::new();
 79        };
 80
 81        self.tool_set
 82            .read(cx)
 83            .tools(cx)
 84            .into_iter()
 85            .filter(|tool| Self::is_enabled(settings, tool.source(), tool.name()))
 86            .collect()
 87    }
 88
 89    fn is_enabled(settings: &AgentProfileSettings, source: ToolSource, name: String) -> bool {
 90        match source {
 91            ToolSource::Native => *settings.tools.get(name.as_str()).unwrap_or(&false),
 92            ToolSource::ContextServer { id } => {
 93                if settings.enable_all_context_servers {
 94                    return true;
 95                }
 96
 97                let Some(preset) = settings.context_servers.get(id.as_ref()) else {
 98                    return false;
 99                };
100                *preset.tools.get(name.as_str()).unwrap_or(&false)
101            }
102        }
103    }
104}
105
106#[cfg(test)]
107mod tests {
108    use agent_settings::ContextServerPreset;
109    use assistant_tool::ToolRegistry;
110    use collections::IndexMap;
111    use gpui::{AppContext, TestAppContext};
112    use http_client::FakeHttpClient;
113    use project::Project;
114    use settings::{Settings, SettingsStore};
115    use ui::SharedString;
116
117    use super::*;
118
119    #[gpui::test]
120    async fn test_enabled_built_in_tools_for_profile(cx: &mut TestAppContext) {
121        init_test_settings(cx);
122
123        let id = AgentProfileId::default();
124        let profile_settings = cx.read(|cx| {
125            AgentSettings::get_global(cx)
126                .profiles
127                .get(&id)
128                .unwrap()
129                .clone()
130        });
131        let tool_set = default_tool_set(cx);
132
133        let profile = AgentProfile::new(id.clone(), tool_set);
134
135        let mut enabled_tools = cx
136            .read(|cx| profile.enabled_tools(cx))
137            .into_iter()
138            .map(|tool| tool.name())
139            .collect::<Vec<_>>();
140        enabled_tools.sort();
141
142        let mut expected_tools = profile_settings
143            .tools
144            .into_iter()
145            .filter_map(|(tool, enabled)| enabled.then_some(tool.to_string()))
146            // Provider dependent
147            .filter(|tool| tool != "web_search")
148            .collect::<Vec<_>>();
149        // Plus all registered MCP tools
150        expected_tools.extend(["enabled_mcp_tool".into(), "disabled_mcp_tool".into()]);
151        expected_tools.sort();
152
153        assert_eq!(enabled_tools, expected_tools);
154    }
155
156    #[gpui::test]
157    async fn test_custom_mcp_settings(cx: &mut TestAppContext) {
158        init_test_settings(cx);
159
160        let id = AgentProfileId("custom_mcp".into());
161        let profile_settings = cx.read(|cx| {
162            AgentSettings::get_global(cx)
163                .profiles
164                .get(&id)
165                .unwrap()
166                .clone()
167        });
168        let tool_set = default_tool_set(cx);
169
170        let profile = AgentProfile::new(id.clone(), tool_set);
171
172        let mut enabled_tools = cx
173            .read(|cx| profile.enabled_tools(cx))
174            .into_iter()
175            .map(|tool| tool.name())
176            .collect::<Vec<_>>();
177        enabled_tools.sort();
178
179        let mut expected_tools = profile_settings.context_servers["mcp"]
180            .tools
181            .iter()
182            .filter_map(|(key, enabled)| enabled.then(|| key.to_string()))
183            .collect::<Vec<_>>();
184        expected_tools.sort();
185
186        assert_eq!(enabled_tools, expected_tools);
187    }
188
189    #[gpui::test]
190    async fn test_only_built_in(cx: &mut TestAppContext) {
191        init_test_settings(cx);
192
193        let id = AgentProfileId("write_minus_mcp".into());
194        let profile_settings = cx.read(|cx| {
195            AgentSettings::get_global(cx)
196                .profiles
197                .get(&id)
198                .unwrap()
199                .clone()
200        });
201        let tool_set = default_tool_set(cx);
202
203        let profile = AgentProfile::new(id.clone(), tool_set);
204
205        let mut enabled_tools = cx
206            .read(|cx| profile.enabled_tools(cx))
207            .into_iter()
208            .map(|tool| tool.name())
209            .collect::<Vec<_>>();
210        enabled_tools.sort();
211
212        let mut expected_tools = profile_settings
213            .tools
214            .into_iter()
215            .filter_map(|(tool, enabled)| enabled.then_some(tool.to_string()))
216            // Provider dependent
217            .filter(|tool| tool != "web_search")
218            .collect::<Vec<_>>();
219        expected_tools.sort();
220
221        assert_eq!(enabled_tools, expected_tools);
222    }
223
224    fn init_test_settings(cx: &mut TestAppContext) {
225        cx.update(|cx| {
226            let settings_store = SettingsStore::test(cx);
227            cx.set_global(settings_store);
228            Project::init_settings(cx);
229            AgentSettings::register(cx);
230            language_model::init_settings(cx);
231            ToolRegistry::default_global(cx);
232            assistant_tools::init(FakeHttpClient::with_404_response(), cx);
233        });
234
235        cx.update(|cx| {
236            let mut agent_settings = AgentSettings::get_global(cx).clone();
237            agent_settings.profiles.insert(
238                AgentProfileId("write_minus_mcp".into()),
239                AgentProfileSettings {
240                    name: "write_minus_mcp".into(),
241                    enable_all_context_servers: false,
242                    ..agent_settings.profiles[&AgentProfileId::default()].clone()
243                },
244            );
245            agent_settings.profiles.insert(
246                AgentProfileId("custom_mcp".into()),
247                AgentProfileSettings {
248                    name: "mcp".into(),
249                    tools: IndexMap::default(),
250                    enable_all_context_servers: false,
251                    context_servers: IndexMap::from_iter([("mcp".into(), context_server_preset())]),
252                },
253            );
254            AgentSettings::override_global(agent_settings, cx);
255        })
256    }
257
258    fn context_server_preset() -> ContextServerPreset {
259        ContextServerPreset {
260            tools: IndexMap::from_iter([
261                ("enabled_mcp_tool".into(), true),
262                ("disabled_mcp_tool".into(), false),
263            ]),
264        }
265    }
266
267    fn default_tool_set(cx: &mut TestAppContext) -> Entity<ToolWorkingSet> {
268        cx.new(|_| {
269            let mut tool_set = ToolWorkingSet::default();
270            tool_set.insert(Arc::new(FakeTool::new("enabled_mcp_tool", "mcp")));
271            tool_set.insert(Arc::new(FakeTool::new("disabled_mcp_tool", "mcp")));
272            tool_set
273        })
274    }
275
276    struct FakeTool {
277        name: String,
278        source: SharedString,
279    }
280
281    impl FakeTool {
282        fn new(name: impl Into<String>, source: impl Into<SharedString>) -> Self {
283            Self {
284                name: name.into(),
285                source: source.into(),
286            }
287        }
288    }
289
290    impl Tool for FakeTool {
291        fn name(&self) -> String {
292            self.name.clone()
293        }
294
295        fn source(&self) -> ToolSource {
296            ToolSource::ContextServer {
297                id: self.source.clone(),
298            }
299        }
300
301        fn description(&self) -> String {
302            unimplemented!()
303        }
304
305        fn icon(&self) -> ui::IconName {
306            unimplemented!()
307        }
308
309        fn needs_confirmation(&self, _input: &serde_json::Value, _cx: &App) -> bool {
310            unimplemented!()
311        }
312
313        fn ui_text(&self, _input: &serde_json::Value) -> String {
314            unimplemented!()
315        }
316
317        fn run(
318            self: Arc<Self>,
319            _input: serde_json::Value,
320            _request: Arc<language_model::LanguageModelRequest>,
321            _project: Entity<Project>,
322            _action_log: Entity<assistant_tool::ActionLog>,
323            _model: Arc<dyn language_model::LanguageModel>,
324            _window: Option<gpui::AnyWindowHandle>,
325            _cx: &mut App,
326        ) -> assistant_tool::ToolResult {
327            unimplemented!()
328        }
329
330        fn may_perform_edits(&self) -> bool {
331            unimplemented!()
332        }
333    }
334}