1mod db;
2mod edit_agent;
3mod legacy_thread;
4mod native_agent_server;
5pub mod outline;
6mod pattern_extraction;
7mod templates;
8#[cfg(test)]
9mod tests;
10mod thread;
11mod thread_store;
12mod tool_permissions;
13mod tools;
14
15use context_server::ContextServerId;
16pub use db::*;
17use itertools::Itertools;
18pub use native_agent_server::NativeAgentServer;
19pub use pattern_extraction::*;
20pub use shell_command_parser::extract_commands;
21pub use templates::*;
22pub use thread::*;
23pub use thread_store::*;
24pub use tool_permissions::*;
25pub use tools::*;
26
27use acp_thread::{
28 AcpThread, AgentModelSelector, AgentSessionInfo, AgentSessionList, AgentSessionListRequest,
29 AgentSessionListResponse, TokenUsageRatio, UserMessageId,
30};
31use agent_client_protocol as acp;
32use anyhow::{Context as _, Result, anyhow};
33use chrono::{DateTime, Utc};
34use collections::{HashMap, HashSet, IndexMap};
35use fs::Fs;
36use futures::channel::{mpsc, oneshot};
37use futures::future::Shared;
38use futures::{FutureExt as _, StreamExt as _, future};
39use gpui::{
40 App, AppContext, AsyncApp, Context, Entity, EntityId, SharedString, Subscription, Task,
41 WeakEntity,
42};
43use language_model::{IconOrSvg, LanguageModel, LanguageModelProvider, LanguageModelRegistry};
44use project::{AgentId, Project, ProjectItem, ProjectPath, Worktree};
45use prompt_store::{
46 ProjectContext, PromptStore, RULES_FILE_NAMES, RulesFileContext, UserRulesContext,
47 WorktreeContext,
48};
49use serde::{Deserialize, Serialize};
50use settings::{LanguageModelSelection, update_settings_file};
51use std::any::Any;
52use std::path::PathBuf;
53use std::rc::Rc;
54use std::sync::{Arc, LazyLock};
55use util::ResultExt;
56use util::path_list::PathList;
57use util::rel_path::RelPath;
58
59#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
60pub struct ProjectSnapshot {
61 pub worktree_snapshots: Vec<project::telemetry_snapshot::TelemetryWorktreeSnapshot>,
62 pub timestamp: DateTime<Utc>,
63}
64
65pub struct RulesLoadingError {
66 pub message: SharedString,
67}
68
69struct ProjectState {
70 project: Entity<Project>,
71 project_context: Entity<ProjectContext>,
72 project_context_needs_refresh: watch::Sender<()>,
73 _maintain_project_context: Task<Result<()>>,
74 context_server_registry: Entity<ContextServerRegistry>,
75 _subscriptions: Vec<Subscription>,
76}
77
78/// Holds both the internal Thread and the AcpThread for a session
79struct Session {
80 /// The internal thread that processes messages
81 thread: Entity<Thread>,
82 /// The ACP thread that handles protocol communication
83 acp_thread: Entity<acp_thread::AcpThread>,
84 project_id: EntityId,
85 pending_save: Task<Result<()>>,
86 _subscriptions: Vec<Subscription>,
87}
88
89pub struct LanguageModels {
90 /// Access language model by ID
91 models: HashMap<acp::ModelId, Arc<dyn LanguageModel>>,
92 /// Cached list for returning language model information
93 model_list: acp_thread::AgentModelList,
94 refresh_models_rx: watch::Receiver<()>,
95 refresh_models_tx: watch::Sender<()>,
96 _authenticate_all_providers_task: Task<()>,
97}
98
99impl LanguageModels {
100 fn new(cx: &mut App) -> Self {
101 let (refresh_models_tx, refresh_models_rx) = watch::channel(());
102
103 let mut this = Self {
104 models: HashMap::default(),
105 model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()),
106 refresh_models_rx,
107 refresh_models_tx,
108 _authenticate_all_providers_task: Self::authenticate_all_language_model_providers(cx),
109 };
110 this.refresh_list(cx);
111 this
112 }
113
114 fn refresh_list(&mut self, cx: &App) {
115 let providers = LanguageModelRegistry::global(cx)
116 .read(cx)
117 .visible_providers()
118 .into_iter()
119 .filter(|provider| provider.is_authenticated(cx))
120 .collect::<Vec<_>>();
121
122 let mut language_model_list = IndexMap::default();
123 let mut recommended_models = HashSet::default();
124
125 let mut recommended = Vec::new();
126 for provider in &providers {
127 for model in provider.recommended_models(cx) {
128 recommended_models.insert((model.provider_id(), model.id()));
129 recommended.push(Self::map_language_model_to_info(&model, provider));
130 }
131 }
132 if !recommended.is_empty() {
133 language_model_list.insert(
134 acp_thread::AgentModelGroupName("Recommended".into()),
135 recommended,
136 );
137 }
138
139 let mut models = HashMap::default();
140 for provider in providers {
141 let mut provider_models = Vec::new();
142 for model in provider.provided_models(cx) {
143 let model_info = Self::map_language_model_to_info(&model, &provider);
144 let model_id = model_info.id.clone();
145 provider_models.push(model_info);
146 models.insert(model_id, model);
147 }
148 if !provider_models.is_empty() {
149 language_model_list.insert(
150 acp_thread::AgentModelGroupName(provider.name().0.clone()),
151 provider_models,
152 );
153 }
154 }
155
156 self.models = models;
157 self.model_list = acp_thread::AgentModelList::Grouped(language_model_list);
158 self.refresh_models_tx.send(()).ok();
159 }
160
161 fn watch(&self) -> watch::Receiver<()> {
162 self.refresh_models_rx.clone()
163 }
164
165 pub fn model_from_id(&self, model_id: &acp::ModelId) -> Option<Arc<dyn LanguageModel>> {
166 self.models.get(model_id).cloned()
167 }
168
169 fn map_language_model_to_info(
170 model: &Arc<dyn LanguageModel>,
171 provider: &Arc<dyn LanguageModelProvider>,
172 ) -> acp_thread::AgentModelInfo {
173 acp_thread::AgentModelInfo {
174 id: Self::model_id(model),
175 name: model.name().0,
176 description: None,
177 icon: Some(match provider.icon() {
178 IconOrSvg::Svg(path) => acp_thread::AgentModelIcon::Path(path),
179 IconOrSvg::Icon(name) => acp_thread::AgentModelIcon::Named(name),
180 }),
181 is_latest: model.is_latest(),
182 cost: model.model_cost_info().map(|cost| cost.to_shared_string()),
183 }
184 }
185
186 fn model_id(model: &Arc<dyn LanguageModel>) -> acp::ModelId {
187 acp::ModelId::new(format!("{}/{}", model.provider_id().0, model.id().0))
188 }
189
190 fn authenticate_all_language_model_providers(cx: &mut App) -> Task<()> {
191 let authenticate_all_providers = LanguageModelRegistry::global(cx)
192 .read(cx)
193 .visible_providers()
194 .iter()
195 .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
196 .collect::<Vec<_>>();
197
198 cx.background_spawn(async move {
199 for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
200 if let Err(err) = authenticate_task.await {
201 match err {
202 language_model::AuthenticateError::CredentialsNotFound => {
203 // Since we're authenticating these providers in the
204 // background for the purposes of populating the
205 // language selector, we don't care about providers
206 // where the credentials are not found.
207 }
208 language_model::AuthenticateError::ConnectionRefused => {
209 // Not logging connection refused errors as they are mostly from LM Studio's noisy auth failures.
210 // LM Studio only has one auth method (endpoint call) which fails for users who haven't enabled it.
211 // TODO: Better manage LM Studio auth logic to avoid these noisy failures.
212 }
213 _ => {
214 // Some providers have noisy failure states that we
215 // don't want to spam the logs with every time the
216 // language model selector is initialized.
217 //
218 // Ideally these should have more clear failure modes
219 // that we know are safe to ignore here, like what we do
220 // with `CredentialsNotFound` above.
221 match provider_id.0.as_ref() {
222 "lmstudio" | "ollama" => {
223 // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
224 //
225 // These fail noisily, so we don't log them.
226 }
227 "copilot_chat" => {
228 // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
229 }
230 _ => {
231 log::error!(
232 "Failed to authenticate provider: {}: {err:#}",
233 provider_name.0
234 );
235 }
236 }
237 }
238 }
239 }
240 }
241 })
242 }
243}
244
245pub struct NativeAgent {
246 /// Session ID -> Session mapping
247 sessions: HashMap<acp::SessionId, Session>,
248 thread_store: Entity<ThreadStore>,
249 /// Project-specific state keyed by project EntityId
250 projects: HashMap<EntityId, ProjectState>,
251 /// Shared templates for all threads
252 templates: Arc<Templates>,
253 /// Cached model information
254 models: LanguageModels,
255 prompt_store: Option<Entity<PromptStore>>,
256 fs: Arc<dyn Fs>,
257 _subscriptions: Vec<Subscription>,
258}
259
260impl NativeAgent {
261 pub fn new(
262 thread_store: Entity<ThreadStore>,
263 templates: Arc<Templates>,
264 prompt_store: Option<Entity<PromptStore>>,
265 fs: Arc<dyn Fs>,
266 cx: &mut App,
267 ) -> Entity<NativeAgent> {
268 log::debug!("Creating new NativeAgent");
269
270 cx.new(|cx| {
271 let mut subscriptions = vec![cx.subscribe(
272 &LanguageModelRegistry::global(cx),
273 Self::handle_models_updated_event,
274 )];
275 if let Some(prompt_store) = prompt_store.as_ref() {
276 subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
277 }
278
279 Self {
280 sessions: HashMap::default(),
281 thread_store,
282 projects: HashMap::default(),
283 templates,
284 models: LanguageModels::new(cx),
285 prompt_store,
286 fs,
287 _subscriptions: subscriptions,
288 }
289 })
290 }
291
292 fn new_session(
293 &mut self,
294 project: Entity<Project>,
295 cx: &mut Context<Self>,
296 ) -> Entity<AcpThread> {
297 let project_id = self.get_or_create_project_state(&project, cx);
298 let project_state = &self.projects[&project_id];
299
300 let registry = LanguageModelRegistry::read_global(cx);
301 let available_count = registry.available_models(cx).count();
302 log::debug!("Total available models: {}", available_count);
303
304 let default_model = registry.default_model().and_then(|default_model| {
305 self.models
306 .model_from_id(&LanguageModels::model_id(&default_model.model))
307 });
308 let thread = cx.new(|cx| {
309 Thread::new(
310 project,
311 project_state.project_context.clone(),
312 project_state.context_server_registry.clone(),
313 self.templates.clone(),
314 default_model,
315 cx,
316 )
317 });
318
319 self.register_session(thread, project_id, cx)
320 }
321
322 fn register_session(
323 &mut self,
324 thread_handle: Entity<Thread>,
325 project_id: EntityId,
326 cx: &mut Context<Self>,
327 ) -> Entity<AcpThread> {
328 let connection = Rc::new(NativeAgentConnection(cx.entity()));
329
330 let thread = thread_handle.read(cx);
331 let session_id = thread.id().clone();
332 let parent_session_id = thread.parent_thread_id();
333 let title = thread.title();
334 let draft_prompt = thread.draft_prompt().map(Vec::from);
335 let scroll_position = thread.ui_scroll_position();
336 let token_usage = thread.latest_token_usage();
337 let project = thread.project.clone();
338 let action_log = thread.action_log.clone();
339 let prompt_capabilities_rx = thread.prompt_capabilities_rx.clone();
340 let acp_thread = cx.new(|cx| {
341 let mut acp_thread = acp_thread::AcpThread::new(
342 parent_session_id,
343 title,
344 None,
345 connection,
346 project.clone(),
347 action_log.clone(),
348 session_id.clone(),
349 prompt_capabilities_rx,
350 cx,
351 );
352 acp_thread.set_draft_prompt(draft_prompt);
353 acp_thread.set_ui_scroll_position(scroll_position);
354 acp_thread.update_token_usage(token_usage, cx);
355 acp_thread
356 });
357
358 let registry = LanguageModelRegistry::read_global(cx);
359 let summarization_model = registry.thread_summary_model().map(|c| c.model);
360
361 let weak = cx.weak_entity();
362 let weak_thread = thread_handle.downgrade();
363 thread_handle.update(cx, |thread, cx| {
364 thread.set_summarization_model(summarization_model, cx);
365 thread.add_default_tools(
366 Rc::new(NativeThreadEnvironment {
367 acp_thread: acp_thread.downgrade(),
368 thread: weak_thread,
369 agent: weak,
370 }) as _,
371 cx,
372 )
373 });
374
375 let subscriptions = vec![
376 cx.subscribe(&thread_handle, Self::handle_thread_title_updated),
377 cx.subscribe(&thread_handle, Self::handle_thread_token_usage_updated),
378 cx.observe(&thread_handle, move |this, thread, cx| {
379 this.save_thread(thread, cx)
380 }),
381 ];
382
383 self.sessions.insert(
384 session_id,
385 Session {
386 thread: thread_handle,
387 acp_thread: acp_thread.clone(),
388 project_id,
389 _subscriptions: subscriptions,
390 pending_save: Task::ready(Ok(())),
391 },
392 );
393
394 self.update_available_commands_for_project(project_id, cx);
395
396 acp_thread
397 }
398
399 pub fn models(&self) -> &LanguageModels {
400 &self.models
401 }
402
403 fn get_or_create_project_state(
404 &mut self,
405 project: &Entity<Project>,
406 cx: &mut Context<Self>,
407 ) -> EntityId {
408 let project_id = project.entity_id();
409 if self.projects.contains_key(&project_id) {
410 return project_id;
411 }
412
413 let project_context = cx.new(|_| ProjectContext::new(vec![], vec![]));
414 self.register_project_with_initial_context(project.clone(), project_context, cx);
415 if let Some(state) = self.projects.get_mut(&project_id) {
416 state.project_context_needs_refresh.send(()).ok();
417 }
418 project_id
419 }
420
421 fn register_project_with_initial_context(
422 &mut self,
423 project: Entity<Project>,
424 project_context: Entity<ProjectContext>,
425 cx: &mut Context<Self>,
426 ) {
427 let project_id = project.entity_id();
428
429 let context_server_store = project.read(cx).context_server_store();
430 let context_server_registry =
431 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
432
433 let subscriptions = vec![
434 cx.subscribe(&project, Self::handle_project_event),
435 cx.subscribe(
436 &context_server_store,
437 Self::handle_context_server_store_updated,
438 ),
439 cx.subscribe(
440 &context_server_registry,
441 Self::handle_context_server_registry_event,
442 ),
443 ];
444
445 let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
446 watch::channel(());
447
448 self.projects.insert(
449 project_id,
450 ProjectState {
451 project,
452 project_context,
453 project_context_needs_refresh: project_context_needs_refresh_tx,
454 _maintain_project_context: cx.spawn(async move |this, cx| {
455 Self::maintain_project_context(
456 this,
457 project_id,
458 project_context_needs_refresh_rx,
459 cx,
460 )
461 .await
462 }),
463 context_server_registry,
464 _subscriptions: subscriptions,
465 },
466 );
467 }
468
469 fn session_project_state(&self, session_id: &acp::SessionId) -> Option<&ProjectState> {
470 self.sessions
471 .get(session_id)
472 .and_then(|session| self.projects.get(&session.project_id))
473 }
474
475 async fn maintain_project_context(
476 this: WeakEntity<Self>,
477 project_id: EntityId,
478 mut needs_refresh: watch::Receiver<()>,
479 cx: &mut AsyncApp,
480 ) -> Result<()> {
481 while needs_refresh.changed().await.is_ok() {
482 let project_context = this
483 .update(cx, |this, cx| {
484 let state = this
485 .projects
486 .get(&project_id)
487 .context("project state not found")?;
488 anyhow::Ok(Self::build_project_context(
489 &state.project,
490 this.prompt_store.as_ref(),
491 cx,
492 ))
493 })??
494 .await;
495 this.update(cx, |this, cx| {
496 if let Some(state) = this.projects.get(&project_id) {
497 state
498 .project_context
499 .update(cx, |current_project_context, _cx| {
500 *current_project_context = project_context;
501 });
502 }
503 })?;
504 }
505
506 Ok(())
507 }
508
509 fn build_project_context(
510 project: &Entity<Project>,
511 prompt_store: Option<&Entity<PromptStore>>,
512 cx: &mut App,
513 ) -> Task<ProjectContext> {
514 let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
515 let worktree_tasks = worktrees
516 .into_iter()
517 .map(|worktree| {
518 Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
519 })
520 .collect::<Vec<_>>();
521 let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
522 prompt_store.read_with(cx, |prompt_store, cx| {
523 let prompts = prompt_store.default_prompt_metadata();
524 let load_tasks = prompts.into_iter().map(|prompt_metadata| {
525 let contents = prompt_store.load(prompt_metadata.id, cx);
526 async move { (contents.await, prompt_metadata) }
527 });
528 cx.background_spawn(future::join_all(load_tasks))
529 })
530 } else {
531 Task::ready(vec![])
532 };
533
534 cx.spawn(async move |_cx| {
535 let (worktrees, default_user_rules) =
536 future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
537
538 let worktrees = worktrees
539 .into_iter()
540 .map(|(worktree, _rules_error)| {
541 // TODO: show error message
542 // if let Some(rules_error) = rules_error {
543 // this.update(cx, |_, cx| cx.emit(rules_error)).ok();
544 // }
545 worktree
546 })
547 .collect::<Vec<_>>();
548
549 let default_user_rules = default_user_rules
550 .into_iter()
551 .flat_map(|(contents, prompt_metadata)| match contents {
552 Ok(contents) => Some(UserRulesContext {
553 uuid: prompt_metadata.id.as_user()?,
554 title: prompt_metadata.title.map(|title| title.to_string()),
555 contents,
556 }),
557 Err(_err) => {
558 // TODO: show error message
559 // this.update(cx, |_, cx| {
560 // cx.emit(RulesLoadingError {
561 // message: format!("{err:?}").into(),
562 // });
563 // })
564 // .ok();
565 None
566 }
567 })
568 .collect::<Vec<_>>();
569
570 ProjectContext::new(worktrees, default_user_rules)
571 })
572 }
573
574 fn load_worktree_info_for_system_prompt(
575 worktree: Entity<Worktree>,
576 project: Entity<Project>,
577 cx: &mut App,
578 ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
579 let tree = worktree.read(cx);
580 let root_name = tree.root_name_str().into();
581 let abs_path = tree.abs_path();
582
583 let mut context = WorktreeContext {
584 root_name,
585 abs_path,
586 rules_file: None,
587 };
588
589 let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
590 let Some(rules_task) = rules_task else {
591 return Task::ready((context, None));
592 };
593
594 cx.spawn(async move |_| {
595 let (rules_file, rules_file_error) = match rules_task.await {
596 Ok(rules_file) => (Some(rules_file), None),
597 Err(err) => (
598 None,
599 Some(RulesLoadingError {
600 message: format!("{err}").into(),
601 }),
602 ),
603 };
604 context.rules_file = rules_file;
605 (context, rules_file_error)
606 })
607 }
608
609 fn load_worktree_rules_file(
610 worktree: Entity<Worktree>,
611 project: Entity<Project>,
612 cx: &mut App,
613 ) -> Option<Task<Result<RulesFileContext>>> {
614 let worktree = worktree.read(cx);
615 let worktree_id = worktree.id();
616 let selected_rules_file = RULES_FILE_NAMES
617 .into_iter()
618 .filter_map(|name| {
619 worktree
620 .entry_for_path(RelPath::unix(name).unwrap())
621 .filter(|entry| entry.is_file())
622 .map(|entry| entry.path.clone())
623 })
624 .next();
625
626 // Note that Cline supports `.clinerules` being a directory, but that is not currently
627 // supported. This doesn't seem to occur often in GitHub repositories.
628 selected_rules_file.map(|path_in_worktree| {
629 let project_path = ProjectPath {
630 worktree_id,
631 path: path_in_worktree.clone(),
632 };
633 let buffer_task =
634 project.update(cx, |project, cx| project.open_buffer(project_path, cx));
635 let rope_task = cx.spawn(async move |cx| {
636 let buffer = buffer_task.await?;
637 let (project_entry_id, rope) = buffer.read_with(cx, |buffer, cx| {
638 let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
639 anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
640 })?;
641 anyhow::Ok((project_entry_id, rope))
642 });
643 // Build a string from the rope on a background thread.
644 cx.background_spawn(async move {
645 let (project_entry_id, rope) = rope_task.await?;
646 anyhow::Ok(RulesFileContext {
647 path_in_worktree,
648 text: rope.to_string().trim().to_string(),
649 project_entry_id: project_entry_id.to_usize(),
650 })
651 })
652 })
653 }
654
655 fn handle_thread_title_updated(
656 &mut self,
657 thread: Entity<Thread>,
658 _: &TitleUpdated,
659 cx: &mut Context<Self>,
660 ) {
661 let session_id = thread.read(cx).id();
662 let Some(session) = self.sessions.get(session_id) else {
663 return;
664 };
665
666 if let Some(title) = thread.read(cx).title() {
667 let acp_thread = session.acp_thread.downgrade();
668 cx.spawn(async move |_, cx| {
669 let task =
670 acp_thread.update(cx, |acp_thread, cx| acp_thread.set_title(title, cx))?;
671 task.await
672 })
673 .detach_and_log_err(cx);
674 }
675 }
676
677 fn handle_thread_token_usage_updated(
678 &mut self,
679 thread: Entity<Thread>,
680 usage: &TokenUsageUpdated,
681 cx: &mut Context<Self>,
682 ) {
683 let Some(session) = self.sessions.get(thread.read(cx).id()) else {
684 return;
685 };
686 session.acp_thread.update(cx, |acp_thread, cx| {
687 acp_thread.update_token_usage(usage.0.clone(), cx);
688 });
689 }
690
691 fn handle_project_event(
692 &mut self,
693 project: Entity<Project>,
694 event: &project::Event,
695 _cx: &mut Context<Self>,
696 ) {
697 let project_id = project.entity_id();
698 let Some(state) = self.projects.get_mut(&project_id) else {
699 return;
700 };
701 match event {
702 project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
703 state.project_context_needs_refresh.send(()).ok();
704 }
705 project::Event::WorktreeUpdatedEntries(_, items) => {
706 if items.iter().any(|(path, _, _)| {
707 RULES_FILE_NAMES
708 .iter()
709 .any(|name| path.as_ref() == RelPath::unix(name).unwrap())
710 }) {
711 state.project_context_needs_refresh.send(()).ok();
712 }
713 }
714 _ => {}
715 }
716 }
717
718 fn handle_prompts_updated_event(
719 &mut self,
720 _prompt_store: Entity<PromptStore>,
721 _event: &prompt_store::PromptsUpdatedEvent,
722 _cx: &mut Context<Self>,
723 ) {
724 for state in self.projects.values_mut() {
725 state.project_context_needs_refresh.send(()).ok();
726 }
727 }
728
729 fn handle_models_updated_event(
730 &mut self,
731 _registry: Entity<LanguageModelRegistry>,
732 event: &language_model::Event,
733 cx: &mut Context<Self>,
734 ) {
735 self.models.refresh_list(cx);
736
737 let registry = LanguageModelRegistry::read_global(cx);
738 let default_model = registry.default_model().map(|m| m.model);
739 let summarization_model = registry.thread_summary_model().map(|m| m.model);
740
741 for session in self.sessions.values_mut() {
742 session.thread.update(cx, |thread, cx| {
743 if thread.model().is_none()
744 && let Some(model) = default_model.clone()
745 {
746 thread.set_model(model, cx);
747 cx.notify();
748 }
749 if let Some(model) = summarization_model.clone() {
750 if thread.summarization_model().is_none()
751 || matches!(event, language_model::Event::ThreadSummaryModelChanged)
752 {
753 thread.set_summarization_model(Some(model), cx);
754 }
755 }
756 });
757 }
758 }
759
760 fn handle_context_server_store_updated(
761 &mut self,
762 store: Entity<project::context_server_store::ContextServerStore>,
763 _event: &project::context_server_store::ServerStatusChangedEvent,
764 cx: &mut Context<Self>,
765 ) {
766 let project_id = self.projects.iter().find_map(|(id, state)| {
767 if *state.context_server_registry.read(cx).server_store() == store {
768 Some(*id)
769 } else {
770 None
771 }
772 });
773 if let Some(project_id) = project_id {
774 self.update_available_commands_for_project(project_id, cx);
775 }
776 }
777
778 fn handle_context_server_registry_event(
779 &mut self,
780 registry: Entity<ContextServerRegistry>,
781 event: &ContextServerRegistryEvent,
782 cx: &mut Context<Self>,
783 ) {
784 match event {
785 ContextServerRegistryEvent::ToolsChanged => {}
786 ContextServerRegistryEvent::PromptsChanged => {
787 let project_id = self.projects.iter().find_map(|(id, state)| {
788 if state.context_server_registry == registry {
789 Some(*id)
790 } else {
791 None
792 }
793 });
794 if let Some(project_id) = project_id {
795 self.update_available_commands_for_project(project_id, cx);
796 }
797 }
798 }
799 }
800
801 fn update_available_commands_for_project(&self, project_id: EntityId, cx: &mut Context<Self>) {
802 let available_commands =
803 Self::build_available_commands_for_project(self.projects.get(&project_id), cx);
804 for session in self.sessions.values() {
805 if session.project_id != project_id {
806 continue;
807 }
808 session.acp_thread.update(cx, |thread, cx| {
809 thread
810 .handle_session_update(
811 acp::SessionUpdate::AvailableCommandsUpdate(
812 acp::AvailableCommandsUpdate::new(available_commands.clone()),
813 ),
814 cx,
815 )
816 .log_err();
817 });
818 }
819 }
820
821 fn build_available_commands_for_project(
822 project_state: Option<&ProjectState>,
823 cx: &App,
824 ) -> Vec<acp::AvailableCommand> {
825 let Some(state) = project_state else {
826 return vec![];
827 };
828 let registry = state.context_server_registry.read(cx);
829
830 let mut prompt_name_counts: HashMap<&str, usize> = HashMap::default();
831 for context_server_prompt in registry.prompts() {
832 *prompt_name_counts
833 .entry(context_server_prompt.prompt.name.as_str())
834 .or_insert(0) += 1;
835 }
836
837 registry
838 .prompts()
839 .flat_map(|context_server_prompt| {
840 let prompt = &context_server_prompt.prompt;
841
842 let should_prefix = prompt_name_counts
843 .get(prompt.name.as_str())
844 .copied()
845 .unwrap_or(0)
846 > 1;
847
848 let name = if should_prefix {
849 format!("{}.{}", context_server_prompt.server_id, prompt.name)
850 } else {
851 prompt.name.clone()
852 };
853
854 let mut command = acp::AvailableCommand::new(
855 name,
856 prompt.description.clone().unwrap_or_default(),
857 );
858
859 match prompt.arguments.as_deref() {
860 Some([arg]) => {
861 let hint = format!("<{}>", arg.name);
862
863 command = command.input(acp::AvailableCommandInput::Unstructured(
864 acp::UnstructuredCommandInput::new(hint),
865 ));
866 }
867 Some([]) | None => {}
868 Some(_) => {
869 // skip >1 argument commands since we don't support them yet
870 return None;
871 }
872 }
873
874 Some(command)
875 })
876 .collect()
877 }
878
879 pub fn load_thread(
880 &mut self,
881 id: acp::SessionId,
882 project: Entity<Project>,
883 cx: &mut Context<Self>,
884 ) -> Task<Result<Entity<Thread>>> {
885 let database_future = ThreadsDatabase::connect(cx);
886 cx.spawn(async move |this, cx| {
887 let database = database_future.await.map_err(|err| anyhow!(err))?;
888 let db_thread = database
889 .load_thread(id.clone())
890 .await?
891 .with_context(|| format!("no thread found with ID: {id:?}"))?;
892
893 this.update(cx, |this, cx| {
894 let project_id = this.get_or_create_project_state(&project, cx);
895 let project_state = this
896 .projects
897 .get(&project_id)
898 .context("project state not found")?;
899 let summarization_model = LanguageModelRegistry::read_global(cx)
900 .thread_summary_model()
901 .map(|c| c.model);
902
903 Ok(cx.new(|cx| {
904 let mut thread = Thread::from_db(
905 id.clone(),
906 db_thread,
907 project_state.project.clone(),
908 project_state.project_context.clone(),
909 project_state.context_server_registry.clone(),
910 this.templates.clone(),
911 cx,
912 );
913 thread.set_summarization_model(summarization_model, cx);
914 thread
915 }))
916 })?
917 })
918 }
919
920 pub fn open_thread(
921 &mut self,
922 id: acp::SessionId,
923 project: Entity<Project>,
924 cx: &mut Context<Self>,
925 ) -> Task<Result<Entity<AcpThread>>> {
926 if let Some(session) = self.sessions.get(&id) {
927 return Task::ready(Ok(session.acp_thread.clone()));
928 }
929
930 let task = self.load_thread(id, project.clone(), cx);
931 cx.spawn(async move |this, cx| {
932 let thread = task.await?;
933 let acp_thread = this.update(cx, |this, cx| {
934 let project_id = this.get_or_create_project_state(&project, cx);
935 this.register_session(thread.clone(), project_id, cx)
936 })?;
937 let events = thread.update(cx, |thread, cx| thread.replay(cx));
938 cx.update(|cx| {
939 NativeAgentConnection::handle_thread_events(events, acp_thread.downgrade(), cx)
940 })
941 .await?;
942 Ok(acp_thread)
943 })
944 }
945
946 pub fn thread_summary(
947 &mut self,
948 id: acp::SessionId,
949 project: Entity<Project>,
950 cx: &mut Context<Self>,
951 ) -> Task<Result<SharedString>> {
952 let thread = self.open_thread(id.clone(), project, cx);
953 cx.spawn(async move |this, cx| {
954 let acp_thread = thread.await?;
955 let result = this
956 .update(cx, |this, cx| {
957 this.sessions
958 .get(&id)
959 .unwrap()
960 .thread
961 .update(cx, |thread, cx| thread.summary(cx))
962 })?
963 .await
964 .context("Failed to generate summary")?;
965 drop(acp_thread);
966 Ok(result)
967 })
968 }
969
970 fn save_thread(&mut self, thread: Entity<Thread>, cx: &mut Context<Self>) {
971 if thread.read(cx).is_empty() {
972 return;
973 }
974
975 let id = thread.read(cx).id().clone();
976 let Some(session) = self.sessions.get_mut(&id) else {
977 return;
978 };
979
980 let project_id = session.project_id;
981 let Some(state) = self.projects.get(&project_id) else {
982 return;
983 };
984
985 let folder_paths = PathList::new(
986 &state
987 .project
988 .read(cx)
989 .visible_worktrees(cx)
990 .map(|worktree| worktree.read(cx).abs_path().to_path_buf())
991 .collect::<Vec<_>>(),
992 );
993
994 let draft_prompt = session.acp_thread.read(cx).draft_prompt().map(Vec::from);
995 let database_future = ThreadsDatabase::connect(cx);
996 let db_thread = thread.update(cx, |thread, cx| {
997 thread.set_draft_prompt(draft_prompt);
998 thread.to_db(cx)
999 });
1000 let thread_store = self.thread_store.clone();
1001 session.pending_save = cx.spawn(async move |_, cx| {
1002 let Some(database) = database_future.await.map_err(|err| anyhow!(err)).log_err() else {
1003 return Ok(());
1004 };
1005 let db_thread = db_thread.await;
1006 database
1007 .save_thread(id, db_thread, folder_paths)
1008 .await
1009 .log_err();
1010 thread_store.update(cx, |store, cx| store.reload(cx));
1011 Ok(())
1012 });
1013 }
1014
1015 fn send_mcp_prompt(
1016 &self,
1017 message_id: UserMessageId,
1018 session_id: acp::SessionId,
1019 prompt_name: String,
1020 server_id: ContextServerId,
1021 arguments: HashMap<String, String>,
1022 original_content: Vec<acp::ContentBlock>,
1023 cx: &mut Context<Self>,
1024 ) -> Task<Result<acp::PromptResponse>> {
1025 let Some(state) = self.session_project_state(&session_id) else {
1026 return Task::ready(Err(anyhow!("Project state not found for session")));
1027 };
1028 let server_store = state
1029 .context_server_registry
1030 .read(cx)
1031 .server_store()
1032 .clone();
1033 let path_style = state.project.read(cx).path_style(cx);
1034
1035 cx.spawn(async move |this, cx| {
1036 let prompt =
1037 crate::get_prompt(&server_store, &server_id, &prompt_name, arguments, cx).await?;
1038
1039 let (acp_thread, thread) = this.update(cx, |this, _cx| {
1040 let session = this
1041 .sessions
1042 .get(&session_id)
1043 .context("Failed to get session")?;
1044 anyhow::Ok((session.acp_thread.clone(), session.thread.clone()))
1045 })??;
1046
1047 let mut last_is_user = true;
1048
1049 thread.update(cx, |thread, cx| {
1050 thread.push_acp_user_block(
1051 message_id,
1052 original_content.into_iter().skip(1),
1053 path_style,
1054 cx,
1055 );
1056 });
1057
1058 for message in prompt.messages {
1059 let context_server::types::PromptMessage { role, content } = message;
1060 let block = mcp_message_content_to_acp_content_block(content);
1061
1062 match role {
1063 context_server::types::Role::User => {
1064 let id = acp_thread::UserMessageId::new();
1065
1066 acp_thread.update(cx, |acp_thread, cx| {
1067 acp_thread.push_user_content_block_with_indent(
1068 Some(id.clone()),
1069 block.clone(),
1070 true,
1071 cx,
1072 );
1073 });
1074
1075 thread.update(cx, |thread, cx| {
1076 thread.push_acp_user_block(id, [block], path_style, cx);
1077 });
1078 }
1079 context_server::types::Role::Assistant => {
1080 acp_thread.update(cx, |acp_thread, cx| {
1081 acp_thread.push_assistant_content_block_with_indent(
1082 block.clone(),
1083 false,
1084 true,
1085 cx,
1086 );
1087 });
1088
1089 thread.update(cx, |thread, cx| {
1090 thread.push_acp_agent_block(block, cx);
1091 });
1092 }
1093 }
1094
1095 last_is_user = role == context_server::types::Role::User;
1096 }
1097
1098 let response_stream = thread.update(cx, |thread, cx| {
1099 if last_is_user {
1100 thread.send_existing(cx)
1101 } else {
1102 // Resume if MCP prompt did not end with a user message
1103 thread.resume(cx)
1104 }
1105 })?;
1106
1107 cx.update(|cx| {
1108 NativeAgentConnection::handle_thread_events(
1109 response_stream,
1110 acp_thread.downgrade(),
1111 cx,
1112 )
1113 })
1114 .await
1115 })
1116 }
1117}
1118
1119/// Wrapper struct that implements the AgentConnection trait
1120#[derive(Clone)]
1121pub struct NativeAgentConnection(pub Entity<NativeAgent>);
1122
1123impl NativeAgentConnection {
1124 pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option<Entity<Thread>> {
1125 self.0
1126 .read(cx)
1127 .sessions
1128 .get(session_id)
1129 .map(|session| session.thread.clone())
1130 }
1131
1132 pub fn load_thread(
1133 &self,
1134 id: acp::SessionId,
1135 project: Entity<Project>,
1136 cx: &mut App,
1137 ) -> Task<Result<Entity<Thread>>> {
1138 self.0
1139 .update(cx, |this, cx| this.load_thread(id, project, cx))
1140 }
1141
1142 fn run_turn(
1143 &self,
1144 session_id: acp::SessionId,
1145 cx: &mut App,
1146 f: impl 'static
1147 + FnOnce(Entity<Thread>, &mut App) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>,
1148 ) -> Task<Result<acp::PromptResponse>> {
1149 let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
1150 agent
1151 .sessions
1152 .get_mut(&session_id)
1153 .map(|s| (s.thread.clone(), s.acp_thread.clone()))
1154 }) else {
1155 return Task::ready(Err(anyhow!("Session not found")));
1156 };
1157 log::debug!("Found session for: {}", session_id);
1158
1159 let response_stream = match f(thread, cx) {
1160 Ok(stream) => stream,
1161 Err(err) => return Task::ready(Err(err)),
1162 };
1163 Self::handle_thread_events(response_stream, acp_thread.downgrade(), cx)
1164 }
1165
1166 fn handle_thread_events(
1167 mut events: mpsc::UnboundedReceiver<Result<ThreadEvent>>,
1168 acp_thread: WeakEntity<AcpThread>,
1169 cx: &App,
1170 ) -> Task<Result<acp::PromptResponse>> {
1171 cx.spawn(async move |cx| {
1172 // Handle response stream and forward to session.acp_thread
1173 while let Some(result) = events.next().await {
1174 match result {
1175 Ok(event) => {
1176 log::trace!("Received completion event: {:?}", event);
1177
1178 match event {
1179 ThreadEvent::UserMessage(message) => {
1180 acp_thread.update(cx, |thread, cx| {
1181 for content in message.content {
1182 thread.push_user_content_block(
1183 Some(message.id.clone()),
1184 content.into(),
1185 cx,
1186 );
1187 }
1188 })?;
1189 }
1190 ThreadEvent::AgentText(text) => {
1191 acp_thread.update(cx, |thread, cx| {
1192 thread.push_assistant_content_block(text.into(), false, cx)
1193 })?;
1194 }
1195 ThreadEvent::AgentThinking(text) => {
1196 acp_thread.update(cx, |thread, cx| {
1197 thread.push_assistant_content_block(text.into(), true, cx)
1198 })?;
1199 }
1200 ThreadEvent::ToolCallAuthorization(ToolCallAuthorization {
1201 tool_call,
1202 options,
1203 response,
1204 context: _,
1205 }) => {
1206 let outcome_task = acp_thread.update(cx, |thread, cx| {
1207 thread.request_tool_call_authorization(tool_call, options, cx)
1208 })??;
1209 cx.background_spawn(async move {
1210 if let acp_thread::RequestPermissionOutcome::Selected(outcome) =
1211 outcome_task.await
1212 {
1213 response
1214 .send(outcome)
1215 .map(|_| anyhow!("authorization receiver was dropped"))
1216 .log_err();
1217 }
1218 })
1219 .detach();
1220 }
1221 ThreadEvent::ToolCall(tool_call) => {
1222 acp_thread.update(cx, |thread, cx| {
1223 thread.upsert_tool_call(tool_call, cx)
1224 })??;
1225 }
1226 ThreadEvent::ToolCallUpdate(update) => {
1227 acp_thread.update(cx, |thread, cx| {
1228 thread.update_tool_call(update, cx)
1229 })??;
1230 }
1231 ThreadEvent::Plan(plan) => {
1232 acp_thread.update(cx, |thread, cx| thread.update_plan(plan, cx))?;
1233 }
1234 ThreadEvent::SubagentSpawned(session_id) => {
1235 acp_thread.update(cx, |thread, cx| {
1236 thread.subagent_spawned(session_id, cx);
1237 })?;
1238 }
1239 ThreadEvent::Retry(status) => {
1240 acp_thread.update(cx, |thread, cx| {
1241 thread.update_retry_status(status, cx)
1242 })?;
1243 }
1244 ThreadEvent::Stop(stop_reason) => {
1245 log::debug!("Assistant message complete: {:?}", stop_reason);
1246 return Ok(acp::PromptResponse::new(stop_reason));
1247 }
1248 }
1249 }
1250 Err(e) => {
1251 log::error!("Error in model response stream: {:?}", e);
1252 return Err(e);
1253 }
1254 }
1255 }
1256
1257 log::debug!("Response stream completed");
1258 anyhow::Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
1259 })
1260 }
1261}
1262
1263struct Command<'a> {
1264 prompt_name: &'a str,
1265 arg_value: &'a str,
1266 explicit_server_id: Option<&'a str>,
1267}
1268
1269impl<'a> Command<'a> {
1270 fn parse(prompt: &'a [acp::ContentBlock]) -> Option<Self> {
1271 let acp::ContentBlock::Text(text_content) = prompt.first()? else {
1272 return None;
1273 };
1274 let text = text_content.text.trim();
1275 let command = text.strip_prefix('/')?;
1276 let (command, arg_value) = command
1277 .split_once(char::is_whitespace)
1278 .unwrap_or((command, ""));
1279
1280 if let Some((server_id, prompt_name)) = command.split_once('.') {
1281 Some(Self {
1282 prompt_name,
1283 arg_value,
1284 explicit_server_id: Some(server_id),
1285 })
1286 } else {
1287 Some(Self {
1288 prompt_name: command,
1289 arg_value,
1290 explicit_server_id: None,
1291 })
1292 }
1293 }
1294}
1295
1296struct NativeAgentModelSelector {
1297 session_id: acp::SessionId,
1298 connection: NativeAgentConnection,
1299}
1300
1301impl acp_thread::AgentModelSelector for NativeAgentModelSelector {
1302 fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
1303 log::debug!("NativeAgentConnection::list_models called");
1304 let list = self.connection.0.read(cx).models.model_list.clone();
1305 Task::ready(if list.is_empty() {
1306 Err(anyhow::anyhow!("No models available"))
1307 } else {
1308 Ok(list)
1309 })
1310 }
1311
1312 fn select_model(&self, model_id: acp::ModelId, cx: &mut App) -> Task<Result<()>> {
1313 log::debug!(
1314 "Setting model for session {}: {}",
1315 self.session_id,
1316 model_id
1317 );
1318 let Some(thread) = self
1319 .connection
1320 .0
1321 .read(cx)
1322 .sessions
1323 .get(&self.session_id)
1324 .map(|session| session.thread.clone())
1325 else {
1326 return Task::ready(Err(anyhow!("Session not found")));
1327 };
1328
1329 let Some(model) = self.connection.0.read(cx).models.model_from_id(&model_id) else {
1330 return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
1331 };
1332
1333 // We want to reset the effort level when switching models, as the currently-selected effort level may
1334 // not be compatible.
1335 let effort = model
1336 .default_effort_level()
1337 .map(|effort_level| effort_level.value.to_string());
1338
1339 thread.update(cx, |thread, cx| {
1340 thread.set_model(model.clone(), cx);
1341 thread.set_thinking_effort(effort.clone(), cx);
1342 thread.set_thinking_enabled(model.supports_thinking(), cx);
1343 });
1344
1345 update_settings_file(
1346 self.connection.0.read(cx).fs.clone(),
1347 cx,
1348 move |settings, cx| {
1349 let provider = model.provider_id().0.to_string();
1350 let model = model.id().0.to_string();
1351 let enable_thinking = thread.read(cx).thinking_enabled();
1352 settings
1353 .agent
1354 .get_or_insert_default()
1355 .set_model(LanguageModelSelection {
1356 provider: provider.into(),
1357 model,
1358 enable_thinking,
1359 effort,
1360 });
1361 },
1362 );
1363
1364 Task::ready(Ok(()))
1365 }
1366
1367 fn selected_model(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelInfo>> {
1368 let Some(thread) = self
1369 .connection
1370 .0
1371 .read(cx)
1372 .sessions
1373 .get(&self.session_id)
1374 .map(|session| session.thread.clone())
1375 else {
1376 return Task::ready(Err(anyhow!("Session not found")));
1377 };
1378 let Some(model) = thread.read(cx).model() else {
1379 return Task::ready(Err(anyhow!("Model not found")));
1380 };
1381 let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
1382 else {
1383 return Task::ready(Err(anyhow!("Provider not found")));
1384 };
1385 Task::ready(Ok(LanguageModels::map_language_model_to_info(
1386 model, &provider,
1387 )))
1388 }
1389
1390 fn watch(&self, cx: &mut App) -> Option<watch::Receiver<()>> {
1391 Some(self.connection.0.read(cx).models.watch())
1392 }
1393
1394 fn should_render_footer(&self) -> bool {
1395 true
1396 }
1397}
1398
1399pub static ZED_AGENT_ID: LazyLock<AgentId> = LazyLock::new(|| AgentId::new("Zed Agent"));
1400
1401impl acp_thread::AgentConnection for NativeAgentConnection {
1402 fn agent_id(&self) -> AgentId {
1403 ZED_AGENT_ID.clone()
1404 }
1405
1406 fn telemetry_id(&self) -> SharedString {
1407 "zed".into()
1408 }
1409
1410 fn new_session(
1411 self: Rc<Self>,
1412 project: Entity<Project>,
1413 work_dirs: PathList,
1414 cx: &mut App,
1415 ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
1416 log::debug!("Creating new thread for project at: {work_dirs:?}");
1417 Task::ready(Ok(self
1418 .0
1419 .update(cx, |agent, cx| agent.new_session(project, cx))))
1420 }
1421
1422 fn supports_load_session(&self) -> bool {
1423 true
1424 }
1425
1426 fn load_session(
1427 self: Rc<Self>,
1428 session_id: acp::SessionId,
1429 project: Entity<Project>,
1430 _work_dirs: PathList,
1431 _title: Option<SharedString>,
1432 cx: &mut App,
1433 ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
1434 self.0
1435 .update(cx, |agent, cx| agent.open_thread(session_id, project, cx))
1436 }
1437
1438 fn supports_close_session(&self) -> bool {
1439 true
1440 }
1441
1442 fn close_session(
1443 self: Rc<Self>,
1444 session_id: &acp::SessionId,
1445 cx: &mut App,
1446 ) -> Task<Result<()>> {
1447 self.0.update(cx, |agent, cx| {
1448 let thread = agent.sessions.get(session_id).map(|s| s.thread.clone());
1449 if let Some(thread) = thread {
1450 agent.save_thread(thread, cx);
1451 }
1452
1453 let Some(session) = agent.sessions.remove(session_id) else {
1454 return Task::ready(Ok(()));
1455 };
1456 let project_id = session.project_id;
1457
1458 let has_remaining = agent.sessions.values().any(|s| s.project_id == project_id);
1459 if !has_remaining {
1460 agent.projects.remove(&project_id);
1461 }
1462
1463 session.pending_save
1464 })
1465 }
1466
1467 fn auth_methods(&self) -> &[acp::AuthMethod] {
1468 &[] // No auth for in-process
1469 }
1470
1471 fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
1472 Task::ready(Ok(()))
1473 }
1474
1475 fn model_selector(&self, session_id: &acp::SessionId) -> Option<Rc<dyn AgentModelSelector>> {
1476 Some(Rc::new(NativeAgentModelSelector {
1477 session_id: session_id.clone(),
1478 connection: self.clone(),
1479 }) as Rc<dyn AgentModelSelector>)
1480 }
1481
1482 fn prompt(
1483 &self,
1484 id: Option<acp_thread::UserMessageId>,
1485 params: acp::PromptRequest,
1486 cx: &mut App,
1487 ) -> Task<Result<acp::PromptResponse>> {
1488 let id = id.expect("UserMessageId is required");
1489 let session_id = params.session_id.clone();
1490 log::info!("Received prompt request for session: {}", session_id);
1491 log::debug!("Prompt blocks count: {}", params.prompt.len());
1492
1493 let Some(project_state) = self.0.read(cx).session_project_state(&session_id) else {
1494 return Task::ready(Err(anyhow::anyhow!("Session not found")));
1495 };
1496
1497 if let Some(parsed_command) = Command::parse(¶ms.prompt) {
1498 let registry = project_state.context_server_registry.read(cx);
1499
1500 let explicit_server_id = parsed_command
1501 .explicit_server_id
1502 .map(|server_id| ContextServerId(server_id.into()));
1503
1504 if let Some(prompt) =
1505 registry.find_prompt(explicit_server_id.as_ref(), parsed_command.prompt_name)
1506 {
1507 let arguments = if !parsed_command.arg_value.is_empty()
1508 && let Some(arg_name) = prompt
1509 .prompt
1510 .arguments
1511 .as_ref()
1512 .and_then(|args| args.first())
1513 .map(|arg| arg.name.clone())
1514 {
1515 HashMap::from_iter([(arg_name, parsed_command.arg_value.to_string())])
1516 } else {
1517 Default::default()
1518 };
1519
1520 let prompt_name = prompt.prompt.name.clone();
1521 let server_id = prompt.server_id.clone();
1522
1523 return self.0.update(cx, |agent, cx| {
1524 agent.send_mcp_prompt(
1525 id,
1526 session_id.clone(),
1527 prompt_name,
1528 server_id,
1529 arguments,
1530 params.prompt,
1531 cx,
1532 )
1533 });
1534 }
1535 };
1536
1537 let path_style = project_state.project.read(cx).path_style(cx);
1538
1539 self.run_turn(session_id, cx, move |thread, cx| {
1540 let content: Vec<UserMessageContent> = params
1541 .prompt
1542 .into_iter()
1543 .map(|block| UserMessageContent::from_content_block(block, path_style))
1544 .collect::<Vec<_>>();
1545 log::debug!("Converted prompt to message: {} chars", content.len());
1546 log::debug!("Message id: {:?}", id);
1547 log::debug!("Message content: {:?}", content);
1548
1549 thread.update(cx, |thread, cx| thread.send(id, content, cx))
1550 })
1551 }
1552
1553 fn retry(
1554 &self,
1555 session_id: &acp::SessionId,
1556 _cx: &App,
1557 ) -> Option<Rc<dyn acp_thread::AgentSessionRetry>> {
1558 Some(Rc::new(NativeAgentSessionRetry {
1559 connection: self.clone(),
1560 session_id: session_id.clone(),
1561 }) as _)
1562 }
1563
1564 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
1565 log::info!("Cancelling on session: {}", session_id);
1566 self.0.update(cx, |agent, cx| {
1567 if let Some(session) = agent.sessions.get(session_id) {
1568 session
1569 .thread
1570 .update(cx, |thread, cx| thread.cancel(cx))
1571 .detach();
1572 }
1573 });
1574 }
1575
1576 fn truncate(
1577 &self,
1578 session_id: &acp::SessionId,
1579 cx: &App,
1580 ) -> Option<Rc<dyn acp_thread::AgentSessionTruncate>> {
1581 self.0.read_with(cx, |agent, _cx| {
1582 agent.sessions.get(session_id).map(|session| {
1583 Rc::new(NativeAgentSessionTruncate {
1584 thread: session.thread.clone(),
1585 acp_thread: session.acp_thread.downgrade(),
1586 }) as _
1587 })
1588 })
1589 }
1590
1591 fn set_title(
1592 &self,
1593 session_id: &acp::SessionId,
1594 cx: &App,
1595 ) -> Option<Rc<dyn acp_thread::AgentSessionSetTitle>> {
1596 self.0.read_with(cx, |agent, _cx| {
1597 agent
1598 .sessions
1599 .get(session_id)
1600 .filter(|s| !s.thread.read(cx).is_subagent())
1601 .map(|session| {
1602 Rc::new(NativeAgentSessionSetTitle {
1603 thread: session.thread.clone(),
1604 }) as _
1605 })
1606 })
1607 }
1608
1609 fn session_list(&self, cx: &mut App) -> Option<Rc<dyn AgentSessionList>> {
1610 let thread_store = self.0.read(cx).thread_store.clone();
1611 Some(Rc::new(NativeAgentSessionList::new(thread_store, cx)) as _)
1612 }
1613
1614 fn telemetry(&self) -> Option<Rc<dyn acp_thread::AgentTelemetry>> {
1615 Some(Rc::new(self.clone()) as Rc<dyn acp_thread::AgentTelemetry>)
1616 }
1617
1618 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1619 self
1620 }
1621}
1622
1623impl acp_thread::AgentTelemetry for NativeAgentConnection {
1624 fn thread_data(
1625 &self,
1626 session_id: &acp::SessionId,
1627 cx: &mut App,
1628 ) -> Task<Result<serde_json::Value>> {
1629 let Some(session) = self.0.read(cx).sessions.get(session_id) else {
1630 return Task::ready(Err(anyhow!("Session not found")));
1631 };
1632
1633 let task = session.thread.read(cx).to_db(cx);
1634 cx.background_spawn(async move {
1635 serde_json::to_value(task.await).context("Failed to serialize thread")
1636 })
1637 }
1638}
1639
1640pub struct NativeAgentSessionList {
1641 thread_store: Entity<ThreadStore>,
1642 updates_tx: smol::channel::Sender<acp_thread::SessionListUpdate>,
1643 updates_rx: smol::channel::Receiver<acp_thread::SessionListUpdate>,
1644 _subscription: Subscription,
1645}
1646
1647impl NativeAgentSessionList {
1648 fn new(thread_store: Entity<ThreadStore>, cx: &mut App) -> Self {
1649 let (tx, rx) = smol::channel::unbounded();
1650 let this_tx = tx.clone();
1651 let subscription = cx.observe(&thread_store, move |_, _| {
1652 this_tx
1653 .try_send(acp_thread::SessionListUpdate::Refresh)
1654 .ok();
1655 });
1656 Self {
1657 thread_store,
1658 updates_tx: tx,
1659 updates_rx: rx,
1660 _subscription: subscription,
1661 }
1662 }
1663
1664 pub fn thread_store(&self) -> &Entity<ThreadStore> {
1665 &self.thread_store
1666 }
1667}
1668
1669impl AgentSessionList for NativeAgentSessionList {
1670 fn list_sessions(
1671 &self,
1672 _request: AgentSessionListRequest,
1673 cx: &mut App,
1674 ) -> Task<Result<AgentSessionListResponse>> {
1675 let sessions = self
1676 .thread_store
1677 .read(cx)
1678 .entries()
1679 .map(|entry| AgentSessionInfo::from(&entry))
1680 .collect();
1681 Task::ready(Ok(AgentSessionListResponse::new(sessions)))
1682 }
1683
1684 fn supports_delete(&self) -> bool {
1685 true
1686 }
1687
1688 fn delete_session(&self, session_id: &acp::SessionId, cx: &mut App) -> Task<Result<()>> {
1689 self.thread_store
1690 .update(cx, |store, cx| store.delete_thread(session_id.clone(), cx))
1691 }
1692
1693 fn delete_sessions(&self, cx: &mut App) -> Task<Result<()>> {
1694 self.thread_store
1695 .update(cx, |store, cx| store.delete_threads(cx))
1696 }
1697
1698 fn watch(
1699 &self,
1700 _cx: &mut App,
1701 ) -> Option<smol::channel::Receiver<acp_thread::SessionListUpdate>> {
1702 Some(self.updates_rx.clone())
1703 }
1704
1705 fn notify_refresh(&self) {
1706 self.updates_tx
1707 .try_send(acp_thread::SessionListUpdate::Refresh)
1708 .ok();
1709 }
1710
1711 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1712 self
1713 }
1714}
1715
1716struct NativeAgentSessionTruncate {
1717 thread: Entity<Thread>,
1718 acp_thread: WeakEntity<AcpThread>,
1719}
1720
1721impl acp_thread::AgentSessionTruncate for NativeAgentSessionTruncate {
1722 fn run(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
1723 match self.thread.update(cx, |thread, cx| {
1724 thread.truncate(message_id.clone(), cx)?;
1725 Ok(thread.latest_token_usage())
1726 }) {
1727 Ok(usage) => {
1728 self.acp_thread
1729 .update(cx, |thread, cx| {
1730 thread.update_token_usage(usage, cx);
1731 })
1732 .ok();
1733 Task::ready(Ok(()))
1734 }
1735 Err(error) => Task::ready(Err(error)),
1736 }
1737 }
1738}
1739
1740struct NativeAgentSessionRetry {
1741 connection: NativeAgentConnection,
1742 session_id: acp::SessionId,
1743}
1744
1745impl acp_thread::AgentSessionRetry for NativeAgentSessionRetry {
1746 fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>> {
1747 self.connection
1748 .run_turn(self.session_id.clone(), cx, |thread, cx| {
1749 thread.update(cx, |thread, cx| thread.resume(cx))
1750 })
1751 }
1752}
1753
1754struct NativeAgentSessionSetTitle {
1755 thread: Entity<Thread>,
1756}
1757
1758impl acp_thread::AgentSessionSetTitle for NativeAgentSessionSetTitle {
1759 fn run(&self, title: SharedString, cx: &mut App) -> Task<Result<()>> {
1760 self.thread
1761 .update(cx, |thread, cx| thread.set_title(title, cx));
1762 Task::ready(Ok(()))
1763 }
1764}
1765
1766pub struct NativeThreadEnvironment {
1767 agent: WeakEntity<NativeAgent>,
1768 thread: WeakEntity<Thread>,
1769 acp_thread: WeakEntity<AcpThread>,
1770}
1771
1772impl NativeThreadEnvironment {
1773 pub(crate) fn create_subagent_thread(
1774 &self,
1775 label: String,
1776 cx: &mut App,
1777 ) -> Result<Rc<dyn SubagentHandle>> {
1778 let Some(parent_thread_entity) = self.thread.upgrade() else {
1779 anyhow::bail!("Parent thread no longer exists".to_string());
1780 };
1781 let parent_thread = parent_thread_entity.read(cx);
1782 let current_depth = parent_thread.depth();
1783 let parent_session_id = parent_thread.id().clone();
1784
1785 if current_depth >= MAX_SUBAGENT_DEPTH {
1786 return Err(anyhow!(
1787 "Maximum subagent depth ({}) reached",
1788 MAX_SUBAGENT_DEPTH
1789 ));
1790 }
1791
1792 let subagent_thread: Entity<Thread> = cx.new(|cx| {
1793 let mut thread = Thread::new_subagent(&parent_thread_entity, cx);
1794 thread.set_title(label.into(), cx);
1795 thread
1796 });
1797
1798 let session_id = subagent_thread.read(cx).id().clone();
1799
1800 let acp_thread = self
1801 .agent
1802 .update(cx, |agent, cx| -> Result<Entity<AcpThread>> {
1803 let project_id = agent
1804 .sessions
1805 .get(&parent_session_id)
1806 .map(|s| s.project_id)
1807 .context("parent session not found")?;
1808 Ok(agent.register_session(subagent_thread.clone(), project_id, cx))
1809 })??;
1810
1811 let depth = current_depth + 1;
1812
1813 telemetry::event!(
1814 "Subagent Started",
1815 session = parent_thread_entity.read(cx).id().to_string(),
1816 subagent_session = session_id.to_string(),
1817 depth,
1818 is_resumed = false,
1819 );
1820
1821 self.prompt_subagent(session_id, subagent_thread, acp_thread)
1822 }
1823
1824 pub(crate) fn resume_subagent_thread(
1825 &self,
1826 session_id: acp::SessionId,
1827 cx: &mut App,
1828 ) -> Result<Rc<dyn SubagentHandle>> {
1829 let (subagent_thread, acp_thread) = self.agent.update(cx, |agent, _cx| {
1830 let session = agent
1831 .sessions
1832 .get(&session_id)
1833 .ok_or_else(|| anyhow!("No subagent session found with id {session_id}"))?;
1834 anyhow::Ok((session.thread.clone(), session.acp_thread.clone()))
1835 })??;
1836
1837 let depth = subagent_thread.read(cx).depth();
1838
1839 if let Some(parent_thread_entity) = self.thread.upgrade() {
1840 telemetry::event!(
1841 "Subagent Started",
1842 session = parent_thread_entity.read(cx).id().to_string(),
1843 subagent_session = session_id.to_string(),
1844 depth,
1845 is_resumed = true,
1846 );
1847 }
1848
1849 self.prompt_subagent(session_id, subagent_thread, acp_thread)
1850 }
1851
1852 fn prompt_subagent(
1853 &self,
1854 session_id: acp::SessionId,
1855 subagent_thread: Entity<Thread>,
1856 acp_thread: Entity<acp_thread::AcpThread>,
1857 ) -> Result<Rc<dyn SubagentHandle>> {
1858 let Some(parent_thread_entity) = self.thread.upgrade() else {
1859 anyhow::bail!("Parent thread no longer exists".to_string());
1860 };
1861 Ok(Rc::new(NativeSubagentHandle::new(
1862 session_id,
1863 subagent_thread,
1864 acp_thread,
1865 parent_thread_entity,
1866 )) as _)
1867 }
1868}
1869
1870impl ThreadEnvironment for NativeThreadEnvironment {
1871 fn create_terminal(
1872 &self,
1873 command: String,
1874 cwd: Option<PathBuf>,
1875 output_byte_limit: Option<u64>,
1876 cx: &mut AsyncApp,
1877 ) -> Task<Result<Rc<dyn TerminalHandle>>> {
1878 let task = self.acp_thread.update(cx, |thread, cx| {
1879 thread.create_terminal(command, vec![], vec![], cwd, output_byte_limit, cx)
1880 });
1881
1882 let acp_thread = self.acp_thread.clone();
1883 cx.spawn(async move |cx| {
1884 let terminal = task?.await?;
1885
1886 let (drop_tx, drop_rx) = oneshot::channel();
1887 let terminal_id = terminal.read_with(cx, |terminal, _cx| terminal.id().clone());
1888
1889 cx.spawn(async move |cx| {
1890 drop_rx.await.ok();
1891 acp_thread.update(cx, |thread, cx| thread.release_terminal(terminal_id, cx))
1892 })
1893 .detach();
1894
1895 let handle = AcpTerminalHandle {
1896 terminal,
1897 _drop_tx: Some(drop_tx),
1898 };
1899
1900 Ok(Rc::new(handle) as _)
1901 })
1902 }
1903
1904 fn create_subagent(&self, label: String, cx: &mut App) -> Result<Rc<dyn SubagentHandle>> {
1905 self.create_subagent_thread(label, cx)
1906 }
1907
1908 fn resume_subagent(
1909 &self,
1910 session_id: acp::SessionId,
1911 cx: &mut App,
1912 ) -> Result<Rc<dyn SubagentHandle>> {
1913 self.resume_subagent_thread(session_id, cx)
1914 }
1915}
1916
1917#[derive(Debug, Clone)]
1918enum SubagentPromptResult {
1919 Completed,
1920 Cancelled,
1921 ContextWindowWarning,
1922 Error(String),
1923}
1924
1925pub struct NativeSubagentHandle {
1926 session_id: acp::SessionId,
1927 parent_thread: WeakEntity<Thread>,
1928 subagent_thread: Entity<Thread>,
1929 acp_thread: Entity<acp_thread::AcpThread>,
1930}
1931
1932impl NativeSubagentHandle {
1933 fn new(
1934 session_id: acp::SessionId,
1935 subagent_thread: Entity<Thread>,
1936 acp_thread: Entity<acp_thread::AcpThread>,
1937 parent_thread_entity: Entity<Thread>,
1938 ) -> Self {
1939 NativeSubagentHandle {
1940 session_id,
1941 subagent_thread,
1942 parent_thread: parent_thread_entity.downgrade(),
1943 acp_thread,
1944 }
1945 }
1946}
1947
1948impl SubagentHandle for NativeSubagentHandle {
1949 fn id(&self) -> acp::SessionId {
1950 self.session_id.clone()
1951 }
1952
1953 fn num_entries(&self, cx: &App) -> usize {
1954 self.acp_thread.read(cx).entries().len()
1955 }
1956
1957 fn send(&self, message: String, cx: &AsyncApp) -> Task<Result<String>> {
1958 let thread = self.subagent_thread.clone();
1959 let acp_thread = self.acp_thread.clone();
1960 let subagent_session_id = self.session_id.clone();
1961 let parent_thread = self.parent_thread.clone();
1962
1963 cx.spawn(async move |cx| {
1964 let (task, _subscription) = cx.update(|cx| {
1965 let ratio_before_prompt = thread
1966 .read(cx)
1967 .latest_token_usage()
1968 .map(|usage| usage.ratio());
1969
1970 parent_thread
1971 .update(cx, |parent_thread, _cx| {
1972 parent_thread.register_running_subagent(thread.downgrade())
1973 })
1974 .ok();
1975
1976 let task = acp_thread.update(cx, |acp_thread, cx| {
1977 acp_thread.send(vec![message.into()], cx)
1978 });
1979
1980 let (token_limit_tx, token_limit_rx) = oneshot::channel::<()>();
1981 let mut token_limit_tx = Some(token_limit_tx);
1982
1983 let subscription = cx.subscribe(
1984 &thread,
1985 move |_thread, event: &TokenUsageUpdated, _cx| {
1986 if let Some(usage) = &event.0 {
1987 let old_ratio = ratio_before_prompt
1988 .clone()
1989 .unwrap_or(TokenUsageRatio::Normal);
1990 let new_ratio = usage.ratio();
1991 if old_ratio == TokenUsageRatio::Normal
1992 && new_ratio == TokenUsageRatio::Warning
1993 {
1994 if let Some(tx) = token_limit_tx.take() {
1995 tx.send(()).ok();
1996 }
1997 }
1998 }
1999 },
2000 );
2001
2002 let wait_for_prompt = cx
2003 .background_spawn(async move {
2004 futures::select! {
2005 response = task.fuse() => match response {
2006 Ok(Some(response)) => {
2007 match response.stop_reason {
2008 acp::StopReason::Cancelled => SubagentPromptResult::Cancelled,
2009 acp::StopReason::MaxTokens => SubagentPromptResult::Error("The agent reached the maximum number of tokens.".into()),
2010 acp::StopReason::MaxTurnRequests => SubagentPromptResult::Error("The agent reached the maximum number of allowed requests between user turns. Try prompting again.".into()),
2011 acp::StopReason::Refusal => SubagentPromptResult::Error("The agent refused to process that prompt. Try again.".into()),
2012 acp::StopReason::EndTurn | _ => SubagentPromptResult::Completed,
2013 }
2014 }
2015 Ok(None) => SubagentPromptResult::Error("No response from the agent. You can try messaging again.".into()),
2016 Err(error) => SubagentPromptResult::Error(error.to_string()),
2017 },
2018 _ = token_limit_rx.fuse() => SubagentPromptResult::ContextWindowWarning,
2019 }
2020 });
2021
2022 (wait_for_prompt, subscription)
2023 });
2024
2025 let result = match task.await {
2026 SubagentPromptResult::Completed => thread.read_with(cx, |thread, _cx| {
2027 thread
2028 .last_message()
2029 .and_then(|message| {
2030 let content = message.as_agent_message()?
2031 .content
2032 .iter()
2033 .filter_map(|c| match c {
2034 AgentMessageContent::Text(text) => Some(text.as_str()),
2035 _ => None,
2036 })
2037 .join("\n\n");
2038 if content.is_empty() {
2039 None
2040 } else {
2041 Some( content)
2042 }
2043 })
2044 .context("No response from subagent")
2045 }),
2046 SubagentPromptResult::Cancelled => Err(anyhow!("User canceled")),
2047 SubagentPromptResult::Error(message) => Err(anyhow!("{message}")),
2048 SubagentPromptResult::ContextWindowWarning => {
2049 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
2050 Err(anyhow!(
2051 "The agent is nearing the end of its context window and has been \
2052 stopped. You can prompt the thread again to have the agent wrap up \
2053 or hand off its work."
2054 ))
2055 }
2056 };
2057
2058 parent_thread
2059 .update(cx, |parent_thread, cx| {
2060 parent_thread.unregister_running_subagent(&subagent_session_id, cx)
2061 })
2062 .ok();
2063
2064 result
2065 })
2066 }
2067}
2068
2069pub struct AcpTerminalHandle {
2070 terminal: Entity<acp_thread::Terminal>,
2071 _drop_tx: Option<oneshot::Sender<()>>,
2072}
2073
2074impl TerminalHandle for AcpTerminalHandle {
2075 fn id(&self, cx: &AsyncApp) -> Result<acp::TerminalId> {
2076 Ok(self.terminal.read_with(cx, |term, _cx| term.id().clone()))
2077 }
2078
2079 fn wait_for_exit(&self, cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
2080 Ok(self
2081 .terminal
2082 .read_with(cx, |term, _cx| term.wait_for_exit()))
2083 }
2084
2085 fn current_output(&self, cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
2086 Ok(self
2087 .terminal
2088 .read_with(cx, |term, cx| term.current_output(cx)))
2089 }
2090
2091 fn kill(&self, cx: &AsyncApp) -> Result<()> {
2092 cx.update(|cx| {
2093 self.terminal.update(cx, |terminal, cx| {
2094 terminal.kill(cx);
2095 });
2096 });
2097 Ok(())
2098 }
2099
2100 fn was_stopped_by_user(&self, cx: &AsyncApp) -> Result<bool> {
2101 Ok(self
2102 .terminal
2103 .read_with(cx, |term, _cx| term.was_stopped_by_user()))
2104 }
2105}
2106
2107#[cfg(test)]
2108mod internal_tests {
2109 use std::path::Path;
2110
2111 use super::*;
2112 use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelInfo, MentionUri};
2113 use fs::FakeFs;
2114 use gpui::TestAppContext;
2115 use indoc::formatdoc;
2116 use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
2117 use language_model::{
2118 LanguageModelCompletionEvent, LanguageModelProviderId, LanguageModelProviderName,
2119 };
2120 use serde_json::json;
2121 use settings::SettingsStore;
2122 use util::{path, rel_path::rel_path};
2123
2124 #[gpui::test]
2125 async fn test_maintaining_project_context(cx: &mut TestAppContext) {
2126 init_test(cx);
2127 let fs = FakeFs::new(cx.executor());
2128 fs.insert_tree(
2129 "/",
2130 json!({
2131 "a": {}
2132 }),
2133 )
2134 .await;
2135 let project = Project::test(fs.clone(), [], cx).await;
2136 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2137 let agent =
2138 cx.update(|cx| NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx));
2139
2140 // Creating a session registers the project and triggers context building.
2141 let connection = NativeAgentConnection(agent.clone());
2142 let _acp_thread = cx
2143 .update(|cx| {
2144 Rc::new(connection).new_session(
2145 project.clone(),
2146 PathList::new(&[Path::new("/")]),
2147 cx,
2148 )
2149 })
2150 .await
2151 .unwrap();
2152 cx.run_until_parked();
2153
2154 let thread = agent.read_with(cx, |agent, _cx| {
2155 agent.sessions.values().next().unwrap().thread.clone()
2156 });
2157
2158 agent.read_with(cx, |agent, cx| {
2159 let project_id = project.entity_id();
2160 let state = agent.projects.get(&project_id).unwrap();
2161 assert_eq!(state.project_context.read(cx).worktrees, vec![]);
2162 assert_eq!(thread.read(cx).project_context().read(cx).worktrees, vec![]);
2163 });
2164
2165 let worktree = project
2166 .update(cx, |project, cx| project.create_worktree("/a", true, cx))
2167 .await
2168 .unwrap();
2169 cx.run_until_parked();
2170 agent.read_with(cx, |agent, cx| {
2171 let project_id = project.entity_id();
2172 let state = agent.projects.get(&project_id).unwrap();
2173 let expected_worktrees = vec![WorktreeContext {
2174 root_name: "a".into(),
2175 abs_path: Path::new("/a").into(),
2176 rules_file: None,
2177 }];
2178 assert_eq!(state.project_context.read(cx).worktrees, expected_worktrees);
2179 assert_eq!(
2180 thread.read(cx).project_context().read(cx).worktrees,
2181 expected_worktrees
2182 );
2183 });
2184
2185 // Creating `/a/.rules` updates the project context.
2186 fs.insert_file("/a/.rules", Vec::new()).await;
2187 cx.run_until_parked();
2188 agent.read_with(cx, |agent, cx| {
2189 let project_id = project.entity_id();
2190 let state = agent.projects.get(&project_id).unwrap();
2191 let rules_entry = worktree
2192 .read(cx)
2193 .entry_for_path(rel_path(".rules"))
2194 .unwrap();
2195 let expected_worktrees = vec![WorktreeContext {
2196 root_name: "a".into(),
2197 abs_path: Path::new("/a").into(),
2198 rules_file: Some(RulesFileContext {
2199 path_in_worktree: rel_path(".rules").into(),
2200 text: "".into(),
2201 project_entry_id: rules_entry.id.to_usize(),
2202 }),
2203 }];
2204 assert_eq!(state.project_context.read(cx).worktrees, expected_worktrees);
2205 assert_eq!(
2206 thread.read(cx).project_context().read(cx).worktrees,
2207 expected_worktrees
2208 );
2209 });
2210 }
2211
2212 #[gpui::test]
2213 async fn test_listing_models(cx: &mut TestAppContext) {
2214 init_test(cx);
2215 let fs = FakeFs::new(cx.executor());
2216 fs.insert_tree("/", json!({ "a": {} })).await;
2217 let project = Project::test(fs.clone(), [], cx).await;
2218 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2219 let connection =
2220 NativeAgentConnection(cx.update(|cx| {
2221 NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx)
2222 }));
2223
2224 // Create a thread/session
2225 let acp_thread = cx
2226 .update(|cx| {
2227 Rc::new(connection.clone()).new_session(
2228 project.clone(),
2229 PathList::new(&[Path::new("/a")]),
2230 cx,
2231 )
2232 })
2233 .await
2234 .unwrap();
2235
2236 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2237
2238 let models = cx
2239 .update(|cx| {
2240 connection
2241 .model_selector(&session_id)
2242 .unwrap()
2243 .list_models(cx)
2244 })
2245 .await
2246 .unwrap();
2247
2248 let acp_thread::AgentModelList::Grouped(models) = models else {
2249 panic!("Unexpected model group");
2250 };
2251 assert_eq!(
2252 models,
2253 IndexMap::from_iter([(
2254 AgentModelGroupName("Fake".into()),
2255 vec![AgentModelInfo {
2256 id: acp::ModelId::new("fake/fake"),
2257 name: "Fake".into(),
2258 description: None,
2259 icon: Some(acp_thread::AgentModelIcon::Named(
2260 ui::IconName::ZedAssistant
2261 )),
2262 is_latest: false,
2263 cost: None,
2264 }]
2265 )])
2266 );
2267 }
2268
2269 #[gpui::test]
2270 async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) {
2271 init_test(cx);
2272 let fs = FakeFs::new(cx.executor());
2273 fs.create_dir(paths::settings_file().parent().unwrap())
2274 .await
2275 .unwrap();
2276 fs.insert_file(
2277 paths::settings_file(),
2278 json!({
2279 "agent": {
2280 "default_model": {
2281 "provider": "foo",
2282 "model": "bar"
2283 }
2284 }
2285 })
2286 .to_string()
2287 .into_bytes(),
2288 )
2289 .await;
2290 let project = Project::test(fs.clone(), [], cx).await;
2291
2292 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2293
2294 // Create the agent and connection
2295 let agent =
2296 cx.update(|cx| NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx));
2297 let connection = NativeAgentConnection(agent.clone());
2298
2299 // Create a thread/session
2300 let acp_thread = cx
2301 .update(|cx| {
2302 Rc::new(connection.clone()).new_session(
2303 project.clone(),
2304 PathList::new(&[Path::new("/a")]),
2305 cx,
2306 )
2307 })
2308 .await
2309 .unwrap();
2310
2311 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2312
2313 // Select a model
2314 let selector = connection.model_selector(&session_id).unwrap();
2315 let model_id = acp::ModelId::new("fake/fake");
2316 cx.update(|cx| selector.select_model(model_id.clone(), cx))
2317 .await
2318 .unwrap();
2319
2320 // Verify the thread has the selected model
2321 agent.read_with(cx, |agent, _| {
2322 let session = agent.sessions.get(&session_id).unwrap();
2323 session.thread.read_with(cx, |thread, _| {
2324 assert_eq!(thread.model().unwrap().id().0, "fake");
2325 });
2326 });
2327
2328 cx.run_until_parked();
2329
2330 // Verify settings file was updated
2331 let settings_content = fs.load(paths::settings_file()).await.unwrap();
2332 let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
2333
2334 // Check that the agent settings contain the selected model
2335 assert_eq!(
2336 settings_json["agent"]["default_model"]["model"],
2337 json!("fake")
2338 );
2339 assert_eq!(
2340 settings_json["agent"]["default_model"]["provider"],
2341 json!("fake")
2342 );
2343
2344 // Register a thinking model and select it.
2345 cx.update(|cx| {
2346 let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2347 "fake-corp",
2348 "fake-thinking",
2349 "Fake Thinking",
2350 true,
2351 ));
2352 let thinking_provider = Arc::new(
2353 FakeLanguageModelProvider::new(
2354 LanguageModelProviderId::from("fake-corp".to_string()),
2355 LanguageModelProviderName::from("Fake Corp".to_string()),
2356 )
2357 .with_models(vec![thinking_model]),
2358 );
2359 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2360 registry.register_provider(thinking_provider, cx);
2361 });
2362 });
2363 agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2364
2365 let selector = connection.model_selector(&session_id).unwrap();
2366 cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2367 .await
2368 .unwrap();
2369 cx.run_until_parked();
2370
2371 // Verify enable_thinking was written to settings as true.
2372 let settings_content = fs.load(paths::settings_file()).await.unwrap();
2373 let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
2374 assert_eq!(
2375 settings_json["agent"]["default_model"]["enable_thinking"],
2376 json!(true),
2377 "selecting a thinking model should persist enable_thinking: true to settings"
2378 );
2379 }
2380
2381 #[gpui::test]
2382 async fn test_select_model_updates_thinking_enabled(cx: &mut TestAppContext) {
2383 init_test(cx);
2384 let fs = FakeFs::new(cx.executor());
2385 fs.create_dir(paths::settings_file().parent().unwrap())
2386 .await
2387 .unwrap();
2388 fs.insert_file(paths::settings_file(), b"{}".to_vec()).await;
2389 let project = Project::test(fs.clone(), [], cx).await;
2390
2391 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2392 let agent =
2393 cx.update(|cx| NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx));
2394 let connection = NativeAgentConnection(agent.clone());
2395
2396 let acp_thread = cx
2397 .update(|cx| {
2398 Rc::new(connection.clone()).new_session(
2399 project.clone(),
2400 PathList::new(&[Path::new("/a")]),
2401 cx,
2402 )
2403 })
2404 .await
2405 .unwrap();
2406 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2407
2408 // Register a second provider with a thinking model.
2409 cx.update(|cx| {
2410 let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2411 "fake-corp",
2412 "fake-thinking",
2413 "Fake Thinking",
2414 true,
2415 ));
2416 let thinking_provider = Arc::new(
2417 FakeLanguageModelProvider::new(
2418 LanguageModelProviderId::from("fake-corp".to_string()),
2419 LanguageModelProviderName::from("Fake Corp".to_string()),
2420 )
2421 .with_models(vec![thinking_model]),
2422 );
2423 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2424 registry.register_provider(thinking_provider, cx);
2425 });
2426 });
2427 // Refresh the agent's model list so it picks up the new provider.
2428 agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2429
2430 // Thread starts with thinking_enabled = false (the default).
2431 agent.read_with(cx, |agent, _| {
2432 let session = agent.sessions.get(&session_id).unwrap();
2433 session.thread.read_with(cx, |thread, _| {
2434 assert!(!thread.thinking_enabled(), "thinking defaults to false");
2435 });
2436 });
2437
2438 // Select the thinking model via select_model.
2439 let selector = connection.model_selector(&session_id).unwrap();
2440 cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2441 .await
2442 .unwrap();
2443
2444 // select_model should have enabled thinking based on the model's supports_thinking().
2445 agent.read_with(cx, |agent, _| {
2446 let session = agent.sessions.get(&session_id).unwrap();
2447 session.thread.read_with(cx, |thread, _| {
2448 assert!(
2449 thread.thinking_enabled(),
2450 "select_model should enable thinking when model supports it"
2451 );
2452 });
2453 });
2454
2455 // Switch back to the non-thinking model.
2456 let selector = connection.model_selector(&session_id).unwrap();
2457 cx.update(|cx| selector.select_model(acp::ModelId::new("fake/fake"), cx))
2458 .await
2459 .unwrap();
2460
2461 // select_model should have disabled thinking.
2462 agent.read_with(cx, |agent, _| {
2463 let session = agent.sessions.get(&session_id).unwrap();
2464 session.thread.read_with(cx, |thread, _| {
2465 assert!(
2466 !thread.thinking_enabled(),
2467 "select_model should disable thinking when model does not support it"
2468 );
2469 });
2470 });
2471 }
2472
2473 #[gpui::test]
2474 async fn test_summarization_model_survives_transient_registry_clearing(
2475 cx: &mut TestAppContext,
2476 ) {
2477 init_test(cx);
2478 let fs = FakeFs::new(cx.executor());
2479 fs.insert_tree("/", json!({ "a": {} })).await;
2480 let project = Project::test(fs.clone(), [], cx).await;
2481
2482 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2483 let agent =
2484 cx.update(|cx| NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx));
2485 let connection = Rc::new(NativeAgentConnection(agent.clone()));
2486
2487 let acp_thread = cx
2488 .update(|cx| {
2489 connection.clone().new_session(
2490 project.clone(),
2491 PathList::new(&[Path::new("/a")]),
2492 cx,
2493 )
2494 })
2495 .await
2496 .unwrap();
2497 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2498
2499 let thread = agent.read_with(cx, |agent, _| {
2500 agent.sessions.get(&session_id).unwrap().thread.clone()
2501 });
2502
2503 thread.read_with(cx, |thread, _| {
2504 assert!(
2505 thread.summarization_model().is_some(),
2506 "session should have a summarization model from the test registry"
2507 );
2508 });
2509
2510 // Simulate what happens during a provider blip:
2511 // update_active_language_model_from_settings calls set_default_model(None)
2512 // when it can't resolve the model, clearing all fallbacks.
2513 cx.update(|cx| {
2514 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2515 registry.set_default_model(None, cx);
2516 });
2517 });
2518 cx.run_until_parked();
2519
2520 thread.read_with(cx, |thread, _| {
2521 assert!(
2522 thread.summarization_model().is_some(),
2523 "summarization model should survive a transient default model clearing"
2524 );
2525 });
2526 }
2527
2528 #[gpui::test]
2529 async fn test_loaded_thread_preserves_thinking_enabled(cx: &mut TestAppContext) {
2530 init_test(cx);
2531 let fs = FakeFs::new(cx.executor());
2532 fs.insert_tree("/", json!({ "a": {} })).await;
2533 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2534 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2535 let agent = cx.update(|cx| {
2536 NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
2537 });
2538 let connection = Rc::new(NativeAgentConnection(agent.clone()));
2539
2540 // Register a thinking model.
2541 let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2542 "fake-corp",
2543 "fake-thinking",
2544 "Fake Thinking",
2545 true,
2546 ));
2547 let thinking_provider = Arc::new(
2548 FakeLanguageModelProvider::new(
2549 LanguageModelProviderId::from("fake-corp".to_string()),
2550 LanguageModelProviderName::from("Fake Corp".to_string()),
2551 )
2552 .with_models(vec![thinking_model.clone()]),
2553 );
2554 cx.update(|cx| {
2555 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2556 registry.register_provider(thinking_provider, cx);
2557 });
2558 });
2559 agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2560
2561 // Create a thread and select the thinking model.
2562 let acp_thread = cx
2563 .update(|cx| {
2564 connection.clone().new_session(
2565 project.clone(),
2566 PathList::new(&[Path::new("/a")]),
2567 cx,
2568 )
2569 })
2570 .await
2571 .unwrap();
2572 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2573
2574 let selector = connection.model_selector(&session_id).unwrap();
2575 cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2576 .await
2577 .unwrap();
2578
2579 // Verify thinking is enabled after selecting the thinking model.
2580 let thread = agent.read_with(cx, |agent, _| {
2581 agent.sessions.get(&session_id).unwrap().thread.clone()
2582 });
2583 thread.read_with(cx, |thread, _| {
2584 assert!(
2585 thread.thinking_enabled(),
2586 "thinking should be enabled after selecting thinking model"
2587 );
2588 });
2589
2590 // Send a message so the thread gets persisted.
2591 let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
2592 let send = cx.foreground_executor().spawn(send);
2593 cx.run_until_parked();
2594
2595 thinking_model.send_last_completion_stream_text_chunk("Response.");
2596 thinking_model.end_last_completion_stream();
2597
2598 send.await.unwrap();
2599 cx.run_until_parked();
2600
2601 // Close the session so it can be reloaded from disk.
2602 cx.update(|cx| connection.clone().close_session(&session_id, cx))
2603 .await
2604 .unwrap();
2605 drop(thread);
2606 drop(acp_thread);
2607 agent.read_with(cx, |agent, _| {
2608 assert!(agent.sessions.is_empty());
2609 });
2610
2611 // Reload the thread and verify thinking_enabled is still true.
2612 let reloaded_acp_thread = agent
2613 .update(cx, |agent, cx| {
2614 agent.open_thread(session_id.clone(), project.clone(), cx)
2615 })
2616 .await
2617 .unwrap();
2618 let reloaded_thread = agent.read_with(cx, |agent, _| {
2619 agent.sessions.get(&session_id).unwrap().thread.clone()
2620 });
2621 reloaded_thread.read_with(cx, |thread, _| {
2622 assert!(
2623 thread.thinking_enabled(),
2624 "thinking_enabled should be preserved when reloading a thread with a thinking model"
2625 );
2626 });
2627
2628 drop(reloaded_acp_thread);
2629 }
2630
2631 #[gpui::test]
2632 async fn test_loaded_thread_preserves_model(cx: &mut TestAppContext) {
2633 init_test(cx);
2634 let fs = FakeFs::new(cx.executor());
2635 fs.insert_tree("/", json!({ "a": {} })).await;
2636 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2637 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2638 let agent = cx.update(|cx| {
2639 NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
2640 });
2641 let connection = Rc::new(NativeAgentConnection(agent.clone()));
2642
2643 // Register a model where id() != name(), like real Anthropic models
2644 // (e.g. id="claude-sonnet-4-5-thinking-latest", name="Claude Sonnet 4.5 Thinking").
2645 let model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2646 "fake-corp",
2647 "custom-model-id",
2648 "Custom Model Display Name",
2649 false,
2650 ));
2651 let provider = Arc::new(
2652 FakeLanguageModelProvider::new(
2653 LanguageModelProviderId::from("fake-corp".to_string()),
2654 LanguageModelProviderName::from("Fake Corp".to_string()),
2655 )
2656 .with_models(vec![model.clone()]),
2657 );
2658 cx.update(|cx| {
2659 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2660 registry.register_provider(provider, cx);
2661 });
2662 });
2663 agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2664
2665 // Create a thread and select the model.
2666 let acp_thread = cx
2667 .update(|cx| {
2668 connection.clone().new_session(
2669 project.clone(),
2670 PathList::new(&[Path::new("/a")]),
2671 cx,
2672 )
2673 })
2674 .await
2675 .unwrap();
2676 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2677
2678 let selector = connection.model_selector(&session_id).unwrap();
2679 cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/custom-model-id"), cx))
2680 .await
2681 .unwrap();
2682
2683 let thread = agent.read_with(cx, |agent, _| {
2684 agent.sessions.get(&session_id).unwrap().thread.clone()
2685 });
2686 thread.read_with(cx, |thread, _| {
2687 assert_eq!(
2688 thread.model().unwrap().id().0.as_ref(),
2689 "custom-model-id",
2690 "model should be set before persisting"
2691 );
2692 });
2693
2694 // Send a message so the thread gets persisted.
2695 let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
2696 let send = cx.foreground_executor().spawn(send);
2697 cx.run_until_parked();
2698
2699 model.send_last_completion_stream_text_chunk("Response.");
2700 model.end_last_completion_stream();
2701
2702 send.await.unwrap();
2703 cx.run_until_parked();
2704
2705 // Close the session so it can be reloaded from disk.
2706 cx.update(|cx| connection.clone().close_session(&session_id, cx))
2707 .await
2708 .unwrap();
2709 drop(thread);
2710 drop(acp_thread);
2711 agent.read_with(cx, |agent, _| {
2712 assert!(agent.sessions.is_empty());
2713 });
2714
2715 // Reload the thread and verify the model was preserved.
2716 let reloaded_acp_thread = agent
2717 .update(cx, |agent, cx| {
2718 agent.open_thread(session_id.clone(), project.clone(), cx)
2719 })
2720 .await
2721 .unwrap();
2722 let reloaded_thread = agent.read_with(cx, |agent, _| {
2723 agent.sessions.get(&session_id).unwrap().thread.clone()
2724 });
2725 reloaded_thread.read_with(cx, |thread, _| {
2726 let reloaded_model = thread
2727 .model()
2728 .expect("model should be present after reload");
2729 assert_eq!(
2730 reloaded_model.id().0.as_ref(),
2731 "custom-model-id",
2732 "reloaded thread should have the same model, not fall back to the default"
2733 );
2734 });
2735
2736 drop(reloaded_acp_thread);
2737 }
2738
2739 #[gpui::test]
2740 async fn test_save_load_thread(cx: &mut TestAppContext) {
2741 init_test(cx);
2742 let fs = FakeFs::new(cx.executor());
2743 fs.insert_tree(
2744 "/",
2745 json!({
2746 "a": {
2747 "b.md": "Lorem"
2748 }
2749 }),
2750 )
2751 .await;
2752 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2753 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2754 let agent = cx.update(|cx| {
2755 NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
2756 });
2757 let connection = Rc::new(NativeAgentConnection(agent.clone()));
2758
2759 let acp_thread = cx
2760 .update(|cx| {
2761 connection
2762 .clone()
2763 .new_session(project.clone(), PathList::new(&[Path::new("")]), cx)
2764 })
2765 .await
2766 .unwrap();
2767 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2768 let thread = agent.read_with(cx, |agent, _| {
2769 agent.sessions.get(&session_id).unwrap().thread.clone()
2770 });
2771
2772 // Ensure empty threads are not saved, even if they get mutated.
2773 let model = Arc::new(FakeLanguageModel::default());
2774 let summary_model = Arc::new(FakeLanguageModel::default());
2775 thread.update(cx, |thread, cx| {
2776 thread.set_model(model.clone(), cx);
2777 thread.set_summarization_model(Some(summary_model.clone()), cx);
2778 });
2779 cx.run_until_parked();
2780 assert_eq!(thread_entries(&thread_store, cx), vec![]);
2781
2782 let send = acp_thread.update(cx, |thread, cx| {
2783 thread.send(
2784 vec![
2785 "What does ".into(),
2786 acp::ContentBlock::ResourceLink(acp::ResourceLink::new(
2787 "b.md",
2788 MentionUri::File {
2789 abs_path: path!("/a/b.md").into(),
2790 }
2791 .to_uri()
2792 .to_string(),
2793 )),
2794 " mean?".into(),
2795 ],
2796 cx,
2797 )
2798 });
2799 let send = cx.foreground_executor().spawn(send);
2800 cx.run_until_parked();
2801
2802 model.send_last_completion_stream_text_chunk("Lorem.");
2803 model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2804 language_model::TokenUsage {
2805 input_tokens: 150,
2806 output_tokens: 75,
2807 ..Default::default()
2808 },
2809 ));
2810 model.end_last_completion_stream();
2811 cx.run_until_parked();
2812 summary_model
2813 .send_last_completion_stream_text_chunk(&format!("Explaining {}", path!("/a/b.md")));
2814 summary_model.end_last_completion_stream();
2815
2816 send.await.unwrap();
2817 let uri = MentionUri::File {
2818 abs_path: path!("/a/b.md").into(),
2819 }
2820 .to_uri();
2821 acp_thread.read_with(cx, |thread, cx| {
2822 assert_eq!(
2823 thread.to_markdown(cx),
2824 formatdoc! {"
2825 ## User
2826
2827 What does [@b.md]({uri}) mean?
2828
2829 ## Assistant
2830
2831 Lorem.
2832
2833 "}
2834 )
2835 });
2836
2837 cx.run_until_parked();
2838
2839 // Set a draft prompt with rich content blocks and scroll position
2840 // AFTER run_until_parked, so the only save that captures these
2841 // changes is the one performed by close_session itself.
2842 let draft_blocks = vec![
2843 acp::ContentBlock::Text(acp::TextContent::new("Check out ")),
2844 acp::ContentBlock::ResourceLink(acp::ResourceLink::new("b.md", uri.to_string())),
2845 acp::ContentBlock::Text(acp::TextContent::new(" please")),
2846 ];
2847 acp_thread.update(cx, |thread, _cx| {
2848 thread.set_draft_prompt(Some(draft_blocks.clone()));
2849 });
2850 thread.update(cx, |thread, _cx| {
2851 thread.set_ui_scroll_position(Some(gpui::ListOffset {
2852 item_ix: 5,
2853 offset_in_item: gpui::px(12.5),
2854 }));
2855 });
2856
2857 // Close the session so it can be reloaded from disk.
2858 cx.update(|cx| connection.clone().close_session(&session_id, cx))
2859 .await
2860 .unwrap();
2861 drop(thread);
2862 drop(acp_thread);
2863 agent.read_with(cx, |agent, _| {
2864 assert_eq!(agent.sessions.keys().cloned().collect::<Vec<_>>(), []);
2865 });
2866
2867 // Ensure the thread can be reloaded from disk.
2868 assert_eq!(
2869 thread_entries(&thread_store, cx),
2870 vec![(
2871 session_id.clone(),
2872 format!("Explaining {}", path!("/a/b.md"))
2873 )]
2874 );
2875 let acp_thread = agent
2876 .update(cx, |agent, cx| {
2877 agent.open_thread(session_id.clone(), project.clone(), cx)
2878 })
2879 .await
2880 .unwrap();
2881 acp_thread.read_with(cx, |thread, cx| {
2882 assert_eq!(
2883 thread.to_markdown(cx),
2884 formatdoc! {"
2885 ## User
2886
2887 What does [@b.md]({uri}) mean?
2888
2889 ## Assistant
2890
2891 Lorem.
2892
2893 "}
2894 )
2895 });
2896
2897 // Ensure the draft prompt with rich content blocks survived the round-trip.
2898 acp_thread.read_with(cx, |thread, _| {
2899 assert_eq!(thread.draft_prompt(), Some(draft_blocks.as_slice()));
2900 });
2901
2902 // Ensure token usage survived the round-trip.
2903 acp_thread.read_with(cx, |thread, _| {
2904 let usage = thread
2905 .token_usage()
2906 .expect("token usage should be restored after reload");
2907 assert_eq!(usage.input_tokens, 150);
2908 assert_eq!(usage.output_tokens, 75);
2909 });
2910
2911 // Ensure scroll position survived the round-trip.
2912 acp_thread.read_with(cx, |thread, _| {
2913 let scroll = thread
2914 .ui_scroll_position()
2915 .expect("scroll position should be restored after reload");
2916 assert_eq!(scroll.item_ix, 5);
2917 assert_eq!(scroll.offset_in_item, gpui::px(12.5));
2918 });
2919 }
2920
2921 #[gpui::test]
2922 async fn test_close_session_saves_thread(cx: &mut TestAppContext) {
2923 init_test(cx);
2924 let fs = FakeFs::new(cx.executor());
2925 fs.insert_tree(
2926 "/",
2927 json!({
2928 "a": {
2929 "file.txt": "hello"
2930 }
2931 }),
2932 )
2933 .await;
2934 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2935 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2936 let agent = cx.update(|cx| {
2937 NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
2938 });
2939 let connection = Rc::new(NativeAgentConnection(agent.clone()));
2940
2941 let acp_thread = cx
2942 .update(|cx| {
2943 connection
2944 .clone()
2945 .new_session(project.clone(), PathList::new(&[Path::new("")]), cx)
2946 })
2947 .await
2948 .unwrap();
2949 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2950 let thread = agent.read_with(cx, |agent, _| {
2951 agent.sessions.get(&session_id).unwrap().thread.clone()
2952 });
2953
2954 let model = Arc::new(FakeLanguageModel::default());
2955 thread.update(cx, |thread, cx| {
2956 thread.set_model(model.clone(), cx);
2957 });
2958
2959 // Send a message so the thread is non-empty (empty threads aren't saved).
2960 let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx));
2961 let send = cx.foreground_executor().spawn(send);
2962 cx.run_until_parked();
2963
2964 model.send_last_completion_stream_text_chunk("world");
2965 model.end_last_completion_stream();
2966 send.await.unwrap();
2967 cx.run_until_parked();
2968
2969 // Set a draft prompt WITHOUT calling run_until_parked afterwards.
2970 // This means no observe-triggered save has run for this change.
2971 // The only way this data gets persisted is if close_session
2972 // itself performs the save.
2973 let draft_blocks = vec![acp::ContentBlock::Text(acp::TextContent::new(
2974 "unsaved draft",
2975 ))];
2976 acp_thread.update(cx, |thread, _cx| {
2977 thread.set_draft_prompt(Some(draft_blocks.clone()));
2978 });
2979
2980 // Close the session immediately — no run_until_parked in between.
2981 cx.update(|cx| connection.clone().close_session(&session_id, cx))
2982 .await
2983 .unwrap();
2984 cx.run_until_parked();
2985
2986 // Reopen and verify the draft prompt was saved.
2987 let reloaded = agent
2988 .update(cx, |agent, cx| {
2989 agent.open_thread(session_id.clone(), project.clone(), cx)
2990 })
2991 .await
2992 .unwrap();
2993 reloaded.read_with(cx, |thread, _| {
2994 assert_eq!(
2995 thread.draft_prompt(),
2996 Some(draft_blocks.as_slice()),
2997 "close_session must save the thread; draft prompt was lost"
2998 );
2999 });
3000 }
3001
3002 fn thread_entries(
3003 thread_store: &Entity<ThreadStore>,
3004 cx: &mut TestAppContext,
3005 ) -> Vec<(acp::SessionId, String)> {
3006 thread_store.read_with(cx, |store, _| {
3007 store
3008 .entries()
3009 .map(|entry| (entry.id.clone(), entry.title.to_string()))
3010 .collect::<Vec<_>>()
3011 })
3012 }
3013
3014 fn init_test(cx: &mut TestAppContext) {
3015 env_logger::try_init().ok();
3016 cx.update(|cx| {
3017 let settings_store = SettingsStore::test(cx);
3018 cx.set_global(settings_store);
3019
3020 LanguageModelRegistry::test(cx);
3021 });
3022 }
3023}
3024
3025fn mcp_message_content_to_acp_content_block(
3026 content: context_server::types::MessageContent,
3027) -> acp::ContentBlock {
3028 match content {
3029 context_server::types::MessageContent::Text {
3030 text,
3031 annotations: _,
3032 } => text.into(),
3033 context_server::types::MessageContent::Image {
3034 data,
3035 mime_type,
3036 annotations: _,
3037 } => acp::ContentBlock::Image(acp::ImageContent::new(data, mime_type)),
3038 context_server::types::MessageContent::Audio {
3039 data,
3040 mime_type,
3041 annotations: _,
3042 } => acp::ContentBlock::Audio(acp::AudioContent::new(data, mime_type)),
3043 context_server::types::MessageContent::Resource {
3044 resource,
3045 annotations: _,
3046 } => {
3047 let mut link =
3048 acp::ResourceLink::new(resource.uri.to_string(), resource.uri.to_string());
3049 if let Some(mime_type) = resource.mime_type {
3050 link = link.mime_type(mime_type);
3051 }
3052 acp::ContentBlock::ResourceLink(link)
3053 }
3054 }
3055}