1mod db;
2mod edit_agent;
3mod legacy_thread;
4mod native_agent_server;
5pub mod outline;
6mod pattern_extraction;
7mod templates;
8#[cfg(test)]
9mod tests;
10mod thread;
11mod thread_store;
12mod tool_permissions;
13mod tools;
14
15use context_server::ContextServerId;
16pub use db::*;
17pub use native_agent_server::NativeAgentServer;
18pub use pattern_extraction::*;
19pub use shell_command_parser::extract_commands;
20pub use templates::*;
21pub use thread::*;
22pub use thread_store::*;
23pub use tool_permissions::*;
24pub use tools::*;
25
26use acp_thread::{
27 AcpThread, AgentModelSelector, AgentSessionInfo, AgentSessionList, AgentSessionListRequest,
28 AgentSessionListResponse, UserMessageId,
29};
30use agent_client_protocol as acp;
31use anyhow::{Context as _, Result, anyhow};
32use chrono::{DateTime, Utc};
33use collections::{HashMap, HashSet, IndexMap};
34use fs::Fs;
35use futures::channel::{mpsc, oneshot};
36use futures::future::Shared;
37use futures::{FutureExt as _, StreamExt as _, future};
38use gpui::{
39 App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
40};
41use language_model::{IconOrSvg, LanguageModel, LanguageModelProvider, LanguageModelRegistry};
42use project::{Project, ProjectItem, ProjectPath, Worktree};
43use prompt_store::{
44 ProjectContext, PromptStore, RULES_FILE_NAMES, RulesFileContext, UserRulesContext,
45 WorktreeContext,
46};
47use serde::{Deserialize, Serialize};
48use settings::{LanguageModelSelection, update_settings_file};
49use std::any::Any;
50use std::path::{Path, PathBuf};
51use std::rc::Rc;
52use std::sync::Arc;
53use std::time::Duration;
54use util::ResultExt;
55use util::rel_path::RelPath;
56
57#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
58pub struct ProjectSnapshot {
59 pub worktree_snapshots: Vec<project::telemetry_snapshot::TelemetryWorktreeSnapshot>,
60 pub timestamp: DateTime<Utc>,
61}
62
63pub struct RulesLoadingError {
64 pub message: SharedString,
65}
66
67/// Holds both the internal Thread and the AcpThread for a session
68struct Session {
69 /// The internal thread that processes messages
70 thread: Entity<Thread>,
71 /// The ACP thread that handles protocol communication
72 acp_thread: Entity<acp_thread::AcpThread>,
73 pending_save: Task<()>,
74 _subscriptions: Vec<Subscription>,
75}
76
77pub struct LanguageModels {
78 /// Access language model by ID
79 models: HashMap<acp::ModelId, Arc<dyn LanguageModel>>,
80 /// Cached list for returning language model information
81 model_list: acp_thread::AgentModelList,
82 refresh_models_rx: watch::Receiver<()>,
83 refresh_models_tx: watch::Sender<()>,
84 _authenticate_all_providers_task: Task<()>,
85}
86
87impl LanguageModels {
88 fn new(cx: &mut App) -> Self {
89 let (refresh_models_tx, refresh_models_rx) = watch::channel(());
90
91 let mut this = Self {
92 models: HashMap::default(),
93 model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()),
94 refresh_models_rx,
95 refresh_models_tx,
96 _authenticate_all_providers_task: Self::authenticate_all_language_model_providers(cx),
97 };
98 this.refresh_list(cx);
99 this
100 }
101
102 fn refresh_list(&mut self, cx: &App) {
103 let providers = LanguageModelRegistry::global(cx)
104 .read(cx)
105 .visible_providers()
106 .into_iter()
107 .filter(|provider| provider.is_authenticated(cx))
108 .collect::<Vec<_>>();
109
110 let mut language_model_list = IndexMap::default();
111 let mut recommended_models = HashSet::default();
112
113 let mut recommended = Vec::new();
114 for provider in &providers {
115 for model in provider.recommended_models(cx) {
116 recommended_models.insert((model.provider_id(), model.id()));
117 recommended.push(Self::map_language_model_to_info(&model, provider));
118 }
119 }
120 if !recommended.is_empty() {
121 language_model_list.insert(
122 acp_thread::AgentModelGroupName("Recommended".into()),
123 recommended,
124 );
125 }
126
127 let mut models = HashMap::default();
128 for provider in providers {
129 let mut provider_models = Vec::new();
130 for model in provider.provided_models(cx) {
131 let model_info = Self::map_language_model_to_info(&model, &provider);
132 let model_id = model_info.id.clone();
133 provider_models.push(model_info);
134 models.insert(model_id, model);
135 }
136 if !provider_models.is_empty() {
137 language_model_list.insert(
138 acp_thread::AgentModelGroupName(provider.name().0.clone()),
139 provider_models,
140 );
141 }
142 }
143
144 self.models = models;
145 self.model_list = acp_thread::AgentModelList::Grouped(language_model_list);
146 self.refresh_models_tx.send(()).ok();
147 }
148
149 fn watch(&self) -> watch::Receiver<()> {
150 self.refresh_models_rx.clone()
151 }
152
153 pub fn model_from_id(&self, model_id: &acp::ModelId) -> Option<Arc<dyn LanguageModel>> {
154 self.models.get(model_id).cloned()
155 }
156
157 fn map_language_model_to_info(
158 model: &Arc<dyn LanguageModel>,
159 provider: &Arc<dyn LanguageModelProvider>,
160 ) -> acp_thread::AgentModelInfo {
161 acp_thread::AgentModelInfo {
162 id: Self::model_id(model),
163 name: model.name().0,
164 description: None,
165 icon: Some(match provider.icon() {
166 IconOrSvg::Svg(path) => acp_thread::AgentModelIcon::Path(path),
167 IconOrSvg::Icon(name) => acp_thread::AgentModelIcon::Named(name),
168 }),
169 is_latest: model.is_latest(),
170 cost: model.model_cost_info().map(|cost| cost.to_shared_string()),
171 }
172 }
173
174 fn model_id(model: &Arc<dyn LanguageModel>) -> acp::ModelId {
175 acp::ModelId::new(format!("{}/{}", model.provider_id().0, model.id().0))
176 }
177
178 fn authenticate_all_language_model_providers(cx: &mut App) -> Task<()> {
179 let authenticate_all_providers = LanguageModelRegistry::global(cx)
180 .read(cx)
181 .visible_providers()
182 .iter()
183 .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
184 .collect::<Vec<_>>();
185
186 cx.background_spawn(async move {
187 for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
188 if let Err(err) = authenticate_task.await {
189 match err {
190 language_model::AuthenticateError::CredentialsNotFound => {
191 // Since we're authenticating these providers in the
192 // background for the purposes of populating the
193 // language selector, we don't care about providers
194 // where the credentials are not found.
195 }
196 language_model::AuthenticateError::ConnectionRefused => {
197 // Not logging connection refused errors as they are mostly from LM Studio's noisy auth failures.
198 // LM Studio only has one auth method (endpoint call) which fails for users who haven't enabled it.
199 // TODO: Better manage LM Studio auth logic to avoid these noisy failures.
200 }
201 _ => {
202 // Some providers have noisy failure states that we
203 // don't want to spam the logs with every time the
204 // language model selector is initialized.
205 //
206 // Ideally these should have more clear failure modes
207 // that we know are safe to ignore here, like what we do
208 // with `CredentialsNotFound` above.
209 match provider_id.0.as_ref() {
210 "lmstudio" | "ollama" => {
211 // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
212 //
213 // These fail noisily, so we don't log them.
214 }
215 "copilot_chat" => {
216 // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
217 }
218 _ => {
219 log::error!(
220 "Failed to authenticate provider: {}: {err:#}",
221 provider_name.0
222 );
223 }
224 }
225 }
226 }
227 }
228 }
229 })
230 }
231}
232
233pub struct NativeAgent {
234 /// Session ID -> Session mapping
235 sessions: HashMap<acp::SessionId, Session>,
236 thread_store: Entity<ThreadStore>,
237 /// Shared project context for all threads
238 project_context: Entity<ProjectContext>,
239 project_context_needs_refresh: watch::Sender<()>,
240 _maintain_project_context: Task<Result<()>>,
241 context_server_registry: Entity<ContextServerRegistry>,
242 /// Shared templates for all threads
243 templates: Arc<Templates>,
244 /// Cached model information
245 models: LanguageModels,
246 project: Entity<Project>,
247 prompt_store: Option<Entity<PromptStore>>,
248 fs: Arc<dyn Fs>,
249 _subscriptions: Vec<Subscription>,
250}
251
252impl NativeAgent {
253 pub async fn new(
254 project: Entity<Project>,
255 thread_store: Entity<ThreadStore>,
256 templates: Arc<Templates>,
257 prompt_store: Option<Entity<PromptStore>>,
258 fs: Arc<dyn Fs>,
259 cx: &mut AsyncApp,
260 ) -> Result<Entity<NativeAgent>> {
261 log::debug!("Creating new NativeAgent");
262
263 let project_context = cx
264 .update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))
265 .await;
266
267 Ok(cx.new(|cx| {
268 let context_server_store = project.read(cx).context_server_store();
269 let context_server_registry =
270 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
271
272 let mut subscriptions = vec![
273 cx.subscribe(&project, Self::handle_project_event),
274 cx.subscribe(
275 &LanguageModelRegistry::global(cx),
276 Self::handle_models_updated_event,
277 ),
278 cx.subscribe(
279 &context_server_store,
280 Self::handle_context_server_store_updated,
281 ),
282 cx.subscribe(
283 &context_server_registry,
284 Self::handle_context_server_registry_event,
285 ),
286 ];
287 if let Some(prompt_store) = prompt_store.as_ref() {
288 subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
289 }
290
291 let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
292 watch::channel(());
293 Self {
294 sessions: HashMap::default(),
295 thread_store,
296 project_context: cx.new(|_| project_context),
297 project_context_needs_refresh: project_context_needs_refresh_tx,
298 _maintain_project_context: cx.spawn(async move |this, cx| {
299 Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
300 }),
301 context_server_registry,
302 templates,
303 models: LanguageModels::new(cx),
304 project,
305 prompt_store,
306 fs,
307 _subscriptions: subscriptions,
308 }
309 }))
310 }
311
312 fn new_session(
313 &mut self,
314 project: Entity<Project>,
315 cx: &mut Context<Self>,
316 ) -> Entity<AcpThread> {
317 // Create Thread
318 // Fetch default model from registry settings
319 let registry = LanguageModelRegistry::read_global(cx);
320 // Log available models for debugging
321 let available_count = registry.available_models(cx).count();
322 log::debug!("Total available models: {}", available_count);
323
324 let default_model = registry.default_model().and_then(|default_model| {
325 self.models
326 .model_from_id(&LanguageModels::model_id(&default_model.model))
327 });
328 let thread = cx.new(|cx| {
329 Thread::new(
330 project.clone(),
331 self.project_context.clone(),
332 self.context_server_registry.clone(),
333 self.templates.clone(),
334 default_model,
335 cx,
336 )
337 });
338
339 self.register_session(thread, cx)
340 }
341
342 fn register_session(
343 &mut self,
344 thread_handle: Entity<Thread>,
345 cx: &mut Context<Self>,
346 ) -> Entity<AcpThread> {
347 let connection = Rc::new(NativeAgentConnection(cx.entity()));
348
349 let thread = thread_handle.read(cx);
350 let session_id = thread.id().clone();
351 let parent_session_id = thread.parent_thread_id();
352 let title = thread.title();
353 let project = thread.project.clone();
354 let action_log = thread.action_log.clone();
355 let prompt_capabilities_rx = thread.prompt_capabilities_rx.clone();
356 let acp_thread = cx.new(|cx| {
357 acp_thread::AcpThread::new(
358 parent_session_id,
359 title,
360 connection,
361 project.clone(),
362 action_log.clone(),
363 session_id.clone(),
364 prompt_capabilities_rx,
365 cx,
366 )
367 });
368
369 let registry = LanguageModelRegistry::read_global(cx);
370 let summarization_model = registry.thread_summary_model().map(|c| c.model);
371
372 let weak = cx.weak_entity();
373 thread_handle.update(cx, |thread, cx| {
374 thread.set_summarization_model(summarization_model, cx);
375 thread.add_default_tools(
376 Rc::new(NativeThreadEnvironment {
377 acp_thread: acp_thread.downgrade(),
378 agent: weak,
379 }) as _,
380 cx,
381 )
382 });
383
384 let subscriptions = vec![
385 cx.subscribe(&thread_handle, Self::handle_thread_title_updated),
386 cx.subscribe(&thread_handle, Self::handle_thread_token_usage_updated),
387 cx.observe(&thread_handle, move |this, thread, cx| {
388 this.save_thread(thread, cx)
389 }),
390 ];
391
392 self.sessions.insert(
393 session_id,
394 Session {
395 thread: thread_handle,
396 acp_thread: acp_thread.clone(),
397 _subscriptions: subscriptions,
398 pending_save: Task::ready(()),
399 },
400 );
401
402 self.update_available_commands(cx);
403
404 acp_thread
405 }
406
407 pub fn models(&self) -> &LanguageModels {
408 &self.models
409 }
410
411 async fn maintain_project_context(
412 this: WeakEntity<Self>,
413 mut needs_refresh: watch::Receiver<()>,
414 cx: &mut AsyncApp,
415 ) -> Result<()> {
416 while needs_refresh.changed().await.is_ok() {
417 let project_context = this
418 .update(cx, |this, cx| {
419 Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx)
420 })?
421 .await;
422 this.update(cx, |this, cx| {
423 this.project_context = cx.new(|_| project_context);
424 })?;
425 }
426
427 Ok(())
428 }
429
430 fn build_project_context(
431 project: &Entity<Project>,
432 prompt_store: Option<&Entity<PromptStore>>,
433 cx: &mut App,
434 ) -> Task<ProjectContext> {
435 let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
436 let worktree_tasks = worktrees
437 .into_iter()
438 .map(|worktree| {
439 Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
440 })
441 .collect::<Vec<_>>();
442 let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
443 prompt_store.read_with(cx, |prompt_store, cx| {
444 let prompts = prompt_store.default_prompt_metadata();
445 let load_tasks = prompts.into_iter().map(|prompt_metadata| {
446 let contents = prompt_store.load(prompt_metadata.id, cx);
447 async move { (contents.await, prompt_metadata) }
448 });
449 cx.background_spawn(future::join_all(load_tasks))
450 })
451 } else {
452 Task::ready(vec![])
453 };
454
455 cx.spawn(async move |_cx| {
456 let (worktrees, default_user_rules) =
457 future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
458
459 let worktrees = worktrees
460 .into_iter()
461 .map(|(worktree, _rules_error)| {
462 // TODO: show error message
463 // if let Some(rules_error) = rules_error {
464 // this.update(cx, |_, cx| cx.emit(rules_error)).ok();
465 // }
466 worktree
467 })
468 .collect::<Vec<_>>();
469
470 let default_user_rules = default_user_rules
471 .into_iter()
472 .flat_map(|(contents, prompt_metadata)| match contents {
473 Ok(contents) => Some(UserRulesContext {
474 uuid: prompt_metadata.id.as_user()?,
475 title: prompt_metadata.title.map(|title| title.to_string()),
476 contents,
477 }),
478 Err(_err) => {
479 // TODO: show error message
480 // this.update(cx, |_, cx| {
481 // cx.emit(RulesLoadingError {
482 // message: format!("{err:?}").into(),
483 // });
484 // })
485 // .ok();
486 None
487 }
488 })
489 .collect::<Vec<_>>();
490
491 ProjectContext::new(worktrees, default_user_rules)
492 })
493 }
494
495 fn load_worktree_info_for_system_prompt(
496 worktree: Entity<Worktree>,
497 project: Entity<Project>,
498 cx: &mut App,
499 ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
500 let tree = worktree.read(cx);
501 let root_name = tree.root_name_str().into();
502 let abs_path = tree.abs_path();
503
504 let mut context = WorktreeContext {
505 root_name,
506 abs_path,
507 rules_file: None,
508 };
509
510 let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
511 let Some(rules_task) = rules_task else {
512 return Task::ready((context, None));
513 };
514
515 cx.spawn(async move |_| {
516 let (rules_file, rules_file_error) = match rules_task.await {
517 Ok(rules_file) => (Some(rules_file), None),
518 Err(err) => (
519 None,
520 Some(RulesLoadingError {
521 message: format!("{err}").into(),
522 }),
523 ),
524 };
525 context.rules_file = rules_file;
526 (context, rules_file_error)
527 })
528 }
529
530 fn load_worktree_rules_file(
531 worktree: Entity<Worktree>,
532 project: Entity<Project>,
533 cx: &mut App,
534 ) -> Option<Task<Result<RulesFileContext>>> {
535 let worktree = worktree.read(cx);
536 let worktree_id = worktree.id();
537 let selected_rules_file = RULES_FILE_NAMES
538 .into_iter()
539 .filter_map(|name| {
540 worktree
541 .entry_for_path(RelPath::unix(name).unwrap())
542 .filter(|entry| entry.is_file())
543 .map(|entry| entry.path.clone())
544 })
545 .next();
546
547 // Note that Cline supports `.clinerules` being a directory, but that is not currently
548 // supported. This doesn't seem to occur often in GitHub repositories.
549 selected_rules_file.map(|path_in_worktree| {
550 let project_path = ProjectPath {
551 worktree_id,
552 path: path_in_worktree.clone(),
553 };
554 let buffer_task =
555 project.update(cx, |project, cx| project.open_buffer(project_path, cx));
556 let rope_task = cx.spawn(async move |cx| {
557 let buffer = buffer_task.await?;
558 let (project_entry_id, rope) = buffer.read_with(cx, |buffer, cx| {
559 let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
560 anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
561 })?;
562 anyhow::Ok((project_entry_id, rope))
563 });
564 // Build a string from the rope on a background thread.
565 cx.background_spawn(async move {
566 let (project_entry_id, rope) = rope_task.await?;
567 anyhow::Ok(RulesFileContext {
568 path_in_worktree,
569 text: rope.to_string().trim().to_string(),
570 project_entry_id: project_entry_id.to_usize(),
571 })
572 })
573 })
574 }
575
576 fn handle_thread_title_updated(
577 &mut self,
578 thread: Entity<Thread>,
579 _: &TitleUpdated,
580 cx: &mut Context<Self>,
581 ) {
582 let session_id = thread.read(cx).id();
583 let Some(session) = self.sessions.get(session_id) else {
584 return;
585 };
586 let thread = thread.downgrade();
587 let acp_thread = session.acp_thread.downgrade();
588 cx.spawn(async move |_, cx| {
589 let title = thread.read_with(cx, |thread, _| thread.title())?;
590 let task = acp_thread.update(cx, |acp_thread, cx| acp_thread.set_title(title, cx))?;
591 task.await
592 })
593 .detach_and_log_err(cx);
594 }
595
596 fn handle_thread_token_usage_updated(
597 &mut self,
598 thread: Entity<Thread>,
599 usage: &TokenUsageUpdated,
600 cx: &mut Context<Self>,
601 ) {
602 let Some(session) = self.sessions.get(thread.read(cx).id()) else {
603 return;
604 };
605 session.acp_thread.update(cx, |acp_thread, cx| {
606 acp_thread.update_token_usage(usage.0.clone(), cx);
607 });
608 }
609
610 fn handle_project_event(
611 &mut self,
612 _project: Entity<Project>,
613 event: &project::Event,
614 _cx: &mut Context<Self>,
615 ) {
616 match event {
617 project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
618 self.project_context_needs_refresh.send(()).ok();
619 }
620 project::Event::WorktreeUpdatedEntries(_, items) => {
621 if items.iter().any(|(path, _, _)| {
622 RULES_FILE_NAMES
623 .iter()
624 .any(|name| path.as_ref() == RelPath::unix(name).unwrap())
625 }) {
626 self.project_context_needs_refresh.send(()).ok();
627 }
628 }
629 _ => {}
630 }
631 }
632
633 fn handle_prompts_updated_event(
634 &mut self,
635 _prompt_store: Entity<PromptStore>,
636 _event: &prompt_store::PromptsUpdatedEvent,
637 _cx: &mut Context<Self>,
638 ) {
639 self.project_context_needs_refresh.send(()).ok();
640 }
641
642 fn handle_models_updated_event(
643 &mut self,
644 _registry: Entity<LanguageModelRegistry>,
645 _event: &language_model::Event,
646 cx: &mut Context<Self>,
647 ) {
648 self.models.refresh_list(cx);
649
650 let registry = LanguageModelRegistry::read_global(cx);
651 let default_model = registry.default_model().map(|m| m.model);
652 let summarization_model = registry.thread_summary_model().map(|m| m.model);
653
654 for session in self.sessions.values_mut() {
655 session.thread.update(cx, |thread, cx| {
656 if thread.model().is_none()
657 && let Some(model) = default_model.clone()
658 {
659 thread.set_model(model, cx);
660 cx.notify();
661 }
662 thread.set_summarization_model(summarization_model.clone(), cx);
663 });
664 }
665 }
666
667 fn handle_context_server_store_updated(
668 &mut self,
669 _store: Entity<project::context_server_store::ContextServerStore>,
670 _event: &project::context_server_store::ServerStatusChangedEvent,
671 cx: &mut Context<Self>,
672 ) {
673 self.update_available_commands(cx);
674 }
675
676 fn handle_context_server_registry_event(
677 &mut self,
678 _registry: Entity<ContextServerRegistry>,
679 event: &ContextServerRegistryEvent,
680 cx: &mut Context<Self>,
681 ) {
682 match event {
683 ContextServerRegistryEvent::ToolsChanged => {}
684 ContextServerRegistryEvent::PromptsChanged => {
685 self.update_available_commands(cx);
686 }
687 }
688 }
689
690 fn update_available_commands(&self, cx: &mut Context<Self>) {
691 let available_commands = self.build_available_commands(cx);
692 for session in self.sessions.values() {
693 session.acp_thread.update(cx, |thread, cx| {
694 thread
695 .handle_session_update(
696 acp::SessionUpdate::AvailableCommandsUpdate(
697 acp::AvailableCommandsUpdate::new(available_commands.clone()),
698 ),
699 cx,
700 )
701 .log_err();
702 });
703 }
704 }
705
706 fn build_available_commands(&self, cx: &App) -> Vec<acp::AvailableCommand> {
707 let registry = self.context_server_registry.read(cx);
708
709 let mut prompt_name_counts: HashMap<&str, usize> = HashMap::default();
710 for context_server_prompt in registry.prompts() {
711 *prompt_name_counts
712 .entry(context_server_prompt.prompt.name.as_str())
713 .or_insert(0) += 1;
714 }
715
716 registry
717 .prompts()
718 .flat_map(|context_server_prompt| {
719 let prompt = &context_server_prompt.prompt;
720
721 let should_prefix = prompt_name_counts
722 .get(prompt.name.as_str())
723 .copied()
724 .unwrap_or(0)
725 > 1;
726
727 let name = if should_prefix {
728 format!("{}.{}", context_server_prompt.server_id, prompt.name)
729 } else {
730 prompt.name.clone()
731 };
732
733 let mut command = acp::AvailableCommand::new(
734 name,
735 prompt.description.clone().unwrap_or_default(),
736 );
737
738 match prompt.arguments.as_deref() {
739 Some([arg]) => {
740 let hint = format!("<{}>", arg.name);
741
742 command = command.input(acp::AvailableCommandInput::Unstructured(
743 acp::UnstructuredCommandInput::new(hint),
744 ));
745 }
746 Some([]) | None => {}
747 Some(_) => {
748 // skip >1 argument commands since we don't support them yet
749 return None;
750 }
751 }
752
753 Some(command)
754 })
755 .collect()
756 }
757
758 pub fn load_thread(
759 &mut self,
760 id: acp::SessionId,
761 cx: &mut Context<Self>,
762 ) -> Task<Result<Entity<Thread>>> {
763 let database_future = ThreadsDatabase::connect(cx);
764 cx.spawn(async move |this, cx| {
765 let database = database_future.await.map_err(|err| anyhow!(err))?;
766 let db_thread = database
767 .load_thread(id.clone())
768 .await?
769 .with_context(|| format!("no thread found with ID: {id:?}"))?;
770
771 this.update(cx, |this, cx| {
772 let summarization_model = LanguageModelRegistry::read_global(cx)
773 .thread_summary_model()
774 .map(|c| c.model);
775
776 cx.new(|cx| {
777 let mut thread = Thread::from_db(
778 id.clone(),
779 db_thread,
780 this.project.clone(),
781 this.project_context.clone(),
782 this.context_server_registry.clone(),
783 this.templates.clone(),
784 cx,
785 );
786 thread.set_summarization_model(summarization_model, cx);
787 thread
788 })
789 })
790 })
791 }
792
793 pub fn open_thread(
794 &mut self,
795 id: acp::SessionId,
796 cx: &mut Context<Self>,
797 ) -> Task<Result<Entity<AcpThread>>> {
798 if let Some(session) = self.sessions.get(&id) {
799 return Task::ready(Ok(session.acp_thread.clone()));
800 }
801
802 let task = self.load_thread(id, cx);
803 cx.spawn(async move |this, cx| {
804 let thread = task.await?;
805 let acp_thread =
806 this.update(cx, |this, cx| this.register_session(thread.clone(), cx))?;
807 let events = thread.update(cx, |thread, cx| thread.replay(cx));
808 cx.update(|cx| {
809 NativeAgentConnection::handle_thread_events(events, acp_thread.downgrade(), cx)
810 })
811 .await?;
812 Ok(acp_thread)
813 })
814 }
815
816 pub fn thread_summary(
817 &mut self,
818 id: acp::SessionId,
819 cx: &mut Context<Self>,
820 ) -> Task<Result<SharedString>> {
821 let thread = self.open_thread(id.clone(), cx);
822 cx.spawn(async move |this, cx| {
823 let acp_thread = thread.await?;
824 let result = this
825 .update(cx, |this, cx| {
826 this.sessions
827 .get(&id)
828 .unwrap()
829 .thread
830 .update(cx, |thread, cx| thread.summary(cx))
831 })?
832 .await
833 .context("Failed to generate summary")?;
834 drop(acp_thread);
835 Ok(result)
836 })
837 }
838
839 fn save_thread(&mut self, thread: Entity<Thread>, cx: &mut Context<Self>) {
840 if thread.read(cx).is_empty() {
841 return;
842 }
843
844 let database_future = ThreadsDatabase::connect(cx);
845 let (id, db_thread) =
846 thread.update(cx, |thread, cx| (thread.id().clone(), thread.to_db(cx)));
847 let Some(session) = self.sessions.get_mut(&id) else {
848 return;
849 };
850 let thread_store = self.thread_store.clone();
851 session.pending_save = cx.spawn(async move |_, cx| {
852 let Some(database) = database_future.await.map_err(|err| anyhow!(err)).log_err() else {
853 return;
854 };
855 let db_thread = db_thread.await;
856 database.save_thread(id, db_thread).await.log_err();
857 thread_store.update(cx, |store, cx| store.reload(cx));
858 });
859 }
860
861 fn send_mcp_prompt(
862 &self,
863 message_id: UserMessageId,
864 session_id: agent_client_protocol::SessionId,
865 prompt_name: String,
866 server_id: ContextServerId,
867 arguments: HashMap<String, String>,
868 original_content: Vec<acp::ContentBlock>,
869 cx: &mut Context<Self>,
870 ) -> Task<Result<acp::PromptResponse>> {
871 let server_store = self.context_server_registry.read(cx).server_store().clone();
872 let path_style = self.project.read(cx).path_style(cx);
873
874 cx.spawn(async move |this, cx| {
875 let prompt =
876 crate::get_prompt(&server_store, &server_id, &prompt_name, arguments, cx).await?;
877
878 let (acp_thread, thread) = this.update(cx, |this, _cx| {
879 let session = this
880 .sessions
881 .get(&session_id)
882 .context("Failed to get session")?;
883 anyhow::Ok((session.acp_thread.clone(), session.thread.clone()))
884 })??;
885
886 let mut last_is_user = true;
887
888 thread.update(cx, |thread, cx| {
889 thread.push_acp_user_block(
890 message_id,
891 original_content.into_iter().skip(1),
892 path_style,
893 cx,
894 );
895 });
896
897 for message in prompt.messages {
898 let context_server::types::PromptMessage { role, content } = message;
899 let block = mcp_message_content_to_acp_content_block(content);
900
901 match role {
902 context_server::types::Role::User => {
903 let id = acp_thread::UserMessageId::new();
904
905 acp_thread.update(cx, |acp_thread, cx| {
906 acp_thread.push_user_content_block_with_indent(
907 Some(id.clone()),
908 block.clone(),
909 true,
910 cx,
911 );
912 });
913
914 thread.update(cx, |thread, cx| {
915 thread.push_acp_user_block(id, [block], path_style, cx);
916 });
917 }
918 context_server::types::Role::Assistant => {
919 acp_thread.update(cx, |acp_thread, cx| {
920 acp_thread.push_assistant_content_block_with_indent(
921 block.clone(),
922 false,
923 true,
924 cx,
925 );
926 });
927
928 thread.update(cx, |thread, cx| {
929 thread.push_acp_agent_block(block, cx);
930 });
931 }
932 }
933
934 last_is_user = role == context_server::types::Role::User;
935 }
936
937 let response_stream = thread.update(cx, |thread, cx| {
938 if last_is_user {
939 thread.send_existing(cx)
940 } else {
941 // Resume if MCP prompt did not end with a user message
942 thread.resume(cx)
943 }
944 })?;
945
946 cx.update(|cx| {
947 NativeAgentConnection::handle_thread_events(
948 response_stream,
949 acp_thread.downgrade(),
950 cx,
951 )
952 })
953 .await
954 })
955 }
956}
957
958/// Wrapper struct that implements the AgentConnection trait
959#[derive(Clone)]
960pub struct NativeAgentConnection(pub Entity<NativeAgent>);
961
962impl NativeAgentConnection {
963 pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option<Entity<Thread>> {
964 self.0
965 .read(cx)
966 .sessions
967 .get(session_id)
968 .map(|session| session.thread.clone())
969 }
970
971 pub fn load_thread(&self, id: acp::SessionId, cx: &mut App) -> Task<Result<Entity<Thread>>> {
972 self.0.update(cx, |this, cx| this.load_thread(id, cx))
973 }
974
975 fn run_turn(
976 &self,
977 session_id: acp::SessionId,
978 cx: &mut App,
979 f: impl 'static
980 + FnOnce(Entity<Thread>, &mut App) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>,
981 ) -> Task<Result<acp::PromptResponse>> {
982 let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
983 agent
984 .sessions
985 .get_mut(&session_id)
986 .map(|s| (s.thread.clone(), s.acp_thread.clone()))
987 }) else {
988 return Task::ready(Err(anyhow!("Session not found")));
989 };
990 log::debug!("Found session for: {}", session_id);
991
992 let response_stream = match f(thread, cx) {
993 Ok(stream) => stream,
994 Err(err) => return Task::ready(Err(err)),
995 };
996 Self::handle_thread_events(response_stream, acp_thread.downgrade(), cx)
997 }
998
999 fn handle_thread_events(
1000 mut events: mpsc::UnboundedReceiver<Result<ThreadEvent>>,
1001 acp_thread: WeakEntity<AcpThread>,
1002 cx: &App,
1003 ) -> Task<Result<acp::PromptResponse>> {
1004 cx.spawn(async move |cx| {
1005 // Handle response stream and forward to session.acp_thread
1006 while let Some(result) = events.next().await {
1007 match result {
1008 Ok(event) => {
1009 log::trace!("Received completion event: {:?}", event);
1010
1011 match event {
1012 ThreadEvent::UserMessage(message) => {
1013 acp_thread.update(cx, |thread, cx| {
1014 for content in message.content {
1015 thread.push_user_content_block(
1016 Some(message.id.clone()),
1017 content.into(),
1018 cx,
1019 );
1020 }
1021 })?;
1022 }
1023 ThreadEvent::AgentText(text) => {
1024 acp_thread.update(cx, |thread, cx| {
1025 thread.push_assistant_content_block(text.into(), false, cx)
1026 })?;
1027 }
1028 ThreadEvent::AgentThinking(text) => {
1029 acp_thread.update(cx, |thread, cx| {
1030 thread.push_assistant_content_block(text.into(), true, cx)
1031 })?;
1032 }
1033 ThreadEvent::ToolCallAuthorization(ToolCallAuthorization {
1034 tool_call,
1035 options,
1036 response,
1037 context: _,
1038 }) => {
1039 let outcome_task = acp_thread.update(cx, |thread, cx| {
1040 thread.request_tool_call_authorization(tool_call, options, cx)
1041 })??;
1042 cx.background_spawn(async move {
1043 if let acp::RequestPermissionOutcome::Selected(
1044 acp::SelectedPermissionOutcome { option_id, .. },
1045 ) = outcome_task.await
1046 {
1047 response
1048 .send(option_id)
1049 .map(|_| anyhow!("authorization receiver was dropped"))
1050 .log_err();
1051 }
1052 })
1053 .detach();
1054 }
1055 ThreadEvent::ToolCall(tool_call) => {
1056 acp_thread.update(cx, |thread, cx| {
1057 thread.upsert_tool_call(tool_call, cx)
1058 })??;
1059 }
1060 ThreadEvent::ToolCallUpdate(update) => {
1061 acp_thread.update(cx, |thread, cx| {
1062 thread.update_tool_call(update, cx)
1063 })??;
1064 }
1065 ThreadEvent::SubagentSpawned(session_id) => {
1066 acp_thread.update(cx, |thread, cx| {
1067 thread.subagent_spawned(session_id, cx);
1068 })?;
1069 }
1070 ThreadEvent::Retry(status) => {
1071 acp_thread.update(cx, |thread, cx| {
1072 thread.update_retry_status(status, cx)
1073 })?;
1074 }
1075 ThreadEvent::Stop(stop_reason) => {
1076 log::debug!("Assistant message complete: {:?}", stop_reason);
1077 return Ok(acp::PromptResponse::new(stop_reason));
1078 }
1079 }
1080 }
1081 Err(e) => {
1082 log::error!("Error in model response stream: {:?}", e);
1083 return Err(e);
1084 }
1085 }
1086 }
1087
1088 log::debug!("Response stream completed");
1089 anyhow::Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
1090 })
1091 }
1092}
1093
1094struct Command<'a> {
1095 prompt_name: &'a str,
1096 arg_value: &'a str,
1097 explicit_server_id: Option<&'a str>,
1098}
1099
1100impl<'a> Command<'a> {
1101 fn parse(prompt: &'a [acp::ContentBlock]) -> Option<Self> {
1102 let acp::ContentBlock::Text(text_content) = prompt.first()? else {
1103 return None;
1104 };
1105 let text = text_content.text.trim();
1106 let command = text.strip_prefix('/')?;
1107 let (command, arg_value) = command
1108 .split_once(char::is_whitespace)
1109 .unwrap_or((command, ""));
1110
1111 if let Some((server_id, prompt_name)) = command.split_once('.') {
1112 Some(Self {
1113 prompt_name,
1114 arg_value,
1115 explicit_server_id: Some(server_id),
1116 })
1117 } else {
1118 Some(Self {
1119 prompt_name: command,
1120 arg_value,
1121 explicit_server_id: None,
1122 })
1123 }
1124 }
1125}
1126
1127struct NativeAgentModelSelector {
1128 session_id: acp::SessionId,
1129 connection: NativeAgentConnection,
1130}
1131
1132impl acp_thread::AgentModelSelector for NativeAgentModelSelector {
1133 fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
1134 log::debug!("NativeAgentConnection::list_models called");
1135 let list = self.connection.0.read(cx).models.model_list.clone();
1136 Task::ready(if list.is_empty() {
1137 Err(anyhow::anyhow!("No models available"))
1138 } else {
1139 Ok(list)
1140 })
1141 }
1142
1143 fn select_model(&self, model_id: acp::ModelId, cx: &mut App) -> Task<Result<()>> {
1144 log::debug!(
1145 "Setting model for session {}: {}",
1146 self.session_id,
1147 model_id
1148 );
1149 let Some(thread) = self
1150 .connection
1151 .0
1152 .read(cx)
1153 .sessions
1154 .get(&self.session_id)
1155 .map(|session| session.thread.clone())
1156 else {
1157 return Task::ready(Err(anyhow!("Session not found")));
1158 };
1159
1160 let Some(model) = self.connection.0.read(cx).models.model_from_id(&model_id) else {
1161 return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
1162 };
1163
1164 // We want to reset the effort level when switching models, as the currently-selected effort level may
1165 // not be compatible.
1166 let effort = model
1167 .default_effort_level()
1168 .map(|effort_level| effort_level.value.to_string());
1169
1170 thread.update(cx, |thread, cx| {
1171 thread.set_model(model.clone(), cx);
1172 thread.set_thinking_effort(effort.clone(), cx);
1173 thread.set_thinking_enabled(model.supports_thinking(), cx);
1174 });
1175
1176 update_settings_file(
1177 self.connection.0.read(cx).fs.clone(),
1178 cx,
1179 move |settings, cx| {
1180 let provider = model.provider_id().0.to_string();
1181 let model = model.id().0.to_string();
1182 let enable_thinking = thread.read(cx).thinking_enabled();
1183 settings
1184 .agent
1185 .get_or_insert_default()
1186 .set_model(LanguageModelSelection {
1187 provider: provider.into(),
1188 model,
1189 enable_thinking,
1190 effort,
1191 });
1192 },
1193 );
1194
1195 Task::ready(Ok(()))
1196 }
1197
1198 fn selected_model(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelInfo>> {
1199 let Some(thread) = self
1200 .connection
1201 .0
1202 .read(cx)
1203 .sessions
1204 .get(&self.session_id)
1205 .map(|session| session.thread.clone())
1206 else {
1207 return Task::ready(Err(anyhow!("Session not found")));
1208 };
1209 let Some(model) = thread.read(cx).model() else {
1210 return Task::ready(Err(anyhow!("Model not found")));
1211 };
1212 let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
1213 else {
1214 return Task::ready(Err(anyhow!("Provider not found")));
1215 };
1216 Task::ready(Ok(LanguageModels::map_language_model_to_info(
1217 model, &provider,
1218 )))
1219 }
1220
1221 fn watch(&self, cx: &mut App) -> Option<watch::Receiver<()>> {
1222 Some(self.connection.0.read(cx).models.watch())
1223 }
1224
1225 fn should_render_footer(&self) -> bool {
1226 true
1227 }
1228}
1229
1230impl acp_thread::AgentConnection for NativeAgentConnection {
1231 fn telemetry_id(&self) -> SharedString {
1232 "zed".into()
1233 }
1234
1235 fn new_session(
1236 self: Rc<Self>,
1237 project: Entity<Project>,
1238 cwd: &Path,
1239 cx: &mut App,
1240 ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
1241 log::debug!("Creating new thread for project at: {cwd:?}");
1242 Task::ready(Ok(self
1243 .0
1244 .update(cx, |agent, cx| agent.new_session(project, cx))))
1245 }
1246
1247 fn supports_load_session(&self) -> bool {
1248 true
1249 }
1250
1251 fn load_session(
1252 self: Rc<Self>,
1253 session: AgentSessionInfo,
1254 _project: Entity<Project>,
1255 _cwd: &Path,
1256 cx: &mut App,
1257 ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
1258 self.0
1259 .update(cx, |agent, cx| agent.open_thread(session.session_id, cx))
1260 }
1261
1262 fn supports_close_session(&self) -> bool {
1263 true
1264 }
1265
1266 fn close_session(&self, session_id: &acp::SessionId, cx: &mut App) -> Task<Result<()>> {
1267 self.0.update(cx, |agent, _cx| {
1268 agent.sessions.remove(session_id);
1269 });
1270 Task::ready(Ok(()))
1271 }
1272
1273 fn auth_methods(&self) -> &[acp::AuthMethod] {
1274 &[] // No auth for in-process
1275 }
1276
1277 fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
1278 Task::ready(Ok(()))
1279 }
1280
1281 fn model_selector(&self, session_id: &acp::SessionId) -> Option<Rc<dyn AgentModelSelector>> {
1282 Some(Rc::new(NativeAgentModelSelector {
1283 session_id: session_id.clone(),
1284 connection: self.clone(),
1285 }) as Rc<dyn AgentModelSelector>)
1286 }
1287
1288 fn prompt(
1289 &self,
1290 id: Option<acp_thread::UserMessageId>,
1291 params: acp::PromptRequest,
1292 cx: &mut App,
1293 ) -> Task<Result<acp::PromptResponse>> {
1294 let id = id.expect("UserMessageId is required");
1295 let session_id = params.session_id.clone();
1296 log::info!("Received prompt request for session: {}", session_id);
1297 log::debug!("Prompt blocks count: {}", params.prompt.len());
1298
1299 if let Some(parsed_command) = Command::parse(¶ms.prompt) {
1300 let registry = self.0.read(cx).context_server_registry.read(cx);
1301
1302 let explicit_server_id = parsed_command
1303 .explicit_server_id
1304 .map(|server_id| ContextServerId(server_id.into()));
1305
1306 if let Some(prompt) =
1307 registry.find_prompt(explicit_server_id.as_ref(), parsed_command.prompt_name)
1308 {
1309 let arguments = if !parsed_command.arg_value.is_empty()
1310 && let Some(arg_name) = prompt
1311 .prompt
1312 .arguments
1313 .as_ref()
1314 .and_then(|args| args.first())
1315 .map(|arg| arg.name.clone())
1316 {
1317 HashMap::from_iter([(arg_name, parsed_command.arg_value.to_string())])
1318 } else {
1319 Default::default()
1320 };
1321
1322 let prompt_name = prompt.prompt.name.clone();
1323 let server_id = prompt.server_id.clone();
1324
1325 return self.0.update(cx, |agent, cx| {
1326 agent.send_mcp_prompt(
1327 id,
1328 session_id.clone(),
1329 prompt_name,
1330 server_id,
1331 arguments,
1332 params.prompt,
1333 cx,
1334 )
1335 });
1336 };
1337 };
1338
1339 let path_style = self.0.read(cx).project.read(cx).path_style(cx);
1340
1341 self.run_turn(session_id, cx, move |thread, cx| {
1342 let content: Vec<UserMessageContent> = params
1343 .prompt
1344 .into_iter()
1345 .map(|block| UserMessageContent::from_content_block(block, path_style))
1346 .collect::<Vec<_>>();
1347 log::debug!("Converted prompt to message: {} chars", content.len());
1348 log::debug!("Message id: {:?}", id);
1349 log::debug!("Message content: {:?}", content);
1350
1351 thread.update(cx, |thread, cx| thread.send(id, content, cx))
1352 })
1353 }
1354
1355 fn retry(
1356 &self,
1357 session_id: &acp::SessionId,
1358 _cx: &App,
1359 ) -> Option<Rc<dyn acp_thread::AgentSessionRetry>> {
1360 Some(Rc::new(NativeAgentSessionRetry {
1361 connection: self.clone(),
1362 session_id: session_id.clone(),
1363 }) as _)
1364 }
1365
1366 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
1367 log::info!("Cancelling on session: {}", session_id);
1368 self.0.update(cx, |agent, cx| {
1369 if let Some(session) = agent.sessions.get(session_id) {
1370 session
1371 .thread
1372 .update(cx, |thread, cx| thread.cancel(cx))
1373 .detach();
1374 }
1375 });
1376 }
1377
1378 fn truncate(
1379 &self,
1380 session_id: &agent_client_protocol::SessionId,
1381 cx: &App,
1382 ) -> Option<Rc<dyn acp_thread::AgentSessionTruncate>> {
1383 self.0.read_with(cx, |agent, _cx| {
1384 agent.sessions.get(session_id).map(|session| {
1385 Rc::new(NativeAgentSessionTruncate {
1386 thread: session.thread.clone(),
1387 acp_thread: session.acp_thread.downgrade(),
1388 }) as _
1389 })
1390 })
1391 }
1392
1393 fn set_title(
1394 &self,
1395 session_id: &acp::SessionId,
1396 cx: &App,
1397 ) -> Option<Rc<dyn acp_thread::AgentSessionSetTitle>> {
1398 self.0.read_with(cx, |agent, _cx| {
1399 agent
1400 .sessions
1401 .get(session_id)
1402 .filter(|s| !s.thread.read(cx).is_subagent())
1403 .map(|session| {
1404 Rc::new(NativeAgentSessionSetTitle {
1405 thread: session.thread.clone(),
1406 }) as _
1407 })
1408 })
1409 }
1410
1411 fn session_list(&self, cx: &mut App) -> Option<Rc<dyn AgentSessionList>> {
1412 let thread_store = self.0.read(cx).thread_store.clone();
1413 Some(Rc::new(NativeAgentSessionList::new(thread_store, cx)) as _)
1414 }
1415
1416 fn telemetry(&self) -> Option<Rc<dyn acp_thread::AgentTelemetry>> {
1417 Some(Rc::new(self.clone()) as Rc<dyn acp_thread::AgentTelemetry>)
1418 }
1419
1420 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1421 self
1422 }
1423}
1424
1425impl acp_thread::AgentTelemetry for NativeAgentConnection {
1426 fn thread_data(
1427 &self,
1428 session_id: &acp::SessionId,
1429 cx: &mut App,
1430 ) -> Task<Result<serde_json::Value>> {
1431 let Some(session) = self.0.read(cx).sessions.get(session_id) else {
1432 return Task::ready(Err(anyhow!("Session not found")));
1433 };
1434
1435 let task = session.thread.read(cx).to_db(cx);
1436 cx.background_spawn(async move {
1437 serde_json::to_value(task.await).context("Failed to serialize thread")
1438 })
1439 }
1440}
1441
1442pub struct NativeAgentSessionList {
1443 thread_store: Entity<ThreadStore>,
1444 updates_tx: smol::channel::Sender<acp_thread::SessionListUpdate>,
1445 updates_rx: smol::channel::Receiver<acp_thread::SessionListUpdate>,
1446 _subscription: Subscription,
1447}
1448
1449impl NativeAgentSessionList {
1450 fn new(thread_store: Entity<ThreadStore>, cx: &mut App) -> Self {
1451 let (tx, rx) = smol::channel::unbounded();
1452 let this_tx = tx.clone();
1453 let subscription = cx.observe(&thread_store, move |_, _| {
1454 this_tx
1455 .try_send(acp_thread::SessionListUpdate::Refresh)
1456 .ok();
1457 });
1458 Self {
1459 thread_store,
1460 updates_tx: tx,
1461 updates_rx: rx,
1462 _subscription: subscription,
1463 }
1464 }
1465
1466 fn to_session_info(entry: DbThreadMetadata) -> AgentSessionInfo {
1467 AgentSessionInfo {
1468 session_id: entry.id,
1469 cwd: None,
1470 title: Some(entry.title),
1471 updated_at: Some(entry.updated_at),
1472 meta: None,
1473 }
1474 }
1475
1476 pub fn thread_store(&self) -> &Entity<ThreadStore> {
1477 &self.thread_store
1478 }
1479}
1480
1481impl AgentSessionList for NativeAgentSessionList {
1482 fn list_sessions(
1483 &self,
1484 _request: AgentSessionListRequest,
1485 cx: &mut App,
1486 ) -> Task<Result<AgentSessionListResponse>> {
1487 let sessions = self
1488 .thread_store
1489 .read(cx)
1490 .entries()
1491 .map(Self::to_session_info)
1492 .collect();
1493 Task::ready(Ok(AgentSessionListResponse::new(sessions)))
1494 }
1495
1496 fn supports_delete(&self) -> bool {
1497 true
1498 }
1499
1500 fn delete_session(&self, session_id: &acp::SessionId, cx: &mut App) -> Task<Result<()>> {
1501 self.thread_store
1502 .update(cx, |store, cx| store.delete_thread(session_id.clone(), cx))
1503 }
1504
1505 fn delete_sessions(&self, cx: &mut App) -> Task<Result<()>> {
1506 self.thread_store
1507 .update(cx, |store, cx| store.delete_threads(cx))
1508 }
1509
1510 fn watch(
1511 &self,
1512 _cx: &mut App,
1513 ) -> Option<smol::channel::Receiver<acp_thread::SessionListUpdate>> {
1514 Some(self.updates_rx.clone())
1515 }
1516
1517 fn notify_refresh(&self) {
1518 self.updates_tx
1519 .try_send(acp_thread::SessionListUpdate::Refresh)
1520 .ok();
1521 }
1522
1523 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1524 self
1525 }
1526}
1527
1528struct NativeAgentSessionTruncate {
1529 thread: Entity<Thread>,
1530 acp_thread: WeakEntity<AcpThread>,
1531}
1532
1533impl acp_thread::AgentSessionTruncate for NativeAgentSessionTruncate {
1534 fn run(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
1535 match self.thread.update(cx, |thread, cx| {
1536 thread.truncate(message_id.clone(), cx)?;
1537 Ok(thread.latest_token_usage())
1538 }) {
1539 Ok(usage) => {
1540 self.acp_thread
1541 .update(cx, |thread, cx| {
1542 thread.update_token_usage(usage, cx);
1543 })
1544 .ok();
1545 Task::ready(Ok(()))
1546 }
1547 Err(error) => Task::ready(Err(error)),
1548 }
1549 }
1550}
1551
1552struct NativeAgentSessionRetry {
1553 connection: NativeAgentConnection,
1554 session_id: acp::SessionId,
1555}
1556
1557impl acp_thread::AgentSessionRetry for NativeAgentSessionRetry {
1558 fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>> {
1559 self.connection
1560 .run_turn(self.session_id.clone(), cx, |thread, cx| {
1561 thread.update(cx, |thread, cx| thread.resume(cx))
1562 })
1563 }
1564}
1565
1566struct NativeAgentSessionSetTitle {
1567 thread: Entity<Thread>,
1568}
1569
1570impl acp_thread::AgentSessionSetTitle for NativeAgentSessionSetTitle {
1571 fn run(&self, title: SharedString, cx: &mut App) -> Task<Result<()>> {
1572 self.thread
1573 .update(cx, |thread, cx| thread.set_title(title, cx));
1574 Task::ready(Ok(()))
1575 }
1576}
1577
1578pub struct NativeThreadEnvironment {
1579 agent: WeakEntity<NativeAgent>,
1580 acp_thread: WeakEntity<AcpThread>,
1581}
1582
1583impl NativeThreadEnvironment {
1584 pub(crate) fn create_subagent_thread(
1585 agent: WeakEntity<NativeAgent>,
1586 parent_thread_entity: Entity<Thread>,
1587 label: String,
1588 initial_prompt: String,
1589 timeout: Option<Duration>,
1590 cx: &mut App,
1591 ) -> Result<Rc<dyn SubagentHandle>> {
1592 let parent_thread = parent_thread_entity.read(cx);
1593 let current_depth = parent_thread.depth();
1594
1595 if current_depth >= MAX_SUBAGENT_DEPTH {
1596 return Err(anyhow!(
1597 "Maximum subagent depth ({}) reached",
1598 MAX_SUBAGENT_DEPTH
1599 ));
1600 }
1601
1602 let subagent_thread: Entity<Thread> = cx.new(|cx| {
1603 let mut thread = Thread::new_subagent(&parent_thread_entity, cx);
1604 thread.set_title(label.into(), cx);
1605 thread
1606 });
1607
1608 let session_id = subagent_thread.read(cx).id().clone();
1609
1610 let acp_thread = agent.update(cx, |agent, cx| {
1611 agent.register_session(subagent_thread.clone(), cx)
1612 })?;
1613
1614 Self::prompt_subagent(
1615 session_id,
1616 subagent_thread,
1617 acp_thread,
1618 parent_thread_entity,
1619 initial_prompt,
1620 timeout,
1621 cx,
1622 )
1623 }
1624
1625 pub(crate) fn resume_subagent_thread(
1626 agent: WeakEntity<NativeAgent>,
1627 parent_thread_entity: Entity<Thread>,
1628 session_id: acp::SessionId,
1629 follow_up_prompt: String,
1630 timeout: Option<Duration>,
1631 cx: &mut App,
1632 ) -> Result<Rc<dyn SubagentHandle>> {
1633 let (subagent_thread, acp_thread) = agent.update(cx, |agent, _cx| {
1634 let session = agent
1635 .sessions
1636 .get(&session_id)
1637 .ok_or_else(|| anyhow!("No subagent session found with id {session_id}"))?;
1638 anyhow::Ok((session.thread.clone(), session.acp_thread.clone()))
1639 })??;
1640
1641 Self::prompt_subagent(
1642 session_id,
1643 subagent_thread,
1644 acp_thread,
1645 parent_thread_entity,
1646 follow_up_prompt,
1647 timeout,
1648 cx,
1649 )
1650 }
1651
1652 fn prompt_subagent(
1653 session_id: acp::SessionId,
1654 subagent_thread: Entity<Thread>,
1655 acp_thread: Entity<acp_thread::AcpThread>,
1656 parent_thread_entity: Entity<Thread>,
1657 prompt: String,
1658 timeout: Option<Duration>,
1659 cx: &mut App,
1660 ) -> Result<Rc<dyn SubagentHandle>> {
1661 parent_thread_entity.update(cx, |parent_thread, _cx| {
1662 parent_thread.register_running_subagent(subagent_thread.downgrade())
1663 });
1664
1665 let task = acp_thread.update(cx, |acp_thread, cx| {
1666 acp_thread.send(vec![prompt.into()], cx)
1667 });
1668
1669 let timeout_timer = timeout.map(|d| cx.background_executor().timer(d));
1670 let wait_for_prompt_to_complete = cx
1671 .background_spawn(async move {
1672 if let Some(timer) = timeout_timer {
1673 futures::select! {
1674 _ = timer.fuse() => SubagentInitialPromptResult::Timeout,
1675 response = task.fuse() => {
1676 let response = response.log_err().flatten();
1677 if response.is_some_and(|response| {
1678 response.stop_reason == acp::StopReason::Cancelled
1679 })
1680 {
1681 SubagentInitialPromptResult::Cancelled
1682 } else {
1683 SubagentInitialPromptResult::Completed
1684 }
1685 },
1686 }
1687 } else {
1688 let response = task.await.log_err().flatten();
1689 if response
1690 .is_some_and(|response| response.stop_reason == acp::StopReason::Cancelled)
1691 {
1692 SubagentInitialPromptResult::Cancelled
1693 } else {
1694 SubagentInitialPromptResult::Completed
1695 }
1696 }
1697 })
1698 .shared();
1699
1700 Ok(Rc::new(NativeSubagentHandle {
1701 session_id,
1702 subagent_thread,
1703 parent_thread: parent_thread_entity.downgrade(),
1704 wait_for_prompt_to_complete,
1705 }) as _)
1706 }
1707}
1708
1709impl ThreadEnvironment for NativeThreadEnvironment {
1710 fn create_terminal(
1711 &self,
1712 command: String,
1713 cwd: Option<PathBuf>,
1714 output_byte_limit: Option<u64>,
1715 cx: &mut AsyncApp,
1716 ) -> Task<Result<Rc<dyn TerminalHandle>>> {
1717 let task = self.acp_thread.update(cx, |thread, cx| {
1718 thread.create_terminal(command, vec![], vec![], cwd, output_byte_limit, cx)
1719 });
1720
1721 let acp_thread = self.acp_thread.clone();
1722 cx.spawn(async move |cx| {
1723 let terminal = task?.await?;
1724
1725 let (drop_tx, drop_rx) = oneshot::channel();
1726 let terminal_id = terminal.read_with(cx, |terminal, _cx| terminal.id().clone());
1727
1728 cx.spawn(async move |cx| {
1729 drop_rx.await.ok();
1730 acp_thread.update(cx, |thread, cx| thread.release_terminal(terminal_id, cx))
1731 })
1732 .detach();
1733
1734 let handle = AcpTerminalHandle {
1735 terminal,
1736 _drop_tx: Some(drop_tx),
1737 };
1738
1739 Ok(Rc::new(handle) as _)
1740 })
1741 }
1742
1743 fn create_subagent(
1744 &self,
1745 parent_thread_entity: Entity<Thread>,
1746 label: String,
1747 initial_prompt: String,
1748 timeout: Option<Duration>,
1749 cx: &mut App,
1750 ) -> Result<Rc<dyn SubagentHandle>> {
1751 Self::create_subagent_thread(
1752 self.agent.clone(),
1753 parent_thread_entity,
1754 label,
1755 initial_prompt,
1756 timeout,
1757 cx,
1758 )
1759 }
1760
1761 fn resume_subagent(
1762 &self,
1763 parent_thread_entity: Entity<Thread>,
1764 session_id: acp::SessionId,
1765 follow_up_prompt: String,
1766 timeout: Option<Duration>,
1767 cx: &mut App,
1768 ) -> Result<Rc<dyn SubagentHandle>> {
1769 Self::resume_subagent_thread(
1770 self.agent.clone(),
1771 parent_thread_entity,
1772 session_id,
1773 follow_up_prompt,
1774 timeout,
1775 cx,
1776 )
1777 }
1778}
1779
1780#[derive(Debug, Clone, Copy)]
1781enum SubagentInitialPromptResult {
1782 Completed,
1783 Timeout,
1784 Cancelled,
1785}
1786
1787pub struct NativeSubagentHandle {
1788 session_id: acp::SessionId,
1789 parent_thread: WeakEntity<Thread>,
1790 subagent_thread: Entity<Thread>,
1791 wait_for_prompt_to_complete: Shared<Task<SubagentInitialPromptResult>>,
1792}
1793
1794impl SubagentHandle for NativeSubagentHandle {
1795 fn id(&self) -> acp::SessionId {
1796 self.session_id.clone()
1797 }
1798
1799 fn wait_for_output(&self, cx: &AsyncApp) -> Task<Result<String>> {
1800 let thread = self.subagent_thread.clone();
1801 let wait_for_prompt = self.wait_for_prompt_to_complete.clone();
1802
1803 let subagent_session_id = self.session_id.clone();
1804 let parent_thread = self.parent_thread.clone();
1805
1806 cx.spawn(async move |cx| {
1807 let result = match wait_for_prompt.await {
1808 SubagentInitialPromptResult::Completed => thread.read_with(cx, |thread, _cx| {
1809 thread
1810 .last_message()
1811 .map(|m| m.to_markdown())
1812 .context("No response from subagent")
1813 }),
1814 SubagentInitialPromptResult::Timeout => {
1815 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
1816 Err(anyhow!("The time to complete the task was exceeded."))
1817 }
1818 SubagentInitialPromptResult::Cancelled => Err(anyhow!("User cancelled")),
1819 };
1820
1821 parent_thread
1822 .update(cx, |parent_thread, cx| {
1823 parent_thread.unregister_running_subagent(&subagent_session_id, cx)
1824 })
1825 .ok();
1826
1827 result
1828 })
1829 }
1830}
1831
1832pub struct AcpTerminalHandle {
1833 terminal: Entity<acp_thread::Terminal>,
1834 _drop_tx: Option<oneshot::Sender<()>>,
1835}
1836
1837impl TerminalHandle for AcpTerminalHandle {
1838 fn id(&self, cx: &AsyncApp) -> Result<acp::TerminalId> {
1839 Ok(self.terminal.read_with(cx, |term, _cx| term.id().clone()))
1840 }
1841
1842 fn wait_for_exit(&self, cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
1843 Ok(self
1844 .terminal
1845 .read_with(cx, |term, _cx| term.wait_for_exit()))
1846 }
1847
1848 fn current_output(&self, cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
1849 Ok(self
1850 .terminal
1851 .read_with(cx, |term, cx| term.current_output(cx)))
1852 }
1853
1854 fn kill(&self, cx: &AsyncApp) -> Result<()> {
1855 cx.update(|cx| {
1856 self.terminal.update(cx, |terminal, cx| {
1857 terminal.kill(cx);
1858 });
1859 });
1860 Ok(())
1861 }
1862
1863 fn was_stopped_by_user(&self, cx: &AsyncApp) -> Result<bool> {
1864 Ok(self
1865 .terminal
1866 .read_with(cx, |term, _cx| term.was_stopped_by_user()))
1867 }
1868}
1869
1870#[cfg(test)]
1871mod internal_tests {
1872 use super::*;
1873 use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelInfo, MentionUri};
1874 use fs::FakeFs;
1875 use gpui::TestAppContext;
1876 use indoc::formatdoc;
1877 use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
1878 use language_model::{LanguageModelProviderId, LanguageModelProviderName};
1879 use serde_json::json;
1880 use settings::SettingsStore;
1881 use util::{path, rel_path::rel_path};
1882
1883 #[gpui::test]
1884 async fn test_maintaining_project_context(cx: &mut TestAppContext) {
1885 init_test(cx);
1886 let fs = FakeFs::new(cx.executor());
1887 fs.insert_tree(
1888 "/",
1889 json!({
1890 "a": {}
1891 }),
1892 )
1893 .await;
1894 let project = Project::test(fs.clone(), [], cx).await;
1895 let thread_store = cx.new(|cx| ThreadStore::new(cx));
1896 let agent = NativeAgent::new(
1897 project.clone(),
1898 thread_store,
1899 Templates::new(),
1900 None,
1901 fs.clone(),
1902 &mut cx.to_async(),
1903 )
1904 .await
1905 .unwrap();
1906 agent.read_with(cx, |agent, cx| {
1907 assert_eq!(agent.project_context.read(cx).worktrees, vec![])
1908 });
1909
1910 let worktree = project
1911 .update(cx, |project, cx| project.create_worktree("/a", true, cx))
1912 .await
1913 .unwrap();
1914 cx.run_until_parked();
1915 agent.read_with(cx, |agent, cx| {
1916 assert_eq!(
1917 agent.project_context.read(cx).worktrees,
1918 vec![WorktreeContext {
1919 root_name: "a".into(),
1920 abs_path: Path::new("/a").into(),
1921 rules_file: None
1922 }]
1923 )
1924 });
1925
1926 // Creating `/a/.rules` updates the project context.
1927 fs.insert_file("/a/.rules", Vec::new()).await;
1928 cx.run_until_parked();
1929 agent.read_with(cx, |agent, cx| {
1930 let rules_entry = worktree
1931 .read(cx)
1932 .entry_for_path(rel_path(".rules"))
1933 .unwrap();
1934 assert_eq!(
1935 agent.project_context.read(cx).worktrees,
1936 vec![WorktreeContext {
1937 root_name: "a".into(),
1938 abs_path: Path::new("/a").into(),
1939 rules_file: Some(RulesFileContext {
1940 path_in_worktree: rel_path(".rules").into(),
1941 text: "".into(),
1942 project_entry_id: rules_entry.id.to_usize()
1943 })
1944 }]
1945 )
1946 });
1947 }
1948
1949 #[gpui::test]
1950 async fn test_listing_models(cx: &mut TestAppContext) {
1951 init_test(cx);
1952 let fs = FakeFs::new(cx.executor());
1953 fs.insert_tree("/", json!({ "a": {} })).await;
1954 let project = Project::test(fs.clone(), [], cx).await;
1955 let thread_store = cx.new(|cx| ThreadStore::new(cx));
1956 let connection = NativeAgentConnection(
1957 NativeAgent::new(
1958 project.clone(),
1959 thread_store,
1960 Templates::new(),
1961 None,
1962 fs.clone(),
1963 &mut cx.to_async(),
1964 )
1965 .await
1966 .unwrap(),
1967 );
1968
1969 // Create a thread/session
1970 let acp_thread = cx
1971 .update(|cx| {
1972 Rc::new(connection.clone()).new_session(project.clone(), Path::new("/a"), cx)
1973 })
1974 .await
1975 .unwrap();
1976
1977 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
1978
1979 let models = cx
1980 .update(|cx| {
1981 connection
1982 .model_selector(&session_id)
1983 .unwrap()
1984 .list_models(cx)
1985 })
1986 .await
1987 .unwrap();
1988
1989 let acp_thread::AgentModelList::Grouped(models) = models else {
1990 panic!("Unexpected model group");
1991 };
1992 assert_eq!(
1993 models,
1994 IndexMap::from_iter([(
1995 AgentModelGroupName("Fake".into()),
1996 vec![AgentModelInfo {
1997 id: acp::ModelId::new("fake/fake"),
1998 name: "Fake".into(),
1999 description: None,
2000 icon: Some(acp_thread::AgentModelIcon::Named(
2001 ui::IconName::ZedAssistant
2002 )),
2003 is_latest: false,
2004 cost: None,
2005 }]
2006 )])
2007 );
2008 }
2009
2010 #[gpui::test]
2011 async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) {
2012 init_test(cx);
2013 let fs = FakeFs::new(cx.executor());
2014 fs.create_dir(paths::settings_file().parent().unwrap())
2015 .await
2016 .unwrap();
2017 fs.insert_file(
2018 paths::settings_file(),
2019 json!({
2020 "agent": {
2021 "default_model": {
2022 "provider": "foo",
2023 "model": "bar"
2024 }
2025 }
2026 })
2027 .to_string()
2028 .into_bytes(),
2029 )
2030 .await;
2031 let project = Project::test(fs.clone(), [], cx).await;
2032
2033 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2034
2035 // Create the agent and connection
2036 let agent = NativeAgent::new(
2037 project.clone(),
2038 thread_store,
2039 Templates::new(),
2040 None,
2041 fs.clone(),
2042 &mut cx.to_async(),
2043 )
2044 .await
2045 .unwrap();
2046 let connection = NativeAgentConnection(agent.clone());
2047
2048 // Create a thread/session
2049 let acp_thread = cx
2050 .update(|cx| {
2051 Rc::new(connection.clone()).new_session(project.clone(), Path::new("/a"), cx)
2052 })
2053 .await
2054 .unwrap();
2055
2056 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2057
2058 // Select a model
2059 let selector = connection.model_selector(&session_id).unwrap();
2060 let model_id = acp::ModelId::new("fake/fake");
2061 cx.update(|cx| selector.select_model(model_id.clone(), cx))
2062 .await
2063 .unwrap();
2064
2065 // Verify the thread has the selected model
2066 agent.read_with(cx, |agent, _| {
2067 let session = agent.sessions.get(&session_id).unwrap();
2068 session.thread.read_with(cx, |thread, _| {
2069 assert_eq!(thread.model().unwrap().id().0, "fake");
2070 });
2071 });
2072
2073 cx.run_until_parked();
2074
2075 // Verify settings file was updated
2076 let settings_content = fs.load(paths::settings_file()).await.unwrap();
2077 let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
2078
2079 // Check that the agent settings contain the selected model
2080 assert_eq!(
2081 settings_json["agent"]["default_model"]["model"],
2082 json!("fake")
2083 );
2084 assert_eq!(
2085 settings_json["agent"]["default_model"]["provider"],
2086 json!("fake")
2087 );
2088
2089 // Register a thinking model and select it.
2090 cx.update(|cx| {
2091 let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2092 "fake-corp",
2093 "fake-thinking",
2094 "Fake Thinking",
2095 true,
2096 ));
2097 let thinking_provider = Arc::new(
2098 FakeLanguageModelProvider::new(
2099 LanguageModelProviderId::from("fake-corp".to_string()),
2100 LanguageModelProviderName::from("Fake Corp".to_string()),
2101 )
2102 .with_models(vec![thinking_model]),
2103 );
2104 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2105 registry.register_provider(thinking_provider, cx);
2106 });
2107 });
2108 agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2109
2110 let selector = connection.model_selector(&session_id).unwrap();
2111 cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2112 .await
2113 .unwrap();
2114 cx.run_until_parked();
2115
2116 // Verify enable_thinking was written to settings as true.
2117 let settings_content = fs.load(paths::settings_file()).await.unwrap();
2118 let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
2119 assert_eq!(
2120 settings_json["agent"]["default_model"]["enable_thinking"],
2121 json!(true),
2122 "selecting a thinking model should persist enable_thinking: true to settings"
2123 );
2124 }
2125
2126 #[gpui::test]
2127 async fn test_select_model_updates_thinking_enabled(cx: &mut TestAppContext) {
2128 init_test(cx);
2129 let fs = FakeFs::new(cx.executor());
2130 fs.create_dir(paths::settings_file().parent().unwrap())
2131 .await
2132 .unwrap();
2133 fs.insert_file(paths::settings_file(), b"{}".to_vec()).await;
2134 let project = Project::test(fs.clone(), [], cx).await;
2135
2136 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2137 let agent = NativeAgent::new(
2138 project.clone(),
2139 thread_store,
2140 Templates::new(),
2141 None,
2142 fs.clone(),
2143 &mut cx.to_async(),
2144 )
2145 .await
2146 .unwrap();
2147 let connection = NativeAgentConnection(agent.clone());
2148
2149 let acp_thread = cx
2150 .update(|cx| {
2151 Rc::new(connection.clone()).new_session(project.clone(), Path::new("/a"), cx)
2152 })
2153 .await
2154 .unwrap();
2155 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2156
2157 // Register a second provider with a thinking model.
2158 cx.update(|cx| {
2159 let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2160 "fake-corp",
2161 "fake-thinking",
2162 "Fake Thinking",
2163 true,
2164 ));
2165 let thinking_provider = Arc::new(
2166 FakeLanguageModelProvider::new(
2167 LanguageModelProviderId::from("fake-corp".to_string()),
2168 LanguageModelProviderName::from("Fake Corp".to_string()),
2169 )
2170 .with_models(vec![thinking_model]),
2171 );
2172 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2173 registry.register_provider(thinking_provider, cx);
2174 });
2175 });
2176 // Refresh the agent's model list so it picks up the new provider.
2177 agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2178
2179 // Thread starts with thinking_enabled = false (the default).
2180 agent.read_with(cx, |agent, _| {
2181 let session = agent.sessions.get(&session_id).unwrap();
2182 session.thread.read_with(cx, |thread, _| {
2183 assert!(!thread.thinking_enabled(), "thinking defaults to false");
2184 });
2185 });
2186
2187 // Select the thinking model via select_model.
2188 let selector = connection.model_selector(&session_id).unwrap();
2189 cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2190 .await
2191 .unwrap();
2192
2193 // select_model should have enabled thinking based on the model's supports_thinking().
2194 agent.read_with(cx, |agent, _| {
2195 let session = agent.sessions.get(&session_id).unwrap();
2196 session.thread.read_with(cx, |thread, _| {
2197 assert!(
2198 thread.thinking_enabled(),
2199 "select_model should enable thinking when model supports it"
2200 );
2201 });
2202 });
2203
2204 // Switch back to the non-thinking model.
2205 let selector = connection.model_selector(&session_id).unwrap();
2206 cx.update(|cx| selector.select_model(acp::ModelId::new("fake/fake"), cx))
2207 .await
2208 .unwrap();
2209
2210 // select_model should have disabled thinking.
2211 agent.read_with(cx, |agent, _| {
2212 let session = agent.sessions.get(&session_id).unwrap();
2213 session.thread.read_with(cx, |thread, _| {
2214 assert!(
2215 !thread.thinking_enabled(),
2216 "select_model should disable thinking when model does not support it"
2217 );
2218 });
2219 });
2220 }
2221
2222 #[gpui::test]
2223 async fn test_loaded_thread_preserves_thinking_enabled(cx: &mut TestAppContext) {
2224 init_test(cx);
2225 let fs = FakeFs::new(cx.executor());
2226 fs.insert_tree("/", json!({ "a": {} })).await;
2227 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2228 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2229 let agent = NativeAgent::new(
2230 project.clone(),
2231 thread_store.clone(),
2232 Templates::new(),
2233 None,
2234 fs.clone(),
2235 &mut cx.to_async(),
2236 )
2237 .await
2238 .unwrap();
2239 let connection = Rc::new(NativeAgentConnection(agent.clone()));
2240
2241 // Register a thinking model.
2242 let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2243 "fake-corp",
2244 "fake-thinking",
2245 "Fake Thinking",
2246 true,
2247 ));
2248 let thinking_provider = Arc::new(
2249 FakeLanguageModelProvider::new(
2250 LanguageModelProviderId::from("fake-corp".to_string()),
2251 LanguageModelProviderName::from("Fake Corp".to_string()),
2252 )
2253 .with_models(vec![thinking_model.clone()]),
2254 );
2255 cx.update(|cx| {
2256 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2257 registry.register_provider(thinking_provider, cx);
2258 });
2259 });
2260 agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2261
2262 // Create a thread and select the thinking model.
2263 let acp_thread = cx
2264 .update(|cx| {
2265 connection
2266 .clone()
2267 .new_session(project.clone(), Path::new("/a"), cx)
2268 })
2269 .await
2270 .unwrap();
2271 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2272
2273 let selector = connection.model_selector(&session_id).unwrap();
2274 cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2275 .await
2276 .unwrap();
2277
2278 // Verify thinking is enabled after selecting the thinking model.
2279 let thread = agent.read_with(cx, |agent, _| {
2280 agent.sessions.get(&session_id).unwrap().thread.clone()
2281 });
2282 thread.read_with(cx, |thread, _| {
2283 assert!(
2284 thread.thinking_enabled(),
2285 "thinking should be enabled after selecting thinking model"
2286 );
2287 });
2288
2289 // Send a message so the thread gets persisted.
2290 let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
2291 let send = cx.foreground_executor().spawn(send);
2292 cx.run_until_parked();
2293
2294 thinking_model.send_last_completion_stream_text_chunk("Response.");
2295 thinking_model.end_last_completion_stream();
2296
2297 send.await.unwrap();
2298 cx.run_until_parked();
2299
2300 // Close the session so it can be reloaded from disk.
2301 cx.update(|cx| connection.clone().close_session(&session_id, cx))
2302 .await
2303 .unwrap();
2304 drop(thread);
2305 drop(acp_thread);
2306 agent.read_with(cx, |agent, _| {
2307 assert!(agent.sessions.is_empty());
2308 });
2309
2310 // Reload the thread and verify thinking_enabled is still true.
2311 let reloaded_acp_thread = agent
2312 .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
2313 .await
2314 .unwrap();
2315 let reloaded_thread = agent.read_with(cx, |agent, _| {
2316 agent.sessions.get(&session_id).unwrap().thread.clone()
2317 });
2318 reloaded_thread.read_with(cx, |thread, _| {
2319 assert!(
2320 thread.thinking_enabled(),
2321 "thinking_enabled should be preserved when reloading a thread with a thinking model"
2322 );
2323 });
2324
2325 drop(reloaded_acp_thread);
2326 }
2327
2328 #[gpui::test]
2329 async fn test_loaded_thread_preserves_model(cx: &mut TestAppContext) {
2330 init_test(cx);
2331 let fs = FakeFs::new(cx.executor());
2332 fs.insert_tree("/", json!({ "a": {} })).await;
2333 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2334 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2335 let agent = NativeAgent::new(
2336 project.clone(),
2337 thread_store.clone(),
2338 Templates::new(),
2339 None,
2340 fs.clone(),
2341 &mut cx.to_async(),
2342 )
2343 .await
2344 .unwrap();
2345 let connection = Rc::new(NativeAgentConnection(agent.clone()));
2346
2347 // Register a model where id() != name(), like real Anthropic models
2348 // (e.g. id="claude-sonnet-4-5-thinking-latest", name="Claude Sonnet 4.5 Thinking").
2349 let model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2350 "fake-corp",
2351 "custom-model-id",
2352 "Custom Model Display Name",
2353 false,
2354 ));
2355 let provider = Arc::new(
2356 FakeLanguageModelProvider::new(
2357 LanguageModelProviderId::from("fake-corp".to_string()),
2358 LanguageModelProviderName::from("Fake Corp".to_string()),
2359 )
2360 .with_models(vec![model.clone()]),
2361 );
2362 cx.update(|cx| {
2363 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2364 registry.register_provider(provider, cx);
2365 });
2366 });
2367 agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2368
2369 // Create a thread and select the model.
2370 let acp_thread = cx
2371 .update(|cx| {
2372 connection
2373 .clone()
2374 .new_session(project.clone(), Path::new("/a"), cx)
2375 })
2376 .await
2377 .unwrap();
2378 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2379
2380 let selector = connection.model_selector(&session_id).unwrap();
2381 cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/custom-model-id"), cx))
2382 .await
2383 .unwrap();
2384
2385 let thread = agent.read_with(cx, |agent, _| {
2386 agent.sessions.get(&session_id).unwrap().thread.clone()
2387 });
2388 thread.read_with(cx, |thread, _| {
2389 assert_eq!(
2390 thread.model().unwrap().id().0.as_ref(),
2391 "custom-model-id",
2392 "model should be set before persisting"
2393 );
2394 });
2395
2396 // Send a message so the thread gets persisted.
2397 let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
2398 let send = cx.foreground_executor().spawn(send);
2399 cx.run_until_parked();
2400
2401 model.send_last_completion_stream_text_chunk("Response.");
2402 model.end_last_completion_stream();
2403
2404 send.await.unwrap();
2405 cx.run_until_parked();
2406
2407 // Close the session so it can be reloaded from disk.
2408 cx.update(|cx| connection.clone().close_session(&session_id, cx))
2409 .await
2410 .unwrap();
2411 drop(thread);
2412 drop(acp_thread);
2413 agent.read_with(cx, |agent, _| {
2414 assert!(agent.sessions.is_empty());
2415 });
2416
2417 // Reload the thread and verify the model was preserved.
2418 let reloaded_acp_thread = agent
2419 .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
2420 .await
2421 .unwrap();
2422 let reloaded_thread = agent.read_with(cx, |agent, _| {
2423 agent.sessions.get(&session_id).unwrap().thread.clone()
2424 });
2425 reloaded_thread.read_with(cx, |thread, _| {
2426 let reloaded_model = thread
2427 .model()
2428 .expect("model should be present after reload");
2429 assert_eq!(
2430 reloaded_model.id().0.as_ref(),
2431 "custom-model-id",
2432 "reloaded thread should have the same model, not fall back to the default"
2433 );
2434 });
2435
2436 drop(reloaded_acp_thread);
2437 }
2438
2439 #[gpui::test]
2440 async fn test_save_load_thread(cx: &mut TestAppContext) {
2441 init_test(cx);
2442 let fs = FakeFs::new(cx.executor());
2443 fs.insert_tree(
2444 "/",
2445 json!({
2446 "a": {
2447 "b.md": "Lorem"
2448 }
2449 }),
2450 )
2451 .await;
2452 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2453 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2454 let agent = NativeAgent::new(
2455 project.clone(),
2456 thread_store.clone(),
2457 Templates::new(),
2458 None,
2459 fs.clone(),
2460 &mut cx.to_async(),
2461 )
2462 .await
2463 .unwrap();
2464 let connection = Rc::new(NativeAgentConnection(agent.clone()));
2465
2466 let acp_thread = cx
2467 .update(|cx| {
2468 connection
2469 .clone()
2470 .new_session(project.clone(), Path::new(""), cx)
2471 })
2472 .await
2473 .unwrap();
2474 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2475 let thread = agent.read_with(cx, |agent, _| {
2476 agent.sessions.get(&session_id).unwrap().thread.clone()
2477 });
2478
2479 // Ensure empty threads are not saved, even if they get mutated.
2480 let model = Arc::new(FakeLanguageModel::default());
2481 let summary_model = Arc::new(FakeLanguageModel::default());
2482 thread.update(cx, |thread, cx| {
2483 thread.set_model(model.clone(), cx);
2484 thread.set_summarization_model(Some(summary_model.clone()), cx);
2485 });
2486 cx.run_until_parked();
2487 assert_eq!(thread_entries(&thread_store, cx), vec![]);
2488
2489 let send = acp_thread.update(cx, |thread, cx| {
2490 thread.send(
2491 vec![
2492 "What does ".into(),
2493 acp::ContentBlock::ResourceLink(acp::ResourceLink::new(
2494 "b.md",
2495 MentionUri::File {
2496 abs_path: path!("/a/b.md").into(),
2497 }
2498 .to_uri()
2499 .to_string(),
2500 )),
2501 " mean?".into(),
2502 ],
2503 cx,
2504 )
2505 });
2506 let send = cx.foreground_executor().spawn(send);
2507 cx.run_until_parked();
2508
2509 model.send_last_completion_stream_text_chunk("Lorem.");
2510 model.end_last_completion_stream();
2511 cx.run_until_parked();
2512 summary_model
2513 .send_last_completion_stream_text_chunk(&format!("Explaining {}", path!("/a/b.md")));
2514 summary_model.end_last_completion_stream();
2515
2516 send.await.unwrap();
2517 let uri = MentionUri::File {
2518 abs_path: path!("/a/b.md").into(),
2519 }
2520 .to_uri();
2521 acp_thread.read_with(cx, |thread, cx| {
2522 assert_eq!(
2523 thread.to_markdown(cx),
2524 formatdoc! {"
2525 ## User
2526
2527 What does [@b.md]({uri}) mean?
2528
2529 ## Assistant
2530
2531 Lorem.
2532
2533 "}
2534 )
2535 });
2536
2537 cx.run_until_parked();
2538
2539 // Close the session so it can be reloaded from disk.
2540 cx.update(|cx| connection.clone().close_session(&session_id, cx))
2541 .await
2542 .unwrap();
2543 drop(thread);
2544 drop(acp_thread);
2545 agent.read_with(cx, |agent, _| {
2546 assert_eq!(agent.sessions.keys().cloned().collect::<Vec<_>>(), []);
2547 });
2548
2549 // Ensure the thread can be reloaded from disk.
2550 assert_eq!(
2551 thread_entries(&thread_store, cx),
2552 vec![(
2553 session_id.clone(),
2554 format!("Explaining {}", path!("/a/b.md"))
2555 )]
2556 );
2557 let acp_thread = agent
2558 .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
2559 .await
2560 .unwrap();
2561 acp_thread.read_with(cx, |thread, cx| {
2562 assert_eq!(
2563 thread.to_markdown(cx),
2564 formatdoc! {"
2565 ## User
2566
2567 What does [@b.md]({uri}) mean?
2568
2569 ## Assistant
2570
2571 Lorem.
2572
2573 "}
2574 )
2575 });
2576 }
2577
2578 fn thread_entries(
2579 thread_store: &Entity<ThreadStore>,
2580 cx: &mut TestAppContext,
2581 ) -> Vec<(acp::SessionId, String)> {
2582 thread_store.read_with(cx, |store, _| {
2583 store
2584 .entries()
2585 .map(|entry| (entry.id.clone(), entry.title.to_string()))
2586 .collect::<Vec<_>>()
2587 })
2588 }
2589
2590 fn init_test(cx: &mut TestAppContext) {
2591 env_logger::try_init().ok();
2592 cx.update(|cx| {
2593 let settings_store = SettingsStore::test(cx);
2594 cx.set_global(settings_store);
2595
2596 LanguageModelRegistry::test(cx);
2597 });
2598 }
2599}
2600
2601fn mcp_message_content_to_acp_content_block(
2602 content: context_server::types::MessageContent,
2603) -> acp::ContentBlock {
2604 match content {
2605 context_server::types::MessageContent::Text {
2606 text,
2607 annotations: _,
2608 } => text.into(),
2609 context_server::types::MessageContent::Image {
2610 data,
2611 mime_type,
2612 annotations: _,
2613 } => acp::ContentBlock::Image(acp::ImageContent::new(data, mime_type)),
2614 context_server::types::MessageContent::Audio {
2615 data,
2616 mime_type,
2617 annotations: _,
2618 } => acp::ContentBlock::Audio(acp::AudioContent::new(data, mime_type)),
2619 context_server::types::MessageContent::Resource {
2620 resource,
2621 annotations: _,
2622 } => {
2623 let mut link =
2624 acp::ResourceLink::new(resource.uri.to_string(), resource.uri.to_string());
2625 if let Some(mime_type) = resource.mime_type {
2626 link = link.mime_type(mime_type);
2627 }
2628 acp::ContentBlock::ResourceLink(link)
2629 }
2630 }
2631}