1mod db;
2mod edit_agent;
3mod legacy_thread;
4mod native_agent_server;
5pub mod outline;
6mod pattern_extraction;
7mod templates;
8#[cfg(test)]
9mod tests;
10mod thread;
11mod thread_store;
12mod tool_permissions;
13mod tools;
14
15use context_server::ContextServerId;
16pub use db::*;
17use itertools::Itertools;
18pub use native_agent_server::NativeAgentServer;
19pub use pattern_extraction::*;
20pub use shell_command_parser::extract_commands;
21pub use templates::*;
22pub use thread::*;
23pub use thread_store::*;
24pub use tool_permissions::*;
25pub use tools::*;
26
27use acp_thread::{
28 AcpThread, AgentModelSelector, AgentSessionInfo, AgentSessionList, AgentSessionListRequest,
29 AgentSessionListResponse, TokenUsageRatio, UserMessageId,
30};
31use agent_client_protocol as acp;
32use anyhow::{Context as _, Result, anyhow};
33use chrono::{DateTime, Utc};
34use collections::{HashMap, HashSet, IndexMap};
35use fs::Fs;
36use futures::channel::{mpsc, oneshot};
37use futures::future::Shared;
38use futures::{FutureExt as _, StreamExt as _, future};
39use gpui::{
40 App, AppContext, AsyncApp, Context, Entity, EntityId, SharedString, Subscription, Task,
41 WeakEntity,
42};
43use language_model::{IconOrSvg, LanguageModel, LanguageModelProvider, LanguageModelRegistry};
44use project::{AgentId, Project, ProjectItem, ProjectPath, Worktree};
45use prompt_store::{
46 ProjectContext, PromptStore, RULES_FILE_NAMES, RulesFileContext, UserRulesContext,
47 WorktreeContext,
48};
49use serde::{Deserialize, Serialize};
50use settings::{LanguageModelSelection, update_settings_file};
51use std::any::Any;
52use std::path::PathBuf;
53use std::rc::Rc;
54use std::sync::{Arc, LazyLock};
55use util::ResultExt;
56use util::path_list::PathList;
57use util::rel_path::RelPath;
58
59#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
60pub struct ProjectSnapshot {
61 pub worktree_snapshots: Vec<project::telemetry_snapshot::TelemetryWorktreeSnapshot>,
62 pub timestamp: DateTime<Utc>,
63}
64
65pub struct RulesLoadingError {
66 pub message: SharedString,
67}
68
69struct ProjectState {
70 project: Entity<Project>,
71 project_context: Entity<ProjectContext>,
72 project_context_needs_refresh: watch::Sender<()>,
73 _maintain_project_context: Task<Result<()>>,
74 context_server_registry: Entity<ContextServerRegistry>,
75 _subscriptions: Vec<Subscription>,
76}
77
78/// Holds both the internal Thread and the AcpThread for a session
79struct Session {
80 /// The internal thread that processes messages
81 thread: Entity<Thread>,
82 /// The ACP thread that handles protocol communication
83 acp_thread: Entity<acp_thread::AcpThread>,
84 project_id: EntityId,
85 pending_save: Task<Result<()>>,
86 _subscriptions: Vec<Subscription>,
87 ref_count: usize,
88}
89
90struct PendingSession {
91 task: Shared<Task<Result<Entity<AcpThread>, Arc<anyhow::Error>>>>,
92 ref_count: usize,
93}
94
95pub struct LanguageModels {
96 /// Access language model by ID
97 models: HashMap<acp::ModelId, Arc<dyn LanguageModel>>,
98 /// Cached list for returning language model information
99 model_list: acp_thread::AgentModelList,
100 refresh_models_rx: watch::Receiver<()>,
101 refresh_models_tx: watch::Sender<()>,
102 _authenticate_all_providers_task: Task<()>,
103}
104
105impl LanguageModels {
106 fn new(cx: &mut App) -> Self {
107 let (refresh_models_tx, refresh_models_rx) = watch::channel(());
108
109 let mut this = Self {
110 models: HashMap::default(),
111 model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()),
112 refresh_models_rx,
113 refresh_models_tx,
114 _authenticate_all_providers_task: Self::authenticate_all_language_model_providers(cx),
115 };
116 this.refresh_list(cx);
117 this
118 }
119
120 fn refresh_list(&mut self, cx: &App) {
121 let providers = LanguageModelRegistry::global(cx)
122 .read(cx)
123 .visible_providers()
124 .into_iter()
125 .filter(|provider| provider.is_authenticated(cx))
126 .collect::<Vec<_>>();
127
128 let mut language_model_list = IndexMap::default();
129 let mut recommended_models = HashSet::default();
130
131 let mut recommended = Vec::new();
132 for provider in &providers {
133 for model in provider.recommended_models(cx) {
134 recommended_models.insert((model.provider_id(), model.id()));
135 recommended.push(Self::map_language_model_to_info(&model, provider));
136 }
137 }
138 if !recommended.is_empty() {
139 language_model_list.insert(
140 acp_thread::AgentModelGroupName("Recommended".into()),
141 recommended,
142 );
143 }
144
145 let mut models = HashMap::default();
146 for provider in providers {
147 let mut provider_models = Vec::new();
148 for model in provider.provided_models(cx) {
149 let model_info = Self::map_language_model_to_info(&model, &provider);
150 let model_id = model_info.id.clone();
151 provider_models.push(model_info);
152 models.insert(model_id, model);
153 }
154 if !provider_models.is_empty() {
155 language_model_list.insert(
156 acp_thread::AgentModelGroupName(provider.name().0.clone()),
157 provider_models,
158 );
159 }
160 }
161
162 self.models = models;
163 self.model_list = acp_thread::AgentModelList::Grouped(language_model_list);
164 self.refresh_models_tx.send(()).ok();
165 }
166
167 fn watch(&self) -> watch::Receiver<()> {
168 self.refresh_models_rx.clone()
169 }
170
171 pub fn model_from_id(&self, model_id: &acp::ModelId) -> Option<Arc<dyn LanguageModel>> {
172 self.models.get(model_id).cloned()
173 }
174
175 fn map_language_model_to_info(
176 model: &Arc<dyn LanguageModel>,
177 provider: &Arc<dyn LanguageModelProvider>,
178 ) -> acp_thread::AgentModelInfo {
179 acp_thread::AgentModelInfo {
180 id: Self::model_id(model),
181 name: model.name().0,
182 description: None,
183 icon: Some(match provider.icon() {
184 IconOrSvg::Svg(path) => acp_thread::AgentModelIcon::Path(path),
185 IconOrSvg::Icon(name) => acp_thread::AgentModelIcon::Named(name),
186 }),
187 is_latest: model.is_latest(),
188 cost: model.model_cost_info().map(|cost| cost.to_shared_string()),
189 }
190 }
191
192 fn model_id(model: &Arc<dyn LanguageModel>) -> acp::ModelId {
193 acp::ModelId::new(format!("{}/{}", model.provider_id().0, model.id().0))
194 }
195
196 fn authenticate_all_language_model_providers(cx: &mut App) -> Task<()> {
197 let authenticate_all_providers = LanguageModelRegistry::global(cx)
198 .read(cx)
199 .visible_providers()
200 .iter()
201 .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
202 .collect::<Vec<_>>();
203
204 cx.background_spawn(async move {
205 for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
206 if let Err(err) = authenticate_task.await {
207 match err {
208 language_model::AuthenticateError::CredentialsNotFound => {
209 // Since we're authenticating these providers in the
210 // background for the purposes of populating the
211 // language selector, we don't care about providers
212 // where the credentials are not found.
213 }
214 language_model::AuthenticateError::ConnectionRefused => {
215 // Not logging connection refused errors as they are mostly from LM Studio's noisy auth failures.
216 // LM Studio only has one auth method (endpoint call) which fails for users who haven't enabled it.
217 // TODO: Better manage LM Studio auth logic to avoid these noisy failures.
218 }
219 _ => {
220 // Some providers have noisy failure states that we
221 // don't want to spam the logs with every time the
222 // language model selector is initialized.
223 //
224 // Ideally these should have more clear failure modes
225 // that we know are safe to ignore here, like what we do
226 // with `CredentialsNotFound` above.
227 match provider_id.0.as_ref() {
228 "lmstudio" | "ollama" => {
229 // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
230 //
231 // These fail noisily, so we don't log them.
232 }
233 "copilot_chat" => {
234 // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
235 }
236 _ => {
237 log::error!(
238 "Failed to authenticate provider: {}: {err:#}",
239 provider_name.0
240 );
241 }
242 }
243 }
244 }
245 }
246 }
247 })
248 }
249}
250
251pub struct NativeAgent {
252 /// Session ID -> Session mapping
253 sessions: HashMap<acp::SessionId, Session>,
254 pending_sessions: HashMap<acp::SessionId, PendingSession>,
255 thread_store: Entity<ThreadStore>,
256 /// Project-specific state keyed by project EntityId
257 projects: HashMap<EntityId, ProjectState>,
258 /// Shared templates for all threads
259 templates: Arc<Templates>,
260 /// Cached model information
261 models: LanguageModels,
262 prompt_store: Option<Entity<PromptStore>>,
263 fs: Arc<dyn Fs>,
264 _subscriptions: Vec<Subscription>,
265}
266
267impl NativeAgent {
268 pub fn new(
269 thread_store: Entity<ThreadStore>,
270 templates: Arc<Templates>,
271 prompt_store: Option<Entity<PromptStore>>,
272 fs: Arc<dyn Fs>,
273 cx: &mut App,
274 ) -> Entity<NativeAgent> {
275 log::debug!("Creating new NativeAgent");
276
277 cx.new(|cx| {
278 let mut subscriptions = vec![cx.subscribe(
279 &LanguageModelRegistry::global(cx),
280 Self::handle_models_updated_event,
281 )];
282 if let Some(prompt_store) = prompt_store.as_ref() {
283 subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
284 }
285
286 Self {
287 sessions: HashMap::default(),
288 pending_sessions: HashMap::default(),
289 thread_store,
290 projects: HashMap::default(),
291 templates,
292 models: LanguageModels::new(cx),
293 prompt_store,
294 fs,
295 _subscriptions: subscriptions,
296 }
297 })
298 }
299
300 fn new_session(
301 &mut self,
302 project: Entity<Project>,
303 cx: &mut Context<Self>,
304 ) -> Entity<AcpThread> {
305 let project_id = self.get_or_create_project_state(&project, cx);
306 let project_state = &self.projects[&project_id];
307
308 let registry = LanguageModelRegistry::read_global(cx);
309 let available_count = registry.available_models(cx).count();
310 log::debug!("Total available models: {}", available_count);
311
312 let default_model = registry.default_model().and_then(|default_model| {
313 self.models
314 .model_from_id(&LanguageModels::model_id(&default_model.model))
315 });
316 let thread = cx.new(|cx| {
317 Thread::new(
318 project,
319 project_state.project_context.clone(),
320 project_state.context_server_registry.clone(),
321 self.templates.clone(),
322 default_model,
323 cx,
324 )
325 });
326
327 self.register_session(thread, project_id, 1, cx)
328 }
329
330 fn register_session(
331 &mut self,
332 thread_handle: Entity<Thread>,
333 project_id: EntityId,
334 ref_count: usize,
335 cx: &mut Context<Self>,
336 ) -> Entity<AcpThread> {
337 let connection = Rc::new(NativeAgentConnection(cx.entity()));
338
339 let thread = thread_handle.read(cx);
340 let session_id = thread.id().clone();
341 let parent_session_id = thread.parent_thread_id();
342 let title = thread.title();
343 let draft_prompt = thread.draft_prompt().map(Vec::from);
344 let scroll_position = thread.ui_scroll_position();
345 let token_usage = thread.latest_token_usage();
346 let project = thread.project.clone();
347 let action_log = thread.action_log.clone();
348 let prompt_capabilities_rx = thread.prompt_capabilities_rx.clone();
349 let acp_thread = cx.new(|cx| {
350 let mut acp_thread = acp_thread::AcpThread::new(
351 parent_session_id,
352 title,
353 None,
354 connection,
355 project.clone(),
356 action_log.clone(),
357 session_id.clone(),
358 prompt_capabilities_rx,
359 cx,
360 );
361 acp_thread.set_draft_prompt(draft_prompt, cx);
362 acp_thread.set_ui_scroll_position(scroll_position);
363 acp_thread.update_token_usage(token_usage, cx);
364 acp_thread
365 });
366
367 let registry = LanguageModelRegistry::read_global(cx);
368 let summarization_model = registry.thread_summary_model().map(|c| c.model);
369
370 let weak = cx.weak_entity();
371 let weak_thread = thread_handle.downgrade();
372 thread_handle.update(cx, |thread, cx| {
373 thread.set_summarization_model(summarization_model, cx);
374 thread.add_default_tools(
375 Rc::new(NativeThreadEnvironment {
376 acp_thread: acp_thread.downgrade(),
377 thread: weak_thread,
378 agent: weak,
379 }) as _,
380 cx,
381 )
382 });
383
384 let subscriptions = vec![
385 cx.subscribe(&thread_handle, Self::handle_thread_title_updated),
386 cx.subscribe(&thread_handle, Self::handle_thread_token_usage_updated),
387 cx.observe(&thread_handle, move |this, thread, cx| {
388 this.save_thread(thread, cx)
389 }),
390 ];
391
392 self.sessions.insert(
393 session_id,
394 Session {
395 thread: thread_handle,
396 acp_thread: acp_thread.clone(),
397 project_id,
398 _subscriptions: subscriptions,
399 pending_save: Task::ready(Ok(())),
400 ref_count,
401 },
402 );
403
404 self.update_available_commands_for_project(project_id, cx);
405
406 acp_thread
407 }
408
409 pub fn models(&self) -> &LanguageModels {
410 &self.models
411 }
412
413 fn get_or_create_project_state(
414 &mut self,
415 project: &Entity<Project>,
416 cx: &mut Context<Self>,
417 ) -> EntityId {
418 let project_id = project.entity_id();
419 if self.projects.contains_key(&project_id) {
420 return project_id;
421 }
422
423 let project_context = cx.new(|_| ProjectContext::new(vec![], vec![]));
424 self.register_project_with_initial_context(project.clone(), project_context, cx);
425 if let Some(state) = self.projects.get_mut(&project_id) {
426 state.project_context_needs_refresh.send(()).ok();
427 }
428 project_id
429 }
430
431 fn register_project_with_initial_context(
432 &mut self,
433 project: Entity<Project>,
434 project_context: Entity<ProjectContext>,
435 cx: &mut Context<Self>,
436 ) {
437 let project_id = project.entity_id();
438
439 let context_server_store = project.read(cx).context_server_store();
440 let context_server_registry =
441 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
442
443 let subscriptions = vec![
444 cx.subscribe(&project, Self::handle_project_event),
445 cx.subscribe(
446 &context_server_store,
447 Self::handle_context_server_store_updated,
448 ),
449 cx.subscribe(
450 &context_server_registry,
451 Self::handle_context_server_registry_event,
452 ),
453 ];
454
455 let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
456 watch::channel(());
457
458 self.projects.insert(
459 project_id,
460 ProjectState {
461 project,
462 project_context,
463 project_context_needs_refresh: project_context_needs_refresh_tx,
464 _maintain_project_context: cx.spawn(async move |this, cx| {
465 Self::maintain_project_context(
466 this,
467 project_id,
468 project_context_needs_refresh_rx,
469 cx,
470 )
471 .await
472 }),
473 context_server_registry,
474 _subscriptions: subscriptions,
475 },
476 );
477 }
478
479 fn session_project_state(&self, session_id: &acp::SessionId) -> Option<&ProjectState> {
480 self.sessions
481 .get(session_id)
482 .and_then(|session| self.projects.get(&session.project_id))
483 }
484
485 async fn maintain_project_context(
486 this: WeakEntity<Self>,
487 project_id: EntityId,
488 mut needs_refresh: watch::Receiver<()>,
489 cx: &mut AsyncApp,
490 ) -> Result<()> {
491 while needs_refresh.changed().await.is_ok() {
492 let project_context = this
493 .update(cx, |this, cx| {
494 let state = this
495 .projects
496 .get(&project_id)
497 .context("project state not found")?;
498 anyhow::Ok(Self::build_project_context(
499 &state.project,
500 this.prompt_store.as_ref(),
501 cx,
502 ))
503 })??
504 .await;
505 this.update(cx, |this, cx| {
506 if let Some(state) = this.projects.get(&project_id) {
507 state
508 .project_context
509 .update(cx, |current_project_context, _cx| {
510 *current_project_context = project_context;
511 });
512 }
513 })?;
514 }
515
516 Ok(())
517 }
518
519 fn build_project_context(
520 project: &Entity<Project>,
521 prompt_store: Option<&Entity<PromptStore>>,
522 cx: &mut App,
523 ) -> Task<ProjectContext> {
524 let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
525 let worktree_tasks = worktrees
526 .into_iter()
527 .map(|worktree| {
528 Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
529 })
530 .collect::<Vec<_>>();
531 let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
532 prompt_store.read_with(cx, |prompt_store, cx| {
533 let prompts = prompt_store.default_prompt_metadata();
534 let load_tasks = prompts.into_iter().map(|prompt_metadata| {
535 let contents = prompt_store.load(prompt_metadata.id, cx);
536 async move { (contents.await, prompt_metadata) }
537 });
538 cx.background_spawn(future::join_all(load_tasks))
539 })
540 } else {
541 Task::ready(vec![])
542 };
543
544 cx.spawn(async move |_cx| {
545 let (worktrees, default_user_rules) =
546 future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
547
548 let worktrees = worktrees
549 .into_iter()
550 .map(|(worktree, _rules_error)| {
551 // TODO: show error message
552 // if let Some(rules_error) = rules_error {
553 // this.update(cx, |_, cx| cx.emit(rules_error)).ok();
554 // }
555 worktree
556 })
557 .collect::<Vec<_>>();
558
559 let default_user_rules = default_user_rules
560 .into_iter()
561 .flat_map(|(contents, prompt_metadata)| match contents {
562 Ok(contents) => Some(UserRulesContext {
563 uuid: prompt_metadata.id.as_user()?,
564 title: prompt_metadata.title.map(|title| title.to_string()),
565 contents,
566 }),
567 Err(_err) => {
568 // TODO: show error message
569 // this.update(cx, |_, cx| {
570 // cx.emit(RulesLoadingError {
571 // message: format!("{err:?}").into(),
572 // });
573 // })
574 // .ok();
575 None
576 }
577 })
578 .collect::<Vec<_>>();
579
580 ProjectContext::new(worktrees, default_user_rules)
581 })
582 }
583
584 fn load_worktree_info_for_system_prompt(
585 worktree: Entity<Worktree>,
586 project: Entity<Project>,
587 cx: &mut App,
588 ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
589 let tree = worktree.read(cx);
590 let root_name = tree.root_name_str().into();
591 let abs_path = tree.abs_path();
592
593 let mut context = WorktreeContext {
594 root_name,
595 abs_path,
596 rules_file: None,
597 };
598
599 let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
600 let Some(rules_task) = rules_task else {
601 return Task::ready((context, None));
602 };
603
604 cx.spawn(async move |_| {
605 let (rules_file, rules_file_error) = match rules_task.await {
606 Ok(rules_file) => (Some(rules_file), None),
607 Err(err) => (
608 None,
609 Some(RulesLoadingError {
610 message: format!("{err}").into(),
611 }),
612 ),
613 };
614 context.rules_file = rules_file;
615 (context, rules_file_error)
616 })
617 }
618
619 fn load_worktree_rules_file(
620 worktree: Entity<Worktree>,
621 project: Entity<Project>,
622 cx: &mut App,
623 ) -> Option<Task<Result<RulesFileContext>>> {
624 let worktree = worktree.read(cx);
625 let worktree_id = worktree.id();
626 let selected_rules_file = RULES_FILE_NAMES
627 .into_iter()
628 .filter_map(|name| {
629 worktree
630 .entry_for_path(RelPath::unix(name).unwrap())
631 .filter(|entry| entry.is_file())
632 .map(|entry| entry.path.clone())
633 })
634 .next();
635
636 // Note that Cline supports `.clinerules` being a directory, but that is not currently
637 // supported. This doesn't seem to occur often in GitHub repositories.
638 selected_rules_file.map(|path_in_worktree| {
639 let project_path = ProjectPath {
640 worktree_id,
641 path: path_in_worktree.clone(),
642 };
643 let buffer_task =
644 project.update(cx, |project, cx| project.open_buffer(project_path, cx));
645 let rope_task = cx.spawn(async move |cx| {
646 let buffer = buffer_task.await?;
647 let (project_entry_id, rope) = buffer.read_with(cx, |buffer, cx| {
648 let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
649 anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
650 })?;
651 anyhow::Ok((project_entry_id, rope))
652 });
653 // Build a string from the rope on a background thread.
654 cx.background_spawn(async move {
655 let (project_entry_id, rope) = rope_task.await?;
656 anyhow::Ok(RulesFileContext {
657 path_in_worktree,
658 text: rope.to_string().trim().to_string(),
659 project_entry_id: project_entry_id.to_usize(),
660 })
661 })
662 })
663 }
664
665 fn handle_thread_title_updated(
666 &mut self,
667 thread: Entity<Thread>,
668 _: &TitleUpdated,
669 cx: &mut Context<Self>,
670 ) {
671 let session_id = thread.read(cx).id();
672 let Some(session) = self.sessions.get(session_id) else {
673 return;
674 };
675
676 let thread = thread.downgrade();
677 let acp_thread = session.acp_thread.downgrade();
678 cx.spawn(async move |_, cx| {
679 let title = thread.read_with(cx, |thread, _| thread.title())?;
680 if let Some(title) = title {
681 let task =
682 acp_thread.update(cx, |acp_thread, cx| acp_thread.set_title(title, cx))?;
683 task.await?;
684 }
685 anyhow::Ok(())
686 })
687 .detach_and_log_err(cx);
688 }
689
690 fn handle_thread_token_usage_updated(
691 &mut self,
692 thread: Entity<Thread>,
693 usage: &TokenUsageUpdated,
694 cx: &mut Context<Self>,
695 ) {
696 let Some(session) = self.sessions.get(thread.read(cx).id()) else {
697 return;
698 };
699 session.acp_thread.update(cx, |acp_thread, cx| {
700 acp_thread.update_token_usage(usage.0.clone(), cx);
701 });
702 }
703
704 fn handle_project_event(
705 &mut self,
706 project: Entity<Project>,
707 event: &project::Event,
708 _cx: &mut Context<Self>,
709 ) {
710 let project_id = project.entity_id();
711 let Some(state) = self.projects.get_mut(&project_id) else {
712 return;
713 };
714 match event {
715 project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
716 state.project_context_needs_refresh.send(()).ok();
717 }
718 project::Event::WorktreeUpdatedEntries(_, items) => {
719 if items.iter().any(|(path, _, _)| {
720 RULES_FILE_NAMES
721 .iter()
722 .any(|name| path.as_ref() == RelPath::unix(name).unwrap())
723 }) {
724 state.project_context_needs_refresh.send(()).ok();
725 }
726 }
727 _ => {}
728 }
729 }
730
731 fn handle_prompts_updated_event(
732 &mut self,
733 _prompt_store: Entity<PromptStore>,
734 _event: &prompt_store::PromptsUpdatedEvent,
735 _cx: &mut Context<Self>,
736 ) {
737 for state in self.projects.values_mut() {
738 state.project_context_needs_refresh.send(()).ok();
739 }
740 }
741
742 fn handle_models_updated_event(
743 &mut self,
744 _registry: Entity<LanguageModelRegistry>,
745 event: &language_model::Event,
746 cx: &mut Context<Self>,
747 ) {
748 self.models.refresh_list(cx);
749
750 let registry = LanguageModelRegistry::read_global(cx);
751 let default_model = registry.default_model().map(|m| m.model);
752 let summarization_model = registry.thread_summary_model().map(|m| m.model);
753
754 for session in self.sessions.values_mut() {
755 session.thread.update(cx, |thread, cx| {
756 if thread.model().is_none()
757 && let Some(model) = default_model.clone()
758 {
759 thread.set_model(model, cx);
760 cx.notify();
761 }
762 if let Some(model) = summarization_model.clone() {
763 if thread.summarization_model().is_none()
764 || matches!(event, language_model::Event::ThreadSummaryModelChanged)
765 {
766 thread.set_summarization_model(Some(model), cx);
767 }
768 }
769 });
770 }
771 }
772
773 fn handle_context_server_store_updated(
774 &mut self,
775 store: Entity<project::context_server_store::ContextServerStore>,
776 _event: &project::context_server_store::ServerStatusChangedEvent,
777 cx: &mut Context<Self>,
778 ) {
779 let project_id = self.projects.iter().find_map(|(id, state)| {
780 if *state.context_server_registry.read(cx).server_store() == store {
781 Some(*id)
782 } else {
783 None
784 }
785 });
786 if let Some(project_id) = project_id {
787 self.update_available_commands_for_project(project_id, cx);
788 }
789 }
790
791 fn handle_context_server_registry_event(
792 &mut self,
793 registry: Entity<ContextServerRegistry>,
794 event: &ContextServerRegistryEvent,
795 cx: &mut Context<Self>,
796 ) {
797 match event {
798 ContextServerRegistryEvent::ToolsChanged => {}
799 ContextServerRegistryEvent::PromptsChanged => {
800 let project_id = self.projects.iter().find_map(|(id, state)| {
801 if state.context_server_registry == registry {
802 Some(*id)
803 } else {
804 None
805 }
806 });
807 if let Some(project_id) = project_id {
808 self.update_available_commands_for_project(project_id, cx);
809 }
810 }
811 }
812 }
813
814 fn update_available_commands_for_project(&self, project_id: EntityId, cx: &mut Context<Self>) {
815 let available_commands =
816 Self::build_available_commands_for_project(self.projects.get(&project_id), cx);
817 for session in self.sessions.values() {
818 if session.project_id != project_id {
819 continue;
820 }
821 session.acp_thread.update(cx, |thread, cx| {
822 thread
823 .handle_session_update(
824 acp::SessionUpdate::AvailableCommandsUpdate(
825 acp::AvailableCommandsUpdate::new(available_commands.clone()),
826 ),
827 cx,
828 )
829 .log_err();
830 });
831 }
832 }
833
834 fn build_available_commands_for_project(
835 project_state: Option<&ProjectState>,
836 cx: &App,
837 ) -> Vec<acp::AvailableCommand> {
838 let Some(state) = project_state else {
839 return vec![];
840 };
841 let registry = state.context_server_registry.read(cx);
842
843 let mut prompt_name_counts: HashMap<&str, usize> = HashMap::default();
844 for context_server_prompt in registry.prompts() {
845 *prompt_name_counts
846 .entry(context_server_prompt.prompt.name.as_str())
847 .or_insert(0) += 1;
848 }
849
850 registry
851 .prompts()
852 .flat_map(|context_server_prompt| {
853 let prompt = &context_server_prompt.prompt;
854
855 let should_prefix = prompt_name_counts
856 .get(prompt.name.as_str())
857 .copied()
858 .unwrap_or(0)
859 > 1;
860
861 let name = if should_prefix {
862 format!("{}.{}", context_server_prompt.server_id, prompt.name)
863 } else {
864 prompt.name.clone()
865 };
866
867 let mut command = acp::AvailableCommand::new(
868 name,
869 prompt.description.clone().unwrap_or_default(),
870 );
871
872 match prompt.arguments.as_deref() {
873 Some([arg]) => {
874 let hint = format!("<{}>", arg.name);
875
876 command = command.input(acp::AvailableCommandInput::Unstructured(
877 acp::UnstructuredCommandInput::new(hint),
878 ));
879 }
880 Some([]) | None => {}
881 Some(_) => {
882 // skip >1 argument commands since we don't support them yet
883 return None;
884 }
885 }
886
887 Some(command)
888 })
889 .collect()
890 }
891
892 pub fn load_thread(
893 &mut self,
894 id: acp::SessionId,
895 project: Entity<Project>,
896 cx: &mut Context<Self>,
897 ) -> Task<Result<Entity<Thread>>> {
898 let database_future = ThreadsDatabase::connect(cx);
899 cx.spawn(async move |this, cx| {
900 let database = database_future.await.map_err(|err| anyhow!(err))?;
901 let db_thread = database
902 .load_thread(id.clone())
903 .await?
904 .with_context(|| format!("no thread found with ID: {id:?}"))?;
905
906 this.update(cx, |this, cx| {
907 let project_id = this.get_or_create_project_state(&project, cx);
908 let project_state = this
909 .projects
910 .get(&project_id)
911 .context("project state not found")?;
912 let summarization_model = LanguageModelRegistry::read_global(cx)
913 .thread_summary_model()
914 .map(|c| c.model);
915
916 Ok(cx.new(|cx| {
917 let mut thread = Thread::from_db(
918 id.clone(),
919 db_thread,
920 project_state.project.clone(),
921 project_state.project_context.clone(),
922 project_state.context_server_registry.clone(),
923 this.templates.clone(),
924 cx,
925 );
926 thread.set_summarization_model(summarization_model, cx);
927 thread
928 }))
929 })?
930 })
931 }
932
933 pub fn open_thread(
934 &mut self,
935 id: acp::SessionId,
936 project: Entity<Project>,
937 cx: &mut Context<Self>,
938 ) -> Task<Result<Entity<AcpThread>>> {
939 if let Some(session) = self.sessions.get_mut(&id) {
940 session.ref_count += 1;
941 return Task::ready(Ok(session.acp_thread.clone()));
942 }
943
944 if let Some(pending) = self.pending_sessions.get_mut(&id) {
945 pending.ref_count += 1;
946 let task = pending.task.clone();
947 return cx.background_spawn(async move { task.await.map_err(|err| anyhow!(err)) });
948 }
949
950 let task = self.load_thread(id.clone(), project.clone(), cx);
951 let shared_task = cx
952 .spawn({
953 let id = id.clone();
954 async move |this, cx| {
955 let thread = match task.await {
956 Ok(thread) => thread,
957 Err(err) => {
958 this.update(cx, |this, _cx| {
959 this.pending_sessions.remove(&id);
960 })
961 .ok();
962 return Err(Arc::new(err));
963 }
964 };
965 let acp_thread = this
966 .update(cx, |this, cx| {
967 let project_id = this.get_or_create_project_state(&project, cx);
968 let ref_count = this
969 .pending_sessions
970 .remove(&id)
971 .map_or(1, |pending| pending.ref_count);
972 this.register_session(thread.clone(), project_id, ref_count, cx)
973 })
974 .map_err(Arc::new)?;
975 let events = thread.update(cx, |thread, cx| thread.replay(cx));
976 cx.update(|cx| {
977 NativeAgentConnection::handle_thread_events(
978 events,
979 acp_thread.downgrade(),
980 cx,
981 )
982 })
983 .await
984 .map_err(Arc::new)?;
985 acp_thread.update(cx, |thread, cx| {
986 thread.snapshot_completed_plan(cx);
987 });
988 Ok(acp_thread)
989 }
990 })
991 .shared();
992 self.pending_sessions.insert(
993 id,
994 PendingSession {
995 task: shared_task.clone(),
996 ref_count: 1,
997 },
998 );
999
1000 cx.background_spawn(async move { shared_task.await.map_err(|err| anyhow!(err)) })
1001 }
1002
1003 pub fn thread_summary(
1004 &mut self,
1005 id: acp::SessionId,
1006 project: Entity<Project>,
1007 cx: &mut Context<Self>,
1008 ) -> Task<Result<SharedString>> {
1009 let thread = self.open_thread(id.clone(), project, cx);
1010 cx.spawn(async move |this, cx| {
1011 let acp_thread = thread.await?;
1012 let result = this
1013 .update(cx, |this, cx| {
1014 this.sessions
1015 .get(&id)
1016 .unwrap()
1017 .thread
1018 .update(cx, |thread, cx| thread.summary(cx))
1019 })?
1020 .await
1021 .context("Failed to generate summary")?;
1022
1023 this.update(cx, |this, cx| this.close_session(&id, cx))?
1024 .await?;
1025 drop(acp_thread);
1026 Ok(result)
1027 })
1028 }
1029
1030 fn close_session(
1031 &mut self,
1032 session_id: &acp::SessionId,
1033 cx: &mut Context<Self>,
1034 ) -> Task<Result<()>> {
1035 let Some(session) = self.sessions.get_mut(session_id) else {
1036 return Task::ready(Ok(()));
1037 };
1038
1039 session.ref_count -= 1;
1040 if session.ref_count > 0 {
1041 return Task::ready(Ok(()));
1042 }
1043
1044 let thread = session.thread.clone();
1045 self.save_thread(thread, cx);
1046 let Some(session) = self.sessions.remove(session_id) else {
1047 return Task::ready(Ok(()));
1048 };
1049 let project_id = session.project_id;
1050
1051 let has_remaining = self.sessions.values().any(|s| s.project_id == project_id);
1052 if !has_remaining {
1053 self.projects.remove(&project_id);
1054 }
1055
1056 session.pending_save
1057 }
1058
1059 fn save_thread(&mut self, thread: Entity<Thread>, cx: &mut Context<Self>) {
1060 if thread.read(cx).is_empty() {
1061 return;
1062 }
1063
1064 let id = thread.read(cx).id().clone();
1065 let Some(session) = self.sessions.get_mut(&id) else {
1066 return;
1067 };
1068
1069 let project_id = session.project_id;
1070 let Some(state) = self.projects.get(&project_id) else {
1071 return;
1072 };
1073
1074 let folder_paths = PathList::new(
1075 &state
1076 .project
1077 .read(cx)
1078 .visible_worktrees(cx)
1079 .map(|worktree| worktree.read(cx).abs_path().to_path_buf())
1080 .collect::<Vec<_>>(),
1081 );
1082
1083 let draft_prompt = session.acp_thread.read(cx).draft_prompt().map(Vec::from);
1084 let database_future = ThreadsDatabase::connect(cx);
1085 let db_thread = thread.update(cx, |thread, cx| {
1086 thread.set_draft_prompt(draft_prompt);
1087 thread.to_db(cx)
1088 });
1089 let thread_store = self.thread_store.clone();
1090 session.pending_save = cx.spawn(async move |_, cx| {
1091 let Some(database) = database_future.await.map_err(|err| anyhow!(err)).log_err() else {
1092 return Ok(());
1093 };
1094 let db_thread = db_thread.await;
1095 database
1096 .save_thread(id, db_thread, folder_paths)
1097 .await
1098 .log_err();
1099 thread_store.update(cx, |store, cx| store.reload(cx));
1100 Ok(())
1101 });
1102 }
1103
1104 fn send_mcp_prompt(
1105 &self,
1106 message_id: UserMessageId,
1107 session_id: acp::SessionId,
1108 prompt_name: String,
1109 server_id: ContextServerId,
1110 arguments: HashMap<String, String>,
1111 original_content: Vec<acp::ContentBlock>,
1112 cx: &mut Context<Self>,
1113 ) -> Task<Result<acp::PromptResponse>> {
1114 let Some(state) = self.session_project_state(&session_id) else {
1115 return Task::ready(Err(anyhow!("Project state not found for session")));
1116 };
1117 let server_store = state
1118 .context_server_registry
1119 .read(cx)
1120 .server_store()
1121 .clone();
1122 let path_style = state.project.read(cx).path_style(cx);
1123
1124 cx.spawn(async move |this, cx| {
1125 let prompt =
1126 crate::get_prompt(&server_store, &server_id, &prompt_name, arguments, cx).await?;
1127
1128 let (acp_thread, thread) = this.update(cx, |this, _cx| {
1129 let session = this
1130 .sessions
1131 .get(&session_id)
1132 .context("Failed to get session")?;
1133 anyhow::Ok((session.acp_thread.clone(), session.thread.clone()))
1134 })??;
1135
1136 let mut last_is_user = true;
1137
1138 thread.update(cx, |thread, cx| {
1139 thread.push_acp_user_block(
1140 message_id,
1141 original_content.into_iter().skip(1),
1142 path_style,
1143 cx,
1144 );
1145 });
1146
1147 for message in prompt.messages {
1148 let context_server::types::PromptMessage { role, content } = message;
1149 let block = mcp_message_content_to_acp_content_block(content);
1150
1151 match role {
1152 context_server::types::Role::User => {
1153 let id = acp_thread::UserMessageId::new();
1154
1155 acp_thread.update(cx, |acp_thread, cx| {
1156 acp_thread.push_user_content_block_with_indent(
1157 Some(id.clone()),
1158 block.clone(),
1159 true,
1160 cx,
1161 );
1162 });
1163
1164 thread.update(cx, |thread, cx| {
1165 thread.push_acp_user_block(id, [block], path_style, cx);
1166 });
1167 }
1168 context_server::types::Role::Assistant => {
1169 acp_thread.update(cx, |acp_thread, cx| {
1170 acp_thread.push_assistant_content_block_with_indent(
1171 block.clone(),
1172 false,
1173 true,
1174 cx,
1175 );
1176 });
1177
1178 thread.update(cx, |thread, cx| {
1179 thread.push_acp_agent_block(block, cx);
1180 });
1181 }
1182 }
1183
1184 last_is_user = role == context_server::types::Role::User;
1185 }
1186
1187 let response_stream = thread.update(cx, |thread, cx| {
1188 if last_is_user {
1189 thread.send_existing(cx)
1190 } else {
1191 // Resume if MCP prompt did not end with a user message
1192 thread.resume(cx)
1193 }
1194 })?;
1195
1196 cx.update(|cx| {
1197 NativeAgentConnection::handle_thread_events(
1198 response_stream,
1199 acp_thread.downgrade(),
1200 cx,
1201 )
1202 })
1203 .await
1204 })
1205 }
1206}
1207
1208/// Wrapper struct that implements the AgentConnection trait
1209#[derive(Clone)]
1210pub struct NativeAgentConnection(pub Entity<NativeAgent>);
1211
1212impl NativeAgentConnection {
1213 pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option<Entity<Thread>> {
1214 self.0
1215 .read(cx)
1216 .sessions
1217 .get(session_id)
1218 .map(|session| session.thread.clone())
1219 }
1220
1221 pub fn load_thread(
1222 &self,
1223 id: acp::SessionId,
1224 project: Entity<Project>,
1225 cx: &mut App,
1226 ) -> Task<Result<Entity<Thread>>> {
1227 self.0
1228 .update(cx, |this, cx| this.load_thread(id, project, cx))
1229 }
1230
1231 fn run_turn(
1232 &self,
1233 session_id: acp::SessionId,
1234 cx: &mut App,
1235 f: impl 'static
1236 + FnOnce(Entity<Thread>, &mut App) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>,
1237 ) -> Task<Result<acp::PromptResponse>> {
1238 let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
1239 agent
1240 .sessions
1241 .get_mut(&session_id)
1242 .map(|s| (s.thread.clone(), s.acp_thread.clone()))
1243 }) else {
1244 log::error!("Session not found in run_turn: {}", session_id);
1245 return Task::ready(Err(anyhow!("Session not found")));
1246 };
1247 log::debug!("Found session for: {}", session_id);
1248
1249 let response_stream = match f(thread, cx) {
1250 Ok(stream) => stream,
1251 Err(err) => return Task::ready(Err(err)),
1252 };
1253 Self::handle_thread_events(response_stream, acp_thread.downgrade(), cx)
1254 }
1255
1256 fn handle_thread_events(
1257 mut events: mpsc::UnboundedReceiver<Result<ThreadEvent>>,
1258 acp_thread: WeakEntity<AcpThread>,
1259 cx: &App,
1260 ) -> Task<Result<acp::PromptResponse>> {
1261 cx.spawn(async move |cx| {
1262 // Handle response stream and forward to session.acp_thread
1263 while let Some(result) = events.next().await {
1264 match result {
1265 Ok(event) => {
1266 log::trace!("Received completion event: {:?}", event);
1267
1268 match event {
1269 ThreadEvent::UserMessage(message) => {
1270 acp_thread.update(cx, |thread, cx| {
1271 for content in message.content {
1272 thread.push_user_content_block(
1273 Some(message.id.clone()),
1274 content.into(),
1275 cx,
1276 );
1277 }
1278 })?;
1279 }
1280 ThreadEvent::AgentText(text) => {
1281 acp_thread.update(cx, |thread, cx| {
1282 thread.push_assistant_content_block(text.into(), false, cx)
1283 })?;
1284 }
1285 ThreadEvent::AgentThinking(text) => {
1286 acp_thread.update(cx, |thread, cx| {
1287 thread.push_assistant_content_block(text.into(), true, cx)
1288 })?;
1289 }
1290 ThreadEvent::ToolCallAuthorization(ToolCallAuthorization {
1291 tool_call,
1292 options,
1293 response,
1294 context: _,
1295 }) => {
1296 let outcome_task = acp_thread.update(cx, |thread, cx| {
1297 thread.request_tool_call_authorization(tool_call, options, cx)
1298 })??;
1299 cx.background_spawn(async move {
1300 if let acp_thread::RequestPermissionOutcome::Selected(outcome) =
1301 outcome_task.await
1302 {
1303 response
1304 .send(outcome)
1305 .map(|_| anyhow!("authorization receiver was dropped"))
1306 .log_err();
1307 }
1308 })
1309 .detach();
1310 }
1311 ThreadEvent::ToolCall(tool_call) => {
1312 acp_thread.update(cx, |thread, cx| {
1313 thread.upsert_tool_call(tool_call, cx)
1314 })??;
1315 }
1316 ThreadEvent::ToolCallUpdate(update) => {
1317 acp_thread.update(cx, |thread, cx| {
1318 thread.update_tool_call(update, cx)
1319 })??;
1320 }
1321 ThreadEvent::Plan(plan) => {
1322 acp_thread.update(cx, |thread, cx| thread.update_plan(plan, cx))?;
1323 }
1324 ThreadEvent::SubagentSpawned(session_id) => {
1325 acp_thread.update(cx, |thread, cx| {
1326 thread.subagent_spawned(session_id, cx);
1327 })?;
1328 }
1329 ThreadEvent::Retry(status) => {
1330 acp_thread.update(cx, |thread, cx| {
1331 thread.update_retry_status(status, cx)
1332 })?;
1333 }
1334 ThreadEvent::Stop(stop_reason) => {
1335 log::debug!("Assistant message complete: {:?}", stop_reason);
1336 return Ok(acp::PromptResponse::new(stop_reason));
1337 }
1338 }
1339 }
1340 Err(e) => {
1341 log::error!("Error in model response stream: {:?}", e);
1342 return Err(e);
1343 }
1344 }
1345 }
1346
1347 log::debug!("Response stream completed");
1348 anyhow::Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
1349 })
1350 }
1351}
1352
1353struct Command<'a> {
1354 prompt_name: &'a str,
1355 arg_value: &'a str,
1356 explicit_server_id: Option<&'a str>,
1357}
1358
1359impl<'a> Command<'a> {
1360 fn parse(prompt: &'a [acp::ContentBlock]) -> Option<Self> {
1361 let acp::ContentBlock::Text(text_content) = prompt.first()? else {
1362 return None;
1363 };
1364 let text = text_content.text.trim();
1365 let command = text.strip_prefix('/')?;
1366 let (command, arg_value) = command
1367 .split_once(char::is_whitespace)
1368 .unwrap_or((command, ""));
1369
1370 if let Some((server_id, prompt_name)) = command.split_once('.') {
1371 Some(Self {
1372 prompt_name,
1373 arg_value,
1374 explicit_server_id: Some(server_id),
1375 })
1376 } else {
1377 Some(Self {
1378 prompt_name: command,
1379 arg_value,
1380 explicit_server_id: None,
1381 })
1382 }
1383 }
1384}
1385
1386struct NativeAgentModelSelector {
1387 session_id: acp::SessionId,
1388 connection: NativeAgentConnection,
1389}
1390
1391impl acp_thread::AgentModelSelector for NativeAgentModelSelector {
1392 fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
1393 log::debug!("NativeAgentConnection::list_models called");
1394 let list = self.connection.0.read(cx).models.model_list.clone();
1395 Task::ready(if list.is_empty() {
1396 Err(anyhow::anyhow!("No models available"))
1397 } else {
1398 Ok(list)
1399 })
1400 }
1401
1402 fn select_model(&self, model_id: acp::ModelId, cx: &mut App) -> Task<Result<()>> {
1403 log::debug!(
1404 "Setting model for session {}: {}",
1405 self.session_id,
1406 model_id
1407 );
1408 let Some(thread) = self
1409 .connection
1410 .0
1411 .read(cx)
1412 .sessions
1413 .get(&self.session_id)
1414 .map(|session| session.thread.clone())
1415 else {
1416 return Task::ready(Err(anyhow!("Session not found")));
1417 };
1418
1419 let Some(model) = self.connection.0.read(cx).models.model_from_id(&model_id) else {
1420 return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
1421 };
1422
1423 // We want to reset the effort level when switching models, as the currently-selected effort level may
1424 // not be compatible.
1425 let effort = model
1426 .default_effort_level()
1427 .map(|effort_level| effort_level.value.to_string());
1428
1429 thread.update(cx, |thread, cx| {
1430 thread.set_model(model.clone(), cx);
1431 thread.set_thinking_effort(effort.clone(), cx);
1432 thread.set_thinking_enabled(model.supports_thinking(), cx);
1433 });
1434
1435 update_settings_file(
1436 self.connection.0.read(cx).fs.clone(),
1437 cx,
1438 move |settings, cx| {
1439 let provider = model.provider_id().0.to_string();
1440 let model = model.id().0.to_string();
1441 let enable_thinking = thread.read(cx).thinking_enabled();
1442 let speed = thread.read(cx).speed();
1443 settings
1444 .agent
1445 .get_or_insert_default()
1446 .set_model(LanguageModelSelection {
1447 provider: provider.into(),
1448 model,
1449 enable_thinking,
1450 effort,
1451 speed,
1452 });
1453 },
1454 );
1455
1456 Task::ready(Ok(()))
1457 }
1458
1459 fn selected_model(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelInfo>> {
1460 let Some(thread) = self
1461 .connection
1462 .0
1463 .read(cx)
1464 .sessions
1465 .get(&self.session_id)
1466 .map(|session| session.thread.clone())
1467 else {
1468 return Task::ready(Err(anyhow!("Session not found")));
1469 };
1470 let Some(model) = thread.read(cx).model() else {
1471 return Task::ready(Err(anyhow!("Model not found")));
1472 };
1473 let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
1474 else {
1475 return Task::ready(Err(anyhow!("Provider not found")));
1476 };
1477 Task::ready(Ok(LanguageModels::map_language_model_to_info(
1478 model, &provider,
1479 )))
1480 }
1481
1482 fn watch(&self, cx: &mut App) -> Option<watch::Receiver<()>> {
1483 Some(self.connection.0.read(cx).models.watch())
1484 }
1485
1486 fn should_render_footer(&self) -> bool {
1487 true
1488 }
1489}
1490
1491pub static ZED_AGENT_ID: LazyLock<AgentId> = LazyLock::new(|| AgentId::new("Zed Agent"));
1492
1493impl acp_thread::AgentConnection for NativeAgentConnection {
1494 fn agent_id(&self) -> AgentId {
1495 ZED_AGENT_ID.clone()
1496 }
1497
1498 fn telemetry_id(&self) -> SharedString {
1499 "zed".into()
1500 }
1501
1502 fn new_session(
1503 self: Rc<Self>,
1504 project: Entity<Project>,
1505 work_dirs: PathList,
1506 cx: &mut App,
1507 ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
1508 log::debug!("Creating new thread for project at: {work_dirs:?}");
1509 Task::ready(Ok(self
1510 .0
1511 .update(cx, |agent, cx| agent.new_session(project, cx))))
1512 }
1513
1514 fn supports_load_session(&self) -> bool {
1515 true
1516 }
1517
1518 fn load_session(
1519 self: Rc<Self>,
1520 session_id: acp::SessionId,
1521 project: Entity<Project>,
1522 _work_dirs: PathList,
1523 _title: Option<SharedString>,
1524 cx: &mut App,
1525 ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
1526 self.0
1527 .update(cx, |agent, cx| agent.open_thread(session_id, project, cx))
1528 }
1529
1530 fn supports_close_session(&self) -> bool {
1531 true
1532 }
1533
1534 fn close_session(
1535 self: Rc<Self>,
1536 session_id: &acp::SessionId,
1537 cx: &mut App,
1538 ) -> Task<Result<()>> {
1539 self.0
1540 .update(cx, |agent, cx| agent.close_session(session_id, cx))
1541 }
1542
1543 fn auth_methods(&self) -> &[acp::AuthMethod] {
1544 &[] // No auth for in-process
1545 }
1546
1547 fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
1548 Task::ready(Ok(()))
1549 }
1550
1551 fn model_selector(&self, session_id: &acp::SessionId) -> Option<Rc<dyn AgentModelSelector>> {
1552 Some(Rc::new(NativeAgentModelSelector {
1553 session_id: session_id.clone(),
1554 connection: self.clone(),
1555 }) as Rc<dyn AgentModelSelector>)
1556 }
1557
1558 fn prompt(
1559 &self,
1560 id: acp_thread::UserMessageId,
1561 params: acp::PromptRequest,
1562 cx: &mut App,
1563 ) -> Task<Result<acp::PromptResponse>> {
1564 let session_id = params.session_id.clone();
1565 log::info!("Received prompt request for session: {}", session_id);
1566 log::debug!("Prompt blocks count: {}", params.prompt.len());
1567
1568 let Some(project_state) = self.0.read(cx).session_project_state(&session_id) else {
1569 log::error!("Session not found in prompt: {}", session_id);
1570 if self.0.read(cx).sessions.contains_key(&session_id) {
1571 log::error!(
1572 "Session found in sessions map, but not in project state: {}",
1573 session_id
1574 );
1575 }
1576 return Task::ready(Err(anyhow::anyhow!("Session not found")));
1577 };
1578
1579 if let Some(parsed_command) = Command::parse(¶ms.prompt) {
1580 let registry = project_state.context_server_registry.read(cx);
1581
1582 let explicit_server_id = parsed_command
1583 .explicit_server_id
1584 .map(|server_id| ContextServerId(server_id.into()));
1585
1586 if let Some(prompt) =
1587 registry.find_prompt(explicit_server_id.as_ref(), parsed_command.prompt_name)
1588 {
1589 let arguments = if !parsed_command.arg_value.is_empty()
1590 && let Some(arg_name) = prompt
1591 .prompt
1592 .arguments
1593 .as_ref()
1594 .and_then(|args| args.first())
1595 .map(|arg| arg.name.clone())
1596 {
1597 HashMap::from_iter([(arg_name, parsed_command.arg_value.to_string())])
1598 } else {
1599 Default::default()
1600 };
1601
1602 let prompt_name = prompt.prompt.name.clone();
1603 let server_id = prompt.server_id.clone();
1604
1605 return self.0.update(cx, |agent, cx| {
1606 agent.send_mcp_prompt(
1607 id,
1608 session_id.clone(),
1609 prompt_name,
1610 server_id,
1611 arguments,
1612 params.prompt,
1613 cx,
1614 )
1615 });
1616 }
1617 };
1618
1619 let path_style = project_state.project.read(cx).path_style(cx);
1620
1621 self.run_turn(session_id, cx, move |thread, cx| {
1622 let content: Vec<UserMessageContent> = params
1623 .prompt
1624 .into_iter()
1625 .map(|block| UserMessageContent::from_content_block(block, path_style))
1626 .collect::<Vec<_>>();
1627 log::debug!("Converted prompt to message: {} chars", content.len());
1628 log::debug!("Message id: {:?}", id);
1629 log::debug!("Message content: {:?}", content);
1630
1631 thread.update(cx, |thread, cx| thread.send(id, content, cx))
1632 })
1633 }
1634
1635 fn retry(
1636 &self,
1637 session_id: &acp::SessionId,
1638 _cx: &App,
1639 ) -> Option<Rc<dyn acp_thread::AgentSessionRetry>> {
1640 Some(Rc::new(NativeAgentSessionRetry {
1641 connection: self.clone(),
1642 session_id: session_id.clone(),
1643 }) as _)
1644 }
1645
1646 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
1647 log::info!("Cancelling on session: {}", session_id);
1648 self.0.update(cx, |agent, cx| {
1649 if let Some(session) = agent.sessions.get(session_id) {
1650 session
1651 .thread
1652 .update(cx, |thread, cx| thread.cancel(cx))
1653 .detach();
1654 }
1655 });
1656 }
1657
1658 fn truncate(
1659 &self,
1660 session_id: &acp::SessionId,
1661 cx: &App,
1662 ) -> Option<Rc<dyn acp_thread::AgentSessionTruncate>> {
1663 self.0.read_with(cx, |agent, _cx| {
1664 agent.sessions.get(session_id).map(|session| {
1665 Rc::new(NativeAgentSessionTruncate {
1666 thread: session.thread.clone(),
1667 acp_thread: session.acp_thread.downgrade(),
1668 }) as _
1669 })
1670 })
1671 }
1672
1673 fn set_title(
1674 &self,
1675 session_id: &acp::SessionId,
1676 cx: &App,
1677 ) -> Option<Rc<dyn acp_thread::AgentSessionSetTitle>> {
1678 self.0.read_with(cx, |agent, _cx| {
1679 agent
1680 .sessions
1681 .get(session_id)
1682 .filter(|s| !s.thread.read(cx).is_subagent())
1683 .map(|session| {
1684 Rc::new(NativeAgentSessionSetTitle {
1685 thread: session.thread.clone(),
1686 }) as _
1687 })
1688 })
1689 }
1690
1691 fn session_list(&self, cx: &mut App) -> Option<Rc<dyn AgentSessionList>> {
1692 let thread_store = self.0.read(cx).thread_store.clone();
1693 Some(Rc::new(NativeAgentSessionList::new(thread_store, cx)) as _)
1694 }
1695
1696 fn telemetry(&self) -> Option<Rc<dyn acp_thread::AgentTelemetry>> {
1697 Some(Rc::new(self.clone()) as Rc<dyn acp_thread::AgentTelemetry>)
1698 }
1699
1700 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1701 self
1702 }
1703}
1704
1705impl acp_thread::AgentTelemetry for NativeAgentConnection {
1706 fn thread_data(
1707 &self,
1708 session_id: &acp::SessionId,
1709 cx: &mut App,
1710 ) -> Task<Result<serde_json::Value>> {
1711 let Some(session) = self.0.read(cx).sessions.get(session_id) else {
1712 return Task::ready(Err(anyhow!("Session not found")));
1713 };
1714
1715 let task = session.thread.read(cx).to_db(cx);
1716 cx.background_spawn(async move {
1717 serde_json::to_value(task.await).context("Failed to serialize thread")
1718 })
1719 }
1720}
1721
1722pub struct NativeAgentSessionList {
1723 thread_store: Entity<ThreadStore>,
1724 updates_tx: smol::channel::Sender<acp_thread::SessionListUpdate>,
1725 updates_rx: smol::channel::Receiver<acp_thread::SessionListUpdate>,
1726 _subscription: Subscription,
1727}
1728
1729impl NativeAgentSessionList {
1730 fn new(thread_store: Entity<ThreadStore>, cx: &mut App) -> Self {
1731 let (tx, rx) = smol::channel::unbounded();
1732 let this_tx = tx.clone();
1733 let subscription = cx.observe(&thread_store, move |_, _| {
1734 this_tx
1735 .try_send(acp_thread::SessionListUpdate::Refresh)
1736 .ok();
1737 });
1738 Self {
1739 thread_store,
1740 updates_tx: tx,
1741 updates_rx: rx,
1742 _subscription: subscription,
1743 }
1744 }
1745
1746 pub fn thread_store(&self) -> &Entity<ThreadStore> {
1747 &self.thread_store
1748 }
1749}
1750
1751impl AgentSessionList for NativeAgentSessionList {
1752 fn list_sessions(
1753 &self,
1754 _request: AgentSessionListRequest,
1755 cx: &mut App,
1756 ) -> Task<Result<AgentSessionListResponse>> {
1757 let sessions = self
1758 .thread_store
1759 .read(cx)
1760 .entries()
1761 .map(|entry| AgentSessionInfo::from(&entry))
1762 .collect();
1763 Task::ready(Ok(AgentSessionListResponse::new(sessions)))
1764 }
1765
1766 fn supports_delete(&self) -> bool {
1767 true
1768 }
1769
1770 fn delete_session(&self, session_id: &acp::SessionId, cx: &mut App) -> Task<Result<()>> {
1771 self.thread_store
1772 .update(cx, |store, cx| store.delete_thread(session_id.clone(), cx))
1773 }
1774
1775 fn delete_sessions(&self, cx: &mut App) -> Task<Result<()>> {
1776 self.thread_store
1777 .update(cx, |store, cx| store.delete_threads(cx))
1778 }
1779
1780 fn watch(
1781 &self,
1782 _cx: &mut App,
1783 ) -> Option<smol::channel::Receiver<acp_thread::SessionListUpdate>> {
1784 Some(self.updates_rx.clone())
1785 }
1786
1787 fn notify_refresh(&self) {
1788 self.updates_tx
1789 .try_send(acp_thread::SessionListUpdate::Refresh)
1790 .ok();
1791 }
1792
1793 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1794 self
1795 }
1796}
1797
1798struct NativeAgentSessionTruncate {
1799 thread: Entity<Thread>,
1800 acp_thread: WeakEntity<AcpThread>,
1801}
1802
1803impl acp_thread::AgentSessionTruncate for NativeAgentSessionTruncate {
1804 fn run(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
1805 match self.thread.update(cx, |thread, cx| {
1806 thread.truncate(message_id.clone(), cx)?;
1807 Ok(thread.latest_token_usage())
1808 }) {
1809 Ok(usage) => {
1810 self.acp_thread
1811 .update(cx, |thread, cx| {
1812 thread.update_token_usage(usage, cx);
1813 })
1814 .ok();
1815 Task::ready(Ok(()))
1816 }
1817 Err(error) => Task::ready(Err(error)),
1818 }
1819 }
1820}
1821
1822struct NativeAgentSessionRetry {
1823 connection: NativeAgentConnection,
1824 session_id: acp::SessionId,
1825}
1826
1827impl acp_thread::AgentSessionRetry for NativeAgentSessionRetry {
1828 fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>> {
1829 self.connection
1830 .run_turn(self.session_id.clone(), cx, |thread, cx| {
1831 thread.update(cx, |thread, cx| thread.resume(cx))
1832 })
1833 }
1834}
1835
1836struct NativeAgentSessionSetTitle {
1837 thread: Entity<Thread>,
1838}
1839
1840impl acp_thread::AgentSessionSetTitle for NativeAgentSessionSetTitle {
1841 fn run(&self, title: SharedString, cx: &mut App) -> Task<Result<()>> {
1842 self.thread
1843 .update(cx, |thread, cx| thread.set_title(title, cx));
1844 Task::ready(Ok(()))
1845 }
1846}
1847
1848pub struct NativeThreadEnvironment {
1849 agent: WeakEntity<NativeAgent>,
1850 thread: WeakEntity<Thread>,
1851 acp_thread: WeakEntity<AcpThread>,
1852}
1853
1854impl NativeThreadEnvironment {
1855 pub(crate) fn create_subagent_thread(
1856 &self,
1857 label: String,
1858 cx: &mut App,
1859 ) -> Result<Rc<dyn SubagentHandle>> {
1860 let Some(parent_thread_entity) = self.thread.upgrade() else {
1861 anyhow::bail!("Parent thread no longer exists".to_string());
1862 };
1863 let parent_thread = parent_thread_entity.read(cx);
1864 let current_depth = parent_thread.depth();
1865 let parent_session_id = parent_thread.id().clone();
1866
1867 if current_depth >= MAX_SUBAGENT_DEPTH {
1868 return Err(anyhow!(
1869 "Maximum subagent depth ({}) reached",
1870 MAX_SUBAGENT_DEPTH
1871 ));
1872 }
1873
1874 let subagent_thread: Entity<Thread> = cx.new(|cx| {
1875 let mut thread = Thread::new_subagent(&parent_thread_entity, cx);
1876 thread.set_title(label.into(), cx);
1877 thread
1878 });
1879
1880 let session_id = subagent_thread.read(cx).id().clone();
1881
1882 let acp_thread = self
1883 .agent
1884 .update(cx, |agent, cx| -> Result<Entity<AcpThread>> {
1885 let project_id = agent
1886 .sessions
1887 .get(&parent_session_id)
1888 .map(|s| s.project_id)
1889 .context("parent session not found")?;
1890 Ok(agent.register_session(subagent_thread.clone(), project_id, 1, cx))
1891 })??;
1892
1893 let depth = current_depth + 1;
1894
1895 telemetry::event!(
1896 "Subagent Started",
1897 session = parent_thread_entity.read(cx).id().to_string(),
1898 subagent_session = session_id.to_string(),
1899 depth,
1900 is_resumed = false,
1901 );
1902
1903 self.prompt_subagent(session_id, subagent_thread, acp_thread)
1904 }
1905
1906 pub(crate) fn resume_subagent_thread(
1907 &self,
1908 session_id: acp::SessionId,
1909 cx: &mut App,
1910 ) -> Result<Rc<dyn SubagentHandle>> {
1911 let (subagent_thread, acp_thread) = self.agent.update(cx, |agent, _cx| {
1912 let session = agent
1913 .sessions
1914 .get(&session_id)
1915 .ok_or_else(|| anyhow!("No subagent session found with id {session_id}"))?;
1916 anyhow::Ok((session.thread.clone(), session.acp_thread.clone()))
1917 })??;
1918
1919 let depth = subagent_thread.read(cx).depth();
1920
1921 if let Some(parent_thread_entity) = self.thread.upgrade() {
1922 telemetry::event!(
1923 "Subagent Started",
1924 session = parent_thread_entity.read(cx).id().to_string(),
1925 subagent_session = session_id.to_string(),
1926 depth,
1927 is_resumed = true,
1928 );
1929 }
1930
1931 self.prompt_subagent(session_id, subagent_thread, acp_thread)
1932 }
1933
1934 fn prompt_subagent(
1935 &self,
1936 session_id: acp::SessionId,
1937 subagent_thread: Entity<Thread>,
1938 acp_thread: Entity<acp_thread::AcpThread>,
1939 ) -> Result<Rc<dyn SubagentHandle>> {
1940 let Some(parent_thread_entity) = self.thread.upgrade() else {
1941 anyhow::bail!("Parent thread no longer exists".to_string());
1942 };
1943 Ok(Rc::new(NativeSubagentHandle::new(
1944 session_id,
1945 subagent_thread,
1946 acp_thread,
1947 parent_thread_entity,
1948 )) as _)
1949 }
1950}
1951
1952impl ThreadEnvironment for NativeThreadEnvironment {
1953 fn create_terminal(
1954 &self,
1955 command: String,
1956 cwd: Option<PathBuf>,
1957 output_byte_limit: Option<u64>,
1958 cx: &mut AsyncApp,
1959 ) -> Task<Result<Rc<dyn TerminalHandle>>> {
1960 let task = self.acp_thread.update(cx, |thread, cx| {
1961 thread.create_terminal(command, vec![], vec![], cwd, output_byte_limit, cx)
1962 });
1963
1964 let acp_thread = self.acp_thread.clone();
1965 cx.spawn(async move |cx| {
1966 let terminal = task?.await?;
1967
1968 let (drop_tx, drop_rx) = oneshot::channel();
1969 let terminal_id = terminal.read_with(cx, |terminal, _cx| terminal.id().clone());
1970
1971 cx.spawn(async move |cx| {
1972 drop_rx.await.ok();
1973 acp_thread.update(cx, |thread, cx| thread.release_terminal(terminal_id, cx))
1974 })
1975 .detach();
1976
1977 let handle = AcpTerminalHandle {
1978 terminal,
1979 _drop_tx: Some(drop_tx),
1980 };
1981
1982 Ok(Rc::new(handle) as _)
1983 })
1984 }
1985
1986 fn create_subagent(&self, label: String, cx: &mut App) -> Result<Rc<dyn SubagentHandle>> {
1987 self.create_subagent_thread(label, cx)
1988 }
1989
1990 fn resume_subagent(
1991 &self,
1992 session_id: acp::SessionId,
1993 cx: &mut App,
1994 ) -> Result<Rc<dyn SubagentHandle>> {
1995 self.resume_subagent_thread(session_id, cx)
1996 }
1997}
1998
1999#[derive(Debug, Clone)]
2000enum SubagentPromptResult {
2001 Completed,
2002 Cancelled,
2003 ContextWindowWarning,
2004 Error(String),
2005}
2006
2007pub struct NativeSubagentHandle {
2008 session_id: acp::SessionId,
2009 parent_thread: WeakEntity<Thread>,
2010 subagent_thread: Entity<Thread>,
2011 acp_thread: Entity<acp_thread::AcpThread>,
2012}
2013
2014impl NativeSubagentHandle {
2015 fn new(
2016 session_id: acp::SessionId,
2017 subagent_thread: Entity<Thread>,
2018 acp_thread: Entity<acp_thread::AcpThread>,
2019 parent_thread_entity: Entity<Thread>,
2020 ) -> Self {
2021 NativeSubagentHandle {
2022 session_id,
2023 subagent_thread,
2024 parent_thread: parent_thread_entity.downgrade(),
2025 acp_thread,
2026 }
2027 }
2028}
2029
2030impl SubagentHandle for NativeSubagentHandle {
2031 fn id(&self) -> acp::SessionId {
2032 self.session_id.clone()
2033 }
2034
2035 fn num_entries(&self, cx: &App) -> usize {
2036 self.acp_thread.read(cx).entries().len()
2037 }
2038
2039 fn send(&self, message: String, cx: &AsyncApp) -> Task<Result<String>> {
2040 let thread = self.subagent_thread.clone();
2041 let acp_thread = self.acp_thread.clone();
2042 let subagent_session_id = self.session_id.clone();
2043 let parent_thread = self.parent_thread.clone();
2044
2045 cx.spawn(async move |cx| {
2046 let (task, _subscription) = cx.update(|cx| {
2047 let ratio_before_prompt = thread
2048 .read(cx)
2049 .latest_token_usage()
2050 .map(|usage| usage.ratio());
2051
2052 parent_thread
2053 .update(cx, |parent_thread, _cx| {
2054 parent_thread.register_running_subagent(thread.downgrade())
2055 })
2056 .ok();
2057
2058 let task = acp_thread.update(cx, |acp_thread, cx| {
2059 acp_thread.send(vec![message.into()], cx)
2060 });
2061
2062 let (token_limit_tx, token_limit_rx) = oneshot::channel::<()>();
2063 let mut token_limit_tx = Some(token_limit_tx);
2064
2065 let subscription = cx.subscribe(
2066 &thread,
2067 move |_thread, event: &TokenUsageUpdated, _cx| {
2068 if let Some(usage) = &event.0 {
2069 let old_ratio = ratio_before_prompt
2070 .clone()
2071 .unwrap_or(TokenUsageRatio::Normal);
2072 let new_ratio = usage.ratio();
2073 if old_ratio == TokenUsageRatio::Normal
2074 && new_ratio == TokenUsageRatio::Warning
2075 {
2076 if let Some(tx) = token_limit_tx.take() {
2077 tx.send(()).ok();
2078 }
2079 }
2080 }
2081 },
2082 );
2083
2084 let wait_for_prompt = cx
2085 .background_spawn(async move {
2086 futures::select! {
2087 response = task.fuse() => match response {
2088 Ok(Some(response)) => {
2089 match response.stop_reason {
2090 acp::StopReason::Cancelled => SubagentPromptResult::Cancelled,
2091 acp::StopReason::MaxTokens => SubagentPromptResult::Error("The agent reached the maximum number of tokens.".into()),
2092 acp::StopReason::MaxTurnRequests => SubagentPromptResult::Error("The agent reached the maximum number of allowed requests between user turns. Try prompting again.".into()),
2093 acp::StopReason::Refusal => SubagentPromptResult::Error("The agent refused to process that prompt. Try again.".into()),
2094 acp::StopReason::EndTurn | _ => SubagentPromptResult::Completed,
2095 }
2096 }
2097 Ok(None) => SubagentPromptResult::Error("No response from the agent. You can try messaging again.".into()),
2098 Err(error) => SubagentPromptResult::Error(error.to_string()),
2099 },
2100 _ = token_limit_rx.fuse() => SubagentPromptResult::ContextWindowWarning,
2101 }
2102 });
2103
2104 (wait_for_prompt, subscription)
2105 });
2106
2107 let result = match task.await {
2108 SubagentPromptResult::Completed => thread.read_with(cx, |thread, _cx| {
2109 thread
2110 .last_message()
2111 .and_then(|message| {
2112 let content = message.as_agent_message()?
2113 .content
2114 .iter()
2115 .filter_map(|c| match c {
2116 AgentMessageContent::Text(text) => Some(text.as_str()),
2117 _ => None,
2118 })
2119 .join("\n\n");
2120 if content.is_empty() {
2121 None
2122 } else {
2123 Some( content)
2124 }
2125 })
2126 .context("No response from subagent")
2127 }),
2128 SubagentPromptResult::Cancelled => Err(anyhow!("User canceled")),
2129 SubagentPromptResult::Error(message) => Err(anyhow!("{message}")),
2130 SubagentPromptResult::ContextWindowWarning => {
2131 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
2132 Err(anyhow!(
2133 "The agent is nearing the end of its context window and has been \
2134 stopped. You can prompt the thread again to have the agent wrap up \
2135 or hand off its work."
2136 ))
2137 }
2138 };
2139
2140 parent_thread
2141 .update(cx, |parent_thread, cx| {
2142 parent_thread.unregister_running_subagent(&subagent_session_id, cx)
2143 })
2144 .ok();
2145
2146 result
2147 })
2148 }
2149}
2150
2151pub struct AcpTerminalHandle {
2152 terminal: Entity<acp_thread::Terminal>,
2153 _drop_tx: Option<oneshot::Sender<()>>,
2154}
2155
2156impl TerminalHandle for AcpTerminalHandle {
2157 fn id(&self, cx: &AsyncApp) -> Result<acp::TerminalId> {
2158 Ok(self.terminal.read_with(cx, |term, _cx| term.id().clone()))
2159 }
2160
2161 fn wait_for_exit(&self, cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
2162 Ok(self
2163 .terminal
2164 .read_with(cx, |term, _cx| term.wait_for_exit()))
2165 }
2166
2167 fn current_output(&self, cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
2168 Ok(self
2169 .terminal
2170 .read_with(cx, |term, cx| term.current_output(cx)))
2171 }
2172
2173 fn kill(&self, cx: &AsyncApp) -> Result<()> {
2174 cx.update(|cx| {
2175 self.terminal.update(cx, |terminal, cx| {
2176 terminal.kill(cx);
2177 });
2178 });
2179 Ok(())
2180 }
2181
2182 fn was_stopped_by_user(&self, cx: &AsyncApp) -> Result<bool> {
2183 Ok(self
2184 .terminal
2185 .read_with(cx, |term, _cx| term.was_stopped_by_user()))
2186 }
2187}
2188
2189#[cfg(test)]
2190mod internal_tests {
2191 use std::path::Path;
2192
2193 use super::*;
2194 use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelInfo, MentionUri};
2195 use fs::FakeFs;
2196 use gpui::TestAppContext;
2197 use indoc::formatdoc;
2198 use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
2199 use language_model::{
2200 LanguageModelCompletionEvent, LanguageModelProviderId, LanguageModelProviderName,
2201 };
2202 use serde_json::json;
2203 use settings::SettingsStore;
2204 use util::{path, rel_path::rel_path};
2205
2206 #[gpui::test]
2207 async fn test_maintaining_project_context(cx: &mut TestAppContext) {
2208 init_test(cx);
2209 let fs = FakeFs::new(cx.executor());
2210 fs.insert_tree(
2211 "/",
2212 json!({
2213 "a": {}
2214 }),
2215 )
2216 .await;
2217 let project = Project::test(fs.clone(), [], cx).await;
2218 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2219 let agent =
2220 cx.update(|cx| NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx));
2221
2222 // Creating a session registers the project and triggers context building.
2223 let connection = NativeAgentConnection(agent.clone());
2224 let _acp_thread = cx
2225 .update(|cx| {
2226 Rc::new(connection).new_session(
2227 project.clone(),
2228 PathList::new(&[Path::new("/")]),
2229 cx,
2230 )
2231 })
2232 .await
2233 .unwrap();
2234 cx.run_until_parked();
2235
2236 let thread = agent.read_with(cx, |agent, _cx| {
2237 agent.sessions.values().next().unwrap().thread.clone()
2238 });
2239
2240 agent.read_with(cx, |agent, cx| {
2241 let project_id = project.entity_id();
2242 let state = agent.projects.get(&project_id).unwrap();
2243 assert_eq!(state.project_context.read(cx).worktrees, vec![]);
2244 assert_eq!(thread.read(cx).project_context().read(cx).worktrees, vec![]);
2245 });
2246
2247 let worktree = project
2248 .update(cx, |project, cx| project.create_worktree("/a", true, cx))
2249 .await
2250 .unwrap();
2251 cx.run_until_parked();
2252 agent.read_with(cx, |agent, cx| {
2253 let project_id = project.entity_id();
2254 let state = agent.projects.get(&project_id).unwrap();
2255 let expected_worktrees = vec![WorktreeContext {
2256 root_name: "a".into(),
2257 abs_path: Path::new("/a").into(),
2258 rules_file: None,
2259 }];
2260 assert_eq!(state.project_context.read(cx).worktrees, expected_worktrees);
2261 assert_eq!(
2262 thread.read(cx).project_context().read(cx).worktrees,
2263 expected_worktrees
2264 );
2265 });
2266
2267 // Creating `/a/.rules` updates the project context.
2268 fs.insert_file("/a/.rules", Vec::new()).await;
2269 cx.run_until_parked();
2270 agent.read_with(cx, |agent, cx| {
2271 let project_id = project.entity_id();
2272 let state = agent.projects.get(&project_id).unwrap();
2273 let rules_entry = worktree
2274 .read(cx)
2275 .entry_for_path(rel_path(".rules"))
2276 .unwrap();
2277 let expected_worktrees = vec![WorktreeContext {
2278 root_name: "a".into(),
2279 abs_path: Path::new("/a").into(),
2280 rules_file: Some(RulesFileContext {
2281 path_in_worktree: rel_path(".rules").into(),
2282 text: "".into(),
2283 project_entry_id: rules_entry.id.to_usize(),
2284 }),
2285 }];
2286 assert_eq!(state.project_context.read(cx).worktrees, expected_worktrees);
2287 assert_eq!(
2288 thread.read(cx).project_context().read(cx).worktrees,
2289 expected_worktrees
2290 );
2291 });
2292 }
2293
2294 #[gpui::test]
2295 async fn test_listing_models(cx: &mut TestAppContext) {
2296 init_test(cx);
2297 let fs = FakeFs::new(cx.executor());
2298 fs.insert_tree("/", json!({ "a": {} })).await;
2299 let project = Project::test(fs.clone(), [], cx).await;
2300 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2301 let connection =
2302 NativeAgentConnection(cx.update(|cx| {
2303 NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx)
2304 }));
2305
2306 // Create a thread/session
2307 let acp_thread = cx
2308 .update(|cx| {
2309 Rc::new(connection.clone()).new_session(
2310 project.clone(),
2311 PathList::new(&[Path::new("/a")]),
2312 cx,
2313 )
2314 })
2315 .await
2316 .unwrap();
2317
2318 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2319
2320 let models = cx
2321 .update(|cx| {
2322 connection
2323 .model_selector(&session_id)
2324 .unwrap()
2325 .list_models(cx)
2326 })
2327 .await
2328 .unwrap();
2329
2330 let acp_thread::AgentModelList::Grouped(models) = models else {
2331 panic!("Unexpected model group");
2332 };
2333 assert_eq!(
2334 models,
2335 IndexMap::from_iter([(
2336 AgentModelGroupName("Fake".into()),
2337 vec![AgentModelInfo {
2338 id: acp::ModelId::new("fake/fake"),
2339 name: "Fake".into(),
2340 description: None,
2341 icon: Some(acp_thread::AgentModelIcon::Named(
2342 ui::IconName::ZedAssistant
2343 )),
2344 is_latest: false,
2345 cost: None,
2346 }]
2347 )])
2348 );
2349 }
2350
2351 #[gpui::test]
2352 async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) {
2353 init_test(cx);
2354 let fs = FakeFs::new(cx.executor());
2355 fs.create_dir(paths::settings_file().parent().unwrap())
2356 .await
2357 .unwrap();
2358 fs.insert_file(
2359 paths::settings_file(),
2360 json!({
2361 "agent": {
2362 "default_model": {
2363 "provider": "foo",
2364 "model": "bar"
2365 }
2366 }
2367 })
2368 .to_string()
2369 .into_bytes(),
2370 )
2371 .await;
2372 let project = Project::test(fs.clone(), [], cx).await;
2373
2374 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2375
2376 // Create the agent and connection
2377 let agent =
2378 cx.update(|cx| NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx));
2379 let connection = NativeAgentConnection(agent.clone());
2380
2381 // Create a thread/session
2382 let acp_thread = cx
2383 .update(|cx| {
2384 Rc::new(connection.clone()).new_session(
2385 project.clone(),
2386 PathList::new(&[Path::new("/a")]),
2387 cx,
2388 )
2389 })
2390 .await
2391 .unwrap();
2392
2393 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2394
2395 // Select a model
2396 let selector = connection.model_selector(&session_id).unwrap();
2397 let model_id = acp::ModelId::new("fake/fake");
2398 cx.update(|cx| selector.select_model(model_id.clone(), cx))
2399 .await
2400 .unwrap();
2401
2402 // Verify the thread has the selected model
2403 agent.read_with(cx, |agent, _| {
2404 let session = agent.sessions.get(&session_id).unwrap();
2405 session.thread.read_with(cx, |thread, _| {
2406 assert_eq!(thread.model().unwrap().id().0, "fake");
2407 });
2408 });
2409
2410 cx.run_until_parked();
2411
2412 // Verify settings file was updated
2413 let settings_content = fs.load(paths::settings_file()).await.unwrap();
2414 let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
2415
2416 // Check that the agent settings contain the selected model
2417 assert_eq!(
2418 settings_json["agent"]["default_model"]["model"],
2419 json!("fake")
2420 );
2421 assert_eq!(
2422 settings_json["agent"]["default_model"]["provider"],
2423 json!("fake")
2424 );
2425
2426 // Register a thinking model and select it.
2427 cx.update(|cx| {
2428 let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2429 "fake-corp",
2430 "fake-thinking",
2431 "Fake Thinking",
2432 true,
2433 ));
2434 let thinking_provider = Arc::new(
2435 FakeLanguageModelProvider::new(
2436 LanguageModelProviderId::from("fake-corp".to_string()),
2437 LanguageModelProviderName::from("Fake Corp".to_string()),
2438 )
2439 .with_models(vec![thinking_model]),
2440 );
2441 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2442 registry.register_provider(thinking_provider, cx);
2443 });
2444 });
2445 agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2446
2447 let selector = connection.model_selector(&session_id).unwrap();
2448 cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2449 .await
2450 .unwrap();
2451 cx.run_until_parked();
2452
2453 // Verify enable_thinking was written to settings as true.
2454 let settings_content = fs.load(paths::settings_file()).await.unwrap();
2455 let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
2456 assert_eq!(
2457 settings_json["agent"]["default_model"]["enable_thinking"],
2458 json!(true),
2459 "selecting a thinking model should persist enable_thinking: true to settings"
2460 );
2461 }
2462
2463 #[gpui::test]
2464 async fn test_select_model_updates_thinking_enabled(cx: &mut TestAppContext) {
2465 init_test(cx);
2466 let fs = FakeFs::new(cx.executor());
2467 fs.create_dir(paths::settings_file().parent().unwrap())
2468 .await
2469 .unwrap();
2470 fs.insert_file(paths::settings_file(), b"{}".to_vec()).await;
2471 let project = Project::test(fs.clone(), [], cx).await;
2472
2473 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2474 let agent =
2475 cx.update(|cx| NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx));
2476 let connection = NativeAgentConnection(agent.clone());
2477
2478 let acp_thread = cx
2479 .update(|cx| {
2480 Rc::new(connection.clone()).new_session(
2481 project.clone(),
2482 PathList::new(&[Path::new("/a")]),
2483 cx,
2484 )
2485 })
2486 .await
2487 .unwrap();
2488 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2489
2490 // Register a second provider with a thinking model.
2491 cx.update(|cx| {
2492 let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2493 "fake-corp",
2494 "fake-thinking",
2495 "Fake Thinking",
2496 true,
2497 ));
2498 let thinking_provider = Arc::new(
2499 FakeLanguageModelProvider::new(
2500 LanguageModelProviderId::from("fake-corp".to_string()),
2501 LanguageModelProviderName::from("Fake Corp".to_string()),
2502 )
2503 .with_models(vec![thinking_model]),
2504 );
2505 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2506 registry.register_provider(thinking_provider, cx);
2507 });
2508 });
2509 // Refresh the agent's model list so it picks up the new provider.
2510 agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2511
2512 // Thread starts with thinking_enabled = false (the default).
2513 agent.read_with(cx, |agent, _| {
2514 let session = agent.sessions.get(&session_id).unwrap();
2515 session.thread.read_with(cx, |thread, _| {
2516 assert!(!thread.thinking_enabled(), "thinking defaults to false");
2517 });
2518 });
2519
2520 // Select the thinking model via select_model.
2521 let selector = connection.model_selector(&session_id).unwrap();
2522 cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2523 .await
2524 .unwrap();
2525
2526 // select_model should have enabled thinking based on the model's supports_thinking().
2527 agent.read_with(cx, |agent, _| {
2528 let session = agent.sessions.get(&session_id).unwrap();
2529 session.thread.read_with(cx, |thread, _| {
2530 assert!(
2531 thread.thinking_enabled(),
2532 "select_model should enable thinking when model supports it"
2533 );
2534 });
2535 });
2536
2537 // Switch back to the non-thinking model.
2538 let selector = connection.model_selector(&session_id).unwrap();
2539 cx.update(|cx| selector.select_model(acp::ModelId::new("fake/fake"), cx))
2540 .await
2541 .unwrap();
2542
2543 // select_model should have disabled thinking.
2544 agent.read_with(cx, |agent, _| {
2545 let session = agent.sessions.get(&session_id).unwrap();
2546 session.thread.read_with(cx, |thread, _| {
2547 assert!(
2548 !thread.thinking_enabled(),
2549 "select_model should disable thinking when model does not support it"
2550 );
2551 });
2552 });
2553 }
2554
2555 #[gpui::test]
2556 async fn test_summarization_model_survives_transient_registry_clearing(
2557 cx: &mut TestAppContext,
2558 ) {
2559 init_test(cx);
2560 let fs = FakeFs::new(cx.executor());
2561 fs.insert_tree("/", json!({ "a": {} })).await;
2562 let project = Project::test(fs.clone(), [], cx).await;
2563
2564 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2565 let agent =
2566 cx.update(|cx| NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx));
2567 let connection = Rc::new(NativeAgentConnection(agent.clone()));
2568
2569 let acp_thread = cx
2570 .update(|cx| {
2571 connection.clone().new_session(
2572 project.clone(),
2573 PathList::new(&[Path::new("/a")]),
2574 cx,
2575 )
2576 })
2577 .await
2578 .unwrap();
2579 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2580
2581 let thread = agent.read_with(cx, |agent, _| {
2582 agent.sessions.get(&session_id).unwrap().thread.clone()
2583 });
2584
2585 thread.read_with(cx, |thread, _| {
2586 assert!(
2587 thread.summarization_model().is_some(),
2588 "session should have a summarization model from the test registry"
2589 );
2590 });
2591
2592 // Simulate what happens during a provider blip:
2593 // update_active_language_model_from_settings calls set_default_model(None)
2594 // when it can't resolve the model, clearing all fallbacks.
2595 cx.update(|cx| {
2596 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2597 registry.set_default_model(None, cx);
2598 });
2599 });
2600 cx.run_until_parked();
2601
2602 thread.read_with(cx, |thread, _| {
2603 assert!(
2604 thread.summarization_model().is_some(),
2605 "summarization model should survive a transient default model clearing"
2606 );
2607 });
2608 }
2609
2610 #[gpui::test]
2611 async fn test_loaded_thread_preserves_thinking_enabled(cx: &mut TestAppContext) {
2612 init_test(cx);
2613 let fs = FakeFs::new(cx.executor());
2614 fs.insert_tree("/", json!({ "a": {} })).await;
2615 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2616 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2617 let agent = cx.update(|cx| {
2618 NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
2619 });
2620 let connection = Rc::new(NativeAgentConnection(agent.clone()));
2621
2622 // Register a thinking model.
2623 let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2624 "fake-corp",
2625 "fake-thinking",
2626 "Fake Thinking",
2627 true,
2628 ));
2629 let thinking_provider = Arc::new(
2630 FakeLanguageModelProvider::new(
2631 LanguageModelProviderId::from("fake-corp".to_string()),
2632 LanguageModelProviderName::from("Fake Corp".to_string()),
2633 )
2634 .with_models(vec![thinking_model.clone()]),
2635 );
2636 cx.update(|cx| {
2637 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2638 registry.register_provider(thinking_provider, cx);
2639 });
2640 });
2641 agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2642
2643 // Create a thread and select the thinking model.
2644 let acp_thread = cx
2645 .update(|cx| {
2646 connection.clone().new_session(
2647 project.clone(),
2648 PathList::new(&[Path::new("/a")]),
2649 cx,
2650 )
2651 })
2652 .await
2653 .unwrap();
2654 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2655
2656 let selector = connection.model_selector(&session_id).unwrap();
2657 cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2658 .await
2659 .unwrap();
2660
2661 // Verify thinking is enabled after selecting the thinking model.
2662 let thread = agent.read_with(cx, |agent, _| {
2663 agent.sessions.get(&session_id).unwrap().thread.clone()
2664 });
2665 thread.read_with(cx, |thread, _| {
2666 assert!(
2667 thread.thinking_enabled(),
2668 "thinking should be enabled after selecting thinking model"
2669 );
2670 });
2671
2672 // Send a message so the thread gets persisted.
2673 let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
2674 let send = cx.foreground_executor().spawn(send);
2675 cx.run_until_parked();
2676
2677 thinking_model.send_last_completion_stream_text_chunk("Response.");
2678 thinking_model.end_last_completion_stream();
2679
2680 send.await.unwrap();
2681 cx.run_until_parked();
2682
2683 // Close the session so it can be reloaded from disk.
2684 cx.update(|cx| connection.clone().close_session(&session_id, cx))
2685 .await
2686 .unwrap();
2687 drop(thread);
2688 drop(acp_thread);
2689 agent.read_with(cx, |agent, _| {
2690 assert!(agent.sessions.is_empty());
2691 });
2692
2693 // Reload the thread and verify thinking_enabled is still true.
2694 let reloaded_acp_thread = agent
2695 .update(cx, |agent, cx| {
2696 agent.open_thread(session_id.clone(), project.clone(), cx)
2697 })
2698 .await
2699 .unwrap();
2700 let reloaded_thread = agent.read_with(cx, |agent, _| {
2701 agent.sessions.get(&session_id).unwrap().thread.clone()
2702 });
2703 reloaded_thread.read_with(cx, |thread, _| {
2704 assert!(
2705 thread.thinking_enabled(),
2706 "thinking_enabled should be preserved when reloading a thread with a thinking model"
2707 );
2708 });
2709
2710 drop(reloaded_acp_thread);
2711 }
2712
2713 #[gpui::test]
2714 async fn test_loaded_thread_preserves_model(cx: &mut TestAppContext) {
2715 init_test(cx);
2716 let fs = FakeFs::new(cx.executor());
2717 fs.insert_tree("/", json!({ "a": {} })).await;
2718 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2719 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2720 let agent = cx.update(|cx| {
2721 NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
2722 });
2723 let connection = Rc::new(NativeAgentConnection(agent.clone()));
2724
2725 // Register a model where id() != name(), like real Anthropic models
2726 // (e.g. id="claude-sonnet-4-5-thinking-latest", name="Claude Sonnet 4.5 Thinking").
2727 let model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2728 "fake-corp",
2729 "custom-model-id",
2730 "Custom Model Display Name",
2731 false,
2732 ));
2733 let provider = Arc::new(
2734 FakeLanguageModelProvider::new(
2735 LanguageModelProviderId::from("fake-corp".to_string()),
2736 LanguageModelProviderName::from("Fake Corp".to_string()),
2737 )
2738 .with_models(vec![model.clone()]),
2739 );
2740 cx.update(|cx| {
2741 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2742 registry.register_provider(provider, cx);
2743 });
2744 });
2745 agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2746
2747 // Create a thread and select the model.
2748 let acp_thread = cx
2749 .update(|cx| {
2750 connection.clone().new_session(
2751 project.clone(),
2752 PathList::new(&[Path::new("/a")]),
2753 cx,
2754 )
2755 })
2756 .await
2757 .unwrap();
2758 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2759
2760 let selector = connection.model_selector(&session_id).unwrap();
2761 cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/custom-model-id"), cx))
2762 .await
2763 .unwrap();
2764
2765 let thread = agent.read_with(cx, |agent, _| {
2766 agent.sessions.get(&session_id).unwrap().thread.clone()
2767 });
2768 thread.read_with(cx, |thread, _| {
2769 assert_eq!(
2770 thread.model().unwrap().id().0.as_ref(),
2771 "custom-model-id",
2772 "model should be set before persisting"
2773 );
2774 });
2775
2776 // Send a message so the thread gets persisted.
2777 let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
2778 let send = cx.foreground_executor().spawn(send);
2779 cx.run_until_parked();
2780
2781 model.send_last_completion_stream_text_chunk("Response.");
2782 model.end_last_completion_stream();
2783
2784 send.await.unwrap();
2785 cx.run_until_parked();
2786
2787 // Close the session so it can be reloaded from disk.
2788 cx.update(|cx| connection.clone().close_session(&session_id, cx))
2789 .await
2790 .unwrap();
2791 drop(thread);
2792 drop(acp_thread);
2793 agent.read_with(cx, |agent, _| {
2794 assert!(agent.sessions.is_empty());
2795 });
2796
2797 // Reload the thread and verify the model was preserved.
2798 let reloaded_acp_thread = agent
2799 .update(cx, |agent, cx| {
2800 agent.open_thread(session_id.clone(), project.clone(), cx)
2801 })
2802 .await
2803 .unwrap();
2804 let reloaded_thread = agent.read_with(cx, |agent, _| {
2805 agent.sessions.get(&session_id).unwrap().thread.clone()
2806 });
2807 reloaded_thread.read_with(cx, |thread, _| {
2808 let reloaded_model = thread
2809 .model()
2810 .expect("model should be present after reload");
2811 assert_eq!(
2812 reloaded_model.id().0.as_ref(),
2813 "custom-model-id",
2814 "reloaded thread should have the same model, not fall back to the default"
2815 );
2816 });
2817
2818 drop(reloaded_acp_thread);
2819 }
2820
2821 #[gpui::test]
2822 async fn test_save_load_thread(cx: &mut TestAppContext) {
2823 init_test(cx);
2824 let fs = FakeFs::new(cx.executor());
2825 fs.insert_tree(
2826 "/",
2827 json!({
2828 "a": {
2829 "b.md": "Lorem"
2830 }
2831 }),
2832 )
2833 .await;
2834 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2835 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2836 let agent = cx.update(|cx| {
2837 NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
2838 });
2839 let connection = Rc::new(NativeAgentConnection(agent.clone()));
2840
2841 let acp_thread = cx
2842 .update(|cx| {
2843 connection
2844 .clone()
2845 .new_session(project.clone(), PathList::new(&[Path::new("")]), cx)
2846 })
2847 .await
2848 .unwrap();
2849 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2850 let thread = agent.read_with(cx, |agent, _| {
2851 agent.sessions.get(&session_id).unwrap().thread.clone()
2852 });
2853
2854 // Ensure empty threads are not saved, even if they get mutated.
2855 let model = Arc::new(FakeLanguageModel::default());
2856 let summary_model = Arc::new(FakeLanguageModel::default());
2857 thread.update(cx, |thread, cx| {
2858 thread.set_model(model.clone(), cx);
2859 thread.set_summarization_model(Some(summary_model.clone()), cx);
2860 });
2861 cx.run_until_parked();
2862 assert_eq!(thread_entries(&thread_store, cx), vec![]);
2863
2864 let send = acp_thread.update(cx, |thread, cx| {
2865 thread.send(
2866 vec![
2867 "What does ".into(),
2868 acp::ContentBlock::ResourceLink(acp::ResourceLink::new(
2869 "b.md",
2870 MentionUri::File {
2871 abs_path: path!("/a/b.md").into(),
2872 }
2873 .to_uri()
2874 .to_string(),
2875 )),
2876 " mean?".into(),
2877 ],
2878 cx,
2879 )
2880 });
2881 let send = cx.foreground_executor().spawn(send);
2882 cx.run_until_parked();
2883
2884 model.send_last_completion_stream_text_chunk("Lorem.");
2885 model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2886 language_model::TokenUsage {
2887 input_tokens: 150,
2888 output_tokens: 75,
2889 ..Default::default()
2890 },
2891 ));
2892 model.end_last_completion_stream();
2893 cx.run_until_parked();
2894 summary_model
2895 .send_last_completion_stream_text_chunk(&format!("Explaining {}", path!("/a/b.md")));
2896 summary_model.end_last_completion_stream();
2897
2898 send.await.unwrap();
2899 let uri = MentionUri::File {
2900 abs_path: path!("/a/b.md").into(),
2901 }
2902 .to_uri();
2903 acp_thread.read_with(cx, |thread, cx| {
2904 assert_eq!(
2905 thread.to_markdown(cx),
2906 formatdoc! {"
2907 ## User
2908
2909 What does [@b.md]({uri}) mean?
2910
2911 ## Assistant
2912
2913 Lorem.
2914
2915 "}
2916 )
2917 });
2918
2919 cx.run_until_parked();
2920
2921 // Set a draft prompt with rich content blocks and scroll position
2922 // AFTER run_until_parked, so the only save that captures these
2923 // changes is the one performed by close_session itself.
2924 let draft_blocks = vec![
2925 acp::ContentBlock::Text(acp::TextContent::new("Check out ")),
2926 acp::ContentBlock::ResourceLink(acp::ResourceLink::new("b.md", uri.to_string())),
2927 acp::ContentBlock::Text(acp::TextContent::new(" please")),
2928 ];
2929 acp_thread.update(cx, |thread, cx| {
2930 thread.set_draft_prompt(Some(draft_blocks.clone()), cx);
2931 });
2932 thread.update(cx, |thread, _cx| {
2933 thread.set_ui_scroll_position(Some(gpui::ListOffset {
2934 item_ix: 5,
2935 offset_in_item: gpui::px(12.5),
2936 }));
2937 });
2938
2939 // Close the session so it can be reloaded from disk.
2940 cx.update(|cx| connection.clone().close_session(&session_id, cx))
2941 .await
2942 .unwrap();
2943 drop(thread);
2944 drop(acp_thread);
2945 agent.read_with(cx, |agent, _| {
2946 assert_eq!(agent.sessions.keys().cloned().collect::<Vec<_>>(), []);
2947 });
2948
2949 // Ensure the thread can be reloaded from disk.
2950 assert_eq!(
2951 thread_entries(&thread_store, cx),
2952 vec![(
2953 session_id.clone(),
2954 format!("Explaining {}", path!("/a/b.md"))
2955 )]
2956 );
2957 let acp_thread = agent
2958 .update(cx, |agent, cx| {
2959 agent.open_thread(session_id.clone(), project.clone(), cx)
2960 })
2961 .await
2962 .unwrap();
2963 acp_thread.read_with(cx, |thread, cx| {
2964 assert_eq!(
2965 thread.to_markdown(cx),
2966 formatdoc! {"
2967 ## User
2968
2969 What does [@b.md]({uri}) mean?
2970
2971 ## Assistant
2972
2973 Lorem.
2974
2975 "}
2976 )
2977 });
2978
2979 // Ensure the draft prompt with rich content blocks survived the round-trip.
2980 acp_thread.read_with(cx, |thread, _| {
2981 assert_eq!(thread.draft_prompt(), Some(draft_blocks.as_slice()));
2982 });
2983
2984 // Ensure token usage survived the round-trip.
2985 acp_thread.read_with(cx, |thread, _| {
2986 let usage = thread
2987 .token_usage()
2988 .expect("token usage should be restored after reload");
2989 assert_eq!(usage.input_tokens, 150);
2990 assert_eq!(usage.output_tokens, 75);
2991 });
2992
2993 // Ensure scroll position survived the round-trip.
2994 acp_thread.read_with(cx, |thread, _| {
2995 let scroll = thread
2996 .ui_scroll_position()
2997 .expect("scroll position should be restored after reload");
2998 assert_eq!(scroll.item_ix, 5);
2999 assert_eq!(scroll.offset_in_item, gpui::px(12.5));
3000 });
3001 }
3002
3003 #[gpui::test]
3004 async fn test_close_session_saves_thread(cx: &mut TestAppContext) {
3005 init_test(cx);
3006 let fs = FakeFs::new(cx.executor());
3007 fs.insert_tree(
3008 "/",
3009 json!({
3010 "a": {
3011 "file.txt": "hello"
3012 }
3013 }),
3014 )
3015 .await;
3016 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
3017 let thread_store = cx.new(|cx| ThreadStore::new(cx));
3018 let agent = cx.update(|cx| {
3019 NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
3020 });
3021 let connection = Rc::new(NativeAgentConnection(agent.clone()));
3022
3023 let acp_thread = cx
3024 .update(|cx| {
3025 connection
3026 .clone()
3027 .new_session(project.clone(), PathList::new(&[Path::new("")]), cx)
3028 })
3029 .await
3030 .unwrap();
3031 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
3032 let thread = agent.read_with(cx, |agent, _| {
3033 agent.sessions.get(&session_id).unwrap().thread.clone()
3034 });
3035
3036 let model = Arc::new(FakeLanguageModel::default());
3037 thread.update(cx, |thread, cx| {
3038 thread.set_model(model.clone(), cx);
3039 });
3040
3041 // Send a message so the thread is non-empty (empty threads aren't saved).
3042 let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx));
3043 let send = cx.foreground_executor().spawn(send);
3044 cx.run_until_parked();
3045
3046 model.send_last_completion_stream_text_chunk("world");
3047 model.end_last_completion_stream();
3048 send.await.unwrap();
3049 cx.run_until_parked();
3050
3051 // Set a draft prompt WITHOUT calling run_until_parked afterwards.
3052 // This means no observe-triggered save has run for this change.
3053 // The only way this data gets persisted is if close_session
3054 // itself performs the save.
3055 let draft_blocks = vec![acp::ContentBlock::Text(acp::TextContent::new(
3056 "unsaved draft",
3057 ))];
3058 acp_thread.update(cx, |thread, cx| {
3059 thread.set_draft_prompt(Some(draft_blocks.clone()), cx);
3060 });
3061
3062 // Close the session immediately — no run_until_parked in between.
3063 cx.update(|cx| connection.clone().close_session(&session_id, cx))
3064 .await
3065 .unwrap();
3066 cx.run_until_parked();
3067
3068 // Reopen and verify the draft prompt was saved.
3069 let reloaded = agent
3070 .update(cx, |agent, cx| {
3071 agent.open_thread(session_id.clone(), project.clone(), cx)
3072 })
3073 .await
3074 .unwrap();
3075 reloaded.read_with(cx, |thread, _| {
3076 assert_eq!(
3077 thread.draft_prompt(),
3078 Some(draft_blocks.as_slice()),
3079 "close_session must save the thread; draft prompt was lost"
3080 );
3081 });
3082 }
3083
3084 #[gpui::test]
3085 async fn test_thread_summary_releases_loaded_session(cx: &mut TestAppContext) {
3086 init_test(cx);
3087 let fs = FakeFs::new(cx.executor());
3088 fs.insert_tree(
3089 "/",
3090 json!({
3091 "a": {
3092 "file.txt": "hello"
3093 }
3094 }),
3095 )
3096 .await;
3097 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
3098 let thread_store = cx.new(|cx| ThreadStore::new(cx));
3099 let agent = cx.update(|cx| {
3100 NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
3101 });
3102 let connection = Rc::new(NativeAgentConnection(agent.clone()));
3103
3104 let acp_thread = cx
3105 .update(|cx| {
3106 connection
3107 .clone()
3108 .new_session(project.clone(), PathList::new(&[Path::new("")]), cx)
3109 })
3110 .await
3111 .unwrap();
3112 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
3113 let thread = agent.read_with(cx, |agent, _| {
3114 agent.sessions.get(&session_id).unwrap().thread.clone()
3115 });
3116
3117 let model = Arc::new(FakeLanguageModel::default());
3118 let summary_model = Arc::new(FakeLanguageModel::default());
3119 thread.update(cx, |thread, cx| {
3120 thread.set_model(model.clone(), cx);
3121 thread.set_summarization_model(Some(summary_model.clone()), cx);
3122 });
3123
3124 let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx));
3125 let send = cx.foreground_executor().spawn(send);
3126 cx.run_until_parked();
3127
3128 model.send_last_completion_stream_text_chunk("world");
3129 model.end_last_completion_stream();
3130 send.await.unwrap();
3131 cx.run_until_parked();
3132
3133 let summary = agent.update(cx, |agent, cx| {
3134 agent.thread_summary(session_id.clone(), project.clone(), cx)
3135 });
3136 cx.run_until_parked();
3137
3138 summary_model.send_last_completion_stream_text_chunk("summary");
3139 summary_model.end_last_completion_stream();
3140
3141 assert_eq!(summary.await.unwrap(), "summary");
3142 cx.run_until_parked();
3143
3144 agent.read_with(cx, |agent, _| {
3145 let session = agent
3146 .sessions
3147 .get(&session_id)
3148 .expect("thread_summary should not close the active session");
3149 assert_eq!(
3150 session.ref_count, 1,
3151 "thread_summary should release its temporary session reference"
3152 );
3153 });
3154
3155 cx.update(|cx| connection.clone().close_session(&session_id, cx))
3156 .await
3157 .unwrap();
3158 cx.run_until_parked();
3159
3160 agent.read_with(cx, |agent, _| {
3161 assert!(
3162 agent.sessions.is_empty(),
3163 "closing the active session after thread_summary should unload it"
3164 );
3165 });
3166 }
3167
3168 #[gpui::test]
3169 async fn test_loaded_sessions_keep_state_until_last_close(cx: &mut TestAppContext) {
3170 init_test(cx);
3171 let fs = FakeFs::new(cx.executor());
3172 fs.insert_tree(
3173 "/",
3174 json!({
3175 "a": {
3176 "file.txt": "hello"
3177 }
3178 }),
3179 )
3180 .await;
3181 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
3182 let thread_store = cx.new(|cx| ThreadStore::new(cx));
3183 let agent = cx.update(|cx| {
3184 NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
3185 });
3186 let connection = Rc::new(NativeAgentConnection(agent.clone()));
3187
3188 let acp_thread = cx
3189 .update(|cx| {
3190 connection
3191 .clone()
3192 .new_session(project.clone(), PathList::new(&[Path::new("")]), cx)
3193 })
3194 .await
3195 .unwrap();
3196 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
3197 let thread = agent.read_with(cx, |agent, _| {
3198 agent.sessions.get(&session_id).unwrap().thread.clone()
3199 });
3200
3201 let model = cx.update(|cx| {
3202 LanguageModelRegistry::read_global(cx)
3203 .default_model()
3204 .map(|default_model| default_model.model)
3205 .expect("default test model should be available")
3206 });
3207 let fake_model = model.as_fake();
3208 thread.update(cx, |thread, cx| {
3209 thread.set_model(model.clone(), cx);
3210 });
3211
3212 let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx));
3213 let send = cx.foreground_executor().spawn(send);
3214 cx.run_until_parked();
3215
3216 fake_model.send_last_completion_stream_text_chunk("world");
3217 fake_model.end_last_completion_stream();
3218 send.await.unwrap();
3219 cx.run_until_parked();
3220
3221 cx.update(|cx| connection.clone().close_session(&session_id, cx))
3222 .await
3223 .unwrap();
3224 drop(thread);
3225 drop(acp_thread);
3226 agent.read_with(cx, |agent, _| {
3227 assert!(agent.sessions.is_empty());
3228 });
3229
3230 let first_loaded_thread = cx.update(|cx| {
3231 connection.clone().load_session(
3232 session_id.clone(),
3233 project.clone(),
3234 PathList::new(&[Path::new("")]),
3235 None,
3236 cx,
3237 )
3238 });
3239 let second_loaded_thread = cx.update(|cx| {
3240 connection.clone().load_session(
3241 session_id.clone(),
3242 project.clone(),
3243 PathList::new(&[Path::new("")]),
3244 None,
3245 cx,
3246 )
3247 });
3248
3249 let first_loaded_thread = first_loaded_thread.await.unwrap();
3250 let second_loaded_thread = second_loaded_thread.await.unwrap();
3251
3252 cx.run_until_parked();
3253
3254 assert_eq!(
3255 first_loaded_thread.entity_id(),
3256 second_loaded_thread.entity_id(),
3257 "concurrent loads for the same session should share one AcpThread"
3258 );
3259
3260 cx.update(|cx| connection.clone().close_session(&session_id, cx))
3261 .await
3262 .unwrap();
3263
3264 agent.read_with(cx, |agent, _| {
3265 assert!(
3266 agent.sessions.contains_key(&session_id),
3267 "closing one loaded session should not drop shared session state"
3268 );
3269 });
3270
3271 let follow_up = second_loaded_thread.update(cx, |thread, cx| {
3272 thread.send(vec!["still there?".into()], cx)
3273 });
3274 let follow_up = cx.foreground_executor().spawn(follow_up);
3275 cx.run_until_parked();
3276
3277 fake_model.send_last_completion_stream_text_chunk("yes");
3278 fake_model.end_last_completion_stream();
3279 follow_up.await.unwrap();
3280 cx.run_until_parked();
3281
3282 second_loaded_thread.read_with(cx, |thread, cx| {
3283 assert_eq!(
3284 thread.to_markdown(cx),
3285 formatdoc! {"
3286 ## User
3287
3288 hello
3289
3290 ## Assistant
3291
3292 world
3293
3294 ## User
3295
3296 still there?
3297
3298 ## Assistant
3299
3300 yes
3301
3302 "}
3303 );
3304 });
3305
3306 cx.update(|cx| connection.clone().close_session(&session_id, cx))
3307 .await
3308 .unwrap();
3309
3310 cx.run_until_parked();
3311
3312 drop(first_loaded_thread);
3313 drop(second_loaded_thread);
3314 agent.read_with(cx, |agent, _| {
3315 assert!(agent.sessions.is_empty());
3316 });
3317 }
3318
3319 #[gpui::test]
3320 async fn test_rapid_title_changes_do_not_loop(cx: &mut TestAppContext) {
3321 // Regression test: rapid title changes must not cause a propagation loop
3322 // between Thread and AcpThread via handle_thread_title_updated.
3323 init_test(cx);
3324 let fs = FakeFs::new(cx.executor());
3325 fs.insert_tree("/", json!({ "a": {} })).await;
3326 let project = Project::test(fs.clone(), [], cx).await;
3327 let thread_store = cx.new(|cx| ThreadStore::new(cx));
3328 let agent = cx.update(|cx| {
3329 NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
3330 });
3331 let connection = Rc::new(NativeAgentConnection(agent.clone()));
3332
3333 let acp_thread = cx
3334 .update(|cx| {
3335 connection
3336 .clone()
3337 .new_session(project.clone(), PathList::new(&[Path::new("")]), cx)
3338 })
3339 .await
3340 .unwrap();
3341
3342 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
3343 let thread = agent.read_with(cx, |agent, _| {
3344 agent.sessions.get(&session_id).unwrap().thread.clone()
3345 });
3346
3347 let title_updated_count = Rc::new(std::cell::RefCell::new(0usize));
3348 cx.update(|cx| {
3349 let count = title_updated_count.clone();
3350 cx.subscribe(
3351 &thread,
3352 move |_entity: Entity<Thread>, _event: &TitleUpdated, _cx: &mut App| {
3353 let new_count = {
3354 let mut count = count.borrow_mut();
3355 *count += 1;
3356 *count
3357 };
3358 assert!(
3359 new_count <= 2,
3360 "TitleUpdated fired {new_count} times; \
3361 title updates are looping"
3362 );
3363 },
3364 )
3365 .detach();
3366 });
3367
3368 thread.update(cx, |thread, cx| thread.set_title("first".into(), cx));
3369 thread.update(cx, |thread, cx| thread.set_title("second".into(), cx));
3370
3371 cx.run_until_parked();
3372
3373 thread.read_with(cx, |thread, _| {
3374 assert_eq!(thread.title(), Some("second".into()));
3375 });
3376 acp_thread.read_with(cx, |acp_thread, _| {
3377 assert_eq!(acp_thread.title(), Some("second".into()));
3378 });
3379
3380 assert_eq!(*title_updated_count.borrow(), 2);
3381 }
3382
3383 fn thread_entries(
3384 thread_store: &Entity<ThreadStore>,
3385 cx: &mut TestAppContext,
3386 ) -> Vec<(acp::SessionId, String)> {
3387 thread_store.read_with(cx, |store, _| {
3388 store
3389 .entries()
3390 .map(|entry| (entry.id.clone(), entry.title.to_string()))
3391 .collect::<Vec<_>>()
3392 })
3393 }
3394
3395 fn init_test(cx: &mut TestAppContext) {
3396 env_logger::try_init().ok();
3397 cx.update(|cx| {
3398 let settings_store = SettingsStore::test(cx);
3399 cx.set_global(settings_store);
3400
3401 LanguageModelRegistry::test(cx);
3402 });
3403 }
3404}
3405
3406fn mcp_message_content_to_acp_content_block(
3407 content: context_server::types::MessageContent,
3408) -> acp::ContentBlock {
3409 match content {
3410 context_server::types::MessageContent::Text {
3411 text,
3412 annotations: _,
3413 } => text.into(),
3414 context_server::types::MessageContent::Image {
3415 data,
3416 mime_type,
3417 annotations: _,
3418 } => acp::ContentBlock::Image(acp::ImageContent::new(data, mime_type)),
3419 context_server::types::MessageContent::Audio {
3420 data,
3421 mime_type,
3422 annotations: _,
3423 } => acp::ContentBlock::Audio(acp::AudioContent::new(data, mime_type)),
3424 context_server::types::MessageContent::Resource {
3425 resource,
3426 annotations: _,
3427 } => {
3428 let mut link =
3429 acp::ResourceLink::new(resource.uri.to_string(), resource.uri.to_string());
3430 if let Some(mime_type) = resource.mime_type {
3431 link = link.mime_type(mime_type);
3432 }
3433 acp::ContentBlock::ResourceLink(link)
3434 }
3435 }
3436}