1mod db;
2mod edit_agent;
3mod legacy_thread;
4mod native_agent_server;
5pub mod outline;
6mod pattern_extraction;
7mod templates;
8#[cfg(test)]
9mod tests;
10mod thread;
11mod thread_store;
12mod tool_permissions;
13mod tools;
14
15use context_server::ContextServerId;
16pub use db::*;
17use itertools::Itertools;
18pub use native_agent_server::NativeAgentServer;
19pub use pattern_extraction::*;
20pub use shell_command_parser::extract_commands;
21pub use templates::*;
22pub use thread::*;
23pub use thread_store::*;
24pub use tool_permissions::*;
25pub use tools::*;
26
27use acp_thread::{
28 AcpThread, AgentModelSelector, AgentSessionInfo, AgentSessionList, AgentSessionListRequest,
29 AgentSessionListResponse, TokenUsageRatio, UserMessageId,
30};
31use agent_client_protocol as acp;
32use anyhow::{Context as _, Result, anyhow};
33use chrono::{DateTime, Utc};
34use collections::{HashMap, HashSet, IndexMap};
35use fs::Fs;
36use futures::channel::{mpsc, oneshot};
37use futures::future::Shared;
38use futures::{FutureExt as _, StreamExt as _, future};
39use gpui::{
40 App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
41};
42use language_model::{IconOrSvg, LanguageModel, LanguageModelProvider, LanguageModelRegistry};
43use project::{Project, ProjectItem, ProjectPath, Worktree};
44use prompt_store::{
45 ProjectContext, PromptStore, RULES_FILE_NAMES, RulesFileContext, UserRulesContext,
46 WorktreeContext,
47};
48use serde::{Deserialize, Serialize};
49use settings::{LanguageModelSelection, update_settings_file};
50use std::any::Any;
51use std::path::{Path, PathBuf};
52use std::rc::Rc;
53use std::sync::Arc;
54use util::ResultExt;
55use util::path_list::PathList;
56use util::rel_path::RelPath;
57
58#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
59pub struct ProjectSnapshot {
60 pub worktree_snapshots: Vec<project::telemetry_snapshot::TelemetryWorktreeSnapshot>,
61 pub timestamp: DateTime<Utc>,
62}
63
64pub struct RulesLoadingError {
65 pub message: SharedString,
66}
67
68/// Holds both the internal Thread and the AcpThread for a session
69struct Session {
70 /// The internal thread that processes messages
71 thread: Entity<Thread>,
72 /// The ACP thread that handles protocol communication
73 acp_thread: Entity<acp_thread::AcpThread>,
74 pending_save: Task<()>,
75 _subscriptions: Vec<Subscription>,
76}
77
78pub struct LanguageModels {
79 /// Access language model by ID
80 models: HashMap<acp::ModelId, Arc<dyn LanguageModel>>,
81 /// Cached list for returning language model information
82 model_list: acp_thread::AgentModelList,
83 refresh_models_rx: watch::Receiver<()>,
84 refresh_models_tx: watch::Sender<()>,
85 _authenticate_all_providers_task: Task<()>,
86}
87
88impl LanguageModels {
89 fn new(cx: &mut App) -> Self {
90 let (refresh_models_tx, refresh_models_rx) = watch::channel(());
91
92 let mut this = Self {
93 models: HashMap::default(),
94 model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()),
95 refresh_models_rx,
96 refresh_models_tx,
97 _authenticate_all_providers_task: Self::authenticate_all_language_model_providers(cx),
98 };
99 this.refresh_list(cx);
100 this
101 }
102
103 fn refresh_list(&mut self, cx: &App) {
104 let providers = LanguageModelRegistry::global(cx)
105 .read(cx)
106 .visible_providers()
107 .into_iter()
108 .filter(|provider| provider.is_authenticated(cx))
109 .collect::<Vec<_>>();
110
111 let mut language_model_list = IndexMap::default();
112 let mut recommended_models = HashSet::default();
113
114 let mut recommended = Vec::new();
115 for provider in &providers {
116 for model in provider.recommended_models(cx) {
117 recommended_models.insert((model.provider_id(), model.id()));
118 recommended.push(Self::map_language_model_to_info(&model, provider));
119 }
120 }
121 if !recommended.is_empty() {
122 language_model_list.insert(
123 acp_thread::AgentModelGroupName("Recommended".into()),
124 recommended,
125 );
126 }
127
128 let mut models = HashMap::default();
129 for provider in providers {
130 let mut provider_models = Vec::new();
131 for model in provider.provided_models(cx) {
132 let model_info = Self::map_language_model_to_info(&model, &provider);
133 let model_id = model_info.id.clone();
134 provider_models.push(model_info);
135 models.insert(model_id, model);
136 }
137 if !provider_models.is_empty() {
138 language_model_list.insert(
139 acp_thread::AgentModelGroupName(provider.name().0.clone()),
140 provider_models,
141 );
142 }
143 }
144
145 self.models = models;
146 self.model_list = acp_thread::AgentModelList::Grouped(language_model_list);
147 self.refresh_models_tx.send(()).ok();
148 }
149
150 fn watch(&self) -> watch::Receiver<()> {
151 self.refresh_models_rx.clone()
152 }
153
154 pub fn model_from_id(&self, model_id: &acp::ModelId) -> Option<Arc<dyn LanguageModel>> {
155 self.models.get(model_id).cloned()
156 }
157
158 fn map_language_model_to_info(
159 model: &Arc<dyn LanguageModel>,
160 provider: &Arc<dyn LanguageModelProvider>,
161 ) -> acp_thread::AgentModelInfo {
162 acp_thread::AgentModelInfo {
163 id: Self::model_id(model),
164 name: model.name().0,
165 description: None,
166 icon: Some(match provider.icon() {
167 IconOrSvg::Svg(path) => acp_thread::AgentModelIcon::Path(path),
168 IconOrSvg::Icon(name) => acp_thread::AgentModelIcon::Named(name),
169 }),
170 is_latest: model.is_latest(),
171 cost: model.model_cost_info().map(|cost| cost.to_shared_string()),
172 }
173 }
174
175 fn model_id(model: &Arc<dyn LanguageModel>) -> acp::ModelId {
176 acp::ModelId::new(format!("{}/{}", model.provider_id().0, model.id().0))
177 }
178
179 fn authenticate_all_language_model_providers(cx: &mut App) -> Task<()> {
180 let authenticate_all_providers = LanguageModelRegistry::global(cx)
181 .read(cx)
182 .visible_providers()
183 .iter()
184 .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
185 .collect::<Vec<_>>();
186
187 cx.background_spawn(async move {
188 for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
189 if let Err(err) = authenticate_task.await {
190 match err {
191 language_model::AuthenticateError::CredentialsNotFound => {
192 // Since we're authenticating these providers in the
193 // background for the purposes of populating the
194 // language selector, we don't care about providers
195 // where the credentials are not found.
196 }
197 language_model::AuthenticateError::ConnectionRefused => {
198 // Not logging connection refused errors as they are mostly from LM Studio's noisy auth failures.
199 // LM Studio only has one auth method (endpoint call) which fails for users who haven't enabled it.
200 // TODO: Better manage LM Studio auth logic to avoid these noisy failures.
201 }
202 _ => {
203 // Some providers have noisy failure states that we
204 // don't want to spam the logs with every time the
205 // language model selector is initialized.
206 //
207 // Ideally these should have more clear failure modes
208 // that we know are safe to ignore here, like what we do
209 // with `CredentialsNotFound` above.
210 match provider_id.0.as_ref() {
211 "lmstudio" | "ollama" => {
212 // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
213 //
214 // These fail noisily, so we don't log them.
215 }
216 "copilot_chat" => {
217 // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
218 }
219 _ => {
220 log::error!(
221 "Failed to authenticate provider: {}: {err:#}",
222 provider_name.0
223 );
224 }
225 }
226 }
227 }
228 }
229 }
230 })
231 }
232}
233
234pub struct NativeAgent {
235 /// Session ID -> Session mapping
236 sessions: HashMap<acp::SessionId, Session>,
237 thread_store: Entity<ThreadStore>,
238 /// Shared project context for all threads
239 project_context: Entity<ProjectContext>,
240 project_context_needs_refresh: watch::Sender<()>,
241 _maintain_project_context: Task<Result<()>>,
242 context_server_registry: Entity<ContextServerRegistry>,
243 /// Shared templates for all threads
244 templates: Arc<Templates>,
245 /// Cached model information
246 models: LanguageModels,
247 project: Entity<Project>,
248 prompt_store: Option<Entity<PromptStore>>,
249 fs: Arc<dyn Fs>,
250 _subscriptions: Vec<Subscription>,
251}
252
253impl NativeAgent {
254 pub async fn new(
255 project: Entity<Project>,
256 thread_store: Entity<ThreadStore>,
257 templates: Arc<Templates>,
258 prompt_store: Option<Entity<PromptStore>>,
259 fs: Arc<dyn Fs>,
260 cx: &mut AsyncApp,
261 ) -> Result<Entity<NativeAgent>> {
262 log::debug!("Creating new NativeAgent");
263
264 let project_context = cx
265 .update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))
266 .await;
267
268 Ok(cx.new(|cx| {
269 let context_server_store = project.read(cx).context_server_store();
270 let context_server_registry =
271 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
272
273 let mut subscriptions = vec![
274 cx.subscribe(&project, Self::handle_project_event),
275 cx.subscribe(
276 &LanguageModelRegistry::global(cx),
277 Self::handle_models_updated_event,
278 ),
279 cx.subscribe(
280 &context_server_store,
281 Self::handle_context_server_store_updated,
282 ),
283 cx.subscribe(
284 &context_server_registry,
285 Self::handle_context_server_registry_event,
286 ),
287 ];
288 if let Some(prompt_store) = prompt_store.as_ref() {
289 subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
290 }
291
292 let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
293 watch::channel(());
294 Self {
295 sessions: HashMap::default(),
296 thread_store,
297 project_context: cx.new(|_| project_context),
298 project_context_needs_refresh: project_context_needs_refresh_tx,
299 _maintain_project_context: cx.spawn(async move |this, cx| {
300 Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
301 }),
302 context_server_registry,
303 templates,
304 models: LanguageModels::new(cx),
305 project,
306 prompt_store,
307 fs,
308 _subscriptions: subscriptions,
309 }
310 }))
311 }
312
313 fn new_session(
314 &mut self,
315 project: Entity<Project>,
316 cx: &mut Context<Self>,
317 ) -> Entity<AcpThread> {
318 // Create Thread
319 // Fetch default model from registry settings
320 let registry = LanguageModelRegistry::read_global(cx);
321 // Log available models for debugging
322 let available_count = registry.available_models(cx).count();
323 log::debug!("Total available models: {}", available_count);
324
325 let default_model = registry.default_model().and_then(|default_model| {
326 self.models
327 .model_from_id(&LanguageModels::model_id(&default_model.model))
328 });
329 let thread = cx.new(|cx| {
330 Thread::new(
331 project.clone(),
332 self.project_context.clone(),
333 self.context_server_registry.clone(),
334 self.templates.clone(),
335 default_model,
336 cx,
337 )
338 });
339
340 self.register_session(thread, cx)
341 }
342
343 fn register_session(
344 &mut self,
345 thread_handle: Entity<Thread>,
346 cx: &mut Context<Self>,
347 ) -> Entity<AcpThread> {
348 let connection = Rc::new(NativeAgentConnection(cx.entity()));
349
350 let thread = thread_handle.read(cx);
351 let session_id = thread.id().clone();
352 let parent_session_id = thread.parent_thread_id();
353 let title = thread.title();
354 let draft_prompt = thread.draft_prompt().map(Vec::from);
355 let scroll_position = thread.ui_scroll_position();
356 let token_usage = thread.latest_token_usage();
357 let project = thread.project.clone();
358 let action_log = thread.action_log.clone();
359 let prompt_capabilities_rx = thread.prompt_capabilities_rx.clone();
360 let acp_thread = cx.new(|cx| {
361 let mut acp_thread = acp_thread::AcpThread::new(
362 parent_session_id,
363 title,
364 None,
365 connection,
366 project.clone(),
367 action_log.clone(),
368 session_id.clone(),
369 prompt_capabilities_rx,
370 cx,
371 );
372 acp_thread.set_draft_prompt(draft_prompt);
373 acp_thread.set_ui_scroll_position(scroll_position);
374 acp_thread.update_token_usage(token_usage, cx);
375 acp_thread
376 });
377
378 let registry = LanguageModelRegistry::read_global(cx);
379 let summarization_model = registry.thread_summary_model().map(|c| c.model);
380
381 let weak = cx.weak_entity();
382 let weak_thread = thread_handle.downgrade();
383 thread_handle.update(cx, |thread, cx| {
384 thread.set_summarization_model(summarization_model, cx);
385 thread.add_default_tools(
386 Rc::new(NativeThreadEnvironment {
387 acp_thread: acp_thread.downgrade(),
388 thread: weak_thread,
389 agent: weak,
390 }) as _,
391 cx,
392 )
393 });
394
395 let subscriptions = vec![
396 cx.subscribe(&thread_handle, Self::handle_thread_title_updated),
397 cx.subscribe(&thread_handle, Self::handle_thread_token_usage_updated),
398 cx.observe(&thread_handle, move |this, thread, cx| {
399 this.save_thread(thread, cx)
400 }),
401 ];
402
403 self.sessions.insert(
404 session_id,
405 Session {
406 thread: thread_handle,
407 acp_thread: acp_thread.clone(),
408 _subscriptions: subscriptions,
409 pending_save: Task::ready(()),
410 },
411 );
412
413 self.update_available_commands(cx);
414
415 acp_thread
416 }
417
418 pub fn models(&self) -> &LanguageModels {
419 &self.models
420 }
421
422 async fn maintain_project_context(
423 this: WeakEntity<Self>,
424 mut needs_refresh: watch::Receiver<()>,
425 cx: &mut AsyncApp,
426 ) -> Result<()> {
427 while needs_refresh.changed().await.is_ok() {
428 let project_context = this
429 .update(cx, |this, cx| {
430 Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx)
431 })?
432 .await;
433 this.update(cx, |this, cx| {
434 this.project_context = cx.new(|_| project_context);
435 })?;
436 }
437
438 Ok(())
439 }
440
441 fn build_project_context(
442 project: &Entity<Project>,
443 prompt_store: Option<&Entity<PromptStore>>,
444 cx: &mut App,
445 ) -> Task<ProjectContext> {
446 let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
447 let worktree_tasks = worktrees
448 .into_iter()
449 .map(|worktree| {
450 Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
451 })
452 .collect::<Vec<_>>();
453 let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
454 prompt_store.read_with(cx, |prompt_store, cx| {
455 let prompts = prompt_store.default_prompt_metadata();
456 let load_tasks = prompts.into_iter().map(|prompt_metadata| {
457 let contents = prompt_store.load(prompt_metadata.id, cx);
458 async move { (contents.await, prompt_metadata) }
459 });
460 cx.background_spawn(future::join_all(load_tasks))
461 })
462 } else {
463 Task::ready(vec![])
464 };
465
466 cx.spawn(async move |_cx| {
467 let (worktrees, default_user_rules) =
468 future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
469
470 let worktrees = worktrees
471 .into_iter()
472 .map(|(worktree, _rules_error)| {
473 // TODO: show error message
474 // if let Some(rules_error) = rules_error {
475 // this.update(cx, |_, cx| cx.emit(rules_error)).ok();
476 // }
477 worktree
478 })
479 .collect::<Vec<_>>();
480
481 let default_user_rules = default_user_rules
482 .into_iter()
483 .flat_map(|(contents, prompt_metadata)| match contents {
484 Ok(contents) => Some(UserRulesContext {
485 uuid: prompt_metadata.id.as_user()?,
486 title: prompt_metadata.title.map(|title| title.to_string()),
487 contents,
488 }),
489 Err(_err) => {
490 // TODO: show error message
491 // this.update(cx, |_, cx| {
492 // cx.emit(RulesLoadingError {
493 // message: format!("{err:?}").into(),
494 // });
495 // })
496 // .ok();
497 None
498 }
499 })
500 .collect::<Vec<_>>();
501
502 ProjectContext::new(worktrees, default_user_rules)
503 })
504 }
505
506 fn load_worktree_info_for_system_prompt(
507 worktree: Entity<Worktree>,
508 project: Entity<Project>,
509 cx: &mut App,
510 ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
511 let tree = worktree.read(cx);
512 let root_name = tree.root_name_str().into();
513 let abs_path = tree.abs_path();
514
515 let mut context = WorktreeContext {
516 root_name,
517 abs_path,
518 rules_file: None,
519 };
520
521 let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
522 let Some(rules_task) = rules_task else {
523 return Task::ready((context, None));
524 };
525
526 cx.spawn(async move |_| {
527 let (rules_file, rules_file_error) = match rules_task.await {
528 Ok(rules_file) => (Some(rules_file), None),
529 Err(err) => (
530 None,
531 Some(RulesLoadingError {
532 message: format!("{err}").into(),
533 }),
534 ),
535 };
536 context.rules_file = rules_file;
537 (context, rules_file_error)
538 })
539 }
540
541 fn load_worktree_rules_file(
542 worktree: Entity<Worktree>,
543 project: Entity<Project>,
544 cx: &mut App,
545 ) -> Option<Task<Result<RulesFileContext>>> {
546 let worktree = worktree.read(cx);
547 let worktree_id = worktree.id();
548 let selected_rules_file = RULES_FILE_NAMES
549 .into_iter()
550 .filter_map(|name| {
551 worktree
552 .entry_for_path(RelPath::unix(name).unwrap())
553 .filter(|entry| entry.is_file())
554 .map(|entry| entry.path.clone())
555 })
556 .next();
557
558 // Note that Cline supports `.clinerules` being a directory, but that is not currently
559 // supported. This doesn't seem to occur often in GitHub repositories.
560 selected_rules_file.map(|path_in_worktree| {
561 let project_path = ProjectPath {
562 worktree_id,
563 path: path_in_worktree.clone(),
564 };
565 let buffer_task =
566 project.update(cx, |project, cx| project.open_buffer(project_path, cx));
567 let rope_task = cx.spawn(async move |cx| {
568 let buffer = buffer_task.await?;
569 let (project_entry_id, rope) = buffer.read_with(cx, |buffer, cx| {
570 let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
571 anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
572 })?;
573 anyhow::Ok((project_entry_id, rope))
574 });
575 // Build a string from the rope on a background thread.
576 cx.background_spawn(async move {
577 let (project_entry_id, rope) = rope_task.await?;
578 anyhow::Ok(RulesFileContext {
579 path_in_worktree,
580 text: rope.to_string().trim().to_string(),
581 project_entry_id: project_entry_id.to_usize(),
582 })
583 })
584 })
585 }
586
587 fn handle_thread_title_updated(
588 &mut self,
589 thread: Entity<Thread>,
590 _: &TitleUpdated,
591 cx: &mut Context<Self>,
592 ) {
593 let session_id = thread.read(cx).id();
594 let Some(session) = self.sessions.get(session_id) else {
595 return;
596 };
597 let thread = thread.downgrade();
598 let acp_thread = session.acp_thread.downgrade();
599 cx.spawn(async move |_, cx| {
600 let title = thread.read_with(cx, |thread, _| thread.title())?;
601 let task = acp_thread.update(cx, |acp_thread, cx| acp_thread.set_title(title, cx))?;
602 task.await
603 })
604 .detach_and_log_err(cx);
605 }
606
607 fn handle_thread_token_usage_updated(
608 &mut self,
609 thread: Entity<Thread>,
610 usage: &TokenUsageUpdated,
611 cx: &mut Context<Self>,
612 ) {
613 let Some(session) = self.sessions.get(thread.read(cx).id()) else {
614 return;
615 };
616 session.acp_thread.update(cx, |acp_thread, cx| {
617 acp_thread.update_token_usage(usage.0.clone(), cx);
618 });
619 }
620
621 fn handle_project_event(
622 &mut self,
623 _project: Entity<Project>,
624 event: &project::Event,
625 _cx: &mut Context<Self>,
626 ) {
627 match event {
628 project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
629 self.project_context_needs_refresh.send(()).ok();
630 }
631 project::Event::WorktreeUpdatedEntries(_, items) => {
632 if items.iter().any(|(path, _, _)| {
633 RULES_FILE_NAMES
634 .iter()
635 .any(|name| path.as_ref() == RelPath::unix(name).unwrap())
636 }) {
637 self.project_context_needs_refresh.send(()).ok();
638 }
639 }
640 _ => {}
641 }
642 }
643
644 fn handle_prompts_updated_event(
645 &mut self,
646 _prompt_store: Entity<PromptStore>,
647 _event: &prompt_store::PromptsUpdatedEvent,
648 _cx: &mut Context<Self>,
649 ) {
650 self.project_context_needs_refresh.send(()).ok();
651 }
652
653 fn handle_models_updated_event(
654 &mut self,
655 _registry: Entity<LanguageModelRegistry>,
656 _event: &language_model::Event,
657 cx: &mut Context<Self>,
658 ) {
659 self.models.refresh_list(cx);
660
661 let registry = LanguageModelRegistry::read_global(cx);
662 let default_model = registry.default_model().map(|m| m.model);
663 let summarization_model = registry.thread_summary_model().map(|m| m.model);
664
665 for session in self.sessions.values_mut() {
666 session.thread.update(cx, |thread, cx| {
667 if thread.model().is_none()
668 && let Some(model) = default_model.clone()
669 {
670 thread.set_model(model, cx);
671 cx.notify();
672 }
673 thread.set_summarization_model(summarization_model.clone(), cx);
674 });
675 }
676 }
677
678 fn handle_context_server_store_updated(
679 &mut self,
680 _store: Entity<project::context_server_store::ContextServerStore>,
681 _event: &project::context_server_store::ServerStatusChangedEvent,
682 cx: &mut Context<Self>,
683 ) {
684 self.update_available_commands(cx);
685 }
686
687 fn handle_context_server_registry_event(
688 &mut self,
689 _registry: Entity<ContextServerRegistry>,
690 event: &ContextServerRegistryEvent,
691 cx: &mut Context<Self>,
692 ) {
693 match event {
694 ContextServerRegistryEvent::ToolsChanged => {}
695 ContextServerRegistryEvent::PromptsChanged => {
696 self.update_available_commands(cx);
697 }
698 }
699 }
700
701 fn update_available_commands(&self, cx: &mut Context<Self>) {
702 let available_commands = self.build_available_commands(cx);
703 for session in self.sessions.values() {
704 session.acp_thread.update(cx, |thread, cx| {
705 thread
706 .handle_session_update(
707 acp::SessionUpdate::AvailableCommandsUpdate(
708 acp::AvailableCommandsUpdate::new(available_commands.clone()),
709 ),
710 cx,
711 )
712 .log_err();
713 });
714 }
715 }
716
717 fn build_available_commands(&self, cx: &App) -> Vec<acp::AvailableCommand> {
718 let registry = self.context_server_registry.read(cx);
719
720 let mut prompt_name_counts: HashMap<&str, usize> = HashMap::default();
721 for context_server_prompt in registry.prompts() {
722 *prompt_name_counts
723 .entry(context_server_prompt.prompt.name.as_str())
724 .or_insert(0) += 1;
725 }
726
727 registry
728 .prompts()
729 .flat_map(|context_server_prompt| {
730 let prompt = &context_server_prompt.prompt;
731
732 let should_prefix = prompt_name_counts
733 .get(prompt.name.as_str())
734 .copied()
735 .unwrap_or(0)
736 > 1;
737
738 let name = if should_prefix {
739 format!("{}.{}", context_server_prompt.server_id, prompt.name)
740 } else {
741 prompt.name.clone()
742 };
743
744 let mut command = acp::AvailableCommand::new(
745 name,
746 prompt.description.clone().unwrap_or_default(),
747 );
748
749 match prompt.arguments.as_deref() {
750 Some([arg]) => {
751 let hint = format!("<{}>", arg.name);
752
753 command = command.input(acp::AvailableCommandInput::Unstructured(
754 acp::UnstructuredCommandInput::new(hint),
755 ));
756 }
757 Some([]) | None => {}
758 Some(_) => {
759 // skip >1 argument commands since we don't support them yet
760 return None;
761 }
762 }
763
764 Some(command)
765 })
766 .collect()
767 }
768
769 pub fn load_thread(
770 &mut self,
771 id: acp::SessionId,
772 cx: &mut Context<Self>,
773 ) -> Task<Result<Entity<Thread>>> {
774 let database_future = ThreadsDatabase::connect(cx);
775 cx.spawn(async move |this, cx| {
776 let database = database_future.await.map_err(|err| anyhow!(err))?;
777 let db_thread = database
778 .load_thread(id.clone())
779 .await?
780 .with_context(|| format!("no thread found with ID: {id:?}"))?;
781
782 this.update(cx, |this, cx| {
783 let summarization_model = LanguageModelRegistry::read_global(cx)
784 .thread_summary_model()
785 .map(|c| c.model);
786
787 cx.new(|cx| {
788 let mut thread = Thread::from_db(
789 id.clone(),
790 db_thread,
791 this.project.clone(),
792 this.project_context.clone(),
793 this.context_server_registry.clone(),
794 this.templates.clone(),
795 cx,
796 );
797 thread.set_summarization_model(summarization_model, cx);
798 thread
799 })
800 })
801 })
802 }
803
804 pub fn open_thread(
805 &mut self,
806 id: acp::SessionId,
807 cx: &mut Context<Self>,
808 ) -> Task<Result<Entity<AcpThread>>> {
809 if let Some(session) = self.sessions.get(&id) {
810 return Task::ready(Ok(session.acp_thread.clone()));
811 }
812
813 let task = self.load_thread(id, cx);
814 cx.spawn(async move |this, cx| {
815 let thread = task.await?;
816 let acp_thread =
817 this.update(cx, |this, cx| this.register_session(thread.clone(), cx))?;
818 let events = thread.update(cx, |thread, cx| thread.replay(cx));
819 cx.update(|cx| {
820 NativeAgentConnection::handle_thread_events(events, acp_thread.downgrade(), cx)
821 })
822 .await?;
823 Ok(acp_thread)
824 })
825 }
826
827 pub fn thread_summary(
828 &mut self,
829 id: acp::SessionId,
830 cx: &mut Context<Self>,
831 ) -> Task<Result<SharedString>> {
832 let thread = self.open_thread(id.clone(), cx);
833 cx.spawn(async move |this, cx| {
834 let acp_thread = thread.await?;
835 let result = this
836 .update(cx, |this, cx| {
837 this.sessions
838 .get(&id)
839 .unwrap()
840 .thread
841 .update(cx, |thread, cx| thread.summary(cx))
842 })?
843 .await
844 .context("Failed to generate summary")?;
845 drop(acp_thread);
846 Ok(result)
847 })
848 }
849
850 fn save_thread(&mut self, thread: Entity<Thread>, cx: &mut Context<Self>) {
851 if thread.read(cx).is_empty() {
852 return;
853 }
854
855 let id = thread.read(cx).id().clone();
856 let Some(session) = self.sessions.get_mut(&id) else {
857 return;
858 };
859
860 let folder_paths = PathList::new(
861 &self
862 .project
863 .read(cx)
864 .visible_worktrees(cx)
865 .map(|worktree| worktree.read(cx).abs_path().to_path_buf())
866 .collect::<Vec<_>>(),
867 );
868
869 let draft_prompt = session.acp_thread.read(cx).draft_prompt().map(Vec::from);
870 let database_future = ThreadsDatabase::connect(cx);
871 let db_thread = thread.update(cx, |thread, cx| {
872 thread.set_draft_prompt(draft_prompt);
873 thread.to_db(cx)
874 });
875 let thread_store = self.thread_store.clone();
876 session.pending_save = cx.spawn(async move |_, cx| {
877 let Some(database) = database_future.await.map_err(|err| anyhow!(err)).log_err() else {
878 return;
879 };
880 let db_thread = db_thread.await;
881 database
882 .save_thread(id, db_thread, folder_paths)
883 .await
884 .log_err();
885 thread_store.update(cx, |store, cx| store.reload(cx));
886 });
887 }
888
889 fn send_mcp_prompt(
890 &self,
891 message_id: UserMessageId,
892 session_id: agent_client_protocol::SessionId,
893 prompt_name: String,
894 server_id: ContextServerId,
895 arguments: HashMap<String, String>,
896 original_content: Vec<acp::ContentBlock>,
897 cx: &mut Context<Self>,
898 ) -> Task<Result<acp::PromptResponse>> {
899 let server_store = self.context_server_registry.read(cx).server_store().clone();
900 let path_style = self.project.read(cx).path_style(cx);
901
902 cx.spawn(async move |this, cx| {
903 let prompt =
904 crate::get_prompt(&server_store, &server_id, &prompt_name, arguments, cx).await?;
905
906 let (acp_thread, thread) = this.update(cx, |this, _cx| {
907 let session = this
908 .sessions
909 .get(&session_id)
910 .context("Failed to get session")?;
911 anyhow::Ok((session.acp_thread.clone(), session.thread.clone()))
912 })??;
913
914 let mut last_is_user = true;
915
916 thread.update(cx, |thread, cx| {
917 thread.push_acp_user_block(
918 message_id,
919 original_content.into_iter().skip(1),
920 path_style,
921 cx,
922 );
923 });
924
925 for message in prompt.messages {
926 let context_server::types::PromptMessage { role, content } = message;
927 let block = mcp_message_content_to_acp_content_block(content);
928
929 match role {
930 context_server::types::Role::User => {
931 let id = acp_thread::UserMessageId::new();
932
933 acp_thread.update(cx, |acp_thread, cx| {
934 acp_thread.push_user_content_block_with_indent(
935 Some(id.clone()),
936 block.clone(),
937 true,
938 cx,
939 );
940 });
941
942 thread.update(cx, |thread, cx| {
943 thread.push_acp_user_block(id, [block], path_style, cx);
944 });
945 }
946 context_server::types::Role::Assistant => {
947 acp_thread.update(cx, |acp_thread, cx| {
948 acp_thread.push_assistant_content_block_with_indent(
949 block.clone(),
950 false,
951 true,
952 cx,
953 );
954 });
955
956 thread.update(cx, |thread, cx| {
957 thread.push_acp_agent_block(block, cx);
958 });
959 }
960 }
961
962 last_is_user = role == context_server::types::Role::User;
963 }
964
965 let response_stream = thread.update(cx, |thread, cx| {
966 if last_is_user {
967 thread.send_existing(cx)
968 } else {
969 // Resume if MCP prompt did not end with a user message
970 thread.resume(cx)
971 }
972 })?;
973
974 cx.update(|cx| {
975 NativeAgentConnection::handle_thread_events(
976 response_stream,
977 acp_thread.downgrade(),
978 cx,
979 )
980 })
981 .await
982 })
983 }
984}
985
986/// Wrapper struct that implements the AgentConnection trait
987#[derive(Clone)]
988pub struct NativeAgentConnection(pub Entity<NativeAgent>);
989
990impl NativeAgentConnection {
991 pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option<Entity<Thread>> {
992 self.0
993 .read(cx)
994 .sessions
995 .get(session_id)
996 .map(|session| session.thread.clone())
997 }
998
999 pub fn load_thread(&self, id: acp::SessionId, cx: &mut App) -> Task<Result<Entity<Thread>>> {
1000 self.0.update(cx, |this, cx| this.load_thread(id, cx))
1001 }
1002
1003 fn run_turn(
1004 &self,
1005 session_id: acp::SessionId,
1006 cx: &mut App,
1007 f: impl 'static
1008 + FnOnce(Entity<Thread>, &mut App) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>,
1009 ) -> Task<Result<acp::PromptResponse>> {
1010 let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
1011 agent
1012 .sessions
1013 .get_mut(&session_id)
1014 .map(|s| (s.thread.clone(), s.acp_thread.clone()))
1015 }) else {
1016 return Task::ready(Err(anyhow!("Session not found")));
1017 };
1018 log::debug!("Found session for: {}", session_id);
1019
1020 let response_stream = match f(thread, cx) {
1021 Ok(stream) => stream,
1022 Err(err) => return Task::ready(Err(err)),
1023 };
1024 Self::handle_thread_events(response_stream, acp_thread.downgrade(), cx)
1025 }
1026
1027 fn handle_thread_events(
1028 mut events: mpsc::UnboundedReceiver<Result<ThreadEvent>>,
1029 acp_thread: WeakEntity<AcpThread>,
1030 cx: &App,
1031 ) -> Task<Result<acp::PromptResponse>> {
1032 cx.spawn(async move |cx| {
1033 // Handle response stream and forward to session.acp_thread
1034 while let Some(result) = events.next().await {
1035 match result {
1036 Ok(event) => {
1037 log::trace!("Received completion event: {:?}", event);
1038
1039 match event {
1040 ThreadEvent::UserMessage(message) => {
1041 acp_thread.update(cx, |thread, cx| {
1042 for content in message.content {
1043 thread.push_user_content_block(
1044 Some(message.id.clone()),
1045 content.into(),
1046 cx,
1047 );
1048 }
1049 })?;
1050 }
1051 ThreadEvent::AgentText(text) => {
1052 acp_thread.update(cx, |thread, cx| {
1053 thread.push_assistant_content_block(text.into(), false, cx)
1054 })?;
1055 }
1056 ThreadEvent::AgentThinking(text) => {
1057 acp_thread.update(cx, |thread, cx| {
1058 thread.push_assistant_content_block(text.into(), true, cx)
1059 })?;
1060 }
1061 ThreadEvent::ToolCallAuthorization(ToolCallAuthorization {
1062 tool_call,
1063 options,
1064 response,
1065 context: _,
1066 }) => {
1067 let outcome_task = acp_thread.update(cx, |thread, cx| {
1068 thread.request_tool_call_authorization(tool_call, options, cx)
1069 })??;
1070 cx.background_spawn(async move {
1071 if let acp::RequestPermissionOutcome::Selected(
1072 acp::SelectedPermissionOutcome { option_id, .. },
1073 ) = outcome_task.await
1074 {
1075 response
1076 .send(option_id)
1077 .map(|_| anyhow!("authorization receiver was dropped"))
1078 .log_err();
1079 }
1080 })
1081 .detach();
1082 }
1083 ThreadEvent::ToolCall(tool_call) => {
1084 acp_thread.update(cx, |thread, cx| {
1085 thread.upsert_tool_call(tool_call, cx)
1086 })??;
1087 }
1088 ThreadEvent::ToolCallUpdate(update) => {
1089 acp_thread.update(cx, |thread, cx| {
1090 thread.update_tool_call(update, cx)
1091 })??;
1092 }
1093 ThreadEvent::SubagentSpawned(session_id) => {
1094 acp_thread.update(cx, |thread, cx| {
1095 thread.subagent_spawned(session_id, cx);
1096 })?;
1097 }
1098 ThreadEvent::Retry(status) => {
1099 acp_thread.update(cx, |thread, cx| {
1100 thread.update_retry_status(status, cx)
1101 })?;
1102 }
1103 ThreadEvent::Stop(stop_reason) => {
1104 log::debug!("Assistant message complete: {:?}", stop_reason);
1105 return Ok(acp::PromptResponse::new(stop_reason));
1106 }
1107 }
1108 }
1109 Err(e) => {
1110 log::error!("Error in model response stream: {:?}", e);
1111 return Err(e);
1112 }
1113 }
1114 }
1115
1116 log::debug!("Response stream completed");
1117 anyhow::Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
1118 })
1119 }
1120}
1121
1122struct Command<'a> {
1123 prompt_name: &'a str,
1124 arg_value: &'a str,
1125 explicit_server_id: Option<&'a str>,
1126}
1127
1128impl<'a> Command<'a> {
1129 fn parse(prompt: &'a [acp::ContentBlock]) -> Option<Self> {
1130 let acp::ContentBlock::Text(text_content) = prompt.first()? else {
1131 return None;
1132 };
1133 let text = text_content.text.trim();
1134 let command = text.strip_prefix('/')?;
1135 let (command, arg_value) = command
1136 .split_once(char::is_whitespace)
1137 .unwrap_or((command, ""));
1138
1139 if let Some((server_id, prompt_name)) = command.split_once('.') {
1140 Some(Self {
1141 prompt_name,
1142 arg_value,
1143 explicit_server_id: Some(server_id),
1144 })
1145 } else {
1146 Some(Self {
1147 prompt_name: command,
1148 arg_value,
1149 explicit_server_id: None,
1150 })
1151 }
1152 }
1153}
1154
1155struct NativeAgentModelSelector {
1156 session_id: acp::SessionId,
1157 connection: NativeAgentConnection,
1158}
1159
1160impl acp_thread::AgentModelSelector for NativeAgentModelSelector {
1161 fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
1162 log::debug!("NativeAgentConnection::list_models called");
1163 let list = self.connection.0.read(cx).models.model_list.clone();
1164 Task::ready(if list.is_empty() {
1165 Err(anyhow::anyhow!("No models available"))
1166 } else {
1167 Ok(list)
1168 })
1169 }
1170
1171 fn select_model(&self, model_id: acp::ModelId, cx: &mut App) -> Task<Result<()>> {
1172 log::debug!(
1173 "Setting model for session {}: {}",
1174 self.session_id,
1175 model_id
1176 );
1177 let Some(thread) = self
1178 .connection
1179 .0
1180 .read(cx)
1181 .sessions
1182 .get(&self.session_id)
1183 .map(|session| session.thread.clone())
1184 else {
1185 return Task::ready(Err(anyhow!("Session not found")));
1186 };
1187
1188 let Some(model) = self.connection.0.read(cx).models.model_from_id(&model_id) else {
1189 return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
1190 };
1191
1192 // We want to reset the effort level when switching models, as the currently-selected effort level may
1193 // not be compatible.
1194 let effort = model
1195 .default_effort_level()
1196 .map(|effort_level| effort_level.value.to_string());
1197
1198 thread.update(cx, |thread, cx| {
1199 thread.set_model(model.clone(), cx);
1200 thread.set_thinking_effort(effort.clone(), cx);
1201 thread.set_thinking_enabled(model.supports_thinking(), cx);
1202 });
1203
1204 update_settings_file(
1205 self.connection.0.read(cx).fs.clone(),
1206 cx,
1207 move |settings, cx| {
1208 let provider = model.provider_id().0.to_string();
1209 let model = model.id().0.to_string();
1210 let enable_thinking = thread.read(cx).thinking_enabled();
1211 settings
1212 .agent
1213 .get_or_insert_default()
1214 .set_model(LanguageModelSelection {
1215 provider: provider.into(),
1216 model,
1217 enable_thinking,
1218 effort,
1219 });
1220 },
1221 );
1222
1223 Task::ready(Ok(()))
1224 }
1225
1226 fn selected_model(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelInfo>> {
1227 let Some(thread) = self
1228 .connection
1229 .0
1230 .read(cx)
1231 .sessions
1232 .get(&self.session_id)
1233 .map(|session| session.thread.clone())
1234 else {
1235 return Task::ready(Err(anyhow!("Session not found")));
1236 };
1237 let Some(model) = thread.read(cx).model() else {
1238 return Task::ready(Err(anyhow!("Model not found")));
1239 };
1240 let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
1241 else {
1242 return Task::ready(Err(anyhow!("Provider not found")));
1243 };
1244 Task::ready(Ok(LanguageModels::map_language_model_to_info(
1245 model, &provider,
1246 )))
1247 }
1248
1249 fn watch(&self, cx: &mut App) -> Option<watch::Receiver<()>> {
1250 Some(self.connection.0.read(cx).models.watch())
1251 }
1252
1253 fn should_render_footer(&self) -> bool {
1254 true
1255 }
1256}
1257
1258impl acp_thread::AgentConnection for NativeAgentConnection {
1259 fn telemetry_id(&self) -> SharedString {
1260 "zed".into()
1261 }
1262
1263 fn new_session(
1264 self: Rc<Self>,
1265 project: Entity<Project>,
1266 cwd: &Path,
1267 cx: &mut App,
1268 ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
1269 log::debug!("Creating new thread for project at: {cwd:?}");
1270 Task::ready(Ok(self
1271 .0
1272 .update(cx, |agent, cx| agent.new_session(project, cx))))
1273 }
1274
1275 fn supports_load_session(&self) -> bool {
1276 true
1277 }
1278
1279 fn load_session(
1280 self: Rc<Self>,
1281 session_id: acp::SessionId,
1282 _project: Entity<Project>,
1283 _cwd: &Path,
1284 _title: Option<SharedString>,
1285 cx: &mut App,
1286 ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
1287 self.0
1288 .update(cx, |agent, cx| agent.open_thread(session_id, cx))
1289 }
1290
1291 fn supports_close_session(&self) -> bool {
1292 true
1293 }
1294
1295 fn close_session(&self, session_id: &acp::SessionId, cx: &mut App) -> Task<Result<()>> {
1296 self.0.update(cx, |agent, _cx| {
1297 agent.sessions.remove(session_id);
1298 });
1299 Task::ready(Ok(()))
1300 }
1301
1302 fn auth_methods(&self) -> &[acp::AuthMethod] {
1303 &[] // No auth for in-process
1304 }
1305
1306 fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
1307 Task::ready(Ok(()))
1308 }
1309
1310 fn model_selector(&self, session_id: &acp::SessionId) -> Option<Rc<dyn AgentModelSelector>> {
1311 Some(Rc::new(NativeAgentModelSelector {
1312 session_id: session_id.clone(),
1313 connection: self.clone(),
1314 }) as Rc<dyn AgentModelSelector>)
1315 }
1316
1317 fn prompt(
1318 &self,
1319 id: Option<acp_thread::UserMessageId>,
1320 params: acp::PromptRequest,
1321 cx: &mut App,
1322 ) -> Task<Result<acp::PromptResponse>> {
1323 let id = id.expect("UserMessageId is required");
1324 let session_id = params.session_id.clone();
1325 log::info!("Received prompt request for session: {}", session_id);
1326 log::debug!("Prompt blocks count: {}", params.prompt.len());
1327
1328 if let Some(parsed_command) = Command::parse(¶ms.prompt) {
1329 let registry = self.0.read(cx).context_server_registry.read(cx);
1330
1331 let explicit_server_id = parsed_command
1332 .explicit_server_id
1333 .map(|server_id| ContextServerId(server_id.into()));
1334
1335 if let Some(prompt) =
1336 registry.find_prompt(explicit_server_id.as_ref(), parsed_command.prompt_name)
1337 {
1338 let arguments = if !parsed_command.arg_value.is_empty()
1339 && let Some(arg_name) = prompt
1340 .prompt
1341 .arguments
1342 .as_ref()
1343 .and_then(|args| args.first())
1344 .map(|arg| arg.name.clone())
1345 {
1346 HashMap::from_iter([(arg_name, parsed_command.arg_value.to_string())])
1347 } else {
1348 Default::default()
1349 };
1350
1351 let prompt_name = prompt.prompt.name.clone();
1352 let server_id = prompt.server_id.clone();
1353
1354 return self.0.update(cx, |agent, cx| {
1355 agent.send_mcp_prompt(
1356 id,
1357 session_id.clone(),
1358 prompt_name,
1359 server_id,
1360 arguments,
1361 params.prompt,
1362 cx,
1363 )
1364 });
1365 };
1366 };
1367
1368 let path_style = self.0.read(cx).project.read(cx).path_style(cx);
1369
1370 self.run_turn(session_id, cx, move |thread, cx| {
1371 let content: Vec<UserMessageContent> = params
1372 .prompt
1373 .into_iter()
1374 .map(|block| UserMessageContent::from_content_block(block, path_style))
1375 .collect::<Vec<_>>();
1376 log::debug!("Converted prompt to message: {} chars", content.len());
1377 log::debug!("Message id: {:?}", id);
1378 log::debug!("Message content: {:?}", content);
1379
1380 thread.update(cx, |thread, cx| thread.send(id, content, cx))
1381 })
1382 }
1383
1384 fn retry(
1385 &self,
1386 session_id: &acp::SessionId,
1387 _cx: &App,
1388 ) -> Option<Rc<dyn acp_thread::AgentSessionRetry>> {
1389 Some(Rc::new(NativeAgentSessionRetry {
1390 connection: self.clone(),
1391 session_id: session_id.clone(),
1392 }) as _)
1393 }
1394
1395 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
1396 log::info!("Cancelling on session: {}", session_id);
1397 self.0.update(cx, |agent, cx| {
1398 if let Some(session) = agent.sessions.get(session_id) {
1399 session
1400 .thread
1401 .update(cx, |thread, cx| thread.cancel(cx))
1402 .detach();
1403 }
1404 });
1405 }
1406
1407 fn truncate(
1408 &self,
1409 session_id: &agent_client_protocol::SessionId,
1410 cx: &App,
1411 ) -> Option<Rc<dyn acp_thread::AgentSessionTruncate>> {
1412 self.0.read_with(cx, |agent, _cx| {
1413 agent.sessions.get(session_id).map(|session| {
1414 Rc::new(NativeAgentSessionTruncate {
1415 thread: session.thread.clone(),
1416 acp_thread: session.acp_thread.downgrade(),
1417 }) as _
1418 })
1419 })
1420 }
1421
1422 fn set_title(
1423 &self,
1424 session_id: &acp::SessionId,
1425 cx: &App,
1426 ) -> Option<Rc<dyn acp_thread::AgentSessionSetTitle>> {
1427 self.0.read_with(cx, |agent, _cx| {
1428 agent
1429 .sessions
1430 .get(session_id)
1431 .filter(|s| !s.thread.read(cx).is_subagent())
1432 .map(|session| {
1433 Rc::new(NativeAgentSessionSetTitle {
1434 thread: session.thread.clone(),
1435 }) as _
1436 })
1437 })
1438 }
1439
1440 fn session_list(&self, cx: &mut App) -> Option<Rc<dyn AgentSessionList>> {
1441 let thread_store = self.0.read(cx).thread_store.clone();
1442 Some(Rc::new(NativeAgentSessionList::new(thread_store, cx)) as _)
1443 }
1444
1445 fn telemetry(&self) -> Option<Rc<dyn acp_thread::AgentTelemetry>> {
1446 Some(Rc::new(self.clone()) as Rc<dyn acp_thread::AgentTelemetry>)
1447 }
1448
1449 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1450 self
1451 }
1452}
1453
1454impl acp_thread::AgentTelemetry for NativeAgentConnection {
1455 fn thread_data(
1456 &self,
1457 session_id: &acp::SessionId,
1458 cx: &mut App,
1459 ) -> Task<Result<serde_json::Value>> {
1460 let Some(session) = self.0.read(cx).sessions.get(session_id) else {
1461 return Task::ready(Err(anyhow!("Session not found")));
1462 };
1463
1464 let task = session.thread.read(cx).to_db(cx);
1465 cx.background_spawn(async move {
1466 serde_json::to_value(task.await).context("Failed to serialize thread")
1467 })
1468 }
1469}
1470
1471pub struct NativeAgentSessionList {
1472 thread_store: Entity<ThreadStore>,
1473 updates_tx: smol::channel::Sender<acp_thread::SessionListUpdate>,
1474 updates_rx: smol::channel::Receiver<acp_thread::SessionListUpdate>,
1475 _subscription: Subscription,
1476}
1477
1478impl NativeAgentSessionList {
1479 fn new(thread_store: Entity<ThreadStore>, cx: &mut App) -> Self {
1480 let (tx, rx) = smol::channel::unbounded();
1481 let this_tx = tx.clone();
1482 let subscription = cx.observe(&thread_store, move |_, _| {
1483 this_tx
1484 .try_send(acp_thread::SessionListUpdate::Refresh)
1485 .ok();
1486 });
1487 Self {
1488 thread_store,
1489 updates_tx: tx,
1490 updates_rx: rx,
1491 _subscription: subscription,
1492 }
1493 }
1494
1495 pub fn thread_store(&self) -> &Entity<ThreadStore> {
1496 &self.thread_store
1497 }
1498}
1499
1500impl AgentSessionList for NativeAgentSessionList {
1501 fn list_sessions(
1502 &self,
1503 _request: AgentSessionListRequest,
1504 cx: &mut App,
1505 ) -> Task<Result<AgentSessionListResponse>> {
1506 let sessions = self
1507 .thread_store
1508 .read(cx)
1509 .entries()
1510 .map(|entry| AgentSessionInfo::from(&entry))
1511 .collect();
1512 Task::ready(Ok(AgentSessionListResponse::new(sessions)))
1513 }
1514
1515 fn supports_delete(&self) -> bool {
1516 true
1517 }
1518
1519 fn delete_session(&self, session_id: &acp::SessionId, cx: &mut App) -> Task<Result<()>> {
1520 self.thread_store
1521 .update(cx, |store, cx| store.delete_thread(session_id.clone(), cx))
1522 }
1523
1524 fn delete_sessions(&self, cx: &mut App) -> Task<Result<()>> {
1525 self.thread_store
1526 .update(cx, |store, cx| store.delete_threads(cx))
1527 }
1528
1529 fn watch(
1530 &self,
1531 _cx: &mut App,
1532 ) -> Option<smol::channel::Receiver<acp_thread::SessionListUpdate>> {
1533 Some(self.updates_rx.clone())
1534 }
1535
1536 fn notify_refresh(&self) {
1537 self.updates_tx
1538 .try_send(acp_thread::SessionListUpdate::Refresh)
1539 .ok();
1540 }
1541
1542 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1543 self
1544 }
1545}
1546
1547struct NativeAgentSessionTruncate {
1548 thread: Entity<Thread>,
1549 acp_thread: WeakEntity<AcpThread>,
1550}
1551
1552impl acp_thread::AgentSessionTruncate for NativeAgentSessionTruncate {
1553 fn run(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
1554 match self.thread.update(cx, |thread, cx| {
1555 thread.truncate(message_id.clone(), cx)?;
1556 Ok(thread.latest_token_usage())
1557 }) {
1558 Ok(usage) => {
1559 self.acp_thread
1560 .update(cx, |thread, cx| {
1561 thread.update_token_usage(usage, cx);
1562 })
1563 .ok();
1564 Task::ready(Ok(()))
1565 }
1566 Err(error) => Task::ready(Err(error)),
1567 }
1568 }
1569}
1570
1571struct NativeAgentSessionRetry {
1572 connection: NativeAgentConnection,
1573 session_id: acp::SessionId,
1574}
1575
1576impl acp_thread::AgentSessionRetry for NativeAgentSessionRetry {
1577 fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>> {
1578 self.connection
1579 .run_turn(self.session_id.clone(), cx, |thread, cx| {
1580 thread.update(cx, |thread, cx| thread.resume(cx))
1581 })
1582 }
1583}
1584
1585struct NativeAgentSessionSetTitle {
1586 thread: Entity<Thread>,
1587}
1588
1589impl acp_thread::AgentSessionSetTitle for NativeAgentSessionSetTitle {
1590 fn run(&self, title: SharedString, cx: &mut App) -> Task<Result<()>> {
1591 self.thread
1592 .update(cx, |thread, cx| thread.set_title(title, cx));
1593 Task::ready(Ok(()))
1594 }
1595}
1596
1597pub struct NativeThreadEnvironment {
1598 agent: WeakEntity<NativeAgent>,
1599 thread: WeakEntity<Thread>,
1600 acp_thread: WeakEntity<AcpThread>,
1601}
1602
1603impl NativeThreadEnvironment {
1604 pub(crate) fn create_subagent_thread(
1605 &self,
1606 label: String,
1607 cx: &mut App,
1608 ) -> Result<Rc<dyn SubagentHandle>> {
1609 let Some(parent_thread_entity) = self.thread.upgrade() else {
1610 anyhow::bail!("Parent thread no longer exists".to_string());
1611 };
1612 let parent_thread = parent_thread_entity.read(cx);
1613 let current_depth = parent_thread.depth();
1614
1615 if current_depth >= MAX_SUBAGENT_DEPTH {
1616 return Err(anyhow!(
1617 "Maximum subagent depth ({}) reached",
1618 MAX_SUBAGENT_DEPTH
1619 ));
1620 }
1621
1622 let subagent_thread: Entity<Thread> = cx.new(|cx| {
1623 let mut thread = Thread::new_subagent(&parent_thread_entity, cx);
1624 thread.set_title(label.into(), cx);
1625 thread
1626 });
1627
1628 let session_id = subagent_thread.read(cx).id().clone();
1629
1630 let acp_thread = self.agent.update(cx, |agent, cx| {
1631 agent.register_session(subagent_thread.clone(), cx)
1632 })?;
1633
1634 let depth = current_depth + 1;
1635
1636 telemetry::event!(
1637 "Subagent Started",
1638 session = parent_thread_entity.read(cx).id().to_string(),
1639 subagent_session = session_id.to_string(),
1640 depth,
1641 is_resumed = false,
1642 );
1643
1644 self.prompt_subagent(session_id, subagent_thread, acp_thread)
1645 }
1646
1647 pub(crate) fn resume_subagent_thread(
1648 &self,
1649 session_id: acp::SessionId,
1650 cx: &mut App,
1651 ) -> Result<Rc<dyn SubagentHandle>> {
1652 let (subagent_thread, acp_thread) = self.agent.update(cx, |agent, _cx| {
1653 let session = agent
1654 .sessions
1655 .get(&session_id)
1656 .ok_or_else(|| anyhow!("No subagent session found with id {session_id}"))?;
1657 anyhow::Ok((session.thread.clone(), session.acp_thread.clone()))
1658 })??;
1659
1660 let depth = subagent_thread.read(cx).depth();
1661
1662 if let Some(parent_thread_entity) = self.thread.upgrade() {
1663 telemetry::event!(
1664 "Subagent Started",
1665 session = parent_thread_entity.read(cx).id().to_string(),
1666 subagent_session = session_id.to_string(),
1667 depth,
1668 is_resumed = true,
1669 );
1670 }
1671
1672 self.prompt_subagent(session_id, subagent_thread, acp_thread)
1673 }
1674
1675 fn prompt_subagent(
1676 &self,
1677 session_id: acp::SessionId,
1678 subagent_thread: Entity<Thread>,
1679 acp_thread: Entity<acp_thread::AcpThread>,
1680 ) -> Result<Rc<dyn SubagentHandle>> {
1681 let Some(parent_thread_entity) = self.thread.upgrade() else {
1682 anyhow::bail!("Parent thread no longer exists".to_string());
1683 };
1684 Ok(Rc::new(NativeSubagentHandle::new(
1685 session_id,
1686 subagent_thread,
1687 acp_thread,
1688 parent_thread_entity,
1689 )) as _)
1690 }
1691}
1692
1693impl ThreadEnvironment for NativeThreadEnvironment {
1694 fn create_terminal(
1695 &self,
1696 command: String,
1697 cwd: Option<PathBuf>,
1698 output_byte_limit: Option<u64>,
1699 cx: &mut AsyncApp,
1700 ) -> Task<Result<Rc<dyn TerminalHandle>>> {
1701 let task = self.acp_thread.update(cx, |thread, cx| {
1702 thread.create_terminal(command, vec![], vec![], cwd, output_byte_limit, cx)
1703 });
1704
1705 let acp_thread = self.acp_thread.clone();
1706 cx.spawn(async move |cx| {
1707 let terminal = task?.await?;
1708
1709 let (drop_tx, drop_rx) = oneshot::channel();
1710 let terminal_id = terminal.read_with(cx, |terminal, _cx| terminal.id().clone());
1711
1712 cx.spawn(async move |cx| {
1713 drop_rx.await.ok();
1714 acp_thread.update(cx, |thread, cx| thread.release_terminal(terminal_id, cx))
1715 })
1716 .detach();
1717
1718 let handle = AcpTerminalHandle {
1719 terminal,
1720 _drop_tx: Some(drop_tx),
1721 };
1722
1723 Ok(Rc::new(handle) as _)
1724 })
1725 }
1726
1727 fn create_subagent(&self, label: String, cx: &mut App) -> Result<Rc<dyn SubagentHandle>> {
1728 self.create_subagent_thread(label, cx)
1729 }
1730
1731 fn resume_subagent(
1732 &self,
1733 session_id: acp::SessionId,
1734 cx: &mut App,
1735 ) -> Result<Rc<dyn SubagentHandle>> {
1736 self.resume_subagent_thread(session_id, cx)
1737 }
1738}
1739
1740#[derive(Debug, Clone)]
1741enum SubagentPromptResult {
1742 Completed,
1743 Cancelled,
1744 ContextWindowWarning,
1745 Error(String),
1746}
1747
1748pub struct NativeSubagentHandle {
1749 session_id: acp::SessionId,
1750 parent_thread: WeakEntity<Thread>,
1751 subagent_thread: Entity<Thread>,
1752 acp_thread: Entity<acp_thread::AcpThread>,
1753}
1754
1755impl NativeSubagentHandle {
1756 fn new(
1757 session_id: acp::SessionId,
1758 subagent_thread: Entity<Thread>,
1759 acp_thread: Entity<acp_thread::AcpThread>,
1760 parent_thread_entity: Entity<Thread>,
1761 ) -> Self {
1762 NativeSubagentHandle {
1763 session_id,
1764 subagent_thread,
1765 parent_thread: parent_thread_entity.downgrade(),
1766 acp_thread,
1767 }
1768 }
1769}
1770
1771impl SubagentHandle for NativeSubagentHandle {
1772 fn id(&self) -> acp::SessionId {
1773 self.session_id.clone()
1774 }
1775
1776 fn num_entries(&self, cx: &App) -> usize {
1777 self.acp_thread.read(cx).entries().len()
1778 }
1779
1780 fn send(&self, message: String, cx: &AsyncApp) -> Task<Result<String>> {
1781 let thread = self.subagent_thread.clone();
1782 let acp_thread = self.acp_thread.clone();
1783 let subagent_session_id = self.session_id.clone();
1784 let parent_thread = self.parent_thread.clone();
1785
1786 cx.spawn(async move |cx| {
1787 let (task, _subscription) = cx.update(|cx| {
1788 let ratio_before_prompt = thread
1789 .read(cx)
1790 .latest_token_usage()
1791 .map(|usage| usage.ratio());
1792
1793 parent_thread
1794 .update(cx, |parent_thread, _cx| {
1795 parent_thread.register_running_subagent(thread.downgrade())
1796 })
1797 .ok();
1798
1799 let task = acp_thread.update(cx, |acp_thread, cx| {
1800 acp_thread.send(vec![message.into()], cx)
1801 });
1802
1803 let (token_limit_tx, token_limit_rx) = oneshot::channel::<()>();
1804 let mut token_limit_tx = Some(token_limit_tx);
1805
1806 let subscription = cx.subscribe(
1807 &thread,
1808 move |_thread, event: &TokenUsageUpdated, _cx| {
1809 if let Some(usage) = &event.0 {
1810 let old_ratio = ratio_before_prompt
1811 .clone()
1812 .unwrap_or(TokenUsageRatio::Normal);
1813 let new_ratio = usage.ratio();
1814 if old_ratio == TokenUsageRatio::Normal
1815 && new_ratio == TokenUsageRatio::Warning
1816 {
1817 if let Some(tx) = token_limit_tx.take() {
1818 tx.send(()).ok();
1819 }
1820 }
1821 }
1822 },
1823 );
1824
1825 let wait_for_prompt = cx
1826 .background_spawn(async move {
1827 futures::select! {
1828 response = task.fuse() => match response {
1829 Ok(Some(response)) => {
1830 match response.stop_reason {
1831 acp::StopReason::Cancelled => SubagentPromptResult::Cancelled,
1832 acp::StopReason::MaxTokens => SubagentPromptResult::Error("The agent reached the maximum number of tokens.".into()),
1833 acp::StopReason::MaxTurnRequests => SubagentPromptResult::Error("The agent reached the maximum number of allowed requests between user turns. Try prompting again.".into()),
1834 acp::StopReason::Refusal => SubagentPromptResult::Error("The agent refused to process that prompt. Try again.".into()),
1835 acp::StopReason::EndTurn | _ => SubagentPromptResult::Completed,
1836 }
1837 }
1838 Ok(None) => SubagentPromptResult::Error("No response from the agent. You can try messaging again.".into()),
1839 Err(error) => SubagentPromptResult::Error(error.to_string()),
1840 },
1841 _ = token_limit_rx.fuse() => SubagentPromptResult::ContextWindowWarning,
1842 }
1843 });
1844
1845 (wait_for_prompt, subscription)
1846 });
1847
1848 let result = match task.await {
1849 SubagentPromptResult::Completed => thread.read_with(cx, |thread, _cx| {
1850 thread
1851 .last_message()
1852 .and_then(|message| {
1853 let content = message.as_agent_message()?
1854 .content
1855 .iter()
1856 .filter_map(|c| match c {
1857 AgentMessageContent::Text(text) => Some(text.as_str()),
1858 _ => None,
1859 })
1860 .join("\n\n");
1861 if content.is_empty() {
1862 None
1863 } else {
1864 Some( content)
1865 }
1866 })
1867 .context("No response from subagent")
1868 }),
1869 SubagentPromptResult::Cancelled => Err(anyhow!("User canceled")),
1870 SubagentPromptResult::Error(message) => Err(anyhow!("{message}")),
1871 SubagentPromptResult::ContextWindowWarning => {
1872 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
1873 Err(anyhow!(
1874 "The agent is nearing the end of its context window and has been \
1875 stopped. You can prompt the thread again to have the agent wrap up \
1876 or hand off its work."
1877 ))
1878 }
1879 };
1880
1881 parent_thread
1882 .update(cx, |parent_thread, cx| {
1883 parent_thread.unregister_running_subagent(&subagent_session_id, cx)
1884 })
1885 .ok();
1886
1887 result
1888 })
1889 }
1890}
1891
1892pub struct AcpTerminalHandle {
1893 terminal: Entity<acp_thread::Terminal>,
1894 _drop_tx: Option<oneshot::Sender<()>>,
1895}
1896
1897impl TerminalHandle for AcpTerminalHandle {
1898 fn id(&self, cx: &AsyncApp) -> Result<acp::TerminalId> {
1899 Ok(self.terminal.read_with(cx, |term, _cx| term.id().clone()))
1900 }
1901
1902 fn wait_for_exit(&self, cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
1903 Ok(self
1904 .terminal
1905 .read_with(cx, |term, _cx| term.wait_for_exit()))
1906 }
1907
1908 fn current_output(&self, cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
1909 Ok(self
1910 .terminal
1911 .read_with(cx, |term, cx| term.current_output(cx)))
1912 }
1913
1914 fn kill(&self, cx: &AsyncApp) -> Result<()> {
1915 cx.update(|cx| {
1916 self.terminal.update(cx, |terminal, cx| {
1917 terminal.kill(cx);
1918 });
1919 });
1920 Ok(())
1921 }
1922
1923 fn was_stopped_by_user(&self, cx: &AsyncApp) -> Result<bool> {
1924 Ok(self
1925 .terminal
1926 .read_with(cx, |term, _cx| term.was_stopped_by_user()))
1927 }
1928}
1929
1930#[cfg(test)]
1931mod internal_tests {
1932 use super::*;
1933 use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelInfo, MentionUri};
1934 use fs::FakeFs;
1935 use gpui::TestAppContext;
1936 use indoc::formatdoc;
1937 use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
1938 use language_model::{
1939 LanguageModelCompletionEvent, LanguageModelProviderId, LanguageModelProviderName,
1940 };
1941 use serde_json::json;
1942 use settings::SettingsStore;
1943 use util::{path, rel_path::rel_path};
1944
1945 #[gpui::test]
1946 async fn test_maintaining_project_context(cx: &mut TestAppContext) {
1947 init_test(cx);
1948 let fs = FakeFs::new(cx.executor());
1949 fs.insert_tree(
1950 "/",
1951 json!({
1952 "a": {}
1953 }),
1954 )
1955 .await;
1956 let project = Project::test(fs.clone(), [], cx).await;
1957 let thread_store = cx.new(|cx| ThreadStore::new(cx));
1958 let agent = NativeAgent::new(
1959 project.clone(),
1960 thread_store,
1961 Templates::new(),
1962 None,
1963 fs.clone(),
1964 &mut cx.to_async(),
1965 )
1966 .await
1967 .unwrap();
1968 agent.read_with(cx, |agent, cx| {
1969 assert_eq!(agent.project_context.read(cx).worktrees, vec![])
1970 });
1971
1972 let worktree = project
1973 .update(cx, |project, cx| project.create_worktree("/a", true, cx))
1974 .await
1975 .unwrap();
1976 cx.run_until_parked();
1977 agent.read_with(cx, |agent, cx| {
1978 assert_eq!(
1979 agent.project_context.read(cx).worktrees,
1980 vec![WorktreeContext {
1981 root_name: "a".into(),
1982 abs_path: Path::new("/a").into(),
1983 rules_file: None
1984 }]
1985 )
1986 });
1987
1988 // Creating `/a/.rules` updates the project context.
1989 fs.insert_file("/a/.rules", Vec::new()).await;
1990 cx.run_until_parked();
1991 agent.read_with(cx, |agent, cx| {
1992 let rules_entry = worktree
1993 .read(cx)
1994 .entry_for_path(rel_path(".rules"))
1995 .unwrap();
1996 assert_eq!(
1997 agent.project_context.read(cx).worktrees,
1998 vec![WorktreeContext {
1999 root_name: "a".into(),
2000 abs_path: Path::new("/a").into(),
2001 rules_file: Some(RulesFileContext {
2002 path_in_worktree: rel_path(".rules").into(),
2003 text: "".into(),
2004 project_entry_id: rules_entry.id.to_usize()
2005 })
2006 }]
2007 )
2008 });
2009 }
2010
2011 #[gpui::test]
2012 async fn test_listing_models(cx: &mut TestAppContext) {
2013 init_test(cx);
2014 let fs = FakeFs::new(cx.executor());
2015 fs.insert_tree("/", json!({ "a": {} })).await;
2016 let project = Project::test(fs.clone(), [], cx).await;
2017 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2018 let connection = NativeAgentConnection(
2019 NativeAgent::new(
2020 project.clone(),
2021 thread_store,
2022 Templates::new(),
2023 None,
2024 fs.clone(),
2025 &mut cx.to_async(),
2026 )
2027 .await
2028 .unwrap(),
2029 );
2030
2031 // Create a thread/session
2032 let acp_thread = cx
2033 .update(|cx| {
2034 Rc::new(connection.clone()).new_session(project.clone(), Path::new("/a"), cx)
2035 })
2036 .await
2037 .unwrap();
2038
2039 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2040
2041 let models = cx
2042 .update(|cx| {
2043 connection
2044 .model_selector(&session_id)
2045 .unwrap()
2046 .list_models(cx)
2047 })
2048 .await
2049 .unwrap();
2050
2051 let acp_thread::AgentModelList::Grouped(models) = models else {
2052 panic!("Unexpected model group");
2053 };
2054 assert_eq!(
2055 models,
2056 IndexMap::from_iter([(
2057 AgentModelGroupName("Fake".into()),
2058 vec![AgentModelInfo {
2059 id: acp::ModelId::new("fake/fake"),
2060 name: "Fake".into(),
2061 description: None,
2062 icon: Some(acp_thread::AgentModelIcon::Named(
2063 ui::IconName::ZedAssistant
2064 )),
2065 is_latest: false,
2066 cost: None,
2067 }]
2068 )])
2069 );
2070 }
2071
2072 #[gpui::test]
2073 async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) {
2074 init_test(cx);
2075 let fs = FakeFs::new(cx.executor());
2076 fs.create_dir(paths::settings_file().parent().unwrap())
2077 .await
2078 .unwrap();
2079 fs.insert_file(
2080 paths::settings_file(),
2081 json!({
2082 "agent": {
2083 "default_model": {
2084 "provider": "foo",
2085 "model": "bar"
2086 }
2087 }
2088 })
2089 .to_string()
2090 .into_bytes(),
2091 )
2092 .await;
2093 let project = Project::test(fs.clone(), [], cx).await;
2094
2095 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2096
2097 // Create the agent and connection
2098 let agent = NativeAgent::new(
2099 project.clone(),
2100 thread_store,
2101 Templates::new(),
2102 None,
2103 fs.clone(),
2104 &mut cx.to_async(),
2105 )
2106 .await
2107 .unwrap();
2108 let connection = NativeAgentConnection(agent.clone());
2109
2110 // Create a thread/session
2111 let acp_thread = cx
2112 .update(|cx| {
2113 Rc::new(connection.clone()).new_session(project.clone(), Path::new("/a"), cx)
2114 })
2115 .await
2116 .unwrap();
2117
2118 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2119
2120 // Select a model
2121 let selector = connection.model_selector(&session_id).unwrap();
2122 let model_id = acp::ModelId::new("fake/fake");
2123 cx.update(|cx| selector.select_model(model_id.clone(), cx))
2124 .await
2125 .unwrap();
2126
2127 // Verify the thread has the selected model
2128 agent.read_with(cx, |agent, _| {
2129 let session = agent.sessions.get(&session_id).unwrap();
2130 session.thread.read_with(cx, |thread, _| {
2131 assert_eq!(thread.model().unwrap().id().0, "fake");
2132 });
2133 });
2134
2135 cx.run_until_parked();
2136
2137 // Verify settings file was updated
2138 let settings_content = fs.load(paths::settings_file()).await.unwrap();
2139 let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
2140
2141 // Check that the agent settings contain the selected model
2142 assert_eq!(
2143 settings_json["agent"]["default_model"]["model"],
2144 json!("fake")
2145 );
2146 assert_eq!(
2147 settings_json["agent"]["default_model"]["provider"],
2148 json!("fake")
2149 );
2150
2151 // Register a thinking model and select it.
2152 cx.update(|cx| {
2153 let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2154 "fake-corp",
2155 "fake-thinking",
2156 "Fake Thinking",
2157 true,
2158 ));
2159 let thinking_provider = Arc::new(
2160 FakeLanguageModelProvider::new(
2161 LanguageModelProviderId::from("fake-corp".to_string()),
2162 LanguageModelProviderName::from("Fake Corp".to_string()),
2163 )
2164 .with_models(vec![thinking_model]),
2165 );
2166 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2167 registry.register_provider(thinking_provider, cx);
2168 });
2169 });
2170 agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2171
2172 let selector = connection.model_selector(&session_id).unwrap();
2173 cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2174 .await
2175 .unwrap();
2176 cx.run_until_parked();
2177
2178 // Verify enable_thinking was written to settings as true.
2179 let settings_content = fs.load(paths::settings_file()).await.unwrap();
2180 let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
2181 assert_eq!(
2182 settings_json["agent"]["default_model"]["enable_thinking"],
2183 json!(true),
2184 "selecting a thinking model should persist enable_thinking: true to settings"
2185 );
2186 }
2187
2188 #[gpui::test]
2189 async fn test_select_model_updates_thinking_enabled(cx: &mut TestAppContext) {
2190 init_test(cx);
2191 let fs = FakeFs::new(cx.executor());
2192 fs.create_dir(paths::settings_file().parent().unwrap())
2193 .await
2194 .unwrap();
2195 fs.insert_file(paths::settings_file(), b"{}".to_vec()).await;
2196 let project = Project::test(fs.clone(), [], cx).await;
2197
2198 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2199 let agent = NativeAgent::new(
2200 project.clone(),
2201 thread_store,
2202 Templates::new(),
2203 None,
2204 fs.clone(),
2205 &mut cx.to_async(),
2206 )
2207 .await
2208 .unwrap();
2209 let connection = NativeAgentConnection(agent.clone());
2210
2211 let acp_thread = cx
2212 .update(|cx| {
2213 Rc::new(connection.clone()).new_session(project.clone(), Path::new("/a"), cx)
2214 })
2215 .await
2216 .unwrap();
2217 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2218
2219 // Register a second provider with a thinking model.
2220 cx.update(|cx| {
2221 let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2222 "fake-corp",
2223 "fake-thinking",
2224 "Fake Thinking",
2225 true,
2226 ));
2227 let thinking_provider = Arc::new(
2228 FakeLanguageModelProvider::new(
2229 LanguageModelProviderId::from("fake-corp".to_string()),
2230 LanguageModelProviderName::from("Fake Corp".to_string()),
2231 )
2232 .with_models(vec![thinking_model]),
2233 );
2234 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2235 registry.register_provider(thinking_provider, cx);
2236 });
2237 });
2238 // Refresh the agent's model list so it picks up the new provider.
2239 agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2240
2241 // Thread starts with thinking_enabled = false (the default).
2242 agent.read_with(cx, |agent, _| {
2243 let session = agent.sessions.get(&session_id).unwrap();
2244 session.thread.read_with(cx, |thread, _| {
2245 assert!(!thread.thinking_enabled(), "thinking defaults to false");
2246 });
2247 });
2248
2249 // Select the thinking model via select_model.
2250 let selector = connection.model_selector(&session_id).unwrap();
2251 cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2252 .await
2253 .unwrap();
2254
2255 // select_model should have enabled thinking based on the model's supports_thinking().
2256 agent.read_with(cx, |agent, _| {
2257 let session = agent.sessions.get(&session_id).unwrap();
2258 session.thread.read_with(cx, |thread, _| {
2259 assert!(
2260 thread.thinking_enabled(),
2261 "select_model should enable thinking when model supports it"
2262 );
2263 });
2264 });
2265
2266 // Switch back to the non-thinking model.
2267 let selector = connection.model_selector(&session_id).unwrap();
2268 cx.update(|cx| selector.select_model(acp::ModelId::new("fake/fake"), cx))
2269 .await
2270 .unwrap();
2271
2272 // select_model should have disabled thinking.
2273 agent.read_with(cx, |agent, _| {
2274 let session = agent.sessions.get(&session_id).unwrap();
2275 session.thread.read_with(cx, |thread, _| {
2276 assert!(
2277 !thread.thinking_enabled(),
2278 "select_model should disable thinking when model does not support it"
2279 );
2280 });
2281 });
2282 }
2283
2284 #[gpui::test]
2285 async fn test_loaded_thread_preserves_thinking_enabled(cx: &mut TestAppContext) {
2286 init_test(cx);
2287 let fs = FakeFs::new(cx.executor());
2288 fs.insert_tree("/", json!({ "a": {} })).await;
2289 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2290 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2291 let agent = NativeAgent::new(
2292 project.clone(),
2293 thread_store.clone(),
2294 Templates::new(),
2295 None,
2296 fs.clone(),
2297 &mut cx.to_async(),
2298 )
2299 .await
2300 .unwrap();
2301 let connection = Rc::new(NativeAgentConnection(agent.clone()));
2302
2303 // Register a thinking model.
2304 let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2305 "fake-corp",
2306 "fake-thinking",
2307 "Fake Thinking",
2308 true,
2309 ));
2310 let thinking_provider = Arc::new(
2311 FakeLanguageModelProvider::new(
2312 LanguageModelProviderId::from("fake-corp".to_string()),
2313 LanguageModelProviderName::from("Fake Corp".to_string()),
2314 )
2315 .with_models(vec![thinking_model.clone()]),
2316 );
2317 cx.update(|cx| {
2318 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2319 registry.register_provider(thinking_provider, cx);
2320 });
2321 });
2322 agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2323
2324 // Create a thread and select the thinking model.
2325 let acp_thread = cx
2326 .update(|cx| {
2327 connection
2328 .clone()
2329 .new_session(project.clone(), Path::new("/a"), cx)
2330 })
2331 .await
2332 .unwrap();
2333 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2334
2335 let selector = connection.model_selector(&session_id).unwrap();
2336 cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2337 .await
2338 .unwrap();
2339
2340 // Verify thinking is enabled after selecting the thinking model.
2341 let thread = agent.read_with(cx, |agent, _| {
2342 agent.sessions.get(&session_id).unwrap().thread.clone()
2343 });
2344 thread.read_with(cx, |thread, _| {
2345 assert!(
2346 thread.thinking_enabled(),
2347 "thinking should be enabled after selecting thinking model"
2348 );
2349 });
2350
2351 // Send a message so the thread gets persisted.
2352 let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
2353 let send = cx.foreground_executor().spawn(send);
2354 cx.run_until_parked();
2355
2356 thinking_model.send_last_completion_stream_text_chunk("Response.");
2357 thinking_model.end_last_completion_stream();
2358
2359 send.await.unwrap();
2360 cx.run_until_parked();
2361
2362 // Close the session so it can be reloaded from disk.
2363 cx.update(|cx| connection.clone().close_session(&session_id, cx))
2364 .await
2365 .unwrap();
2366 drop(thread);
2367 drop(acp_thread);
2368 agent.read_with(cx, |agent, _| {
2369 assert!(agent.sessions.is_empty());
2370 });
2371
2372 // Reload the thread and verify thinking_enabled is still true.
2373 let reloaded_acp_thread = agent
2374 .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
2375 .await
2376 .unwrap();
2377 let reloaded_thread = agent.read_with(cx, |agent, _| {
2378 agent.sessions.get(&session_id).unwrap().thread.clone()
2379 });
2380 reloaded_thread.read_with(cx, |thread, _| {
2381 assert!(
2382 thread.thinking_enabled(),
2383 "thinking_enabled should be preserved when reloading a thread with a thinking model"
2384 );
2385 });
2386
2387 drop(reloaded_acp_thread);
2388 }
2389
2390 #[gpui::test]
2391 async fn test_loaded_thread_preserves_model(cx: &mut TestAppContext) {
2392 init_test(cx);
2393 let fs = FakeFs::new(cx.executor());
2394 fs.insert_tree("/", json!({ "a": {} })).await;
2395 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2396 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2397 let agent = NativeAgent::new(
2398 project.clone(),
2399 thread_store.clone(),
2400 Templates::new(),
2401 None,
2402 fs.clone(),
2403 &mut cx.to_async(),
2404 )
2405 .await
2406 .unwrap();
2407 let connection = Rc::new(NativeAgentConnection(agent.clone()));
2408
2409 // Register a model where id() != name(), like real Anthropic models
2410 // (e.g. id="claude-sonnet-4-5-thinking-latest", name="Claude Sonnet 4.5 Thinking").
2411 let model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2412 "fake-corp",
2413 "custom-model-id",
2414 "Custom Model Display Name",
2415 false,
2416 ));
2417 let provider = Arc::new(
2418 FakeLanguageModelProvider::new(
2419 LanguageModelProviderId::from("fake-corp".to_string()),
2420 LanguageModelProviderName::from("Fake Corp".to_string()),
2421 )
2422 .with_models(vec![model.clone()]),
2423 );
2424 cx.update(|cx| {
2425 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2426 registry.register_provider(provider, cx);
2427 });
2428 });
2429 agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2430
2431 // Create a thread and select the model.
2432 let acp_thread = cx
2433 .update(|cx| {
2434 connection
2435 .clone()
2436 .new_session(project.clone(), Path::new("/a"), cx)
2437 })
2438 .await
2439 .unwrap();
2440 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2441
2442 let selector = connection.model_selector(&session_id).unwrap();
2443 cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/custom-model-id"), cx))
2444 .await
2445 .unwrap();
2446
2447 let thread = agent.read_with(cx, |agent, _| {
2448 agent.sessions.get(&session_id).unwrap().thread.clone()
2449 });
2450 thread.read_with(cx, |thread, _| {
2451 assert_eq!(
2452 thread.model().unwrap().id().0.as_ref(),
2453 "custom-model-id",
2454 "model should be set before persisting"
2455 );
2456 });
2457
2458 // Send a message so the thread gets persisted.
2459 let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
2460 let send = cx.foreground_executor().spawn(send);
2461 cx.run_until_parked();
2462
2463 model.send_last_completion_stream_text_chunk("Response.");
2464 model.end_last_completion_stream();
2465
2466 send.await.unwrap();
2467 cx.run_until_parked();
2468
2469 // Close the session so it can be reloaded from disk.
2470 cx.update(|cx| connection.clone().close_session(&session_id, cx))
2471 .await
2472 .unwrap();
2473 drop(thread);
2474 drop(acp_thread);
2475 agent.read_with(cx, |agent, _| {
2476 assert!(agent.sessions.is_empty());
2477 });
2478
2479 // Reload the thread and verify the model was preserved.
2480 let reloaded_acp_thread = agent
2481 .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
2482 .await
2483 .unwrap();
2484 let reloaded_thread = agent.read_with(cx, |agent, _| {
2485 agent.sessions.get(&session_id).unwrap().thread.clone()
2486 });
2487 reloaded_thread.read_with(cx, |thread, _| {
2488 let reloaded_model = thread
2489 .model()
2490 .expect("model should be present after reload");
2491 assert_eq!(
2492 reloaded_model.id().0.as_ref(),
2493 "custom-model-id",
2494 "reloaded thread should have the same model, not fall back to the default"
2495 );
2496 });
2497
2498 drop(reloaded_acp_thread);
2499 }
2500
2501 #[gpui::test]
2502 async fn test_save_load_thread(cx: &mut TestAppContext) {
2503 init_test(cx);
2504 let fs = FakeFs::new(cx.executor());
2505 fs.insert_tree(
2506 "/",
2507 json!({
2508 "a": {
2509 "b.md": "Lorem"
2510 }
2511 }),
2512 )
2513 .await;
2514 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2515 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2516 let agent = NativeAgent::new(
2517 project.clone(),
2518 thread_store.clone(),
2519 Templates::new(),
2520 None,
2521 fs.clone(),
2522 &mut cx.to_async(),
2523 )
2524 .await
2525 .unwrap();
2526 let connection = Rc::new(NativeAgentConnection(agent.clone()));
2527
2528 let acp_thread = cx
2529 .update(|cx| {
2530 connection
2531 .clone()
2532 .new_session(project.clone(), Path::new(""), cx)
2533 })
2534 .await
2535 .unwrap();
2536 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2537 let thread = agent.read_with(cx, |agent, _| {
2538 agent.sessions.get(&session_id).unwrap().thread.clone()
2539 });
2540
2541 // Ensure empty threads are not saved, even if they get mutated.
2542 let model = Arc::new(FakeLanguageModel::default());
2543 let summary_model = Arc::new(FakeLanguageModel::default());
2544 thread.update(cx, |thread, cx| {
2545 thread.set_model(model.clone(), cx);
2546 thread.set_summarization_model(Some(summary_model.clone()), cx);
2547 });
2548 cx.run_until_parked();
2549 assert_eq!(thread_entries(&thread_store, cx), vec![]);
2550
2551 let send = acp_thread.update(cx, |thread, cx| {
2552 thread.send(
2553 vec![
2554 "What does ".into(),
2555 acp::ContentBlock::ResourceLink(acp::ResourceLink::new(
2556 "b.md",
2557 MentionUri::File {
2558 abs_path: path!("/a/b.md").into(),
2559 }
2560 .to_uri()
2561 .to_string(),
2562 )),
2563 " mean?".into(),
2564 ],
2565 cx,
2566 )
2567 });
2568 let send = cx.foreground_executor().spawn(send);
2569 cx.run_until_parked();
2570
2571 model.send_last_completion_stream_text_chunk("Lorem.");
2572 model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2573 language_model::TokenUsage {
2574 input_tokens: 150,
2575 output_tokens: 75,
2576 ..Default::default()
2577 },
2578 ));
2579 model.end_last_completion_stream();
2580 cx.run_until_parked();
2581 summary_model
2582 .send_last_completion_stream_text_chunk(&format!("Explaining {}", path!("/a/b.md")));
2583 summary_model.end_last_completion_stream();
2584
2585 send.await.unwrap();
2586 let uri = MentionUri::File {
2587 abs_path: path!("/a/b.md").into(),
2588 }
2589 .to_uri();
2590 acp_thread.read_with(cx, |thread, cx| {
2591 assert_eq!(
2592 thread.to_markdown(cx),
2593 formatdoc! {"
2594 ## User
2595
2596 What does [@b.md]({uri}) mean?
2597
2598 ## Assistant
2599
2600 Lorem.
2601
2602 "}
2603 )
2604 });
2605
2606 cx.run_until_parked();
2607
2608 // Set a draft prompt with rich content blocks before saving.
2609 let draft_blocks = vec![
2610 acp::ContentBlock::Text(acp::TextContent::new("Check out ")),
2611 acp::ContentBlock::ResourceLink(acp::ResourceLink::new("b.md", uri.to_string())),
2612 acp::ContentBlock::Text(acp::TextContent::new(" please")),
2613 ];
2614 acp_thread.update(cx, |thread, _cx| {
2615 thread.set_draft_prompt(Some(draft_blocks.clone()));
2616 });
2617 thread.update(cx, |thread, _cx| {
2618 thread.set_ui_scroll_position(Some(gpui::ListOffset {
2619 item_ix: 5,
2620 offset_in_item: gpui::px(12.5),
2621 }));
2622 });
2623 thread.update(cx, |_thread, cx| cx.notify());
2624 cx.run_until_parked();
2625
2626 // Close the session so it can be reloaded from disk.
2627 cx.update(|cx| connection.clone().close_session(&session_id, cx))
2628 .await
2629 .unwrap();
2630 drop(thread);
2631 drop(acp_thread);
2632 agent.read_with(cx, |agent, _| {
2633 assert_eq!(agent.sessions.keys().cloned().collect::<Vec<_>>(), []);
2634 });
2635
2636 // Ensure the thread can be reloaded from disk.
2637 assert_eq!(
2638 thread_entries(&thread_store, cx),
2639 vec![(
2640 session_id.clone(),
2641 format!("Explaining {}", path!("/a/b.md"))
2642 )]
2643 );
2644 let acp_thread = agent
2645 .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
2646 .await
2647 .unwrap();
2648 acp_thread.read_with(cx, |thread, cx| {
2649 assert_eq!(
2650 thread.to_markdown(cx),
2651 formatdoc! {"
2652 ## User
2653
2654 What does [@b.md]({uri}) mean?
2655
2656 ## Assistant
2657
2658 Lorem.
2659
2660 "}
2661 )
2662 });
2663
2664 // Ensure the draft prompt with rich content blocks survived the round-trip.
2665 acp_thread.read_with(cx, |thread, _| {
2666 assert_eq!(thread.draft_prompt(), Some(draft_blocks.as_slice()));
2667 });
2668
2669 // Ensure token usage survived the round-trip.
2670 acp_thread.read_with(cx, |thread, _| {
2671 let usage = thread
2672 .token_usage()
2673 .expect("token usage should be restored after reload");
2674 assert_eq!(usage.input_tokens, 150);
2675 assert_eq!(usage.output_tokens, 75);
2676 });
2677
2678 // Ensure scroll position survived the round-trip.
2679 acp_thread.read_with(cx, |thread, _| {
2680 let scroll = thread
2681 .ui_scroll_position()
2682 .expect("scroll position should be restored after reload");
2683 assert_eq!(scroll.item_ix, 5);
2684 assert_eq!(scroll.offset_in_item, gpui::px(12.5));
2685 });
2686 }
2687
2688 fn thread_entries(
2689 thread_store: &Entity<ThreadStore>,
2690 cx: &mut TestAppContext,
2691 ) -> Vec<(acp::SessionId, String)> {
2692 thread_store.read_with(cx, |store, _| {
2693 store
2694 .entries()
2695 .map(|entry| (entry.id.clone(), entry.title.to_string()))
2696 .collect::<Vec<_>>()
2697 })
2698 }
2699
2700 fn init_test(cx: &mut TestAppContext) {
2701 env_logger::try_init().ok();
2702 cx.update(|cx| {
2703 let settings_store = SettingsStore::test(cx);
2704 cx.set_global(settings_store);
2705
2706 LanguageModelRegistry::test(cx);
2707 });
2708 }
2709}
2710
2711fn mcp_message_content_to_acp_content_block(
2712 content: context_server::types::MessageContent,
2713) -> acp::ContentBlock {
2714 match content {
2715 context_server::types::MessageContent::Text {
2716 text,
2717 annotations: _,
2718 } => text.into(),
2719 context_server::types::MessageContent::Image {
2720 data,
2721 mime_type,
2722 annotations: _,
2723 } => acp::ContentBlock::Image(acp::ImageContent::new(data, mime_type)),
2724 context_server::types::MessageContent::Audio {
2725 data,
2726 mime_type,
2727 annotations: _,
2728 } => acp::ContentBlock::Audio(acp::AudioContent::new(data, mime_type)),
2729 context_server::types::MessageContent::Resource {
2730 resource,
2731 annotations: _,
2732 } => {
2733 let mut link =
2734 acp::ResourceLink::new(resource.uri.to_string(), resource.uri.to_string());
2735 if let Some(mime_type) = resource.mime_type {
2736 link = link.mime_type(mime_type);
2737 }
2738 acp::ContentBlock::ResourceLink(link)
2739 }
2740 }
2741}