1use crate::{
2 ContextServerRegistry, DbLanguageModel, DbThread, SystemPromptTemplate, Template, Templates,
3};
4use acp_thread::{MentionUri, UserMessageId};
5use action_log::ActionLog;
6use agent::thread::{DetailedSummaryState, GitState, ProjectSnapshot, WorktreeSnapshot};
7use agent_client_protocol as acp;
8use agent_settings::{AgentProfileId, AgentSettings, CompletionMode};
9use anyhow::{Context as _, Result, anyhow};
10use assistant_tool::adapt_schema_to_format;
11use chrono::{DateTime, Utc};
12use cloud_llm_client::{CompletionIntent, CompletionRequestStatus};
13use collections::IndexMap;
14use fs::Fs;
15use futures::{
16 FutureExt,
17 channel::{mpsc, oneshot},
18 future::Shared,
19 stream::FuturesUnordered,
20};
21use git::repository::DiffType;
22use gpui::{App, AppContext, Context, Entity, SharedString, Task};
23use language_model::{
24 LanguageModel, LanguageModelCompletionEvent, LanguageModelImage, LanguageModelProviderId,
25 LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
26 LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolSchemaFormat,
27 LanguageModelToolUse, LanguageModelToolUseId, Role, StopReason, TokenUsage,
28};
29use project::{
30 Project,
31 git_store::{GitStore, RepositoryState},
32};
33use prompt_store::ProjectContext;
34use schemars::{JsonSchema, Schema};
35use serde::{Deserialize, Serialize};
36use settings::{Settings, update_settings_file};
37use smol::stream::StreamExt;
38use std::{cell::RefCell, collections::BTreeMap, path::Path, rc::Rc, sync::Arc};
39use std::{fmt::Write, ops::Range};
40use util::{ResultExt, markdown::MarkdownCodeBlock};
41use uuid::Uuid;
42
43const TOOL_CANCELED_MESSAGE: &str = "Tool canceled by user";
44
45/// The ID of the user prompt that initiated a request.
46///
47/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
48#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
49pub struct PromptId(Arc<str>);
50
51impl PromptId {
52 pub fn new() -> Self {
53 Self(Uuid::new_v4().to_string().into())
54 }
55}
56
57impl std::fmt::Display for PromptId {
58 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
59 write!(f, "{}", self.0)
60 }
61}
62
63#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
64pub enum Message {
65 User(UserMessage),
66 Agent(AgentMessage),
67 Resume,
68}
69
70impl Message {
71 pub fn as_agent_message(&self) -> Option<&AgentMessage> {
72 match self {
73 Message::Agent(agent_message) => Some(agent_message),
74 _ => None,
75 }
76 }
77
78 pub fn to_markdown(&self) -> String {
79 match self {
80 Message::User(message) => message.to_markdown(),
81 Message::Agent(message) => message.to_markdown(),
82 Message::Resume => "[resumed after tool use limit was reached]".into(),
83 }
84 }
85}
86
87#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
88pub struct UserMessage {
89 pub id: UserMessageId,
90 pub content: Vec<UserMessageContent>,
91}
92
93#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
94pub enum UserMessageContent {
95 Text(String),
96 Mention { uri: MentionUri, content: String },
97 Image(LanguageModelImage),
98}
99
100impl UserMessage {
101 pub fn to_markdown(&self) -> String {
102 let mut markdown = String::from("## User\n\n");
103
104 for content in &self.content {
105 match content {
106 UserMessageContent::Text(text) => {
107 markdown.push_str(text);
108 markdown.push('\n');
109 }
110 UserMessageContent::Image(_) => {
111 markdown.push_str("<image />\n");
112 }
113 UserMessageContent::Mention { uri, content } => {
114 if !content.is_empty() {
115 let _ = write!(&mut markdown, "{}\n\n{}\n", uri.as_link(), content);
116 } else {
117 let _ = write!(&mut markdown, "{}\n", uri.as_link());
118 }
119 }
120 }
121 }
122
123 markdown
124 }
125
126 fn to_request(&self) -> LanguageModelRequestMessage {
127 let mut message = LanguageModelRequestMessage {
128 role: Role::User,
129 content: Vec::with_capacity(self.content.len()),
130 cache: false,
131 };
132
133 const OPEN_CONTEXT: &str = "<context>\n\
134 The following items were attached by the user. \
135 They are up-to-date and don't need to be re-read.\n\n";
136
137 const OPEN_FILES_TAG: &str = "<files>";
138 const OPEN_SYMBOLS_TAG: &str = "<symbols>";
139 const OPEN_THREADS_TAG: &str = "<threads>";
140 const OPEN_FETCH_TAG: &str = "<fetched_urls>";
141 const OPEN_RULES_TAG: &str =
142 "<rules>\nThe user has specified the following rules that should be applied:\n";
143
144 let mut file_context = OPEN_FILES_TAG.to_string();
145 let mut symbol_context = OPEN_SYMBOLS_TAG.to_string();
146 let mut thread_context = OPEN_THREADS_TAG.to_string();
147 let mut fetch_context = OPEN_FETCH_TAG.to_string();
148 let mut rules_context = OPEN_RULES_TAG.to_string();
149
150 for chunk in &self.content {
151 let chunk = match chunk {
152 UserMessageContent::Text(text) => {
153 language_model::MessageContent::Text(text.clone())
154 }
155 UserMessageContent::Image(value) => {
156 language_model::MessageContent::Image(value.clone())
157 }
158 UserMessageContent::Mention { uri, content } => {
159 match uri {
160 MentionUri::File { abs_path, .. } => {
161 write!(
162 &mut symbol_context,
163 "\n{}",
164 MarkdownCodeBlock {
165 tag: &codeblock_tag(&abs_path, None),
166 text: &content.to_string(),
167 }
168 )
169 .ok();
170 }
171 MentionUri::Symbol {
172 path, line_range, ..
173 }
174 | MentionUri::Selection {
175 path, line_range, ..
176 } => {
177 write!(
178 &mut rules_context,
179 "\n{}",
180 MarkdownCodeBlock {
181 tag: &codeblock_tag(&path, Some(line_range)),
182 text: &content
183 }
184 )
185 .ok();
186 }
187 MentionUri::Thread { .. } => {
188 write!(&mut thread_context, "\n{}\n", content).ok();
189 }
190 MentionUri::TextThread { .. } => {
191 write!(&mut thread_context, "\n{}\n", content).ok();
192 }
193 MentionUri::Rule { .. } => {
194 write!(
195 &mut rules_context,
196 "\n{}",
197 MarkdownCodeBlock {
198 tag: "",
199 text: &content
200 }
201 )
202 .ok();
203 }
204 MentionUri::Fetch { url } => {
205 write!(&mut fetch_context, "\nFetch: {}\n\n{}", url, content).ok();
206 }
207 }
208
209 language_model::MessageContent::Text(uri.as_link().to_string())
210 }
211 };
212
213 message.content.push(chunk);
214 }
215
216 let len_before_context = message.content.len();
217
218 if file_context.len() > OPEN_FILES_TAG.len() {
219 file_context.push_str("</files>\n");
220 message
221 .content
222 .push(language_model::MessageContent::Text(file_context));
223 }
224
225 if symbol_context.len() > OPEN_SYMBOLS_TAG.len() {
226 symbol_context.push_str("</symbols>\n");
227 message
228 .content
229 .push(language_model::MessageContent::Text(symbol_context));
230 }
231
232 if thread_context.len() > OPEN_THREADS_TAG.len() {
233 thread_context.push_str("</threads>\n");
234 message
235 .content
236 .push(language_model::MessageContent::Text(thread_context));
237 }
238
239 if fetch_context.len() > OPEN_FETCH_TAG.len() {
240 fetch_context.push_str("</fetched_urls>\n");
241 message
242 .content
243 .push(language_model::MessageContent::Text(fetch_context));
244 }
245
246 if rules_context.len() > OPEN_RULES_TAG.len() {
247 rules_context.push_str("</user_rules>\n");
248 message
249 .content
250 .push(language_model::MessageContent::Text(rules_context));
251 }
252
253 if message.content.len() > len_before_context {
254 message.content.insert(
255 len_before_context,
256 language_model::MessageContent::Text(OPEN_CONTEXT.into()),
257 );
258 message
259 .content
260 .push(language_model::MessageContent::Text("</context>".into()));
261 }
262
263 message
264 }
265}
266
267fn codeblock_tag(full_path: &Path, line_range: Option<&Range<u32>>) -> String {
268 let mut result = String::new();
269
270 if let Some(extension) = full_path.extension().and_then(|ext| ext.to_str()) {
271 let _ = write!(result, "{} ", extension);
272 }
273
274 let _ = write!(result, "{}", full_path.display());
275
276 if let Some(range) = line_range {
277 if range.start == range.end {
278 let _ = write!(result, ":{}", range.start + 1);
279 } else {
280 let _ = write!(result, ":{}-{}", range.start + 1, range.end + 1);
281 }
282 }
283
284 result
285}
286
287impl AgentMessage {
288 pub fn to_markdown(&self) -> String {
289 let mut markdown = String::from("## Assistant\n\n");
290
291 for content in &self.content {
292 match content {
293 AgentMessageContent::Text(text) => {
294 markdown.push_str(text);
295 markdown.push('\n');
296 }
297 AgentMessageContent::Thinking { text, .. } => {
298 markdown.push_str("<think>");
299 markdown.push_str(text);
300 markdown.push_str("</think>\n");
301 }
302 AgentMessageContent::RedactedThinking(_) => {
303 markdown.push_str("<redacted_thinking />\n")
304 }
305 AgentMessageContent::ToolUse(tool_use) => {
306 markdown.push_str(&format!(
307 "**Tool Use**: {} (ID: {})\n",
308 tool_use.name, tool_use.id
309 ));
310 markdown.push_str(&format!(
311 "{}\n",
312 MarkdownCodeBlock {
313 tag: "json",
314 text: &format!("{:#}", tool_use.input)
315 }
316 ));
317 }
318 }
319 }
320
321 for tool_result in self.tool_results.values() {
322 markdown.push_str(&format!(
323 "**Tool Result**: {} (ID: {})\n\n",
324 tool_result.tool_name, tool_result.tool_use_id
325 ));
326 if tool_result.is_error {
327 markdown.push_str("**ERROR:**\n");
328 }
329
330 match &tool_result.content {
331 LanguageModelToolResultContent::Text(text) => {
332 writeln!(markdown, "{text}\n").ok();
333 }
334 LanguageModelToolResultContent::Image(_) => {
335 writeln!(markdown, "<image />\n").ok();
336 }
337 }
338
339 if let Some(output) = tool_result.output.as_ref() {
340 writeln!(
341 markdown,
342 "**Debug Output**:\n\n```json\n{}\n```\n",
343 serde_json::to_string_pretty(output).unwrap()
344 )
345 .unwrap();
346 }
347 }
348
349 markdown
350 }
351
352 pub fn to_request(&self) -> Vec<LanguageModelRequestMessage> {
353 let mut assistant_message = LanguageModelRequestMessage {
354 role: Role::Assistant,
355 content: Vec::with_capacity(self.content.len()),
356 cache: false,
357 };
358 for chunk in &self.content {
359 let chunk = match chunk {
360 AgentMessageContent::Text(text) => {
361 language_model::MessageContent::Text(text.clone())
362 }
363 AgentMessageContent::Thinking { text, signature } => {
364 language_model::MessageContent::Thinking {
365 text: text.clone(),
366 signature: signature.clone(),
367 }
368 }
369 AgentMessageContent::RedactedThinking(value) => {
370 language_model::MessageContent::RedactedThinking(value.clone())
371 }
372 AgentMessageContent::ToolUse(value) => {
373 language_model::MessageContent::ToolUse(value.clone())
374 }
375 };
376 assistant_message.content.push(chunk);
377 }
378
379 let mut user_message = LanguageModelRequestMessage {
380 role: Role::User,
381 content: Vec::new(),
382 cache: false,
383 };
384
385 for tool_result in self.tool_results.values() {
386 user_message
387 .content
388 .push(language_model::MessageContent::ToolResult(
389 tool_result.clone(),
390 ));
391 }
392
393 let mut messages = Vec::new();
394 if !assistant_message.content.is_empty() {
395 messages.push(assistant_message);
396 }
397 if !user_message.content.is_empty() {
398 messages.push(user_message);
399 }
400 messages
401 }
402}
403
404#[derive(Default, Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
405pub struct AgentMessage {
406 pub content: Vec<AgentMessageContent>,
407 pub tool_results: IndexMap<LanguageModelToolUseId, LanguageModelToolResult>,
408}
409
410#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
411pub enum AgentMessageContent {
412 Text(String),
413 Thinking {
414 text: String,
415 signature: Option<String>,
416 },
417 RedactedThinking(String),
418 ToolUse(LanguageModelToolUse),
419}
420
421#[derive(Debug)]
422pub enum ThreadEvent {
423 UserMessage(UserMessage),
424 AgentText(String),
425 AgentThinking(String),
426 ToolCall(acp::ToolCall),
427 ToolCallUpdate(acp_thread::ToolCallUpdate),
428 ToolCallAuthorization(ToolCallAuthorization),
429 Stop(acp::StopReason),
430}
431
432#[derive(Debug)]
433pub struct ToolCallAuthorization {
434 pub tool_call: acp::ToolCallUpdate,
435 pub options: Vec<acp::PermissionOption>,
436 pub response: oneshot::Sender<acp::PermissionOptionId>,
437}
438
439enum ThreadTitle {
440 None,
441 Pending(Task<()>),
442 Done(Result<SharedString>),
443}
444
445impl ThreadTitle {
446 pub fn unwrap_or_default(&self) -> SharedString {
447 if let ThreadTitle::Done(Ok(title)) = self {
448 title.clone()
449 } else {
450 "New Thread".into()
451 }
452 }
453}
454
455pub struct Thread {
456 id: acp::SessionId,
457 prompt_id: PromptId,
458 updated_at: DateTime<Utc>,
459 title: ThreadTitle,
460 summary: DetailedSummaryState,
461 messages: Vec<Message>,
462 completion_mode: CompletionMode,
463 /// Holds the task that handles agent interaction until the end of the turn.
464 /// Survives across multiple requests as the model performs tool calls and
465 /// we run tools, report their results.
466 running_turn: Option<RunningTurn>,
467 pending_message: Option<AgentMessage>,
468 tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
469 tool_use_limit_reached: bool,
470 request_token_usage: Vec<TokenUsage>,
471 cumulative_token_usage: TokenUsage,
472 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
473 context_server_registry: Entity<ContextServerRegistry>,
474 profile_id: AgentProfileId,
475 project_context: Rc<RefCell<ProjectContext>>,
476 templates: Arc<Templates>,
477 model: Arc<dyn LanguageModel>,
478 project: Entity<Project>,
479 action_log: Entity<ActionLog>,
480}
481
482impl Thread {
483 pub fn new(
484 id: acp::SessionId,
485 project: Entity<Project>,
486 project_context: Rc<RefCell<ProjectContext>>,
487 context_server_registry: Entity<ContextServerRegistry>,
488 action_log: Entity<ActionLog>,
489 templates: Arc<Templates>,
490 model: Arc<dyn LanguageModel>,
491 cx: &mut Context<Self>,
492 ) -> Self {
493 let profile_id = AgentSettings::get_global(cx).default_profile.clone();
494 Self {
495 id,
496 prompt_id: PromptId::new(),
497 updated_at: Utc::now(),
498 title: ThreadTitle::None,
499 summary: DetailedSummaryState::default(),
500 messages: Vec::new(),
501 completion_mode: CompletionMode::Normal,
502 running_turn: None,
503 pending_message: None,
504 tools: BTreeMap::default(),
505 tool_use_limit_reached: false,
506 request_token_usage: Vec::new(),
507 cumulative_token_usage: TokenUsage::default(),
508 initial_project_snapshot: {
509 let project_snapshot = Self::project_snapshot(project.clone(), cx);
510 cx.foreground_executor()
511 .spawn(async move { Some(project_snapshot.await) })
512 .shared()
513 },
514 context_server_registry,
515 profile_id,
516 project_context,
517 templates,
518 model,
519 project,
520 action_log,
521 }
522 }
523
524 pub fn id(&self) -> &acp::SessionId {
525 &self.id
526 }
527
528 pub fn from_db(
529 id: acp::SessionId,
530 db_thread: DbThread,
531 project: Entity<Project>,
532 project_context: Rc<RefCell<ProjectContext>>,
533 context_server_registry: Entity<ContextServerRegistry>,
534 action_log: Entity<ActionLog>,
535 templates: Arc<Templates>,
536 model: Arc<dyn LanguageModel>,
537 cx: &mut Context<Self>,
538 ) -> Self {
539 let profile_id = db_thread
540 .profile
541 .unwrap_or_else(|| AgentSettings::get_global(cx).default_profile.clone());
542 Self {
543 id,
544 prompt_id: PromptId::new(),
545 title: ThreadTitle::Done(Ok(db_thread.title.clone())),
546 summary: db_thread.summary,
547 messages: db_thread.messages,
548 completion_mode: CompletionMode::Normal,
549 running_turn: None,
550 pending_message: None,
551 tools: BTreeMap::default(),
552 tool_use_limit_reached: false,
553 request_token_usage: db_thread.request_token_usage.clone(),
554 cumulative_token_usage: db_thread.cumulative_token_usage.clone(),
555 initial_project_snapshot: Task::ready(db_thread.initial_project_snapshot).shared(),
556 context_server_registry,
557 profile_id,
558 project_context,
559 templates,
560 model,
561 project,
562 action_log,
563 updated_at: db_thread.updated_at, // todo!(figure out if we can remove the "recently opened" list)
564 }
565 }
566
567 pub fn to_db(&self, cx: &App) -> Task<DbThread> {
568 let initial_project_snapshot = self.initial_project_snapshot.clone();
569 let mut thread = DbThread {
570 title: self.title.unwrap_or_default(),
571 messages: self.messages.clone(),
572 updated_at: self.updated_at.clone(),
573 summary: self.summary.clone(),
574 initial_project_snapshot: None,
575 cumulative_token_usage: self.cumulative_token_usage.clone(),
576 request_token_usage: self.request_token_usage.clone(),
577 model: Some(DbLanguageModel {
578 provider: self.model.provider_id().to_string(),
579 model: self.model.name().0.to_string(),
580 }),
581 completion_mode: Some(self.completion_mode.into()),
582 profile: Some(self.profile_id.clone()),
583 };
584
585 cx.background_spawn(async move {
586 let initial_project_snapshot = initial_project_snapshot.await;
587 thread.initial_project_snapshot = initial_project_snapshot;
588 thread
589 })
590 }
591
592 pub fn replay(
593 &mut self,
594 cx: &mut Context<Self>,
595 ) -> mpsc::UnboundedReceiver<Result<ThreadEvent>> {
596 let (tx, rx) = mpsc::unbounded();
597 let stream = ThreadEventStream(tx);
598 for message in &self.messages {
599 match message {
600 Message::User(user_message) => stream.send_user_message(&user_message),
601 Message::Agent(assistant_message) => {
602 for content in &assistant_message.content {
603 match content {
604 AgentMessageContent::Text(text) => stream.send_text(text),
605 AgentMessageContent::Thinking { text, .. } => {
606 stream.send_thinking(text)
607 }
608 AgentMessageContent::RedactedThinking(_) => {}
609 AgentMessageContent::ToolUse(tool_use) => {
610 self.replay_tool_call(
611 tool_use,
612 assistant_message.tool_results.get(&tool_use.id),
613 &stream,
614 cx,
615 );
616 }
617 }
618 }
619 }
620 Message::Resume => {}
621 }
622 }
623 rx
624 }
625
626 fn replay_tool_call(
627 &self,
628 tool_use: &LanguageModelToolUse,
629 tool_result: Option<&LanguageModelToolResult>,
630 stream: &ThreadEventStream,
631 cx: &mut Context<Self>,
632 ) {
633 let Some(tool) = self.tools.get(tool_use.name.as_ref()) else {
634 stream
635 .0
636 .unbounded_send(Ok(ThreadEvent::ToolCall(acp::ToolCall {
637 id: acp::ToolCallId(tool_use.id.to_string().into()),
638 title: tool_use.name.to_string(),
639 kind: acp::ToolKind::Other,
640 status: acp::ToolCallStatus::Failed,
641 content: Vec::new(),
642 locations: Vec::new(),
643 raw_input: Some(tool_use.input.clone()),
644 raw_output: None,
645 })))
646 .ok();
647 return;
648 };
649
650 let title = tool.initial_title(tool_use.input.clone());
651 let kind = tool.kind();
652 stream.send_tool_call(&tool_use.id, title, kind, tool_use.input.clone());
653
654 let output = tool_result
655 .as_ref()
656 .and_then(|result| result.output.clone());
657 if let Some(output) = output.clone() {
658 let tool_event_stream = ToolCallEventStream::new(
659 tool_use.id.clone(),
660 stream.clone(),
661 Some(self.project.read(cx).fs().clone()),
662 );
663 tool.replay(tool_use.input.clone(), output, tool_event_stream, cx)
664 .log_err();
665 }
666
667 stream.update_tool_call_fields(
668 &tool_use.id,
669 acp::ToolCallUpdateFields {
670 status: Some(acp::ToolCallStatus::Completed),
671 raw_output: output,
672 ..Default::default()
673 },
674 );
675 }
676
677 /// Create a snapshot of the current project state including git information and unsaved buffers.
678 fn project_snapshot(
679 project: Entity<Project>,
680 cx: &mut Context<Self>,
681 ) -> Task<Arc<agent::thread::ProjectSnapshot>> {
682 let git_store = project.read(cx).git_store().clone();
683 let worktree_snapshots: Vec<_> = project
684 .read(cx)
685 .visible_worktrees(cx)
686 .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
687 .collect();
688
689 cx.spawn(async move |_, cx| {
690 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
691
692 let mut unsaved_buffers = Vec::new();
693 cx.update(|app_cx| {
694 let buffer_store = project.read(app_cx).buffer_store();
695 for buffer_handle in buffer_store.read(app_cx).buffers() {
696 let buffer = buffer_handle.read(app_cx);
697 if buffer.is_dirty() {
698 if let Some(file) = buffer.file() {
699 let path = file.path().to_string_lossy().to_string();
700 unsaved_buffers.push(path);
701 }
702 }
703 }
704 })
705 .ok();
706
707 Arc::new(ProjectSnapshot {
708 worktree_snapshots,
709 unsaved_buffer_paths: unsaved_buffers,
710 timestamp: Utc::now(),
711 })
712 })
713 }
714
715 fn worktree_snapshot(
716 worktree: Entity<project::Worktree>,
717 git_store: Entity<GitStore>,
718 cx: &App,
719 ) -> Task<agent::thread::WorktreeSnapshot> {
720 cx.spawn(async move |cx| {
721 // Get worktree path and snapshot
722 let worktree_info = cx.update(|app_cx| {
723 let worktree = worktree.read(app_cx);
724 let path = worktree.abs_path().to_string_lossy().to_string();
725 let snapshot = worktree.snapshot();
726 (path, snapshot)
727 });
728
729 let Ok((worktree_path, _snapshot)) = worktree_info else {
730 return WorktreeSnapshot {
731 worktree_path: String::new(),
732 git_state: None,
733 };
734 };
735
736 let git_state = git_store
737 .update(cx, |git_store, cx| {
738 git_store
739 .repositories()
740 .values()
741 .find(|repo| {
742 repo.read(cx)
743 .abs_path_to_repo_path(&worktree.read(cx).abs_path())
744 .is_some()
745 })
746 .cloned()
747 })
748 .ok()
749 .flatten()
750 .map(|repo| {
751 repo.update(cx, |repo, _| {
752 let current_branch =
753 repo.branch.as_ref().map(|branch| branch.name().to_owned());
754 repo.send_job(None, |state, _| async move {
755 let RepositoryState::Local { backend, .. } = state else {
756 return GitState {
757 remote_url: None,
758 head_sha: None,
759 current_branch,
760 diff: None,
761 };
762 };
763
764 let remote_url = backend.remote_url("origin");
765 let head_sha = backend.head_sha().await;
766 let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
767
768 GitState {
769 remote_url,
770 head_sha,
771 current_branch,
772 diff,
773 }
774 })
775 })
776 });
777
778 let git_state = match git_state {
779 Some(git_state) => match git_state.ok() {
780 Some(git_state) => git_state.await.ok(),
781 None => None,
782 },
783 None => None,
784 };
785
786 WorktreeSnapshot {
787 worktree_path,
788 git_state,
789 }
790 })
791 }
792
793 pub fn project(&self) -> &Entity<Project> {
794 &self.project
795 }
796
797 pub fn action_log(&self) -> &Entity<ActionLog> {
798 &self.action_log
799 }
800
801 pub fn model(&self) -> &Arc<dyn LanguageModel> {
802 &self.model
803 }
804
805 pub fn set_model(&mut self, model: Arc<dyn LanguageModel>, cx: &mut Context<Self>) {
806 self.model = model;
807 cx.notify()
808 }
809
810 pub fn completion_mode(&self) -> CompletionMode {
811 self.completion_mode
812 }
813
814 pub fn set_completion_mode(&mut self, mode: CompletionMode, cx: &mut Context<Self>) {
815 self.completion_mode = mode;
816 cx.notify()
817 }
818
819 #[cfg(any(test, feature = "test-support"))]
820 pub fn last_message(&self) -> Option<Message> {
821 if let Some(message) = self.pending_message.clone() {
822 Some(Message::Agent(message))
823 } else {
824 self.messages.last().cloned()
825 }
826 }
827
828 pub fn add_tool(&mut self, tool: impl AgentTool) {
829 self.tools.insert(tool.name(), tool.erase());
830 }
831
832 pub fn remove_tool(&mut self, name: &str) -> bool {
833 self.tools.remove(name).is_some()
834 }
835
836 pub fn profile(&self) -> &AgentProfileId {
837 &self.profile_id
838 }
839
840 pub fn set_profile(&mut self, profile_id: AgentProfileId) {
841 self.profile_id = profile_id;
842 }
843
844 pub fn cancel(&mut self, cx: &mut Context<Self>) {
845 if let Some(running_turn) = self.running_turn.take() {
846 running_turn.cancel();
847 }
848 self.flush_pending_message(cx);
849 }
850
851 pub fn truncate(&mut self, message_id: UserMessageId, cx: &mut Context<Self>) -> Result<()> {
852 self.cancel(cx);
853 let Some(position) = self.messages.iter().position(
854 |msg| matches!(msg, Message::User(UserMessage { id, .. }) if id == &message_id),
855 ) else {
856 return Err(anyhow!("Message not found"));
857 };
858 self.messages.truncate(position);
859 cx.notify();
860 Ok(())
861 }
862
863 pub fn resume(
864 &mut self,
865 cx: &mut Context<Self>,
866 ) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>> {
867 anyhow::ensure!(
868 self.tool_use_limit_reached,
869 "can only resume after tool use limit is reached"
870 );
871
872 self.messages.push(Message::Resume);
873 cx.notify();
874
875 log::info!("Total messages in thread: {}", self.messages.len());
876 Ok(self.run_turn(cx))
877 }
878
879 /// Sending a message results in the model streaming a response, which could include tool calls.
880 /// After calling tools, the model will stops and waits for any outstanding tool calls to be completed and their results sent.
881 /// The returned channel will report all the occurrences in which the model stops before erroring or ending its turn.
882 pub fn send<T>(
883 &mut self,
884 id: UserMessageId,
885 content: impl IntoIterator<Item = T>,
886 cx: &mut Context<Self>,
887 ) -> mpsc::UnboundedReceiver<Result<ThreadEvent>>
888 where
889 T: Into<UserMessageContent>,
890 {
891 log::info!("Thread::send called with model: {:?}", self.model.name());
892 self.advance_prompt_id();
893
894 let content = content.into_iter().map(Into::into).collect::<Vec<_>>();
895 log::debug!("Thread::send content: {:?}", content);
896
897 self.messages
898 .push(Message::User(UserMessage { id, content }));
899 cx.notify();
900
901 log::info!("Total messages in thread: {}", self.messages.len());
902 self.run_turn(cx)
903 }
904
905 fn run_turn(&mut self, cx: &mut Context<Self>) -> mpsc::UnboundedReceiver<Result<ThreadEvent>> {
906 self.cancel(cx);
907
908 let model = self.model.clone();
909 let (events_tx, events_rx) = mpsc::unbounded::<Result<ThreadEvent>>();
910 let event_stream = ThreadEventStream(events_tx);
911 let message_ix = self.messages.len().saturating_sub(1);
912 self.tool_use_limit_reached = false;
913 self.running_turn = Some(RunningTurn {
914 event_stream: event_stream.clone(),
915 _task: cx.spawn(async move |this, cx| {
916 log::info!("Starting agent turn execution");
917 let turn_result: Result<()> = async {
918 let mut completion_intent = CompletionIntent::UserPrompt;
919 loop {
920 log::debug!(
921 "Building completion request with intent: {:?}",
922 completion_intent
923 );
924 let request = this.update(cx, |this, cx| {
925 this.build_completion_request(completion_intent, cx)
926 })?;
927
928 log::info!("Calling model.stream_completion");
929 let mut events = model.stream_completion(request, cx).await?;
930 log::debug!("Stream completion started successfully");
931
932 let mut tool_use_limit_reached = false;
933 let mut tool_uses = FuturesUnordered::new();
934 while let Some(event) = events.next().await {
935 match event? {
936 LanguageModelCompletionEvent::StatusUpdate(
937 CompletionRequestStatus::ToolUseLimitReached,
938 ) => {
939 tool_use_limit_reached = true;
940 }
941 LanguageModelCompletionEvent::Stop(reason) => {
942 event_stream.send_stop(reason);
943 if reason == StopReason::Refusal {
944 this.update(cx, |this, cx| {
945 this.flush_pending_message(cx);
946 this.messages.truncate(message_ix);
947 })?;
948 return Ok(());
949 }
950 }
951 event => {
952 log::trace!("Received completion event: {:?}", event);
953 this.update(cx, |this, cx| {
954 tool_uses.extend(this.handle_streamed_completion_event(
955 event,
956 &event_stream,
957 cx,
958 ));
959 })
960 .ok();
961 }
962 }
963 }
964
965 let used_tools = tool_uses.is_empty();
966 while let Some(tool_result) = tool_uses.next().await {
967 log::info!("Tool finished {:?}", tool_result);
968
969 event_stream.update_tool_call_fields(
970 &tool_result.tool_use_id,
971 acp::ToolCallUpdateFields {
972 status: Some(if tool_result.is_error {
973 acp::ToolCallStatus::Failed
974 } else {
975 acp::ToolCallStatus::Completed
976 }),
977 raw_output: tool_result.output.clone(),
978 ..Default::default()
979 },
980 );
981 this.update(cx, |this, _cx| {
982 this.pending_message()
983 .tool_results
984 .insert(tool_result.tool_use_id.clone(), tool_result);
985 })
986 .ok();
987 }
988
989 if tool_use_limit_reached {
990 log::info!("Tool use limit reached, completing turn");
991 this.update(cx, |this, _cx| this.tool_use_limit_reached = true)?;
992 return Err(language_model::ToolUseLimitReachedError.into());
993 } else if used_tools {
994 log::info!("No tool uses found, completing turn");
995 return Ok(());
996 } else {
997 this.update(cx, |this, cx| this.flush_pending_message(cx))?;
998 completion_intent = CompletionIntent::ToolResults;
999 }
1000 }
1001 }
1002 .await;
1003
1004 if let Err(error) = turn_result {
1005 log::error!("Turn execution failed: {:?}", error);
1006 event_stream.send_error(error);
1007 } else {
1008 log::info!("Turn execution completed successfully");
1009 }
1010
1011 this.update(cx, |this, cx| {
1012 this.flush_pending_message(cx);
1013 this.running_turn.take();
1014 })
1015 .ok();
1016 }),
1017 });
1018 events_rx
1019 }
1020
1021 pub fn build_system_message(&self) -> LanguageModelRequestMessage {
1022 log::debug!("Building system message");
1023 let prompt = SystemPromptTemplate {
1024 project: &self.project_context.borrow(),
1025 available_tools: self.tools.keys().cloned().collect(),
1026 }
1027 .render(&self.templates)
1028 .context("failed to build system prompt")
1029 .expect("Invalid template");
1030 log::debug!("System message built");
1031 LanguageModelRequestMessage {
1032 role: Role::System,
1033 content: vec![prompt.into()],
1034 cache: true,
1035 }
1036 }
1037
1038 /// A helper method that's called on every streamed completion event.
1039 /// Returns an optional tool result task, which the main agentic loop in
1040 /// send will send back to the model when it resolves.
1041 fn handle_streamed_completion_event(
1042 &mut self,
1043 event: LanguageModelCompletionEvent,
1044 event_stream: &ThreadEventStream,
1045 cx: &mut Context<Self>,
1046 ) -> Option<Task<LanguageModelToolResult>> {
1047 log::trace!("Handling streamed completion event: {:?}", event);
1048 use LanguageModelCompletionEvent::*;
1049
1050 match event {
1051 StartMessage { .. } => {
1052 self.flush_pending_message(cx);
1053 self.pending_message = Some(AgentMessage::default());
1054 }
1055 Text(new_text) => self.handle_text_event(new_text, event_stream, cx),
1056 Thinking { text, signature } => {
1057 self.handle_thinking_event(text, signature, event_stream, cx)
1058 }
1059 RedactedThinking { data } => self.handle_redacted_thinking_event(data, cx),
1060 ToolUse(tool_use) => {
1061 return self.handle_tool_use_event(tool_use, event_stream, cx);
1062 }
1063 ToolUseJsonParseError {
1064 id,
1065 tool_name,
1066 raw_input,
1067 json_parse_error,
1068 } => {
1069 return Some(Task::ready(self.handle_tool_use_json_parse_error_event(
1070 id,
1071 tool_name,
1072 raw_input,
1073 json_parse_error,
1074 )));
1075 }
1076 UsageUpdate(_) | StatusUpdate(_) => {}
1077 Stop(_) => unreachable!(),
1078 }
1079
1080 None
1081 }
1082
1083 fn handle_text_event(
1084 &mut self,
1085 new_text: String,
1086 event_stream: &ThreadEventStream,
1087 cx: &mut Context<Self>,
1088 ) {
1089 event_stream.send_text(&new_text);
1090
1091 let last_message = self.pending_message();
1092 if let Some(AgentMessageContent::Text(text)) = last_message.content.last_mut() {
1093 text.push_str(&new_text);
1094 } else {
1095 last_message
1096 .content
1097 .push(AgentMessageContent::Text(new_text));
1098 }
1099
1100 cx.notify();
1101 }
1102
1103 fn handle_thinking_event(
1104 &mut self,
1105 new_text: String,
1106 new_signature: Option<String>,
1107 event_stream: &ThreadEventStream,
1108 cx: &mut Context<Self>,
1109 ) {
1110 event_stream.send_thinking(&new_text);
1111
1112 let last_message = self.pending_message();
1113 if let Some(AgentMessageContent::Thinking { text, signature }) =
1114 last_message.content.last_mut()
1115 {
1116 text.push_str(&new_text);
1117 *signature = new_signature.or(signature.take());
1118 } else {
1119 last_message.content.push(AgentMessageContent::Thinking {
1120 text: new_text,
1121 signature: new_signature,
1122 });
1123 }
1124
1125 cx.notify();
1126 }
1127
1128 fn handle_redacted_thinking_event(&mut self, data: String, cx: &mut Context<Self>) {
1129 let last_message = self.pending_message();
1130 last_message
1131 .content
1132 .push(AgentMessageContent::RedactedThinking(data));
1133 cx.notify();
1134 }
1135
1136 fn handle_tool_use_event(
1137 &mut self,
1138 tool_use: LanguageModelToolUse,
1139 event_stream: &ThreadEventStream,
1140 cx: &mut Context<Self>,
1141 ) -> Option<Task<LanguageModelToolResult>> {
1142 cx.notify();
1143
1144 let tool = self.tools.get(tool_use.name.as_ref()).cloned();
1145 let mut title = SharedString::from(&tool_use.name);
1146 let mut kind = acp::ToolKind::Other;
1147 if let Some(tool) = tool.as_ref() {
1148 title = tool.initial_title(tool_use.input.clone());
1149 kind = tool.kind();
1150 }
1151
1152 // Ensure the last message ends in the current tool use
1153 let last_message = self.pending_message();
1154 let push_new_tool_use = last_message.content.last_mut().map_or(true, |content| {
1155 if let AgentMessageContent::ToolUse(last_tool_use) = content {
1156 if last_tool_use.id == tool_use.id {
1157 *last_tool_use = tool_use.clone();
1158 false
1159 } else {
1160 true
1161 }
1162 } else {
1163 true
1164 }
1165 });
1166
1167 if push_new_tool_use {
1168 event_stream.send_tool_call(&tool_use.id, title, kind, tool_use.input.clone());
1169 last_message
1170 .content
1171 .push(AgentMessageContent::ToolUse(tool_use.clone()));
1172 } else {
1173 event_stream.update_tool_call_fields(
1174 &tool_use.id,
1175 acp::ToolCallUpdateFields {
1176 title: Some(title.into()),
1177 kind: Some(kind),
1178 raw_input: Some(tool_use.input.clone()),
1179 ..Default::default()
1180 },
1181 );
1182 }
1183
1184 if !tool_use.is_input_complete {
1185 return None;
1186 }
1187
1188 let Some(tool) = tool else {
1189 let content = format!("No tool named {} exists", tool_use.name);
1190 return Some(Task::ready(LanguageModelToolResult {
1191 content: LanguageModelToolResultContent::Text(Arc::from(content)),
1192 tool_use_id: tool_use.id,
1193 tool_name: tool_use.name,
1194 is_error: true,
1195 output: None,
1196 }));
1197 };
1198
1199 let fs = self.project.read(cx).fs().clone();
1200 let tool_event_stream =
1201 ToolCallEventStream::new(tool_use.id.clone(), event_stream.clone(), Some(fs));
1202 tool_event_stream.update_fields(acp::ToolCallUpdateFields {
1203 status: Some(acp::ToolCallStatus::InProgress),
1204 ..Default::default()
1205 });
1206 let supports_images = self.model.supports_images();
1207 let tool_result = tool.run(tool_use.input, tool_event_stream, cx);
1208 log::info!("Running tool {}", tool_use.name);
1209 Some(cx.foreground_executor().spawn(async move {
1210 let tool_result = tool_result.await.and_then(|output| {
1211 if let LanguageModelToolResultContent::Image(_) = &output.llm_output {
1212 if !supports_images {
1213 return Err(anyhow!(
1214 "Attempted to read an image, but this model doesn't support it.",
1215 ));
1216 }
1217 }
1218 Ok(output)
1219 });
1220
1221 match tool_result {
1222 Ok(output) => LanguageModelToolResult {
1223 tool_use_id: tool_use.id,
1224 tool_name: tool_use.name,
1225 is_error: false,
1226 content: output.llm_output,
1227 output: Some(output.raw_output),
1228 },
1229 Err(error) => LanguageModelToolResult {
1230 tool_use_id: tool_use.id,
1231 tool_name: tool_use.name,
1232 is_error: true,
1233 content: LanguageModelToolResultContent::Text(Arc::from(error.to_string())),
1234 output: None,
1235 },
1236 }
1237 }))
1238 }
1239
1240 fn handle_tool_use_json_parse_error_event(
1241 &mut self,
1242 tool_use_id: LanguageModelToolUseId,
1243 tool_name: Arc<str>,
1244 raw_input: Arc<str>,
1245 json_parse_error: String,
1246 ) -> LanguageModelToolResult {
1247 let tool_output = format!("Error parsing input JSON: {json_parse_error}");
1248 LanguageModelToolResult {
1249 tool_use_id,
1250 tool_name,
1251 is_error: true,
1252 content: LanguageModelToolResultContent::Text(tool_output.into()),
1253 output: Some(serde_json::Value::String(raw_input.to_string())),
1254 }
1255 }
1256
1257 fn pending_message(&mut self) -> &mut AgentMessage {
1258 self.pending_message.get_or_insert_default()
1259 }
1260
1261 fn flush_pending_message(&mut self, cx: &mut Context<Self>) {
1262 let Some(mut message) = self.pending_message.take() else {
1263 return;
1264 };
1265
1266 for content in &message.content {
1267 let AgentMessageContent::ToolUse(tool_use) = content else {
1268 continue;
1269 };
1270
1271 if !message.tool_results.contains_key(&tool_use.id) {
1272 message.tool_results.insert(
1273 tool_use.id.clone(),
1274 LanguageModelToolResult {
1275 tool_use_id: tool_use.id.clone(),
1276 tool_name: tool_use.name.clone(),
1277 is_error: true,
1278 content: LanguageModelToolResultContent::Text(TOOL_CANCELED_MESSAGE.into()),
1279 output: None,
1280 },
1281 );
1282 }
1283 }
1284
1285 self.messages.push(Message::Agent(message));
1286 dbg!("!!!!!!!!!!!!!!!!!!!!!!!");
1287 cx.notify()
1288 }
1289
1290 pub(crate) fn build_completion_request(
1291 &self,
1292 completion_intent: CompletionIntent,
1293 cx: &mut App,
1294 ) -> LanguageModelRequest {
1295 log::debug!("Building completion request");
1296 log::debug!("Completion intent: {:?}", completion_intent);
1297 log::debug!("Completion mode: {:?}", self.completion_mode);
1298
1299 let messages = self.build_request_messages();
1300 log::info!("Request will include {} messages", messages.len());
1301
1302 let tools = if let Some(tools) = self.tools(cx).log_err() {
1303 tools
1304 .filter_map(|tool| {
1305 let tool_name = tool.name().to_string();
1306 log::trace!("Including tool: {}", tool_name);
1307 Some(LanguageModelRequestTool {
1308 name: tool_name,
1309 description: tool.description().to_string(),
1310 input_schema: tool
1311 .input_schema(self.model.tool_input_format())
1312 .log_err()?,
1313 })
1314 })
1315 .collect()
1316 } else {
1317 Vec::new()
1318 };
1319
1320 log::info!("Request includes {} tools", tools.len());
1321
1322 let request = LanguageModelRequest {
1323 thread_id: Some(self.id.to_string()),
1324 prompt_id: Some(self.prompt_id.to_string()),
1325 intent: Some(completion_intent),
1326 mode: Some(self.completion_mode.into()),
1327 messages,
1328 tools,
1329 tool_choice: None,
1330 stop: Vec::new(),
1331 temperature: AgentSettings::temperature_for_model(self.model(), cx),
1332 thinking_allowed: true,
1333 };
1334
1335 log::debug!("Completion request built successfully");
1336 request
1337 }
1338
1339 fn tools<'a>(&'a self, cx: &'a App) -> Result<impl Iterator<Item = &'a Arc<dyn AnyAgentTool>>> {
1340 let profile = AgentSettings::get_global(cx)
1341 .profiles
1342 .get(&self.profile_id)
1343 .context("profile not found")?;
1344 let provider_id = self.model.provider_id();
1345
1346 Ok(self
1347 .tools
1348 .iter()
1349 .filter(move |(_, tool)| tool.supported_provider(&provider_id))
1350 .filter_map(|(tool_name, tool)| {
1351 if profile.is_tool_enabled(tool_name) {
1352 Some(tool)
1353 } else {
1354 None
1355 }
1356 })
1357 .chain(self.context_server_registry.read(cx).servers().flat_map(
1358 |(server_id, tools)| {
1359 tools.iter().filter_map(|(tool_name, tool)| {
1360 if profile.is_context_server_tool_enabled(&server_id.0, tool_name) {
1361 Some(tool)
1362 } else {
1363 None
1364 }
1365 })
1366 },
1367 )))
1368 }
1369
1370 fn build_request_messages(&self) -> Vec<LanguageModelRequestMessage> {
1371 log::trace!(
1372 "Building request messages from {} thread messages",
1373 self.messages.len()
1374 );
1375 let mut messages = vec![self.build_system_message()];
1376 for message in &self.messages {
1377 match message {
1378 Message::User(message) => messages.push(message.to_request()),
1379 Message::Agent(message) => messages.extend(message.to_request()),
1380 Message::Resume => messages.push(LanguageModelRequestMessage {
1381 role: Role::User,
1382 content: vec!["Continue where you left off".into()],
1383 cache: false,
1384 }),
1385 }
1386 }
1387
1388 if let Some(message) = self.pending_message.as_ref() {
1389 messages.extend(message.to_request());
1390 }
1391
1392 if let Some(last_user_message) = messages
1393 .iter_mut()
1394 .rev()
1395 .find(|message| message.role == Role::User)
1396 {
1397 last_user_message.cache = true;
1398 }
1399
1400 messages
1401 }
1402
1403 pub fn to_markdown(&self) -> String {
1404 let mut markdown = String::new();
1405 for (ix, message) in self.messages.iter().enumerate() {
1406 if ix > 0 {
1407 markdown.push('\n');
1408 }
1409 markdown.push_str(&message.to_markdown());
1410 }
1411
1412 if let Some(message) = self.pending_message.as_ref() {
1413 markdown.push('\n');
1414 markdown.push_str(&message.to_markdown());
1415 }
1416
1417 markdown
1418 }
1419
1420 fn advance_prompt_id(&mut self) {
1421 self.prompt_id = PromptId::new();
1422 }
1423}
1424
1425struct RunningTurn {
1426 /// Holds the task that handles agent interaction until the end of the turn.
1427 /// Survives across multiple requests as the model performs tool calls and
1428 /// we run tools, report their results.
1429 _task: Task<()>,
1430 /// The current event stream for the running turn. Used to report a final
1431 /// cancellation event if we cancel the turn.
1432 event_stream: ThreadEventStream,
1433}
1434
1435impl RunningTurn {
1436 fn cancel(self) {
1437 log::debug!("Cancelling in progress turn");
1438 self.event_stream.send_canceled();
1439 }
1440}
1441
1442pub trait AgentTool
1443where
1444 Self: 'static + Sized,
1445{
1446 type Input: for<'de> Deserialize<'de> + Serialize + JsonSchema;
1447 type Output: for<'de> Deserialize<'de> + Serialize + Into<LanguageModelToolResultContent>;
1448
1449 fn name(&self) -> SharedString;
1450
1451 fn description(&self) -> SharedString {
1452 let schema = schemars::schema_for!(Self::Input);
1453 SharedString::new(
1454 schema
1455 .get("description")
1456 .and_then(|description| description.as_str())
1457 .unwrap_or_default(),
1458 )
1459 }
1460
1461 fn kind(&self) -> acp::ToolKind;
1462
1463 /// The initial tool title to display. Can be updated during the tool run.
1464 fn initial_title(&self, input: Result<Self::Input, serde_json::Value>) -> SharedString;
1465
1466 /// Returns the JSON schema that describes the tool's input.
1467 fn input_schema(&self) -> Schema {
1468 schemars::schema_for!(Self::Input)
1469 }
1470
1471 /// Some tools rely on a provider for the underlying billing or other reasons.
1472 /// Allow the tool to check if they are compatible, or should be filtered out.
1473 fn supported_provider(&self, _provider: &LanguageModelProviderId) -> bool {
1474 true
1475 }
1476
1477 /// Runs the tool with the provided input.
1478 fn run(
1479 self: Arc<Self>,
1480 input: Self::Input,
1481 event_stream: ToolCallEventStream,
1482 cx: &mut App,
1483 ) -> Task<Result<Self::Output>>;
1484
1485 /// Emits events for a previous execution of the tool.
1486 fn replay(
1487 &self,
1488 _input: Self::Input,
1489 _output: Self::Output,
1490 _event_stream: ToolCallEventStream,
1491 _cx: &mut App,
1492 ) -> Result<()> {
1493 Ok(())
1494 }
1495
1496 fn erase(self) -> Arc<dyn AnyAgentTool> {
1497 Arc::new(Erased(Arc::new(self)))
1498 }
1499}
1500
1501pub struct Erased<T>(T);
1502
1503pub struct AgentToolOutput {
1504 pub llm_output: LanguageModelToolResultContent,
1505 pub raw_output: serde_json::Value,
1506}
1507
1508pub trait AnyAgentTool {
1509 fn name(&self) -> SharedString;
1510 fn description(&self) -> SharedString;
1511 fn kind(&self) -> acp::ToolKind;
1512 fn initial_title(&self, input: serde_json::Value) -> SharedString;
1513 fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value>;
1514 fn supported_provider(&self, _provider: &LanguageModelProviderId) -> bool {
1515 true
1516 }
1517 fn run(
1518 self: Arc<Self>,
1519 input: serde_json::Value,
1520 event_stream: ToolCallEventStream,
1521 cx: &mut App,
1522 ) -> Task<Result<AgentToolOutput>>;
1523 fn replay(
1524 &self,
1525 input: serde_json::Value,
1526 output: serde_json::Value,
1527 event_stream: ToolCallEventStream,
1528 cx: &mut App,
1529 ) -> Result<()>;
1530}
1531
1532impl<T> AnyAgentTool for Erased<Arc<T>>
1533where
1534 T: AgentTool,
1535{
1536 fn name(&self) -> SharedString {
1537 self.0.name()
1538 }
1539
1540 fn description(&self) -> SharedString {
1541 self.0.description()
1542 }
1543
1544 fn kind(&self) -> agent_client_protocol::ToolKind {
1545 self.0.kind()
1546 }
1547
1548 fn initial_title(&self, input: serde_json::Value) -> SharedString {
1549 let parsed_input = serde_json::from_value(input.clone()).map_err(|_| input);
1550 self.0.initial_title(parsed_input)
1551 }
1552
1553 fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
1554 let mut json = serde_json::to_value(self.0.input_schema())?;
1555 adapt_schema_to_format(&mut json, format)?;
1556 Ok(json)
1557 }
1558
1559 fn supported_provider(&self, provider: &LanguageModelProviderId) -> bool {
1560 self.0.supported_provider(provider)
1561 }
1562
1563 fn run(
1564 self: Arc<Self>,
1565 input: serde_json::Value,
1566 event_stream: ToolCallEventStream,
1567 cx: &mut App,
1568 ) -> Task<Result<AgentToolOutput>> {
1569 cx.spawn(async move |cx| {
1570 let input = serde_json::from_value(input)?;
1571 let output = cx
1572 .update(|cx| self.0.clone().run(input, event_stream, cx))?
1573 .await?;
1574 let raw_output = serde_json::to_value(&output)?;
1575 Ok(AgentToolOutput {
1576 llm_output: output.into(),
1577 raw_output,
1578 })
1579 })
1580 }
1581
1582 fn replay(
1583 &self,
1584 input: serde_json::Value,
1585 output: serde_json::Value,
1586 event_stream: ToolCallEventStream,
1587 cx: &mut App,
1588 ) -> Result<()> {
1589 let input = serde_json::from_value(input)?;
1590 let output = serde_json::from_value(output)?;
1591 self.0.replay(input, output, event_stream, cx)
1592 }
1593}
1594
1595#[derive(Clone)]
1596struct ThreadEventStream(mpsc::UnboundedSender<Result<ThreadEvent>>);
1597
1598impl ThreadEventStream {
1599 fn send_user_message(&self, message: &UserMessage) {
1600 self.0
1601 .unbounded_send(Ok(ThreadEvent::UserMessage(message.clone())))
1602 .ok();
1603 }
1604
1605 fn send_text(&self, text: &str) {
1606 self.0
1607 .unbounded_send(Ok(ThreadEvent::AgentText(text.to_string())))
1608 .ok();
1609 }
1610
1611 fn send_thinking(&self, text: &str) {
1612 self.0
1613 .unbounded_send(Ok(ThreadEvent::AgentThinking(text.to_string())))
1614 .ok();
1615 }
1616
1617 fn send_tool_call(
1618 &self,
1619 id: &LanguageModelToolUseId,
1620 title: SharedString,
1621 kind: acp::ToolKind,
1622 input: serde_json::Value,
1623 ) {
1624 self.0
1625 .unbounded_send(Ok(ThreadEvent::ToolCall(Self::initial_tool_call(
1626 id,
1627 title.to_string(),
1628 kind,
1629 input,
1630 ))))
1631 .ok();
1632 }
1633
1634 fn initial_tool_call(
1635 id: &LanguageModelToolUseId,
1636 title: String,
1637 kind: acp::ToolKind,
1638 input: serde_json::Value,
1639 ) -> acp::ToolCall {
1640 acp::ToolCall {
1641 id: acp::ToolCallId(id.to_string().into()),
1642 title,
1643 kind,
1644 status: acp::ToolCallStatus::Pending,
1645 content: vec![],
1646 locations: vec![],
1647 raw_input: Some(input),
1648 raw_output: None,
1649 }
1650 }
1651
1652 fn update_tool_call_fields(
1653 &self,
1654 tool_use_id: &LanguageModelToolUseId,
1655 fields: acp::ToolCallUpdateFields,
1656 ) {
1657 self.0
1658 .unbounded_send(Ok(ThreadEvent::ToolCallUpdate(
1659 acp::ToolCallUpdate {
1660 id: acp::ToolCallId(tool_use_id.to_string().into()),
1661 fields,
1662 }
1663 .into(),
1664 )))
1665 .ok();
1666 }
1667
1668 fn send_stop(&self, reason: StopReason) {
1669 match reason {
1670 StopReason::EndTurn => {
1671 self.0
1672 .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::EndTurn)))
1673 .ok();
1674 }
1675 StopReason::MaxTokens => {
1676 self.0
1677 .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::MaxTokens)))
1678 .ok();
1679 }
1680 StopReason::Refusal => {
1681 self.0
1682 .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::Refusal)))
1683 .ok();
1684 }
1685 StopReason::ToolUse => {}
1686 }
1687 }
1688
1689 fn send_canceled(&self) {
1690 self.0
1691 .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::Canceled)))
1692 .ok();
1693 }
1694
1695 fn send_error(&self, error: impl Into<anyhow::Error>) {
1696 self.0.unbounded_send(Err(error.into())).ok();
1697 }
1698}
1699
1700#[derive(Clone)]
1701pub struct ToolCallEventStream {
1702 tool_use_id: LanguageModelToolUseId,
1703 stream: ThreadEventStream,
1704 fs: Option<Arc<dyn Fs>>,
1705}
1706
1707impl ToolCallEventStream {
1708 #[cfg(test)]
1709 pub fn test() -> (Self, ToolCallEventStreamReceiver) {
1710 let (events_tx, events_rx) = mpsc::unbounded::<Result<ThreadEvent>>();
1711
1712 let stream = ToolCallEventStream::new("test_id".into(), ThreadEventStream(events_tx), None);
1713
1714 (stream, ToolCallEventStreamReceiver(events_rx))
1715 }
1716
1717 fn new(
1718 tool_use_id: LanguageModelToolUseId,
1719 stream: ThreadEventStream,
1720 fs: Option<Arc<dyn Fs>>,
1721 ) -> Self {
1722 Self {
1723 tool_use_id,
1724 stream,
1725 fs,
1726 }
1727 }
1728
1729 pub fn update_fields(&self, fields: acp::ToolCallUpdateFields) {
1730 self.stream
1731 .update_tool_call_fields(&self.tool_use_id, fields);
1732 }
1733
1734 pub fn update_diff(&self, diff: Entity<acp_thread::Diff>) {
1735 self.stream
1736 .0
1737 .unbounded_send(Ok(ThreadEvent::ToolCallUpdate(
1738 acp_thread::ToolCallUpdateDiff {
1739 id: acp::ToolCallId(self.tool_use_id.to_string().into()),
1740 diff,
1741 }
1742 .into(),
1743 )))
1744 .ok();
1745 }
1746
1747 pub fn update_terminal(&self, terminal: Entity<acp_thread::Terminal>) {
1748 self.stream
1749 .0
1750 .unbounded_send(Ok(ThreadEvent::ToolCallUpdate(
1751 acp_thread::ToolCallUpdateTerminal {
1752 id: acp::ToolCallId(self.tool_use_id.to_string().into()),
1753 terminal,
1754 }
1755 .into(),
1756 )))
1757 .ok();
1758 }
1759
1760 pub fn authorize(&self, title: impl Into<String>, cx: &mut App) -> Task<Result<()>> {
1761 if agent_settings::AgentSettings::get_global(cx).always_allow_tool_actions {
1762 return Task::ready(Ok(()));
1763 }
1764
1765 let (response_tx, response_rx) = oneshot::channel();
1766 self.stream
1767 .0
1768 .unbounded_send(Ok(ThreadEvent::ToolCallAuthorization(
1769 ToolCallAuthorization {
1770 tool_call: acp::ToolCallUpdate {
1771 id: acp::ToolCallId(self.tool_use_id.to_string().into()),
1772 fields: acp::ToolCallUpdateFields {
1773 title: Some(title.into()),
1774 ..Default::default()
1775 },
1776 },
1777 options: vec![
1778 acp::PermissionOption {
1779 id: acp::PermissionOptionId("always_allow".into()),
1780 name: "Always Allow".into(),
1781 kind: acp::PermissionOptionKind::AllowAlways,
1782 },
1783 acp::PermissionOption {
1784 id: acp::PermissionOptionId("allow".into()),
1785 name: "Allow".into(),
1786 kind: acp::PermissionOptionKind::AllowOnce,
1787 },
1788 acp::PermissionOption {
1789 id: acp::PermissionOptionId("deny".into()),
1790 name: "Deny".into(),
1791 kind: acp::PermissionOptionKind::RejectOnce,
1792 },
1793 ],
1794 response: response_tx,
1795 },
1796 )))
1797 .ok();
1798 let fs = self.fs.clone();
1799 cx.spawn(async move |cx| match response_rx.await?.0.as_ref() {
1800 "always_allow" => {
1801 if let Some(fs) = fs.clone() {
1802 cx.update(|cx| {
1803 update_settings_file::<AgentSettings>(fs, cx, |settings, _| {
1804 settings.set_always_allow_tool_actions(true);
1805 });
1806 })?;
1807 }
1808
1809 Ok(())
1810 }
1811 "allow" => Ok(()),
1812 _ => Err(anyhow!("Permission to run tool denied by user")),
1813 })
1814 }
1815}
1816
1817#[cfg(test)]
1818pub struct ToolCallEventStreamReceiver(mpsc::UnboundedReceiver<Result<ThreadEvent>>);
1819
1820#[cfg(test)]
1821impl ToolCallEventStreamReceiver {
1822 pub async fn expect_authorization(&mut self) -> ToolCallAuthorization {
1823 let event = self.0.next().await;
1824 if let Some(Ok(ThreadEvent::ToolCallAuthorization(auth))) = event {
1825 auth
1826 } else {
1827 panic!("Expected ToolCallAuthorization but got: {:?}", event);
1828 }
1829 }
1830
1831 pub async fn expect_terminal(&mut self) -> Entity<acp_thread::Terminal> {
1832 let event = self.0.next().await;
1833 if let Some(Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateTerminal(
1834 update,
1835 )))) = event
1836 {
1837 update.terminal
1838 } else {
1839 panic!("Expected terminal but got: {:?}", event);
1840 }
1841 }
1842}
1843
1844#[cfg(test)]
1845impl std::ops::Deref for ToolCallEventStreamReceiver {
1846 type Target = mpsc::UnboundedReceiver<Result<ThreadEvent>>;
1847
1848 fn deref(&self) -> &Self::Target {
1849 &self.0
1850 }
1851}
1852
1853#[cfg(test)]
1854impl std::ops::DerefMut for ToolCallEventStreamReceiver {
1855 fn deref_mut(&mut self) -> &mut Self::Target {
1856 &mut self.0
1857 }
1858}
1859
1860impl From<&str> for UserMessageContent {
1861 fn from(text: &str) -> Self {
1862 Self::Text(text.into())
1863 }
1864}
1865
1866impl From<acp::ContentBlock> for UserMessageContent {
1867 fn from(value: acp::ContentBlock) -> Self {
1868 match value {
1869 acp::ContentBlock::Text(text_content) => Self::Text(text_content.text),
1870 acp::ContentBlock::Image(image_content) => Self::Image(convert_image(image_content)),
1871 acp::ContentBlock::Audio(_) => {
1872 // TODO
1873 Self::Text("[audio]".to_string())
1874 }
1875 acp::ContentBlock::ResourceLink(resource_link) => {
1876 match MentionUri::parse(&resource_link.uri) {
1877 Ok(uri) => Self::Mention {
1878 uri,
1879 content: String::new(),
1880 },
1881 Err(err) => {
1882 log::error!("Failed to parse mention link: {}", err);
1883 Self::Text(format!("[{}]({})", resource_link.name, resource_link.uri))
1884 }
1885 }
1886 }
1887 acp::ContentBlock::Resource(resource) => match resource.resource {
1888 acp::EmbeddedResourceResource::TextResourceContents(resource) => {
1889 match MentionUri::parse(&resource.uri) {
1890 Ok(uri) => Self::Mention {
1891 uri,
1892 content: resource.text,
1893 },
1894 Err(err) => {
1895 log::error!("Failed to parse mention link: {}", err);
1896 Self::Text(
1897 MarkdownCodeBlock {
1898 tag: &resource.uri,
1899 text: &resource.text,
1900 }
1901 .to_string(),
1902 )
1903 }
1904 }
1905 }
1906 acp::EmbeddedResourceResource::BlobResourceContents(_) => {
1907 // TODO
1908 Self::Text("[blob]".to_string())
1909 }
1910 },
1911 }
1912 }
1913}
1914
1915impl From<UserMessageContent> for acp::ContentBlock {
1916 fn from(content: UserMessageContent) -> Self {
1917 match content {
1918 UserMessageContent::Text(text) => acp::ContentBlock::Text(acp::TextContent {
1919 text,
1920 annotations: None,
1921 }),
1922 UserMessageContent::Image(image) => acp::ContentBlock::Image(acp::ImageContent {
1923 data: image.source.to_string(),
1924 mime_type: "image/png".to_string(),
1925 annotations: None,
1926 uri: None,
1927 }),
1928 UserMessageContent::Mention { uri, content } => {
1929 todo!()
1930 }
1931 }
1932 }
1933}
1934
1935fn convert_image(image_content: acp::ImageContent) -> LanguageModelImage {
1936 LanguageModelImage {
1937 source: image_content.data.into(),
1938 // TODO: make this optional?
1939 size: gpui::Size::new(0.into(), 0.into()),
1940 }
1941}