1mod db;
2mod edit_agent;
3mod legacy_thread;
4mod native_agent_server;
5pub mod outline;
6mod pattern_extraction;
7mod templates;
8#[cfg(test)]
9mod tests;
10mod thread;
11mod thread_store;
12mod tool_permissions;
13mod tools;
14
15use context_server::ContextServerId;
16pub use db::*;
17use itertools::Itertools;
18pub use native_agent_server::NativeAgentServer;
19pub use pattern_extraction::*;
20pub use shell_command_parser::extract_commands;
21pub use templates::*;
22pub use thread::*;
23pub use thread_store::*;
24pub use tool_permissions::*;
25pub use tools::*;
26
27use acp_thread::{
28 AcpThread, AgentModelSelector, AgentSessionInfo, AgentSessionList, AgentSessionListRequest,
29 AgentSessionListResponse, TokenUsageRatio, UserMessageId,
30};
31use agent_client_protocol as acp;
32use anyhow::{Context as _, Result, anyhow};
33use chrono::{DateTime, Utc};
34use collections::{HashMap, HashSet, IndexMap};
35use fs::Fs;
36use futures::channel::{mpsc, oneshot};
37use futures::future::Shared;
38use futures::{FutureExt as _, StreamExt as _, future};
39use gpui::{
40 App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
41};
42use language_model::{IconOrSvg, LanguageModel, LanguageModelProvider, LanguageModelRegistry};
43use project::{Project, ProjectItem, ProjectPath, Worktree};
44use prompt_store::{
45 ProjectContext, PromptStore, RULES_FILE_NAMES, RulesFileContext, UserRulesContext,
46 WorktreeContext,
47};
48use serde::{Deserialize, Serialize};
49use settings::{LanguageModelSelection, update_settings_file};
50use std::any::Any;
51use std::path::{Path, PathBuf};
52use std::rc::Rc;
53use std::sync::Arc;
54use util::ResultExt;
55use util::path_list::PathList;
56use util::rel_path::RelPath;
57
58#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
59pub struct ProjectSnapshot {
60 pub worktree_snapshots: Vec<project::telemetry_snapshot::TelemetryWorktreeSnapshot>,
61 pub timestamp: DateTime<Utc>,
62}
63
64pub struct RulesLoadingError {
65 pub message: SharedString,
66}
67
68/// Holds both the internal Thread and the AcpThread for a session
69struct Session {
70 /// The internal thread that processes messages
71 thread: Entity<Thread>,
72 /// The ACP thread that handles protocol communication
73 acp_thread: Entity<acp_thread::AcpThread>,
74 pending_save: Task<()>,
75 _subscriptions: Vec<Subscription>,
76}
77
78pub struct LanguageModels {
79 /// Access language model by ID
80 models: HashMap<acp::ModelId, Arc<dyn LanguageModel>>,
81 /// Cached list for returning language model information
82 model_list: acp_thread::AgentModelList,
83 refresh_models_rx: watch::Receiver<()>,
84 refresh_models_tx: watch::Sender<()>,
85 _authenticate_all_providers_task: Task<()>,
86}
87
88impl LanguageModels {
89 fn new(cx: &mut App) -> Self {
90 let (refresh_models_tx, refresh_models_rx) = watch::channel(());
91
92 let mut this = Self {
93 models: HashMap::default(),
94 model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()),
95 refresh_models_rx,
96 refresh_models_tx,
97 _authenticate_all_providers_task: Self::authenticate_all_language_model_providers(cx),
98 };
99 this.refresh_list(cx);
100 this
101 }
102
103 fn refresh_list(&mut self, cx: &App) {
104 let providers = LanguageModelRegistry::global(cx)
105 .read(cx)
106 .visible_providers()
107 .into_iter()
108 .filter(|provider| provider.is_authenticated(cx))
109 .collect::<Vec<_>>();
110
111 let mut language_model_list = IndexMap::default();
112 let mut recommended_models = HashSet::default();
113
114 let mut recommended = Vec::new();
115 for provider in &providers {
116 for model in provider.recommended_models(cx) {
117 recommended_models.insert((model.provider_id(), model.id()));
118 recommended.push(Self::map_language_model_to_info(&model, provider));
119 }
120 }
121 if !recommended.is_empty() {
122 language_model_list.insert(
123 acp_thread::AgentModelGroupName("Recommended".into()),
124 recommended,
125 );
126 }
127
128 let mut models = HashMap::default();
129 for provider in providers {
130 let mut provider_models = Vec::new();
131 for model in provider.provided_models(cx) {
132 let model_info = Self::map_language_model_to_info(&model, &provider);
133 let model_id = model_info.id.clone();
134 provider_models.push(model_info);
135 models.insert(model_id, model);
136 }
137 if !provider_models.is_empty() {
138 language_model_list.insert(
139 acp_thread::AgentModelGroupName(provider.name().0.clone()),
140 provider_models,
141 );
142 }
143 }
144
145 self.models = models;
146 self.model_list = acp_thread::AgentModelList::Grouped(language_model_list);
147 self.refresh_models_tx.send(()).ok();
148 }
149
150 fn watch(&self) -> watch::Receiver<()> {
151 self.refresh_models_rx.clone()
152 }
153
154 pub fn model_from_id(&self, model_id: &acp::ModelId) -> Option<Arc<dyn LanguageModel>> {
155 self.models.get(model_id).cloned()
156 }
157
158 fn map_language_model_to_info(
159 model: &Arc<dyn LanguageModel>,
160 provider: &Arc<dyn LanguageModelProvider>,
161 ) -> acp_thread::AgentModelInfo {
162 acp_thread::AgentModelInfo {
163 id: Self::model_id(model),
164 name: model.name().0,
165 description: None,
166 icon: Some(match provider.icon() {
167 IconOrSvg::Svg(path) => acp_thread::AgentModelIcon::Path(path),
168 IconOrSvg::Icon(name) => acp_thread::AgentModelIcon::Named(name),
169 }),
170 is_latest: model.is_latest(),
171 cost: model.model_cost_info().map(|cost| cost.to_shared_string()),
172 }
173 }
174
175 fn model_id(model: &Arc<dyn LanguageModel>) -> acp::ModelId {
176 acp::ModelId::new(format!("{}/{}", model.provider_id().0, model.id().0))
177 }
178
179 fn authenticate_all_language_model_providers(cx: &mut App) -> Task<()> {
180 let authenticate_all_providers = LanguageModelRegistry::global(cx)
181 .read(cx)
182 .visible_providers()
183 .iter()
184 .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
185 .collect::<Vec<_>>();
186
187 cx.background_spawn(async move {
188 for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
189 if let Err(err) = authenticate_task.await {
190 match err {
191 language_model::AuthenticateError::CredentialsNotFound => {
192 // Since we're authenticating these providers in the
193 // background for the purposes of populating the
194 // language selector, we don't care about providers
195 // where the credentials are not found.
196 }
197 language_model::AuthenticateError::ConnectionRefused => {
198 // Not logging connection refused errors as they are mostly from LM Studio's noisy auth failures.
199 // LM Studio only has one auth method (endpoint call) which fails for users who haven't enabled it.
200 // TODO: Better manage LM Studio auth logic to avoid these noisy failures.
201 }
202 _ => {
203 // Some providers have noisy failure states that we
204 // don't want to spam the logs with every time the
205 // language model selector is initialized.
206 //
207 // Ideally these should have more clear failure modes
208 // that we know are safe to ignore here, like what we do
209 // with `CredentialsNotFound` above.
210 match provider_id.0.as_ref() {
211 "lmstudio" | "ollama" => {
212 // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
213 //
214 // These fail noisily, so we don't log them.
215 }
216 "copilot_chat" => {
217 // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
218 }
219 _ => {
220 log::error!(
221 "Failed to authenticate provider: {}: {err:#}",
222 provider_name.0
223 );
224 }
225 }
226 }
227 }
228 }
229 }
230 })
231 }
232}
233
234pub struct NativeAgent {
235 /// Session ID -> Session mapping
236 sessions: HashMap<acp::SessionId, Session>,
237 thread_store: Entity<ThreadStore>,
238 /// Shared project context for all threads
239 project_context: Entity<ProjectContext>,
240 project_context_needs_refresh: watch::Sender<()>,
241 _maintain_project_context: Task<Result<()>>,
242 context_server_registry: Entity<ContextServerRegistry>,
243 /// Shared templates for all threads
244 templates: Arc<Templates>,
245 /// Cached model information
246 models: LanguageModels,
247 project: Entity<Project>,
248 prompt_store: Option<Entity<PromptStore>>,
249 fs: Arc<dyn Fs>,
250 _subscriptions: Vec<Subscription>,
251}
252
253impl NativeAgent {
254 pub async fn new(
255 project: Entity<Project>,
256 thread_store: Entity<ThreadStore>,
257 templates: Arc<Templates>,
258 prompt_store: Option<Entity<PromptStore>>,
259 fs: Arc<dyn Fs>,
260 cx: &mut AsyncApp,
261 ) -> Result<Entity<NativeAgent>> {
262 log::debug!("Creating new NativeAgent");
263
264 let project_context = cx
265 .update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))
266 .await;
267
268 Ok(cx.new(|cx| {
269 let context_server_store = project.read(cx).context_server_store();
270 let context_server_registry =
271 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
272
273 let mut subscriptions = vec![
274 cx.subscribe(&project, Self::handle_project_event),
275 cx.subscribe(
276 &LanguageModelRegistry::global(cx),
277 Self::handle_models_updated_event,
278 ),
279 cx.subscribe(
280 &context_server_store,
281 Self::handle_context_server_store_updated,
282 ),
283 cx.subscribe(
284 &context_server_registry,
285 Self::handle_context_server_registry_event,
286 ),
287 ];
288 if let Some(prompt_store) = prompt_store.as_ref() {
289 subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
290 }
291
292 let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
293 watch::channel(());
294 Self {
295 sessions: HashMap::default(),
296 thread_store,
297 project_context: cx.new(|_| project_context),
298 project_context_needs_refresh: project_context_needs_refresh_tx,
299 _maintain_project_context: cx.spawn(async move |this, cx| {
300 Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
301 }),
302 context_server_registry,
303 templates,
304 models: LanguageModels::new(cx),
305 project,
306 prompt_store,
307 fs,
308 _subscriptions: subscriptions,
309 }
310 }))
311 }
312
313 fn new_session(
314 &mut self,
315 project: Entity<Project>,
316 cx: &mut Context<Self>,
317 ) -> Entity<AcpThread> {
318 // Create Thread
319 // Fetch default model from registry settings
320 let registry = LanguageModelRegistry::read_global(cx);
321 // Log available models for debugging
322 let available_count = registry.available_models(cx).count();
323 log::debug!("Total available models: {}", available_count);
324
325 let default_model = registry.default_model().and_then(|default_model| {
326 self.models
327 .model_from_id(&LanguageModels::model_id(&default_model.model))
328 });
329 let thread = cx.new(|cx| {
330 Thread::new(
331 project.clone(),
332 self.project_context.clone(),
333 self.context_server_registry.clone(),
334 self.templates.clone(),
335 default_model,
336 cx,
337 )
338 });
339
340 self.register_session(thread, cx)
341 }
342
343 fn register_session(
344 &mut self,
345 thread_handle: Entity<Thread>,
346 cx: &mut Context<Self>,
347 ) -> Entity<AcpThread> {
348 let connection = Rc::new(NativeAgentConnection(cx.entity()));
349
350 let thread = thread_handle.read(cx);
351 let session_id = thread.id().clone();
352 let parent_session_id = thread.parent_thread_id();
353 let title = thread.title();
354 let draft_prompt = thread.draft_prompt().map(Vec::from);
355 let project = thread.project.clone();
356 let action_log = thread.action_log.clone();
357 let prompt_capabilities_rx = thread.prompt_capabilities_rx.clone();
358 let acp_thread = cx.new(|cx| {
359 let mut acp_thread = acp_thread::AcpThread::new(
360 parent_session_id,
361 title,
362 connection,
363 project.clone(),
364 action_log.clone(),
365 session_id.clone(),
366 prompt_capabilities_rx,
367 cx,
368 );
369 acp_thread.set_draft_prompt(draft_prompt);
370 acp_thread
371 });
372
373 let registry = LanguageModelRegistry::read_global(cx);
374 let summarization_model = registry.thread_summary_model().map(|c| c.model);
375
376 let weak = cx.weak_entity();
377 let weak_thread = thread_handle.downgrade();
378 thread_handle.update(cx, |thread, cx| {
379 thread.set_summarization_model(summarization_model, cx);
380 thread.add_default_tools(
381 Rc::new(NativeThreadEnvironment {
382 acp_thread: acp_thread.downgrade(),
383 thread: weak_thread,
384 agent: weak,
385 }) as _,
386 cx,
387 )
388 });
389
390 let subscriptions = vec![
391 cx.subscribe(&thread_handle, Self::handle_thread_title_updated),
392 cx.subscribe(&thread_handle, Self::handle_thread_token_usage_updated),
393 cx.observe(&thread_handle, move |this, thread, cx| {
394 this.save_thread(thread, cx)
395 }),
396 ];
397
398 self.sessions.insert(
399 session_id,
400 Session {
401 thread: thread_handle,
402 acp_thread: acp_thread.clone(),
403 _subscriptions: subscriptions,
404 pending_save: Task::ready(()),
405 },
406 );
407
408 self.update_available_commands(cx);
409
410 acp_thread
411 }
412
413 pub fn models(&self) -> &LanguageModels {
414 &self.models
415 }
416
417 async fn maintain_project_context(
418 this: WeakEntity<Self>,
419 mut needs_refresh: watch::Receiver<()>,
420 cx: &mut AsyncApp,
421 ) -> Result<()> {
422 while needs_refresh.changed().await.is_ok() {
423 let project_context = this
424 .update(cx, |this, cx| {
425 Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx)
426 })?
427 .await;
428 this.update(cx, |this, cx| {
429 this.project_context = cx.new(|_| project_context);
430 })?;
431 }
432
433 Ok(())
434 }
435
436 fn build_project_context(
437 project: &Entity<Project>,
438 prompt_store: Option<&Entity<PromptStore>>,
439 cx: &mut App,
440 ) -> Task<ProjectContext> {
441 let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
442 let worktree_tasks = worktrees
443 .into_iter()
444 .map(|worktree| {
445 Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
446 })
447 .collect::<Vec<_>>();
448 let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
449 prompt_store.read_with(cx, |prompt_store, cx| {
450 let prompts = prompt_store.default_prompt_metadata();
451 let load_tasks = prompts.into_iter().map(|prompt_metadata| {
452 let contents = prompt_store.load(prompt_metadata.id, cx);
453 async move { (contents.await, prompt_metadata) }
454 });
455 cx.background_spawn(future::join_all(load_tasks))
456 })
457 } else {
458 Task::ready(vec![])
459 };
460
461 cx.spawn(async move |_cx| {
462 let (worktrees, default_user_rules) =
463 future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
464
465 let worktrees = worktrees
466 .into_iter()
467 .map(|(worktree, _rules_error)| {
468 // TODO: show error message
469 // if let Some(rules_error) = rules_error {
470 // this.update(cx, |_, cx| cx.emit(rules_error)).ok();
471 // }
472 worktree
473 })
474 .collect::<Vec<_>>();
475
476 let default_user_rules = default_user_rules
477 .into_iter()
478 .flat_map(|(contents, prompt_metadata)| match contents {
479 Ok(contents) => Some(UserRulesContext {
480 uuid: prompt_metadata.id.as_user()?,
481 title: prompt_metadata.title.map(|title| title.to_string()),
482 contents,
483 }),
484 Err(_err) => {
485 // TODO: show error message
486 // this.update(cx, |_, cx| {
487 // cx.emit(RulesLoadingError {
488 // message: format!("{err:?}").into(),
489 // });
490 // })
491 // .ok();
492 None
493 }
494 })
495 .collect::<Vec<_>>();
496
497 ProjectContext::new(worktrees, default_user_rules)
498 })
499 }
500
501 fn load_worktree_info_for_system_prompt(
502 worktree: Entity<Worktree>,
503 project: Entity<Project>,
504 cx: &mut App,
505 ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
506 let tree = worktree.read(cx);
507 let root_name = tree.root_name_str().into();
508 let abs_path = tree.abs_path();
509
510 let mut context = WorktreeContext {
511 root_name,
512 abs_path,
513 rules_file: None,
514 };
515
516 let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
517 let Some(rules_task) = rules_task else {
518 return Task::ready((context, None));
519 };
520
521 cx.spawn(async move |_| {
522 let (rules_file, rules_file_error) = match rules_task.await {
523 Ok(rules_file) => (Some(rules_file), None),
524 Err(err) => (
525 None,
526 Some(RulesLoadingError {
527 message: format!("{err}").into(),
528 }),
529 ),
530 };
531 context.rules_file = rules_file;
532 (context, rules_file_error)
533 })
534 }
535
536 fn load_worktree_rules_file(
537 worktree: Entity<Worktree>,
538 project: Entity<Project>,
539 cx: &mut App,
540 ) -> Option<Task<Result<RulesFileContext>>> {
541 let worktree = worktree.read(cx);
542 let worktree_id = worktree.id();
543 let selected_rules_file = RULES_FILE_NAMES
544 .into_iter()
545 .filter_map(|name| {
546 worktree
547 .entry_for_path(RelPath::unix(name).unwrap())
548 .filter(|entry| entry.is_file())
549 .map(|entry| entry.path.clone())
550 })
551 .next();
552
553 // Note that Cline supports `.clinerules` being a directory, but that is not currently
554 // supported. This doesn't seem to occur often in GitHub repositories.
555 selected_rules_file.map(|path_in_worktree| {
556 let project_path = ProjectPath {
557 worktree_id,
558 path: path_in_worktree.clone(),
559 };
560 let buffer_task =
561 project.update(cx, |project, cx| project.open_buffer(project_path, cx));
562 let rope_task = cx.spawn(async move |cx| {
563 let buffer = buffer_task.await?;
564 let (project_entry_id, rope) = buffer.read_with(cx, |buffer, cx| {
565 let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
566 anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
567 })?;
568 anyhow::Ok((project_entry_id, rope))
569 });
570 // Build a string from the rope on a background thread.
571 cx.background_spawn(async move {
572 let (project_entry_id, rope) = rope_task.await?;
573 anyhow::Ok(RulesFileContext {
574 path_in_worktree,
575 text: rope.to_string().trim().to_string(),
576 project_entry_id: project_entry_id.to_usize(),
577 })
578 })
579 })
580 }
581
582 fn handle_thread_title_updated(
583 &mut self,
584 thread: Entity<Thread>,
585 _: &TitleUpdated,
586 cx: &mut Context<Self>,
587 ) {
588 let session_id = thread.read(cx).id();
589 let Some(session) = self.sessions.get(session_id) else {
590 return;
591 };
592 let thread = thread.downgrade();
593 let acp_thread = session.acp_thread.downgrade();
594 cx.spawn(async move |_, cx| {
595 let title = thread.read_with(cx, |thread, _| thread.title())?;
596 let task = acp_thread.update(cx, |acp_thread, cx| acp_thread.set_title(title, cx))?;
597 task.await
598 })
599 .detach_and_log_err(cx);
600 }
601
602 fn handle_thread_token_usage_updated(
603 &mut self,
604 thread: Entity<Thread>,
605 usage: &TokenUsageUpdated,
606 cx: &mut Context<Self>,
607 ) {
608 let Some(session) = self.sessions.get(thread.read(cx).id()) else {
609 return;
610 };
611 session.acp_thread.update(cx, |acp_thread, cx| {
612 acp_thread.update_token_usage(usage.0.clone(), cx);
613 });
614 }
615
616 fn handle_project_event(
617 &mut self,
618 _project: Entity<Project>,
619 event: &project::Event,
620 _cx: &mut Context<Self>,
621 ) {
622 match event {
623 project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
624 self.project_context_needs_refresh.send(()).ok();
625 }
626 project::Event::WorktreeUpdatedEntries(_, items) => {
627 if items.iter().any(|(path, _, _)| {
628 RULES_FILE_NAMES
629 .iter()
630 .any(|name| path.as_ref() == RelPath::unix(name).unwrap())
631 }) {
632 self.project_context_needs_refresh.send(()).ok();
633 }
634 }
635 _ => {}
636 }
637 }
638
639 fn handle_prompts_updated_event(
640 &mut self,
641 _prompt_store: Entity<PromptStore>,
642 _event: &prompt_store::PromptsUpdatedEvent,
643 _cx: &mut Context<Self>,
644 ) {
645 self.project_context_needs_refresh.send(()).ok();
646 }
647
648 fn handle_models_updated_event(
649 &mut self,
650 _registry: Entity<LanguageModelRegistry>,
651 _event: &language_model::Event,
652 cx: &mut Context<Self>,
653 ) {
654 self.models.refresh_list(cx);
655
656 let registry = LanguageModelRegistry::read_global(cx);
657 let default_model = registry.default_model().map(|m| m.model);
658 let summarization_model = registry.thread_summary_model().map(|m| m.model);
659
660 for session in self.sessions.values_mut() {
661 session.thread.update(cx, |thread, cx| {
662 if thread.model().is_none()
663 && let Some(model) = default_model.clone()
664 {
665 thread.set_model(model, cx);
666 cx.notify();
667 }
668 thread.set_summarization_model(summarization_model.clone(), cx);
669 });
670 }
671 }
672
673 fn handle_context_server_store_updated(
674 &mut self,
675 _store: Entity<project::context_server_store::ContextServerStore>,
676 _event: &project::context_server_store::ServerStatusChangedEvent,
677 cx: &mut Context<Self>,
678 ) {
679 self.update_available_commands(cx);
680 }
681
682 fn handle_context_server_registry_event(
683 &mut self,
684 _registry: Entity<ContextServerRegistry>,
685 event: &ContextServerRegistryEvent,
686 cx: &mut Context<Self>,
687 ) {
688 match event {
689 ContextServerRegistryEvent::ToolsChanged => {}
690 ContextServerRegistryEvent::PromptsChanged => {
691 self.update_available_commands(cx);
692 }
693 }
694 }
695
696 fn update_available_commands(&self, cx: &mut Context<Self>) {
697 let available_commands = self.build_available_commands(cx);
698 for session in self.sessions.values() {
699 session.acp_thread.update(cx, |thread, cx| {
700 thread
701 .handle_session_update(
702 acp::SessionUpdate::AvailableCommandsUpdate(
703 acp::AvailableCommandsUpdate::new(available_commands.clone()),
704 ),
705 cx,
706 )
707 .log_err();
708 });
709 }
710 }
711
712 fn build_available_commands(&self, cx: &App) -> Vec<acp::AvailableCommand> {
713 let registry = self.context_server_registry.read(cx);
714
715 let mut prompt_name_counts: HashMap<&str, usize> = HashMap::default();
716 for context_server_prompt in registry.prompts() {
717 *prompt_name_counts
718 .entry(context_server_prompt.prompt.name.as_str())
719 .or_insert(0) += 1;
720 }
721
722 registry
723 .prompts()
724 .flat_map(|context_server_prompt| {
725 let prompt = &context_server_prompt.prompt;
726
727 let should_prefix = prompt_name_counts
728 .get(prompt.name.as_str())
729 .copied()
730 .unwrap_or(0)
731 > 1;
732
733 let name = if should_prefix {
734 format!("{}.{}", context_server_prompt.server_id, prompt.name)
735 } else {
736 prompt.name.clone()
737 };
738
739 let mut command = acp::AvailableCommand::new(
740 name,
741 prompt.description.clone().unwrap_or_default(),
742 );
743
744 match prompt.arguments.as_deref() {
745 Some([arg]) => {
746 let hint = format!("<{}>", arg.name);
747
748 command = command.input(acp::AvailableCommandInput::Unstructured(
749 acp::UnstructuredCommandInput::new(hint),
750 ));
751 }
752 Some([]) | None => {}
753 Some(_) => {
754 // skip >1 argument commands since we don't support them yet
755 return None;
756 }
757 }
758
759 Some(command)
760 })
761 .collect()
762 }
763
764 pub fn load_thread(
765 &mut self,
766 id: acp::SessionId,
767 cx: &mut Context<Self>,
768 ) -> Task<Result<Entity<Thread>>> {
769 let database_future = ThreadsDatabase::connect(cx);
770 cx.spawn(async move |this, cx| {
771 let database = database_future.await.map_err(|err| anyhow!(err))?;
772 let db_thread = database
773 .load_thread(id.clone())
774 .await?
775 .with_context(|| format!("no thread found with ID: {id:?}"))?;
776
777 this.update(cx, |this, cx| {
778 let summarization_model = LanguageModelRegistry::read_global(cx)
779 .thread_summary_model()
780 .map(|c| c.model);
781
782 cx.new(|cx| {
783 let mut thread = Thread::from_db(
784 id.clone(),
785 db_thread,
786 this.project.clone(),
787 this.project_context.clone(),
788 this.context_server_registry.clone(),
789 this.templates.clone(),
790 cx,
791 );
792 thread.set_summarization_model(summarization_model, cx);
793 thread
794 })
795 })
796 })
797 }
798
799 pub fn open_thread(
800 &mut self,
801 id: acp::SessionId,
802 cx: &mut Context<Self>,
803 ) -> Task<Result<Entity<AcpThread>>> {
804 if let Some(session) = self.sessions.get(&id) {
805 return Task::ready(Ok(session.acp_thread.clone()));
806 }
807
808 let task = self.load_thread(id, cx);
809 cx.spawn(async move |this, cx| {
810 let thread = task.await?;
811 let acp_thread =
812 this.update(cx, |this, cx| this.register_session(thread.clone(), cx))?;
813 let events = thread.update(cx, |thread, cx| thread.replay(cx));
814 cx.update(|cx| {
815 NativeAgentConnection::handle_thread_events(events, acp_thread.downgrade(), cx)
816 })
817 .await?;
818 Ok(acp_thread)
819 })
820 }
821
822 pub fn thread_summary(
823 &mut self,
824 id: acp::SessionId,
825 cx: &mut Context<Self>,
826 ) -> Task<Result<SharedString>> {
827 let thread = self.open_thread(id.clone(), cx);
828 cx.spawn(async move |this, cx| {
829 let acp_thread = thread.await?;
830 let result = this
831 .update(cx, |this, cx| {
832 this.sessions
833 .get(&id)
834 .unwrap()
835 .thread
836 .update(cx, |thread, cx| thread.summary(cx))
837 })?
838 .await
839 .context("Failed to generate summary")?;
840 drop(acp_thread);
841 Ok(result)
842 })
843 }
844
845 fn save_thread(&mut self, thread: Entity<Thread>, cx: &mut Context<Self>) {
846 if thread.read(cx).is_empty() {
847 return;
848 }
849
850 let id = thread.read(cx).id().clone();
851 let Some(session) = self.sessions.get_mut(&id) else {
852 return;
853 };
854
855 let folder_paths = PathList::new(
856 &self
857 .project
858 .read(cx)
859 .visible_worktrees(cx)
860 .map(|worktree| worktree.read(cx).abs_path().to_path_buf())
861 .collect::<Vec<_>>(),
862 );
863
864 let draft_prompt = session.acp_thread.read(cx).draft_prompt().map(Vec::from);
865 let database_future = ThreadsDatabase::connect(cx);
866 let db_thread = thread.update(cx, |thread, cx| {
867 thread.set_draft_prompt(draft_prompt);
868 thread.to_db(cx)
869 });
870 let thread_store = self.thread_store.clone();
871 session.pending_save = cx.spawn(async move |_, cx| {
872 let Some(database) = database_future.await.map_err(|err| anyhow!(err)).log_err() else {
873 return;
874 };
875 let db_thread = db_thread.await;
876 database
877 .save_thread(id, db_thread, folder_paths)
878 .await
879 .log_err();
880 thread_store.update(cx, |store, cx| store.reload(cx));
881 });
882 }
883
884 fn send_mcp_prompt(
885 &self,
886 message_id: UserMessageId,
887 session_id: agent_client_protocol::SessionId,
888 prompt_name: String,
889 server_id: ContextServerId,
890 arguments: HashMap<String, String>,
891 original_content: Vec<acp::ContentBlock>,
892 cx: &mut Context<Self>,
893 ) -> Task<Result<acp::PromptResponse>> {
894 let server_store = self.context_server_registry.read(cx).server_store().clone();
895 let path_style = self.project.read(cx).path_style(cx);
896
897 cx.spawn(async move |this, cx| {
898 let prompt =
899 crate::get_prompt(&server_store, &server_id, &prompt_name, arguments, cx).await?;
900
901 let (acp_thread, thread) = this.update(cx, |this, _cx| {
902 let session = this
903 .sessions
904 .get(&session_id)
905 .context("Failed to get session")?;
906 anyhow::Ok((session.acp_thread.clone(), session.thread.clone()))
907 })??;
908
909 let mut last_is_user = true;
910
911 thread.update(cx, |thread, cx| {
912 thread.push_acp_user_block(
913 message_id,
914 original_content.into_iter().skip(1),
915 path_style,
916 cx,
917 );
918 });
919
920 for message in prompt.messages {
921 let context_server::types::PromptMessage { role, content } = message;
922 let block = mcp_message_content_to_acp_content_block(content);
923
924 match role {
925 context_server::types::Role::User => {
926 let id = acp_thread::UserMessageId::new();
927
928 acp_thread.update(cx, |acp_thread, cx| {
929 acp_thread.push_user_content_block_with_indent(
930 Some(id.clone()),
931 block.clone(),
932 true,
933 cx,
934 );
935 });
936
937 thread.update(cx, |thread, cx| {
938 thread.push_acp_user_block(id, [block], path_style, cx);
939 });
940 }
941 context_server::types::Role::Assistant => {
942 acp_thread.update(cx, |acp_thread, cx| {
943 acp_thread.push_assistant_content_block_with_indent(
944 block.clone(),
945 false,
946 true,
947 cx,
948 );
949 });
950
951 thread.update(cx, |thread, cx| {
952 thread.push_acp_agent_block(block, cx);
953 });
954 }
955 }
956
957 last_is_user = role == context_server::types::Role::User;
958 }
959
960 let response_stream = thread.update(cx, |thread, cx| {
961 if last_is_user {
962 thread.send_existing(cx)
963 } else {
964 // Resume if MCP prompt did not end with a user message
965 thread.resume(cx)
966 }
967 })?;
968
969 cx.update(|cx| {
970 NativeAgentConnection::handle_thread_events(
971 response_stream,
972 acp_thread.downgrade(),
973 cx,
974 )
975 })
976 .await
977 })
978 }
979}
980
981/// Wrapper struct that implements the AgentConnection trait
982#[derive(Clone)]
983pub struct NativeAgentConnection(pub Entity<NativeAgent>);
984
985impl NativeAgentConnection {
986 pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option<Entity<Thread>> {
987 self.0
988 .read(cx)
989 .sessions
990 .get(session_id)
991 .map(|session| session.thread.clone())
992 }
993
994 pub fn load_thread(&self, id: acp::SessionId, cx: &mut App) -> Task<Result<Entity<Thread>>> {
995 self.0.update(cx, |this, cx| this.load_thread(id, cx))
996 }
997
998 fn run_turn(
999 &self,
1000 session_id: acp::SessionId,
1001 cx: &mut App,
1002 f: impl 'static
1003 + FnOnce(Entity<Thread>, &mut App) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>,
1004 ) -> Task<Result<acp::PromptResponse>> {
1005 let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
1006 agent
1007 .sessions
1008 .get_mut(&session_id)
1009 .map(|s| (s.thread.clone(), s.acp_thread.clone()))
1010 }) else {
1011 return Task::ready(Err(anyhow!("Session not found")));
1012 };
1013 log::debug!("Found session for: {}", session_id);
1014
1015 let response_stream = match f(thread, cx) {
1016 Ok(stream) => stream,
1017 Err(err) => return Task::ready(Err(err)),
1018 };
1019 Self::handle_thread_events(response_stream, acp_thread.downgrade(), cx)
1020 }
1021
1022 fn handle_thread_events(
1023 mut events: mpsc::UnboundedReceiver<Result<ThreadEvent>>,
1024 acp_thread: WeakEntity<AcpThread>,
1025 cx: &App,
1026 ) -> Task<Result<acp::PromptResponse>> {
1027 cx.spawn(async move |cx| {
1028 // Handle response stream and forward to session.acp_thread
1029 while let Some(result) = events.next().await {
1030 match result {
1031 Ok(event) => {
1032 log::trace!("Received completion event: {:?}", event);
1033
1034 match event {
1035 ThreadEvent::UserMessage(message) => {
1036 acp_thread.update(cx, |thread, cx| {
1037 for content in message.content {
1038 thread.push_user_content_block(
1039 Some(message.id.clone()),
1040 content.into(),
1041 cx,
1042 );
1043 }
1044 })?;
1045 }
1046 ThreadEvent::AgentText(text) => {
1047 acp_thread.update(cx, |thread, cx| {
1048 thread.push_assistant_content_block(text.into(), false, cx)
1049 })?;
1050 }
1051 ThreadEvent::AgentThinking(text) => {
1052 acp_thread.update(cx, |thread, cx| {
1053 thread.push_assistant_content_block(text.into(), true, cx)
1054 })?;
1055 }
1056 ThreadEvent::ToolCallAuthorization(ToolCallAuthorization {
1057 tool_call,
1058 options,
1059 response,
1060 context: _,
1061 }) => {
1062 let outcome_task = acp_thread.update(cx, |thread, cx| {
1063 thread.request_tool_call_authorization(tool_call, options, cx)
1064 })??;
1065 cx.background_spawn(async move {
1066 if let acp::RequestPermissionOutcome::Selected(
1067 acp::SelectedPermissionOutcome { option_id, .. },
1068 ) = outcome_task.await
1069 {
1070 response
1071 .send(option_id)
1072 .map(|_| anyhow!("authorization receiver was dropped"))
1073 .log_err();
1074 }
1075 })
1076 .detach();
1077 }
1078 ThreadEvent::ToolCall(tool_call) => {
1079 acp_thread.update(cx, |thread, cx| {
1080 thread.upsert_tool_call(tool_call, cx)
1081 })??;
1082 }
1083 ThreadEvent::ToolCallUpdate(update) => {
1084 acp_thread.update(cx, |thread, cx| {
1085 thread.update_tool_call(update, cx)
1086 })??;
1087 }
1088 ThreadEvent::SubagentSpawned(session_id) => {
1089 acp_thread.update(cx, |thread, cx| {
1090 thread.subagent_spawned(session_id, cx);
1091 })?;
1092 }
1093 ThreadEvent::Retry(status) => {
1094 acp_thread.update(cx, |thread, cx| {
1095 thread.update_retry_status(status, cx)
1096 })?;
1097 }
1098 ThreadEvent::Stop(stop_reason) => {
1099 log::debug!("Assistant message complete: {:?}", stop_reason);
1100 return Ok(acp::PromptResponse::new(stop_reason));
1101 }
1102 }
1103 }
1104 Err(e) => {
1105 log::error!("Error in model response stream: {:?}", e);
1106 return Err(e);
1107 }
1108 }
1109 }
1110
1111 log::debug!("Response stream completed");
1112 anyhow::Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
1113 })
1114 }
1115}
1116
1117struct Command<'a> {
1118 prompt_name: &'a str,
1119 arg_value: &'a str,
1120 explicit_server_id: Option<&'a str>,
1121}
1122
1123impl<'a> Command<'a> {
1124 fn parse(prompt: &'a [acp::ContentBlock]) -> Option<Self> {
1125 let acp::ContentBlock::Text(text_content) = prompt.first()? else {
1126 return None;
1127 };
1128 let text = text_content.text.trim();
1129 let command = text.strip_prefix('/')?;
1130 let (command, arg_value) = command
1131 .split_once(char::is_whitespace)
1132 .unwrap_or((command, ""));
1133
1134 if let Some((server_id, prompt_name)) = command.split_once('.') {
1135 Some(Self {
1136 prompt_name,
1137 arg_value,
1138 explicit_server_id: Some(server_id),
1139 })
1140 } else {
1141 Some(Self {
1142 prompt_name: command,
1143 arg_value,
1144 explicit_server_id: None,
1145 })
1146 }
1147 }
1148}
1149
1150struct NativeAgentModelSelector {
1151 session_id: acp::SessionId,
1152 connection: NativeAgentConnection,
1153}
1154
1155impl acp_thread::AgentModelSelector for NativeAgentModelSelector {
1156 fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
1157 log::debug!("NativeAgentConnection::list_models called");
1158 let list = self.connection.0.read(cx).models.model_list.clone();
1159 Task::ready(if list.is_empty() {
1160 Err(anyhow::anyhow!("No models available"))
1161 } else {
1162 Ok(list)
1163 })
1164 }
1165
1166 fn select_model(&self, model_id: acp::ModelId, cx: &mut App) -> Task<Result<()>> {
1167 log::debug!(
1168 "Setting model for session {}: {}",
1169 self.session_id,
1170 model_id
1171 );
1172 let Some(thread) = self
1173 .connection
1174 .0
1175 .read(cx)
1176 .sessions
1177 .get(&self.session_id)
1178 .map(|session| session.thread.clone())
1179 else {
1180 return Task::ready(Err(anyhow!("Session not found")));
1181 };
1182
1183 let Some(model) = self.connection.0.read(cx).models.model_from_id(&model_id) else {
1184 return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
1185 };
1186
1187 // We want to reset the effort level when switching models, as the currently-selected effort level may
1188 // not be compatible.
1189 let effort = model
1190 .default_effort_level()
1191 .map(|effort_level| effort_level.value.to_string());
1192
1193 thread.update(cx, |thread, cx| {
1194 thread.set_model(model.clone(), cx);
1195 thread.set_thinking_effort(effort.clone(), cx);
1196 thread.set_thinking_enabled(model.supports_thinking(), cx);
1197 });
1198
1199 update_settings_file(
1200 self.connection.0.read(cx).fs.clone(),
1201 cx,
1202 move |settings, cx| {
1203 let provider = model.provider_id().0.to_string();
1204 let model = model.id().0.to_string();
1205 let enable_thinking = thread.read(cx).thinking_enabled();
1206 settings
1207 .agent
1208 .get_or_insert_default()
1209 .set_model(LanguageModelSelection {
1210 provider: provider.into(),
1211 model,
1212 enable_thinking,
1213 effort,
1214 });
1215 },
1216 );
1217
1218 Task::ready(Ok(()))
1219 }
1220
1221 fn selected_model(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelInfo>> {
1222 let Some(thread) = self
1223 .connection
1224 .0
1225 .read(cx)
1226 .sessions
1227 .get(&self.session_id)
1228 .map(|session| session.thread.clone())
1229 else {
1230 return Task::ready(Err(anyhow!("Session not found")));
1231 };
1232 let Some(model) = thread.read(cx).model() else {
1233 return Task::ready(Err(anyhow!("Model not found")));
1234 };
1235 let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
1236 else {
1237 return Task::ready(Err(anyhow!("Provider not found")));
1238 };
1239 Task::ready(Ok(LanguageModels::map_language_model_to_info(
1240 model, &provider,
1241 )))
1242 }
1243
1244 fn watch(&self, cx: &mut App) -> Option<watch::Receiver<()>> {
1245 Some(self.connection.0.read(cx).models.watch())
1246 }
1247
1248 fn should_render_footer(&self) -> bool {
1249 true
1250 }
1251}
1252
1253impl acp_thread::AgentConnection for NativeAgentConnection {
1254 fn telemetry_id(&self) -> SharedString {
1255 "zed".into()
1256 }
1257
1258 fn new_session(
1259 self: Rc<Self>,
1260 project: Entity<Project>,
1261 cwd: &Path,
1262 cx: &mut App,
1263 ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
1264 log::debug!("Creating new thread for project at: {cwd:?}");
1265 Task::ready(Ok(self
1266 .0
1267 .update(cx, |agent, cx| agent.new_session(project, cx))))
1268 }
1269
1270 fn supports_load_session(&self) -> bool {
1271 true
1272 }
1273
1274 fn load_session(
1275 self: Rc<Self>,
1276 session: AgentSessionInfo,
1277 _project: Entity<Project>,
1278 _cwd: &Path,
1279 cx: &mut App,
1280 ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
1281 self.0
1282 .update(cx, |agent, cx| agent.open_thread(session.session_id, cx))
1283 }
1284
1285 fn supports_close_session(&self) -> bool {
1286 true
1287 }
1288
1289 fn close_session(&self, session_id: &acp::SessionId, cx: &mut App) -> Task<Result<()>> {
1290 self.0.update(cx, |agent, _cx| {
1291 agent.sessions.remove(session_id);
1292 });
1293 Task::ready(Ok(()))
1294 }
1295
1296 fn auth_methods(&self) -> &[acp::AuthMethod] {
1297 &[] // No auth for in-process
1298 }
1299
1300 fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
1301 Task::ready(Ok(()))
1302 }
1303
1304 fn model_selector(&self, session_id: &acp::SessionId) -> Option<Rc<dyn AgentModelSelector>> {
1305 Some(Rc::new(NativeAgentModelSelector {
1306 session_id: session_id.clone(),
1307 connection: self.clone(),
1308 }) as Rc<dyn AgentModelSelector>)
1309 }
1310
1311 fn prompt(
1312 &self,
1313 id: Option<acp_thread::UserMessageId>,
1314 params: acp::PromptRequest,
1315 cx: &mut App,
1316 ) -> Task<Result<acp::PromptResponse>> {
1317 let id = id.expect("UserMessageId is required");
1318 let session_id = params.session_id.clone();
1319 log::info!("Received prompt request for session: {}", session_id);
1320 log::debug!("Prompt blocks count: {}", params.prompt.len());
1321
1322 if let Some(parsed_command) = Command::parse(¶ms.prompt) {
1323 let registry = self.0.read(cx).context_server_registry.read(cx);
1324
1325 let explicit_server_id = parsed_command
1326 .explicit_server_id
1327 .map(|server_id| ContextServerId(server_id.into()));
1328
1329 if let Some(prompt) =
1330 registry.find_prompt(explicit_server_id.as_ref(), parsed_command.prompt_name)
1331 {
1332 let arguments = if !parsed_command.arg_value.is_empty()
1333 && let Some(arg_name) = prompt
1334 .prompt
1335 .arguments
1336 .as_ref()
1337 .and_then(|args| args.first())
1338 .map(|arg| arg.name.clone())
1339 {
1340 HashMap::from_iter([(arg_name, parsed_command.arg_value.to_string())])
1341 } else {
1342 Default::default()
1343 };
1344
1345 let prompt_name = prompt.prompt.name.clone();
1346 let server_id = prompt.server_id.clone();
1347
1348 return self.0.update(cx, |agent, cx| {
1349 agent.send_mcp_prompt(
1350 id,
1351 session_id.clone(),
1352 prompt_name,
1353 server_id,
1354 arguments,
1355 params.prompt,
1356 cx,
1357 )
1358 });
1359 };
1360 };
1361
1362 let path_style = self.0.read(cx).project.read(cx).path_style(cx);
1363
1364 self.run_turn(session_id, cx, move |thread, cx| {
1365 let content: Vec<UserMessageContent> = params
1366 .prompt
1367 .into_iter()
1368 .map(|block| UserMessageContent::from_content_block(block, path_style))
1369 .collect::<Vec<_>>();
1370 log::debug!("Converted prompt to message: {} chars", content.len());
1371 log::debug!("Message id: {:?}", id);
1372 log::debug!("Message content: {:?}", content);
1373
1374 thread.update(cx, |thread, cx| thread.send(id, content, cx))
1375 })
1376 }
1377
1378 fn retry(
1379 &self,
1380 session_id: &acp::SessionId,
1381 _cx: &App,
1382 ) -> Option<Rc<dyn acp_thread::AgentSessionRetry>> {
1383 Some(Rc::new(NativeAgentSessionRetry {
1384 connection: self.clone(),
1385 session_id: session_id.clone(),
1386 }) as _)
1387 }
1388
1389 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
1390 log::info!("Cancelling on session: {}", session_id);
1391 self.0.update(cx, |agent, cx| {
1392 if let Some(session) = agent.sessions.get(session_id) {
1393 session
1394 .thread
1395 .update(cx, |thread, cx| thread.cancel(cx))
1396 .detach();
1397 }
1398 });
1399 }
1400
1401 fn truncate(
1402 &self,
1403 session_id: &agent_client_protocol::SessionId,
1404 cx: &App,
1405 ) -> Option<Rc<dyn acp_thread::AgentSessionTruncate>> {
1406 self.0.read_with(cx, |agent, _cx| {
1407 agent.sessions.get(session_id).map(|session| {
1408 Rc::new(NativeAgentSessionTruncate {
1409 thread: session.thread.clone(),
1410 acp_thread: session.acp_thread.downgrade(),
1411 }) as _
1412 })
1413 })
1414 }
1415
1416 fn set_title(
1417 &self,
1418 session_id: &acp::SessionId,
1419 cx: &App,
1420 ) -> Option<Rc<dyn acp_thread::AgentSessionSetTitle>> {
1421 self.0.read_with(cx, |agent, _cx| {
1422 agent
1423 .sessions
1424 .get(session_id)
1425 .filter(|s| !s.thread.read(cx).is_subagent())
1426 .map(|session| {
1427 Rc::new(NativeAgentSessionSetTitle {
1428 thread: session.thread.clone(),
1429 }) as _
1430 })
1431 })
1432 }
1433
1434 fn session_list(&self, cx: &mut App) -> Option<Rc<dyn AgentSessionList>> {
1435 let thread_store = self.0.read(cx).thread_store.clone();
1436 Some(Rc::new(NativeAgentSessionList::new(thread_store, cx)) as _)
1437 }
1438
1439 fn telemetry(&self) -> Option<Rc<dyn acp_thread::AgentTelemetry>> {
1440 Some(Rc::new(self.clone()) as Rc<dyn acp_thread::AgentTelemetry>)
1441 }
1442
1443 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1444 self
1445 }
1446}
1447
1448impl acp_thread::AgentTelemetry for NativeAgentConnection {
1449 fn thread_data(
1450 &self,
1451 session_id: &acp::SessionId,
1452 cx: &mut App,
1453 ) -> Task<Result<serde_json::Value>> {
1454 let Some(session) = self.0.read(cx).sessions.get(session_id) else {
1455 return Task::ready(Err(anyhow!("Session not found")));
1456 };
1457
1458 let task = session.thread.read(cx).to_db(cx);
1459 cx.background_spawn(async move {
1460 serde_json::to_value(task.await).context("Failed to serialize thread")
1461 })
1462 }
1463}
1464
1465pub struct NativeAgentSessionList {
1466 thread_store: Entity<ThreadStore>,
1467 updates_tx: smol::channel::Sender<acp_thread::SessionListUpdate>,
1468 updates_rx: smol::channel::Receiver<acp_thread::SessionListUpdate>,
1469 _subscription: Subscription,
1470}
1471
1472impl NativeAgentSessionList {
1473 fn new(thread_store: Entity<ThreadStore>, cx: &mut App) -> Self {
1474 let (tx, rx) = smol::channel::unbounded();
1475 let this_tx = tx.clone();
1476 let subscription = cx.observe(&thread_store, move |_, _| {
1477 this_tx
1478 .try_send(acp_thread::SessionListUpdate::Refresh)
1479 .ok();
1480 });
1481 Self {
1482 thread_store,
1483 updates_tx: tx,
1484 updates_rx: rx,
1485 _subscription: subscription,
1486 }
1487 }
1488
1489 fn to_session_info(entry: DbThreadMetadata) -> AgentSessionInfo {
1490 AgentSessionInfo {
1491 session_id: entry.id,
1492 cwd: None,
1493 title: Some(entry.title),
1494 updated_at: Some(entry.updated_at),
1495 meta: None,
1496 }
1497 }
1498
1499 pub fn thread_store(&self) -> &Entity<ThreadStore> {
1500 &self.thread_store
1501 }
1502}
1503
1504impl AgentSessionList for NativeAgentSessionList {
1505 fn list_sessions(
1506 &self,
1507 _request: AgentSessionListRequest,
1508 cx: &mut App,
1509 ) -> Task<Result<AgentSessionListResponse>> {
1510 let sessions = self
1511 .thread_store
1512 .read(cx)
1513 .entries()
1514 .map(Self::to_session_info)
1515 .collect();
1516 Task::ready(Ok(AgentSessionListResponse::new(sessions)))
1517 }
1518
1519 fn supports_delete(&self) -> bool {
1520 true
1521 }
1522
1523 fn delete_session(&self, session_id: &acp::SessionId, cx: &mut App) -> Task<Result<()>> {
1524 self.thread_store
1525 .update(cx, |store, cx| store.delete_thread(session_id.clone(), cx))
1526 }
1527
1528 fn delete_sessions(&self, cx: &mut App) -> Task<Result<()>> {
1529 self.thread_store
1530 .update(cx, |store, cx| store.delete_threads(cx))
1531 }
1532
1533 fn watch(
1534 &self,
1535 _cx: &mut App,
1536 ) -> Option<smol::channel::Receiver<acp_thread::SessionListUpdate>> {
1537 Some(self.updates_rx.clone())
1538 }
1539
1540 fn notify_refresh(&self) {
1541 self.updates_tx
1542 .try_send(acp_thread::SessionListUpdate::Refresh)
1543 .ok();
1544 }
1545
1546 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1547 self
1548 }
1549}
1550
1551struct NativeAgentSessionTruncate {
1552 thread: Entity<Thread>,
1553 acp_thread: WeakEntity<AcpThread>,
1554}
1555
1556impl acp_thread::AgentSessionTruncate for NativeAgentSessionTruncate {
1557 fn run(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
1558 match self.thread.update(cx, |thread, cx| {
1559 thread.truncate(message_id.clone(), cx)?;
1560 Ok(thread.latest_token_usage())
1561 }) {
1562 Ok(usage) => {
1563 self.acp_thread
1564 .update(cx, |thread, cx| {
1565 thread.update_token_usage(usage, cx);
1566 })
1567 .ok();
1568 Task::ready(Ok(()))
1569 }
1570 Err(error) => Task::ready(Err(error)),
1571 }
1572 }
1573}
1574
1575struct NativeAgentSessionRetry {
1576 connection: NativeAgentConnection,
1577 session_id: acp::SessionId,
1578}
1579
1580impl acp_thread::AgentSessionRetry for NativeAgentSessionRetry {
1581 fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>> {
1582 self.connection
1583 .run_turn(self.session_id.clone(), cx, |thread, cx| {
1584 thread.update(cx, |thread, cx| thread.resume(cx))
1585 })
1586 }
1587}
1588
1589struct NativeAgentSessionSetTitle {
1590 thread: Entity<Thread>,
1591}
1592
1593impl acp_thread::AgentSessionSetTitle for NativeAgentSessionSetTitle {
1594 fn run(&self, title: SharedString, cx: &mut App) -> Task<Result<()>> {
1595 self.thread
1596 .update(cx, |thread, cx| thread.set_title(title, cx));
1597 Task::ready(Ok(()))
1598 }
1599}
1600
1601pub struct NativeThreadEnvironment {
1602 agent: WeakEntity<NativeAgent>,
1603 thread: WeakEntity<Thread>,
1604 acp_thread: WeakEntity<AcpThread>,
1605}
1606
1607impl NativeThreadEnvironment {
1608 pub(crate) fn create_subagent_thread(
1609 &self,
1610 label: String,
1611 cx: &mut App,
1612 ) -> Result<Rc<dyn SubagentHandle>> {
1613 let Some(parent_thread_entity) = self.thread.upgrade() else {
1614 anyhow::bail!("Parent thread no longer exists".to_string());
1615 };
1616 let parent_thread = parent_thread_entity.read(cx);
1617 let current_depth = parent_thread.depth();
1618
1619 if current_depth >= MAX_SUBAGENT_DEPTH {
1620 return Err(anyhow!(
1621 "Maximum subagent depth ({}) reached",
1622 MAX_SUBAGENT_DEPTH
1623 ));
1624 }
1625
1626 let subagent_thread: Entity<Thread> = cx.new(|cx| {
1627 let mut thread = Thread::new_subagent(&parent_thread_entity, cx);
1628 thread.set_title(label.into(), cx);
1629 thread
1630 });
1631
1632 let session_id = subagent_thread.read(cx).id().clone();
1633
1634 let acp_thread = self.agent.update(cx, |agent, cx| {
1635 agent.register_session(subagent_thread.clone(), cx)
1636 })?;
1637
1638 self.prompt_subagent(session_id, subagent_thread, acp_thread)
1639 }
1640
1641 pub(crate) fn resume_subagent_thread(
1642 &self,
1643 session_id: acp::SessionId,
1644 cx: &mut App,
1645 ) -> Result<Rc<dyn SubagentHandle>> {
1646 let (subagent_thread, acp_thread) = self.agent.update(cx, |agent, _cx| {
1647 let session = agent
1648 .sessions
1649 .get(&session_id)
1650 .ok_or_else(|| anyhow!("No subagent session found with id {session_id}"))?;
1651 anyhow::Ok((session.thread.clone(), session.acp_thread.clone()))
1652 })??;
1653
1654 self.prompt_subagent(session_id, subagent_thread, acp_thread)
1655 }
1656
1657 fn prompt_subagent(
1658 &self,
1659 session_id: acp::SessionId,
1660 subagent_thread: Entity<Thread>,
1661 acp_thread: Entity<acp_thread::AcpThread>,
1662 ) -> Result<Rc<dyn SubagentHandle>> {
1663 let Some(parent_thread_entity) = self.thread.upgrade() else {
1664 anyhow::bail!("Parent thread no longer exists".to_string());
1665 };
1666 Ok(Rc::new(NativeSubagentHandle::new(
1667 session_id,
1668 subagent_thread,
1669 acp_thread,
1670 parent_thread_entity,
1671 )) as _)
1672 }
1673}
1674
1675impl ThreadEnvironment for NativeThreadEnvironment {
1676 fn create_terminal(
1677 &self,
1678 command: String,
1679 cwd: Option<PathBuf>,
1680 output_byte_limit: Option<u64>,
1681 cx: &mut AsyncApp,
1682 ) -> Task<Result<Rc<dyn TerminalHandle>>> {
1683 let task = self.acp_thread.update(cx, |thread, cx| {
1684 thread.create_terminal(command, vec![], vec![], cwd, output_byte_limit, cx)
1685 });
1686
1687 let acp_thread = self.acp_thread.clone();
1688 cx.spawn(async move |cx| {
1689 let terminal = task?.await?;
1690
1691 let (drop_tx, drop_rx) = oneshot::channel();
1692 let terminal_id = terminal.read_with(cx, |terminal, _cx| terminal.id().clone());
1693
1694 cx.spawn(async move |cx| {
1695 drop_rx.await.ok();
1696 acp_thread.update(cx, |thread, cx| thread.release_terminal(terminal_id, cx))
1697 })
1698 .detach();
1699
1700 let handle = AcpTerminalHandle {
1701 terminal,
1702 _drop_tx: Some(drop_tx),
1703 };
1704
1705 Ok(Rc::new(handle) as _)
1706 })
1707 }
1708
1709 fn create_subagent(&self, label: String, cx: &mut App) -> Result<Rc<dyn SubagentHandle>> {
1710 self.create_subagent_thread(label, cx)
1711 }
1712
1713 fn resume_subagent(
1714 &self,
1715 session_id: acp::SessionId,
1716 cx: &mut App,
1717 ) -> Result<Rc<dyn SubagentHandle>> {
1718 self.resume_subagent_thread(session_id, cx)
1719 }
1720}
1721
1722#[derive(Debug, Clone)]
1723enum SubagentPromptResult {
1724 Completed,
1725 Cancelled,
1726 ContextWindowWarning,
1727 Error(String),
1728}
1729
1730pub struct NativeSubagentHandle {
1731 session_id: acp::SessionId,
1732 parent_thread: WeakEntity<Thread>,
1733 subagent_thread: Entity<Thread>,
1734 acp_thread: Entity<acp_thread::AcpThread>,
1735}
1736
1737impl NativeSubagentHandle {
1738 fn new(
1739 session_id: acp::SessionId,
1740 subagent_thread: Entity<Thread>,
1741 acp_thread: Entity<acp_thread::AcpThread>,
1742 parent_thread_entity: Entity<Thread>,
1743 ) -> Self {
1744 NativeSubagentHandle {
1745 session_id,
1746 subagent_thread,
1747 parent_thread: parent_thread_entity.downgrade(),
1748 acp_thread,
1749 }
1750 }
1751}
1752
1753impl SubagentHandle for NativeSubagentHandle {
1754 fn id(&self) -> acp::SessionId {
1755 self.session_id.clone()
1756 }
1757
1758 fn num_entries(&self, cx: &App) -> usize {
1759 self.acp_thread.read(cx).entries().len()
1760 }
1761
1762 fn send(&self, message: String, cx: &AsyncApp) -> Task<Result<String>> {
1763 let thread = self.subagent_thread.clone();
1764 let acp_thread = self.acp_thread.clone();
1765 let subagent_session_id = self.session_id.clone();
1766 let parent_thread = self.parent_thread.clone();
1767
1768 cx.spawn(async move |cx| {
1769 let (task, _subscription) = cx.update(|cx| {
1770 let ratio_before_prompt = thread
1771 .read(cx)
1772 .latest_token_usage()
1773 .map(|usage| usage.ratio());
1774
1775 parent_thread
1776 .update(cx, |parent_thread, _cx| {
1777 parent_thread.register_running_subagent(thread.downgrade())
1778 })
1779 .ok();
1780
1781 let task = acp_thread.update(cx, |acp_thread, cx| {
1782 acp_thread.send(vec![message.into()], cx)
1783 });
1784
1785 let (token_limit_tx, token_limit_rx) = oneshot::channel::<()>();
1786 let mut token_limit_tx = Some(token_limit_tx);
1787
1788 let subscription = cx.subscribe(
1789 &thread,
1790 move |_thread, event: &TokenUsageUpdated, _cx| {
1791 if let Some(usage) = &event.0 {
1792 let old_ratio = ratio_before_prompt
1793 .clone()
1794 .unwrap_or(TokenUsageRatio::Normal);
1795 let new_ratio = usage.ratio();
1796 if old_ratio == TokenUsageRatio::Normal
1797 && new_ratio == TokenUsageRatio::Warning
1798 {
1799 if let Some(tx) = token_limit_tx.take() {
1800 tx.send(()).ok();
1801 }
1802 }
1803 }
1804 },
1805 );
1806
1807 let wait_for_prompt = cx
1808 .background_spawn(async move {
1809 futures::select! {
1810 response = task.fuse() => match response {
1811 Ok(Some(response)) => {
1812 match response.stop_reason {
1813 acp::StopReason::Cancelled => SubagentPromptResult::Cancelled,
1814 acp::StopReason::MaxTokens => SubagentPromptResult::Error("The agent reached the maximum number of tokens.".into()),
1815 acp::StopReason::MaxTurnRequests => SubagentPromptResult::Error("The agent reached the maximum number of allowed requests between user turns. Try prompting again.".into()),
1816 acp::StopReason::Refusal => SubagentPromptResult::Error("The agent refused to process that prompt. Try again.".into()),
1817 acp::StopReason::EndTurn | _ => SubagentPromptResult::Completed,
1818 }
1819 }
1820 Ok(None) => SubagentPromptResult::Error("No response from the agent. You can try messaging again.".into()),
1821 Err(error) => SubagentPromptResult::Error(error.to_string()),
1822 },
1823 _ = token_limit_rx.fuse() => SubagentPromptResult::ContextWindowWarning,
1824 }
1825 });
1826
1827 (wait_for_prompt, subscription)
1828 });
1829
1830 let result = match task.await {
1831 SubagentPromptResult::Completed => thread.read_with(cx, |thread, _cx| {
1832 thread
1833 .last_message()
1834 .and_then(|message| {
1835 let content = message.as_agent_message()?
1836 .content
1837 .iter()
1838 .filter_map(|c| match c {
1839 AgentMessageContent::Text(text) => Some(text.as_str()),
1840 _ => None,
1841 })
1842 .join("\n\n");
1843 if content.is_empty() {
1844 None
1845 } else {
1846 Some( content)
1847 }
1848 })
1849 .context("No response from subagent")
1850 }),
1851 SubagentPromptResult::Cancelled => Err(anyhow!("User canceled")),
1852 SubagentPromptResult::Error(message) => Err(anyhow!("{message}")),
1853 SubagentPromptResult::ContextWindowWarning => {
1854 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
1855 Err(anyhow!(
1856 "The agent is nearing the end of its context window and has been \
1857 stopped. You can prompt the thread again to have the agent wrap up \
1858 or hand off its work."
1859 ))
1860 }
1861 };
1862
1863 parent_thread
1864 .update(cx, |parent_thread, cx| {
1865 parent_thread.unregister_running_subagent(&subagent_session_id, cx)
1866 })
1867 .ok();
1868
1869 result
1870 })
1871 }
1872}
1873
1874pub struct AcpTerminalHandle {
1875 terminal: Entity<acp_thread::Terminal>,
1876 _drop_tx: Option<oneshot::Sender<()>>,
1877}
1878
1879impl TerminalHandle for AcpTerminalHandle {
1880 fn id(&self, cx: &AsyncApp) -> Result<acp::TerminalId> {
1881 Ok(self.terminal.read_with(cx, |term, _cx| term.id().clone()))
1882 }
1883
1884 fn wait_for_exit(&self, cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
1885 Ok(self
1886 .terminal
1887 .read_with(cx, |term, _cx| term.wait_for_exit()))
1888 }
1889
1890 fn current_output(&self, cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
1891 Ok(self
1892 .terminal
1893 .read_with(cx, |term, cx| term.current_output(cx)))
1894 }
1895
1896 fn kill(&self, cx: &AsyncApp) -> Result<()> {
1897 cx.update(|cx| {
1898 self.terminal.update(cx, |terminal, cx| {
1899 terminal.kill(cx);
1900 });
1901 });
1902 Ok(())
1903 }
1904
1905 fn was_stopped_by_user(&self, cx: &AsyncApp) -> Result<bool> {
1906 Ok(self
1907 .terminal
1908 .read_with(cx, |term, _cx| term.was_stopped_by_user()))
1909 }
1910}
1911
1912#[cfg(test)]
1913mod internal_tests {
1914 use super::*;
1915 use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelInfo, MentionUri};
1916 use fs::FakeFs;
1917 use gpui::TestAppContext;
1918 use indoc::formatdoc;
1919 use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
1920 use language_model::{LanguageModelProviderId, LanguageModelProviderName};
1921 use serde_json::json;
1922 use settings::SettingsStore;
1923 use util::{path, rel_path::rel_path};
1924
1925 #[gpui::test]
1926 async fn test_maintaining_project_context(cx: &mut TestAppContext) {
1927 init_test(cx);
1928 let fs = FakeFs::new(cx.executor());
1929 fs.insert_tree(
1930 "/",
1931 json!({
1932 "a": {}
1933 }),
1934 )
1935 .await;
1936 let project = Project::test(fs.clone(), [], cx).await;
1937 let thread_store = cx.new(|cx| ThreadStore::new(cx));
1938 let agent = NativeAgent::new(
1939 project.clone(),
1940 thread_store,
1941 Templates::new(),
1942 None,
1943 fs.clone(),
1944 &mut cx.to_async(),
1945 )
1946 .await
1947 .unwrap();
1948 agent.read_with(cx, |agent, cx| {
1949 assert_eq!(agent.project_context.read(cx).worktrees, vec![])
1950 });
1951
1952 let worktree = project
1953 .update(cx, |project, cx| project.create_worktree("/a", true, cx))
1954 .await
1955 .unwrap();
1956 cx.run_until_parked();
1957 agent.read_with(cx, |agent, cx| {
1958 assert_eq!(
1959 agent.project_context.read(cx).worktrees,
1960 vec![WorktreeContext {
1961 root_name: "a".into(),
1962 abs_path: Path::new("/a").into(),
1963 rules_file: None
1964 }]
1965 )
1966 });
1967
1968 // Creating `/a/.rules` updates the project context.
1969 fs.insert_file("/a/.rules", Vec::new()).await;
1970 cx.run_until_parked();
1971 agent.read_with(cx, |agent, cx| {
1972 let rules_entry = worktree
1973 .read(cx)
1974 .entry_for_path(rel_path(".rules"))
1975 .unwrap();
1976 assert_eq!(
1977 agent.project_context.read(cx).worktrees,
1978 vec![WorktreeContext {
1979 root_name: "a".into(),
1980 abs_path: Path::new("/a").into(),
1981 rules_file: Some(RulesFileContext {
1982 path_in_worktree: rel_path(".rules").into(),
1983 text: "".into(),
1984 project_entry_id: rules_entry.id.to_usize()
1985 })
1986 }]
1987 )
1988 });
1989 }
1990
1991 #[gpui::test]
1992 async fn test_listing_models(cx: &mut TestAppContext) {
1993 init_test(cx);
1994 let fs = FakeFs::new(cx.executor());
1995 fs.insert_tree("/", json!({ "a": {} })).await;
1996 let project = Project::test(fs.clone(), [], cx).await;
1997 let thread_store = cx.new(|cx| ThreadStore::new(cx));
1998 let connection = NativeAgentConnection(
1999 NativeAgent::new(
2000 project.clone(),
2001 thread_store,
2002 Templates::new(),
2003 None,
2004 fs.clone(),
2005 &mut cx.to_async(),
2006 )
2007 .await
2008 .unwrap(),
2009 );
2010
2011 // Create a thread/session
2012 let acp_thread = cx
2013 .update(|cx| {
2014 Rc::new(connection.clone()).new_session(project.clone(), Path::new("/a"), cx)
2015 })
2016 .await
2017 .unwrap();
2018
2019 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2020
2021 let models = cx
2022 .update(|cx| {
2023 connection
2024 .model_selector(&session_id)
2025 .unwrap()
2026 .list_models(cx)
2027 })
2028 .await
2029 .unwrap();
2030
2031 let acp_thread::AgentModelList::Grouped(models) = models else {
2032 panic!("Unexpected model group");
2033 };
2034 assert_eq!(
2035 models,
2036 IndexMap::from_iter([(
2037 AgentModelGroupName("Fake".into()),
2038 vec![AgentModelInfo {
2039 id: acp::ModelId::new("fake/fake"),
2040 name: "Fake".into(),
2041 description: None,
2042 icon: Some(acp_thread::AgentModelIcon::Named(
2043 ui::IconName::ZedAssistant
2044 )),
2045 is_latest: false,
2046 cost: None,
2047 }]
2048 )])
2049 );
2050 }
2051
2052 #[gpui::test]
2053 async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) {
2054 init_test(cx);
2055 let fs = FakeFs::new(cx.executor());
2056 fs.create_dir(paths::settings_file().parent().unwrap())
2057 .await
2058 .unwrap();
2059 fs.insert_file(
2060 paths::settings_file(),
2061 json!({
2062 "agent": {
2063 "default_model": {
2064 "provider": "foo",
2065 "model": "bar"
2066 }
2067 }
2068 })
2069 .to_string()
2070 .into_bytes(),
2071 )
2072 .await;
2073 let project = Project::test(fs.clone(), [], cx).await;
2074
2075 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2076
2077 // Create the agent and connection
2078 let agent = NativeAgent::new(
2079 project.clone(),
2080 thread_store,
2081 Templates::new(),
2082 None,
2083 fs.clone(),
2084 &mut cx.to_async(),
2085 )
2086 .await
2087 .unwrap();
2088 let connection = NativeAgentConnection(agent.clone());
2089
2090 // Create a thread/session
2091 let acp_thread = cx
2092 .update(|cx| {
2093 Rc::new(connection.clone()).new_session(project.clone(), Path::new("/a"), cx)
2094 })
2095 .await
2096 .unwrap();
2097
2098 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2099
2100 // Select a model
2101 let selector = connection.model_selector(&session_id).unwrap();
2102 let model_id = acp::ModelId::new("fake/fake");
2103 cx.update(|cx| selector.select_model(model_id.clone(), cx))
2104 .await
2105 .unwrap();
2106
2107 // Verify the thread has the selected model
2108 agent.read_with(cx, |agent, _| {
2109 let session = agent.sessions.get(&session_id).unwrap();
2110 session.thread.read_with(cx, |thread, _| {
2111 assert_eq!(thread.model().unwrap().id().0, "fake");
2112 });
2113 });
2114
2115 cx.run_until_parked();
2116
2117 // Verify settings file was updated
2118 let settings_content = fs.load(paths::settings_file()).await.unwrap();
2119 let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
2120
2121 // Check that the agent settings contain the selected model
2122 assert_eq!(
2123 settings_json["agent"]["default_model"]["model"],
2124 json!("fake")
2125 );
2126 assert_eq!(
2127 settings_json["agent"]["default_model"]["provider"],
2128 json!("fake")
2129 );
2130
2131 // Register a thinking model and select it.
2132 cx.update(|cx| {
2133 let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2134 "fake-corp",
2135 "fake-thinking",
2136 "Fake Thinking",
2137 true,
2138 ));
2139 let thinking_provider = Arc::new(
2140 FakeLanguageModelProvider::new(
2141 LanguageModelProviderId::from("fake-corp".to_string()),
2142 LanguageModelProviderName::from("Fake Corp".to_string()),
2143 )
2144 .with_models(vec![thinking_model]),
2145 );
2146 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2147 registry.register_provider(thinking_provider, cx);
2148 });
2149 });
2150 agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2151
2152 let selector = connection.model_selector(&session_id).unwrap();
2153 cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2154 .await
2155 .unwrap();
2156 cx.run_until_parked();
2157
2158 // Verify enable_thinking was written to settings as true.
2159 let settings_content = fs.load(paths::settings_file()).await.unwrap();
2160 let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
2161 assert_eq!(
2162 settings_json["agent"]["default_model"]["enable_thinking"],
2163 json!(true),
2164 "selecting a thinking model should persist enable_thinking: true to settings"
2165 );
2166 }
2167
2168 #[gpui::test]
2169 async fn test_select_model_updates_thinking_enabled(cx: &mut TestAppContext) {
2170 init_test(cx);
2171 let fs = FakeFs::new(cx.executor());
2172 fs.create_dir(paths::settings_file().parent().unwrap())
2173 .await
2174 .unwrap();
2175 fs.insert_file(paths::settings_file(), b"{}".to_vec()).await;
2176 let project = Project::test(fs.clone(), [], cx).await;
2177
2178 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2179 let agent = NativeAgent::new(
2180 project.clone(),
2181 thread_store,
2182 Templates::new(),
2183 None,
2184 fs.clone(),
2185 &mut cx.to_async(),
2186 )
2187 .await
2188 .unwrap();
2189 let connection = NativeAgentConnection(agent.clone());
2190
2191 let acp_thread = cx
2192 .update(|cx| {
2193 Rc::new(connection.clone()).new_session(project.clone(), Path::new("/a"), cx)
2194 })
2195 .await
2196 .unwrap();
2197 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2198
2199 // Register a second provider with a thinking model.
2200 cx.update(|cx| {
2201 let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2202 "fake-corp",
2203 "fake-thinking",
2204 "Fake Thinking",
2205 true,
2206 ));
2207 let thinking_provider = Arc::new(
2208 FakeLanguageModelProvider::new(
2209 LanguageModelProviderId::from("fake-corp".to_string()),
2210 LanguageModelProviderName::from("Fake Corp".to_string()),
2211 )
2212 .with_models(vec![thinking_model]),
2213 );
2214 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2215 registry.register_provider(thinking_provider, cx);
2216 });
2217 });
2218 // Refresh the agent's model list so it picks up the new provider.
2219 agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2220
2221 // Thread starts with thinking_enabled = false (the default).
2222 agent.read_with(cx, |agent, _| {
2223 let session = agent.sessions.get(&session_id).unwrap();
2224 session.thread.read_with(cx, |thread, _| {
2225 assert!(!thread.thinking_enabled(), "thinking defaults to false");
2226 });
2227 });
2228
2229 // Select the thinking model via select_model.
2230 let selector = connection.model_selector(&session_id).unwrap();
2231 cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2232 .await
2233 .unwrap();
2234
2235 // select_model should have enabled thinking based on the model's supports_thinking().
2236 agent.read_with(cx, |agent, _| {
2237 let session = agent.sessions.get(&session_id).unwrap();
2238 session.thread.read_with(cx, |thread, _| {
2239 assert!(
2240 thread.thinking_enabled(),
2241 "select_model should enable thinking when model supports it"
2242 );
2243 });
2244 });
2245
2246 // Switch back to the non-thinking model.
2247 let selector = connection.model_selector(&session_id).unwrap();
2248 cx.update(|cx| selector.select_model(acp::ModelId::new("fake/fake"), cx))
2249 .await
2250 .unwrap();
2251
2252 // select_model should have disabled thinking.
2253 agent.read_with(cx, |agent, _| {
2254 let session = agent.sessions.get(&session_id).unwrap();
2255 session.thread.read_with(cx, |thread, _| {
2256 assert!(
2257 !thread.thinking_enabled(),
2258 "select_model should disable thinking when model does not support it"
2259 );
2260 });
2261 });
2262 }
2263
2264 #[gpui::test]
2265 async fn test_loaded_thread_preserves_thinking_enabled(cx: &mut TestAppContext) {
2266 init_test(cx);
2267 let fs = FakeFs::new(cx.executor());
2268 fs.insert_tree("/", json!({ "a": {} })).await;
2269 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2270 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2271 let agent = NativeAgent::new(
2272 project.clone(),
2273 thread_store.clone(),
2274 Templates::new(),
2275 None,
2276 fs.clone(),
2277 &mut cx.to_async(),
2278 )
2279 .await
2280 .unwrap();
2281 let connection = Rc::new(NativeAgentConnection(agent.clone()));
2282
2283 // Register a thinking model.
2284 let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2285 "fake-corp",
2286 "fake-thinking",
2287 "Fake Thinking",
2288 true,
2289 ));
2290 let thinking_provider = Arc::new(
2291 FakeLanguageModelProvider::new(
2292 LanguageModelProviderId::from("fake-corp".to_string()),
2293 LanguageModelProviderName::from("Fake Corp".to_string()),
2294 )
2295 .with_models(vec![thinking_model.clone()]),
2296 );
2297 cx.update(|cx| {
2298 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2299 registry.register_provider(thinking_provider, cx);
2300 });
2301 });
2302 agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2303
2304 // Create a thread and select the thinking model.
2305 let acp_thread = cx
2306 .update(|cx| {
2307 connection
2308 .clone()
2309 .new_session(project.clone(), Path::new("/a"), cx)
2310 })
2311 .await
2312 .unwrap();
2313 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2314
2315 let selector = connection.model_selector(&session_id).unwrap();
2316 cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2317 .await
2318 .unwrap();
2319
2320 // Verify thinking is enabled after selecting the thinking model.
2321 let thread = agent.read_with(cx, |agent, _| {
2322 agent.sessions.get(&session_id).unwrap().thread.clone()
2323 });
2324 thread.read_with(cx, |thread, _| {
2325 assert!(
2326 thread.thinking_enabled(),
2327 "thinking should be enabled after selecting thinking model"
2328 );
2329 });
2330
2331 // Send a message so the thread gets persisted.
2332 let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
2333 let send = cx.foreground_executor().spawn(send);
2334 cx.run_until_parked();
2335
2336 thinking_model.send_last_completion_stream_text_chunk("Response.");
2337 thinking_model.end_last_completion_stream();
2338
2339 send.await.unwrap();
2340 cx.run_until_parked();
2341
2342 // Close the session so it can be reloaded from disk.
2343 cx.update(|cx| connection.clone().close_session(&session_id, cx))
2344 .await
2345 .unwrap();
2346 drop(thread);
2347 drop(acp_thread);
2348 agent.read_with(cx, |agent, _| {
2349 assert!(agent.sessions.is_empty());
2350 });
2351
2352 // Reload the thread and verify thinking_enabled is still true.
2353 let reloaded_acp_thread = agent
2354 .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
2355 .await
2356 .unwrap();
2357 let reloaded_thread = agent.read_with(cx, |agent, _| {
2358 agent.sessions.get(&session_id).unwrap().thread.clone()
2359 });
2360 reloaded_thread.read_with(cx, |thread, _| {
2361 assert!(
2362 thread.thinking_enabled(),
2363 "thinking_enabled should be preserved when reloading a thread with a thinking model"
2364 );
2365 });
2366
2367 drop(reloaded_acp_thread);
2368 }
2369
2370 #[gpui::test]
2371 async fn test_loaded_thread_preserves_model(cx: &mut TestAppContext) {
2372 init_test(cx);
2373 let fs = FakeFs::new(cx.executor());
2374 fs.insert_tree("/", json!({ "a": {} })).await;
2375 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2376 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2377 let agent = NativeAgent::new(
2378 project.clone(),
2379 thread_store.clone(),
2380 Templates::new(),
2381 None,
2382 fs.clone(),
2383 &mut cx.to_async(),
2384 )
2385 .await
2386 .unwrap();
2387 let connection = Rc::new(NativeAgentConnection(agent.clone()));
2388
2389 // Register a model where id() != name(), like real Anthropic models
2390 // (e.g. id="claude-sonnet-4-5-thinking-latest", name="Claude Sonnet 4.5 Thinking").
2391 let model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2392 "fake-corp",
2393 "custom-model-id",
2394 "Custom Model Display Name",
2395 false,
2396 ));
2397 let provider = Arc::new(
2398 FakeLanguageModelProvider::new(
2399 LanguageModelProviderId::from("fake-corp".to_string()),
2400 LanguageModelProviderName::from("Fake Corp".to_string()),
2401 )
2402 .with_models(vec![model.clone()]),
2403 );
2404 cx.update(|cx| {
2405 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2406 registry.register_provider(provider, cx);
2407 });
2408 });
2409 agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2410
2411 // Create a thread and select the model.
2412 let acp_thread = cx
2413 .update(|cx| {
2414 connection
2415 .clone()
2416 .new_session(project.clone(), Path::new("/a"), cx)
2417 })
2418 .await
2419 .unwrap();
2420 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2421
2422 let selector = connection.model_selector(&session_id).unwrap();
2423 cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/custom-model-id"), cx))
2424 .await
2425 .unwrap();
2426
2427 let thread = agent.read_with(cx, |agent, _| {
2428 agent.sessions.get(&session_id).unwrap().thread.clone()
2429 });
2430 thread.read_with(cx, |thread, _| {
2431 assert_eq!(
2432 thread.model().unwrap().id().0.as_ref(),
2433 "custom-model-id",
2434 "model should be set before persisting"
2435 );
2436 });
2437
2438 // Send a message so the thread gets persisted.
2439 let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
2440 let send = cx.foreground_executor().spawn(send);
2441 cx.run_until_parked();
2442
2443 model.send_last_completion_stream_text_chunk("Response.");
2444 model.end_last_completion_stream();
2445
2446 send.await.unwrap();
2447 cx.run_until_parked();
2448
2449 // Close the session so it can be reloaded from disk.
2450 cx.update(|cx| connection.clone().close_session(&session_id, cx))
2451 .await
2452 .unwrap();
2453 drop(thread);
2454 drop(acp_thread);
2455 agent.read_with(cx, |agent, _| {
2456 assert!(agent.sessions.is_empty());
2457 });
2458
2459 // Reload the thread and verify the model was preserved.
2460 let reloaded_acp_thread = agent
2461 .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
2462 .await
2463 .unwrap();
2464 let reloaded_thread = agent.read_with(cx, |agent, _| {
2465 agent.sessions.get(&session_id).unwrap().thread.clone()
2466 });
2467 reloaded_thread.read_with(cx, |thread, _| {
2468 let reloaded_model = thread
2469 .model()
2470 .expect("model should be present after reload");
2471 assert_eq!(
2472 reloaded_model.id().0.as_ref(),
2473 "custom-model-id",
2474 "reloaded thread should have the same model, not fall back to the default"
2475 );
2476 });
2477
2478 drop(reloaded_acp_thread);
2479 }
2480
2481 #[gpui::test]
2482 async fn test_save_load_thread(cx: &mut TestAppContext) {
2483 init_test(cx);
2484 let fs = FakeFs::new(cx.executor());
2485 fs.insert_tree(
2486 "/",
2487 json!({
2488 "a": {
2489 "b.md": "Lorem"
2490 }
2491 }),
2492 )
2493 .await;
2494 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2495 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2496 let agent = NativeAgent::new(
2497 project.clone(),
2498 thread_store.clone(),
2499 Templates::new(),
2500 None,
2501 fs.clone(),
2502 &mut cx.to_async(),
2503 )
2504 .await
2505 .unwrap();
2506 let connection = Rc::new(NativeAgentConnection(agent.clone()));
2507
2508 let acp_thread = cx
2509 .update(|cx| {
2510 connection
2511 .clone()
2512 .new_session(project.clone(), Path::new(""), cx)
2513 })
2514 .await
2515 .unwrap();
2516 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2517 let thread = agent.read_with(cx, |agent, _| {
2518 agent.sessions.get(&session_id).unwrap().thread.clone()
2519 });
2520
2521 // Ensure empty threads are not saved, even if they get mutated.
2522 let model = Arc::new(FakeLanguageModel::default());
2523 let summary_model = Arc::new(FakeLanguageModel::default());
2524 thread.update(cx, |thread, cx| {
2525 thread.set_model(model.clone(), cx);
2526 thread.set_summarization_model(Some(summary_model.clone()), cx);
2527 });
2528 cx.run_until_parked();
2529 assert_eq!(thread_entries(&thread_store, cx), vec![]);
2530
2531 let send = acp_thread.update(cx, |thread, cx| {
2532 thread.send(
2533 vec![
2534 "What does ".into(),
2535 acp::ContentBlock::ResourceLink(acp::ResourceLink::new(
2536 "b.md",
2537 MentionUri::File {
2538 abs_path: path!("/a/b.md").into(),
2539 }
2540 .to_uri()
2541 .to_string(),
2542 )),
2543 " mean?".into(),
2544 ],
2545 cx,
2546 )
2547 });
2548 let send = cx.foreground_executor().spawn(send);
2549 cx.run_until_parked();
2550
2551 model.send_last_completion_stream_text_chunk("Lorem.");
2552 model.end_last_completion_stream();
2553 cx.run_until_parked();
2554 summary_model
2555 .send_last_completion_stream_text_chunk(&format!("Explaining {}", path!("/a/b.md")));
2556 summary_model.end_last_completion_stream();
2557
2558 send.await.unwrap();
2559 let uri = MentionUri::File {
2560 abs_path: path!("/a/b.md").into(),
2561 }
2562 .to_uri();
2563 acp_thread.read_with(cx, |thread, cx| {
2564 assert_eq!(
2565 thread.to_markdown(cx),
2566 formatdoc! {"
2567 ## User
2568
2569 What does [@b.md]({uri}) mean?
2570
2571 ## Assistant
2572
2573 Lorem.
2574
2575 "}
2576 )
2577 });
2578
2579 cx.run_until_parked();
2580
2581 // Set a draft prompt with rich content blocks before saving.
2582 let draft_blocks = vec![
2583 acp::ContentBlock::Text(acp::TextContent::new("Check out ")),
2584 acp::ContentBlock::ResourceLink(acp::ResourceLink::new("b.md", uri.to_string())),
2585 acp::ContentBlock::Text(acp::TextContent::new(" please")),
2586 ];
2587 acp_thread.update(cx, |thread, _cx| {
2588 thread.set_draft_prompt(Some(draft_blocks.clone()));
2589 });
2590 thread.update(cx, |_thread, cx| cx.notify());
2591 cx.run_until_parked();
2592
2593 // Close the session so it can be reloaded from disk.
2594 cx.update(|cx| connection.clone().close_session(&session_id, cx))
2595 .await
2596 .unwrap();
2597 drop(thread);
2598 drop(acp_thread);
2599 agent.read_with(cx, |agent, _| {
2600 assert_eq!(agent.sessions.keys().cloned().collect::<Vec<_>>(), []);
2601 });
2602
2603 // Ensure the thread can be reloaded from disk.
2604 assert_eq!(
2605 thread_entries(&thread_store, cx),
2606 vec![(
2607 session_id.clone(),
2608 format!("Explaining {}", path!("/a/b.md"))
2609 )]
2610 );
2611 let acp_thread = agent
2612 .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
2613 .await
2614 .unwrap();
2615 acp_thread.read_with(cx, |thread, cx| {
2616 assert_eq!(
2617 thread.to_markdown(cx),
2618 formatdoc! {"
2619 ## User
2620
2621 What does [@b.md]({uri}) mean?
2622
2623 ## Assistant
2624
2625 Lorem.
2626
2627 "}
2628 )
2629 });
2630
2631 // Ensure the draft prompt with rich content blocks survived the round-trip.
2632 acp_thread.read_with(cx, |thread, _| {
2633 assert_eq!(thread.draft_prompt(), Some(draft_blocks.as_slice()));
2634 });
2635 }
2636
2637 fn thread_entries(
2638 thread_store: &Entity<ThreadStore>,
2639 cx: &mut TestAppContext,
2640 ) -> Vec<(acp::SessionId, String)> {
2641 thread_store.read_with(cx, |store, _| {
2642 store
2643 .entries()
2644 .map(|entry| (entry.id.clone(), entry.title.to_string()))
2645 .collect::<Vec<_>>()
2646 })
2647 }
2648
2649 fn init_test(cx: &mut TestAppContext) {
2650 env_logger::try_init().ok();
2651 cx.update(|cx| {
2652 let settings_store = SettingsStore::test(cx);
2653 cx.set_global(settings_store);
2654
2655 LanguageModelRegistry::test(cx);
2656 });
2657 }
2658}
2659
2660fn mcp_message_content_to_acp_content_block(
2661 content: context_server::types::MessageContent,
2662) -> acp::ContentBlock {
2663 match content {
2664 context_server::types::MessageContent::Text {
2665 text,
2666 annotations: _,
2667 } => text.into(),
2668 context_server::types::MessageContent::Image {
2669 data,
2670 mime_type,
2671 annotations: _,
2672 } => acp::ContentBlock::Image(acp::ImageContent::new(data, mime_type)),
2673 context_server::types::MessageContent::Audio {
2674 data,
2675 mime_type,
2676 annotations: _,
2677 } => acp::ContentBlock::Audio(acp::AudioContent::new(data, mime_type)),
2678 context_server::types::MessageContent::Resource {
2679 resource,
2680 annotations: _,
2681 } => {
2682 let mut link =
2683 acp::ResourceLink::new(resource.uri.to_string(), resource.uri.to_string());
2684 if let Some(mime_type) = resource.mime_type {
2685 link = link.mime_type(mime_type);
2686 }
2687 acp::ContentBlock::ResourceLink(link)
2688 }
2689 }
2690}