1mod db;
2mod edit_agent;
3mod legacy_thread;
4mod native_agent_server;
5pub mod outline;
6mod pattern_extraction;
7mod templates;
8#[cfg(test)]
9mod tests;
10mod thread;
11mod thread_store;
12mod tool_permissions;
13mod tools;
14
15use context_server::ContextServerId;
16pub use db::*;
17pub use native_agent_server::NativeAgentServer;
18pub use pattern_extraction::*;
19pub use shell_command_parser::extract_commands;
20pub use templates::*;
21pub use thread::*;
22pub use thread_store::*;
23pub use tool_permissions::*;
24pub use tools::*;
25
26use acp_thread::{
27 AcpThread, AgentModelSelector, AgentSessionInfo, AgentSessionList, AgentSessionListRequest,
28 AgentSessionListResponse, UserMessageId,
29};
30use agent_client_protocol as acp;
31use anyhow::{Context as _, Result, anyhow};
32use chrono::{DateTime, Utc};
33use collections::{HashMap, HashSet, IndexMap};
34use fs::Fs;
35use futures::channel::{mpsc, oneshot};
36use futures::future::Shared;
37use futures::{FutureExt as _, StreamExt as _, future};
38use gpui::{
39 App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
40};
41use language_model::{IconOrSvg, LanguageModel, LanguageModelProvider, LanguageModelRegistry};
42use project::{Project, ProjectItem, ProjectPath, Worktree};
43use prompt_store::{
44 ProjectContext, PromptStore, RULES_FILE_NAMES, RulesFileContext, UserRulesContext,
45 WorktreeContext,
46};
47use serde::{Deserialize, Serialize};
48use settings::{LanguageModelSelection, update_settings_file};
49use std::any::Any;
50use std::path::{Path, PathBuf};
51use std::rc::Rc;
52use std::sync::Arc;
53use std::time::Duration;
54use util::ResultExt;
55use util::rel_path::RelPath;
56
57#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
58pub struct ProjectSnapshot {
59 pub worktree_snapshots: Vec<project::telemetry_snapshot::TelemetryWorktreeSnapshot>,
60 pub timestamp: DateTime<Utc>,
61}
62
63pub struct RulesLoadingError {
64 pub message: SharedString,
65}
66
67/// Holds both the internal Thread and the AcpThread for a session
68struct Session {
69 /// The internal thread that processes messages
70 thread: Entity<Thread>,
71 /// The ACP thread that handles protocol communication
72 acp_thread: Entity<acp_thread::AcpThread>,
73 pending_save: Task<()>,
74 _subscriptions: Vec<Subscription>,
75}
76
77pub struct LanguageModels {
78 /// Access language model by ID
79 models: HashMap<acp::ModelId, Arc<dyn LanguageModel>>,
80 /// Cached list for returning language model information
81 model_list: acp_thread::AgentModelList,
82 refresh_models_rx: watch::Receiver<()>,
83 refresh_models_tx: watch::Sender<()>,
84 _authenticate_all_providers_task: Task<()>,
85}
86
87impl LanguageModels {
88 fn new(cx: &mut App) -> Self {
89 let (refresh_models_tx, refresh_models_rx) = watch::channel(());
90
91 let mut this = Self {
92 models: HashMap::default(),
93 model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()),
94 refresh_models_rx,
95 refresh_models_tx,
96 _authenticate_all_providers_task: Self::authenticate_all_language_model_providers(cx),
97 };
98 this.refresh_list(cx);
99 this
100 }
101
102 fn refresh_list(&mut self, cx: &App) {
103 let providers = LanguageModelRegistry::global(cx)
104 .read(cx)
105 .visible_providers()
106 .into_iter()
107 .filter(|provider| provider.is_authenticated(cx))
108 .collect::<Vec<_>>();
109
110 let mut language_model_list = IndexMap::default();
111 let mut recommended_models = HashSet::default();
112
113 let mut recommended = Vec::new();
114 for provider in &providers {
115 for model in provider.recommended_models(cx) {
116 recommended_models.insert((model.provider_id(), model.id()));
117 recommended.push(Self::map_language_model_to_info(&model, provider));
118 }
119 }
120 if !recommended.is_empty() {
121 language_model_list.insert(
122 acp_thread::AgentModelGroupName("Recommended".into()),
123 recommended,
124 );
125 }
126
127 let mut models = HashMap::default();
128 for provider in providers {
129 let mut provider_models = Vec::new();
130 for model in provider.provided_models(cx) {
131 let model_info = Self::map_language_model_to_info(&model, &provider);
132 let model_id = model_info.id.clone();
133 provider_models.push(model_info);
134 models.insert(model_id, model);
135 }
136 if !provider_models.is_empty() {
137 language_model_list.insert(
138 acp_thread::AgentModelGroupName(provider.name().0.clone()),
139 provider_models,
140 );
141 }
142 }
143
144 self.models = models;
145 self.model_list = acp_thread::AgentModelList::Grouped(language_model_list);
146 self.refresh_models_tx.send(()).ok();
147 }
148
149 fn watch(&self) -> watch::Receiver<()> {
150 self.refresh_models_rx.clone()
151 }
152
153 pub fn model_from_id(&self, model_id: &acp::ModelId) -> Option<Arc<dyn LanguageModel>> {
154 self.models.get(model_id).cloned()
155 }
156
157 fn map_language_model_to_info(
158 model: &Arc<dyn LanguageModel>,
159 provider: &Arc<dyn LanguageModelProvider>,
160 ) -> acp_thread::AgentModelInfo {
161 acp_thread::AgentModelInfo {
162 id: Self::model_id(model),
163 name: model.name().0,
164 description: None,
165 icon: Some(match provider.icon() {
166 IconOrSvg::Svg(path) => acp_thread::AgentModelIcon::Path(path),
167 IconOrSvg::Icon(name) => acp_thread::AgentModelIcon::Named(name),
168 }),
169 is_latest: model.is_latest(),
170 cost: model.model_cost_info().map(|cost| cost.to_shared_string()),
171 }
172 }
173
174 fn model_id(model: &Arc<dyn LanguageModel>) -> acp::ModelId {
175 acp::ModelId::new(format!("{}/{}", model.provider_id().0, model.id().0))
176 }
177
178 fn authenticate_all_language_model_providers(cx: &mut App) -> Task<()> {
179 let authenticate_all_providers = LanguageModelRegistry::global(cx)
180 .read(cx)
181 .visible_providers()
182 .iter()
183 .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
184 .collect::<Vec<_>>();
185
186 cx.background_spawn(async move {
187 for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
188 if let Err(err) = authenticate_task.await {
189 match err {
190 language_model::AuthenticateError::CredentialsNotFound => {
191 // Since we're authenticating these providers in the
192 // background for the purposes of populating the
193 // language selector, we don't care about providers
194 // where the credentials are not found.
195 }
196 language_model::AuthenticateError::ConnectionRefused => {
197 // Not logging connection refused errors as they are mostly from LM Studio's noisy auth failures.
198 // LM Studio only has one auth method (endpoint call) which fails for users who haven't enabled it.
199 // TODO: Better manage LM Studio auth logic to avoid these noisy failures.
200 }
201 _ => {
202 // Some providers have noisy failure states that we
203 // don't want to spam the logs with every time the
204 // language model selector is initialized.
205 //
206 // Ideally these should have more clear failure modes
207 // that we know are safe to ignore here, like what we do
208 // with `CredentialsNotFound` above.
209 match provider_id.0.as_ref() {
210 "lmstudio" | "ollama" => {
211 // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
212 //
213 // These fail noisily, so we don't log them.
214 }
215 "copilot_chat" => {
216 // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
217 }
218 _ => {
219 log::error!(
220 "Failed to authenticate provider: {}: {err:#}",
221 provider_name.0
222 );
223 }
224 }
225 }
226 }
227 }
228 }
229 })
230 }
231}
232
233pub struct NativeAgent {
234 /// Session ID -> Session mapping
235 sessions: HashMap<acp::SessionId, Session>,
236 thread_store: Entity<ThreadStore>,
237 /// Shared project context for all threads
238 project_context: Entity<ProjectContext>,
239 project_context_needs_refresh: watch::Sender<()>,
240 _maintain_project_context: Task<Result<()>>,
241 context_server_registry: Entity<ContextServerRegistry>,
242 /// Shared templates for all threads
243 templates: Arc<Templates>,
244 /// Cached model information
245 models: LanguageModels,
246 project: Entity<Project>,
247 prompt_store: Option<Entity<PromptStore>>,
248 fs: Arc<dyn Fs>,
249 _subscriptions: Vec<Subscription>,
250}
251
252impl NativeAgent {
253 pub async fn new(
254 project: Entity<Project>,
255 thread_store: Entity<ThreadStore>,
256 templates: Arc<Templates>,
257 prompt_store: Option<Entity<PromptStore>>,
258 fs: Arc<dyn Fs>,
259 cx: &mut AsyncApp,
260 ) -> Result<Entity<NativeAgent>> {
261 log::debug!("Creating new NativeAgent");
262
263 let project_context = cx
264 .update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))
265 .await;
266
267 Ok(cx.new(|cx| {
268 let context_server_store = project.read(cx).context_server_store();
269 let context_server_registry =
270 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
271
272 let mut subscriptions = vec![
273 cx.subscribe(&project, Self::handle_project_event),
274 cx.subscribe(
275 &LanguageModelRegistry::global(cx),
276 Self::handle_models_updated_event,
277 ),
278 cx.subscribe(
279 &context_server_store,
280 Self::handle_context_server_store_updated,
281 ),
282 cx.subscribe(
283 &context_server_registry,
284 Self::handle_context_server_registry_event,
285 ),
286 ];
287 if let Some(prompt_store) = prompt_store.as_ref() {
288 subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
289 }
290
291 let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
292 watch::channel(());
293 Self {
294 sessions: HashMap::default(),
295 thread_store,
296 project_context: cx.new(|_| project_context),
297 project_context_needs_refresh: project_context_needs_refresh_tx,
298 _maintain_project_context: cx.spawn(async move |this, cx| {
299 Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
300 }),
301 context_server_registry,
302 templates,
303 models: LanguageModels::new(cx),
304 project,
305 prompt_store,
306 fs,
307 _subscriptions: subscriptions,
308 }
309 }))
310 }
311
312 fn new_session(
313 &mut self,
314 project: Entity<Project>,
315 cx: &mut Context<Self>,
316 ) -> Entity<AcpThread> {
317 // Create Thread
318 // Fetch default model from registry settings
319 let registry = LanguageModelRegistry::read_global(cx);
320 // Log available models for debugging
321 let available_count = registry.available_models(cx).count();
322 log::debug!("Total available models: {}", available_count);
323
324 let default_model = registry.default_model().and_then(|default_model| {
325 self.models
326 .model_from_id(&LanguageModels::model_id(&default_model.model))
327 });
328 let thread = cx.new(|cx| {
329 Thread::new(
330 project.clone(),
331 self.project_context.clone(),
332 self.context_server_registry.clone(),
333 self.templates.clone(),
334 default_model,
335 cx,
336 )
337 });
338
339 self.register_session(thread, None, cx)
340 }
341
342 fn register_session(
343 &mut self,
344 thread_handle: Entity<Thread>,
345 allowed_tool_names: Option<Vec<&str>>,
346 cx: &mut Context<Self>,
347 ) -> Entity<AcpThread> {
348 let connection = Rc::new(NativeAgentConnection(cx.entity()));
349
350 let thread = thread_handle.read(cx);
351 let session_id = thread.id().clone();
352 let parent_session_id = thread.parent_thread_id();
353 let title = thread.title();
354 let project = thread.project.clone();
355 let action_log = thread.action_log.clone();
356 let prompt_capabilities_rx = thread.prompt_capabilities_rx.clone();
357 let acp_thread = cx.new(|cx| {
358 acp_thread::AcpThread::new(
359 parent_session_id,
360 title,
361 connection,
362 project.clone(),
363 action_log.clone(),
364 session_id.clone(),
365 prompt_capabilities_rx,
366 cx,
367 )
368 });
369
370 let registry = LanguageModelRegistry::read_global(cx);
371 let summarization_model = registry.thread_summary_model().map(|c| c.model);
372
373 let weak = cx.weak_entity();
374 thread_handle.update(cx, |thread, cx| {
375 thread.set_summarization_model(summarization_model, cx);
376 thread.add_default_tools(
377 allowed_tool_names,
378 Rc::new(NativeThreadEnvironment {
379 acp_thread: acp_thread.downgrade(),
380 agent: weak,
381 }) as _,
382 cx,
383 )
384 });
385
386 let subscriptions = vec![
387 cx.subscribe(&thread_handle, Self::handle_thread_title_updated),
388 cx.subscribe(&thread_handle, Self::handle_thread_token_usage_updated),
389 cx.observe(&thread_handle, move |this, thread, cx| {
390 this.save_thread(thread, cx)
391 }),
392 ];
393
394 self.sessions.insert(
395 session_id,
396 Session {
397 thread: thread_handle,
398 acp_thread: acp_thread.clone(),
399 _subscriptions: subscriptions,
400 pending_save: Task::ready(()),
401 },
402 );
403
404 self.update_available_commands(cx);
405
406 acp_thread
407 }
408
409 pub fn models(&self) -> &LanguageModels {
410 &self.models
411 }
412
413 async fn maintain_project_context(
414 this: WeakEntity<Self>,
415 mut needs_refresh: watch::Receiver<()>,
416 cx: &mut AsyncApp,
417 ) -> Result<()> {
418 while needs_refresh.changed().await.is_ok() {
419 let project_context = this
420 .update(cx, |this, cx| {
421 Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx)
422 })?
423 .await;
424 this.update(cx, |this, cx| {
425 this.project_context = cx.new(|_| project_context);
426 })?;
427 }
428
429 Ok(())
430 }
431
432 fn build_project_context(
433 project: &Entity<Project>,
434 prompt_store: Option<&Entity<PromptStore>>,
435 cx: &mut App,
436 ) -> Task<ProjectContext> {
437 let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
438 let worktree_tasks = worktrees
439 .into_iter()
440 .map(|worktree| {
441 Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
442 })
443 .collect::<Vec<_>>();
444 let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
445 prompt_store.read_with(cx, |prompt_store, cx| {
446 let prompts = prompt_store.default_prompt_metadata();
447 let load_tasks = prompts.into_iter().map(|prompt_metadata| {
448 let contents = prompt_store.load(prompt_metadata.id, cx);
449 async move { (contents.await, prompt_metadata) }
450 });
451 cx.background_spawn(future::join_all(load_tasks))
452 })
453 } else {
454 Task::ready(vec![])
455 };
456
457 cx.spawn(async move |_cx| {
458 let (worktrees, default_user_rules) =
459 future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
460
461 let worktrees = worktrees
462 .into_iter()
463 .map(|(worktree, _rules_error)| {
464 // TODO: show error message
465 // if let Some(rules_error) = rules_error {
466 // this.update(cx, |_, cx| cx.emit(rules_error)).ok();
467 // }
468 worktree
469 })
470 .collect::<Vec<_>>();
471
472 let default_user_rules = default_user_rules
473 .into_iter()
474 .flat_map(|(contents, prompt_metadata)| match contents {
475 Ok(contents) => Some(UserRulesContext {
476 uuid: prompt_metadata.id.as_user()?,
477 title: prompt_metadata.title.map(|title| title.to_string()),
478 contents,
479 }),
480 Err(_err) => {
481 // TODO: show error message
482 // this.update(cx, |_, cx| {
483 // cx.emit(RulesLoadingError {
484 // message: format!("{err:?}").into(),
485 // });
486 // })
487 // .ok();
488 None
489 }
490 })
491 .collect::<Vec<_>>();
492
493 ProjectContext::new(worktrees, default_user_rules)
494 })
495 }
496
497 fn load_worktree_info_for_system_prompt(
498 worktree: Entity<Worktree>,
499 project: Entity<Project>,
500 cx: &mut App,
501 ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
502 let tree = worktree.read(cx);
503 let root_name = tree.root_name_str().into();
504 let abs_path = tree.abs_path();
505
506 let mut context = WorktreeContext {
507 root_name,
508 abs_path,
509 rules_file: None,
510 };
511
512 let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
513 let Some(rules_task) = rules_task else {
514 return Task::ready((context, None));
515 };
516
517 cx.spawn(async move |_| {
518 let (rules_file, rules_file_error) = match rules_task.await {
519 Ok(rules_file) => (Some(rules_file), None),
520 Err(err) => (
521 None,
522 Some(RulesLoadingError {
523 message: format!("{err}").into(),
524 }),
525 ),
526 };
527 context.rules_file = rules_file;
528 (context, rules_file_error)
529 })
530 }
531
532 fn load_worktree_rules_file(
533 worktree: Entity<Worktree>,
534 project: Entity<Project>,
535 cx: &mut App,
536 ) -> Option<Task<Result<RulesFileContext>>> {
537 let worktree = worktree.read(cx);
538 let worktree_id = worktree.id();
539 let selected_rules_file = RULES_FILE_NAMES
540 .into_iter()
541 .filter_map(|name| {
542 worktree
543 .entry_for_path(RelPath::unix(name).unwrap())
544 .filter(|entry| entry.is_file())
545 .map(|entry| entry.path.clone())
546 })
547 .next();
548
549 // Note that Cline supports `.clinerules` being a directory, but that is not currently
550 // supported. This doesn't seem to occur often in GitHub repositories.
551 selected_rules_file.map(|path_in_worktree| {
552 let project_path = ProjectPath {
553 worktree_id,
554 path: path_in_worktree.clone(),
555 };
556 let buffer_task =
557 project.update(cx, |project, cx| project.open_buffer(project_path, cx));
558 let rope_task = cx.spawn(async move |cx| {
559 let buffer = buffer_task.await?;
560 let (project_entry_id, rope) = buffer.read_with(cx, |buffer, cx| {
561 let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
562 anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
563 })?;
564 anyhow::Ok((project_entry_id, rope))
565 });
566 // Build a string from the rope on a background thread.
567 cx.background_spawn(async move {
568 let (project_entry_id, rope) = rope_task.await?;
569 anyhow::Ok(RulesFileContext {
570 path_in_worktree,
571 text: rope.to_string().trim().to_string(),
572 project_entry_id: project_entry_id.to_usize(),
573 })
574 })
575 })
576 }
577
578 fn handle_thread_title_updated(
579 &mut self,
580 thread: Entity<Thread>,
581 _: &TitleUpdated,
582 cx: &mut Context<Self>,
583 ) {
584 let session_id = thread.read(cx).id();
585 let Some(session) = self.sessions.get(session_id) else {
586 return;
587 };
588 let thread = thread.downgrade();
589 let acp_thread = session.acp_thread.downgrade();
590 cx.spawn(async move |_, cx| {
591 let title = thread.read_with(cx, |thread, _| thread.title())?;
592 let task = acp_thread.update(cx, |acp_thread, cx| acp_thread.set_title(title, cx))?;
593 task.await
594 })
595 .detach_and_log_err(cx);
596 }
597
598 fn handle_thread_token_usage_updated(
599 &mut self,
600 thread: Entity<Thread>,
601 usage: &TokenUsageUpdated,
602 cx: &mut Context<Self>,
603 ) {
604 let Some(session) = self.sessions.get(thread.read(cx).id()) else {
605 return;
606 };
607 session.acp_thread.update(cx, |acp_thread, cx| {
608 acp_thread.update_token_usage(usage.0.clone(), cx);
609 });
610 }
611
612 fn handle_project_event(
613 &mut self,
614 _project: Entity<Project>,
615 event: &project::Event,
616 _cx: &mut Context<Self>,
617 ) {
618 match event {
619 project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
620 self.project_context_needs_refresh.send(()).ok();
621 }
622 project::Event::WorktreeUpdatedEntries(_, items) => {
623 if items.iter().any(|(path, _, _)| {
624 RULES_FILE_NAMES
625 .iter()
626 .any(|name| path.as_ref() == RelPath::unix(name).unwrap())
627 }) {
628 self.project_context_needs_refresh.send(()).ok();
629 }
630 }
631 _ => {}
632 }
633 }
634
635 fn handle_prompts_updated_event(
636 &mut self,
637 _prompt_store: Entity<PromptStore>,
638 _event: &prompt_store::PromptsUpdatedEvent,
639 _cx: &mut Context<Self>,
640 ) {
641 self.project_context_needs_refresh.send(()).ok();
642 }
643
644 fn handle_models_updated_event(
645 &mut self,
646 _registry: Entity<LanguageModelRegistry>,
647 _event: &language_model::Event,
648 cx: &mut Context<Self>,
649 ) {
650 self.models.refresh_list(cx);
651
652 let registry = LanguageModelRegistry::read_global(cx);
653 let default_model = registry.default_model().map(|m| m.model);
654 let summarization_model = registry.thread_summary_model().map(|m| m.model);
655
656 for session in self.sessions.values_mut() {
657 session.thread.update(cx, |thread, cx| {
658 if thread.model().is_none()
659 && let Some(model) = default_model.clone()
660 {
661 thread.set_model(model, cx);
662 cx.notify();
663 }
664 thread.set_summarization_model(summarization_model.clone(), cx);
665 });
666 }
667 }
668
669 fn handle_context_server_store_updated(
670 &mut self,
671 _store: Entity<project::context_server_store::ContextServerStore>,
672 _event: &project::context_server_store::ServerStatusChangedEvent,
673 cx: &mut Context<Self>,
674 ) {
675 self.update_available_commands(cx);
676 }
677
678 fn handle_context_server_registry_event(
679 &mut self,
680 _registry: Entity<ContextServerRegistry>,
681 event: &ContextServerRegistryEvent,
682 cx: &mut Context<Self>,
683 ) {
684 match event {
685 ContextServerRegistryEvent::ToolsChanged => {}
686 ContextServerRegistryEvent::PromptsChanged => {
687 self.update_available_commands(cx);
688 }
689 }
690 }
691
692 fn update_available_commands(&self, cx: &mut Context<Self>) {
693 let available_commands = self.build_available_commands(cx);
694 for session in self.sessions.values() {
695 session.acp_thread.update(cx, |thread, cx| {
696 thread
697 .handle_session_update(
698 acp::SessionUpdate::AvailableCommandsUpdate(
699 acp::AvailableCommandsUpdate::new(available_commands.clone()),
700 ),
701 cx,
702 )
703 .log_err();
704 });
705 }
706 }
707
708 fn build_available_commands(&self, cx: &App) -> Vec<acp::AvailableCommand> {
709 let registry = self.context_server_registry.read(cx);
710
711 let mut prompt_name_counts: HashMap<&str, usize> = HashMap::default();
712 for context_server_prompt in registry.prompts() {
713 *prompt_name_counts
714 .entry(context_server_prompt.prompt.name.as_str())
715 .or_insert(0) += 1;
716 }
717
718 registry
719 .prompts()
720 .flat_map(|context_server_prompt| {
721 let prompt = &context_server_prompt.prompt;
722
723 let should_prefix = prompt_name_counts
724 .get(prompt.name.as_str())
725 .copied()
726 .unwrap_or(0)
727 > 1;
728
729 let name = if should_prefix {
730 format!("{}.{}", context_server_prompt.server_id, prompt.name)
731 } else {
732 prompt.name.clone()
733 };
734
735 let mut command = acp::AvailableCommand::new(
736 name,
737 prompt.description.clone().unwrap_or_default(),
738 );
739
740 match prompt.arguments.as_deref() {
741 Some([arg]) => {
742 let hint = format!("<{}>", arg.name);
743
744 command = command.input(acp::AvailableCommandInput::Unstructured(
745 acp::UnstructuredCommandInput::new(hint),
746 ));
747 }
748 Some([]) | None => {}
749 Some(_) => {
750 // skip >1 argument commands since we don't support them yet
751 return None;
752 }
753 }
754
755 Some(command)
756 })
757 .collect()
758 }
759
760 pub fn load_thread(
761 &mut self,
762 id: acp::SessionId,
763 cx: &mut Context<Self>,
764 ) -> Task<Result<Entity<Thread>>> {
765 let database_future = ThreadsDatabase::connect(cx);
766 cx.spawn(async move |this, cx| {
767 let database = database_future.await.map_err(|err| anyhow!(err))?;
768 let db_thread = database
769 .load_thread(id.clone())
770 .await?
771 .with_context(|| format!("no thread found with ID: {id:?}"))?;
772
773 this.update(cx, |this, cx| {
774 let summarization_model = LanguageModelRegistry::read_global(cx)
775 .thread_summary_model()
776 .map(|c| c.model);
777
778 cx.new(|cx| {
779 let mut thread = Thread::from_db(
780 id.clone(),
781 db_thread,
782 this.project.clone(),
783 this.project_context.clone(),
784 this.context_server_registry.clone(),
785 this.templates.clone(),
786 cx,
787 );
788 thread.set_summarization_model(summarization_model, cx);
789 thread
790 })
791 })
792 })
793 }
794
795 pub fn open_thread(
796 &mut self,
797 id: acp::SessionId,
798 cx: &mut Context<Self>,
799 ) -> Task<Result<Entity<AcpThread>>> {
800 if let Some(session) = self.sessions.get(&id) {
801 return Task::ready(Ok(session.acp_thread.clone()));
802 }
803
804 let task = self.load_thread(id, cx);
805 cx.spawn(async move |this, cx| {
806 let thread = task.await?;
807 let acp_thread = this.update(cx, |this, cx| {
808 this.register_session(thread.clone(), None, cx)
809 })?;
810 let events = thread.update(cx, |thread, cx| thread.replay(cx));
811 cx.update(|cx| {
812 NativeAgentConnection::handle_thread_events(events, acp_thread.downgrade(), cx)
813 })
814 .await?;
815 Ok(acp_thread)
816 })
817 }
818
819 pub fn thread_summary(
820 &mut self,
821 id: acp::SessionId,
822 cx: &mut Context<Self>,
823 ) -> Task<Result<SharedString>> {
824 let thread = self.open_thread(id.clone(), cx);
825 cx.spawn(async move |this, cx| {
826 let acp_thread = thread.await?;
827 let result = this
828 .update(cx, |this, cx| {
829 this.sessions
830 .get(&id)
831 .unwrap()
832 .thread
833 .update(cx, |thread, cx| thread.summary(cx))
834 })?
835 .await
836 .context("Failed to generate summary")?;
837 drop(acp_thread);
838 Ok(result)
839 })
840 }
841
842 fn save_thread(&mut self, thread: Entity<Thread>, cx: &mut Context<Self>) {
843 if thread.read(cx).is_empty() {
844 return;
845 }
846
847 let database_future = ThreadsDatabase::connect(cx);
848 let (id, db_thread) =
849 thread.update(cx, |thread, cx| (thread.id().clone(), thread.to_db(cx)));
850 let Some(session) = self.sessions.get_mut(&id) else {
851 return;
852 };
853 let thread_store = self.thread_store.clone();
854 session.pending_save = cx.spawn(async move |_, cx| {
855 let Some(database) = database_future.await.map_err(|err| anyhow!(err)).log_err() else {
856 return;
857 };
858 let db_thread = db_thread.await;
859 database.save_thread(id, db_thread).await.log_err();
860 thread_store.update(cx, |store, cx| store.reload(cx));
861 });
862 }
863
864 fn send_mcp_prompt(
865 &self,
866 message_id: UserMessageId,
867 session_id: agent_client_protocol::SessionId,
868 prompt_name: String,
869 server_id: ContextServerId,
870 arguments: HashMap<String, String>,
871 original_content: Vec<acp::ContentBlock>,
872 cx: &mut Context<Self>,
873 ) -> Task<Result<acp::PromptResponse>> {
874 let server_store = self.context_server_registry.read(cx).server_store().clone();
875 let path_style = self.project.read(cx).path_style(cx);
876
877 cx.spawn(async move |this, cx| {
878 let prompt =
879 crate::get_prompt(&server_store, &server_id, &prompt_name, arguments, cx).await?;
880
881 let (acp_thread, thread) = this.update(cx, |this, _cx| {
882 let session = this
883 .sessions
884 .get(&session_id)
885 .context("Failed to get session")?;
886 anyhow::Ok((session.acp_thread.clone(), session.thread.clone()))
887 })??;
888
889 let mut last_is_user = true;
890
891 thread.update(cx, |thread, cx| {
892 thread.push_acp_user_block(
893 message_id,
894 original_content.into_iter().skip(1),
895 path_style,
896 cx,
897 );
898 });
899
900 for message in prompt.messages {
901 let context_server::types::PromptMessage { role, content } = message;
902 let block = mcp_message_content_to_acp_content_block(content);
903
904 match role {
905 context_server::types::Role::User => {
906 let id = acp_thread::UserMessageId::new();
907
908 acp_thread.update(cx, |acp_thread, cx| {
909 acp_thread.push_user_content_block_with_indent(
910 Some(id.clone()),
911 block.clone(),
912 true,
913 cx,
914 );
915 });
916
917 thread.update(cx, |thread, cx| {
918 thread.push_acp_user_block(id, [block], path_style, cx);
919 });
920 }
921 context_server::types::Role::Assistant => {
922 acp_thread.update(cx, |acp_thread, cx| {
923 acp_thread.push_assistant_content_block_with_indent(
924 block.clone(),
925 false,
926 true,
927 cx,
928 );
929 });
930
931 thread.update(cx, |thread, cx| {
932 thread.push_acp_agent_block(block, cx);
933 });
934 }
935 }
936
937 last_is_user = role == context_server::types::Role::User;
938 }
939
940 let response_stream = thread.update(cx, |thread, cx| {
941 if last_is_user {
942 thread.send_existing(cx)
943 } else {
944 // Resume if MCP prompt did not end with a user message
945 thread.resume(cx)
946 }
947 })?;
948
949 cx.update(|cx| {
950 NativeAgentConnection::handle_thread_events(
951 response_stream,
952 acp_thread.downgrade(),
953 cx,
954 )
955 })
956 .await
957 })
958 }
959}
960
961/// Wrapper struct that implements the AgentConnection trait
962#[derive(Clone)]
963pub struct NativeAgentConnection(pub Entity<NativeAgent>);
964
965impl NativeAgentConnection {
966 pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option<Entity<Thread>> {
967 self.0
968 .read(cx)
969 .sessions
970 .get(session_id)
971 .map(|session| session.thread.clone())
972 }
973
974 pub fn load_thread(&self, id: acp::SessionId, cx: &mut App) -> Task<Result<Entity<Thread>>> {
975 self.0.update(cx, |this, cx| this.load_thread(id, cx))
976 }
977
978 fn run_turn(
979 &self,
980 session_id: acp::SessionId,
981 cx: &mut App,
982 f: impl 'static
983 + FnOnce(Entity<Thread>, &mut App) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>,
984 ) -> Task<Result<acp::PromptResponse>> {
985 let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
986 agent
987 .sessions
988 .get_mut(&session_id)
989 .map(|s| (s.thread.clone(), s.acp_thread.clone()))
990 }) else {
991 return Task::ready(Err(anyhow!("Session not found")));
992 };
993 log::debug!("Found session for: {}", session_id);
994
995 let response_stream = match f(thread, cx) {
996 Ok(stream) => stream,
997 Err(err) => return Task::ready(Err(err)),
998 };
999 Self::handle_thread_events(response_stream, acp_thread.downgrade(), cx)
1000 }
1001
1002 fn handle_thread_events(
1003 mut events: mpsc::UnboundedReceiver<Result<ThreadEvent>>,
1004 acp_thread: WeakEntity<AcpThread>,
1005 cx: &App,
1006 ) -> Task<Result<acp::PromptResponse>> {
1007 cx.spawn(async move |cx| {
1008 // Handle response stream and forward to session.acp_thread
1009 while let Some(result) = events.next().await {
1010 match result {
1011 Ok(event) => {
1012 log::trace!("Received completion event: {:?}", event);
1013
1014 match event {
1015 ThreadEvent::UserMessage(message) => {
1016 acp_thread.update(cx, |thread, cx| {
1017 for content in message.content {
1018 thread.push_user_content_block(
1019 Some(message.id.clone()),
1020 content.into(),
1021 cx,
1022 );
1023 }
1024 })?;
1025 }
1026 ThreadEvent::AgentText(text) => {
1027 acp_thread.update(cx, |thread, cx| {
1028 thread.push_assistant_content_block(text.into(), false, cx)
1029 })?;
1030 }
1031 ThreadEvent::AgentThinking(text) => {
1032 acp_thread.update(cx, |thread, cx| {
1033 thread.push_assistant_content_block(text.into(), true, cx)
1034 })?;
1035 }
1036 ThreadEvent::ToolCallAuthorization(ToolCallAuthorization {
1037 tool_call,
1038 options,
1039 response,
1040 context: _,
1041 }) => {
1042 let outcome_task = acp_thread.update(cx, |thread, cx| {
1043 thread.request_tool_call_authorization(tool_call, options, cx)
1044 })??;
1045 cx.background_spawn(async move {
1046 if let acp::RequestPermissionOutcome::Selected(
1047 acp::SelectedPermissionOutcome { option_id, .. },
1048 ) = outcome_task.await
1049 {
1050 response
1051 .send(option_id)
1052 .map(|_| anyhow!("authorization receiver was dropped"))
1053 .log_err();
1054 }
1055 })
1056 .detach();
1057 }
1058 ThreadEvent::ToolCall(tool_call) => {
1059 acp_thread.update(cx, |thread, cx| {
1060 thread.upsert_tool_call(tool_call, cx)
1061 })??;
1062 }
1063 ThreadEvent::ToolCallUpdate(update) => {
1064 acp_thread.update(cx, |thread, cx| {
1065 thread.update_tool_call(update, cx)
1066 })??;
1067 }
1068 ThreadEvent::SubagentSpawned(session_id) => {
1069 acp_thread.update(cx, |thread, cx| {
1070 thread.subagent_spawned(session_id, cx);
1071 })?;
1072 }
1073 ThreadEvent::Retry(status) => {
1074 acp_thread.update(cx, |thread, cx| {
1075 thread.update_retry_status(status, cx)
1076 })?;
1077 }
1078 ThreadEvent::Stop(stop_reason) => {
1079 log::debug!("Assistant message complete: {:?}", stop_reason);
1080 return Ok(acp::PromptResponse::new(stop_reason));
1081 }
1082 }
1083 }
1084 Err(e) => {
1085 log::error!("Error in model response stream: {:?}", e);
1086 return Err(e);
1087 }
1088 }
1089 }
1090
1091 log::debug!("Response stream completed");
1092 anyhow::Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
1093 })
1094 }
1095}
1096
1097struct Command<'a> {
1098 prompt_name: &'a str,
1099 arg_value: &'a str,
1100 explicit_server_id: Option<&'a str>,
1101}
1102
1103impl<'a> Command<'a> {
1104 fn parse(prompt: &'a [acp::ContentBlock]) -> Option<Self> {
1105 let acp::ContentBlock::Text(text_content) = prompt.first()? else {
1106 return None;
1107 };
1108 let text = text_content.text.trim();
1109 let command = text.strip_prefix('/')?;
1110 let (command, arg_value) = command
1111 .split_once(char::is_whitespace)
1112 .unwrap_or((command, ""));
1113
1114 if let Some((server_id, prompt_name)) = command.split_once('.') {
1115 Some(Self {
1116 prompt_name,
1117 arg_value,
1118 explicit_server_id: Some(server_id),
1119 })
1120 } else {
1121 Some(Self {
1122 prompt_name: command,
1123 arg_value,
1124 explicit_server_id: None,
1125 })
1126 }
1127 }
1128}
1129
1130struct NativeAgentModelSelector {
1131 session_id: acp::SessionId,
1132 connection: NativeAgentConnection,
1133}
1134
1135impl acp_thread::AgentModelSelector for NativeAgentModelSelector {
1136 fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
1137 log::debug!("NativeAgentConnection::list_models called");
1138 let list = self.connection.0.read(cx).models.model_list.clone();
1139 Task::ready(if list.is_empty() {
1140 Err(anyhow::anyhow!("No models available"))
1141 } else {
1142 Ok(list)
1143 })
1144 }
1145
1146 fn select_model(&self, model_id: acp::ModelId, cx: &mut App) -> Task<Result<()>> {
1147 log::debug!(
1148 "Setting model for session {}: {}",
1149 self.session_id,
1150 model_id
1151 );
1152 let Some(thread) = self
1153 .connection
1154 .0
1155 .read(cx)
1156 .sessions
1157 .get(&self.session_id)
1158 .map(|session| session.thread.clone())
1159 else {
1160 return Task::ready(Err(anyhow!("Session not found")));
1161 };
1162
1163 let Some(model) = self.connection.0.read(cx).models.model_from_id(&model_id) else {
1164 return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
1165 };
1166
1167 // We want to reset the effort level when switching models, as the currently-selected effort level may
1168 // not be compatible.
1169 let effort = model
1170 .default_effort_level()
1171 .map(|effort_level| effort_level.value.to_string());
1172
1173 thread.update(cx, |thread, cx| {
1174 thread.set_model(model.clone(), cx);
1175 thread.set_thinking_effort(effort.clone(), cx);
1176 thread.set_thinking_enabled(model.supports_thinking(), cx);
1177 });
1178
1179 update_settings_file(
1180 self.connection.0.read(cx).fs.clone(),
1181 cx,
1182 move |settings, cx| {
1183 let provider = model.provider_id().0.to_string();
1184 let model = model.id().0.to_string();
1185 let enable_thinking = thread.read(cx).thinking_enabled();
1186 settings
1187 .agent
1188 .get_or_insert_default()
1189 .set_model(LanguageModelSelection {
1190 provider: provider.into(),
1191 model,
1192 enable_thinking,
1193 effort,
1194 });
1195 },
1196 );
1197
1198 Task::ready(Ok(()))
1199 }
1200
1201 fn selected_model(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelInfo>> {
1202 let Some(thread) = self
1203 .connection
1204 .0
1205 .read(cx)
1206 .sessions
1207 .get(&self.session_id)
1208 .map(|session| session.thread.clone())
1209 else {
1210 return Task::ready(Err(anyhow!("Session not found")));
1211 };
1212 let Some(model) = thread.read(cx).model() else {
1213 return Task::ready(Err(anyhow!("Model not found")));
1214 };
1215 let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
1216 else {
1217 return Task::ready(Err(anyhow!("Provider not found")));
1218 };
1219 Task::ready(Ok(LanguageModels::map_language_model_to_info(
1220 model, &provider,
1221 )))
1222 }
1223
1224 fn watch(&self, cx: &mut App) -> Option<watch::Receiver<()>> {
1225 Some(self.connection.0.read(cx).models.watch())
1226 }
1227
1228 fn should_render_footer(&self) -> bool {
1229 true
1230 }
1231}
1232
1233impl acp_thread::AgentConnection for NativeAgentConnection {
1234 fn telemetry_id(&self) -> SharedString {
1235 "zed".into()
1236 }
1237
1238 fn new_session(
1239 self: Rc<Self>,
1240 project: Entity<Project>,
1241 cwd: &Path,
1242 cx: &mut App,
1243 ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
1244 log::debug!("Creating new thread for project at: {cwd:?}");
1245 Task::ready(Ok(self
1246 .0
1247 .update(cx, |agent, cx| agent.new_session(project, cx))))
1248 }
1249
1250 fn supports_load_session(&self) -> bool {
1251 true
1252 }
1253
1254 fn load_session(
1255 self: Rc<Self>,
1256 session: AgentSessionInfo,
1257 _project: Entity<Project>,
1258 _cwd: &Path,
1259 cx: &mut App,
1260 ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
1261 self.0
1262 .update(cx, |agent, cx| agent.open_thread(session.session_id, cx))
1263 }
1264
1265 fn supports_close_session(&self) -> bool {
1266 true
1267 }
1268
1269 fn close_session(&self, session_id: &acp::SessionId, cx: &mut App) -> Task<Result<()>> {
1270 self.0.update(cx, |agent, _cx| {
1271 agent.sessions.remove(session_id);
1272 });
1273 Task::ready(Ok(()))
1274 }
1275
1276 fn auth_methods(&self) -> &[acp::AuthMethod] {
1277 &[] // No auth for in-process
1278 }
1279
1280 fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
1281 Task::ready(Ok(()))
1282 }
1283
1284 fn model_selector(&self, session_id: &acp::SessionId) -> Option<Rc<dyn AgentModelSelector>> {
1285 Some(Rc::new(NativeAgentModelSelector {
1286 session_id: session_id.clone(),
1287 connection: self.clone(),
1288 }) as Rc<dyn AgentModelSelector>)
1289 }
1290
1291 fn prompt(
1292 &self,
1293 id: Option<acp_thread::UserMessageId>,
1294 params: acp::PromptRequest,
1295 cx: &mut App,
1296 ) -> Task<Result<acp::PromptResponse>> {
1297 let id = id.expect("UserMessageId is required");
1298 let session_id = params.session_id.clone();
1299 log::info!("Received prompt request for session: {}", session_id);
1300 log::debug!("Prompt blocks count: {}", params.prompt.len());
1301
1302 if let Some(parsed_command) = Command::parse(¶ms.prompt) {
1303 let registry = self.0.read(cx).context_server_registry.read(cx);
1304
1305 let explicit_server_id = parsed_command
1306 .explicit_server_id
1307 .map(|server_id| ContextServerId(server_id.into()));
1308
1309 if let Some(prompt) =
1310 registry.find_prompt(explicit_server_id.as_ref(), parsed_command.prompt_name)
1311 {
1312 let arguments = if !parsed_command.arg_value.is_empty()
1313 && let Some(arg_name) = prompt
1314 .prompt
1315 .arguments
1316 .as_ref()
1317 .and_then(|args| args.first())
1318 .map(|arg| arg.name.clone())
1319 {
1320 HashMap::from_iter([(arg_name, parsed_command.arg_value.to_string())])
1321 } else {
1322 Default::default()
1323 };
1324
1325 let prompt_name = prompt.prompt.name.clone();
1326 let server_id = prompt.server_id.clone();
1327
1328 return self.0.update(cx, |agent, cx| {
1329 agent.send_mcp_prompt(
1330 id,
1331 session_id.clone(),
1332 prompt_name,
1333 server_id,
1334 arguments,
1335 params.prompt,
1336 cx,
1337 )
1338 });
1339 };
1340 };
1341
1342 let path_style = self.0.read(cx).project.read(cx).path_style(cx);
1343
1344 self.run_turn(session_id, cx, move |thread, cx| {
1345 let content: Vec<UserMessageContent> = params
1346 .prompt
1347 .into_iter()
1348 .map(|block| UserMessageContent::from_content_block(block, path_style))
1349 .collect::<Vec<_>>();
1350 log::debug!("Converted prompt to message: {} chars", content.len());
1351 log::debug!("Message id: {:?}", id);
1352 log::debug!("Message content: {:?}", content);
1353
1354 thread.update(cx, |thread, cx| thread.send(id, content, cx))
1355 })
1356 }
1357
1358 fn retry(
1359 &self,
1360 session_id: &acp::SessionId,
1361 _cx: &App,
1362 ) -> Option<Rc<dyn acp_thread::AgentSessionRetry>> {
1363 Some(Rc::new(NativeAgentSessionRetry {
1364 connection: self.clone(),
1365 session_id: session_id.clone(),
1366 }) as _)
1367 }
1368
1369 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
1370 log::info!("Cancelling on session: {}", session_id);
1371 self.0.update(cx, |agent, cx| {
1372 if let Some(agent) = agent.sessions.get(session_id) {
1373 agent
1374 .thread
1375 .update(cx, |thread, cx| thread.cancel(cx))
1376 .detach();
1377 }
1378 });
1379 }
1380
1381 fn truncate(
1382 &self,
1383 session_id: &agent_client_protocol::SessionId,
1384 cx: &App,
1385 ) -> Option<Rc<dyn acp_thread::AgentSessionTruncate>> {
1386 self.0.read_with(cx, |agent, _cx| {
1387 agent.sessions.get(session_id).map(|session| {
1388 Rc::new(NativeAgentSessionTruncate {
1389 thread: session.thread.clone(),
1390 acp_thread: session.acp_thread.downgrade(),
1391 }) as _
1392 })
1393 })
1394 }
1395
1396 fn set_title(
1397 &self,
1398 session_id: &acp::SessionId,
1399 cx: &App,
1400 ) -> Option<Rc<dyn acp_thread::AgentSessionSetTitle>> {
1401 self.0.read_with(cx, |agent, _cx| {
1402 agent
1403 .sessions
1404 .get(session_id)
1405 .filter(|s| !s.thread.read(cx).is_subagent())
1406 .map(|session| {
1407 Rc::new(NativeAgentSessionSetTitle {
1408 thread: session.thread.clone(),
1409 }) as _
1410 })
1411 })
1412 }
1413
1414 fn session_list(&self, cx: &mut App) -> Option<Rc<dyn AgentSessionList>> {
1415 let thread_store = self.0.read(cx).thread_store.clone();
1416 Some(Rc::new(NativeAgentSessionList::new(thread_store, cx)) as _)
1417 }
1418
1419 fn telemetry(&self) -> Option<Rc<dyn acp_thread::AgentTelemetry>> {
1420 Some(Rc::new(self.clone()) as Rc<dyn acp_thread::AgentTelemetry>)
1421 }
1422
1423 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1424 self
1425 }
1426}
1427
1428impl acp_thread::AgentTelemetry for NativeAgentConnection {
1429 fn thread_data(
1430 &self,
1431 session_id: &acp::SessionId,
1432 cx: &mut App,
1433 ) -> Task<Result<serde_json::Value>> {
1434 let Some(session) = self.0.read(cx).sessions.get(session_id) else {
1435 return Task::ready(Err(anyhow!("Session not found")));
1436 };
1437
1438 let task = session.thread.read(cx).to_db(cx);
1439 cx.background_spawn(async move {
1440 serde_json::to_value(task.await).context("Failed to serialize thread")
1441 })
1442 }
1443}
1444
1445pub struct NativeAgentSessionList {
1446 thread_store: Entity<ThreadStore>,
1447 updates_tx: smol::channel::Sender<acp_thread::SessionListUpdate>,
1448 updates_rx: smol::channel::Receiver<acp_thread::SessionListUpdate>,
1449 _subscription: Subscription,
1450}
1451
1452impl NativeAgentSessionList {
1453 fn new(thread_store: Entity<ThreadStore>, cx: &mut App) -> Self {
1454 let (tx, rx) = smol::channel::unbounded();
1455 let this_tx = tx.clone();
1456 let subscription = cx.observe(&thread_store, move |_, _| {
1457 this_tx
1458 .try_send(acp_thread::SessionListUpdate::Refresh)
1459 .ok();
1460 });
1461 Self {
1462 thread_store,
1463 updates_tx: tx,
1464 updates_rx: rx,
1465 _subscription: subscription,
1466 }
1467 }
1468
1469 fn to_session_info(entry: DbThreadMetadata) -> AgentSessionInfo {
1470 AgentSessionInfo {
1471 session_id: entry.id,
1472 cwd: None,
1473 title: Some(entry.title),
1474 updated_at: Some(entry.updated_at),
1475 meta: None,
1476 }
1477 }
1478
1479 pub fn thread_store(&self) -> &Entity<ThreadStore> {
1480 &self.thread_store
1481 }
1482}
1483
1484impl AgentSessionList for NativeAgentSessionList {
1485 fn list_sessions(
1486 &self,
1487 _request: AgentSessionListRequest,
1488 cx: &mut App,
1489 ) -> Task<Result<AgentSessionListResponse>> {
1490 let sessions = self
1491 .thread_store
1492 .read(cx)
1493 .entries()
1494 .map(Self::to_session_info)
1495 .collect();
1496 Task::ready(Ok(AgentSessionListResponse::new(sessions)))
1497 }
1498
1499 fn supports_delete(&self) -> bool {
1500 true
1501 }
1502
1503 fn delete_session(&self, session_id: &acp::SessionId, cx: &mut App) -> Task<Result<()>> {
1504 self.thread_store
1505 .update(cx, |store, cx| store.delete_thread(session_id.clone(), cx))
1506 }
1507
1508 fn delete_sessions(&self, cx: &mut App) -> Task<Result<()>> {
1509 self.thread_store
1510 .update(cx, |store, cx| store.delete_threads(cx))
1511 }
1512
1513 fn watch(
1514 &self,
1515 _cx: &mut App,
1516 ) -> Option<smol::channel::Receiver<acp_thread::SessionListUpdate>> {
1517 Some(self.updates_rx.clone())
1518 }
1519
1520 fn notify_refresh(&self) {
1521 self.updates_tx
1522 .try_send(acp_thread::SessionListUpdate::Refresh)
1523 .ok();
1524 }
1525
1526 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1527 self
1528 }
1529}
1530
1531struct NativeAgentSessionTruncate {
1532 thread: Entity<Thread>,
1533 acp_thread: WeakEntity<AcpThread>,
1534}
1535
1536impl acp_thread::AgentSessionTruncate for NativeAgentSessionTruncate {
1537 fn run(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
1538 match self.thread.update(cx, |thread, cx| {
1539 thread.truncate(message_id.clone(), cx)?;
1540 Ok(thread.latest_token_usage())
1541 }) {
1542 Ok(usage) => {
1543 self.acp_thread
1544 .update(cx, |thread, cx| {
1545 thread.update_token_usage(usage, cx);
1546 })
1547 .ok();
1548 Task::ready(Ok(()))
1549 }
1550 Err(error) => Task::ready(Err(error)),
1551 }
1552 }
1553}
1554
1555struct NativeAgentSessionRetry {
1556 connection: NativeAgentConnection,
1557 session_id: acp::SessionId,
1558}
1559
1560impl acp_thread::AgentSessionRetry for NativeAgentSessionRetry {
1561 fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>> {
1562 self.connection
1563 .run_turn(self.session_id.clone(), cx, |thread, cx| {
1564 thread.update(cx, |thread, cx| thread.resume(cx))
1565 })
1566 }
1567}
1568
1569struct NativeAgentSessionSetTitle {
1570 thread: Entity<Thread>,
1571}
1572
1573impl acp_thread::AgentSessionSetTitle for NativeAgentSessionSetTitle {
1574 fn run(&self, title: SharedString, cx: &mut App) -> Task<Result<()>> {
1575 self.thread
1576 .update(cx, |thread, cx| thread.set_title(title, cx));
1577 Task::ready(Ok(()))
1578 }
1579}
1580
1581pub struct NativeThreadEnvironment {
1582 agent: WeakEntity<NativeAgent>,
1583 acp_thread: WeakEntity<AcpThread>,
1584}
1585
1586impl NativeThreadEnvironment {
1587 pub(crate) fn create_subagent_thread(
1588 agent: WeakEntity<NativeAgent>,
1589 parent_thread_entity: Entity<Thread>,
1590 label: String,
1591 initial_prompt: String,
1592 timeout: Option<Duration>,
1593 allowed_tools: Option<Vec<String>>,
1594 cx: &mut App,
1595 ) -> Result<Rc<dyn SubagentHandle>> {
1596 let parent_thread = parent_thread_entity.read(cx);
1597 let current_depth = parent_thread.depth();
1598
1599 if current_depth >= MAX_SUBAGENT_DEPTH {
1600 return Err(anyhow!(
1601 "Maximum subagent depth ({}) reached",
1602 MAX_SUBAGENT_DEPTH
1603 ));
1604 }
1605
1606 let running_count = parent_thread.running_subagent_count();
1607 if running_count >= MAX_PARALLEL_SUBAGENTS {
1608 return Err(anyhow!(
1609 "Maximum parallel subagents ({}) reached. Wait for existing subagents to complete.",
1610 MAX_PARALLEL_SUBAGENTS
1611 ));
1612 }
1613
1614 let allowed_tools = match allowed_tools {
1615 Some(tools) => {
1616 let parent_tool_names: std::collections::HashSet<&str> =
1617 parent_thread.tools.keys().map(|s| s.as_str()).collect();
1618 Some(
1619 tools
1620 .into_iter()
1621 .filter(|t| parent_tool_names.contains(t.as_str()))
1622 .collect::<Vec<_>>(),
1623 )
1624 }
1625 None => Some(parent_thread.tools.keys().map(|s| s.to_string()).collect()),
1626 };
1627
1628 let subagent_thread: Entity<Thread> = cx.new(|cx| {
1629 let mut thread = Thread::new_subagent(&parent_thread_entity, cx);
1630 thread.set_title(label.into(), cx);
1631 thread
1632 });
1633
1634 let session_id = subagent_thread.read(cx).id().clone();
1635
1636 let acp_thread = agent.update(cx, |agent, cx| {
1637 agent.register_session(
1638 subagent_thread.clone(),
1639 allowed_tools
1640 .as_ref()
1641 .map(|v| v.iter().map(|s| s.as_str()).collect()),
1642 cx,
1643 )
1644 })?;
1645
1646 parent_thread_entity.update(cx, |parent_thread, _cx| {
1647 parent_thread.register_running_subagent(subagent_thread.downgrade())
1648 });
1649
1650 let task = acp_thread.update(cx, |agent, cx| agent.send(vec![initial_prompt.into()], cx));
1651
1652 let timeout_timer = timeout.map(|d| cx.background_executor().timer(d));
1653 let wait_for_prompt_to_complete = cx
1654 .background_spawn(async move {
1655 if let Some(timer) = timeout_timer {
1656 futures::select! {
1657 _ = timer.fuse() => SubagentInitialPromptResult::Timeout,
1658 _ = task.fuse() => SubagentInitialPromptResult::Completed,
1659 }
1660 } else {
1661 task.await.log_err();
1662 SubagentInitialPromptResult::Completed
1663 }
1664 })
1665 .shared();
1666
1667 let mut user_stop_rx: watch::Receiver<bool> =
1668 acp_thread.update(cx, |thread, _| thread.user_stop_receiver());
1669
1670 let user_cancelled = cx
1671 .background_spawn(async move {
1672 loop {
1673 if *user_stop_rx.borrow() {
1674 return;
1675 }
1676 if user_stop_rx.changed().await.is_err() {
1677 std::future::pending::<()>().await;
1678 }
1679 }
1680 })
1681 .shared();
1682
1683 Ok(Rc::new(NativeSubagentHandle {
1684 session_id,
1685 subagent_thread,
1686 parent_thread: parent_thread_entity.downgrade(),
1687 acp_thread,
1688 wait_for_prompt_to_complete,
1689 user_cancelled,
1690 }) as _)
1691 }
1692}
1693
1694impl ThreadEnvironment for NativeThreadEnvironment {
1695 fn create_terminal(
1696 &self,
1697 command: String,
1698 cwd: Option<PathBuf>,
1699 output_byte_limit: Option<u64>,
1700 cx: &mut AsyncApp,
1701 ) -> Task<Result<Rc<dyn TerminalHandle>>> {
1702 let task = self.acp_thread.update(cx, |thread, cx| {
1703 thread.create_terminal(command, vec![], vec![], cwd, output_byte_limit, cx)
1704 });
1705
1706 let acp_thread = self.acp_thread.clone();
1707 cx.spawn(async move |cx| {
1708 let terminal = task?.await?;
1709
1710 let (drop_tx, drop_rx) = oneshot::channel();
1711 let terminal_id = terminal.read_with(cx, |terminal, _cx| terminal.id().clone());
1712
1713 cx.spawn(async move |cx| {
1714 drop_rx.await.ok();
1715 acp_thread.update(cx, |thread, cx| thread.release_terminal(terminal_id, cx))
1716 })
1717 .detach();
1718
1719 let handle = AcpTerminalHandle {
1720 terminal,
1721 _drop_tx: Some(drop_tx),
1722 };
1723
1724 Ok(Rc::new(handle) as _)
1725 })
1726 }
1727
1728 fn create_subagent(
1729 &self,
1730 parent_thread_entity: Entity<Thread>,
1731 label: String,
1732 initial_prompt: String,
1733 timeout: Option<Duration>,
1734 allowed_tools: Option<Vec<String>>,
1735 cx: &mut App,
1736 ) -> Result<Rc<dyn SubagentHandle>> {
1737 Self::create_subagent_thread(
1738 self.agent.clone(),
1739 parent_thread_entity,
1740 label,
1741 initial_prompt,
1742 timeout,
1743 allowed_tools,
1744 cx,
1745 )
1746 }
1747}
1748
1749#[derive(Debug, Clone, Copy)]
1750enum SubagentInitialPromptResult {
1751 Completed,
1752 Timeout,
1753}
1754
1755pub struct NativeSubagentHandle {
1756 session_id: acp::SessionId,
1757 parent_thread: WeakEntity<Thread>,
1758 subagent_thread: Entity<Thread>,
1759 acp_thread: Entity<AcpThread>,
1760 wait_for_prompt_to_complete: Shared<Task<SubagentInitialPromptResult>>,
1761 user_cancelled: Shared<Task<()>>,
1762}
1763
1764impl SubagentHandle for NativeSubagentHandle {
1765 fn id(&self) -> acp::SessionId {
1766 self.session_id.clone()
1767 }
1768
1769 fn wait_for_summary(&self, summary_prompt: String, cx: &AsyncApp) -> Task<Result<String>> {
1770 let thread = self.subagent_thread.clone();
1771 let acp_thread = self.acp_thread.clone();
1772 let wait_for_prompt = self.wait_for_prompt_to_complete.clone();
1773
1774 let wait_for_summary_task = cx.spawn(async move |cx| {
1775 let timed_out = match wait_for_prompt.await {
1776 SubagentInitialPromptResult::Completed => false,
1777 SubagentInitialPromptResult::Timeout => true,
1778 };
1779
1780 let summary_prompt = if timed_out {
1781 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
1782 format!("{}\n{}", "The time to complete the task was exceeded. Stop with the task and follow the directions below:", summary_prompt)
1783 } else {
1784 summary_prompt
1785 };
1786
1787 acp_thread
1788 .update(cx, |thread, cx| thread.send(vec![summary_prompt.into()], cx))
1789 .await?;
1790
1791 thread.read_with(cx, |thread, _cx| {
1792 thread
1793 .last_message()
1794 .map(|m| m.to_markdown())
1795 .context("No response from subagent")
1796 })
1797 });
1798
1799 let user_cancelled = self.user_cancelled.clone();
1800 let thread = self.subagent_thread.clone();
1801 let subagent_session_id = self.session_id.clone();
1802 let parent_thread = self.parent_thread.clone();
1803 cx.spawn(async move |cx| {
1804 let result = futures::select! {
1805 result = wait_for_summary_task.fuse() => result,
1806 _ = user_cancelled.fuse() => {
1807 thread.update(cx, |thread, cx| thread.cancel(cx).detach());
1808 Err(anyhow!("User cancelled"))
1809 },
1810 };
1811 parent_thread
1812 .update(cx, |parent_thread, cx| {
1813 parent_thread.unregister_running_subagent(&subagent_session_id, cx)
1814 })
1815 .ok();
1816 result
1817 })
1818 }
1819}
1820
1821pub struct AcpTerminalHandle {
1822 terminal: Entity<acp_thread::Terminal>,
1823 _drop_tx: Option<oneshot::Sender<()>>,
1824}
1825
1826impl TerminalHandle for AcpTerminalHandle {
1827 fn id(&self, cx: &AsyncApp) -> Result<acp::TerminalId> {
1828 Ok(self.terminal.read_with(cx, |term, _cx| term.id().clone()))
1829 }
1830
1831 fn wait_for_exit(&self, cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
1832 Ok(self
1833 .terminal
1834 .read_with(cx, |term, _cx| term.wait_for_exit()))
1835 }
1836
1837 fn current_output(&self, cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
1838 Ok(self
1839 .terminal
1840 .read_with(cx, |term, cx| term.current_output(cx)))
1841 }
1842
1843 fn kill(&self, cx: &AsyncApp) -> Result<()> {
1844 cx.update(|cx| {
1845 self.terminal.update(cx, |terminal, cx| {
1846 terminal.kill(cx);
1847 });
1848 });
1849 Ok(())
1850 }
1851
1852 fn was_stopped_by_user(&self, cx: &AsyncApp) -> Result<bool> {
1853 Ok(self
1854 .terminal
1855 .read_with(cx, |term, _cx| term.was_stopped_by_user()))
1856 }
1857}
1858
1859#[cfg(test)]
1860mod internal_tests {
1861 use super::*;
1862 use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelInfo, MentionUri};
1863 use fs::FakeFs;
1864 use gpui::TestAppContext;
1865 use indoc::formatdoc;
1866 use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
1867 use language_model::{LanguageModelProviderId, LanguageModelProviderName};
1868 use serde_json::json;
1869 use settings::SettingsStore;
1870 use util::{path, rel_path::rel_path};
1871
1872 #[gpui::test]
1873 async fn test_maintaining_project_context(cx: &mut TestAppContext) {
1874 init_test(cx);
1875 let fs = FakeFs::new(cx.executor());
1876 fs.insert_tree(
1877 "/",
1878 json!({
1879 "a": {}
1880 }),
1881 )
1882 .await;
1883 let project = Project::test(fs.clone(), [], cx).await;
1884 let thread_store = cx.new(|cx| ThreadStore::new(cx));
1885 let agent = NativeAgent::new(
1886 project.clone(),
1887 thread_store,
1888 Templates::new(),
1889 None,
1890 fs.clone(),
1891 &mut cx.to_async(),
1892 )
1893 .await
1894 .unwrap();
1895 agent.read_with(cx, |agent, cx| {
1896 assert_eq!(agent.project_context.read(cx).worktrees, vec![])
1897 });
1898
1899 let worktree = project
1900 .update(cx, |project, cx| project.create_worktree("/a", true, cx))
1901 .await
1902 .unwrap();
1903 cx.run_until_parked();
1904 agent.read_with(cx, |agent, cx| {
1905 assert_eq!(
1906 agent.project_context.read(cx).worktrees,
1907 vec![WorktreeContext {
1908 root_name: "a".into(),
1909 abs_path: Path::new("/a").into(),
1910 rules_file: None
1911 }]
1912 )
1913 });
1914
1915 // Creating `/a/.rules` updates the project context.
1916 fs.insert_file("/a/.rules", Vec::new()).await;
1917 cx.run_until_parked();
1918 agent.read_with(cx, |agent, cx| {
1919 let rules_entry = worktree
1920 .read(cx)
1921 .entry_for_path(rel_path(".rules"))
1922 .unwrap();
1923 assert_eq!(
1924 agent.project_context.read(cx).worktrees,
1925 vec![WorktreeContext {
1926 root_name: "a".into(),
1927 abs_path: Path::new("/a").into(),
1928 rules_file: Some(RulesFileContext {
1929 path_in_worktree: rel_path(".rules").into(),
1930 text: "".into(),
1931 project_entry_id: rules_entry.id.to_usize()
1932 })
1933 }]
1934 )
1935 });
1936 }
1937
1938 #[gpui::test]
1939 async fn test_listing_models(cx: &mut TestAppContext) {
1940 init_test(cx);
1941 let fs = FakeFs::new(cx.executor());
1942 fs.insert_tree("/", json!({ "a": {} })).await;
1943 let project = Project::test(fs.clone(), [], cx).await;
1944 let thread_store = cx.new(|cx| ThreadStore::new(cx));
1945 let connection = NativeAgentConnection(
1946 NativeAgent::new(
1947 project.clone(),
1948 thread_store,
1949 Templates::new(),
1950 None,
1951 fs.clone(),
1952 &mut cx.to_async(),
1953 )
1954 .await
1955 .unwrap(),
1956 );
1957
1958 // Create a thread/session
1959 let acp_thread = cx
1960 .update(|cx| {
1961 Rc::new(connection.clone()).new_session(project.clone(), Path::new("/a"), cx)
1962 })
1963 .await
1964 .unwrap();
1965
1966 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
1967
1968 let models = cx
1969 .update(|cx| {
1970 connection
1971 .model_selector(&session_id)
1972 .unwrap()
1973 .list_models(cx)
1974 })
1975 .await
1976 .unwrap();
1977
1978 let acp_thread::AgentModelList::Grouped(models) = models else {
1979 panic!("Unexpected model group");
1980 };
1981 assert_eq!(
1982 models,
1983 IndexMap::from_iter([(
1984 AgentModelGroupName("Fake".into()),
1985 vec![AgentModelInfo {
1986 id: acp::ModelId::new("fake/fake"),
1987 name: "Fake".into(),
1988 description: None,
1989 icon: Some(acp_thread::AgentModelIcon::Named(
1990 ui::IconName::ZedAssistant
1991 )),
1992 is_latest: false,
1993 cost: None,
1994 }]
1995 )])
1996 );
1997 }
1998
1999 #[gpui::test]
2000 async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) {
2001 init_test(cx);
2002 let fs = FakeFs::new(cx.executor());
2003 fs.create_dir(paths::settings_file().parent().unwrap())
2004 .await
2005 .unwrap();
2006 fs.insert_file(
2007 paths::settings_file(),
2008 json!({
2009 "agent": {
2010 "default_model": {
2011 "provider": "foo",
2012 "model": "bar"
2013 }
2014 }
2015 })
2016 .to_string()
2017 .into_bytes(),
2018 )
2019 .await;
2020 let project = Project::test(fs.clone(), [], cx).await;
2021
2022 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2023
2024 // Create the agent and connection
2025 let agent = NativeAgent::new(
2026 project.clone(),
2027 thread_store,
2028 Templates::new(),
2029 None,
2030 fs.clone(),
2031 &mut cx.to_async(),
2032 )
2033 .await
2034 .unwrap();
2035 let connection = NativeAgentConnection(agent.clone());
2036
2037 // Create a thread/session
2038 let acp_thread = cx
2039 .update(|cx| {
2040 Rc::new(connection.clone()).new_session(project.clone(), Path::new("/a"), cx)
2041 })
2042 .await
2043 .unwrap();
2044
2045 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2046
2047 // Select a model
2048 let selector = connection.model_selector(&session_id).unwrap();
2049 let model_id = acp::ModelId::new("fake/fake");
2050 cx.update(|cx| selector.select_model(model_id.clone(), cx))
2051 .await
2052 .unwrap();
2053
2054 // Verify the thread has the selected model
2055 agent.read_with(cx, |agent, _| {
2056 let session = agent.sessions.get(&session_id).unwrap();
2057 session.thread.read_with(cx, |thread, _| {
2058 assert_eq!(thread.model().unwrap().id().0, "fake");
2059 });
2060 });
2061
2062 cx.run_until_parked();
2063
2064 // Verify settings file was updated
2065 let settings_content = fs.load(paths::settings_file()).await.unwrap();
2066 let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
2067
2068 // Check that the agent settings contain the selected model
2069 assert_eq!(
2070 settings_json["agent"]["default_model"]["model"],
2071 json!("fake")
2072 );
2073 assert_eq!(
2074 settings_json["agent"]["default_model"]["provider"],
2075 json!("fake")
2076 );
2077
2078 // Register a thinking model and select it.
2079 cx.update(|cx| {
2080 let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2081 "fake-corp",
2082 "fake-thinking",
2083 "Fake Thinking",
2084 true,
2085 ));
2086 let thinking_provider = Arc::new(
2087 FakeLanguageModelProvider::new(
2088 LanguageModelProviderId::from("fake-corp".to_string()),
2089 LanguageModelProviderName::from("Fake Corp".to_string()),
2090 )
2091 .with_models(vec![thinking_model]),
2092 );
2093 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2094 registry.register_provider(thinking_provider, cx);
2095 });
2096 });
2097 agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2098
2099 let selector = connection.model_selector(&session_id).unwrap();
2100 cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2101 .await
2102 .unwrap();
2103 cx.run_until_parked();
2104
2105 // Verify enable_thinking was written to settings as true.
2106 let settings_content = fs.load(paths::settings_file()).await.unwrap();
2107 let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
2108 assert_eq!(
2109 settings_json["agent"]["default_model"]["enable_thinking"],
2110 json!(true),
2111 "selecting a thinking model should persist enable_thinking: true to settings"
2112 );
2113 }
2114
2115 #[gpui::test]
2116 async fn test_select_model_updates_thinking_enabled(cx: &mut TestAppContext) {
2117 init_test(cx);
2118 let fs = FakeFs::new(cx.executor());
2119 fs.create_dir(paths::settings_file().parent().unwrap())
2120 .await
2121 .unwrap();
2122 fs.insert_file(paths::settings_file(), b"{}".to_vec()).await;
2123 let project = Project::test(fs.clone(), [], cx).await;
2124
2125 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2126 let agent = NativeAgent::new(
2127 project.clone(),
2128 thread_store,
2129 Templates::new(),
2130 None,
2131 fs.clone(),
2132 &mut cx.to_async(),
2133 )
2134 .await
2135 .unwrap();
2136 let connection = NativeAgentConnection(agent.clone());
2137
2138 let acp_thread = cx
2139 .update(|cx| {
2140 Rc::new(connection.clone()).new_session(project.clone(), Path::new("/a"), cx)
2141 })
2142 .await
2143 .unwrap();
2144 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2145
2146 // Register a second provider with a thinking model.
2147 cx.update(|cx| {
2148 let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2149 "fake-corp",
2150 "fake-thinking",
2151 "Fake Thinking",
2152 true,
2153 ));
2154 let thinking_provider = Arc::new(
2155 FakeLanguageModelProvider::new(
2156 LanguageModelProviderId::from("fake-corp".to_string()),
2157 LanguageModelProviderName::from("Fake Corp".to_string()),
2158 )
2159 .with_models(vec![thinking_model]),
2160 );
2161 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2162 registry.register_provider(thinking_provider, cx);
2163 });
2164 });
2165 // Refresh the agent's model list so it picks up the new provider.
2166 agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2167
2168 // Thread starts with thinking_enabled = false (the default).
2169 agent.read_with(cx, |agent, _| {
2170 let session = agent.sessions.get(&session_id).unwrap();
2171 session.thread.read_with(cx, |thread, _| {
2172 assert!(!thread.thinking_enabled(), "thinking defaults to false");
2173 });
2174 });
2175
2176 // Select the thinking model via select_model.
2177 let selector = connection.model_selector(&session_id).unwrap();
2178 cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2179 .await
2180 .unwrap();
2181
2182 // select_model should have enabled thinking based on the model's supports_thinking().
2183 agent.read_with(cx, |agent, _| {
2184 let session = agent.sessions.get(&session_id).unwrap();
2185 session.thread.read_with(cx, |thread, _| {
2186 assert!(
2187 thread.thinking_enabled(),
2188 "select_model should enable thinking when model supports it"
2189 );
2190 });
2191 });
2192
2193 // Switch back to the non-thinking model.
2194 let selector = connection.model_selector(&session_id).unwrap();
2195 cx.update(|cx| selector.select_model(acp::ModelId::new("fake/fake"), cx))
2196 .await
2197 .unwrap();
2198
2199 // select_model should have disabled thinking.
2200 agent.read_with(cx, |agent, _| {
2201 let session = agent.sessions.get(&session_id).unwrap();
2202 session.thread.read_with(cx, |thread, _| {
2203 assert!(
2204 !thread.thinking_enabled(),
2205 "select_model should disable thinking when model does not support it"
2206 );
2207 });
2208 });
2209 }
2210
2211 #[gpui::test]
2212 async fn test_loaded_thread_preserves_thinking_enabled(cx: &mut TestAppContext) {
2213 init_test(cx);
2214 let fs = FakeFs::new(cx.executor());
2215 fs.insert_tree("/", json!({ "a": {} })).await;
2216 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2217 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2218 let agent = NativeAgent::new(
2219 project.clone(),
2220 thread_store.clone(),
2221 Templates::new(),
2222 None,
2223 fs.clone(),
2224 &mut cx.to_async(),
2225 )
2226 .await
2227 .unwrap();
2228 let connection = Rc::new(NativeAgentConnection(agent.clone()));
2229
2230 // Register a thinking model.
2231 let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2232 "fake-corp",
2233 "fake-thinking",
2234 "Fake Thinking",
2235 true,
2236 ));
2237 let thinking_provider = Arc::new(
2238 FakeLanguageModelProvider::new(
2239 LanguageModelProviderId::from("fake-corp".to_string()),
2240 LanguageModelProviderName::from("Fake Corp".to_string()),
2241 )
2242 .with_models(vec![thinking_model.clone()]),
2243 );
2244 cx.update(|cx| {
2245 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2246 registry.register_provider(thinking_provider, cx);
2247 });
2248 });
2249 agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2250
2251 // Create a thread and select the thinking model.
2252 let acp_thread = cx
2253 .update(|cx| {
2254 connection
2255 .clone()
2256 .new_session(project.clone(), Path::new("/a"), cx)
2257 })
2258 .await
2259 .unwrap();
2260 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2261
2262 let selector = connection.model_selector(&session_id).unwrap();
2263 cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2264 .await
2265 .unwrap();
2266
2267 // Verify thinking is enabled after selecting the thinking model.
2268 let thread = agent.read_with(cx, |agent, _| {
2269 agent.sessions.get(&session_id).unwrap().thread.clone()
2270 });
2271 thread.read_with(cx, |thread, _| {
2272 assert!(
2273 thread.thinking_enabled(),
2274 "thinking should be enabled after selecting thinking model"
2275 );
2276 });
2277
2278 // Send a message so the thread gets persisted.
2279 let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
2280 let send = cx.foreground_executor().spawn(send);
2281 cx.run_until_parked();
2282
2283 thinking_model.send_last_completion_stream_text_chunk("Response.");
2284 thinking_model.end_last_completion_stream();
2285
2286 send.await.unwrap();
2287 cx.run_until_parked();
2288
2289 // Close the session so it can be reloaded from disk.
2290 cx.update(|cx| connection.clone().close_session(&session_id, cx))
2291 .await
2292 .unwrap();
2293 drop(thread);
2294 drop(acp_thread);
2295 agent.read_with(cx, |agent, _| {
2296 assert!(agent.sessions.is_empty());
2297 });
2298
2299 // Reload the thread and verify thinking_enabled is still true.
2300 let reloaded_acp_thread = agent
2301 .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
2302 .await
2303 .unwrap();
2304 let reloaded_thread = agent.read_with(cx, |agent, _| {
2305 agent.sessions.get(&session_id).unwrap().thread.clone()
2306 });
2307 reloaded_thread.read_with(cx, |thread, _| {
2308 assert!(
2309 thread.thinking_enabled(),
2310 "thinking_enabled should be preserved when reloading a thread with a thinking model"
2311 );
2312 });
2313
2314 drop(reloaded_acp_thread);
2315 }
2316
2317 #[gpui::test]
2318 async fn test_loaded_thread_preserves_model(cx: &mut TestAppContext) {
2319 init_test(cx);
2320 let fs = FakeFs::new(cx.executor());
2321 fs.insert_tree("/", json!({ "a": {} })).await;
2322 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2323 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2324 let agent = NativeAgent::new(
2325 project.clone(),
2326 thread_store.clone(),
2327 Templates::new(),
2328 None,
2329 fs.clone(),
2330 &mut cx.to_async(),
2331 )
2332 .await
2333 .unwrap();
2334 let connection = Rc::new(NativeAgentConnection(agent.clone()));
2335
2336 // Register a model where id() != name(), like real Anthropic models
2337 // (e.g. id="claude-sonnet-4-5-thinking-latest", name="Claude Sonnet 4.5 Thinking").
2338 let model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2339 "fake-corp",
2340 "custom-model-id",
2341 "Custom Model Display Name",
2342 false,
2343 ));
2344 let provider = Arc::new(
2345 FakeLanguageModelProvider::new(
2346 LanguageModelProviderId::from("fake-corp".to_string()),
2347 LanguageModelProviderName::from("Fake Corp".to_string()),
2348 )
2349 .with_models(vec![model.clone()]),
2350 );
2351 cx.update(|cx| {
2352 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2353 registry.register_provider(provider, cx);
2354 });
2355 });
2356 agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2357
2358 // Create a thread and select the model.
2359 let acp_thread = cx
2360 .update(|cx| {
2361 connection
2362 .clone()
2363 .new_session(project.clone(), Path::new("/a"), cx)
2364 })
2365 .await
2366 .unwrap();
2367 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2368
2369 let selector = connection.model_selector(&session_id).unwrap();
2370 cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/custom-model-id"), cx))
2371 .await
2372 .unwrap();
2373
2374 let thread = agent.read_with(cx, |agent, _| {
2375 agent.sessions.get(&session_id).unwrap().thread.clone()
2376 });
2377 thread.read_with(cx, |thread, _| {
2378 assert_eq!(
2379 thread.model().unwrap().id().0.as_ref(),
2380 "custom-model-id",
2381 "model should be set before persisting"
2382 );
2383 });
2384
2385 // Send a message so the thread gets persisted.
2386 let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
2387 let send = cx.foreground_executor().spawn(send);
2388 cx.run_until_parked();
2389
2390 model.send_last_completion_stream_text_chunk("Response.");
2391 model.end_last_completion_stream();
2392
2393 send.await.unwrap();
2394 cx.run_until_parked();
2395
2396 // Close the session so it can be reloaded from disk.
2397 cx.update(|cx| connection.clone().close_session(&session_id, cx))
2398 .await
2399 .unwrap();
2400 drop(thread);
2401 drop(acp_thread);
2402 agent.read_with(cx, |agent, _| {
2403 assert!(agent.sessions.is_empty());
2404 });
2405
2406 // Reload the thread and verify the model was preserved.
2407 let reloaded_acp_thread = agent
2408 .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
2409 .await
2410 .unwrap();
2411 let reloaded_thread = agent.read_with(cx, |agent, _| {
2412 agent.sessions.get(&session_id).unwrap().thread.clone()
2413 });
2414 reloaded_thread.read_with(cx, |thread, _| {
2415 let reloaded_model = thread
2416 .model()
2417 .expect("model should be present after reload");
2418 assert_eq!(
2419 reloaded_model.id().0.as_ref(),
2420 "custom-model-id",
2421 "reloaded thread should have the same model, not fall back to the default"
2422 );
2423 });
2424
2425 drop(reloaded_acp_thread);
2426 }
2427
2428 #[gpui::test]
2429 async fn test_save_load_thread(cx: &mut TestAppContext) {
2430 init_test(cx);
2431 let fs = FakeFs::new(cx.executor());
2432 fs.insert_tree(
2433 "/",
2434 json!({
2435 "a": {
2436 "b.md": "Lorem"
2437 }
2438 }),
2439 )
2440 .await;
2441 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2442 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2443 let agent = NativeAgent::new(
2444 project.clone(),
2445 thread_store.clone(),
2446 Templates::new(),
2447 None,
2448 fs.clone(),
2449 &mut cx.to_async(),
2450 )
2451 .await
2452 .unwrap();
2453 let connection = Rc::new(NativeAgentConnection(agent.clone()));
2454
2455 let acp_thread = cx
2456 .update(|cx| {
2457 connection
2458 .clone()
2459 .new_session(project.clone(), Path::new(""), cx)
2460 })
2461 .await
2462 .unwrap();
2463 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2464 let thread = agent.read_with(cx, |agent, _| {
2465 agent.sessions.get(&session_id).unwrap().thread.clone()
2466 });
2467
2468 // Ensure empty threads are not saved, even if they get mutated.
2469 let model = Arc::new(FakeLanguageModel::default());
2470 let summary_model = Arc::new(FakeLanguageModel::default());
2471 thread.update(cx, |thread, cx| {
2472 thread.set_model(model.clone(), cx);
2473 thread.set_summarization_model(Some(summary_model.clone()), cx);
2474 });
2475 cx.run_until_parked();
2476 assert_eq!(thread_entries(&thread_store, cx), vec![]);
2477
2478 let send = acp_thread.update(cx, |thread, cx| {
2479 thread.send(
2480 vec![
2481 "What does ".into(),
2482 acp::ContentBlock::ResourceLink(acp::ResourceLink::new(
2483 "b.md",
2484 MentionUri::File {
2485 abs_path: path!("/a/b.md").into(),
2486 }
2487 .to_uri()
2488 .to_string(),
2489 )),
2490 " mean?".into(),
2491 ],
2492 cx,
2493 )
2494 });
2495 let send = cx.foreground_executor().spawn(send);
2496 cx.run_until_parked();
2497
2498 model.send_last_completion_stream_text_chunk("Lorem.");
2499 model.end_last_completion_stream();
2500 cx.run_until_parked();
2501 summary_model
2502 .send_last_completion_stream_text_chunk(&format!("Explaining {}", path!("/a/b.md")));
2503 summary_model.end_last_completion_stream();
2504
2505 send.await.unwrap();
2506 let uri = MentionUri::File {
2507 abs_path: path!("/a/b.md").into(),
2508 }
2509 .to_uri();
2510 acp_thread.read_with(cx, |thread, cx| {
2511 assert_eq!(
2512 thread.to_markdown(cx),
2513 formatdoc! {"
2514 ## User
2515
2516 What does [@b.md]({uri}) mean?
2517
2518 ## Assistant
2519
2520 Lorem.
2521
2522 "}
2523 )
2524 });
2525
2526 cx.run_until_parked();
2527
2528 // Close the session so it can be reloaded from disk.
2529 cx.update(|cx| connection.clone().close_session(&session_id, cx))
2530 .await
2531 .unwrap();
2532 drop(thread);
2533 drop(acp_thread);
2534 agent.read_with(cx, |agent, _| {
2535 assert_eq!(agent.sessions.keys().cloned().collect::<Vec<_>>(), []);
2536 });
2537
2538 // Ensure the thread can be reloaded from disk.
2539 assert_eq!(
2540 thread_entries(&thread_store, cx),
2541 vec![(
2542 session_id.clone(),
2543 format!("Explaining {}", path!("/a/b.md"))
2544 )]
2545 );
2546 let acp_thread = agent
2547 .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
2548 .await
2549 .unwrap();
2550 acp_thread.read_with(cx, |thread, cx| {
2551 assert_eq!(
2552 thread.to_markdown(cx),
2553 formatdoc! {"
2554 ## User
2555
2556 What does [@b.md]({uri}) mean?
2557
2558 ## Assistant
2559
2560 Lorem.
2561
2562 "}
2563 )
2564 });
2565 }
2566
2567 fn thread_entries(
2568 thread_store: &Entity<ThreadStore>,
2569 cx: &mut TestAppContext,
2570 ) -> Vec<(acp::SessionId, String)> {
2571 thread_store.read_with(cx, |store, _| {
2572 store
2573 .entries()
2574 .map(|entry| (entry.id.clone(), entry.title.to_string()))
2575 .collect::<Vec<_>>()
2576 })
2577 }
2578
2579 fn init_test(cx: &mut TestAppContext) {
2580 env_logger::try_init().ok();
2581 cx.update(|cx| {
2582 let settings_store = SettingsStore::test(cx);
2583 cx.set_global(settings_store);
2584
2585 LanguageModelRegistry::test(cx);
2586 });
2587 }
2588}
2589
2590fn mcp_message_content_to_acp_content_block(
2591 content: context_server::types::MessageContent,
2592) -> acp::ContentBlock {
2593 match content {
2594 context_server::types::MessageContent::Text {
2595 text,
2596 annotations: _,
2597 } => text.into(),
2598 context_server::types::MessageContent::Image {
2599 data,
2600 mime_type,
2601 annotations: _,
2602 } => acp::ContentBlock::Image(acp::ImageContent::new(data, mime_type)),
2603 context_server::types::MessageContent::Audio {
2604 data,
2605 mime_type,
2606 annotations: _,
2607 } => acp::ContentBlock::Audio(acp::AudioContent::new(data, mime_type)),
2608 context_server::types::MessageContent::Resource {
2609 resource,
2610 annotations: _,
2611 } => {
2612 let mut link =
2613 acp::ResourceLink::new(resource.uri.to_string(), resource.uri.to_string());
2614 if let Some(mime_type) = resource.mime_type {
2615 link = link.mime_type(mime_type);
2616 }
2617 acp::ContentBlock::ResourceLink(link)
2618 }
2619 }
2620}