1mod db;
2mod edit_agent;
3mod legacy_thread;
4mod native_agent_server;
5pub mod outline;
6mod templates;
7#[cfg(test)]
8mod tests;
9mod thread;
10mod thread_store;
11mod tool_permissions;
12mod tools;
13
14use context_server::ContextServerId;
15pub use db::*;
16pub use native_agent_server::NativeAgentServer;
17pub use templates::*;
18pub use thread::*;
19pub use thread_store::*;
20pub use tool_permissions::*;
21pub use tools::*;
22
23use acp_thread::{
24 AcpThread, AgentModelSelector, AgentSessionInfo, AgentSessionList, AgentSessionListRequest,
25 AgentSessionListResponse, UserMessageId,
26};
27use agent_client_protocol as acp;
28use anyhow::{Context as _, Result, anyhow};
29use chrono::{DateTime, Utc};
30use collections::{HashMap, HashSet, IndexMap};
31use fs::Fs;
32use futures::channel::{mpsc, oneshot};
33use futures::future::Shared;
34use futures::{StreamExt, future};
35use gpui::{
36 App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
37};
38use language_model::{IconOrSvg, LanguageModel, LanguageModelProvider, LanguageModelRegistry};
39use project::{Project, ProjectItem, ProjectPath, Worktree};
40use prompt_store::{
41 ProjectContext, PromptStore, RULES_FILE_NAMES, RulesFileContext, UserRulesContext,
42 WorktreeContext,
43};
44use serde::{Deserialize, Serialize};
45use settings::{LanguageModelSelection, update_settings_file};
46use std::any::Any;
47use std::path::{Path, PathBuf};
48use std::rc::Rc;
49use std::sync::Arc;
50use util::ResultExt;
51use util::rel_path::RelPath;
52
53#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
54pub struct ProjectSnapshot {
55 pub worktree_snapshots: Vec<project::telemetry_snapshot::TelemetryWorktreeSnapshot>,
56 pub timestamp: DateTime<Utc>,
57}
58
59pub struct RulesLoadingError {
60 pub message: SharedString,
61}
62
63/// Holds both the internal Thread and the AcpThread for a session
64struct Session {
65 /// The internal thread that processes messages
66 thread: Entity<Thread>,
67 /// The ACP thread that handles protocol communication
68 acp_thread: WeakEntity<acp_thread::AcpThread>,
69 pending_save: Task<()>,
70 _subscriptions: Vec<Subscription>,
71}
72
73pub struct LanguageModels {
74 /// Access language model by ID
75 models: HashMap<acp::ModelId, Arc<dyn LanguageModel>>,
76 /// Cached list for returning language model information
77 model_list: acp_thread::AgentModelList,
78 refresh_models_rx: watch::Receiver<()>,
79 refresh_models_tx: watch::Sender<()>,
80 _authenticate_all_providers_task: Task<()>,
81}
82
83impl LanguageModels {
84 fn new(cx: &mut App) -> Self {
85 let (refresh_models_tx, refresh_models_rx) = watch::channel(());
86
87 let mut this = Self {
88 models: HashMap::default(),
89 model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()),
90 refresh_models_rx,
91 refresh_models_tx,
92 _authenticate_all_providers_task: Self::authenticate_all_language_model_providers(cx),
93 };
94 this.refresh_list(cx);
95 this
96 }
97
98 fn refresh_list(&mut self, cx: &App) {
99 let providers = LanguageModelRegistry::global(cx)
100 .read(cx)
101 .visible_providers()
102 .into_iter()
103 .filter(|provider| provider.is_authenticated(cx))
104 .collect::<Vec<_>>();
105
106 let mut language_model_list = IndexMap::default();
107 let mut recommended_models = HashSet::default();
108
109 let mut recommended = Vec::new();
110 for provider in &providers {
111 for model in provider.recommended_models(cx) {
112 recommended_models.insert((model.provider_id(), model.id()));
113 recommended.push(Self::map_language_model_to_info(&model, provider));
114 }
115 }
116 if !recommended.is_empty() {
117 language_model_list.insert(
118 acp_thread::AgentModelGroupName("Recommended".into()),
119 recommended,
120 );
121 }
122
123 let mut models = HashMap::default();
124 for provider in providers {
125 let mut provider_models = Vec::new();
126 for model in provider.provided_models(cx) {
127 let model_info = Self::map_language_model_to_info(&model, &provider);
128 let model_id = model_info.id.clone();
129 provider_models.push(model_info);
130 models.insert(model_id, model);
131 }
132 if !provider_models.is_empty() {
133 language_model_list.insert(
134 acp_thread::AgentModelGroupName(provider.name().0.clone()),
135 provider_models,
136 );
137 }
138 }
139
140 self.models = models;
141 self.model_list = acp_thread::AgentModelList::Grouped(language_model_list);
142 self.refresh_models_tx.send(()).ok();
143 }
144
145 fn watch(&self) -> watch::Receiver<()> {
146 self.refresh_models_rx.clone()
147 }
148
149 pub fn model_from_id(&self, model_id: &acp::ModelId) -> Option<Arc<dyn LanguageModel>> {
150 self.models.get(model_id).cloned()
151 }
152
153 fn map_language_model_to_info(
154 model: &Arc<dyn LanguageModel>,
155 provider: &Arc<dyn LanguageModelProvider>,
156 ) -> acp_thread::AgentModelInfo {
157 acp_thread::AgentModelInfo {
158 id: Self::model_id(model),
159 name: model.name().0,
160 description: None,
161 icon: Some(match provider.icon() {
162 IconOrSvg::Svg(path) => acp_thread::AgentModelIcon::Path(path),
163 IconOrSvg::Icon(name) => acp_thread::AgentModelIcon::Named(name),
164 }),
165 }
166 }
167
168 fn model_id(model: &Arc<dyn LanguageModel>) -> acp::ModelId {
169 acp::ModelId::new(format!("{}/{}", model.provider_id().0, model.id().0))
170 }
171
172 fn authenticate_all_language_model_providers(cx: &mut App) -> Task<()> {
173 let authenticate_all_providers = LanguageModelRegistry::global(cx)
174 .read(cx)
175 .visible_providers()
176 .iter()
177 .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
178 .collect::<Vec<_>>();
179
180 cx.background_spawn(async move {
181 for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
182 if let Err(err) = authenticate_task.await {
183 match err {
184 language_model::AuthenticateError::CredentialsNotFound => {
185 // Since we're authenticating these providers in the
186 // background for the purposes of populating the
187 // language selector, we don't care about providers
188 // where the credentials are not found.
189 }
190 language_model::AuthenticateError::ConnectionRefused => {
191 // Not logging connection refused errors as they are mostly from LM Studio's noisy auth failures.
192 // LM Studio only has one auth method (endpoint call) which fails for users who haven't enabled it.
193 // TODO: Better manage LM Studio auth logic to avoid these noisy failures.
194 }
195 _ => {
196 // Some providers have noisy failure states that we
197 // don't want to spam the logs with every time the
198 // language model selector is initialized.
199 //
200 // Ideally these should have more clear failure modes
201 // that we know are safe to ignore here, like what we do
202 // with `CredentialsNotFound` above.
203 match provider_id.0.as_ref() {
204 "lmstudio" | "ollama" => {
205 // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
206 //
207 // These fail noisily, so we don't log them.
208 }
209 "copilot_chat" => {
210 // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
211 }
212 _ => {
213 log::error!(
214 "Failed to authenticate provider: {}: {err:#}",
215 provider_name.0
216 );
217 }
218 }
219 }
220 }
221 }
222 }
223 })
224 }
225}
226
227pub struct NativeAgent {
228 /// Session ID -> Session mapping
229 sessions: HashMap<acp::SessionId, Session>,
230 thread_store: Entity<ThreadStore>,
231 /// Shared project context for all threads
232 project_context: Entity<ProjectContext>,
233 project_context_needs_refresh: watch::Sender<()>,
234 _maintain_project_context: Task<Result<()>>,
235 context_server_registry: Entity<ContextServerRegistry>,
236 /// Shared templates for all threads
237 templates: Arc<Templates>,
238 /// Cached model information
239 models: LanguageModels,
240 project: Entity<Project>,
241 prompt_store: Option<Entity<PromptStore>>,
242 fs: Arc<dyn Fs>,
243 _subscriptions: Vec<Subscription>,
244}
245
246impl NativeAgent {
247 pub async fn new(
248 project: Entity<Project>,
249 thread_store: Entity<ThreadStore>,
250 templates: Arc<Templates>,
251 prompt_store: Option<Entity<PromptStore>>,
252 fs: Arc<dyn Fs>,
253 cx: &mut AsyncApp,
254 ) -> Result<Entity<NativeAgent>> {
255 log::debug!("Creating new NativeAgent");
256
257 let project_context = cx
258 .update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))
259 .await;
260
261 Ok(cx.new(|cx| {
262 let context_server_store = project.read(cx).context_server_store();
263 let context_server_registry =
264 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
265
266 let mut subscriptions = vec![
267 cx.subscribe(&project, Self::handle_project_event),
268 cx.subscribe(
269 &LanguageModelRegistry::global(cx),
270 Self::handle_models_updated_event,
271 ),
272 cx.subscribe(
273 &context_server_store,
274 Self::handle_context_server_store_updated,
275 ),
276 cx.subscribe(
277 &context_server_registry,
278 Self::handle_context_server_registry_event,
279 ),
280 ];
281 if let Some(prompt_store) = prompt_store.as_ref() {
282 subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
283 }
284
285 let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
286 watch::channel(());
287 Self {
288 sessions: HashMap::default(),
289 thread_store,
290 project_context: cx.new(|_| project_context),
291 project_context_needs_refresh: project_context_needs_refresh_tx,
292 _maintain_project_context: cx.spawn(async move |this, cx| {
293 Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
294 }),
295 context_server_registry,
296 templates,
297 models: LanguageModels::new(cx),
298 project,
299 prompt_store,
300 fs,
301 _subscriptions: subscriptions,
302 }
303 }))
304 }
305
306 fn register_session(
307 &mut self,
308 thread_handle: Entity<Thread>,
309 cx: &mut Context<Self>,
310 ) -> Entity<AcpThread> {
311 let connection = Rc::new(NativeAgentConnection(cx.entity()));
312
313 let thread = thread_handle.read(cx);
314 let session_id = thread.id().clone();
315 let title = thread.title();
316 let project = thread.project.clone();
317 let action_log = thread.action_log.clone();
318 let prompt_capabilities_rx = thread.prompt_capabilities_rx.clone();
319 let acp_thread = cx.new(|cx| {
320 acp_thread::AcpThread::new(
321 title,
322 connection,
323 project.clone(),
324 action_log.clone(),
325 session_id.clone(),
326 prompt_capabilities_rx,
327 cx,
328 )
329 });
330
331 let registry = LanguageModelRegistry::read_global(cx);
332 let summarization_model = registry.thread_summary_model().map(|c| c.model);
333
334 thread_handle.update(cx, |thread, cx| {
335 thread.set_summarization_model(summarization_model, cx);
336 thread.add_default_tools(
337 Rc::new(AcpThreadEnvironment {
338 acp_thread: acp_thread.downgrade(),
339 }) as _,
340 cx,
341 )
342 });
343
344 let subscriptions = vec![
345 cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
346 this.sessions.remove(acp_thread.session_id());
347 }),
348 cx.subscribe(&thread_handle, Self::handle_thread_title_updated),
349 cx.subscribe(&thread_handle, Self::handle_thread_token_usage_updated),
350 cx.observe(&thread_handle, move |this, thread, cx| {
351 this.save_thread(thread, cx)
352 }),
353 ];
354
355 self.sessions.insert(
356 session_id,
357 Session {
358 thread: thread_handle,
359 acp_thread: acp_thread.downgrade(),
360 _subscriptions: subscriptions,
361 pending_save: Task::ready(()),
362 },
363 );
364
365 self.update_available_commands(cx);
366
367 acp_thread
368 }
369
370 pub fn models(&self) -> &LanguageModels {
371 &self.models
372 }
373
374 async fn maintain_project_context(
375 this: WeakEntity<Self>,
376 mut needs_refresh: watch::Receiver<()>,
377 cx: &mut AsyncApp,
378 ) -> Result<()> {
379 while needs_refresh.changed().await.is_ok() {
380 let project_context = this
381 .update(cx, |this, cx| {
382 Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx)
383 })?
384 .await;
385 this.update(cx, |this, cx| {
386 this.project_context = cx.new(|_| project_context);
387 })?;
388 }
389
390 Ok(())
391 }
392
393 fn build_project_context(
394 project: &Entity<Project>,
395 prompt_store: Option<&Entity<PromptStore>>,
396 cx: &mut App,
397 ) -> Task<ProjectContext> {
398 let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
399 let worktree_tasks = worktrees
400 .into_iter()
401 .map(|worktree| {
402 Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
403 })
404 .collect::<Vec<_>>();
405 let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
406 prompt_store.read_with(cx, |prompt_store, cx| {
407 let prompts = prompt_store.default_prompt_metadata();
408 let load_tasks = prompts.into_iter().map(|prompt_metadata| {
409 let contents = prompt_store.load(prompt_metadata.id, cx);
410 async move { (contents.await, prompt_metadata) }
411 });
412 cx.background_spawn(future::join_all(load_tasks))
413 })
414 } else {
415 Task::ready(vec![])
416 };
417
418 cx.spawn(async move |_cx| {
419 let (worktrees, default_user_rules) =
420 future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
421
422 let worktrees = worktrees
423 .into_iter()
424 .map(|(worktree, _rules_error)| {
425 // TODO: show error message
426 // if let Some(rules_error) = rules_error {
427 // this.update(cx, |_, cx| cx.emit(rules_error)).ok();
428 // }
429 worktree
430 })
431 .collect::<Vec<_>>();
432
433 let default_user_rules = default_user_rules
434 .into_iter()
435 .flat_map(|(contents, prompt_metadata)| match contents {
436 Ok(contents) => Some(UserRulesContext {
437 uuid: prompt_metadata.id.as_user()?,
438 title: prompt_metadata.title.map(|title| title.to_string()),
439 contents,
440 }),
441 Err(_err) => {
442 // TODO: show error message
443 // this.update(cx, |_, cx| {
444 // cx.emit(RulesLoadingError {
445 // message: format!("{err:?}").into(),
446 // });
447 // })
448 // .ok();
449 None
450 }
451 })
452 .collect::<Vec<_>>();
453
454 ProjectContext::new(worktrees, default_user_rules)
455 })
456 }
457
458 fn load_worktree_info_for_system_prompt(
459 worktree: Entity<Worktree>,
460 project: Entity<Project>,
461 cx: &mut App,
462 ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
463 let tree = worktree.read(cx);
464 let root_name = tree.root_name_str().into();
465 let abs_path = tree.abs_path();
466
467 let mut context = WorktreeContext {
468 root_name,
469 abs_path,
470 rules_file: None,
471 };
472
473 let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
474 let Some(rules_task) = rules_task else {
475 return Task::ready((context, None));
476 };
477
478 cx.spawn(async move |_| {
479 let (rules_file, rules_file_error) = match rules_task.await {
480 Ok(rules_file) => (Some(rules_file), None),
481 Err(err) => (
482 None,
483 Some(RulesLoadingError {
484 message: format!("{err}").into(),
485 }),
486 ),
487 };
488 context.rules_file = rules_file;
489 (context, rules_file_error)
490 })
491 }
492
493 fn load_worktree_rules_file(
494 worktree: Entity<Worktree>,
495 project: Entity<Project>,
496 cx: &mut App,
497 ) -> Option<Task<Result<RulesFileContext>>> {
498 let worktree = worktree.read(cx);
499 let worktree_id = worktree.id();
500 let selected_rules_file = RULES_FILE_NAMES
501 .into_iter()
502 .filter_map(|name| {
503 worktree
504 .entry_for_path(RelPath::unix(name).unwrap())
505 .filter(|entry| entry.is_file())
506 .map(|entry| entry.path.clone())
507 })
508 .next();
509
510 // Note that Cline supports `.clinerules` being a directory, but that is not currently
511 // supported. This doesn't seem to occur often in GitHub repositories.
512 selected_rules_file.map(|path_in_worktree| {
513 let project_path = ProjectPath {
514 worktree_id,
515 path: path_in_worktree.clone(),
516 };
517 let buffer_task =
518 project.update(cx, |project, cx| project.open_buffer(project_path, cx));
519 let rope_task = cx.spawn(async move |cx| {
520 let buffer = buffer_task.await?;
521 let (project_entry_id, rope) = buffer.read_with(cx, |buffer, cx| {
522 let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
523 anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
524 })?;
525 anyhow::Ok((project_entry_id, rope))
526 });
527 // Build a string from the rope on a background thread.
528 cx.background_spawn(async move {
529 let (project_entry_id, rope) = rope_task.await?;
530 anyhow::Ok(RulesFileContext {
531 path_in_worktree,
532 text: rope.to_string().trim().to_string(),
533 project_entry_id: project_entry_id.to_usize(),
534 })
535 })
536 })
537 }
538
539 fn handle_thread_title_updated(
540 &mut self,
541 thread: Entity<Thread>,
542 _: &TitleUpdated,
543 cx: &mut Context<Self>,
544 ) {
545 let session_id = thread.read(cx).id();
546 let Some(session) = self.sessions.get(session_id) else {
547 return;
548 };
549 let thread = thread.downgrade();
550 let acp_thread = session.acp_thread.clone();
551 cx.spawn(async move |_, cx| {
552 let title = thread.read_with(cx, |thread, _| thread.title())?;
553 let task = acp_thread.update(cx, |acp_thread, cx| acp_thread.set_title(title, cx))?;
554 task.await
555 })
556 .detach_and_log_err(cx);
557 }
558
559 fn handle_thread_token_usage_updated(
560 &mut self,
561 thread: Entity<Thread>,
562 usage: &TokenUsageUpdated,
563 cx: &mut Context<Self>,
564 ) {
565 let Some(session) = self.sessions.get(thread.read(cx).id()) else {
566 return;
567 };
568 session
569 .acp_thread
570 .update(cx, |acp_thread, cx| {
571 acp_thread.update_token_usage(usage.0.clone(), cx);
572 })
573 .ok();
574 }
575
576 fn handle_project_event(
577 &mut self,
578 _project: Entity<Project>,
579 event: &project::Event,
580 _cx: &mut Context<Self>,
581 ) {
582 match event {
583 project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
584 self.project_context_needs_refresh.send(()).ok();
585 }
586 project::Event::WorktreeUpdatedEntries(_, items) => {
587 if items.iter().any(|(path, _, _)| {
588 RULES_FILE_NAMES
589 .iter()
590 .any(|name| path.as_ref() == RelPath::unix(name).unwrap())
591 }) {
592 self.project_context_needs_refresh.send(()).ok();
593 }
594 }
595 _ => {}
596 }
597 }
598
599 fn handle_prompts_updated_event(
600 &mut self,
601 _prompt_store: Entity<PromptStore>,
602 _event: &prompt_store::PromptsUpdatedEvent,
603 _cx: &mut Context<Self>,
604 ) {
605 self.project_context_needs_refresh.send(()).ok();
606 }
607
608 fn handle_models_updated_event(
609 &mut self,
610 _registry: Entity<LanguageModelRegistry>,
611 _event: &language_model::Event,
612 cx: &mut Context<Self>,
613 ) {
614 self.models.refresh_list(cx);
615
616 let registry = LanguageModelRegistry::read_global(cx);
617 let default_model = registry.default_model().map(|m| m.model);
618 let summarization_model = registry.thread_summary_model().map(|m| m.model);
619
620 for session in self.sessions.values_mut() {
621 session.thread.update(cx, |thread, cx| {
622 if thread.model().is_none()
623 && let Some(model) = default_model.clone()
624 {
625 thread.set_model(model, cx);
626 cx.notify();
627 }
628 thread.set_summarization_model(summarization_model.clone(), cx);
629 });
630 }
631 }
632
633 fn handle_context_server_store_updated(
634 &mut self,
635 _store: Entity<project::context_server_store::ContextServerStore>,
636 _event: &project::context_server_store::Event,
637 cx: &mut Context<Self>,
638 ) {
639 self.update_available_commands(cx);
640 }
641
642 fn handle_context_server_registry_event(
643 &mut self,
644 _registry: Entity<ContextServerRegistry>,
645 event: &ContextServerRegistryEvent,
646 cx: &mut Context<Self>,
647 ) {
648 match event {
649 ContextServerRegistryEvent::ToolsChanged => {}
650 ContextServerRegistryEvent::PromptsChanged => {
651 self.update_available_commands(cx);
652 }
653 }
654 }
655
656 fn update_available_commands(&self, cx: &mut Context<Self>) {
657 let available_commands = self.build_available_commands(cx);
658 for session in self.sessions.values() {
659 if let Some(acp_thread) = session.acp_thread.upgrade() {
660 acp_thread.update(cx, |thread, cx| {
661 thread
662 .handle_session_update(
663 acp::SessionUpdate::AvailableCommandsUpdate(
664 acp::AvailableCommandsUpdate::new(available_commands.clone()),
665 ),
666 cx,
667 )
668 .log_err();
669 });
670 }
671 }
672 }
673
674 fn build_available_commands(&self, cx: &App) -> Vec<acp::AvailableCommand> {
675 let registry = self.context_server_registry.read(cx);
676
677 let mut prompt_name_counts: HashMap<&str, usize> = HashMap::default();
678 for context_server_prompt in registry.prompts() {
679 *prompt_name_counts
680 .entry(context_server_prompt.prompt.name.as_str())
681 .or_insert(0) += 1;
682 }
683
684 registry
685 .prompts()
686 .flat_map(|context_server_prompt| {
687 let prompt = &context_server_prompt.prompt;
688
689 let should_prefix = prompt_name_counts
690 .get(prompt.name.as_str())
691 .copied()
692 .unwrap_or(0)
693 > 1;
694
695 let name = if should_prefix {
696 format!("{}.{}", context_server_prompt.server_id, prompt.name)
697 } else {
698 prompt.name.clone()
699 };
700
701 let mut command = acp::AvailableCommand::new(
702 name,
703 prompt.description.clone().unwrap_or_default(),
704 );
705
706 match prompt.arguments.as_deref() {
707 Some([arg]) => {
708 let hint = format!("<{}>", arg.name);
709
710 command = command.input(acp::AvailableCommandInput::Unstructured(
711 acp::UnstructuredCommandInput::new(hint),
712 ));
713 }
714 Some([]) | None => {}
715 Some(_) => {
716 // skip >1 argument commands since we don't support them yet
717 return None;
718 }
719 }
720
721 Some(command)
722 })
723 .collect()
724 }
725
726 pub fn load_thread(
727 &mut self,
728 id: acp::SessionId,
729 cx: &mut Context<Self>,
730 ) -> Task<Result<Entity<Thread>>> {
731 let database_future = ThreadsDatabase::connect(cx);
732 cx.spawn(async move |this, cx| {
733 let database = database_future.await.map_err(|err| anyhow!(err))?;
734 let db_thread = database
735 .load_thread(id.clone())
736 .await?
737 .with_context(|| format!("no thread found with ID: {id:?}"))?;
738
739 this.update(cx, |this, cx| {
740 let summarization_model = LanguageModelRegistry::read_global(cx)
741 .thread_summary_model()
742 .map(|c| c.model);
743
744 cx.new(|cx| {
745 let mut thread = Thread::from_db(
746 id.clone(),
747 db_thread,
748 this.project.clone(),
749 this.project_context.clone(),
750 this.context_server_registry.clone(),
751 this.templates.clone(),
752 cx,
753 );
754 thread.set_summarization_model(summarization_model, cx);
755 thread
756 })
757 })
758 })
759 }
760
761 pub fn open_thread(
762 &mut self,
763 id: acp::SessionId,
764 cx: &mut Context<Self>,
765 ) -> Task<Result<Entity<AcpThread>>> {
766 let task = self.load_thread(id, cx);
767 cx.spawn(async move |this, cx| {
768 let thread = task.await?;
769 let acp_thread =
770 this.update(cx, |this, cx| this.register_session(thread.clone(), cx))?;
771 let events = thread.update(cx, |thread, cx| thread.replay(cx));
772 cx.update(|cx| {
773 NativeAgentConnection::handle_thread_events(events, acp_thread.downgrade(), cx)
774 })
775 .await?;
776 Ok(acp_thread)
777 })
778 }
779
780 pub fn thread_summary(
781 &mut self,
782 id: acp::SessionId,
783 cx: &mut Context<Self>,
784 ) -> Task<Result<SharedString>> {
785 let thread = self.open_thread(id.clone(), cx);
786 cx.spawn(async move |this, cx| {
787 let acp_thread = thread.await?;
788 let result = this
789 .update(cx, |this, cx| {
790 this.sessions
791 .get(&id)
792 .unwrap()
793 .thread
794 .update(cx, |thread, cx| thread.summary(cx))
795 })?
796 .await
797 .context("Failed to generate summary")?;
798 drop(acp_thread);
799 Ok(result)
800 })
801 }
802
803 fn save_thread(&mut self, thread: Entity<Thread>, cx: &mut Context<Self>) {
804 if thread.read(cx).is_empty() {
805 return;
806 }
807
808 let database_future = ThreadsDatabase::connect(cx);
809 let (id, db_thread) =
810 thread.update(cx, |thread, cx| (thread.id().clone(), thread.to_db(cx)));
811 let Some(session) = self.sessions.get_mut(&id) else {
812 return;
813 };
814 let thread_store = self.thread_store.clone();
815 session.pending_save = cx.spawn(async move |_, cx| {
816 let Some(database) = database_future.await.map_err(|err| anyhow!(err)).log_err() else {
817 return;
818 };
819 let db_thread = db_thread.await;
820 database.save_thread(id, db_thread).await.log_err();
821 thread_store.update(cx, |store, cx| store.reload(cx));
822 });
823 }
824
825 fn send_mcp_prompt(
826 &self,
827 message_id: UserMessageId,
828 session_id: agent_client_protocol::SessionId,
829 prompt_name: String,
830 server_id: ContextServerId,
831 arguments: HashMap<String, String>,
832 original_content: Vec<acp::ContentBlock>,
833 cx: &mut Context<Self>,
834 ) -> Task<Result<acp::PromptResponse>> {
835 let server_store = self.context_server_registry.read(cx).server_store().clone();
836 let path_style = self.project.read(cx).path_style(cx);
837
838 cx.spawn(async move |this, cx| {
839 let prompt =
840 crate::get_prompt(&server_store, &server_id, &prompt_name, arguments, cx).await?;
841
842 let (acp_thread, thread) = this.update(cx, |this, _cx| {
843 let session = this
844 .sessions
845 .get(&session_id)
846 .context("Failed to get session")?;
847 anyhow::Ok((session.acp_thread.clone(), session.thread.clone()))
848 })??;
849
850 let mut last_is_user = true;
851
852 thread.update(cx, |thread, cx| {
853 thread.push_acp_user_block(
854 message_id,
855 original_content.into_iter().skip(1),
856 path_style,
857 cx,
858 );
859 });
860
861 for message in prompt.messages {
862 let context_server::types::PromptMessage { role, content } = message;
863 let block = mcp_message_content_to_acp_content_block(content);
864
865 match role {
866 context_server::types::Role::User => {
867 let id = acp_thread::UserMessageId::new();
868
869 acp_thread.update(cx, |acp_thread, cx| {
870 acp_thread.push_user_content_block_with_indent(
871 Some(id.clone()),
872 block.clone(),
873 true,
874 cx,
875 );
876 })?;
877
878 thread.update(cx, |thread, cx| {
879 thread.push_acp_user_block(id, [block], path_style, cx);
880 });
881 }
882 context_server::types::Role::Assistant => {
883 acp_thread.update(cx, |acp_thread, cx| {
884 acp_thread.push_assistant_content_block_with_indent(
885 block.clone(),
886 false,
887 true,
888 cx,
889 );
890 })?;
891
892 thread.update(cx, |thread, cx| {
893 thread.push_acp_agent_block(block, cx);
894 });
895 }
896 }
897
898 last_is_user = role == context_server::types::Role::User;
899 }
900
901 let response_stream = thread.update(cx, |thread, cx| {
902 if last_is_user {
903 thread.send_existing(cx)
904 } else {
905 // Resume if MCP prompt did not end with a user message
906 thread.resume(cx)
907 }
908 })?;
909
910 cx.update(|cx| {
911 NativeAgentConnection::handle_thread_events(response_stream, acp_thread, cx)
912 })
913 .await
914 })
915 }
916}
917
918/// Wrapper struct that implements the AgentConnection trait
919#[derive(Clone)]
920pub struct NativeAgentConnection(pub Entity<NativeAgent>);
921
922impl NativeAgentConnection {
923 pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option<Entity<Thread>> {
924 self.0
925 .read(cx)
926 .sessions
927 .get(session_id)
928 .map(|session| session.thread.clone())
929 }
930
931 pub fn load_thread(&self, id: acp::SessionId, cx: &mut App) -> Task<Result<Entity<Thread>>> {
932 self.0.update(cx, |this, cx| this.load_thread(id, cx))
933 }
934
935 fn run_turn(
936 &self,
937 session_id: acp::SessionId,
938 cx: &mut App,
939 f: impl 'static
940 + FnOnce(Entity<Thread>, &mut App) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>,
941 ) -> Task<Result<acp::PromptResponse>> {
942 let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
943 agent
944 .sessions
945 .get_mut(&session_id)
946 .map(|s| (s.thread.clone(), s.acp_thread.clone()))
947 }) else {
948 return Task::ready(Err(anyhow!("Session not found")));
949 };
950 log::debug!("Found session for: {}", session_id);
951
952 let response_stream = match f(thread, cx) {
953 Ok(stream) => stream,
954 Err(err) => return Task::ready(Err(err)),
955 };
956 Self::handle_thread_events(response_stream, acp_thread, cx)
957 }
958
959 fn handle_thread_events(
960 mut events: mpsc::UnboundedReceiver<Result<ThreadEvent>>,
961 acp_thread: WeakEntity<AcpThread>,
962 cx: &App,
963 ) -> Task<Result<acp::PromptResponse>> {
964 cx.spawn(async move |cx| {
965 // Handle response stream and forward to session.acp_thread
966 while let Some(result) = events.next().await {
967 match result {
968 Ok(event) => {
969 log::trace!("Received completion event: {:?}", event);
970
971 match event {
972 ThreadEvent::UserMessage(message) => {
973 acp_thread.update(cx, |thread, cx| {
974 for content in message.content {
975 thread.push_user_content_block(
976 Some(message.id.clone()),
977 content.into(),
978 cx,
979 );
980 }
981 })?;
982 }
983 ThreadEvent::AgentText(text) => {
984 acp_thread.update(cx, |thread, cx| {
985 thread.push_assistant_content_block(text.into(), false, cx)
986 })?;
987 }
988 ThreadEvent::AgentThinking(text) => {
989 acp_thread.update(cx, |thread, cx| {
990 thread.push_assistant_content_block(text.into(), true, cx)
991 })?;
992 }
993 ThreadEvent::ToolCallAuthorization(ToolCallAuthorization {
994 tool_call,
995 options,
996 response,
997 }) => {
998 let outcome_task = acp_thread.update(cx, |thread, cx| {
999 thread.request_tool_call_authorization(
1000 tool_call, options, true, cx,
1001 )
1002 })??;
1003 cx.background_spawn(async move {
1004 if let acp::RequestPermissionOutcome::Selected(
1005 acp::SelectedPermissionOutcome { option_id, .. },
1006 ) = outcome_task.await
1007 {
1008 response
1009 .send(option_id)
1010 .map(|_| anyhow!("authorization receiver was dropped"))
1011 .log_err();
1012 }
1013 })
1014 .detach();
1015 }
1016 ThreadEvent::ToolCall(tool_call) => {
1017 acp_thread.update(cx, |thread, cx| {
1018 thread.upsert_tool_call(tool_call, cx)
1019 })??;
1020 }
1021 ThreadEvent::ToolCallUpdate(update) => {
1022 acp_thread.update(cx, |thread, cx| {
1023 thread.update_tool_call(update, cx)
1024 })??;
1025 }
1026 ThreadEvent::Retry(status) => {
1027 acp_thread.update(cx, |thread, cx| {
1028 thread.update_retry_status(status, cx)
1029 })?;
1030 }
1031 ThreadEvent::Stop(stop_reason) => {
1032 log::debug!("Assistant message complete: {:?}", stop_reason);
1033 return Ok(acp::PromptResponse::new(stop_reason));
1034 }
1035 }
1036 }
1037 Err(e) => {
1038 log::error!("Error in model response stream: {:?}", e);
1039 return Err(e);
1040 }
1041 }
1042 }
1043
1044 log::debug!("Response stream completed");
1045 anyhow::Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
1046 })
1047 }
1048}
1049
1050struct Command<'a> {
1051 prompt_name: &'a str,
1052 arg_value: &'a str,
1053 explicit_server_id: Option<&'a str>,
1054}
1055
1056impl<'a> Command<'a> {
1057 fn parse(prompt: &'a [acp::ContentBlock]) -> Option<Self> {
1058 let acp::ContentBlock::Text(text_content) = prompt.first()? else {
1059 return None;
1060 };
1061 let text = text_content.text.trim();
1062 let command = text.strip_prefix('/')?;
1063 let (command, arg_value) = command
1064 .split_once(char::is_whitespace)
1065 .unwrap_or((command, ""));
1066
1067 if let Some((server_id, prompt_name)) = command.split_once('.') {
1068 Some(Self {
1069 prompt_name,
1070 arg_value,
1071 explicit_server_id: Some(server_id),
1072 })
1073 } else {
1074 Some(Self {
1075 prompt_name: command,
1076 arg_value,
1077 explicit_server_id: None,
1078 })
1079 }
1080 }
1081}
1082
1083struct NativeAgentModelSelector {
1084 session_id: acp::SessionId,
1085 connection: NativeAgentConnection,
1086}
1087
1088impl acp_thread::AgentModelSelector for NativeAgentModelSelector {
1089 fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
1090 log::debug!("NativeAgentConnection::list_models called");
1091 let list = self.connection.0.read(cx).models.model_list.clone();
1092 Task::ready(if list.is_empty() {
1093 Err(anyhow::anyhow!("No models available"))
1094 } else {
1095 Ok(list)
1096 })
1097 }
1098
1099 fn select_model(&self, model_id: acp::ModelId, cx: &mut App) -> Task<Result<()>> {
1100 log::debug!(
1101 "Setting model for session {}: {}",
1102 self.session_id,
1103 model_id
1104 );
1105 let Some(thread) = self
1106 .connection
1107 .0
1108 .read(cx)
1109 .sessions
1110 .get(&self.session_id)
1111 .map(|session| session.thread.clone())
1112 else {
1113 return Task::ready(Err(anyhow!("Session not found")));
1114 };
1115
1116 let Some(model) = self.connection.0.read(cx).models.model_from_id(&model_id) else {
1117 return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
1118 };
1119
1120 thread.update(cx, |thread, cx| {
1121 thread.set_model(model.clone(), cx);
1122 });
1123
1124 update_settings_file(
1125 self.connection.0.read(cx).fs.clone(),
1126 cx,
1127 move |settings, _cx| {
1128 let provider = model.provider_id().0.to_string();
1129 let model = model.id().0.to_string();
1130 settings
1131 .agent
1132 .get_or_insert_default()
1133 .set_model(LanguageModelSelection {
1134 provider: provider.into(),
1135 model,
1136 });
1137 },
1138 );
1139
1140 Task::ready(Ok(()))
1141 }
1142
1143 fn selected_model(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelInfo>> {
1144 let Some(thread) = self
1145 .connection
1146 .0
1147 .read(cx)
1148 .sessions
1149 .get(&self.session_id)
1150 .map(|session| session.thread.clone())
1151 else {
1152 return Task::ready(Err(anyhow!("Session not found")));
1153 };
1154 let Some(model) = thread.read(cx).model() else {
1155 return Task::ready(Err(anyhow!("Model not found")));
1156 };
1157 let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
1158 else {
1159 return Task::ready(Err(anyhow!("Provider not found")));
1160 };
1161 Task::ready(Ok(LanguageModels::map_language_model_to_info(
1162 model, &provider,
1163 )))
1164 }
1165
1166 fn watch(&self, cx: &mut App) -> Option<watch::Receiver<()>> {
1167 Some(self.connection.0.read(cx).models.watch())
1168 }
1169
1170 fn should_render_footer(&self) -> bool {
1171 true
1172 }
1173}
1174
1175impl acp_thread::AgentConnection for NativeAgentConnection {
1176 fn telemetry_id(&self) -> SharedString {
1177 "zed".into()
1178 }
1179
1180 fn new_thread(
1181 self: Rc<Self>,
1182 project: Entity<Project>,
1183 cwd: &Path,
1184 cx: &mut App,
1185 ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
1186 let agent = self.0.clone();
1187 log::debug!("Creating new thread for project at: {:?}", cwd);
1188
1189 cx.spawn(async move |cx| {
1190 log::debug!("Starting thread creation in async context");
1191
1192 // Create Thread
1193 let thread = agent.update(cx, |agent, cx| {
1194 // Fetch default model from registry settings
1195 let registry = LanguageModelRegistry::read_global(cx);
1196 // Log available models for debugging
1197 let available_count = registry.available_models(cx).count();
1198 log::debug!("Total available models: {}", available_count);
1199
1200 let default_model = registry.default_model().and_then(|default_model| {
1201 agent
1202 .models
1203 .model_from_id(&LanguageModels::model_id(&default_model.model))
1204 });
1205 cx.new(|cx| {
1206 Thread::new(
1207 project.clone(),
1208 agent.project_context.clone(),
1209 agent.context_server_registry.clone(),
1210 agent.templates.clone(),
1211 default_model,
1212 cx,
1213 )
1214 })
1215 });
1216 Ok(agent.update(cx, |agent, cx| agent.register_session(thread, cx)))
1217 })
1218 }
1219
1220 fn auth_methods(&self) -> &[acp::AuthMethod] {
1221 &[] // No auth for in-process
1222 }
1223
1224 fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
1225 Task::ready(Ok(()))
1226 }
1227
1228 fn model_selector(&self, session_id: &acp::SessionId) -> Option<Rc<dyn AgentModelSelector>> {
1229 Some(Rc::new(NativeAgentModelSelector {
1230 session_id: session_id.clone(),
1231 connection: self.clone(),
1232 }) as Rc<dyn AgentModelSelector>)
1233 }
1234
1235 fn prompt(
1236 &self,
1237 id: Option<acp_thread::UserMessageId>,
1238 params: acp::PromptRequest,
1239 cx: &mut App,
1240 ) -> Task<Result<acp::PromptResponse>> {
1241 let id = id.expect("UserMessageId is required");
1242 let session_id = params.session_id.clone();
1243 log::info!("Received prompt request for session: {}", session_id);
1244 log::debug!("Prompt blocks count: {}", params.prompt.len());
1245
1246 if let Some(parsed_command) = Command::parse(¶ms.prompt) {
1247 let registry = self.0.read(cx).context_server_registry.read(cx);
1248
1249 let explicit_server_id = parsed_command
1250 .explicit_server_id
1251 .map(|server_id| ContextServerId(server_id.into()));
1252
1253 if let Some(prompt) =
1254 registry.find_prompt(explicit_server_id.as_ref(), parsed_command.prompt_name)
1255 {
1256 let arguments = if !parsed_command.arg_value.is_empty()
1257 && let Some(arg_name) = prompt
1258 .prompt
1259 .arguments
1260 .as_ref()
1261 .and_then(|args| args.first())
1262 .map(|arg| arg.name.clone())
1263 {
1264 HashMap::from_iter([(arg_name, parsed_command.arg_value.to_string())])
1265 } else {
1266 Default::default()
1267 };
1268
1269 let prompt_name = prompt.prompt.name.clone();
1270 let server_id = prompt.server_id.clone();
1271
1272 return self.0.update(cx, |agent, cx| {
1273 agent.send_mcp_prompt(
1274 id,
1275 session_id.clone(),
1276 prompt_name,
1277 server_id,
1278 arguments,
1279 params.prompt,
1280 cx,
1281 )
1282 });
1283 };
1284 };
1285
1286 let path_style = self.0.read(cx).project.read(cx).path_style(cx);
1287
1288 self.run_turn(session_id, cx, move |thread, cx| {
1289 let content: Vec<UserMessageContent> = params
1290 .prompt
1291 .into_iter()
1292 .map(|block| UserMessageContent::from_content_block(block, path_style))
1293 .collect::<Vec<_>>();
1294 log::debug!("Converted prompt to message: {} chars", content.len());
1295 log::debug!("Message id: {:?}", id);
1296 log::debug!("Message content: {:?}", content);
1297
1298 thread.update(cx, |thread, cx| thread.send(id, content, cx))
1299 })
1300 }
1301
1302 fn resume(
1303 &self,
1304 session_id: &acp::SessionId,
1305 _cx: &App,
1306 ) -> Option<Rc<dyn acp_thread::AgentSessionResume>> {
1307 Some(Rc::new(NativeAgentSessionResume {
1308 connection: self.clone(),
1309 session_id: session_id.clone(),
1310 }) as _)
1311 }
1312
1313 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
1314 log::info!("Cancelling on session: {}", session_id);
1315 self.0.update(cx, |agent, cx| {
1316 if let Some(agent) = agent.sessions.get(session_id) {
1317 agent
1318 .thread
1319 .update(cx, |thread, cx| thread.cancel(cx))
1320 .detach();
1321 }
1322 });
1323 }
1324
1325 fn truncate(
1326 &self,
1327 session_id: &agent_client_protocol::SessionId,
1328 cx: &App,
1329 ) -> Option<Rc<dyn acp_thread::AgentSessionTruncate>> {
1330 self.0.read_with(cx, |agent, _cx| {
1331 agent.sessions.get(session_id).map(|session| {
1332 Rc::new(NativeAgentSessionTruncate {
1333 thread: session.thread.clone(),
1334 acp_thread: session.acp_thread.clone(),
1335 }) as _
1336 })
1337 })
1338 }
1339
1340 fn set_title(
1341 &self,
1342 session_id: &acp::SessionId,
1343 _cx: &App,
1344 ) -> Option<Rc<dyn acp_thread::AgentSessionSetTitle>> {
1345 Some(Rc::new(NativeAgentSessionSetTitle {
1346 connection: self.clone(),
1347 session_id: session_id.clone(),
1348 }) as _)
1349 }
1350
1351 fn session_list(&self, cx: &mut App) -> Option<Rc<dyn AgentSessionList>> {
1352 let thread_store = self.0.read(cx).thread_store.clone();
1353 Some(Rc::new(NativeAgentSessionList::new(thread_store, cx)) as _)
1354 }
1355
1356 fn telemetry(&self) -> Option<Rc<dyn acp_thread::AgentTelemetry>> {
1357 Some(Rc::new(self.clone()) as Rc<dyn acp_thread::AgentTelemetry>)
1358 }
1359
1360 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1361 self
1362 }
1363}
1364
1365impl acp_thread::AgentTelemetry for NativeAgentConnection {
1366 fn thread_data(
1367 &self,
1368 session_id: &acp::SessionId,
1369 cx: &mut App,
1370 ) -> Task<Result<serde_json::Value>> {
1371 let Some(session) = self.0.read(cx).sessions.get(session_id) else {
1372 return Task::ready(Err(anyhow!("Session not found")));
1373 };
1374
1375 let task = session.thread.read(cx).to_db(cx);
1376 cx.background_spawn(async move {
1377 serde_json::to_value(task.await).context("Failed to serialize thread")
1378 })
1379 }
1380}
1381
1382pub struct NativeAgentSessionList {
1383 thread_store: Entity<ThreadStore>,
1384 updates_rx: watch::Receiver<()>,
1385 _subscription: Subscription,
1386}
1387
1388impl NativeAgentSessionList {
1389 fn new(thread_store: Entity<ThreadStore>, cx: &mut App) -> Self {
1390 let (mut tx, rx) = watch::channel(());
1391 let subscription = cx.observe(&thread_store, move |_, _| {
1392 tx.send(()).ok();
1393 });
1394 Self {
1395 thread_store,
1396 updates_rx: rx,
1397 _subscription: subscription,
1398 }
1399 }
1400
1401 fn to_session_info(entry: DbThreadMetadata) -> AgentSessionInfo {
1402 AgentSessionInfo {
1403 session_id: entry.id,
1404 cwd: None,
1405 title: Some(entry.title),
1406 updated_at: Some(entry.updated_at),
1407 meta: None,
1408 }
1409 }
1410
1411 pub fn thread_store(&self) -> &Entity<ThreadStore> {
1412 &self.thread_store
1413 }
1414}
1415
1416impl AgentSessionList for NativeAgentSessionList {
1417 fn list_sessions(
1418 &self,
1419 _request: AgentSessionListRequest,
1420 cx: &mut App,
1421 ) -> Task<Result<AgentSessionListResponse>> {
1422 let sessions = self
1423 .thread_store
1424 .read(cx)
1425 .entries()
1426 .map(Self::to_session_info)
1427 .collect();
1428 Task::ready(Ok(AgentSessionListResponse::new(sessions)))
1429 }
1430
1431 fn supports_delete(&self) -> bool {
1432 true
1433 }
1434
1435 fn delete_session(&self, session_id: &acp::SessionId, cx: &mut App) -> Task<Result<()>> {
1436 self.thread_store
1437 .update(cx, |store, cx| store.delete_thread(session_id.clone(), cx))
1438 }
1439
1440 fn delete_sessions(&self, cx: &mut App) -> Task<Result<()>> {
1441 self.thread_store
1442 .update(cx, |store, cx| store.delete_threads(cx))
1443 }
1444
1445 fn watch(&self, _cx: &mut App) -> Option<watch::Receiver<()>> {
1446 Some(self.updates_rx.clone())
1447 }
1448
1449 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1450 self
1451 }
1452}
1453
1454struct NativeAgentSessionTruncate {
1455 thread: Entity<Thread>,
1456 acp_thread: WeakEntity<AcpThread>,
1457}
1458
1459impl acp_thread::AgentSessionTruncate for NativeAgentSessionTruncate {
1460 fn run(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
1461 match self.thread.update(cx, |thread, cx| {
1462 thread.truncate(message_id.clone(), cx)?;
1463 Ok(thread.latest_token_usage())
1464 }) {
1465 Ok(usage) => {
1466 self.acp_thread
1467 .update(cx, |thread, cx| {
1468 thread.update_token_usage(usage, cx);
1469 })
1470 .ok();
1471 Task::ready(Ok(()))
1472 }
1473 Err(error) => Task::ready(Err(error)),
1474 }
1475 }
1476}
1477
1478struct NativeAgentSessionResume {
1479 connection: NativeAgentConnection,
1480 session_id: acp::SessionId,
1481}
1482
1483impl acp_thread::AgentSessionResume for NativeAgentSessionResume {
1484 fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>> {
1485 self.connection
1486 .run_turn(self.session_id.clone(), cx, |thread, cx| {
1487 thread.update(cx, |thread, cx| thread.resume(cx))
1488 })
1489 }
1490}
1491
1492struct NativeAgentSessionSetTitle {
1493 connection: NativeAgentConnection,
1494 session_id: acp::SessionId,
1495}
1496
1497impl acp_thread::AgentSessionSetTitle for NativeAgentSessionSetTitle {
1498 fn run(&self, title: SharedString, cx: &mut App) -> Task<Result<()>> {
1499 let Some(session) = self.connection.0.read(cx).sessions.get(&self.session_id) else {
1500 return Task::ready(Err(anyhow!("session not found")));
1501 };
1502 let thread = session.thread.clone();
1503 thread.update(cx, |thread, cx| thread.set_title(title, cx));
1504 Task::ready(Ok(()))
1505 }
1506}
1507
1508pub struct AcpThreadEnvironment {
1509 acp_thread: WeakEntity<AcpThread>,
1510}
1511
1512impl ThreadEnvironment for AcpThreadEnvironment {
1513 fn create_terminal(
1514 &self,
1515 command: String,
1516 cwd: Option<PathBuf>,
1517 output_byte_limit: Option<u64>,
1518 cx: &mut AsyncApp,
1519 ) -> Task<Result<Rc<dyn TerminalHandle>>> {
1520 let task = self.acp_thread.update(cx, |thread, cx| {
1521 thread.create_terminal(command, vec![], vec![], cwd, output_byte_limit, cx)
1522 });
1523
1524 let acp_thread = self.acp_thread.clone();
1525 cx.spawn(async move |cx| {
1526 let terminal = task?.await?;
1527
1528 let (drop_tx, drop_rx) = oneshot::channel();
1529 let terminal_id = terminal.read_with(cx, |terminal, _cx| terminal.id().clone());
1530
1531 cx.spawn(async move |cx| {
1532 drop_rx.await.ok();
1533 acp_thread.update(cx, |thread, cx| thread.release_terminal(terminal_id, cx))
1534 })
1535 .detach();
1536
1537 let handle = AcpTerminalHandle {
1538 terminal,
1539 _drop_tx: Some(drop_tx),
1540 };
1541
1542 Ok(Rc::new(handle) as _)
1543 })
1544 }
1545}
1546
1547pub struct AcpTerminalHandle {
1548 terminal: Entity<acp_thread::Terminal>,
1549 _drop_tx: Option<oneshot::Sender<()>>,
1550}
1551
1552impl TerminalHandle for AcpTerminalHandle {
1553 fn id(&self, cx: &AsyncApp) -> Result<acp::TerminalId> {
1554 Ok(self.terminal.read_with(cx, |term, _cx| term.id().clone()))
1555 }
1556
1557 fn wait_for_exit(&self, cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
1558 Ok(self
1559 .terminal
1560 .read_with(cx, |term, _cx| term.wait_for_exit()))
1561 }
1562
1563 fn current_output(&self, cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
1564 Ok(self
1565 .terminal
1566 .read_with(cx, |term, cx| term.current_output(cx)))
1567 }
1568
1569 fn kill(&self, cx: &AsyncApp) -> Result<()> {
1570 cx.update(|cx| {
1571 self.terminal.update(cx, |terminal, cx| {
1572 terminal.kill(cx);
1573 });
1574 });
1575 Ok(())
1576 }
1577
1578 fn was_stopped_by_user(&self, cx: &AsyncApp) -> Result<bool> {
1579 Ok(self
1580 .terminal
1581 .read_with(cx, |term, _cx| term.was_stopped_by_user()))
1582 }
1583}
1584
1585#[cfg(test)]
1586mod internal_tests {
1587 use super::*;
1588 use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelInfo, MentionUri};
1589 use fs::FakeFs;
1590 use gpui::TestAppContext;
1591 use indoc::formatdoc;
1592 use language_model::fake_provider::FakeLanguageModel;
1593 use serde_json::json;
1594 use settings::SettingsStore;
1595 use util::{path, rel_path::rel_path};
1596
1597 #[gpui::test]
1598 async fn test_maintaining_project_context(cx: &mut TestAppContext) {
1599 init_test(cx);
1600 let fs = FakeFs::new(cx.executor());
1601 fs.insert_tree(
1602 "/",
1603 json!({
1604 "a": {}
1605 }),
1606 )
1607 .await;
1608 let project = Project::test(fs.clone(), [], cx).await;
1609 let thread_store = cx.new(|cx| ThreadStore::new(cx));
1610 let agent = NativeAgent::new(
1611 project.clone(),
1612 thread_store,
1613 Templates::new(),
1614 None,
1615 fs.clone(),
1616 &mut cx.to_async(),
1617 )
1618 .await
1619 .unwrap();
1620 agent.read_with(cx, |agent, cx| {
1621 assert_eq!(agent.project_context.read(cx).worktrees, vec![])
1622 });
1623
1624 let worktree = project
1625 .update(cx, |project, cx| project.create_worktree("/a", true, cx))
1626 .await
1627 .unwrap();
1628 cx.run_until_parked();
1629 agent.read_with(cx, |agent, cx| {
1630 assert_eq!(
1631 agent.project_context.read(cx).worktrees,
1632 vec![WorktreeContext {
1633 root_name: "a".into(),
1634 abs_path: Path::new("/a").into(),
1635 rules_file: None
1636 }]
1637 )
1638 });
1639
1640 // Creating `/a/.rules` updates the project context.
1641 fs.insert_file("/a/.rules", Vec::new()).await;
1642 cx.run_until_parked();
1643 agent.read_with(cx, |agent, cx| {
1644 let rules_entry = worktree
1645 .read(cx)
1646 .entry_for_path(rel_path(".rules"))
1647 .unwrap();
1648 assert_eq!(
1649 agent.project_context.read(cx).worktrees,
1650 vec![WorktreeContext {
1651 root_name: "a".into(),
1652 abs_path: Path::new("/a").into(),
1653 rules_file: Some(RulesFileContext {
1654 path_in_worktree: rel_path(".rules").into(),
1655 text: "".into(),
1656 project_entry_id: rules_entry.id.to_usize()
1657 })
1658 }]
1659 )
1660 });
1661 }
1662
1663 #[gpui::test]
1664 async fn test_listing_models(cx: &mut TestAppContext) {
1665 init_test(cx);
1666 let fs = FakeFs::new(cx.executor());
1667 fs.insert_tree("/", json!({ "a": {} })).await;
1668 let project = Project::test(fs.clone(), [], cx).await;
1669 let thread_store = cx.new(|cx| ThreadStore::new(cx));
1670 let connection = NativeAgentConnection(
1671 NativeAgent::new(
1672 project.clone(),
1673 thread_store,
1674 Templates::new(),
1675 None,
1676 fs.clone(),
1677 &mut cx.to_async(),
1678 )
1679 .await
1680 .unwrap(),
1681 );
1682
1683 // Create a thread/session
1684 let acp_thread = cx
1685 .update(|cx| {
1686 Rc::new(connection.clone()).new_thread(project.clone(), Path::new("/a"), cx)
1687 })
1688 .await
1689 .unwrap();
1690
1691 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
1692
1693 let models = cx
1694 .update(|cx| {
1695 connection
1696 .model_selector(&session_id)
1697 .unwrap()
1698 .list_models(cx)
1699 })
1700 .await
1701 .unwrap();
1702
1703 let acp_thread::AgentModelList::Grouped(models) = models else {
1704 panic!("Unexpected model group");
1705 };
1706 assert_eq!(
1707 models,
1708 IndexMap::from_iter([(
1709 AgentModelGroupName("Fake".into()),
1710 vec![AgentModelInfo {
1711 id: acp::ModelId::new("fake/fake"),
1712 name: "Fake".into(),
1713 description: None,
1714 icon: Some(acp_thread::AgentModelIcon::Named(
1715 ui::IconName::ZedAssistant
1716 )),
1717 }]
1718 )])
1719 );
1720 }
1721
1722 #[gpui::test]
1723 async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) {
1724 init_test(cx);
1725 let fs = FakeFs::new(cx.executor());
1726 fs.create_dir(paths::settings_file().parent().unwrap())
1727 .await
1728 .unwrap();
1729 fs.insert_file(
1730 paths::settings_file(),
1731 json!({
1732 "agent": {
1733 "default_model": {
1734 "provider": "foo",
1735 "model": "bar"
1736 }
1737 }
1738 })
1739 .to_string()
1740 .into_bytes(),
1741 )
1742 .await;
1743 let project = Project::test(fs.clone(), [], cx).await;
1744
1745 let thread_store = cx.new(|cx| ThreadStore::new(cx));
1746
1747 // Create the agent and connection
1748 let agent = NativeAgent::new(
1749 project.clone(),
1750 thread_store,
1751 Templates::new(),
1752 None,
1753 fs.clone(),
1754 &mut cx.to_async(),
1755 )
1756 .await
1757 .unwrap();
1758 let connection = NativeAgentConnection(agent.clone());
1759
1760 // Create a thread/session
1761 let acp_thread = cx
1762 .update(|cx| {
1763 Rc::new(connection.clone()).new_thread(project.clone(), Path::new("/a"), cx)
1764 })
1765 .await
1766 .unwrap();
1767
1768 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
1769
1770 // Select a model
1771 let selector = connection.model_selector(&session_id).unwrap();
1772 let model_id = acp::ModelId::new("fake/fake");
1773 cx.update(|cx| selector.select_model(model_id.clone(), cx))
1774 .await
1775 .unwrap();
1776
1777 // Verify the thread has the selected model
1778 agent.read_with(cx, |agent, _| {
1779 let session = agent.sessions.get(&session_id).unwrap();
1780 session.thread.read_with(cx, |thread, _| {
1781 assert_eq!(thread.model().unwrap().id().0, "fake");
1782 });
1783 });
1784
1785 cx.run_until_parked();
1786
1787 // Verify settings file was updated
1788 let settings_content = fs.load(paths::settings_file()).await.unwrap();
1789 let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
1790
1791 // Check that the agent settings contain the selected model
1792 assert_eq!(
1793 settings_json["agent"]["default_model"]["model"],
1794 json!("fake")
1795 );
1796 assert_eq!(
1797 settings_json["agent"]["default_model"]["provider"],
1798 json!("fake")
1799 );
1800 }
1801
1802 #[gpui::test]
1803 async fn test_save_load_thread(cx: &mut TestAppContext) {
1804 init_test(cx);
1805 let fs = FakeFs::new(cx.executor());
1806 fs.insert_tree(
1807 "/",
1808 json!({
1809 "a": {
1810 "b.md": "Lorem"
1811 }
1812 }),
1813 )
1814 .await;
1815 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
1816 let thread_store = cx.new(|cx| ThreadStore::new(cx));
1817 let agent = NativeAgent::new(
1818 project.clone(),
1819 thread_store.clone(),
1820 Templates::new(),
1821 None,
1822 fs.clone(),
1823 &mut cx.to_async(),
1824 )
1825 .await
1826 .unwrap();
1827 let connection = Rc::new(NativeAgentConnection(agent.clone()));
1828
1829 let acp_thread = cx
1830 .update(|cx| {
1831 connection
1832 .clone()
1833 .new_thread(project.clone(), Path::new(""), cx)
1834 })
1835 .await
1836 .unwrap();
1837 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
1838 let thread = agent.read_with(cx, |agent, _| {
1839 agent.sessions.get(&session_id).unwrap().thread.clone()
1840 });
1841
1842 // Ensure empty threads are not saved, even if they get mutated.
1843 let model = Arc::new(FakeLanguageModel::default());
1844 let summary_model = Arc::new(FakeLanguageModel::default());
1845 thread.update(cx, |thread, cx| {
1846 thread.set_model(model.clone(), cx);
1847 thread.set_summarization_model(Some(summary_model.clone()), cx);
1848 });
1849 cx.run_until_parked();
1850 assert_eq!(thread_entries(&thread_store, cx), vec![]);
1851
1852 let send = acp_thread.update(cx, |thread, cx| {
1853 thread.send(
1854 vec![
1855 "What does ".into(),
1856 acp::ContentBlock::ResourceLink(acp::ResourceLink::new(
1857 "b.md",
1858 MentionUri::File {
1859 abs_path: path!("/a/b.md").into(),
1860 }
1861 .to_uri()
1862 .to_string(),
1863 )),
1864 " mean?".into(),
1865 ],
1866 cx,
1867 )
1868 });
1869 let send = cx.foreground_executor().spawn(send);
1870 cx.run_until_parked();
1871
1872 model.send_last_completion_stream_text_chunk("Lorem.");
1873 model.end_last_completion_stream();
1874 cx.run_until_parked();
1875 summary_model
1876 .send_last_completion_stream_text_chunk(&format!("Explaining {}", path!("/a/b.md")));
1877 summary_model.end_last_completion_stream();
1878
1879 send.await.unwrap();
1880 let uri = MentionUri::File {
1881 abs_path: path!("/a/b.md").into(),
1882 }
1883 .to_uri();
1884 acp_thread.read_with(cx, |thread, cx| {
1885 assert_eq!(
1886 thread.to_markdown(cx),
1887 formatdoc! {"
1888 ## User
1889
1890 What does [@b.md]({uri}) mean?
1891
1892 ## Assistant
1893
1894 Lorem.
1895
1896 "}
1897 )
1898 });
1899
1900 cx.run_until_parked();
1901
1902 // Drop the ACP thread, which should cause the session to be dropped as well.
1903 cx.update(|_| {
1904 drop(thread);
1905 drop(acp_thread);
1906 });
1907 agent.read_with(cx, |agent, _| {
1908 assert_eq!(agent.sessions.keys().cloned().collect::<Vec<_>>(), []);
1909 });
1910
1911 // Ensure the thread can be reloaded from disk.
1912 assert_eq!(
1913 thread_entries(&thread_store, cx),
1914 vec![(
1915 session_id.clone(),
1916 format!("Explaining {}", path!("/a/b.md"))
1917 )]
1918 );
1919 let acp_thread = agent
1920 .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
1921 .await
1922 .unwrap();
1923 acp_thread.read_with(cx, |thread, cx| {
1924 assert_eq!(
1925 thread.to_markdown(cx),
1926 formatdoc! {"
1927 ## User
1928
1929 What does [@b.md]({uri}) mean?
1930
1931 ## Assistant
1932
1933 Lorem.
1934
1935 "}
1936 )
1937 });
1938 }
1939
1940 fn thread_entries(
1941 thread_store: &Entity<ThreadStore>,
1942 cx: &mut TestAppContext,
1943 ) -> Vec<(acp::SessionId, String)> {
1944 thread_store.read_with(cx, |store, _| {
1945 store
1946 .entries()
1947 .map(|entry| (entry.id.clone(), entry.title.to_string()))
1948 .collect::<Vec<_>>()
1949 })
1950 }
1951
1952 fn init_test(cx: &mut TestAppContext) {
1953 env_logger::try_init().ok();
1954 cx.update(|cx| {
1955 let settings_store = SettingsStore::test(cx);
1956 cx.set_global(settings_store);
1957
1958 LanguageModelRegistry::test(cx);
1959 });
1960 }
1961}
1962
1963fn mcp_message_content_to_acp_content_block(
1964 content: context_server::types::MessageContent,
1965) -> acp::ContentBlock {
1966 match content {
1967 context_server::types::MessageContent::Text {
1968 text,
1969 annotations: _,
1970 } => text.into(),
1971 context_server::types::MessageContent::Image {
1972 data,
1973 mime_type,
1974 annotations: _,
1975 } => acp::ContentBlock::Image(acp::ImageContent::new(data, mime_type)),
1976 context_server::types::MessageContent::Audio {
1977 data,
1978 mime_type,
1979 annotations: _,
1980 } => acp::ContentBlock::Audio(acp::AudioContent::new(data, mime_type)),
1981 context_server::types::MessageContent::Resource {
1982 resource,
1983 annotations: _,
1984 } => {
1985 let mut link =
1986 acp::ResourceLink::new(resource.uri.to_string(), resource.uri.to_string());
1987 if let Some(mime_type) = resource.mime_type {
1988 link = link.mime_type(mime_type);
1989 }
1990 acp::ContentBlock::ResourceLink(link)
1991 }
1992 }
1993}