1mod db;
2mod edit_agent;
3mod history_store;
4mod legacy_thread;
5mod native_agent_server;
6pub mod outline;
7mod templates;
8#[cfg(test)]
9mod tests;
10mod thread;
11mod tools;
12
13use context_server::ContextServerId;
14pub use db::*;
15pub use history_store::*;
16pub use native_agent_server::NativeAgentServer;
17pub use templates::*;
18pub use thread::*;
19pub use tools::*;
20
21use acp_thread::{AcpThread, AgentModelSelector, UserMessageId};
22use agent_client_protocol as acp;
23use anyhow::{Context as _, Result, anyhow};
24use chrono::{DateTime, Utc};
25use collections::{HashMap, HashSet, IndexMap};
26use fs::Fs;
27use futures::channel::{mpsc, oneshot};
28use futures::future::Shared;
29use futures::{StreamExt, future};
30use gpui::{
31 App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
32};
33use language_model::{IconOrSvg, LanguageModel, LanguageModelProvider, LanguageModelRegistry};
34use project::{Project, ProjectItem, ProjectPath, Worktree};
35use prompt_store::{
36 ProjectContext, PromptStore, RULES_FILE_NAMES, RulesFileContext, UserRulesContext,
37 WorktreeContext,
38};
39use serde::{Deserialize, Serialize};
40use settings::{LanguageModelSelection, update_settings_file};
41use std::any::Any;
42use std::path::{Path, PathBuf};
43use std::rc::Rc;
44use std::sync::Arc;
45use util::ResultExt;
46use util::rel_path::RelPath;
47
48#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
49pub struct ProjectSnapshot {
50 pub worktree_snapshots: Vec<project::telemetry_snapshot::TelemetryWorktreeSnapshot>,
51 pub timestamp: DateTime<Utc>,
52}
53
54pub struct RulesLoadingError {
55 pub message: SharedString,
56}
57
58/// Holds both the internal Thread and the AcpThread for a session
59struct Session {
60 /// The internal thread that processes messages
61 thread: Entity<Thread>,
62 /// The ACP thread that handles protocol communication
63 acp_thread: WeakEntity<acp_thread::AcpThread>,
64 pending_save: Task<()>,
65 _subscriptions: Vec<Subscription>,
66}
67
68pub struct LanguageModels {
69 /// Access language model by ID
70 models: HashMap<acp::ModelId, Arc<dyn LanguageModel>>,
71 /// Cached list for returning language model information
72 model_list: acp_thread::AgentModelList,
73 refresh_models_rx: watch::Receiver<()>,
74 refresh_models_tx: watch::Sender<()>,
75 _authenticate_all_providers_task: Task<()>,
76}
77
78impl LanguageModels {
79 fn new(cx: &mut App) -> Self {
80 let (refresh_models_tx, refresh_models_rx) = watch::channel(());
81
82 let mut this = Self {
83 models: HashMap::default(),
84 model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()),
85 refresh_models_rx,
86 refresh_models_tx,
87 _authenticate_all_providers_task: Self::authenticate_all_language_model_providers(cx),
88 };
89 this.refresh_list(cx);
90 this
91 }
92
93 fn refresh_list(&mut self, cx: &App) {
94 let providers = LanguageModelRegistry::global(cx)
95 .read(cx)
96 .visible_providers()
97 .into_iter()
98 .filter(|provider| provider.is_authenticated(cx))
99 .collect::<Vec<_>>();
100
101 let mut language_model_list = IndexMap::default();
102 let mut recommended_models = HashSet::default();
103
104 let mut recommended = Vec::new();
105 for provider in &providers {
106 for model in provider.recommended_models(cx) {
107 recommended_models.insert((model.provider_id(), model.id()));
108 recommended.push(Self::map_language_model_to_info(&model, provider));
109 }
110 }
111 if !recommended.is_empty() {
112 language_model_list.insert(
113 acp_thread::AgentModelGroupName("Recommended".into()),
114 recommended,
115 );
116 }
117
118 let mut models = HashMap::default();
119 for provider in providers {
120 let mut provider_models = Vec::new();
121 for model in provider.provided_models(cx) {
122 let model_info = Self::map_language_model_to_info(&model, &provider);
123 let model_id = model_info.id.clone();
124 provider_models.push(model_info);
125 models.insert(model_id, model);
126 }
127 if !provider_models.is_empty() {
128 language_model_list.insert(
129 acp_thread::AgentModelGroupName(provider.name().0.clone()),
130 provider_models,
131 );
132 }
133 }
134
135 self.models = models;
136 self.model_list = acp_thread::AgentModelList::Grouped(language_model_list);
137 self.refresh_models_tx.send(()).ok();
138 }
139
140 fn watch(&self) -> watch::Receiver<()> {
141 self.refresh_models_rx.clone()
142 }
143
144 pub fn model_from_id(&self, model_id: &acp::ModelId) -> Option<Arc<dyn LanguageModel>> {
145 self.models.get(model_id).cloned()
146 }
147
148 fn map_language_model_to_info(
149 model: &Arc<dyn LanguageModel>,
150 provider: &Arc<dyn LanguageModelProvider>,
151 ) -> acp_thread::AgentModelInfo {
152 acp_thread::AgentModelInfo {
153 id: Self::model_id(model),
154 name: model.name().0,
155 description: None,
156 icon: Some(match provider.icon() {
157 IconOrSvg::Svg(path) => acp_thread::AgentModelIcon::Path(path),
158 IconOrSvg::Icon(name) => acp_thread::AgentModelIcon::Named(name),
159 }),
160 }
161 }
162
163 fn model_id(model: &Arc<dyn LanguageModel>) -> acp::ModelId {
164 acp::ModelId::new(format!("{}/{}", model.provider_id().0, model.id().0))
165 }
166
167 fn authenticate_all_language_model_providers(cx: &mut App) -> Task<()> {
168 let authenticate_all_providers = LanguageModelRegistry::global(cx)
169 .read(cx)
170 .visible_providers()
171 .iter()
172 .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
173 .collect::<Vec<_>>();
174
175 cx.background_spawn(async move {
176 for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
177 if let Err(err) = authenticate_task.await {
178 match err {
179 language_model::AuthenticateError::CredentialsNotFound => {
180 // Since we're authenticating these providers in the
181 // background for the purposes of populating the
182 // language selector, we don't care about providers
183 // where the credentials are not found.
184 }
185 language_model::AuthenticateError::ConnectionRefused => {
186 // Not logging connection refused errors as they are mostly from LM Studio's noisy auth failures.
187 // LM Studio only has one auth method (endpoint call) which fails for users who haven't enabled it.
188 // TODO: Better manage LM Studio auth logic to avoid these noisy failures.
189 }
190 _ => {
191 // Some providers have noisy failure states that we
192 // don't want to spam the logs with every time the
193 // language model selector is initialized.
194 //
195 // Ideally these should have more clear failure modes
196 // that we know are safe to ignore here, like what we do
197 // with `CredentialsNotFound` above.
198 match provider_id.0.as_ref() {
199 "lmstudio" | "ollama" => {
200 // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
201 //
202 // These fail noisily, so we don't log them.
203 }
204 "copilot_chat" => {
205 // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
206 }
207 _ => {
208 log::error!(
209 "Failed to authenticate provider: {}: {err:#}",
210 provider_name.0
211 );
212 }
213 }
214 }
215 }
216 }
217 }
218 })
219 }
220}
221
222pub struct NativeAgent {
223 /// Session ID -> Session mapping
224 sessions: HashMap<acp::SessionId, Session>,
225 history: Entity<HistoryStore>,
226 /// Shared project context for all threads
227 project_context: Entity<ProjectContext>,
228 project_context_needs_refresh: watch::Sender<()>,
229 _maintain_project_context: Task<Result<()>>,
230 context_server_registry: Entity<ContextServerRegistry>,
231 /// Shared templates for all threads
232 templates: Arc<Templates>,
233 /// Cached model information
234 models: LanguageModels,
235 project: Entity<Project>,
236 prompt_store: Option<Entity<PromptStore>>,
237 fs: Arc<dyn Fs>,
238 _subscriptions: Vec<Subscription>,
239}
240
241impl NativeAgent {
242 pub async fn new(
243 project: Entity<Project>,
244 history: Entity<HistoryStore>,
245 templates: Arc<Templates>,
246 prompt_store: Option<Entity<PromptStore>>,
247 fs: Arc<dyn Fs>,
248 cx: &mut AsyncApp,
249 ) -> Result<Entity<NativeAgent>> {
250 log::debug!("Creating new NativeAgent");
251
252 let project_context = cx
253 .update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))?
254 .await;
255
256 cx.new(|cx| {
257 let context_server_store = project.read(cx).context_server_store();
258 let context_server_registry =
259 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
260
261 let mut subscriptions = vec![
262 cx.subscribe(&project, Self::handle_project_event),
263 cx.subscribe(
264 &LanguageModelRegistry::global(cx),
265 Self::handle_models_updated_event,
266 ),
267 cx.subscribe(
268 &context_server_store,
269 Self::handle_context_server_store_updated,
270 ),
271 cx.subscribe(
272 &context_server_registry,
273 Self::handle_context_server_registry_event,
274 ),
275 ];
276 if let Some(prompt_store) = prompt_store.as_ref() {
277 subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
278 }
279
280 let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
281 watch::channel(());
282 Self {
283 sessions: HashMap::default(),
284 history,
285 project_context: cx.new(|_| project_context),
286 project_context_needs_refresh: project_context_needs_refresh_tx,
287 _maintain_project_context: cx.spawn(async move |this, cx| {
288 Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
289 }),
290 context_server_registry,
291 templates,
292 models: LanguageModels::new(cx),
293 project,
294 prompt_store,
295 fs,
296 _subscriptions: subscriptions,
297 }
298 })
299 }
300
301 fn register_session(
302 &mut self,
303 thread_handle: Entity<Thread>,
304 cx: &mut Context<Self>,
305 ) -> Entity<AcpThread> {
306 let connection = Rc::new(NativeAgentConnection(cx.entity()));
307
308 let thread = thread_handle.read(cx);
309 let session_id = thread.id().clone();
310 let title = thread.title();
311 let project = thread.project.clone();
312 let action_log = thread.action_log.clone();
313 let prompt_capabilities_rx = thread.prompt_capabilities_rx.clone();
314 let acp_thread = cx.new(|cx| {
315 acp_thread::AcpThread::new(
316 title,
317 connection,
318 project.clone(),
319 action_log.clone(),
320 session_id.clone(),
321 prompt_capabilities_rx,
322 cx,
323 )
324 });
325
326 let registry = LanguageModelRegistry::read_global(cx);
327 let summarization_model = registry.thread_summary_model().map(|c| c.model);
328
329 thread_handle.update(cx, |thread, cx| {
330 thread.set_summarization_model(summarization_model, cx);
331 thread.add_default_tools(
332 Rc::new(AcpThreadEnvironment {
333 acp_thread: acp_thread.downgrade(),
334 }) as _,
335 cx,
336 )
337 });
338
339 let subscriptions = vec![
340 cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
341 this.sessions.remove(acp_thread.session_id());
342 }),
343 cx.subscribe(&thread_handle, Self::handle_thread_title_updated),
344 cx.subscribe(&thread_handle, Self::handle_thread_token_usage_updated),
345 cx.observe(&thread_handle, move |this, thread, cx| {
346 this.save_thread(thread, cx)
347 }),
348 ];
349
350 self.sessions.insert(
351 session_id,
352 Session {
353 thread: thread_handle,
354 acp_thread: acp_thread.downgrade(),
355 _subscriptions: subscriptions,
356 pending_save: Task::ready(()),
357 },
358 );
359
360 self.update_available_commands(cx);
361
362 acp_thread
363 }
364
365 pub fn models(&self) -> &LanguageModels {
366 &self.models
367 }
368
369 async fn maintain_project_context(
370 this: WeakEntity<Self>,
371 mut needs_refresh: watch::Receiver<()>,
372 cx: &mut AsyncApp,
373 ) -> Result<()> {
374 while needs_refresh.changed().await.is_ok() {
375 let project_context = this
376 .update(cx, |this, cx| {
377 Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx)
378 })?
379 .await;
380 this.update(cx, |this, cx| {
381 this.project_context = cx.new(|_| project_context);
382 })?;
383 }
384
385 Ok(())
386 }
387
388 fn build_project_context(
389 project: &Entity<Project>,
390 prompt_store: Option<&Entity<PromptStore>>,
391 cx: &mut App,
392 ) -> Task<ProjectContext> {
393 let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
394 let worktree_tasks = worktrees
395 .into_iter()
396 .map(|worktree| {
397 Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
398 })
399 .collect::<Vec<_>>();
400 let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
401 prompt_store.read_with(cx, |prompt_store, cx| {
402 let prompts = prompt_store.default_prompt_metadata();
403 let load_tasks = prompts.into_iter().map(|prompt_metadata| {
404 let contents = prompt_store.load(prompt_metadata.id, cx);
405 async move { (contents.await, prompt_metadata) }
406 });
407 cx.background_spawn(future::join_all(load_tasks))
408 })
409 } else {
410 Task::ready(vec![])
411 };
412
413 cx.spawn(async move |_cx| {
414 let (worktrees, default_user_rules) =
415 future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
416
417 let worktrees = worktrees
418 .into_iter()
419 .map(|(worktree, _rules_error)| {
420 // TODO: show error message
421 // if let Some(rules_error) = rules_error {
422 // this.update(cx, |_, cx| cx.emit(rules_error)).ok();
423 // }
424 worktree
425 })
426 .collect::<Vec<_>>();
427
428 let default_user_rules = default_user_rules
429 .into_iter()
430 .flat_map(|(contents, prompt_metadata)| match contents {
431 Ok(contents) => Some(UserRulesContext {
432 uuid: prompt_metadata.id.as_user()?,
433 title: prompt_metadata.title.map(|title| title.to_string()),
434 contents,
435 }),
436 Err(_err) => {
437 // TODO: show error message
438 // this.update(cx, |_, cx| {
439 // cx.emit(RulesLoadingError {
440 // message: format!("{err:?}").into(),
441 // });
442 // })
443 // .ok();
444 None
445 }
446 })
447 .collect::<Vec<_>>();
448
449 ProjectContext::new(worktrees, default_user_rules)
450 })
451 }
452
453 fn load_worktree_info_for_system_prompt(
454 worktree: Entity<Worktree>,
455 project: Entity<Project>,
456 cx: &mut App,
457 ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
458 let tree = worktree.read(cx);
459 let root_name = tree.root_name_str().into();
460 let abs_path = tree.abs_path();
461
462 let mut context = WorktreeContext {
463 root_name,
464 abs_path,
465 rules_file: None,
466 };
467
468 let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
469 let Some(rules_task) = rules_task else {
470 return Task::ready((context, None));
471 };
472
473 cx.spawn(async move |_| {
474 let (rules_file, rules_file_error) = match rules_task.await {
475 Ok(rules_file) => (Some(rules_file), None),
476 Err(err) => (
477 None,
478 Some(RulesLoadingError {
479 message: format!("{err}").into(),
480 }),
481 ),
482 };
483 context.rules_file = rules_file;
484 (context, rules_file_error)
485 })
486 }
487
488 fn load_worktree_rules_file(
489 worktree: Entity<Worktree>,
490 project: Entity<Project>,
491 cx: &mut App,
492 ) -> Option<Task<Result<RulesFileContext>>> {
493 let worktree = worktree.read(cx);
494 let worktree_id = worktree.id();
495 let selected_rules_file = RULES_FILE_NAMES
496 .into_iter()
497 .filter_map(|name| {
498 worktree
499 .entry_for_path(RelPath::unix(name).unwrap())
500 .filter(|entry| entry.is_file())
501 .map(|entry| entry.path.clone())
502 })
503 .next();
504
505 // Note that Cline supports `.clinerules` being a directory, but that is not currently
506 // supported. This doesn't seem to occur often in GitHub repositories.
507 selected_rules_file.map(|path_in_worktree| {
508 let project_path = ProjectPath {
509 worktree_id,
510 path: path_in_worktree.clone(),
511 };
512 let buffer_task =
513 project.update(cx, |project, cx| project.open_buffer(project_path, cx));
514 let rope_task = cx.spawn(async move |cx| {
515 buffer_task.await?.read_with(cx, |buffer, cx| {
516 let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
517 anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
518 })?
519 });
520 // Build a string from the rope on a background thread.
521 cx.background_spawn(async move {
522 let (project_entry_id, rope) = rope_task.await?;
523 anyhow::Ok(RulesFileContext {
524 path_in_worktree,
525 text: rope.to_string().trim().to_string(),
526 project_entry_id: project_entry_id.to_usize(),
527 })
528 })
529 })
530 }
531
532 fn handle_thread_title_updated(
533 &mut self,
534 thread: Entity<Thread>,
535 _: &TitleUpdated,
536 cx: &mut Context<Self>,
537 ) {
538 let session_id = thread.read(cx).id();
539 let Some(session) = self.sessions.get(session_id) else {
540 return;
541 };
542 let thread = thread.downgrade();
543 let acp_thread = session.acp_thread.clone();
544 cx.spawn(async move |_, cx| {
545 let title = thread.read_with(cx, |thread, _| thread.title())?;
546 let task = acp_thread.update(cx, |acp_thread, cx| acp_thread.set_title(title, cx))?;
547 task.await
548 })
549 .detach_and_log_err(cx);
550 }
551
552 fn handle_thread_token_usage_updated(
553 &mut self,
554 thread: Entity<Thread>,
555 usage: &TokenUsageUpdated,
556 cx: &mut Context<Self>,
557 ) {
558 let Some(session) = self.sessions.get(thread.read(cx).id()) else {
559 return;
560 };
561 session
562 .acp_thread
563 .update(cx, |acp_thread, cx| {
564 acp_thread.update_token_usage(usage.0.clone(), cx);
565 })
566 .ok();
567 }
568
569 fn handle_project_event(
570 &mut self,
571 _project: Entity<Project>,
572 event: &project::Event,
573 _cx: &mut Context<Self>,
574 ) {
575 match event {
576 project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
577 self.project_context_needs_refresh.send(()).ok();
578 }
579 project::Event::WorktreeUpdatedEntries(_, items) => {
580 if items.iter().any(|(path, _, _)| {
581 RULES_FILE_NAMES
582 .iter()
583 .any(|name| path.as_ref() == RelPath::unix(name).unwrap())
584 }) {
585 self.project_context_needs_refresh.send(()).ok();
586 }
587 }
588 _ => {}
589 }
590 }
591
592 fn handle_prompts_updated_event(
593 &mut self,
594 _prompt_store: Entity<PromptStore>,
595 _event: &prompt_store::PromptsUpdatedEvent,
596 _cx: &mut Context<Self>,
597 ) {
598 self.project_context_needs_refresh.send(()).ok();
599 }
600
601 fn handle_models_updated_event(
602 &mut self,
603 _registry: Entity<LanguageModelRegistry>,
604 _event: &language_model::Event,
605 cx: &mut Context<Self>,
606 ) {
607 self.models.refresh_list(cx);
608
609 let registry = LanguageModelRegistry::read_global(cx);
610 let default_model = registry.default_model().map(|m| m.model);
611 let summarization_model = registry.thread_summary_model().map(|m| m.model);
612
613 for session in self.sessions.values_mut() {
614 session.thread.update(cx, |thread, cx| {
615 if thread.model().is_none()
616 && let Some(model) = default_model.clone()
617 {
618 thread.set_model(model, cx);
619 cx.notify();
620 }
621 thread.set_summarization_model(summarization_model.clone(), cx);
622 });
623 }
624 }
625
626 fn handle_context_server_store_updated(
627 &mut self,
628 _store: Entity<project::context_server_store::ContextServerStore>,
629 _event: &project::context_server_store::Event,
630 cx: &mut Context<Self>,
631 ) {
632 self.update_available_commands(cx);
633 }
634
635 fn handle_context_server_registry_event(
636 &mut self,
637 _registry: Entity<ContextServerRegistry>,
638 event: &ContextServerRegistryEvent,
639 cx: &mut Context<Self>,
640 ) {
641 match event {
642 ContextServerRegistryEvent::ToolsChanged => {}
643 ContextServerRegistryEvent::PromptsChanged => {
644 self.update_available_commands(cx);
645 }
646 }
647 }
648
649 fn update_available_commands(&self, cx: &mut Context<Self>) {
650 let available_commands = self.build_available_commands(cx);
651 for session in self.sessions.values() {
652 if let Some(acp_thread) = session.acp_thread.upgrade() {
653 acp_thread.update(cx, |thread, cx| {
654 thread
655 .handle_session_update(
656 acp::SessionUpdate::AvailableCommandsUpdate(
657 acp::AvailableCommandsUpdate::new(available_commands.clone()),
658 ),
659 cx,
660 )
661 .log_err();
662 });
663 }
664 }
665 }
666
667 fn build_available_commands(&self, cx: &App) -> Vec<acp::AvailableCommand> {
668 let registry = self.context_server_registry.read(cx);
669
670 let mut prompt_name_counts: HashMap<&str, usize> = HashMap::default();
671 for context_server_prompt in registry.prompts() {
672 *prompt_name_counts
673 .entry(context_server_prompt.prompt.name.as_str())
674 .or_insert(0) += 1;
675 }
676
677 registry
678 .prompts()
679 .flat_map(|context_server_prompt| {
680 let prompt = &context_server_prompt.prompt;
681
682 let should_prefix = prompt_name_counts
683 .get(prompt.name.as_str())
684 .copied()
685 .unwrap_or(0)
686 > 1;
687
688 let name = if should_prefix {
689 format!("{}.{}", context_server_prompt.server_id, prompt.name)
690 } else {
691 prompt.name.clone()
692 };
693
694 let mut command = acp::AvailableCommand::new(
695 name,
696 prompt.description.clone().unwrap_or_default(),
697 );
698
699 match prompt.arguments.as_deref() {
700 Some([arg]) => {
701 let hint = format!("<{}>", arg.name);
702
703 command = command.input(acp::AvailableCommandInput::Unstructured(
704 acp::UnstructuredCommandInput::new(hint),
705 ));
706 }
707 Some([]) | None => {}
708 Some(_) => {
709 // skip >1 argument commands since we don't support them yet
710 return None;
711 }
712 }
713
714 Some(command)
715 })
716 .collect()
717 }
718
719 pub fn load_thread(
720 &mut self,
721 id: acp::SessionId,
722 cx: &mut Context<Self>,
723 ) -> Task<Result<Entity<Thread>>> {
724 let database_future = ThreadsDatabase::connect(cx);
725 cx.spawn(async move |this, cx| {
726 let database = database_future.await.map_err(|err| anyhow!(err))?;
727 let db_thread = database
728 .load_thread(id.clone())
729 .await?
730 .with_context(|| format!("no thread found with ID: {id:?}"))?;
731
732 this.update(cx, |this, cx| {
733 let summarization_model = LanguageModelRegistry::read_global(cx)
734 .thread_summary_model()
735 .map(|c| c.model);
736
737 cx.new(|cx| {
738 let mut thread = Thread::from_db(
739 id.clone(),
740 db_thread,
741 this.project.clone(),
742 this.project_context.clone(),
743 this.context_server_registry.clone(),
744 this.templates.clone(),
745 cx,
746 );
747 thread.set_summarization_model(summarization_model, cx);
748 thread
749 })
750 })
751 })
752 }
753
754 pub fn open_thread(
755 &mut self,
756 id: acp::SessionId,
757 cx: &mut Context<Self>,
758 ) -> Task<Result<Entity<AcpThread>>> {
759 let task = self.load_thread(id, cx);
760 cx.spawn(async move |this, cx| {
761 let thread = task.await?;
762 let acp_thread =
763 this.update(cx, |this, cx| this.register_session(thread.clone(), cx))?;
764 let events = thread.update(cx, |thread, cx| thread.replay(cx))?;
765 cx.update(|cx| {
766 NativeAgentConnection::handle_thread_events(events, acp_thread.downgrade(), cx)
767 })?
768 .await?;
769 Ok(acp_thread)
770 })
771 }
772
773 pub fn thread_summary(
774 &mut self,
775 id: acp::SessionId,
776 cx: &mut Context<Self>,
777 ) -> Task<Result<SharedString>> {
778 let thread = self.open_thread(id.clone(), cx);
779 cx.spawn(async move |this, cx| {
780 let acp_thread = thread.await?;
781 let result = this
782 .update(cx, |this, cx| {
783 this.sessions
784 .get(&id)
785 .unwrap()
786 .thread
787 .update(cx, |thread, cx| thread.summary(cx))
788 })?
789 .await
790 .context("Failed to generate summary")?;
791 drop(acp_thread);
792 Ok(result)
793 })
794 }
795
796 fn save_thread(&mut self, thread: Entity<Thread>, cx: &mut Context<Self>) {
797 if thread.read(cx).is_empty() {
798 return;
799 }
800
801 let database_future = ThreadsDatabase::connect(cx);
802 let (id, db_thread) =
803 thread.update(cx, |thread, cx| (thread.id().clone(), thread.to_db(cx)));
804 let Some(session) = self.sessions.get_mut(&id) else {
805 return;
806 };
807 let history = self.history.clone();
808 session.pending_save = cx.spawn(async move |_, cx| {
809 let Some(database) = database_future.await.map_err(|err| anyhow!(err)).log_err() else {
810 return;
811 };
812 let db_thread = db_thread.await;
813 database.save_thread(id, db_thread).await.log_err();
814 history.update(cx, |history, cx| history.reload(cx)).ok();
815 });
816 }
817
818 fn send_mcp_prompt(
819 &self,
820 message_id: UserMessageId,
821 session_id: agent_client_protocol::SessionId,
822 prompt_name: String,
823 server_id: ContextServerId,
824 arguments: HashMap<String, String>,
825 original_content: Vec<acp::ContentBlock>,
826 cx: &mut Context<Self>,
827 ) -> Task<Result<acp::PromptResponse>> {
828 let server_store = self.context_server_registry.read(cx).server_store().clone();
829 let path_style = self.project.read(cx).path_style(cx);
830
831 cx.spawn(async move |this, cx| {
832 let prompt =
833 crate::get_prompt(&server_store, &server_id, &prompt_name, arguments, cx).await?;
834
835 let (acp_thread, thread) = this.update(cx, |this, _cx| {
836 let session = this
837 .sessions
838 .get(&session_id)
839 .context("Failed to get session")?;
840 anyhow::Ok((session.acp_thread.clone(), session.thread.clone()))
841 })??;
842
843 let mut last_is_user = true;
844
845 thread.update(cx, |thread, cx| {
846 thread.push_acp_user_block(
847 message_id,
848 original_content.into_iter().skip(1),
849 path_style,
850 cx,
851 );
852 })?;
853
854 for message in prompt.messages {
855 let context_server::types::PromptMessage { role, content } = message;
856 let block = mcp_message_content_to_acp_content_block(content);
857
858 match role {
859 context_server::types::Role::User => {
860 let id = acp_thread::UserMessageId::new();
861
862 acp_thread.update(cx, |acp_thread, cx| {
863 acp_thread.push_user_content_block_with_indent(
864 Some(id.clone()),
865 block.clone(),
866 true,
867 cx,
868 );
869 anyhow::Ok(())
870 })??;
871
872 thread.update(cx, |thread, cx| {
873 thread.push_acp_user_block(id, [block], path_style, cx);
874 anyhow::Ok(())
875 })??;
876 }
877 context_server::types::Role::Assistant => {
878 acp_thread.update(cx, |acp_thread, cx| {
879 acp_thread.push_assistant_content_block_with_indent(
880 block.clone(),
881 false,
882 true,
883 cx,
884 );
885 anyhow::Ok(())
886 })??;
887
888 thread.update(cx, |thread, cx| {
889 thread.push_acp_agent_block(block, cx);
890 anyhow::Ok(())
891 })??;
892 }
893 }
894
895 last_is_user = role == context_server::types::Role::User;
896 }
897
898 let response_stream = thread.update(cx, |thread, cx| {
899 if last_is_user {
900 thread.send_existing(cx)
901 } else {
902 // Resume if MCP prompt did not end with a user message
903 thread.resume(cx)
904 }
905 })??;
906
907 cx.update(|cx| {
908 NativeAgentConnection::handle_thread_events(response_stream, acp_thread, cx)
909 })?
910 .await
911 })
912 }
913}
914
915/// Wrapper struct that implements the AgentConnection trait
916#[derive(Clone)]
917pub struct NativeAgentConnection(pub Entity<NativeAgent>);
918
919impl NativeAgentConnection {
920 pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option<Entity<Thread>> {
921 self.0
922 .read(cx)
923 .sessions
924 .get(session_id)
925 .map(|session| session.thread.clone())
926 }
927
928 pub fn load_thread(&self, id: acp::SessionId, cx: &mut App) -> Task<Result<Entity<Thread>>> {
929 self.0.update(cx, |this, cx| this.load_thread(id, cx))
930 }
931
932 fn run_turn(
933 &self,
934 session_id: acp::SessionId,
935 cx: &mut App,
936 f: impl 'static
937 + FnOnce(Entity<Thread>, &mut App) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>,
938 ) -> Task<Result<acp::PromptResponse>> {
939 let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
940 agent
941 .sessions
942 .get_mut(&session_id)
943 .map(|s| (s.thread.clone(), s.acp_thread.clone()))
944 }) else {
945 return Task::ready(Err(anyhow!("Session not found")));
946 };
947 log::debug!("Found session for: {}", session_id);
948
949 let response_stream = match f(thread, cx) {
950 Ok(stream) => stream,
951 Err(err) => return Task::ready(Err(err)),
952 };
953 Self::handle_thread_events(response_stream, acp_thread, cx)
954 }
955
956 fn handle_thread_events(
957 mut events: mpsc::UnboundedReceiver<Result<ThreadEvent>>,
958 acp_thread: WeakEntity<AcpThread>,
959 cx: &App,
960 ) -> Task<Result<acp::PromptResponse>> {
961 cx.spawn(async move |cx| {
962 // Handle response stream and forward to session.acp_thread
963 while let Some(result) = events.next().await {
964 match result {
965 Ok(event) => {
966 log::trace!("Received completion event: {:?}", event);
967
968 match event {
969 ThreadEvent::UserMessage(message) => {
970 acp_thread.update(cx, |thread, cx| {
971 for content in message.content {
972 thread.push_user_content_block(
973 Some(message.id.clone()),
974 content.into(),
975 cx,
976 );
977 }
978 })?;
979 }
980 ThreadEvent::AgentText(text) => {
981 acp_thread.update(cx, |thread, cx| {
982 thread.push_assistant_content_block(text.into(), false, cx)
983 })?;
984 }
985 ThreadEvent::AgentThinking(text) => {
986 acp_thread.update(cx, |thread, cx| {
987 thread.push_assistant_content_block(text.into(), true, cx)
988 })?;
989 }
990 ThreadEvent::ToolCallAuthorization(ToolCallAuthorization {
991 tool_call,
992 options,
993 response,
994 }) => {
995 let outcome_task = acp_thread.update(cx, |thread, cx| {
996 thread.request_tool_call_authorization(
997 tool_call, options, true, cx,
998 )
999 })??;
1000 cx.background_spawn(async move {
1001 if let acp::RequestPermissionOutcome::Selected(
1002 acp::SelectedPermissionOutcome { option_id, .. },
1003 ) = outcome_task.await
1004 {
1005 response
1006 .send(option_id)
1007 .map(|_| anyhow!("authorization receiver was dropped"))
1008 .log_err();
1009 }
1010 })
1011 .detach();
1012 }
1013 ThreadEvent::ToolCall(tool_call) => {
1014 acp_thread.update(cx, |thread, cx| {
1015 thread.upsert_tool_call(tool_call, cx)
1016 })??;
1017 }
1018 ThreadEvent::ToolCallUpdate(update) => {
1019 acp_thread.update(cx, |thread, cx| {
1020 thread.update_tool_call(update, cx)
1021 })??;
1022 }
1023 ThreadEvent::Retry(status) => {
1024 acp_thread.update(cx, |thread, cx| {
1025 thread.update_retry_status(status, cx)
1026 })?;
1027 }
1028 ThreadEvent::Stop(stop_reason) => {
1029 log::debug!("Assistant message complete: {:?}", stop_reason);
1030 return Ok(acp::PromptResponse::new(stop_reason));
1031 }
1032 }
1033 }
1034 Err(e) => {
1035 log::error!("Error in model response stream: {:?}", e);
1036 return Err(e);
1037 }
1038 }
1039 }
1040
1041 log::debug!("Response stream completed");
1042 anyhow::Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
1043 })
1044 }
1045}
1046
1047struct Command<'a> {
1048 prompt_name: &'a str,
1049 arg_value: &'a str,
1050 explicit_server_id: Option<&'a str>,
1051}
1052
1053impl<'a> Command<'a> {
1054 fn parse(prompt: &'a [acp::ContentBlock]) -> Option<Self> {
1055 let acp::ContentBlock::Text(text_content) = prompt.first()? else {
1056 return None;
1057 };
1058 let text = text_content.text.trim();
1059 let command = text.strip_prefix('/')?;
1060 let (command, arg_value) = command
1061 .split_once(char::is_whitespace)
1062 .unwrap_or((command, ""));
1063
1064 if let Some((server_id, prompt_name)) = command.split_once('.') {
1065 Some(Self {
1066 prompt_name,
1067 arg_value,
1068 explicit_server_id: Some(server_id),
1069 })
1070 } else {
1071 Some(Self {
1072 prompt_name: command,
1073 arg_value,
1074 explicit_server_id: None,
1075 })
1076 }
1077 }
1078}
1079
1080struct NativeAgentModelSelector {
1081 session_id: acp::SessionId,
1082 connection: NativeAgentConnection,
1083}
1084
1085impl acp_thread::AgentModelSelector for NativeAgentModelSelector {
1086 fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
1087 log::debug!("NativeAgentConnection::list_models called");
1088 let list = self.connection.0.read(cx).models.model_list.clone();
1089 Task::ready(if list.is_empty() {
1090 Err(anyhow::anyhow!("No models available"))
1091 } else {
1092 Ok(list)
1093 })
1094 }
1095
1096 fn select_model(&self, model_id: acp::ModelId, cx: &mut App) -> Task<Result<()>> {
1097 log::debug!(
1098 "Setting model for session {}: {}",
1099 self.session_id,
1100 model_id
1101 );
1102 let Some(thread) = self
1103 .connection
1104 .0
1105 .read(cx)
1106 .sessions
1107 .get(&self.session_id)
1108 .map(|session| session.thread.clone())
1109 else {
1110 return Task::ready(Err(anyhow!("Session not found")));
1111 };
1112
1113 let Some(model) = self.connection.0.read(cx).models.model_from_id(&model_id) else {
1114 return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
1115 };
1116
1117 thread.update(cx, |thread, cx| {
1118 thread.set_model(model.clone(), cx);
1119 });
1120
1121 update_settings_file(
1122 self.connection.0.read(cx).fs.clone(),
1123 cx,
1124 move |settings, _cx| {
1125 let provider = model.provider_id().0.to_string();
1126 let model = model.id().0.to_string();
1127 settings
1128 .agent
1129 .get_or_insert_default()
1130 .set_model(LanguageModelSelection {
1131 provider: provider.into(),
1132 model,
1133 });
1134 },
1135 );
1136
1137 Task::ready(Ok(()))
1138 }
1139
1140 fn selected_model(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelInfo>> {
1141 let Some(thread) = self
1142 .connection
1143 .0
1144 .read(cx)
1145 .sessions
1146 .get(&self.session_id)
1147 .map(|session| session.thread.clone())
1148 else {
1149 return Task::ready(Err(anyhow!("Session not found")));
1150 };
1151 let Some(model) = thread.read(cx).model() else {
1152 return Task::ready(Err(anyhow!("Model not found")));
1153 };
1154 let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
1155 else {
1156 return Task::ready(Err(anyhow!("Provider not found")));
1157 };
1158 Task::ready(Ok(LanguageModels::map_language_model_to_info(
1159 model, &provider,
1160 )))
1161 }
1162
1163 fn watch(&self, cx: &mut App) -> Option<watch::Receiver<()>> {
1164 Some(self.connection.0.read(cx).models.watch())
1165 }
1166
1167 fn should_render_footer(&self) -> bool {
1168 true
1169 }
1170
1171 fn supports_favorites(&self) -> bool {
1172 true
1173 }
1174}
1175
1176impl acp_thread::AgentConnection for NativeAgentConnection {
1177 fn telemetry_id(&self) -> SharedString {
1178 "zed".into()
1179 }
1180
1181 fn new_thread(
1182 self: Rc<Self>,
1183 project: Entity<Project>,
1184 cwd: &Path,
1185 cx: &mut App,
1186 ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
1187 let agent = self.0.clone();
1188 log::debug!("Creating new thread for project at: {:?}", cwd);
1189
1190 cx.spawn(async move |cx| {
1191 log::debug!("Starting thread creation in async context");
1192
1193 // Create Thread
1194 let thread = agent.update(
1195 cx,
1196 |agent, cx: &mut gpui::Context<NativeAgent>| -> Result<_> {
1197 // Fetch default model from registry settings
1198 let registry = LanguageModelRegistry::read_global(cx);
1199 // Log available models for debugging
1200 let available_count = registry.available_models(cx).count();
1201 log::debug!("Total available models: {}", available_count);
1202
1203 let default_model = registry.default_model().and_then(|default_model| {
1204 agent
1205 .models
1206 .model_from_id(&LanguageModels::model_id(&default_model.model))
1207 });
1208 Ok(cx.new(|cx| {
1209 Thread::new(
1210 project.clone(),
1211 agent.project_context.clone(),
1212 agent.context_server_registry.clone(),
1213 agent.templates.clone(),
1214 default_model,
1215 cx,
1216 )
1217 }))
1218 },
1219 )??;
1220 agent.update(cx, |agent, cx| agent.register_session(thread, cx))
1221 })
1222 }
1223
1224 fn auth_methods(&self) -> &[acp::AuthMethod] {
1225 &[] // No auth for in-process
1226 }
1227
1228 fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
1229 Task::ready(Ok(()))
1230 }
1231
1232 fn model_selector(&self, session_id: &acp::SessionId) -> Option<Rc<dyn AgentModelSelector>> {
1233 Some(Rc::new(NativeAgentModelSelector {
1234 session_id: session_id.clone(),
1235 connection: self.clone(),
1236 }) as Rc<dyn AgentModelSelector>)
1237 }
1238
1239 fn prompt(
1240 &self,
1241 id: Option<acp_thread::UserMessageId>,
1242 params: acp::PromptRequest,
1243 cx: &mut App,
1244 ) -> Task<Result<acp::PromptResponse>> {
1245 let id = id.expect("UserMessageId is required");
1246 let session_id = params.session_id.clone();
1247 log::info!("Received prompt request for session: {}", session_id);
1248 log::debug!("Prompt blocks count: {}", params.prompt.len());
1249
1250 if let Some(parsed_command) = Command::parse(¶ms.prompt) {
1251 let registry = self.0.read(cx).context_server_registry.read(cx);
1252
1253 let explicit_server_id = parsed_command
1254 .explicit_server_id
1255 .map(|server_id| ContextServerId(server_id.into()));
1256
1257 if let Some(prompt) =
1258 registry.find_prompt(explicit_server_id.as_ref(), parsed_command.prompt_name)
1259 {
1260 let arguments = if !parsed_command.arg_value.is_empty()
1261 && let Some(arg_name) = prompt
1262 .prompt
1263 .arguments
1264 .as_ref()
1265 .and_then(|args| args.first())
1266 .map(|arg| arg.name.clone())
1267 {
1268 HashMap::from_iter([(arg_name, parsed_command.arg_value.to_string())])
1269 } else {
1270 Default::default()
1271 };
1272
1273 let prompt_name = prompt.prompt.name.clone();
1274 let server_id = prompt.server_id.clone();
1275
1276 return self.0.update(cx, |agent, cx| {
1277 agent.send_mcp_prompt(
1278 id,
1279 session_id.clone(),
1280 prompt_name,
1281 server_id,
1282 arguments,
1283 params.prompt,
1284 cx,
1285 )
1286 });
1287 };
1288 };
1289
1290 let path_style = self.0.read(cx).project.read(cx).path_style(cx);
1291
1292 self.run_turn(session_id, cx, move |thread, cx| {
1293 let content: Vec<UserMessageContent> = params
1294 .prompt
1295 .into_iter()
1296 .map(|block| UserMessageContent::from_content_block(block, path_style))
1297 .collect::<Vec<_>>();
1298 log::debug!("Converted prompt to message: {} chars", content.len());
1299 log::debug!("Message id: {:?}", id);
1300 log::debug!("Message content: {:?}", content);
1301
1302 thread.update(cx, |thread, cx| thread.send(id, content, cx))
1303 })
1304 }
1305
1306 fn resume(
1307 &self,
1308 session_id: &acp::SessionId,
1309 _cx: &App,
1310 ) -> Option<Rc<dyn acp_thread::AgentSessionResume>> {
1311 Some(Rc::new(NativeAgentSessionResume {
1312 connection: self.clone(),
1313 session_id: session_id.clone(),
1314 }) as _)
1315 }
1316
1317 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
1318 log::info!("Cancelling on session: {}", session_id);
1319 self.0.update(cx, |agent, cx| {
1320 if let Some(agent) = agent.sessions.get(session_id) {
1321 agent.thread.update(cx, |thread, cx| thread.cancel(cx));
1322 }
1323 });
1324 }
1325
1326 fn truncate(
1327 &self,
1328 session_id: &agent_client_protocol::SessionId,
1329 cx: &App,
1330 ) -> Option<Rc<dyn acp_thread::AgentSessionTruncate>> {
1331 self.0.read_with(cx, |agent, _cx| {
1332 agent.sessions.get(session_id).map(|session| {
1333 Rc::new(NativeAgentSessionTruncate {
1334 thread: session.thread.clone(),
1335 acp_thread: session.acp_thread.clone(),
1336 }) as _
1337 })
1338 })
1339 }
1340
1341 fn set_title(
1342 &self,
1343 session_id: &acp::SessionId,
1344 _cx: &App,
1345 ) -> Option<Rc<dyn acp_thread::AgentSessionSetTitle>> {
1346 Some(Rc::new(NativeAgentSessionSetTitle {
1347 connection: self.clone(),
1348 session_id: session_id.clone(),
1349 }) as _)
1350 }
1351
1352 fn telemetry(&self) -> Option<Rc<dyn acp_thread::AgentTelemetry>> {
1353 Some(Rc::new(self.clone()) as Rc<dyn acp_thread::AgentTelemetry>)
1354 }
1355
1356 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1357 self
1358 }
1359}
1360
1361impl acp_thread::AgentTelemetry for NativeAgentConnection {
1362 fn thread_data(
1363 &self,
1364 session_id: &acp::SessionId,
1365 cx: &mut App,
1366 ) -> Task<Result<serde_json::Value>> {
1367 let Some(session) = self.0.read(cx).sessions.get(session_id) else {
1368 return Task::ready(Err(anyhow!("Session not found")));
1369 };
1370
1371 let task = session.thread.read(cx).to_db(cx);
1372 cx.background_spawn(async move {
1373 serde_json::to_value(task.await).context("Failed to serialize thread")
1374 })
1375 }
1376}
1377
1378struct NativeAgentSessionTruncate {
1379 thread: Entity<Thread>,
1380 acp_thread: WeakEntity<AcpThread>,
1381}
1382
1383impl acp_thread::AgentSessionTruncate for NativeAgentSessionTruncate {
1384 fn run(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
1385 match self.thread.update(cx, |thread, cx| {
1386 thread.truncate(message_id.clone(), cx)?;
1387 Ok(thread.latest_token_usage())
1388 }) {
1389 Ok(usage) => {
1390 self.acp_thread
1391 .update(cx, |thread, cx| {
1392 thread.update_token_usage(usage, cx);
1393 })
1394 .ok();
1395 Task::ready(Ok(()))
1396 }
1397 Err(error) => Task::ready(Err(error)),
1398 }
1399 }
1400}
1401
1402struct NativeAgentSessionResume {
1403 connection: NativeAgentConnection,
1404 session_id: acp::SessionId,
1405}
1406
1407impl acp_thread::AgentSessionResume for NativeAgentSessionResume {
1408 fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>> {
1409 self.connection
1410 .run_turn(self.session_id.clone(), cx, |thread, cx| {
1411 thread.update(cx, |thread, cx| thread.resume(cx))
1412 })
1413 }
1414}
1415
1416struct NativeAgentSessionSetTitle {
1417 connection: NativeAgentConnection,
1418 session_id: acp::SessionId,
1419}
1420
1421impl acp_thread::AgentSessionSetTitle for NativeAgentSessionSetTitle {
1422 fn run(&self, title: SharedString, cx: &mut App) -> Task<Result<()>> {
1423 let Some(session) = self.connection.0.read(cx).sessions.get(&self.session_id) else {
1424 return Task::ready(Err(anyhow!("session not found")));
1425 };
1426 let thread = session.thread.clone();
1427 thread.update(cx, |thread, cx| thread.set_title(title, cx));
1428 Task::ready(Ok(()))
1429 }
1430}
1431
1432pub struct AcpThreadEnvironment {
1433 acp_thread: WeakEntity<AcpThread>,
1434}
1435
1436impl ThreadEnvironment for AcpThreadEnvironment {
1437 fn create_terminal(
1438 &self,
1439 command: String,
1440 cwd: Option<PathBuf>,
1441 output_byte_limit: Option<u64>,
1442 cx: &mut AsyncApp,
1443 ) -> Task<Result<Rc<dyn TerminalHandle>>> {
1444 let task = self.acp_thread.update(cx, |thread, cx| {
1445 thread.create_terminal(command, vec![], vec![], cwd, output_byte_limit, cx)
1446 });
1447
1448 let acp_thread = self.acp_thread.clone();
1449 cx.spawn(async move |cx| {
1450 let terminal = task?.await?;
1451
1452 let (drop_tx, drop_rx) = oneshot::channel();
1453 let terminal_id = terminal.read_with(cx, |terminal, _cx| terminal.id().clone())?;
1454
1455 cx.spawn(async move |cx| {
1456 drop_rx.await.ok();
1457 acp_thread.update(cx, |thread, cx| thread.release_terminal(terminal_id, cx))
1458 })
1459 .detach();
1460
1461 let handle = AcpTerminalHandle {
1462 terminal,
1463 _drop_tx: Some(drop_tx),
1464 };
1465
1466 Ok(Rc::new(handle) as _)
1467 })
1468 }
1469}
1470
1471pub struct AcpTerminalHandle {
1472 terminal: Entity<acp_thread::Terminal>,
1473 _drop_tx: Option<oneshot::Sender<()>>,
1474}
1475
1476impl TerminalHandle for AcpTerminalHandle {
1477 fn id(&self, cx: &AsyncApp) -> Result<acp::TerminalId> {
1478 self.terminal.read_with(cx, |term, _cx| term.id().clone())
1479 }
1480
1481 fn wait_for_exit(&self, cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
1482 self.terminal
1483 .read_with(cx, |term, _cx| term.wait_for_exit())
1484 }
1485
1486 fn current_output(&self, cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
1487 self.terminal
1488 .read_with(cx, |term, cx| term.current_output(cx))
1489 }
1490
1491 fn kill(&self, cx: &AsyncApp) -> Result<()> {
1492 cx.update(|cx| {
1493 self.terminal.update(cx, |terminal, cx| {
1494 terminal.kill(cx);
1495 });
1496 })?;
1497 Ok(())
1498 }
1499}
1500
1501#[cfg(test)]
1502mod internal_tests {
1503 use crate::HistoryEntryId;
1504
1505 use super::*;
1506 use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelInfo, MentionUri};
1507 use fs::FakeFs;
1508 use gpui::TestAppContext;
1509 use indoc::formatdoc;
1510 use language_model::fake_provider::FakeLanguageModel;
1511 use serde_json::json;
1512 use settings::SettingsStore;
1513 use util::{path, rel_path::rel_path};
1514
1515 #[gpui::test]
1516 async fn test_maintaining_project_context(cx: &mut TestAppContext) {
1517 init_test(cx);
1518 let fs = FakeFs::new(cx.executor());
1519 fs.insert_tree(
1520 "/",
1521 json!({
1522 "a": {}
1523 }),
1524 )
1525 .await;
1526 let project = Project::test(fs.clone(), [], cx).await;
1527 let text_thread_store =
1528 cx.new(|cx| assistant_text_thread::TextThreadStore::fake(project.clone(), cx));
1529 let history_store = cx.new(|cx| HistoryStore::new(text_thread_store, cx));
1530 let agent = NativeAgent::new(
1531 project.clone(),
1532 history_store,
1533 Templates::new(),
1534 None,
1535 fs.clone(),
1536 &mut cx.to_async(),
1537 )
1538 .await
1539 .unwrap();
1540 agent.read_with(cx, |agent, cx| {
1541 assert_eq!(agent.project_context.read(cx).worktrees, vec![])
1542 });
1543
1544 let worktree = project
1545 .update(cx, |project, cx| project.create_worktree("/a", true, cx))
1546 .await
1547 .unwrap();
1548 cx.run_until_parked();
1549 agent.read_with(cx, |agent, cx| {
1550 assert_eq!(
1551 agent.project_context.read(cx).worktrees,
1552 vec![WorktreeContext {
1553 root_name: "a".into(),
1554 abs_path: Path::new("/a").into(),
1555 rules_file: None
1556 }]
1557 )
1558 });
1559
1560 // Creating `/a/.rules` updates the project context.
1561 fs.insert_file("/a/.rules", Vec::new()).await;
1562 cx.run_until_parked();
1563 agent.read_with(cx, |agent, cx| {
1564 let rules_entry = worktree
1565 .read(cx)
1566 .entry_for_path(rel_path(".rules"))
1567 .unwrap();
1568 assert_eq!(
1569 agent.project_context.read(cx).worktrees,
1570 vec![WorktreeContext {
1571 root_name: "a".into(),
1572 abs_path: Path::new("/a").into(),
1573 rules_file: Some(RulesFileContext {
1574 path_in_worktree: rel_path(".rules").into(),
1575 text: "".into(),
1576 project_entry_id: rules_entry.id.to_usize()
1577 })
1578 }]
1579 )
1580 });
1581 }
1582
1583 #[gpui::test]
1584 async fn test_listing_models(cx: &mut TestAppContext) {
1585 init_test(cx);
1586 let fs = FakeFs::new(cx.executor());
1587 fs.insert_tree("/", json!({ "a": {} })).await;
1588 let project = Project::test(fs.clone(), [], cx).await;
1589 let text_thread_store =
1590 cx.new(|cx| assistant_text_thread::TextThreadStore::fake(project.clone(), cx));
1591 let history_store = cx.new(|cx| HistoryStore::new(text_thread_store, cx));
1592 let connection = NativeAgentConnection(
1593 NativeAgent::new(
1594 project.clone(),
1595 history_store,
1596 Templates::new(),
1597 None,
1598 fs.clone(),
1599 &mut cx.to_async(),
1600 )
1601 .await
1602 .unwrap(),
1603 );
1604
1605 // Create a thread/session
1606 let acp_thread = cx
1607 .update(|cx| {
1608 Rc::new(connection.clone()).new_thread(project.clone(), Path::new("/a"), cx)
1609 })
1610 .await
1611 .unwrap();
1612
1613 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
1614
1615 let models = cx
1616 .update(|cx| {
1617 connection
1618 .model_selector(&session_id)
1619 .unwrap()
1620 .list_models(cx)
1621 })
1622 .await
1623 .unwrap();
1624
1625 let acp_thread::AgentModelList::Grouped(models) = models else {
1626 panic!("Unexpected model group");
1627 };
1628 assert_eq!(
1629 models,
1630 IndexMap::from_iter([(
1631 AgentModelGroupName("Fake".into()),
1632 vec![AgentModelInfo {
1633 id: acp::ModelId::new("fake/fake"),
1634 name: "Fake".into(),
1635 description: None,
1636 icon: Some(acp_thread::AgentModelIcon::Named(
1637 ui::IconName::ZedAssistant
1638 )),
1639 }]
1640 )])
1641 );
1642 }
1643
1644 #[gpui::test]
1645 async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) {
1646 init_test(cx);
1647 let fs = FakeFs::new(cx.executor());
1648 fs.create_dir(paths::settings_file().parent().unwrap())
1649 .await
1650 .unwrap();
1651 fs.insert_file(
1652 paths::settings_file(),
1653 json!({
1654 "agent": {
1655 "default_model": {
1656 "provider": "foo",
1657 "model": "bar"
1658 }
1659 }
1660 })
1661 .to_string()
1662 .into_bytes(),
1663 )
1664 .await;
1665 let project = Project::test(fs.clone(), [], cx).await;
1666
1667 let text_thread_store =
1668 cx.new(|cx| assistant_text_thread::TextThreadStore::fake(project.clone(), cx));
1669 let history_store = cx.new(|cx| HistoryStore::new(text_thread_store, cx));
1670
1671 // Create the agent and connection
1672 let agent = NativeAgent::new(
1673 project.clone(),
1674 history_store,
1675 Templates::new(),
1676 None,
1677 fs.clone(),
1678 &mut cx.to_async(),
1679 )
1680 .await
1681 .unwrap();
1682 let connection = NativeAgentConnection(agent.clone());
1683
1684 // Create a thread/session
1685 let acp_thread = cx
1686 .update(|cx| {
1687 Rc::new(connection.clone()).new_thread(project.clone(), Path::new("/a"), cx)
1688 })
1689 .await
1690 .unwrap();
1691
1692 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
1693
1694 // Select a model
1695 let selector = connection.model_selector(&session_id).unwrap();
1696 let model_id = acp::ModelId::new("fake/fake");
1697 cx.update(|cx| selector.select_model(model_id.clone(), cx))
1698 .await
1699 .unwrap();
1700
1701 // Verify the thread has the selected model
1702 agent.read_with(cx, |agent, _| {
1703 let session = agent.sessions.get(&session_id).unwrap();
1704 session.thread.read_with(cx, |thread, _| {
1705 assert_eq!(thread.model().unwrap().id().0, "fake");
1706 });
1707 });
1708
1709 cx.run_until_parked();
1710
1711 // Verify settings file was updated
1712 let settings_content = fs.load(paths::settings_file()).await.unwrap();
1713 let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
1714
1715 // Check that the agent settings contain the selected model
1716 assert_eq!(
1717 settings_json["agent"]["default_model"]["model"],
1718 json!("fake")
1719 );
1720 assert_eq!(
1721 settings_json["agent"]["default_model"]["provider"],
1722 json!("fake")
1723 );
1724 }
1725
1726 #[gpui::test]
1727 async fn test_save_load_thread(cx: &mut TestAppContext) {
1728 init_test(cx);
1729 let fs = FakeFs::new(cx.executor());
1730 fs.insert_tree(
1731 "/",
1732 json!({
1733 "a": {
1734 "b.md": "Lorem"
1735 }
1736 }),
1737 )
1738 .await;
1739 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
1740 let text_thread_store =
1741 cx.new(|cx| assistant_text_thread::TextThreadStore::fake(project.clone(), cx));
1742 let history_store = cx.new(|cx| HistoryStore::new(text_thread_store, cx));
1743 let agent = NativeAgent::new(
1744 project.clone(),
1745 history_store.clone(),
1746 Templates::new(),
1747 None,
1748 fs.clone(),
1749 &mut cx.to_async(),
1750 )
1751 .await
1752 .unwrap();
1753 let connection = Rc::new(NativeAgentConnection(agent.clone()));
1754
1755 let acp_thread = cx
1756 .update(|cx| {
1757 connection
1758 .clone()
1759 .new_thread(project.clone(), Path::new(""), cx)
1760 })
1761 .await
1762 .unwrap();
1763 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
1764 let thread = agent.read_with(cx, |agent, _| {
1765 agent.sessions.get(&session_id).unwrap().thread.clone()
1766 });
1767
1768 // Ensure empty threads are not saved, even if they get mutated.
1769 let model = Arc::new(FakeLanguageModel::default());
1770 let summary_model = Arc::new(FakeLanguageModel::default());
1771 thread.update(cx, |thread, cx| {
1772 thread.set_model(model.clone(), cx);
1773 thread.set_summarization_model(Some(summary_model.clone()), cx);
1774 });
1775 cx.run_until_parked();
1776 assert_eq!(history_entries(&history_store, cx), vec![]);
1777
1778 let send = acp_thread.update(cx, |thread, cx| {
1779 thread.send(
1780 vec![
1781 "What does ".into(),
1782 acp::ContentBlock::ResourceLink(acp::ResourceLink::new(
1783 "b.md",
1784 MentionUri::File {
1785 abs_path: path!("/a/b.md").into(),
1786 }
1787 .to_uri()
1788 .to_string(),
1789 )),
1790 " mean?".into(),
1791 ],
1792 cx,
1793 )
1794 });
1795 let send = cx.foreground_executor().spawn(send);
1796 cx.run_until_parked();
1797
1798 model.send_last_completion_stream_text_chunk("Lorem.");
1799 model.end_last_completion_stream();
1800 cx.run_until_parked();
1801 summary_model
1802 .send_last_completion_stream_text_chunk(&format!("Explaining {}", path!("/a/b.md")));
1803 summary_model.end_last_completion_stream();
1804
1805 send.await.unwrap();
1806 let uri = MentionUri::File {
1807 abs_path: path!("/a/b.md").into(),
1808 }
1809 .to_uri();
1810 acp_thread.read_with(cx, |thread, cx| {
1811 assert_eq!(
1812 thread.to_markdown(cx),
1813 formatdoc! {"
1814 ## User
1815
1816 What does [@b.md]({uri}) mean?
1817
1818 ## Assistant
1819
1820 Lorem.
1821
1822 "}
1823 )
1824 });
1825
1826 cx.run_until_parked();
1827
1828 // Drop the ACP thread, which should cause the session to be dropped as well.
1829 cx.update(|_| {
1830 drop(thread);
1831 drop(acp_thread);
1832 });
1833 agent.read_with(cx, |agent, _| {
1834 assert_eq!(agent.sessions.keys().cloned().collect::<Vec<_>>(), []);
1835 });
1836
1837 // Ensure the thread can be reloaded from disk.
1838 assert_eq!(
1839 history_entries(&history_store, cx),
1840 vec![(
1841 HistoryEntryId::AcpThread(session_id.clone()),
1842 format!("Explaining {}", path!("/a/b.md"))
1843 )]
1844 );
1845 let acp_thread = agent
1846 .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
1847 .await
1848 .unwrap();
1849 acp_thread.read_with(cx, |thread, cx| {
1850 assert_eq!(
1851 thread.to_markdown(cx),
1852 formatdoc! {"
1853 ## User
1854
1855 What does [@b.md]({uri}) mean?
1856
1857 ## Assistant
1858
1859 Lorem.
1860
1861 "}
1862 )
1863 });
1864 }
1865
1866 fn history_entries(
1867 history: &Entity<HistoryStore>,
1868 cx: &mut TestAppContext,
1869 ) -> Vec<(HistoryEntryId, String)> {
1870 history.read_with(cx, |history, _| {
1871 history
1872 .entries()
1873 .map(|e| (e.id(), e.title().to_string()))
1874 .collect::<Vec<_>>()
1875 })
1876 }
1877
1878 fn init_test(cx: &mut TestAppContext) {
1879 env_logger::try_init().ok();
1880 cx.update(|cx| {
1881 let settings_store = SettingsStore::test(cx);
1882 cx.set_global(settings_store);
1883
1884 LanguageModelRegistry::test(cx);
1885 });
1886 }
1887}
1888
1889fn mcp_message_content_to_acp_content_block(
1890 content: context_server::types::MessageContent,
1891) -> acp::ContentBlock {
1892 match content {
1893 context_server::types::MessageContent::Text {
1894 text,
1895 annotations: _,
1896 } => text.into(),
1897 context_server::types::MessageContent::Image {
1898 data,
1899 mime_type,
1900 annotations: _,
1901 } => acp::ContentBlock::Image(acp::ImageContent::new(data, mime_type)),
1902 context_server::types::MessageContent::Audio {
1903 data,
1904 mime_type,
1905 annotations: _,
1906 } => acp::ContentBlock::Audio(acp::AudioContent::new(data, mime_type)),
1907 context_server::types::MessageContent::Resource {
1908 resource,
1909 annotations: _,
1910 } => {
1911 let mut link =
1912 acp::ResourceLink::new(resource.uri.to_string(), resource.uri.to_string());
1913 if let Some(mime_type) = resource.mime_type {
1914 link = link.mime_type(mime_type);
1915 }
1916 acp::ContentBlock::ResourceLink(link)
1917 }
1918 }
1919}