1mod db;
2mod edit_agent;
3mod legacy_thread;
4mod native_agent_server;
5pub mod outline;
6mod templates;
7#[cfg(test)]
8mod tests;
9mod thread;
10mod thread_store;
11mod tool_permissions;
12mod tools;
13
14use context_server::ContextServerId;
15pub use db::*;
16pub use native_agent_server::NativeAgentServer;
17pub use templates::*;
18pub use thread::*;
19pub use thread_store::*;
20pub use tool_permissions::*;
21pub use tools::*;
22
23use acp_thread::{AcpThread, AgentModelSelector, UserMessageId};
24use agent_client_protocol as acp;
25use anyhow::{Context as _, Result, anyhow};
26use chrono::{DateTime, Utc};
27use collections::{HashMap, HashSet, IndexMap};
28use fs::Fs;
29use futures::channel::{mpsc, oneshot};
30use futures::future::Shared;
31use futures::{StreamExt, future};
32use gpui::{
33 App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
34};
35use language_model::{IconOrSvg, LanguageModel, LanguageModelProvider, LanguageModelRegistry};
36use project::{Project, ProjectItem, ProjectPath, Worktree};
37use prompt_store::{
38 ProjectContext, PromptStore, RULES_FILE_NAMES, RulesFileContext, UserRulesContext,
39 WorktreeContext,
40};
41use serde::{Deserialize, Serialize};
42use settings::{LanguageModelSelection, update_settings_file};
43use std::any::Any;
44use std::path::{Path, PathBuf};
45use std::rc::Rc;
46use std::sync::Arc;
47use util::ResultExt;
48use util::rel_path::RelPath;
49
50#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
51pub struct ProjectSnapshot {
52 pub worktree_snapshots: Vec<project::telemetry_snapshot::TelemetryWorktreeSnapshot>,
53 pub timestamp: DateTime<Utc>,
54}
55
56pub struct RulesLoadingError {
57 pub message: SharedString,
58}
59
60/// Holds both the internal Thread and the AcpThread for a session
61struct Session {
62 /// The internal thread that processes messages
63 thread: Entity<Thread>,
64 /// The ACP thread that handles protocol communication
65 acp_thread: WeakEntity<acp_thread::AcpThread>,
66 pending_save: Task<()>,
67 _subscriptions: Vec<Subscription>,
68}
69
70pub struct LanguageModels {
71 /// Access language model by ID
72 models: HashMap<acp::ModelId, Arc<dyn LanguageModel>>,
73 /// Cached list for returning language model information
74 model_list: acp_thread::AgentModelList,
75 refresh_models_rx: watch::Receiver<()>,
76 refresh_models_tx: watch::Sender<()>,
77 _authenticate_all_providers_task: Task<()>,
78}
79
80impl LanguageModels {
81 fn new(cx: &mut App) -> Self {
82 let (refresh_models_tx, refresh_models_rx) = watch::channel(());
83
84 let mut this = Self {
85 models: HashMap::default(),
86 model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()),
87 refresh_models_rx,
88 refresh_models_tx,
89 _authenticate_all_providers_task: Self::authenticate_all_language_model_providers(cx),
90 };
91 this.refresh_list(cx);
92 this
93 }
94
95 fn refresh_list(&mut self, cx: &App) {
96 let providers = LanguageModelRegistry::global(cx)
97 .read(cx)
98 .visible_providers()
99 .into_iter()
100 .filter(|provider| provider.is_authenticated(cx))
101 .collect::<Vec<_>>();
102
103 let mut language_model_list = IndexMap::default();
104 let mut recommended_models = HashSet::default();
105
106 let mut recommended = Vec::new();
107 for provider in &providers {
108 for model in provider.recommended_models(cx) {
109 recommended_models.insert((model.provider_id(), model.id()));
110 recommended.push(Self::map_language_model_to_info(&model, provider));
111 }
112 }
113 if !recommended.is_empty() {
114 language_model_list.insert(
115 acp_thread::AgentModelGroupName("Recommended".into()),
116 recommended,
117 );
118 }
119
120 let mut models = HashMap::default();
121 for provider in providers {
122 let mut provider_models = Vec::new();
123 for model in provider.provided_models(cx) {
124 let model_info = Self::map_language_model_to_info(&model, &provider);
125 let model_id = model_info.id.clone();
126 provider_models.push(model_info);
127 models.insert(model_id, model);
128 }
129 if !provider_models.is_empty() {
130 language_model_list.insert(
131 acp_thread::AgentModelGroupName(provider.name().0.clone()),
132 provider_models,
133 );
134 }
135 }
136
137 self.models = models;
138 self.model_list = acp_thread::AgentModelList::Grouped(language_model_list);
139 self.refresh_models_tx.send(()).ok();
140 }
141
142 fn watch(&self) -> watch::Receiver<()> {
143 self.refresh_models_rx.clone()
144 }
145
146 pub fn model_from_id(&self, model_id: &acp::ModelId) -> Option<Arc<dyn LanguageModel>> {
147 self.models.get(model_id).cloned()
148 }
149
150 fn map_language_model_to_info(
151 model: &Arc<dyn LanguageModel>,
152 provider: &Arc<dyn LanguageModelProvider>,
153 ) -> acp_thread::AgentModelInfo {
154 acp_thread::AgentModelInfo {
155 id: Self::model_id(model),
156 name: model.name().0,
157 description: None,
158 icon: Some(match provider.icon() {
159 IconOrSvg::Svg(path) => acp_thread::AgentModelIcon::Path(path),
160 IconOrSvg::Icon(name) => acp_thread::AgentModelIcon::Named(name),
161 }),
162 }
163 }
164
165 fn model_id(model: &Arc<dyn LanguageModel>) -> acp::ModelId {
166 acp::ModelId::new(format!("{}/{}", model.provider_id().0, model.id().0))
167 }
168
169 fn authenticate_all_language_model_providers(cx: &mut App) -> Task<()> {
170 let authenticate_all_providers = LanguageModelRegistry::global(cx)
171 .read(cx)
172 .visible_providers()
173 .iter()
174 .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
175 .collect::<Vec<_>>();
176
177 cx.background_spawn(async move {
178 for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
179 if let Err(err) = authenticate_task.await {
180 match err {
181 language_model::AuthenticateError::CredentialsNotFound => {
182 // Since we're authenticating these providers in the
183 // background for the purposes of populating the
184 // language selector, we don't care about providers
185 // where the credentials are not found.
186 }
187 language_model::AuthenticateError::ConnectionRefused => {
188 // Not logging connection refused errors as they are mostly from LM Studio's noisy auth failures.
189 // LM Studio only has one auth method (endpoint call) which fails for users who haven't enabled it.
190 // TODO: Better manage LM Studio auth logic to avoid these noisy failures.
191 }
192 _ => {
193 // Some providers have noisy failure states that we
194 // don't want to spam the logs with every time the
195 // language model selector is initialized.
196 //
197 // Ideally these should have more clear failure modes
198 // that we know are safe to ignore here, like what we do
199 // with `CredentialsNotFound` above.
200 match provider_id.0.as_ref() {
201 "lmstudio" | "ollama" => {
202 // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
203 //
204 // These fail noisily, so we don't log them.
205 }
206 "copilot_chat" => {
207 // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
208 }
209 _ => {
210 log::error!(
211 "Failed to authenticate provider: {}: {err:#}",
212 provider_name.0
213 );
214 }
215 }
216 }
217 }
218 }
219 }
220 })
221 }
222}
223
224pub struct NativeAgent {
225 /// Session ID -> Session mapping
226 sessions: HashMap<acp::SessionId, Session>,
227 thread_store: Entity<ThreadStore>,
228 /// Shared project context for all threads
229 project_context: Entity<ProjectContext>,
230 project_context_needs_refresh: watch::Sender<()>,
231 _maintain_project_context: Task<Result<()>>,
232 context_server_registry: Entity<ContextServerRegistry>,
233 /// Shared templates for all threads
234 templates: Arc<Templates>,
235 /// Cached model information
236 models: LanguageModels,
237 project: Entity<Project>,
238 prompt_store: Option<Entity<PromptStore>>,
239 fs: Arc<dyn Fs>,
240 _subscriptions: Vec<Subscription>,
241}
242
243impl NativeAgent {
244 pub async fn new(
245 project: Entity<Project>,
246 thread_store: Entity<ThreadStore>,
247 templates: Arc<Templates>,
248 prompt_store: Option<Entity<PromptStore>>,
249 fs: Arc<dyn Fs>,
250 cx: &mut AsyncApp,
251 ) -> Result<Entity<NativeAgent>> {
252 log::debug!("Creating new NativeAgent");
253
254 let project_context = cx
255 .update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))
256 .await;
257
258 Ok(cx.new(|cx| {
259 let context_server_store = project.read(cx).context_server_store();
260 let context_server_registry =
261 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
262
263 let mut subscriptions = vec![
264 cx.subscribe(&project, Self::handle_project_event),
265 cx.subscribe(
266 &LanguageModelRegistry::global(cx),
267 Self::handle_models_updated_event,
268 ),
269 cx.subscribe(
270 &context_server_store,
271 Self::handle_context_server_store_updated,
272 ),
273 cx.subscribe(
274 &context_server_registry,
275 Self::handle_context_server_registry_event,
276 ),
277 ];
278 if let Some(prompt_store) = prompt_store.as_ref() {
279 subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
280 }
281
282 let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
283 watch::channel(());
284 Self {
285 sessions: HashMap::default(),
286 thread_store,
287 project_context: cx.new(|_| project_context),
288 project_context_needs_refresh: project_context_needs_refresh_tx,
289 _maintain_project_context: cx.spawn(async move |this, cx| {
290 Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
291 }),
292 context_server_registry,
293 templates,
294 models: LanguageModels::new(cx),
295 project,
296 prompt_store,
297 fs,
298 _subscriptions: subscriptions,
299 }
300 }))
301 }
302
303 fn register_session(
304 &mut self,
305 thread_handle: Entity<Thread>,
306 cx: &mut Context<Self>,
307 ) -> Entity<AcpThread> {
308 let connection = Rc::new(NativeAgentConnection(cx.entity()));
309
310 let thread = thread_handle.read(cx);
311 let session_id = thread.id().clone();
312 let title = thread.title();
313 let project = thread.project.clone();
314 let action_log = thread.action_log.clone();
315 let prompt_capabilities_rx = thread.prompt_capabilities_rx.clone();
316 let acp_thread = cx.new(|cx| {
317 acp_thread::AcpThread::new(
318 title,
319 connection,
320 project.clone(),
321 action_log.clone(),
322 session_id.clone(),
323 prompt_capabilities_rx,
324 cx,
325 )
326 });
327
328 let registry = LanguageModelRegistry::read_global(cx);
329 let summarization_model = registry.thread_summary_model().map(|c| c.model);
330
331 thread_handle.update(cx, |thread, cx| {
332 thread.set_summarization_model(summarization_model, cx);
333 thread.add_default_tools(
334 Rc::new(AcpThreadEnvironment {
335 acp_thread: acp_thread.downgrade(),
336 }) as _,
337 cx,
338 )
339 });
340
341 let subscriptions = vec![
342 cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
343 this.sessions.remove(acp_thread.session_id());
344 }),
345 cx.subscribe(&thread_handle, Self::handle_thread_title_updated),
346 cx.subscribe(&thread_handle, Self::handle_thread_token_usage_updated),
347 cx.observe(&thread_handle, move |this, thread, cx| {
348 this.save_thread(thread, cx)
349 }),
350 ];
351
352 self.sessions.insert(
353 session_id,
354 Session {
355 thread: thread_handle,
356 acp_thread: acp_thread.downgrade(),
357 _subscriptions: subscriptions,
358 pending_save: Task::ready(()),
359 },
360 );
361
362 self.update_available_commands(cx);
363
364 acp_thread
365 }
366
367 pub fn models(&self) -> &LanguageModels {
368 &self.models
369 }
370
371 async fn maintain_project_context(
372 this: WeakEntity<Self>,
373 mut needs_refresh: watch::Receiver<()>,
374 cx: &mut AsyncApp,
375 ) -> Result<()> {
376 while needs_refresh.changed().await.is_ok() {
377 let project_context = this
378 .update(cx, |this, cx| {
379 Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx)
380 })?
381 .await;
382 this.update(cx, |this, cx| {
383 this.project_context = cx.new(|_| project_context);
384 })?;
385 }
386
387 Ok(())
388 }
389
390 fn build_project_context(
391 project: &Entity<Project>,
392 prompt_store: Option<&Entity<PromptStore>>,
393 cx: &mut App,
394 ) -> Task<ProjectContext> {
395 let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
396 let worktree_tasks = worktrees
397 .into_iter()
398 .map(|worktree| {
399 Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
400 })
401 .collect::<Vec<_>>();
402 let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
403 prompt_store.read_with(cx, |prompt_store, cx| {
404 let prompts = prompt_store.default_prompt_metadata();
405 let load_tasks = prompts.into_iter().map(|prompt_metadata| {
406 let contents = prompt_store.load(prompt_metadata.id, cx);
407 async move { (contents.await, prompt_metadata) }
408 });
409 cx.background_spawn(future::join_all(load_tasks))
410 })
411 } else {
412 Task::ready(vec![])
413 };
414
415 cx.spawn(async move |_cx| {
416 let (worktrees, default_user_rules) =
417 future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
418
419 let worktrees = worktrees
420 .into_iter()
421 .map(|(worktree, _rules_error)| {
422 // TODO: show error message
423 // if let Some(rules_error) = rules_error {
424 // this.update(cx, |_, cx| cx.emit(rules_error)).ok();
425 // }
426 worktree
427 })
428 .collect::<Vec<_>>();
429
430 let default_user_rules = default_user_rules
431 .into_iter()
432 .flat_map(|(contents, prompt_metadata)| match contents {
433 Ok(contents) => Some(UserRulesContext {
434 uuid: prompt_metadata.id.as_user()?,
435 title: prompt_metadata.title.map(|title| title.to_string()),
436 contents,
437 }),
438 Err(_err) => {
439 // TODO: show error message
440 // this.update(cx, |_, cx| {
441 // cx.emit(RulesLoadingError {
442 // message: format!("{err:?}").into(),
443 // });
444 // })
445 // .ok();
446 None
447 }
448 })
449 .collect::<Vec<_>>();
450
451 ProjectContext::new(worktrees, default_user_rules)
452 })
453 }
454
455 fn load_worktree_info_for_system_prompt(
456 worktree: Entity<Worktree>,
457 project: Entity<Project>,
458 cx: &mut App,
459 ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
460 let tree = worktree.read(cx);
461 let root_name = tree.root_name_str().into();
462 let abs_path = tree.abs_path();
463
464 let mut context = WorktreeContext {
465 root_name,
466 abs_path,
467 rules_file: None,
468 };
469
470 let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
471 let Some(rules_task) = rules_task else {
472 return Task::ready((context, None));
473 };
474
475 cx.spawn(async move |_| {
476 let (rules_file, rules_file_error) = match rules_task.await {
477 Ok(rules_file) => (Some(rules_file), None),
478 Err(err) => (
479 None,
480 Some(RulesLoadingError {
481 message: format!("{err}").into(),
482 }),
483 ),
484 };
485 context.rules_file = rules_file;
486 (context, rules_file_error)
487 })
488 }
489
490 fn load_worktree_rules_file(
491 worktree: Entity<Worktree>,
492 project: Entity<Project>,
493 cx: &mut App,
494 ) -> Option<Task<Result<RulesFileContext>>> {
495 let worktree = worktree.read(cx);
496 let worktree_id = worktree.id();
497 let selected_rules_file = RULES_FILE_NAMES
498 .into_iter()
499 .filter_map(|name| {
500 worktree
501 .entry_for_path(RelPath::unix(name).unwrap())
502 .filter(|entry| entry.is_file())
503 .map(|entry| entry.path.clone())
504 })
505 .next();
506
507 // Note that Cline supports `.clinerules` being a directory, but that is not currently
508 // supported. This doesn't seem to occur often in GitHub repositories.
509 selected_rules_file.map(|path_in_worktree| {
510 let project_path = ProjectPath {
511 worktree_id,
512 path: path_in_worktree.clone(),
513 };
514 let buffer_task =
515 project.update(cx, |project, cx| project.open_buffer(project_path, cx));
516 let rope_task = cx.spawn(async move |cx| {
517 let buffer = buffer_task.await?;
518 let (project_entry_id, rope) = buffer.read_with(cx, |buffer, cx| {
519 let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
520 anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
521 })?;
522 anyhow::Ok((project_entry_id, rope))
523 });
524 // Build a string from the rope on a background thread.
525 cx.background_spawn(async move {
526 let (project_entry_id, rope) = rope_task.await?;
527 anyhow::Ok(RulesFileContext {
528 path_in_worktree,
529 text: rope.to_string().trim().to_string(),
530 project_entry_id: project_entry_id.to_usize(),
531 })
532 })
533 })
534 }
535
536 fn handle_thread_title_updated(
537 &mut self,
538 thread: Entity<Thread>,
539 _: &TitleUpdated,
540 cx: &mut Context<Self>,
541 ) {
542 let session_id = thread.read(cx).id();
543 let Some(session) = self.sessions.get(session_id) else {
544 return;
545 };
546 let thread = thread.downgrade();
547 let acp_thread = session.acp_thread.clone();
548 cx.spawn(async move |_, cx| {
549 let title = thread.read_with(cx, |thread, _| thread.title())?;
550 let task = acp_thread.update(cx, |acp_thread, cx| acp_thread.set_title(title, cx))?;
551 task.await
552 })
553 .detach_and_log_err(cx);
554 }
555
556 fn handle_thread_token_usage_updated(
557 &mut self,
558 thread: Entity<Thread>,
559 usage: &TokenUsageUpdated,
560 cx: &mut Context<Self>,
561 ) {
562 let Some(session) = self.sessions.get(thread.read(cx).id()) else {
563 return;
564 };
565 session
566 .acp_thread
567 .update(cx, |acp_thread, cx| {
568 acp_thread.update_token_usage(usage.0.clone(), cx);
569 })
570 .ok();
571 }
572
573 fn handle_project_event(
574 &mut self,
575 _project: Entity<Project>,
576 event: &project::Event,
577 _cx: &mut Context<Self>,
578 ) {
579 match event {
580 project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
581 self.project_context_needs_refresh.send(()).ok();
582 }
583 project::Event::WorktreeUpdatedEntries(_, items) => {
584 if items.iter().any(|(path, _, _)| {
585 RULES_FILE_NAMES
586 .iter()
587 .any(|name| path.as_ref() == RelPath::unix(name).unwrap())
588 }) {
589 self.project_context_needs_refresh.send(()).ok();
590 }
591 }
592 _ => {}
593 }
594 }
595
596 fn handle_prompts_updated_event(
597 &mut self,
598 _prompt_store: Entity<PromptStore>,
599 _event: &prompt_store::PromptsUpdatedEvent,
600 _cx: &mut Context<Self>,
601 ) {
602 self.project_context_needs_refresh.send(()).ok();
603 }
604
605 fn handle_models_updated_event(
606 &mut self,
607 _registry: Entity<LanguageModelRegistry>,
608 _event: &language_model::Event,
609 cx: &mut Context<Self>,
610 ) {
611 self.models.refresh_list(cx);
612
613 let registry = LanguageModelRegistry::read_global(cx);
614 let default_model = registry.default_model().map(|m| m.model);
615 let summarization_model = registry.thread_summary_model().map(|m| m.model);
616
617 for session in self.sessions.values_mut() {
618 session.thread.update(cx, |thread, cx| {
619 if thread.model().is_none()
620 && let Some(model) = default_model.clone()
621 {
622 thread.set_model(model, cx);
623 cx.notify();
624 }
625 thread.set_summarization_model(summarization_model.clone(), cx);
626 });
627 }
628 }
629
630 fn handle_context_server_store_updated(
631 &mut self,
632 _store: Entity<project::context_server_store::ContextServerStore>,
633 _event: &project::context_server_store::Event,
634 cx: &mut Context<Self>,
635 ) {
636 self.update_available_commands(cx);
637 }
638
639 fn handle_context_server_registry_event(
640 &mut self,
641 _registry: Entity<ContextServerRegistry>,
642 event: &ContextServerRegistryEvent,
643 cx: &mut Context<Self>,
644 ) {
645 match event {
646 ContextServerRegistryEvent::ToolsChanged => {}
647 ContextServerRegistryEvent::PromptsChanged => {
648 self.update_available_commands(cx);
649 }
650 }
651 }
652
653 fn update_available_commands(&self, cx: &mut Context<Self>) {
654 let available_commands = self.build_available_commands(cx);
655 for session in self.sessions.values() {
656 if let Some(acp_thread) = session.acp_thread.upgrade() {
657 acp_thread.update(cx, |thread, cx| {
658 thread
659 .handle_session_update(
660 acp::SessionUpdate::AvailableCommandsUpdate(
661 acp::AvailableCommandsUpdate::new(available_commands.clone()),
662 ),
663 cx,
664 )
665 .log_err();
666 });
667 }
668 }
669 }
670
671 fn build_available_commands(&self, cx: &App) -> Vec<acp::AvailableCommand> {
672 let registry = self.context_server_registry.read(cx);
673
674 let mut prompt_name_counts: HashMap<&str, usize> = HashMap::default();
675 for context_server_prompt in registry.prompts() {
676 *prompt_name_counts
677 .entry(context_server_prompt.prompt.name.as_str())
678 .or_insert(0) += 1;
679 }
680
681 registry
682 .prompts()
683 .flat_map(|context_server_prompt| {
684 let prompt = &context_server_prompt.prompt;
685
686 let should_prefix = prompt_name_counts
687 .get(prompt.name.as_str())
688 .copied()
689 .unwrap_or(0)
690 > 1;
691
692 let name = if should_prefix {
693 format!("{}.{}", context_server_prompt.server_id, prompt.name)
694 } else {
695 prompt.name.clone()
696 };
697
698 let mut command = acp::AvailableCommand::new(
699 name,
700 prompt.description.clone().unwrap_or_default(),
701 );
702
703 match prompt.arguments.as_deref() {
704 Some([arg]) => {
705 let hint = format!("<{}>", arg.name);
706
707 command = command.input(acp::AvailableCommandInput::Unstructured(
708 acp::UnstructuredCommandInput::new(hint),
709 ));
710 }
711 Some([]) | None => {}
712 Some(_) => {
713 // skip >1 argument commands since we don't support them yet
714 return None;
715 }
716 }
717
718 Some(command)
719 })
720 .collect()
721 }
722
723 pub fn load_thread(
724 &mut self,
725 id: acp::SessionId,
726 cx: &mut Context<Self>,
727 ) -> Task<Result<Entity<Thread>>> {
728 let database_future = ThreadsDatabase::connect(cx);
729 cx.spawn(async move |this, cx| {
730 let database = database_future.await.map_err(|err| anyhow!(err))?;
731 let db_thread = database
732 .load_thread(id.clone())
733 .await?
734 .with_context(|| format!("no thread found with ID: {id:?}"))?;
735
736 this.update(cx, |this, cx| {
737 let summarization_model = LanguageModelRegistry::read_global(cx)
738 .thread_summary_model()
739 .map(|c| c.model);
740
741 cx.new(|cx| {
742 let mut thread = Thread::from_db(
743 id.clone(),
744 db_thread,
745 this.project.clone(),
746 this.project_context.clone(),
747 this.context_server_registry.clone(),
748 this.templates.clone(),
749 cx,
750 );
751 thread.set_summarization_model(summarization_model, cx);
752 thread
753 })
754 })
755 })
756 }
757
758 pub fn open_thread(
759 &mut self,
760 id: acp::SessionId,
761 cx: &mut Context<Self>,
762 ) -> Task<Result<Entity<AcpThread>>> {
763 let task = self.load_thread(id, cx);
764 cx.spawn(async move |this, cx| {
765 let thread = task.await?;
766 let acp_thread =
767 this.update(cx, |this, cx| this.register_session(thread.clone(), cx))?;
768 let events = thread.update(cx, |thread, cx| thread.replay(cx));
769 cx.update(|cx| {
770 NativeAgentConnection::handle_thread_events(events, acp_thread.downgrade(), cx)
771 })
772 .await?;
773 Ok(acp_thread)
774 })
775 }
776
777 pub fn thread_summary(
778 &mut self,
779 id: acp::SessionId,
780 cx: &mut Context<Self>,
781 ) -> Task<Result<SharedString>> {
782 let thread = self.open_thread(id.clone(), cx);
783 cx.spawn(async move |this, cx| {
784 let acp_thread = thread.await?;
785 let result = this
786 .update(cx, |this, cx| {
787 this.sessions
788 .get(&id)
789 .unwrap()
790 .thread
791 .update(cx, |thread, cx| thread.summary(cx))
792 })?
793 .await
794 .context("Failed to generate summary")?;
795 drop(acp_thread);
796 Ok(result)
797 })
798 }
799
800 fn save_thread(&mut self, thread: Entity<Thread>, cx: &mut Context<Self>) {
801 if thread.read(cx).is_empty() {
802 return;
803 }
804
805 let database_future = ThreadsDatabase::connect(cx);
806 let (id, db_thread) =
807 thread.update(cx, |thread, cx| (thread.id().clone(), thread.to_db(cx)));
808 let Some(session) = self.sessions.get_mut(&id) else {
809 return;
810 };
811 let thread_store = self.thread_store.clone();
812 session.pending_save = cx.spawn(async move |_, cx| {
813 let Some(database) = database_future.await.map_err(|err| anyhow!(err)).log_err() else {
814 return;
815 };
816 let db_thread = db_thread.await;
817 database.save_thread(id, db_thread).await.log_err();
818 thread_store.update(cx, |store, cx| store.reload(cx));
819 });
820 }
821
822 fn send_mcp_prompt(
823 &self,
824 message_id: UserMessageId,
825 session_id: agent_client_protocol::SessionId,
826 prompt_name: String,
827 server_id: ContextServerId,
828 arguments: HashMap<String, String>,
829 original_content: Vec<acp::ContentBlock>,
830 cx: &mut Context<Self>,
831 ) -> Task<Result<acp::PromptResponse>> {
832 let server_store = self.context_server_registry.read(cx).server_store().clone();
833 let path_style = self.project.read(cx).path_style(cx);
834
835 cx.spawn(async move |this, cx| {
836 let prompt =
837 crate::get_prompt(&server_store, &server_id, &prompt_name, arguments, cx).await?;
838
839 let (acp_thread, thread) = this.update(cx, |this, _cx| {
840 let session = this
841 .sessions
842 .get(&session_id)
843 .context("Failed to get session")?;
844 anyhow::Ok((session.acp_thread.clone(), session.thread.clone()))
845 })??;
846
847 let mut last_is_user = true;
848
849 thread.update(cx, |thread, cx| {
850 thread.push_acp_user_block(
851 message_id,
852 original_content.into_iter().skip(1),
853 path_style,
854 cx,
855 );
856 });
857
858 for message in prompt.messages {
859 let context_server::types::PromptMessage { role, content } = message;
860 let block = mcp_message_content_to_acp_content_block(content);
861
862 match role {
863 context_server::types::Role::User => {
864 let id = acp_thread::UserMessageId::new();
865
866 acp_thread.update(cx, |acp_thread, cx| {
867 acp_thread.push_user_content_block_with_indent(
868 Some(id.clone()),
869 block.clone(),
870 true,
871 cx,
872 );
873 })?;
874
875 thread.update(cx, |thread, cx| {
876 thread.push_acp_user_block(id, [block], path_style, cx);
877 });
878 }
879 context_server::types::Role::Assistant => {
880 acp_thread.update(cx, |acp_thread, cx| {
881 acp_thread.push_assistant_content_block_with_indent(
882 block.clone(),
883 false,
884 true,
885 cx,
886 );
887 })?;
888
889 thread.update(cx, |thread, cx| {
890 thread.push_acp_agent_block(block, cx);
891 });
892 }
893 }
894
895 last_is_user = role == context_server::types::Role::User;
896 }
897
898 let response_stream = thread.update(cx, |thread, cx| {
899 if last_is_user {
900 thread.send_existing(cx)
901 } else {
902 // Resume if MCP prompt did not end with a user message
903 thread.resume(cx)
904 }
905 })?;
906
907 cx.update(|cx| {
908 NativeAgentConnection::handle_thread_events(response_stream, acp_thread, cx)
909 })
910 .await
911 })
912 }
913}
914
915/// Wrapper struct that implements the AgentConnection trait
916#[derive(Clone)]
917pub struct NativeAgentConnection(pub Entity<NativeAgent>);
918
919impl NativeAgentConnection {
920 pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option<Entity<Thread>> {
921 self.0
922 .read(cx)
923 .sessions
924 .get(session_id)
925 .map(|session| session.thread.clone())
926 }
927
928 pub fn load_thread(&self, id: acp::SessionId, cx: &mut App) -> Task<Result<Entity<Thread>>> {
929 self.0.update(cx, |this, cx| this.load_thread(id, cx))
930 }
931
932 fn run_turn(
933 &self,
934 session_id: acp::SessionId,
935 cx: &mut App,
936 f: impl 'static
937 + FnOnce(Entity<Thread>, &mut App) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>,
938 ) -> Task<Result<acp::PromptResponse>> {
939 let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
940 agent
941 .sessions
942 .get_mut(&session_id)
943 .map(|s| (s.thread.clone(), s.acp_thread.clone()))
944 }) else {
945 return Task::ready(Err(anyhow!("Session not found")));
946 };
947 log::debug!("Found session for: {}", session_id);
948
949 let response_stream = match f(thread, cx) {
950 Ok(stream) => stream,
951 Err(err) => return Task::ready(Err(err)),
952 };
953 Self::handle_thread_events(response_stream, acp_thread, cx)
954 }
955
956 fn handle_thread_events(
957 mut events: mpsc::UnboundedReceiver<Result<ThreadEvent>>,
958 acp_thread: WeakEntity<AcpThread>,
959 cx: &App,
960 ) -> Task<Result<acp::PromptResponse>> {
961 cx.spawn(async move |cx| {
962 // Handle response stream and forward to session.acp_thread
963 while let Some(result) = events.next().await {
964 match result {
965 Ok(event) => {
966 log::trace!("Received completion event: {:?}", event);
967
968 match event {
969 ThreadEvent::UserMessage(message) => {
970 acp_thread.update(cx, |thread, cx| {
971 for content in message.content {
972 thread.push_user_content_block(
973 Some(message.id.clone()),
974 content.into(),
975 cx,
976 );
977 }
978 })?;
979 }
980 ThreadEvent::AgentText(text) => {
981 acp_thread.update(cx, |thread, cx| {
982 thread.push_assistant_content_block(text.into(), false, cx)
983 })?;
984 }
985 ThreadEvent::AgentThinking(text) => {
986 acp_thread.update(cx, |thread, cx| {
987 thread.push_assistant_content_block(text.into(), true, cx)
988 })?;
989 }
990 ThreadEvent::ToolCallAuthorization(ToolCallAuthorization {
991 tool_call,
992 options,
993 response,
994 }) => {
995 let outcome_task = acp_thread.update(cx, |thread, cx| {
996 thread.request_tool_call_authorization(
997 tool_call, options, true, cx,
998 )
999 })??;
1000 cx.background_spawn(async move {
1001 if let acp::RequestPermissionOutcome::Selected(
1002 acp::SelectedPermissionOutcome { option_id, .. },
1003 ) = outcome_task.await
1004 {
1005 response
1006 .send(option_id)
1007 .map(|_| anyhow!("authorization receiver was dropped"))
1008 .log_err();
1009 }
1010 })
1011 .detach();
1012 }
1013 ThreadEvent::ToolCall(tool_call) => {
1014 acp_thread.update(cx, |thread, cx| {
1015 thread.upsert_tool_call(tool_call, cx)
1016 })??;
1017 }
1018 ThreadEvent::ToolCallUpdate(update) => {
1019 acp_thread.update(cx, |thread, cx| {
1020 thread.update_tool_call(update, cx)
1021 })??;
1022 }
1023 ThreadEvent::Retry(status) => {
1024 acp_thread.update(cx, |thread, cx| {
1025 thread.update_retry_status(status, cx)
1026 })?;
1027 }
1028 ThreadEvent::Stop(stop_reason) => {
1029 log::debug!("Assistant message complete: {:?}", stop_reason);
1030 return Ok(acp::PromptResponse::new(stop_reason));
1031 }
1032 }
1033 }
1034 Err(e) => {
1035 log::error!("Error in model response stream: {:?}", e);
1036 return Err(e);
1037 }
1038 }
1039 }
1040
1041 log::debug!("Response stream completed");
1042 anyhow::Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
1043 })
1044 }
1045}
1046
1047struct Command<'a> {
1048 prompt_name: &'a str,
1049 arg_value: &'a str,
1050 explicit_server_id: Option<&'a str>,
1051}
1052
1053impl<'a> Command<'a> {
1054 fn parse(prompt: &'a [acp::ContentBlock]) -> Option<Self> {
1055 let acp::ContentBlock::Text(text_content) = prompt.first()? else {
1056 return None;
1057 };
1058 let text = text_content.text.trim();
1059 let command = text.strip_prefix('/')?;
1060 let (command, arg_value) = command
1061 .split_once(char::is_whitespace)
1062 .unwrap_or((command, ""));
1063
1064 if let Some((server_id, prompt_name)) = command.split_once('.') {
1065 Some(Self {
1066 prompt_name,
1067 arg_value,
1068 explicit_server_id: Some(server_id),
1069 })
1070 } else {
1071 Some(Self {
1072 prompt_name: command,
1073 arg_value,
1074 explicit_server_id: None,
1075 })
1076 }
1077 }
1078}
1079
1080struct NativeAgentModelSelector {
1081 session_id: acp::SessionId,
1082 connection: NativeAgentConnection,
1083}
1084
1085impl acp_thread::AgentModelSelector for NativeAgentModelSelector {
1086 fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
1087 log::debug!("NativeAgentConnection::list_models called");
1088 let list = self.connection.0.read(cx).models.model_list.clone();
1089 Task::ready(if list.is_empty() {
1090 Err(anyhow::anyhow!("No models available"))
1091 } else {
1092 Ok(list)
1093 })
1094 }
1095
1096 fn select_model(&self, model_id: acp::ModelId, cx: &mut App) -> Task<Result<()>> {
1097 log::debug!(
1098 "Setting model for session {}: {}",
1099 self.session_id,
1100 model_id
1101 );
1102 let Some(thread) = self
1103 .connection
1104 .0
1105 .read(cx)
1106 .sessions
1107 .get(&self.session_id)
1108 .map(|session| session.thread.clone())
1109 else {
1110 return Task::ready(Err(anyhow!("Session not found")));
1111 };
1112
1113 let Some(model) = self.connection.0.read(cx).models.model_from_id(&model_id) else {
1114 return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
1115 };
1116
1117 thread.update(cx, |thread, cx| {
1118 thread.set_model(model.clone(), cx);
1119 });
1120
1121 update_settings_file(
1122 self.connection.0.read(cx).fs.clone(),
1123 cx,
1124 move |settings, _cx| {
1125 let provider = model.provider_id().0.to_string();
1126 let model = model.id().0.to_string();
1127 settings
1128 .agent
1129 .get_or_insert_default()
1130 .set_model(LanguageModelSelection {
1131 provider: provider.into(),
1132 model,
1133 });
1134 },
1135 );
1136
1137 Task::ready(Ok(()))
1138 }
1139
1140 fn selected_model(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelInfo>> {
1141 let Some(thread) = self
1142 .connection
1143 .0
1144 .read(cx)
1145 .sessions
1146 .get(&self.session_id)
1147 .map(|session| session.thread.clone())
1148 else {
1149 return Task::ready(Err(anyhow!("Session not found")));
1150 };
1151 let Some(model) = thread.read(cx).model() else {
1152 return Task::ready(Err(anyhow!("Model not found")));
1153 };
1154 let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
1155 else {
1156 return Task::ready(Err(anyhow!("Provider not found")));
1157 };
1158 Task::ready(Ok(LanguageModels::map_language_model_to_info(
1159 model, &provider,
1160 )))
1161 }
1162
1163 fn watch(&self, cx: &mut App) -> Option<watch::Receiver<()>> {
1164 Some(self.connection.0.read(cx).models.watch())
1165 }
1166
1167 fn should_render_footer(&self) -> bool {
1168 true
1169 }
1170}
1171
1172impl acp_thread::AgentConnection for NativeAgentConnection {
1173 fn telemetry_id(&self) -> SharedString {
1174 "zed".into()
1175 }
1176
1177 fn new_thread(
1178 self: Rc<Self>,
1179 project: Entity<Project>,
1180 cwd: &Path,
1181 cx: &mut App,
1182 ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
1183 let agent = self.0.clone();
1184 log::debug!("Creating new thread for project at: {:?}", cwd);
1185
1186 cx.spawn(async move |cx| {
1187 log::debug!("Starting thread creation in async context");
1188
1189 // Create Thread
1190 let thread = agent.update(cx, |agent, cx| {
1191 // Fetch default model from registry settings
1192 let registry = LanguageModelRegistry::read_global(cx);
1193 // Log available models for debugging
1194 let available_count = registry.available_models(cx).count();
1195 log::debug!("Total available models: {}", available_count);
1196
1197 let default_model = registry.default_model().and_then(|default_model| {
1198 agent
1199 .models
1200 .model_from_id(&LanguageModels::model_id(&default_model.model))
1201 });
1202 cx.new(|cx| {
1203 Thread::new(
1204 project.clone(),
1205 agent.project_context.clone(),
1206 agent.context_server_registry.clone(),
1207 agent.templates.clone(),
1208 default_model,
1209 cx,
1210 )
1211 })
1212 });
1213 Ok(agent.update(cx, |agent, cx| agent.register_session(thread, cx)))
1214 })
1215 }
1216
1217 fn auth_methods(&self) -> &[acp::AuthMethod] {
1218 &[] // No auth for in-process
1219 }
1220
1221 fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
1222 Task::ready(Ok(()))
1223 }
1224
1225 fn model_selector(&self, session_id: &acp::SessionId) -> Option<Rc<dyn AgentModelSelector>> {
1226 Some(Rc::new(NativeAgentModelSelector {
1227 session_id: session_id.clone(),
1228 connection: self.clone(),
1229 }) as Rc<dyn AgentModelSelector>)
1230 }
1231
1232 fn prompt(
1233 &self,
1234 id: Option<acp_thread::UserMessageId>,
1235 params: acp::PromptRequest,
1236 cx: &mut App,
1237 ) -> Task<Result<acp::PromptResponse>> {
1238 let id = id.expect("UserMessageId is required");
1239 let session_id = params.session_id.clone();
1240 log::info!("Received prompt request for session: {}", session_id);
1241 log::debug!("Prompt blocks count: {}", params.prompt.len());
1242
1243 if let Some(parsed_command) = Command::parse(¶ms.prompt) {
1244 let registry = self.0.read(cx).context_server_registry.read(cx);
1245
1246 let explicit_server_id = parsed_command
1247 .explicit_server_id
1248 .map(|server_id| ContextServerId(server_id.into()));
1249
1250 if let Some(prompt) =
1251 registry.find_prompt(explicit_server_id.as_ref(), parsed_command.prompt_name)
1252 {
1253 let arguments = if !parsed_command.arg_value.is_empty()
1254 && let Some(arg_name) = prompt
1255 .prompt
1256 .arguments
1257 .as_ref()
1258 .and_then(|args| args.first())
1259 .map(|arg| arg.name.clone())
1260 {
1261 HashMap::from_iter([(arg_name, parsed_command.arg_value.to_string())])
1262 } else {
1263 Default::default()
1264 };
1265
1266 let prompt_name = prompt.prompt.name.clone();
1267 let server_id = prompt.server_id.clone();
1268
1269 return self.0.update(cx, |agent, cx| {
1270 agent.send_mcp_prompt(
1271 id,
1272 session_id.clone(),
1273 prompt_name,
1274 server_id,
1275 arguments,
1276 params.prompt,
1277 cx,
1278 )
1279 });
1280 };
1281 };
1282
1283 let path_style = self.0.read(cx).project.read(cx).path_style(cx);
1284
1285 self.run_turn(session_id, cx, move |thread, cx| {
1286 let content: Vec<UserMessageContent> = params
1287 .prompt
1288 .into_iter()
1289 .map(|block| UserMessageContent::from_content_block(block, path_style))
1290 .collect::<Vec<_>>();
1291 log::debug!("Converted prompt to message: {} chars", content.len());
1292 log::debug!("Message id: {:?}", id);
1293 log::debug!("Message content: {:?}", content);
1294
1295 thread.update(cx, |thread, cx| thread.send(id, content, cx))
1296 })
1297 }
1298
1299 fn resume(
1300 &self,
1301 session_id: &acp::SessionId,
1302 _cx: &App,
1303 ) -> Option<Rc<dyn acp_thread::AgentSessionResume>> {
1304 Some(Rc::new(NativeAgentSessionResume {
1305 connection: self.clone(),
1306 session_id: session_id.clone(),
1307 }) as _)
1308 }
1309
1310 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
1311 log::info!("Cancelling on session: {}", session_id);
1312 self.0.update(cx, |agent, cx| {
1313 if let Some(agent) = agent.sessions.get(session_id) {
1314 agent
1315 .thread
1316 .update(cx, |thread, cx| thread.cancel(cx))
1317 .detach();
1318 }
1319 });
1320 }
1321
1322 fn truncate(
1323 &self,
1324 session_id: &agent_client_protocol::SessionId,
1325 cx: &App,
1326 ) -> Option<Rc<dyn acp_thread::AgentSessionTruncate>> {
1327 self.0.read_with(cx, |agent, _cx| {
1328 agent.sessions.get(session_id).map(|session| {
1329 Rc::new(NativeAgentSessionTruncate {
1330 thread: session.thread.clone(),
1331 acp_thread: session.acp_thread.clone(),
1332 }) as _
1333 })
1334 })
1335 }
1336
1337 fn set_title(
1338 &self,
1339 session_id: &acp::SessionId,
1340 _cx: &App,
1341 ) -> Option<Rc<dyn acp_thread::AgentSessionSetTitle>> {
1342 Some(Rc::new(NativeAgentSessionSetTitle {
1343 connection: self.clone(),
1344 session_id: session_id.clone(),
1345 }) as _)
1346 }
1347
1348 fn telemetry(&self) -> Option<Rc<dyn acp_thread::AgentTelemetry>> {
1349 Some(Rc::new(self.clone()) as Rc<dyn acp_thread::AgentTelemetry>)
1350 }
1351
1352 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1353 self
1354 }
1355}
1356
1357impl acp_thread::AgentTelemetry for NativeAgentConnection {
1358 fn thread_data(
1359 &self,
1360 session_id: &acp::SessionId,
1361 cx: &mut App,
1362 ) -> Task<Result<serde_json::Value>> {
1363 let Some(session) = self.0.read(cx).sessions.get(session_id) else {
1364 return Task::ready(Err(anyhow!("Session not found")));
1365 };
1366
1367 let task = session.thread.read(cx).to_db(cx);
1368 cx.background_spawn(async move {
1369 serde_json::to_value(task.await).context("Failed to serialize thread")
1370 })
1371 }
1372}
1373
1374struct NativeAgentSessionTruncate {
1375 thread: Entity<Thread>,
1376 acp_thread: WeakEntity<AcpThread>,
1377}
1378
1379impl acp_thread::AgentSessionTruncate for NativeAgentSessionTruncate {
1380 fn run(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
1381 match self.thread.update(cx, |thread, cx| {
1382 thread.truncate(message_id.clone(), cx)?;
1383 Ok(thread.latest_token_usage())
1384 }) {
1385 Ok(usage) => {
1386 self.acp_thread
1387 .update(cx, |thread, cx| {
1388 thread.update_token_usage(usage, cx);
1389 })
1390 .ok();
1391 Task::ready(Ok(()))
1392 }
1393 Err(error) => Task::ready(Err(error)),
1394 }
1395 }
1396}
1397
1398struct NativeAgentSessionResume {
1399 connection: NativeAgentConnection,
1400 session_id: acp::SessionId,
1401}
1402
1403impl acp_thread::AgentSessionResume for NativeAgentSessionResume {
1404 fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>> {
1405 self.connection
1406 .run_turn(self.session_id.clone(), cx, |thread, cx| {
1407 thread.update(cx, |thread, cx| thread.resume(cx))
1408 })
1409 }
1410}
1411
1412struct NativeAgentSessionSetTitle {
1413 connection: NativeAgentConnection,
1414 session_id: acp::SessionId,
1415}
1416
1417impl acp_thread::AgentSessionSetTitle for NativeAgentSessionSetTitle {
1418 fn run(&self, title: SharedString, cx: &mut App) -> Task<Result<()>> {
1419 let Some(session) = self.connection.0.read(cx).sessions.get(&self.session_id) else {
1420 return Task::ready(Err(anyhow!("session not found")));
1421 };
1422 let thread = session.thread.clone();
1423 thread.update(cx, |thread, cx| thread.set_title(title, cx));
1424 Task::ready(Ok(()))
1425 }
1426}
1427
1428pub struct AcpThreadEnvironment {
1429 acp_thread: WeakEntity<AcpThread>,
1430}
1431
1432impl ThreadEnvironment for AcpThreadEnvironment {
1433 fn create_terminal(
1434 &self,
1435 command: String,
1436 cwd: Option<PathBuf>,
1437 output_byte_limit: Option<u64>,
1438 cx: &mut AsyncApp,
1439 ) -> Task<Result<Rc<dyn TerminalHandle>>> {
1440 let task = self.acp_thread.update(cx, |thread, cx| {
1441 thread.create_terminal(command, vec![], vec![], cwd, output_byte_limit, cx)
1442 });
1443
1444 let acp_thread = self.acp_thread.clone();
1445 cx.spawn(async move |cx| {
1446 let terminal = task?.await?;
1447
1448 let (drop_tx, drop_rx) = oneshot::channel();
1449 let terminal_id = terminal.read_with(cx, |terminal, _cx| terminal.id().clone());
1450
1451 cx.spawn(async move |cx| {
1452 drop_rx.await.ok();
1453 acp_thread.update(cx, |thread, cx| thread.release_terminal(terminal_id, cx))
1454 })
1455 .detach();
1456
1457 let handle = AcpTerminalHandle {
1458 terminal,
1459 _drop_tx: Some(drop_tx),
1460 };
1461
1462 Ok(Rc::new(handle) as _)
1463 })
1464 }
1465}
1466
1467pub struct AcpTerminalHandle {
1468 terminal: Entity<acp_thread::Terminal>,
1469 _drop_tx: Option<oneshot::Sender<()>>,
1470}
1471
1472impl TerminalHandle for AcpTerminalHandle {
1473 fn id(&self, cx: &AsyncApp) -> Result<acp::TerminalId> {
1474 Ok(self.terminal.read_with(cx, |term, _cx| term.id().clone()))
1475 }
1476
1477 fn wait_for_exit(&self, cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
1478 Ok(self
1479 .terminal
1480 .read_with(cx, |term, _cx| term.wait_for_exit()))
1481 }
1482
1483 fn current_output(&self, cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
1484 Ok(self
1485 .terminal
1486 .read_with(cx, |term, cx| term.current_output(cx)))
1487 }
1488
1489 fn kill(&self, cx: &AsyncApp) -> Result<()> {
1490 cx.update(|cx| {
1491 self.terminal.update(cx, |terminal, cx| {
1492 terminal.kill(cx);
1493 });
1494 });
1495 Ok(())
1496 }
1497
1498 fn was_stopped_by_user(&self, cx: &AsyncApp) -> Result<bool> {
1499 Ok(self
1500 .terminal
1501 .read_with(cx, |term, _cx| term.was_stopped_by_user()))
1502 }
1503}
1504
1505#[cfg(test)]
1506mod internal_tests {
1507 use super::*;
1508 use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelInfo, MentionUri};
1509 use fs::FakeFs;
1510 use gpui::TestAppContext;
1511 use indoc::formatdoc;
1512 use language_model::fake_provider::FakeLanguageModel;
1513 use serde_json::json;
1514 use settings::SettingsStore;
1515 use util::{path, rel_path::rel_path};
1516
1517 #[gpui::test]
1518 async fn test_maintaining_project_context(cx: &mut TestAppContext) {
1519 init_test(cx);
1520 let fs = FakeFs::new(cx.executor());
1521 fs.insert_tree(
1522 "/",
1523 json!({
1524 "a": {}
1525 }),
1526 )
1527 .await;
1528 let project = Project::test(fs.clone(), [], cx).await;
1529 let thread_store = cx.new(|cx| ThreadStore::new(cx));
1530 let agent = NativeAgent::new(
1531 project.clone(),
1532 thread_store,
1533 Templates::new(),
1534 None,
1535 fs.clone(),
1536 &mut cx.to_async(),
1537 )
1538 .await
1539 .unwrap();
1540 agent.read_with(cx, |agent, cx| {
1541 assert_eq!(agent.project_context.read(cx).worktrees, vec![])
1542 });
1543
1544 let worktree = project
1545 .update(cx, |project, cx| project.create_worktree("/a", true, cx))
1546 .await
1547 .unwrap();
1548 cx.run_until_parked();
1549 agent.read_with(cx, |agent, cx| {
1550 assert_eq!(
1551 agent.project_context.read(cx).worktrees,
1552 vec![WorktreeContext {
1553 root_name: "a".into(),
1554 abs_path: Path::new("/a").into(),
1555 rules_file: None
1556 }]
1557 )
1558 });
1559
1560 // Creating `/a/.rules` updates the project context.
1561 fs.insert_file("/a/.rules", Vec::new()).await;
1562 cx.run_until_parked();
1563 agent.read_with(cx, |agent, cx| {
1564 let rules_entry = worktree
1565 .read(cx)
1566 .entry_for_path(rel_path(".rules"))
1567 .unwrap();
1568 assert_eq!(
1569 agent.project_context.read(cx).worktrees,
1570 vec![WorktreeContext {
1571 root_name: "a".into(),
1572 abs_path: Path::new("/a").into(),
1573 rules_file: Some(RulesFileContext {
1574 path_in_worktree: rel_path(".rules").into(),
1575 text: "".into(),
1576 project_entry_id: rules_entry.id.to_usize()
1577 })
1578 }]
1579 )
1580 });
1581 }
1582
1583 #[gpui::test]
1584 async fn test_listing_models(cx: &mut TestAppContext) {
1585 init_test(cx);
1586 let fs = FakeFs::new(cx.executor());
1587 fs.insert_tree("/", json!({ "a": {} })).await;
1588 let project = Project::test(fs.clone(), [], cx).await;
1589 let thread_store = cx.new(|cx| ThreadStore::new(cx));
1590 let connection = NativeAgentConnection(
1591 NativeAgent::new(
1592 project.clone(),
1593 thread_store,
1594 Templates::new(),
1595 None,
1596 fs.clone(),
1597 &mut cx.to_async(),
1598 )
1599 .await
1600 .unwrap(),
1601 );
1602
1603 // Create a thread/session
1604 let acp_thread = cx
1605 .update(|cx| {
1606 Rc::new(connection.clone()).new_thread(project.clone(), Path::new("/a"), cx)
1607 })
1608 .await
1609 .unwrap();
1610
1611 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
1612
1613 let models = cx
1614 .update(|cx| {
1615 connection
1616 .model_selector(&session_id)
1617 .unwrap()
1618 .list_models(cx)
1619 })
1620 .await
1621 .unwrap();
1622
1623 let acp_thread::AgentModelList::Grouped(models) = models else {
1624 panic!("Unexpected model group");
1625 };
1626 assert_eq!(
1627 models,
1628 IndexMap::from_iter([(
1629 AgentModelGroupName("Fake".into()),
1630 vec![AgentModelInfo {
1631 id: acp::ModelId::new("fake/fake"),
1632 name: "Fake".into(),
1633 description: None,
1634 icon: Some(acp_thread::AgentModelIcon::Named(
1635 ui::IconName::ZedAssistant
1636 )),
1637 }]
1638 )])
1639 );
1640 }
1641
1642 #[gpui::test]
1643 async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) {
1644 init_test(cx);
1645 let fs = FakeFs::new(cx.executor());
1646 fs.create_dir(paths::settings_file().parent().unwrap())
1647 .await
1648 .unwrap();
1649 fs.insert_file(
1650 paths::settings_file(),
1651 json!({
1652 "agent": {
1653 "default_model": {
1654 "provider": "foo",
1655 "model": "bar"
1656 }
1657 }
1658 })
1659 .to_string()
1660 .into_bytes(),
1661 )
1662 .await;
1663 let project = Project::test(fs.clone(), [], cx).await;
1664
1665 let thread_store = cx.new(|cx| ThreadStore::new(cx));
1666
1667 // Create the agent and connection
1668 let agent = NativeAgent::new(
1669 project.clone(),
1670 thread_store,
1671 Templates::new(),
1672 None,
1673 fs.clone(),
1674 &mut cx.to_async(),
1675 )
1676 .await
1677 .unwrap();
1678 let connection = NativeAgentConnection(agent.clone());
1679
1680 // Create a thread/session
1681 let acp_thread = cx
1682 .update(|cx| {
1683 Rc::new(connection.clone()).new_thread(project.clone(), Path::new("/a"), cx)
1684 })
1685 .await
1686 .unwrap();
1687
1688 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
1689
1690 // Select a model
1691 let selector = connection.model_selector(&session_id).unwrap();
1692 let model_id = acp::ModelId::new("fake/fake");
1693 cx.update(|cx| selector.select_model(model_id.clone(), cx))
1694 .await
1695 .unwrap();
1696
1697 // Verify the thread has the selected model
1698 agent.read_with(cx, |agent, _| {
1699 let session = agent.sessions.get(&session_id).unwrap();
1700 session.thread.read_with(cx, |thread, _| {
1701 assert_eq!(thread.model().unwrap().id().0, "fake");
1702 });
1703 });
1704
1705 cx.run_until_parked();
1706
1707 // Verify settings file was updated
1708 let settings_content = fs.load(paths::settings_file()).await.unwrap();
1709 let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
1710
1711 // Check that the agent settings contain the selected model
1712 assert_eq!(
1713 settings_json["agent"]["default_model"]["model"],
1714 json!("fake")
1715 );
1716 assert_eq!(
1717 settings_json["agent"]["default_model"]["provider"],
1718 json!("fake")
1719 );
1720 }
1721
1722 #[gpui::test]
1723 async fn test_save_load_thread(cx: &mut TestAppContext) {
1724 init_test(cx);
1725 let fs = FakeFs::new(cx.executor());
1726 fs.insert_tree(
1727 "/",
1728 json!({
1729 "a": {
1730 "b.md": "Lorem"
1731 }
1732 }),
1733 )
1734 .await;
1735 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
1736 let thread_store = cx.new(|cx| ThreadStore::new(cx));
1737 let agent = NativeAgent::new(
1738 project.clone(),
1739 thread_store.clone(),
1740 Templates::new(),
1741 None,
1742 fs.clone(),
1743 &mut cx.to_async(),
1744 )
1745 .await
1746 .unwrap();
1747 let connection = Rc::new(NativeAgentConnection(agent.clone()));
1748
1749 let acp_thread = cx
1750 .update(|cx| {
1751 connection
1752 .clone()
1753 .new_thread(project.clone(), Path::new(""), cx)
1754 })
1755 .await
1756 .unwrap();
1757 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
1758 let thread = agent.read_with(cx, |agent, _| {
1759 agent.sessions.get(&session_id).unwrap().thread.clone()
1760 });
1761
1762 // Ensure empty threads are not saved, even if they get mutated.
1763 let model = Arc::new(FakeLanguageModel::default());
1764 let summary_model = Arc::new(FakeLanguageModel::default());
1765 thread.update(cx, |thread, cx| {
1766 thread.set_model(model.clone(), cx);
1767 thread.set_summarization_model(Some(summary_model.clone()), cx);
1768 });
1769 cx.run_until_parked();
1770 assert_eq!(thread_entries(&thread_store, cx), vec![]);
1771
1772 let send = acp_thread.update(cx, |thread, cx| {
1773 thread.send(
1774 vec![
1775 "What does ".into(),
1776 acp::ContentBlock::ResourceLink(acp::ResourceLink::new(
1777 "b.md",
1778 MentionUri::File {
1779 abs_path: path!("/a/b.md").into(),
1780 }
1781 .to_uri()
1782 .to_string(),
1783 )),
1784 " mean?".into(),
1785 ],
1786 cx,
1787 )
1788 });
1789 let send = cx.foreground_executor().spawn(send);
1790 cx.run_until_parked();
1791
1792 model.send_last_completion_stream_text_chunk("Lorem.");
1793 model.end_last_completion_stream();
1794 cx.run_until_parked();
1795 summary_model
1796 .send_last_completion_stream_text_chunk(&format!("Explaining {}", path!("/a/b.md")));
1797 summary_model.end_last_completion_stream();
1798
1799 send.await.unwrap();
1800 let uri = MentionUri::File {
1801 abs_path: path!("/a/b.md").into(),
1802 }
1803 .to_uri();
1804 acp_thread.read_with(cx, |thread, cx| {
1805 assert_eq!(
1806 thread.to_markdown(cx),
1807 formatdoc! {"
1808 ## User
1809
1810 What does [@b.md]({uri}) mean?
1811
1812 ## Assistant
1813
1814 Lorem.
1815
1816 "}
1817 )
1818 });
1819
1820 cx.run_until_parked();
1821
1822 // Drop the ACP thread, which should cause the session to be dropped as well.
1823 cx.update(|_| {
1824 drop(thread);
1825 drop(acp_thread);
1826 });
1827 agent.read_with(cx, |agent, _| {
1828 assert_eq!(agent.sessions.keys().cloned().collect::<Vec<_>>(), []);
1829 });
1830
1831 // Ensure the thread can be reloaded from disk.
1832 assert_eq!(
1833 thread_entries(&thread_store, cx),
1834 vec![(
1835 session_id.clone(),
1836 format!("Explaining {}", path!("/a/b.md"))
1837 )]
1838 );
1839 let acp_thread = agent
1840 .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
1841 .await
1842 .unwrap();
1843 acp_thread.read_with(cx, |thread, cx| {
1844 assert_eq!(
1845 thread.to_markdown(cx),
1846 formatdoc! {"
1847 ## User
1848
1849 What does [@b.md]({uri}) mean?
1850
1851 ## Assistant
1852
1853 Lorem.
1854
1855 "}
1856 )
1857 });
1858 }
1859
1860 fn thread_entries(
1861 thread_store: &Entity<ThreadStore>,
1862 cx: &mut TestAppContext,
1863 ) -> Vec<(acp::SessionId, String)> {
1864 thread_store.read_with(cx, |store, _| {
1865 store
1866 .entries()
1867 .map(|entry| (entry.id.clone(), entry.title.to_string()))
1868 .collect::<Vec<_>>()
1869 })
1870 }
1871
1872 fn init_test(cx: &mut TestAppContext) {
1873 env_logger::try_init().ok();
1874 cx.update(|cx| {
1875 let settings_store = SettingsStore::test(cx);
1876 cx.set_global(settings_store);
1877
1878 LanguageModelRegistry::test(cx);
1879 });
1880 }
1881}
1882
1883fn mcp_message_content_to_acp_content_block(
1884 content: context_server::types::MessageContent,
1885) -> acp::ContentBlock {
1886 match content {
1887 context_server::types::MessageContent::Text {
1888 text,
1889 annotations: _,
1890 } => text.into(),
1891 context_server::types::MessageContent::Image {
1892 data,
1893 mime_type,
1894 annotations: _,
1895 } => acp::ContentBlock::Image(acp::ImageContent::new(data, mime_type)),
1896 context_server::types::MessageContent::Audio {
1897 data,
1898 mime_type,
1899 annotations: _,
1900 } => acp::ContentBlock::Audio(acp::AudioContent::new(data, mime_type)),
1901 context_server::types::MessageContent::Resource {
1902 resource,
1903 annotations: _,
1904 } => {
1905 let mut link =
1906 acp::ResourceLink::new(resource.uri.to_string(), resource.uri.to_string());
1907 if let Some(mime_type) = resource.mime_type {
1908 link = link.mime_type(mime_type);
1909 }
1910 acp::ContentBlock::ResourceLink(link)
1911 }
1912 }
1913}