1mod db;
2mod edit_agent;
3mod legacy_thread;
4mod native_agent_server;
5pub mod outline;
6mod pattern_extraction;
7mod templates;
8#[cfg(test)]
9mod tests;
10mod thread;
11mod thread_store;
12mod tool_permissions;
13mod tools;
14
15use context_server::ContextServerId;
16pub use db::*;
17use itertools::Itertools;
18pub use native_agent_server::NativeAgentServer;
19pub use pattern_extraction::*;
20pub use shell_command_parser::extract_commands;
21pub use templates::*;
22pub use thread::*;
23pub use thread_store::*;
24pub use tool_permissions::*;
25pub use tools::*;
26
27use acp_thread::{
28 AcpThread, AgentModelSelector, AgentSessionInfo, AgentSessionList, AgentSessionListRequest,
29 AgentSessionListResponse, ThreadId, TokenUsageRatio, UserMessageId,
30};
31use agent_client_protocol as acp;
32use anyhow::{Context as _, Result, anyhow};
33use chrono::{DateTime, Utc};
34use collections::{HashMap, HashSet, IndexMap};
35use fs::Fs;
36use futures::channel::{mpsc, oneshot};
37use futures::future::Shared;
38use futures::{FutureExt as _, StreamExt as _, future};
39use gpui::{
40 App, AppContext, AsyncApp, Context, Entity, EntityId, SharedString, Subscription, Task,
41 WeakEntity,
42};
43use language_model::{IconOrSvg, LanguageModel, LanguageModelProvider, LanguageModelRegistry};
44use project::{AgentId, Project, ProjectItem, ProjectPath, Worktree};
45use prompt_store::{
46 ProjectContext, PromptStore, RULES_FILE_NAMES, RulesFileContext, UserRulesContext,
47 WorktreeContext,
48};
49use serde::{Deserialize, Serialize};
50use settings::{LanguageModelSelection, update_settings_file};
51use std::any::Any;
52use std::path::PathBuf;
53use std::rc::Rc;
54use std::sync::{Arc, LazyLock};
55use util::ResultExt;
56use util::path_list::PathList;
57use util::rel_path::RelPath;
58
59#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
60pub struct ProjectSnapshot {
61 pub worktree_snapshots: Vec<project::telemetry_snapshot::TelemetryWorktreeSnapshot>,
62 pub timestamp: DateTime<Utc>,
63}
64
65pub struct RulesLoadingError {
66 pub message: SharedString,
67}
68
69struct ProjectState {
70 project: Entity<Project>,
71 project_context: Entity<ProjectContext>,
72 project_context_needs_refresh: watch::Sender<()>,
73 _maintain_project_context: Task<Result<()>>,
74 context_server_registry: Entity<ContextServerRegistry>,
75 _subscriptions: Vec<Subscription>,
76}
77
78/// Holds both the internal Thread and the AcpThread for a session
79struct Session {
80 /// The internal thread that processes messages
81 thread: Entity<Thread>,
82 /// The ACP thread that handles protocol communication
83 acp_thread: Entity<acp_thread::AcpThread>,
84 project_id: EntityId,
85 pending_save: Task<Result<()>>,
86 _subscriptions: Vec<Subscription>,
87}
88
89pub struct LanguageModels {
90 /// Access language model by ID
91 models: HashMap<acp::ModelId, Arc<dyn LanguageModel>>,
92 /// Cached list for returning language model information
93 model_list: acp_thread::AgentModelList,
94 refresh_models_rx: watch::Receiver<()>,
95 refresh_models_tx: watch::Sender<()>,
96 _authenticate_all_providers_task: Task<()>,
97}
98
99impl LanguageModels {
100 fn new(cx: &mut App) -> Self {
101 let (refresh_models_tx, refresh_models_rx) = watch::channel(());
102
103 let mut this = Self {
104 models: HashMap::default(),
105 model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()),
106 refresh_models_rx,
107 refresh_models_tx,
108 _authenticate_all_providers_task: Self::authenticate_all_language_model_providers(cx),
109 };
110 this.refresh_list(cx);
111 this
112 }
113
114 fn refresh_list(&mut self, cx: &App) {
115 let providers = LanguageModelRegistry::global(cx)
116 .read(cx)
117 .visible_providers()
118 .into_iter()
119 .filter(|provider| provider.is_authenticated(cx))
120 .collect::<Vec<_>>();
121
122 let mut language_model_list = IndexMap::default();
123 let mut recommended_models = HashSet::default();
124
125 let mut recommended = Vec::new();
126 for provider in &providers {
127 for model in provider.recommended_models(cx) {
128 recommended_models.insert((model.provider_id(), model.id()));
129 recommended.push(Self::map_language_model_to_info(&model, provider));
130 }
131 }
132 if !recommended.is_empty() {
133 language_model_list.insert(
134 acp_thread::AgentModelGroupName("Recommended".into()),
135 recommended,
136 );
137 }
138
139 let mut models = HashMap::default();
140 for provider in providers {
141 let mut provider_models = Vec::new();
142 for model in provider.provided_models(cx) {
143 let model_info = Self::map_language_model_to_info(&model, &provider);
144 let model_id = model_info.id.clone();
145 provider_models.push(model_info);
146 models.insert(model_id, model);
147 }
148 if !provider_models.is_empty() {
149 language_model_list.insert(
150 acp_thread::AgentModelGroupName(provider.name().0.clone()),
151 provider_models,
152 );
153 }
154 }
155
156 self.models = models;
157 self.model_list = acp_thread::AgentModelList::Grouped(language_model_list);
158 self.refresh_models_tx.send(()).ok();
159 }
160
161 fn watch(&self) -> watch::Receiver<()> {
162 self.refresh_models_rx.clone()
163 }
164
165 pub fn model_from_id(&self, model_id: &acp::ModelId) -> Option<Arc<dyn LanguageModel>> {
166 self.models.get(model_id).cloned()
167 }
168
169 fn map_language_model_to_info(
170 model: &Arc<dyn LanguageModel>,
171 provider: &Arc<dyn LanguageModelProvider>,
172 ) -> acp_thread::AgentModelInfo {
173 acp_thread::AgentModelInfo {
174 id: Self::model_id(model),
175 name: model.name().0,
176 description: None,
177 icon: Some(match provider.icon() {
178 IconOrSvg::Svg(path) => acp_thread::AgentModelIcon::Path(path),
179 IconOrSvg::Icon(name) => acp_thread::AgentModelIcon::Named(name),
180 }),
181 is_latest: model.is_latest(),
182 cost: model.model_cost_info().map(|cost| cost.to_shared_string()),
183 }
184 }
185
186 fn model_id(model: &Arc<dyn LanguageModel>) -> acp::ModelId {
187 acp::ModelId::new(format!("{}/{}", model.provider_id().0, model.id().0))
188 }
189
190 fn authenticate_all_language_model_providers(cx: &mut App) -> Task<()> {
191 let authenticate_all_providers = LanguageModelRegistry::global(cx)
192 .read(cx)
193 .visible_providers()
194 .iter()
195 .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
196 .collect::<Vec<_>>();
197
198 cx.background_spawn(async move {
199 for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
200 if let Err(err) = authenticate_task.await {
201 match err {
202 language_model::AuthenticateError::CredentialsNotFound => {
203 // Since we're authenticating these providers in the
204 // background for the purposes of populating the
205 // language selector, we don't care about providers
206 // where the credentials are not found.
207 }
208 language_model::AuthenticateError::ConnectionRefused => {
209 // Not logging connection refused errors as they are mostly from LM Studio's noisy auth failures.
210 // LM Studio only has one auth method (endpoint call) which fails for users who haven't enabled it.
211 // TODO: Better manage LM Studio auth logic to avoid these noisy failures.
212 }
213 _ => {
214 // Some providers have noisy failure states that we
215 // don't want to spam the logs with every time the
216 // language model selector is initialized.
217 //
218 // Ideally these should have more clear failure modes
219 // that we know are safe to ignore here, like what we do
220 // with `CredentialsNotFound` above.
221 match provider_id.0.as_ref() {
222 "lmstudio" | "ollama" => {
223 // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
224 //
225 // These fail noisily, so we don't log them.
226 }
227 "copilot_chat" => {
228 // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
229 }
230 _ => {
231 log::error!(
232 "Failed to authenticate provider: {}: {err:#}",
233 provider_name.0
234 );
235 }
236 }
237 }
238 }
239 }
240 }
241 })
242 }
243}
244
245pub struct NativeAgent {
246 /// Session ID -> Session mapping
247 sessions: HashMap<acp::SessionId, Session>,
248 thread_store: Entity<ThreadStore>,
249 /// Project-specific state keyed by project EntityId
250 projects: HashMap<EntityId, ProjectState>,
251 /// Shared templates for all threads
252 templates: Arc<Templates>,
253 /// Cached model information
254 models: LanguageModels,
255 prompt_store: Option<Entity<PromptStore>>,
256 fs: Arc<dyn Fs>,
257 _subscriptions: Vec<Subscription>,
258}
259
260impl NativeAgent {
261 pub fn new(
262 thread_store: Entity<ThreadStore>,
263 templates: Arc<Templates>,
264 prompt_store: Option<Entity<PromptStore>>,
265 fs: Arc<dyn Fs>,
266 cx: &mut App,
267 ) -> Entity<NativeAgent> {
268 log::debug!("Creating new NativeAgent");
269
270 cx.new(|cx| {
271 let mut subscriptions = vec![cx.subscribe(
272 &LanguageModelRegistry::global(cx),
273 Self::handle_models_updated_event,
274 )];
275 if let Some(prompt_store) = prompt_store.as_ref() {
276 subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
277 }
278
279 Self {
280 sessions: HashMap::default(),
281 thread_store,
282 projects: HashMap::default(),
283 templates,
284 models: LanguageModels::new(cx),
285 prompt_store,
286 fs,
287 _subscriptions: subscriptions,
288 }
289 })
290 }
291
292 fn new_session(
293 &mut self,
294 project: Entity<Project>,
295 cx: &mut Context<Self>,
296 ) -> Entity<AcpThread> {
297 let project_id = self.get_or_create_project_state(&project, cx);
298 let project_state = &self.projects[&project_id];
299
300 let registry = LanguageModelRegistry::read_global(cx);
301 let available_count = registry.available_models(cx).count();
302 log::debug!("Total available models: {}", available_count);
303
304 let default_model = registry.default_model().and_then(|default_model| {
305 self.models
306 .model_from_id(&LanguageModels::model_id(&default_model.model))
307 });
308 let thread = cx.new(|cx| {
309 Thread::new(
310 project,
311 project_state.project_context.clone(),
312 project_state.context_server_registry.clone(),
313 self.templates.clone(),
314 default_model,
315 cx,
316 )
317 });
318
319 self.register_session(thread, project_id, cx)
320 }
321
322 fn register_session(
323 &mut self,
324 thread_handle: Entity<Thread>,
325 project_id: EntityId,
326 cx: &mut Context<Self>,
327 ) -> Entity<AcpThread> {
328 let connection = Rc::new(NativeAgentConnection(cx.entity()));
329
330 let thread = thread_handle.read(cx);
331 let thread_id = thread.id();
332 let session_id = thread.session_id().clone();
333 let parent_thread_id = thread.parent_thread_id();
334 let title = thread.title();
335 let draft_prompt = thread.draft_prompt().map(Vec::from);
336 let scroll_position = thread.ui_scroll_position();
337 let token_usage = thread.latest_token_usage();
338 let project = thread.project.clone();
339 let action_log = thread.action_log.clone();
340 let prompt_capabilities_rx = thread.prompt_capabilities_rx.clone();
341 let acp_thread = cx.new(|cx| {
342 let mut acp_thread = acp_thread::AcpThread::new(
343 parent_thread_id,
344 title,
345 None,
346 connection,
347 project.clone(),
348 action_log.clone(),
349 thread_id,
350 session_id.clone(),
351 prompt_capabilities_rx,
352 cx,
353 );
354 acp_thread.set_draft_prompt(draft_prompt);
355 acp_thread.set_ui_scroll_position(scroll_position);
356 acp_thread.update_token_usage(token_usage, cx);
357 acp_thread
358 });
359
360 let registry = LanguageModelRegistry::read_global(cx);
361 let summarization_model = registry.thread_summary_model().map(|c| c.model);
362
363 let weak = cx.weak_entity();
364 let weak_thread = thread_handle.downgrade();
365 thread_handle.update(cx, |thread, cx| {
366 thread.set_summarization_model(summarization_model, cx);
367 thread.add_default_tools(
368 Rc::new(NativeThreadEnvironment {
369 acp_thread: acp_thread.downgrade(),
370 thread: weak_thread,
371 agent: weak,
372 }) as _,
373 cx,
374 )
375 });
376
377 let subscriptions = vec![
378 cx.subscribe(&thread_handle, Self::handle_thread_title_updated),
379 cx.subscribe(&thread_handle, Self::handle_thread_token_usage_updated),
380 cx.observe(&thread_handle, move |this, thread, cx| {
381 this.save_thread(thread, cx)
382 }),
383 ];
384
385 self.sessions.insert(
386 session_id,
387 Session {
388 thread: thread_handle,
389 acp_thread: acp_thread.clone(),
390 project_id,
391 _subscriptions: subscriptions,
392 pending_save: Task::ready(Ok(())),
393 },
394 );
395
396 self.update_available_commands_for_project(project_id, cx);
397
398 acp_thread
399 }
400
401 pub fn models(&self) -> &LanguageModels {
402 &self.models
403 }
404
405 fn get_or_create_project_state(
406 &mut self,
407 project: &Entity<Project>,
408 cx: &mut Context<Self>,
409 ) -> EntityId {
410 let project_id = project.entity_id();
411 if self.projects.contains_key(&project_id) {
412 return project_id;
413 }
414
415 let project_context = cx.new(|_| ProjectContext::new(vec![], vec![]));
416 self.register_project_with_initial_context(project.clone(), project_context, cx);
417 if let Some(state) = self.projects.get_mut(&project_id) {
418 state.project_context_needs_refresh.send(()).ok();
419 }
420 project_id
421 }
422
423 fn register_project_with_initial_context(
424 &mut self,
425 project: Entity<Project>,
426 project_context: Entity<ProjectContext>,
427 cx: &mut Context<Self>,
428 ) {
429 let project_id = project.entity_id();
430
431 let context_server_store = project.read(cx).context_server_store();
432 let context_server_registry =
433 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
434
435 let subscriptions = vec![
436 cx.subscribe(&project, Self::handle_project_event),
437 cx.subscribe(
438 &context_server_store,
439 Self::handle_context_server_store_updated,
440 ),
441 cx.subscribe(
442 &context_server_registry,
443 Self::handle_context_server_registry_event,
444 ),
445 ];
446
447 let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
448 watch::channel(());
449
450 self.projects.insert(
451 project_id,
452 ProjectState {
453 project,
454 project_context,
455 project_context_needs_refresh: project_context_needs_refresh_tx,
456 _maintain_project_context: cx.spawn(async move |this, cx| {
457 Self::maintain_project_context(
458 this,
459 project_id,
460 project_context_needs_refresh_rx,
461 cx,
462 )
463 .await
464 }),
465 context_server_registry,
466 _subscriptions: subscriptions,
467 },
468 );
469 }
470
471 fn session_project_state(&self, session_id: &acp::SessionId) -> Option<&ProjectState> {
472 self.sessions
473 .get(session_id)
474 .and_then(|session| self.projects.get(&session.project_id))
475 }
476
477 async fn maintain_project_context(
478 this: WeakEntity<Self>,
479 project_id: EntityId,
480 mut needs_refresh: watch::Receiver<()>,
481 cx: &mut AsyncApp,
482 ) -> Result<()> {
483 while needs_refresh.changed().await.is_ok() {
484 let project_context = this
485 .update(cx, |this, cx| {
486 let state = this
487 .projects
488 .get(&project_id)
489 .context("project state not found")?;
490 anyhow::Ok(Self::build_project_context(
491 &state.project,
492 this.prompt_store.as_ref(),
493 cx,
494 ))
495 })??
496 .await;
497 this.update(cx, |this, cx| {
498 if let Some(state) = this.projects.get(&project_id) {
499 state
500 .project_context
501 .update(cx, |current_project_context, _cx| {
502 *current_project_context = project_context;
503 });
504 }
505 })?;
506 }
507
508 Ok(())
509 }
510
511 fn build_project_context(
512 project: &Entity<Project>,
513 prompt_store: Option<&Entity<PromptStore>>,
514 cx: &mut App,
515 ) -> Task<ProjectContext> {
516 let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
517 let worktree_tasks = worktrees
518 .into_iter()
519 .map(|worktree| {
520 Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
521 })
522 .collect::<Vec<_>>();
523 let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
524 prompt_store.read_with(cx, |prompt_store, cx| {
525 let prompts = prompt_store.default_prompt_metadata();
526 let load_tasks = prompts.into_iter().map(|prompt_metadata| {
527 let contents = prompt_store.load(prompt_metadata.id, cx);
528 async move { (contents.await, prompt_metadata) }
529 });
530 cx.background_spawn(future::join_all(load_tasks))
531 })
532 } else {
533 Task::ready(vec![])
534 };
535
536 cx.spawn(async move |_cx| {
537 let (worktrees, default_user_rules) =
538 future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
539
540 let worktrees = worktrees
541 .into_iter()
542 .map(|(worktree, _rules_error)| {
543 // TODO: show error message
544 // if let Some(rules_error) = rules_error {
545 // this.update(cx, |_, cx| cx.emit(rules_error)).ok();
546 // }
547 worktree
548 })
549 .collect::<Vec<_>>();
550
551 let default_user_rules = default_user_rules
552 .into_iter()
553 .flat_map(|(contents, prompt_metadata)| match contents {
554 Ok(contents) => Some(UserRulesContext {
555 uuid: prompt_metadata.id.as_user()?,
556 title: prompt_metadata.title.map(|title| title.to_string()),
557 contents,
558 }),
559 Err(_err) => {
560 // TODO: show error message
561 // this.update(cx, |_, cx| {
562 // cx.emit(RulesLoadingError {
563 // message: format!("{err:?}").into(),
564 // });
565 // })
566 // .ok();
567 None
568 }
569 })
570 .collect::<Vec<_>>();
571
572 ProjectContext::new(worktrees, default_user_rules)
573 })
574 }
575
576 fn load_worktree_info_for_system_prompt(
577 worktree: Entity<Worktree>,
578 project: Entity<Project>,
579 cx: &mut App,
580 ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
581 let tree = worktree.read(cx);
582 let root_name = tree.root_name_str().into();
583 let abs_path = tree.abs_path();
584
585 let mut context = WorktreeContext {
586 root_name,
587 abs_path,
588 rules_file: None,
589 };
590
591 let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
592 let Some(rules_task) = rules_task else {
593 return Task::ready((context, None));
594 };
595
596 cx.spawn(async move |_| {
597 let (rules_file, rules_file_error) = match rules_task.await {
598 Ok(rules_file) => (Some(rules_file), None),
599 Err(err) => (
600 None,
601 Some(RulesLoadingError {
602 message: format!("{err}").into(),
603 }),
604 ),
605 };
606 context.rules_file = rules_file;
607 (context, rules_file_error)
608 })
609 }
610
611 fn load_worktree_rules_file(
612 worktree: Entity<Worktree>,
613 project: Entity<Project>,
614 cx: &mut App,
615 ) -> Option<Task<Result<RulesFileContext>>> {
616 let worktree = worktree.read(cx);
617 let worktree_id = worktree.id();
618 let selected_rules_file = RULES_FILE_NAMES
619 .into_iter()
620 .filter_map(|name| {
621 worktree
622 .entry_for_path(RelPath::unix(name).unwrap())
623 .filter(|entry| entry.is_file())
624 .map(|entry| entry.path.clone())
625 })
626 .next();
627
628 // Note that Cline supports `.clinerules` being a directory, but that is not currently
629 // supported. This doesn't seem to occur often in GitHub repositories.
630 selected_rules_file.map(|path_in_worktree| {
631 let project_path = ProjectPath {
632 worktree_id,
633 path: path_in_worktree.clone(),
634 };
635 let buffer_task =
636 project.update(cx, |project, cx| project.open_buffer(project_path, cx));
637 let rope_task = cx.spawn(async move |cx| {
638 let buffer = buffer_task.await?;
639 let (project_entry_id, rope) = buffer.read_with(cx, |buffer, cx| {
640 let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
641 anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
642 })?;
643 anyhow::Ok((project_entry_id, rope))
644 });
645 // Build a string from the rope on a background thread.
646 cx.background_spawn(async move {
647 let (project_entry_id, rope) = rope_task.await?;
648 anyhow::Ok(RulesFileContext {
649 path_in_worktree,
650 text: rope.to_string().trim().to_string(),
651 project_entry_id: project_entry_id.to_usize(),
652 })
653 })
654 })
655 }
656
657 fn handle_thread_title_updated(
658 &mut self,
659 thread: Entity<Thread>,
660 _: &TitleUpdated,
661 cx: &mut Context<Self>,
662 ) {
663 let session_id = thread.read(cx).session_id();
664 let Some(session) = self.sessions.get(session_id) else {
665 return;
666 };
667
668 let thread = thread.downgrade();
669 let acp_thread = session.acp_thread.downgrade();
670 cx.spawn(async move |_, cx| {
671 let title = thread.read_with(cx, |thread, _| thread.title())?;
672 if let Some(title) = title {
673 let task =
674 acp_thread.update(cx, |acp_thread, cx| acp_thread.set_title(title, cx))?;
675 task.await?;
676 }
677 anyhow::Ok(())
678 })
679 .detach_and_log_err(cx);
680 }
681
682 fn handle_thread_token_usage_updated(
683 &mut self,
684 thread: Entity<Thread>,
685 usage: &TokenUsageUpdated,
686 cx: &mut Context<Self>,
687 ) {
688 let Some(session) = self.sessions.get(thread.read(cx).session_id()) else {
689 return;
690 };
691 session.acp_thread.update(cx, |acp_thread, cx| {
692 acp_thread.update_token_usage(usage.0.clone(), cx);
693 });
694 }
695
696 fn handle_project_event(
697 &mut self,
698 project: Entity<Project>,
699 event: &project::Event,
700 _cx: &mut Context<Self>,
701 ) {
702 let project_id = project.entity_id();
703 let Some(state) = self.projects.get_mut(&project_id) else {
704 return;
705 };
706 match event {
707 project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
708 state.project_context_needs_refresh.send(()).ok();
709 }
710 project::Event::WorktreeUpdatedEntries(_, items) => {
711 if items.iter().any(|(path, _, _)| {
712 RULES_FILE_NAMES
713 .iter()
714 .any(|name| path.as_ref() == RelPath::unix(name).unwrap())
715 }) {
716 state.project_context_needs_refresh.send(()).ok();
717 }
718 }
719 _ => {}
720 }
721 }
722
723 fn handle_prompts_updated_event(
724 &mut self,
725 _prompt_store: Entity<PromptStore>,
726 _event: &prompt_store::PromptsUpdatedEvent,
727 _cx: &mut Context<Self>,
728 ) {
729 for state in self.projects.values_mut() {
730 state.project_context_needs_refresh.send(()).ok();
731 }
732 }
733
734 fn handle_models_updated_event(
735 &mut self,
736 _registry: Entity<LanguageModelRegistry>,
737 event: &language_model::Event,
738 cx: &mut Context<Self>,
739 ) {
740 self.models.refresh_list(cx);
741
742 let registry = LanguageModelRegistry::read_global(cx);
743 let default_model = registry.default_model().map(|m| m.model);
744 let summarization_model = registry.thread_summary_model().map(|m| m.model);
745
746 for session in self.sessions.values_mut() {
747 session.thread.update(cx, |thread, cx| {
748 if thread.model().is_none()
749 && let Some(model) = default_model.clone()
750 {
751 thread.set_model(model, cx);
752 cx.notify();
753 }
754 if let Some(model) = summarization_model.clone() {
755 if thread.summarization_model().is_none()
756 || matches!(event, language_model::Event::ThreadSummaryModelChanged)
757 {
758 thread.set_summarization_model(Some(model), cx);
759 }
760 }
761 });
762 }
763 }
764
765 fn handle_context_server_store_updated(
766 &mut self,
767 store: Entity<project::context_server_store::ContextServerStore>,
768 _event: &project::context_server_store::ServerStatusChangedEvent,
769 cx: &mut Context<Self>,
770 ) {
771 let project_id = self.projects.iter().find_map(|(id, state)| {
772 if *state.context_server_registry.read(cx).server_store() == store {
773 Some(*id)
774 } else {
775 None
776 }
777 });
778 if let Some(project_id) = project_id {
779 self.update_available_commands_for_project(project_id, cx);
780 }
781 }
782
783 fn handle_context_server_registry_event(
784 &mut self,
785 registry: Entity<ContextServerRegistry>,
786 event: &ContextServerRegistryEvent,
787 cx: &mut Context<Self>,
788 ) {
789 match event {
790 ContextServerRegistryEvent::ToolsChanged => {}
791 ContextServerRegistryEvent::PromptsChanged => {
792 let project_id = self.projects.iter().find_map(|(id, state)| {
793 if state.context_server_registry == registry {
794 Some(*id)
795 } else {
796 None
797 }
798 });
799 if let Some(project_id) = project_id {
800 self.update_available_commands_for_project(project_id, cx);
801 }
802 }
803 }
804 }
805
806 fn update_available_commands_for_project(&self, project_id: EntityId, cx: &mut Context<Self>) {
807 let available_commands =
808 Self::build_available_commands_for_project(self.projects.get(&project_id), cx);
809 for session in self.sessions.values() {
810 if session.project_id != project_id {
811 continue;
812 }
813 session.acp_thread.update(cx, |thread, cx| {
814 thread
815 .handle_session_update(
816 acp::SessionUpdate::AvailableCommandsUpdate(
817 acp::AvailableCommandsUpdate::new(available_commands.clone()),
818 ),
819 cx,
820 )
821 .log_err();
822 });
823 }
824 }
825
826 fn build_available_commands_for_project(
827 project_state: Option<&ProjectState>,
828 cx: &App,
829 ) -> Vec<acp::AvailableCommand> {
830 let Some(state) = project_state else {
831 return vec![];
832 };
833 let registry = state.context_server_registry.read(cx);
834
835 let mut prompt_name_counts: HashMap<&str, usize> = HashMap::default();
836 for context_server_prompt in registry.prompts() {
837 *prompt_name_counts
838 .entry(context_server_prompt.prompt.name.as_str())
839 .or_insert(0) += 1;
840 }
841
842 registry
843 .prompts()
844 .flat_map(|context_server_prompt| {
845 let prompt = &context_server_prompt.prompt;
846
847 let should_prefix = prompt_name_counts
848 .get(prompt.name.as_str())
849 .copied()
850 .unwrap_or(0)
851 > 1;
852
853 let name = if should_prefix {
854 format!("{}.{}", context_server_prompt.server_id, prompt.name)
855 } else {
856 prompt.name.clone()
857 };
858
859 let mut command = acp::AvailableCommand::new(
860 name,
861 prompt.description.clone().unwrap_or_default(),
862 );
863
864 match prompt.arguments.as_deref() {
865 Some([arg]) => {
866 let hint = format!("<{}>", arg.name);
867
868 command = command.input(acp::AvailableCommandInput::Unstructured(
869 acp::UnstructuredCommandInput::new(hint),
870 ));
871 }
872 Some([]) | None => {}
873 Some(_) => {
874 // skip >1 argument commands since we don't support them yet
875 return None;
876 }
877 }
878
879 Some(command)
880 })
881 .collect()
882 }
883
884 pub fn load_thread(
885 &mut self,
886 id: acp::SessionId,
887 project: Entity<Project>,
888 cx: &mut Context<Self>,
889 ) -> Task<Result<Entity<Thread>>> {
890 let database_future = ThreadsDatabase::connect(cx);
891 cx.spawn(async move |this, cx| {
892 let database = database_future.await.map_err(|err| anyhow!(err))?;
893 let db_thread = database
894 .load_thread(id.clone())
895 .await?
896 .with_context(|| format!("no thread found with ID: {id:?}"))?;
897
898 this.update(cx, |this, cx| {
899 let project_id = this.get_or_create_project_state(&project, cx);
900 let project_state = this
901 .projects
902 .get(&project_id)
903 .context("project state not found")?;
904 let summarization_model = LanguageModelRegistry::read_global(cx)
905 .thread_summary_model()
906 .map(|c| c.model);
907
908 Ok(cx.new(|cx| {
909 let mut thread = Thread::from_db(
910 ThreadId::new(),
911 id.clone(),
912 db_thread,
913 project_state.project.clone(),
914 project_state.project_context.clone(),
915 project_state.context_server_registry.clone(),
916 this.templates.clone(),
917 cx,
918 );
919 thread.set_summarization_model(summarization_model, cx);
920 thread
921 }))
922 })?
923 })
924 }
925
926 pub fn open_thread(
927 &mut self,
928 id: acp::SessionId,
929 project: Entity<Project>,
930 cx: &mut Context<Self>,
931 ) -> Task<Result<Entity<AcpThread>>> {
932 if let Some(session) = self.sessions.get(&id) {
933 return Task::ready(Ok(session.acp_thread.clone()));
934 }
935
936 let task = self.load_thread(id, project.clone(), cx);
937 cx.spawn(async move |this, cx| {
938 let thread = task.await?;
939 let acp_thread = this.update(cx, |this, cx| {
940 let project_id = this.get_or_create_project_state(&project, cx);
941 this.register_session(thread.clone(), project_id, cx)
942 })?;
943 let events = thread.update(cx, |thread, cx| thread.replay(cx));
944 cx.update(|cx| {
945 NativeAgentConnection::handle_thread_events(events, acp_thread.downgrade(), cx)
946 })
947 .await?;
948 acp_thread.update(cx, |thread, cx| {
949 thread.snapshot_completed_plan(cx);
950 });
951 Ok(acp_thread)
952 })
953 }
954
955 pub fn thread_summary(
956 &mut self,
957 id: acp::SessionId,
958 project: Entity<Project>,
959 cx: &mut Context<Self>,
960 ) -> Task<Result<SharedString>> {
961 let thread = self.open_thread(id.clone(), project, cx);
962 cx.spawn(async move |this, cx| {
963 let acp_thread = thread.await?;
964 let summary_task = this.update(cx, |this, cx| {
965 let session = this.sessions.get(&id).context("session not found")?;
966 anyhow::Ok(session.thread.update(cx, |thread, cx| thread.summary(cx)))
967 })??;
968 let result = summary_task.await.context("Failed to generate summary")?;
969 drop(acp_thread);
970 Ok(result)
971 })
972 }
973
974 fn save_thread(&mut self, thread: Entity<Thread>, cx: &mut Context<Self>) {
975 if thread.read(cx).is_empty() {
976 return;
977 }
978
979 let session_id = thread.read(cx).session_id().clone();
980 let Some(session) = self.sessions.get_mut(&session_id) else {
981 return;
982 };
983
984 let project_id = session.project_id;
985 let Some(state) = self.projects.get(&project_id) else {
986 return;
987 };
988
989 let folder_paths = PathList::new(
990 &state
991 .project
992 .read(cx)
993 .visible_worktrees(cx)
994 .map(|worktree| worktree.read(cx).abs_path().to_path_buf())
995 .collect::<Vec<_>>(),
996 );
997
998 let draft_prompt = session.acp_thread.read(cx).draft_prompt().map(Vec::from);
999 let database_future = ThreadsDatabase::connect(cx);
1000 let db_thread = thread.update(cx, |thread, cx| {
1001 thread.set_draft_prompt(draft_prompt);
1002 thread.to_db(cx)
1003 });
1004 let thread_store = self.thread_store.clone();
1005 session.pending_save = cx.spawn(async move |_, cx| {
1006 let Some(database) = database_future.await.map_err(|err| anyhow!(err)).log_err() else {
1007 return Ok(());
1008 };
1009 let db_thread = db_thread.await;
1010 database
1011 .save_thread(session_id, db_thread, folder_paths)
1012 .await
1013 .log_err();
1014 thread_store.update(cx, |store, cx| store.reload(cx));
1015 Ok(())
1016 });
1017 }
1018
1019 fn send_mcp_prompt(
1020 &self,
1021 message_id: UserMessageId,
1022 session_id: acp::SessionId,
1023 prompt_name: String,
1024 server_id: ContextServerId,
1025 arguments: HashMap<String, String>,
1026 original_content: Vec<acp::ContentBlock>,
1027 cx: &mut Context<Self>,
1028 ) -> Task<Result<acp::PromptResponse>> {
1029 let Some(state) = self.session_project_state(&session_id) else {
1030 return Task::ready(Err(anyhow!("Project state not found for session")));
1031 };
1032 let server_store = state
1033 .context_server_registry
1034 .read(cx)
1035 .server_store()
1036 .clone();
1037 let path_style = state.project.read(cx).path_style(cx);
1038
1039 cx.spawn(async move |this, cx| {
1040 let prompt =
1041 crate::get_prompt(&server_store, &server_id, &prompt_name, arguments, cx).await?;
1042
1043 let (acp_thread, thread) = this.update(cx, |this, _cx| {
1044 let session = this
1045 .sessions
1046 .get(&session_id)
1047 .context("Failed to get session")?;
1048 anyhow::Ok((session.acp_thread.clone(), session.thread.clone()))
1049 })??;
1050
1051 let mut last_is_user = true;
1052
1053 thread.update(cx, |thread, cx| {
1054 thread.push_acp_user_block(
1055 message_id,
1056 original_content.into_iter().skip(1),
1057 path_style,
1058 cx,
1059 );
1060 });
1061
1062 for message in prompt.messages {
1063 let context_server::types::PromptMessage { role, content } = message;
1064 let block = mcp_message_content_to_acp_content_block(content);
1065
1066 match role {
1067 context_server::types::Role::User => {
1068 let id = acp_thread::UserMessageId::new();
1069
1070 acp_thread.update(cx, |acp_thread, cx| {
1071 acp_thread.push_user_content_block_with_indent(
1072 Some(id.clone()),
1073 block.clone(),
1074 true,
1075 cx,
1076 );
1077 });
1078
1079 thread.update(cx, |thread, cx| {
1080 thread.push_acp_user_block(id, [block], path_style, cx);
1081 });
1082 }
1083 context_server::types::Role::Assistant => {
1084 acp_thread.update(cx, |acp_thread, cx| {
1085 acp_thread.push_assistant_content_block_with_indent(
1086 block.clone(),
1087 false,
1088 true,
1089 cx,
1090 );
1091 });
1092
1093 thread.update(cx, |thread, cx| {
1094 thread.push_acp_agent_block(block, cx);
1095 });
1096 }
1097 }
1098
1099 last_is_user = role == context_server::types::Role::User;
1100 }
1101
1102 let response_stream = thread.update(cx, |thread, cx| {
1103 if last_is_user {
1104 thread.send_existing(cx)
1105 } else {
1106 // Resume if MCP prompt did not end with a user message
1107 thread.resume(cx)
1108 }
1109 })?;
1110
1111 cx.update(|cx| {
1112 NativeAgentConnection::handle_thread_events(
1113 response_stream,
1114 acp_thread.downgrade(),
1115 cx,
1116 )
1117 })
1118 .await
1119 })
1120 }
1121}
1122
1123/// Wrapper struct that implements the AgentConnection trait
1124#[derive(Clone)]
1125pub struct NativeAgentConnection(pub Entity<NativeAgent>);
1126
1127impl NativeAgentConnection {
1128 pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option<Entity<Thread>> {
1129 self.0
1130 .read(cx)
1131 .sessions
1132 .get(session_id)
1133 .map(|session| session.thread.clone())
1134 }
1135
1136 pub fn load_thread(
1137 &self,
1138 id: acp::SessionId,
1139 project: Entity<Project>,
1140 cx: &mut App,
1141 ) -> Task<Result<Entity<Thread>>> {
1142 self.0
1143 .update(cx, |this, cx| this.load_thread(id, project, cx))
1144 }
1145
1146 fn run_turn(
1147 &self,
1148 session_id: acp::SessionId,
1149 cx: &mut App,
1150 f: impl 'static
1151 + FnOnce(Entity<Thread>, &mut App) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>,
1152 ) -> Task<Result<acp::PromptResponse>> {
1153 let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
1154 agent
1155 .sessions
1156 .get(&session_id)
1157 .map(|s| (s.thread.clone(), s.acp_thread.clone()))
1158 }) else {
1159 return Task::ready(Err(anyhow!("Session not found")));
1160 };
1161 log::debug!("Found session for: {}", session_id);
1162
1163 let response_stream = match f(thread, cx) {
1164 Ok(stream) => stream,
1165 Err(err) => return Task::ready(Err(err)),
1166 };
1167 Self::handle_thread_events(response_stream, acp_thread.downgrade(), cx)
1168 }
1169
1170 fn handle_thread_events(
1171 mut events: mpsc::UnboundedReceiver<Result<ThreadEvent>>,
1172 acp_thread: WeakEntity<AcpThread>,
1173 cx: &App,
1174 ) -> Task<Result<acp::PromptResponse>> {
1175 cx.spawn(async move |cx| {
1176 // Handle response stream and forward to session.acp_thread
1177 while let Some(result) = events.next().await {
1178 match result {
1179 Ok(event) => {
1180 log::trace!("Received completion event: {:?}", event);
1181
1182 match event {
1183 ThreadEvent::UserMessage(message) => {
1184 acp_thread.update(cx, |thread, cx| {
1185 for content in message.content {
1186 thread.push_user_content_block(
1187 Some(message.id.clone()),
1188 content.into(),
1189 cx,
1190 );
1191 }
1192 })?;
1193 }
1194 ThreadEvent::AgentText(text) => {
1195 acp_thread.update(cx, |thread, cx| {
1196 thread.push_assistant_content_block(text.into(), false, cx)
1197 })?;
1198 }
1199 ThreadEvent::AgentThinking(text) => {
1200 acp_thread.update(cx, |thread, cx| {
1201 thread.push_assistant_content_block(text.into(), true, cx)
1202 })?;
1203 }
1204 ThreadEvent::ToolCallAuthorization(ToolCallAuthorization {
1205 tool_call,
1206 options,
1207 response,
1208 context: _,
1209 }) => {
1210 let outcome_task = acp_thread.update(cx, |thread, cx| {
1211 thread.request_tool_call_authorization(tool_call, options, cx)
1212 })??;
1213 cx.background_spawn(async move {
1214 if let acp_thread::RequestPermissionOutcome::Selected(outcome) =
1215 outcome_task.await
1216 {
1217 response
1218 .send(outcome)
1219 .map(|_| anyhow!("authorization receiver was dropped"))
1220 .log_err();
1221 }
1222 })
1223 .detach();
1224 }
1225 ThreadEvent::ToolCall(tool_call) => {
1226 acp_thread.update(cx, |thread, cx| {
1227 thread.upsert_tool_call(tool_call, cx)
1228 })??;
1229 }
1230 ThreadEvent::ToolCallUpdate(update) => {
1231 acp_thread.update(cx, |thread, cx| {
1232 thread.update_tool_call(update, cx)
1233 })??;
1234 }
1235 ThreadEvent::Plan(plan) => {
1236 acp_thread.update(cx, |thread, cx| thread.update_plan(plan, cx))?;
1237 }
1238 ThreadEvent::SubagentSpawned(session_id) => {
1239 acp_thread.update(cx, |thread, cx| {
1240 thread.subagent_spawned(session_id, cx);
1241 })?;
1242 }
1243 ThreadEvent::Retry(status) => {
1244 acp_thread.update(cx, |thread, cx| {
1245 thread.update_retry_status(status, cx)
1246 })?;
1247 }
1248 ThreadEvent::Stop(stop_reason) => {
1249 log::debug!("Assistant message complete: {:?}", stop_reason);
1250 return Ok(acp::PromptResponse::new(stop_reason));
1251 }
1252 }
1253 }
1254 Err(e) => {
1255 log::error!("Error in model response stream: {:?}", e);
1256 return Err(e);
1257 }
1258 }
1259 }
1260
1261 log::debug!("Response stream completed");
1262 anyhow::Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
1263 })
1264 }
1265}
1266
1267struct Command<'a> {
1268 prompt_name: &'a str,
1269 arg_value: &'a str,
1270 explicit_server_id: Option<&'a str>,
1271}
1272
1273impl<'a> Command<'a> {
1274 fn parse(prompt: &'a [acp::ContentBlock]) -> Option<Self> {
1275 let acp::ContentBlock::Text(text_content) = prompt.first()? else {
1276 return None;
1277 };
1278 let text = text_content.text.trim();
1279 let command = text.strip_prefix('/')?;
1280 let (command, arg_value) = command
1281 .split_once(char::is_whitespace)
1282 .unwrap_or((command, ""));
1283
1284 if let Some((server_id, prompt_name)) = command.split_once('.') {
1285 Some(Self {
1286 prompt_name,
1287 arg_value,
1288 explicit_server_id: Some(server_id),
1289 })
1290 } else {
1291 Some(Self {
1292 prompt_name: command,
1293 arg_value,
1294 explicit_server_id: None,
1295 })
1296 }
1297 }
1298}
1299
1300struct NativeAgentModelSelector {
1301 session_id: acp::SessionId,
1302 connection: NativeAgentConnection,
1303}
1304
1305impl acp_thread::AgentModelSelector for NativeAgentModelSelector {
1306 fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
1307 log::debug!("NativeAgentConnection::list_models called");
1308 let list = self.connection.0.read(cx).models.model_list.clone();
1309 Task::ready(if list.is_empty() {
1310 Err(anyhow::anyhow!("No models available"))
1311 } else {
1312 Ok(list)
1313 })
1314 }
1315
1316 fn select_model(&self, model_id: acp::ModelId, cx: &mut App) -> Task<Result<()>> {
1317 log::debug!(
1318 "Setting model for session {}: {}",
1319 self.session_id,
1320 model_id
1321 );
1322 let Some(thread) = self
1323 .connection
1324 .0
1325 .read(cx)
1326 .sessions
1327 .get(&self.session_id)
1328 .map(|session| session.thread.clone())
1329 else {
1330 return Task::ready(Err(anyhow!("Session not found")));
1331 };
1332
1333 let Some(model) = self.connection.0.read(cx).models.model_from_id(&model_id) else {
1334 return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
1335 };
1336
1337 // We want to reset the effort level when switching models, as the currently-selected effort level may
1338 // not be compatible.
1339 let effort = model
1340 .default_effort_level()
1341 .map(|effort_level| effort_level.value.to_string());
1342
1343 thread.update(cx, |thread, cx| {
1344 thread.set_model(model.clone(), cx);
1345 thread.set_thinking_effort(effort.clone(), cx);
1346 thread.set_thinking_enabled(model.supports_thinking(), cx);
1347 });
1348
1349 update_settings_file(
1350 self.connection.0.read(cx).fs.clone(),
1351 cx,
1352 move |settings, cx| {
1353 let provider = model.provider_id().0.to_string();
1354 let model = model.id().0.to_string();
1355 let enable_thinking = thread.read(cx).thinking_enabled();
1356 let speed = thread.read(cx).speed();
1357 settings
1358 .agent
1359 .get_or_insert_default()
1360 .set_model(LanguageModelSelection {
1361 provider: provider.into(),
1362 model,
1363 enable_thinking,
1364 effort,
1365 speed,
1366 });
1367 },
1368 );
1369
1370 Task::ready(Ok(()))
1371 }
1372
1373 fn selected_model(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelInfo>> {
1374 let Some(thread) = self
1375 .connection
1376 .0
1377 .read(cx)
1378 .sessions
1379 .get(&self.session_id)
1380 .map(|session| session.thread.clone())
1381 else {
1382 return Task::ready(Err(anyhow!("Session not found")));
1383 };
1384 let Some(model) = thread.read(cx).model() else {
1385 return Task::ready(Err(anyhow!("Model not found")));
1386 };
1387 let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
1388 else {
1389 return Task::ready(Err(anyhow!("Provider not found")));
1390 };
1391 Task::ready(Ok(LanguageModels::map_language_model_to_info(
1392 model, &provider,
1393 )))
1394 }
1395
1396 fn watch(&self, cx: &mut App) -> Option<watch::Receiver<()>> {
1397 Some(self.connection.0.read(cx).models.watch())
1398 }
1399
1400 fn should_render_footer(&self) -> bool {
1401 true
1402 }
1403}
1404
1405pub static ZED_AGENT_ID: LazyLock<AgentId> = LazyLock::new(|| AgentId::new("Zed Agent"));
1406
1407impl acp_thread::AgentConnection for NativeAgentConnection {
1408 fn agent_id(&self) -> AgentId {
1409 ZED_AGENT_ID.clone()
1410 }
1411
1412 fn telemetry_id(&self) -> SharedString {
1413 "zed".into()
1414 }
1415
1416 fn new_session(
1417 self: Rc<Self>,
1418 project: Entity<Project>,
1419 work_dirs: PathList,
1420 cx: &mut App,
1421 ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
1422 log::debug!("Creating new thread for project at: {work_dirs:?}");
1423 Task::ready(Ok(self
1424 .0
1425 .update(cx, |agent, cx| agent.new_session(project, cx))))
1426 }
1427
1428 fn supports_load_session(&self) -> bool {
1429 true
1430 }
1431
1432 fn load_session(
1433 self: Rc<Self>,
1434 session_id: acp::SessionId,
1435 project: Entity<Project>,
1436 _work_dirs: PathList,
1437 _title: Option<SharedString>,
1438 cx: &mut App,
1439 ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
1440 self.0
1441 .update(cx, |agent, cx| agent.open_thread(session_id, project, cx))
1442 }
1443
1444 fn supports_close_session(&self) -> bool {
1445 true
1446 }
1447
1448 fn close_session(
1449 self: Rc<Self>,
1450 session_id: &acp::SessionId,
1451 cx: &mut App,
1452 ) -> Task<Result<()>> {
1453 self.0.update(cx, |agent, cx| {
1454 let thread = agent.sessions.get(session_id).map(|s| s.thread.clone());
1455 if let Some(thread) = thread {
1456 agent.save_thread(thread, cx);
1457 }
1458
1459 let Some(session) = agent.sessions.remove(session_id) else {
1460 return Task::ready(Ok(()));
1461 };
1462 let project_id = session.project_id;
1463
1464 let has_remaining = agent.sessions.values().any(|s| s.project_id == project_id);
1465 if !has_remaining {
1466 agent.projects.remove(&project_id);
1467 }
1468
1469 session.pending_save
1470 })
1471 }
1472
1473 fn auth_methods(&self) -> &[acp::AuthMethod] {
1474 &[] // No auth for in-process
1475 }
1476
1477 fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
1478 Task::ready(Ok(()))
1479 }
1480
1481 fn model_selector(&self, session_id: &acp::SessionId) -> Option<Rc<dyn AgentModelSelector>> {
1482 Some(Rc::new(NativeAgentModelSelector {
1483 session_id: session_id.clone(),
1484 connection: self.clone(),
1485 }) as Rc<dyn AgentModelSelector>)
1486 }
1487
1488 fn prompt(
1489 &self,
1490 id: Option<acp_thread::UserMessageId>,
1491 params: acp::PromptRequest,
1492 cx: &mut App,
1493 ) -> Task<Result<acp::PromptResponse>> {
1494 let id = id.expect("UserMessageId is required");
1495 let session_id = params.session_id.clone();
1496 log::info!("Received prompt request for session: {}", session_id);
1497 log::debug!("Prompt blocks count: {}", params.prompt.len());
1498
1499 let Some(project_state) = self.0.read(cx).session_project_state(&session_id) else {
1500 return Task::ready(Err(anyhow::anyhow!("Session not found")));
1501 };
1502
1503 if let Some(parsed_command) = Command::parse(¶ms.prompt) {
1504 let registry = project_state.context_server_registry.read(cx);
1505
1506 let explicit_server_id = parsed_command
1507 .explicit_server_id
1508 .map(|server_id| ContextServerId(server_id.into()));
1509
1510 if let Some(prompt) =
1511 registry.find_prompt(explicit_server_id.as_ref(), parsed_command.prompt_name)
1512 {
1513 let arguments = if !parsed_command.arg_value.is_empty()
1514 && let Some(arg_name) = prompt
1515 .prompt
1516 .arguments
1517 .as_ref()
1518 .and_then(|args| args.first())
1519 .map(|arg| arg.name.clone())
1520 {
1521 HashMap::from_iter([(arg_name, parsed_command.arg_value.to_string())])
1522 } else {
1523 Default::default()
1524 };
1525
1526 let prompt_name = prompt.prompt.name.clone();
1527 let server_id = prompt.server_id.clone();
1528
1529 return self.0.update(cx, |agent, cx| {
1530 agent.send_mcp_prompt(
1531 id,
1532 session_id.clone(),
1533 prompt_name,
1534 server_id,
1535 arguments,
1536 params.prompt,
1537 cx,
1538 )
1539 });
1540 }
1541 };
1542
1543 let path_style = project_state.project.read(cx).path_style(cx);
1544
1545 self.run_turn(session_id, cx, move |thread, cx| {
1546 let content: Vec<UserMessageContent> = params
1547 .prompt
1548 .into_iter()
1549 .map(|block| UserMessageContent::from_content_block(block, path_style))
1550 .collect::<Vec<_>>();
1551 log::debug!("Converted prompt to message: {} chars", content.len());
1552 log::debug!("Message id: {:?}", id);
1553 log::debug!("Message content: {:?}", content);
1554
1555 thread.update(cx, |thread, cx| thread.send(id, content, cx))
1556 })
1557 }
1558
1559 fn retry(
1560 &self,
1561 session_id: &acp::SessionId,
1562 _cx: &App,
1563 ) -> Option<Rc<dyn acp_thread::AgentSessionRetry>> {
1564 Some(Rc::new(NativeAgentSessionRetry {
1565 connection: self.clone(),
1566 session_id: session_id.clone(),
1567 }) as _)
1568 }
1569
1570 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
1571 log::info!("Cancelling on session: {}", session_id);
1572 self.0.update(cx, |agent, cx| {
1573 if let Some(session) = agent.sessions.get(session_id) {
1574 session
1575 .thread
1576 .update(cx, |thread, cx| thread.cancel(cx))
1577 .detach();
1578 }
1579 });
1580 }
1581
1582 fn truncate(
1583 &self,
1584 session_id: &acp::SessionId,
1585 cx: &App,
1586 ) -> Option<Rc<dyn acp_thread::AgentSessionTruncate>> {
1587 self.0.read_with(cx, |agent, _cx| {
1588 agent.sessions.get(session_id).map(|session| {
1589 Rc::new(NativeAgentSessionTruncate {
1590 thread: session.thread.clone(),
1591 acp_thread: session.acp_thread.downgrade(),
1592 }) as _
1593 })
1594 })
1595 }
1596
1597 fn set_title(
1598 &self,
1599 session_id: &acp::SessionId,
1600 cx: &App,
1601 ) -> Option<Rc<dyn acp_thread::AgentSessionSetTitle>> {
1602 self.0.read_with(cx, |agent, _cx| {
1603 agent
1604 .sessions
1605 .get(session_id)
1606 .filter(|s| !s.thread.read(cx).is_subagent())
1607 .map(|session| {
1608 Rc::new(NativeAgentSessionSetTitle {
1609 thread: session.thread.clone(),
1610 }) as _
1611 })
1612 })
1613 }
1614
1615 fn session_list(&self, cx: &mut App) -> Option<Rc<dyn AgentSessionList>> {
1616 let thread_store = self.0.read(cx).thread_store.clone();
1617 Some(Rc::new(NativeAgentSessionList::new(thread_store, cx)) as _)
1618 }
1619
1620 fn telemetry(&self) -> Option<Rc<dyn acp_thread::AgentTelemetry>> {
1621 Some(Rc::new(self.clone()) as Rc<dyn acp_thread::AgentTelemetry>)
1622 }
1623
1624 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1625 self
1626 }
1627}
1628
1629impl acp_thread::AgentTelemetry for NativeAgentConnection {
1630 fn thread_data(
1631 &self,
1632 session_id: &acp::SessionId,
1633 cx: &mut App,
1634 ) -> Task<Result<serde_json::Value>> {
1635 let Some(session) = self.0.read(cx).sessions.get(session_id) else {
1636 return Task::ready(Err(anyhow!("Session not found")));
1637 };
1638
1639 let task = session.thread.read(cx).to_db(cx);
1640 cx.background_spawn(async move {
1641 serde_json::to_value(task.await).context("Failed to serialize thread")
1642 })
1643 }
1644}
1645
1646pub struct NativeAgentSessionList {
1647 thread_store: Entity<ThreadStore>,
1648 updates_tx: smol::channel::Sender<acp_thread::SessionListUpdate>,
1649 updates_rx: smol::channel::Receiver<acp_thread::SessionListUpdate>,
1650 _subscription: Subscription,
1651}
1652
1653impl NativeAgentSessionList {
1654 fn new(thread_store: Entity<ThreadStore>, cx: &mut App) -> Self {
1655 let (tx, rx) = smol::channel::unbounded();
1656 let this_tx = tx.clone();
1657 let subscription = cx.observe(&thread_store, move |_, _| {
1658 this_tx
1659 .try_send(acp_thread::SessionListUpdate::Refresh)
1660 .ok();
1661 });
1662 Self {
1663 thread_store,
1664 updates_tx: tx,
1665 updates_rx: rx,
1666 _subscription: subscription,
1667 }
1668 }
1669
1670 pub fn thread_store(&self) -> &Entity<ThreadStore> {
1671 &self.thread_store
1672 }
1673}
1674
1675impl AgentSessionList for NativeAgentSessionList {
1676 fn list_sessions(
1677 &self,
1678 _request: AgentSessionListRequest,
1679 cx: &mut App,
1680 ) -> Task<Result<AgentSessionListResponse>> {
1681 let sessions = self
1682 .thread_store
1683 .read(cx)
1684 .entries()
1685 .map(|entry| AgentSessionInfo::from(&entry))
1686 .collect();
1687 Task::ready(Ok(AgentSessionListResponse::new(sessions)))
1688 }
1689
1690 fn supports_delete(&self) -> bool {
1691 true
1692 }
1693
1694 fn delete_session(&self, session_id: &acp::SessionId, cx: &mut App) -> Task<Result<()>> {
1695 self.thread_store
1696 .update(cx, |store, cx| store.delete_thread(session_id.clone(), cx))
1697 }
1698
1699 fn delete_sessions(&self, cx: &mut App) -> Task<Result<()>> {
1700 self.thread_store
1701 .update(cx, |store, cx| store.delete_threads(cx))
1702 }
1703
1704 fn watch(
1705 &self,
1706 _cx: &mut App,
1707 ) -> Option<smol::channel::Receiver<acp_thread::SessionListUpdate>> {
1708 Some(self.updates_rx.clone())
1709 }
1710
1711 fn notify_refresh(&self) {
1712 self.updates_tx
1713 .try_send(acp_thread::SessionListUpdate::Refresh)
1714 .ok();
1715 }
1716
1717 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1718 self
1719 }
1720}
1721
1722struct NativeAgentSessionTruncate {
1723 thread: Entity<Thread>,
1724 acp_thread: WeakEntity<AcpThread>,
1725}
1726
1727impl acp_thread::AgentSessionTruncate for NativeAgentSessionTruncate {
1728 fn run(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
1729 match self.thread.update(cx, |thread, cx| {
1730 thread.truncate(message_id.clone(), cx)?;
1731 Ok(thread.latest_token_usage())
1732 }) {
1733 Ok(usage) => {
1734 self.acp_thread
1735 .update(cx, |thread, cx| {
1736 thread.update_token_usage(usage, cx);
1737 })
1738 .ok();
1739 Task::ready(Ok(()))
1740 }
1741 Err(error) => Task::ready(Err(error)),
1742 }
1743 }
1744}
1745
1746struct NativeAgentSessionRetry {
1747 connection: NativeAgentConnection,
1748 session_id: acp::SessionId,
1749}
1750
1751impl acp_thread::AgentSessionRetry for NativeAgentSessionRetry {
1752 fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>> {
1753 self.connection
1754 .run_turn(self.session_id.clone(), cx, |thread, cx| {
1755 thread.update(cx, |thread, cx| thread.resume(cx))
1756 })
1757 }
1758}
1759
1760struct NativeAgentSessionSetTitle {
1761 thread: Entity<Thread>,
1762}
1763
1764impl acp_thread::AgentSessionSetTitle for NativeAgentSessionSetTitle {
1765 fn run(&self, title: SharedString, cx: &mut App) -> Task<Result<()>> {
1766 self.thread
1767 .update(cx, |thread, cx| thread.set_title(title, cx));
1768 Task::ready(Ok(()))
1769 }
1770}
1771
1772pub struct NativeThreadEnvironment {
1773 agent: WeakEntity<NativeAgent>,
1774 thread: WeakEntity<Thread>,
1775 acp_thread: WeakEntity<AcpThread>,
1776}
1777
1778impl NativeThreadEnvironment {
1779 pub(crate) fn create_subagent_thread(
1780 &self,
1781 label: String,
1782 cx: &mut App,
1783 ) -> Result<Rc<dyn SubagentHandle>> {
1784 let Some(parent_thread_entity) = self.thread.upgrade() else {
1785 anyhow::bail!("Parent thread no longer exists".to_string());
1786 };
1787 let parent_thread = parent_thread_entity.read(cx);
1788 let current_depth = parent_thread.depth();
1789 let parent_session_id = parent_thread.session_id().clone();
1790
1791 if current_depth >= MAX_SUBAGENT_DEPTH {
1792 return Err(anyhow!(
1793 "Maximum subagent depth ({}) reached",
1794 MAX_SUBAGENT_DEPTH
1795 ));
1796 }
1797
1798 let subagent_thread: Entity<Thread> = cx.new(|cx| {
1799 let mut thread = Thread::new_subagent(&parent_thread_entity, cx);
1800 thread.set_title(label.into(), cx);
1801 thread
1802 });
1803
1804 let subagent_thread_id = subagent_thread.read(cx).id();
1805
1806 let acp_thread = self
1807 .agent
1808 .update(cx, |agent, cx| -> Result<Entity<AcpThread>> {
1809 let project_id = agent
1810 .sessions
1811 .get(&parent_session_id)
1812 .map(|s| s.project_id)
1813 .context("parent session not found")?;
1814 Ok(agent.register_session(subagent_thread.clone(), project_id, cx))
1815 })??;
1816
1817 let depth = current_depth + 1;
1818
1819 telemetry::event!(
1820 "Subagent Started",
1821 session = parent_thread_entity.read(cx).id().to_string(),
1822 subagent_session = subagent_thread_id.to_string(),
1823 depth,
1824 is_resumed = false,
1825 );
1826
1827 self.prompt_subagent(subagent_thread_id, subagent_thread, acp_thread)
1828 }
1829
1830 pub(crate) fn resume_subagent_thread(
1831 &self,
1832 session_id: acp::SessionId,
1833 cx: &mut App,
1834 ) -> Result<Rc<dyn SubagentHandle>> {
1835 let (subagent_thread, acp_thread, thread_id) = self.agent.update(cx, |agent, _cx| {
1836 let session = agent
1837 .sessions
1838 .get(&session_id)
1839 .ok_or_else(|| anyhow!("No subagent session found with id {session_id}"))?;
1840 let thread_id = session.thread.read(_cx).id();
1841 anyhow::Ok((
1842 session.thread.clone(),
1843 session.acp_thread.clone(),
1844 thread_id,
1845 ))
1846 })??;
1847
1848 let depth = subagent_thread.read(cx).depth();
1849
1850 if let Some(parent_thread_entity) = self.thread.upgrade() {
1851 telemetry::event!(
1852 "Subagent Started",
1853 session = parent_thread_entity.read(cx).id().to_string(),
1854 subagent_session = session_id.to_string(),
1855 depth,
1856 is_resumed = true,
1857 );
1858 }
1859
1860 self.prompt_subagent(thread_id, subagent_thread, acp_thread)
1861 }
1862
1863 fn prompt_subagent(
1864 &self,
1865 thread_id: ThreadId,
1866 subagent_thread: Entity<Thread>,
1867 acp_thread: Entity<acp_thread::AcpThread>,
1868 ) -> Result<Rc<dyn SubagentHandle>> {
1869 let Some(parent_thread_entity) = self.thread.upgrade() else {
1870 anyhow::bail!("Parent thread no longer exists".to_string());
1871 };
1872 Ok(Rc::new(NativeSubagentHandle::new(
1873 thread_id,
1874 subagent_thread,
1875 acp_thread,
1876 parent_thread_entity,
1877 )) as _)
1878 }
1879}
1880
1881impl ThreadEnvironment for NativeThreadEnvironment {
1882 fn create_terminal(
1883 &self,
1884 command: String,
1885 cwd: Option<PathBuf>,
1886 output_byte_limit: Option<u64>,
1887 cx: &mut AsyncApp,
1888 ) -> Task<Result<Rc<dyn TerminalHandle>>> {
1889 let task = self.acp_thread.update(cx, |thread, cx| {
1890 thread.create_terminal(command, vec![], vec![], cwd, output_byte_limit, cx)
1891 });
1892
1893 let acp_thread = self.acp_thread.clone();
1894 cx.spawn(async move |cx| {
1895 let terminal = task?.await?;
1896
1897 let (drop_tx, drop_rx) = oneshot::channel();
1898 let terminal_id = terminal.read_with(cx, |terminal, _cx| terminal.id().clone());
1899
1900 cx.spawn(async move |cx| {
1901 drop_rx.await.ok();
1902 acp_thread.update(cx, |thread, cx| thread.release_terminal(terminal_id, cx))
1903 })
1904 .detach();
1905
1906 let handle = AcpTerminalHandle {
1907 terminal,
1908 _drop_tx: Some(drop_tx),
1909 };
1910
1911 Ok(Rc::new(handle) as _)
1912 })
1913 }
1914
1915 fn create_subagent(&self, label: String, cx: &mut App) -> Result<Rc<dyn SubagentHandle>> {
1916 self.create_subagent_thread(label, cx)
1917 }
1918
1919 fn resume_subagent(
1920 &self,
1921 session_id: acp::SessionId,
1922 cx: &mut App,
1923 ) -> Result<Rc<dyn SubagentHandle>> {
1924 self.resume_subagent_thread(session_id, cx)
1925 }
1926}
1927
1928#[derive(Debug, Clone)]
1929enum SubagentPromptResult {
1930 Completed,
1931 Cancelled,
1932 ContextWindowWarning,
1933 Error(String),
1934}
1935
1936pub struct NativeSubagentHandle {
1937 thread_id: ThreadId,
1938 parent_thread: WeakEntity<Thread>,
1939 subagent_thread: Entity<Thread>,
1940 acp_thread: Entity<acp_thread::AcpThread>,
1941}
1942
1943impl NativeSubagentHandle {
1944 fn new(
1945 thread_id: ThreadId,
1946 subagent_thread: Entity<Thread>,
1947 acp_thread: Entity<acp_thread::AcpThread>,
1948 parent_thread_entity: Entity<Thread>,
1949 ) -> Self {
1950 NativeSubagentHandle {
1951 thread_id,
1952 subagent_thread,
1953 parent_thread: parent_thread_entity.downgrade(),
1954 acp_thread,
1955 }
1956 }
1957}
1958
1959impl SubagentHandle for NativeSubagentHandle {
1960 fn id(&self) -> ThreadId {
1961 self.thread_id
1962 }
1963
1964 fn session_id(&self, cx: &App) -> acp::SessionId {
1965 self.subagent_thread.read(cx).session_id().clone()
1966 }
1967
1968 fn num_entries(&self, cx: &App) -> usize {
1969 self.acp_thread.read(cx).entries().len()
1970 }
1971
1972 fn send(&self, message: String, cx: &AsyncApp) -> Task<Result<String>> {
1973 let thread = self.subagent_thread.clone();
1974 let acp_thread = self.acp_thread.clone();
1975 let subagent_thread_id = self.thread_id;
1976 let parent_thread = self.parent_thread.clone();
1977
1978 cx.spawn(async move |cx| {
1979 let (task, _subscription) = cx.update(|cx| {
1980 let ratio_before_prompt = thread
1981 .read(cx)
1982 .latest_token_usage()
1983 .map(|usage| usage.ratio());
1984
1985 parent_thread
1986 .update(cx, |parent_thread, _cx| {
1987 parent_thread.register_running_subagent(thread.downgrade())
1988 })
1989 .ok();
1990
1991 let task = acp_thread.update(cx, |acp_thread, cx| {
1992 acp_thread.send(vec![message.into()], cx)
1993 });
1994
1995 let (token_limit_tx, token_limit_rx) = oneshot::channel::<()>();
1996 let mut token_limit_tx = Some(token_limit_tx);
1997
1998 let subscription = cx.subscribe(
1999 &thread,
2000 move |_thread, event: &TokenUsageUpdated, _cx| {
2001 if let Some(usage) = &event.0 {
2002 let old_ratio = ratio_before_prompt
2003 .clone()
2004 .unwrap_or(TokenUsageRatio::Normal);
2005 let new_ratio = usage.ratio();
2006 if old_ratio == TokenUsageRatio::Normal
2007 && new_ratio == TokenUsageRatio::Warning
2008 {
2009 if let Some(tx) = token_limit_tx.take() {
2010 tx.send(()).ok();
2011 }
2012 }
2013 }
2014 },
2015 );
2016
2017 let wait_for_prompt = cx
2018 .background_spawn(async move {
2019 futures::select! {
2020 response = task.fuse() => match response {
2021 Ok(Some(response)) => {
2022 match response.stop_reason {
2023 acp::StopReason::Cancelled => SubagentPromptResult::Cancelled,
2024 acp::StopReason::MaxTokens => SubagentPromptResult::Error("The agent reached the maximum number of tokens.".into()),
2025 acp::StopReason::MaxTurnRequests => SubagentPromptResult::Error("The agent reached the maximum number of allowed requests between user turns. Try prompting again.".into()),
2026 acp::StopReason::Refusal => SubagentPromptResult::Error("The agent refused to process that prompt. Try again.".into()),
2027 acp::StopReason::EndTurn | _ => SubagentPromptResult::Completed,
2028 }
2029 }
2030 Ok(None) => SubagentPromptResult::Error("No response from the agent. You can try messaging again.".into()),
2031 Err(error) => SubagentPromptResult::Error(error.to_string()),
2032 },
2033 _ = token_limit_rx.fuse() => SubagentPromptResult::ContextWindowWarning,
2034 }
2035 });
2036
2037 (wait_for_prompt, subscription)
2038 });
2039
2040 let result = match task.await {
2041 SubagentPromptResult::Completed => thread.read_with(cx, |thread, _cx| {
2042 thread
2043 .last_message()
2044 .and_then(|message| {
2045 let content = message.as_agent_message()?
2046 .content
2047 .iter()
2048 .filter_map(|c| match c {
2049 AgentMessageContent::Text(text) => Some(text.as_str()),
2050 _ => None,
2051 })
2052 .join("\n\n");
2053 if content.is_empty() {
2054 None
2055 } else {
2056 Some( content)
2057 }
2058 })
2059 .context("No response from subagent")
2060 }),
2061 SubagentPromptResult::Cancelled => Err(anyhow!("User canceled")),
2062 SubagentPromptResult::Error(message) => Err(anyhow!("{message}")),
2063 SubagentPromptResult::ContextWindowWarning => {
2064 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
2065 Err(anyhow!(
2066 "The agent is nearing the end of its context window and has been \
2067 stopped. You can prompt the thread again to have the agent wrap up \
2068 or hand off its work."
2069 ))
2070 }
2071 };
2072
2073 parent_thread
2074 .update(cx, |parent_thread, cx| {
2075 parent_thread.unregister_running_subagent(subagent_thread_id, cx)
2076 })
2077 .ok();
2078
2079 result
2080 })
2081 }
2082}
2083
2084pub struct AcpTerminalHandle {
2085 terminal: Entity<acp_thread::Terminal>,
2086 _drop_tx: Option<oneshot::Sender<()>>,
2087}
2088
2089impl TerminalHandle for AcpTerminalHandle {
2090 fn id(&self, cx: &AsyncApp) -> Result<acp::TerminalId> {
2091 Ok(self.terminal.read_with(cx, |term, _cx| term.id().clone()))
2092 }
2093
2094 fn wait_for_exit(&self, cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
2095 Ok(self
2096 .terminal
2097 .read_with(cx, |term, _cx| term.wait_for_exit()))
2098 }
2099
2100 fn current_output(&self, cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
2101 Ok(self
2102 .terminal
2103 .read_with(cx, |term, cx| term.current_output(cx)))
2104 }
2105
2106 fn kill(&self, cx: &AsyncApp) -> Result<()> {
2107 cx.update(|cx| {
2108 self.terminal.update(cx, |terminal, cx| {
2109 terminal.kill(cx);
2110 });
2111 });
2112 Ok(())
2113 }
2114
2115 fn was_stopped_by_user(&self, cx: &AsyncApp) -> Result<bool> {
2116 Ok(self
2117 .terminal
2118 .read_with(cx, |term, _cx| term.was_stopped_by_user()))
2119 }
2120}
2121
2122#[cfg(test)]
2123mod internal_tests {
2124 use std::path::Path;
2125
2126 use super::*;
2127 use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelInfo, MentionUri};
2128 use fs::FakeFs;
2129 use gpui::TestAppContext;
2130 use indoc::formatdoc;
2131 use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
2132 use language_model::{
2133 LanguageModelCompletionEvent, LanguageModelProviderId, LanguageModelProviderName,
2134 };
2135 use serde_json::json;
2136 use settings::SettingsStore;
2137 use util::{path, rel_path::rel_path};
2138
2139 #[gpui::test]
2140 async fn test_maintaining_project_context(cx: &mut TestAppContext) {
2141 init_test(cx);
2142 let fs = FakeFs::new(cx.executor());
2143 fs.insert_tree(
2144 "/",
2145 json!({
2146 "a": {}
2147 }),
2148 )
2149 .await;
2150 let project = Project::test(fs.clone(), [], cx).await;
2151 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2152 let agent =
2153 cx.update(|cx| NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx));
2154
2155 // Creating a session registers the project and triggers context building.
2156 let connection = NativeAgentConnection(agent.clone());
2157 let _acp_thread = cx
2158 .update(|cx| {
2159 Rc::new(connection).new_session(
2160 project.clone(),
2161 PathList::new(&[Path::new("/")]),
2162 cx,
2163 )
2164 })
2165 .await
2166 .unwrap();
2167 cx.run_until_parked();
2168
2169 let thread = agent.read_with(cx, |agent, _cx| {
2170 agent.sessions.values().next().unwrap().thread.clone()
2171 });
2172
2173 agent.read_with(cx, |agent, cx| {
2174 let project_id = project.entity_id();
2175 let state = agent.projects.get(&project_id).unwrap();
2176 assert_eq!(state.project_context.read(cx).worktrees, vec![]);
2177 assert_eq!(thread.read(cx).project_context().read(cx).worktrees, vec![]);
2178 });
2179
2180 let worktree = project
2181 .update(cx, |project, cx| project.create_worktree("/a", true, cx))
2182 .await
2183 .unwrap();
2184 cx.run_until_parked();
2185 agent.read_with(cx, |agent, cx| {
2186 let project_id = project.entity_id();
2187 let state = agent.projects.get(&project_id).unwrap();
2188 let expected_worktrees = vec![WorktreeContext {
2189 root_name: "a".into(),
2190 abs_path: Path::new("/a").into(),
2191 rules_file: None,
2192 }];
2193 assert_eq!(state.project_context.read(cx).worktrees, expected_worktrees);
2194 assert_eq!(
2195 thread.read(cx).project_context().read(cx).worktrees,
2196 expected_worktrees
2197 );
2198 });
2199
2200 // Creating `/a/.rules` updates the project context.
2201 fs.insert_file("/a/.rules", Vec::new()).await;
2202 cx.run_until_parked();
2203 agent.read_with(cx, |agent, cx| {
2204 let project_id = project.entity_id();
2205 let state = agent.projects.get(&project_id).unwrap();
2206 let rules_entry = worktree
2207 .read(cx)
2208 .entry_for_path(rel_path(".rules"))
2209 .unwrap();
2210 let expected_worktrees = vec![WorktreeContext {
2211 root_name: "a".into(),
2212 abs_path: Path::new("/a").into(),
2213 rules_file: Some(RulesFileContext {
2214 path_in_worktree: rel_path(".rules").into(),
2215 text: "".into(),
2216 project_entry_id: rules_entry.id.to_usize(),
2217 }),
2218 }];
2219 assert_eq!(state.project_context.read(cx).worktrees, expected_worktrees);
2220 assert_eq!(
2221 thread.read(cx).project_context().read(cx).worktrees,
2222 expected_worktrees
2223 );
2224 });
2225 }
2226
2227 #[gpui::test]
2228 async fn test_listing_models(cx: &mut TestAppContext) {
2229 init_test(cx);
2230 let fs = FakeFs::new(cx.executor());
2231 fs.insert_tree("/", json!({ "a": {} })).await;
2232 let project = Project::test(fs.clone(), [], cx).await;
2233 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2234 let connection =
2235 NativeAgentConnection(cx.update(|cx| {
2236 NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx)
2237 }));
2238
2239 // Create a thread/session
2240 let acp_thread = cx
2241 .update(|cx| {
2242 Rc::new(connection.clone()).new_session(
2243 project.clone(),
2244 PathList::new(&[Path::new("/a")]),
2245 cx,
2246 )
2247 })
2248 .await
2249 .unwrap();
2250
2251 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2252
2253 let models = cx
2254 .update(|cx| {
2255 connection
2256 .model_selector(&session_id)
2257 .unwrap()
2258 .list_models(cx)
2259 })
2260 .await
2261 .unwrap();
2262
2263 let acp_thread::AgentModelList::Grouped(models) = models else {
2264 panic!("Unexpected model group");
2265 };
2266 assert_eq!(
2267 models,
2268 IndexMap::from_iter([(
2269 AgentModelGroupName("Fake".into()),
2270 vec![AgentModelInfo {
2271 id: acp::ModelId::new("fake/fake"),
2272 name: "Fake".into(),
2273 description: None,
2274 icon: Some(acp_thread::AgentModelIcon::Named(
2275 ui::IconName::ZedAssistant
2276 )),
2277 is_latest: false,
2278 cost: None,
2279 }]
2280 )])
2281 );
2282 }
2283
2284 #[gpui::test]
2285 async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) {
2286 init_test(cx);
2287 let fs = FakeFs::new(cx.executor());
2288 fs.create_dir(paths::settings_file().parent().unwrap())
2289 .await
2290 .unwrap();
2291 fs.insert_file(
2292 paths::settings_file(),
2293 json!({
2294 "agent": {
2295 "default_model": {
2296 "provider": "foo",
2297 "model": "bar"
2298 }
2299 }
2300 })
2301 .to_string()
2302 .into_bytes(),
2303 )
2304 .await;
2305 let project = Project::test(fs.clone(), [], cx).await;
2306
2307 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2308
2309 // Create the agent and connection
2310 let agent =
2311 cx.update(|cx| NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx));
2312 let connection = NativeAgentConnection(agent.clone());
2313
2314 // Create a thread/session
2315 let acp_thread = cx
2316 .update(|cx| {
2317 Rc::new(connection.clone()).new_session(
2318 project.clone(),
2319 PathList::new(&[Path::new("/a")]),
2320 cx,
2321 )
2322 })
2323 .await
2324 .unwrap();
2325
2326 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2327
2328 // Select a model
2329 let selector = connection.model_selector(&session_id).unwrap();
2330 let model_id = acp::ModelId::new("fake/fake");
2331 cx.update(|cx| selector.select_model(model_id.clone(), cx))
2332 .await
2333 .unwrap();
2334
2335 // Verify the thread has the selected model
2336 agent.read_with(cx, |agent, _| {
2337 let session = agent.sessions.get(&session_id).unwrap();
2338 session.thread.read_with(cx, |thread, _| {
2339 assert_eq!(thread.model().unwrap().id().0, "fake");
2340 });
2341 });
2342
2343 cx.run_until_parked();
2344
2345 // Verify settings file was updated
2346 let settings_content = fs.load(paths::settings_file()).await.unwrap();
2347 let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
2348
2349 // Check that the agent settings contain the selected model
2350 assert_eq!(
2351 settings_json["agent"]["default_model"]["model"],
2352 json!("fake")
2353 );
2354 assert_eq!(
2355 settings_json["agent"]["default_model"]["provider"],
2356 json!("fake")
2357 );
2358
2359 // Register a thinking model and select it.
2360 cx.update(|cx| {
2361 let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2362 "fake-corp",
2363 "fake-thinking",
2364 "Fake Thinking",
2365 true,
2366 ));
2367 let thinking_provider = Arc::new(
2368 FakeLanguageModelProvider::new(
2369 LanguageModelProviderId::from("fake-corp".to_string()),
2370 LanguageModelProviderName::from("Fake Corp".to_string()),
2371 )
2372 .with_models(vec![thinking_model]),
2373 );
2374 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2375 registry.register_provider(thinking_provider, cx);
2376 });
2377 });
2378 agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2379
2380 let selector = connection.model_selector(&session_id).unwrap();
2381 cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2382 .await
2383 .unwrap();
2384 cx.run_until_parked();
2385
2386 // Verify enable_thinking was written to settings as true.
2387 let settings_content = fs.load(paths::settings_file()).await.unwrap();
2388 let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
2389 assert_eq!(
2390 settings_json["agent"]["default_model"]["enable_thinking"],
2391 json!(true),
2392 "selecting a thinking model should persist enable_thinking: true to settings"
2393 );
2394 }
2395
2396 #[gpui::test]
2397 async fn test_select_model_updates_thinking_enabled(cx: &mut TestAppContext) {
2398 init_test(cx);
2399 let fs = FakeFs::new(cx.executor());
2400 fs.create_dir(paths::settings_file().parent().unwrap())
2401 .await
2402 .unwrap();
2403 fs.insert_file(paths::settings_file(), b"{}".to_vec()).await;
2404 let project = Project::test(fs.clone(), [], cx).await;
2405
2406 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2407 let agent =
2408 cx.update(|cx| NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx));
2409 let connection = NativeAgentConnection(agent.clone());
2410
2411 let acp_thread = cx
2412 .update(|cx| {
2413 Rc::new(connection.clone()).new_session(
2414 project.clone(),
2415 PathList::new(&[Path::new("/a")]),
2416 cx,
2417 )
2418 })
2419 .await
2420 .unwrap();
2421 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2422
2423 // Register a second provider with a thinking model.
2424 cx.update(|cx| {
2425 let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2426 "fake-corp",
2427 "fake-thinking",
2428 "Fake Thinking",
2429 true,
2430 ));
2431 let thinking_provider = Arc::new(
2432 FakeLanguageModelProvider::new(
2433 LanguageModelProviderId::from("fake-corp".to_string()),
2434 LanguageModelProviderName::from("Fake Corp".to_string()),
2435 )
2436 .with_models(vec![thinking_model]),
2437 );
2438 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2439 registry.register_provider(thinking_provider, cx);
2440 });
2441 });
2442 // Refresh the agent's model list so it picks up the new provider.
2443 agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2444
2445 // Thread starts with thinking_enabled = false (the default).
2446 agent.read_with(cx, |agent, _| {
2447 let session = agent.sessions.get(&session_id).unwrap();
2448 session.thread.read_with(cx, |thread, _| {
2449 assert!(!thread.thinking_enabled(), "thinking defaults to false");
2450 });
2451 });
2452
2453 // Select the thinking model via select_model.
2454 let selector = connection.model_selector(&session_id).unwrap();
2455 cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2456 .await
2457 .unwrap();
2458
2459 // select_model should have enabled thinking based on the model's supports_thinking().
2460 agent.read_with(cx, |agent, _| {
2461 let session = agent.sessions.get(&session_id).unwrap();
2462 session.thread.read_with(cx, |thread, _| {
2463 assert!(
2464 thread.thinking_enabled(),
2465 "select_model should enable thinking when model supports it"
2466 );
2467 });
2468 });
2469
2470 // Switch back to the non-thinking model.
2471 let selector = connection.model_selector(&session_id).unwrap();
2472 cx.update(|cx| selector.select_model(acp::ModelId::new("fake/fake"), cx))
2473 .await
2474 .unwrap();
2475
2476 // select_model should have disabled thinking.
2477 agent.read_with(cx, |agent, _| {
2478 let session = agent.sessions.get(&session_id).unwrap();
2479 session.thread.read_with(cx, |thread, _| {
2480 assert!(
2481 !thread.thinking_enabled(),
2482 "select_model should disable thinking when model does not support it"
2483 );
2484 });
2485 });
2486 }
2487
2488 #[gpui::test]
2489 async fn test_summarization_model_survives_transient_registry_clearing(
2490 cx: &mut TestAppContext,
2491 ) {
2492 init_test(cx);
2493 let fs = FakeFs::new(cx.executor());
2494 fs.insert_tree("/", json!({ "a": {} })).await;
2495 let project = Project::test(fs.clone(), [], cx).await;
2496
2497 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2498 let agent =
2499 cx.update(|cx| NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx));
2500 let connection = Rc::new(NativeAgentConnection(agent.clone()));
2501
2502 let acp_thread = cx
2503 .update(|cx| {
2504 connection.clone().new_session(
2505 project.clone(),
2506 PathList::new(&[Path::new("/a")]),
2507 cx,
2508 )
2509 })
2510 .await
2511 .unwrap();
2512 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2513
2514 let thread = agent.read_with(cx, |agent, _| {
2515 agent.sessions.get(&session_id).unwrap().thread.clone()
2516 });
2517
2518 thread.read_with(cx, |thread, _| {
2519 assert!(
2520 thread.summarization_model().is_some(),
2521 "session should have a summarization model from the test registry"
2522 );
2523 });
2524
2525 // Simulate what happens during a provider blip:
2526 // update_active_language_model_from_settings calls set_default_model(None)
2527 // when it can't resolve the model, clearing all fallbacks.
2528 cx.update(|cx| {
2529 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2530 registry.set_default_model(None, cx);
2531 });
2532 });
2533 cx.run_until_parked();
2534
2535 thread.read_with(cx, |thread, _| {
2536 assert!(
2537 thread.summarization_model().is_some(),
2538 "summarization model should survive a transient default model clearing"
2539 );
2540 });
2541 }
2542
2543 #[gpui::test]
2544 async fn test_loaded_thread_preserves_thinking_enabled(cx: &mut TestAppContext) {
2545 init_test(cx);
2546 let fs = FakeFs::new(cx.executor());
2547 fs.insert_tree("/", json!({ "a": {} })).await;
2548 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2549 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2550 let agent = cx.update(|cx| {
2551 NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
2552 });
2553 let connection = Rc::new(NativeAgentConnection(agent.clone()));
2554
2555 // Register a thinking model.
2556 let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2557 "fake-corp",
2558 "fake-thinking",
2559 "Fake Thinking",
2560 true,
2561 ));
2562 let thinking_provider = Arc::new(
2563 FakeLanguageModelProvider::new(
2564 LanguageModelProviderId::from("fake-corp".to_string()),
2565 LanguageModelProviderName::from("Fake Corp".to_string()),
2566 )
2567 .with_models(vec![thinking_model.clone()]),
2568 );
2569 cx.update(|cx| {
2570 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2571 registry.register_provider(thinking_provider, cx);
2572 });
2573 });
2574 agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2575
2576 // Create a thread and select the thinking model.
2577 let acp_thread = cx
2578 .update(|cx| {
2579 connection.clone().new_session(
2580 project.clone(),
2581 PathList::new(&[Path::new("/a")]),
2582 cx,
2583 )
2584 })
2585 .await
2586 .unwrap();
2587 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2588
2589 let selector = connection.model_selector(&session_id).unwrap();
2590 cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2591 .await
2592 .unwrap();
2593
2594 // Verify thinking is enabled after selecting the thinking model.
2595 let thread = agent.read_with(cx, |agent, _| {
2596 agent.sessions.get(&session_id).unwrap().thread.clone()
2597 });
2598 thread.read_with(cx, |thread, _| {
2599 assert!(
2600 thread.thinking_enabled(),
2601 "thinking should be enabled after selecting thinking model"
2602 );
2603 });
2604
2605 // Send a message so the thread gets persisted.
2606 let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
2607 let send = cx.foreground_executor().spawn(send);
2608 cx.run_until_parked();
2609
2610 thinking_model.send_last_completion_stream_text_chunk("Response.");
2611 thinking_model.end_last_completion_stream();
2612
2613 send.await.unwrap();
2614 cx.run_until_parked();
2615
2616 // Close the session so it can be reloaded from disk.
2617 cx.update(|cx| connection.clone().close_session(&session_id, cx))
2618 .await
2619 .unwrap();
2620 drop(thread);
2621 drop(acp_thread);
2622 agent.read_with(cx, |agent, _| {
2623 assert!(agent.sessions.is_empty());
2624 });
2625
2626 // Reload the thread and verify thinking_enabled is still true.
2627 let reloaded_acp_thread = agent
2628 .update(cx, |agent, cx| {
2629 agent.open_thread(session_id.clone(), project.clone(), cx)
2630 })
2631 .await
2632 .unwrap();
2633 let reloaded_thread = agent.read_with(cx, |agent, _| {
2634 agent.sessions.get(&session_id).unwrap().thread.clone()
2635 });
2636 reloaded_thread.read_with(cx, |thread, _| {
2637 assert!(
2638 thread.thinking_enabled(),
2639 "thinking_enabled should be preserved when reloading a thread with a thinking model"
2640 );
2641 });
2642
2643 drop(reloaded_acp_thread);
2644 }
2645
2646 #[gpui::test]
2647 async fn test_loaded_thread_preserves_model(cx: &mut TestAppContext) {
2648 init_test(cx);
2649 let fs = FakeFs::new(cx.executor());
2650 fs.insert_tree("/", json!({ "a": {} })).await;
2651 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2652 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2653 let agent = cx.update(|cx| {
2654 NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
2655 });
2656 let connection = Rc::new(NativeAgentConnection(agent.clone()));
2657
2658 // Register a model where id() != name(), like real Anthropic models
2659 // (e.g. id="claude-sonnet-4-5-thinking-latest", name="Claude Sonnet 4.5 Thinking").
2660 let model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2661 "fake-corp",
2662 "custom-model-id",
2663 "Custom Model Display Name",
2664 false,
2665 ));
2666 let provider = Arc::new(
2667 FakeLanguageModelProvider::new(
2668 LanguageModelProviderId::from("fake-corp".to_string()),
2669 LanguageModelProviderName::from("Fake Corp".to_string()),
2670 )
2671 .with_models(vec![model.clone()]),
2672 );
2673 cx.update(|cx| {
2674 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2675 registry.register_provider(provider, cx);
2676 });
2677 });
2678 agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2679
2680 // Create a thread and select the model.
2681 let acp_thread = cx
2682 .update(|cx| {
2683 connection.clone().new_session(
2684 project.clone(),
2685 PathList::new(&[Path::new("/a")]),
2686 cx,
2687 )
2688 })
2689 .await
2690 .unwrap();
2691 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2692
2693 let selector = connection.model_selector(&session_id).unwrap();
2694 cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/custom-model-id"), cx))
2695 .await
2696 .unwrap();
2697
2698 let thread = agent.read_with(cx, |agent, _| {
2699 agent.sessions.get(&session_id).unwrap().thread.clone()
2700 });
2701 thread.read_with(cx, |thread, _| {
2702 assert_eq!(
2703 thread.model().unwrap().id().0.as_ref(),
2704 "custom-model-id",
2705 "model should be set before persisting"
2706 );
2707 });
2708
2709 // Send a message so the thread gets persisted.
2710 let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
2711 let send = cx.foreground_executor().spawn(send);
2712 cx.run_until_parked();
2713
2714 model.send_last_completion_stream_text_chunk("Response.");
2715 model.end_last_completion_stream();
2716
2717 send.await.unwrap();
2718 cx.run_until_parked();
2719
2720 // Close the session so it can be reloaded from disk.
2721 cx.update(|cx| connection.clone().close_session(&session_id, cx))
2722 .await
2723 .unwrap();
2724 drop(thread);
2725 drop(acp_thread);
2726 agent.read_with(cx, |agent, _| {
2727 assert!(agent.sessions.is_empty());
2728 });
2729
2730 // Reload the thread and verify the model was preserved.
2731 let reloaded_acp_thread = agent
2732 .update(cx, |agent, cx| {
2733 agent.open_thread(session_id.clone(), project.clone(), cx)
2734 })
2735 .await
2736 .unwrap();
2737 let reloaded_thread = agent.read_with(cx, |agent, _| {
2738 agent.sessions.get(&session_id).unwrap().thread.clone()
2739 });
2740 reloaded_thread.read_with(cx, |thread, _| {
2741 let reloaded_model = thread
2742 .model()
2743 .expect("model should be present after reload");
2744 assert_eq!(
2745 reloaded_model.id().0.as_ref(),
2746 "custom-model-id",
2747 "reloaded thread should have the same model, not fall back to the default"
2748 );
2749 });
2750
2751 drop(reloaded_acp_thread);
2752 }
2753
2754 #[gpui::test]
2755 async fn test_save_load_thread(cx: &mut TestAppContext) {
2756 init_test(cx);
2757 let fs = FakeFs::new(cx.executor());
2758 fs.insert_tree(
2759 "/",
2760 json!({
2761 "a": {
2762 "b.md": "Lorem"
2763 }
2764 }),
2765 )
2766 .await;
2767 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2768 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2769 let agent = cx.update(|cx| {
2770 NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
2771 });
2772 let connection = Rc::new(NativeAgentConnection(agent.clone()));
2773
2774 let acp_thread = cx
2775 .update(|cx| {
2776 connection
2777 .clone()
2778 .new_session(project.clone(), PathList::new(&[Path::new("")]), cx)
2779 })
2780 .await
2781 .unwrap();
2782 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2783 let thread = agent.read_with(cx, |agent, _| {
2784 agent.sessions.get(&session_id).unwrap().thread.clone()
2785 });
2786
2787 // Ensure empty threads are not saved, even if they get mutated.
2788 let model = Arc::new(FakeLanguageModel::default());
2789 let summary_model = Arc::new(FakeLanguageModel::default());
2790 thread.update(cx, |thread, cx| {
2791 thread.set_model(model.clone(), cx);
2792 thread.set_summarization_model(Some(summary_model.clone()), cx);
2793 });
2794 cx.run_until_parked();
2795 assert_eq!(thread_entries(&thread_store, cx), vec![]);
2796
2797 let send = acp_thread.update(cx, |thread, cx| {
2798 thread.send(
2799 vec![
2800 "What does ".into(),
2801 acp::ContentBlock::ResourceLink(acp::ResourceLink::new(
2802 "b.md",
2803 MentionUri::File {
2804 abs_path: path!("/a/b.md").into(),
2805 }
2806 .to_uri()
2807 .to_string(),
2808 )),
2809 " mean?".into(),
2810 ],
2811 cx,
2812 )
2813 });
2814 let send = cx.foreground_executor().spawn(send);
2815 cx.run_until_parked();
2816
2817 model.send_last_completion_stream_text_chunk("Lorem.");
2818 model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2819 language_model::TokenUsage {
2820 input_tokens: 150,
2821 output_tokens: 75,
2822 ..Default::default()
2823 },
2824 ));
2825 model.end_last_completion_stream();
2826 cx.run_until_parked();
2827 summary_model
2828 .send_last_completion_stream_text_chunk(&format!("Explaining {}", path!("/a/b.md")));
2829 summary_model.end_last_completion_stream();
2830
2831 send.await.unwrap();
2832 let uri = MentionUri::File {
2833 abs_path: path!("/a/b.md").into(),
2834 }
2835 .to_uri();
2836 acp_thread.read_with(cx, |thread, cx| {
2837 assert_eq!(
2838 thread.to_markdown(cx),
2839 formatdoc! {"
2840 ## User
2841
2842 What does [@b.md]({uri}) mean?
2843
2844 ## Assistant
2845
2846 Lorem.
2847
2848 "}
2849 )
2850 });
2851
2852 cx.run_until_parked();
2853
2854 // Set a draft prompt with rich content blocks and scroll position
2855 // AFTER run_until_parked, so the only save that captures these
2856 // changes is the one performed by close_session itself.
2857 let draft_blocks = vec![
2858 acp::ContentBlock::Text(acp::TextContent::new("Check out ")),
2859 acp::ContentBlock::ResourceLink(acp::ResourceLink::new("b.md", uri.to_string())),
2860 acp::ContentBlock::Text(acp::TextContent::new(" please")),
2861 ];
2862 acp_thread.update(cx, |thread, _cx| {
2863 thread.set_draft_prompt(Some(draft_blocks.clone()));
2864 });
2865 thread.update(cx, |thread, _cx| {
2866 thread.set_ui_scroll_position(Some(gpui::ListOffset {
2867 item_ix: 5,
2868 offset_in_item: gpui::px(12.5),
2869 }));
2870 });
2871
2872 // Close the session so it can be reloaded from disk.
2873 cx.update(|cx| connection.clone().close_session(&session_id, cx))
2874 .await
2875 .unwrap();
2876 drop(thread);
2877 drop(acp_thread);
2878 agent.read_with(cx, |agent, _| {
2879 assert_eq!(agent.sessions.keys().cloned().collect::<Vec<_>>(), []);
2880 });
2881
2882 // Ensure the thread can be reloaded from disk.
2883 assert_eq!(
2884 thread_entries(&thread_store, cx),
2885 vec![(
2886 session_id.clone(),
2887 format!("Explaining {}", path!("/a/b.md"))
2888 )]
2889 );
2890 let acp_thread = agent
2891 .update(cx, |agent, cx| {
2892 agent.open_thread(session_id.clone(), project.clone(), cx)
2893 })
2894 .await
2895 .unwrap();
2896 acp_thread.read_with(cx, |thread, cx| {
2897 assert_eq!(
2898 thread.to_markdown(cx),
2899 formatdoc! {"
2900 ## User
2901
2902 What does [@b.md]({uri}) mean?
2903
2904 ## Assistant
2905
2906 Lorem.
2907
2908 "}
2909 )
2910 });
2911
2912 // Ensure the draft prompt with rich content blocks survived the round-trip.
2913 acp_thread.read_with(cx, |thread, _| {
2914 assert_eq!(thread.draft_prompt(), Some(draft_blocks.as_slice()));
2915 });
2916
2917 // Ensure token usage survived the round-trip.
2918 acp_thread.read_with(cx, |thread, _| {
2919 let usage = thread
2920 .token_usage()
2921 .expect("token usage should be restored after reload");
2922 assert_eq!(usage.input_tokens, 150);
2923 assert_eq!(usage.output_tokens, 75);
2924 });
2925
2926 // Ensure scroll position survived the round-trip.
2927 acp_thread.read_with(cx, |thread, _| {
2928 let scroll = thread
2929 .ui_scroll_position()
2930 .expect("scroll position should be restored after reload");
2931 assert_eq!(scroll.item_ix, 5);
2932 assert_eq!(scroll.offset_in_item, gpui::px(12.5));
2933 });
2934 }
2935
2936 #[gpui::test]
2937 async fn test_close_session_saves_thread(cx: &mut TestAppContext) {
2938 init_test(cx);
2939 let fs = FakeFs::new(cx.executor());
2940 fs.insert_tree(
2941 "/",
2942 json!({
2943 "a": {
2944 "file.txt": "hello"
2945 }
2946 }),
2947 )
2948 .await;
2949 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2950 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2951 let agent = cx.update(|cx| {
2952 NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
2953 });
2954 let connection = Rc::new(NativeAgentConnection(agent.clone()));
2955
2956 let acp_thread = cx
2957 .update(|cx| {
2958 connection
2959 .clone()
2960 .new_session(project.clone(), PathList::new(&[Path::new("")]), cx)
2961 })
2962 .await
2963 .unwrap();
2964 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2965 let thread = agent.read_with(cx, |agent, _| {
2966 agent.sessions.get(&session_id).unwrap().thread.clone()
2967 });
2968
2969 let model = Arc::new(FakeLanguageModel::default());
2970 thread.update(cx, |thread, cx| {
2971 thread.set_model(model.clone(), cx);
2972 });
2973
2974 // Send a message so the thread is non-empty (empty threads aren't saved).
2975 let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx));
2976 let send = cx.foreground_executor().spawn(send);
2977 cx.run_until_parked();
2978
2979 model.send_last_completion_stream_text_chunk("world");
2980 model.end_last_completion_stream();
2981 send.await.unwrap();
2982 cx.run_until_parked();
2983
2984 // Set a draft prompt WITHOUT calling run_until_parked afterwards.
2985 // This means no observe-triggered save has run for this change.
2986 // The only way this data gets persisted is if close_session
2987 // itself performs the save.
2988 let draft_blocks = vec![acp::ContentBlock::Text(acp::TextContent::new(
2989 "unsaved draft",
2990 ))];
2991 acp_thread.update(cx, |thread, _cx| {
2992 thread.set_draft_prompt(Some(draft_blocks.clone()));
2993 });
2994
2995 // Close the session immediately — no run_until_parked in between.
2996 cx.update(|cx| connection.clone().close_session(&session_id, cx))
2997 .await
2998 .unwrap();
2999 cx.run_until_parked();
3000
3001 // Reopen and verify the draft prompt was saved.
3002 let reloaded = agent
3003 .update(cx, |agent, cx| {
3004 agent.open_thread(session_id.clone(), project.clone(), cx)
3005 })
3006 .await
3007 .unwrap();
3008 reloaded.read_with(cx, |thread, _| {
3009 assert_eq!(
3010 thread.draft_prompt(),
3011 Some(draft_blocks.as_slice()),
3012 "close_session must save the thread; draft prompt was lost"
3013 );
3014 });
3015 }
3016
3017 #[gpui::test]
3018 async fn test_rapid_title_changes_do_not_loop(cx: &mut TestAppContext) {
3019 // Regression test: rapid title changes must not cause a propagation loop
3020 // between Thread and AcpThread via handle_thread_title_updated.
3021 init_test(cx);
3022 let fs = FakeFs::new(cx.executor());
3023 fs.insert_tree("/", json!({ "a": {} })).await;
3024 let project = Project::test(fs.clone(), [], cx).await;
3025 let thread_store = cx.new(|cx| ThreadStore::new(cx));
3026 let agent = cx.update(|cx| {
3027 NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
3028 });
3029 let connection = Rc::new(NativeAgentConnection(agent.clone()));
3030
3031 let acp_thread = cx
3032 .update(|cx| {
3033 connection
3034 .clone()
3035 .new_session(project.clone(), PathList::new(&[Path::new("")]), cx)
3036 })
3037 .await
3038 .unwrap();
3039
3040 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
3041 let thread = agent.read_with(cx, |agent, _| {
3042 agent.sessions.get(&session_id).unwrap().thread.clone()
3043 });
3044
3045 let title_updated_count = Rc::new(std::cell::RefCell::new(0usize));
3046 cx.update(|cx| {
3047 let count = title_updated_count.clone();
3048 cx.subscribe(
3049 &thread,
3050 move |_entity: Entity<Thread>, _event: &TitleUpdated, _cx: &mut App| {
3051 let new_count = {
3052 let mut count = count.borrow_mut();
3053 *count += 1;
3054 *count
3055 };
3056 assert!(
3057 new_count <= 2,
3058 "TitleUpdated fired {new_count} times; \
3059 title updates are looping"
3060 );
3061 },
3062 )
3063 .detach();
3064 });
3065
3066 thread.update(cx, |thread, cx| thread.set_title("first".into(), cx));
3067 thread.update(cx, |thread, cx| thread.set_title("second".into(), cx));
3068
3069 cx.run_until_parked();
3070
3071 thread.read_with(cx, |thread, _| {
3072 assert_eq!(thread.title(), Some("second".into()));
3073 });
3074 acp_thread.read_with(cx, |acp_thread, _| {
3075 assert_eq!(acp_thread.title(), Some("second".into()));
3076 });
3077
3078 assert_eq!(*title_updated_count.borrow(), 2);
3079 }
3080
3081 fn thread_entries(
3082 thread_store: &Entity<ThreadStore>,
3083 cx: &mut TestAppContext,
3084 ) -> Vec<(acp::SessionId, String)> {
3085 thread_store.read_with(cx, |store, _| {
3086 store
3087 .entries()
3088 .map(|entry| (entry.id.clone(), entry.title.to_string()))
3089 .collect::<Vec<_>>()
3090 })
3091 }
3092
3093 fn init_test(cx: &mut TestAppContext) {
3094 env_logger::try_init().ok();
3095 cx.update(|cx| {
3096 let settings_store = SettingsStore::test(cx);
3097 cx.set_global(settings_store);
3098
3099 LanguageModelRegistry::test(cx);
3100 });
3101 }
3102}
3103
3104fn mcp_message_content_to_acp_content_block(
3105 content: context_server::types::MessageContent,
3106) -> acp::ContentBlock {
3107 match content {
3108 context_server::types::MessageContent::Text {
3109 text,
3110 annotations: _,
3111 } => text.into(),
3112 context_server::types::MessageContent::Image {
3113 data,
3114 mime_type,
3115 annotations: _,
3116 } => acp::ContentBlock::Image(acp::ImageContent::new(data, mime_type)),
3117 context_server::types::MessageContent::Audio {
3118 data,
3119 mime_type,
3120 annotations: _,
3121 } => acp::ContentBlock::Audio(acp::AudioContent::new(data, mime_type)),
3122 context_server::types::MessageContent::Resource {
3123 resource,
3124 annotations: _,
3125 } => {
3126 let mut link =
3127 acp::ResourceLink::new(resource.uri.to_string(), resource.uri.to_string());
3128 if let Some(mime_type) = resource.mime_type {
3129 link = link.mime_type(mime_type);
3130 }
3131 acp::ContentBlock::ResourceLink(link)
3132 }
3133 }
3134}