1mod db;
2mod edit_agent;
3mod history_store;
4mod legacy_thread;
5mod native_agent_server;
6pub mod outline;
7mod templates;
8#[cfg(test)]
9mod tests;
10mod thread;
11mod tools;
12
13use context_server::ContextServerId;
14pub use db::*;
15pub use history_store::*;
16pub use native_agent_server::NativeAgentServer;
17pub use templates::*;
18pub use thread::*;
19pub use tools::*;
20
21use acp_thread::{AcpThread, AgentModelIcon, AgentModelSelector, UserMessageId};
22use agent_client_protocol as acp;
23use anyhow::{Context as _, Result, anyhow};
24use chrono::{DateTime, Utc};
25use collections::{HashMap, HashSet, IndexMap};
26use fs::Fs;
27use futures::channel::{mpsc, oneshot};
28use futures::future::Shared;
29use futures::{StreamExt, future};
30use gpui::{
31 App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
32};
33use language_model::{LanguageModel, LanguageModelProvider, LanguageModelRegistry};
34use project::{Project, ProjectItem, ProjectPath, Worktree};
35use prompt_store::{
36 ProjectContext, PromptStore, RULES_FILE_NAMES, RulesFileContext, UserRulesContext,
37 WorktreeContext,
38};
39use serde::{Deserialize, Serialize};
40use settings::{LanguageModelSelection, update_settings_file};
41use std::any::Any;
42use std::path::{Path, PathBuf};
43use std::rc::Rc;
44use std::sync::Arc;
45use util::ResultExt;
46use util::rel_path::RelPath;
47
48#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
49pub struct ProjectSnapshot {
50 pub worktree_snapshots: Vec<project::telemetry_snapshot::TelemetryWorktreeSnapshot>,
51 pub timestamp: DateTime<Utc>,
52}
53
54pub struct RulesLoadingError {
55 pub message: SharedString,
56}
57
58/// Holds both the internal Thread and the AcpThread for a session
59struct Session {
60 /// The internal thread that processes messages
61 thread: Entity<Thread>,
62 /// The ACP thread that handles protocol communication
63 acp_thread: WeakEntity<acp_thread::AcpThread>,
64 pending_save: Task<()>,
65 _subscriptions: Vec<Subscription>,
66}
67
68pub struct LanguageModels {
69 /// Access language model by ID
70 models: HashMap<acp::ModelId, Arc<dyn LanguageModel>>,
71 /// Cached list for returning language model information
72 model_list: acp_thread::AgentModelList,
73 refresh_models_rx: watch::Receiver<()>,
74 refresh_models_tx: watch::Sender<()>,
75 _authenticate_all_providers_task: Task<()>,
76}
77
78impl LanguageModels {
79 fn new(cx: &mut App) -> Self {
80 let (refresh_models_tx, refresh_models_rx) = watch::channel(());
81
82 let mut this = Self {
83 models: HashMap::default(),
84 model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()),
85 refresh_models_rx,
86 refresh_models_tx,
87 _authenticate_all_providers_task: Self::authenticate_all_language_model_providers(cx),
88 };
89 this.refresh_list(cx);
90 this
91 }
92
93 fn refresh_list(&mut self, cx: &App) {
94 let providers = LanguageModelRegistry::global(cx)
95 .read(cx)
96 .visible_providers()
97 .into_iter()
98 .filter(|provider| provider.is_authenticated(cx))
99 .collect::<Vec<_>>();
100
101 let mut language_model_list = IndexMap::default();
102 let mut recommended_models = HashSet::default();
103
104 let mut recommended = Vec::new();
105 for provider in &providers {
106 for model in provider.recommended_models(cx) {
107 recommended_models.insert((model.provider_id(), model.id()));
108 recommended.push(Self::map_language_model_to_info(&model, provider));
109 }
110 }
111 if !recommended.is_empty() {
112 language_model_list.insert(
113 acp_thread::AgentModelGroupName("Recommended".into()),
114 recommended,
115 );
116 }
117
118 let mut models = HashMap::default();
119 for provider in providers {
120 let mut provider_models = Vec::new();
121 for model in provider.provided_models(cx) {
122 let model_info = Self::map_language_model_to_info(&model, &provider);
123 let model_id = model_info.id.clone();
124 provider_models.push(model_info);
125 models.insert(model_id, model);
126 }
127 if !provider_models.is_empty() {
128 language_model_list.insert(
129 acp_thread::AgentModelGroupName(provider.name().0.clone()),
130 provider_models,
131 );
132 }
133 }
134
135 self.models = models;
136 self.model_list = acp_thread::AgentModelList::Grouped(language_model_list);
137 self.refresh_models_tx.send(()).ok();
138 }
139
140 fn watch(&self) -> watch::Receiver<()> {
141 self.refresh_models_rx.clone()
142 }
143
144 pub fn model_from_id(&self, model_id: &acp::ModelId) -> Option<Arc<dyn LanguageModel>> {
145 self.models.get(model_id).cloned()
146 }
147
148 fn map_language_model_to_info(
149 model: &Arc<dyn LanguageModel>,
150 provider: &Arc<dyn LanguageModelProvider>,
151 ) -> acp_thread::AgentModelInfo {
152 let icon = if let Some(path) = provider.icon_path() {
153 Some(AgentModelIcon::Path(path))
154 } else {
155 Some(AgentModelIcon::Named(provider.icon()))
156 };
157 acp_thread::AgentModelInfo {
158 id: Self::model_id(model),
159 name: model.name().0,
160 description: None,
161 icon,
162 }
163 }
164
165 fn model_id(model: &Arc<dyn LanguageModel>) -> acp::ModelId {
166 acp::ModelId::new(format!("{}/{}", model.provider_id().0, model.id().0))
167 }
168
169 fn authenticate_all_language_model_providers(cx: &mut App) -> Task<()> {
170 let authenticate_all_providers = LanguageModelRegistry::global(cx)
171 .read(cx)
172 .providers()
173 .iter()
174 .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
175 .collect::<Vec<_>>();
176
177 cx.background_spawn(async move {
178 for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
179 if let Err(err) = authenticate_task.await {
180 match err {
181 language_model::AuthenticateError::CredentialsNotFound => {
182 // Since we're authenticating these providers in the
183 // background for the purposes of populating the
184 // language selector, we don't care about providers
185 // where the credentials are not found.
186 }
187 language_model::AuthenticateError::ConnectionRefused => {
188 // Not logging connection refused errors as they are mostly from LM Studio's noisy auth failures.
189 // LM Studio only has one auth method (endpoint call) which fails for users who haven't enabled it.
190 // TODO: Better manage LM Studio auth logic to avoid these noisy failures.
191 }
192 _ => {
193 // Some providers have noisy failure states that we
194 // don't want to spam the logs with every time the
195 // language model selector is initialized.
196 //
197 // Ideally these should have more clear failure modes
198 // that we know are safe to ignore here, like what we do
199 // with `CredentialsNotFound` above.
200 match provider_id.0.as_ref() {
201 "lmstudio" | "ollama" => {
202 // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
203 //
204 // These fail noisily, so we don't log them.
205 }
206 "copilot_chat" => {
207 // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
208 }
209 _ => {
210 log::error!(
211 "Failed to authenticate provider: {}: {err:#}",
212 provider_name.0
213 );
214 }
215 }
216 }
217 }
218 }
219 }
220 })
221 }
222}
223
224pub struct NativeAgent {
225 /// Session ID -> Session mapping
226 sessions: HashMap<acp::SessionId, Session>,
227 history: Entity<HistoryStore>,
228 /// Shared project context for all threads
229 project_context: Entity<ProjectContext>,
230 project_context_needs_refresh: watch::Sender<()>,
231 _maintain_project_context: Task<Result<()>>,
232 context_server_registry: Entity<ContextServerRegistry>,
233 /// Shared templates for all threads
234 templates: Arc<Templates>,
235 /// Cached model information
236 models: LanguageModels,
237 project: Entity<Project>,
238 prompt_store: Option<Entity<PromptStore>>,
239 fs: Arc<dyn Fs>,
240 _subscriptions: Vec<Subscription>,
241}
242
243impl NativeAgent {
244 pub async fn new(
245 project: Entity<Project>,
246 history: Entity<HistoryStore>,
247 templates: Arc<Templates>,
248 prompt_store: Option<Entity<PromptStore>>,
249 fs: Arc<dyn Fs>,
250 cx: &mut AsyncApp,
251 ) -> Result<Entity<NativeAgent>> {
252 log::debug!("Creating new NativeAgent");
253
254 let project_context = cx
255 .update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))?
256 .await;
257
258 cx.new(|cx| {
259 let context_server_store = project.read(cx).context_server_store();
260 let context_server_registry =
261 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
262
263 let mut subscriptions = vec![
264 cx.subscribe(&project, Self::handle_project_event),
265 cx.subscribe(
266 &LanguageModelRegistry::global(cx),
267 Self::handle_models_updated_event,
268 ),
269 cx.subscribe(
270 &context_server_store,
271 Self::handle_context_server_store_updated,
272 ),
273 cx.subscribe(
274 &context_server_registry,
275 Self::handle_context_server_registry_event,
276 ),
277 ];
278 if let Some(prompt_store) = prompt_store.as_ref() {
279 subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
280 }
281
282 let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
283 watch::channel(());
284 Self {
285 sessions: HashMap::default(),
286 history,
287 project_context: cx.new(|_| project_context),
288 project_context_needs_refresh: project_context_needs_refresh_tx,
289 _maintain_project_context: cx.spawn(async move |this, cx| {
290 Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
291 }),
292 context_server_registry,
293 templates,
294 models: LanguageModels::new(cx),
295 project,
296 prompt_store,
297 fs,
298 _subscriptions: subscriptions,
299 }
300 })
301 }
302
303 fn register_session(
304 &mut self,
305 thread_handle: Entity<Thread>,
306 cx: &mut Context<Self>,
307 ) -> Entity<AcpThread> {
308 let connection = Rc::new(NativeAgentConnection(cx.entity()));
309
310 let thread = thread_handle.read(cx);
311 let session_id = thread.id().clone();
312 let title = thread.title();
313 let project = thread.project.clone();
314 let action_log = thread.action_log.clone();
315 let prompt_capabilities_rx = thread.prompt_capabilities_rx.clone();
316 let acp_thread = cx.new(|cx| {
317 acp_thread::AcpThread::new(
318 title,
319 connection,
320 project.clone(),
321 action_log.clone(),
322 session_id.clone(),
323 prompt_capabilities_rx,
324 cx,
325 )
326 });
327
328 let registry = LanguageModelRegistry::read_global(cx);
329 let summarization_model = registry.thread_summary_model().map(|c| c.model);
330
331 thread_handle.update(cx, |thread, cx| {
332 thread.set_summarization_model(summarization_model, cx);
333 thread.add_default_tools(
334 Rc::new(AcpThreadEnvironment {
335 acp_thread: acp_thread.downgrade(),
336 }) as _,
337 cx,
338 )
339 });
340
341 let subscriptions = vec![
342 cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
343 this.sessions.remove(acp_thread.session_id());
344 }),
345 cx.subscribe(&thread_handle, Self::handle_thread_title_updated),
346 cx.subscribe(&thread_handle, Self::handle_thread_token_usage_updated),
347 cx.observe(&thread_handle, move |this, thread, cx| {
348 this.save_thread(thread, cx)
349 }),
350 ];
351
352 self.sessions.insert(
353 session_id,
354 Session {
355 thread: thread_handle,
356 acp_thread: acp_thread.downgrade(),
357 _subscriptions: subscriptions,
358 pending_save: Task::ready(()),
359 },
360 );
361
362 self.update_available_commands(cx);
363
364 acp_thread
365 }
366
367 pub fn models(&self) -> &LanguageModels {
368 &self.models
369 }
370
371 async fn maintain_project_context(
372 this: WeakEntity<Self>,
373 mut needs_refresh: watch::Receiver<()>,
374 cx: &mut AsyncApp,
375 ) -> Result<()> {
376 while needs_refresh.changed().await.is_ok() {
377 let project_context = this
378 .update(cx, |this, cx| {
379 Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx)
380 })?
381 .await;
382 this.update(cx, |this, cx| {
383 this.project_context = cx.new(|_| project_context);
384 })?;
385 }
386
387 Ok(())
388 }
389
390 fn build_project_context(
391 project: &Entity<Project>,
392 prompt_store: Option<&Entity<PromptStore>>,
393 cx: &mut App,
394 ) -> Task<ProjectContext> {
395 let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
396 let worktree_tasks = worktrees
397 .into_iter()
398 .map(|worktree| {
399 Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
400 })
401 .collect::<Vec<_>>();
402 let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
403 prompt_store.read_with(cx, |prompt_store, cx| {
404 let prompts = prompt_store.default_prompt_metadata();
405 let load_tasks = prompts.into_iter().map(|prompt_metadata| {
406 let contents = prompt_store.load(prompt_metadata.id, cx);
407 async move { (contents.await, prompt_metadata) }
408 });
409 cx.background_spawn(future::join_all(load_tasks))
410 })
411 } else {
412 Task::ready(vec![])
413 };
414
415 cx.spawn(async move |_cx| {
416 let (worktrees, default_user_rules) =
417 future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
418
419 let worktrees = worktrees
420 .into_iter()
421 .map(|(worktree, _rules_error)| {
422 // TODO: show error message
423 // if let Some(rules_error) = rules_error {
424 // this.update(cx, |_, cx| cx.emit(rules_error)).ok();
425 // }
426 worktree
427 })
428 .collect::<Vec<_>>();
429
430 let default_user_rules = default_user_rules
431 .into_iter()
432 .flat_map(|(contents, prompt_metadata)| match contents {
433 Ok(contents) => Some(UserRulesContext {
434 uuid: prompt_metadata.id.user_id()?,
435 title: prompt_metadata.title.map(|title| title.to_string()),
436 contents,
437 }),
438 Err(_err) => {
439 // TODO: show error message
440 // this.update(cx, |_, cx| {
441 // cx.emit(RulesLoadingError {
442 // message: format!("{err:?}").into(),
443 // });
444 // })
445 // .ok();
446 None
447 }
448 })
449 .collect::<Vec<_>>();
450
451 ProjectContext::new(worktrees, default_user_rules)
452 })
453 }
454
455 fn load_worktree_info_for_system_prompt(
456 worktree: Entity<Worktree>,
457 project: Entity<Project>,
458 cx: &mut App,
459 ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
460 let tree = worktree.read(cx);
461 let root_name = tree.root_name_str().into();
462 let abs_path = tree.abs_path();
463
464 let mut context = WorktreeContext {
465 root_name,
466 abs_path,
467 rules_file: None,
468 };
469
470 let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
471 let Some(rules_task) = rules_task else {
472 return Task::ready((context, None));
473 };
474
475 cx.spawn(async move |_| {
476 let (rules_file, rules_file_error) = match rules_task.await {
477 Ok(rules_file) => (Some(rules_file), None),
478 Err(err) => (
479 None,
480 Some(RulesLoadingError {
481 message: format!("{err}").into(),
482 }),
483 ),
484 };
485 context.rules_file = rules_file;
486 (context, rules_file_error)
487 })
488 }
489
490 fn load_worktree_rules_file(
491 worktree: Entity<Worktree>,
492 project: Entity<Project>,
493 cx: &mut App,
494 ) -> Option<Task<Result<RulesFileContext>>> {
495 let worktree = worktree.read(cx);
496 let worktree_id = worktree.id();
497 let selected_rules_file = RULES_FILE_NAMES
498 .into_iter()
499 .filter_map(|name| {
500 worktree
501 .entry_for_path(RelPath::unix(name).unwrap())
502 .filter(|entry| entry.is_file())
503 .map(|entry| entry.path.clone())
504 })
505 .next();
506
507 // Note that Cline supports `.clinerules` being a directory, but that is not currently
508 // supported. This doesn't seem to occur often in GitHub repositories.
509 selected_rules_file.map(|path_in_worktree| {
510 let project_path = ProjectPath {
511 worktree_id,
512 path: path_in_worktree.clone(),
513 };
514 let buffer_task =
515 project.update(cx, |project, cx| project.open_buffer(project_path, cx));
516 let rope_task = cx.spawn(async move |cx| {
517 buffer_task.await?.read_with(cx, |buffer, cx| {
518 let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
519 anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
520 })?
521 });
522 // Build a string from the rope on a background thread.
523 cx.background_spawn(async move {
524 let (project_entry_id, rope) = rope_task.await?;
525 anyhow::Ok(RulesFileContext {
526 path_in_worktree,
527 text: rope.to_string().trim().to_string(),
528 project_entry_id: project_entry_id.to_usize(),
529 })
530 })
531 })
532 }
533
534 fn handle_thread_title_updated(
535 &mut self,
536 thread: Entity<Thread>,
537 _: &TitleUpdated,
538 cx: &mut Context<Self>,
539 ) {
540 let session_id = thread.read(cx).id();
541 let Some(session) = self.sessions.get(session_id) else {
542 return;
543 };
544 let thread = thread.downgrade();
545 let acp_thread = session.acp_thread.clone();
546 cx.spawn(async move |_, cx| {
547 let title = thread.read_with(cx, |thread, _| thread.title())?;
548 let task = acp_thread.update(cx, |acp_thread, cx| acp_thread.set_title(title, cx))?;
549 task.await
550 })
551 .detach_and_log_err(cx);
552 }
553
554 fn handle_thread_token_usage_updated(
555 &mut self,
556 thread: Entity<Thread>,
557 usage: &TokenUsageUpdated,
558 cx: &mut Context<Self>,
559 ) {
560 let Some(session) = self.sessions.get(thread.read(cx).id()) else {
561 return;
562 };
563 session
564 .acp_thread
565 .update(cx, |acp_thread, cx| {
566 acp_thread.update_token_usage(usage.0.clone(), cx);
567 })
568 .ok();
569 }
570
571 fn handle_project_event(
572 &mut self,
573 _project: Entity<Project>,
574 event: &project::Event,
575 _cx: &mut Context<Self>,
576 ) {
577 match event {
578 project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
579 self.project_context_needs_refresh.send(()).ok();
580 }
581 project::Event::WorktreeUpdatedEntries(_, items) => {
582 if items.iter().any(|(path, _, _)| {
583 RULES_FILE_NAMES
584 .iter()
585 .any(|name| path.as_ref() == RelPath::unix(name).unwrap())
586 }) {
587 self.project_context_needs_refresh.send(()).ok();
588 }
589 }
590 _ => {}
591 }
592 }
593
594 fn handle_prompts_updated_event(
595 &mut self,
596 _prompt_store: Entity<PromptStore>,
597 _event: &prompt_store::PromptsUpdatedEvent,
598 _cx: &mut Context<Self>,
599 ) {
600 self.project_context_needs_refresh.send(()).ok();
601 }
602
603 fn handle_models_updated_event(
604 &mut self,
605 _registry: Entity<LanguageModelRegistry>,
606 _event: &language_model::Event,
607 cx: &mut Context<Self>,
608 ) {
609 self.models.refresh_list(cx);
610
611 let registry = LanguageModelRegistry::read_global(cx);
612 let default_model = registry.default_model().map(|m| m.model);
613 let summarization_model = registry.thread_summary_model().map(|m| m.model);
614
615 for session in self.sessions.values_mut() {
616 session.thread.update(cx, |thread, cx| {
617 if thread.model().is_none()
618 && let Some(model) = default_model.clone()
619 {
620 thread.set_model(model, cx);
621 cx.notify();
622 }
623 thread.set_summarization_model(summarization_model.clone(), cx);
624 });
625 }
626 }
627
628 fn handle_context_server_store_updated(
629 &mut self,
630 _store: Entity<project::context_server_store::ContextServerStore>,
631 _event: &project::context_server_store::Event,
632 cx: &mut Context<Self>,
633 ) {
634 self.update_available_commands(cx);
635 }
636
637 fn handle_context_server_registry_event(
638 &mut self,
639 _registry: Entity<ContextServerRegistry>,
640 event: &ContextServerRegistryEvent,
641 cx: &mut Context<Self>,
642 ) {
643 match event {
644 ContextServerRegistryEvent::ToolsChanged => {}
645 ContextServerRegistryEvent::PromptsChanged => {
646 self.update_available_commands(cx);
647 }
648 }
649 }
650
651 fn update_available_commands(&self, cx: &mut Context<Self>) {
652 let available_commands = self.build_available_commands(cx);
653 for session in self.sessions.values() {
654 if let Some(acp_thread) = session.acp_thread.upgrade() {
655 acp_thread.update(cx, |thread, cx| {
656 thread
657 .handle_session_update(
658 acp::SessionUpdate::AvailableCommandsUpdate(
659 acp::AvailableCommandsUpdate::new(available_commands.clone()),
660 ),
661 cx,
662 )
663 .log_err();
664 });
665 }
666 }
667 }
668
669 fn build_available_commands(&self, cx: &App) -> Vec<acp::AvailableCommand> {
670 let registry = self.context_server_registry.read(cx);
671
672 let mut prompt_name_counts: HashMap<&str, usize> = HashMap::default();
673 for context_server_prompt in registry.prompts() {
674 *prompt_name_counts
675 .entry(context_server_prompt.prompt.name.as_str())
676 .or_insert(0) += 1;
677 }
678
679 registry
680 .prompts()
681 .flat_map(|context_server_prompt| {
682 let prompt = &context_server_prompt.prompt;
683
684 let should_prefix = prompt_name_counts
685 .get(prompt.name.as_str())
686 .copied()
687 .unwrap_or(0)
688 > 1;
689
690 let name = if should_prefix {
691 format!("{}.{}", context_server_prompt.server_id, prompt.name)
692 } else {
693 prompt.name.clone()
694 };
695
696 let mut command = acp::AvailableCommand::new(
697 name,
698 prompt.description.clone().unwrap_or_default(),
699 );
700
701 match prompt.arguments.as_deref() {
702 Some([arg]) => {
703 let hint = format!("<{}>", arg.name);
704
705 command = command.input(acp::AvailableCommandInput::Unstructured(
706 acp::UnstructuredCommandInput::new(hint),
707 ));
708 }
709 Some([]) | None => {}
710 Some(_) => {
711 // skip >1 argument commands since we don't support them yet
712 return None;
713 }
714 }
715
716 Some(command)
717 })
718 .collect()
719 }
720
721 pub fn load_thread(
722 &mut self,
723 id: acp::SessionId,
724 cx: &mut Context<Self>,
725 ) -> Task<Result<Entity<Thread>>> {
726 let database_future = ThreadsDatabase::connect(cx);
727 cx.spawn(async move |this, cx| {
728 let database = database_future.await.map_err(|err| anyhow!(err))?;
729 let db_thread = database
730 .load_thread(id.clone())
731 .await?
732 .with_context(|| format!("no thread found with ID: {id:?}"))?;
733
734 this.update(cx, |this, cx| {
735 let summarization_model = LanguageModelRegistry::read_global(cx)
736 .thread_summary_model()
737 .map(|c| c.model);
738
739 cx.new(|cx| {
740 let mut thread = Thread::from_db(
741 id.clone(),
742 db_thread,
743 this.project.clone(),
744 this.project_context.clone(),
745 this.context_server_registry.clone(),
746 this.templates.clone(),
747 cx,
748 );
749 thread.set_summarization_model(summarization_model, cx);
750 thread
751 })
752 })
753 })
754 }
755
756 pub fn open_thread(
757 &mut self,
758 id: acp::SessionId,
759 cx: &mut Context<Self>,
760 ) -> Task<Result<Entity<AcpThread>>> {
761 let task = self.load_thread(id, cx);
762 cx.spawn(async move |this, cx| {
763 let thread = task.await?;
764 let acp_thread =
765 this.update(cx, |this, cx| this.register_session(thread.clone(), cx))?;
766 let events = thread.update(cx, |thread, cx| thread.replay(cx))?;
767 cx.update(|cx| {
768 NativeAgentConnection::handle_thread_events(events, acp_thread.downgrade(), cx)
769 })?
770 .await?;
771 Ok(acp_thread)
772 })
773 }
774
775 pub fn thread_summary(
776 &mut self,
777 id: acp::SessionId,
778 cx: &mut Context<Self>,
779 ) -> Task<Result<SharedString>> {
780 let thread = self.open_thread(id.clone(), cx);
781 cx.spawn(async move |this, cx| {
782 let acp_thread = thread.await?;
783 let result = this
784 .update(cx, |this, cx| {
785 this.sessions
786 .get(&id)
787 .unwrap()
788 .thread
789 .update(cx, |thread, cx| thread.summary(cx))
790 })?
791 .await
792 .context("Failed to generate summary")?;
793 drop(acp_thread);
794 Ok(result)
795 })
796 }
797
798 fn save_thread(&mut self, thread: Entity<Thread>, cx: &mut Context<Self>) {
799 if thread.read(cx).is_empty() {
800 return;
801 }
802
803 let database_future = ThreadsDatabase::connect(cx);
804 let (id, db_thread) =
805 thread.update(cx, |thread, cx| (thread.id().clone(), thread.to_db(cx)));
806 let Some(session) = self.sessions.get_mut(&id) else {
807 return;
808 };
809 let history = self.history.clone();
810 session.pending_save = cx.spawn(async move |_, cx| {
811 let Some(database) = database_future.await.map_err(|err| anyhow!(err)).log_err() else {
812 return;
813 };
814 let db_thread = db_thread.await;
815 database.save_thread(id, db_thread).await.log_err();
816 history.update(cx, |history, cx| history.reload(cx)).ok();
817 });
818 }
819
820 fn send_mcp_prompt(
821 &self,
822 message_id: UserMessageId,
823 session_id: agent_client_protocol::SessionId,
824 prompt_name: String,
825 server_id: ContextServerId,
826 arguments: HashMap<String, String>,
827 original_content: Vec<acp::ContentBlock>,
828 cx: &mut Context<Self>,
829 ) -> Task<Result<acp::PromptResponse>> {
830 let server_store = self.context_server_registry.read(cx).server_store().clone();
831 let path_style = self.project.read(cx).path_style(cx);
832
833 cx.spawn(async move |this, cx| {
834 let prompt =
835 crate::get_prompt(&server_store, &server_id, &prompt_name, arguments, cx).await?;
836
837 let (acp_thread, thread) = this.update(cx, |this, _cx| {
838 let session = this
839 .sessions
840 .get(&session_id)
841 .context("Failed to get session")?;
842 anyhow::Ok((session.acp_thread.clone(), session.thread.clone()))
843 })??;
844
845 let mut last_is_user = true;
846
847 thread.update(cx, |thread, cx| {
848 thread.push_acp_user_block(
849 message_id,
850 original_content.into_iter().skip(1),
851 path_style,
852 cx,
853 );
854 })?;
855
856 for message in prompt.messages {
857 let context_server::types::PromptMessage { role, content } = message;
858 let block = mcp_message_content_to_acp_content_block(content);
859
860 match role {
861 context_server::types::Role::User => {
862 let id = acp_thread::UserMessageId::new();
863
864 acp_thread.update(cx, |acp_thread, cx| {
865 acp_thread.push_user_content_block_with_indent(
866 Some(id.clone()),
867 block.clone(),
868 true,
869 cx,
870 );
871 anyhow::Ok(())
872 })??;
873
874 thread.update(cx, |thread, cx| {
875 thread.push_acp_user_block(id, [block], path_style, cx);
876 anyhow::Ok(())
877 })??;
878 }
879 context_server::types::Role::Assistant => {
880 acp_thread.update(cx, |acp_thread, cx| {
881 acp_thread.push_assistant_content_block_with_indent(
882 block.clone(),
883 false,
884 true,
885 cx,
886 );
887 anyhow::Ok(())
888 })??;
889
890 thread.update(cx, |thread, cx| {
891 thread.push_acp_agent_block(block, cx);
892 anyhow::Ok(())
893 })??;
894 }
895 }
896
897 last_is_user = role == context_server::types::Role::User;
898 }
899
900 let response_stream = thread.update(cx, |thread, cx| {
901 if last_is_user {
902 thread.send_existing(cx)
903 } else {
904 // Resume if MCP prompt did not end with a user message
905 thread.resume(cx)
906 }
907 })??;
908
909 cx.update(|cx| {
910 NativeAgentConnection::handle_thread_events(response_stream, acp_thread, cx)
911 })?
912 .await
913 })
914 }
915}
916
917/// Wrapper struct that implements the AgentConnection trait
918#[derive(Clone)]
919pub struct NativeAgentConnection(pub Entity<NativeAgent>);
920
921impl NativeAgentConnection {
922 pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option<Entity<Thread>> {
923 self.0
924 .read(cx)
925 .sessions
926 .get(session_id)
927 .map(|session| session.thread.clone())
928 }
929
930 pub fn load_thread(&self, id: acp::SessionId, cx: &mut App) -> Task<Result<Entity<Thread>>> {
931 self.0.update(cx, |this, cx| this.load_thread(id, cx))
932 }
933
934 fn run_turn(
935 &self,
936 session_id: acp::SessionId,
937 cx: &mut App,
938 f: impl 'static
939 + FnOnce(Entity<Thread>, &mut App) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>,
940 ) -> Task<Result<acp::PromptResponse>> {
941 let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
942 agent
943 .sessions
944 .get_mut(&session_id)
945 .map(|s| (s.thread.clone(), s.acp_thread.clone()))
946 }) else {
947 return Task::ready(Err(anyhow!("Session not found")));
948 };
949 log::debug!("Found session for: {}", session_id);
950
951 let response_stream = match f(thread, cx) {
952 Ok(stream) => stream,
953 Err(err) => return Task::ready(Err(err)),
954 };
955 Self::handle_thread_events(response_stream, acp_thread, cx)
956 }
957
958 fn handle_thread_events(
959 mut events: mpsc::UnboundedReceiver<Result<ThreadEvent>>,
960 acp_thread: WeakEntity<AcpThread>,
961 cx: &App,
962 ) -> Task<Result<acp::PromptResponse>> {
963 cx.spawn(async move |cx| {
964 // Handle response stream and forward to session.acp_thread
965 while let Some(result) = events.next().await {
966 match result {
967 Ok(event) => {
968 log::trace!("Received completion event: {:?}", event);
969
970 match event {
971 ThreadEvent::UserMessage(message) => {
972 acp_thread.update(cx, |thread, cx| {
973 for content in message.content {
974 thread.push_user_content_block(
975 Some(message.id.clone()),
976 content.into(),
977 cx,
978 );
979 }
980 })?;
981 }
982 ThreadEvent::AgentText(text) => {
983 acp_thread.update(cx, |thread, cx| {
984 thread.push_assistant_content_block(text.into(), false, cx)
985 })?;
986 }
987 ThreadEvent::AgentThinking(text) => {
988 acp_thread.update(cx, |thread, cx| {
989 thread.push_assistant_content_block(text.into(), true, cx)
990 })?;
991 }
992 ThreadEvent::ToolCallAuthorization(ToolCallAuthorization {
993 tool_call,
994 options,
995 response,
996 }) => {
997 let outcome_task = acp_thread.update(cx, |thread, cx| {
998 thread.request_tool_call_authorization(
999 tool_call, options, true, cx,
1000 )
1001 })??;
1002 cx.background_spawn(async move {
1003 if let acp::RequestPermissionOutcome::Selected(
1004 acp::SelectedPermissionOutcome { option_id, .. },
1005 ) = outcome_task.await
1006 {
1007 response
1008 .send(option_id)
1009 .map(|_| anyhow!("authorization receiver was dropped"))
1010 .log_err();
1011 }
1012 })
1013 .detach();
1014 }
1015 ThreadEvent::ToolCall(tool_call) => {
1016 acp_thread.update(cx, |thread, cx| {
1017 thread.upsert_tool_call(tool_call, cx)
1018 })??;
1019 }
1020 ThreadEvent::ToolCallUpdate(update) => {
1021 acp_thread.update(cx, |thread, cx| {
1022 thread.update_tool_call(update, cx)
1023 })??;
1024 }
1025 ThreadEvent::Retry(status) => {
1026 acp_thread.update(cx, |thread, cx| {
1027 thread.update_retry_status(status, cx)
1028 })?;
1029 }
1030 ThreadEvent::Stop(stop_reason) => {
1031 log::debug!("Assistant message complete: {:?}", stop_reason);
1032 return Ok(acp::PromptResponse::new(stop_reason));
1033 }
1034 }
1035 }
1036 Err(e) => {
1037 log::error!("Error in model response stream: {:?}", e);
1038 return Err(e);
1039 }
1040 }
1041 }
1042
1043 log::debug!("Response stream completed");
1044 anyhow::Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
1045 })
1046 }
1047}
1048
1049struct Command<'a> {
1050 prompt_name: &'a str,
1051 arg_value: &'a str,
1052 explicit_server_id: Option<&'a str>,
1053}
1054
1055impl<'a> Command<'a> {
1056 fn parse(prompt: &'a [acp::ContentBlock]) -> Option<Self> {
1057 let acp::ContentBlock::Text(text_content) = prompt.first()? else {
1058 return None;
1059 };
1060 let text = text_content.text.trim();
1061 let command = text.strip_prefix('/')?;
1062 let (command, arg_value) = command
1063 .split_once(char::is_whitespace)
1064 .unwrap_or((command, ""));
1065
1066 if let Some((server_id, prompt_name)) = command.split_once('.') {
1067 Some(Self {
1068 prompt_name,
1069 arg_value,
1070 explicit_server_id: Some(server_id),
1071 })
1072 } else {
1073 Some(Self {
1074 prompt_name: command,
1075 arg_value,
1076 explicit_server_id: None,
1077 })
1078 }
1079 }
1080}
1081
1082struct NativeAgentModelSelector {
1083 session_id: acp::SessionId,
1084 connection: NativeAgentConnection,
1085}
1086
1087impl acp_thread::AgentModelSelector for NativeAgentModelSelector {
1088 fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
1089 log::debug!("NativeAgentConnection::list_models called");
1090 let list = self.connection.0.read(cx).models.model_list.clone();
1091 Task::ready(if list.is_empty() {
1092 Err(anyhow::anyhow!("No models available"))
1093 } else {
1094 Ok(list)
1095 })
1096 }
1097
1098 fn select_model(&self, model_id: acp::ModelId, cx: &mut App) -> Task<Result<()>> {
1099 log::debug!(
1100 "Setting model for session {}: {}",
1101 self.session_id,
1102 model_id
1103 );
1104 let Some(thread) = self
1105 .connection
1106 .0
1107 .read(cx)
1108 .sessions
1109 .get(&self.session_id)
1110 .map(|session| session.thread.clone())
1111 else {
1112 return Task::ready(Err(anyhow!("Session not found")));
1113 };
1114
1115 let Some(model) = self.connection.0.read(cx).models.model_from_id(&model_id) else {
1116 return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
1117 };
1118
1119 thread.update(cx, |thread, cx| {
1120 thread.set_model(model.clone(), cx);
1121 });
1122
1123 update_settings_file(
1124 self.connection.0.read(cx).fs.clone(),
1125 cx,
1126 move |settings, _cx| {
1127 let provider = model.provider_id().0.to_string();
1128 let model = model.id().0.to_string();
1129 settings
1130 .agent
1131 .get_or_insert_default()
1132 .set_model(LanguageModelSelection {
1133 provider: provider.into(),
1134 model,
1135 });
1136 },
1137 );
1138
1139 Task::ready(Ok(()))
1140 }
1141
1142 fn selected_model(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelInfo>> {
1143 let Some(thread) = self
1144 .connection
1145 .0
1146 .read(cx)
1147 .sessions
1148 .get(&self.session_id)
1149 .map(|session| session.thread.clone())
1150 else {
1151 return Task::ready(Err(anyhow!("Session not found")));
1152 };
1153 let Some(model) = thread.read(cx).model() else {
1154 return Task::ready(Err(anyhow!("Model not found")));
1155 };
1156 let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
1157 else {
1158 return Task::ready(Err(anyhow!("Provider not found")));
1159 };
1160 Task::ready(Ok(LanguageModels::map_language_model_to_info(
1161 model, &provider,
1162 )))
1163 }
1164
1165 fn watch(&self, cx: &mut App) -> Option<watch::Receiver<()>> {
1166 Some(self.connection.0.read(cx).models.watch())
1167 }
1168
1169 fn should_render_footer(&self) -> bool {
1170 true
1171 }
1172
1173 fn supports_favorites(&self) -> bool {
1174 true
1175 }
1176}
1177
1178impl acp_thread::AgentConnection for NativeAgentConnection {
1179 fn telemetry_id(&self) -> SharedString {
1180 "zed".into()
1181 }
1182
1183 fn new_thread(
1184 self: Rc<Self>,
1185 project: Entity<Project>,
1186 cwd: &Path,
1187 cx: &mut App,
1188 ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
1189 let agent = self.0.clone();
1190 log::debug!("Creating new thread for project at: {:?}", cwd);
1191
1192 cx.spawn(async move |cx| {
1193 log::debug!("Starting thread creation in async context");
1194
1195 // Create Thread
1196 let thread = agent.update(
1197 cx,
1198 |agent, cx: &mut gpui::Context<NativeAgent>| -> Result<_> {
1199 // Fetch default model from registry settings
1200 let registry = LanguageModelRegistry::read_global(cx);
1201 // Log available models for debugging
1202 let available_count = registry.available_models(cx).count();
1203 log::debug!("Total available models: {}", available_count);
1204
1205 let default_model = registry.default_model().and_then(|default_model| {
1206 agent
1207 .models
1208 .model_from_id(&LanguageModels::model_id(&default_model.model))
1209 });
1210 Ok(cx.new(|cx| {
1211 Thread::new(
1212 project.clone(),
1213 agent.project_context.clone(),
1214 agent.context_server_registry.clone(),
1215 agent.templates.clone(),
1216 default_model,
1217 cx,
1218 )
1219 }))
1220 },
1221 )??;
1222 agent.update(cx, |agent, cx| agent.register_session(thread, cx))
1223 })
1224 }
1225
1226 fn auth_methods(&self) -> &[acp::AuthMethod] {
1227 &[] // No auth for in-process
1228 }
1229
1230 fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
1231 Task::ready(Ok(()))
1232 }
1233
1234 fn model_selector(&self, session_id: &acp::SessionId) -> Option<Rc<dyn AgentModelSelector>> {
1235 Some(Rc::new(NativeAgentModelSelector {
1236 session_id: session_id.clone(),
1237 connection: self.clone(),
1238 }) as Rc<dyn AgentModelSelector>)
1239 }
1240
1241 fn prompt(
1242 &self,
1243 id: Option<acp_thread::UserMessageId>,
1244 params: acp::PromptRequest,
1245 cx: &mut App,
1246 ) -> Task<Result<acp::PromptResponse>> {
1247 let id = id.expect("UserMessageId is required");
1248 let session_id = params.session_id.clone();
1249 log::info!("Received prompt request for session: {}", session_id);
1250 log::debug!("Prompt blocks count: {}", params.prompt.len());
1251
1252 if let Some(parsed_command) = Command::parse(¶ms.prompt) {
1253 let registry = self.0.read(cx).context_server_registry.read(cx);
1254
1255 let explicit_server_id = parsed_command
1256 .explicit_server_id
1257 .map(|server_id| ContextServerId(server_id.into()));
1258
1259 if let Some(prompt) =
1260 registry.find_prompt(explicit_server_id.as_ref(), parsed_command.prompt_name)
1261 {
1262 let arguments = if !parsed_command.arg_value.is_empty()
1263 && let Some(arg_name) = prompt
1264 .prompt
1265 .arguments
1266 .as_ref()
1267 .and_then(|args| args.first())
1268 .map(|arg| arg.name.clone())
1269 {
1270 HashMap::from_iter([(arg_name, parsed_command.arg_value.to_string())])
1271 } else {
1272 Default::default()
1273 };
1274
1275 let prompt_name = prompt.prompt.name.clone();
1276 let server_id = prompt.server_id.clone();
1277
1278 return self.0.update(cx, |agent, cx| {
1279 agent.send_mcp_prompt(
1280 id,
1281 session_id.clone(),
1282 prompt_name,
1283 server_id,
1284 arguments,
1285 params.prompt,
1286 cx,
1287 )
1288 });
1289 };
1290 };
1291
1292 let path_style = self.0.read(cx).project.read(cx).path_style(cx);
1293
1294 self.run_turn(session_id, cx, move |thread, cx| {
1295 let content: Vec<UserMessageContent> = params
1296 .prompt
1297 .into_iter()
1298 .map(|block| UserMessageContent::from_content_block(block, path_style))
1299 .collect::<Vec<_>>();
1300 log::debug!("Converted prompt to message: {} chars", content.len());
1301 log::debug!("Message id: {:?}", id);
1302 log::debug!("Message content: {:?}", content);
1303
1304 thread.update(cx, |thread, cx| thread.send(id, content, cx))
1305 })
1306 }
1307
1308 fn resume(
1309 &self,
1310 session_id: &acp::SessionId,
1311 _cx: &App,
1312 ) -> Option<Rc<dyn acp_thread::AgentSessionResume>> {
1313 Some(Rc::new(NativeAgentSessionResume {
1314 connection: self.clone(),
1315 session_id: session_id.clone(),
1316 }) as _)
1317 }
1318
1319 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
1320 log::info!("Cancelling on session: {}", session_id);
1321 self.0.update(cx, |agent, cx| {
1322 if let Some(agent) = agent.sessions.get(session_id) {
1323 agent.thread.update(cx, |thread, cx| thread.cancel(cx));
1324 }
1325 });
1326 }
1327
1328 fn truncate(
1329 &self,
1330 session_id: &agent_client_protocol::SessionId,
1331 cx: &App,
1332 ) -> Option<Rc<dyn acp_thread::AgentSessionTruncate>> {
1333 self.0.read_with(cx, |agent, _cx| {
1334 agent.sessions.get(session_id).map(|session| {
1335 Rc::new(NativeAgentSessionTruncate {
1336 thread: session.thread.clone(),
1337 acp_thread: session.acp_thread.clone(),
1338 }) as _
1339 })
1340 })
1341 }
1342
1343 fn set_title(
1344 &self,
1345 session_id: &acp::SessionId,
1346 _cx: &App,
1347 ) -> Option<Rc<dyn acp_thread::AgentSessionSetTitle>> {
1348 Some(Rc::new(NativeAgentSessionSetTitle {
1349 connection: self.clone(),
1350 session_id: session_id.clone(),
1351 }) as _)
1352 }
1353
1354 fn telemetry(&self) -> Option<Rc<dyn acp_thread::AgentTelemetry>> {
1355 Some(Rc::new(self.clone()) as Rc<dyn acp_thread::AgentTelemetry>)
1356 }
1357
1358 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1359 self
1360 }
1361}
1362
1363impl acp_thread::AgentTelemetry for NativeAgentConnection {
1364 fn thread_data(
1365 &self,
1366 session_id: &acp::SessionId,
1367 cx: &mut App,
1368 ) -> Task<Result<serde_json::Value>> {
1369 let Some(session) = self.0.read(cx).sessions.get(session_id) else {
1370 return Task::ready(Err(anyhow!("Session not found")));
1371 };
1372
1373 let task = session.thread.read(cx).to_db(cx);
1374 cx.background_spawn(async move {
1375 serde_json::to_value(task.await).context("Failed to serialize thread")
1376 })
1377 }
1378}
1379
1380struct NativeAgentSessionTruncate {
1381 thread: Entity<Thread>,
1382 acp_thread: WeakEntity<AcpThread>,
1383}
1384
1385impl acp_thread::AgentSessionTruncate for NativeAgentSessionTruncate {
1386 fn run(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
1387 match self.thread.update(cx, |thread, cx| {
1388 thread.truncate(message_id.clone(), cx)?;
1389 Ok(thread.latest_token_usage())
1390 }) {
1391 Ok(usage) => {
1392 self.acp_thread
1393 .update(cx, |thread, cx| {
1394 thread.update_token_usage(usage, cx);
1395 })
1396 .ok();
1397 Task::ready(Ok(()))
1398 }
1399 Err(error) => Task::ready(Err(error)),
1400 }
1401 }
1402}
1403
1404struct NativeAgentSessionResume {
1405 connection: NativeAgentConnection,
1406 session_id: acp::SessionId,
1407}
1408
1409impl acp_thread::AgentSessionResume for NativeAgentSessionResume {
1410 fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>> {
1411 self.connection
1412 .run_turn(self.session_id.clone(), cx, |thread, cx| {
1413 thread.update(cx, |thread, cx| thread.resume(cx))
1414 })
1415 }
1416}
1417
1418struct NativeAgentSessionSetTitle {
1419 connection: NativeAgentConnection,
1420 session_id: acp::SessionId,
1421}
1422
1423impl acp_thread::AgentSessionSetTitle for NativeAgentSessionSetTitle {
1424 fn run(&self, title: SharedString, cx: &mut App) -> Task<Result<()>> {
1425 let Some(session) = self.connection.0.read(cx).sessions.get(&self.session_id) else {
1426 return Task::ready(Err(anyhow!("session not found")));
1427 };
1428 let thread = session.thread.clone();
1429 thread.update(cx, |thread, cx| thread.set_title(title, cx));
1430 Task::ready(Ok(()))
1431 }
1432}
1433
1434pub struct AcpThreadEnvironment {
1435 acp_thread: WeakEntity<AcpThread>,
1436}
1437
1438impl ThreadEnvironment for AcpThreadEnvironment {
1439 fn create_terminal(
1440 &self,
1441 command: String,
1442 cwd: Option<PathBuf>,
1443 output_byte_limit: Option<u64>,
1444 cx: &mut AsyncApp,
1445 ) -> Task<Result<Rc<dyn TerminalHandle>>> {
1446 let task = self.acp_thread.update(cx, |thread, cx| {
1447 thread.create_terminal(command, vec![], vec![], cwd, output_byte_limit, cx)
1448 });
1449
1450 let acp_thread = self.acp_thread.clone();
1451 cx.spawn(async move |cx| {
1452 let terminal = task?.await?;
1453
1454 let (drop_tx, drop_rx) = oneshot::channel();
1455 let terminal_id = terminal.read_with(cx, |terminal, _cx| terminal.id().clone())?;
1456
1457 cx.spawn(async move |cx| {
1458 drop_rx.await.ok();
1459 acp_thread.update(cx, |thread, cx| thread.release_terminal(terminal_id, cx))
1460 })
1461 .detach();
1462
1463 let handle = AcpTerminalHandle {
1464 terminal,
1465 _drop_tx: Some(drop_tx),
1466 };
1467
1468 Ok(Rc::new(handle) as _)
1469 })
1470 }
1471}
1472
1473pub struct AcpTerminalHandle {
1474 terminal: Entity<acp_thread::Terminal>,
1475 _drop_tx: Option<oneshot::Sender<()>>,
1476}
1477
1478impl TerminalHandle for AcpTerminalHandle {
1479 fn id(&self, cx: &AsyncApp) -> Result<acp::TerminalId> {
1480 self.terminal.read_with(cx, |term, _cx| term.id().clone())
1481 }
1482
1483 fn wait_for_exit(&self, cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
1484 self.terminal
1485 .read_with(cx, |term, _cx| term.wait_for_exit())
1486 }
1487
1488 fn current_output(&self, cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
1489 self.terminal
1490 .read_with(cx, |term, cx| term.current_output(cx))
1491 }
1492
1493 fn kill(&self, cx: &AsyncApp) -> Result<()> {
1494 cx.update(|cx| {
1495 self.terminal.update(cx, |terminal, cx| {
1496 terminal.kill(cx);
1497 });
1498 })?;
1499 Ok(())
1500 }
1501}
1502
1503#[cfg(test)]
1504mod internal_tests {
1505 use crate::HistoryEntryId;
1506
1507 use super::*;
1508 use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelInfo, MentionUri};
1509 use fs::FakeFs;
1510 use gpui::TestAppContext;
1511 use indoc::formatdoc;
1512 use language_model::fake_provider::FakeLanguageModel;
1513 use serde_json::json;
1514 use settings::SettingsStore;
1515 use util::{path, rel_path::rel_path};
1516
1517 #[gpui::test]
1518 async fn test_maintaining_project_context(cx: &mut TestAppContext) {
1519 init_test(cx);
1520 let fs = FakeFs::new(cx.executor());
1521 fs.insert_tree(
1522 "/",
1523 json!({
1524 "a": {}
1525 }),
1526 )
1527 .await;
1528 let project = Project::test(fs.clone(), [], cx).await;
1529 let text_thread_store =
1530 cx.new(|cx| assistant_text_thread::TextThreadStore::fake(project.clone(), cx));
1531 let history_store = cx.new(|cx| HistoryStore::new(text_thread_store, cx));
1532 let agent = NativeAgent::new(
1533 project.clone(),
1534 history_store,
1535 Templates::new(),
1536 None,
1537 fs.clone(),
1538 &mut cx.to_async(),
1539 )
1540 .await
1541 .unwrap();
1542 agent.read_with(cx, |agent, cx| {
1543 assert_eq!(agent.project_context.read(cx).worktrees, vec![])
1544 });
1545
1546 let worktree = project
1547 .update(cx, |project, cx| project.create_worktree("/a", true, cx))
1548 .await
1549 .unwrap();
1550 cx.run_until_parked();
1551 agent.read_with(cx, |agent, cx| {
1552 assert_eq!(
1553 agent.project_context.read(cx).worktrees,
1554 vec![WorktreeContext {
1555 root_name: "a".into(),
1556 abs_path: Path::new("/a").into(),
1557 rules_file: None
1558 }]
1559 )
1560 });
1561
1562 // Creating `/a/.rules` updates the project context.
1563 fs.insert_file("/a/.rules", Vec::new()).await;
1564 cx.run_until_parked();
1565 agent.read_with(cx, |agent, cx| {
1566 let rules_entry = worktree
1567 .read(cx)
1568 .entry_for_path(rel_path(".rules"))
1569 .unwrap();
1570 assert_eq!(
1571 agent.project_context.read(cx).worktrees,
1572 vec![WorktreeContext {
1573 root_name: "a".into(),
1574 abs_path: Path::new("/a").into(),
1575 rules_file: Some(RulesFileContext {
1576 path_in_worktree: rel_path(".rules").into(),
1577 text: "".into(),
1578 project_entry_id: rules_entry.id.to_usize()
1579 })
1580 }]
1581 )
1582 });
1583 }
1584
1585 #[gpui::test]
1586 async fn test_listing_models(cx: &mut TestAppContext) {
1587 init_test(cx);
1588 let fs = FakeFs::new(cx.executor());
1589 fs.insert_tree("/", json!({ "a": {} })).await;
1590 let project = Project::test(fs.clone(), [], cx).await;
1591 let text_thread_store =
1592 cx.new(|cx| assistant_text_thread::TextThreadStore::fake(project.clone(), cx));
1593 let history_store = cx.new(|cx| HistoryStore::new(text_thread_store, cx));
1594 let connection = NativeAgentConnection(
1595 NativeAgent::new(
1596 project.clone(),
1597 history_store,
1598 Templates::new(),
1599 None,
1600 fs.clone(),
1601 &mut cx.to_async(),
1602 )
1603 .await
1604 .unwrap(),
1605 );
1606
1607 // Create a thread/session
1608 let acp_thread = cx
1609 .update(|cx| {
1610 Rc::new(connection.clone()).new_thread(project.clone(), Path::new("/a"), cx)
1611 })
1612 .await
1613 .unwrap();
1614
1615 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
1616
1617 let models = cx
1618 .update(|cx| {
1619 connection
1620 .model_selector(&session_id)
1621 .unwrap()
1622 .list_models(cx)
1623 })
1624 .await
1625 .unwrap();
1626
1627 let acp_thread::AgentModelList::Grouped(models) = models else {
1628 panic!("Unexpected model group");
1629 };
1630 assert_eq!(
1631 models,
1632 IndexMap::from_iter([(
1633 AgentModelGroupName("Fake".into()),
1634 vec![AgentModelInfo {
1635 id: acp::ModelId::new("fake/fake"),
1636 name: "Fake".into(),
1637 description: None,
1638 icon: Some(AgentModelIcon::Named(ui::IconName::ZedAssistant)),
1639 }]
1640 )])
1641 );
1642 }
1643
1644 #[gpui::test]
1645 async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) {
1646 init_test(cx);
1647 let fs = FakeFs::new(cx.executor());
1648 fs.create_dir(paths::settings_file().parent().unwrap())
1649 .await
1650 .unwrap();
1651 fs.insert_file(
1652 paths::settings_file(),
1653 json!({
1654 "agent": {
1655 "default_model": {
1656 "provider": "foo",
1657 "model": "bar"
1658 }
1659 }
1660 })
1661 .to_string()
1662 .into_bytes(),
1663 )
1664 .await;
1665 let project = Project::test(fs.clone(), [], cx).await;
1666
1667 let text_thread_store =
1668 cx.new(|cx| assistant_text_thread::TextThreadStore::fake(project.clone(), cx));
1669 let history_store = cx.new(|cx| HistoryStore::new(text_thread_store, cx));
1670
1671 // Create the agent and connection
1672 let agent = NativeAgent::new(
1673 project.clone(),
1674 history_store,
1675 Templates::new(),
1676 None,
1677 fs.clone(),
1678 &mut cx.to_async(),
1679 )
1680 .await
1681 .unwrap();
1682 let connection = NativeAgentConnection(agent.clone());
1683
1684 // Create a thread/session
1685 let acp_thread = cx
1686 .update(|cx| {
1687 Rc::new(connection.clone()).new_thread(project.clone(), Path::new("/a"), cx)
1688 })
1689 .await
1690 .unwrap();
1691
1692 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
1693
1694 // Select a model
1695 let selector = connection.model_selector(&session_id).unwrap();
1696 let model_id = acp::ModelId::new("fake/fake");
1697 cx.update(|cx| selector.select_model(model_id.clone(), cx))
1698 .await
1699 .unwrap();
1700
1701 // Verify the thread has the selected model
1702 agent.read_with(cx, |agent, _| {
1703 let session = agent.sessions.get(&session_id).unwrap();
1704 session.thread.read_with(cx, |thread, _| {
1705 assert_eq!(thread.model().unwrap().id().0, "fake");
1706 });
1707 });
1708
1709 cx.run_until_parked();
1710
1711 // Verify settings file was updated
1712 let settings_content = fs.load(paths::settings_file()).await.unwrap();
1713 let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
1714
1715 // Check that the agent settings contain the selected model
1716 assert_eq!(
1717 settings_json["agent"]["default_model"]["model"],
1718 json!("fake")
1719 );
1720 assert_eq!(
1721 settings_json["agent"]["default_model"]["provider"],
1722 json!("fake")
1723 );
1724 }
1725
1726 #[gpui::test]
1727 async fn test_save_load_thread(cx: &mut TestAppContext) {
1728 init_test(cx);
1729 let fs = FakeFs::new(cx.executor());
1730 fs.insert_tree(
1731 "/",
1732 json!({
1733 "a": {
1734 "b.md": "Lorem"
1735 }
1736 }),
1737 )
1738 .await;
1739 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
1740 let text_thread_store =
1741 cx.new(|cx| assistant_text_thread::TextThreadStore::fake(project.clone(), cx));
1742 let history_store = cx.new(|cx| HistoryStore::new(text_thread_store, cx));
1743 let agent = NativeAgent::new(
1744 project.clone(),
1745 history_store.clone(),
1746 Templates::new(),
1747 None,
1748 fs.clone(),
1749 &mut cx.to_async(),
1750 )
1751 .await
1752 .unwrap();
1753 let connection = Rc::new(NativeAgentConnection(agent.clone()));
1754
1755 let acp_thread = cx
1756 .update(|cx| {
1757 connection
1758 .clone()
1759 .new_thread(project.clone(), Path::new(""), cx)
1760 })
1761 .await
1762 .unwrap();
1763 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
1764 let thread = agent.read_with(cx, |agent, _| {
1765 agent.sessions.get(&session_id).unwrap().thread.clone()
1766 });
1767
1768 // Ensure empty threads are not saved, even if they get mutated.
1769 let model = Arc::new(FakeLanguageModel::default());
1770 let summary_model = Arc::new(FakeLanguageModel::default());
1771 thread.update(cx, |thread, cx| {
1772 thread.set_model(model.clone(), cx);
1773 thread.set_summarization_model(Some(summary_model.clone()), cx);
1774 });
1775 cx.run_until_parked();
1776 assert_eq!(history_entries(&history_store, cx), vec![]);
1777
1778 let send = acp_thread.update(cx, |thread, cx| {
1779 thread.send(
1780 vec![
1781 "What does ".into(),
1782 acp::ContentBlock::ResourceLink(acp::ResourceLink::new(
1783 "b.md",
1784 MentionUri::File {
1785 abs_path: path!("/a/b.md").into(),
1786 }
1787 .to_uri()
1788 .to_string(),
1789 )),
1790 " mean?".into(),
1791 ],
1792 cx,
1793 )
1794 });
1795 let send = cx.foreground_executor().spawn(send);
1796 cx.run_until_parked();
1797
1798 model.send_last_completion_stream_text_chunk("Lorem.");
1799 model.end_last_completion_stream();
1800 cx.run_until_parked();
1801 summary_model
1802 .send_last_completion_stream_text_chunk(&format!("Explaining {}", path!("/a/b.md")));
1803 summary_model.end_last_completion_stream();
1804
1805 send.await.unwrap();
1806 let uri = MentionUri::File {
1807 abs_path: path!("/a/b.md").into(),
1808 }
1809 .to_uri();
1810 acp_thread.read_with(cx, |thread, cx| {
1811 assert_eq!(
1812 thread.to_markdown(cx),
1813 formatdoc! {"
1814 ## User
1815
1816 What does [@b.md]({uri}) mean?
1817
1818 ## Assistant
1819
1820 Lorem.
1821
1822 "}
1823 )
1824 });
1825
1826 cx.run_until_parked();
1827
1828 // Drop the ACP thread, which should cause the session to be dropped as well.
1829 cx.update(|_| {
1830 drop(thread);
1831 drop(acp_thread);
1832 });
1833 agent.read_with(cx, |agent, _| {
1834 assert_eq!(agent.sessions.keys().cloned().collect::<Vec<_>>(), []);
1835 });
1836
1837 // Ensure the thread can be reloaded from disk.
1838 assert_eq!(
1839 history_entries(&history_store, cx),
1840 vec![(
1841 HistoryEntryId::AcpThread(session_id.clone()),
1842 format!("Explaining {}", path!("/a/b.md"))
1843 )]
1844 );
1845 let acp_thread = agent
1846 .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
1847 .await
1848 .unwrap();
1849 acp_thread.read_with(cx, |thread, cx| {
1850 assert_eq!(
1851 thread.to_markdown(cx),
1852 formatdoc! {"
1853 ## User
1854
1855 What does [@b.md]({uri}) mean?
1856
1857 ## Assistant
1858
1859 Lorem.
1860
1861 "}
1862 )
1863 });
1864 }
1865
1866 fn history_entries(
1867 history: &Entity<HistoryStore>,
1868 cx: &mut TestAppContext,
1869 ) -> Vec<(HistoryEntryId, String)> {
1870 history.read_with(cx, |history, _| {
1871 history
1872 .entries()
1873 .map(|e| (e.id(), e.title().to_string()))
1874 .collect::<Vec<_>>()
1875 })
1876 }
1877
1878 fn init_test(cx: &mut TestAppContext) {
1879 env_logger::try_init().ok();
1880 cx.update(|cx| {
1881 let settings_store = SettingsStore::test(cx);
1882 cx.set_global(settings_store);
1883
1884 LanguageModelRegistry::test(cx);
1885 });
1886 }
1887}
1888
1889fn mcp_message_content_to_acp_content_block(
1890 content: context_server::types::MessageContent,
1891) -> acp::ContentBlock {
1892 match content {
1893 context_server::types::MessageContent::Text {
1894 text,
1895 annotations: _,
1896 } => text.into(),
1897 context_server::types::MessageContent::Image {
1898 data,
1899 mime_type,
1900 annotations: _,
1901 } => acp::ContentBlock::Image(acp::ImageContent::new(data, mime_type)),
1902 context_server::types::MessageContent::Audio {
1903 data,
1904 mime_type,
1905 annotations: _,
1906 } => acp::ContentBlock::Audio(acp::AudioContent::new(data, mime_type)),
1907 context_server::types::MessageContent::Resource {
1908 resource,
1909 annotations: _,
1910 } => {
1911 let mut link =
1912 acp::ResourceLink::new(resource.uri.to_string(), resource.uri.to_string());
1913 if let Some(mime_type) = resource.mime_type {
1914 link = link.mime_type(mime_type);
1915 }
1916 acp::ContentBlock::ResourceLink(link)
1917 }
1918 }
1919}