1mod db;
2mod edit_agent;
3mod legacy_thread;
4mod native_agent_server;
5pub mod outline;
6mod pattern_extraction;
7mod templates;
8#[cfg(test)]
9mod tests;
10mod thread;
11mod thread_store;
12mod tool_permissions;
13mod tools;
14
15use context_server::ContextServerId;
16pub use db::*;
17pub use native_agent_server::NativeAgentServer;
18pub use pattern_extraction::*;
19pub use shell_command_parser::extract_commands;
20pub use templates::*;
21pub use thread::*;
22pub use thread_store::*;
23pub use tool_permissions::*;
24pub use tools::*;
25
26use acp_thread::{
27 AcpThread, AgentModelSelector, AgentSessionInfo, AgentSessionList, AgentSessionListRequest,
28 AgentSessionListResponse, TokenUsageRatio, UserMessageId,
29};
30use agent_client_protocol as acp;
31use anyhow::{Context as _, Result, anyhow};
32use chrono::{DateTime, Utc};
33use collections::{HashMap, HashSet, IndexMap};
34use fs::Fs;
35use futures::channel::{mpsc, oneshot};
36use futures::future::Shared;
37use futures::{FutureExt as _, StreamExt as _, future};
38use gpui::{
39 App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
40};
41use language_model::{IconOrSvg, LanguageModel, LanguageModelProvider, LanguageModelRegistry};
42use project::{Project, ProjectItem, ProjectPath, Worktree};
43use prompt_store::{
44 ProjectContext, PromptStore, RULES_FILE_NAMES, RulesFileContext, UserRulesContext,
45 WorktreeContext,
46};
47use serde::{Deserialize, Serialize};
48use settings::{LanguageModelSelection, update_settings_file};
49use std::any::Any;
50use std::path::{Path, PathBuf};
51use std::rc::Rc;
52use std::sync::Arc;
53use util::ResultExt;
54use util::rel_path::RelPath;
55
56#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
57pub struct ProjectSnapshot {
58 pub worktree_snapshots: Vec<project::telemetry_snapshot::TelemetryWorktreeSnapshot>,
59 pub timestamp: DateTime<Utc>,
60}
61
62pub struct RulesLoadingError {
63 pub message: SharedString,
64}
65
66/// Holds both the internal Thread and the AcpThread for a session
67struct Session {
68 /// The internal thread that processes messages
69 thread: Entity<Thread>,
70 /// The ACP thread that handles protocol communication
71 acp_thread: Entity<acp_thread::AcpThread>,
72 pending_save: Task<()>,
73 _subscriptions: Vec<Subscription>,
74}
75
76pub struct LanguageModels {
77 /// Access language model by ID
78 models: HashMap<acp::ModelId, Arc<dyn LanguageModel>>,
79 /// Cached list for returning language model information
80 model_list: acp_thread::AgentModelList,
81 refresh_models_rx: watch::Receiver<()>,
82 refresh_models_tx: watch::Sender<()>,
83 _authenticate_all_providers_task: Task<()>,
84}
85
86impl LanguageModels {
87 fn new(cx: &mut App) -> Self {
88 let (refresh_models_tx, refresh_models_rx) = watch::channel(());
89
90 let mut this = Self {
91 models: HashMap::default(),
92 model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()),
93 refresh_models_rx,
94 refresh_models_tx,
95 _authenticate_all_providers_task: Self::authenticate_all_language_model_providers(cx),
96 };
97 this.refresh_list(cx);
98 this
99 }
100
101 fn refresh_list(&mut self, cx: &App) {
102 let providers = LanguageModelRegistry::global(cx)
103 .read(cx)
104 .visible_providers()
105 .into_iter()
106 .filter(|provider| provider.is_authenticated(cx))
107 .collect::<Vec<_>>();
108
109 let mut language_model_list = IndexMap::default();
110 let mut recommended_models = HashSet::default();
111
112 let mut recommended = Vec::new();
113 for provider in &providers {
114 for model in provider.recommended_models(cx) {
115 recommended_models.insert((model.provider_id(), model.id()));
116 recommended.push(Self::map_language_model_to_info(&model, provider));
117 }
118 }
119 if !recommended.is_empty() {
120 language_model_list.insert(
121 acp_thread::AgentModelGroupName("Recommended".into()),
122 recommended,
123 );
124 }
125
126 let mut models = HashMap::default();
127 for provider in providers {
128 let mut provider_models = Vec::new();
129 for model in provider.provided_models(cx) {
130 let model_info = Self::map_language_model_to_info(&model, &provider);
131 let model_id = model_info.id.clone();
132 provider_models.push(model_info);
133 models.insert(model_id, model);
134 }
135 if !provider_models.is_empty() {
136 language_model_list.insert(
137 acp_thread::AgentModelGroupName(provider.name().0.clone()),
138 provider_models,
139 );
140 }
141 }
142
143 self.models = models;
144 self.model_list = acp_thread::AgentModelList::Grouped(language_model_list);
145 self.refresh_models_tx.send(()).ok();
146 }
147
148 fn watch(&self) -> watch::Receiver<()> {
149 self.refresh_models_rx.clone()
150 }
151
152 pub fn model_from_id(&self, model_id: &acp::ModelId) -> Option<Arc<dyn LanguageModel>> {
153 self.models.get(model_id).cloned()
154 }
155
156 fn map_language_model_to_info(
157 model: &Arc<dyn LanguageModel>,
158 provider: &Arc<dyn LanguageModelProvider>,
159 ) -> acp_thread::AgentModelInfo {
160 acp_thread::AgentModelInfo {
161 id: Self::model_id(model),
162 name: model.name().0,
163 description: None,
164 icon: Some(match provider.icon() {
165 IconOrSvg::Svg(path) => acp_thread::AgentModelIcon::Path(path),
166 IconOrSvg::Icon(name) => acp_thread::AgentModelIcon::Named(name),
167 }),
168 is_latest: model.is_latest(),
169 cost: model.model_cost_info().map(|cost| cost.to_shared_string()),
170 }
171 }
172
173 fn model_id(model: &Arc<dyn LanguageModel>) -> acp::ModelId {
174 acp::ModelId::new(format!("{}/{}", model.provider_id().0, model.id().0))
175 }
176
177 fn authenticate_all_language_model_providers(cx: &mut App) -> Task<()> {
178 let authenticate_all_providers = LanguageModelRegistry::global(cx)
179 .read(cx)
180 .visible_providers()
181 .iter()
182 .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
183 .collect::<Vec<_>>();
184
185 cx.background_spawn(async move {
186 for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
187 if let Err(err) = authenticate_task.await {
188 match err {
189 language_model::AuthenticateError::CredentialsNotFound => {
190 // Since we're authenticating these providers in the
191 // background for the purposes of populating the
192 // language selector, we don't care about providers
193 // where the credentials are not found.
194 }
195 language_model::AuthenticateError::ConnectionRefused => {
196 // Not logging connection refused errors as they are mostly from LM Studio's noisy auth failures.
197 // LM Studio only has one auth method (endpoint call) which fails for users who haven't enabled it.
198 // TODO: Better manage LM Studio auth logic to avoid these noisy failures.
199 }
200 _ => {
201 // Some providers have noisy failure states that we
202 // don't want to spam the logs with every time the
203 // language model selector is initialized.
204 //
205 // Ideally these should have more clear failure modes
206 // that we know are safe to ignore here, like what we do
207 // with `CredentialsNotFound` above.
208 match provider_id.0.as_ref() {
209 "lmstudio" | "ollama" => {
210 // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
211 //
212 // These fail noisily, so we don't log them.
213 }
214 "copilot_chat" => {
215 // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
216 }
217 _ => {
218 log::error!(
219 "Failed to authenticate provider: {}: {err:#}",
220 provider_name.0
221 );
222 }
223 }
224 }
225 }
226 }
227 }
228 })
229 }
230}
231
232pub struct NativeAgent {
233 /// Session ID -> Session mapping
234 sessions: HashMap<acp::SessionId, Session>,
235 thread_store: Entity<ThreadStore>,
236 /// Shared project context for all threads
237 project_context: Entity<ProjectContext>,
238 project_context_needs_refresh: watch::Sender<()>,
239 _maintain_project_context: Task<Result<()>>,
240 context_server_registry: Entity<ContextServerRegistry>,
241 /// Shared templates for all threads
242 templates: Arc<Templates>,
243 /// Cached model information
244 models: LanguageModels,
245 project: Entity<Project>,
246 prompt_store: Option<Entity<PromptStore>>,
247 fs: Arc<dyn Fs>,
248 _subscriptions: Vec<Subscription>,
249}
250
251impl NativeAgent {
252 pub async fn new(
253 project: Entity<Project>,
254 thread_store: Entity<ThreadStore>,
255 templates: Arc<Templates>,
256 prompt_store: Option<Entity<PromptStore>>,
257 fs: Arc<dyn Fs>,
258 cx: &mut AsyncApp,
259 ) -> Result<Entity<NativeAgent>> {
260 log::debug!("Creating new NativeAgent");
261
262 let project_context = cx
263 .update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))
264 .await;
265
266 Ok(cx.new(|cx| {
267 let context_server_store = project.read(cx).context_server_store();
268 let context_server_registry =
269 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
270
271 let mut subscriptions = vec![
272 cx.subscribe(&project, Self::handle_project_event),
273 cx.subscribe(
274 &LanguageModelRegistry::global(cx),
275 Self::handle_models_updated_event,
276 ),
277 cx.subscribe(
278 &context_server_store,
279 Self::handle_context_server_store_updated,
280 ),
281 cx.subscribe(
282 &context_server_registry,
283 Self::handle_context_server_registry_event,
284 ),
285 ];
286 if let Some(prompt_store) = prompt_store.as_ref() {
287 subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
288 }
289
290 let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
291 watch::channel(());
292 Self {
293 sessions: HashMap::default(),
294 thread_store,
295 project_context: cx.new(|_| project_context),
296 project_context_needs_refresh: project_context_needs_refresh_tx,
297 _maintain_project_context: cx.spawn(async move |this, cx| {
298 Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
299 }),
300 context_server_registry,
301 templates,
302 models: LanguageModels::new(cx),
303 project,
304 prompt_store,
305 fs,
306 _subscriptions: subscriptions,
307 }
308 }))
309 }
310
311 fn new_session(
312 &mut self,
313 project: Entity<Project>,
314 cx: &mut Context<Self>,
315 ) -> Entity<AcpThread> {
316 // Create Thread
317 // Fetch default model from registry settings
318 let registry = LanguageModelRegistry::read_global(cx);
319 // Log available models for debugging
320 let available_count = registry.available_models(cx).count();
321 log::debug!("Total available models: {}", available_count);
322
323 let default_model = registry.default_model().and_then(|default_model| {
324 self.models
325 .model_from_id(&LanguageModels::model_id(&default_model.model))
326 });
327 let thread = cx.new(|cx| {
328 Thread::new(
329 project.clone(),
330 self.project_context.clone(),
331 self.context_server_registry.clone(),
332 self.templates.clone(),
333 default_model,
334 cx,
335 )
336 });
337
338 self.register_session(thread, cx)
339 }
340
341 fn register_session(
342 &mut self,
343 thread_handle: Entity<Thread>,
344 cx: &mut Context<Self>,
345 ) -> Entity<AcpThread> {
346 let connection = Rc::new(NativeAgentConnection(cx.entity()));
347
348 let thread = thread_handle.read(cx);
349 let session_id = thread.id().clone();
350 let parent_session_id = thread.parent_thread_id();
351 let title = thread.title();
352 let project = thread.project.clone();
353 let action_log = thread.action_log.clone();
354 let prompt_capabilities_rx = thread.prompt_capabilities_rx.clone();
355 let acp_thread = cx.new(|cx| {
356 acp_thread::AcpThread::new(
357 parent_session_id,
358 title,
359 connection,
360 project.clone(),
361 action_log.clone(),
362 session_id.clone(),
363 prompt_capabilities_rx,
364 cx,
365 )
366 });
367
368 let registry = LanguageModelRegistry::read_global(cx);
369 let summarization_model = registry.thread_summary_model().map(|c| c.model);
370
371 let weak = cx.weak_entity();
372 let weak_thread = thread_handle.downgrade();
373 thread_handle.update(cx, |thread, cx| {
374 thread.set_summarization_model(summarization_model, cx);
375 thread.add_default_tools(
376 Rc::new(NativeThreadEnvironment {
377 acp_thread: acp_thread.downgrade(),
378 thread: weak_thread,
379 agent: weak,
380 }) as _,
381 cx,
382 )
383 });
384
385 let subscriptions = vec![
386 cx.subscribe(&thread_handle, Self::handle_thread_title_updated),
387 cx.subscribe(&thread_handle, Self::handle_thread_token_usage_updated),
388 cx.observe(&thread_handle, move |this, thread, cx| {
389 this.save_thread(thread, cx)
390 }),
391 ];
392
393 self.sessions.insert(
394 session_id,
395 Session {
396 thread: thread_handle,
397 acp_thread: acp_thread.clone(),
398 _subscriptions: subscriptions,
399 pending_save: Task::ready(()),
400 },
401 );
402
403 self.update_available_commands(cx);
404
405 acp_thread
406 }
407
408 pub fn models(&self) -> &LanguageModels {
409 &self.models
410 }
411
412 async fn maintain_project_context(
413 this: WeakEntity<Self>,
414 mut needs_refresh: watch::Receiver<()>,
415 cx: &mut AsyncApp,
416 ) -> Result<()> {
417 while needs_refresh.changed().await.is_ok() {
418 let project_context = this
419 .update(cx, |this, cx| {
420 Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx)
421 })?
422 .await;
423 this.update(cx, |this, cx| {
424 this.project_context = cx.new(|_| project_context);
425 })?;
426 }
427
428 Ok(())
429 }
430
431 fn build_project_context(
432 project: &Entity<Project>,
433 prompt_store: Option<&Entity<PromptStore>>,
434 cx: &mut App,
435 ) -> Task<ProjectContext> {
436 let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
437 let worktree_tasks = worktrees
438 .into_iter()
439 .map(|worktree| {
440 Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
441 })
442 .collect::<Vec<_>>();
443 let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
444 prompt_store.read_with(cx, |prompt_store, cx| {
445 let prompts = prompt_store.default_prompt_metadata();
446 let load_tasks = prompts.into_iter().map(|prompt_metadata| {
447 let contents = prompt_store.load(prompt_metadata.id, cx);
448 async move { (contents.await, prompt_metadata) }
449 });
450 cx.background_spawn(future::join_all(load_tasks))
451 })
452 } else {
453 Task::ready(vec![])
454 };
455
456 cx.spawn(async move |_cx| {
457 let (worktrees, default_user_rules) =
458 future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
459
460 let worktrees = worktrees
461 .into_iter()
462 .map(|(worktree, _rules_error)| {
463 // TODO: show error message
464 // if let Some(rules_error) = rules_error {
465 // this.update(cx, |_, cx| cx.emit(rules_error)).ok();
466 // }
467 worktree
468 })
469 .collect::<Vec<_>>();
470
471 let default_user_rules = default_user_rules
472 .into_iter()
473 .flat_map(|(contents, prompt_metadata)| match contents {
474 Ok(contents) => Some(UserRulesContext {
475 uuid: prompt_metadata.id.as_user()?,
476 title: prompt_metadata.title.map(|title| title.to_string()),
477 contents,
478 }),
479 Err(_err) => {
480 // TODO: show error message
481 // this.update(cx, |_, cx| {
482 // cx.emit(RulesLoadingError {
483 // message: format!("{err:?}").into(),
484 // });
485 // })
486 // .ok();
487 None
488 }
489 })
490 .collect::<Vec<_>>();
491
492 ProjectContext::new(worktrees, default_user_rules)
493 })
494 }
495
496 fn load_worktree_info_for_system_prompt(
497 worktree: Entity<Worktree>,
498 project: Entity<Project>,
499 cx: &mut App,
500 ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
501 let tree = worktree.read(cx);
502 let root_name = tree.root_name_str().into();
503 let abs_path = tree.abs_path();
504
505 let mut context = WorktreeContext {
506 root_name,
507 abs_path,
508 rules_file: None,
509 };
510
511 let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
512 let Some(rules_task) = rules_task else {
513 return Task::ready((context, None));
514 };
515
516 cx.spawn(async move |_| {
517 let (rules_file, rules_file_error) = match rules_task.await {
518 Ok(rules_file) => (Some(rules_file), None),
519 Err(err) => (
520 None,
521 Some(RulesLoadingError {
522 message: format!("{err}").into(),
523 }),
524 ),
525 };
526 context.rules_file = rules_file;
527 (context, rules_file_error)
528 })
529 }
530
531 fn load_worktree_rules_file(
532 worktree: Entity<Worktree>,
533 project: Entity<Project>,
534 cx: &mut App,
535 ) -> Option<Task<Result<RulesFileContext>>> {
536 let worktree = worktree.read(cx);
537 let worktree_id = worktree.id();
538 let selected_rules_file = RULES_FILE_NAMES
539 .into_iter()
540 .filter_map(|name| {
541 worktree
542 .entry_for_path(RelPath::unix(name).unwrap())
543 .filter(|entry| entry.is_file())
544 .map(|entry| entry.path.clone())
545 })
546 .next();
547
548 // Note that Cline supports `.clinerules` being a directory, but that is not currently
549 // supported. This doesn't seem to occur often in GitHub repositories.
550 selected_rules_file.map(|path_in_worktree| {
551 let project_path = ProjectPath {
552 worktree_id,
553 path: path_in_worktree.clone(),
554 };
555 let buffer_task =
556 project.update(cx, |project, cx| project.open_buffer(project_path, cx));
557 let rope_task = cx.spawn(async move |cx| {
558 let buffer = buffer_task.await?;
559 let (project_entry_id, rope) = buffer.read_with(cx, |buffer, cx| {
560 let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
561 anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
562 })?;
563 anyhow::Ok((project_entry_id, rope))
564 });
565 // Build a string from the rope on a background thread.
566 cx.background_spawn(async move {
567 let (project_entry_id, rope) = rope_task.await?;
568 anyhow::Ok(RulesFileContext {
569 path_in_worktree,
570 text: rope.to_string().trim().to_string(),
571 project_entry_id: project_entry_id.to_usize(),
572 })
573 })
574 })
575 }
576
577 fn handle_thread_title_updated(
578 &mut self,
579 thread: Entity<Thread>,
580 _: &TitleUpdated,
581 cx: &mut Context<Self>,
582 ) {
583 let session_id = thread.read(cx).id();
584 let Some(session) = self.sessions.get(session_id) else {
585 return;
586 };
587 let thread = thread.downgrade();
588 let acp_thread = session.acp_thread.downgrade();
589 cx.spawn(async move |_, cx| {
590 let title = thread.read_with(cx, |thread, _| thread.title())?;
591 let task = acp_thread.update(cx, |acp_thread, cx| acp_thread.set_title(title, cx))?;
592 task.await
593 })
594 .detach_and_log_err(cx);
595 }
596
597 fn handle_thread_token_usage_updated(
598 &mut self,
599 thread: Entity<Thread>,
600 usage: &TokenUsageUpdated,
601 cx: &mut Context<Self>,
602 ) {
603 let Some(session) = self.sessions.get(thread.read(cx).id()) else {
604 return;
605 };
606 session.acp_thread.update(cx, |acp_thread, cx| {
607 acp_thread.update_token_usage(usage.0.clone(), cx);
608 });
609 }
610
611 fn handle_project_event(
612 &mut self,
613 _project: Entity<Project>,
614 event: &project::Event,
615 _cx: &mut Context<Self>,
616 ) {
617 match event {
618 project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
619 self.project_context_needs_refresh.send(()).ok();
620 }
621 project::Event::WorktreeUpdatedEntries(_, items) => {
622 if items.iter().any(|(path, _, _)| {
623 RULES_FILE_NAMES
624 .iter()
625 .any(|name| path.as_ref() == RelPath::unix(name).unwrap())
626 }) {
627 self.project_context_needs_refresh.send(()).ok();
628 }
629 }
630 _ => {}
631 }
632 }
633
634 fn handle_prompts_updated_event(
635 &mut self,
636 _prompt_store: Entity<PromptStore>,
637 _event: &prompt_store::PromptsUpdatedEvent,
638 _cx: &mut Context<Self>,
639 ) {
640 self.project_context_needs_refresh.send(()).ok();
641 }
642
643 fn handle_models_updated_event(
644 &mut self,
645 _registry: Entity<LanguageModelRegistry>,
646 _event: &language_model::Event,
647 cx: &mut Context<Self>,
648 ) {
649 self.models.refresh_list(cx);
650
651 let registry = LanguageModelRegistry::read_global(cx);
652 let default_model = registry.default_model().map(|m| m.model);
653 let summarization_model = registry.thread_summary_model().map(|m| m.model);
654
655 for session in self.sessions.values_mut() {
656 session.thread.update(cx, |thread, cx| {
657 if thread.model().is_none()
658 && let Some(model) = default_model.clone()
659 {
660 thread.set_model(model, cx);
661 cx.notify();
662 }
663 thread.set_summarization_model(summarization_model.clone(), cx);
664 });
665 }
666 }
667
668 fn handle_context_server_store_updated(
669 &mut self,
670 _store: Entity<project::context_server_store::ContextServerStore>,
671 _event: &project::context_server_store::ServerStatusChangedEvent,
672 cx: &mut Context<Self>,
673 ) {
674 self.update_available_commands(cx);
675 }
676
677 fn handle_context_server_registry_event(
678 &mut self,
679 _registry: Entity<ContextServerRegistry>,
680 event: &ContextServerRegistryEvent,
681 cx: &mut Context<Self>,
682 ) {
683 match event {
684 ContextServerRegistryEvent::ToolsChanged => {}
685 ContextServerRegistryEvent::PromptsChanged => {
686 self.update_available_commands(cx);
687 }
688 }
689 }
690
691 fn update_available_commands(&self, cx: &mut Context<Self>) {
692 let available_commands = self.build_available_commands(cx);
693 for session in self.sessions.values() {
694 session.acp_thread.update(cx, |thread, cx| {
695 thread
696 .handle_session_update(
697 acp::SessionUpdate::AvailableCommandsUpdate(
698 acp::AvailableCommandsUpdate::new(available_commands.clone()),
699 ),
700 cx,
701 )
702 .log_err();
703 });
704 }
705 }
706
707 fn build_available_commands(&self, cx: &App) -> Vec<acp::AvailableCommand> {
708 let registry = self.context_server_registry.read(cx);
709
710 let mut prompt_name_counts: HashMap<&str, usize> = HashMap::default();
711 for context_server_prompt in registry.prompts() {
712 *prompt_name_counts
713 .entry(context_server_prompt.prompt.name.as_str())
714 .or_insert(0) += 1;
715 }
716
717 registry
718 .prompts()
719 .flat_map(|context_server_prompt| {
720 let prompt = &context_server_prompt.prompt;
721
722 let should_prefix = prompt_name_counts
723 .get(prompt.name.as_str())
724 .copied()
725 .unwrap_or(0)
726 > 1;
727
728 let name = if should_prefix {
729 format!("{}.{}", context_server_prompt.server_id, prompt.name)
730 } else {
731 prompt.name.clone()
732 };
733
734 let mut command = acp::AvailableCommand::new(
735 name,
736 prompt.description.clone().unwrap_or_default(),
737 );
738
739 match prompt.arguments.as_deref() {
740 Some([arg]) => {
741 let hint = format!("<{}>", arg.name);
742
743 command = command.input(acp::AvailableCommandInput::Unstructured(
744 acp::UnstructuredCommandInput::new(hint),
745 ));
746 }
747 Some([]) | None => {}
748 Some(_) => {
749 // skip >1 argument commands since we don't support them yet
750 return None;
751 }
752 }
753
754 Some(command)
755 })
756 .collect()
757 }
758
759 pub fn load_thread(
760 &mut self,
761 id: acp::SessionId,
762 cx: &mut Context<Self>,
763 ) -> Task<Result<Entity<Thread>>> {
764 let database_future = ThreadsDatabase::connect(cx);
765 cx.spawn(async move |this, cx| {
766 let database = database_future.await.map_err(|err| anyhow!(err))?;
767 let db_thread = database
768 .load_thread(id.clone())
769 .await?
770 .with_context(|| format!("no thread found with ID: {id:?}"))?;
771
772 this.update(cx, |this, cx| {
773 let summarization_model = LanguageModelRegistry::read_global(cx)
774 .thread_summary_model()
775 .map(|c| c.model);
776
777 cx.new(|cx| {
778 let mut thread = Thread::from_db(
779 id.clone(),
780 db_thread,
781 this.project.clone(),
782 this.project_context.clone(),
783 this.context_server_registry.clone(),
784 this.templates.clone(),
785 cx,
786 );
787 thread.set_summarization_model(summarization_model, cx);
788 thread
789 })
790 })
791 })
792 }
793
794 pub fn open_thread(
795 &mut self,
796 id: acp::SessionId,
797 cx: &mut Context<Self>,
798 ) -> Task<Result<Entity<AcpThread>>> {
799 if let Some(session) = self.sessions.get(&id) {
800 return Task::ready(Ok(session.acp_thread.clone()));
801 }
802
803 let task = self.load_thread(id, cx);
804 cx.spawn(async move |this, cx| {
805 let thread = task.await?;
806 let acp_thread =
807 this.update(cx, |this, cx| this.register_session(thread.clone(), cx))?;
808 let events = thread.update(cx, |thread, cx| thread.replay(cx));
809 cx.update(|cx| {
810 NativeAgentConnection::handle_thread_events(events, acp_thread.downgrade(), cx)
811 })
812 .await?;
813 Ok(acp_thread)
814 })
815 }
816
817 pub fn thread_summary(
818 &mut self,
819 id: acp::SessionId,
820 cx: &mut Context<Self>,
821 ) -> Task<Result<SharedString>> {
822 let thread = self.open_thread(id.clone(), cx);
823 cx.spawn(async move |this, cx| {
824 let acp_thread = thread.await?;
825 let result = this
826 .update(cx, |this, cx| {
827 this.sessions
828 .get(&id)
829 .unwrap()
830 .thread
831 .update(cx, |thread, cx| thread.summary(cx))
832 })?
833 .await
834 .context("Failed to generate summary")?;
835 drop(acp_thread);
836 Ok(result)
837 })
838 }
839
840 fn save_thread(&mut self, thread: Entity<Thread>, cx: &mut Context<Self>) {
841 if thread.read(cx).is_empty() {
842 return;
843 }
844
845 let database_future = ThreadsDatabase::connect(cx);
846 let (id, db_thread) =
847 thread.update(cx, |thread, cx| (thread.id().clone(), thread.to_db(cx)));
848 let Some(session) = self.sessions.get_mut(&id) else {
849 return;
850 };
851 let thread_store = self.thread_store.clone();
852 session.pending_save = cx.spawn(async move |_, cx| {
853 let Some(database) = database_future.await.map_err(|err| anyhow!(err)).log_err() else {
854 return;
855 };
856 let db_thread = db_thread.await;
857 database.save_thread(id, db_thread).await.log_err();
858 thread_store.update(cx, |store, cx| store.reload(cx));
859 });
860 }
861
862 fn send_mcp_prompt(
863 &self,
864 message_id: UserMessageId,
865 session_id: agent_client_protocol::SessionId,
866 prompt_name: String,
867 server_id: ContextServerId,
868 arguments: HashMap<String, String>,
869 original_content: Vec<acp::ContentBlock>,
870 cx: &mut Context<Self>,
871 ) -> Task<Result<acp::PromptResponse>> {
872 let server_store = self.context_server_registry.read(cx).server_store().clone();
873 let path_style = self.project.read(cx).path_style(cx);
874
875 cx.spawn(async move |this, cx| {
876 let prompt =
877 crate::get_prompt(&server_store, &server_id, &prompt_name, arguments, cx).await?;
878
879 let (acp_thread, thread) = this.update(cx, |this, _cx| {
880 let session = this
881 .sessions
882 .get(&session_id)
883 .context("Failed to get session")?;
884 anyhow::Ok((session.acp_thread.clone(), session.thread.clone()))
885 })??;
886
887 let mut last_is_user = true;
888
889 thread.update(cx, |thread, cx| {
890 thread.push_acp_user_block(
891 message_id,
892 original_content.into_iter().skip(1),
893 path_style,
894 cx,
895 );
896 });
897
898 for message in prompt.messages {
899 let context_server::types::PromptMessage { role, content } = message;
900 let block = mcp_message_content_to_acp_content_block(content);
901
902 match role {
903 context_server::types::Role::User => {
904 let id = acp_thread::UserMessageId::new();
905
906 acp_thread.update(cx, |acp_thread, cx| {
907 acp_thread.push_user_content_block_with_indent(
908 Some(id.clone()),
909 block.clone(),
910 true,
911 cx,
912 );
913 });
914
915 thread.update(cx, |thread, cx| {
916 thread.push_acp_user_block(id, [block], path_style, cx);
917 });
918 }
919 context_server::types::Role::Assistant => {
920 acp_thread.update(cx, |acp_thread, cx| {
921 acp_thread.push_assistant_content_block_with_indent(
922 block.clone(),
923 false,
924 true,
925 cx,
926 );
927 });
928
929 thread.update(cx, |thread, cx| {
930 thread.push_acp_agent_block(block, cx);
931 });
932 }
933 }
934
935 last_is_user = role == context_server::types::Role::User;
936 }
937
938 let response_stream = thread.update(cx, |thread, cx| {
939 if last_is_user {
940 thread.send_existing(cx)
941 } else {
942 // Resume if MCP prompt did not end with a user message
943 thread.resume(cx)
944 }
945 })?;
946
947 cx.update(|cx| {
948 NativeAgentConnection::handle_thread_events(
949 response_stream,
950 acp_thread.downgrade(),
951 cx,
952 )
953 })
954 .await
955 })
956 }
957}
958
959/// Wrapper struct that implements the AgentConnection trait
960#[derive(Clone)]
961pub struct NativeAgentConnection(pub Entity<NativeAgent>);
962
963impl NativeAgentConnection {
964 pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option<Entity<Thread>> {
965 self.0
966 .read(cx)
967 .sessions
968 .get(session_id)
969 .map(|session| session.thread.clone())
970 }
971
972 pub fn load_thread(&self, id: acp::SessionId, cx: &mut App) -> Task<Result<Entity<Thread>>> {
973 self.0.update(cx, |this, cx| this.load_thread(id, cx))
974 }
975
976 fn run_turn(
977 &self,
978 session_id: acp::SessionId,
979 cx: &mut App,
980 f: impl 'static
981 + FnOnce(Entity<Thread>, &mut App) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>,
982 ) -> Task<Result<acp::PromptResponse>> {
983 let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
984 agent
985 .sessions
986 .get_mut(&session_id)
987 .map(|s| (s.thread.clone(), s.acp_thread.clone()))
988 }) else {
989 return Task::ready(Err(anyhow!("Session not found")));
990 };
991 log::debug!("Found session for: {}", session_id);
992
993 let response_stream = match f(thread, cx) {
994 Ok(stream) => stream,
995 Err(err) => return Task::ready(Err(err)),
996 };
997 Self::handle_thread_events(response_stream, acp_thread.downgrade(), cx)
998 }
999
1000 fn handle_thread_events(
1001 mut events: mpsc::UnboundedReceiver<Result<ThreadEvent>>,
1002 acp_thread: WeakEntity<AcpThread>,
1003 cx: &App,
1004 ) -> Task<Result<acp::PromptResponse>> {
1005 cx.spawn(async move |cx| {
1006 // Handle response stream and forward to session.acp_thread
1007 while let Some(result) = events.next().await {
1008 match result {
1009 Ok(event) => {
1010 log::trace!("Received completion event: {:?}", event);
1011
1012 match event {
1013 ThreadEvent::UserMessage(message) => {
1014 acp_thread.update(cx, |thread, cx| {
1015 for content in message.content {
1016 thread.push_user_content_block(
1017 Some(message.id.clone()),
1018 content.into(),
1019 cx,
1020 );
1021 }
1022 })?;
1023 }
1024 ThreadEvent::AgentText(text) => {
1025 acp_thread.update(cx, |thread, cx| {
1026 thread.push_assistant_content_block(text.into(), false, cx)
1027 })?;
1028 }
1029 ThreadEvent::AgentThinking(text) => {
1030 acp_thread.update(cx, |thread, cx| {
1031 thread.push_assistant_content_block(text.into(), true, cx)
1032 })?;
1033 }
1034 ThreadEvent::ToolCallAuthorization(ToolCallAuthorization {
1035 tool_call,
1036 options,
1037 response,
1038 context: _,
1039 }) => {
1040 let outcome_task = acp_thread.update(cx, |thread, cx| {
1041 thread.request_tool_call_authorization(tool_call, options, cx)
1042 })??;
1043 cx.background_spawn(async move {
1044 if let acp::RequestPermissionOutcome::Selected(
1045 acp::SelectedPermissionOutcome { option_id, .. },
1046 ) = outcome_task.await
1047 {
1048 response
1049 .send(option_id)
1050 .map(|_| anyhow!("authorization receiver was dropped"))
1051 .log_err();
1052 }
1053 })
1054 .detach();
1055 }
1056 ThreadEvent::ToolCall(tool_call) => {
1057 acp_thread.update(cx, |thread, cx| {
1058 thread.upsert_tool_call(tool_call, cx)
1059 })??;
1060 }
1061 ThreadEvent::ToolCallUpdate(update) => {
1062 acp_thread.update(cx, |thread, cx| {
1063 thread.update_tool_call(update, cx)
1064 })??;
1065 }
1066 ThreadEvent::SubagentSpawned(session_id) => {
1067 acp_thread.update(cx, |thread, cx| {
1068 thread.subagent_spawned(session_id, cx);
1069 })?;
1070 }
1071 ThreadEvent::Retry(status) => {
1072 acp_thread.update(cx, |thread, cx| {
1073 thread.update_retry_status(status, cx)
1074 })?;
1075 }
1076 ThreadEvent::Stop(stop_reason) => {
1077 log::debug!("Assistant message complete: {:?}", stop_reason);
1078 return Ok(acp::PromptResponse::new(stop_reason));
1079 }
1080 }
1081 }
1082 Err(e) => {
1083 log::error!("Error in model response stream: {:?}", e);
1084 return Err(e);
1085 }
1086 }
1087 }
1088
1089 log::debug!("Response stream completed");
1090 anyhow::Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
1091 })
1092 }
1093}
1094
1095struct Command<'a> {
1096 prompt_name: &'a str,
1097 arg_value: &'a str,
1098 explicit_server_id: Option<&'a str>,
1099}
1100
1101impl<'a> Command<'a> {
1102 fn parse(prompt: &'a [acp::ContentBlock]) -> Option<Self> {
1103 let acp::ContentBlock::Text(text_content) = prompt.first()? else {
1104 return None;
1105 };
1106 let text = text_content.text.trim();
1107 let command = text.strip_prefix('/')?;
1108 let (command, arg_value) = command
1109 .split_once(char::is_whitespace)
1110 .unwrap_or((command, ""));
1111
1112 if let Some((server_id, prompt_name)) = command.split_once('.') {
1113 Some(Self {
1114 prompt_name,
1115 arg_value,
1116 explicit_server_id: Some(server_id),
1117 })
1118 } else {
1119 Some(Self {
1120 prompt_name: command,
1121 arg_value,
1122 explicit_server_id: None,
1123 })
1124 }
1125 }
1126}
1127
1128struct NativeAgentModelSelector {
1129 session_id: acp::SessionId,
1130 connection: NativeAgentConnection,
1131}
1132
1133impl acp_thread::AgentModelSelector for NativeAgentModelSelector {
1134 fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
1135 log::debug!("NativeAgentConnection::list_models called");
1136 let list = self.connection.0.read(cx).models.model_list.clone();
1137 Task::ready(if list.is_empty() {
1138 Err(anyhow::anyhow!("No models available"))
1139 } else {
1140 Ok(list)
1141 })
1142 }
1143
1144 fn select_model(&self, model_id: acp::ModelId, cx: &mut App) -> Task<Result<()>> {
1145 log::debug!(
1146 "Setting model for session {}: {}",
1147 self.session_id,
1148 model_id
1149 );
1150 let Some(thread) = self
1151 .connection
1152 .0
1153 .read(cx)
1154 .sessions
1155 .get(&self.session_id)
1156 .map(|session| session.thread.clone())
1157 else {
1158 return Task::ready(Err(anyhow!("Session not found")));
1159 };
1160
1161 let Some(model) = self.connection.0.read(cx).models.model_from_id(&model_id) else {
1162 return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
1163 };
1164
1165 // We want to reset the effort level when switching models, as the currently-selected effort level may
1166 // not be compatible.
1167 let effort = model
1168 .default_effort_level()
1169 .map(|effort_level| effort_level.value.to_string());
1170
1171 thread.update(cx, |thread, cx| {
1172 thread.set_model(model.clone(), cx);
1173 thread.set_thinking_effort(effort.clone(), cx);
1174 thread.set_thinking_enabled(model.supports_thinking(), cx);
1175 });
1176
1177 update_settings_file(
1178 self.connection.0.read(cx).fs.clone(),
1179 cx,
1180 move |settings, cx| {
1181 let provider = model.provider_id().0.to_string();
1182 let model = model.id().0.to_string();
1183 let enable_thinking = thread.read(cx).thinking_enabled();
1184 settings
1185 .agent
1186 .get_or_insert_default()
1187 .set_model(LanguageModelSelection {
1188 provider: provider.into(),
1189 model,
1190 enable_thinking,
1191 effort,
1192 });
1193 },
1194 );
1195
1196 Task::ready(Ok(()))
1197 }
1198
1199 fn selected_model(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelInfo>> {
1200 let Some(thread) = self
1201 .connection
1202 .0
1203 .read(cx)
1204 .sessions
1205 .get(&self.session_id)
1206 .map(|session| session.thread.clone())
1207 else {
1208 return Task::ready(Err(anyhow!("Session not found")));
1209 };
1210 let Some(model) = thread.read(cx).model() else {
1211 return Task::ready(Err(anyhow!("Model not found")));
1212 };
1213 let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
1214 else {
1215 return Task::ready(Err(anyhow!("Provider not found")));
1216 };
1217 Task::ready(Ok(LanguageModels::map_language_model_to_info(
1218 model, &provider,
1219 )))
1220 }
1221
1222 fn watch(&self, cx: &mut App) -> Option<watch::Receiver<()>> {
1223 Some(self.connection.0.read(cx).models.watch())
1224 }
1225
1226 fn should_render_footer(&self) -> bool {
1227 true
1228 }
1229}
1230
1231impl acp_thread::AgentConnection for NativeAgentConnection {
1232 fn telemetry_id(&self) -> SharedString {
1233 "zed".into()
1234 }
1235
1236 fn new_session(
1237 self: Rc<Self>,
1238 project: Entity<Project>,
1239 cwd: &Path,
1240 cx: &mut App,
1241 ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
1242 log::debug!("Creating new thread for project at: {cwd:?}");
1243 Task::ready(Ok(self
1244 .0
1245 .update(cx, |agent, cx| agent.new_session(project, cx))))
1246 }
1247
1248 fn supports_load_session(&self) -> bool {
1249 true
1250 }
1251
1252 fn load_session(
1253 self: Rc<Self>,
1254 session: AgentSessionInfo,
1255 _project: Entity<Project>,
1256 _cwd: &Path,
1257 cx: &mut App,
1258 ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
1259 self.0
1260 .update(cx, |agent, cx| agent.open_thread(session.session_id, cx))
1261 }
1262
1263 fn supports_close_session(&self) -> bool {
1264 true
1265 }
1266
1267 fn close_session(&self, session_id: &acp::SessionId, cx: &mut App) -> Task<Result<()>> {
1268 self.0.update(cx, |agent, _cx| {
1269 agent.sessions.remove(session_id);
1270 });
1271 Task::ready(Ok(()))
1272 }
1273
1274 fn auth_methods(&self) -> &[acp::AuthMethod] {
1275 &[] // No auth for in-process
1276 }
1277
1278 fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
1279 Task::ready(Ok(()))
1280 }
1281
1282 fn model_selector(&self, session_id: &acp::SessionId) -> Option<Rc<dyn AgentModelSelector>> {
1283 Some(Rc::new(NativeAgentModelSelector {
1284 session_id: session_id.clone(),
1285 connection: self.clone(),
1286 }) as Rc<dyn AgentModelSelector>)
1287 }
1288
1289 fn prompt(
1290 &self,
1291 id: Option<acp_thread::UserMessageId>,
1292 params: acp::PromptRequest,
1293 cx: &mut App,
1294 ) -> Task<Result<acp::PromptResponse>> {
1295 let id = id.expect("UserMessageId is required");
1296 let session_id = params.session_id.clone();
1297 log::info!("Received prompt request for session: {}", session_id);
1298 log::debug!("Prompt blocks count: {}", params.prompt.len());
1299
1300 if let Some(parsed_command) = Command::parse(¶ms.prompt) {
1301 let registry = self.0.read(cx).context_server_registry.read(cx);
1302
1303 let explicit_server_id = parsed_command
1304 .explicit_server_id
1305 .map(|server_id| ContextServerId(server_id.into()));
1306
1307 if let Some(prompt) =
1308 registry.find_prompt(explicit_server_id.as_ref(), parsed_command.prompt_name)
1309 {
1310 let arguments = if !parsed_command.arg_value.is_empty()
1311 && let Some(arg_name) = prompt
1312 .prompt
1313 .arguments
1314 .as_ref()
1315 .and_then(|args| args.first())
1316 .map(|arg| arg.name.clone())
1317 {
1318 HashMap::from_iter([(arg_name, parsed_command.arg_value.to_string())])
1319 } else {
1320 Default::default()
1321 };
1322
1323 let prompt_name = prompt.prompt.name.clone();
1324 let server_id = prompt.server_id.clone();
1325
1326 return self.0.update(cx, |agent, cx| {
1327 agent.send_mcp_prompt(
1328 id,
1329 session_id.clone(),
1330 prompt_name,
1331 server_id,
1332 arguments,
1333 params.prompt,
1334 cx,
1335 )
1336 });
1337 };
1338 };
1339
1340 let path_style = self.0.read(cx).project.read(cx).path_style(cx);
1341
1342 self.run_turn(session_id, cx, move |thread, cx| {
1343 let content: Vec<UserMessageContent> = params
1344 .prompt
1345 .into_iter()
1346 .map(|block| UserMessageContent::from_content_block(block, path_style))
1347 .collect::<Vec<_>>();
1348 log::debug!("Converted prompt to message: {} chars", content.len());
1349 log::debug!("Message id: {:?}", id);
1350 log::debug!("Message content: {:?}", content);
1351
1352 thread.update(cx, |thread, cx| thread.send(id, content, cx))
1353 })
1354 }
1355
1356 fn retry(
1357 &self,
1358 session_id: &acp::SessionId,
1359 _cx: &App,
1360 ) -> Option<Rc<dyn acp_thread::AgentSessionRetry>> {
1361 Some(Rc::new(NativeAgentSessionRetry {
1362 connection: self.clone(),
1363 session_id: session_id.clone(),
1364 }) as _)
1365 }
1366
1367 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
1368 log::info!("Cancelling on session: {}", session_id);
1369 self.0.update(cx, |agent, cx| {
1370 if let Some(session) = agent.sessions.get(session_id) {
1371 session
1372 .thread
1373 .update(cx, |thread, cx| thread.cancel(cx))
1374 .detach();
1375 }
1376 });
1377 }
1378
1379 fn truncate(
1380 &self,
1381 session_id: &agent_client_protocol::SessionId,
1382 cx: &App,
1383 ) -> Option<Rc<dyn acp_thread::AgentSessionTruncate>> {
1384 self.0.read_with(cx, |agent, _cx| {
1385 agent.sessions.get(session_id).map(|session| {
1386 Rc::new(NativeAgentSessionTruncate {
1387 thread: session.thread.clone(),
1388 acp_thread: session.acp_thread.downgrade(),
1389 }) as _
1390 })
1391 })
1392 }
1393
1394 fn set_title(
1395 &self,
1396 session_id: &acp::SessionId,
1397 cx: &App,
1398 ) -> Option<Rc<dyn acp_thread::AgentSessionSetTitle>> {
1399 self.0.read_with(cx, |agent, _cx| {
1400 agent
1401 .sessions
1402 .get(session_id)
1403 .filter(|s| !s.thread.read(cx).is_subagent())
1404 .map(|session| {
1405 Rc::new(NativeAgentSessionSetTitle {
1406 thread: session.thread.clone(),
1407 }) as _
1408 })
1409 })
1410 }
1411
1412 fn session_list(&self, cx: &mut App) -> Option<Rc<dyn AgentSessionList>> {
1413 let thread_store = self.0.read(cx).thread_store.clone();
1414 Some(Rc::new(NativeAgentSessionList::new(thread_store, cx)) as _)
1415 }
1416
1417 fn telemetry(&self) -> Option<Rc<dyn acp_thread::AgentTelemetry>> {
1418 Some(Rc::new(self.clone()) as Rc<dyn acp_thread::AgentTelemetry>)
1419 }
1420
1421 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1422 self
1423 }
1424}
1425
1426impl acp_thread::AgentTelemetry for NativeAgentConnection {
1427 fn thread_data(
1428 &self,
1429 session_id: &acp::SessionId,
1430 cx: &mut App,
1431 ) -> Task<Result<serde_json::Value>> {
1432 let Some(session) = self.0.read(cx).sessions.get(session_id) else {
1433 return Task::ready(Err(anyhow!("Session not found")));
1434 };
1435
1436 let task = session.thread.read(cx).to_db(cx);
1437 cx.background_spawn(async move {
1438 serde_json::to_value(task.await).context("Failed to serialize thread")
1439 })
1440 }
1441}
1442
1443pub struct NativeAgentSessionList {
1444 thread_store: Entity<ThreadStore>,
1445 updates_tx: smol::channel::Sender<acp_thread::SessionListUpdate>,
1446 updates_rx: smol::channel::Receiver<acp_thread::SessionListUpdate>,
1447 _subscription: Subscription,
1448}
1449
1450impl NativeAgentSessionList {
1451 fn new(thread_store: Entity<ThreadStore>, cx: &mut App) -> Self {
1452 let (tx, rx) = smol::channel::unbounded();
1453 let this_tx = tx.clone();
1454 let subscription = cx.observe(&thread_store, move |_, _| {
1455 this_tx
1456 .try_send(acp_thread::SessionListUpdate::Refresh)
1457 .ok();
1458 });
1459 Self {
1460 thread_store,
1461 updates_tx: tx,
1462 updates_rx: rx,
1463 _subscription: subscription,
1464 }
1465 }
1466
1467 fn to_session_info(entry: DbThreadMetadata) -> AgentSessionInfo {
1468 AgentSessionInfo {
1469 session_id: entry.id,
1470 cwd: None,
1471 title: Some(entry.title),
1472 updated_at: Some(entry.updated_at),
1473 meta: None,
1474 }
1475 }
1476
1477 pub fn thread_store(&self) -> &Entity<ThreadStore> {
1478 &self.thread_store
1479 }
1480}
1481
1482impl AgentSessionList for NativeAgentSessionList {
1483 fn list_sessions(
1484 &self,
1485 _request: AgentSessionListRequest,
1486 cx: &mut App,
1487 ) -> Task<Result<AgentSessionListResponse>> {
1488 let sessions = self
1489 .thread_store
1490 .read(cx)
1491 .entries()
1492 .map(Self::to_session_info)
1493 .collect();
1494 Task::ready(Ok(AgentSessionListResponse::new(sessions)))
1495 }
1496
1497 fn supports_delete(&self) -> bool {
1498 true
1499 }
1500
1501 fn delete_session(&self, session_id: &acp::SessionId, cx: &mut App) -> Task<Result<()>> {
1502 self.thread_store
1503 .update(cx, |store, cx| store.delete_thread(session_id.clone(), cx))
1504 }
1505
1506 fn delete_sessions(&self, cx: &mut App) -> Task<Result<()>> {
1507 self.thread_store
1508 .update(cx, |store, cx| store.delete_threads(cx))
1509 }
1510
1511 fn watch(
1512 &self,
1513 _cx: &mut App,
1514 ) -> Option<smol::channel::Receiver<acp_thread::SessionListUpdate>> {
1515 Some(self.updates_rx.clone())
1516 }
1517
1518 fn notify_refresh(&self) {
1519 self.updates_tx
1520 .try_send(acp_thread::SessionListUpdate::Refresh)
1521 .ok();
1522 }
1523
1524 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1525 self
1526 }
1527}
1528
1529struct NativeAgentSessionTruncate {
1530 thread: Entity<Thread>,
1531 acp_thread: WeakEntity<AcpThread>,
1532}
1533
1534impl acp_thread::AgentSessionTruncate for NativeAgentSessionTruncate {
1535 fn run(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
1536 match self.thread.update(cx, |thread, cx| {
1537 thread.truncate(message_id.clone(), cx)?;
1538 Ok(thread.latest_token_usage())
1539 }) {
1540 Ok(usage) => {
1541 self.acp_thread
1542 .update(cx, |thread, cx| {
1543 thread.update_token_usage(usage, cx);
1544 })
1545 .ok();
1546 Task::ready(Ok(()))
1547 }
1548 Err(error) => Task::ready(Err(error)),
1549 }
1550 }
1551}
1552
1553struct NativeAgentSessionRetry {
1554 connection: NativeAgentConnection,
1555 session_id: acp::SessionId,
1556}
1557
1558impl acp_thread::AgentSessionRetry for NativeAgentSessionRetry {
1559 fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>> {
1560 self.connection
1561 .run_turn(self.session_id.clone(), cx, |thread, cx| {
1562 thread.update(cx, |thread, cx| thread.resume(cx))
1563 })
1564 }
1565}
1566
1567struct NativeAgentSessionSetTitle {
1568 thread: Entity<Thread>,
1569}
1570
1571impl acp_thread::AgentSessionSetTitle for NativeAgentSessionSetTitle {
1572 fn run(&self, title: SharedString, cx: &mut App) -> Task<Result<()>> {
1573 self.thread
1574 .update(cx, |thread, cx| thread.set_title(title, cx));
1575 Task::ready(Ok(()))
1576 }
1577}
1578
1579pub struct NativeThreadEnvironment {
1580 agent: WeakEntity<NativeAgent>,
1581 thread: WeakEntity<Thread>,
1582 acp_thread: WeakEntity<AcpThread>,
1583}
1584
1585impl NativeThreadEnvironment {
1586 pub(crate) fn create_subagent_thread(
1587 &self,
1588 label: String,
1589 cx: &mut App,
1590 ) -> Result<Rc<dyn SubagentHandle>> {
1591 let Some(parent_thread_entity) = self.thread.upgrade() else {
1592 anyhow::bail!("Parent thread no longer exists".to_string());
1593 };
1594 let parent_thread = parent_thread_entity.read(cx);
1595 let current_depth = parent_thread.depth();
1596
1597 if current_depth >= MAX_SUBAGENT_DEPTH {
1598 return Err(anyhow!(
1599 "Maximum subagent depth ({}) reached",
1600 MAX_SUBAGENT_DEPTH
1601 ));
1602 }
1603
1604 let subagent_thread: Entity<Thread> = cx.new(|cx| {
1605 let mut thread = Thread::new_subagent(&parent_thread_entity, cx);
1606 thread.set_title(label.into(), cx);
1607 thread
1608 });
1609
1610 let session_id = subagent_thread.read(cx).id().clone();
1611
1612 let acp_thread = self.agent.update(cx, |agent, cx| {
1613 agent.register_session(subagent_thread.clone(), cx)
1614 })?;
1615
1616 self.prompt_subagent(session_id, subagent_thread, acp_thread)
1617 }
1618
1619 pub(crate) fn resume_subagent_thread(
1620 &self,
1621 session_id: acp::SessionId,
1622 cx: &mut App,
1623 ) -> Result<Rc<dyn SubagentHandle>> {
1624 let (subagent_thread, acp_thread) = self.agent.update(cx, |agent, _cx| {
1625 let session = agent
1626 .sessions
1627 .get(&session_id)
1628 .ok_or_else(|| anyhow!("No subagent session found with id {session_id}"))?;
1629 anyhow::Ok((session.thread.clone(), session.acp_thread.clone()))
1630 })??;
1631
1632 self.prompt_subagent(session_id, subagent_thread, acp_thread)
1633 }
1634
1635 fn prompt_subagent(
1636 &self,
1637 session_id: acp::SessionId,
1638 subagent_thread: Entity<Thread>,
1639 acp_thread: Entity<acp_thread::AcpThread>,
1640 ) -> Result<Rc<dyn SubagentHandle>> {
1641 let Some(parent_thread_entity) = self.thread.upgrade() else {
1642 anyhow::bail!("Parent thread no longer exists".to_string());
1643 };
1644 Ok(Rc::new(NativeSubagentHandle::new(
1645 session_id,
1646 subagent_thread,
1647 acp_thread,
1648 parent_thread_entity,
1649 )) as _)
1650 }
1651}
1652
1653impl ThreadEnvironment for NativeThreadEnvironment {
1654 fn create_terminal(
1655 &self,
1656 command: String,
1657 cwd: Option<PathBuf>,
1658 output_byte_limit: Option<u64>,
1659 cx: &mut AsyncApp,
1660 ) -> Task<Result<Rc<dyn TerminalHandle>>> {
1661 let task = self.acp_thread.update(cx, |thread, cx| {
1662 thread.create_terminal(command, vec![], vec![], cwd, output_byte_limit, cx)
1663 });
1664
1665 let acp_thread = self.acp_thread.clone();
1666 cx.spawn(async move |cx| {
1667 let terminal = task?.await?;
1668
1669 let (drop_tx, drop_rx) = oneshot::channel();
1670 let terminal_id = terminal.read_with(cx, |terminal, _cx| terminal.id().clone());
1671
1672 cx.spawn(async move |cx| {
1673 drop_rx.await.ok();
1674 acp_thread.update(cx, |thread, cx| thread.release_terminal(terminal_id, cx))
1675 })
1676 .detach();
1677
1678 let handle = AcpTerminalHandle {
1679 terminal,
1680 _drop_tx: Some(drop_tx),
1681 };
1682
1683 Ok(Rc::new(handle) as _)
1684 })
1685 }
1686
1687 fn create_subagent(&self, label: String, cx: &mut App) -> Result<Rc<dyn SubagentHandle>> {
1688 self.create_subagent_thread(label, cx)
1689 }
1690
1691 fn resume_subagent(
1692 &self,
1693 session_id: acp::SessionId,
1694 cx: &mut App,
1695 ) -> Result<Rc<dyn SubagentHandle>> {
1696 self.resume_subagent_thread(session_id, cx)
1697 }
1698}
1699
1700#[derive(Debug, Clone)]
1701enum SubagentPromptResult {
1702 Completed,
1703 Cancelled,
1704 ContextWindowWarning,
1705 Error(String),
1706}
1707
1708pub struct NativeSubagentHandle {
1709 session_id: acp::SessionId,
1710 parent_thread: WeakEntity<Thread>,
1711 subagent_thread: Entity<Thread>,
1712 acp_thread: Entity<acp_thread::AcpThread>,
1713}
1714
1715impl NativeSubagentHandle {
1716 fn new(
1717 session_id: acp::SessionId,
1718 subagent_thread: Entity<Thread>,
1719 acp_thread: Entity<acp_thread::AcpThread>,
1720 parent_thread_entity: Entity<Thread>,
1721 ) -> Self {
1722 NativeSubagentHandle {
1723 session_id,
1724 subagent_thread,
1725 parent_thread: parent_thread_entity.downgrade(),
1726 acp_thread,
1727 }
1728 }
1729}
1730
1731impl SubagentHandle for NativeSubagentHandle {
1732 fn id(&self) -> acp::SessionId {
1733 self.session_id.clone()
1734 }
1735
1736 fn run_turn(&self, message: String, cx: &AsyncApp) -> Task<Result<String>> {
1737 let thread = self.subagent_thread.clone();
1738 let acp_thread = self.acp_thread.clone();
1739 let subagent_session_id = self.session_id.clone();
1740 let parent_thread = self.parent_thread.clone();
1741
1742 cx.spawn(async move |cx| {
1743 let (task, _subscription) = cx.update(|cx| {
1744 let ratio_before_prompt = thread
1745 .read(cx)
1746 .latest_token_usage()
1747 .map(|usage| usage.ratio());
1748
1749 parent_thread
1750 .update(cx, |parent_thread, _cx| {
1751 parent_thread.register_running_subagent(thread.downgrade())
1752 })
1753 .ok();
1754
1755 let task = acp_thread.update(cx, |acp_thread, cx| {
1756 acp_thread.send(vec![message.into()], cx)
1757 });
1758
1759 let (token_limit_tx, token_limit_rx) = oneshot::channel::<()>();
1760 let mut token_limit_tx = Some(token_limit_tx);
1761
1762 let subscription = cx.subscribe(
1763 &thread,
1764 move |_thread, event: &TokenUsageUpdated, _cx| {
1765 if let Some(usage) = &event.0 {
1766 let old_ratio = ratio_before_prompt
1767 .clone()
1768 .unwrap_or(TokenUsageRatio::Normal);
1769 let new_ratio = usage.ratio();
1770 if old_ratio == TokenUsageRatio::Normal
1771 && new_ratio == TokenUsageRatio::Warning
1772 {
1773 if let Some(tx) = token_limit_tx.take() {
1774 tx.send(()).ok();
1775 }
1776 }
1777 }
1778 },
1779 );
1780
1781 let wait_for_prompt = cx
1782 .background_spawn(async move {
1783 futures::select! {
1784 response = task.fuse() => match response {
1785 Ok(Some(response)) => {
1786 match response.stop_reason {
1787 acp::StopReason::Cancelled => SubagentPromptResult::Cancelled,
1788 acp::StopReason::MaxTokens => SubagentPromptResult::Error("The agent reached the maximum number of tokens.".into()),
1789 acp::StopReason::MaxTurnRequests => SubagentPromptResult::Error("The agent reached the maximum number of allowed requests between user turns. Try prompting again.".into()),
1790 acp::StopReason::Refusal => SubagentPromptResult::Error("The agent refused to process that prompt. Try again.".into()),
1791 acp::StopReason::EndTurn | _ => SubagentPromptResult::Completed,
1792 }
1793 }
1794 Ok(None) => SubagentPromptResult::Error("No response from the agent. You can try messaging again.".into()),
1795 Err(error) => SubagentPromptResult::Error(error.to_string()),
1796 },
1797 _ = token_limit_rx.fuse() => SubagentPromptResult::ContextWindowWarning,
1798 }
1799 });
1800
1801 (wait_for_prompt, subscription)
1802 });
1803
1804 let result = match task.await {
1805 SubagentPromptResult::Completed => thread.read_with(cx, |thread, _cx| {
1806 thread
1807 .last_message()
1808 .map(|m| m.to_markdown())
1809 .context("No response from subagent")
1810 }),
1811 SubagentPromptResult::Cancelled => Err(anyhow!("User cancelled")),
1812 SubagentPromptResult::Error(message) => Err(anyhow!("{message}")),
1813 SubagentPromptResult::ContextWindowWarning => {
1814 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
1815 Err(anyhow!(
1816 "The agent is nearing the end of its context window and has been \
1817 stopped. You can prompt the thread again to have the agent wrap up \
1818 or hand off its work."
1819 ))
1820 }
1821 };
1822
1823 parent_thread
1824 .update(cx, |parent_thread, cx| {
1825 parent_thread.unregister_running_subagent(&subagent_session_id, cx)
1826 })
1827 .ok();
1828
1829 result
1830 })
1831 }
1832}
1833
1834pub struct AcpTerminalHandle {
1835 terminal: Entity<acp_thread::Terminal>,
1836 _drop_tx: Option<oneshot::Sender<()>>,
1837}
1838
1839impl TerminalHandle for AcpTerminalHandle {
1840 fn id(&self, cx: &AsyncApp) -> Result<acp::TerminalId> {
1841 Ok(self.terminal.read_with(cx, |term, _cx| term.id().clone()))
1842 }
1843
1844 fn wait_for_exit(&self, cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
1845 Ok(self
1846 .terminal
1847 .read_with(cx, |term, _cx| term.wait_for_exit()))
1848 }
1849
1850 fn current_output(&self, cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
1851 Ok(self
1852 .terminal
1853 .read_with(cx, |term, cx| term.current_output(cx)))
1854 }
1855
1856 fn kill(&self, cx: &AsyncApp) -> Result<()> {
1857 cx.update(|cx| {
1858 self.terminal.update(cx, |terminal, cx| {
1859 terminal.kill(cx);
1860 });
1861 });
1862 Ok(())
1863 }
1864
1865 fn was_stopped_by_user(&self, cx: &AsyncApp) -> Result<bool> {
1866 Ok(self
1867 .terminal
1868 .read_with(cx, |term, _cx| term.was_stopped_by_user()))
1869 }
1870}
1871
1872#[cfg(test)]
1873mod internal_tests {
1874 use super::*;
1875 use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelInfo, MentionUri};
1876 use fs::FakeFs;
1877 use gpui::TestAppContext;
1878 use indoc::formatdoc;
1879 use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
1880 use language_model::{LanguageModelProviderId, LanguageModelProviderName};
1881 use serde_json::json;
1882 use settings::SettingsStore;
1883 use util::{path, rel_path::rel_path};
1884
1885 #[gpui::test]
1886 async fn test_maintaining_project_context(cx: &mut TestAppContext) {
1887 init_test(cx);
1888 let fs = FakeFs::new(cx.executor());
1889 fs.insert_tree(
1890 "/",
1891 json!({
1892 "a": {}
1893 }),
1894 )
1895 .await;
1896 let project = Project::test(fs.clone(), [], cx).await;
1897 let thread_store = cx.new(|cx| ThreadStore::new(cx));
1898 let agent = NativeAgent::new(
1899 project.clone(),
1900 thread_store,
1901 Templates::new(),
1902 None,
1903 fs.clone(),
1904 &mut cx.to_async(),
1905 )
1906 .await
1907 .unwrap();
1908 agent.read_with(cx, |agent, cx| {
1909 assert_eq!(agent.project_context.read(cx).worktrees, vec![])
1910 });
1911
1912 let worktree = project
1913 .update(cx, |project, cx| project.create_worktree("/a", true, cx))
1914 .await
1915 .unwrap();
1916 cx.run_until_parked();
1917 agent.read_with(cx, |agent, cx| {
1918 assert_eq!(
1919 agent.project_context.read(cx).worktrees,
1920 vec![WorktreeContext {
1921 root_name: "a".into(),
1922 abs_path: Path::new("/a").into(),
1923 rules_file: None
1924 }]
1925 )
1926 });
1927
1928 // Creating `/a/.rules` updates the project context.
1929 fs.insert_file("/a/.rules", Vec::new()).await;
1930 cx.run_until_parked();
1931 agent.read_with(cx, |agent, cx| {
1932 let rules_entry = worktree
1933 .read(cx)
1934 .entry_for_path(rel_path(".rules"))
1935 .unwrap();
1936 assert_eq!(
1937 agent.project_context.read(cx).worktrees,
1938 vec![WorktreeContext {
1939 root_name: "a".into(),
1940 abs_path: Path::new("/a").into(),
1941 rules_file: Some(RulesFileContext {
1942 path_in_worktree: rel_path(".rules").into(),
1943 text: "".into(),
1944 project_entry_id: rules_entry.id.to_usize()
1945 })
1946 }]
1947 )
1948 });
1949 }
1950
1951 #[gpui::test]
1952 async fn test_listing_models(cx: &mut TestAppContext) {
1953 init_test(cx);
1954 let fs = FakeFs::new(cx.executor());
1955 fs.insert_tree("/", json!({ "a": {} })).await;
1956 let project = Project::test(fs.clone(), [], cx).await;
1957 let thread_store = cx.new(|cx| ThreadStore::new(cx));
1958 let connection = NativeAgentConnection(
1959 NativeAgent::new(
1960 project.clone(),
1961 thread_store,
1962 Templates::new(),
1963 None,
1964 fs.clone(),
1965 &mut cx.to_async(),
1966 )
1967 .await
1968 .unwrap(),
1969 );
1970
1971 // Create a thread/session
1972 let acp_thread = cx
1973 .update(|cx| {
1974 Rc::new(connection.clone()).new_session(project.clone(), Path::new("/a"), cx)
1975 })
1976 .await
1977 .unwrap();
1978
1979 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
1980
1981 let models = cx
1982 .update(|cx| {
1983 connection
1984 .model_selector(&session_id)
1985 .unwrap()
1986 .list_models(cx)
1987 })
1988 .await
1989 .unwrap();
1990
1991 let acp_thread::AgentModelList::Grouped(models) = models else {
1992 panic!("Unexpected model group");
1993 };
1994 assert_eq!(
1995 models,
1996 IndexMap::from_iter([(
1997 AgentModelGroupName("Fake".into()),
1998 vec![AgentModelInfo {
1999 id: acp::ModelId::new("fake/fake"),
2000 name: "Fake".into(),
2001 description: None,
2002 icon: Some(acp_thread::AgentModelIcon::Named(
2003 ui::IconName::ZedAssistant
2004 )),
2005 is_latest: false,
2006 cost: None,
2007 }]
2008 )])
2009 );
2010 }
2011
2012 #[gpui::test]
2013 async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) {
2014 init_test(cx);
2015 let fs = FakeFs::new(cx.executor());
2016 fs.create_dir(paths::settings_file().parent().unwrap())
2017 .await
2018 .unwrap();
2019 fs.insert_file(
2020 paths::settings_file(),
2021 json!({
2022 "agent": {
2023 "default_model": {
2024 "provider": "foo",
2025 "model": "bar"
2026 }
2027 }
2028 })
2029 .to_string()
2030 .into_bytes(),
2031 )
2032 .await;
2033 let project = Project::test(fs.clone(), [], cx).await;
2034
2035 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2036
2037 // Create the agent and connection
2038 let agent = NativeAgent::new(
2039 project.clone(),
2040 thread_store,
2041 Templates::new(),
2042 None,
2043 fs.clone(),
2044 &mut cx.to_async(),
2045 )
2046 .await
2047 .unwrap();
2048 let connection = NativeAgentConnection(agent.clone());
2049
2050 // Create a thread/session
2051 let acp_thread = cx
2052 .update(|cx| {
2053 Rc::new(connection.clone()).new_session(project.clone(), Path::new("/a"), cx)
2054 })
2055 .await
2056 .unwrap();
2057
2058 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2059
2060 // Select a model
2061 let selector = connection.model_selector(&session_id).unwrap();
2062 let model_id = acp::ModelId::new("fake/fake");
2063 cx.update(|cx| selector.select_model(model_id.clone(), cx))
2064 .await
2065 .unwrap();
2066
2067 // Verify the thread has the selected model
2068 agent.read_with(cx, |agent, _| {
2069 let session = agent.sessions.get(&session_id).unwrap();
2070 session.thread.read_with(cx, |thread, _| {
2071 assert_eq!(thread.model().unwrap().id().0, "fake");
2072 });
2073 });
2074
2075 cx.run_until_parked();
2076
2077 // Verify settings file was updated
2078 let settings_content = fs.load(paths::settings_file()).await.unwrap();
2079 let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
2080
2081 // Check that the agent settings contain the selected model
2082 assert_eq!(
2083 settings_json["agent"]["default_model"]["model"],
2084 json!("fake")
2085 );
2086 assert_eq!(
2087 settings_json["agent"]["default_model"]["provider"],
2088 json!("fake")
2089 );
2090
2091 // Register a thinking model and select it.
2092 cx.update(|cx| {
2093 let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2094 "fake-corp",
2095 "fake-thinking",
2096 "Fake Thinking",
2097 true,
2098 ));
2099 let thinking_provider = Arc::new(
2100 FakeLanguageModelProvider::new(
2101 LanguageModelProviderId::from("fake-corp".to_string()),
2102 LanguageModelProviderName::from("Fake Corp".to_string()),
2103 )
2104 .with_models(vec![thinking_model]),
2105 );
2106 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2107 registry.register_provider(thinking_provider, cx);
2108 });
2109 });
2110 agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2111
2112 let selector = connection.model_selector(&session_id).unwrap();
2113 cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2114 .await
2115 .unwrap();
2116 cx.run_until_parked();
2117
2118 // Verify enable_thinking was written to settings as true.
2119 let settings_content = fs.load(paths::settings_file()).await.unwrap();
2120 let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
2121 assert_eq!(
2122 settings_json["agent"]["default_model"]["enable_thinking"],
2123 json!(true),
2124 "selecting a thinking model should persist enable_thinking: true to settings"
2125 );
2126 }
2127
2128 #[gpui::test]
2129 async fn test_select_model_updates_thinking_enabled(cx: &mut TestAppContext) {
2130 init_test(cx);
2131 let fs = FakeFs::new(cx.executor());
2132 fs.create_dir(paths::settings_file().parent().unwrap())
2133 .await
2134 .unwrap();
2135 fs.insert_file(paths::settings_file(), b"{}".to_vec()).await;
2136 let project = Project::test(fs.clone(), [], cx).await;
2137
2138 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2139 let agent = NativeAgent::new(
2140 project.clone(),
2141 thread_store,
2142 Templates::new(),
2143 None,
2144 fs.clone(),
2145 &mut cx.to_async(),
2146 )
2147 .await
2148 .unwrap();
2149 let connection = NativeAgentConnection(agent.clone());
2150
2151 let acp_thread = cx
2152 .update(|cx| {
2153 Rc::new(connection.clone()).new_session(project.clone(), Path::new("/a"), cx)
2154 })
2155 .await
2156 .unwrap();
2157 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2158
2159 // Register a second provider with a thinking model.
2160 cx.update(|cx| {
2161 let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2162 "fake-corp",
2163 "fake-thinking",
2164 "Fake Thinking",
2165 true,
2166 ));
2167 let thinking_provider = Arc::new(
2168 FakeLanguageModelProvider::new(
2169 LanguageModelProviderId::from("fake-corp".to_string()),
2170 LanguageModelProviderName::from("Fake Corp".to_string()),
2171 )
2172 .with_models(vec![thinking_model]),
2173 );
2174 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2175 registry.register_provider(thinking_provider, cx);
2176 });
2177 });
2178 // Refresh the agent's model list so it picks up the new provider.
2179 agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2180
2181 // Thread starts with thinking_enabled = false (the default).
2182 agent.read_with(cx, |agent, _| {
2183 let session = agent.sessions.get(&session_id).unwrap();
2184 session.thread.read_with(cx, |thread, _| {
2185 assert!(!thread.thinking_enabled(), "thinking defaults to false");
2186 });
2187 });
2188
2189 // Select the thinking model via select_model.
2190 let selector = connection.model_selector(&session_id).unwrap();
2191 cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2192 .await
2193 .unwrap();
2194
2195 // select_model should have enabled thinking based on the model's supports_thinking().
2196 agent.read_with(cx, |agent, _| {
2197 let session = agent.sessions.get(&session_id).unwrap();
2198 session.thread.read_with(cx, |thread, _| {
2199 assert!(
2200 thread.thinking_enabled(),
2201 "select_model should enable thinking when model supports it"
2202 );
2203 });
2204 });
2205
2206 // Switch back to the non-thinking model.
2207 let selector = connection.model_selector(&session_id).unwrap();
2208 cx.update(|cx| selector.select_model(acp::ModelId::new("fake/fake"), cx))
2209 .await
2210 .unwrap();
2211
2212 // select_model should have disabled thinking.
2213 agent.read_with(cx, |agent, _| {
2214 let session = agent.sessions.get(&session_id).unwrap();
2215 session.thread.read_with(cx, |thread, _| {
2216 assert!(
2217 !thread.thinking_enabled(),
2218 "select_model should disable thinking when model does not support it"
2219 );
2220 });
2221 });
2222 }
2223
2224 #[gpui::test]
2225 async fn test_loaded_thread_preserves_thinking_enabled(cx: &mut TestAppContext) {
2226 init_test(cx);
2227 let fs = FakeFs::new(cx.executor());
2228 fs.insert_tree("/", json!({ "a": {} })).await;
2229 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2230 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2231 let agent = NativeAgent::new(
2232 project.clone(),
2233 thread_store.clone(),
2234 Templates::new(),
2235 None,
2236 fs.clone(),
2237 &mut cx.to_async(),
2238 )
2239 .await
2240 .unwrap();
2241 let connection = Rc::new(NativeAgentConnection(agent.clone()));
2242
2243 // Register a thinking model.
2244 let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2245 "fake-corp",
2246 "fake-thinking",
2247 "Fake Thinking",
2248 true,
2249 ));
2250 let thinking_provider = Arc::new(
2251 FakeLanguageModelProvider::new(
2252 LanguageModelProviderId::from("fake-corp".to_string()),
2253 LanguageModelProviderName::from("Fake Corp".to_string()),
2254 )
2255 .with_models(vec![thinking_model.clone()]),
2256 );
2257 cx.update(|cx| {
2258 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2259 registry.register_provider(thinking_provider, cx);
2260 });
2261 });
2262 agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2263
2264 // Create a thread and select the thinking model.
2265 let acp_thread = cx
2266 .update(|cx| {
2267 connection
2268 .clone()
2269 .new_session(project.clone(), Path::new("/a"), cx)
2270 })
2271 .await
2272 .unwrap();
2273 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2274
2275 let selector = connection.model_selector(&session_id).unwrap();
2276 cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2277 .await
2278 .unwrap();
2279
2280 // Verify thinking is enabled after selecting the thinking model.
2281 let thread = agent.read_with(cx, |agent, _| {
2282 agent.sessions.get(&session_id).unwrap().thread.clone()
2283 });
2284 thread.read_with(cx, |thread, _| {
2285 assert!(
2286 thread.thinking_enabled(),
2287 "thinking should be enabled after selecting thinking model"
2288 );
2289 });
2290
2291 // Send a message so the thread gets persisted.
2292 let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
2293 let send = cx.foreground_executor().spawn(send);
2294 cx.run_until_parked();
2295
2296 thinking_model.send_last_completion_stream_text_chunk("Response.");
2297 thinking_model.end_last_completion_stream();
2298
2299 send.await.unwrap();
2300 cx.run_until_parked();
2301
2302 // Close the session so it can be reloaded from disk.
2303 cx.update(|cx| connection.clone().close_session(&session_id, cx))
2304 .await
2305 .unwrap();
2306 drop(thread);
2307 drop(acp_thread);
2308 agent.read_with(cx, |agent, _| {
2309 assert!(agent.sessions.is_empty());
2310 });
2311
2312 // Reload the thread and verify thinking_enabled is still true.
2313 let reloaded_acp_thread = agent
2314 .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
2315 .await
2316 .unwrap();
2317 let reloaded_thread = agent.read_with(cx, |agent, _| {
2318 agent.sessions.get(&session_id).unwrap().thread.clone()
2319 });
2320 reloaded_thread.read_with(cx, |thread, _| {
2321 assert!(
2322 thread.thinking_enabled(),
2323 "thinking_enabled should be preserved when reloading a thread with a thinking model"
2324 );
2325 });
2326
2327 drop(reloaded_acp_thread);
2328 }
2329
2330 #[gpui::test]
2331 async fn test_loaded_thread_preserves_model(cx: &mut TestAppContext) {
2332 init_test(cx);
2333 let fs = FakeFs::new(cx.executor());
2334 fs.insert_tree("/", json!({ "a": {} })).await;
2335 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2336 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2337 let agent = NativeAgent::new(
2338 project.clone(),
2339 thread_store.clone(),
2340 Templates::new(),
2341 None,
2342 fs.clone(),
2343 &mut cx.to_async(),
2344 )
2345 .await
2346 .unwrap();
2347 let connection = Rc::new(NativeAgentConnection(agent.clone()));
2348
2349 // Register a model where id() != name(), like real Anthropic models
2350 // (e.g. id="claude-sonnet-4-5-thinking-latest", name="Claude Sonnet 4.5 Thinking").
2351 let model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2352 "fake-corp",
2353 "custom-model-id",
2354 "Custom Model Display Name",
2355 false,
2356 ));
2357 let provider = Arc::new(
2358 FakeLanguageModelProvider::new(
2359 LanguageModelProviderId::from("fake-corp".to_string()),
2360 LanguageModelProviderName::from("Fake Corp".to_string()),
2361 )
2362 .with_models(vec![model.clone()]),
2363 );
2364 cx.update(|cx| {
2365 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2366 registry.register_provider(provider, cx);
2367 });
2368 });
2369 agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2370
2371 // Create a thread and select the model.
2372 let acp_thread = cx
2373 .update(|cx| {
2374 connection
2375 .clone()
2376 .new_session(project.clone(), Path::new("/a"), cx)
2377 })
2378 .await
2379 .unwrap();
2380 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2381
2382 let selector = connection.model_selector(&session_id).unwrap();
2383 cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/custom-model-id"), cx))
2384 .await
2385 .unwrap();
2386
2387 let thread = agent.read_with(cx, |agent, _| {
2388 agent.sessions.get(&session_id).unwrap().thread.clone()
2389 });
2390 thread.read_with(cx, |thread, _| {
2391 assert_eq!(
2392 thread.model().unwrap().id().0.as_ref(),
2393 "custom-model-id",
2394 "model should be set before persisting"
2395 );
2396 });
2397
2398 // Send a message so the thread gets persisted.
2399 let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
2400 let send = cx.foreground_executor().spawn(send);
2401 cx.run_until_parked();
2402
2403 model.send_last_completion_stream_text_chunk("Response.");
2404 model.end_last_completion_stream();
2405
2406 send.await.unwrap();
2407 cx.run_until_parked();
2408
2409 // Close the session so it can be reloaded from disk.
2410 cx.update(|cx| connection.clone().close_session(&session_id, cx))
2411 .await
2412 .unwrap();
2413 drop(thread);
2414 drop(acp_thread);
2415 agent.read_with(cx, |agent, _| {
2416 assert!(agent.sessions.is_empty());
2417 });
2418
2419 // Reload the thread and verify the model was preserved.
2420 let reloaded_acp_thread = agent
2421 .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
2422 .await
2423 .unwrap();
2424 let reloaded_thread = agent.read_with(cx, |agent, _| {
2425 agent.sessions.get(&session_id).unwrap().thread.clone()
2426 });
2427 reloaded_thread.read_with(cx, |thread, _| {
2428 let reloaded_model = thread
2429 .model()
2430 .expect("model should be present after reload");
2431 assert_eq!(
2432 reloaded_model.id().0.as_ref(),
2433 "custom-model-id",
2434 "reloaded thread should have the same model, not fall back to the default"
2435 );
2436 });
2437
2438 drop(reloaded_acp_thread);
2439 }
2440
2441 #[gpui::test]
2442 async fn test_save_load_thread(cx: &mut TestAppContext) {
2443 init_test(cx);
2444 let fs = FakeFs::new(cx.executor());
2445 fs.insert_tree(
2446 "/",
2447 json!({
2448 "a": {
2449 "b.md": "Lorem"
2450 }
2451 }),
2452 )
2453 .await;
2454 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2455 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2456 let agent = NativeAgent::new(
2457 project.clone(),
2458 thread_store.clone(),
2459 Templates::new(),
2460 None,
2461 fs.clone(),
2462 &mut cx.to_async(),
2463 )
2464 .await
2465 .unwrap();
2466 let connection = Rc::new(NativeAgentConnection(agent.clone()));
2467
2468 let acp_thread = cx
2469 .update(|cx| {
2470 connection
2471 .clone()
2472 .new_session(project.clone(), Path::new(""), cx)
2473 })
2474 .await
2475 .unwrap();
2476 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2477 let thread = agent.read_with(cx, |agent, _| {
2478 agent.sessions.get(&session_id).unwrap().thread.clone()
2479 });
2480
2481 // Ensure empty threads are not saved, even if they get mutated.
2482 let model = Arc::new(FakeLanguageModel::default());
2483 let summary_model = Arc::new(FakeLanguageModel::default());
2484 thread.update(cx, |thread, cx| {
2485 thread.set_model(model.clone(), cx);
2486 thread.set_summarization_model(Some(summary_model.clone()), cx);
2487 });
2488 cx.run_until_parked();
2489 assert_eq!(thread_entries(&thread_store, cx), vec![]);
2490
2491 let send = acp_thread.update(cx, |thread, cx| {
2492 thread.send(
2493 vec![
2494 "What does ".into(),
2495 acp::ContentBlock::ResourceLink(acp::ResourceLink::new(
2496 "b.md",
2497 MentionUri::File {
2498 abs_path: path!("/a/b.md").into(),
2499 }
2500 .to_uri()
2501 .to_string(),
2502 )),
2503 " mean?".into(),
2504 ],
2505 cx,
2506 )
2507 });
2508 let send = cx.foreground_executor().spawn(send);
2509 cx.run_until_parked();
2510
2511 model.send_last_completion_stream_text_chunk("Lorem.");
2512 model.end_last_completion_stream();
2513 cx.run_until_parked();
2514 summary_model
2515 .send_last_completion_stream_text_chunk(&format!("Explaining {}", path!("/a/b.md")));
2516 summary_model.end_last_completion_stream();
2517
2518 send.await.unwrap();
2519 let uri = MentionUri::File {
2520 abs_path: path!("/a/b.md").into(),
2521 }
2522 .to_uri();
2523 acp_thread.read_with(cx, |thread, cx| {
2524 assert_eq!(
2525 thread.to_markdown(cx),
2526 formatdoc! {"
2527 ## User
2528
2529 What does [@b.md]({uri}) mean?
2530
2531 ## Assistant
2532
2533 Lorem.
2534
2535 "}
2536 )
2537 });
2538
2539 cx.run_until_parked();
2540
2541 // Close the session so it can be reloaded from disk.
2542 cx.update(|cx| connection.clone().close_session(&session_id, cx))
2543 .await
2544 .unwrap();
2545 drop(thread);
2546 drop(acp_thread);
2547 agent.read_with(cx, |agent, _| {
2548 assert_eq!(agent.sessions.keys().cloned().collect::<Vec<_>>(), []);
2549 });
2550
2551 // Ensure the thread can be reloaded from disk.
2552 assert_eq!(
2553 thread_entries(&thread_store, cx),
2554 vec![(
2555 session_id.clone(),
2556 format!("Explaining {}", path!("/a/b.md"))
2557 )]
2558 );
2559 let acp_thread = agent
2560 .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
2561 .await
2562 .unwrap();
2563 acp_thread.read_with(cx, |thread, cx| {
2564 assert_eq!(
2565 thread.to_markdown(cx),
2566 formatdoc! {"
2567 ## User
2568
2569 What does [@b.md]({uri}) mean?
2570
2571 ## Assistant
2572
2573 Lorem.
2574
2575 "}
2576 )
2577 });
2578 }
2579
2580 fn thread_entries(
2581 thread_store: &Entity<ThreadStore>,
2582 cx: &mut TestAppContext,
2583 ) -> Vec<(acp::SessionId, String)> {
2584 thread_store.read_with(cx, |store, _| {
2585 store
2586 .entries()
2587 .map(|entry| (entry.id.clone(), entry.title.to_string()))
2588 .collect::<Vec<_>>()
2589 })
2590 }
2591
2592 fn init_test(cx: &mut TestAppContext) {
2593 env_logger::try_init().ok();
2594 cx.update(|cx| {
2595 let settings_store = SettingsStore::test(cx);
2596 cx.set_global(settings_store);
2597
2598 LanguageModelRegistry::test(cx);
2599 });
2600 }
2601}
2602
2603fn mcp_message_content_to_acp_content_block(
2604 content: context_server::types::MessageContent,
2605) -> acp::ContentBlock {
2606 match content {
2607 context_server::types::MessageContent::Text {
2608 text,
2609 annotations: _,
2610 } => text.into(),
2611 context_server::types::MessageContent::Image {
2612 data,
2613 mime_type,
2614 annotations: _,
2615 } => acp::ContentBlock::Image(acp::ImageContent::new(data, mime_type)),
2616 context_server::types::MessageContent::Audio {
2617 data,
2618 mime_type,
2619 annotations: _,
2620 } => acp::ContentBlock::Audio(acp::AudioContent::new(data, mime_type)),
2621 context_server::types::MessageContent::Resource {
2622 resource,
2623 annotations: _,
2624 } => {
2625 let mut link =
2626 acp::ResourceLink::new(resource.uri.to_string(), resource.uri.to_string());
2627 if let Some(mime_type) = resource.mime_type {
2628 link = link.mime_type(mime_type);
2629 }
2630 acp::ContentBlock::ResourceLink(link)
2631 }
2632 }
2633}