1mod db;
2mod edit_agent;
3mod legacy_thread;
4mod native_agent_server;
5pub mod outline;
6mod pattern_extraction;
7mod templates;
8#[cfg(test)]
9mod tests;
10mod thread;
11mod thread_store;
12mod tool_permissions;
13mod tools;
14
15use context_server::ContextServerId;
16pub use db::*;
17use itertools::Itertools;
18pub use native_agent_server::NativeAgentServer;
19pub use pattern_extraction::*;
20pub use shell_command_parser::extract_commands;
21pub use templates::*;
22pub use thread::*;
23pub use thread_store::*;
24pub use tool_permissions::*;
25pub use tools::*;
26
27use acp_thread::{
28 AcpThread, AgentModelSelector, AgentSessionInfo, AgentSessionList, AgentSessionListRequest,
29 AgentSessionListResponse, TokenUsageRatio, UserMessageId,
30};
31use agent_client_protocol::schema as acp;
32use anyhow::{Context as _, Result, anyhow};
33use chrono::{DateTime, Utc};
34use collections::{HashMap, HashSet, IndexMap};
35use fs::Fs;
36use futures::channel::{mpsc, oneshot};
37use futures::future::Shared;
38use futures::{FutureExt as _, StreamExt as _, future};
39use gpui::{
40 App, AppContext, AsyncApp, Context, Entity, EntityId, SharedString, Subscription, Task,
41 WeakEntity,
42};
43use language_model::{IconOrSvg, LanguageModel, LanguageModelProvider, LanguageModelRegistry};
44use project::{AgentId, Project, ProjectItem, ProjectPath, Worktree};
45use prompt_store::{
46 ProjectContext, PromptStore, RULES_FILE_NAMES, RulesFileContext, UserRulesContext,
47 WorktreeContext,
48};
49use serde::{Deserialize, Serialize};
50use settings::{LanguageModelSelection, Settings as _, update_settings_file};
51use std::any::Any;
52use std::path::PathBuf;
53use std::rc::Rc;
54use std::sync::{Arc, LazyLock};
55use util::ResultExt;
56use util::path_list::PathList;
57use util::rel_path::RelPath;
58
59#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
60pub struct ProjectSnapshot {
61 pub worktree_snapshots: Vec<project::telemetry_snapshot::TelemetryWorktreeSnapshot>,
62 pub timestamp: DateTime<Utc>,
63}
64
65pub struct RulesLoadingError {
66 pub message: SharedString,
67}
68
69struct ProjectState {
70 project: Entity<Project>,
71 project_context: Entity<ProjectContext>,
72 project_context_needs_refresh: watch::Sender<()>,
73 _maintain_project_context: Task<Result<()>>,
74 context_server_registry: Entity<ContextServerRegistry>,
75 _subscriptions: Vec<Subscription>,
76}
77
78/// Holds both the internal Thread and the AcpThread for a session
79struct Session {
80 /// The internal thread that processes messages
81 thread: Entity<Thread>,
82 /// The ACP thread that handles protocol communication
83 acp_thread: Entity<acp_thread::AcpThread>,
84 project_id: EntityId,
85 pending_save: Task<Result<()>>,
86 _subscriptions: Vec<Subscription>,
87 ref_count: usize,
88}
89
90struct PendingSession {
91 task: Shared<Task<Result<Entity<AcpThread>, Arc<anyhow::Error>>>>,
92 ref_count: usize,
93}
94
95pub struct LanguageModels {
96 /// Access language model by ID
97 models: HashMap<acp::ModelId, Arc<dyn LanguageModel>>,
98 /// Cached list for returning language model information
99 model_list: acp_thread::AgentModelList,
100 refresh_models_rx: watch::Receiver<()>,
101 refresh_models_tx: watch::Sender<()>,
102 _authenticate_all_providers_task: Task<()>,
103}
104
105impl LanguageModels {
106 fn new(cx: &mut App) -> Self {
107 let (refresh_models_tx, refresh_models_rx) = watch::channel(());
108
109 let mut this = Self {
110 models: HashMap::default(),
111 model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()),
112 refresh_models_rx,
113 refresh_models_tx,
114 _authenticate_all_providers_task: Self::authenticate_all_language_model_providers(cx),
115 };
116 this.refresh_list(cx);
117 this
118 }
119
120 fn refresh_list(&mut self, cx: &App) {
121 let providers = LanguageModelRegistry::global(cx)
122 .read(cx)
123 .visible_providers()
124 .into_iter()
125 .filter(|provider| provider.is_authenticated(cx))
126 .collect::<Vec<_>>();
127
128 let mut language_model_list = IndexMap::default();
129 let mut recommended_models = HashSet::default();
130
131 let mut recommended = Vec::new();
132 for provider in &providers {
133 for model in provider.recommended_models(cx) {
134 recommended_models.insert((model.provider_id(), model.id()));
135 recommended.push(Self::map_language_model_to_info(&model, provider));
136 }
137 }
138 if !recommended.is_empty() {
139 language_model_list.insert(
140 acp_thread::AgentModelGroupName("Recommended".into()),
141 recommended,
142 );
143 }
144
145 let mut models = HashMap::default();
146 for provider in providers {
147 let mut provider_models = Vec::new();
148 for model in provider.provided_models(cx) {
149 let model_info = Self::map_language_model_to_info(&model, &provider);
150 let model_id = model_info.id.clone();
151 provider_models.push(model_info);
152 models.insert(model_id, model);
153 }
154 if !provider_models.is_empty() {
155 language_model_list.insert(
156 acp_thread::AgentModelGroupName(provider.name().0.clone()),
157 provider_models,
158 );
159 }
160 }
161
162 self.models = models;
163 self.model_list = acp_thread::AgentModelList::Grouped(language_model_list);
164 self.refresh_models_tx.send(()).ok();
165 }
166
167 fn watch(&self) -> watch::Receiver<()> {
168 self.refresh_models_rx.clone()
169 }
170
171 pub fn model_from_id(&self, model_id: &acp::ModelId) -> Option<Arc<dyn LanguageModel>> {
172 self.models.get(model_id).cloned()
173 }
174
175 fn map_language_model_to_info(
176 model: &Arc<dyn LanguageModel>,
177 provider: &Arc<dyn LanguageModelProvider>,
178 ) -> acp_thread::AgentModelInfo {
179 acp_thread::AgentModelInfo {
180 id: Self::model_id(model),
181 name: model.name().0,
182 description: None,
183 icon: Some(match provider.icon() {
184 IconOrSvg::Svg(path) => acp_thread::AgentModelIcon::Path(path),
185 IconOrSvg::Icon(name) => acp_thread::AgentModelIcon::Named(name),
186 }),
187 is_latest: model.is_latest(),
188 cost: model.model_cost_info().map(|cost| cost.to_shared_string()),
189 }
190 }
191
192 fn model_id(model: &Arc<dyn LanguageModel>) -> acp::ModelId {
193 acp::ModelId::new(format!("{}/{}", model.provider_id().0, model.id().0))
194 }
195
196 fn authenticate_all_language_model_providers(cx: &mut App) -> Task<()> {
197 let authenticate_all_providers = LanguageModelRegistry::global(cx)
198 .read(cx)
199 .visible_providers()
200 .iter()
201 .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
202 .collect::<Vec<_>>();
203
204 cx.spawn(async move |cx| {
205 for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
206 if let Err(err) = authenticate_task.await {
207 match err {
208 language_model::AuthenticateError::CredentialsNotFound => {
209 // Since we're authenticating these providers in the
210 // background for the purposes of populating the
211 // language selector, we don't care about providers
212 // where the credentials are not found.
213 }
214 language_model::AuthenticateError::ConnectionRefused => {
215 // Not logging connection refused errors as they are mostly from LM Studio's noisy auth failures.
216 // LM Studio only has one auth method (endpoint call) which fails for users who haven't enabled it.
217 // TODO: Better manage LM Studio auth logic to avoid these noisy failures.
218 }
219 _ => {
220 // Some providers have noisy failure states that we
221 // don't want to spam the logs with every time the
222 // language model selector is initialized.
223 //
224 // Ideally these should have more clear failure modes
225 // that we know are safe to ignore here, like what we do
226 // with `CredentialsNotFound` above.
227 match provider_id.0.as_ref() {
228 "lmstudio" | "ollama" => {
229 // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
230 //
231 // These fail noisily, so we don't log them.
232 }
233 "copilot_chat" => {
234 // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
235 }
236 _ => {
237 log::error!(
238 "Failed to authenticate provider: {}: {err:#}",
239 provider_name.0
240 );
241 }
242 }
243 }
244 }
245 }
246 }
247
248 cx.update(language_models::update_environment_fallback_model);
249 })
250 }
251}
252
253pub struct NativeAgent {
254 /// Session ID -> Session mapping
255 sessions: HashMap<acp::SessionId, Session>,
256 pending_sessions: HashMap<acp::SessionId, PendingSession>,
257 thread_store: Entity<ThreadStore>,
258 /// Project-specific state keyed by project EntityId
259 projects: HashMap<EntityId, ProjectState>,
260 /// Shared templates for all threads
261 templates: Arc<Templates>,
262 /// Cached model information
263 models: LanguageModels,
264 prompt_store: Option<Entity<PromptStore>>,
265 fs: Arc<dyn Fs>,
266 _subscriptions: Vec<Subscription>,
267}
268
269impl NativeAgent {
270 pub fn new(
271 thread_store: Entity<ThreadStore>,
272 templates: Arc<Templates>,
273 prompt_store: Option<Entity<PromptStore>>,
274 fs: Arc<dyn Fs>,
275 cx: &mut App,
276 ) -> Entity<NativeAgent> {
277 log::debug!("Creating new NativeAgent");
278
279 cx.new(|cx| {
280 let mut subscriptions = vec![cx.subscribe(
281 &LanguageModelRegistry::global(cx),
282 Self::handle_models_updated_event,
283 )];
284 if let Some(prompt_store) = prompt_store.as_ref() {
285 subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
286 }
287
288 Self {
289 sessions: HashMap::default(),
290 pending_sessions: HashMap::default(),
291 thread_store,
292 projects: HashMap::default(),
293 templates,
294 models: LanguageModels::new(cx),
295 prompt_store,
296 fs,
297 _subscriptions: subscriptions,
298 }
299 })
300 }
301
302 fn new_session(
303 &mut self,
304 project: Entity<Project>,
305 cx: &mut Context<Self>,
306 ) -> Entity<AcpThread> {
307 let project_id = self.get_or_create_project_state(&project, cx);
308 let project_state = &self.projects[&project_id];
309
310 let registry = LanguageModelRegistry::read_global(cx);
311 let available_count = registry.available_models(cx).count();
312 log::debug!("Total available models: {}", available_count);
313
314 let default_model = registry.default_model().and_then(|default_model| {
315 self.models
316 .model_from_id(&LanguageModels::model_id(&default_model.model))
317 });
318 let thread = cx.new(|cx| {
319 Thread::new(
320 project,
321 project_state.project_context.clone(),
322 project_state.context_server_registry.clone(),
323 self.templates.clone(),
324 default_model,
325 cx,
326 )
327 });
328
329 self.register_session(thread, project_id, 1, cx)
330 }
331
332 fn register_session(
333 &mut self,
334 thread_handle: Entity<Thread>,
335 project_id: EntityId,
336 ref_count: usize,
337 cx: &mut Context<Self>,
338 ) -> Entity<AcpThread> {
339 let connection = Rc::new(NativeAgentConnection(cx.entity()));
340
341 let thread = thread_handle.read(cx);
342 let session_id = thread.id().clone();
343 let parent_session_id = thread.parent_thread_id();
344 let title = thread.title();
345 let draft_prompt = thread.draft_prompt().map(Vec::from);
346 let scroll_position = thread.ui_scroll_position();
347 let token_usage = thread.latest_token_usage();
348 let project = thread.project.clone();
349 let action_log = thread.action_log.clone();
350 let prompt_capabilities_rx = thread.prompt_capabilities_rx.clone();
351 let acp_thread = cx.new(|cx| {
352 let mut acp_thread = acp_thread::AcpThread::new(
353 parent_session_id,
354 title,
355 None,
356 connection,
357 project.clone(),
358 action_log.clone(),
359 session_id.clone(),
360 prompt_capabilities_rx,
361 cx,
362 );
363 acp_thread.set_draft_prompt(draft_prompt, cx);
364 acp_thread.set_ui_scroll_position(scroll_position);
365 acp_thread.update_token_usage(token_usage, cx);
366 acp_thread
367 });
368
369 let registry = LanguageModelRegistry::read_global(cx);
370 let summarization_model = registry.thread_summary_model(cx).map(|c| c.model);
371
372 let weak = cx.weak_entity();
373 let weak_thread = thread_handle.downgrade();
374 thread_handle.update(cx, |thread, cx| {
375 thread.set_summarization_model(summarization_model, cx);
376 thread.add_default_tools(
377 Rc::new(NativeThreadEnvironment {
378 acp_thread: acp_thread.downgrade(),
379 thread: weak_thread,
380 agent: weak,
381 }) as _,
382 cx,
383 )
384 });
385
386 let subscriptions = vec![
387 cx.subscribe(&thread_handle, Self::handle_thread_title_updated),
388 cx.subscribe(&thread_handle, Self::handle_thread_token_usage_updated),
389 cx.observe(&thread_handle, move |this, thread, cx| {
390 this.save_thread(thread, cx)
391 }),
392 ];
393
394 self.sessions.insert(
395 session_id,
396 Session {
397 thread: thread_handle,
398 acp_thread: acp_thread.clone(),
399 project_id,
400 _subscriptions: subscriptions,
401 pending_save: Task::ready(Ok(())),
402 ref_count,
403 },
404 );
405
406 self.update_available_commands_for_project(project_id, cx);
407
408 acp_thread
409 }
410
411 pub fn models(&self) -> &LanguageModels {
412 &self.models
413 }
414
415 fn get_or_create_project_state(
416 &mut self,
417 project: &Entity<Project>,
418 cx: &mut Context<Self>,
419 ) -> EntityId {
420 let project_id = project.entity_id();
421 if self.projects.contains_key(&project_id) {
422 return project_id;
423 }
424
425 let project_context = cx.new(|_| ProjectContext::new(vec![], vec![]));
426 self.register_project_with_initial_context(project.clone(), project_context, cx);
427 if let Some(state) = self.projects.get_mut(&project_id) {
428 state.project_context_needs_refresh.send(()).ok();
429 }
430 project_id
431 }
432
433 fn register_project_with_initial_context(
434 &mut self,
435 project: Entity<Project>,
436 project_context: Entity<ProjectContext>,
437 cx: &mut Context<Self>,
438 ) {
439 let project_id = project.entity_id();
440
441 let context_server_store = project.read(cx).context_server_store();
442 let context_server_registry =
443 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
444
445 let subscriptions = vec![
446 cx.subscribe(&project, Self::handle_project_event),
447 cx.subscribe(
448 &context_server_store,
449 Self::handle_context_server_store_updated,
450 ),
451 cx.subscribe(
452 &context_server_registry,
453 Self::handle_context_server_registry_event,
454 ),
455 ];
456
457 let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
458 watch::channel(());
459
460 self.projects.insert(
461 project_id,
462 ProjectState {
463 project,
464 project_context,
465 project_context_needs_refresh: project_context_needs_refresh_tx,
466 _maintain_project_context: cx.spawn(async move |this, cx| {
467 Self::maintain_project_context(
468 this,
469 project_id,
470 project_context_needs_refresh_rx,
471 cx,
472 )
473 .await
474 }),
475 context_server_registry,
476 _subscriptions: subscriptions,
477 },
478 );
479 }
480
481 fn session_project_state(&self, session_id: &acp::SessionId) -> Option<&ProjectState> {
482 self.sessions
483 .get(session_id)
484 .and_then(|session| self.projects.get(&session.project_id))
485 }
486
487 async fn maintain_project_context(
488 this: WeakEntity<Self>,
489 project_id: EntityId,
490 mut needs_refresh: watch::Receiver<()>,
491 cx: &mut AsyncApp,
492 ) -> Result<()> {
493 while needs_refresh.changed().await.is_ok() {
494 let project_context = this
495 .update(cx, |this, cx| {
496 let state = this
497 .projects
498 .get(&project_id)
499 .context("project state not found")?;
500 anyhow::Ok(Self::build_project_context(
501 &state.project,
502 this.prompt_store.as_ref(),
503 cx,
504 ))
505 })??
506 .await;
507 this.update(cx, |this, cx| {
508 if let Some(state) = this.projects.get(&project_id) {
509 state
510 .project_context
511 .update(cx, |current_project_context, _cx| {
512 *current_project_context = project_context;
513 });
514 }
515 })?;
516 }
517
518 Ok(())
519 }
520
521 fn build_project_context(
522 project: &Entity<Project>,
523 prompt_store: Option<&Entity<PromptStore>>,
524 cx: &mut App,
525 ) -> Task<ProjectContext> {
526 let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
527 let worktree_tasks = worktrees
528 .into_iter()
529 .map(|worktree| {
530 Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
531 })
532 .collect::<Vec<_>>();
533 let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
534 prompt_store.read_with(cx, |prompt_store, cx| {
535 let prompts = prompt_store.default_prompt_metadata();
536 let load_tasks = prompts.into_iter().map(|prompt_metadata| {
537 let contents = prompt_store.load(prompt_metadata.id, cx);
538 async move { (contents.await, prompt_metadata) }
539 });
540 cx.background_spawn(future::join_all(load_tasks))
541 })
542 } else {
543 Task::ready(vec![])
544 };
545
546 cx.spawn(async move |_cx| {
547 let (worktrees, default_user_rules) =
548 future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
549
550 let worktrees = worktrees
551 .into_iter()
552 .map(|(worktree, _rules_error)| {
553 // TODO: show error message
554 // if let Some(rules_error) = rules_error {
555 // this.update(cx, |_, cx| cx.emit(rules_error)).ok();
556 // }
557 worktree
558 })
559 .collect::<Vec<_>>();
560
561 let default_user_rules = default_user_rules
562 .into_iter()
563 .flat_map(|(contents, prompt_metadata)| match contents {
564 Ok(contents) => Some(UserRulesContext {
565 uuid: prompt_metadata.id.as_user()?,
566 title: prompt_metadata.title.map(|title| title.to_string()),
567 contents,
568 }),
569 Err(_err) => {
570 // TODO: show error message
571 // this.update(cx, |_, cx| {
572 // cx.emit(RulesLoadingError {
573 // message: format!("{err:?}").into(),
574 // });
575 // })
576 // .ok();
577 None
578 }
579 })
580 .collect::<Vec<_>>();
581
582 ProjectContext::new(worktrees, default_user_rules)
583 })
584 }
585
586 fn load_worktree_info_for_system_prompt(
587 worktree: Entity<Worktree>,
588 project: Entity<Project>,
589 cx: &mut App,
590 ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
591 let tree = worktree.read(cx);
592 let root_name = tree.root_name_str().into();
593 let abs_path = tree.abs_path();
594 let scan_complete = tree.as_local().map(|local| local.scan_complete());
595
596 let mut context = WorktreeContext {
597 root_name,
598 abs_path,
599 rules_file: None,
600 };
601
602 cx.spawn(async move |cx| {
603 if let Some(scan_complete) = scan_complete {
604 scan_complete.await;
605 }
606
607 let rules_task = cx.update(|cx| Self::load_worktree_rules_file(worktree, project, cx));
608
609 let (rules_file, rules_file_error) = match rules_task {
610 Some(rules_task) => match rules_task.await {
611 Ok(rules_file) => (Some(rules_file), None),
612 Err(err) => (
613 None,
614 Some(RulesLoadingError {
615 message: format!("{err}").into(),
616 }),
617 ),
618 },
619 None => (None, None),
620 };
621 context.rules_file = rules_file;
622 (context, rules_file_error)
623 })
624 }
625
626 fn load_worktree_rules_file(
627 worktree: Entity<Worktree>,
628 project: Entity<Project>,
629 cx: &mut App,
630 ) -> Option<Task<Result<RulesFileContext>>> {
631 let worktree = worktree.read(cx);
632 let worktree_id = worktree.id();
633 let selected_rules_file = RULES_FILE_NAMES
634 .into_iter()
635 .filter_map(|name| {
636 worktree
637 .entry_for_path(RelPath::unix(name).unwrap())
638 .filter(|entry| entry.is_file())
639 .map(|entry| entry.path.clone())
640 })
641 .next();
642
643 // Note that Cline supports `.clinerules` being a directory, but that is not currently
644 // supported. This doesn't seem to occur often in GitHub repositories.
645 selected_rules_file.map(|path_in_worktree| {
646 let project_path = ProjectPath {
647 worktree_id,
648 path: path_in_worktree.clone(),
649 };
650 let buffer_task =
651 project.update(cx, |project, cx| project.open_buffer(project_path, cx));
652 let rope_task = cx.spawn(async move |cx| {
653 let buffer = buffer_task.await?;
654 let (project_entry_id, rope) = buffer.read_with(cx, |buffer, cx| {
655 let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
656 anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
657 })?;
658 anyhow::Ok((project_entry_id, rope))
659 });
660 // Build a string from the rope on a background thread.
661 cx.background_spawn(async move {
662 let (project_entry_id, rope) = rope_task.await?;
663 anyhow::Ok(RulesFileContext {
664 path_in_worktree,
665 text: rope.to_string().trim().to_string(),
666 project_entry_id: project_entry_id.to_usize(),
667 })
668 })
669 })
670 }
671
672 fn handle_thread_title_updated(
673 &mut self,
674 thread: Entity<Thread>,
675 _: &TitleUpdated,
676 cx: &mut Context<Self>,
677 ) {
678 let session_id = thread.read(cx).id();
679 let Some(session) = self.sessions.get(session_id) else {
680 return;
681 };
682
683 let thread = thread.downgrade();
684 let acp_thread = session.acp_thread.downgrade();
685 cx.spawn(async move |_, cx| {
686 let title = thread.read_with(cx, |thread, _| thread.title())?;
687 if let Some(title) = title {
688 let task =
689 acp_thread.update(cx, |acp_thread, cx| acp_thread.set_title(title, cx))?;
690 task.await?;
691 }
692 anyhow::Ok(())
693 })
694 .detach_and_log_err(cx);
695 }
696
697 fn handle_thread_token_usage_updated(
698 &mut self,
699 thread: Entity<Thread>,
700 usage: &TokenUsageUpdated,
701 cx: &mut Context<Self>,
702 ) {
703 let Some(session) = self.sessions.get(thread.read(cx).id()) else {
704 return;
705 };
706 session.acp_thread.update(cx, |acp_thread, cx| {
707 acp_thread.update_token_usage(usage.0.clone(), cx);
708 });
709 }
710
711 fn handle_project_event(
712 &mut self,
713 project: Entity<Project>,
714 event: &project::Event,
715 _cx: &mut Context<Self>,
716 ) {
717 let project_id = project.entity_id();
718 let Some(state) = self.projects.get_mut(&project_id) else {
719 return;
720 };
721 match event {
722 project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
723 state.project_context_needs_refresh.send(()).ok();
724 }
725 project::Event::WorktreeUpdatedEntries(_, items) => {
726 if items.iter().any(|(path, _, _)| {
727 RULES_FILE_NAMES
728 .iter()
729 .any(|name| path.as_ref() == RelPath::unix(name).unwrap())
730 }) {
731 state.project_context_needs_refresh.send(()).ok();
732 }
733 }
734 _ => {}
735 }
736 }
737
738 fn handle_prompts_updated_event(
739 &mut self,
740 _prompt_store: Entity<PromptStore>,
741 _event: &prompt_store::PromptsUpdatedEvent,
742 _cx: &mut Context<Self>,
743 ) {
744 for state in self.projects.values_mut() {
745 state.project_context_needs_refresh.send(()).ok();
746 }
747 }
748
749 fn handle_models_updated_event(
750 &mut self,
751 _registry: Entity<LanguageModelRegistry>,
752 event: &language_model::Event,
753 cx: &mut Context<Self>,
754 ) {
755 self.models.refresh_list(cx);
756
757 let registry = LanguageModelRegistry::read_global(cx);
758 let default_model = registry.default_model().map(|m| m.model);
759 let summarization_model = registry.thread_summary_model(cx).map(|m| m.model);
760
761 for session in self.sessions.values_mut() {
762 session.thread.update(cx, |thread, cx| {
763 if thread.model().is_none()
764 && let Some(model) = default_model.clone()
765 {
766 thread.set_model(model, cx);
767 cx.notify();
768 }
769 if let Some(model) = summarization_model.clone() {
770 if thread.summarization_model().is_none()
771 || matches!(event, language_model::Event::ThreadSummaryModelChanged)
772 {
773 thread.set_summarization_model(Some(model), cx);
774 }
775 }
776 });
777 }
778 }
779
780 fn handle_context_server_store_updated(
781 &mut self,
782 store: Entity<project::context_server_store::ContextServerStore>,
783 _event: &project::context_server_store::ServerStatusChangedEvent,
784 cx: &mut Context<Self>,
785 ) {
786 let project_id = self.projects.iter().find_map(|(id, state)| {
787 if *state.context_server_registry.read(cx).server_store() == store {
788 Some(*id)
789 } else {
790 None
791 }
792 });
793 if let Some(project_id) = project_id {
794 self.update_available_commands_for_project(project_id, cx);
795 }
796 }
797
798 fn handle_context_server_registry_event(
799 &mut self,
800 registry: Entity<ContextServerRegistry>,
801 event: &ContextServerRegistryEvent,
802 cx: &mut Context<Self>,
803 ) {
804 match event {
805 ContextServerRegistryEvent::ToolsChanged => {}
806 ContextServerRegistryEvent::PromptsChanged => {
807 let project_id = self.projects.iter().find_map(|(id, state)| {
808 if state.context_server_registry == registry {
809 Some(*id)
810 } else {
811 None
812 }
813 });
814 if let Some(project_id) = project_id {
815 self.update_available_commands_for_project(project_id, cx);
816 }
817 }
818 }
819 }
820
821 fn update_available_commands_for_project(&self, project_id: EntityId, cx: &mut Context<Self>) {
822 let available_commands =
823 Self::build_available_commands_for_project(self.projects.get(&project_id), cx);
824 for session in self.sessions.values() {
825 if session.project_id != project_id {
826 continue;
827 }
828 session.acp_thread.update(cx, |thread, cx| {
829 thread
830 .handle_session_update(
831 acp::SessionUpdate::AvailableCommandsUpdate(
832 acp::AvailableCommandsUpdate::new(available_commands.clone()),
833 ),
834 cx,
835 )
836 .log_err();
837 });
838 }
839 }
840
841 fn build_available_commands_for_project(
842 project_state: Option<&ProjectState>,
843 cx: &App,
844 ) -> Vec<acp::AvailableCommand> {
845 let Some(state) = project_state else {
846 return vec![];
847 };
848 let registry = state.context_server_registry.read(cx);
849
850 let mut prompt_name_counts: HashMap<&str, usize> = HashMap::default();
851 for context_server_prompt in registry.prompts() {
852 *prompt_name_counts
853 .entry(context_server_prompt.prompt.name.as_str())
854 .or_insert(0) += 1;
855 }
856
857 registry
858 .prompts()
859 .flat_map(|context_server_prompt| {
860 let prompt = &context_server_prompt.prompt;
861
862 let should_prefix = prompt_name_counts
863 .get(prompt.name.as_str())
864 .copied()
865 .unwrap_or(0)
866 > 1;
867
868 let name = if should_prefix {
869 format!("{}.{}", context_server_prompt.server_id, prompt.name)
870 } else {
871 prompt.name.clone()
872 };
873
874 let mut command = acp::AvailableCommand::new(
875 name,
876 prompt.description.clone().unwrap_or_default(),
877 );
878
879 match prompt.arguments.as_deref() {
880 Some([arg]) => {
881 let hint = format!("<{}>", arg.name);
882
883 command = command.input(acp::AvailableCommandInput::Unstructured(
884 acp::UnstructuredCommandInput::new(hint),
885 ));
886 }
887 Some([]) | None => {}
888 Some(_) => {
889 // skip >1 argument commands since we don't support them yet
890 return None;
891 }
892 }
893
894 Some(command)
895 })
896 .collect()
897 }
898
899 pub fn load_thread(
900 &mut self,
901 id: acp::SessionId,
902 project: Entity<Project>,
903 cx: &mut Context<Self>,
904 ) -> Task<Result<Entity<Thread>>> {
905 let database_future = ThreadsDatabase::connect(cx);
906 cx.spawn(async move |this, cx| {
907 let database = database_future.await.map_err(|err| anyhow!(err))?;
908 let db_thread = database
909 .load_thread(id.clone())
910 .await?
911 .with_context(|| format!("no thread found with ID: {id:?}"))?;
912
913 this.update(cx, |this, cx| {
914 let project_id = this.get_or_create_project_state(&project, cx);
915 let project_state = this
916 .projects
917 .get(&project_id)
918 .context("project state not found")?;
919 let summarization_model = LanguageModelRegistry::read_global(cx)
920 .thread_summary_model(cx)
921 .map(|c| c.model);
922
923 Ok(cx.new(|cx| {
924 let mut thread = Thread::from_db(
925 id.clone(),
926 db_thread,
927 project_state.project.clone(),
928 project_state.project_context.clone(),
929 project_state.context_server_registry.clone(),
930 this.templates.clone(),
931 cx,
932 );
933 thread.set_summarization_model(summarization_model, cx);
934 thread
935 }))
936 })?
937 })
938 }
939
940 pub fn open_thread(
941 &mut self,
942 id: acp::SessionId,
943 project: Entity<Project>,
944 cx: &mut Context<Self>,
945 ) -> Task<Result<Entity<AcpThread>>> {
946 if let Some(session) = self.sessions.get_mut(&id) {
947 session.ref_count += 1;
948 return Task::ready(Ok(session.acp_thread.clone()));
949 }
950
951 if let Some(pending) = self.pending_sessions.get_mut(&id) {
952 pending.ref_count += 1;
953 let task = pending.task.clone();
954 return cx.background_spawn(async move { task.await.map_err(|err| anyhow!(err)) });
955 }
956
957 let task = self.load_thread(id.clone(), project.clone(), cx);
958 let shared_task = cx
959 .spawn({
960 let id = id.clone();
961 async move |this, cx| {
962 let thread = match task.await {
963 Ok(thread) => thread,
964 Err(err) => {
965 this.update(cx, |this, _cx| {
966 this.pending_sessions.remove(&id);
967 })
968 .ok();
969 return Err(Arc::new(err));
970 }
971 };
972 let acp_thread = this
973 .update(cx, |this, cx| {
974 let project_id = this.get_or_create_project_state(&project, cx);
975 let ref_count = this
976 .pending_sessions
977 .remove(&id)
978 .map_or(1, |pending| pending.ref_count);
979 this.register_session(thread.clone(), project_id, ref_count, cx)
980 })
981 .map_err(Arc::new)?;
982 let events = thread.update(cx, |thread, cx| thread.replay(cx));
983 cx.update(|cx| {
984 NativeAgentConnection::handle_thread_events(
985 events,
986 acp_thread.downgrade(),
987 cx,
988 )
989 })
990 .await
991 .map_err(Arc::new)?;
992 acp_thread.update(cx, |thread, cx| {
993 thread.snapshot_completed_plan(cx);
994 });
995 Ok(acp_thread)
996 }
997 })
998 .shared();
999 self.pending_sessions.insert(
1000 id,
1001 PendingSession {
1002 task: shared_task.clone(),
1003 ref_count: 1,
1004 },
1005 );
1006
1007 cx.background_spawn(async move { shared_task.await.map_err(|err| anyhow!(err)) })
1008 }
1009
1010 pub fn thread_summary(
1011 &mut self,
1012 id: acp::SessionId,
1013 project: Entity<Project>,
1014 cx: &mut Context<Self>,
1015 ) -> Task<Result<SharedString>> {
1016 let thread = self.open_thread(id.clone(), project, cx);
1017 cx.spawn(async move |this, cx| {
1018 let acp_thread = thread.await?;
1019 let result = this
1020 .update(cx, |this, cx| {
1021 this.sessions
1022 .get(&id)
1023 .unwrap()
1024 .thread
1025 .update(cx, |thread, cx| thread.summary(cx))
1026 })?
1027 .await
1028 .context("Failed to generate summary")?;
1029
1030 this.update(cx, |this, cx| this.close_session(&id, cx))?
1031 .await?;
1032 drop(acp_thread);
1033 Ok(result)
1034 })
1035 }
1036
1037 fn close_session(
1038 &mut self,
1039 session_id: &acp::SessionId,
1040 cx: &mut Context<Self>,
1041 ) -> Task<Result<()>> {
1042 let Some(session) = self.sessions.get_mut(session_id) else {
1043 return Task::ready(Ok(()));
1044 };
1045
1046 session.ref_count -= 1;
1047 if session.ref_count > 0 {
1048 return Task::ready(Ok(()));
1049 }
1050
1051 let thread = session.thread.clone();
1052 self.save_thread(thread, cx);
1053 let Some(session) = self.sessions.remove(session_id) else {
1054 return Task::ready(Ok(()));
1055 };
1056 let project_id = session.project_id;
1057
1058 let has_remaining = self.sessions.values().any(|s| s.project_id == project_id);
1059 if !has_remaining {
1060 self.projects.remove(&project_id);
1061 }
1062
1063 session.pending_save
1064 }
1065
1066 fn save_thread(&mut self, thread: Entity<Thread>, cx: &mut Context<Self>) {
1067 if thread.read(cx).is_empty() {
1068 return;
1069 }
1070
1071 let id = thread.read(cx).id().clone();
1072 let Some(session) = self.sessions.get_mut(&id) else {
1073 return;
1074 };
1075
1076 let project_id = session.project_id;
1077 let Some(state) = self.projects.get(&project_id) else {
1078 return;
1079 };
1080
1081 let folder_paths = PathList::new(
1082 &state
1083 .project
1084 .read(cx)
1085 .visible_worktrees(cx)
1086 .map(|worktree| worktree.read(cx).abs_path().to_path_buf())
1087 .collect::<Vec<_>>(),
1088 );
1089
1090 let draft_prompt = session.acp_thread.read(cx).draft_prompt().map(Vec::from);
1091 let database_future = ThreadsDatabase::connect(cx);
1092 let db_thread = thread.update(cx, |thread, cx| {
1093 thread.set_draft_prompt(draft_prompt);
1094 thread.to_db(cx)
1095 });
1096 let thread_store = self.thread_store.clone();
1097 session.pending_save = cx.spawn(async move |_, cx| {
1098 let Some(database) = database_future.await.map_err(|err| anyhow!(err)).log_err() else {
1099 return Ok(());
1100 };
1101 let db_thread = db_thread.await;
1102 database
1103 .save_thread(id, db_thread, folder_paths)
1104 .await
1105 .log_err();
1106 thread_store.update(cx, |store, cx| store.reload(cx));
1107 Ok(())
1108 });
1109 }
1110
1111 fn send_mcp_prompt(
1112 &self,
1113 message_id: UserMessageId,
1114 session_id: acp::SessionId,
1115 prompt_name: String,
1116 server_id: ContextServerId,
1117 arguments: HashMap<String, String>,
1118 original_content: Vec<acp::ContentBlock>,
1119 cx: &mut Context<Self>,
1120 ) -> Task<Result<acp::PromptResponse>> {
1121 let Some(state) = self.session_project_state(&session_id) else {
1122 return Task::ready(Err(anyhow!("Project state not found for session")));
1123 };
1124 let server_store = state
1125 .context_server_registry
1126 .read(cx)
1127 .server_store()
1128 .clone();
1129 let path_style = state.project.read(cx).path_style(cx);
1130
1131 cx.spawn(async move |this, cx| {
1132 let prompt =
1133 crate::get_prompt(&server_store, &server_id, &prompt_name, arguments, cx).await?;
1134
1135 let (acp_thread, thread) = this.update(cx, |this, _cx| {
1136 let session = this
1137 .sessions
1138 .get(&session_id)
1139 .context("Failed to get session")?;
1140 anyhow::Ok((session.acp_thread.clone(), session.thread.clone()))
1141 })??;
1142
1143 let mut last_is_user = true;
1144
1145 thread.update(cx, |thread, cx| {
1146 thread.push_acp_user_block(
1147 message_id,
1148 original_content.into_iter().skip(1),
1149 path_style,
1150 cx,
1151 );
1152 });
1153
1154 for message in prompt.messages {
1155 let context_server::types::PromptMessage { role, content } = message;
1156 let block = mcp_message_content_to_acp_content_block(content);
1157
1158 match role {
1159 context_server::types::Role::User => {
1160 let id = acp_thread::UserMessageId::new();
1161
1162 acp_thread.update(cx, |acp_thread, cx| {
1163 acp_thread.push_user_content_block_with_indent(
1164 Some(id.clone()),
1165 block.clone(),
1166 true,
1167 cx,
1168 );
1169 });
1170
1171 thread.update(cx, |thread, cx| {
1172 thread.push_acp_user_block(id, [block], path_style, cx);
1173 });
1174 }
1175 context_server::types::Role::Assistant => {
1176 acp_thread.update(cx, |acp_thread, cx| {
1177 acp_thread.push_assistant_content_block_with_indent(
1178 block.clone(),
1179 false,
1180 true,
1181 cx,
1182 );
1183 });
1184
1185 thread.update(cx, |thread, cx| {
1186 thread.push_acp_agent_block(block, cx);
1187 });
1188 }
1189 }
1190
1191 last_is_user = role == context_server::types::Role::User;
1192 }
1193
1194 let response_stream = thread.update(cx, |thread, cx| {
1195 if last_is_user {
1196 thread.send_existing(cx)
1197 } else {
1198 // Resume if MCP prompt did not end with a user message
1199 thread.resume(cx)
1200 }
1201 })?;
1202
1203 cx.update(|cx| {
1204 NativeAgentConnection::handle_thread_events(
1205 response_stream,
1206 acp_thread.downgrade(),
1207 cx,
1208 )
1209 })
1210 .await
1211 })
1212 }
1213}
1214
1215/// Wrapper struct that implements the AgentConnection trait
1216#[derive(Clone)]
1217pub struct NativeAgentConnection(pub Entity<NativeAgent>);
1218
1219impl NativeAgentConnection {
1220 pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option<Entity<Thread>> {
1221 self.0
1222 .read(cx)
1223 .sessions
1224 .get(session_id)
1225 .map(|session| session.thread.clone())
1226 }
1227
1228 pub fn load_thread(
1229 &self,
1230 id: acp::SessionId,
1231 project: Entity<Project>,
1232 cx: &mut App,
1233 ) -> Task<Result<Entity<Thread>>> {
1234 self.0
1235 .update(cx, |this, cx| this.load_thread(id, project, cx))
1236 }
1237
1238 fn run_turn(
1239 &self,
1240 session_id: acp::SessionId,
1241 cx: &mut App,
1242 f: impl 'static
1243 + FnOnce(Entity<Thread>, &mut App) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>,
1244 ) -> Task<Result<acp::PromptResponse>> {
1245 let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
1246 agent
1247 .sessions
1248 .get_mut(&session_id)
1249 .map(|s| (s.thread.clone(), s.acp_thread.clone()))
1250 }) else {
1251 log::error!("Session not found in run_turn: {}", session_id);
1252 return Task::ready(Err(anyhow!("Session not found")));
1253 };
1254 log::debug!("Found session for: {}", session_id);
1255
1256 let response_stream = match f(thread, cx) {
1257 Ok(stream) => stream,
1258 Err(err) => return Task::ready(Err(err)),
1259 };
1260 Self::handle_thread_events(response_stream, acp_thread.downgrade(), cx)
1261 }
1262
1263 fn handle_thread_events(
1264 mut events: mpsc::UnboundedReceiver<Result<ThreadEvent>>,
1265 acp_thread: WeakEntity<AcpThread>,
1266 cx: &App,
1267 ) -> Task<Result<acp::PromptResponse>> {
1268 cx.spawn(async move |cx| {
1269 // Handle response stream and forward to session.acp_thread
1270 while let Some(result) = events.next().await {
1271 match result {
1272 Ok(event) => {
1273 log::trace!("Received completion event: {:?}", event);
1274
1275 match event {
1276 ThreadEvent::UserMessage(message) => {
1277 acp_thread.update(cx, |thread, cx| {
1278 for content in message.content {
1279 thread.push_user_content_block(
1280 Some(message.id.clone()),
1281 content.into(),
1282 cx,
1283 );
1284 }
1285 })?;
1286 }
1287 ThreadEvent::AgentText(text) => {
1288 acp_thread.update(cx, |thread, cx| {
1289 thread.push_assistant_content_block(text.into(), false, cx)
1290 })?;
1291 }
1292 ThreadEvent::AgentThinking(text) => {
1293 acp_thread.update(cx, |thread, cx| {
1294 thread.push_assistant_content_block(text.into(), true, cx)
1295 })?;
1296 }
1297 ThreadEvent::ToolCallAuthorization(ToolCallAuthorization {
1298 tool_call,
1299 options,
1300 response,
1301 context: _,
1302 }) => {
1303 let outcome_task = acp_thread.update(cx, |thread, cx| {
1304 thread.request_tool_call_authorization(tool_call, options, cx)
1305 })??;
1306 cx.background_spawn(async move {
1307 if let acp_thread::RequestPermissionOutcome::Selected(outcome) =
1308 outcome_task.await
1309 {
1310 response
1311 .send(outcome)
1312 .map_err(|_| {
1313 anyhow!("authorization receiver was dropped")
1314 })
1315 .log_err();
1316 }
1317 })
1318 .detach();
1319 }
1320 ThreadEvent::ToolCall(tool_call) => {
1321 acp_thread.update(cx, |thread, cx| {
1322 thread.upsert_tool_call(tool_call, cx)
1323 })??;
1324 }
1325 ThreadEvent::ToolCallUpdate(update) => {
1326 acp_thread.update(cx, |thread, cx| {
1327 thread.update_tool_call(update, cx)
1328 })??;
1329 }
1330 ThreadEvent::Plan(plan) => {
1331 acp_thread.update(cx, |thread, cx| thread.update_plan(plan, cx))?;
1332 }
1333 ThreadEvent::SubagentSpawned(session_id) => {
1334 acp_thread.update(cx, |thread, cx| {
1335 thread.subagent_spawned(session_id, cx);
1336 })?;
1337 }
1338 ThreadEvent::Retry(status) => {
1339 acp_thread.update(cx, |thread, cx| {
1340 thread.update_retry_status(status, cx)
1341 })?;
1342 }
1343 ThreadEvent::Stop(stop_reason) => {
1344 log::debug!("Assistant message complete: {:?}", stop_reason);
1345 return Ok(acp::PromptResponse::new(stop_reason));
1346 }
1347 }
1348 }
1349 Err(e) => {
1350 log::error!("Error in model response stream: {:?}", e);
1351 return Err(e);
1352 }
1353 }
1354 }
1355
1356 log::debug!("Response stream completed");
1357 anyhow::Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
1358 })
1359 }
1360}
1361
1362struct Command<'a> {
1363 prompt_name: &'a str,
1364 arg_value: &'a str,
1365 explicit_server_id: Option<&'a str>,
1366}
1367
1368impl<'a> Command<'a> {
1369 fn parse(prompt: &'a [acp::ContentBlock]) -> Option<Self> {
1370 let acp::ContentBlock::Text(text_content) = prompt.first()? else {
1371 return None;
1372 };
1373 let text = text_content.text.trim();
1374 let command = text.strip_prefix('/')?;
1375 let (command, arg_value) = command
1376 .split_once(char::is_whitespace)
1377 .unwrap_or((command, ""));
1378
1379 if let Some((server_id, prompt_name)) = command.split_once('.') {
1380 Some(Self {
1381 prompt_name,
1382 arg_value,
1383 explicit_server_id: Some(server_id),
1384 })
1385 } else {
1386 Some(Self {
1387 prompt_name: command,
1388 arg_value,
1389 explicit_server_id: None,
1390 })
1391 }
1392 }
1393}
1394
1395struct NativeAgentModelSelector {
1396 session_id: acp::SessionId,
1397 connection: NativeAgentConnection,
1398}
1399
1400impl acp_thread::AgentModelSelector for NativeAgentModelSelector {
1401 fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
1402 log::debug!("NativeAgentConnection::list_models called");
1403 let list = self.connection.0.read(cx).models.model_list.clone();
1404 Task::ready(if list.is_empty() {
1405 Err(anyhow::anyhow!("No models available"))
1406 } else {
1407 Ok(list)
1408 })
1409 }
1410
1411 fn select_model(&self, model_id: acp::ModelId, cx: &mut App) -> Task<Result<()>> {
1412 log::debug!(
1413 "Setting model for session {}: {}",
1414 self.session_id,
1415 model_id
1416 );
1417 let Some(thread) = self
1418 .connection
1419 .0
1420 .read(cx)
1421 .sessions
1422 .get(&self.session_id)
1423 .map(|session| session.thread.clone())
1424 else {
1425 return Task::ready(Err(anyhow!("Session not found")));
1426 };
1427
1428 let Some(model) = self.connection.0.read(cx).models.model_from_id(&model_id) else {
1429 return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
1430 };
1431
1432 let favorite = agent_settings::AgentSettings::get_global(cx)
1433 .favorite_models
1434 .iter()
1435 .find(|favorite| {
1436 favorite.provider.0 == model.provider_id().0.as_ref()
1437 && favorite.model == model.id().0.as_ref()
1438 })
1439 .cloned();
1440
1441 let LanguageModelSelection {
1442 enable_thinking,
1443 effort,
1444 speed,
1445 ..
1446 } = agent_settings::language_model_to_selection(&model, favorite.as_ref());
1447
1448 thread.update(cx, |thread, cx| {
1449 thread.set_model(model.clone(), cx);
1450 thread.set_thinking_effort(effort.clone(), cx);
1451 thread.set_thinking_enabled(enable_thinking, cx);
1452 if let Some(speed) = speed {
1453 thread.set_speed(speed, cx);
1454 }
1455 });
1456
1457 update_settings_file(
1458 self.connection.0.read(cx).fs.clone(),
1459 cx,
1460 move |settings, cx| {
1461 let provider = model.provider_id().0.to_string();
1462 let model = model.id().0.to_string();
1463 let enable_thinking = thread.read(cx).thinking_enabled();
1464 let speed = thread.read(cx).speed();
1465 settings
1466 .agent
1467 .get_or_insert_default()
1468 .set_model(LanguageModelSelection {
1469 provider: provider.into(),
1470 model,
1471 enable_thinking,
1472 effort,
1473 speed,
1474 });
1475 },
1476 );
1477
1478 Task::ready(Ok(()))
1479 }
1480
1481 fn selected_model(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelInfo>> {
1482 let Some(thread) = self
1483 .connection
1484 .0
1485 .read(cx)
1486 .sessions
1487 .get(&self.session_id)
1488 .map(|session| session.thread.clone())
1489 else {
1490 return Task::ready(Err(anyhow!("Session not found")));
1491 };
1492 let Some(model) = thread.read(cx).model() else {
1493 return Task::ready(Err(anyhow!("Model not found")));
1494 };
1495 let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
1496 else {
1497 return Task::ready(Err(anyhow!("Provider not found")));
1498 };
1499 Task::ready(Ok(LanguageModels::map_language_model_to_info(
1500 model, &provider,
1501 )))
1502 }
1503
1504 fn watch(&self, cx: &mut App) -> Option<watch::Receiver<()>> {
1505 Some(self.connection.0.read(cx).models.watch())
1506 }
1507
1508 fn should_render_footer(&self) -> bool {
1509 true
1510 }
1511}
1512
1513pub static ZED_AGENT_ID: LazyLock<AgentId> = LazyLock::new(|| AgentId::new("Zed Agent"));
1514
1515impl acp_thread::AgentConnection for NativeAgentConnection {
1516 fn agent_id(&self) -> AgentId {
1517 ZED_AGENT_ID.clone()
1518 }
1519
1520 fn telemetry_id(&self) -> SharedString {
1521 "zed".into()
1522 }
1523
1524 fn new_session(
1525 self: Rc<Self>,
1526 project: Entity<Project>,
1527 work_dirs: PathList,
1528 cx: &mut App,
1529 ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
1530 log::debug!("Creating new thread for project at: {work_dirs:?}");
1531 Task::ready(Ok(self
1532 .0
1533 .update(cx, |agent, cx| agent.new_session(project, cx))))
1534 }
1535
1536 fn supports_load_session(&self) -> bool {
1537 true
1538 }
1539
1540 fn load_session(
1541 self: Rc<Self>,
1542 session_id: acp::SessionId,
1543 project: Entity<Project>,
1544 _work_dirs: PathList,
1545 _title: Option<SharedString>,
1546 cx: &mut App,
1547 ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
1548 self.0
1549 .update(cx, |agent, cx| agent.open_thread(session_id, project, cx))
1550 }
1551
1552 fn supports_close_session(&self) -> bool {
1553 true
1554 }
1555
1556 fn close_session(
1557 self: Rc<Self>,
1558 session_id: &acp::SessionId,
1559 cx: &mut App,
1560 ) -> Task<Result<()>> {
1561 self.0
1562 .update(cx, |agent, cx| agent.close_session(session_id, cx))
1563 }
1564
1565 fn auth_methods(&self) -> &[acp::AuthMethod] {
1566 &[] // No auth for in-process
1567 }
1568
1569 fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
1570 Task::ready(Ok(()))
1571 }
1572
1573 fn model_selector(&self, session_id: &acp::SessionId) -> Option<Rc<dyn AgentModelSelector>> {
1574 Some(Rc::new(NativeAgentModelSelector {
1575 session_id: session_id.clone(),
1576 connection: self.clone(),
1577 }) as Rc<dyn AgentModelSelector>)
1578 }
1579
1580 fn prompt(
1581 &self,
1582 id: acp_thread::UserMessageId,
1583 params: acp::PromptRequest,
1584 cx: &mut App,
1585 ) -> Task<Result<acp::PromptResponse>> {
1586 let session_id = params.session_id.clone();
1587 log::info!("Received prompt request for session: {}", session_id);
1588 log::debug!("Prompt blocks count: {}", params.prompt.len());
1589
1590 let Some(project_state) = self.0.read(cx).session_project_state(&session_id) else {
1591 log::error!("Session not found in prompt: {}", session_id);
1592 if self.0.read(cx).sessions.contains_key(&session_id) {
1593 log::error!(
1594 "Session found in sessions map, but not in project state: {}",
1595 session_id
1596 );
1597 }
1598 return Task::ready(Err(anyhow::anyhow!("Session not found")));
1599 };
1600
1601 if let Some(parsed_command) = Command::parse(¶ms.prompt) {
1602 let registry = project_state.context_server_registry.read(cx);
1603
1604 let explicit_server_id = parsed_command
1605 .explicit_server_id
1606 .map(|server_id| ContextServerId(server_id.into()));
1607
1608 if let Some(prompt) =
1609 registry.find_prompt(explicit_server_id.as_ref(), parsed_command.prompt_name)
1610 {
1611 let arguments = if !parsed_command.arg_value.is_empty()
1612 && let Some(arg_name) = prompt
1613 .prompt
1614 .arguments
1615 .as_ref()
1616 .and_then(|args| args.first())
1617 .map(|arg| arg.name.clone())
1618 {
1619 HashMap::from_iter([(arg_name, parsed_command.arg_value.to_string())])
1620 } else {
1621 Default::default()
1622 };
1623
1624 let prompt_name = prompt.prompt.name.clone();
1625 let server_id = prompt.server_id.clone();
1626
1627 return self.0.update(cx, |agent, cx| {
1628 agent.send_mcp_prompt(
1629 id,
1630 session_id.clone(),
1631 prompt_name,
1632 server_id,
1633 arguments,
1634 params.prompt,
1635 cx,
1636 )
1637 });
1638 }
1639 };
1640
1641 let path_style = project_state.project.read(cx).path_style(cx);
1642
1643 self.run_turn(session_id, cx, move |thread, cx| {
1644 let content: Vec<UserMessageContent> = params
1645 .prompt
1646 .into_iter()
1647 .map(|block| UserMessageContent::from_content_block(block, path_style))
1648 .collect::<Vec<_>>();
1649 log::debug!("Converted prompt to message: {} chars", content.len());
1650 log::debug!("Message id: {:?}", id);
1651 log::debug!("Message content: {:?}", content);
1652
1653 thread.update(cx, |thread, cx| thread.send(id, content, cx))
1654 })
1655 }
1656
1657 fn retry(
1658 &self,
1659 session_id: &acp::SessionId,
1660 _cx: &App,
1661 ) -> Option<Rc<dyn acp_thread::AgentSessionRetry>> {
1662 Some(Rc::new(NativeAgentSessionRetry {
1663 connection: self.clone(),
1664 session_id: session_id.clone(),
1665 }) as _)
1666 }
1667
1668 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
1669 log::info!("Cancelling on session: {}", session_id);
1670 self.0.update(cx, |agent, cx| {
1671 if let Some(session) = agent.sessions.get(session_id) {
1672 session
1673 .thread
1674 .update(cx, |thread, cx| thread.cancel(cx))
1675 .detach();
1676 }
1677 });
1678 }
1679
1680 fn truncate(
1681 &self,
1682 session_id: &acp::SessionId,
1683 cx: &App,
1684 ) -> Option<Rc<dyn acp_thread::AgentSessionTruncate>> {
1685 self.0.read_with(cx, |agent, _cx| {
1686 agent.sessions.get(session_id).map(|session| {
1687 Rc::new(NativeAgentSessionTruncate {
1688 thread: session.thread.clone(),
1689 acp_thread: session.acp_thread.downgrade(),
1690 }) as _
1691 })
1692 })
1693 }
1694
1695 fn set_title(
1696 &self,
1697 session_id: &acp::SessionId,
1698 cx: &App,
1699 ) -> Option<Rc<dyn acp_thread::AgentSessionSetTitle>> {
1700 self.0.read_with(cx, |agent, _cx| {
1701 agent
1702 .sessions
1703 .get(session_id)
1704 .filter(|s| !s.thread.read(cx).is_subagent())
1705 .map(|session| {
1706 Rc::new(NativeAgentSessionSetTitle {
1707 thread: session.thread.clone(),
1708 }) as _
1709 })
1710 })
1711 }
1712
1713 fn session_list(&self, cx: &mut App) -> Option<Rc<dyn AgentSessionList>> {
1714 let thread_store = self.0.read(cx).thread_store.clone();
1715 Some(Rc::new(NativeAgentSessionList::new(thread_store, cx)) as _)
1716 }
1717
1718 fn telemetry(&self) -> Option<Rc<dyn acp_thread::AgentTelemetry>> {
1719 Some(Rc::new(self.clone()) as Rc<dyn acp_thread::AgentTelemetry>)
1720 }
1721
1722 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1723 self
1724 }
1725}
1726
1727impl acp_thread::AgentTelemetry for NativeAgentConnection {
1728 fn thread_data(
1729 &self,
1730 session_id: &acp::SessionId,
1731 cx: &mut App,
1732 ) -> Task<Result<serde_json::Value>> {
1733 let Some(session) = self.0.read(cx).sessions.get(session_id) else {
1734 return Task::ready(Err(anyhow!("Session not found")));
1735 };
1736
1737 let task = session.thread.read(cx).to_db(cx);
1738 cx.background_spawn(async move {
1739 serde_json::to_value(task.await).context("Failed to serialize thread")
1740 })
1741 }
1742}
1743
1744pub struct NativeAgentSessionList {
1745 thread_store: Entity<ThreadStore>,
1746 updates_tx: smol::channel::Sender<acp_thread::SessionListUpdate>,
1747 updates_rx: smol::channel::Receiver<acp_thread::SessionListUpdate>,
1748 _subscription: Subscription,
1749}
1750
1751impl NativeAgentSessionList {
1752 fn new(thread_store: Entity<ThreadStore>, cx: &mut App) -> Self {
1753 let (tx, rx) = smol::channel::unbounded();
1754 let this_tx = tx.clone();
1755 let subscription = cx.observe(&thread_store, move |_, _| {
1756 this_tx
1757 .try_send(acp_thread::SessionListUpdate::Refresh)
1758 .ok();
1759 });
1760 Self {
1761 thread_store,
1762 updates_tx: tx,
1763 updates_rx: rx,
1764 _subscription: subscription,
1765 }
1766 }
1767
1768 pub fn thread_store(&self) -> &Entity<ThreadStore> {
1769 &self.thread_store
1770 }
1771}
1772
1773impl AgentSessionList for NativeAgentSessionList {
1774 fn list_sessions(
1775 &self,
1776 _request: AgentSessionListRequest,
1777 cx: &mut App,
1778 ) -> Task<Result<AgentSessionListResponse>> {
1779 let sessions = self
1780 .thread_store
1781 .read(cx)
1782 .entries()
1783 .map(|entry| AgentSessionInfo::from(&entry))
1784 .collect();
1785 Task::ready(Ok(AgentSessionListResponse::new(sessions)))
1786 }
1787
1788 fn supports_delete(&self) -> bool {
1789 true
1790 }
1791
1792 fn delete_session(&self, session_id: &acp::SessionId, cx: &mut App) -> Task<Result<()>> {
1793 self.thread_store
1794 .update(cx, |store, cx| store.delete_thread(session_id.clone(), cx))
1795 }
1796
1797 fn delete_sessions(&self, cx: &mut App) -> Task<Result<()>> {
1798 self.thread_store
1799 .update(cx, |store, cx| store.delete_threads(cx))
1800 }
1801
1802 fn watch(
1803 &self,
1804 _cx: &mut App,
1805 ) -> Option<smol::channel::Receiver<acp_thread::SessionListUpdate>> {
1806 Some(self.updates_rx.clone())
1807 }
1808
1809 fn notify_refresh(&self) {
1810 self.updates_tx
1811 .try_send(acp_thread::SessionListUpdate::Refresh)
1812 .ok();
1813 }
1814
1815 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1816 self
1817 }
1818}
1819
1820struct NativeAgentSessionTruncate {
1821 thread: Entity<Thread>,
1822 acp_thread: WeakEntity<AcpThread>,
1823}
1824
1825impl acp_thread::AgentSessionTruncate for NativeAgentSessionTruncate {
1826 fn run(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
1827 match self.thread.update(cx, |thread, cx| {
1828 thread.truncate(message_id.clone(), cx)?;
1829 Ok(thread.latest_token_usage())
1830 }) {
1831 Ok(usage) => {
1832 self.acp_thread
1833 .update(cx, |thread, cx| {
1834 thread.update_token_usage(usage, cx);
1835 })
1836 .ok();
1837 Task::ready(Ok(()))
1838 }
1839 Err(error) => Task::ready(Err(error)),
1840 }
1841 }
1842}
1843
1844struct NativeAgentSessionRetry {
1845 connection: NativeAgentConnection,
1846 session_id: acp::SessionId,
1847}
1848
1849impl acp_thread::AgentSessionRetry for NativeAgentSessionRetry {
1850 fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>> {
1851 self.connection
1852 .run_turn(self.session_id.clone(), cx, |thread, cx| {
1853 thread.update(cx, |thread, cx| thread.resume(cx))
1854 })
1855 }
1856}
1857
1858struct NativeAgentSessionSetTitle {
1859 thread: Entity<Thread>,
1860}
1861
1862impl acp_thread::AgentSessionSetTitle for NativeAgentSessionSetTitle {
1863 fn run(&self, title: SharedString, cx: &mut App) -> Task<Result<()>> {
1864 self.thread
1865 .update(cx, |thread, cx| thread.set_title(title, cx));
1866 Task::ready(Ok(()))
1867 }
1868}
1869
1870pub struct NativeThreadEnvironment {
1871 agent: WeakEntity<NativeAgent>,
1872 thread: WeakEntity<Thread>,
1873 acp_thread: WeakEntity<AcpThread>,
1874}
1875
1876impl NativeThreadEnvironment {
1877 pub(crate) fn create_subagent_thread(
1878 &self,
1879 label: String,
1880 cx: &mut App,
1881 ) -> Result<Rc<dyn SubagentHandle>> {
1882 let Some(parent_thread_entity) = self.thread.upgrade() else {
1883 anyhow::bail!("Parent thread no longer exists".to_string());
1884 };
1885 let parent_thread = parent_thread_entity.read(cx);
1886 let current_depth = parent_thread.depth();
1887 let parent_session_id = parent_thread.id().clone();
1888
1889 if current_depth >= MAX_SUBAGENT_DEPTH {
1890 return Err(anyhow!(
1891 "Maximum subagent depth ({}) reached",
1892 MAX_SUBAGENT_DEPTH
1893 ));
1894 }
1895
1896 let subagent_thread: Entity<Thread> = cx.new(|cx| {
1897 let mut thread = Thread::new_subagent(&parent_thread_entity, cx);
1898 thread.set_title(label.into(), cx);
1899 thread
1900 });
1901
1902 let session_id = subagent_thread.read(cx).id().clone();
1903
1904 let acp_thread = self
1905 .agent
1906 .update(cx, |agent, cx| -> Result<Entity<AcpThread>> {
1907 let project_id = agent
1908 .sessions
1909 .get(&parent_session_id)
1910 .map(|s| s.project_id)
1911 .context("parent session not found")?;
1912 Ok(agent.register_session(subagent_thread.clone(), project_id, 1, cx))
1913 })??;
1914
1915 let depth = current_depth + 1;
1916
1917 telemetry::event!(
1918 "Subagent Started",
1919 session = parent_thread_entity.read(cx).id().to_string(),
1920 subagent_session = session_id.to_string(),
1921 depth,
1922 is_resumed = false,
1923 );
1924
1925 self.prompt_subagent(session_id, subagent_thread, acp_thread)
1926 }
1927
1928 pub(crate) fn resume_subagent_thread(
1929 &self,
1930 session_id: acp::SessionId,
1931 cx: &mut App,
1932 ) -> Result<Rc<dyn SubagentHandle>> {
1933 let (subagent_thread, acp_thread) = self.agent.update(cx, |agent, _cx| {
1934 let session = agent
1935 .sessions
1936 .get(&session_id)
1937 .ok_or_else(|| anyhow!("No subagent session found with id {session_id}"))?;
1938 anyhow::Ok((session.thread.clone(), session.acp_thread.clone()))
1939 })??;
1940
1941 let depth = subagent_thread.read(cx).depth();
1942
1943 if let Some(parent_thread_entity) = self.thread.upgrade() {
1944 telemetry::event!(
1945 "Subagent Started",
1946 session = parent_thread_entity.read(cx).id().to_string(),
1947 subagent_session = session_id.to_string(),
1948 depth,
1949 is_resumed = true,
1950 );
1951 }
1952
1953 self.prompt_subagent(session_id, subagent_thread, acp_thread)
1954 }
1955
1956 fn prompt_subagent(
1957 &self,
1958 session_id: acp::SessionId,
1959 subagent_thread: Entity<Thread>,
1960 acp_thread: Entity<acp_thread::AcpThread>,
1961 ) -> Result<Rc<dyn SubagentHandle>> {
1962 let Some(parent_thread_entity) = self.thread.upgrade() else {
1963 anyhow::bail!("Parent thread no longer exists".to_string());
1964 };
1965 Ok(Rc::new(NativeSubagentHandle::new(
1966 session_id,
1967 subagent_thread,
1968 acp_thread,
1969 parent_thread_entity,
1970 )) as _)
1971 }
1972}
1973
1974impl ThreadEnvironment for NativeThreadEnvironment {
1975 fn create_terminal(
1976 &self,
1977 command: String,
1978 cwd: Option<PathBuf>,
1979 output_byte_limit: Option<u64>,
1980 cx: &mut AsyncApp,
1981 ) -> Task<Result<Rc<dyn TerminalHandle>>> {
1982 let task = self.acp_thread.update(cx, |thread, cx| {
1983 thread.create_terminal(command, vec![], vec![], cwd, output_byte_limit, cx)
1984 });
1985
1986 let acp_thread = self.acp_thread.clone();
1987 cx.spawn(async move |cx| {
1988 let terminal = task?.await?;
1989
1990 let (drop_tx, drop_rx) = oneshot::channel();
1991 let terminal_id = terminal.read_with(cx, |terminal, _cx| terminal.id().clone());
1992
1993 cx.spawn(async move |cx| {
1994 drop_rx.await.ok();
1995 acp_thread.update(cx, |thread, cx| thread.release_terminal(terminal_id, cx))
1996 })
1997 .detach();
1998
1999 let handle = AcpTerminalHandle {
2000 terminal,
2001 _drop_tx: Some(drop_tx),
2002 };
2003
2004 Ok(Rc::new(handle) as _)
2005 })
2006 }
2007
2008 fn create_subagent(&self, label: String, cx: &mut App) -> Result<Rc<dyn SubagentHandle>> {
2009 self.create_subagent_thread(label, cx)
2010 }
2011
2012 fn resume_subagent(
2013 &self,
2014 session_id: acp::SessionId,
2015 cx: &mut App,
2016 ) -> Result<Rc<dyn SubagentHandle>> {
2017 self.resume_subagent_thread(session_id, cx)
2018 }
2019}
2020
2021#[derive(Debug, Clone)]
2022enum SubagentPromptResult {
2023 Completed,
2024 Cancelled,
2025 ContextWindowWarning,
2026 Error(String),
2027}
2028
2029pub struct NativeSubagentHandle {
2030 session_id: acp::SessionId,
2031 parent_thread: WeakEntity<Thread>,
2032 subagent_thread: Entity<Thread>,
2033 acp_thread: Entity<acp_thread::AcpThread>,
2034}
2035
2036impl NativeSubagentHandle {
2037 fn new(
2038 session_id: acp::SessionId,
2039 subagent_thread: Entity<Thread>,
2040 acp_thread: Entity<acp_thread::AcpThread>,
2041 parent_thread_entity: Entity<Thread>,
2042 ) -> Self {
2043 NativeSubagentHandle {
2044 session_id,
2045 subagent_thread,
2046 parent_thread: parent_thread_entity.downgrade(),
2047 acp_thread,
2048 }
2049 }
2050}
2051
2052impl SubagentHandle for NativeSubagentHandle {
2053 fn id(&self) -> acp::SessionId {
2054 self.session_id.clone()
2055 }
2056
2057 fn num_entries(&self, cx: &App) -> usize {
2058 self.acp_thread.read(cx).entries().len()
2059 }
2060
2061 fn send(&self, message: String, cx: &AsyncApp) -> Task<Result<String>> {
2062 let thread = self.subagent_thread.clone();
2063 let acp_thread = self.acp_thread.clone();
2064 let subagent_session_id = self.session_id.clone();
2065 let parent_thread = self.parent_thread.clone();
2066
2067 cx.spawn(async move |cx| {
2068 let (task, _subscription) = cx.update(|cx| {
2069 let ratio_before_prompt = thread
2070 .read(cx)
2071 .latest_token_usage()
2072 .map(|usage| usage.ratio());
2073
2074 parent_thread
2075 .update(cx, |parent_thread, _cx| {
2076 parent_thread.register_running_subagent(thread.downgrade())
2077 })
2078 .ok();
2079
2080 let task = acp_thread.update(cx, |acp_thread, cx| {
2081 acp_thread.send(vec![message.into()], cx)
2082 });
2083
2084 let (token_limit_tx, token_limit_rx) = oneshot::channel::<()>();
2085 let mut token_limit_tx = Some(token_limit_tx);
2086
2087 let subscription = cx.subscribe(
2088 &thread,
2089 move |_thread, event: &TokenUsageUpdated, _cx| {
2090 if let Some(usage) = &event.0 {
2091 let old_ratio = ratio_before_prompt
2092 .clone()
2093 .unwrap_or(TokenUsageRatio::Normal);
2094 let new_ratio = usage.ratio();
2095 if old_ratio == TokenUsageRatio::Normal
2096 && new_ratio == TokenUsageRatio::Warning
2097 {
2098 if let Some(tx) = token_limit_tx.take() {
2099 tx.send(()).ok();
2100 }
2101 }
2102 }
2103 },
2104 );
2105
2106 let wait_for_prompt = cx
2107 .background_spawn(async move {
2108 futures::select! {
2109 response = task.fuse() => match response {
2110 Ok(Some(response)) => {
2111 match response.stop_reason {
2112 acp::StopReason::Cancelled => SubagentPromptResult::Cancelled,
2113 acp::StopReason::MaxTokens => SubagentPromptResult::Error("The agent reached the maximum number of tokens.".into()),
2114 acp::StopReason::MaxTurnRequests => SubagentPromptResult::Error("The agent reached the maximum number of allowed requests between user turns. Try prompting again.".into()),
2115 acp::StopReason::Refusal => SubagentPromptResult::Error("The agent refused to process that prompt. Try again.".into()),
2116 acp::StopReason::EndTurn | _ => SubagentPromptResult::Completed,
2117 }
2118 }
2119 Ok(None) => SubagentPromptResult::Error("No response from the agent. You can try messaging again.".into()),
2120 Err(error) => SubagentPromptResult::Error(error.to_string()),
2121 },
2122 _ = token_limit_rx.fuse() => SubagentPromptResult::ContextWindowWarning,
2123 }
2124 });
2125
2126 (wait_for_prompt, subscription)
2127 });
2128
2129 let result = match task.await {
2130 SubagentPromptResult::Completed => thread.read_with(cx, |thread, _cx| {
2131 thread
2132 .last_message()
2133 .and_then(|message| {
2134 let content = message.as_agent_message()?
2135 .content
2136 .iter()
2137 .filter_map(|c| match c {
2138 AgentMessageContent::Text(text) => Some(text.as_str()),
2139 _ => None,
2140 })
2141 .join("\n\n");
2142 if content.is_empty() {
2143 None
2144 } else {
2145 Some( content)
2146 }
2147 })
2148 .context("No response from subagent")
2149 }),
2150 SubagentPromptResult::Cancelled => Err(anyhow!("User canceled")),
2151 SubagentPromptResult::Error(message) => Err(anyhow!("{message}")),
2152 SubagentPromptResult::ContextWindowWarning => {
2153 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
2154 Err(anyhow!(
2155 "The agent is nearing the end of its context window and has been \
2156 stopped. You can prompt the thread again to have the agent wrap up \
2157 or hand off its work."
2158 ))
2159 }
2160 };
2161
2162 parent_thread
2163 .update(cx, |parent_thread, cx| {
2164 parent_thread.unregister_running_subagent(&subagent_session_id, cx)
2165 })
2166 .ok();
2167
2168 result
2169 })
2170 }
2171}
2172
2173pub struct AcpTerminalHandle {
2174 terminal: Entity<acp_thread::Terminal>,
2175 _drop_tx: Option<oneshot::Sender<()>>,
2176}
2177
2178impl TerminalHandle for AcpTerminalHandle {
2179 fn id(&self, cx: &AsyncApp) -> Result<acp::TerminalId> {
2180 Ok(self.terminal.read_with(cx, |term, _cx| term.id().clone()))
2181 }
2182
2183 fn wait_for_exit(&self, cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
2184 Ok(self
2185 .terminal
2186 .read_with(cx, |term, _cx| term.wait_for_exit()))
2187 }
2188
2189 fn current_output(&self, cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
2190 Ok(self
2191 .terminal
2192 .read_with(cx, |term, cx| term.current_output(cx)))
2193 }
2194
2195 fn kill(&self, cx: &AsyncApp) -> Result<()> {
2196 cx.update(|cx| {
2197 self.terminal.update(cx, |terminal, cx| {
2198 terminal.kill(cx);
2199 });
2200 });
2201 Ok(())
2202 }
2203
2204 fn was_stopped_by_user(&self, cx: &AsyncApp) -> Result<bool> {
2205 Ok(self
2206 .terminal
2207 .read_with(cx, |term, _cx| term.was_stopped_by_user()))
2208 }
2209}
2210
2211#[cfg(test)]
2212mod internal_tests {
2213 use std::path::Path;
2214
2215 use super::*;
2216 use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelInfo, MentionUri};
2217 use fs::FakeFs;
2218 use gpui::TestAppContext;
2219 use indoc::formatdoc;
2220 use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
2221 use language_model::{
2222 LanguageModelCompletionEvent, LanguageModelProviderId, LanguageModelProviderName,
2223 };
2224 use serde_json::json;
2225 use settings::SettingsStore;
2226 use util::{path, rel_path::rel_path};
2227
2228 #[gpui::test]
2229 async fn test_maintaining_project_context(cx: &mut TestAppContext) {
2230 init_test(cx);
2231 let fs = FakeFs::new(cx.executor());
2232 fs.insert_tree(
2233 "/",
2234 json!({
2235 "a": {}
2236 }),
2237 )
2238 .await;
2239 let project = Project::test(fs.clone(), [], cx).await;
2240 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2241 let agent =
2242 cx.update(|cx| NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx));
2243
2244 // Creating a session registers the project and triggers context building.
2245 let connection = NativeAgentConnection(agent.clone());
2246 let _acp_thread = cx
2247 .update(|cx| {
2248 Rc::new(connection).new_session(
2249 project.clone(),
2250 PathList::new(&[Path::new("/")]),
2251 cx,
2252 )
2253 })
2254 .await
2255 .unwrap();
2256 cx.run_until_parked();
2257
2258 let thread = agent.read_with(cx, |agent, _cx| {
2259 agent.sessions.values().next().unwrap().thread.clone()
2260 });
2261
2262 agent.read_with(cx, |agent, cx| {
2263 let project_id = project.entity_id();
2264 let state = agent.projects.get(&project_id).unwrap();
2265 assert_eq!(state.project_context.read(cx).worktrees, vec![]);
2266 assert_eq!(thread.read(cx).project_context().read(cx).worktrees, vec![]);
2267 });
2268
2269 let worktree = project
2270 .update(cx, |project, cx| project.create_worktree("/a", true, cx))
2271 .await
2272 .unwrap();
2273 cx.run_until_parked();
2274 agent.read_with(cx, |agent, cx| {
2275 let project_id = project.entity_id();
2276 let state = agent.projects.get(&project_id).unwrap();
2277 let expected_worktrees = vec![WorktreeContext {
2278 root_name: "a".into(),
2279 abs_path: Path::new("/a").into(),
2280 rules_file: None,
2281 }];
2282 assert_eq!(state.project_context.read(cx).worktrees, expected_worktrees);
2283 assert_eq!(
2284 thread.read(cx).project_context().read(cx).worktrees,
2285 expected_worktrees
2286 );
2287 });
2288
2289 // Creating `/a/.rules` updates the project context.
2290 fs.insert_file("/a/.rules", Vec::new()).await;
2291 cx.run_until_parked();
2292 agent.read_with(cx, |agent, cx| {
2293 let project_id = project.entity_id();
2294 let state = agent.projects.get(&project_id).unwrap();
2295 let rules_entry = worktree
2296 .read(cx)
2297 .entry_for_path(rel_path(".rules"))
2298 .unwrap();
2299 let expected_worktrees = vec![WorktreeContext {
2300 root_name: "a".into(),
2301 abs_path: Path::new("/a").into(),
2302 rules_file: Some(RulesFileContext {
2303 path_in_worktree: rel_path(".rules").into(),
2304 text: "".into(),
2305 project_entry_id: rules_entry.id.to_usize(),
2306 }),
2307 }];
2308 assert_eq!(state.project_context.read(cx).worktrees, expected_worktrees);
2309 assert_eq!(
2310 thread.read(cx).project_context().read(cx).worktrees,
2311 expected_worktrees
2312 );
2313 });
2314 }
2315
2316 #[gpui::test]
2317 async fn test_listing_models(cx: &mut TestAppContext) {
2318 init_test(cx);
2319 let fs = FakeFs::new(cx.executor());
2320 fs.insert_tree("/", json!({ "a": {} })).await;
2321 let project = Project::test(fs.clone(), [], cx).await;
2322 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2323 let connection =
2324 NativeAgentConnection(cx.update(|cx| {
2325 NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx)
2326 }));
2327
2328 // Create a thread/session
2329 let acp_thread = cx
2330 .update(|cx| {
2331 Rc::new(connection.clone()).new_session(
2332 project.clone(),
2333 PathList::new(&[Path::new("/a")]),
2334 cx,
2335 )
2336 })
2337 .await
2338 .unwrap();
2339
2340 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2341
2342 let models = cx
2343 .update(|cx| {
2344 connection
2345 .model_selector(&session_id)
2346 .unwrap()
2347 .list_models(cx)
2348 })
2349 .await
2350 .unwrap();
2351
2352 let acp_thread::AgentModelList::Grouped(models) = models else {
2353 panic!("Unexpected model group");
2354 };
2355 assert_eq!(
2356 models,
2357 IndexMap::from_iter([(
2358 AgentModelGroupName("Fake".into()),
2359 vec![AgentModelInfo {
2360 id: acp::ModelId::new("fake/fake"),
2361 name: "Fake".into(),
2362 description: None,
2363 icon: Some(acp_thread::AgentModelIcon::Named(
2364 ui::IconName::ZedAssistant
2365 )),
2366 is_latest: false,
2367 cost: None,
2368 }]
2369 )])
2370 );
2371 }
2372
2373 #[gpui::test]
2374 async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) {
2375 init_test(cx);
2376 let fs = FakeFs::new(cx.executor());
2377 fs.create_dir(paths::settings_file().parent().unwrap())
2378 .await
2379 .unwrap();
2380 fs.insert_file(
2381 paths::settings_file(),
2382 json!({
2383 "agent": {
2384 "default_model": {
2385 "provider": "foo",
2386 "model": "bar"
2387 }
2388 }
2389 })
2390 .to_string()
2391 .into_bytes(),
2392 )
2393 .await;
2394 let project = Project::test(fs.clone(), [], cx).await;
2395
2396 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2397
2398 // Create the agent and connection
2399 let agent =
2400 cx.update(|cx| NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx));
2401 let connection = NativeAgentConnection(agent.clone());
2402
2403 // Create a thread/session
2404 let acp_thread = cx
2405 .update(|cx| {
2406 Rc::new(connection.clone()).new_session(
2407 project.clone(),
2408 PathList::new(&[Path::new("/a")]),
2409 cx,
2410 )
2411 })
2412 .await
2413 .unwrap();
2414
2415 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2416
2417 // Select a model
2418 let selector = connection.model_selector(&session_id).unwrap();
2419 let model_id = acp::ModelId::new("fake/fake");
2420 cx.update(|cx| selector.select_model(model_id.clone(), cx))
2421 .await
2422 .unwrap();
2423
2424 // Verify the thread has the selected model
2425 agent.read_with(cx, |agent, _| {
2426 let session = agent.sessions.get(&session_id).unwrap();
2427 session.thread.read_with(cx, |thread, _| {
2428 assert_eq!(thread.model().unwrap().id().0, "fake");
2429 });
2430 });
2431
2432 cx.run_until_parked();
2433
2434 // Verify settings file was updated
2435 let settings_content = fs.load(paths::settings_file()).await.unwrap();
2436 let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
2437
2438 // Check that the agent settings contain the selected model
2439 assert_eq!(
2440 settings_json["agent"]["default_model"]["model"],
2441 json!("fake")
2442 );
2443 assert_eq!(
2444 settings_json["agent"]["default_model"]["provider"],
2445 json!("fake")
2446 );
2447
2448 // Register a thinking model and select it.
2449 cx.update(|cx| {
2450 let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2451 "fake-corp",
2452 "fake-thinking",
2453 "Fake Thinking",
2454 true,
2455 ));
2456 let thinking_provider = Arc::new(
2457 FakeLanguageModelProvider::new(
2458 LanguageModelProviderId::from("fake-corp".to_string()),
2459 LanguageModelProviderName::from("Fake Corp".to_string()),
2460 )
2461 .with_models(vec![thinking_model]),
2462 );
2463 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2464 registry.register_provider(thinking_provider, cx);
2465 });
2466 });
2467 agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2468
2469 let selector = connection.model_selector(&session_id).unwrap();
2470 cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2471 .await
2472 .unwrap();
2473 cx.run_until_parked();
2474
2475 // Verify enable_thinking was written to settings as true.
2476 let settings_content = fs.load(paths::settings_file()).await.unwrap();
2477 let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
2478 assert_eq!(
2479 settings_json["agent"]["default_model"]["enable_thinking"],
2480 json!(true),
2481 "selecting a thinking model should persist enable_thinking: true to settings"
2482 );
2483 }
2484
2485 #[gpui::test]
2486 async fn test_select_model_updates_thinking_enabled(cx: &mut TestAppContext) {
2487 init_test(cx);
2488 let fs = FakeFs::new(cx.executor());
2489 fs.create_dir(paths::settings_file().parent().unwrap())
2490 .await
2491 .unwrap();
2492 fs.insert_file(paths::settings_file(), b"{}".to_vec()).await;
2493 let project = Project::test(fs.clone(), [], cx).await;
2494
2495 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2496 let agent =
2497 cx.update(|cx| NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx));
2498 let connection = NativeAgentConnection(agent.clone());
2499
2500 let acp_thread = cx
2501 .update(|cx| {
2502 Rc::new(connection.clone()).new_session(
2503 project.clone(),
2504 PathList::new(&[Path::new("/a")]),
2505 cx,
2506 )
2507 })
2508 .await
2509 .unwrap();
2510 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2511
2512 // Register a second provider with a thinking model.
2513 cx.update(|cx| {
2514 let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2515 "fake-corp",
2516 "fake-thinking",
2517 "Fake Thinking",
2518 true,
2519 ));
2520 let thinking_provider = Arc::new(
2521 FakeLanguageModelProvider::new(
2522 LanguageModelProviderId::from("fake-corp".to_string()),
2523 LanguageModelProviderName::from("Fake Corp".to_string()),
2524 )
2525 .with_models(vec![thinking_model]),
2526 );
2527 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2528 registry.register_provider(thinking_provider, cx);
2529 });
2530 });
2531 // Refresh the agent's model list so it picks up the new provider.
2532 agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2533
2534 // Thread starts with thinking_enabled = false (the default).
2535 agent.read_with(cx, |agent, _| {
2536 let session = agent.sessions.get(&session_id).unwrap();
2537 session.thread.read_with(cx, |thread, _| {
2538 assert!(!thread.thinking_enabled(), "thinking defaults to false");
2539 });
2540 });
2541
2542 // Select the thinking model via select_model.
2543 let selector = connection.model_selector(&session_id).unwrap();
2544 cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2545 .await
2546 .unwrap();
2547
2548 // select_model should have enabled thinking based on the model's supports_thinking().
2549 agent.read_with(cx, |agent, _| {
2550 let session = agent.sessions.get(&session_id).unwrap();
2551 session.thread.read_with(cx, |thread, _| {
2552 assert!(
2553 thread.thinking_enabled(),
2554 "select_model should enable thinking when model supports it"
2555 );
2556 });
2557 });
2558
2559 // Switch back to the non-thinking model.
2560 let selector = connection.model_selector(&session_id).unwrap();
2561 cx.update(|cx| selector.select_model(acp::ModelId::new("fake/fake"), cx))
2562 .await
2563 .unwrap();
2564
2565 // select_model should have disabled thinking.
2566 agent.read_with(cx, |agent, _| {
2567 let session = agent.sessions.get(&session_id).unwrap();
2568 session.thread.read_with(cx, |thread, _| {
2569 assert!(
2570 !thread.thinking_enabled(),
2571 "select_model should disable thinking when model does not support it"
2572 );
2573 });
2574 });
2575 }
2576
2577 #[gpui::test]
2578 async fn test_summarization_model_survives_transient_registry_clearing(
2579 cx: &mut TestAppContext,
2580 ) {
2581 init_test(cx);
2582 let fs = FakeFs::new(cx.executor());
2583 fs.insert_tree("/", json!({ "a": {} })).await;
2584 let project = Project::test(fs.clone(), [], cx).await;
2585
2586 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2587 let agent =
2588 cx.update(|cx| NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx));
2589 let connection = Rc::new(NativeAgentConnection(agent.clone()));
2590
2591 let acp_thread = cx
2592 .update(|cx| {
2593 connection.clone().new_session(
2594 project.clone(),
2595 PathList::new(&[Path::new("/a")]),
2596 cx,
2597 )
2598 })
2599 .await
2600 .unwrap();
2601 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2602
2603 let thread = agent.read_with(cx, |agent, _| {
2604 agent.sessions.get(&session_id).unwrap().thread.clone()
2605 });
2606
2607 thread.read_with(cx, |thread, _| {
2608 assert!(
2609 thread.summarization_model().is_some(),
2610 "session should have a summarization model from the test registry"
2611 );
2612 });
2613
2614 // Simulate what happens during a provider blip:
2615 // update_active_language_model_from_settings calls set_default_model(None)
2616 // when it can't resolve the model, clearing all fallbacks.
2617 cx.update(|cx| {
2618 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2619 registry.set_default_model(None, cx);
2620 });
2621 });
2622 cx.run_until_parked();
2623
2624 thread.read_with(cx, |thread, _| {
2625 assert!(
2626 thread.summarization_model().is_some(),
2627 "summarization model should survive a transient default model clearing"
2628 );
2629 });
2630 }
2631
2632 #[gpui::test]
2633 async fn test_loaded_thread_preserves_thinking_enabled(cx: &mut TestAppContext) {
2634 init_test(cx);
2635 let fs = FakeFs::new(cx.executor());
2636 fs.insert_tree("/", json!({ "a": {} })).await;
2637 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2638 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2639 let agent = cx.update(|cx| {
2640 NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
2641 });
2642 let connection = Rc::new(NativeAgentConnection(agent.clone()));
2643
2644 // Register a thinking model.
2645 let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2646 "fake-corp",
2647 "fake-thinking",
2648 "Fake Thinking",
2649 true,
2650 ));
2651 let thinking_provider = Arc::new(
2652 FakeLanguageModelProvider::new(
2653 LanguageModelProviderId::from("fake-corp".to_string()),
2654 LanguageModelProviderName::from("Fake Corp".to_string()),
2655 )
2656 .with_models(vec![thinking_model.clone()]),
2657 );
2658 cx.update(|cx| {
2659 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2660 registry.register_provider(thinking_provider, cx);
2661 });
2662 });
2663 agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2664
2665 // Create a thread and select the thinking model.
2666 let acp_thread = cx
2667 .update(|cx| {
2668 connection.clone().new_session(
2669 project.clone(),
2670 PathList::new(&[Path::new("/a")]),
2671 cx,
2672 )
2673 })
2674 .await
2675 .unwrap();
2676 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2677
2678 let selector = connection.model_selector(&session_id).unwrap();
2679 cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2680 .await
2681 .unwrap();
2682
2683 // Verify thinking is enabled after selecting the thinking model.
2684 let thread = agent.read_with(cx, |agent, _| {
2685 agent.sessions.get(&session_id).unwrap().thread.clone()
2686 });
2687 thread.read_with(cx, |thread, _| {
2688 assert!(
2689 thread.thinking_enabled(),
2690 "thinking should be enabled after selecting thinking model"
2691 );
2692 });
2693
2694 // Send a message so the thread gets persisted.
2695 let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
2696 let send = cx.foreground_executor().spawn(send);
2697 cx.run_until_parked();
2698
2699 thinking_model.send_last_completion_stream_text_chunk("Response.");
2700 thinking_model.end_last_completion_stream();
2701
2702 send.await.unwrap();
2703 cx.run_until_parked();
2704
2705 // Close the session so it can be reloaded from disk.
2706 cx.update(|cx| connection.clone().close_session(&session_id, cx))
2707 .await
2708 .unwrap();
2709 drop(thread);
2710 drop(acp_thread);
2711 agent.read_with(cx, |agent, _| {
2712 assert!(agent.sessions.is_empty());
2713 });
2714
2715 // Reload the thread and verify thinking_enabled is still true.
2716 let reloaded_acp_thread = agent
2717 .update(cx, |agent, cx| {
2718 agent.open_thread(session_id.clone(), project.clone(), cx)
2719 })
2720 .await
2721 .unwrap();
2722 let reloaded_thread = agent.read_with(cx, |agent, _| {
2723 agent.sessions.get(&session_id).unwrap().thread.clone()
2724 });
2725 reloaded_thread.read_with(cx, |thread, _| {
2726 assert!(
2727 thread.thinking_enabled(),
2728 "thinking_enabled should be preserved when reloading a thread with a thinking model"
2729 );
2730 });
2731
2732 drop(reloaded_acp_thread);
2733 }
2734
2735 #[gpui::test]
2736 async fn test_loaded_thread_preserves_model(cx: &mut TestAppContext) {
2737 init_test(cx);
2738 let fs = FakeFs::new(cx.executor());
2739 fs.insert_tree("/", json!({ "a": {} })).await;
2740 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2741 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2742 let agent = cx.update(|cx| {
2743 NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
2744 });
2745 let connection = Rc::new(NativeAgentConnection(agent.clone()));
2746
2747 // Register a model where id() != name(), like real Anthropic models
2748 // (e.g. id="claude-sonnet-4-5-thinking-latest", name="Claude Sonnet 4.5 Thinking").
2749 let model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2750 "fake-corp",
2751 "custom-model-id",
2752 "Custom Model Display Name",
2753 false,
2754 ));
2755 let provider = Arc::new(
2756 FakeLanguageModelProvider::new(
2757 LanguageModelProviderId::from("fake-corp".to_string()),
2758 LanguageModelProviderName::from("Fake Corp".to_string()),
2759 )
2760 .with_models(vec![model.clone()]),
2761 );
2762 cx.update(|cx| {
2763 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2764 registry.register_provider(provider, cx);
2765 });
2766 });
2767 agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2768
2769 // Create a thread and select the model.
2770 let acp_thread = cx
2771 .update(|cx| {
2772 connection.clone().new_session(
2773 project.clone(),
2774 PathList::new(&[Path::new("/a")]),
2775 cx,
2776 )
2777 })
2778 .await
2779 .unwrap();
2780 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2781
2782 let selector = connection.model_selector(&session_id).unwrap();
2783 cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/custom-model-id"), cx))
2784 .await
2785 .unwrap();
2786
2787 let thread = agent.read_with(cx, |agent, _| {
2788 agent.sessions.get(&session_id).unwrap().thread.clone()
2789 });
2790 thread.read_with(cx, |thread, _| {
2791 assert_eq!(
2792 thread.model().unwrap().id().0.as_ref(),
2793 "custom-model-id",
2794 "model should be set before persisting"
2795 );
2796 });
2797
2798 // Send a message so the thread gets persisted.
2799 let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
2800 let send = cx.foreground_executor().spawn(send);
2801 cx.run_until_parked();
2802
2803 model.send_last_completion_stream_text_chunk("Response.");
2804 model.end_last_completion_stream();
2805
2806 send.await.unwrap();
2807 cx.run_until_parked();
2808
2809 // Close the session so it can be reloaded from disk.
2810 cx.update(|cx| connection.clone().close_session(&session_id, cx))
2811 .await
2812 .unwrap();
2813 drop(thread);
2814 drop(acp_thread);
2815 agent.read_with(cx, |agent, _| {
2816 assert!(agent.sessions.is_empty());
2817 });
2818
2819 // Reload the thread and verify the model was preserved.
2820 let reloaded_acp_thread = agent
2821 .update(cx, |agent, cx| {
2822 agent.open_thread(session_id.clone(), project.clone(), cx)
2823 })
2824 .await
2825 .unwrap();
2826 let reloaded_thread = agent.read_with(cx, |agent, _| {
2827 agent.sessions.get(&session_id).unwrap().thread.clone()
2828 });
2829 reloaded_thread.read_with(cx, |thread, _| {
2830 let reloaded_model = thread
2831 .model()
2832 .expect("model should be present after reload");
2833 assert_eq!(
2834 reloaded_model.id().0.as_ref(),
2835 "custom-model-id",
2836 "reloaded thread should have the same model, not fall back to the default"
2837 );
2838 });
2839
2840 drop(reloaded_acp_thread);
2841 }
2842
2843 #[gpui::test]
2844 async fn test_save_load_thread(cx: &mut TestAppContext) {
2845 init_test(cx);
2846 let fs = FakeFs::new(cx.executor());
2847 fs.insert_tree(
2848 "/",
2849 json!({
2850 "a": {
2851 "b.md": "Lorem"
2852 }
2853 }),
2854 )
2855 .await;
2856 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2857 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2858 let agent = cx.update(|cx| {
2859 NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
2860 });
2861 let connection = Rc::new(NativeAgentConnection(agent.clone()));
2862
2863 let acp_thread = cx
2864 .update(|cx| {
2865 connection
2866 .clone()
2867 .new_session(project.clone(), PathList::new(&[Path::new("")]), cx)
2868 })
2869 .await
2870 .unwrap();
2871 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2872 let thread = agent.read_with(cx, |agent, _| {
2873 agent.sessions.get(&session_id).unwrap().thread.clone()
2874 });
2875
2876 // Ensure empty threads are not saved, even if they get mutated.
2877 let model = Arc::new(FakeLanguageModel::default());
2878 let summary_model = Arc::new(FakeLanguageModel::default());
2879 thread.update(cx, |thread, cx| {
2880 thread.set_model(model.clone(), cx);
2881 thread.set_summarization_model(Some(summary_model.clone()), cx);
2882 });
2883 cx.run_until_parked();
2884 assert_eq!(thread_entries(&thread_store, cx), vec![]);
2885
2886 let send = acp_thread.update(cx, |thread, cx| {
2887 thread.send(
2888 vec![
2889 "What does ".into(),
2890 acp::ContentBlock::ResourceLink(acp::ResourceLink::new(
2891 "b.md",
2892 MentionUri::File {
2893 abs_path: path!("/a/b.md").into(),
2894 }
2895 .to_uri()
2896 .to_string(),
2897 )),
2898 " mean?".into(),
2899 ],
2900 cx,
2901 )
2902 });
2903 let send = cx.foreground_executor().spawn(send);
2904 cx.run_until_parked();
2905
2906 model.send_last_completion_stream_text_chunk("Lorem.");
2907 model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2908 language_model::TokenUsage {
2909 input_tokens: 150,
2910 output_tokens: 75,
2911 ..Default::default()
2912 },
2913 ));
2914 model.end_last_completion_stream();
2915 cx.run_until_parked();
2916 summary_model
2917 .send_last_completion_stream_text_chunk(&format!("Explaining {}", path!("/a/b.md")));
2918 summary_model.end_last_completion_stream();
2919
2920 send.await.unwrap();
2921 let uri = MentionUri::File {
2922 abs_path: path!("/a/b.md").into(),
2923 }
2924 .to_uri();
2925 acp_thread.read_with(cx, |thread, cx| {
2926 assert_eq!(
2927 thread.to_markdown(cx),
2928 formatdoc! {"
2929 ## User
2930
2931 What does [@b.md]({uri}) mean?
2932
2933 ## Assistant
2934
2935 Lorem.
2936
2937 "}
2938 )
2939 });
2940
2941 cx.run_until_parked();
2942
2943 // Set a draft prompt with rich content blocks and scroll position
2944 // AFTER run_until_parked, so the only save that captures these
2945 // changes is the one performed by close_session itself.
2946 let draft_blocks = vec![
2947 acp::ContentBlock::Text(acp::TextContent::new("Check out ")),
2948 acp::ContentBlock::ResourceLink(acp::ResourceLink::new("b.md", uri.to_string())),
2949 acp::ContentBlock::Text(acp::TextContent::new(" please")),
2950 ];
2951 acp_thread.update(cx, |thread, cx| {
2952 thread.set_draft_prompt(Some(draft_blocks.clone()), cx);
2953 });
2954 thread.update(cx, |thread, _cx| {
2955 thread.set_ui_scroll_position(Some(gpui::ListOffset {
2956 item_ix: 5,
2957 offset_in_item: gpui::px(12.5),
2958 }));
2959 });
2960
2961 // Close the session so it can be reloaded from disk.
2962 cx.update(|cx| connection.clone().close_session(&session_id, cx))
2963 .await
2964 .unwrap();
2965 drop(thread);
2966 drop(acp_thread);
2967 agent.read_with(cx, |agent, _| {
2968 assert_eq!(agent.sessions.keys().cloned().collect::<Vec<_>>(), []);
2969 });
2970
2971 // Ensure the thread can be reloaded from disk.
2972 assert_eq!(
2973 thread_entries(&thread_store, cx),
2974 vec![(
2975 session_id.clone(),
2976 format!("Explaining {}", path!("/a/b.md"))
2977 )]
2978 );
2979 let acp_thread = agent
2980 .update(cx, |agent, cx| {
2981 agent.open_thread(session_id.clone(), project.clone(), cx)
2982 })
2983 .await
2984 .unwrap();
2985 acp_thread.read_with(cx, |thread, cx| {
2986 assert_eq!(
2987 thread.to_markdown(cx),
2988 formatdoc! {"
2989 ## User
2990
2991 What does [@b.md]({uri}) mean?
2992
2993 ## Assistant
2994
2995 Lorem.
2996
2997 "}
2998 )
2999 });
3000
3001 // Ensure the draft prompt with rich content blocks survived the round-trip.
3002 acp_thread.read_with(cx, |thread, _| {
3003 assert_eq!(thread.draft_prompt(), Some(draft_blocks.as_slice()));
3004 });
3005
3006 // Ensure token usage survived the round-trip.
3007 acp_thread.read_with(cx, |thread, _| {
3008 let usage = thread
3009 .token_usage()
3010 .expect("token usage should be restored after reload");
3011 assert_eq!(usage.input_tokens, 150);
3012 assert_eq!(usage.output_tokens, 75);
3013 });
3014
3015 // Ensure scroll position survived the round-trip.
3016 acp_thread.read_with(cx, |thread, _| {
3017 let scroll = thread
3018 .ui_scroll_position()
3019 .expect("scroll position should be restored after reload");
3020 assert_eq!(scroll.item_ix, 5);
3021 assert_eq!(scroll.offset_in_item, gpui::px(12.5));
3022 });
3023 }
3024
3025 #[gpui::test]
3026 async fn test_close_session_saves_thread(cx: &mut TestAppContext) {
3027 init_test(cx);
3028 let fs = FakeFs::new(cx.executor());
3029 fs.insert_tree(
3030 "/",
3031 json!({
3032 "a": {
3033 "file.txt": "hello"
3034 }
3035 }),
3036 )
3037 .await;
3038 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
3039 let thread_store = cx.new(|cx| ThreadStore::new(cx));
3040 let agent = cx.update(|cx| {
3041 NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
3042 });
3043 let connection = Rc::new(NativeAgentConnection(agent.clone()));
3044
3045 let acp_thread = cx
3046 .update(|cx| {
3047 connection
3048 .clone()
3049 .new_session(project.clone(), PathList::new(&[Path::new("")]), cx)
3050 })
3051 .await
3052 .unwrap();
3053 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
3054 let thread = agent.read_with(cx, |agent, _| {
3055 agent.sessions.get(&session_id).unwrap().thread.clone()
3056 });
3057
3058 let model = Arc::new(FakeLanguageModel::default());
3059 thread.update(cx, |thread, cx| {
3060 thread.set_model(model.clone(), cx);
3061 });
3062
3063 // Send a message so the thread is non-empty (empty threads aren't saved).
3064 let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx));
3065 let send = cx.foreground_executor().spawn(send);
3066 cx.run_until_parked();
3067
3068 model.send_last_completion_stream_text_chunk("world");
3069 model.end_last_completion_stream();
3070 send.await.unwrap();
3071 cx.run_until_parked();
3072
3073 // Set a draft prompt WITHOUT calling run_until_parked afterwards.
3074 // This means no observe-triggered save has run for this change.
3075 // The only way this data gets persisted is if close_session
3076 // itself performs the save.
3077 let draft_blocks = vec![acp::ContentBlock::Text(acp::TextContent::new(
3078 "unsaved draft",
3079 ))];
3080 acp_thread.update(cx, |thread, cx| {
3081 thread.set_draft_prompt(Some(draft_blocks.clone()), cx);
3082 });
3083
3084 // Close the session immediately — no run_until_parked in between.
3085 cx.update(|cx| connection.clone().close_session(&session_id, cx))
3086 .await
3087 .unwrap();
3088 cx.run_until_parked();
3089
3090 // Reopen and verify the draft prompt was saved.
3091 let reloaded = agent
3092 .update(cx, |agent, cx| {
3093 agent.open_thread(session_id.clone(), project.clone(), cx)
3094 })
3095 .await
3096 .unwrap();
3097 reloaded.read_with(cx, |thread, _| {
3098 assert_eq!(
3099 thread.draft_prompt(),
3100 Some(draft_blocks.as_slice()),
3101 "close_session must save the thread; draft prompt was lost"
3102 );
3103 });
3104 }
3105
3106 #[gpui::test]
3107 async fn test_thread_summary_releases_loaded_session(cx: &mut TestAppContext) {
3108 init_test(cx);
3109 let fs = FakeFs::new(cx.executor());
3110 fs.insert_tree(
3111 "/",
3112 json!({
3113 "a": {
3114 "file.txt": "hello"
3115 }
3116 }),
3117 )
3118 .await;
3119 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
3120 let thread_store = cx.new(|cx| ThreadStore::new(cx));
3121 let agent = cx.update(|cx| {
3122 NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
3123 });
3124 let connection = Rc::new(NativeAgentConnection(agent.clone()));
3125
3126 let acp_thread = cx
3127 .update(|cx| {
3128 connection
3129 .clone()
3130 .new_session(project.clone(), PathList::new(&[Path::new("")]), cx)
3131 })
3132 .await
3133 .unwrap();
3134 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
3135 let thread = agent.read_with(cx, |agent, _| {
3136 agent.sessions.get(&session_id).unwrap().thread.clone()
3137 });
3138
3139 let model = Arc::new(FakeLanguageModel::default());
3140 let summary_model = Arc::new(FakeLanguageModel::default());
3141 thread.update(cx, |thread, cx| {
3142 thread.set_model(model.clone(), cx);
3143 thread.set_summarization_model(Some(summary_model.clone()), cx);
3144 });
3145
3146 let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx));
3147 let send = cx.foreground_executor().spawn(send);
3148 cx.run_until_parked();
3149
3150 model.send_last_completion_stream_text_chunk("world");
3151 model.end_last_completion_stream();
3152 send.await.unwrap();
3153 cx.run_until_parked();
3154
3155 let summary = agent.update(cx, |agent, cx| {
3156 agent.thread_summary(session_id.clone(), project.clone(), cx)
3157 });
3158 cx.run_until_parked();
3159
3160 summary_model.send_last_completion_stream_text_chunk("summary");
3161 summary_model.end_last_completion_stream();
3162
3163 assert_eq!(summary.await.unwrap(), "summary");
3164 cx.run_until_parked();
3165
3166 agent.read_with(cx, |agent, _| {
3167 let session = agent
3168 .sessions
3169 .get(&session_id)
3170 .expect("thread_summary should not close the active session");
3171 assert_eq!(
3172 session.ref_count, 1,
3173 "thread_summary should release its temporary session reference"
3174 );
3175 });
3176
3177 cx.update(|cx| connection.clone().close_session(&session_id, cx))
3178 .await
3179 .unwrap();
3180 cx.run_until_parked();
3181
3182 agent.read_with(cx, |agent, _| {
3183 assert!(
3184 agent.sessions.is_empty(),
3185 "closing the active session after thread_summary should unload it"
3186 );
3187 });
3188 }
3189
3190 #[gpui::test]
3191 async fn test_loaded_sessions_keep_state_until_last_close(cx: &mut TestAppContext) {
3192 init_test(cx);
3193 let fs = FakeFs::new(cx.executor());
3194 fs.insert_tree(
3195 "/",
3196 json!({
3197 "a": {
3198 "file.txt": "hello"
3199 }
3200 }),
3201 )
3202 .await;
3203 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
3204 let thread_store = cx.new(|cx| ThreadStore::new(cx));
3205 let agent = cx.update(|cx| {
3206 NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
3207 });
3208 let connection = Rc::new(NativeAgentConnection(agent.clone()));
3209
3210 let acp_thread = cx
3211 .update(|cx| {
3212 connection
3213 .clone()
3214 .new_session(project.clone(), PathList::new(&[Path::new("")]), cx)
3215 })
3216 .await
3217 .unwrap();
3218 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
3219 let thread = agent.read_with(cx, |agent, _| {
3220 agent.sessions.get(&session_id).unwrap().thread.clone()
3221 });
3222
3223 let model = cx.update(|cx| {
3224 LanguageModelRegistry::read_global(cx)
3225 .default_model()
3226 .map(|default_model| default_model.model)
3227 .expect("default test model should be available")
3228 });
3229 let fake_model = model.as_fake();
3230 thread.update(cx, |thread, cx| {
3231 thread.set_model(model.clone(), cx);
3232 });
3233
3234 let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx));
3235 let send = cx.foreground_executor().spawn(send);
3236 cx.run_until_parked();
3237
3238 fake_model.send_last_completion_stream_text_chunk("world");
3239 fake_model.end_last_completion_stream();
3240 send.await.unwrap();
3241 cx.run_until_parked();
3242
3243 cx.update(|cx| connection.clone().close_session(&session_id, cx))
3244 .await
3245 .unwrap();
3246 drop(thread);
3247 drop(acp_thread);
3248 agent.read_with(cx, |agent, _| {
3249 assert!(agent.sessions.is_empty());
3250 });
3251
3252 let first_loaded_thread = cx.update(|cx| {
3253 connection.clone().load_session(
3254 session_id.clone(),
3255 project.clone(),
3256 PathList::new(&[Path::new("")]),
3257 None,
3258 cx,
3259 )
3260 });
3261 let second_loaded_thread = cx.update(|cx| {
3262 connection.clone().load_session(
3263 session_id.clone(),
3264 project.clone(),
3265 PathList::new(&[Path::new("")]),
3266 None,
3267 cx,
3268 )
3269 });
3270
3271 let first_loaded_thread = first_loaded_thread.await.unwrap();
3272 let second_loaded_thread = second_loaded_thread.await.unwrap();
3273
3274 cx.run_until_parked();
3275
3276 assert_eq!(
3277 first_loaded_thread.entity_id(),
3278 second_loaded_thread.entity_id(),
3279 "concurrent loads for the same session should share one AcpThread"
3280 );
3281
3282 cx.update(|cx| connection.clone().close_session(&session_id, cx))
3283 .await
3284 .unwrap();
3285
3286 agent.read_with(cx, |agent, _| {
3287 assert!(
3288 agent.sessions.contains_key(&session_id),
3289 "closing one loaded session should not drop shared session state"
3290 );
3291 });
3292
3293 let follow_up = second_loaded_thread.update(cx, |thread, cx| {
3294 thread.send(vec!["still there?".into()], cx)
3295 });
3296 let follow_up = cx.foreground_executor().spawn(follow_up);
3297 cx.run_until_parked();
3298
3299 fake_model.send_last_completion_stream_text_chunk("yes");
3300 fake_model.end_last_completion_stream();
3301 follow_up.await.unwrap();
3302 cx.run_until_parked();
3303
3304 second_loaded_thread.read_with(cx, |thread, cx| {
3305 assert_eq!(
3306 thread.to_markdown(cx),
3307 formatdoc! {"
3308 ## User
3309
3310 hello
3311
3312 ## Assistant
3313
3314 world
3315
3316 ## User
3317
3318 still there?
3319
3320 ## Assistant
3321
3322 yes
3323
3324 "}
3325 );
3326 });
3327
3328 cx.update(|cx| connection.clone().close_session(&session_id, cx))
3329 .await
3330 .unwrap();
3331
3332 cx.run_until_parked();
3333
3334 drop(first_loaded_thread);
3335 drop(second_loaded_thread);
3336 agent.read_with(cx, |agent, _| {
3337 assert!(agent.sessions.is_empty());
3338 });
3339 }
3340
3341 #[gpui::test]
3342 async fn test_rapid_title_changes_do_not_loop(cx: &mut TestAppContext) {
3343 // Regression test: rapid title changes must not cause a propagation loop
3344 // between Thread and AcpThread via handle_thread_title_updated.
3345 init_test(cx);
3346 let fs = FakeFs::new(cx.executor());
3347 fs.insert_tree("/", json!({ "a": {} })).await;
3348 let project = Project::test(fs.clone(), [], cx).await;
3349 let thread_store = cx.new(|cx| ThreadStore::new(cx));
3350 let agent = cx.update(|cx| {
3351 NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
3352 });
3353 let connection = Rc::new(NativeAgentConnection(agent.clone()));
3354
3355 let acp_thread = cx
3356 .update(|cx| {
3357 connection
3358 .clone()
3359 .new_session(project.clone(), PathList::new(&[Path::new("")]), cx)
3360 })
3361 .await
3362 .unwrap();
3363
3364 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
3365 let thread = agent.read_with(cx, |agent, _| {
3366 agent.sessions.get(&session_id).unwrap().thread.clone()
3367 });
3368
3369 let title_updated_count = Rc::new(std::cell::RefCell::new(0usize));
3370 cx.update(|cx| {
3371 let count = title_updated_count.clone();
3372 cx.subscribe(
3373 &thread,
3374 move |_entity: Entity<Thread>, _event: &TitleUpdated, _cx: &mut App| {
3375 let new_count = {
3376 let mut count = count.borrow_mut();
3377 *count += 1;
3378 *count
3379 };
3380 assert!(
3381 new_count <= 2,
3382 "TitleUpdated fired {new_count} times; \
3383 title updates are looping"
3384 );
3385 },
3386 )
3387 .detach();
3388 });
3389
3390 thread.update(cx, |thread, cx| thread.set_title("first".into(), cx));
3391 thread.update(cx, |thread, cx| thread.set_title("second".into(), cx));
3392
3393 cx.run_until_parked();
3394
3395 thread.read_with(cx, |thread, _| {
3396 assert_eq!(thread.title(), Some("second".into()));
3397 });
3398 acp_thread.read_with(cx, |acp_thread, _| {
3399 assert_eq!(acp_thread.title(), Some("second".into()));
3400 });
3401
3402 assert_eq!(*title_updated_count.borrow(), 2);
3403 }
3404
3405 fn thread_entries(
3406 thread_store: &Entity<ThreadStore>,
3407 cx: &mut TestAppContext,
3408 ) -> Vec<(acp::SessionId, String)> {
3409 thread_store.read_with(cx, |store, _| {
3410 store
3411 .entries()
3412 .map(|entry| (entry.id.clone(), entry.title.to_string()))
3413 .collect::<Vec<_>>()
3414 })
3415 }
3416
3417 fn init_test(cx: &mut TestAppContext) {
3418 env_logger::try_init().ok();
3419 cx.update(|cx| {
3420 let settings_store = SettingsStore::test(cx);
3421 cx.set_global(settings_store);
3422
3423 LanguageModelRegistry::test(cx);
3424 });
3425 }
3426}
3427
3428fn mcp_message_content_to_acp_content_block(
3429 content: context_server::types::MessageContent,
3430) -> acp::ContentBlock {
3431 match content {
3432 context_server::types::MessageContent::Text {
3433 text,
3434 annotations: _,
3435 } => text.into(),
3436 context_server::types::MessageContent::Image {
3437 data,
3438 mime_type,
3439 annotations: _,
3440 } => acp::ContentBlock::Image(acp::ImageContent::new(data, mime_type)),
3441 context_server::types::MessageContent::Audio {
3442 data,
3443 mime_type,
3444 annotations: _,
3445 } => acp::ContentBlock::Audio(acp::AudioContent::new(data, mime_type)),
3446 context_server::types::MessageContent::Resource {
3447 resource,
3448 annotations: _,
3449 } => {
3450 let mut link =
3451 acp::ResourceLink::new(resource.uri.to_string(), resource.uri.to_string());
3452 if let Some(mime_type) = resource.mime_type {
3453 link = link.mime_type(mime_type);
3454 }
3455 acp::ContentBlock::ResourceLink(link)
3456 }
3457 }
3458}