1use crate::{
2 ContextServerRegistry, DbLanguageModel, DbThread, SystemPromptTemplate, Template, Templates,
3};
4use acp_thread::{MentionUri, UserMessageId};
5use action_log::ActionLog;
6use agent::thread::{DetailedSummaryState, GitState, ProjectSnapshot, WorktreeSnapshot};
7use agent_client_protocol as acp;
8use agent_settings::{AgentProfileId, AgentSettings, CompletionMode};
9use anyhow::{Context as _, Result, anyhow};
10use assistant_tool::adapt_schema_to_format;
11use chrono::{DateTime, Utc};
12use cloud_llm_client::{CompletionIntent, CompletionRequestStatus};
13use collections::IndexMap;
14use fs::Fs;
15use futures::{
16 FutureExt,
17 channel::{mpsc, oneshot},
18 future::Shared,
19 stream::FuturesUnordered,
20};
21use git::repository::DiffType;
22use gpui::{App, AppContext, Context, Entity, SharedString, Task};
23use language_model::{
24 LanguageModel, LanguageModelCompletionEvent, LanguageModelImage, LanguageModelProviderId,
25 LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
26 LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolSchemaFormat,
27 LanguageModelToolUse, LanguageModelToolUseId, Role, StopReason, TokenUsage,
28};
29use project::{
30 Project,
31 git_store::{GitStore, RepositoryState},
32};
33use prompt_store::ProjectContext;
34use schemars::{JsonSchema, Schema};
35use serde::{Deserialize, Serialize};
36use settings::{Settings, update_settings_file};
37use smol::stream::StreamExt;
38use std::{cell::RefCell, collections::BTreeMap, path::Path, rc::Rc, sync::Arc};
39use std::{fmt::Write, ops::Range};
40use util::{ResultExt, markdown::MarkdownCodeBlock};
41use uuid::Uuid;
42
43const TOOL_CANCELED_MESSAGE: &str = "Tool canceled by user";
44
45/// The ID of the user prompt that initiated a request.
46///
47/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
48#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
49pub struct PromptId(Arc<str>);
50
51impl PromptId {
52 pub fn new() -> Self {
53 Self(Uuid::new_v4().to_string().into())
54 }
55}
56
57impl std::fmt::Display for PromptId {
58 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
59 write!(f, "{}", self.0)
60 }
61}
62
63#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
64pub enum Message {
65 User(UserMessage),
66 Agent(AgentMessage),
67 Resume,
68}
69
70impl Message {
71 pub fn as_agent_message(&self) -> Option<&AgentMessage> {
72 match self {
73 Message::Agent(agent_message) => Some(agent_message),
74 _ => None,
75 }
76 }
77
78 pub fn to_markdown(&self) -> String {
79 match self {
80 Message::User(message) => message.to_markdown(),
81 Message::Agent(message) => message.to_markdown(),
82 Message::Resume => "[resumed after tool use limit was reached]".into(),
83 }
84 }
85}
86
87#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
88pub struct UserMessage {
89 pub id: UserMessageId,
90 pub content: Vec<UserMessageContent>,
91}
92
93#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
94pub enum UserMessageContent {
95 Text(String),
96 Mention { uri: MentionUri, content: String },
97 Image(LanguageModelImage),
98}
99
100impl UserMessage {
101 pub fn to_markdown(&self) -> String {
102 let mut markdown = String::from("## User\n\n");
103
104 for content in &self.content {
105 match content {
106 UserMessageContent::Text(text) => {
107 markdown.push_str(text);
108 markdown.push('\n');
109 }
110 UserMessageContent::Image(_) => {
111 markdown.push_str("<image />\n");
112 }
113 UserMessageContent::Mention { uri, content } => {
114 if !content.is_empty() {
115 let _ = write!(&mut markdown, "{}\n\n{}\n", uri.as_link(), content);
116 } else {
117 let _ = write!(&mut markdown, "{}\n", uri.as_link());
118 }
119 }
120 }
121 }
122
123 markdown
124 }
125
126 fn to_request(&self) -> LanguageModelRequestMessage {
127 let mut message = LanguageModelRequestMessage {
128 role: Role::User,
129 content: Vec::with_capacity(self.content.len()),
130 cache: false,
131 };
132
133 const OPEN_CONTEXT: &str = "<context>\n\
134 The following items were attached by the user. \
135 They are up-to-date and don't need to be re-read.\n\n";
136
137 const OPEN_FILES_TAG: &str = "<files>";
138 const OPEN_SYMBOLS_TAG: &str = "<symbols>";
139 const OPEN_THREADS_TAG: &str = "<threads>";
140 const OPEN_FETCH_TAG: &str = "<fetched_urls>";
141 const OPEN_RULES_TAG: &str =
142 "<rules>\nThe user has specified the following rules that should be applied:\n";
143
144 let mut file_context = OPEN_FILES_TAG.to_string();
145 let mut symbol_context = OPEN_SYMBOLS_TAG.to_string();
146 let mut thread_context = OPEN_THREADS_TAG.to_string();
147 let mut fetch_context = OPEN_FETCH_TAG.to_string();
148 let mut rules_context = OPEN_RULES_TAG.to_string();
149
150 for chunk in &self.content {
151 let chunk = match chunk {
152 UserMessageContent::Text(text) => {
153 language_model::MessageContent::Text(text.clone())
154 }
155 UserMessageContent::Image(value) => {
156 language_model::MessageContent::Image(value.clone())
157 }
158 UserMessageContent::Mention { uri, content } => {
159 match uri {
160 MentionUri::File { abs_path, .. } => {
161 write!(
162 &mut symbol_context,
163 "\n{}",
164 MarkdownCodeBlock {
165 tag: &codeblock_tag(&abs_path, None),
166 text: &content.to_string(),
167 }
168 )
169 .ok();
170 }
171 MentionUri::Symbol {
172 path, line_range, ..
173 }
174 | MentionUri::Selection {
175 path, line_range, ..
176 } => {
177 write!(
178 &mut rules_context,
179 "\n{}",
180 MarkdownCodeBlock {
181 tag: &codeblock_tag(&path, Some(line_range)),
182 text: &content
183 }
184 )
185 .ok();
186 }
187 MentionUri::Thread { .. } => {
188 write!(&mut thread_context, "\n{}\n", content).ok();
189 }
190 MentionUri::TextThread { .. } => {
191 write!(&mut thread_context, "\n{}\n", content).ok();
192 }
193 MentionUri::Rule { .. } => {
194 write!(
195 &mut rules_context,
196 "\n{}",
197 MarkdownCodeBlock {
198 tag: "",
199 text: &content
200 }
201 )
202 .ok();
203 }
204 MentionUri::Fetch { url } => {
205 write!(&mut fetch_context, "\nFetch: {}\n\n{}", url, content).ok();
206 }
207 }
208
209 language_model::MessageContent::Text(uri.as_link().to_string())
210 }
211 };
212
213 message.content.push(chunk);
214 }
215
216 let len_before_context = message.content.len();
217
218 if file_context.len() > OPEN_FILES_TAG.len() {
219 file_context.push_str("</files>\n");
220 message
221 .content
222 .push(language_model::MessageContent::Text(file_context));
223 }
224
225 if symbol_context.len() > OPEN_SYMBOLS_TAG.len() {
226 symbol_context.push_str("</symbols>\n");
227 message
228 .content
229 .push(language_model::MessageContent::Text(symbol_context));
230 }
231
232 if thread_context.len() > OPEN_THREADS_TAG.len() {
233 thread_context.push_str("</threads>\n");
234 message
235 .content
236 .push(language_model::MessageContent::Text(thread_context));
237 }
238
239 if fetch_context.len() > OPEN_FETCH_TAG.len() {
240 fetch_context.push_str("</fetched_urls>\n");
241 message
242 .content
243 .push(language_model::MessageContent::Text(fetch_context));
244 }
245
246 if rules_context.len() > OPEN_RULES_TAG.len() {
247 rules_context.push_str("</user_rules>\n");
248 message
249 .content
250 .push(language_model::MessageContent::Text(rules_context));
251 }
252
253 if message.content.len() > len_before_context {
254 message.content.insert(
255 len_before_context,
256 language_model::MessageContent::Text(OPEN_CONTEXT.into()),
257 );
258 message
259 .content
260 .push(language_model::MessageContent::Text("</context>".into()));
261 }
262
263 message
264 }
265}
266
267fn codeblock_tag(full_path: &Path, line_range: Option<&Range<u32>>) -> String {
268 let mut result = String::new();
269
270 if let Some(extension) = full_path.extension().and_then(|ext| ext.to_str()) {
271 let _ = write!(result, "{} ", extension);
272 }
273
274 let _ = write!(result, "{}", full_path.display());
275
276 if let Some(range) = line_range {
277 if range.start == range.end {
278 let _ = write!(result, ":{}", range.start + 1);
279 } else {
280 let _ = write!(result, ":{}-{}", range.start + 1, range.end + 1);
281 }
282 }
283
284 result
285}
286
287impl AgentMessage {
288 pub fn to_markdown(&self) -> String {
289 let mut markdown = String::from("## Assistant\n\n");
290
291 for content in &self.content {
292 match content {
293 AgentMessageContent::Text(text) => {
294 markdown.push_str(text);
295 markdown.push('\n');
296 }
297 AgentMessageContent::Thinking { text, .. } => {
298 markdown.push_str("<think>");
299 markdown.push_str(text);
300 markdown.push_str("</think>\n");
301 }
302 AgentMessageContent::RedactedThinking(_) => {
303 markdown.push_str("<redacted_thinking />\n")
304 }
305 AgentMessageContent::ToolUse(tool_use) => {
306 markdown.push_str(&format!(
307 "**Tool Use**: {} (ID: {})\n",
308 tool_use.name, tool_use.id
309 ));
310 markdown.push_str(&format!(
311 "{}\n",
312 MarkdownCodeBlock {
313 tag: "json",
314 text: &format!("{:#}", tool_use.input)
315 }
316 ));
317 }
318 }
319 }
320
321 for tool_result in self.tool_results.values() {
322 markdown.push_str(&format!(
323 "**Tool Result**: {} (ID: {})\n\n",
324 tool_result.tool_name, tool_result.tool_use_id
325 ));
326 if tool_result.is_error {
327 markdown.push_str("**ERROR:**\n");
328 }
329
330 match &tool_result.content {
331 LanguageModelToolResultContent::Text(text) => {
332 writeln!(markdown, "{text}\n").ok();
333 }
334 LanguageModelToolResultContent::Image(_) => {
335 writeln!(markdown, "<image />\n").ok();
336 }
337 }
338
339 if let Some(output) = tool_result.output.as_ref() {
340 writeln!(
341 markdown,
342 "**Debug Output**:\n\n```json\n{}\n```\n",
343 serde_json::to_string_pretty(output).unwrap()
344 )
345 .unwrap();
346 }
347 }
348
349 markdown
350 }
351
352 pub fn to_request(&self) -> Vec<LanguageModelRequestMessage> {
353 let mut assistant_message = LanguageModelRequestMessage {
354 role: Role::Assistant,
355 content: Vec::with_capacity(self.content.len()),
356 cache: false,
357 };
358 for chunk in &self.content {
359 let chunk = match chunk {
360 AgentMessageContent::Text(text) => {
361 language_model::MessageContent::Text(text.clone())
362 }
363 AgentMessageContent::Thinking { text, signature } => {
364 language_model::MessageContent::Thinking {
365 text: text.clone(),
366 signature: signature.clone(),
367 }
368 }
369 AgentMessageContent::RedactedThinking(value) => {
370 language_model::MessageContent::RedactedThinking(value.clone())
371 }
372 AgentMessageContent::ToolUse(value) => {
373 language_model::MessageContent::ToolUse(value.clone())
374 }
375 };
376 assistant_message.content.push(chunk);
377 }
378
379 let mut user_message = LanguageModelRequestMessage {
380 role: Role::User,
381 content: Vec::new(),
382 cache: false,
383 };
384
385 for tool_result in self.tool_results.values() {
386 user_message
387 .content
388 .push(language_model::MessageContent::ToolResult(
389 tool_result.clone(),
390 ));
391 }
392
393 let mut messages = Vec::new();
394 if !assistant_message.content.is_empty() {
395 messages.push(assistant_message);
396 }
397 if !user_message.content.is_empty() {
398 messages.push(user_message);
399 }
400 messages
401 }
402}
403
404#[derive(Default, Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
405pub struct AgentMessage {
406 pub content: Vec<AgentMessageContent>,
407 pub tool_results: IndexMap<LanguageModelToolUseId, LanguageModelToolResult>,
408}
409
410#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
411pub enum AgentMessageContent {
412 Text(String),
413 Thinking {
414 text: String,
415 signature: Option<String>,
416 },
417 RedactedThinking(String),
418 ToolUse(LanguageModelToolUse),
419}
420
421#[derive(Debug)]
422pub enum ThreadEvent {
423 UserMessage(UserMessage),
424 AgentText(String),
425 AgentThinking(String),
426 ToolCall(acp::ToolCall),
427 ToolCallUpdate(acp_thread::ToolCallUpdate),
428 ToolCallAuthorization(ToolCallAuthorization),
429 Stop(acp::StopReason),
430}
431
432#[derive(Debug)]
433pub struct ToolCallAuthorization {
434 pub tool_call: acp::ToolCallUpdate,
435 pub options: Vec<acp::PermissionOption>,
436 pub response: oneshot::Sender<acp::PermissionOptionId>,
437}
438
439enum ThreadTitle {
440 None,
441 Pending(Task<()>),
442 Done(Result<SharedString>),
443}
444
445impl ThreadTitle {
446 pub fn unwrap_or_default(&self) -> SharedString {
447 if let ThreadTitle::Done(Ok(title)) = self {
448 title.clone()
449 } else {
450 "New Thread".into()
451 }
452 }
453}
454
455pub struct Thread {
456 id: acp::SessionId,
457 prompt_id: PromptId,
458 updated_at: DateTime<Utc>,
459 title: ThreadTitle,
460 summary: DetailedSummaryState,
461 messages: Vec<Message>,
462 completion_mode: CompletionMode,
463 /// Holds the task that handles agent interaction until the end of the turn.
464 /// Survives across multiple requests as the model performs tool calls and
465 /// we run tools, report their results.
466 running_turn: Option<RunningTurn>,
467 pending_message: Option<AgentMessage>,
468 tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
469 tool_use_limit_reached: bool,
470 request_token_usage: Vec<TokenUsage>,
471 cumulative_token_usage: TokenUsage,
472 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
473 context_server_registry: Entity<ContextServerRegistry>,
474 profile_id: AgentProfileId,
475 project_context: Rc<RefCell<ProjectContext>>,
476 templates: Arc<Templates>,
477 model: Arc<dyn LanguageModel>,
478 project: Entity<Project>,
479 action_log: Entity<ActionLog>,
480}
481
482impl Thread {
483 pub fn new(
484 id: acp::SessionId,
485 project: Entity<Project>,
486 project_context: Rc<RefCell<ProjectContext>>,
487 context_server_registry: Entity<ContextServerRegistry>,
488 action_log: Entity<ActionLog>,
489 templates: Arc<Templates>,
490 model: Arc<dyn LanguageModel>,
491 cx: &mut Context<Self>,
492 ) -> Self {
493 let profile_id = AgentSettings::get_global(cx).default_profile.clone();
494 Self {
495 id,
496 prompt_id: PromptId::new(),
497 updated_at: Utc::now(),
498 title: ThreadTitle::None,
499 summary: DetailedSummaryState::default(),
500 messages: Vec::new(),
501 completion_mode: CompletionMode::Normal,
502 running_turn: None,
503 pending_message: None,
504 tools: BTreeMap::default(),
505 tool_use_limit_reached: false,
506 request_token_usage: Vec::new(),
507 cumulative_token_usage: TokenUsage::default(),
508 initial_project_snapshot: {
509 let project_snapshot = Self::project_snapshot(project.clone(), cx);
510 cx.foreground_executor()
511 .spawn(async move { Some(project_snapshot.await) })
512 .shared()
513 },
514 context_server_registry,
515 profile_id,
516 project_context,
517 templates,
518 model,
519 project,
520 action_log,
521 }
522 }
523
524 pub fn id(&self) -> &acp::SessionId {
525 &self.id
526 }
527
528 pub fn from_db(
529 id: acp::SessionId,
530 db_thread: DbThread,
531 project: Entity<Project>,
532 project_context: Rc<RefCell<ProjectContext>>,
533 context_server_registry: Entity<ContextServerRegistry>,
534 action_log: Entity<ActionLog>,
535 templates: Arc<Templates>,
536 model: Arc<dyn LanguageModel>,
537 cx: &mut Context<Self>,
538 ) -> Self {
539 let profile_id = db_thread
540 .profile
541 .unwrap_or_else(|| AgentSettings::get_global(cx).default_profile.clone());
542 Self {
543 id,
544 prompt_id: PromptId::new(),
545 title: ThreadTitle::Done(Ok(db_thread.title.clone())),
546 summary: db_thread.summary,
547 messages: db_thread.messages,
548 completion_mode: CompletionMode::Normal,
549 running_turn: None,
550 pending_message: None,
551 tools: BTreeMap::default(),
552 tool_use_limit_reached: false,
553 request_token_usage: db_thread.request_token_usage.clone(),
554 cumulative_token_usage: db_thread.cumulative_token_usage.clone(),
555 initial_project_snapshot: Task::ready(db_thread.initial_project_snapshot).shared(),
556 context_server_registry,
557 profile_id,
558 project_context,
559 templates,
560 model,
561 project,
562 action_log,
563 updated_at: db_thread.updated_at, // todo!(figure out if we can remove the "recently opened" list)
564 }
565 }
566
567 pub fn to_db(&self, cx: &App) -> Task<DbThread> {
568 let initial_project_snapshot = self.initial_project_snapshot.clone();
569 let mut thread = DbThread {
570 title: self.title.unwrap_or_default(),
571 messages: self.messages.clone(),
572 updated_at: self.updated_at.clone(),
573 summary: self.summary.clone(),
574 initial_project_snapshot: None,
575 cumulative_token_usage: self.cumulative_token_usage.clone(),
576 request_token_usage: self.request_token_usage.clone(),
577 model: Some(DbLanguageModel {
578 provider: self.model.provider_id().to_string(),
579 model: self.model.name().0.to_string(),
580 }),
581 completion_mode: Some(self.completion_mode.into()),
582 profile: Some(self.profile_id.clone()),
583 };
584
585 cx.background_spawn(async move {
586 let initial_project_snapshot = initial_project_snapshot.await;
587 thread.initial_project_snapshot = initial_project_snapshot;
588 thread
589 })
590 }
591
592 pub fn replay(
593 &mut self,
594 cx: &mut Context<Self>,
595 ) -> mpsc::UnboundedReceiver<Result<ThreadEvent>> {
596 let (tx, rx) = mpsc::unbounded();
597 let stream = ThreadEventStream(tx);
598 for message in &self.messages {
599 match message {
600 Message::User(user_message) => stream.send_user_message(&user_message),
601 Message::Agent(assistant_message) => {
602 for content in &assistant_message.content {
603 match content {
604 AgentMessageContent::Text(text) => stream.send_text(text),
605 AgentMessageContent::Thinking { text, .. } => {
606 stream.send_thinking(text)
607 }
608 AgentMessageContent::RedactedThinking(_) => {}
609 AgentMessageContent::ToolUse(tool_use) => {
610 self.replay_tool_call(
611 tool_use,
612 assistant_message.tool_results.get(&tool_use.id),
613 &stream,
614 cx,
615 );
616 }
617 }
618 }
619 }
620 Message::Resume => {}
621 }
622 }
623 rx
624 }
625
626 fn replay_tool_call(
627 &self,
628 tool_use: &LanguageModelToolUse,
629 tool_result: Option<&LanguageModelToolResult>,
630 stream: &ThreadEventStream,
631 cx: &mut Context<Self>,
632 ) {
633 let Some(tool) = self.tools.get(tool_use.name.as_ref()) else {
634 stream
635 .0
636 .unbounded_send(Ok(ThreadEvent::ToolCall(acp::ToolCall {
637 id: acp::ToolCallId(tool_use.id.to_string().into()),
638 title: tool_use.name.to_string(),
639 kind: acp::ToolKind::Other,
640 status: acp::ToolCallStatus::Failed,
641 content: Vec::new(),
642 locations: Vec::new(),
643 raw_input: Some(tool_use.input.clone()),
644 raw_output: None,
645 })))
646 .ok();
647 return;
648 };
649
650 let title = tool.initial_title(tool_use.input.clone());
651 let kind = tool.kind();
652 stream.send_tool_call(&tool_use.id, title, kind, tool_use.input.clone());
653
654 let output = tool_result
655 .as_ref()
656 .and_then(|result| result.output.clone());
657 if let Some(output) = output.clone() {
658 let tool_event_stream = ToolCallEventStream::new(
659 tool_use.id.clone(),
660 stream.clone(),
661 Some(self.project.read(cx).fs().clone()),
662 );
663 tool.replay(tool_use.input.clone(), output, tool_event_stream, cx)
664 .log_err();
665 }
666
667 stream.update_tool_call_fields(
668 &tool_use.id,
669 acp::ToolCallUpdateFields {
670 status: Some(acp::ToolCallStatus::Completed),
671 raw_output: output,
672 ..Default::default()
673 },
674 );
675 }
676
677 /// Create a snapshot of the current project state including git information and unsaved buffers.
678 fn project_snapshot(
679 project: Entity<Project>,
680 cx: &mut Context<Self>,
681 ) -> Task<Arc<agent::thread::ProjectSnapshot>> {
682 let git_store = project.read(cx).git_store().clone();
683 let worktree_snapshots: Vec<_> = project
684 .read(cx)
685 .visible_worktrees(cx)
686 .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
687 .collect();
688
689 cx.spawn(async move |_, cx| {
690 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
691
692 let mut unsaved_buffers = Vec::new();
693 cx.update(|app_cx| {
694 let buffer_store = project.read(app_cx).buffer_store();
695 for buffer_handle in buffer_store.read(app_cx).buffers() {
696 let buffer = buffer_handle.read(app_cx);
697 if buffer.is_dirty() {
698 if let Some(file) = buffer.file() {
699 let path = file.path().to_string_lossy().to_string();
700 unsaved_buffers.push(path);
701 }
702 }
703 }
704 })
705 .ok();
706
707 Arc::new(ProjectSnapshot {
708 worktree_snapshots,
709 unsaved_buffer_paths: unsaved_buffers,
710 timestamp: Utc::now(),
711 })
712 })
713 }
714
715 fn worktree_snapshot(
716 worktree: Entity<project::Worktree>,
717 git_store: Entity<GitStore>,
718 cx: &App,
719 ) -> Task<agent::thread::WorktreeSnapshot> {
720 cx.spawn(async move |cx| {
721 // Get worktree path and snapshot
722 let worktree_info = cx.update(|app_cx| {
723 let worktree = worktree.read(app_cx);
724 let path = worktree.abs_path().to_string_lossy().to_string();
725 let snapshot = worktree.snapshot();
726 (path, snapshot)
727 });
728
729 let Ok((worktree_path, _snapshot)) = worktree_info else {
730 return WorktreeSnapshot {
731 worktree_path: String::new(),
732 git_state: None,
733 };
734 };
735
736 let git_state = git_store
737 .update(cx, |git_store, cx| {
738 git_store
739 .repositories()
740 .values()
741 .find(|repo| {
742 repo.read(cx)
743 .abs_path_to_repo_path(&worktree.read(cx).abs_path())
744 .is_some()
745 })
746 .cloned()
747 })
748 .ok()
749 .flatten()
750 .map(|repo| {
751 repo.update(cx, |repo, _| {
752 let current_branch =
753 repo.branch.as_ref().map(|branch| branch.name().to_owned());
754 repo.send_job(None, |state, _| async move {
755 let RepositoryState::Local { backend, .. } = state else {
756 return GitState {
757 remote_url: None,
758 head_sha: None,
759 current_branch,
760 diff: None,
761 };
762 };
763
764 let remote_url = backend.remote_url("origin");
765 let head_sha = backend.head_sha().await;
766 let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
767
768 GitState {
769 remote_url,
770 head_sha,
771 current_branch,
772 diff,
773 }
774 })
775 })
776 });
777
778 let git_state = match git_state {
779 Some(git_state) => match git_state.ok() {
780 Some(git_state) => git_state.await.ok(),
781 None => None,
782 },
783 None => None,
784 };
785
786 WorktreeSnapshot {
787 worktree_path,
788 git_state,
789 }
790 })
791 }
792
793 pub fn project(&self) -> &Entity<Project> {
794 &self.project
795 }
796
797 pub fn action_log(&self) -> &Entity<ActionLog> {
798 &self.action_log
799 }
800
801 pub fn model(&self) -> &Arc<dyn LanguageModel> {
802 &self.model
803 }
804
805 pub fn set_model(&mut self, model: Arc<dyn LanguageModel>) {
806 self.model = model;
807 }
808
809 pub fn completion_mode(&self) -> CompletionMode {
810 self.completion_mode
811 }
812
813 pub fn set_completion_mode(&mut self, mode: CompletionMode) {
814 self.completion_mode = mode;
815 }
816
817 #[cfg(any(test, feature = "test-support"))]
818 pub fn last_message(&self) -> Option<Message> {
819 if let Some(message) = self.pending_message.clone() {
820 Some(Message::Agent(message))
821 } else {
822 self.messages.last().cloned()
823 }
824 }
825
826 pub fn add_tool(&mut self, tool: impl AgentTool) {
827 self.tools.insert(tool.name(), tool.erase());
828 }
829
830 pub fn remove_tool(&mut self, name: &str) -> bool {
831 self.tools.remove(name).is_some()
832 }
833
834 pub fn profile(&self) -> &AgentProfileId {
835 &self.profile_id
836 }
837
838 pub fn set_profile(&mut self, profile_id: AgentProfileId) {
839 self.profile_id = profile_id;
840 }
841
842 pub fn cancel(&mut self) {
843 if let Some(running_turn) = self.running_turn.take() {
844 running_turn.cancel();
845 }
846 self.flush_pending_message();
847 }
848
849 pub fn truncate(&mut self, message_id: UserMessageId) -> Result<()> {
850 self.cancel();
851 let Some(position) = self.messages.iter().position(
852 |msg| matches!(msg, Message::User(UserMessage { id, .. }) if id == &message_id),
853 ) else {
854 return Err(anyhow!("Message not found"));
855 };
856 self.messages.truncate(position);
857 Ok(())
858 }
859
860 pub fn resume(
861 &mut self,
862 cx: &mut Context<Self>,
863 ) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>> {
864 anyhow::ensure!(
865 self.tool_use_limit_reached,
866 "can only resume after tool use limit is reached"
867 );
868
869 self.messages.push(Message::Resume);
870 cx.notify();
871
872 log::info!("Total messages in thread: {}", self.messages.len());
873 Ok(self.run_turn(cx))
874 }
875
876 /// Sending a message results in the model streaming a response, which could include tool calls.
877 /// After calling tools, the model will stops and waits for any outstanding tool calls to be completed and their results sent.
878 /// The returned channel will report all the occurrences in which the model stops before erroring or ending its turn.
879 pub fn send<T>(
880 &mut self,
881 id: UserMessageId,
882 content: impl IntoIterator<Item = T>,
883 cx: &mut Context<Self>,
884 ) -> mpsc::UnboundedReceiver<Result<ThreadEvent>>
885 where
886 T: Into<UserMessageContent>,
887 {
888 log::info!("Thread::send called with model: {:?}", self.model.name());
889 self.advance_prompt_id();
890
891 let content = content.into_iter().map(Into::into).collect::<Vec<_>>();
892 log::debug!("Thread::send content: {:?}", content);
893
894 self.messages
895 .push(Message::User(UserMessage { id, content }));
896 cx.notify();
897
898 log::info!("Total messages in thread: {}", self.messages.len());
899 self.run_turn(cx)
900 }
901
902 fn run_turn(&mut self, cx: &mut Context<Self>) -> mpsc::UnboundedReceiver<Result<ThreadEvent>> {
903 self.cancel();
904
905 let model = self.model.clone();
906 let (events_tx, events_rx) = mpsc::unbounded::<Result<ThreadEvent>>();
907 let event_stream = ThreadEventStream(events_tx);
908 let message_ix = self.messages.len().saturating_sub(1);
909 self.tool_use_limit_reached = false;
910 self.running_turn = Some(RunningTurn {
911 event_stream: event_stream.clone(),
912 _task: cx.spawn(async move |this, cx| {
913 log::info!("Starting agent turn execution");
914 let turn_result: Result<()> = async {
915 let mut completion_intent = CompletionIntent::UserPrompt;
916 loop {
917 log::debug!(
918 "Building completion request with intent: {:?}",
919 completion_intent
920 );
921 let request = this.update(cx, |this, cx| {
922 this.build_completion_request(completion_intent, cx)
923 })?;
924
925 log::info!("Calling model.stream_completion");
926 let mut events = model.stream_completion(request, cx).await?;
927 log::debug!("Stream completion started successfully");
928
929 let mut tool_use_limit_reached = false;
930 let mut tool_uses = FuturesUnordered::new();
931 while let Some(event) = events.next().await {
932 match event? {
933 LanguageModelCompletionEvent::StatusUpdate(
934 CompletionRequestStatus::ToolUseLimitReached,
935 ) => {
936 tool_use_limit_reached = true;
937 }
938 LanguageModelCompletionEvent::Stop(reason) => {
939 event_stream.send_stop(reason);
940 if reason == StopReason::Refusal {
941 this.update(cx, |this, _cx| {
942 this.flush_pending_message();
943 this.messages.truncate(message_ix);
944 })?;
945 return Ok(());
946 }
947 }
948 event => {
949 log::trace!("Received completion event: {:?}", event);
950 this.update(cx, |this, cx| {
951 tool_uses.extend(this.handle_streamed_completion_event(
952 event,
953 &event_stream,
954 cx,
955 ));
956 })
957 .ok();
958 }
959 }
960 }
961
962 let used_tools = tool_uses.is_empty();
963 while let Some(tool_result) = tool_uses.next().await {
964 log::info!("Tool finished {:?}", tool_result);
965
966 event_stream.update_tool_call_fields(
967 &tool_result.tool_use_id,
968 acp::ToolCallUpdateFields {
969 status: Some(if tool_result.is_error {
970 acp::ToolCallStatus::Failed
971 } else {
972 acp::ToolCallStatus::Completed
973 }),
974 raw_output: tool_result.output.clone(),
975 ..Default::default()
976 },
977 );
978 this.update(cx, |this, _cx| {
979 this.pending_message()
980 .tool_results
981 .insert(tool_result.tool_use_id.clone(), tool_result);
982 })
983 .ok();
984 }
985
986 if tool_use_limit_reached {
987 log::info!("Tool use limit reached, completing turn");
988 this.update(cx, |this, _cx| this.tool_use_limit_reached = true)?;
989 return Err(language_model::ToolUseLimitReachedError.into());
990 } else if used_tools {
991 log::info!("No tool uses found, completing turn");
992 return Ok(());
993 } else {
994 this.update(cx, |this, _| this.flush_pending_message())?;
995 completion_intent = CompletionIntent::ToolResults;
996 }
997 }
998 }
999 .await;
1000
1001 if let Err(error) = turn_result {
1002 log::error!("Turn execution failed: {:?}", error);
1003 event_stream.send_error(error);
1004 } else {
1005 log::info!("Turn execution completed successfully");
1006 }
1007
1008 this.update(cx, |this, _| {
1009 this.flush_pending_message();
1010 this.running_turn.take();
1011 })
1012 .ok();
1013 }),
1014 });
1015 events_rx
1016 }
1017
1018 pub fn build_system_message(&self) -> LanguageModelRequestMessage {
1019 log::debug!("Building system message");
1020 let prompt = SystemPromptTemplate {
1021 project: &self.project_context.borrow(),
1022 available_tools: self.tools.keys().cloned().collect(),
1023 }
1024 .render(&self.templates)
1025 .context("failed to build system prompt")
1026 .expect("Invalid template");
1027 log::debug!("System message built");
1028 LanguageModelRequestMessage {
1029 role: Role::System,
1030 content: vec![prompt.into()],
1031 cache: true,
1032 }
1033 }
1034
1035 /// A helper method that's called on every streamed completion event.
1036 /// Returns an optional tool result task, which the main agentic loop in
1037 /// send will send back to the model when it resolves.
1038 fn handle_streamed_completion_event(
1039 &mut self,
1040 event: LanguageModelCompletionEvent,
1041 event_stream: &ThreadEventStream,
1042 cx: &mut Context<Self>,
1043 ) -> Option<Task<LanguageModelToolResult>> {
1044 log::trace!("Handling streamed completion event: {:?}", event);
1045 use LanguageModelCompletionEvent::*;
1046
1047 match event {
1048 StartMessage { .. } => {
1049 self.flush_pending_message();
1050 self.pending_message = Some(AgentMessage::default());
1051 }
1052 Text(new_text) => self.handle_text_event(new_text, event_stream, cx),
1053 Thinking { text, signature } => {
1054 self.handle_thinking_event(text, signature, event_stream, cx)
1055 }
1056 RedactedThinking { data } => self.handle_redacted_thinking_event(data, cx),
1057 ToolUse(tool_use) => {
1058 return self.handle_tool_use_event(tool_use, event_stream, cx);
1059 }
1060 ToolUseJsonParseError {
1061 id,
1062 tool_name,
1063 raw_input,
1064 json_parse_error,
1065 } => {
1066 return Some(Task::ready(self.handle_tool_use_json_parse_error_event(
1067 id,
1068 tool_name,
1069 raw_input,
1070 json_parse_error,
1071 )));
1072 }
1073 UsageUpdate(_) | StatusUpdate(_) => {}
1074 Stop(_) => unreachable!(),
1075 }
1076
1077 None
1078 }
1079
1080 fn handle_text_event(
1081 &mut self,
1082 new_text: String,
1083 event_stream: &ThreadEventStream,
1084 cx: &mut Context<Self>,
1085 ) {
1086 event_stream.send_text(&new_text);
1087
1088 let last_message = self.pending_message();
1089 if let Some(AgentMessageContent::Text(text)) = last_message.content.last_mut() {
1090 text.push_str(&new_text);
1091 } else {
1092 last_message
1093 .content
1094 .push(AgentMessageContent::Text(new_text));
1095 }
1096
1097 cx.notify();
1098 }
1099
1100 fn handle_thinking_event(
1101 &mut self,
1102 new_text: String,
1103 new_signature: Option<String>,
1104 event_stream: &ThreadEventStream,
1105 cx: &mut Context<Self>,
1106 ) {
1107 event_stream.send_thinking(&new_text);
1108
1109 let last_message = self.pending_message();
1110 if let Some(AgentMessageContent::Thinking { text, signature }) =
1111 last_message.content.last_mut()
1112 {
1113 text.push_str(&new_text);
1114 *signature = new_signature.or(signature.take());
1115 } else {
1116 last_message.content.push(AgentMessageContent::Thinking {
1117 text: new_text,
1118 signature: new_signature,
1119 });
1120 }
1121
1122 cx.notify();
1123 }
1124
1125 fn handle_redacted_thinking_event(&mut self, data: String, cx: &mut Context<Self>) {
1126 let last_message = self.pending_message();
1127 last_message
1128 .content
1129 .push(AgentMessageContent::RedactedThinking(data));
1130 cx.notify();
1131 }
1132
1133 fn handle_tool_use_event(
1134 &mut self,
1135 tool_use: LanguageModelToolUse,
1136 event_stream: &ThreadEventStream,
1137 cx: &mut Context<Self>,
1138 ) -> Option<Task<LanguageModelToolResult>> {
1139 cx.notify();
1140
1141 let tool = self.tools.get(tool_use.name.as_ref()).cloned();
1142 let mut title = SharedString::from(&tool_use.name);
1143 let mut kind = acp::ToolKind::Other;
1144 if let Some(tool) = tool.as_ref() {
1145 title = tool.initial_title(tool_use.input.clone());
1146 kind = tool.kind();
1147 }
1148
1149 // Ensure the last message ends in the current tool use
1150 let last_message = self.pending_message();
1151 let push_new_tool_use = last_message.content.last_mut().map_or(true, |content| {
1152 if let AgentMessageContent::ToolUse(last_tool_use) = content {
1153 if last_tool_use.id == tool_use.id {
1154 *last_tool_use = tool_use.clone();
1155 false
1156 } else {
1157 true
1158 }
1159 } else {
1160 true
1161 }
1162 });
1163
1164 if push_new_tool_use {
1165 event_stream.send_tool_call(&tool_use.id, title, kind, tool_use.input.clone());
1166 last_message
1167 .content
1168 .push(AgentMessageContent::ToolUse(tool_use.clone()));
1169 } else {
1170 event_stream.update_tool_call_fields(
1171 &tool_use.id,
1172 acp::ToolCallUpdateFields {
1173 title: Some(title.into()),
1174 kind: Some(kind),
1175 raw_input: Some(tool_use.input.clone()),
1176 ..Default::default()
1177 },
1178 );
1179 }
1180
1181 if !tool_use.is_input_complete {
1182 return None;
1183 }
1184
1185 let Some(tool) = tool else {
1186 let content = format!("No tool named {} exists", tool_use.name);
1187 return Some(Task::ready(LanguageModelToolResult {
1188 content: LanguageModelToolResultContent::Text(Arc::from(content)),
1189 tool_use_id: tool_use.id,
1190 tool_name: tool_use.name,
1191 is_error: true,
1192 output: None,
1193 }));
1194 };
1195
1196 let fs = self.project.read(cx).fs().clone();
1197 let tool_event_stream =
1198 ToolCallEventStream::new(tool_use.id.clone(), event_stream.clone(), Some(fs));
1199 tool_event_stream.update_fields(acp::ToolCallUpdateFields {
1200 status: Some(acp::ToolCallStatus::InProgress),
1201 ..Default::default()
1202 });
1203 let supports_images = self.model.supports_images();
1204 let tool_result = tool.run(tool_use.input, tool_event_stream, cx);
1205 log::info!("Running tool {}", tool_use.name);
1206 Some(cx.foreground_executor().spawn(async move {
1207 let tool_result = tool_result.await.and_then(|output| {
1208 if let LanguageModelToolResultContent::Image(_) = &output.llm_output {
1209 if !supports_images {
1210 return Err(anyhow!(
1211 "Attempted to read an image, but this model doesn't support it.",
1212 ));
1213 }
1214 }
1215 Ok(output)
1216 });
1217
1218 match tool_result {
1219 Ok(output) => LanguageModelToolResult {
1220 tool_use_id: tool_use.id,
1221 tool_name: tool_use.name,
1222 is_error: false,
1223 content: output.llm_output,
1224 output: Some(output.raw_output),
1225 },
1226 Err(error) => LanguageModelToolResult {
1227 tool_use_id: tool_use.id,
1228 tool_name: tool_use.name,
1229 is_error: true,
1230 content: LanguageModelToolResultContent::Text(Arc::from(error.to_string())),
1231 output: None,
1232 },
1233 }
1234 }))
1235 }
1236
1237 fn handle_tool_use_json_parse_error_event(
1238 &mut self,
1239 tool_use_id: LanguageModelToolUseId,
1240 tool_name: Arc<str>,
1241 raw_input: Arc<str>,
1242 json_parse_error: String,
1243 ) -> LanguageModelToolResult {
1244 let tool_output = format!("Error parsing input JSON: {json_parse_error}");
1245 LanguageModelToolResult {
1246 tool_use_id,
1247 tool_name,
1248 is_error: true,
1249 content: LanguageModelToolResultContent::Text(tool_output.into()),
1250 output: Some(serde_json::Value::String(raw_input.to_string())),
1251 }
1252 }
1253
1254 fn pending_message(&mut self) -> &mut AgentMessage {
1255 self.pending_message.get_or_insert_default()
1256 }
1257
1258 fn flush_pending_message(&mut self) {
1259 let Some(mut message) = self.pending_message.take() else {
1260 return;
1261 };
1262
1263 for content in &message.content {
1264 let AgentMessageContent::ToolUse(tool_use) = content else {
1265 continue;
1266 };
1267
1268 if !message.tool_results.contains_key(&tool_use.id) {
1269 message.tool_results.insert(
1270 tool_use.id.clone(),
1271 LanguageModelToolResult {
1272 tool_use_id: tool_use.id.clone(),
1273 tool_name: tool_use.name.clone(),
1274 is_error: true,
1275 content: LanguageModelToolResultContent::Text(TOOL_CANCELED_MESSAGE.into()),
1276 output: None,
1277 },
1278 );
1279 }
1280 }
1281
1282 self.messages.push(Message::Agent(message));
1283 }
1284
1285 pub(crate) fn build_completion_request(
1286 &self,
1287 completion_intent: CompletionIntent,
1288 cx: &mut App,
1289 ) -> LanguageModelRequest {
1290 log::debug!("Building completion request");
1291 log::debug!("Completion intent: {:?}", completion_intent);
1292 log::debug!("Completion mode: {:?}", self.completion_mode);
1293
1294 let messages = self.build_request_messages();
1295 log::info!("Request will include {} messages", messages.len());
1296
1297 let tools = if let Some(tools) = self.tools(cx).log_err() {
1298 tools
1299 .filter_map(|tool| {
1300 let tool_name = tool.name().to_string();
1301 log::trace!("Including tool: {}", tool_name);
1302 Some(LanguageModelRequestTool {
1303 name: tool_name,
1304 description: tool.description().to_string(),
1305 input_schema: tool
1306 .input_schema(self.model.tool_input_format())
1307 .log_err()?,
1308 })
1309 })
1310 .collect()
1311 } else {
1312 Vec::new()
1313 };
1314
1315 log::info!("Request includes {} tools", tools.len());
1316
1317 let request = LanguageModelRequest {
1318 thread_id: Some(self.id.to_string()),
1319 prompt_id: Some(self.prompt_id.to_string()),
1320 intent: Some(completion_intent),
1321 mode: Some(self.completion_mode.into()),
1322 messages,
1323 tools,
1324 tool_choice: None,
1325 stop: Vec::new(),
1326 temperature: AgentSettings::temperature_for_model(self.model(), cx),
1327 thinking_allowed: true,
1328 };
1329
1330 log::debug!("Completion request built successfully");
1331 request
1332 }
1333
1334 fn tools<'a>(&'a self, cx: &'a App) -> Result<impl Iterator<Item = &'a Arc<dyn AnyAgentTool>>> {
1335 let profile = AgentSettings::get_global(cx)
1336 .profiles
1337 .get(&self.profile_id)
1338 .context("profile not found")?;
1339 let provider_id = self.model.provider_id();
1340
1341 Ok(self
1342 .tools
1343 .iter()
1344 .filter(move |(_, tool)| tool.supported_provider(&provider_id))
1345 .filter_map(|(tool_name, tool)| {
1346 if profile.is_tool_enabled(tool_name) {
1347 Some(tool)
1348 } else {
1349 None
1350 }
1351 })
1352 .chain(self.context_server_registry.read(cx).servers().flat_map(
1353 |(server_id, tools)| {
1354 tools.iter().filter_map(|(tool_name, tool)| {
1355 if profile.is_context_server_tool_enabled(&server_id.0, tool_name) {
1356 Some(tool)
1357 } else {
1358 None
1359 }
1360 })
1361 },
1362 )))
1363 }
1364
1365 fn build_request_messages(&self) -> Vec<LanguageModelRequestMessage> {
1366 log::trace!(
1367 "Building request messages from {} thread messages",
1368 self.messages.len()
1369 );
1370 let mut messages = vec![self.build_system_message()];
1371 for message in &self.messages {
1372 match message {
1373 Message::User(message) => messages.push(message.to_request()),
1374 Message::Agent(message) => messages.extend(message.to_request()),
1375 Message::Resume => messages.push(LanguageModelRequestMessage {
1376 role: Role::User,
1377 content: vec!["Continue where you left off".into()],
1378 cache: false,
1379 }),
1380 }
1381 }
1382
1383 if let Some(message) = self.pending_message.as_ref() {
1384 messages.extend(message.to_request());
1385 }
1386
1387 if let Some(last_user_message) = messages
1388 .iter_mut()
1389 .rev()
1390 .find(|message| message.role == Role::User)
1391 {
1392 last_user_message.cache = true;
1393 }
1394
1395 messages
1396 }
1397
1398 pub fn to_markdown(&self) -> String {
1399 let mut markdown = String::new();
1400 for (ix, message) in self.messages.iter().enumerate() {
1401 if ix > 0 {
1402 markdown.push('\n');
1403 }
1404 markdown.push_str(&message.to_markdown());
1405 }
1406
1407 if let Some(message) = self.pending_message.as_ref() {
1408 markdown.push('\n');
1409 markdown.push_str(&message.to_markdown());
1410 }
1411
1412 markdown
1413 }
1414
1415 fn advance_prompt_id(&mut self) {
1416 self.prompt_id = PromptId::new();
1417 }
1418}
1419
1420struct RunningTurn {
1421 /// Holds the task that handles agent interaction until the end of the turn.
1422 /// Survives across multiple requests as the model performs tool calls and
1423 /// we run tools, report their results.
1424 _task: Task<()>,
1425 /// The current event stream for the running turn. Used to report a final
1426 /// cancellation event if we cancel the turn.
1427 event_stream: ThreadEventStream,
1428}
1429
1430impl RunningTurn {
1431 fn cancel(self) {
1432 log::debug!("Cancelling in progress turn");
1433 self.event_stream.send_canceled();
1434 }
1435}
1436
1437pub trait AgentTool
1438where
1439 Self: 'static + Sized,
1440{
1441 type Input: for<'de> Deserialize<'de> + Serialize + JsonSchema;
1442 type Output: for<'de> Deserialize<'de> + Serialize + Into<LanguageModelToolResultContent>;
1443
1444 fn name(&self) -> SharedString;
1445
1446 fn description(&self) -> SharedString {
1447 let schema = schemars::schema_for!(Self::Input);
1448 SharedString::new(
1449 schema
1450 .get("description")
1451 .and_then(|description| description.as_str())
1452 .unwrap_or_default(),
1453 )
1454 }
1455
1456 fn kind(&self) -> acp::ToolKind;
1457
1458 /// The initial tool title to display. Can be updated during the tool run.
1459 fn initial_title(&self, input: Result<Self::Input, serde_json::Value>) -> SharedString;
1460
1461 /// Returns the JSON schema that describes the tool's input.
1462 fn input_schema(&self) -> Schema {
1463 schemars::schema_for!(Self::Input)
1464 }
1465
1466 /// Some tools rely on a provider for the underlying billing or other reasons.
1467 /// Allow the tool to check if they are compatible, or should be filtered out.
1468 fn supported_provider(&self, _provider: &LanguageModelProviderId) -> bool {
1469 true
1470 }
1471
1472 /// Runs the tool with the provided input.
1473 fn run(
1474 self: Arc<Self>,
1475 input: Self::Input,
1476 event_stream: ToolCallEventStream,
1477 cx: &mut App,
1478 ) -> Task<Result<Self::Output>>;
1479
1480 /// Emits events for a previous execution of the tool.
1481 fn replay(
1482 &self,
1483 _input: Self::Input,
1484 _output: Self::Output,
1485 _event_stream: ToolCallEventStream,
1486 _cx: &mut App,
1487 ) -> Result<()> {
1488 Ok(())
1489 }
1490
1491 fn erase(self) -> Arc<dyn AnyAgentTool> {
1492 Arc::new(Erased(Arc::new(self)))
1493 }
1494}
1495
1496pub struct Erased<T>(T);
1497
1498pub struct AgentToolOutput {
1499 pub llm_output: LanguageModelToolResultContent,
1500 pub raw_output: serde_json::Value,
1501}
1502
1503pub trait AnyAgentTool {
1504 fn name(&self) -> SharedString;
1505 fn description(&self) -> SharedString;
1506 fn kind(&self) -> acp::ToolKind;
1507 fn initial_title(&self, input: serde_json::Value) -> SharedString;
1508 fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value>;
1509 fn supported_provider(&self, _provider: &LanguageModelProviderId) -> bool {
1510 true
1511 }
1512 fn run(
1513 self: Arc<Self>,
1514 input: serde_json::Value,
1515 event_stream: ToolCallEventStream,
1516 cx: &mut App,
1517 ) -> Task<Result<AgentToolOutput>>;
1518 fn replay(
1519 &self,
1520 input: serde_json::Value,
1521 output: serde_json::Value,
1522 event_stream: ToolCallEventStream,
1523 cx: &mut App,
1524 ) -> Result<()>;
1525}
1526
1527impl<T> AnyAgentTool for Erased<Arc<T>>
1528where
1529 T: AgentTool,
1530{
1531 fn name(&self) -> SharedString {
1532 self.0.name()
1533 }
1534
1535 fn description(&self) -> SharedString {
1536 self.0.description()
1537 }
1538
1539 fn kind(&self) -> agent_client_protocol::ToolKind {
1540 self.0.kind()
1541 }
1542
1543 fn initial_title(&self, input: serde_json::Value) -> SharedString {
1544 let parsed_input = serde_json::from_value(input.clone()).map_err(|_| input);
1545 self.0.initial_title(parsed_input)
1546 }
1547
1548 fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
1549 let mut json = serde_json::to_value(self.0.input_schema())?;
1550 adapt_schema_to_format(&mut json, format)?;
1551 Ok(json)
1552 }
1553
1554 fn supported_provider(&self, provider: &LanguageModelProviderId) -> bool {
1555 self.0.supported_provider(provider)
1556 }
1557
1558 fn run(
1559 self: Arc<Self>,
1560 input: serde_json::Value,
1561 event_stream: ToolCallEventStream,
1562 cx: &mut App,
1563 ) -> Task<Result<AgentToolOutput>> {
1564 cx.spawn(async move |cx| {
1565 let input = serde_json::from_value(input)?;
1566 let output = cx
1567 .update(|cx| self.0.clone().run(input, event_stream, cx))?
1568 .await?;
1569 let raw_output = serde_json::to_value(&output)?;
1570 Ok(AgentToolOutput {
1571 llm_output: output.into(),
1572 raw_output,
1573 })
1574 })
1575 }
1576
1577 fn replay(
1578 &self,
1579 input: serde_json::Value,
1580 output: serde_json::Value,
1581 event_stream: ToolCallEventStream,
1582 cx: &mut App,
1583 ) -> Result<()> {
1584 let input = serde_json::from_value(input)?;
1585 let output = serde_json::from_value(output)?;
1586 self.0.replay(input, output, event_stream, cx)
1587 }
1588}
1589
1590#[derive(Clone)]
1591struct ThreadEventStream(mpsc::UnboundedSender<Result<ThreadEvent>>);
1592
1593impl ThreadEventStream {
1594 fn send_user_message(&self, message: &UserMessage) {
1595 self.0
1596 .unbounded_send(Ok(ThreadEvent::UserMessage(message.clone())))
1597 .ok();
1598 }
1599
1600 fn send_text(&self, text: &str) {
1601 self.0
1602 .unbounded_send(Ok(ThreadEvent::AgentText(text.to_string())))
1603 .ok();
1604 }
1605
1606 fn send_thinking(&self, text: &str) {
1607 self.0
1608 .unbounded_send(Ok(ThreadEvent::AgentThinking(text.to_string())))
1609 .ok();
1610 }
1611
1612 fn send_tool_call(
1613 &self,
1614 id: &LanguageModelToolUseId,
1615 title: SharedString,
1616 kind: acp::ToolKind,
1617 input: serde_json::Value,
1618 ) {
1619 self.0
1620 .unbounded_send(Ok(ThreadEvent::ToolCall(Self::initial_tool_call(
1621 id,
1622 title.to_string(),
1623 kind,
1624 input,
1625 ))))
1626 .ok();
1627 }
1628
1629 fn initial_tool_call(
1630 id: &LanguageModelToolUseId,
1631 title: String,
1632 kind: acp::ToolKind,
1633 input: serde_json::Value,
1634 ) -> acp::ToolCall {
1635 acp::ToolCall {
1636 id: acp::ToolCallId(id.to_string().into()),
1637 title,
1638 kind,
1639 status: acp::ToolCallStatus::Pending,
1640 content: vec![],
1641 locations: vec![],
1642 raw_input: Some(input),
1643 raw_output: None,
1644 }
1645 }
1646
1647 fn update_tool_call_fields(
1648 &self,
1649 tool_use_id: &LanguageModelToolUseId,
1650 fields: acp::ToolCallUpdateFields,
1651 ) {
1652 self.0
1653 .unbounded_send(Ok(ThreadEvent::ToolCallUpdate(
1654 acp::ToolCallUpdate {
1655 id: acp::ToolCallId(tool_use_id.to_string().into()),
1656 fields,
1657 }
1658 .into(),
1659 )))
1660 .ok();
1661 }
1662
1663 fn send_stop(&self, reason: StopReason) {
1664 match reason {
1665 StopReason::EndTurn => {
1666 self.0
1667 .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::EndTurn)))
1668 .ok();
1669 }
1670 StopReason::MaxTokens => {
1671 self.0
1672 .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::MaxTokens)))
1673 .ok();
1674 }
1675 StopReason::Refusal => {
1676 self.0
1677 .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::Refusal)))
1678 .ok();
1679 }
1680 StopReason::ToolUse => {}
1681 }
1682 }
1683
1684 fn send_canceled(&self) {
1685 self.0
1686 .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::Canceled)))
1687 .ok();
1688 }
1689
1690 fn send_error(&self, error: impl Into<anyhow::Error>) {
1691 self.0.unbounded_send(Err(error.into())).ok();
1692 }
1693}
1694
1695#[derive(Clone)]
1696pub struct ToolCallEventStream {
1697 tool_use_id: LanguageModelToolUseId,
1698 stream: ThreadEventStream,
1699 fs: Option<Arc<dyn Fs>>,
1700}
1701
1702impl ToolCallEventStream {
1703 #[cfg(test)]
1704 pub fn test() -> (Self, ToolCallEventStreamReceiver) {
1705 let (events_tx, events_rx) = mpsc::unbounded::<Result<ThreadEvent>>();
1706
1707 let stream = ToolCallEventStream::new("test_id".into(), ThreadEventStream(events_tx), None);
1708
1709 (stream, ToolCallEventStreamReceiver(events_rx))
1710 }
1711
1712 fn new(
1713 tool_use_id: LanguageModelToolUseId,
1714 stream: ThreadEventStream,
1715 fs: Option<Arc<dyn Fs>>,
1716 ) -> Self {
1717 Self {
1718 tool_use_id,
1719 stream,
1720 fs,
1721 }
1722 }
1723
1724 pub fn update_fields(&self, fields: acp::ToolCallUpdateFields) {
1725 self.stream
1726 .update_tool_call_fields(&self.tool_use_id, fields);
1727 }
1728
1729 pub fn update_diff(&self, diff: Entity<acp_thread::Diff>) {
1730 self.stream
1731 .0
1732 .unbounded_send(Ok(ThreadEvent::ToolCallUpdate(
1733 acp_thread::ToolCallUpdateDiff {
1734 id: acp::ToolCallId(self.tool_use_id.to_string().into()),
1735 diff,
1736 }
1737 .into(),
1738 )))
1739 .ok();
1740 }
1741
1742 pub fn update_terminal(&self, terminal: Entity<acp_thread::Terminal>) {
1743 self.stream
1744 .0
1745 .unbounded_send(Ok(ThreadEvent::ToolCallUpdate(
1746 acp_thread::ToolCallUpdateTerminal {
1747 id: acp::ToolCallId(self.tool_use_id.to_string().into()),
1748 terminal,
1749 }
1750 .into(),
1751 )))
1752 .ok();
1753 }
1754
1755 pub fn authorize(&self, title: impl Into<String>, cx: &mut App) -> Task<Result<()>> {
1756 if agent_settings::AgentSettings::get_global(cx).always_allow_tool_actions {
1757 return Task::ready(Ok(()));
1758 }
1759
1760 let (response_tx, response_rx) = oneshot::channel();
1761 self.stream
1762 .0
1763 .unbounded_send(Ok(ThreadEvent::ToolCallAuthorization(
1764 ToolCallAuthorization {
1765 tool_call: acp::ToolCallUpdate {
1766 id: acp::ToolCallId(self.tool_use_id.to_string().into()),
1767 fields: acp::ToolCallUpdateFields {
1768 title: Some(title.into()),
1769 ..Default::default()
1770 },
1771 },
1772 options: vec![
1773 acp::PermissionOption {
1774 id: acp::PermissionOptionId("always_allow".into()),
1775 name: "Always Allow".into(),
1776 kind: acp::PermissionOptionKind::AllowAlways,
1777 },
1778 acp::PermissionOption {
1779 id: acp::PermissionOptionId("allow".into()),
1780 name: "Allow".into(),
1781 kind: acp::PermissionOptionKind::AllowOnce,
1782 },
1783 acp::PermissionOption {
1784 id: acp::PermissionOptionId("deny".into()),
1785 name: "Deny".into(),
1786 kind: acp::PermissionOptionKind::RejectOnce,
1787 },
1788 ],
1789 response: response_tx,
1790 },
1791 )))
1792 .ok();
1793 let fs = self.fs.clone();
1794 cx.spawn(async move |cx| match response_rx.await?.0.as_ref() {
1795 "always_allow" => {
1796 if let Some(fs) = fs.clone() {
1797 cx.update(|cx| {
1798 update_settings_file::<AgentSettings>(fs, cx, |settings, _| {
1799 settings.set_always_allow_tool_actions(true);
1800 });
1801 })?;
1802 }
1803
1804 Ok(())
1805 }
1806 "allow" => Ok(()),
1807 _ => Err(anyhow!("Permission to run tool denied by user")),
1808 })
1809 }
1810}
1811
1812#[cfg(test)]
1813pub struct ToolCallEventStreamReceiver(mpsc::UnboundedReceiver<Result<ThreadEvent>>);
1814
1815#[cfg(test)]
1816impl ToolCallEventStreamReceiver {
1817 pub async fn expect_authorization(&mut self) -> ToolCallAuthorization {
1818 let event = self.0.next().await;
1819 if let Some(Ok(ThreadEvent::ToolCallAuthorization(auth))) = event {
1820 auth
1821 } else {
1822 panic!("Expected ToolCallAuthorization but got: {:?}", event);
1823 }
1824 }
1825
1826 pub async fn expect_terminal(&mut self) -> Entity<acp_thread::Terminal> {
1827 let event = self.0.next().await;
1828 if let Some(Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateTerminal(
1829 update,
1830 )))) = event
1831 {
1832 update.terminal
1833 } else {
1834 panic!("Expected terminal but got: {:?}", event);
1835 }
1836 }
1837}
1838
1839#[cfg(test)]
1840impl std::ops::Deref for ToolCallEventStreamReceiver {
1841 type Target = mpsc::UnboundedReceiver<Result<ThreadEvent>>;
1842
1843 fn deref(&self) -> &Self::Target {
1844 &self.0
1845 }
1846}
1847
1848#[cfg(test)]
1849impl std::ops::DerefMut for ToolCallEventStreamReceiver {
1850 fn deref_mut(&mut self) -> &mut Self::Target {
1851 &mut self.0
1852 }
1853}
1854
1855impl From<&str> for UserMessageContent {
1856 fn from(text: &str) -> Self {
1857 Self::Text(text.into())
1858 }
1859}
1860
1861impl From<acp::ContentBlock> for UserMessageContent {
1862 fn from(value: acp::ContentBlock) -> Self {
1863 match value {
1864 acp::ContentBlock::Text(text_content) => Self::Text(text_content.text),
1865 acp::ContentBlock::Image(image_content) => Self::Image(convert_image(image_content)),
1866 acp::ContentBlock::Audio(_) => {
1867 // TODO
1868 Self::Text("[audio]".to_string())
1869 }
1870 acp::ContentBlock::ResourceLink(resource_link) => {
1871 match MentionUri::parse(&resource_link.uri) {
1872 Ok(uri) => Self::Mention {
1873 uri,
1874 content: String::new(),
1875 },
1876 Err(err) => {
1877 log::error!("Failed to parse mention link: {}", err);
1878 Self::Text(format!("[{}]({})", resource_link.name, resource_link.uri))
1879 }
1880 }
1881 }
1882 acp::ContentBlock::Resource(resource) => match resource.resource {
1883 acp::EmbeddedResourceResource::TextResourceContents(resource) => {
1884 match MentionUri::parse(&resource.uri) {
1885 Ok(uri) => Self::Mention {
1886 uri,
1887 content: resource.text,
1888 },
1889 Err(err) => {
1890 log::error!("Failed to parse mention link: {}", err);
1891 Self::Text(
1892 MarkdownCodeBlock {
1893 tag: &resource.uri,
1894 text: &resource.text,
1895 }
1896 .to_string(),
1897 )
1898 }
1899 }
1900 }
1901 acp::EmbeddedResourceResource::BlobResourceContents(_) => {
1902 // TODO
1903 Self::Text("[blob]".to_string())
1904 }
1905 },
1906 }
1907 }
1908}
1909
1910impl From<UserMessageContent> for acp::ContentBlock {
1911 fn from(content: UserMessageContent) -> Self {
1912 match content {
1913 UserMessageContent::Text(text) => acp::ContentBlock::Text(acp::TextContent {
1914 text,
1915 annotations: None,
1916 }),
1917 UserMessageContent::Image(image) => acp::ContentBlock::Image(acp::ImageContent {
1918 data: image.source.to_string(),
1919 mime_type: "image/png".to_string(),
1920 annotations: None,
1921 uri: None,
1922 }),
1923 UserMessageContent::Mention { uri, content } => {
1924 todo!()
1925 }
1926 }
1927 }
1928}
1929
1930fn convert_image(image_content: acp::ImageContent) -> LanguageModelImage {
1931 LanguageModelImage {
1932 source: image_content.data.into(),
1933 // TODO: make this optional?
1934 size: gpui::Size::new(0.into(), 0.into()),
1935 }
1936}