1mod db;
2mod edit_agent;
3mod legacy_thread;
4mod native_agent_server;
5pub mod outline;
6mod pattern_extraction;
7mod templates;
8#[cfg(test)]
9mod tests;
10mod thread;
11mod thread_store;
12mod tool_permissions;
13mod tools;
14
15use context_server::ContextServerId;
16pub use db::*;
17pub use native_agent_server::NativeAgentServer;
18pub use pattern_extraction::*;
19pub use templates::*;
20pub use thread::*;
21pub use thread_store::*;
22pub use tool_permissions::*;
23pub use tools::*;
24
25use acp_thread::{
26 AcpThread, AgentModelSelector, AgentSessionInfo, AgentSessionList, AgentSessionListRequest,
27 AgentSessionListResponse, UserMessageId,
28};
29use agent_client_protocol as acp;
30use anyhow::{Context as _, Result, anyhow};
31use chrono::{DateTime, Utc};
32use collections::{HashMap, HashSet, IndexMap};
33use fs::Fs;
34use futures::channel::{mpsc, oneshot};
35use futures::future::Shared;
36use futures::{FutureExt as _, StreamExt as _, future};
37use gpui::{
38 App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
39};
40use language_model::{IconOrSvg, LanguageModel, LanguageModelProvider, LanguageModelRegistry};
41use project::{Project, ProjectItem, ProjectPath, Worktree};
42use prompt_store::{
43 ProjectContext, PromptStore, RULES_FILE_NAMES, RulesFileContext, UserRulesContext,
44 WorktreeContext,
45};
46use serde::{Deserialize, Serialize};
47use settings::{LanguageModelSelection, update_settings_file};
48use std::any::Any;
49use std::path::{Path, PathBuf};
50use std::rc::Rc;
51use std::sync::Arc;
52use std::time::Duration;
53use util::ResultExt;
54use util::rel_path::RelPath;
55
56#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
57pub struct ProjectSnapshot {
58 pub worktree_snapshots: Vec<project::telemetry_snapshot::TelemetryWorktreeSnapshot>,
59 pub timestamp: DateTime<Utc>,
60}
61
62pub struct RulesLoadingError {
63 pub message: SharedString,
64}
65
66/// Holds both the internal Thread and the AcpThread for a session
67struct Session {
68 /// The internal thread that processes messages
69 thread: Entity<Thread>,
70 /// The ACP thread that handles protocol communication
71 acp_thread: Entity<acp_thread::AcpThread>,
72 pending_save: Task<()>,
73 _subscriptions: Vec<Subscription>,
74}
75
76pub struct LanguageModels {
77 /// Access language model by ID
78 models: HashMap<acp::ModelId, Arc<dyn LanguageModel>>,
79 /// Cached list for returning language model information
80 model_list: acp_thread::AgentModelList,
81 refresh_models_rx: watch::Receiver<()>,
82 refresh_models_tx: watch::Sender<()>,
83 _authenticate_all_providers_task: Task<()>,
84}
85
86impl LanguageModels {
87 fn new(cx: &mut App) -> Self {
88 let (refresh_models_tx, refresh_models_rx) = watch::channel(());
89
90 let mut this = Self {
91 models: HashMap::default(),
92 model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()),
93 refresh_models_rx,
94 refresh_models_tx,
95 _authenticate_all_providers_task: Self::authenticate_all_language_model_providers(cx),
96 };
97 this.refresh_list(cx);
98 this
99 }
100
101 fn refresh_list(&mut self, cx: &App) {
102 let providers = LanguageModelRegistry::global(cx)
103 .read(cx)
104 .visible_providers()
105 .into_iter()
106 .filter(|provider| provider.is_authenticated(cx))
107 .collect::<Vec<_>>();
108
109 let mut language_model_list = IndexMap::default();
110 let mut recommended_models = HashSet::default();
111
112 let mut recommended = Vec::new();
113 for provider in &providers {
114 for model in provider.recommended_models(cx) {
115 recommended_models.insert((model.provider_id(), model.id()));
116 recommended.push(Self::map_language_model_to_info(&model, provider));
117 }
118 }
119 if !recommended.is_empty() {
120 language_model_list.insert(
121 acp_thread::AgentModelGroupName("Recommended".into()),
122 recommended,
123 );
124 }
125
126 let mut models = HashMap::default();
127 for provider in providers {
128 let mut provider_models = Vec::new();
129 for model in provider.provided_models(cx) {
130 let model_info = Self::map_language_model_to_info(&model, &provider);
131 let model_id = model_info.id.clone();
132 provider_models.push(model_info);
133 models.insert(model_id, model);
134 }
135 if !provider_models.is_empty() {
136 language_model_list.insert(
137 acp_thread::AgentModelGroupName(provider.name().0.clone()),
138 provider_models,
139 );
140 }
141 }
142
143 self.models = models;
144 self.model_list = acp_thread::AgentModelList::Grouped(language_model_list);
145 self.refresh_models_tx.send(()).ok();
146 }
147
148 fn watch(&self) -> watch::Receiver<()> {
149 self.refresh_models_rx.clone()
150 }
151
152 pub fn model_from_id(&self, model_id: &acp::ModelId) -> Option<Arc<dyn LanguageModel>> {
153 self.models.get(model_id).cloned()
154 }
155
156 fn map_language_model_to_info(
157 model: &Arc<dyn LanguageModel>,
158 provider: &Arc<dyn LanguageModelProvider>,
159 ) -> acp_thread::AgentModelInfo {
160 acp_thread::AgentModelInfo {
161 id: Self::model_id(model),
162 name: model.name().0,
163 description: None,
164 icon: Some(match provider.icon() {
165 IconOrSvg::Svg(path) => acp_thread::AgentModelIcon::Path(path),
166 IconOrSvg::Icon(name) => acp_thread::AgentModelIcon::Named(name),
167 }),
168 is_latest: model.is_latest(),
169 }
170 }
171
172 fn model_id(model: &Arc<dyn LanguageModel>) -> acp::ModelId {
173 acp::ModelId::new(format!("{}/{}", model.provider_id().0, model.id().0))
174 }
175
176 fn authenticate_all_language_model_providers(cx: &mut App) -> Task<()> {
177 let authenticate_all_providers = LanguageModelRegistry::global(cx)
178 .read(cx)
179 .visible_providers()
180 .iter()
181 .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
182 .collect::<Vec<_>>();
183
184 cx.background_spawn(async move {
185 for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
186 if let Err(err) = authenticate_task.await {
187 match err {
188 language_model::AuthenticateError::CredentialsNotFound => {
189 // Since we're authenticating these providers in the
190 // background for the purposes of populating the
191 // language selector, we don't care about providers
192 // where the credentials are not found.
193 }
194 language_model::AuthenticateError::ConnectionRefused => {
195 // Not logging connection refused errors as they are mostly from LM Studio's noisy auth failures.
196 // LM Studio only has one auth method (endpoint call) which fails for users who haven't enabled it.
197 // TODO: Better manage LM Studio auth logic to avoid these noisy failures.
198 }
199 _ => {
200 // Some providers have noisy failure states that we
201 // don't want to spam the logs with every time the
202 // language model selector is initialized.
203 //
204 // Ideally these should have more clear failure modes
205 // that we know are safe to ignore here, like what we do
206 // with `CredentialsNotFound` above.
207 match provider_id.0.as_ref() {
208 "lmstudio" | "ollama" => {
209 // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
210 //
211 // These fail noisily, so we don't log them.
212 }
213 "copilot_chat" => {
214 // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
215 }
216 _ => {
217 log::error!(
218 "Failed to authenticate provider: {}: {err:#}",
219 provider_name.0
220 );
221 }
222 }
223 }
224 }
225 }
226 }
227 })
228 }
229}
230
231pub struct NativeAgent {
232 /// Session ID -> Session mapping
233 sessions: HashMap<acp::SessionId, Session>,
234 thread_store: Entity<ThreadStore>,
235 /// Shared project context for all threads
236 project_context: Entity<ProjectContext>,
237 project_context_needs_refresh: watch::Sender<()>,
238 _maintain_project_context: Task<Result<()>>,
239 context_server_registry: Entity<ContextServerRegistry>,
240 /// Shared templates for all threads
241 templates: Arc<Templates>,
242 /// Cached model information
243 models: LanguageModels,
244 project: Entity<Project>,
245 prompt_store: Option<Entity<PromptStore>>,
246 fs: Arc<dyn Fs>,
247 _subscriptions: Vec<Subscription>,
248}
249
250impl NativeAgent {
251 pub async fn new(
252 project: Entity<Project>,
253 thread_store: Entity<ThreadStore>,
254 templates: Arc<Templates>,
255 prompt_store: Option<Entity<PromptStore>>,
256 fs: Arc<dyn Fs>,
257 cx: &mut AsyncApp,
258 ) -> Result<Entity<NativeAgent>> {
259 log::debug!("Creating new NativeAgent");
260
261 let project_context = cx
262 .update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))
263 .await;
264
265 Ok(cx.new(|cx| {
266 let context_server_store = project.read(cx).context_server_store();
267 let context_server_registry =
268 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
269
270 let mut subscriptions = vec![
271 cx.subscribe(&project, Self::handle_project_event),
272 cx.subscribe(
273 &LanguageModelRegistry::global(cx),
274 Self::handle_models_updated_event,
275 ),
276 cx.subscribe(
277 &context_server_store,
278 Self::handle_context_server_store_updated,
279 ),
280 cx.subscribe(
281 &context_server_registry,
282 Self::handle_context_server_registry_event,
283 ),
284 ];
285 if let Some(prompt_store) = prompt_store.as_ref() {
286 subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
287 }
288
289 let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
290 watch::channel(());
291 Self {
292 sessions: HashMap::default(),
293 thread_store,
294 project_context: cx.new(|_| project_context),
295 project_context_needs_refresh: project_context_needs_refresh_tx,
296 _maintain_project_context: cx.spawn(async move |this, cx| {
297 Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
298 }),
299 context_server_registry,
300 templates,
301 models: LanguageModels::new(cx),
302 project,
303 prompt_store,
304 fs,
305 _subscriptions: subscriptions,
306 }
307 }))
308 }
309
310 fn new_session(
311 &mut self,
312 project: Entity<Project>,
313 cx: &mut Context<Self>,
314 ) -> Entity<AcpThread> {
315 // Create Thread
316 // Fetch default model from registry settings
317 let registry = LanguageModelRegistry::read_global(cx);
318 // Log available models for debugging
319 let available_count = registry.available_models(cx).count();
320 log::debug!("Total available models: {}", available_count);
321
322 let default_model = registry.default_model().and_then(|default_model| {
323 self.models
324 .model_from_id(&LanguageModels::model_id(&default_model.model))
325 });
326 let thread = cx.new(|cx| {
327 Thread::new(
328 project.clone(),
329 self.project_context.clone(),
330 self.context_server_registry.clone(),
331 self.templates.clone(),
332 default_model,
333 cx,
334 )
335 });
336
337 self.register_session(thread, None, cx)
338 }
339
340 fn register_session(
341 &mut self,
342 thread_handle: Entity<Thread>,
343 allowed_tool_names: Option<Vec<&str>>,
344 cx: &mut Context<Self>,
345 ) -> Entity<AcpThread> {
346 let connection = Rc::new(NativeAgentConnection(cx.entity()));
347
348 let thread = thread_handle.read(cx);
349 let session_id = thread.id().clone();
350 let parent_session_id = thread.parent_thread_id();
351 let title = thread.title();
352 let project = thread.project.clone();
353 let action_log = thread.action_log.clone();
354 let prompt_capabilities_rx = thread.prompt_capabilities_rx.clone();
355 let acp_thread = cx.new(|cx| {
356 acp_thread::AcpThread::new(
357 parent_session_id,
358 title,
359 connection,
360 project.clone(),
361 action_log.clone(),
362 session_id.clone(),
363 prompt_capabilities_rx,
364 cx,
365 )
366 });
367
368 let registry = LanguageModelRegistry::read_global(cx);
369 let summarization_model = registry.thread_summary_model().map(|c| c.model);
370
371 let weak = cx.weak_entity();
372 thread_handle.update(cx, |thread, cx| {
373 thread.set_summarization_model(summarization_model, cx);
374 thread.add_default_tools(
375 allowed_tool_names,
376 Rc::new(NativeThreadEnvironment {
377 acp_thread: acp_thread.downgrade(),
378 agent: weak,
379 }) as _,
380 cx,
381 )
382 });
383
384 let subscriptions = vec![
385 cx.subscribe(&thread_handle, Self::handle_thread_title_updated),
386 cx.subscribe(&thread_handle, Self::handle_thread_token_usage_updated),
387 cx.observe(&thread_handle, move |this, thread, cx| {
388 this.save_thread(thread, cx)
389 }),
390 ];
391
392 self.sessions.insert(
393 session_id,
394 Session {
395 thread: thread_handle,
396 acp_thread: acp_thread.clone(),
397 _subscriptions: subscriptions,
398 pending_save: Task::ready(()),
399 },
400 );
401
402 self.update_available_commands(cx);
403
404 acp_thread
405 }
406
407 pub fn models(&self) -> &LanguageModels {
408 &self.models
409 }
410
411 async fn maintain_project_context(
412 this: WeakEntity<Self>,
413 mut needs_refresh: watch::Receiver<()>,
414 cx: &mut AsyncApp,
415 ) -> Result<()> {
416 while needs_refresh.changed().await.is_ok() {
417 let project_context = this
418 .update(cx, |this, cx| {
419 Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx)
420 })?
421 .await;
422 this.update(cx, |this, cx| {
423 this.project_context = cx.new(|_| project_context);
424 })?;
425 }
426
427 Ok(())
428 }
429
430 fn build_project_context(
431 project: &Entity<Project>,
432 prompt_store: Option<&Entity<PromptStore>>,
433 cx: &mut App,
434 ) -> Task<ProjectContext> {
435 let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
436 let worktree_tasks = worktrees
437 .into_iter()
438 .map(|worktree| {
439 Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
440 })
441 .collect::<Vec<_>>();
442 let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
443 prompt_store.read_with(cx, |prompt_store, cx| {
444 let prompts = prompt_store.default_prompt_metadata();
445 let load_tasks = prompts.into_iter().map(|prompt_metadata| {
446 let contents = prompt_store.load(prompt_metadata.id, cx);
447 async move { (contents.await, prompt_metadata) }
448 });
449 cx.background_spawn(future::join_all(load_tasks))
450 })
451 } else {
452 Task::ready(vec![])
453 };
454
455 cx.spawn(async move |_cx| {
456 let (worktrees, default_user_rules) =
457 future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
458
459 let worktrees = worktrees
460 .into_iter()
461 .map(|(worktree, _rules_error)| {
462 // TODO: show error message
463 // if let Some(rules_error) = rules_error {
464 // this.update(cx, |_, cx| cx.emit(rules_error)).ok();
465 // }
466 worktree
467 })
468 .collect::<Vec<_>>();
469
470 let default_user_rules = default_user_rules
471 .into_iter()
472 .flat_map(|(contents, prompt_metadata)| match contents {
473 Ok(contents) => Some(UserRulesContext {
474 uuid: prompt_metadata.id.as_user()?,
475 title: prompt_metadata.title.map(|title| title.to_string()),
476 contents,
477 }),
478 Err(_err) => {
479 // TODO: show error message
480 // this.update(cx, |_, cx| {
481 // cx.emit(RulesLoadingError {
482 // message: format!("{err:?}").into(),
483 // });
484 // })
485 // .ok();
486 None
487 }
488 })
489 .collect::<Vec<_>>();
490
491 ProjectContext::new(worktrees, default_user_rules)
492 })
493 }
494
495 fn load_worktree_info_for_system_prompt(
496 worktree: Entity<Worktree>,
497 project: Entity<Project>,
498 cx: &mut App,
499 ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
500 let tree = worktree.read(cx);
501 let root_name = tree.root_name_str().into();
502 let abs_path = tree.abs_path();
503
504 let mut context = WorktreeContext {
505 root_name,
506 abs_path,
507 rules_file: None,
508 };
509
510 let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
511 let Some(rules_task) = rules_task else {
512 return Task::ready((context, None));
513 };
514
515 cx.spawn(async move |_| {
516 let (rules_file, rules_file_error) = match rules_task.await {
517 Ok(rules_file) => (Some(rules_file), None),
518 Err(err) => (
519 None,
520 Some(RulesLoadingError {
521 message: format!("{err}").into(),
522 }),
523 ),
524 };
525 context.rules_file = rules_file;
526 (context, rules_file_error)
527 })
528 }
529
530 fn load_worktree_rules_file(
531 worktree: Entity<Worktree>,
532 project: Entity<Project>,
533 cx: &mut App,
534 ) -> Option<Task<Result<RulesFileContext>>> {
535 let worktree = worktree.read(cx);
536 let worktree_id = worktree.id();
537 let selected_rules_file = RULES_FILE_NAMES
538 .into_iter()
539 .filter_map(|name| {
540 worktree
541 .entry_for_path(RelPath::unix(name).unwrap())
542 .filter(|entry| entry.is_file())
543 .map(|entry| entry.path.clone())
544 })
545 .next();
546
547 // Note that Cline supports `.clinerules` being a directory, but that is not currently
548 // supported. This doesn't seem to occur often in GitHub repositories.
549 selected_rules_file.map(|path_in_worktree| {
550 let project_path = ProjectPath {
551 worktree_id,
552 path: path_in_worktree.clone(),
553 };
554 let buffer_task =
555 project.update(cx, |project, cx| project.open_buffer(project_path, cx));
556 let rope_task = cx.spawn(async move |cx| {
557 let buffer = buffer_task.await?;
558 let (project_entry_id, rope) = buffer.read_with(cx, |buffer, cx| {
559 let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
560 anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
561 })?;
562 anyhow::Ok((project_entry_id, rope))
563 });
564 // Build a string from the rope on a background thread.
565 cx.background_spawn(async move {
566 let (project_entry_id, rope) = rope_task.await?;
567 anyhow::Ok(RulesFileContext {
568 path_in_worktree,
569 text: rope.to_string().trim().to_string(),
570 project_entry_id: project_entry_id.to_usize(),
571 })
572 })
573 })
574 }
575
576 fn handle_thread_title_updated(
577 &mut self,
578 thread: Entity<Thread>,
579 _: &TitleUpdated,
580 cx: &mut Context<Self>,
581 ) {
582 let session_id = thread.read(cx).id();
583 let Some(session) = self.sessions.get(session_id) else {
584 return;
585 };
586 let thread = thread.downgrade();
587 let acp_thread = session.acp_thread.downgrade();
588 cx.spawn(async move |_, cx| {
589 let title = thread.read_with(cx, |thread, _| thread.title())?;
590 let task = acp_thread.update(cx, |acp_thread, cx| acp_thread.set_title(title, cx))?;
591 task.await
592 })
593 .detach_and_log_err(cx);
594 }
595
596 fn handle_thread_token_usage_updated(
597 &mut self,
598 thread: Entity<Thread>,
599 usage: &TokenUsageUpdated,
600 cx: &mut Context<Self>,
601 ) {
602 let Some(session) = self.sessions.get(thread.read(cx).id()) else {
603 return;
604 };
605 session.acp_thread.update(cx, |acp_thread, cx| {
606 acp_thread.update_token_usage(usage.0.clone(), cx);
607 });
608 }
609
610 fn handle_project_event(
611 &mut self,
612 _project: Entity<Project>,
613 event: &project::Event,
614 _cx: &mut Context<Self>,
615 ) {
616 match event {
617 project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
618 self.project_context_needs_refresh.send(()).ok();
619 }
620 project::Event::WorktreeUpdatedEntries(_, items) => {
621 if items.iter().any(|(path, _, _)| {
622 RULES_FILE_NAMES
623 .iter()
624 .any(|name| path.as_ref() == RelPath::unix(name).unwrap())
625 }) {
626 self.project_context_needs_refresh.send(()).ok();
627 }
628 }
629 _ => {}
630 }
631 }
632
633 fn handle_prompts_updated_event(
634 &mut self,
635 _prompt_store: Entity<PromptStore>,
636 _event: &prompt_store::PromptsUpdatedEvent,
637 _cx: &mut Context<Self>,
638 ) {
639 self.project_context_needs_refresh.send(()).ok();
640 }
641
642 fn handle_models_updated_event(
643 &mut self,
644 _registry: Entity<LanguageModelRegistry>,
645 _event: &language_model::Event,
646 cx: &mut Context<Self>,
647 ) {
648 self.models.refresh_list(cx);
649
650 let registry = LanguageModelRegistry::read_global(cx);
651 let default_model = registry.default_model().map(|m| m.model);
652 let summarization_model = registry.thread_summary_model().map(|m| m.model);
653
654 for session in self.sessions.values_mut() {
655 session.thread.update(cx, |thread, cx| {
656 if thread.model().is_none()
657 && let Some(model) = default_model.clone()
658 {
659 thread.set_model(model, cx);
660 cx.notify();
661 }
662 thread.set_summarization_model(summarization_model.clone(), cx);
663 });
664 }
665 }
666
667 fn handle_context_server_store_updated(
668 &mut self,
669 _store: Entity<project::context_server_store::ContextServerStore>,
670 _event: &project::context_server_store::ServerStatusChangedEvent,
671 cx: &mut Context<Self>,
672 ) {
673 self.update_available_commands(cx);
674 }
675
676 fn handle_context_server_registry_event(
677 &mut self,
678 _registry: Entity<ContextServerRegistry>,
679 event: &ContextServerRegistryEvent,
680 cx: &mut Context<Self>,
681 ) {
682 match event {
683 ContextServerRegistryEvent::ToolsChanged => {}
684 ContextServerRegistryEvent::PromptsChanged => {
685 self.update_available_commands(cx);
686 }
687 }
688 }
689
690 fn update_available_commands(&self, cx: &mut Context<Self>) {
691 let available_commands = self.build_available_commands(cx);
692 for session in self.sessions.values() {
693 session.acp_thread.update(cx, |thread, cx| {
694 thread
695 .handle_session_update(
696 acp::SessionUpdate::AvailableCommandsUpdate(
697 acp::AvailableCommandsUpdate::new(available_commands.clone()),
698 ),
699 cx,
700 )
701 .log_err();
702 });
703 }
704 }
705
706 fn build_available_commands(&self, cx: &App) -> Vec<acp::AvailableCommand> {
707 let registry = self.context_server_registry.read(cx);
708
709 let mut prompt_name_counts: HashMap<&str, usize> = HashMap::default();
710 for context_server_prompt in registry.prompts() {
711 *prompt_name_counts
712 .entry(context_server_prompt.prompt.name.as_str())
713 .or_insert(0) += 1;
714 }
715
716 registry
717 .prompts()
718 .flat_map(|context_server_prompt| {
719 let prompt = &context_server_prompt.prompt;
720
721 let should_prefix = prompt_name_counts
722 .get(prompt.name.as_str())
723 .copied()
724 .unwrap_or(0)
725 > 1;
726
727 let name = if should_prefix {
728 format!("{}.{}", context_server_prompt.server_id, prompt.name)
729 } else {
730 prompt.name.clone()
731 };
732
733 let mut command = acp::AvailableCommand::new(
734 name,
735 prompt.description.clone().unwrap_or_default(),
736 );
737
738 match prompt.arguments.as_deref() {
739 Some([arg]) => {
740 let hint = format!("<{}>", arg.name);
741
742 command = command.input(acp::AvailableCommandInput::Unstructured(
743 acp::UnstructuredCommandInput::new(hint),
744 ));
745 }
746 Some([]) | None => {}
747 Some(_) => {
748 // skip >1 argument commands since we don't support them yet
749 return None;
750 }
751 }
752
753 Some(command)
754 })
755 .collect()
756 }
757
758 pub fn load_thread(
759 &mut self,
760 id: acp::SessionId,
761 cx: &mut Context<Self>,
762 ) -> Task<Result<Entity<Thread>>> {
763 let database_future = ThreadsDatabase::connect(cx);
764 cx.spawn(async move |this, cx| {
765 let database = database_future.await.map_err(|err| anyhow!(err))?;
766 let db_thread = database
767 .load_thread(id.clone())
768 .await?
769 .with_context(|| format!("no thread found with ID: {id:?}"))?;
770
771 this.update(cx, |this, cx| {
772 let summarization_model = LanguageModelRegistry::read_global(cx)
773 .thread_summary_model()
774 .map(|c| c.model);
775
776 cx.new(|cx| {
777 let mut thread = Thread::from_db(
778 id.clone(),
779 db_thread,
780 this.project.clone(),
781 this.project_context.clone(),
782 this.context_server_registry.clone(),
783 this.templates.clone(),
784 cx,
785 );
786 thread.set_summarization_model(summarization_model, cx);
787 thread
788 })
789 })
790 })
791 }
792
793 pub fn open_thread(
794 &mut self,
795 id: acp::SessionId,
796 cx: &mut Context<Self>,
797 ) -> Task<Result<Entity<AcpThread>>> {
798 if let Some(session) = self.sessions.get(&id) {
799 return Task::ready(Ok(session.acp_thread.clone()));
800 }
801
802 let task = self.load_thread(id, cx);
803 cx.spawn(async move |this, cx| {
804 let thread = task.await?;
805 let acp_thread = this.update(cx, |this, cx| {
806 this.register_session(thread.clone(), None, cx)
807 })?;
808 let events = thread.update(cx, |thread, cx| thread.replay(cx));
809 cx.update(|cx| {
810 NativeAgentConnection::handle_thread_events(events, acp_thread.downgrade(), cx)
811 })
812 .await?;
813 Ok(acp_thread)
814 })
815 }
816
817 pub fn thread_summary(
818 &mut self,
819 id: acp::SessionId,
820 cx: &mut Context<Self>,
821 ) -> Task<Result<SharedString>> {
822 let thread = self.open_thread(id.clone(), cx);
823 cx.spawn(async move |this, cx| {
824 let acp_thread = thread.await?;
825 let result = this
826 .update(cx, |this, cx| {
827 this.sessions
828 .get(&id)
829 .unwrap()
830 .thread
831 .update(cx, |thread, cx| thread.summary(cx))
832 })?
833 .await
834 .context("Failed to generate summary")?;
835 drop(acp_thread);
836 Ok(result)
837 })
838 }
839
840 fn save_thread(&mut self, thread: Entity<Thread>, cx: &mut Context<Self>) {
841 if thread.read(cx).is_empty() {
842 return;
843 }
844
845 let database_future = ThreadsDatabase::connect(cx);
846 let (id, db_thread) =
847 thread.update(cx, |thread, cx| (thread.id().clone(), thread.to_db(cx)));
848 let Some(session) = self.sessions.get_mut(&id) else {
849 return;
850 };
851 let thread_store = self.thread_store.clone();
852 session.pending_save = cx.spawn(async move |_, cx| {
853 let Some(database) = database_future.await.map_err(|err| anyhow!(err)).log_err() else {
854 return;
855 };
856 let db_thread = db_thread.await;
857 database.save_thread(id, db_thread).await.log_err();
858 thread_store.update(cx, |store, cx| store.reload(cx));
859 });
860 }
861
862 fn send_mcp_prompt(
863 &self,
864 message_id: UserMessageId,
865 session_id: agent_client_protocol::SessionId,
866 prompt_name: String,
867 server_id: ContextServerId,
868 arguments: HashMap<String, String>,
869 original_content: Vec<acp::ContentBlock>,
870 cx: &mut Context<Self>,
871 ) -> Task<Result<acp::PromptResponse>> {
872 let server_store = self.context_server_registry.read(cx).server_store().clone();
873 let path_style = self.project.read(cx).path_style(cx);
874
875 cx.spawn(async move |this, cx| {
876 let prompt =
877 crate::get_prompt(&server_store, &server_id, &prompt_name, arguments, cx).await?;
878
879 let (acp_thread, thread) = this.update(cx, |this, _cx| {
880 let session = this
881 .sessions
882 .get(&session_id)
883 .context("Failed to get session")?;
884 anyhow::Ok((session.acp_thread.clone(), session.thread.clone()))
885 })??;
886
887 let mut last_is_user = true;
888
889 thread.update(cx, |thread, cx| {
890 thread.push_acp_user_block(
891 message_id,
892 original_content.into_iter().skip(1),
893 path_style,
894 cx,
895 );
896 });
897
898 for message in prompt.messages {
899 let context_server::types::PromptMessage { role, content } = message;
900 let block = mcp_message_content_to_acp_content_block(content);
901
902 match role {
903 context_server::types::Role::User => {
904 let id = acp_thread::UserMessageId::new();
905
906 acp_thread.update(cx, |acp_thread, cx| {
907 acp_thread.push_user_content_block_with_indent(
908 Some(id.clone()),
909 block.clone(),
910 true,
911 cx,
912 );
913 });
914
915 thread.update(cx, |thread, cx| {
916 thread.push_acp_user_block(id, [block], path_style, cx);
917 });
918 }
919 context_server::types::Role::Assistant => {
920 acp_thread.update(cx, |acp_thread, cx| {
921 acp_thread.push_assistant_content_block_with_indent(
922 block.clone(),
923 false,
924 true,
925 cx,
926 );
927 });
928
929 thread.update(cx, |thread, cx| {
930 thread.push_acp_agent_block(block, cx);
931 });
932 }
933 }
934
935 last_is_user = role == context_server::types::Role::User;
936 }
937
938 let response_stream = thread.update(cx, |thread, cx| {
939 if last_is_user {
940 thread.send_existing(cx)
941 } else {
942 // Resume if MCP prompt did not end with a user message
943 thread.resume(cx)
944 }
945 })?;
946
947 cx.update(|cx| {
948 NativeAgentConnection::handle_thread_events(
949 response_stream,
950 acp_thread.downgrade(),
951 cx,
952 )
953 })
954 .await
955 })
956 }
957}
958
959/// Wrapper struct that implements the AgentConnection trait
960#[derive(Clone)]
961pub struct NativeAgentConnection(pub Entity<NativeAgent>);
962
963impl NativeAgentConnection {
964 pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option<Entity<Thread>> {
965 self.0
966 .read(cx)
967 .sessions
968 .get(session_id)
969 .map(|session| session.thread.clone())
970 }
971
972 pub fn load_thread(&self, id: acp::SessionId, cx: &mut App) -> Task<Result<Entity<Thread>>> {
973 self.0.update(cx, |this, cx| this.load_thread(id, cx))
974 }
975
976 fn run_turn(
977 &self,
978 session_id: acp::SessionId,
979 cx: &mut App,
980 f: impl 'static
981 + FnOnce(Entity<Thread>, &mut App) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>,
982 ) -> Task<Result<acp::PromptResponse>> {
983 let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
984 agent
985 .sessions
986 .get_mut(&session_id)
987 .map(|s| (s.thread.clone(), s.acp_thread.clone()))
988 }) else {
989 return Task::ready(Err(anyhow!("Session not found")));
990 };
991 log::debug!("Found session for: {}", session_id);
992
993 let response_stream = match f(thread, cx) {
994 Ok(stream) => stream,
995 Err(err) => return Task::ready(Err(err)),
996 };
997 Self::handle_thread_events(response_stream, acp_thread.downgrade(), cx)
998 }
999
1000 fn handle_thread_events(
1001 mut events: mpsc::UnboundedReceiver<Result<ThreadEvent>>,
1002 acp_thread: WeakEntity<AcpThread>,
1003 cx: &App,
1004 ) -> Task<Result<acp::PromptResponse>> {
1005 cx.spawn(async move |cx| {
1006 // Handle response stream and forward to session.acp_thread
1007 while let Some(result) = events.next().await {
1008 match result {
1009 Ok(event) => {
1010 log::trace!("Received completion event: {:?}", event);
1011
1012 match event {
1013 ThreadEvent::UserMessage(message) => {
1014 acp_thread.update(cx, |thread, cx| {
1015 for content in message.content {
1016 thread.push_user_content_block(
1017 Some(message.id.clone()),
1018 content.into(),
1019 cx,
1020 );
1021 }
1022 })?;
1023 }
1024 ThreadEvent::AgentText(text) => {
1025 acp_thread.update(cx, |thread, cx| {
1026 thread.push_assistant_content_block(text.into(), false, cx)
1027 })?;
1028 }
1029 ThreadEvent::AgentThinking(text) => {
1030 acp_thread.update(cx, |thread, cx| {
1031 thread.push_assistant_content_block(text.into(), true, cx)
1032 })?;
1033 }
1034 ThreadEvent::ToolCallAuthorization(ToolCallAuthorization {
1035 tool_call,
1036 options,
1037 response,
1038 context: _,
1039 }) => {
1040 let outcome_task = acp_thread.update(cx, |thread, cx| {
1041 thread.request_tool_call_authorization(
1042 tool_call, options, true, cx,
1043 )
1044 })??;
1045 cx.background_spawn(async move {
1046 if let acp::RequestPermissionOutcome::Selected(
1047 acp::SelectedPermissionOutcome { option_id, .. },
1048 ) = outcome_task.await
1049 {
1050 response
1051 .send(option_id)
1052 .map(|_| anyhow!("authorization receiver was dropped"))
1053 .log_err();
1054 }
1055 })
1056 .detach();
1057 }
1058 ThreadEvent::ToolCall(tool_call) => {
1059 acp_thread.update(cx, |thread, cx| {
1060 thread.upsert_tool_call(tool_call, cx)
1061 })??;
1062 }
1063 ThreadEvent::ToolCallUpdate(update) => {
1064 acp_thread.update(cx, |thread, cx| {
1065 thread.update_tool_call(update, cx)
1066 })??;
1067 }
1068 ThreadEvent::SubagentSpawned(session_id) => {
1069 acp_thread.update(cx, |thread, cx| {
1070 thread.subagent_spawned(session_id, cx);
1071 })?;
1072 }
1073 ThreadEvent::Retry(status) => {
1074 acp_thread.update(cx, |thread, cx| {
1075 thread.update_retry_status(status, cx)
1076 })?;
1077 }
1078 ThreadEvent::Stop(stop_reason) => {
1079 log::debug!("Assistant message complete: {:?}", stop_reason);
1080 return Ok(acp::PromptResponse::new(stop_reason));
1081 }
1082 }
1083 }
1084 Err(e) => {
1085 log::error!("Error in model response stream: {:?}", e);
1086 return Err(e);
1087 }
1088 }
1089 }
1090
1091 log::debug!("Response stream completed");
1092 anyhow::Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
1093 })
1094 }
1095}
1096
1097struct Command<'a> {
1098 prompt_name: &'a str,
1099 arg_value: &'a str,
1100 explicit_server_id: Option<&'a str>,
1101}
1102
1103impl<'a> Command<'a> {
1104 fn parse(prompt: &'a [acp::ContentBlock]) -> Option<Self> {
1105 let acp::ContentBlock::Text(text_content) = prompt.first()? else {
1106 return None;
1107 };
1108 let text = text_content.text.trim();
1109 let command = text.strip_prefix('/')?;
1110 let (command, arg_value) = command
1111 .split_once(char::is_whitespace)
1112 .unwrap_or((command, ""));
1113
1114 if let Some((server_id, prompt_name)) = command.split_once('.') {
1115 Some(Self {
1116 prompt_name,
1117 arg_value,
1118 explicit_server_id: Some(server_id),
1119 })
1120 } else {
1121 Some(Self {
1122 prompt_name: command,
1123 arg_value,
1124 explicit_server_id: None,
1125 })
1126 }
1127 }
1128}
1129
1130struct NativeAgentModelSelector {
1131 session_id: acp::SessionId,
1132 connection: NativeAgentConnection,
1133}
1134
1135impl acp_thread::AgentModelSelector for NativeAgentModelSelector {
1136 fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
1137 log::debug!("NativeAgentConnection::list_models called");
1138 let list = self.connection.0.read(cx).models.model_list.clone();
1139 Task::ready(if list.is_empty() {
1140 Err(anyhow::anyhow!("No models available"))
1141 } else {
1142 Ok(list)
1143 })
1144 }
1145
1146 fn select_model(&self, model_id: acp::ModelId, cx: &mut App) -> Task<Result<()>> {
1147 log::debug!(
1148 "Setting model for session {}: {}",
1149 self.session_id,
1150 model_id
1151 );
1152 let Some(thread) = self
1153 .connection
1154 .0
1155 .read(cx)
1156 .sessions
1157 .get(&self.session_id)
1158 .map(|session| session.thread.clone())
1159 else {
1160 return Task::ready(Err(anyhow!("Session not found")));
1161 };
1162
1163 let Some(model) = self.connection.0.read(cx).models.model_from_id(&model_id) else {
1164 return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
1165 };
1166
1167 // We want to reset the effort level when switching models, as the currently-selected effort level may
1168 // not be compatible.
1169 let effort = model
1170 .default_effort_level()
1171 .map(|effort_level| effort_level.value.to_string());
1172
1173 thread.update(cx, |thread, cx| {
1174 thread.set_model(model.clone(), cx);
1175 thread.set_thinking_effort(effort.clone(), cx);
1176 thread.set_thinking_enabled(model.supports_thinking(), cx);
1177 });
1178
1179 update_settings_file(
1180 self.connection.0.read(cx).fs.clone(),
1181 cx,
1182 move |settings, cx| {
1183 let provider = model.provider_id().0.to_string();
1184 let model = model.id().0.to_string();
1185 let enable_thinking = thread.read(cx).thinking_enabled();
1186 settings
1187 .agent
1188 .get_or_insert_default()
1189 .set_model(LanguageModelSelection {
1190 provider: provider.into(),
1191 model,
1192 enable_thinking,
1193 effort,
1194 });
1195 },
1196 );
1197
1198 Task::ready(Ok(()))
1199 }
1200
1201 fn selected_model(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelInfo>> {
1202 let Some(thread) = self
1203 .connection
1204 .0
1205 .read(cx)
1206 .sessions
1207 .get(&self.session_id)
1208 .map(|session| session.thread.clone())
1209 else {
1210 return Task::ready(Err(anyhow!("Session not found")));
1211 };
1212 let Some(model) = thread.read(cx).model() else {
1213 return Task::ready(Err(anyhow!("Model not found")));
1214 };
1215 let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
1216 else {
1217 return Task::ready(Err(anyhow!("Provider not found")));
1218 };
1219 Task::ready(Ok(LanguageModels::map_language_model_to_info(
1220 model, &provider,
1221 )))
1222 }
1223
1224 fn watch(&self, cx: &mut App) -> Option<watch::Receiver<()>> {
1225 Some(self.connection.0.read(cx).models.watch())
1226 }
1227
1228 fn should_render_footer(&self) -> bool {
1229 true
1230 }
1231}
1232
1233impl acp_thread::AgentConnection for NativeAgentConnection {
1234 fn telemetry_id(&self) -> SharedString {
1235 "zed".into()
1236 }
1237
1238 fn new_session(
1239 self: Rc<Self>,
1240 project: Entity<Project>,
1241 cwd: &Path,
1242 cx: &mut App,
1243 ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
1244 log::debug!("Creating new thread for project at: {cwd:?}");
1245 Task::ready(Ok(self
1246 .0
1247 .update(cx, |agent, cx| agent.new_session(project, cx))))
1248 }
1249
1250 fn supports_load_session(&self, _cx: &App) -> bool {
1251 true
1252 }
1253
1254 fn load_session(
1255 self: Rc<Self>,
1256 session: AgentSessionInfo,
1257 _project: Entity<Project>,
1258 _cwd: &Path,
1259 cx: &mut App,
1260 ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
1261 self.0
1262 .update(cx, |agent, cx| agent.open_thread(session.session_id, cx))
1263 }
1264
1265 fn supports_close_session(&self, _cx: &App) -> bool {
1266 true
1267 }
1268
1269 fn close_session(&self, session_id: &acp::SessionId, cx: &mut App) -> Task<Result<()>> {
1270 self.0.update(cx, |agent, _cx| {
1271 agent.sessions.remove(session_id);
1272 });
1273 Task::ready(Ok(()))
1274 }
1275
1276 fn auth_methods(&self) -> &[acp::AuthMethod] {
1277 &[] // No auth for in-process
1278 }
1279
1280 fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
1281 Task::ready(Ok(()))
1282 }
1283
1284 fn model_selector(&self, session_id: &acp::SessionId) -> Option<Rc<dyn AgentModelSelector>> {
1285 Some(Rc::new(NativeAgentModelSelector {
1286 session_id: session_id.clone(),
1287 connection: self.clone(),
1288 }) as Rc<dyn AgentModelSelector>)
1289 }
1290
1291 fn prompt(
1292 &self,
1293 id: Option<acp_thread::UserMessageId>,
1294 params: acp::PromptRequest,
1295 cx: &mut App,
1296 ) -> Task<Result<acp::PromptResponse>> {
1297 let id = id.expect("UserMessageId is required");
1298 let session_id = params.session_id.clone();
1299 log::info!("Received prompt request for session: {}", session_id);
1300 log::debug!("Prompt blocks count: {}", params.prompt.len());
1301
1302 if let Some(parsed_command) = Command::parse(¶ms.prompt) {
1303 let registry = self.0.read(cx).context_server_registry.read(cx);
1304
1305 let explicit_server_id = parsed_command
1306 .explicit_server_id
1307 .map(|server_id| ContextServerId(server_id.into()));
1308
1309 if let Some(prompt) =
1310 registry.find_prompt(explicit_server_id.as_ref(), parsed_command.prompt_name)
1311 {
1312 let arguments = if !parsed_command.arg_value.is_empty()
1313 && let Some(arg_name) = prompt
1314 .prompt
1315 .arguments
1316 .as_ref()
1317 .and_then(|args| args.first())
1318 .map(|arg| arg.name.clone())
1319 {
1320 HashMap::from_iter([(arg_name, parsed_command.arg_value.to_string())])
1321 } else {
1322 Default::default()
1323 };
1324
1325 let prompt_name = prompt.prompt.name.clone();
1326 let server_id = prompt.server_id.clone();
1327
1328 return self.0.update(cx, |agent, cx| {
1329 agent.send_mcp_prompt(
1330 id,
1331 session_id.clone(),
1332 prompt_name,
1333 server_id,
1334 arguments,
1335 params.prompt,
1336 cx,
1337 )
1338 });
1339 };
1340 };
1341
1342 let path_style = self.0.read(cx).project.read(cx).path_style(cx);
1343
1344 self.run_turn(session_id, cx, move |thread, cx| {
1345 let content: Vec<UserMessageContent> = params
1346 .prompt
1347 .into_iter()
1348 .map(|block| UserMessageContent::from_content_block(block, path_style))
1349 .collect::<Vec<_>>();
1350 log::debug!("Converted prompt to message: {} chars", content.len());
1351 log::debug!("Message id: {:?}", id);
1352 log::debug!("Message content: {:?}", content);
1353
1354 thread.update(cx, |thread, cx| thread.send(id, content, cx))
1355 })
1356 }
1357
1358 fn retry(
1359 &self,
1360 session_id: &acp::SessionId,
1361 _cx: &App,
1362 ) -> Option<Rc<dyn acp_thread::AgentSessionRetry>> {
1363 Some(Rc::new(NativeAgentSessionRetry {
1364 connection: self.clone(),
1365 session_id: session_id.clone(),
1366 }) as _)
1367 }
1368
1369 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
1370 log::info!("Cancelling on session: {}", session_id);
1371 self.0.update(cx, |agent, cx| {
1372 if let Some(agent) = agent.sessions.get(session_id) {
1373 agent
1374 .thread
1375 .update(cx, |thread, cx| thread.cancel(cx))
1376 .detach();
1377 }
1378 });
1379 }
1380
1381 fn truncate(
1382 &self,
1383 session_id: &agent_client_protocol::SessionId,
1384 cx: &App,
1385 ) -> Option<Rc<dyn acp_thread::AgentSessionTruncate>> {
1386 self.0.read_with(cx, |agent, _cx| {
1387 agent.sessions.get(session_id).map(|session| {
1388 Rc::new(NativeAgentSessionTruncate {
1389 thread: session.thread.clone(),
1390 acp_thread: session.acp_thread.downgrade(),
1391 }) as _
1392 })
1393 })
1394 }
1395
1396 fn set_title(
1397 &self,
1398 session_id: &acp::SessionId,
1399 _cx: &App,
1400 ) -> Option<Rc<dyn acp_thread::AgentSessionSetTitle>> {
1401 Some(Rc::new(NativeAgentSessionSetTitle {
1402 connection: self.clone(),
1403 session_id: session_id.clone(),
1404 }) as _)
1405 }
1406
1407 fn session_list(&self, cx: &mut App) -> Option<Rc<dyn AgentSessionList>> {
1408 let thread_store = self.0.read(cx).thread_store.clone();
1409 Some(Rc::new(NativeAgentSessionList::new(thread_store, cx)) as _)
1410 }
1411
1412 fn telemetry(&self) -> Option<Rc<dyn acp_thread::AgentTelemetry>> {
1413 Some(Rc::new(self.clone()) as Rc<dyn acp_thread::AgentTelemetry>)
1414 }
1415
1416 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1417 self
1418 }
1419}
1420
1421impl acp_thread::AgentTelemetry for NativeAgentConnection {
1422 fn thread_data(
1423 &self,
1424 session_id: &acp::SessionId,
1425 cx: &mut App,
1426 ) -> Task<Result<serde_json::Value>> {
1427 let Some(session) = self.0.read(cx).sessions.get(session_id) else {
1428 return Task::ready(Err(anyhow!("Session not found")));
1429 };
1430
1431 let task = session.thread.read(cx).to_db(cx);
1432 cx.background_spawn(async move {
1433 serde_json::to_value(task.await).context("Failed to serialize thread")
1434 })
1435 }
1436}
1437
1438pub struct NativeAgentSessionList {
1439 thread_store: Entity<ThreadStore>,
1440 updates_tx: smol::channel::Sender<acp_thread::SessionListUpdate>,
1441 updates_rx: smol::channel::Receiver<acp_thread::SessionListUpdate>,
1442 _subscription: Subscription,
1443}
1444
1445impl NativeAgentSessionList {
1446 fn new(thread_store: Entity<ThreadStore>, cx: &mut App) -> Self {
1447 let (tx, rx) = smol::channel::unbounded();
1448 let this_tx = tx.clone();
1449 let subscription = cx.observe(&thread_store, move |_, _| {
1450 this_tx
1451 .try_send(acp_thread::SessionListUpdate::Refresh)
1452 .ok();
1453 });
1454 Self {
1455 thread_store,
1456 updates_tx: tx,
1457 updates_rx: rx,
1458 _subscription: subscription,
1459 }
1460 }
1461
1462 fn to_session_info(entry: DbThreadMetadata) -> AgentSessionInfo {
1463 AgentSessionInfo {
1464 session_id: entry.id,
1465 cwd: None,
1466 title: Some(entry.title),
1467 updated_at: Some(entry.updated_at),
1468 meta: None,
1469 }
1470 }
1471
1472 pub fn thread_store(&self) -> &Entity<ThreadStore> {
1473 &self.thread_store
1474 }
1475}
1476
1477impl AgentSessionList for NativeAgentSessionList {
1478 fn list_sessions(
1479 &self,
1480 _request: AgentSessionListRequest,
1481 cx: &mut App,
1482 ) -> Task<Result<AgentSessionListResponse>> {
1483 let sessions = self
1484 .thread_store
1485 .read(cx)
1486 .entries()
1487 .map(Self::to_session_info)
1488 .collect();
1489 Task::ready(Ok(AgentSessionListResponse::new(sessions)))
1490 }
1491
1492 fn supports_delete(&self) -> bool {
1493 true
1494 }
1495
1496 fn delete_session(&self, session_id: &acp::SessionId, cx: &mut App) -> Task<Result<()>> {
1497 self.thread_store
1498 .update(cx, |store, cx| store.delete_thread(session_id.clone(), cx))
1499 }
1500
1501 fn delete_sessions(&self, cx: &mut App) -> Task<Result<()>> {
1502 self.thread_store
1503 .update(cx, |store, cx| store.delete_threads(cx))
1504 }
1505
1506 fn watch(
1507 &self,
1508 _cx: &mut App,
1509 ) -> Option<smol::channel::Receiver<acp_thread::SessionListUpdate>> {
1510 Some(self.updates_rx.clone())
1511 }
1512
1513 fn notify_refresh(&self) {
1514 self.updates_tx
1515 .try_send(acp_thread::SessionListUpdate::Refresh)
1516 .ok();
1517 }
1518
1519 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1520 self
1521 }
1522}
1523
1524struct NativeAgentSessionTruncate {
1525 thread: Entity<Thread>,
1526 acp_thread: WeakEntity<AcpThread>,
1527}
1528
1529impl acp_thread::AgentSessionTruncate for NativeAgentSessionTruncate {
1530 fn run(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
1531 match self.thread.update(cx, |thread, cx| {
1532 thread.truncate(message_id.clone(), cx)?;
1533 Ok(thread.latest_token_usage())
1534 }) {
1535 Ok(usage) => {
1536 self.acp_thread
1537 .update(cx, |thread, cx| {
1538 thread.update_token_usage(usage, cx);
1539 })
1540 .ok();
1541 Task::ready(Ok(()))
1542 }
1543 Err(error) => Task::ready(Err(error)),
1544 }
1545 }
1546}
1547
1548struct NativeAgentSessionRetry {
1549 connection: NativeAgentConnection,
1550 session_id: acp::SessionId,
1551}
1552
1553impl acp_thread::AgentSessionRetry for NativeAgentSessionRetry {
1554 fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>> {
1555 self.connection
1556 .run_turn(self.session_id.clone(), cx, |thread, cx| {
1557 thread.update(cx, |thread, cx| thread.resume(cx))
1558 })
1559 }
1560}
1561
1562struct NativeAgentSessionSetTitle {
1563 connection: NativeAgentConnection,
1564 session_id: acp::SessionId,
1565}
1566
1567impl acp_thread::AgentSessionSetTitle for NativeAgentSessionSetTitle {
1568 fn run(&self, title: SharedString, cx: &mut App) -> Task<Result<()>> {
1569 let Some(session) = self.connection.0.read(cx).sessions.get(&self.session_id) else {
1570 return Task::ready(Err(anyhow!("session not found")));
1571 };
1572 let thread = session.thread.clone();
1573 thread.update(cx, |thread, cx| thread.set_title(title, cx));
1574 Task::ready(Ok(()))
1575 }
1576}
1577
1578pub struct NativeThreadEnvironment {
1579 agent: WeakEntity<NativeAgent>,
1580 acp_thread: WeakEntity<AcpThread>,
1581}
1582
1583impl NativeThreadEnvironment {
1584 pub(crate) fn create_subagent_thread(
1585 agent: WeakEntity<NativeAgent>,
1586 parent_thread_entity: Entity<Thread>,
1587 label: String,
1588 initial_prompt: String,
1589 timeout: Option<Duration>,
1590 allowed_tools: Option<Vec<String>>,
1591 cx: &mut App,
1592 ) -> Result<Rc<dyn SubagentHandle>> {
1593 let parent_thread = parent_thread_entity.read(cx);
1594 let current_depth = parent_thread.depth();
1595
1596 if current_depth >= MAX_SUBAGENT_DEPTH {
1597 return Err(anyhow!(
1598 "Maximum subagent depth ({}) reached",
1599 MAX_SUBAGENT_DEPTH
1600 ));
1601 }
1602
1603 let running_count = parent_thread.running_subagent_count();
1604 if running_count >= MAX_PARALLEL_SUBAGENTS {
1605 return Err(anyhow!(
1606 "Maximum parallel subagents ({}) reached. Wait for existing subagents to complete.",
1607 MAX_PARALLEL_SUBAGENTS
1608 ));
1609 }
1610
1611 let allowed_tools = match allowed_tools {
1612 Some(tools) => {
1613 let parent_tool_names: std::collections::HashSet<&str> =
1614 parent_thread.tools.keys().map(|s| s.as_str()).collect();
1615 Some(
1616 tools
1617 .into_iter()
1618 .filter(|t| parent_tool_names.contains(t.as_str()))
1619 .collect::<Vec<_>>(),
1620 )
1621 }
1622 None => Some(parent_thread.tools.keys().map(|s| s.to_string()).collect()),
1623 };
1624
1625 let subagent_thread: Entity<Thread> = cx.new(|cx| {
1626 let mut thread = Thread::new_subagent(&parent_thread_entity, cx);
1627 thread.set_title(label.into(), cx);
1628 thread
1629 });
1630
1631 let session_id = subagent_thread.read(cx).id().clone();
1632
1633 let acp_thread = agent.update(cx, |agent, cx| {
1634 agent.register_session(
1635 subagent_thread.clone(),
1636 allowed_tools
1637 .as_ref()
1638 .map(|v| v.iter().map(|s| s.as_str()).collect()),
1639 cx,
1640 )
1641 })?;
1642
1643 parent_thread_entity.update(cx, |parent_thread, _cx| {
1644 parent_thread.register_running_subagent(subagent_thread.downgrade())
1645 });
1646
1647 let task = acp_thread.update(cx, |agent, cx| agent.send(vec![initial_prompt.into()], cx));
1648
1649 let timeout_timer = timeout.map(|d| cx.background_executor().timer(d));
1650 let wait_for_prompt_to_complete = cx
1651 .background_spawn(async move {
1652 if let Some(timer) = timeout_timer {
1653 futures::select! {
1654 _ = timer.fuse() => SubagentInitialPromptResult::Timeout,
1655 _ = task.fuse() => SubagentInitialPromptResult::Completed,
1656 }
1657 } else {
1658 task.await.log_err();
1659 SubagentInitialPromptResult::Completed
1660 }
1661 })
1662 .shared();
1663
1664 let mut user_stop_rx: watch::Receiver<bool> =
1665 acp_thread.update(cx, |thread, _| thread.user_stop_receiver());
1666
1667 let user_cancelled = cx
1668 .background_spawn(async move {
1669 loop {
1670 if *user_stop_rx.borrow() {
1671 return;
1672 }
1673 if user_stop_rx.changed().await.is_err() {
1674 std::future::pending::<()>().await;
1675 }
1676 }
1677 })
1678 .shared();
1679
1680 Ok(Rc::new(NativeSubagentHandle {
1681 session_id,
1682 subagent_thread,
1683 parent_thread: parent_thread_entity.downgrade(),
1684 acp_thread,
1685 wait_for_prompt_to_complete,
1686 user_cancelled,
1687 }) as _)
1688 }
1689}
1690
1691impl ThreadEnvironment for NativeThreadEnvironment {
1692 fn create_terminal(
1693 &self,
1694 command: String,
1695 cwd: Option<PathBuf>,
1696 output_byte_limit: Option<u64>,
1697 cx: &mut AsyncApp,
1698 ) -> Task<Result<Rc<dyn TerminalHandle>>> {
1699 let task = self.acp_thread.update(cx, |thread, cx| {
1700 thread.create_terminal(command, vec![], vec![], cwd, output_byte_limit, cx)
1701 });
1702
1703 let acp_thread = self.acp_thread.clone();
1704 cx.spawn(async move |cx| {
1705 let terminal = task?.await?;
1706
1707 let (drop_tx, drop_rx) = oneshot::channel();
1708 let terminal_id = terminal.read_with(cx, |terminal, _cx| terminal.id().clone());
1709
1710 cx.spawn(async move |cx| {
1711 drop_rx.await.ok();
1712 acp_thread.update(cx, |thread, cx| thread.release_terminal(terminal_id, cx))
1713 })
1714 .detach();
1715
1716 let handle = AcpTerminalHandle {
1717 terminal,
1718 _drop_tx: Some(drop_tx),
1719 };
1720
1721 Ok(Rc::new(handle) as _)
1722 })
1723 }
1724
1725 fn create_subagent(
1726 &self,
1727 parent_thread_entity: Entity<Thread>,
1728 label: String,
1729 initial_prompt: String,
1730 timeout: Option<Duration>,
1731 allowed_tools: Option<Vec<String>>,
1732 cx: &mut App,
1733 ) -> Result<Rc<dyn SubagentHandle>> {
1734 Self::create_subagent_thread(
1735 self.agent.clone(),
1736 parent_thread_entity,
1737 label,
1738 initial_prompt,
1739 timeout,
1740 allowed_tools,
1741 cx,
1742 )
1743 }
1744}
1745
1746#[derive(Debug, Clone, Copy)]
1747enum SubagentInitialPromptResult {
1748 Completed,
1749 Timeout,
1750}
1751
1752pub struct NativeSubagentHandle {
1753 session_id: acp::SessionId,
1754 parent_thread: WeakEntity<Thread>,
1755 subagent_thread: Entity<Thread>,
1756 acp_thread: Entity<AcpThread>,
1757 wait_for_prompt_to_complete: Shared<Task<SubagentInitialPromptResult>>,
1758 user_cancelled: Shared<Task<()>>,
1759}
1760
1761impl SubagentHandle for NativeSubagentHandle {
1762 fn id(&self) -> acp::SessionId {
1763 self.session_id.clone()
1764 }
1765
1766 fn wait_for_summary(&self, summary_prompt: String, cx: &AsyncApp) -> Task<Result<String>> {
1767 let thread = self.subagent_thread.clone();
1768 let acp_thread = self.acp_thread.clone();
1769 let wait_for_prompt = self.wait_for_prompt_to_complete.clone();
1770
1771 let wait_for_summary_task = cx.spawn(async move |cx| {
1772 let timed_out = match wait_for_prompt.await {
1773 SubagentInitialPromptResult::Completed => false,
1774 SubagentInitialPromptResult::Timeout => true,
1775 };
1776
1777 let summary_prompt = if timed_out {
1778 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
1779 format!("{}\n{}", "The time to complete the task was exceeded. Stop with the task and follow the directions below:", summary_prompt)
1780 } else {
1781 summary_prompt
1782 };
1783
1784 acp_thread
1785 .update(cx, |thread, cx| thread.send(vec![summary_prompt.into()], cx))
1786 .await?;
1787
1788 thread.read_with(cx, |thread, _cx| {
1789 thread
1790 .last_message()
1791 .map(|m| m.to_markdown())
1792 .context("No response from subagent")
1793 })
1794 });
1795
1796 let user_cancelled = self.user_cancelled.clone();
1797 let thread = self.subagent_thread.clone();
1798 let subagent_session_id = self.session_id.clone();
1799 let parent_thread = self.parent_thread.clone();
1800 cx.spawn(async move |cx| {
1801 let result = futures::select! {
1802 result = wait_for_summary_task.fuse() => result,
1803 _ = user_cancelled.fuse() => {
1804 thread.update(cx, |thread, cx| thread.cancel(cx).detach());
1805 Err(anyhow!("User cancelled"))
1806 },
1807 };
1808 parent_thread
1809 .update(cx, |parent_thread, cx| {
1810 parent_thread.unregister_running_subagent(&subagent_session_id, cx)
1811 })
1812 .ok();
1813 result
1814 })
1815 }
1816}
1817
1818pub struct AcpTerminalHandle {
1819 terminal: Entity<acp_thread::Terminal>,
1820 _drop_tx: Option<oneshot::Sender<()>>,
1821}
1822
1823impl TerminalHandle for AcpTerminalHandle {
1824 fn id(&self, cx: &AsyncApp) -> Result<acp::TerminalId> {
1825 Ok(self.terminal.read_with(cx, |term, _cx| term.id().clone()))
1826 }
1827
1828 fn wait_for_exit(&self, cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
1829 Ok(self
1830 .terminal
1831 .read_with(cx, |term, _cx| term.wait_for_exit()))
1832 }
1833
1834 fn current_output(&self, cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
1835 Ok(self
1836 .terminal
1837 .read_with(cx, |term, cx| term.current_output(cx)))
1838 }
1839
1840 fn kill(&self, cx: &AsyncApp) -> Result<()> {
1841 cx.update(|cx| {
1842 self.terminal.update(cx, |terminal, cx| {
1843 terminal.kill(cx);
1844 });
1845 });
1846 Ok(())
1847 }
1848
1849 fn was_stopped_by_user(&self, cx: &AsyncApp) -> Result<bool> {
1850 Ok(self
1851 .terminal
1852 .read_with(cx, |term, _cx| term.was_stopped_by_user()))
1853 }
1854}
1855
1856#[cfg(test)]
1857mod internal_tests {
1858 use super::*;
1859 use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelInfo, MentionUri};
1860 use fs::FakeFs;
1861 use gpui::TestAppContext;
1862 use indoc::formatdoc;
1863 use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
1864 use language_model::{LanguageModelProviderId, LanguageModelProviderName};
1865 use serde_json::json;
1866 use settings::SettingsStore;
1867 use util::{path, rel_path::rel_path};
1868
1869 #[gpui::test]
1870 async fn test_maintaining_project_context(cx: &mut TestAppContext) {
1871 init_test(cx);
1872 let fs = FakeFs::new(cx.executor());
1873 fs.insert_tree(
1874 "/",
1875 json!({
1876 "a": {}
1877 }),
1878 )
1879 .await;
1880 let project = Project::test(fs.clone(), [], cx).await;
1881 let thread_store = cx.new(|cx| ThreadStore::new(cx));
1882 let agent = NativeAgent::new(
1883 project.clone(),
1884 thread_store,
1885 Templates::new(),
1886 None,
1887 fs.clone(),
1888 &mut cx.to_async(),
1889 )
1890 .await
1891 .unwrap();
1892 agent.read_with(cx, |agent, cx| {
1893 assert_eq!(agent.project_context.read(cx).worktrees, vec![])
1894 });
1895
1896 let worktree = project
1897 .update(cx, |project, cx| project.create_worktree("/a", true, cx))
1898 .await
1899 .unwrap();
1900 cx.run_until_parked();
1901 agent.read_with(cx, |agent, cx| {
1902 assert_eq!(
1903 agent.project_context.read(cx).worktrees,
1904 vec![WorktreeContext {
1905 root_name: "a".into(),
1906 abs_path: Path::new("/a").into(),
1907 rules_file: None
1908 }]
1909 )
1910 });
1911
1912 // Creating `/a/.rules` updates the project context.
1913 fs.insert_file("/a/.rules", Vec::new()).await;
1914 cx.run_until_parked();
1915 agent.read_with(cx, |agent, cx| {
1916 let rules_entry = worktree
1917 .read(cx)
1918 .entry_for_path(rel_path(".rules"))
1919 .unwrap();
1920 assert_eq!(
1921 agent.project_context.read(cx).worktrees,
1922 vec![WorktreeContext {
1923 root_name: "a".into(),
1924 abs_path: Path::new("/a").into(),
1925 rules_file: Some(RulesFileContext {
1926 path_in_worktree: rel_path(".rules").into(),
1927 text: "".into(),
1928 project_entry_id: rules_entry.id.to_usize()
1929 })
1930 }]
1931 )
1932 });
1933 }
1934
1935 #[gpui::test]
1936 async fn test_listing_models(cx: &mut TestAppContext) {
1937 init_test(cx);
1938 let fs = FakeFs::new(cx.executor());
1939 fs.insert_tree("/", json!({ "a": {} })).await;
1940 let project = Project::test(fs.clone(), [], cx).await;
1941 let thread_store = cx.new(|cx| ThreadStore::new(cx));
1942 let connection = NativeAgentConnection(
1943 NativeAgent::new(
1944 project.clone(),
1945 thread_store,
1946 Templates::new(),
1947 None,
1948 fs.clone(),
1949 &mut cx.to_async(),
1950 )
1951 .await
1952 .unwrap(),
1953 );
1954
1955 // Create a thread/session
1956 let acp_thread = cx
1957 .update(|cx| {
1958 Rc::new(connection.clone()).new_session(project.clone(), Path::new("/a"), cx)
1959 })
1960 .await
1961 .unwrap();
1962
1963 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
1964
1965 let models = cx
1966 .update(|cx| {
1967 connection
1968 .model_selector(&session_id)
1969 .unwrap()
1970 .list_models(cx)
1971 })
1972 .await
1973 .unwrap();
1974
1975 let acp_thread::AgentModelList::Grouped(models) = models else {
1976 panic!("Unexpected model group");
1977 };
1978 assert_eq!(
1979 models,
1980 IndexMap::from_iter([(
1981 AgentModelGroupName("Fake".into()),
1982 vec![AgentModelInfo {
1983 id: acp::ModelId::new("fake/fake"),
1984 name: "Fake".into(),
1985 description: None,
1986 icon: Some(acp_thread::AgentModelIcon::Named(
1987 ui::IconName::ZedAssistant
1988 )),
1989 is_latest: false,
1990 }]
1991 )])
1992 );
1993 }
1994
1995 #[gpui::test]
1996 async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) {
1997 init_test(cx);
1998 let fs = FakeFs::new(cx.executor());
1999 fs.create_dir(paths::settings_file().parent().unwrap())
2000 .await
2001 .unwrap();
2002 fs.insert_file(
2003 paths::settings_file(),
2004 json!({
2005 "agent": {
2006 "default_model": {
2007 "provider": "foo",
2008 "model": "bar"
2009 }
2010 }
2011 })
2012 .to_string()
2013 .into_bytes(),
2014 )
2015 .await;
2016 let project = Project::test(fs.clone(), [], cx).await;
2017
2018 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2019
2020 // Create the agent and connection
2021 let agent = NativeAgent::new(
2022 project.clone(),
2023 thread_store,
2024 Templates::new(),
2025 None,
2026 fs.clone(),
2027 &mut cx.to_async(),
2028 )
2029 .await
2030 .unwrap();
2031 let connection = NativeAgentConnection(agent.clone());
2032
2033 // Create a thread/session
2034 let acp_thread = cx
2035 .update(|cx| {
2036 Rc::new(connection.clone()).new_session(project.clone(), Path::new("/a"), cx)
2037 })
2038 .await
2039 .unwrap();
2040
2041 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2042
2043 // Select a model
2044 let selector = connection.model_selector(&session_id).unwrap();
2045 let model_id = acp::ModelId::new("fake/fake");
2046 cx.update(|cx| selector.select_model(model_id.clone(), cx))
2047 .await
2048 .unwrap();
2049
2050 // Verify the thread has the selected model
2051 agent.read_with(cx, |agent, _| {
2052 let session = agent.sessions.get(&session_id).unwrap();
2053 session.thread.read_with(cx, |thread, _| {
2054 assert_eq!(thread.model().unwrap().id().0, "fake");
2055 });
2056 });
2057
2058 cx.run_until_parked();
2059
2060 // Verify settings file was updated
2061 let settings_content = fs.load(paths::settings_file()).await.unwrap();
2062 let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
2063
2064 // Check that the agent settings contain the selected model
2065 assert_eq!(
2066 settings_json["agent"]["default_model"]["model"],
2067 json!("fake")
2068 );
2069 assert_eq!(
2070 settings_json["agent"]["default_model"]["provider"],
2071 json!("fake")
2072 );
2073
2074 // Register a thinking model and select it.
2075 cx.update(|cx| {
2076 let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2077 "fake-corp",
2078 "fake-thinking",
2079 "Fake Thinking",
2080 true,
2081 ));
2082 let thinking_provider = Arc::new(
2083 FakeLanguageModelProvider::new(
2084 LanguageModelProviderId::from("fake-corp".to_string()),
2085 LanguageModelProviderName::from("Fake Corp".to_string()),
2086 )
2087 .with_models(vec![thinking_model]),
2088 );
2089 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2090 registry.register_provider(thinking_provider, cx);
2091 });
2092 });
2093 agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2094
2095 let selector = connection.model_selector(&session_id).unwrap();
2096 cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2097 .await
2098 .unwrap();
2099 cx.run_until_parked();
2100
2101 // Verify enable_thinking was written to settings as true.
2102 let settings_content = fs.load(paths::settings_file()).await.unwrap();
2103 let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
2104 assert_eq!(
2105 settings_json["agent"]["default_model"]["enable_thinking"],
2106 json!(true),
2107 "selecting a thinking model should persist enable_thinking: true to settings"
2108 );
2109 }
2110
2111 #[gpui::test]
2112 async fn test_select_model_updates_thinking_enabled(cx: &mut TestAppContext) {
2113 init_test(cx);
2114 let fs = FakeFs::new(cx.executor());
2115 fs.create_dir(paths::settings_file().parent().unwrap())
2116 .await
2117 .unwrap();
2118 fs.insert_file(paths::settings_file(), b"{}".to_vec()).await;
2119 let project = Project::test(fs.clone(), [], cx).await;
2120
2121 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2122 let agent = NativeAgent::new(
2123 project.clone(),
2124 thread_store,
2125 Templates::new(),
2126 None,
2127 fs.clone(),
2128 &mut cx.to_async(),
2129 )
2130 .await
2131 .unwrap();
2132 let connection = NativeAgentConnection(agent.clone());
2133
2134 let acp_thread = cx
2135 .update(|cx| {
2136 Rc::new(connection.clone()).new_session(project.clone(), Path::new("/a"), cx)
2137 })
2138 .await
2139 .unwrap();
2140 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2141
2142 // Register a second provider with a thinking model.
2143 cx.update(|cx| {
2144 let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2145 "fake-corp",
2146 "fake-thinking",
2147 "Fake Thinking",
2148 true,
2149 ));
2150 let thinking_provider = Arc::new(
2151 FakeLanguageModelProvider::new(
2152 LanguageModelProviderId::from("fake-corp".to_string()),
2153 LanguageModelProviderName::from("Fake Corp".to_string()),
2154 )
2155 .with_models(vec![thinking_model]),
2156 );
2157 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2158 registry.register_provider(thinking_provider, cx);
2159 });
2160 });
2161 // Refresh the agent's model list so it picks up the new provider.
2162 agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2163
2164 // Thread starts with thinking_enabled = false (the default).
2165 agent.read_with(cx, |agent, _| {
2166 let session = agent.sessions.get(&session_id).unwrap();
2167 session.thread.read_with(cx, |thread, _| {
2168 assert!(!thread.thinking_enabled(), "thinking defaults to false");
2169 });
2170 });
2171
2172 // Select the thinking model via select_model.
2173 let selector = connection.model_selector(&session_id).unwrap();
2174 cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2175 .await
2176 .unwrap();
2177
2178 // select_model should have enabled thinking based on the model's supports_thinking().
2179 agent.read_with(cx, |agent, _| {
2180 let session = agent.sessions.get(&session_id).unwrap();
2181 session.thread.read_with(cx, |thread, _| {
2182 assert!(
2183 thread.thinking_enabled(),
2184 "select_model should enable thinking when model supports it"
2185 );
2186 });
2187 });
2188
2189 // Switch back to the non-thinking model.
2190 let selector = connection.model_selector(&session_id).unwrap();
2191 cx.update(|cx| selector.select_model(acp::ModelId::new("fake/fake"), cx))
2192 .await
2193 .unwrap();
2194
2195 // select_model should have disabled thinking.
2196 agent.read_with(cx, |agent, _| {
2197 let session = agent.sessions.get(&session_id).unwrap();
2198 session.thread.read_with(cx, |thread, _| {
2199 assert!(
2200 !thread.thinking_enabled(),
2201 "select_model should disable thinking when model does not support it"
2202 );
2203 });
2204 });
2205 }
2206
2207 #[gpui::test]
2208 async fn test_loaded_thread_preserves_thinking_enabled(cx: &mut TestAppContext) {
2209 init_test(cx);
2210 let fs = FakeFs::new(cx.executor());
2211 fs.insert_tree("/", json!({ "a": {} })).await;
2212 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2213 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2214 let agent = NativeAgent::new(
2215 project.clone(),
2216 thread_store.clone(),
2217 Templates::new(),
2218 None,
2219 fs.clone(),
2220 &mut cx.to_async(),
2221 )
2222 .await
2223 .unwrap();
2224 let connection = Rc::new(NativeAgentConnection(agent.clone()));
2225
2226 // Register a thinking model.
2227 let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2228 "fake-corp",
2229 "fake-thinking",
2230 "Fake Thinking",
2231 true,
2232 ));
2233 let thinking_provider = Arc::new(
2234 FakeLanguageModelProvider::new(
2235 LanguageModelProviderId::from("fake-corp".to_string()),
2236 LanguageModelProviderName::from("Fake Corp".to_string()),
2237 )
2238 .with_models(vec![thinking_model.clone()]),
2239 );
2240 cx.update(|cx| {
2241 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2242 registry.register_provider(thinking_provider, cx);
2243 });
2244 });
2245 agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2246
2247 // Create a thread and select the thinking model.
2248 let acp_thread = cx
2249 .update(|cx| {
2250 connection
2251 .clone()
2252 .new_session(project.clone(), Path::new("/a"), cx)
2253 })
2254 .await
2255 .unwrap();
2256 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2257
2258 let selector = connection.model_selector(&session_id).unwrap();
2259 cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2260 .await
2261 .unwrap();
2262
2263 // Verify thinking is enabled after selecting the thinking model.
2264 let thread = agent.read_with(cx, |agent, _| {
2265 agent.sessions.get(&session_id).unwrap().thread.clone()
2266 });
2267 thread.read_with(cx, |thread, _| {
2268 assert!(
2269 thread.thinking_enabled(),
2270 "thinking should be enabled after selecting thinking model"
2271 );
2272 });
2273
2274 // Send a message so the thread gets persisted.
2275 let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
2276 let send = cx.foreground_executor().spawn(send);
2277 cx.run_until_parked();
2278
2279 thinking_model.send_last_completion_stream_text_chunk("Response.");
2280 thinking_model.end_last_completion_stream();
2281
2282 send.await.unwrap();
2283 cx.run_until_parked();
2284
2285 // Close the session so it can be reloaded from disk.
2286 cx.update(|cx| connection.clone().close_session(&session_id, cx))
2287 .await
2288 .unwrap();
2289 drop(thread);
2290 drop(acp_thread);
2291 agent.read_with(cx, |agent, _| {
2292 assert!(agent.sessions.is_empty());
2293 });
2294
2295 // Reload the thread and verify thinking_enabled is still true.
2296 let reloaded_acp_thread = agent
2297 .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
2298 .await
2299 .unwrap();
2300 let reloaded_thread = agent.read_with(cx, |agent, _| {
2301 agent.sessions.get(&session_id).unwrap().thread.clone()
2302 });
2303 reloaded_thread.read_with(cx, |thread, _| {
2304 assert!(
2305 thread.thinking_enabled(),
2306 "thinking_enabled should be preserved when reloading a thread with a thinking model"
2307 );
2308 });
2309
2310 drop(reloaded_acp_thread);
2311 }
2312
2313 #[gpui::test]
2314 async fn test_loaded_thread_preserves_model(cx: &mut TestAppContext) {
2315 init_test(cx);
2316 let fs = FakeFs::new(cx.executor());
2317 fs.insert_tree("/", json!({ "a": {} })).await;
2318 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2319 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2320 let agent = NativeAgent::new(
2321 project.clone(),
2322 thread_store.clone(),
2323 Templates::new(),
2324 None,
2325 fs.clone(),
2326 &mut cx.to_async(),
2327 )
2328 .await
2329 .unwrap();
2330 let connection = Rc::new(NativeAgentConnection(agent.clone()));
2331
2332 // Register a model where id() != name(), like real Anthropic models
2333 // (e.g. id="claude-sonnet-4-5-thinking-latest", name="Claude Sonnet 4.5 Thinking").
2334 let model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2335 "fake-corp",
2336 "custom-model-id",
2337 "Custom Model Display Name",
2338 false,
2339 ));
2340 let provider = Arc::new(
2341 FakeLanguageModelProvider::new(
2342 LanguageModelProviderId::from("fake-corp".to_string()),
2343 LanguageModelProviderName::from("Fake Corp".to_string()),
2344 )
2345 .with_models(vec![model.clone()]),
2346 );
2347 cx.update(|cx| {
2348 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2349 registry.register_provider(provider, cx);
2350 });
2351 });
2352 agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2353
2354 // Create a thread and select the model.
2355 let acp_thread = cx
2356 .update(|cx| {
2357 connection
2358 .clone()
2359 .new_session(project.clone(), Path::new("/a"), cx)
2360 })
2361 .await
2362 .unwrap();
2363 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2364
2365 let selector = connection.model_selector(&session_id).unwrap();
2366 cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/custom-model-id"), cx))
2367 .await
2368 .unwrap();
2369
2370 let thread = agent.read_with(cx, |agent, _| {
2371 agent.sessions.get(&session_id).unwrap().thread.clone()
2372 });
2373 thread.read_with(cx, |thread, _| {
2374 assert_eq!(
2375 thread.model().unwrap().id().0.as_ref(),
2376 "custom-model-id",
2377 "model should be set before persisting"
2378 );
2379 });
2380
2381 // Send a message so the thread gets persisted.
2382 let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
2383 let send = cx.foreground_executor().spawn(send);
2384 cx.run_until_parked();
2385
2386 model.send_last_completion_stream_text_chunk("Response.");
2387 model.end_last_completion_stream();
2388
2389 send.await.unwrap();
2390 cx.run_until_parked();
2391
2392 // Close the session so it can be reloaded from disk.
2393 cx.update(|cx| connection.clone().close_session(&session_id, cx))
2394 .await
2395 .unwrap();
2396 drop(thread);
2397 drop(acp_thread);
2398 agent.read_with(cx, |agent, _| {
2399 assert!(agent.sessions.is_empty());
2400 });
2401
2402 // Reload the thread and verify the model was preserved.
2403 let reloaded_acp_thread = agent
2404 .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
2405 .await
2406 .unwrap();
2407 let reloaded_thread = agent.read_with(cx, |agent, _| {
2408 agent.sessions.get(&session_id).unwrap().thread.clone()
2409 });
2410 reloaded_thread.read_with(cx, |thread, _| {
2411 let reloaded_model = thread
2412 .model()
2413 .expect("model should be present after reload");
2414 assert_eq!(
2415 reloaded_model.id().0.as_ref(),
2416 "custom-model-id",
2417 "reloaded thread should have the same model, not fall back to the default"
2418 );
2419 });
2420
2421 drop(reloaded_acp_thread);
2422 }
2423
2424 #[gpui::test]
2425 async fn test_save_load_thread(cx: &mut TestAppContext) {
2426 init_test(cx);
2427 let fs = FakeFs::new(cx.executor());
2428 fs.insert_tree(
2429 "/",
2430 json!({
2431 "a": {
2432 "b.md": "Lorem"
2433 }
2434 }),
2435 )
2436 .await;
2437 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2438 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2439 let agent = NativeAgent::new(
2440 project.clone(),
2441 thread_store.clone(),
2442 Templates::new(),
2443 None,
2444 fs.clone(),
2445 &mut cx.to_async(),
2446 )
2447 .await
2448 .unwrap();
2449 let connection = Rc::new(NativeAgentConnection(agent.clone()));
2450
2451 let acp_thread = cx
2452 .update(|cx| {
2453 connection
2454 .clone()
2455 .new_session(project.clone(), Path::new(""), cx)
2456 })
2457 .await
2458 .unwrap();
2459 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2460 let thread = agent.read_with(cx, |agent, _| {
2461 agent.sessions.get(&session_id).unwrap().thread.clone()
2462 });
2463
2464 // Ensure empty threads are not saved, even if they get mutated.
2465 let model = Arc::new(FakeLanguageModel::default());
2466 let summary_model = Arc::new(FakeLanguageModel::default());
2467 thread.update(cx, |thread, cx| {
2468 thread.set_model(model.clone(), cx);
2469 thread.set_summarization_model(Some(summary_model.clone()), cx);
2470 });
2471 cx.run_until_parked();
2472 assert_eq!(thread_entries(&thread_store, cx), vec![]);
2473
2474 let send = acp_thread.update(cx, |thread, cx| {
2475 thread.send(
2476 vec![
2477 "What does ".into(),
2478 acp::ContentBlock::ResourceLink(acp::ResourceLink::new(
2479 "b.md",
2480 MentionUri::File {
2481 abs_path: path!("/a/b.md").into(),
2482 }
2483 .to_uri()
2484 .to_string(),
2485 )),
2486 " mean?".into(),
2487 ],
2488 cx,
2489 )
2490 });
2491 let send = cx.foreground_executor().spawn(send);
2492 cx.run_until_parked();
2493
2494 model.send_last_completion_stream_text_chunk("Lorem.");
2495 model.end_last_completion_stream();
2496 cx.run_until_parked();
2497 summary_model
2498 .send_last_completion_stream_text_chunk(&format!("Explaining {}", path!("/a/b.md")));
2499 summary_model.end_last_completion_stream();
2500
2501 send.await.unwrap();
2502 let uri = MentionUri::File {
2503 abs_path: path!("/a/b.md").into(),
2504 }
2505 .to_uri();
2506 acp_thread.read_with(cx, |thread, cx| {
2507 assert_eq!(
2508 thread.to_markdown(cx),
2509 formatdoc! {"
2510 ## User
2511
2512 What does [@b.md]({uri}) mean?
2513
2514 ## Assistant
2515
2516 Lorem.
2517
2518 "}
2519 )
2520 });
2521
2522 cx.run_until_parked();
2523
2524 // Close the session so it can be reloaded from disk.
2525 cx.update(|cx| connection.clone().close_session(&session_id, cx))
2526 .await
2527 .unwrap();
2528 drop(thread);
2529 drop(acp_thread);
2530 agent.read_with(cx, |agent, _| {
2531 assert_eq!(agent.sessions.keys().cloned().collect::<Vec<_>>(), []);
2532 });
2533
2534 // Ensure the thread can be reloaded from disk.
2535 assert_eq!(
2536 thread_entries(&thread_store, cx),
2537 vec![(
2538 session_id.clone(),
2539 format!("Explaining {}", path!("/a/b.md"))
2540 )]
2541 );
2542 let acp_thread = agent
2543 .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
2544 .await
2545 .unwrap();
2546 acp_thread.read_with(cx, |thread, cx| {
2547 assert_eq!(
2548 thread.to_markdown(cx),
2549 formatdoc! {"
2550 ## User
2551
2552 What does [@b.md]({uri}) mean?
2553
2554 ## Assistant
2555
2556 Lorem.
2557
2558 "}
2559 )
2560 });
2561 }
2562
2563 fn thread_entries(
2564 thread_store: &Entity<ThreadStore>,
2565 cx: &mut TestAppContext,
2566 ) -> Vec<(acp::SessionId, String)> {
2567 thread_store.read_with(cx, |store, _| {
2568 store
2569 .entries()
2570 .map(|entry| (entry.id.clone(), entry.title.to_string()))
2571 .collect::<Vec<_>>()
2572 })
2573 }
2574
2575 fn init_test(cx: &mut TestAppContext) {
2576 env_logger::try_init().ok();
2577 cx.update(|cx| {
2578 let settings_store = SettingsStore::test(cx);
2579 cx.set_global(settings_store);
2580
2581 LanguageModelRegistry::test(cx);
2582 });
2583 }
2584}
2585
2586fn mcp_message_content_to_acp_content_block(
2587 content: context_server::types::MessageContent,
2588) -> acp::ContentBlock {
2589 match content {
2590 context_server::types::MessageContent::Text {
2591 text,
2592 annotations: _,
2593 } => text.into(),
2594 context_server::types::MessageContent::Image {
2595 data,
2596 mime_type,
2597 annotations: _,
2598 } => acp::ContentBlock::Image(acp::ImageContent::new(data, mime_type)),
2599 context_server::types::MessageContent::Audio {
2600 data,
2601 mime_type,
2602 annotations: _,
2603 } => acp::ContentBlock::Audio(acp::AudioContent::new(data, mime_type)),
2604 context_server::types::MessageContent::Resource {
2605 resource,
2606 annotations: _,
2607 } => {
2608 let mut link =
2609 acp::ResourceLink::new(resource.uri.to_string(), resource.uri.to_string());
2610 if let Some(mime_type) = resource.mime_type {
2611 link = link.mime_type(mime_type);
2612 }
2613 acp::ContentBlock::ResourceLink(link)
2614 }
2615 }
2616}