1mod db;
2mod edit_agent;
3mod legacy_thread;
4mod native_agent_server;
5pub mod outline;
6mod pattern_extraction;
7mod templates;
8#[cfg(test)]
9mod tests;
10mod thread;
11mod thread_store;
12mod tool_permissions;
13mod tools;
14
15use context_server::ContextServerId;
16pub use db::*;
17use itertools::Itertools;
18pub use native_agent_server::NativeAgentServer;
19pub use pattern_extraction::*;
20pub use shell_command_parser::extract_commands;
21pub use templates::*;
22pub use thread::*;
23pub use thread_store::*;
24pub use tool_permissions::*;
25pub use tools::*;
26
27use acp_thread::{
28 AcpThread, AgentModelSelector, AgentSessionInfo, AgentSessionList, AgentSessionListRequest,
29 AgentSessionListResponse, TokenUsageRatio, UserMessageId,
30};
31use agent_client_protocol as acp;
32use anyhow::{Context as _, Result, anyhow};
33use chrono::{DateTime, Utc};
34use collections::{HashMap, HashSet, IndexMap};
35use fs::Fs;
36use futures::channel::{mpsc, oneshot};
37use futures::future::Shared;
38use futures::{FutureExt as _, StreamExt as _, future};
39use gpui::{
40 App, AppContext, AsyncApp, Context, Entity, EntityId, SharedString, Subscription, Task,
41 WeakEntity,
42};
43use language_model::{IconOrSvg, LanguageModel, LanguageModelProvider, LanguageModelRegistry};
44use project::{AgentId, Project, ProjectItem, ProjectPath, Worktree};
45use prompt_store::{
46 ProjectContext, PromptStore, RULES_FILE_NAMES, RulesFileContext, UserRulesContext,
47 WorktreeContext,
48};
49use serde::{Deserialize, Serialize};
50use settings::{LanguageModelSelection, update_settings_file};
51use std::any::Any;
52use std::path::PathBuf;
53use std::rc::Rc;
54use std::sync::{Arc, LazyLock};
55use util::ResultExt;
56use util::path_list::PathList;
57use util::rel_path::RelPath;
58
59#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
60pub struct ProjectSnapshot {
61 pub worktree_snapshots: Vec<project::telemetry_snapshot::TelemetryWorktreeSnapshot>,
62 pub timestamp: DateTime<Utc>,
63}
64
65pub struct RulesLoadingError {
66 pub message: SharedString,
67}
68
69struct ProjectState {
70 project: Entity<Project>,
71 project_context: Entity<ProjectContext>,
72 project_context_needs_refresh: watch::Sender<()>,
73 _maintain_project_context: Task<Result<()>>,
74 context_server_registry: Entity<ContextServerRegistry>,
75 _subscriptions: Vec<Subscription>,
76}
77
78/// Holds both the internal Thread and the AcpThread for a session
79struct Session {
80 /// The internal thread that processes messages
81 thread: Entity<Thread>,
82 /// The ACP thread that handles protocol communication
83 acp_thread: Entity<acp_thread::AcpThread>,
84 project_id: EntityId,
85 pending_save: Task<()>,
86 _subscriptions: Vec<Subscription>,
87}
88
89pub struct LanguageModels {
90 /// Access language model by ID
91 models: HashMap<acp::ModelId, Arc<dyn LanguageModel>>,
92 /// Cached list for returning language model information
93 model_list: acp_thread::AgentModelList,
94 refresh_models_rx: watch::Receiver<()>,
95 refresh_models_tx: watch::Sender<()>,
96 _authenticate_all_providers_task: Task<()>,
97}
98
99impl LanguageModels {
100 fn new(cx: &mut App) -> Self {
101 let (refresh_models_tx, refresh_models_rx) = watch::channel(());
102
103 let mut this = Self {
104 models: HashMap::default(),
105 model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()),
106 refresh_models_rx,
107 refresh_models_tx,
108 _authenticate_all_providers_task: Self::authenticate_all_language_model_providers(cx),
109 };
110 this.refresh_list(cx);
111 this
112 }
113
114 fn refresh_list(&mut self, cx: &App) {
115 let providers = LanguageModelRegistry::global(cx)
116 .read(cx)
117 .visible_providers()
118 .into_iter()
119 .filter(|provider| provider.is_authenticated(cx))
120 .collect::<Vec<_>>();
121
122 let mut language_model_list = IndexMap::default();
123 let mut recommended_models = HashSet::default();
124
125 let mut recommended = Vec::new();
126 for provider in &providers {
127 for model in provider.recommended_models(cx) {
128 recommended_models.insert((model.provider_id(), model.id()));
129 recommended.push(Self::map_language_model_to_info(&model, provider));
130 }
131 }
132 if !recommended.is_empty() {
133 language_model_list.insert(
134 acp_thread::AgentModelGroupName("Recommended".into()),
135 recommended,
136 );
137 }
138
139 let mut models = HashMap::default();
140 for provider in providers {
141 let mut provider_models = Vec::new();
142 for model in provider.provided_models(cx) {
143 let model_info = Self::map_language_model_to_info(&model, &provider);
144 let model_id = model_info.id.clone();
145 provider_models.push(model_info);
146 models.insert(model_id, model);
147 }
148 if !provider_models.is_empty() {
149 language_model_list.insert(
150 acp_thread::AgentModelGroupName(provider.name().0.clone()),
151 provider_models,
152 );
153 }
154 }
155
156 self.models = models;
157 self.model_list = acp_thread::AgentModelList::Grouped(language_model_list);
158 self.refresh_models_tx.send(()).ok();
159 }
160
161 fn watch(&self) -> watch::Receiver<()> {
162 self.refresh_models_rx.clone()
163 }
164
165 pub fn model_from_id(&self, model_id: &acp::ModelId) -> Option<Arc<dyn LanguageModel>> {
166 self.models.get(model_id).cloned()
167 }
168
169 fn map_language_model_to_info(
170 model: &Arc<dyn LanguageModel>,
171 provider: &Arc<dyn LanguageModelProvider>,
172 ) -> acp_thread::AgentModelInfo {
173 acp_thread::AgentModelInfo {
174 id: Self::model_id(model),
175 name: model.name().0,
176 description: None,
177 icon: Some(match provider.icon() {
178 IconOrSvg::Svg(path) => acp_thread::AgentModelIcon::Path(path),
179 IconOrSvg::Icon(name) => acp_thread::AgentModelIcon::Named(name),
180 }),
181 is_latest: model.is_latest(),
182 cost: model.model_cost_info().map(|cost| cost.to_shared_string()),
183 }
184 }
185
186 fn model_id(model: &Arc<dyn LanguageModel>) -> acp::ModelId {
187 acp::ModelId::new(format!("{}/{}", model.provider_id().0, model.id().0))
188 }
189
190 fn authenticate_all_language_model_providers(cx: &mut App) -> Task<()> {
191 let authenticate_all_providers = LanguageModelRegistry::global(cx)
192 .read(cx)
193 .visible_providers()
194 .iter()
195 .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
196 .collect::<Vec<_>>();
197
198 cx.background_spawn(async move {
199 for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
200 if let Err(err) = authenticate_task.await {
201 match err {
202 language_model::AuthenticateError::CredentialsNotFound => {
203 // Since we're authenticating these providers in the
204 // background for the purposes of populating the
205 // language selector, we don't care about providers
206 // where the credentials are not found.
207 }
208 language_model::AuthenticateError::ConnectionRefused => {
209 // Not logging connection refused errors as they are mostly from LM Studio's noisy auth failures.
210 // LM Studio only has one auth method (endpoint call) which fails for users who haven't enabled it.
211 // TODO: Better manage LM Studio auth logic to avoid these noisy failures.
212 }
213 _ => {
214 // Some providers have noisy failure states that we
215 // don't want to spam the logs with every time the
216 // language model selector is initialized.
217 //
218 // Ideally these should have more clear failure modes
219 // that we know are safe to ignore here, like what we do
220 // with `CredentialsNotFound` above.
221 match provider_id.0.as_ref() {
222 "lmstudio" | "ollama" => {
223 // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
224 //
225 // These fail noisily, so we don't log them.
226 }
227 "copilot_chat" => {
228 // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
229 }
230 _ => {
231 log::error!(
232 "Failed to authenticate provider: {}: {err:#}",
233 provider_name.0
234 );
235 }
236 }
237 }
238 }
239 }
240 }
241 })
242 }
243}
244
245pub struct NativeAgent {
246 /// Session ID -> Session mapping
247 sessions: HashMap<acp::SessionId, Session>,
248 thread_store: Entity<ThreadStore>,
249 /// Project-specific state keyed by project EntityId
250 projects: HashMap<EntityId, ProjectState>,
251 /// Shared templates for all threads
252 templates: Arc<Templates>,
253 /// Cached model information
254 models: LanguageModels,
255 prompt_store: Option<Entity<PromptStore>>,
256 fs: Arc<dyn Fs>,
257 _subscriptions: Vec<Subscription>,
258}
259
260impl NativeAgent {
261 pub fn new(
262 thread_store: Entity<ThreadStore>,
263 templates: Arc<Templates>,
264 prompt_store: Option<Entity<PromptStore>>,
265 fs: Arc<dyn Fs>,
266 cx: &mut App,
267 ) -> Entity<NativeAgent> {
268 log::debug!("Creating new NativeAgent");
269
270 cx.new(|cx| {
271 let mut subscriptions = vec![cx.subscribe(
272 &LanguageModelRegistry::global(cx),
273 Self::handle_models_updated_event,
274 )];
275 if let Some(prompt_store) = prompt_store.as_ref() {
276 subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
277 }
278
279 Self {
280 sessions: HashMap::default(),
281 thread_store,
282 projects: HashMap::default(),
283 templates,
284 models: LanguageModels::new(cx),
285 prompt_store,
286 fs,
287 _subscriptions: subscriptions,
288 }
289 })
290 }
291
292 fn new_session(
293 &mut self,
294 project: Entity<Project>,
295 cx: &mut Context<Self>,
296 ) -> Entity<AcpThread> {
297 let project_id = self.get_or_create_project_state(&project, cx);
298 let project_state = &self.projects[&project_id];
299
300 let registry = LanguageModelRegistry::read_global(cx);
301 let available_count = registry.available_models(cx).count();
302 log::debug!("Total available models: {}", available_count);
303
304 let default_model = registry.default_model().and_then(|default_model| {
305 self.models
306 .model_from_id(&LanguageModels::model_id(&default_model.model))
307 });
308 let thread = cx.new(|cx| {
309 Thread::new(
310 project,
311 project_state.project_context.clone(),
312 project_state.context_server_registry.clone(),
313 self.templates.clone(),
314 default_model,
315 cx,
316 )
317 });
318
319 self.register_session(thread, project_id, cx)
320 }
321
322 fn register_session(
323 &mut self,
324 thread_handle: Entity<Thread>,
325 project_id: EntityId,
326 cx: &mut Context<Self>,
327 ) -> Entity<AcpThread> {
328 let connection = Rc::new(NativeAgentConnection(cx.entity()));
329
330 let thread = thread_handle.read(cx);
331 let session_id = thread.id().clone();
332 let parent_session_id = thread.parent_thread_id();
333 let title = thread.title();
334 let draft_prompt = thread.draft_prompt().map(Vec::from);
335 let scroll_position = thread.ui_scroll_position();
336 let token_usage = thread.latest_token_usage();
337 let project = thread.project.clone();
338 let action_log = thread.action_log.clone();
339 let prompt_capabilities_rx = thread.prompt_capabilities_rx.clone();
340 let acp_thread = cx.new(|cx| {
341 let mut acp_thread = acp_thread::AcpThread::new(
342 parent_session_id,
343 title,
344 None,
345 connection,
346 project.clone(),
347 action_log.clone(),
348 session_id.clone(),
349 prompt_capabilities_rx,
350 cx,
351 );
352 acp_thread.set_draft_prompt(draft_prompt);
353 acp_thread.set_ui_scroll_position(scroll_position);
354 acp_thread.update_token_usage(token_usage, cx);
355 acp_thread
356 });
357
358 let registry = LanguageModelRegistry::read_global(cx);
359 let summarization_model = registry.thread_summary_model().map(|c| c.model);
360
361 let weak = cx.weak_entity();
362 let weak_thread = thread_handle.downgrade();
363 thread_handle.update(cx, |thread, cx| {
364 thread.set_summarization_model(summarization_model, cx);
365 thread.add_default_tools(
366 Rc::new(NativeThreadEnvironment {
367 acp_thread: acp_thread.downgrade(),
368 thread: weak_thread,
369 agent: weak,
370 }) as _,
371 cx,
372 )
373 });
374
375 let subscriptions = vec![
376 cx.subscribe(&thread_handle, Self::handle_thread_title_updated),
377 cx.subscribe(&thread_handle, Self::handle_thread_token_usage_updated),
378 cx.observe(&thread_handle, move |this, thread, cx| {
379 this.save_thread(thread, cx)
380 }),
381 ];
382
383 self.sessions.insert(
384 session_id,
385 Session {
386 thread: thread_handle,
387 acp_thread: acp_thread.clone(),
388 project_id,
389 _subscriptions: subscriptions,
390 pending_save: Task::ready(()),
391 },
392 );
393
394 self.update_available_commands_for_project(project_id, cx);
395
396 acp_thread
397 }
398
399 pub fn models(&self) -> &LanguageModels {
400 &self.models
401 }
402
403 fn get_or_create_project_state(
404 &mut self,
405 project: &Entity<Project>,
406 cx: &mut Context<Self>,
407 ) -> EntityId {
408 let project_id = project.entity_id();
409 if self.projects.contains_key(&project_id) {
410 return project_id;
411 }
412
413 let project_context = cx.new(|_| ProjectContext::new(vec![], vec![]));
414 self.register_project_with_initial_context(project.clone(), project_context, cx);
415 if let Some(state) = self.projects.get_mut(&project_id) {
416 state.project_context_needs_refresh.send(()).ok();
417 }
418 project_id
419 }
420
421 fn register_project_with_initial_context(
422 &mut self,
423 project: Entity<Project>,
424 project_context: Entity<ProjectContext>,
425 cx: &mut Context<Self>,
426 ) {
427 let project_id = project.entity_id();
428
429 let context_server_store = project.read(cx).context_server_store();
430 let context_server_registry =
431 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
432
433 let subscriptions = vec![
434 cx.subscribe(&project, Self::handle_project_event),
435 cx.subscribe(
436 &context_server_store,
437 Self::handle_context_server_store_updated,
438 ),
439 cx.subscribe(
440 &context_server_registry,
441 Self::handle_context_server_registry_event,
442 ),
443 ];
444
445 let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
446 watch::channel(());
447
448 self.projects.insert(
449 project_id,
450 ProjectState {
451 project,
452 project_context,
453 project_context_needs_refresh: project_context_needs_refresh_tx,
454 _maintain_project_context: cx.spawn(async move |this, cx| {
455 Self::maintain_project_context(
456 this,
457 project_id,
458 project_context_needs_refresh_rx,
459 cx,
460 )
461 .await
462 }),
463 context_server_registry,
464 _subscriptions: subscriptions,
465 },
466 );
467 }
468
469 fn session_project_state(&self, session_id: &acp::SessionId) -> Option<&ProjectState> {
470 self.sessions
471 .get(session_id)
472 .and_then(|session| self.projects.get(&session.project_id))
473 }
474
475 async fn maintain_project_context(
476 this: WeakEntity<Self>,
477 project_id: EntityId,
478 mut needs_refresh: watch::Receiver<()>,
479 cx: &mut AsyncApp,
480 ) -> Result<()> {
481 while needs_refresh.changed().await.is_ok() {
482 let project_context = this
483 .update(cx, |this, cx| {
484 let state = this
485 .projects
486 .get(&project_id)
487 .context("project state not found")?;
488 anyhow::Ok(Self::build_project_context(
489 &state.project,
490 this.prompt_store.as_ref(),
491 cx,
492 ))
493 })??
494 .await;
495 this.update(cx, |this, cx| {
496 if let Some(state) = this.projects.get(&project_id) {
497 state
498 .project_context
499 .update(cx, |current_project_context, _cx| {
500 *current_project_context = project_context;
501 });
502 }
503 })?;
504 }
505
506 Ok(())
507 }
508
509 fn build_project_context(
510 project: &Entity<Project>,
511 prompt_store: Option<&Entity<PromptStore>>,
512 cx: &mut App,
513 ) -> Task<ProjectContext> {
514 let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
515 let worktree_tasks = worktrees
516 .into_iter()
517 .map(|worktree| {
518 Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
519 })
520 .collect::<Vec<_>>();
521 let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
522 prompt_store.read_with(cx, |prompt_store, cx| {
523 let prompts = prompt_store.default_prompt_metadata();
524 let load_tasks = prompts.into_iter().map(|prompt_metadata| {
525 let contents = prompt_store.load(prompt_metadata.id, cx);
526 async move { (contents.await, prompt_metadata) }
527 });
528 cx.background_spawn(future::join_all(load_tasks))
529 })
530 } else {
531 Task::ready(vec![])
532 };
533
534 cx.spawn(async move |_cx| {
535 let (worktrees, default_user_rules) =
536 future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
537
538 let worktrees = worktrees
539 .into_iter()
540 .map(|(worktree, _rules_error)| {
541 // TODO: show error message
542 // if let Some(rules_error) = rules_error {
543 // this.update(cx, |_, cx| cx.emit(rules_error)).ok();
544 // }
545 worktree
546 })
547 .collect::<Vec<_>>();
548
549 let default_user_rules = default_user_rules
550 .into_iter()
551 .flat_map(|(contents, prompt_metadata)| match contents {
552 Ok(contents) => Some(UserRulesContext {
553 uuid: prompt_metadata.id.as_user()?,
554 title: prompt_metadata.title.map(|title| title.to_string()),
555 contents,
556 }),
557 Err(_err) => {
558 // TODO: show error message
559 // this.update(cx, |_, cx| {
560 // cx.emit(RulesLoadingError {
561 // message: format!("{err:?}").into(),
562 // });
563 // })
564 // .ok();
565 None
566 }
567 })
568 .collect::<Vec<_>>();
569
570 ProjectContext::new(worktrees, default_user_rules)
571 })
572 }
573
574 fn load_worktree_info_for_system_prompt(
575 worktree: Entity<Worktree>,
576 project: Entity<Project>,
577 cx: &mut App,
578 ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
579 let tree = worktree.read(cx);
580 let root_name = tree.root_name_str().into();
581 let abs_path = tree.abs_path();
582
583 let mut context = WorktreeContext {
584 root_name,
585 abs_path,
586 rules_file: None,
587 };
588
589 let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
590 let Some(rules_task) = rules_task else {
591 return Task::ready((context, None));
592 };
593
594 cx.spawn(async move |_| {
595 let (rules_file, rules_file_error) = match rules_task.await {
596 Ok(rules_file) => (Some(rules_file), None),
597 Err(err) => (
598 None,
599 Some(RulesLoadingError {
600 message: format!("{err}").into(),
601 }),
602 ),
603 };
604 context.rules_file = rules_file;
605 (context, rules_file_error)
606 })
607 }
608
609 fn load_worktree_rules_file(
610 worktree: Entity<Worktree>,
611 project: Entity<Project>,
612 cx: &mut App,
613 ) -> Option<Task<Result<RulesFileContext>>> {
614 let worktree = worktree.read(cx);
615 let worktree_id = worktree.id();
616 let selected_rules_file = RULES_FILE_NAMES
617 .into_iter()
618 .filter_map(|name| {
619 worktree
620 .entry_for_path(RelPath::unix(name).unwrap())
621 .filter(|entry| entry.is_file())
622 .map(|entry| entry.path.clone())
623 })
624 .next();
625
626 // Note that Cline supports `.clinerules` being a directory, but that is not currently
627 // supported. This doesn't seem to occur often in GitHub repositories.
628 selected_rules_file.map(|path_in_worktree| {
629 let project_path = ProjectPath {
630 worktree_id,
631 path: path_in_worktree.clone(),
632 };
633 let buffer_task =
634 project.update(cx, |project, cx| project.open_buffer(project_path, cx));
635 let rope_task = cx.spawn(async move |cx| {
636 let buffer = buffer_task.await?;
637 let (project_entry_id, rope) = buffer.read_with(cx, |buffer, cx| {
638 let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
639 anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
640 })?;
641 anyhow::Ok((project_entry_id, rope))
642 });
643 // Build a string from the rope on a background thread.
644 cx.background_spawn(async move {
645 let (project_entry_id, rope) = rope_task.await?;
646 anyhow::Ok(RulesFileContext {
647 path_in_worktree,
648 text: rope.to_string().trim().to_string(),
649 project_entry_id: project_entry_id.to_usize(),
650 })
651 })
652 })
653 }
654
655 fn handle_thread_title_updated(
656 &mut self,
657 thread: Entity<Thread>,
658 _: &TitleUpdated,
659 cx: &mut Context<Self>,
660 ) {
661 let session_id = thread.read(cx).id();
662 let Some(session) = self.sessions.get(session_id) else {
663 return;
664 };
665 let thread = thread.downgrade();
666 let acp_thread = session.acp_thread.downgrade();
667 cx.spawn(async move |_, cx| {
668 let title = thread.read_with(cx, |thread, _| thread.title())?;
669 let task = acp_thread.update(cx, |acp_thread, cx| acp_thread.set_title(title, cx))?;
670 task.await
671 })
672 .detach_and_log_err(cx);
673 }
674
675 fn handle_thread_token_usage_updated(
676 &mut self,
677 thread: Entity<Thread>,
678 usage: &TokenUsageUpdated,
679 cx: &mut Context<Self>,
680 ) {
681 let Some(session) = self.sessions.get(thread.read(cx).id()) else {
682 return;
683 };
684 session.acp_thread.update(cx, |acp_thread, cx| {
685 acp_thread.update_token_usage(usage.0.clone(), cx);
686 });
687 }
688
689 fn handle_project_event(
690 &mut self,
691 project: Entity<Project>,
692 event: &project::Event,
693 _cx: &mut Context<Self>,
694 ) {
695 let project_id = project.entity_id();
696 let Some(state) = self.projects.get_mut(&project_id) else {
697 return;
698 };
699 match event {
700 project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
701 state.project_context_needs_refresh.send(()).ok();
702 }
703 project::Event::WorktreeUpdatedEntries(_, items) => {
704 if items.iter().any(|(path, _, _)| {
705 RULES_FILE_NAMES
706 .iter()
707 .any(|name| path.as_ref() == RelPath::unix(name).unwrap())
708 }) {
709 state.project_context_needs_refresh.send(()).ok();
710 }
711 }
712 _ => {}
713 }
714 }
715
716 fn handle_prompts_updated_event(
717 &mut self,
718 _prompt_store: Entity<PromptStore>,
719 _event: &prompt_store::PromptsUpdatedEvent,
720 _cx: &mut Context<Self>,
721 ) {
722 for state in self.projects.values_mut() {
723 state.project_context_needs_refresh.send(()).ok();
724 }
725 }
726
727 fn handle_models_updated_event(
728 &mut self,
729 _registry: Entity<LanguageModelRegistry>,
730 _event: &language_model::Event,
731 cx: &mut Context<Self>,
732 ) {
733 self.models.refresh_list(cx);
734
735 let registry = LanguageModelRegistry::read_global(cx);
736 let default_model = registry.default_model().map(|m| m.model);
737 let summarization_model = registry.thread_summary_model().map(|m| m.model);
738
739 for session in self.sessions.values_mut() {
740 session.thread.update(cx, |thread, cx| {
741 if thread.model().is_none()
742 && let Some(model) = default_model.clone()
743 {
744 thread.set_model(model, cx);
745 cx.notify();
746 }
747 thread.set_summarization_model(summarization_model.clone(), cx);
748 });
749 }
750 }
751
752 fn handle_context_server_store_updated(
753 &mut self,
754 store: Entity<project::context_server_store::ContextServerStore>,
755 _event: &project::context_server_store::ServerStatusChangedEvent,
756 cx: &mut Context<Self>,
757 ) {
758 let project_id = self.projects.iter().find_map(|(id, state)| {
759 if *state.context_server_registry.read(cx).server_store() == store {
760 Some(*id)
761 } else {
762 None
763 }
764 });
765 if let Some(project_id) = project_id {
766 self.update_available_commands_for_project(project_id, cx);
767 }
768 }
769
770 fn handle_context_server_registry_event(
771 &mut self,
772 registry: Entity<ContextServerRegistry>,
773 event: &ContextServerRegistryEvent,
774 cx: &mut Context<Self>,
775 ) {
776 match event {
777 ContextServerRegistryEvent::ToolsChanged => {}
778 ContextServerRegistryEvent::PromptsChanged => {
779 let project_id = self.projects.iter().find_map(|(id, state)| {
780 if state.context_server_registry == registry {
781 Some(*id)
782 } else {
783 None
784 }
785 });
786 if let Some(project_id) = project_id {
787 self.update_available_commands_for_project(project_id, cx);
788 }
789 }
790 }
791 }
792
793 fn update_available_commands_for_project(&self, project_id: EntityId, cx: &mut Context<Self>) {
794 let available_commands =
795 Self::build_available_commands_for_project(self.projects.get(&project_id), cx);
796 for session in self.sessions.values() {
797 if session.project_id != project_id {
798 continue;
799 }
800 session.acp_thread.update(cx, |thread, cx| {
801 thread
802 .handle_session_update(
803 acp::SessionUpdate::AvailableCommandsUpdate(
804 acp::AvailableCommandsUpdate::new(available_commands.clone()),
805 ),
806 cx,
807 )
808 .log_err();
809 });
810 }
811 }
812
813 fn build_available_commands_for_project(
814 project_state: Option<&ProjectState>,
815 cx: &App,
816 ) -> Vec<acp::AvailableCommand> {
817 let Some(state) = project_state else {
818 return vec![];
819 };
820 let registry = state.context_server_registry.read(cx);
821
822 let mut prompt_name_counts: HashMap<&str, usize> = HashMap::default();
823 for context_server_prompt in registry.prompts() {
824 *prompt_name_counts
825 .entry(context_server_prompt.prompt.name.as_str())
826 .or_insert(0) += 1;
827 }
828
829 registry
830 .prompts()
831 .flat_map(|context_server_prompt| {
832 let prompt = &context_server_prompt.prompt;
833
834 let should_prefix = prompt_name_counts
835 .get(prompt.name.as_str())
836 .copied()
837 .unwrap_or(0)
838 > 1;
839
840 let name = if should_prefix {
841 format!("{}.{}", context_server_prompt.server_id, prompt.name)
842 } else {
843 prompt.name.clone()
844 };
845
846 let mut command = acp::AvailableCommand::new(
847 name,
848 prompt.description.clone().unwrap_or_default(),
849 );
850
851 match prompt.arguments.as_deref() {
852 Some([arg]) => {
853 let hint = format!("<{}>", arg.name);
854
855 command = command.input(acp::AvailableCommandInput::Unstructured(
856 acp::UnstructuredCommandInput::new(hint),
857 ));
858 }
859 Some([]) | None => {}
860 Some(_) => {
861 // skip >1 argument commands since we don't support them yet
862 return None;
863 }
864 }
865
866 Some(command)
867 })
868 .collect()
869 }
870
871 pub fn load_thread(
872 &mut self,
873 id: acp::SessionId,
874 project: Entity<Project>,
875 cx: &mut Context<Self>,
876 ) -> Task<Result<Entity<Thread>>> {
877 let database_future = ThreadsDatabase::connect(cx);
878 cx.spawn(async move |this, cx| {
879 let database = database_future.await.map_err(|err| anyhow!(err))?;
880 let db_thread = database
881 .load_thread(id.clone())
882 .await?
883 .with_context(|| format!("no thread found with ID: {id:?}"))?;
884
885 this.update(cx, |this, cx| {
886 let project_id = this.get_or_create_project_state(&project, cx);
887 let project_state = this
888 .projects
889 .get(&project_id)
890 .context("project state not found")?;
891 let summarization_model = LanguageModelRegistry::read_global(cx)
892 .thread_summary_model()
893 .map(|c| c.model);
894
895 Ok(cx.new(|cx| {
896 let mut thread = Thread::from_db(
897 id.clone(),
898 db_thread,
899 project_state.project.clone(),
900 project_state.project_context.clone(),
901 project_state.context_server_registry.clone(),
902 this.templates.clone(),
903 cx,
904 );
905 thread.set_summarization_model(summarization_model, cx);
906 thread
907 }))
908 })?
909 })
910 }
911
912 pub fn open_thread(
913 &mut self,
914 id: acp::SessionId,
915 project: Entity<Project>,
916 cx: &mut Context<Self>,
917 ) -> Task<Result<Entity<AcpThread>>> {
918 if let Some(session) = self.sessions.get(&id) {
919 return Task::ready(Ok(session.acp_thread.clone()));
920 }
921
922 let task = self.load_thread(id, project.clone(), cx);
923 cx.spawn(async move |this, cx| {
924 let thread = task.await?;
925 let acp_thread = this.update(cx, |this, cx| {
926 let project_id = this.get_or_create_project_state(&project, cx);
927 this.register_session(thread.clone(), project_id, cx)
928 })?;
929 let events = thread.update(cx, |thread, cx| thread.replay(cx));
930 cx.update(|cx| {
931 NativeAgentConnection::handle_thread_events(events, acp_thread.downgrade(), cx)
932 })
933 .await?;
934 Ok(acp_thread)
935 })
936 }
937
938 pub fn thread_summary(
939 &mut self,
940 id: acp::SessionId,
941 project: Entity<Project>,
942 cx: &mut Context<Self>,
943 ) -> Task<Result<SharedString>> {
944 let thread = self.open_thread(id.clone(), project, cx);
945 cx.spawn(async move |this, cx| {
946 let acp_thread = thread.await?;
947 let result = this
948 .update(cx, |this, cx| {
949 this.sessions
950 .get(&id)
951 .unwrap()
952 .thread
953 .update(cx, |thread, cx| thread.summary(cx))
954 })?
955 .await
956 .context("Failed to generate summary")?;
957 drop(acp_thread);
958 Ok(result)
959 })
960 }
961
962 fn save_thread(&mut self, thread: Entity<Thread>, cx: &mut Context<Self>) {
963 if thread.read(cx).is_empty() {
964 return;
965 }
966
967 let id = thread.read(cx).id().clone();
968 let Some(session) = self.sessions.get_mut(&id) else {
969 return;
970 };
971
972 let project_id = session.project_id;
973 let Some(state) = self.projects.get(&project_id) else {
974 return;
975 };
976
977 let folder_paths = PathList::new(
978 &state
979 .project
980 .read(cx)
981 .visible_worktrees(cx)
982 .map(|worktree| worktree.read(cx).abs_path().to_path_buf())
983 .collect::<Vec<_>>(),
984 );
985
986 let draft_prompt = session.acp_thread.read(cx).draft_prompt().map(Vec::from);
987 let database_future = ThreadsDatabase::connect(cx);
988 let db_thread = thread.update(cx, |thread, cx| {
989 thread.set_draft_prompt(draft_prompt);
990 thread.to_db(cx)
991 });
992 let thread_store = self.thread_store.clone();
993 session.pending_save = cx.spawn(async move |_, cx| {
994 let Some(database) = database_future.await.map_err(|err| anyhow!(err)).log_err() else {
995 return;
996 };
997 let db_thread = db_thread.await;
998 database
999 .save_thread(id, db_thread, folder_paths)
1000 .await
1001 .log_err();
1002 thread_store.update(cx, |store, cx| store.reload(cx));
1003 });
1004 }
1005
1006 fn send_mcp_prompt(
1007 &self,
1008 message_id: UserMessageId,
1009 session_id: acp::SessionId,
1010 prompt_name: String,
1011 server_id: ContextServerId,
1012 arguments: HashMap<String, String>,
1013 original_content: Vec<acp::ContentBlock>,
1014 cx: &mut Context<Self>,
1015 ) -> Task<Result<acp::PromptResponse>> {
1016 let Some(state) = self.session_project_state(&session_id) else {
1017 return Task::ready(Err(anyhow!("Project state not found for session")));
1018 };
1019 let server_store = state
1020 .context_server_registry
1021 .read(cx)
1022 .server_store()
1023 .clone();
1024 let path_style = state.project.read(cx).path_style(cx);
1025
1026 cx.spawn(async move |this, cx| {
1027 let prompt =
1028 crate::get_prompt(&server_store, &server_id, &prompt_name, arguments, cx).await?;
1029
1030 let (acp_thread, thread) = this.update(cx, |this, _cx| {
1031 let session = this
1032 .sessions
1033 .get(&session_id)
1034 .context("Failed to get session")?;
1035 anyhow::Ok((session.acp_thread.clone(), session.thread.clone()))
1036 })??;
1037
1038 let mut last_is_user = true;
1039
1040 thread.update(cx, |thread, cx| {
1041 thread.push_acp_user_block(
1042 message_id,
1043 original_content.into_iter().skip(1),
1044 path_style,
1045 cx,
1046 );
1047 });
1048
1049 for message in prompt.messages {
1050 let context_server::types::PromptMessage { role, content } = message;
1051 let block = mcp_message_content_to_acp_content_block(content);
1052
1053 match role {
1054 context_server::types::Role::User => {
1055 let id = acp_thread::UserMessageId::new();
1056
1057 acp_thread.update(cx, |acp_thread, cx| {
1058 acp_thread.push_user_content_block_with_indent(
1059 Some(id.clone()),
1060 block.clone(),
1061 true,
1062 cx,
1063 );
1064 });
1065
1066 thread.update(cx, |thread, cx| {
1067 thread.push_acp_user_block(id, [block], path_style, cx);
1068 });
1069 }
1070 context_server::types::Role::Assistant => {
1071 acp_thread.update(cx, |acp_thread, cx| {
1072 acp_thread.push_assistant_content_block_with_indent(
1073 block.clone(),
1074 false,
1075 true,
1076 cx,
1077 );
1078 });
1079
1080 thread.update(cx, |thread, cx| {
1081 thread.push_acp_agent_block(block, cx);
1082 });
1083 }
1084 }
1085
1086 last_is_user = role == context_server::types::Role::User;
1087 }
1088
1089 let response_stream = thread.update(cx, |thread, cx| {
1090 if last_is_user {
1091 thread.send_existing(cx)
1092 } else {
1093 // Resume if MCP prompt did not end with a user message
1094 thread.resume(cx)
1095 }
1096 })?;
1097
1098 cx.update(|cx| {
1099 NativeAgentConnection::handle_thread_events(
1100 response_stream,
1101 acp_thread.downgrade(),
1102 cx,
1103 )
1104 })
1105 .await
1106 })
1107 }
1108}
1109
1110/// Wrapper struct that implements the AgentConnection trait
1111#[derive(Clone)]
1112pub struct NativeAgentConnection(pub Entity<NativeAgent>);
1113
1114impl NativeAgentConnection {
1115 pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option<Entity<Thread>> {
1116 self.0
1117 .read(cx)
1118 .sessions
1119 .get(session_id)
1120 .map(|session| session.thread.clone())
1121 }
1122
1123 pub fn load_thread(
1124 &self,
1125 id: acp::SessionId,
1126 project: Entity<Project>,
1127 cx: &mut App,
1128 ) -> Task<Result<Entity<Thread>>> {
1129 self.0
1130 .update(cx, |this, cx| this.load_thread(id, project, cx))
1131 }
1132
1133 fn run_turn(
1134 &self,
1135 session_id: acp::SessionId,
1136 cx: &mut App,
1137 f: impl 'static
1138 + FnOnce(Entity<Thread>, &mut App) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>,
1139 ) -> Task<Result<acp::PromptResponse>> {
1140 let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
1141 agent
1142 .sessions
1143 .get_mut(&session_id)
1144 .map(|s| (s.thread.clone(), s.acp_thread.clone()))
1145 }) else {
1146 return Task::ready(Err(anyhow!("Session not found")));
1147 };
1148 log::debug!("Found session for: {}", session_id);
1149
1150 let response_stream = match f(thread, cx) {
1151 Ok(stream) => stream,
1152 Err(err) => return Task::ready(Err(err)),
1153 };
1154 Self::handle_thread_events(response_stream, acp_thread.downgrade(), cx)
1155 }
1156
1157 fn handle_thread_events(
1158 mut events: mpsc::UnboundedReceiver<Result<ThreadEvent>>,
1159 acp_thread: WeakEntity<AcpThread>,
1160 cx: &App,
1161 ) -> Task<Result<acp::PromptResponse>> {
1162 cx.spawn(async move |cx| {
1163 // Handle response stream and forward to session.acp_thread
1164 while let Some(result) = events.next().await {
1165 match result {
1166 Ok(event) => {
1167 log::trace!("Received completion event: {:?}", event);
1168
1169 match event {
1170 ThreadEvent::UserMessage(message) => {
1171 acp_thread.update(cx, |thread, cx| {
1172 for content in message.content {
1173 thread.push_user_content_block(
1174 Some(message.id.clone()),
1175 content.into(),
1176 cx,
1177 );
1178 }
1179 })?;
1180 }
1181 ThreadEvent::AgentText(text) => {
1182 acp_thread.update(cx, |thread, cx| {
1183 thread.push_assistant_content_block(text.into(), false, cx)
1184 })?;
1185 }
1186 ThreadEvent::AgentThinking(text) => {
1187 acp_thread.update(cx, |thread, cx| {
1188 thread.push_assistant_content_block(text.into(), true, cx)
1189 })?;
1190 }
1191 ThreadEvent::ToolCallAuthorization(ToolCallAuthorization {
1192 tool_call,
1193 options,
1194 response,
1195 context: _,
1196 }) => {
1197 let outcome_task = acp_thread.update(cx, |thread, cx| {
1198 thread.request_tool_call_authorization(tool_call, options, cx)
1199 })??;
1200 cx.background_spawn(async move {
1201 if let acp_thread::RequestPermissionOutcome::Selected(outcome) =
1202 outcome_task.await
1203 {
1204 response
1205 .send(outcome)
1206 .map(|_| anyhow!("authorization receiver was dropped"))
1207 .log_err();
1208 }
1209 })
1210 .detach();
1211 }
1212 ThreadEvent::ToolCall(tool_call) => {
1213 acp_thread.update(cx, |thread, cx| {
1214 thread.upsert_tool_call(tool_call, cx)
1215 })??;
1216 }
1217 ThreadEvent::ToolCallUpdate(update) => {
1218 acp_thread.update(cx, |thread, cx| {
1219 thread.update_tool_call(update, cx)
1220 })??;
1221 }
1222 ThreadEvent::Plan(plan) => {
1223 acp_thread.update(cx, |thread, cx| thread.update_plan(plan, cx))?;
1224 }
1225 ThreadEvent::SubagentSpawned(session_id) => {
1226 acp_thread.update(cx, |thread, cx| {
1227 thread.subagent_spawned(session_id, cx);
1228 })?;
1229 }
1230 ThreadEvent::Retry(status) => {
1231 acp_thread.update(cx, |thread, cx| {
1232 thread.update_retry_status(status, cx)
1233 })?;
1234 }
1235 ThreadEvent::Stop(stop_reason) => {
1236 log::debug!("Assistant message complete: {:?}", stop_reason);
1237 return Ok(acp::PromptResponse::new(stop_reason));
1238 }
1239 }
1240 }
1241 Err(e) => {
1242 log::error!("Error in model response stream: {:?}", e);
1243 return Err(e);
1244 }
1245 }
1246 }
1247
1248 log::debug!("Response stream completed");
1249 anyhow::Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
1250 })
1251 }
1252}
1253
1254struct Command<'a> {
1255 prompt_name: &'a str,
1256 arg_value: &'a str,
1257 explicit_server_id: Option<&'a str>,
1258}
1259
1260impl<'a> Command<'a> {
1261 fn parse(prompt: &'a [acp::ContentBlock]) -> Option<Self> {
1262 let acp::ContentBlock::Text(text_content) = prompt.first()? else {
1263 return None;
1264 };
1265 let text = text_content.text.trim();
1266 let command = text.strip_prefix('/')?;
1267 let (command, arg_value) = command
1268 .split_once(char::is_whitespace)
1269 .unwrap_or((command, ""));
1270
1271 if let Some((server_id, prompt_name)) = command.split_once('.') {
1272 Some(Self {
1273 prompt_name,
1274 arg_value,
1275 explicit_server_id: Some(server_id),
1276 })
1277 } else {
1278 Some(Self {
1279 prompt_name: command,
1280 arg_value,
1281 explicit_server_id: None,
1282 })
1283 }
1284 }
1285}
1286
1287struct NativeAgentModelSelector {
1288 session_id: acp::SessionId,
1289 connection: NativeAgentConnection,
1290}
1291
1292impl acp_thread::AgentModelSelector for NativeAgentModelSelector {
1293 fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
1294 log::debug!("NativeAgentConnection::list_models called");
1295 let list = self.connection.0.read(cx).models.model_list.clone();
1296 Task::ready(if list.is_empty() {
1297 Err(anyhow::anyhow!("No models available"))
1298 } else {
1299 Ok(list)
1300 })
1301 }
1302
1303 fn select_model(&self, model_id: acp::ModelId, cx: &mut App) -> Task<Result<()>> {
1304 log::debug!(
1305 "Setting model for session {}: {}",
1306 self.session_id,
1307 model_id
1308 );
1309 let Some(thread) = self
1310 .connection
1311 .0
1312 .read(cx)
1313 .sessions
1314 .get(&self.session_id)
1315 .map(|session| session.thread.clone())
1316 else {
1317 return Task::ready(Err(anyhow!("Session not found")));
1318 };
1319
1320 let Some(model) = self.connection.0.read(cx).models.model_from_id(&model_id) else {
1321 return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
1322 };
1323
1324 // We want to reset the effort level when switching models, as the currently-selected effort level may
1325 // not be compatible.
1326 let effort = model
1327 .default_effort_level()
1328 .map(|effort_level| effort_level.value.to_string());
1329
1330 thread.update(cx, |thread, cx| {
1331 thread.set_model(model.clone(), cx);
1332 thread.set_thinking_effort(effort.clone(), cx);
1333 thread.set_thinking_enabled(model.supports_thinking(), cx);
1334 });
1335
1336 update_settings_file(
1337 self.connection.0.read(cx).fs.clone(),
1338 cx,
1339 move |settings, cx| {
1340 let provider = model.provider_id().0.to_string();
1341 let model = model.id().0.to_string();
1342 let enable_thinking = thread.read(cx).thinking_enabled();
1343 settings
1344 .agent
1345 .get_or_insert_default()
1346 .set_model(LanguageModelSelection {
1347 provider: provider.into(),
1348 model,
1349 enable_thinking,
1350 effort,
1351 });
1352 },
1353 );
1354
1355 Task::ready(Ok(()))
1356 }
1357
1358 fn selected_model(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelInfo>> {
1359 let Some(thread) = self
1360 .connection
1361 .0
1362 .read(cx)
1363 .sessions
1364 .get(&self.session_id)
1365 .map(|session| session.thread.clone())
1366 else {
1367 return Task::ready(Err(anyhow!("Session not found")));
1368 };
1369 let Some(model) = thread.read(cx).model() else {
1370 return Task::ready(Err(anyhow!("Model not found")));
1371 };
1372 let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
1373 else {
1374 return Task::ready(Err(anyhow!("Provider not found")));
1375 };
1376 Task::ready(Ok(LanguageModels::map_language_model_to_info(
1377 model, &provider,
1378 )))
1379 }
1380
1381 fn watch(&self, cx: &mut App) -> Option<watch::Receiver<()>> {
1382 Some(self.connection.0.read(cx).models.watch())
1383 }
1384
1385 fn should_render_footer(&self) -> bool {
1386 true
1387 }
1388}
1389
1390pub static ZED_AGENT_ID: LazyLock<AgentId> = LazyLock::new(|| AgentId::new("Zed Agent"));
1391
1392impl acp_thread::AgentConnection for NativeAgentConnection {
1393 fn agent_id(&self) -> AgentId {
1394 ZED_AGENT_ID.clone()
1395 }
1396
1397 fn telemetry_id(&self) -> SharedString {
1398 "zed".into()
1399 }
1400
1401 fn new_session(
1402 self: Rc<Self>,
1403 project: Entity<Project>,
1404 work_dirs: PathList,
1405 cx: &mut App,
1406 ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
1407 log::debug!("Creating new thread for project at: {work_dirs:?}");
1408 Task::ready(Ok(self
1409 .0
1410 .update(cx, |agent, cx| agent.new_session(project, cx))))
1411 }
1412
1413 fn supports_load_session(&self) -> bool {
1414 true
1415 }
1416
1417 fn load_session(
1418 self: Rc<Self>,
1419 session_id: acp::SessionId,
1420 project: Entity<Project>,
1421 _work_dirs: PathList,
1422 _title: Option<SharedString>,
1423 cx: &mut App,
1424 ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
1425 self.0
1426 .update(cx, |agent, cx| agent.open_thread(session_id, project, cx))
1427 }
1428
1429 fn supports_close_session(&self) -> bool {
1430 true
1431 }
1432
1433 fn close_session(
1434 self: Rc<Self>,
1435 session_id: &acp::SessionId,
1436 cx: &mut App,
1437 ) -> Task<Result<()>> {
1438 self.0.update(cx, |agent, cx| {
1439 let Some(session) = agent.sessions.remove(session_id) else {
1440 return;
1441 };
1442 let project_id = session.project_id;
1443 agent.save_thread(session.thread, cx);
1444
1445 let has_remaining = agent.sessions.values().any(|s| s.project_id == project_id);
1446 if !has_remaining {
1447 agent.projects.remove(&project_id);
1448 }
1449 });
1450 Task::ready(Ok(()))
1451 }
1452
1453 fn auth_methods(&self) -> &[acp::AuthMethod] {
1454 &[] // No auth for in-process
1455 }
1456
1457 fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
1458 Task::ready(Ok(()))
1459 }
1460
1461 fn model_selector(&self, session_id: &acp::SessionId) -> Option<Rc<dyn AgentModelSelector>> {
1462 Some(Rc::new(NativeAgentModelSelector {
1463 session_id: session_id.clone(),
1464 connection: self.clone(),
1465 }) as Rc<dyn AgentModelSelector>)
1466 }
1467
1468 fn prompt(
1469 &self,
1470 id: Option<acp_thread::UserMessageId>,
1471 params: acp::PromptRequest,
1472 cx: &mut App,
1473 ) -> Task<Result<acp::PromptResponse>> {
1474 let id = id.expect("UserMessageId is required");
1475 let session_id = params.session_id.clone();
1476 log::info!("Received prompt request for session: {}", session_id);
1477 log::debug!("Prompt blocks count: {}", params.prompt.len());
1478
1479 let Some(project_state) = self.0.read(cx).session_project_state(&session_id) else {
1480 return Task::ready(Err(anyhow::anyhow!("Session not found")));
1481 };
1482
1483 if let Some(parsed_command) = Command::parse(¶ms.prompt) {
1484 let registry = project_state.context_server_registry.read(cx);
1485
1486 let explicit_server_id = parsed_command
1487 .explicit_server_id
1488 .map(|server_id| ContextServerId(server_id.into()));
1489
1490 if let Some(prompt) =
1491 registry.find_prompt(explicit_server_id.as_ref(), parsed_command.prompt_name)
1492 {
1493 let arguments = if !parsed_command.arg_value.is_empty()
1494 && let Some(arg_name) = prompt
1495 .prompt
1496 .arguments
1497 .as_ref()
1498 .and_then(|args| args.first())
1499 .map(|arg| arg.name.clone())
1500 {
1501 HashMap::from_iter([(arg_name, parsed_command.arg_value.to_string())])
1502 } else {
1503 Default::default()
1504 };
1505
1506 let prompt_name = prompt.prompt.name.clone();
1507 let server_id = prompt.server_id.clone();
1508
1509 return self.0.update(cx, |agent, cx| {
1510 agent.send_mcp_prompt(
1511 id,
1512 session_id.clone(),
1513 prompt_name,
1514 server_id,
1515 arguments,
1516 params.prompt,
1517 cx,
1518 )
1519 });
1520 }
1521 };
1522
1523 let path_style = project_state.project.read(cx).path_style(cx);
1524
1525 self.run_turn(session_id, cx, move |thread, cx| {
1526 let content: Vec<UserMessageContent> = params
1527 .prompt
1528 .into_iter()
1529 .map(|block| UserMessageContent::from_content_block(block, path_style))
1530 .collect::<Vec<_>>();
1531 log::debug!("Converted prompt to message: {} chars", content.len());
1532 log::debug!("Message id: {:?}", id);
1533 log::debug!("Message content: {:?}", content);
1534
1535 thread.update(cx, |thread, cx| thread.send(id, content, cx))
1536 })
1537 }
1538
1539 fn retry(
1540 &self,
1541 session_id: &acp::SessionId,
1542 _cx: &App,
1543 ) -> Option<Rc<dyn acp_thread::AgentSessionRetry>> {
1544 Some(Rc::new(NativeAgentSessionRetry {
1545 connection: self.clone(),
1546 session_id: session_id.clone(),
1547 }) as _)
1548 }
1549
1550 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
1551 log::info!("Cancelling on session: {}", session_id);
1552 self.0.update(cx, |agent, cx| {
1553 if let Some(session) = agent.sessions.get(session_id) {
1554 session
1555 .thread
1556 .update(cx, |thread, cx| thread.cancel(cx))
1557 .detach();
1558 }
1559 });
1560 }
1561
1562 fn truncate(
1563 &self,
1564 session_id: &acp::SessionId,
1565 cx: &App,
1566 ) -> Option<Rc<dyn acp_thread::AgentSessionTruncate>> {
1567 self.0.read_with(cx, |agent, _cx| {
1568 agent.sessions.get(session_id).map(|session| {
1569 Rc::new(NativeAgentSessionTruncate {
1570 thread: session.thread.clone(),
1571 acp_thread: session.acp_thread.downgrade(),
1572 }) as _
1573 })
1574 })
1575 }
1576
1577 fn set_title(
1578 &self,
1579 session_id: &acp::SessionId,
1580 cx: &App,
1581 ) -> Option<Rc<dyn acp_thread::AgentSessionSetTitle>> {
1582 self.0.read_with(cx, |agent, _cx| {
1583 agent
1584 .sessions
1585 .get(session_id)
1586 .filter(|s| !s.thread.read(cx).is_subagent())
1587 .map(|session| {
1588 Rc::new(NativeAgentSessionSetTitle {
1589 thread: session.thread.clone(),
1590 }) as _
1591 })
1592 })
1593 }
1594
1595 fn session_list(&self, cx: &mut App) -> Option<Rc<dyn AgentSessionList>> {
1596 let thread_store = self.0.read(cx).thread_store.clone();
1597 Some(Rc::new(NativeAgentSessionList::new(thread_store, cx)) as _)
1598 }
1599
1600 fn telemetry(&self) -> Option<Rc<dyn acp_thread::AgentTelemetry>> {
1601 Some(Rc::new(self.clone()) as Rc<dyn acp_thread::AgentTelemetry>)
1602 }
1603
1604 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1605 self
1606 }
1607}
1608
1609impl acp_thread::AgentTelemetry for NativeAgentConnection {
1610 fn thread_data(
1611 &self,
1612 session_id: &acp::SessionId,
1613 cx: &mut App,
1614 ) -> Task<Result<serde_json::Value>> {
1615 let Some(session) = self.0.read(cx).sessions.get(session_id) else {
1616 return Task::ready(Err(anyhow!("Session not found")));
1617 };
1618
1619 let task = session.thread.read(cx).to_db(cx);
1620 cx.background_spawn(async move {
1621 serde_json::to_value(task.await).context("Failed to serialize thread")
1622 })
1623 }
1624}
1625
1626pub struct NativeAgentSessionList {
1627 thread_store: Entity<ThreadStore>,
1628 updates_tx: smol::channel::Sender<acp_thread::SessionListUpdate>,
1629 updates_rx: smol::channel::Receiver<acp_thread::SessionListUpdate>,
1630 _subscription: Subscription,
1631}
1632
1633impl NativeAgentSessionList {
1634 fn new(thread_store: Entity<ThreadStore>, cx: &mut App) -> Self {
1635 let (tx, rx) = smol::channel::unbounded();
1636 let this_tx = tx.clone();
1637 let subscription = cx.observe(&thread_store, move |_, _| {
1638 this_tx
1639 .try_send(acp_thread::SessionListUpdate::Refresh)
1640 .ok();
1641 });
1642 Self {
1643 thread_store,
1644 updates_tx: tx,
1645 updates_rx: rx,
1646 _subscription: subscription,
1647 }
1648 }
1649
1650 pub fn thread_store(&self) -> &Entity<ThreadStore> {
1651 &self.thread_store
1652 }
1653}
1654
1655impl AgentSessionList for NativeAgentSessionList {
1656 fn list_sessions(
1657 &self,
1658 _request: AgentSessionListRequest,
1659 cx: &mut App,
1660 ) -> Task<Result<AgentSessionListResponse>> {
1661 let sessions = self
1662 .thread_store
1663 .read(cx)
1664 .entries()
1665 .map(|entry| AgentSessionInfo::from(&entry))
1666 .collect();
1667 Task::ready(Ok(AgentSessionListResponse::new(sessions)))
1668 }
1669
1670 fn supports_delete(&self) -> bool {
1671 true
1672 }
1673
1674 fn delete_session(&self, session_id: &acp::SessionId, cx: &mut App) -> Task<Result<()>> {
1675 self.thread_store
1676 .update(cx, |store, cx| store.delete_thread(session_id.clone(), cx))
1677 }
1678
1679 fn delete_sessions(&self, cx: &mut App) -> Task<Result<()>> {
1680 self.thread_store
1681 .update(cx, |store, cx| store.delete_threads(cx))
1682 }
1683
1684 fn watch(
1685 &self,
1686 _cx: &mut App,
1687 ) -> Option<smol::channel::Receiver<acp_thread::SessionListUpdate>> {
1688 Some(self.updates_rx.clone())
1689 }
1690
1691 fn notify_refresh(&self) {
1692 self.updates_tx
1693 .try_send(acp_thread::SessionListUpdate::Refresh)
1694 .ok();
1695 }
1696
1697 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1698 self
1699 }
1700}
1701
1702struct NativeAgentSessionTruncate {
1703 thread: Entity<Thread>,
1704 acp_thread: WeakEntity<AcpThread>,
1705}
1706
1707impl acp_thread::AgentSessionTruncate for NativeAgentSessionTruncate {
1708 fn run(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
1709 match self.thread.update(cx, |thread, cx| {
1710 thread.truncate(message_id.clone(), cx)?;
1711 Ok(thread.latest_token_usage())
1712 }) {
1713 Ok(usage) => {
1714 self.acp_thread
1715 .update(cx, |thread, cx| {
1716 thread.update_token_usage(usage, cx);
1717 })
1718 .ok();
1719 Task::ready(Ok(()))
1720 }
1721 Err(error) => Task::ready(Err(error)),
1722 }
1723 }
1724}
1725
1726struct NativeAgentSessionRetry {
1727 connection: NativeAgentConnection,
1728 session_id: acp::SessionId,
1729}
1730
1731impl acp_thread::AgentSessionRetry for NativeAgentSessionRetry {
1732 fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>> {
1733 self.connection
1734 .run_turn(self.session_id.clone(), cx, |thread, cx| {
1735 thread.update(cx, |thread, cx| thread.resume(cx))
1736 })
1737 }
1738}
1739
1740struct NativeAgentSessionSetTitle {
1741 thread: Entity<Thread>,
1742}
1743
1744impl acp_thread::AgentSessionSetTitle for NativeAgentSessionSetTitle {
1745 fn run(&self, title: SharedString, cx: &mut App) -> Task<Result<()>> {
1746 self.thread
1747 .update(cx, |thread, cx| thread.set_title(title, cx));
1748 Task::ready(Ok(()))
1749 }
1750}
1751
1752pub struct NativeThreadEnvironment {
1753 agent: WeakEntity<NativeAgent>,
1754 thread: WeakEntity<Thread>,
1755 acp_thread: WeakEntity<AcpThread>,
1756}
1757
1758impl NativeThreadEnvironment {
1759 pub(crate) fn create_subagent_thread(
1760 &self,
1761 label: String,
1762 cx: &mut App,
1763 ) -> Result<Rc<dyn SubagentHandle>> {
1764 let Some(parent_thread_entity) = self.thread.upgrade() else {
1765 anyhow::bail!("Parent thread no longer exists".to_string());
1766 };
1767 let parent_thread = parent_thread_entity.read(cx);
1768 let current_depth = parent_thread.depth();
1769 let parent_session_id = parent_thread.id().clone();
1770
1771 if current_depth >= MAX_SUBAGENT_DEPTH {
1772 return Err(anyhow!(
1773 "Maximum subagent depth ({}) reached",
1774 MAX_SUBAGENT_DEPTH
1775 ));
1776 }
1777
1778 let subagent_thread: Entity<Thread> = cx.new(|cx| {
1779 let mut thread = Thread::new_subagent(&parent_thread_entity, cx);
1780 thread.set_title(label.into(), cx);
1781 thread
1782 });
1783
1784 let session_id = subagent_thread.read(cx).id().clone();
1785
1786 let acp_thread = self
1787 .agent
1788 .update(cx, |agent, cx| -> Result<Entity<AcpThread>> {
1789 let project_id = agent
1790 .sessions
1791 .get(&parent_session_id)
1792 .map(|s| s.project_id)
1793 .context("parent session not found")?;
1794 Ok(agent.register_session(subagent_thread.clone(), project_id, cx))
1795 })??;
1796
1797 let depth = current_depth + 1;
1798
1799 telemetry::event!(
1800 "Subagent Started",
1801 session = parent_thread_entity.read(cx).id().to_string(),
1802 subagent_session = session_id.to_string(),
1803 depth,
1804 is_resumed = false,
1805 );
1806
1807 self.prompt_subagent(session_id, subagent_thread, acp_thread)
1808 }
1809
1810 pub(crate) fn resume_subagent_thread(
1811 &self,
1812 session_id: acp::SessionId,
1813 cx: &mut App,
1814 ) -> Result<Rc<dyn SubagentHandle>> {
1815 let (subagent_thread, acp_thread) = self.agent.update(cx, |agent, _cx| {
1816 let session = agent
1817 .sessions
1818 .get(&session_id)
1819 .ok_or_else(|| anyhow!("No subagent session found with id {session_id}"))?;
1820 anyhow::Ok((session.thread.clone(), session.acp_thread.clone()))
1821 })??;
1822
1823 let depth = subagent_thread.read(cx).depth();
1824
1825 if let Some(parent_thread_entity) = self.thread.upgrade() {
1826 telemetry::event!(
1827 "Subagent Started",
1828 session = parent_thread_entity.read(cx).id().to_string(),
1829 subagent_session = session_id.to_string(),
1830 depth,
1831 is_resumed = true,
1832 );
1833 }
1834
1835 self.prompt_subagent(session_id, subagent_thread, acp_thread)
1836 }
1837
1838 fn prompt_subagent(
1839 &self,
1840 session_id: acp::SessionId,
1841 subagent_thread: Entity<Thread>,
1842 acp_thread: Entity<acp_thread::AcpThread>,
1843 ) -> Result<Rc<dyn SubagentHandle>> {
1844 let Some(parent_thread_entity) = self.thread.upgrade() else {
1845 anyhow::bail!("Parent thread no longer exists".to_string());
1846 };
1847 Ok(Rc::new(NativeSubagentHandle::new(
1848 session_id,
1849 subagent_thread,
1850 acp_thread,
1851 parent_thread_entity,
1852 )) as _)
1853 }
1854}
1855
1856impl ThreadEnvironment for NativeThreadEnvironment {
1857 fn create_terminal(
1858 &self,
1859 command: String,
1860 cwd: Option<PathBuf>,
1861 output_byte_limit: Option<u64>,
1862 cx: &mut AsyncApp,
1863 ) -> Task<Result<Rc<dyn TerminalHandle>>> {
1864 let task = self.acp_thread.update(cx, |thread, cx| {
1865 thread.create_terminal(command, vec![], vec![], cwd, output_byte_limit, cx)
1866 });
1867
1868 let acp_thread = self.acp_thread.clone();
1869 cx.spawn(async move |cx| {
1870 let terminal = task?.await?;
1871
1872 let (drop_tx, drop_rx) = oneshot::channel();
1873 let terminal_id = terminal.read_with(cx, |terminal, _cx| terminal.id().clone());
1874
1875 cx.spawn(async move |cx| {
1876 drop_rx.await.ok();
1877 acp_thread.update(cx, |thread, cx| thread.release_terminal(terminal_id, cx))
1878 })
1879 .detach();
1880
1881 let handle = AcpTerminalHandle {
1882 terminal,
1883 _drop_tx: Some(drop_tx),
1884 };
1885
1886 Ok(Rc::new(handle) as _)
1887 })
1888 }
1889
1890 fn create_subagent(&self, label: String, cx: &mut App) -> Result<Rc<dyn SubagentHandle>> {
1891 self.create_subagent_thread(label, cx)
1892 }
1893
1894 fn resume_subagent(
1895 &self,
1896 session_id: acp::SessionId,
1897 cx: &mut App,
1898 ) -> Result<Rc<dyn SubagentHandle>> {
1899 self.resume_subagent_thread(session_id, cx)
1900 }
1901}
1902
1903#[derive(Debug, Clone)]
1904enum SubagentPromptResult {
1905 Completed,
1906 Cancelled,
1907 ContextWindowWarning,
1908 Error(String),
1909}
1910
1911pub struct NativeSubagentHandle {
1912 session_id: acp::SessionId,
1913 parent_thread: WeakEntity<Thread>,
1914 subagent_thread: Entity<Thread>,
1915 acp_thread: Entity<acp_thread::AcpThread>,
1916}
1917
1918impl NativeSubagentHandle {
1919 fn new(
1920 session_id: acp::SessionId,
1921 subagent_thread: Entity<Thread>,
1922 acp_thread: Entity<acp_thread::AcpThread>,
1923 parent_thread_entity: Entity<Thread>,
1924 ) -> Self {
1925 NativeSubagentHandle {
1926 session_id,
1927 subagent_thread,
1928 parent_thread: parent_thread_entity.downgrade(),
1929 acp_thread,
1930 }
1931 }
1932}
1933
1934impl SubagentHandle for NativeSubagentHandle {
1935 fn id(&self) -> acp::SessionId {
1936 self.session_id.clone()
1937 }
1938
1939 fn num_entries(&self, cx: &App) -> usize {
1940 self.acp_thread.read(cx).entries().len()
1941 }
1942
1943 fn send(&self, message: String, cx: &AsyncApp) -> Task<Result<String>> {
1944 let thread = self.subagent_thread.clone();
1945 let acp_thread = self.acp_thread.clone();
1946 let subagent_session_id = self.session_id.clone();
1947 let parent_thread = self.parent_thread.clone();
1948
1949 cx.spawn(async move |cx| {
1950 let (task, _subscription) = cx.update(|cx| {
1951 let ratio_before_prompt = thread
1952 .read(cx)
1953 .latest_token_usage()
1954 .map(|usage| usage.ratio());
1955
1956 parent_thread
1957 .update(cx, |parent_thread, _cx| {
1958 parent_thread.register_running_subagent(thread.downgrade())
1959 })
1960 .ok();
1961
1962 let task = acp_thread.update(cx, |acp_thread, cx| {
1963 acp_thread.send(vec![message.into()], cx)
1964 });
1965
1966 let (token_limit_tx, token_limit_rx) = oneshot::channel::<()>();
1967 let mut token_limit_tx = Some(token_limit_tx);
1968
1969 let subscription = cx.subscribe(
1970 &thread,
1971 move |_thread, event: &TokenUsageUpdated, _cx| {
1972 if let Some(usage) = &event.0 {
1973 let old_ratio = ratio_before_prompt
1974 .clone()
1975 .unwrap_or(TokenUsageRatio::Normal);
1976 let new_ratio = usage.ratio();
1977 if old_ratio == TokenUsageRatio::Normal
1978 && new_ratio == TokenUsageRatio::Warning
1979 {
1980 if let Some(tx) = token_limit_tx.take() {
1981 tx.send(()).ok();
1982 }
1983 }
1984 }
1985 },
1986 );
1987
1988 let wait_for_prompt = cx
1989 .background_spawn(async move {
1990 futures::select! {
1991 response = task.fuse() => match response {
1992 Ok(Some(response)) => {
1993 match response.stop_reason {
1994 acp::StopReason::Cancelled => SubagentPromptResult::Cancelled,
1995 acp::StopReason::MaxTokens => SubagentPromptResult::Error("The agent reached the maximum number of tokens.".into()),
1996 acp::StopReason::MaxTurnRequests => SubagentPromptResult::Error("The agent reached the maximum number of allowed requests between user turns. Try prompting again.".into()),
1997 acp::StopReason::Refusal => SubagentPromptResult::Error("The agent refused to process that prompt. Try again.".into()),
1998 acp::StopReason::EndTurn | _ => SubagentPromptResult::Completed,
1999 }
2000 }
2001 Ok(None) => SubagentPromptResult::Error("No response from the agent. You can try messaging again.".into()),
2002 Err(error) => SubagentPromptResult::Error(error.to_string()),
2003 },
2004 _ = token_limit_rx.fuse() => SubagentPromptResult::ContextWindowWarning,
2005 }
2006 });
2007
2008 (wait_for_prompt, subscription)
2009 });
2010
2011 let result = match task.await {
2012 SubagentPromptResult::Completed => thread.read_with(cx, |thread, _cx| {
2013 thread
2014 .last_message()
2015 .and_then(|message| {
2016 let content = message.as_agent_message()?
2017 .content
2018 .iter()
2019 .filter_map(|c| match c {
2020 AgentMessageContent::Text(text) => Some(text.as_str()),
2021 _ => None,
2022 })
2023 .join("\n\n");
2024 if content.is_empty() {
2025 None
2026 } else {
2027 Some( content)
2028 }
2029 })
2030 .context("No response from subagent")
2031 }),
2032 SubagentPromptResult::Cancelled => Err(anyhow!("User canceled")),
2033 SubagentPromptResult::Error(message) => Err(anyhow!("{message}")),
2034 SubagentPromptResult::ContextWindowWarning => {
2035 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
2036 Err(anyhow!(
2037 "The agent is nearing the end of its context window and has been \
2038 stopped. You can prompt the thread again to have the agent wrap up \
2039 or hand off its work."
2040 ))
2041 }
2042 };
2043
2044 parent_thread
2045 .update(cx, |parent_thread, cx| {
2046 parent_thread.unregister_running_subagent(&subagent_session_id, cx)
2047 })
2048 .ok();
2049
2050 result
2051 })
2052 }
2053}
2054
2055pub struct AcpTerminalHandle {
2056 terminal: Entity<acp_thread::Terminal>,
2057 _drop_tx: Option<oneshot::Sender<()>>,
2058}
2059
2060impl TerminalHandle for AcpTerminalHandle {
2061 fn id(&self, cx: &AsyncApp) -> Result<acp::TerminalId> {
2062 Ok(self.terminal.read_with(cx, |term, _cx| term.id().clone()))
2063 }
2064
2065 fn wait_for_exit(&self, cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
2066 Ok(self
2067 .terminal
2068 .read_with(cx, |term, _cx| term.wait_for_exit()))
2069 }
2070
2071 fn current_output(&self, cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
2072 Ok(self
2073 .terminal
2074 .read_with(cx, |term, cx| term.current_output(cx)))
2075 }
2076
2077 fn kill(&self, cx: &AsyncApp) -> Result<()> {
2078 cx.update(|cx| {
2079 self.terminal.update(cx, |terminal, cx| {
2080 terminal.kill(cx);
2081 });
2082 });
2083 Ok(())
2084 }
2085
2086 fn was_stopped_by_user(&self, cx: &AsyncApp) -> Result<bool> {
2087 Ok(self
2088 .terminal
2089 .read_with(cx, |term, _cx| term.was_stopped_by_user()))
2090 }
2091}
2092
2093#[cfg(test)]
2094mod internal_tests {
2095 use std::path::Path;
2096
2097 use super::*;
2098 use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelInfo, MentionUri};
2099 use fs::FakeFs;
2100 use gpui::TestAppContext;
2101 use indoc::formatdoc;
2102 use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
2103 use language_model::{
2104 LanguageModelCompletionEvent, LanguageModelProviderId, LanguageModelProviderName,
2105 };
2106 use serde_json::json;
2107 use settings::SettingsStore;
2108 use util::{path, rel_path::rel_path};
2109
2110 #[gpui::test]
2111 async fn test_maintaining_project_context(cx: &mut TestAppContext) {
2112 init_test(cx);
2113 let fs = FakeFs::new(cx.executor());
2114 fs.insert_tree(
2115 "/",
2116 json!({
2117 "a": {}
2118 }),
2119 )
2120 .await;
2121 let project = Project::test(fs.clone(), [], cx).await;
2122 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2123 let agent =
2124 cx.update(|cx| NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx));
2125
2126 // Creating a session registers the project and triggers context building.
2127 let connection = NativeAgentConnection(agent.clone());
2128 let _acp_thread = cx
2129 .update(|cx| {
2130 Rc::new(connection).new_session(
2131 project.clone(),
2132 PathList::new(&[Path::new("/")]),
2133 cx,
2134 )
2135 })
2136 .await
2137 .unwrap();
2138 cx.run_until_parked();
2139
2140 let thread = agent.read_with(cx, |agent, _cx| {
2141 agent.sessions.values().next().unwrap().thread.clone()
2142 });
2143
2144 agent.read_with(cx, |agent, cx| {
2145 let project_id = project.entity_id();
2146 let state = agent.projects.get(&project_id).unwrap();
2147 assert_eq!(state.project_context.read(cx).worktrees, vec![]);
2148 assert_eq!(thread.read(cx).project_context().read(cx).worktrees, vec![]);
2149 });
2150
2151 let worktree = project
2152 .update(cx, |project, cx| project.create_worktree("/a", true, cx))
2153 .await
2154 .unwrap();
2155 cx.run_until_parked();
2156 agent.read_with(cx, |agent, cx| {
2157 let project_id = project.entity_id();
2158 let state = agent.projects.get(&project_id).unwrap();
2159 let expected_worktrees = vec![WorktreeContext {
2160 root_name: "a".into(),
2161 abs_path: Path::new("/a").into(),
2162 rules_file: None,
2163 }];
2164 assert_eq!(state.project_context.read(cx).worktrees, expected_worktrees);
2165 assert_eq!(
2166 thread.read(cx).project_context().read(cx).worktrees,
2167 expected_worktrees
2168 );
2169 });
2170
2171 // Creating `/a/.rules` updates the project context.
2172 fs.insert_file("/a/.rules", Vec::new()).await;
2173 cx.run_until_parked();
2174 agent.read_with(cx, |agent, cx| {
2175 let project_id = project.entity_id();
2176 let state = agent.projects.get(&project_id).unwrap();
2177 let rules_entry = worktree
2178 .read(cx)
2179 .entry_for_path(rel_path(".rules"))
2180 .unwrap();
2181 let expected_worktrees = vec![WorktreeContext {
2182 root_name: "a".into(),
2183 abs_path: Path::new("/a").into(),
2184 rules_file: Some(RulesFileContext {
2185 path_in_worktree: rel_path(".rules").into(),
2186 text: "".into(),
2187 project_entry_id: rules_entry.id.to_usize(),
2188 }),
2189 }];
2190 assert_eq!(state.project_context.read(cx).worktrees, expected_worktrees);
2191 assert_eq!(
2192 thread.read(cx).project_context().read(cx).worktrees,
2193 expected_worktrees
2194 );
2195 });
2196 }
2197
2198 #[gpui::test]
2199 async fn test_listing_models(cx: &mut TestAppContext) {
2200 init_test(cx);
2201 let fs = FakeFs::new(cx.executor());
2202 fs.insert_tree("/", json!({ "a": {} })).await;
2203 let project = Project::test(fs.clone(), [], cx).await;
2204 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2205 let connection =
2206 NativeAgentConnection(cx.update(|cx| {
2207 NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx)
2208 }));
2209
2210 // Create a thread/session
2211 let acp_thread = cx
2212 .update(|cx| {
2213 Rc::new(connection.clone()).new_session(
2214 project.clone(),
2215 PathList::new(&[Path::new("/a")]),
2216 cx,
2217 )
2218 })
2219 .await
2220 .unwrap();
2221
2222 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2223
2224 let models = cx
2225 .update(|cx| {
2226 connection
2227 .model_selector(&session_id)
2228 .unwrap()
2229 .list_models(cx)
2230 })
2231 .await
2232 .unwrap();
2233
2234 let acp_thread::AgentModelList::Grouped(models) = models else {
2235 panic!("Unexpected model group");
2236 };
2237 assert_eq!(
2238 models,
2239 IndexMap::from_iter([(
2240 AgentModelGroupName("Fake".into()),
2241 vec![AgentModelInfo {
2242 id: acp::ModelId::new("fake/fake"),
2243 name: "Fake".into(),
2244 description: None,
2245 icon: Some(acp_thread::AgentModelIcon::Named(
2246 ui::IconName::ZedAssistant
2247 )),
2248 is_latest: false,
2249 cost: None,
2250 }]
2251 )])
2252 );
2253 }
2254
2255 #[gpui::test]
2256 async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) {
2257 init_test(cx);
2258 let fs = FakeFs::new(cx.executor());
2259 fs.create_dir(paths::settings_file().parent().unwrap())
2260 .await
2261 .unwrap();
2262 fs.insert_file(
2263 paths::settings_file(),
2264 json!({
2265 "agent": {
2266 "default_model": {
2267 "provider": "foo",
2268 "model": "bar"
2269 }
2270 }
2271 })
2272 .to_string()
2273 .into_bytes(),
2274 )
2275 .await;
2276 let project = Project::test(fs.clone(), [], cx).await;
2277
2278 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2279
2280 // Create the agent and connection
2281 let agent =
2282 cx.update(|cx| NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx));
2283 let connection = NativeAgentConnection(agent.clone());
2284
2285 // Create a thread/session
2286 let acp_thread = cx
2287 .update(|cx| {
2288 Rc::new(connection.clone()).new_session(
2289 project.clone(),
2290 PathList::new(&[Path::new("/a")]),
2291 cx,
2292 )
2293 })
2294 .await
2295 .unwrap();
2296
2297 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2298
2299 // Select a model
2300 let selector = connection.model_selector(&session_id).unwrap();
2301 let model_id = acp::ModelId::new("fake/fake");
2302 cx.update(|cx| selector.select_model(model_id.clone(), cx))
2303 .await
2304 .unwrap();
2305
2306 // Verify the thread has the selected model
2307 agent.read_with(cx, |agent, _| {
2308 let session = agent.sessions.get(&session_id).unwrap();
2309 session.thread.read_with(cx, |thread, _| {
2310 assert_eq!(thread.model().unwrap().id().0, "fake");
2311 });
2312 });
2313
2314 cx.run_until_parked();
2315
2316 // Verify settings file was updated
2317 let settings_content = fs.load(paths::settings_file()).await.unwrap();
2318 let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
2319
2320 // Check that the agent settings contain the selected model
2321 assert_eq!(
2322 settings_json["agent"]["default_model"]["model"],
2323 json!("fake")
2324 );
2325 assert_eq!(
2326 settings_json["agent"]["default_model"]["provider"],
2327 json!("fake")
2328 );
2329
2330 // Register a thinking model and select it.
2331 cx.update(|cx| {
2332 let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2333 "fake-corp",
2334 "fake-thinking",
2335 "Fake Thinking",
2336 true,
2337 ));
2338 let thinking_provider = Arc::new(
2339 FakeLanguageModelProvider::new(
2340 LanguageModelProviderId::from("fake-corp".to_string()),
2341 LanguageModelProviderName::from("Fake Corp".to_string()),
2342 )
2343 .with_models(vec![thinking_model]),
2344 );
2345 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2346 registry.register_provider(thinking_provider, cx);
2347 });
2348 });
2349 agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2350
2351 let selector = connection.model_selector(&session_id).unwrap();
2352 cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2353 .await
2354 .unwrap();
2355 cx.run_until_parked();
2356
2357 // Verify enable_thinking was written to settings as true.
2358 let settings_content = fs.load(paths::settings_file()).await.unwrap();
2359 let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
2360 assert_eq!(
2361 settings_json["agent"]["default_model"]["enable_thinking"],
2362 json!(true),
2363 "selecting a thinking model should persist enable_thinking: true to settings"
2364 );
2365 }
2366
2367 #[gpui::test]
2368 async fn test_select_model_updates_thinking_enabled(cx: &mut TestAppContext) {
2369 init_test(cx);
2370 let fs = FakeFs::new(cx.executor());
2371 fs.create_dir(paths::settings_file().parent().unwrap())
2372 .await
2373 .unwrap();
2374 fs.insert_file(paths::settings_file(), b"{}".to_vec()).await;
2375 let project = Project::test(fs.clone(), [], cx).await;
2376
2377 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2378 let agent =
2379 cx.update(|cx| NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx));
2380 let connection = NativeAgentConnection(agent.clone());
2381
2382 let acp_thread = cx
2383 .update(|cx| {
2384 Rc::new(connection.clone()).new_session(
2385 project.clone(),
2386 PathList::new(&[Path::new("/a")]),
2387 cx,
2388 )
2389 })
2390 .await
2391 .unwrap();
2392 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2393
2394 // Register a second provider with a thinking model.
2395 cx.update(|cx| {
2396 let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2397 "fake-corp",
2398 "fake-thinking",
2399 "Fake Thinking",
2400 true,
2401 ));
2402 let thinking_provider = Arc::new(
2403 FakeLanguageModelProvider::new(
2404 LanguageModelProviderId::from("fake-corp".to_string()),
2405 LanguageModelProviderName::from("Fake Corp".to_string()),
2406 )
2407 .with_models(vec![thinking_model]),
2408 );
2409 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2410 registry.register_provider(thinking_provider, cx);
2411 });
2412 });
2413 // Refresh the agent's model list so it picks up the new provider.
2414 agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2415
2416 // Thread starts with thinking_enabled = false (the default).
2417 agent.read_with(cx, |agent, _| {
2418 let session = agent.sessions.get(&session_id).unwrap();
2419 session.thread.read_with(cx, |thread, _| {
2420 assert!(!thread.thinking_enabled(), "thinking defaults to false");
2421 });
2422 });
2423
2424 // Select the thinking model via select_model.
2425 let selector = connection.model_selector(&session_id).unwrap();
2426 cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2427 .await
2428 .unwrap();
2429
2430 // select_model should have enabled thinking based on the model's supports_thinking().
2431 agent.read_with(cx, |agent, _| {
2432 let session = agent.sessions.get(&session_id).unwrap();
2433 session.thread.read_with(cx, |thread, _| {
2434 assert!(
2435 thread.thinking_enabled(),
2436 "select_model should enable thinking when model supports it"
2437 );
2438 });
2439 });
2440
2441 // Switch back to the non-thinking model.
2442 let selector = connection.model_selector(&session_id).unwrap();
2443 cx.update(|cx| selector.select_model(acp::ModelId::new("fake/fake"), cx))
2444 .await
2445 .unwrap();
2446
2447 // select_model should have disabled thinking.
2448 agent.read_with(cx, |agent, _| {
2449 let session = agent.sessions.get(&session_id).unwrap();
2450 session.thread.read_with(cx, |thread, _| {
2451 assert!(
2452 !thread.thinking_enabled(),
2453 "select_model should disable thinking when model does not support it"
2454 );
2455 });
2456 });
2457 }
2458
2459 #[gpui::test]
2460 async fn test_loaded_thread_preserves_thinking_enabled(cx: &mut TestAppContext) {
2461 init_test(cx);
2462 let fs = FakeFs::new(cx.executor());
2463 fs.insert_tree("/", json!({ "a": {} })).await;
2464 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2465 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2466 let agent = cx.update(|cx| {
2467 NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
2468 });
2469 let connection = Rc::new(NativeAgentConnection(agent.clone()));
2470
2471 // Register a thinking model.
2472 let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2473 "fake-corp",
2474 "fake-thinking",
2475 "Fake Thinking",
2476 true,
2477 ));
2478 let thinking_provider = Arc::new(
2479 FakeLanguageModelProvider::new(
2480 LanguageModelProviderId::from("fake-corp".to_string()),
2481 LanguageModelProviderName::from("Fake Corp".to_string()),
2482 )
2483 .with_models(vec![thinking_model.clone()]),
2484 );
2485 cx.update(|cx| {
2486 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2487 registry.register_provider(thinking_provider, cx);
2488 });
2489 });
2490 agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2491
2492 // Create a thread and select the thinking model.
2493 let acp_thread = cx
2494 .update(|cx| {
2495 connection.clone().new_session(
2496 project.clone(),
2497 PathList::new(&[Path::new("/a")]),
2498 cx,
2499 )
2500 })
2501 .await
2502 .unwrap();
2503 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2504
2505 let selector = connection.model_selector(&session_id).unwrap();
2506 cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2507 .await
2508 .unwrap();
2509
2510 // Verify thinking is enabled after selecting the thinking model.
2511 let thread = agent.read_with(cx, |agent, _| {
2512 agent.sessions.get(&session_id).unwrap().thread.clone()
2513 });
2514 thread.read_with(cx, |thread, _| {
2515 assert!(
2516 thread.thinking_enabled(),
2517 "thinking should be enabled after selecting thinking model"
2518 );
2519 });
2520
2521 // Send a message so the thread gets persisted.
2522 let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
2523 let send = cx.foreground_executor().spawn(send);
2524 cx.run_until_parked();
2525
2526 thinking_model.send_last_completion_stream_text_chunk("Response.");
2527 thinking_model.end_last_completion_stream();
2528
2529 send.await.unwrap();
2530 cx.run_until_parked();
2531
2532 // Close the session so it can be reloaded from disk.
2533 cx.update(|cx| connection.clone().close_session(&session_id, cx))
2534 .await
2535 .unwrap();
2536 drop(thread);
2537 drop(acp_thread);
2538 agent.read_with(cx, |agent, _| {
2539 assert!(agent.sessions.is_empty());
2540 });
2541
2542 // Reload the thread and verify thinking_enabled is still true.
2543 let reloaded_acp_thread = agent
2544 .update(cx, |agent, cx| {
2545 agent.open_thread(session_id.clone(), project.clone(), cx)
2546 })
2547 .await
2548 .unwrap();
2549 let reloaded_thread = agent.read_with(cx, |agent, _| {
2550 agent.sessions.get(&session_id).unwrap().thread.clone()
2551 });
2552 reloaded_thread.read_with(cx, |thread, _| {
2553 assert!(
2554 thread.thinking_enabled(),
2555 "thinking_enabled should be preserved when reloading a thread with a thinking model"
2556 );
2557 });
2558
2559 drop(reloaded_acp_thread);
2560 }
2561
2562 #[gpui::test]
2563 async fn test_loaded_thread_preserves_model(cx: &mut TestAppContext) {
2564 init_test(cx);
2565 let fs = FakeFs::new(cx.executor());
2566 fs.insert_tree("/", json!({ "a": {} })).await;
2567 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2568 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2569 let agent = cx.update(|cx| {
2570 NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
2571 });
2572 let connection = Rc::new(NativeAgentConnection(agent.clone()));
2573
2574 // Register a model where id() != name(), like real Anthropic models
2575 // (e.g. id="claude-sonnet-4-5-thinking-latest", name="Claude Sonnet 4.5 Thinking").
2576 let model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2577 "fake-corp",
2578 "custom-model-id",
2579 "Custom Model Display Name",
2580 false,
2581 ));
2582 let provider = Arc::new(
2583 FakeLanguageModelProvider::new(
2584 LanguageModelProviderId::from("fake-corp".to_string()),
2585 LanguageModelProviderName::from("Fake Corp".to_string()),
2586 )
2587 .with_models(vec![model.clone()]),
2588 );
2589 cx.update(|cx| {
2590 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2591 registry.register_provider(provider, cx);
2592 });
2593 });
2594 agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2595
2596 // Create a thread and select the model.
2597 let acp_thread = cx
2598 .update(|cx| {
2599 connection.clone().new_session(
2600 project.clone(),
2601 PathList::new(&[Path::new("/a")]),
2602 cx,
2603 )
2604 })
2605 .await
2606 .unwrap();
2607 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2608
2609 let selector = connection.model_selector(&session_id).unwrap();
2610 cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/custom-model-id"), cx))
2611 .await
2612 .unwrap();
2613
2614 let thread = agent.read_with(cx, |agent, _| {
2615 agent.sessions.get(&session_id).unwrap().thread.clone()
2616 });
2617 thread.read_with(cx, |thread, _| {
2618 assert_eq!(
2619 thread.model().unwrap().id().0.as_ref(),
2620 "custom-model-id",
2621 "model should be set before persisting"
2622 );
2623 });
2624
2625 // Send a message so the thread gets persisted.
2626 let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
2627 let send = cx.foreground_executor().spawn(send);
2628 cx.run_until_parked();
2629
2630 model.send_last_completion_stream_text_chunk("Response.");
2631 model.end_last_completion_stream();
2632
2633 send.await.unwrap();
2634 cx.run_until_parked();
2635
2636 // Close the session so it can be reloaded from disk.
2637 cx.update(|cx| connection.clone().close_session(&session_id, cx))
2638 .await
2639 .unwrap();
2640 drop(thread);
2641 drop(acp_thread);
2642 agent.read_with(cx, |agent, _| {
2643 assert!(agent.sessions.is_empty());
2644 });
2645
2646 // Reload the thread and verify the model was preserved.
2647 let reloaded_acp_thread = agent
2648 .update(cx, |agent, cx| {
2649 agent.open_thread(session_id.clone(), project.clone(), cx)
2650 })
2651 .await
2652 .unwrap();
2653 let reloaded_thread = agent.read_with(cx, |agent, _| {
2654 agent.sessions.get(&session_id).unwrap().thread.clone()
2655 });
2656 reloaded_thread.read_with(cx, |thread, _| {
2657 let reloaded_model = thread
2658 .model()
2659 .expect("model should be present after reload");
2660 assert_eq!(
2661 reloaded_model.id().0.as_ref(),
2662 "custom-model-id",
2663 "reloaded thread should have the same model, not fall back to the default"
2664 );
2665 });
2666
2667 drop(reloaded_acp_thread);
2668 }
2669
2670 #[gpui::test]
2671 async fn test_save_load_thread(cx: &mut TestAppContext) {
2672 init_test(cx);
2673 let fs = FakeFs::new(cx.executor());
2674 fs.insert_tree(
2675 "/",
2676 json!({
2677 "a": {
2678 "b.md": "Lorem"
2679 }
2680 }),
2681 )
2682 .await;
2683 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2684 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2685 let agent = cx.update(|cx| {
2686 NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
2687 });
2688 let connection = Rc::new(NativeAgentConnection(agent.clone()));
2689
2690 let acp_thread = cx
2691 .update(|cx| {
2692 connection
2693 .clone()
2694 .new_session(project.clone(), PathList::new(&[Path::new("")]), cx)
2695 })
2696 .await
2697 .unwrap();
2698 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2699 let thread = agent.read_with(cx, |agent, _| {
2700 agent.sessions.get(&session_id).unwrap().thread.clone()
2701 });
2702
2703 // Ensure empty threads are not saved, even if they get mutated.
2704 let model = Arc::new(FakeLanguageModel::default());
2705 let summary_model = Arc::new(FakeLanguageModel::default());
2706 thread.update(cx, |thread, cx| {
2707 thread.set_model(model.clone(), cx);
2708 thread.set_summarization_model(Some(summary_model.clone()), cx);
2709 });
2710 cx.run_until_parked();
2711 assert_eq!(thread_entries(&thread_store, cx), vec![]);
2712
2713 let send = acp_thread.update(cx, |thread, cx| {
2714 thread.send(
2715 vec![
2716 "What does ".into(),
2717 acp::ContentBlock::ResourceLink(acp::ResourceLink::new(
2718 "b.md",
2719 MentionUri::File {
2720 abs_path: path!("/a/b.md").into(),
2721 }
2722 .to_uri()
2723 .to_string(),
2724 )),
2725 " mean?".into(),
2726 ],
2727 cx,
2728 )
2729 });
2730 let send = cx.foreground_executor().spawn(send);
2731 cx.run_until_parked();
2732
2733 model.send_last_completion_stream_text_chunk("Lorem.");
2734 model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2735 language_model::TokenUsage {
2736 input_tokens: 150,
2737 output_tokens: 75,
2738 ..Default::default()
2739 },
2740 ));
2741 model.end_last_completion_stream();
2742 cx.run_until_parked();
2743 summary_model
2744 .send_last_completion_stream_text_chunk(&format!("Explaining {}", path!("/a/b.md")));
2745 summary_model.end_last_completion_stream();
2746
2747 send.await.unwrap();
2748 let uri = MentionUri::File {
2749 abs_path: path!("/a/b.md").into(),
2750 }
2751 .to_uri();
2752 acp_thread.read_with(cx, |thread, cx| {
2753 assert_eq!(
2754 thread.to_markdown(cx),
2755 formatdoc! {"
2756 ## User
2757
2758 What does [@b.md]({uri}) mean?
2759
2760 ## Assistant
2761
2762 Lorem.
2763
2764 "}
2765 )
2766 });
2767
2768 cx.run_until_parked();
2769
2770 // Set a draft prompt with rich content blocks before saving.
2771 let draft_blocks = vec![
2772 acp::ContentBlock::Text(acp::TextContent::new("Check out ")),
2773 acp::ContentBlock::ResourceLink(acp::ResourceLink::new("b.md", uri.to_string())),
2774 acp::ContentBlock::Text(acp::TextContent::new(" please")),
2775 ];
2776 acp_thread.update(cx, |thread, _cx| {
2777 thread.set_draft_prompt(Some(draft_blocks.clone()));
2778 });
2779 thread.update(cx, |thread, _cx| {
2780 thread.set_ui_scroll_position(Some(gpui::ListOffset {
2781 item_ix: 5,
2782 offset_in_item: gpui::px(12.5),
2783 }));
2784 });
2785 thread.update(cx, |_thread, cx| cx.notify());
2786 cx.run_until_parked();
2787
2788 // Close the session so it can be reloaded from disk.
2789 cx.update(|cx| connection.clone().close_session(&session_id, cx))
2790 .await
2791 .unwrap();
2792 drop(thread);
2793 drop(acp_thread);
2794 agent.read_with(cx, |agent, _| {
2795 assert_eq!(agent.sessions.keys().cloned().collect::<Vec<_>>(), []);
2796 });
2797
2798 // Ensure the thread can be reloaded from disk.
2799 assert_eq!(
2800 thread_entries(&thread_store, cx),
2801 vec![(
2802 session_id.clone(),
2803 format!("Explaining {}", path!("/a/b.md"))
2804 )]
2805 );
2806 let acp_thread = agent
2807 .update(cx, |agent, cx| {
2808 agent.open_thread(session_id.clone(), project.clone(), cx)
2809 })
2810 .await
2811 .unwrap();
2812 acp_thread.read_with(cx, |thread, cx| {
2813 assert_eq!(
2814 thread.to_markdown(cx),
2815 formatdoc! {"
2816 ## User
2817
2818 What does [@b.md]({uri}) mean?
2819
2820 ## Assistant
2821
2822 Lorem.
2823
2824 "}
2825 )
2826 });
2827
2828 // Ensure the draft prompt with rich content blocks survived the round-trip.
2829 acp_thread.read_with(cx, |thread, _| {
2830 assert_eq!(thread.draft_prompt(), Some(draft_blocks.as_slice()));
2831 });
2832
2833 // Ensure token usage survived the round-trip.
2834 acp_thread.read_with(cx, |thread, _| {
2835 let usage = thread
2836 .token_usage()
2837 .expect("token usage should be restored after reload");
2838 assert_eq!(usage.input_tokens, 150);
2839 assert_eq!(usage.output_tokens, 75);
2840 });
2841
2842 // Ensure scroll position survived the round-trip.
2843 acp_thread.read_with(cx, |thread, _| {
2844 let scroll = thread
2845 .ui_scroll_position()
2846 .expect("scroll position should be restored after reload");
2847 assert_eq!(scroll.item_ix, 5);
2848 assert_eq!(scroll.offset_in_item, gpui::px(12.5));
2849 });
2850 }
2851
2852 fn thread_entries(
2853 thread_store: &Entity<ThreadStore>,
2854 cx: &mut TestAppContext,
2855 ) -> Vec<(acp::SessionId, String)> {
2856 thread_store.read_with(cx, |store, _| {
2857 store
2858 .entries()
2859 .map(|entry| (entry.id.clone(), entry.title.to_string()))
2860 .collect::<Vec<_>>()
2861 })
2862 }
2863
2864 fn init_test(cx: &mut TestAppContext) {
2865 env_logger::try_init().ok();
2866 cx.update(|cx| {
2867 let settings_store = SettingsStore::test(cx);
2868 cx.set_global(settings_store);
2869
2870 LanguageModelRegistry::test(cx);
2871 });
2872 }
2873}
2874
2875fn mcp_message_content_to_acp_content_block(
2876 content: context_server::types::MessageContent,
2877) -> acp::ContentBlock {
2878 match content {
2879 context_server::types::MessageContent::Text {
2880 text,
2881 annotations: _,
2882 } => text.into(),
2883 context_server::types::MessageContent::Image {
2884 data,
2885 mime_type,
2886 annotations: _,
2887 } => acp::ContentBlock::Image(acp::ImageContent::new(data, mime_type)),
2888 context_server::types::MessageContent::Audio {
2889 data,
2890 mime_type,
2891 annotations: _,
2892 } => acp::ContentBlock::Audio(acp::AudioContent::new(data, mime_type)),
2893 context_server::types::MessageContent::Resource {
2894 resource,
2895 annotations: _,
2896 } => {
2897 let mut link =
2898 acp::ResourceLink::new(resource.uri.to_string(), resource.uri.to_string());
2899 if let Some(mime_type) = resource.mime_type {
2900 link = link.mime_type(mime_type);
2901 }
2902 acp::ContentBlock::ResourceLink(link)
2903 }
2904 }
2905}