1use std::sync::Arc;
2
3use agent_settings::{AgentProfileId, AgentProfileSettings, AgentSettings};
4use assistant_tool::{Tool, ToolSource, ToolWorkingSet};
5use collections::IndexMap;
6use convert_case::{Case, Casing};
7use fs::Fs;
8use gpui::{App, Entity, SharedString};
9use settings::{Settings, update_settings_file};
10use util::ResultExt;
11
12#[derive(Clone, Debug, Eq, PartialEq)]
13pub struct AgentProfile {
14 id: AgentProfileId,
15 tool_set: Entity<ToolWorkingSet>,
16}
17
18pub type AvailableProfiles = IndexMap<AgentProfileId, SharedString>;
19
20impl AgentProfile {
21 pub fn new(id: AgentProfileId, tool_set: Entity<ToolWorkingSet>) -> Self {
22 Self { id, tool_set }
23 }
24
25 /// Saves a new profile to the settings.
26 pub fn create(
27 name: String,
28 base_profile_id: Option<AgentProfileId>,
29 fs: Arc<dyn Fs>,
30 cx: &App,
31 ) -> AgentProfileId {
32 let id = AgentProfileId(name.to_case(Case::Kebab).into());
33
34 let base_profile =
35 base_profile_id.and_then(|id| AgentSettings::get_global(cx).profiles.get(&id).cloned());
36
37 let profile_settings = AgentProfileSettings {
38 name: name.into(),
39 tools: base_profile
40 .as_ref()
41 .map(|profile| profile.tools.clone())
42 .unwrap_or_default(),
43 enable_all_context_servers: base_profile
44 .as_ref()
45 .map(|profile| profile.enable_all_context_servers)
46 .unwrap_or_default(),
47 context_servers: base_profile
48 .map(|profile| profile.context_servers)
49 .unwrap_or_default(),
50 };
51
52 update_settings_file::<AgentSettings>(fs, cx, {
53 let id = id.clone();
54 move |settings, _cx| {
55 settings.create_profile(id, profile_settings).log_err();
56 }
57 });
58
59 id
60 }
61
62 /// Returns a map of AgentProfileIds to their names
63 pub fn available_profiles(cx: &App) -> AvailableProfiles {
64 let mut profiles = AvailableProfiles::default();
65 for (id, profile) in AgentSettings::get_global(cx).profiles.iter() {
66 profiles.insert(id.clone(), profile.name.clone());
67 }
68 profiles
69 }
70
71 pub fn id(&self) -> &AgentProfileId {
72 &self.id
73 }
74
75 pub fn enabled_tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
76 let Some(settings) = AgentSettings::get_global(cx).profiles.get(&self.id) else {
77 return Vec::new();
78 };
79
80 self.tool_set
81 .read(cx)
82 .tools(cx)
83 .into_iter()
84 .filter(|tool| Self::is_enabled(settings, tool.source(), tool.name()))
85 .collect()
86 }
87
88 pub fn is_tool_enabled(&self, source: ToolSource, tool_name: String, cx: &App) -> bool {
89 let Some(settings) = AgentSettings::get_global(cx).profiles.get(&self.id) else {
90 return false;
91 };
92
93 return Self::is_enabled(settings, source, tool_name);
94 }
95
96 fn is_enabled(settings: &AgentProfileSettings, source: ToolSource, name: String) -> bool {
97 match source {
98 ToolSource::Native => *settings.tools.get(name.as_str()).unwrap_or(&false),
99 ToolSource::ContextServer { id } => {
100 if settings.enable_all_context_servers {
101 return true;
102 }
103
104 let Some(preset) = settings.context_servers.get(id.as_ref()) else {
105 return false;
106 };
107 *preset.tools.get(name.as_str()).unwrap_or(&false)
108 }
109 }
110 }
111}
112
113#[cfg(test)]
114mod tests {
115 use agent_settings::ContextServerPreset;
116 use assistant_tool::ToolRegistry;
117 use collections::IndexMap;
118 use gpui::SharedString;
119 use gpui::{AppContext, TestAppContext};
120 use http_client::FakeHttpClient;
121 use project::Project;
122 use settings::{Settings, SettingsStore};
123
124 use super::*;
125
126 #[gpui::test]
127 async fn test_enabled_built_in_tools_for_profile(cx: &mut TestAppContext) {
128 init_test_settings(cx);
129
130 let id = AgentProfileId::default();
131 let profile_settings = cx.read(|cx| {
132 AgentSettings::get_global(cx)
133 .profiles
134 .get(&id)
135 .unwrap()
136 .clone()
137 });
138 let tool_set = default_tool_set(cx);
139
140 let profile = AgentProfile::new(id.clone(), tool_set);
141
142 let mut enabled_tools = cx
143 .read(|cx| profile.enabled_tools(cx))
144 .into_iter()
145 .map(|tool| tool.name())
146 .collect::<Vec<_>>();
147 enabled_tools.sort();
148
149 let mut expected_tools = profile_settings
150 .tools
151 .into_iter()
152 .filter_map(|(tool, enabled)| enabled.then_some(tool.to_string()))
153 // Provider dependent
154 .filter(|tool| tool != "web_search")
155 .collect::<Vec<_>>();
156 // Plus all registered MCP tools
157 expected_tools.extend(["enabled_mcp_tool".into(), "disabled_mcp_tool".into()]);
158 expected_tools.sort();
159
160 assert_eq!(enabled_tools, expected_tools);
161 }
162
163 #[gpui::test]
164 async fn test_custom_mcp_settings(cx: &mut TestAppContext) {
165 init_test_settings(cx);
166
167 let id = AgentProfileId("custom_mcp".into());
168 let profile_settings = cx.read(|cx| {
169 AgentSettings::get_global(cx)
170 .profiles
171 .get(&id)
172 .unwrap()
173 .clone()
174 });
175 let tool_set = default_tool_set(cx);
176
177 let profile = AgentProfile::new(id.clone(), tool_set);
178
179 let mut enabled_tools = cx
180 .read(|cx| profile.enabled_tools(cx))
181 .into_iter()
182 .map(|tool| tool.name())
183 .collect::<Vec<_>>();
184 enabled_tools.sort();
185
186 let mut expected_tools = profile_settings.context_servers["mcp"]
187 .tools
188 .iter()
189 .filter_map(|(key, enabled)| enabled.then(|| key.to_string()))
190 .collect::<Vec<_>>();
191 expected_tools.sort();
192
193 assert_eq!(enabled_tools, expected_tools);
194 }
195
196 #[gpui::test]
197 async fn test_only_built_in(cx: &mut TestAppContext) {
198 init_test_settings(cx);
199
200 let id = AgentProfileId("write_minus_mcp".into());
201 let profile_settings = cx.read(|cx| {
202 AgentSettings::get_global(cx)
203 .profiles
204 .get(&id)
205 .unwrap()
206 .clone()
207 });
208 let tool_set = default_tool_set(cx);
209
210 let profile = AgentProfile::new(id.clone(), tool_set);
211
212 let mut enabled_tools = cx
213 .read(|cx| profile.enabled_tools(cx))
214 .into_iter()
215 .map(|tool| tool.name())
216 .collect::<Vec<_>>();
217 enabled_tools.sort();
218
219 let mut expected_tools = profile_settings
220 .tools
221 .into_iter()
222 .filter_map(|(tool, enabled)| enabled.then_some(tool.to_string()))
223 // Provider dependent
224 .filter(|tool| tool != "web_search")
225 .collect::<Vec<_>>();
226 expected_tools.sort();
227
228 assert_eq!(enabled_tools, expected_tools);
229 }
230
231 fn init_test_settings(cx: &mut TestAppContext) {
232 cx.update(|cx| {
233 let settings_store = SettingsStore::test(cx);
234 cx.set_global(settings_store);
235 Project::init_settings(cx);
236 AgentSettings::register(cx);
237 language_model::init_settings(cx);
238 ToolRegistry::default_global(cx);
239 assistant_tools::init(FakeHttpClient::with_404_response(), cx);
240 });
241
242 cx.update(|cx| {
243 let mut agent_settings = AgentSettings::get_global(cx).clone();
244 agent_settings.profiles.insert(
245 AgentProfileId("write_minus_mcp".into()),
246 AgentProfileSettings {
247 name: "write_minus_mcp".into(),
248 enable_all_context_servers: false,
249 ..agent_settings.profiles[&AgentProfileId::default()].clone()
250 },
251 );
252 agent_settings.profiles.insert(
253 AgentProfileId("custom_mcp".into()),
254 AgentProfileSettings {
255 name: "mcp".into(),
256 tools: IndexMap::default(),
257 enable_all_context_servers: false,
258 context_servers: IndexMap::from_iter([("mcp".into(), context_server_preset())]),
259 },
260 );
261 AgentSettings::override_global(agent_settings, cx);
262 })
263 }
264
265 fn context_server_preset() -> ContextServerPreset {
266 ContextServerPreset {
267 tools: IndexMap::from_iter([
268 ("enabled_mcp_tool".into(), true),
269 ("disabled_mcp_tool".into(), false),
270 ]),
271 }
272 }
273
274 fn default_tool_set(cx: &mut TestAppContext) -> Entity<ToolWorkingSet> {
275 cx.new(|_| {
276 let mut tool_set = ToolWorkingSet::default();
277 tool_set.insert(Arc::new(FakeTool::new("enabled_mcp_tool", "mcp")));
278 tool_set.insert(Arc::new(FakeTool::new("disabled_mcp_tool", "mcp")));
279 tool_set
280 })
281 }
282
283 struct FakeTool {
284 name: String,
285 source: SharedString,
286 }
287
288 impl FakeTool {
289 fn new(name: impl Into<String>, source: impl Into<SharedString>) -> Self {
290 Self {
291 name: name.into(),
292 source: source.into(),
293 }
294 }
295 }
296
297 impl Tool for FakeTool {
298 fn name(&self) -> String {
299 self.name.clone()
300 }
301
302 fn source(&self) -> ToolSource {
303 ToolSource::ContextServer {
304 id: self.source.clone(),
305 }
306 }
307
308 fn description(&self) -> String {
309 unimplemented!()
310 }
311
312 fn icon(&self) -> icons::IconName {
313 unimplemented!()
314 }
315
316 fn needs_confirmation(&self, _input: &serde_json::Value, _cx: &App) -> bool {
317 unimplemented!()
318 }
319
320 fn ui_text(&self, _input: &serde_json::Value) -> String {
321 unimplemented!()
322 }
323
324 fn run(
325 self: Arc<Self>,
326 _input: serde_json::Value,
327 _request: Arc<language_model::LanguageModelRequest>,
328 _project: Entity<Project>,
329 _action_log: Entity<assistant_tool::ActionLog>,
330 _model: Arc<dyn language_model::LanguageModel>,
331 _window: Option<gpui::AnyWindowHandle>,
332 _cx: &mut App,
333 ) -> assistant_tool::ToolResult {
334 unimplemented!()
335 }
336
337 fn may_perform_edits(&self) -> bool {
338 unimplemented!()
339 }
340 }
341}