1use std::sync::Arc;
2
3use agent_settings::{AgentProfileId, AgentProfileSettings, AgentSettings};
4use assistant_tool::{Tool, ToolSource, ToolWorkingSet};
5use collections::IndexMap;
6use convert_case::{Case, Casing};
7use fs::Fs;
8use gpui::{App, Entity};
9use settings::{Settings, update_settings_file};
10use ui::SharedString;
11use util::ResultExt;
12
13#[derive(Clone, Debug, Eq, PartialEq)]
14pub struct AgentProfile {
15 id: AgentProfileId,
16 tool_set: Entity<ToolWorkingSet>,
17}
18
19pub type AvailableProfiles = IndexMap<AgentProfileId, SharedString>;
20
21impl AgentProfile {
22 pub fn new(id: AgentProfileId, tool_set: Entity<ToolWorkingSet>) -> Self {
23 Self { id, tool_set }
24 }
25
26 /// Saves a new profile to the settings.
27 pub fn create(
28 name: String,
29 base_profile_id: Option<AgentProfileId>,
30 fs: Arc<dyn Fs>,
31 cx: &App,
32 ) -> AgentProfileId {
33 let id = AgentProfileId(name.to_case(Case::Kebab).into());
34
35 let base_profile =
36 base_profile_id.and_then(|id| AgentSettings::get_global(cx).profiles.get(&id).cloned());
37
38 let profile_settings = AgentProfileSettings {
39 name: name.into(),
40 tools: base_profile
41 .as_ref()
42 .map(|profile| profile.tools.clone())
43 .unwrap_or_default(),
44 enable_all_context_servers: base_profile
45 .as_ref()
46 .map(|profile| profile.enable_all_context_servers)
47 .unwrap_or_default(),
48 context_servers: base_profile
49 .map(|profile| profile.context_servers)
50 .unwrap_or_default(),
51 };
52
53 update_settings_file::<AgentSettings>(fs, cx, {
54 let id = id.clone();
55 move |settings, _cx| {
56 settings.create_profile(id, profile_settings).log_err();
57 }
58 });
59
60 id
61 }
62
63 /// Returns a map of AgentProfileIds to their names
64 pub fn available_profiles(cx: &App) -> AvailableProfiles {
65 let mut profiles = AvailableProfiles::default();
66 for (id, profile) in AgentSettings::get_global(cx).profiles.iter() {
67 profiles.insert(id.clone(), profile.name.clone());
68 }
69 profiles
70 }
71
72 pub fn id(&self) -> &AgentProfileId {
73 &self.id
74 }
75
76 pub fn enabled_tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
77 let Some(settings) = AgentSettings::get_global(cx).profiles.get(&self.id) else {
78 return Vec::new();
79 };
80
81 self.tool_set
82 .read(cx)
83 .tools(cx)
84 .into_iter()
85 .filter(|tool| Self::is_enabled(settings, tool.source(), tool.name()))
86 .collect()
87 }
88
89 fn is_enabled(settings: &AgentProfileSettings, source: ToolSource, name: String) -> bool {
90 match source {
91 ToolSource::Native => *settings.tools.get(name.as_str()).unwrap_or(&false),
92 ToolSource::ContextServer { id } => {
93 if settings.enable_all_context_servers {
94 return true;
95 }
96
97 let Some(preset) = settings.context_servers.get(id.as_ref()) else {
98 return false;
99 };
100 *preset.tools.get(name.as_str()).unwrap_or(&false)
101 }
102 }
103 }
104}
105
106#[cfg(test)]
107mod tests {
108 use agent_settings::ContextServerPreset;
109 use assistant_tool::ToolRegistry;
110 use collections::IndexMap;
111 use gpui::{AppContext, TestAppContext};
112 use http_client::FakeHttpClient;
113 use project::Project;
114 use settings::{Settings, SettingsStore};
115 use ui::SharedString;
116
117 use super::*;
118
119 #[gpui::test]
120 async fn test_enabled_built_in_tools_for_profile(cx: &mut TestAppContext) {
121 init_test_settings(cx);
122
123 let id = AgentProfileId::default();
124 let profile_settings = cx.read(|cx| {
125 AgentSettings::get_global(cx)
126 .profiles
127 .get(&id)
128 .unwrap()
129 .clone()
130 });
131 let tool_set = default_tool_set(cx);
132
133 let profile = AgentProfile::new(id.clone(), tool_set);
134
135 let mut enabled_tools = cx
136 .read(|cx| profile.enabled_tools(cx))
137 .into_iter()
138 .map(|tool| tool.name())
139 .collect::<Vec<_>>();
140 enabled_tools.sort();
141
142 let mut expected_tools = profile_settings
143 .tools
144 .into_iter()
145 .filter_map(|(tool, enabled)| enabled.then_some(tool.to_string()))
146 // Provider dependent
147 .filter(|tool| tool != "web_search")
148 .collect::<Vec<_>>();
149 // Plus all registered MCP tools
150 expected_tools.extend(["enabled_mcp_tool".into(), "disabled_mcp_tool".into()]);
151 expected_tools.sort();
152
153 assert_eq!(enabled_tools, expected_tools);
154 }
155
156 #[gpui::test]
157 async fn test_custom_mcp_settings(cx: &mut TestAppContext) {
158 init_test_settings(cx);
159
160 let id = AgentProfileId("custom_mcp".into());
161 let profile_settings = cx.read(|cx| {
162 AgentSettings::get_global(cx)
163 .profiles
164 .get(&id)
165 .unwrap()
166 .clone()
167 });
168 let tool_set = default_tool_set(cx);
169
170 let profile = AgentProfile::new(id.clone(), tool_set);
171
172 let mut enabled_tools = cx
173 .read(|cx| profile.enabled_tools(cx))
174 .into_iter()
175 .map(|tool| tool.name())
176 .collect::<Vec<_>>();
177 enabled_tools.sort();
178
179 let mut expected_tools = profile_settings.context_servers["mcp"]
180 .tools
181 .iter()
182 .filter_map(|(key, enabled)| enabled.then(|| key.to_string()))
183 .collect::<Vec<_>>();
184 expected_tools.sort();
185
186 assert_eq!(enabled_tools, expected_tools);
187 }
188
189 #[gpui::test]
190 async fn test_only_built_in(cx: &mut TestAppContext) {
191 init_test_settings(cx);
192
193 let id = AgentProfileId("write_minus_mcp".into());
194 let profile_settings = cx.read(|cx| {
195 AgentSettings::get_global(cx)
196 .profiles
197 .get(&id)
198 .unwrap()
199 .clone()
200 });
201 let tool_set = default_tool_set(cx);
202
203 let profile = AgentProfile::new(id.clone(), tool_set);
204
205 let mut enabled_tools = cx
206 .read(|cx| profile.enabled_tools(cx))
207 .into_iter()
208 .map(|tool| tool.name())
209 .collect::<Vec<_>>();
210 enabled_tools.sort();
211
212 let mut expected_tools = profile_settings
213 .tools
214 .into_iter()
215 .filter_map(|(tool, enabled)| enabled.then_some(tool.to_string()))
216 // Provider dependent
217 .filter(|tool| tool != "web_search")
218 .collect::<Vec<_>>();
219 expected_tools.sort();
220
221 assert_eq!(enabled_tools, expected_tools);
222 }
223
224 fn init_test_settings(cx: &mut TestAppContext) {
225 cx.update(|cx| {
226 let settings_store = SettingsStore::test(cx);
227 cx.set_global(settings_store);
228 Project::init_settings(cx);
229 AgentSettings::register(cx);
230 language_model::init_settings(cx);
231 ToolRegistry::default_global(cx);
232 assistant_tools::init(FakeHttpClient::with_404_response(), cx);
233 });
234
235 cx.update(|cx| {
236 let mut agent_settings = AgentSettings::get_global(cx).clone();
237 agent_settings.profiles.insert(
238 AgentProfileId("write_minus_mcp".into()),
239 AgentProfileSettings {
240 name: "write_minus_mcp".into(),
241 enable_all_context_servers: false,
242 ..agent_settings.profiles[&AgentProfileId::default()].clone()
243 },
244 );
245 agent_settings.profiles.insert(
246 AgentProfileId("custom_mcp".into()),
247 AgentProfileSettings {
248 name: "mcp".into(),
249 tools: IndexMap::default(),
250 enable_all_context_servers: false,
251 context_servers: IndexMap::from_iter([("mcp".into(), context_server_preset())]),
252 },
253 );
254 AgentSettings::override_global(agent_settings, cx);
255 })
256 }
257
258 fn context_server_preset() -> ContextServerPreset {
259 ContextServerPreset {
260 tools: IndexMap::from_iter([
261 ("enabled_mcp_tool".into(), true),
262 ("disabled_mcp_tool".into(), false),
263 ]),
264 }
265 }
266
267 fn default_tool_set(cx: &mut TestAppContext) -> Entity<ToolWorkingSet> {
268 cx.new(|_| {
269 let mut tool_set = ToolWorkingSet::default();
270 tool_set.insert(Arc::new(FakeTool::new("enabled_mcp_tool", "mcp")));
271 tool_set.insert(Arc::new(FakeTool::new("disabled_mcp_tool", "mcp")));
272 tool_set
273 })
274 }
275
276 struct FakeTool {
277 name: String,
278 source: SharedString,
279 }
280
281 impl FakeTool {
282 fn new(name: impl Into<String>, source: impl Into<SharedString>) -> Self {
283 Self {
284 name: name.into(),
285 source: source.into(),
286 }
287 }
288 }
289
290 impl Tool for FakeTool {
291 fn name(&self) -> String {
292 self.name.clone()
293 }
294
295 fn source(&self) -> ToolSource {
296 ToolSource::ContextServer {
297 id: self.source.clone(),
298 }
299 }
300
301 fn description(&self) -> String {
302 unimplemented!()
303 }
304
305 fn icon(&self) -> ui::IconName {
306 unimplemented!()
307 }
308
309 fn needs_confirmation(&self, _input: &serde_json::Value, _cx: &App) -> bool {
310 unimplemented!()
311 }
312
313 fn ui_text(&self, _input: &serde_json::Value) -> String {
314 unimplemented!()
315 }
316
317 fn run(
318 self: Arc<Self>,
319 _input: serde_json::Value,
320 _request: Arc<language_model::LanguageModelRequest>,
321 _project: Entity<Project>,
322 _action_log: Entity<assistant_tool::ActionLog>,
323 _model: Arc<dyn language_model::LanguageModel>,
324 _window: Option<gpui::AnyWindowHandle>,
325 _cx: &mut App,
326 ) -> assistant_tool::ToolResult {
327 unimplemented!()
328 }
329
330 fn may_perform_edits(&self) -> bool {
331 unimplemented!()
332 }
333 }
334}