1use crate::{
2 ContextServerRegistry, Thread, ThreadEvent, ThreadsDatabase, ToolCallAuthorization,
3 UserMessageContent, templates::Templates,
4};
5use crate::{HistoryStore, TitleUpdated, TokenUsageUpdated};
6use acp_thread::{AcpThread, AgentModelSelector};
7use action_log::ActionLog;
8use agent_client_protocol as acp;
9use agent_settings::AgentSettings;
10use anyhow::{Context as _, Result, anyhow};
11use collections::{HashSet, IndexMap};
12use fs::Fs;
13use futures::channel::mpsc;
14use futures::{StreamExt, future};
15use gpui::{
16 App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
17};
18use language_model::{LanguageModel, LanguageModelProvider, LanguageModelRegistry};
19use project::{Project, ProjectItem, ProjectPath, Worktree};
20use prompt_store::{
21 ProjectContext, PromptId, PromptStore, RulesFileContext, UserRulesContext, WorktreeContext,
22};
23use settings::update_settings_file;
24use std::any::Any;
25use std::collections::HashMap;
26use std::path::Path;
27use std::rc::Rc;
28use std::sync::Arc;
29use util::ResultExt;
30
31const RULES_FILE_NAMES: [&str; 9] = [
32 ".rules",
33 ".cursorrules",
34 ".windsurfrules",
35 ".clinerules",
36 ".github/copilot-instructions.md",
37 "CLAUDE.md",
38 "AGENT.md",
39 "AGENTS.md",
40 "GEMINI.md",
41];
42
43pub struct RulesLoadingError {
44 pub message: SharedString,
45}
46
47/// Holds both the internal Thread and the AcpThread for a session
48struct Session {
49 /// The internal thread that processes messages
50 thread: Entity<Thread>,
51 /// The ACP thread that handles protocol communication
52 acp_thread: WeakEntity<acp_thread::AcpThread>,
53 pending_save: Task<()>,
54 _subscriptions: Vec<Subscription>,
55}
56
57pub struct LanguageModels {
58 /// Access language model by ID
59 models: HashMap<acp_thread::AgentModelId, Arc<dyn LanguageModel>>,
60 /// Cached list for returning language model information
61 model_list: acp_thread::AgentModelList,
62 refresh_models_rx: watch::Receiver<()>,
63 refresh_models_tx: watch::Sender<()>,
64 _authenticate_all_providers_task: Task<()>,
65}
66
67impl LanguageModels {
68 fn new(cx: &mut App) -> Self {
69 let (refresh_models_tx, refresh_models_rx) = watch::channel(());
70
71 let mut this = Self {
72 models: HashMap::default(),
73 model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()),
74 refresh_models_rx,
75 refresh_models_tx,
76 _authenticate_all_providers_task: Self::authenticate_all_language_model_providers(cx),
77 };
78 this.refresh_list(cx);
79 this
80 }
81
82 fn refresh_list(&mut self, cx: &App) {
83 let providers = LanguageModelRegistry::global(cx)
84 .read(cx)
85 .providers()
86 .into_iter()
87 .filter(|provider| provider.is_authenticated(cx))
88 .collect::<Vec<_>>();
89
90 let mut language_model_list = IndexMap::default();
91 let mut recommended_models = HashSet::default();
92
93 let mut recommended = Vec::new();
94 for provider in &providers {
95 for model in provider.recommended_models(cx) {
96 recommended_models.insert((model.provider_id(), model.id()));
97 recommended.push(Self::map_language_model_to_info(&model, provider));
98 }
99 }
100 if !recommended.is_empty() {
101 language_model_list.insert(
102 acp_thread::AgentModelGroupName("Recommended".into()),
103 recommended,
104 );
105 }
106
107 let mut models = HashMap::default();
108 for provider in providers {
109 let mut provider_models = Vec::new();
110 for model in provider.provided_models(cx) {
111 let model_info = Self::map_language_model_to_info(&model, &provider);
112 let model_id = model_info.id.clone();
113 if !recommended_models.contains(&(model.provider_id(), model.id())) {
114 provider_models.push(model_info);
115 }
116 models.insert(model_id, model);
117 }
118 if !provider_models.is_empty() {
119 language_model_list.insert(
120 acp_thread::AgentModelGroupName(provider.name().0.clone()),
121 provider_models,
122 );
123 }
124 }
125
126 self.models = models;
127 self.model_list = acp_thread::AgentModelList::Grouped(language_model_list);
128 self.refresh_models_tx.send(()).ok();
129 }
130
131 fn watch(&self) -> watch::Receiver<()> {
132 self.refresh_models_rx.clone()
133 }
134
135 pub fn model_from_id(
136 &self,
137 model_id: &acp_thread::AgentModelId,
138 ) -> Option<Arc<dyn LanguageModel>> {
139 self.models.get(model_id).cloned()
140 }
141
142 fn map_language_model_to_info(
143 model: &Arc<dyn LanguageModel>,
144 provider: &Arc<dyn LanguageModelProvider>,
145 ) -> acp_thread::AgentModelInfo {
146 acp_thread::AgentModelInfo {
147 id: Self::model_id(model),
148 name: model.name().0,
149 icon: Some(provider.icon()),
150 }
151 }
152
153 fn model_id(model: &Arc<dyn LanguageModel>) -> acp_thread::AgentModelId {
154 acp_thread::AgentModelId(format!("{}/{}", model.provider_id().0, model.id().0).into())
155 }
156
157 fn authenticate_all_language_model_providers(cx: &mut App) -> Task<()> {
158 let authenticate_all_providers = LanguageModelRegistry::global(cx)
159 .read(cx)
160 .providers()
161 .iter()
162 .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
163 .collect::<Vec<_>>();
164
165 cx.background_spawn(async move {
166 for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
167 if let Err(err) = authenticate_task.await {
168 if matches!(err, language_model::AuthenticateError::CredentialsNotFound) {
169 // Since we're authenticating these providers in the
170 // background for the purposes of populating the
171 // language selector, we don't care about providers
172 // where the credentials are not found.
173 } else {
174 // Some providers have noisy failure states that we
175 // don't want to spam the logs with every time the
176 // language model selector is initialized.
177 //
178 // Ideally these should have more clear failure modes
179 // that we know are safe to ignore here, like what we do
180 // with `CredentialsNotFound` above.
181 match provider_id.0.as_ref() {
182 "lmstudio" | "ollama" => {
183 // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
184 //
185 // These fail noisily, so we don't log them.
186 }
187 "copilot_chat" => {
188 // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
189 }
190 _ => {
191 log::error!(
192 "Failed to authenticate provider: {}: {err}",
193 provider_name.0
194 );
195 }
196 }
197 }
198 }
199 }
200 })
201 }
202}
203
204pub struct NativeAgent {
205 /// Session ID -> Session mapping
206 sessions: HashMap<acp::SessionId, Session>,
207 history: Entity<HistoryStore>,
208 /// Shared project context for all threads
209 project_context: Entity<ProjectContext>,
210 project_context_needs_refresh: watch::Sender<()>,
211 _maintain_project_context: Task<Result<()>>,
212 context_server_registry: Entity<ContextServerRegistry>,
213 /// Shared templates for all threads
214 templates: Arc<Templates>,
215 /// Cached model information
216 models: LanguageModels,
217 project: Entity<Project>,
218 prompt_store: Option<Entity<PromptStore>>,
219 fs: Arc<dyn Fs>,
220 _subscriptions: Vec<Subscription>,
221}
222
223impl NativeAgent {
224 pub async fn new(
225 project: Entity<Project>,
226 history: Entity<HistoryStore>,
227 templates: Arc<Templates>,
228 prompt_store: Option<Entity<PromptStore>>,
229 fs: Arc<dyn Fs>,
230 cx: &mut AsyncApp,
231 ) -> Result<Entity<NativeAgent>> {
232 log::debug!("Creating new NativeAgent");
233
234 let project_context = cx
235 .update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))?
236 .await;
237
238 cx.new(|cx| {
239 let mut subscriptions = vec![
240 cx.subscribe(&project, Self::handle_project_event),
241 cx.subscribe(
242 &LanguageModelRegistry::global(cx),
243 Self::handle_models_updated_event,
244 ),
245 ];
246 if let Some(prompt_store) = prompt_store.as_ref() {
247 subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
248 }
249
250 let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
251 watch::channel(());
252 Self {
253 sessions: HashMap::new(),
254 history,
255 project_context: cx.new(|_| project_context),
256 project_context_needs_refresh: project_context_needs_refresh_tx,
257 _maintain_project_context: cx.spawn(async move |this, cx| {
258 Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
259 }),
260 context_server_registry: cx.new(|cx| {
261 ContextServerRegistry::new(project.read(cx).context_server_store(), cx)
262 }),
263 templates,
264 models: LanguageModels::new(cx),
265 project,
266 prompt_store,
267 fs,
268 _subscriptions: subscriptions,
269 }
270 })
271 }
272
273 fn register_session(
274 &mut self,
275 thread_handle: Entity<Thread>,
276 cx: &mut Context<Self>,
277 ) -> Entity<AcpThread> {
278 let connection = Rc::new(NativeAgentConnection(cx.entity()));
279 let registry = LanguageModelRegistry::read_global(cx);
280 let summarization_model = registry.thread_summary_model().map(|c| c.model);
281
282 thread_handle.update(cx, |thread, cx| {
283 thread.set_summarization_model(summarization_model, cx);
284 thread.add_default_tools(cx)
285 });
286
287 let thread = thread_handle.read(cx);
288 let session_id = thread.id().clone();
289 let title = thread.title();
290 let project = thread.project.clone();
291 let action_log = thread.action_log.clone();
292 let prompt_capabilities_rx = thread.prompt_capabilities_rx.clone();
293 let acp_thread = cx.new(|cx| {
294 acp_thread::AcpThread::new(
295 title,
296 connection,
297 project.clone(),
298 action_log.clone(),
299 session_id.clone(),
300 prompt_capabilities_rx,
301 cx,
302 )
303 });
304 let subscriptions = vec![
305 cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
306 this.sessions.remove(acp_thread.session_id());
307 }),
308 cx.subscribe(&thread_handle, Self::handle_thread_title_updated),
309 cx.subscribe(&thread_handle, Self::handle_thread_token_usage_updated),
310 cx.observe(&thread_handle, move |this, thread, cx| {
311 this.save_thread(thread, cx)
312 }),
313 ];
314
315 self.sessions.insert(
316 session_id,
317 Session {
318 thread: thread_handle,
319 acp_thread: acp_thread.downgrade(),
320 _subscriptions: subscriptions,
321 pending_save: Task::ready(()),
322 },
323 );
324 acp_thread
325 }
326
327 pub fn models(&self) -> &LanguageModels {
328 &self.models
329 }
330
331 async fn maintain_project_context(
332 this: WeakEntity<Self>,
333 mut needs_refresh: watch::Receiver<()>,
334 cx: &mut AsyncApp,
335 ) -> Result<()> {
336 while needs_refresh.changed().await.is_ok() {
337 let project_context = this
338 .update(cx, |this, cx| {
339 Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx)
340 })?
341 .await;
342 this.update(cx, |this, cx| {
343 this.project_context = cx.new(|_| project_context);
344 })?;
345 }
346
347 Ok(())
348 }
349
350 fn build_project_context(
351 project: &Entity<Project>,
352 prompt_store: Option<&Entity<PromptStore>>,
353 cx: &mut App,
354 ) -> Task<ProjectContext> {
355 let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
356 let worktree_tasks = worktrees
357 .into_iter()
358 .map(|worktree| {
359 Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
360 })
361 .collect::<Vec<_>>();
362 let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
363 prompt_store.read_with(cx, |prompt_store, cx| {
364 let prompts = prompt_store.default_prompt_metadata();
365 let load_tasks = prompts.into_iter().map(|prompt_metadata| {
366 let contents = prompt_store.load(prompt_metadata.id, cx);
367 async move { (contents.await, prompt_metadata) }
368 });
369 cx.background_spawn(future::join_all(load_tasks))
370 })
371 } else {
372 Task::ready(vec![])
373 };
374
375 cx.spawn(async move |_cx| {
376 let (worktrees, default_user_rules) =
377 future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
378
379 let worktrees = worktrees
380 .into_iter()
381 .map(|(worktree, _rules_error)| {
382 // TODO: show error message
383 // if let Some(rules_error) = rules_error {
384 // this.update(cx, |_, cx| cx.emit(rules_error)).ok();
385 // }
386 worktree
387 })
388 .collect::<Vec<_>>();
389
390 let default_user_rules = default_user_rules
391 .into_iter()
392 .flat_map(|(contents, prompt_metadata)| match contents {
393 Ok(contents) => Some(UserRulesContext {
394 uuid: match prompt_metadata.id {
395 PromptId::User { uuid } => uuid,
396 PromptId::EditWorkflow => return None,
397 },
398 title: prompt_metadata.title.map(|title| title.to_string()),
399 contents,
400 }),
401 Err(_err) => {
402 // TODO: show error message
403 // this.update(cx, |_, cx| {
404 // cx.emit(RulesLoadingError {
405 // message: format!("{err:?}").into(),
406 // });
407 // })
408 // .ok();
409 None
410 }
411 })
412 .collect::<Vec<_>>();
413
414 ProjectContext::new(worktrees, default_user_rules)
415 })
416 }
417
418 fn load_worktree_info_for_system_prompt(
419 worktree: Entity<Worktree>,
420 project: Entity<Project>,
421 cx: &mut App,
422 ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
423 let tree = worktree.read(cx);
424 let root_name = tree.root_name().into();
425 let abs_path = tree.abs_path();
426
427 let mut context = WorktreeContext {
428 root_name,
429 abs_path,
430 rules_file: None,
431 };
432
433 let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
434 let Some(rules_task) = rules_task else {
435 return Task::ready((context, None));
436 };
437
438 cx.spawn(async move |_| {
439 let (rules_file, rules_file_error) = match rules_task.await {
440 Ok(rules_file) => (Some(rules_file), None),
441 Err(err) => (
442 None,
443 Some(RulesLoadingError {
444 message: format!("{err}").into(),
445 }),
446 ),
447 };
448 context.rules_file = rules_file;
449 (context, rules_file_error)
450 })
451 }
452
453 fn load_worktree_rules_file(
454 worktree: Entity<Worktree>,
455 project: Entity<Project>,
456 cx: &mut App,
457 ) -> Option<Task<Result<RulesFileContext>>> {
458 let worktree = worktree.read(cx);
459 let worktree_id = worktree.id();
460 let selected_rules_file = RULES_FILE_NAMES
461 .into_iter()
462 .filter_map(|name| {
463 worktree
464 .entry_for_path(name)
465 .filter(|entry| entry.is_file())
466 .map(|entry| entry.path.clone())
467 })
468 .next();
469
470 // Note that Cline supports `.clinerules` being a directory, but that is not currently
471 // supported. This doesn't seem to occur often in GitHub repositories.
472 selected_rules_file.map(|path_in_worktree| {
473 let project_path = ProjectPath {
474 worktree_id,
475 path: path_in_worktree.clone(),
476 };
477 let buffer_task =
478 project.update(cx, |project, cx| project.open_buffer(project_path, cx));
479 let rope_task = cx.spawn(async move |cx| {
480 buffer_task.await?.read_with(cx, |buffer, cx| {
481 let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
482 anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
483 })?
484 });
485 // Build a string from the rope on a background thread.
486 cx.background_spawn(async move {
487 let (project_entry_id, rope) = rope_task.await?;
488 anyhow::Ok(RulesFileContext {
489 path_in_worktree,
490 text: rope.to_string().trim().to_string(),
491 project_entry_id: project_entry_id.to_usize(),
492 })
493 })
494 })
495 }
496
497 fn handle_thread_title_updated(
498 &mut self,
499 thread: Entity<Thread>,
500 _: &TitleUpdated,
501 cx: &mut Context<Self>,
502 ) {
503 let session_id = thread.read(cx).id();
504 let Some(session) = self.sessions.get(session_id) else {
505 return;
506 };
507 let thread = thread.downgrade();
508 let acp_thread = session.acp_thread.clone();
509 cx.spawn(async move |_, cx| {
510 let title = thread.read_with(cx, |thread, _| thread.title())?;
511 let task = acp_thread.update(cx, |acp_thread, cx| acp_thread.set_title(title, cx))?;
512 task.await
513 })
514 .detach_and_log_err(cx);
515 }
516
517 fn handle_thread_token_usage_updated(
518 &mut self,
519 thread: Entity<Thread>,
520 usage: &TokenUsageUpdated,
521 cx: &mut Context<Self>,
522 ) {
523 let Some(session) = self.sessions.get(thread.read(cx).id()) else {
524 return;
525 };
526 session
527 .acp_thread
528 .update(cx, |acp_thread, cx| {
529 acp_thread.update_token_usage(usage.0.clone(), cx);
530 })
531 .ok();
532 }
533
534 fn handle_project_event(
535 &mut self,
536 _project: Entity<Project>,
537 event: &project::Event,
538 _cx: &mut Context<Self>,
539 ) {
540 match event {
541 project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
542 self.project_context_needs_refresh.send(()).ok();
543 }
544 project::Event::WorktreeUpdatedEntries(_, items) => {
545 if items.iter().any(|(path, _, _)| {
546 RULES_FILE_NAMES
547 .iter()
548 .any(|name| path.as_ref() == Path::new(name))
549 }) {
550 self.project_context_needs_refresh.send(()).ok();
551 }
552 }
553 _ => {}
554 }
555 }
556
557 fn handle_prompts_updated_event(
558 &mut self,
559 _prompt_store: Entity<PromptStore>,
560 _event: &prompt_store::PromptsUpdatedEvent,
561 _cx: &mut Context<Self>,
562 ) {
563 self.project_context_needs_refresh.send(()).ok();
564 }
565
566 fn handle_models_updated_event(
567 &mut self,
568 _registry: Entity<LanguageModelRegistry>,
569 _event: &language_model::Event,
570 cx: &mut Context<Self>,
571 ) {
572 self.models.refresh_list(cx);
573
574 let registry = LanguageModelRegistry::read_global(cx);
575 let default_model = registry.default_model().map(|m| m.model);
576 let summarization_model = registry.thread_summary_model().map(|m| m.model);
577
578 for session in self.sessions.values_mut() {
579 session.thread.update(cx, |thread, cx| {
580 if thread.model().is_none()
581 && let Some(model) = default_model.clone()
582 {
583 thread.set_model(model, cx);
584 cx.notify();
585 }
586 thread.set_summarization_model(summarization_model.clone(), cx);
587 });
588 }
589 }
590
591 pub fn open_thread(
592 &mut self,
593 id: acp::SessionId,
594 cx: &mut Context<Self>,
595 ) -> Task<Result<Entity<AcpThread>>> {
596 let database_future = ThreadsDatabase::connect(cx);
597 cx.spawn(async move |this, cx| {
598 let database = database_future.await.map_err(|err| anyhow!(err))?;
599 let db_thread = database
600 .load_thread(id.clone())
601 .await?
602 .with_context(|| format!("no thread found with ID: {id:?}"))?;
603
604 let thread = this.update(cx, |this, cx| {
605 let action_log = cx.new(|_cx| ActionLog::new(this.project.clone()));
606 cx.new(|cx| {
607 Thread::from_db(
608 id.clone(),
609 db_thread,
610 this.project.clone(),
611 this.project_context.clone(),
612 this.context_server_registry.clone(),
613 action_log.clone(),
614 this.templates.clone(),
615 cx,
616 )
617 })
618 })?;
619 let acp_thread =
620 this.update(cx, |this, cx| this.register_session(thread.clone(), cx))?;
621 let events = thread.update(cx, |thread, cx| thread.replay(cx))?;
622 cx.update(|cx| {
623 NativeAgentConnection::handle_thread_events(events, acp_thread.downgrade(), cx)
624 })?
625 .await?;
626 Ok(acp_thread)
627 })
628 }
629
630 pub fn thread_summary(
631 &mut self,
632 id: acp::SessionId,
633 cx: &mut Context<Self>,
634 ) -> Task<Result<SharedString>> {
635 let thread = self.open_thread(id.clone(), cx);
636 cx.spawn(async move |this, cx| {
637 let acp_thread = thread.await?;
638 let result = this
639 .update(cx, |this, cx| {
640 this.sessions
641 .get(&id)
642 .unwrap()
643 .thread
644 .update(cx, |thread, cx| thread.summary(cx))
645 })?
646 .await?;
647 drop(acp_thread);
648 Ok(result)
649 })
650 }
651
652 fn save_thread(&mut self, thread: Entity<Thread>, cx: &mut Context<Self>) {
653 if thread.read(cx).is_empty() {
654 return;
655 }
656
657 let database_future = ThreadsDatabase::connect(cx);
658 let (id, db_thread) =
659 thread.update(cx, |thread, cx| (thread.id().clone(), thread.to_db(cx)));
660 let Some(session) = self.sessions.get_mut(&id) else {
661 return;
662 };
663 let history = self.history.clone();
664 session.pending_save = cx.spawn(async move |_, cx| {
665 let Some(database) = database_future.await.map_err(|err| anyhow!(err)).log_err() else {
666 return;
667 };
668 let db_thread = db_thread.await;
669 database.save_thread(id, db_thread).await.log_err();
670 history.update(cx, |history, cx| history.reload(cx)).ok();
671 });
672 }
673}
674
675/// Wrapper struct that implements the AgentConnection trait
676#[derive(Clone)]
677pub struct NativeAgentConnection(pub Entity<NativeAgent>);
678
679impl NativeAgentConnection {
680 pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option<Entity<Thread>> {
681 self.0
682 .read(cx)
683 .sessions
684 .get(session_id)
685 .map(|session| session.thread.clone())
686 }
687
688 fn run_turn(
689 &self,
690 session_id: acp::SessionId,
691 cx: &mut App,
692 f: impl 'static
693 + FnOnce(Entity<Thread>, &mut App) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>,
694 ) -> Task<Result<acp::PromptResponse>> {
695 let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
696 agent
697 .sessions
698 .get_mut(&session_id)
699 .map(|s| (s.thread.clone(), s.acp_thread.clone()))
700 }) else {
701 return Task::ready(Err(anyhow!("Session not found")));
702 };
703 log::debug!("Found session for: {}", session_id);
704
705 let response_stream = match f(thread, cx) {
706 Ok(stream) => stream,
707 Err(err) => return Task::ready(Err(err)),
708 };
709 Self::handle_thread_events(response_stream, acp_thread, cx)
710 }
711
712 fn handle_thread_events(
713 mut events: mpsc::UnboundedReceiver<Result<ThreadEvent>>,
714 acp_thread: WeakEntity<AcpThread>,
715 cx: &App,
716 ) -> Task<Result<acp::PromptResponse>> {
717 cx.spawn(async move |cx| {
718 // Handle response stream and forward to session.acp_thread
719 while let Some(result) = events.next().await {
720 match result {
721 Ok(event) => {
722 log::trace!("Received completion event: {:?}", event);
723
724 match event {
725 ThreadEvent::UserMessage(message) => {
726 acp_thread.update(cx, |thread, cx| {
727 for content in message.content {
728 thread.push_user_content_block(
729 Some(message.id.clone()),
730 content.into(),
731 cx,
732 );
733 }
734 })?;
735 }
736 ThreadEvent::AgentText(text) => {
737 acp_thread.update(cx, |thread, cx| {
738 thread.push_assistant_content_block(
739 acp::ContentBlock::Text(acp::TextContent {
740 text,
741 annotations: None,
742 }),
743 false,
744 cx,
745 )
746 })?;
747 }
748 ThreadEvent::AgentThinking(text) => {
749 acp_thread.update(cx, |thread, cx| {
750 thread.push_assistant_content_block(
751 acp::ContentBlock::Text(acp::TextContent {
752 text,
753 annotations: None,
754 }),
755 true,
756 cx,
757 )
758 })?;
759 }
760 ThreadEvent::ToolCallAuthorization(ToolCallAuthorization {
761 tool_call,
762 options,
763 response,
764 }) => {
765 let outcome_task = acp_thread.update(cx, |thread, cx| {
766 thread.request_tool_call_authorization(tool_call, options, cx)
767 })??;
768 cx.background_spawn(async move {
769 if let acp::RequestPermissionOutcome::Selected { option_id } =
770 outcome_task.await
771 {
772 response
773 .send(option_id)
774 .map(|_| anyhow!("authorization receiver was dropped"))
775 .log_err();
776 }
777 })
778 .detach();
779 }
780 ThreadEvent::ToolCall(tool_call) => {
781 acp_thread.update(cx, |thread, cx| {
782 thread.upsert_tool_call(tool_call, cx)
783 })??;
784 }
785 ThreadEvent::ToolCallUpdate(update) => {
786 acp_thread.update(cx, |thread, cx| {
787 thread.update_tool_call(update, cx)
788 })??;
789 }
790 ThreadEvent::Retry(status) => {
791 acp_thread.update(cx, |thread, cx| {
792 thread.update_retry_status(status, cx)
793 })?;
794 }
795 ThreadEvent::Stop(stop_reason) => {
796 log::debug!("Assistant message complete: {:?}", stop_reason);
797 return Ok(acp::PromptResponse { stop_reason });
798 }
799 }
800 }
801 Err(e) => {
802 log::error!("Error in model response stream: {:?}", e);
803 return Err(e);
804 }
805 }
806 }
807
808 log::debug!("Response stream completed");
809 anyhow::Ok(acp::PromptResponse {
810 stop_reason: acp::StopReason::EndTurn,
811 })
812 })
813 }
814}
815
816impl AgentModelSelector for NativeAgentConnection {
817 fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
818 log::debug!("NativeAgentConnection::list_models called");
819 let list = self.0.read(cx).models.model_list.clone();
820 Task::ready(if list.is_empty() {
821 Err(anyhow::anyhow!("No models available"))
822 } else {
823 Ok(list)
824 })
825 }
826
827 fn select_model(
828 &self,
829 session_id: acp::SessionId,
830 model_id: acp_thread::AgentModelId,
831 cx: &mut App,
832 ) -> Task<Result<()>> {
833 log::debug!("Setting model for session {}: {}", session_id, model_id);
834 let Some(thread) = self
835 .0
836 .read(cx)
837 .sessions
838 .get(&session_id)
839 .map(|session| session.thread.clone())
840 else {
841 return Task::ready(Err(anyhow!("Session not found")));
842 };
843
844 let Some(model) = self.0.read(cx).models.model_from_id(&model_id) else {
845 return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
846 };
847
848 thread.update(cx, |thread, cx| {
849 thread.set_model(model.clone(), cx);
850 });
851
852 update_settings_file::<AgentSettings>(
853 self.0.read(cx).fs.clone(),
854 cx,
855 move |settings, _cx| {
856 settings.set_model(model);
857 },
858 );
859
860 Task::ready(Ok(()))
861 }
862
863 fn selected_model(
864 &self,
865 session_id: &acp::SessionId,
866 cx: &mut App,
867 ) -> Task<Result<acp_thread::AgentModelInfo>> {
868 let session_id = session_id.clone();
869
870 let Some(thread) = self
871 .0
872 .read(cx)
873 .sessions
874 .get(&session_id)
875 .map(|session| session.thread.clone())
876 else {
877 return Task::ready(Err(anyhow!("Session not found")));
878 };
879 let Some(model) = thread.read(cx).model() else {
880 return Task::ready(Err(anyhow!("Model not found")));
881 };
882 let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
883 else {
884 return Task::ready(Err(anyhow!("Provider not found")));
885 };
886 Task::ready(Ok(LanguageModels::map_language_model_to_info(
887 model, &provider,
888 )))
889 }
890
891 fn watch(&self, cx: &mut App) -> watch::Receiver<()> {
892 self.0.read(cx).models.watch()
893 }
894}
895
896impl acp_thread::AgentConnection for NativeAgentConnection {
897 fn new_thread(
898 self: Rc<Self>,
899 project: Entity<Project>,
900 cwd: &Path,
901 cx: &mut App,
902 ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
903 let agent = self.0.clone();
904 log::debug!("Creating new thread for project at: {:?}", cwd);
905
906 cx.spawn(async move |cx| {
907 log::debug!("Starting thread creation in async context");
908
909 // Create Thread
910 let thread = agent.update(
911 cx,
912 |agent, cx: &mut gpui::Context<NativeAgent>| -> Result<_> {
913 // Fetch default model from registry settings
914 let registry = LanguageModelRegistry::read_global(cx);
915 // Log available models for debugging
916 let available_count = registry.available_models(cx).count();
917 log::debug!("Total available models: {}", available_count);
918
919 let default_model = registry.default_model().and_then(|default_model| {
920 agent
921 .models
922 .model_from_id(&LanguageModels::model_id(&default_model.model))
923 });
924 Ok(cx.new(|cx| {
925 Thread::new(
926 project.clone(),
927 agent.project_context.clone(),
928 agent.context_server_registry.clone(),
929 agent.templates.clone(),
930 default_model,
931 cx,
932 )
933 }))
934 },
935 )??;
936 agent.update(cx, |agent, cx| agent.register_session(thread, cx))
937 })
938 }
939
940 fn auth_methods(&self) -> &[acp::AuthMethod] {
941 &[] // No auth for in-process
942 }
943
944 fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
945 Task::ready(Ok(()))
946 }
947
948 fn model_selector(&self) -> Option<Rc<dyn AgentModelSelector>> {
949 Some(Rc::new(self.clone()) as Rc<dyn AgentModelSelector>)
950 }
951
952 fn prompt(
953 &self,
954 id: Option<acp_thread::UserMessageId>,
955 params: acp::PromptRequest,
956 cx: &mut App,
957 ) -> Task<Result<acp::PromptResponse>> {
958 let id = id.expect("UserMessageId is required");
959 let session_id = params.session_id.clone();
960 log::info!("Received prompt request for session: {}", session_id);
961 log::debug!("Prompt blocks count: {}", params.prompt.len());
962
963 self.run_turn(session_id, cx, |thread, cx| {
964 let content: Vec<UserMessageContent> = params
965 .prompt
966 .into_iter()
967 .map(Into::into)
968 .collect::<Vec<_>>();
969 log::debug!("Converted prompt to message: {} chars", content.len());
970 log::debug!("Message id: {:?}", id);
971 log::debug!("Message content: {:?}", content);
972
973 thread.update(cx, |thread, cx| thread.send(id, content, cx))
974 })
975 }
976
977 fn resume(
978 &self,
979 session_id: &acp::SessionId,
980 _cx: &App,
981 ) -> Option<Rc<dyn acp_thread::AgentSessionResume>> {
982 Some(Rc::new(NativeAgentSessionResume {
983 connection: self.clone(),
984 session_id: session_id.clone(),
985 }) as _)
986 }
987
988 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
989 log::info!("Cancelling on session: {}", session_id);
990 self.0.update(cx, |agent, cx| {
991 if let Some(agent) = agent.sessions.get(session_id) {
992 agent.thread.update(cx, |thread, cx| thread.cancel(cx));
993 }
994 });
995 }
996
997 fn truncate(
998 &self,
999 session_id: &agent_client_protocol::SessionId,
1000 cx: &App,
1001 ) -> Option<Rc<dyn acp_thread::AgentSessionTruncate>> {
1002 self.0.read_with(cx, |agent, _cx| {
1003 agent.sessions.get(session_id).map(|session| {
1004 Rc::new(NativeAgentSessionEditor {
1005 thread: session.thread.clone(),
1006 acp_thread: session.acp_thread.clone(),
1007 }) as _
1008 })
1009 })
1010 }
1011
1012 fn set_title(
1013 &self,
1014 session_id: &acp::SessionId,
1015 _cx: &App,
1016 ) -> Option<Rc<dyn acp_thread::AgentSessionSetTitle>> {
1017 Some(Rc::new(NativeAgentSessionSetTitle {
1018 connection: self.clone(),
1019 session_id: session_id.clone(),
1020 }) as _)
1021 }
1022
1023 fn telemetry(&self) -> Option<Rc<dyn acp_thread::AgentTelemetry>> {
1024 Some(Rc::new(self.clone()) as Rc<dyn acp_thread::AgentTelemetry>)
1025 }
1026
1027 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1028 self
1029 }
1030}
1031
1032impl acp_thread::AgentTelemetry for NativeAgentConnection {
1033 fn agent_name(&self) -> String {
1034 "Zed".into()
1035 }
1036
1037 fn thread_data(
1038 &self,
1039 session_id: &acp::SessionId,
1040 cx: &mut App,
1041 ) -> Task<Result<serde_json::Value>> {
1042 let Some(session) = self.0.read(cx).sessions.get(session_id) else {
1043 return Task::ready(Err(anyhow!("Session not found")));
1044 };
1045
1046 let task = session.thread.read(cx).to_db(cx);
1047 cx.background_spawn(async move {
1048 serde_json::to_value(task.await).context("Failed to serialize thread")
1049 })
1050 }
1051}
1052
1053struct NativeAgentSessionEditor {
1054 thread: Entity<Thread>,
1055 acp_thread: WeakEntity<AcpThread>,
1056}
1057
1058impl acp_thread::AgentSessionTruncate for NativeAgentSessionEditor {
1059 fn run(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
1060 match self.thread.update(cx, |thread, cx| {
1061 thread.truncate(message_id.clone(), cx)?;
1062 Ok(thread.latest_token_usage())
1063 }) {
1064 Ok(usage) => {
1065 self.acp_thread
1066 .update(cx, |thread, cx| {
1067 thread.update_token_usage(usage, cx);
1068 })
1069 .ok();
1070 Task::ready(Ok(()))
1071 }
1072 Err(error) => Task::ready(Err(error)),
1073 }
1074 }
1075}
1076
1077struct NativeAgentSessionResume {
1078 connection: NativeAgentConnection,
1079 session_id: acp::SessionId,
1080}
1081
1082impl acp_thread::AgentSessionResume for NativeAgentSessionResume {
1083 fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>> {
1084 self.connection
1085 .run_turn(self.session_id.clone(), cx, |thread, cx| {
1086 thread.update(cx, |thread, cx| thread.resume(cx))
1087 })
1088 }
1089}
1090
1091struct NativeAgentSessionSetTitle {
1092 connection: NativeAgentConnection,
1093 session_id: acp::SessionId,
1094}
1095
1096impl acp_thread::AgentSessionSetTitle for NativeAgentSessionSetTitle {
1097 fn run(&self, title: SharedString, cx: &mut App) -> Task<Result<()>> {
1098 let Some(session) = self.connection.0.read(cx).sessions.get(&self.session_id) else {
1099 return Task::ready(Err(anyhow!("session not found")));
1100 };
1101 let thread = session.thread.clone();
1102 thread.update(cx, |thread, cx| thread.set_title(title, cx));
1103 Task::ready(Ok(()))
1104 }
1105}
1106
1107#[cfg(test)]
1108mod tests {
1109 use crate::HistoryEntryId;
1110
1111 use super::*;
1112 use acp_thread::{
1113 AgentConnection, AgentModelGroupName, AgentModelId, AgentModelInfo, MentionUri,
1114 };
1115 use fs::FakeFs;
1116 use gpui::TestAppContext;
1117 use indoc::indoc;
1118 use language_model::fake_provider::FakeLanguageModel;
1119 use serde_json::json;
1120 use settings::SettingsStore;
1121 use util::path;
1122
1123 #[gpui::test]
1124 async fn test_maintaining_project_context(cx: &mut TestAppContext) {
1125 init_test(cx);
1126 let fs = FakeFs::new(cx.executor());
1127 fs.insert_tree(
1128 "/",
1129 json!({
1130 "a": {}
1131 }),
1132 )
1133 .await;
1134 let project = Project::test(fs.clone(), [], cx).await;
1135 let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx));
1136 let history_store = cx.new(|cx| HistoryStore::new(context_store, cx));
1137 let agent = NativeAgent::new(
1138 project.clone(),
1139 history_store,
1140 Templates::new(),
1141 None,
1142 fs.clone(),
1143 &mut cx.to_async(),
1144 )
1145 .await
1146 .unwrap();
1147 agent.read_with(cx, |agent, cx| {
1148 assert_eq!(agent.project_context.read(cx).worktrees, vec![])
1149 });
1150
1151 let worktree = project
1152 .update(cx, |project, cx| project.create_worktree("/a", true, cx))
1153 .await
1154 .unwrap();
1155 cx.run_until_parked();
1156 agent.read_with(cx, |agent, cx| {
1157 assert_eq!(
1158 agent.project_context.read(cx).worktrees,
1159 vec![WorktreeContext {
1160 root_name: "a".into(),
1161 abs_path: Path::new("/a").into(),
1162 rules_file: None
1163 }]
1164 )
1165 });
1166
1167 // Creating `/a/.rules` updates the project context.
1168 fs.insert_file("/a/.rules", Vec::new()).await;
1169 cx.run_until_parked();
1170 agent.read_with(cx, |agent, cx| {
1171 let rules_entry = worktree.read(cx).entry_for_path(".rules").unwrap();
1172 assert_eq!(
1173 agent.project_context.read(cx).worktrees,
1174 vec![WorktreeContext {
1175 root_name: "a".into(),
1176 abs_path: Path::new("/a").into(),
1177 rules_file: Some(RulesFileContext {
1178 path_in_worktree: Path::new(".rules").into(),
1179 text: "".into(),
1180 project_entry_id: rules_entry.id.to_usize()
1181 })
1182 }]
1183 )
1184 });
1185 }
1186
1187 #[gpui::test]
1188 async fn test_listing_models(cx: &mut TestAppContext) {
1189 init_test(cx);
1190 let fs = FakeFs::new(cx.executor());
1191 fs.insert_tree("/", json!({ "a": {} })).await;
1192 let project = Project::test(fs.clone(), [], cx).await;
1193 let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx));
1194 let history_store = cx.new(|cx| HistoryStore::new(context_store, cx));
1195 let connection = NativeAgentConnection(
1196 NativeAgent::new(
1197 project.clone(),
1198 history_store,
1199 Templates::new(),
1200 None,
1201 fs.clone(),
1202 &mut cx.to_async(),
1203 )
1204 .await
1205 .unwrap(),
1206 );
1207
1208 let models = cx.update(|cx| connection.list_models(cx)).await.unwrap();
1209
1210 let acp_thread::AgentModelList::Grouped(models) = models else {
1211 panic!("Unexpected model group");
1212 };
1213 assert_eq!(
1214 models,
1215 IndexMap::from_iter([(
1216 AgentModelGroupName("Fake".into()),
1217 vec![AgentModelInfo {
1218 id: AgentModelId("fake/fake".into()),
1219 name: "Fake".into(),
1220 icon: Some(ui::IconName::ZedAssistant),
1221 }]
1222 )])
1223 );
1224 }
1225
1226 #[gpui::test]
1227 async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) {
1228 init_test(cx);
1229 let fs = FakeFs::new(cx.executor());
1230 fs.create_dir(paths::settings_file().parent().unwrap())
1231 .await
1232 .unwrap();
1233 fs.insert_file(
1234 paths::settings_file(),
1235 json!({
1236 "agent": {
1237 "default_model": {
1238 "provider": "foo",
1239 "model": "bar"
1240 }
1241 }
1242 })
1243 .to_string()
1244 .into_bytes(),
1245 )
1246 .await;
1247 let project = Project::test(fs.clone(), [], cx).await;
1248
1249 let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx));
1250 let history_store = cx.new(|cx| HistoryStore::new(context_store, cx));
1251
1252 // Create the agent and connection
1253 let agent = NativeAgent::new(
1254 project.clone(),
1255 history_store,
1256 Templates::new(),
1257 None,
1258 fs.clone(),
1259 &mut cx.to_async(),
1260 )
1261 .await
1262 .unwrap();
1263 let connection = NativeAgentConnection(agent.clone());
1264
1265 // Create a thread/session
1266 let acp_thread = cx
1267 .update(|cx| {
1268 Rc::new(connection.clone()).new_thread(project.clone(), Path::new("/a"), cx)
1269 })
1270 .await
1271 .unwrap();
1272
1273 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
1274
1275 // Select a model
1276 let model_id = AgentModelId("fake/fake".into());
1277 cx.update(|cx| connection.select_model(session_id.clone(), model_id.clone(), cx))
1278 .await
1279 .unwrap();
1280
1281 // Verify the thread has the selected model
1282 agent.read_with(cx, |agent, _| {
1283 let session = agent.sessions.get(&session_id).unwrap();
1284 session.thread.read_with(cx, |thread, _| {
1285 assert_eq!(thread.model().unwrap().id().0, "fake");
1286 });
1287 });
1288
1289 cx.run_until_parked();
1290
1291 // Verify settings file was updated
1292 let settings_content = fs.load(paths::settings_file()).await.unwrap();
1293 let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
1294
1295 // Check that the agent settings contain the selected model
1296 assert_eq!(
1297 settings_json["agent"]["default_model"]["model"],
1298 json!("fake")
1299 );
1300 assert_eq!(
1301 settings_json["agent"]["default_model"]["provider"],
1302 json!("fake")
1303 );
1304 }
1305
1306 #[gpui::test]
1307 #[cfg_attr(target_os = "windows", ignore)] // TODO: Fix this test on Windows
1308 async fn test_save_load_thread(cx: &mut TestAppContext) {
1309 init_test(cx);
1310 let fs = FakeFs::new(cx.executor());
1311 fs.insert_tree(
1312 "/",
1313 json!({
1314 "a": {
1315 "b.md": "Lorem"
1316 }
1317 }),
1318 )
1319 .await;
1320 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
1321 let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx));
1322 let history_store = cx.new(|cx| HistoryStore::new(context_store, cx));
1323 let agent = NativeAgent::new(
1324 project.clone(),
1325 history_store.clone(),
1326 Templates::new(),
1327 None,
1328 fs.clone(),
1329 &mut cx.to_async(),
1330 )
1331 .await
1332 .unwrap();
1333 let connection = Rc::new(NativeAgentConnection(agent.clone()));
1334
1335 let acp_thread = cx
1336 .update(|cx| {
1337 connection
1338 .clone()
1339 .new_thread(project.clone(), Path::new(""), cx)
1340 })
1341 .await
1342 .unwrap();
1343 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
1344 let thread = agent.read_with(cx, |agent, _| {
1345 agent.sessions.get(&session_id).unwrap().thread.clone()
1346 });
1347
1348 // Ensure empty threads are not saved, even if they get mutated.
1349 let model = Arc::new(FakeLanguageModel::default());
1350 let summary_model = Arc::new(FakeLanguageModel::default());
1351 thread.update(cx, |thread, cx| {
1352 thread.set_model(model.clone(), cx);
1353 thread.set_summarization_model(Some(summary_model.clone()), cx);
1354 });
1355 cx.run_until_parked();
1356 assert_eq!(history_entries(&history_store, cx), vec![]);
1357
1358 let send = acp_thread.update(cx, |thread, cx| {
1359 thread.send(
1360 vec![
1361 "What does ".into(),
1362 acp::ContentBlock::ResourceLink(acp::ResourceLink {
1363 name: "b.md".into(),
1364 uri: MentionUri::File {
1365 abs_path: path!("/a/b.md").into(),
1366 }
1367 .to_uri()
1368 .to_string(),
1369 annotations: None,
1370 description: None,
1371 mime_type: None,
1372 size: None,
1373 title: None,
1374 }),
1375 " mean?".into(),
1376 ],
1377 cx,
1378 )
1379 });
1380 let send = cx.foreground_executor().spawn(send);
1381 cx.run_until_parked();
1382
1383 model.send_last_completion_stream_text_chunk("Lorem.");
1384 model.end_last_completion_stream();
1385 cx.run_until_parked();
1386 summary_model.send_last_completion_stream_text_chunk("Explaining /a/b.md");
1387 summary_model.end_last_completion_stream();
1388
1389 send.await.unwrap();
1390 acp_thread.read_with(cx, |thread, cx| {
1391 assert_eq!(
1392 thread.to_markdown(cx),
1393 indoc! {"
1394 ## User
1395
1396 What does [@b.md](file:///a/b.md) mean?
1397
1398 ## Assistant
1399
1400 Lorem.
1401
1402 "}
1403 )
1404 });
1405
1406 cx.run_until_parked();
1407
1408 // Drop the ACP thread, which should cause the session to be dropped as well.
1409 cx.update(|_| {
1410 drop(thread);
1411 drop(acp_thread);
1412 });
1413 agent.read_with(cx, |agent, _| {
1414 assert_eq!(agent.sessions.keys().cloned().collect::<Vec<_>>(), []);
1415 });
1416
1417 // Ensure the thread can be reloaded from disk.
1418 assert_eq!(
1419 history_entries(&history_store, cx),
1420 vec![(
1421 HistoryEntryId::AcpThread(session_id.clone()),
1422 "Explaining /a/b.md".into()
1423 )]
1424 );
1425 let acp_thread = agent
1426 .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
1427 .await
1428 .unwrap();
1429 acp_thread.read_with(cx, |thread, cx| {
1430 assert_eq!(
1431 thread.to_markdown(cx),
1432 indoc! {"
1433 ## User
1434
1435 What does [@b.md](file:///a/b.md) mean?
1436
1437 ## Assistant
1438
1439 Lorem.
1440
1441 "}
1442 )
1443 });
1444 }
1445
1446 fn history_entries(
1447 history: &Entity<HistoryStore>,
1448 cx: &mut TestAppContext,
1449 ) -> Vec<(HistoryEntryId, String)> {
1450 history.read_with(cx, |history, _| {
1451 history
1452 .entries()
1453 .map(|e| (e.id(), e.title().to_string()))
1454 .collect::<Vec<_>>()
1455 })
1456 }
1457
1458 fn init_test(cx: &mut TestAppContext) {
1459 env_logger::try_init().ok();
1460 cx.update(|cx| {
1461 let settings_store = SettingsStore::test(cx);
1462 cx.set_global(settings_store);
1463 Project::init_settings(cx);
1464 agent_settings::init(cx);
1465 language::init(cx);
1466 LanguageModelRegistry::test(cx);
1467 });
1468 }
1469}