1mod db;
2mod edit_agent;
3mod legacy_thread;
4mod native_agent_server;
5pub mod outline;
6mod pattern_extraction;
7mod templates;
8#[cfg(test)]
9mod tests;
10mod thread;
11mod thread_store;
12mod tool_permissions;
13mod tools;
14
15use context_server::ContextServerId;
16pub use db::*;
17use itertools::Itertools;
18pub use native_agent_server::NativeAgentServer;
19pub use pattern_extraction::*;
20pub use shell_command_parser::extract_commands;
21pub use templates::*;
22pub use thread::*;
23pub use thread_store::*;
24pub use tool_permissions::*;
25pub use tools::*;
26
27use acp_thread::{
28 AcpThread, AgentModelSelector, AgentSessionInfo, AgentSessionList, AgentSessionListRequest,
29 AgentSessionListResponse, TokenUsageRatio, UserMessageId,
30};
31use agent_client_protocol as acp;
32use anyhow::{Context as _, Result, anyhow};
33use chrono::{DateTime, Utc};
34use collections::{HashMap, HashSet, IndexMap};
35use fs::Fs;
36use futures::channel::{mpsc, oneshot};
37use futures::future::Shared;
38use futures::{FutureExt as _, StreamExt as _, future};
39use gpui::{
40 App, AppContext, AsyncApp, Context, Entity, EntityId, SharedString, Subscription, Task,
41 WeakEntity,
42};
43use language_model::{IconOrSvg, LanguageModel, LanguageModelProvider, LanguageModelRegistry};
44use project::{AgentId, Project, ProjectItem, ProjectPath, Worktree};
45use prompt_store::{
46 ProjectContext, PromptStore, RULES_FILE_NAMES, RulesFileContext, UserRulesContext,
47 WorktreeContext,
48};
49use serde::{Deserialize, Serialize};
50use settings::{LanguageModelSelection, update_settings_file};
51use std::any::Any;
52use std::path::PathBuf;
53use std::rc::Rc;
54use std::sync::{Arc, LazyLock};
55use util::ResultExt;
56use util::path_list::PathList;
57use util::rel_path::RelPath;
58
59#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
60pub struct ProjectSnapshot {
61 pub worktree_snapshots: Vec<project::telemetry_snapshot::TelemetryWorktreeSnapshot>,
62 pub timestamp: DateTime<Utc>,
63}
64
65pub struct RulesLoadingError {
66 pub message: SharedString,
67}
68
69struct ProjectState {
70 project: Entity<Project>,
71 project_context: Entity<ProjectContext>,
72 project_context_needs_refresh: watch::Sender<()>,
73 _maintain_project_context: Task<Result<()>>,
74 context_server_registry: Entity<ContextServerRegistry>,
75 _subscriptions: Vec<Subscription>,
76}
77
78/// Holds both the internal Thread and the AcpThread for a session
79struct Session {
80 /// The internal thread that processes messages
81 thread: Entity<Thread>,
82 /// The ACP thread that handles protocol communication
83 acp_thread: Entity<acp_thread::AcpThread>,
84 project_id: EntityId,
85 pending_save: Task<()>,
86 _subscriptions: Vec<Subscription>,
87}
88
89pub struct LanguageModels {
90 /// Access language model by ID
91 models: HashMap<acp::ModelId, Arc<dyn LanguageModel>>,
92 /// Cached list for returning language model information
93 model_list: acp_thread::AgentModelList,
94 refresh_models_rx: watch::Receiver<()>,
95 refresh_models_tx: watch::Sender<()>,
96 _authenticate_all_providers_task: Task<()>,
97}
98
99impl LanguageModels {
100 fn new(cx: &mut App) -> Self {
101 let (refresh_models_tx, refresh_models_rx) = watch::channel(());
102
103 let mut this = Self {
104 models: HashMap::default(),
105 model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()),
106 refresh_models_rx,
107 refresh_models_tx,
108 _authenticate_all_providers_task: Self::authenticate_all_language_model_providers(cx),
109 };
110 this.refresh_list(cx);
111 this
112 }
113
114 fn refresh_list(&mut self, cx: &App) {
115 let providers = LanguageModelRegistry::global(cx)
116 .read(cx)
117 .visible_providers()
118 .into_iter()
119 .filter(|provider| provider.is_authenticated(cx))
120 .collect::<Vec<_>>();
121
122 let mut language_model_list = IndexMap::default();
123 let mut recommended_models = HashSet::default();
124
125 let mut recommended = Vec::new();
126 for provider in &providers {
127 for model in provider.recommended_models(cx) {
128 recommended_models.insert((model.provider_id(), model.id()));
129 recommended.push(Self::map_language_model_to_info(&model, provider));
130 }
131 }
132 if !recommended.is_empty() {
133 language_model_list.insert(
134 acp_thread::AgentModelGroupName("Recommended".into()),
135 recommended,
136 );
137 }
138
139 let mut models = HashMap::default();
140 for provider in providers {
141 let mut provider_models = Vec::new();
142 for model in provider.provided_models(cx) {
143 let model_info = Self::map_language_model_to_info(&model, &provider);
144 let model_id = model_info.id.clone();
145 provider_models.push(model_info);
146 models.insert(model_id, model);
147 }
148 if !provider_models.is_empty() {
149 language_model_list.insert(
150 acp_thread::AgentModelGroupName(provider.name().0.clone()),
151 provider_models,
152 );
153 }
154 }
155
156 self.models = models;
157 self.model_list = acp_thread::AgentModelList::Grouped(language_model_list);
158 self.refresh_models_tx.send(()).ok();
159 }
160
161 fn watch(&self) -> watch::Receiver<()> {
162 self.refresh_models_rx.clone()
163 }
164
165 pub fn model_from_id(&self, model_id: &acp::ModelId) -> Option<Arc<dyn LanguageModel>> {
166 self.models.get(model_id).cloned()
167 }
168
169 fn map_language_model_to_info(
170 model: &Arc<dyn LanguageModel>,
171 provider: &Arc<dyn LanguageModelProvider>,
172 ) -> acp_thread::AgentModelInfo {
173 acp_thread::AgentModelInfo {
174 id: Self::model_id(model),
175 name: model.name().0,
176 description: None,
177 icon: Some(match provider.icon() {
178 IconOrSvg::Svg(path) => acp_thread::AgentModelIcon::Path(path),
179 IconOrSvg::Icon(name) => acp_thread::AgentModelIcon::Named(name),
180 }),
181 is_latest: model.is_latest(),
182 cost: model.model_cost_info().map(|cost| cost.to_shared_string()),
183 }
184 }
185
186 fn model_id(model: &Arc<dyn LanguageModel>) -> acp::ModelId {
187 acp::ModelId::new(format!("{}/{}", model.provider_id().0, model.id().0))
188 }
189
190 fn authenticate_all_language_model_providers(cx: &mut App) -> Task<()> {
191 let authenticate_all_providers = LanguageModelRegistry::global(cx)
192 .read(cx)
193 .visible_providers()
194 .iter()
195 .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
196 .collect::<Vec<_>>();
197
198 cx.background_spawn(async move {
199 for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
200 if let Err(err) = authenticate_task.await {
201 match err {
202 language_model::AuthenticateError::CredentialsNotFound => {
203 // Since we're authenticating these providers in the
204 // background for the purposes of populating the
205 // language selector, we don't care about providers
206 // where the credentials are not found.
207 }
208 language_model::AuthenticateError::ConnectionRefused => {
209 // Not logging connection refused errors as they are mostly from LM Studio's noisy auth failures.
210 // LM Studio only has one auth method (endpoint call) which fails for users who haven't enabled it.
211 // TODO: Better manage LM Studio auth logic to avoid these noisy failures.
212 }
213 _ => {
214 // Some providers have noisy failure states that we
215 // don't want to spam the logs with every time the
216 // language model selector is initialized.
217 //
218 // Ideally these should have more clear failure modes
219 // that we know are safe to ignore here, like what we do
220 // with `CredentialsNotFound` above.
221 match provider_id.0.as_ref() {
222 "lmstudio" | "ollama" => {
223 // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
224 //
225 // These fail noisily, so we don't log them.
226 }
227 "copilot_chat" => {
228 // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
229 }
230 _ => {
231 log::error!(
232 "Failed to authenticate provider: {}: {err:#}",
233 provider_name.0
234 );
235 }
236 }
237 }
238 }
239 }
240 }
241 })
242 }
243}
244
245pub struct NativeAgent {
246 /// Session ID -> Session mapping
247 sessions: HashMap<acp::SessionId, Session>,
248 thread_store: Entity<ThreadStore>,
249 /// Project-specific state keyed by project EntityId
250 projects: HashMap<EntityId, ProjectState>,
251 /// Shared templates for all threads
252 templates: Arc<Templates>,
253 /// Cached model information
254 models: LanguageModels,
255 prompt_store: Option<Entity<PromptStore>>,
256 fs: Arc<dyn Fs>,
257 _subscriptions: Vec<Subscription>,
258}
259
260impl NativeAgent {
261 pub fn new(
262 thread_store: Entity<ThreadStore>,
263 templates: Arc<Templates>,
264 prompt_store: Option<Entity<PromptStore>>,
265 fs: Arc<dyn Fs>,
266 cx: &mut App,
267 ) -> Entity<NativeAgent> {
268 log::debug!("Creating new NativeAgent");
269
270 cx.new(|cx| {
271 let mut subscriptions = vec![cx.subscribe(
272 &LanguageModelRegistry::global(cx),
273 Self::handle_models_updated_event,
274 )];
275 if let Some(prompt_store) = prompt_store.as_ref() {
276 subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
277 }
278
279 Self {
280 sessions: HashMap::default(),
281 thread_store,
282 projects: HashMap::default(),
283 templates,
284 models: LanguageModels::new(cx),
285 prompt_store,
286 fs,
287 _subscriptions: subscriptions,
288 }
289 })
290 }
291
292 fn new_session(
293 &mut self,
294 project: Entity<Project>,
295 cx: &mut Context<Self>,
296 ) -> Entity<AcpThread> {
297 let project_id = self.get_or_create_project_state(&project, cx);
298 let project_state = &self.projects[&project_id];
299
300 let registry = LanguageModelRegistry::read_global(cx);
301 let available_count = registry.available_models(cx).count();
302 log::debug!("Total available models: {}", available_count);
303
304 let default_model = registry.default_model().and_then(|default_model| {
305 self.models
306 .model_from_id(&LanguageModels::model_id(&default_model.model))
307 });
308 let thread = cx.new(|cx| {
309 Thread::new(
310 project,
311 project_state.project_context.clone(),
312 project_state.context_server_registry.clone(),
313 self.templates.clone(),
314 default_model,
315 cx,
316 )
317 });
318
319 self.register_session(thread, project_id, cx)
320 }
321
322 fn register_session(
323 &mut self,
324 thread_handle: Entity<Thread>,
325 project_id: EntityId,
326 cx: &mut Context<Self>,
327 ) -> Entity<AcpThread> {
328 let connection = Rc::new(NativeAgentConnection(cx.entity()));
329
330 let thread = thread_handle.read(cx);
331 let session_id = thread.id().clone();
332 let parent_session_id = thread.parent_thread_id();
333 let title = thread.title();
334 let draft_prompt = thread.draft_prompt().map(Vec::from);
335 let scroll_position = thread.ui_scroll_position();
336 let token_usage = thread.latest_token_usage();
337 let project = thread.project.clone();
338 let action_log = thread.action_log.clone();
339 let prompt_capabilities_rx = thread.prompt_capabilities_rx.clone();
340 let acp_thread = cx.new(|cx| {
341 let mut acp_thread = acp_thread::AcpThread::new(
342 parent_session_id,
343 title,
344 None,
345 connection,
346 project.clone(),
347 action_log.clone(),
348 session_id.clone(),
349 prompt_capabilities_rx,
350 cx,
351 );
352 acp_thread.set_draft_prompt(draft_prompt);
353 acp_thread.set_ui_scroll_position(scroll_position);
354 acp_thread.update_token_usage(token_usage, cx);
355 acp_thread
356 });
357
358 let registry = LanguageModelRegistry::read_global(cx);
359 let summarization_model = registry.thread_summary_model().map(|c| c.model);
360
361 let weak = cx.weak_entity();
362 let weak_thread = thread_handle.downgrade();
363 thread_handle.update(cx, |thread, cx| {
364 thread.set_summarization_model(summarization_model, cx);
365 thread.add_default_tools(
366 Rc::new(NativeThreadEnvironment {
367 acp_thread: acp_thread.downgrade(),
368 thread: weak_thread,
369 agent: weak,
370 }) as _,
371 cx,
372 )
373 });
374
375 let subscriptions = vec![
376 cx.subscribe(&thread_handle, Self::handle_thread_title_updated),
377 cx.subscribe(&thread_handle, Self::handle_thread_token_usage_updated),
378 cx.observe(&thread_handle, move |this, thread, cx| {
379 this.save_thread(thread, cx)
380 }),
381 ];
382
383 self.sessions.insert(
384 session_id,
385 Session {
386 thread: thread_handle,
387 acp_thread: acp_thread.clone(),
388 project_id,
389 _subscriptions: subscriptions,
390 pending_save: Task::ready(()),
391 },
392 );
393
394 self.update_available_commands_for_project(project_id, cx);
395
396 acp_thread
397 }
398
399 pub fn models(&self) -> &LanguageModels {
400 &self.models
401 }
402
403 fn get_or_create_project_state(
404 &mut self,
405 project: &Entity<Project>,
406 cx: &mut Context<Self>,
407 ) -> EntityId {
408 let project_id = project.entity_id();
409 if self.projects.contains_key(&project_id) {
410 return project_id;
411 }
412
413 let project_context = cx.new(|_| ProjectContext::new(vec![], vec![]));
414 self.register_project_with_initial_context(project.clone(), project_context, cx);
415 if let Some(state) = self.projects.get_mut(&project_id) {
416 state.project_context_needs_refresh.send(()).ok();
417 }
418 project_id
419 }
420
421 fn register_project_with_initial_context(
422 &mut self,
423 project: Entity<Project>,
424 project_context: Entity<ProjectContext>,
425 cx: &mut Context<Self>,
426 ) {
427 let project_id = project.entity_id();
428
429 let context_server_store = project.read(cx).context_server_store();
430 let context_server_registry =
431 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
432
433 let subscriptions = vec![
434 cx.subscribe(&project, Self::handle_project_event),
435 cx.subscribe(
436 &context_server_store,
437 Self::handle_context_server_store_updated,
438 ),
439 cx.subscribe(
440 &context_server_registry,
441 Self::handle_context_server_registry_event,
442 ),
443 ];
444
445 let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
446 watch::channel(());
447
448 self.projects.insert(
449 project_id,
450 ProjectState {
451 project,
452 project_context,
453 project_context_needs_refresh: project_context_needs_refresh_tx,
454 _maintain_project_context: cx.spawn(async move |this, cx| {
455 Self::maintain_project_context(
456 this,
457 project_id,
458 project_context_needs_refresh_rx,
459 cx,
460 )
461 .await
462 }),
463 context_server_registry,
464 _subscriptions: subscriptions,
465 },
466 );
467 }
468
469 fn session_project_state(&self, session_id: &acp::SessionId) -> Option<&ProjectState> {
470 self.sessions
471 .get(session_id)
472 .and_then(|session| self.projects.get(&session.project_id))
473 }
474
475 async fn maintain_project_context(
476 this: WeakEntity<Self>,
477 project_id: EntityId,
478 mut needs_refresh: watch::Receiver<()>,
479 cx: &mut AsyncApp,
480 ) -> Result<()> {
481 while needs_refresh.changed().await.is_ok() {
482 let project_context = this
483 .update(cx, |this, cx| {
484 let state = this
485 .projects
486 .get(&project_id)
487 .context("project state not found")?;
488 anyhow::Ok(Self::build_project_context(
489 &state.project,
490 this.prompt_store.as_ref(),
491 cx,
492 ))
493 })??
494 .await;
495 this.update(cx, |this, cx| {
496 if let Some(state) = this.projects.get(&project_id) {
497 state
498 .project_context
499 .update(cx, |current_project_context, _cx| {
500 *current_project_context = project_context;
501 });
502 }
503 })?;
504 }
505
506 Ok(())
507 }
508
509 fn build_project_context(
510 project: &Entity<Project>,
511 prompt_store: Option<&Entity<PromptStore>>,
512 cx: &mut App,
513 ) -> Task<ProjectContext> {
514 let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
515 let worktree_tasks = worktrees
516 .into_iter()
517 .map(|worktree| {
518 Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
519 })
520 .collect::<Vec<_>>();
521 let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
522 prompt_store.read_with(cx, |prompt_store, cx| {
523 let prompts = prompt_store.default_prompt_metadata();
524 let load_tasks = prompts.into_iter().map(|prompt_metadata| {
525 let contents = prompt_store.load(prompt_metadata.id, cx);
526 async move { (contents.await, prompt_metadata) }
527 });
528 cx.background_spawn(future::join_all(load_tasks))
529 })
530 } else {
531 Task::ready(vec![])
532 };
533
534 cx.spawn(async move |_cx| {
535 let (worktrees, default_user_rules) =
536 future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
537
538 let worktrees = worktrees
539 .into_iter()
540 .map(|(worktree, _rules_error)| {
541 // TODO: show error message
542 // if let Some(rules_error) = rules_error {
543 // this.update(cx, |_, cx| cx.emit(rules_error)).ok();
544 // }
545 worktree
546 })
547 .collect::<Vec<_>>();
548
549 let default_user_rules = default_user_rules
550 .into_iter()
551 .flat_map(|(contents, prompt_metadata)| match contents {
552 Ok(contents) => Some(UserRulesContext {
553 uuid: prompt_metadata.id.as_user()?,
554 title: prompt_metadata.title.map(|title| title.to_string()),
555 contents,
556 }),
557 Err(_err) => {
558 // TODO: show error message
559 // this.update(cx, |_, cx| {
560 // cx.emit(RulesLoadingError {
561 // message: format!("{err:?}").into(),
562 // });
563 // })
564 // .ok();
565 None
566 }
567 })
568 .collect::<Vec<_>>();
569
570 ProjectContext::new(worktrees, default_user_rules)
571 })
572 }
573
574 fn load_worktree_info_for_system_prompt(
575 worktree: Entity<Worktree>,
576 project: Entity<Project>,
577 cx: &mut App,
578 ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
579 let tree = worktree.read(cx);
580 let root_name = tree.root_name_str().into();
581 let abs_path = tree.abs_path();
582
583 let mut context = WorktreeContext {
584 root_name,
585 abs_path,
586 rules_file: None,
587 };
588
589 let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
590 let Some(rules_task) = rules_task else {
591 return Task::ready((context, None));
592 };
593
594 cx.spawn(async move |_| {
595 let (rules_file, rules_file_error) = match rules_task.await {
596 Ok(rules_file) => (Some(rules_file), None),
597 Err(err) => (
598 None,
599 Some(RulesLoadingError {
600 message: format!("{err}").into(),
601 }),
602 ),
603 };
604 context.rules_file = rules_file;
605 (context, rules_file_error)
606 })
607 }
608
609 fn load_worktree_rules_file(
610 worktree: Entity<Worktree>,
611 project: Entity<Project>,
612 cx: &mut App,
613 ) -> Option<Task<Result<RulesFileContext>>> {
614 let worktree = worktree.read(cx);
615 let worktree_id = worktree.id();
616 let selected_rules_file = RULES_FILE_NAMES
617 .into_iter()
618 .filter_map(|name| {
619 worktree
620 .entry_for_path(RelPath::unix(name).unwrap())
621 .filter(|entry| entry.is_file())
622 .map(|entry| entry.path.clone())
623 })
624 .next();
625
626 // Note that Cline supports `.clinerules` being a directory, but that is not currently
627 // supported. This doesn't seem to occur often in GitHub repositories.
628 selected_rules_file.map(|path_in_worktree| {
629 let project_path = ProjectPath {
630 worktree_id,
631 path: path_in_worktree.clone(),
632 };
633 let buffer_task =
634 project.update(cx, |project, cx| project.open_buffer(project_path, cx));
635 let rope_task = cx.spawn(async move |cx| {
636 let buffer = buffer_task.await?;
637 let (project_entry_id, rope) = buffer.read_with(cx, |buffer, cx| {
638 let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
639 anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
640 })?;
641 anyhow::Ok((project_entry_id, rope))
642 });
643 // Build a string from the rope on a background thread.
644 cx.background_spawn(async move {
645 let (project_entry_id, rope) = rope_task.await?;
646 anyhow::Ok(RulesFileContext {
647 path_in_worktree,
648 text: rope.to_string().trim().to_string(),
649 project_entry_id: project_entry_id.to_usize(),
650 })
651 })
652 })
653 }
654
655 fn handle_thread_title_updated(
656 &mut self,
657 thread: Entity<Thread>,
658 _: &TitleUpdated,
659 cx: &mut Context<Self>,
660 ) {
661 let session_id = thread.read(cx).id();
662 let Some(session) = self.sessions.get(session_id) else {
663 return;
664 };
665 let thread = thread.downgrade();
666 let acp_thread = session.acp_thread.downgrade();
667 cx.spawn(async move |_, cx| {
668 let title = thread.read_with(cx, |thread, _| thread.title())?;
669 let task = acp_thread.update(cx, |acp_thread, cx| acp_thread.set_title(title, cx))?;
670 task.await
671 })
672 .detach_and_log_err(cx);
673 }
674
675 fn handle_thread_token_usage_updated(
676 &mut self,
677 thread: Entity<Thread>,
678 usage: &TokenUsageUpdated,
679 cx: &mut Context<Self>,
680 ) {
681 let Some(session) = self.sessions.get(thread.read(cx).id()) else {
682 return;
683 };
684 session.acp_thread.update(cx, |acp_thread, cx| {
685 acp_thread.update_token_usage(usage.0.clone(), cx);
686 });
687 }
688
689 fn handle_project_event(
690 &mut self,
691 project: Entity<Project>,
692 event: &project::Event,
693 _cx: &mut Context<Self>,
694 ) {
695 let project_id = project.entity_id();
696 let Some(state) = self.projects.get_mut(&project_id) else {
697 return;
698 };
699 match event {
700 project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
701 state.project_context_needs_refresh.send(()).ok();
702 }
703 project::Event::WorktreeUpdatedEntries(_, items) => {
704 if items.iter().any(|(path, _, _)| {
705 RULES_FILE_NAMES
706 .iter()
707 .any(|name| path.as_ref() == RelPath::unix(name).unwrap())
708 }) {
709 state.project_context_needs_refresh.send(()).ok();
710 }
711 }
712 _ => {}
713 }
714 }
715
716 fn handle_prompts_updated_event(
717 &mut self,
718 _prompt_store: Entity<PromptStore>,
719 _event: &prompt_store::PromptsUpdatedEvent,
720 _cx: &mut Context<Self>,
721 ) {
722 for state in self.projects.values_mut() {
723 state.project_context_needs_refresh.send(()).ok();
724 }
725 }
726
727 fn handle_models_updated_event(
728 &mut self,
729 _registry: Entity<LanguageModelRegistry>,
730 event: &language_model::Event,
731 cx: &mut Context<Self>,
732 ) {
733 self.models.refresh_list(cx);
734
735 let registry = LanguageModelRegistry::read_global(cx);
736 let default_model = registry.default_model().map(|m| m.model);
737 let summarization_model = registry.thread_summary_model().map(|m| m.model);
738
739 for session in self.sessions.values_mut() {
740 session.thread.update(cx, |thread, cx| {
741 if thread.model().is_none()
742 && let Some(model) = default_model.clone()
743 {
744 thread.set_model(model, cx);
745 cx.notify();
746 }
747 if let Some(model) = summarization_model.clone() {
748 if thread.summarization_model().is_none()
749 || matches!(event, language_model::Event::ThreadSummaryModelChanged)
750 {
751 thread.set_summarization_model(Some(model), cx);
752 }
753 }
754 });
755 }
756 }
757
758 fn handle_context_server_store_updated(
759 &mut self,
760 store: Entity<project::context_server_store::ContextServerStore>,
761 _event: &project::context_server_store::ServerStatusChangedEvent,
762 cx: &mut Context<Self>,
763 ) {
764 let project_id = self.projects.iter().find_map(|(id, state)| {
765 if *state.context_server_registry.read(cx).server_store() == store {
766 Some(*id)
767 } else {
768 None
769 }
770 });
771 if let Some(project_id) = project_id {
772 self.update_available_commands_for_project(project_id, cx);
773 }
774 }
775
776 fn handle_context_server_registry_event(
777 &mut self,
778 registry: Entity<ContextServerRegistry>,
779 event: &ContextServerRegistryEvent,
780 cx: &mut Context<Self>,
781 ) {
782 match event {
783 ContextServerRegistryEvent::ToolsChanged => {}
784 ContextServerRegistryEvent::PromptsChanged => {
785 let project_id = self.projects.iter().find_map(|(id, state)| {
786 if state.context_server_registry == registry {
787 Some(*id)
788 } else {
789 None
790 }
791 });
792 if let Some(project_id) = project_id {
793 self.update_available_commands_for_project(project_id, cx);
794 }
795 }
796 }
797 }
798
799 fn update_available_commands_for_project(&self, project_id: EntityId, cx: &mut Context<Self>) {
800 let available_commands =
801 Self::build_available_commands_for_project(self.projects.get(&project_id), cx);
802 for session in self.sessions.values() {
803 if session.project_id != project_id {
804 continue;
805 }
806 session.acp_thread.update(cx, |thread, cx| {
807 thread
808 .handle_session_update(
809 acp::SessionUpdate::AvailableCommandsUpdate(
810 acp::AvailableCommandsUpdate::new(available_commands.clone()),
811 ),
812 cx,
813 )
814 .log_err();
815 });
816 }
817 }
818
819 fn build_available_commands_for_project(
820 project_state: Option<&ProjectState>,
821 cx: &App,
822 ) -> Vec<acp::AvailableCommand> {
823 let Some(state) = project_state else {
824 return vec![];
825 };
826 let registry = state.context_server_registry.read(cx);
827
828 let mut prompt_name_counts: HashMap<&str, usize> = HashMap::default();
829 for context_server_prompt in registry.prompts() {
830 *prompt_name_counts
831 .entry(context_server_prompt.prompt.name.as_str())
832 .or_insert(0) += 1;
833 }
834
835 registry
836 .prompts()
837 .flat_map(|context_server_prompt| {
838 let prompt = &context_server_prompt.prompt;
839
840 let should_prefix = prompt_name_counts
841 .get(prompt.name.as_str())
842 .copied()
843 .unwrap_or(0)
844 > 1;
845
846 let name = if should_prefix {
847 format!("{}.{}", context_server_prompt.server_id, prompt.name)
848 } else {
849 prompt.name.clone()
850 };
851
852 let mut command = acp::AvailableCommand::new(
853 name,
854 prompt.description.clone().unwrap_or_default(),
855 );
856
857 match prompt.arguments.as_deref() {
858 Some([arg]) => {
859 let hint = format!("<{}>", arg.name);
860
861 command = command.input(acp::AvailableCommandInput::Unstructured(
862 acp::UnstructuredCommandInput::new(hint),
863 ));
864 }
865 Some([]) | None => {}
866 Some(_) => {
867 // skip >1 argument commands since we don't support them yet
868 return None;
869 }
870 }
871
872 Some(command)
873 })
874 .collect()
875 }
876
877 pub fn load_thread(
878 &mut self,
879 id: acp::SessionId,
880 project: Entity<Project>,
881 cx: &mut Context<Self>,
882 ) -> Task<Result<Entity<Thread>>> {
883 let database_future = ThreadsDatabase::connect(cx);
884 cx.spawn(async move |this, cx| {
885 let database = database_future.await.map_err(|err| anyhow!(err))?;
886 let db_thread = database
887 .load_thread(id.clone())
888 .await?
889 .with_context(|| format!("no thread found with ID: {id:?}"))?;
890
891 this.update(cx, |this, cx| {
892 let project_id = this.get_or_create_project_state(&project, cx);
893 let project_state = this
894 .projects
895 .get(&project_id)
896 .context("project state not found")?;
897 let summarization_model = LanguageModelRegistry::read_global(cx)
898 .thread_summary_model()
899 .map(|c| c.model);
900
901 Ok(cx.new(|cx| {
902 let mut thread = Thread::from_db(
903 id.clone(),
904 db_thread,
905 project_state.project.clone(),
906 project_state.project_context.clone(),
907 project_state.context_server_registry.clone(),
908 this.templates.clone(),
909 cx,
910 );
911 thread.set_summarization_model(summarization_model, cx);
912 thread
913 }))
914 })?
915 })
916 }
917
918 pub fn open_thread(
919 &mut self,
920 id: acp::SessionId,
921 project: Entity<Project>,
922 cx: &mut Context<Self>,
923 ) -> Task<Result<Entity<AcpThread>>> {
924 if let Some(session) = self.sessions.get(&id) {
925 return Task::ready(Ok(session.acp_thread.clone()));
926 }
927
928 let task = self.load_thread(id, project.clone(), cx);
929 cx.spawn(async move |this, cx| {
930 let thread = task.await?;
931 let acp_thread = this.update(cx, |this, cx| {
932 let project_id = this.get_or_create_project_state(&project, cx);
933 this.register_session(thread.clone(), project_id, cx)
934 })?;
935 let events = thread.update(cx, |thread, cx| thread.replay(cx));
936 cx.update(|cx| {
937 NativeAgentConnection::handle_thread_events(events, acp_thread.downgrade(), cx)
938 })
939 .await?;
940 Ok(acp_thread)
941 })
942 }
943
944 pub fn thread_summary(
945 &mut self,
946 id: acp::SessionId,
947 project: Entity<Project>,
948 cx: &mut Context<Self>,
949 ) -> Task<Result<SharedString>> {
950 let thread = self.open_thread(id.clone(), project, cx);
951 cx.spawn(async move |this, cx| {
952 let acp_thread = thread.await?;
953 let result = this
954 .update(cx, |this, cx| {
955 this.sessions
956 .get(&id)
957 .unwrap()
958 .thread
959 .update(cx, |thread, cx| thread.summary(cx))
960 })?
961 .await
962 .context("Failed to generate summary")?;
963 drop(acp_thread);
964 Ok(result)
965 })
966 }
967
968 fn save_thread(&mut self, thread: Entity<Thread>, cx: &mut Context<Self>) {
969 if thread.read(cx).is_empty() {
970 return;
971 }
972
973 let id = thread.read(cx).id().clone();
974 let Some(session) = self.sessions.get_mut(&id) else {
975 return;
976 };
977
978 let project_id = session.project_id;
979 let Some(state) = self.projects.get(&project_id) else {
980 return;
981 };
982
983 let folder_paths = PathList::new(
984 &state
985 .project
986 .read(cx)
987 .visible_worktrees(cx)
988 .map(|worktree| worktree.read(cx).abs_path().to_path_buf())
989 .collect::<Vec<_>>(),
990 );
991
992 let draft_prompt = session.acp_thread.read(cx).draft_prompt().map(Vec::from);
993 let database_future = ThreadsDatabase::connect(cx);
994 let db_thread = thread.update(cx, |thread, cx| {
995 thread.set_draft_prompt(draft_prompt);
996 thread.to_db(cx)
997 });
998 let thread_store = self.thread_store.clone();
999 session.pending_save = cx.spawn(async move |_, cx| {
1000 let Some(database) = database_future.await.map_err(|err| anyhow!(err)).log_err() else {
1001 return;
1002 };
1003 let db_thread = db_thread.await;
1004 database
1005 .save_thread(id, db_thread, folder_paths)
1006 .await
1007 .log_err();
1008 thread_store.update(cx, |store, cx| store.reload(cx));
1009 });
1010 }
1011
1012 fn send_mcp_prompt(
1013 &self,
1014 message_id: UserMessageId,
1015 session_id: acp::SessionId,
1016 prompt_name: String,
1017 server_id: ContextServerId,
1018 arguments: HashMap<String, String>,
1019 original_content: Vec<acp::ContentBlock>,
1020 cx: &mut Context<Self>,
1021 ) -> Task<Result<acp::PromptResponse>> {
1022 let Some(state) = self.session_project_state(&session_id) else {
1023 return Task::ready(Err(anyhow!("Project state not found for session")));
1024 };
1025 let server_store = state
1026 .context_server_registry
1027 .read(cx)
1028 .server_store()
1029 .clone();
1030 let path_style = state.project.read(cx).path_style(cx);
1031
1032 cx.spawn(async move |this, cx| {
1033 let prompt =
1034 crate::get_prompt(&server_store, &server_id, &prompt_name, arguments, cx).await?;
1035
1036 let (acp_thread, thread) = this.update(cx, |this, _cx| {
1037 let session = this
1038 .sessions
1039 .get(&session_id)
1040 .context("Failed to get session")?;
1041 anyhow::Ok((session.acp_thread.clone(), session.thread.clone()))
1042 })??;
1043
1044 let mut last_is_user = true;
1045
1046 thread.update(cx, |thread, cx| {
1047 thread.push_acp_user_block(
1048 message_id,
1049 original_content.into_iter().skip(1),
1050 path_style,
1051 cx,
1052 );
1053 });
1054
1055 for message in prompt.messages {
1056 let context_server::types::PromptMessage { role, content } = message;
1057 let block = mcp_message_content_to_acp_content_block(content);
1058
1059 match role {
1060 context_server::types::Role::User => {
1061 let id = acp_thread::UserMessageId::new();
1062
1063 acp_thread.update(cx, |acp_thread, cx| {
1064 acp_thread.push_user_content_block_with_indent(
1065 Some(id.clone()),
1066 block.clone(),
1067 true,
1068 cx,
1069 );
1070 });
1071
1072 thread.update(cx, |thread, cx| {
1073 thread.push_acp_user_block(id, [block], path_style, cx);
1074 });
1075 }
1076 context_server::types::Role::Assistant => {
1077 acp_thread.update(cx, |acp_thread, cx| {
1078 acp_thread.push_assistant_content_block_with_indent(
1079 block.clone(),
1080 false,
1081 true,
1082 cx,
1083 );
1084 });
1085
1086 thread.update(cx, |thread, cx| {
1087 thread.push_acp_agent_block(block, cx);
1088 });
1089 }
1090 }
1091
1092 last_is_user = role == context_server::types::Role::User;
1093 }
1094
1095 let response_stream = thread.update(cx, |thread, cx| {
1096 if last_is_user {
1097 thread.send_existing(cx)
1098 } else {
1099 // Resume if MCP prompt did not end with a user message
1100 thread.resume(cx)
1101 }
1102 })?;
1103
1104 cx.update(|cx| {
1105 NativeAgentConnection::handle_thread_events(
1106 response_stream,
1107 acp_thread.downgrade(),
1108 cx,
1109 )
1110 })
1111 .await
1112 })
1113 }
1114}
1115
1116/// Wrapper struct that implements the AgentConnection trait
1117#[derive(Clone)]
1118pub struct NativeAgentConnection(pub Entity<NativeAgent>);
1119
1120impl NativeAgentConnection {
1121 pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option<Entity<Thread>> {
1122 self.0
1123 .read(cx)
1124 .sessions
1125 .get(session_id)
1126 .map(|session| session.thread.clone())
1127 }
1128
1129 pub fn load_thread(
1130 &self,
1131 id: acp::SessionId,
1132 project: Entity<Project>,
1133 cx: &mut App,
1134 ) -> Task<Result<Entity<Thread>>> {
1135 self.0
1136 .update(cx, |this, cx| this.load_thread(id, project, cx))
1137 }
1138
1139 fn run_turn(
1140 &self,
1141 session_id: acp::SessionId,
1142 cx: &mut App,
1143 f: impl 'static
1144 + FnOnce(Entity<Thread>, &mut App) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>,
1145 ) -> Task<Result<acp::PromptResponse>> {
1146 let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
1147 agent
1148 .sessions
1149 .get_mut(&session_id)
1150 .map(|s| (s.thread.clone(), s.acp_thread.clone()))
1151 }) else {
1152 return Task::ready(Err(anyhow!("Session not found")));
1153 };
1154 log::debug!("Found session for: {}", session_id);
1155
1156 let response_stream = match f(thread, cx) {
1157 Ok(stream) => stream,
1158 Err(err) => return Task::ready(Err(err)),
1159 };
1160 Self::handle_thread_events(response_stream, acp_thread.downgrade(), cx)
1161 }
1162
1163 fn handle_thread_events(
1164 mut events: mpsc::UnboundedReceiver<Result<ThreadEvent>>,
1165 acp_thread: WeakEntity<AcpThread>,
1166 cx: &App,
1167 ) -> Task<Result<acp::PromptResponse>> {
1168 cx.spawn(async move |cx| {
1169 // Handle response stream and forward to session.acp_thread
1170 while let Some(result) = events.next().await {
1171 match result {
1172 Ok(event) => {
1173 log::trace!("Received completion event: {:?}", event);
1174
1175 match event {
1176 ThreadEvent::UserMessage(message) => {
1177 acp_thread.update(cx, |thread, cx| {
1178 for content in message.content {
1179 thread.push_user_content_block(
1180 Some(message.id.clone()),
1181 content.into(),
1182 cx,
1183 );
1184 }
1185 })?;
1186 }
1187 ThreadEvent::AgentText(text) => {
1188 acp_thread.update(cx, |thread, cx| {
1189 thread.push_assistant_content_block(text.into(), false, cx)
1190 })?;
1191 }
1192 ThreadEvent::AgentThinking(text) => {
1193 acp_thread.update(cx, |thread, cx| {
1194 thread.push_assistant_content_block(text.into(), true, cx)
1195 })?;
1196 }
1197 ThreadEvent::ToolCallAuthorization(ToolCallAuthorization {
1198 tool_call,
1199 options,
1200 response,
1201 context: _,
1202 }) => {
1203 let outcome_task = acp_thread.update(cx, |thread, cx| {
1204 thread.request_tool_call_authorization(tool_call, options, cx)
1205 })??;
1206 cx.background_spawn(async move {
1207 if let acp_thread::RequestPermissionOutcome::Selected(outcome) =
1208 outcome_task.await
1209 {
1210 response
1211 .send(outcome)
1212 .map(|_| anyhow!("authorization receiver was dropped"))
1213 .log_err();
1214 }
1215 })
1216 .detach();
1217 }
1218 ThreadEvent::ToolCall(tool_call) => {
1219 acp_thread.update(cx, |thread, cx| {
1220 thread.upsert_tool_call(tool_call, cx)
1221 })??;
1222 }
1223 ThreadEvent::ToolCallUpdate(update) => {
1224 acp_thread.update(cx, |thread, cx| {
1225 thread.update_tool_call(update, cx)
1226 })??;
1227 }
1228 ThreadEvent::Plan(plan) => {
1229 acp_thread.update(cx, |thread, cx| thread.update_plan(plan, cx))?;
1230 }
1231 ThreadEvent::SubagentSpawned(session_id) => {
1232 acp_thread.update(cx, |thread, cx| {
1233 thread.subagent_spawned(session_id, cx);
1234 })?;
1235 }
1236 ThreadEvent::Retry(status) => {
1237 acp_thread.update(cx, |thread, cx| {
1238 thread.update_retry_status(status, cx)
1239 })?;
1240 }
1241 ThreadEvent::Stop(stop_reason) => {
1242 log::debug!("Assistant message complete: {:?}", stop_reason);
1243 return Ok(acp::PromptResponse::new(stop_reason));
1244 }
1245 }
1246 }
1247 Err(e) => {
1248 log::error!("Error in model response stream: {:?}", e);
1249 return Err(e);
1250 }
1251 }
1252 }
1253
1254 log::debug!("Response stream completed");
1255 anyhow::Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
1256 })
1257 }
1258}
1259
1260struct Command<'a> {
1261 prompt_name: &'a str,
1262 arg_value: &'a str,
1263 explicit_server_id: Option<&'a str>,
1264}
1265
1266impl<'a> Command<'a> {
1267 fn parse(prompt: &'a [acp::ContentBlock]) -> Option<Self> {
1268 let acp::ContentBlock::Text(text_content) = prompt.first()? else {
1269 return None;
1270 };
1271 let text = text_content.text.trim();
1272 let command = text.strip_prefix('/')?;
1273 let (command, arg_value) = command
1274 .split_once(char::is_whitespace)
1275 .unwrap_or((command, ""));
1276
1277 if let Some((server_id, prompt_name)) = command.split_once('.') {
1278 Some(Self {
1279 prompt_name,
1280 arg_value,
1281 explicit_server_id: Some(server_id),
1282 })
1283 } else {
1284 Some(Self {
1285 prompt_name: command,
1286 arg_value,
1287 explicit_server_id: None,
1288 })
1289 }
1290 }
1291}
1292
1293struct NativeAgentModelSelector {
1294 session_id: acp::SessionId,
1295 connection: NativeAgentConnection,
1296}
1297
1298impl acp_thread::AgentModelSelector for NativeAgentModelSelector {
1299 fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
1300 log::debug!("NativeAgentConnection::list_models called");
1301 let list = self.connection.0.read(cx).models.model_list.clone();
1302 Task::ready(if list.is_empty() {
1303 Err(anyhow::anyhow!("No models available"))
1304 } else {
1305 Ok(list)
1306 })
1307 }
1308
1309 fn select_model(&self, model_id: acp::ModelId, cx: &mut App) -> Task<Result<()>> {
1310 log::debug!(
1311 "Setting model for session {}: {}",
1312 self.session_id,
1313 model_id
1314 );
1315 let Some(thread) = self
1316 .connection
1317 .0
1318 .read(cx)
1319 .sessions
1320 .get(&self.session_id)
1321 .map(|session| session.thread.clone())
1322 else {
1323 return Task::ready(Err(anyhow!("Session not found")));
1324 };
1325
1326 let Some(model) = self.connection.0.read(cx).models.model_from_id(&model_id) else {
1327 return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
1328 };
1329
1330 // We want to reset the effort level when switching models, as the currently-selected effort level may
1331 // not be compatible.
1332 let effort = model
1333 .default_effort_level()
1334 .map(|effort_level| effort_level.value.to_string());
1335
1336 thread.update(cx, |thread, cx| {
1337 thread.set_model(model.clone(), cx);
1338 thread.set_thinking_effort(effort.clone(), cx);
1339 thread.set_thinking_enabled(model.supports_thinking(), cx);
1340 });
1341
1342 update_settings_file(
1343 self.connection.0.read(cx).fs.clone(),
1344 cx,
1345 move |settings, cx| {
1346 let provider = model.provider_id().0.to_string();
1347 let model = model.id().0.to_string();
1348 let enable_thinking = thread.read(cx).thinking_enabled();
1349 settings
1350 .agent
1351 .get_or_insert_default()
1352 .set_model(LanguageModelSelection {
1353 provider: provider.into(),
1354 model,
1355 enable_thinking,
1356 effort,
1357 });
1358 },
1359 );
1360
1361 Task::ready(Ok(()))
1362 }
1363
1364 fn selected_model(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelInfo>> {
1365 let Some(thread) = self
1366 .connection
1367 .0
1368 .read(cx)
1369 .sessions
1370 .get(&self.session_id)
1371 .map(|session| session.thread.clone())
1372 else {
1373 return Task::ready(Err(anyhow!("Session not found")));
1374 };
1375 let Some(model) = thread.read(cx).model() else {
1376 return Task::ready(Err(anyhow!("Model not found")));
1377 };
1378 let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
1379 else {
1380 return Task::ready(Err(anyhow!("Provider not found")));
1381 };
1382 Task::ready(Ok(LanguageModels::map_language_model_to_info(
1383 model, &provider,
1384 )))
1385 }
1386
1387 fn watch(&self, cx: &mut App) -> Option<watch::Receiver<()>> {
1388 Some(self.connection.0.read(cx).models.watch())
1389 }
1390
1391 fn should_render_footer(&self) -> bool {
1392 true
1393 }
1394}
1395
1396pub static ZED_AGENT_ID: LazyLock<AgentId> = LazyLock::new(|| AgentId::new("Zed Agent"));
1397
1398impl acp_thread::AgentConnection for NativeAgentConnection {
1399 fn agent_id(&self) -> AgentId {
1400 ZED_AGENT_ID.clone()
1401 }
1402
1403 fn telemetry_id(&self) -> SharedString {
1404 "zed".into()
1405 }
1406
1407 fn new_session(
1408 self: Rc<Self>,
1409 project: Entity<Project>,
1410 work_dirs: PathList,
1411 cx: &mut App,
1412 ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
1413 log::debug!("Creating new thread for project at: {work_dirs:?}");
1414 Task::ready(Ok(self
1415 .0
1416 .update(cx, |agent, cx| agent.new_session(project, cx))))
1417 }
1418
1419 fn supports_load_session(&self) -> bool {
1420 true
1421 }
1422
1423 fn load_session(
1424 self: Rc<Self>,
1425 session_id: acp::SessionId,
1426 project: Entity<Project>,
1427 _work_dirs: PathList,
1428 _title: Option<SharedString>,
1429 cx: &mut App,
1430 ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
1431 self.0
1432 .update(cx, |agent, cx| agent.open_thread(session_id, project, cx))
1433 }
1434
1435 fn supports_close_session(&self) -> bool {
1436 true
1437 }
1438
1439 fn close_session(
1440 self: Rc<Self>,
1441 session_id: &acp::SessionId,
1442 cx: &mut App,
1443 ) -> Task<Result<()>> {
1444 self.0.update(cx, |agent, cx| {
1445 let Some(session) = agent.sessions.remove(session_id) else {
1446 return;
1447 };
1448 let project_id = session.project_id;
1449 agent.save_thread(session.thread, cx);
1450
1451 let has_remaining = agent.sessions.values().any(|s| s.project_id == project_id);
1452 if !has_remaining {
1453 agent.projects.remove(&project_id);
1454 }
1455 });
1456 Task::ready(Ok(()))
1457 }
1458
1459 fn auth_methods(&self) -> &[acp::AuthMethod] {
1460 &[] // No auth for in-process
1461 }
1462
1463 fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
1464 Task::ready(Ok(()))
1465 }
1466
1467 fn model_selector(&self, session_id: &acp::SessionId) -> Option<Rc<dyn AgentModelSelector>> {
1468 Some(Rc::new(NativeAgentModelSelector {
1469 session_id: session_id.clone(),
1470 connection: self.clone(),
1471 }) as Rc<dyn AgentModelSelector>)
1472 }
1473
1474 fn prompt(
1475 &self,
1476 id: Option<acp_thread::UserMessageId>,
1477 params: acp::PromptRequest,
1478 cx: &mut App,
1479 ) -> Task<Result<acp::PromptResponse>> {
1480 let id = id.expect("UserMessageId is required");
1481 let session_id = params.session_id.clone();
1482 log::info!("Received prompt request for session: {}", session_id);
1483 log::debug!("Prompt blocks count: {}", params.prompt.len());
1484
1485 let Some(project_state) = self.0.read(cx).session_project_state(&session_id) else {
1486 return Task::ready(Err(anyhow::anyhow!("Session not found")));
1487 };
1488
1489 if let Some(parsed_command) = Command::parse(¶ms.prompt) {
1490 let registry = project_state.context_server_registry.read(cx);
1491
1492 let explicit_server_id = parsed_command
1493 .explicit_server_id
1494 .map(|server_id| ContextServerId(server_id.into()));
1495
1496 if let Some(prompt) =
1497 registry.find_prompt(explicit_server_id.as_ref(), parsed_command.prompt_name)
1498 {
1499 let arguments = if !parsed_command.arg_value.is_empty()
1500 && let Some(arg_name) = prompt
1501 .prompt
1502 .arguments
1503 .as_ref()
1504 .and_then(|args| args.first())
1505 .map(|arg| arg.name.clone())
1506 {
1507 HashMap::from_iter([(arg_name, parsed_command.arg_value.to_string())])
1508 } else {
1509 Default::default()
1510 };
1511
1512 let prompt_name = prompt.prompt.name.clone();
1513 let server_id = prompt.server_id.clone();
1514
1515 return self.0.update(cx, |agent, cx| {
1516 agent.send_mcp_prompt(
1517 id,
1518 session_id.clone(),
1519 prompt_name,
1520 server_id,
1521 arguments,
1522 params.prompt,
1523 cx,
1524 )
1525 });
1526 }
1527 };
1528
1529 let path_style = project_state.project.read(cx).path_style(cx);
1530
1531 self.run_turn(session_id, cx, move |thread, cx| {
1532 let content: Vec<UserMessageContent> = params
1533 .prompt
1534 .into_iter()
1535 .map(|block| UserMessageContent::from_content_block(block, path_style))
1536 .collect::<Vec<_>>();
1537 log::debug!("Converted prompt to message: {} chars", content.len());
1538 log::debug!("Message id: {:?}", id);
1539 log::debug!("Message content: {:?}", content);
1540
1541 thread.update(cx, |thread, cx| thread.send(id, content, cx))
1542 })
1543 }
1544
1545 fn retry(
1546 &self,
1547 session_id: &acp::SessionId,
1548 _cx: &App,
1549 ) -> Option<Rc<dyn acp_thread::AgentSessionRetry>> {
1550 Some(Rc::new(NativeAgentSessionRetry {
1551 connection: self.clone(),
1552 session_id: session_id.clone(),
1553 }) as _)
1554 }
1555
1556 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
1557 log::info!("Cancelling on session: {}", session_id);
1558 self.0.update(cx, |agent, cx| {
1559 if let Some(session) = agent.sessions.get(session_id) {
1560 session
1561 .thread
1562 .update(cx, |thread, cx| thread.cancel(cx))
1563 .detach();
1564 }
1565 });
1566 }
1567
1568 fn truncate(
1569 &self,
1570 session_id: &acp::SessionId,
1571 cx: &App,
1572 ) -> Option<Rc<dyn acp_thread::AgentSessionTruncate>> {
1573 self.0.read_with(cx, |agent, _cx| {
1574 agent.sessions.get(session_id).map(|session| {
1575 Rc::new(NativeAgentSessionTruncate {
1576 thread: session.thread.clone(),
1577 acp_thread: session.acp_thread.downgrade(),
1578 }) as _
1579 })
1580 })
1581 }
1582
1583 fn set_title(
1584 &self,
1585 session_id: &acp::SessionId,
1586 cx: &App,
1587 ) -> Option<Rc<dyn acp_thread::AgentSessionSetTitle>> {
1588 self.0.read_with(cx, |agent, _cx| {
1589 agent
1590 .sessions
1591 .get(session_id)
1592 .filter(|s| !s.thread.read(cx).is_subagent())
1593 .map(|session| {
1594 Rc::new(NativeAgentSessionSetTitle {
1595 thread: session.thread.clone(),
1596 }) as _
1597 })
1598 })
1599 }
1600
1601 fn session_list(&self, cx: &mut App) -> Option<Rc<dyn AgentSessionList>> {
1602 let thread_store = self.0.read(cx).thread_store.clone();
1603 Some(Rc::new(NativeAgentSessionList::new(thread_store, cx)) as _)
1604 }
1605
1606 fn telemetry(&self) -> Option<Rc<dyn acp_thread::AgentTelemetry>> {
1607 Some(Rc::new(self.clone()) as Rc<dyn acp_thread::AgentTelemetry>)
1608 }
1609
1610 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1611 self
1612 }
1613}
1614
1615impl acp_thread::AgentTelemetry for NativeAgentConnection {
1616 fn thread_data(
1617 &self,
1618 session_id: &acp::SessionId,
1619 cx: &mut App,
1620 ) -> Task<Result<serde_json::Value>> {
1621 let Some(session) = self.0.read(cx).sessions.get(session_id) else {
1622 return Task::ready(Err(anyhow!("Session not found")));
1623 };
1624
1625 let task = session.thread.read(cx).to_db(cx);
1626 cx.background_spawn(async move {
1627 serde_json::to_value(task.await).context("Failed to serialize thread")
1628 })
1629 }
1630}
1631
1632pub struct NativeAgentSessionList {
1633 thread_store: Entity<ThreadStore>,
1634 updates_tx: smol::channel::Sender<acp_thread::SessionListUpdate>,
1635 updates_rx: smol::channel::Receiver<acp_thread::SessionListUpdate>,
1636 _subscription: Subscription,
1637}
1638
1639impl NativeAgentSessionList {
1640 fn new(thread_store: Entity<ThreadStore>, cx: &mut App) -> Self {
1641 let (tx, rx) = smol::channel::unbounded();
1642 let this_tx = tx.clone();
1643 let subscription = cx.observe(&thread_store, move |_, _| {
1644 this_tx
1645 .try_send(acp_thread::SessionListUpdate::Refresh)
1646 .ok();
1647 });
1648 Self {
1649 thread_store,
1650 updates_tx: tx,
1651 updates_rx: rx,
1652 _subscription: subscription,
1653 }
1654 }
1655
1656 pub fn thread_store(&self) -> &Entity<ThreadStore> {
1657 &self.thread_store
1658 }
1659}
1660
1661impl AgentSessionList for NativeAgentSessionList {
1662 fn list_sessions(
1663 &self,
1664 _request: AgentSessionListRequest,
1665 cx: &mut App,
1666 ) -> Task<Result<AgentSessionListResponse>> {
1667 let sessions = self
1668 .thread_store
1669 .read(cx)
1670 .entries()
1671 .map(|entry| AgentSessionInfo::from(&entry))
1672 .collect();
1673 Task::ready(Ok(AgentSessionListResponse::new(sessions)))
1674 }
1675
1676 fn supports_delete(&self) -> bool {
1677 true
1678 }
1679
1680 fn delete_session(&self, session_id: &acp::SessionId, cx: &mut App) -> Task<Result<()>> {
1681 self.thread_store
1682 .update(cx, |store, cx| store.delete_thread(session_id.clone(), cx))
1683 }
1684
1685 fn delete_sessions(&self, cx: &mut App) -> Task<Result<()>> {
1686 self.thread_store
1687 .update(cx, |store, cx| store.delete_threads(cx))
1688 }
1689
1690 fn watch(
1691 &self,
1692 _cx: &mut App,
1693 ) -> Option<smol::channel::Receiver<acp_thread::SessionListUpdate>> {
1694 Some(self.updates_rx.clone())
1695 }
1696
1697 fn notify_refresh(&self) {
1698 self.updates_tx
1699 .try_send(acp_thread::SessionListUpdate::Refresh)
1700 .ok();
1701 }
1702
1703 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1704 self
1705 }
1706}
1707
1708struct NativeAgentSessionTruncate {
1709 thread: Entity<Thread>,
1710 acp_thread: WeakEntity<AcpThread>,
1711}
1712
1713impl acp_thread::AgentSessionTruncate for NativeAgentSessionTruncate {
1714 fn run(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
1715 match self.thread.update(cx, |thread, cx| {
1716 thread.truncate(message_id.clone(), cx)?;
1717 Ok(thread.latest_token_usage())
1718 }) {
1719 Ok(usage) => {
1720 self.acp_thread
1721 .update(cx, |thread, cx| {
1722 thread.update_token_usage(usage, cx);
1723 })
1724 .ok();
1725 Task::ready(Ok(()))
1726 }
1727 Err(error) => Task::ready(Err(error)),
1728 }
1729 }
1730}
1731
1732struct NativeAgentSessionRetry {
1733 connection: NativeAgentConnection,
1734 session_id: acp::SessionId,
1735}
1736
1737impl acp_thread::AgentSessionRetry for NativeAgentSessionRetry {
1738 fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>> {
1739 self.connection
1740 .run_turn(self.session_id.clone(), cx, |thread, cx| {
1741 thread.update(cx, |thread, cx| thread.resume(cx))
1742 })
1743 }
1744}
1745
1746struct NativeAgentSessionSetTitle {
1747 thread: Entity<Thread>,
1748}
1749
1750impl acp_thread::AgentSessionSetTitle for NativeAgentSessionSetTitle {
1751 fn run(&self, title: SharedString, cx: &mut App) -> Task<Result<()>> {
1752 self.thread
1753 .update(cx, |thread, cx| thread.set_title(title, cx));
1754 Task::ready(Ok(()))
1755 }
1756}
1757
1758pub struct NativeThreadEnvironment {
1759 agent: WeakEntity<NativeAgent>,
1760 thread: WeakEntity<Thread>,
1761 acp_thread: WeakEntity<AcpThread>,
1762}
1763
1764impl NativeThreadEnvironment {
1765 pub(crate) fn create_subagent_thread(
1766 &self,
1767 label: String,
1768 cx: &mut App,
1769 ) -> Result<Rc<dyn SubagentHandle>> {
1770 let Some(parent_thread_entity) = self.thread.upgrade() else {
1771 anyhow::bail!("Parent thread no longer exists".to_string());
1772 };
1773 let parent_thread = parent_thread_entity.read(cx);
1774 let current_depth = parent_thread.depth();
1775 let parent_session_id = parent_thread.id().clone();
1776
1777 if current_depth >= MAX_SUBAGENT_DEPTH {
1778 return Err(anyhow!(
1779 "Maximum subagent depth ({}) reached",
1780 MAX_SUBAGENT_DEPTH
1781 ));
1782 }
1783
1784 let subagent_thread: Entity<Thread> = cx.new(|cx| {
1785 let mut thread = Thread::new_subagent(&parent_thread_entity, cx);
1786 thread.set_title(label.into(), cx);
1787 thread
1788 });
1789
1790 let session_id = subagent_thread.read(cx).id().clone();
1791
1792 let acp_thread = self
1793 .agent
1794 .update(cx, |agent, cx| -> Result<Entity<AcpThread>> {
1795 let project_id = agent
1796 .sessions
1797 .get(&parent_session_id)
1798 .map(|s| s.project_id)
1799 .context("parent session not found")?;
1800 Ok(agent.register_session(subagent_thread.clone(), project_id, cx))
1801 })??;
1802
1803 let depth = current_depth + 1;
1804
1805 telemetry::event!(
1806 "Subagent Started",
1807 session = parent_thread_entity.read(cx).id().to_string(),
1808 subagent_session = session_id.to_string(),
1809 depth,
1810 is_resumed = false,
1811 );
1812
1813 self.prompt_subagent(session_id, subagent_thread, acp_thread)
1814 }
1815
1816 pub(crate) fn resume_subagent_thread(
1817 &self,
1818 session_id: acp::SessionId,
1819 cx: &mut App,
1820 ) -> Result<Rc<dyn SubagentHandle>> {
1821 let (subagent_thread, acp_thread) = self.agent.update(cx, |agent, _cx| {
1822 let session = agent
1823 .sessions
1824 .get(&session_id)
1825 .ok_or_else(|| anyhow!("No subagent session found with id {session_id}"))?;
1826 anyhow::Ok((session.thread.clone(), session.acp_thread.clone()))
1827 })??;
1828
1829 let depth = subagent_thread.read(cx).depth();
1830
1831 if let Some(parent_thread_entity) = self.thread.upgrade() {
1832 telemetry::event!(
1833 "Subagent Started",
1834 session = parent_thread_entity.read(cx).id().to_string(),
1835 subagent_session = session_id.to_string(),
1836 depth,
1837 is_resumed = true,
1838 );
1839 }
1840
1841 self.prompt_subagent(session_id, subagent_thread, acp_thread)
1842 }
1843
1844 fn prompt_subagent(
1845 &self,
1846 session_id: acp::SessionId,
1847 subagent_thread: Entity<Thread>,
1848 acp_thread: Entity<acp_thread::AcpThread>,
1849 ) -> Result<Rc<dyn SubagentHandle>> {
1850 let Some(parent_thread_entity) = self.thread.upgrade() else {
1851 anyhow::bail!("Parent thread no longer exists".to_string());
1852 };
1853 Ok(Rc::new(NativeSubagentHandle::new(
1854 session_id,
1855 subagent_thread,
1856 acp_thread,
1857 parent_thread_entity,
1858 )) as _)
1859 }
1860}
1861
1862impl ThreadEnvironment for NativeThreadEnvironment {
1863 fn create_terminal(
1864 &self,
1865 command: String,
1866 cwd: Option<PathBuf>,
1867 output_byte_limit: Option<u64>,
1868 cx: &mut AsyncApp,
1869 ) -> Task<Result<Rc<dyn TerminalHandle>>> {
1870 let task = self.acp_thread.update(cx, |thread, cx| {
1871 thread.create_terminal(command, vec![], vec![], cwd, output_byte_limit, cx)
1872 });
1873
1874 let acp_thread = self.acp_thread.clone();
1875 cx.spawn(async move |cx| {
1876 let terminal = task?.await?;
1877
1878 let (drop_tx, drop_rx) = oneshot::channel();
1879 let terminal_id = terminal.read_with(cx, |terminal, _cx| terminal.id().clone());
1880
1881 cx.spawn(async move |cx| {
1882 drop_rx.await.ok();
1883 acp_thread.update(cx, |thread, cx| thread.release_terminal(terminal_id, cx))
1884 })
1885 .detach();
1886
1887 let handle = AcpTerminalHandle {
1888 terminal,
1889 _drop_tx: Some(drop_tx),
1890 };
1891
1892 Ok(Rc::new(handle) as _)
1893 })
1894 }
1895
1896 fn create_subagent(&self, label: String, cx: &mut App) -> Result<Rc<dyn SubagentHandle>> {
1897 self.create_subagent_thread(label, cx)
1898 }
1899
1900 fn resume_subagent(
1901 &self,
1902 session_id: acp::SessionId,
1903 cx: &mut App,
1904 ) -> Result<Rc<dyn SubagentHandle>> {
1905 self.resume_subagent_thread(session_id, cx)
1906 }
1907}
1908
1909#[derive(Debug, Clone)]
1910enum SubagentPromptResult {
1911 Completed,
1912 Cancelled,
1913 ContextWindowWarning,
1914 Error(String),
1915}
1916
1917pub struct NativeSubagentHandle {
1918 session_id: acp::SessionId,
1919 parent_thread: WeakEntity<Thread>,
1920 subagent_thread: Entity<Thread>,
1921 acp_thread: Entity<acp_thread::AcpThread>,
1922}
1923
1924impl NativeSubagentHandle {
1925 fn new(
1926 session_id: acp::SessionId,
1927 subagent_thread: Entity<Thread>,
1928 acp_thread: Entity<acp_thread::AcpThread>,
1929 parent_thread_entity: Entity<Thread>,
1930 ) -> Self {
1931 NativeSubagentHandle {
1932 session_id,
1933 subagent_thread,
1934 parent_thread: parent_thread_entity.downgrade(),
1935 acp_thread,
1936 }
1937 }
1938}
1939
1940impl SubagentHandle for NativeSubagentHandle {
1941 fn id(&self) -> acp::SessionId {
1942 self.session_id.clone()
1943 }
1944
1945 fn num_entries(&self, cx: &App) -> usize {
1946 self.acp_thread.read(cx).entries().len()
1947 }
1948
1949 fn send(&self, message: String, cx: &AsyncApp) -> Task<Result<String>> {
1950 let thread = self.subagent_thread.clone();
1951 let acp_thread = self.acp_thread.clone();
1952 let subagent_session_id = self.session_id.clone();
1953 let parent_thread = self.parent_thread.clone();
1954
1955 cx.spawn(async move |cx| {
1956 let (task, _subscription) = cx.update(|cx| {
1957 let ratio_before_prompt = thread
1958 .read(cx)
1959 .latest_token_usage()
1960 .map(|usage| usage.ratio());
1961
1962 parent_thread
1963 .update(cx, |parent_thread, _cx| {
1964 parent_thread.register_running_subagent(thread.downgrade())
1965 })
1966 .ok();
1967
1968 let task = acp_thread.update(cx, |acp_thread, cx| {
1969 acp_thread.send(vec![message.into()], cx)
1970 });
1971
1972 let (token_limit_tx, token_limit_rx) = oneshot::channel::<()>();
1973 let mut token_limit_tx = Some(token_limit_tx);
1974
1975 let subscription = cx.subscribe(
1976 &thread,
1977 move |_thread, event: &TokenUsageUpdated, _cx| {
1978 if let Some(usage) = &event.0 {
1979 let old_ratio = ratio_before_prompt
1980 .clone()
1981 .unwrap_or(TokenUsageRatio::Normal);
1982 let new_ratio = usage.ratio();
1983 if old_ratio == TokenUsageRatio::Normal
1984 && new_ratio == TokenUsageRatio::Warning
1985 {
1986 if let Some(tx) = token_limit_tx.take() {
1987 tx.send(()).ok();
1988 }
1989 }
1990 }
1991 },
1992 );
1993
1994 let wait_for_prompt = cx
1995 .background_spawn(async move {
1996 futures::select! {
1997 response = task.fuse() => match response {
1998 Ok(Some(response)) => {
1999 match response.stop_reason {
2000 acp::StopReason::Cancelled => SubagentPromptResult::Cancelled,
2001 acp::StopReason::MaxTokens => SubagentPromptResult::Error("The agent reached the maximum number of tokens.".into()),
2002 acp::StopReason::MaxTurnRequests => SubagentPromptResult::Error("The agent reached the maximum number of allowed requests between user turns. Try prompting again.".into()),
2003 acp::StopReason::Refusal => SubagentPromptResult::Error("The agent refused to process that prompt. Try again.".into()),
2004 acp::StopReason::EndTurn | _ => SubagentPromptResult::Completed,
2005 }
2006 }
2007 Ok(None) => SubagentPromptResult::Error("No response from the agent. You can try messaging again.".into()),
2008 Err(error) => SubagentPromptResult::Error(error.to_string()),
2009 },
2010 _ = token_limit_rx.fuse() => SubagentPromptResult::ContextWindowWarning,
2011 }
2012 });
2013
2014 (wait_for_prompt, subscription)
2015 });
2016
2017 let result = match task.await {
2018 SubagentPromptResult::Completed => thread.read_with(cx, |thread, _cx| {
2019 thread
2020 .last_message()
2021 .and_then(|message| {
2022 let content = message.as_agent_message()?
2023 .content
2024 .iter()
2025 .filter_map(|c| match c {
2026 AgentMessageContent::Text(text) => Some(text.as_str()),
2027 _ => None,
2028 })
2029 .join("\n\n");
2030 if content.is_empty() {
2031 None
2032 } else {
2033 Some( content)
2034 }
2035 })
2036 .context("No response from subagent")
2037 }),
2038 SubagentPromptResult::Cancelled => Err(anyhow!("User canceled")),
2039 SubagentPromptResult::Error(message) => Err(anyhow!("{message}")),
2040 SubagentPromptResult::ContextWindowWarning => {
2041 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
2042 Err(anyhow!(
2043 "The agent is nearing the end of its context window and has been \
2044 stopped. You can prompt the thread again to have the agent wrap up \
2045 or hand off its work."
2046 ))
2047 }
2048 };
2049
2050 parent_thread
2051 .update(cx, |parent_thread, cx| {
2052 parent_thread.unregister_running_subagent(&subagent_session_id, cx)
2053 })
2054 .ok();
2055
2056 result
2057 })
2058 }
2059}
2060
2061pub struct AcpTerminalHandle {
2062 terminal: Entity<acp_thread::Terminal>,
2063 _drop_tx: Option<oneshot::Sender<()>>,
2064}
2065
2066impl TerminalHandle for AcpTerminalHandle {
2067 fn id(&self, cx: &AsyncApp) -> Result<acp::TerminalId> {
2068 Ok(self.terminal.read_with(cx, |term, _cx| term.id().clone()))
2069 }
2070
2071 fn wait_for_exit(&self, cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
2072 Ok(self
2073 .terminal
2074 .read_with(cx, |term, _cx| term.wait_for_exit()))
2075 }
2076
2077 fn current_output(&self, cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
2078 Ok(self
2079 .terminal
2080 .read_with(cx, |term, cx| term.current_output(cx)))
2081 }
2082
2083 fn kill(&self, cx: &AsyncApp) -> Result<()> {
2084 cx.update(|cx| {
2085 self.terminal.update(cx, |terminal, cx| {
2086 terminal.kill(cx);
2087 });
2088 });
2089 Ok(())
2090 }
2091
2092 fn was_stopped_by_user(&self, cx: &AsyncApp) -> Result<bool> {
2093 Ok(self
2094 .terminal
2095 .read_with(cx, |term, _cx| term.was_stopped_by_user()))
2096 }
2097}
2098
2099#[cfg(test)]
2100mod internal_tests {
2101 use std::path::Path;
2102
2103 use super::*;
2104 use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelInfo, MentionUri};
2105 use fs::FakeFs;
2106 use gpui::TestAppContext;
2107 use indoc::formatdoc;
2108 use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
2109 use language_model::{
2110 LanguageModelCompletionEvent, LanguageModelProviderId, LanguageModelProviderName,
2111 };
2112 use serde_json::json;
2113 use settings::SettingsStore;
2114 use util::{path, rel_path::rel_path};
2115
2116 #[gpui::test]
2117 async fn test_maintaining_project_context(cx: &mut TestAppContext) {
2118 init_test(cx);
2119 let fs = FakeFs::new(cx.executor());
2120 fs.insert_tree(
2121 "/",
2122 json!({
2123 "a": {}
2124 }),
2125 )
2126 .await;
2127 let project = Project::test(fs.clone(), [], cx).await;
2128 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2129 let agent =
2130 cx.update(|cx| NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx));
2131
2132 // Creating a session registers the project and triggers context building.
2133 let connection = NativeAgentConnection(agent.clone());
2134 let _acp_thread = cx
2135 .update(|cx| {
2136 Rc::new(connection).new_session(
2137 project.clone(),
2138 PathList::new(&[Path::new("/")]),
2139 cx,
2140 )
2141 })
2142 .await
2143 .unwrap();
2144 cx.run_until_parked();
2145
2146 let thread = agent.read_with(cx, |agent, _cx| {
2147 agent.sessions.values().next().unwrap().thread.clone()
2148 });
2149
2150 agent.read_with(cx, |agent, cx| {
2151 let project_id = project.entity_id();
2152 let state = agent.projects.get(&project_id).unwrap();
2153 assert_eq!(state.project_context.read(cx).worktrees, vec![]);
2154 assert_eq!(thread.read(cx).project_context().read(cx).worktrees, vec![]);
2155 });
2156
2157 let worktree = project
2158 .update(cx, |project, cx| project.create_worktree("/a", true, cx))
2159 .await
2160 .unwrap();
2161 cx.run_until_parked();
2162 agent.read_with(cx, |agent, cx| {
2163 let project_id = project.entity_id();
2164 let state = agent.projects.get(&project_id).unwrap();
2165 let expected_worktrees = vec![WorktreeContext {
2166 root_name: "a".into(),
2167 abs_path: Path::new("/a").into(),
2168 rules_file: None,
2169 }];
2170 assert_eq!(state.project_context.read(cx).worktrees, expected_worktrees);
2171 assert_eq!(
2172 thread.read(cx).project_context().read(cx).worktrees,
2173 expected_worktrees
2174 );
2175 });
2176
2177 // Creating `/a/.rules` updates the project context.
2178 fs.insert_file("/a/.rules", Vec::new()).await;
2179 cx.run_until_parked();
2180 agent.read_with(cx, |agent, cx| {
2181 let project_id = project.entity_id();
2182 let state = agent.projects.get(&project_id).unwrap();
2183 let rules_entry = worktree
2184 .read(cx)
2185 .entry_for_path(rel_path(".rules"))
2186 .unwrap();
2187 let expected_worktrees = vec![WorktreeContext {
2188 root_name: "a".into(),
2189 abs_path: Path::new("/a").into(),
2190 rules_file: Some(RulesFileContext {
2191 path_in_worktree: rel_path(".rules").into(),
2192 text: "".into(),
2193 project_entry_id: rules_entry.id.to_usize(),
2194 }),
2195 }];
2196 assert_eq!(state.project_context.read(cx).worktrees, expected_worktrees);
2197 assert_eq!(
2198 thread.read(cx).project_context().read(cx).worktrees,
2199 expected_worktrees
2200 );
2201 });
2202 }
2203
2204 #[gpui::test]
2205 async fn test_listing_models(cx: &mut TestAppContext) {
2206 init_test(cx);
2207 let fs = FakeFs::new(cx.executor());
2208 fs.insert_tree("/", json!({ "a": {} })).await;
2209 let project = Project::test(fs.clone(), [], cx).await;
2210 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2211 let connection =
2212 NativeAgentConnection(cx.update(|cx| {
2213 NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx)
2214 }));
2215
2216 // Create a thread/session
2217 let acp_thread = cx
2218 .update(|cx| {
2219 Rc::new(connection.clone()).new_session(
2220 project.clone(),
2221 PathList::new(&[Path::new("/a")]),
2222 cx,
2223 )
2224 })
2225 .await
2226 .unwrap();
2227
2228 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2229
2230 let models = cx
2231 .update(|cx| {
2232 connection
2233 .model_selector(&session_id)
2234 .unwrap()
2235 .list_models(cx)
2236 })
2237 .await
2238 .unwrap();
2239
2240 let acp_thread::AgentModelList::Grouped(models) = models else {
2241 panic!("Unexpected model group");
2242 };
2243 assert_eq!(
2244 models,
2245 IndexMap::from_iter([(
2246 AgentModelGroupName("Fake".into()),
2247 vec![AgentModelInfo {
2248 id: acp::ModelId::new("fake/fake"),
2249 name: "Fake".into(),
2250 description: None,
2251 icon: Some(acp_thread::AgentModelIcon::Named(
2252 ui::IconName::ZedAssistant
2253 )),
2254 is_latest: false,
2255 cost: None,
2256 }]
2257 )])
2258 );
2259 }
2260
2261 #[gpui::test]
2262 async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) {
2263 init_test(cx);
2264 let fs = FakeFs::new(cx.executor());
2265 fs.create_dir(paths::settings_file().parent().unwrap())
2266 .await
2267 .unwrap();
2268 fs.insert_file(
2269 paths::settings_file(),
2270 json!({
2271 "agent": {
2272 "default_model": {
2273 "provider": "foo",
2274 "model": "bar"
2275 }
2276 }
2277 })
2278 .to_string()
2279 .into_bytes(),
2280 )
2281 .await;
2282 let project = Project::test(fs.clone(), [], cx).await;
2283
2284 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2285
2286 // Create the agent and connection
2287 let agent =
2288 cx.update(|cx| NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx));
2289 let connection = NativeAgentConnection(agent.clone());
2290
2291 // Create a thread/session
2292 let acp_thread = cx
2293 .update(|cx| {
2294 Rc::new(connection.clone()).new_session(
2295 project.clone(),
2296 PathList::new(&[Path::new("/a")]),
2297 cx,
2298 )
2299 })
2300 .await
2301 .unwrap();
2302
2303 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2304
2305 // Select a model
2306 let selector = connection.model_selector(&session_id).unwrap();
2307 let model_id = acp::ModelId::new("fake/fake");
2308 cx.update(|cx| selector.select_model(model_id.clone(), cx))
2309 .await
2310 .unwrap();
2311
2312 // Verify the thread has the selected model
2313 agent.read_with(cx, |agent, _| {
2314 let session = agent.sessions.get(&session_id).unwrap();
2315 session.thread.read_with(cx, |thread, _| {
2316 assert_eq!(thread.model().unwrap().id().0, "fake");
2317 });
2318 });
2319
2320 cx.run_until_parked();
2321
2322 // Verify settings file was updated
2323 let settings_content = fs.load(paths::settings_file()).await.unwrap();
2324 let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
2325
2326 // Check that the agent settings contain the selected model
2327 assert_eq!(
2328 settings_json["agent"]["default_model"]["model"],
2329 json!("fake")
2330 );
2331 assert_eq!(
2332 settings_json["agent"]["default_model"]["provider"],
2333 json!("fake")
2334 );
2335
2336 // Register a thinking model and select it.
2337 cx.update(|cx| {
2338 let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2339 "fake-corp",
2340 "fake-thinking",
2341 "Fake Thinking",
2342 true,
2343 ));
2344 let thinking_provider = Arc::new(
2345 FakeLanguageModelProvider::new(
2346 LanguageModelProviderId::from("fake-corp".to_string()),
2347 LanguageModelProviderName::from("Fake Corp".to_string()),
2348 )
2349 .with_models(vec![thinking_model]),
2350 );
2351 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2352 registry.register_provider(thinking_provider, cx);
2353 });
2354 });
2355 agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2356
2357 let selector = connection.model_selector(&session_id).unwrap();
2358 cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2359 .await
2360 .unwrap();
2361 cx.run_until_parked();
2362
2363 // Verify enable_thinking was written to settings as true.
2364 let settings_content = fs.load(paths::settings_file()).await.unwrap();
2365 let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
2366 assert_eq!(
2367 settings_json["agent"]["default_model"]["enable_thinking"],
2368 json!(true),
2369 "selecting a thinking model should persist enable_thinking: true to settings"
2370 );
2371 }
2372
2373 #[gpui::test]
2374 async fn test_select_model_updates_thinking_enabled(cx: &mut TestAppContext) {
2375 init_test(cx);
2376 let fs = FakeFs::new(cx.executor());
2377 fs.create_dir(paths::settings_file().parent().unwrap())
2378 .await
2379 .unwrap();
2380 fs.insert_file(paths::settings_file(), b"{}".to_vec()).await;
2381 let project = Project::test(fs.clone(), [], cx).await;
2382
2383 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2384 let agent =
2385 cx.update(|cx| NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx));
2386 let connection = NativeAgentConnection(agent.clone());
2387
2388 let acp_thread = cx
2389 .update(|cx| {
2390 Rc::new(connection.clone()).new_session(
2391 project.clone(),
2392 PathList::new(&[Path::new("/a")]),
2393 cx,
2394 )
2395 })
2396 .await
2397 .unwrap();
2398 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2399
2400 // Register a second provider with a thinking model.
2401 cx.update(|cx| {
2402 let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2403 "fake-corp",
2404 "fake-thinking",
2405 "Fake Thinking",
2406 true,
2407 ));
2408 let thinking_provider = Arc::new(
2409 FakeLanguageModelProvider::new(
2410 LanguageModelProviderId::from("fake-corp".to_string()),
2411 LanguageModelProviderName::from("Fake Corp".to_string()),
2412 )
2413 .with_models(vec![thinking_model]),
2414 );
2415 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2416 registry.register_provider(thinking_provider, cx);
2417 });
2418 });
2419 // Refresh the agent's model list so it picks up the new provider.
2420 agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2421
2422 // Thread starts with thinking_enabled = false (the default).
2423 agent.read_with(cx, |agent, _| {
2424 let session = agent.sessions.get(&session_id).unwrap();
2425 session.thread.read_with(cx, |thread, _| {
2426 assert!(!thread.thinking_enabled(), "thinking defaults to false");
2427 });
2428 });
2429
2430 // Select the thinking model via select_model.
2431 let selector = connection.model_selector(&session_id).unwrap();
2432 cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2433 .await
2434 .unwrap();
2435
2436 // select_model should have enabled thinking based on the model's supports_thinking().
2437 agent.read_with(cx, |agent, _| {
2438 let session = agent.sessions.get(&session_id).unwrap();
2439 session.thread.read_with(cx, |thread, _| {
2440 assert!(
2441 thread.thinking_enabled(),
2442 "select_model should enable thinking when model supports it"
2443 );
2444 });
2445 });
2446
2447 // Switch back to the non-thinking model.
2448 let selector = connection.model_selector(&session_id).unwrap();
2449 cx.update(|cx| selector.select_model(acp::ModelId::new("fake/fake"), cx))
2450 .await
2451 .unwrap();
2452
2453 // select_model should have disabled thinking.
2454 agent.read_with(cx, |agent, _| {
2455 let session = agent.sessions.get(&session_id).unwrap();
2456 session.thread.read_with(cx, |thread, _| {
2457 assert!(
2458 !thread.thinking_enabled(),
2459 "select_model should disable thinking when model does not support it"
2460 );
2461 });
2462 });
2463 }
2464
2465 #[gpui::test]
2466 async fn test_summarization_model_survives_transient_registry_clearing(
2467 cx: &mut TestAppContext,
2468 ) {
2469 init_test(cx);
2470 let fs = FakeFs::new(cx.executor());
2471 fs.insert_tree("/", json!({ "a": {} })).await;
2472 let project = Project::test(fs.clone(), [], cx).await;
2473
2474 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2475 let agent =
2476 cx.update(|cx| NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx));
2477 let connection = Rc::new(NativeAgentConnection(agent.clone()));
2478
2479 let acp_thread = cx
2480 .update(|cx| {
2481 connection.clone().new_session(
2482 project.clone(),
2483 PathList::new(&[Path::new("/a")]),
2484 cx,
2485 )
2486 })
2487 .await
2488 .unwrap();
2489 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2490
2491 let thread = agent.read_with(cx, |agent, _| {
2492 agent.sessions.get(&session_id).unwrap().thread.clone()
2493 });
2494
2495 thread.read_with(cx, |thread, _| {
2496 assert!(
2497 thread.summarization_model().is_some(),
2498 "session should have a summarization model from the test registry"
2499 );
2500 });
2501
2502 // Simulate what happens during a provider blip:
2503 // update_active_language_model_from_settings calls set_default_model(None)
2504 // when it can't resolve the model, clearing all fallbacks.
2505 cx.update(|cx| {
2506 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2507 registry.set_default_model(None, cx);
2508 });
2509 });
2510 cx.run_until_parked();
2511
2512 thread.read_with(cx, |thread, _| {
2513 assert!(
2514 thread.summarization_model().is_some(),
2515 "summarization model should survive a transient default model clearing"
2516 );
2517 });
2518 }
2519
2520 #[gpui::test]
2521 async fn test_loaded_thread_preserves_thinking_enabled(cx: &mut TestAppContext) {
2522 init_test(cx);
2523 let fs = FakeFs::new(cx.executor());
2524 fs.insert_tree("/", json!({ "a": {} })).await;
2525 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2526 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2527 let agent = cx.update(|cx| {
2528 NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
2529 });
2530 let connection = Rc::new(NativeAgentConnection(agent.clone()));
2531
2532 // Register a thinking model.
2533 let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2534 "fake-corp",
2535 "fake-thinking",
2536 "Fake Thinking",
2537 true,
2538 ));
2539 let thinking_provider = Arc::new(
2540 FakeLanguageModelProvider::new(
2541 LanguageModelProviderId::from("fake-corp".to_string()),
2542 LanguageModelProviderName::from("Fake Corp".to_string()),
2543 )
2544 .with_models(vec![thinking_model.clone()]),
2545 );
2546 cx.update(|cx| {
2547 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2548 registry.register_provider(thinking_provider, cx);
2549 });
2550 });
2551 agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2552
2553 // Create a thread and select the thinking model.
2554 let acp_thread = cx
2555 .update(|cx| {
2556 connection.clone().new_session(
2557 project.clone(),
2558 PathList::new(&[Path::new("/a")]),
2559 cx,
2560 )
2561 })
2562 .await
2563 .unwrap();
2564 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2565
2566 let selector = connection.model_selector(&session_id).unwrap();
2567 cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2568 .await
2569 .unwrap();
2570
2571 // Verify thinking is enabled after selecting the thinking model.
2572 let thread = agent.read_with(cx, |agent, _| {
2573 agent.sessions.get(&session_id).unwrap().thread.clone()
2574 });
2575 thread.read_with(cx, |thread, _| {
2576 assert!(
2577 thread.thinking_enabled(),
2578 "thinking should be enabled after selecting thinking model"
2579 );
2580 });
2581
2582 // Send a message so the thread gets persisted.
2583 let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
2584 let send = cx.foreground_executor().spawn(send);
2585 cx.run_until_parked();
2586
2587 thinking_model.send_last_completion_stream_text_chunk("Response.");
2588 thinking_model.end_last_completion_stream();
2589
2590 send.await.unwrap();
2591 cx.run_until_parked();
2592
2593 // Close the session so it can be reloaded from disk.
2594 cx.update(|cx| connection.clone().close_session(&session_id, cx))
2595 .await
2596 .unwrap();
2597 drop(thread);
2598 drop(acp_thread);
2599 agent.read_with(cx, |agent, _| {
2600 assert!(agent.sessions.is_empty());
2601 });
2602
2603 // Reload the thread and verify thinking_enabled is still true.
2604 let reloaded_acp_thread = agent
2605 .update(cx, |agent, cx| {
2606 agent.open_thread(session_id.clone(), project.clone(), cx)
2607 })
2608 .await
2609 .unwrap();
2610 let reloaded_thread = agent.read_with(cx, |agent, _| {
2611 agent.sessions.get(&session_id).unwrap().thread.clone()
2612 });
2613 reloaded_thread.read_with(cx, |thread, _| {
2614 assert!(
2615 thread.thinking_enabled(),
2616 "thinking_enabled should be preserved when reloading a thread with a thinking model"
2617 );
2618 });
2619
2620 drop(reloaded_acp_thread);
2621 }
2622
2623 #[gpui::test]
2624 async fn test_loaded_thread_preserves_model(cx: &mut TestAppContext) {
2625 init_test(cx);
2626 let fs = FakeFs::new(cx.executor());
2627 fs.insert_tree("/", json!({ "a": {} })).await;
2628 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2629 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2630 let agent = cx.update(|cx| {
2631 NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
2632 });
2633 let connection = Rc::new(NativeAgentConnection(agent.clone()));
2634
2635 // Register a model where id() != name(), like real Anthropic models
2636 // (e.g. id="claude-sonnet-4-5-thinking-latest", name="Claude Sonnet 4.5 Thinking").
2637 let model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2638 "fake-corp",
2639 "custom-model-id",
2640 "Custom Model Display Name",
2641 false,
2642 ));
2643 let provider = Arc::new(
2644 FakeLanguageModelProvider::new(
2645 LanguageModelProviderId::from("fake-corp".to_string()),
2646 LanguageModelProviderName::from("Fake Corp".to_string()),
2647 )
2648 .with_models(vec![model.clone()]),
2649 );
2650 cx.update(|cx| {
2651 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2652 registry.register_provider(provider, cx);
2653 });
2654 });
2655 agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2656
2657 // Create a thread and select the model.
2658 let acp_thread = cx
2659 .update(|cx| {
2660 connection.clone().new_session(
2661 project.clone(),
2662 PathList::new(&[Path::new("/a")]),
2663 cx,
2664 )
2665 })
2666 .await
2667 .unwrap();
2668 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2669
2670 let selector = connection.model_selector(&session_id).unwrap();
2671 cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/custom-model-id"), cx))
2672 .await
2673 .unwrap();
2674
2675 let thread = agent.read_with(cx, |agent, _| {
2676 agent.sessions.get(&session_id).unwrap().thread.clone()
2677 });
2678 thread.read_with(cx, |thread, _| {
2679 assert_eq!(
2680 thread.model().unwrap().id().0.as_ref(),
2681 "custom-model-id",
2682 "model should be set before persisting"
2683 );
2684 });
2685
2686 // Send a message so the thread gets persisted.
2687 let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
2688 let send = cx.foreground_executor().spawn(send);
2689 cx.run_until_parked();
2690
2691 model.send_last_completion_stream_text_chunk("Response.");
2692 model.end_last_completion_stream();
2693
2694 send.await.unwrap();
2695 cx.run_until_parked();
2696
2697 // Close the session so it can be reloaded from disk.
2698 cx.update(|cx| connection.clone().close_session(&session_id, cx))
2699 .await
2700 .unwrap();
2701 drop(thread);
2702 drop(acp_thread);
2703 agent.read_with(cx, |agent, _| {
2704 assert!(agent.sessions.is_empty());
2705 });
2706
2707 // Reload the thread and verify the model was preserved.
2708 let reloaded_acp_thread = agent
2709 .update(cx, |agent, cx| {
2710 agent.open_thread(session_id.clone(), project.clone(), cx)
2711 })
2712 .await
2713 .unwrap();
2714 let reloaded_thread = agent.read_with(cx, |agent, _| {
2715 agent.sessions.get(&session_id).unwrap().thread.clone()
2716 });
2717 reloaded_thread.read_with(cx, |thread, _| {
2718 let reloaded_model = thread
2719 .model()
2720 .expect("model should be present after reload");
2721 assert_eq!(
2722 reloaded_model.id().0.as_ref(),
2723 "custom-model-id",
2724 "reloaded thread should have the same model, not fall back to the default"
2725 );
2726 });
2727
2728 drop(reloaded_acp_thread);
2729 }
2730
2731 #[gpui::test]
2732 async fn test_save_load_thread(cx: &mut TestAppContext) {
2733 init_test(cx);
2734 let fs = FakeFs::new(cx.executor());
2735 fs.insert_tree(
2736 "/",
2737 json!({
2738 "a": {
2739 "b.md": "Lorem"
2740 }
2741 }),
2742 )
2743 .await;
2744 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2745 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2746 let agent = cx.update(|cx| {
2747 NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
2748 });
2749 let connection = Rc::new(NativeAgentConnection(agent.clone()));
2750
2751 let acp_thread = cx
2752 .update(|cx| {
2753 connection
2754 .clone()
2755 .new_session(project.clone(), PathList::new(&[Path::new("")]), cx)
2756 })
2757 .await
2758 .unwrap();
2759 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2760 let thread = agent.read_with(cx, |agent, _| {
2761 agent.sessions.get(&session_id).unwrap().thread.clone()
2762 });
2763
2764 // Ensure empty threads are not saved, even if they get mutated.
2765 let model = Arc::new(FakeLanguageModel::default());
2766 let summary_model = Arc::new(FakeLanguageModel::default());
2767 thread.update(cx, |thread, cx| {
2768 thread.set_model(model.clone(), cx);
2769 thread.set_summarization_model(Some(summary_model.clone()), cx);
2770 });
2771 cx.run_until_parked();
2772 assert_eq!(thread_entries(&thread_store, cx), vec![]);
2773
2774 let send = acp_thread.update(cx, |thread, cx| {
2775 thread.send(
2776 vec![
2777 "What does ".into(),
2778 acp::ContentBlock::ResourceLink(acp::ResourceLink::new(
2779 "b.md",
2780 MentionUri::File {
2781 abs_path: path!("/a/b.md").into(),
2782 }
2783 .to_uri()
2784 .to_string(),
2785 )),
2786 " mean?".into(),
2787 ],
2788 cx,
2789 )
2790 });
2791 let send = cx.foreground_executor().spawn(send);
2792 cx.run_until_parked();
2793
2794 model.send_last_completion_stream_text_chunk("Lorem.");
2795 model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2796 language_model::TokenUsage {
2797 input_tokens: 150,
2798 output_tokens: 75,
2799 ..Default::default()
2800 },
2801 ));
2802 model.end_last_completion_stream();
2803 cx.run_until_parked();
2804 summary_model
2805 .send_last_completion_stream_text_chunk(&format!("Explaining {}", path!("/a/b.md")));
2806 summary_model.end_last_completion_stream();
2807
2808 send.await.unwrap();
2809 let uri = MentionUri::File {
2810 abs_path: path!("/a/b.md").into(),
2811 }
2812 .to_uri();
2813 acp_thread.read_with(cx, |thread, cx| {
2814 assert_eq!(
2815 thread.to_markdown(cx),
2816 formatdoc! {"
2817 ## User
2818
2819 What does [@b.md]({uri}) mean?
2820
2821 ## Assistant
2822
2823 Lorem.
2824
2825 "}
2826 )
2827 });
2828
2829 cx.run_until_parked();
2830
2831 // Set a draft prompt with rich content blocks before saving.
2832 let draft_blocks = vec![
2833 acp::ContentBlock::Text(acp::TextContent::new("Check out ")),
2834 acp::ContentBlock::ResourceLink(acp::ResourceLink::new("b.md", uri.to_string())),
2835 acp::ContentBlock::Text(acp::TextContent::new(" please")),
2836 ];
2837 acp_thread.update(cx, |thread, _cx| {
2838 thread.set_draft_prompt(Some(draft_blocks.clone()));
2839 });
2840 thread.update(cx, |thread, _cx| {
2841 thread.set_ui_scroll_position(Some(gpui::ListOffset {
2842 item_ix: 5,
2843 offset_in_item: gpui::px(12.5),
2844 }));
2845 });
2846 thread.update(cx, |_thread, cx| cx.notify());
2847 cx.run_until_parked();
2848
2849 // Close the session so it can be reloaded from disk.
2850 cx.update(|cx| connection.clone().close_session(&session_id, cx))
2851 .await
2852 .unwrap();
2853 drop(thread);
2854 drop(acp_thread);
2855 agent.read_with(cx, |agent, _| {
2856 assert_eq!(agent.sessions.keys().cloned().collect::<Vec<_>>(), []);
2857 });
2858
2859 // Ensure the thread can be reloaded from disk.
2860 assert_eq!(
2861 thread_entries(&thread_store, cx),
2862 vec![(
2863 session_id.clone(),
2864 format!("Explaining {}", path!("/a/b.md"))
2865 )]
2866 );
2867 let acp_thread = agent
2868 .update(cx, |agent, cx| {
2869 agent.open_thread(session_id.clone(), project.clone(), cx)
2870 })
2871 .await
2872 .unwrap();
2873 acp_thread.read_with(cx, |thread, cx| {
2874 assert_eq!(
2875 thread.to_markdown(cx),
2876 formatdoc! {"
2877 ## User
2878
2879 What does [@b.md]({uri}) mean?
2880
2881 ## Assistant
2882
2883 Lorem.
2884
2885 "}
2886 )
2887 });
2888
2889 // Ensure the draft prompt with rich content blocks survived the round-trip.
2890 acp_thread.read_with(cx, |thread, _| {
2891 assert_eq!(thread.draft_prompt(), Some(draft_blocks.as_slice()));
2892 });
2893
2894 // Ensure token usage survived the round-trip.
2895 acp_thread.read_with(cx, |thread, _| {
2896 let usage = thread
2897 .token_usage()
2898 .expect("token usage should be restored after reload");
2899 assert_eq!(usage.input_tokens, 150);
2900 assert_eq!(usage.output_tokens, 75);
2901 });
2902
2903 // Ensure scroll position survived the round-trip.
2904 acp_thread.read_with(cx, |thread, _| {
2905 let scroll = thread
2906 .ui_scroll_position()
2907 .expect("scroll position should be restored after reload");
2908 assert_eq!(scroll.item_ix, 5);
2909 assert_eq!(scroll.offset_in_item, gpui::px(12.5));
2910 });
2911 }
2912
2913 fn thread_entries(
2914 thread_store: &Entity<ThreadStore>,
2915 cx: &mut TestAppContext,
2916 ) -> Vec<(acp::SessionId, String)> {
2917 thread_store.read_with(cx, |store, _| {
2918 store
2919 .entries()
2920 .map(|entry| (entry.id.clone(), entry.title.to_string()))
2921 .collect::<Vec<_>>()
2922 })
2923 }
2924
2925 fn init_test(cx: &mut TestAppContext) {
2926 env_logger::try_init().ok();
2927 cx.update(|cx| {
2928 let settings_store = SettingsStore::test(cx);
2929 cx.set_global(settings_store);
2930
2931 LanguageModelRegistry::test(cx);
2932 });
2933 }
2934}
2935
2936fn mcp_message_content_to_acp_content_block(
2937 content: context_server::types::MessageContent,
2938) -> acp::ContentBlock {
2939 match content {
2940 context_server::types::MessageContent::Text {
2941 text,
2942 annotations: _,
2943 } => text.into(),
2944 context_server::types::MessageContent::Image {
2945 data,
2946 mime_type,
2947 annotations: _,
2948 } => acp::ContentBlock::Image(acp::ImageContent::new(data, mime_type)),
2949 context_server::types::MessageContent::Audio {
2950 data,
2951 mime_type,
2952 annotations: _,
2953 } => acp::ContentBlock::Audio(acp::AudioContent::new(data, mime_type)),
2954 context_server::types::MessageContent::Resource {
2955 resource,
2956 annotations: _,
2957 } => {
2958 let mut link =
2959 acp::ResourceLink::new(resource.uri.to_string(), resource.uri.to_string());
2960 if let Some(mime_type) = resource.mime_type {
2961 link = link.mime_type(mime_type);
2962 }
2963 acp::ContentBlock::ResourceLink(link)
2964 }
2965 }
2966}