1use crate::native_agent_server::NATIVE_AGENT_SERVER_NAME;
2use crate::{
3 ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DeletePathTool, DiagnosticsTool,
4 EditFileTool, FetchTool, FindPathTool, GrepTool, ListDirectoryTool, MovePathTool, NowTool,
5 OpenTool, ReadFileTool, TerminalTool, ThinkingTool, Thread, ThreadEvent, ToolCallAuthorization,
6 UserMessageContent, WebSearchTool, templates::Templates,
7};
8use crate::{ThreadsDatabase, generate_session_id};
9use acp_thread::{AcpThread, AcpThreadMetadata, AgentModelSelector};
10use agent_client_protocol as acp;
11use agent_settings::AgentSettings;
12use anyhow::{Context as _, Result, anyhow};
13use collections::{HashSet, IndexMap};
14use fs::Fs;
15use futures::channel::mpsc;
16use futures::{StreamExt, future};
17use gpui::{
18 App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
19};
20use language_model::{LanguageModel, LanguageModelProvider, LanguageModelRegistry, SelectedModel};
21use project::{Project, ProjectItem, ProjectPath, Worktree};
22use prompt_store::{
23 ProjectContext, PromptId, PromptStore, RulesFileContext, UserRulesContext, WorktreeContext,
24};
25use settings::update_settings_file;
26use std::any::Any;
27use std::cell::RefCell;
28use std::collections::HashMap;
29use std::path::Path;
30use std::rc::Rc;
31use std::sync::Arc;
32use std::time::Duration;
33use util::ResultExt;
34
35const RULES_FILE_NAMES: [&'static str; 9] = [
36 ".rules",
37 ".cursorrules",
38 ".windsurfrules",
39 ".clinerules",
40 ".github/copilot-instructions.md",
41 "CLAUDE.md",
42 "AGENT.md",
43 "AGENTS.md",
44 "GEMINI.md",
45];
46
47const SAVE_THREAD_DEBOUNCE: Duration = Duration::from_millis(500);
48
49pub struct RulesLoadingError {
50 pub message: SharedString,
51}
52
53/// Holds both the internal Thread and the AcpThread for a session
54struct Session {
55 /// The internal thread that processes messages
56 thread: Entity<Thread>,
57 /// The ACP thread that handles protocol communication
58 acp_thread: WeakEntity<acp_thread::AcpThread>,
59 save_task: Task<Result<()>>,
60 _subscriptions: Vec<Subscription>,
61}
62
63pub struct LanguageModels {
64 /// Access language model by ID
65 models: HashMap<acp_thread::AgentModelId, Arc<dyn LanguageModel>>,
66 /// Cached list for returning language model information
67 model_list: acp_thread::AgentModelList,
68 refresh_models_rx: watch::Receiver<()>,
69 refresh_models_tx: watch::Sender<()>,
70}
71
72impl LanguageModels {
73 fn new(cx: &App) -> Self {
74 let (refresh_models_tx, refresh_models_rx) = watch::channel(());
75 let mut this = Self {
76 models: HashMap::default(),
77 model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()),
78 refresh_models_rx,
79 refresh_models_tx,
80 };
81 this.refresh_list(cx);
82 this
83 }
84
85 fn refresh_list(&mut self, cx: &App) {
86 let providers = LanguageModelRegistry::global(cx)
87 .read(cx)
88 .providers()
89 .into_iter()
90 .filter(|provider| provider.is_authenticated(cx))
91 .collect::<Vec<_>>();
92
93 let mut language_model_list = IndexMap::default();
94 let mut recommended_models = HashSet::default();
95
96 let mut recommended = Vec::new();
97 for provider in &providers {
98 for model in provider.recommended_models(cx) {
99 recommended_models.insert(model.id());
100 recommended.push(Self::map_language_model_to_info(&model, &provider));
101 }
102 }
103 if !recommended.is_empty() {
104 language_model_list.insert(
105 acp_thread::AgentModelGroupName("Recommended".into()),
106 recommended,
107 );
108 }
109
110 let mut models = HashMap::default();
111 for provider in providers {
112 let mut provider_models = Vec::new();
113 for model in provider.provided_models(cx) {
114 let model_info = Self::map_language_model_to_info(&model, &provider);
115 let model_id = model_info.id.clone();
116 if !recommended_models.contains(&model.id()) {
117 provider_models.push(model_info);
118 }
119 models.insert(model_id, model);
120 }
121 if !provider_models.is_empty() {
122 language_model_list.insert(
123 acp_thread::AgentModelGroupName(provider.name().0.clone()),
124 provider_models,
125 );
126 }
127 }
128
129 self.models = models;
130 self.model_list = acp_thread::AgentModelList::Grouped(language_model_list);
131 self.refresh_models_tx.send(()).ok();
132 }
133
134 fn watch(&self) -> watch::Receiver<()> {
135 self.refresh_models_rx.clone()
136 }
137
138 pub fn model_from_id(
139 &self,
140 model_id: &acp_thread::AgentModelId,
141 ) -> Option<Arc<dyn LanguageModel>> {
142 self.models.get(model_id).cloned()
143 }
144
145 fn map_language_model_to_info(
146 model: &Arc<dyn LanguageModel>,
147 provider: &Arc<dyn LanguageModelProvider>,
148 ) -> acp_thread::AgentModelInfo {
149 acp_thread::AgentModelInfo {
150 id: Self::model_id(model),
151 name: model.name().0,
152 icon: Some(provider.icon()),
153 }
154 }
155
156 fn model_id(model: &Arc<dyn LanguageModel>) -> acp_thread::AgentModelId {
157 acp_thread::AgentModelId(format!("{}/{}", model.provider_id().0, model.id().0).into())
158 }
159}
160
161pub struct NativeAgent {
162 /// Session ID -> Session mapping
163 sessions: HashMap<acp::SessionId, Session>,
164 /// Shared project context for all threads
165 project_context: Rc<RefCell<ProjectContext>>,
166 project_context_needs_refresh: watch::Sender<()>,
167 _maintain_project_context: Task<Result<()>>,
168 context_server_registry: Entity<ContextServerRegistry>,
169 /// Shared templates for all threads
170 templates: Arc<Templates>,
171 /// Cached model information
172 models: LanguageModels,
173 project: Entity<Project>,
174 prompt_store: Option<Entity<PromptStore>>,
175 thread_database: Arc<ThreadsDatabase>,
176 history: watch::Sender<Option<Vec<AcpThreadMetadata>>>,
177 load_history: Task<()>,
178 fs: Arc<dyn Fs>,
179 _subscriptions: Vec<Subscription>,
180}
181
182impl NativeAgent {
183 pub async fn new(
184 project: Entity<Project>,
185 templates: Arc<Templates>,
186 prompt_store: Option<Entity<PromptStore>>,
187 fs: Arc<dyn Fs>,
188 cx: &mut AsyncApp,
189 ) -> Result<Entity<NativeAgent>> {
190 log::info!("Creating new NativeAgent");
191
192 let project_context = cx
193 .update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))?
194 .await;
195
196 let thread_database = cx
197 .update(|cx| ThreadsDatabase::connect(cx))?
198 .await
199 .map_err(|e| anyhow!(e))?;
200
201 cx.new(|cx| {
202 let mut subscriptions = vec![
203 cx.subscribe(&project, Self::handle_project_event),
204 cx.subscribe(
205 &LanguageModelRegistry::global(cx),
206 Self::handle_models_updated_event,
207 ),
208 ];
209 if let Some(prompt_store) = prompt_store.as_ref() {
210 subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
211 }
212
213 let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
214 watch::channel(());
215 let mut this = Self {
216 sessions: HashMap::new(),
217 project_context: Rc::new(RefCell::new(project_context)),
218 project_context_needs_refresh: project_context_needs_refresh_tx,
219 _maintain_project_context: cx.spawn(async move |this, cx| {
220 Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
221 }),
222 context_server_registry: cx.new(|cx| {
223 ContextServerRegistry::new(project.read(cx).context_server_store(), cx)
224 }),
225 thread_database,
226 templates,
227 models: LanguageModels::new(cx),
228 project,
229 prompt_store,
230 fs,
231 history: watch::channel(None).0,
232 load_history: Task::ready(()),
233 _subscriptions: subscriptions,
234 };
235 this.reload_history(cx);
236 this
237 })
238 }
239
240 pub fn insert_session(
241 &mut self,
242 thread: Entity<Thread>,
243 acp_thread: Entity<AcpThread>,
244 cx: &mut Context<Self>,
245 ) {
246 let id = thread.read(cx).id().clone();
247 self.sessions.insert(
248 id,
249 Session {
250 thread: thread.clone(),
251 acp_thread: acp_thread.downgrade(),
252 save_task: Task::ready(Ok(())),
253 _subscriptions: vec![
254 cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
255 this.sessions.remove(acp_thread.session_id());
256 }),
257 cx.observe(&thread, |this, thread, cx| {
258 thread.update(cx, |thread, cx| {
259 thread.generate_title_if_needed(cx);
260 });
261 this.save_thread(thread.clone(), cx)
262 }),
263 ],
264 },
265 );
266 }
267
268 fn save_thread(&mut self, thread_handle: Entity<Thread>, cx: &mut Context<Self>) {
269 let thread = thread_handle.read(cx);
270 let id = thread.id().clone();
271 let Some(session) = self.sessions.get_mut(&id) else {
272 return;
273 };
274
275 let thread = thread_handle.downgrade();
276 let thread_database = self.thread_database.clone();
277 session.save_task = cx.spawn(async move |this, cx| {
278 cx.background_executor().timer(SAVE_THREAD_DEBOUNCE).await;
279 let db_thread = thread.update(cx, |thread, cx| thread.to_db(cx))?.await;
280 thread_database.save_thread(id, db_thread).await?;
281 this.update(cx, |this, cx| this.reload_history(cx))?;
282 Ok(())
283 });
284 }
285
286 fn reload_history(&mut self, cx: &mut Context<Self>) {
287 let thread_database = self.thread_database.clone();
288 self.load_history = cx.spawn(async move |this, cx| {
289 let results = cx
290 .background_spawn(async move {
291 let results = thread_database.list_threads().await?;
292 anyhow::Ok(
293 results
294 .into_iter()
295 .map(|thread| AcpThreadMetadata {
296 agent: NATIVE_AGENT_SERVER_NAME.clone(),
297 id: thread.id.into(),
298 title: thread.title,
299 updated_at: thread.updated_at,
300 })
301 .collect(),
302 )
303 })
304 .await;
305 if let Some(results) = results.log_err() {
306 this.update(cx, |this, _| this.history.send(Some(results)))
307 .ok();
308 }
309 });
310 }
311
312 pub fn models(&self) -> &LanguageModels {
313 &self.models
314 }
315
316 async fn maintain_project_context(
317 this: WeakEntity<Self>,
318 mut needs_refresh: watch::Receiver<()>,
319 cx: &mut AsyncApp,
320 ) -> Result<()> {
321 while needs_refresh.changed().await.is_ok() {
322 let project_context = this
323 .update(cx, |this, cx| {
324 Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx)
325 })?
326 .await;
327 this.update(cx, |this, _| this.project_context.replace(project_context))?;
328 }
329
330 Ok(())
331 }
332
333 fn build_project_context(
334 project: &Entity<Project>,
335 prompt_store: Option<&Entity<PromptStore>>,
336 cx: &mut App,
337 ) -> Task<ProjectContext> {
338 let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
339 let worktree_tasks = worktrees
340 .into_iter()
341 .map(|worktree| {
342 Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
343 })
344 .collect::<Vec<_>>();
345 let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
346 prompt_store.read_with(cx, |prompt_store, cx| {
347 let prompts = prompt_store.default_prompt_metadata();
348 let load_tasks = prompts.into_iter().map(|prompt_metadata| {
349 let contents = prompt_store.load(prompt_metadata.id, cx);
350 async move { (contents.await, prompt_metadata) }
351 });
352 cx.background_spawn(future::join_all(load_tasks))
353 })
354 } else {
355 Task::ready(vec![])
356 };
357
358 cx.spawn(async move |_cx| {
359 let (worktrees, default_user_rules) =
360 future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
361
362 let worktrees = worktrees
363 .into_iter()
364 .map(|(worktree, _rules_error)| {
365 // TODO: show error message
366 // if let Some(rules_error) = rules_error {
367 // this.update(cx, |_, cx| cx.emit(rules_error)).ok();
368 // }
369 worktree
370 })
371 .collect::<Vec<_>>();
372
373 let default_user_rules = default_user_rules
374 .into_iter()
375 .flat_map(|(contents, prompt_metadata)| match contents {
376 Ok(contents) => Some(UserRulesContext {
377 uuid: match prompt_metadata.id {
378 PromptId::User { uuid } => uuid,
379 PromptId::EditWorkflow => return None,
380 },
381 title: prompt_metadata.title.map(|title| title.to_string()),
382 contents,
383 }),
384 Err(_err) => {
385 // TODO: show error message
386 // this.update(cx, |_, cx| {
387 // cx.emit(RulesLoadingError {
388 // message: format!("{err:?}").into(),
389 // });
390 // })
391 // .ok();
392 None
393 }
394 })
395 .collect::<Vec<_>>();
396
397 ProjectContext::new(worktrees, default_user_rules)
398 })
399 }
400
401 fn load_worktree_info_for_system_prompt(
402 worktree: Entity<Worktree>,
403 project: Entity<Project>,
404 cx: &mut App,
405 ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
406 let tree = worktree.read(cx);
407 let root_name = tree.root_name().into();
408 let abs_path = tree.abs_path();
409
410 let mut context = WorktreeContext {
411 root_name,
412 abs_path,
413 rules_file: None,
414 };
415
416 let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
417 let Some(rules_task) = rules_task else {
418 return Task::ready((context, None));
419 };
420
421 cx.spawn(async move |_| {
422 let (rules_file, rules_file_error) = match rules_task.await {
423 Ok(rules_file) => (Some(rules_file), None),
424 Err(err) => (
425 None,
426 Some(RulesLoadingError {
427 message: format!("{err}").into(),
428 }),
429 ),
430 };
431 context.rules_file = rules_file;
432 (context, rules_file_error)
433 })
434 }
435
436 fn load_worktree_rules_file(
437 worktree: Entity<Worktree>,
438 project: Entity<Project>,
439 cx: &mut App,
440 ) -> Option<Task<Result<RulesFileContext>>> {
441 let worktree = worktree.read(cx);
442 let worktree_id = worktree.id();
443 let selected_rules_file = RULES_FILE_NAMES
444 .into_iter()
445 .filter_map(|name| {
446 worktree
447 .entry_for_path(name)
448 .filter(|entry| entry.is_file())
449 .map(|entry| entry.path.clone())
450 })
451 .next();
452
453 // Note that Cline supports `.clinerules` being a directory, but that is not currently
454 // supported. This doesn't seem to occur often in GitHub repositories.
455 selected_rules_file.map(|path_in_worktree| {
456 let project_path = ProjectPath {
457 worktree_id,
458 path: path_in_worktree.clone(),
459 };
460 let buffer_task =
461 project.update(cx, |project, cx| project.open_buffer(project_path, cx));
462 let rope_task = cx.spawn(async move |cx| {
463 buffer_task.await?.read_with(cx, |buffer, cx| {
464 let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
465 anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
466 })?
467 });
468 // Build a string from the rope on a background thread.
469 cx.background_spawn(async move {
470 let (project_entry_id, rope) = rope_task.await?;
471 anyhow::Ok(RulesFileContext {
472 path_in_worktree,
473 text: rope.to_string().trim().to_string(),
474 project_entry_id: project_entry_id.to_usize(),
475 })
476 })
477 })
478 }
479
480 fn handle_project_event(
481 &mut self,
482 _project: Entity<Project>,
483 event: &project::Event,
484 _cx: &mut Context<Self>,
485 ) {
486 match event {
487 project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
488 self.project_context_needs_refresh.send(()).ok();
489 }
490 project::Event::WorktreeUpdatedEntries(_, items) => {
491 if items.iter().any(|(path, _, _)| {
492 RULES_FILE_NAMES
493 .iter()
494 .any(|name| path.as_ref() == Path::new(name))
495 }) {
496 self.project_context_needs_refresh.send(()).ok();
497 }
498 }
499 _ => {}
500 }
501 }
502
503 fn handle_prompts_updated_event(
504 &mut self,
505 _prompt_store: Entity<PromptStore>,
506 _event: &prompt_store::PromptsUpdatedEvent,
507 _cx: &mut Context<Self>,
508 ) {
509 self.project_context_needs_refresh.send(()).ok();
510 }
511
512 fn handle_models_updated_event(
513 &mut self,
514 registry: Entity<LanguageModelRegistry>,
515 _event: &language_model::Event,
516 cx: &mut Context<Self>,
517 ) {
518 self.models.refresh_list(cx);
519 for session in self.sessions.values_mut() {
520 session.thread.update(cx, |thread, cx| {
521 let model_id = LanguageModels::model_id(&thread.model());
522 if let Some(model) = self.models.model_from_id(&model_id) {
523 thread.set_model(model.clone(), cx);
524 }
525 let summarization_model = registry
526 .read(cx)
527 .thread_summary_model()
528 .map(|model| model.model.clone());
529 thread.set_summarization_model(summarization_model, cx);
530 });
531 }
532 }
533}
534
535/// Wrapper struct that implements the AgentConnection trait
536#[derive(Clone)]
537pub struct NativeAgentConnection(pub Entity<NativeAgent>);
538
539impl NativeAgentConnection {
540 pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option<Entity<Thread>> {
541 self.0
542 .read(cx)
543 .sessions
544 .get(session_id)
545 .map(|session| session.thread.clone())
546 }
547
548 fn run_turn(
549 &self,
550 session_id: acp::SessionId,
551 cx: &mut App,
552 f: impl 'static
553 + FnOnce(Entity<Thread>, &mut App) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>,
554 ) -> Task<Result<acp::PromptResponse>> {
555 let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
556 agent
557 .sessions
558 .get_mut(&session_id)
559 .map(|s| (s.thread.clone(), s.acp_thread.clone()))
560 }) else {
561 return Task::ready(Err(anyhow!("Session not found")));
562 };
563 log::debug!("Found session for: {}", session_id);
564
565 let response_stream = match f(thread, cx) {
566 Ok(stream) => stream,
567 Err(err) => return Task::ready(Err(err)),
568 };
569 Self::handle_thread_events(response_stream, acp_thread, cx)
570 }
571
572 fn handle_thread_events(
573 mut response_stream: mpsc::UnboundedReceiver<Result<ThreadEvent>>,
574 acp_thread: WeakEntity<AcpThread>,
575 cx: &mut App,
576 ) -> Task<Result<acp::PromptResponse>> {
577 cx.spawn(async move |cx| {
578 // Handle response stream and forward to session.acp_thread
579 while let Some(result) = response_stream.next().await {
580 match result {
581 Ok(event) => {
582 log::trace!("Received completion event: {:?}", event);
583
584 match event {
585 ThreadEvent::UserMessage(message) => {
586 acp_thread.update(cx, |thread, cx| {
587 for content in message.content {
588 thread.push_user_content_block(
589 Some(message.id.clone()),
590 content.into(),
591 cx,
592 );
593 }
594 })?;
595 }
596 ThreadEvent::AgentText(text) => {
597 acp_thread.update(cx, |thread, cx| {
598 thread.push_assistant_content_block(
599 acp::ContentBlock::Text(acp::TextContent {
600 text,
601 annotations: None,
602 }),
603 false,
604 cx,
605 )
606 })?;
607 }
608 ThreadEvent::AgentThinking(text) => {
609 acp_thread.update(cx, |thread, cx| {
610 thread.push_assistant_content_block(
611 acp::ContentBlock::Text(acp::TextContent {
612 text,
613 annotations: None,
614 }),
615 true,
616 cx,
617 )
618 })?;
619 }
620 ThreadEvent::ToolCallAuthorization(ToolCallAuthorization {
621 tool_call,
622 options,
623 response,
624 }) => {
625 let recv = acp_thread.update(cx, |thread, cx| {
626 thread.request_tool_call_authorization(tool_call, options, cx)
627 })?;
628 cx.background_spawn(async move {
629 if let Some(recv) = recv.log_err()
630 && let Some(option) = recv
631 .await
632 .context("authorization sender was dropped")
633 .log_err()
634 {
635 response
636 .send(option)
637 .map(|_| anyhow!("authorization receiver was dropped"))
638 .log_err();
639 }
640 })
641 .detach();
642 }
643 ThreadEvent::ToolCall(tool_call) => {
644 acp_thread.update(cx, |thread, cx| {
645 thread.upsert_tool_call(tool_call, cx)
646 })??;
647 }
648 ThreadEvent::ToolCallUpdate(update) => {
649 acp_thread.update(cx, |thread, cx| {
650 thread.update_tool_call(update, cx)
651 })??;
652 }
653 ThreadEvent::TitleUpdate(title) => {
654 acp_thread
655 .update(cx, |thread, cx| thread.update_title(title, cx))??;
656 }
657 ThreadEvent::Stop(stop_reason) => {
658 log::debug!("Assistant message complete: {:?}", stop_reason);
659 return Ok(acp::PromptResponse { stop_reason });
660 }
661 }
662 }
663 Err(e) => {
664 log::error!("Error in model response stream: {:?}", e);
665 return Err(e);
666 }
667 }
668 }
669
670 log::info!("Response stream completed");
671 anyhow::Ok(acp::PromptResponse {
672 stop_reason: acp::StopReason::EndTurn,
673 })
674 })
675 }
676
677 fn register_tools(
678 thread: &mut Thread,
679 project: Entity<Project>,
680 action_log: Entity<action_log::ActionLog>,
681 cx: &mut Context<Thread>,
682 ) {
683 let language_registry = project.read(cx).languages().clone();
684 thread.add_tool(CopyPathTool::new(project.clone()));
685 thread.add_tool(CreateDirectoryTool::new(project.clone()));
686 thread.add_tool(DeletePathTool::new(project.clone(), action_log.clone()));
687 thread.add_tool(DiagnosticsTool::new(project.clone()));
688 thread.add_tool(EditFileTool::new(cx.weak_entity(), language_registry));
689 thread.add_tool(FetchTool::new(project.read(cx).client().http_client()));
690 thread.add_tool(FindPathTool::new(project.clone()));
691 thread.add_tool(GrepTool::new(project.clone()));
692 thread.add_tool(ListDirectoryTool::new(project.clone()));
693 thread.add_tool(MovePathTool::new(project.clone()));
694 thread.add_tool(NowTool);
695 thread.add_tool(OpenTool::new(project.clone()));
696 thread.add_tool(ReadFileTool::new(project.clone(), action_log));
697 thread.add_tool(TerminalTool::new(project.clone(), cx));
698 thread.add_tool(ThinkingTool);
699 thread.add_tool(WebSearchTool); // TODO: Enable this only if it's a zed model.
700 }
701}
702
703impl AgentModelSelector for NativeAgentConnection {
704 fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
705 log::debug!("NativeAgentConnection::list_models called");
706 let list = self.0.read(cx).models.model_list.clone();
707 Task::ready(if list.is_empty() {
708 Err(anyhow::anyhow!("No models available"))
709 } else {
710 Ok(list)
711 })
712 }
713
714 fn select_model(
715 &self,
716 session_id: acp::SessionId,
717 model_id: acp_thread::AgentModelId,
718 cx: &mut App,
719 ) -> Task<Result<()>> {
720 log::info!("Setting model for session {}: {}", session_id, model_id);
721 let Some(thread) = self
722 .0
723 .read(cx)
724 .sessions
725 .get(&session_id)
726 .map(|session| session.thread.clone())
727 else {
728 return Task::ready(Err(anyhow!("Session not found")));
729 };
730
731 let Some(model) = self.0.read(cx).models.model_from_id(&model_id) else {
732 return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
733 };
734
735 thread.update(cx, |thread, cx| {
736 thread.set_model(model.clone(), cx);
737 });
738
739 update_settings_file::<AgentSettings>(
740 self.0.read(cx).fs.clone(),
741 cx,
742 move |settings, _cx| {
743 settings.set_model(model);
744 },
745 );
746
747 Task::ready(Ok(()))
748 }
749
750 fn selected_model(
751 &self,
752 session_id: &acp::SessionId,
753 cx: &mut App,
754 ) -> Task<Result<acp_thread::AgentModelInfo>> {
755 let session_id = session_id.clone();
756
757 let Some(thread) = self
758 .0
759 .read(cx)
760 .sessions
761 .get(&session_id)
762 .map(|session| session.thread.clone())
763 else {
764 return Task::ready(Err(anyhow!("Session not found")));
765 };
766 let model = thread.read(cx).model().clone();
767 let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
768 else {
769 return Task::ready(Err(anyhow!("Provider not found")));
770 };
771 Task::ready(Ok(LanguageModels::map_language_model_to_info(
772 &model, &provider,
773 )))
774 }
775
776 fn watch(&self, cx: &mut App) -> watch::Receiver<()> {
777 self.0.read(cx).models.watch()
778 }
779}
780
781impl acp_thread::AgentConnection for NativeAgentConnection {
782 fn new_thread(
783 self: Rc<Self>,
784 project: Entity<Project>,
785 cwd: &Path,
786 cx: &mut App,
787 ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
788 let agent = self.0.clone();
789 log::info!("Creating new thread for project at: {:?}", cwd);
790
791 cx.spawn(async move |cx| {
792 log::debug!("Starting thread creation in async context");
793
794 // Generate session ID
795 let session_id = generate_session_id();
796 log::info!("Created session with ID: {}", session_id);
797
798 // Create AcpThread
799 let acp_thread = cx.update(|cx| {
800 cx.new(|cx| {
801 acp_thread::AcpThread::new(
802 "agent2",
803 self.clone(),
804 project.clone(),
805 session_id.clone(),
806 cx,
807 )
808 })
809 })?;
810 let action_log = cx.update(|cx| acp_thread.read(cx).action_log().clone())?;
811
812 // Create Thread
813 let thread = agent.update(
814 cx,
815 |agent, cx: &mut gpui::Context<NativeAgent>| -> Result<_> {
816 // Fetch default model from registry settings
817 let registry = LanguageModelRegistry::read_global(cx);
818
819 // Log available models for debugging
820 let available_count = registry.available_models(cx).count();
821 log::debug!("Total available models: {}", available_count);
822
823 let default_model = registry
824 .default_model()
825 .and_then(|default_model| {
826 agent
827 .models
828 .model_from_id(&LanguageModels::model_id(&default_model.model))
829 })
830 .ok_or_else(|| {
831 log::warn!("No default model configured in settings");
832 anyhow!(
833 "No default model. Please configure a default model in settings."
834 )
835 })?;
836
837 let summarization_model = registry.thread_summary_model().map(|c| c.model);
838
839 let thread = cx.new(|cx| {
840 let mut thread = Thread::new(
841 session_id.clone(),
842 project.clone(),
843 agent.project_context.clone(),
844 agent.context_server_registry.clone(),
845 action_log.clone(),
846 agent.templates.clone(),
847 default_model,
848 summarization_model,
849 cx,
850 );
851 Self::register_tools(&mut thread, project, action_log, cx);
852 thread
853 });
854
855 Ok(thread)
856 },
857 )??;
858
859 // Store the session
860 agent.update(cx, |agent, cx| {
861 agent.insert_session(thread, acp_thread.clone(), cx)
862 })?;
863
864 Ok(acp_thread)
865 })
866 }
867
868 fn auth_methods(&self) -> &[acp::AuthMethod] {
869 &[] // No auth for in-process
870 }
871
872 fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
873 Task::ready(Ok(()))
874 }
875
876 fn list_threads(
877 &self,
878 cx: &mut App,
879 ) -> Option<watch::Receiver<Option<Vec<AcpThreadMetadata>>>> {
880 Some(self.0.read(cx).history.receiver())
881 }
882
883 fn load_thread(
884 self: Rc<Self>,
885 project: Entity<Project>,
886 _cwd: &Path,
887 session_id: acp::SessionId,
888 cx: &mut App,
889 ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
890 let database = self.0.update(cx, |this, _| this.thread_database.clone());
891 cx.spawn(async move |cx| {
892 let db_thread = database
893 .load_thread(session_id.clone())
894 .await?
895 .context("no such thread found")?;
896
897 let acp_thread = cx.update(|cx| {
898 cx.new(|cx| {
899 acp_thread::AcpThread::new(
900 db_thread.title.clone(),
901 self.clone(),
902 project.clone(),
903 session_id.clone(),
904 cx,
905 )
906 })
907 })?;
908 let action_log = cx.update(|cx| acp_thread.read(cx).action_log().clone())?;
909 let agent = self.0.clone();
910
911 // Create Thread
912 let thread = agent.update(cx, |agent, cx| {
913 let language_model_registry = LanguageModelRegistry::global(cx);
914 let configured_model = language_model_registry
915 .update(cx, |registry, cx| {
916 db_thread
917 .model
918 .as_ref()
919 .and_then(|model| {
920 let model = SelectedModel {
921 provider: model.provider.clone().into(),
922 model: model.model.clone().into(),
923 };
924 registry.select_model(&model, cx)
925 })
926 .or_else(|| registry.default_model())
927 })
928 .context("no default model configured")?;
929
930 let model = agent
931 .models
932 .model_from_id(&LanguageModels::model_id(&configured_model.model))
933 .context("no model by id")?;
934
935 let summarization_model = language_model_registry
936 .read(cx)
937 .thread_summary_model()
938 .map(|c| c.model);
939
940 let thread = cx.new(|cx| {
941 let mut thread = Thread::from_db(
942 session_id,
943 db_thread,
944 project.clone(),
945 agent.project_context.clone(),
946 agent.context_server_registry.clone(),
947 action_log.clone(),
948 agent.templates.clone(),
949 model,
950 summarization_model,
951 cx,
952 );
953 Self::register_tools(&mut thread, project, action_log, cx);
954 thread
955 });
956
957 anyhow::Ok(thread)
958 })??;
959
960 // Store the session
961 agent.update(cx, |agent, cx| {
962 agent.insert_session(thread.clone(), acp_thread.clone(), cx)
963 })?;
964
965 let events = thread.update(cx, |thread, cx| thread.replay(cx))?;
966 cx.update(|cx| Self::handle_thread_events(events, acp_thread.downgrade(), cx))?
967 .await?;
968
969 Ok(acp_thread)
970 })
971 }
972
973 fn model_selector(&self) -> Option<Rc<dyn AgentModelSelector>> {
974 Some(Rc::new(self.clone()) as Rc<dyn AgentModelSelector>)
975 }
976
977 fn prompt(
978 &self,
979 id: Option<acp_thread::UserMessageId>,
980 params: acp::PromptRequest,
981 cx: &mut App,
982 ) -> Task<Result<acp::PromptResponse>> {
983 let id = id.expect("UserMessageId is required");
984 let session_id = params.session_id.clone();
985 log::info!("Received prompt request for session: {}", session_id);
986 log::debug!("Prompt blocks count: {}", params.prompt.len());
987
988 self.run_turn(session_id, cx, |thread, cx| {
989 let content: Vec<UserMessageContent> = params
990 .prompt
991 .into_iter()
992 .map(Into::into)
993 .collect::<Vec<_>>();
994 log::info!("Converted prompt to message: {} chars", content.len());
995 log::debug!("Message id: {:?}", id);
996 log::debug!("Message content: {:?}", content);
997
998 Ok(thread.update(cx, |thread, cx| {
999 log::info!(
1000 "Sending message to thread with model: {:?}",
1001 thread.model().name()
1002 );
1003 thread.send(id, content, cx)
1004 }))
1005 })
1006 }
1007
1008 fn resume(
1009 &self,
1010 session_id: &acp::SessionId,
1011 _cx: &mut App,
1012 ) -> Option<Rc<dyn acp_thread::AgentSessionResume>> {
1013 Some(Rc::new(NativeAgentSessionResume {
1014 connection: self.clone(),
1015 session_id: session_id.clone(),
1016 }) as _)
1017 }
1018
1019 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
1020 log::info!("Cancelling on session: {}", session_id);
1021 self.0.update(cx, |agent, cx| {
1022 if let Some(agent) = agent.sessions.get(session_id) {
1023 agent.thread.update(cx, |thread, cx| thread.cancel(cx));
1024 }
1025 });
1026 }
1027
1028 fn session_editor(
1029 &self,
1030 session_id: &agent_client_protocol::SessionId,
1031 cx: &mut App,
1032 ) -> Option<Rc<dyn acp_thread::AgentSessionEditor>> {
1033 self.0.update(cx, |agent, _cx| {
1034 agent
1035 .sessions
1036 .get(session_id)
1037 .map(|session| Rc::new(NativeAgentSessionEditor(session.thread.clone())) as _)
1038 })
1039 }
1040
1041 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1042 self
1043 }
1044}
1045
1046struct NativeAgentSessionEditor(Entity<Thread>);
1047
1048impl acp_thread::AgentSessionEditor for NativeAgentSessionEditor {
1049 fn truncate(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
1050 Task::ready(
1051 self.0
1052 .update(cx, |thread, cx| thread.truncate(message_id, cx)),
1053 )
1054 }
1055}
1056
1057struct NativeAgentSessionResume {
1058 connection: NativeAgentConnection,
1059 session_id: acp::SessionId,
1060}
1061
1062impl acp_thread::AgentSessionResume for NativeAgentSessionResume {
1063 fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>> {
1064 self.connection
1065 .run_turn(self.session_id.clone(), cx, |thread, cx| {
1066 thread.update(cx, |thread, cx| thread.resume(cx))
1067 })
1068 }
1069}
1070
1071#[cfg(test)]
1072mod tests {
1073 use crate::HistoryStore;
1074
1075 use super::*;
1076 use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelId, AgentModelInfo};
1077 use fs::FakeFs;
1078 use gpui::TestAppContext;
1079 use language_model::fake_provider::FakeLanguageModel;
1080 use serde_json::json;
1081 use settings::SettingsStore;
1082 use util::path;
1083
1084 #[gpui::test]
1085 async fn test_maintaining_project_context(cx: &mut TestAppContext) {
1086 init_test(cx);
1087 let fs = FakeFs::new(cx.executor());
1088 fs.insert_tree(
1089 "/",
1090 json!({
1091 "a": {}
1092 }),
1093 )
1094 .await;
1095 let project = Project::test(fs.clone(), [], cx).await;
1096 let agent = NativeAgent::new(
1097 project.clone(),
1098 Templates::new(),
1099 None,
1100 fs.clone(),
1101 &mut cx.to_async(),
1102 )
1103 .await
1104 .unwrap();
1105 agent.read_with(cx, |agent, _| {
1106 assert_eq!(agent.project_context.borrow().worktrees, vec![])
1107 });
1108
1109 let worktree = project
1110 .update(cx, |project, cx| project.create_worktree("/a", true, cx))
1111 .await
1112 .unwrap();
1113 cx.run_until_parked();
1114 agent.read_with(cx, |agent, _| {
1115 assert_eq!(
1116 agent.project_context.borrow().worktrees,
1117 vec![WorktreeContext {
1118 root_name: "a".into(),
1119 abs_path: Path::new("/a").into(),
1120 rules_file: None
1121 }]
1122 )
1123 });
1124
1125 // Creating `/a/.rules` updates the project context.
1126 fs.insert_file("/a/.rules", Vec::new()).await;
1127 cx.run_until_parked();
1128 agent.read_with(cx, |agent, cx| {
1129 let rules_entry = worktree.read(cx).entry_for_path(".rules").unwrap();
1130 assert_eq!(
1131 agent.project_context.borrow().worktrees,
1132 vec![WorktreeContext {
1133 root_name: "a".into(),
1134 abs_path: Path::new("/a").into(),
1135 rules_file: Some(RulesFileContext {
1136 path_in_worktree: Path::new(".rules").into(),
1137 text: "".into(),
1138 project_entry_id: rules_entry.id.to_usize()
1139 })
1140 }]
1141 )
1142 });
1143 }
1144
1145 #[gpui::test]
1146 async fn test_listing_models(cx: &mut TestAppContext) {
1147 init_test(cx);
1148 let fs = FakeFs::new(cx.executor());
1149 fs.insert_tree("/", json!({ "a": {} })).await;
1150 let project = Project::test(fs.clone(), [], cx).await;
1151 let connection = NativeAgentConnection(
1152 NativeAgent::new(
1153 project.clone(),
1154 Templates::new(),
1155 None,
1156 fs.clone(),
1157 &mut cx.to_async(),
1158 )
1159 .await
1160 .unwrap(),
1161 );
1162
1163 let models = cx.update(|cx| connection.list_models(cx)).await.unwrap();
1164
1165 let acp_thread::AgentModelList::Grouped(models) = models else {
1166 panic!("Unexpected model group");
1167 };
1168 assert_eq!(
1169 models,
1170 IndexMap::from_iter([(
1171 AgentModelGroupName("Fake".into()),
1172 vec![AgentModelInfo {
1173 id: AgentModelId("fake/fake".into()),
1174 name: "Fake".into(),
1175 icon: Some(ui::IconName::ZedAssistant),
1176 }]
1177 )])
1178 );
1179 }
1180
1181 #[gpui::test]
1182 async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) {
1183 init_test(cx);
1184 let fs = FakeFs::new(cx.executor());
1185 fs.create_dir(paths::settings_file().parent().unwrap())
1186 .await
1187 .unwrap();
1188 fs.insert_file(
1189 paths::settings_file(),
1190 json!({
1191 "agent": {
1192 "default_model": {
1193 "provider": "foo",
1194 "model": "bar"
1195 }
1196 }
1197 })
1198 .to_string()
1199 .into_bytes(),
1200 )
1201 .await;
1202 let project = Project::test(fs.clone(), [], cx).await;
1203
1204 // Create the agent and connection
1205 let agent = NativeAgent::new(
1206 project.clone(),
1207 Templates::new(),
1208 None,
1209 fs.clone(),
1210 &mut cx.to_async(),
1211 )
1212 .await
1213 .unwrap();
1214 let connection = NativeAgentConnection(agent.clone());
1215
1216 // Create a thread/session
1217 let acp_thread = cx
1218 .update(|cx| {
1219 Rc::new(connection.clone()).new_thread(project.clone(), Path::new("/a"), cx)
1220 })
1221 .await
1222 .unwrap();
1223
1224 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
1225
1226 // Select a model
1227 let model_id = AgentModelId("fake/fake".into());
1228 cx.update(|cx| connection.select_model(session_id.clone(), model_id.clone(), cx))
1229 .await
1230 .unwrap();
1231
1232 // Verify the thread has the selected model
1233 agent.read_with(cx, |agent, _| {
1234 let session = agent.sessions.get(&session_id).unwrap();
1235 session.thread.read_with(cx, |thread, _| {
1236 assert_eq!(thread.model().id().0, "fake");
1237 });
1238 });
1239
1240 cx.run_until_parked();
1241
1242 // Verify settings file was updated
1243 let settings_content = fs.load(paths::settings_file()).await.unwrap();
1244 let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
1245
1246 // Check that the agent settings contain the selected model
1247 assert_eq!(
1248 settings_json["agent"]["default_model"]["model"],
1249 json!("fake")
1250 );
1251 assert_eq!(
1252 settings_json["agent"]["default_model"]["provider"],
1253 json!("fake")
1254 );
1255 }
1256
1257 #[gpui::test]
1258 async fn test_history(cx: &mut TestAppContext) {
1259 init_test(cx);
1260 let fs = FakeFs::new(cx.executor());
1261 let project = Project::test(fs.clone(), [], cx).await;
1262
1263 let agent = NativeAgent::new(
1264 project.clone(),
1265 Templates::new(),
1266 None,
1267 fs.clone(),
1268 &mut cx.to_async(),
1269 )
1270 .await
1271 .unwrap();
1272 let connection = NativeAgentConnection(agent.clone());
1273 let history_store = cx.new(|cx| {
1274 let mut store = HistoryStore::new(cx);
1275 store.register_agent(NATIVE_AGENT_SERVER_NAME.clone(), &connection, cx);
1276 store
1277 });
1278
1279 let acp_thread = cx
1280 .update(|cx| {
1281 Rc::new(connection.clone()).new_thread(project.clone(), Path::new(path!("")), cx)
1282 })
1283 .await
1284 .unwrap();
1285 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
1286 let selector = connection.model_selector().unwrap();
1287
1288 let summarization_model: Arc<dyn LanguageModel> =
1289 Arc::new(FakeLanguageModel::default()) as _;
1290
1291 agent.update(cx, |agent, cx| {
1292 let thread = agent.sessions.get(&session_id).unwrap().thread.clone();
1293 thread.update(cx, |thread, cx| {
1294 thread.set_summarization_model(Some(summarization_model.clone()), cx);
1295 })
1296 });
1297
1298 let model = cx
1299 .update(|cx| selector.selected_model(&session_id, cx))
1300 .await
1301 .expect("selected_model should succeed");
1302 let model = cx
1303 .update(|cx| agent.read(cx).models().model_from_id(&model.id))
1304 .unwrap();
1305 let model = model.as_fake();
1306
1307 let send = acp_thread.update(cx, |thread, cx| thread.send_raw("Hi", cx));
1308 let send = cx.foreground_executor().spawn(send);
1309 cx.run_until_parked();
1310 model.send_last_completion_stream_text_chunk("Hey");
1311 model.end_last_completion_stream();
1312 send.await.unwrap();
1313
1314 summarization_model
1315 .as_fake()
1316 .send_last_completion_stream_text_chunk("Saying Hello");
1317 summarization_model.as_fake().end_last_completion_stream();
1318 cx.executor().advance_clock(SAVE_THREAD_DEBOUNCE);
1319
1320 let history = history_store.update(cx, |store, cx| store.entries(cx));
1321 assert_eq!(history.len(), 1);
1322 assert_eq!(history[0].title(), "Saying Hello");
1323 }
1324
1325 fn init_test(cx: &mut TestAppContext) {
1326 env_logger::try_init().ok();
1327 cx.update(|cx| {
1328 let settings_store = SettingsStore::test(cx);
1329 cx.set_global(settings_store);
1330 Project::init_settings(cx);
1331 agent_settings::init(cx);
1332 language::init(cx);
1333 LanguageModelRegistry::test(cx);
1334 });
1335 }
1336}