1mod db;
2mod edit_agent;
3mod legacy_thread;
4mod native_agent_server;
5pub mod outline;
6mod pattern_extraction;
7mod templates;
8#[cfg(test)]
9mod tests;
10mod thread;
11mod thread_store;
12mod tool_permissions;
13mod tools;
14
15use context_server::ContextServerId;
16pub use db::*;
17pub use native_agent_server::NativeAgentServer;
18pub use pattern_extraction::*;
19pub use shell_command_parser::extract_commands;
20pub use templates::*;
21pub use thread::*;
22pub use thread_store::*;
23pub use tool_permissions::*;
24pub use tools::*;
25
26use acp_thread::{
27 AcpThread, AgentModelSelector, AgentSessionInfo, AgentSessionList, AgentSessionListRequest,
28 AgentSessionListResponse, UserMessageId,
29};
30use agent_client_protocol as acp;
31use anyhow::{Context as _, Result, anyhow};
32use chrono::{DateTime, Utc};
33use collections::{HashMap, HashSet, IndexMap};
34use fs::Fs;
35use futures::channel::{mpsc, oneshot};
36use futures::future::Shared;
37use futures::{FutureExt as _, StreamExt as _, future};
38use gpui::{
39 App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
40};
41use language_model::{IconOrSvg, LanguageModel, LanguageModelProvider, LanguageModelRegistry};
42use project::{Project, ProjectItem, ProjectPath, Worktree};
43use prompt_store::{
44 ProjectContext, PromptStore, RULES_FILE_NAMES, RulesFileContext, UserRulesContext,
45 WorktreeContext,
46};
47use serde::{Deserialize, Serialize};
48use settings::{LanguageModelSelection, update_settings_file};
49use std::any::Any;
50use std::path::{Path, PathBuf};
51use std::rc::Rc;
52use std::sync::Arc;
53use std::time::Duration;
54use util::ResultExt;
55use util::rel_path::RelPath;
56
57#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
58pub struct ProjectSnapshot {
59 pub worktree_snapshots: Vec<project::telemetry_snapshot::TelemetryWorktreeSnapshot>,
60 pub timestamp: DateTime<Utc>,
61}
62
63pub struct RulesLoadingError {
64 pub message: SharedString,
65}
66
67/// Holds both the internal Thread and the AcpThread for a session
68struct Session {
69 /// The internal thread that processes messages
70 thread: Entity<Thread>,
71 /// The ACP thread that handles protocol communication
72 acp_thread: Entity<acp_thread::AcpThread>,
73 pending_save: Task<()>,
74 _subscriptions: Vec<Subscription>,
75}
76
77pub struct LanguageModels {
78 /// Access language model by ID
79 models: HashMap<acp::ModelId, Arc<dyn LanguageModel>>,
80 /// Cached list for returning language model information
81 model_list: acp_thread::AgentModelList,
82 refresh_models_rx: watch::Receiver<()>,
83 refresh_models_tx: watch::Sender<()>,
84 _authenticate_all_providers_task: Task<()>,
85}
86
87impl LanguageModels {
88 fn new(cx: &mut App) -> Self {
89 let (refresh_models_tx, refresh_models_rx) = watch::channel(());
90
91 let mut this = Self {
92 models: HashMap::default(),
93 model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()),
94 refresh_models_rx,
95 refresh_models_tx,
96 _authenticate_all_providers_task: Self::authenticate_all_language_model_providers(cx),
97 };
98 this.refresh_list(cx);
99 this
100 }
101
102 fn refresh_list(&mut self, cx: &App) {
103 let providers = LanguageModelRegistry::global(cx)
104 .read(cx)
105 .visible_providers()
106 .into_iter()
107 .filter(|provider| provider.is_authenticated(cx))
108 .collect::<Vec<_>>();
109
110 let mut language_model_list = IndexMap::default();
111 let mut recommended_models = HashSet::default();
112
113 let mut recommended = Vec::new();
114 for provider in &providers {
115 for model in provider.recommended_models(cx) {
116 recommended_models.insert((model.provider_id(), model.id()));
117 recommended.push(Self::map_language_model_to_info(&model, provider));
118 }
119 }
120 if !recommended.is_empty() {
121 language_model_list.insert(
122 acp_thread::AgentModelGroupName("Recommended".into()),
123 recommended,
124 );
125 }
126
127 let mut models = HashMap::default();
128 for provider in providers {
129 let mut provider_models = Vec::new();
130 for model in provider.provided_models(cx) {
131 let model_info = Self::map_language_model_to_info(&model, &provider);
132 let model_id = model_info.id.clone();
133 provider_models.push(model_info);
134 models.insert(model_id, model);
135 }
136 if !provider_models.is_empty() {
137 language_model_list.insert(
138 acp_thread::AgentModelGroupName(provider.name().0.clone()),
139 provider_models,
140 );
141 }
142 }
143
144 self.models = models;
145 self.model_list = acp_thread::AgentModelList::Grouped(language_model_list);
146 self.refresh_models_tx.send(()).ok();
147 }
148
149 fn watch(&self) -> watch::Receiver<()> {
150 self.refresh_models_rx.clone()
151 }
152
153 pub fn model_from_id(&self, model_id: &acp::ModelId) -> Option<Arc<dyn LanguageModel>> {
154 self.models.get(model_id).cloned()
155 }
156
157 fn map_language_model_to_info(
158 model: &Arc<dyn LanguageModel>,
159 provider: &Arc<dyn LanguageModelProvider>,
160 ) -> acp_thread::AgentModelInfo {
161 acp_thread::AgentModelInfo {
162 id: Self::model_id(model),
163 name: model.name().0,
164 description: None,
165 icon: Some(match provider.icon() {
166 IconOrSvg::Svg(path) => acp_thread::AgentModelIcon::Path(path),
167 IconOrSvg::Icon(name) => acp_thread::AgentModelIcon::Named(name),
168 }),
169 is_latest: model.is_latest(),
170 cost: model.model_cost_info().map(|cost| cost.to_shared_string()),
171 }
172 }
173
174 fn model_id(model: &Arc<dyn LanguageModel>) -> acp::ModelId {
175 acp::ModelId::new(format!("{}/{}", model.provider_id().0, model.id().0))
176 }
177
178 fn authenticate_all_language_model_providers(cx: &mut App) -> Task<()> {
179 let authenticate_all_providers = LanguageModelRegistry::global(cx)
180 .read(cx)
181 .visible_providers()
182 .iter()
183 .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
184 .collect::<Vec<_>>();
185
186 cx.background_spawn(async move {
187 for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
188 if let Err(err) = authenticate_task.await {
189 match err {
190 language_model::AuthenticateError::CredentialsNotFound => {
191 // Since we're authenticating these providers in the
192 // background for the purposes of populating the
193 // language selector, we don't care about providers
194 // where the credentials are not found.
195 }
196 language_model::AuthenticateError::ConnectionRefused => {
197 // Not logging connection refused errors as they are mostly from LM Studio's noisy auth failures.
198 // LM Studio only has one auth method (endpoint call) which fails for users who haven't enabled it.
199 // TODO: Better manage LM Studio auth logic to avoid these noisy failures.
200 }
201 _ => {
202 // Some providers have noisy failure states that we
203 // don't want to spam the logs with every time the
204 // language model selector is initialized.
205 //
206 // Ideally these should have more clear failure modes
207 // that we know are safe to ignore here, like what we do
208 // with `CredentialsNotFound` above.
209 match provider_id.0.as_ref() {
210 "lmstudio" | "ollama" => {
211 // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
212 //
213 // These fail noisily, so we don't log them.
214 }
215 "copilot_chat" => {
216 // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
217 }
218 _ => {
219 log::error!(
220 "Failed to authenticate provider: {}: {err:#}",
221 provider_name.0
222 );
223 }
224 }
225 }
226 }
227 }
228 }
229 })
230 }
231}
232
233pub struct NativeAgent {
234 /// Session ID -> Session mapping
235 sessions: HashMap<acp::SessionId, Session>,
236 thread_store: Entity<ThreadStore>,
237 /// Shared project context for all threads
238 project_context: Entity<ProjectContext>,
239 project_context_needs_refresh: watch::Sender<()>,
240 _maintain_project_context: Task<Result<()>>,
241 context_server_registry: Entity<ContextServerRegistry>,
242 /// Shared templates for all threads
243 templates: Arc<Templates>,
244 /// Cached model information
245 models: LanguageModels,
246 project: Entity<Project>,
247 prompt_store: Option<Entity<PromptStore>>,
248 fs: Arc<dyn Fs>,
249 _subscriptions: Vec<Subscription>,
250}
251
252impl NativeAgent {
253 pub async fn new(
254 project: Entity<Project>,
255 thread_store: Entity<ThreadStore>,
256 templates: Arc<Templates>,
257 prompt_store: Option<Entity<PromptStore>>,
258 fs: Arc<dyn Fs>,
259 cx: &mut AsyncApp,
260 ) -> Result<Entity<NativeAgent>> {
261 log::debug!("Creating new NativeAgent");
262
263 let project_context = cx
264 .update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))
265 .await;
266
267 Ok(cx.new(|cx| {
268 let context_server_store = project.read(cx).context_server_store();
269 let context_server_registry =
270 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
271
272 let mut subscriptions = vec![
273 cx.subscribe(&project, Self::handle_project_event),
274 cx.subscribe(
275 &LanguageModelRegistry::global(cx),
276 Self::handle_models_updated_event,
277 ),
278 cx.subscribe(
279 &context_server_store,
280 Self::handle_context_server_store_updated,
281 ),
282 cx.subscribe(
283 &context_server_registry,
284 Self::handle_context_server_registry_event,
285 ),
286 ];
287 if let Some(prompt_store) = prompt_store.as_ref() {
288 subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
289 }
290
291 let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
292 watch::channel(());
293 Self {
294 sessions: HashMap::default(),
295 thread_store,
296 project_context: cx.new(|_| project_context),
297 project_context_needs_refresh: project_context_needs_refresh_tx,
298 _maintain_project_context: cx.spawn(async move |this, cx| {
299 Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
300 }),
301 context_server_registry,
302 templates,
303 models: LanguageModels::new(cx),
304 project,
305 prompt_store,
306 fs,
307 _subscriptions: subscriptions,
308 }
309 }))
310 }
311
312 fn new_session(
313 &mut self,
314 project: Entity<Project>,
315 cx: &mut Context<Self>,
316 ) -> Entity<AcpThread> {
317 // Create Thread
318 // Fetch default model from registry settings
319 let registry = LanguageModelRegistry::read_global(cx);
320 // Log available models for debugging
321 let available_count = registry.available_models(cx).count();
322 log::debug!("Total available models: {}", available_count);
323
324 let default_model = registry.default_model().and_then(|default_model| {
325 self.models
326 .model_from_id(&LanguageModels::model_id(&default_model.model))
327 });
328 let thread = cx.new(|cx| {
329 Thread::new(
330 project.clone(),
331 self.project_context.clone(),
332 self.context_server_registry.clone(),
333 self.templates.clone(),
334 default_model,
335 cx,
336 )
337 });
338
339 self.register_session(thread, None, cx)
340 }
341
342 fn register_session(
343 &mut self,
344 thread_handle: Entity<Thread>,
345 allowed_tool_names: Option<Vec<&str>>,
346 cx: &mut Context<Self>,
347 ) -> Entity<AcpThread> {
348 let connection = Rc::new(NativeAgentConnection(cx.entity()));
349
350 let thread = thread_handle.read(cx);
351 let session_id = thread.id().clone();
352 let parent_session_id = thread.parent_thread_id();
353 let title = thread.title();
354 let project = thread.project.clone();
355 let action_log = thread.action_log.clone();
356 let prompt_capabilities_rx = thread.prompt_capabilities_rx.clone();
357 let acp_thread = cx.new(|cx| {
358 acp_thread::AcpThread::new(
359 parent_session_id,
360 title,
361 connection,
362 project.clone(),
363 action_log.clone(),
364 session_id.clone(),
365 prompt_capabilities_rx,
366 cx,
367 )
368 });
369
370 let registry = LanguageModelRegistry::read_global(cx);
371 let summarization_model = registry.thread_summary_model().map(|c| c.model);
372
373 let weak = cx.weak_entity();
374 thread_handle.update(cx, |thread, cx| {
375 thread.set_summarization_model(summarization_model, cx);
376 thread.add_default_tools(
377 allowed_tool_names,
378 Rc::new(NativeThreadEnvironment {
379 acp_thread: acp_thread.downgrade(),
380 agent: weak,
381 }) as _,
382 cx,
383 )
384 });
385
386 let subscriptions = vec![
387 cx.subscribe(&thread_handle, Self::handle_thread_title_updated),
388 cx.subscribe(&thread_handle, Self::handle_thread_token_usage_updated),
389 cx.observe(&thread_handle, move |this, thread, cx| {
390 this.save_thread(thread, cx)
391 }),
392 ];
393
394 self.sessions.insert(
395 session_id,
396 Session {
397 thread: thread_handle,
398 acp_thread: acp_thread.clone(),
399 _subscriptions: subscriptions,
400 pending_save: Task::ready(()),
401 },
402 );
403
404 self.update_available_commands(cx);
405
406 acp_thread
407 }
408
409 pub fn models(&self) -> &LanguageModels {
410 &self.models
411 }
412
413 async fn maintain_project_context(
414 this: WeakEntity<Self>,
415 mut needs_refresh: watch::Receiver<()>,
416 cx: &mut AsyncApp,
417 ) -> Result<()> {
418 while needs_refresh.changed().await.is_ok() {
419 let project_context = this
420 .update(cx, |this, cx| {
421 Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx)
422 })?
423 .await;
424 this.update(cx, |this, cx| {
425 this.project_context = cx.new(|_| project_context);
426 })?;
427 }
428
429 Ok(())
430 }
431
432 fn build_project_context(
433 project: &Entity<Project>,
434 prompt_store: Option<&Entity<PromptStore>>,
435 cx: &mut App,
436 ) -> Task<ProjectContext> {
437 let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
438 let worktree_tasks = worktrees
439 .into_iter()
440 .map(|worktree| {
441 Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
442 })
443 .collect::<Vec<_>>();
444 let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
445 prompt_store.read_with(cx, |prompt_store, cx| {
446 let prompts = prompt_store.default_prompt_metadata();
447 let load_tasks = prompts.into_iter().map(|prompt_metadata| {
448 let contents = prompt_store.load(prompt_metadata.id, cx);
449 async move { (contents.await, prompt_metadata) }
450 });
451 cx.background_spawn(future::join_all(load_tasks))
452 })
453 } else {
454 Task::ready(vec![])
455 };
456
457 cx.spawn(async move |_cx| {
458 let (worktrees, default_user_rules) =
459 future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
460
461 let worktrees = worktrees
462 .into_iter()
463 .map(|(worktree, _rules_error)| {
464 // TODO: show error message
465 // if let Some(rules_error) = rules_error {
466 // this.update(cx, |_, cx| cx.emit(rules_error)).ok();
467 // }
468 worktree
469 })
470 .collect::<Vec<_>>();
471
472 let default_user_rules = default_user_rules
473 .into_iter()
474 .flat_map(|(contents, prompt_metadata)| match contents {
475 Ok(contents) => Some(UserRulesContext {
476 uuid: prompt_metadata.id.as_user()?,
477 title: prompt_metadata.title.map(|title| title.to_string()),
478 contents,
479 }),
480 Err(_err) => {
481 // TODO: show error message
482 // this.update(cx, |_, cx| {
483 // cx.emit(RulesLoadingError {
484 // message: format!("{err:?}").into(),
485 // });
486 // })
487 // .ok();
488 None
489 }
490 })
491 .collect::<Vec<_>>();
492
493 ProjectContext::new(worktrees, default_user_rules)
494 })
495 }
496
497 fn load_worktree_info_for_system_prompt(
498 worktree: Entity<Worktree>,
499 project: Entity<Project>,
500 cx: &mut App,
501 ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
502 let tree = worktree.read(cx);
503 let root_name = tree.root_name_str().into();
504 let abs_path = tree.abs_path();
505
506 let mut context = WorktreeContext {
507 root_name,
508 abs_path,
509 rules_file: None,
510 };
511
512 let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
513 let Some(rules_task) = rules_task else {
514 return Task::ready((context, None));
515 };
516
517 cx.spawn(async move |_| {
518 let (rules_file, rules_file_error) = match rules_task.await {
519 Ok(rules_file) => (Some(rules_file), None),
520 Err(err) => (
521 None,
522 Some(RulesLoadingError {
523 message: format!("{err}").into(),
524 }),
525 ),
526 };
527 context.rules_file = rules_file;
528 (context, rules_file_error)
529 })
530 }
531
532 fn load_worktree_rules_file(
533 worktree: Entity<Worktree>,
534 project: Entity<Project>,
535 cx: &mut App,
536 ) -> Option<Task<Result<RulesFileContext>>> {
537 let worktree = worktree.read(cx);
538 let worktree_id = worktree.id();
539 let selected_rules_file = RULES_FILE_NAMES
540 .into_iter()
541 .filter_map(|name| {
542 worktree
543 .entry_for_path(RelPath::unix(name).unwrap())
544 .filter(|entry| entry.is_file())
545 .map(|entry| entry.path.clone())
546 })
547 .next();
548
549 // Note that Cline supports `.clinerules` being a directory, but that is not currently
550 // supported. This doesn't seem to occur often in GitHub repositories.
551 selected_rules_file.map(|path_in_worktree| {
552 let project_path = ProjectPath {
553 worktree_id,
554 path: path_in_worktree.clone(),
555 };
556 let buffer_task =
557 project.update(cx, |project, cx| project.open_buffer(project_path, cx));
558 let rope_task = cx.spawn(async move |cx| {
559 let buffer = buffer_task.await?;
560 let (project_entry_id, rope) = buffer.read_with(cx, |buffer, cx| {
561 let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
562 anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
563 })?;
564 anyhow::Ok((project_entry_id, rope))
565 });
566 // Build a string from the rope on a background thread.
567 cx.background_spawn(async move {
568 let (project_entry_id, rope) = rope_task.await?;
569 anyhow::Ok(RulesFileContext {
570 path_in_worktree,
571 text: rope.to_string().trim().to_string(),
572 project_entry_id: project_entry_id.to_usize(),
573 })
574 })
575 })
576 }
577
578 fn handle_thread_title_updated(
579 &mut self,
580 thread: Entity<Thread>,
581 _: &TitleUpdated,
582 cx: &mut Context<Self>,
583 ) {
584 let session_id = thread.read(cx).id();
585 let Some(session) = self.sessions.get(session_id) else {
586 return;
587 };
588 let thread = thread.downgrade();
589 let acp_thread = session.acp_thread.downgrade();
590 cx.spawn(async move |_, cx| {
591 let title = thread.read_with(cx, |thread, _| thread.title())?;
592 let task = acp_thread.update(cx, |acp_thread, cx| acp_thread.set_title(title, cx))?;
593 task.await
594 })
595 .detach_and_log_err(cx);
596 }
597
598 fn handle_thread_token_usage_updated(
599 &mut self,
600 thread: Entity<Thread>,
601 usage: &TokenUsageUpdated,
602 cx: &mut Context<Self>,
603 ) {
604 let Some(session) = self.sessions.get(thread.read(cx).id()) else {
605 return;
606 };
607 session.acp_thread.update(cx, |acp_thread, cx| {
608 acp_thread.update_token_usage(usage.0.clone(), cx);
609 });
610 }
611
612 fn handle_project_event(
613 &mut self,
614 _project: Entity<Project>,
615 event: &project::Event,
616 _cx: &mut Context<Self>,
617 ) {
618 match event {
619 project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
620 self.project_context_needs_refresh.send(()).ok();
621 }
622 project::Event::WorktreeUpdatedEntries(_, items) => {
623 if items.iter().any(|(path, _, _)| {
624 RULES_FILE_NAMES
625 .iter()
626 .any(|name| path.as_ref() == RelPath::unix(name).unwrap())
627 }) {
628 self.project_context_needs_refresh.send(()).ok();
629 }
630 }
631 _ => {}
632 }
633 }
634
635 fn handle_prompts_updated_event(
636 &mut self,
637 _prompt_store: Entity<PromptStore>,
638 _event: &prompt_store::PromptsUpdatedEvent,
639 _cx: &mut Context<Self>,
640 ) {
641 self.project_context_needs_refresh.send(()).ok();
642 }
643
644 fn handle_models_updated_event(
645 &mut self,
646 _registry: Entity<LanguageModelRegistry>,
647 _event: &language_model::Event,
648 cx: &mut Context<Self>,
649 ) {
650 self.models.refresh_list(cx);
651
652 let registry = LanguageModelRegistry::read_global(cx);
653 let default_model = registry.default_model().map(|m| m.model);
654 let summarization_model = registry.thread_summary_model().map(|m| m.model);
655
656 for session in self.sessions.values_mut() {
657 session.thread.update(cx, |thread, cx| {
658 if thread.model().is_none()
659 && let Some(model) = default_model.clone()
660 {
661 thread.set_model(model, cx);
662 cx.notify();
663 }
664 thread.set_summarization_model(summarization_model.clone(), cx);
665 });
666 }
667 }
668
669 fn handle_context_server_store_updated(
670 &mut self,
671 _store: Entity<project::context_server_store::ContextServerStore>,
672 _event: &project::context_server_store::ServerStatusChangedEvent,
673 cx: &mut Context<Self>,
674 ) {
675 self.update_available_commands(cx);
676 }
677
678 fn handle_context_server_registry_event(
679 &mut self,
680 _registry: Entity<ContextServerRegistry>,
681 event: &ContextServerRegistryEvent,
682 cx: &mut Context<Self>,
683 ) {
684 match event {
685 ContextServerRegistryEvent::ToolsChanged => {}
686 ContextServerRegistryEvent::PromptsChanged => {
687 self.update_available_commands(cx);
688 }
689 }
690 }
691
692 fn update_available_commands(&self, cx: &mut Context<Self>) {
693 let available_commands = self.build_available_commands(cx);
694 for session in self.sessions.values() {
695 session.acp_thread.update(cx, |thread, cx| {
696 thread
697 .handle_session_update(
698 acp::SessionUpdate::AvailableCommandsUpdate(
699 acp::AvailableCommandsUpdate::new(available_commands.clone()),
700 ),
701 cx,
702 )
703 .log_err();
704 });
705 }
706 }
707
708 fn build_available_commands(&self, cx: &App) -> Vec<acp::AvailableCommand> {
709 let registry = self.context_server_registry.read(cx);
710
711 let mut prompt_name_counts: HashMap<&str, usize> = HashMap::default();
712 for context_server_prompt in registry.prompts() {
713 *prompt_name_counts
714 .entry(context_server_prompt.prompt.name.as_str())
715 .or_insert(0) += 1;
716 }
717
718 registry
719 .prompts()
720 .flat_map(|context_server_prompt| {
721 let prompt = &context_server_prompt.prompt;
722
723 let should_prefix = prompt_name_counts
724 .get(prompt.name.as_str())
725 .copied()
726 .unwrap_or(0)
727 > 1;
728
729 let name = if should_prefix {
730 format!("{}.{}", context_server_prompt.server_id, prompt.name)
731 } else {
732 prompt.name.clone()
733 };
734
735 let mut command = acp::AvailableCommand::new(
736 name,
737 prompt.description.clone().unwrap_or_default(),
738 );
739
740 match prompt.arguments.as_deref() {
741 Some([arg]) => {
742 let hint = format!("<{}>", arg.name);
743
744 command = command.input(acp::AvailableCommandInput::Unstructured(
745 acp::UnstructuredCommandInput::new(hint),
746 ));
747 }
748 Some([]) | None => {}
749 Some(_) => {
750 // skip >1 argument commands since we don't support them yet
751 return None;
752 }
753 }
754
755 Some(command)
756 })
757 .collect()
758 }
759
760 pub fn load_thread(
761 &mut self,
762 id: acp::SessionId,
763 cx: &mut Context<Self>,
764 ) -> Task<Result<Entity<Thread>>> {
765 let database_future = ThreadsDatabase::connect(cx);
766 cx.spawn(async move |this, cx| {
767 let database = database_future.await.map_err(|err| anyhow!(err))?;
768 let db_thread = database
769 .load_thread(id.clone())
770 .await?
771 .with_context(|| format!("no thread found with ID: {id:?}"))?;
772
773 this.update(cx, |this, cx| {
774 let summarization_model = LanguageModelRegistry::read_global(cx)
775 .thread_summary_model()
776 .map(|c| c.model);
777
778 cx.new(|cx| {
779 let mut thread = Thread::from_db(
780 id.clone(),
781 db_thread,
782 this.project.clone(),
783 this.project_context.clone(),
784 this.context_server_registry.clone(),
785 this.templates.clone(),
786 cx,
787 );
788 thread.set_summarization_model(summarization_model, cx);
789 thread
790 })
791 })
792 })
793 }
794
795 pub fn open_thread(
796 &mut self,
797 id: acp::SessionId,
798 cx: &mut Context<Self>,
799 ) -> Task<Result<Entity<AcpThread>>> {
800 if let Some(session) = self.sessions.get(&id) {
801 return Task::ready(Ok(session.acp_thread.clone()));
802 }
803
804 let task = self.load_thread(id, cx);
805 cx.spawn(async move |this, cx| {
806 let thread = task.await?;
807 let acp_thread = this.update(cx, |this, cx| {
808 this.register_session(thread.clone(), None, cx)
809 })?;
810 let events = thread.update(cx, |thread, cx| thread.replay(cx));
811 cx.update(|cx| {
812 NativeAgentConnection::handle_thread_events(events, acp_thread.downgrade(), cx)
813 })
814 .await?;
815 Ok(acp_thread)
816 })
817 }
818
819 pub fn thread_summary(
820 &mut self,
821 id: acp::SessionId,
822 cx: &mut Context<Self>,
823 ) -> Task<Result<SharedString>> {
824 let thread = self.open_thread(id.clone(), cx);
825 cx.spawn(async move |this, cx| {
826 let acp_thread = thread.await?;
827 let result = this
828 .update(cx, |this, cx| {
829 this.sessions
830 .get(&id)
831 .unwrap()
832 .thread
833 .update(cx, |thread, cx| thread.summary(cx))
834 })?
835 .await
836 .context("Failed to generate summary")?;
837 drop(acp_thread);
838 Ok(result)
839 })
840 }
841
842 fn save_thread(&mut self, thread: Entity<Thread>, cx: &mut Context<Self>) {
843 if thread.read(cx).is_empty() {
844 return;
845 }
846
847 let database_future = ThreadsDatabase::connect(cx);
848 let (id, db_thread) =
849 thread.update(cx, |thread, cx| (thread.id().clone(), thread.to_db(cx)));
850 let Some(session) = self.sessions.get_mut(&id) else {
851 return;
852 };
853 let thread_store = self.thread_store.clone();
854 session.pending_save = cx.spawn(async move |_, cx| {
855 let Some(database) = database_future.await.map_err(|err| anyhow!(err)).log_err() else {
856 return;
857 };
858 let db_thread = db_thread.await;
859 database.save_thread(id, db_thread).await.log_err();
860 thread_store.update(cx, |store, cx| store.reload(cx));
861 });
862 }
863
864 fn send_mcp_prompt(
865 &self,
866 message_id: UserMessageId,
867 session_id: agent_client_protocol::SessionId,
868 prompt_name: String,
869 server_id: ContextServerId,
870 arguments: HashMap<String, String>,
871 original_content: Vec<acp::ContentBlock>,
872 cx: &mut Context<Self>,
873 ) -> Task<Result<acp::PromptResponse>> {
874 let server_store = self.context_server_registry.read(cx).server_store().clone();
875 let path_style = self.project.read(cx).path_style(cx);
876
877 cx.spawn(async move |this, cx| {
878 let prompt =
879 crate::get_prompt(&server_store, &server_id, &prompt_name, arguments, cx).await?;
880
881 let (acp_thread, thread) = this.update(cx, |this, _cx| {
882 let session = this
883 .sessions
884 .get(&session_id)
885 .context("Failed to get session")?;
886 anyhow::Ok((session.acp_thread.clone(), session.thread.clone()))
887 })??;
888
889 let mut last_is_user = true;
890
891 thread.update(cx, |thread, cx| {
892 thread.push_acp_user_block(
893 message_id,
894 original_content.into_iter().skip(1),
895 path_style,
896 cx,
897 );
898 });
899
900 for message in prompt.messages {
901 let context_server::types::PromptMessage { role, content } = message;
902 let block = mcp_message_content_to_acp_content_block(content);
903
904 match role {
905 context_server::types::Role::User => {
906 let id = acp_thread::UserMessageId::new();
907
908 acp_thread.update(cx, |acp_thread, cx| {
909 acp_thread.push_user_content_block_with_indent(
910 Some(id.clone()),
911 block.clone(),
912 true,
913 cx,
914 );
915 });
916
917 thread.update(cx, |thread, cx| {
918 thread.push_acp_user_block(id, [block], path_style, cx);
919 });
920 }
921 context_server::types::Role::Assistant => {
922 acp_thread.update(cx, |acp_thread, cx| {
923 acp_thread.push_assistant_content_block_with_indent(
924 block.clone(),
925 false,
926 true,
927 cx,
928 );
929 });
930
931 thread.update(cx, |thread, cx| {
932 thread.push_acp_agent_block(block, cx);
933 });
934 }
935 }
936
937 last_is_user = role == context_server::types::Role::User;
938 }
939
940 let response_stream = thread.update(cx, |thread, cx| {
941 if last_is_user {
942 thread.send_existing(cx)
943 } else {
944 // Resume if MCP prompt did not end with a user message
945 thread.resume(cx)
946 }
947 })?;
948
949 cx.update(|cx| {
950 NativeAgentConnection::handle_thread_events(
951 response_stream,
952 acp_thread.downgrade(),
953 cx,
954 )
955 })
956 .await
957 })
958 }
959}
960
961/// Wrapper struct that implements the AgentConnection trait
962#[derive(Clone)]
963pub struct NativeAgentConnection(pub Entity<NativeAgent>);
964
965impl NativeAgentConnection {
966 pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option<Entity<Thread>> {
967 self.0
968 .read(cx)
969 .sessions
970 .get(session_id)
971 .map(|session| session.thread.clone())
972 }
973
974 pub fn load_thread(&self, id: acp::SessionId, cx: &mut App) -> Task<Result<Entity<Thread>>> {
975 self.0.update(cx, |this, cx| this.load_thread(id, cx))
976 }
977
978 fn run_turn(
979 &self,
980 session_id: acp::SessionId,
981 cx: &mut App,
982 f: impl 'static
983 + FnOnce(Entity<Thread>, &mut App) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>,
984 ) -> Task<Result<acp::PromptResponse>> {
985 let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
986 agent
987 .sessions
988 .get_mut(&session_id)
989 .map(|s| (s.thread.clone(), s.acp_thread.clone()))
990 }) else {
991 return Task::ready(Err(anyhow!("Session not found")));
992 };
993 log::debug!("Found session for: {}", session_id);
994
995 let response_stream = match f(thread, cx) {
996 Ok(stream) => stream,
997 Err(err) => return Task::ready(Err(err)),
998 };
999 Self::handle_thread_events(response_stream, acp_thread.downgrade(), cx)
1000 }
1001
1002 fn handle_thread_events(
1003 mut events: mpsc::UnboundedReceiver<Result<ThreadEvent>>,
1004 acp_thread: WeakEntity<AcpThread>,
1005 cx: &App,
1006 ) -> Task<Result<acp::PromptResponse>> {
1007 cx.spawn(async move |cx| {
1008 // Handle response stream and forward to session.acp_thread
1009 while let Some(result) = events.next().await {
1010 match result {
1011 Ok(event) => {
1012 log::trace!("Received completion event: {:?}", event);
1013
1014 match event {
1015 ThreadEvent::UserMessage(message) => {
1016 acp_thread.update(cx, |thread, cx| {
1017 for content in message.content {
1018 thread.push_user_content_block(
1019 Some(message.id.clone()),
1020 content.into(),
1021 cx,
1022 );
1023 }
1024 })?;
1025 }
1026 ThreadEvent::AgentText(text) => {
1027 acp_thread.update(cx, |thread, cx| {
1028 thread.push_assistant_content_block(text.into(), false, cx)
1029 })?;
1030 }
1031 ThreadEvent::AgentThinking(text) => {
1032 acp_thread.update(cx, |thread, cx| {
1033 thread.push_assistant_content_block(text.into(), true, cx)
1034 })?;
1035 }
1036 ThreadEvent::ToolCallAuthorization(ToolCallAuthorization {
1037 tool_call,
1038 options,
1039 response,
1040 context: _,
1041 }) => {
1042 let outcome_task = acp_thread.update(cx, |thread, cx| {
1043 thread.request_tool_call_authorization(tool_call, options, cx)
1044 })??;
1045 cx.background_spawn(async move {
1046 if let acp::RequestPermissionOutcome::Selected(
1047 acp::SelectedPermissionOutcome { option_id, .. },
1048 ) = outcome_task.await
1049 {
1050 response
1051 .send(option_id)
1052 .map(|_| anyhow!("authorization receiver was dropped"))
1053 .log_err();
1054 }
1055 })
1056 .detach();
1057 }
1058 ThreadEvent::ToolCall(tool_call) => {
1059 acp_thread.update(cx, |thread, cx| {
1060 thread.upsert_tool_call(tool_call, cx)
1061 })??;
1062 }
1063 ThreadEvent::ToolCallUpdate(update) => {
1064 acp_thread.update(cx, |thread, cx| {
1065 thread.update_tool_call(update, cx)
1066 })??;
1067 }
1068 ThreadEvent::SubagentSpawned(session_id) => {
1069 acp_thread.update(cx, |thread, cx| {
1070 thread.subagent_spawned(session_id, cx);
1071 })?;
1072 }
1073 ThreadEvent::Retry(status) => {
1074 acp_thread.update(cx, |thread, cx| {
1075 thread.update_retry_status(status, cx)
1076 })?;
1077 }
1078 ThreadEvent::Stop(stop_reason) => {
1079 log::debug!("Assistant message complete: {:?}", stop_reason);
1080 return Ok(acp::PromptResponse::new(stop_reason));
1081 }
1082 }
1083 }
1084 Err(e) => {
1085 log::error!("Error in model response stream: {:?}", e);
1086 return Err(e);
1087 }
1088 }
1089 }
1090
1091 log::debug!("Response stream completed");
1092 anyhow::Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
1093 })
1094 }
1095}
1096
1097struct Command<'a> {
1098 prompt_name: &'a str,
1099 arg_value: &'a str,
1100 explicit_server_id: Option<&'a str>,
1101}
1102
1103impl<'a> Command<'a> {
1104 fn parse(prompt: &'a [acp::ContentBlock]) -> Option<Self> {
1105 let acp::ContentBlock::Text(text_content) = prompt.first()? else {
1106 return None;
1107 };
1108 let text = text_content.text.trim();
1109 let command = text.strip_prefix('/')?;
1110 let (command, arg_value) = command
1111 .split_once(char::is_whitespace)
1112 .unwrap_or((command, ""));
1113
1114 if let Some((server_id, prompt_name)) = command.split_once('.') {
1115 Some(Self {
1116 prompt_name,
1117 arg_value,
1118 explicit_server_id: Some(server_id),
1119 })
1120 } else {
1121 Some(Self {
1122 prompt_name: command,
1123 arg_value,
1124 explicit_server_id: None,
1125 })
1126 }
1127 }
1128}
1129
1130struct NativeAgentModelSelector {
1131 session_id: acp::SessionId,
1132 connection: NativeAgentConnection,
1133}
1134
1135impl acp_thread::AgentModelSelector for NativeAgentModelSelector {
1136 fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
1137 log::debug!("NativeAgentConnection::list_models called");
1138 let list = self.connection.0.read(cx).models.model_list.clone();
1139 Task::ready(if list.is_empty() {
1140 Err(anyhow::anyhow!("No models available"))
1141 } else {
1142 Ok(list)
1143 })
1144 }
1145
1146 fn select_model(&self, model_id: acp::ModelId, cx: &mut App) -> Task<Result<()>> {
1147 log::debug!(
1148 "Setting model for session {}: {}",
1149 self.session_id,
1150 model_id
1151 );
1152 let Some(thread) = self
1153 .connection
1154 .0
1155 .read(cx)
1156 .sessions
1157 .get(&self.session_id)
1158 .map(|session| session.thread.clone())
1159 else {
1160 return Task::ready(Err(anyhow!("Session not found")));
1161 };
1162
1163 let Some(model) = self.connection.0.read(cx).models.model_from_id(&model_id) else {
1164 return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
1165 };
1166
1167 // We want to reset the effort level when switching models, as the currently-selected effort level may
1168 // not be compatible.
1169 let effort = model
1170 .default_effort_level()
1171 .map(|effort_level| effort_level.value.to_string());
1172
1173 thread.update(cx, |thread, cx| {
1174 thread.set_model(model.clone(), cx);
1175 thread.set_thinking_effort(effort.clone(), cx);
1176 thread.set_thinking_enabled(model.supports_thinking(), cx);
1177 });
1178
1179 update_settings_file(
1180 self.connection.0.read(cx).fs.clone(),
1181 cx,
1182 move |settings, cx| {
1183 let provider = model.provider_id().0.to_string();
1184 let model = model.id().0.to_string();
1185 let enable_thinking = thread.read(cx).thinking_enabled();
1186 settings
1187 .agent
1188 .get_or_insert_default()
1189 .set_model(LanguageModelSelection {
1190 provider: provider.into(),
1191 model,
1192 enable_thinking,
1193 effort,
1194 });
1195 },
1196 );
1197
1198 Task::ready(Ok(()))
1199 }
1200
1201 fn selected_model(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelInfo>> {
1202 let Some(thread) = self
1203 .connection
1204 .0
1205 .read(cx)
1206 .sessions
1207 .get(&self.session_id)
1208 .map(|session| session.thread.clone())
1209 else {
1210 return Task::ready(Err(anyhow!("Session not found")));
1211 };
1212 let Some(model) = thread.read(cx).model() else {
1213 return Task::ready(Err(anyhow!("Model not found")));
1214 };
1215 let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
1216 else {
1217 return Task::ready(Err(anyhow!("Provider not found")));
1218 };
1219 Task::ready(Ok(LanguageModels::map_language_model_to_info(
1220 model, &provider,
1221 )))
1222 }
1223
1224 fn watch(&self, cx: &mut App) -> Option<watch::Receiver<()>> {
1225 Some(self.connection.0.read(cx).models.watch())
1226 }
1227
1228 fn should_render_footer(&self) -> bool {
1229 true
1230 }
1231}
1232
1233impl acp_thread::AgentConnection for NativeAgentConnection {
1234 fn telemetry_id(&self) -> SharedString {
1235 "zed".into()
1236 }
1237
1238 fn new_session(
1239 self: Rc<Self>,
1240 project: Entity<Project>,
1241 cwd: &Path,
1242 cx: &mut App,
1243 ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
1244 log::debug!("Creating new thread for project at: {cwd:?}");
1245 Task::ready(Ok(self
1246 .0
1247 .update(cx, |agent, cx| agent.new_session(project, cx))))
1248 }
1249
1250 fn supports_load_session(&self) -> bool {
1251 true
1252 }
1253
1254 fn load_session(
1255 self: Rc<Self>,
1256 session: AgentSessionInfo,
1257 _project: Entity<Project>,
1258 _cwd: &Path,
1259 cx: &mut App,
1260 ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
1261 self.0
1262 .update(cx, |agent, cx| agent.open_thread(session.session_id, cx))
1263 }
1264
1265 fn supports_close_session(&self) -> bool {
1266 true
1267 }
1268
1269 fn close_session(&self, session_id: &acp::SessionId, cx: &mut App) -> Task<Result<()>> {
1270 self.0.update(cx, |agent, _cx| {
1271 agent.sessions.remove(session_id);
1272 });
1273 Task::ready(Ok(()))
1274 }
1275
1276 fn auth_methods(&self) -> &[acp::AuthMethod] {
1277 &[] // No auth for in-process
1278 }
1279
1280 fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
1281 Task::ready(Ok(()))
1282 }
1283
1284 fn model_selector(&self, session_id: &acp::SessionId) -> Option<Rc<dyn AgentModelSelector>> {
1285 Some(Rc::new(NativeAgentModelSelector {
1286 session_id: session_id.clone(),
1287 connection: self.clone(),
1288 }) as Rc<dyn AgentModelSelector>)
1289 }
1290
1291 fn prompt(
1292 &self,
1293 id: Option<acp_thread::UserMessageId>,
1294 params: acp::PromptRequest,
1295 cx: &mut App,
1296 ) -> Task<Result<acp::PromptResponse>> {
1297 let id = id.expect("UserMessageId is required");
1298 let session_id = params.session_id.clone();
1299 log::info!("Received prompt request for session: {}", session_id);
1300 log::debug!("Prompt blocks count: {}", params.prompt.len());
1301
1302 if let Some(parsed_command) = Command::parse(¶ms.prompt) {
1303 let registry = self.0.read(cx).context_server_registry.read(cx);
1304
1305 let explicit_server_id = parsed_command
1306 .explicit_server_id
1307 .map(|server_id| ContextServerId(server_id.into()));
1308
1309 if let Some(prompt) =
1310 registry.find_prompt(explicit_server_id.as_ref(), parsed_command.prompt_name)
1311 {
1312 let arguments = if !parsed_command.arg_value.is_empty()
1313 && let Some(arg_name) = prompt
1314 .prompt
1315 .arguments
1316 .as_ref()
1317 .and_then(|args| args.first())
1318 .map(|arg| arg.name.clone())
1319 {
1320 HashMap::from_iter([(arg_name, parsed_command.arg_value.to_string())])
1321 } else {
1322 Default::default()
1323 };
1324
1325 let prompt_name = prompt.prompt.name.clone();
1326 let server_id = prompt.server_id.clone();
1327
1328 return self.0.update(cx, |agent, cx| {
1329 agent.send_mcp_prompt(
1330 id,
1331 session_id.clone(),
1332 prompt_name,
1333 server_id,
1334 arguments,
1335 params.prompt,
1336 cx,
1337 )
1338 });
1339 };
1340 };
1341
1342 let path_style = self.0.read(cx).project.read(cx).path_style(cx);
1343
1344 self.run_turn(session_id, cx, move |thread, cx| {
1345 let content: Vec<UserMessageContent> = params
1346 .prompt
1347 .into_iter()
1348 .map(|block| UserMessageContent::from_content_block(block, path_style))
1349 .collect::<Vec<_>>();
1350 log::debug!("Converted prompt to message: {} chars", content.len());
1351 log::debug!("Message id: {:?}", id);
1352 log::debug!("Message content: {:?}", content);
1353
1354 thread.update(cx, |thread, cx| thread.send(id, content, cx))
1355 })
1356 }
1357
1358 fn retry(
1359 &self,
1360 session_id: &acp::SessionId,
1361 _cx: &App,
1362 ) -> Option<Rc<dyn acp_thread::AgentSessionRetry>> {
1363 Some(Rc::new(NativeAgentSessionRetry {
1364 connection: self.clone(),
1365 session_id: session_id.clone(),
1366 }) as _)
1367 }
1368
1369 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
1370 log::info!("Cancelling on session: {}", session_id);
1371 self.0.update(cx, |agent, cx| {
1372 if let Some(session) = agent.sessions.get(session_id) {
1373 session
1374 .thread
1375 .update(cx, |thread, cx| thread.cancel(cx))
1376 .detach();
1377 }
1378 });
1379 }
1380
1381 fn truncate(
1382 &self,
1383 session_id: &agent_client_protocol::SessionId,
1384 cx: &App,
1385 ) -> Option<Rc<dyn acp_thread::AgentSessionTruncate>> {
1386 self.0.read_with(cx, |agent, _cx| {
1387 agent.sessions.get(session_id).map(|session| {
1388 Rc::new(NativeAgentSessionTruncate {
1389 thread: session.thread.clone(),
1390 acp_thread: session.acp_thread.downgrade(),
1391 }) as _
1392 })
1393 })
1394 }
1395
1396 fn set_title(
1397 &self,
1398 session_id: &acp::SessionId,
1399 cx: &App,
1400 ) -> Option<Rc<dyn acp_thread::AgentSessionSetTitle>> {
1401 self.0.read_with(cx, |agent, _cx| {
1402 agent
1403 .sessions
1404 .get(session_id)
1405 .filter(|s| !s.thread.read(cx).is_subagent())
1406 .map(|session| {
1407 Rc::new(NativeAgentSessionSetTitle {
1408 thread: session.thread.clone(),
1409 }) as _
1410 })
1411 })
1412 }
1413
1414 fn session_list(&self, cx: &mut App) -> Option<Rc<dyn AgentSessionList>> {
1415 let thread_store = self.0.read(cx).thread_store.clone();
1416 Some(Rc::new(NativeAgentSessionList::new(thread_store, cx)) as _)
1417 }
1418
1419 fn telemetry(&self) -> Option<Rc<dyn acp_thread::AgentTelemetry>> {
1420 Some(Rc::new(self.clone()) as Rc<dyn acp_thread::AgentTelemetry>)
1421 }
1422
1423 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1424 self
1425 }
1426}
1427
1428impl acp_thread::AgentTelemetry for NativeAgentConnection {
1429 fn thread_data(
1430 &self,
1431 session_id: &acp::SessionId,
1432 cx: &mut App,
1433 ) -> Task<Result<serde_json::Value>> {
1434 let Some(session) = self.0.read(cx).sessions.get(session_id) else {
1435 return Task::ready(Err(anyhow!("Session not found")));
1436 };
1437
1438 let task = session.thread.read(cx).to_db(cx);
1439 cx.background_spawn(async move {
1440 serde_json::to_value(task.await).context("Failed to serialize thread")
1441 })
1442 }
1443}
1444
1445pub struct NativeAgentSessionList {
1446 thread_store: Entity<ThreadStore>,
1447 updates_tx: smol::channel::Sender<acp_thread::SessionListUpdate>,
1448 updates_rx: smol::channel::Receiver<acp_thread::SessionListUpdate>,
1449 _subscription: Subscription,
1450}
1451
1452impl NativeAgentSessionList {
1453 fn new(thread_store: Entity<ThreadStore>, cx: &mut App) -> Self {
1454 let (tx, rx) = smol::channel::unbounded();
1455 let this_tx = tx.clone();
1456 let subscription = cx.observe(&thread_store, move |_, _| {
1457 this_tx
1458 .try_send(acp_thread::SessionListUpdate::Refresh)
1459 .ok();
1460 });
1461 Self {
1462 thread_store,
1463 updates_tx: tx,
1464 updates_rx: rx,
1465 _subscription: subscription,
1466 }
1467 }
1468
1469 fn to_session_info(entry: DbThreadMetadata) -> AgentSessionInfo {
1470 AgentSessionInfo {
1471 session_id: entry.id,
1472 cwd: None,
1473 title: Some(entry.title),
1474 updated_at: Some(entry.updated_at),
1475 meta: None,
1476 }
1477 }
1478
1479 pub fn thread_store(&self) -> &Entity<ThreadStore> {
1480 &self.thread_store
1481 }
1482}
1483
1484impl AgentSessionList for NativeAgentSessionList {
1485 fn list_sessions(
1486 &self,
1487 _request: AgentSessionListRequest,
1488 cx: &mut App,
1489 ) -> Task<Result<AgentSessionListResponse>> {
1490 let sessions = self
1491 .thread_store
1492 .read(cx)
1493 .entries()
1494 .map(Self::to_session_info)
1495 .collect();
1496 Task::ready(Ok(AgentSessionListResponse::new(sessions)))
1497 }
1498
1499 fn supports_delete(&self) -> bool {
1500 true
1501 }
1502
1503 fn delete_session(&self, session_id: &acp::SessionId, cx: &mut App) -> Task<Result<()>> {
1504 self.thread_store
1505 .update(cx, |store, cx| store.delete_thread(session_id.clone(), cx))
1506 }
1507
1508 fn delete_sessions(&self, cx: &mut App) -> Task<Result<()>> {
1509 self.thread_store
1510 .update(cx, |store, cx| store.delete_threads(cx))
1511 }
1512
1513 fn watch(
1514 &self,
1515 _cx: &mut App,
1516 ) -> Option<smol::channel::Receiver<acp_thread::SessionListUpdate>> {
1517 Some(self.updates_rx.clone())
1518 }
1519
1520 fn notify_refresh(&self) {
1521 self.updates_tx
1522 .try_send(acp_thread::SessionListUpdate::Refresh)
1523 .ok();
1524 }
1525
1526 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1527 self
1528 }
1529}
1530
1531struct NativeAgentSessionTruncate {
1532 thread: Entity<Thread>,
1533 acp_thread: WeakEntity<AcpThread>,
1534}
1535
1536impl acp_thread::AgentSessionTruncate for NativeAgentSessionTruncate {
1537 fn run(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
1538 match self.thread.update(cx, |thread, cx| {
1539 thread.truncate(message_id.clone(), cx)?;
1540 Ok(thread.latest_token_usage())
1541 }) {
1542 Ok(usage) => {
1543 self.acp_thread
1544 .update(cx, |thread, cx| {
1545 thread.update_token_usage(usage, cx);
1546 })
1547 .ok();
1548 Task::ready(Ok(()))
1549 }
1550 Err(error) => Task::ready(Err(error)),
1551 }
1552 }
1553}
1554
1555struct NativeAgentSessionRetry {
1556 connection: NativeAgentConnection,
1557 session_id: acp::SessionId,
1558}
1559
1560impl acp_thread::AgentSessionRetry for NativeAgentSessionRetry {
1561 fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>> {
1562 self.connection
1563 .run_turn(self.session_id.clone(), cx, |thread, cx| {
1564 thread.update(cx, |thread, cx| thread.resume(cx))
1565 })
1566 }
1567}
1568
1569struct NativeAgentSessionSetTitle {
1570 thread: Entity<Thread>,
1571}
1572
1573impl acp_thread::AgentSessionSetTitle for NativeAgentSessionSetTitle {
1574 fn run(&self, title: SharedString, cx: &mut App) -> Task<Result<()>> {
1575 self.thread
1576 .update(cx, |thread, cx| thread.set_title(title, cx));
1577 Task::ready(Ok(()))
1578 }
1579}
1580
1581pub struct NativeThreadEnvironment {
1582 agent: WeakEntity<NativeAgent>,
1583 acp_thread: WeakEntity<AcpThread>,
1584}
1585
1586impl NativeThreadEnvironment {
1587 pub(crate) fn create_subagent_thread(
1588 agent: WeakEntity<NativeAgent>,
1589 parent_thread_entity: Entity<Thread>,
1590 label: String,
1591 initial_prompt: String,
1592 timeout: Option<Duration>,
1593 allowed_tools: Option<Vec<String>>,
1594 cx: &mut App,
1595 ) -> Result<Rc<dyn SubagentHandle>> {
1596 let parent_thread = parent_thread_entity.read(cx);
1597 let current_depth = parent_thread.depth();
1598
1599 if current_depth >= MAX_SUBAGENT_DEPTH {
1600 return Err(anyhow!(
1601 "Maximum subagent depth ({}) reached",
1602 MAX_SUBAGENT_DEPTH
1603 ));
1604 }
1605
1606 let running_count = parent_thread.running_subagent_count();
1607 if running_count >= MAX_PARALLEL_SUBAGENTS {
1608 return Err(anyhow!(
1609 "Maximum parallel subagents ({}) reached. Wait for existing subagents to complete.",
1610 MAX_PARALLEL_SUBAGENTS
1611 ));
1612 }
1613
1614 let allowed_tools = match allowed_tools {
1615 Some(tools) => {
1616 let parent_tool_names: std::collections::HashSet<&str> =
1617 parent_thread.tools.keys().map(|s| s.as_str()).collect();
1618 Some(
1619 tools
1620 .into_iter()
1621 .filter(|t| parent_tool_names.contains(t.as_str()))
1622 .collect::<Vec<_>>(),
1623 )
1624 }
1625 None => Some(parent_thread.tools.keys().map(|s| s.to_string()).collect()),
1626 };
1627
1628 let subagent_thread: Entity<Thread> = cx.new(|cx| {
1629 let mut thread = Thread::new_subagent(&parent_thread_entity, cx);
1630 thread.set_title(label.into(), cx);
1631 thread
1632 });
1633
1634 let session_id = subagent_thread.read(cx).id().clone();
1635
1636 let acp_thread = agent.update(cx, |agent, cx| {
1637 agent.register_session(
1638 subagent_thread.clone(),
1639 allowed_tools
1640 .as_ref()
1641 .map(|v| v.iter().map(|s| s.as_str()).collect()),
1642 cx,
1643 )
1644 })?;
1645
1646 parent_thread_entity.update(cx, |parent_thread, _cx| {
1647 parent_thread.register_running_subagent(subagent_thread.downgrade())
1648 });
1649
1650 let task = acp_thread.update(cx, |agent, cx| agent.send(vec![initial_prompt.into()], cx));
1651
1652 let timeout_timer = timeout.map(|d| cx.background_executor().timer(d));
1653 let wait_for_prompt_to_complete = cx
1654 .background_spawn(async move {
1655 if let Some(timer) = timeout_timer {
1656 futures::select! {
1657 _ = timer.fuse() => SubagentInitialPromptResult::Timeout,
1658 response = task.fuse() => {
1659 let response = response.log_err().flatten();
1660 if response.is_some_and(|response| {
1661 response.stop_reason == acp::StopReason::Cancelled
1662 })
1663 {
1664 SubagentInitialPromptResult::Cancelled
1665 } else {
1666 SubagentInitialPromptResult::Completed
1667 }
1668 },
1669 }
1670 } else {
1671 let response = task.await.log_err().flatten();
1672 if response
1673 .is_some_and(|response| response.stop_reason == acp::StopReason::Cancelled)
1674 {
1675 SubagentInitialPromptResult::Cancelled
1676 } else {
1677 SubagentInitialPromptResult::Completed
1678 }
1679 }
1680 })
1681 .shared();
1682
1683 Ok(Rc::new(NativeSubagentHandle {
1684 session_id,
1685 subagent_thread,
1686 parent_thread: parent_thread_entity.downgrade(),
1687 acp_thread,
1688 wait_for_prompt_to_complete,
1689 }) as _)
1690 }
1691}
1692
1693impl ThreadEnvironment for NativeThreadEnvironment {
1694 fn create_terminal(
1695 &self,
1696 command: String,
1697 cwd: Option<PathBuf>,
1698 output_byte_limit: Option<u64>,
1699 cx: &mut AsyncApp,
1700 ) -> Task<Result<Rc<dyn TerminalHandle>>> {
1701 let task = self.acp_thread.update(cx, |thread, cx| {
1702 thread.create_terminal(command, vec![], vec![], cwd, output_byte_limit, cx)
1703 });
1704
1705 let acp_thread = self.acp_thread.clone();
1706 cx.spawn(async move |cx| {
1707 let terminal = task?.await?;
1708
1709 let (drop_tx, drop_rx) = oneshot::channel();
1710 let terminal_id = terminal.read_with(cx, |terminal, _cx| terminal.id().clone());
1711
1712 cx.spawn(async move |cx| {
1713 drop_rx.await.ok();
1714 acp_thread.update(cx, |thread, cx| thread.release_terminal(terminal_id, cx))
1715 })
1716 .detach();
1717
1718 let handle = AcpTerminalHandle {
1719 terminal,
1720 _drop_tx: Some(drop_tx),
1721 };
1722
1723 Ok(Rc::new(handle) as _)
1724 })
1725 }
1726
1727 fn create_subagent(
1728 &self,
1729 parent_thread_entity: Entity<Thread>,
1730 label: String,
1731 initial_prompt: String,
1732 timeout: Option<Duration>,
1733 allowed_tools: Option<Vec<String>>,
1734 cx: &mut App,
1735 ) -> Result<Rc<dyn SubagentHandle>> {
1736 Self::create_subagent_thread(
1737 self.agent.clone(),
1738 parent_thread_entity,
1739 label,
1740 initial_prompt,
1741 timeout,
1742 allowed_tools,
1743 cx,
1744 )
1745 }
1746}
1747
1748#[derive(Debug, Clone, Copy)]
1749enum SubagentInitialPromptResult {
1750 Completed,
1751 Timeout,
1752 Cancelled,
1753}
1754
1755pub struct NativeSubagentHandle {
1756 session_id: acp::SessionId,
1757 parent_thread: WeakEntity<Thread>,
1758 subagent_thread: Entity<Thread>,
1759 acp_thread: Entity<AcpThread>,
1760 wait_for_prompt_to_complete: Shared<Task<SubagentInitialPromptResult>>,
1761}
1762
1763impl SubagentHandle for NativeSubagentHandle {
1764 fn id(&self) -> acp::SessionId {
1765 self.session_id.clone()
1766 }
1767
1768 fn wait_for_summary(&self, summary_prompt: String, cx: &AsyncApp) -> Task<Result<String>> {
1769 let thread = self.subagent_thread.clone();
1770 let acp_thread = self.acp_thread.clone();
1771 let wait_for_prompt = self.wait_for_prompt_to_complete.clone();
1772
1773 let wait_for_summary_task = cx.spawn(async move |cx| {
1774 let timed_out = match wait_for_prompt.await {
1775 SubagentInitialPromptResult::Completed => false,
1776 SubagentInitialPromptResult::Timeout => true,
1777 SubagentInitialPromptResult::Cancelled => return Err(anyhow!("User cancelled")),
1778 };
1779
1780 let summary_prompt = if timed_out {
1781 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
1782 format!("{}\n{}", "The time to complete the task was exceeded. Stop with the task and follow the directions below:", summary_prompt)
1783 } else {
1784 summary_prompt
1785 };
1786
1787 let response = acp_thread
1788 .update(cx, |thread, cx| thread.send(vec![summary_prompt.into()], cx))
1789 .await?;
1790
1791 let was_canceled = response.is_some_and(|r| r.stop_reason == acp::StopReason::Cancelled);
1792 if was_canceled {
1793 return Err(anyhow!("User cancelled"));
1794 }
1795
1796 thread.read_with(cx, |thread, _cx| {
1797 thread
1798 .last_message()
1799 .map(|m| m.to_markdown())
1800 .context("No response from subagent")
1801 })
1802 });
1803
1804 let subagent_session_id = self.session_id.clone();
1805 let parent_thread = self.parent_thread.clone();
1806 cx.spawn(async move |cx| {
1807 let result = wait_for_summary_task.await;
1808 parent_thread
1809 .update(cx, |parent_thread, cx| {
1810 parent_thread.unregister_running_subagent(&subagent_session_id, cx)
1811 })
1812 .ok();
1813 result
1814 })
1815 }
1816}
1817
1818pub struct AcpTerminalHandle {
1819 terminal: Entity<acp_thread::Terminal>,
1820 _drop_tx: Option<oneshot::Sender<()>>,
1821}
1822
1823impl TerminalHandle for AcpTerminalHandle {
1824 fn id(&self, cx: &AsyncApp) -> Result<acp::TerminalId> {
1825 Ok(self.terminal.read_with(cx, |term, _cx| term.id().clone()))
1826 }
1827
1828 fn wait_for_exit(&self, cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
1829 Ok(self
1830 .terminal
1831 .read_with(cx, |term, _cx| term.wait_for_exit()))
1832 }
1833
1834 fn current_output(&self, cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
1835 Ok(self
1836 .terminal
1837 .read_with(cx, |term, cx| term.current_output(cx)))
1838 }
1839
1840 fn kill(&self, cx: &AsyncApp) -> Result<()> {
1841 cx.update(|cx| {
1842 self.terminal.update(cx, |terminal, cx| {
1843 terminal.kill(cx);
1844 });
1845 });
1846 Ok(())
1847 }
1848
1849 fn was_stopped_by_user(&self, cx: &AsyncApp) -> Result<bool> {
1850 Ok(self
1851 .terminal
1852 .read_with(cx, |term, _cx| term.was_stopped_by_user()))
1853 }
1854}
1855
1856#[cfg(test)]
1857mod internal_tests {
1858 use super::*;
1859 use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelInfo, MentionUri};
1860 use fs::FakeFs;
1861 use gpui::TestAppContext;
1862 use indoc::formatdoc;
1863 use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
1864 use language_model::{LanguageModelProviderId, LanguageModelProviderName};
1865 use serde_json::json;
1866 use settings::SettingsStore;
1867 use util::{path, rel_path::rel_path};
1868
1869 #[gpui::test]
1870 async fn test_maintaining_project_context(cx: &mut TestAppContext) {
1871 init_test(cx);
1872 let fs = FakeFs::new(cx.executor());
1873 fs.insert_tree(
1874 "/",
1875 json!({
1876 "a": {}
1877 }),
1878 )
1879 .await;
1880 let project = Project::test(fs.clone(), [], cx).await;
1881 let thread_store = cx.new(|cx| ThreadStore::new(cx));
1882 let agent = NativeAgent::new(
1883 project.clone(),
1884 thread_store,
1885 Templates::new(),
1886 None,
1887 fs.clone(),
1888 &mut cx.to_async(),
1889 )
1890 .await
1891 .unwrap();
1892 agent.read_with(cx, |agent, cx| {
1893 assert_eq!(agent.project_context.read(cx).worktrees, vec![])
1894 });
1895
1896 let worktree = project
1897 .update(cx, |project, cx| project.create_worktree("/a", true, cx))
1898 .await
1899 .unwrap();
1900 cx.run_until_parked();
1901 agent.read_with(cx, |agent, cx| {
1902 assert_eq!(
1903 agent.project_context.read(cx).worktrees,
1904 vec![WorktreeContext {
1905 root_name: "a".into(),
1906 abs_path: Path::new("/a").into(),
1907 rules_file: None
1908 }]
1909 )
1910 });
1911
1912 // Creating `/a/.rules` updates the project context.
1913 fs.insert_file("/a/.rules", Vec::new()).await;
1914 cx.run_until_parked();
1915 agent.read_with(cx, |agent, cx| {
1916 let rules_entry = worktree
1917 .read(cx)
1918 .entry_for_path(rel_path(".rules"))
1919 .unwrap();
1920 assert_eq!(
1921 agent.project_context.read(cx).worktrees,
1922 vec![WorktreeContext {
1923 root_name: "a".into(),
1924 abs_path: Path::new("/a").into(),
1925 rules_file: Some(RulesFileContext {
1926 path_in_worktree: rel_path(".rules").into(),
1927 text: "".into(),
1928 project_entry_id: rules_entry.id.to_usize()
1929 })
1930 }]
1931 )
1932 });
1933 }
1934
1935 #[gpui::test]
1936 async fn test_listing_models(cx: &mut TestAppContext) {
1937 init_test(cx);
1938 let fs = FakeFs::new(cx.executor());
1939 fs.insert_tree("/", json!({ "a": {} })).await;
1940 let project = Project::test(fs.clone(), [], cx).await;
1941 let thread_store = cx.new(|cx| ThreadStore::new(cx));
1942 let connection = NativeAgentConnection(
1943 NativeAgent::new(
1944 project.clone(),
1945 thread_store,
1946 Templates::new(),
1947 None,
1948 fs.clone(),
1949 &mut cx.to_async(),
1950 )
1951 .await
1952 .unwrap(),
1953 );
1954
1955 // Create a thread/session
1956 let acp_thread = cx
1957 .update(|cx| {
1958 Rc::new(connection.clone()).new_session(project.clone(), Path::new("/a"), cx)
1959 })
1960 .await
1961 .unwrap();
1962
1963 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
1964
1965 let models = cx
1966 .update(|cx| {
1967 connection
1968 .model_selector(&session_id)
1969 .unwrap()
1970 .list_models(cx)
1971 })
1972 .await
1973 .unwrap();
1974
1975 let acp_thread::AgentModelList::Grouped(models) = models else {
1976 panic!("Unexpected model group");
1977 };
1978 assert_eq!(
1979 models,
1980 IndexMap::from_iter([(
1981 AgentModelGroupName("Fake".into()),
1982 vec![AgentModelInfo {
1983 id: acp::ModelId::new("fake/fake"),
1984 name: "Fake".into(),
1985 description: None,
1986 icon: Some(acp_thread::AgentModelIcon::Named(
1987 ui::IconName::ZedAssistant
1988 )),
1989 is_latest: false,
1990 cost: None,
1991 }]
1992 )])
1993 );
1994 }
1995
1996 #[gpui::test]
1997 async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) {
1998 init_test(cx);
1999 let fs = FakeFs::new(cx.executor());
2000 fs.create_dir(paths::settings_file().parent().unwrap())
2001 .await
2002 .unwrap();
2003 fs.insert_file(
2004 paths::settings_file(),
2005 json!({
2006 "agent": {
2007 "default_model": {
2008 "provider": "foo",
2009 "model": "bar"
2010 }
2011 }
2012 })
2013 .to_string()
2014 .into_bytes(),
2015 )
2016 .await;
2017 let project = Project::test(fs.clone(), [], cx).await;
2018
2019 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2020
2021 // Create the agent and connection
2022 let agent = NativeAgent::new(
2023 project.clone(),
2024 thread_store,
2025 Templates::new(),
2026 None,
2027 fs.clone(),
2028 &mut cx.to_async(),
2029 )
2030 .await
2031 .unwrap();
2032 let connection = NativeAgentConnection(agent.clone());
2033
2034 // Create a thread/session
2035 let acp_thread = cx
2036 .update(|cx| {
2037 Rc::new(connection.clone()).new_session(project.clone(), Path::new("/a"), cx)
2038 })
2039 .await
2040 .unwrap();
2041
2042 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2043
2044 // Select a model
2045 let selector = connection.model_selector(&session_id).unwrap();
2046 let model_id = acp::ModelId::new("fake/fake");
2047 cx.update(|cx| selector.select_model(model_id.clone(), cx))
2048 .await
2049 .unwrap();
2050
2051 // Verify the thread has the selected model
2052 agent.read_with(cx, |agent, _| {
2053 let session = agent.sessions.get(&session_id).unwrap();
2054 session.thread.read_with(cx, |thread, _| {
2055 assert_eq!(thread.model().unwrap().id().0, "fake");
2056 });
2057 });
2058
2059 cx.run_until_parked();
2060
2061 // Verify settings file was updated
2062 let settings_content = fs.load(paths::settings_file()).await.unwrap();
2063 let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
2064
2065 // Check that the agent settings contain the selected model
2066 assert_eq!(
2067 settings_json["agent"]["default_model"]["model"],
2068 json!("fake")
2069 );
2070 assert_eq!(
2071 settings_json["agent"]["default_model"]["provider"],
2072 json!("fake")
2073 );
2074
2075 // Register a thinking model and select it.
2076 cx.update(|cx| {
2077 let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2078 "fake-corp",
2079 "fake-thinking",
2080 "Fake Thinking",
2081 true,
2082 ));
2083 let thinking_provider = Arc::new(
2084 FakeLanguageModelProvider::new(
2085 LanguageModelProviderId::from("fake-corp".to_string()),
2086 LanguageModelProviderName::from("Fake Corp".to_string()),
2087 )
2088 .with_models(vec![thinking_model]),
2089 );
2090 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2091 registry.register_provider(thinking_provider, cx);
2092 });
2093 });
2094 agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2095
2096 let selector = connection.model_selector(&session_id).unwrap();
2097 cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2098 .await
2099 .unwrap();
2100 cx.run_until_parked();
2101
2102 // Verify enable_thinking was written to settings as true.
2103 let settings_content = fs.load(paths::settings_file()).await.unwrap();
2104 let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
2105 assert_eq!(
2106 settings_json["agent"]["default_model"]["enable_thinking"],
2107 json!(true),
2108 "selecting a thinking model should persist enable_thinking: true to settings"
2109 );
2110 }
2111
2112 #[gpui::test]
2113 async fn test_select_model_updates_thinking_enabled(cx: &mut TestAppContext) {
2114 init_test(cx);
2115 let fs = FakeFs::new(cx.executor());
2116 fs.create_dir(paths::settings_file().parent().unwrap())
2117 .await
2118 .unwrap();
2119 fs.insert_file(paths::settings_file(), b"{}".to_vec()).await;
2120 let project = Project::test(fs.clone(), [], cx).await;
2121
2122 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2123 let agent = NativeAgent::new(
2124 project.clone(),
2125 thread_store,
2126 Templates::new(),
2127 None,
2128 fs.clone(),
2129 &mut cx.to_async(),
2130 )
2131 .await
2132 .unwrap();
2133 let connection = NativeAgentConnection(agent.clone());
2134
2135 let acp_thread = cx
2136 .update(|cx| {
2137 Rc::new(connection.clone()).new_session(project.clone(), Path::new("/a"), cx)
2138 })
2139 .await
2140 .unwrap();
2141 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2142
2143 // Register a second provider with a thinking model.
2144 cx.update(|cx| {
2145 let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2146 "fake-corp",
2147 "fake-thinking",
2148 "Fake Thinking",
2149 true,
2150 ));
2151 let thinking_provider = Arc::new(
2152 FakeLanguageModelProvider::new(
2153 LanguageModelProviderId::from("fake-corp".to_string()),
2154 LanguageModelProviderName::from("Fake Corp".to_string()),
2155 )
2156 .with_models(vec![thinking_model]),
2157 );
2158 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2159 registry.register_provider(thinking_provider, cx);
2160 });
2161 });
2162 // Refresh the agent's model list so it picks up the new provider.
2163 agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2164
2165 // Thread starts with thinking_enabled = false (the default).
2166 agent.read_with(cx, |agent, _| {
2167 let session = agent.sessions.get(&session_id).unwrap();
2168 session.thread.read_with(cx, |thread, _| {
2169 assert!(!thread.thinking_enabled(), "thinking defaults to false");
2170 });
2171 });
2172
2173 // Select the thinking model via select_model.
2174 let selector = connection.model_selector(&session_id).unwrap();
2175 cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2176 .await
2177 .unwrap();
2178
2179 // select_model should have enabled thinking based on the model's supports_thinking().
2180 agent.read_with(cx, |agent, _| {
2181 let session = agent.sessions.get(&session_id).unwrap();
2182 session.thread.read_with(cx, |thread, _| {
2183 assert!(
2184 thread.thinking_enabled(),
2185 "select_model should enable thinking when model supports it"
2186 );
2187 });
2188 });
2189
2190 // Switch back to the non-thinking model.
2191 let selector = connection.model_selector(&session_id).unwrap();
2192 cx.update(|cx| selector.select_model(acp::ModelId::new("fake/fake"), cx))
2193 .await
2194 .unwrap();
2195
2196 // select_model should have disabled thinking.
2197 agent.read_with(cx, |agent, _| {
2198 let session = agent.sessions.get(&session_id).unwrap();
2199 session.thread.read_with(cx, |thread, _| {
2200 assert!(
2201 !thread.thinking_enabled(),
2202 "select_model should disable thinking when model does not support it"
2203 );
2204 });
2205 });
2206 }
2207
2208 #[gpui::test]
2209 async fn test_loaded_thread_preserves_thinking_enabled(cx: &mut TestAppContext) {
2210 init_test(cx);
2211 let fs = FakeFs::new(cx.executor());
2212 fs.insert_tree("/", json!({ "a": {} })).await;
2213 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2214 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2215 let agent = NativeAgent::new(
2216 project.clone(),
2217 thread_store.clone(),
2218 Templates::new(),
2219 None,
2220 fs.clone(),
2221 &mut cx.to_async(),
2222 )
2223 .await
2224 .unwrap();
2225 let connection = Rc::new(NativeAgentConnection(agent.clone()));
2226
2227 // Register a thinking model.
2228 let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2229 "fake-corp",
2230 "fake-thinking",
2231 "Fake Thinking",
2232 true,
2233 ));
2234 let thinking_provider = Arc::new(
2235 FakeLanguageModelProvider::new(
2236 LanguageModelProviderId::from("fake-corp".to_string()),
2237 LanguageModelProviderName::from("Fake Corp".to_string()),
2238 )
2239 .with_models(vec![thinking_model.clone()]),
2240 );
2241 cx.update(|cx| {
2242 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2243 registry.register_provider(thinking_provider, cx);
2244 });
2245 });
2246 agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2247
2248 // Create a thread and select the thinking model.
2249 let acp_thread = cx
2250 .update(|cx| {
2251 connection
2252 .clone()
2253 .new_session(project.clone(), Path::new("/a"), cx)
2254 })
2255 .await
2256 .unwrap();
2257 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2258
2259 let selector = connection.model_selector(&session_id).unwrap();
2260 cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2261 .await
2262 .unwrap();
2263
2264 // Verify thinking is enabled after selecting the thinking model.
2265 let thread = agent.read_with(cx, |agent, _| {
2266 agent.sessions.get(&session_id).unwrap().thread.clone()
2267 });
2268 thread.read_with(cx, |thread, _| {
2269 assert!(
2270 thread.thinking_enabled(),
2271 "thinking should be enabled after selecting thinking model"
2272 );
2273 });
2274
2275 // Send a message so the thread gets persisted.
2276 let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
2277 let send = cx.foreground_executor().spawn(send);
2278 cx.run_until_parked();
2279
2280 thinking_model.send_last_completion_stream_text_chunk("Response.");
2281 thinking_model.end_last_completion_stream();
2282
2283 send.await.unwrap();
2284 cx.run_until_parked();
2285
2286 // Close the session so it can be reloaded from disk.
2287 cx.update(|cx| connection.clone().close_session(&session_id, cx))
2288 .await
2289 .unwrap();
2290 drop(thread);
2291 drop(acp_thread);
2292 agent.read_with(cx, |agent, _| {
2293 assert!(agent.sessions.is_empty());
2294 });
2295
2296 // Reload the thread and verify thinking_enabled is still true.
2297 let reloaded_acp_thread = agent
2298 .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
2299 .await
2300 .unwrap();
2301 let reloaded_thread = agent.read_with(cx, |agent, _| {
2302 agent.sessions.get(&session_id).unwrap().thread.clone()
2303 });
2304 reloaded_thread.read_with(cx, |thread, _| {
2305 assert!(
2306 thread.thinking_enabled(),
2307 "thinking_enabled should be preserved when reloading a thread with a thinking model"
2308 );
2309 });
2310
2311 drop(reloaded_acp_thread);
2312 }
2313
2314 #[gpui::test]
2315 async fn test_loaded_thread_preserves_model(cx: &mut TestAppContext) {
2316 init_test(cx);
2317 let fs = FakeFs::new(cx.executor());
2318 fs.insert_tree("/", json!({ "a": {} })).await;
2319 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2320 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2321 let agent = NativeAgent::new(
2322 project.clone(),
2323 thread_store.clone(),
2324 Templates::new(),
2325 None,
2326 fs.clone(),
2327 &mut cx.to_async(),
2328 )
2329 .await
2330 .unwrap();
2331 let connection = Rc::new(NativeAgentConnection(agent.clone()));
2332
2333 // Register a model where id() != name(), like real Anthropic models
2334 // (e.g. id="claude-sonnet-4-5-thinking-latest", name="Claude Sonnet 4.5 Thinking").
2335 let model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2336 "fake-corp",
2337 "custom-model-id",
2338 "Custom Model Display Name",
2339 false,
2340 ));
2341 let provider = Arc::new(
2342 FakeLanguageModelProvider::new(
2343 LanguageModelProviderId::from("fake-corp".to_string()),
2344 LanguageModelProviderName::from("Fake Corp".to_string()),
2345 )
2346 .with_models(vec![model.clone()]),
2347 );
2348 cx.update(|cx| {
2349 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2350 registry.register_provider(provider, cx);
2351 });
2352 });
2353 agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2354
2355 // Create a thread and select the model.
2356 let acp_thread = cx
2357 .update(|cx| {
2358 connection
2359 .clone()
2360 .new_session(project.clone(), Path::new("/a"), cx)
2361 })
2362 .await
2363 .unwrap();
2364 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2365
2366 let selector = connection.model_selector(&session_id).unwrap();
2367 cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/custom-model-id"), cx))
2368 .await
2369 .unwrap();
2370
2371 let thread = agent.read_with(cx, |agent, _| {
2372 agent.sessions.get(&session_id).unwrap().thread.clone()
2373 });
2374 thread.read_with(cx, |thread, _| {
2375 assert_eq!(
2376 thread.model().unwrap().id().0.as_ref(),
2377 "custom-model-id",
2378 "model should be set before persisting"
2379 );
2380 });
2381
2382 // Send a message so the thread gets persisted.
2383 let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
2384 let send = cx.foreground_executor().spawn(send);
2385 cx.run_until_parked();
2386
2387 model.send_last_completion_stream_text_chunk("Response.");
2388 model.end_last_completion_stream();
2389
2390 send.await.unwrap();
2391 cx.run_until_parked();
2392
2393 // Close the session so it can be reloaded from disk.
2394 cx.update(|cx| connection.clone().close_session(&session_id, cx))
2395 .await
2396 .unwrap();
2397 drop(thread);
2398 drop(acp_thread);
2399 agent.read_with(cx, |agent, _| {
2400 assert!(agent.sessions.is_empty());
2401 });
2402
2403 // Reload the thread and verify the model was preserved.
2404 let reloaded_acp_thread = agent
2405 .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
2406 .await
2407 .unwrap();
2408 let reloaded_thread = agent.read_with(cx, |agent, _| {
2409 agent.sessions.get(&session_id).unwrap().thread.clone()
2410 });
2411 reloaded_thread.read_with(cx, |thread, _| {
2412 let reloaded_model = thread
2413 .model()
2414 .expect("model should be present after reload");
2415 assert_eq!(
2416 reloaded_model.id().0.as_ref(),
2417 "custom-model-id",
2418 "reloaded thread should have the same model, not fall back to the default"
2419 );
2420 });
2421
2422 drop(reloaded_acp_thread);
2423 }
2424
2425 #[gpui::test]
2426 async fn test_save_load_thread(cx: &mut TestAppContext) {
2427 init_test(cx);
2428 let fs = FakeFs::new(cx.executor());
2429 fs.insert_tree(
2430 "/",
2431 json!({
2432 "a": {
2433 "b.md": "Lorem"
2434 }
2435 }),
2436 )
2437 .await;
2438 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2439 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2440 let agent = NativeAgent::new(
2441 project.clone(),
2442 thread_store.clone(),
2443 Templates::new(),
2444 None,
2445 fs.clone(),
2446 &mut cx.to_async(),
2447 )
2448 .await
2449 .unwrap();
2450 let connection = Rc::new(NativeAgentConnection(agent.clone()));
2451
2452 let acp_thread = cx
2453 .update(|cx| {
2454 connection
2455 .clone()
2456 .new_session(project.clone(), Path::new(""), cx)
2457 })
2458 .await
2459 .unwrap();
2460 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2461 let thread = agent.read_with(cx, |agent, _| {
2462 agent.sessions.get(&session_id).unwrap().thread.clone()
2463 });
2464
2465 // Ensure empty threads are not saved, even if they get mutated.
2466 let model = Arc::new(FakeLanguageModel::default());
2467 let summary_model = Arc::new(FakeLanguageModel::default());
2468 thread.update(cx, |thread, cx| {
2469 thread.set_model(model.clone(), cx);
2470 thread.set_summarization_model(Some(summary_model.clone()), cx);
2471 });
2472 cx.run_until_parked();
2473 assert_eq!(thread_entries(&thread_store, cx), vec![]);
2474
2475 let send = acp_thread.update(cx, |thread, cx| {
2476 thread.send(
2477 vec![
2478 "What does ".into(),
2479 acp::ContentBlock::ResourceLink(acp::ResourceLink::new(
2480 "b.md",
2481 MentionUri::File {
2482 abs_path: path!("/a/b.md").into(),
2483 }
2484 .to_uri()
2485 .to_string(),
2486 )),
2487 " mean?".into(),
2488 ],
2489 cx,
2490 )
2491 });
2492 let send = cx.foreground_executor().spawn(send);
2493 cx.run_until_parked();
2494
2495 model.send_last_completion_stream_text_chunk("Lorem.");
2496 model.end_last_completion_stream();
2497 cx.run_until_parked();
2498 summary_model
2499 .send_last_completion_stream_text_chunk(&format!("Explaining {}", path!("/a/b.md")));
2500 summary_model.end_last_completion_stream();
2501
2502 send.await.unwrap();
2503 let uri = MentionUri::File {
2504 abs_path: path!("/a/b.md").into(),
2505 }
2506 .to_uri();
2507 acp_thread.read_with(cx, |thread, cx| {
2508 assert_eq!(
2509 thread.to_markdown(cx),
2510 formatdoc! {"
2511 ## User
2512
2513 What does [@b.md]({uri}) mean?
2514
2515 ## Assistant
2516
2517 Lorem.
2518
2519 "}
2520 )
2521 });
2522
2523 cx.run_until_parked();
2524
2525 // Close the session so it can be reloaded from disk.
2526 cx.update(|cx| connection.clone().close_session(&session_id, cx))
2527 .await
2528 .unwrap();
2529 drop(thread);
2530 drop(acp_thread);
2531 agent.read_with(cx, |agent, _| {
2532 assert_eq!(agent.sessions.keys().cloned().collect::<Vec<_>>(), []);
2533 });
2534
2535 // Ensure the thread can be reloaded from disk.
2536 assert_eq!(
2537 thread_entries(&thread_store, cx),
2538 vec![(
2539 session_id.clone(),
2540 format!("Explaining {}", path!("/a/b.md"))
2541 )]
2542 );
2543 let acp_thread = agent
2544 .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
2545 .await
2546 .unwrap();
2547 acp_thread.read_with(cx, |thread, cx| {
2548 assert_eq!(
2549 thread.to_markdown(cx),
2550 formatdoc! {"
2551 ## User
2552
2553 What does [@b.md]({uri}) mean?
2554
2555 ## Assistant
2556
2557 Lorem.
2558
2559 "}
2560 )
2561 });
2562 }
2563
2564 fn thread_entries(
2565 thread_store: &Entity<ThreadStore>,
2566 cx: &mut TestAppContext,
2567 ) -> Vec<(acp::SessionId, String)> {
2568 thread_store.read_with(cx, |store, _| {
2569 store
2570 .entries()
2571 .map(|entry| (entry.id.clone(), entry.title.to_string()))
2572 .collect::<Vec<_>>()
2573 })
2574 }
2575
2576 fn init_test(cx: &mut TestAppContext) {
2577 env_logger::try_init().ok();
2578 cx.update(|cx| {
2579 let settings_store = SettingsStore::test(cx);
2580 cx.set_global(settings_store);
2581
2582 LanguageModelRegistry::test(cx);
2583 });
2584 }
2585}
2586
2587fn mcp_message_content_to_acp_content_block(
2588 content: context_server::types::MessageContent,
2589) -> acp::ContentBlock {
2590 match content {
2591 context_server::types::MessageContent::Text {
2592 text,
2593 annotations: _,
2594 } => text.into(),
2595 context_server::types::MessageContent::Image {
2596 data,
2597 mime_type,
2598 annotations: _,
2599 } => acp::ContentBlock::Image(acp::ImageContent::new(data, mime_type)),
2600 context_server::types::MessageContent::Audio {
2601 data,
2602 mime_type,
2603 annotations: _,
2604 } => acp::ContentBlock::Audio(acp::AudioContent::new(data, mime_type)),
2605 context_server::types::MessageContent::Resource {
2606 resource,
2607 annotations: _,
2608 } => {
2609 let mut link =
2610 acp::ResourceLink::new(resource.uri.to_string(), resource.uri.to_string());
2611 if let Some(mime_type) = resource.mime_type {
2612 link = link.mime_type(mime_type);
2613 }
2614 acp::ContentBlock::ResourceLink(link)
2615 }
2616 }
2617}