1mod db;
2mod edit_agent;
3mod legacy_thread;
4mod native_agent_server;
5pub mod outline;
6mod pattern_extraction;
7mod templates;
8#[cfg(test)]
9mod tests;
10mod thread;
11mod thread_store;
12mod tool_permissions;
13mod tools;
14
15use context_server::ContextServerId;
16pub use db::*;
17use itertools::Itertools;
18pub use native_agent_server::NativeAgentServer;
19pub use pattern_extraction::*;
20pub use shell_command_parser::extract_commands;
21pub use templates::*;
22pub use thread::*;
23pub use thread_store::*;
24pub use tool_permissions::*;
25pub use tools::*;
26
27use acp_thread::{
28 AcpThread, AgentModelSelector, AgentSessionInfo, AgentSessionList, AgentSessionListRequest,
29 AgentSessionListResponse, TokenUsageRatio, UserMessageId,
30};
31use agent_client_protocol as acp;
32use anyhow::{Context as _, Result, anyhow};
33use chrono::{DateTime, Utc};
34use collections::{HashMap, HashSet, IndexMap};
35use fs::Fs;
36use futures::channel::{mpsc, oneshot};
37use futures::future::Shared;
38use futures::{FutureExt as _, StreamExt as _, future};
39use gpui::{
40 App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
41};
42use language_model::{IconOrSvg, LanguageModel, LanguageModelProvider, LanguageModelRegistry};
43use project::{Project, ProjectItem, ProjectPath, Worktree};
44use prompt_store::{
45 ProjectContext, PromptStore, RULES_FILE_NAMES, RulesFileContext, UserRulesContext,
46 WorktreeContext,
47};
48use serde::{Deserialize, Serialize};
49use settings::{LanguageModelSelection, update_settings_file};
50use std::any::Any;
51use std::path::{Path, PathBuf};
52use std::rc::Rc;
53use std::sync::Arc;
54use util::ResultExt;
55use util::path_list::PathList;
56use util::rel_path::RelPath;
57
58#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
59pub struct ProjectSnapshot {
60 pub worktree_snapshots: Vec<project::telemetry_snapshot::TelemetryWorktreeSnapshot>,
61 pub timestamp: DateTime<Utc>,
62}
63
64pub struct RulesLoadingError {
65 pub message: SharedString,
66}
67
68/// Holds both the internal Thread and the AcpThread for a session
69struct Session {
70 /// The internal thread that processes messages
71 thread: Entity<Thread>,
72 /// The ACP thread that handles protocol communication
73 acp_thread: Entity<acp_thread::AcpThread>,
74 pending_save: Task<()>,
75 _subscriptions: Vec<Subscription>,
76}
77
78pub struct LanguageModels {
79 /// Access language model by ID
80 models: HashMap<acp::ModelId, Arc<dyn LanguageModel>>,
81 /// Cached list for returning language model information
82 model_list: acp_thread::AgentModelList,
83 refresh_models_rx: watch::Receiver<()>,
84 refresh_models_tx: watch::Sender<()>,
85 _authenticate_all_providers_task: Task<()>,
86}
87
88impl LanguageModels {
89 fn new(cx: &mut App) -> Self {
90 let (refresh_models_tx, refresh_models_rx) = watch::channel(());
91
92 let mut this = Self {
93 models: HashMap::default(),
94 model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()),
95 refresh_models_rx,
96 refresh_models_tx,
97 _authenticate_all_providers_task: Self::authenticate_all_language_model_providers(cx),
98 };
99 this.refresh_list(cx);
100 this
101 }
102
103 fn refresh_list(&mut self, cx: &App) {
104 let providers = LanguageModelRegistry::global(cx)
105 .read(cx)
106 .visible_providers()
107 .into_iter()
108 .filter(|provider| provider.is_authenticated(cx))
109 .collect::<Vec<_>>();
110
111 let mut language_model_list = IndexMap::default();
112 let mut recommended_models = HashSet::default();
113
114 let mut recommended = Vec::new();
115 for provider in &providers {
116 for model in provider.recommended_models(cx) {
117 recommended_models.insert((model.provider_id(), model.id()));
118 recommended.push(Self::map_language_model_to_info(&model, provider));
119 }
120 }
121 if !recommended.is_empty() {
122 language_model_list.insert(
123 acp_thread::AgentModelGroupName("Recommended".into()),
124 recommended,
125 );
126 }
127
128 let mut models = HashMap::default();
129 for provider in providers {
130 let mut provider_models = Vec::new();
131 for model in provider.provided_models(cx) {
132 let model_info = Self::map_language_model_to_info(&model, &provider);
133 let model_id = model_info.id.clone();
134 provider_models.push(model_info);
135 models.insert(model_id, model);
136 }
137 if !provider_models.is_empty() {
138 language_model_list.insert(
139 acp_thread::AgentModelGroupName(provider.name().0.clone()),
140 provider_models,
141 );
142 }
143 }
144
145 self.models = models;
146 self.model_list = acp_thread::AgentModelList::Grouped(language_model_list);
147 self.refresh_models_tx.send(()).ok();
148 }
149
150 fn watch(&self) -> watch::Receiver<()> {
151 self.refresh_models_rx.clone()
152 }
153
154 pub fn model_from_id(&self, model_id: &acp::ModelId) -> Option<Arc<dyn LanguageModel>> {
155 self.models.get(model_id).cloned()
156 }
157
158 fn map_language_model_to_info(
159 model: &Arc<dyn LanguageModel>,
160 provider: &Arc<dyn LanguageModelProvider>,
161 ) -> acp_thread::AgentModelInfo {
162 acp_thread::AgentModelInfo {
163 id: Self::model_id(model),
164 name: model.name().0,
165 description: None,
166 icon: Some(match provider.icon() {
167 IconOrSvg::Svg(path) => acp_thread::AgentModelIcon::Path(path),
168 IconOrSvg::Icon(name) => acp_thread::AgentModelIcon::Named(name),
169 }),
170 is_latest: model.is_latest(),
171 cost: model.model_cost_info().map(|cost| cost.to_shared_string()),
172 }
173 }
174
175 fn model_id(model: &Arc<dyn LanguageModel>) -> acp::ModelId {
176 acp::ModelId::new(format!("{}/{}", model.provider_id().0, model.id().0))
177 }
178
179 fn authenticate_all_language_model_providers(cx: &mut App) -> Task<()> {
180 let authenticate_all_providers = LanguageModelRegistry::global(cx)
181 .read(cx)
182 .visible_providers()
183 .iter()
184 .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
185 .collect::<Vec<_>>();
186
187 cx.background_spawn(async move {
188 for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
189 if let Err(err) = authenticate_task.await {
190 match err {
191 language_model::AuthenticateError::CredentialsNotFound => {
192 // Since we're authenticating these providers in the
193 // background for the purposes of populating the
194 // language selector, we don't care about providers
195 // where the credentials are not found.
196 }
197 language_model::AuthenticateError::ConnectionRefused => {
198 // Not logging connection refused errors as they are mostly from LM Studio's noisy auth failures.
199 // LM Studio only has one auth method (endpoint call) which fails for users who haven't enabled it.
200 // TODO: Better manage LM Studio auth logic to avoid these noisy failures.
201 }
202 _ => {
203 // Some providers have noisy failure states that we
204 // don't want to spam the logs with every time the
205 // language model selector is initialized.
206 //
207 // Ideally these should have more clear failure modes
208 // that we know are safe to ignore here, like what we do
209 // with `CredentialsNotFound` above.
210 match provider_id.0.as_ref() {
211 "lmstudio" | "ollama" => {
212 // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
213 //
214 // These fail noisily, so we don't log them.
215 }
216 "copilot_chat" => {
217 // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
218 }
219 _ => {
220 log::error!(
221 "Failed to authenticate provider: {}: {err:#}",
222 provider_name.0
223 );
224 }
225 }
226 }
227 }
228 }
229 }
230 })
231 }
232}
233
234pub struct NativeAgent {
235 /// Session ID -> Session mapping
236 sessions: HashMap<acp::SessionId, Session>,
237 thread_store: Entity<ThreadStore>,
238 /// Shared project context for all threads
239 project_context: Entity<ProjectContext>,
240 project_context_needs_refresh: watch::Sender<()>,
241 _maintain_project_context: Task<Result<()>>,
242 context_server_registry: Entity<ContextServerRegistry>,
243 /// Shared templates for all threads
244 templates: Arc<Templates>,
245 /// Cached model information
246 models: LanguageModels,
247 project: Entity<Project>,
248 prompt_store: Option<Entity<PromptStore>>,
249 fs: Arc<dyn Fs>,
250 _subscriptions: Vec<Subscription>,
251}
252
253impl NativeAgent {
254 pub async fn new(
255 project: Entity<Project>,
256 thread_store: Entity<ThreadStore>,
257 templates: Arc<Templates>,
258 prompt_store: Option<Entity<PromptStore>>,
259 fs: Arc<dyn Fs>,
260 cx: &mut AsyncApp,
261 ) -> Result<Entity<NativeAgent>> {
262 log::debug!("Creating new NativeAgent");
263
264 let project_context = cx
265 .update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))
266 .await;
267
268 Ok(cx.new(|cx| {
269 let context_server_store = project.read(cx).context_server_store();
270 let context_server_registry =
271 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
272
273 let mut subscriptions = vec![
274 cx.subscribe(&project, Self::handle_project_event),
275 cx.subscribe(
276 &LanguageModelRegistry::global(cx),
277 Self::handle_models_updated_event,
278 ),
279 cx.subscribe(
280 &context_server_store,
281 Self::handle_context_server_store_updated,
282 ),
283 cx.subscribe(
284 &context_server_registry,
285 Self::handle_context_server_registry_event,
286 ),
287 ];
288 if let Some(prompt_store) = prompt_store.as_ref() {
289 subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
290 }
291
292 let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
293 watch::channel(());
294 Self {
295 sessions: HashMap::default(),
296 thread_store,
297 project_context: cx.new(|_| project_context),
298 project_context_needs_refresh: project_context_needs_refresh_tx,
299 _maintain_project_context: cx.spawn(async move |this, cx| {
300 Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
301 }),
302 context_server_registry,
303 templates,
304 models: LanguageModels::new(cx),
305 project,
306 prompt_store,
307 fs,
308 _subscriptions: subscriptions,
309 }
310 }))
311 }
312
313 fn new_session(
314 &mut self,
315 project: Entity<Project>,
316 cx: &mut Context<Self>,
317 ) -> Entity<AcpThread> {
318 // Create Thread
319 // Fetch default model from registry settings
320 let registry = LanguageModelRegistry::read_global(cx);
321 // Log available models for debugging
322 let available_count = registry.available_models(cx).count();
323 log::debug!("Total available models: {}", available_count);
324
325 let default_model = registry.default_model().and_then(|default_model| {
326 self.models
327 .model_from_id(&LanguageModels::model_id(&default_model.model))
328 });
329 let thread = cx.new(|cx| {
330 Thread::new(
331 project.clone(),
332 self.project_context.clone(),
333 self.context_server_registry.clone(),
334 self.templates.clone(),
335 default_model,
336 cx,
337 )
338 });
339
340 self.register_session(thread, cx)
341 }
342
343 fn register_session(
344 &mut self,
345 thread_handle: Entity<Thread>,
346 cx: &mut Context<Self>,
347 ) -> Entity<AcpThread> {
348 let connection = Rc::new(NativeAgentConnection(cx.entity()));
349
350 let thread = thread_handle.read(cx);
351 let session_id = thread.id().clone();
352 let parent_session_id = thread.parent_thread_id();
353 let title = thread.title();
354 let project = thread.project.clone();
355 let action_log = thread.action_log.clone();
356 let prompt_capabilities_rx = thread.prompt_capabilities_rx.clone();
357 let acp_thread = cx.new(|cx| {
358 acp_thread::AcpThread::new(
359 parent_session_id,
360 title,
361 connection,
362 project.clone(),
363 action_log.clone(),
364 session_id.clone(),
365 prompt_capabilities_rx,
366 cx,
367 )
368 });
369
370 let registry = LanguageModelRegistry::read_global(cx);
371 let summarization_model = registry.thread_summary_model().map(|c| c.model);
372
373 let weak = cx.weak_entity();
374 let weak_thread = thread_handle.downgrade();
375 thread_handle.update(cx, |thread, cx| {
376 thread.set_summarization_model(summarization_model, cx);
377 thread.add_default_tools(
378 Rc::new(NativeThreadEnvironment {
379 acp_thread: acp_thread.downgrade(),
380 thread: weak_thread,
381 agent: weak,
382 }) as _,
383 cx,
384 )
385 });
386
387 let subscriptions = vec![
388 cx.subscribe(&thread_handle, Self::handle_thread_title_updated),
389 cx.subscribe(&thread_handle, Self::handle_thread_token_usage_updated),
390 cx.observe(&thread_handle, move |this, thread, cx| {
391 this.save_thread(thread, cx)
392 }),
393 ];
394
395 self.sessions.insert(
396 session_id,
397 Session {
398 thread: thread_handle,
399 acp_thread: acp_thread.clone(),
400 _subscriptions: subscriptions,
401 pending_save: Task::ready(()),
402 },
403 );
404
405 self.update_available_commands(cx);
406
407 acp_thread
408 }
409
410 pub fn models(&self) -> &LanguageModels {
411 &self.models
412 }
413
414 async fn maintain_project_context(
415 this: WeakEntity<Self>,
416 mut needs_refresh: watch::Receiver<()>,
417 cx: &mut AsyncApp,
418 ) -> Result<()> {
419 while needs_refresh.changed().await.is_ok() {
420 let project_context = this
421 .update(cx, |this, cx| {
422 Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx)
423 })?
424 .await;
425 this.update(cx, |this, cx| {
426 this.project_context = cx.new(|_| project_context);
427 })?;
428 }
429
430 Ok(())
431 }
432
433 fn build_project_context(
434 project: &Entity<Project>,
435 prompt_store: Option<&Entity<PromptStore>>,
436 cx: &mut App,
437 ) -> Task<ProjectContext> {
438 let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
439 let worktree_tasks = worktrees
440 .into_iter()
441 .map(|worktree| {
442 Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
443 })
444 .collect::<Vec<_>>();
445 let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
446 prompt_store.read_with(cx, |prompt_store, cx| {
447 let prompts = prompt_store.default_prompt_metadata();
448 let load_tasks = prompts.into_iter().map(|prompt_metadata| {
449 let contents = prompt_store.load(prompt_metadata.id, cx);
450 async move { (contents.await, prompt_metadata) }
451 });
452 cx.background_spawn(future::join_all(load_tasks))
453 })
454 } else {
455 Task::ready(vec![])
456 };
457
458 cx.spawn(async move |_cx| {
459 let (worktrees, default_user_rules) =
460 future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
461
462 let worktrees = worktrees
463 .into_iter()
464 .map(|(worktree, _rules_error)| {
465 // TODO: show error message
466 // if let Some(rules_error) = rules_error {
467 // this.update(cx, |_, cx| cx.emit(rules_error)).ok();
468 // }
469 worktree
470 })
471 .collect::<Vec<_>>();
472
473 let default_user_rules = default_user_rules
474 .into_iter()
475 .flat_map(|(contents, prompt_metadata)| match contents {
476 Ok(contents) => Some(UserRulesContext {
477 uuid: prompt_metadata.id.as_user()?,
478 title: prompt_metadata.title.map(|title| title.to_string()),
479 contents,
480 }),
481 Err(_err) => {
482 // TODO: show error message
483 // this.update(cx, |_, cx| {
484 // cx.emit(RulesLoadingError {
485 // message: format!("{err:?}").into(),
486 // });
487 // })
488 // .ok();
489 None
490 }
491 })
492 .collect::<Vec<_>>();
493
494 ProjectContext::new(worktrees, default_user_rules)
495 })
496 }
497
498 fn load_worktree_info_for_system_prompt(
499 worktree: Entity<Worktree>,
500 project: Entity<Project>,
501 cx: &mut App,
502 ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
503 let tree = worktree.read(cx);
504 let root_name = tree.root_name_str().into();
505 let abs_path = tree.abs_path();
506
507 let mut context = WorktreeContext {
508 root_name,
509 abs_path,
510 rules_file: None,
511 };
512
513 let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
514 let Some(rules_task) = rules_task else {
515 return Task::ready((context, None));
516 };
517
518 cx.spawn(async move |_| {
519 let (rules_file, rules_file_error) = match rules_task.await {
520 Ok(rules_file) => (Some(rules_file), None),
521 Err(err) => (
522 None,
523 Some(RulesLoadingError {
524 message: format!("{err}").into(),
525 }),
526 ),
527 };
528 context.rules_file = rules_file;
529 (context, rules_file_error)
530 })
531 }
532
533 fn load_worktree_rules_file(
534 worktree: Entity<Worktree>,
535 project: Entity<Project>,
536 cx: &mut App,
537 ) -> Option<Task<Result<RulesFileContext>>> {
538 let worktree = worktree.read(cx);
539 let worktree_id = worktree.id();
540 let selected_rules_file = RULES_FILE_NAMES
541 .into_iter()
542 .filter_map(|name| {
543 worktree
544 .entry_for_path(RelPath::unix(name).unwrap())
545 .filter(|entry| entry.is_file())
546 .map(|entry| entry.path.clone())
547 })
548 .next();
549
550 // Note that Cline supports `.clinerules` being a directory, but that is not currently
551 // supported. This doesn't seem to occur often in GitHub repositories.
552 selected_rules_file.map(|path_in_worktree| {
553 let project_path = ProjectPath {
554 worktree_id,
555 path: path_in_worktree.clone(),
556 };
557 let buffer_task =
558 project.update(cx, |project, cx| project.open_buffer(project_path, cx));
559 let rope_task = cx.spawn(async move |cx| {
560 let buffer = buffer_task.await?;
561 let (project_entry_id, rope) = buffer.read_with(cx, |buffer, cx| {
562 let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
563 anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
564 })?;
565 anyhow::Ok((project_entry_id, rope))
566 });
567 // Build a string from the rope on a background thread.
568 cx.background_spawn(async move {
569 let (project_entry_id, rope) = rope_task.await?;
570 anyhow::Ok(RulesFileContext {
571 path_in_worktree,
572 text: rope.to_string().trim().to_string(),
573 project_entry_id: project_entry_id.to_usize(),
574 })
575 })
576 })
577 }
578
579 fn handle_thread_title_updated(
580 &mut self,
581 thread: Entity<Thread>,
582 _: &TitleUpdated,
583 cx: &mut Context<Self>,
584 ) {
585 let session_id = thread.read(cx).id();
586 let Some(session) = self.sessions.get(session_id) else {
587 return;
588 };
589 let thread = thread.downgrade();
590 let acp_thread = session.acp_thread.downgrade();
591 cx.spawn(async move |_, cx| {
592 let title = thread.read_with(cx, |thread, _| thread.title())?;
593 let task = acp_thread.update(cx, |acp_thread, cx| acp_thread.set_title(title, cx))?;
594 task.await
595 })
596 .detach_and_log_err(cx);
597 }
598
599 fn handle_thread_token_usage_updated(
600 &mut self,
601 thread: Entity<Thread>,
602 usage: &TokenUsageUpdated,
603 cx: &mut Context<Self>,
604 ) {
605 let Some(session) = self.sessions.get(thread.read(cx).id()) else {
606 return;
607 };
608 session.acp_thread.update(cx, |acp_thread, cx| {
609 acp_thread.update_token_usage(usage.0.clone(), cx);
610 });
611 }
612
613 fn handle_project_event(
614 &mut self,
615 _project: Entity<Project>,
616 event: &project::Event,
617 _cx: &mut Context<Self>,
618 ) {
619 match event {
620 project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
621 self.project_context_needs_refresh.send(()).ok();
622 }
623 project::Event::WorktreeUpdatedEntries(_, items) => {
624 if items.iter().any(|(path, _, _)| {
625 RULES_FILE_NAMES
626 .iter()
627 .any(|name| path.as_ref() == RelPath::unix(name).unwrap())
628 }) {
629 self.project_context_needs_refresh.send(()).ok();
630 }
631 }
632 _ => {}
633 }
634 }
635
636 fn handle_prompts_updated_event(
637 &mut self,
638 _prompt_store: Entity<PromptStore>,
639 _event: &prompt_store::PromptsUpdatedEvent,
640 _cx: &mut Context<Self>,
641 ) {
642 self.project_context_needs_refresh.send(()).ok();
643 }
644
645 fn handle_models_updated_event(
646 &mut self,
647 _registry: Entity<LanguageModelRegistry>,
648 _event: &language_model::Event,
649 cx: &mut Context<Self>,
650 ) {
651 self.models.refresh_list(cx);
652
653 let registry = LanguageModelRegistry::read_global(cx);
654 let default_model = registry.default_model().map(|m| m.model);
655 let summarization_model = registry.thread_summary_model().map(|m| m.model);
656
657 for session in self.sessions.values_mut() {
658 session.thread.update(cx, |thread, cx| {
659 if thread.model().is_none()
660 && let Some(model) = default_model.clone()
661 {
662 thread.set_model(model, cx);
663 cx.notify();
664 }
665 thread.set_summarization_model(summarization_model.clone(), cx);
666 });
667 }
668 }
669
670 fn handle_context_server_store_updated(
671 &mut self,
672 _store: Entity<project::context_server_store::ContextServerStore>,
673 _event: &project::context_server_store::ServerStatusChangedEvent,
674 cx: &mut Context<Self>,
675 ) {
676 self.update_available_commands(cx);
677 }
678
679 fn handle_context_server_registry_event(
680 &mut self,
681 _registry: Entity<ContextServerRegistry>,
682 event: &ContextServerRegistryEvent,
683 cx: &mut Context<Self>,
684 ) {
685 match event {
686 ContextServerRegistryEvent::ToolsChanged => {}
687 ContextServerRegistryEvent::PromptsChanged => {
688 self.update_available_commands(cx);
689 }
690 }
691 }
692
693 fn update_available_commands(&self, cx: &mut Context<Self>) {
694 let available_commands = self.build_available_commands(cx);
695 for session in self.sessions.values() {
696 session.acp_thread.update(cx, |thread, cx| {
697 thread
698 .handle_session_update(
699 acp::SessionUpdate::AvailableCommandsUpdate(
700 acp::AvailableCommandsUpdate::new(available_commands.clone()),
701 ),
702 cx,
703 )
704 .log_err();
705 });
706 }
707 }
708
709 fn build_available_commands(&self, cx: &App) -> Vec<acp::AvailableCommand> {
710 let registry = self.context_server_registry.read(cx);
711
712 let mut prompt_name_counts: HashMap<&str, usize> = HashMap::default();
713 for context_server_prompt in registry.prompts() {
714 *prompt_name_counts
715 .entry(context_server_prompt.prompt.name.as_str())
716 .or_insert(0) += 1;
717 }
718
719 registry
720 .prompts()
721 .flat_map(|context_server_prompt| {
722 let prompt = &context_server_prompt.prompt;
723
724 let should_prefix = prompt_name_counts
725 .get(prompt.name.as_str())
726 .copied()
727 .unwrap_or(0)
728 > 1;
729
730 let name = if should_prefix {
731 format!("{}.{}", context_server_prompt.server_id, prompt.name)
732 } else {
733 prompt.name.clone()
734 };
735
736 let mut command = acp::AvailableCommand::new(
737 name,
738 prompt.description.clone().unwrap_or_default(),
739 );
740
741 match prompt.arguments.as_deref() {
742 Some([arg]) => {
743 let hint = format!("<{}>", arg.name);
744
745 command = command.input(acp::AvailableCommandInput::Unstructured(
746 acp::UnstructuredCommandInput::new(hint),
747 ));
748 }
749 Some([]) | None => {}
750 Some(_) => {
751 // skip >1 argument commands since we don't support them yet
752 return None;
753 }
754 }
755
756 Some(command)
757 })
758 .collect()
759 }
760
761 pub fn load_thread(
762 &mut self,
763 id: acp::SessionId,
764 cx: &mut Context<Self>,
765 ) -> Task<Result<Entity<Thread>>> {
766 let database_future = ThreadsDatabase::connect(cx);
767 cx.spawn(async move |this, cx| {
768 let database = database_future.await.map_err(|err| anyhow!(err))?;
769 let db_thread = database
770 .load_thread(id.clone())
771 .await?
772 .with_context(|| format!("no thread found with ID: {id:?}"))?;
773
774 this.update(cx, |this, cx| {
775 let summarization_model = LanguageModelRegistry::read_global(cx)
776 .thread_summary_model()
777 .map(|c| c.model);
778
779 cx.new(|cx| {
780 let mut thread = Thread::from_db(
781 id.clone(),
782 db_thread,
783 this.project.clone(),
784 this.project_context.clone(),
785 this.context_server_registry.clone(),
786 this.templates.clone(),
787 cx,
788 );
789 thread.set_summarization_model(summarization_model, cx);
790 thread
791 })
792 })
793 })
794 }
795
796 pub fn open_thread(
797 &mut self,
798 id: acp::SessionId,
799 cx: &mut Context<Self>,
800 ) -> Task<Result<Entity<AcpThread>>> {
801 if let Some(session) = self.sessions.get(&id) {
802 return Task::ready(Ok(session.acp_thread.clone()));
803 }
804
805 let task = self.load_thread(id, cx);
806 cx.spawn(async move |this, cx| {
807 let thread = task.await?;
808 let acp_thread =
809 this.update(cx, |this, cx| this.register_session(thread.clone(), cx))?;
810 let events = thread.update(cx, |thread, cx| thread.replay(cx));
811 cx.update(|cx| {
812 NativeAgentConnection::handle_thread_events(events, acp_thread.downgrade(), cx)
813 })
814 .await?;
815 Ok(acp_thread)
816 })
817 }
818
819 pub fn thread_summary(
820 &mut self,
821 id: acp::SessionId,
822 cx: &mut Context<Self>,
823 ) -> Task<Result<SharedString>> {
824 let thread = self.open_thread(id.clone(), cx);
825 cx.spawn(async move |this, cx| {
826 let acp_thread = thread.await?;
827 let result = this
828 .update(cx, |this, cx| {
829 this.sessions
830 .get(&id)
831 .unwrap()
832 .thread
833 .update(cx, |thread, cx| thread.summary(cx))
834 })?
835 .await
836 .context("Failed to generate summary")?;
837 drop(acp_thread);
838 Ok(result)
839 })
840 }
841
842 fn save_thread(&mut self, thread: Entity<Thread>, cx: &mut Context<Self>) {
843 if thread.read(cx).is_empty() {
844 return;
845 }
846
847 let database_future = ThreadsDatabase::connect(cx);
848 let (id, db_thread) =
849 thread.update(cx, |thread, cx| (thread.id().clone(), thread.to_db(cx)));
850 let Some(session) = self.sessions.get_mut(&id) else {
851 return;
852 };
853
854 let folder_paths = PathList::new(
855 &self
856 .project
857 .read(cx)
858 .visible_worktrees(cx)
859 .map(|worktree| worktree.read(cx).abs_path().to_path_buf())
860 .collect::<Vec<_>>(),
861 );
862
863 let thread_store = self.thread_store.clone();
864 session.pending_save = cx.spawn(async move |_, cx| {
865 let Some(database) = database_future.await.map_err(|err| anyhow!(err)).log_err() else {
866 return;
867 };
868 let db_thread = db_thread.await;
869 database
870 .save_thread(id, db_thread, folder_paths)
871 .await
872 .log_err();
873 thread_store.update(cx, |store, cx| store.reload(cx));
874 });
875 }
876
877 fn send_mcp_prompt(
878 &self,
879 message_id: UserMessageId,
880 session_id: agent_client_protocol::SessionId,
881 prompt_name: String,
882 server_id: ContextServerId,
883 arguments: HashMap<String, String>,
884 original_content: Vec<acp::ContentBlock>,
885 cx: &mut Context<Self>,
886 ) -> Task<Result<acp::PromptResponse>> {
887 let server_store = self.context_server_registry.read(cx).server_store().clone();
888 let path_style = self.project.read(cx).path_style(cx);
889
890 cx.spawn(async move |this, cx| {
891 let prompt =
892 crate::get_prompt(&server_store, &server_id, &prompt_name, arguments, cx).await?;
893
894 let (acp_thread, thread) = this.update(cx, |this, _cx| {
895 let session = this
896 .sessions
897 .get(&session_id)
898 .context("Failed to get session")?;
899 anyhow::Ok((session.acp_thread.clone(), session.thread.clone()))
900 })??;
901
902 let mut last_is_user = true;
903
904 thread.update(cx, |thread, cx| {
905 thread.push_acp_user_block(
906 message_id,
907 original_content.into_iter().skip(1),
908 path_style,
909 cx,
910 );
911 });
912
913 for message in prompt.messages {
914 let context_server::types::PromptMessage { role, content } = message;
915 let block = mcp_message_content_to_acp_content_block(content);
916
917 match role {
918 context_server::types::Role::User => {
919 let id = acp_thread::UserMessageId::new();
920
921 acp_thread.update(cx, |acp_thread, cx| {
922 acp_thread.push_user_content_block_with_indent(
923 Some(id.clone()),
924 block.clone(),
925 true,
926 cx,
927 );
928 });
929
930 thread.update(cx, |thread, cx| {
931 thread.push_acp_user_block(id, [block], path_style, cx);
932 });
933 }
934 context_server::types::Role::Assistant => {
935 acp_thread.update(cx, |acp_thread, cx| {
936 acp_thread.push_assistant_content_block_with_indent(
937 block.clone(),
938 false,
939 true,
940 cx,
941 );
942 });
943
944 thread.update(cx, |thread, cx| {
945 thread.push_acp_agent_block(block, cx);
946 });
947 }
948 }
949
950 last_is_user = role == context_server::types::Role::User;
951 }
952
953 let response_stream = thread.update(cx, |thread, cx| {
954 if last_is_user {
955 thread.send_existing(cx)
956 } else {
957 // Resume if MCP prompt did not end with a user message
958 thread.resume(cx)
959 }
960 })?;
961
962 cx.update(|cx| {
963 NativeAgentConnection::handle_thread_events(
964 response_stream,
965 acp_thread.downgrade(),
966 cx,
967 )
968 })
969 .await
970 })
971 }
972}
973
974/// Wrapper struct that implements the AgentConnection trait
975#[derive(Clone)]
976pub struct NativeAgentConnection(pub Entity<NativeAgent>);
977
978impl NativeAgentConnection {
979 pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option<Entity<Thread>> {
980 self.0
981 .read(cx)
982 .sessions
983 .get(session_id)
984 .map(|session| session.thread.clone())
985 }
986
987 pub fn load_thread(&self, id: acp::SessionId, cx: &mut App) -> Task<Result<Entity<Thread>>> {
988 self.0.update(cx, |this, cx| this.load_thread(id, cx))
989 }
990
991 fn run_turn(
992 &self,
993 session_id: acp::SessionId,
994 cx: &mut App,
995 f: impl 'static
996 + FnOnce(Entity<Thread>, &mut App) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>,
997 ) -> Task<Result<acp::PromptResponse>> {
998 let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
999 agent
1000 .sessions
1001 .get_mut(&session_id)
1002 .map(|s| (s.thread.clone(), s.acp_thread.clone()))
1003 }) else {
1004 return Task::ready(Err(anyhow!("Session not found")));
1005 };
1006 log::debug!("Found session for: {}", session_id);
1007
1008 let response_stream = match f(thread, cx) {
1009 Ok(stream) => stream,
1010 Err(err) => return Task::ready(Err(err)),
1011 };
1012 Self::handle_thread_events(response_stream, acp_thread.downgrade(), cx)
1013 }
1014
1015 fn handle_thread_events(
1016 mut events: mpsc::UnboundedReceiver<Result<ThreadEvent>>,
1017 acp_thread: WeakEntity<AcpThread>,
1018 cx: &App,
1019 ) -> Task<Result<acp::PromptResponse>> {
1020 cx.spawn(async move |cx| {
1021 // Handle response stream and forward to session.acp_thread
1022 while let Some(result) = events.next().await {
1023 match result {
1024 Ok(event) => {
1025 log::trace!("Received completion event: {:?}", event);
1026
1027 match event {
1028 ThreadEvent::UserMessage(message) => {
1029 acp_thread.update(cx, |thread, cx| {
1030 for content in message.content {
1031 thread.push_user_content_block(
1032 Some(message.id.clone()),
1033 content.into(),
1034 cx,
1035 );
1036 }
1037 })?;
1038 }
1039 ThreadEvent::AgentText(text) => {
1040 acp_thread.update(cx, |thread, cx| {
1041 thread.push_assistant_content_block(text.into(), false, cx)
1042 })?;
1043 }
1044 ThreadEvent::AgentThinking(text) => {
1045 acp_thread.update(cx, |thread, cx| {
1046 thread.push_assistant_content_block(text.into(), true, cx)
1047 })?;
1048 }
1049 ThreadEvent::ToolCallAuthorization(ToolCallAuthorization {
1050 tool_call,
1051 options,
1052 response,
1053 context: _,
1054 }) => {
1055 let outcome_task = acp_thread.update(cx, |thread, cx| {
1056 thread.request_tool_call_authorization(tool_call, options, cx)
1057 })??;
1058 cx.background_spawn(async move {
1059 if let acp::RequestPermissionOutcome::Selected(
1060 acp::SelectedPermissionOutcome { option_id, .. },
1061 ) = outcome_task.await
1062 {
1063 response
1064 .send(option_id)
1065 .map(|_| anyhow!("authorization receiver was dropped"))
1066 .log_err();
1067 }
1068 })
1069 .detach();
1070 }
1071 ThreadEvent::ToolCall(tool_call) => {
1072 acp_thread.update(cx, |thread, cx| {
1073 thread.upsert_tool_call(tool_call, cx)
1074 })??;
1075 }
1076 ThreadEvent::ToolCallUpdate(update) => {
1077 acp_thread.update(cx, |thread, cx| {
1078 thread.update_tool_call(update, cx)
1079 })??;
1080 }
1081 ThreadEvent::SubagentSpawned(session_id) => {
1082 acp_thread.update(cx, |thread, cx| {
1083 thread.subagent_spawned(session_id, cx);
1084 })?;
1085 }
1086 ThreadEvent::Retry(status) => {
1087 acp_thread.update(cx, |thread, cx| {
1088 thread.update_retry_status(status, cx)
1089 })?;
1090 }
1091 ThreadEvent::Stop(stop_reason) => {
1092 log::debug!("Assistant message complete: {:?}", stop_reason);
1093 return Ok(acp::PromptResponse::new(stop_reason));
1094 }
1095 }
1096 }
1097 Err(e) => {
1098 log::error!("Error in model response stream: {:?}", e);
1099 return Err(e);
1100 }
1101 }
1102 }
1103
1104 log::debug!("Response stream completed");
1105 anyhow::Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
1106 })
1107 }
1108}
1109
1110struct Command<'a> {
1111 prompt_name: &'a str,
1112 arg_value: &'a str,
1113 explicit_server_id: Option<&'a str>,
1114}
1115
1116impl<'a> Command<'a> {
1117 fn parse(prompt: &'a [acp::ContentBlock]) -> Option<Self> {
1118 let acp::ContentBlock::Text(text_content) = prompt.first()? else {
1119 return None;
1120 };
1121 let text = text_content.text.trim();
1122 let command = text.strip_prefix('/')?;
1123 let (command, arg_value) = command
1124 .split_once(char::is_whitespace)
1125 .unwrap_or((command, ""));
1126
1127 if let Some((server_id, prompt_name)) = command.split_once('.') {
1128 Some(Self {
1129 prompt_name,
1130 arg_value,
1131 explicit_server_id: Some(server_id),
1132 })
1133 } else {
1134 Some(Self {
1135 prompt_name: command,
1136 arg_value,
1137 explicit_server_id: None,
1138 })
1139 }
1140 }
1141}
1142
1143struct NativeAgentModelSelector {
1144 session_id: acp::SessionId,
1145 connection: NativeAgentConnection,
1146}
1147
1148impl acp_thread::AgentModelSelector for NativeAgentModelSelector {
1149 fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
1150 log::debug!("NativeAgentConnection::list_models called");
1151 let list = self.connection.0.read(cx).models.model_list.clone();
1152 Task::ready(if list.is_empty() {
1153 Err(anyhow::anyhow!("No models available"))
1154 } else {
1155 Ok(list)
1156 })
1157 }
1158
1159 fn select_model(&self, model_id: acp::ModelId, cx: &mut App) -> Task<Result<()>> {
1160 log::debug!(
1161 "Setting model for session {}: {}",
1162 self.session_id,
1163 model_id
1164 );
1165 let Some(thread) = self
1166 .connection
1167 .0
1168 .read(cx)
1169 .sessions
1170 .get(&self.session_id)
1171 .map(|session| session.thread.clone())
1172 else {
1173 return Task::ready(Err(anyhow!("Session not found")));
1174 };
1175
1176 let Some(model) = self.connection.0.read(cx).models.model_from_id(&model_id) else {
1177 return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
1178 };
1179
1180 // We want to reset the effort level when switching models, as the currently-selected effort level may
1181 // not be compatible.
1182 let effort = model
1183 .default_effort_level()
1184 .map(|effort_level| effort_level.value.to_string());
1185
1186 thread.update(cx, |thread, cx| {
1187 thread.set_model(model.clone(), cx);
1188 thread.set_thinking_effort(effort.clone(), cx);
1189 thread.set_thinking_enabled(model.supports_thinking(), cx);
1190 });
1191
1192 update_settings_file(
1193 self.connection.0.read(cx).fs.clone(),
1194 cx,
1195 move |settings, cx| {
1196 let provider = model.provider_id().0.to_string();
1197 let model = model.id().0.to_string();
1198 let enable_thinking = thread.read(cx).thinking_enabled();
1199 settings
1200 .agent
1201 .get_or_insert_default()
1202 .set_model(LanguageModelSelection {
1203 provider: provider.into(),
1204 model,
1205 enable_thinking,
1206 effort,
1207 });
1208 },
1209 );
1210
1211 Task::ready(Ok(()))
1212 }
1213
1214 fn selected_model(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelInfo>> {
1215 let Some(thread) = self
1216 .connection
1217 .0
1218 .read(cx)
1219 .sessions
1220 .get(&self.session_id)
1221 .map(|session| session.thread.clone())
1222 else {
1223 return Task::ready(Err(anyhow!("Session not found")));
1224 };
1225 let Some(model) = thread.read(cx).model() else {
1226 return Task::ready(Err(anyhow!("Model not found")));
1227 };
1228 let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
1229 else {
1230 return Task::ready(Err(anyhow!("Provider not found")));
1231 };
1232 Task::ready(Ok(LanguageModels::map_language_model_to_info(
1233 model, &provider,
1234 )))
1235 }
1236
1237 fn watch(&self, cx: &mut App) -> Option<watch::Receiver<()>> {
1238 Some(self.connection.0.read(cx).models.watch())
1239 }
1240
1241 fn should_render_footer(&self) -> bool {
1242 true
1243 }
1244}
1245
1246impl acp_thread::AgentConnection for NativeAgentConnection {
1247 fn telemetry_id(&self) -> SharedString {
1248 "zed".into()
1249 }
1250
1251 fn new_session(
1252 self: Rc<Self>,
1253 project: Entity<Project>,
1254 cwd: &Path,
1255 cx: &mut App,
1256 ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
1257 log::debug!("Creating new thread for project at: {cwd:?}");
1258 Task::ready(Ok(self
1259 .0
1260 .update(cx, |agent, cx| agent.new_session(project, cx))))
1261 }
1262
1263 fn supports_load_session(&self) -> bool {
1264 true
1265 }
1266
1267 fn load_session(
1268 self: Rc<Self>,
1269 session: AgentSessionInfo,
1270 _project: Entity<Project>,
1271 _cwd: &Path,
1272 cx: &mut App,
1273 ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
1274 self.0
1275 .update(cx, |agent, cx| agent.open_thread(session.session_id, cx))
1276 }
1277
1278 fn supports_close_session(&self) -> bool {
1279 true
1280 }
1281
1282 fn close_session(&self, session_id: &acp::SessionId, cx: &mut App) -> Task<Result<()>> {
1283 self.0.update(cx, |agent, _cx| {
1284 agent.sessions.remove(session_id);
1285 });
1286 Task::ready(Ok(()))
1287 }
1288
1289 fn auth_methods(&self) -> &[acp::AuthMethod] {
1290 &[] // No auth for in-process
1291 }
1292
1293 fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
1294 Task::ready(Ok(()))
1295 }
1296
1297 fn model_selector(&self, session_id: &acp::SessionId) -> Option<Rc<dyn AgentModelSelector>> {
1298 Some(Rc::new(NativeAgentModelSelector {
1299 session_id: session_id.clone(),
1300 connection: self.clone(),
1301 }) as Rc<dyn AgentModelSelector>)
1302 }
1303
1304 fn prompt(
1305 &self,
1306 id: Option<acp_thread::UserMessageId>,
1307 params: acp::PromptRequest,
1308 cx: &mut App,
1309 ) -> Task<Result<acp::PromptResponse>> {
1310 let id = id.expect("UserMessageId is required");
1311 let session_id = params.session_id.clone();
1312 log::info!("Received prompt request for session: {}", session_id);
1313 log::debug!("Prompt blocks count: {}", params.prompt.len());
1314
1315 if let Some(parsed_command) = Command::parse(¶ms.prompt) {
1316 let registry = self.0.read(cx).context_server_registry.read(cx);
1317
1318 let explicit_server_id = parsed_command
1319 .explicit_server_id
1320 .map(|server_id| ContextServerId(server_id.into()));
1321
1322 if let Some(prompt) =
1323 registry.find_prompt(explicit_server_id.as_ref(), parsed_command.prompt_name)
1324 {
1325 let arguments = if !parsed_command.arg_value.is_empty()
1326 && let Some(arg_name) = prompt
1327 .prompt
1328 .arguments
1329 .as_ref()
1330 .and_then(|args| args.first())
1331 .map(|arg| arg.name.clone())
1332 {
1333 HashMap::from_iter([(arg_name, parsed_command.arg_value.to_string())])
1334 } else {
1335 Default::default()
1336 };
1337
1338 let prompt_name = prompt.prompt.name.clone();
1339 let server_id = prompt.server_id.clone();
1340
1341 return self.0.update(cx, |agent, cx| {
1342 agent.send_mcp_prompt(
1343 id,
1344 session_id.clone(),
1345 prompt_name,
1346 server_id,
1347 arguments,
1348 params.prompt,
1349 cx,
1350 )
1351 });
1352 };
1353 };
1354
1355 let path_style = self.0.read(cx).project.read(cx).path_style(cx);
1356
1357 self.run_turn(session_id, cx, move |thread, cx| {
1358 let content: Vec<UserMessageContent> = params
1359 .prompt
1360 .into_iter()
1361 .map(|block| UserMessageContent::from_content_block(block, path_style))
1362 .collect::<Vec<_>>();
1363 log::debug!("Converted prompt to message: {} chars", content.len());
1364 log::debug!("Message id: {:?}", id);
1365 log::debug!("Message content: {:?}", content);
1366
1367 thread.update(cx, |thread, cx| thread.send(id, content, cx))
1368 })
1369 }
1370
1371 fn retry(
1372 &self,
1373 session_id: &acp::SessionId,
1374 _cx: &App,
1375 ) -> Option<Rc<dyn acp_thread::AgentSessionRetry>> {
1376 Some(Rc::new(NativeAgentSessionRetry {
1377 connection: self.clone(),
1378 session_id: session_id.clone(),
1379 }) as _)
1380 }
1381
1382 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
1383 log::info!("Cancelling on session: {}", session_id);
1384 self.0.update(cx, |agent, cx| {
1385 if let Some(session) = agent.sessions.get(session_id) {
1386 session
1387 .thread
1388 .update(cx, |thread, cx| thread.cancel(cx))
1389 .detach();
1390 }
1391 });
1392 }
1393
1394 fn truncate(
1395 &self,
1396 session_id: &agent_client_protocol::SessionId,
1397 cx: &App,
1398 ) -> Option<Rc<dyn acp_thread::AgentSessionTruncate>> {
1399 self.0.read_with(cx, |agent, _cx| {
1400 agent.sessions.get(session_id).map(|session| {
1401 Rc::new(NativeAgentSessionTruncate {
1402 thread: session.thread.clone(),
1403 acp_thread: session.acp_thread.downgrade(),
1404 }) as _
1405 })
1406 })
1407 }
1408
1409 fn set_title(
1410 &self,
1411 session_id: &acp::SessionId,
1412 cx: &App,
1413 ) -> Option<Rc<dyn acp_thread::AgentSessionSetTitle>> {
1414 self.0.read_with(cx, |agent, _cx| {
1415 agent
1416 .sessions
1417 .get(session_id)
1418 .filter(|s| !s.thread.read(cx).is_subagent())
1419 .map(|session| {
1420 Rc::new(NativeAgentSessionSetTitle {
1421 thread: session.thread.clone(),
1422 }) as _
1423 })
1424 })
1425 }
1426
1427 fn session_list(&self, cx: &mut App) -> Option<Rc<dyn AgentSessionList>> {
1428 let thread_store = self.0.read(cx).thread_store.clone();
1429 Some(Rc::new(NativeAgentSessionList::new(thread_store, cx)) as _)
1430 }
1431
1432 fn telemetry(&self) -> Option<Rc<dyn acp_thread::AgentTelemetry>> {
1433 Some(Rc::new(self.clone()) as Rc<dyn acp_thread::AgentTelemetry>)
1434 }
1435
1436 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1437 self
1438 }
1439}
1440
1441impl acp_thread::AgentTelemetry for NativeAgentConnection {
1442 fn thread_data(
1443 &self,
1444 session_id: &acp::SessionId,
1445 cx: &mut App,
1446 ) -> Task<Result<serde_json::Value>> {
1447 let Some(session) = self.0.read(cx).sessions.get(session_id) else {
1448 return Task::ready(Err(anyhow!("Session not found")));
1449 };
1450
1451 let task = session.thread.read(cx).to_db(cx);
1452 cx.background_spawn(async move {
1453 serde_json::to_value(task.await).context("Failed to serialize thread")
1454 })
1455 }
1456}
1457
1458pub struct NativeAgentSessionList {
1459 thread_store: Entity<ThreadStore>,
1460 updates_tx: smol::channel::Sender<acp_thread::SessionListUpdate>,
1461 updates_rx: smol::channel::Receiver<acp_thread::SessionListUpdate>,
1462 _subscription: Subscription,
1463}
1464
1465impl NativeAgentSessionList {
1466 fn new(thread_store: Entity<ThreadStore>, cx: &mut App) -> Self {
1467 let (tx, rx) = smol::channel::unbounded();
1468 let this_tx = tx.clone();
1469 let subscription = cx.observe(&thread_store, move |_, _| {
1470 this_tx
1471 .try_send(acp_thread::SessionListUpdate::Refresh)
1472 .ok();
1473 });
1474 Self {
1475 thread_store,
1476 updates_tx: tx,
1477 updates_rx: rx,
1478 _subscription: subscription,
1479 }
1480 }
1481
1482 fn to_session_info(entry: DbThreadMetadata) -> AgentSessionInfo {
1483 AgentSessionInfo {
1484 session_id: entry.id,
1485 cwd: None,
1486 title: Some(entry.title),
1487 updated_at: Some(entry.updated_at),
1488 meta: None,
1489 }
1490 }
1491
1492 pub fn thread_store(&self) -> &Entity<ThreadStore> {
1493 &self.thread_store
1494 }
1495}
1496
1497impl AgentSessionList for NativeAgentSessionList {
1498 fn list_sessions(
1499 &self,
1500 _request: AgentSessionListRequest,
1501 cx: &mut App,
1502 ) -> Task<Result<AgentSessionListResponse>> {
1503 let sessions = self
1504 .thread_store
1505 .read(cx)
1506 .entries()
1507 .map(Self::to_session_info)
1508 .collect();
1509 Task::ready(Ok(AgentSessionListResponse::new(sessions)))
1510 }
1511
1512 fn supports_delete(&self) -> bool {
1513 true
1514 }
1515
1516 fn delete_session(&self, session_id: &acp::SessionId, cx: &mut App) -> Task<Result<()>> {
1517 self.thread_store
1518 .update(cx, |store, cx| store.delete_thread(session_id.clone(), cx))
1519 }
1520
1521 fn delete_sessions(&self, cx: &mut App) -> Task<Result<()>> {
1522 self.thread_store
1523 .update(cx, |store, cx| store.delete_threads(cx))
1524 }
1525
1526 fn watch(
1527 &self,
1528 _cx: &mut App,
1529 ) -> Option<smol::channel::Receiver<acp_thread::SessionListUpdate>> {
1530 Some(self.updates_rx.clone())
1531 }
1532
1533 fn notify_refresh(&self) {
1534 self.updates_tx
1535 .try_send(acp_thread::SessionListUpdate::Refresh)
1536 .ok();
1537 }
1538
1539 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1540 self
1541 }
1542}
1543
1544struct NativeAgentSessionTruncate {
1545 thread: Entity<Thread>,
1546 acp_thread: WeakEntity<AcpThread>,
1547}
1548
1549impl acp_thread::AgentSessionTruncate for NativeAgentSessionTruncate {
1550 fn run(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
1551 match self.thread.update(cx, |thread, cx| {
1552 thread.truncate(message_id.clone(), cx)?;
1553 Ok(thread.latest_token_usage())
1554 }) {
1555 Ok(usage) => {
1556 self.acp_thread
1557 .update(cx, |thread, cx| {
1558 thread.update_token_usage(usage, cx);
1559 })
1560 .ok();
1561 Task::ready(Ok(()))
1562 }
1563 Err(error) => Task::ready(Err(error)),
1564 }
1565 }
1566}
1567
1568struct NativeAgentSessionRetry {
1569 connection: NativeAgentConnection,
1570 session_id: acp::SessionId,
1571}
1572
1573impl acp_thread::AgentSessionRetry for NativeAgentSessionRetry {
1574 fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>> {
1575 self.connection
1576 .run_turn(self.session_id.clone(), cx, |thread, cx| {
1577 thread.update(cx, |thread, cx| thread.resume(cx))
1578 })
1579 }
1580}
1581
1582struct NativeAgentSessionSetTitle {
1583 thread: Entity<Thread>,
1584}
1585
1586impl acp_thread::AgentSessionSetTitle for NativeAgentSessionSetTitle {
1587 fn run(&self, title: SharedString, cx: &mut App) -> Task<Result<()>> {
1588 self.thread
1589 .update(cx, |thread, cx| thread.set_title(title, cx));
1590 Task::ready(Ok(()))
1591 }
1592}
1593
1594pub struct NativeThreadEnvironment {
1595 agent: WeakEntity<NativeAgent>,
1596 thread: WeakEntity<Thread>,
1597 acp_thread: WeakEntity<AcpThread>,
1598}
1599
1600impl NativeThreadEnvironment {
1601 pub(crate) fn create_subagent_thread(
1602 &self,
1603 label: String,
1604 cx: &mut App,
1605 ) -> Result<Rc<dyn SubagentHandle>> {
1606 let Some(parent_thread_entity) = self.thread.upgrade() else {
1607 anyhow::bail!("Parent thread no longer exists".to_string());
1608 };
1609 let parent_thread = parent_thread_entity.read(cx);
1610 let current_depth = parent_thread.depth();
1611
1612 if current_depth >= MAX_SUBAGENT_DEPTH {
1613 return Err(anyhow!(
1614 "Maximum subagent depth ({}) reached",
1615 MAX_SUBAGENT_DEPTH
1616 ));
1617 }
1618
1619 let subagent_thread: Entity<Thread> = cx.new(|cx| {
1620 let mut thread = Thread::new_subagent(&parent_thread_entity, cx);
1621 thread.set_title(label.into(), cx);
1622 thread
1623 });
1624
1625 let session_id = subagent_thread.read(cx).id().clone();
1626
1627 let acp_thread = self.agent.update(cx, |agent, cx| {
1628 agent.register_session(subagent_thread.clone(), cx)
1629 })?;
1630
1631 self.prompt_subagent(session_id, subagent_thread, acp_thread)
1632 }
1633
1634 pub(crate) fn resume_subagent_thread(
1635 &self,
1636 session_id: acp::SessionId,
1637 cx: &mut App,
1638 ) -> Result<Rc<dyn SubagentHandle>> {
1639 let (subagent_thread, acp_thread) = self.agent.update(cx, |agent, _cx| {
1640 let session = agent
1641 .sessions
1642 .get(&session_id)
1643 .ok_or_else(|| anyhow!("No subagent session found with id {session_id}"))?;
1644 anyhow::Ok((session.thread.clone(), session.acp_thread.clone()))
1645 })??;
1646
1647 self.prompt_subagent(session_id, subagent_thread, acp_thread)
1648 }
1649
1650 fn prompt_subagent(
1651 &self,
1652 session_id: acp::SessionId,
1653 subagent_thread: Entity<Thread>,
1654 acp_thread: Entity<acp_thread::AcpThread>,
1655 ) -> Result<Rc<dyn SubagentHandle>> {
1656 let Some(parent_thread_entity) = self.thread.upgrade() else {
1657 anyhow::bail!("Parent thread no longer exists".to_string());
1658 };
1659 Ok(Rc::new(NativeSubagentHandle::new(
1660 session_id,
1661 subagent_thread,
1662 acp_thread,
1663 parent_thread_entity,
1664 )) as _)
1665 }
1666}
1667
1668impl ThreadEnvironment for NativeThreadEnvironment {
1669 fn create_terminal(
1670 &self,
1671 command: String,
1672 cwd: Option<PathBuf>,
1673 output_byte_limit: Option<u64>,
1674 cx: &mut AsyncApp,
1675 ) -> Task<Result<Rc<dyn TerminalHandle>>> {
1676 let task = self.acp_thread.update(cx, |thread, cx| {
1677 thread.create_terminal(command, vec![], vec![], cwd, output_byte_limit, cx)
1678 });
1679
1680 let acp_thread = self.acp_thread.clone();
1681 cx.spawn(async move |cx| {
1682 let terminal = task?.await?;
1683
1684 let (drop_tx, drop_rx) = oneshot::channel();
1685 let terminal_id = terminal.read_with(cx, |terminal, _cx| terminal.id().clone());
1686
1687 cx.spawn(async move |cx| {
1688 drop_rx.await.ok();
1689 acp_thread.update(cx, |thread, cx| thread.release_terminal(terminal_id, cx))
1690 })
1691 .detach();
1692
1693 let handle = AcpTerminalHandle {
1694 terminal,
1695 _drop_tx: Some(drop_tx),
1696 };
1697
1698 Ok(Rc::new(handle) as _)
1699 })
1700 }
1701
1702 fn create_subagent(&self, label: String, cx: &mut App) -> Result<Rc<dyn SubagentHandle>> {
1703 self.create_subagent_thread(label, cx)
1704 }
1705
1706 fn resume_subagent(
1707 &self,
1708 session_id: acp::SessionId,
1709 cx: &mut App,
1710 ) -> Result<Rc<dyn SubagentHandle>> {
1711 self.resume_subagent_thread(session_id, cx)
1712 }
1713}
1714
1715#[derive(Debug, Clone)]
1716enum SubagentPromptResult {
1717 Completed,
1718 Cancelled,
1719 ContextWindowWarning,
1720 Error(String),
1721}
1722
1723pub struct NativeSubagentHandle {
1724 session_id: acp::SessionId,
1725 parent_thread: WeakEntity<Thread>,
1726 subagent_thread: Entity<Thread>,
1727 acp_thread: Entity<acp_thread::AcpThread>,
1728}
1729
1730impl NativeSubagentHandle {
1731 fn new(
1732 session_id: acp::SessionId,
1733 subagent_thread: Entity<Thread>,
1734 acp_thread: Entity<acp_thread::AcpThread>,
1735 parent_thread_entity: Entity<Thread>,
1736 ) -> Self {
1737 NativeSubagentHandle {
1738 session_id,
1739 subagent_thread,
1740 parent_thread: parent_thread_entity.downgrade(),
1741 acp_thread,
1742 }
1743 }
1744}
1745
1746impl SubagentHandle for NativeSubagentHandle {
1747 fn id(&self) -> acp::SessionId {
1748 self.session_id.clone()
1749 }
1750
1751 fn num_entries(&self, cx: &App) -> usize {
1752 self.acp_thread.read(cx).entries().len()
1753 }
1754
1755 fn send(&self, message: String, cx: &AsyncApp) -> Task<Result<String>> {
1756 let thread = self.subagent_thread.clone();
1757 let acp_thread = self.acp_thread.clone();
1758 let subagent_session_id = self.session_id.clone();
1759 let parent_thread = self.parent_thread.clone();
1760
1761 cx.spawn(async move |cx| {
1762 let (task, _subscription) = cx.update(|cx| {
1763 let ratio_before_prompt = thread
1764 .read(cx)
1765 .latest_token_usage()
1766 .map(|usage| usage.ratio());
1767
1768 parent_thread
1769 .update(cx, |parent_thread, _cx| {
1770 parent_thread.register_running_subagent(thread.downgrade())
1771 })
1772 .ok();
1773
1774 let task = acp_thread.update(cx, |acp_thread, cx| {
1775 acp_thread.send(vec![message.into()], cx)
1776 });
1777
1778 let (token_limit_tx, token_limit_rx) = oneshot::channel::<()>();
1779 let mut token_limit_tx = Some(token_limit_tx);
1780
1781 let subscription = cx.subscribe(
1782 &thread,
1783 move |_thread, event: &TokenUsageUpdated, _cx| {
1784 if let Some(usage) = &event.0 {
1785 let old_ratio = ratio_before_prompt
1786 .clone()
1787 .unwrap_or(TokenUsageRatio::Normal);
1788 let new_ratio = usage.ratio();
1789 if old_ratio == TokenUsageRatio::Normal
1790 && new_ratio == TokenUsageRatio::Warning
1791 {
1792 if let Some(tx) = token_limit_tx.take() {
1793 tx.send(()).ok();
1794 }
1795 }
1796 }
1797 },
1798 );
1799
1800 let wait_for_prompt = cx
1801 .background_spawn(async move {
1802 futures::select! {
1803 response = task.fuse() => match response {
1804 Ok(Some(response)) => {
1805 match response.stop_reason {
1806 acp::StopReason::Cancelled => SubagentPromptResult::Cancelled,
1807 acp::StopReason::MaxTokens => SubagentPromptResult::Error("The agent reached the maximum number of tokens.".into()),
1808 acp::StopReason::MaxTurnRequests => SubagentPromptResult::Error("The agent reached the maximum number of allowed requests between user turns. Try prompting again.".into()),
1809 acp::StopReason::Refusal => SubagentPromptResult::Error("The agent refused to process that prompt. Try again.".into()),
1810 acp::StopReason::EndTurn | _ => SubagentPromptResult::Completed,
1811 }
1812 }
1813 Ok(None) => SubagentPromptResult::Error("No response from the agent. You can try messaging again.".into()),
1814 Err(error) => SubagentPromptResult::Error(error.to_string()),
1815 },
1816 _ = token_limit_rx.fuse() => SubagentPromptResult::ContextWindowWarning,
1817 }
1818 });
1819
1820 (wait_for_prompt, subscription)
1821 });
1822
1823 let result = match task.await {
1824 SubagentPromptResult::Completed => thread.read_with(cx, |thread, _cx| {
1825 thread
1826 .last_message()
1827 .and_then(|message| {
1828 let content = message.as_agent_message()?
1829 .content
1830 .iter()
1831 .filter_map(|c| match c {
1832 AgentMessageContent::Text(text) => Some(text.as_str()),
1833 _ => None,
1834 })
1835 .join("\n\n");
1836 if content.is_empty() {
1837 None
1838 } else {
1839 Some( content)
1840 }
1841 })
1842 .context("No response from subagent")
1843 }),
1844 SubagentPromptResult::Cancelled => Err(anyhow!("User canceled")),
1845 SubagentPromptResult::Error(message) => Err(anyhow!("{message}")),
1846 SubagentPromptResult::ContextWindowWarning => {
1847 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
1848 Err(anyhow!(
1849 "The agent is nearing the end of its context window and has been \
1850 stopped. You can prompt the thread again to have the agent wrap up \
1851 or hand off its work."
1852 ))
1853 }
1854 };
1855
1856 parent_thread
1857 .update(cx, |parent_thread, cx| {
1858 parent_thread.unregister_running_subagent(&subagent_session_id, cx)
1859 })
1860 .ok();
1861
1862 result
1863 })
1864 }
1865}
1866
1867pub struct AcpTerminalHandle {
1868 terminal: Entity<acp_thread::Terminal>,
1869 _drop_tx: Option<oneshot::Sender<()>>,
1870}
1871
1872impl TerminalHandle for AcpTerminalHandle {
1873 fn id(&self, cx: &AsyncApp) -> Result<acp::TerminalId> {
1874 Ok(self.terminal.read_with(cx, |term, _cx| term.id().clone()))
1875 }
1876
1877 fn wait_for_exit(&self, cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
1878 Ok(self
1879 .terminal
1880 .read_with(cx, |term, _cx| term.wait_for_exit()))
1881 }
1882
1883 fn current_output(&self, cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
1884 Ok(self
1885 .terminal
1886 .read_with(cx, |term, cx| term.current_output(cx)))
1887 }
1888
1889 fn kill(&self, cx: &AsyncApp) -> Result<()> {
1890 cx.update(|cx| {
1891 self.terminal.update(cx, |terminal, cx| {
1892 terminal.kill(cx);
1893 });
1894 });
1895 Ok(())
1896 }
1897
1898 fn was_stopped_by_user(&self, cx: &AsyncApp) -> Result<bool> {
1899 Ok(self
1900 .terminal
1901 .read_with(cx, |term, _cx| term.was_stopped_by_user()))
1902 }
1903}
1904
1905#[cfg(test)]
1906mod internal_tests {
1907 use super::*;
1908 use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelInfo, MentionUri};
1909 use fs::FakeFs;
1910 use gpui::TestAppContext;
1911 use indoc::formatdoc;
1912 use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
1913 use language_model::{LanguageModelProviderId, LanguageModelProviderName};
1914 use serde_json::json;
1915 use settings::SettingsStore;
1916 use util::{path, rel_path::rel_path};
1917
1918 #[gpui::test]
1919 async fn test_maintaining_project_context(cx: &mut TestAppContext) {
1920 init_test(cx);
1921 let fs = FakeFs::new(cx.executor());
1922 fs.insert_tree(
1923 "/",
1924 json!({
1925 "a": {}
1926 }),
1927 )
1928 .await;
1929 let project = Project::test(fs.clone(), [], cx).await;
1930 let thread_store = cx.new(|cx| ThreadStore::new(cx));
1931 let agent = NativeAgent::new(
1932 project.clone(),
1933 thread_store,
1934 Templates::new(),
1935 None,
1936 fs.clone(),
1937 &mut cx.to_async(),
1938 )
1939 .await
1940 .unwrap();
1941 agent.read_with(cx, |agent, cx| {
1942 assert_eq!(agent.project_context.read(cx).worktrees, vec![])
1943 });
1944
1945 let worktree = project
1946 .update(cx, |project, cx| project.create_worktree("/a", true, cx))
1947 .await
1948 .unwrap();
1949 cx.run_until_parked();
1950 agent.read_with(cx, |agent, cx| {
1951 assert_eq!(
1952 agent.project_context.read(cx).worktrees,
1953 vec![WorktreeContext {
1954 root_name: "a".into(),
1955 abs_path: Path::new("/a").into(),
1956 rules_file: None
1957 }]
1958 )
1959 });
1960
1961 // Creating `/a/.rules` updates the project context.
1962 fs.insert_file("/a/.rules", Vec::new()).await;
1963 cx.run_until_parked();
1964 agent.read_with(cx, |agent, cx| {
1965 let rules_entry = worktree
1966 .read(cx)
1967 .entry_for_path(rel_path(".rules"))
1968 .unwrap();
1969 assert_eq!(
1970 agent.project_context.read(cx).worktrees,
1971 vec![WorktreeContext {
1972 root_name: "a".into(),
1973 abs_path: Path::new("/a").into(),
1974 rules_file: Some(RulesFileContext {
1975 path_in_worktree: rel_path(".rules").into(),
1976 text: "".into(),
1977 project_entry_id: rules_entry.id.to_usize()
1978 })
1979 }]
1980 )
1981 });
1982 }
1983
1984 #[gpui::test]
1985 async fn test_listing_models(cx: &mut TestAppContext) {
1986 init_test(cx);
1987 let fs = FakeFs::new(cx.executor());
1988 fs.insert_tree("/", json!({ "a": {} })).await;
1989 let project = Project::test(fs.clone(), [], cx).await;
1990 let thread_store = cx.new(|cx| ThreadStore::new(cx));
1991 let connection = NativeAgentConnection(
1992 NativeAgent::new(
1993 project.clone(),
1994 thread_store,
1995 Templates::new(),
1996 None,
1997 fs.clone(),
1998 &mut cx.to_async(),
1999 )
2000 .await
2001 .unwrap(),
2002 );
2003
2004 // Create a thread/session
2005 let acp_thread = cx
2006 .update(|cx| {
2007 Rc::new(connection.clone()).new_session(project.clone(), Path::new("/a"), cx)
2008 })
2009 .await
2010 .unwrap();
2011
2012 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2013
2014 let models = cx
2015 .update(|cx| {
2016 connection
2017 .model_selector(&session_id)
2018 .unwrap()
2019 .list_models(cx)
2020 })
2021 .await
2022 .unwrap();
2023
2024 let acp_thread::AgentModelList::Grouped(models) = models else {
2025 panic!("Unexpected model group");
2026 };
2027 assert_eq!(
2028 models,
2029 IndexMap::from_iter([(
2030 AgentModelGroupName("Fake".into()),
2031 vec![AgentModelInfo {
2032 id: acp::ModelId::new("fake/fake"),
2033 name: "Fake".into(),
2034 description: None,
2035 icon: Some(acp_thread::AgentModelIcon::Named(
2036 ui::IconName::ZedAssistant
2037 )),
2038 is_latest: false,
2039 cost: None,
2040 }]
2041 )])
2042 );
2043 }
2044
2045 #[gpui::test]
2046 async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) {
2047 init_test(cx);
2048 let fs = FakeFs::new(cx.executor());
2049 fs.create_dir(paths::settings_file().parent().unwrap())
2050 .await
2051 .unwrap();
2052 fs.insert_file(
2053 paths::settings_file(),
2054 json!({
2055 "agent": {
2056 "default_model": {
2057 "provider": "foo",
2058 "model": "bar"
2059 }
2060 }
2061 })
2062 .to_string()
2063 .into_bytes(),
2064 )
2065 .await;
2066 let project = Project::test(fs.clone(), [], cx).await;
2067
2068 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2069
2070 // Create the agent and connection
2071 let agent = NativeAgent::new(
2072 project.clone(),
2073 thread_store,
2074 Templates::new(),
2075 None,
2076 fs.clone(),
2077 &mut cx.to_async(),
2078 )
2079 .await
2080 .unwrap();
2081 let connection = NativeAgentConnection(agent.clone());
2082
2083 // Create a thread/session
2084 let acp_thread = cx
2085 .update(|cx| {
2086 Rc::new(connection.clone()).new_session(project.clone(), Path::new("/a"), cx)
2087 })
2088 .await
2089 .unwrap();
2090
2091 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2092
2093 // Select a model
2094 let selector = connection.model_selector(&session_id).unwrap();
2095 let model_id = acp::ModelId::new("fake/fake");
2096 cx.update(|cx| selector.select_model(model_id.clone(), cx))
2097 .await
2098 .unwrap();
2099
2100 // Verify the thread has the selected model
2101 agent.read_with(cx, |agent, _| {
2102 let session = agent.sessions.get(&session_id).unwrap();
2103 session.thread.read_with(cx, |thread, _| {
2104 assert_eq!(thread.model().unwrap().id().0, "fake");
2105 });
2106 });
2107
2108 cx.run_until_parked();
2109
2110 // Verify settings file was updated
2111 let settings_content = fs.load(paths::settings_file()).await.unwrap();
2112 let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
2113
2114 // Check that the agent settings contain the selected model
2115 assert_eq!(
2116 settings_json["agent"]["default_model"]["model"],
2117 json!("fake")
2118 );
2119 assert_eq!(
2120 settings_json["agent"]["default_model"]["provider"],
2121 json!("fake")
2122 );
2123
2124 // Register a thinking model and select it.
2125 cx.update(|cx| {
2126 let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2127 "fake-corp",
2128 "fake-thinking",
2129 "Fake Thinking",
2130 true,
2131 ));
2132 let thinking_provider = Arc::new(
2133 FakeLanguageModelProvider::new(
2134 LanguageModelProviderId::from("fake-corp".to_string()),
2135 LanguageModelProviderName::from("Fake Corp".to_string()),
2136 )
2137 .with_models(vec![thinking_model]),
2138 );
2139 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2140 registry.register_provider(thinking_provider, cx);
2141 });
2142 });
2143 agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2144
2145 let selector = connection.model_selector(&session_id).unwrap();
2146 cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2147 .await
2148 .unwrap();
2149 cx.run_until_parked();
2150
2151 // Verify enable_thinking was written to settings as true.
2152 let settings_content = fs.load(paths::settings_file()).await.unwrap();
2153 let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
2154 assert_eq!(
2155 settings_json["agent"]["default_model"]["enable_thinking"],
2156 json!(true),
2157 "selecting a thinking model should persist enable_thinking: true to settings"
2158 );
2159 }
2160
2161 #[gpui::test]
2162 async fn test_select_model_updates_thinking_enabled(cx: &mut TestAppContext) {
2163 init_test(cx);
2164 let fs = FakeFs::new(cx.executor());
2165 fs.create_dir(paths::settings_file().parent().unwrap())
2166 .await
2167 .unwrap();
2168 fs.insert_file(paths::settings_file(), b"{}".to_vec()).await;
2169 let project = Project::test(fs.clone(), [], cx).await;
2170
2171 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2172 let agent = NativeAgent::new(
2173 project.clone(),
2174 thread_store,
2175 Templates::new(),
2176 None,
2177 fs.clone(),
2178 &mut cx.to_async(),
2179 )
2180 .await
2181 .unwrap();
2182 let connection = NativeAgentConnection(agent.clone());
2183
2184 let acp_thread = cx
2185 .update(|cx| {
2186 Rc::new(connection.clone()).new_session(project.clone(), Path::new("/a"), cx)
2187 })
2188 .await
2189 .unwrap();
2190 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2191
2192 // Register a second provider with a thinking model.
2193 cx.update(|cx| {
2194 let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2195 "fake-corp",
2196 "fake-thinking",
2197 "Fake Thinking",
2198 true,
2199 ));
2200 let thinking_provider = Arc::new(
2201 FakeLanguageModelProvider::new(
2202 LanguageModelProviderId::from("fake-corp".to_string()),
2203 LanguageModelProviderName::from("Fake Corp".to_string()),
2204 )
2205 .with_models(vec![thinking_model]),
2206 );
2207 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2208 registry.register_provider(thinking_provider, cx);
2209 });
2210 });
2211 // Refresh the agent's model list so it picks up the new provider.
2212 agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2213
2214 // Thread starts with thinking_enabled = false (the default).
2215 agent.read_with(cx, |agent, _| {
2216 let session = agent.sessions.get(&session_id).unwrap();
2217 session.thread.read_with(cx, |thread, _| {
2218 assert!(!thread.thinking_enabled(), "thinking defaults to false");
2219 });
2220 });
2221
2222 // Select the thinking model via select_model.
2223 let selector = connection.model_selector(&session_id).unwrap();
2224 cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2225 .await
2226 .unwrap();
2227
2228 // select_model should have enabled thinking based on the model's supports_thinking().
2229 agent.read_with(cx, |agent, _| {
2230 let session = agent.sessions.get(&session_id).unwrap();
2231 session.thread.read_with(cx, |thread, _| {
2232 assert!(
2233 thread.thinking_enabled(),
2234 "select_model should enable thinking when model supports it"
2235 );
2236 });
2237 });
2238
2239 // Switch back to the non-thinking model.
2240 let selector = connection.model_selector(&session_id).unwrap();
2241 cx.update(|cx| selector.select_model(acp::ModelId::new("fake/fake"), cx))
2242 .await
2243 .unwrap();
2244
2245 // select_model should have disabled thinking.
2246 agent.read_with(cx, |agent, _| {
2247 let session = agent.sessions.get(&session_id).unwrap();
2248 session.thread.read_with(cx, |thread, _| {
2249 assert!(
2250 !thread.thinking_enabled(),
2251 "select_model should disable thinking when model does not support it"
2252 );
2253 });
2254 });
2255 }
2256
2257 #[gpui::test]
2258 async fn test_loaded_thread_preserves_thinking_enabled(cx: &mut TestAppContext) {
2259 init_test(cx);
2260 let fs = FakeFs::new(cx.executor());
2261 fs.insert_tree("/", json!({ "a": {} })).await;
2262 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2263 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2264 let agent = NativeAgent::new(
2265 project.clone(),
2266 thread_store.clone(),
2267 Templates::new(),
2268 None,
2269 fs.clone(),
2270 &mut cx.to_async(),
2271 )
2272 .await
2273 .unwrap();
2274 let connection = Rc::new(NativeAgentConnection(agent.clone()));
2275
2276 // Register a thinking model.
2277 let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2278 "fake-corp",
2279 "fake-thinking",
2280 "Fake Thinking",
2281 true,
2282 ));
2283 let thinking_provider = Arc::new(
2284 FakeLanguageModelProvider::new(
2285 LanguageModelProviderId::from("fake-corp".to_string()),
2286 LanguageModelProviderName::from("Fake Corp".to_string()),
2287 )
2288 .with_models(vec![thinking_model.clone()]),
2289 );
2290 cx.update(|cx| {
2291 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2292 registry.register_provider(thinking_provider, cx);
2293 });
2294 });
2295 agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2296
2297 // Create a thread and select the thinking model.
2298 let acp_thread = cx
2299 .update(|cx| {
2300 connection
2301 .clone()
2302 .new_session(project.clone(), Path::new("/a"), cx)
2303 })
2304 .await
2305 .unwrap();
2306 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2307
2308 let selector = connection.model_selector(&session_id).unwrap();
2309 cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2310 .await
2311 .unwrap();
2312
2313 // Verify thinking is enabled after selecting the thinking model.
2314 let thread = agent.read_with(cx, |agent, _| {
2315 agent.sessions.get(&session_id).unwrap().thread.clone()
2316 });
2317 thread.read_with(cx, |thread, _| {
2318 assert!(
2319 thread.thinking_enabled(),
2320 "thinking should be enabled after selecting thinking model"
2321 );
2322 });
2323
2324 // Send a message so the thread gets persisted.
2325 let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
2326 let send = cx.foreground_executor().spawn(send);
2327 cx.run_until_parked();
2328
2329 thinking_model.send_last_completion_stream_text_chunk("Response.");
2330 thinking_model.end_last_completion_stream();
2331
2332 send.await.unwrap();
2333 cx.run_until_parked();
2334
2335 // Close the session so it can be reloaded from disk.
2336 cx.update(|cx| connection.clone().close_session(&session_id, cx))
2337 .await
2338 .unwrap();
2339 drop(thread);
2340 drop(acp_thread);
2341 agent.read_with(cx, |agent, _| {
2342 assert!(agent.sessions.is_empty());
2343 });
2344
2345 // Reload the thread and verify thinking_enabled is still true.
2346 let reloaded_acp_thread = agent
2347 .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
2348 .await
2349 .unwrap();
2350 let reloaded_thread = agent.read_with(cx, |agent, _| {
2351 agent.sessions.get(&session_id).unwrap().thread.clone()
2352 });
2353 reloaded_thread.read_with(cx, |thread, _| {
2354 assert!(
2355 thread.thinking_enabled(),
2356 "thinking_enabled should be preserved when reloading a thread with a thinking model"
2357 );
2358 });
2359
2360 drop(reloaded_acp_thread);
2361 }
2362
2363 #[gpui::test]
2364 async fn test_loaded_thread_preserves_model(cx: &mut TestAppContext) {
2365 init_test(cx);
2366 let fs = FakeFs::new(cx.executor());
2367 fs.insert_tree("/", json!({ "a": {} })).await;
2368 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2369 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2370 let agent = NativeAgent::new(
2371 project.clone(),
2372 thread_store.clone(),
2373 Templates::new(),
2374 None,
2375 fs.clone(),
2376 &mut cx.to_async(),
2377 )
2378 .await
2379 .unwrap();
2380 let connection = Rc::new(NativeAgentConnection(agent.clone()));
2381
2382 // Register a model where id() != name(), like real Anthropic models
2383 // (e.g. id="claude-sonnet-4-5-thinking-latest", name="Claude Sonnet 4.5 Thinking").
2384 let model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2385 "fake-corp",
2386 "custom-model-id",
2387 "Custom Model Display Name",
2388 false,
2389 ));
2390 let provider = Arc::new(
2391 FakeLanguageModelProvider::new(
2392 LanguageModelProviderId::from("fake-corp".to_string()),
2393 LanguageModelProviderName::from("Fake Corp".to_string()),
2394 )
2395 .with_models(vec![model.clone()]),
2396 );
2397 cx.update(|cx| {
2398 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2399 registry.register_provider(provider, cx);
2400 });
2401 });
2402 agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2403
2404 // Create a thread and select the model.
2405 let acp_thread = cx
2406 .update(|cx| {
2407 connection
2408 .clone()
2409 .new_session(project.clone(), Path::new("/a"), cx)
2410 })
2411 .await
2412 .unwrap();
2413 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2414
2415 let selector = connection.model_selector(&session_id).unwrap();
2416 cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/custom-model-id"), cx))
2417 .await
2418 .unwrap();
2419
2420 let thread = agent.read_with(cx, |agent, _| {
2421 agent.sessions.get(&session_id).unwrap().thread.clone()
2422 });
2423 thread.read_with(cx, |thread, _| {
2424 assert_eq!(
2425 thread.model().unwrap().id().0.as_ref(),
2426 "custom-model-id",
2427 "model should be set before persisting"
2428 );
2429 });
2430
2431 // Send a message so the thread gets persisted.
2432 let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
2433 let send = cx.foreground_executor().spawn(send);
2434 cx.run_until_parked();
2435
2436 model.send_last_completion_stream_text_chunk("Response.");
2437 model.end_last_completion_stream();
2438
2439 send.await.unwrap();
2440 cx.run_until_parked();
2441
2442 // Close the session so it can be reloaded from disk.
2443 cx.update(|cx| connection.clone().close_session(&session_id, cx))
2444 .await
2445 .unwrap();
2446 drop(thread);
2447 drop(acp_thread);
2448 agent.read_with(cx, |agent, _| {
2449 assert!(agent.sessions.is_empty());
2450 });
2451
2452 // Reload the thread and verify the model was preserved.
2453 let reloaded_acp_thread = agent
2454 .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
2455 .await
2456 .unwrap();
2457 let reloaded_thread = agent.read_with(cx, |agent, _| {
2458 agent.sessions.get(&session_id).unwrap().thread.clone()
2459 });
2460 reloaded_thread.read_with(cx, |thread, _| {
2461 let reloaded_model = thread
2462 .model()
2463 .expect("model should be present after reload");
2464 assert_eq!(
2465 reloaded_model.id().0.as_ref(),
2466 "custom-model-id",
2467 "reloaded thread should have the same model, not fall back to the default"
2468 );
2469 });
2470
2471 drop(reloaded_acp_thread);
2472 }
2473
2474 #[gpui::test]
2475 async fn test_save_load_thread(cx: &mut TestAppContext) {
2476 init_test(cx);
2477 let fs = FakeFs::new(cx.executor());
2478 fs.insert_tree(
2479 "/",
2480 json!({
2481 "a": {
2482 "b.md": "Lorem"
2483 }
2484 }),
2485 )
2486 .await;
2487 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2488 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2489 let agent = NativeAgent::new(
2490 project.clone(),
2491 thread_store.clone(),
2492 Templates::new(),
2493 None,
2494 fs.clone(),
2495 &mut cx.to_async(),
2496 )
2497 .await
2498 .unwrap();
2499 let connection = Rc::new(NativeAgentConnection(agent.clone()));
2500
2501 let acp_thread = cx
2502 .update(|cx| {
2503 connection
2504 .clone()
2505 .new_session(project.clone(), Path::new(""), cx)
2506 })
2507 .await
2508 .unwrap();
2509 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2510 let thread = agent.read_with(cx, |agent, _| {
2511 agent.sessions.get(&session_id).unwrap().thread.clone()
2512 });
2513
2514 // Ensure empty threads are not saved, even if they get mutated.
2515 let model = Arc::new(FakeLanguageModel::default());
2516 let summary_model = Arc::new(FakeLanguageModel::default());
2517 thread.update(cx, |thread, cx| {
2518 thread.set_model(model.clone(), cx);
2519 thread.set_summarization_model(Some(summary_model.clone()), cx);
2520 });
2521 cx.run_until_parked();
2522 assert_eq!(thread_entries(&thread_store, cx), vec![]);
2523
2524 let send = acp_thread.update(cx, |thread, cx| {
2525 thread.send(
2526 vec![
2527 "What does ".into(),
2528 acp::ContentBlock::ResourceLink(acp::ResourceLink::new(
2529 "b.md",
2530 MentionUri::File {
2531 abs_path: path!("/a/b.md").into(),
2532 }
2533 .to_uri()
2534 .to_string(),
2535 )),
2536 " mean?".into(),
2537 ],
2538 cx,
2539 )
2540 });
2541 let send = cx.foreground_executor().spawn(send);
2542 cx.run_until_parked();
2543
2544 model.send_last_completion_stream_text_chunk("Lorem.");
2545 model.end_last_completion_stream();
2546 cx.run_until_parked();
2547 summary_model
2548 .send_last_completion_stream_text_chunk(&format!("Explaining {}", path!("/a/b.md")));
2549 summary_model.end_last_completion_stream();
2550
2551 send.await.unwrap();
2552 let uri = MentionUri::File {
2553 abs_path: path!("/a/b.md").into(),
2554 }
2555 .to_uri();
2556 acp_thread.read_with(cx, |thread, cx| {
2557 assert_eq!(
2558 thread.to_markdown(cx),
2559 formatdoc! {"
2560 ## User
2561
2562 What does [@b.md]({uri}) mean?
2563
2564 ## Assistant
2565
2566 Lorem.
2567
2568 "}
2569 )
2570 });
2571
2572 cx.run_until_parked();
2573
2574 // Close the session so it can be reloaded from disk.
2575 cx.update(|cx| connection.clone().close_session(&session_id, cx))
2576 .await
2577 .unwrap();
2578 drop(thread);
2579 drop(acp_thread);
2580 agent.read_with(cx, |agent, _| {
2581 assert_eq!(agent.sessions.keys().cloned().collect::<Vec<_>>(), []);
2582 });
2583
2584 // Ensure the thread can be reloaded from disk.
2585 assert_eq!(
2586 thread_entries(&thread_store, cx),
2587 vec![(
2588 session_id.clone(),
2589 format!("Explaining {}", path!("/a/b.md"))
2590 )]
2591 );
2592 let acp_thread = agent
2593 .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
2594 .await
2595 .unwrap();
2596 acp_thread.read_with(cx, |thread, cx| {
2597 assert_eq!(
2598 thread.to_markdown(cx),
2599 formatdoc! {"
2600 ## User
2601
2602 What does [@b.md]({uri}) mean?
2603
2604 ## Assistant
2605
2606 Lorem.
2607
2608 "}
2609 )
2610 });
2611 }
2612
2613 fn thread_entries(
2614 thread_store: &Entity<ThreadStore>,
2615 cx: &mut TestAppContext,
2616 ) -> Vec<(acp::SessionId, String)> {
2617 thread_store.read_with(cx, |store, _| {
2618 store
2619 .entries()
2620 .map(|entry| (entry.id.clone(), entry.title.to_string()))
2621 .collect::<Vec<_>>()
2622 })
2623 }
2624
2625 fn init_test(cx: &mut TestAppContext) {
2626 env_logger::try_init().ok();
2627 cx.update(|cx| {
2628 let settings_store = SettingsStore::test(cx);
2629 cx.set_global(settings_store);
2630
2631 LanguageModelRegistry::test(cx);
2632 });
2633 }
2634}
2635
2636fn mcp_message_content_to_acp_content_block(
2637 content: context_server::types::MessageContent,
2638) -> acp::ContentBlock {
2639 match content {
2640 context_server::types::MessageContent::Text {
2641 text,
2642 annotations: _,
2643 } => text.into(),
2644 context_server::types::MessageContent::Image {
2645 data,
2646 mime_type,
2647 annotations: _,
2648 } => acp::ContentBlock::Image(acp::ImageContent::new(data, mime_type)),
2649 context_server::types::MessageContent::Audio {
2650 data,
2651 mime_type,
2652 annotations: _,
2653 } => acp::ContentBlock::Audio(acp::AudioContent::new(data, mime_type)),
2654 context_server::types::MessageContent::Resource {
2655 resource,
2656 annotations: _,
2657 } => {
2658 let mut link =
2659 acp::ResourceLink::new(resource.uri.to_string(), resource.uri.to_string());
2660 if let Some(mime_type) = resource.mime_type {
2661 link = link.mime_type(mime_type);
2662 }
2663 acp::ContentBlock::ResourceLink(link)
2664 }
2665 }
2666}