1mod db;
2mod edit_agent;
3mod legacy_thread;
4mod native_agent_server;
5pub mod outline;
6mod pattern_extraction;
7mod templates;
8#[cfg(test)]
9mod tests;
10mod thread;
11mod thread_store;
12mod tool_permissions;
13mod tools;
14
15use context_server::ContextServerId;
16pub use db::*;
17pub use native_agent_server::NativeAgentServer;
18pub use pattern_extraction::*;
19pub use shell_command_parser::extract_commands;
20pub use templates::*;
21pub use thread::*;
22pub use thread_store::*;
23pub use tool_permissions::*;
24pub use tools::*;
25
26use acp_thread::{
27 AcpThread, AgentModelSelector, AgentSessionInfo, AgentSessionList, AgentSessionListRequest,
28 AgentSessionListResponse, TokenUsageRatio, UserMessageId,
29};
30use agent_client_protocol as acp;
31use anyhow::{Context as _, Result, anyhow};
32use chrono::{DateTime, Utc};
33use collections::{HashMap, HashSet, IndexMap};
34use fs::Fs;
35use futures::channel::{mpsc, oneshot};
36use futures::future::Shared;
37use futures::{FutureExt as _, StreamExt as _, future};
38use gpui::{
39 App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
40};
41use language_model::{IconOrSvg, LanguageModel, LanguageModelProvider, LanguageModelRegistry};
42use project::{Project, ProjectItem, ProjectPath, Worktree};
43use prompt_store::{
44 ProjectContext, PromptStore, RULES_FILE_NAMES, RulesFileContext, UserRulesContext,
45 WorktreeContext,
46};
47use serde::{Deserialize, Serialize};
48use settings::{LanguageModelSelection, update_settings_file};
49use std::any::Any;
50use std::path::{Path, PathBuf};
51use std::rc::Rc;
52use std::sync::Arc;
53use util::ResultExt;
54use util::path_list::PathList;
55use util::rel_path::RelPath;
56
57#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
58pub struct ProjectSnapshot {
59 pub worktree_snapshots: Vec<project::telemetry_snapshot::TelemetryWorktreeSnapshot>,
60 pub timestamp: DateTime<Utc>,
61}
62
63pub struct RulesLoadingError {
64 pub message: SharedString,
65}
66
67/// Holds both the internal Thread and the AcpThread for a session
68struct Session {
69 /// The internal thread that processes messages
70 thread: Entity<Thread>,
71 /// The ACP thread that handles protocol communication
72 acp_thread: Entity<acp_thread::AcpThread>,
73 pending_save: Task<()>,
74 _subscriptions: Vec<Subscription>,
75}
76
77pub struct LanguageModels {
78 /// Access language model by ID
79 models: HashMap<acp::ModelId, Arc<dyn LanguageModel>>,
80 /// Cached list for returning language model information
81 model_list: acp_thread::AgentModelList,
82 refresh_models_rx: watch::Receiver<()>,
83 refresh_models_tx: watch::Sender<()>,
84 _authenticate_all_providers_task: Task<()>,
85}
86
87impl LanguageModels {
88 fn new(cx: &mut App) -> Self {
89 let (refresh_models_tx, refresh_models_rx) = watch::channel(());
90
91 let mut this = Self {
92 models: HashMap::default(),
93 model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()),
94 refresh_models_rx,
95 refresh_models_tx,
96 _authenticate_all_providers_task: Self::authenticate_all_language_model_providers(cx),
97 };
98 this.refresh_list(cx);
99 this
100 }
101
102 fn refresh_list(&mut self, cx: &App) {
103 let providers = LanguageModelRegistry::global(cx)
104 .read(cx)
105 .visible_providers()
106 .into_iter()
107 .filter(|provider| provider.is_authenticated(cx))
108 .collect::<Vec<_>>();
109
110 let mut language_model_list = IndexMap::default();
111 let mut recommended_models = HashSet::default();
112
113 let mut recommended = Vec::new();
114 for provider in &providers {
115 for model in provider.recommended_models(cx) {
116 recommended_models.insert((model.provider_id(), model.id()));
117 recommended.push(Self::map_language_model_to_info(&model, provider));
118 }
119 }
120 if !recommended.is_empty() {
121 language_model_list.insert(
122 acp_thread::AgentModelGroupName("Recommended".into()),
123 recommended,
124 );
125 }
126
127 let mut models = HashMap::default();
128 for provider in providers {
129 let mut provider_models = Vec::new();
130 for model in provider.provided_models(cx) {
131 let model_info = Self::map_language_model_to_info(&model, &provider);
132 let model_id = model_info.id.clone();
133 provider_models.push(model_info);
134 models.insert(model_id, model);
135 }
136 if !provider_models.is_empty() {
137 language_model_list.insert(
138 acp_thread::AgentModelGroupName(provider.name().0.clone()),
139 provider_models,
140 );
141 }
142 }
143
144 self.models = models;
145 self.model_list = acp_thread::AgentModelList::Grouped(language_model_list);
146 self.refresh_models_tx.send(()).ok();
147 }
148
149 fn watch(&self) -> watch::Receiver<()> {
150 self.refresh_models_rx.clone()
151 }
152
153 pub fn model_from_id(&self, model_id: &acp::ModelId) -> Option<Arc<dyn LanguageModel>> {
154 self.models.get(model_id).cloned()
155 }
156
157 fn map_language_model_to_info(
158 model: &Arc<dyn LanguageModel>,
159 provider: &Arc<dyn LanguageModelProvider>,
160 ) -> acp_thread::AgentModelInfo {
161 acp_thread::AgentModelInfo {
162 id: Self::model_id(model),
163 name: model.name().0,
164 description: None,
165 icon: Some(match provider.icon() {
166 IconOrSvg::Svg(path) => acp_thread::AgentModelIcon::Path(path),
167 IconOrSvg::Icon(name) => acp_thread::AgentModelIcon::Named(name),
168 }),
169 is_latest: model.is_latest(),
170 cost: model.model_cost_info().map(|cost| cost.to_shared_string()),
171 }
172 }
173
174 fn model_id(model: &Arc<dyn LanguageModel>) -> acp::ModelId {
175 acp::ModelId::new(format!("{}/{}", model.provider_id().0, model.id().0))
176 }
177
178 fn authenticate_all_language_model_providers(cx: &mut App) -> Task<()> {
179 let authenticate_all_providers = LanguageModelRegistry::global(cx)
180 .read(cx)
181 .visible_providers()
182 .iter()
183 .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
184 .collect::<Vec<_>>();
185
186 cx.background_spawn(async move {
187 for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
188 if let Err(err) = authenticate_task.await {
189 match err {
190 language_model::AuthenticateError::CredentialsNotFound => {
191 // Since we're authenticating these providers in the
192 // background for the purposes of populating the
193 // language selector, we don't care about providers
194 // where the credentials are not found.
195 }
196 language_model::AuthenticateError::ConnectionRefused => {
197 // Not logging connection refused errors as they are mostly from LM Studio's noisy auth failures.
198 // LM Studio only has one auth method (endpoint call) which fails for users who haven't enabled it.
199 // TODO: Better manage LM Studio auth logic to avoid these noisy failures.
200 }
201 _ => {
202 // Some providers have noisy failure states that we
203 // don't want to spam the logs with every time the
204 // language model selector is initialized.
205 //
206 // Ideally these should have more clear failure modes
207 // that we know are safe to ignore here, like what we do
208 // with `CredentialsNotFound` above.
209 match provider_id.0.as_ref() {
210 "lmstudio" | "ollama" => {
211 // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
212 //
213 // These fail noisily, so we don't log them.
214 }
215 "copilot_chat" => {
216 // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
217 }
218 _ => {
219 log::error!(
220 "Failed to authenticate provider: {}: {err:#}",
221 provider_name.0
222 );
223 }
224 }
225 }
226 }
227 }
228 }
229 })
230 }
231}
232
233pub struct NativeAgent {
234 /// Session ID -> Session mapping
235 sessions: HashMap<acp::SessionId, Session>,
236 thread_store: Entity<ThreadStore>,
237 /// Shared project context for all threads
238 project_context: Entity<ProjectContext>,
239 project_context_needs_refresh: watch::Sender<()>,
240 _maintain_project_context: Task<Result<()>>,
241 context_server_registry: Entity<ContextServerRegistry>,
242 /// Shared templates for all threads
243 templates: Arc<Templates>,
244 /// Cached model information
245 models: LanguageModels,
246 project: Entity<Project>,
247 prompt_store: Option<Entity<PromptStore>>,
248 fs: Arc<dyn Fs>,
249 _subscriptions: Vec<Subscription>,
250}
251
252impl NativeAgent {
253 pub async fn new(
254 project: Entity<Project>,
255 thread_store: Entity<ThreadStore>,
256 templates: Arc<Templates>,
257 prompt_store: Option<Entity<PromptStore>>,
258 fs: Arc<dyn Fs>,
259 cx: &mut AsyncApp,
260 ) -> Result<Entity<NativeAgent>> {
261 log::debug!("Creating new NativeAgent");
262
263 let project_context = cx
264 .update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))
265 .await;
266
267 Ok(cx.new(|cx| {
268 let context_server_store = project.read(cx).context_server_store();
269 let context_server_registry =
270 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
271
272 let mut subscriptions = vec![
273 cx.subscribe(&project, Self::handle_project_event),
274 cx.subscribe(
275 &LanguageModelRegistry::global(cx),
276 Self::handle_models_updated_event,
277 ),
278 cx.subscribe(
279 &context_server_store,
280 Self::handle_context_server_store_updated,
281 ),
282 cx.subscribe(
283 &context_server_registry,
284 Self::handle_context_server_registry_event,
285 ),
286 ];
287 if let Some(prompt_store) = prompt_store.as_ref() {
288 subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
289 }
290
291 let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
292 watch::channel(());
293 Self {
294 sessions: HashMap::default(),
295 thread_store,
296 project_context: cx.new(|_| project_context),
297 project_context_needs_refresh: project_context_needs_refresh_tx,
298 _maintain_project_context: cx.spawn(async move |this, cx| {
299 Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
300 }),
301 context_server_registry,
302 templates,
303 models: LanguageModels::new(cx),
304 project,
305 prompt_store,
306 fs,
307 _subscriptions: subscriptions,
308 }
309 }))
310 }
311
312 fn new_session(
313 &mut self,
314 project: Entity<Project>,
315 cx: &mut Context<Self>,
316 ) -> Entity<AcpThread> {
317 // Create Thread
318 // Fetch default model from registry settings
319 let registry = LanguageModelRegistry::read_global(cx);
320 // Log available models for debugging
321 let available_count = registry.available_models(cx).count();
322 log::debug!("Total available models: {}", available_count);
323
324 let default_model = registry.default_model().and_then(|default_model| {
325 self.models
326 .model_from_id(&LanguageModels::model_id(&default_model.model))
327 });
328 let thread = cx.new(|cx| {
329 Thread::new(
330 project.clone(),
331 self.project_context.clone(),
332 self.context_server_registry.clone(),
333 self.templates.clone(),
334 default_model,
335 cx,
336 )
337 });
338
339 self.register_session(thread, cx)
340 }
341
342 fn register_session(
343 &mut self,
344 thread_handle: Entity<Thread>,
345 cx: &mut Context<Self>,
346 ) -> Entity<AcpThread> {
347 let connection = Rc::new(NativeAgentConnection(cx.entity()));
348
349 let thread = thread_handle.read(cx);
350 let session_id = thread.id().clone();
351 let parent_session_id = thread.parent_thread_id();
352 let title = thread.title();
353 let project = thread.project.clone();
354 let action_log = thread.action_log.clone();
355 let prompt_capabilities_rx = thread.prompt_capabilities_rx.clone();
356 let acp_thread = cx.new(|cx| {
357 acp_thread::AcpThread::new(
358 parent_session_id,
359 title,
360 connection,
361 project.clone(),
362 action_log.clone(),
363 session_id.clone(),
364 prompt_capabilities_rx,
365 cx,
366 )
367 });
368
369 let registry = LanguageModelRegistry::read_global(cx);
370 let summarization_model = registry.thread_summary_model().map(|c| c.model);
371
372 let weak = cx.weak_entity();
373 let weak_thread = thread_handle.downgrade();
374 thread_handle.update(cx, |thread, cx| {
375 thread.set_summarization_model(summarization_model, cx);
376 thread.add_default_tools(
377 Rc::new(NativeThreadEnvironment {
378 acp_thread: acp_thread.downgrade(),
379 thread: weak_thread,
380 agent: weak,
381 }) as _,
382 cx,
383 )
384 });
385
386 let subscriptions = vec![
387 cx.subscribe(&thread_handle, Self::handle_thread_title_updated),
388 cx.subscribe(&thread_handle, Self::handle_thread_token_usage_updated),
389 cx.observe(&thread_handle, move |this, thread, cx| {
390 this.save_thread(thread, cx)
391 }),
392 ];
393
394 self.sessions.insert(
395 session_id,
396 Session {
397 thread: thread_handle,
398 acp_thread: acp_thread.clone(),
399 _subscriptions: subscriptions,
400 pending_save: Task::ready(()),
401 },
402 );
403
404 self.update_available_commands(cx);
405
406 acp_thread
407 }
408
409 pub fn models(&self) -> &LanguageModels {
410 &self.models
411 }
412
413 async fn maintain_project_context(
414 this: WeakEntity<Self>,
415 mut needs_refresh: watch::Receiver<()>,
416 cx: &mut AsyncApp,
417 ) -> Result<()> {
418 while needs_refresh.changed().await.is_ok() {
419 let project_context = this
420 .update(cx, |this, cx| {
421 Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx)
422 })?
423 .await;
424 this.update(cx, |this, cx| {
425 this.project_context = cx.new(|_| project_context);
426 })?;
427 }
428
429 Ok(())
430 }
431
432 fn build_project_context(
433 project: &Entity<Project>,
434 prompt_store: Option<&Entity<PromptStore>>,
435 cx: &mut App,
436 ) -> Task<ProjectContext> {
437 let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
438 let worktree_tasks = worktrees
439 .into_iter()
440 .map(|worktree| {
441 Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
442 })
443 .collect::<Vec<_>>();
444 let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
445 prompt_store.read_with(cx, |prompt_store, cx| {
446 let prompts = prompt_store.default_prompt_metadata();
447 let load_tasks = prompts.into_iter().map(|prompt_metadata| {
448 let contents = prompt_store.load(prompt_metadata.id, cx);
449 async move { (contents.await, prompt_metadata) }
450 });
451 cx.background_spawn(future::join_all(load_tasks))
452 })
453 } else {
454 Task::ready(vec![])
455 };
456
457 cx.spawn(async move |_cx| {
458 let (worktrees, default_user_rules) =
459 future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
460
461 let worktrees = worktrees
462 .into_iter()
463 .map(|(worktree, _rules_error)| {
464 // TODO: show error message
465 // if let Some(rules_error) = rules_error {
466 // this.update(cx, |_, cx| cx.emit(rules_error)).ok();
467 // }
468 worktree
469 })
470 .collect::<Vec<_>>();
471
472 let default_user_rules = default_user_rules
473 .into_iter()
474 .flat_map(|(contents, prompt_metadata)| match contents {
475 Ok(contents) => Some(UserRulesContext {
476 uuid: prompt_metadata.id.as_user()?,
477 title: prompt_metadata.title.map(|title| title.to_string()),
478 contents,
479 }),
480 Err(_err) => {
481 // TODO: show error message
482 // this.update(cx, |_, cx| {
483 // cx.emit(RulesLoadingError {
484 // message: format!("{err:?}").into(),
485 // });
486 // })
487 // .ok();
488 None
489 }
490 })
491 .collect::<Vec<_>>();
492
493 ProjectContext::new(worktrees, default_user_rules)
494 })
495 }
496
497 fn load_worktree_info_for_system_prompt(
498 worktree: Entity<Worktree>,
499 project: Entity<Project>,
500 cx: &mut App,
501 ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
502 let tree = worktree.read(cx);
503 let root_name = tree.root_name_str().into();
504 let abs_path = tree.abs_path();
505
506 let mut context = WorktreeContext {
507 root_name,
508 abs_path,
509 rules_file: None,
510 };
511
512 let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
513 let Some(rules_task) = rules_task else {
514 return Task::ready((context, None));
515 };
516
517 cx.spawn(async move |_| {
518 let (rules_file, rules_file_error) = match rules_task.await {
519 Ok(rules_file) => (Some(rules_file), None),
520 Err(err) => (
521 None,
522 Some(RulesLoadingError {
523 message: format!("{err}").into(),
524 }),
525 ),
526 };
527 context.rules_file = rules_file;
528 (context, rules_file_error)
529 })
530 }
531
532 fn load_worktree_rules_file(
533 worktree: Entity<Worktree>,
534 project: Entity<Project>,
535 cx: &mut App,
536 ) -> Option<Task<Result<RulesFileContext>>> {
537 let worktree = worktree.read(cx);
538 let worktree_id = worktree.id();
539 let selected_rules_file = RULES_FILE_NAMES
540 .into_iter()
541 .filter_map(|name| {
542 worktree
543 .entry_for_path(RelPath::unix(name).unwrap())
544 .filter(|entry| entry.is_file())
545 .map(|entry| entry.path.clone())
546 })
547 .next();
548
549 // Note that Cline supports `.clinerules` being a directory, but that is not currently
550 // supported. This doesn't seem to occur often in GitHub repositories.
551 selected_rules_file.map(|path_in_worktree| {
552 let project_path = ProjectPath {
553 worktree_id,
554 path: path_in_worktree.clone(),
555 };
556 let buffer_task =
557 project.update(cx, |project, cx| project.open_buffer(project_path, cx));
558 let rope_task = cx.spawn(async move |cx| {
559 let buffer = buffer_task.await?;
560 let (project_entry_id, rope) = buffer.read_with(cx, |buffer, cx| {
561 let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
562 anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
563 })?;
564 anyhow::Ok((project_entry_id, rope))
565 });
566 // Build a string from the rope on a background thread.
567 cx.background_spawn(async move {
568 let (project_entry_id, rope) = rope_task.await?;
569 anyhow::Ok(RulesFileContext {
570 path_in_worktree,
571 text: rope.to_string().trim().to_string(),
572 project_entry_id: project_entry_id.to_usize(),
573 })
574 })
575 })
576 }
577
578 fn handle_thread_title_updated(
579 &mut self,
580 thread: Entity<Thread>,
581 _: &TitleUpdated,
582 cx: &mut Context<Self>,
583 ) {
584 let session_id = thread.read(cx).id();
585 let Some(session) = self.sessions.get(session_id) else {
586 return;
587 };
588 let thread = thread.downgrade();
589 let acp_thread = session.acp_thread.downgrade();
590 cx.spawn(async move |_, cx| {
591 let title = thread.read_with(cx, |thread, _| thread.title())?;
592 let task = acp_thread.update(cx, |acp_thread, cx| acp_thread.set_title(title, cx))?;
593 task.await
594 })
595 .detach_and_log_err(cx);
596 }
597
598 fn handle_thread_token_usage_updated(
599 &mut self,
600 thread: Entity<Thread>,
601 usage: &TokenUsageUpdated,
602 cx: &mut Context<Self>,
603 ) {
604 let Some(session) = self.sessions.get(thread.read(cx).id()) else {
605 return;
606 };
607 session.acp_thread.update(cx, |acp_thread, cx| {
608 acp_thread.update_token_usage(usage.0.clone(), cx);
609 });
610 }
611
612 fn handle_project_event(
613 &mut self,
614 _project: Entity<Project>,
615 event: &project::Event,
616 _cx: &mut Context<Self>,
617 ) {
618 match event {
619 project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
620 self.project_context_needs_refresh.send(()).ok();
621 }
622 project::Event::WorktreeUpdatedEntries(_, items) => {
623 if items.iter().any(|(path, _, _)| {
624 RULES_FILE_NAMES
625 .iter()
626 .any(|name| path.as_ref() == RelPath::unix(name).unwrap())
627 }) {
628 self.project_context_needs_refresh.send(()).ok();
629 }
630 }
631 _ => {}
632 }
633 }
634
635 fn handle_prompts_updated_event(
636 &mut self,
637 _prompt_store: Entity<PromptStore>,
638 _event: &prompt_store::PromptsUpdatedEvent,
639 _cx: &mut Context<Self>,
640 ) {
641 self.project_context_needs_refresh.send(()).ok();
642 }
643
644 fn handle_models_updated_event(
645 &mut self,
646 _registry: Entity<LanguageModelRegistry>,
647 _event: &language_model::Event,
648 cx: &mut Context<Self>,
649 ) {
650 self.models.refresh_list(cx);
651
652 let registry = LanguageModelRegistry::read_global(cx);
653 let default_model = registry.default_model().map(|m| m.model);
654 let summarization_model = registry.thread_summary_model().map(|m| m.model);
655
656 for session in self.sessions.values_mut() {
657 session.thread.update(cx, |thread, cx| {
658 if thread.model().is_none()
659 && let Some(model) = default_model.clone()
660 {
661 thread.set_model(model, cx);
662 cx.notify();
663 }
664 thread.set_summarization_model(summarization_model.clone(), cx);
665 });
666 }
667 }
668
669 fn handle_context_server_store_updated(
670 &mut self,
671 _store: Entity<project::context_server_store::ContextServerStore>,
672 _event: &project::context_server_store::ServerStatusChangedEvent,
673 cx: &mut Context<Self>,
674 ) {
675 self.update_available_commands(cx);
676 }
677
678 fn handle_context_server_registry_event(
679 &mut self,
680 _registry: Entity<ContextServerRegistry>,
681 event: &ContextServerRegistryEvent,
682 cx: &mut Context<Self>,
683 ) {
684 match event {
685 ContextServerRegistryEvent::ToolsChanged => {}
686 ContextServerRegistryEvent::PromptsChanged => {
687 self.update_available_commands(cx);
688 }
689 }
690 }
691
692 fn update_available_commands(&self, cx: &mut Context<Self>) {
693 let available_commands = self.build_available_commands(cx);
694 for session in self.sessions.values() {
695 session.acp_thread.update(cx, |thread, cx| {
696 thread
697 .handle_session_update(
698 acp::SessionUpdate::AvailableCommandsUpdate(
699 acp::AvailableCommandsUpdate::new(available_commands.clone()),
700 ),
701 cx,
702 )
703 .log_err();
704 });
705 }
706 }
707
708 fn build_available_commands(&self, cx: &App) -> Vec<acp::AvailableCommand> {
709 let registry = self.context_server_registry.read(cx);
710
711 let mut prompt_name_counts: HashMap<&str, usize> = HashMap::default();
712 for context_server_prompt in registry.prompts() {
713 *prompt_name_counts
714 .entry(context_server_prompt.prompt.name.as_str())
715 .or_insert(0) += 1;
716 }
717
718 registry
719 .prompts()
720 .flat_map(|context_server_prompt| {
721 let prompt = &context_server_prompt.prompt;
722
723 let should_prefix = prompt_name_counts
724 .get(prompt.name.as_str())
725 .copied()
726 .unwrap_or(0)
727 > 1;
728
729 let name = if should_prefix {
730 format!("{}.{}", context_server_prompt.server_id, prompt.name)
731 } else {
732 prompt.name.clone()
733 };
734
735 let mut command = acp::AvailableCommand::new(
736 name,
737 prompt.description.clone().unwrap_or_default(),
738 );
739
740 match prompt.arguments.as_deref() {
741 Some([arg]) => {
742 let hint = format!("<{}>", arg.name);
743
744 command = command.input(acp::AvailableCommandInput::Unstructured(
745 acp::UnstructuredCommandInput::new(hint),
746 ));
747 }
748 Some([]) | None => {}
749 Some(_) => {
750 // skip >1 argument commands since we don't support them yet
751 return None;
752 }
753 }
754
755 Some(command)
756 })
757 .collect()
758 }
759
760 pub fn load_thread(
761 &mut self,
762 id: acp::SessionId,
763 cx: &mut Context<Self>,
764 ) -> Task<Result<Entity<Thread>>> {
765 let database_future = ThreadsDatabase::connect(cx);
766 cx.spawn(async move |this, cx| {
767 let database = database_future.await.map_err(|err| anyhow!(err))?;
768 let db_thread = database
769 .load_thread(id.clone())
770 .await?
771 .with_context(|| format!("no thread found with ID: {id:?}"))?;
772
773 this.update(cx, |this, cx| {
774 let summarization_model = LanguageModelRegistry::read_global(cx)
775 .thread_summary_model()
776 .map(|c| c.model);
777
778 cx.new(|cx| {
779 let mut thread = Thread::from_db(
780 id.clone(),
781 db_thread,
782 this.project.clone(),
783 this.project_context.clone(),
784 this.context_server_registry.clone(),
785 this.templates.clone(),
786 cx,
787 );
788 thread.set_summarization_model(summarization_model, cx);
789 thread
790 })
791 })
792 })
793 }
794
795 pub fn open_thread(
796 &mut self,
797 id: acp::SessionId,
798 cx: &mut Context<Self>,
799 ) -> Task<Result<Entity<AcpThread>>> {
800 if let Some(session) = self.sessions.get(&id) {
801 return Task::ready(Ok(session.acp_thread.clone()));
802 }
803
804 let task = self.load_thread(id, cx);
805 cx.spawn(async move |this, cx| {
806 let thread = task.await?;
807 let acp_thread =
808 this.update(cx, |this, cx| this.register_session(thread.clone(), cx))?;
809 let events = thread.update(cx, |thread, cx| thread.replay(cx));
810 cx.update(|cx| {
811 NativeAgentConnection::handle_thread_events(events, acp_thread.downgrade(), cx)
812 })
813 .await?;
814 Ok(acp_thread)
815 })
816 }
817
818 pub fn thread_summary(
819 &mut self,
820 id: acp::SessionId,
821 cx: &mut Context<Self>,
822 ) -> Task<Result<SharedString>> {
823 let thread = self.open_thread(id.clone(), cx);
824 cx.spawn(async move |this, cx| {
825 let acp_thread = thread.await?;
826 let result = this
827 .update(cx, |this, cx| {
828 this.sessions
829 .get(&id)
830 .unwrap()
831 .thread
832 .update(cx, |thread, cx| thread.summary(cx))
833 })?
834 .await
835 .context("Failed to generate summary")?;
836 drop(acp_thread);
837 Ok(result)
838 })
839 }
840
841 fn save_thread(&mut self, thread: Entity<Thread>, cx: &mut Context<Self>) {
842 if thread.read(cx).is_empty() {
843 return;
844 }
845
846 let database_future = ThreadsDatabase::connect(cx);
847 let (id, db_thread) =
848 thread.update(cx, |thread, cx| (thread.id().clone(), thread.to_db(cx)));
849 let Some(session) = self.sessions.get_mut(&id) else {
850 return;
851 };
852
853 let folder_paths = PathList::new(
854 &self
855 .project
856 .read(cx)
857 .visible_worktrees(cx)
858 .map(|worktree| worktree.read(cx).abs_path().to_path_buf())
859 .collect::<Vec<_>>(),
860 );
861
862 let thread_store = self.thread_store.clone();
863 session.pending_save = cx.spawn(async move |_, cx| {
864 let Some(database) = database_future.await.map_err(|err| anyhow!(err)).log_err() else {
865 return;
866 };
867 let db_thread = db_thread.await;
868 database
869 .save_thread(id, db_thread, folder_paths)
870 .await
871 .log_err();
872 thread_store.update(cx, |store, cx| store.reload(cx));
873 });
874 }
875
876 fn send_mcp_prompt(
877 &self,
878 message_id: UserMessageId,
879 session_id: agent_client_protocol::SessionId,
880 prompt_name: String,
881 server_id: ContextServerId,
882 arguments: HashMap<String, String>,
883 original_content: Vec<acp::ContentBlock>,
884 cx: &mut Context<Self>,
885 ) -> Task<Result<acp::PromptResponse>> {
886 let server_store = self.context_server_registry.read(cx).server_store().clone();
887 let path_style = self.project.read(cx).path_style(cx);
888
889 cx.spawn(async move |this, cx| {
890 let prompt =
891 crate::get_prompt(&server_store, &server_id, &prompt_name, arguments, cx).await?;
892
893 let (acp_thread, thread) = this.update(cx, |this, _cx| {
894 let session = this
895 .sessions
896 .get(&session_id)
897 .context("Failed to get session")?;
898 anyhow::Ok((session.acp_thread.clone(), session.thread.clone()))
899 })??;
900
901 let mut last_is_user = true;
902
903 thread.update(cx, |thread, cx| {
904 thread.push_acp_user_block(
905 message_id,
906 original_content.into_iter().skip(1),
907 path_style,
908 cx,
909 );
910 });
911
912 for message in prompt.messages {
913 let context_server::types::PromptMessage { role, content } = message;
914 let block = mcp_message_content_to_acp_content_block(content);
915
916 match role {
917 context_server::types::Role::User => {
918 let id = acp_thread::UserMessageId::new();
919
920 acp_thread.update(cx, |acp_thread, cx| {
921 acp_thread.push_user_content_block_with_indent(
922 Some(id.clone()),
923 block.clone(),
924 true,
925 cx,
926 );
927 });
928
929 thread.update(cx, |thread, cx| {
930 thread.push_acp_user_block(id, [block], path_style, cx);
931 });
932 }
933 context_server::types::Role::Assistant => {
934 acp_thread.update(cx, |acp_thread, cx| {
935 acp_thread.push_assistant_content_block_with_indent(
936 block.clone(),
937 false,
938 true,
939 cx,
940 );
941 });
942
943 thread.update(cx, |thread, cx| {
944 thread.push_acp_agent_block(block, cx);
945 });
946 }
947 }
948
949 last_is_user = role == context_server::types::Role::User;
950 }
951
952 let response_stream = thread.update(cx, |thread, cx| {
953 if last_is_user {
954 thread.send_existing(cx)
955 } else {
956 // Resume if MCP prompt did not end with a user message
957 thread.resume(cx)
958 }
959 })?;
960
961 cx.update(|cx| {
962 NativeAgentConnection::handle_thread_events(
963 response_stream,
964 acp_thread.downgrade(),
965 cx,
966 )
967 })
968 .await
969 })
970 }
971}
972
973/// Wrapper struct that implements the AgentConnection trait
974#[derive(Clone)]
975pub struct NativeAgentConnection(pub Entity<NativeAgent>);
976
977impl NativeAgentConnection {
978 pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option<Entity<Thread>> {
979 self.0
980 .read(cx)
981 .sessions
982 .get(session_id)
983 .map(|session| session.thread.clone())
984 }
985
986 pub fn load_thread(&self, id: acp::SessionId, cx: &mut App) -> Task<Result<Entity<Thread>>> {
987 self.0.update(cx, |this, cx| this.load_thread(id, cx))
988 }
989
990 fn run_turn(
991 &self,
992 session_id: acp::SessionId,
993 cx: &mut App,
994 f: impl 'static
995 + FnOnce(Entity<Thread>, &mut App) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>,
996 ) -> Task<Result<acp::PromptResponse>> {
997 let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
998 agent
999 .sessions
1000 .get_mut(&session_id)
1001 .map(|s| (s.thread.clone(), s.acp_thread.clone()))
1002 }) else {
1003 return Task::ready(Err(anyhow!("Session not found")));
1004 };
1005 log::debug!("Found session for: {}", session_id);
1006
1007 let response_stream = match f(thread, cx) {
1008 Ok(stream) => stream,
1009 Err(err) => return Task::ready(Err(err)),
1010 };
1011 Self::handle_thread_events(response_stream, acp_thread.downgrade(), cx)
1012 }
1013
1014 fn handle_thread_events(
1015 mut events: mpsc::UnboundedReceiver<Result<ThreadEvent>>,
1016 acp_thread: WeakEntity<AcpThread>,
1017 cx: &App,
1018 ) -> Task<Result<acp::PromptResponse>> {
1019 cx.spawn(async move |cx| {
1020 // Handle response stream and forward to session.acp_thread
1021 while let Some(result) = events.next().await {
1022 match result {
1023 Ok(event) => {
1024 log::trace!("Received completion event: {:?}", event);
1025
1026 match event {
1027 ThreadEvent::UserMessage(message) => {
1028 acp_thread.update(cx, |thread, cx| {
1029 for content in message.content {
1030 thread.push_user_content_block(
1031 Some(message.id.clone()),
1032 content.into(),
1033 cx,
1034 );
1035 }
1036 })?;
1037 }
1038 ThreadEvent::AgentText(text) => {
1039 acp_thread.update(cx, |thread, cx| {
1040 thread.push_assistant_content_block(text.into(), false, cx)
1041 })?;
1042 }
1043 ThreadEvent::AgentThinking(text) => {
1044 acp_thread.update(cx, |thread, cx| {
1045 thread.push_assistant_content_block(text.into(), true, cx)
1046 })?;
1047 }
1048 ThreadEvent::ToolCallAuthorization(ToolCallAuthorization {
1049 tool_call,
1050 options,
1051 response,
1052 context: _,
1053 }) => {
1054 let outcome_task = acp_thread.update(cx, |thread, cx| {
1055 thread.request_tool_call_authorization(tool_call, options, cx)
1056 })??;
1057 cx.background_spawn(async move {
1058 if let acp::RequestPermissionOutcome::Selected(
1059 acp::SelectedPermissionOutcome { option_id, .. },
1060 ) = outcome_task.await
1061 {
1062 response
1063 .send(option_id)
1064 .map(|_| anyhow!("authorization receiver was dropped"))
1065 .log_err();
1066 }
1067 })
1068 .detach();
1069 }
1070 ThreadEvent::ToolCall(tool_call) => {
1071 acp_thread.update(cx, |thread, cx| {
1072 thread.upsert_tool_call(tool_call, cx)
1073 })??;
1074 }
1075 ThreadEvent::ToolCallUpdate(update) => {
1076 acp_thread.update(cx, |thread, cx| {
1077 thread.update_tool_call(update, cx)
1078 })??;
1079 }
1080 ThreadEvent::SubagentSpawned(session_id) => {
1081 acp_thread.update(cx, |thread, cx| {
1082 thread.subagent_spawned(session_id, cx);
1083 })?;
1084 }
1085 ThreadEvent::Retry(status) => {
1086 acp_thread.update(cx, |thread, cx| {
1087 thread.update_retry_status(status, cx)
1088 })?;
1089 }
1090 ThreadEvent::Stop(stop_reason) => {
1091 log::debug!("Assistant message complete: {:?}", stop_reason);
1092 return Ok(acp::PromptResponse::new(stop_reason));
1093 }
1094 }
1095 }
1096 Err(e) => {
1097 log::error!("Error in model response stream: {:?}", e);
1098 return Err(e);
1099 }
1100 }
1101 }
1102
1103 log::debug!("Response stream completed");
1104 anyhow::Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
1105 })
1106 }
1107}
1108
1109struct Command<'a> {
1110 prompt_name: &'a str,
1111 arg_value: &'a str,
1112 explicit_server_id: Option<&'a str>,
1113}
1114
1115impl<'a> Command<'a> {
1116 fn parse(prompt: &'a [acp::ContentBlock]) -> Option<Self> {
1117 let acp::ContentBlock::Text(text_content) = prompt.first()? else {
1118 return None;
1119 };
1120 let text = text_content.text.trim();
1121 let command = text.strip_prefix('/')?;
1122 let (command, arg_value) = command
1123 .split_once(char::is_whitespace)
1124 .unwrap_or((command, ""));
1125
1126 if let Some((server_id, prompt_name)) = command.split_once('.') {
1127 Some(Self {
1128 prompt_name,
1129 arg_value,
1130 explicit_server_id: Some(server_id),
1131 })
1132 } else {
1133 Some(Self {
1134 prompt_name: command,
1135 arg_value,
1136 explicit_server_id: None,
1137 })
1138 }
1139 }
1140}
1141
1142struct NativeAgentModelSelector {
1143 session_id: acp::SessionId,
1144 connection: NativeAgentConnection,
1145}
1146
1147impl acp_thread::AgentModelSelector for NativeAgentModelSelector {
1148 fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
1149 log::debug!("NativeAgentConnection::list_models called");
1150 let list = self.connection.0.read(cx).models.model_list.clone();
1151 Task::ready(if list.is_empty() {
1152 Err(anyhow::anyhow!("No models available"))
1153 } else {
1154 Ok(list)
1155 })
1156 }
1157
1158 fn select_model(&self, model_id: acp::ModelId, cx: &mut App) -> Task<Result<()>> {
1159 log::debug!(
1160 "Setting model for session {}: {}",
1161 self.session_id,
1162 model_id
1163 );
1164 let Some(thread) = self
1165 .connection
1166 .0
1167 .read(cx)
1168 .sessions
1169 .get(&self.session_id)
1170 .map(|session| session.thread.clone())
1171 else {
1172 return Task::ready(Err(anyhow!("Session not found")));
1173 };
1174
1175 let Some(model) = self.connection.0.read(cx).models.model_from_id(&model_id) else {
1176 return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
1177 };
1178
1179 // We want to reset the effort level when switching models, as the currently-selected effort level may
1180 // not be compatible.
1181 let effort = model
1182 .default_effort_level()
1183 .map(|effort_level| effort_level.value.to_string());
1184
1185 thread.update(cx, |thread, cx| {
1186 thread.set_model(model.clone(), cx);
1187 thread.set_thinking_effort(effort.clone(), cx);
1188 thread.set_thinking_enabled(model.supports_thinking(), cx);
1189 });
1190
1191 update_settings_file(
1192 self.connection.0.read(cx).fs.clone(),
1193 cx,
1194 move |settings, cx| {
1195 let provider = model.provider_id().0.to_string();
1196 let model = model.id().0.to_string();
1197 let enable_thinking = thread.read(cx).thinking_enabled();
1198 settings
1199 .agent
1200 .get_or_insert_default()
1201 .set_model(LanguageModelSelection {
1202 provider: provider.into(),
1203 model,
1204 enable_thinking,
1205 effort,
1206 });
1207 },
1208 );
1209
1210 Task::ready(Ok(()))
1211 }
1212
1213 fn selected_model(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelInfo>> {
1214 let Some(thread) = self
1215 .connection
1216 .0
1217 .read(cx)
1218 .sessions
1219 .get(&self.session_id)
1220 .map(|session| session.thread.clone())
1221 else {
1222 return Task::ready(Err(anyhow!("Session not found")));
1223 };
1224 let Some(model) = thread.read(cx).model() else {
1225 return Task::ready(Err(anyhow!("Model not found")));
1226 };
1227 let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
1228 else {
1229 return Task::ready(Err(anyhow!("Provider not found")));
1230 };
1231 Task::ready(Ok(LanguageModels::map_language_model_to_info(
1232 model, &provider,
1233 )))
1234 }
1235
1236 fn watch(&self, cx: &mut App) -> Option<watch::Receiver<()>> {
1237 Some(self.connection.0.read(cx).models.watch())
1238 }
1239
1240 fn should_render_footer(&self) -> bool {
1241 true
1242 }
1243}
1244
1245impl acp_thread::AgentConnection for NativeAgentConnection {
1246 fn telemetry_id(&self) -> SharedString {
1247 "zed".into()
1248 }
1249
1250 fn new_session(
1251 self: Rc<Self>,
1252 project: Entity<Project>,
1253 cwd: &Path,
1254 cx: &mut App,
1255 ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
1256 log::debug!("Creating new thread for project at: {cwd:?}");
1257 Task::ready(Ok(self
1258 .0
1259 .update(cx, |agent, cx| agent.new_session(project, cx))))
1260 }
1261
1262 fn supports_load_session(&self) -> bool {
1263 true
1264 }
1265
1266 fn load_session(
1267 self: Rc<Self>,
1268 session: AgentSessionInfo,
1269 _project: Entity<Project>,
1270 _cwd: &Path,
1271 cx: &mut App,
1272 ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
1273 self.0
1274 .update(cx, |agent, cx| agent.open_thread(session.session_id, cx))
1275 }
1276
1277 fn supports_close_session(&self) -> bool {
1278 true
1279 }
1280
1281 fn close_session(&self, session_id: &acp::SessionId, cx: &mut App) -> Task<Result<()>> {
1282 self.0.update(cx, |agent, _cx| {
1283 agent.sessions.remove(session_id);
1284 });
1285 Task::ready(Ok(()))
1286 }
1287
1288 fn auth_methods(&self) -> &[acp::AuthMethod] {
1289 &[] // No auth for in-process
1290 }
1291
1292 fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
1293 Task::ready(Ok(()))
1294 }
1295
1296 fn model_selector(&self, session_id: &acp::SessionId) -> Option<Rc<dyn AgentModelSelector>> {
1297 Some(Rc::new(NativeAgentModelSelector {
1298 session_id: session_id.clone(),
1299 connection: self.clone(),
1300 }) as Rc<dyn AgentModelSelector>)
1301 }
1302
1303 fn prompt(
1304 &self,
1305 id: Option<acp_thread::UserMessageId>,
1306 params: acp::PromptRequest,
1307 cx: &mut App,
1308 ) -> Task<Result<acp::PromptResponse>> {
1309 let id = id.expect("UserMessageId is required");
1310 let session_id = params.session_id.clone();
1311 log::info!("Received prompt request for session: {}", session_id);
1312 log::debug!("Prompt blocks count: {}", params.prompt.len());
1313
1314 if let Some(parsed_command) = Command::parse(¶ms.prompt) {
1315 let registry = self.0.read(cx).context_server_registry.read(cx);
1316
1317 let explicit_server_id = parsed_command
1318 .explicit_server_id
1319 .map(|server_id| ContextServerId(server_id.into()));
1320
1321 if let Some(prompt) =
1322 registry.find_prompt(explicit_server_id.as_ref(), parsed_command.prompt_name)
1323 {
1324 let arguments = if !parsed_command.arg_value.is_empty()
1325 && let Some(arg_name) = prompt
1326 .prompt
1327 .arguments
1328 .as_ref()
1329 .and_then(|args| args.first())
1330 .map(|arg| arg.name.clone())
1331 {
1332 HashMap::from_iter([(arg_name, parsed_command.arg_value.to_string())])
1333 } else {
1334 Default::default()
1335 };
1336
1337 let prompt_name = prompt.prompt.name.clone();
1338 let server_id = prompt.server_id.clone();
1339
1340 return self.0.update(cx, |agent, cx| {
1341 agent.send_mcp_prompt(
1342 id,
1343 session_id.clone(),
1344 prompt_name,
1345 server_id,
1346 arguments,
1347 params.prompt,
1348 cx,
1349 )
1350 });
1351 };
1352 };
1353
1354 let path_style = self.0.read(cx).project.read(cx).path_style(cx);
1355
1356 self.run_turn(session_id, cx, move |thread, cx| {
1357 let content: Vec<UserMessageContent> = params
1358 .prompt
1359 .into_iter()
1360 .map(|block| UserMessageContent::from_content_block(block, path_style))
1361 .collect::<Vec<_>>();
1362 log::debug!("Converted prompt to message: {} chars", content.len());
1363 log::debug!("Message id: {:?}", id);
1364 log::debug!("Message content: {:?}", content);
1365
1366 thread.update(cx, |thread, cx| thread.send(id, content, cx))
1367 })
1368 }
1369
1370 fn retry(
1371 &self,
1372 session_id: &acp::SessionId,
1373 _cx: &App,
1374 ) -> Option<Rc<dyn acp_thread::AgentSessionRetry>> {
1375 Some(Rc::new(NativeAgentSessionRetry {
1376 connection: self.clone(),
1377 session_id: session_id.clone(),
1378 }) as _)
1379 }
1380
1381 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
1382 log::info!("Cancelling on session: {}", session_id);
1383 self.0.update(cx, |agent, cx| {
1384 if let Some(session) = agent.sessions.get(session_id) {
1385 session
1386 .thread
1387 .update(cx, |thread, cx| thread.cancel(cx))
1388 .detach();
1389 }
1390 });
1391 }
1392
1393 fn truncate(
1394 &self,
1395 session_id: &agent_client_protocol::SessionId,
1396 cx: &App,
1397 ) -> Option<Rc<dyn acp_thread::AgentSessionTruncate>> {
1398 self.0.read_with(cx, |agent, _cx| {
1399 agent.sessions.get(session_id).map(|session| {
1400 Rc::new(NativeAgentSessionTruncate {
1401 thread: session.thread.clone(),
1402 acp_thread: session.acp_thread.downgrade(),
1403 }) as _
1404 })
1405 })
1406 }
1407
1408 fn set_title(
1409 &self,
1410 session_id: &acp::SessionId,
1411 cx: &App,
1412 ) -> Option<Rc<dyn acp_thread::AgentSessionSetTitle>> {
1413 self.0.read_with(cx, |agent, _cx| {
1414 agent
1415 .sessions
1416 .get(session_id)
1417 .filter(|s| !s.thread.read(cx).is_subagent())
1418 .map(|session| {
1419 Rc::new(NativeAgentSessionSetTitle {
1420 thread: session.thread.clone(),
1421 }) as _
1422 })
1423 })
1424 }
1425
1426 fn session_list(&self, cx: &mut App) -> Option<Rc<dyn AgentSessionList>> {
1427 let thread_store = self.0.read(cx).thread_store.clone();
1428 Some(Rc::new(NativeAgentSessionList::new(thread_store, cx)) as _)
1429 }
1430
1431 fn telemetry(&self) -> Option<Rc<dyn acp_thread::AgentTelemetry>> {
1432 Some(Rc::new(self.clone()) as Rc<dyn acp_thread::AgentTelemetry>)
1433 }
1434
1435 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1436 self
1437 }
1438}
1439
1440impl acp_thread::AgentTelemetry for NativeAgentConnection {
1441 fn thread_data(
1442 &self,
1443 session_id: &acp::SessionId,
1444 cx: &mut App,
1445 ) -> Task<Result<serde_json::Value>> {
1446 let Some(session) = self.0.read(cx).sessions.get(session_id) else {
1447 return Task::ready(Err(anyhow!("Session not found")));
1448 };
1449
1450 let task = session.thread.read(cx).to_db(cx);
1451 cx.background_spawn(async move {
1452 serde_json::to_value(task.await).context("Failed to serialize thread")
1453 })
1454 }
1455}
1456
1457pub struct NativeAgentSessionList {
1458 thread_store: Entity<ThreadStore>,
1459 updates_tx: smol::channel::Sender<acp_thread::SessionListUpdate>,
1460 updates_rx: smol::channel::Receiver<acp_thread::SessionListUpdate>,
1461 _subscription: Subscription,
1462}
1463
1464impl NativeAgentSessionList {
1465 fn new(thread_store: Entity<ThreadStore>, cx: &mut App) -> Self {
1466 let (tx, rx) = smol::channel::unbounded();
1467 let this_tx = tx.clone();
1468 let subscription = cx.observe(&thread_store, move |_, _| {
1469 this_tx
1470 .try_send(acp_thread::SessionListUpdate::Refresh)
1471 .ok();
1472 });
1473 Self {
1474 thread_store,
1475 updates_tx: tx,
1476 updates_rx: rx,
1477 _subscription: subscription,
1478 }
1479 }
1480
1481 fn to_session_info(entry: DbThreadMetadata) -> AgentSessionInfo {
1482 AgentSessionInfo {
1483 session_id: entry.id,
1484 cwd: None,
1485 title: Some(entry.title),
1486 updated_at: Some(entry.updated_at),
1487 meta: None,
1488 }
1489 }
1490
1491 pub fn thread_store(&self) -> &Entity<ThreadStore> {
1492 &self.thread_store
1493 }
1494}
1495
1496impl AgentSessionList for NativeAgentSessionList {
1497 fn list_sessions(
1498 &self,
1499 _request: AgentSessionListRequest,
1500 cx: &mut App,
1501 ) -> Task<Result<AgentSessionListResponse>> {
1502 let sessions = self
1503 .thread_store
1504 .read(cx)
1505 .entries()
1506 .map(Self::to_session_info)
1507 .collect();
1508 Task::ready(Ok(AgentSessionListResponse::new(sessions)))
1509 }
1510
1511 fn supports_delete(&self) -> bool {
1512 true
1513 }
1514
1515 fn delete_session(&self, session_id: &acp::SessionId, cx: &mut App) -> Task<Result<()>> {
1516 self.thread_store
1517 .update(cx, |store, cx| store.delete_thread(session_id.clone(), cx))
1518 }
1519
1520 fn delete_sessions(&self, cx: &mut App) -> Task<Result<()>> {
1521 self.thread_store
1522 .update(cx, |store, cx| store.delete_threads(cx))
1523 }
1524
1525 fn watch(
1526 &self,
1527 _cx: &mut App,
1528 ) -> Option<smol::channel::Receiver<acp_thread::SessionListUpdate>> {
1529 Some(self.updates_rx.clone())
1530 }
1531
1532 fn notify_refresh(&self) {
1533 self.updates_tx
1534 .try_send(acp_thread::SessionListUpdate::Refresh)
1535 .ok();
1536 }
1537
1538 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1539 self
1540 }
1541}
1542
1543struct NativeAgentSessionTruncate {
1544 thread: Entity<Thread>,
1545 acp_thread: WeakEntity<AcpThread>,
1546}
1547
1548impl acp_thread::AgentSessionTruncate for NativeAgentSessionTruncate {
1549 fn run(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
1550 match self.thread.update(cx, |thread, cx| {
1551 thread.truncate(message_id.clone(), cx)?;
1552 Ok(thread.latest_token_usage())
1553 }) {
1554 Ok(usage) => {
1555 self.acp_thread
1556 .update(cx, |thread, cx| {
1557 thread.update_token_usage(usage, cx);
1558 })
1559 .ok();
1560 Task::ready(Ok(()))
1561 }
1562 Err(error) => Task::ready(Err(error)),
1563 }
1564 }
1565}
1566
1567struct NativeAgentSessionRetry {
1568 connection: NativeAgentConnection,
1569 session_id: acp::SessionId,
1570}
1571
1572impl acp_thread::AgentSessionRetry for NativeAgentSessionRetry {
1573 fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>> {
1574 self.connection
1575 .run_turn(self.session_id.clone(), cx, |thread, cx| {
1576 thread.update(cx, |thread, cx| thread.resume(cx))
1577 })
1578 }
1579}
1580
1581struct NativeAgentSessionSetTitle {
1582 thread: Entity<Thread>,
1583}
1584
1585impl acp_thread::AgentSessionSetTitle for NativeAgentSessionSetTitle {
1586 fn run(&self, title: SharedString, cx: &mut App) -> Task<Result<()>> {
1587 self.thread
1588 .update(cx, |thread, cx| thread.set_title(title, cx));
1589 Task::ready(Ok(()))
1590 }
1591}
1592
1593pub struct NativeThreadEnvironment {
1594 agent: WeakEntity<NativeAgent>,
1595 thread: WeakEntity<Thread>,
1596 acp_thread: WeakEntity<AcpThread>,
1597}
1598
1599impl NativeThreadEnvironment {
1600 pub(crate) fn create_subagent_thread(
1601 &self,
1602 label: String,
1603 cx: &mut App,
1604 ) -> Result<Rc<dyn SubagentHandle>> {
1605 let Some(parent_thread_entity) = self.thread.upgrade() else {
1606 anyhow::bail!("Parent thread no longer exists".to_string());
1607 };
1608 let parent_thread = parent_thread_entity.read(cx);
1609 let current_depth = parent_thread.depth();
1610
1611 if current_depth >= MAX_SUBAGENT_DEPTH {
1612 return Err(anyhow!(
1613 "Maximum subagent depth ({}) reached",
1614 MAX_SUBAGENT_DEPTH
1615 ));
1616 }
1617
1618 let subagent_thread: Entity<Thread> = cx.new(|cx| {
1619 let mut thread = Thread::new_subagent(&parent_thread_entity, cx);
1620 thread.set_title(label.into(), cx);
1621 thread
1622 });
1623
1624 let session_id = subagent_thread.read(cx).id().clone();
1625
1626 let acp_thread = self.agent.update(cx, |agent, cx| {
1627 agent.register_session(subagent_thread.clone(), cx)
1628 })?;
1629
1630 self.prompt_subagent(session_id, subagent_thread, acp_thread)
1631 }
1632
1633 pub(crate) fn resume_subagent_thread(
1634 &self,
1635 session_id: acp::SessionId,
1636 cx: &mut App,
1637 ) -> Result<Rc<dyn SubagentHandle>> {
1638 let (subagent_thread, acp_thread) = self.agent.update(cx, |agent, _cx| {
1639 let session = agent
1640 .sessions
1641 .get(&session_id)
1642 .ok_or_else(|| anyhow!("No subagent session found with id {session_id}"))?;
1643 anyhow::Ok((session.thread.clone(), session.acp_thread.clone()))
1644 })??;
1645
1646 self.prompt_subagent(session_id, subagent_thread, acp_thread)
1647 }
1648
1649 fn prompt_subagent(
1650 &self,
1651 session_id: acp::SessionId,
1652 subagent_thread: Entity<Thread>,
1653 acp_thread: Entity<acp_thread::AcpThread>,
1654 ) -> Result<Rc<dyn SubagentHandle>> {
1655 let Some(parent_thread_entity) = self.thread.upgrade() else {
1656 anyhow::bail!("Parent thread no longer exists".to_string());
1657 };
1658 Ok(Rc::new(NativeSubagentHandle::new(
1659 session_id,
1660 subagent_thread,
1661 acp_thread,
1662 parent_thread_entity,
1663 )) as _)
1664 }
1665}
1666
1667impl ThreadEnvironment for NativeThreadEnvironment {
1668 fn create_terminal(
1669 &self,
1670 command: String,
1671 cwd: Option<PathBuf>,
1672 output_byte_limit: Option<u64>,
1673 cx: &mut AsyncApp,
1674 ) -> Task<Result<Rc<dyn TerminalHandle>>> {
1675 let task = self.acp_thread.update(cx, |thread, cx| {
1676 thread.create_terminal(command, vec![], vec![], cwd, output_byte_limit, cx)
1677 });
1678
1679 let acp_thread = self.acp_thread.clone();
1680 cx.spawn(async move |cx| {
1681 let terminal = task?.await?;
1682
1683 let (drop_tx, drop_rx) = oneshot::channel();
1684 let terminal_id = terminal.read_with(cx, |terminal, _cx| terminal.id().clone());
1685
1686 cx.spawn(async move |cx| {
1687 drop_rx.await.ok();
1688 acp_thread.update(cx, |thread, cx| thread.release_terminal(terminal_id, cx))
1689 })
1690 .detach();
1691
1692 let handle = AcpTerminalHandle {
1693 terminal,
1694 _drop_tx: Some(drop_tx),
1695 };
1696
1697 Ok(Rc::new(handle) as _)
1698 })
1699 }
1700
1701 fn create_subagent(&self, label: String, cx: &mut App) -> Result<Rc<dyn SubagentHandle>> {
1702 self.create_subagent_thread(label, cx)
1703 }
1704
1705 fn resume_subagent(
1706 &self,
1707 session_id: acp::SessionId,
1708 cx: &mut App,
1709 ) -> Result<Rc<dyn SubagentHandle>> {
1710 self.resume_subagent_thread(session_id, cx)
1711 }
1712}
1713
1714#[derive(Debug, Clone)]
1715enum SubagentPromptResult {
1716 Completed,
1717 Cancelled,
1718 ContextWindowWarning,
1719 Error(String),
1720}
1721
1722pub struct NativeSubagentHandle {
1723 session_id: acp::SessionId,
1724 parent_thread: WeakEntity<Thread>,
1725 subagent_thread: Entity<Thread>,
1726 acp_thread: Entity<acp_thread::AcpThread>,
1727}
1728
1729impl NativeSubagentHandle {
1730 fn new(
1731 session_id: acp::SessionId,
1732 subagent_thread: Entity<Thread>,
1733 acp_thread: Entity<acp_thread::AcpThread>,
1734 parent_thread_entity: Entity<Thread>,
1735 ) -> Self {
1736 NativeSubagentHandle {
1737 session_id,
1738 subagent_thread,
1739 parent_thread: parent_thread_entity.downgrade(),
1740 acp_thread,
1741 }
1742 }
1743}
1744
1745impl SubagentHandle for NativeSubagentHandle {
1746 fn id(&self) -> acp::SessionId {
1747 self.session_id.clone()
1748 }
1749
1750 fn send(&self, message: String, cx: &AsyncApp) -> Task<Result<String>> {
1751 let thread = self.subagent_thread.clone();
1752 let acp_thread = self.acp_thread.clone();
1753 let subagent_session_id = self.session_id.clone();
1754 let parent_thread = self.parent_thread.clone();
1755
1756 cx.spawn(async move |cx| {
1757 let (task, _subscription) = cx.update(|cx| {
1758 let ratio_before_prompt = thread
1759 .read(cx)
1760 .latest_token_usage()
1761 .map(|usage| usage.ratio());
1762
1763 parent_thread
1764 .update(cx, |parent_thread, _cx| {
1765 parent_thread.register_running_subagent(thread.downgrade())
1766 })
1767 .ok();
1768
1769 let task = acp_thread.update(cx, |acp_thread, cx| {
1770 acp_thread.send(vec![message.into()], cx)
1771 });
1772
1773 let (token_limit_tx, token_limit_rx) = oneshot::channel::<()>();
1774 let mut token_limit_tx = Some(token_limit_tx);
1775
1776 let subscription = cx.subscribe(
1777 &thread,
1778 move |_thread, event: &TokenUsageUpdated, _cx| {
1779 if let Some(usage) = &event.0 {
1780 let old_ratio = ratio_before_prompt
1781 .clone()
1782 .unwrap_or(TokenUsageRatio::Normal);
1783 let new_ratio = usage.ratio();
1784 if old_ratio == TokenUsageRatio::Normal
1785 && new_ratio == TokenUsageRatio::Warning
1786 {
1787 if let Some(tx) = token_limit_tx.take() {
1788 tx.send(()).ok();
1789 }
1790 }
1791 }
1792 },
1793 );
1794
1795 let wait_for_prompt = cx
1796 .background_spawn(async move {
1797 futures::select! {
1798 response = task.fuse() => match response {
1799 Ok(Some(response)) => {
1800 match response.stop_reason {
1801 acp::StopReason::Cancelled => SubagentPromptResult::Cancelled,
1802 acp::StopReason::MaxTokens => SubagentPromptResult::Error("The agent reached the maximum number of tokens.".into()),
1803 acp::StopReason::MaxTurnRequests => SubagentPromptResult::Error("The agent reached the maximum number of allowed requests between user turns. Try prompting again.".into()),
1804 acp::StopReason::Refusal => SubagentPromptResult::Error("The agent refused to process that prompt. Try again.".into()),
1805 acp::StopReason::EndTurn | _ => SubagentPromptResult::Completed,
1806 }
1807 }
1808 Ok(None) => SubagentPromptResult::Error("No response from the agent. You can try messaging again.".into()),
1809 Err(error) => SubagentPromptResult::Error(error.to_string()),
1810 },
1811 _ = token_limit_rx.fuse() => SubagentPromptResult::ContextWindowWarning,
1812 }
1813 });
1814
1815 (wait_for_prompt, subscription)
1816 });
1817
1818 let result = match task.await {
1819 SubagentPromptResult::Completed => thread.read_with(cx, |thread, _cx| {
1820 thread
1821 .last_message()
1822 .map(|m| m.to_markdown())
1823 .context("No response from subagent")
1824 }),
1825 SubagentPromptResult::Cancelled => Err(anyhow!("User cancelled")),
1826 SubagentPromptResult::Error(message) => Err(anyhow!("{message}")),
1827 SubagentPromptResult::ContextWindowWarning => {
1828 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
1829 Err(anyhow!(
1830 "The agent is nearing the end of its context window and has been \
1831 stopped. You can prompt the thread again to have the agent wrap up \
1832 or hand off its work."
1833 ))
1834 }
1835 };
1836
1837 parent_thread
1838 .update(cx, |parent_thread, cx| {
1839 parent_thread.unregister_running_subagent(&subagent_session_id, cx)
1840 })
1841 .ok();
1842
1843 result
1844 })
1845 }
1846}
1847
1848pub struct AcpTerminalHandle {
1849 terminal: Entity<acp_thread::Terminal>,
1850 _drop_tx: Option<oneshot::Sender<()>>,
1851}
1852
1853impl TerminalHandle for AcpTerminalHandle {
1854 fn id(&self, cx: &AsyncApp) -> Result<acp::TerminalId> {
1855 Ok(self.terminal.read_with(cx, |term, _cx| term.id().clone()))
1856 }
1857
1858 fn wait_for_exit(&self, cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
1859 Ok(self
1860 .terminal
1861 .read_with(cx, |term, _cx| term.wait_for_exit()))
1862 }
1863
1864 fn current_output(&self, cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
1865 Ok(self
1866 .terminal
1867 .read_with(cx, |term, cx| term.current_output(cx)))
1868 }
1869
1870 fn kill(&self, cx: &AsyncApp) -> Result<()> {
1871 cx.update(|cx| {
1872 self.terminal.update(cx, |terminal, cx| {
1873 terminal.kill(cx);
1874 });
1875 });
1876 Ok(())
1877 }
1878
1879 fn was_stopped_by_user(&self, cx: &AsyncApp) -> Result<bool> {
1880 Ok(self
1881 .terminal
1882 .read_with(cx, |term, _cx| term.was_stopped_by_user()))
1883 }
1884}
1885
1886#[cfg(test)]
1887mod internal_tests {
1888 use super::*;
1889 use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelInfo, MentionUri};
1890 use fs::FakeFs;
1891 use gpui::TestAppContext;
1892 use indoc::formatdoc;
1893 use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
1894 use language_model::{LanguageModelProviderId, LanguageModelProviderName};
1895 use serde_json::json;
1896 use settings::SettingsStore;
1897 use util::{path, rel_path::rel_path};
1898
1899 #[gpui::test]
1900 async fn test_maintaining_project_context(cx: &mut TestAppContext) {
1901 init_test(cx);
1902 let fs = FakeFs::new(cx.executor());
1903 fs.insert_tree(
1904 "/",
1905 json!({
1906 "a": {}
1907 }),
1908 )
1909 .await;
1910 let project = Project::test(fs.clone(), [], cx).await;
1911 let thread_store = cx.new(|cx| ThreadStore::new(cx));
1912 let agent = NativeAgent::new(
1913 project.clone(),
1914 thread_store,
1915 Templates::new(),
1916 None,
1917 fs.clone(),
1918 &mut cx.to_async(),
1919 )
1920 .await
1921 .unwrap();
1922 agent.read_with(cx, |agent, cx| {
1923 assert_eq!(agent.project_context.read(cx).worktrees, vec![])
1924 });
1925
1926 let worktree = project
1927 .update(cx, |project, cx| project.create_worktree("/a", true, cx))
1928 .await
1929 .unwrap();
1930 cx.run_until_parked();
1931 agent.read_with(cx, |agent, cx| {
1932 assert_eq!(
1933 agent.project_context.read(cx).worktrees,
1934 vec![WorktreeContext {
1935 root_name: "a".into(),
1936 abs_path: Path::new("/a").into(),
1937 rules_file: None
1938 }]
1939 )
1940 });
1941
1942 // Creating `/a/.rules` updates the project context.
1943 fs.insert_file("/a/.rules", Vec::new()).await;
1944 cx.run_until_parked();
1945 agent.read_with(cx, |agent, cx| {
1946 let rules_entry = worktree
1947 .read(cx)
1948 .entry_for_path(rel_path(".rules"))
1949 .unwrap();
1950 assert_eq!(
1951 agent.project_context.read(cx).worktrees,
1952 vec![WorktreeContext {
1953 root_name: "a".into(),
1954 abs_path: Path::new("/a").into(),
1955 rules_file: Some(RulesFileContext {
1956 path_in_worktree: rel_path(".rules").into(),
1957 text: "".into(),
1958 project_entry_id: rules_entry.id.to_usize()
1959 })
1960 }]
1961 )
1962 });
1963 }
1964
1965 #[gpui::test]
1966 async fn test_listing_models(cx: &mut TestAppContext) {
1967 init_test(cx);
1968 let fs = FakeFs::new(cx.executor());
1969 fs.insert_tree("/", json!({ "a": {} })).await;
1970 let project = Project::test(fs.clone(), [], cx).await;
1971 let thread_store = cx.new(|cx| ThreadStore::new(cx));
1972 let connection = NativeAgentConnection(
1973 NativeAgent::new(
1974 project.clone(),
1975 thread_store,
1976 Templates::new(),
1977 None,
1978 fs.clone(),
1979 &mut cx.to_async(),
1980 )
1981 .await
1982 .unwrap(),
1983 );
1984
1985 // Create a thread/session
1986 let acp_thread = cx
1987 .update(|cx| {
1988 Rc::new(connection.clone()).new_session(project.clone(), Path::new("/a"), cx)
1989 })
1990 .await
1991 .unwrap();
1992
1993 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
1994
1995 let models = cx
1996 .update(|cx| {
1997 connection
1998 .model_selector(&session_id)
1999 .unwrap()
2000 .list_models(cx)
2001 })
2002 .await
2003 .unwrap();
2004
2005 let acp_thread::AgentModelList::Grouped(models) = models else {
2006 panic!("Unexpected model group");
2007 };
2008 assert_eq!(
2009 models,
2010 IndexMap::from_iter([(
2011 AgentModelGroupName("Fake".into()),
2012 vec![AgentModelInfo {
2013 id: acp::ModelId::new("fake/fake"),
2014 name: "Fake".into(),
2015 description: None,
2016 icon: Some(acp_thread::AgentModelIcon::Named(
2017 ui::IconName::ZedAssistant
2018 )),
2019 is_latest: false,
2020 cost: None,
2021 }]
2022 )])
2023 );
2024 }
2025
2026 #[gpui::test]
2027 async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) {
2028 init_test(cx);
2029 let fs = FakeFs::new(cx.executor());
2030 fs.create_dir(paths::settings_file().parent().unwrap())
2031 .await
2032 .unwrap();
2033 fs.insert_file(
2034 paths::settings_file(),
2035 json!({
2036 "agent": {
2037 "default_model": {
2038 "provider": "foo",
2039 "model": "bar"
2040 }
2041 }
2042 })
2043 .to_string()
2044 .into_bytes(),
2045 )
2046 .await;
2047 let project = Project::test(fs.clone(), [], cx).await;
2048
2049 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2050
2051 // Create the agent and connection
2052 let agent = NativeAgent::new(
2053 project.clone(),
2054 thread_store,
2055 Templates::new(),
2056 None,
2057 fs.clone(),
2058 &mut cx.to_async(),
2059 )
2060 .await
2061 .unwrap();
2062 let connection = NativeAgentConnection(agent.clone());
2063
2064 // Create a thread/session
2065 let acp_thread = cx
2066 .update(|cx| {
2067 Rc::new(connection.clone()).new_session(project.clone(), Path::new("/a"), cx)
2068 })
2069 .await
2070 .unwrap();
2071
2072 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2073
2074 // Select a model
2075 let selector = connection.model_selector(&session_id).unwrap();
2076 let model_id = acp::ModelId::new("fake/fake");
2077 cx.update(|cx| selector.select_model(model_id.clone(), cx))
2078 .await
2079 .unwrap();
2080
2081 // Verify the thread has the selected model
2082 agent.read_with(cx, |agent, _| {
2083 let session = agent.sessions.get(&session_id).unwrap();
2084 session.thread.read_with(cx, |thread, _| {
2085 assert_eq!(thread.model().unwrap().id().0, "fake");
2086 });
2087 });
2088
2089 cx.run_until_parked();
2090
2091 // Verify settings file was updated
2092 let settings_content = fs.load(paths::settings_file()).await.unwrap();
2093 let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
2094
2095 // Check that the agent settings contain the selected model
2096 assert_eq!(
2097 settings_json["agent"]["default_model"]["model"],
2098 json!("fake")
2099 );
2100 assert_eq!(
2101 settings_json["agent"]["default_model"]["provider"],
2102 json!("fake")
2103 );
2104
2105 // Register a thinking model and select it.
2106 cx.update(|cx| {
2107 let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2108 "fake-corp",
2109 "fake-thinking",
2110 "Fake Thinking",
2111 true,
2112 ));
2113 let thinking_provider = Arc::new(
2114 FakeLanguageModelProvider::new(
2115 LanguageModelProviderId::from("fake-corp".to_string()),
2116 LanguageModelProviderName::from("Fake Corp".to_string()),
2117 )
2118 .with_models(vec![thinking_model]),
2119 );
2120 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2121 registry.register_provider(thinking_provider, cx);
2122 });
2123 });
2124 agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2125
2126 let selector = connection.model_selector(&session_id).unwrap();
2127 cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2128 .await
2129 .unwrap();
2130 cx.run_until_parked();
2131
2132 // Verify enable_thinking was written to settings as true.
2133 let settings_content = fs.load(paths::settings_file()).await.unwrap();
2134 let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
2135 assert_eq!(
2136 settings_json["agent"]["default_model"]["enable_thinking"],
2137 json!(true),
2138 "selecting a thinking model should persist enable_thinking: true to settings"
2139 );
2140 }
2141
2142 #[gpui::test]
2143 async fn test_select_model_updates_thinking_enabled(cx: &mut TestAppContext) {
2144 init_test(cx);
2145 let fs = FakeFs::new(cx.executor());
2146 fs.create_dir(paths::settings_file().parent().unwrap())
2147 .await
2148 .unwrap();
2149 fs.insert_file(paths::settings_file(), b"{}".to_vec()).await;
2150 let project = Project::test(fs.clone(), [], cx).await;
2151
2152 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2153 let agent = NativeAgent::new(
2154 project.clone(),
2155 thread_store,
2156 Templates::new(),
2157 None,
2158 fs.clone(),
2159 &mut cx.to_async(),
2160 )
2161 .await
2162 .unwrap();
2163 let connection = NativeAgentConnection(agent.clone());
2164
2165 let acp_thread = cx
2166 .update(|cx| {
2167 Rc::new(connection.clone()).new_session(project.clone(), Path::new("/a"), cx)
2168 })
2169 .await
2170 .unwrap();
2171 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2172
2173 // Register a second provider with a thinking model.
2174 cx.update(|cx| {
2175 let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2176 "fake-corp",
2177 "fake-thinking",
2178 "Fake Thinking",
2179 true,
2180 ));
2181 let thinking_provider = Arc::new(
2182 FakeLanguageModelProvider::new(
2183 LanguageModelProviderId::from("fake-corp".to_string()),
2184 LanguageModelProviderName::from("Fake Corp".to_string()),
2185 )
2186 .with_models(vec![thinking_model]),
2187 );
2188 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2189 registry.register_provider(thinking_provider, cx);
2190 });
2191 });
2192 // Refresh the agent's model list so it picks up the new provider.
2193 agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2194
2195 // Thread starts with thinking_enabled = false (the default).
2196 agent.read_with(cx, |agent, _| {
2197 let session = agent.sessions.get(&session_id).unwrap();
2198 session.thread.read_with(cx, |thread, _| {
2199 assert!(!thread.thinking_enabled(), "thinking defaults to false");
2200 });
2201 });
2202
2203 // Select the thinking model via select_model.
2204 let selector = connection.model_selector(&session_id).unwrap();
2205 cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2206 .await
2207 .unwrap();
2208
2209 // select_model should have enabled thinking based on the model's supports_thinking().
2210 agent.read_with(cx, |agent, _| {
2211 let session = agent.sessions.get(&session_id).unwrap();
2212 session.thread.read_with(cx, |thread, _| {
2213 assert!(
2214 thread.thinking_enabled(),
2215 "select_model should enable thinking when model supports it"
2216 );
2217 });
2218 });
2219
2220 // Switch back to the non-thinking model.
2221 let selector = connection.model_selector(&session_id).unwrap();
2222 cx.update(|cx| selector.select_model(acp::ModelId::new("fake/fake"), cx))
2223 .await
2224 .unwrap();
2225
2226 // select_model should have disabled thinking.
2227 agent.read_with(cx, |agent, _| {
2228 let session = agent.sessions.get(&session_id).unwrap();
2229 session.thread.read_with(cx, |thread, _| {
2230 assert!(
2231 !thread.thinking_enabled(),
2232 "select_model should disable thinking when model does not support it"
2233 );
2234 });
2235 });
2236 }
2237
2238 #[gpui::test]
2239 async fn test_loaded_thread_preserves_thinking_enabled(cx: &mut TestAppContext) {
2240 init_test(cx);
2241 let fs = FakeFs::new(cx.executor());
2242 fs.insert_tree("/", json!({ "a": {} })).await;
2243 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2244 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2245 let agent = NativeAgent::new(
2246 project.clone(),
2247 thread_store.clone(),
2248 Templates::new(),
2249 None,
2250 fs.clone(),
2251 &mut cx.to_async(),
2252 )
2253 .await
2254 .unwrap();
2255 let connection = Rc::new(NativeAgentConnection(agent.clone()));
2256
2257 // Register a thinking model.
2258 let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2259 "fake-corp",
2260 "fake-thinking",
2261 "Fake Thinking",
2262 true,
2263 ));
2264 let thinking_provider = Arc::new(
2265 FakeLanguageModelProvider::new(
2266 LanguageModelProviderId::from("fake-corp".to_string()),
2267 LanguageModelProviderName::from("Fake Corp".to_string()),
2268 )
2269 .with_models(vec![thinking_model.clone()]),
2270 );
2271 cx.update(|cx| {
2272 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2273 registry.register_provider(thinking_provider, cx);
2274 });
2275 });
2276 agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2277
2278 // Create a thread and select the thinking model.
2279 let acp_thread = cx
2280 .update(|cx| {
2281 connection
2282 .clone()
2283 .new_session(project.clone(), Path::new("/a"), cx)
2284 })
2285 .await
2286 .unwrap();
2287 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2288
2289 let selector = connection.model_selector(&session_id).unwrap();
2290 cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2291 .await
2292 .unwrap();
2293
2294 // Verify thinking is enabled after selecting the thinking model.
2295 let thread = agent.read_with(cx, |agent, _| {
2296 agent.sessions.get(&session_id).unwrap().thread.clone()
2297 });
2298 thread.read_with(cx, |thread, _| {
2299 assert!(
2300 thread.thinking_enabled(),
2301 "thinking should be enabled after selecting thinking model"
2302 );
2303 });
2304
2305 // Send a message so the thread gets persisted.
2306 let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
2307 let send = cx.foreground_executor().spawn(send);
2308 cx.run_until_parked();
2309
2310 thinking_model.send_last_completion_stream_text_chunk("Response.");
2311 thinking_model.end_last_completion_stream();
2312
2313 send.await.unwrap();
2314 cx.run_until_parked();
2315
2316 // Close the session so it can be reloaded from disk.
2317 cx.update(|cx| connection.clone().close_session(&session_id, cx))
2318 .await
2319 .unwrap();
2320 drop(thread);
2321 drop(acp_thread);
2322 agent.read_with(cx, |agent, _| {
2323 assert!(agent.sessions.is_empty());
2324 });
2325
2326 // Reload the thread and verify thinking_enabled is still true.
2327 let reloaded_acp_thread = agent
2328 .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
2329 .await
2330 .unwrap();
2331 let reloaded_thread = agent.read_with(cx, |agent, _| {
2332 agent.sessions.get(&session_id).unwrap().thread.clone()
2333 });
2334 reloaded_thread.read_with(cx, |thread, _| {
2335 assert!(
2336 thread.thinking_enabled(),
2337 "thinking_enabled should be preserved when reloading a thread with a thinking model"
2338 );
2339 });
2340
2341 drop(reloaded_acp_thread);
2342 }
2343
2344 #[gpui::test]
2345 async fn test_loaded_thread_preserves_model(cx: &mut TestAppContext) {
2346 init_test(cx);
2347 let fs = FakeFs::new(cx.executor());
2348 fs.insert_tree("/", json!({ "a": {} })).await;
2349 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2350 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2351 let agent = NativeAgent::new(
2352 project.clone(),
2353 thread_store.clone(),
2354 Templates::new(),
2355 None,
2356 fs.clone(),
2357 &mut cx.to_async(),
2358 )
2359 .await
2360 .unwrap();
2361 let connection = Rc::new(NativeAgentConnection(agent.clone()));
2362
2363 // Register a model where id() != name(), like real Anthropic models
2364 // (e.g. id="claude-sonnet-4-5-thinking-latest", name="Claude Sonnet 4.5 Thinking").
2365 let model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2366 "fake-corp",
2367 "custom-model-id",
2368 "Custom Model Display Name",
2369 false,
2370 ));
2371 let provider = Arc::new(
2372 FakeLanguageModelProvider::new(
2373 LanguageModelProviderId::from("fake-corp".to_string()),
2374 LanguageModelProviderName::from("Fake Corp".to_string()),
2375 )
2376 .with_models(vec![model.clone()]),
2377 );
2378 cx.update(|cx| {
2379 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2380 registry.register_provider(provider, cx);
2381 });
2382 });
2383 agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2384
2385 // Create a thread and select the model.
2386 let acp_thread = cx
2387 .update(|cx| {
2388 connection
2389 .clone()
2390 .new_session(project.clone(), Path::new("/a"), cx)
2391 })
2392 .await
2393 .unwrap();
2394 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2395
2396 let selector = connection.model_selector(&session_id).unwrap();
2397 cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/custom-model-id"), cx))
2398 .await
2399 .unwrap();
2400
2401 let thread = agent.read_with(cx, |agent, _| {
2402 agent.sessions.get(&session_id).unwrap().thread.clone()
2403 });
2404 thread.read_with(cx, |thread, _| {
2405 assert_eq!(
2406 thread.model().unwrap().id().0.as_ref(),
2407 "custom-model-id",
2408 "model should be set before persisting"
2409 );
2410 });
2411
2412 // Send a message so the thread gets persisted.
2413 let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
2414 let send = cx.foreground_executor().spawn(send);
2415 cx.run_until_parked();
2416
2417 model.send_last_completion_stream_text_chunk("Response.");
2418 model.end_last_completion_stream();
2419
2420 send.await.unwrap();
2421 cx.run_until_parked();
2422
2423 // Close the session so it can be reloaded from disk.
2424 cx.update(|cx| connection.clone().close_session(&session_id, cx))
2425 .await
2426 .unwrap();
2427 drop(thread);
2428 drop(acp_thread);
2429 agent.read_with(cx, |agent, _| {
2430 assert!(agent.sessions.is_empty());
2431 });
2432
2433 // Reload the thread and verify the model was preserved.
2434 let reloaded_acp_thread = agent
2435 .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
2436 .await
2437 .unwrap();
2438 let reloaded_thread = agent.read_with(cx, |agent, _| {
2439 agent.sessions.get(&session_id).unwrap().thread.clone()
2440 });
2441 reloaded_thread.read_with(cx, |thread, _| {
2442 let reloaded_model = thread
2443 .model()
2444 .expect("model should be present after reload");
2445 assert_eq!(
2446 reloaded_model.id().0.as_ref(),
2447 "custom-model-id",
2448 "reloaded thread should have the same model, not fall back to the default"
2449 );
2450 });
2451
2452 drop(reloaded_acp_thread);
2453 }
2454
2455 #[gpui::test]
2456 async fn test_save_load_thread(cx: &mut TestAppContext) {
2457 init_test(cx);
2458 let fs = FakeFs::new(cx.executor());
2459 fs.insert_tree(
2460 "/",
2461 json!({
2462 "a": {
2463 "b.md": "Lorem"
2464 }
2465 }),
2466 )
2467 .await;
2468 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2469 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2470 let agent = NativeAgent::new(
2471 project.clone(),
2472 thread_store.clone(),
2473 Templates::new(),
2474 None,
2475 fs.clone(),
2476 &mut cx.to_async(),
2477 )
2478 .await
2479 .unwrap();
2480 let connection = Rc::new(NativeAgentConnection(agent.clone()));
2481
2482 let acp_thread = cx
2483 .update(|cx| {
2484 connection
2485 .clone()
2486 .new_session(project.clone(), Path::new(""), cx)
2487 })
2488 .await
2489 .unwrap();
2490 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2491 let thread = agent.read_with(cx, |agent, _| {
2492 agent.sessions.get(&session_id).unwrap().thread.clone()
2493 });
2494
2495 // Ensure empty threads are not saved, even if they get mutated.
2496 let model = Arc::new(FakeLanguageModel::default());
2497 let summary_model = Arc::new(FakeLanguageModel::default());
2498 thread.update(cx, |thread, cx| {
2499 thread.set_model(model.clone(), cx);
2500 thread.set_summarization_model(Some(summary_model.clone()), cx);
2501 });
2502 cx.run_until_parked();
2503 assert_eq!(thread_entries(&thread_store, cx), vec![]);
2504
2505 let send = acp_thread.update(cx, |thread, cx| {
2506 thread.send(
2507 vec![
2508 "What does ".into(),
2509 acp::ContentBlock::ResourceLink(acp::ResourceLink::new(
2510 "b.md",
2511 MentionUri::File {
2512 abs_path: path!("/a/b.md").into(),
2513 }
2514 .to_uri()
2515 .to_string(),
2516 )),
2517 " mean?".into(),
2518 ],
2519 cx,
2520 )
2521 });
2522 let send = cx.foreground_executor().spawn(send);
2523 cx.run_until_parked();
2524
2525 model.send_last_completion_stream_text_chunk("Lorem.");
2526 model.end_last_completion_stream();
2527 cx.run_until_parked();
2528 summary_model
2529 .send_last_completion_stream_text_chunk(&format!("Explaining {}", path!("/a/b.md")));
2530 summary_model.end_last_completion_stream();
2531
2532 send.await.unwrap();
2533 let uri = MentionUri::File {
2534 abs_path: path!("/a/b.md").into(),
2535 }
2536 .to_uri();
2537 acp_thread.read_with(cx, |thread, cx| {
2538 assert_eq!(
2539 thread.to_markdown(cx),
2540 formatdoc! {"
2541 ## User
2542
2543 What does [@b.md]({uri}) mean?
2544
2545 ## Assistant
2546
2547 Lorem.
2548
2549 "}
2550 )
2551 });
2552
2553 cx.run_until_parked();
2554
2555 // Close the session so it can be reloaded from disk.
2556 cx.update(|cx| connection.clone().close_session(&session_id, cx))
2557 .await
2558 .unwrap();
2559 drop(thread);
2560 drop(acp_thread);
2561 agent.read_with(cx, |agent, _| {
2562 assert_eq!(agent.sessions.keys().cloned().collect::<Vec<_>>(), []);
2563 });
2564
2565 // Ensure the thread can be reloaded from disk.
2566 assert_eq!(
2567 thread_entries(&thread_store, cx),
2568 vec![(
2569 session_id.clone(),
2570 format!("Explaining {}", path!("/a/b.md"))
2571 )]
2572 );
2573 let acp_thread = agent
2574 .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
2575 .await
2576 .unwrap();
2577 acp_thread.read_with(cx, |thread, cx| {
2578 assert_eq!(
2579 thread.to_markdown(cx),
2580 formatdoc! {"
2581 ## User
2582
2583 What does [@b.md]({uri}) mean?
2584
2585 ## Assistant
2586
2587 Lorem.
2588
2589 "}
2590 )
2591 });
2592 }
2593
2594 fn thread_entries(
2595 thread_store: &Entity<ThreadStore>,
2596 cx: &mut TestAppContext,
2597 ) -> Vec<(acp::SessionId, String)> {
2598 thread_store.read_with(cx, |store, _| {
2599 store
2600 .entries()
2601 .map(|entry| (entry.id.clone(), entry.title.to_string()))
2602 .collect::<Vec<_>>()
2603 })
2604 }
2605
2606 fn init_test(cx: &mut TestAppContext) {
2607 env_logger::try_init().ok();
2608 cx.update(|cx| {
2609 let settings_store = SettingsStore::test(cx);
2610 cx.set_global(settings_store);
2611
2612 LanguageModelRegistry::test(cx);
2613 });
2614 }
2615}
2616
2617fn mcp_message_content_to_acp_content_block(
2618 content: context_server::types::MessageContent,
2619) -> acp::ContentBlock {
2620 match content {
2621 context_server::types::MessageContent::Text {
2622 text,
2623 annotations: _,
2624 } => text.into(),
2625 context_server::types::MessageContent::Image {
2626 data,
2627 mime_type,
2628 annotations: _,
2629 } => acp::ContentBlock::Image(acp::ImageContent::new(data, mime_type)),
2630 context_server::types::MessageContent::Audio {
2631 data,
2632 mime_type,
2633 annotations: _,
2634 } => acp::ContentBlock::Audio(acp::AudioContent::new(data, mime_type)),
2635 context_server::types::MessageContent::Resource {
2636 resource,
2637 annotations: _,
2638 } => {
2639 let mut link =
2640 acp::ResourceLink::new(resource.uri.to_string(), resource.uri.to_string());
2641 if let Some(mime_type) = resource.mime_type {
2642 link = link.mime_type(mime_type);
2643 }
2644 acp::ContentBlock::ResourceLink(link)
2645 }
2646 }
2647}