1mod db;
2mod edit_agent;
3mod legacy_thread;
4mod native_agent_server;
5pub mod outline;
6mod pattern_extraction;
7mod templates;
8#[cfg(test)]
9mod tests;
10mod thread;
11mod thread_store;
12mod tool_permissions;
13mod tools;
14
15use context_server::ContextServerId;
16pub use db::*;
17use itertools::Itertools;
18pub use native_agent_server::NativeAgentServer;
19pub use pattern_extraction::*;
20pub use shell_command_parser::extract_commands;
21pub use templates::*;
22pub use thread::*;
23pub use thread_store::*;
24pub use tool_permissions::*;
25pub use tools::*;
26
27use acp_thread::{
28 AcpThread, AgentModelSelector, AgentSessionInfo, AgentSessionList, AgentSessionListRequest,
29 AgentSessionListResponse, TokenUsageRatio, UserMessageId,
30};
31use agent_client_protocol as acp;
32use anyhow::{Context as _, Result, anyhow};
33use chrono::{DateTime, Utc};
34use collections::{HashMap, HashSet, IndexMap};
35use fs::Fs;
36use futures::channel::{mpsc, oneshot};
37use futures::future::Shared;
38use futures::{FutureExt as _, StreamExt as _, future};
39use gpui::{
40 App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
41};
42use language_model::{IconOrSvg, LanguageModel, LanguageModelProvider, LanguageModelRegistry};
43use project::{Project, ProjectItem, ProjectPath, Worktree};
44use prompt_store::{
45 ProjectContext, PromptStore, RULES_FILE_NAMES, RulesFileContext, UserRulesContext,
46 WorktreeContext,
47};
48use serde::{Deserialize, Serialize};
49use settings::{LanguageModelSelection, update_settings_file};
50use std::any::Any;
51use std::path::{Path, PathBuf};
52use std::rc::Rc;
53use std::sync::Arc;
54use util::ResultExt;
55use util::path_list::PathList;
56use util::rel_path::RelPath;
57
58#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
59pub struct ProjectSnapshot {
60 pub worktree_snapshots: Vec<project::telemetry_snapshot::TelemetryWorktreeSnapshot>,
61 pub timestamp: DateTime<Utc>,
62}
63
64pub struct RulesLoadingError {
65 pub message: SharedString,
66}
67
68/// Holds both the internal Thread and the AcpThread for a session
69struct Session {
70 /// The internal thread that processes messages
71 thread: Entity<Thread>,
72 /// The ACP thread that handles protocol communication
73 acp_thread: Entity<acp_thread::AcpThread>,
74 pending_save: Task<()>,
75 _subscriptions: Vec<Subscription>,
76}
77
78pub struct LanguageModels {
79 /// Access language model by ID
80 models: HashMap<acp::ModelId, Arc<dyn LanguageModel>>,
81 /// Cached list for returning language model information
82 model_list: acp_thread::AgentModelList,
83 refresh_models_rx: watch::Receiver<()>,
84 refresh_models_tx: watch::Sender<()>,
85 _authenticate_all_providers_task: Task<()>,
86}
87
88impl LanguageModels {
89 fn new(cx: &mut App) -> Self {
90 let (refresh_models_tx, refresh_models_rx) = watch::channel(());
91
92 let mut this = Self {
93 models: HashMap::default(),
94 model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()),
95 refresh_models_rx,
96 refresh_models_tx,
97 _authenticate_all_providers_task: Self::authenticate_all_language_model_providers(cx),
98 };
99 this.refresh_list(cx);
100 this
101 }
102
103 fn refresh_list(&mut self, cx: &App) {
104 let providers = LanguageModelRegistry::global(cx)
105 .read(cx)
106 .visible_providers()
107 .into_iter()
108 .filter(|provider| provider.is_authenticated(cx))
109 .collect::<Vec<_>>();
110
111 let mut language_model_list = IndexMap::default();
112 let mut recommended_models = HashSet::default();
113
114 let mut recommended = Vec::new();
115 for provider in &providers {
116 for model in provider.recommended_models(cx) {
117 recommended_models.insert((model.provider_id(), model.id()));
118 recommended.push(Self::map_language_model_to_info(&model, provider));
119 }
120 }
121 if !recommended.is_empty() {
122 language_model_list.insert(
123 acp_thread::AgentModelGroupName("Recommended".into()),
124 recommended,
125 );
126 }
127
128 let mut models = HashMap::default();
129 for provider in providers {
130 let mut provider_models = Vec::new();
131 for model in provider.provided_models(cx) {
132 let model_info = Self::map_language_model_to_info(&model, &provider);
133 let model_id = model_info.id.clone();
134 provider_models.push(model_info);
135 models.insert(model_id, model);
136 }
137 if !provider_models.is_empty() {
138 language_model_list.insert(
139 acp_thread::AgentModelGroupName(provider.name().0.clone()),
140 provider_models,
141 );
142 }
143 }
144
145 self.models = models;
146 self.model_list = acp_thread::AgentModelList::Grouped(language_model_list);
147 self.refresh_models_tx.send(()).ok();
148 }
149
150 fn watch(&self) -> watch::Receiver<()> {
151 self.refresh_models_rx.clone()
152 }
153
154 pub fn model_from_id(&self, model_id: &acp::ModelId) -> Option<Arc<dyn LanguageModel>> {
155 self.models.get(model_id).cloned()
156 }
157
158 fn map_language_model_to_info(
159 model: &Arc<dyn LanguageModel>,
160 provider: &Arc<dyn LanguageModelProvider>,
161 ) -> acp_thread::AgentModelInfo {
162 acp_thread::AgentModelInfo {
163 id: Self::model_id(model),
164 name: model.name().0,
165 description: None,
166 icon: Some(match provider.icon() {
167 IconOrSvg::Svg(path) => acp_thread::AgentModelIcon::Path(path),
168 IconOrSvg::Icon(name) => acp_thread::AgentModelIcon::Named(name),
169 }),
170 is_latest: model.is_latest(),
171 cost: model.model_cost_info().map(|cost| cost.to_shared_string()),
172 }
173 }
174
175 fn model_id(model: &Arc<dyn LanguageModel>) -> acp::ModelId {
176 acp::ModelId::new(format!("{}/{}", model.provider_id().0, model.id().0))
177 }
178
179 fn authenticate_all_language_model_providers(cx: &mut App) -> Task<()> {
180 let authenticate_all_providers = LanguageModelRegistry::global(cx)
181 .read(cx)
182 .visible_providers()
183 .iter()
184 .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
185 .collect::<Vec<_>>();
186
187 cx.background_spawn(async move {
188 for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
189 if let Err(err) = authenticate_task.await {
190 match err {
191 language_model::AuthenticateError::CredentialsNotFound => {
192 // Since we're authenticating these providers in the
193 // background for the purposes of populating the
194 // language selector, we don't care about providers
195 // where the credentials are not found.
196 }
197 language_model::AuthenticateError::ConnectionRefused => {
198 // Not logging connection refused errors as they are mostly from LM Studio's noisy auth failures.
199 // LM Studio only has one auth method (endpoint call) which fails for users who haven't enabled it.
200 // TODO: Better manage LM Studio auth logic to avoid these noisy failures.
201 }
202 _ => {
203 // Some providers have noisy failure states that we
204 // don't want to spam the logs with every time the
205 // language model selector is initialized.
206 //
207 // Ideally these should have more clear failure modes
208 // that we know are safe to ignore here, like what we do
209 // with `CredentialsNotFound` above.
210 match provider_id.0.as_ref() {
211 "lmstudio" | "ollama" => {
212 // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
213 //
214 // These fail noisily, so we don't log them.
215 }
216 "copilot_chat" => {
217 // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
218 }
219 _ => {
220 log::error!(
221 "Failed to authenticate provider: {}: {err:#}",
222 provider_name.0
223 );
224 }
225 }
226 }
227 }
228 }
229 }
230 })
231 }
232}
233
234pub struct NativeAgent {
235 /// Session ID -> Session mapping
236 sessions: HashMap<acp::SessionId, Session>,
237 thread_store: Entity<ThreadStore>,
238 /// Shared project context for all threads
239 project_context: Entity<ProjectContext>,
240 project_context_needs_refresh: watch::Sender<()>,
241 _maintain_project_context: Task<Result<()>>,
242 context_server_registry: Entity<ContextServerRegistry>,
243 /// Shared templates for all threads
244 templates: Arc<Templates>,
245 /// Cached model information
246 models: LanguageModels,
247 project: Entity<Project>,
248 prompt_store: Option<Entity<PromptStore>>,
249 fs: Arc<dyn Fs>,
250 _subscriptions: Vec<Subscription>,
251}
252
253impl NativeAgent {
254 pub async fn new(
255 project: Entity<Project>,
256 thread_store: Entity<ThreadStore>,
257 templates: Arc<Templates>,
258 prompt_store: Option<Entity<PromptStore>>,
259 fs: Arc<dyn Fs>,
260 cx: &mut AsyncApp,
261 ) -> Result<Entity<NativeAgent>> {
262 log::debug!("Creating new NativeAgent");
263
264 let project_context = cx
265 .update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))
266 .await;
267
268 Ok(cx.new(|cx| {
269 let context_server_store = project.read(cx).context_server_store();
270 let context_server_registry =
271 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
272
273 let mut subscriptions = vec![
274 cx.subscribe(&project, Self::handle_project_event),
275 cx.subscribe(
276 &LanguageModelRegistry::global(cx),
277 Self::handle_models_updated_event,
278 ),
279 cx.subscribe(
280 &context_server_store,
281 Self::handle_context_server_store_updated,
282 ),
283 cx.subscribe(
284 &context_server_registry,
285 Self::handle_context_server_registry_event,
286 ),
287 ];
288 if let Some(prompt_store) = prompt_store.as_ref() {
289 subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
290 }
291
292 let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
293 watch::channel(());
294 Self {
295 sessions: HashMap::default(),
296 thread_store,
297 project_context: cx.new(|_| project_context),
298 project_context_needs_refresh: project_context_needs_refresh_tx,
299 _maintain_project_context: cx.spawn(async move |this, cx| {
300 Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
301 }),
302 context_server_registry,
303 templates,
304 models: LanguageModels::new(cx),
305 project,
306 prompt_store,
307 fs,
308 _subscriptions: subscriptions,
309 }
310 }))
311 }
312
313 fn new_session(
314 &mut self,
315 project: Entity<Project>,
316 cx: &mut Context<Self>,
317 ) -> Entity<AcpThread> {
318 // Create Thread
319 // Fetch default model from registry settings
320 let registry = LanguageModelRegistry::read_global(cx);
321 // Log available models for debugging
322 let available_count = registry.available_models(cx).count();
323 log::debug!("Total available models: {}", available_count);
324
325 let default_model = registry.default_model().and_then(|default_model| {
326 self.models
327 .model_from_id(&LanguageModels::model_id(&default_model.model))
328 });
329 let thread = cx.new(|cx| {
330 Thread::new(
331 project.clone(),
332 self.project_context.clone(),
333 self.context_server_registry.clone(),
334 self.templates.clone(),
335 default_model,
336 cx,
337 )
338 });
339
340 self.register_session(thread, cx)
341 }
342
343 fn register_session(
344 &mut self,
345 thread_handle: Entity<Thread>,
346 cx: &mut Context<Self>,
347 ) -> Entity<AcpThread> {
348 let connection = Rc::new(NativeAgentConnection(cx.entity()));
349
350 let thread = thread_handle.read(cx);
351 let session_id = thread.id().clone();
352 let parent_session_id = thread.parent_thread_id();
353 let title = thread.title();
354 let draft_prompt = thread.draft_prompt().map(Vec::from);
355 let scroll_position = thread.ui_scroll_position();
356 let token_usage = thread.latest_token_usage();
357 let project = thread.project.clone();
358 let action_log = thread.action_log.clone();
359 let prompt_capabilities_rx = thread.prompt_capabilities_rx.clone();
360 let acp_thread = cx.new(|cx| {
361 let mut acp_thread = acp_thread::AcpThread::new(
362 parent_session_id,
363 title,
364 connection,
365 project.clone(),
366 action_log.clone(),
367 session_id.clone(),
368 prompt_capabilities_rx,
369 cx,
370 );
371 acp_thread.set_draft_prompt(draft_prompt);
372 acp_thread.set_ui_scroll_position(scroll_position);
373 acp_thread.update_token_usage(token_usage, cx);
374 acp_thread
375 });
376
377 let registry = LanguageModelRegistry::read_global(cx);
378 let summarization_model = registry.thread_summary_model().map(|c| c.model);
379
380 let weak = cx.weak_entity();
381 let weak_thread = thread_handle.downgrade();
382 thread_handle.update(cx, |thread, cx| {
383 thread.set_summarization_model(summarization_model, cx);
384 thread.add_default_tools(
385 Rc::new(NativeThreadEnvironment {
386 acp_thread: acp_thread.downgrade(),
387 thread: weak_thread,
388 agent: weak,
389 }) as _,
390 cx,
391 )
392 });
393
394 let subscriptions = vec![
395 cx.subscribe(&thread_handle, Self::handle_thread_title_updated),
396 cx.subscribe(&thread_handle, Self::handle_thread_token_usage_updated),
397 cx.observe(&thread_handle, move |this, thread, cx| {
398 this.save_thread(thread, cx)
399 }),
400 ];
401
402 self.sessions.insert(
403 session_id,
404 Session {
405 thread: thread_handle,
406 acp_thread: acp_thread.clone(),
407 _subscriptions: subscriptions,
408 pending_save: Task::ready(()),
409 },
410 );
411
412 self.update_available_commands(cx);
413
414 acp_thread
415 }
416
417 pub fn models(&self) -> &LanguageModels {
418 &self.models
419 }
420
421 async fn maintain_project_context(
422 this: WeakEntity<Self>,
423 mut needs_refresh: watch::Receiver<()>,
424 cx: &mut AsyncApp,
425 ) -> Result<()> {
426 while needs_refresh.changed().await.is_ok() {
427 let project_context = this
428 .update(cx, |this, cx| {
429 Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx)
430 })?
431 .await;
432 this.update(cx, |this, cx| {
433 this.project_context = cx.new(|_| project_context);
434 })?;
435 }
436
437 Ok(())
438 }
439
440 fn build_project_context(
441 project: &Entity<Project>,
442 prompt_store: Option<&Entity<PromptStore>>,
443 cx: &mut App,
444 ) -> Task<ProjectContext> {
445 let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
446 let worktree_tasks = worktrees
447 .into_iter()
448 .map(|worktree| {
449 Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
450 })
451 .collect::<Vec<_>>();
452 let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
453 prompt_store.read_with(cx, |prompt_store, cx| {
454 let prompts = prompt_store.default_prompt_metadata();
455 let load_tasks = prompts.into_iter().map(|prompt_metadata| {
456 let contents = prompt_store.load(prompt_metadata.id, cx);
457 async move { (contents.await, prompt_metadata) }
458 });
459 cx.background_spawn(future::join_all(load_tasks))
460 })
461 } else {
462 Task::ready(vec![])
463 };
464
465 cx.spawn(async move |_cx| {
466 let (worktrees, default_user_rules) =
467 future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
468
469 let worktrees = worktrees
470 .into_iter()
471 .map(|(worktree, _rules_error)| {
472 // TODO: show error message
473 // if let Some(rules_error) = rules_error {
474 // this.update(cx, |_, cx| cx.emit(rules_error)).ok();
475 // }
476 worktree
477 })
478 .collect::<Vec<_>>();
479
480 let default_user_rules = default_user_rules
481 .into_iter()
482 .flat_map(|(contents, prompt_metadata)| match contents {
483 Ok(contents) => Some(UserRulesContext {
484 uuid: prompt_metadata.id.as_user()?,
485 title: prompt_metadata.title.map(|title| title.to_string()),
486 contents,
487 }),
488 Err(_err) => {
489 // TODO: show error message
490 // this.update(cx, |_, cx| {
491 // cx.emit(RulesLoadingError {
492 // message: format!("{err:?}").into(),
493 // });
494 // })
495 // .ok();
496 None
497 }
498 })
499 .collect::<Vec<_>>();
500
501 ProjectContext::new(worktrees, default_user_rules)
502 })
503 }
504
505 fn load_worktree_info_for_system_prompt(
506 worktree: Entity<Worktree>,
507 project: Entity<Project>,
508 cx: &mut App,
509 ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
510 let tree = worktree.read(cx);
511 let root_name = tree.root_name_str().into();
512 let abs_path = tree.abs_path();
513
514 let mut context = WorktreeContext {
515 root_name,
516 abs_path,
517 rules_file: None,
518 };
519
520 let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
521 let Some(rules_task) = rules_task else {
522 return Task::ready((context, None));
523 };
524
525 cx.spawn(async move |_| {
526 let (rules_file, rules_file_error) = match rules_task.await {
527 Ok(rules_file) => (Some(rules_file), None),
528 Err(err) => (
529 None,
530 Some(RulesLoadingError {
531 message: format!("{err}").into(),
532 }),
533 ),
534 };
535 context.rules_file = rules_file;
536 (context, rules_file_error)
537 })
538 }
539
540 fn load_worktree_rules_file(
541 worktree: Entity<Worktree>,
542 project: Entity<Project>,
543 cx: &mut App,
544 ) -> Option<Task<Result<RulesFileContext>>> {
545 let worktree = worktree.read(cx);
546 let worktree_id = worktree.id();
547 let selected_rules_file = RULES_FILE_NAMES
548 .into_iter()
549 .filter_map(|name| {
550 worktree
551 .entry_for_path(RelPath::unix(name).unwrap())
552 .filter(|entry| entry.is_file())
553 .map(|entry| entry.path.clone())
554 })
555 .next();
556
557 // Note that Cline supports `.clinerules` being a directory, but that is not currently
558 // supported. This doesn't seem to occur often in GitHub repositories.
559 selected_rules_file.map(|path_in_worktree| {
560 let project_path = ProjectPath {
561 worktree_id,
562 path: path_in_worktree.clone(),
563 };
564 let buffer_task =
565 project.update(cx, |project, cx| project.open_buffer(project_path, cx));
566 let rope_task = cx.spawn(async move |cx| {
567 let buffer = buffer_task.await?;
568 let (project_entry_id, rope) = buffer.read_with(cx, |buffer, cx| {
569 let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
570 anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
571 })?;
572 anyhow::Ok((project_entry_id, rope))
573 });
574 // Build a string from the rope on a background thread.
575 cx.background_spawn(async move {
576 let (project_entry_id, rope) = rope_task.await?;
577 anyhow::Ok(RulesFileContext {
578 path_in_worktree,
579 text: rope.to_string().trim().to_string(),
580 project_entry_id: project_entry_id.to_usize(),
581 })
582 })
583 })
584 }
585
586 fn handle_thread_title_updated(
587 &mut self,
588 thread: Entity<Thread>,
589 _: &TitleUpdated,
590 cx: &mut Context<Self>,
591 ) {
592 let session_id = thread.read(cx).id();
593 let Some(session) = self.sessions.get(session_id) else {
594 return;
595 };
596 let thread = thread.downgrade();
597 let acp_thread = session.acp_thread.downgrade();
598 cx.spawn(async move |_, cx| {
599 let title = thread.read_with(cx, |thread, _| thread.title())?;
600 let task = acp_thread.update(cx, |acp_thread, cx| acp_thread.set_title(title, cx))?;
601 task.await
602 })
603 .detach_and_log_err(cx);
604 }
605
606 fn handle_thread_token_usage_updated(
607 &mut self,
608 thread: Entity<Thread>,
609 usage: &TokenUsageUpdated,
610 cx: &mut Context<Self>,
611 ) {
612 let Some(session) = self.sessions.get(thread.read(cx).id()) else {
613 return;
614 };
615 session.acp_thread.update(cx, |acp_thread, cx| {
616 acp_thread.update_token_usage(usage.0.clone(), cx);
617 });
618 }
619
620 fn handle_project_event(
621 &mut self,
622 _project: Entity<Project>,
623 event: &project::Event,
624 _cx: &mut Context<Self>,
625 ) {
626 match event {
627 project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
628 self.project_context_needs_refresh.send(()).ok();
629 }
630 project::Event::WorktreeUpdatedEntries(_, items) => {
631 if items.iter().any(|(path, _, _)| {
632 RULES_FILE_NAMES
633 .iter()
634 .any(|name| path.as_ref() == RelPath::unix(name).unwrap())
635 }) {
636 self.project_context_needs_refresh.send(()).ok();
637 }
638 }
639 _ => {}
640 }
641 }
642
643 fn handle_prompts_updated_event(
644 &mut self,
645 _prompt_store: Entity<PromptStore>,
646 _event: &prompt_store::PromptsUpdatedEvent,
647 _cx: &mut Context<Self>,
648 ) {
649 self.project_context_needs_refresh.send(()).ok();
650 }
651
652 fn handle_models_updated_event(
653 &mut self,
654 _registry: Entity<LanguageModelRegistry>,
655 _event: &language_model::Event,
656 cx: &mut Context<Self>,
657 ) {
658 self.models.refresh_list(cx);
659
660 let registry = LanguageModelRegistry::read_global(cx);
661 let default_model = registry.default_model().map(|m| m.model);
662 let summarization_model = registry.thread_summary_model().map(|m| m.model);
663
664 for session in self.sessions.values_mut() {
665 session.thread.update(cx, |thread, cx| {
666 if thread.model().is_none()
667 && let Some(model) = default_model.clone()
668 {
669 thread.set_model(model, cx);
670 cx.notify();
671 }
672 thread.set_summarization_model(summarization_model.clone(), cx);
673 });
674 }
675 }
676
677 fn handle_context_server_store_updated(
678 &mut self,
679 _store: Entity<project::context_server_store::ContextServerStore>,
680 _event: &project::context_server_store::ServerStatusChangedEvent,
681 cx: &mut Context<Self>,
682 ) {
683 self.update_available_commands(cx);
684 }
685
686 fn handle_context_server_registry_event(
687 &mut self,
688 _registry: Entity<ContextServerRegistry>,
689 event: &ContextServerRegistryEvent,
690 cx: &mut Context<Self>,
691 ) {
692 match event {
693 ContextServerRegistryEvent::ToolsChanged => {}
694 ContextServerRegistryEvent::PromptsChanged => {
695 self.update_available_commands(cx);
696 }
697 }
698 }
699
700 fn update_available_commands(&self, cx: &mut Context<Self>) {
701 let available_commands = self.build_available_commands(cx);
702 for session in self.sessions.values() {
703 session.acp_thread.update(cx, |thread, cx| {
704 thread
705 .handle_session_update(
706 acp::SessionUpdate::AvailableCommandsUpdate(
707 acp::AvailableCommandsUpdate::new(available_commands.clone()),
708 ),
709 cx,
710 )
711 .log_err();
712 });
713 }
714 }
715
716 fn build_available_commands(&self, cx: &App) -> Vec<acp::AvailableCommand> {
717 let registry = self.context_server_registry.read(cx);
718
719 let mut prompt_name_counts: HashMap<&str, usize> = HashMap::default();
720 for context_server_prompt in registry.prompts() {
721 *prompt_name_counts
722 .entry(context_server_prompt.prompt.name.as_str())
723 .or_insert(0) += 1;
724 }
725
726 registry
727 .prompts()
728 .flat_map(|context_server_prompt| {
729 let prompt = &context_server_prompt.prompt;
730
731 let should_prefix = prompt_name_counts
732 .get(prompt.name.as_str())
733 .copied()
734 .unwrap_or(0)
735 > 1;
736
737 let name = if should_prefix {
738 format!("{}.{}", context_server_prompt.server_id, prompt.name)
739 } else {
740 prompt.name.clone()
741 };
742
743 let mut command = acp::AvailableCommand::new(
744 name,
745 prompt.description.clone().unwrap_or_default(),
746 );
747
748 match prompt.arguments.as_deref() {
749 Some([arg]) => {
750 let hint = format!("<{}>", arg.name);
751
752 command = command.input(acp::AvailableCommandInput::Unstructured(
753 acp::UnstructuredCommandInput::new(hint),
754 ));
755 }
756 Some([]) | None => {}
757 Some(_) => {
758 // skip >1 argument commands since we don't support them yet
759 return None;
760 }
761 }
762
763 Some(command)
764 })
765 .collect()
766 }
767
768 pub fn load_thread(
769 &mut self,
770 id: acp::SessionId,
771 cx: &mut Context<Self>,
772 ) -> Task<Result<Entity<Thread>>> {
773 let database_future = ThreadsDatabase::connect(cx);
774 cx.spawn(async move |this, cx| {
775 let database = database_future.await.map_err(|err| anyhow!(err))?;
776 let db_thread = database
777 .load_thread(id.clone())
778 .await?
779 .with_context(|| format!("no thread found with ID: {id:?}"))?;
780
781 this.update(cx, |this, cx| {
782 let summarization_model = LanguageModelRegistry::read_global(cx)
783 .thread_summary_model()
784 .map(|c| c.model);
785
786 cx.new(|cx| {
787 let mut thread = Thread::from_db(
788 id.clone(),
789 db_thread,
790 this.project.clone(),
791 this.project_context.clone(),
792 this.context_server_registry.clone(),
793 this.templates.clone(),
794 cx,
795 );
796 thread.set_summarization_model(summarization_model, cx);
797 thread
798 })
799 })
800 })
801 }
802
803 pub fn open_thread(
804 &mut self,
805 id: acp::SessionId,
806 cx: &mut Context<Self>,
807 ) -> Task<Result<Entity<AcpThread>>> {
808 if let Some(session) = self.sessions.get(&id) {
809 return Task::ready(Ok(session.acp_thread.clone()));
810 }
811
812 let task = self.load_thread(id, cx);
813 cx.spawn(async move |this, cx| {
814 let thread = task.await?;
815 let acp_thread =
816 this.update(cx, |this, cx| this.register_session(thread.clone(), cx))?;
817 let events = thread.update(cx, |thread, cx| thread.replay(cx));
818 cx.update(|cx| {
819 NativeAgentConnection::handle_thread_events(events, acp_thread.downgrade(), cx)
820 })
821 .await?;
822 Ok(acp_thread)
823 })
824 }
825
826 pub fn thread_summary(
827 &mut self,
828 id: acp::SessionId,
829 cx: &mut Context<Self>,
830 ) -> Task<Result<SharedString>> {
831 let thread = self.open_thread(id.clone(), cx);
832 cx.spawn(async move |this, cx| {
833 let acp_thread = thread.await?;
834 let result = this
835 .update(cx, |this, cx| {
836 this.sessions
837 .get(&id)
838 .unwrap()
839 .thread
840 .update(cx, |thread, cx| thread.summary(cx))
841 })?
842 .await
843 .context("Failed to generate summary")?;
844 drop(acp_thread);
845 Ok(result)
846 })
847 }
848
849 fn save_thread(&mut self, thread: Entity<Thread>, cx: &mut Context<Self>) {
850 if thread.read(cx).is_empty() {
851 return;
852 }
853
854 let id = thread.read(cx).id().clone();
855 let Some(session) = self.sessions.get_mut(&id) else {
856 return;
857 };
858
859 let folder_paths = PathList::new(
860 &self
861 .project
862 .read(cx)
863 .visible_worktrees(cx)
864 .map(|worktree| worktree.read(cx).abs_path().to_path_buf())
865 .collect::<Vec<_>>(),
866 );
867
868 let draft_prompt = session.acp_thread.read(cx).draft_prompt().map(Vec::from);
869 let database_future = ThreadsDatabase::connect(cx);
870 let db_thread = thread.update(cx, |thread, cx| {
871 thread.set_draft_prompt(draft_prompt);
872 thread.to_db(cx)
873 });
874 let thread_store = self.thread_store.clone();
875 session.pending_save = cx.spawn(async move |_, cx| {
876 let Some(database) = database_future.await.map_err(|err| anyhow!(err)).log_err() else {
877 return;
878 };
879 let db_thread = db_thread.await;
880 database
881 .save_thread(id, db_thread, folder_paths)
882 .await
883 .log_err();
884 thread_store.update(cx, |store, cx| store.reload(cx));
885 });
886 }
887
888 fn send_mcp_prompt(
889 &self,
890 message_id: UserMessageId,
891 session_id: agent_client_protocol::SessionId,
892 prompt_name: String,
893 server_id: ContextServerId,
894 arguments: HashMap<String, String>,
895 original_content: Vec<acp::ContentBlock>,
896 cx: &mut Context<Self>,
897 ) -> Task<Result<acp::PromptResponse>> {
898 let server_store = self.context_server_registry.read(cx).server_store().clone();
899 let path_style = self.project.read(cx).path_style(cx);
900
901 cx.spawn(async move |this, cx| {
902 let prompt =
903 crate::get_prompt(&server_store, &server_id, &prompt_name, arguments, cx).await?;
904
905 let (acp_thread, thread) = this.update(cx, |this, _cx| {
906 let session = this
907 .sessions
908 .get(&session_id)
909 .context("Failed to get session")?;
910 anyhow::Ok((session.acp_thread.clone(), session.thread.clone()))
911 })??;
912
913 let mut last_is_user = true;
914
915 thread.update(cx, |thread, cx| {
916 thread.push_acp_user_block(
917 message_id,
918 original_content.into_iter().skip(1),
919 path_style,
920 cx,
921 );
922 });
923
924 for message in prompt.messages {
925 let context_server::types::PromptMessage { role, content } = message;
926 let block = mcp_message_content_to_acp_content_block(content);
927
928 match role {
929 context_server::types::Role::User => {
930 let id = acp_thread::UserMessageId::new();
931
932 acp_thread.update(cx, |acp_thread, cx| {
933 acp_thread.push_user_content_block_with_indent(
934 Some(id.clone()),
935 block.clone(),
936 true,
937 cx,
938 );
939 });
940
941 thread.update(cx, |thread, cx| {
942 thread.push_acp_user_block(id, [block], path_style, cx);
943 });
944 }
945 context_server::types::Role::Assistant => {
946 acp_thread.update(cx, |acp_thread, cx| {
947 acp_thread.push_assistant_content_block_with_indent(
948 block.clone(),
949 false,
950 true,
951 cx,
952 );
953 });
954
955 thread.update(cx, |thread, cx| {
956 thread.push_acp_agent_block(block, cx);
957 });
958 }
959 }
960
961 last_is_user = role == context_server::types::Role::User;
962 }
963
964 let response_stream = thread.update(cx, |thread, cx| {
965 if last_is_user {
966 thread.send_existing(cx)
967 } else {
968 // Resume if MCP prompt did not end with a user message
969 thread.resume(cx)
970 }
971 })?;
972
973 cx.update(|cx| {
974 NativeAgentConnection::handle_thread_events(
975 response_stream,
976 acp_thread.downgrade(),
977 cx,
978 )
979 })
980 .await
981 })
982 }
983}
984
985/// Wrapper struct that implements the AgentConnection trait
986#[derive(Clone)]
987pub struct NativeAgentConnection(pub Entity<NativeAgent>);
988
989impl NativeAgentConnection {
990 pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option<Entity<Thread>> {
991 self.0
992 .read(cx)
993 .sessions
994 .get(session_id)
995 .map(|session| session.thread.clone())
996 }
997
998 pub fn load_thread(&self, id: acp::SessionId, cx: &mut App) -> Task<Result<Entity<Thread>>> {
999 self.0.update(cx, |this, cx| this.load_thread(id, cx))
1000 }
1001
1002 fn run_turn(
1003 &self,
1004 session_id: acp::SessionId,
1005 cx: &mut App,
1006 f: impl 'static
1007 + FnOnce(Entity<Thread>, &mut App) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>,
1008 ) -> Task<Result<acp::PromptResponse>> {
1009 let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
1010 agent
1011 .sessions
1012 .get_mut(&session_id)
1013 .map(|s| (s.thread.clone(), s.acp_thread.clone()))
1014 }) else {
1015 return Task::ready(Err(anyhow!("Session not found")));
1016 };
1017 log::debug!("Found session for: {}", session_id);
1018
1019 let response_stream = match f(thread, cx) {
1020 Ok(stream) => stream,
1021 Err(err) => return Task::ready(Err(err)),
1022 };
1023 Self::handle_thread_events(response_stream, acp_thread.downgrade(), cx)
1024 }
1025
1026 fn handle_thread_events(
1027 mut events: mpsc::UnboundedReceiver<Result<ThreadEvent>>,
1028 acp_thread: WeakEntity<AcpThread>,
1029 cx: &App,
1030 ) -> Task<Result<acp::PromptResponse>> {
1031 cx.spawn(async move |cx| {
1032 // Handle response stream and forward to session.acp_thread
1033 while let Some(result) = events.next().await {
1034 match result {
1035 Ok(event) => {
1036 log::trace!("Received completion event: {:?}", event);
1037
1038 match event {
1039 ThreadEvent::UserMessage(message) => {
1040 acp_thread.update(cx, |thread, cx| {
1041 for content in message.content {
1042 thread.push_user_content_block(
1043 Some(message.id.clone()),
1044 content.into(),
1045 cx,
1046 );
1047 }
1048 })?;
1049 }
1050 ThreadEvent::AgentText(text) => {
1051 acp_thread.update(cx, |thread, cx| {
1052 thread.push_assistant_content_block(text.into(), false, cx)
1053 })?;
1054 }
1055 ThreadEvent::AgentThinking(text) => {
1056 acp_thread.update(cx, |thread, cx| {
1057 thread.push_assistant_content_block(text.into(), true, cx)
1058 })?;
1059 }
1060 ThreadEvent::ToolCallAuthorization(ToolCallAuthorization {
1061 tool_call,
1062 options,
1063 response,
1064 context: _,
1065 }) => {
1066 let outcome_task = acp_thread.update(cx, |thread, cx| {
1067 thread.request_tool_call_authorization(tool_call, options, cx)
1068 })??;
1069 cx.background_spawn(async move {
1070 if let acp::RequestPermissionOutcome::Selected(
1071 acp::SelectedPermissionOutcome { option_id, .. },
1072 ) = outcome_task.await
1073 {
1074 response
1075 .send(option_id)
1076 .map(|_| anyhow!("authorization receiver was dropped"))
1077 .log_err();
1078 }
1079 })
1080 .detach();
1081 }
1082 ThreadEvent::ToolCall(tool_call) => {
1083 acp_thread.update(cx, |thread, cx| {
1084 thread.upsert_tool_call(tool_call, cx)
1085 })??;
1086 }
1087 ThreadEvent::ToolCallUpdate(update) => {
1088 acp_thread.update(cx, |thread, cx| {
1089 thread.update_tool_call(update, cx)
1090 })??;
1091 }
1092 ThreadEvent::SubagentSpawned(session_id) => {
1093 acp_thread.update(cx, |thread, cx| {
1094 thread.subagent_spawned(session_id, cx);
1095 })?;
1096 }
1097 ThreadEvent::Retry(status) => {
1098 acp_thread.update(cx, |thread, cx| {
1099 thread.update_retry_status(status, cx)
1100 })?;
1101 }
1102 ThreadEvent::Stop(stop_reason) => {
1103 log::debug!("Assistant message complete: {:?}", stop_reason);
1104 return Ok(acp::PromptResponse::new(stop_reason));
1105 }
1106 }
1107 }
1108 Err(e) => {
1109 log::error!("Error in model response stream: {:?}", e);
1110 return Err(e);
1111 }
1112 }
1113 }
1114
1115 log::debug!("Response stream completed");
1116 anyhow::Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
1117 })
1118 }
1119}
1120
1121struct Command<'a> {
1122 prompt_name: &'a str,
1123 arg_value: &'a str,
1124 explicit_server_id: Option<&'a str>,
1125}
1126
1127impl<'a> Command<'a> {
1128 fn parse(prompt: &'a [acp::ContentBlock]) -> Option<Self> {
1129 let acp::ContentBlock::Text(text_content) = prompt.first()? else {
1130 return None;
1131 };
1132 let text = text_content.text.trim();
1133 let command = text.strip_prefix('/')?;
1134 let (command, arg_value) = command
1135 .split_once(char::is_whitespace)
1136 .unwrap_or((command, ""));
1137
1138 if let Some((server_id, prompt_name)) = command.split_once('.') {
1139 Some(Self {
1140 prompt_name,
1141 arg_value,
1142 explicit_server_id: Some(server_id),
1143 })
1144 } else {
1145 Some(Self {
1146 prompt_name: command,
1147 arg_value,
1148 explicit_server_id: None,
1149 })
1150 }
1151 }
1152}
1153
1154struct NativeAgentModelSelector {
1155 session_id: acp::SessionId,
1156 connection: NativeAgentConnection,
1157}
1158
1159impl acp_thread::AgentModelSelector for NativeAgentModelSelector {
1160 fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
1161 log::debug!("NativeAgentConnection::list_models called");
1162 let list = self.connection.0.read(cx).models.model_list.clone();
1163 Task::ready(if list.is_empty() {
1164 Err(anyhow::anyhow!("No models available"))
1165 } else {
1166 Ok(list)
1167 })
1168 }
1169
1170 fn select_model(&self, model_id: acp::ModelId, cx: &mut App) -> Task<Result<()>> {
1171 log::debug!(
1172 "Setting model for session {}: {}",
1173 self.session_id,
1174 model_id
1175 );
1176 let Some(thread) = self
1177 .connection
1178 .0
1179 .read(cx)
1180 .sessions
1181 .get(&self.session_id)
1182 .map(|session| session.thread.clone())
1183 else {
1184 return Task::ready(Err(anyhow!("Session not found")));
1185 };
1186
1187 let Some(model) = self.connection.0.read(cx).models.model_from_id(&model_id) else {
1188 return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
1189 };
1190
1191 // We want to reset the effort level when switching models, as the currently-selected effort level may
1192 // not be compatible.
1193 let effort = model
1194 .default_effort_level()
1195 .map(|effort_level| effort_level.value.to_string());
1196
1197 thread.update(cx, |thread, cx| {
1198 thread.set_model(model.clone(), cx);
1199 thread.set_thinking_effort(effort.clone(), cx);
1200 thread.set_thinking_enabled(model.supports_thinking(), cx);
1201 });
1202
1203 update_settings_file(
1204 self.connection.0.read(cx).fs.clone(),
1205 cx,
1206 move |settings, cx| {
1207 let provider = model.provider_id().0.to_string();
1208 let model = model.id().0.to_string();
1209 let enable_thinking = thread.read(cx).thinking_enabled();
1210 settings
1211 .agent
1212 .get_or_insert_default()
1213 .set_model(LanguageModelSelection {
1214 provider: provider.into(),
1215 model,
1216 enable_thinking,
1217 effort,
1218 });
1219 },
1220 );
1221
1222 Task::ready(Ok(()))
1223 }
1224
1225 fn selected_model(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelInfo>> {
1226 let Some(thread) = self
1227 .connection
1228 .0
1229 .read(cx)
1230 .sessions
1231 .get(&self.session_id)
1232 .map(|session| session.thread.clone())
1233 else {
1234 return Task::ready(Err(anyhow!("Session not found")));
1235 };
1236 let Some(model) = thread.read(cx).model() else {
1237 return Task::ready(Err(anyhow!("Model not found")));
1238 };
1239 let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
1240 else {
1241 return Task::ready(Err(anyhow!("Provider not found")));
1242 };
1243 Task::ready(Ok(LanguageModels::map_language_model_to_info(
1244 model, &provider,
1245 )))
1246 }
1247
1248 fn watch(&self, cx: &mut App) -> Option<watch::Receiver<()>> {
1249 Some(self.connection.0.read(cx).models.watch())
1250 }
1251
1252 fn should_render_footer(&self) -> bool {
1253 true
1254 }
1255}
1256
1257impl acp_thread::AgentConnection for NativeAgentConnection {
1258 fn telemetry_id(&self) -> SharedString {
1259 "zed".into()
1260 }
1261
1262 fn new_session(
1263 self: Rc<Self>,
1264 project: Entity<Project>,
1265 cwd: &Path,
1266 cx: &mut App,
1267 ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
1268 log::debug!("Creating new thread for project at: {cwd:?}");
1269 Task::ready(Ok(self
1270 .0
1271 .update(cx, |agent, cx| agent.new_session(project, cx))))
1272 }
1273
1274 fn supports_load_session(&self) -> bool {
1275 true
1276 }
1277
1278 fn load_session(
1279 self: Rc<Self>,
1280 session: AgentSessionInfo,
1281 _project: Entity<Project>,
1282 _cwd: &Path,
1283 cx: &mut App,
1284 ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
1285 self.0
1286 .update(cx, |agent, cx| agent.open_thread(session.session_id, cx))
1287 }
1288
1289 fn supports_close_session(&self) -> bool {
1290 true
1291 }
1292
1293 fn close_session(&self, session_id: &acp::SessionId, cx: &mut App) -> Task<Result<()>> {
1294 self.0.update(cx, |agent, _cx| {
1295 agent.sessions.remove(session_id);
1296 });
1297 Task::ready(Ok(()))
1298 }
1299
1300 fn auth_methods(&self) -> &[acp::AuthMethod] {
1301 &[] // No auth for in-process
1302 }
1303
1304 fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
1305 Task::ready(Ok(()))
1306 }
1307
1308 fn model_selector(&self, session_id: &acp::SessionId) -> Option<Rc<dyn AgentModelSelector>> {
1309 Some(Rc::new(NativeAgentModelSelector {
1310 session_id: session_id.clone(),
1311 connection: self.clone(),
1312 }) as Rc<dyn AgentModelSelector>)
1313 }
1314
1315 fn prompt(
1316 &self,
1317 id: Option<acp_thread::UserMessageId>,
1318 params: acp::PromptRequest,
1319 cx: &mut App,
1320 ) -> Task<Result<acp::PromptResponse>> {
1321 let id = id.expect("UserMessageId is required");
1322 let session_id = params.session_id.clone();
1323 log::info!("Received prompt request for session: {}", session_id);
1324 log::debug!("Prompt blocks count: {}", params.prompt.len());
1325
1326 if let Some(parsed_command) = Command::parse(¶ms.prompt) {
1327 let registry = self.0.read(cx).context_server_registry.read(cx);
1328
1329 let explicit_server_id = parsed_command
1330 .explicit_server_id
1331 .map(|server_id| ContextServerId(server_id.into()));
1332
1333 if let Some(prompt) =
1334 registry.find_prompt(explicit_server_id.as_ref(), parsed_command.prompt_name)
1335 {
1336 let arguments = if !parsed_command.arg_value.is_empty()
1337 && let Some(arg_name) = prompt
1338 .prompt
1339 .arguments
1340 .as_ref()
1341 .and_then(|args| args.first())
1342 .map(|arg| arg.name.clone())
1343 {
1344 HashMap::from_iter([(arg_name, parsed_command.arg_value.to_string())])
1345 } else {
1346 Default::default()
1347 };
1348
1349 let prompt_name = prompt.prompt.name.clone();
1350 let server_id = prompt.server_id.clone();
1351
1352 return self.0.update(cx, |agent, cx| {
1353 agent.send_mcp_prompt(
1354 id,
1355 session_id.clone(),
1356 prompt_name,
1357 server_id,
1358 arguments,
1359 params.prompt,
1360 cx,
1361 )
1362 });
1363 };
1364 };
1365
1366 let path_style = self.0.read(cx).project.read(cx).path_style(cx);
1367
1368 self.run_turn(session_id, cx, move |thread, cx| {
1369 let content: Vec<UserMessageContent> = params
1370 .prompt
1371 .into_iter()
1372 .map(|block| UserMessageContent::from_content_block(block, path_style))
1373 .collect::<Vec<_>>();
1374 log::debug!("Converted prompt to message: {} chars", content.len());
1375 log::debug!("Message id: {:?}", id);
1376 log::debug!("Message content: {:?}", content);
1377
1378 thread.update(cx, |thread, cx| thread.send(id, content, cx))
1379 })
1380 }
1381
1382 fn retry(
1383 &self,
1384 session_id: &acp::SessionId,
1385 _cx: &App,
1386 ) -> Option<Rc<dyn acp_thread::AgentSessionRetry>> {
1387 Some(Rc::new(NativeAgentSessionRetry {
1388 connection: self.clone(),
1389 session_id: session_id.clone(),
1390 }) as _)
1391 }
1392
1393 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
1394 log::info!("Cancelling on session: {}", session_id);
1395 self.0.update(cx, |agent, cx| {
1396 if let Some(session) = agent.sessions.get(session_id) {
1397 session
1398 .thread
1399 .update(cx, |thread, cx| thread.cancel(cx))
1400 .detach();
1401 }
1402 });
1403 }
1404
1405 fn truncate(
1406 &self,
1407 session_id: &agent_client_protocol::SessionId,
1408 cx: &App,
1409 ) -> Option<Rc<dyn acp_thread::AgentSessionTruncate>> {
1410 self.0.read_with(cx, |agent, _cx| {
1411 agent.sessions.get(session_id).map(|session| {
1412 Rc::new(NativeAgentSessionTruncate {
1413 thread: session.thread.clone(),
1414 acp_thread: session.acp_thread.downgrade(),
1415 }) as _
1416 })
1417 })
1418 }
1419
1420 fn set_title(
1421 &self,
1422 session_id: &acp::SessionId,
1423 cx: &App,
1424 ) -> Option<Rc<dyn acp_thread::AgentSessionSetTitle>> {
1425 self.0.read_with(cx, |agent, _cx| {
1426 agent
1427 .sessions
1428 .get(session_id)
1429 .filter(|s| !s.thread.read(cx).is_subagent())
1430 .map(|session| {
1431 Rc::new(NativeAgentSessionSetTitle {
1432 thread: session.thread.clone(),
1433 }) as _
1434 })
1435 })
1436 }
1437
1438 fn session_list(&self, cx: &mut App) -> Option<Rc<dyn AgentSessionList>> {
1439 let thread_store = self.0.read(cx).thread_store.clone();
1440 Some(Rc::new(NativeAgentSessionList::new(thread_store, cx)) as _)
1441 }
1442
1443 fn telemetry(&self) -> Option<Rc<dyn acp_thread::AgentTelemetry>> {
1444 Some(Rc::new(self.clone()) as Rc<dyn acp_thread::AgentTelemetry>)
1445 }
1446
1447 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1448 self
1449 }
1450}
1451
1452impl acp_thread::AgentTelemetry for NativeAgentConnection {
1453 fn thread_data(
1454 &self,
1455 session_id: &acp::SessionId,
1456 cx: &mut App,
1457 ) -> Task<Result<serde_json::Value>> {
1458 let Some(session) = self.0.read(cx).sessions.get(session_id) else {
1459 return Task::ready(Err(anyhow!("Session not found")));
1460 };
1461
1462 let task = session.thread.read(cx).to_db(cx);
1463 cx.background_spawn(async move {
1464 serde_json::to_value(task.await).context("Failed to serialize thread")
1465 })
1466 }
1467}
1468
1469pub struct NativeAgentSessionList {
1470 thread_store: Entity<ThreadStore>,
1471 updates_tx: smol::channel::Sender<acp_thread::SessionListUpdate>,
1472 updates_rx: smol::channel::Receiver<acp_thread::SessionListUpdate>,
1473 _subscription: Subscription,
1474}
1475
1476impl NativeAgentSessionList {
1477 fn new(thread_store: Entity<ThreadStore>, cx: &mut App) -> Self {
1478 let (tx, rx) = smol::channel::unbounded();
1479 let this_tx = tx.clone();
1480 let subscription = cx.observe(&thread_store, move |_, _| {
1481 this_tx
1482 .try_send(acp_thread::SessionListUpdate::Refresh)
1483 .ok();
1484 });
1485 Self {
1486 thread_store,
1487 updates_tx: tx,
1488 updates_rx: rx,
1489 _subscription: subscription,
1490 }
1491 }
1492
1493 fn to_session_info(entry: DbThreadMetadata) -> AgentSessionInfo {
1494 AgentSessionInfo {
1495 session_id: entry.id,
1496 cwd: None,
1497 title: Some(entry.title),
1498 updated_at: Some(entry.updated_at),
1499 meta: None,
1500 }
1501 }
1502
1503 pub fn thread_store(&self) -> &Entity<ThreadStore> {
1504 &self.thread_store
1505 }
1506}
1507
1508impl AgentSessionList for NativeAgentSessionList {
1509 fn list_sessions(
1510 &self,
1511 _request: AgentSessionListRequest,
1512 cx: &mut App,
1513 ) -> Task<Result<AgentSessionListResponse>> {
1514 let sessions = self
1515 .thread_store
1516 .read(cx)
1517 .entries()
1518 .map(Self::to_session_info)
1519 .collect();
1520 Task::ready(Ok(AgentSessionListResponse::new(sessions)))
1521 }
1522
1523 fn supports_delete(&self) -> bool {
1524 true
1525 }
1526
1527 fn delete_session(&self, session_id: &acp::SessionId, cx: &mut App) -> Task<Result<()>> {
1528 self.thread_store
1529 .update(cx, |store, cx| store.delete_thread(session_id.clone(), cx))
1530 }
1531
1532 fn delete_sessions(&self, cx: &mut App) -> Task<Result<()>> {
1533 self.thread_store
1534 .update(cx, |store, cx| store.delete_threads(cx))
1535 }
1536
1537 fn watch(
1538 &self,
1539 _cx: &mut App,
1540 ) -> Option<smol::channel::Receiver<acp_thread::SessionListUpdate>> {
1541 Some(self.updates_rx.clone())
1542 }
1543
1544 fn notify_refresh(&self) {
1545 self.updates_tx
1546 .try_send(acp_thread::SessionListUpdate::Refresh)
1547 .ok();
1548 }
1549
1550 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1551 self
1552 }
1553}
1554
1555struct NativeAgentSessionTruncate {
1556 thread: Entity<Thread>,
1557 acp_thread: WeakEntity<AcpThread>,
1558}
1559
1560impl acp_thread::AgentSessionTruncate for NativeAgentSessionTruncate {
1561 fn run(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
1562 match self.thread.update(cx, |thread, cx| {
1563 thread.truncate(message_id.clone(), cx)?;
1564 Ok(thread.latest_token_usage())
1565 }) {
1566 Ok(usage) => {
1567 self.acp_thread
1568 .update(cx, |thread, cx| {
1569 thread.update_token_usage(usage, cx);
1570 })
1571 .ok();
1572 Task::ready(Ok(()))
1573 }
1574 Err(error) => Task::ready(Err(error)),
1575 }
1576 }
1577}
1578
1579struct NativeAgentSessionRetry {
1580 connection: NativeAgentConnection,
1581 session_id: acp::SessionId,
1582}
1583
1584impl acp_thread::AgentSessionRetry for NativeAgentSessionRetry {
1585 fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>> {
1586 self.connection
1587 .run_turn(self.session_id.clone(), cx, |thread, cx| {
1588 thread.update(cx, |thread, cx| thread.resume(cx))
1589 })
1590 }
1591}
1592
1593struct NativeAgentSessionSetTitle {
1594 thread: Entity<Thread>,
1595}
1596
1597impl acp_thread::AgentSessionSetTitle for NativeAgentSessionSetTitle {
1598 fn run(&self, title: SharedString, cx: &mut App) -> Task<Result<()>> {
1599 self.thread
1600 .update(cx, |thread, cx| thread.set_title(title, cx));
1601 Task::ready(Ok(()))
1602 }
1603}
1604
1605pub struct NativeThreadEnvironment {
1606 agent: WeakEntity<NativeAgent>,
1607 thread: WeakEntity<Thread>,
1608 acp_thread: WeakEntity<AcpThread>,
1609}
1610
1611impl NativeThreadEnvironment {
1612 pub(crate) fn create_subagent_thread(
1613 &self,
1614 label: String,
1615 cx: &mut App,
1616 ) -> Result<Rc<dyn SubagentHandle>> {
1617 let Some(parent_thread_entity) = self.thread.upgrade() else {
1618 anyhow::bail!("Parent thread no longer exists".to_string());
1619 };
1620 let parent_thread = parent_thread_entity.read(cx);
1621 let current_depth = parent_thread.depth();
1622
1623 if current_depth >= MAX_SUBAGENT_DEPTH {
1624 return Err(anyhow!(
1625 "Maximum subagent depth ({}) reached",
1626 MAX_SUBAGENT_DEPTH
1627 ));
1628 }
1629
1630 let subagent_thread: Entity<Thread> = cx.new(|cx| {
1631 let mut thread = Thread::new_subagent(&parent_thread_entity, cx);
1632 thread.set_title(label.into(), cx);
1633 thread
1634 });
1635
1636 let session_id = subagent_thread.read(cx).id().clone();
1637
1638 let acp_thread = self.agent.update(cx, |agent, cx| {
1639 agent.register_session(subagent_thread.clone(), cx)
1640 })?;
1641
1642 self.prompt_subagent(session_id, subagent_thread, acp_thread)
1643 }
1644
1645 pub(crate) fn resume_subagent_thread(
1646 &self,
1647 session_id: acp::SessionId,
1648 cx: &mut App,
1649 ) -> Result<Rc<dyn SubagentHandle>> {
1650 let (subagent_thread, acp_thread) = self.agent.update(cx, |agent, _cx| {
1651 let session = agent
1652 .sessions
1653 .get(&session_id)
1654 .ok_or_else(|| anyhow!("No subagent session found with id {session_id}"))?;
1655 anyhow::Ok((session.thread.clone(), session.acp_thread.clone()))
1656 })??;
1657
1658 self.prompt_subagent(session_id, subagent_thread, acp_thread)
1659 }
1660
1661 fn prompt_subagent(
1662 &self,
1663 session_id: acp::SessionId,
1664 subagent_thread: Entity<Thread>,
1665 acp_thread: Entity<acp_thread::AcpThread>,
1666 ) -> Result<Rc<dyn SubagentHandle>> {
1667 let Some(parent_thread_entity) = self.thread.upgrade() else {
1668 anyhow::bail!("Parent thread no longer exists".to_string());
1669 };
1670 Ok(Rc::new(NativeSubagentHandle::new(
1671 session_id,
1672 subagent_thread,
1673 acp_thread,
1674 parent_thread_entity,
1675 )) as _)
1676 }
1677}
1678
1679impl ThreadEnvironment for NativeThreadEnvironment {
1680 fn create_terminal(
1681 &self,
1682 command: String,
1683 cwd: Option<PathBuf>,
1684 output_byte_limit: Option<u64>,
1685 cx: &mut AsyncApp,
1686 ) -> Task<Result<Rc<dyn TerminalHandle>>> {
1687 let task = self.acp_thread.update(cx, |thread, cx| {
1688 thread.create_terminal(command, vec![], vec![], cwd, output_byte_limit, cx)
1689 });
1690
1691 let acp_thread = self.acp_thread.clone();
1692 cx.spawn(async move |cx| {
1693 let terminal = task?.await?;
1694
1695 let (drop_tx, drop_rx) = oneshot::channel();
1696 let terminal_id = terminal.read_with(cx, |terminal, _cx| terminal.id().clone());
1697
1698 cx.spawn(async move |cx| {
1699 drop_rx.await.ok();
1700 acp_thread.update(cx, |thread, cx| thread.release_terminal(terminal_id, cx))
1701 })
1702 .detach();
1703
1704 let handle = AcpTerminalHandle {
1705 terminal,
1706 _drop_tx: Some(drop_tx),
1707 };
1708
1709 Ok(Rc::new(handle) as _)
1710 })
1711 }
1712
1713 fn create_subagent(&self, label: String, cx: &mut App) -> Result<Rc<dyn SubagentHandle>> {
1714 self.create_subagent_thread(label, cx)
1715 }
1716
1717 fn resume_subagent(
1718 &self,
1719 session_id: acp::SessionId,
1720 cx: &mut App,
1721 ) -> Result<Rc<dyn SubagentHandle>> {
1722 self.resume_subagent_thread(session_id, cx)
1723 }
1724}
1725
1726#[derive(Debug, Clone)]
1727enum SubagentPromptResult {
1728 Completed,
1729 Cancelled,
1730 ContextWindowWarning,
1731 Error(String),
1732}
1733
1734pub struct NativeSubagentHandle {
1735 session_id: acp::SessionId,
1736 parent_thread: WeakEntity<Thread>,
1737 subagent_thread: Entity<Thread>,
1738 acp_thread: Entity<acp_thread::AcpThread>,
1739}
1740
1741impl NativeSubagentHandle {
1742 fn new(
1743 session_id: acp::SessionId,
1744 subagent_thread: Entity<Thread>,
1745 acp_thread: Entity<acp_thread::AcpThread>,
1746 parent_thread_entity: Entity<Thread>,
1747 ) -> Self {
1748 NativeSubagentHandle {
1749 session_id,
1750 subagent_thread,
1751 parent_thread: parent_thread_entity.downgrade(),
1752 acp_thread,
1753 }
1754 }
1755}
1756
1757impl SubagentHandle for NativeSubagentHandle {
1758 fn id(&self) -> acp::SessionId {
1759 self.session_id.clone()
1760 }
1761
1762 fn num_entries(&self, cx: &App) -> usize {
1763 self.acp_thread.read(cx).entries().len()
1764 }
1765
1766 fn send(&self, message: String, cx: &AsyncApp) -> Task<Result<String>> {
1767 let thread = self.subagent_thread.clone();
1768 let acp_thread = self.acp_thread.clone();
1769 let subagent_session_id = self.session_id.clone();
1770 let parent_thread = self.parent_thread.clone();
1771
1772 cx.spawn(async move |cx| {
1773 let (task, _subscription) = cx.update(|cx| {
1774 let ratio_before_prompt = thread
1775 .read(cx)
1776 .latest_token_usage()
1777 .map(|usage| usage.ratio());
1778
1779 parent_thread
1780 .update(cx, |parent_thread, _cx| {
1781 parent_thread.register_running_subagent(thread.downgrade())
1782 })
1783 .ok();
1784
1785 let task = acp_thread.update(cx, |acp_thread, cx| {
1786 acp_thread.send(vec![message.into()], cx)
1787 });
1788
1789 let (token_limit_tx, token_limit_rx) = oneshot::channel::<()>();
1790 let mut token_limit_tx = Some(token_limit_tx);
1791
1792 let subscription = cx.subscribe(
1793 &thread,
1794 move |_thread, event: &TokenUsageUpdated, _cx| {
1795 if let Some(usage) = &event.0 {
1796 let old_ratio = ratio_before_prompt
1797 .clone()
1798 .unwrap_or(TokenUsageRatio::Normal);
1799 let new_ratio = usage.ratio();
1800 if old_ratio == TokenUsageRatio::Normal
1801 && new_ratio == TokenUsageRatio::Warning
1802 {
1803 if let Some(tx) = token_limit_tx.take() {
1804 tx.send(()).ok();
1805 }
1806 }
1807 }
1808 },
1809 );
1810
1811 let wait_for_prompt = cx
1812 .background_spawn(async move {
1813 futures::select! {
1814 response = task.fuse() => match response {
1815 Ok(Some(response)) => {
1816 match response.stop_reason {
1817 acp::StopReason::Cancelled => SubagentPromptResult::Cancelled,
1818 acp::StopReason::MaxTokens => SubagentPromptResult::Error("The agent reached the maximum number of tokens.".into()),
1819 acp::StopReason::MaxTurnRequests => SubagentPromptResult::Error("The agent reached the maximum number of allowed requests between user turns. Try prompting again.".into()),
1820 acp::StopReason::Refusal => SubagentPromptResult::Error("The agent refused to process that prompt. Try again.".into()),
1821 acp::StopReason::EndTurn | _ => SubagentPromptResult::Completed,
1822 }
1823 }
1824 Ok(None) => SubagentPromptResult::Error("No response from the agent. You can try messaging again.".into()),
1825 Err(error) => SubagentPromptResult::Error(error.to_string()),
1826 },
1827 _ = token_limit_rx.fuse() => SubagentPromptResult::ContextWindowWarning,
1828 }
1829 });
1830
1831 (wait_for_prompt, subscription)
1832 });
1833
1834 let result = match task.await {
1835 SubagentPromptResult::Completed => thread.read_with(cx, |thread, _cx| {
1836 thread
1837 .last_message()
1838 .and_then(|message| {
1839 let content = message.as_agent_message()?
1840 .content
1841 .iter()
1842 .filter_map(|c| match c {
1843 AgentMessageContent::Text(text) => Some(text.as_str()),
1844 _ => None,
1845 })
1846 .join("\n\n");
1847 if content.is_empty() {
1848 None
1849 } else {
1850 Some( content)
1851 }
1852 })
1853 .context("No response from subagent")
1854 }),
1855 SubagentPromptResult::Cancelled => Err(anyhow!("User canceled")),
1856 SubagentPromptResult::Error(message) => Err(anyhow!("{message}")),
1857 SubagentPromptResult::ContextWindowWarning => {
1858 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
1859 Err(anyhow!(
1860 "The agent is nearing the end of its context window and has been \
1861 stopped. You can prompt the thread again to have the agent wrap up \
1862 or hand off its work."
1863 ))
1864 }
1865 };
1866
1867 parent_thread
1868 .update(cx, |parent_thread, cx| {
1869 parent_thread.unregister_running_subagent(&subagent_session_id, cx)
1870 })
1871 .ok();
1872
1873 result
1874 })
1875 }
1876}
1877
1878pub struct AcpTerminalHandle {
1879 terminal: Entity<acp_thread::Terminal>,
1880 _drop_tx: Option<oneshot::Sender<()>>,
1881}
1882
1883impl TerminalHandle for AcpTerminalHandle {
1884 fn id(&self, cx: &AsyncApp) -> Result<acp::TerminalId> {
1885 Ok(self.terminal.read_with(cx, |term, _cx| term.id().clone()))
1886 }
1887
1888 fn wait_for_exit(&self, cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
1889 Ok(self
1890 .terminal
1891 .read_with(cx, |term, _cx| term.wait_for_exit()))
1892 }
1893
1894 fn current_output(&self, cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
1895 Ok(self
1896 .terminal
1897 .read_with(cx, |term, cx| term.current_output(cx)))
1898 }
1899
1900 fn kill(&self, cx: &AsyncApp) -> Result<()> {
1901 cx.update(|cx| {
1902 self.terminal.update(cx, |terminal, cx| {
1903 terminal.kill(cx);
1904 });
1905 });
1906 Ok(())
1907 }
1908
1909 fn was_stopped_by_user(&self, cx: &AsyncApp) -> Result<bool> {
1910 Ok(self
1911 .terminal
1912 .read_with(cx, |term, _cx| term.was_stopped_by_user()))
1913 }
1914}
1915
1916#[cfg(test)]
1917mod internal_tests {
1918 use super::*;
1919 use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelInfo, MentionUri};
1920 use fs::FakeFs;
1921 use gpui::TestAppContext;
1922 use indoc::formatdoc;
1923 use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
1924 use language_model::{
1925 LanguageModelCompletionEvent, LanguageModelProviderId, LanguageModelProviderName,
1926 };
1927 use serde_json::json;
1928 use settings::SettingsStore;
1929 use util::{path, rel_path::rel_path};
1930
1931 #[gpui::test]
1932 async fn test_maintaining_project_context(cx: &mut TestAppContext) {
1933 init_test(cx);
1934 let fs = FakeFs::new(cx.executor());
1935 fs.insert_tree(
1936 "/",
1937 json!({
1938 "a": {}
1939 }),
1940 )
1941 .await;
1942 let project = Project::test(fs.clone(), [], cx).await;
1943 let thread_store = cx.new(|cx| ThreadStore::new(cx));
1944 let agent = NativeAgent::new(
1945 project.clone(),
1946 thread_store,
1947 Templates::new(),
1948 None,
1949 fs.clone(),
1950 &mut cx.to_async(),
1951 )
1952 .await
1953 .unwrap();
1954 agent.read_with(cx, |agent, cx| {
1955 assert_eq!(agent.project_context.read(cx).worktrees, vec![])
1956 });
1957
1958 let worktree = project
1959 .update(cx, |project, cx| project.create_worktree("/a", true, cx))
1960 .await
1961 .unwrap();
1962 cx.run_until_parked();
1963 agent.read_with(cx, |agent, cx| {
1964 assert_eq!(
1965 agent.project_context.read(cx).worktrees,
1966 vec![WorktreeContext {
1967 root_name: "a".into(),
1968 abs_path: Path::new("/a").into(),
1969 rules_file: None
1970 }]
1971 )
1972 });
1973
1974 // Creating `/a/.rules` updates the project context.
1975 fs.insert_file("/a/.rules", Vec::new()).await;
1976 cx.run_until_parked();
1977 agent.read_with(cx, |agent, cx| {
1978 let rules_entry = worktree
1979 .read(cx)
1980 .entry_for_path(rel_path(".rules"))
1981 .unwrap();
1982 assert_eq!(
1983 agent.project_context.read(cx).worktrees,
1984 vec![WorktreeContext {
1985 root_name: "a".into(),
1986 abs_path: Path::new("/a").into(),
1987 rules_file: Some(RulesFileContext {
1988 path_in_worktree: rel_path(".rules").into(),
1989 text: "".into(),
1990 project_entry_id: rules_entry.id.to_usize()
1991 })
1992 }]
1993 )
1994 });
1995 }
1996
1997 #[gpui::test]
1998 async fn test_listing_models(cx: &mut TestAppContext) {
1999 init_test(cx);
2000 let fs = FakeFs::new(cx.executor());
2001 fs.insert_tree("/", json!({ "a": {} })).await;
2002 let project = Project::test(fs.clone(), [], cx).await;
2003 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2004 let connection = NativeAgentConnection(
2005 NativeAgent::new(
2006 project.clone(),
2007 thread_store,
2008 Templates::new(),
2009 None,
2010 fs.clone(),
2011 &mut cx.to_async(),
2012 )
2013 .await
2014 .unwrap(),
2015 );
2016
2017 // Create a thread/session
2018 let acp_thread = cx
2019 .update(|cx| {
2020 Rc::new(connection.clone()).new_session(project.clone(), Path::new("/a"), cx)
2021 })
2022 .await
2023 .unwrap();
2024
2025 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2026
2027 let models = cx
2028 .update(|cx| {
2029 connection
2030 .model_selector(&session_id)
2031 .unwrap()
2032 .list_models(cx)
2033 })
2034 .await
2035 .unwrap();
2036
2037 let acp_thread::AgentModelList::Grouped(models) = models else {
2038 panic!("Unexpected model group");
2039 };
2040 assert_eq!(
2041 models,
2042 IndexMap::from_iter([(
2043 AgentModelGroupName("Fake".into()),
2044 vec![AgentModelInfo {
2045 id: acp::ModelId::new("fake/fake"),
2046 name: "Fake".into(),
2047 description: None,
2048 icon: Some(acp_thread::AgentModelIcon::Named(
2049 ui::IconName::ZedAssistant
2050 )),
2051 is_latest: false,
2052 cost: None,
2053 }]
2054 )])
2055 );
2056 }
2057
2058 #[gpui::test]
2059 async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) {
2060 init_test(cx);
2061 let fs = FakeFs::new(cx.executor());
2062 fs.create_dir(paths::settings_file().parent().unwrap())
2063 .await
2064 .unwrap();
2065 fs.insert_file(
2066 paths::settings_file(),
2067 json!({
2068 "agent": {
2069 "default_model": {
2070 "provider": "foo",
2071 "model": "bar"
2072 }
2073 }
2074 })
2075 .to_string()
2076 .into_bytes(),
2077 )
2078 .await;
2079 let project = Project::test(fs.clone(), [], cx).await;
2080
2081 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2082
2083 // Create the agent and connection
2084 let agent = NativeAgent::new(
2085 project.clone(),
2086 thread_store,
2087 Templates::new(),
2088 None,
2089 fs.clone(),
2090 &mut cx.to_async(),
2091 )
2092 .await
2093 .unwrap();
2094 let connection = NativeAgentConnection(agent.clone());
2095
2096 // Create a thread/session
2097 let acp_thread = cx
2098 .update(|cx| {
2099 Rc::new(connection.clone()).new_session(project.clone(), Path::new("/a"), cx)
2100 })
2101 .await
2102 .unwrap();
2103
2104 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2105
2106 // Select a model
2107 let selector = connection.model_selector(&session_id).unwrap();
2108 let model_id = acp::ModelId::new("fake/fake");
2109 cx.update(|cx| selector.select_model(model_id.clone(), cx))
2110 .await
2111 .unwrap();
2112
2113 // Verify the thread has the selected model
2114 agent.read_with(cx, |agent, _| {
2115 let session = agent.sessions.get(&session_id).unwrap();
2116 session.thread.read_with(cx, |thread, _| {
2117 assert_eq!(thread.model().unwrap().id().0, "fake");
2118 });
2119 });
2120
2121 cx.run_until_parked();
2122
2123 // Verify settings file was updated
2124 let settings_content = fs.load(paths::settings_file()).await.unwrap();
2125 let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
2126
2127 // Check that the agent settings contain the selected model
2128 assert_eq!(
2129 settings_json["agent"]["default_model"]["model"],
2130 json!("fake")
2131 );
2132 assert_eq!(
2133 settings_json["agent"]["default_model"]["provider"],
2134 json!("fake")
2135 );
2136
2137 // Register a thinking model and select it.
2138 cx.update(|cx| {
2139 let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2140 "fake-corp",
2141 "fake-thinking",
2142 "Fake Thinking",
2143 true,
2144 ));
2145 let thinking_provider = Arc::new(
2146 FakeLanguageModelProvider::new(
2147 LanguageModelProviderId::from("fake-corp".to_string()),
2148 LanguageModelProviderName::from("Fake Corp".to_string()),
2149 )
2150 .with_models(vec![thinking_model]),
2151 );
2152 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2153 registry.register_provider(thinking_provider, cx);
2154 });
2155 });
2156 agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2157
2158 let selector = connection.model_selector(&session_id).unwrap();
2159 cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2160 .await
2161 .unwrap();
2162 cx.run_until_parked();
2163
2164 // Verify enable_thinking was written to settings as true.
2165 let settings_content = fs.load(paths::settings_file()).await.unwrap();
2166 let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
2167 assert_eq!(
2168 settings_json["agent"]["default_model"]["enable_thinking"],
2169 json!(true),
2170 "selecting a thinking model should persist enable_thinking: true to settings"
2171 );
2172 }
2173
2174 #[gpui::test]
2175 async fn test_select_model_updates_thinking_enabled(cx: &mut TestAppContext) {
2176 init_test(cx);
2177 let fs = FakeFs::new(cx.executor());
2178 fs.create_dir(paths::settings_file().parent().unwrap())
2179 .await
2180 .unwrap();
2181 fs.insert_file(paths::settings_file(), b"{}".to_vec()).await;
2182 let project = Project::test(fs.clone(), [], cx).await;
2183
2184 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2185 let agent = NativeAgent::new(
2186 project.clone(),
2187 thread_store,
2188 Templates::new(),
2189 None,
2190 fs.clone(),
2191 &mut cx.to_async(),
2192 )
2193 .await
2194 .unwrap();
2195 let connection = NativeAgentConnection(agent.clone());
2196
2197 let acp_thread = cx
2198 .update(|cx| {
2199 Rc::new(connection.clone()).new_session(project.clone(), Path::new("/a"), cx)
2200 })
2201 .await
2202 .unwrap();
2203 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2204
2205 // Register a second provider with a thinking model.
2206 cx.update(|cx| {
2207 let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2208 "fake-corp",
2209 "fake-thinking",
2210 "Fake Thinking",
2211 true,
2212 ));
2213 let thinking_provider = Arc::new(
2214 FakeLanguageModelProvider::new(
2215 LanguageModelProviderId::from("fake-corp".to_string()),
2216 LanguageModelProviderName::from("Fake Corp".to_string()),
2217 )
2218 .with_models(vec![thinking_model]),
2219 );
2220 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2221 registry.register_provider(thinking_provider, cx);
2222 });
2223 });
2224 // Refresh the agent's model list so it picks up the new provider.
2225 agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2226
2227 // Thread starts with thinking_enabled = false (the default).
2228 agent.read_with(cx, |agent, _| {
2229 let session = agent.sessions.get(&session_id).unwrap();
2230 session.thread.read_with(cx, |thread, _| {
2231 assert!(!thread.thinking_enabled(), "thinking defaults to false");
2232 });
2233 });
2234
2235 // Select the thinking model via select_model.
2236 let selector = connection.model_selector(&session_id).unwrap();
2237 cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2238 .await
2239 .unwrap();
2240
2241 // select_model should have enabled thinking based on the model's supports_thinking().
2242 agent.read_with(cx, |agent, _| {
2243 let session = agent.sessions.get(&session_id).unwrap();
2244 session.thread.read_with(cx, |thread, _| {
2245 assert!(
2246 thread.thinking_enabled(),
2247 "select_model should enable thinking when model supports it"
2248 );
2249 });
2250 });
2251
2252 // Switch back to the non-thinking model.
2253 let selector = connection.model_selector(&session_id).unwrap();
2254 cx.update(|cx| selector.select_model(acp::ModelId::new("fake/fake"), cx))
2255 .await
2256 .unwrap();
2257
2258 // select_model should have disabled thinking.
2259 agent.read_with(cx, |agent, _| {
2260 let session = agent.sessions.get(&session_id).unwrap();
2261 session.thread.read_with(cx, |thread, _| {
2262 assert!(
2263 !thread.thinking_enabled(),
2264 "select_model should disable thinking when model does not support it"
2265 );
2266 });
2267 });
2268 }
2269
2270 #[gpui::test]
2271 async fn test_loaded_thread_preserves_thinking_enabled(cx: &mut TestAppContext) {
2272 init_test(cx);
2273 let fs = FakeFs::new(cx.executor());
2274 fs.insert_tree("/", json!({ "a": {} })).await;
2275 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2276 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2277 let agent = NativeAgent::new(
2278 project.clone(),
2279 thread_store.clone(),
2280 Templates::new(),
2281 None,
2282 fs.clone(),
2283 &mut cx.to_async(),
2284 )
2285 .await
2286 .unwrap();
2287 let connection = Rc::new(NativeAgentConnection(agent.clone()));
2288
2289 // Register a thinking model.
2290 let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2291 "fake-corp",
2292 "fake-thinking",
2293 "Fake Thinking",
2294 true,
2295 ));
2296 let thinking_provider = Arc::new(
2297 FakeLanguageModelProvider::new(
2298 LanguageModelProviderId::from("fake-corp".to_string()),
2299 LanguageModelProviderName::from("Fake Corp".to_string()),
2300 )
2301 .with_models(vec![thinking_model.clone()]),
2302 );
2303 cx.update(|cx| {
2304 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2305 registry.register_provider(thinking_provider, cx);
2306 });
2307 });
2308 agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2309
2310 // Create a thread and select the thinking model.
2311 let acp_thread = cx
2312 .update(|cx| {
2313 connection
2314 .clone()
2315 .new_session(project.clone(), Path::new("/a"), cx)
2316 })
2317 .await
2318 .unwrap();
2319 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2320
2321 let selector = connection.model_selector(&session_id).unwrap();
2322 cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2323 .await
2324 .unwrap();
2325
2326 // Verify thinking is enabled after selecting the thinking model.
2327 let thread = agent.read_with(cx, |agent, _| {
2328 agent.sessions.get(&session_id).unwrap().thread.clone()
2329 });
2330 thread.read_with(cx, |thread, _| {
2331 assert!(
2332 thread.thinking_enabled(),
2333 "thinking should be enabled after selecting thinking model"
2334 );
2335 });
2336
2337 // Send a message so the thread gets persisted.
2338 let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
2339 let send = cx.foreground_executor().spawn(send);
2340 cx.run_until_parked();
2341
2342 thinking_model.send_last_completion_stream_text_chunk("Response.");
2343 thinking_model.end_last_completion_stream();
2344
2345 send.await.unwrap();
2346 cx.run_until_parked();
2347
2348 // Close the session so it can be reloaded from disk.
2349 cx.update(|cx| connection.clone().close_session(&session_id, cx))
2350 .await
2351 .unwrap();
2352 drop(thread);
2353 drop(acp_thread);
2354 agent.read_with(cx, |agent, _| {
2355 assert!(agent.sessions.is_empty());
2356 });
2357
2358 // Reload the thread and verify thinking_enabled is still true.
2359 let reloaded_acp_thread = agent
2360 .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
2361 .await
2362 .unwrap();
2363 let reloaded_thread = agent.read_with(cx, |agent, _| {
2364 agent.sessions.get(&session_id).unwrap().thread.clone()
2365 });
2366 reloaded_thread.read_with(cx, |thread, _| {
2367 assert!(
2368 thread.thinking_enabled(),
2369 "thinking_enabled should be preserved when reloading a thread with a thinking model"
2370 );
2371 });
2372
2373 drop(reloaded_acp_thread);
2374 }
2375
2376 #[gpui::test]
2377 async fn test_loaded_thread_preserves_model(cx: &mut TestAppContext) {
2378 init_test(cx);
2379 let fs = FakeFs::new(cx.executor());
2380 fs.insert_tree("/", json!({ "a": {} })).await;
2381 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2382 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2383 let agent = NativeAgent::new(
2384 project.clone(),
2385 thread_store.clone(),
2386 Templates::new(),
2387 None,
2388 fs.clone(),
2389 &mut cx.to_async(),
2390 )
2391 .await
2392 .unwrap();
2393 let connection = Rc::new(NativeAgentConnection(agent.clone()));
2394
2395 // Register a model where id() != name(), like real Anthropic models
2396 // (e.g. id="claude-sonnet-4-5-thinking-latest", name="Claude Sonnet 4.5 Thinking").
2397 let model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2398 "fake-corp",
2399 "custom-model-id",
2400 "Custom Model Display Name",
2401 false,
2402 ));
2403 let provider = Arc::new(
2404 FakeLanguageModelProvider::new(
2405 LanguageModelProviderId::from("fake-corp".to_string()),
2406 LanguageModelProviderName::from("Fake Corp".to_string()),
2407 )
2408 .with_models(vec![model.clone()]),
2409 );
2410 cx.update(|cx| {
2411 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2412 registry.register_provider(provider, cx);
2413 });
2414 });
2415 agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2416
2417 // Create a thread and select the model.
2418 let acp_thread = cx
2419 .update(|cx| {
2420 connection
2421 .clone()
2422 .new_session(project.clone(), Path::new("/a"), cx)
2423 })
2424 .await
2425 .unwrap();
2426 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2427
2428 let selector = connection.model_selector(&session_id).unwrap();
2429 cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/custom-model-id"), cx))
2430 .await
2431 .unwrap();
2432
2433 let thread = agent.read_with(cx, |agent, _| {
2434 agent.sessions.get(&session_id).unwrap().thread.clone()
2435 });
2436 thread.read_with(cx, |thread, _| {
2437 assert_eq!(
2438 thread.model().unwrap().id().0.as_ref(),
2439 "custom-model-id",
2440 "model should be set before persisting"
2441 );
2442 });
2443
2444 // Send a message so the thread gets persisted.
2445 let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
2446 let send = cx.foreground_executor().spawn(send);
2447 cx.run_until_parked();
2448
2449 model.send_last_completion_stream_text_chunk("Response.");
2450 model.end_last_completion_stream();
2451
2452 send.await.unwrap();
2453 cx.run_until_parked();
2454
2455 // Close the session so it can be reloaded from disk.
2456 cx.update(|cx| connection.clone().close_session(&session_id, cx))
2457 .await
2458 .unwrap();
2459 drop(thread);
2460 drop(acp_thread);
2461 agent.read_with(cx, |agent, _| {
2462 assert!(agent.sessions.is_empty());
2463 });
2464
2465 // Reload the thread and verify the model was preserved.
2466 let reloaded_acp_thread = agent
2467 .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
2468 .await
2469 .unwrap();
2470 let reloaded_thread = agent.read_with(cx, |agent, _| {
2471 agent.sessions.get(&session_id).unwrap().thread.clone()
2472 });
2473 reloaded_thread.read_with(cx, |thread, _| {
2474 let reloaded_model = thread
2475 .model()
2476 .expect("model should be present after reload");
2477 assert_eq!(
2478 reloaded_model.id().0.as_ref(),
2479 "custom-model-id",
2480 "reloaded thread should have the same model, not fall back to the default"
2481 );
2482 });
2483
2484 drop(reloaded_acp_thread);
2485 }
2486
2487 #[gpui::test]
2488 async fn test_save_load_thread(cx: &mut TestAppContext) {
2489 init_test(cx);
2490 let fs = FakeFs::new(cx.executor());
2491 fs.insert_tree(
2492 "/",
2493 json!({
2494 "a": {
2495 "b.md": "Lorem"
2496 }
2497 }),
2498 )
2499 .await;
2500 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2501 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2502 let agent = NativeAgent::new(
2503 project.clone(),
2504 thread_store.clone(),
2505 Templates::new(),
2506 None,
2507 fs.clone(),
2508 &mut cx.to_async(),
2509 )
2510 .await
2511 .unwrap();
2512 let connection = Rc::new(NativeAgentConnection(agent.clone()));
2513
2514 let acp_thread = cx
2515 .update(|cx| {
2516 connection
2517 .clone()
2518 .new_session(project.clone(), Path::new(""), cx)
2519 })
2520 .await
2521 .unwrap();
2522 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2523 let thread = agent.read_with(cx, |agent, _| {
2524 agent.sessions.get(&session_id).unwrap().thread.clone()
2525 });
2526
2527 // Ensure empty threads are not saved, even if they get mutated.
2528 let model = Arc::new(FakeLanguageModel::default());
2529 let summary_model = Arc::new(FakeLanguageModel::default());
2530 thread.update(cx, |thread, cx| {
2531 thread.set_model(model.clone(), cx);
2532 thread.set_summarization_model(Some(summary_model.clone()), cx);
2533 });
2534 cx.run_until_parked();
2535 assert_eq!(thread_entries(&thread_store, cx), vec![]);
2536
2537 let send = acp_thread.update(cx, |thread, cx| {
2538 thread.send(
2539 vec![
2540 "What does ".into(),
2541 acp::ContentBlock::ResourceLink(acp::ResourceLink::new(
2542 "b.md",
2543 MentionUri::File {
2544 abs_path: path!("/a/b.md").into(),
2545 }
2546 .to_uri()
2547 .to_string(),
2548 )),
2549 " mean?".into(),
2550 ],
2551 cx,
2552 )
2553 });
2554 let send = cx.foreground_executor().spawn(send);
2555 cx.run_until_parked();
2556
2557 model.send_last_completion_stream_text_chunk("Lorem.");
2558 model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2559 language_model::TokenUsage {
2560 input_tokens: 150,
2561 output_tokens: 75,
2562 ..Default::default()
2563 },
2564 ));
2565 model.end_last_completion_stream();
2566 cx.run_until_parked();
2567 summary_model
2568 .send_last_completion_stream_text_chunk(&format!("Explaining {}", path!("/a/b.md")));
2569 summary_model.end_last_completion_stream();
2570
2571 send.await.unwrap();
2572 let uri = MentionUri::File {
2573 abs_path: path!("/a/b.md").into(),
2574 }
2575 .to_uri();
2576 acp_thread.read_with(cx, |thread, cx| {
2577 assert_eq!(
2578 thread.to_markdown(cx),
2579 formatdoc! {"
2580 ## User
2581
2582 What does [@b.md]({uri}) mean?
2583
2584 ## Assistant
2585
2586 Lorem.
2587
2588 "}
2589 )
2590 });
2591
2592 cx.run_until_parked();
2593
2594 // Set a draft prompt with rich content blocks before saving.
2595 let draft_blocks = vec![
2596 acp::ContentBlock::Text(acp::TextContent::new("Check out ")),
2597 acp::ContentBlock::ResourceLink(acp::ResourceLink::new("b.md", uri.to_string())),
2598 acp::ContentBlock::Text(acp::TextContent::new(" please")),
2599 ];
2600 acp_thread.update(cx, |thread, _cx| {
2601 thread.set_draft_prompt(Some(draft_blocks.clone()));
2602 });
2603 thread.update(cx, |thread, _cx| {
2604 thread.set_ui_scroll_position(Some(gpui::ListOffset {
2605 item_ix: 5,
2606 offset_in_item: gpui::px(12.5),
2607 }));
2608 });
2609 thread.update(cx, |_thread, cx| cx.notify());
2610 cx.run_until_parked();
2611
2612 // Close the session so it can be reloaded from disk.
2613 cx.update(|cx| connection.clone().close_session(&session_id, cx))
2614 .await
2615 .unwrap();
2616 drop(thread);
2617 drop(acp_thread);
2618 agent.read_with(cx, |agent, _| {
2619 assert_eq!(agent.sessions.keys().cloned().collect::<Vec<_>>(), []);
2620 });
2621
2622 // Ensure the thread can be reloaded from disk.
2623 assert_eq!(
2624 thread_entries(&thread_store, cx),
2625 vec![(
2626 session_id.clone(),
2627 format!("Explaining {}", path!("/a/b.md"))
2628 )]
2629 );
2630 let acp_thread = agent
2631 .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
2632 .await
2633 .unwrap();
2634 acp_thread.read_with(cx, |thread, cx| {
2635 assert_eq!(
2636 thread.to_markdown(cx),
2637 formatdoc! {"
2638 ## User
2639
2640 What does [@b.md]({uri}) mean?
2641
2642 ## Assistant
2643
2644 Lorem.
2645
2646 "}
2647 )
2648 });
2649
2650 // Ensure the draft prompt with rich content blocks survived the round-trip.
2651 acp_thread.read_with(cx, |thread, _| {
2652 assert_eq!(thread.draft_prompt(), Some(draft_blocks.as_slice()));
2653 });
2654
2655 // Ensure token usage survived the round-trip.
2656 acp_thread.read_with(cx, |thread, _| {
2657 let usage = thread
2658 .token_usage()
2659 .expect("token usage should be restored after reload");
2660 assert_eq!(usage.input_tokens, 150);
2661 assert_eq!(usage.output_tokens, 75);
2662 });
2663
2664 // Ensure scroll position survived the round-trip.
2665 acp_thread.read_with(cx, |thread, _| {
2666 let scroll = thread
2667 .ui_scroll_position()
2668 .expect("scroll position should be restored after reload");
2669 assert_eq!(scroll.item_ix, 5);
2670 assert_eq!(scroll.offset_in_item, gpui::px(12.5));
2671 });
2672 }
2673
2674 fn thread_entries(
2675 thread_store: &Entity<ThreadStore>,
2676 cx: &mut TestAppContext,
2677 ) -> Vec<(acp::SessionId, String)> {
2678 thread_store.read_with(cx, |store, _| {
2679 store
2680 .entries()
2681 .map(|entry| (entry.id.clone(), entry.title.to_string()))
2682 .collect::<Vec<_>>()
2683 })
2684 }
2685
2686 fn init_test(cx: &mut TestAppContext) {
2687 env_logger::try_init().ok();
2688 cx.update(|cx| {
2689 let settings_store = SettingsStore::test(cx);
2690 cx.set_global(settings_store);
2691
2692 LanguageModelRegistry::test(cx);
2693 });
2694 }
2695}
2696
2697fn mcp_message_content_to_acp_content_block(
2698 content: context_server::types::MessageContent,
2699) -> acp::ContentBlock {
2700 match content {
2701 context_server::types::MessageContent::Text {
2702 text,
2703 annotations: _,
2704 } => text.into(),
2705 context_server::types::MessageContent::Image {
2706 data,
2707 mime_type,
2708 annotations: _,
2709 } => acp::ContentBlock::Image(acp::ImageContent::new(data, mime_type)),
2710 context_server::types::MessageContent::Audio {
2711 data,
2712 mime_type,
2713 annotations: _,
2714 } => acp::ContentBlock::Audio(acp::AudioContent::new(data, mime_type)),
2715 context_server::types::MessageContent::Resource {
2716 resource,
2717 annotations: _,
2718 } => {
2719 let mut link =
2720 acp::ResourceLink::new(resource.uri.to_string(), resource.uri.to_string());
2721 if let Some(mime_type) = resource.mime_type {
2722 link = link.mime_type(mime_type);
2723 }
2724 acp::ContentBlock::ResourceLink(link)
2725 }
2726 }
2727}