1mod db;
2mod edit_agent;
3mod legacy_thread;
4mod native_agent_server;
5pub mod outline;
6mod pattern_extraction;
7mod templates;
8#[cfg(test)]
9mod tests;
10mod thread;
11mod thread_store;
12mod tool_permissions;
13mod tools;
14
15use context_server::ContextServerId;
16pub use db::*;
17pub use native_agent_server::NativeAgentServer;
18pub use pattern_extraction::*;
19pub use shell_command_parser::extract_commands;
20pub use templates::*;
21pub use thread::*;
22pub use thread_store::*;
23pub use tool_permissions::*;
24pub use tools::*;
25
26use acp_thread::{
27 AcpThread, AgentModelSelector, AgentSessionInfo, AgentSessionList, AgentSessionListRequest,
28 AgentSessionListResponse, UserMessageId,
29};
30use agent_client_protocol as acp;
31use anyhow::{Context as _, Result, anyhow};
32use chrono::{DateTime, Utc};
33use collections::{HashMap, HashSet, IndexMap};
34use fs::Fs;
35use futures::channel::{mpsc, oneshot};
36use futures::future::Shared;
37use futures::{FutureExt as _, StreamExt as _, future};
38use gpui::{
39 App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
40};
41use language_model::{IconOrSvg, LanguageModel, LanguageModelProvider, LanguageModelRegistry};
42use project::{Project, ProjectItem, ProjectPath, Worktree};
43use prompt_store::{
44 ProjectContext, PromptStore, RULES_FILE_NAMES, RulesFileContext, UserRulesContext,
45 WorktreeContext,
46};
47use serde::{Deserialize, Serialize};
48use settings::{LanguageModelSelection, update_settings_file};
49use std::any::Any;
50use std::path::{Path, PathBuf};
51use std::rc::Rc;
52use std::sync::Arc;
53use std::time::Duration;
54use util::ResultExt;
55use util::rel_path::RelPath;
56
57#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
58pub struct ProjectSnapshot {
59 pub worktree_snapshots: Vec<project::telemetry_snapshot::TelemetryWorktreeSnapshot>,
60 pub timestamp: DateTime<Utc>,
61}
62
63pub struct RulesLoadingError {
64 pub message: SharedString,
65}
66
67/// Holds both the internal Thread and the AcpThread for a session
68struct Session {
69 /// The internal thread that processes messages
70 thread: Entity<Thread>,
71 /// The ACP thread that handles protocol communication
72 acp_thread: Entity<acp_thread::AcpThread>,
73 pending_save: Task<()>,
74 _subscriptions: Vec<Subscription>,
75}
76
77pub struct LanguageModels {
78 /// Access language model by ID
79 models: HashMap<acp::ModelId, Arc<dyn LanguageModel>>,
80 /// Cached list for returning language model information
81 model_list: acp_thread::AgentModelList,
82 refresh_models_rx: watch::Receiver<()>,
83 refresh_models_tx: watch::Sender<()>,
84 _authenticate_all_providers_task: Task<()>,
85}
86
87impl LanguageModels {
88 fn new(cx: &mut App) -> Self {
89 let (refresh_models_tx, refresh_models_rx) = watch::channel(());
90
91 let mut this = Self {
92 models: HashMap::default(),
93 model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()),
94 refresh_models_rx,
95 refresh_models_tx,
96 _authenticate_all_providers_task: Self::authenticate_all_language_model_providers(cx),
97 };
98 this.refresh_list(cx);
99 this
100 }
101
102 fn refresh_list(&mut self, cx: &App) {
103 let providers = LanguageModelRegistry::global(cx)
104 .read(cx)
105 .visible_providers()
106 .into_iter()
107 .filter(|provider| provider.is_authenticated(cx))
108 .collect::<Vec<_>>();
109
110 let mut language_model_list = IndexMap::default();
111 let mut recommended_models = HashSet::default();
112
113 let mut recommended = Vec::new();
114 for provider in &providers {
115 for model in provider.recommended_models(cx) {
116 recommended_models.insert((model.provider_id(), model.id()));
117 recommended.push(Self::map_language_model_to_info(&model, provider));
118 }
119 }
120 if !recommended.is_empty() {
121 language_model_list.insert(
122 acp_thread::AgentModelGroupName("Recommended".into()),
123 recommended,
124 );
125 }
126
127 let mut models = HashMap::default();
128 for provider in providers {
129 let mut provider_models = Vec::new();
130 for model in provider.provided_models(cx) {
131 let model_info = Self::map_language_model_to_info(&model, &provider);
132 let model_id = model_info.id.clone();
133 provider_models.push(model_info);
134 models.insert(model_id, model);
135 }
136 if !provider_models.is_empty() {
137 language_model_list.insert(
138 acp_thread::AgentModelGroupName(provider.name().0.clone()),
139 provider_models,
140 );
141 }
142 }
143
144 self.models = models;
145 self.model_list = acp_thread::AgentModelList::Grouped(language_model_list);
146 self.refresh_models_tx.send(()).ok();
147 }
148
149 fn watch(&self) -> watch::Receiver<()> {
150 self.refresh_models_rx.clone()
151 }
152
153 pub fn model_from_id(&self, model_id: &acp::ModelId) -> Option<Arc<dyn LanguageModel>> {
154 self.models.get(model_id).cloned()
155 }
156
157 fn map_language_model_to_info(
158 model: &Arc<dyn LanguageModel>,
159 provider: &Arc<dyn LanguageModelProvider>,
160 ) -> acp_thread::AgentModelInfo {
161 acp_thread::AgentModelInfo {
162 id: Self::model_id(model),
163 name: model.name().0,
164 description: None,
165 icon: Some(match provider.icon() {
166 IconOrSvg::Svg(path) => acp_thread::AgentModelIcon::Path(path),
167 IconOrSvg::Icon(name) => acp_thread::AgentModelIcon::Named(name),
168 }),
169 is_latest: model.is_latest(),
170 cost: model.model_cost_info().map(|cost| cost.to_shared_string()),
171 }
172 }
173
174 fn model_id(model: &Arc<dyn LanguageModel>) -> acp::ModelId {
175 acp::ModelId::new(format!("{}/{}", model.provider_id().0, model.id().0))
176 }
177
178 fn authenticate_all_language_model_providers(cx: &mut App) -> Task<()> {
179 let authenticate_all_providers = LanguageModelRegistry::global(cx)
180 .read(cx)
181 .visible_providers()
182 .iter()
183 .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
184 .collect::<Vec<_>>();
185
186 cx.background_spawn(async move {
187 for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
188 if let Err(err) = authenticate_task.await {
189 match err {
190 language_model::AuthenticateError::CredentialsNotFound => {
191 // Since we're authenticating these providers in the
192 // background for the purposes of populating the
193 // language selector, we don't care about providers
194 // where the credentials are not found.
195 }
196 language_model::AuthenticateError::ConnectionRefused => {
197 // Not logging connection refused errors as they are mostly from LM Studio's noisy auth failures.
198 // LM Studio only has one auth method (endpoint call) which fails for users who haven't enabled it.
199 // TODO: Better manage LM Studio auth logic to avoid these noisy failures.
200 }
201 _ => {
202 // Some providers have noisy failure states that we
203 // don't want to spam the logs with every time the
204 // language model selector is initialized.
205 //
206 // Ideally these should have more clear failure modes
207 // that we know are safe to ignore here, like what we do
208 // with `CredentialsNotFound` above.
209 match provider_id.0.as_ref() {
210 "lmstudio" | "ollama" => {
211 // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
212 //
213 // These fail noisily, so we don't log them.
214 }
215 "copilot_chat" => {
216 // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
217 }
218 _ => {
219 log::error!(
220 "Failed to authenticate provider: {}: {err:#}",
221 provider_name.0
222 );
223 }
224 }
225 }
226 }
227 }
228 }
229 })
230 }
231}
232
233pub struct NativeAgent {
234 /// Session ID -> Session mapping
235 sessions: HashMap<acp::SessionId, Session>,
236 thread_store: Entity<ThreadStore>,
237 /// Shared project context for all threads
238 project_context: Entity<ProjectContext>,
239 project_context_needs_refresh: watch::Sender<()>,
240 _maintain_project_context: Task<Result<()>>,
241 context_server_registry: Entity<ContextServerRegistry>,
242 /// Shared templates for all threads
243 templates: Arc<Templates>,
244 /// Cached model information
245 models: LanguageModels,
246 project: Entity<Project>,
247 prompt_store: Option<Entity<PromptStore>>,
248 fs: Arc<dyn Fs>,
249 _subscriptions: Vec<Subscription>,
250}
251
252impl NativeAgent {
253 pub async fn new(
254 project: Entity<Project>,
255 thread_store: Entity<ThreadStore>,
256 templates: Arc<Templates>,
257 prompt_store: Option<Entity<PromptStore>>,
258 fs: Arc<dyn Fs>,
259 cx: &mut AsyncApp,
260 ) -> Result<Entity<NativeAgent>> {
261 log::debug!("Creating new NativeAgent");
262
263 let project_context = cx
264 .update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))
265 .await;
266
267 Ok(cx.new(|cx| {
268 let context_server_store = project.read(cx).context_server_store();
269 let context_server_registry =
270 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
271
272 let mut subscriptions = vec![
273 cx.subscribe(&project, Self::handle_project_event),
274 cx.subscribe(
275 &LanguageModelRegistry::global(cx),
276 Self::handle_models_updated_event,
277 ),
278 cx.subscribe(
279 &context_server_store,
280 Self::handle_context_server_store_updated,
281 ),
282 cx.subscribe(
283 &context_server_registry,
284 Self::handle_context_server_registry_event,
285 ),
286 ];
287 if let Some(prompt_store) = prompt_store.as_ref() {
288 subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
289 }
290
291 let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
292 watch::channel(());
293 Self {
294 sessions: HashMap::default(),
295 thread_store,
296 project_context: cx.new(|_| project_context),
297 project_context_needs_refresh: project_context_needs_refresh_tx,
298 _maintain_project_context: cx.spawn(async move |this, cx| {
299 Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
300 }),
301 context_server_registry,
302 templates,
303 models: LanguageModels::new(cx),
304 project,
305 prompt_store,
306 fs,
307 _subscriptions: subscriptions,
308 }
309 }))
310 }
311
312 fn new_session(
313 &mut self,
314 project: Entity<Project>,
315 cx: &mut Context<Self>,
316 ) -> Entity<AcpThread> {
317 // Create Thread
318 // Fetch default model from registry settings
319 let registry = LanguageModelRegistry::read_global(cx);
320 // Log available models for debugging
321 let available_count = registry.available_models(cx).count();
322 log::debug!("Total available models: {}", available_count);
323
324 let default_model = registry.default_model().and_then(|default_model| {
325 self.models
326 .model_from_id(&LanguageModels::model_id(&default_model.model))
327 });
328 let thread = cx.new(|cx| {
329 Thread::new(
330 project.clone(),
331 self.project_context.clone(),
332 self.context_server_registry.clone(),
333 self.templates.clone(),
334 default_model,
335 cx,
336 )
337 });
338
339 self.register_session(thread, None, cx)
340 }
341
342 fn register_session(
343 &mut self,
344 thread_handle: Entity<Thread>,
345 allowed_tool_names: Option<Vec<SharedString>>,
346 cx: &mut Context<Self>,
347 ) -> Entity<AcpThread> {
348 let connection = Rc::new(NativeAgentConnection(cx.entity()));
349
350 let thread = thread_handle.read(cx);
351 let session_id = thread.id().clone();
352 let parent_session_id = thread.parent_thread_id();
353 let title = thread.title();
354 let project = thread.project.clone();
355 let action_log = thread.action_log.clone();
356 let prompt_capabilities_rx = thread.prompt_capabilities_rx.clone();
357 let acp_thread = cx.new(|cx| {
358 acp_thread::AcpThread::new(
359 parent_session_id,
360 title,
361 connection,
362 project.clone(),
363 action_log.clone(),
364 session_id.clone(),
365 prompt_capabilities_rx,
366 cx,
367 )
368 });
369
370 let registry = LanguageModelRegistry::read_global(cx);
371 let summarization_model = registry.thread_summary_model().map(|c| c.model);
372
373 let weak = cx.weak_entity();
374 thread_handle.update(cx, |thread, cx| {
375 thread.set_summarization_model(summarization_model, cx);
376 thread.add_default_tools(
377 allowed_tool_names,
378 Rc::new(NativeThreadEnvironment {
379 acp_thread: acp_thread.downgrade(),
380 agent: weak,
381 }) as _,
382 cx,
383 )
384 });
385
386 let subscriptions = vec![
387 cx.subscribe(&thread_handle, Self::handle_thread_title_updated),
388 cx.subscribe(&thread_handle, Self::handle_thread_token_usage_updated),
389 cx.observe(&thread_handle, move |this, thread, cx| {
390 this.save_thread(thread, cx)
391 }),
392 ];
393
394 self.sessions.insert(
395 session_id,
396 Session {
397 thread: thread_handle,
398 acp_thread: acp_thread.clone(),
399 _subscriptions: subscriptions,
400 pending_save: Task::ready(()),
401 },
402 );
403
404 self.update_available_commands(cx);
405
406 acp_thread
407 }
408
409 pub fn models(&self) -> &LanguageModels {
410 &self.models
411 }
412
413 async fn maintain_project_context(
414 this: WeakEntity<Self>,
415 mut needs_refresh: watch::Receiver<()>,
416 cx: &mut AsyncApp,
417 ) -> Result<()> {
418 while needs_refresh.changed().await.is_ok() {
419 let project_context = this
420 .update(cx, |this, cx| {
421 Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx)
422 })?
423 .await;
424 this.update(cx, |this, cx| {
425 this.project_context = cx.new(|_| project_context);
426 })?;
427 }
428
429 Ok(())
430 }
431
432 fn build_project_context(
433 project: &Entity<Project>,
434 prompt_store: Option<&Entity<PromptStore>>,
435 cx: &mut App,
436 ) -> Task<ProjectContext> {
437 let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
438 let worktree_tasks = worktrees
439 .into_iter()
440 .map(|worktree| {
441 Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
442 })
443 .collect::<Vec<_>>();
444 let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
445 prompt_store.read_with(cx, |prompt_store, cx| {
446 let prompts = prompt_store.default_prompt_metadata();
447 let load_tasks = prompts.into_iter().map(|prompt_metadata| {
448 let contents = prompt_store.load(prompt_metadata.id, cx);
449 async move { (contents.await, prompt_metadata) }
450 });
451 cx.background_spawn(future::join_all(load_tasks))
452 })
453 } else {
454 Task::ready(vec![])
455 };
456
457 cx.spawn(async move |_cx| {
458 let (worktrees, default_user_rules) =
459 future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
460
461 let worktrees = worktrees
462 .into_iter()
463 .map(|(worktree, _rules_error)| {
464 // TODO: show error message
465 // if let Some(rules_error) = rules_error {
466 // this.update(cx, |_, cx| cx.emit(rules_error)).ok();
467 // }
468 worktree
469 })
470 .collect::<Vec<_>>();
471
472 let default_user_rules = default_user_rules
473 .into_iter()
474 .flat_map(|(contents, prompt_metadata)| match contents {
475 Ok(contents) => Some(UserRulesContext {
476 uuid: prompt_metadata.id.as_user()?,
477 title: prompt_metadata.title.map(|title| title.to_string()),
478 contents,
479 }),
480 Err(_err) => {
481 // TODO: show error message
482 // this.update(cx, |_, cx| {
483 // cx.emit(RulesLoadingError {
484 // message: format!("{err:?}").into(),
485 // });
486 // })
487 // .ok();
488 None
489 }
490 })
491 .collect::<Vec<_>>();
492
493 ProjectContext::new(worktrees, default_user_rules)
494 })
495 }
496
497 fn load_worktree_info_for_system_prompt(
498 worktree: Entity<Worktree>,
499 project: Entity<Project>,
500 cx: &mut App,
501 ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
502 let tree = worktree.read(cx);
503 let root_name = tree.root_name_str().into();
504 let abs_path = tree.abs_path();
505
506 let mut context = WorktreeContext {
507 root_name,
508 abs_path,
509 rules_file: None,
510 };
511
512 let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
513 let Some(rules_task) = rules_task else {
514 return Task::ready((context, None));
515 };
516
517 cx.spawn(async move |_| {
518 let (rules_file, rules_file_error) = match rules_task.await {
519 Ok(rules_file) => (Some(rules_file), None),
520 Err(err) => (
521 None,
522 Some(RulesLoadingError {
523 message: format!("{err}").into(),
524 }),
525 ),
526 };
527 context.rules_file = rules_file;
528 (context, rules_file_error)
529 })
530 }
531
532 fn load_worktree_rules_file(
533 worktree: Entity<Worktree>,
534 project: Entity<Project>,
535 cx: &mut App,
536 ) -> Option<Task<Result<RulesFileContext>>> {
537 let worktree = worktree.read(cx);
538 let worktree_id = worktree.id();
539 let selected_rules_file = RULES_FILE_NAMES
540 .into_iter()
541 .filter_map(|name| {
542 worktree
543 .entry_for_path(RelPath::unix(name).unwrap())
544 .filter(|entry| entry.is_file())
545 .map(|entry| entry.path.clone())
546 })
547 .next();
548
549 // Note that Cline supports `.clinerules` being a directory, but that is not currently
550 // supported. This doesn't seem to occur often in GitHub repositories.
551 selected_rules_file.map(|path_in_worktree| {
552 let project_path = ProjectPath {
553 worktree_id,
554 path: path_in_worktree.clone(),
555 };
556 let buffer_task =
557 project.update(cx, |project, cx| project.open_buffer(project_path, cx));
558 let rope_task = cx.spawn(async move |cx| {
559 let buffer = buffer_task.await?;
560 let (project_entry_id, rope) = buffer.read_with(cx, |buffer, cx| {
561 let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
562 anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
563 })?;
564 anyhow::Ok((project_entry_id, rope))
565 });
566 // Build a string from the rope on a background thread.
567 cx.background_spawn(async move {
568 let (project_entry_id, rope) = rope_task.await?;
569 anyhow::Ok(RulesFileContext {
570 path_in_worktree,
571 text: rope.to_string().trim().to_string(),
572 project_entry_id: project_entry_id.to_usize(),
573 })
574 })
575 })
576 }
577
578 fn handle_thread_title_updated(
579 &mut self,
580 thread: Entity<Thread>,
581 _: &TitleUpdated,
582 cx: &mut Context<Self>,
583 ) {
584 let session_id = thread.read(cx).id();
585 let Some(session) = self.sessions.get(session_id) else {
586 return;
587 };
588 let thread = thread.downgrade();
589 let acp_thread = session.acp_thread.downgrade();
590 cx.spawn(async move |_, cx| {
591 let title = thread.read_with(cx, |thread, _| thread.title())?;
592 let task = acp_thread.update(cx, |acp_thread, cx| acp_thread.set_title(title, cx))?;
593 task.await
594 })
595 .detach_and_log_err(cx);
596 }
597
598 fn handle_thread_token_usage_updated(
599 &mut self,
600 thread: Entity<Thread>,
601 usage: &TokenUsageUpdated,
602 cx: &mut Context<Self>,
603 ) {
604 let Some(session) = self.sessions.get(thread.read(cx).id()) else {
605 return;
606 };
607 session.acp_thread.update(cx, |acp_thread, cx| {
608 acp_thread.update_token_usage(usage.0.clone(), cx);
609 });
610 }
611
612 fn handle_project_event(
613 &mut self,
614 _project: Entity<Project>,
615 event: &project::Event,
616 _cx: &mut Context<Self>,
617 ) {
618 match event {
619 project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
620 self.project_context_needs_refresh.send(()).ok();
621 }
622 project::Event::WorktreeUpdatedEntries(_, items) => {
623 if items.iter().any(|(path, _, _)| {
624 RULES_FILE_NAMES
625 .iter()
626 .any(|name| path.as_ref() == RelPath::unix(name).unwrap())
627 }) {
628 self.project_context_needs_refresh.send(()).ok();
629 }
630 }
631 _ => {}
632 }
633 }
634
635 fn handle_prompts_updated_event(
636 &mut self,
637 _prompt_store: Entity<PromptStore>,
638 _event: &prompt_store::PromptsUpdatedEvent,
639 _cx: &mut Context<Self>,
640 ) {
641 self.project_context_needs_refresh.send(()).ok();
642 }
643
644 fn handle_models_updated_event(
645 &mut self,
646 _registry: Entity<LanguageModelRegistry>,
647 _event: &language_model::Event,
648 cx: &mut Context<Self>,
649 ) {
650 self.models.refresh_list(cx);
651
652 let registry = LanguageModelRegistry::read_global(cx);
653 let default_model = registry.default_model().map(|m| m.model);
654 let summarization_model = registry.thread_summary_model().map(|m| m.model);
655
656 for session in self.sessions.values_mut() {
657 session.thread.update(cx, |thread, cx| {
658 if thread.model().is_none()
659 && let Some(model) = default_model.clone()
660 {
661 thread.set_model(model, cx);
662 cx.notify();
663 }
664 thread.set_summarization_model(summarization_model.clone(), cx);
665 });
666 }
667 }
668
669 fn handle_context_server_store_updated(
670 &mut self,
671 _store: Entity<project::context_server_store::ContextServerStore>,
672 _event: &project::context_server_store::ServerStatusChangedEvent,
673 cx: &mut Context<Self>,
674 ) {
675 self.update_available_commands(cx);
676 }
677
678 fn handle_context_server_registry_event(
679 &mut self,
680 _registry: Entity<ContextServerRegistry>,
681 event: &ContextServerRegistryEvent,
682 cx: &mut Context<Self>,
683 ) {
684 match event {
685 ContextServerRegistryEvent::ToolsChanged => {}
686 ContextServerRegistryEvent::PromptsChanged => {
687 self.update_available_commands(cx);
688 }
689 }
690 }
691
692 fn update_available_commands(&self, cx: &mut Context<Self>) {
693 let available_commands = self.build_available_commands(cx);
694 for session in self.sessions.values() {
695 session.acp_thread.update(cx, |thread, cx| {
696 thread
697 .handle_session_update(
698 acp::SessionUpdate::AvailableCommandsUpdate(
699 acp::AvailableCommandsUpdate::new(available_commands.clone()),
700 ),
701 cx,
702 )
703 .log_err();
704 });
705 }
706 }
707
708 fn build_available_commands(&self, cx: &App) -> Vec<acp::AvailableCommand> {
709 let registry = self.context_server_registry.read(cx);
710
711 let mut prompt_name_counts: HashMap<&str, usize> = HashMap::default();
712 for context_server_prompt in registry.prompts() {
713 *prompt_name_counts
714 .entry(context_server_prompt.prompt.name.as_str())
715 .or_insert(0) += 1;
716 }
717
718 registry
719 .prompts()
720 .flat_map(|context_server_prompt| {
721 let prompt = &context_server_prompt.prompt;
722
723 let should_prefix = prompt_name_counts
724 .get(prompt.name.as_str())
725 .copied()
726 .unwrap_or(0)
727 > 1;
728
729 let name = if should_prefix {
730 format!("{}.{}", context_server_prompt.server_id, prompt.name)
731 } else {
732 prompt.name.clone()
733 };
734
735 let mut command = acp::AvailableCommand::new(
736 name,
737 prompt.description.clone().unwrap_or_default(),
738 );
739
740 match prompt.arguments.as_deref() {
741 Some([arg]) => {
742 let hint = format!("<{}>", arg.name);
743
744 command = command.input(acp::AvailableCommandInput::Unstructured(
745 acp::UnstructuredCommandInput::new(hint),
746 ));
747 }
748 Some([]) | None => {}
749 Some(_) => {
750 // skip >1 argument commands since we don't support them yet
751 return None;
752 }
753 }
754
755 Some(command)
756 })
757 .collect()
758 }
759
760 pub fn load_thread(
761 &mut self,
762 id: acp::SessionId,
763 cx: &mut Context<Self>,
764 ) -> Task<Result<Entity<Thread>>> {
765 let database_future = ThreadsDatabase::connect(cx);
766 cx.spawn(async move |this, cx| {
767 let database = database_future.await.map_err(|err| anyhow!(err))?;
768 let db_thread = database
769 .load_thread(id.clone())
770 .await?
771 .with_context(|| format!("no thread found with ID: {id:?}"))?;
772
773 this.update(cx, |this, cx| {
774 let summarization_model = LanguageModelRegistry::read_global(cx)
775 .thread_summary_model()
776 .map(|c| c.model);
777
778 cx.new(|cx| {
779 let mut thread = Thread::from_db(
780 id.clone(),
781 db_thread,
782 this.project.clone(),
783 this.project_context.clone(),
784 this.context_server_registry.clone(),
785 this.templates.clone(),
786 cx,
787 );
788 thread.set_summarization_model(summarization_model, cx);
789 thread
790 })
791 })
792 })
793 }
794
795 pub fn open_thread(
796 &mut self,
797 id: acp::SessionId,
798 cx: &mut Context<Self>,
799 ) -> Task<Result<Entity<AcpThread>>> {
800 if let Some(session) = self.sessions.get(&id) {
801 return Task::ready(Ok(session.acp_thread.clone()));
802 }
803
804 let task = self.load_thread(id, cx);
805 cx.spawn(async move |this, cx| {
806 let thread = task.await?;
807 let acp_thread = this.update(cx, |this, cx| {
808 this.register_session(thread.clone(), None, cx)
809 })?;
810 let events = thread.update(cx, |thread, cx| thread.replay(cx));
811 cx.update(|cx| {
812 NativeAgentConnection::handle_thread_events(events, acp_thread.downgrade(), cx)
813 })
814 .await?;
815 Ok(acp_thread)
816 })
817 }
818
819 pub fn thread_summary(
820 &mut self,
821 id: acp::SessionId,
822 cx: &mut Context<Self>,
823 ) -> Task<Result<SharedString>> {
824 let thread = self.open_thread(id.clone(), cx);
825 cx.spawn(async move |this, cx| {
826 let acp_thread = thread.await?;
827 let result = this
828 .update(cx, |this, cx| {
829 this.sessions
830 .get(&id)
831 .unwrap()
832 .thread
833 .update(cx, |thread, cx| thread.summary(cx))
834 })?
835 .await
836 .context("Failed to generate summary")?;
837 drop(acp_thread);
838 Ok(result)
839 })
840 }
841
842 fn save_thread(&mut self, thread: Entity<Thread>, cx: &mut Context<Self>) {
843 if thread.read(cx).is_empty() {
844 return;
845 }
846
847 let database_future = ThreadsDatabase::connect(cx);
848 let (id, db_thread) =
849 thread.update(cx, |thread, cx| (thread.id().clone(), thread.to_db(cx)));
850 let Some(session) = self.sessions.get_mut(&id) else {
851 return;
852 };
853 let thread_store = self.thread_store.clone();
854 session.pending_save = cx.spawn(async move |_, cx| {
855 let Some(database) = database_future.await.map_err(|err| anyhow!(err)).log_err() else {
856 return;
857 };
858 let db_thread = db_thread.await;
859 database.save_thread(id, db_thread).await.log_err();
860 thread_store.update(cx, |store, cx| store.reload(cx));
861 });
862 }
863
864 fn send_mcp_prompt(
865 &self,
866 message_id: UserMessageId,
867 session_id: agent_client_protocol::SessionId,
868 prompt_name: String,
869 server_id: ContextServerId,
870 arguments: HashMap<String, String>,
871 original_content: Vec<acp::ContentBlock>,
872 cx: &mut Context<Self>,
873 ) -> Task<Result<acp::PromptResponse>> {
874 let server_store = self.context_server_registry.read(cx).server_store().clone();
875 let path_style = self.project.read(cx).path_style(cx);
876
877 cx.spawn(async move |this, cx| {
878 let prompt =
879 crate::get_prompt(&server_store, &server_id, &prompt_name, arguments, cx).await?;
880
881 let (acp_thread, thread) = this.update(cx, |this, _cx| {
882 let session = this
883 .sessions
884 .get(&session_id)
885 .context("Failed to get session")?;
886 anyhow::Ok((session.acp_thread.clone(), session.thread.clone()))
887 })??;
888
889 let mut last_is_user = true;
890
891 thread.update(cx, |thread, cx| {
892 thread.push_acp_user_block(
893 message_id,
894 original_content.into_iter().skip(1),
895 path_style,
896 cx,
897 );
898 });
899
900 for message in prompt.messages {
901 let context_server::types::PromptMessage { role, content } = message;
902 let block = mcp_message_content_to_acp_content_block(content);
903
904 match role {
905 context_server::types::Role::User => {
906 let id = acp_thread::UserMessageId::new();
907
908 acp_thread.update(cx, |acp_thread, cx| {
909 acp_thread.push_user_content_block_with_indent(
910 Some(id.clone()),
911 block.clone(),
912 true,
913 cx,
914 );
915 });
916
917 thread.update(cx, |thread, cx| {
918 thread.push_acp_user_block(id, [block], path_style, cx);
919 });
920 }
921 context_server::types::Role::Assistant => {
922 acp_thread.update(cx, |acp_thread, cx| {
923 acp_thread.push_assistant_content_block_with_indent(
924 block.clone(),
925 false,
926 true,
927 cx,
928 );
929 });
930
931 thread.update(cx, |thread, cx| {
932 thread.push_acp_agent_block(block, cx);
933 });
934 }
935 }
936
937 last_is_user = role == context_server::types::Role::User;
938 }
939
940 let response_stream = thread.update(cx, |thread, cx| {
941 if last_is_user {
942 thread.send_existing(cx)
943 } else {
944 // Resume if MCP prompt did not end with a user message
945 thread.resume(cx)
946 }
947 })?;
948
949 cx.update(|cx| {
950 NativeAgentConnection::handle_thread_events(
951 response_stream,
952 acp_thread.downgrade(),
953 cx,
954 )
955 })
956 .await
957 })
958 }
959}
960
961/// Wrapper struct that implements the AgentConnection trait
962#[derive(Clone)]
963pub struct NativeAgentConnection(pub Entity<NativeAgent>);
964
965impl NativeAgentConnection {
966 pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option<Entity<Thread>> {
967 self.0
968 .read(cx)
969 .sessions
970 .get(session_id)
971 .map(|session| session.thread.clone())
972 }
973
974 pub fn load_thread(&self, id: acp::SessionId, cx: &mut App) -> Task<Result<Entity<Thread>>> {
975 self.0.update(cx, |this, cx| this.load_thread(id, cx))
976 }
977
978 fn run_turn(
979 &self,
980 session_id: acp::SessionId,
981 cx: &mut App,
982 f: impl 'static
983 + FnOnce(Entity<Thread>, &mut App) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>,
984 ) -> Task<Result<acp::PromptResponse>> {
985 let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
986 agent
987 .sessions
988 .get_mut(&session_id)
989 .map(|s| (s.thread.clone(), s.acp_thread.clone()))
990 }) else {
991 return Task::ready(Err(anyhow!("Session not found")));
992 };
993 log::debug!("Found session for: {}", session_id);
994
995 let response_stream = match f(thread, cx) {
996 Ok(stream) => stream,
997 Err(err) => return Task::ready(Err(err)),
998 };
999 Self::handle_thread_events(response_stream, acp_thread.downgrade(), cx)
1000 }
1001
1002 fn handle_thread_events(
1003 mut events: mpsc::UnboundedReceiver<Result<ThreadEvent>>,
1004 acp_thread: WeakEntity<AcpThread>,
1005 cx: &App,
1006 ) -> Task<Result<acp::PromptResponse>> {
1007 cx.spawn(async move |cx| {
1008 // Handle response stream and forward to session.acp_thread
1009 while let Some(result) = events.next().await {
1010 match result {
1011 Ok(event) => {
1012 log::trace!("Received completion event: {:?}", event);
1013
1014 match event {
1015 ThreadEvent::UserMessage(message) => {
1016 acp_thread.update(cx, |thread, cx| {
1017 for content in message.content {
1018 thread.push_user_content_block(
1019 Some(message.id.clone()),
1020 content.into(),
1021 cx,
1022 );
1023 }
1024 })?;
1025 }
1026 ThreadEvent::AgentText(text) => {
1027 acp_thread.update(cx, |thread, cx| {
1028 thread.push_assistant_content_block(text.into(), false, cx)
1029 })?;
1030 }
1031 ThreadEvent::AgentThinking(text) => {
1032 acp_thread.update(cx, |thread, cx| {
1033 thread.push_assistant_content_block(text.into(), true, cx)
1034 })?;
1035 }
1036 ThreadEvent::ToolCallAuthorization(ToolCallAuthorization {
1037 tool_call,
1038 options,
1039 response,
1040 context: _,
1041 }) => {
1042 let outcome_task = acp_thread.update(cx, |thread, cx| {
1043 thread.request_tool_call_authorization(tool_call, options, cx)
1044 })??;
1045 cx.background_spawn(async move {
1046 if let acp::RequestPermissionOutcome::Selected(
1047 acp::SelectedPermissionOutcome { option_id, .. },
1048 ) = outcome_task.await
1049 {
1050 response
1051 .send(option_id)
1052 .map(|_| anyhow!("authorization receiver was dropped"))
1053 .log_err();
1054 }
1055 })
1056 .detach();
1057 }
1058 ThreadEvent::ToolCall(tool_call) => {
1059 acp_thread.update(cx, |thread, cx| {
1060 thread.upsert_tool_call(tool_call, cx)
1061 })??;
1062 }
1063 ThreadEvent::ToolCallUpdate(update) => {
1064 acp_thread.update(cx, |thread, cx| {
1065 thread.update_tool_call(update, cx)
1066 })??;
1067 }
1068 ThreadEvent::SubagentSpawned(session_id) => {
1069 acp_thread.update(cx, |thread, cx| {
1070 thread.subagent_spawned(session_id, cx);
1071 })?;
1072 }
1073 ThreadEvent::Retry(status) => {
1074 acp_thread.update(cx, |thread, cx| {
1075 thread.update_retry_status(status, cx)
1076 })?;
1077 }
1078 ThreadEvent::Stop(stop_reason) => {
1079 log::debug!("Assistant message complete: {:?}", stop_reason);
1080 return Ok(acp::PromptResponse::new(stop_reason));
1081 }
1082 }
1083 }
1084 Err(e) => {
1085 log::error!("Error in model response stream: {:?}", e);
1086 return Err(e);
1087 }
1088 }
1089 }
1090
1091 log::debug!("Response stream completed");
1092 anyhow::Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
1093 })
1094 }
1095}
1096
1097struct Command<'a> {
1098 prompt_name: &'a str,
1099 arg_value: &'a str,
1100 explicit_server_id: Option<&'a str>,
1101}
1102
1103impl<'a> Command<'a> {
1104 fn parse(prompt: &'a [acp::ContentBlock]) -> Option<Self> {
1105 let acp::ContentBlock::Text(text_content) = prompt.first()? else {
1106 return None;
1107 };
1108 let text = text_content.text.trim();
1109 let command = text.strip_prefix('/')?;
1110 let (command, arg_value) = command
1111 .split_once(char::is_whitespace)
1112 .unwrap_or((command, ""));
1113
1114 if let Some((server_id, prompt_name)) = command.split_once('.') {
1115 Some(Self {
1116 prompt_name,
1117 arg_value,
1118 explicit_server_id: Some(server_id),
1119 })
1120 } else {
1121 Some(Self {
1122 prompt_name: command,
1123 arg_value,
1124 explicit_server_id: None,
1125 })
1126 }
1127 }
1128}
1129
1130struct NativeAgentModelSelector {
1131 session_id: acp::SessionId,
1132 connection: NativeAgentConnection,
1133}
1134
1135impl acp_thread::AgentModelSelector for NativeAgentModelSelector {
1136 fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
1137 log::debug!("NativeAgentConnection::list_models called");
1138 let list = self.connection.0.read(cx).models.model_list.clone();
1139 Task::ready(if list.is_empty() {
1140 Err(anyhow::anyhow!("No models available"))
1141 } else {
1142 Ok(list)
1143 })
1144 }
1145
1146 fn select_model(&self, model_id: acp::ModelId, cx: &mut App) -> Task<Result<()>> {
1147 log::debug!(
1148 "Setting model for session {}: {}",
1149 self.session_id,
1150 model_id
1151 );
1152 let Some(thread) = self
1153 .connection
1154 .0
1155 .read(cx)
1156 .sessions
1157 .get(&self.session_id)
1158 .map(|session| session.thread.clone())
1159 else {
1160 return Task::ready(Err(anyhow!("Session not found")));
1161 };
1162
1163 let Some(model) = self.connection.0.read(cx).models.model_from_id(&model_id) else {
1164 return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
1165 };
1166
1167 // We want to reset the effort level when switching models, as the currently-selected effort level may
1168 // not be compatible.
1169 let effort = model
1170 .default_effort_level()
1171 .map(|effort_level| effort_level.value.to_string());
1172
1173 thread.update(cx, |thread, cx| {
1174 thread.set_model(model.clone(), cx);
1175 thread.set_thinking_effort(effort.clone(), cx);
1176 thread.set_thinking_enabled(model.supports_thinking(), cx);
1177 });
1178
1179 update_settings_file(
1180 self.connection.0.read(cx).fs.clone(),
1181 cx,
1182 move |settings, cx| {
1183 let provider = model.provider_id().0.to_string();
1184 let model = model.id().0.to_string();
1185 let enable_thinking = thread.read(cx).thinking_enabled();
1186 settings
1187 .agent
1188 .get_or_insert_default()
1189 .set_model(LanguageModelSelection {
1190 provider: provider.into(),
1191 model,
1192 enable_thinking,
1193 effort,
1194 });
1195 },
1196 );
1197
1198 Task::ready(Ok(()))
1199 }
1200
1201 fn selected_model(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelInfo>> {
1202 let Some(thread) = self
1203 .connection
1204 .0
1205 .read(cx)
1206 .sessions
1207 .get(&self.session_id)
1208 .map(|session| session.thread.clone())
1209 else {
1210 return Task::ready(Err(anyhow!("Session not found")));
1211 };
1212 let Some(model) = thread.read(cx).model() else {
1213 return Task::ready(Err(anyhow!("Model not found")));
1214 };
1215 let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
1216 else {
1217 return Task::ready(Err(anyhow!("Provider not found")));
1218 };
1219 Task::ready(Ok(LanguageModels::map_language_model_to_info(
1220 model, &provider,
1221 )))
1222 }
1223
1224 fn watch(&self, cx: &mut App) -> Option<watch::Receiver<()>> {
1225 Some(self.connection.0.read(cx).models.watch())
1226 }
1227
1228 fn should_render_footer(&self) -> bool {
1229 true
1230 }
1231}
1232
1233impl acp_thread::AgentConnection for NativeAgentConnection {
1234 fn telemetry_id(&self) -> SharedString {
1235 "zed".into()
1236 }
1237
1238 fn new_session(
1239 self: Rc<Self>,
1240 project: Entity<Project>,
1241 cwd: &Path,
1242 cx: &mut App,
1243 ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
1244 log::debug!("Creating new thread for project at: {cwd:?}");
1245 Task::ready(Ok(self
1246 .0
1247 .update(cx, |agent, cx| agent.new_session(project, cx))))
1248 }
1249
1250 fn supports_load_session(&self) -> bool {
1251 true
1252 }
1253
1254 fn load_session(
1255 self: Rc<Self>,
1256 session: AgentSessionInfo,
1257 _project: Entity<Project>,
1258 _cwd: &Path,
1259 cx: &mut App,
1260 ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
1261 self.0
1262 .update(cx, |agent, cx| agent.open_thread(session.session_id, cx))
1263 }
1264
1265 fn supports_close_session(&self) -> bool {
1266 true
1267 }
1268
1269 fn close_session(&self, session_id: &acp::SessionId, cx: &mut App) -> Task<Result<()>> {
1270 self.0.update(cx, |agent, _cx| {
1271 agent.sessions.remove(session_id);
1272 });
1273 Task::ready(Ok(()))
1274 }
1275
1276 fn auth_methods(&self) -> &[acp::AuthMethod] {
1277 &[] // No auth for in-process
1278 }
1279
1280 fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
1281 Task::ready(Ok(()))
1282 }
1283
1284 fn model_selector(&self, session_id: &acp::SessionId) -> Option<Rc<dyn AgentModelSelector>> {
1285 Some(Rc::new(NativeAgentModelSelector {
1286 session_id: session_id.clone(),
1287 connection: self.clone(),
1288 }) as Rc<dyn AgentModelSelector>)
1289 }
1290
1291 fn prompt(
1292 &self,
1293 id: Option<acp_thread::UserMessageId>,
1294 params: acp::PromptRequest,
1295 cx: &mut App,
1296 ) -> Task<Result<acp::PromptResponse>> {
1297 let id = id.expect("UserMessageId is required");
1298 let session_id = params.session_id.clone();
1299 log::info!("Received prompt request for session: {}", session_id);
1300 log::debug!("Prompt blocks count: {}", params.prompt.len());
1301
1302 if let Some(parsed_command) = Command::parse(¶ms.prompt) {
1303 let registry = self.0.read(cx).context_server_registry.read(cx);
1304
1305 let explicit_server_id = parsed_command
1306 .explicit_server_id
1307 .map(|server_id| ContextServerId(server_id.into()));
1308
1309 if let Some(prompt) =
1310 registry.find_prompt(explicit_server_id.as_ref(), parsed_command.prompt_name)
1311 {
1312 let arguments = if !parsed_command.arg_value.is_empty()
1313 && let Some(arg_name) = prompt
1314 .prompt
1315 .arguments
1316 .as_ref()
1317 .and_then(|args| args.first())
1318 .map(|arg| arg.name.clone())
1319 {
1320 HashMap::from_iter([(arg_name, parsed_command.arg_value.to_string())])
1321 } else {
1322 Default::default()
1323 };
1324
1325 let prompt_name = prompt.prompt.name.clone();
1326 let server_id = prompt.server_id.clone();
1327
1328 return self.0.update(cx, |agent, cx| {
1329 agent.send_mcp_prompt(
1330 id,
1331 session_id.clone(),
1332 prompt_name,
1333 server_id,
1334 arguments,
1335 params.prompt,
1336 cx,
1337 )
1338 });
1339 };
1340 };
1341
1342 let path_style = self.0.read(cx).project.read(cx).path_style(cx);
1343
1344 self.run_turn(session_id, cx, move |thread, cx| {
1345 let content: Vec<UserMessageContent> = params
1346 .prompt
1347 .into_iter()
1348 .map(|block| UserMessageContent::from_content_block(block, path_style))
1349 .collect::<Vec<_>>();
1350 log::debug!("Converted prompt to message: {} chars", content.len());
1351 log::debug!("Message id: {:?}", id);
1352 log::debug!("Message content: {:?}", content);
1353
1354 thread.update(cx, |thread, cx| thread.send(id, content, cx))
1355 })
1356 }
1357
1358 fn retry(
1359 &self,
1360 session_id: &acp::SessionId,
1361 _cx: &App,
1362 ) -> Option<Rc<dyn acp_thread::AgentSessionRetry>> {
1363 Some(Rc::new(NativeAgentSessionRetry {
1364 connection: self.clone(),
1365 session_id: session_id.clone(),
1366 }) as _)
1367 }
1368
1369 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
1370 log::info!("Cancelling on session: {}", session_id);
1371 self.0.update(cx, |agent, cx| {
1372 if let Some(session) = agent.sessions.get(session_id) {
1373 session
1374 .thread
1375 .update(cx, |thread, cx| thread.cancel(cx))
1376 .detach();
1377 }
1378 });
1379 }
1380
1381 fn truncate(
1382 &self,
1383 session_id: &agent_client_protocol::SessionId,
1384 cx: &App,
1385 ) -> Option<Rc<dyn acp_thread::AgentSessionTruncate>> {
1386 self.0.read_with(cx, |agent, _cx| {
1387 agent.sessions.get(session_id).map(|session| {
1388 Rc::new(NativeAgentSessionTruncate {
1389 thread: session.thread.clone(),
1390 acp_thread: session.acp_thread.downgrade(),
1391 }) as _
1392 })
1393 })
1394 }
1395
1396 fn set_title(
1397 &self,
1398 session_id: &acp::SessionId,
1399 cx: &App,
1400 ) -> Option<Rc<dyn acp_thread::AgentSessionSetTitle>> {
1401 self.0.read_with(cx, |agent, _cx| {
1402 agent
1403 .sessions
1404 .get(session_id)
1405 .filter(|s| !s.thread.read(cx).is_subagent())
1406 .map(|session| {
1407 Rc::new(NativeAgentSessionSetTitle {
1408 thread: session.thread.clone(),
1409 }) as _
1410 })
1411 })
1412 }
1413
1414 fn session_list(&self, cx: &mut App) -> Option<Rc<dyn AgentSessionList>> {
1415 let thread_store = self.0.read(cx).thread_store.clone();
1416 Some(Rc::new(NativeAgentSessionList::new(thread_store, cx)) as _)
1417 }
1418
1419 fn telemetry(&self) -> Option<Rc<dyn acp_thread::AgentTelemetry>> {
1420 Some(Rc::new(self.clone()) as Rc<dyn acp_thread::AgentTelemetry>)
1421 }
1422
1423 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1424 self
1425 }
1426}
1427
1428impl acp_thread::AgentTelemetry for NativeAgentConnection {
1429 fn thread_data(
1430 &self,
1431 session_id: &acp::SessionId,
1432 cx: &mut App,
1433 ) -> Task<Result<serde_json::Value>> {
1434 let Some(session) = self.0.read(cx).sessions.get(session_id) else {
1435 return Task::ready(Err(anyhow!("Session not found")));
1436 };
1437
1438 let task = session.thread.read(cx).to_db(cx);
1439 cx.background_spawn(async move {
1440 serde_json::to_value(task.await).context("Failed to serialize thread")
1441 })
1442 }
1443}
1444
1445pub struct NativeAgentSessionList {
1446 thread_store: Entity<ThreadStore>,
1447 updates_tx: smol::channel::Sender<acp_thread::SessionListUpdate>,
1448 updates_rx: smol::channel::Receiver<acp_thread::SessionListUpdate>,
1449 _subscription: Subscription,
1450}
1451
1452impl NativeAgentSessionList {
1453 fn new(thread_store: Entity<ThreadStore>, cx: &mut App) -> Self {
1454 let (tx, rx) = smol::channel::unbounded();
1455 let this_tx = tx.clone();
1456 let subscription = cx.observe(&thread_store, move |_, _| {
1457 this_tx
1458 .try_send(acp_thread::SessionListUpdate::Refresh)
1459 .ok();
1460 });
1461 Self {
1462 thread_store,
1463 updates_tx: tx,
1464 updates_rx: rx,
1465 _subscription: subscription,
1466 }
1467 }
1468
1469 fn to_session_info(entry: DbThreadMetadata) -> AgentSessionInfo {
1470 AgentSessionInfo {
1471 session_id: entry.id,
1472 cwd: None,
1473 title: Some(entry.title),
1474 updated_at: Some(entry.updated_at),
1475 meta: None,
1476 }
1477 }
1478
1479 pub fn thread_store(&self) -> &Entity<ThreadStore> {
1480 &self.thread_store
1481 }
1482}
1483
1484impl AgentSessionList for NativeAgentSessionList {
1485 fn list_sessions(
1486 &self,
1487 _request: AgentSessionListRequest,
1488 cx: &mut App,
1489 ) -> Task<Result<AgentSessionListResponse>> {
1490 let sessions = self
1491 .thread_store
1492 .read(cx)
1493 .entries()
1494 .map(Self::to_session_info)
1495 .collect();
1496 Task::ready(Ok(AgentSessionListResponse::new(sessions)))
1497 }
1498
1499 fn supports_delete(&self) -> bool {
1500 true
1501 }
1502
1503 fn delete_session(&self, session_id: &acp::SessionId, cx: &mut App) -> Task<Result<()>> {
1504 self.thread_store
1505 .update(cx, |store, cx| store.delete_thread(session_id.clone(), cx))
1506 }
1507
1508 fn delete_sessions(&self, cx: &mut App) -> Task<Result<()>> {
1509 self.thread_store
1510 .update(cx, |store, cx| store.delete_threads(cx))
1511 }
1512
1513 fn watch(
1514 &self,
1515 _cx: &mut App,
1516 ) -> Option<smol::channel::Receiver<acp_thread::SessionListUpdate>> {
1517 Some(self.updates_rx.clone())
1518 }
1519
1520 fn notify_refresh(&self) {
1521 self.updates_tx
1522 .try_send(acp_thread::SessionListUpdate::Refresh)
1523 .ok();
1524 }
1525
1526 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1527 self
1528 }
1529}
1530
1531struct NativeAgentSessionTruncate {
1532 thread: Entity<Thread>,
1533 acp_thread: WeakEntity<AcpThread>,
1534}
1535
1536impl acp_thread::AgentSessionTruncate for NativeAgentSessionTruncate {
1537 fn run(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
1538 match self.thread.update(cx, |thread, cx| {
1539 thread.truncate(message_id.clone(), cx)?;
1540 Ok(thread.latest_token_usage())
1541 }) {
1542 Ok(usage) => {
1543 self.acp_thread
1544 .update(cx, |thread, cx| {
1545 thread.update_token_usage(usage, cx);
1546 })
1547 .ok();
1548 Task::ready(Ok(()))
1549 }
1550 Err(error) => Task::ready(Err(error)),
1551 }
1552 }
1553}
1554
1555struct NativeAgentSessionRetry {
1556 connection: NativeAgentConnection,
1557 session_id: acp::SessionId,
1558}
1559
1560impl acp_thread::AgentSessionRetry for NativeAgentSessionRetry {
1561 fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>> {
1562 self.connection
1563 .run_turn(self.session_id.clone(), cx, |thread, cx| {
1564 thread.update(cx, |thread, cx| thread.resume(cx))
1565 })
1566 }
1567}
1568
1569struct NativeAgentSessionSetTitle {
1570 thread: Entity<Thread>,
1571}
1572
1573impl acp_thread::AgentSessionSetTitle for NativeAgentSessionSetTitle {
1574 fn run(&self, title: SharedString, cx: &mut App) -> Task<Result<()>> {
1575 self.thread
1576 .update(cx, |thread, cx| thread.set_title(title, cx));
1577 Task::ready(Ok(()))
1578 }
1579}
1580
1581pub struct NativeThreadEnvironment {
1582 agent: WeakEntity<NativeAgent>,
1583 acp_thread: WeakEntity<AcpThread>,
1584}
1585
1586impl NativeThreadEnvironment {
1587 pub(crate) fn create_subagent_thread(
1588 agent: WeakEntity<NativeAgent>,
1589 parent_thread_entity: Entity<Thread>,
1590 label: String,
1591 initial_prompt: String,
1592 timeout: Option<Duration>,
1593 cx: &mut App,
1594 ) -> Result<Rc<dyn SubagentHandle>> {
1595 let parent_thread = parent_thread_entity.read(cx);
1596 let current_depth = parent_thread.depth();
1597
1598 if current_depth >= MAX_SUBAGENT_DEPTH {
1599 return Err(anyhow!(
1600 "Maximum subagent depth ({}) reached",
1601 MAX_SUBAGENT_DEPTH
1602 ));
1603 }
1604 let allowed_tool_names = Some(parent_thread.tools.keys().cloned().collect::<Vec<_>>());
1605
1606 let subagent_thread: Entity<Thread> = cx.new(|cx| {
1607 let mut thread = Thread::new_subagent(&parent_thread_entity, cx);
1608 thread.set_title(label.into(), cx);
1609 thread
1610 });
1611
1612 let session_id = subagent_thread.read(cx).id().clone();
1613
1614 let acp_thread = agent.update(cx, |agent, cx| {
1615 agent.register_session(subagent_thread.clone(), allowed_tool_names, cx)
1616 })?;
1617
1618 parent_thread_entity.update(cx, |parent_thread, _cx| {
1619 parent_thread.register_running_subagent(subagent_thread.downgrade())
1620 });
1621
1622 let task = acp_thread.update(cx, |agent, cx| agent.send(vec![initial_prompt.into()], cx));
1623
1624 let timeout_timer = timeout.map(|d| cx.background_executor().timer(d));
1625 let wait_for_prompt_to_complete = cx
1626 .background_spawn(async move {
1627 if let Some(timer) = timeout_timer {
1628 futures::select! {
1629 _ = timer.fuse() => SubagentInitialPromptResult::Timeout,
1630 response = task.fuse() => {
1631 let response = response.log_err().flatten();
1632 if response.is_some_and(|response| {
1633 response.stop_reason == acp::StopReason::Cancelled
1634 })
1635 {
1636 SubagentInitialPromptResult::Cancelled
1637 } else {
1638 SubagentInitialPromptResult::Completed
1639 }
1640 },
1641 }
1642 } else {
1643 let response = task.await.log_err().flatten();
1644 if response
1645 .is_some_and(|response| response.stop_reason == acp::StopReason::Cancelled)
1646 {
1647 SubagentInitialPromptResult::Cancelled
1648 } else {
1649 SubagentInitialPromptResult::Completed
1650 }
1651 }
1652 })
1653 .shared();
1654
1655 Ok(Rc::new(NativeSubagentHandle {
1656 session_id,
1657 subagent_thread,
1658 parent_thread: parent_thread_entity.downgrade(),
1659 wait_for_prompt_to_complete,
1660 }) as _)
1661 }
1662}
1663
1664impl ThreadEnvironment for NativeThreadEnvironment {
1665 fn create_terminal(
1666 &self,
1667 command: String,
1668 cwd: Option<PathBuf>,
1669 output_byte_limit: Option<u64>,
1670 cx: &mut AsyncApp,
1671 ) -> Task<Result<Rc<dyn TerminalHandle>>> {
1672 let task = self.acp_thread.update(cx, |thread, cx| {
1673 thread.create_terminal(command, vec![], vec![], cwd, output_byte_limit, cx)
1674 });
1675
1676 let acp_thread = self.acp_thread.clone();
1677 cx.spawn(async move |cx| {
1678 let terminal = task?.await?;
1679
1680 let (drop_tx, drop_rx) = oneshot::channel();
1681 let terminal_id = terminal.read_with(cx, |terminal, _cx| terminal.id().clone());
1682
1683 cx.spawn(async move |cx| {
1684 drop_rx.await.ok();
1685 acp_thread.update(cx, |thread, cx| thread.release_terminal(terminal_id, cx))
1686 })
1687 .detach();
1688
1689 let handle = AcpTerminalHandle {
1690 terminal,
1691 _drop_tx: Some(drop_tx),
1692 };
1693
1694 Ok(Rc::new(handle) as _)
1695 })
1696 }
1697
1698 fn create_subagent(
1699 &self,
1700 parent_thread_entity: Entity<Thread>,
1701 label: String,
1702 initial_prompt: String,
1703 timeout: Option<Duration>,
1704 cx: &mut App,
1705 ) -> Result<Rc<dyn SubagentHandle>> {
1706 Self::create_subagent_thread(
1707 self.agent.clone(),
1708 parent_thread_entity,
1709 label,
1710 initial_prompt,
1711 timeout,
1712 cx,
1713 )
1714 }
1715}
1716
1717#[derive(Debug, Clone, Copy)]
1718enum SubagentInitialPromptResult {
1719 Completed,
1720 Timeout,
1721 Cancelled,
1722}
1723
1724pub struct NativeSubagentHandle {
1725 session_id: acp::SessionId,
1726 parent_thread: WeakEntity<Thread>,
1727 subagent_thread: Entity<Thread>,
1728 wait_for_prompt_to_complete: Shared<Task<SubagentInitialPromptResult>>,
1729}
1730
1731impl SubagentHandle for NativeSubagentHandle {
1732 fn id(&self) -> acp::SessionId {
1733 self.session_id.clone()
1734 }
1735
1736 fn wait_for_output(&self, cx: &AsyncApp) -> Task<Result<String>> {
1737 let thread = self.subagent_thread.clone();
1738 let wait_for_prompt = self.wait_for_prompt_to_complete.clone();
1739
1740 let subagent_session_id = self.session_id.clone();
1741 let parent_thread = self.parent_thread.clone();
1742
1743 cx.spawn(async move |cx| {
1744 match wait_for_prompt.await {
1745 SubagentInitialPromptResult::Completed => {}
1746 SubagentInitialPromptResult::Timeout => {
1747 return Err(anyhow!("The time to complete the task was exceeded."));
1748 }
1749 SubagentInitialPromptResult::Cancelled => return Err(anyhow!("User cancelled")),
1750 };
1751
1752 let result = thread.read_with(cx, |thread, _cx| {
1753 thread
1754 .last_message()
1755 .map(|m| m.to_markdown())
1756 .context("No response from subagent")
1757 });
1758
1759 parent_thread
1760 .update(cx, |parent_thread, cx| {
1761 parent_thread.unregister_running_subagent(&subagent_session_id, cx)
1762 })
1763 .ok();
1764
1765 result
1766 })
1767 }
1768}
1769
1770pub struct AcpTerminalHandle {
1771 terminal: Entity<acp_thread::Terminal>,
1772 _drop_tx: Option<oneshot::Sender<()>>,
1773}
1774
1775impl TerminalHandle for AcpTerminalHandle {
1776 fn id(&self, cx: &AsyncApp) -> Result<acp::TerminalId> {
1777 Ok(self.terminal.read_with(cx, |term, _cx| term.id().clone()))
1778 }
1779
1780 fn wait_for_exit(&self, cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
1781 Ok(self
1782 .terminal
1783 .read_with(cx, |term, _cx| term.wait_for_exit()))
1784 }
1785
1786 fn current_output(&self, cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
1787 Ok(self
1788 .terminal
1789 .read_with(cx, |term, cx| term.current_output(cx)))
1790 }
1791
1792 fn kill(&self, cx: &AsyncApp) -> Result<()> {
1793 cx.update(|cx| {
1794 self.terminal.update(cx, |terminal, cx| {
1795 terminal.kill(cx);
1796 });
1797 });
1798 Ok(())
1799 }
1800
1801 fn was_stopped_by_user(&self, cx: &AsyncApp) -> Result<bool> {
1802 Ok(self
1803 .terminal
1804 .read_with(cx, |term, _cx| term.was_stopped_by_user()))
1805 }
1806}
1807
1808#[cfg(test)]
1809mod internal_tests {
1810 use super::*;
1811 use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelInfo, MentionUri};
1812 use fs::FakeFs;
1813 use gpui::TestAppContext;
1814 use indoc::formatdoc;
1815 use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
1816 use language_model::{LanguageModelProviderId, LanguageModelProviderName};
1817 use serde_json::json;
1818 use settings::SettingsStore;
1819 use util::{path, rel_path::rel_path};
1820
1821 #[gpui::test]
1822 async fn test_maintaining_project_context(cx: &mut TestAppContext) {
1823 init_test(cx);
1824 let fs = FakeFs::new(cx.executor());
1825 fs.insert_tree(
1826 "/",
1827 json!({
1828 "a": {}
1829 }),
1830 )
1831 .await;
1832 let project = Project::test(fs.clone(), [], cx).await;
1833 let thread_store = cx.new(|cx| ThreadStore::new(cx));
1834 let agent = NativeAgent::new(
1835 project.clone(),
1836 thread_store,
1837 Templates::new(),
1838 None,
1839 fs.clone(),
1840 &mut cx.to_async(),
1841 )
1842 .await
1843 .unwrap();
1844 agent.read_with(cx, |agent, cx| {
1845 assert_eq!(agent.project_context.read(cx).worktrees, vec![])
1846 });
1847
1848 let worktree = project
1849 .update(cx, |project, cx| project.create_worktree("/a", true, cx))
1850 .await
1851 .unwrap();
1852 cx.run_until_parked();
1853 agent.read_with(cx, |agent, cx| {
1854 assert_eq!(
1855 agent.project_context.read(cx).worktrees,
1856 vec![WorktreeContext {
1857 root_name: "a".into(),
1858 abs_path: Path::new("/a").into(),
1859 rules_file: None
1860 }]
1861 )
1862 });
1863
1864 // Creating `/a/.rules` updates the project context.
1865 fs.insert_file("/a/.rules", Vec::new()).await;
1866 cx.run_until_parked();
1867 agent.read_with(cx, |agent, cx| {
1868 let rules_entry = worktree
1869 .read(cx)
1870 .entry_for_path(rel_path(".rules"))
1871 .unwrap();
1872 assert_eq!(
1873 agent.project_context.read(cx).worktrees,
1874 vec![WorktreeContext {
1875 root_name: "a".into(),
1876 abs_path: Path::new("/a").into(),
1877 rules_file: Some(RulesFileContext {
1878 path_in_worktree: rel_path(".rules").into(),
1879 text: "".into(),
1880 project_entry_id: rules_entry.id.to_usize()
1881 })
1882 }]
1883 )
1884 });
1885 }
1886
1887 #[gpui::test]
1888 async fn test_listing_models(cx: &mut TestAppContext) {
1889 init_test(cx);
1890 let fs = FakeFs::new(cx.executor());
1891 fs.insert_tree("/", json!({ "a": {} })).await;
1892 let project = Project::test(fs.clone(), [], cx).await;
1893 let thread_store = cx.new(|cx| ThreadStore::new(cx));
1894 let connection = NativeAgentConnection(
1895 NativeAgent::new(
1896 project.clone(),
1897 thread_store,
1898 Templates::new(),
1899 None,
1900 fs.clone(),
1901 &mut cx.to_async(),
1902 )
1903 .await
1904 .unwrap(),
1905 );
1906
1907 // Create a thread/session
1908 let acp_thread = cx
1909 .update(|cx| {
1910 Rc::new(connection.clone()).new_session(project.clone(), Path::new("/a"), cx)
1911 })
1912 .await
1913 .unwrap();
1914
1915 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
1916
1917 let models = cx
1918 .update(|cx| {
1919 connection
1920 .model_selector(&session_id)
1921 .unwrap()
1922 .list_models(cx)
1923 })
1924 .await
1925 .unwrap();
1926
1927 let acp_thread::AgentModelList::Grouped(models) = models else {
1928 panic!("Unexpected model group");
1929 };
1930 assert_eq!(
1931 models,
1932 IndexMap::from_iter([(
1933 AgentModelGroupName("Fake".into()),
1934 vec![AgentModelInfo {
1935 id: acp::ModelId::new("fake/fake"),
1936 name: "Fake".into(),
1937 description: None,
1938 icon: Some(acp_thread::AgentModelIcon::Named(
1939 ui::IconName::ZedAssistant
1940 )),
1941 is_latest: false,
1942 cost: None,
1943 }]
1944 )])
1945 );
1946 }
1947
1948 #[gpui::test]
1949 async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) {
1950 init_test(cx);
1951 let fs = FakeFs::new(cx.executor());
1952 fs.create_dir(paths::settings_file().parent().unwrap())
1953 .await
1954 .unwrap();
1955 fs.insert_file(
1956 paths::settings_file(),
1957 json!({
1958 "agent": {
1959 "default_model": {
1960 "provider": "foo",
1961 "model": "bar"
1962 }
1963 }
1964 })
1965 .to_string()
1966 .into_bytes(),
1967 )
1968 .await;
1969 let project = Project::test(fs.clone(), [], cx).await;
1970
1971 let thread_store = cx.new(|cx| ThreadStore::new(cx));
1972
1973 // Create the agent and connection
1974 let agent = NativeAgent::new(
1975 project.clone(),
1976 thread_store,
1977 Templates::new(),
1978 None,
1979 fs.clone(),
1980 &mut cx.to_async(),
1981 )
1982 .await
1983 .unwrap();
1984 let connection = NativeAgentConnection(agent.clone());
1985
1986 // Create a thread/session
1987 let acp_thread = cx
1988 .update(|cx| {
1989 Rc::new(connection.clone()).new_session(project.clone(), Path::new("/a"), cx)
1990 })
1991 .await
1992 .unwrap();
1993
1994 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
1995
1996 // Select a model
1997 let selector = connection.model_selector(&session_id).unwrap();
1998 let model_id = acp::ModelId::new("fake/fake");
1999 cx.update(|cx| selector.select_model(model_id.clone(), cx))
2000 .await
2001 .unwrap();
2002
2003 // Verify the thread has the selected model
2004 agent.read_with(cx, |agent, _| {
2005 let session = agent.sessions.get(&session_id).unwrap();
2006 session.thread.read_with(cx, |thread, _| {
2007 assert_eq!(thread.model().unwrap().id().0, "fake");
2008 });
2009 });
2010
2011 cx.run_until_parked();
2012
2013 // Verify settings file was updated
2014 let settings_content = fs.load(paths::settings_file()).await.unwrap();
2015 let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
2016
2017 // Check that the agent settings contain the selected model
2018 assert_eq!(
2019 settings_json["agent"]["default_model"]["model"],
2020 json!("fake")
2021 );
2022 assert_eq!(
2023 settings_json["agent"]["default_model"]["provider"],
2024 json!("fake")
2025 );
2026
2027 // Register a thinking model and select it.
2028 cx.update(|cx| {
2029 let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2030 "fake-corp",
2031 "fake-thinking",
2032 "Fake Thinking",
2033 true,
2034 ));
2035 let thinking_provider = Arc::new(
2036 FakeLanguageModelProvider::new(
2037 LanguageModelProviderId::from("fake-corp".to_string()),
2038 LanguageModelProviderName::from("Fake Corp".to_string()),
2039 )
2040 .with_models(vec![thinking_model]),
2041 );
2042 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2043 registry.register_provider(thinking_provider, cx);
2044 });
2045 });
2046 agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2047
2048 let selector = connection.model_selector(&session_id).unwrap();
2049 cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2050 .await
2051 .unwrap();
2052 cx.run_until_parked();
2053
2054 // Verify enable_thinking was written to settings as true.
2055 let settings_content = fs.load(paths::settings_file()).await.unwrap();
2056 let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
2057 assert_eq!(
2058 settings_json["agent"]["default_model"]["enable_thinking"],
2059 json!(true),
2060 "selecting a thinking model should persist enable_thinking: true to settings"
2061 );
2062 }
2063
2064 #[gpui::test]
2065 async fn test_select_model_updates_thinking_enabled(cx: &mut TestAppContext) {
2066 init_test(cx);
2067 let fs = FakeFs::new(cx.executor());
2068 fs.create_dir(paths::settings_file().parent().unwrap())
2069 .await
2070 .unwrap();
2071 fs.insert_file(paths::settings_file(), b"{}".to_vec()).await;
2072 let project = Project::test(fs.clone(), [], cx).await;
2073
2074 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2075 let agent = NativeAgent::new(
2076 project.clone(),
2077 thread_store,
2078 Templates::new(),
2079 None,
2080 fs.clone(),
2081 &mut cx.to_async(),
2082 )
2083 .await
2084 .unwrap();
2085 let connection = NativeAgentConnection(agent.clone());
2086
2087 let acp_thread = cx
2088 .update(|cx| {
2089 Rc::new(connection.clone()).new_session(project.clone(), Path::new("/a"), cx)
2090 })
2091 .await
2092 .unwrap();
2093 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2094
2095 // Register a second provider with a thinking model.
2096 cx.update(|cx| {
2097 let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2098 "fake-corp",
2099 "fake-thinking",
2100 "Fake Thinking",
2101 true,
2102 ));
2103 let thinking_provider = Arc::new(
2104 FakeLanguageModelProvider::new(
2105 LanguageModelProviderId::from("fake-corp".to_string()),
2106 LanguageModelProviderName::from("Fake Corp".to_string()),
2107 )
2108 .with_models(vec![thinking_model]),
2109 );
2110 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2111 registry.register_provider(thinking_provider, cx);
2112 });
2113 });
2114 // Refresh the agent's model list so it picks up the new provider.
2115 agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2116
2117 // Thread starts with thinking_enabled = false (the default).
2118 agent.read_with(cx, |agent, _| {
2119 let session = agent.sessions.get(&session_id).unwrap();
2120 session.thread.read_with(cx, |thread, _| {
2121 assert!(!thread.thinking_enabled(), "thinking defaults to false");
2122 });
2123 });
2124
2125 // Select the thinking model via select_model.
2126 let selector = connection.model_selector(&session_id).unwrap();
2127 cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2128 .await
2129 .unwrap();
2130
2131 // select_model should have enabled thinking based on the model's supports_thinking().
2132 agent.read_with(cx, |agent, _| {
2133 let session = agent.sessions.get(&session_id).unwrap();
2134 session.thread.read_with(cx, |thread, _| {
2135 assert!(
2136 thread.thinking_enabled(),
2137 "select_model should enable thinking when model supports it"
2138 );
2139 });
2140 });
2141
2142 // Switch back to the non-thinking model.
2143 let selector = connection.model_selector(&session_id).unwrap();
2144 cx.update(|cx| selector.select_model(acp::ModelId::new("fake/fake"), cx))
2145 .await
2146 .unwrap();
2147
2148 // select_model should have disabled thinking.
2149 agent.read_with(cx, |agent, _| {
2150 let session = agent.sessions.get(&session_id).unwrap();
2151 session.thread.read_with(cx, |thread, _| {
2152 assert!(
2153 !thread.thinking_enabled(),
2154 "select_model should disable thinking when model does not support it"
2155 );
2156 });
2157 });
2158 }
2159
2160 #[gpui::test]
2161 async fn test_loaded_thread_preserves_thinking_enabled(cx: &mut TestAppContext) {
2162 init_test(cx);
2163 let fs = FakeFs::new(cx.executor());
2164 fs.insert_tree("/", json!({ "a": {} })).await;
2165 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2166 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2167 let agent = NativeAgent::new(
2168 project.clone(),
2169 thread_store.clone(),
2170 Templates::new(),
2171 None,
2172 fs.clone(),
2173 &mut cx.to_async(),
2174 )
2175 .await
2176 .unwrap();
2177 let connection = Rc::new(NativeAgentConnection(agent.clone()));
2178
2179 // Register a thinking model.
2180 let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2181 "fake-corp",
2182 "fake-thinking",
2183 "Fake Thinking",
2184 true,
2185 ));
2186 let thinking_provider = Arc::new(
2187 FakeLanguageModelProvider::new(
2188 LanguageModelProviderId::from("fake-corp".to_string()),
2189 LanguageModelProviderName::from("Fake Corp".to_string()),
2190 )
2191 .with_models(vec![thinking_model.clone()]),
2192 );
2193 cx.update(|cx| {
2194 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2195 registry.register_provider(thinking_provider, cx);
2196 });
2197 });
2198 agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2199
2200 // Create a thread and select the thinking model.
2201 let acp_thread = cx
2202 .update(|cx| {
2203 connection
2204 .clone()
2205 .new_session(project.clone(), Path::new("/a"), cx)
2206 })
2207 .await
2208 .unwrap();
2209 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2210
2211 let selector = connection.model_selector(&session_id).unwrap();
2212 cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2213 .await
2214 .unwrap();
2215
2216 // Verify thinking is enabled after selecting the thinking model.
2217 let thread = agent.read_with(cx, |agent, _| {
2218 agent.sessions.get(&session_id).unwrap().thread.clone()
2219 });
2220 thread.read_with(cx, |thread, _| {
2221 assert!(
2222 thread.thinking_enabled(),
2223 "thinking should be enabled after selecting thinking model"
2224 );
2225 });
2226
2227 // Send a message so the thread gets persisted.
2228 let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
2229 let send = cx.foreground_executor().spawn(send);
2230 cx.run_until_parked();
2231
2232 thinking_model.send_last_completion_stream_text_chunk("Response.");
2233 thinking_model.end_last_completion_stream();
2234
2235 send.await.unwrap();
2236 cx.run_until_parked();
2237
2238 // Close the session so it can be reloaded from disk.
2239 cx.update(|cx| connection.clone().close_session(&session_id, cx))
2240 .await
2241 .unwrap();
2242 drop(thread);
2243 drop(acp_thread);
2244 agent.read_with(cx, |agent, _| {
2245 assert!(agent.sessions.is_empty());
2246 });
2247
2248 // Reload the thread and verify thinking_enabled is still true.
2249 let reloaded_acp_thread = agent
2250 .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
2251 .await
2252 .unwrap();
2253 let reloaded_thread = agent.read_with(cx, |agent, _| {
2254 agent.sessions.get(&session_id).unwrap().thread.clone()
2255 });
2256 reloaded_thread.read_with(cx, |thread, _| {
2257 assert!(
2258 thread.thinking_enabled(),
2259 "thinking_enabled should be preserved when reloading a thread with a thinking model"
2260 );
2261 });
2262
2263 drop(reloaded_acp_thread);
2264 }
2265
2266 #[gpui::test]
2267 async fn test_loaded_thread_preserves_model(cx: &mut TestAppContext) {
2268 init_test(cx);
2269 let fs = FakeFs::new(cx.executor());
2270 fs.insert_tree("/", json!({ "a": {} })).await;
2271 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2272 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2273 let agent = NativeAgent::new(
2274 project.clone(),
2275 thread_store.clone(),
2276 Templates::new(),
2277 None,
2278 fs.clone(),
2279 &mut cx.to_async(),
2280 )
2281 .await
2282 .unwrap();
2283 let connection = Rc::new(NativeAgentConnection(agent.clone()));
2284
2285 // Register a model where id() != name(), like real Anthropic models
2286 // (e.g. id="claude-sonnet-4-5-thinking-latest", name="Claude Sonnet 4.5 Thinking").
2287 let model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2288 "fake-corp",
2289 "custom-model-id",
2290 "Custom Model Display Name",
2291 false,
2292 ));
2293 let provider = Arc::new(
2294 FakeLanguageModelProvider::new(
2295 LanguageModelProviderId::from("fake-corp".to_string()),
2296 LanguageModelProviderName::from("Fake Corp".to_string()),
2297 )
2298 .with_models(vec![model.clone()]),
2299 );
2300 cx.update(|cx| {
2301 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2302 registry.register_provider(provider, cx);
2303 });
2304 });
2305 agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2306
2307 // Create a thread and select the model.
2308 let acp_thread = cx
2309 .update(|cx| {
2310 connection
2311 .clone()
2312 .new_session(project.clone(), Path::new("/a"), cx)
2313 })
2314 .await
2315 .unwrap();
2316 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2317
2318 let selector = connection.model_selector(&session_id).unwrap();
2319 cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/custom-model-id"), cx))
2320 .await
2321 .unwrap();
2322
2323 let thread = agent.read_with(cx, |agent, _| {
2324 agent.sessions.get(&session_id).unwrap().thread.clone()
2325 });
2326 thread.read_with(cx, |thread, _| {
2327 assert_eq!(
2328 thread.model().unwrap().id().0.as_ref(),
2329 "custom-model-id",
2330 "model should be set before persisting"
2331 );
2332 });
2333
2334 // Send a message so the thread gets persisted.
2335 let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
2336 let send = cx.foreground_executor().spawn(send);
2337 cx.run_until_parked();
2338
2339 model.send_last_completion_stream_text_chunk("Response.");
2340 model.end_last_completion_stream();
2341
2342 send.await.unwrap();
2343 cx.run_until_parked();
2344
2345 // Close the session so it can be reloaded from disk.
2346 cx.update(|cx| connection.clone().close_session(&session_id, cx))
2347 .await
2348 .unwrap();
2349 drop(thread);
2350 drop(acp_thread);
2351 agent.read_with(cx, |agent, _| {
2352 assert!(agent.sessions.is_empty());
2353 });
2354
2355 // Reload the thread and verify the model was preserved.
2356 let reloaded_acp_thread = agent
2357 .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
2358 .await
2359 .unwrap();
2360 let reloaded_thread = agent.read_with(cx, |agent, _| {
2361 agent.sessions.get(&session_id).unwrap().thread.clone()
2362 });
2363 reloaded_thread.read_with(cx, |thread, _| {
2364 let reloaded_model = thread
2365 .model()
2366 .expect("model should be present after reload");
2367 assert_eq!(
2368 reloaded_model.id().0.as_ref(),
2369 "custom-model-id",
2370 "reloaded thread should have the same model, not fall back to the default"
2371 );
2372 });
2373
2374 drop(reloaded_acp_thread);
2375 }
2376
2377 #[gpui::test]
2378 async fn test_save_load_thread(cx: &mut TestAppContext) {
2379 init_test(cx);
2380 let fs = FakeFs::new(cx.executor());
2381 fs.insert_tree(
2382 "/",
2383 json!({
2384 "a": {
2385 "b.md": "Lorem"
2386 }
2387 }),
2388 )
2389 .await;
2390 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2391 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2392 let agent = NativeAgent::new(
2393 project.clone(),
2394 thread_store.clone(),
2395 Templates::new(),
2396 None,
2397 fs.clone(),
2398 &mut cx.to_async(),
2399 )
2400 .await
2401 .unwrap();
2402 let connection = Rc::new(NativeAgentConnection(agent.clone()));
2403
2404 let acp_thread = cx
2405 .update(|cx| {
2406 connection
2407 .clone()
2408 .new_session(project.clone(), Path::new(""), cx)
2409 })
2410 .await
2411 .unwrap();
2412 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2413 let thread = agent.read_with(cx, |agent, _| {
2414 agent.sessions.get(&session_id).unwrap().thread.clone()
2415 });
2416
2417 // Ensure empty threads are not saved, even if they get mutated.
2418 let model = Arc::new(FakeLanguageModel::default());
2419 let summary_model = Arc::new(FakeLanguageModel::default());
2420 thread.update(cx, |thread, cx| {
2421 thread.set_model(model.clone(), cx);
2422 thread.set_summarization_model(Some(summary_model.clone()), cx);
2423 });
2424 cx.run_until_parked();
2425 assert_eq!(thread_entries(&thread_store, cx), vec![]);
2426
2427 let send = acp_thread.update(cx, |thread, cx| {
2428 thread.send(
2429 vec![
2430 "What does ".into(),
2431 acp::ContentBlock::ResourceLink(acp::ResourceLink::new(
2432 "b.md",
2433 MentionUri::File {
2434 abs_path: path!("/a/b.md").into(),
2435 }
2436 .to_uri()
2437 .to_string(),
2438 )),
2439 " mean?".into(),
2440 ],
2441 cx,
2442 )
2443 });
2444 let send = cx.foreground_executor().spawn(send);
2445 cx.run_until_parked();
2446
2447 model.send_last_completion_stream_text_chunk("Lorem.");
2448 model.end_last_completion_stream();
2449 cx.run_until_parked();
2450 summary_model
2451 .send_last_completion_stream_text_chunk(&format!("Explaining {}", path!("/a/b.md")));
2452 summary_model.end_last_completion_stream();
2453
2454 send.await.unwrap();
2455 let uri = MentionUri::File {
2456 abs_path: path!("/a/b.md").into(),
2457 }
2458 .to_uri();
2459 acp_thread.read_with(cx, |thread, cx| {
2460 assert_eq!(
2461 thread.to_markdown(cx),
2462 formatdoc! {"
2463 ## User
2464
2465 What does [@b.md]({uri}) mean?
2466
2467 ## Assistant
2468
2469 Lorem.
2470
2471 "}
2472 )
2473 });
2474
2475 cx.run_until_parked();
2476
2477 // Close the session so it can be reloaded from disk.
2478 cx.update(|cx| connection.clone().close_session(&session_id, cx))
2479 .await
2480 .unwrap();
2481 drop(thread);
2482 drop(acp_thread);
2483 agent.read_with(cx, |agent, _| {
2484 assert_eq!(agent.sessions.keys().cloned().collect::<Vec<_>>(), []);
2485 });
2486
2487 // Ensure the thread can be reloaded from disk.
2488 assert_eq!(
2489 thread_entries(&thread_store, cx),
2490 vec![(
2491 session_id.clone(),
2492 format!("Explaining {}", path!("/a/b.md"))
2493 )]
2494 );
2495 let acp_thread = agent
2496 .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
2497 .await
2498 .unwrap();
2499 acp_thread.read_with(cx, |thread, cx| {
2500 assert_eq!(
2501 thread.to_markdown(cx),
2502 formatdoc! {"
2503 ## User
2504
2505 What does [@b.md]({uri}) mean?
2506
2507 ## Assistant
2508
2509 Lorem.
2510
2511 "}
2512 )
2513 });
2514 }
2515
2516 fn thread_entries(
2517 thread_store: &Entity<ThreadStore>,
2518 cx: &mut TestAppContext,
2519 ) -> Vec<(acp::SessionId, String)> {
2520 thread_store.read_with(cx, |store, _| {
2521 store
2522 .entries()
2523 .map(|entry| (entry.id.clone(), entry.title.to_string()))
2524 .collect::<Vec<_>>()
2525 })
2526 }
2527
2528 fn init_test(cx: &mut TestAppContext) {
2529 env_logger::try_init().ok();
2530 cx.update(|cx| {
2531 let settings_store = SettingsStore::test(cx);
2532 cx.set_global(settings_store);
2533
2534 LanguageModelRegistry::test(cx);
2535 });
2536 }
2537}
2538
2539fn mcp_message_content_to_acp_content_block(
2540 content: context_server::types::MessageContent,
2541) -> acp::ContentBlock {
2542 match content {
2543 context_server::types::MessageContent::Text {
2544 text,
2545 annotations: _,
2546 } => text.into(),
2547 context_server::types::MessageContent::Image {
2548 data,
2549 mime_type,
2550 annotations: _,
2551 } => acp::ContentBlock::Image(acp::ImageContent::new(data, mime_type)),
2552 context_server::types::MessageContent::Audio {
2553 data,
2554 mime_type,
2555 annotations: _,
2556 } => acp::ContentBlock::Audio(acp::AudioContent::new(data, mime_type)),
2557 context_server::types::MessageContent::Resource {
2558 resource,
2559 annotations: _,
2560 } => {
2561 let mut link =
2562 acp::ResourceLink::new(resource.uri.to_string(), resource.uri.to_string());
2563 if let Some(mime_type) = resource.mime_type {
2564 link = link.mime_type(mime_type);
2565 }
2566 acp::ContentBlock::ResourceLink(link)
2567 }
2568 }
2569}