1use crate::{ContextServerRegistry, SystemPromptTemplate, Template, Templates};
2use acp_thread::{MentionUri, UserMessageId};
3use action_log::ActionLog;
4use agent_client_protocol as acp;
5use agent_settings::{AgentProfileId, AgentSettings, CompletionMode};
6use anyhow::{Context as _, Result, anyhow};
7use assistant_tool::adapt_schema_to_format;
8use cloud_llm_client::{CompletionIntent, CompletionRequestStatus};
9use collections::IndexMap;
10use fs::Fs;
11use futures::{
12 channel::{mpsc, oneshot},
13 stream::FuturesUnordered,
14};
15use gpui::{App, AsyncApp, Context, Entity, SharedString, Task, WeakEntity};
16use language_model::{
17 LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelImage,
18 LanguageModelProviderId, LanguageModelRequest, LanguageModelRequestMessage,
19 LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolResultContent,
20 LanguageModelToolSchemaFormat, LanguageModelToolUse, LanguageModelToolUseId, Role, StopReason,
21};
22use project::Project;
23use prompt_store::ProjectContext;
24use schemars::{JsonSchema, Schema};
25use serde::{Deserialize, Serialize};
26use settings::{Settings, update_settings_file};
27use smol::stream::StreamExt;
28use std::{
29 collections::BTreeMap,
30 path::Path,
31 sync::Arc,
32 time::{Duration, Instant},
33};
34use std::{fmt::Write, ops::Range};
35use util::{ResultExt, markdown::MarkdownCodeBlock};
36use uuid::Uuid;
37
38#[derive(
39 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
40)]
41pub struct ThreadId(Arc<str>);
42
43impl ThreadId {
44 pub fn new() -> Self {
45 Self(Uuid::new_v4().to_string().into())
46 }
47}
48
49impl std::fmt::Display for ThreadId {
50 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
51 write!(f, "{}", self.0)
52 }
53}
54
55impl From<&str> for ThreadId {
56 fn from(value: &str) -> Self {
57 Self(value.into())
58 }
59}
60
61/// The ID of the user prompt that initiated a request.
62///
63/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
64#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
65pub struct PromptId(Arc<str>);
66
67impl PromptId {
68 pub fn new() -> Self {
69 Self(Uuid::new_v4().to_string().into())
70 }
71}
72
73impl std::fmt::Display for PromptId {
74 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
75 write!(f, "{}", self.0)
76 }
77}
78
79pub(crate) const MAX_RETRY_ATTEMPTS: u8 = 4;
80pub(crate) const BASE_RETRY_DELAY: Duration = Duration::from_secs(5);
81
82#[derive(Debug, Clone)]
83enum RetryStrategy {
84 ExponentialBackoff {
85 initial_delay: Duration,
86 max_attempts: u8,
87 },
88 Fixed {
89 delay: Duration,
90 max_attempts: u8,
91 },
92}
93
94#[derive(Debug, Clone, PartialEq, Eq)]
95pub enum Message {
96 User(UserMessage),
97 Agent(AgentMessage),
98 Resume,
99}
100
101impl Message {
102 pub fn as_agent_message(&self) -> Option<&AgentMessage> {
103 match self {
104 Message::Agent(agent_message) => Some(agent_message),
105 _ => None,
106 }
107 }
108
109 pub fn to_markdown(&self) -> String {
110 match self {
111 Message::User(message) => message.to_markdown(),
112 Message::Agent(message) => message.to_markdown(),
113 Message::Resume => "[resumed after tool use limit was reached]".into(),
114 }
115 }
116}
117
118#[derive(Debug, Clone, PartialEq, Eq)]
119pub struct UserMessage {
120 pub id: UserMessageId,
121 pub content: Vec<UserMessageContent>,
122}
123
124#[derive(Debug, Clone, PartialEq, Eq)]
125pub enum UserMessageContent {
126 Text(String),
127 Mention { uri: MentionUri, content: String },
128 Image(LanguageModelImage),
129}
130
131impl UserMessage {
132 pub fn to_markdown(&self) -> String {
133 let mut markdown = String::from("## User\n\n");
134
135 for content in &self.content {
136 match content {
137 UserMessageContent::Text(text) => {
138 markdown.push_str(text);
139 markdown.push('\n');
140 }
141 UserMessageContent::Image(_) => {
142 markdown.push_str("<image />\n");
143 }
144 UserMessageContent::Mention { uri, content } => {
145 if !content.is_empty() {
146 let _ = write!(&mut markdown, "{}\n\n{}\n", uri.as_link(), content);
147 } else {
148 let _ = write!(&mut markdown, "{}\n", uri.as_link());
149 }
150 }
151 }
152 }
153
154 markdown
155 }
156
157 fn to_request(&self) -> LanguageModelRequestMessage {
158 let mut message = LanguageModelRequestMessage {
159 role: Role::User,
160 content: Vec::with_capacity(self.content.len()),
161 cache: false,
162 };
163
164 const OPEN_CONTEXT: &str = "<context>\n\
165 The following items were attached by the user. \
166 They are up-to-date and don't need to be re-read.\n\n";
167
168 const OPEN_FILES_TAG: &str = "<files>";
169 const OPEN_DIRECTORIES_TAG: &str = "<directories>";
170 const OPEN_SYMBOLS_TAG: &str = "<symbols>";
171 const OPEN_THREADS_TAG: &str = "<threads>";
172 const OPEN_FETCH_TAG: &str = "<fetched_urls>";
173 const OPEN_RULES_TAG: &str =
174 "<rules>\nThe user has specified the following rules that should be applied:\n";
175
176 let mut file_context = OPEN_FILES_TAG.to_string();
177 let mut directory_context = OPEN_DIRECTORIES_TAG.to_string();
178 let mut symbol_context = OPEN_SYMBOLS_TAG.to_string();
179 let mut thread_context = OPEN_THREADS_TAG.to_string();
180 let mut fetch_context = OPEN_FETCH_TAG.to_string();
181 let mut rules_context = OPEN_RULES_TAG.to_string();
182
183 for chunk in &self.content {
184 let chunk = match chunk {
185 UserMessageContent::Text(text) => {
186 language_model::MessageContent::Text(text.clone())
187 }
188 UserMessageContent::Image(value) => {
189 language_model::MessageContent::Image(value.clone())
190 }
191 UserMessageContent::Mention { uri, content } => {
192 match uri {
193 MentionUri::File { abs_path } => {
194 write!(
195 &mut symbol_context,
196 "\n{}",
197 MarkdownCodeBlock {
198 tag: &codeblock_tag(abs_path, None),
199 text: &content.to_string(),
200 }
201 )
202 .ok();
203 }
204 MentionUri::Directory { .. } => {
205 write!(&mut directory_context, "\n{}\n", content).ok();
206 }
207 MentionUri::Symbol {
208 path, line_range, ..
209 }
210 | MentionUri::Selection {
211 path, line_range, ..
212 } => {
213 write!(
214 &mut rules_context,
215 "\n{}",
216 MarkdownCodeBlock {
217 tag: &codeblock_tag(path, Some(line_range)),
218 text: content
219 }
220 )
221 .ok();
222 }
223 MentionUri::Thread { .. } => {
224 write!(&mut thread_context, "\n{}\n", content).ok();
225 }
226 MentionUri::TextThread { .. } => {
227 write!(&mut thread_context, "\n{}\n", content).ok();
228 }
229 MentionUri::Rule { .. } => {
230 write!(
231 &mut rules_context,
232 "\n{}",
233 MarkdownCodeBlock {
234 tag: "",
235 text: content
236 }
237 )
238 .ok();
239 }
240 MentionUri::Fetch { url } => {
241 write!(&mut fetch_context, "\nFetch: {}\n\n{}", url, content).ok();
242 }
243 }
244
245 language_model::MessageContent::Text(uri.as_link().to_string())
246 }
247 };
248
249 message.content.push(chunk);
250 }
251
252 let len_before_context = message.content.len();
253
254 if file_context.len() > OPEN_FILES_TAG.len() {
255 file_context.push_str("</files>\n");
256 message
257 .content
258 .push(language_model::MessageContent::Text(file_context));
259 }
260
261 if directory_context.len() > OPEN_DIRECTORIES_TAG.len() {
262 directory_context.push_str("</directories>\n");
263 message
264 .content
265 .push(language_model::MessageContent::Text(directory_context));
266 }
267
268 if symbol_context.len() > OPEN_SYMBOLS_TAG.len() {
269 symbol_context.push_str("</symbols>\n");
270 message
271 .content
272 .push(language_model::MessageContent::Text(symbol_context));
273 }
274
275 if thread_context.len() > OPEN_THREADS_TAG.len() {
276 thread_context.push_str("</threads>\n");
277 message
278 .content
279 .push(language_model::MessageContent::Text(thread_context));
280 }
281
282 if fetch_context.len() > OPEN_FETCH_TAG.len() {
283 fetch_context.push_str("</fetched_urls>\n");
284 message
285 .content
286 .push(language_model::MessageContent::Text(fetch_context));
287 }
288
289 if rules_context.len() > OPEN_RULES_TAG.len() {
290 rules_context.push_str("</user_rules>\n");
291 message
292 .content
293 .push(language_model::MessageContent::Text(rules_context));
294 }
295
296 if message.content.len() > len_before_context {
297 message.content.insert(
298 len_before_context,
299 language_model::MessageContent::Text(OPEN_CONTEXT.into()),
300 );
301 message
302 .content
303 .push(language_model::MessageContent::Text("</context>".into()));
304 }
305
306 message
307 }
308}
309
310fn codeblock_tag(full_path: &Path, line_range: Option<&Range<u32>>) -> String {
311 let mut result = String::new();
312
313 if let Some(extension) = full_path.extension().and_then(|ext| ext.to_str()) {
314 let _ = write!(result, "{} ", extension);
315 }
316
317 let _ = write!(result, "{}", full_path.display());
318
319 if let Some(range) = line_range {
320 if range.start == range.end {
321 let _ = write!(result, ":{}", range.start + 1);
322 } else {
323 let _ = write!(result, ":{}-{}", range.start + 1, range.end + 1);
324 }
325 }
326
327 result
328}
329
330impl AgentMessage {
331 pub fn to_markdown(&self) -> String {
332 let mut markdown = String::from("## Assistant\n\n");
333
334 for content in &self.content {
335 match content {
336 AgentMessageContent::Text(text) => {
337 markdown.push_str(text);
338 markdown.push('\n');
339 }
340 AgentMessageContent::Thinking { text, .. } => {
341 markdown.push_str("<think>");
342 markdown.push_str(text);
343 markdown.push_str("</think>\n");
344 }
345 AgentMessageContent::RedactedThinking(_) => {
346 markdown.push_str("<redacted_thinking />\n")
347 }
348 AgentMessageContent::Image(_) => {
349 markdown.push_str("<image />\n");
350 }
351 AgentMessageContent::ToolUse(tool_use) => {
352 markdown.push_str(&format!(
353 "**Tool Use**: {} (ID: {})\n",
354 tool_use.name, tool_use.id
355 ));
356 markdown.push_str(&format!(
357 "{}\n",
358 MarkdownCodeBlock {
359 tag: "json",
360 text: &format!("{:#}", tool_use.input)
361 }
362 ));
363 }
364 }
365 }
366
367 for tool_result in self.tool_results.values() {
368 markdown.push_str(&format!(
369 "**Tool Result**: {} (ID: {})\n\n",
370 tool_result.tool_name, tool_result.tool_use_id
371 ));
372 if tool_result.is_error {
373 markdown.push_str("**ERROR:**\n");
374 }
375
376 match &tool_result.content {
377 LanguageModelToolResultContent::Text(text) => {
378 writeln!(markdown, "{text}\n").ok();
379 }
380 LanguageModelToolResultContent::Image(_) => {
381 writeln!(markdown, "<image />\n").ok();
382 }
383 }
384
385 if let Some(output) = tool_result.output.as_ref() {
386 writeln!(
387 markdown,
388 "**Debug Output**:\n\n```json\n{}\n```\n",
389 serde_json::to_string_pretty(output).unwrap()
390 )
391 .unwrap();
392 }
393 }
394
395 markdown
396 }
397
398 pub fn to_request(&self) -> Vec<LanguageModelRequestMessage> {
399 let mut assistant_message = LanguageModelRequestMessage {
400 role: Role::Assistant,
401 content: Vec::with_capacity(self.content.len()),
402 cache: false,
403 };
404 for chunk in &self.content {
405 let chunk = match chunk {
406 AgentMessageContent::Text(text) => {
407 language_model::MessageContent::Text(text.clone())
408 }
409 AgentMessageContent::Thinking { text, signature } => {
410 language_model::MessageContent::Thinking {
411 text: text.clone(),
412 signature: signature.clone(),
413 }
414 }
415 AgentMessageContent::RedactedThinking(value) => {
416 language_model::MessageContent::RedactedThinking(value.clone())
417 }
418 AgentMessageContent::ToolUse(value) => {
419 language_model::MessageContent::ToolUse(value.clone())
420 }
421 AgentMessageContent::Image(value) => {
422 language_model::MessageContent::Image(value.clone())
423 }
424 };
425 assistant_message.content.push(chunk);
426 }
427
428 let mut user_message = LanguageModelRequestMessage {
429 role: Role::User,
430 content: Vec::new(),
431 cache: false,
432 };
433
434 for tool_result in self.tool_results.values() {
435 user_message
436 .content
437 .push(language_model::MessageContent::ToolResult(
438 tool_result.clone(),
439 ));
440 }
441
442 let mut messages = Vec::new();
443 if !assistant_message.content.is_empty() {
444 messages.push(assistant_message);
445 }
446 if !user_message.content.is_empty() {
447 messages.push(user_message);
448 }
449 messages
450 }
451}
452
453#[derive(Default, Debug, Clone, PartialEq, Eq)]
454pub struct AgentMessage {
455 pub content: Vec<AgentMessageContent>,
456 pub tool_results: IndexMap<LanguageModelToolUseId, LanguageModelToolResult>,
457}
458
459#[derive(Debug, Clone, PartialEq, Eq)]
460pub enum AgentMessageContent {
461 Text(String),
462 Thinking {
463 text: String,
464 signature: Option<String>,
465 },
466 RedactedThinking(String),
467 Image(LanguageModelImage),
468 ToolUse(LanguageModelToolUse),
469}
470
471#[derive(Debug)]
472pub enum AgentResponseEvent {
473 Text(String),
474 Thinking(String),
475 ToolCall(acp::ToolCall),
476 ToolCallUpdate(acp_thread::ToolCallUpdate),
477 ToolCallAuthorization(ToolCallAuthorization),
478 Retry(acp_thread::RetryStatus),
479 Stop(acp::StopReason),
480}
481
482#[derive(Debug)]
483pub struct ToolCallAuthorization {
484 pub tool_call: acp::ToolCallUpdate,
485 pub options: Vec<acp::PermissionOption>,
486 pub response: oneshot::Sender<acp::PermissionOptionId>,
487}
488
489pub struct Thread {
490 id: ThreadId,
491 prompt_id: PromptId,
492 messages: Vec<Message>,
493 completion_mode: CompletionMode,
494 /// Holds the task that handles agent interaction until the end of the turn.
495 /// Survives across multiple requests as the model performs tool calls and
496 /// we run tools, report their results.
497 running_turn: Option<RunningTurn>,
498 pending_message: Option<AgentMessage>,
499 tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
500 tool_use_limit_reached: bool,
501 context_server_registry: Entity<ContextServerRegistry>,
502 profile_id: AgentProfileId,
503 project_context: Entity<ProjectContext>,
504 templates: Arc<Templates>,
505 model: Option<Arc<dyn LanguageModel>>,
506 project: Entity<Project>,
507 action_log: Entity<ActionLog>,
508}
509
510impl Thread {
511 pub fn new(
512 project: Entity<Project>,
513 project_context: Entity<ProjectContext>,
514 context_server_registry: Entity<ContextServerRegistry>,
515 action_log: Entity<ActionLog>,
516 templates: Arc<Templates>,
517 model: Option<Arc<dyn LanguageModel>>,
518 cx: &mut Context<Self>,
519 ) -> Self {
520 let profile_id = AgentSettings::get_global(cx).default_profile.clone();
521 Self {
522 id: ThreadId::new(),
523 prompt_id: PromptId::new(),
524 messages: Vec::new(),
525 completion_mode: AgentSettings::get_global(cx).preferred_completion_mode,
526 running_turn: None,
527 pending_message: None,
528 tools: BTreeMap::default(),
529 tool_use_limit_reached: false,
530 context_server_registry,
531 profile_id,
532 project_context,
533 templates,
534 model,
535 project,
536 action_log,
537 }
538 }
539
540 pub fn project(&self) -> &Entity<Project> {
541 &self.project
542 }
543
544 pub fn project_context(&self) -> &Entity<ProjectContext> {
545 &self.project_context
546 }
547
548 pub fn action_log(&self) -> &Entity<ActionLog> {
549 &self.action_log
550 }
551
552 pub fn model(&self) -> Option<&Arc<dyn LanguageModel>> {
553 self.model.as_ref()
554 }
555
556 pub fn set_model(&mut self, model: Arc<dyn LanguageModel>) {
557 self.model = Some(model);
558 }
559
560 pub fn completion_mode(&self) -> CompletionMode {
561 self.completion_mode
562 }
563
564 pub fn set_completion_mode(&mut self, mode: CompletionMode) {
565 self.completion_mode = mode;
566 }
567
568 #[cfg(any(test, feature = "test-support"))]
569 pub fn last_message(&self) -> Option<Message> {
570 if let Some(message) = self.pending_message.clone() {
571 Some(Message::Agent(message))
572 } else {
573 self.messages.last().cloned()
574 }
575 }
576
577 pub fn add_tool(&mut self, tool: impl AgentTool) {
578 self.tools.insert(tool.name(), tool.erase());
579 }
580
581 pub fn remove_tool(&mut self, name: &str) -> bool {
582 self.tools.remove(name).is_some()
583 }
584
585 pub fn profile(&self) -> &AgentProfileId {
586 &self.profile_id
587 }
588
589 pub fn set_profile(&mut self, profile_id: AgentProfileId) {
590 self.profile_id = profile_id;
591 }
592
593 pub fn cancel(&mut self) {
594 if let Some(running_turn) = self.running_turn.take() {
595 running_turn.cancel();
596 }
597 self.flush_pending_message();
598 }
599
600 pub fn truncate(&mut self, message_id: UserMessageId) -> Result<()> {
601 self.cancel();
602 let Some(position) = self.messages.iter().position(
603 |msg| matches!(msg, Message::User(UserMessage { id, .. }) if id == &message_id),
604 ) else {
605 return Err(anyhow!("Message not found"));
606 };
607 self.messages.truncate(position);
608 Ok(())
609 }
610
611 pub fn resume(
612 &mut self,
613 cx: &mut Context<Self>,
614 ) -> Result<mpsc::UnboundedReceiver<Result<AgentResponseEvent>>> {
615 anyhow::ensure!(self.model.is_some(), "Model not set");
616 anyhow::ensure!(
617 self.tool_use_limit_reached,
618 "can only resume after tool use limit is reached"
619 );
620
621 self.messages.push(Message::Resume);
622 cx.notify();
623
624 log::info!("Total messages in thread: {}", self.messages.len());
625 self.run_turn(cx)
626 }
627
628 /// Sending a message results in the model streaming a response, which could include tool calls.
629 /// After calling tools, the model will stops and waits for any outstanding tool calls to be completed and their results sent.
630 /// The returned channel will report all the occurrences in which the model stops before erroring or ending its turn.
631 pub fn send<T>(
632 &mut self,
633 id: UserMessageId,
634 content: impl IntoIterator<Item = T>,
635 cx: &mut Context<Self>,
636 ) -> Result<mpsc::UnboundedReceiver<Result<AgentResponseEvent>>>
637 where
638 T: Into<UserMessageContent>,
639 {
640 let model = self.model().context("No language model configured")?;
641
642 log::info!("Thread::send called with model: {:?}", model.name());
643 self.advance_prompt_id();
644
645 let content = content.into_iter().map(Into::into).collect::<Vec<_>>();
646 log::debug!("Thread::send content: {:?}", content);
647
648 self.messages
649 .push(Message::User(UserMessage { id, content }));
650 cx.notify();
651
652 log::info!("Total messages in thread: {}", self.messages.len());
653 self.run_turn(cx)
654 }
655
656 fn run_turn(
657 &mut self,
658 cx: &mut Context<Self>,
659 ) -> Result<mpsc::UnboundedReceiver<Result<AgentResponseEvent>>> {
660 self.cancel();
661
662 let model = self
663 .model()
664 .cloned()
665 .context("No language model configured")?;
666 let (events_tx, events_rx) = mpsc::unbounded::<Result<AgentResponseEvent>>();
667 let event_stream = AgentResponseEventStream(events_tx);
668 let message_ix = self.messages.len().saturating_sub(1);
669 self.tool_use_limit_reached = false;
670 self.running_turn = Some(RunningTurn {
671 event_stream: event_stream.clone(),
672 _task: cx.spawn(async move |this, cx| {
673 log::info!("Starting agent turn execution");
674 let turn_result: Result<()> = async {
675 let mut completion_intent = CompletionIntent::UserPrompt;
676 loop {
677 log::debug!(
678 "Building completion request with intent: {:?}",
679 completion_intent
680 );
681 let request = this.update(cx, |this, cx| {
682 this.build_completion_request(completion_intent, cx)
683 })??;
684
685 log::info!("Calling model.stream_completion");
686
687 let mut tool_use_limit_reached = false;
688 let mut tool_uses = Self::stream_completion_with_retries(
689 this.clone(),
690 model.clone(),
691 request,
692 message_ix,
693 &event_stream,
694 &mut tool_use_limit_reached,
695 cx,
696 )
697 .await?;
698
699 let used_tools = tool_uses.is_empty();
700 while let Some(tool_result) = tool_uses.next().await {
701 log::info!("Tool finished {:?}", tool_result);
702
703 event_stream.update_tool_call_fields(
704 &tool_result.tool_use_id,
705 acp::ToolCallUpdateFields {
706 status: Some(if tool_result.is_error {
707 acp::ToolCallStatus::Failed
708 } else {
709 acp::ToolCallStatus::Completed
710 }),
711 raw_output: tool_result.output.clone(),
712 ..Default::default()
713 },
714 );
715 this.update(cx, |this, _cx| {
716 this.pending_message()
717 .tool_results
718 .insert(tool_result.tool_use_id.clone(), tool_result);
719 })
720 .ok();
721 }
722
723 if tool_use_limit_reached {
724 log::info!("Tool use limit reached, completing turn");
725 this.update(cx, |this, _cx| this.tool_use_limit_reached = true)?;
726 return Err(language_model::ToolUseLimitReachedError.into());
727 } else if used_tools {
728 log::info!("No tool uses found, completing turn");
729 return Ok(());
730 } else {
731 this.update(cx, |this, _| this.flush_pending_message())?;
732 completion_intent = CompletionIntent::ToolResults;
733 }
734 }
735 }
736 .await;
737
738 if let Err(error) = turn_result {
739 log::error!("Turn execution failed: {:?}", error);
740 event_stream.send_error(error);
741 } else {
742 log::info!("Turn execution completed successfully");
743 }
744
745 this.update(cx, |this, _| {
746 this.flush_pending_message();
747 this.running_turn.take();
748 })
749 .ok();
750 }),
751 });
752 Ok(events_rx)
753 }
754
755 async fn stream_completion_with_retries(
756 this: WeakEntity<Self>,
757 model: Arc<dyn LanguageModel>,
758 request: LanguageModelRequest,
759 message_ix: usize,
760 event_stream: &AgentResponseEventStream,
761 tool_use_limit_reached: &mut bool,
762 cx: &mut AsyncApp,
763 ) -> Result<FuturesUnordered<Task<LanguageModelToolResult>>> {
764 log::debug!("Stream completion started successfully");
765
766 let mut attempt = None;
767 'retry: loop {
768 let mut events = model.stream_completion(request.clone(), cx).await?;
769 let mut tool_uses = FuturesUnordered::new();
770 while let Some(event) = events.next().await {
771 match event {
772 Ok(LanguageModelCompletionEvent::StatusUpdate(
773 CompletionRequestStatus::ToolUseLimitReached,
774 )) => {
775 *tool_use_limit_reached = true;
776 }
777 Ok(LanguageModelCompletionEvent::Stop(reason)) => {
778 event_stream.send_stop(reason);
779 if reason == StopReason::Refusal {
780 this.update(cx, |this, _cx| {
781 this.flush_pending_message();
782 this.messages.truncate(message_ix);
783 })?;
784 return Ok(tool_uses);
785 }
786 }
787 Ok(event) => {
788 log::trace!("Received completion event: {:?}", event);
789 this.update(cx, |this, cx| {
790 tool_uses.extend(this.handle_streamed_completion_event(
791 event,
792 event_stream,
793 cx,
794 ));
795 })
796 .ok();
797 }
798 Err(error) => {
799 let completion_mode =
800 this.read_with(cx, |thread, _cx| thread.completion_mode())?;
801 if completion_mode == CompletionMode::Normal {
802 return Err(error.into());
803 }
804
805 let Some(strategy) = Self::retry_strategy_for(&error) else {
806 return Err(error.into());
807 };
808
809 let max_attempts = match &strategy {
810 RetryStrategy::ExponentialBackoff { max_attempts, .. } => *max_attempts,
811 RetryStrategy::Fixed { max_attempts, .. } => *max_attempts,
812 };
813
814 let attempt = attempt.get_or_insert(0u8);
815
816 *attempt += 1;
817
818 let attempt = *attempt;
819 if attempt > max_attempts {
820 return Err(error.into());
821 }
822
823 let delay = match &strategy {
824 RetryStrategy::ExponentialBackoff { initial_delay, .. } => {
825 let delay_secs =
826 initial_delay.as_secs() * 2u64.pow((attempt - 1) as u32);
827 Duration::from_secs(delay_secs)
828 }
829 RetryStrategy::Fixed { delay, .. } => *delay,
830 };
831 log::debug!("Retry attempt {attempt} with delay {delay:?}");
832
833 event_stream.send_retry(acp_thread::RetryStatus {
834 last_error: error.to_string().into(),
835 attempt: attempt as usize,
836 max_attempts: max_attempts as usize,
837 started_at: Instant::now(),
838 duration: delay,
839 });
840
841 cx.background_executor().timer(delay).await;
842 continue 'retry;
843 }
844 }
845 }
846 return Ok(tool_uses);
847 }
848 }
849
850 pub fn build_system_message(&self, cx: &App) -> LanguageModelRequestMessage {
851 log::debug!("Building system message");
852 let prompt = SystemPromptTemplate {
853 project: self.project_context.read(cx),
854 available_tools: self.tools.keys().cloned().collect(),
855 }
856 .render(&self.templates)
857 .context("failed to build system prompt")
858 .expect("Invalid template");
859 log::debug!("System message built");
860 LanguageModelRequestMessage {
861 role: Role::System,
862 content: vec![prompt.into()],
863 cache: true,
864 }
865 }
866
867 /// A helper method that's called on every streamed completion event.
868 /// Returns an optional tool result task, which the main agentic loop in
869 /// send will send back to the model when it resolves.
870 fn handle_streamed_completion_event(
871 &mut self,
872 event: LanguageModelCompletionEvent,
873 event_stream: &AgentResponseEventStream,
874 cx: &mut Context<Self>,
875 ) -> Option<Task<LanguageModelToolResult>> {
876 log::trace!("Handling streamed completion event: {:?}", event);
877 use LanguageModelCompletionEvent::*;
878
879 match event {
880 StartMessage { .. } => {
881 self.flush_pending_message();
882 self.pending_message = Some(AgentMessage::default());
883 }
884 Text(new_text) => self.handle_text_event(new_text, event_stream, cx),
885 Thinking { text, signature } => {
886 self.handle_thinking_event(text, signature, event_stream, cx)
887 }
888 RedactedThinking { data } => self.handle_redacted_thinking_event(data, cx),
889 ToolUse(tool_use) => {
890 return self.handle_tool_use_event(tool_use, event_stream, cx);
891 }
892 ToolUseJsonParseError {
893 id,
894 tool_name,
895 raw_input,
896 json_parse_error,
897 } => {
898 return Some(Task::ready(self.handle_tool_use_json_parse_error_event(
899 id,
900 tool_name,
901 raw_input,
902 json_parse_error,
903 )));
904 }
905 UsageUpdate(_) | StatusUpdate(_) => {}
906 Stop(_) => unreachable!(),
907 }
908
909 None
910 }
911
912 fn handle_text_event(
913 &mut self,
914 new_text: String,
915 event_stream: &AgentResponseEventStream,
916 cx: &mut Context<Self>,
917 ) {
918 event_stream.send_text(&new_text);
919
920 let last_message = self.pending_message();
921 if let Some(AgentMessageContent::Text(text)) = last_message.content.last_mut() {
922 text.push_str(&new_text);
923 } else {
924 last_message
925 .content
926 .push(AgentMessageContent::Text(new_text));
927 }
928
929 cx.notify();
930 }
931
932 fn handle_thinking_event(
933 &mut self,
934 new_text: String,
935 new_signature: Option<String>,
936 event_stream: &AgentResponseEventStream,
937 cx: &mut Context<Self>,
938 ) {
939 event_stream.send_thinking(&new_text);
940
941 let last_message = self.pending_message();
942 if let Some(AgentMessageContent::Thinking { text, signature }) =
943 last_message.content.last_mut()
944 {
945 text.push_str(&new_text);
946 *signature = new_signature.or(signature.take());
947 } else {
948 last_message.content.push(AgentMessageContent::Thinking {
949 text: new_text,
950 signature: new_signature,
951 });
952 }
953
954 cx.notify();
955 }
956
957 fn handle_redacted_thinking_event(&mut self, data: String, cx: &mut Context<Self>) {
958 let last_message = self.pending_message();
959 last_message
960 .content
961 .push(AgentMessageContent::RedactedThinking(data));
962 cx.notify();
963 }
964
965 fn handle_tool_use_event(
966 &mut self,
967 tool_use: LanguageModelToolUse,
968 event_stream: &AgentResponseEventStream,
969 cx: &mut Context<Self>,
970 ) -> Option<Task<LanguageModelToolResult>> {
971 cx.notify();
972
973 let tool = self.tools.get(tool_use.name.as_ref()).cloned();
974 let mut title = SharedString::from(&tool_use.name);
975 let mut kind = acp::ToolKind::Other;
976 if let Some(tool) = tool.as_ref() {
977 title = tool.initial_title(tool_use.input.clone());
978 kind = tool.kind();
979 }
980
981 // Ensure the last message ends in the current tool use
982 let last_message = self.pending_message();
983 let push_new_tool_use = last_message.content.last_mut().map_or(true, |content| {
984 if let AgentMessageContent::ToolUse(last_tool_use) = content {
985 if last_tool_use.id == tool_use.id {
986 *last_tool_use = tool_use.clone();
987 false
988 } else {
989 true
990 }
991 } else {
992 true
993 }
994 });
995
996 if push_new_tool_use {
997 event_stream.send_tool_call(&tool_use.id, title, kind, tool_use.input.clone());
998 last_message
999 .content
1000 .push(AgentMessageContent::ToolUse(tool_use.clone()));
1001 } else {
1002 event_stream.update_tool_call_fields(
1003 &tool_use.id,
1004 acp::ToolCallUpdateFields {
1005 title: Some(title.into()),
1006 kind: Some(kind),
1007 raw_input: Some(tool_use.input.clone()),
1008 ..Default::default()
1009 },
1010 );
1011 }
1012
1013 if !tool_use.is_input_complete {
1014 return None;
1015 }
1016
1017 let Some(tool) = tool else {
1018 let content = format!("No tool named {} exists", tool_use.name);
1019 return Some(Task::ready(LanguageModelToolResult {
1020 content: LanguageModelToolResultContent::Text(Arc::from(content)),
1021 tool_use_id: tool_use.id,
1022 tool_name: tool_use.name,
1023 is_error: true,
1024 output: None,
1025 }));
1026 };
1027
1028 let fs = self.project.read(cx).fs().clone();
1029 let tool_event_stream =
1030 ToolCallEventStream::new(tool_use.id.clone(), event_stream.clone(), Some(fs));
1031 tool_event_stream.update_fields(acp::ToolCallUpdateFields {
1032 status: Some(acp::ToolCallStatus::InProgress),
1033 ..Default::default()
1034 });
1035 let supports_images = self.model().map_or(false, |model| model.supports_images());
1036 let tool_result = tool.run(tool_use.input, tool_event_stream, cx);
1037 log::info!("Running tool {}", tool_use.name);
1038 Some(cx.foreground_executor().spawn(async move {
1039 let tool_result = tool_result.await.and_then(|output| {
1040 if let LanguageModelToolResultContent::Image(_) = &output.llm_output
1041 && !supports_images {
1042 return Err(anyhow!(
1043 "Attempted to read an image, but this model doesn't support it.",
1044 ));
1045 }
1046 Ok(output)
1047 });
1048
1049 match tool_result {
1050 Ok(output) => LanguageModelToolResult {
1051 tool_use_id: tool_use.id,
1052 tool_name: tool_use.name,
1053 is_error: false,
1054 content: output.llm_output,
1055 output: Some(output.raw_output),
1056 },
1057 Err(error) => LanguageModelToolResult {
1058 tool_use_id: tool_use.id,
1059 tool_name: tool_use.name,
1060 is_error: true,
1061 content: LanguageModelToolResultContent::Text(Arc::from(error.to_string())),
1062 output: None,
1063 },
1064 }
1065 }))
1066 }
1067
1068 fn handle_tool_use_json_parse_error_event(
1069 &mut self,
1070 tool_use_id: LanguageModelToolUseId,
1071 tool_name: Arc<str>,
1072 raw_input: Arc<str>,
1073 json_parse_error: String,
1074 ) -> LanguageModelToolResult {
1075 let tool_output = format!("Error parsing input JSON: {json_parse_error}");
1076 LanguageModelToolResult {
1077 tool_use_id,
1078 tool_name,
1079 is_error: true,
1080 content: LanguageModelToolResultContent::Text(tool_output.into()),
1081 output: Some(serde_json::Value::String(raw_input.to_string())),
1082 }
1083 }
1084
1085 fn pending_message(&mut self) -> &mut AgentMessage {
1086 self.pending_message.get_or_insert_default()
1087 }
1088
1089 fn flush_pending_message(&mut self) {
1090 let Some(mut message) = self.pending_message.take() else {
1091 return;
1092 };
1093
1094 for content in &message.content {
1095 let AgentMessageContent::ToolUse(tool_use) = content else {
1096 continue;
1097 };
1098
1099 if !message.tool_results.contains_key(&tool_use.id) {
1100 message.tool_results.insert(
1101 tool_use.id.clone(),
1102 LanguageModelToolResult {
1103 tool_use_id: tool_use.id.clone(),
1104 tool_name: tool_use.name.clone(),
1105 is_error: true,
1106 content: LanguageModelToolResultContent::Text(
1107 "Tool canceled by user".into(),
1108 ),
1109 output: None,
1110 },
1111 );
1112 }
1113 }
1114
1115 self.messages.push(Message::Agent(message));
1116 }
1117
1118 pub(crate) fn build_completion_request(
1119 &self,
1120 completion_intent: CompletionIntent,
1121 cx: &mut App,
1122 ) -> Result<LanguageModelRequest> {
1123 let model = self.model().context("No language model configured")?;
1124
1125 log::debug!("Building completion request");
1126 log::debug!("Completion intent: {:?}", completion_intent);
1127 log::debug!("Completion mode: {:?}", self.completion_mode);
1128
1129 let messages = self.build_request_messages(cx);
1130 log::info!("Request will include {} messages", messages.len());
1131
1132 let tools = if let Some(tools) = self.tools(cx).log_err() {
1133 tools
1134 .filter_map(|tool| {
1135 let tool_name = tool.name().to_string();
1136 log::trace!("Including tool: {}", tool_name);
1137 Some(LanguageModelRequestTool {
1138 name: tool_name,
1139 description: tool.description().to_string(),
1140 input_schema: tool.input_schema(model.tool_input_format()).log_err()?,
1141 })
1142 })
1143 .collect()
1144 } else {
1145 Vec::new()
1146 };
1147
1148 log::info!("Request includes {} tools", tools.len());
1149
1150 let request = LanguageModelRequest {
1151 thread_id: Some(self.id.to_string()),
1152 prompt_id: Some(self.prompt_id.to_string()),
1153 intent: Some(completion_intent),
1154 mode: Some(self.completion_mode.into()),
1155 messages,
1156 tools,
1157 tool_choice: None,
1158 stop: Vec::new(),
1159 temperature: AgentSettings::temperature_for_model(model, cx),
1160 thinking_allowed: true,
1161 };
1162
1163 log::debug!("Completion request built successfully");
1164 Ok(request)
1165 }
1166
1167 fn tools<'a>(&'a self, cx: &'a App) -> Result<impl Iterator<Item = &'a Arc<dyn AnyAgentTool>>> {
1168 let model = self.model().context("No language model configured")?;
1169
1170 let profile = AgentSettings::get_global(cx)
1171 .profiles
1172 .get(&self.profile_id)
1173 .context("profile not found")?;
1174 let provider_id = model.provider_id();
1175
1176 Ok(self
1177 .tools
1178 .iter()
1179 .filter(move |(_, tool)| tool.supported_provider(&provider_id))
1180 .filter_map(|(tool_name, tool)| {
1181 if profile.is_tool_enabled(tool_name) {
1182 Some(tool)
1183 } else {
1184 None
1185 }
1186 })
1187 .chain(self.context_server_registry.read(cx).servers().flat_map(
1188 |(server_id, tools)| {
1189 tools.iter().filter_map(|(tool_name, tool)| {
1190 if profile.is_context_server_tool_enabled(&server_id.0, tool_name) {
1191 Some(tool)
1192 } else {
1193 None
1194 }
1195 })
1196 },
1197 )))
1198 }
1199
1200 fn build_request_messages(&self, cx: &App) -> Vec<LanguageModelRequestMessage> {
1201 log::trace!(
1202 "Building request messages from {} thread messages",
1203 self.messages.len()
1204 );
1205 let mut messages = vec![self.build_system_message(cx)];
1206 for message in &self.messages {
1207 match message {
1208 Message::User(message) => messages.push(message.to_request()),
1209 Message::Agent(message) => messages.extend(message.to_request()),
1210 Message::Resume => messages.push(LanguageModelRequestMessage {
1211 role: Role::User,
1212 content: vec!["Continue where you left off".into()],
1213 cache: false,
1214 }),
1215 }
1216 }
1217
1218 if let Some(message) = self.pending_message.as_ref() {
1219 messages.extend(message.to_request());
1220 }
1221
1222 if let Some(last_user_message) = messages
1223 .iter_mut()
1224 .rev()
1225 .find(|message| message.role == Role::User)
1226 {
1227 last_user_message.cache = true;
1228 }
1229
1230 messages
1231 }
1232
1233 pub fn to_markdown(&self) -> String {
1234 let mut markdown = String::new();
1235 for (ix, message) in self.messages.iter().enumerate() {
1236 if ix > 0 {
1237 markdown.push('\n');
1238 }
1239 markdown.push_str(&message.to_markdown());
1240 }
1241
1242 if let Some(message) = self.pending_message.as_ref() {
1243 markdown.push('\n');
1244 markdown.push_str(&message.to_markdown());
1245 }
1246
1247 markdown
1248 }
1249
1250 fn advance_prompt_id(&mut self) {
1251 self.prompt_id = PromptId::new();
1252 }
1253
1254 fn retry_strategy_for(error: &LanguageModelCompletionError) -> Option<RetryStrategy> {
1255 use LanguageModelCompletionError::*;
1256 use http_client::StatusCode;
1257
1258 // General strategy here:
1259 // - If retrying won't help (e.g. invalid API key or payload too large), return None so we don't retry at all.
1260 // - If it's a time-based issue (e.g. server overloaded, rate limit exceeded), retry up to 4 times with exponential backoff.
1261 // - If it's an issue that *might* be fixed by retrying (e.g. internal server error), retry up to 3 times.
1262 match error {
1263 HttpResponseError {
1264 status_code: StatusCode::TOO_MANY_REQUESTS,
1265 ..
1266 } => Some(RetryStrategy::ExponentialBackoff {
1267 initial_delay: BASE_RETRY_DELAY,
1268 max_attempts: MAX_RETRY_ATTEMPTS,
1269 }),
1270 ServerOverloaded { retry_after, .. } | RateLimitExceeded { retry_after, .. } => {
1271 Some(RetryStrategy::Fixed {
1272 delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
1273 max_attempts: MAX_RETRY_ATTEMPTS,
1274 })
1275 }
1276 UpstreamProviderError {
1277 status,
1278 retry_after,
1279 ..
1280 } => match *status {
1281 StatusCode::TOO_MANY_REQUESTS | StatusCode::SERVICE_UNAVAILABLE => {
1282 Some(RetryStrategy::Fixed {
1283 delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
1284 max_attempts: MAX_RETRY_ATTEMPTS,
1285 })
1286 }
1287 StatusCode::INTERNAL_SERVER_ERROR => Some(RetryStrategy::Fixed {
1288 delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
1289 // Internal Server Error could be anything, retry up to 3 times.
1290 max_attempts: 3,
1291 }),
1292 status => {
1293 // There is no StatusCode variant for the unofficial HTTP 529 ("The service is overloaded"),
1294 // but we frequently get them in practice. See https://http.dev/529
1295 if status.as_u16() == 529 {
1296 Some(RetryStrategy::Fixed {
1297 delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
1298 max_attempts: MAX_RETRY_ATTEMPTS,
1299 })
1300 } else {
1301 Some(RetryStrategy::Fixed {
1302 delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
1303 max_attempts: 2,
1304 })
1305 }
1306 }
1307 },
1308 ApiInternalServerError { .. } => Some(RetryStrategy::Fixed {
1309 delay: BASE_RETRY_DELAY,
1310 max_attempts: 3,
1311 }),
1312 ApiReadResponseError { .. }
1313 | HttpSend { .. }
1314 | DeserializeResponse { .. }
1315 | BadRequestFormat { .. } => Some(RetryStrategy::Fixed {
1316 delay: BASE_RETRY_DELAY,
1317 max_attempts: 3,
1318 }),
1319 // Retrying these errors definitely shouldn't help.
1320 HttpResponseError {
1321 status_code:
1322 StatusCode::PAYLOAD_TOO_LARGE | StatusCode::FORBIDDEN | StatusCode::UNAUTHORIZED,
1323 ..
1324 }
1325 | AuthenticationError { .. }
1326 | PermissionError { .. }
1327 | NoApiKey { .. }
1328 | ApiEndpointNotFound { .. }
1329 | PromptTooLarge { .. } => None,
1330 // These errors might be transient, so retry them
1331 SerializeRequest { .. } | BuildRequestBody { .. } => Some(RetryStrategy::Fixed {
1332 delay: BASE_RETRY_DELAY,
1333 max_attempts: 1,
1334 }),
1335 // Retry all other 4xx and 5xx errors once.
1336 HttpResponseError { status_code, .. }
1337 if status_code.is_client_error() || status_code.is_server_error() =>
1338 {
1339 Some(RetryStrategy::Fixed {
1340 delay: BASE_RETRY_DELAY,
1341 max_attempts: 3,
1342 })
1343 }
1344 Other(err)
1345 if err.is::<language_model::PaymentRequiredError>()
1346 || err.is::<language_model::ModelRequestLimitReachedError>() =>
1347 {
1348 // Retrying won't help for Payment Required or Model Request Limit errors (where
1349 // the user must upgrade to usage-based billing to get more requests, or else wait
1350 // for a significant amount of time for the request limit to reset).
1351 None
1352 }
1353 // Conservatively assume that any other errors are non-retryable
1354 HttpResponseError { .. } | Other(..) => Some(RetryStrategy::Fixed {
1355 delay: BASE_RETRY_DELAY,
1356 max_attempts: 2,
1357 }),
1358 }
1359 }
1360}
1361
1362struct RunningTurn {
1363 /// Holds the task that handles agent interaction until the end of the turn.
1364 /// Survives across multiple requests as the model performs tool calls and
1365 /// we run tools, report their results.
1366 _task: Task<()>,
1367 /// The current event stream for the running turn. Used to report a final
1368 /// cancellation event if we cancel the turn.
1369 event_stream: AgentResponseEventStream,
1370}
1371
1372impl RunningTurn {
1373 fn cancel(self) {
1374 log::debug!("Cancelling in progress turn");
1375 self.event_stream.send_canceled();
1376 }
1377}
1378
1379pub trait AgentTool
1380where
1381 Self: 'static + Sized,
1382{
1383 type Input: for<'de> Deserialize<'de> + Serialize + JsonSchema;
1384 type Output: for<'de> Deserialize<'de> + Serialize + Into<LanguageModelToolResultContent>;
1385
1386 fn name(&self) -> SharedString;
1387
1388 fn description(&self) -> SharedString {
1389 let schema = schemars::schema_for!(Self::Input);
1390 SharedString::new(
1391 schema
1392 .get("description")
1393 .and_then(|description| description.as_str())
1394 .unwrap_or_default(),
1395 )
1396 }
1397
1398 fn kind(&self) -> acp::ToolKind;
1399
1400 /// The initial tool title to display. Can be updated during the tool run.
1401 fn initial_title(&self, input: Result<Self::Input, serde_json::Value>) -> SharedString;
1402
1403 /// Returns the JSON schema that describes the tool's input.
1404 fn input_schema(&self) -> Schema {
1405 schemars::schema_for!(Self::Input)
1406 }
1407
1408 /// Some tools rely on a provider for the underlying billing or other reasons.
1409 /// Allow the tool to check if they are compatible, or should be filtered out.
1410 fn supported_provider(&self, _provider: &LanguageModelProviderId) -> bool {
1411 true
1412 }
1413
1414 /// Runs the tool with the provided input.
1415 fn run(
1416 self: Arc<Self>,
1417 input: Self::Input,
1418 event_stream: ToolCallEventStream,
1419 cx: &mut App,
1420 ) -> Task<Result<Self::Output>>;
1421
1422 fn erase(self) -> Arc<dyn AnyAgentTool> {
1423 Arc::new(Erased(Arc::new(self)))
1424 }
1425}
1426
1427pub struct Erased<T>(T);
1428
1429pub struct AgentToolOutput {
1430 pub llm_output: LanguageModelToolResultContent,
1431 pub raw_output: serde_json::Value,
1432}
1433
1434pub trait AnyAgentTool {
1435 fn name(&self) -> SharedString;
1436 fn description(&self) -> SharedString;
1437 fn kind(&self) -> acp::ToolKind;
1438 fn initial_title(&self, input: serde_json::Value) -> SharedString;
1439 fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value>;
1440 fn supported_provider(&self, _provider: &LanguageModelProviderId) -> bool {
1441 true
1442 }
1443 fn run(
1444 self: Arc<Self>,
1445 input: serde_json::Value,
1446 event_stream: ToolCallEventStream,
1447 cx: &mut App,
1448 ) -> Task<Result<AgentToolOutput>>;
1449}
1450
1451impl<T> AnyAgentTool for Erased<Arc<T>>
1452where
1453 T: AgentTool,
1454{
1455 fn name(&self) -> SharedString {
1456 self.0.name()
1457 }
1458
1459 fn description(&self) -> SharedString {
1460 self.0.description()
1461 }
1462
1463 fn kind(&self) -> agent_client_protocol::ToolKind {
1464 self.0.kind()
1465 }
1466
1467 fn initial_title(&self, input: serde_json::Value) -> SharedString {
1468 let parsed_input = serde_json::from_value(input.clone()).map_err(|_| input);
1469 self.0.initial_title(parsed_input)
1470 }
1471
1472 fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
1473 let mut json = serde_json::to_value(self.0.input_schema())?;
1474 adapt_schema_to_format(&mut json, format)?;
1475 Ok(json)
1476 }
1477
1478 fn supported_provider(&self, provider: &LanguageModelProviderId) -> bool {
1479 self.0.supported_provider(provider)
1480 }
1481
1482 fn run(
1483 self: Arc<Self>,
1484 input: serde_json::Value,
1485 event_stream: ToolCallEventStream,
1486 cx: &mut App,
1487 ) -> Task<Result<AgentToolOutput>> {
1488 cx.spawn(async move |cx| {
1489 let input = serde_json::from_value(input)?;
1490 let output = cx
1491 .update(|cx| self.0.clone().run(input, event_stream, cx))?
1492 .await?;
1493 let raw_output = serde_json::to_value(&output)?;
1494 Ok(AgentToolOutput {
1495 llm_output: output.into(),
1496 raw_output,
1497 })
1498 })
1499 }
1500}
1501
1502#[derive(Clone)]
1503struct AgentResponseEventStream(mpsc::UnboundedSender<Result<AgentResponseEvent>>);
1504
1505impl AgentResponseEventStream {
1506 fn send_text(&self, text: &str) {
1507 self.0
1508 .unbounded_send(Ok(AgentResponseEvent::Text(text.to_string())))
1509 .ok();
1510 }
1511
1512 fn send_thinking(&self, text: &str) {
1513 self.0
1514 .unbounded_send(Ok(AgentResponseEvent::Thinking(text.to_string())))
1515 .ok();
1516 }
1517
1518 fn send_tool_call(
1519 &self,
1520 id: &LanguageModelToolUseId,
1521 title: SharedString,
1522 kind: acp::ToolKind,
1523 input: serde_json::Value,
1524 ) {
1525 self.0
1526 .unbounded_send(Ok(AgentResponseEvent::ToolCall(Self::initial_tool_call(
1527 id,
1528 title.to_string(),
1529 kind,
1530 input,
1531 ))))
1532 .ok();
1533 }
1534
1535 fn initial_tool_call(
1536 id: &LanguageModelToolUseId,
1537 title: String,
1538 kind: acp::ToolKind,
1539 input: serde_json::Value,
1540 ) -> acp::ToolCall {
1541 acp::ToolCall {
1542 id: acp::ToolCallId(id.to_string().into()),
1543 title,
1544 kind,
1545 status: acp::ToolCallStatus::Pending,
1546 content: vec![],
1547 locations: vec![],
1548 raw_input: Some(input),
1549 raw_output: None,
1550 }
1551 }
1552
1553 fn update_tool_call_fields(
1554 &self,
1555 tool_use_id: &LanguageModelToolUseId,
1556 fields: acp::ToolCallUpdateFields,
1557 ) {
1558 self.0
1559 .unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate(
1560 acp::ToolCallUpdate {
1561 id: acp::ToolCallId(tool_use_id.to_string().into()),
1562 fields,
1563 }
1564 .into(),
1565 )))
1566 .ok();
1567 }
1568
1569 fn send_retry(&self, status: acp_thread::RetryStatus) {
1570 self.0
1571 .unbounded_send(Ok(AgentResponseEvent::Retry(status)))
1572 .ok();
1573 }
1574
1575 fn send_stop(&self, reason: StopReason) {
1576 match reason {
1577 StopReason::EndTurn => {
1578 self.0
1579 .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::EndTurn)))
1580 .ok();
1581 }
1582 StopReason::MaxTokens => {
1583 self.0
1584 .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::MaxTokens)))
1585 .ok();
1586 }
1587 StopReason::Refusal => {
1588 self.0
1589 .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::Refusal)))
1590 .ok();
1591 }
1592 StopReason::ToolUse => {}
1593 }
1594 }
1595
1596 fn send_canceled(&self) {
1597 self.0
1598 .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::Canceled)))
1599 .ok();
1600 }
1601
1602 fn send_error(&self, error: impl Into<anyhow::Error>) {
1603 self.0.unbounded_send(Err(error.into())).ok();
1604 }
1605}
1606
1607#[derive(Clone)]
1608pub struct ToolCallEventStream {
1609 tool_use_id: LanguageModelToolUseId,
1610 stream: AgentResponseEventStream,
1611 fs: Option<Arc<dyn Fs>>,
1612}
1613
1614impl ToolCallEventStream {
1615 #[cfg(test)]
1616 pub fn test() -> (Self, ToolCallEventStreamReceiver) {
1617 let (events_tx, events_rx) = mpsc::unbounded::<Result<AgentResponseEvent>>();
1618
1619 let stream =
1620 ToolCallEventStream::new("test_id".into(), AgentResponseEventStream(events_tx), None);
1621
1622 (stream, ToolCallEventStreamReceiver(events_rx))
1623 }
1624
1625 fn new(
1626 tool_use_id: LanguageModelToolUseId,
1627 stream: AgentResponseEventStream,
1628 fs: Option<Arc<dyn Fs>>,
1629 ) -> Self {
1630 Self {
1631 tool_use_id,
1632 stream,
1633 fs,
1634 }
1635 }
1636
1637 pub fn update_fields(&self, fields: acp::ToolCallUpdateFields) {
1638 self.stream
1639 .update_tool_call_fields(&self.tool_use_id, fields);
1640 }
1641
1642 pub fn update_diff(&self, diff: Entity<acp_thread::Diff>) {
1643 self.stream
1644 .0
1645 .unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate(
1646 acp_thread::ToolCallUpdateDiff {
1647 id: acp::ToolCallId(self.tool_use_id.to_string().into()),
1648 diff,
1649 }
1650 .into(),
1651 )))
1652 .ok();
1653 }
1654
1655 pub fn update_terminal(&self, terminal: Entity<acp_thread::Terminal>) {
1656 self.stream
1657 .0
1658 .unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate(
1659 acp_thread::ToolCallUpdateTerminal {
1660 id: acp::ToolCallId(self.tool_use_id.to_string().into()),
1661 terminal,
1662 }
1663 .into(),
1664 )))
1665 .ok();
1666 }
1667
1668 pub fn authorize(&self, title: impl Into<String>, cx: &mut App) -> Task<Result<()>> {
1669 if agent_settings::AgentSettings::get_global(cx).always_allow_tool_actions {
1670 return Task::ready(Ok(()));
1671 }
1672
1673 let (response_tx, response_rx) = oneshot::channel();
1674 self.stream
1675 .0
1676 .unbounded_send(Ok(AgentResponseEvent::ToolCallAuthorization(
1677 ToolCallAuthorization {
1678 tool_call: acp::ToolCallUpdate {
1679 id: acp::ToolCallId(self.tool_use_id.to_string().into()),
1680 fields: acp::ToolCallUpdateFields {
1681 title: Some(title.into()),
1682 ..Default::default()
1683 },
1684 },
1685 options: vec![
1686 acp::PermissionOption {
1687 id: acp::PermissionOptionId("always_allow".into()),
1688 name: "Always Allow".into(),
1689 kind: acp::PermissionOptionKind::AllowAlways,
1690 },
1691 acp::PermissionOption {
1692 id: acp::PermissionOptionId("allow".into()),
1693 name: "Allow".into(),
1694 kind: acp::PermissionOptionKind::AllowOnce,
1695 },
1696 acp::PermissionOption {
1697 id: acp::PermissionOptionId("deny".into()),
1698 name: "Deny".into(),
1699 kind: acp::PermissionOptionKind::RejectOnce,
1700 },
1701 ],
1702 response: response_tx,
1703 },
1704 )))
1705 .ok();
1706 let fs = self.fs.clone();
1707 cx.spawn(async move |cx| match response_rx.await?.0.as_ref() {
1708 "always_allow" => {
1709 if let Some(fs) = fs.clone() {
1710 cx.update(|cx| {
1711 update_settings_file::<AgentSettings>(fs, cx, |settings, _| {
1712 settings.set_always_allow_tool_actions(true);
1713 });
1714 })?;
1715 }
1716
1717 Ok(())
1718 }
1719 "allow" => Ok(()),
1720 _ => Err(anyhow!("Permission to run tool denied by user")),
1721 })
1722 }
1723}
1724
1725#[cfg(test)]
1726pub struct ToolCallEventStreamReceiver(mpsc::UnboundedReceiver<Result<AgentResponseEvent>>);
1727
1728#[cfg(test)]
1729impl ToolCallEventStreamReceiver {
1730 pub async fn expect_authorization(&mut self) -> ToolCallAuthorization {
1731 let event = self.0.next().await;
1732 if let Some(Ok(AgentResponseEvent::ToolCallAuthorization(auth))) = event {
1733 auth
1734 } else {
1735 panic!("Expected ToolCallAuthorization but got: {:?}", event);
1736 }
1737 }
1738
1739 pub async fn expect_terminal(&mut self) -> Entity<acp_thread::Terminal> {
1740 let event = self.0.next().await;
1741 if let Some(Ok(AgentResponseEvent::ToolCallUpdate(
1742 acp_thread::ToolCallUpdate::UpdateTerminal(update),
1743 ))) = event
1744 {
1745 update.terminal
1746 } else {
1747 panic!("Expected terminal but got: {:?}", event);
1748 }
1749 }
1750}
1751
1752#[cfg(test)]
1753impl std::ops::Deref for ToolCallEventStreamReceiver {
1754 type Target = mpsc::UnboundedReceiver<Result<AgentResponseEvent>>;
1755
1756 fn deref(&self) -> &Self::Target {
1757 &self.0
1758 }
1759}
1760
1761#[cfg(test)]
1762impl std::ops::DerefMut for ToolCallEventStreamReceiver {
1763 fn deref_mut(&mut self) -> &mut Self::Target {
1764 &mut self.0
1765 }
1766}
1767
1768impl From<&str> for UserMessageContent {
1769 fn from(text: &str) -> Self {
1770 Self::Text(text.into())
1771 }
1772}
1773
1774impl From<acp::ContentBlock> for UserMessageContent {
1775 fn from(value: acp::ContentBlock) -> Self {
1776 match value {
1777 acp::ContentBlock::Text(text_content) => Self::Text(text_content.text),
1778 acp::ContentBlock::Image(image_content) => Self::Image(convert_image(image_content)),
1779 acp::ContentBlock::Audio(_) => {
1780 // TODO
1781 Self::Text("[audio]".to_string())
1782 }
1783 acp::ContentBlock::ResourceLink(resource_link) => {
1784 match MentionUri::parse(&resource_link.uri) {
1785 Ok(uri) => Self::Mention {
1786 uri,
1787 content: String::new(),
1788 },
1789 Err(err) => {
1790 log::error!("Failed to parse mention link: {}", err);
1791 Self::Text(format!("[{}]({})", resource_link.name, resource_link.uri))
1792 }
1793 }
1794 }
1795 acp::ContentBlock::Resource(resource) => match resource.resource {
1796 acp::EmbeddedResourceResource::TextResourceContents(resource) => {
1797 match MentionUri::parse(&resource.uri) {
1798 Ok(uri) => Self::Mention {
1799 uri,
1800 content: resource.text,
1801 },
1802 Err(err) => {
1803 log::error!("Failed to parse mention link: {}", err);
1804 Self::Text(
1805 MarkdownCodeBlock {
1806 tag: &resource.uri,
1807 text: &resource.text,
1808 }
1809 .to_string(),
1810 )
1811 }
1812 }
1813 }
1814 acp::EmbeddedResourceResource::BlobResourceContents(_) => {
1815 // TODO
1816 Self::Text("[blob]".to_string())
1817 }
1818 },
1819 }
1820 }
1821}
1822
1823fn convert_image(image_content: acp::ImageContent) -> LanguageModelImage {
1824 LanguageModelImage {
1825 source: image_content.data.into(),
1826 // TODO: make this optional?
1827 size: gpui::Size::new(0.into(), 0.into()),
1828 }
1829}