1mod db;
2mod edit_agent;
3mod legacy_thread;
4mod native_agent_server;
5pub mod outline;
6mod pattern_extraction;
7mod templates;
8#[cfg(test)]
9mod tests;
10mod thread;
11mod thread_store;
12mod tool_permissions;
13mod tools;
14
15use context_server::ContextServerId;
16pub use db::*;
17use itertools::Itertools;
18pub use native_agent_server::NativeAgentServer;
19pub use pattern_extraction::*;
20pub use shell_command_parser::extract_commands;
21pub use templates::*;
22pub use thread::*;
23pub use thread_store::*;
24pub use tool_permissions::*;
25pub use tools::*;
26
27use acp_thread::{
28 AcpThread, AgentModelSelector, AgentSessionInfo, AgentSessionList, AgentSessionListRequest,
29 AgentSessionListResponse, TokenUsageRatio, UserMessageId,
30};
31use agent_client_protocol as acp;
32use anyhow::{Context as _, Result, anyhow};
33use chrono::{DateTime, Utc};
34use collections::{HashMap, HashSet, IndexMap};
35use fs::Fs;
36use futures::channel::{mpsc, oneshot};
37use futures::future::Shared;
38use futures::{FutureExt as _, StreamExt as _, future};
39use gpui::{
40 App, AppContext, AsyncApp, Context, Entity, EntityId, SharedString, Subscription, Task,
41 WeakEntity,
42};
43use language_model::{IconOrSvg, LanguageModel, LanguageModelProvider, LanguageModelRegistry};
44use project::{AgentId, Project, ProjectItem, ProjectPath, Worktree};
45use prompt_store::{
46 ProjectContext, PromptStore, RULES_FILE_NAMES, RulesFileContext, UserRulesContext,
47 WorktreeContext,
48};
49use serde::{Deserialize, Serialize};
50use settings::{LanguageModelSelection, update_settings_file};
51use std::any::Any;
52use std::path::PathBuf;
53use std::rc::Rc;
54use std::sync::{Arc, LazyLock};
55use util::ResultExt;
56use util::path_list::PathList;
57use util::rel_path::RelPath;
58
59#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
60pub struct ProjectSnapshot {
61 pub worktree_snapshots: Vec<project::telemetry_snapshot::TelemetryWorktreeSnapshot>,
62 pub timestamp: DateTime<Utc>,
63}
64
65pub struct RulesLoadingError {
66 pub message: SharedString,
67}
68
69struct ProjectState {
70 project: Entity<Project>,
71 project_context: Entity<ProjectContext>,
72 project_context_needs_refresh: watch::Sender<()>,
73 _maintain_project_context: Task<Result<()>>,
74 context_server_registry: Entity<ContextServerRegistry>,
75 _subscriptions: Vec<Subscription>,
76}
77
78/// Holds both the internal Thread and the AcpThread for a session
79struct Session {
80 /// The internal thread that processes messages
81 thread: Entity<Thread>,
82 /// The ACP thread that handles protocol communication
83 acp_thread: Entity<acp_thread::AcpThread>,
84 project_id: EntityId,
85 pending_save: Task<Result<()>>,
86 _subscriptions: Vec<Subscription>,
87}
88
89pub struct LanguageModels {
90 /// Access language model by ID
91 models: HashMap<acp::ModelId, Arc<dyn LanguageModel>>,
92 /// Cached list for returning language model information
93 model_list: acp_thread::AgentModelList,
94 refresh_models_rx: watch::Receiver<()>,
95 refresh_models_tx: watch::Sender<()>,
96 _authenticate_all_providers_task: Task<()>,
97}
98
99impl LanguageModels {
100 fn new(cx: &mut App) -> Self {
101 let (refresh_models_tx, refresh_models_rx) = watch::channel(());
102
103 let mut this = Self {
104 models: HashMap::default(),
105 model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()),
106 refresh_models_rx,
107 refresh_models_tx,
108 _authenticate_all_providers_task: Self::authenticate_all_language_model_providers(cx),
109 };
110 this.refresh_list(cx);
111 this
112 }
113
114 fn refresh_list(&mut self, cx: &App) {
115 let providers = LanguageModelRegistry::global(cx)
116 .read(cx)
117 .visible_providers()
118 .into_iter()
119 .filter(|provider| provider.is_authenticated(cx))
120 .collect::<Vec<_>>();
121
122 let mut language_model_list = IndexMap::default();
123 let mut recommended_models = HashSet::default();
124
125 let mut recommended = Vec::new();
126 for provider in &providers {
127 for model in provider.recommended_models(cx) {
128 recommended_models.insert((model.provider_id(), model.id()));
129 recommended.push(Self::map_language_model_to_info(&model, provider));
130 }
131 }
132 if !recommended.is_empty() {
133 language_model_list.insert(
134 acp_thread::AgentModelGroupName("Recommended".into()),
135 recommended,
136 );
137 }
138
139 let mut models = HashMap::default();
140 for provider in providers {
141 let mut provider_models = Vec::new();
142 for model in provider.provided_models(cx) {
143 let model_info = Self::map_language_model_to_info(&model, &provider);
144 let model_id = model_info.id.clone();
145 provider_models.push(model_info);
146 models.insert(model_id, model);
147 }
148 if !provider_models.is_empty() {
149 language_model_list.insert(
150 acp_thread::AgentModelGroupName(provider.name().0.clone()),
151 provider_models,
152 );
153 }
154 }
155
156 self.models = models;
157 self.model_list = acp_thread::AgentModelList::Grouped(language_model_list);
158 self.refresh_models_tx.send(()).ok();
159 }
160
161 fn watch(&self) -> watch::Receiver<()> {
162 self.refresh_models_rx.clone()
163 }
164
165 pub fn model_from_id(&self, model_id: &acp::ModelId) -> Option<Arc<dyn LanguageModel>> {
166 self.models.get(model_id).cloned()
167 }
168
169 fn map_language_model_to_info(
170 model: &Arc<dyn LanguageModel>,
171 provider: &Arc<dyn LanguageModelProvider>,
172 ) -> acp_thread::AgentModelInfo {
173 acp_thread::AgentModelInfo {
174 id: Self::model_id(model),
175 name: model.name().0,
176 description: None,
177 icon: Some(match provider.icon() {
178 IconOrSvg::Svg(path) => acp_thread::AgentModelIcon::Path(path),
179 IconOrSvg::Icon(name) => acp_thread::AgentModelIcon::Named(name),
180 }),
181 is_latest: model.is_latest(),
182 cost: model.model_cost_info().map(|cost| cost.to_shared_string()),
183 }
184 }
185
186 fn model_id(model: &Arc<dyn LanguageModel>) -> acp::ModelId {
187 acp::ModelId::new(format!("{}/{}", model.provider_id().0, model.id().0))
188 }
189
190 fn authenticate_all_language_model_providers(cx: &mut App) -> Task<()> {
191 let authenticate_all_providers = LanguageModelRegistry::global(cx)
192 .read(cx)
193 .visible_providers()
194 .iter()
195 .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
196 .collect::<Vec<_>>();
197
198 cx.background_spawn(async move {
199 for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
200 if let Err(err) = authenticate_task.await {
201 match err {
202 language_model::AuthenticateError::CredentialsNotFound => {
203 // Since we're authenticating these providers in the
204 // background for the purposes of populating the
205 // language selector, we don't care about providers
206 // where the credentials are not found.
207 }
208 language_model::AuthenticateError::ConnectionRefused => {
209 // Not logging connection refused errors as they are mostly from LM Studio's noisy auth failures.
210 // LM Studio only has one auth method (endpoint call) which fails for users who haven't enabled it.
211 // TODO: Better manage LM Studio auth logic to avoid these noisy failures.
212 }
213 _ => {
214 // Some providers have noisy failure states that we
215 // don't want to spam the logs with every time the
216 // language model selector is initialized.
217 //
218 // Ideally these should have more clear failure modes
219 // that we know are safe to ignore here, like what we do
220 // with `CredentialsNotFound` above.
221 match provider_id.0.as_ref() {
222 "lmstudio" | "ollama" => {
223 // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
224 //
225 // These fail noisily, so we don't log them.
226 }
227 "copilot_chat" => {
228 // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
229 }
230 _ => {
231 log::error!(
232 "Failed to authenticate provider: {}: {err:#}",
233 provider_name.0
234 );
235 }
236 }
237 }
238 }
239 }
240 }
241 })
242 }
243}
244
245pub struct NativeAgent {
246 /// Session ID -> Session mapping
247 sessions: HashMap<acp::SessionId, Session>,
248 thread_store: Entity<ThreadStore>,
249 /// Project-specific state keyed by project EntityId
250 projects: HashMap<EntityId, ProjectState>,
251 /// Shared templates for all threads
252 templates: Arc<Templates>,
253 /// Cached model information
254 models: LanguageModels,
255 prompt_store: Option<Entity<PromptStore>>,
256 fs: Arc<dyn Fs>,
257 _subscriptions: Vec<Subscription>,
258}
259
260impl NativeAgent {
261 pub fn new(
262 thread_store: Entity<ThreadStore>,
263 templates: Arc<Templates>,
264 prompt_store: Option<Entity<PromptStore>>,
265 fs: Arc<dyn Fs>,
266 cx: &mut App,
267 ) -> Entity<NativeAgent> {
268 log::debug!("Creating new NativeAgent");
269
270 cx.new(|cx| {
271 let mut subscriptions = vec![cx.subscribe(
272 &LanguageModelRegistry::global(cx),
273 Self::handle_models_updated_event,
274 )];
275 if let Some(prompt_store) = prompt_store.as_ref() {
276 subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
277 }
278
279 Self {
280 sessions: HashMap::default(),
281 thread_store,
282 projects: HashMap::default(),
283 templates,
284 models: LanguageModels::new(cx),
285 prompt_store,
286 fs,
287 _subscriptions: subscriptions,
288 }
289 })
290 }
291
292 fn new_session(
293 &mut self,
294 project: Entity<Project>,
295 cx: &mut Context<Self>,
296 ) -> Entity<AcpThread> {
297 let project_id = self.get_or_create_project_state(&project, cx);
298 let project_state = &self.projects[&project_id];
299
300 let registry = LanguageModelRegistry::read_global(cx);
301 let available_count = registry.available_models(cx).count();
302 log::debug!("Total available models: {}", available_count);
303
304 let default_model = registry.default_model().and_then(|default_model| {
305 self.models
306 .model_from_id(&LanguageModels::model_id(&default_model.model))
307 });
308 let thread = cx.new(|cx| {
309 Thread::new(
310 project,
311 project_state.project_context.clone(),
312 project_state.context_server_registry.clone(),
313 self.templates.clone(),
314 default_model,
315 cx,
316 )
317 });
318
319 self.register_session(thread, project_id, cx)
320 }
321
322 fn register_session(
323 &mut self,
324 thread_handle: Entity<Thread>,
325 project_id: EntityId,
326 cx: &mut Context<Self>,
327 ) -> Entity<AcpThread> {
328 let connection = Rc::new(NativeAgentConnection(cx.entity()));
329
330 let thread = thread_handle.read(cx);
331 let session_id = thread.id().clone();
332 let parent_session_id = thread.parent_thread_id();
333 let title = thread.title();
334 let draft_prompt = thread.draft_prompt().map(Vec::from);
335 let scroll_position = thread.ui_scroll_position();
336 let token_usage = thread.latest_token_usage();
337 let project = thread.project.clone();
338 let action_log = thread.action_log.clone();
339 let prompt_capabilities_rx = thread.prompt_capabilities_rx.clone();
340 let acp_thread = cx.new(|cx| {
341 let mut acp_thread = acp_thread::AcpThread::new(
342 parent_session_id,
343 title,
344 None,
345 connection,
346 project.clone(),
347 action_log.clone(),
348 session_id.clone(),
349 prompt_capabilities_rx,
350 cx,
351 );
352 acp_thread.set_draft_prompt(draft_prompt);
353 acp_thread.set_ui_scroll_position(scroll_position);
354 acp_thread.update_token_usage(token_usage, cx);
355 acp_thread
356 });
357
358 let registry = LanguageModelRegistry::read_global(cx);
359 let summarization_model = registry.thread_summary_model().map(|c| c.model);
360
361 let weak = cx.weak_entity();
362 let weak_thread = thread_handle.downgrade();
363 thread_handle.update(cx, |thread, cx| {
364 thread.set_summarization_model(summarization_model, cx);
365 thread.add_default_tools(
366 Rc::new(NativeThreadEnvironment {
367 acp_thread: acp_thread.downgrade(),
368 thread: weak_thread,
369 agent: weak,
370 }) as _,
371 cx,
372 )
373 });
374
375 let subscriptions = vec![
376 cx.subscribe(&thread_handle, Self::handle_thread_title_updated),
377 cx.subscribe(&thread_handle, Self::handle_thread_token_usage_updated),
378 cx.observe(&thread_handle, move |this, thread, cx| {
379 this.save_thread(thread, cx)
380 }),
381 ];
382
383 self.sessions.insert(
384 session_id,
385 Session {
386 thread: thread_handle,
387 acp_thread: acp_thread.clone(),
388 project_id,
389 _subscriptions: subscriptions,
390 pending_save: Task::ready(Ok(())),
391 },
392 );
393
394 self.update_available_commands_for_project(project_id, cx);
395
396 acp_thread
397 }
398
399 pub fn models(&self) -> &LanguageModels {
400 &self.models
401 }
402
403 fn get_or_create_project_state(
404 &mut self,
405 project: &Entity<Project>,
406 cx: &mut Context<Self>,
407 ) -> EntityId {
408 let project_id = project.entity_id();
409 if self.projects.contains_key(&project_id) {
410 return project_id;
411 }
412
413 let project_context = cx.new(|_| ProjectContext::new(vec![], vec![]));
414 self.register_project_with_initial_context(project.clone(), project_context, cx);
415 if let Some(state) = self.projects.get_mut(&project_id) {
416 state.project_context_needs_refresh.send(()).ok();
417 }
418 project_id
419 }
420
421 fn register_project_with_initial_context(
422 &mut self,
423 project: Entity<Project>,
424 project_context: Entity<ProjectContext>,
425 cx: &mut Context<Self>,
426 ) {
427 let project_id = project.entity_id();
428
429 let context_server_store = project.read(cx).context_server_store();
430 let context_server_registry =
431 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
432
433 let subscriptions = vec![
434 cx.subscribe(&project, Self::handle_project_event),
435 cx.subscribe(
436 &context_server_store,
437 Self::handle_context_server_store_updated,
438 ),
439 cx.subscribe(
440 &context_server_registry,
441 Self::handle_context_server_registry_event,
442 ),
443 ];
444
445 let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
446 watch::channel(());
447
448 self.projects.insert(
449 project_id,
450 ProjectState {
451 project,
452 project_context,
453 project_context_needs_refresh: project_context_needs_refresh_tx,
454 _maintain_project_context: cx.spawn(async move |this, cx| {
455 Self::maintain_project_context(
456 this,
457 project_id,
458 project_context_needs_refresh_rx,
459 cx,
460 )
461 .await
462 }),
463 context_server_registry,
464 _subscriptions: subscriptions,
465 },
466 );
467 }
468
469 fn session_project_state(&self, session_id: &acp::SessionId) -> Option<&ProjectState> {
470 self.sessions
471 .get(session_id)
472 .and_then(|session| self.projects.get(&session.project_id))
473 }
474
475 async fn maintain_project_context(
476 this: WeakEntity<Self>,
477 project_id: EntityId,
478 mut needs_refresh: watch::Receiver<()>,
479 cx: &mut AsyncApp,
480 ) -> Result<()> {
481 while needs_refresh.changed().await.is_ok() {
482 let project_context = this
483 .update(cx, |this, cx| {
484 let state = this
485 .projects
486 .get(&project_id)
487 .context("project state not found")?;
488 anyhow::Ok(Self::build_project_context(
489 &state.project,
490 this.prompt_store.as_ref(),
491 cx,
492 ))
493 })??
494 .await;
495 this.update(cx, |this, cx| {
496 if let Some(state) = this.projects.get(&project_id) {
497 state
498 .project_context
499 .update(cx, |current_project_context, _cx| {
500 *current_project_context = project_context;
501 });
502 }
503 })?;
504 }
505
506 Ok(())
507 }
508
509 fn build_project_context(
510 project: &Entity<Project>,
511 prompt_store: Option<&Entity<PromptStore>>,
512 cx: &mut App,
513 ) -> Task<ProjectContext> {
514 let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
515 let worktree_tasks = worktrees
516 .into_iter()
517 .map(|worktree| {
518 Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
519 })
520 .collect::<Vec<_>>();
521 let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
522 prompt_store.read_with(cx, |prompt_store, cx| {
523 let prompts = prompt_store.default_prompt_metadata();
524 let load_tasks = prompts.into_iter().map(|prompt_metadata| {
525 let contents = prompt_store.load(prompt_metadata.id, cx);
526 async move { (contents.await, prompt_metadata) }
527 });
528 cx.background_spawn(future::join_all(load_tasks))
529 })
530 } else {
531 Task::ready(vec![])
532 };
533
534 cx.spawn(async move |_cx| {
535 let (worktrees, default_user_rules) =
536 future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
537
538 let worktrees = worktrees
539 .into_iter()
540 .map(|(worktree, _rules_error)| {
541 // TODO: show error message
542 // if let Some(rules_error) = rules_error {
543 // this.update(cx, |_, cx| cx.emit(rules_error)).ok();
544 // }
545 worktree
546 })
547 .collect::<Vec<_>>();
548
549 let default_user_rules = default_user_rules
550 .into_iter()
551 .flat_map(|(contents, prompt_metadata)| match contents {
552 Ok(contents) => Some(UserRulesContext {
553 uuid: prompt_metadata.id.as_user()?,
554 title: prompt_metadata.title.map(|title| title.to_string()),
555 contents,
556 }),
557 Err(_err) => {
558 // TODO: show error message
559 // this.update(cx, |_, cx| {
560 // cx.emit(RulesLoadingError {
561 // message: format!("{err:?}").into(),
562 // });
563 // })
564 // .ok();
565 None
566 }
567 })
568 .collect::<Vec<_>>();
569
570 ProjectContext::new(worktrees, default_user_rules)
571 })
572 }
573
574 fn load_worktree_info_for_system_prompt(
575 worktree: Entity<Worktree>,
576 project: Entity<Project>,
577 cx: &mut App,
578 ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
579 let tree = worktree.read(cx);
580 let root_name = tree.root_name_str().into();
581 let abs_path = tree.abs_path();
582
583 let mut context = WorktreeContext {
584 root_name,
585 abs_path,
586 rules_file: None,
587 };
588
589 let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
590 let Some(rules_task) = rules_task else {
591 return Task::ready((context, None));
592 };
593
594 cx.spawn(async move |_| {
595 let (rules_file, rules_file_error) = match rules_task.await {
596 Ok(rules_file) => (Some(rules_file), None),
597 Err(err) => (
598 None,
599 Some(RulesLoadingError {
600 message: format!("{err}").into(),
601 }),
602 ),
603 };
604 context.rules_file = rules_file;
605 (context, rules_file_error)
606 })
607 }
608
609 fn load_worktree_rules_file(
610 worktree: Entity<Worktree>,
611 project: Entity<Project>,
612 cx: &mut App,
613 ) -> Option<Task<Result<RulesFileContext>>> {
614 let worktree = worktree.read(cx);
615 let worktree_id = worktree.id();
616 let selected_rules_file = RULES_FILE_NAMES
617 .into_iter()
618 .filter_map(|name| {
619 worktree
620 .entry_for_path(RelPath::unix(name).unwrap())
621 .filter(|entry| entry.is_file())
622 .map(|entry| entry.path.clone())
623 })
624 .next();
625
626 // Note that Cline supports `.clinerules` being a directory, but that is not currently
627 // supported. This doesn't seem to occur often in GitHub repositories.
628 selected_rules_file.map(|path_in_worktree| {
629 let project_path = ProjectPath {
630 worktree_id,
631 path: path_in_worktree.clone(),
632 };
633 let buffer_task =
634 project.update(cx, |project, cx| project.open_buffer(project_path, cx));
635 let rope_task = cx.spawn(async move |cx| {
636 let buffer = buffer_task.await?;
637 let (project_entry_id, rope) = buffer.read_with(cx, |buffer, cx| {
638 let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
639 anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
640 })?;
641 anyhow::Ok((project_entry_id, rope))
642 });
643 // Build a string from the rope on a background thread.
644 cx.background_spawn(async move {
645 let (project_entry_id, rope) = rope_task.await?;
646 anyhow::Ok(RulesFileContext {
647 path_in_worktree,
648 text: rope.to_string().trim().to_string(),
649 project_entry_id: project_entry_id.to_usize(),
650 })
651 })
652 })
653 }
654
655 fn handle_thread_title_updated(
656 &mut self,
657 thread: Entity<Thread>,
658 _: &TitleUpdated,
659 cx: &mut Context<Self>,
660 ) {
661 let session_id = thread.read(cx).id();
662 let Some(session) = self.sessions.get(session_id) else {
663 return;
664 };
665
666 let thread = thread.downgrade();
667 let acp_thread = session.acp_thread.downgrade();
668 cx.spawn(async move |_, cx| {
669 let title = thread.read_with(cx, |thread, _| thread.title())?;
670 if let Some(title) = title {
671 let task =
672 acp_thread.update(cx, |acp_thread, cx| acp_thread.set_title(title, cx))?;
673 task.await?;
674 }
675 anyhow::Ok(())
676 })
677 .detach_and_log_err(cx);
678 }
679
680 fn handle_thread_token_usage_updated(
681 &mut self,
682 thread: Entity<Thread>,
683 usage: &TokenUsageUpdated,
684 cx: &mut Context<Self>,
685 ) {
686 let Some(session) = self.sessions.get(thread.read(cx).id()) else {
687 return;
688 };
689 session.acp_thread.update(cx, |acp_thread, cx| {
690 acp_thread.update_token_usage(usage.0.clone(), cx);
691 });
692 }
693
694 fn handle_project_event(
695 &mut self,
696 project: Entity<Project>,
697 event: &project::Event,
698 _cx: &mut Context<Self>,
699 ) {
700 let project_id = project.entity_id();
701 let Some(state) = self.projects.get_mut(&project_id) else {
702 return;
703 };
704 match event {
705 project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
706 state.project_context_needs_refresh.send(()).ok();
707 }
708 project::Event::WorktreeUpdatedEntries(_, items) => {
709 if items.iter().any(|(path, _, _)| {
710 RULES_FILE_NAMES
711 .iter()
712 .any(|name| path.as_ref() == RelPath::unix(name).unwrap())
713 }) {
714 state.project_context_needs_refresh.send(()).ok();
715 }
716 }
717 _ => {}
718 }
719 }
720
721 fn handle_prompts_updated_event(
722 &mut self,
723 _prompt_store: Entity<PromptStore>,
724 _event: &prompt_store::PromptsUpdatedEvent,
725 _cx: &mut Context<Self>,
726 ) {
727 for state in self.projects.values_mut() {
728 state.project_context_needs_refresh.send(()).ok();
729 }
730 }
731
732 fn handle_models_updated_event(
733 &mut self,
734 _registry: Entity<LanguageModelRegistry>,
735 event: &language_model::Event,
736 cx: &mut Context<Self>,
737 ) {
738 self.models.refresh_list(cx);
739
740 let registry = LanguageModelRegistry::read_global(cx);
741 let default_model = registry.default_model().map(|m| m.model);
742 let summarization_model = registry.thread_summary_model().map(|m| m.model);
743
744 for session in self.sessions.values_mut() {
745 session.thread.update(cx, |thread, cx| {
746 if thread.model().is_none()
747 && let Some(model) = default_model.clone()
748 {
749 thread.set_model(model, cx);
750 cx.notify();
751 }
752 if let Some(model) = summarization_model.clone() {
753 if thread.summarization_model().is_none()
754 || matches!(event, language_model::Event::ThreadSummaryModelChanged)
755 {
756 thread.set_summarization_model(Some(model), cx);
757 }
758 }
759 });
760 }
761 }
762
763 fn handle_context_server_store_updated(
764 &mut self,
765 store: Entity<project::context_server_store::ContextServerStore>,
766 _event: &project::context_server_store::ServerStatusChangedEvent,
767 cx: &mut Context<Self>,
768 ) {
769 let project_id = self.projects.iter().find_map(|(id, state)| {
770 if *state.context_server_registry.read(cx).server_store() == store {
771 Some(*id)
772 } else {
773 None
774 }
775 });
776 if let Some(project_id) = project_id {
777 self.update_available_commands_for_project(project_id, cx);
778 }
779 }
780
781 fn handle_context_server_registry_event(
782 &mut self,
783 registry: Entity<ContextServerRegistry>,
784 event: &ContextServerRegistryEvent,
785 cx: &mut Context<Self>,
786 ) {
787 match event {
788 ContextServerRegistryEvent::ToolsChanged => {}
789 ContextServerRegistryEvent::PromptsChanged => {
790 let project_id = self.projects.iter().find_map(|(id, state)| {
791 if state.context_server_registry == registry {
792 Some(*id)
793 } else {
794 None
795 }
796 });
797 if let Some(project_id) = project_id {
798 self.update_available_commands_for_project(project_id, cx);
799 }
800 }
801 }
802 }
803
804 fn update_available_commands_for_project(&self, project_id: EntityId, cx: &mut Context<Self>) {
805 let available_commands =
806 Self::build_available_commands_for_project(self.projects.get(&project_id), cx);
807 for session in self.sessions.values() {
808 if session.project_id != project_id {
809 continue;
810 }
811 session.acp_thread.update(cx, |thread, cx| {
812 thread
813 .handle_session_update(
814 acp::SessionUpdate::AvailableCommandsUpdate(
815 acp::AvailableCommandsUpdate::new(available_commands.clone()),
816 ),
817 cx,
818 )
819 .log_err();
820 });
821 }
822 }
823
824 fn build_available_commands_for_project(
825 project_state: Option<&ProjectState>,
826 cx: &App,
827 ) -> Vec<acp::AvailableCommand> {
828 let Some(state) = project_state else {
829 return vec![];
830 };
831 let registry = state.context_server_registry.read(cx);
832
833 let mut prompt_name_counts: HashMap<&str, usize> = HashMap::default();
834 for context_server_prompt in registry.prompts() {
835 *prompt_name_counts
836 .entry(context_server_prompt.prompt.name.as_str())
837 .or_insert(0) += 1;
838 }
839
840 registry
841 .prompts()
842 .flat_map(|context_server_prompt| {
843 let prompt = &context_server_prompt.prompt;
844
845 let should_prefix = prompt_name_counts
846 .get(prompt.name.as_str())
847 .copied()
848 .unwrap_or(0)
849 > 1;
850
851 let name = if should_prefix {
852 format!("{}.{}", context_server_prompt.server_id, prompt.name)
853 } else {
854 prompt.name.clone()
855 };
856
857 let mut command = acp::AvailableCommand::new(
858 name,
859 prompt.description.clone().unwrap_or_default(),
860 );
861
862 match prompt.arguments.as_deref() {
863 Some([arg]) => {
864 let hint = format!("<{}>", arg.name);
865
866 command = command.input(acp::AvailableCommandInput::Unstructured(
867 acp::UnstructuredCommandInput::new(hint),
868 ));
869 }
870 Some([]) | None => {}
871 Some(_) => {
872 // skip >1 argument commands since we don't support them yet
873 return None;
874 }
875 }
876
877 Some(command)
878 })
879 .collect()
880 }
881
882 pub fn load_thread(
883 &mut self,
884 id: acp::SessionId,
885 project: Entity<Project>,
886 cx: &mut Context<Self>,
887 ) -> Task<Result<Entity<Thread>>> {
888 let database_future = ThreadsDatabase::connect(cx);
889 cx.spawn(async move |this, cx| {
890 let database = database_future.await.map_err(|err| anyhow!(err))?;
891 let db_thread = database
892 .load_thread(id.clone())
893 .await?
894 .with_context(|| format!("no thread found with ID: {id:?}"))?;
895
896 this.update(cx, |this, cx| {
897 let project_id = this.get_or_create_project_state(&project, cx);
898 let project_state = this
899 .projects
900 .get(&project_id)
901 .context("project state not found")?;
902 let summarization_model = LanguageModelRegistry::read_global(cx)
903 .thread_summary_model()
904 .map(|c| c.model);
905
906 Ok(cx.new(|cx| {
907 let mut thread = Thread::from_db(
908 id.clone(),
909 db_thread,
910 project_state.project.clone(),
911 project_state.project_context.clone(),
912 project_state.context_server_registry.clone(),
913 this.templates.clone(),
914 cx,
915 );
916 thread.set_summarization_model(summarization_model, cx);
917 thread
918 }))
919 })?
920 })
921 }
922
923 pub fn open_thread(
924 &mut self,
925 id: acp::SessionId,
926 project: Entity<Project>,
927 cx: &mut Context<Self>,
928 ) -> Task<Result<Entity<AcpThread>>> {
929 if let Some(session) = self.sessions.get(&id) {
930 return Task::ready(Ok(session.acp_thread.clone()));
931 }
932
933 let task = self.load_thread(id, project.clone(), cx);
934 cx.spawn(async move |this, cx| {
935 let thread = task.await?;
936 let acp_thread = this.update(cx, |this, cx| {
937 let project_id = this.get_or_create_project_state(&project, cx);
938 this.register_session(thread.clone(), project_id, cx)
939 })?;
940 let events = thread.update(cx, |thread, cx| thread.replay(cx));
941 cx.update(|cx| {
942 NativeAgentConnection::handle_thread_events(events, acp_thread.downgrade(), cx)
943 })
944 .await?;
945 acp_thread.update(cx, |thread, cx| {
946 thread.snapshot_completed_plan(cx);
947 });
948 Ok(acp_thread)
949 })
950 }
951
952 pub fn thread_summary(
953 &mut self,
954 id: acp::SessionId,
955 project: Entity<Project>,
956 cx: &mut Context<Self>,
957 ) -> Task<Result<SharedString>> {
958 let thread = self.open_thread(id.clone(), project, cx);
959 cx.spawn(async move |this, cx| {
960 let acp_thread = thread.await?;
961 let result = this
962 .update(cx, |this, cx| {
963 this.sessions
964 .get(&id)
965 .unwrap()
966 .thread
967 .update(cx, |thread, cx| thread.summary(cx))
968 })?
969 .await
970 .context("Failed to generate summary")?;
971 drop(acp_thread);
972 Ok(result)
973 })
974 }
975
976 fn save_thread(&mut self, thread: Entity<Thread>, cx: &mut Context<Self>) {
977 let id = thread.read(cx).id().clone();
978 let Some(session) = self.sessions.get_mut(&id) else {
979 return;
980 };
981
982 let has_draft_prompt = session
983 .acp_thread
984 .read(cx)
985 .draft_prompt()
986 .is_some_and(|p| !p.is_empty());
987 if thread.read(cx).is_empty() && !has_draft_prompt {
988 return;
989 }
990
991 let project_id = session.project_id;
992 let Some(state) = self.projects.get(&project_id) else {
993 return;
994 };
995
996 let folder_paths = PathList::new(
997 &state
998 .project
999 .read(cx)
1000 .visible_worktrees(cx)
1001 .map(|worktree| worktree.read(cx).abs_path().to_path_buf())
1002 .collect::<Vec<_>>(),
1003 );
1004
1005 let draft_prompt = session.acp_thread.read(cx).draft_prompt().map(Vec::from);
1006 let database_future = ThreadsDatabase::connect(cx);
1007 let db_thread = thread.update(cx, |thread, cx| {
1008 thread.set_draft_prompt(draft_prompt);
1009 thread.to_db(cx)
1010 });
1011 let thread_store = self.thread_store.clone();
1012 session.pending_save = cx.spawn(async move |_, cx| {
1013 let Some(database) = database_future.await.map_err(|err| anyhow!(err)).log_err() else {
1014 return Ok(());
1015 };
1016 let db_thread = db_thread.await;
1017 database
1018 .save_thread(id, db_thread, folder_paths)
1019 .await
1020 .log_err();
1021 thread_store.update(cx, |store, cx| store.reload(cx));
1022 Ok(())
1023 });
1024 }
1025
1026 fn send_mcp_prompt(
1027 &self,
1028 message_id: UserMessageId,
1029 session_id: acp::SessionId,
1030 prompt_name: String,
1031 server_id: ContextServerId,
1032 arguments: HashMap<String, String>,
1033 original_content: Vec<acp::ContentBlock>,
1034 cx: &mut Context<Self>,
1035 ) -> Task<Result<acp::PromptResponse>> {
1036 let Some(state) = self.session_project_state(&session_id) else {
1037 return Task::ready(Err(anyhow!("Project state not found for session")));
1038 };
1039 let server_store = state
1040 .context_server_registry
1041 .read(cx)
1042 .server_store()
1043 .clone();
1044 let path_style = state.project.read(cx).path_style(cx);
1045
1046 cx.spawn(async move |this, cx| {
1047 let prompt =
1048 crate::get_prompt(&server_store, &server_id, &prompt_name, arguments, cx).await?;
1049
1050 let (acp_thread, thread) = this.update(cx, |this, _cx| {
1051 let session = this
1052 .sessions
1053 .get(&session_id)
1054 .context("Failed to get session")?;
1055 anyhow::Ok((session.acp_thread.clone(), session.thread.clone()))
1056 })??;
1057
1058 let mut last_is_user = true;
1059
1060 thread.update(cx, |thread, cx| {
1061 thread.push_acp_user_block(
1062 message_id,
1063 original_content.into_iter().skip(1),
1064 path_style,
1065 cx,
1066 );
1067 });
1068
1069 for message in prompt.messages {
1070 let context_server::types::PromptMessage { role, content } = message;
1071 let block = mcp_message_content_to_acp_content_block(content);
1072
1073 match role {
1074 context_server::types::Role::User => {
1075 let id = acp_thread::UserMessageId::new();
1076
1077 acp_thread.update(cx, |acp_thread, cx| {
1078 acp_thread.push_user_content_block_with_indent(
1079 Some(id.clone()),
1080 block.clone(),
1081 true,
1082 cx,
1083 );
1084 });
1085
1086 thread.update(cx, |thread, cx| {
1087 thread.push_acp_user_block(id, [block], path_style, cx);
1088 });
1089 }
1090 context_server::types::Role::Assistant => {
1091 acp_thread.update(cx, |acp_thread, cx| {
1092 acp_thread.push_assistant_content_block_with_indent(
1093 block.clone(),
1094 false,
1095 true,
1096 cx,
1097 );
1098 });
1099
1100 thread.update(cx, |thread, cx| {
1101 thread.push_acp_agent_block(block, cx);
1102 });
1103 }
1104 }
1105
1106 last_is_user = role == context_server::types::Role::User;
1107 }
1108
1109 let response_stream = thread.update(cx, |thread, cx| {
1110 if last_is_user {
1111 thread.send_existing(cx)
1112 } else {
1113 // Resume if MCP prompt did not end with a user message
1114 thread.resume(cx)
1115 }
1116 })?;
1117
1118 cx.update(|cx| {
1119 NativeAgentConnection::handle_thread_events(
1120 response_stream,
1121 acp_thread.downgrade(),
1122 cx,
1123 )
1124 })
1125 .await
1126 })
1127 }
1128}
1129
1130/// Wrapper struct that implements the AgentConnection trait
1131#[derive(Clone)]
1132pub struct NativeAgentConnection(pub Entity<NativeAgent>);
1133
1134impl NativeAgentConnection {
1135 pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option<Entity<Thread>> {
1136 self.0
1137 .read(cx)
1138 .sessions
1139 .get(session_id)
1140 .map(|session| session.thread.clone())
1141 }
1142
1143 pub fn load_thread(
1144 &self,
1145 id: acp::SessionId,
1146 project: Entity<Project>,
1147 cx: &mut App,
1148 ) -> Task<Result<Entity<Thread>>> {
1149 self.0
1150 .update(cx, |this, cx| this.load_thread(id, project, cx))
1151 }
1152
1153 fn run_turn(
1154 &self,
1155 session_id: acp::SessionId,
1156 cx: &mut App,
1157 f: impl 'static
1158 + FnOnce(Entity<Thread>, &mut App) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>,
1159 ) -> Task<Result<acp::PromptResponse>> {
1160 let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
1161 agent
1162 .sessions
1163 .get_mut(&session_id)
1164 .map(|s| (s.thread.clone(), s.acp_thread.clone()))
1165 }) else {
1166 return Task::ready(Err(anyhow!("Session not found")));
1167 };
1168 log::debug!("Found session for: {}", session_id);
1169
1170 let response_stream = match f(thread, cx) {
1171 Ok(stream) => stream,
1172 Err(err) => return Task::ready(Err(err)),
1173 };
1174 Self::handle_thread_events(response_stream, acp_thread.downgrade(), cx)
1175 }
1176
1177 fn handle_thread_events(
1178 mut events: mpsc::UnboundedReceiver<Result<ThreadEvent>>,
1179 acp_thread: WeakEntity<AcpThread>,
1180 cx: &App,
1181 ) -> Task<Result<acp::PromptResponse>> {
1182 cx.spawn(async move |cx| {
1183 // Handle response stream and forward to session.acp_thread
1184 while let Some(result) = events.next().await {
1185 match result {
1186 Ok(event) => {
1187 log::trace!("Received completion event: {:?}", event);
1188
1189 match event {
1190 ThreadEvent::UserMessage(message) => {
1191 acp_thread.update(cx, |thread, cx| {
1192 for content in message.content {
1193 thread.push_user_content_block(
1194 Some(message.id.clone()),
1195 content.into(),
1196 cx,
1197 );
1198 }
1199 })?;
1200 }
1201 ThreadEvent::AgentText(text) => {
1202 acp_thread.update(cx, |thread, cx| {
1203 thread.push_assistant_content_block(text.into(), false, cx)
1204 })?;
1205 }
1206 ThreadEvent::AgentThinking(text) => {
1207 acp_thread.update(cx, |thread, cx| {
1208 thread.push_assistant_content_block(text.into(), true, cx)
1209 })?;
1210 }
1211 ThreadEvent::ToolCallAuthorization(ToolCallAuthorization {
1212 tool_call,
1213 options,
1214 response,
1215 context: _,
1216 }) => {
1217 let outcome_task = acp_thread.update(cx, |thread, cx| {
1218 thread.request_tool_call_authorization(tool_call, options, cx)
1219 })??;
1220 cx.background_spawn(async move {
1221 if let acp_thread::RequestPermissionOutcome::Selected(outcome) =
1222 outcome_task.await
1223 {
1224 response
1225 .send(outcome)
1226 .map(|_| anyhow!("authorization receiver was dropped"))
1227 .log_err();
1228 }
1229 })
1230 .detach();
1231 }
1232 ThreadEvent::ToolCall(tool_call) => {
1233 acp_thread.update(cx, |thread, cx| {
1234 thread.upsert_tool_call(tool_call, cx)
1235 })??;
1236 }
1237 ThreadEvent::ToolCallUpdate(update) => {
1238 acp_thread.update(cx, |thread, cx| {
1239 thread.update_tool_call(update, cx)
1240 })??;
1241 }
1242 ThreadEvent::Plan(plan) => {
1243 acp_thread.update(cx, |thread, cx| thread.update_plan(plan, cx))?;
1244 }
1245 ThreadEvent::SubagentSpawned(session_id) => {
1246 acp_thread.update(cx, |thread, cx| {
1247 thread.subagent_spawned(session_id, cx);
1248 })?;
1249 }
1250 ThreadEvent::Retry(status) => {
1251 acp_thread.update(cx, |thread, cx| {
1252 thread.update_retry_status(status, cx)
1253 })?;
1254 }
1255 ThreadEvent::Stop(stop_reason) => {
1256 log::debug!("Assistant message complete: {:?}", stop_reason);
1257 return Ok(acp::PromptResponse::new(stop_reason));
1258 }
1259 }
1260 }
1261 Err(e) => {
1262 log::error!("Error in model response stream: {:?}", e);
1263 return Err(e);
1264 }
1265 }
1266 }
1267
1268 log::debug!("Response stream completed");
1269 anyhow::Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
1270 })
1271 }
1272}
1273
1274struct Command<'a> {
1275 prompt_name: &'a str,
1276 arg_value: &'a str,
1277 explicit_server_id: Option<&'a str>,
1278}
1279
1280impl<'a> Command<'a> {
1281 fn parse(prompt: &'a [acp::ContentBlock]) -> Option<Self> {
1282 let acp::ContentBlock::Text(text_content) = prompt.first()? else {
1283 return None;
1284 };
1285 let text = text_content.text.trim();
1286 let command = text.strip_prefix('/')?;
1287 let (command, arg_value) = command
1288 .split_once(char::is_whitespace)
1289 .unwrap_or((command, ""));
1290
1291 if let Some((server_id, prompt_name)) = command.split_once('.') {
1292 Some(Self {
1293 prompt_name,
1294 arg_value,
1295 explicit_server_id: Some(server_id),
1296 })
1297 } else {
1298 Some(Self {
1299 prompt_name: command,
1300 arg_value,
1301 explicit_server_id: None,
1302 })
1303 }
1304 }
1305}
1306
1307struct NativeAgentModelSelector {
1308 session_id: acp::SessionId,
1309 connection: NativeAgentConnection,
1310}
1311
1312impl acp_thread::AgentModelSelector for NativeAgentModelSelector {
1313 fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
1314 log::debug!("NativeAgentConnection::list_models called");
1315 let list = self.connection.0.read(cx).models.model_list.clone();
1316 Task::ready(if list.is_empty() {
1317 Err(anyhow::anyhow!("No models available"))
1318 } else {
1319 Ok(list)
1320 })
1321 }
1322
1323 fn select_model(&self, model_id: acp::ModelId, cx: &mut App) -> Task<Result<()>> {
1324 log::debug!(
1325 "Setting model for session {}: {}",
1326 self.session_id,
1327 model_id
1328 );
1329 let Some(thread) = self
1330 .connection
1331 .0
1332 .read(cx)
1333 .sessions
1334 .get(&self.session_id)
1335 .map(|session| session.thread.clone())
1336 else {
1337 return Task::ready(Err(anyhow!("Session not found")));
1338 };
1339
1340 let Some(model) = self.connection.0.read(cx).models.model_from_id(&model_id) else {
1341 return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
1342 };
1343
1344 // We want to reset the effort level when switching models, as the currently-selected effort level may
1345 // not be compatible.
1346 let effort = model
1347 .default_effort_level()
1348 .map(|effort_level| effort_level.value.to_string());
1349
1350 thread.update(cx, |thread, cx| {
1351 thread.set_model(model.clone(), cx);
1352 thread.set_thinking_effort(effort.clone(), cx);
1353 thread.set_thinking_enabled(model.supports_thinking(), cx);
1354 });
1355
1356 update_settings_file(
1357 self.connection.0.read(cx).fs.clone(),
1358 cx,
1359 move |settings, cx| {
1360 let provider = model.provider_id().0.to_string();
1361 let model = model.id().0.to_string();
1362 let enable_thinking = thread.read(cx).thinking_enabled();
1363 settings
1364 .agent
1365 .get_or_insert_default()
1366 .set_model(LanguageModelSelection {
1367 provider: provider.into(),
1368 model,
1369 enable_thinking,
1370 effort,
1371 });
1372 },
1373 );
1374
1375 Task::ready(Ok(()))
1376 }
1377
1378 fn selected_model(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelInfo>> {
1379 let Some(thread) = self
1380 .connection
1381 .0
1382 .read(cx)
1383 .sessions
1384 .get(&self.session_id)
1385 .map(|session| session.thread.clone())
1386 else {
1387 return Task::ready(Err(anyhow!("Session not found")));
1388 };
1389 let Some(model) = thread.read(cx).model() else {
1390 return Task::ready(Err(anyhow!("Model not found")));
1391 };
1392 let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
1393 else {
1394 return Task::ready(Err(anyhow!("Provider not found")));
1395 };
1396 Task::ready(Ok(LanguageModels::map_language_model_to_info(
1397 model, &provider,
1398 )))
1399 }
1400
1401 fn watch(&self, cx: &mut App) -> Option<watch::Receiver<()>> {
1402 Some(self.connection.0.read(cx).models.watch())
1403 }
1404
1405 fn should_render_footer(&self) -> bool {
1406 true
1407 }
1408}
1409
1410pub static ZED_AGENT_ID: LazyLock<AgentId> = LazyLock::new(|| AgentId::new("Zed Agent"));
1411
1412impl acp_thread::AgentConnection for NativeAgentConnection {
1413 fn agent_id(&self) -> AgentId {
1414 ZED_AGENT_ID.clone()
1415 }
1416
1417 fn telemetry_id(&self) -> SharedString {
1418 "zed".into()
1419 }
1420
1421 fn new_session(
1422 self: Rc<Self>,
1423 project: Entity<Project>,
1424 work_dirs: PathList,
1425 cx: &mut App,
1426 ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
1427 log::debug!("Creating new thread for project at: {work_dirs:?}");
1428 Task::ready(Ok(self
1429 .0
1430 .update(cx, |agent, cx| agent.new_session(project, cx))))
1431 }
1432
1433 fn supports_load_session(&self) -> bool {
1434 true
1435 }
1436
1437 fn load_session(
1438 self: Rc<Self>,
1439 session_id: acp::SessionId,
1440 project: Entity<Project>,
1441 _work_dirs: PathList,
1442 _title: Option<SharedString>,
1443 cx: &mut App,
1444 ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
1445 self.0
1446 .update(cx, |agent, cx| agent.open_thread(session_id, project, cx))
1447 }
1448
1449 fn supports_close_session(&self) -> bool {
1450 true
1451 }
1452
1453 fn close_session(
1454 self: Rc<Self>,
1455 session_id: &acp::SessionId,
1456 cx: &mut App,
1457 ) -> Task<Result<()>> {
1458 self.0.update(cx, |agent, cx| {
1459 let thread = agent.sessions.get(session_id).map(|s| s.thread.clone());
1460 if let Some(thread) = thread {
1461 agent.save_thread(thread, cx);
1462 }
1463
1464 let Some(session) = agent.sessions.remove(session_id) else {
1465 return Task::ready(Ok(()));
1466 };
1467 let project_id = session.project_id;
1468
1469 let has_remaining = agent.sessions.values().any(|s| s.project_id == project_id);
1470 if !has_remaining {
1471 agent.projects.remove(&project_id);
1472 }
1473
1474 session.pending_save
1475 })
1476 }
1477
1478 fn auth_methods(&self) -> &[acp::AuthMethod] {
1479 &[] // No auth for in-process
1480 }
1481
1482 fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
1483 Task::ready(Ok(()))
1484 }
1485
1486 fn model_selector(&self, session_id: &acp::SessionId) -> Option<Rc<dyn AgentModelSelector>> {
1487 Some(Rc::new(NativeAgentModelSelector {
1488 session_id: session_id.clone(),
1489 connection: self.clone(),
1490 }) as Rc<dyn AgentModelSelector>)
1491 }
1492
1493 fn prompt(
1494 &self,
1495 id: Option<acp_thread::UserMessageId>,
1496 params: acp::PromptRequest,
1497 cx: &mut App,
1498 ) -> Task<Result<acp::PromptResponse>> {
1499 let id = id.expect("UserMessageId is required");
1500 let session_id = params.session_id.clone();
1501 log::info!("Received prompt request for session: {}", session_id);
1502 log::debug!("Prompt blocks count: {}", params.prompt.len());
1503
1504 let Some(project_state) = self.0.read(cx).session_project_state(&session_id) else {
1505 return Task::ready(Err(anyhow::anyhow!("Session not found")));
1506 };
1507
1508 if let Some(parsed_command) = Command::parse(¶ms.prompt) {
1509 let registry = project_state.context_server_registry.read(cx);
1510
1511 let explicit_server_id = parsed_command
1512 .explicit_server_id
1513 .map(|server_id| ContextServerId(server_id.into()));
1514
1515 if let Some(prompt) =
1516 registry.find_prompt(explicit_server_id.as_ref(), parsed_command.prompt_name)
1517 {
1518 let arguments = if !parsed_command.arg_value.is_empty()
1519 && let Some(arg_name) = prompt
1520 .prompt
1521 .arguments
1522 .as_ref()
1523 .and_then(|args| args.first())
1524 .map(|arg| arg.name.clone())
1525 {
1526 HashMap::from_iter([(arg_name, parsed_command.arg_value.to_string())])
1527 } else {
1528 Default::default()
1529 };
1530
1531 let prompt_name = prompt.prompt.name.clone();
1532 let server_id = prompt.server_id.clone();
1533
1534 return self.0.update(cx, |agent, cx| {
1535 agent.send_mcp_prompt(
1536 id,
1537 session_id.clone(),
1538 prompt_name,
1539 server_id,
1540 arguments,
1541 params.prompt,
1542 cx,
1543 )
1544 });
1545 }
1546 };
1547
1548 let path_style = project_state.project.read(cx).path_style(cx);
1549
1550 self.run_turn(session_id, cx, move |thread, cx| {
1551 let content: Vec<UserMessageContent> = params
1552 .prompt
1553 .into_iter()
1554 .map(|block| UserMessageContent::from_content_block(block, path_style))
1555 .collect::<Vec<_>>();
1556 log::debug!("Converted prompt to message: {} chars", content.len());
1557 log::debug!("Message id: {:?}", id);
1558 log::debug!("Message content: {:?}", content);
1559
1560 thread.update(cx, |thread, cx| thread.send(id, content, cx))
1561 })
1562 }
1563
1564 fn retry(
1565 &self,
1566 session_id: &acp::SessionId,
1567 _cx: &App,
1568 ) -> Option<Rc<dyn acp_thread::AgentSessionRetry>> {
1569 Some(Rc::new(NativeAgentSessionRetry {
1570 connection: self.clone(),
1571 session_id: session_id.clone(),
1572 }) as _)
1573 }
1574
1575 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
1576 log::info!("Cancelling on session: {}", session_id);
1577 self.0.update(cx, |agent, cx| {
1578 if let Some(session) = agent.sessions.get(session_id) {
1579 session
1580 .thread
1581 .update(cx, |thread, cx| thread.cancel(cx))
1582 .detach();
1583 }
1584 });
1585 }
1586
1587 fn truncate(
1588 &self,
1589 session_id: &acp::SessionId,
1590 cx: &App,
1591 ) -> Option<Rc<dyn acp_thread::AgentSessionTruncate>> {
1592 self.0.read_with(cx, |agent, _cx| {
1593 agent.sessions.get(session_id).map(|session| {
1594 Rc::new(NativeAgentSessionTruncate {
1595 thread: session.thread.clone(),
1596 acp_thread: session.acp_thread.downgrade(),
1597 }) as _
1598 })
1599 })
1600 }
1601
1602 fn set_title(
1603 &self,
1604 session_id: &acp::SessionId,
1605 cx: &App,
1606 ) -> Option<Rc<dyn acp_thread::AgentSessionSetTitle>> {
1607 self.0.read_with(cx, |agent, _cx| {
1608 agent
1609 .sessions
1610 .get(session_id)
1611 .filter(|s| !s.thread.read(cx).is_subagent())
1612 .map(|session| {
1613 Rc::new(NativeAgentSessionSetTitle {
1614 thread: session.thread.clone(),
1615 }) as _
1616 })
1617 })
1618 }
1619
1620 fn session_list(&self, cx: &mut App) -> Option<Rc<dyn AgentSessionList>> {
1621 let thread_store = self.0.read(cx).thread_store.clone();
1622 Some(Rc::new(NativeAgentSessionList::new(thread_store, cx)) as _)
1623 }
1624
1625 fn telemetry(&self) -> Option<Rc<dyn acp_thread::AgentTelemetry>> {
1626 Some(Rc::new(self.clone()) as Rc<dyn acp_thread::AgentTelemetry>)
1627 }
1628
1629 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1630 self
1631 }
1632}
1633
1634impl acp_thread::AgentTelemetry for NativeAgentConnection {
1635 fn thread_data(
1636 &self,
1637 session_id: &acp::SessionId,
1638 cx: &mut App,
1639 ) -> Task<Result<serde_json::Value>> {
1640 let Some(session) = self.0.read(cx).sessions.get(session_id) else {
1641 return Task::ready(Err(anyhow!("Session not found")));
1642 };
1643
1644 let task = session.thread.read(cx).to_db(cx);
1645 cx.background_spawn(async move {
1646 serde_json::to_value(task.await).context("Failed to serialize thread")
1647 })
1648 }
1649}
1650
1651pub struct NativeAgentSessionList {
1652 thread_store: Entity<ThreadStore>,
1653 updates_tx: smol::channel::Sender<acp_thread::SessionListUpdate>,
1654 updates_rx: smol::channel::Receiver<acp_thread::SessionListUpdate>,
1655 _subscription: Subscription,
1656}
1657
1658impl NativeAgentSessionList {
1659 fn new(thread_store: Entity<ThreadStore>, cx: &mut App) -> Self {
1660 let (tx, rx) = smol::channel::unbounded();
1661 let this_tx = tx.clone();
1662 let subscription = cx.observe(&thread_store, move |_, _| {
1663 this_tx
1664 .try_send(acp_thread::SessionListUpdate::Refresh)
1665 .ok();
1666 });
1667 Self {
1668 thread_store,
1669 updates_tx: tx,
1670 updates_rx: rx,
1671 _subscription: subscription,
1672 }
1673 }
1674
1675 pub fn thread_store(&self) -> &Entity<ThreadStore> {
1676 &self.thread_store
1677 }
1678}
1679
1680impl AgentSessionList for NativeAgentSessionList {
1681 fn list_sessions(
1682 &self,
1683 _request: AgentSessionListRequest,
1684 cx: &mut App,
1685 ) -> Task<Result<AgentSessionListResponse>> {
1686 let sessions = self
1687 .thread_store
1688 .read(cx)
1689 .entries()
1690 .map(|entry| AgentSessionInfo::from(&entry))
1691 .collect();
1692 Task::ready(Ok(AgentSessionListResponse::new(sessions)))
1693 }
1694
1695 fn supports_delete(&self) -> bool {
1696 true
1697 }
1698
1699 fn delete_session(&self, session_id: &acp::SessionId, cx: &mut App) -> Task<Result<()>> {
1700 self.thread_store
1701 .update(cx, |store, cx| store.delete_thread(session_id.clone(), cx))
1702 }
1703
1704 fn delete_sessions(&self, cx: &mut App) -> Task<Result<()>> {
1705 self.thread_store
1706 .update(cx, |store, cx| store.delete_threads(cx))
1707 }
1708
1709 fn watch(
1710 &self,
1711 _cx: &mut App,
1712 ) -> Option<smol::channel::Receiver<acp_thread::SessionListUpdate>> {
1713 Some(self.updates_rx.clone())
1714 }
1715
1716 fn notify_refresh(&self) {
1717 self.updates_tx
1718 .try_send(acp_thread::SessionListUpdate::Refresh)
1719 .ok();
1720 }
1721
1722 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1723 self
1724 }
1725}
1726
1727struct NativeAgentSessionTruncate {
1728 thread: Entity<Thread>,
1729 acp_thread: WeakEntity<AcpThread>,
1730}
1731
1732impl acp_thread::AgentSessionTruncate for NativeAgentSessionTruncate {
1733 fn run(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
1734 match self.thread.update(cx, |thread, cx| {
1735 thread.truncate(message_id.clone(), cx)?;
1736 Ok(thread.latest_token_usage())
1737 }) {
1738 Ok(usage) => {
1739 self.acp_thread
1740 .update(cx, |thread, cx| {
1741 thread.update_token_usage(usage, cx);
1742 })
1743 .ok();
1744 Task::ready(Ok(()))
1745 }
1746 Err(error) => Task::ready(Err(error)),
1747 }
1748 }
1749}
1750
1751struct NativeAgentSessionRetry {
1752 connection: NativeAgentConnection,
1753 session_id: acp::SessionId,
1754}
1755
1756impl acp_thread::AgentSessionRetry for NativeAgentSessionRetry {
1757 fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>> {
1758 self.connection
1759 .run_turn(self.session_id.clone(), cx, |thread, cx| {
1760 thread.update(cx, |thread, cx| thread.resume(cx))
1761 })
1762 }
1763}
1764
1765struct NativeAgentSessionSetTitle {
1766 thread: Entity<Thread>,
1767}
1768
1769impl acp_thread::AgentSessionSetTitle for NativeAgentSessionSetTitle {
1770 fn run(&self, title: SharedString, cx: &mut App) -> Task<Result<()>> {
1771 self.thread
1772 .update(cx, |thread, cx| thread.set_title(title, cx));
1773 Task::ready(Ok(()))
1774 }
1775}
1776
1777pub struct NativeThreadEnvironment {
1778 agent: WeakEntity<NativeAgent>,
1779 thread: WeakEntity<Thread>,
1780 acp_thread: WeakEntity<AcpThread>,
1781}
1782
1783impl NativeThreadEnvironment {
1784 pub(crate) fn create_subagent_thread(
1785 &self,
1786 label: String,
1787 cx: &mut App,
1788 ) -> Result<Rc<dyn SubagentHandle>> {
1789 let Some(parent_thread_entity) = self.thread.upgrade() else {
1790 anyhow::bail!("Parent thread no longer exists".to_string());
1791 };
1792 let parent_thread = parent_thread_entity.read(cx);
1793 let current_depth = parent_thread.depth();
1794 let parent_session_id = parent_thread.id().clone();
1795
1796 if current_depth >= MAX_SUBAGENT_DEPTH {
1797 return Err(anyhow!(
1798 "Maximum subagent depth ({}) reached",
1799 MAX_SUBAGENT_DEPTH
1800 ));
1801 }
1802
1803 let subagent_thread: Entity<Thread> = cx.new(|cx| {
1804 let mut thread = Thread::new_subagent(&parent_thread_entity, cx);
1805 thread.set_title(label.into(), cx);
1806 thread
1807 });
1808
1809 let session_id = subagent_thread.read(cx).id().clone();
1810
1811 let acp_thread = self
1812 .agent
1813 .update(cx, |agent, cx| -> Result<Entity<AcpThread>> {
1814 let project_id = agent
1815 .sessions
1816 .get(&parent_session_id)
1817 .map(|s| s.project_id)
1818 .context("parent session not found")?;
1819 Ok(agent.register_session(subagent_thread.clone(), project_id, cx))
1820 })??;
1821
1822 let depth = current_depth + 1;
1823
1824 telemetry::event!(
1825 "Subagent Started",
1826 session = parent_thread_entity.read(cx).id().to_string(),
1827 subagent_session = session_id.to_string(),
1828 depth,
1829 is_resumed = false,
1830 );
1831
1832 self.prompt_subagent(session_id, subagent_thread, acp_thread)
1833 }
1834
1835 pub(crate) fn resume_subagent_thread(
1836 &self,
1837 session_id: acp::SessionId,
1838 cx: &mut App,
1839 ) -> Result<Rc<dyn SubagentHandle>> {
1840 let (subagent_thread, acp_thread) = self.agent.update(cx, |agent, _cx| {
1841 let session = agent
1842 .sessions
1843 .get(&session_id)
1844 .ok_or_else(|| anyhow!("No subagent session found with id {session_id}"))?;
1845 anyhow::Ok((session.thread.clone(), session.acp_thread.clone()))
1846 })??;
1847
1848 let depth = subagent_thread.read(cx).depth();
1849
1850 if let Some(parent_thread_entity) = self.thread.upgrade() {
1851 telemetry::event!(
1852 "Subagent Started",
1853 session = parent_thread_entity.read(cx).id().to_string(),
1854 subagent_session = session_id.to_string(),
1855 depth,
1856 is_resumed = true,
1857 );
1858 }
1859
1860 self.prompt_subagent(session_id, subagent_thread, acp_thread)
1861 }
1862
1863 fn prompt_subagent(
1864 &self,
1865 session_id: acp::SessionId,
1866 subagent_thread: Entity<Thread>,
1867 acp_thread: Entity<acp_thread::AcpThread>,
1868 ) -> Result<Rc<dyn SubagentHandle>> {
1869 let Some(parent_thread_entity) = self.thread.upgrade() else {
1870 anyhow::bail!("Parent thread no longer exists".to_string());
1871 };
1872 Ok(Rc::new(NativeSubagentHandle::new(
1873 session_id,
1874 subagent_thread,
1875 acp_thread,
1876 parent_thread_entity,
1877 )) as _)
1878 }
1879}
1880
1881impl ThreadEnvironment for NativeThreadEnvironment {
1882 fn create_terminal(
1883 &self,
1884 command: String,
1885 cwd: Option<PathBuf>,
1886 output_byte_limit: Option<u64>,
1887 cx: &mut AsyncApp,
1888 ) -> Task<Result<Rc<dyn TerminalHandle>>> {
1889 let task = self.acp_thread.update(cx, |thread, cx| {
1890 thread.create_terminal(command, vec![], vec![], cwd, output_byte_limit, cx)
1891 });
1892
1893 let acp_thread = self.acp_thread.clone();
1894 cx.spawn(async move |cx| {
1895 let terminal = task?.await?;
1896
1897 let (drop_tx, drop_rx) = oneshot::channel();
1898 let terminal_id = terminal.read_with(cx, |terminal, _cx| terminal.id().clone());
1899
1900 cx.spawn(async move |cx| {
1901 drop_rx.await.ok();
1902 acp_thread.update(cx, |thread, cx| thread.release_terminal(terminal_id, cx))
1903 })
1904 .detach();
1905
1906 let handle = AcpTerminalHandle {
1907 terminal,
1908 _drop_tx: Some(drop_tx),
1909 };
1910
1911 Ok(Rc::new(handle) as _)
1912 })
1913 }
1914
1915 fn create_subagent(&self, label: String, cx: &mut App) -> Result<Rc<dyn SubagentHandle>> {
1916 self.create_subagent_thread(label, cx)
1917 }
1918
1919 fn resume_subagent(
1920 &self,
1921 session_id: acp::SessionId,
1922 cx: &mut App,
1923 ) -> Result<Rc<dyn SubagentHandle>> {
1924 self.resume_subagent_thread(session_id, cx)
1925 }
1926}
1927
1928#[derive(Debug, Clone)]
1929enum SubagentPromptResult {
1930 Completed,
1931 Cancelled,
1932 ContextWindowWarning,
1933 Error(String),
1934}
1935
1936pub struct NativeSubagentHandle {
1937 session_id: acp::SessionId,
1938 parent_thread: WeakEntity<Thread>,
1939 subagent_thread: Entity<Thread>,
1940 acp_thread: Entity<acp_thread::AcpThread>,
1941}
1942
1943impl NativeSubagentHandle {
1944 fn new(
1945 session_id: acp::SessionId,
1946 subagent_thread: Entity<Thread>,
1947 acp_thread: Entity<acp_thread::AcpThread>,
1948 parent_thread_entity: Entity<Thread>,
1949 ) -> Self {
1950 NativeSubagentHandle {
1951 session_id,
1952 subagent_thread,
1953 parent_thread: parent_thread_entity.downgrade(),
1954 acp_thread,
1955 }
1956 }
1957}
1958
1959impl SubagentHandle for NativeSubagentHandle {
1960 fn id(&self) -> acp::SessionId {
1961 self.session_id.clone()
1962 }
1963
1964 fn num_entries(&self, cx: &App) -> usize {
1965 self.acp_thread.read(cx).entries().len()
1966 }
1967
1968 fn send(&self, message: String, cx: &AsyncApp) -> Task<Result<String>> {
1969 let thread = self.subagent_thread.clone();
1970 let acp_thread = self.acp_thread.clone();
1971 let subagent_session_id = self.session_id.clone();
1972 let parent_thread = self.parent_thread.clone();
1973
1974 cx.spawn(async move |cx| {
1975 let (task, _subscription) = cx.update(|cx| {
1976 let ratio_before_prompt = thread
1977 .read(cx)
1978 .latest_token_usage()
1979 .map(|usage| usage.ratio());
1980
1981 parent_thread
1982 .update(cx, |parent_thread, _cx| {
1983 parent_thread.register_running_subagent(thread.downgrade())
1984 })
1985 .ok();
1986
1987 let task = acp_thread.update(cx, |acp_thread, cx| {
1988 acp_thread.send(vec![message.into()], cx)
1989 });
1990
1991 let (token_limit_tx, token_limit_rx) = oneshot::channel::<()>();
1992 let mut token_limit_tx = Some(token_limit_tx);
1993
1994 let subscription = cx.subscribe(
1995 &thread,
1996 move |_thread, event: &TokenUsageUpdated, _cx| {
1997 if let Some(usage) = &event.0 {
1998 let old_ratio = ratio_before_prompt
1999 .clone()
2000 .unwrap_or(TokenUsageRatio::Normal);
2001 let new_ratio = usage.ratio();
2002 if old_ratio == TokenUsageRatio::Normal
2003 && new_ratio == TokenUsageRatio::Warning
2004 {
2005 if let Some(tx) = token_limit_tx.take() {
2006 tx.send(()).ok();
2007 }
2008 }
2009 }
2010 },
2011 );
2012
2013 let wait_for_prompt = cx
2014 .background_spawn(async move {
2015 futures::select! {
2016 response = task.fuse() => match response {
2017 Ok(Some(response)) => {
2018 match response.stop_reason {
2019 acp::StopReason::Cancelled => SubagentPromptResult::Cancelled,
2020 acp::StopReason::MaxTokens => SubagentPromptResult::Error("The agent reached the maximum number of tokens.".into()),
2021 acp::StopReason::MaxTurnRequests => SubagentPromptResult::Error("The agent reached the maximum number of allowed requests between user turns. Try prompting again.".into()),
2022 acp::StopReason::Refusal => SubagentPromptResult::Error("The agent refused to process that prompt. Try again.".into()),
2023 acp::StopReason::EndTurn | _ => SubagentPromptResult::Completed,
2024 }
2025 }
2026 Ok(None) => SubagentPromptResult::Error("No response from the agent. You can try messaging again.".into()),
2027 Err(error) => SubagentPromptResult::Error(error.to_string()),
2028 },
2029 _ = token_limit_rx.fuse() => SubagentPromptResult::ContextWindowWarning,
2030 }
2031 });
2032
2033 (wait_for_prompt, subscription)
2034 });
2035
2036 let result = match task.await {
2037 SubagentPromptResult::Completed => thread.read_with(cx, |thread, _cx| {
2038 thread
2039 .last_message()
2040 .and_then(|message| {
2041 let content = message.as_agent_message()?
2042 .content
2043 .iter()
2044 .filter_map(|c| match c {
2045 AgentMessageContent::Text(text) => Some(text.as_str()),
2046 _ => None,
2047 })
2048 .join("\n\n");
2049 if content.is_empty() {
2050 None
2051 } else {
2052 Some( content)
2053 }
2054 })
2055 .context("No response from subagent")
2056 }),
2057 SubagentPromptResult::Cancelled => Err(anyhow!("User canceled")),
2058 SubagentPromptResult::Error(message) => Err(anyhow!("{message}")),
2059 SubagentPromptResult::ContextWindowWarning => {
2060 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
2061 Err(anyhow!(
2062 "The agent is nearing the end of its context window and has been \
2063 stopped. You can prompt the thread again to have the agent wrap up \
2064 or hand off its work."
2065 ))
2066 }
2067 };
2068
2069 parent_thread
2070 .update(cx, |parent_thread, cx| {
2071 parent_thread.unregister_running_subagent(&subagent_session_id, cx)
2072 })
2073 .ok();
2074
2075 result
2076 })
2077 }
2078}
2079
2080pub struct AcpTerminalHandle {
2081 terminal: Entity<acp_thread::Terminal>,
2082 _drop_tx: Option<oneshot::Sender<()>>,
2083}
2084
2085impl TerminalHandle for AcpTerminalHandle {
2086 fn id(&self, cx: &AsyncApp) -> Result<acp::TerminalId> {
2087 Ok(self.terminal.read_with(cx, |term, _cx| term.id().clone()))
2088 }
2089
2090 fn wait_for_exit(&self, cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
2091 Ok(self
2092 .terminal
2093 .read_with(cx, |term, _cx| term.wait_for_exit()))
2094 }
2095
2096 fn current_output(&self, cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
2097 Ok(self
2098 .terminal
2099 .read_with(cx, |term, cx| term.current_output(cx)))
2100 }
2101
2102 fn kill(&self, cx: &AsyncApp) -> Result<()> {
2103 cx.update(|cx| {
2104 self.terminal.update(cx, |terminal, cx| {
2105 terminal.kill(cx);
2106 });
2107 });
2108 Ok(())
2109 }
2110
2111 fn was_stopped_by_user(&self, cx: &AsyncApp) -> Result<bool> {
2112 Ok(self
2113 .terminal
2114 .read_with(cx, |term, _cx| term.was_stopped_by_user()))
2115 }
2116}
2117
2118#[cfg(test)]
2119mod internal_tests {
2120 use std::path::Path;
2121
2122 use super::*;
2123 use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelInfo, MentionUri};
2124 use fs::FakeFs;
2125 use gpui::TestAppContext;
2126 use indoc::formatdoc;
2127 use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
2128 use language_model::{
2129 LanguageModelCompletionEvent, LanguageModelProviderId, LanguageModelProviderName,
2130 };
2131 use serde_json::json;
2132 use settings::SettingsStore;
2133 use util::{path, rel_path::rel_path};
2134
2135 #[gpui::test]
2136 async fn test_maintaining_project_context(cx: &mut TestAppContext) {
2137 init_test(cx);
2138 let fs = FakeFs::new(cx.executor());
2139 fs.insert_tree(
2140 "/",
2141 json!({
2142 "a": {}
2143 }),
2144 )
2145 .await;
2146 let project = Project::test(fs.clone(), [], cx).await;
2147 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2148 let agent =
2149 cx.update(|cx| NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx));
2150
2151 // Creating a session registers the project and triggers context building.
2152 let connection = NativeAgentConnection(agent.clone());
2153 let _acp_thread = cx
2154 .update(|cx| {
2155 Rc::new(connection).new_session(
2156 project.clone(),
2157 PathList::new(&[Path::new("/")]),
2158 cx,
2159 )
2160 })
2161 .await
2162 .unwrap();
2163 cx.run_until_parked();
2164
2165 let thread = agent.read_with(cx, |agent, _cx| {
2166 agent.sessions.values().next().unwrap().thread.clone()
2167 });
2168
2169 agent.read_with(cx, |agent, cx| {
2170 let project_id = project.entity_id();
2171 let state = agent.projects.get(&project_id).unwrap();
2172 assert_eq!(state.project_context.read(cx).worktrees, vec![]);
2173 assert_eq!(thread.read(cx).project_context().read(cx).worktrees, vec![]);
2174 });
2175
2176 let worktree = project
2177 .update(cx, |project, cx| project.create_worktree("/a", true, cx))
2178 .await
2179 .unwrap();
2180 cx.run_until_parked();
2181 agent.read_with(cx, |agent, cx| {
2182 let project_id = project.entity_id();
2183 let state = agent.projects.get(&project_id).unwrap();
2184 let expected_worktrees = vec![WorktreeContext {
2185 root_name: "a".into(),
2186 abs_path: Path::new("/a").into(),
2187 rules_file: None,
2188 }];
2189 assert_eq!(state.project_context.read(cx).worktrees, expected_worktrees);
2190 assert_eq!(
2191 thread.read(cx).project_context().read(cx).worktrees,
2192 expected_worktrees
2193 );
2194 });
2195
2196 // Creating `/a/.rules` updates the project context.
2197 fs.insert_file("/a/.rules", Vec::new()).await;
2198 cx.run_until_parked();
2199 agent.read_with(cx, |agent, cx| {
2200 let project_id = project.entity_id();
2201 let state = agent.projects.get(&project_id).unwrap();
2202 let rules_entry = worktree
2203 .read(cx)
2204 .entry_for_path(rel_path(".rules"))
2205 .unwrap();
2206 let expected_worktrees = vec![WorktreeContext {
2207 root_name: "a".into(),
2208 abs_path: Path::new("/a").into(),
2209 rules_file: Some(RulesFileContext {
2210 path_in_worktree: rel_path(".rules").into(),
2211 text: "".into(),
2212 project_entry_id: rules_entry.id.to_usize(),
2213 }),
2214 }];
2215 assert_eq!(state.project_context.read(cx).worktrees, expected_worktrees);
2216 assert_eq!(
2217 thread.read(cx).project_context().read(cx).worktrees,
2218 expected_worktrees
2219 );
2220 });
2221 }
2222
2223 #[gpui::test]
2224 async fn test_listing_models(cx: &mut TestAppContext) {
2225 init_test(cx);
2226 let fs = FakeFs::new(cx.executor());
2227 fs.insert_tree("/", json!({ "a": {} })).await;
2228 let project = Project::test(fs.clone(), [], cx).await;
2229 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2230 let connection =
2231 NativeAgentConnection(cx.update(|cx| {
2232 NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx)
2233 }));
2234
2235 // Create a thread/session
2236 let acp_thread = cx
2237 .update(|cx| {
2238 Rc::new(connection.clone()).new_session(
2239 project.clone(),
2240 PathList::new(&[Path::new("/a")]),
2241 cx,
2242 )
2243 })
2244 .await
2245 .unwrap();
2246
2247 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2248
2249 let models = cx
2250 .update(|cx| {
2251 connection
2252 .model_selector(&session_id)
2253 .unwrap()
2254 .list_models(cx)
2255 })
2256 .await
2257 .unwrap();
2258
2259 let acp_thread::AgentModelList::Grouped(models) = models else {
2260 panic!("Unexpected model group");
2261 };
2262 assert_eq!(
2263 models,
2264 IndexMap::from_iter([(
2265 AgentModelGroupName("Fake".into()),
2266 vec![AgentModelInfo {
2267 id: acp::ModelId::new("fake/fake"),
2268 name: "Fake".into(),
2269 description: None,
2270 icon: Some(acp_thread::AgentModelIcon::Named(
2271 ui::IconName::ZedAssistant
2272 )),
2273 is_latest: false,
2274 cost: None,
2275 }]
2276 )])
2277 );
2278 }
2279
2280 #[gpui::test]
2281 async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) {
2282 init_test(cx);
2283 let fs = FakeFs::new(cx.executor());
2284 fs.create_dir(paths::settings_file().parent().unwrap())
2285 .await
2286 .unwrap();
2287 fs.insert_file(
2288 paths::settings_file(),
2289 json!({
2290 "agent": {
2291 "default_model": {
2292 "provider": "foo",
2293 "model": "bar"
2294 }
2295 }
2296 })
2297 .to_string()
2298 .into_bytes(),
2299 )
2300 .await;
2301 let project = Project::test(fs.clone(), [], cx).await;
2302
2303 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2304
2305 // Create the agent and connection
2306 let agent =
2307 cx.update(|cx| NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx));
2308 let connection = NativeAgentConnection(agent.clone());
2309
2310 // Create a thread/session
2311 let acp_thread = cx
2312 .update(|cx| {
2313 Rc::new(connection.clone()).new_session(
2314 project.clone(),
2315 PathList::new(&[Path::new("/a")]),
2316 cx,
2317 )
2318 })
2319 .await
2320 .unwrap();
2321
2322 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2323
2324 // Select a model
2325 let selector = connection.model_selector(&session_id).unwrap();
2326 let model_id = acp::ModelId::new("fake/fake");
2327 cx.update(|cx| selector.select_model(model_id.clone(), cx))
2328 .await
2329 .unwrap();
2330
2331 // Verify the thread has the selected model
2332 agent.read_with(cx, |agent, _| {
2333 let session = agent.sessions.get(&session_id).unwrap();
2334 session.thread.read_with(cx, |thread, _| {
2335 assert_eq!(thread.model().unwrap().id().0, "fake");
2336 });
2337 });
2338
2339 cx.run_until_parked();
2340
2341 // Verify settings file was updated
2342 let settings_content = fs.load(paths::settings_file()).await.unwrap();
2343 let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
2344
2345 // Check that the agent settings contain the selected model
2346 assert_eq!(
2347 settings_json["agent"]["default_model"]["model"],
2348 json!("fake")
2349 );
2350 assert_eq!(
2351 settings_json["agent"]["default_model"]["provider"],
2352 json!("fake")
2353 );
2354
2355 // Register a thinking model and select it.
2356 cx.update(|cx| {
2357 let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2358 "fake-corp",
2359 "fake-thinking",
2360 "Fake Thinking",
2361 true,
2362 ));
2363 let thinking_provider = Arc::new(
2364 FakeLanguageModelProvider::new(
2365 LanguageModelProviderId::from("fake-corp".to_string()),
2366 LanguageModelProviderName::from("Fake Corp".to_string()),
2367 )
2368 .with_models(vec![thinking_model]),
2369 );
2370 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2371 registry.register_provider(thinking_provider, cx);
2372 });
2373 });
2374 agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2375
2376 let selector = connection.model_selector(&session_id).unwrap();
2377 cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2378 .await
2379 .unwrap();
2380 cx.run_until_parked();
2381
2382 // Verify enable_thinking was written to settings as true.
2383 let settings_content = fs.load(paths::settings_file()).await.unwrap();
2384 let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
2385 assert_eq!(
2386 settings_json["agent"]["default_model"]["enable_thinking"],
2387 json!(true),
2388 "selecting a thinking model should persist enable_thinking: true to settings"
2389 );
2390 }
2391
2392 #[gpui::test]
2393 async fn test_select_model_updates_thinking_enabled(cx: &mut TestAppContext) {
2394 init_test(cx);
2395 let fs = FakeFs::new(cx.executor());
2396 fs.create_dir(paths::settings_file().parent().unwrap())
2397 .await
2398 .unwrap();
2399 fs.insert_file(paths::settings_file(), b"{}".to_vec()).await;
2400 let project = Project::test(fs.clone(), [], cx).await;
2401
2402 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2403 let agent =
2404 cx.update(|cx| NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx));
2405 let connection = NativeAgentConnection(agent.clone());
2406
2407 let acp_thread = cx
2408 .update(|cx| {
2409 Rc::new(connection.clone()).new_session(
2410 project.clone(),
2411 PathList::new(&[Path::new("/a")]),
2412 cx,
2413 )
2414 })
2415 .await
2416 .unwrap();
2417 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2418
2419 // Register a second provider with a thinking model.
2420 cx.update(|cx| {
2421 let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2422 "fake-corp",
2423 "fake-thinking",
2424 "Fake Thinking",
2425 true,
2426 ));
2427 let thinking_provider = Arc::new(
2428 FakeLanguageModelProvider::new(
2429 LanguageModelProviderId::from("fake-corp".to_string()),
2430 LanguageModelProviderName::from("Fake Corp".to_string()),
2431 )
2432 .with_models(vec![thinking_model]),
2433 );
2434 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2435 registry.register_provider(thinking_provider, cx);
2436 });
2437 });
2438 // Refresh the agent's model list so it picks up the new provider.
2439 agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2440
2441 // Thread starts with thinking_enabled = false (the default).
2442 agent.read_with(cx, |agent, _| {
2443 let session = agent.sessions.get(&session_id).unwrap();
2444 session.thread.read_with(cx, |thread, _| {
2445 assert!(!thread.thinking_enabled(), "thinking defaults to false");
2446 });
2447 });
2448
2449 // Select the thinking model via select_model.
2450 let selector = connection.model_selector(&session_id).unwrap();
2451 cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2452 .await
2453 .unwrap();
2454
2455 // select_model should have enabled thinking based on the model's supports_thinking().
2456 agent.read_with(cx, |agent, _| {
2457 let session = agent.sessions.get(&session_id).unwrap();
2458 session.thread.read_with(cx, |thread, _| {
2459 assert!(
2460 thread.thinking_enabled(),
2461 "select_model should enable thinking when model supports it"
2462 );
2463 });
2464 });
2465
2466 // Switch back to the non-thinking model.
2467 let selector = connection.model_selector(&session_id).unwrap();
2468 cx.update(|cx| selector.select_model(acp::ModelId::new("fake/fake"), cx))
2469 .await
2470 .unwrap();
2471
2472 // select_model should have disabled thinking.
2473 agent.read_with(cx, |agent, _| {
2474 let session = agent.sessions.get(&session_id).unwrap();
2475 session.thread.read_with(cx, |thread, _| {
2476 assert!(
2477 !thread.thinking_enabled(),
2478 "select_model should disable thinking when model does not support it"
2479 );
2480 });
2481 });
2482 }
2483
2484 #[gpui::test]
2485 async fn test_summarization_model_survives_transient_registry_clearing(
2486 cx: &mut TestAppContext,
2487 ) {
2488 init_test(cx);
2489 let fs = FakeFs::new(cx.executor());
2490 fs.insert_tree("/", json!({ "a": {} })).await;
2491 let project = Project::test(fs.clone(), [], cx).await;
2492
2493 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2494 let agent =
2495 cx.update(|cx| NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx));
2496 let connection = Rc::new(NativeAgentConnection(agent.clone()));
2497
2498 let acp_thread = cx
2499 .update(|cx| {
2500 connection.clone().new_session(
2501 project.clone(),
2502 PathList::new(&[Path::new("/a")]),
2503 cx,
2504 )
2505 })
2506 .await
2507 .unwrap();
2508 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2509
2510 let thread = agent.read_with(cx, |agent, _| {
2511 agent.sessions.get(&session_id).unwrap().thread.clone()
2512 });
2513
2514 thread.read_with(cx, |thread, _| {
2515 assert!(
2516 thread.summarization_model().is_some(),
2517 "session should have a summarization model from the test registry"
2518 );
2519 });
2520
2521 // Simulate what happens during a provider blip:
2522 // update_active_language_model_from_settings calls set_default_model(None)
2523 // when it can't resolve the model, clearing all fallbacks.
2524 cx.update(|cx| {
2525 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2526 registry.set_default_model(None, cx);
2527 });
2528 });
2529 cx.run_until_parked();
2530
2531 thread.read_with(cx, |thread, _| {
2532 assert!(
2533 thread.summarization_model().is_some(),
2534 "summarization model should survive a transient default model clearing"
2535 );
2536 });
2537 }
2538
2539 #[gpui::test]
2540 async fn test_loaded_thread_preserves_thinking_enabled(cx: &mut TestAppContext) {
2541 init_test(cx);
2542 let fs = FakeFs::new(cx.executor());
2543 fs.insert_tree("/", json!({ "a": {} })).await;
2544 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2545 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2546 let agent = cx.update(|cx| {
2547 NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
2548 });
2549 let connection = Rc::new(NativeAgentConnection(agent.clone()));
2550
2551 // Register a thinking model.
2552 let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2553 "fake-corp",
2554 "fake-thinking",
2555 "Fake Thinking",
2556 true,
2557 ));
2558 let thinking_provider = Arc::new(
2559 FakeLanguageModelProvider::new(
2560 LanguageModelProviderId::from("fake-corp".to_string()),
2561 LanguageModelProviderName::from("Fake Corp".to_string()),
2562 )
2563 .with_models(vec![thinking_model.clone()]),
2564 );
2565 cx.update(|cx| {
2566 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2567 registry.register_provider(thinking_provider, cx);
2568 });
2569 });
2570 agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2571
2572 // Create a thread and select the thinking model.
2573 let acp_thread = cx
2574 .update(|cx| {
2575 connection.clone().new_session(
2576 project.clone(),
2577 PathList::new(&[Path::new("/a")]),
2578 cx,
2579 )
2580 })
2581 .await
2582 .unwrap();
2583 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2584
2585 let selector = connection.model_selector(&session_id).unwrap();
2586 cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2587 .await
2588 .unwrap();
2589
2590 // Verify thinking is enabled after selecting the thinking model.
2591 let thread = agent.read_with(cx, |agent, _| {
2592 agent.sessions.get(&session_id).unwrap().thread.clone()
2593 });
2594 thread.read_with(cx, |thread, _| {
2595 assert!(
2596 thread.thinking_enabled(),
2597 "thinking should be enabled after selecting thinking model"
2598 );
2599 });
2600
2601 // Send a message so the thread gets persisted.
2602 let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
2603 let send = cx.foreground_executor().spawn(send);
2604 cx.run_until_parked();
2605
2606 thinking_model.send_last_completion_stream_text_chunk("Response.");
2607 thinking_model.end_last_completion_stream();
2608
2609 send.await.unwrap();
2610 cx.run_until_parked();
2611
2612 // Close the session so it can be reloaded from disk.
2613 cx.update(|cx| connection.clone().close_session(&session_id, cx))
2614 .await
2615 .unwrap();
2616 drop(thread);
2617 drop(acp_thread);
2618 agent.read_with(cx, |agent, _| {
2619 assert!(agent.sessions.is_empty());
2620 });
2621
2622 // Reload the thread and verify thinking_enabled is still true.
2623 let reloaded_acp_thread = agent
2624 .update(cx, |agent, cx| {
2625 agent.open_thread(session_id.clone(), project.clone(), cx)
2626 })
2627 .await
2628 .unwrap();
2629 let reloaded_thread = agent.read_with(cx, |agent, _| {
2630 agent.sessions.get(&session_id).unwrap().thread.clone()
2631 });
2632 reloaded_thread.read_with(cx, |thread, _| {
2633 assert!(
2634 thread.thinking_enabled(),
2635 "thinking_enabled should be preserved when reloading a thread with a thinking model"
2636 );
2637 });
2638
2639 drop(reloaded_acp_thread);
2640 }
2641
2642 #[gpui::test]
2643 async fn test_loaded_thread_preserves_model(cx: &mut TestAppContext) {
2644 init_test(cx);
2645 let fs = FakeFs::new(cx.executor());
2646 fs.insert_tree("/", json!({ "a": {} })).await;
2647 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2648 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2649 let agent = cx.update(|cx| {
2650 NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
2651 });
2652 let connection = Rc::new(NativeAgentConnection(agent.clone()));
2653
2654 // Register a model where id() != name(), like real Anthropic models
2655 // (e.g. id="claude-sonnet-4-5-thinking-latest", name="Claude Sonnet 4.5 Thinking").
2656 let model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2657 "fake-corp",
2658 "custom-model-id",
2659 "Custom Model Display Name",
2660 false,
2661 ));
2662 let provider = Arc::new(
2663 FakeLanguageModelProvider::new(
2664 LanguageModelProviderId::from("fake-corp".to_string()),
2665 LanguageModelProviderName::from("Fake Corp".to_string()),
2666 )
2667 .with_models(vec![model.clone()]),
2668 );
2669 cx.update(|cx| {
2670 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2671 registry.register_provider(provider, cx);
2672 });
2673 });
2674 agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2675
2676 // Create a thread and select the model.
2677 let acp_thread = cx
2678 .update(|cx| {
2679 connection.clone().new_session(
2680 project.clone(),
2681 PathList::new(&[Path::new("/a")]),
2682 cx,
2683 )
2684 })
2685 .await
2686 .unwrap();
2687 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2688
2689 let selector = connection.model_selector(&session_id).unwrap();
2690 cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/custom-model-id"), cx))
2691 .await
2692 .unwrap();
2693
2694 let thread = agent.read_with(cx, |agent, _| {
2695 agent.sessions.get(&session_id).unwrap().thread.clone()
2696 });
2697 thread.read_with(cx, |thread, _| {
2698 assert_eq!(
2699 thread.model().unwrap().id().0.as_ref(),
2700 "custom-model-id",
2701 "model should be set before persisting"
2702 );
2703 });
2704
2705 // Send a message so the thread gets persisted.
2706 let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
2707 let send = cx.foreground_executor().spawn(send);
2708 cx.run_until_parked();
2709
2710 model.send_last_completion_stream_text_chunk("Response.");
2711 model.end_last_completion_stream();
2712
2713 send.await.unwrap();
2714 cx.run_until_parked();
2715
2716 // Close the session so it can be reloaded from disk.
2717 cx.update(|cx| connection.clone().close_session(&session_id, cx))
2718 .await
2719 .unwrap();
2720 drop(thread);
2721 drop(acp_thread);
2722 agent.read_with(cx, |agent, _| {
2723 assert!(agent.sessions.is_empty());
2724 });
2725
2726 // Reload the thread and verify the model was preserved.
2727 let reloaded_acp_thread = agent
2728 .update(cx, |agent, cx| {
2729 agent.open_thread(session_id.clone(), project.clone(), cx)
2730 })
2731 .await
2732 .unwrap();
2733 let reloaded_thread = agent.read_with(cx, |agent, _| {
2734 agent.sessions.get(&session_id).unwrap().thread.clone()
2735 });
2736 reloaded_thread.read_with(cx, |thread, _| {
2737 let reloaded_model = thread
2738 .model()
2739 .expect("model should be present after reload");
2740 assert_eq!(
2741 reloaded_model.id().0.as_ref(),
2742 "custom-model-id",
2743 "reloaded thread should have the same model, not fall back to the default"
2744 );
2745 });
2746
2747 drop(reloaded_acp_thread);
2748 }
2749
2750 #[gpui::test]
2751 async fn test_save_load_thread(cx: &mut TestAppContext) {
2752 init_test(cx);
2753 let fs = FakeFs::new(cx.executor());
2754 fs.insert_tree(
2755 "/",
2756 json!({
2757 "a": {
2758 "b.md": "Lorem"
2759 }
2760 }),
2761 )
2762 .await;
2763 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2764 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2765 let agent = cx.update(|cx| {
2766 NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
2767 });
2768 let connection = Rc::new(NativeAgentConnection(agent.clone()));
2769
2770 let acp_thread = cx
2771 .update(|cx| {
2772 connection
2773 .clone()
2774 .new_session(project.clone(), PathList::new(&[Path::new("")]), cx)
2775 })
2776 .await
2777 .unwrap();
2778 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2779 let thread = agent.read_with(cx, |agent, _| {
2780 agent.sessions.get(&session_id).unwrap().thread.clone()
2781 });
2782
2783 // Ensure empty threads are not saved, even if they get mutated.
2784 let model = Arc::new(FakeLanguageModel::default());
2785 let summary_model = Arc::new(FakeLanguageModel::default());
2786 thread.update(cx, |thread, cx| {
2787 thread.set_model(model.clone(), cx);
2788 thread.set_summarization_model(Some(summary_model.clone()), cx);
2789 });
2790 cx.run_until_parked();
2791 assert_eq!(thread_entries(&thread_store, cx), vec![]);
2792
2793 let send = acp_thread.update(cx, |thread, cx| {
2794 thread.send(
2795 vec![
2796 "What does ".into(),
2797 acp::ContentBlock::ResourceLink(acp::ResourceLink::new(
2798 "b.md",
2799 MentionUri::File {
2800 abs_path: path!("/a/b.md").into(),
2801 }
2802 .to_uri()
2803 .to_string(),
2804 )),
2805 " mean?".into(),
2806 ],
2807 cx,
2808 )
2809 });
2810 let send = cx.foreground_executor().spawn(send);
2811 cx.run_until_parked();
2812
2813 model.send_last_completion_stream_text_chunk("Lorem.");
2814 model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2815 language_model::TokenUsage {
2816 input_tokens: 150,
2817 output_tokens: 75,
2818 ..Default::default()
2819 },
2820 ));
2821 model.end_last_completion_stream();
2822 cx.run_until_parked();
2823 summary_model
2824 .send_last_completion_stream_text_chunk(&format!("Explaining {}", path!("/a/b.md")));
2825 summary_model.end_last_completion_stream();
2826
2827 send.await.unwrap();
2828 let uri = MentionUri::File {
2829 abs_path: path!("/a/b.md").into(),
2830 }
2831 .to_uri();
2832 acp_thread.read_with(cx, |thread, cx| {
2833 assert_eq!(
2834 thread.to_markdown(cx),
2835 formatdoc! {"
2836 ## User
2837
2838 What does [@b.md]({uri}) mean?
2839
2840 ## Assistant
2841
2842 Lorem.
2843
2844 "}
2845 )
2846 });
2847
2848 cx.run_until_parked();
2849
2850 // Set a draft prompt with rich content blocks and scroll position
2851 // AFTER run_until_parked, so the only save that captures these
2852 // changes is the one performed by close_session itself.
2853 let draft_blocks = vec![
2854 acp::ContentBlock::Text(acp::TextContent::new("Check out ")),
2855 acp::ContentBlock::ResourceLink(acp::ResourceLink::new("b.md", uri.to_string())),
2856 acp::ContentBlock::Text(acp::TextContent::new(" please")),
2857 ];
2858 acp_thread.update(cx, |thread, _cx| {
2859 thread.set_draft_prompt(Some(draft_blocks.clone()));
2860 });
2861 thread.update(cx, |thread, _cx| {
2862 thread.set_ui_scroll_position(Some(gpui::ListOffset {
2863 item_ix: 5,
2864 offset_in_item: gpui::px(12.5),
2865 }));
2866 });
2867
2868 // Close the session so it can be reloaded from disk.
2869 cx.update(|cx| connection.clone().close_session(&session_id, cx))
2870 .await
2871 .unwrap();
2872 drop(thread);
2873 drop(acp_thread);
2874 agent.read_with(cx, |agent, _| {
2875 assert_eq!(agent.sessions.keys().cloned().collect::<Vec<_>>(), []);
2876 });
2877
2878 // Ensure the thread can be reloaded from disk.
2879 assert_eq!(
2880 thread_entries(&thread_store, cx),
2881 vec![(
2882 session_id.clone(),
2883 format!("Explaining {}", path!("/a/b.md"))
2884 )]
2885 );
2886 let acp_thread = agent
2887 .update(cx, |agent, cx| {
2888 agent.open_thread(session_id.clone(), project.clone(), cx)
2889 })
2890 .await
2891 .unwrap();
2892 acp_thread.read_with(cx, |thread, cx| {
2893 assert_eq!(
2894 thread.to_markdown(cx),
2895 formatdoc! {"
2896 ## User
2897
2898 What does [@b.md]({uri}) mean?
2899
2900 ## Assistant
2901
2902 Lorem.
2903
2904 "}
2905 )
2906 });
2907
2908 // Ensure the draft prompt with rich content blocks survived the round-trip.
2909 acp_thread.read_with(cx, |thread, _| {
2910 assert_eq!(thread.draft_prompt(), Some(draft_blocks.as_slice()));
2911 });
2912
2913 // Ensure token usage survived the round-trip.
2914 acp_thread.read_with(cx, |thread, _| {
2915 let usage = thread
2916 .token_usage()
2917 .expect("token usage should be restored after reload");
2918 assert_eq!(usage.input_tokens, 150);
2919 assert_eq!(usage.output_tokens, 75);
2920 });
2921
2922 // Ensure scroll position survived the round-trip.
2923 acp_thread.read_with(cx, |thread, _| {
2924 let scroll = thread
2925 .ui_scroll_position()
2926 .expect("scroll position should be restored after reload");
2927 assert_eq!(scroll.item_ix, 5);
2928 assert_eq!(scroll.offset_in_item, gpui::px(12.5));
2929 });
2930 }
2931
2932 #[gpui::test]
2933 async fn test_close_session_saves_thread(cx: &mut TestAppContext) {
2934 init_test(cx);
2935 let fs = FakeFs::new(cx.executor());
2936 fs.insert_tree(
2937 "/",
2938 json!({
2939 "a": {
2940 "file.txt": "hello"
2941 }
2942 }),
2943 )
2944 .await;
2945 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2946 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2947 let agent = cx.update(|cx| {
2948 NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
2949 });
2950 let connection = Rc::new(NativeAgentConnection(agent.clone()));
2951
2952 let acp_thread = cx
2953 .update(|cx| {
2954 connection
2955 .clone()
2956 .new_session(project.clone(), PathList::new(&[Path::new("")]), cx)
2957 })
2958 .await
2959 .unwrap();
2960 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2961 let thread = agent.read_with(cx, |agent, _| {
2962 agent.sessions.get(&session_id).unwrap().thread.clone()
2963 });
2964
2965 let model = Arc::new(FakeLanguageModel::default());
2966 thread.update(cx, |thread, cx| {
2967 thread.set_model(model.clone(), cx);
2968 });
2969
2970 // Send a message so the thread is non-empty (empty threads aren't saved).
2971 let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx));
2972 let send = cx.foreground_executor().spawn(send);
2973 cx.run_until_parked();
2974
2975 model.send_last_completion_stream_text_chunk("world");
2976 model.end_last_completion_stream();
2977 send.await.unwrap();
2978 cx.run_until_parked();
2979
2980 // Set a draft prompt WITHOUT calling run_until_parked afterwards.
2981 // This means no observe-triggered save has run for this change.
2982 // The only way this data gets persisted is if close_session
2983 // itself performs the save.
2984 let draft_blocks = vec![acp::ContentBlock::Text(acp::TextContent::new(
2985 "unsaved draft",
2986 ))];
2987 acp_thread.update(cx, |thread, _cx| {
2988 thread.set_draft_prompt(Some(draft_blocks.clone()));
2989 });
2990
2991 // Close the session immediately — no run_until_parked in between.
2992 cx.update(|cx| connection.clone().close_session(&session_id, cx))
2993 .await
2994 .unwrap();
2995 cx.run_until_parked();
2996
2997 // Reopen and verify the draft prompt was saved.
2998 let reloaded = agent
2999 .update(cx, |agent, cx| {
3000 agent.open_thread(session_id.clone(), project.clone(), cx)
3001 })
3002 .await
3003 .unwrap();
3004 reloaded.read_with(cx, |thread, _| {
3005 assert_eq!(
3006 thread.draft_prompt(),
3007 Some(draft_blocks.as_slice()),
3008 "close_session must save the thread; draft prompt was lost"
3009 );
3010 });
3011 }
3012
3013 #[gpui::test]
3014 async fn test_rapid_title_changes_do_not_loop(cx: &mut TestAppContext) {
3015 // Regression test: rapid title changes must not cause a propagation loop
3016 // between Thread and AcpThread via handle_thread_title_updated.
3017 init_test(cx);
3018 let fs = FakeFs::new(cx.executor());
3019 fs.insert_tree("/", json!({ "a": {} })).await;
3020 let project = Project::test(fs.clone(), [], cx).await;
3021 let thread_store = cx.new(|cx| ThreadStore::new(cx));
3022 let agent = cx.update(|cx| {
3023 NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
3024 });
3025 let connection = Rc::new(NativeAgentConnection(agent.clone()));
3026
3027 let acp_thread = cx
3028 .update(|cx| {
3029 connection
3030 .clone()
3031 .new_session(project.clone(), PathList::new(&[Path::new("")]), cx)
3032 })
3033 .await
3034 .unwrap();
3035
3036 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
3037 let thread = agent.read_with(cx, |agent, _| {
3038 agent.sessions.get(&session_id).unwrap().thread.clone()
3039 });
3040
3041 let title_updated_count = Rc::new(std::cell::RefCell::new(0usize));
3042 cx.update(|cx| {
3043 let count = title_updated_count.clone();
3044 cx.subscribe(
3045 &thread,
3046 move |_entity: Entity<Thread>, _event: &TitleUpdated, _cx: &mut App| {
3047 let new_count = {
3048 let mut count = count.borrow_mut();
3049 *count += 1;
3050 *count
3051 };
3052 assert!(
3053 new_count <= 2,
3054 "TitleUpdated fired {new_count} times; \
3055 title updates are looping"
3056 );
3057 },
3058 )
3059 .detach();
3060 });
3061
3062 thread.update(cx, |thread, cx| thread.set_title("first".into(), cx));
3063 thread.update(cx, |thread, cx| thread.set_title("second".into(), cx));
3064
3065 cx.run_until_parked();
3066
3067 thread.read_with(cx, |thread, _| {
3068 assert_eq!(thread.title(), Some("second".into()));
3069 });
3070 acp_thread.read_with(cx, |acp_thread, _| {
3071 assert_eq!(acp_thread.title(), Some("second".into()));
3072 });
3073
3074 assert_eq!(*title_updated_count.borrow(), 2);
3075 }
3076
3077 fn thread_entries(
3078 thread_store: &Entity<ThreadStore>,
3079 cx: &mut TestAppContext,
3080 ) -> Vec<(acp::SessionId, String)> {
3081 thread_store.read_with(cx, |store, _| {
3082 store
3083 .entries()
3084 .map(|entry| (entry.id.clone(), entry.title.to_string()))
3085 .collect::<Vec<_>>()
3086 })
3087 }
3088
3089 fn init_test(cx: &mut TestAppContext) {
3090 env_logger::try_init().ok();
3091 cx.update(|cx| {
3092 let settings_store = SettingsStore::test(cx);
3093 cx.set_global(settings_store);
3094
3095 LanguageModelRegistry::test(cx);
3096 });
3097 }
3098}
3099
3100fn mcp_message_content_to_acp_content_block(
3101 content: context_server::types::MessageContent,
3102) -> acp::ContentBlock {
3103 match content {
3104 context_server::types::MessageContent::Text {
3105 text,
3106 annotations: _,
3107 } => text.into(),
3108 context_server::types::MessageContent::Image {
3109 data,
3110 mime_type,
3111 annotations: _,
3112 } => acp::ContentBlock::Image(acp::ImageContent::new(data, mime_type)),
3113 context_server::types::MessageContent::Audio {
3114 data,
3115 mime_type,
3116 annotations: _,
3117 } => acp::ContentBlock::Audio(acp::AudioContent::new(data, mime_type)),
3118 context_server::types::MessageContent::Resource {
3119 resource,
3120 annotations: _,
3121 } => {
3122 let mut link =
3123 acp::ResourceLink::new(resource.uri.to_string(), resource.uri.to_string());
3124 if let Some(mime_type) = resource.mime_type {
3125 link = link.mime_type(mime_type);
3126 }
3127 acp::ContentBlock::ResourceLink(link)
3128 }
3129 }
3130}