1mod db;
2mod edit_agent;
3mod legacy_thread;
4mod native_agent_server;
5pub mod outline;
6mod pattern_extraction;
7mod templates;
8#[cfg(test)]
9mod tests;
10mod thread;
11mod thread_store;
12mod tool_permissions;
13mod tools;
14
15use context_server::ContextServerId;
16pub use db::*;
17use itertools::Itertools;
18pub use native_agent_server::NativeAgentServer;
19pub use pattern_extraction::*;
20pub use shell_command_parser::extract_commands;
21pub use templates::*;
22pub use thread::*;
23pub use thread_store::*;
24pub use tool_permissions::*;
25pub use tools::*;
26
27use acp_thread::{
28 AcpThread, AgentModelSelector, AgentSessionInfo, AgentSessionList, AgentSessionListRequest,
29 AgentSessionListResponse, TokenUsageRatio, UserMessageId,
30};
31use agent_client_protocol as acp;
32use anyhow::{Context as _, Result, anyhow};
33use chrono::{DateTime, Utc};
34use collections::{HashMap, HashSet, IndexMap};
35use fs::Fs;
36use futures::channel::{mpsc, oneshot};
37use futures::future::Shared;
38use futures::{FutureExt as _, StreamExt as _, future};
39use gpui::{
40 App, AppContext, AsyncApp, Context, Entity, EntityId, SharedString, Subscription, Task,
41 WeakEntity,
42};
43use language_model::{IconOrSvg, LanguageModel, LanguageModelProvider, LanguageModelRegistry};
44use project::{AgentId, Project, ProjectItem, ProjectPath, Worktree};
45use prompt_store::{
46 ProjectContext, PromptStore, RULES_FILE_NAMES, RulesFileContext, UserRulesContext,
47 WorktreeContext,
48};
49use serde::{Deserialize, Serialize};
50use settings::{LanguageModelSelection, Settings as _, update_settings_file};
51use std::any::Any;
52use std::path::PathBuf;
53use std::rc::Rc;
54use std::sync::{Arc, LazyLock};
55use util::ResultExt;
56use util::path_list::PathList;
57use util::rel_path::RelPath;
58
59#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
60pub struct ProjectSnapshot {
61 pub worktree_snapshots: Vec<project::telemetry_snapshot::TelemetryWorktreeSnapshot>,
62 pub timestamp: DateTime<Utc>,
63}
64
65pub struct RulesLoadingError {
66 pub message: SharedString,
67}
68
69struct ProjectState {
70 project: Entity<Project>,
71 project_context: Entity<ProjectContext>,
72 project_context_needs_refresh: watch::Sender<()>,
73 _maintain_project_context: Task<Result<()>>,
74 context_server_registry: Entity<ContextServerRegistry>,
75 _subscriptions: Vec<Subscription>,
76}
77
78/// Holds both the internal Thread and the AcpThread for a session
79struct Session {
80 /// The internal thread that processes messages
81 thread: Entity<Thread>,
82 /// The ACP thread that handles protocol communication
83 acp_thread: Entity<acp_thread::AcpThread>,
84 project_id: EntityId,
85 pending_save: Task<Result<()>>,
86 _subscriptions: Vec<Subscription>,
87 ref_count: usize,
88}
89
90struct PendingSession {
91 task: Shared<Task<Result<Entity<AcpThread>, Arc<anyhow::Error>>>>,
92 ref_count: usize,
93}
94
95pub struct LanguageModels {
96 /// Access language model by ID
97 models: HashMap<acp::ModelId, Arc<dyn LanguageModel>>,
98 /// Cached list for returning language model information
99 model_list: acp_thread::AgentModelList,
100 refresh_models_rx: watch::Receiver<()>,
101 refresh_models_tx: watch::Sender<()>,
102 _authenticate_all_providers_task: Task<()>,
103}
104
105impl LanguageModels {
106 fn new(cx: &mut App) -> Self {
107 let (refresh_models_tx, refresh_models_rx) = watch::channel(());
108
109 let mut this = Self {
110 models: HashMap::default(),
111 model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()),
112 refresh_models_rx,
113 refresh_models_tx,
114 _authenticate_all_providers_task: Self::authenticate_all_language_model_providers(cx),
115 };
116 this.refresh_list(cx);
117 this
118 }
119
120 fn refresh_list(&mut self, cx: &App) {
121 let providers = LanguageModelRegistry::global(cx)
122 .read(cx)
123 .visible_providers()
124 .into_iter()
125 .filter(|provider| provider.is_authenticated(cx))
126 .collect::<Vec<_>>();
127
128 let mut language_model_list = IndexMap::default();
129 let mut recommended_models = HashSet::default();
130
131 let mut recommended = Vec::new();
132 for provider in &providers {
133 for model in provider.recommended_models(cx) {
134 recommended_models.insert((model.provider_id(), model.id()));
135 recommended.push(Self::map_language_model_to_info(&model, provider));
136 }
137 }
138 if !recommended.is_empty() {
139 language_model_list.insert(
140 acp_thread::AgentModelGroupName("Recommended".into()),
141 recommended,
142 );
143 }
144
145 let mut models = HashMap::default();
146 for provider in providers {
147 let mut provider_models = Vec::new();
148 for model in provider.provided_models(cx) {
149 let model_info = Self::map_language_model_to_info(&model, &provider);
150 let model_id = model_info.id.clone();
151 provider_models.push(model_info);
152 models.insert(model_id, model);
153 }
154 if !provider_models.is_empty() {
155 language_model_list.insert(
156 acp_thread::AgentModelGroupName(provider.name().0.clone()),
157 provider_models,
158 );
159 }
160 }
161
162 self.models = models;
163 self.model_list = acp_thread::AgentModelList::Grouped(language_model_list);
164 self.refresh_models_tx.send(()).ok();
165 }
166
167 fn watch(&self) -> watch::Receiver<()> {
168 self.refresh_models_rx.clone()
169 }
170
171 pub fn model_from_id(&self, model_id: &acp::ModelId) -> Option<Arc<dyn LanguageModel>> {
172 self.models.get(model_id).cloned()
173 }
174
175 fn map_language_model_to_info(
176 model: &Arc<dyn LanguageModel>,
177 provider: &Arc<dyn LanguageModelProvider>,
178 ) -> acp_thread::AgentModelInfo {
179 acp_thread::AgentModelInfo {
180 id: Self::model_id(model),
181 name: model.name().0,
182 description: None,
183 icon: Some(match provider.icon() {
184 IconOrSvg::Svg(path) => acp_thread::AgentModelIcon::Path(path),
185 IconOrSvg::Icon(name) => acp_thread::AgentModelIcon::Named(name),
186 }),
187 is_latest: model.is_latest(),
188 cost: model.model_cost_info().map(|cost| cost.to_shared_string()),
189 }
190 }
191
192 fn model_id(model: &Arc<dyn LanguageModel>) -> acp::ModelId {
193 acp::ModelId::new(format!("{}/{}", model.provider_id().0, model.id().0))
194 }
195
196 fn authenticate_all_language_model_providers(cx: &mut App) -> Task<()> {
197 let authenticate_all_providers = LanguageModelRegistry::global(cx)
198 .read(cx)
199 .visible_providers()
200 .iter()
201 .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
202 .collect::<Vec<_>>();
203
204 cx.spawn(async move |cx| {
205 for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
206 if let Err(err) = authenticate_task.await {
207 match err {
208 language_model::AuthenticateError::CredentialsNotFound => {
209 // Since we're authenticating these providers in the
210 // background for the purposes of populating the
211 // language selector, we don't care about providers
212 // where the credentials are not found.
213 }
214 language_model::AuthenticateError::ConnectionRefused => {
215 // Not logging connection refused errors as they are mostly from LM Studio's noisy auth failures.
216 // LM Studio only has one auth method (endpoint call) which fails for users who haven't enabled it.
217 // TODO: Better manage LM Studio auth logic to avoid these noisy failures.
218 }
219 _ => {
220 // Some providers have noisy failure states that we
221 // don't want to spam the logs with every time the
222 // language model selector is initialized.
223 //
224 // Ideally these should have more clear failure modes
225 // that we know are safe to ignore here, like what we do
226 // with `CredentialsNotFound` above.
227 match provider_id.0.as_ref() {
228 "lmstudio" | "ollama" => {
229 // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
230 //
231 // These fail noisily, so we don't log them.
232 }
233 "copilot_chat" => {
234 // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
235 }
236 _ => {
237 log::error!(
238 "Failed to authenticate provider: {}: {err:#}",
239 provider_name.0
240 );
241 }
242 }
243 }
244 }
245 }
246 }
247
248 cx.update(language_models::update_environment_fallback_model);
249 })
250 }
251}
252
253pub struct NativeAgent {
254 /// Session ID -> Session mapping
255 sessions: HashMap<acp::SessionId, Session>,
256 pending_sessions: HashMap<acp::SessionId, PendingSession>,
257 thread_store: Entity<ThreadStore>,
258 /// Project-specific state keyed by project EntityId
259 projects: HashMap<EntityId, ProjectState>,
260 /// Shared templates for all threads
261 templates: Arc<Templates>,
262 /// Cached model information
263 models: LanguageModels,
264 prompt_store: Option<Entity<PromptStore>>,
265 fs: Arc<dyn Fs>,
266 _subscriptions: Vec<Subscription>,
267}
268
269impl NativeAgent {
270 pub fn new(
271 thread_store: Entity<ThreadStore>,
272 templates: Arc<Templates>,
273 prompt_store: Option<Entity<PromptStore>>,
274 fs: Arc<dyn Fs>,
275 cx: &mut App,
276 ) -> Entity<NativeAgent> {
277 log::debug!("Creating new NativeAgent");
278
279 cx.new(|cx| {
280 let mut subscriptions = vec![cx.subscribe(
281 &LanguageModelRegistry::global(cx),
282 Self::handle_models_updated_event,
283 )];
284 if let Some(prompt_store) = prompt_store.as_ref() {
285 subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
286 }
287
288 Self {
289 sessions: HashMap::default(),
290 pending_sessions: HashMap::default(),
291 thread_store,
292 projects: HashMap::default(),
293 templates,
294 models: LanguageModels::new(cx),
295 prompt_store,
296 fs,
297 _subscriptions: subscriptions,
298 }
299 })
300 }
301
302 fn new_session(
303 &mut self,
304 project: Entity<Project>,
305 cx: &mut Context<Self>,
306 ) -> Entity<AcpThread> {
307 let project_id = self.get_or_create_project_state(&project, cx);
308 let project_state = &self.projects[&project_id];
309
310 let registry = LanguageModelRegistry::read_global(cx);
311 let available_count = registry.available_models(cx).count();
312 log::debug!("Total available models: {}", available_count);
313
314 let default_model = registry.default_model().and_then(|default_model| {
315 self.models
316 .model_from_id(&LanguageModels::model_id(&default_model.model))
317 });
318 let thread = cx.new(|cx| {
319 Thread::new(
320 project,
321 project_state.project_context.clone(),
322 project_state.context_server_registry.clone(),
323 self.templates.clone(),
324 default_model,
325 cx,
326 )
327 });
328
329 self.register_session(thread, project_id, 1, cx)
330 }
331
332 fn register_session(
333 &mut self,
334 thread_handle: Entity<Thread>,
335 project_id: EntityId,
336 ref_count: usize,
337 cx: &mut Context<Self>,
338 ) -> Entity<AcpThread> {
339 let connection = Rc::new(NativeAgentConnection(cx.entity()));
340
341 let thread = thread_handle.read(cx);
342 let session_id = thread.id().clone();
343 let parent_session_id = thread.parent_thread_id();
344 let title = thread.title();
345 let draft_prompt = thread.draft_prompt().map(Vec::from);
346 let scroll_position = thread.ui_scroll_position();
347 let token_usage = thread.latest_token_usage();
348 let project = thread.project.clone();
349 let action_log = thread.action_log.clone();
350 let prompt_capabilities_rx = thread.prompt_capabilities_rx.clone();
351 let acp_thread = cx.new(|cx| {
352 let mut acp_thread = acp_thread::AcpThread::new(
353 parent_session_id,
354 title,
355 None,
356 connection,
357 project.clone(),
358 action_log.clone(),
359 session_id.clone(),
360 prompt_capabilities_rx,
361 cx,
362 );
363 acp_thread.set_draft_prompt(draft_prompt, cx);
364 acp_thread.set_ui_scroll_position(scroll_position);
365 acp_thread.update_token_usage(token_usage, cx);
366 acp_thread
367 });
368
369 let registry = LanguageModelRegistry::read_global(cx);
370 let summarization_model = registry.thread_summary_model(cx).map(|c| c.model);
371
372 let weak = cx.weak_entity();
373 let weak_thread = thread_handle.downgrade();
374 thread_handle.update(cx, |thread, cx| {
375 thread.set_summarization_model(summarization_model, cx);
376 thread.add_default_tools(
377 Rc::new(NativeThreadEnvironment {
378 acp_thread: acp_thread.downgrade(),
379 thread: weak_thread,
380 agent: weak,
381 }) as _,
382 cx,
383 )
384 });
385
386 let subscriptions = vec![
387 cx.subscribe(&thread_handle, Self::handle_thread_title_updated),
388 cx.subscribe(&thread_handle, Self::handle_thread_token_usage_updated),
389 cx.observe(&thread_handle, move |this, thread, cx| {
390 this.save_thread(thread, cx)
391 }),
392 ];
393
394 self.sessions.insert(
395 session_id,
396 Session {
397 thread: thread_handle,
398 acp_thread: acp_thread.clone(),
399 project_id,
400 _subscriptions: subscriptions,
401 pending_save: Task::ready(Ok(())),
402 ref_count,
403 },
404 );
405
406 self.update_available_commands_for_project(project_id, cx);
407
408 acp_thread
409 }
410
411 pub fn models(&self) -> &LanguageModels {
412 &self.models
413 }
414
415 fn get_or_create_project_state(
416 &mut self,
417 project: &Entity<Project>,
418 cx: &mut Context<Self>,
419 ) -> EntityId {
420 let project_id = project.entity_id();
421 if self.projects.contains_key(&project_id) {
422 return project_id;
423 }
424
425 let project_context = cx.new(|_| ProjectContext::new(vec![], vec![]));
426 self.register_project_with_initial_context(project.clone(), project_context, cx);
427 if let Some(state) = self.projects.get_mut(&project_id) {
428 state.project_context_needs_refresh.send(()).ok();
429 }
430 project_id
431 }
432
433 fn register_project_with_initial_context(
434 &mut self,
435 project: Entity<Project>,
436 project_context: Entity<ProjectContext>,
437 cx: &mut Context<Self>,
438 ) {
439 let project_id = project.entity_id();
440
441 let context_server_store = project.read(cx).context_server_store();
442 let context_server_registry =
443 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
444
445 let subscriptions = vec![
446 cx.subscribe(&project, Self::handle_project_event),
447 cx.subscribe(
448 &context_server_store,
449 Self::handle_context_server_store_updated,
450 ),
451 cx.subscribe(
452 &context_server_registry,
453 Self::handle_context_server_registry_event,
454 ),
455 ];
456
457 let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
458 watch::channel(());
459
460 self.projects.insert(
461 project_id,
462 ProjectState {
463 project,
464 project_context,
465 project_context_needs_refresh: project_context_needs_refresh_tx,
466 _maintain_project_context: cx.spawn(async move |this, cx| {
467 Self::maintain_project_context(
468 this,
469 project_id,
470 project_context_needs_refresh_rx,
471 cx,
472 )
473 .await
474 }),
475 context_server_registry,
476 _subscriptions: subscriptions,
477 },
478 );
479 }
480
481 fn session_project_state(&self, session_id: &acp::SessionId) -> Option<&ProjectState> {
482 self.sessions
483 .get(session_id)
484 .and_then(|session| self.projects.get(&session.project_id))
485 }
486
487 async fn maintain_project_context(
488 this: WeakEntity<Self>,
489 project_id: EntityId,
490 mut needs_refresh: watch::Receiver<()>,
491 cx: &mut AsyncApp,
492 ) -> Result<()> {
493 while needs_refresh.changed().await.is_ok() {
494 let project_context = this
495 .update(cx, |this, cx| {
496 let state = this
497 .projects
498 .get(&project_id)
499 .context("project state not found")?;
500 anyhow::Ok(Self::build_project_context(
501 &state.project,
502 this.prompt_store.as_ref(),
503 cx,
504 ))
505 })??
506 .await;
507 this.update(cx, |this, cx| {
508 if let Some(state) = this.projects.get(&project_id) {
509 state
510 .project_context
511 .update(cx, |current_project_context, _cx| {
512 *current_project_context = project_context;
513 });
514 }
515 })?;
516 }
517
518 Ok(())
519 }
520
521 fn build_project_context(
522 project: &Entity<Project>,
523 prompt_store: Option<&Entity<PromptStore>>,
524 cx: &mut App,
525 ) -> Task<ProjectContext> {
526 let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
527 let worktree_tasks = worktrees
528 .into_iter()
529 .map(|worktree| {
530 Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
531 })
532 .collect::<Vec<_>>();
533 let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
534 prompt_store.read_with(cx, |prompt_store, cx| {
535 let prompts = prompt_store.default_prompt_metadata();
536 let load_tasks = prompts.into_iter().map(|prompt_metadata| {
537 let contents = prompt_store.load(prompt_metadata.id, cx);
538 async move { (contents.await, prompt_metadata) }
539 });
540 cx.background_spawn(future::join_all(load_tasks))
541 })
542 } else {
543 Task::ready(vec![])
544 };
545
546 cx.spawn(async move |_cx| {
547 let (worktrees, default_user_rules) =
548 future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
549
550 let worktrees = worktrees
551 .into_iter()
552 .map(|(worktree, _rules_error)| {
553 // TODO: show error message
554 // if let Some(rules_error) = rules_error {
555 // this.update(cx, |_, cx| cx.emit(rules_error)).ok();
556 // }
557 worktree
558 })
559 .collect::<Vec<_>>();
560
561 let default_user_rules = default_user_rules
562 .into_iter()
563 .flat_map(|(contents, prompt_metadata)| match contents {
564 Ok(contents) => Some(UserRulesContext {
565 uuid: prompt_metadata.id.as_user()?,
566 title: prompt_metadata.title.map(|title| title.to_string()),
567 contents,
568 }),
569 Err(_err) => {
570 // TODO: show error message
571 // this.update(cx, |_, cx| {
572 // cx.emit(RulesLoadingError {
573 // message: format!("{err:?}").into(),
574 // });
575 // })
576 // .ok();
577 None
578 }
579 })
580 .collect::<Vec<_>>();
581
582 ProjectContext::new(worktrees, default_user_rules)
583 })
584 }
585
586 fn load_worktree_info_for_system_prompt(
587 worktree: Entity<Worktree>,
588 project: Entity<Project>,
589 cx: &mut App,
590 ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
591 let tree = worktree.read(cx);
592 let root_name = tree.root_name_str().into();
593 let abs_path = tree.abs_path();
594 let scan_complete = tree.as_local().map(|local| local.scan_complete());
595
596 let mut context = WorktreeContext {
597 root_name,
598 abs_path,
599 rules_file: None,
600 };
601
602 cx.spawn(async move |cx| {
603 if let Some(scan_complete) = scan_complete {
604 scan_complete.await;
605 }
606
607 let rules_task = cx.update(|cx| Self::load_worktree_rules_file(worktree, project, cx));
608
609 let (rules_file, rules_file_error) = match rules_task {
610 Some(rules_task) => match rules_task.await {
611 Ok(rules_file) => (Some(rules_file), None),
612 Err(err) => (
613 None,
614 Some(RulesLoadingError {
615 message: format!("{err}").into(),
616 }),
617 ),
618 },
619 None => (None, None),
620 };
621 context.rules_file = rules_file;
622 (context, rules_file_error)
623 })
624 }
625
626 fn load_worktree_rules_file(
627 worktree: Entity<Worktree>,
628 project: Entity<Project>,
629 cx: &mut App,
630 ) -> Option<Task<Result<RulesFileContext>>> {
631 let worktree = worktree.read(cx);
632 let worktree_id = worktree.id();
633 let selected_rules_file = RULES_FILE_NAMES
634 .into_iter()
635 .filter_map(|name| {
636 worktree
637 .entry_for_path(RelPath::unix(name).unwrap())
638 .filter(|entry| entry.is_file())
639 .map(|entry| entry.path.clone())
640 })
641 .next();
642
643 // Note that Cline supports `.clinerules` being a directory, but that is not currently
644 // supported. This doesn't seem to occur often in GitHub repositories.
645 selected_rules_file.map(|path_in_worktree| {
646 let project_path = ProjectPath {
647 worktree_id,
648 path: path_in_worktree.clone(),
649 };
650 let buffer_task =
651 project.update(cx, |project, cx| project.open_buffer(project_path, cx));
652 let rope_task = cx.spawn(async move |cx| {
653 let buffer = buffer_task.await?;
654 let (project_entry_id, rope) = buffer.read_with(cx, |buffer, cx| {
655 let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
656 anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
657 })?;
658 anyhow::Ok((project_entry_id, rope))
659 });
660 // Build a string from the rope on a background thread.
661 cx.background_spawn(async move {
662 let (project_entry_id, rope) = rope_task.await?;
663 anyhow::Ok(RulesFileContext {
664 path_in_worktree,
665 text: rope.to_string().trim().to_string(),
666 project_entry_id: project_entry_id.to_usize(),
667 })
668 })
669 })
670 }
671
672 fn handle_thread_title_updated(
673 &mut self,
674 thread: Entity<Thread>,
675 _: &TitleUpdated,
676 cx: &mut Context<Self>,
677 ) {
678 let session_id = thread.read(cx).id();
679 let Some(session) = self.sessions.get(session_id) else {
680 return;
681 };
682
683 let thread = thread.downgrade();
684 let acp_thread = session.acp_thread.downgrade();
685 cx.spawn(async move |_, cx| {
686 let title = thread.read_with(cx, |thread, _| thread.title())?;
687 if let Some(title) = title {
688 let task =
689 acp_thread.update(cx, |acp_thread, cx| acp_thread.set_title(title, cx))?;
690 task.await?;
691 }
692 anyhow::Ok(())
693 })
694 .detach_and_log_err(cx);
695 }
696
697 fn handle_thread_token_usage_updated(
698 &mut self,
699 thread: Entity<Thread>,
700 usage: &TokenUsageUpdated,
701 cx: &mut Context<Self>,
702 ) {
703 let Some(session) = self.sessions.get(thread.read(cx).id()) else {
704 return;
705 };
706 session.acp_thread.update(cx, |acp_thread, cx| {
707 acp_thread.update_token_usage(usage.0.clone(), cx);
708 });
709 }
710
711 fn handle_project_event(
712 &mut self,
713 project: Entity<Project>,
714 event: &project::Event,
715 _cx: &mut Context<Self>,
716 ) {
717 let project_id = project.entity_id();
718 let Some(state) = self.projects.get_mut(&project_id) else {
719 return;
720 };
721 match event {
722 project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
723 state.project_context_needs_refresh.send(()).ok();
724 }
725 project::Event::WorktreeUpdatedEntries(_, items) => {
726 if items.iter().any(|(path, _, _)| {
727 RULES_FILE_NAMES
728 .iter()
729 .any(|name| path.as_ref() == RelPath::unix(name).unwrap())
730 }) {
731 state.project_context_needs_refresh.send(()).ok();
732 }
733 }
734 _ => {}
735 }
736 }
737
738 fn handle_prompts_updated_event(
739 &mut self,
740 _prompt_store: Entity<PromptStore>,
741 _event: &prompt_store::PromptsUpdatedEvent,
742 _cx: &mut Context<Self>,
743 ) {
744 for state in self.projects.values_mut() {
745 state.project_context_needs_refresh.send(()).ok();
746 }
747 }
748
749 fn handle_models_updated_event(
750 &mut self,
751 _registry: Entity<LanguageModelRegistry>,
752 event: &language_model::Event,
753 cx: &mut Context<Self>,
754 ) {
755 self.models.refresh_list(cx);
756
757 let registry = LanguageModelRegistry::read_global(cx);
758 let default_model = registry.default_model().map(|m| m.model);
759 let summarization_model = registry.thread_summary_model(cx).map(|m| m.model);
760
761 for session in self.sessions.values_mut() {
762 session.thread.update(cx, |thread, cx| {
763 if thread.model().is_none()
764 && let Some(model) = default_model.clone()
765 {
766 thread.set_model(model, cx);
767 cx.notify();
768 }
769 if let Some(model) = summarization_model.clone() {
770 if thread.summarization_model().is_none()
771 || matches!(event, language_model::Event::ThreadSummaryModelChanged)
772 {
773 thread.set_summarization_model(Some(model), cx);
774 }
775 }
776 });
777 }
778 }
779
780 fn handle_context_server_store_updated(
781 &mut self,
782 store: Entity<project::context_server_store::ContextServerStore>,
783 _event: &project::context_server_store::ServerStatusChangedEvent,
784 cx: &mut Context<Self>,
785 ) {
786 let project_id = self.projects.iter().find_map(|(id, state)| {
787 if *state.context_server_registry.read(cx).server_store() == store {
788 Some(*id)
789 } else {
790 None
791 }
792 });
793 if let Some(project_id) = project_id {
794 self.update_available_commands_for_project(project_id, cx);
795 }
796 }
797
798 fn handle_context_server_registry_event(
799 &mut self,
800 registry: Entity<ContextServerRegistry>,
801 event: &ContextServerRegistryEvent,
802 cx: &mut Context<Self>,
803 ) {
804 match event {
805 ContextServerRegistryEvent::ToolsChanged => {}
806 ContextServerRegistryEvent::PromptsChanged => {
807 let project_id = self.projects.iter().find_map(|(id, state)| {
808 if state.context_server_registry == registry {
809 Some(*id)
810 } else {
811 None
812 }
813 });
814 if let Some(project_id) = project_id {
815 self.update_available_commands_for_project(project_id, cx);
816 }
817 }
818 }
819 }
820
821 fn update_available_commands_for_project(&self, project_id: EntityId, cx: &mut Context<Self>) {
822 let available_commands =
823 Self::build_available_commands_for_project(self.projects.get(&project_id), cx);
824 for session in self.sessions.values() {
825 if session.project_id != project_id {
826 continue;
827 }
828 session.acp_thread.update(cx, |thread, cx| {
829 thread
830 .handle_session_update(
831 acp::SessionUpdate::AvailableCommandsUpdate(
832 acp::AvailableCommandsUpdate::new(available_commands.clone()),
833 ),
834 cx,
835 )
836 .log_err();
837 });
838 }
839 }
840
841 fn build_available_commands_for_project(
842 project_state: Option<&ProjectState>,
843 cx: &App,
844 ) -> Vec<acp::AvailableCommand> {
845 let Some(state) = project_state else {
846 return vec![];
847 };
848 let registry = state.context_server_registry.read(cx);
849
850 let mut prompt_name_counts: HashMap<&str, usize> = HashMap::default();
851 for context_server_prompt in registry.prompts() {
852 *prompt_name_counts
853 .entry(context_server_prompt.prompt.name.as_str())
854 .or_insert(0) += 1;
855 }
856
857 registry
858 .prompts()
859 .flat_map(|context_server_prompt| {
860 let prompt = &context_server_prompt.prompt;
861
862 let should_prefix = prompt_name_counts
863 .get(prompt.name.as_str())
864 .copied()
865 .unwrap_or(0)
866 > 1;
867
868 let name = if should_prefix {
869 format!("{}.{}", context_server_prompt.server_id, prompt.name)
870 } else {
871 prompt.name.clone()
872 };
873
874 let mut command = acp::AvailableCommand::new(
875 name,
876 prompt.description.clone().unwrap_or_default(),
877 );
878
879 match prompt.arguments.as_deref() {
880 Some([arg]) => {
881 let hint = format!("<{}>", arg.name);
882
883 command = command.input(acp::AvailableCommandInput::Unstructured(
884 acp::UnstructuredCommandInput::new(hint),
885 ));
886 }
887 Some([]) | None => {}
888 Some(_) => {
889 // skip >1 argument commands since we don't support them yet
890 return None;
891 }
892 }
893
894 Some(command)
895 })
896 .collect()
897 }
898
899 pub fn load_thread(
900 &mut self,
901 id: acp::SessionId,
902 project: Entity<Project>,
903 cx: &mut Context<Self>,
904 ) -> Task<Result<Entity<Thread>>> {
905 let database_future = ThreadsDatabase::connect(cx);
906 cx.spawn(async move |this, cx| {
907 let database = database_future.await.map_err(|err| anyhow!(err))?;
908 let db_thread = database
909 .load_thread(id.clone())
910 .await?
911 .with_context(|| format!("no thread found with ID: {id:?}"))?;
912
913 this.update(cx, |this, cx| {
914 let project_id = this.get_or_create_project_state(&project, cx);
915 let project_state = this
916 .projects
917 .get(&project_id)
918 .context("project state not found")?;
919 let summarization_model = LanguageModelRegistry::read_global(cx)
920 .thread_summary_model(cx)
921 .map(|c| c.model);
922
923 Ok(cx.new(|cx| {
924 let mut thread = Thread::from_db(
925 id.clone(),
926 db_thread,
927 project_state.project.clone(),
928 project_state.project_context.clone(),
929 project_state.context_server_registry.clone(),
930 this.templates.clone(),
931 cx,
932 );
933 thread.set_summarization_model(summarization_model, cx);
934 thread
935 }))
936 })?
937 })
938 }
939
940 pub fn open_thread(
941 &mut self,
942 id: acp::SessionId,
943 project: Entity<Project>,
944 cx: &mut Context<Self>,
945 ) -> Task<Result<Entity<AcpThread>>> {
946 if let Some(session) = self.sessions.get_mut(&id) {
947 session.ref_count += 1;
948 return Task::ready(Ok(session.acp_thread.clone()));
949 }
950
951 if let Some(pending) = self.pending_sessions.get_mut(&id) {
952 pending.ref_count += 1;
953 let task = pending.task.clone();
954 return cx.background_spawn(async move { task.await.map_err(|err| anyhow!(err)) });
955 }
956
957 let task = self.load_thread(id.clone(), project.clone(), cx);
958 let shared_task = cx
959 .spawn({
960 let id = id.clone();
961 async move |this, cx| {
962 let thread = match task.await {
963 Ok(thread) => thread,
964 Err(err) => {
965 this.update(cx, |this, _cx| {
966 this.pending_sessions.remove(&id);
967 })
968 .ok();
969 return Err(Arc::new(err));
970 }
971 };
972 let acp_thread = this
973 .update(cx, |this, cx| {
974 let project_id = this.get_or_create_project_state(&project, cx);
975 let ref_count = this
976 .pending_sessions
977 .remove(&id)
978 .map_or(1, |pending| pending.ref_count);
979 this.register_session(thread.clone(), project_id, ref_count, cx)
980 })
981 .map_err(Arc::new)?;
982 let events = thread.update(cx, |thread, cx| thread.replay(cx));
983 cx.update(|cx| {
984 NativeAgentConnection::handle_thread_events(
985 events,
986 acp_thread.downgrade(),
987 cx,
988 )
989 })
990 .await
991 .map_err(Arc::new)?;
992 acp_thread.update(cx, |thread, cx| {
993 thread.snapshot_completed_plan(cx);
994 });
995 Ok(acp_thread)
996 }
997 })
998 .shared();
999 self.pending_sessions.insert(
1000 id,
1001 PendingSession {
1002 task: shared_task.clone(),
1003 ref_count: 1,
1004 },
1005 );
1006
1007 cx.background_spawn(async move { shared_task.await.map_err(|err| anyhow!(err)) })
1008 }
1009
1010 pub fn thread_summary(
1011 &mut self,
1012 id: acp::SessionId,
1013 project: Entity<Project>,
1014 cx: &mut Context<Self>,
1015 ) -> Task<Result<SharedString>> {
1016 let thread = self.open_thread(id.clone(), project, cx);
1017 cx.spawn(async move |this, cx| {
1018 let acp_thread = thread.await?;
1019 let result = this
1020 .update(cx, |this, cx| {
1021 this.sessions
1022 .get(&id)
1023 .unwrap()
1024 .thread
1025 .update(cx, |thread, cx| thread.summary(cx))
1026 })?
1027 .await
1028 .context("Failed to generate summary")?;
1029
1030 this.update(cx, |this, cx| this.close_session(&id, cx))?
1031 .await?;
1032 drop(acp_thread);
1033 Ok(result)
1034 })
1035 }
1036
1037 fn close_session(
1038 &mut self,
1039 session_id: &acp::SessionId,
1040 cx: &mut Context<Self>,
1041 ) -> Task<Result<()>> {
1042 let Some(session) = self.sessions.get_mut(session_id) else {
1043 return Task::ready(Ok(()));
1044 };
1045
1046 session.ref_count -= 1;
1047 if session.ref_count > 0 {
1048 return Task::ready(Ok(()));
1049 }
1050
1051 let thread = session.thread.clone();
1052 self.save_thread(thread, cx);
1053 let Some(session) = self.sessions.remove(session_id) else {
1054 return Task::ready(Ok(()));
1055 };
1056 let project_id = session.project_id;
1057
1058 let has_remaining = self.sessions.values().any(|s| s.project_id == project_id);
1059 if !has_remaining {
1060 self.projects.remove(&project_id);
1061 }
1062
1063 session.pending_save
1064 }
1065
1066 fn save_thread(&mut self, thread: Entity<Thread>, cx: &mut Context<Self>) {
1067 if thread.read(cx).is_empty() {
1068 return;
1069 }
1070
1071 let id = thread.read(cx).id().clone();
1072 let Some(session) = self.sessions.get_mut(&id) else {
1073 return;
1074 };
1075
1076 let project_id = session.project_id;
1077 let Some(state) = self.projects.get(&project_id) else {
1078 return;
1079 };
1080
1081 let folder_paths = PathList::new(
1082 &state
1083 .project
1084 .read(cx)
1085 .visible_worktrees(cx)
1086 .map(|worktree| worktree.read(cx).abs_path().to_path_buf())
1087 .collect::<Vec<_>>(),
1088 );
1089
1090 let draft_prompt = session.acp_thread.read(cx).draft_prompt().map(Vec::from);
1091 let database_future = ThreadsDatabase::connect(cx);
1092 let db_thread = thread.update(cx, |thread, cx| {
1093 thread.set_draft_prompt(draft_prompt);
1094 thread.to_db(cx)
1095 });
1096 let thread_store = self.thread_store.clone();
1097 session.pending_save = cx.spawn(async move |_, cx| {
1098 let Some(database) = database_future.await.map_err(|err| anyhow!(err)).log_err() else {
1099 return Ok(());
1100 };
1101 let db_thread = db_thread.await;
1102 database
1103 .save_thread(id, db_thread, folder_paths)
1104 .await
1105 .log_err();
1106 thread_store.update(cx, |store, cx| store.reload(cx));
1107 Ok(())
1108 });
1109 }
1110
1111 fn send_mcp_prompt(
1112 &self,
1113 message_id: UserMessageId,
1114 session_id: acp::SessionId,
1115 prompt_name: String,
1116 server_id: ContextServerId,
1117 arguments: HashMap<String, String>,
1118 original_content: Vec<acp::ContentBlock>,
1119 cx: &mut Context<Self>,
1120 ) -> Task<Result<acp::PromptResponse>> {
1121 let Some(state) = self.session_project_state(&session_id) else {
1122 return Task::ready(Err(anyhow!("Project state not found for session")));
1123 };
1124 let server_store = state
1125 .context_server_registry
1126 .read(cx)
1127 .server_store()
1128 .clone();
1129 let path_style = state.project.read(cx).path_style(cx);
1130
1131 cx.spawn(async move |this, cx| {
1132 let prompt =
1133 crate::get_prompt(&server_store, &server_id, &prompt_name, arguments, cx).await?;
1134
1135 let (acp_thread, thread) = this.update(cx, |this, _cx| {
1136 let session = this
1137 .sessions
1138 .get(&session_id)
1139 .context("Failed to get session")?;
1140 anyhow::Ok((session.acp_thread.clone(), session.thread.clone()))
1141 })??;
1142
1143 let mut last_is_user = true;
1144
1145 thread.update(cx, |thread, cx| {
1146 thread.push_acp_user_block(
1147 message_id,
1148 original_content.into_iter().skip(1),
1149 path_style,
1150 cx,
1151 );
1152 });
1153
1154 for message in prompt.messages {
1155 let context_server::types::PromptMessage { role, content } = message;
1156 let block = mcp_message_content_to_acp_content_block(content);
1157
1158 match role {
1159 context_server::types::Role::User => {
1160 let id = acp_thread::UserMessageId::new();
1161
1162 acp_thread.update(cx, |acp_thread, cx| {
1163 acp_thread.push_user_content_block_with_indent(
1164 Some(id.clone()),
1165 block.clone(),
1166 true,
1167 cx,
1168 );
1169 });
1170
1171 thread.update(cx, |thread, cx| {
1172 thread.push_acp_user_block(id, [block], path_style, cx);
1173 });
1174 }
1175 context_server::types::Role::Assistant => {
1176 acp_thread.update(cx, |acp_thread, cx| {
1177 acp_thread.push_assistant_content_block_with_indent(
1178 block.clone(),
1179 false,
1180 true,
1181 cx,
1182 );
1183 });
1184
1185 thread.update(cx, |thread, cx| {
1186 thread.push_acp_agent_block(block, cx);
1187 });
1188 }
1189 }
1190
1191 last_is_user = role == context_server::types::Role::User;
1192 }
1193
1194 let response_stream = thread.update(cx, |thread, cx| {
1195 if last_is_user {
1196 thread.send_existing(cx)
1197 } else {
1198 // Resume if MCP prompt did not end with a user message
1199 thread.resume(cx)
1200 }
1201 })?;
1202
1203 cx.update(|cx| {
1204 NativeAgentConnection::handle_thread_events(
1205 response_stream,
1206 acp_thread.downgrade(),
1207 cx,
1208 )
1209 })
1210 .await
1211 })
1212 }
1213}
1214
1215/// Wrapper struct that implements the AgentConnection trait
1216#[derive(Clone)]
1217pub struct NativeAgentConnection(pub Entity<NativeAgent>);
1218
1219impl NativeAgentConnection {
1220 pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option<Entity<Thread>> {
1221 self.0
1222 .read(cx)
1223 .sessions
1224 .get(session_id)
1225 .map(|session| session.thread.clone())
1226 }
1227
1228 pub fn load_thread(
1229 &self,
1230 id: acp::SessionId,
1231 project: Entity<Project>,
1232 cx: &mut App,
1233 ) -> Task<Result<Entity<Thread>>> {
1234 self.0
1235 .update(cx, |this, cx| this.load_thread(id, project, cx))
1236 }
1237
1238 fn run_turn(
1239 &self,
1240 session_id: acp::SessionId,
1241 cx: &mut App,
1242 f: impl 'static
1243 + FnOnce(Entity<Thread>, &mut App) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>,
1244 ) -> Task<Result<acp::PromptResponse>> {
1245 let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
1246 agent
1247 .sessions
1248 .get_mut(&session_id)
1249 .map(|s| (s.thread.clone(), s.acp_thread.clone()))
1250 }) else {
1251 log::error!("Session not found in run_turn: {}", session_id);
1252 return Task::ready(Err(anyhow!("Session not found")));
1253 };
1254 log::debug!("Found session for: {}", session_id);
1255
1256 let response_stream = match f(thread, cx) {
1257 Ok(stream) => stream,
1258 Err(err) => return Task::ready(Err(err)),
1259 };
1260 Self::handle_thread_events(response_stream, acp_thread.downgrade(), cx)
1261 }
1262
1263 fn handle_thread_events(
1264 mut events: mpsc::UnboundedReceiver<Result<ThreadEvent>>,
1265 acp_thread: WeakEntity<AcpThread>,
1266 cx: &App,
1267 ) -> Task<Result<acp::PromptResponse>> {
1268 cx.spawn(async move |cx| {
1269 // Handle response stream and forward to session.acp_thread
1270 while let Some(result) = events.next().await {
1271 match result {
1272 Ok(event) => {
1273 log::trace!("Received completion event: {:?}", event);
1274
1275 match event {
1276 ThreadEvent::UserMessage(message) => {
1277 acp_thread.update(cx, |thread, cx| {
1278 for content in message.content {
1279 thread.push_user_content_block(
1280 Some(message.id.clone()),
1281 content.into(),
1282 cx,
1283 );
1284 }
1285 })?;
1286 }
1287 ThreadEvent::AgentText(text) => {
1288 acp_thread.update(cx, |thread, cx| {
1289 thread.push_assistant_content_block(text.into(), false, cx)
1290 })?;
1291 }
1292 ThreadEvent::AgentThinking(text) => {
1293 acp_thread.update(cx, |thread, cx| {
1294 thread.push_assistant_content_block(text.into(), true, cx)
1295 })?;
1296 }
1297 ThreadEvent::ToolCallAuthorization(ToolCallAuthorization {
1298 tool_call,
1299 options,
1300 response,
1301 context: _,
1302 }) => {
1303 let outcome_task = acp_thread.update(cx, |thread, cx| {
1304 thread.request_tool_call_authorization(tool_call, options, cx)
1305 })??;
1306 cx.background_spawn(async move {
1307 if let acp_thread::RequestPermissionOutcome::Selected(outcome) =
1308 outcome_task.await
1309 {
1310 response
1311 .send(outcome)
1312 .map(|_| anyhow!("authorization receiver was dropped"))
1313 .log_err();
1314 }
1315 })
1316 .detach();
1317 }
1318 ThreadEvent::ToolCall(tool_call) => {
1319 acp_thread.update(cx, |thread, cx| {
1320 thread.upsert_tool_call(tool_call, cx)
1321 })??;
1322 }
1323 ThreadEvent::ToolCallUpdate(update) => {
1324 acp_thread.update(cx, |thread, cx| {
1325 thread.update_tool_call(update, cx)
1326 })??;
1327 }
1328 ThreadEvent::Plan(plan) => {
1329 acp_thread.update(cx, |thread, cx| thread.update_plan(plan, cx))?;
1330 }
1331 ThreadEvent::SubagentSpawned(session_id) => {
1332 acp_thread.update(cx, |thread, cx| {
1333 thread.subagent_spawned(session_id, cx);
1334 })?;
1335 }
1336 ThreadEvent::Retry(status) => {
1337 acp_thread.update(cx, |thread, cx| {
1338 thread.update_retry_status(status, cx)
1339 })?;
1340 }
1341 ThreadEvent::Stop(stop_reason) => {
1342 log::debug!("Assistant message complete: {:?}", stop_reason);
1343 return Ok(acp::PromptResponse::new(stop_reason));
1344 }
1345 }
1346 }
1347 Err(e) => {
1348 log::error!("Error in model response stream: {:?}", e);
1349 return Err(e);
1350 }
1351 }
1352 }
1353
1354 log::debug!("Response stream completed");
1355 anyhow::Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
1356 })
1357 }
1358}
1359
1360struct Command<'a> {
1361 prompt_name: &'a str,
1362 arg_value: &'a str,
1363 explicit_server_id: Option<&'a str>,
1364}
1365
1366impl<'a> Command<'a> {
1367 fn parse(prompt: &'a [acp::ContentBlock]) -> Option<Self> {
1368 let acp::ContentBlock::Text(text_content) = prompt.first()? else {
1369 return None;
1370 };
1371 let text = text_content.text.trim();
1372 let command = text.strip_prefix('/')?;
1373 let (command, arg_value) = command
1374 .split_once(char::is_whitespace)
1375 .unwrap_or((command, ""));
1376
1377 if let Some((server_id, prompt_name)) = command.split_once('.') {
1378 Some(Self {
1379 prompt_name,
1380 arg_value,
1381 explicit_server_id: Some(server_id),
1382 })
1383 } else {
1384 Some(Self {
1385 prompt_name: command,
1386 arg_value,
1387 explicit_server_id: None,
1388 })
1389 }
1390 }
1391}
1392
1393struct NativeAgentModelSelector {
1394 session_id: acp::SessionId,
1395 connection: NativeAgentConnection,
1396}
1397
1398impl acp_thread::AgentModelSelector for NativeAgentModelSelector {
1399 fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
1400 log::debug!("NativeAgentConnection::list_models called");
1401 let list = self.connection.0.read(cx).models.model_list.clone();
1402 Task::ready(if list.is_empty() {
1403 Err(anyhow::anyhow!("No models available"))
1404 } else {
1405 Ok(list)
1406 })
1407 }
1408
1409 fn select_model(&self, model_id: acp::ModelId, cx: &mut App) -> Task<Result<()>> {
1410 log::debug!(
1411 "Setting model for session {}: {}",
1412 self.session_id,
1413 model_id
1414 );
1415 let Some(thread) = self
1416 .connection
1417 .0
1418 .read(cx)
1419 .sessions
1420 .get(&self.session_id)
1421 .map(|session| session.thread.clone())
1422 else {
1423 return Task::ready(Err(anyhow!("Session not found")));
1424 };
1425
1426 let Some(model) = self.connection.0.read(cx).models.model_from_id(&model_id) else {
1427 return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
1428 };
1429
1430 let favorite = agent_settings::AgentSettings::get_global(cx)
1431 .favorite_models
1432 .iter()
1433 .find(|favorite| {
1434 favorite.provider.0 == model.provider_id().0.as_ref()
1435 && favorite.model == model.id().0.as_ref()
1436 })
1437 .cloned();
1438
1439 let LanguageModelSelection {
1440 enable_thinking,
1441 effort,
1442 speed,
1443 ..
1444 } = agent_settings::language_model_to_selection(&model, favorite.as_ref());
1445
1446 thread.update(cx, |thread, cx| {
1447 thread.set_model(model.clone(), cx);
1448 thread.set_thinking_effort(effort.clone(), cx);
1449 thread.set_thinking_enabled(enable_thinking, cx);
1450 if let Some(speed) = speed {
1451 thread.set_speed(speed, cx);
1452 }
1453 });
1454
1455 update_settings_file(
1456 self.connection.0.read(cx).fs.clone(),
1457 cx,
1458 move |settings, cx| {
1459 let provider = model.provider_id().0.to_string();
1460 let model = model.id().0.to_string();
1461 let enable_thinking = thread.read(cx).thinking_enabled();
1462 let speed = thread.read(cx).speed();
1463 settings
1464 .agent
1465 .get_or_insert_default()
1466 .set_model(LanguageModelSelection {
1467 provider: provider.into(),
1468 model,
1469 enable_thinking,
1470 effort,
1471 speed,
1472 });
1473 },
1474 );
1475
1476 Task::ready(Ok(()))
1477 }
1478
1479 fn selected_model(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelInfo>> {
1480 let Some(thread) = self
1481 .connection
1482 .0
1483 .read(cx)
1484 .sessions
1485 .get(&self.session_id)
1486 .map(|session| session.thread.clone())
1487 else {
1488 return Task::ready(Err(anyhow!("Session not found")));
1489 };
1490 let Some(model) = thread.read(cx).model() else {
1491 return Task::ready(Err(anyhow!("Model not found")));
1492 };
1493 let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
1494 else {
1495 return Task::ready(Err(anyhow!("Provider not found")));
1496 };
1497 Task::ready(Ok(LanguageModels::map_language_model_to_info(
1498 model, &provider,
1499 )))
1500 }
1501
1502 fn watch(&self, cx: &mut App) -> Option<watch::Receiver<()>> {
1503 Some(self.connection.0.read(cx).models.watch())
1504 }
1505
1506 fn should_render_footer(&self) -> bool {
1507 true
1508 }
1509}
1510
1511pub static ZED_AGENT_ID: LazyLock<AgentId> = LazyLock::new(|| AgentId::new("Zed Agent"));
1512
1513impl acp_thread::AgentConnection for NativeAgentConnection {
1514 fn agent_id(&self) -> AgentId {
1515 ZED_AGENT_ID.clone()
1516 }
1517
1518 fn telemetry_id(&self) -> SharedString {
1519 "zed".into()
1520 }
1521
1522 fn new_session(
1523 self: Rc<Self>,
1524 project: Entity<Project>,
1525 work_dirs: PathList,
1526 cx: &mut App,
1527 ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
1528 log::debug!("Creating new thread for project at: {work_dirs:?}");
1529 Task::ready(Ok(self
1530 .0
1531 .update(cx, |agent, cx| agent.new_session(project, cx))))
1532 }
1533
1534 fn supports_load_session(&self) -> bool {
1535 true
1536 }
1537
1538 fn load_session(
1539 self: Rc<Self>,
1540 session_id: acp::SessionId,
1541 project: Entity<Project>,
1542 _work_dirs: PathList,
1543 _title: Option<SharedString>,
1544 cx: &mut App,
1545 ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
1546 self.0
1547 .update(cx, |agent, cx| agent.open_thread(session_id, project, cx))
1548 }
1549
1550 fn supports_close_session(&self) -> bool {
1551 true
1552 }
1553
1554 fn close_session(
1555 self: Rc<Self>,
1556 session_id: &acp::SessionId,
1557 cx: &mut App,
1558 ) -> Task<Result<()>> {
1559 self.0
1560 .update(cx, |agent, cx| agent.close_session(session_id, cx))
1561 }
1562
1563 fn auth_methods(&self) -> &[acp::AuthMethod] {
1564 &[] // No auth for in-process
1565 }
1566
1567 fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
1568 Task::ready(Ok(()))
1569 }
1570
1571 fn model_selector(&self, session_id: &acp::SessionId) -> Option<Rc<dyn AgentModelSelector>> {
1572 Some(Rc::new(NativeAgentModelSelector {
1573 session_id: session_id.clone(),
1574 connection: self.clone(),
1575 }) as Rc<dyn AgentModelSelector>)
1576 }
1577
1578 fn prompt(
1579 &self,
1580 id: acp_thread::UserMessageId,
1581 params: acp::PromptRequest,
1582 cx: &mut App,
1583 ) -> Task<Result<acp::PromptResponse>> {
1584 let session_id = params.session_id.clone();
1585 log::info!("Received prompt request for session: {}", session_id);
1586 log::debug!("Prompt blocks count: {}", params.prompt.len());
1587
1588 let Some(project_state) = self.0.read(cx).session_project_state(&session_id) else {
1589 log::error!("Session not found in prompt: {}", session_id);
1590 if self.0.read(cx).sessions.contains_key(&session_id) {
1591 log::error!(
1592 "Session found in sessions map, but not in project state: {}",
1593 session_id
1594 );
1595 }
1596 return Task::ready(Err(anyhow::anyhow!("Session not found")));
1597 };
1598
1599 if let Some(parsed_command) = Command::parse(¶ms.prompt) {
1600 let registry = project_state.context_server_registry.read(cx);
1601
1602 let explicit_server_id = parsed_command
1603 .explicit_server_id
1604 .map(|server_id| ContextServerId(server_id.into()));
1605
1606 if let Some(prompt) =
1607 registry.find_prompt(explicit_server_id.as_ref(), parsed_command.prompt_name)
1608 {
1609 let arguments = if !parsed_command.arg_value.is_empty()
1610 && let Some(arg_name) = prompt
1611 .prompt
1612 .arguments
1613 .as_ref()
1614 .and_then(|args| args.first())
1615 .map(|arg| arg.name.clone())
1616 {
1617 HashMap::from_iter([(arg_name, parsed_command.arg_value.to_string())])
1618 } else {
1619 Default::default()
1620 };
1621
1622 let prompt_name = prompt.prompt.name.clone();
1623 let server_id = prompt.server_id.clone();
1624
1625 return self.0.update(cx, |agent, cx| {
1626 agent.send_mcp_prompt(
1627 id,
1628 session_id.clone(),
1629 prompt_name,
1630 server_id,
1631 arguments,
1632 params.prompt,
1633 cx,
1634 )
1635 });
1636 }
1637 };
1638
1639 let path_style = project_state.project.read(cx).path_style(cx);
1640
1641 self.run_turn(session_id, cx, move |thread, cx| {
1642 let content: Vec<UserMessageContent> = params
1643 .prompt
1644 .into_iter()
1645 .map(|block| UserMessageContent::from_content_block(block, path_style))
1646 .collect::<Vec<_>>();
1647 log::debug!("Converted prompt to message: {} chars", content.len());
1648 log::debug!("Message id: {:?}", id);
1649 log::debug!("Message content: {:?}", content);
1650
1651 thread.update(cx, |thread, cx| thread.send(id, content, cx))
1652 })
1653 }
1654
1655 fn retry(
1656 &self,
1657 session_id: &acp::SessionId,
1658 _cx: &App,
1659 ) -> Option<Rc<dyn acp_thread::AgentSessionRetry>> {
1660 Some(Rc::new(NativeAgentSessionRetry {
1661 connection: self.clone(),
1662 session_id: session_id.clone(),
1663 }) as _)
1664 }
1665
1666 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
1667 log::info!("Cancelling on session: {}", session_id);
1668 self.0.update(cx, |agent, cx| {
1669 if let Some(session) = agent.sessions.get(session_id) {
1670 session
1671 .thread
1672 .update(cx, |thread, cx| thread.cancel(cx))
1673 .detach();
1674 }
1675 });
1676 }
1677
1678 fn truncate(
1679 &self,
1680 session_id: &acp::SessionId,
1681 cx: &App,
1682 ) -> Option<Rc<dyn acp_thread::AgentSessionTruncate>> {
1683 self.0.read_with(cx, |agent, _cx| {
1684 agent.sessions.get(session_id).map(|session| {
1685 Rc::new(NativeAgentSessionTruncate {
1686 thread: session.thread.clone(),
1687 acp_thread: session.acp_thread.downgrade(),
1688 }) as _
1689 })
1690 })
1691 }
1692
1693 fn set_title(
1694 &self,
1695 session_id: &acp::SessionId,
1696 cx: &App,
1697 ) -> Option<Rc<dyn acp_thread::AgentSessionSetTitle>> {
1698 self.0.read_with(cx, |agent, _cx| {
1699 agent
1700 .sessions
1701 .get(session_id)
1702 .filter(|s| !s.thread.read(cx).is_subagent())
1703 .map(|session| {
1704 Rc::new(NativeAgentSessionSetTitle {
1705 thread: session.thread.clone(),
1706 }) as _
1707 })
1708 })
1709 }
1710
1711 fn session_list(&self, cx: &mut App) -> Option<Rc<dyn AgentSessionList>> {
1712 let thread_store = self.0.read(cx).thread_store.clone();
1713 Some(Rc::new(NativeAgentSessionList::new(thread_store, cx)) as _)
1714 }
1715
1716 fn telemetry(&self) -> Option<Rc<dyn acp_thread::AgentTelemetry>> {
1717 Some(Rc::new(self.clone()) as Rc<dyn acp_thread::AgentTelemetry>)
1718 }
1719
1720 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1721 self
1722 }
1723}
1724
1725impl acp_thread::AgentTelemetry for NativeAgentConnection {
1726 fn thread_data(
1727 &self,
1728 session_id: &acp::SessionId,
1729 cx: &mut App,
1730 ) -> Task<Result<serde_json::Value>> {
1731 let Some(session) = self.0.read(cx).sessions.get(session_id) else {
1732 return Task::ready(Err(anyhow!("Session not found")));
1733 };
1734
1735 let task = session.thread.read(cx).to_db(cx);
1736 cx.background_spawn(async move {
1737 serde_json::to_value(task.await).context("Failed to serialize thread")
1738 })
1739 }
1740}
1741
1742pub struct NativeAgentSessionList {
1743 thread_store: Entity<ThreadStore>,
1744 updates_tx: smol::channel::Sender<acp_thread::SessionListUpdate>,
1745 updates_rx: smol::channel::Receiver<acp_thread::SessionListUpdate>,
1746 _subscription: Subscription,
1747}
1748
1749impl NativeAgentSessionList {
1750 fn new(thread_store: Entity<ThreadStore>, cx: &mut App) -> Self {
1751 let (tx, rx) = smol::channel::unbounded();
1752 let this_tx = tx.clone();
1753 let subscription = cx.observe(&thread_store, move |_, _| {
1754 this_tx
1755 .try_send(acp_thread::SessionListUpdate::Refresh)
1756 .ok();
1757 });
1758 Self {
1759 thread_store,
1760 updates_tx: tx,
1761 updates_rx: rx,
1762 _subscription: subscription,
1763 }
1764 }
1765
1766 pub fn thread_store(&self) -> &Entity<ThreadStore> {
1767 &self.thread_store
1768 }
1769}
1770
1771impl AgentSessionList for NativeAgentSessionList {
1772 fn list_sessions(
1773 &self,
1774 _request: AgentSessionListRequest,
1775 cx: &mut App,
1776 ) -> Task<Result<AgentSessionListResponse>> {
1777 let sessions = self
1778 .thread_store
1779 .read(cx)
1780 .entries()
1781 .map(|entry| AgentSessionInfo::from(&entry))
1782 .collect();
1783 Task::ready(Ok(AgentSessionListResponse::new(sessions)))
1784 }
1785
1786 fn supports_delete(&self) -> bool {
1787 true
1788 }
1789
1790 fn delete_session(&self, session_id: &acp::SessionId, cx: &mut App) -> Task<Result<()>> {
1791 self.thread_store
1792 .update(cx, |store, cx| store.delete_thread(session_id.clone(), cx))
1793 }
1794
1795 fn delete_sessions(&self, cx: &mut App) -> Task<Result<()>> {
1796 self.thread_store
1797 .update(cx, |store, cx| store.delete_threads(cx))
1798 }
1799
1800 fn watch(
1801 &self,
1802 _cx: &mut App,
1803 ) -> Option<smol::channel::Receiver<acp_thread::SessionListUpdate>> {
1804 Some(self.updates_rx.clone())
1805 }
1806
1807 fn notify_refresh(&self) {
1808 self.updates_tx
1809 .try_send(acp_thread::SessionListUpdate::Refresh)
1810 .ok();
1811 }
1812
1813 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1814 self
1815 }
1816}
1817
1818struct NativeAgentSessionTruncate {
1819 thread: Entity<Thread>,
1820 acp_thread: WeakEntity<AcpThread>,
1821}
1822
1823impl acp_thread::AgentSessionTruncate for NativeAgentSessionTruncate {
1824 fn run(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
1825 match self.thread.update(cx, |thread, cx| {
1826 thread.truncate(message_id.clone(), cx)?;
1827 Ok(thread.latest_token_usage())
1828 }) {
1829 Ok(usage) => {
1830 self.acp_thread
1831 .update(cx, |thread, cx| {
1832 thread.update_token_usage(usage, cx);
1833 })
1834 .ok();
1835 Task::ready(Ok(()))
1836 }
1837 Err(error) => Task::ready(Err(error)),
1838 }
1839 }
1840}
1841
1842struct NativeAgentSessionRetry {
1843 connection: NativeAgentConnection,
1844 session_id: acp::SessionId,
1845}
1846
1847impl acp_thread::AgentSessionRetry for NativeAgentSessionRetry {
1848 fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>> {
1849 self.connection
1850 .run_turn(self.session_id.clone(), cx, |thread, cx| {
1851 thread.update(cx, |thread, cx| thread.resume(cx))
1852 })
1853 }
1854}
1855
1856struct NativeAgentSessionSetTitle {
1857 thread: Entity<Thread>,
1858}
1859
1860impl acp_thread::AgentSessionSetTitle for NativeAgentSessionSetTitle {
1861 fn run(&self, title: SharedString, cx: &mut App) -> Task<Result<()>> {
1862 self.thread
1863 .update(cx, |thread, cx| thread.set_title(title, cx));
1864 Task::ready(Ok(()))
1865 }
1866}
1867
1868pub struct NativeThreadEnvironment {
1869 agent: WeakEntity<NativeAgent>,
1870 thread: WeakEntity<Thread>,
1871 acp_thread: WeakEntity<AcpThread>,
1872}
1873
1874impl NativeThreadEnvironment {
1875 pub(crate) fn create_subagent_thread(
1876 &self,
1877 label: String,
1878 cx: &mut App,
1879 ) -> Result<Rc<dyn SubagentHandle>> {
1880 let Some(parent_thread_entity) = self.thread.upgrade() else {
1881 anyhow::bail!("Parent thread no longer exists".to_string());
1882 };
1883 let parent_thread = parent_thread_entity.read(cx);
1884 let current_depth = parent_thread.depth();
1885 let parent_session_id = parent_thread.id().clone();
1886
1887 if current_depth >= MAX_SUBAGENT_DEPTH {
1888 return Err(anyhow!(
1889 "Maximum subagent depth ({}) reached",
1890 MAX_SUBAGENT_DEPTH
1891 ));
1892 }
1893
1894 let subagent_thread: Entity<Thread> = cx.new(|cx| {
1895 let mut thread = Thread::new_subagent(&parent_thread_entity, cx);
1896 thread.set_title(label.into(), cx);
1897 thread
1898 });
1899
1900 let session_id = subagent_thread.read(cx).id().clone();
1901
1902 let acp_thread = self
1903 .agent
1904 .update(cx, |agent, cx| -> Result<Entity<AcpThread>> {
1905 let project_id = agent
1906 .sessions
1907 .get(&parent_session_id)
1908 .map(|s| s.project_id)
1909 .context("parent session not found")?;
1910 Ok(agent.register_session(subagent_thread.clone(), project_id, 1, cx))
1911 })??;
1912
1913 let depth = current_depth + 1;
1914
1915 telemetry::event!(
1916 "Subagent Started",
1917 session = parent_thread_entity.read(cx).id().to_string(),
1918 subagent_session = session_id.to_string(),
1919 depth,
1920 is_resumed = false,
1921 );
1922
1923 self.prompt_subagent(session_id, subagent_thread, acp_thread)
1924 }
1925
1926 pub(crate) fn resume_subagent_thread(
1927 &self,
1928 session_id: acp::SessionId,
1929 cx: &mut App,
1930 ) -> Result<Rc<dyn SubagentHandle>> {
1931 let (subagent_thread, acp_thread) = self.agent.update(cx, |agent, _cx| {
1932 let session = agent
1933 .sessions
1934 .get(&session_id)
1935 .ok_or_else(|| anyhow!("No subagent session found with id {session_id}"))?;
1936 anyhow::Ok((session.thread.clone(), session.acp_thread.clone()))
1937 })??;
1938
1939 let depth = subagent_thread.read(cx).depth();
1940
1941 if let Some(parent_thread_entity) = self.thread.upgrade() {
1942 telemetry::event!(
1943 "Subagent Started",
1944 session = parent_thread_entity.read(cx).id().to_string(),
1945 subagent_session = session_id.to_string(),
1946 depth,
1947 is_resumed = true,
1948 );
1949 }
1950
1951 self.prompt_subagent(session_id, subagent_thread, acp_thread)
1952 }
1953
1954 fn prompt_subagent(
1955 &self,
1956 session_id: acp::SessionId,
1957 subagent_thread: Entity<Thread>,
1958 acp_thread: Entity<acp_thread::AcpThread>,
1959 ) -> Result<Rc<dyn SubagentHandle>> {
1960 let Some(parent_thread_entity) = self.thread.upgrade() else {
1961 anyhow::bail!("Parent thread no longer exists".to_string());
1962 };
1963 Ok(Rc::new(NativeSubagentHandle::new(
1964 session_id,
1965 subagent_thread,
1966 acp_thread,
1967 parent_thread_entity,
1968 )) as _)
1969 }
1970}
1971
1972impl ThreadEnvironment for NativeThreadEnvironment {
1973 fn create_terminal(
1974 &self,
1975 command: String,
1976 cwd: Option<PathBuf>,
1977 output_byte_limit: Option<u64>,
1978 cx: &mut AsyncApp,
1979 ) -> Task<Result<Rc<dyn TerminalHandle>>> {
1980 let task = self.acp_thread.update(cx, |thread, cx| {
1981 thread.create_terminal(command, vec![], vec![], cwd, output_byte_limit, cx)
1982 });
1983
1984 let acp_thread = self.acp_thread.clone();
1985 cx.spawn(async move |cx| {
1986 let terminal = task?.await?;
1987
1988 let (drop_tx, drop_rx) = oneshot::channel();
1989 let terminal_id = terminal.read_with(cx, |terminal, _cx| terminal.id().clone());
1990
1991 cx.spawn(async move |cx| {
1992 drop_rx.await.ok();
1993 acp_thread.update(cx, |thread, cx| thread.release_terminal(terminal_id, cx))
1994 })
1995 .detach();
1996
1997 let handle = AcpTerminalHandle {
1998 terminal,
1999 _drop_tx: Some(drop_tx),
2000 };
2001
2002 Ok(Rc::new(handle) as _)
2003 })
2004 }
2005
2006 fn create_subagent(&self, label: String, cx: &mut App) -> Result<Rc<dyn SubagentHandle>> {
2007 self.create_subagent_thread(label, cx)
2008 }
2009
2010 fn resume_subagent(
2011 &self,
2012 session_id: acp::SessionId,
2013 cx: &mut App,
2014 ) -> Result<Rc<dyn SubagentHandle>> {
2015 self.resume_subagent_thread(session_id, cx)
2016 }
2017}
2018
2019#[derive(Debug, Clone)]
2020enum SubagentPromptResult {
2021 Completed,
2022 Cancelled,
2023 ContextWindowWarning,
2024 Error(String),
2025}
2026
2027pub struct NativeSubagentHandle {
2028 session_id: acp::SessionId,
2029 parent_thread: WeakEntity<Thread>,
2030 subagent_thread: Entity<Thread>,
2031 acp_thread: Entity<acp_thread::AcpThread>,
2032}
2033
2034impl NativeSubagentHandle {
2035 fn new(
2036 session_id: acp::SessionId,
2037 subagent_thread: Entity<Thread>,
2038 acp_thread: Entity<acp_thread::AcpThread>,
2039 parent_thread_entity: Entity<Thread>,
2040 ) -> Self {
2041 NativeSubagentHandle {
2042 session_id,
2043 subagent_thread,
2044 parent_thread: parent_thread_entity.downgrade(),
2045 acp_thread,
2046 }
2047 }
2048}
2049
2050impl SubagentHandle for NativeSubagentHandle {
2051 fn id(&self) -> acp::SessionId {
2052 self.session_id.clone()
2053 }
2054
2055 fn num_entries(&self, cx: &App) -> usize {
2056 self.acp_thread.read(cx).entries().len()
2057 }
2058
2059 fn send(&self, message: String, cx: &AsyncApp) -> Task<Result<String>> {
2060 let thread = self.subagent_thread.clone();
2061 let acp_thread = self.acp_thread.clone();
2062 let subagent_session_id = self.session_id.clone();
2063 let parent_thread = self.parent_thread.clone();
2064
2065 cx.spawn(async move |cx| {
2066 let (task, _subscription) = cx.update(|cx| {
2067 let ratio_before_prompt = thread
2068 .read(cx)
2069 .latest_token_usage()
2070 .map(|usage| usage.ratio());
2071
2072 parent_thread
2073 .update(cx, |parent_thread, _cx| {
2074 parent_thread.register_running_subagent(thread.downgrade())
2075 })
2076 .ok();
2077
2078 let task = acp_thread.update(cx, |acp_thread, cx| {
2079 acp_thread.send(vec![message.into()], cx)
2080 });
2081
2082 let (token_limit_tx, token_limit_rx) = oneshot::channel::<()>();
2083 let mut token_limit_tx = Some(token_limit_tx);
2084
2085 let subscription = cx.subscribe(
2086 &thread,
2087 move |_thread, event: &TokenUsageUpdated, _cx| {
2088 if let Some(usage) = &event.0 {
2089 let old_ratio = ratio_before_prompt
2090 .clone()
2091 .unwrap_or(TokenUsageRatio::Normal);
2092 let new_ratio = usage.ratio();
2093 if old_ratio == TokenUsageRatio::Normal
2094 && new_ratio == TokenUsageRatio::Warning
2095 {
2096 if let Some(tx) = token_limit_tx.take() {
2097 tx.send(()).ok();
2098 }
2099 }
2100 }
2101 },
2102 );
2103
2104 let wait_for_prompt = cx
2105 .background_spawn(async move {
2106 futures::select! {
2107 response = task.fuse() => match response {
2108 Ok(Some(response)) => {
2109 match response.stop_reason {
2110 acp::StopReason::Cancelled => SubagentPromptResult::Cancelled,
2111 acp::StopReason::MaxTokens => SubagentPromptResult::Error("The agent reached the maximum number of tokens.".into()),
2112 acp::StopReason::MaxTurnRequests => SubagentPromptResult::Error("The agent reached the maximum number of allowed requests between user turns. Try prompting again.".into()),
2113 acp::StopReason::Refusal => SubagentPromptResult::Error("The agent refused to process that prompt. Try again.".into()),
2114 acp::StopReason::EndTurn | _ => SubagentPromptResult::Completed,
2115 }
2116 }
2117 Ok(None) => SubagentPromptResult::Error("No response from the agent. You can try messaging again.".into()),
2118 Err(error) => SubagentPromptResult::Error(error.to_string()),
2119 },
2120 _ = token_limit_rx.fuse() => SubagentPromptResult::ContextWindowWarning,
2121 }
2122 });
2123
2124 (wait_for_prompt, subscription)
2125 });
2126
2127 let result = match task.await {
2128 SubagentPromptResult::Completed => thread.read_with(cx, |thread, _cx| {
2129 thread
2130 .last_message()
2131 .and_then(|message| {
2132 let content = message.as_agent_message()?
2133 .content
2134 .iter()
2135 .filter_map(|c| match c {
2136 AgentMessageContent::Text(text) => Some(text.as_str()),
2137 _ => None,
2138 })
2139 .join("\n\n");
2140 if content.is_empty() {
2141 None
2142 } else {
2143 Some( content)
2144 }
2145 })
2146 .context("No response from subagent")
2147 }),
2148 SubagentPromptResult::Cancelled => Err(anyhow!("User canceled")),
2149 SubagentPromptResult::Error(message) => Err(anyhow!("{message}")),
2150 SubagentPromptResult::ContextWindowWarning => {
2151 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
2152 Err(anyhow!(
2153 "The agent is nearing the end of its context window and has been \
2154 stopped. You can prompt the thread again to have the agent wrap up \
2155 or hand off its work."
2156 ))
2157 }
2158 };
2159
2160 parent_thread
2161 .update(cx, |parent_thread, cx| {
2162 parent_thread.unregister_running_subagent(&subagent_session_id, cx)
2163 })
2164 .ok();
2165
2166 result
2167 })
2168 }
2169}
2170
2171pub struct AcpTerminalHandle {
2172 terminal: Entity<acp_thread::Terminal>,
2173 _drop_tx: Option<oneshot::Sender<()>>,
2174}
2175
2176impl TerminalHandle for AcpTerminalHandle {
2177 fn id(&self, cx: &AsyncApp) -> Result<acp::TerminalId> {
2178 Ok(self.terminal.read_with(cx, |term, _cx| term.id().clone()))
2179 }
2180
2181 fn wait_for_exit(&self, cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
2182 Ok(self
2183 .terminal
2184 .read_with(cx, |term, _cx| term.wait_for_exit()))
2185 }
2186
2187 fn current_output(&self, cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
2188 Ok(self
2189 .terminal
2190 .read_with(cx, |term, cx| term.current_output(cx)))
2191 }
2192
2193 fn kill(&self, cx: &AsyncApp) -> Result<()> {
2194 cx.update(|cx| {
2195 self.terminal.update(cx, |terminal, cx| {
2196 terminal.kill(cx);
2197 });
2198 });
2199 Ok(())
2200 }
2201
2202 fn was_stopped_by_user(&self, cx: &AsyncApp) -> Result<bool> {
2203 Ok(self
2204 .terminal
2205 .read_with(cx, |term, _cx| term.was_stopped_by_user()))
2206 }
2207}
2208
2209#[cfg(test)]
2210mod internal_tests {
2211 use std::path::Path;
2212
2213 use super::*;
2214 use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelInfo, MentionUri};
2215 use fs::FakeFs;
2216 use gpui::TestAppContext;
2217 use indoc::formatdoc;
2218 use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
2219 use language_model::{
2220 LanguageModelCompletionEvent, LanguageModelProviderId, LanguageModelProviderName,
2221 };
2222 use serde_json::json;
2223 use settings::SettingsStore;
2224 use util::{path, rel_path::rel_path};
2225
2226 #[gpui::test]
2227 async fn test_maintaining_project_context(cx: &mut TestAppContext) {
2228 init_test(cx);
2229 let fs = FakeFs::new(cx.executor());
2230 fs.insert_tree(
2231 "/",
2232 json!({
2233 "a": {}
2234 }),
2235 )
2236 .await;
2237 let project = Project::test(fs.clone(), [], cx).await;
2238 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2239 let agent =
2240 cx.update(|cx| NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx));
2241
2242 // Creating a session registers the project and triggers context building.
2243 let connection = NativeAgentConnection(agent.clone());
2244 let _acp_thread = cx
2245 .update(|cx| {
2246 Rc::new(connection).new_session(
2247 project.clone(),
2248 PathList::new(&[Path::new("/")]),
2249 cx,
2250 )
2251 })
2252 .await
2253 .unwrap();
2254 cx.run_until_parked();
2255
2256 let thread = agent.read_with(cx, |agent, _cx| {
2257 agent.sessions.values().next().unwrap().thread.clone()
2258 });
2259
2260 agent.read_with(cx, |agent, cx| {
2261 let project_id = project.entity_id();
2262 let state = agent.projects.get(&project_id).unwrap();
2263 assert_eq!(state.project_context.read(cx).worktrees, vec![]);
2264 assert_eq!(thread.read(cx).project_context().read(cx).worktrees, vec![]);
2265 });
2266
2267 let worktree = project
2268 .update(cx, |project, cx| project.create_worktree("/a", true, cx))
2269 .await
2270 .unwrap();
2271 cx.run_until_parked();
2272 agent.read_with(cx, |agent, cx| {
2273 let project_id = project.entity_id();
2274 let state = agent.projects.get(&project_id).unwrap();
2275 let expected_worktrees = vec![WorktreeContext {
2276 root_name: "a".into(),
2277 abs_path: Path::new("/a").into(),
2278 rules_file: None,
2279 }];
2280 assert_eq!(state.project_context.read(cx).worktrees, expected_worktrees);
2281 assert_eq!(
2282 thread.read(cx).project_context().read(cx).worktrees,
2283 expected_worktrees
2284 );
2285 });
2286
2287 // Creating `/a/.rules` updates the project context.
2288 fs.insert_file("/a/.rules", Vec::new()).await;
2289 cx.run_until_parked();
2290 agent.read_with(cx, |agent, cx| {
2291 let project_id = project.entity_id();
2292 let state = agent.projects.get(&project_id).unwrap();
2293 let rules_entry = worktree
2294 .read(cx)
2295 .entry_for_path(rel_path(".rules"))
2296 .unwrap();
2297 let expected_worktrees = vec![WorktreeContext {
2298 root_name: "a".into(),
2299 abs_path: Path::new("/a").into(),
2300 rules_file: Some(RulesFileContext {
2301 path_in_worktree: rel_path(".rules").into(),
2302 text: "".into(),
2303 project_entry_id: rules_entry.id.to_usize(),
2304 }),
2305 }];
2306 assert_eq!(state.project_context.read(cx).worktrees, expected_worktrees);
2307 assert_eq!(
2308 thread.read(cx).project_context().read(cx).worktrees,
2309 expected_worktrees
2310 );
2311 });
2312 }
2313
2314 #[gpui::test]
2315 async fn test_listing_models(cx: &mut TestAppContext) {
2316 init_test(cx);
2317 let fs = FakeFs::new(cx.executor());
2318 fs.insert_tree("/", json!({ "a": {} })).await;
2319 let project = Project::test(fs.clone(), [], cx).await;
2320 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2321 let connection =
2322 NativeAgentConnection(cx.update(|cx| {
2323 NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx)
2324 }));
2325
2326 // Create a thread/session
2327 let acp_thread = cx
2328 .update(|cx| {
2329 Rc::new(connection.clone()).new_session(
2330 project.clone(),
2331 PathList::new(&[Path::new("/a")]),
2332 cx,
2333 )
2334 })
2335 .await
2336 .unwrap();
2337
2338 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2339
2340 let models = cx
2341 .update(|cx| {
2342 connection
2343 .model_selector(&session_id)
2344 .unwrap()
2345 .list_models(cx)
2346 })
2347 .await
2348 .unwrap();
2349
2350 let acp_thread::AgentModelList::Grouped(models) = models else {
2351 panic!("Unexpected model group");
2352 };
2353 assert_eq!(
2354 models,
2355 IndexMap::from_iter([(
2356 AgentModelGroupName("Fake".into()),
2357 vec![AgentModelInfo {
2358 id: acp::ModelId::new("fake/fake"),
2359 name: "Fake".into(),
2360 description: None,
2361 icon: Some(acp_thread::AgentModelIcon::Named(
2362 ui::IconName::ZedAssistant
2363 )),
2364 is_latest: false,
2365 cost: None,
2366 }]
2367 )])
2368 );
2369 }
2370
2371 #[gpui::test]
2372 async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) {
2373 init_test(cx);
2374 let fs = FakeFs::new(cx.executor());
2375 fs.create_dir(paths::settings_file().parent().unwrap())
2376 .await
2377 .unwrap();
2378 fs.insert_file(
2379 paths::settings_file(),
2380 json!({
2381 "agent": {
2382 "default_model": {
2383 "provider": "foo",
2384 "model": "bar"
2385 }
2386 }
2387 })
2388 .to_string()
2389 .into_bytes(),
2390 )
2391 .await;
2392 let project = Project::test(fs.clone(), [], cx).await;
2393
2394 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2395
2396 // Create the agent and connection
2397 let agent =
2398 cx.update(|cx| NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx));
2399 let connection = NativeAgentConnection(agent.clone());
2400
2401 // Create a thread/session
2402 let acp_thread = cx
2403 .update(|cx| {
2404 Rc::new(connection.clone()).new_session(
2405 project.clone(),
2406 PathList::new(&[Path::new("/a")]),
2407 cx,
2408 )
2409 })
2410 .await
2411 .unwrap();
2412
2413 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2414
2415 // Select a model
2416 let selector = connection.model_selector(&session_id).unwrap();
2417 let model_id = acp::ModelId::new("fake/fake");
2418 cx.update(|cx| selector.select_model(model_id.clone(), cx))
2419 .await
2420 .unwrap();
2421
2422 // Verify the thread has the selected model
2423 agent.read_with(cx, |agent, _| {
2424 let session = agent.sessions.get(&session_id).unwrap();
2425 session.thread.read_with(cx, |thread, _| {
2426 assert_eq!(thread.model().unwrap().id().0, "fake");
2427 });
2428 });
2429
2430 cx.run_until_parked();
2431
2432 // Verify settings file was updated
2433 let settings_content = fs.load(paths::settings_file()).await.unwrap();
2434 let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
2435
2436 // Check that the agent settings contain the selected model
2437 assert_eq!(
2438 settings_json["agent"]["default_model"]["model"],
2439 json!("fake")
2440 );
2441 assert_eq!(
2442 settings_json["agent"]["default_model"]["provider"],
2443 json!("fake")
2444 );
2445
2446 // Register a thinking model and select it.
2447 cx.update(|cx| {
2448 let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2449 "fake-corp",
2450 "fake-thinking",
2451 "Fake Thinking",
2452 true,
2453 ));
2454 let thinking_provider = Arc::new(
2455 FakeLanguageModelProvider::new(
2456 LanguageModelProviderId::from("fake-corp".to_string()),
2457 LanguageModelProviderName::from("Fake Corp".to_string()),
2458 )
2459 .with_models(vec![thinking_model]),
2460 );
2461 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2462 registry.register_provider(thinking_provider, cx);
2463 });
2464 });
2465 agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2466
2467 let selector = connection.model_selector(&session_id).unwrap();
2468 cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2469 .await
2470 .unwrap();
2471 cx.run_until_parked();
2472
2473 // Verify enable_thinking was written to settings as true.
2474 let settings_content = fs.load(paths::settings_file()).await.unwrap();
2475 let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
2476 assert_eq!(
2477 settings_json["agent"]["default_model"]["enable_thinking"],
2478 json!(true),
2479 "selecting a thinking model should persist enable_thinking: true to settings"
2480 );
2481 }
2482
2483 #[gpui::test]
2484 async fn test_select_model_updates_thinking_enabled(cx: &mut TestAppContext) {
2485 init_test(cx);
2486 let fs = FakeFs::new(cx.executor());
2487 fs.create_dir(paths::settings_file().parent().unwrap())
2488 .await
2489 .unwrap();
2490 fs.insert_file(paths::settings_file(), b"{}".to_vec()).await;
2491 let project = Project::test(fs.clone(), [], cx).await;
2492
2493 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2494 let agent =
2495 cx.update(|cx| NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx));
2496 let connection = NativeAgentConnection(agent.clone());
2497
2498 let acp_thread = cx
2499 .update(|cx| {
2500 Rc::new(connection.clone()).new_session(
2501 project.clone(),
2502 PathList::new(&[Path::new("/a")]),
2503 cx,
2504 )
2505 })
2506 .await
2507 .unwrap();
2508 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2509
2510 // Register a second provider with a thinking model.
2511 cx.update(|cx| {
2512 let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2513 "fake-corp",
2514 "fake-thinking",
2515 "Fake Thinking",
2516 true,
2517 ));
2518 let thinking_provider = Arc::new(
2519 FakeLanguageModelProvider::new(
2520 LanguageModelProviderId::from("fake-corp".to_string()),
2521 LanguageModelProviderName::from("Fake Corp".to_string()),
2522 )
2523 .with_models(vec![thinking_model]),
2524 );
2525 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2526 registry.register_provider(thinking_provider, cx);
2527 });
2528 });
2529 // Refresh the agent's model list so it picks up the new provider.
2530 agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2531
2532 // Thread starts with thinking_enabled = false (the default).
2533 agent.read_with(cx, |agent, _| {
2534 let session = agent.sessions.get(&session_id).unwrap();
2535 session.thread.read_with(cx, |thread, _| {
2536 assert!(!thread.thinking_enabled(), "thinking defaults to false");
2537 });
2538 });
2539
2540 // Select the thinking model via select_model.
2541 let selector = connection.model_selector(&session_id).unwrap();
2542 cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2543 .await
2544 .unwrap();
2545
2546 // select_model should have enabled thinking based on the model's supports_thinking().
2547 agent.read_with(cx, |agent, _| {
2548 let session = agent.sessions.get(&session_id).unwrap();
2549 session.thread.read_with(cx, |thread, _| {
2550 assert!(
2551 thread.thinking_enabled(),
2552 "select_model should enable thinking when model supports it"
2553 );
2554 });
2555 });
2556
2557 // Switch back to the non-thinking model.
2558 let selector = connection.model_selector(&session_id).unwrap();
2559 cx.update(|cx| selector.select_model(acp::ModelId::new("fake/fake"), cx))
2560 .await
2561 .unwrap();
2562
2563 // select_model should have disabled thinking.
2564 agent.read_with(cx, |agent, _| {
2565 let session = agent.sessions.get(&session_id).unwrap();
2566 session.thread.read_with(cx, |thread, _| {
2567 assert!(
2568 !thread.thinking_enabled(),
2569 "select_model should disable thinking when model does not support it"
2570 );
2571 });
2572 });
2573 }
2574
2575 #[gpui::test]
2576 async fn test_summarization_model_survives_transient_registry_clearing(
2577 cx: &mut TestAppContext,
2578 ) {
2579 init_test(cx);
2580 let fs = FakeFs::new(cx.executor());
2581 fs.insert_tree("/", json!({ "a": {} })).await;
2582 let project = Project::test(fs.clone(), [], cx).await;
2583
2584 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2585 let agent =
2586 cx.update(|cx| NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx));
2587 let connection = Rc::new(NativeAgentConnection(agent.clone()));
2588
2589 let acp_thread = cx
2590 .update(|cx| {
2591 connection.clone().new_session(
2592 project.clone(),
2593 PathList::new(&[Path::new("/a")]),
2594 cx,
2595 )
2596 })
2597 .await
2598 .unwrap();
2599 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2600
2601 let thread = agent.read_with(cx, |agent, _| {
2602 agent.sessions.get(&session_id).unwrap().thread.clone()
2603 });
2604
2605 thread.read_with(cx, |thread, _| {
2606 assert!(
2607 thread.summarization_model().is_some(),
2608 "session should have a summarization model from the test registry"
2609 );
2610 });
2611
2612 // Simulate what happens during a provider blip:
2613 // update_active_language_model_from_settings calls set_default_model(None)
2614 // when it can't resolve the model, clearing all fallbacks.
2615 cx.update(|cx| {
2616 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2617 registry.set_default_model(None, cx);
2618 });
2619 });
2620 cx.run_until_parked();
2621
2622 thread.read_with(cx, |thread, _| {
2623 assert!(
2624 thread.summarization_model().is_some(),
2625 "summarization model should survive a transient default model clearing"
2626 );
2627 });
2628 }
2629
2630 #[gpui::test]
2631 async fn test_loaded_thread_preserves_thinking_enabled(cx: &mut TestAppContext) {
2632 init_test(cx);
2633 let fs = FakeFs::new(cx.executor());
2634 fs.insert_tree("/", json!({ "a": {} })).await;
2635 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2636 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2637 let agent = cx.update(|cx| {
2638 NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
2639 });
2640 let connection = Rc::new(NativeAgentConnection(agent.clone()));
2641
2642 // Register a thinking model.
2643 let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2644 "fake-corp",
2645 "fake-thinking",
2646 "Fake Thinking",
2647 true,
2648 ));
2649 let thinking_provider = Arc::new(
2650 FakeLanguageModelProvider::new(
2651 LanguageModelProviderId::from("fake-corp".to_string()),
2652 LanguageModelProviderName::from("Fake Corp".to_string()),
2653 )
2654 .with_models(vec![thinking_model.clone()]),
2655 );
2656 cx.update(|cx| {
2657 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2658 registry.register_provider(thinking_provider, cx);
2659 });
2660 });
2661 agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2662
2663 // Create a thread and select the thinking model.
2664 let acp_thread = cx
2665 .update(|cx| {
2666 connection.clone().new_session(
2667 project.clone(),
2668 PathList::new(&[Path::new("/a")]),
2669 cx,
2670 )
2671 })
2672 .await
2673 .unwrap();
2674 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2675
2676 let selector = connection.model_selector(&session_id).unwrap();
2677 cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2678 .await
2679 .unwrap();
2680
2681 // Verify thinking is enabled after selecting the thinking model.
2682 let thread = agent.read_with(cx, |agent, _| {
2683 agent.sessions.get(&session_id).unwrap().thread.clone()
2684 });
2685 thread.read_with(cx, |thread, _| {
2686 assert!(
2687 thread.thinking_enabled(),
2688 "thinking should be enabled after selecting thinking model"
2689 );
2690 });
2691
2692 // Send a message so the thread gets persisted.
2693 let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
2694 let send = cx.foreground_executor().spawn(send);
2695 cx.run_until_parked();
2696
2697 thinking_model.send_last_completion_stream_text_chunk("Response.");
2698 thinking_model.end_last_completion_stream();
2699
2700 send.await.unwrap();
2701 cx.run_until_parked();
2702
2703 // Close the session so it can be reloaded from disk.
2704 cx.update(|cx| connection.clone().close_session(&session_id, cx))
2705 .await
2706 .unwrap();
2707 drop(thread);
2708 drop(acp_thread);
2709 agent.read_with(cx, |agent, _| {
2710 assert!(agent.sessions.is_empty());
2711 });
2712
2713 // Reload the thread and verify thinking_enabled is still true.
2714 let reloaded_acp_thread = agent
2715 .update(cx, |agent, cx| {
2716 agent.open_thread(session_id.clone(), project.clone(), cx)
2717 })
2718 .await
2719 .unwrap();
2720 let reloaded_thread = agent.read_with(cx, |agent, _| {
2721 agent.sessions.get(&session_id).unwrap().thread.clone()
2722 });
2723 reloaded_thread.read_with(cx, |thread, _| {
2724 assert!(
2725 thread.thinking_enabled(),
2726 "thinking_enabled should be preserved when reloading a thread with a thinking model"
2727 );
2728 });
2729
2730 drop(reloaded_acp_thread);
2731 }
2732
2733 #[gpui::test]
2734 async fn test_loaded_thread_preserves_model(cx: &mut TestAppContext) {
2735 init_test(cx);
2736 let fs = FakeFs::new(cx.executor());
2737 fs.insert_tree("/", json!({ "a": {} })).await;
2738 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2739 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2740 let agent = cx.update(|cx| {
2741 NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
2742 });
2743 let connection = Rc::new(NativeAgentConnection(agent.clone()));
2744
2745 // Register a model where id() != name(), like real Anthropic models
2746 // (e.g. id="claude-sonnet-4-5-thinking-latest", name="Claude Sonnet 4.5 Thinking").
2747 let model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2748 "fake-corp",
2749 "custom-model-id",
2750 "Custom Model Display Name",
2751 false,
2752 ));
2753 let provider = Arc::new(
2754 FakeLanguageModelProvider::new(
2755 LanguageModelProviderId::from("fake-corp".to_string()),
2756 LanguageModelProviderName::from("Fake Corp".to_string()),
2757 )
2758 .with_models(vec![model.clone()]),
2759 );
2760 cx.update(|cx| {
2761 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2762 registry.register_provider(provider, cx);
2763 });
2764 });
2765 agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2766
2767 // Create a thread and select the model.
2768 let acp_thread = cx
2769 .update(|cx| {
2770 connection.clone().new_session(
2771 project.clone(),
2772 PathList::new(&[Path::new("/a")]),
2773 cx,
2774 )
2775 })
2776 .await
2777 .unwrap();
2778 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2779
2780 let selector = connection.model_selector(&session_id).unwrap();
2781 cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/custom-model-id"), cx))
2782 .await
2783 .unwrap();
2784
2785 let thread = agent.read_with(cx, |agent, _| {
2786 agent.sessions.get(&session_id).unwrap().thread.clone()
2787 });
2788 thread.read_with(cx, |thread, _| {
2789 assert_eq!(
2790 thread.model().unwrap().id().0.as_ref(),
2791 "custom-model-id",
2792 "model should be set before persisting"
2793 );
2794 });
2795
2796 // Send a message so the thread gets persisted.
2797 let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
2798 let send = cx.foreground_executor().spawn(send);
2799 cx.run_until_parked();
2800
2801 model.send_last_completion_stream_text_chunk("Response.");
2802 model.end_last_completion_stream();
2803
2804 send.await.unwrap();
2805 cx.run_until_parked();
2806
2807 // Close the session so it can be reloaded from disk.
2808 cx.update(|cx| connection.clone().close_session(&session_id, cx))
2809 .await
2810 .unwrap();
2811 drop(thread);
2812 drop(acp_thread);
2813 agent.read_with(cx, |agent, _| {
2814 assert!(agent.sessions.is_empty());
2815 });
2816
2817 // Reload the thread and verify the model was preserved.
2818 let reloaded_acp_thread = agent
2819 .update(cx, |agent, cx| {
2820 agent.open_thread(session_id.clone(), project.clone(), cx)
2821 })
2822 .await
2823 .unwrap();
2824 let reloaded_thread = agent.read_with(cx, |agent, _| {
2825 agent.sessions.get(&session_id).unwrap().thread.clone()
2826 });
2827 reloaded_thread.read_with(cx, |thread, _| {
2828 let reloaded_model = thread
2829 .model()
2830 .expect("model should be present after reload");
2831 assert_eq!(
2832 reloaded_model.id().0.as_ref(),
2833 "custom-model-id",
2834 "reloaded thread should have the same model, not fall back to the default"
2835 );
2836 });
2837
2838 drop(reloaded_acp_thread);
2839 }
2840
2841 #[gpui::test]
2842 async fn test_save_load_thread(cx: &mut TestAppContext) {
2843 init_test(cx);
2844 let fs = FakeFs::new(cx.executor());
2845 fs.insert_tree(
2846 "/",
2847 json!({
2848 "a": {
2849 "b.md": "Lorem"
2850 }
2851 }),
2852 )
2853 .await;
2854 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2855 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2856 let agent = cx.update(|cx| {
2857 NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
2858 });
2859 let connection = Rc::new(NativeAgentConnection(agent.clone()));
2860
2861 let acp_thread = cx
2862 .update(|cx| {
2863 connection
2864 .clone()
2865 .new_session(project.clone(), PathList::new(&[Path::new("")]), cx)
2866 })
2867 .await
2868 .unwrap();
2869 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2870 let thread = agent.read_with(cx, |agent, _| {
2871 agent.sessions.get(&session_id).unwrap().thread.clone()
2872 });
2873
2874 // Ensure empty threads are not saved, even if they get mutated.
2875 let model = Arc::new(FakeLanguageModel::default());
2876 let summary_model = Arc::new(FakeLanguageModel::default());
2877 thread.update(cx, |thread, cx| {
2878 thread.set_model(model.clone(), cx);
2879 thread.set_summarization_model(Some(summary_model.clone()), cx);
2880 });
2881 cx.run_until_parked();
2882 assert_eq!(thread_entries(&thread_store, cx), vec![]);
2883
2884 let send = acp_thread.update(cx, |thread, cx| {
2885 thread.send(
2886 vec![
2887 "What does ".into(),
2888 acp::ContentBlock::ResourceLink(acp::ResourceLink::new(
2889 "b.md",
2890 MentionUri::File {
2891 abs_path: path!("/a/b.md").into(),
2892 }
2893 .to_uri()
2894 .to_string(),
2895 )),
2896 " mean?".into(),
2897 ],
2898 cx,
2899 )
2900 });
2901 let send = cx.foreground_executor().spawn(send);
2902 cx.run_until_parked();
2903
2904 model.send_last_completion_stream_text_chunk("Lorem.");
2905 model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2906 language_model::TokenUsage {
2907 input_tokens: 150,
2908 output_tokens: 75,
2909 ..Default::default()
2910 },
2911 ));
2912 model.end_last_completion_stream();
2913 cx.run_until_parked();
2914 summary_model
2915 .send_last_completion_stream_text_chunk(&format!("Explaining {}", path!("/a/b.md")));
2916 summary_model.end_last_completion_stream();
2917
2918 send.await.unwrap();
2919 let uri = MentionUri::File {
2920 abs_path: path!("/a/b.md").into(),
2921 }
2922 .to_uri();
2923 acp_thread.read_with(cx, |thread, cx| {
2924 assert_eq!(
2925 thread.to_markdown(cx),
2926 formatdoc! {"
2927 ## User
2928
2929 What does [@b.md]({uri}) mean?
2930
2931 ## Assistant
2932
2933 Lorem.
2934
2935 "}
2936 )
2937 });
2938
2939 cx.run_until_parked();
2940
2941 // Set a draft prompt with rich content blocks and scroll position
2942 // AFTER run_until_parked, so the only save that captures these
2943 // changes is the one performed by close_session itself.
2944 let draft_blocks = vec![
2945 acp::ContentBlock::Text(acp::TextContent::new("Check out ")),
2946 acp::ContentBlock::ResourceLink(acp::ResourceLink::new("b.md", uri.to_string())),
2947 acp::ContentBlock::Text(acp::TextContent::new(" please")),
2948 ];
2949 acp_thread.update(cx, |thread, cx| {
2950 thread.set_draft_prompt(Some(draft_blocks.clone()), cx);
2951 });
2952 thread.update(cx, |thread, _cx| {
2953 thread.set_ui_scroll_position(Some(gpui::ListOffset {
2954 item_ix: 5,
2955 offset_in_item: gpui::px(12.5),
2956 }));
2957 });
2958
2959 // Close the session so it can be reloaded from disk.
2960 cx.update(|cx| connection.clone().close_session(&session_id, cx))
2961 .await
2962 .unwrap();
2963 drop(thread);
2964 drop(acp_thread);
2965 agent.read_with(cx, |agent, _| {
2966 assert_eq!(agent.sessions.keys().cloned().collect::<Vec<_>>(), []);
2967 });
2968
2969 // Ensure the thread can be reloaded from disk.
2970 assert_eq!(
2971 thread_entries(&thread_store, cx),
2972 vec![(
2973 session_id.clone(),
2974 format!("Explaining {}", path!("/a/b.md"))
2975 )]
2976 );
2977 let acp_thread = agent
2978 .update(cx, |agent, cx| {
2979 agent.open_thread(session_id.clone(), project.clone(), cx)
2980 })
2981 .await
2982 .unwrap();
2983 acp_thread.read_with(cx, |thread, cx| {
2984 assert_eq!(
2985 thread.to_markdown(cx),
2986 formatdoc! {"
2987 ## User
2988
2989 What does [@b.md]({uri}) mean?
2990
2991 ## Assistant
2992
2993 Lorem.
2994
2995 "}
2996 )
2997 });
2998
2999 // Ensure the draft prompt with rich content blocks survived the round-trip.
3000 acp_thread.read_with(cx, |thread, _| {
3001 assert_eq!(thread.draft_prompt(), Some(draft_blocks.as_slice()));
3002 });
3003
3004 // Ensure token usage survived the round-trip.
3005 acp_thread.read_with(cx, |thread, _| {
3006 let usage = thread
3007 .token_usage()
3008 .expect("token usage should be restored after reload");
3009 assert_eq!(usage.input_tokens, 150);
3010 assert_eq!(usage.output_tokens, 75);
3011 });
3012
3013 // Ensure scroll position survived the round-trip.
3014 acp_thread.read_with(cx, |thread, _| {
3015 let scroll = thread
3016 .ui_scroll_position()
3017 .expect("scroll position should be restored after reload");
3018 assert_eq!(scroll.item_ix, 5);
3019 assert_eq!(scroll.offset_in_item, gpui::px(12.5));
3020 });
3021 }
3022
3023 #[gpui::test]
3024 async fn test_close_session_saves_thread(cx: &mut TestAppContext) {
3025 init_test(cx);
3026 let fs = FakeFs::new(cx.executor());
3027 fs.insert_tree(
3028 "/",
3029 json!({
3030 "a": {
3031 "file.txt": "hello"
3032 }
3033 }),
3034 )
3035 .await;
3036 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
3037 let thread_store = cx.new(|cx| ThreadStore::new(cx));
3038 let agent = cx.update(|cx| {
3039 NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
3040 });
3041 let connection = Rc::new(NativeAgentConnection(agent.clone()));
3042
3043 let acp_thread = cx
3044 .update(|cx| {
3045 connection
3046 .clone()
3047 .new_session(project.clone(), PathList::new(&[Path::new("")]), cx)
3048 })
3049 .await
3050 .unwrap();
3051 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
3052 let thread = agent.read_with(cx, |agent, _| {
3053 agent.sessions.get(&session_id).unwrap().thread.clone()
3054 });
3055
3056 let model = Arc::new(FakeLanguageModel::default());
3057 thread.update(cx, |thread, cx| {
3058 thread.set_model(model.clone(), cx);
3059 });
3060
3061 // Send a message so the thread is non-empty (empty threads aren't saved).
3062 let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx));
3063 let send = cx.foreground_executor().spawn(send);
3064 cx.run_until_parked();
3065
3066 model.send_last_completion_stream_text_chunk("world");
3067 model.end_last_completion_stream();
3068 send.await.unwrap();
3069 cx.run_until_parked();
3070
3071 // Set a draft prompt WITHOUT calling run_until_parked afterwards.
3072 // This means no observe-triggered save has run for this change.
3073 // The only way this data gets persisted is if close_session
3074 // itself performs the save.
3075 let draft_blocks = vec![acp::ContentBlock::Text(acp::TextContent::new(
3076 "unsaved draft",
3077 ))];
3078 acp_thread.update(cx, |thread, cx| {
3079 thread.set_draft_prompt(Some(draft_blocks.clone()), cx);
3080 });
3081
3082 // Close the session immediately — no run_until_parked in between.
3083 cx.update(|cx| connection.clone().close_session(&session_id, cx))
3084 .await
3085 .unwrap();
3086 cx.run_until_parked();
3087
3088 // Reopen and verify the draft prompt was saved.
3089 let reloaded = agent
3090 .update(cx, |agent, cx| {
3091 agent.open_thread(session_id.clone(), project.clone(), cx)
3092 })
3093 .await
3094 .unwrap();
3095 reloaded.read_with(cx, |thread, _| {
3096 assert_eq!(
3097 thread.draft_prompt(),
3098 Some(draft_blocks.as_slice()),
3099 "close_session must save the thread; draft prompt was lost"
3100 );
3101 });
3102 }
3103
3104 #[gpui::test]
3105 async fn test_thread_summary_releases_loaded_session(cx: &mut TestAppContext) {
3106 init_test(cx);
3107 let fs = FakeFs::new(cx.executor());
3108 fs.insert_tree(
3109 "/",
3110 json!({
3111 "a": {
3112 "file.txt": "hello"
3113 }
3114 }),
3115 )
3116 .await;
3117 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
3118 let thread_store = cx.new(|cx| ThreadStore::new(cx));
3119 let agent = cx.update(|cx| {
3120 NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
3121 });
3122 let connection = Rc::new(NativeAgentConnection(agent.clone()));
3123
3124 let acp_thread = cx
3125 .update(|cx| {
3126 connection
3127 .clone()
3128 .new_session(project.clone(), PathList::new(&[Path::new("")]), cx)
3129 })
3130 .await
3131 .unwrap();
3132 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
3133 let thread = agent.read_with(cx, |agent, _| {
3134 agent.sessions.get(&session_id).unwrap().thread.clone()
3135 });
3136
3137 let model = Arc::new(FakeLanguageModel::default());
3138 let summary_model = Arc::new(FakeLanguageModel::default());
3139 thread.update(cx, |thread, cx| {
3140 thread.set_model(model.clone(), cx);
3141 thread.set_summarization_model(Some(summary_model.clone()), cx);
3142 });
3143
3144 let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx));
3145 let send = cx.foreground_executor().spawn(send);
3146 cx.run_until_parked();
3147
3148 model.send_last_completion_stream_text_chunk("world");
3149 model.end_last_completion_stream();
3150 send.await.unwrap();
3151 cx.run_until_parked();
3152
3153 let summary = agent.update(cx, |agent, cx| {
3154 agent.thread_summary(session_id.clone(), project.clone(), cx)
3155 });
3156 cx.run_until_parked();
3157
3158 summary_model.send_last_completion_stream_text_chunk("summary");
3159 summary_model.end_last_completion_stream();
3160
3161 assert_eq!(summary.await.unwrap(), "summary");
3162 cx.run_until_parked();
3163
3164 agent.read_with(cx, |agent, _| {
3165 let session = agent
3166 .sessions
3167 .get(&session_id)
3168 .expect("thread_summary should not close the active session");
3169 assert_eq!(
3170 session.ref_count, 1,
3171 "thread_summary should release its temporary session reference"
3172 );
3173 });
3174
3175 cx.update(|cx| connection.clone().close_session(&session_id, cx))
3176 .await
3177 .unwrap();
3178 cx.run_until_parked();
3179
3180 agent.read_with(cx, |agent, _| {
3181 assert!(
3182 agent.sessions.is_empty(),
3183 "closing the active session after thread_summary should unload it"
3184 );
3185 });
3186 }
3187
3188 #[gpui::test]
3189 async fn test_loaded_sessions_keep_state_until_last_close(cx: &mut TestAppContext) {
3190 init_test(cx);
3191 let fs = FakeFs::new(cx.executor());
3192 fs.insert_tree(
3193 "/",
3194 json!({
3195 "a": {
3196 "file.txt": "hello"
3197 }
3198 }),
3199 )
3200 .await;
3201 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
3202 let thread_store = cx.new(|cx| ThreadStore::new(cx));
3203 let agent = cx.update(|cx| {
3204 NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
3205 });
3206 let connection = Rc::new(NativeAgentConnection(agent.clone()));
3207
3208 let acp_thread = cx
3209 .update(|cx| {
3210 connection
3211 .clone()
3212 .new_session(project.clone(), PathList::new(&[Path::new("")]), cx)
3213 })
3214 .await
3215 .unwrap();
3216 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
3217 let thread = agent.read_with(cx, |agent, _| {
3218 agent.sessions.get(&session_id).unwrap().thread.clone()
3219 });
3220
3221 let model = cx.update(|cx| {
3222 LanguageModelRegistry::read_global(cx)
3223 .default_model()
3224 .map(|default_model| default_model.model)
3225 .expect("default test model should be available")
3226 });
3227 let fake_model = model.as_fake();
3228 thread.update(cx, |thread, cx| {
3229 thread.set_model(model.clone(), cx);
3230 });
3231
3232 let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx));
3233 let send = cx.foreground_executor().spawn(send);
3234 cx.run_until_parked();
3235
3236 fake_model.send_last_completion_stream_text_chunk("world");
3237 fake_model.end_last_completion_stream();
3238 send.await.unwrap();
3239 cx.run_until_parked();
3240
3241 cx.update(|cx| connection.clone().close_session(&session_id, cx))
3242 .await
3243 .unwrap();
3244 drop(thread);
3245 drop(acp_thread);
3246 agent.read_with(cx, |agent, _| {
3247 assert!(agent.sessions.is_empty());
3248 });
3249
3250 let first_loaded_thread = cx.update(|cx| {
3251 connection.clone().load_session(
3252 session_id.clone(),
3253 project.clone(),
3254 PathList::new(&[Path::new("")]),
3255 None,
3256 cx,
3257 )
3258 });
3259 let second_loaded_thread = cx.update(|cx| {
3260 connection.clone().load_session(
3261 session_id.clone(),
3262 project.clone(),
3263 PathList::new(&[Path::new("")]),
3264 None,
3265 cx,
3266 )
3267 });
3268
3269 let first_loaded_thread = first_loaded_thread.await.unwrap();
3270 let second_loaded_thread = second_loaded_thread.await.unwrap();
3271
3272 cx.run_until_parked();
3273
3274 assert_eq!(
3275 first_loaded_thread.entity_id(),
3276 second_loaded_thread.entity_id(),
3277 "concurrent loads for the same session should share one AcpThread"
3278 );
3279
3280 cx.update(|cx| connection.clone().close_session(&session_id, cx))
3281 .await
3282 .unwrap();
3283
3284 agent.read_with(cx, |agent, _| {
3285 assert!(
3286 agent.sessions.contains_key(&session_id),
3287 "closing one loaded session should not drop shared session state"
3288 );
3289 });
3290
3291 let follow_up = second_loaded_thread.update(cx, |thread, cx| {
3292 thread.send(vec!["still there?".into()], cx)
3293 });
3294 let follow_up = cx.foreground_executor().spawn(follow_up);
3295 cx.run_until_parked();
3296
3297 fake_model.send_last_completion_stream_text_chunk("yes");
3298 fake_model.end_last_completion_stream();
3299 follow_up.await.unwrap();
3300 cx.run_until_parked();
3301
3302 second_loaded_thread.read_with(cx, |thread, cx| {
3303 assert_eq!(
3304 thread.to_markdown(cx),
3305 formatdoc! {"
3306 ## User
3307
3308 hello
3309
3310 ## Assistant
3311
3312 world
3313
3314 ## User
3315
3316 still there?
3317
3318 ## Assistant
3319
3320 yes
3321
3322 "}
3323 );
3324 });
3325
3326 cx.update(|cx| connection.clone().close_session(&session_id, cx))
3327 .await
3328 .unwrap();
3329
3330 cx.run_until_parked();
3331
3332 drop(first_loaded_thread);
3333 drop(second_loaded_thread);
3334 agent.read_with(cx, |agent, _| {
3335 assert!(agent.sessions.is_empty());
3336 });
3337 }
3338
3339 #[gpui::test]
3340 async fn test_rapid_title_changes_do_not_loop(cx: &mut TestAppContext) {
3341 // Regression test: rapid title changes must not cause a propagation loop
3342 // between Thread and AcpThread via handle_thread_title_updated.
3343 init_test(cx);
3344 let fs = FakeFs::new(cx.executor());
3345 fs.insert_tree("/", json!({ "a": {} })).await;
3346 let project = Project::test(fs.clone(), [], cx).await;
3347 let thread_store = cx.new(|cx| ThreadStore::new(cx));
3348 let agent = cx.update(|cx| {
3349 NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
3350 });
3351 let connection = Rc::new(NativeAgentConnection(agent.clone()));
3352
3353 let acp_thread = cx
3354 .update(|cx| {
3355 connection
3356 .clone()
3357 .new_session(project.clone(), PathList::new(&[Path::new("")]), cx)
3358 })
3359 .await
3360 .unwrap();
3361
3362 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
3363 let thread = agent.read_with(cx, |agent, _| {
3364 agent.sessions.get(&session_id).unwrap().thread.clone()
3365 });
3366
3367 let title_updated_count = Rc::new(std::cell::RefCell::new(0usize));
3368 cx.update(|cx| {
3369 let count = title_updated_count.clone();
3370 cx.subscribe(
3371 &thread,
3372 move |_entity: Entity<Thread>, _event: &TitleUpdated, _cx: &mut App| {
3373 let new_count = {
3374 let mut count = count.borrow_mut();
3375 *count += 1;
3376 *count
3377 };
3378 assert!(
3379 new_count <= 2,
3380 "TitleUpdated fired {new_count} times; \
3381 title updates are looping"
3382 );
3383 },
3384 )
3385 .detach();
3386 });
3387
3388 thread.update(cx, |thread, cx| thread.set_title("first".into(), cx));
3389 thread.update(cx, |thread, cx| thread.set_title("second".into(), cx));
3390
3391 cx.run_until_parked();
3392
3393 thread.read_with(cx, |thread, _| {
3394 assert_eq!(thread.title(), Some("second".into()));
3395 });
3396 acp_thread.read_with(cx, |acp_thread, _| {
3397 assert_eq!(acp_thread.title(), Some("second".into()));
3398 });
3399
3400 assert_eq!(*title_updated_count.borrow(), 2);
3401 }
3402
3403 fn thread_entries(
3404 thread_store: &Entity<ThreadStore>,
3405 cx: &mut TestAppContext,
3406 ) -> Vec<(acp::SessionId, String)> {
3407 thread_store.read_with(cx, |store, _| {
3408 store
3409 .entries()
3410 .map(|entry| (entry.id.clone(), entry.title.to_string()))
3411 .collect::<Vec<_>>()
3412 })
3413 }
3414
3415 fn init_test(cx: &mut TestAppContext) {
3416 env_logger::try_init().ok();
3417 cx.update(|cx| {
3418 let settings_store = SettingsStore::test(cx);
3419 cx.set_global(settings_store);
3420
3421 LanguageModelRegistry::test(cx);
3422 });
3423 }
3424}
3425
3426fn mcp_message_content_to_acp_content_block(
3427 content: context_server::types::MessageContent,
3428) -> acp::ContentBlock {
3429 match content {
3430 context_server::types::MessageContent::Text {
3431 text,
3432 annotations: _,
3433 } => text.into(),
3434 context_server::types::MessageContent::Image {
3435 data,
3436 mime_type,
3437 annotations: _,
3438 } => acp::ContentBlock::Image(acp::ImageContent::new(data, mime_type)),
3439 context_server::types::MessageContent::Audio {
3440 data,
3441 mime_type,
3442 annotations: _,
3443 } => acp::ContentBlock::Audio(acp::AudioContent::new(data, mime_type)),
3444 context_server::types::MessageContent::Resource {
3445 resource,
3446 annotations: _,
3447 } => {
3448 let mut link =
3449 acp::ResourceLink::new(resource.uri.to_string(), resource.uri.to_string());
3450 if let Some(mime_type) = resource.mime_type {
3451 link = link.mime_type(mime_type);
3452 }
3453 acp::ContentBlock::ResourceLink(link)
3454 }
3455 }
3456}