1use std::sync::Arc;
2
3use agent_settings::{AgentProfileId, AgentProfileSettings, AgentSettings};
4use assistant_tool::{Tool, ToolSource, ToolWorkingSet, UniqueToolName};
5use collections::IndexMap;
6use convert_case::{Case, Casing};
7use fs::Fs;
8use gpui::{App, Entity, SharedString};
9use settings::{Settings, update_settings_file};
10use util::ResultExt;
11
12#[derive(Clone, Debug, Eq, PartialEq)]
13pub struct AgentProfile {
14 id: AgentProfileId,
15 tool_set: Entity<ToolWorkingSet>,
16}
17
18pub type AvailableProfiles = IndexMap<AgentProfileId, SharedString>;
19
20impl AgentProfile {
21 pub fn new(id: AgentProfileId, tool_set: Entity<ToolWorkingSet>) -> Self {
22 Self { id, tool_set }
23 }
24
25 /// Saves a new profile to the settings.
26 pub fn create(
27 name: String,
28 base_profile_id: Option<AgentProfileId>,
29 fs: Arc<dyn Fs>,
30 cx: &App,
31 ) -> AgentProfileId {
32 let id = AgentProfileId(name.to_case(Case::Kebab).into());
33
34 let base_profile =
35 base_profile_id.and_then(|id| AgentSettings::get_global(cx).profiles.get(&id).cloned());
36
37 let profile_settings = AgentProfileSettings {
38 name: name.into(),
39 tools: base_profile
40 .as_ref()
41 .map(|profile| profile.tools.clone())
42 .unwrap_or_default(),
43 enable_all_context_servers: base_profile
44 .as_ref()
45 .map(|profile| profile.enable_all_context_servers)
46 .unwrap_or_default(),
47 context_servers: base_profile
48 .map(|profile| profile.context_servers)
49 .unwrap_or_default(),
50 };
51
52 update_settings_file(fs, cx, {
53 let id = id.clone();
54 move |settings, _cx| {
55 profile_settings.save_to_settings(id, settings).log_err();
56 }
57 });
58
59 id
60 }
61
62 /// Returns a map of AgentProfileIds to their names
63 pub fn available_profiles(cx: &App) -> AvailableProfiles {
64 let mut profiles = AvailableProfiles::default();
65 for (id, profile) in AgentSettings::get_global(cx).profiles.iter() {
66 profiles.insert(id.clone(), profile.name.clone());
67 }
68 profiles
69 }
70
71 pub fn id(&self) -> &AgentProfileId {
72 &self.id
73 }
74
75 pub fn enabled_tools(&self, cx: &App) -> Vec<(UniqueToolName, Arc<dyn Tool>)> {
76 let Some(settings) = AgentSettings::get_global(cx).profiles.get(&self.id) else {
77 return Vec::new();
78 };
79
80 self.tool_set
81 .read(cx)
82 .tools(cx)
83 .into_iter()
84 .filter(|(_, tool)| Self::is_enabled(settings, tool.source(), tool.name()))
85 .collect()
86 }
87
88 pub fn is_tool_enabled(&self, source: ToolSource, tool_name: String, cx: &App) -> bool {
89 let Some(settings) = AgentSettings::get_global(cx).profiles.get(&self.id) else {
90 return false;
91 };
92
93 Self::is_enabled(settings, source, tool_name)
94 }
95
96 fn is_enabled(settings: &AgentProfileSettings, source: ToolSource, name: String) -> bool {
97 match source {
98 ToolSource::Native => *settings.tools.get(name.as_str()).unwrap_or(&false),
99 ToolSource::ContextServer { id } => settings
100 .context_servers
101 .get(id.as_ref())
102 .and_then(|preset| preset.tools.get(name.as_str()).copied())
103 .unwrap_or(settings.enable_all_context_servers),
104 }
105 }
106}
107
108#[cfg(test)]
109mod tests {
110 use agent_settings::ContextServerPreset;
111 use assistant_tool::ToolRegistry;
112 use collections::IndexMap;
113 use gpui::SharedString;
114 use gpui::{AppContext, TestAppContext};
115 use http_client::FakeHttpClient;
116 use project::Project;
117 use settings::{Settings, SettingsStore};
118
119 use super::*;
120
121 #[gpui::test]
122 async fn test_enabled_built_in_tools_for_profile(cx: &mut TestAppContext) {
123 init_test_settings(cx);
124
125 let id = AgentProfileId::default();
126 let profile_settings = cx.read(|cx| {
127 AgentSettings::get_global(cx)
128 .profiles
129 .get(&id)
130 .unwrap()
131 .clone()
132 });
133 let tool_set = default_tool_set(cx);
134
135 let profile = AgentProfile::new(id, tool_set);
136
137 let mut enabled_tools = cx
138 .read(|cx| profile.enabled_tools(cx))
139 .into_iter()
140 .map(|(_, tool)| tool.name())
141 .collect::<Vec<_>>();
142 enabled_tools.sort();
143
144 let mut expected_tools = profile_settings
145 .tools
146 .into_iter()
147 .filter_map(|(tool, enabled)| enabled.then_some(tool.to_string()))
148 // Provider dependent
149 .filter(|tool| tool != "web_search")
150 .collect::<Vec<_>>();
151 // Plus all registered MCP tools
152 expected_tools.extend(["enabled_mcp_tool".into(), "disabled_mcp_tool".into()]);
153 expected_tools.sort();
154
155 assert_eq!(enabled_tools, expected_tools);
156 }
157
158 #[gpui::test]
159 async fn test_custom_mcp_settings(cx: &mut TestAppContext) {
160 init_test_settings(cx);
161
162 let id = AgentProfileId("custom_mcp".into());
163 let profile_settings = cx.read(|cx| {
164 AgentSettings::get_global(cx)
165 .profiles
166 .get(&id)
167 .unwrap()
168 .clone()
169 });
170 let tool_set = default_tool_set(cx);
171
172 let profile = AgentProfile::new(id, tool_set);
173
174 let mut enabled_tools = cx
175 .read(|cx| profile.enabled_tools(cx))
176 .into_iter()
177 .map(|(_, tool)| tool.name())
178 .collect::<Vec<_>>();
179 enabled_tools.sort();
180
181 let mut expected_tools = profile_settings.context_servers["mcp"]
182 .tools
183 .iter()
184 .filter_map(|(key, enabled)| enabled.then(|| key.to_string()))
185 .collect::<Vec<_>>();
186 expected_tools.sort();
187
188 assert_eq!(enabled_tools, expected_tools);
189 }
190
191 #[gpui::test]
192 async fn test_only_built_in(cx: &mut TestAppContext) {
193 init_test_settings(cx);
194
195 let id = AgentProfileId("write_minus_mcp".into());
196 let profile_settings = cx.read(|cx| {
197 AgentSettings::get_global(cx)
198 .profiles
199 .get(&id)
200 .unwrap()
201 .clone()
202 });
203 let tool_set = default_tool_set(cx);
204
205 let profile = AgentProfile::new(id, tool_set);
206
207 let mut enabled_tools = cx
208 .read(|cx| profile.enabled_tools(cx))
209 .into_iter()
210 .map(|(_, tool)| tool.name())
211 .collect::<Vec<_>>();
212 enabled_tools.sort();
213
214 let mut expected_tools = profile_settings
215 .tools
216 .into_iter()
217 .filter_map(|(tool, enabled)| enabled.then_some(tool.to_string()))
218 // Provider dependent
219 .filter(|tool| tool != "web_search")
220 .collect::<Vec<_>>();
221 expected_tools.sort();
222
223 assert_eq!(enabled_tools, expected_tools);
224 }
225
226 fn init_test_settings(cx: &mut TestAppContext) {
227 cx.update(|cx| {
228 let settings_store = SettingsStore::test(cx);
229 cx.set_global(settings_store);
230 Project::init_settings(cx);
231 AgentSettings::register(cx);
232 language_model::init_settings(cx);
233 ToolRegistry::default_global(cx);
234 assistant_tools::init(FakeHttpClient::with_404_response(), cx);
235 });
236
237 cx.update(|cx| {
238 let mut agent_settings = AgentSettings::get_global(cx).clone();
239 agent_settings.profiles.insert(
240 AgentProfileId("write_minus_mcp".into()),
241 AgentProfileSettings {
242 name: "write_minus_mcp".into(),
243 enable_all_context_servers: false,
244 ..agent_settings.profiles[&AgentProfileId::default()].clone()
245 },
246 );
247 agent_settings.profiles.insert(
248 AgentProfileId("custom_mcp".into()),
249 AgentProfileSettings {
250 name: "mcp".into(),
251 tools: IndexMap::default(),
252 enable_all_context_servers: false,
253 context_servers: IndexMap::from_iter([("mcp".into(), context_server_preset())]),
254 },
255 );
256 AgentSettings::override_global(agent_settings, cx);
257 })
258 }
259
260 fn context_server_preset() -> ContextServerPreset {
261 ContextServerPreset {
262 tools: IndexMap::from_iter([
263 ("enabled_mcp_tool".into(), true),
264 ("disabled_mcp_tool".into(), false),
265 ]),
266 }
267 }
268
269 fn default_tool_set(cx: &mut TestAppContext) -> Entity<ToolWorkingSet> {
270 cx.new(|cx| {
271 let mut tool_set = ToolWorkingSet::default();
272 tool_set.insert(Arc::new(FakeTool::new("enabled_mcp_tool", "mcp")), cx);
273 tool_set.insert(Arc::new(FakeTool::new("disabled_mcp_tool", "mcp")), cx);
274 tool_set
275 })
276 }
277
278 struct FakeTool {
279 name: String,
280 source: SharedString,
281 }
282
283 impl FakeTool {
284 fn new(name: impl Into<String>, source: impl Into<SharedString>) -> Self {
285 Self {
286 name: name.into(),
287 source: source.into(),
288 }
289 }
290 }
291
292 impl Tool for FakeTool {
293 fn name(&self) -> String {
294 self.name.clone()
295 }
296
297 fn source(&self) -> ToolSource {
298 ToolSource::ContextServer {
299 id: self.source.clone(),
300 }
301 }
302
303 fn description(&self) -> String {
304 unimplemented!()
305 }
306
307 fn icon(&self) -> icons::IconName {
308 unimplemented!()
309 }
310
311 fn needs_confirmation(
312 &self,
313 _input: &serde_json::Value,
314 _project: &Entity<Project>,
315 _cx: &App,
316 ) -> bool {
317 unimplemented!()
318 }
319
320 fn ui_text(&self, _input: &serde_json::Value) -> String {
321 unimplemented!()
322 }
323
324 fn run(
325 self: Arc<Self>,
326 _input: serde_json::Value,
327 _request: Arc<language_model::LanguageModelRequest>,
328 _project: Entity<Project>,
329 _action_log: Entity<action_log::ActionLog>,
330 _model: Arc<dyn language_model::LanguageModel>,
331 _window: Option<gpui::AnyWindowHandle>,
332 _cx: &mut App,
333 ) -> assistant_tool::ToolResult {
334 unimplemented!()
335 }
336
337 fn may_perform_edits(&self) -> bool {
338 unimplemented!()
339 }
340 }
341}