1use crate::{ContextServerRegistry, SystemPromptTemplate, Template, Templates};
2use acp_thread::{MentionUri, UserMessageId};
3use action_log::ActionLog;
4use agent_client_protocol as acp;
5use agent_settings::{AgentProfileId, AgentSettings, CompletionMode};
6use anyhow::{Context as _, Result, anyhow};
7use assistant_tool::adapt_schema_to_format;
8use cloud_llm_client::{CompletionIntent, CompletionRequestStatus};
9use collections::IndexMap;
10use fs::Fs;
11use futures::{
12 channel::{mpsc, oneshot},
13 stream::FuturesUnordered,
14};
15use gpui::{App, AsyncApp, Context, Entity, SharedString, Task, WeakEntity};
16use language_model::{
17 LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelImage,
18 LanguageModelProviderId, LanguageModelRequest, LanguageModelRequestMessage,
19 LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolResultContent,
20 LanguageModelToolSchemaFormat, LanguageModelToolUse, LanguageModelToolUseId, Role, StopReason,
21};
22use project::Project;
23use prompt_store::ProjectContext;
24use schemars::{JsonSchema, Schema};
25use serde::{Deserialize, Serialize};
26use settings::{Settings, update_settings_file};
27use smol::stream::StreamExt;
28use std::{
29 collections::BTreeMap,
30 path::Path,
31 sync::Arc,
32 time::{Duration, Instant},
33};
34use std::{fmt::Write, ops::Range};
35use util::{ResultExt, markdown::MarkdownCodeBlock};
36use uuid::Uuid;
37
38#[derive(
39 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
40)]
41pub struct ThreadId(Arc<str>);
42
43impl ThreadId {
44 pub fn new() -> Self {
45 Self(Uuid::new_v4().to_string().into())
46 }
47}
48
49impl std::fmt::Display for ThreadId {
50 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
51 write!(f, "{}", self.0)
52 }
53}
54
55impl From<&str> for ThreadId {
56 fn from(value: &str) -> Self {
57 Self(value.into())
58 }
59}
60
61/// The ID of the user prompt that initiated a request.
62///
63/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
64#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
65pub struct PromptId(Arc<str>);
66
67impl PromptId {
68 pub fn new() -> Self {
69 Self(Uuid::new_v4().to_string().into())
70 }
71}
72
73impl std::fmt::Display for PromptId {
74 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
75 write!(f, "{}", self.0)
76 }
77}
78
79pub(crate) const MAX_RETRY_ATTEMPTS: u8 = 4;
80pub(crate) const BASE_RETRY_DELAY: Duration = Duration::from_secs(5);
81
82#[derive(Debug, Clone)]
83enum RetryStrategy {
84 ExponentialBackoff {
85 initial_delay: Duration,
86 max_attempts: u8,
87 },
88 Fixed {
89 delay: Duration,
90 max_attempts: u8,
91 },
92}
93
94#[derive(Debug, Clone, PartialEq, Eq)]
95pub enum Message {
96 User(UserMessage),
97 Agent(AgentMessage),
98 Resume,
99}
100
101impl Message {
102 pub fn as_agent_message(&self) -> Option<&AgentMessage> {
103 match self {
104 Message::Agent(agent_message) => Some(agent_message),
105 _ => None,
106 }
107 }
108
109 pub fn to_markdown(&self) -> String {
110 match self {
111 Message::User(message) => message.to_markdown(),
112 Message::Agent(message) => message.to_markdown(),
113 Message::Resume => "[resumed after tool use limit was reached]".into(),
114 }
115 }
116}
117
118#[derive(Debug, Clone, PartialEq, Eq)]
119pub struct UserMessage {
120 pub id: UserMessageId,
121 pub content: Vec<UserMessageContent>,
122}
123
124#[derive(Debug, Clone, PartialEq, Eq)]
125pub enum UserMessageContent {
126 Text(String),
127 Mention { uri: MentionUri, content: String },
128 Image(LanguageModelImage),
129}
130
131impl UserMessage {
132 pub fn to_markdown(&self) -> String {
133 let mut markdown = String::from("## User\n\n");
134
135 for content in &self.content {
136 match content {
137 UserMessageContent::Text(text) => {
138 markdown.push_str(text);
139 markdown.push('\n');
140 }
141 UserMessageContent::Image(_) => {
142 markdown.push_str("<image />\n");
143 }
144 UserMessageContent::Mention { uri, content } => {
145 if !content.is_empty() {
146 let _ = write!(&mut markdown, "{}\n\n{}\n", uri.as_link(), content);
147 } else {
148 let _ = write!(&mut markdown, "{}\n", uri.as_link());
149 }
150 }
151 }
152 }
153
154 markdown
155 }
156
157 fn to_request(&self) -> LanguageModelRequestMessage {
158 let mut message = LanguageModelRequestMessage {
159 role: Role::User,
160 content: Vec::with_capacity(self.content.len()),
161 cache: false,
162 };
163
164 const OPEN_CONTEXT: &str = "<context>\n\
165 The following items were attached by the user. \
166 They are up-to-date and don't need to be re-read.\n\n";
167
168 const OPEN_FILES_TAG: &str = "<files>";
169 const OPEN_DIRECTORIES_TAG: &str = "<directories>";
170 const OPEN_SYMBOLS_TAG: &str = "<symbols>";
171 const OPEN_THREADS_TAG: &str = "<threads>";
172 const OPEN_FETCH_TAG: &str = "<fetched_urls>";
173 const OPEN_RULES_TAG: &str =
174 "<rules>\nThe user has specified the following rules that should be applied:\n";
175
176 let mut file_context = OPEN_FILES_TAG.to_string();
177 let mut directory_context = OPEN_DIRECTORIES_TAG.to_string();
178 let mut symbol_context = OPEN_SYMBOLS_TAG.to_string();
179 let mut thread_context = OPEN_THREADS_TAG.to_string();
180 let mut fetch_context = OPEN_FETCH_TAG.to_string();
181 let mut rules_context = OPEN_RULES_TAG.to_string();
182
183 for chunk in &self.content {
184 let chunk = match chunk {
185 UserMessageContent::Text(text) => {
186 language_model::MessageContent::Text(text.clone())
187 }
188 UserMessageContent::Image(value) => {
189 language_model::MessageContent::Image(value.clone())
190 }
191 UserMessageContent::Mention { uri, content } => {
192 match uri {
193 MentionUri::File { abs_path } => {
194 write!(
195 &mut symbol_context,
196 "\n{}",
197 MarkdownCodeBlock {
198 tag: &codeblock_tag(abs_path, None),
199 text: &content.to_string(),
200 }
201 )
202 .ok();
203 }
204 MentionUri::Directory { .. } => {
205 write!(&mut directory_context, "\n{}\n", content).ok();
206 }
207 MentionUri::Symbol {
208 path, line_range, ..
209 }
210 | MentionUri::Selection {
211 path, line_range, ..
212 } => {
213 write!(
214 &mut rules_context,
215 "\n{}",
216 MarkdownCodeBlock {
217 tag: &codeblock_tag(path, Some(line_range)),
218 text: content
219 }
220 )
221 .ok();
222 }
223 MentionUri::Thread { .. } => {
224 write!(&mut thread_context, "\n{}\n", content).ok();
225 }
226 MentionUri::TextThread { .. } => {
227 write!(&mut thread_context, "\n{}\n", content).ok();
228 }
229 MentionUri::Rule { .. } => {
230 write!(
231 &mut rules_context,
232 "\n{}",
233 MarkdownCodeBlock {
234 tag: "",
235 text: content
236 }
237 )
238 .ok();
239 }
240 MentionUri::Fetch { url } => {
241 write!(&mut fetch_context, "\nFetch: {}\n\n{}", url, content).ok();
242 }
243 }
244
245 language_model::MessageContent::Text(uri.as_link().to_string())
246 }
247 };
248
249 message.content.push(chunk);
250 }
251
252 let len_before_context = message.content.len();
253
254 if file_context.len() > OPEN_FILES_TAG.len() {
255 file_context.push_str("</files>\n");
256 message
257 .content
258 .push(language_model::MessageContent::Text(file_context));
259 }
260
261 if directory_context.len() > OPEN_DIRECTORIES_TAG.len() {
262 directory_context.push_str("</directories>\n");
263 message
264 .content
265 .push(language_model::MessageContent::Text(directory_context));
266 }
267
268 if symbol_context.len() > OPEN_SYMBOLS_TAG.len() {
269 symbol_context.push_str("</symbols>\n");
270 message
271 .content
272 .push(language_model::MessageContent::Text(symbol_context));
273 }
274
275 if thread_context.len() > OPEN_THREADS_TAG.len() {
276 thread_context.push_str("</threads>\n");
277 message
278 .content
279 .push(language_model::MessageContent::Text(thread_context));
280 }
281
282 if fetch_context.len() > OPEN_FETCH_TAG.len() {
283 fetch_context.push_str("</fetched_urls>\n");
284 message
285 .content
286 .push(language_model::MessageContent::Text(fetch_context));
287 }
288
289 if rules_context.len() > OPEN_RULES_TAG.len() {
290 rules_context.push_str("</user_rules>\n");
291 message
292 .content
293 .push(language_model::MessageContent::Text(rules_context));
294 }
295
296 if message.content.len() > len_before_context {
297 message.content.insert(
298 len_before_context,
299 language_model::MessageContent::Text(OPEN_CONTEXT.into()),
300 );
301 message
302 .content
303 .push(language_model::MessageContent::Text("</context>".into()));
304 }
305
306 message
307 }
308}
309
310fn codeblock_tag(full_path: &Path, line_range: Option<&Range<u32>>) -> String {
311 let mut result = String::new();
312
313 if let Some(extension) = full_path.extension().and_then(|ext| ext.to_str()) {
314 let _ = write!(result, "{} ", extension);
315 }
316
317 let _ = write!(result, "{}", full_path.display());
318
319 if let Some(range) = line_range {
320 if range.start == range.end {
321 let _ = write!(result, ":{}", range.start + 1);
322 } else {
323 let _ = write!(result, ":{}-{}", range.start + 1, range.end + 1);
324 }
325 }
326
327 result
328}
329
330impl AgentMessage {
331 pub fn to_markdown(&self) -> String {
332 let mut markdown = String::from("## Assistant\n\n");
333
334 for content in &self.content {
335 match content {
336 AgentMessageContent::Text(text) => {
337 markdown.push_str(text);
338 markdown.push('\n');
339 }
340 AgentMessageContent::Thinking { text, .. } => {
341 markdown.push_str("<think>");
342 markdown.push_str(text);
343 markdown.push_str("</think>\n");
344 }
345 AgentMessageContent::RedactedThinking(_) => {
346 markdown.push_str("<redacted_thinking />\n")
347 }
348 AgentMessageContent::Image(_) => {
349 markdown.push_str("<image />\n");
350 }
351 AgentMessageContent::ToolUse(tool_use) => {
352 markdown.push_str(&format!(
353 "**Tool Use**: {} (ID: {})\n",
354 tool_use.name, tool_use.id
355 ));
356 markdown.push_str(&format!(
357 "{}\n",
358 MarkdownCodeBlock {
359 tag: "json",
360 text: &format!("{:#}", tool_use.input)
361 }
362 ));
363 }
364 }
365 }
366
367 for tool_result in self.tool_results.values() {
368 markdown.push_str(&format!(
369 "**Tool Result**: {} (ID: {})\n\n",
370 tool_result.tool_name, tool_result.tool_use_id
371 ));
372 if tool_result.is_error {
373 markdown.push_str("**ERROR:**\n");
374 }
375
376 match &tool_result.content {
377 LanguageModelToolResultContent::Text(text) => {
378 writeln!(markdown, "{text}\n").ok();
379 }
380 LanguageModelToolResultContent::Image(_) => {
381 writeln!(markdown, "<image />\n").ok();
382 }
383 }
384
385 if let Some(output) = tool_result.output.as_ref() {
386 writeln!(
387 markdown,
388 "**Debug Output**:\n\n```json\n{}\n```\n",
389 serde_json::to_string_pretty(output).unwrap()
390 )
391 .unwrap();
392 }
393 }
394
395 markdown
396 }
397
398 pub fn to_request(&self) -> Vec<LanguageModelRequestMessage> {
399 let mut assistant_message = LanguageModelRequestMessage {
400 role: Role::Assistant,
401 content: Vec::with_capacity(self.content.len()),
402 cache: false,
403 };
404 for chunk in &self.content {
405 let chunk = match chunk {
406 AgentMessageContent::Text(text) => {
407 language_model::MessageContent::Text(text.clone())
408 }
409 AgentMessageContent::Thinking { text, signature } => {
410 language_model::MessageContent::Thinking {
411 text: text.clone(),
412 signature: signature.clone(),
413 }
414 }
415 AgentMessageContent::RedactedThinking(value) => {
416 language_model::MessageContent::RedactedThinking(value.clone())
417 }
418 AgentMessageContent::ToolUse(value) => {
419 language_model::MessageContent::ToolUse(value.clone())
420 }
421 AgentMessageContent::Image(value) => {
422 language_model::MessageContent::Image(value.clone())
423 }
424 };
425 assistant_message.content.push(chunk);
426 }
427
428 let mut user_message = LanguageModelRequestMessage {
429 role: Role::User,
430 content: Vec::new(),
431 cache: false,
432 };
433
434 for tool_result in self.tool_results.values() {
435 user_message
436 .content
437 .push(language_model::MessageContent::ToolResult(
438 tool_result.clone(),
439 ));
440 }
441
442 let mut messages = Vec::new();
443 if !assistant_message.content.is_empty() {
444 messages.push(assistant_message);
445 }
446 if !user_message.content.is_empty() {
447 messages.push(user_message);
448 }
449 messages
450 }
451}
452
453#[derive(Default, Debug, Clone, PartialEq, Eq)]
454pub struct AgentMessage {
455 pub content: Vec<AgentMessageContent>,
456 pub tool_results: IndexMap<LanguageModelToolUseId, LanguageModelToolResult>,
457}
458
459#[derive(Debug, Clone, PartialEq, Eq)]
460pub enum AgentMessageContent {
461 Text(String),
462 Thinking {
463 text: String,
464 signature: Option<String>,
465 },
466 RedactedThinking(String),
467 Image(LanguageModelImage),
468 ToolUse(LanguageModelToolUse),
469}
470
471#[derive(Debug)]
472pub enum AgentResponseEvent {
473 Text(String),
474 Thinking(String),
475 ToolCall(acp::ToolCall),
476 ToolCallUpdate(acp_thread::ToolCallUpdate),
477 ToolCallAuthorization(ToolCallAuthorization),
478 Retry(acp_thread::RetryStatus),
479 Stop(acp::StopReason),
480}
481
482#[derive(Debug)]
483pub struct ToolCallAuthorization {
484 pub tool_call: acp::ToolCallUpdate,
485 pub options: Vec<acp::PermissionOption>,
486 pub response: oneshot::Sender<acp::PermissionOptionId>,
487}
488
489pub struct Thread {
490 id: ThreadId,
491 prompt_id: PromptId,
492 messages: Vec<Message>,
493 completion_mode: CompletionMode,
494 /// Holds the task that handles agent interaction until the end of the turn.
495 /// Survives across multiple requests as the model performs tool calls and
496 /// we run tools, report their results.
497 running_turn: Option<RunningTurn>,
498 pending_message: Option<AgentMessage>,
499 tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
500 tool_use_limit_reached: bool,
501 context_server_registry: Entity<ContextServerRegistry>,
502 profile_id: AgentProfileId,
503 project_context: Entity<ProjectContext>,
504 templates: Arc<Templates>,
505 model: Option<Arc<dyn LanguageModel>>,
506 project: Entity<Project>,
507 action_log: Entity<ActionLog>,
508}
509
510impl Thread {
511 pub fn new(
512 project: Entity<Project>,
513 project_context: Entity<ProjectContext>,
514 context_server_registry: Entity<ContextServerRegistry>,
515 action_log: Entity<ActionLog>,
516 templates: Arc<Templates>,
517 model: Option<Arc<dyn LanguageModel>>,
518 cx: &mut Context<Self>,
519 ) -> Self {
520 let profile_id = AgentSettings::get_global(cx).default_profile.clone();
521 Self {
522 id: ThreadId::new(),
523 prompt_id: PromptId::new(),
524 messages: Vec::new(),
525 completion_mode: AgentSettings::get_global(cx).preferred_completion_mode,
526 running_turn: None,
527 pending_message: None,
528 tools: BTreeMap::default(),
529 tool_use_limit_reached: false,
530 context_server_registry,
531 profile_id,
532 project_context,
533 templates,
534 model,
535 project,
536 action_log,
537 }
538 }
539
540 pub fn project(&self) -> &Entity<Project> {
541 &self.project
542 }
543
544 pub fn project_context(&self) -> &Entity<ProjectContext> {
545 &self.project_context
546 }
547
548 pub fn action_log(&self) -> &Entity<ActionLog> {
549 &self.action_log
550 }
551
552 pub fn model(&self) -> Option<&Arc<dyn LanguageModel>> {
553 self.model.as_ref()
554 }
555
556 pub fn set_model(&mut self, model: Arc<dyn LanguageModel>) {
557 self.model = Some(model);
558 }
559
560 pub fn completion_mode(&self) -> CompletionMode {
561 self.completion_mode
562 }
563
564 pub fn set_completion_mode(&mut self, mode: CompletionMode) {
565 self.completion_mode = mode;
566 }
567
568 #[cfg(any(test, feature = "test-support"))]
569 pub fn last_message(&self) -> Option<Message> {
570 if let Some(message) = self.pending_message.clone() {
571 Some(Message::Agent(message))
572 } else {
573 self.messages.last().cloned()
574 }
575 }
576
577 pub fn add_tool(&mut self, tool: impl AgentTool) {
578 self.tools.insert(tool.name(), tool.erase());
579 }
580
581 pub fn remove_tool(&mut self, name: &str) -> bool {
582 self.tools.remove(name).is_some()
583 }
584
585 pub fn profile(&self) -> &AgentProfileId {
586 &self.profile_id
587 }
588
589 pub fn set_profile(&mut self, profile_id: AgentProfileId) {
590 self.profile_id = profile_id;
591 }
592
593 pub fn cancel(&mut self) {
594 if let Some(running_turn) = self.running_turn.take() {
595 running_turn.cancel();
596 }
597 self.flush_pending_message();
598 }
599
600 pub fn truncate(&mut self, message_id: UserMessageId) -> Result<()> {
601 self.cancel();
602 let Some(position) = self.messages.iter().position(
603 |msg| matches!(msg, Message::User(UserMessage { id, .. }) if id == &message_id),
604 ) else {
605 return Err(anyhow!("Message not found"));
606 };
607 self.messages.truncate(position);
608 Ok(())
609 }
610
611 pub fn resume(
612 &mut self,
613 cx: &mut Context<Self>,
614 ) -> Result<mpsc::UnboundedReceiver<Result<AgentResponseEvent>>> {
615 anyhow::ensure!(self.model.is_some(), "Model not set");
616 anyhow::ensure!(
617 self.tool_use_limit_reached,
618 "can only resume after tool use limit is reached"
619 );
620
621 self.messages.push(Message::Resume);
622 cx.notify();
623
624 log::info!("Total messages in thread: {}", self.messages.len());
625 self.run_turn(cx)
626 }
627
628 /// Sending a message results in the model streaming a response, which could include tool calls.
629 /// After calling tools, the model will stops and waits for any outstanding tool calls to be completed and their results sent.
630 /// The returned channel will report all the occurrences in which the model stops before erroring or ending its turn.
631 pub fn send<T>(
632 &mut self,
633 id: UserMessageId,
634 content: impl IntoIterator<Item = T>,
635 cx: &mut Context<Self>,
636 ) -> Result<mpsc::UnboundedReceiver<Result<AgentResponseEvent>>>
637 where
638 T: Into<UserMessageContent>,
639 {
640 let model = self.model().context("No language model configured")?;
641
642 log::info!("Thread::send called with model: {:?}", model.name());
643 self.advance_prompt_id();
644
645 let content = content.into_iter().map(Into::into).collect::<Vec<_>>();
646 log::debug!("Thread::send content: {:?}", content);
647
648 self.messages
649 .push(Message::User(UserMessage { id, content }));
650 cx.notify();
651
652 log::info!("Total messages in thread: {}", self.messages.len());
653 self.run_turn(cx)
654 }
655
656 fn run_turn(
657 &mut self,
658 cx: &mut Context<Self>,
659 ) -> Result<mpsc::UnboundedReceiver<Result<AgentResponseEvent>>> {
660 self.cancel();
661
662 let model = self
663 .model()
664 .cloned()
665 .context("No language model configured")?;
666 let (events_tx, events_rx) = mpsc::unbounded::<Result<AgentResponseEvent>>();
667 let event_stream = AgentResponseEventStream(events_tx);
668 let message_ix = self.messages.len().saturating_sub(1);
669 self.tool_use_limit_reached = false;
670 self.running_turn = Some(RunningTurn {
671 event_stream: event_stream.clone(),
672 _task: cx.spawn(async move |this, cx| {
673 log::info!("Starting agent turn execution");
674 let turn_result: Result<()> = async {
675 let mut completion_intent = CompletionIntent::UserPrompt;
676 loop {
677 log::debug!(
678 "Building completion request with intent: {:?}",
679 completion_intent
680 );
681 let request = this.update(cx, |this, cx| {
682 this.build_completion_request(completion_intent, cx)
683 })??;
684
685 log::info!("Calling model.stream_completion");
686
687 let mut tool_use_limit_reached = false;
688 let mut tool_uses = Self::stream_completion_with_retries(
689 this.clone(),
690 model.clone(),
691 request,
692 message_ix,
693 &event_stream,
694 &mut tool_use_limit_reached,
695 cx,
696 )
697 .await?;
698
699 let used_tools = tool_uses.is_empty();
700 while let Some(tool_result) = tool_uses.next().await {
701 log::info!("Tool finished {:?}", tool_result);
702
703 event_stream.update_tool_call_fields(
704 &tool_result.tool_use_id,
705 acp::ToolCallUpdateFields {
706 status: Some(if tool_result.is_error {
707 acp::ToolCallStatus::Failed
708 } else {
709 acp::ToolCallStatus::Completed
710 }),
711 raw_output: tool_result.output.clone(),
712 ..Default::default()
713 },
714 );
715 this.update(cx, |this, _cx| {
716 this.pending_message()
717 .tool_results
718 .insert(tool_result.tool_use_id.clone(), tool_result);
719 })
720 .ok();
721 }
722
723 if tool_use_limit_reached {
724 log::info!("Tool use limit reached, completing turn");
725 this.update(cx, |this, _cx| this.tool_use_limit_reached = true)?;
726 return Err(language_model::ToolUseLimitReachedError.into());
727 } else if used_tools {
728 log::info!("No tool uses found, completing turn");
729 return Ok(());
730 } else {
731 this.update(cx, |this, _| this.flush_pending_message())?;
732 completion_intent = CompletionIntent::ToolResults;
733 }
734 }
735 }
736 .await;
737
738 if let Err(error) = turn_result {
739 log::error!("Turn execution failed: {:?}", error);
740 event_stream.send_error(error);
741 } else {
742 log::info!("Turn execution completed successfully");
743 }
744
745 this.update(cx, |this, _| {
746 this.flush_pending_message();
747 this.running_turn.take();
748 })
749 .ok();
750 }),
751 });
752 Ok(events_rx)
753 }
754
755 async fn stream_completion_with_retries(
756 this: WeakEntity<Self>,
757 model: Arc<dyn LanguageModel>,
758 request: LanguageModelRequest,
759 message_ix: usize,
760 event_stream: &AgentResponseEventStream,
761 tool_use_limit_reached: &mut bool,
762 cx: &mut AsyncApp,
763 ) -> Result<FuturesUnordered<Task<LanguageModelToolResult>>> {
764 log::debug!("Stream completion started successfully");
765
766 let mut attempt = None;
767 'retry: loop {
768 let mut events = model.stream_completion(request.clone(), cx).await?;
769 let mut tool_uses = FuturesUnordered::new();
770 while let Some(event) = events.next().await {
771 match event {
772 Ok(LanguageModelCompletionEvent::StatusUpdate(
773 CompletionRequestStatus::ToolUseLimitReached,
774 )) => {
775 *tool_use_limit_reached = true;
776 }
777 Ok(LanguageModelCompletionEvent::Stop(reason)) => {
778 event_stream.send_stop(reason);
779 if reason == StopReason::Refusal {
780 this.update(cx, |this, _cx| {
781 this.flush_pending_message();
782 this.messages.truncate(message_ix);
783 })?;
784 return Ok(tool_uses);
785 }
786 }
787 Ok(event) => {
788 log::trace!("Received completion event: {:?}", event);
789 this.update(cx, |this, cx| {
790 tool_uses.extend(this.handle_streamed_completion_event(
791 event,
792 event_stream,
793 cx,
794 ));
795 })
796 .ok();
797 }
798 Err(error) => {
799 let completion_mode =
800 this.read_with(cx, |thread, _cx| thread.completion_mode())?;
801 if completion_mode == CompletionMode::Normal {
802 return Err(error.into());
803 }
804
805 let Some(strategy) = Self::retry_strategy_for(&error) else {
806 return Err(error.into());
807 };
808
809 let max_attempts = match &strategy {
810 RetryStrategy::ExponentialBackoff { max_attempts, .. } => *max_attempts,
811 RetryStrategy::Fixed { max_attempts, .. } => *max_attempts,
812 };
813
814 let attempt = attempt.get_or_insert(0u8);
815
816 *attempt += 1;
817
818 let attempt = *attempt;
819 if attempt > max_attempts {
820 return Err(error.into());
821 }
822
823 let delay = match &strategy {
824 RetryStrategy::ExponentialBackoff { initial_delay, .. } => {
825 let delay_secs =
826 initial_delay.as_secs() * 2u64.pow((attempt - 1) as u32);
827 Duration::from_secs(delay_secs)
828 }
829 RetryStrategy::Fixed { delay, .. } => *delay,
830 };
831 log::debug!("Retry attempt {attempt} with delay {delay:?}");
832
833 event_stream.send_retry(acp_thread::RetryStatus {
834 last_error: error.to_string().into(),
835 attempt: attempt as usize,
836 max_attempts: max_attempts as usize,
837 started_at: Instant::now(),
838 duration: delay,
839 });
840
841 cx.background_executor().timer(delay).await;
842 continue 'retry;
843 }
844 }
845 }
846 return Ok(tool_uses);
847 }
848 }
849
850 pub fn build_system_message(&self, cx: &App) -> LanguageModelRequestMessage {
851 log::debug!("Building system message");
852 let prompt = SystemPromptTemplate {
853 project: self.project_context.read(cx),
854 available_tools: self.tools.keys().cloned().collect(),
855 }
856 .render(&self.templates)
857 .context("failed to build system prompt")
858 .expect("Invalid template");
859 log::debug!("System message built");
860 LanguageModelRequestMessage {
861 role: Role::System,
862 content: vec![prompt.into()],
863 cache: true,
864 }
865 }
866
867 /// A helper method that's called on every streamed completion event.
868 /// Returns an optional tool result task, which the main agentic loop in
869 /// send will send back to the model when it resolves.
870 fn handle_streamed_completion_event(
871 &mut self,
872 event: LanguageModelCompletionEvent,
873 event_stream: &AgentResponseEventStream,
874 cx: &mut Context<Self>,
875 ) -> Option<Task<LanguageModelToolResult>> {
876 log::trace!("Handling streamed completion event: {:?}", event);
877 use LanguageModelCompletionEvent::*;
878
879 match event {
880 StartMessage { .. } => {
881 self.flush_pending_message();
882 self.pending_message = Some(AgentMessage::default());
883 }
884 Text(new_text) => self.handle_text_event(new_text, event_stream, cx),
885 Thinking { text, signature } => {
886 self.handle_thinking_event(text, signature, event_stream, cx)
887 }
888 RedactedThinking { data } => self.handle_redacted_thinking_event(data, cx),
889 ToolUse(tool_use) => {
890 return self.handle_tool_use_event(tool_use, event_stream, cx);
891 }
892 ToolUseJsonParseError {
893 id,
894 tool_name,
895 raw_input,
896 json_parse_error,
897 } => {
898 return Some(Task::ready(self.handle_tool_use_json_parse_error_event(
899 id,
900 tool_name,
901 raw_input,
902 json_parse_error,
903 )));
904 }
905 UsageUpdate(_) | StatusUpdate(_) => {}
906 Stop(_) => unreachable!(),
907 }
908
909 None
910 }
911
912 fn handle_text_event(
913 &mut self,
914 new_text: String,
915 event_stream: &AgentResponseEventStream,
916 cx: &mut Context<Self>,
917 ) {
918 event_stream.send_text(&new_text);
919
920 let last_message = self.pending_message();
921 if let Some(AgentMessageContent::Text(text)) = last_message.content.last_mut() {
922 text.push_str(&new_text);
923 } else {
924 last_message
925 .content
926 .push(AgentMessageContent::Text(new_text));
927 }
928
929 cx.notify();
930 }
931
932 fn handle_thinking_event(
933 &mut self,
934 new_text: String,
935 new_signature: Option<String>,
936 event_stream: &AgentResponseEventStream,
937 cx: &mut Context<Self>,
938 ) {
939 event_stream.send_thinking(&new_text);
940
941 let last_message = self.pending_message();
942 if let Some(AgentMessageContent::Thinking { text, signature }) =
943 last_message.content.last_mut()
944 {
945 text.push_str(&new_text);
946 *signature = new_signature.or(signature.take());
947 } else {
948 last_message.content.push(AgentMessageContent::Thinking {
949 text: new_text,
950 signature: new_signature,
951 });
952 }
953
954 cx.notify();
955 }
956
957 fn handle_redacted_thinking_event(&mut self, data: String, cx: &mut Context<Self>) {
958 let last_message = self.pending_message();
959 last_message
960 .content
961 .push(AgentMessageContent::RedactedThinking(data));
962 cx.notify();
963 }
964
965 fn handle_tool_use_event(
966 &mut self,
967 tool_use: LanguageModelToolUse,
968 event_stream: &AgentResponseEventStream,
969 cx: &mut Context<Self>,
970 ) -> Option<Task<LanguageModelToolResult>> {
971 cx.notify();
972
973 let tool = self.tools.get(tool_use.name.as_ref()).cloned();
974 let mut title = SharedString::from(&tool_use.name);
975 let mut kind = acp::ToolKind::Other;
976 if let Some(tool) = tool.as_ref() {
977 title = tool.initial_title(tool_use.input.clone());
978 kind = tool.kind();
979 }
980
981 // Ensure the last message ends in the current tool use
982 let last_message = self.pending_message();
983 let push_new_tool_use = last_message.content.last_mut().map_or(true, |content| {
984 if let AgentMessageContent::ToolUse(last_tool_use) = content {
985 if last_tool_use.id == tool_use.id {
986 *last_tool_use = tool_use.clone();
987 false
988 } else {
989 true
990 }
991 } else {
992 true
993 }
994 });
995
996 if push_new_tool_use {
997 event_stream.send_tool_call(&tool_use.id, title, kind, tool_use.input.clone());
998 last_message
999 .content
1000 .push(AgentMessageContent::ToolUse(tool_use.clone()));
1001 } else {
1002 event_stream.update_tool_call_fields(
1003 &tool_use.id,
1004 acp::ToolCallUpdateFields {
1005 title: Some(title.into()),
1006 kind: Some(kind),
1007 raw_input: Some(tool_use.input.clone()),
1008 ..Default::default()
1009 },
1010 );
1011 }
1012
1013 if !tool_use.is_input_complete {
1014 return None;
1015 }
1016
1017 let Some(tool) = tool else {
1018 let content = format!("No tool named {} exists", tool_use.name);
1019 return Some(Task::ready(LanguageModelToolResult {
1020 content: LanguageModelToolResultContent::Text(Arc::from(content)),
1021 tool_use_id: tool_use.id,
1022 tool_name: tool_use.name,
1023 is_error: true,
1024 output: None,
1025 }));
1026 };
1027
1028 let fs = self.project.read(cx).fs().clone();
1029 let tool_event_stream =
1030 ToolCallEventStream::new(tool_use.id.clone(), event_stream.clone(), Some(fs));
1031 tool_event_stream.update_fields(acp::ToolCallUpdateFields {
1032 status: Some(acp::ToolCallStatus::InProgress),
1033 ..Default::default()
1034 });
1035 let supports_images = self.model().map_or(false, |model| model.supports_images());
1036 let tool_result = tool.run(tool_use.input, tool_event_stream, cx);
1037 log::info!("Running tool {}", tool_use.name);
1038 Some(cx.foreground_executor().spawn(async move {
1039 let tool_result = tool_result.await.and_then(|output| {
1040 if let LanguageModelToolResultContent::Image(_) = &output.llm_output
1041 && !supports_images
1042 {
1043 return Err(anyhow!(
1044 "Attempted to read an image, but this model doesn't support it.",
1045 ));
1046 }
1047 Ok(output)
1048 });
1049
1050 match tool_result {
1051 Ok(output) => LanguageModelToolResult {
1052 tool_use_id: tool_use.id,
1053 tool_name: tool_use.name,
1054 is_error: false,
1055 content: output.llm_output,
1056 output: Some(output.raw_output),
1057 },
1058 Err(error) => LanguageModelToolResult {
1059 tool_use_id: tool_use.id,
1060 tool_name: tool_use.name,
1061 is_error: true,
1062 content: LanguageModelToolResultContent::Text(Arc::from(error.to_string())),
1063 output: None,
1064 },
1065 }
1066 }))
1067 }
1068
1069 fn handle_tool_use_json_parse_error_event(
1070 &mut self,
1071 tool_use_id: LanguageModelToolUseId,
1072 tool_name: Arc<str>,
1073 raw_input: Arc<str>,
1074 json_parse_error: String,
1075 ) -> LanguageModelToolResult {
1076 let tool_output = format!("Error parsing input JSON: {json_parse_error}");
1077 LanguageModelToolResult {
1078 tool_use_id,
1079 tool_name,
1080 is_error: true,
1081 content: LanguageModelToolResultContent::Text(tool_output.into()),
1082 output: Some(serde_json::Value::String(raw_input.to_string())),
1083 }
1084 }
1085
1086 fn pending_message(&mut self) -> &mut AgentMessage {
1087 self.pending_message.get_or_insert_default()
1088 }
1089
1090 fn flush_pending_message(&mut self) {
1091 let Some(mut message) = self.pending_message.take() else {
1092 return;
1093 };
1094
1095 for content in &message.content {
1096 let AgentMessageContent::ToolUse(tool_use) = content else {
1097 continue;
1098 };
1099
1100 if !message.tool_results.contains_key(&tool_use.id) {
1101 message.tool_results.insert(
1102 tool_use.id.clone(),
1103 LanguageModelToolResult {
1104 tool_use_id: tool_use.id.clone(),
1105 tool_name: tool_use.name.clone(),
1106 is_error: true,
1107 content: LanguageModelToolResultContent::Text(
1108 "Tool canceled by user".into(),
1109 ),
1110 output: None,
1111 },
1112 );
1113 }
1114 }
1115
1116 self.messages.push(Message::Agent(message));
1117 }
1118
1119 pub(crate) fn build_completion_request(
1120 &self,
1121 completion_intent: CompletionIntent,
1122 cx: &mut App,
1123 ) -> Result<LanguageModelRequest> {
1124 let model = self.model().context("No language model configured")?;
1125
1126 log::debug!("Building completion request");
1127 log::debug!("Completion intent: {:?}", completion_intent);
1128 log::debug!("Completion mode: {:?}", self.completion_mode);
1129
1130 let messages = self.build_request_messages(cx);
1131 log::info!("Request will include {} messages", messages.len());
1132
1133 let tools = if let Some(tools) = self.tools(cx).log_err() {
1134 tools
1135 .filter_map(|tool| {
1136 let tool_name = tool.name().to_string();
1137 log::trace!("Including tool: {}", tool_name);
1138 Some(LanguageModelRequestTool {
1139 name: tool_name,
1140 description: tool.description().to_string(),
1141 input_schema: tool.input_schema(model.tool_input_format()).log_err()?,
1142 })
1143 })
1144 .collect()
1145 } else {
1146 Vec::new()
1147 };
1148
1149 log::info!("Request includes {} tools", tools.len());
1150
1151 let request = LanguageModelRequest {
1152 thread_id: Some(self.id.to_string()),
1153 prompt_id: Some(self.prompt_id.to_string()),
1154 intent: Some(completion_intent),
1155 mode: Some(self.completion_mode.into()),
1156 messages,
1157 tools,
1158 tool_choice: None,
1159 stop: Vec::new(),
1160 temperature: AgentSettings::temperature_for_model(model, cx),
1161 thinking_allowed: true,
1162 };
1163
1164 log::debug!("Completion request built successfully");
1165 Ok(request)
1166 }
1167
1168 fn tools<'a>(&'a self, cx: &'a App) -> Result<impl Iterator<Item = &'a Arc<dyn AnyAgentTool>>> {
1169 let model = self.model().context("No language model configured")?;
1170
1171 let profile = AgentSettings::get_global(cx)
1172 .profiles
1173 .get(&self.profile_id)
1174 .context("profile not found")?;
1175 let provider_id = model.provider_id();
1176
1177 Ok(self
1178 .tools
1179 .iter()
1180 .filter(move |(_, tool)| tool.supported_provider(&provider_id))
1181 .filter_map(|(tool_name, tool)| {
1182 if profile.is_tool_enabled(tool_name) {
1183 Some(tool)
1184 } else {
1185 None
1186 }
1187 })
1188 .chain(self.context_server_registry.read(cx).servers().flat_map(
1189 |(server_id, tools)| {
1190 tools.iter().filter_map(|(tool_name, tool)| {
1191 if profile.is_context_server_tool_enabled(&server_id.0, tool_name) {
1192 Some(tool)
1193 } else {
1194 None
1195 }
1196 })
1197 },
1198 )))
1199 }
1200
1201 fn build_request_messages(&self, cx: &App) -> Vec<LanguageModelRequestMessage> {
1202 log::trace!(
1203 "Building request messages from {} thread messages",
1204 self.messages.len()
1205 );
1206 let mut messages = vec![self.build_system_message(cx)];
1207 for message in &self.messages {
1208 match message {
1209 Message::User(message) => messages.push(message.to_request()),
1210 Message::Agent(message) => messages.extend(message.to_request()),
1211 Message::Resume => messages.push(LanguageModelRequestMessage {
1212 role: Role::User,
1213 content: vec!["Continue where you left off".into()],
1214 cache: false,
1215 }),
1216 }
1217 }
1218
1219 if let Some(message) = self.pending_message.as_ref() {
1220 messages.extend(message.to_request());
1221 }
1222
1223 if let Some(last_user_message) = messages
1224 .iter_mut()
1225 .rev()
1226 .find(|message| message.role == Role::User)
1227 {
1228 last_user_message.cache = true;
1229 }
1230
1231 messages
1232 }
1233
1234 pub fn to_markdown(&self) -> String {
1235 let mut markdown = String::new();
1236 for (ix, message) in self.messages.iter().enumerate() {
1237 if ix > 0 {
1238 markdown.push('\n');
1239 }
1240 markdown.push_str(&message.to_markdown());
1241 }
1242
1243 if let Some(message) = self.pending_message.as_ref() {
1244 markdown.push('\n');
1245 markdown.push_str(&message.to_markdown());
1246 }
1247
1248 markdown
1249 }
1250
1251 fn advance_prompt_id(&mut self) {
1252 self.prompt_id = PromptId::new();
1253 }
1254
1255 fn retry_strategy_for(error: &LanguageModelCompletionError) -> Option<RetryStrategy> {
1256 use LanguageModelCompletionError::*;
1257 use http_client::StatusCode;
1258
1259 // General strategy here:
1260 // - If retrying won't help (e.g. invalid API key or payload too large), return None so we don't retry at all.
1261 // - If it's a time-based issue (e.g. server overloaded, rate limit exceeded), retry up to 4 times with exponential backoff.
1262 // - If it's an issue that *might* be fixed by retrying (e.g. internal server error), retry up to 3 times.
1263 match error {
1264 HttpResponseError {
1265 status_code: StatusCode::TOO_MANY_REQUESTS,
1266 ..
1267 } => Some(RetryStrategy::ExponentialBackoff {
1268 initial_delay: BASE_RETRY_DELAY,
1269 max_attempts: MAX_RETRY_ATTEMPTS,
1270 }),
1271 ServerOverloaded { retry_after, .. } | RateLimitExceeded { retry_after, .. } => {
1272 Some(RetryStrategy::Fixed {
1273 delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
1274 max_attempts: MAX_RETRY_ATTEMPTS,
1275 })
1276 }
1277 UpstreamProviderError {
1278 status,
1279 retry_after,
1280 ..
1281 } => match *status {
1282 StatusCode::TOO_MANY_REQUESTS | StatusCode::SERVICE_UNAVAILABLE => {
1283 Some(RetryStrategy::Fixed {
1284 delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
1285 max_attempts: MAX_RETRY_ATTEMPTS,
1286 })
1287 }
1288 StatusCode::INTERNAL_SERVER_ERROR => Some(RetryStrategy::Fixed {
1289 delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
1290 // Internal Server Error could be anything, retry up to 3 times.
1291 max_attempts: 3,
1292 }),
1293 status => {
1294 // There is no StatusCode variant for the unofficial HTTP 529 ("The service is overloaded"),
1295 // but we frequently get them in practice. See https://http.dev/529
1296 if status.as_u16() == 529 {
1297 Some(RetryStrategy::Fixed {
1298 delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
1299 max_attempts: MAX_RETRY_ATTEMPTS,
1300 })
1301 } else {
1302 Some(RetryStrategy::Fixed {
1303 delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
1304 max_attempts: 2,
1305 })
1306 }
1307 }
1308 },
1309 ApiInternalServerError { .. } => Some(RetryStrategy::Fixed {
1310 delay: BASE_RETRY_DELAY,
1311 max_attempts: 3,
1312 }),
1313 ApiReadResponseError { .. }
1314 | HttpSend { .. }
1315 | DeserializeResponse { .. }
1316 | BadRequestFormat { .. } => Some(RetryStrategy::Fixed {
1317 delay: BASE_RETRY_DELAY,
1318 max_attempts: 3,
1319 }),
1320 // Retrying these errors definitely shouldn't help.
1321 HttpResponseError {
1322 status_code:
1323 StatusCode::PAYLOAD_TOO_LARGE | StatusCode::FORBIDDEN | StatusCode::UNAUTHORIZED,
1324 ..
1325 }
1326 | AuthenticationError { .. }
1327 | PermissionError { .. }
1328 | NoApiKey { .. }
1329 | ApiEndpointNotFound { .. }
1330 | PromptTooLarge { .. } => None,
1331 // These errors might be transient, so retry them
1332 SerializeRequest { .. } | BuildRequestBody { .. } => Some(RetryStrategy::Fixed {
1333 delay: BASE_RETRY_DELAY,
1334 max_attempts: 1,
1335 }),
1336 // Retry all other 4xx and 5xx errors once.
1337 HttpResponseError { status_code, .. }
1338 if status_code.is_client_error() || status_code.is_server_error() =>
1339 {
1340 Some(RetryStrategy::Fixed {
1341 delay: BASE_RETRY_DELAY,
1342 max_attempts: 3,
1343 })
1344 }
1345 Other(err)
1346 if err.is::<language_model::PaymentRequiredError>()
1347 || err.is::<language_model::ModelRequestLimitReachedError>() =>
1348 {
1349 // Retrying won't help for Payment Required or Model Request Limit errors (where
1350 // the user must upgrade to usage-based billing to get more requests, or else wait
1351 // for a significant amount of time for the request limit to reset).
1352 None
1353 }
1354 // Conservatively assume that any other errors are non-retryable
1355 HttpResponseError { .. } | Other(..) => Some(RetryStrategy::Fixed {
1356 delay: BASE_RETRY_DELAY,
1357 max_attempts: 2,
1358 }),
1359 }
1360 }
1361}
1362
1363struct RunningTurn {
1364 /// Holds the task that handles agent interaction until the end of the turn.
1365 /// Survives across multiple requests as the model performs tool calls and
1366 /// we run tools, report their results.
1367 _task: Task<()>,
1368 /// The current event stream for the running turn. Used to report a final
1369 /// cancellation event if we cancel the turn.
1370 event_stream: AgentResponseEventStream,
1371}
1372
1373impl RunningTurn {
1374 fn cancel(self) {
1375 log::debug!("Cancelling in progress turn");
1376 self.event_stream.send_canceled();
1377 }
1378}
1379
1380pub trait AgentTool
1381where
1382 Self: 'static + Sized,
1383{
1384 type Input: for<'de> Deserialize<'de> + Serialize + JsonSchema;
1385 type Output: for<'de> Deserialize<'de> + Serialize + Into<LanguageModelToolResultContent>;
1386
1387 fn name(&self) -> SharedString;
1388
1389 fn description(&self) -> SharedString {
1390 let schema = schemars::schema_for!(Self::Input);
1391 SharedString::new(
1392 schema
1393 .get("description")
1394 .and_then(|description| description.as_str())
1395 .unwrap_or_default(),
1396 )
1397 }
1398
1399 fn kind(&self) -> acp::ToolKind;
1400
1401 /// The initial tool title to display. Can be updated during the tool run.
1402 fn initial_title(&self, input: Result<Self::Input, serde_json::Value>) -> SharedString;
1403
1404 /// Returns the JSON schema that describes the tool's input.
1405 fn input_schema(&self) -> Schema {
1406 schemars::schema_for!(Self::Input)
1407 }
1408
1409 /// Some tools rely on a provider for the underlying billing or other reasons.
1410 /// Allow the tool to check if they are compatible, or should be filtered out.
1411 fn supported_provider(&self, _provider: &LanguageModelProviderId) -> bool {
1412 true
1413 }
1414
1415 /// Runs the tool with the provided input.
1416 fn run(
1417 self: Arc<Self>,
1418 input: Self::Input,
1419 event_stream: ToolCallEventStream,
1420 cx: &mut App,
1421 ) -> Task<Result<Self::Output>>;
1422
1423 fn erase(self) -> Arc<dyn AnyAgentTool> {
1424 Arc::new(Erased(Arc::new(self)))
1425 }
1426}
1427
1428pub struct Erased<T>(T);
1429
1430pub struct AgentToolOutput {
1431 pub llm_output: LanguageModelToolResultContent,
1432 pub raw_output: serde_json::Value,
1433}
1434
1435pub trait AnyAgentTool {
1436 fn name(&self) -> SharedString;
1437 fn description(&self) -> SharedString;
1438 fn kind(&self) -> acp::ToolKind;
1439 fn initial_title(&self, input: serde_json::Value) -> SharedString;
1440 fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value>;
1441 fn supported_provider(&self, _provider: &LanguageModelProviderId) -> bool {
1442 true
1443 }
1444 fn run(
1445 self: Arc<Self>,
1446 input: serde_json::Value,
1447 event_stream: ToolCallEventStream,
1448 cx: &mut App,
1449 ) -> Task<Result<AgentToolOutput>>;
1450}
1451
1452impl<T> AnyAgentTool for Erased<Arc<T>>
1453where
1454 T: AgentTool,
1455{
1456 fn name(&self) -> SharedString {
1457 self.0.name()
1458 }
1459
1460 fn description(&self) -> SharedString {
1461 self.0.description()
1462 }
1463
1464 fn kind(&self) -> agent_client_protocol::ToolKind {
1465 self.0.kind()
1466 }
1467
1468 fn initial_title(&self, input: serde_json::Value) -> SharedString {
1469 let parsed_input = serde_json::from_value(input.clone()).map_err(|_| input);
1470 self.0.initial_title(parsed_input)
1471 }
1472
1473 fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
1474 let mut json = serde_json::to_value(self.0.input_schema())?;
1475 adapt_schema_to_format(&mut json, format)?;
1476 Ok(json)
1477 }
1478
1479 fn supported_provider(&self, provider: &LanguageModelProviderId) -> bool {
1480 self.0.supported_provider(provider)
1481 }
1482
1483 fn run(
1484 self: Arc<Self>,
1485 input: serde_json::Value,
1486 event_stream: ToolCallEventStream,
1487 cx: &mut App,
1488 ) -> Task<Result<AgentToolOutput>> {
1489 cx.spawn(async move |cx| {
1490 let input = serde_json::from_value(input)?;
1491 let output = cx
1492 .update(|cx| self.0.clone().run(input, event_stream, cx))?
1493 .await?;
1494 let raw_output = serde_json::to_value(&output)?;
1495 Ok(AgentToolOutput {
1496 llm_output: output.into(),
1497 raw_output,
1498 })
1499 })
1500 }
1501}
1502
1503#[derive(Clone)]
1504struct AgentResponseEventStream(mpsc::UnboundedSender<Result<AgentResponseEvent>>);
1505
1506impl AgentResponseEventStream {
1507 fn send_text(&self, text: &str) {
1508 self.0
1509 .unbounded_send(Ok(AgentResponseEvent::Text(text.to_string())))
1510 .ok();
1511 }
1512
1513 fn send_thinking(&self, text: &str) {
1514 self.0
1515 .unbounded_send(Ok(AgentResponseEvent::Thinking(text.to_string())))
1516 .ok();
1517 }
1518
1519 fn send_tool_call(
1520 &self,
1521 id: &LanguageModelToolUseId,
1522 title: SharedString,
1523 kind: acp::ToolKind,
1524 input: serde_json::Value,
1525 ) {
1526 self.0
1527 .unbounded_send(Ok(AgentResponseEvent::ToolCall(Self::initial_tool_call(
1528 id,
1529 title.to_string(),
1530 kind,
1531 input,
1532 ))))
1533 .ok();
1534 }
1535
1536 fn initial_tool_call(
1537 id: &LanguageModelToolUseId,
1538 title: String,
1539 kind: acp::ToolKind,
1540 input: serde_json::Value,
1541 ) -> acp::ToolCall {
1542 acp::ToolCall {
1543 id: acp::ToolCallId(id.to_string().into()),
1544 title,
1545 kind,
1546 status: acp::ToolCallStatus::Pending,
1547 content: vec![],
1548 locations: vec![],
1549 raw_input: Some(input),
1550 raw_output: None,
1551 }
1552 }
1553
1554 fn update_tool_call_fields(
1555 &self,
1556 tool_use_id: &LanguageModelToolUseId,
1557 fields: acp::ToolCallUpdateFields,
1558 ) {
1559 self.0
1560 .unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate(
1561 acp::ToolCallUpdate {
1562 id: acp::ToolCallId(tool_use_id.to_string().into()),
1563 fields,
1564 }
1565 .into(),
1566 )))
1567 .ok();
1568 }
1569
1570 fn send_retry(&self, status: acp_thread::RetryStatus) {
1571 self.0
1572 .unbounded_send(Ok(AgentResponseEvent::Retry(status)))
1573 .ok();
1574 }
1575
1576 fn send_stop(&self, reason: StopReason) {
1577 match reason {
1578 StopReason::EndTurn => {
1579 self.0
1580 .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::EndTurn)))
1581 .ok();
1582 }
1583 StopReason::MaxTokens => {
1584 self.0
1585 .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::MaxTokens)))
1586 .ok();
1587 }
1588 StopReason::Refusal => {
1589 self.0
1590 .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::Refusal)))
1591 .ok();
1592 }
1593 StopReason::ToolUse => {}
1594 }
1595 }
1596
1597 fn send_canceled(&self) {
1598 self.0
1599 .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::Canceled)))
1600 .ok();
1601 }
1602
1603 fn send_error(&self, error: impl Into<anyhow::Error>) {
1604 self.0.unbounded_send(Err(error.into())).ok();
1605 }
1606}
1607
1608#[derive(Clone)]
1609pub struct ToolCallEventStream {
1610 tool_use_id: LanguageModelToolUseId,
1611 stream: AgentResponseEventStream,
1612 fs: Option<Arc<dyn Fs>>,
1613}
1614
1615impl ToolCallEventStream {
1616 #[cfg(test)]
1617 pub fn test() -> (Self, ToolCallEventStreamReceiver) {
1618 let (events_tx, events_rx) = mpsc::unbounded::<Result<AgentResponseEvent>>();
1619
1620 let stream =
1621 ToolCallEventStream::new("test_id".into(), AgentResponseEventStream(events_tx), None);
1622
1623 (stream, ToolCallEventStreamReceiver(events_rx))
1624 }
1625
1626 fn new(
1627 tool_use_id: LanguageModelToolUseId,
1628 stream: AgentResponseEventStream,
1629 fs: Option<Arc<dyn Fs>>,
1630 ) -> Self {
1631 Self {
1632 tool_use_id,
1633 stream,
1634 fs,
1635 }
1636 }
1637
1638 pub fn update_fields(&self, fields: acp::ToolCallUpdateFields) {
1639 self.stream
1640 .update_tool_call_fields(&self.tool_use_id, fields);
1641 }
1642
1643 pub fn update_diff(&self, diff: Entity<acp_thread::Diff>) {
1644 self.stream
1645 .0
1646 .unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate(
1647 acp_thread::ToolCallUpdateDiff {
1648 id: acp::ToolCallId(self.tool_use_id.to_string().into()),
1649 diff,
1650 }
1651 .into(),
1652 )))
1653 .ok();
1654 }
1655
1656 pub fn update_terminal(&self, terminal: Entity<acp_thread::Terminal>) {
1657 self.stream
1658 .0
1659 .unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate(
1660 acp_thread::ToolCallUpdateTerminal {
1661 id: acp::ToolCallId(self.tool_use_id.to_string().into()),
1662 terminal,
1663 }
1664 .into(),
1665 )))
1666 .ok();
1667 }
1668
1669 pub fn authorize(&self, title: impl Into<String>, cx: &mut App) -> Task<Result<()>> {
1670 if agent_settings::AgentSettings::get_global(cx).always_allow_tool_actions {
1671 return Task::ready(Ok(()));
1672 }
1673
1674 let (response_tx, response_rx) = oneshot::channel();
1675 self.stream
1676 .0
1677 .unbounded_send(Ok(AgentResponseEvent::ToolCallAuthorization(
1678 ToolCallAuthorization {
1679 tool_call: acp::ToolCallUpdate {
1680 id: acp::ToolCallId(self.tool_use_id.to_string().into()),
1681 fields: acp::ToolCallUpdateFields {
1682 title: Some(title.into()),
1683 ..Default::default()
1684 },
1685 },
1686 options: vec![
1687 acp::PermissionOption {
1688 id: acp::PermissionOptionId("always_allow".into()),
1689 name: "Always Allow".into(),
1690 kind: acp::PermissionOptionKind::AllowAlways,
1691 },
1692 acp::PermissionOption {
1693 id: acp::PermissionOptionId("allow".into()),
1694 name: "Allow".into(),
1695 kind: acp::PermissionOptionKind::AllowOnce,
1696 },
1697 acp::PermissionOption {
1698 id: acp::PermissionOptionId("deny".into()),
1699 name: "Deny".into(),
1700 kind: acp::PermissionOptionKind::RejectOnce,
1701 },
1702 ],
1703 response: response_tx,
1704 },
1705 )))
1706 .ok();
1707 let fs = self.fs.clone();
1708 cx.spawn(async move |cx| match response_rx.await?.0.as_ref() {
1709 "always_allow" => {
1710 if let Some(fs) = fs.clone() {
1711 cx.update(|cx| {
1712 update_settings_file::<AgentSettings>(fs, cx, |settings, _| {
1713 settings.set_always_allow_tool_actions(true);
1714 });
1715 })?;
1716 }
1717
1718 Ok(())
1719 }
1720 "allow" => Ok(()),
1721 _ => Err(anyhow!("Permission to run tool denied by user")),
1722 })
1723 }
1724}
1725
1726#[cfg(test)]
1727pub struct ToolCallEventStreamReceiver(mpsc::UnboundedReceiver<Result<AgentResponseEvent>>);
1728
1729#[cfg(test)]
1730impl ToolCallEventStreamReceiver {
1731 pub async fn expect_authorization(&mut self) -> ToolCallAuthorization {
1732 let event = self.0.next().await;
1733 if let Some(Ok(AgentResponseEvent::ToolCallAuthorization(auth))) = event {
1734 auth
1735 } else {
1736 panic!("Expected ToolCallAuthorization but got: {:?}", event);
1737 }
1738 }
1739
1740 pub async fn expect_terminal(&mut self) -> Entity<acp_thread::Terminal> {
1741 let event = self.0.next().await;
1742 if let Some(Ok(AgentResponseEvent::ToolCallUpdate(
1743 acp_thread::ToolCallUpdate::UpdateTerminal(update),
1744 ))) = event
1745 {
1746 update.terminal
1747 } else {
1748 panic!("Expected terminal but got: {:?}", event);
1749 }
1750 }
1751}
1752
1753#[cfg(test)]
1754impl std::ops::Deref for ToolCallEventStreamReceiver {
1755 type Target = mpsc::UnboundedReceiver<Result<AgentResponseEvent>>;
1756
1757 fn deref(&self) -> &Self::Target {
1758 &self.0
1759 }
1760}
1761
1762#[cfg(test)]
1763impl std::ops::DerefMut for ToolCallEventStreamReceiver {
1764 fn deref_mut(&mut self) -> &mut Self::Target {
1765 &mut self.0
1766 }
1767}
1768
1769impl From<&str> for UserMessageContent {
1770 fn from(text: &str) -> Self {
1771 Self::Text(text.into())
1772 }
1773}
1774
1775impl From<acp::ContentBlock> for UserMessageContent {
1776 fn from(value: acp::ContentBlock) -> Self {
1777 match value {
1778 acp::ContentBlock::Text(text_content) => Self::Text(text_content.text),
1779 acp::ContentBlock::Image(image_content) => Self::Image(convert_image(image_content)),
1780 acp::ContentBlock::Audio(_) => {
1781 // TODO
1782 Self::Text("[audio]".to_string())
1783 }
1784 acp::ContentBlock::ResourceLink(resource_link) => {
1785 match MentionUri::parse(&resource_link.uri) {
1786 Ok(uri) => Self::Mention {
1787 uri,
1788 content: String::new(),
1789 },
1790 Err(err) => {
1791 log::error!("Failed to parse mention link: {}", err);
1792 Self::Text(format!("[{}]({})", resource_link.name, resource_link.uri))
1793 }
1794 }
1795 }
1796 acp::ContentBlock::Resource(resource) => match resource.resource {
1797 acp::EmbeddedResourceResource::TextResourceContents(resource) => {
1798 match MentionUri::parse(&resource.uri) {
1799 Ok(uri) => Self::Mention {
1800 uri,
1801 content: resource.text,
1802 },
1803 Err(err) => {
1804 log::error!("Failed to parse mention link: {}", err);
1805 Self::Text(
1806 MarkdownCodeBlock {
1807 tag: &resource.uri,
1808 text: &resource.text,
1809 }
1810 .to_string(),
1811 )
1812 }
1813 }
1814 }
1815 acp::EmbeddedResourceResource::BlobResourceContents(_) => {
1816 // TODO
1817 Self::Text("[blob]".to_string())
1818 }
1819 },
1820 }
1821 }
1822}
1823
1824fn convert_image(image_content: acp::ImageContent) -> LanguageModelImage {
1825 LanguageModelImage {
1826 source: image_content.data.into(),
1827 // TODO: make this optional?
1828 size: gpui::Size::new(0.into(), 0.into()),
1829 }
1830}