1mod db;
2mod edit_agent;
3mod legacy_thread;
4mod native_agent_server;
5pub mod outline;
6mod pattern_extraction;
7mod templates;
8#[cfg(test)]
9mod tests;
10mod thread;
11mod thread_store;
12mod tool_permissions;
13mod tools;
14
15use context_server::ContextServerId;
16pub use db::*;
17use itertools::Itertools;
18pub use native_agent_server::NativeAgentServer;
19pub use pattern_extraction::*;
20pub use shell_command_parser::extract_commands;
21pub use templates::*;
22pub use thread::*;
23pub use thread_store::*;
24pub use tool_permissions::*;
25pub use tools::*;
26
27use acp_thread::{
28 AcpThread, AgentModelSelector, AgentSessionInfo, AgentSessionList, AgentSessionListRequest,
29 AgentSessionListResponse, TokenUsageRatio, UserMessageId,
30};
31use agent_client_protocol as acp;
32use anyhow::{Context as _, Result, anyhow};
33use chrono::{DateTime, Utc};
34use collections::{HashMap, HashSet, IndexMap};
35use fs::Fs;
36use futures::channel::{mpsc, oneshot};
37use futures::future::Shared;
38use futures::{FutureExt as _, StreamExt as _, future};
39use gpui::{
40 App, AppContext, AsyncApp, Context, Entity, EntityId, SharedString, Subscription, Task,
41 WeakEntity,
42};
43use language_model::{IconOrSvg, LanguageModel, LanguageModelProvider, LanguageModelRegistry};
44use project::{AgentId, Project, ProjectItem, ProjectPath, Worktree};
45use prompt_store::{
46 ProjectContext, PromptStore, RULES_FILE_NAMES, RulesFileContext, UserRulesContext,
47 WorktreeContext,
48};
49use serde::{Deserialize, Serialize};
50use settings::{LanguageModelSelection, Settings as _, update_settings_file};
51use std::any::Any;
52use std::path::PathBuf;
53use std::rc::Rc;
54use std::sync::{Arc, LazyLock};
55use util::ResultExt;
56use util::path_list::PathList;
57use util::rel_path::RelPath;
58
59#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
60pub struct ProjectSnapshot {
61 pub worktree_snapshots: Vec<project::telemetry_snapshot::TelemetryWorktreeSnapshot>,
62 pub timestamp: DateTime<Utc>,
63}
64
65pub struct RulesLoadingError {
66 pub message: SharedString,
67}
68
69struct ProjectState {
70 project: Entity<Project>,
71 project_context: Entity<ProjectContext>,
72 project_context_needs_refresh: watch::Sender<()>,
73 _maintain_project_context: Task<Result<()>>,
74 context_server_registry: Entity<ContextServerRegistry>,
75 _subscriptions: Vec<Subscription>,
76}
77
78/// Holds both the internal Thread and the AcpThread for a session
79struct Session {
80 /// The internal thread that processes messages
81 thread: Entity<Thread>,
82 /// The ACP thread that handles protocol communication
83 acp_thread: Entity<acp_thread::AcpThread>,
84 project_id: EntityId,
85 pending_save: Task<Result<()>>,
86 _subscriptions: Vec<Subscription>,
87 ref_count: usize,
88}
89
90struct PendingSession {
91 task: Shared<Task<Result<Entity<AcpThread>, Arc<anyhow::Error>>>>,
92 ref_count: usize,
93}
94
95pub struct LanguageModels {
96 /// Access language model by ID
97 models: HashMap<acp::ModelId, Arc<dyn LanguageModel>>,
98 /// Cached list for returning language model information
99 model_list: acp_thread::AgentModelList,
100 refresh_models_rx: watch::Receiver<()>,
101 refresh_models_tx: watch::Sender<()>,
102 _authenticate_all_providers_task: Task<()>,
103}
104
105impl LanguageModels {
106 fn new(cx: &mut App) -> Self {
107 let (refresh_models_tx, refresh_models_rx) = watch::channel(());
108
109 let mut this = Self {
110 models: HashMap::default(),
111 model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()),
112 refresh_models_rx,
113 refresh_models_tx,
114 _authenticate_all_providers_task: Self::authenticate_all_language_model_providers(cx),
115 };
116 this.refresh_list(cx);
117 this
118 }
119
120 fn refresh_list(&mut self, cx: &App) {
121 let providers = LanguageModelRegistry::global(cx)
122 .read(cx)
123 .visible_providers()
124 .into_iter()
125 .filter(|provider| provider.is_authenticated(cx))
126 .collect::<Vec<_>>();
127
128 let mut language_model_list = IndexMap::default();
129 let mut recommended_models = HashSet::default();
130
131 let mut recommended = Vec::new();
132 for provider in &providers {
133 for model in provider.recommended_models(cx) {
134 recommended_models.insert((model.provider_id(), model.id()));
135 recommended.push(Self::map_language_model_to_info(&model, provider));
136 }
137 }
138 if !recommended.is_empty() {
139 language_model_list.insert(
140 acp_thread::AgentModelGroupName("Recommended".into()),
141 recommended,
142 );
143 }
144
145 let mut models = HashMap::default();
146 for provider in providers {
147 let mut provider_models = Vec::new();
148 for model in provider.provided_models(cx) {
149 let model_info = Self::map_language_model_to_info(&model, &provider);
150 let model_id = model_info.id.clone();
151 provider_models.push(model_info);
152 models.insert(model_id, model);
153 }
154 if !provider_models.is_empty() {
155 language_model_list.insert(
156 acp_thread::AgentModelGroupName(provider.name().0.clone()),
157 provider_models,
158 );
159 }
160 }
161
162 self.models = models;
163 self.model_list = acp_thread::AgentModelList::Grouped(language_model_list);
164 self.refresh_models_tx.send(()).ok();
165 }
166
167 fn watch(&self) -> watch::Receiver<()> {
168 self.refresh_models_rx.clone()
169 }
170
171 pub fn model_from_id(&self, model_id: &acp::ModelId) -> Option<Arc<dyn LanguageModel>> {
172 self.models.get(model_id).cloned()
173 }
174
175 fn map_language_model_to_info(
176 model: &Arc<dyn LanguageModel>,
177 provider: &Arc<dyn LanguageModelProvider>,
178 ) -> acp_thread::AgentModelInfo {
179 acp_thread::AgentModelInfo {
180 id: Self::model_id(model),
181 name: model.name().0,
182 description: None,
183 icon: Some(match provider.icon() {
184 IconOrSvg::Svg(path) => acp_thread::AgentModelIcon::Path(path),
185 IconOrSvg::Icon(name) => acp_thread::AgentModelIcon::Named(name),
186 }),
187 is_latest: model.is_latest(),
188 cost: model.model_cost_info().map(|cost| cost.to_shared_string()),
189 }
190 }
191
192 fn model_id(model: &Arc<dyn LanguageModel>) -> acp::ModelId {
193 acp::ModelId::new(format!("{}/{}", model.provider_id().0, model.id().0))
194 }
195
196 fn authenticate_all_language_model_providers(cx: &mut App) -> Task<()> {
197 let authenticate_all_providers = LanguageModelRegistry::global(cx)
198 .read(cx)
199 .visible_providers()
200 .iter()
201 .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
202 .collect::<Vec<_>>();
203
204 cx.spawn(async move |cx| {
205 for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
206 if let Err(err) = authenticate_task.await {
207 match err {
208 language_model::AuthenticateError::CredentialsNotFound => {
209 // Since we're authenticating these providers in the
210 // background for the purposes of populating the
211 // language selector, we don't care about providers
212 // where the credentials are not found.
213 }
214 language_model::AuthenticateError::ConnectionRefused => {
215 // Not logging connection refused errors as they are mostly from LM Studio's noisy auth failures.
216 // LM Studio only has one auth method (endpoint call) which fails for users who haven't enabled it.
217 // TODO: Better manage LM Studio auth logic to avoid these noisy failures.
218 }
219 _ => {
220 // Some providers have noisy failure states that we
221 // don't want to spam the logs with every time the
222 // language model selector is initialized.
223 //
224 // Ideally these should have more clear failure modes
225 // that we know are safe to ignore here, like what we do
226 // with `CredentialsNotFound` above.
227 match provider_id.0.as_ref() {
228 "lmstudio" | "ollama" => {
229 // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
230 //
231 // These fail noisily, so we don't log them.
232 }
233 "copilot_chat" => {
234 // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
235 }
236 _ => {
237 log::error!(
238 "Failed to authenticate provider: {}: {err:#}",
239 provider_name.0
240 );
241 }
242 }
243 }
244 }
245 }
246 }
247
248 cx.update(language_models::update_environment_fallback_model);
249 })
250 }
251}
252
253pub struct NativeAgent {
254 /// Session ID -> Session mapping
255 sessions: HashMap<acp::SessionId, Session>,
256 pending_sessions: HashMap<acp::SessionId, PendingSession>,
257 thread_store: Entity<ThreadStore>,
258 /// Project-specific state keyed by project EntityId
259 projects: HashMap<EntityId, ProjectState>,
260 /// Shared templates for all threads
261 templates: Arc<Templates>,
262 /// Cached model information
263 models: LanguageModels,
264 prompt_store: Option<Entity<PromptStore>>,
265 fs: Arc<dyn Fs>,
266 _subscriptions: Vec<Subscription>,
267}
268
269impl NativeAgent {
270 pub fn new(
271 thread_store: Entity<ThreadStore>,
272 templates: Arc<Templates>,
273 prompt_store: Option<Entity<PromptStore>>,
274 fs: Arc<dyn Fs>,
275 cx: &mut App,
276 ) -> Entity<NativeAgent> {
277 log::debug!("Creating new NativeAgent");
278
279 cx.new(|cx| {
280 let mut subscriptions = vec![cx.subscribe(
281 &LanguageModelRegistry::global(cx),
282 Self::handle_models_updated_event,
283 )];
284 if let Some(prompt_store) = prompt_store.as_ref() {
285 subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
286 }
287
288 Self {
289 sessions: HashMap::default(),
290 pending_sessions: HashMap::default(),
291 thread_store,
292 projects: HashMap::default(),
293 templates,
294 models: LanguageModels::new(cx),
295 prompt_store,
296 fs,
297 _subscriptions: subscriptions,
298 }
299 })
300 }
301
302 fn new_session(
303 &mut self,
304 project: Entity<Project>,
305 cx: &mut Context<Self>,
306 ) -> Entity<AcpThread> {
307 let project_id = self.get_or_create_project_state(&project, cx);
308 let project_state = &self.projects[&project_id];
309
310 let registry = LanguageModelRegistry::read_global(cx);
311 let available_count = registry.available_models(cx).count();
312 log::debug!("Total available models: {}", available_count);
313
314 let default_model = registry.default_model().and_then(|default_model| {
315 self.models
316 .model_from_id(&LanguageModels::model_id(&default_model.model))
317 });
318 let thread = cx.new(|cx| {
319 Thread::new(
320 project,
321 project_state.project_context.clone(),
322 project_state.context_server_registry.clone(),
323 self.templates.clone(),
324 default_model,
325 cx,
326 )
327 });
328
329 self.register_session(thread, project_id, 1, cx)
330 }
331
332 fn register_session(
333 &mut self,
334 thread_handle: Entity<Thread>,
335 project_id: EntityId,
336 ref_count: usize,
337 cx: &mut Context<Self>,
338 ) -> Entity<AcpThread> {
339 let connection = Rc::new(NativeAgentConnection(cx.entity()));
340
341 let thread = thread_handle.read(cx);
342 let session_id = thread.id().clone();
343 let parent_session_id = thread.parent_thread_id();
344 let title = thread.title();
345 let draft_prompt = thread.draft_prompt().map(Vec::from);
346 let scroll_position = thread.ui_scroll_position();
347 let token_usage = thread.latest_token_usage();
348 let project = thread.project.clone();
349 let action_log = thread.action_log.clone();
350 let prompt_capabilities_rx = thread.prompt_capabilities_rx.clone();
351 let acp_thread = cx.new(|cx| {
352 let mut acp_thread = acp_thread::AcpThread::new(
353 parent_session_id,
354 title,
355 None,
356 connection,
357 project.clone(),
358 action_log.clone(),
359 session_id.clone(),
360 prompt_capabilities_rx,
361 cx,
362 );
363 acp_thread.set_draft_prompt(draft_prompt, cx);
364 acp_thread.set_ui_scroll_position(scroll_position);
365 acp_thread.update_token_usage(token_usage, cx);
366 acp_thread
367 });
368
369 let registry = LanguageModelRegistry::read_global(cx);
370 let summarization_model = registry.thread_summary_model(cx).map(|c| c.model);
371
372 let weak = cx.weak_entity();
373 let weak_thread = thread_handle.downgrade();
374 thread_handle.update(cx, |thread, cx| {
375 thread.set_summarization_model(summarization_model, cx);
376 thread.add_default_tools(
377 Rc::new(NativeThreadEnvironment {
378 acp_thread: acp_thread.downgrade(),
379 thread: weak_thread,
380 agent: weak,
381 }) as _,
382 cx,
383 )
384 });
385
386 let subscriptions = vec![
387 cx.subscribe(&thread_handle, Self::handle_thread_title_updated),
388 cx.subscribe(&thread_handle, Self::handle_thread_token_usage_updated),
389 cx.observe(&thread_handle, move |this, thread, cx| {
390 this.save_thread(thread, cx)
391 }),
392 ];
393
394 self.sessions.insert(
395 session_id,
396 Session {
397 thread: thread_handle,
398 acp_thread: acp_thread.clone(),
399 project_id,
400 _subscriptions: subscriptions,
401 pending_save: Task::ready(Ok(())),
402 ref_count,
403 },
404 );
405
406 self.update_available_commands_for_project(project_id, cx);
407
408 acp_thread
409 }
410
411 pub fn models(&self) -> &LanguageModels {
412 &self.models
413 }
414
415 fn get_or_create_project_state(
416 &mut self,
417 project: &Entity<Project>,
418 cx: &mut Context<Self>,
419 ) -> EntityId {
420 let project_id = project.entity_id();
421 if self.projects.contains_key(&project_id) {
422 return project_id;
423 }
424
425 let project_context = cx.new(|_| ProjectContext::new(vec![], vec![]));
426 self.register_project_with_initial_context(project.clone(), project_context, cx);
427 if let Some(state) = self.projects.get_mut(&project_id) {
428 state.project_context_needs_refresh.send(()).ok();
429 }
430 project_id
431 }
432
433 fn register_project_with_initial_context(
434 &mut self,
435 project: Entity<Project>,
436 project_context: Entity<ProjectContext>,
437 cx: &mut Context<Self>,
438 ) {
439 let project_id = project.entity_id();
440
441 let context_server_store = project.read(cx).context_server_store();
442 let context_server_registry =
443 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
444
445 let subscriptions = vec![
446 cx.subscribe(&project, Self::handle_project_event),
447 cx.subscribe(
448 &context_server_store,
449 Self::handle_context_server_store_updated,
450 ),
451 cx.subscribe(
452 &context_server_registry,
453 Self::handle_context_server_registry_event,
454 ),
455 ];
456
457 let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
458 watch::channel(());
459
460 self.projects.insert(
461 project_id,
462 ProjectState {
463 project,
464 project_context,
465 project_context_needs_refresh: project_context_needs_refresh_tx,
466 _maintain_project_context: cx.spawn(async move |this, cx| {
467 Self::maintain_project_context(
468 this,
469 project_id,
470 project_context_needs_refresh_rx,
471 cx,
472 )
473 .await
474 }),
475 context_server_registry,
476 _subscriptions: subscriptions,
477 },
478 );
479 }
480
481 fn session_project_state(&self, session_id: &acp::SessionId) -> Option<&ProjectState> {
482 self.sessions
483 .get(session_id)
484 .and_then(|session| self.projects.get(&session.project_id))
485 }
486
487 async fn maintain_project_context(
488 this: WeakEntity<Self>,
489 project_id: EntityId,
490 mut needs_refresh: watch::Receiver<()>,
491 cx: &mut AsyncApp,
492 ) -> Result<()> {
493 while needs_refresh.changed().await.is_ok() {
494 let project_context = this
495 .update(cx, |this, cx| {
496 let state = this
497 .projects
498 .get(&project_id)
499 .context("project state not found")?;
500 anyhow::Ok(Self::build_project_context(
501 &state.project,
502 this.prompt_store.as_ref(),
503 cx,
504 ))
505 })??
506 .await;
507 this.update(cx, |this, cx| {
508 if let Some(state) = this.projects.get(&project_id) {
509 state
510 .project_context
511 .update(cx, |current_project_context, _cx| {
512 *current_project_context = project_context;
513 });
514 }
515 })?;
516 }
517
518 Ok(())
519 }
520
521 fn build_project_context(
522 project: &Entity<Project>,
523 prompt_store: Option<&Entity<PromptStore>>,
524 cx: &mut App,
525 ) -> Task<ProjectContext> {
526 let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
527 let worktree_tasks = worktrees
528 .into_iter()
529 .map(|worktree| {
530 Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
531 })
532 .collect::<Vec<_>>();
533 let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
534 prompt_store.read_with(cx, |prompt_store, cx| {
535 let prompts = prompt_store.default_prompt_metadata();
536 let load_tasks = prompts.into_iter().map(|prompt_metadata| {
537 let contents = prompt_store.load(prompt_metadata.id, cx);
538 async move { (contents.await, prompt_metadata) }
539 });
540 cx.background_spawn(future::join_all(load_tasks))
541 })
542 } else {
543 Task::ready(vec![])
544 };
545
546 cx.spawn(async move |_cx| {
547 let (worktrees, default_user_rules) =
548 future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
549
550 let worktrees = worktrees
551 .into_iter()
552 .map(|(worktree, _rules_error)| {
553 // TODO: show error message
554 // if let Some(rules_error) = rules_error {
555 // this.update(cx, |_, cx| cx.emit(rules_error)).ok();
556 // }
557 worktree
558 })
559 .collect::<Vec<_>>();
560
561 let default_user_rules = default_user_rules
562 .into_iter()
563 .flat_map(|(contents, prompt_metadata)| match contents {
564 Ok(contents) => Some(UserRulesContext {
565 uuid: prompt_metadata.id.as_user()?,
566 title: prompt_metadata.title.map(|title| title.to_string()),
567 contents,
568 }),
569 Err(_err) => {
570 // TODO: show error message
571 // this.update(cx, |_, cx| {
572 // cx.emit(RulesLoadingError {
573 // message: format!("{err:?}").into(),
574 // });
575 // })
576 // .ok();
577 None
578 }
579 })
580 .collect::<Vec<_>>();
581
582 ProjectContext::new(worktrees, default_user_rules)
583 })
584 }
585
586 fn load_worktree_info_for_system_prompt(
587 worktree: Entity<Worktree>,
588 project: Entity<Project>,
589 cx: &mut App,
590 ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
591 let tree = worktree.read(cx);
592 let root_name = tree.root_name_str().into();
593 let abs_path = tree.abs_path();
594
595 let mut context = WorktreeContext {
596 root_name,
597 abs_path,
598 rules_file: None,
599 };
600
601 let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
602 let Some(rules_task) = rules_task else {
603 return Task::ready((context, None));
604 };
605
606 cx.spawn(async move |_| {
607 let (rules_file, rules_file_error) = match rules_task.await {
608 Ok(rules_file) => (Some(rules_file), None),
609 Err(err) => (
610 None,
611 Some(RulesLoadingError {
612 message: format!("{err}").into(),
613 }),
614 ),
615 };
616 context.rules_file = rules_file;
617 (context, rules_file_error)
618 })
619 }
620
621 fn load_worktree_rules_file(
622 worktree: Entity<Worktree>,
623 project: Entity<Project>,
624 cx: &mut App,
625 ) -> Option<Task<Result<RulesFileContext>>> {
626 let worktree = worktree.read(cx);
627 let worktree_id = worktree.id();
628 let selected_rules_file = RULES_FILE_NAMES
629 .into_iter()
630 .filter_map(|name| {
631 worktree
632 .entry_for_path(RelPath::unix(name).unwrap())
633 .filter(|entry| entry.is_file())
634 .map(|entry| entry.path.clone())
635 })
636 .next();
637
638 // Note that Cline supports `.clinerules` being a directory, but that is not currently
639 // supported. This doesn't seem to occur often in GitHub repositories.
640 selected_rules_file.map(|path_in_worktree| {
641 let project_path = ProjectPath {
642 worktree_id,
643 path: path_in_worktree.clone(),
644 };
645 let buffer_task =
646 project.update(cx, |project, cx| project.open_buffer(project_path, cx));
647 let rope_task = cx.spawn(async move |cx| {
648 let buffer = buffer_task.await?;
649 let (project_entry_id, rope) = buffer.read_with(cx, |buffer, cx| {
650 let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
651 anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
652 })?;
653 anyhow::Ok((project_entry_id, rope))
654 });
655 // Build a string from the rope on a background thread.
656 cx.background_spawn(async move {
657 let (project_entry_id, rope) = rope_task.await?;
658 anyhow::Ok(RulesFileContext {
659 path_in_worktree,
660 text: rope.to_string().trim().to_string(),
661 project_entry_id: project_entry_id.to_usize(),
662 })
663 })
664 })
665 }
666
667 fn handle_thread_title_updated(
668 &mut self,
669 thread: Entity<Thread>,
670 _: &TitleUpdated,
671 cx: &mut Context<Self>,
672 ) {
673 let session_id = thread.read(cx).id();
674 let Some(session) = self.sessions.get(session_id) else {
675 return;
676 };
677
678 let thread = thread.downgrade();
679 let acp_thread = session.acp_thread.downgrade();
680 cx.spawn(async move |_, cx| {
681 let title = thread.read_with(cx, |thread, _| thread.title())?;
682 if let Some(title) = title {
683 let task =
684 acp_thread.update(cx, |acp_thread, cx| acp_thread.set_title(title, cx))?;
685 task.await?;
686 }
687 anyhow::Ok(())
688 })
689 .detach_and_log_err(cx);
690 }
691
692 fn handle_thread_token_usage_updated(
693 &mut self,
694 thread: Entity<Thread>,
695 usage: &TokenUsageUpdated,
696 cx: &mut Context<Self>,
697 ) {
698 let Some(session) = self.sessions.get(thread.read(cx).id()) else {
699 return;
700 };
701 session.acp_thread.update(cx, |acp_thread, cx| {
702 acp_thread.update_token_usage(usage.0.clone(), cx);
703 });
704 }
705
706 fn handle_project_event(
707 &mut self,
708 project: Entity<Project>,
709 event: &project::Event,
710 _cx: &mut Context<Self>,
711 ) {
712 let project_id = project.entity_id();
713 let Some(state) = self.projects.get_mut(&project_id) else {
714 return;
715 };
716 match event {
717 project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
718 state.project_context_needs_refresh.send(()).ok();
719 }
720 project::Event::WorktreeUpdatedEntries(_, items) => {
721 if items.iter().any(|(path, _, _)| {
722 RULES_FILE_NAMES
723 .iter()
724 .any(|name| path.as_ref() == RelPath::unix(name).unwrap())
725 }) {
726 state.project_context_needs_refresh.send(()).ok();
727 }
728 }
729 _ => {}
730 }
731 }
732
733 fn handle_prompts_updated_event(
734 &mut self,
735 _prompt_store: Entity<PromptStore>,
736 _event: &prompt_store::PromptsUpdatedEvent,
737 _cx: &mut Context<Self>,
738 ) {
739 for state in self.projects.values_mut() {
740 state.project_context_needs_refresh.send(()).ok();
741 }
742 }
743
744 fn handle_models_updated_event(
745 &mut self,
746 _registry: Entity<LanguageModelRegistry>,
747 event: &language_model::Event,
748 cx: &mut Context<Self>,
749 ) {
750 self.models.refresh_list(cx);
751
752 let registry = LanguageModelRegistry::read_global(cx);
753 let default_model = registry.default_model().map(|m| m.model);
754 let summarization_model = registry.thread_summary_model(cx).map(|m| m.model);
755
756 for session in self.sessions.values_mut() {
757 session.thread.update(cx, |thread, cx| {
758 let should_update_model = thread.model().is_none()
759 || (thread.is_empty()
760 && matches!(event, language_model::Event::DefaultModelChanged));
761 if should_update_model && let Some(model) = default_model.clone() {
762 thread.set_model(model, cx);
763 cx.notify();
764 }
765 if let Some(model) = summarization_model.clone() {
766 if thread.summarization_model().is_none()
767 || matches!(event, language_model::Event::ThreadSummaryModelChanged)
768 {
769 thread.set_summarization_model(Some(model), cx);
770 }
771 }
772 });
773 }
774 }
775
776 fn handle_context_server_store_updated(
777 &mut self,
778 store: Entity<project::context_server_store::ContextServerStore>,
779 _event: &project::context_server_store::ServerStatusChangedEvent,
780 cx: &mut Context<Self>,
781 ) {
782 let project_id = self.projects.iter().find_map(|(id, state)| {
783 if *state.context_server_registry.read(cx).server_store() == store {
784 Some(*id)
785 } else {
786 None
787 }
788 });
789 if let Some(project_id) = project_id {
790 self.update_available_commands_for_project(project_id, cx);
791 }
792 }
793
794 fn handle_context_server_registry_event(
795 &mut self,
796 registry: Entity<ContextServerRegistry>,
797 event: &ContextServerRegistryEvent,
798 cx: &mut Context<Self>,
799 ) {
800 match event {
801 ContextServerRegistryEvent::ToolsChanged => {}
802 ContextServerRegistryEvent::PromptsChanged => {
803 let project_id = self.projects.iter().find_map(|(id, state)| {
804 if state.context_server_registry == registry {
805 Some(*id)
806 } else {
807 None
808 }
809 });
810 if let Some(project_id) = project_id {
811 self.update_available_commands_for_project(project_id, cx);
812 }
813 }
814 }
815 }
816
817 fn update_available_commands_for_project(&self, project_id: EntityId, cx: &mut Context<Self>) {
818 let available_commands =
819 Self::build_available_commands_for_project(self.projects.get(&project_id), cx);
820 for session in self.sessions.values() {
821 if session.project_id != project_id {
822 continue;
823 }
824 session.acp_thread.update(cx, |thread, cx| {
825 thread
826 .handle_session_update(
827 acp::SessionUpdate::AvailableCommandsUpdate(
828 acp::AvailableCommandsUpdate::new(available_commands.clone()),
829 ),
830 cx,
831 )
832 .log_err();
833 });
834 }
835 }
836
837 fn build_available_commands_for_project(
838 project_state: Option<&ProjectState>,
839 cx: &App,
840 ) -> Vec<acp::AvailableCommand> {
841 let Some(state) = project_state else {
842 return vec![];
843 };
844 let registry = state.context_server_registry.read(cx);
845
846 let mut prompt_name_counts: HashMap<&str, usize> = HashMap::default();
847 for context_server_prompt in registry.prompts() {
848 *prompt_name_counts
849 .entry(context_server_prompt.prompt.name.as_str())
850 .or_insert(0) += 1;
851 }
852
853 registry
854 .prompts()
855 .flat_map(|context_server_prompt| {
856 let prompt = &context_server_prompt.prompt;
857
858 let should_prefix = prompt_name_counts
859 .get(prompt.name.as_str())
860 .copied()
861 .unwrap_or(0)
862 > 1;
863
864 let name = if should_prefix {
865 format!("{}.{}", context_server_prompt.server_id, prompt.name)
866 } else {
867 prompt.name.clone()
868 };
869
870 let mut command = acp::AvailableCommand::new(
871 name,
872 prompt.description.clone().unwrap_or_default(),
873 );
874
875 match prompt.arguments.as_deref() {
876 Some([arg]) => {
877 let hint = format!("<{}>", arg.name);
878
879 command = command.input(acp::AvailableCommandInput::Unstructured(
880 acp::UnstructuredCommandInput::new(hint),
881 ));
882 }
883 Some([]) | None => {}
884 Some(_) => {
885 // skip >1 argument commands since we don't support them yet
886 return None;
887 }
888 }
889
890 Some(command)
891 })
892 .collect()
893 }
894
895 pub fn load_thread(
896 &mut self,
897 id: acp::SessionId,
898 project: Entity<Project>,
899 cx: &mut Context<Self>,
900 ) -> Task<Result<Entity<Thread>>> {
901 let database_future = ThreadsDatabase::connect(cx);
902 cx.spawn(async move |this, cx| {
903 let database = database_future.await.map_err(|err| anyhow!(err))?;
904 let db_thread = database
905 .load_thread(id.clone())
906 .await?
907 .with_context(|| format!("no thread found with ID: {id:?}"))?;
908
909 this.update(cx, |this, cx| {
910 let project_id = this.get_or_create_project_state(&project, cx);
911 let project_state = this
912 .projects
913 .get(&project_id)
914 .context("project state not found")?;
915 let summarization_model = LanguageModelRegistry::read_global(cx)
916 .thread_summary_model(cx)
917 .map(|c| c.model);
918
919 Ok(cx.new(|cx| {
920 let mut thread = Thread::from_db(
921 id.clone(),
922 db_thread,
923 project_state.project.clone(),
924 project_state.project_context.clone(),
925 project_state.context_server_registry.clone(),
926 this.templates.clone(),
927 cx,
928 );
929 thread.set_summarization_model(summarization_model, cx);
930 thread
931 }))
932 })?
933 })
934 }
935
936 pub fn open_thread(
937 &mut self,
938 id: acp::SessionId,
939 project: Entity<Project>,
940 cx: &mut Context<Self>,
941 ) -> Task<Result<Entity<AcpThread>>> {
942 if let Some(session) = self.sessions.get_mut(&id) {
943 session.ref_count += 1;
944 return Task::ready(Ok(session.acp_thread.clone()));
945 }
946
947 if let Some(pending) = self.pending_sessions.get_mut(&id) {
948 pending.ref_count += 1;
949 let task = pending.task.clone();
950 return cx.background_spawn(async move { task.await.map_err(|err| anyhow!(err)) });
951 }
952
953 let task = self.load_thread(id.clone(), project.clone(), cx);
954 let shared_task = cx
955 .spawn({
956 let id = id.clone();
957 async move |this, cx| {
958 let thread = match task.await {
959 Ok(thread) => thread,
960 Err(err) => {
961 this.update(cx, |this, _cx| {
962 this.pending_sessions.remove(&id);
963 })
964 .ok();
965 return Err(Arc::new(err));
966 }
967 };
968 let acp_thread = this
969 .update(cx, |this, cx| {
970 let project_id = this.get_or_create_project_state(&project, cx);
971 let ref_count = this
972 .pending_sessions
973 .remove(&id)
974 .map_or(1, |pending| pending.ref_count);
975 this.register_session(thread.clone(), project_id, ref_count, cx)
976 })
977 .map_err(Arc::new)?;
978 let events = thread.update(cx, |thread, cx| thread.replay(cx));
979 cx.update(|cx| {
980 NativeAgentConnection::handle_thread_events(
981 events,
982 acp_thread.downgrade(),
983 cx,
984 )
985 })
986 .await
987 .map_err(Arc::new)?;
988 acp_thread.update(cx, |thread, cx| {
989 thread.snapshot_completed_plan(cx);
990 });
991 Ok(acp_thread)
992 }
993 })
994 .shared();
995 self.pending_sessions.insert(
996 id,
997 PendingSession {
998 task: shared_task.clone(),
999 ref_count: 1,
1000 },
1001 );
1002
1003 cx.background_spawn(async move { shared_task.await.map_err(|err| anyhow!(err)) })
1004 }
1005
1006 pub fn thread_summary(
1007 &mut self,
1008 id: acp::SessionId,
1009 project: Entity<Project>,
1010 cx: &mut Context<Self>,
1011 ) -> Task<Result<SharedString>> {
1012 let thread = self.open_thread(id.clone(), project, cx);
1013 cx.spawn(async move |this, cx| {
1014 let acp_thread = thread.await?;
1015 let result = this
1016 .update(cx, |this, cx| {
1017 this.sessions
1018 .get(&id)
1019 .unwrap()
1020 .thread
1021 .update(cx, |thread, cx| thread.summary(cx))
1022 })?
1023 .await
1024 .context("Failed to generate summary")?;
1025
1026 this.update(cx, |this, cx| this.close_session(&id, cx))?
1027 .await?;
1028 drop(acp_thread);
1029 Ok(result)
1030 })
1031 }
1032
1033 fn close_session(
1034 &mut self,
1035 session_id: &acp::SessionId,
1036 cx: &mut Context<Self>,
1037 ) -> Task<Result<()>> {
1038 let Some(session) = self.sessions.get_mut(session_id) else {
1039 return Task::ready(Ok(()));
1040 };
1041
1042 session.ref_count -= 1;
1043 if session.ref_count > 0 {
1044 return Task::ready(Ok(()));
1045 }
1046
1047 let thread = session.thread.clone();
1048 self.save_thread(thread, cx);
1049 let Some(session) = self.sessions.remove(session_id) else {
1050 return Task::ready(Ok(()));
1051 };
1052 let project_id = session.project_id;
1053
1054 let has_remaining = self.sessions.values().any(|s| s.project_id == project_id);
1055 if !has_remaining {
1056 self.projects.remove(&project_id);
1057 }
1058
1059 session.pending_save
1060 }
1061
1062 fn save_thread(&mut self, thread: Entity<Thread>, cx: &mut Context<Self>) {
1063 if thread.read(cx).is_empty() {
1064 return;
1065 }
1066
1067 let id = thread.read(cx).id().clone();
1068 let Some(session) = self.sessions.get_mut(&id) else {
1069 return;
1070 };
1071
1072 let project_id = session.project_id;
1073 let Some(state) = self.projects.get(&project_id) else {
1074 return;
1075 };
1076
1077 let folder_paths = PathList::new(
1078 &state
1079 .project
1080 .read(cx)
1081 .visible_worktrees(cx)
1082 .map(|worktree| worktree.read(cx).abs_path().to_path_buf())
1083 .collect::<Vec<_>>(),
1084 );
1085
1086 let draft_prompt = session.acp_thread.read(cx).draft_prompt().map(Vec::from);
1087 let database_future = ThreadsDatabase::connect(cx);
1088 let db_thread = thread.update(cx, |thread, cx| {
1089 thread.set_draft_prompt(draft_prompt);
1090 thread.to_db(cx)
1091 });
1092 let thread_store = self.thread_store.clone();
1093 session.pending_save = cx.spawn(async move |_, cx| {
1094 let Some(database) = database_future.await.map_err(|err| anyhow!(err)).log_err() else {
1095 return Ok(());
1096 };
1097 let db_thread = db_thread.await;
1098 database
1099 .save_thread(id, db_thread, folder_paths)
1100 .await
1101 .log_err();
1102 thread_store.update(cx, |store, cx| store.reload(cx));
1103 Ok(())
1104 });
1105 }
1106
1107 fn send_mcp_prompt(
1108 &self,
1109 message_id: UserMessageId,
1110 session_id: acp::SessionId,
1111 prompt_name: String,
1112 server_id: ContextServerId,
1113 arguments: HashMap<String, String>,
1114 original_content: Vec<acp::ContentBlock>,
1115 cx: &mut Context<Self>,
1116 ) -> Task<Result<acp::PromptResponse>> {
1117 let Some(state) = self.session_project_state(&session_id) else {
1118 return Task::ready(Err(anyhow!("Project state not found for session")));
1119 };
1120 let server_store = state
1121 .context_server_registry
1122 .read(cx)
1123 .server_store()
1124 .clone();
1125 let path_style = state.project.read(cx).path_style(cx);
1126
1127 cx.spawn(async move |this, cx| {
1128 let prompt =
1129 crate::get_prompt(&server_store, &server_id, &prompt_name, arguments, cx).await?;
1130
1131 let (acp_thread, thread) = this.update(cx, |this, _cx| {
1132 let session = this
1133 .sessions
1134 .get(&session_id)
1135 .context("Failed to get session")?;
1136 anyhow::Ok((session.acp_thread.clone(), session.thread.clone()))
1137 })??;
1138
1139 let mut last_is_user = true;
1140
1141 thread.update(cx, |thread, cx| {
1142 thread.push_acp_user_block(
1143 message_id,
1144 original_content.into_iter().skip(1),
1145 path_style,
1146 cx,
1147 );
1148 });
1149
1150 for message in prompt.messages {
1151 let context_server::types::PromptMessage { role, content } = message;
1152 let block = mcp_message_content_to_acp_content_block(content);
1153
1154 match role {
1155 context_server::types::Role::User => {
1156 let id = acp_thread::UserMessageId::new();
1157
1158 acp_thread.update(cx, |acp_thread, cx| {
1159 acp_thread.push_user_content_block_with_indent(
1160 Some(id.clone()),
1161 block.clone(),
1162 true,
1163 cx,
1164 );
1165 });
1166
1167 thread.update(cx, |thread, cx| {
1168 thread.push_acp_user_block(id, [block], path_style, cx);
1169 });
1170 }
1171 context_server::types::Role::Assistant => {
1172 acp_thread.update(cx, |acp_thread, cx| {
1173 acp_thread.push_assistant_content_block_with_indent(
1174 block.clone(),
1175 false,
1176 true,
1177 cx,
1178 );
1179 });
1180
1181 thread.update(cx, |thread, cx| {
1182 thread.push_acp_agent_block(block, cx);
1183 });
1184 }
1185 }
1186
1187 last_is_user = role == context_server::types::Role::User;
1188 }
1189
1190 let response_stream = thread.update(cx, |thread, cx| {
1191 if last_is_user {
1192 thread.send_existing(cx)
1193 } else {
1194 // Resume if MCP prompt did not end with a user message
1195 thread.resume(cx)
1196 }
1197 })?;
1198
1199 cx.update(|cx| {
1200 NativeAgentConnection::handle_thread_events(
1201 response_stream,
1202 acp_thread.downgrade(),
1203 cx,
1204 )
1205 })
1206 .await
1207 })
1208 }
1209}
1210
1211/// Wrapper struct that implements the AgentConnection trait
1212#[derive(Clone)]
1213pub struct NativeAgentConnection(pub Entity<NativeAgent>);
1214
1215impl NativeAgentConnection {
1216 pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option<Entity<Thread>> {
1217 self.0
1218 .read(cx)
1219 .sessions
1220 .get(session_id)
1221 .map(|session| session.thread.clone())
1222 }
1223
1224 pub fn load_thread(
1225 &self,
1226 id: acp::SessionId,
1227 project: Entity<Project>,
1228 cx: &mut App,
1229 ) -> Task<Result<Entity<Thread>>> {
1230 self.0
1231 .update(cx, |this, cx| this.load_thread(id, project, cx))
1232 }
1233
1234 fn run_turn(
1235 &self,
1236 session_id: acp::SessionId,
1237 cx: &mut App,
1238 f: impl 'static
1239 + FnOnce(Entity<Thread>, &mut App) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>,
1240 ) -> Task<Result<acp::PromptResponse>> {
1241 let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
1242 agent
1243 .sessions
1244 .get_mut(&session_id)
1245 .map(|s| (s.thread.clone(), s.acp_thread.clone()))
1246 }) else {
1247 log::error!("Session not found in run_turn: {}", session_id);
1248 return Task::ready(Err(anyhow!("Session not found")));
1249 };
1250 log::debug!("Found session for: {}", session_id);
1251
1252 let response_stream = match f(thread, cx) {
1253 Ok(stream) => stream,
1254 Err(err) => return Task::ready(Err(err)),
1255 };
1256 Self::handle_thread_events(response_stream, acp_thread.downgrade(), cx)
1257 }
1258
1259 fn handle_thread_events(
1260 mut events: mpsc::UnboundedReceiver<Result<ThreadEvent>>,
1261 acp_thread: WeakEntity<AcpThread>,
1262 cx: &App,
1263 ) -> Task<Result<acp::PromptResponse>> {
1264 cx.spawn(async move |cx| {
1265 // Handle response stream and forward to session.acp_thread
1266 while let Some(result) = events.next().await {
1267 match result {
1268 Ok(event) => {
1269 log::trace!("Received completion event: {:?}", event);
1270
1271 match event {
1272 ThreadEvent::UserMessage(message) => {
1273 acp_thread.update(cx, |thread, cx| {
1274 for content in message.content {
1275 thread.push_user_content_block(
1276 Some(message.id.clone()),
1277 content.into(),
1278 cx,
1279 );
1280 }
1281 })?;
1282 }
1283 ThreadEvent::AgentText(text) => {
1284 acp_thread.update(cx, |thread, cx| {
1285 thread.push_assistant_content_block(text.into(), false, cx)
1286 })?;
1287 }
1288 ThreadEvent::AgentThinking(text) => {
1289 acp_thread.update(cx, |thread, cx| {
1290 thread.push_assistant_content_block(text.into(), true, cx)
1291 })?;
1292 }
1293 ThreadEvent::ToolCallAuthorization(ToolCallAuthorization {
1294 tool_call,
1295 options,
1296 response,
1297 context: _,
1298 }) => {
1299 let outcome_task = acp_thread.update(cx, |thread, cx| {
1300 thread.request_tool_call_authorization(tool_call, options, cx)
1301 })??;
1302 cx.background_spawn(async move {
1303 if let acp_thread::RequestPermissionOutcome::Selected(outcome) =
1304 outcome_task.await
1305 {
1306 response
1307 .send(outcome)
1308 .map(|_| anyhow!("authorization receiver was dropped"))
1309 .log_err();
1310 }
1311 })
1312 .detach();
1313 }
1314 ThreadEvent::ToolCall(tool_call) => {
1315 acp_thread.update(cx, |thread, cx| {
1316 thread.upsert_tool_call(tool_call, cx)
1317 })??;
1318 }
1319 ThreadEvent::ToolCallUpdate(update) => {
1320 acp_thread.update(cx, |thread, cx| {
1321 thread.update_tool_call(update, cx)
1322 })??;
1323 }
1324 ThreadEvent::Plan(plan) => {
1325 acp_thread.update(cx, |thread, cx| thread.update_plan(plan, cx))?;
1326 }
1327 ThreadEvent::SubagentSpawned(session_id) => {
1328 acp_thread.update(cx, |thread, cx| {
1329 thread.subagent_spawned(session_id, cx);
1330 })?;
1331 }
1332 ThreadEvent::Retry(status) => {
1333 acp_thread.update(cx, |thread, cx| {
1334 thread.update_retry_status(status, cx)
1335 })?;
1336 }
1337 ThreadEvent::Stop(stop_reason) => {
1338 log::debug!("Assistant message complete: {:?}", stop_reason);
1339 return Ok(acp::PromptResponse::new(stop_reason));
1340 }
1341 }
1342 }
1343 Err(e) => {
1344 log::error!("Error in model response stream: {:?}", e);
1345 return Err(e);
1346 }
1347 }
1348 }
1349
1350 log::debug!("Response stream completed");
1351 anyhow::Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
1352 })
1353 }
1354}
1355
1356struct Command<'a> {
1357 prompt_name: &'a str,
1358 arg_value: &'a str,
1359 explicit_server_id: Option<&'a str>,
1360}
1361
1362impl<'a> Command<'a> {
1363 fn parse(prompt: &'a [acp::ContentBlock]) -> Option<Self> {
1364 let acp::ContentBlock::Text(text_content) = prompt.first()? else {
1365 return None;
1366 };
1367 let text = text_content.text.trim();
1368 let command = text.strip_prefix('/')?;
1369 let (command, arg_value) = command
1370 .split_once(char::is_whitespace)
1371 .unwrap_or((command, ""));
1372
1373 if let Some((server_id, prompt_name)) = command.split_once('.') {
1374 Some(Self {
1375 prompt_name,
1376 arg_value,
1377 explicit_server_id: Some(server_id),
1378 })
1379 } else {
1380 Some(Self {
1381 prompt_name: command,
1382 arg_value,
1383 explicit_server_id: None,
1384 })
1385 }
1386 }
1387}
1388
1389struct NativeAgentModelSelector {
1390 session_id: acp::SessionId,
1391 connection: NativeAgentConnection,
1392}
1393
1394impl acp_thread::AgentModelSelector for NativeAgentModelSelector {
1395 fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
1396 log::debug!("NativeAgentConnection::list_models called");
1397 let list = self.connection.0.read(cx).models.model_list.clone();
1398 Task::ready(if list.is_empty() {
1399 Err(anyhow::anyhow!("No models available"))
1400 } else {
1401 Ok(list)
1402 })
1403 }
1404
1405 fn select_model(&self, model_id: acp::ModelId, cx: &mut App) -> Task<Result<()>> {
1406 log::debug!(
1407 "Setting model for session {}: {}",
1408 self.session_id,
1409 model_id
1410 );
1411 let Some(thread) = self
1412 .connection
1413 .0
1414 .read(cx)
1415 .sessions
1416 .get(&self.session_id)
1417 .map(|session| session.thread.clone())
1418 else {
1419 return Task::ready(Err(anyhow!("Session not found")));
1420 };
1421
1422 let Some(model) = self.connection.0.read(cx).models.model_from_id(&model_id) else {
1423 return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
1424 };
1425
1426 let favorite = agent_settings::AgentSettings::get_global(cx)
1427 .favorite_models
1428 .iter()
1429 .find(|favorite| {
1430 favorite.provider.0 == model.provider_id().0.as_ref()
1431 && favorite.model == model.id().0.as_ref()
1432 })
1433 .cloned();
1434
1435 let LanguageModelSelection {
1436 enable_thinking,
1437 effort,
1438 speed,
1439 ..
1440 } = agent_settings::language_model_to_selection(&model, favorite.as_ref());
1441
1442 thread.update(cx, |thread, cx| {
1443 thread.set_model(model.clone(), cx);
1444 thread.set_thinking_effort(effort.clone(), cx);
1445 thread.set_thinking_enabled(enable_thinking, cx);
1446 if let Some(speed) = speed {
1447 thread.set_speed(speed, cx);
1448 }
1449 });
1450
1451 update_settings_file(
1452 self.connection.0.read(cx).fs.clone(),
1453 cx,
1454 move |settings, cx| {
1455 let provider = model.provider_id().0.to_string();
1456 let model = model.id().0.to_string();
1457 let enable_thinking = thread.read(cx).thinking_enabled();
1458 let speed = thread.read(cx).speed();
1459 settings
1460 .agent
1461 .get_or_insert_default()
1462 .set_model(LanguageModelSelection {
1463 provider: provider.into(),
1464 model,
1465 enable_thinking,
1466 effort,
1467 speed,
1468 });
1469 },
1470 );
1471
1472 Task::ready(Ok(()))
1473 }
1474
1475 fn selected_model(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelInfo>> {
1476 let Some(thread) = self
1477 .connection
1478 .0
1479 .read(cx)
1480 .sessions
1481 .get(&self.session_id)
1482 .map(|session| session.thread.clone())
1483 else {
1484 return Task::ready(Err(anyhow!("Session not found")));
1485 };
1486 let Some(model) = thread.read(cx).model() else {
1487 return Task::ready(Err(anyhow!("Model not found")));
1488 };
1489 let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
1490 else {
1491 return Task::ready(Err(anyhow!("Provider not found")));
1492 };
1493 Task::ready(Ok(LanguageModels::map_language_model_to_info(
1494 model, &provider,
1495 )))
1496 }
1497
1498 fn watch(&self, cx: &mut App) -> Option<watch::Receiver<()>> {
1499 Some(self.connection.0.read(cx).models.watch())
1500 }
1501
1502 fn should_render_footer(&self) -> bool {
1503 true
1504 }
1505}
1506
1507pub static ZED_AGENT_ID: LazyLock<AgentId> = LazyLock::new(|| AgentId::new("Zed Agent"));
1508
1509impl acp_thread::AgentConnection for NativeAgentConnection {
1510 fn agent_id(&self) -> AgentId {
1511 ZED_AGENT_ID.clone()
1512 }
1513
1514 fn telemetry_id(&self) -> SharedString {
1515 "zed".into()
1516 }
1517
1518 fn new_session(
1519 self: Rc<Self>,
1520 project: Entity<Project>,
1521 work_dirs: PathList,
1522 cx: &mut App,
1523 ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
1524 log::debug!("Creating new thread for project at: {work_dirs:?}");
1525 Task::ready(Ok(self
1526 .0
1527 .update(cx, |agent, cx| agent.new_session(project, cx))))
1528 }
1529
1530 fn supports_load_session(&self) -> bool {
1531 true
1532 }
1533
1534 fn load_session(
1535 self: Rc<Self>,
1536 session_id: acp::SessionId,
1537 project: Entity<Project>,
1538 _work_dirs: PathList,
1539 _title: Option<SharedString>,
1540 cx: &mut App,
1541 ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
1542 self.0
1543 .update(cx, |agent, cx| agent.open_thread(session_id, project, cx))
1544 }
1545
1546 fn supports_close_session(&self) -> bool {
1547 true
1548 }
1549
1550 fn close_session(
1551 self: Rc<Self>,
1552 session_id: &acp::SessionId,
1553 cx: &mut App,
1554 ) -> Task<Result<()>> {
1555 self.0
1556 .update(cx, |agent, cx| agent.close_session(session_id, cx))
1557 }
1558
1559 fn auth_methods(&self) -> &[acp::AuthMethod] {
1560 &[] // No auth for in-process
1561 }
1562
1563 fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
1564 Task::ready(Ok(()))
1565 }
1566
1567 fn model_selector(&self, session_id: &acp::SessionId) -> Option<Rc<dyn AgentModelSelector>> {
1568 Some(Rc::new(NativeAgentModelSelector {
1569 session_id: session_id.clone(),
1570 connection: self.clone(),
1571 }) as Rc<dyn AgentModelSelector>)
1572 }
1573
1574 fn prompt(
1575 &self,
1576 id: acp_thread::UserMessageId,
1577 params: acp::PromptRequest,
1578 cx: &mut App,
1579 ) -> Task<Result<acp::PromptResponse>> {
1580 let session_id = params.session_id.clone();
1581 log::info!("Received prompt request for session: {}", session_id);
1582 log::debug!("Prompt blocks count: {}", params.prompt.len());
1583
1584 let Some(project_state) = self.0.read(cx).session_project_state(&session_id) else {
1585 log::error!("Session not found in prompt: {}", session_id);
1586 if self.0.read(cx).sessions.contains_key(&session_id) {
1587 log::error!(
1588 "Session found in sessions map, but not in project state: {}",
1589 session_id
1590 );
1591 }
1592 return Task::ready(Err(anyhow::anyhow!("Session not found")));
1593 };
1594
1595 if let Some(parsed_command) = Command::parse(¶ms.prompt) {
1596 let registry = project_state.context_server_registry.read(cx);
1597
1598 let explicit_server_id = parsed_command
1599 .explicit_server_id
1600 .map(|server_id| ContextServerId(server_id.into()));
1601
1602 if let Some(prompt) =
1603 registry.find_prompt(explicit_server_id.as_ref(), parsed_command.prompt_name)
1604 {
1605 let arguments = if !parsed_command.arg_value.is_empty()
1606 && let Some(arg_name) = prompt
1607 .prompt
1608 .arguments
1609 .as_ref()
1610 .and_then(|args| args.first())
1611 .map(|arg| arg.name.clone())
1612 {
1613 HashMap::from_iter([(arg_name, parsed_command.arg_value.to_string())])
1614 } else {
1615 Default::default()
1616 };
1617
1618 let prompt_name = prompt.prompt.name.clone();
1619 let server_id = prompt.server_id.clone();
1620
1621 return self.0.update(cx, |agent, cx| {
1622 agent.send_mcp_prompt(
1623 id,
1624 session_id.clone(),
1625 prompt_name,
1626 server_id,
1627 arguments,
1628 params.prompt,
1629 cx,
1630 )
1631 });
1632 }
1633 };
1634
1635 let path_style = project_state.project.read(cx).path_style(cx);
1636
1637 self.run_turn(session_id, cx, move |thread, cx| {
1638 let content: Vec<UserMessageContent> = params
1639 .prompt
1640 .into_iter()
1641 .map(|block| UserMessageContent::from_content_block(block, path_style))
1642 .collect::<Vec<_>>();
1643 log::debug!("Converted prompt to message: {} chars", content.len());
1644 log::debug!("Message id: {:?}", id);
1645 log::debug!("Message content: {:?}", content);
1646
1647 thread.update(cx, |thread, cx| thread.send(id, content, cx))
1648 })
1649 }
1650
1651 fn retry(
1652 &self,
1653 session_id: &acp::SessionId,
1654 _cx: &App,
1655 ) -> Option<Rc<dyn acp_thread::AgentSessionRetry>> {
1656 Some(Rc::new(NativeAgentSessionRetry {
1657 connection: self.clone(),
1658 session_id: session_id.clone(),
1659 }) as _)
1660 }
1661
1662 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
1663 log::info!("Cancelling on session: {}", session_id);
1664 self.0.update(cx, |agent, cx| {
1665 if let Some(session) = agent.sessions.get(session_id) {
1666 session
1667 .thread
1668 .update(cx, |thread, cx| thread.cancel(cx))
1669 .detach();
1670 }
1671 });
1672 }
1673
1674 fn truncate(
1675 &self,
1676 session_id: &acp::SessionId,
1677 cx: &App,
1678 ) -> Option<Rc<dyn acp_thread::AgentSessionTruncate>> {
1679 self.0.read_with(cx, |agent, _cx| {
1680 agent.sessions.get(session_id).map(|session| {
1681 Rc::new(NativeAgentSessionTruncate {
1682 thread: session.thread.clone(),
1683 acp_thread: session.acp_thread.downgrade(),
1684 }) as _
1685 })
1686 })
1687 }
1688
1689 fn set_title(
1690 &self,
1691 session_id: &acp::SessionId,
1692 cx: &App,
1693 ) -> Option<Rc<dyn acp_thread::AgentSessionSetTitle>> {
1694 self.0.read_with(cx, |agent, _cx| {
1695 agent
1696 .sessions
1697 .get(session_id)
1698 .filter(|s| !s.thread.read(cx).is_subagent())
1699 .map(|session| {
1700 Rc::new(NativeAgentSessionSetTitle {
1701 thread: session.thread.clone(),
1702 }) as _
1703 })
1704 })
1705 }
1706
1707 fn session_list(&self, cx: &mut App) -> Option<Rc<dyn AgentSessionList>> {
1708 let thread_store = self.0.read(cx).thread_store.clone();
1709 Some(Rc::new(NativeAgentSessionList::new(thread_store, cx)) as _)
1710 }
1711
1712 fn telemetry(&self) -> Option<Rc<dyn acp_thread::AgentTelemetry>> {
1713 Some(Rc::new(self.clone()) as Rc<dyn acp_thread::AgentTelemetry>)
1714 }
1715
1716 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1717 self
1718 }
1719}
1720
1721impl acp_thread::AgentTelemetry for NativeAgentConnection {
1722 fn thread_data(
1723 &self,
1724 session_id: &acp::SessionId,
1725 cx: &mut App,
1726 ) -> Task<Result<serde_json::Value>> {
1727 let Some(session) = self.0.read(cx).sessions.get(session_id) else {
1728 return Task::ready(Err(anyhow!("Session not found")));
1729 };
1730
1731 let task = session.thread.read(cx).to_db(cx);
1732 cx.background_spawn(async move {
1733 serde_json::to_value(task.await).context("Failed to serialize thread")
1734 })
1735 }
1736}
1737
1738pub struct NativeAgentSessionList {
1739 thread_store: Entity<ThreadStore>,
1740 updates_tx: smol::channel::Sender<acp_thread::SessionListUpdate>,
1741 updates_rx: smol::channel::Receiver<acp_thread::SessionListUpdate>,
1742 _subscription: Subscription,
1743}
1744
1745impl NativeAgentSessionList {
1746 fn new(thread_store: Entity<ThreadStore>, cx: &mut App) -> Self {
1747 let (tx, rx) = smol::channel::unbounded();
1748 let this_tx = tx.clone();
1749 let subscription = cx.observe(&thread_store, move |_, _| {
1750 this_tx
1751 .try_send(acp_thread::SessionListUpdate::Refresh)
1752 .ok();
1753 });
1754 Self {
1755 thread_store,
1756 updates_tx: tx,
1757 updates_rx: rx,
1758 _subscription: subscription,
1759 }
1760 }
1761
1762 pub fn thread_store(&self) -> &Entity<ThreadStore> {
1763 &self.thread_store
1764 }
1765}
1766
1767impl AgentSessionList for NativeAgentSessionList {
1768 fn list_sessions(
1769 &self,
1770 _request: AgentSessionListRequest,
1771 cx: &mut App,
1772 ) -> Task<Result<AgentSessionListResponse>> {
1773 let sessions = self
1774 .thread_store
1775 .read(cx)
1776 .entries()
1777 .map(|entry| AgentSessionInfo::from(&entry))
1778 .collect();
1779 Task::ready(Ok(AgentSessionListResponse::new(sessions)))
1780 }
1781
1782 fn supports_delete(&self) -> bool {
1783 true
1784 }
1785
1786 fn delete_session(&self, session_id: &acp::SessionId, cx: &mut App) -> Task<Result<()>> {
1787 self.thread_store
1788 .update(cx, |store, cx| store.delete_thread(session_id.clone(), cx))
1789 }
1790
1791 fn delete_sessions(&self, cx: &mut App) -> Task<Result<()>> {
1792 self.thread_store
1793 .update(cx, |store, cx| store.delete_threads(cx))
1794 }
1795
1796 fn watch(
1797 &self,
1798 _cx: &mut App,
1799 ) -> Option<smol::channel::Receiver<acp_thread::SessionListUpdate>> {
1800 Some(self.updates_rx.clone())
1801 }
1802
1803 fn notify_refresh(&self) {
1804 self.updates_tx
1805 .try_send(acp_thread::SessionListUpdate::Refresh)
1806 .ok();
1807 }
1808
1809 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1810 self
1811 }
1812}
1813
1814struct NativeAgentSessionTruncate {
1815 thread: Entity<Thread>,
1816 acp_thread: WeakEntity<AcpThread>,
1817}
1818
1819impl acp_thread::AgentSessionTruncate for NativeAgentSessionTruncate {
1820 fn run(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
1821 match self.thread.update(cx, |thread, cx| {
1822 thread.truncate(message_id.clone(), cx)?;
1823 Ok(thread.latest_token_usage())
1824 }) {
1825 Ok(usage) => {
1826 self.acp_thread
1827 .update(cx, |thread, cx| {
1828 thread.update_token_usage(usage, cx);
1829 })
1830 .ok();
1831 Task::ready(Ok(()))
1832 }
1833 Err(error) => Task::ready(Err(error)),
1834 }
1835 }
1836}
1837
1838struct NativeAgentSessionRetry {
1839 connection: NativeAgentConnection,
1840 session_id: acp::SessionId,
1841}
1842
1843impl acp_thread::AgentSessionRetry for NativeAgentSessionRetry {
1844 fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>> {
1845 self.connection
1846 .run_turn(self.session_id.clone(), cx, |thread, cx| {
1847 thread.update(cx, |thread, cx| thread.resume(cx))
1848 })
1849 }
1850}
1851
1852struct NativeAgentSessionSetTitle {
1853 thread: Entity<Thread>,
1854}
1855
1856impl acp_thread::AgentSessionSetTitle for NativeAgentSessionSetTitle {
1857 fn run(&self, title: SharedString, cx: &mut App) -> Task<Result<()>> {
1858 self.thread
1859 .update(cx, |thread, cx| thread.set_title(title, cx));
1860 Task::ready(Ok(()))
1861 }
1862}
1863
1864pub struct NativeThreadEnvironment {
1865 agent: WeakEntity<NativeAgent>,
1866 thread: WeakEntity<Thread>,
1867 acp_thread: WeakEntity<AcpThread>,
1868}
1869
1870impl NativeThreadEnvironment {
1871 pub(crate) fn create_subagent_thread(
1872 &self,
1873 label: String,
1874 cx: &mut App,
1875 ) -> Result<Rc<dyn SubagentHandle>> {
1876 let Some(parent_thread_entity) = self.thread.upgrade() else {
1877 anyhow::bail!("Parent thread no longer exists".to_string());
1878 };
1879 let parent_thread = parent_thread_entity.read(cx);
1880 let current_depth = parent_thread.depth();
1881 let parent_session_id = parent_thread.id().clone();
1882
1883 if current_depth >= MAX_SUBAGENT_DEPTH {
1884 return Err(anyhow!(
1885 "Maximum subagent depth ({}) reached",
1886 MAX_SUBAGENT_DEPTH
1887 ));
1888 }
1889
1890 let subagent_thread: Entity<Thread> = cx.new(|cx| {
1891 let mut thread = Thread::new_subagent(&parent_thread_entity, cx);
1892 thread.set_title(label.into(), cx);
1893 thread
1894 });
1895
1896 let session_id = subagent_thread.read(cx).id().clone();
1897
1898 let acp_thread = self
1899 .agent
1900 .update(cx, |agent, cx| -> Result<Entity<AcpThread>> {
1901 let project_id = agent
1902 .sessions
1903 .get(&parent_session_id)
1904 .map(|s| s.project_id)
1905 .context("parent session not found")?;
1906 Ok(agent.register_session(subagent_thread.clone(), project_id, 1, cx))
1907 })??;
1908
1909 let depth = current_depth + 1;
1910
1911 telemetry::event!(
1912 "Subagent Started",
1913 session = parent_thread_entity.read(cx).id().to_string(),
1914 subagent_session = session_id.to_string(),
1915 depth,
1916 is_resumed = false,
1917 );
1918
1919 self.prompt_subagent(session_id, subagent_thread, acp_thread)
1920 }
1921
1922 pub(crate) fn resume_subagent_thread(
1923 &self,
1924 session_id: acp::SessionId,
1925 cx: &mut App,
1926 ) -> Result<Rc<dyn SubagentHandle>> {
1927 let (subagent_thread, acp_thread) = self.agent.update(cx, |agent, _cx| {
1928 let session = agent
1929 .sessions
1930 .get(&session_id)
1931 .ok_or_else(|| anyhow!("No subagent session found with id {session_id}"))?;
1932 anyhow::Ok((session.thread.clone(), session.acp_thread.clone()))
1933 })??;
1934
1935 let depth = subagent_thread.read(cx).depth();
1936
1937 if let Some(parent_thread_entity) = self.thread.upgrade() {
1938 telemetry::event!(
1939 "Subagent Started",
1940 session = parent_thread_entity.read(cx).id().to_string(),
1941 subagent_session = session_id.to_string(),
1942 depth,
1943 is_resumed = true,
1944 );
1945 }
1946
1947 self.prompt_subagent(session_id, subagent_thread, acp_thread)
1948 }
1949
1950 fn prompt_subagent(
1951 &self,
1952 session_id: acp::SessionId,
1953 subagent_thread: Entity<Thread>,
1954 acp_thread: Entity<acp_thread::AcpThread>,
1955 ) -> Result<Rc<dyn SubagentHandle>> {
1956 let Some(parent_thread_entity) = self.thread.upgrade() else {
1957 anyhow::bail!("Parent thread no longer exists".to_string());
1958 };
1959 Ok(Rc::new(NativeSubagentHandle::new(
1960 session_id,
1961 subagent_thread,
1962 acp_thread,
1963 parent_thread_entity,
1964 )) as _)
1965 }
1966}
1967
1968impl ThreadEnvironment for NativeThreadEnvironment {
1969 fn create_terminal(
1970 &self,
1971 command: String,
1972 cwd: Option<PathBuf>,
1973 output_byte_limit: Option<u64>,
1974 cx: &mut AsyncApp,
1975 ) -> Task<Result<Rc<dyn TerminalHandle>>> {
1976 let task = self.acp_thread.update(cx, |thread, cx| {
1977 thread.create_terminal(command, vec![], vec![], cwd, output_byte_limit, cx)
1978 });
1979
1980 let acp_thread = self.acp_thread.clone();
1981 cx.spawn(async move |cx| {
1982 let terminal = task?.await?;
1983
1984 let (drop_tx, drop_rx) = oneshot::channel();
1985 let terminal_id = terminal.read_with(cx, |terminal, _cx| terminal.id().clone());
1986
1987 cx.spawn(async move |cx| {
1988 drop_rx.await.ok();
1989 acp_thread.update(cx, |thread, cx| thread.release_terminal(terminal_id, cx))
1990 })
1991 .detach();
1992
1993 let handle = AcpTerminalHandle {
1994 terminal,
1995 _drop_tx: Some(drop_tx),
1996 };
1997
1998 Ok(Rc::new(handle) as _)
1999 })
2000 }
2001
2002 fn create_subagent(&self, label: String, cx: &mut App) -> Result<Rc<dyn SubagentHandle>> {
2003 self.create_subagent_thread(label, cx)
2004 }
2005
2006 fn resume_subagent(
2007 &self,
2008 session_id: acp::SessionId,
2009 cx: &mut App,
2010 ) -> Result<Rc<dyn SubagentHandle>> {
2011 self.resume_subagent_thread(session_id, cx)
2012 }
2013}
2014
2015#[derive(Debug, Clone)]
2016enum SubagentPromptResult {
2017 Completed,
2018 Cancelled,
2019 ContextWindowWarning,
2020 Error(String),
2021}
2022
2023pub struct NativeSubagentHandle {
2024 session_id: acp::SessionId,
2025 parent_thread: WeakEntity<Thread>,
2026 subagent_thread: Entity<Thread>,
2027 acp_thread: Entity<acp_thread::AcpThread>,
2028}
2029
2030impl NativeSubagentHandle {
2031 fn new(
2032 session_id: acp::SessionId,
2033 subagent_thread: Entity<Thread>,
2034 acp_thread: Entity<acp_thread::AcpThread>,
2035 parent_thread_entity: Entity<Thread>,
2036 ) -> Self {
2037 NativeSubagentHandle {
2038 session_id,
2039 subagent_thread,
2040 parent_thread: parent_thread_entity.downgrade(),
2041 acp_thread,
2042 }
2043 }
2044}
2045
2046impl SubagentHandle for NativeSubagentHandle {
2047 fn id(&self) -> acp::SessionId {
2048 self.session_id.clone()
2049 }
2050
2051 fn num_entries(&self, cx: &App) -> usize {
2052 self.acp_thread.read(cx).entries().len()
2053 }
2054
2055 fn send(&self, message: String, cx: &AsyncApp) -> Task<Result<String>> {
2056 let thread = self.subagent_thread.clone();
2057 let acp_thread = self.acp_thread.clone();
2058 let subagent_session_id = self.session_id.clone();
2059 let parent_thread = self.parent_thread.clone();
2060
2061 cx.spawn(async move |cx| {
2062 let (task, _subscription) = cx.update(|cx| {
2063 let ratio_before_prompt = thread
2064 .read(cx)
2065 .latest_token_usage()
2066 .map(|usage| usage.ratio());
2067
2068 parent_thread
2069 .update(cx, |parent_thread, _cx| {
2070 parent_thread.register_running_subagent(thread.downgrade())
2071 })
2072 .ok();
2073
2074 let task = acp_thread.update(cx, |acp_thread, cx| {
2075 acp_thread.send(vec![message.into()], cx)
2076 });
2077
2078 let (token_limit_tx, token_limit_rx) = oneshot::channel::<()>();
2079 let mut token_limit_tx = Some(token_limit_tx);
2080
2081 let subscription = cx.subscribe(
2082 &thread,
2083 move |_thread, event: &TokenUsageUpdated, _cx| {
2084 if let Some(usage) = &event.0 {
2085 let old_ratio = ratio_before_prompt
2086 .clone()
2087 .unwrap_or(TokenUsageRatio::Normal);
2088 let new_ratio = usage.ratio();
2089 if old_ratio == TokenUsageRatio::Normal
2090 && new_ratio == TokenUsageRatio::Warning
2091 {
2092 if let Some(tx) = token_limit_tx.take() {
2093 tx.send(()).ok();
2094 }
2095 }
2096 }
2097 },
2098 );
2099
2100 let wait_for_prompt = cx
2101 .background_spawn(async move {
2102 futures::select! {
2103 response = task.fuse() => match response {
2104 Ok(Some(response)) => {
2105 match response.stop_reason {
2106 acp::StopReason::Cancelled => SubagentPromptResult::Cancelled,
2107 acp::StopReason::MaxTokens => SubagentPromptResult::Error("The agent reached the maximum number of tokens.".into()),
2108 acp::StopReason::MaxTurnRequests => SubagentPromptResult::Error("The agent reached the maximum number of allowed requests between user turns. Try prompting again.".into()),
2109 acp::StopReason::Refusal => SubagentPromptResult::Error("The agent refused to process that prompt. Try again.".into()),
2110 acp::StopReason::EndTurn | _ => SubagentPromptResult::Completed,
2111 }
2112 }
2113 Ok(None) => SubagentPromptResult::Error("No response from the agent. You can try messaging again.".into()),
2114 Err(error) => SubagentPromptResult::Error(error.to_string()),
2115 },
2116 _ = token_limit_rx.fuse() => SubagentPromptResult::ContextWindowWarning,
2117 }
2118 });
2119
2120 (wait_for_prompt, subscription)
2121 });
2122
2123 let result = match task.await {
2124 SubagentPromptResult::Completed => thread.read_with(cx, |thread, _cx| {
2125 thread
2126 .last_message()
2127 .and_then(|message| {
2128 let content = message.as_agent_message()?
2129 .content
2130 .iter()
2131 .filter_map(|c| match c {
2132 AgentMessageContent::Text(text) => Some(text.as_str()),
2133 _ => None,
2134 })
2135 .join("\n\n");
2136 if content.is_empty() {
2137 None
2138 } else {
2139 Some( content)
2140 }
2141 })
2142 .context("No response from subagent")
2143 }),
2144 SubagentPromptResult::Cancelled => Err(anyhow!("User canceled")),
2145 SubagentPromptResult::Error(message) => Err(anyhow!("{message}")),
2146 SubagentPromptResult::ContextWindowWarning => {
2147 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
2148 Err(anyhow!(
2149 "The agent is nearing the end of its context window and has been \
2150 stopped. You can prompt the thread again to have the agent wrap up \
2151 or hand off its work."
2152 ))
2153 }
2154 };
2155
2156 parent_thread
2157 .update(cx, |parent_thread, cx| {
2158 parent_thread.unregister_running_subagent(&subagent_session_id, cx)
2159 })
2160 .ok();
2161
2162 result
2163 })
2164 }
2165}
2166
2167pub struct AcpTerminalHandle {
2168 terminal: Entity<acp_thread::Terminal>,
2169 _drop_tx: Option<oneshot::Sender<()>>,
2170}
2171
2172impl TerminalHandle for AcpTerminalHandle {
2173 fn id(&self, cx: &AsyncApp) -> Result<acp::TerminalId> {
2174 Ok(self.terminal.read_with(cx, |term, _cx| term.id().clone()))
2175 }
2176
2177 fn wait_for_exit(&self, cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
2178 Ok(self
2179 .terminal
2180 .read_with(cx, |term, _cx| term.wait_for_exit()))
2181 }
2182
2183 fn current_output(&self, cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
2184 Ok(self
2185 .terminal
2186 .read_with(cx, |term, cx| term.current_output(cx)))
2187 }
2188
2189 fn kill(&self, cx: &AsyncApp) -> Result<()> {
2190 cx.update(|cx| {
2191 self.terminal.update(cx, |terminal, cx| {
2192 terminal.kill(cx);
2193 });
2194 });
2195 Ok(())
2196 }
2197
2198 fn was_stopped_by_user(&self, cx: &AsyncApp) -> Result<bool> {
2199 Ok(self
2200 .terminal
2201 .read_with(cx, |term, _cx| term.was_stopped_by_user()))
2202 }
2203}
2204
2205#[cfg(test)]
2206mod internal_tests {
2207 use std::path::Path;
2208
2209 use super::*;
2210 use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelInfo, MentionUri};
2211 use fs::FakeFs;
2212 use gpui::TestAppContext;
2213 use indoc::formatdoc;
2214 use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
2215 use language_model::{
2216 LanguageModelCompletionEvent, LanguageModelProviderId, LanguageModelProviderName,
2217 };
2218 use serde_json::json;
2219 use settings::SettingsStore;
2220 use util::{path, rel_path::rel_path};
2221
2222 #[gpui::test]
2223 async fn test_maintaining_project_context(cx: &mut TestAppContext) {
2224 init_test(cx);
2225 let fs = FakeFs::new(cx.executor());
2226 fs.insert_tree(
2227 "/",
2228 json!({
2229 "a": {}
2230 }),
2231 )
2232 .await;
2233 let project = Project::test(fs.clone(), [], cx).await;
2234 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2235 let agent =
2236 cx.update(|cx| NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx));
2237
2238 // Creating a session registers the project and triggers context building.
2239 let connection = NativeAgentConnection(agent.clone());
2240 let _acp_thread = cx
2241 .update(|cx| {
2242 Rc::new(connection).new_session(
2243 project.clone(),
2244 PathList::new(&[Path::new("/")]),
2245 cx,
2246 )
2247 })
2248 .await
2249 .unwrap();
2250 cx.run_until_parked();
2251
2252 let thread = agent.read_with(cx, |agent, _cx| {
2253 agent.sessions.values().next().unwrap().thread.clone()
2254 });
2255
2256 agent.read_with(cx, |agent, cx| {
2257 let project_id = project.entity_id();
2258 let state = agent.projects.get(&project_id).unwrap();
2259 assert_eq!(state.project_context.read(cx).worktrees, vec![]);
2260 assert_eq!(thread.read(cx).project_context().read(cx).worktrees, vec![]);
2261 });
2262
2263 let worktree = project
2264 .update(cx, |project, cx| project.create_worktree("/a", true, cx))
2265 .await
2266 .unwrap();
2267 cx.run_until_parked();
2268 agent.read_with(cx, |agent, cx| {
2269 let project_id = project.entity_id();
2270 let state = agent.projects.get(&project_id).unwrap();
2271 let expected_worktrees = vec![WorktreeContext {
2272 root_name: "a".into(),
2273 abs_path: Path::new("/a").into(),
2274 rules_file: None,
2275 }];
2276 assert_eq!(state.project_context.read(cx).worktrees, expected_worktrees);
2277 assert_eq!(
2278 thread.read(cx).project_context().read(cx).worktrees,
2279 expected_worktrees
2280 );
2281 });
2282
2283 // Creating `/a/.rules` updates the project context.
2284 fs.insert_file("/a/.rules", Vec::new()).await;
2285 cx.run_until_parked();
2286 agent.read_with(cx, |agent, cx| {
2287 let project_id = project.entity_id();
2288 let state = agent.projects.get(&project_id).unwrap();
2289 let rules_entry = worktree
2290 .read(cx)
2291 .entry_for_path(rel_path(".rules"))
2292 .unwrap();
2293 let expected_worktrees = vec![WorktreeContext {
2294 root_name: "a".into(),
2295 abs_path: Path::new("/a").into(),
2296 rules_file: Some(RulesFileContext {
2297 path_in_worktree: rel_path(".rules").into(),
2298 text: "".into(),
2299 project_entry_id: rules_entry.id.to_usize(),
2300 }),
2301 }];
2302 assert_eq!(state.project_context.read(cx).worktrees, expected_worktrees);
2303 assert_eq!(
2304 thread.read(cx).project_context().read(cx).worktrees,
2305 expected_worktrees
2306 );
2307 });
2308 }
2309
2310 #[gpui::test]
2311 async fn test_listing_models(cx: &mut TestAppContext) {
2312 init_test(cx);
2313 let fs = FakeFs::new(cx.executor());
2314 fs.insert_tree("/", json!({ "a": {} })).await;
2315 let project = Project::test(fs.clone(), [], cx).await;
2316 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2317 let connection =
2318 NativeAgentConnection(cx.update(|cx| {
2319 NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx)
2320 }));
2321
2322 // Create a thread/session
2323 let acp_thread = cx
2324 .update(|cx| {
2325 Rc::new(connection.clone()).new_session(
2326 project.clone(),
2327 PathList::new(&[Path::new("/a")]),
2328 cx,
2329 )
2330 })
2331 .await
2332 .unwrap();
2333
2334 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2335
2336 let models = cx
2337 .update(|cx| {
2338 connection
2339 .model_selector(&session_id)
2340 .unwrap()
2341 .list_models(cx)
2342 })
2343 .await
2344 .unwrap();
2345
2346 let acp_thread::AgentModelList::Grouped(models) = models else {
2347 panic!("Unexpected model group");
2348 };
2349 assert_eq!(
2350 models,
2351 IndexMap::from_iter([(
2352 AgentModelGroupName("Fake".into()),
2353 vec![AgentModelInfo {
2354 id: acp::ModelId::new("fake/fake"),
2355 name: "Fake".into(),
2356 description: None,
2357 icon: Some(acp_thread::AgentModelIcon::Named(
2358 ui::IconName::ZedAssistant
2359 )),
2360 is_latest: false,
2361 cost: None,
2362 }]
2363 )])
2364 );
2365 }
2366
2367 #[gpui::test]
2368 async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) {
2369 init_test(cx);
2370 let fs = FakeFs::new(cx.executor());
2371 fs.create_dir(paths::settings_file().parent().unwrap())
2372 .await
2373 .unwrap();
2374 fs.insert_file(
2375 paths::settings_file(),
2376 json!({
2377 "agent": {
2378 "default_model": {
2379 "provider": "foo",
2380 "model": "bar"
2381 }
2382 }
2383 })
2384 .to_string()
2385 .into_bytes(),
2386 )
2387 .await;
2388 let project = Project::test(fs.clone(), [], cx).await;
2389
2390 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2391
2392 // Create the agent and connection
2393 let agent =
2394 cx.update(|cx| NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx));
2395 let connection = NativeAgentConnection(agent.clone());
2396
2397 // Create a thread/session
2398 let acp_thread = cx
2399 .update(|cx| {
2400 Rc::new(connection.clone()).new_session(
2401 project.clone(),
2402 PathList::new(&[Path::new("/a")]),
2403 cx,
2404 )
2405 })
2406 .await
2407 .unwrap();
2408
2409 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2410
2411 // Select a model
2412 let selector = connection.model_selector(&session_id).unwrap();
2413 let model_id = acp::ModelId::new("fake/fake");
2414 cx.update(|cx| selector.select_model(model_id.clone(), cx))
2415 .await
2416 .unwrap();
2417
2418 // Verify the thread has the selected model
2419 agent.read_with(cx, |agent, _| {
2420 let session = agent.sessions.get(&session_id).unwrap();
2421 session.thread.read_with(cx, |thread, _| {
2422 assert_eq!(thread.model().unwrap().id().0, "fake");
2423 });
2424 });
2425
2426 cx.run_until_parked();
2427
2428 // Verify settings file was updated
2429 let settings_content = fs.load(paths::settings_file()).await.unwrap();
2430 let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
2431
2432 // Check that the agent settings contain the selected model
2433 assert_eq!(
2434 settings_json["agent"]["default_model"]["model"],
2435 json!("fake")
2436 );
2437 assert_eq!(
2438 settings_json["agent"]["default_model"]["provider"],
2439 json!("fake")
2440 );
2441
2442 // Register a thinking model and select it.
2443 cx.update(|cx| {
2444 let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2445 "fake-corp",
2446 "fake-thinking",
2447 "Fake Thinking",
2448 true,
2449 ));
2450 let thinking_provider = Arc::new(
2451 FakeLanguageModelProvider::new(
2452 LanguageModelProviderId::from("fake-corp".to_string()),
2453 LanguageModelProviderName::from("Fake Corp".to_string()),
2454 )
2455 .with_models(vec![thinking_model]),
2456 );
2457 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2458 registry.register_provider(thinking_provider, cx);
2459 });
2460 });
2461 agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2462
2463 let selector = connection.model_selector(&session_id).unwrap();
2464 cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2465 .await
2466 .unwrap();
2467 cx.run_until_parked();
2468
2469 // Verify enable_thinking was written to settings as true.
2470 let settings_content = fs.load(paths::settings_file()).await.unwrap();
2471 let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
2472 assert_eq!(
2473 settings_json["agent"]["default_model"]["enable_thinking"],
2474 json!(true),
2475 "selecting a thinking model should persist enable_thinking: true to settings"
2476 );
2477 }
2478
2479 #[gpui::test]
2480 async fn test_select_model_updates_thinking_enabled(cx: &mut TestAppContext) {
2481 init_test(cx);
2482 let fs = FakeFs::new(cx.executor());
2483 fs.create_dir(paths::settings_file().parent().unwrap())
2484 .await
2485 .unwrap();
2486 fs.insert_file(paths::settings_file(), b"{}".to_vec()).await;
2487 let project = Project::test(fs.clone(), [], cx).await;
2488
2489 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2490 let agent =
2491 cx.update(|cx| NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx));
2492 let connection = NativeAgentConnection(agent.clone());
2493
2494 let acp_thread = cx
2495 .update(|cx| {
2496 Rc::new(connection.clone()).new_session(
2497 project.clone(),
2498 PathList::new(&[Path::new("/a")]),
2499 cx,
2500 )
2501 })
2502 .await
2503 .unwrap();
2504 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2505
2506 // Register a second provider with a thinking model.
2507 cx.update(|cx| {
2508 let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2509 "fake-corp",
2510 "fake-thinking",
2511 "Fake Thinking",
2512 true,
2513 ));
2514 let thinking_provider = Arc::new(
2515 FakeLanguageModelProvider::new(
2516 LanguageModelProviderId::from("fake-corp".to_string()),
2517 LanguageModelProviderName::from("Fake Corp".to_string()),
2518 )
2519 .with_models(vec![thinking_model]),
2520 );
2521 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2522 registry.register_provider(thinking_provider, cx);
2523 });
2524 });
2525 // Refresh the agent's model list so it picks up the new provider.
2526 agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2527
2528 // Thread starts with thinking_enabled = false (the default).
2529 agent.read_with(cx, |agent, _| {
2530 let session = agent.sessions.get(&session_id).unwrap();
2531 session.thread.read_with(cx, |thread, _| {
2532 assert!(!thread.thinking_enabled(), "thinking defaults to false");
2533 });
2534 });
2535
2536 // Select the thinking model via select_model.
2537 let selector = connection.model_selector(&session_id).unwrap();
2538 cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2539 .await
2540 .unwrap();
2541
2542 // select_model should have enabled thinking based on the model's supports_thinking().
2543 agent.read_with(cx, |agent, _| {
2544 let session = agent.sessions.get(&session_id).unwrap();
2545 session.thread.read_with(cx, |thread, _| {
2546 assert!(
2547 thread.thinking_enabled(),
2548 "select_model should enable thinking when model supports it"
2549 );
2550 });
2551 });
2552
2553 // Switch back to the non-thinking model.
2554 let selector = connection.model_selector(&session_id).unwrap();
2555 cx.update(|cx| selector.select_model(acp::ModelId::new("fake/fake"), cx))
2556 .await
2557 .unwrap();
2558
2559 // select_model should have disabled thinking.
2560 agent.read_with(cx, |agent, _| {
2561 let session = agent.sessions.get(&session_id).unwrap();
2562 session.thread.read_with(cx, |thread, _| {
2563 assert!(
2564 !thread.thinking_enabled(),
2565 "select_model should disable thinking when model does not support it"
2566 );
2567 });
2568 });
2569 }
2570
2571 #[gpui::test]
2572 async fn test_summarization_model_survives_transient_registry_clearing(
2573 cx: &mut TestAppContext,
2574 ) {
2575 init_test(cx);
2576 let fs = FakeFs::new(cx.executor());
2577 fs.insert_tree("/", json!({ "a": {} })).await;
2578 let project = Project::test(fs.clone(), [], cx).await;
2579
2580 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2581 let agent =
2582 cx.update(|cx| NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx));
2583 let connection = Rc::new(NativeAgentConnection(agent.clone()));
2584
2585 let acp_thread = cx
2586 .update(|cx| {
2587 connection.clone().new_session(
2588 project.clone(),
2589 PathList::new(&[Path::new("/a")]),
2590 cx,
2591 )
2592 })
2593 .await
2594 .unwrap();
2595 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2596
2597 let thread = agent.read_with(cx, |agent, _| {
2598 agent.sessions.get(&session_id).unwrap().thread.clone()
2599 });
2600
2601 thread.read_with(cx, |thread, _| {
2602 assert!(
2603 thread.summarization_model().is_some(),
2604 "session should have a summarization model from the test registry"
2605 );
2606 });
2607
2608 // Simulate what happens during a provider blip:
2609 // update_active_language_model_from_settings calls set_default_model(None)
2610 // when it can't resolve the model, clearing all fallbacks.
2611 cx.update(|cx| {
2612 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2613 registry.set_default_model(None, cx);
2614 });
2615 });
2616 cx.run_until_parked();
2617
2618 thread.read_with(cx, |thread, _| {
2619 assert!(
2620 thread.summarization_model().is_some(),
2621 "summarization model should survive a transient default model clearing"
2622 );
2623 });
2624 }
2625
2626 #[gpui::test]
2627 async fn test_loaded_thread_preserves_thinking_enabled(cx: &mut TestAppContext) {
2628 init_test(cx);
2629 let fs = FakeFs::new(cx.executor());
2630 fs.insert_tree("/", json!({ "a": {} })).await;
2631 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2632 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2633 let agent = cx.update(|cx| {
2634 NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
2635 });
2636 let connection = Rc::new(NativeAgentConnection(agent.clone()));
2637
2638 // Register a thinking model.
2639 let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2640 "fake-corp",
2641 "fake-thinking",
2642 "Fake Thinking",
2643 true,
2644 ));
2645 let thinking_provider = Arc::new(
2646 FakeLanguageModelProvider::new(
2647 LanguageModelProviderId::from("fake-corp".to_string()),
2648 LanguageModelProviderName::from("Fake Corp".to_string()),
2649 )
2650 .with_models(vec![thinking_model.clone()]),
2651 );
2652 cx.update(|cx| {
2653 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2654 registry.register_provider(thinking_provider, cx);
2655 });
2656 });
2657 agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2658
2659 // Create a thread and select the thinking model.
2660 let acp_thread = cx
2661 .update(|cx| {
2662 connection.clone().new_session(
2663 project.clone(),
2664 PathList::new(&[Path::new("/a")]),
2665 cx,
2666 )
2667 })
2668 .await
2669 .unwrap();
2670 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2671
2672 let selector = connection.model_selector(&session_id).unwrap();
2673 cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2674 .await
2675 .unwrap();
2676
2677 // Verify thinking is enabled after selecting the thinking model.
2678 let thread = agent.read_with(cx, |agent, _| {
2679 agent.sessions.get(&session_id).unwrap().thread.clone()
2680 });
2681 thread.read_with(cx, |thread, _| {
2682 assert!(
2683 thread.thinking_enabled(),
2684 "thinking should be enabled after selecting thinking model"
2685 );
2686 });
2687
2688 // Send a message so the thread gets persisted.
2689 let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
2690 let send = cx.foreground_executor().spawn(send);
2691 cx.run_until_parked();
2692
2693 thinking_model.send_last_completion_stream_text_chunk("Response.");
2694 thinking_model.end_last_completion_stream();
2695
2696 send.await.unwrap();
2697 cx.run_until_parked();
2698
2699 // Close the session so it can be reloaded from disk.
2700 cx.update(|cx| connection.clone().close_session(&session_id, cx))
2701 .await
2702 .unwrap();
2703 drop(thread);
2704 drop(acp_thread);
2705 agent.read_with(cx, |agent, _| {
2706 assert!(agent.sessions.is_empty());
2707 });
2708
2709 // Reload the thread and verify thinking_enabled is still true.
2710 let reloaded_acp_thread = agent
2711 .update(cx, |agent, cx| {
2712 agent.open_thread(session_id.clone(), project.clone(), cx)
2713 })
2714 .await
2715 .unwrap();
2716 let reloaded_thread = agent.read_with(cx, |agent, _| {
2717 agent.sessions.get(&session_id).unwrap().thread.clone()
2718 });
2719 reloaded_thread.read_with(cx, |thread, _| {
2720 assert!(
2721 thread.thinking_enabled(),
2722 "thinking_enabled should be preserved when reloading a thread with a thinking model"
2723 );
2724 });
2725
2726 drop(reloaded_acp_thread);
2727 }
2728
2729 #[gpui::test]
2730 async fn test_loaded_thread_preserves_model(cx: &mut TestAppContext) {
2731 init_test(cx);
2732 let fs = FakeFs::new(cx.executor());
2733 fs.insert_tree("/", json!({ "a": {} })).await;
2734 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2735 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2736 let agent = cx.update(|cx| {
2737 NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
2738 });
2739 let connection = Rc::new(NativeAgentConnection(agent.clone()));
2740
2741 // Register a model where id() != name(), like real Anthropic models
2742 // (e.g. id="claude-sonnet-4-5-thinking-latest", name="Claude Sonnet 4.5 Thinking").
2743 let model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2744 "fake-corp",
2745 "custom-model-id",
2746 "Custom Model Display Name",
2747 false,
2748 ));
2749 let provider = Arc::new(
2750 FakeLanguageModelProvider::new(
2751 LanguageModelProviderId::from("fake-corp".to_string()),
2752 LanguageModelProviderName::from("Fake Corp".to_string()),
2753 )
2754 .with_models(vec![model.clone()]),
2755 );
2756 cx.update(|cx| {
2757 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2758 registry.register_provider(provider, cx);
2759 });
2760 });
2761 agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2762
2763 // Create a thread and select the model.
2764 let acp_thread = cx
2765 .update(|cx| {
2766 connection.clone().new_session(
2767 project.clone(),
2768 PathList::new(&[Path::new("/a")]),
2769 cx,
2770 )
2771 })
2772 .await
2773 .unwrap();
2774 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2775
2776 let selector = connection.model_selector(&session_id).unwrap();
2777 cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/custom-model-id"), cx))
2778 .await
2779 .unwrap();
2780
2781 let thread = agent.read_with(cx, |agent, _| {
2782 agent.sessions.get(&session_id).unwrap().thread.clone()
2783 });
2784 thread.read_with(cx, |thread, _| {
2785 assert_eq!(
2786 thread.model().unwrap().id().0.as_ref(),
2787 "custom-model-id",
2788 "model should be set before persisting"
2789 );
2790 });
2791
2792 // Send a message so the thread gets persisted.
2793 let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
2794 let send = cx.foreground_executor().spawn(send);
2795 cx.run_until_parked();
2796
2797 model.send_last_completion_stream_text_chunk("Response.");
2798 model.end_last_completion_stream();
2799
2800 send.await.unwrap();
2801 cx.run_until_parked();
2802
2803 // Close the session so it can be reloaded from disk.
2804 cx.update(|cx| connection.clone().close_session(&session_id, cx))
2805 .await
2806 .unwrap();
2807 drop(thread);
2808 drop(acp_thread);
2809 agent.read_with(cx, |agent, _| {
2810 assert!(agent.sessions.is_empty());
2811 });
2812
2813 // Reload the thread and verify the model was preserved.
2814 let reloaded_acp_thread = agent
2815 .update(cx, |agent, cx| {
2816 agent.open_thread(session_id.clone(), project.clone(), cx)
2817 })
2818 .await
2819 .unwrap();
2820 let reloaded_thread = agent.read_with(cx, |agent, _| {
2821 agent.sessions.get(&session_id).unwrap().thread.clone()
2822 });
2823 reloaded_thread.read_with(cx, |thread, _| {
2824 let reloaded_model = thread
2825 .model()
2826 .expect("model should be present after reload");
2827 assert_eq!(
2828 reloaded_model.id().0.as_ref(),
2829 "custom-model-id",
2830 "reloaded thread should have the same model, not fall back to the default"
2831 );
2832 });
2833
2834 drop(reloaded_acp_thread);
2835 }
2836
2837 #[gpui::test]
2838 async fn test_save_load_thread(cx: &mut TestAppContext) {
2839 init_test(cx);
2840 let fs = FakeFs::new(cx.executor());
2841 fs.insert_tree(
2842 "/",
2843 json!({
2844 "a": {
2845 "b.md": "Lorem"
2846 }
2847 }),
2848 )
2849 .await;
2850 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2851 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2852 let agent = cx.update(|cx| {
2853 NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
2854 });
2855 let connection = Rc::new(NativeAgentConnection(agent.clone()));
2856
2857 let acp_thread = cx
2858 .update(|cx| {
2859 connection
2860 .clone()
2861 .new_session(project.clone(), PathList::new(&[Path::new("")]), cx)
2862 })
2863 .await
2864 .unwrap();
2865 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2866 let thread = agent.read_with(cx, |agent, _| {
2867 agent.sessions.get(&session_id).unwrap().thread.clone()
2868 });
2869
2870 // Ensure empty threads are not saved, even if they get mutated.
2871 let model = Arc::new(FakeLanguageModel::default());
2872 let summary_model = Arc::new(FakeLanguageModel::default());
2873 thread.update(cx, |thread, cx| {
2874 thread.set_model(model.clone(), cx);
2875 thread.set_summarization_model(Some(summary_model.clone()), cx);
2876 });
2877 cx.run_until_parked();
2878 assert_eq!(thread_entries(&thread_store, cx), vec![]);
2879
2880 let send = acp_thread.update(cx, |thread, cx| {
2881 thread.send(
2882 vec![
2883 "What does ".into(),
2884 acp::ContentBlock::ResourceLink(acp::ResourceLink::new(
2885 "b.md",
2886 MentionUri::File {
2887 abs_path: path!("/a/b.md").into(),
2888 }
2889 .to_uri()
2890 .to_string(),
2891 )),
2892 " mean?".into(),
2893 ],
2894 cx,
2895 )
2896 });
2897 let send = cx.foreground_executor().spawn(send);
2898 cx.run_until_parked();
2899
2900 model.send_last_completion_stream_text_chunk("Lorem.");
2901 model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2902 language_model::TokenUsage {
2903 input_tokens: 150,
2904 output_tokens: 75,
2905 ..Default::default()
2906 },
2907 ));
2908 model.end_last_completion_stream();
2909 cx.run_until_parked();
2910 summary_model
2911 .send_last_completion_stream_text_chunk(&format!("Explaining {}", path!("/a/b.md")));
2912 summary_model.end_last_completion_stream();
2913
2914 send.await.unwrap();
2915 let uri = MentionUri::File {
2916 abs_path: path!("/a/b.md").into(),
2917 }
2918 .to_uri();
2919 acp_thread.read_with(cx, |thread, cx| {
2920 assert_eq!(
2921 thread.to_markdown(cx),
2922 formatdoc! {"
2923 ## User
2924
2925 What does [@b.md]({uri}) mean?
2926
2927 ## Assistant
2928
2929 Lorem.
2930
2931 "}
2932 )
2933 });
2934
2935 cx.run_until_parked();
2936
2937 // Set a draft prompt with rich content blocks and scroll position
2938 // AFTER run_until_parked, so the only save that captures these
2939 // changes is the one performed by close_session itself.
2940 let draft_blocks = vec![
2941 acp::ContentBlock::Text(acp::TextContent::new("Check out ")),
2942 acp::ContentBlock::ResourceLink(acp::ResourceLink::new("b.md", uri.to_string())),
2943 acp::ContentBlock::Text(acp::TextContent::new(" please")),
2944 ];
2945 acp_thread.update(cx, |thread, cx| {
2946 thread.set_draft_prompt(Some(draft_blocks.clone()), cx);
2947 });
2948 thread.update(cx, |thread, _cx| {
2949 thread.set_ui_scroll_position(Some(gpui::ListOffset {
2950 item_ix: 5,
2951 offset_in_item: gpui::px(12.5),
2952 }));
2953 });
2954
2955 // Close the session so it can be reloaded from disk.
2956 cx.update(|cx| connection.clone().close_session(&session_id, cx))
2957 .await
2958 .unwrap();
2959 drop(thread);
2960 drop(acp_thread);
2961 agent.read_with(cx, |agent, _| {
2962 assert_eq!(agent.sessions.keys().cloned().collect::<Vec<_>>(), []);
2963 });
2964
2965 // Ensure the thread can be reloaded from disk.
2966 assert_eq!(
2967 thread_entries(&thread_store, cx),
2968 vec![(
2969 session_id.clone(),
2970 format!("Explaining {}", path!("/a/b.md"))
2971 )]
2972 );
2973 let acp_thread = agent
2974 .update(cx, |agent, cx| {
2975 agent.open_thread(session_id.clone(), project.clone(), cx)
2976 })
2977 .await
2978 .unwrap();
2979 acp_thread.read_with(cx, |thread, cx| {
2980 assert_eq!(
2981 thread.to_markdown(cx),
2982 formatdoc! {"
2983 ## User
2984
2985 What does [@b.md]({uri}) mean?
2986
2987 ## Assistant
2988
2989 Lorem.
2990
2991 "}
2992 )
2993 });
2994
2995 // Ensure the draft prompt with rich content blocks survived the round-trip.
2996 acp_thread.read_with(cx, |thread, _| {
2997 assert_eq!(thread.draft_prompt(), Some(draft_blocks.as_slice()));
2998 });
2999
3000 // Ensure token usage survived the round-trip.
3001 acp_thread.read_with(cx, |thread, _| {
3002 let usage = thread
3003 .token_usage()
3004 .expect("token usage should be restored after reload");
3005 assert_eq!(usage.input_tokens, 150);
3006 assert_eq!(usage.output_tokens, 75);
3007 });
3008
3009 // Ensure scroll position survived the round-trip.
3010 acp_thread.read_with(cx, |thread, _| {
3011 let scroll = thread
3012 .ui_scroll_position()
3013 .expect("scroll position should be restored after reload");
3014 assert_eq!(scroll.item_ix, 5);
3015 assert_eq!(scroll.offset_in_item, gpui::px(12.5));
3016 });
3017 }
3018
3019 #[gpui::test]
3020 async fn test_close_session_saves_thread(cx: &mut TestAppContext) {
3021 init_test(cx);
3022 let fs = FakeFs::new(cx.executor());
3023 fs.insert_tree(
3024 "/",
3025 json!({
3026 "a": {
3027 "file.txt": "hello"
3028 }
3029 }),
3030 )
3031 .await;
3032 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
3033 let thread_store = cx.new(|cx| ThreadStore::new(cx));
3034 let agent = cx.update(|cx| {
3035 NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
3036 });
3037 let connection = Rc::new(NativeAgentConnection(agent.clone()));
3038
3039 let acp_thread = cx
3040 .update(|cx| {
3041 connection
3042 .clone()
3043 .new_session(project.clone(), PathList::new(&[Path::new("")]), cx)
3044 })
3045 .await
3046 .unwrap();
3047 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
3048 let thread = agent.read_with(cx, |agent, _| {
3049 agent.sessions.get(&session_id).unwrap().thread.clone()
3050 });
3051
3052 let model = Arc::new(FakeLanguageModel::default());
3053 thread.update(cx, |thread, cx| {
3054 thread.set_model(model.clone(), cx);
3055 });
3056
3057 // Send a message so the thread is non-empty (empty threads aren't saved).
3058 let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx));
3059 let send = cx.foreground_executor().spawn(send);
3060 cx.run_until_parked();
3061
3062 model.send_last_completion_stream_text_chunk("world");
3063 model.end_last_completion_stream();
3064 send.await.unwrap();
3065 cx.run_until_parked();
3066
3067 // Set a draft prompt WITHOUT calling run_until_parked afterwards.
3068 // This means no observe-triggered save has run for this change.
3069 // The only way this data gets persisted is if close_session
3070 // itself performs the save.
3071 let draft_blocks = vec![acp::ContentBlock::Text(acp::TextContent::new(
3072 "unsaved draft",
3073 ))];
3074 acp_thread.update(cx, |thread, cx| {
3075 thread.set_draft_prompt(Some(draft_blocks.clone()), cx);
3076 });
3077
3078 // Close the session immediately — no run_until_parked in between.
3079 cx.update(|cx| connection.clone().close_session(&session_id, cx))
3080 .await
3081 .unwrap();
3082 cx.run_until_parked();
3083
3084 // Reopen and verify the draft prompt was saved.
3085 let reloaded = agent
3086 .update(cx, |agent, cx| {
3087 agent.open_thread(session_id.clone(), project.clone(), cx)
3088 })
3089 .await
3090 .unwrap();
3091 reloaded.read_with(cx, |thread, _| {
3092 assert_eq!(
3093 thread.draft_prompt(),
3094 Some(draft_blocks.as_slice()),
3095 "close_session must save the thread; draft prompt was lost"
3096 );
3097 });
3098 }
3099
3100 #[gpui::test]
3101 async fn test_thread_summary_releases_loaded_session(cx: &mut TestAppContext) {
3102 init_test(cx);
3103 let fs = FakeFs::new(cx.executor());
3104 fs.insert_tree(
3105 "/",
3106 json!({
3107 "a": {
3108 "file.txt": "hello"
3109 }
3110 }),
3111 )
3112 .await;
3113 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
3114 let thread_store = cx.new(|cx| ThreadStore::new(cx));
3115 let agent = cx.update(|cx| {
3116 NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
3117 });
3118 let connection = Rc::new(NativeAgentConnection(agent.clone()));
3119
3120 let acp_thread = cx
3121 .update(|cx| {
3122 connection
3123 .clone()
3124 .new_session(project.clone(), PathList::new(&[Path::new("")]), cx)
3125 })
3126 .await
3127 .unwrap();
3128 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
3129 let thread = agent.read_with(cx, |agent, _| {
3130 agent.sessions.get(&session_id).unwrap().thread.clone()
3131 });
3132
3133 let model = Arc::new(FakeLanguageModel::default());
3134 let summary_model = Arc::new(FakeLanguageModel::default());
3135 thread.update(cx, |thread, cx| {
3136 thread.set_model(model.clone(), cx);
3137 thread.set_summarization_model(Some(summary_model.clone()), cx);
3138 });
3139
3140 let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx));
3141 let send = cx.foreground_executor().spawn(send);
3142 cx.run_until_parked();
3143
3144 model.send_last_completion_stream_text_chunk("world");
3145 model.end_last_completion_stream();
3146 send.await.unwrap();
3147 cx.run_until_parked();
3148
3149 let summary = agent.update(cx, |agent, cx| {
3150 agent.thread_summary(session_id.clone(), project.clone(), cx)
3151 });
3152 cx.run_until_parked();
3153
3154 summary_model.send_last_completion_stream_text_chunk("summary");
3155 summary_model.end_last_completion_stream();
3156
3157 assert_eq!(summary.await.unwrap(), "summary");
3158 cx.run_until_parked();
3159
3160 agent.read_with(cx, |agent, _| {
3161 let session = agent
3162 .sessions
3163 .get(&session_id)
3164 .expect("thread_summary should not close the active session");
3165 assert_eq!(
3166 session.ref_count, 1,
3167 "thread_summary should release its temporary session reference"
3168 );
3169 });
3170
3171 cx.update(|cx| connection.clone().close_session(&session_id, cx))
3172 .await
3173 .unwrap();
3174 cx.run_until_parked();
3175
3176 agent.read_with(cx, |agent, _| {
3177 assert!(
3178 agent.sessions.is_empty(),
3179 "closing the active session after thread_summary should unload it"
3180 );
3181 });
3182 }
3183
3184 #[gpui::test]
3185 async fn test_loaded_sessions_keep_state_until_last_close(cx: &mut TestAppContext) {
3186 init_test(cx);
3187 let fs = FakeFs::new(cx.executor());
3188 fs.insert_tree(
3189 "/",
3190 json!({
3191 "a": {
3192 "file.txt": "hello"
3193 }
3194 }),
3195 )
3196 .await;
3197 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
3198 let thread_store = cx.new(|cx| ThreadStore::new(cx));
3199 let agent = cx.update(|cx| {
3200 NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
3201 });
3202 let connection = Rc::new(NativeAgentConnection(agent.clone()));
3203
3204 let acp_thread = cx
3205 .update(|cx| {
3206 connection
3207 .clone()
3208 .new_session(project.clone(), PathList::new(&[Path::new("")]), cx)
3209 })
3210 .await
3211 .unwrap();
3212 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
3213 let thread = agent.read_with(cx, |agent, _| {
3214 agent.sessions.get(&session_id).unwrap().thread.clone()
3215 });
3216
3217 let model = cx.update(|cx| {
3218 LanguageModelRegistry::read_global(cx)
3219 .default_model()
3220 .map(|default_model| default_model.model)
3221 .expect("default test model should be available")
3222 });
3223 let fake_model = model.as_fake();
3224 thread.update(cx, |thread, cx| {
3225 thread.set_model(model.clone(), cx);
3226 });
3227
3228 let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx));
3229 let send = cx.foreground_executor().spawn(send);
3230 cx.run_until_parked();
3231
3232 fake_model.send_last_completion_stream_text_chunk("world");
3233 fake_model.end_last_completion_stream();
3234 send.await.unwrap();
3235 cx.run_until_parked();
3236
3237 cx.update(|cx| connection.clone().close_session(&session_id, cx))
3238 .await
3239 .unwrap();
3240 drop(thread);
3241 drop(acp_thread);
3242 agent.read_with(cx, |agent, _| {
3243 assert!(agent.sessions.is_empty());
3244 });
3245
3246 let first_loaded_thread = cx.update(|cx| {
3247 connection.clone().load_session(
3248 session_id.clone(),
3249 project.clone(),
3250 PathList::new(&[Path::new("")]),
3251 None,
3252 cx,
3253 )
3254 });
3255 let second_loaded_thread = cx.update(|cx| {
3256 connection.clone().load_session(
3257 session_id.clone(),
3258 project.clone(),
3259 PathList::new(&[Path::new("")]),
3260 None,
3261 cx,
3262 )
3263 });
3264
3265 let first_loaded_thread = first_loaded_thread.await.unwrap();
3266 let second_loaded_thread = second_loaded_thread.await.unwrap();
3267
3268 cx.run_until_parked();
3269
3270 assert_eq!(
3271 first_loaded_thread.entity_id(),
3272 second_loaded_thread.entity_id(),
3273 "concurrent loads for the same session should share one AcpThread"
3274 );
3275
3276 cx.update(|cx| connection.clone().close_session(&session_id, cx))
3277 .await
3278 .unwrap();
3279
3280 agent.read_with(cx, |agent, _| {
3281 assert!(
3282 agent.sessions.contains_key(&session_id),
3283 "closing one loaded session should not drop shared session state"
3284 );
3285 });
3286
3287 let follow_up = second_loaded_thread.update(cx, |thread, cx| {
3288 thread.send(vec!["still there?".into()], cx)
3289 });
3290 let follow_up = cx.foreground_executor().spawn(follow_up);
3291 cx.run_until_parked();
3292
3293 fake_model.send_last_completion_stream_text_chunk("yes");
3294 fake_model.end_last_completion_stream();
3295 follow_up.await.unwrap();
3296 cx.run_until_parked();
3297
3298 second_loaded_thread.read_with(cx, |thread, cx| {
3299 assert_eq!(
3300 thread.to_markdown(cx),
3301 formatdoc! {"
3302 ## User
3303
3304 hello
3305
3306 ## Assistant
3307
3308 world
3309
3310 ## User
3311
3312 still there?
3313
3314 ## Assistant
3315
3316 yes
3317
3318 "}
3319 );
3320 });
3321
3322 cx.update(|cx| connection.clone().close_session(&session_id, cx))
3323 .await
3324 .unwrap();
3325
3326 cx.run_until_parked();
3327
3328 drop(first_loaded_thread);
3329 drop(second_loaded_thread);
3330 agent.read_with(cx, |agent, _| {
3331 assert!(agent.sessions.is_empty());
3332 });
3333 }
3334
3335 #[gpui::test]
3336 async fn test_rapid_title_changes_do_not_loop(cx: &mut TestAppContext) {
3337 // Regression test: rapid title changes must not cause a propagation loop
3338 // between Thread and AcpThread via handle_thread_title_updated.
3339 init_test(cx);
3340 let fs = FakeFs::new(cx.executor());
3341 fs.insert_tree("/", json!({ "a": {} })).await;
3342 let project = Project::test(fs.clone(), [], cx).await;
3343 let thread_store = cx.new(|cx| ThreadStore::new(cx));
3344 let agent = cx.update(|cx| {
3345 NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
3346 });
3347 let connection = Rc::new(NativeAgentConnection(agent.clone()));
3348
3349 let acp_thread = cx
3350 .update(|cx| {
3351 connection
3352 .clone()
3353 .new_session(project.clone(), PathList::new(&[Path::new("")]), cx)
3354 })
3355 .await
3356 .unwrap();
3357
3358 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
3359 let thread = agent.read_with(cx, |agent, _| {
3360 agent.sessions.get(&session_id).unwrap().thread.clone()
3361 });
3362
3363 let title_updated_count = Rc::new(std::cell::RefCell::new(0usize));
3364 cx.update(|cx| {
3365 let count = title_updated_count.clone();
3366 cx.subscribe(
3367 &thread,
3368 move |_entity: Entity<Thread>, _event: &TitleUpdated, _cx: &mut App| {
3369 let new_count = {
3370 let mut count = count.borrow_mut();
3371 *count += 1;
3372 *count
3373 };
3374 assert!(
3375 new_count <= 2,
3376 "TitleUpdated fired {new_count} times; \
3377 title updates are looping"
3378 );
3379 },
3380 )
3381 .detach();
3382 });
3383
3384 thread.update(cx, |thread, cx| thread.set_title("first".into(), cx));
3385 thread.update(cx, |thread, cx| thread.set_title("second".into(), cx));
3386
3387 cx.run_until_parked();
3388
3389 thread.read_with(cx, |thread, _| {
3390 assert_eq!(thread.title(), Some("second".into()));
3391 });
3392 acp_thread.read_with(cx, |acp_thread, _| {
3393 assert_eq!(acp_thread.title(), Some("second".into()));
3394 });
3395
3396 assert_eq!(*title_updated_count.borrow(), 2);
3397 }
3398
3399 fn thread_entries(
3400 thread_store: &Entity<ThreadStore>,
3401 cx: &mut TestAppContext,
3402 ) -> Vec<(acp::SessionId, String)> {
3403 thread_store.read_with(cx, |store, _| {
3404 store
3405 .entries()
3406 .map(|entry| (entry.id.clone(), entry.title.to_string()))
3407 .collect::<Vec<_>>()
3408 })
3409 }
3410
3411 fn init_test(cx: &mut TestAppContext) {
3412 env_logger::try_init().ok();
3413 cx.update(|cx| {
3414 let settings_store = SettingsStore::test(cx);
3415 cx.set_global(settings_store);
3416
3417 LanguageModelRegistry::test(cx);
3418 });
3419 }
3420}
3421
3422fn mcp_message_content_to_acp_content_block(
3423 content: context_server::types::MessageContent,
3424) -> acp::ContentBlock {
3425 match content {
3426 context_server::types::MessageContent::Text {
3427 text,
3428 annotations: _,
3429 } => text.into(),
3430 context_server::types::MessageContent::Image {
3431 data,
3432 mime_type,
3433 annotations: _,
3434 } => acp::ContentBlock::Image(acp::ImageContent::new(data, mime_type)),
3435 context_server::types::MessageContent::Audio {
3436 data,
3437 mime_type,
3438 annotations: _,
3439 } => acp::ContentBlock::Audio(acp::AudioContent::new(data, mime_type)),
3440 context_server::types::MessageContent::Resource {
3441 resource,
3442 annotations: _,
3443 } => {
3444 let mut link =
3445 acp::ResourceLink::new(resource.uri.to_string(), resource.uri.to_string());
3446 if let Some(mime_type) = resource.mime_type {
3447 link = link.mime_type(mime_type);
3448 }
3449 acp::ContentBlock::ResourceLink(link)
3450 }
3451 }
3452}