1mod db;
2mod edit_agent;
3mod legacy_thread;
4mod native_agent_server;
5pub mod outline;
6mod pattern_extraction;
7mod templates;
8#[cfg(test)]
9mod tests;
10mod thread;
11mod thread_store;
12mod tool_permissions;
13mod tools;
14
15use context_server::ContextServerId;
16pub use db::*;
17use itertools::Itertools;
18pub use native_agent_server::NativeAgentServer;
19pub use pattern_extraction::*;
20pub use shell_command_parser::extract_commands;
21pub use templates::*;
22pub use thread::*;
23pub use thread_store::*;
24pub use tool_permissions::*;
25pub use tools::*;
26
27use acp_thread::{
28 AcpThread, AgentModelSelector, AgentSessionInfo, AgentSessionList, AgentSessionListRequest,
29 AgentSessionListResponse, TokenUsageRatio, UserMessageId,
30};
31use agent_client_protocol as acp;
32use anyhow::{Context as _, Result, anyhow};
33use chrono::{DateTime, Utc};
34use collections::{HashMap, HashSet, IndexMap};
35use fs::Fs;
36use futures::channel::{mpsc, oneshot};
37use futures::future::Shared;
38use futures::{FutureExt as _, StreamExt as _, future};
39use gpui::{
40 App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
41};
42use language_model::{IconOrSvg, LanguageModel, LanguageModelProvider, LanguageModelRegistry};
43use project::{Project, ProjectItem, ProjectPath, Worktree};
44use prompt_store::{
45 ProjectContext, PromptStore, RULES_FILE_NAMES, RulesFileContext, UserRulesContext,
46 WorktreeContext,
47};
48use serde::{Deserialize, Serialize};
49use settings::{LanguageModelSelection, update_settings_file};
50use std::any::Any;
51use std::path::{Path, PathBuf};
52use std::rc::Rc;
53use std::sync::Arc;
54use util::ResultExt;
55use util::path_list::PathList;
56use util::rel_path::RelPath;
57
58#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
59pub struct ProjectSnapshot {
60 pub worktree_snapshots: Vec<project::telemetry_snapshot::TelemetryWorktreeSnapshot>,
61 pub timestamp: DateTime<Utc>,
62}
63
64pub struct RulesLoadingError {
65 pub message: SharedString,
66}
67
68/// Holds both the internal Thread and the AcpThread for a session
69struct Session {
70 /// The internal thread that processes messages
71 thread: Entity<Thread>,
72 /// The ACP thread that handles protocol communication
73 acp_thread: Entity<acp_thread::AcpThread>,
74 pending_save: Task<()>,
75 _subscriptions: Vec<Subscription>,
76}
77
78pub struct LanguageModels {
79 /// Access language model by ID
80 models: HashMap<acp::ModelId, Arc<dyn LanguageModel>>,
81 /// Cached list for returning language model information
82 model_list: acp_thread::AgentModelList,
83 refresh_models_rx: watch::Receiver<()>,
84 refresh_models_tx: watch::Sender<()>,
85 _authenticate_all_providers_task: Task<()>,
86}
87
88impl LanguageModels {
89 fn new(cx: &mut App) -> Self {
90 let (refresh_models_tx, refresh_models_rx) = watch::channel(());
91
92 let mut this = Self {
93 models: HashMap::default(),
94 model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()),
95 refresh_models_rx,
96 refresh_models_tx,
97 _authenticate_all_providers_task: Self::authenticate_all_language_model_providers(cx),
98 };
99 this.refresh_list(cx);
100 this
101 }
102
103 fn refresh_list(&mut self, cx: &App) {
104 let providers = LanguageModelRegistry::global(cx)
105 .read(cx)
106 .visible_providers()
107 .into_iter()
108 .filter(|provider| provider.is_authenticated(cx))
109 .collect::<Vec<_>>();
110
111 let mut language_model_list = IndexMap::default();
112 let mut recommended_models = HashSet::default();
113
114 let mut recommended = Vec::new();
115 for provider in &providers {
116 for model in provider.recommended_models(cx) {
117 recommended_models.insert((model.provider_id(), model.id()));
118 recommended.push(Self::map_language_model_to_info(&model, provider));
119 }
120 }
121 if !recommended.is_empty() {
122 language_model_list.insert(
123 acp_thread::AgentModelGroupName("Recommended".into()),
124 recommended,
125 );
126 }
127
128 let mut models = HashMap::default();
129 for provider in providers {
130 let mut provider_models = Vec::new();
131 for model in provider.provided_models(cx) {
132 let model_info = Self::map_language_model_to_info(&model, &provider);
133 let model_id = model_info.id.clone();
134 provider_models.push(model_info);
135 models.insert(model_id, model);
136 }
137 if !provider_models.is_empty() {
138 language_model_list.insert(
139 acp_thread::AgentModelGroupName(provider.name().0.clone()),
140 provider_models,
141 );
142 }
143 }
144
145 self.models = models;
146 self.model_list = acp_thread::AgentModelList::Grouped(language_model_list);
147 self.refresh_models_tx.send(()).ok();
148 }
149
150 fn watch(&self) -> watch::Receiver<()> {
151 self.refresh_models_rx.clone()
152 }
153
154 pub fn model_from_id(&self, model_id: &acp::ModelId) -> Option<Arc<dyn LanguageModel>> {
155 self.models.get(model_id).cloned()
156 }
157
158 fn map_language_model_to_info(
159 model: &Arc<dyn LanguageModel>,
160 provider: &Arc<dyn LanguageModelProvider>,
161 ) -> acp_thread::AgentModelInfo {
162 acp_thread::AgentModelInfo {
163 id: Self::model_id(model),
164 name: model.name().0,
165 description: None,
166 icon: Some(match provider.icon() {
167 IconOrSvg::Svg(path) => acp_thread::AgentModelIcon::Path(path),
168 IconOrSvg::Icon(name) => acp_thread::AgentModelIcon::Named(name),
169 }),
170 is_latest: model.is_latest(),
171 cost: model.model_cost_info().map(|cost| cost.to_shared_string()),
172 }
173 }
174
175 fn model_id(model: &Arc<dyn LanguageModel>) -> acp::ModelId {
176 acp::ModelId::new(format!("{}/{}", model.provider_id().0, model.id().0))
177 }
178
179 fn authenticate_all_language_model_providers(cx: &mut App) -> Task<()> {
180 let authenticate_all_providers = LanguageModelRegistry::global(cx)
181 .read(cx)
182 .visible_providers()
183 .iter()
184 .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
185 .collect::<Vec<_>>();
186
187 cx.background_spawn(async move {
188 for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
189 if let Err(err) = authenticate_task.await {
190 match err {
191 language_model::AuthenticateError::CredentialsNotFound => {
192 // Since we're authenticating these providers in the
193 // background for the purposes of populating the
194 // language selector, we don't care about providers
195 // where the credentials are not found.
196 }
197 language_model::AuthenticateError::ConnectionRefused => {
198 // Not logging connection refused errors as they are mostly from LM Studio's noisy auth failures.
199 // LM Studio only has one auth method (endpoint call) which fails for users who haven't enabled it.
200 // TODO: Better manage LM Studio auth logic to avoid these noisy failures.
201 }
202 _ => {
203 // Some providers have noisy failure states that we
204 // don't want to spam the logs with every time the
205 // language model selector is initialized.
206 //
207 // Ideally these should have more clear failure modes
208 // that we know are safe to ignore here, like what we do
209 // with `CredentialsNotFound` above.
210 match provider_id.0.as_ref() {
211 "lmstudio" | "ollama" => {
212 // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
213 //
214 // These fail noisily, so we don't log them.
215 }
216 "copilot_chat" => {
217 // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
218 }
219 _ => {
220 log::error!(
221 "Failed to authenticate provider: {}: {err:#}",
222 provider_name.0
223 );
224 }
225 }
226 }
227 }
228 }
229 }
230 })
231 }
232}
233
234pub struct NativeAgent {
235 /// Session ID -> Session mapping
236 sessions: HashMap<acp::SessionId, Session>,
237 thread_store: Entity<ThreadStore>,
238 /// Shared project context for all threads
239 project_context: Entity<ProjectContext>,
240 project_context_needs_refresh: watch::Sender<()>,
241 _maintain_project_context: Task<Result<()>>,
242 context_server_registry: Entity<ContextServerRegistry>,
243 /// Shared templates for all threads
244 templates: Arc<Templates>,
245 /// Cached model information
246 models: LanguageModels,
247 project: Entity<Project>,
248 prompt_store: Option<Entity<PromptStore>>,
249 fs: Arc<dyn Fs>,
250 _subscriptions: Vec<Subscription>,
251}
252
253impl NativeAgent {
254 pub async fn new(
255 project: Entity<Project>,
256 thread_store: Entity<ThreadStore>,
257 templates: Arc<Templates>,
258 prompt_store: Option<Entity<PromptStore>>,
259 fs: Arc<dyn Fs>,
260 cx: &mut AsyncApp,
261 ) -> Result<Entity<NativeAgent>> {
262 log::debug!("Creating new NativeAgent");
263
264 let project_context = cx
265 .update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))
266 .await;
267
268 Ok(cx.new(|cx| {
269 let context_server_store = project.read(cx).context_server_store();
270 let context_server_registry =
271 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
272
273 let mut subscriptions = vec![
274 cx.subscribe(&project, Self::handle_project_event),
275 cx.subscribe(
276 &LanguageModelRegistry::global(cx),
277 Self::handle_models_updated_event,
278 ),
279 cx.subscribe(
280 &context_server_store,
281 Self::handle_context_server_store_updated,
282 ),
283 cx.subscribe(
284 &context_server_registry,
285 Self::handle_context_server_registry_event,
286 ),
287 ];
288 if let Some(prompt_store) = prompt_store.as_ref() {
289 subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
290 }
291
292 let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
293 watch::channel(());
294 Self {
295 sessions: HashMap::default(),
296 thread_store,
297 project_context: cx.new(|_| project_context),
298 project_context_needs_refresh: project_context_needs_refresh_tx,
299 _maintain_project_context: cx.spawn(async move |this, cx| {
300 Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
301 }),
302 context_server_registry,
303 templates,
304 models: LanguageModels::new(cx),
305 project,
306 prompt_store,
307 fs,
308 _subscriptions: subscriptions,
309 }
310 }))
311 }
312
313 fn new_session(
314 &mut self,
315 project: Entity<Project>,
316 cx: &mut Context<Self>,
317 ) -> Entity<AcpThread> {
318 // Create Thread
319 // Fetch default model from registry settings
320 let registry = LanguageModelRegistry::read_global(cx);
321 // Log available models for debugging
322 let available_count = registry.available_models(cx).count();
323 log::debug!("Total available models: {}", available_count);
324
325 let default_model = registry.default_model().and_then(|default_model| {
326 self.models
327 .model_from_id(&LanguageModels::model_id(&default_model.model))
328 });
329 let thread = cx.new(|cx| {
330 Thread::new(
331 project.clone(),
332 self.project_context.clone(),
333 self.context_server_registry.clone(),
334 self.templates.clone(),
335 default_model,
336 cx,
337 )
338 });
339
340 self.register_session(thread, cx)
341 }
342
343 fn register_session(
344 &mut self,
345 thread_handle: Entity<Thread>,
346 cx: &mut Context<Self>,
347 ) -> Entity<AcpThread> {
348 let connection = Rc::new(NativeAgentConnection(cx.entity()));
349
350 let thread = thread_handle.read(cx);
351 let session_id = thread.id().clone();
352 let parent_session_id = thread.parent_thread_id();
353 let title = thread.title();
354 let draft_prompt = thread.draft_prompt().map(Vec::from);
355 let scroll_position = thread.ui_scroll_position();
356 let token_usage = thread.latest_token_usage();
357 let project = thread.project.clone();
358 let action_log = thread.action_log.clone();
359 let prompt_capabilities_rx = thread.prompt_capabilities_rx.clone();
360 let acp_thread = cx.new(|cx| {
361 let mut acp_thread = acp_thread::AcpThread::new(
362 parent_session_id,
363 title,
364 connection,
365 project.clone(),
366 action_log.clone(),
367 session_id.clone(),
368 prompt_capabilities_rx,
369 cx,
370 );
371 acp_thread.set_draft_prompt(draft_prompt);
372 acp_thread.set_ui_scroll_position(scroll_position);
373 acp_thread.update_token_usage(token_usage, cx);
374 acp_thread
375 });
376
377 let registry = LanguageModelRegistry::read_global(cx);
378 let summarization_model = registry.thread_summary_model().map(|c| c.model);
379
380 let weak = cx.weak_entity();
381 let weak_thread = thread_handle.downgrade();
382 thread_handle.update(cx, |thread, cx| {
383 thread.set_summarization_model(summarization_model, cx);
384 thread.add_default_tools(
385 Rc::new(NativeThreadEnvironment {
386 acp_thread: acp_thread.downgrade(),
387 thread: weak_thread,
388 agent: weak,
389 }) as _,
390 cx,
391 )
392 });
393
394 let subscriptions = vec![
395 cx.subscribe(&thread_handle, Self::handle_thread_title_updated),
396 cx.subscribe(&thread_handle, Self::handle_thread_token_usage_updated),
397 cx.observe(&thread_handle, move |this, thread, cx| {
398 this.save_thread(thread, cx)
399 }),
400 ];
401
402 self.sessions.insert(
403 session_id,
404 Session {
405 thread: thread_handle,
406 acp_thread: acp_thread.clone(),
407 _subscriptions: subscriptions,
408 pending_save: Task::ready(()),
409 },
410 );
411
412 self.update_available_commands(cx);
413
414 acp_thread
415 }
416
417 pub fn models(&self) -> &LanguageModels {
418 &self.models
419 }
420
421 async fn maintain_project_context(
422 this: WeakEntity<Self>,
423 mut needs_refresh: watch::Receiver<()>,
424 cx: &mut AsyncApp,
425 ) -> Result<()> {
426 while needs_refresh.changed().await.is_ok() {
427 let project_context = this
428 .update(cx, |this, cx| {
429 Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx)
430 })?
431 .await;
432 this.update(cx, |this, cx| {
433 this.project_context = cx.new(|_| project_context);
434 })?;
435 }
436
437 Ok(())
438 }
439
440 fn build_project_context(
441 project: &Entity<Project>,
442 prompt_store: Option<&Entity<PromptStore>>,
443 cx: &mut App,
444 ) -> Task<ProjectContext> {
445 let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
446 let worktree_tasks = worktrees
447 .into_iter()
448 .map(|worktree| {
449 Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
450 })
451 .collect::<Vec<_>>();
452 let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
453 prompt_store.read_with(cx, |prompt_store, cx| {
454 let prompts = prompt_store.default_prompt_metadata();
455 let load_tasks = prompts.into_iter().map(|prompt_metadata| {
456 let contents = prompt_store.load(prompt_metadata.id, cx);
457 async move { (contents.await, prompt_metadata) }
458 });
459 cx.background_spawn(future::join_all(load_tasks))
460 })
461 } else {
462 Task::ready(vec![])
463 };
464
465 cx.spawn(async move |_cx| {
466 let (worktrees, default_user_rules) =
467 future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
468
469 let worktrees = worktrees
470 .into_iter()
471 .map(|(worktree, _rules_error)| {
472 // TODO: show error message
473 // if let Some(rules_error) = rules_error {
474 // this.update(cx, |_, cx| cx.emit(rules_error)).ok();
475 // }
476 worktree
477 })
478 .collect::<Vec<_>>();
479
480 let default_user_rules = default_user_rules
481 .into_iter()
482 .flat_map(|(contents, prompt_metadata)| match contents {
483 Ok(contents) => Some(UserRulesContext {
484 uuid: prompt_metadata.id.as_user()?,
485 title: prompt_metadata.title.map(|title| title.to_string()),
486 contents,
487 }),
488 Err(_err) => {
489 // TODO: show error message
490 // this.update(cx, |_, cx| {
491 // cx.emit(RulesLoadingError {
492 // message: format!("{err:?}").into(),
493 // });
494 // })
495 // .ok();
496 None
497 }
498 })
499 .collect::<Vec<_>>();
500
501 ProjectContext::new(worktrees, default_user_rules)
502 })
503 }
504
505 fn load_worktree_info_for_system_prompt(
506 worktree: Entity<Worktree>,
507 project: Entity<Project>,
508 cx: &mut App,
509 ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
510 let tree = worktree.read(cx);
511 let root_name = tree.root_name_str().into();
512 let abs_path = tree.abs_path();
513
514 let mut context = WorktreeContext {
515 root_name,
516 abs_path,
517 rules_file: None,
518 };
519
520 let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
521 let Some(rules_task) = rules_task else {
522 return Task::ready((context, None));
523 };
524
525 cx.spawn(async move |_| {
526 let (rules_file, rules_file_error) = match rules_task.await {
527 Ok(rules_file) => (Some(rules_file), None),
528 Err(err) => (
529 None,
530 Some(RulesLoadingError {
531 message: format!("{err}").into(),
532 }),
533 ),
534 };
535 context.rules_file = rules_file;
536 (context, rules_file_error)
537 })
538 }
539
540 fn load_worktree_rules_file(
541 worktree: Entity<Worktree>,
542 project: Entity<Project>,
543 cx: &mut App,
544 ) -> Option<Task<Result<RulesFileContext>>> {
545 let worktree = worktree.read(cx);
546 let worktree_id = worktree.id();
547 let selected_rules_file = RULES_FILE_NAMES
548 .into_iter()
549 .filter_map(|name| {
550 worktree
551 .entry_for_path(RelPath::unix(name).unwrap())
552 .filter(|entry| entry.is_file())
553 .map(|entry| entry.path.clone())
554 })
555 .next();
556
557 // Note that Cline supports `.clinerules` being a directory, but that is not currently
558 // supported. This doesn't seem to occur often in GitHub repositories.
559 selected_rules_file.map(|path_in_worktree| {
560 let project_path = ProjectPath {
561 worktree_id,
562 path: path_in_worktree.clone(),
563 };
564 let buffer_task =
565 project.update(cx, |project, cx| project.open_buffer(project_path, cx));
566 let rope_task = cx.spawn(async move |cx| {
567 let buffer = buffer_task.await?;
568 let (project_entry_id, rope) = buffer.read_with(cx, |buffer, cx| {
569 let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
570 anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
571 })?;
572 anyhow::Ok((project_entry_id, rope))
573 });
574 // Build a string from the rope on a background thread.
575 cx.background_spawn(async move {
576 let (project_entry_id, rope) = rope_task.await?;
577 anyhow::Ok(RulesFileContext {
578 path_in_worktree,
579 text: rope.to_string().trim().to_string(),
580 project_entry_id: project_entry_id.to_usize(),
581 })
582 })
583 })
584 }
585
586 fn handle_thread_title_updated(
587 &mut self,
588 thread: Entity<Thread>,
589 _: &TitleUpdated,
590 cx: &mut Context<Self>,
591 ) {
592 let session_id = thread.read(cx).id();
593 let Some(session) = self.sessions.get(session_id) else {
594 return;
595 };
596 let thread = thread.downgrade();
597 let acp_thread = session.acp_thread.downgrade();
598 cx.spawn(async move |_, cx| {
599 let title = thread.read_with(cx, |thread, _| thread.title())?;
600 let task = acp_thread.update(cx, |acp_thread, cx| acp_thread.set_title(title, cx))?;
601 task.await
602 })
603 .detach_and_log_err(cx);
604 }
605
606 fn handle_thread_token_usage_updated(
607 &mut self,
608 thread: Entity<Thread>,
609 usage: &TokenUsageUpdated,
610 cx: &mut Context<Self>,
611 ) {
612 let Some(session) = self.sessions.get(thread.read(cx).id()) else {
613 return;
614 };
615 session.acp_thread.update(cx, |acp_thread, cx| {
616 acp_thread.update_token_usage(usage.0.clone(), cx);
617 });
618 }
619
620 fn handle_project_event(
621 &mut self,
622 _project: Entity<Project>,
623 event: &project::Event,
624 _cx: &mut Context<Self>,
625 ) {
626 match event {
627 project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
628 self.project_context_needs_refresh.send(()).ok();
629 }
630 project::Event::WorktreeUpdatedEntries(_, items) => {
631 if items.iter().any(|(path, _, _)| {
632 RULES_FILE_NAMES
633 .iter()
634 .any(|name| path.as_ref() == RelPath::unix(name).unwrap())
635 }) {
636 self.project_context_needs_refresh.send(()).ok();
637 }
638 }
639 _ => {}
640 }
641 }
642
643 fn handle_prompts_updated_event(
644 &mut self,
645 _prompt_store: Entity<PromptStore>,
646 _event: &prompt_store::PromptsUpdatedEvent,
647 _cx: &mut Context<Self>,
648 ) {
649 self.project_context_needs_refresh.send(()).ok();
650 }
651
652 fn handle_models_updated_event(
653 &mut self,
654 _registry: Entity<LanguageModelRegistry>,
655 _event: &language_model::Event,
656 cx: &mut Context<Self>,
657 ) {
658 self.models.refresh_list(cx);
659
660 let registry = LanguageModelRegistry::read_global(cx);
661 let default_model = registry.default_model().map(|m| m.model);
662 let summarization_model = registry.thread_summary_model().map(|m| m.model);
663
664 for session in self.sessions.values_mut() {
665 session.thread.update(cx, |thread, cx| {
666 if thread.model().is_none()
667 && let Some(model) = default_model.clone()
668 {
669 thread.set_model(model, cx);
670 cx.notify();
671 }
672 thread.set_summarization_model(summarization_model.clone(), cx);
673 });
674 }
675 }
676
677 fn handle_context_server_store_updated(
678 &mut self,
679 _store: Entity<project::context_server_store::ContextServerStore>,
680 _event: &project::context_server_store::ServerStatusChangedEvent,
681 cx: &mut Context<Self>,
682 ) {
683 self.update_available_commands(cx);
684 }
685
686 fn handle_context_server_registry_event(
687 &mut self,
688 _registry: Entity<ContextServerRegistry>,
689 event: &ContextServerRegistryEvent,
690 cx: &mut Context<Self>,
691 ) {
692 match event {
693 ContextServerRegistryEvent::ToolsChanged => {}
694 ContextServerRegistryEvent::PromptsChanged => {
695 self.update_available_commands(cx);
696 }
697 }
698 }
699
700 fn update_available_commands(&self, cx: &mut Context<Self>) {
701 let available_commands = self.build_available_commands(cx);
702 for session in self.sessions.values() {
703 session.acp_thread.update(cx, |thread, cx| {
704 thread
705 .handle_session_update(
706 acp::SessionUpdate::AvailableCommandsUpdate(
707 acp::AvailableCommandsUpdate::new(available_commands.clone()),
708 ),
709 cx,
710 )
711 .log_err();
712 });
713 }
714 }
715
716 fn build_available_commands(&self, cx: &App) -> Vec<acp::AvailableCommand> {
717 let registry = self.context_server_registry.read(cx);
718
719 let mut prompt_name_counts: HashMap<&str, usize> = HashMap::default();
720 for context_server_prompt in registry.prompts() {
721 *prompt_name_counts
722 .entry(context_server_prompt.prompt.name.as_str())
723 .or_insert(0) += 1;
724 }
725
726 registry
727 .prompts()
728 .flat_map(|context_server_prompt| {
729 let prompt = &context_server_prompt.prompt;
730
731 let should_prefix = prompt_name_counts
732 .get(prompt.name.as_str())
733 .copied()
734 .unwrap_or(0)
735 > 1;
736
737 let name = if should_prefix {
738 format!("{}.{}", context_server_prompt.server_id, prompt.name)
739 } else {
740 prompt.name.clone()
741 };
742
743 let mut command = acp::AvailableCommand::new(
744 name,
745 prompt.description.clone().unwrap_or_default(),
746 );
747
748 match prompt.arguments.as_deref() {
749 Some([arg]) => {
750 let hint = format!("<{}>", arg.name);
751
752 command = command.input(acp::AvailableCommandInput::Unstructured(
753 acp::UnstructuredCommandInput::new(hint),
754 ));
755 }
756 Some([]) | None => {}
757 Some(_) => {
758 // skip >1 argument commands since we don't support them yet
759 return None;
760 }
761 }
762
763 Some(command)
764 })
765 .collect()
766 }
767
768 pub fn load_thread(
769 &mut self,
770 id: acp::SessionId,
771 cx: &mut Context<Self>,
772 ) -> Task<Result<Entity<Thread>>> {
773 let database_future = ThreadsDatabase::connect(cx);
774 cx.spawn(async move |this, cx| {
775 let database = database_future.await.map_err(|err| anyhow!(err))?;
776 let db_thread = database
777 .load_thread(id.clone())
778 .await?
779 .with_context(|| format!("no thread found with ID: {id:?}"))?;
780
781 this.update(cx, |this, cx| {
782 let summarization_model = LanguageModelRegistry::read_global(cx)
783 .thread_summary_model()
784 .map(|c| c.model);
785
786 cx.new(|cx| {
787 let mut thread = Thread::from_db(
788 id.clone(),
789 db_thread,
790 this.project.clone(),
791 this.project_context.clone(),
792 this.context_server_registry.clone(),
793 this.templates.clone(),
794 cx,
795 );
796 thread.set_summarization_model(summarization_model, cx);
797 thread
798 })
799 })
800 })
801 }
802
803 pub fn open_thread(
804 &mut self,
805 id: acp::SessionId,
806 cx: &mut Context<Self>,
807 ) -> Task<Result<Entity<AcpThread>>> {
808 if let Some(session) = self.sessions.get(&id) {
809 return Task::ready(Ok(session.acp_thread.clone()));
810 }
811
812 let task = self.load_thread(id, cx);
813 cx.spawn(async move |this, cx| {
814 let thread = task.await?;
815 let acp_thread =
816 this.update(cx, |this, cx| this.register_session(thread.clone(), cx))?;
817 let events = thread.update(cx, |thread, cx| thread.replay(cx));
818 cx.update(|cx| {
819 NativeAgentConnection::handle_thread_events(events, acp_thread.downgrade(), cx)
820 })
821 .await?;
822 Ok(acp_thread)
823 })
824 }
825
826 pub fn thread_summary(
827 &mut self,
828 id: acp::SessionId,
829 cx: &mut Context<Self>,
830 ) -> Task<Result<SharedString>> {
831 let thread = self.open_thread(id.clone(), cx);
832 cx.spawn(async move |this, cx| {
833 let acp_thread = thread.await?;
834 let result = this
835 .update(cx, |this, cx| {
836 this.sessions
837 .get(&id)
838 .unwrap()
839 .thread
840 .update(cx, |thread, cx| thread.summary(cx))
841 })?
842 .await
843 .context("Failed to generate summary")?;
844 drop(acp_thread);
845 Ok(result)
846 })
847 }
848
849 fn save_thread(&mut self, thread: Entity<Thread>, cx: &mut Context<Self>) {
850 if thread.read(cx).is_empty() {
851 return;
852 }
853
854 let id = thread.read(cx).id().clone();
855 let Some(session) = self.sessions.get_mut(&id) else {
856 return;
857 };
858
859 let folder_paths = PathList::new(
860 &self
861 .project
862 .read(cx)
863 .visible_worktrees(cx)
864 .map(|worktree| worktree.read(cx).abs_path().to_path_buf())
865 .collect::<Vec<_>>(),
866 );
867
868 let draft_prompt = session.acp_thread.read(cx).draft_prompt().map(Vec::from);
869 let database_future = ThreadsDatabase::connect(cx);
870 let db_thread = thread.update(cx, |thread, cx| {
871 thread.set_draft_prompt(draft_prompt);
872 thread.to_db(cx)
873 });
874 let thread_store = self.thread_store.clone();
875 session.pending_save = cx.spawn(async move |_, cx| {
876 let Some(database) = database_future.await.map_err(|err| anyhow!(err)).log_err() else {
877 return;
878 };
879 let db_thread = db_thread.await;
880 database
881 .save_thread(id, db_thread, folder_paths)
882 .await
883 .log_err();
884 thread_store.update(cx, |store, cx| store.reload(cx));
885 });
886 }
887
888 fn send_mcp_prompt(
889 &self,
890 message_id: UserMessageId,
891 session_id: agent_client_protocol::SessionId,
892 prompt_name: String,
893 server_id: ContextServerId,
894 arguments: HashMap<String, String>,
895 original_content: Vec<acp::ContentBlock>,
896 cx: &mut Context<Self>,
897 ) -> Task<Result<acp::PromptResponse>> {
898 let server_store = self.context_server_registry.read(cx).server_store().clone();
899 let path_style = self.project.read(cx).path_style(cx);
900
901 cx.spawn(async move |this, cx| {
902 let prompt =
903 crate::get_prompt(&server_store, &server_id, &prompt_name, arguments, cx).await?;
904
905 let (acp_thread, thread) = this.update(cx, |this, _cx| {
906 let session = this
907 .sessions
908 .get(&session_id)
909 .context("Failed to get session")?;
910 anyhow::Ok((session.acp_thread.clone(), session.thread.clone()))
911 })??;
912
913 let mut last_is_user = true;
914
915 thread.update(cx, |thread, cx| {
916 thread.push_acp_user_block(
917 message_id,
918 original_content.into_iter().skip(1),
919 path_style,
920 cx,
921 );
922 });
923
924 for message in prompt.messages {
925 let context_server::types::PromptMessage { role, content } = message;
926 let block = mcp_message_content_to_acp_content_block(content);
927
928 match role {
929 context_server::types::Role::User => {
930 let id = acp_thread::UserMessageId::new();
931
932 acp_thread.update(cx, |acp_thread, cx| {
933 acp_thread.push_user_content_block_with_indent(
934 Some(id.clone()),
935 block.clone(),
936 true,
937 cx,
938 );
939 });
940
941 thread.update(cx, |thread, cx| {
942 thread.push_acp_user_block(id, [block], path_style, cx);
943 });
944 }
945 context_server::types::Role::Assistant => {
946 acp_thread.update(cx, |acp_thread, cx| {
947 acp_thread.push_assistant_content_block_with_indent(
948 block.clone(),
949 false,
950 true,
951 cx,
952 );
953 });
954
955 thread.update(cx, |thread, cx| {
956 thread.push_acp_agent_block(block, cx);
957 });
958 }
959 }
960
961 last_is_user = role == context_server::types::Role::User;
962 }
963
964 let response_stream = thread.update(cx, |thread, cx| {
965 if last_is_user {
966 thread.send_existing(cx)
967 } else {
968 // Resume if MCP prompt did not end with a user message
969 thread.resume(cx)
970 }
971 })?;
972
973 cx.update(|cx| {
974 NativeAgentConnection::handle_thread_events(
975 response_stream,
976 acp_thread.downgrade(),
977 cx,
978 )
979 })
980 .await
981 })
982 }
983}
984
985/// Wrapper struct that implements the AgentConnection trait
986#[derive(Clone)]
987pub struct NativeAgentConnection(pub Entity<NativeAgent>);
988
989impl NativeAgentConnection {
990 pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option<Entity<Thread>> {
991 self.0
992 .read(cx)
993 .sessions
994 .get(session_id)
995 .map(|session| session.thread.clone())
996 }
997
998 pub fn load_thread(&self, id: acp::SessionId, cx: &mut App) -> Task<Result<Entity<Thread>>> {
999 self.0.update(cx, |this, cx| this.load_thread(id, cx))
1000 }
1001
1002 fn run_turn(
1003 &self,
1004 session_id: acp::SessionId,
1005 cx: &mut App,
1006 f: impl 'static
1007 + FnOnce(Entity<Thread>, &mut App) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>,
1008 ) -> Task<Result<acp::PromptResponse>> {
1009 let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
1010 agent
1011 .sessions
1012 .get_mut(&session_id)
1013 .map(|s| (s.thread.clone(), s.acp_thread.clone()))
1014 }) else {
1015 return Task::ready(Err(anyhow!("Session not found")));
1016 };
1017 log::debug!("Found session for: {}", session_id);
1018
1019 let response_stream = match f(thread, cx) {
1020 Ok(stream) => stream,
1021 Err(err) => return Task::ready(Err(err)),
1022 };
1023 Self::handle_thread_events(response_stream, acp_thread.downgrade(), cx)
1024 }
1025
1026 fn handle_thread_events(
1027 mut events: mpsc::UnboundedReceiver<Result<ThreadEvent>>,
1028 acp_thread: WeakEntity<AcpThread>,
1029 cx: &App,
1030 ) -> Task<Result<acp::PromptResponse>> {
1031 cx.spawn(async move |cx| {
1032 // Handle response stream and forward to session.acp_thread
1033 while let Some(result) = events.next().await {
1034 match result {
1035 Ok(event) => {
1036 log::trace!("Received completion event: {:?}", event);
1037
1038 match event {
1039 ThreadEvent::UserMessage(message) => {
1040 acp_thread.update(cx, |thread, cx| {
1041 for content in message.content {
1042 thread.push_user_content_block(
1043 Some(message.id.clone()),
1044 content.into(),
1045 cx,
1046 );
1047 }
1048 })?;
1049 }
1050 ThreadEvent::AgentText(text) => {
1051 acp_thread.update(cx, |thread, cx| {
1052 thread.push_assistant_content_block(text.into(), false, cx)
1053 })?;
1054 }
1055 ThreadEvent::AgentThinking(text) => {
1056 acp_thread.update(cx, |thread, cx| {
1057 thread.push_assistant_content_block(text.into(), true, cx)
1058 })?;
1059 }
1060 ThreadEvent::ToolCallAuthorization(ToolCallAuthorization {
1061 tool_call,
1062 options,
1063 response,
1064 context: _,
1065 }) => {
1066 let outcome_task = acp_thread.update(cx, |thread, cx| {
1067 thread.request_tool_call_authorization(tool_call, options, cx)
1068 })??;
1069 cx.background_spawn(async move {
1070 if let acp::RequestPermissionOutcome::Selected(
1071 acp::SelectedPermissionOutcome { option_id, .. },
1072 ) = outcome_task.await
1073 {
1074 response
1075 .send(option_id)
1076 .map(|_| anyhow!("authorization receiver was dropped"))
1077 .log_err();
1078 }
1079 })
1080 .detach();
1081 }
1082 ThreadEvent::ToolCall(tool_call) => {
1083 acp_thread.update(cx, |thread, cx| {
1084 thread.upsert_tool_call(tool_call, cx)
1085 })??;
1086 }
1087 ThreadEvent::ToolCallUpdate(update) => {
1088 acp_thread.update(cx, |thread, cx| {
1089 thread.update_tool_call(update, cx)
1090 })??;
1091 }
1092 ThreadEvent::SubagentSpawned(session_id) => {
1093 acp_thread.update(cx, |thread, cx| {
1094 thread.subagent_spawned(session_id, cx);
1095 })?;
1096 }
1097 ThreadEvent::Retry(status) => {
1098 acp_thread.update(cx, |thread, cx| {
1099 thread.update_retry_status(status, cx)
1100 })?;
1101 }
1102 ThreadEvent::Stop(stop_reason) => {
1103 log::debug!("Assistant message complete: {:?}", stop_reason);
1104 return Ok(acp::PromptResponse::new(stop_reason));
1105 }
1106 }
1107 }
1108 Err(e) => {
1109 log::error!("Error in model response stream: {:?}", e);
1110 return Err(e);
1111 }
1112 }
1113 }
1114
1115 log::debug!("Response stream completed");
1116 anyhow::Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
1117 })
1118 }
1119}
1120
1121struct Command<'a> {
1122 prompt_name: &'a str,
1123 arg_value: &'a str,
1124 explicit_server_id: Option<&'a str>,
1125}
1126
1127impl<'a> Command<'a> {
1128 fn parse(prompt: &'a [acp::ContentBlock]) -> Option<Self> {
1129 let acp::ContentBlock::Text(text_content) = prompt.first()? else {
1130 return None;
1131 };
1132 let text = text_content.text.trim();
1133 let command = text.strip_prefix('/')?;
1134 let (command, arg_value) = command
1135 .split_once(char::is_whitespace)
1136 .unwrap_or((command, ""));
1137
1138 if let Some((server_id, prompt_name)) = command.split_once('.') {
1139 Some(Self {
1140 prompt_name,
1141 arg_value,
1142 explicit_server_id: Some(server_id),
1143 })
1144 } else {
1145 Some(Self {
1146 prompt_name: command,
1147 arg_value,
1148 explicit_server_id: None,
1149 })
1150 }
1151 }
1152}
1153
1154struct NativeAgentModelSelector {
1155 session_id: acp::SessionId,
1156 connection: NativeAgentConnection,
1157}
1158
1159impl acp_thread::AgentModelSelector for NativeAgentModelSelector {
1160 fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
1161 log::debug!("NativeAgentConnection::list_models called");
1162 let list = self.connection.0.read(cx).models.model_list.clone();
1163 Task::ready(if list.is_empty() {
1164 Err(anyhow::anyhow!("No models available"))
1165 } else {
1166 Ok(list)
1167 })
1168 }
1169
1170 fn select_model(&self, model_id: acp::ModelId, cx: &mut App) -> Task<Result<()>> {
1171 log::debug!(
1172 "Setting model for session {}: {}",
1173 self.session_id,
1174 model_id
1175 );
1176 let Some(thread) = self
1177 .connection
1178 .0
1179 .read(cx)
1180 .sessions
1181 .get(&self.session_id)
1182 .map(|session| session.thread.clone())
1183 else {
1184 return Task::ready(Err(anyhow!("Session not found")));
1185 };
1186
1187 let Some(model) = self.connection.0.read(cx).models.model_from_id(&model_id) else {
1188 return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
1189 };
1190
1191 // We want to reset the effort level when switching models, as the currently-selected effort level may
1192 // not be compatible.
1193 let effort = model
1194 .default_effort_level()
1195 .map(|effort_level| effort_level.value.to_string());
1196
1197 thread.update(cx, |thread, cx| {
1198 thread.set_model(model.clone(), cx);
1199 thread.set_thinking_effort(effort.clone(), cx);
1200 thread.set_thinking_enabled(model.supports_thinking(), cx);
1201 });
1202
1203 update_settings_file(
1204 self.connection.0.read(cx).fs.clone(),
1205 cx,
1206 move |settings, cx| {
1207 let provider = model.provider_id().0.to_string();
1208 let model = model.id().0.to_string();
1209 let enable_thinking = thread.read(cx).thinking_enabled();
1210 settings
1211 .agent
1212 .get_or_insert_default()
1213 .set_model(LanguageModelSelection {
1214 provider: provider.into(),
1215 model,
1216 enable_thinking,
1217 effort,
1218 });
1219 },
1220 );
1221
1222 Task::ready(Ok(()))
1223 }
1224
1225 fn selected_model(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelInfo>> {
1226 let Some(thread) = self
1227 .connection
1228 .0
1229 .read(cx)
1230 .sessions
1231 .get(&self.session_id)
1232 .map(|session| session.thread.clone())
1233 else {
1234 return Task::ready(Err(anyhow!("Session not found")));
1235 };
1236 let Some(model) = thread.read(cx).model() else {
1237 return Task::ready(Err(anyhow!("Model not found")));
1238 };
1239 let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
1240 else {
1241 return Task::ready(Err(anyhow!("Provider not found")));
1242 };
1243 Task::ready(Ok(LanguageModels::map_language_model_to_info(
1244 model, &provider,
1245 )))
1246 }
1247
1248 fn watch(&self, cx: &mut App) -> Option<watch::Receiver<()>> {
1249 Some(self.connection.0.read(cx).models.watch())
1250 }
1251
1252 fn should_render_footer(&self) -> bool {
1253 true
1254 }
1255}
1256
1257impl acp_thread::AgentConnection for NativeAgentConnection {
1258 fn telemetry_id(&self) -> SharedString {
1259 "zed".into()
1260 }
1261
1262 fn new_session(
1263 self: Rc<Self>,
1264 project: Entity<Project>,
1265 cwd: &Path,
1266 cx: &mut App,
1267 ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
1268 log::debug!("Creating new thread for project at: {cwd:?}");
1269 Task::ready(Ok(self
1270 .0
1271 .update(cx, |agent, cx| agent.new_session(project, cx))))
1272 }
1273
1274 fn supports_load_session(&self) -> bool {
1275 true
1276 }
1277
1278 fn load_session(
1279 self: Rc<Self>,
1280 session: AgentSessionInfo,
1281 _project: Entity<Project>,
1282 _cwd: &Path,
1283 cx: &mut App,
1284 ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
1285 self.0
1286 .update(cx, |agent, cx| agent.open_thread(session.session_id, cx))
1287 }
1288
1289 fn supports_close_session(&self) -> bool {
1290 true
1291 }
1292
1293 fn close_session(&self, session_id: &acp::SessionId, cx: &mut App) -> Task<Result<()>> {
1294 self.0.update(cx, |agent, _cx| {
1295 agent.sessions.remove(session_id);
1296 });
1297 Task::ready(Ok(()))
1298 }
1299
1300 fn auth_methods(&self) -> &[acp::AuthMethod] {
1301 &[] // No auth for in-process
1302 }
1303
1304 fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
1305 Task::ready(Ok(()))
1306 }
1307
1308 fn model_selector(&self, session_id: &acp::SessionId) -> Option<Rc<dyn AgentModelSelector>> {
1309 Some(Rc::new(NativeAgentModelSelector {
1310 session_id: session_id.clone(),
1311 connection: self.clone(),
1312 }) as Rc<dyn AgentModelSelector>)
1313 }
1314
1315 fn prompt(
1316 &self,
1317 id: Option<acp_thread::UserMessageId>,
1318 params: acp::PromptRequest,
1319 cx: &mut App,
1320 ) -> Task<Result<acp::PromptResponse>> {
1321 let id = id.expect("UserMessageId is required");
1322 let session_id = params.session_id.clone();
1323 log::info!("Received prompt request for session: {}", session_id);
1324 log::debug!("Prompt blocks count: {}", params.prompt.len());
1325
1326 if let Some(parsed_command) = Command::parse(¶ms.prompt) {
1327 let registry = self.0.read(cx).context_server_registry.read(cx);
1328
1329 let explicit_server_id = parsed_command
1330 .explicit_server_id
1331 .map(|server_id| ContextServerId(server_id.into()));
1332
1333 if let Some(prompt) =
1334 registry.find_prompt(explicit_server_id.as_ref(), parsed_command.prompt_name)
1335 {
1336 let arguments = if !parsed_command.arg_value.is_empty()
1337 && let Some(arg_name) = prompt
1338 .prompt
1339 .arguments
1340 .as_ref()
1341 .and_then(|args| args.first())
1342 .map(|arg| arg.name.clone())
1343 {
1344 HashMap::from_iter([(arg_name, parsed_command.arg_value.to_string())])
1345 } else {
1346 Default::default()
1347 };
1348
1349 let prompt_name = prompt.prompt.name.clone();
1350 let server_id = prompt.server_id.clone();
1351
1352 return self.0.update(cx, |agent, cx| {
1353 agent.send_mcp_prompt(
1354 id,
1355 session_id.clone(),
1356 prompt_name,
1357 server_id,
1358 arguments,
1359 params.prompt,
1360 cx,
1361 )
1362 });
1363 };
1364 };
1365
1366 let path_style = self.0.read(cx).project.read(cx).path_style(cx);
1367
1368 self.run_turn(session_id, cx, move |thread, cx| {
1369 let content: Vec<UserMessageContent> = params
1370 .prompt
1371 .into_iter()
1372 .map(|block| UserMessageContent::from_content_block(block, path_style))
1373 .collect::<Vec<_>>();
1374 log::debug!("Converted prompt to message: {} chars", content.len());
1375 log::debug!("Message id: {:?}", id);
1376 log::debug!("Message content: {:?}", content);
1377
1378 thread.update(cx, |thread, cx| thread.send(id, content, cx))
1379 })
1380 }
1381
1382 fn retry(
1383 &self,
1384 session_id: &acp::SessionId,
1385 _cx: &App,
1386 ) -> Option<Rc<dyn acp_thread::AgentSessionRetry>> {
1387 Some(Rc::new(NativeAgentSessionRetry {
1388 connection: self.clone(),
1389 session_id: session_id.clone(),
1390 }) as _)
1391 }
1392
1393 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
1394 log::info!("Cancelling on session: {}", session_id);
1395 self.0.update(cx, |agent, cx| {
1396 if let Some(session) = agent.sessions.get(session_id) {
1397 session
1398 .thread
1399 .update(cx, |thread, cx| thread.cancel(cx))
1400 .detach();
1401 }
1402 });
1403 }
1404
1405 fn truncate(
1406 &self,
1407 session_id: &agent_client_protocol::SessionId,
1408 cx: &App,
1409 ) -> Option<Rc<dyn acp_thread::AgentSessionTruncate>> {
1410 self.0.read_with(cx, |agent, _cx| {
1411 agent.sessions.get(session_id).map(|session| {
1412 Rc::new(NativeAgentSessionTruncate {
1413 thread: session.thread.clone(),
1414 acp_thread: session.acp_thread.downgrade(),
1415 }) as _
1416 })
1417 })
1418 }
1419
1420 fn set_title(
1421 &self,
1422 session_id: &acp::SessionId,
1423 cx: &App,
1424 ) -> Option<Rc<dyn acp_thread::AgentSessionSetTitle>> {
1425 self.0.read_with(cx, |agent, _cx| {
1426 agent
1427 .sessions
1428 .get(session_id)
1429 .filter(|s| !s.thread.read(cx).is_subagent())
1430 .map(|session| {
1431 Rc::new(NativeAgentSessionSetTitle {
1432 thread: session.thread.clone(),
1433 }) as _
1434 })
1435 })
1436 }
1437
1438 fn session_list(&self, cx: &mut App) -> Option<Rc<dyn AgentSessionList>> {
1439 let thread_store = self.0.read(cx).thread_store.clone();
1440 Some(Rc::new(NativeAgentSessionList::new(thread_store, cx)) as _)
1441 }
1442
1443 fn telemetry(&self) -> Option<Rc<dyn acp_thread::AgentTelemetry>> {
1444 Some(Rc::new(self.clone()) as Rc<dyn acp_thread::AgentTelemetry>)
1445 }
1446
1447 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1448 self
1449 }
1450}
1451
1452impl acp_thread::AgentTelemetry for NativeAgentConnection {
1453 fn thread_data(
1454 &self,
1455 session_id: &acp::SessionId,
1456 cx: &mut App,
1457 ) -> Task<Result<serde_json::Value>> {
1458 let Some(session) = self.0.read(cx).sessions.get(session_id) else {
1459 return Task::ready(Err(anyhow!("Session not found")));
1460 };
1461
1462 let task = session.thread.read(cx).to_db(cx);
1463 cx.background_spawn(async move {
1464 serde_json::to_value(task.await).context("Failed to serialize thread")
1465 })
1466 }
1467}
1468
1469pub struct NativeAgentSessionList {
1470 thread_store: Entity<ThreadStore>,
1471 updates_tx: smol::channel::Sender<acp_thread::SessionListUpdate>,
1472 updates_rx: smol::channel::Receiver<acp_thread::SessionListUpdate>,
1473 _subscription: Subscription,
1474}
1475
1476impl NativeAgentSessionList {
1477 fn new(thread_store: Entity<ThreadStore>, cx: &mut App) -> Self {
1478 let (tx, rx) = smol::channel::unbounded();
1479 let this_tx = tx.clone();
1480 let subscription = cx.observe(&thread_store, move |_, _| {
1481 this_tx
1482 .try_send(acp_thread::SessionListUpdate::Refresh)
1483 .ok();
1484 });
1485 Self {
1486 thread_store,
1487 updates_tx: tx,
1488 updates_rx: rx,
1489 _subscription: subscription,
1490 }
1491 }
1492
1493 pub fn thread_store(&self) -> &Entity<ThreadStore> {
1494 &self.thread_store
1495 }
1496}
1497
1498impl AgentSessionList for NativeAgentSessionList {
1499 fn list_sessions(
1500 &self,
1501 _request: AgentSessionListRequest,
1502 cx: &mut App,
1503 ) -> Task<Result<AgentSessionListResponse>> {
1504 let sessions = self
1505 .thread_store
1506 .read(cx)
1507 .entries()
1508 .map(|entry| AgentSessionInfo::from(&entry))
1509 .collect();
1510 Task::ready(Ok(AgentSessionListResponse::new(sessions)))
1511 }
1512
1513 fn supports_delete(&self) -> bool {
1514 true
1515 }
1516
1517 fn delete_session(&self, session_id: &acp::SessionId, cx: &mut App) -> Task<Result<()>> {
1518 self.thread_store
1519 .update(cx, |store, cx| store.delete_thread(session_id.clone(), cx))
1520 }
1521
1522 fn delete_sessions(&self, cx: &mut App) -> Task<Result<()>> {
1523 self.thread_store
1524 .update(cx, |store, cx| store.delete_threads(cx))
1525 }
1526
1527 fn watch(
1528 &self,
1529 _cx: &mut App,
1530 ) -> Option<smol::channel::Receiver<acp_thread::SessionListUpdate>> {
1531 Some(self.updates_rx.clone())
1532 }
1533
1534 fn notify_refresh(&self) {
1535 self.updates_tx
1536 .try_send(acp_thread::SessionListUpdate::Refresh)
1537 .ok();
1538 }
1539
1540 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1541 self
1542 }
1543}
1544
1545struct NativeAgentSessionTruncate {
1546 thread: Entity<Thread>,
1547 acp_thread: WeakEntity<AcpThread>,
1548}
1549
1550impl acp_thread::AgentSessionTruncate for NativeAgentSessionTruncate {
1551 fn run(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
1552 match self.thread.update(cx, |thread, cx| {
1553 thread.truncate(message_id.clone(), cx)?;
1554 Ok(thread.latest_token_usage())
1555 }) {
1556 Ok(usage) => {
1557 self.acp_thread
1558 .update(cx, |thread, cx| {
1559 thread.update_token_usage(usage, cx);
1560 })
1561 .ok();
1562 Task::ready(Ok(()))
1563 }
1564 Err(error) => Task::ready(Err(error)),
1565 }
1566 }
1567}
1568
1569struct NativeAgentSessionRetry {
1570 connection: NativeAgentConnection,
1571 session_id: acp::SessionId,
1572}
1573
1574impl acp_thread::AgentSessionRetry for NativeAgentSessionRetry {
1575 fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>> {
1576 self.connection
1577 .run_turn(self.session_id.clone(), cx, |thread, cx| {
1578 thread.update(cx, |thread, cx| thread.resume(cx))
1579 })
1580 }
1581}
1582
1583struct NativeAgentSessionSetTitle {
1584 thread: Entity<Thread>,
1585}
1586
1587impl acp_thread::AgentSessionSetTitle for NativeAgentSessionSetTitle {
1588 fn run(&self, title: SharedString, cx: &mut App) -> Task<Result<()>> {
1589 self.thread
1590 .update(cx, |thread, cx| thread.set_title(title, cx));
1591 Task::ready(Ok(()))
1592 }
1593}
1594
1595pub struct NativeThreadEnvironment {
1596 agent: WeakEntity<NativeAgent>,
1597 thread: WeakEntity<Thread>,
1598 acp_thread: WeakEntity<AcpThread>,
1599}
1600
1601impl NativeThreadEnvironment {
1602 pub(crate) fn create_subagent_thread(
1603 &self,
1604 label: String,
1605 cx: &mut App,
1606 ) -> Result<Rc<dyn SubagentHandle>> {
1607 let Some(parent_thread_entity) = self.thread.upgrade() else {
1608 anyhow::bail!("Parent thread no longer exists".to_string());
1609 };
1610 let parent_thread = parent_thread_entity.read(cx);
1611 let current_depth = parent_thread.depth();
1612
1613 if current_depth >= MAX_SUBAGENT_DEPTH {
1614 return Err(anyhow!(
1615 "Maximum subagent depth ({}) reached",
1616 MAX_SUBAGENT_DEPTH
1617 ));
1618 }
1619
1620 let subagent_thread: Entity<Thread> = cx.new(|cx| {
1621 let mut thread = Thread::new_subagent(&parent_thread_entity, cx);
1622 thread.set_title(label.into(), cx);
1623 thread
1624 });
1625
1626 let session_id = subagent_thread.read(cx).id().clone();
1627
1628 let acp_thread = self.agent.update(cx, |agent, cx| {
1629 agent.register_session(subagent_thread.clone(), cx)
1630 })?;
1631
1632 let depth = current_depth + 1;
1633
1634 telemetry::event!(
1635 "Subagent Started",
1636 session = parent_thread_entity.read(cx).id().to_string(),
1637 subagent_session = session_id.to_string(),
1638 depth,
1639 is_resumed = false,
1640 );
1641
1642 self.prompt_subagent(session_id, subagent_thread, acp_thread)
1643 }
1644
1645 pub(crate) fn resume_subagent_thread(
1646 &self,
1647 session_id: acp::SessionId,
1648 cx: &mut App,
1649 ) -> Result<Rc<dyn SubagentHandle>> {
1650 let (subagent_thread, acp_thread) = self.agent.update(cx, |agent, _cx| {
1651 let session = agent
1652 .sessions
1653 .get(&session_id)
1654 .ok_or_else(|| anyhow!("No subagent session found with id {session_id}"))?;
1655 anyhow::Ok((session.thread.clone(), session.acp_thread.clone()))
1656 })??;
1657
1658 let depth = subagent_thread.read(cx).depth();
1659
1660 if let Some(parent_thread_entity) = self.thread.upgrade() {
1661 telemetry::event!(
1662 "Subagent Started",
1663 session = parent_thread_entity.read(cx).id().to_string(),
1664 subagent_session = session_id.to_string(),
1665 depth,
1666 is_resumed = true,
1667 );
1668 }
1669
1670 self.prompt_subagent(session_id, subagent_thread, acp_thread)
1671 }
1672
1673 fn prompt_subagent(
1674 &self,
1675 session_id: acp::SessionId,
1676 subagent_thread: Entity<Thread>,
1677 acp_thread: Entity<acp_thread::AcpThread>,
1678 ) -> Result<Rc<dyn SubagentHandle>> {
1679 let Some(parent_thread_entity) = self.thread.upgrade() else {
1680 anyhow::bail!("Parent thread no longer exists".to_string());
1681 };
1682 Ok(Rc::new(NativeSubagentHandle::new(
1683 session_id,
1684 subagent_thread,
1685 acp_thread,
1686 parent_thread_entity,
1687 )) as _)
1688 }
1689}
1690
1691impl ThreadEnvironment for NativeThreadEnvironment {
1692 fn create_terminal(
1693 &self,
1694 command: String,
1695 cwd: Option<PathBuf>,
1696 output_byte_limit: Option<u64>,
1697 cx: &mut AsyncApp,
1698 ) -> Task<Result<Rc<dyn TerminalHandle>>> {
1699 let task = self.acp_thread.update(cx, |thread, cx| {
1700 thread.create_terminal(command, vec![], vec![], cwd, output_byte_limit, cx)
1701 });
1702
1703 let acp_thread = self.acp_thread.clone();
1704 cx.spawn(async move |cx| {
1705 let terminal = task?.await?;
1706
1707 let (drop_tx, drop_rx) = oneshot::channel();
1708 let terminal_id = terminal.read_with(cx, |terminal, _cx| terminal.id().clone());
1709
1710 cx.spawn(async move |cx| {
1711 drop_rx.await.ok();
1712 acp_thread.update(cx, |thread, cx| thread.release_terminal(terminal_id, cx))
1713 })
1714 .detach();
1715
1716 let handle = AcpTerminalHandle {
1717 terminal,
1718 _drop_tx: Some(drop_tx),
1719 };
1720
1721 Ok(Rc::new(handle) as _)
1722 })
1723 }
1724
1725 fn create_subagent(&self, label: String, cx: &mut App) -> Result<Rc<dyn SubagentHandle>> {
1726 self.create_subagent_thread(label, cx)
1727 }
1728
1729 fn resume_subagent(
1730 &self,
1731 session_id: acp::SessionId,
1732 cx: &mut App,
1733 ) -> Result<Rc<dyn SubagentHandle>> {
1734 self.resume_subagent_thread(session_id, cx)
1735 }
1736}
1737
1738#[derive(Debug, Clone)]
1739enum SubagentPromptResult {
1740 Completed,
1741 Cancelled,
1742 ContextWindowWarning,
1743 Error(String),
1744}
1745
1746pub struct NativeSubagentHandle {
1747 session_id: acp::SessionId,
1748 parent_thread: WeakEntity<Thread>,
1749 subagent_thread: Entity<Thread>,
1750 acp_thread: Entity<acp_thread::AcpThread>,
1751}
1752
1753impl NativeSubagentHandle {
1754 fn new(
1755 session_id: acp::SessionId,
1756 subagent_thread: Entity<Thread>,
1757 acp_thread: Entity<acp_thread::AcpThread>,
1758 parent_thread_entity: Entity<Thread>,
1759 ) -> Self {
1760 NativeSubagentHandle {
1761 session_id,
1762 subagent_thread,
1763 parent_thread: parent_thread_entity.downgrade(),
1764 acp_thread,
1765 }
1766 }
1767}
1768
1769impl SubagentHandle for NativeSubagentHandle {
1770 fn id(&self) -> acp::SessionId {
1771 self.session_id.clone()
1772 }
1773
1774 fn num_entries(&self, cx: &App) -> usize {
1775 self.acp_thread.read(cx).entries().len()
1776 }
1777
1778 fn send(&self, message: String, cx: &AsyncApp) -> Task<Result<String>> {
1779 let thread = self.subagent_thread.clone();
1780 let acp_thread = self.acp_thread.clone();
1781 let subagent_session_id = self.session_id.clone();
1782 let parent_thread = self.parent_thread.clone();
1783
1784 cx.spawn(async move |cx| {
1785 let (task, _subscription) = cx.update(|cx| {
1786 let ratio_before_prompt = thread
1787 .read(cx)
1788 .latest_token_usage()
1789 .map(|usage| usage.ratio());
1790
1791 parent_thread
1792 .update(cx, |parent_thread, _cx| {
1793 parent_thread.register_running_subagent(thread.downgrade())
1794 })
1795 .ok();
1796
1797 let task = acp_thread.update(cx, |acp_thread, cx| {
1798 acp_thread.send(vec![message.into()], cx)
1799 });
1800
1801 let (token_limit_tx, token_limit_rx) = oneshot::channel::<()>();
1802 let mut token_limit_tx = Some(token_limit_tx);
1803
1804 let subscription = cx.subscribe(
1805 &thread,
1806 move |_thread, event: &TokenUsageUpdated, _cx| {
1807 if let Some(usage) = &event.0 {
1808 let old_ratio = ratio_before_prompt
1809 .clone()
1810 .unwrap_or(TokenUsageRatio::Normal);
1811 let new_ratio = usage.ratio();
1812 if old_ratio == TokenUsageRatio::Normal
1813 && new_ratio == TokenUsageRatio::Warning
1814 {
1815 if let Some(tx) = token_limit_tx.take() {
1816 tx.send(()).ok();
1817 }
1818 }
1819 }
1820 },
1821 );
1822
1823 let wait_for_prompt = cx
1824 .background_spawn(async move {
1825 futures::select! {
1826 response = task.fuse() => match response {
1827 Ok(Some(response)) => {
1828 match response.stop_reason {
1829 acp::StopReason::Cancelled => SubagentPromptResult::Cancelled,
1830 acp::StopReason::MaxTokens => SubagentPromptResult::Error("The agent reached the maximum number of tokens.".into()),
1831 acp::StopReason::MaxTurnRequests => SubagentPromptResult::Error("The agent reached the maximum number of allowed requests between user turns. Try prompting again.".into()),
1832 acp::StopReason::Refusal => SubagentPromptResult::Error("The agent refused to process that prompt. Try again.".into()),
1833 acp::StopReason::EndTurn | _ => SubagentPromptResult::Completed,
1834 }
1835 }
1836 Ok(None) => SubagentPromptResult::Error("No response from the agent. You can try messaging again.".into()),
1837 Err(error) => SubagentPromptResult::Error(error.to_string()),
1838 },
1839 _ = token_limit_rx.fuse() => SubagentPromptResult::ContextWindowWarning,
1840 }
1841 });
1842
1843 (wait_for_prompt, subscription)
1844 });
1845
1846 let result = match task.await {
1847 SubagentPromptResult::Completed => thread.read_with(cx, |thread, _cx| {
1848 thread
1849 .last_message()
1850 .and_then(|message| {
1851 let content = message.as_agent_message()?
1852 .content
1853 .iter()
1854 .filter_map(|c| match c {
1855 AgentMessageContent::Text(text) => Some(text.as_str()),
1856 _ => None,
1857 })
1858 .join("\n\n");
1859 if content.is_empty() {
1860 None
1861 } else {
1862 Some( content)
1863 }
1864 })
1865 .context("No response from subagent")
1866 }),
1867 SubagentPromptResult::Cancelled => Err(anyhow!("User canceled")),
1868 SubagentPromptResult::Error(message) => Err(anyhow!("{message}")),
1869 SubagentPromptResult::ContextWindowWarning => {
1870 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
1871 Err(anyhow!(
1872 "The agent is nearing the end of its context window and has been \
1873 stopped. You can prompt the thread again to have the agent wrap up \
1874 or hand off its work."
1875 ))
1876 }
1877 };
1878
1879 parent_thread
1880 .update(cx, |parent_thread, cx| {
1881 parent_thread.unregister_running_subagent(&subagent_session_id, cx)
1882 })
1883 .ok();
1884
1885 result
1886 })
1887 }
1888}
1889
1890pub struct AcpTerminalHandle {
1891 terminal: Entity<acp_thread::Terminal>,
1892 _drop_tx: Option<oneshot::Sender<()>>,
1893}
1894
1895impl TerminalHandle for AcpTerminalHandle {
1896 fn id(&self, cx: &AsyncApp) -> Result<acp::TerminalId> {
1897 Ok(self.terminal.read_with(cx, |term, _cx| term.id().clone()))
1898 }
1899
1900 fn wait_for_exit(&self, cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
1901 Ok(self
1902 .terminal
1903 .read_with(cx, |term, _cx| term.wait_for_exit()))
1904 }
1905
1906 fn current_output(&self, cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
1907 Ok(self
1908 .terminal
1909 .read_with(cx, |term, cx| term.current_output(cx)))
1910 }
1911
1912 fn kill(&self, cx: &AsyncApp) -> Result<()> {
1913 cx.update(|cx| {
1914 self.terminal.update(cx, |terminal, cx| {
1915 terminal.kill(cx);
1916 });
1917 });
1918 Ok(())
1919 }
1920
1921 fn was_stopped_by_user(&self, cx: &AsyncApp) -> Result<bool> {
1922 Ok(self
1923 .terminal
1924 .read_with(cx, |term, _cx| term.was_stopped_by_user()))
1925 }
1926}
1927
1928#[cfg(test)]
1929mod internal_tests {
1930 use super::*;
1931 use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelInfo, MentionUri};
1932 use fs::FakeFs;
1933 use gpui::TestAppContext;
1934 use indoc::formatdoc;
1935 use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
1936 use language_model::{
1937 LanguageModelCompletionEvent, LanguageModelProviderId, LanguageModelProviderName,
1938 };
1939 use serde_json::json;
1940 use settings::SettingsStore;
1941 use util::{path, rel_path::rel_path};
1942
1943 #[gpui::test]
1944 async fn test_maintaining_project_context(cx: &mut TestAppContext) {
1945 init_test(cx);
1946 let fs = FakeFs::new(cx.executor());
1947 fs.insert_tree(
1948 "/",
1949 json!({
1950 "a": {}
1951 }),
1952 )
1953 .await;
1954 let project = Project::test(fs.clone(), [], cx).await;
1955 let thread_store = cx.new(|cx| ThreadStore::new(cx));
1956 let agent = NativeAgent::new(
1957 project.clone(),
1958 thread_store,
1959 Templates::new(),
1960 None,
1961 fs.clone(),
1962 &mut cx.to_async(),
1963 )
1964 .await
1965 .unwrap();
1966 agent.read_with(cx, |agent, cx| {
1967 assert_eq!(agent.project_context.read(cx).worktrees, vec![])
1968 });
1969
1970 let worktree = project
1971 .update(cx, |project, cx| project.create_worktree("/a", true, cx))
1972 .await
1973 .unwrap();
1974 cx.run_until_parked();
1975 agent.read_with(cx, |agent, cx| {
1976 assert_eq!(
1977 agent.project_context.read(cx).worktrees,
1978 vec![WorktreeContext {
1979 root_name: "a".into(),
1980 abs_path: Path::new("/a").into(),
1981 rules_file: None
1982 }]
1983 )
1984 });
1985
1986 // Creating `/a/.rules` updates the project context.
1987 fs.insert_file("/a/.rules", Vec::new()).await;
1988 cx.run_until_parked();
1989 agent.read_with(cx, |agent, cx| {
1990 let rules_entry = worktree
1991 .read(cx)
1992 .entry_for_path(rel_path(".rules"))
1993 .unwrap();
1994 assert_eq!(
1995 agent.project_context.read(cx).worktrees,
1996 vec![WorktreeContext {
1997 root_name: "a".into(),
1998 abs_path: Path::new("/a").into(),
1999 rules_file: Some(RulesFileContext {
2000 path_in_worktree: rel_path(".rules").into(),
2001 text: "".into(),
2002 project_entry_id: rules_entry.id.to_usize()
2003 })
2004 }]
2005 )
2006 });
2007 }
2008
2009 #[gpui::test]
2010 async fn test_listing_models(cx: &mut TestAppContext) {
2011 init_test(cx);
2012 let fs = FakeFs::new(cx.executor());
2013 fs.insert_tree("/", json!({ "a": {} })).await;
2014 let project = Project::test(fs.clone(), [], cx).await;
2015 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2016 let connection = NativeAgentConnection(
2017 NativeAgent::new(
2018 project.clone(),
2019 thread_store,
2020 Templates::new(),
2021 None,
2022 fs.clone(),
2023 &mut cx.to_async(),
2024 )
2025 .await
2026 .unwrap(),
2027 );
2028
2029 // Create a thread/session
2030 let acp_thread = cx
2031 .update(|cx| {
2032 Rc::new(connection.clone()).new_session(project.clone(), Path::new("/a"), cx)
2033 })
2034 .await
2035 .unwrap();
2036
2037 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2038
2039 let models = cx
2040 .update(|cx| {
2041 connection
2042 .model_selector(&session_id)
2043 .unwrap()
2044 .list_models(cx)
2045 })
2046 .await
2047 .unwrap();
2048
2049 let acp_thread::AgentModelList::Grouped(models) = models else {
2050 panic!("Unexpected model group");
2051 };
2052 assert_eq!(
2053 models,
2054 IndexMap::from_iter([(
2055 AgentModelGroupName("Fake".into()),
2056 vec![AgentModelInfo {
2057 id: acp::ModelId::new("fake/fake"),
2058 name: "Fake".into(),
2059 description: None,
2060 icon: Some(acp_thread::AgentModelIcon::Named(
2061 ui::IconName::ZedAssistant
2062 )),
2063 is_latest: false,
2064 cost: None,
2065 }]
2066 )])
2067 );
2068 }
2069
2070 #[gpui::test]
2071 async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) {
2072 init_test(cx);
2073 let fs = FakeFs::new(cx.executor());
2074 fs.create_dir(paths::settings_file().parent().unwrap())
2075 .await
2076 .unwrap();
2077 fs.insert_file(
2078 paths::settings_file(),
2079 json!({
2080 "agent": {
2081 "default_model": {
2082 "provider": "foo",
2083 "model": "bar"
2084 }
2085 }
2086 })
2087 .to_string()
2088 .into_bytes(),
2089 )
2090 .await;
2091 let project = Project::test(fs.clone(), [], cx).await;
2092
2093 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2094
2095 // Create the agent and connection
2096 let agent = NativeAgent::new(
2097 project.clone(),
2098 thread_store,
2099 Templates::new(),
2100 None,
2101 fs.clone(),
2102 &mut cx.to_async(),
2103 )
2104 .await
2105 .unwrap();
2106 let connection = NativeAgentConnection(agent.clone());
2107
2108 // Create a thread/session
2109 let acp_thread = cx
2110 .update(|cx| {
2111 Rc::new(connection.clone()).new_session(project.clone(), Path::new("/a"), cx)
2112 })
2113 .await
2114 .unwrap();
2115
2116 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2117
2118 // Select a model
2119 let selector = connection.model_selector(&session_id).unwrap();
2120 let model_id = acp::ModelId::new("fake/fake");
2121 cx.update(|cx| selector.select_model(model_id.clone(), cx))
2122 .await
2123 .unwrap();
2124
2125 // Verify the thread has the selected model
2126 agent.read_with(cx, |agent, _| {
2127 let session = agent.sessions.get(&session_id).unwrap();
2128 session.thread.read_with(cx, |thread, _| {
2129 assert_eq!(thread.model().unwrap().id().0, "fake");
2130 });
2131 });
2132
2133 cx.run_until_parked();
2134
2135 // Verify settings file was updated
2136 let settings_content = fs.load(paths::settings_file()).await.unwrap();
2137 let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
2138
2139 // Check that the agent settings contain the selected model
2140 assert_eq!(
2141 settings_json["agent"]["default_model"]["model"],
2142 json!("fake")
2143 );
2144 assert_eq!(
2145 settings_json["agent"]["default_model"]["provider"],
2146 json!("fake")
2147 );
2148
2149 // Register a thinking model and select it.
2150 cx.update(|cx| {
2151 let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2152 "fake-corp",
2153 "fake-thinking",
2154 "Fake Thinking",
2155 true,
2156 ));
2157 let thinking_provider = Arc::new(
2158 FakeLanguageModelProvider::new(
2159 LanguageModelProviderId::from("fake-corp".to_string()),
2160 LanguageModelProviderName::from("Fake Corp".to_string()),
2161 )
2162 .with_models(vec![thinking_model]),
2163 );
2164 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2165 registry.register_provider(thinking_provider, cx);
2166 });
2167 });
2168 agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2169
2170 let selector = connection.model_selector(&session_id).unwrap();
2171 cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2172 .await
2173 .unwrap();
2174 cx.run_until_parked();
2175
2176 // Verify enable_thinking was written to settings as true.
2177 let settings_content = fs.load(paths::settings_file()).await.unwrap();
2178 let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
2179 assert_eq!(
2180 settings_json["agent"]["default_model"]["enable_thinking"],
2181 json!(true),
2182 "selecting a thinking model should persist enable_thinking: true to settings"
2183 );
2184 }
2185
2186 #[gpui::test]
2187 async fn test_select_model_updates_thinking_enabled(cx: &mut TestAppContext) {
2188 init_test(cx);
2189 let fs = FakeFs::new(cx.executor());
2190 fs.create_dir(paths::settings_file().parent().unwrap())
2191 .await
2192 .unwrap();
2193 fs.insert_file(paths::settings_file(), b"{}".to_vec()).await;
2194 let project = Project::test(fs.clone(), [], cx).await;
2195
2196 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2197 let agent = NativeAgent::new(
2198 project.clone(),
2199 thread_store,
2200 Templates::new(),
2201 None,
2202 fs.clone(),
2203 &mut cx.to_async(),
2204 )
2205 .await
2206 .unwrap();
2207 let connection = NativeAgentConnection(agent.clone());
2208
2209 let acp_thread = cx
2210 .update(|cx| {
2211 Rc::new(connection.clone()).new_session(project.clone(), Path::new("/a"), cx)
2212 })
2213 .await
2214 .unwrap();
2215 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2216
2217 // Register a second provider with a thinking model.
2218 cx.update(|cx| {
2219 let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2220 "fake-corp",
2221 "fake-thinking",
2222 "Fake Thinking",
2223 true,
2224 ));
2225 let thinking_provider = Arc::new(
2226 FakeLanguageModelProvider::new(
2227 LanguageModelProviderId::from("fake-corp".to_string()),
2228 LanguageModelProviderName::from("Fake Corp".to_string()),
2229 )
2230 .with_models(vec![thinking_model]),
2231 );
2232 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2233 registry.register_provider(thinking_provider, cx);
2234 });
2235 });
2236 // Refresh the agent's model list so it picks up the new provider.
2237 agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2238
2239 // Thread starts with thinking_enabled = false (the default).
2240 agent.read_with(cx, |agent, _| {
2241 let session = agent.sessions.get(&session_id).unwrap();
2242 session.thread.read_with(cx, |thread, _| {
2243 assert!(!thread.thinking_enabled(), "thinking defaults to false");
2244 });
2245 });
2246
2247 // Select the thinking model via select_model.
2248 let selector = connection.model_selector(&session_id).unwrap();
2249 cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2250 .await
2251 .unwrap();
2252
2253 // select_model should have enabled thinking based on the model's supports_thinking().
2254 agent.read_with(cx, |agent, _| {
2255 let session = agent.sessions.get(&session_id).unwrap();
2256 session.thread.read_with(cx, |thread, _| {
2257 assert!(
2258 thread.thinking_enabled(),
2259 "select_model should enable thinking when model supports it"
2260 );
2261 });
2262 });
2263
2264 // Switch back to the non-thinking model.
2265 let selector = connection.model_selector(&session_id).unwrap();
2266 cx.update(|cx| selector.select_model(acp::ModelId::new("fake/fake"), cx))
2267 .await
2268 .unwrap();
2269
2270 // select_model should have disabled thinking.
2271 agent.read_with(cx, |agent, _| {
2272 let session = agent.sessions.get(&session_id).unwrap();
2273 session.thread.read_with(cx, |thread, _| {
2274 assert!(
2275 !thread.thinking_enabled(),
2276 "select_model should disable thinking when model does not support it"
2277 );
2278 });
2279 });
2280 }
2281
2282 #[gpui::test]
2283 async fn test_loaded_thread_preserves_thinking_enabled(cx: &mut TestAppContext) {
2284 init_test(cx);
2285 let fs = FakeFs::new(cx.executor());
2286 fs.insert_tree("/", json!({ "a": {} })).await;
2287 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2288 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2289 let agent = NativeAgent::new(
2290 project.clone(),
2291 thread_store.clone(),
2292 Templates::new(),
2293 None,
2294 fs.clone(),
2295 &mut cx.to_async(),
2296 )
2297 .await
2298 .unwrap();
2299 let connection = Rc::new(NativeAgentConnection(agent.clone()));
2300
2301 // Register a thinking model.
2302 let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2303 "fake-corp",
2304 "fake-thinking",
2305 "Fake Thinking",
2306 true,
2307 ));
2308 let thinking_provider = Arc::new(
2309 FakeLanguageModelProvider::new(
2310 LanguageModelProviderId::from("fake-corp".to_string()),
2311 LanguageModelProviderName::from("Fake Corp".to_string()),
2312 )
2313 .with_models(vec![thinking_model.clone()]),
2314 );
2315 cx.update(|cx| {
2316 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2317 registry.register_provider(thinking_provider, cx);
2318 });
2319 });
2320 agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2321
2322 // Create a thread and select the thinking model.
2323 let acp_thread = cx
2324 .update(|cx| {
2325 connection
2326 .clone()
2327 .new_session(project.clone(), Path::new("/a"), cx)
2328 })
2329 .await
2330 .unwrap();
2331 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2332
2333 let selector = connection.model_selector(&session_id).unwrap();
2334 cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2335 .await
2336 .unwrap();
2337
2338 // Verify thinking is enabled after selecting the thinking model.
2339 let thread = agent.read_with(cx, |agent, _| {
2340 agent.sessions.get(&session_id).unwrap().thread.clone()
2341 });
2342 thread.read_with(cx, |thread, _| {
2343 assert!(
2344 thread.thinking_enabled(),
2345 "thinking should be enabled after selecting thinking model"
2346 );
2347 });
2348
2349 // Send a message so the thread gets persisted.
2350 let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
2351 let send = cx.foreground_executor().spawn(send);
2352 cx.run_until_parked();
2353
2354 thinking_model.send_last_completion_stream_text_chunk("Response.");
2355 thinking_model.end_last_completion_stream();
2356
2357 send.await.unwrap();
2358 cx.run_until_parked();
2359
2360 // Close the session so it can be reloaded from disk.
2361 cx.update(|cx| connection.clone().close_session(&session_id, cx))
2362 .await
2363 .unwrap();
2364 drop(thread);
2365 drop(acp_thread);
2366 agent.read_with(cx, |agent, _| {
2367 assert!(agent.sessions.is_empty());
2368 });
2369
2370 // Reload the thread and verify thinking_enabled is still true.
2371 let reloaded_acp_thread = agent
2372 .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
2373 .await
2374 .unwrap();
2375 let reloaded_thread = agent.read_with(cx, |agent, _| {
2376 agent.sessions.get(&session_id).unwrap().thread.clone()
2377 });
2378 reloaded_thread.read_with(cx, |thread, _| {
2379 assert!(
2380 thread.thinking_enabled(),
2381 "thinking_enabled should be preserved when reloading a thread with a thinking model"
2382 );
2383 });
2384
2385 drop(reloaded_acp_thread);
2386 }
2387
2388 #[gpui::test]
2389 async fn test_loaded_thread_preserves_model(cx: &mut TestAppContext) {
2390 init_test(cx);
2391 let fs = FakeFs::new(cx.executor());
2392 fs.insert_tree("/", json!({ "a": {} })).await;
2393 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2394 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2395 let agent = NativeAgent::new(
2396 project.clone(),
2397 thread_store.clone(),
2398 Templates::new(),
2399 None,
2400 fs.clone(),
2401 &mut cx.to_async(),
2402 )
2403 .await
2404 .unwrap();
2405 let connection = Rc::new(NativeAgentConnection(agent.clone()));
2406
2407 // Register a model where id() != name(), like real Anthropic models
2408 // (e.g. id="claude-sonnet-4-5-thinking-latest", name="Claude Sonnet 4.5 Thinking").
2409 let model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2410 "fake-corp",
2411 "custom-model-id",
2412 "Custom Model Display Name",
2413 false,
2414 ));
2415 let provider = Arc::new(
2416 FakeLanguageModelProvider::new(
2417 LanguageModelProviderId::from("fake-corp".to_string()),
2418 LanguageModelProviderName::from("Fake Corp".to_string()),
2419 )
2420 .with_models(vec![model.clone()]),
2421 );
2422 cx.update(|cx| {
2423 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2424 registry.register_provider(provider, cx);
2425 });
2426 });
2427 agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2428
2429 // Create a thread and select the model.
2430 let acp_thread = cx
2431 .update(|cx| {
2432 connection
2433 .clone()
2434 .new_session(project.clone(), Path::new("/a"), cx)
2435 })
2436 .await
2437 .unwrap();
2438 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2439
2440 let selector = connection.model_selector(&session_id).unwrap();
2441 cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/custom-model-id"), cx))
2442 .await
2443 .unwrap();
2444
2445 let thread = agent.read_with(cx, |agent, _| {
2446 agent.sessions.get(&session_id).unwrap().thread.clone()
2447 });
2448 thread.read_with(cx, |thread, _| {
2449 assert_eq!(
2450 thread.model().unwrap().id().0.as_ref(),
2451 "custom-model-id",
2452 "model should be set before persisting"
2453 );
2454 });
2455
2456 // Send a message so the thread gets persisted.
2457 let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
2458 let send = cx.foreground_executor().spawn(send);
2459 cx.run_until_parked();
2460
2461 model.send_last_completion_stream_text_chunk("Response.");
2462 model.end_last_completion_stream();
2463
2464 send.await.unwrap();
2465 cx.run_until_parked();
2466
2467 // Close the session so it can be reloaded from disk.
2468 cx.update(|cx| connection.clone().close_session(&session_id, cx))
2469 .await
2470 .unwrap();
2471 drop(thread);
2472 drop(acp_thread);
2473 agent.read_with(cx, |agent, _| {
2474 assert!(agent.sessions.is_empty());
2475 });
2476
2477 // Reload the thread and verify the model was preserved.
2478 let reloaded_acp_thread = agent
2479 .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
2480 .await
2481 .unwrap();
2482 let reloaded_thread = agent.read_with(cx, |agent, _| {
2483 agent.sessions.get(&session_id).unwrap().thread.clone()
2484 });
2485 reloaded_thread.read_with(cx, |thread, _| {
2486 let reloaded_model = thread
2487 .model()
2488 .expect("model should be present after reload");
2489 assert_eq!(
2490 reloaded_model.id().0.as_ref(),
2491 "custom-model-id",
2492 "reloaded thread should have the same model, not fall back to the default"
2493 );
2494 });
2495
2496 drop(reloaded_acp_thread);
2497 }
2498
2499 #[gpui::test]
2500 async fn test_save_load_thread(cx: &mut TestAppContext) {
2501 init_test(cx);
2502 let fs = FakeFs::new(cx.executor());
2503 fs.insert_tree(
2504 "/",
2505 json!({
2506 "a": {
2507 "b.md": "Lorem"
2508 }
2509 }),
2510 )
2511 .await;
2512 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2513 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2514 let agent = NativeAgent::new(
2515 project.clone(),
2516 thread_store.clone(),
2517 Templates::new(),
2518 None,
2519 fs.clone(),
2520 &mut cx.to_async(),
2521 )
2522 .await
2523 .unwrap();
2524 let connection = Rc::new(NativeAgentConnection(agent.clone()));
2525
2526 let acp_thread = cx
2527 .update(|cx| {
2528 connection
2529 .clone()
2530 .new_session(project.clone(), Path::new(""), cx)
2531 })
2532 .await
2533 .unwrap();
2534 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2535 let thread = agent.read_with(cx, |agent, _| {
2536 agent.sessions.get(&session_id).unwrap().thread.clone()
2537 });
2538
2539 // Ensure empty threads are not saved, even if they get mutated.
2540 let model = Arc::new(FakeLanguageModel::default());
2541 let summary_model = Arc::new(FakeLanguageModel::default());
2542 thread.update(cx, |thread, cx| {
2543 thread.set_model(model.clone(), cx);
2544 thread.set_summarization_model(Some(summary_model.clone()), cx);
2545 });
2546 cx.run_until_parked();
2547 assert_eq!(thread_entries(&thread_store, cx), vec![]);
2548
2549 let send = acp_thread.update(cx, |thread, cx| {
2550 thread.send(
2551 vec![
2552 "What does ".into(),
2553 acp::ContentBlock::ResourceLink(acp::ResourceLink::new(
2554 "b.md",
2555 MentionUri::File {
2556 abs_path: path!("/a/b.md").into(),
2557 }
2558 .to_uri()
2559 .to_string(),
2560 )),
2561 " mean?".into(),
2562 ],
2563 cx,
2564 )
2565 });
2566 let send = cx.foreground_executor().spawn(send);
2567 cx.run_until_parked();
2568
2569 model.send_last_completion_stream_text_chunk("Lorem.");
2570 model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2571 language_model::TokenUsage {
2572 input_tokens: 150,
2573 output_tokens: 75,
2574 ..Default::default()
2575 },
2576 ));
2577 model.end_last_completion_stream();
2578 cx.run_until_parked();
2579 summary_model
2580 .send_last_completion_stream_text_chunk(&format!("Explaining {}", path!("/a/b.md")));
2581 summary_model.end_last_completion_stream();
2582
2583 send.await.unwrap();
2584 let uri = MentionUri::File {
2585 abs_path: path!("/a/b.md").into(),
2586 }
2587 .to_uri();
2588 acp_thread.read_with(cx, |thread, cx| {
2589 assert_eq!(
2590 thread.to_markdown(cx),
2591 formatdoc! {"
2592 ## User
2593
2594 What does [@b.md]({uri}) mean?
2595
2596 ## Assistant
2597
2598 Lorem.
2599
2600 "}
2601 )
2602 });
2603
2604 cx.run_until_parked();
2605
2606 // Set a draft prompt with rich content blocks before saving.
2607 let draft_blocks = vec![
2608 acp::ContentBlock::Text(acp::TextContent::new("Check out ")),
2609 acp::ContentBlock::ResourceLink(acp::ResourceLink::new("b.md", uri.to_string())),
2610 acp::ContentBlock::Text(acp::TextContent::new(" please")),
2611 ];
2612 acp_thread.update(cx, |thread, _cx| {
2613 thread.set_draft_prompt(Some(draft_blocks.clone()));
2614 });
2615 thread.update(cx, |thread, _cx| {
2616 thread.set_ui_scroll_position(Some(gpui::ListOffset {
2617 item_ix: 5,
2618 offset_in_item: gpui::px(12.5),
2619 }));
2620 });
2621 thread.update(cx, |_thread, cx| cx.notify());
2622 cx.run_until_parked();
2623
2624 // Close the session so it can be reloaded from disk.
2625 cx.update(|cx| connection.clone().close_session(&session_id, cx))
2626 .await
2627 .unwrap();
2628 drop(thread);
2629 drop(acp_thread);
2630 agent.read_with(cx, |agent, _| {
2631 assert_eq!(agent.sessions.keys().cloned().collect::<Vec<_>>(), []);
2632 });
2633
2634 // Ensure the thread can be reloaded from disk.
2635 assert_eq!(
2636 thread_entries(&thread_store, cx),
2637 vec![(
2638 session_id.clone(),
2639 format!("Explaining {}", path!("/a/b.md"))
2640 )]
2641 );
2642 let acp_thread = agent
2643 .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
2644 .await
2645 .unwrap();
2646 acp_thread.read_with(cx, |thread, cx| {
2647 assert_eq!(
2648 thread.to_markdown(cx),
2649 formatdoc! {"
2650 ## User
2651
2652 What does [@b.md]({uri}) mean?
2653
2654 ## Assistant
2655
2656 Lorem.
2657
2658 "}
2659 )
2660 });
2661
2662 // Ensure the draft prompt with rich content blocks survived the round-trip.
2663 acp_thread.read_with(cx, |thread, _| {
2664 assert_eq!(thread.draft_prompt(), Some(draft_blocks.as_slice()));
2665 });
2666
2667 // Ensure token usage survived the round-trip.
2668 acp_thread.read_with(cx, |thread, _| {
2669 let usage = thread
2670 .token_usage()
2671 .expect("token usage should be restored after reload");
2672 assert_eq!(usage.input_tokens, 150);
2673 assert_eq!(usage.output_tokens, 75);
2674 });
2675
2676 // Ensure scroll position survived the round-trip.
2677 acp_thread.read_with(cx, |thread, _| {
2678 let scroll = thread
2679 .ui_scroll_position()
2680 .expect("scroll position should be restored after reload");
2681 assert_eq!(scroll.item_ix, 5);
2682 assert_eq!(scroll.offset_in_item, gpui::px(12.5));
2683 });
2684 }
2685
2686 fn thread_entries(
2687 thread_store: &Entity<ThreadStore>,
2688 cx: &mut TestAppContext,
2689 ) -> Vec<(acp::SessionId, String)> {
2690 thread_store.read_with(cx, |store, _| {
2691 store
2692 .entries()
2693 .map(|entry| (entry.id.clone(), entry.title.to_string()))
2694 .collect::<Vec<_>>()
2695 })
2696 }
2697
2698 fn init_test(cx: &mut TestAppContext) {
2699 env_logger::try_init().ok();
2700 cx.update(|cx| {
2701 let settings_store = SettingsStore::test(cx);
2702 cx.set_global(settings_store);
2703
2704 LanguageModelRegistry::test(cx);
2705 });
2706 }
2707}
2708
2709fn mcp_message_content_to_acp_content_block(
2710 content: context_server::types::MessageContent,
2711) -> acp::ContentBlock {
2712 match content {
2713 context_server::types::MessageContent::Text {
2714 text,
2715 annotations: _,
2716 } => text.into(),
2717 context_server::types::MessageContent::Image {
2718 data,
2719 mime_type,
2720 annotations: _,
2721 } => acp::ContentBlock::Image(acp::ImageContent::new(data, mime_type)),
2722 context_server::types::MessageContent::Audio {
2723 data,
2724 mime_type,
2725 annotations: _,
2726 } => acp::ContentBlock::Audio(acp::AudioContent::new(data, mime_type)),
2727 context_server::types::MessageContent::Resource {
2728 resource,
2729 annotations: _,
2730 } => {
2731 let mut link =
2732 acp::ResourceLink::new(resource.uri.to_string(), resource.uri.to_string());
2733 if let Some(mime_type) = resource.mime_type {
2734 link = link.mime_type(mime_type);
2735 }
2736 acp::ContentBlock::ResourceLink(link)
2737 }
2738 }
2739}