1mod db;
2mod edit_agent;
3mod legacy_thread;
4mod native_agent_server;
5pub mod outline;
6mod pattern_extraction;
7mod shell_parser;
8mod templates;
9#[cfg(test)]
10mod tests;
11mod thread;
12mod thread_store;
13mod tool_permissions;
14mod tools;
15
16use context_server::ContextServerId;
17pub use db::*;
18pub use native_agent_server::NativeAgentServer;
19pub use pattern_extraction::*;
20pub use templates::*;
21pub use thread::*;
22pub use thread_store::*;
23pub use tool_permissions::*;
24pub use tools::*;
25
26use acp_thread::{
27 AcpThread, AgentModelSelector, AgentSessionInfo, AgentSessionList, AgentSessionListRequest,
28 AgentSessionListResponse, UserMessageId,
29};
30use agent_client_protocol as acp;
31use anyhow::{Context as _, Result, anyhow};
32use chrono::{DateTime, Utc};
33use collections::{HashMap, HashSet, IndexMap};
34use fs::Fs;
35use futures::channel::{mpsc, oneshot};
36use futures::future::Shared;
37use futures::{StreamExt, future};
38use gpui::{
39 App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
40};
41use language_model::{IconOrSvg, LanguageModel, LanguageModelProvider, LanguageModelRegistry};
42use project::{Project, ProjectItem, ProjectPath, Worktree};
43use prompt_store::{
44 ProjectContext, PromptStore, RULES_FILE_NAMES, RulesFileContext, UserRulesContext,
45 WorktreeContext,
46};
47use serde::{Deserialize, Serialize};
48use settings::{LanguageModelSelection, update_settings_file};
49use std::any::Any;
50use std::path::{Path, PathBuf};
51use std::rc::Rc;
52use std::sync::Arc;
53use util::ResultExt;
54use util::rel_path::RelPath;
55
56#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
57pub struct ProjectSnapshot {
58 pub worktree_snapshots: Vec<project::telemetry_snapshot::TelemetryWorktreeSnapshot>,
59 pub timestamp: DateTime<Utc>,
60}
61
62pub struct RulesLoadingError {
63 pub message: SharedString,
64}
65
66/// Holds both the internal Thread and the AcpThread for a session
67struct Session {
68 /// The internal thread that processes messages
69 thread: Entity<Thread>,
70 /// The ACP thread that handles protocol communication
71 acp_thread: WeakEntity<acp_thread::AcpThread>,
72 pending_save: Task<()>,
73 _subscriptions: Vec<Subscription>,
74}
75
76pub struct LanguageModels {
77 /// Access language model by ID
78 models: HashMap<acp::ModelId, Arc<dyn LanguageModel>>,
79 /// Cached list for returning language model information
80 model_list: acp_thread::AgentModelList,
81 refresh_models_rx: watch::Receiver<()>,
82 refresh_models_tx: watch::Sender<()>,
83 _authenticate_all_providers_task: Task<()>,
84}
85
86impl LanguageModels {
87 fn new(cx: &mut App) -> Self {
88 let (refresh_models_tx, refresh_models_rx) = watch::channel(());
89
90 let mut this = Self {
91 models: HashMap::default(),
92 model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()),
93 refresh_models_rx,
94 refresh_models_tx,
95 _authenticate_all_providers_task: Self::authenticate_all_language_model_providers(cx),
96 };
97 this.refresh_list(cx);
98 this
99 }
100
101 fn refresh_list(&mut self, cx: &App) {
102 let providers = LanguageModelRegistry::global(cx)
103 .read(cx)
104 .visible_providers()
105 .into_iter()
106 .filter(|provider| provider.is_authenticated(cx))
107 .collect::<Vec<_>>();
108
109 let mut language_model_list = IndexMap::default();
110 let mut recommended_models = HashSet::default();
111
112 let mut recommended = Vec::new();
113 for provider in &providers {
114 for model in provider.recommended_models(cx) {
115 recommended_models.insert((model.provider_id(), model.id()));
116 recommended.push(Self::map_language_model_to_info(&model, provider));
117 }
118 }
119 if !recommended.is_empty() {
120 language_model_list.insert(
121 acp_thread::AgentModelGroupName("Recommended".into()),
122 recommended,
123 );
124 }
125
126 let mut models = HashMap::default();
127 for provider in providers {
128 let mut provider_models = Vec::new();
129 for model in provider.provided_models(cx) {
130 let model_info = Self::map_language_model_to_info(&model, &provider);
131 let model_id = model_info.id.clone();
132 provider_models.push(model_info);
133 models.insert(model_id, model);
134 }
135 if !provider_models.is_empty() {
136 language_model_list.insert(
137 acp_thread::AgentModelGroupName(provider.name().0.clone()),
138 provider_models,
139 );
140 }
141 }
142
143 self.models = models;
144 self.model_list = acp_thread::AgentModelList::Grouped(language_model_list);
145 self.refresh_models_tx.send(()).ok();
146 }
147
148 fn watch(&self) -> watch::Receiver<()> {
149 self.refresh_models_rx.clone()
150 }
151
152 pub fn model_from_id(&self, model_id: &acp::ModelId) -> Option<Arc<dyn LanguageModel>> {
153 self.models.get(model_id).cloned()
154 }
155
156 fn map_language_model_to_info(
157 model: &Arc<dyn LanguageModel>,
158 provider: &Arc<dyn LanguageModelProvider>,
159 ) -> acp_thread::AgentModelInfo {
160 acp_thread::AgentModelInfo {
161 id: Self::model_id(model),
162 name: model.name().0,
163 description: None,
164 icon: Some(match provider.icon() {
165 IconOrSvg::Svg(path) => acp_thread::AgentModelIcon::Path(path),
166 IconOrSvg::Icon(name) => acp_thread::AgentModelIcon::Named(name),
167 }),
168 is_latest: model.is_latest(),
169 }
170 }
171
172 fn model_id(model: &Arc<dyn LanguageModel>) -> acp::ModelId {
173 acp::ModelId::new(format!("{}/{}", model.provider_id().0, model.id().0))
174 }
175
176 fn authenticate_all_language_model_providers(cx: &mut App) -> Task<()> {
177 let authenticate_all_providers = LanguageModelRegistry::global(cx)
178 .read(cx)
179 .visible_providers()
180 .iter()
181 .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
182 .collect::<Vec<_>>();
183
184 cx.background_spawn(async move {
185 for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
186 if let Err(err) = authenticate_task.await {
187 match err {
188 language_model::AuthenticateError::CredentialsNotFound => {
189 // Since we're authenticating these providers in the
190 // background for the purposes of populating the
191 // language selector, we don't care about providers
192 // where the credentials are not found.
193 }
194 language_model::AuthenticateError::ConnectionRefused => {
195 // Not logging connection refused errors as they are mostly from LM Studio's noisy auth failures.
196 // LM Studio only has one auth method (endpoint call) which fails for users who haven't enabled it.
197 // TODO: Better manage LM Studio auth logic to avoid these noisy failures.
198 }
199 _ => {
200 // Some providers have noisy failure states that we
201 // don't want to spam the logs with every time the
202 // language model selector is initialized.
203 //
204 // Ideally these should have more clear failure modes
205 // that we know are safe to ignore here, like what we do
206 // with `CredentialsNotFound` above.
207 match provider_id.0.as_ref() {
208 "lmstudio" | "ollama" => {
209 // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
210 //
211 // These fail noisily, so we don't log them.
212 }
213 "copilot_chat" => {
214 // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
215 }
216 _ => {
217 log::error!(
218 "Failed to authenticate provider: {}: {err:#}",
219 provider_name.0
220 );
221 }
222 }
223 }
224 }
225 }
226 }
227 })
228 }
229}
230
231pub struct NativeAgent {
232 /// Session ID -> Session mapping
233 sessions: HashMap<acp::SessionId, Session>,
234 thread_store: Entity<ThreadStore>,
235 /// Shared project context for all threads
236 project_context: Entity<ProjectContext>,
237 project_context_needs_refresh: watch::Sender<()>,
238 _maintain_project_context: Task<Result<()>>,
239 context_server_registry: Entity<ContextServerRegistry>,
240 /// Shared templates for all threads
241 templates: Arc<Templates>,
242 /// Cached model information
243 models: LanguageModels,
244 project: Entity<Project>,
245 prompt_store: Option<Entity<PromptStore>>,
246 fs: Arc<dyn Fs>,
247 _subscriptions: Vec<Subscription>,
248}
249
250impl NativeAgent {
251 pub async fn new(
252 project: Entity<Project>,
253 thread_store: Entity<ThreadStore>,
254 templates: Arc<Templates>,
255 prompt_store: Option<Entity<PromptStore>>,
256 fs: Arc<dyn Fs>,
257 cx: &mut AsyncApp,
258 ) -> Result<Entity<NativeAgent>> {
259 log::debug!("Creating new NativeAgent");
260
261 let project_context = cx
262 .update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))
263 .await;
264
265 Ok(cx.new(|cx| {
266 let context_server_store = project.read(cx).context_server_store();
267 let context_server_registry =
268 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
269
270 let mut subscriptions = vec![
271 cx.subscribe(&project, Self::handle_project_event),
272 cx.subscribe(
273 &LanguageModelRegistry::global(cx),
274 Self::handle_models_updated_event,
275 ),
276 cx.subscribe(
277 &context_server_store,
278 Self::handle_context_server_store_updated,
279 ),
280 cx.subscribe(
281 &context_server_registry,
282 Self::handle_context_server_registry_event,
283 ),
284 ];
285 if let Some(prompt_store) = prompt_store.as_ref() {
286 subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
287 }
288
289 let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
290 watch::channel(());
291 Self {
292 sessions: HashMap::default(),
293 thread_store,
294 project_context: cx.new(|_| project_context),
295 project_context_needs_refresh: project_context_needs_refresh_tx,
296 _maintain_project_context: cx.spawn(async move |this, cx| {
297 Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
298 }),
299 context_server_registry,
300 templates,
301 models: LanguageModels::new(cx),
302 project,
303 prompt_store,
304 fs,
305 _subscriptions: subscriptions,
306 }
307 }))
308 }
309
310 fn new_session(
311 &mut self,
312 project: Entity<Project>,
313 cx: &mut Context<Self>,
314 ) -> Entity<AcpThread> {
315 // Create Thread
316 // Fetch default model from registry settings
317 let registry = LanguageModelRegistry::read_global(cx);
318 // Log available models for debugging
319 let available_count = registry.available_models(cx).count();
320 log::debug!("Total available models: {}", available_count);
321
322 let default_model = registry.default_model().and_then(|default_model| {
323 self.models
324 .model_from_id(&LanguageModels::model_id(&default_model.model))
325 });
326 let thread = cx.new(|cx| {
327 Thread::new(
328 project.clone(),
329 self.project_context.clone(),
330 self.context_server_registry.clone(),
331 self.templates.clone(),
332 default_model,
333 cx,
334 )
335 });
336
337 self.register_session(thread, cx)
338 }
339
340 fn register_session(
341 &mut self,
342 thread_handle: Entity<Thread>,
343 cx: &mut Context<Self>,
344 ) -> Entity<AcpThread> {
345 let connection = Rc::new(NativeAgentConnection(cx.entity()));
346
347 let thread = thread_handle.read(cx);
348 let session_id = thread.id().clone();
349 let title = thread.title();
350 let project = thread.project.clone();
351 let action_log = thread.action_log.clone();
352 let prompt_capabilities_rx = thread.prompt_capabilities_rx.clone();
353 let acp_thread = cx.new(|cx| {
354 acp_thread::AcpThread::new(
355 title,
356 connection,
357 project.clone(),
358 action_log.clone(),
359 session_id.clone(),
360 prompt_capabilities_rx,
361 cx,
362 )
363 });
364
365 let registry = LanguageModelRegistry::read_global(cx);
366 let summarization_model = registry.thread_summary_model().map(|c| c.model);
367
368 thread_handle.update(cx, |thread, cx| {
369 thread.set_summarization_model(summarization_model, cx);
370 thread.add_default_tools(
371 Rc::new(AcpThreadEnvironment {
372 acp_thread: acp_thread.downgrade(),
373 }) as _,
374 cx,
375 )
376 });
377
378 let subscriptions = vec![
379 cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
380 this.sessions.remove(acp_thread.session_id());
381 }),
382 cx.subscribe(&thread_handle, Self::handle_thread_title_updated),
383 cx.subscribe(&thread_handle, Self::handle_thread_token_usage_updated),
384 cx.observe(&thread_handle, move |this, thread, cx| {
385 this.save_thread(thread, cx)
386 }),
387 ];
388
389 self.sessions.insert(
390 session_id,
391 Session {
392 thread: thread_handle,
393 acp_thread: acp_thread.downgrade(),
394 _subscriptions: subscriptions,
395 pending_save: Task::ready(()),
396 },
397 );
398
399 self.update_available_commands(cx);
400
401 acp_thread
402 }
403
404 pub fn models(&self) -> &LanguageModels {
405 &self.models
406 }
407
408 async fn maintain_project_context(
409 this: WeakEntity<Self>,
410 mut needs_refresh: watch::Receiver<()>,
411 cx: &mut AsyncApp,
412 ) -> Result<()> {
413 while needs_refresh.changed().await.is_ok() {
414 let project_context = this
415 .update(cx, |this, cx| {
416 Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx)
417 })?
418 .await;
419 this.update(cx, |this, cx| {
420 this.project_context = cx.new(|_| project_context);
421 })?;
422 }
423
424 Ok(())
425 }
426
427 fn build_project_context(
428 project: &Entity<Project>,
429 prompt_store: Option<&Entity<PromptStore>>,
430 cx: &mut App,
431 ) -> Task<ProjectContext> {
432 let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
433 let worktree_tasks = worktrees
434 .into_iter()
435 .map(|worktree| {
436 Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
437 })
438 .collect::<Vec<_>>();
439 let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
440 prompt_store.read_with(cx, |prompt_store, cx| {
441 let prompts = prompt_store.default_prompt_metadata();
442 let load_tasks = prompts.into_iter().map(|prompt_metadata| {
443 let contents = prompt_store.load(prompt_metadata.id, cx);
444 async move { (contents.await, prompt_metadata) }
445 });
446 cx.background_spawn(future::join_all(load_tasks))
447 })
448 } else {
449 Task::ready(vec![])
450 };
451
452 cx.spawn(async move |_cx| {
453 let (worktrees, default_user_rules) =
454 future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
455
456 let worktrees = worktrees
457 .into_iter()
458 .map(|(worktree, _rules_error)| {
459 // TODO: show error message
460 // if let Some(rules_error) = rules_error {
461 // this.update(cx, |_, cx| cx.emit(rules_error)).ok();
462 // }
463 worktree
464 })
465 .collect::<Vec<_>>();
466
467 let default_user_rules = default_user_rules
468 .into_iter()
469 .flat_map(|(contents, prompt_metadata)| match contents {
470 Ok(contents) => Some(UserRulesContext {
471 uuid: prompt_metadata.id.as_user()?,
472 title: prompt_metadata.title.map(|title| title.to_string()),
473 contents,
474 }),
475 Err(_err) => {
476 // TODO: show error message
477 // this.update(cx, |_, cx| {
478 // cx.emit(RulesLoadingError {
479 // message: format!("{err:?}").into(),
480 // });
481 // })
482 // .ok();
483 None
484 }
485 })
486 .collect::<Vec<_>>();
487
488 ProjectContext::new(worktrees, default_user_rules)
489 })
490 }
491
492 fn load_worktree_info_for_system_prompt(
493 worktree: Entity<Worktree>,
494 project: Entity<Project>,
495 cx: &mut App,
496 ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
497 let tree = worktree.read(cx);
498 let root_name = tree.root_name_str().into();
499 let abs_path = tree.abs_path();
500
501 let mut context = WorktreeContext {
502 root_name,
503 abs_path,
504 rules_file: None,
505 };
506
507 let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
508 let Some(rules_task) = rules_task else {
509 return Task::ready((context, None));
510 };
511
512 cx.spawn(async move |_| {
513 let (rules_file, rules_file_error) = match rules_task.await {
514 Ok(rules_file) => (Some(rules_file), None),
515 Err(err) => (
516 None,
517 Some(RulesLoadingError {
518 message: format!("{err}").into(),
519 }),
520 ),
521 };
522 context.rules_file = rules_file;
523 (context, rules_file_error)
524 })
525 }
526
527 fn load_worktree_rules_file(
528 worktree: Entity<Worktree>,
529 project: Entity<Project>,
530 cx: &mut App,
531 ) -> Option<Task<Result<RulesFileContext>>> {
532 let worktree = worktree.read(cx);
533 let worktree_id = worktree.id();
534 let selected_rules_file = RULES_FILE_NAMES
535 .into_iter()
536 .filter_map(|name| {
537 worktree
538 .entry_for_path(RelPath::unix(name).unwrap())
539 .filter(|entry| entry.is_file())
540 .map(|entry| entry.path.clone())
541 })
542 .next();
543
544 // Note that Cline supports `.clinerules` being a directory, but that is not currently
545 // supported. This doesn't seem to occur often in GitHub repositories.
546 selected_rules_file.map(|path_in_worktree| {
547 let project_path = ProjectPath {
548 worktree_id,
549 path: path_in_worktree.clone(),
550 };
551 let buffer_task =
552 project.update(cx, |project, cx| project.open_buffer(project_path, cx));
553 let rope_task = cx.spawn(async move |cx| {
554 let buffer = buffer_task.await?;
555 let (project_entry_id, rope) = buffer.read_with(cx, |buffer, cx| {
556 let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
557 anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
558 })?;
559 anyhow::Ok((project_entry_id, rope))
560 });
561 // Build a string from the rope on a background thread.
562 cx.background_spawn(async move {
563 let (project_entry_id, rope) = rope_task.await?;
564 anyhow::Ok(RulesFileContext {
565 path_in_worktree,
566 text: rope.to_string().trim().to_string(),
567 project_entry_id: project_entry_id.to_usize(),
568 })
569 })
570 })
571 }
572
573 fn handle_thread_title_updated(
574 &mut self,
575 thread: Entity<Thread>,
576 _: &TitleUpdated,
577 cx: &mut Context<Self>,
578 ) {
579 let session_id = thread.read(cx).id();
580 let Some(session) = self.sessions.get(session_id) else {
581 return;
582 };
583 let thread = thread.downgrade();
584 let acp_thread = session.acp_thread.clone();
585 cx.spawn(async move |_, cx| {
586 let title = thread.read_with(cx, |thread, _| thread.title())?;
587 let task = acp_thread.update(cx, |acp_thread, cx| acp_thread.set_title(title, cx))?;
588 task.await
589 })
590 .detach_and_log_err(cx);
591 }
592
593 fn handle_thread_token_usage_updated(
594 &mut self,
595 thread: Entity<Thread>,
596 usage: &TokenUsageUpdated,
597 cx: &mut Context<Self>,
598 ) {
599 let Some(session) = self.sessions.get(thread.read(cx).id()) else {
600 return;
601 };
602 session
603 .acp_thread
604 .update(cx, |acp_thread, cx| {
605 acp_thread.update_token_usage(usage.0.clone(), cx);
606 })
607 .ok();
608 }
609
610 fn handle_project_event(
611 &mut self,
612 _project: Entity<Project>,
613 event: &project::Event,
614 _cx: &mut Context<Self>,
615 ) {
616 match event {
617 project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
618 self.project_context_needs_refresh.send(()).ok();
619 }
620 project::Event::WorktreeUpdatedEntries(_, items) => {
621 if items.iter().any(|(path, _, _)| {
622 RULES_FILE_NAMES
623 .iter()
624 .any(|name| path.as_ref() == RelPath::unix(name).unwrap())
625 }) {
626 self.project_context_needs_refresh.send(()).ok();
627 }
628 }
629 _ => {}
630 }
631 }
632
633 fn handle_prompts_updated_event(
634 &mut self,
635 _prompt_store: Entity<PromptStore>,
636 _event: &prompt_store::PromptsUpdatedEvent,
637 _cx: &mut Context<Self>,
638 ) {
639 self.project_context_needs_refresh.send(()).ok();
640 }
641
642 fn handle_models_updated_event(
643 &mut self,
644 _registry: Entity<LanguageModelRegistry>,
645 _event: &language_model::Event,
646 cx: &mut Context<Self>,
647 ) {
648 self.models.refresh_list(cx);
649
650 let registry = LanguageModelRegistry::read_global(cx);
651 let default_model = registry.default_model().map(|m| m.model);
652 let summarization_model = registry.thread_summary_model().map(|m| m.model);
653
654 for session in self.sessions.values_mut() {
655 session.thread.update(cx, |thread, cx| {
656 if thread.model().is_none()
657 && let Some(model) = default_model.clone()
658 {
659 thread.set_model(model, cx);
660 cx.notify();
661 }
662 thread.set_summarization_model(summarization_model.clone(), cx);
663 });
664 }
665 }
666
667 fn handle_context_server_store_updated(
668 &mut self,
669 _store: Entity<project::context_server_store::ContextServerStore>,
670 _event: &project::context_server_store::ServerStatusChangedEvent,
671 cx: &mut Context<Self>,
672 ) {
673 self.update_available_commands(cx);
674 }
675
676 fn handle_context_server_registry_event(
677 &mut self,
678 _registry: Entity<ContextServerRegistry>,
679 event: &ContextServerRegistryEvent,
680 cx: &mut Context<Self>,
681 ) {
682 match event {
683 ContextServerRegistryEvent::ToolsChanged => {}
684 ContextServerRegistryEvent::PromptsChanged => {
685 self.update_available_commands(cx);
686 }
687 }
688 }
689
690 fn update_available_commands(&self, cx: &mut Context<Self>) {
691 let available_commands = self.build_available_commands(cx);
692 for session in self.sessions.values() {
693 if let Some(acp_thread) = session.acp_thread.upgrade() {
694 acp_thread.update(cx, |thread, cx| {
695 thread
696 .handle_session_update(
697 acp::SessionUpdate::AvailableCommandsUpdate(
698 acp::AvailableCommandsUpdate::new(available_commands.clone()),
699 ),
700 cx,
701 )
702 .log_err();
703 });
704 }
705 }
706 }
707
708 fn build_available_commands(&self, cx: &App) -> Vec<acp::AvailableCommand> {
709 let registry = self.context_server_registry.read(cx);
710
711 let mut prompt_name_counts: HashMap<&str, usize> = HashMap::default();
712 for context_server_prompt in registry.prompts() {
713 *prompt_name_counts
714 .entry(context_server_prompt.prompt.name.as_str())
715 .or_insert(0) += 1;
716 }
717
718 registry
719 .prompts()
720 .flat_map(|context_server_prompt| {
721 let prompt = &context_server_prompt.prompt;
722
723 let should_prefix = prompt_name_counts
724 .get(prompt.name.as_str())
725 .copied()
726 .unwrap_or(0)
727 > 1;
728
729 let name = if should_prefix {
730 format!("{}.{}", context_server_prompt.server_id, prompt.name)
731 } else {
732 prompt.name.clone()
733 };
734
735 let mut command = acp::AvailableCommand::new(
736 name,
737 prompt.description.clone().unwrap_or_default(),
738 );
739
740 match prompt.arguments.as_deref() {
741 Some([arg]) => {
742 let hint = format!("<{}>", arg.name);
743
744 command = command.input(acp::AvailableCommandInput::Unstructured(
745 acp::UnstructuredCommandInput::new(hint),
746 ));
747 }
748 Some([]) | None => {}
749 Some(_) => {
750 // skip >1 argument commands since we don't support them yet
751 return None;
752 }
753 }
754
755 Some(command)
756 })
757 .collect()
758 }
759
760 pub fn load_thread(
761 &mut self,
762 id: acp::SessionId,
763 cx: &mut Context<Self>,
764 ) -> Task<Result<Entity<Thread>>> {
765 let database_future = ThreadsDatabase::connect(cx);
766 cx.spawn(async move |this, cx| {
767 let database = database_future.await.map_err(|err| anyhow!(err))?;
768 let db_thread = database
769 .load_thread(id.clone())
770 .await?
771 .with_context(|| format!("no thread found with ID: {id:?}"))?;
772
773 this.update(cx, |this, cx| {
774 let summarization_model = LanguageModelRegistry::read_global(cx)
775 .thread_summary_model()
776 .map(|c| c.model);
777
778 cx.new(|cx| {
779 let mut thread = Thread::from_db(
780 id.clone(),
781 db_thread,
782 this.project.clone(),
783 this.project_context.clone(),
784 this.context_server_registry.clone(),
785 this.templates.clone(),
786 cx,
787 );
788 thread.set_summarization_model(summarization_model, cx);
789 thread
790 })
791 })
792 })
793 }
794
795 pub fn open_thread(
796 &mut self,
797 id: acp::SessionId,
798 cx: &mut Context<Self>,
799 ) -> Task<Result<Entity<AcpThread>>> {
800 let task = self.load_thread(id, cx);
801 cx.spawn(async move |this, cx| {
802 let thread = task.await?;
803 let acp_thread =
804 this.update(cx, |this, cx| this.register_session(thread.clone(), cx))?;
805 let events = thread.update(cx, |thread, cx| thread.replay(cx));
806 cx.update(|cx| {
807 NativeAgentConnection::handle_thread_events(events, acp_thread.downgrade(), cx)
808 })
809 .await?;
810 Ok(acp_thread)
811 })
812 }
813
814 pub fn thread_summary(
815 &mut self,
816 id: acp::SessionId,
817 cx: &mut Context<Self>,
818 ) -> Task<Result<SharedString>> {
819 let thread = self.open_thread(id.clone(), cx);
820 cx.spawn(async move |this, cx| {
821 let acp_thread = thread.await?;
822 let result = this
823 .update(cx, |this, cx| {
824 this.sessions
825 .get(&id)
826 .unwrap()
827 .thread
828 .update(cx, |thread, cx| thread.summary(cx))
829 })?
830 .await
831 .context("Failed to generate summary")?;
832 drop(acp_thread);
833 Ok(result)
834 })
835 }
836
837 fn save_thread(&mut self, thread: Entity<Thread>, cx: &mut Context<Self>) {
838 if thread.read(cx).is_empty() {
839 return;
840 }
841
842 let database_future = ThreadsDatabase::connect(cx);
843 let (id, db_thread) =
844 thread.update(cx, |thread, cx| (thread.id().clone(), thread.to_db(cx)));
845 let Some(session) = self.sessions.get_mut(&id) else {
846 return;
847 };
848 let thread_store = self.thread_store.clone();
849 session.pending_save = cx.spawn(async move |_, cx| {
850 let Some(database) = database_future.await.map_err(|err| anyhow!(err)).log_err() else {
851 return;
852 };
853 let db_thread = db_thread.await;
854 database.save_thread(id, db_thread).await.log_err();
855 thread_store.update(cx, |store, cx| store.reload(cx));
856 });
857 }
858
859 fn send_mcp_prompt(
860 &self,
861 message_id: UserMessageId,
862 session_id: agent_client_protocol::SessionId,
863 prompt_name: String,
864 server_id: ContextServerId,
865 arguments: HashMap<String, String>,
866 original_content: Vec<acp::ContentBlock>,
867 cx: &mut Context<Self>,
868 ) -> Task<Result<acp::PromptResponse>> {
869 let server_store = self.context_server_registry.read(cx).server_store().clone();
870 let path_style = self.project.read(cx).path_style(cx);
871
872 cx.spawn(async move |this, cx| {
873 let prompt =
874 crate::get_prompt(&server_store, &server_id, &prompt_name, arguments, cx).await?;
875
876 let (acp_thread, thread) = this.update(cx, |this, _cx| {
877 let session = this
878 .sessions
879 .get(&session_id)
880 .context("Failed to get session")?;
881 anyhow::Ok((session.acp_thread.clone(), session.thread.clone()))
882 })??;
883
884 let mut last_is_user = true;
885
886 thread.update(cx, |thread, cx| {
887 thread.push_acp_user_block(
888 message_id,
889 original_content.into_iter().skip(1),
890 path_style,
891 cx,
892 );
893 });
894
895 for message in prompt.messages {
896 let context_server::types::PromptMessage { role, content } = message;
897 let block = mcp_message_content_to_acp_content_block(content);
898
899 match role {
900 context_server::types::Role::User => {
901 let id = acp_thread::UserMessageId::new();
902
903 acp_thread.update(cx, |acp_thread, cx| {
904 acp_thread.push_user_content_block_with_indent(
905 Some(id.clone()),
906 block.clone(),
907 true,
908 cx,
909 );
910 })?;
911
912 thread.update(cx, |thread, cx| {
913 thread.push_acp_user_block(id, [block], path_style, cx);
914 });
915 }
916 context_server::types::Role::Assistant => {
917 acp_thread.update(cx, |acp_thread, cx| {
918 acp_thread.push_assistant_content_block_with_indent(
919 block.clone(),
920 false,
921 true,
922 cx,
923 );
924 })?;
925
926 thread.update(cx, |thread, cx| {
927 thread.push_acp_agent_block(block, cx);
928 });
929 }
930 }
931
932 last_is_user = role == context_server::types::Role::User;
933 }
934
935 let response_stream = thread.update(cx, |thread, cx| {
936 if last_is_user {
937 thread.send_existing(cx)
938 } else {
939 // Resume if MCP prompt did not end with a user message
940 thread.resume(cx)
941 }
942 })?;
943
944 cx.update(|cx| {
945 NativeAgentConnection::handle_thread_events(response_stream, acp_thread, cx)
946 })
947 .await
948 })
949 }
950}
951
952/// Wrapper struct that implements the AgentConnection trait
953#[derive(Clone)]
954pub struct NativeAgentConnection(pub Entity<NativeAgent>);
955
956impl NativeAgentConnection {
957 pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option<Entity<Thread>> {
958 self.0
959 .read(cx)
960 .sessions
961 .get(session_id)
962 .map(|session| session.thread.clone())
963 }
964
965 pub fn load_thread(&self, id: acp::SessionId, cx: &mut App) -> Task<Result<Entity<Thread>>> {
966 self.0.update(cx, |this, cx| this.load_thread(id, cx))
967 }
968
969 fn run_turn(
970 &self,
971 session_id: acp::SessionId,
972 cx: &mut App,
973 f: impl 'static
974 + FnOnce(Entity<Thread>, &mut App) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>,
975 ) -> Task<Result<acp::PromptResponse>> {
976 let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
977 agent
978 .sessions
979 .get_mut(&session_id)
980 .map(|s| (s.thread.clone(), s.acp_thread.clone()))
981 }) else {
982 return Task::ready(Err(anyhow!("Session not found")));
983 };
984 log::debug!("Found session for: {}", session_id);
985
986 let response_stream = match f(thread, cx) {
987 Ok(stream) => stream,
988 Err(err) => return Task::ready(Err(err)),
989 };
990 Self::handle_thread_events(response_stream, acp_thread, cx)
991 }
992
993 fn handle_thread_events(
994 mut events: mpsc::UnboundedReceiver<Result<ThreadEvent>>,
995 acp_thread: WeakEntity<AcpThread>,
996 cx: &App,
997 ) -> Task<Result<acp::PromptResponse>> {
998 cx.spawn(async move |cx| {
999 // Handle response stream and forward to session.acp_thread
1000 while let Some(result) = events.next().await {
1001 match result {
1002 Ok(event) => {
1003 log::trace!("Received completion event: {:?}", event);
1004
1005 match event {
1006 ThreadEvent::UserMessage(message) => {
1007 acp_thread.update(cx, |thread, cx| {
1008 for content in message.content {
1009 thread.push_user_content_block(
1010 Some(message.id.clone()),
1011 content.into(),
1012 cx,
1013 );
1014 }
1015 })?;
1016 }
1017 ThreadEvent::AgentText(text) => {
1018 acp_thread.update(cx, |thread, cx| {
1019 thread.push_assistant_content_block(text.into(), false, cx)
1020 })?;
1021 }
1022 ThreadEvent::AgentThinking(text) => {
1023 acp_thread.update(cx, |thread, cx| {
1024 thread.push_assistant_content_block(text.into(), true, cx)
1025 })?;
1026 }
1027 ThreadEvent::ToolCallAuthorization(ToolCallAuthorization {
1028 tool_call,
1029 options,
1030 response,
1031 context: _,
1032 }) => {
1033 let outcome_task = acp_thread.update(cx, |thread, cx| {
1034 thread.request_tool_call_authorization(
1035 tool_call, options, true, cx,
1036 )
1037 })??;
1038 cx.background_spawn(async move {
1039 if let acp::RequestPermissionOutcome::Selected(
1040 acp::SelectedPermissionOutcome { option_id, .. },
1041 ) = outcome_task.await
1042 {
1043 response
1044 .send(option_id)
1045 .map(|_| anyhow!("authorization receiver was dropped"))
1046 .log_err();
1047 }
1048 })
1049 .detach();
1050 }
1051 ThreadEvent::ToolCall(tool_call) => {
1052 acp_thread.update(cx, |thread, cx| {
1053 thread.upsert_tool_call(tool_call, cx)
1054 })??;
1055 }
1056 ThreadEvent::ToolCallUpdate(update) => {
1057 acp_thread.update(cx, |thread, cx| {
1058 thread.update_tool_call(update, cx)
1059 })??;
1060 }
1061 ThreadEvent::Retry(status) => {
1062 acp_thread.update(cx, |thread, cx| {
1063 thread.update_retry_status(status, cx)
1064 })?;
1065 }
1066 ThreadEvent::Stop(stop_reason) => {
1067 log::debug!("Assistant message complete: {:?}", stop_reason);
1068 return Ok(acp::PromptResponse::new(stop_reason));
1069 }
1070 }
1071 }
1072 Err(e) => {
1073 log::error!("Error in model response stream: {:?}", e);
1074 return Err(e);
1075 }
1076 }
1077 }
1078
1079 log::debug!("Response stream completed");
1080 anyhow::Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
1081 })
1082 }
1083}
1084
1085struct Command<'a> {
1086 prompt_name: &'a str,
1087 arg_value: &'a str,
1088 explicit_server_id: Option<&'a str>,
1089}
1090
1091impl<'a> Command<'a> {
1092 fn parse(prompt: &'a [acp::ContentBlock]) -> Option<Self> {
1093 let acp::ContentBlock::Text(text_content) = prompt.first()? else {
1094 return None;
1095 };
1096 let text = text_content.text.trim();
1097 let command = text.strip_prefix('/')?;
1098 let (command, arg_value) = command
1099 .split_once(char::is_whitespace)
1100 .unwrap_or((command, ""));
1101
1102 if let Some((server_id, prompt_name)) = command.split_once('.') {
1103 Some(Self {
1104 prompt_name,
1105 arg_value,
1106 explicit_server_id: Some(server_id),
1107 })
1108 } else {
1109 Some(Self {
1110 prompt_name: command,
1111 arg_value,
1112 explicit_server_id: None,
1113 })
1114 }
1115 }
1116}
1117
1118struct NativeAgentModelSelector {
1119 session_id: acp::SessionId,
1120 connection: NativeAgentConnection,
1121}
1122
1123impl acp_thread::AgentModelSelector for NativeAgentModelSelector {
1124 fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
1125 log::debug!("NativeAgentConnection::list_models called");
1126 let list = self.connection.0.read(cx).models.model_list.clone();
1127 Task::ready(if list.is_empty() {
1128 Err(anyhow::anyhow!("No models available"))
1129 } else {
1130 Ok(list)
1131 })
1132 }
1133
1134 fn select_model(&self, model_id: acp::ModelId, cx: &mut App) -> Task<Result<()>> {
1135 log::debug!(
1136 "Setting model for session {}: {}",
1137 self.session_id,
1138 model_id
1139 );
1140 let Some(thread) = self
1141 .connection
1142 .0
1143 .read(cx)
1144 .sessions
1145 .get(&self.session_id)
1146 .map(|session| session.thread.clone())
1147 else {
1148 return Task::ready(Err(anyhow!("Session not found")));
1149 };
1150
1151 let Some(model) = self.connection.0.read(cx).models.model_from_id(&model_id) else {
1152 return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
1153 };
1154
1155 // We want to reset the effort level when switching models, as the currently-selected effort level may
1156 // not be compatible.
1157 let effort = model
1158 .default_effort_level()
1159 .map(|effort_level| effort_level.value.to_string());
1160
1161 thread.update(cx, |thread, cx| {
1162 thread.set_model(model.clone(), cx);
1163 thread.set_thinking_effort(effort.clone(), cx);
1164 thread.set_thinking_enabled(model.supports_thinking(), cx);
1165 });
1166
1167 update_settings_file(
1168 self.connection.0.read(cx).fs.clone(),
1169 cx,
1170 move |settings, cx| {
1171 let provider = model.provider_id().0.to_string();
1172 let model = model.id().0.to_string();
1173 let enable_thinking = thread.read(cx).thinking_enabled();
1174 settings
1175 .agent
1176 .get_or_insert_default()
1177 .set_model(LanguageModelSelection {
1178 provider: provider.into(),
1179 model,
1180 enable_thinking,
1181 effort,
1182 });
1183 },
1184 );
1185
1186 Task::ready(Ok(()))
1187 }
1188
1189 fn selected_model(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelInfo>> {
1190 let Some(thread) = self
1191 .connection
1192 .0
1193 .read(cx)
1194 .sessions
1195 .get(&self.session_id)
1196 .map(|session| session.thread.clone())
1197 else {
1198 return Task::ready(Err(anyhow!("Session not found")));
1199 };
1200 let Some(model) = thread.read(cx).model() else {
1201 return Task::ready(Err(anyhow!("Model not found")));
1202 };
1203 let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
1204 else {
1205 return Task::ready(Err(anyhow!("Provider not found")));
1206 };
1207 Task::ready(Ok(LanguageModels::map_language_model_to_info(
1208 model, &provider,
1209 )))
1210 }
1211
1212 fn watch(&self, cx: &mut App) -> Option<watch::Receiver<()>> {
1213 Some(self.connection.0.read(cx).models.watch())
1214 }
1215
1216 fn should_render_footer(&self) -> bool {
1217 true
1218 }
1219}
1220
1221impl acp_thread::AgentConnection for NativeAgentConnection {
1222 fn telemetry_id(&self) -> SharedString {
1223 "zed".into()
1224 }
1225
1226 fn new_thread(
1227 self: Rc<Self>,
1228 project: Entity<Project>,
1229 cwd: &Path,
1230 cx: &mut App,
1231 ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
1232 log::debug!("Creating new thread for project at: {cwd:?}");
1233 Task::ready(Ok(self
1234 .0
1235 .update(cx, |agent, cx| agent.new_session(project, cx))))
1236 }
1237
1238 fn supports_load_session(&self, _cx: &App) -> bool {
1239 true
1240 }
1241
1242 fn load_session(
1243 self: Rc<Self>,
1244 session: AgentSessionInfo,
1245 _project: Entity<Project>,
1246 _cwd: &Path,
1247 cx: &mut App,
1248 ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
1249 self.0
1250 .update(cx, |agent, cx| agent.open_thread(session.session_id, cx))
1251 }
1252
1253 fn auth_methods(&self) -> &[acp::AuthMethod] {
1254 &[] // No auth for in-process
1255 }
1256
1257 fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
1258 Task::ready(Ok(()))
1259 }
1260
1261 fn model_selector(&self, session_id: &acp::SessionId) -> Option<Rc<dyn AgentModelSelector>> {
1262 Some(Rc::new(NativeAgentModelSelector {
1263 session_id: session_id.clone(),
1264 connection: self.clone(),
1265 }) as Rc<dyn AgentModelSelector>)
1266 }
1267
1268 fn prompt(
1269 &self,
1270 id: Option<acp_thread::UserMessageId>,
1271 params: acp::PromptRequest,
1272 cx: &mut App,
1273 ) -> Task<Result<acp::PromptResponse>> {
1274 let id = id.expect("UserMessageId is required");
1275 let session_id = params.session_id.clone();
1276 log::info!("Received prompt request for session: {}", session_id);
1277 log::debug!("Prompt blocks count: {}", params.prompt.len());
1278
1279 if let Some(parsed_command) = Command::parse(¶ms.prompt) {
1280 let registry = self.0.read(cx).context_server_registry.read(cx);
1281
1282 let explicit_server_id = parsed_command
1283 .explicit_server_id
1284 .map(|server_id| ContextServerId(server_id.into()));
1285
1286 if let Some(prompt) =
1287 registry.find_prompt(explicit_server_id.as_ref(), parsed_command.prompt_name)
1288 {
1289 let arguments = if !parsed_command.arg_value.is_empty()
1290 && let Some(arg_name) = prompt
1291 .prompt
1292 .arguments
1293 .as_ref()
1294 .and_then(|args| args.first())
1295 .map(|arg| arg.name.clone())
1296 {
1297 HashMap::from_iter([(arg_name, parsed_command.arg_value.to_string())])
1298 } else {
1299 Default::default()
1300 };
1301
1302 let prompt_name = prompt.prompt.name.clone();
1303 let server_id = prompt.server_id.clone();
1304
1305 return self.0.update(cx, |agent, cx| {
1306 agent.send_mcp_prompt(
1307 id,
1308 session_id.clone(),
1309 prompt_name,
1310 server_id,
1311 arguments,
1312 params.prompt,
1313 cx,
1314 )
1315 });
1316 };
1317 };
1318
1319 let path_style = self.0.read(cx).project.read(cx).path_style(cx);
1320
1321 self.run_turn(session_id, cx, move |thread, cx| {
1322 let content: Vec<UserMessageContent> = params
1323 .prompt
1324 .into_iter()
1325 .map(|block| UserMessageContent::from_content_block(block, path_style))
1326 .collect::<Vec<_>>();
1327 log::debug!("Converted prompt to message: {} chars", content.len());
1328 log::debug!("Message id: {:?}", id);
1329 log::debug!("Message content: {:?}", content);
1330
1331 thread.update(cx, |thread, cx| thread.send(id, content, cx))
1332 })
1333 }
1334
1335 fn retry(
1336 &self,
1337 session_id: &acp::SessionId,
1338 _cx: &App,
1339 ) -> Option<Rc<dyn acp_thread::AgentSessionRetry>> {
1340 Some(Rc::new(NativeAgentSessionRetry {
1341 connection: self.clone(),
1342 session_id: session_id.clone(),
1343 }) as _)
1344 }
1345
1346 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
1347 log::info!("Cancelling on session: {}", session_id);
1348 self.0.update(cx, |agent, cx| {
1349 if let Some(agent) = agent.sessions.get(session_id) {
1350 agent
1351 .thread
1352 .update(cx, |thread, cx| thread.cancel(cx))
1353 .detach();
1354 }
1355 });
1356 }
1357
1358 fn truncate(
1359 &self,
1360 session_id: &agent_client_protocol::SessionId,
1361 cx: &App,
1362 ) -> Option<Rc<dyn acp_thread::AgentSessionTruncate>> {
1363 self.0.read_with(cx, |agent, _cx| {
1364 agent.sessions.get(session_id).map(|session| {
1365 Rc::new(NativeAgentSessionTruncate {
1366 thread: session.thread.clone(),
1367 acp_thread: session.acp_thread.clone(),
1368 }) as _
1369 })
1370 })
1371 }
1372
1373 fn set_title(
1374 &self,
1375 session_id: &acp::SessionId,
1376 _cx: &App,
1377 ) -> Option<Rc<dyn acp_thread::AgentSessionSetTitle>> {
1378 Some(Rc::new(NativeAgentSessionSetTitle {
1379 connection: self.clone(),
1380 session_id: session_id.clone(),
1381 }) as _)
1382 }
1383
1384 fn session_list(&self, cx: &mut App) -> Option<Rc<dyn AgentSessionList>> {
1385 let thread_store = self.0.read(cx).thread_store.clone();
1386 Some(Rc::new(NativeAgentSessionList::new(thread_store, cx)) as _)
1387 }
1388
1389 fn telemetry(&self) -> Option<Rc<dyn acp_thread::AgentTelemetry>> {
1390 Some(Rc::new(self.clone()) as Rc<dyn acp_thread::AgentTelemetry>)
1391 }
1392
1393 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1394 self
1395 }
1396}
1397
1398impl acp_thread::AgentTelemetry for NativeAgentConnection {
1399 fn thread_data(
1400 &self,
1401 session_id: &acp::SessionId,
1402 cx: &mut App,
1403 ) -> Task<Result<serde_json::Value>> {
1404 let Some(session) = self.0.read(cx).sessions.get(session_id) else {
1405 return Task::ready(Err(anyhow!("Session not found")));
1406 };
1407
1408 let task = session.thread.read(cx).to_db(cx);
1409 cx.background_spawn(async move {
1410 serde_json::to_value(task.await).context("Failed to serialize thread")
1411 })
1412 }
1413}
1414
1415pub struct NativeAgentSessionList {
1416 thread_store: Entity<ThreadStore>,
1417 updates_tx: smol::channel::Sender<acp_thread::SessionListUpdate>,
1418 updates_rx: smol::channel::Receiver<acp_thread::SessionListUpdate>,
1419 _subscription: Subscription,
1420}
1421
1422impl NativeAgentSessionList {
1423 fn new(thread_store: Entity<ThreadStore>, cx: &mut App) -> Self {
1424 let (tx, rx) = smol::channel::unbounded();
1425 let this_tx = tx.clone();
1426 let subscription = cx.observe(&thread_store, move |_, _| {
1427 this_tx
1428 .try_send(acp_thread::SessionListUpdate::Refresh)
1429 .ok();
1430 });
1431 Self {
1432 thread_store,
1433 updates_tx: tx,
1434 updates_rx: rx,
1435 _subscription: subscription,
1436 }
1437 }
1438
1439 fn to_session_info(entry: DbThreadMetadata) -> AgentSessionInfo {
1440 AgentSessionInfo {
1441 session_id: entry.id,
1442 cwd: None,
1443 title: Some(entry.title),
1444 updated_at: Some(entry.updated_at),
1445 meta: None,
1446 }
1447 }
1448
1449 pub fn thread_store(&self) -> &Entity<ThreadStore> {
1450 &self.thread_store
1451 }
1452}
1453
1454impl AgentSessionList for NativeAgentSessionList {
1455 fn list_sessions(
1456 &self,
1457 _request: AgentSessionListRequest,
1458 cx: &mut App,
1459 ) -> Task<Result<AgentSessionListResponse>> {
1460 let sessions = self
1461 .thread_store
1462 .read(cx)
1463 .entries()
1464 .map(Self::to_session_info)
1465 .collect();
1466 Task::ready(Ok(AgentSessionListResponse::new(sessions)))
1467 }
1468
1469 fn supports_delete(&self) -> bool {
1470 true
1471 }
1472
1473 fn delete_session(&self, session_id: &acp::SessionId, cx: &mut App) -> Task<Result<()>> {
1474 self.thread_store
1475 .update(cx, |store, cx| store.delete_thread(session_id.clone(), cx))
1476 }
1477
1478 fn delete_sessions(&self, cx: &mut App) -> Task<Result<()>> {
1479 self.thread_store
1480 .update(cx, |store, cx| store.delete_threads(cx))
1481 }
1482
1483 fn watch(
1484 &self,
1485 _cx: &mut App,
1486 ) -> Option<smol::channel::Receiver<acp_thread::SessionListUpdate>> {
1487 Some(self.updates_rx.clone())
1488 }
1489
1490 fn notify_refresh(&self) {
1491 self.updates_tx
1492 .try_send(acp_thread::SessionListUpdate::Refresh)
1493 .ok();
1494 }
1495
1496 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1497 self
1498 }
1499}
1500
1501struct NativeAgentSessionTruncate {
1502 thread: Entity<Thread>,
1503 acp_thread: WeakEntity<AcpThread>,
1504}
1505
1506impl acp_thread::AgentSessionTruncate for NativeAgentSessionTruncate {
1507 fn run(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
1508 match self.thread.update(cx, |thread, cx| {
1509 thread.truncate(message_id.clone(), cx)?;
1510 Ok(thread.latest_token_usage())
1511 }) {
1512 Ok(usage) => {
1513 self.acp_thread
1514 .update(cx, |thread, cx| {
1515 thread.update_token_usage(usage, cx);
1516 })
1517 .ok();
1518 Task::ready(Ok(()))
1519 }
1520 Err(error) => Task::ready(Err(error)),
1521 }
1522 }
1523}
1524
1525struct NativeAgentSessionRetry {
1526 connection: NativeAgentConnection,
1527 session_id: acp::SessionId,
1528}
1529
1530impl acp_thread::AgentSessionRetry for NativeAgentSessionRetry {
1531 fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>> {
1532 self.connection
1533 .run_turn(self.session_id.clone(), cx, |thread, cx| {
1534 thread.update(cx, |thread, cx| thread.resume(cx))
1535 })
1536 }
1537}
1538
1539struct NativeAgentSessionSetTitle {
1540 connection: NativeAgentConnection,
1541 session_id: acp::SessionId,
1542}
1543
1544impl acp_thread::AgentSessionSetTitle for NativeAgentSessionSetTitle {
1545 fn run(&self, title: SharedString, cx: &mut App) -> Task<Result<()>> {
1546 let Some(session) = self.connection.0.read(cx).sessions.get(&self.session_id) else {
1547 return Task::ready(Err(anyhow!("session not found")));
1548 };
1549 let thread = session.thread.clone();
1550 thread.update(cx, |thread, cx| thread.set_title(title, cx));
1551 Task::ready(Ok(()))
1552 }
1553}
1554
1555pub struct AcpThreadEnvironment {
1556 acp_thread: WeakEntity<AcpThread>,
1557}
1558
1559impl ThreadEnvironment for AcpThreadEnvironment {
1560 fn create_terminal(
1561 &self,
1562 command: String,
1563 cwd: Option<PathBuf>,
1564 output_byte_limit: Option<u64>,
1565 cx: &mut AsyncApp,
1566 ) -> Task<Result<Rc<dyn TerminalHandle>>> {
1567 let task = self.acp_thread.update(cx, |thread, cx| {
1568 thread.create_terminal(command, vec![], vec![], cwd, output_byte_limit, cx)
1569 });
1570
1571 let acp_thread = self.acp_thread.clone();
1572 cx.spawn(async move |cx| {
1573 let terminal = task?.await?;
1574
1575 let (drop_tx, drop_rx) = oneshot::channel();
1576 let terminal_id = terminal.read_with(cx, |terminal, _cx| terminal.id().clone());
1577
1578 cx.spawn(async move |cx| {
1579 drop_rx.await.ok();
1580 acp_thread.update(cx, |thread, cx| thread.release_terminal(terminal_id, cx))
1581 })
1582 .detach();
1583
1584 let handle = AcpTerminalHandle {
1585 terminal,
1586 _drop_tx: Some(drop_tx),
1587 };
1588
1589 Ok(Rc::new(handle) as _)
1590 })
1591 }
1592}
1593
1594pub struct AcpTerminalHandle {
1595 terminal: Entity<acp_thread::Terminal>,
1596 _drop_tx: Option<oneshot::Sender<()>>,
1597}
1598
1599impl TerminalHandle for AcpTerminalHandle {
1600 fn id(&self, cx: &AsyncApp) -> Result<acp::TerminalId> {
1601 Ok(self.terminal.read_with(cx, |term, _cx| term.id().clone()))
1602 }
1603
1604 fn wait_for_exit(&self, cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
1605 Ok(self
1606 .terminal
1607 .read_with(cx, |term, _cx| term.wait_for_exit()))
1608 }
1609
1610 fn current_output(&self, cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
1611 Ok(self
1612 .terminal
1613 .read_with(cx, |term, cx| term.current_output(cx)))
1614 }
1615
1616 fn kill(&self, cx: &AsyncApp) -> Result<()> {
1617 cx.update(|cx| {
1618 self.terminal.update(cx, |terminal, cx| {
1619 terminal.kill(cx);
1620 });
1621 });
1622 Ok(())
1623 }
1624
1625 fn was_stopped_by_user(&self, cx: &AsyncApp) -> Result<bool> {
1626 Ok(self
1627 .terminal
1628 .read_with(cx, |term, _cx| term.was_stopped_by_user()))
1629 }
1630}
1631
1632#[cfg(test)]
1633mod internal_tests {
1634 use super::*;
1635 use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelInfo, MentionUri};
1636 use fs::FakeFs;
1637 use gpui::TestAppContext;
1638 use indoc::formatdoc;
1639 use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
1640 use language_model::{LanguageModelProviderId, LanguageModelProviderName};
1641 use serde_json::json;
1642 use settings::SettingsStore;
1643 use util::{path, rel_path::rel_path};
1644
1645 #[gpui::test]
1646 async fn test_maintaining_project_context(cx: &mut TestAppContext) {
1647 init_test(cx);
1648 let fs = FakeFs::new(cx.executor());
1649 fs.insert_tree(
1650 "/",
1651 json!({
1652 "a": {}
1653 }),
1654 )
1655 .await;
1656 let project = Project::test(fs.clone(), [], cx).await;
1657 let thread_store = cx.new(|cx| ThreadStore::new(cx));
1658 let agent = NativeAgent::new(
1659 project.clone(),
1660 thread_store,
1661 Templates::new(),
1662 None,
1663 fs.clone(),
1664 &mut cx.to_async(),
1665 )
1666 .await
1667 .unwrap();
1668 agent.read_with(cx, |agent, cx| {
1669 assert_eq!(agent.project_context.read(cx).worktrees, vec![])
1670 });
1671
1672 let worktree = project
1673 .update(cx, |project, cx| project.create_worktree("/a", true, cx))
1674 .await
1675 .unwrap();
1676 cx.run_until_parked();
1677 agent.read_with(cx, |agent, cx| {
1678 assert_eq!(
1679 agent.project_context.read(cx).worktrees,
1680 vec![WorktreeContext {
1681 root_name: "a".into(),
1682 abs_path: Path::new("/a").into(),
1683 rules_file: None
1684 }]
1685 )
1686 });
1687
1688 // Creating `/a/.rules` updates the project context.
1689 fs.insert_file("/a/.rules", Vec::new()).await;
1690 cx.run_until_parked();
1691 agent.read_with(cx, |agent, cx| {
1692 let rules_entry = worktree
1693 .read(cx)
1694 .entry_for_path(rel_path(".rules"))
1695 .unwrap();
1696 assert_eq!(
1697 agent.project_context.read(cx).worktrees,
1698 vec![WorktreeContext {
1699 root_name: "a".into(),
1700 abs_path: Path::new("/a").into(),
1701 rules_file: Some(RulesFileContext {
1702 path_in_worktree: rel_path(".rules").into(),
1703 text: "".into(),
1704 project_entry_id: rules_entry.id.to_usize()
1705 })
1706 }]
1707 )
1708 });
1709 }
1710
1711 #[gpui::test]
1712 async fn test_listing_models(cx: &mut TestAppContext) {
1713 init_test(cx);
1714 let fs = FakeFs::new(cx.executor());
1715 fs.insert_tree("/", json!({ "a": {} })).await;
1716 let project = Project::test(fs.clone(), [], cx).await;
1717 let thread_store = cx.new(|cx| ThreadStore::new(cx));
1718 let connection = NativeAgentConnection(
1719 NativeAgent::new(
1720 project.clone(),
1721 thread_store,
1722 Templates::new(),
1723 None,
1724 fs.clone(),
1725 &mut cx.to_async(),
1726 )
1727 .await
1728 .unwrap(),
1729 );
1730
1731 // Create a thread/session
1732 let acp_thread = cx
1733 .update(|cx| {
1734 Rc::new(connection.clone()).new_thread(project.clone(), Path::new("/a"), cx)
1735 })
1736 .await
1737 .unwrap();
1738
1739 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
1740
1741 let models = cx
1742 .update(|cx| {
1743 connection
1744 .model_selector(&session_id)
1745 .unwrap()
1746 .list_models(cx)
1747 })
1748 .await
1749 .unwrap();
1750
1751 let acp_thread::AgentModelList::Grouped(models) = models else {
1752 panic!("Unexpected model group");
1753 };
1754 assert_eq!(
1755 models,
1756 IndexMap::from_iter([(
1757 AgentModelGroupName("Fake".into()),
1758 vec![AgentModelInfo {
1759 id: acp::ModelId::new("fake/fake"),
1760 name: "Fake".into(),
1761 description: None,
1762 icon: Some(acp_thread::AgentModelIcon::Named(
1763 ui::IconName::ZedAssistant
1764 )),
1765 is_latest: false,
1766 }]
1767 )])
1768 );
1769 }
1770
1771 #[gpui::test]
1772 async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) {
1773 init_test(cx);
1774 let fs = FakeFs::new(cx.executor());
1775 fs.create_dir(paths::settings_file().parent().unwrap())
1776 .await
1777 .unwrap();
1778 fs.insert_file(
1779 paths::settings_file(),
1780 json!({
1781 "agent": {
1782 "default_model": {
1783 "provider": "foo",
1784 "model": "bar"
1785 }
1786 }
1787 })
1788 .to_string()
1789 .into_bytes(),
1790 )
1791 .await;
1792 let project = Project::test(fs.clone(), [], cx).await;
1793
1794 let thread_store = cx.new(|cx| ThreadStore::new(cx));
1795
1796 // Create the agent and connection
1797 let agent = NativeAgent::new(
1798 project.clone(),
1799 thread_store,
1800 Templates::new(),
1801 None,
1802 fs.clone(),
1803 &mut cx.to_async(),
1804 )
1805 .await
1806 .unwrap();
1807 let connection = NativeAgentConnection(agent.clone());
1808
1809 // Create a thread/session
1810 let acp_thread = cx
1811 .update(|cx| {
1812 Rc::new(connection.clone()).new_thread(project.clone(), Path::new("/a"), cx)
1813 })
1814 .await
1815 .unwrap();
1816
1817 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
1818
1819 // Select a model
1820 let selector = connection.model_selector(&session_id).unwrap();
1821 let model_id = acp::ModelId::new("fake/fake");
1822 cx.update(|cx| selector.select_model(model_id.clone(), cx))
1823 .await
1824 .unwrap();
1825
1826 // Verify the thread has the selected model
1827 agent.read_with(cx, |agent, _| {
1828 let session = agent.sessions.get(&session_id).unwrap();
1829 session.thread.read_with(cx, |thread, _| {
1830 assert_eq!(thread.model().unwrap().id().0, "fake");
1831 });
1832 });
1833
1834 cx.run_until_parked();
1835
1836 // Verify settings file was updated
1837 let settings_content = fs.load(paths::settings_file()).await.unwrap();
1838 let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
1839
1840 // Check that the agent settings contain the selected model
1841 assert_eq!(
1842 settings_json["agent"]["default_model"]["model"],
1843 json!("fake")
1844 );
1845 assert_eq!(
1846 settings_json["agent"]["default_model"]["provider"],
1847 json!("fake")
1848 );
1849
1850 // Register a thinking model and select it.
1851 cx.update(|cx| {
1852 let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
1853 "fake-corp",
1854 "fake-thinking",
1855 "Fake Thinking",
1856 true,
1857 ));
1858 let thinking_provider = Arc::new(
1859 FakeLanguageModelProvider::new(
1860 LanguageModelProviderId::from("fake-corp".to_string()),
1861 LanguageModelProviderName::from("Fake Corp".to_string()),
1862 )
1863 .with_models(vec![thinking_model]),
1864 );
1865 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
1866 registry.register_provider(thinking_provider, cx);
1867 });
1868 });
1869 agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
1870
1871 let selector = connection.model_selector(&session_id).unwrap();
1872 cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
1873 .await
1874 .unwrap();
1875 cx.run_until_parked();
1876
1877 // Verify enable_thinking was written to settings as true.
1878 let settings_content = fs.load(paths::settings_file()).await.unwrap();
1879 let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
1880 assert_eq!(
1881 settings_json["agent"]["default_model"]["enable_thinking"],
1882 json!(true),
1883 "selecting a thinking model should persist enable_thinking: true to settings"
1884 );
1885 }
1886
1887 #[gpui::test]
1888 async fn test_select_model_updates_thinking_enabled(cx: &mut TestAppContext) {
1889 init_test(cx);
1890 let fs = FakeFs::new(cx.executor());
1891 fs.create_dir(paths::settings_file().parent().unwrap())
1892 .await
1893 .unwrap();
1894 fs.insert_file(paths::settings_file(), b"{}".to_vec()).await;
1895 let project = Project::test(fs.clone(), [], cx).await;
1896
1897 let thread_store = cx.new(|cx| ThreadStore::new(cx));
1898 let agent = NativeAgent::new(
1899 project.clone(),
1900 thread_store,
1901 Templates::new(),
1902 None,
1903 fs.clone(),
1904 &mut cx.to_async(),
1905 )
1906 .await
1907 .unwrap();
1908 let connection = NativeAgentConnection(agent.clone());
1909
1910 let acp_thread = cx
1911 .update(|cx| {
1912 Rc::new(connection.clone()).new_thread(project.clone(), Path::new("/a"), cx)
1913 })
1914 .await
1915 .unwrap();
1916 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
1917
1918 // Register a second provider with a thinking model.
1919 cx.update(|cx| {
1920 let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
1921 "fake-corp",
1922 "fake-thinking",
1923 "Fake Thinking",
1924 true,
1925 ));
1926 let thinking_provider = Arc::new(
1927 FakeLanguageModelProvider::new(
1928 LanguageModelProviderId::from("fake-corp".to_string()),
1929 LanguageModelProviderName::from("Fake Corp".to_string()),
1930 )
1931 .with_models(vec![thinking_model]),
1932 );
1933 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
1934 registry.register_provider(thinking_provider, cx);
1935 });
1936 });
1937 // Refresh the agent's model list so it picks up the new provider.
1938 agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
1939
1940 // Thread starts with thinking_enabled = false (the default).
1941 agent.read_with(cx, |agent, _| {
1942 let session = agent.sessions.get(&session_id).unwrap();
1943 session.thread.read_with(cx, |thread, _| {
1944 assert!(!thread.thinking_enabled(), "thinking defaults to false");
1945 });
1946 });
1947
1948 // Select the thinking model via select_model.
1949 let selector = connection.model_selector(&session_id).unwrap();
1950 cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
1951 .await
1952 .unwrap();
1953
1954 // select_model should have enabled thinking based on the model's supports_thinking().
1955 agent.read_with(cx, |agent, _| {
1956 let session = agent.sessions.get(&session_id).unwrap();
1957 session.thread.read_with(cx, |thread, _| {
1958 assert!(
1959 thread.thinking_enabled(),
1960 "select_model should enable thinking when model supports it"
1961 );
1962 });
1963 });
1964
1965 // Switch back to the non-thinking model.
1966 let selector = connection.model_selector(&session_id).unwrap();
1967 cx.update(|cx| selector.select_model(acp::ModelId::new("fake/fake"), cx))
1968 .await
1969 .unwrap();
1970
1971 // select_model should have disabled thinking.
1972 agent.read_with(cx, |agent, _| {
1973 let session = agent.sessions.get(&session_id).unwrap();
1974 session.thread.read_with(cx, |thread, _| {
1975 assert!(
1976 !thread.thinking_enabled(),
1977 "select_model should disable thinking when model does not support it"
1978 );
1979 });
1980 });
1981 }
1982
1983 #[gpui::test]
1984 async fn test_loaded_thread_preserves_thinking_enabled(cx: &mut TestAppContext) {
1985 init_test(cx);
1986 let fs = FakeFs::new(cx.executor());
1987 fs.insert_tree("/", json!({ "a": {} })).await;
1988 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
1989 let thread_store = cx.new(|cx| ThreadStore::new(cx));
1990 let agent = NativeAgent::new(
1991 project.clone(),
1992 thread_store.clone(),
1993 Templates::new(),
1994 None,
1995 fs.clone(),
1996 &mut cx.to_async(),
1997 )
1998 .await
1999 .unwrap();
2000 let connection = Rc::new(NativeAgentConnection(agent.clone()));
2001
2002 // Register a thinking model.
2003 let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2004 "fake-corp",
2005 "fake-thinking",
2006 "Fake Thinking",
2007 true,
2008 ));
2009 let thinking_provider = Arc::new(
2010 FakeLanguageModelProvider::new(
2011 LanguageModelProviderId::from("fake-corp".to_string()),
2012 LanguageModelProviderName::from("Fake Corp".to_string()),
2013 )
2014 .with_models(vec![thinking_model.clone()]),
2015 );
2016 cx.update(|cx| {
2017 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2018 registry.register_provider(thinking_provider, cx);
2019 });
2020 });
2021 agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2022
2023 // Create a thread and select the thinking model.
2024 let acp_thread = cx
2025 .update(|cx| {
2026 connection
2027 .clone()
2028 .new_thread(project.clone(), Path::new("/a"), cx)
2029 })
2030 .await
2031 .unwrap();
2032 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2033
2034 let selector = connection.model_selector(&session_id).unwrap();
2035 cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2036 .await
2037 .unwrap();
2038
2039 // Verify thinking is enabled after selecting the thinking model.
2040 let thread = agent.read_with(cx, |agent, _| {
2041 agent.sessions.get(&session_id).unwrap().thread.clone()
2042 });
2043 thread.read_with(cx, |thread, _| {
2044 assert!(
2045 thread.thinking_enabled(),
2046 "thinking should be enabled after selecting thinking model"
2047 );
2048 });
2049
2050 // Send a message so the thread gets persisted.
2051 let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
2052 let send = cx.foreground_executor().spawn(send);
2053 cx.run_until_parked();
2054
2055 thinking_model.send_last_completion_stream_text_chunk("Response.");
2056 thinking_model.end_last_completion_stream();
2057
2058 send.await.unwrap();
2059 cx.run_until_parked();
2060
2061 // Drop the thread so it can be reloaded from disk.
2062 cx.update(|_| {
2063 drop(thread);
2064 drop(acp_thread);
2065 });
2066 agent.read_with(cx, |agent, _| {
2067 assert!(agent.sessions.is_empty());
2068 });
2069
2070 // Reload the thread and verify thinking_enabled is still true.
2071 let reloaded_acp_thread = agent
2072 .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
2073 .await
2074 .unwrap();
2075 let reloaded_thread = agent.read_with(cx, |agent, _| {
2076 agent.sessions.get(&session_id).unwrap().thread.clone()
2077 });
2078 reloaded_thread.read_with(cx, |thread, _| {
2079 assert!(
2080 thread.thinking_enabled(),
2081 "thinking_enabled should be preserved when reloading a thread with a thinking model"
2082 );
2083 });
2084
2085 drop(reloaded_acp_thread);
2086 }
2087
2088 #[gpui::test]
2089 async fn test_loaded_thread_preserves_model(cx: &mut TestAppContext) {
2090 init_test(cx);
2091 let fs = FakeFs::new(cx.executor());
2092 fs.insert_tree("/", json!({ "a": {} })).await;
2093 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2094 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2095 let agent = NativeAgent::new(
2096 project.clone(),
2097 thread_store.clone(),
2098 Templates::new(),
2099 None,
2100 fs.clone(),
2101 &mut cx.to_async(),
2102 )
2103 .await
2104 .unwrap();
2105 let connection = Rc::new(NativeAgentConnection(agent.clone()));
2106
2107 // Register a model where id() != name(), like real Anthropic models
2108 // (e.g. id="claude-sonnet-4-5-thinking-latest", name="Claude Sonnet 4.5 Thinking").
2109 let model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2110 "fake-corp",
2111 "custom-model-id",
2112 "Custom Model Display Name",
2113 false,
2114 ));
2115 let provider = Arc::new(
2116 FakeLanguageModelProvider::new(
2117 LanguageModelProviderId::from("fake-corp".to_string()),
2118 LanguageModelProviderName::from("Fake Corp".to_string()),
2119 )
2120 .with_models(vec![model.clone()]),
2121 );
2122 cx.update(|cx| {
2123 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2124 registry.register_provider(provider, cx);
2125 });
2126 });
2127 agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2128
2129 // Create a thread and select the model.
2130 let acp_thread = cx
2131 .update(|cx| {
2132 connection
2133 .clone()
2134 .new_thread(project.clone(), Path::new("/a"), cx)
2135 })
2136 .await
2137 .unwrap();
2138 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2139
2140 let selector = connection.model_selector(&session_id).unwrap();
2141 cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/custom-model-id"), cx))
2142 .await
2143 .unwrap();
2144
2145 let thread = agent.read_with(cx, |agent, _| {
2146 agent.sessions.get(&session_id).unwrap().thread.clone()
2147 });
2148 thread.read_with(cx, |thread, _| {
2149 assert_eq!(
2150 thread.model().unwrap().id().0.as_ref(),
2151 "custom-model-id",
2152 "model should be set before persisting"
2153 );
2154 });
2155
2156 // Send a message so the thread gets persisted.
2157 let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
2158 let send = cx.foreground_executor().spawn(send);
2159 cx.run_until_parked();
2160
2161 model.send_last_completion_stream_text_chunk("Response.");
2162 model.end_last_completion_stream();
2163
2164 send.await.unwrap();
2165 cx.run_until_parked();
2166
2167 // Drop the thread so it can be reloaded from disk.
2168 cx.update(|_| {
2169 drop(thread);
2170 drop(acp_thread);
2171 });
2172 agent.read_with(cx, |agent, _| {
2173 assert!(agent.sessions.is_empty());
2174 });
2175
2176 // Reload the thread and verify the model was preserved.
2177 let reloaded_acp_thread = agent
2178 .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
2179 .await
2180 .unwrap();
2181 let reloaded_thread = agent.read_with(cx, |agent, _| {
2182 agent.sessions.get(&session_id).unwrap().thread.clone()
2183 });
2184 reloaded_thread.read_with(cx, |thread, _| {
2185 let reloaded_model = thread
2186 .model()
2187 .expect("model should be present after reload");
2188 assert_eq!(
2189 reloaded_model.id().0.as_ref(),
2190 "custom-model-id",
2191 "reloaded thread should have the same model, not fall back to the default"
2192 );
2193 });
2194
2195 drop(reloaded_acp_thread);
2196 }
2197
2198 #[gpui::test]
2199 async fn test_save_load_thread(cx: &mut TestAppContext) {
2200 init_test(cx);
2201 let fs = FakeFs::new(cx.executor());
2202 fs.insert_tree(
2203 "/",
2204 json!({
2205 "a": {
2206 "b.md": "Lorem"
2207 }
2208 }),
2209 )
2210 .await;
2211 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2212 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2213 let agent = NativeAgent::new(
2214 project.clone(),
2215 thread_store.clone(),
2216 Templates::new(),
2217 None,
2218 fs.clone(),
2219 &mut cx.to_async(),
2220 )
2221 .await
2222 .unwrap();
2223 let connection = Rc::new(NativeAgentConnection(agent.clone()));
2224
2225 let acp_thread = cx
2226 .update(|cx| {
2227 connection
2228 .clone()
2229 .new_thread(project.clone(), Path::new(""), cx)
2230 })
2231 .await
2232 .unwrap();
2233 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2234 let thread = agent.read_with(cx, |agent, _| {
2235 agent.sessions.get(&session_id).unwrap().thread.clone()
2236 });
2237
2238 // Ensure empty threads are not saved, even if they get mutated.
2239 let model = Arc::new(FakeLanguageModel::default());
2240 let summary_model = Arc::new(FakeLanguageModel::default());
2241 thread.update(cx, |thread, cx| {
2242 thread.set_model(model.clone(), cx);
2243 thread.set_summarization_model(Some(summary_model.clone()), cx);
2244 });
2245 cx.run_until_parked();
2246 assert_eq!(thread_entries(&thread_store, cx), vec![]);
2247
2248 let send = acp_thread.update(cx, |thread, cx| {
2249 thread.send(
2250 vec![
2251 "What does ".into(),
2252 acp::ContentBlock::ResourceLink(acp::ResourceLink::new(
2253 "b.md",
2254 MentionUri::File {
2255 abs_path: path!("/a/b.md").into(),
2256 }
2257 .to_uri()
2258 .to_string(),
2259 )),
2260 " mean?".into(),
2261 ],
2262 cx,
2263 )
2264 });
2265 let send = cx.foreground_executor().spawn(send);
2266 cx.run_until_parked();
2267
2268 model.send_last_completion_stream_text_chunk("Lorem.");
2269 model.end_last_completion_stream();
2270 cx.run_until_parked();
2271 summary_model
2272 .send_last_completion_stream_text_chunk(&format!("Explaining {}", path!("/a/b.md")));
2273 summary_model.end_last_completion_stream();
2274
2275 send.await.unwrap();
2276 let uri = MentionUri::File {
2277 abs_path: path!("/a/b.md").into(),
2278 }
2279 .to_uri();
2280 acp_thread.read_with(cx, |thread, cx| {
2281 assert_eq!(
2282 thread.to_markdown(cx),
2283 formatdoc! {"
2284 ## User
2285
2286 What does [@b.md]({uri}) mean?
2287
2288 ## Assistant
2289
2290 Lorem.
2291
2292 "}
2293 )
2294 });
2295
2296 cx.run_until_parked();
2297
2298 // Drop the ACP thread, which should cause the session to be dropped as well.
2299 cx.update(|_| {
2300 drop(thread);
2301 drop(acp_thread);
2302 });
2303 agent.read_with(cx, |agent, _| {
2304 assert_eq!(agent.sessions.keys().cloned().collect::<Vec<_>>(), []);
2305 });
2306
2307 // Ensure the thread can be reloaded from disk.
2308 assert_eq!(
2309 thread_entries(&thread_store, cx),
2310 vec![(
2311 session_id.clone(),
2312 format!("Explaining {}", path!("/a/b.md"))
2313 )]
2314 );
2315 let acp_thread = agent
2316 .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
2317 .await
2318 .unwrap();
2319 acp_thread.read_with(cx, |thread, cx| {
2320 assert_eq!(
2321 thread.to_markdown(cx),
2322 formatdoc! {"
2323 ## User
2324
2325 What does [@b.md]({uri}) mean?
2326
2327 ## Assistant
2328
2329 Lorem.
2330
2331 "}
2332 )
2333 });
2334 }
2335
2336 fn thread_entries(
2337 thread_store: &Entity<ThreadStore>,
2338 cx: &mut TestAppContext,
2339 ) -> Vec<(acp::SessionId, String)> {
2340 thread_store.read_with(cx, |store, _| {
2341 store
2342 .entries()
2343 .map(|entry| (entry.id.clone(), entry.title.to_string()))
2344 .collect::<Vec<_>>()
2345 })
2346 }
2347
2348 fn init_test(cx: &mut TestAppContext) {
2349 env_logger::try_init().ok();
2350 cx.update(|cx| {
2351 let settings_store = SettingsStore::test(cx);
2352 cx.set_global(settings_store);
2353
2354 LanguageModelRegistry::test(cx);
2355 });
2356 }
2357}
2358
2359fn mcp_message_content_to_acp_content_block(
2360 content: context_server::types::MessageContent,
2361) -> acp::ContentBlock {
2362 match content {
2363 context_server::types::MessageContent::Text {
2364 text,
2365 annotations: _,
2366 } => text.into(),
2367 context_server::types::MessageContent::Image {
2368 data,
2369 mime_type,
2370 annotations: _,
2371 } => acp::ContentBlock::Image(acp::ImageContent::new(data, mime_type)),
2372 context_server::types::MessageContent::Audio {
2373 data,
2374 mime_type,
2375 annotations: _,
2376 } => acp::ContentBlock::Audio(acp::AudioContent::new(data, mime_type)),
2377 context_server::types::MessageContent::Resource {
2378 resource,
2379 annotations: _,
2380 } => {
2381 let mut link =
2382 acp::ResourceLink::new(resource.uri.to_string(), resource.uri.to_string());
2383 if let Some(mime_type) = resource.mime_type {
2384 link = link.mime_type(mime_type);
2385 }
2386 acp::ContentBlock::ResourceLink(link)
2387 }
2388 }
2389}