1mod db;
2mod edit_agent;
3mod legacy_thread;
4mod native_agent_server;
5pub mod outline;
6mod pattern_extraction;
7mod templates;
8#[cfg(test)]
9mod tests;
10mod thread;
11mod thread_store;
12mod tool_permissions;
13mod tools;
14
15use context_server::ContextServerId;
16pub use db::*;
17pub use native_agent_server::NativeAgentServer;
18pub use pattern_extraction::*;
19pub use shell_command_parser::extract_commands;
20pub use templates::*;
21pub use thread::*;
22pub use thread_store::*;
23pub use tool_permissions::*;
24pub use tools::*;
25
26use acp_thread::{
27 AcpThread, AgentModelSelector, AgentSessionInfo, AgentSessionList, AgentSessionListRequest,
28 AgentSessionListResponse, TokenUsageRatio, UserMessageId,
29};
30use agent_client_protocol as acp;
31use anyhow::{Context as _, Result, anyhow};
32use chrono::{DateTime, Utc};
33use collections::{HashMap, HashSet, IndexMap};
34use fs::Fs;
35use futures::channel::{mpsc, oneshot};
36use futures::future::Shared;
37use futures::{FutureExt as _, StreamExt as _, future};
38use gpui::{
39 App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
40};
41use language_model::{IconOrSvg, LanguageModel, LanguageModelProvider, LanguageModelRegistry};
42use project::{Project, ProjectItem, ProjectPath, Worktree};
43use prompt_store::{
44 ProjectContext, PromptStore, RULES_FILE_NAMES, RulesFileContext, UserRulesContext,
45 WorktreeContext,
46};
47use serde::{Deserialize, Serialize};
48use settings::{LanguageModelSelection, update_settings_file};
49use std::any::Any;
50use std::path::{Path, PathBuf};
51use std::rc::Rc;
52use std::sync::Arc;
53use util::ResultExt;
54use util::rel_path::RelPath;
55
56#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
57pub struct ProjectSnapshot {
58 pub worktree_snapshots: Vec<project::telemetry_snapshot::TelemetryWorktreeSnapshot>,
59 pub timestamp: DateTime<Utc>,
60}
61
62pub struct RulesLoadingError {
63 pub message: SharedString,
64}
65
66/// Holds both the internal Thread and the AcpThread for a session
67struct Session {
68 /// The internal thread that processes messages
69 thread: Entity<Thread>,
70 /// The ACP thread that handles protocol communication
71 acp_thread: Entity<acp_thread::AcpThread>,
72 pending_save: Task<()>,
73 _subscriptions: Vec<Subscription>,
74}
75
76pub struct LanguageModels {
77 /// Access language model by ID
78 models: HashMap<acp::ModelId, Arc<dyn LanguageModel>>,
79 /// Cached list for returning language model information
80 model_list: acp_thread::AgentModelList,
81 refresh_models_rx: watch::Receiver<()>,
82 refresh_models_tx: watch::Sender<()>,
83 _authenticate_all_providers_task: Task<()>,
84}
85
86impl LanguageModels {
87 fn new(cx: &mut App) -> Self {
88 let (refresh_models_tx, refresh_models_rx) = watch::channel(());
89
90 let mut this = Self {
91 models: HashMap::default(),
92 model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()),
93 refresh_models_rx,
94 refresh_models_tx,
95 _authenticate_all_providers_task: Self::authenticate_all_language_model_providers(cx),
96 };
97 this.refresh_list(cx);
98 this
99 }
100
101 fn refresh_list(&mut self, cx: &App) {
102 let providers = LanguageModelRegistry::global(cx)
103 .read(cx)
104 .visible_providers()
105 .into_iter()
106 .filter(|provider| provider.is_authenticated(cx))
107 .collect::<Vec<_>>();
108
109 let mut language_model_list = IndexMap::default();
110 let mut recommended_models = HashSet::default();
111
112 let mut recommended = Vec::new();
113 for provider in &providers {
114 for model in provider.recommended_models(cx) {
115 recommended_models.insert((model.provider_id(), model.id()));
116 recommended.push(Self::map_language_model_to_info(&model, provider));
117 }
118 }
119 if !recommended.is_empty() {
120 language_model_list.insert(
121 acp_thread::AgentModelGroupName("Recommended".into()),
122 recommended,
123 );
124 }
125
126 let mut models = HashMap::default();
127 for provider in providers {
128 let mut provider_models = Vec::new();
129 for model in provider.provided_models(cx) {
130 let model_info = Self::map_language_model_to_info(&model, &provider);
131 let model_id = model_info.id.clone();
132 provider_models.push(model_info);
133 models.insert(model_id, model);
134 }
135 if !provider_models.is_empty() {
136 language_model_list.insert(
137 acp_thread::AgentModelGroupName(provider.name().0.clone()),
138 provider_models,
139 );
140 }
141 }
142
143 self.models = models;
144 self.model_list = acp_thread::AgentModelList::Grouped(language_model_list);
145 self.refresh_models_tx.send(()).ok();
146 }
147
148 fn watch(&self) -> watch::Receiver<()> {
149 self.refresh_models_rx.clone()
150 }
151
152 pub fn model_from_id(&self, model_id: &acp::ModelId) -> Option<Arc<dyn LanguageModel>> {
153 self.models.get(model_id).cloned()
154 }
155
156 fn map_language_model_to_info(
157 model: &Arc<dyn LanguageModel>,
158 provider: &Arc<dyn LanguageModelProvider>,
159 ) -> acp_thread::AgentModelInfo {
160 acp_thread::AgentModelInfo {
161 id: Self::model_id(model),
162 name: model.name().0,
163 description: None,
164 icon: Some(match provider.icon() {
165 IconOrSvg::Svg(path) => acp_thread::AgentModelIcon::Path(path),
166 IconOrSvg::Icon(name) => acp_thread::AgentModelIcon::Named(name),
167 }),
168 is_latest: model.is_latest(),
169 cost: model.model_cost_info().map(|cost| cost.to_shared_string()),
170 }
171 }
172
173 fn model_id(model: &Arc<dyn LanguageModel>) -> acp::ModelId {
174 acp::ModelId::new(format!("{}/{}", model.provider_id().0, model.id().0))
175 }
176
177 fn authenticate_all_language_model_providers(cx: &mut App) -> Task<()> {
178 let authenticate_all_providers = LanguageModelRegistry::global(cx)
179 .read(cx)
180 .visible_providers()
181 .iter()
182 .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
183 .collect::<Vec<_>>();
184
185 cx.background_spawn(async move {
186 for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
187 if let Err(err) = authenticate_task.await {
188 match err {
189 language_model::AuthenticateError::CredentialsNotFound => {
190 // Since we're authenticating these providers in the
191 // background for the purposes of populating the
192 // language selector, we don't care about providers
193 // where the credentials are not found.
194 }
195 language_model::AuthenticateError::ConnectionRefused => {
196 // Not logging connection refused errors as they are mostly from LM Studio's noisy auth failures.
197 // LM Studio only has one auth method (endpoint call) which fails for users who haven't enabled it.
198 // TODO: Better manage LM Studio auth logic to avoid these noisy failures.
199 }
200 _ => {
201 // Some providers have noisy failure states that we
202 // don't want to spam the logs with every time the
203 // language model selector is initialized.
204 //
205 // Ideally these should have more clear failure modes
206 // that we know are safe to ignore here, like what we do
207 // with `CredentialsNotFound` above.
208 match provider_id.0.as_ref() {
209 "lmstudio" | "ollama" => {
210 // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
211 //
212 // These fail noisily, so we don't log them.
213 }
214 "copilot_chat" => {
215 // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
216 }
217 _ => {
218 log::error!(
219 "Failed to authenticate provider: {}: {err:#}",
220 provider_name.0
221 );
222 }
223 }
224 }
225 }
226 }
227 }
228 })
229 }
230}
231
232pub struct NativeAgent {
233 /// Session ID -> Session mapping
234 sessions: HashMap<acp::SessionId, Session>,
235 thread_store: Entity<ThreadStore>,
236 /// Shared project context for all threads
237 project_context: Entity<ProjectContext>,
238 project_context_needs_refresh: watch::Sender<()>,
239 _maintain_project_context: Task<Result<()>>,
240 context_server_registry: Entity<ContextServerRegistry>,
241 /// Shared templates for all threads
242 templates: Arc<Templates>,
243 /// Cached model information
244 models: LanguageModels,
245 project: Entity<Project>,
246 prompt_store: Option<Entity<PromptStore>>,
247 fs: Arc<dyn Fs>,
248 _subscriptions: Vec<Subscription>,
249}
250
251impl NativeAgent {
252 pub async fn new(
253 project: Entity<Project>,
254 thread_store: Entity<ThreadStore>,
255 templates: Arc<Templates>,
256 prompt_store: Option<Entity<PromptStore>>,
257 fs: Arc<dyn Fs>,
258 cx: &mut AsyncApp,
259 ) -> Result<Entity<NativeAgent>> {
260 log::debug!("Creating new NativeAgent");
261
262 let project_context = cx
263 .update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))
264 .await;
265
266 Ok(cx.new(|cx| {
267 let context_server_store = project.read(cx).context_server_store();
268 let context_server_registry =
269 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
270
271 let mut subscriptions = vec![
272 cx.subscribe(&project, Self::handle_project_event),
273 cx.subscribe(
274 &LanguageModelRegistry::global(cx),
275 Self::handle_models_updated_event,
276 ),
277 cx.subscribe(
278 &context_server_store,
279 Self::handle_context_server_store_updated,
280 ),
281 cx.subscribe(
282 &context_server_registry,
283 Self::handle_context_server_registry_event,
284 ),
285 ];
286 if let Some(prompt_store) = prompt_store.as_ref() {
287 subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
288 }
289
290 let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
291 watch::channel(());
292 Self {
293 sessions: HashMap::default(),
294 thread_store,
295 project_context: cx.new(|_| project_context),
296 project_context_needs_refresh: project_context_needs_refresh_tx,
297 _maintain_project_context: cx.spawn(async move |this, cx| {
298 Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
299 }),
300 context_server_registry,
301 templates,
302 models: LanguageModels::new(cx),
303 project,
304 prompt_store,
305 fs,
306 _subscriptions: subscriptions,
307 }
308 }))
309 }
310
311 fn new_session(
312 &mut self,
313 project: Entity<Project>,
314 cx: &mut Context<Self>,
315 ) -> Entity<AcpThread> {
316 // Create Thread
317 // Fetch default model from registry settings
318 let registry = LanguageModelRegistry::read_global(cx);
319 // Log available models for debugging
320 let available_count = registry.available_models(cx).count();
321 log::debug!("Total available models: {}", available_count);
322
323 let default_model = registry.default_model().and_then(|default_model| {
324 self.models
325 .model_from_id(&LanguageModels::model_id(&default_model.model))
326 });
327 let thread = cx.new(|cx| {
328 Thread::new(
329 project.clone(),
330 self.project_context.clone(),
331 self.context_server_registry.clone(),
332 self.templates.clone(),
333 default_model,
334 cx,
335 )
336 });
337
338 self.register_session(thread, cx)
339 }
340
341 fn register_session(
342 &mut self,
343 thread_handle: Entity<Thread>,
344 cx: &mut Context<Self>,
345 ) -> Entity<AcpThread> {
346 let connection = Rc::new(NativeAgentConnection(cx.entity()));
347
348 let thread = thread_handle.read(cx);
349 let session_id = thread.id().clone();
350 let parent_session_id = thread.parent_thread_id();
351 let title = thread.title();
352 let project = thread.project.clone();
353 let action_log = thread.action_log.clone();
354 let prompt_capabilities_rx = thread.prompt_capabilities_rx.clone();
355 let acp_thread = cx.new(|cx| {
356 acp_thread::AcpThread::new(
357 parent_session_id,
358 title,
359 connection,
360 project.clone(),
361 action_log.clone(),
362 session_id.clone(),
363 prompt_capabilities_rx,
364 cx,
365 )
366 });
367
368 let registry = LanguageModelRegistry::read_global(cx);
369 let summarization_model = registry.thread_summary_model().map(|c| c.model);
370
371 let weak = cx.weak_entity();
372 thread_handle.update(cx, |thread, cx| {
373 thread.set_summarization_model(summarization_model, cx);
374 thread.add_default_tools(
375 Rc::new(NativeThreadEnvironment {
376 acp_thread: acp_thread.downgrade(),
377 agent: weak,
378 }) as _,
379 cx,
380 )
381 });
382
383 let subscriptions = vec![
384 cx.subscribe(&thread_handle, Self::handle_thread_title_updated),
385 cx.subscribe(&thread_handle, Self::handle_thread_token_usage_updated),
386 cx.observe(&thread_handle, move |this, thread, cx| {
387 this.save_thread(thread, cx)
388 }),
389 ];
390
391 self.sessions.insert(
392 session_id,
393 Session {
394 thread: thread_handle,
395 acp_thread: acp_thread.clone(),
396 _subscriptions: subscriptions,
397 pending_save: Task::ready(()),
398 },
399 );
400
401 self.update_available_commands(cx);
402
403 acp_thread
404 }
405
406 pub fn models(&self) -> &LanguageModels {
407 &self.models
408 }
409
410 async fn maintain_project_context(
411 this: WeakEntity<Self>,
412 mut needs_refresh: watch::Receiver<()>,
413 cx: &mut AsyncApp,
414 ) -> Result<()> {
415 while needs_refresh.changed().await.is_ok() {
416 let project_context = this
417 .update(cx, |this, cx| {
418 Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx)
419 })?
420 .await;
421 this.update(cx, |this, cx| {
422 this.project_context = cx.new(|_| project_context);
423 })?;
424 }
425
426 Ok(())
427 }
428
429 fn build_project_context(
430 project: &Entity<Project>,
431 prompt_store: Option<&Entity<PromptStore>>,
432 cx: &mut App,
433 ) -> Task<ProjectContext> {
434 let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
435 let worktree_tasks = worktrees
436 .into_iter()
437 .map(|worktree| {
438 Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
439 })
440 .collect::<Vec<_>>();
441 let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
442 prompt_store.read_with(cx, |prompt_store, cx| {
443 let prompts = prompt_store.default_prompt_metadata();
444 let load_tasks = prompts.into_iter().map(|prompt_metadata| {
445 let contents = prompt_store.load(prompt_metadata.id, cx);
446 async move { (contents.await, prompt_metadata) }
447 });
448 cx.background_spawn(future::join_all(load_tasks))
449 })
450 } else {
451 Task::ready(vec![])
452 };
453
454 cx.spawn(async move |_cx| {
455 let (worktrees, default_user_rules) =
456 future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
457
458 let worktrees = worktrees
459 .into_iter()
460 .map(|(worktree, _rules_error)| {
461 // TODO: show error message
462 // if let Some(rules_error) = rules_error {
463 // this.update(cx, |_, cx| cx.emit(rules_error)).ok();
464 // }
465 worktree
466 })
467 .collect::<Vec<_>>();
468
469 let default_user_rules = default_user_rules
470 .into_iter()
471 .flat_map(|(contents, prompt_metadata)| match contents {
472 Ok(contents) => Some(UserRulesContext {
473 uuid: prompt_metadata.id.as_user()?,
474 title: prompt_metadata.title.map(|title| title.to_string()),
475 contents,
476 }),
477 Err(_err) => {
478 // TODO: show error message
479 // this.update(cx, |_, cx| {
480 // cx.emit(RulesLoadingError {
481 // message: format!("{err:?}").into(),
482 // });
483 // })
484 // .ok();
485 None
486 }
487 })
488 .collect::<Vec<_>>();
489
490 ProjectContext::new(worktrees, default_user_rules)
491 })
492 }
493
494 fn load_worktree_info_for_system_prompt(
495 worktree: Entity<Worktree>,
496 project: Entity<Project>,
497 cx: &mut App,
498 ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
499 let tree = worktree.read(cx);
500 let root_name = tree.root_name_str().into();
501 let abs_path = tree.abs_path();
502
503 let mut context = WorktreeContext {
504 root_name,
505 abs_path,
506 rules_file: None,
507 };
508
509 let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
510 let Some(rules_task) = rules_task else {
511 return Task::ready((context, None));
512 };
513
514 cx.spawn(async move |_| {
515 let (rules_file, rules_file_error) = match rules_task.await {
516 Ok(rules_file) => (Some(rules_file), None),
517 Err(err) => (
518 None,
519 Some(RulesLoadingError {
520 message: format!("{err}").into(),
521 }),
522 ),
523 };
524 context.rules_file = rules_file;
525 (context, rules_file_error)
526 })
527 }
528
529 fn load_worktree_rules_file(
530 worktree: Entity<Worktree>,
531 project: Entity<Project>,
532 cx: &mut App,
533 ) -> Option<Task<Result<RulesFileContext>>> {
534 let worktree = worktree.read(cx);
535 let worktree_id = worktree.id();
536 let selected_rules_file = RULES_FILE_NAMES
537 .into_iter()
538 .filter_map(|name| {
539 worktree
540 .entry_for_path(RelPath::unix(name).unwrap())
541 .filter(|entry| entry.is_file())
542 .map(|entry| entry.path.clone())
543 })
544 .next();
545
546 // Note that Cline supports `.clinerules` being a directory, but that is not currently
547 // supported. This doesn't seem to occur often in GitHub repositories.
548 selected_rules_file.map(|path_in_worktree| {
549 let project_path = ProjectPath {
550 worktree_id,
551 path: path_in_worktree.clone(),
552 };
553 let buffer_task =
554 project.update(cx, |project, cx| project.open_buffer(project_path, cx));
555 let rope_task = cx.spawn(async move |cx| {
556 let buffer = buffer_task.await?;
557 let (project_entry_id, rope) = buffer.read_with(cx, |buffer, cx| {
558 let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
559 anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
560 })?;
561 anyhow::Ok((project_entry_id, rope))
562 });
563 // Build a string from the rope on a background thread.
564 cx.background_spawn(async move {
565 let (project_entry_id, rope) = rope_task.await?;
566 anyhow::Ok(RulesFileContext {
567 path_in_worktree,
568 text: rope.to_string().trim().to_string(),
569 project_entry_id: project_entry_id.to_usize(),
570 })
571 })
572 })
573 }
574
575 fn handle_thread_title_updated(
576 &mut self,
577 thread: Entity<Thread>,
578 _: &TitleUpdated,
579 cx: &mut Context<Self>,
580 ) {
581 let session_id = thread.read(cx).id();
582 let Some(session) = self.sessions.get(session_id) else {
583 return;
584 };
585 let thread = thread.downgrade();
586 let acp_thread = session.acp_thread.downgrade();
587 cx.spawn(async move |_, cx| {
588 let title = thread.read_with(cx, |thread, _| thread.title())?;
589 let task = acp_thread.update(cx, |acp_thread, cx| acp_thread.set_title(title, cx))?;
590 task.await
591 })
592 .detach_and_log_err(cx);
593 }
594
595 fn handle_thread_token_usage_updated(
596 &mut self,
597 thread: Entity<Thread>,
598 usage: &TokenUsageUpdated,
599 cx: &mut Context<Self>,
600 ) {
601 let Some(session) = self.sessions.get(thread.read(cx).id()) else {
602 return;
603 };
604 session.acp_thread.update(cx, |acp_thread, cx| {
605 acp_thread.update_token_usage(usage.0.clone(), cx);
606 });
607 }
608
609 fn handle_project_event(
610 &mut self,
611 _project: Entity<Project>,
612 event: &project::Event,
613 _cx: &mut Context<Self>,
614 ) {
615 match event {
616 project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
617 self.project_context_needs_refresh.send(()).ok();
618 }
619 project::Event::WorktreeUpdatedEntries(_, items) => {
620 if items.iter().any(|(path, _, _)| {
621 RULES_FILE_NAMES
622 .iter()
623 .any(|name| path.as_ref() == RelPath::unix(name).unwrap())
624 }) {
625 self.project_context_needs_refresh.send(()).ok();
626 }
627 }
628 _ => {}
629 }
630 }
631
632 fn handle_prompts_updated_event(
633 &mut self,
634 _prompt_store: Entity<PromptStore>,
635 _event: &prompt_store::PromptsUpdatedEvent,
636 _cx: &mut Context<Self>,
637 ) {
638 self.project_context_needs_refresh.send(()).ok();
639 }
640
641 fn handle_models_updated_event(
642 &mut self,
643 _registry: Entity<LanguageModelRegistry>,
644 _event: &language_model::Event,
645 cx: &mut Context<Self>,
646 ) {
647 self.models.refresh_list(cx);
648
649 let registry = LanguageModelRegistry::read_global(cx);
650 let default_model = registry.default_model().map(|m| m.model);
651 let summarization_model = registry.thread_summary_model().map(|m| m.model);
652
653 for session in self.sessions.values_mut() {
654 session.thread.update(cx, |thread, cx| {
655 if thread.model().is_none()
656 && let Some(model) = default_model.clone()
657 {
658 thread.set_model(model, cx);
659 cx.notify();
660 }
661 thread.set_summarization_model(summarization_model.clone(), cx);
662 });
663 }
664 }
665
666 fn handle_context_server_store_updated(
667 &mut self,
668 _store: Entity<project::context_server_store::ContextServerStore>,
669 _event: &project::context_server_store::ServerStatusChangedEvent,
670 cx: &mut Context<Self>,
671 ) {
672 self.update_available_commands(cx);
673 }
674
675 fn handle_context_server_registry_event(
676 &mut self,
677 _registry: Entity<ContextServerRegistry>,
678 event: &ContextServerRegistryEvent,
679 cx: &mut Context<Self>,
680 ) {
681 match event {
682 ContextServerRegistryEvent::ToolsChanged => {}
683 ContextServerRegistryEvent::PromptsChanged => {
684 self.update_available_commands(cx);
685 }
686 }
687 }
688
689 fn update_available_commands(&self, cx: &mut Context<Self>) {
690 let available_commands = self.build_available_commands(cx);
691 for session in self.sessions.values() {
692 session.acp_thread.update(cx, |thread, cx| {
693 thread
694 .handle_session_update(
695 acp::SessionUpdate::AvailableCommandsUpdate(
696 acp::AvailableCommandsUpdate::new(available_commands.clone()),
697 ),
698 cx,
699 )
700 .log_err();
701 });
702 }
703 }
704
705 fn build_available_commands(&self, cx: &App) -> Vec<acp::AvailableCommand> {
706 let registry = self.context_server_registry.read(cx);
707
708 let mut prompt_name_counts: HashMap<&str, usize> = HashMap::default();
709 for context_server_prompt in registry.prompts() {
710 *prompt_name_counts
711 .entry(context_server_prompt.prompt.name.as_str())
712 .or_insert(0) += 1;
713 }
714
715 registry
716 .prompts()
717 .flat_map(|context_server_prompt| {
718 let prompt = &context_server_prompt.prompt;
719
720 let should_prefix = prompt_name_counts
721 .get(prompt.name.as_str())
722 .copied()
723 .unwrap_or(0)
724 > 1;
725
726 let name = if should_prefix {
727 format!("{}.{}", context_server_prompt.server_id, prompt.name)
728 } else {
729 prompt.name.clone()
730 };
731
732 let mut command = acp::AvailableCommand::new(
733 name,
734 prompt.description.clone().unwrap_or_default(),
735 );
736
737 match prompt.arguments.as_deref() {
738 Some([arg]) => {
739 let hint = format!("<{}>", arg.name);
740
741 command = command.input(acp::AvailableCommandInput::Unstructured(
742 acp::UnstructuredCommandInput::new(hint),
743 ));
744 }
745 Some([]) | None => {}
746 Some(_) => {
747 // skip >1 argument commands since we don't support them yet
748 return None;
749 }
750 }
751
752 Some(command)
753 })
754 .collect()
755 }
756
757 pub fn load_thread(
758 &mut self,
759 id: acp::SessionId,
760 cx: &mut Context<Self>,
761 ) -> Task<Result<Entity<Thread>>> {
762 let database_future = ThreadsDatabase::connect(cx);
763 cx.spawn(async move |this, cx| {
764 let database = database_future.await.map_err(|err| anyhow!(err))?;
765 let db_thread = database
766 .load_thread(id.clone())
767 .await?
768 .with_context(|| format!("no thread found with ID: {id:?}"))?;
769
770 this.update(cx, |this, cx| {
771 let summarization_model = LanguageModelRegistry::read_global(cx)
772 .thread_summary_model()
773 .map(|c| c.model);
774
775 cx.new(|cx| {
776 let mut thread = Thread::from_db(
777 id.clone(),
778 db_thread,
779 this.project.clone(),
780 this.project_context.clone(),
781 this.context_server_registry.clone(),
782 this.templates.clone(),
783 cx,
784 );
785 thread.set_summarization_model(summarization_model, cx);
786 thread
787 })
788 })
789 })
790 }
791
792 pub fn open_thread(
793 &mut self,
794 id: acp::SessionId,
795 cx: &mut Context<Self>,
796 ) -> Task<Result<Entity<AcpThread>>> {
797 if let Some(session) = self.sessions.get(&id) {
798 return Task::ready(Ok(session.acp_thread.clone()));
799 }
800
801 let task = self.load_thread(id, cx);
802 cx.spawn(async move |this, cx| {
803 let thread = task.await?;
804 let acp_thread =
805 this.update(cx, |this, cx| this.register_session(thread.clone(), cx))?;
806 let events = thread.update(cx, |thread, cx| thread.replay(cx));
807 cx.update(|cx| {
808 NativeAgentConnection::handle_thread_events(events, acp_thread.downgrade(), cx)
809 })
810 .await?;
811 Ok(acp_thread)
812 })
813 }
814
815 pub fn thread_summary(
816 &mut self,
817 id: acp::SessionId,
818 cx: &mut Context<Self>,
819 ) -> Task<Result<SharedString>> {
820 let thread = self.open_thread(id.clone(), cx);
821 cx.spawn(async move |this, cx| {
822 let acp_thread = thread.await?;
823 let result = this
824 .update(cx, |this, cx| {
825 this.sessions
826 .get(&id)
827 .unwrap()
828 .thread
829 .update(cx, |thread, cx| thread.summary(cx))
830 })?
831 .await
832 .context("Failed to generate summary")?;
833 drop(acp_thread);
834 Ok(result)
835 })
836 }
837
838 fn save_thread(&mut self, thread: Entity<Thread>, cx: &mut Context<Self>) {
839 if thread.read(cx).is_empty() {
840 return;
841 }
842
843 let database_future = ThreadsDatabase::connect(cx);
844 let (id, db_thread) =
845 thread.update(cx, |thread, cx| (thread.id().clone(), thread.to_db(cx)));
846 let Some(session) = self.sessions.get_mut(&id) else {
847 return;
848 };
849 let thread_store = self.thread_store.clone();
850 session.pending_save = cx.spawn(async move |_, cx| {
851 let Some(database) = database_future.await.map_err(|err| anyhow!(err)).log_err() else {
852 return;
853 };
854 let db_thread = db_thread.await;
855 database.save_thread(id, db_thread).await.log_err();
856 thread_store.update(cx, |store, cx| store.reload(cx));
857 });
858 }
859
860 fn send_mcp_prompt(
861 &self,
862 message_id: UserMessageId,
863 session_id: agent_client_protocol::SessionId,
864 prompt_name: String,
865 server_id: ContextServerId,
866 arguments: HashMap<String, String>,
867 original_content: Vec<acp::ContentBlock>,
868 cx: &mut Context<Self>,
869 ) -> Task<Result<acp::PromptResponse>> {
870 let server_store = self.context_server_registry.read(cx).server_store().clone();
871 let path_style = self.project.read(cx).path_style(cx);
872
873 cx.spawn(async move |this, cx| {
874 let prompt =
875 crate::get_prompt(&server_store, &server_id, &prompt_name, arguments, cx).await?;
876
877 let (acp_thread, thread) = this.update(cx, |this, _cx| {
878 let session = this
879 .sessions
880 .get(&session_id)
881 .context("Failed to get session")?;
882 anyhow::Ok((session.acp_thread.clone(), session.thread.clone()))
883 })??;
884
885 let mut last_is_user = true;
886
887 thread.update(cx, |thread, cx| {
888 thread.push_acp_user_block(
889 message_id,
890 original_content.into_iter().skip(1),
891 path_style,
892 cx,
893 );
894 });
895
896 for message in prompt.messages {
897 let context_server::types::PromptMessage { role, content } = message;
898 let block = mcp_message_content_to_acp_content_block(content);
899
900 match role {
901 context_server::types::Role::User => {
902 let id = acp_thread::UserMessageId::new();
903
904 acp_thread.update(cx, |acp_thread, cx| {
905 acp_thread.push_user_content_block_with_indent(
906 Some(id.clone()),
907 block.clone(),
908 true,
909 cx,
910 );
911 });
912
913 thread.update(cx, |thread, cx| {
914 thread.push_acp_user_block(id, [block], path_style, cx);
915 });
916 }
917 context_server::types::Role::Assistant => {
918 acp_thread.update(cx, |acp_thread, cx| {
919 acp_thread.push_assistant_content_block_with_indent(
920 block.clone(),
921 false,
922 true,
923 cx,
924 );
925 });
926
927 thread.update(cx, |thread, cx| {
928 thread.push_acp_agent_block(block, cx);
929 });
930 }
931 }
932
933 last_is_user = role == context_server::types::Role::User;
934 }
935
936 let response_stream = thread.update(cx, |thread, cx| {
937 if last_is_user {
938 thread.send_existing(cx)
939 } else {
940 // Resume if MCP prompt did not end with a user message
941 thread.resume(cx)
942 }
943 })?;
944
945 cx.update(|cx| {
946 NativeAgentConnection::handle_thread_events(
947 response_stream,
948 acp_thread.downgrade(),
949 cx,
950 )
951 })
952 .await
953 })
954 }
955}
956
957/// Wrapper struct that implements the AgentConnection trait
958#[derive(Clone)]
959pub struct NativeAgentConnection(pub Entity<NativeAgent>);
960
961impl NativeAgentConnection {
962 pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option<Entity<Thread>> {
963 self.0
964 .read(cx)
965 .sessions
966 .get(session_id)
967 .map(|session| session.thread.clone())
968 }
969
970 pub fn load_thread(&self, id: acp::SessionId, cx: &mut App) -> Task<Result<Entity<Thread>>> {
971 self.0.update(cx, |this, cx| this.load_thread(id, cx))
972 }
973
974 fn run_turn(
975 &self,
976 session_id: acp::SessionId,
977 cx: &mut App,
978 f: impl 'static
979 + FnOnce(Entity<Thread>, &mut App) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>,
980 ) -> Task<Result<acp::PromptResponse>> {
981 let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
982 agent
983 .sessions
984 .get_mut(&session_id)
985 .map(|s| (s.thread.clone(), s.acp_thread.clone()))
986 }) else {
987 return Task::ready(Err(anyhow!("Session not found")));
988 };
989 log::debug!("Found session for: {}", session_id);
990
991 let response_stream = match f(thread, cx) {
992 Ok(stream) => stream,
993 Err(err) => return Task::ready(Err(err)),
994 };
995 Self::handle_thread_events(response_stream, acp_thread.downgrade(), cx)
996 }
997
998 fn handle_thread_events(
999 mut events: mpsc::UnboundedReceiver<Result<ThreadEvent>>,
1000 acp_thread: WeakEntity<AcpThread>,
1001 cx: &App,
1002 ) -> Task<Result<acp::PromptResponse>> {
1003 cx.spawn(async move |cx| {
1004 // Handle response stream and forward to session.acp_thread
1005 while let Some(result) = events.next().await {
1006 match result {
1007 Ok(event) => {
1008 log::trace!("Received completion event: {:?}", event);
1009
1010 match event {
1011 ThreadEvent::UserMessage(message) => {
1012 acp_thread.update(cx, |thread, cx| {
1013 for content in message.content {
1014 thread.push_user_content_block(
1015 Some(message.id.clone()),
1016 content.into(),
1017 cx,
1018 );
1019 }
1020 })?;
1021 }
1022 ThreadEvent::AgentText(text) => {
1023 acp_thread.update(cx, |thread, cx| {
1024 thread.push_assistant_content_block(text.into(), false, cx)
1025 })?;
1026 }
1027 ThreadEvent::AgentThinking(text) => {
1028 acp_thread.update(cx, |thread, cx| {
1029 thread.push_assistant_content_block(text.into(), true, cx)
1030 })?;
1031 }
1032 ThreadEvent::ToolCallAuthorization(ToolCallAuthorization {
1033 tool_call,
1034 options,
1035 response,
1036 context: _,
1037 }) => {
1038 let outcome_task = acp_thread.update(cx, |thread, cx| {
1039 thread.request_tool_call_authorization(tool_call, options, cx)
1040 })??;
1041 cx.background_spawn(async move {
1042 if let acp::RequestPermissionOutcome::Selected(
1043 acp::SelectedPermissionOutcome { option_id, .. },
1044 ) = outcome_task.await
1045 {
1046 response
1047 .send(option_id)
1048 .map(|_| anyhow!("authorization receiver was dropped"))
1049 .log_err();
1050 }
1051 })
1052 .detach();
1053 }
1054 ThreadEvent::ToolCall(tool_call) => {
1055 acp_thread.update(cx, |thread, cx| {
1056 thread.upsert_tool_call(tool_call, cx)
1057 })??;
1058 }
1059 ThreadEvent::ToolCallUpdate(update) => {
1060 acp_thread.update(cx, |thread, cx| {
1061 thread.update_tool_call(update, cx)
1062 })??;
1063 }
1064 ThreadEvent::SubagentSpawned(session_id) => {
1065 acp_thread.update(cx, |thread, cx| {
1066 thread.subagent_spawned(session_id, cx);
1067 })?;
1068 }
1069 ThreadEvent::Retry(status) => {
1070 acp_thread.update(cx, |thread, cx| {
1071 thread.update_retry_status(status, cx)
1072 })?;
1073 }
1074 ThreadEvent::Stop(stop_reason) => {
1075 log::debug!("Assistant message complete: {:?}", stop_reason);
1076 return Ok(acp::PromptResponse::new(stop_reason));
1077 }
1078 }
1079 }
1080 Err(e) => {
1081 log::error!("Error in model response stream: {:?}", e);
1082 return Err(e);
1083 }
1084 }
1085 }
1086
1087 log::debug!("Response stream completed");
1088 anyhow::Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
1089 })
1090 }
1091}
1092
1093struct Command<'a> {
1094 prompt_name: &'a str,
1095 arg_value: &'a str,
1096 explicit_server_id: Option<&'a str>,
1097}
1098
1099impl<'a> Command<'a> {
1100 fn parse(prompt: &'a [acp::ContentBlock]) -> Option<Self> {
1101 let acp::ContentBlock::Text(text_content) = prompt.first()? else {
1102 return None;
1103 };
1104 let text = text_content.text.trim();
1105 let command = text.strip_prefix('/')?;
1106 let (command, arg_value) = command
1107 .split_once(char::is_whitespace)
1108 .unwrap_or((command, ""));
1109
1110 if let Some((server_id, prompt_name)) = command.split_once('.') {
1111 Some(Self {
1112 prompt_name,
1113 arg_value,
1114 explicit_server_id: Some(server_id),
1115 })
1116 } else {
1117 Some(Self {
1118 prompt_name: command,
1119 arg_value,
1120 explicit_server_id: None,
1121 })
1122 }
1123 }
1124}
1125
1126struct NativeAgentModelSelector {
1127 session_id: acp::SessionId,
1128 connection: NativeAgentConnection,
1129}
1130
1131impl acp_thread::AgentModelSelector for NativeAgentModelSelector {
1132 fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
1133 log::debug!("NativeAgentConnection::list_models called");
1134 let list = self.connection.0.read(cx).models.model_list.clone();
1135 Task::ready(if list.is_empty() {
1136 Err(anyhow::anyhow!("No models available"))
1137 } else {
1138 Ok(list)
1139 })
1140 }
1141
1142 fn select_model(&self, model_id: acp::ModelId, cx: &mut App) -> Task<Result<()>> {
1143 log::debug!(
1144 "Setting model for session {}: {}",
1145 self.session_id,
1146 model_id
1147 );
1148 let Some(thread) = self
1149 .connection
1150 .0
1151 .read(cx)
1152 .sessions
1153 .get(&self.session_id)
1154 .map(|session| session.thread.clone())
1155 else {
1156 return Task::ready(Err(anyhow!("Session not found")));
1157 };
1158
1159 let Some(model) = self.connection.0.read(cx).models.model_from_id(&model_id) else {
1160 return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
1161 };
1162
1163 // We want to reset the effort level when switching models, as the currently-selected effort level may
1164 // not be compatible.
1165 let effort = model
1166 .default_effort_level()
1167 .map(|effort_level| effort_level.value.to_string());
1168
1169 thread.update(cx, |thread, cx| {
1170 thread.set_model(model.clone(), cx);
1171 thread.set_thinking_effort(effort.clone(), cx);
1172 thread.set_thinking_enabled(model.supports_thinking(), cx);
1173 });
1174
1175 update_settings_file(
1176 self.connection.0.read(cx).fs.clone(),
1177 cx,
1178 move |settings, cx| {
1179 let provider = model.provider_id().0.to_string();
1180 let model = model.id().0.to_string();
1181 let enable_thinking = thread.read(cx).thinking_enabled();
1182 settings
1183 .agent
1184 .get_or_insert_default()
1185 .set_model(LanguageModelSelection {
1186 provider: provider.into(),
1187 model,
1188 enable_thinking,
1189 effort,
1190 });
1191 },
1192 );
1193
1194 Task::ready(Ok(()))
1195 }
1196
1197 fn selected_model(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelInfo>> {
1198 let Some(thread) = self
1199 .connection
1200 .0
1201 .read(cx)
1202 .sessions
1203 .get(&self.session_id)
1204 .map(|session| session.thread.clone())
1205 else {
1206 return Task::ready(Err(anyhow!("Session not found")));
1207 };
1208 let Some(model) = thread.read(cx).model() else {
1209 return Task::ready(Err(anyhow!("Model not found")));
1210 };
1211 let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
1212 else {
1213 return Task::ready(Err(anyhow!("Provider not found")));
1214 };
1215 Task::ready(Ok(LanguageModels::map_language_model_to_info(
1216 model, &provider,
1217 )))
1218 }
1219
1220 fn watch(&self, cx: &mut App) -> Option<watch::Receiver<()>> {
1221 Some(self.connection.0.read(cx).models.watch())
1222 }
1223
1224 fn should_render_footer(&self) -> bool {
1225 true
1226 }
1227}
1228
1229impl acp_thread::AgentConnection for NativeAgentConnection {
1230 fn telemetry_id(&self) -> SharedString {
1231 "zed".into()
1232 }
1233
1234 fn new_session(
1235 self: Rc<Self>,
1236 project: Entity<Project>,
1237 cwd: &Path,
1238 cx: &mut App,
1239 ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
1240 log::debug!("Creating new thread for project at: {cwd:?}");
1241 Task::ready(Ok(self
1242 .0
1243 .update(cx, |agent, cx| agent.new_session(project, cx))))
1244 }
1245
1246 fn supports_load_session(&self) -> bool {
1247 true
1248 }
1249
1250 fn load_session(
1251 self: Rc<Self>,
1252 session: AgentSessionInfo,
1253 _project: Entity<Project>,
1254 _cwd: &Path,
1255 cx: &mut App,
1256 ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
1257 self.0
1258 .update(cx, |agent, cx| agent.open_thread(session.session_id, cx))
1259 }
1260
1261 fn supports_close_session(&self) -> bool {
1262 true
1263 }
1264
1265 fn close_session(&self, session_id: &acp::SessionId, cx: &mut App) -> Task<Result<()>> {
1266 self.0.update(cx, |agent, _cx| {
1267 agent.sessions.remove(session_id);
1268 });
1269 Task::ready(Ok(()))
1270 }
1271
1272 fn auth_methods(&self) -> &[acp::AuthMethod] {
1273 &[] // No auth for in-process
1274 }
1275
1276 fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
1277 Task::ready(Ok(()))
1278 }
1279
1280 fn model_selector(&self, session_id: &acp::SessionId) -> Option<Rc<dyn AgentModelSelector>> {
1281 Some(Rc::new(NativeAgentModelSelector {
1282 session_id: session_id.clone(),
1283 connection: self.clone(),
1284 }) as Rc<dyn AgentModelSelector>)
1285 }
1286
1287 fn prompt(
1288 &self,
1289 id: Option<acp_thread::UserMessageId>,
1290 params: acp::PromptRequest,
1291 cx: &mut App,
1292 ) -> Task<Result<acp::PromptResponse>> {
1293 let id = id.expect("UserMessageId is required");
1294 let session_id = params.session_id.clone();
1295 log::info!("Received prompt request for session: {}", session_id);
1296 log::debug!("Prompt blocks count: {}", params.prompt.len());
1297
1298 if let Some(parsed_command) = Command::parse(¶ms.prompt) {
1299 let registry = self.0.read(cx).context_server_registry.read(cx);
1300
1301 let explicit_server_id = parsed_command
1302 .explicit_server_id
1303 .map(|server_id| ContextServerId(server_id.into()));
1304
1305 if let Some(prompt) =
1306 registry.find_prompt(explicit_server_id.as_ref(), parsed_command.prompt_name)
1307 {
1308 let arguments = if !parsed_command.arg_value.is_empty()
1309 && let Some(arg_name) = prompt
1310 .prompt
1311 .arguments
1312 .as_ref()
1313 .and_then(|args| args.first())
1314 .map(|arg| arg.name.clone())
1315 {
1316 HashMap::from_iter([(arg_name, parsed_command.arg_value.to_string())])
1317 } else {
1318 Default::default()
1319 };
1320
1321 let prompt_name = prompt.prompt.name.clone();
1322 let server_id = prompt.server_id.clone();
1323
1324 return self.0.update(cx, |agent, cx| {
1325 agent.send_mcp_prompt(
1326 id,
1327 session_id.clone(),
1328 prompt_name,
1329 server_id,
1330 arguments,
1331 params.prompt,
1332 cx,
1333 )
1334 });
1335 };
1336 };
1337
1338 let path_style = self.0.read(cx).project.read(cx).path_style(cx);
1339
1340 self.run_turn(session_id, cx, move |thread, cx| {
1341 let content: Vec<UserMessageContent> = params
1342 .prompt
1343 .into_iter()
1344 .map(|block| UserMessageContent::from_content_block(block, path_style))
1345 .collect::<Vec<_>>();
1346 log::debug!("Converted prompt to message: {} chars", content.len());
1347 log::debug!("Message id: {:?}", id);
1348 log::debug!("Message content: {:?}", content);
1349
1350 thread.update(cx, |thread, cx| thread.send(id, content, cx))
1351 })
1352 }
1353
1354 fn retry(
1355 &self,
1356 session_id: &acp::SessionId,
1357 _cx: &App,
1358 ) -> Option<Rc<dyn acp_thread::AgentSessionRetry>> {
1359 Some(Rc::new(NativeAgentSessionRetry {
1360 connection: self.clone(),
1361 session_id: session_id.clone(),
1362 }) as _)
1363 }
1364
1365 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
1366 log::info!("Cancelling on session: {}", session_id);
1367 self.0.update(cx, |agent, cx| {
1368 if let Some(session) = agent.sessions.get(session_id) {
1369 session
1370 .thread
1371 .update(cx, |thread, cx| thread.cancel(cx))
1372 .detach();
1373 }
1374 });
1375 }
1376
1377 fn truncate(
1378 &self,
1379 session_id: &agent_client_protocol::SessionId,
1380 cx: &App,
1381 ) -> Option<Rc<dyn acp_thread::AgentSessionTruncate>> {
1382 self.0.read_with(cx, |agent, _cx| {
1383 agent.sessions.get(session_id).map(|session| {
1384 Rc::new(NativeAgentSessionTruncate {
1385 thread: session.thread.clone(),
1386 acp_thread: session.acp_thread.downgrade(),
1387 }) as _
1388 })
1389 })
1390 }
1391
1392 fn set_title(
1393 &self,
1394 session_id: &acp::SessionId,
1395 cx: &App,
1396 ) -> Option<Rc<dyn acp_thread::AgentSessionSetTitle>> {
1397 self.0.read_with(cx, |agent, _cx| {
1398 agent
1399 .sessions
1400 .get(session_id)
1401 .filter(|s| !s.thread.read(cx).is_subagent())
1402 .map(|session| {
1403 Rc::new(NativeAgentSessionSetTitle {
1404 thread: session.thread.clone(),
1405 }) as _
1406 })
1407 })
1408 }
1409
1410 fn session_list(&self, cx: &mut App) -> Option<Rc<dyn AgentSessionList>> {
1411 let thread_store = self.0.read(cx).thread_store.clone();
1412 Some(Rc::new(NativeAgentSessionList::new(thread_store, cx)) as _)
1413 }
1414
1415 fn telemetry(&self) -> Option<Rc<dyn acp_thread::AgentTelemetry>> {
1416 Some(Rc::new(self.clone()) as Rc<dyn acp_thread::AgentTelemetry>)
1417 }
1418
1419 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1420 self
1421 }
1422}
1423
1424impl acp_thread::AgentTelemetry for NativeAgentConnection {
1425 fn thread_data(
1426 &self,
1427 session_id: &acp::SessionId,
1428 cx: &mut App,
1429 ) -> Task<Result<serde_json::Value>> {
1430 let Some(session) = self.0.read(cx).sessions.get(session_id) else {
1431 return Task::ready(Err(anyhow!("Session not found")));
1432 };
1433
1434 let task = session.thread.read(cx).to_db(cx);
1435 cx.background_spawn(async move {
1436 serde_json::to_value(task.await).context("Failed to serialize thread")
1437 })
1438 }
1439}
1440
1441pub struct NativeAgentSessionList {
1442 thread_store: Entity<ThreadStore>,
1443 updates_tx: smol::channel::Sender<acp_thread::SessionListUpdate>,
1444 updates_rx: smol::channel::Receiver<acp_thread::SessionListUpdate>,
1445 _subscription: Subscription,
1446}
1447
1448impl NativeAgentSessionList {
1449 fn new(thread_store: Entity<ThreadStore>, cx: &mut App) -> Self {
1450 let (tx, rx) = smol::channel::unbounded();
1451 let this_tx = tx.clone();
1452 let subscription = cx.observe(&thread_store, move |_, _| {
1453 this_tx
1454 .try_send(acp_thread::SessionListUpdate::Refresh)
1455 .ok();
1456 });
1457 Self {
1458 thread_store,
1459 updates_tx: tx,
1460 updates_rx: rx,
1461 _subscription: subscription,
1462 }
1463 }
1464
1465 fn to_session_info(entry: DbThreadMetadata) -> AgentSessionInfo {
1466 AgentSessionInfo {
1467 session_id: entry.id,
1468 cwd: None,
1469 title: Some(entry.title),
1470 updated_at: Some(entry.updated_at),
1471 meta: None,
1472 }
1473 }
1474
1475 pub fn thread_store(&self) -> &Entity<ThreadStore> {
1476 &self.thread_store
1477 }
1478}
1479
1480impl AgentSessionList for NativeAgentSessionList {
1481 fn list_sessions(
1482 &self,
1483 _request: AgentSessionListRequest,
1484 cx: &mut App,
1485 ) -> Task<Result<AgentSessionListResponse>> {
1486 let sessions = self
1487 .thread_store
1488 .read(cx)
1489 .entries()
1490 .map(Self::to_session_info)
1491 .collect();
1492 Task::ready(Ok(AgentSessionListResponse::new(sessions)))
1493 }
1494
1495 fn supports_delete(&self) -> bool {
1496 true
1497 }
1498
1499 fn delete_session(&self, session_id: &acp::SessionId, cx: &mut App) -> Task<Result<()>> {
1500 self.thread_store
1501 .update(cx, |store, cx| store.delete_thread(session_id.clone(), cx))
1502 }
1503
1504 fn delete_sessions(&self, cx: &mut App) -> Task<Result<()>> {
1505 self.thread_store
1506 .update(cx, |store, cx| store.delete_threads(cx))
1507 }
1508
1509 fn watch(
1510 &self,
1511 _cx: &mut App,
1512 ) -> Option<smol::channel::Receiver<acp_thread::SessionListUpdate>> {
1513 Some(self.updates_rx.clone())
1514 }
1515
1516 fn notify_refresh(&self) {
1517 self.updates_tx
1518 .try_send(acp_thread::SessionListUpdate::Refresh)
1519 .ok();
1520 }
1521
1522 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1523 self
1524 }
1525}
1526
1527struct NativeAgentSessionTruncate {
1528 thread: Entity<Thread>,
1529 acp_thread: WeakEntity<AcpThread>,
1530}
1531
1532impl acp_thread::AgentSessionTruncate for NativeAgentSessionTruncate {
1533 fn run(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
1534 match self.thread.update(cx, |thread, cx| {
1535 thread.truncate(message_id.clone(), cx)?;
1536 Ok(thread.latest_token_usage())
1537 }) {
1538 Ok(usage) => {
1539 self.acp_thread
1540 .update(cx, |thread, cx| {
1541 thread.update_token_usage(usage, cx);
1542 })
1543 .ok();
1544 Task::ready(Ok(()))
1545 }
1546 Err(error) => Task::ready(Err(error)),
1547 }
1548 }
1549}
1550
1551struct NativeAgentSessionRetry {
1552 connection: NativeAgentConnection,
1553 session_id: acp::SessionId,
1554}
1555
1556impl acp_thread::AgentSessionRetry for NativeAgentSessionRetry {
1557 fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>> {
1558 self.connection
1559 .run_turn(self.session_id.clone(), cx, |thread, cx| {
1560 thread.update(cx, |thread, cx| thread.resume(cx))
1561 })
1562 }
1563}
1564
1565struct NativeAgentSessionSetTitle {
1566 thread: Entity<Thread>,
1567}
1568
1569impl acp_thread::AgentSessionSetTitle for NativeAgentSessionSetTitle {
1570 fn run(&self, title: SharedString, cx: &mut App) -> Task<Result<()>> {
1571 self.thread
1572 .update(cx, |thread, cx| thread.set_title(title, cx));
1573 Task::ready(Ok(()))
1574 }
1575}
1576
1577pub struct NativeThreadEnvironment {
1578 agent: WeakEntity<NativeAgent>,
1579 acp_thread: WeakEntity<AcpThread>,
1580}
1581
1582impl NativeThreadEnvironment {
1583 pub(crate) fn create_subagent_thread(
1584 agent: WeakEntity<NativeAgent>,
1585 parent_thread_entity: Entity<Thread>,
1586 label: String,
1587 initial_prompt: String,
1588 cx: &mut App,
1589 ) -> Result<Rc<dyn SubagentHandle>> {
1590 let parent_thread = parent_thread_entity.read(cx);
1591 let current_depth = parent_thread.depth();
1592
1593 if current_depth >= MAX_SUBAGENT_DEPTH {
1594 return Err(anyhow!(
1595 "Maximum subagent depth ({}) reached",
1596 MAX_SUBAGENT_DEPTH
1597 ));
1598 }
1599
1600 let subagent_thread: Entity<Thread> = cx.new(|cx| {
1601 let mut thread = Thread::new_subagent(&parent_thread_entity, cx);
1602 thread.set_title(label.into(), cx);
1603 thread
1604 });
1605
1606 let session_id = subagent_thread.read(cx).id().clone();
1607
1608 let acp_thread = agent.update(cx, |agent, cx| {
1609 agent.register_session(subagent_thread.clone(), cx)
1610 })?;
1611
1612 Self::prompt_subagent(
1613 session_id,
1614 subagent_thread,
1615 acp_thread,
1616 parent_thread_entity,
1617 initial_prompt,
1618 cx,
1619 )
1620 }
1621
1622 pub(crate) fn resume_subagent_thread(
1623 agent: WeakEntity<NativeAgent>,
1624 parent_thread_entity: Entity<Thread>,
1625 session_id: acp::SessionId,
1626 follow_up_prompt: String,
1627 cx: &mut App,
1628 ) -> Result<Rc<dyn SubagentHandle>> {
1629 let (subagent_thread, acp_thread) = agent.update(cx, |agent, _cx| {
1630 let session = agent
1631 .sessions
1632 .get(&session_id)
1633 .ok_or_else(|| anyhow!("No subagent session found with id {session_id}"))?;
1634 anyhow::Ok((session.thread.clone(), session.acp_thread.clone()))
1635 })??;
1636
1637 Self::prompt_subagent(
1638 session_id,
1639 subagent_thread,
1640 acp_thread,
1641 parent_thread_entity,
1642 follow_up_prompt,
1643 cx,
1644 )
1645 }
1646
1647 fn prompt_subagent(
1648 session_id: acp::SessionId,
1649 subagent_thread: Entity<Thread>,
1650 acp_thread: Entity<acp_thread::AcpThread>,
1651 parent_thread_entity: Entity<Thread>,
1652 prompt: String,
1653 cx: &mut App,
1654 ) -> Result<Rc<dyn SubagentHandle>> {
1655 Ok(Rc::new(NativeSubagentHandle::new(
1656 session_id,
1657 subagent_thread,
1658 acp_thread,
1659 parent_thread_entity,
1660 prompt,
1661 cx,
1662 )) as _)
1663 }
1664}
1665
1666impl ThreadEnvironment for NativeThreadEnvironment {
1667 fn create_terminal(
1668 &self,
1669 command: String,
1670 cwd: Option<PathBuf>,
1671 output_byte_limit: Option<u64>,
1672 cx: &mut AsyncApp,
1673 ) -> Task<Result<Rc<dyn TerminalHandle>>> {
1674 let task = self.acp_thread.update(cx, |thread, cx| {
1675 thread.create_terminal(command, vec![], vec![], cwd, output_byte_limit, cx)
1676 });
1677
1678 let acp_thread = self.acp_thread.clone();
1679 cx.spawn(async move |cx| {
1680 let terminal = task?.await?;
1681
1682 let (drop_tx, drop_rx) = oneshot::channel();
1683 let terminal_id = terminal.read_with(cx, |terminal, _cx| terminal.id().clone());
1684
1685 cx.spawn(async move |cx| {
1686 drop_rx.await.ok();
1687 acp_thread.update(cx, |thread, cx| thread.release_terminal(terminal_id, cx))
1688 })
1689 .detach();
1690
1691 let handle = AcpTerminalHandle {
1692 terminal,
1693 _drop_tx: Some(drop_tx),
1694 };
1695
1696 Ok(Rc::new(handle) as _)
1697 })
1698 }
1699
1700 fn create_subagent(
1701 &self,
1702 parent_thread_entity: Entity<Thread>,
1703 label: String,
1704 initial_prompt: String,
1705 cx: &mut App,
1706 ) -> Result<Rc<dyn SubagentHandle>> {
1707 Self::create_subagent_thread(
1708 self.agent.clone(),
1709 parent_thread_entity,
1710 label,
1711 initial_prompt,
1712 cx,
1713 )
1714 }
1715
1716 fn resume_subagent(
1717 &self,
1718 parent_thread_entity: Entity<Thread>,
1719 session_id: acp::SessionId,
1720 follow_up_prompt: String,
1721 cx: &mut App,
1722 ) -> Result<Rc<dyn SubagentHandle>> {
1723 Self::resume_subagent_thread(
1724 self.agent.clone(),
1725 parent_thread_entity,
1726 session_id,
1727 follow_up_prompt,
1728 cx,
1729 )
1730 }
1731}
1732
1733#[derive(Debug, Clone)]
1734enum SubagentPromptResult {
1735 Completed,
1736 Cancelled,
1737 ContextWindowWarning,
1738 Error(String),
1739}
1740
1741pub struct NativeSubagentHandle {
1742 session_id: acp::SessionId,
1743 parent_thread: WeakEntity<Thread>,
1744 subagent_thread: Entity<Thread>,
1745 wait_for_prompt_to_complete: Shared<Task<SubagentPromptResult>>,
1746 _subscription: Subscription,
1747}
1748
1749impl NativeSubagentHandle {
1750 fn new(
1751 session_id: acp::SessionId,
1752 subagent_thread: Entity<Thread>,
1753 acp_thread: Entity<acp_thread::AcpThread>,
1754 parent_thread_entity: Entity<Thread>,
1755 prompt: String,
1756 cx: &mut App,
1757 ) -> Self {
1758 let ratio_before_prompt = subagent_thread
1759 .read(cx)
1760 .latest_token_usage()
1761 .map(|usage| usage.ratio());
1762
1763 parent_thread_entity.update(cx, |parent_thread, _cx| {
1764 parent_thread.register_running_subagent(subagent_thread.downgrade())
1765 });
1766
1767 let task = acp_thread.update(cx, |acp_thread, cx| {
1768 acp_thread.send(vec![prompt.into()], cx)
1769 });
1770
1771 let (token_limit_tx, token_limit_rx) = oneshot::channel::<()>();
1772 let mut token_limit_tx = Some(token_limit_tx);
1773
1774 let subscription = cx.subscribe(
1775 &subagent_thread,
1776 move |_thread, event: &TokenUsageUpdated, _cx| {
1777 if let Some(usage) = &event.0 {
1778 let old_ratio = ratio_before_prompt
1779 .clone()
1780 .unwrap_or(TokenUsageRatio::Normal);
1781 let new_ratio = usage.ratio();
1782 if old_ratio == TokenUsageRatio::Normal && new_ratio == TokenUsageRatio::Warning
1783 {
1784 if let Some(tx) = token_limit_tx.take() {
1785 tx.send(()).ok();
1786 }
1787 }
1788 }
1789 },
1790 );
1791
1792 let wait_for_prompt_to_complete = cx
1793 .background_spawn(async move {
1794 futures::select! {
1795 response = task.fuse() => match response {
1796 Ok(Some(response)) =>{
1797 match response.stop_reason {
1798 acp::StopReason::Cancelled => SubagentPromptResult::Cancelled,
1799 acp::StopReason::MaxTokens => SubagentPromptResult::Error("The agent reached the maximum number of tokens.".into()),
1800 acp::StopReason::MaxTurnRequests => SubagentPromptResult::Error("The agent reached the maximum number of allowed requests between user turns. Try prompting again.".into()),
1801 acp::StopReason::Refusal => SubagentPromptResult::Error("The agent refused to process that prompt. Try again.".into()),
1802 acp::StopReason::EndTurn | _ => SubagentPromptResult::Completed,
1803 }
1804
1805 }
1806 Ok(None) => SubagentPromptResult::Error("No response from the agent. You can try messaging again.".into()),
1807 Err(error) => SubagentPromptResult::Error(error.to_string()),
1808 },
1809 _ = token_limit_rx.fuse() => SubagentPromptResult::ContextWindowWarning,
1810 }
1811 })
1812 .shared();
1813
1814 NativeSubagentHandle {
1815 session_id,
1816 subagent_thread,
1817 parent_thread: parent_thread_entity.downgrade(),
1818 wait_for_prompt_to_complete,
1819 _subscription: subscription,
1820 }
1821 }
1822}
1823
1824impl SubagentHandle for NativeSubagentHandle {
1825 fn id(&self) -> acp::SessionId {
1826 self.session_id.clone()
1827 }
1828
1829 fn wait_for_output(&self, cx: &AsyncApp) -> Task<Result<String>> {
1830 let thread = self.subagent_thread.clone();
1831 let wait_for_prompt = self.wait_for_prompt_to_complete.clone();
1832
1833 let subagent_session_id = self.session_id.clone();
1834 let parent_thread = self.parent_thread.clone();
1835
1836 cx.spawn(async move |cx| {
1837 let result = match wait_for_prompt.await {
1838 SubagentPromptResult::Completed => thread.read_with(cx, |thread, _cx| {
1839 thread
1840 .last_message()
1841 .map(|m| m.to_markdown())
1842 .context("No response from subagent")
1843 }),
1844 SubagentPromptResult::Cancelled => Err(anyhow!("User cancelled")),
1845 SubagentPromptResult::Error(message) => Err(anyhow!("{message}")),
1846 SubagentPromptResult::ContextWindowWarning => {
1847 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
1848 Err(anyhow!(
1849 "The agent is nearing the end of its context window and has been \
1850 stopped. You can prompt the thread again to have the agent wrap up \
1851 or hand off its work."
1852 ))
1853 }
1854 };
1855
1856 parent_thread
1857 .update(cx, |parent_thread, cx| {
1858 parent_thread.unregister_running_subagent(&subagent_session_id, cx)
1859 })
1860 .ok();
1861
1862 result
1863 })
1864 }
1865}
1866
1867pub struct AcpTerminalHandle {
1868 terminal: Entity<acp_thread::Terminal>,
1869 _drop_tx: Option<oneshot::Sender<()>>,
1870}
1871
1872impl TerminalHandle for AcpTerminalHandle {
1873 fn id(&self, cx: &AsyncApp) -> Result<acp::TerminalId> {
1874 Ok(self.terminal.read_with(cx, |term, _cx| term.id().clone()))
1875 }
1876
1877 fn wait_for_exit(&self, cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
1878 Ok(self
1879 .terminal
1880 .read_with(cx, |term, _cx| term.wait_for_exit()))
1881 }
1882
1883 fn current_output(&self, cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
1884 Ok(self
1885 .terminal
1886 .read_with(cx, |term, cx| term.current_output(cx)))
1887 }
1888
1889 fn kill(&self, cx: &AsyncApp) -> Result<()> {
1890 cx.update(|cx| {
1891 self.terminal.update(cx, |terminal, cx| {
1892 terminal.kill(cx);
1893 });
1894 });
1895 Ok(())
1896 }
1897
1898 fn was_stopped_by_user(&self, cx: &AsyncApp) -> Result<bool> {
1899 Ok(self
1900 .terminal
1901 .read_with(cx, |term, _cx| term.was_stopped_by_user()))
1902 }
1903}
1904
1905#[cfg(test)]
1906mod internal_tests {
1907 use super::*;
1908 use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelInfo, MentionUri};
1909 use fs::FakeFs;
1910 use gpui::TestAppContext;
1911 use indoc::formatdoc;
1912 use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
1913 use language_model::{LanguageModelProviderId, LanguageModelProviderName};
1914 use serde_json::json;
1915 use settings::SettingsStore;
1916 use util::{path, rel_path::rel_path};
1917
1918 #[gpui::test]
1919 async fn test_maintaining_project_context(cx: &mut TestAppContext) {
1920 init_test(cx);
1921 let fs = FakeFs::new(cx.executor());
1922 fs.insert_tree(
1923 "/",
1924 json!({
1925 "a": {}
1926 }),
1927 )
1928 .await;
1929 let project = Project::test(fs.clone(), [], cx).await;
1930 let thread_store = cx.new(|cx| ThreadStore::new(cx));
1931 let agent = NativeAgent::new(
1932 project.clone(),
1933 thread_store,
1934 Templates::new(),
1935 None,
1936 fs.clone(),
1937 &mut cx.to_async(),
1938 )
1939 .await
1940 .unwrap();
1941 agent.read_with(cx, |agent, cx| {
1942 assert_eq!(agent.project_context.read(cx).worktrees, vec![])
1943 });
1944
1945 let worktree = project
1946 .update(cx, |project, cx| project.create_worktree("/a", true, cx))
1947 .await
1948 .unwrap();
1949 cx.run_until_parked();
1950 agent.read_with(cx, |agent, cx| {
1951 assert_eq!(
1952 agent.project_context.read(cx).worktrees,
1953 vec![WorktreeContext {
1954 root_name: "a".into(),
1955 abs_path: Path::new("/a").into(),
1956 rules_file: None
1957 }]
1958 )
1959 });
1960
1961 // Creating `/a/.rules` updates the project context.
1962 fs.insert_file("/a/.rules", Vec::new()).await;
1963 cx.run_until_parked();
1964 agent.read_with(cx, |agent, cx| {
1965 let rules_entry = worktree
1966 .read(cx)
1967 .entry_for_path(rel_path(".rules"))
1968 .unwrap();
1969 assert_eq!(
1970 agent.project_context.read(cx).worktrees,
1971 vec![WorktreeContext {
1972 root_name: "a".into(),
1973 abs_path: Path::new("/a").into(),
1974 rules_file: Some(RulesFileContext {
1975 path_in_worktree: rel_path(".rules").into(),
1976 text: "".into(),
1977 project_entry_id: rules_entry.id.to_usize()
1978 })
1979 }]
1980 )
1981 });
1982 }
1983
1984 #[gpui::test]
1985 async fn test_listing_models(cx: &mut TestAppContext) {
1986 init_test(cx);
1987 let fs = FakeFs::new(cx.executor());
1988 fs.insert_tree("/", json!({ "a": {} })).await;
1989 let project = Project::test(fs.clone(), [], cx).await;
1990 let thread_store = cx.new(|cx| ThreadStore::new(cx));
1991 let connection = NativeAgentConnection(
1992 NativeAgent::new(
1993 project.clone(),
1994 thread_store,
1995 Templates::new(),
1996 None,
1997 fs.clone(),
1998 &mut cx.to_async(),
1999 )
2000 .await
2001 .unwrap(),
2002 );
2003
2004 // Create a thread/session
2005 let acp_thread = cx
2006 .update(|cx| {
2007 Rc::new(connection.clone()).new_session(project.clone(), Path::new("/a"), cx)
2008 })
2009 .await
2010 .unwrap();
2011
2012 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2013
2014 let models = cx
2015 .update(|cx| {
2016 connection
2017 .model_selector(&session_id)
2018 .unwrap()
2019 .list_models(cx)
2020 })
2021 .await
2022 .unwrap();
2023
2024 let acp_thread::AgentModelList::Grouped(models) = models else {
2025 panic!("Unexpected model group");
2026 };
2027 assert_eq!(
2028 models,
2029 IndexMap::from_iter([(
2030 AgentModelGroupName("Fake".into()),
2031 vec![AgentModelInfo {
2032 id: acp::ModelId::new("fake/fake"),
2033 name: "Fake".into(),
2034 description: None,
2035 icon: Some(acp_thread::AgentModelIcon::Named(
2036 ui::IconName::ZedAssistant
2037 )),
2038 is_latest: false,
2039 cost: None,
2040 }]
2041 )])
2042 );
2043 }
2044
2045 #[gpui::test]
2046 async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) {
2047 init_test(cx);
2048 let fs = FakeFs::new(cx.executor());
2049 fs.create_dir(paths::settings_file().parent().unwrap())
2050 .await
2051 .unwrap();
2052 fs.insert_file(
2053 paths::settings_file(),
2054 json!({
2055 "agent": {
2056 "default_model": {
2057 "provider": "foo",
2058 "model": "bar"
2059 }
2060 }
2061 })
2062 .to_string()
2063 .into_bytes(),
2064 )
2065 .await;
2066 let project = Project::test(fs.clone(), [], cx).await;
2067
2068 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2069
2070 // Create the agent and connection
2071 let agent = NativeAgent::new(
2072 project.clone(),
2073 thread_store,
2074 Templates::new(),
2075 None,
2076 fs.clone(),
2077 &mut cx.to_async(),
2078 )
2079 .await
2080 .unwrap();
2081 let connection = NativeAgentConnection(agent.clone());
2082
2083 // Create a thread/session
2084 let acp_thread = cx
2085 .update(|cx| {
2086 Rc::new(connection.clone()).new_session(project.clone(), Path::new("/a"), cx)
2087 })
2088 .await
2089 .unwrap();
2090
2091 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2092
2093 // Select a model
2094 let selector = connection.model_selector(&session_id).unwrap();
2095 let model_id = acp::ModelId::new("fake/fake");
2096 cx.update(|cx| selector.select_model(model_id.clone(), cx))
2097 .await
2098 .unwrap();
2099
2100 // Verify the thread has the selected model
2101 agent.read_with(cx, |agent, _| {
2102 let session = agent.sessions.get(&session_id).unwrap();
2103 session.thread.read_with(cx, |thread, _| {
2104 assert_eq!(thread.model().unwrap().id().0, "fake");
2105 });
2106 });
2107
2108 cx.run_until_parked();
2109
2110 // Verify settings file was updated
2111 let settings_content = fs.load(paths::settings_file()).await.unwrap();
2112 let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
2113
2114 // Check that the agent settings contain the selected model
2115 assert_eq!(
2116 settings_json["agent"]["default_model"]["model"],
2117 json!("fake")
2118 );
2119 assert_eq!(
2120 settings_json["agent"]["default_model"]["provider"],
2121 json!("fake")
2122 );
2123
2124 // Register a thinking model and select it.
2125 cx.update(|cx| {
2126 let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2127 "fake-corp",
2128 "fake-thinking",
2129 "Fake Thinking",
2130 true,
2131 ));
2132 let thinking_provider = Arc::new(
2133 FakeLanguageModelProvider::new(
2134 LanguageModelProviderId::from("fake-corp".to_string()),
2135 LanguageModelProviderName::from("Fake Corp".to_string()),
2136 )
2137 .with_models(vec![thinking_model]),
2138 );
2139 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2140 registry.register_provider(thinking_provider, cx);
2141 });
2142 });
2143 agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2144
2145 let selector = connection.model_selector(&session_id).unwrap();
2146 cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2147 .await
2148 .unwrap();
2149 cx.run_until_parked();
2150
2151 // Verify enable_thinking was written to settings as true.
2152 let settings_content = fs.load(paths::settings_file()).await.unwrap();
2153 let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
2154 assert_eq!(
2155 settings_json["agent"]["default_model"]["enable_thinking"],
2156 json!(true),
2157 "selecting a thinking model should persist enable_thinking: true to settings"
2158 );
2159 }
2160
2161 #[gpui::test]
2162 async fn test_select_model_updates_thinking_enabled(cx: &mut TestAppContext) {
2163 init_test(cx);
2164 let fs = FakeFs::new(cx.executor());
2165 fs.create_dir(paths::settings_file().parent().unwrap())
2166 .await
2167 .unwrap();
2168 fs.insert_file(paths::settings_file(), b"{}".to_vec()).await;
2169 let project = Project::test(fs.clone(), [], cx).await;
2170
2171 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2172 let agent = NativeAgent::new(
2173 project.clone(),
2174 thread_store,
2175 Templates::new(),
2176 None,
2177 fs.clone(),
2178 &mut cx.to_async(),
2179 )
2180 .await
2181 .unwrap();
2182 let connection = NativeAgentConnection(agent.clone());
2183
2184 let acp_thread = cx
2185 .update(|cx| {
2186 Rc::new(connection.clone()).new_session(project.clone(), Path::new("/a"), cx)
2187 })
2188 .await
2189 .unwrap();
2190 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2191
2192 // Register a second provider with a thinking model.
2193 cx.update(|cx| {
2194 let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2195 "fake-corp",
2196 "fake-thinking",
2197 "Fake Thinking",
2198 true,
2199 ));
2200 let thinking_provider = Arc::new(
2201 FakeLanguageModelProvider::new(
2202 LanguageModelProviderId::from("fake-corp".to_string()),
2203 LanguageModelProviderName::from("Fake Corp".to_string()),
2204 )
2205 .with_models(vec![thinking_model]),
2206 );
2207 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2208 registry.register_provider(thinking_provider, cx);
2209 });
2210 });
2211 // Refresh the agent's model list so it picks up the new provider.
2212 agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2213
2214 // Thread starts with thinking_enabled = false (the default).
2215 agent.read_with(cx, |agent, _| {
2216 let session = agent.sessions.get(&session_id).unwrap();
2217 session.thread.read_with(cx, |thread, _| {
2218 assert!(!thread.thinking_enabled(), "thinking defaults to false");
2219 });
2220 });
2221
2222 // Select the thinking model via select_model.
2223 let selector = connection.model_selector(&session_id).unwrap();
2224 cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2225 .await
2226 .unwrap();
2227
2228 // select_model should have enabled thinking based on the model's supports_thinking().
2229 agent.read_with(cx, |agent, _| {
2230 let session = agent.sessions.get(&session_id).unwrap();
2231 session.thread.read_with(cx, |thread, _| {
2232 assert!(
2233 thread.thinking_enabled(),
2234 "select_model should enable thinking when model supports it"
2235 );
2236 });
2237 });
2238
2239 // Switch back to the non-thinking model.
2240 let selector = connection.model_selector(&session_id).unwrap();
2241 cx.update(|cx| selector.select_model(acp::ModelId::new("fake/fake"), cx))
2242 .await
2243 .unwrap();
2244
2245 // select_model should have disabled thinking.
2246 agent.read_with(cx, |agent, _| {
2247 let session = agent.sessions.get(&session_id).unwrap();
2248 session.thread.read_with(cx, |thread, _| {
2249 assert!(
2250 !thread.thinking_enabled(),
2251 "select_model should disable thinking when model does not support it"
2252 );
2253 });
2254 });
2255 }
2256
2257 #[gpui::test]
2258 async fn test_loaded_thread_preserves_thinking_enabled(cx: &mut TestAppContext) {
2259 init_test(cx);
2260 let fs = FakeFs::new(cx.executor());
2261 fs.insert_tree("/", json!({ "a": {} })).await;
2262 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2263 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2264 let agent = NativeAgent::new(
2265 project.clone(),
2266 thread_store.clone(),
2267 Templates::new(),
2268 None,
2269 fs.clone(),
2270 &mut cx.to_async(),
2271 )
2272 .await
2273 .unwrap();
2274 let connection = Rc::new(NativeAgentConnection(agent.clone()));
2275
2276 // Register a thinking model.
2277 let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2278 "fake-corp",
2279 "fake-thinking",
2280 "Fake Thinking",
2281 true,
2282 ));
2283 let thinking_provider = Arc::new(
2284 FakeLanguageModelProvider::new(
2285 LanguageModelProviderId::from("fake-corp".to_string()),
2286 LanguageModelProviderName::from("Fake Corp".to_string()),
2287 )
2288 .with_models(vec![thinking_model.clone()]),
2289 );
2290 cx.update(|cx| {
2291 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2292 registry.register_provider(thinking_provider, cx);
2293 });
2294 });
2295 agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2296
2297 // Create a thread and select the thinking model.
2298 let acp_thread = cx
2299 .update(|cx| {
2300 connection
2301 .clone()
2302 .new_session(project.clone(), Path::new("/a"), cx)
2303 })
2304 .await
2305 .unwrap();
2306 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2307
2308 let selector = connection.model_selector(&session_id).unwrap();
2309 cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2310 .await
2311 .unwrap();
2312
2313 // Verify thinking is enabled after selecting the thinking model.
2314 let thread = agent.read_with(cx, |agent, _| {
2315 agent.sessions.get(&session_id).unwrap().thread.clone()
2316 });
2317 thread.read_with(cx, |thread, _| {
2318 assert!(
2319 thread.thinking_enabled(),
2320 "thinking should be enabled after selecting thinking model"
2321 );
2322 });
2323
2324 // Send a message so the thread gets persisted.
2325 let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
2326 let send = cx.foreground_executor().spawn(send);
2327 cx.run_until_parked();
2328
2329 thinking_model.send_last_completion_stream_text_chunk("Response.");
2330 thinking_model.end_last_completion_stream();
2331
2332 send.await.unwrap();
2333 cx.run_until_parked();
2334
2335 // Close the session so it can be reloaded from disk.
2336 cx.update(|cx| connection.clone().close_session(&session_id, cx))
2337 .await
2338 .unwrap();
2339 drop(thread);
2340 drop(acp_thread);
2341 agent.read_with(cx, |agent, _| {
2342 assert!(agent.sessions.is_empty());
2343 });
2344
2345 // Reload the thread and verify thinking_enabled is still true.
2346 let reloaded_acp_thread = agent
2347 .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
2348 .await
2349 .unwrap();
2350 let reloaded_thread = agent.read_with(cx, |agent, _| {
2351 agent.sessions.get(&session_id).unwrap().thread.clone()
2352 });
2353 reloaded_thread.read_with(cx, |thread, _| {
2354 assert!(
2355 thread.thinking_enabled(),
2356 "thinking_enabled should be preserved when reloading a thread with a thinking model"
2357 );
2358 });
2359
2360 drop(reloaded_acp_thread);
2361 }
2362
2363 #[gpui::test]
2364 async fn test_loaded_thread_preserves_model(cx: &mut TestAppContext) {
2365 init_test(cx);
2366 let fs = FakeFs::new(cx.executor());
2367 fs.insert_tree("/", json!({ "a": {} })).await;
2368 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2369 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2370 let agent = NativeAgent::new(
2371 project.clone(),
2372 thread_store.clone(),
2373 Templates::new(),
2374 None,
2375 fs.clone(),
2376 &mut cx.to_async(),
2377 )
2378 .await
2379 .unwrap();
2380 let connection = Rc::new(NativeAgentConnection(agent.clone()));
2381
2382 // Register a model where id() != name(), like real Anthropic models
2383 // (e.g. id="claude-sonnet-4-5-thinking-latest", name="Claude Sonnet 4.5 Thinking").
2384 let model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2385 "fake-corp",
2386 "custom-model-id",
2387 "Custom Model Display Name",
2388 false,
2389 ));
2390 let provider = Arc::new(
2391 FakeLanguageModelProvider::new(
2392 LanguageModelProviderId::from("fake-corp".to_string()),
2393 LanguageModelProviderName::from("Fake Corp".to_string()),
2394 )
2395 .with_models(vec![model.clone()]),
2396 );
2397 cx.update(|cx| {
2398 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2399 registry.register_provider(provider, cx);
2400 });
2401 });
2402 agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2403
2404 // Create a thread and select the model.
2405 let acp_thread = cx
2406 .update(|cx| {
2407 connection
2408 .clone()
2409 .new_session(project.clone(), Path::new("/a"), cx)
2410 })
2411 .await
2412 .unwrap();
2413 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2414
2415 let selector = connection.model_selector(&session_id).unwrap();
2416 cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/custom-model-id"), cx))
2417 .await
2418 .unwrap();
2419
2420 let thread = agent.read_with(cx, |agent, _| {
2421 agent.sessions.get(&session_id).unwrap().thread.clone()
2422 });
2423 thread.read_with(cx, |thread, _| {
2424 assert_eq!(
2425 thread.model().unwrap().id().0.as_ref(),
2426 "custom-model-id",
2427 "model should be set before persisting"
2428 );
2429 });
2430
2431 // Send a message so the thread gets persisted.
2432 let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
2433 let send = cx.foreground_executor().spawn(send);
2434 cx.run_until_parked();
2435
2436 model.send_last_completion_stream_text_chunk("Response.");
2437 model.end_last_completion_stream();
2438
2439 send.await.unwrap();
2440 cx.run_until_parked();
2441
2442 // Close the session so it can be reloaded from disk.
2443 cx.update(|cx| connection.clone().close_session(&session_id, cx))
2444 .await
2445 .unwrap();
2446 drop(thread);
2447 drop(acp_thread);
2448 agent.read_with(cx, |agent, _| {
2449 assert!(agent.sessions.is_empty());
2450 });
2451
2452 // Reload the thread and verify the model was preserved.
2453 let reloaded_acp_thread = agent
2454 .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
2455 .await
2456 .unwrap();
2457 let reloaded_thread = agent.read_with(cx, |agent, _| {
2458 agent.sessions.get(&session_id).unwrap().thread.clone()
2459 });
2460 reloaded_thread.read_with(cx, |thread, _| {
2461 let reloaded_model = thread
2462 .model()
2463 .expect("model should be present after reload");
2464 assert_eq!(
2465 reloaded_model.id().0.as_ref(),
2466 "custom-model-id",
2467 "reloaded thread should have the same model, not fall back to the default"
2468 );
2469 });
2470
2471 drop(reloaded_acp_thread);
2472 }
2473
2474 #[gpui::test]
2475 async fn test_save_load_thread(cx: &mut TestAppContext) {
2476 init_test(cx);
2477 let fs = FakeFs::new(cx.executor());
2478 fs.insert_tree(
2479 "/",
2480 json!({
2481 "a": {
2482 "b.md": "Lorem"
2483 }
2484 }),
2485 )
2486 .await;
2487 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2488 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2489 let agent = NativeAgent::new(
2490 project.clone(),
2491 thread_store.clone(),
2492 Templates::new(),
2493 None,
2494 fs.clone(),
2495 &mut cx.to_async(),
2496 )
2497 .await
2498 .unwrap();
2499 let connection = Rc::new(NativeAgentConnection(agent.clone()));
2500
2501 let acp_thread = cx
2502 .update(|cx| {
2503 connection
2504 .clone()
2505 .new_session(project.clone(), Path::new(""), cx)
2506 })
2507 .await
2508 .unwrap();
2509 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2510 let thread = agent.read_with(cx, |agent, _| {
2511 agent.sessions.get(&session_id).unwrap().thread.clone()
2512 });
2513
2514 // Ensure empty threads are not saved, even if they get mutated.
2515 let model = Arc::new(FakeLanguageModel::default());
2516 let summary_model = Arc::new(FakeLanguageModel::default());
2517 thread.update(cx, |thread, cx| {
2518 thread.set_model(model.clone(), cx);
2519 thread.set_summarization_model(Some(summary_model.clone()), cx);
2520 });
2521 cx.run_until_parked();
2522 assert_eq!(thread_entries(&thread_store, cx), vec![]);
2523
2524 let send = acp_thread.update(cx, |thread, cx| {
2525 thread.send(
2526 vec![
2527 "What does ".into(),
2528 acp::ContentBlock::ResourceLink(acp::ResourceLink::new(
2529 "b.md",
2530 MentionUri::File {
2531 abs_path: path!("/a/b.md").into(),
2532 }
2533 .to_uri()
2534 .to_string(),
2535 )),
2536 " mean?".into(),
2537 ],
2538 cx,
2539 )
2540 });
2541 let send = cx.foreground_executor().spawn(send);
2542 cx.run_until_parked();
2543
2544 model.send_last_completion_stream_text_chunk("Lorem.");
2545 model.end_last_completion_stream();
2546 cx.run_until_parked();
2547 summary_model
2548 .send_last_completion_stream_text_chunk(&format!("Explaining {}", path!("/a/b.md")));
2549 summary_model.end_last_completion_stream();
2550
2551 send.await.unwrap();
2552 let uri = MentionUri::File {
2553 abs_path: path!("/a/b.md").into(),
2554 }
2555 .to_uri();
2556 acp_thread.read_with(cx, |thread, cx| {
2557 assert_eq!(
2558 thread.to_markdown(cx),
2559 formatdoc! {"
2560 ## User
2561
2562 What does [@b.md]({uri}) mean?
2563
2564 ## Assistant
2565
2566 Lorem.
2567
2568 "}
2569 )
2570 });
2571
2572 cx.run_until_parked();
2573
2574 // Close the session so it can be reloaded from disk.
2575 cx.update(|cx| connection.clone().close_session(&session_id, cx))
2576 .await
2577 .unwrap();
2578 drop(thread);
2579 drop(acp_thread);
2580 agent.read_with(cx, |agent, _| {
2581 assert_eq!(agent.sessions.keys().cloned().collect::<Vec<_>>(), []);
2582 });
2583
2584 // Ensure the thread can be reloaded from disk.
2585 assert_eq!(
2586 thread_entries(&thread_store, cx),
2587 vec![(
2588 session_id.clone(),
2589 format!("Explaining {}", path!("/a/b.md"))
2590 )]
2591 );
2592 let acp_thread = agent
2593 .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
2594 .await
2595 .unwrap();
2596 acp_thread.read_with(cx, |thread, cx| {
2597 assert_eq!(
2598 thread.to_markdown(cx),
2599 formatdoc! {"
2600 ## User
2601
2602 What does [@b.md]({uri}) mean?
2603
2604 ## Assistant
2605
2606 Lorem.
2607
2608 "}
2609 )
2610 });
2611 }
2612
2613 fn thread_entries(
2614 thread_store: &Entity<ThreadStore>,
2615 cx: &mut TestAppContext,
2616 ) -> Vec<(acp::SessionId, String)> {
2617 thread_store.read_with(cx, |store, _| {
2618 store
2619 .entries()
2620 .map(|entry| (entry.id.clone(), entry.title.to_string()))
2621 .collect::<Vec<_>>()
2622 })
2623 }
2624
2625 fn init_test(cx: &mut TestAppContext) {
2626 env_logger::try_init().ok();
2627 cx.update(|cx| {
2628 let settings_store = SettingsStore::test(cx);
2629 cx.set_global(settings_store);
2630
2631 LanguageModelRegistry::test(cx);
2632 });
2633 }
2634}
2635
2636fn mcp_message_content_to_acp_content_block(
2637 content: context_server::types::MessageContent,
2638) -> acp::ContentBlock {
2639 match content {
2640 context_server::types::MessageContent::Text {
2641 text,
2642 annotations: _,
2643 } => text.into(),
2644 context_server::types::MessageContent::Image {
2645 data,
2646 mime_type,
2647 annotations: _,
2648 } => acp::ContentBlock::Image(acp::ImageContent::new(data, mime_type)),
2649 context_server::types::MessageContent::Audio {
2650 data,
2651 mime_type,
2652 annotations: _,
2653 } => acp::ContentBlock::Audio(acp::AudioContent::new(data, mime_type)),
2654 context_server::types::MessageContent::Resource {
2655 resource,
2656 annotations: _,
2657 } => {
2658 let mut link =
2659 acp::ResourceLink::new(resource.uri.to_string(), resource.uri.to_string());
2660 if let Some(mime_type) = resource.mime_type {
2661 link = link.mime_type(mime_type);
2662 }
2663 acp::ContentBlock::ResourceLink(link)
2664 }
2665 }
2666}