1use crate::{
2 ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DbLanguageModel, DbThread,
3 DeletePathTool, DiagnosticsTool, EditFileTool, FetchTool, FindPathTool, GrepTool,
4 ListDirectoryTool, MovePathTool, NowTool, OpenTool, ProjectSnapshot, ReadFileTool,
5 SystemPromptTemplate, Template, Templates, TerminalTool, ThinkingTool, WebSearchTool,
6};
7use acp_thread::{MentionUri, UserMessageId};
8use action_log::ActionLog;
9
10use agent_client_protocol as acp;
11use agent_settings::{
12 AgentProfileId, AgentProfileSettings, AgentSettings, CompletionMode,
13 SUMMARIZE_THREAD_DETAILED_PROMPT, SUMMARIZE_THREAD_PROMPT,
14};
15use anyhow::{Context as _, Result, anyhow};
16use chrono::{DateTime, Utc};
17use client::{ModelRequestUsage, RequestUsage, UserStore};
18use cloud_llm_client::{CompletionIntent, CompletionRequestStatus, Plan, UsageLimit};
19use collections::{HashMap, HashSet, IndexMap};
20use fs::Fs;
21use futures::stream;
22use futures::{
23 FutureExt,
24 channel::{mpsc, oneshot},
25 future::Shared,
26 stream::FuturesUnordered,
27};
28use gpui::{
29 App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity,
30};
31use language_model::{
32 LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelExt,
33 LanguageModelImage, LanguageModelProviderId, LanguageModelRegistry, LanguageModelRequest,
34 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
35 LanguageModelToolResultContent, LanguageModelToolSchemaFormat, LanguageModelToolUse,
36 LanguageModelToolUseId, Role, SelectedModel, StopReason, TokenUsage, ZED_CLOUD_PROVIDER_ID,
37};
38use project::Project;
39use prompt_store::ProjectContext;
40use schemars::{JsonSchema, Schema};
41use serde::{Deserialize, Serialize};
42use settings::{Settings, update_settings_file};
43use smol::stream::StreamExt;
44use std::{
45 collections::BTreeMap,
46 ops::RangeInclusive,
47 path::Path,
48 rc::Rc,
49 sync::Arc,
50 time::{Duration, Instant},
51};
52use std::{fmt::Write, path::PathBuf};
53use util::{ResultExt, debug_panic, markdown::MarkdownCodeBlock};
54use uuid::Uuid;
55
56const TOOL_CANCELED_MESSAGE: &str = "Tool canceled by user";
57pub const MAX_TOOL_NAME_LENGTH: usize = 64;
58
59/// The ID of the user prompt that initiated a request.
60///
61/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
62#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
63pub struct PromptId(Arc<str>);
64
65impl PromptId {
66 pub fn new() -> Self {
67 Self(Uuid::new_v4().to_string().into())
68 }
69}
70
71impl std::fmt::Display for PromptId {
72 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
73 write!(f, "{}", self.0)
74 }
75}
76
77pub(crate) const MAX_RETRY_ATTEMPTS: u8 = 4;
78pub(crate) const BASE_RETRY_DELAY: Duration = Duration::from_secs(5);
79
80#[derive(Debug, Clone)]
81enum RetryStrategy {
82 ExponentialBackoff {
83 initial_delay: Duration,
84 max_attempts: u8,
85 },
86 Fixed {
87 delay: Duration,
88 max_attempts: u8,
89 },
90}
91
92#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
93pub enum Message {
94 User(UserMessage),
95 Agent(AgentMessage),
96 Resume,
97}
98
99impl Message {
100 pub fn as_agent_message(&self) -> Option<&AgentMessage> {
101 match self {
102 Message::Agent(agent_message) => Some(agent_message),
103 _ => None,
104 }
105 }
106
107 pub fn to_request(&self) -> Vec<LanguageModelRequestMessage> {
108 match self {
109 Message::User(message) => vec![message.to_request()],
110 Message::Agent(message) => message.to_request(),
111 Message::Resume => vec![LanguageModelRequestMessage {
112 role: Role::User,
113 content: vec!["Continue where you left off".into()],
114 cache: false,
115 }],
116 }
117 }
118
119 pub fn to_markdown(&self) -> String {
120 match self {
121 Message::User(message) => message.to_markdown(),
122 Message::Agent(message) => message.to_markdown(),
123 Message::Resume => "[resume]\n".into(),
124 }
125 }
126
127 pub fn role(&self) -> Role {
128 match self {
129 Message::User(_) | Message::Resume => Role::User,
130 Message::Agent(_) => Role::Assistant,
131 }
132 }
133}
134
135#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
136pub struct UserMessage {
137 pub id: UserMessageId,
138 pub content: Vec<UserMessageContent>,
139}
140
141#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
142pub enum UserMessageContent {
143 Text(String),
144 Mention { uri: MentionUri, content: String },
145 Image(LanguageModelImage),
146}
147
148impl UserMessage {
149 pub fn to_markdown(&self) -> String {
150 let mut markdown = String::from("## User\n\n");
151
152 for content in &self.content {
153 match content {
154 UserMessageContent::Text(text) => {
155 markdown.push_str(text);
156 markdown.push('\n');
157 }
158 UserMessageContent::Image(_) => {
159 markdown.push_str("<image />\n");
160 }
161 UserMessageContent::Mention { uri, content } => {
162 if !content.is_empty() {
163 let _ = writeln!(&mut markdown, "{}\n\n{}", uri.as_link(), content);
164 } else {
165 let _ = writeln!(&mut markdown, "{}", uri.as_link());
166 }
167 }
168 }
169 }
170
171 markdown
172 }
173
174 fn to_request(&self) -> LanguageModelRequestMessage {
175 let mut message = LanguageModelRequestMessage {
176 role: Role::User,
177 content: Vec::with_capacity(self.content.len()),
178 cache: false,
179 };
180
181 const OPEN_CONTEXT: &str = "<context>\n\
182 The following items were attached by the user. \
183 They are up-to-date and don't need to be re-read.\n\n";
184
185 const OPEN_FILES_TAG: &str = "<files>";
186 const OPEN_DIRECTORIES_TAG: &str = "<directories>";
187 const OPEN_SYMBOLS_TAG: &str = "<symbols>";
188 const OPEN_SELECTIONS_TAG: &str = "<selections>";
189 const OPEN_THREADS_TAG: &str = "<threads>";
190 const OPEN_FETCH_TAG: &str = "<fetched_urls>";
191 const OPEN_RULES_TAG: &str =
192 "<rules>\nThe user has specified the following rules that should be applied:\n";
193
194 let mut file_context = OPEN_FILES_TAG.to_string();
195 let mut directory_context = OPEN_DIRECTORIES_TAG.to_string();
196 let mut symbol_context = OPEN_SYMBOLS_TAG.to_string();
197 let mut selection_context = OPEN_SELECTIONS_TAG.to_string();
198 let mut thread_context = OPEN_THREADS_TAG.to_string();
199 let mut fetch_context = OPEN_FETCH_TAG.to_string();
200 let mut rules_context = OPEN_RULES_TAG.to_string();
201
202 for chunk in &self.content {
203 let chunk = match chunk {
204 UserMessageContent::Text(text) => {
205 language_model::MessageContent::Text(text.clone())
206 }
207 UserMessageContent::Image(value) => {
208 language_model::MessageContent::Image(value.clone())
209 }
210 UserMessageContent::Mention { uri, content } => {
211 match uri {
212 MentionUri::File { abs_path } => {
213 write!(
214 &mut file_context,
215 "\n{}",
216 MarkdownCodeBlock {
217 tag: &codeblock_tag(abs_path, None),
218 text: &content.to_string(),
219 }
220 )
221 .ok();
222 }
223 MentionUri::PastedImage => {
224 debug_panic!("pasted image URI should not be used in mention content")
225 }
226 MentionUri::Directory { .. } => {
227 write!(&mut directory_context, "\n{}\n", content).ok();
228 }
229 MentionUri::Symbol {
230 abs_path: path,
231 line_range,
232 ..
233 } => {
234 write!(
235 &mut symbol_context,
236 "\n{}",
237 MarkdownCodeBlock {
238 tag: &codeblock_tag(path, Some(line_range)),
239 text: content
240 }
241 )
242 .ok();
243 }
244 MentionUri::Selection {
245 abs_path: path,
246 line_range,
247 ..
248 } => {
249 write!(
250 &mut selection_context,
251 "\n{}",
252 MarkdownCodeBlock {
253 tag: &codeblock_tag(
254 path.as_deref().unwrap_or("Untitled".as_ref()),
255 Some(line_range)
256 ),
257 text: content
258 }
259 )
260 .ok();
261 }
262 MentionUri::Thread { .. } => {
263 write!(&mut thread_context, "\n{}\n", content).ok();
264 }
265 MentionUri::TextThread { .. } => {
266 write!(&mut thread_context, "\n{}\n", content).ok();
267 }
268 MentionUri::Rule { .. } => {
269 write!(
270 &mut rules_context,
271 "\n{}",
272 MarkdownCodeBlock {
273 tag: "",
274 text: content
275 }
276 )
277 .ok();
278 }
279 MentionUri::Fetch { url } => {
280 write!(&mut fetch_context, "\nFetch: {}\n\n{}", url, content).ok();
281 }
282 }
283
284 language_model::MessageContent::Text(uri.as_link().to_string())
285 }
286 };
287
288 message.content.push(chunk);
289 }
290
291 let len_before_context = message.content.len();
292
293 if file_context.len() > OPEN_FILES_TAG.len() {
294 file_context.push_str("</files>\n");
295 message
296 .content
297 .push(language_model::MessageContent::Text(file_context));
298 }
299
300 if directory_context.len() > OPEN_DIRECTORIES_TAG.len() {
301 directory_context.push_str("</directories>\n");
302 message
303 .content
304 .push(language_model::MessageContent::Text(directory_context));
305 }
306
307 if symbol_context.len() > OPEN_SYMBOLS_TAG.len() {
308 symbol_context.push_str("</symbols>\n");
309 message
310 .content
311 .push(language_model::MessageContent::Text(symbol_context));
312 }
313
314 if selection_context.len() > OPEN_SELECTIONS_TAG.len() {
315 selection_context.push_str("</selections>\n");
316 message
317 .content
318 .push(language_model::MessageContent::Text(selection_context));
319 }
320
321 if thread_context.len() > OPEN_THREADS_TAG.len() {
322 thread_context.push_str("</threads>\n");
323 message
324 .content
325 .push(language_model::MessageContent::Text(thread_context));
326 }
327
328 if fetch_context.len() > OPEN_FETCH_TAG.len() {
329 fetch_context.push_str("</fetched_urls>\n");
330 message
331 .content
332 .push(language_model::MessageContent::Text(fetch_context));
333 }
334
335 if rules_context.len() > OPEN_RULES_TAG.len() {
336 rules_context.push_str("</user_rules>\n");
337 message
338 .content
339 .push(language_model::MessageContent::Text(rules_context));
340 }
341
342 if message.content.len() > len_before_context {
343 message.content.insert(
344 len_before_context,
345 language_model::MessageContent::Text(OPEN_CONTEXT.into()),
346 );
347 message
348 .content
349 .push(language_model::MessageContent::Text("</context>".into()));
350 }
351
352 message
353 }
354}
355
356fn codeblock_tag(full_path: &Path, line_range: Option<&RangeInclusive<u32>>) -> String {
357 let mut result = String::new();
358
359 if let Some(extension) = full_path.extension().and_then(|ext| ext.to_str()) {
360 let _ = write!(result, "{} ", extension);
361 }
362
363 let _ = write!(result, "{}", full_path.display());
364
365 if let Some(range) = line_range {
366 if range.start() == range.end() {
367 let _ = write!(result, ":{}", range.start() + 1);
368 } else {
369 let _ = write!(result, ":{}-{}", range.start() + 1, range.end() + 1);
370 }
371 }
372
373 result
374}
375
376impl AgentMessage {
377 pub fn to_markdown(&self) -> String {
378 let mut markdown = String::from("## Assistant\n\n");
379
380 for content in &self.content {
381 match content {
382 AgentMessageContent::Text(text) => {
383 markdown.push_str(text);
384 markdown.push('\n');
385 }
386 AgentMessageContent::Thinking { text, .. } => {
387 markdown.push_str("<think>");
388 markdown.push_str(text);
389 markdown.push_str("</think>\n");
390 }
391 AgentMessageContent::RedactedThinking(_) => {
392 markdown.push_str("<redacted_thinking />\n")
393 }
394 AgentMessageContent::ToolUse(tool_use) => {
395 markdown.push_str(&format!(
396 "**Tool Use**: {} (ID: {})\n",
397 tool_use.name, tool_use.id
398 ));
399 markdown.push_str(&format!(
400 "{}\n",
401 MarkdownCodeBlock {
402 tag: "json",
403 text: &format!("{:#}", tool_use.input)
404 }
405 ));
406 }
407 }
408 }
409
410 for tool_result in self.tool_results.values() {
411 markdown.push_str(&format!(
412 "**Tool Result**: {} (ID: {})\n\n",
413 tool_result.tool_name, tool_result.tool_use_id
414 ));
415 if tool_result.is_error {
416 markdown.push_str("**ERROR:**\n");
417 }
418
419 match &tool_result.content {
420 LanguageModelToolResultContent::Text(text) => {
421 writeln!(markdown, "{text}\n").ok();
422 }
423 LanguageModelToolResultContent::Image(_) => {
424 writeln!(markdown, "<image />\n").ok();
425 }
426 }
427
428 if let Some(output) = tool_result.output.as_ref() {
429 writeln!(
430 markdown,
431 "**Debug Output**:\n\n```json\n{}\n```\n",
432 serde_json::to_string_pretty(output).unwrap()
433 )
434 .unwrap();
435 }
436 }
437
438 markdown
439 }
440
441 pub fn to_request(&self) -> Vec<LanguageModelRequestMessage> {
442 let mut assistant_message = LanguageModelRequestMessage {
443 role: Role::Assistant,
444 content: Vec::with_capacity(self.content.len()),
445 cache: false,
446 };
447 for chunk in &self.content {
448 match chunk {
449 AgentMessageContent::Text(text) => {
450 assistant_message
451 .content
452 .push(language_model::MessageContent::Text(text.clone()));
453 }
454 AgentMessageContent::Thinking { text, signature } => {
455 assistant_message
456 .content
457 .push(language_model::MessageContent::Thinking {
458 text: text.clone(),
459 signature: signature.clone(),
460 });
461 }
462 AgentMessageContent::RedactedThinking(value) => {
463 assistant_message.content.push(
464 language_model::MessageContent::RedactedThinking(value.clone()),
465 );
466 }
467 AgentMessageContent::ToolUse(tool_use) => {
468 if self.tool_results.contains_key(&tool_use.id) {
469 assistant_message
470 .content
471 .push(language_model::MessageContent::ToolUse(tool_use.clone()));
472 }
473 }
474 };
475 }
476
477 let mut user_message = LanguageModelRequestMessage {
478 role: Role::User,
479 content: Vec::new(),
480 cache: false,
481 };
482
483 for tool_result in self.tool_results.values() {
484 let mut tool_result = tool_result.clone();
485 // Surprisingly, the API fails if we return an empty string here.
486 // It thinks we are sending a tool use without a tool result.
487 if tool_result.content.is_empty() {
488 tool_result.content = "<Tool returned an empty string>".into();
489 }
490 user_message
491 .content
492 .push(language_model::MessageContent::ToolResult(tool_result));
493 }
494
495 let mut messages = Vec::new();
496 if !assistant_message.content.is_empty() {
497 messages.push(assistant_message);
498 }
499 if !user_message.content.is_empty() {
500 messages.push(user_message);
501 }
502 messages
503 }
504}
505
506#[derive(Default, Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
507pub struct AgentMessage {
508 pub content: Vec<AgentMessageContent>,
509 pub tool_results: IndexMap<LanguageModelToolUseId, LanguageModelToolResult>,
510}
511
512#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
513pub enum AgentMessageContent {
514 Text(String),
515 Thinking {
516 text: String,
517 signature: Option<String>,
518 },
519 RedactedThinking(String),
520 ToolUse(LanguageModelToolUse),
521}
522
523pub trait TerminalHandle {
524 fn id(&self, cx: &AsyncApp) -> Result<acp::TerminalId>;
525 fn current_output(&self, cx: &AsyncApp) -> Result<acp::TerminalOutputResponse>;
526 fn wait_for_exit(&self, cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>>;
527}
528
529pub trait ThreadEnvironment {
530 fn create_terminal(
531 &self,
532 command: String,
533 cwd: Option<PathBuf>,
534 output_byte_limit: Option<u64>,
535 cx: &mut AsyncApp,
536 ) -> Task<Result<Rc<dyn TerminalHandle>>>;
537}
538
539#[derive(Debug)]
540pub enum ThreadEvent {
541 UserMessage(UserMessage),
542 AgentText(String),
543 AgentThinking(String),
544 ToolCall(acp::ToolCall),
545 ToolCallUpdate(acp_thread::ToolCallUpdate),
546 ToolCallAuthorization(ToolCallAuthorization),
547 Retry(acp_thread::RetryStatus),
548 Stop(acp::StopReason),
549}
550
551#[derive(Debug)]
552pub struct NewTerminal {
553 pub command: String,
554 pub output_byte_limit: Option<u64>,
555 pub cwd: Option<PathBuf>,
556 pub response: oneshot::Sender<Result<Entity<acp_thread::Terminal>>>,
557}
558
559#[derive(Debug)]
560pub struct ToolCallAuthorization {
561 pub tool_call: acp::ToolCallUpdate,
562 pub options: Vec<acp::PermissionOption>,
563 pub response: oneshot::Sender<acp::PermissionOptionId>,
564}
565
566#[derive(Debug, thiserror::Error)]
567enum CompletionError {
568 #[error("max tokens")]
569 MaxTokens,
570 #[error("refusal")]
571 Refusal,
572 #[error(transparent)]
573 Other(#[from] anyhow::Error),
574}
575
576pub struct Thread {
577 id: acp::SessionId,
578 prompt_id: PromptId,
579 updated_at: DateTime<Utc>,
580 title: Option<SharedString>,
581 pending_title_generation: Option<Task<()>>,
582 pending_summary_generation: Option<Shared<Task<Option<SharedString>>>>,
583 summary: Option<SharedString>,
584 messages: Vec<Message>,
585 user_store: Entity<UserStore>,
586 completion_mode: CompletionMode,
587 /// Holds the task that handles agent interaction until the end of the turn.
588 /// Survives across multiple requests as the model performs tool calls and
589 /// we run tools, report their results.
590 running_turn: Option<RunningTurn>,
591 pending_message: Option<AgentMessage>,
592 tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
593 tool_use_limit_reached: bool,
594 request_token_usage: HashMap<UserMessageId, language_model::TokenUsage>,
595 #[allow(unused)]
596 cumulative_token_usage: TokenUsage,
597 #[allow(unused)]
598 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
599 context_server_registry: Entity<ContextServerRegistry>,
600 profile_id: AgentProfileId,
601 project_context: Entity<ProjectContext>,
602 templates: Arc<Templates>,
603 model: Option<Arc<dyn LanguageModel>>,
604 summarization_model: Option<Arc<dyn LanguageModel>>,
605 prompt_capabilities_tx: watch::Sender<acp::PromptCapabilities>,
606 pub(crate) prompt_capabilities_rx: watch::Receiver<acp::PromptCapabilities>,
607 pub(crate) project: Entity<Project>,
608 pub(crate) action_log: Entity<ActionLog>,
609}
610
611impl Thread {
612 fn prompt_capabilities(model: Option<&dyn LanguageModel>) -> acp::PromptCapabilities {
613 let image = model.map_or(true, |model| model.supports_images());
614 acp::PromptCapabilities {
615 meta: None,
616 image,
617 audio: false,
618 embedded_context: true,
619 }
620 }
621
622 pub fn new(
623 project: Entity<Project>,
624 project_context: Entity<ProjectContext>,
625 context_server_registry: Entity<ContextServerRegistry>,
626 templates: Arc<Templates>,
627 model: Option<Arc<dyn LanguageModel>>,
628 cx: &mut Context<Self>,
629 ) -> Self {
630 let profile_id = AgentSettings::get_global(cx).default_profile.clone();
631 let action_log = cx.new(|_cx| ActionLog::new(project.clone()));
632 let (prompt_capabilities_tx, prompt_capabilities_rx) =
633 watch::channel(Self::prompt_capabilities(model.as_deref()));
634 Self {
635 id: acp::SessionId(uuid::Uuid::new_v4().to_string().into()),
636 prompt_id: PromptId::new(),
637 updated_at: Utc::now(),
638 title: None,
639 pending_title_generation: None,
640 pending_summary_generation: None,
641 summary: None,
642 messages: Vec::new(),
643 user_store: project.read(cx).user_store(),
644 completion_mode: AgentSettings::get_global(cx).preferred_completion_mode,
645 running_turn: None,
646 pending_message: None,
647 tools: BTreeMap::default(),
648 tool_use_limit_reached: false,
649 request_token_usage: HashMap::default(),
650 cumulative_token_usage: TokenUsage::default(),
651 initial_project_snapshot: {
652 let project_snapshot = Self::project_snapshot(project.clone(), cx);
653 cx.foreground_executor()
654 .spawn(async move { Some(project_snapshot.await) })
655 .shared()
656 },
657 context_server_registry,
658 profile_id,
659 project_context,
660 templates,
661 model,
662 summarization_model: None,
663 prompt_capabilities_tx,
664 prompt_capabilities_rx,
665 project,
666 action_log,
667 }
668 }
669
670 pub fn id(&self) -> &acp::SessionId {
671 &self.id
672 }
673
674 pub fn replay(
675 &mut self,
676 cx: &mut Context<Self>,
677 ) -> mpsc::UnboundedReceiver<Result<ThreadEvent>> {
678 let (tx, rx) = mpsc::unbounded();
679 let stream = ThreadEventStream(tx);
680 for message in &self.messages {
681 match message {
682 Message::User(user_message) => stream.send_user_message(user_message),
683 Message::Agent(assistant_message) => {
684 for content in &assistant_message.content {
685 match content {
686 AgentMessageContent::Text(text) => stream.send_text(text),
687 AgentMessageContent::Thinking { text, .. } => {
688 stream.send_thinking(text)
689 }
690 AgentMessageContent::RedactedThinking(_) => {}
691 AgentMessageContent::ToolUse(tool_use) => {
692 self.replay_tool_call(
693 tool_use,
694 assistant_message.tool_results.get(&tool_use.id),
695 &stream,
696 cx,
697 );
698 }
699 }
700 }
701 }
702 Message::Resume => {}
703 }
704 }
705 rx
706 }
707
708 fn replay_tool_call(
709 &self,
710 tool_use: &LanguageModelToolUse,
711 tool_result: Option<&LanguageModelToolResult>,
712 stream: &ThreadEventStream,
713 cx: &mut Context<Self>,
714 ) {
715 let tool = self.tools.get(tool_use.name.as_ref()).cloned().or_else(|| {
716 self.context_server_registry
717 .read(cx)
718 .servers()
719 .find_map(|(_, tools)| {
720 if let Some(tool) = tools.get(tool_use.name.as_ref()) {
721 Some(tool.clone())
722 } else {
723 None
724 }
725 })
726 });
727
728 let Some(tool) = tool else {
729 stream
730 .0
731 .unbounded_send(Ok(ThreadEvent::ToolCall(acp::ToolCall {
732 meta: None,
733 id: acp::ToolCallId(tool_use.id.to_string().into()),
734 title: tool_use.name.to_string(),
735 kind: acp::ToolKind::Other,
736 status: acp::ToolCallStatus::Failed,
737 content: Vec::new(),
738 locations: Vec::new(),
739 raw_input: Some(tool_use.input.clone()),
740 raw_output: None,
741 })))
742 .ok();
743 return;
744 };
745
746 let title = tool.initial_title(tool_use.input.clone(), cx);
747 let kind = tool.kind();
748 stream.send_tool_call(&tool_use.id, title, kind, tool_use.input.clone());
749
750 let output = tool_result
751 .as_ref()
752 .and_then(|result| result.output.clone());
753 if let Some(output) = output.clone() {
754 let tool_event_stream = ToolCallEventStream::new(
755 tool_use.id.clone(),
756 stream.clone(),
757 Some(self.project.read(cx).fs().clone()),
758 );
759 tool.replay(tool_use.input.clone(), output, tool_event_stream, cx)
760 .log_err();
761 }
762
763 stream.update_tool_call_fields(
764 &tool_use.id,
765 acp::ToolCallUpdateFields {
766 status: Some(
767 tool_result
768 .as_ref()
769 .map_or(acp::ToolCallStatus::Failed, |result| {
770 if result.is_error {
771 acp::ToolCallStatus::Failed
772 } else {
773 acp::ToolCallStatus::Completed
774 }
775 }),
776 ),
777 raw_output: output,
778 ..Default::default()
779 },
780 );
781 }
782
783 pub fn from_db(
784 id: acp::SessionId,
785 db_thread: DbThread,
786 project: Entity<Project>,
787 project_context: Entity<ProjectContext>,
788 context_server_registry: Entity<ContextServerRegistry>,
789 templates: Arc<Templates>,
790 cx: &mut Context<Self>,
791 ) -> Self {
792 let profile_id = db_thread
793 .profile
794 .unwrap_or_else(|| AgentSettings::get_global(cx).default_profile.clone());
795 let model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
796 db_thread
797 .model
798 .and_then(|model| {
799 let model = SelectedModel {
800 provider: model.provider.clone().into(),
801 model: model.model.into(),
802 };
803 registry.select_model(&model, cx)
804 })
805 .or_else(|| registry.default_model())
806 .map(|model| model.model)
807 });
808 let (prompt_capabilities_tx, prompt_capabilities_rx) =
809 watch::channel(Self::prompt_capabilities(model.as_deref()));
810
811 let action_log = cx.new(|_| ActionLog::new(project.clone()));
812
813 Self {
814 id,
815 prompt_id: PromptId::new(),
816 title: if db_thread.title.is_empty() {
817 None
818 } else {
819 Some(db_thread.title.clone())
820 },
821 pending_title_generation: None,
822 pending_summary_generation: None,
823 summary: db_thread.detailed_summary,
824 messages: db_thread.messages,
825 user_store: project.read(cx).user_store(),
826 completion_mode: db_thread.completion_mode.unwrap_or_default(),
827 running_turn: None,
828 pending_message: None,
829 tools: BTreeMap::default(),
830 tool_use_limit_reached: false,
831 request_token_usage: db_thread.request_token_usage.clone(),
832 cumulative_token_usage: db_thread.cumulative_token_usage,
833 initial_project_snapshot: Task::ready(db_thread.initial_project_snapshot).shared(),
834 context_server_registry,
835 profile_id,
836 project_context,
837 templates,
838 model,
839 summarization_model: None,
840 project,
841 action_log,
842 updated_at: db_thread.updated_at,
843 prompt_capabilities_tx,
844 prompt_capabilities_rx,
845 }
846 }
847
848 pub fn to_db(&self, cx: &App) -> Task<DbThread> {
849 let initial_project_snapshot = self.initial_project_snapshot.clone();
850 let mut thread = DbThread {
851 title: self.title(),
852 messages: self.messages.clone(),
853 updated_at: self.updated_at,
854 detailed_summary: self.summary.clone(),
855 initial_project_snapshot: None,
856 cumulative_token_usage: self.cumulative_token_usage,
857 request_token_usage: self.request_token_usage.clone(),
858 model: self.model.as_ref().map(|model| DbLanguageModel {
859 provider: model.provider_id().to_string(),
860 model: model.name().0.to_string(),
861 }),
862 completion_mode: Some(self.completion_mode),
863 profile: Some(self.profile_id.clone()),
864 };
865
866 cx.background_spawn(async move {
867 let initial_project_snapshot = initial_project_snapshot.await;
868 thread.initial_project_snapshot = initial_project_snapshot;
869 thread
870 })
871 }
872
873 /// Create a snapshot of the current project state including git information and unsaved buffers.
874 fn project_snapshot(
875 project: Entity<Project>,
876 cx: &mut Context<Self>,
877 ) -> Task<Arc<ProjectSnapshot>> {
878 let task = project::telemetry_snapshot::TelemetrySnapshot::new(&project, cx);
879 cx.spawn(async move |_, _| {
880 let snapshot = task.await;
881
882 Arc::new(ProjectSnapshot {
883 worktree_snapshots: snapshot.worktree_snapshots,
884 timestamp: Utc::now(),
885 })
886 })
887 }
888
889 pub fn project_context(&self) -> &Entity<ProjectContext> {
890 &self.project_context
891 }
892
893 pub fn project(&self) -> &Entity<Project> {
894 &self.project
895 }
896
897 pub fn action_log(&self) -> &Entity<ActionLog> {
898 &self.action_log
899 }
900
901 pub fn is_empty(&self) -> bool {
902 self.messages.is_empty() && self.title.is_none()
903 }
904
905 pub fn model(&self) -> Option<&Arc<dyn LanguageModel>> {
906 self.model.as_ref()
907 }
908
909 pub fn set_model(&mut self, model: Arc<dyn LanguageModel>, cx: &mut Context<Self>) {
910 let old_usage = self.latest_token_usage();
911 self.model = Some(model);
912 let new_caps = Self::prompt_capabilities(self.model.as_deref());
913 let new_usage = self.latest_token_usage();
914 if old_usage != new_usage {
915 cx.emit(TokenUsageUpdated(new_usage));
916 }
917 self.prompt_capabilities_tx.send(new_caps).log_err();
918 cx.notify()
919 }
920
921 pub fn summarization_model(&self) -> Option<&Arc<dyn LanguageModel>> {
922 self.summarization_model.as_ref()
923 }
924
925 pub fn set_summarization_model(
926 &mut self,
927 model: Option<Arc<dyn LanguageModel>>,
928 cx: &mut Context<Self>,
929 ) {
930 self.summarization_model = model;
931 cx.notify()
932 }
933
934 pub fn completion_mode(&self) -> CompletionMode {
935 self.completion_mode
936 }
937
938 pub fn set_completion_mode(&mut self, mode: CompletionMode, cx: &mut Context<Self>) {
939 let old_usage = self.latest_token_usage();
940 self.completion_mode = mode;
941 let new_usage = self.latest_token_usage();
942 if old_usage != new_usage {
943 cx.emit(TokenUsageUpdated(new_usage));
944 }
945 cx.notify()
946 }
947
948 #[cfg(any(test, feature = "test-support"))]
949 pub fn last_message(&self) -> Option<Message> {
950 if let Some(message) = self.pending_message.clone() {
951 Some(Message::Agent(message))
952 } else {
953 self.messages.last().cloned()
954 }
955 }
956
957 pub fn add_default_tools(
958 &mut self,
959 environment: Rc<dyn ThreadEnvironment>,
960 cx: &mut Context<Self>,
961 ) {
962 let language_registry = self.project.read(cx).languages().clone();
963 self.add_tool(CopyPathTool::new(self.project.clone()));
964 self.add_tool(CreateDirectoryTool::new(self.project.clone()));
965 self.add_tool(DeletePathTool::new(
966 self.project.clone(),
967 self.action_log.clone(),
968 ));
969 self.add_tool(DiagnosticsTool::new(self.project.clone()));
970 self.add_tool(EditFileTool::new(
971 self.project.clone(),
972 cx.weak_entity(),
973 language_registry,
974 Templates::new(),
975 ));
976 self.add_tool(FetchTool::new(self.project.read(cx).client().http_client()));
977 self.add_tool(FindPathTool::new(self.project.clone()));
978 self.add_tool(GrepTool::new(self.project.clone()));
979 self.add_tool(ListDirectoryTool::new(self.project.clone()));
980 self.add_tool(MovePathTool::new(self.project.clone()));
981 self.add_tool(NowTool);
982 self.add_tool(OpenTool::new(self.project.clone()));
983 self.add_tool(ReadFileTool::new(
984 self.project.clone(),
985 self.action_log.clone(),
986 ));
987 self.add_tool(TerminalTool::new(self.project.clone(), environment));
988 self.add_tool(ThinkingTool);
989 self.add_tool(WebSearchTool);
990 }
991
992 pub fn add_tool<T: AgentTool>(&mut self, tool: T) {
993 self.tools.insert(T::name().into(), tool.erase());
994 }
995
996 pub fn remove_tool(&mut self, name: &str) -> bool {
997 self.tools.remove(name).is_some()
998 }
999
1000 pub fn profile(&self) -> &AgentProfileId {
1001 &self.profile_id
1002 }
1003
1004 pub fn set_profile(&mut self, profile_id: AgentProfileId) {
1005 self.profile_id = profile_id;
1006 }
1007
1008 pub fn cancel(&mut self, cx: &mut Context<Self>) {
1009 if let Some(running_turn) = self.running_turn.take() {
1010 running_turn.cancel();
1011 }
1012 self.flush_pending_message(cx);
1013 }
1014
1015 fn update_token_usage(&mut self, update: language_model::TokenUsage, cx: &mut Context<Self>) {
1016 let Some(last_user_message) = self.last_user_message() else {
1017 return;
1018 };
1019
1020 self.request_token_usage
1021 .insert(last_user_message.id.clone(), update);
1022 cx.emit(TokenUsageUpdated(self.latest_token_usage()));
1023 cx.notify();
1024 }
1025
1026 pub fn truncate(&mut self, message_id: UserMessageId, cx: &mut Context<Self>) -> Result<()> {
1027 self.cancel(cx);
1028 let Some(position) = self.messages.iter().position(
1029 |msg| matches!(msg, Message::User(UserMessage { id, .. }) if id == &message_id),
1030 ) else {
1031 return Err(anyhow!("Message not found"));
1032 };
1033
1034 for message in self.messages.drain(position..) {
1035 match message {
1036 Message::User(message) => {
1037 self.request_token_usage.remove(&message.id);
1038 }
1039 Message::Agent(_) | Message::Resume => {}
1040 }
1041 }
1042 self.clear_summary();
1043 cx.notify();
1044 Ok(())
1045 }
1046
1047 pub fn latest_token_usage(&self) -> Option<acp_thread::TokenUsage> {
1048 let last_user_message = self.last_user_message()?;
1049 let tokens = self.request_token_usage.get(&last_user_message.id)?;
1050 let model = self.model.clone()?;
1051
1052 Some(acp_thread::TokenUsage {
1053 max_tokens: model.max_token_count_for_mode(self.completion_mode.into()),
1054 used_tokens: tokens.total_tokens(),
1055 })
1056 }
1057
1058 pub fn resume(
1059 &mut self,
1060 cx: &mut Context<Self>,
1061 ) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>> {
1062 self.messages.push(Message::Resume);
1063 cx.notify();
1064
1065 log::debug!("Total messages in thread: {}", self.messages.len());
1066 self.run_turn(cx)
1067 }
1068
1069 /// Sending a message results in the model streaming a response, which could include tool calls.
1070 /// After calling tools, the model will stops and waits for any outstanding tool calls to be completed and their results sent.
1071 /// The returned channel will report all the occurrences in which the model stops before erroring or ending its turn.
1072 pub fn send<T>(
1073 &mut self,
1074 id: UserMessageId,
1075 content: impl IntoIterator<Item = T>,
1076 cx: &mut Context<Self>,
1077 ) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>
1078 where
1079 T: Into<UserMessageContent>,
1080 {
1081 let model = self.model().context("No language model configured")?;
1082
1083 log::info!("Thread::send called with model: {}", model.name().0);
1084 self.advance_prompt_id();
1085
1086 let content = content.into_iter().map(Into::into).collect::<Vec<_>>();
1087 log::debug!("Thread::send content: {:?}", content);
1088
1089 self.messages
1090 .push(Message::User(UserMessage { id, content }));
1091 cx.notify();
1092
1093 log::debug!("Total messages in thread: {}", self.messages.len());
1094 self.run_turn(cx)
1095 }
1096
1097 fn run_turn(
1098 &mut self,
1099 cx: &mut Context<Self>,
1100 ) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>> {
1101 self.cancel(cx);
1102
1103 let model = self.model.clone().context("No language model configured")?;
1104 let profile = AgentSettings::get_global(cx)
1105 .profiles
1106 .get(&self.profile_id)
1107 .context("Profile not found")?;
1108 let (events_tx, events_rx) = mpsc::unbounded::<Result<ThreadEvent>>();
1109 let event_stream = ThreadEventStream(events_tx);
1110 let message_ix = self.messages.len().saturating_sub(1);
1111 self.tool_use_limit_reached = false;
1112 self.clear_summary();
1113 self.running_turn = Some(RunningTurn {
1114 event_stream: event_stream.clone(),
1115 tools: self.enabled_tools(profile, &model, cx),
1116 _task: cx.spawn(async move |this, cx| {
1117 log::debug!("Starting agent turn execution");
1118
1119 let turn_result = Self::run_turn_internal(&this, model, &event_stream, cx).await;
1120 _ = this.update(cx, |this, cx| this.flush_pending_message(cx));
1121
1122 match turn_result {
1123 Ok(()) => {
1124 log::debug!("Turn execution completed");
1125 event_stream.send_stop(acp::StopReason::EndTurn);
1126 }
1127 Err(error) => {
1128 log::error!("Turn execution failed: {:?}", error);
1129 match error.downcast::<CompletionError>() {
1130 Ok(CompletionError::Refusal) => {
1131 event_stream.send_stop(acp::StopReason::Refusal);
1132 _ = this.update(cx, |this, _| this.messages.truncate(message_ix));
1133 }
1134 Ok(CompletionError::MaxTokens) => {
1135 event_stream.send_stop(acp::StopReason::MaxTokens);
1136 }
1137 Ok(CompletionError::Other(error)) | Err(error) => {
1138 event_stream.send_error(error);
1139 }
1140 }
1141 }
1142 }
1143
1144 _ = this.update(cx, |this, _| this.running_turn.take());
1145 }),
1146 });
1147 Ok(events_rx)
1148 }
1149
1150 async fn run_turn_internal(
1151 this: &WeakEntity<Self>,
1152 model: Arc<dyn LanguageModel>,
1153 event_stream: &ThreadEventStream,
1154 cx: &mut AsyncApp,
1155 ) -> Result<()> {
1156 let mut attempt = 0;
1157 let mut intent = CompletionIntent::UserPrompt;
1158 loop {
1159 let request =
1160 this.update(cx, |this, cx| this.build_completion_request(intent, cx))??;
1161
1162 telemetry::event!(
1163 "Agent Thread Completion",
1164 thread_id = this.read_with(cx, |this, _| this.id.to_string())?,
1165 prompt_id = this.read_with(cx, |this, _| this.prompt_id.to_string())?,
1166 model = model.telemetry_id(),
1167 model_provider = model.provider_id().to_string(),
1168 attempt
1169 );
1170
1171 log::debug!("Calling model.stream_completion, attempt {}", attempt);
1172
1173 let (mut events, mut error) = match model.stream_completion(request, cx).await {
1174 Ok(events) => (events, None),
1175 Err(err) => (stream::empty().boxed(), Some(err)),
1176 };
1177 let mut tool_results = FuturesUnordered::new();
1178 while let Some(event) = events.next().await {
1179 log::trace!("Received completion event: {:?}", event);
1180 match event {
1181 Ok(event) => {
1182 tool_results.extend(this.update(cx, |this, cx| {
1183 this.handle_completion_event(event, event_stream, cx)
1184 })??);
1185 }
1186 Err(err) => {
1187 error = Some(err);
1188 break;
1189 }
1190 }
1191 }
1192
1193 let end_turn = tool_results.is_empty();
1194 while let Some(tool_result) = tool_results.next().await {
1195 log::debug!("Tool finished {:?}", tool_result);
1196
1197 event_stream.update_tool_call_fields(
1198 &tool_result.tool_use_id,
1199 acp::ToolCallUpdateFields {
1200 status: Some(if tool_result.is_error {
1201 acp::ToolCallStatus::Failed
1202 } else {
1203 acp::ToolCallStatus::Completed
1204 }),
1205 raw_output: tool_result.output.clone(),
1206 ..Default::default()
1207 },
1208 );
1209 this.update(cx, |this, _cx| {
1210 this.pending_message()
1211 .tool_results
1212 .insert(tool_result.tool_use_id.clone(), tool_result);
1213 })?;
1214 }
1215
1216 this.update(cx, |this, cx| {
1217 this.flush_pending_message(cx);
1218 if this.title.is_none() && this.pending_title_generation.is_none() {
1219 this.generate_title(cx);
1220 }
1221 })?;
1222
1223 if let Some(error) = error {
1224 attempt += 1;
1225 let retry = this.update(cx, |this, cx| {
1226 let user_store = this.user_store.read(cx);
1227 this.handle_completion_error(error, attempt, user_store.plan())
1228 })??;
1229 let timer = cx.background_executor().timer(retry.duration);
1230 event_stream.send_retry(retry);
1231 timer.await;
1232 this.update(cx, |this, _cx| {
1233 if let Some(Message::Agent(message)) = this.messages.last() {
1234 if message.tool_results.is_empty() {
1235 intent = CompletionIntent::UserPrompt;
1236 this.messages.push(Message::Resume);
1237 }
1238 }
1239 })?;
1240 } else if this.read_with(cx, |this, _| this.tool_use_limit_reached)? {
1241 return Err(language_model::ToolUseLimitReachedError.into());
1242 } else if end_turn {
1243 return Ok(());
1244 } else {
1245 intent = CompletionIntent::ToolResults;
1246 attempt = 0;
1247 }
1248 }
1249 }
1250
1251 fn handle_completion_error(
1252 &mut self,
1253 error: LanguageModelCompletionError,
1254 attempt: u8,
1255 plan: Option<Plan>,
1256 ) -> Result<acp_thread::RetryStatus> {
1257 let Some(model) = self.model.as_ref() else {
1258 return Err(anyhow!(error));
1259 };
1260
1261 let auto_retry = if model.provider_id() == ZED_CLOUD_PROVIDER_ID {
1262 match plan {
1263 Some(Plan::V2(_)) => true,
1264 Some(Plan::V1(_)) => self.completion_mode == CompletionMode::Burn,
1265 None => false,
1266 }
1267 } else {
1268 true
1269 };
1270
1271 if !auto_retry {
1272 return Err(anyhow!(error));
1273 }
1274
1275 let Some(strategy) = Self::retry_strategy_for(&error) else {
1276 return Err(anyhow!(error));
1277 };
1278
1279 let max_attempts = match &strategy {
1280 RetryStrategy::ExponentialBackoff { max_attempts, .. } => *max_attempts,
1281 RetryStrategy::Fixed { max_attempts, .. } => *max_attempts,
1282 };
1283
1284 if attempt > max_attempts {
1285 return Err(anyhow!(error));
1286 }
1287
1288 let delay = match &strategy {
1289 RetryStrategy::ExponentialBackoff { initial_delay, .. } => {
1290 let delay_secs = initial_delay.as_secs() * 2u64.pow((attempt - 1) as u32);
1291 Duration::from_secs(delay_secs)
1292 }
1293 RetryStrategy::Fixed { delay, .. } => *delay,
1294 };
1295 log::debug!("Retry attempt {attempt} with delay {delay:?}");
1296
1297 Ok(acp_thread::RetryStatus {
1298 last_error: error.to_string().into(),
1299 attempt: attempt as usize,
1300 max_attempts: max_attempts as usize,
1301 started_at: Instant::now(),
1302 duration: delay,
1303 })
1304 }
1305
1306 /// A helper method that's called on every streamed completion event.
1307 /// Returns an optional tool result task, which the main agentic loop will
1308 /// send back to the model when it resolves.
1309 fn handle_completion_event(
1310 &mut self,
1311 event: LanguageModelCompletionEvent,
1312 event_stream: &ThreadEventStream,
1313 cx: &mut Context<Self>,
1314 ) -> Result<Option<Task<LanguageModelToolResult>>> {
1315 log::trace!("Handling streamed completion event: {:?}", event);
1316 use LanguageModelCompletionEvent::*;
1317
1318 match event {
1319 StartMessage { .. } => {
1320 self.flush_pending_message(cx);
1321 self.pending_message = Some(AgentMessage::default());
1322 }
1323 Text(new_text) => self.handle_text_event(new_text, event_stream, cx),
1324 Thinking { text, signature } => {
1325 self.handle_thinking_event(text, signature, event_stream, cx)
1326 }
1327 RedactedThinking { data } => self.handle_redacted_thinking_event(data, cx),
1328 ToolUse(tool_use) => {
1329 return Ok(self.handle_tool_use_event(tool_use, event_stream, cx));
1330 }
1331 ToolUseJsonParseError {
1332 id,
1333 tool_name,
1334 raw_input,
1335 json_parse_error,
1336 } => {
1337 return Ok(Some(Task::ready(
1338 self.handle_tool_use_json_parse_error_event(
1339 id,
1340 tool_name,
1341 raw_input,
1342 json_parse_error,
1343 ),
1344 )));
1345 }
1346 UsageUpdate(usage) => {
1347 telemetry::event!(
1348 "Agent Thread Completion Usage Updated",
1349 thread_id = self.id.to_string(),
1350 prompt_id = self.prompt_id.to_string(),
1351 model = self.model.as_ref().map(|m| m.telemetry_id()),
1352 model_provider = self.model.as_ref().map(|m| m.provider_id().to_string()),
1353 input_tokens = usage.input_tokens,
1354 output_tokens = usage.output_tokens,
1355 cache_creation_input_tokens = usage.cache_creation_input_tokens,
1356 cache_read_input_tokens = usage.cache_read_input_tokens,
1357 );
1358 self.update_token_usage(usage, cx);
1359 }
1360 StatusUpdate(CompletionRequestStatus::UsageUpdated { amount, limit }) => {
1361 self.update_model_request_usage(amount, limit, cx);
1362 }
1363 StatusUpdate(
1364 CompletionRequestStatus::Started
1365 | CompletionRequestStatus::Queued { .. }
1366 | CompletionRequestStatus::Failed { .. },
1367 ) => {}
1368 StatusUpdate(CompletionRequestStatus::ToolUseLimitReached) => {
1369 self.tool_use_limit_reached = true;
1370 }
1371 Stop(StopReason::Refusal) => return Err(CompletionError::Refusal.into()),
1372 Stop(StopReason::MaxTokens) => return Err(CompletionError::MaxTokens.into()),
1373 Stop(StopReason::ToolUse | StopReason::EndTurn) => {}
1374 }
1375
1376 Ok(None)
1377 }
1378
1379 fn handle_text_event(
1380 &mut self,
1381 new_text: String,
1382 event_stream: &ThreadEventStream,
1383 cx: &mut Context<Self>,
1384 ) {
1385 event_stream.send_text(&new_text);
1386
1387 let last_message = self.pending_message();
1388 if let Some(AgentMessageContent::Text(text)) = last_message.content.last_mut() {
1389 text.push_str(&new_text);
1390 } else {
1391 last_message
1392 .content
1393 .push(AgentMessageContent::Text(new_text));
1394 }
1395
1396 cx.notify();
1397 }
1398
1399 fn handle_thinking_event(
1400 &mut self,
1401 new_text: String,
1402 new_signature: Option<String>,
1403 event_stream: &ThreadEventStream,
1404 cx: &mut Context<Self>,
1405 ) {
1406 event_stream.send_thinking(&new_text);
1407
1408 let last_message = self.pending_message();
1409 if let Some(AgentMessageContent::Thinking { text, signature }) =
1410 last_message.content.last_mut()
1411 {
1412 text.push_str(&new_text);
1413 *signature = new_signature.or(signature.take());
1414 } else {
1415 last_message.content.push(AgentMessageContent::Thinking {
1416 text: new_text,
1417 signature: new_signature,
1418 });
1419 }
1420
1421 cx.notify();
1422 }
1423
1424 fn handle_redacted_thinking_event(&mut self, data: String, cx: &mut Context<Self>) {
1425 let last_message = self.pending_message();
1426 last_message
1427 .content
1428 .push(AgentMessageContent::RedactedThinking(data));
1429 cx.notify();
1430 }
1431
1432 fn handle_tool_use_event(
1433 &mut self,
1434 tool_use: LanguageModelToolUse,
1435 event_stream: &ThreadEventStream,
1436 cx: &mut Context<Self>,
1437 ) -> Option<Task<LanguageModelToolResult>> {
1438 cx.notify();
1439
1440 let tool = self.tool(tool_use.name.as_ref());
1441 let mut title = SharedString::from(&tool_use.name);
1442 let mut kind = acp::ToolKind::Other;
1443 if let Some(tool) = tool.as_ref() {
1444 title = tool.initial_title(tool_use.input.clone(), cx);
1445 kind = tool.kind();
1446 }
1447
1448 // Ensure the last message ends in the current tool use
1449 let last_message = self.pending_message();
1450 let push_new_tool_use = last_message.content.last_mut().is_none_or(|content| {
1451 if let AgentMessageContent::ToolUse(last_tool_use) = content {
1452 if last_tool_use.id == tool_use.id {
1453 *last_tool_use = tool_use.clone();
1454 false
1455 } else {
1456 true
1457 }
1458 } else {
1459 true
1460 }
1461 });
1462
1463 if push_new_tool_use {
1464 event_stream.send_tool_call(&tool_use.id, title, kind, tool_use.input.clone());
1465 last_message
1466 .content
1467 .push(AgentMessageContent::ToolUse(tool_use.clone()));
1468 } else {
1469 event_stream.update_tool_call_fields(
1470 &tool_use.id,
1471 acp::ToolCallUpdateFields {
1472 title: Some(title.into()),
1473 kind: Some(kind),
1474 raw_input: Some(tool_use.input.clone()),
1475 ..Default::default()
1476 },
1477 );
1478 }
1479
1480 if !tool_use.is_input_complete {
1481 return None;
1482 }
1483
1484 let Some(tool) = tool else {
1485 let content = format!("No tool named {} exists", tool_use.name);
1486 return Some(Task::ready(LanguageModelToolResult {
1487 content: LanguageModelToolResultContent::Text(Arc::from(content)),
1488 tool_use_id: tool_use.id,
1489 tool_name: tool_use.name,
1490 is_error: true,
1491 output: None,
1492 }));
1493 };
1494
1495 let fs = self.project.read(cx).fs().clone();
1496 let tool_event_stream =
1497 ToolCallEventStream::new(tool_use.id.clone(), event_stream.clone(), Some(fs));
1498 tool_event_stream.update_fields(acp::ToolCallUpdateFields {
1499 status: Some(acp::ToolCallStatus::InProgress),
1500 ..Default::default()
1501 });
1502 let supports_images = self.model().is_some_and(|model| model.supports_images());
1503 let tool_result = tool.run(tool_use.input, tool_event_stream, cx);
1504 log::debug!("Running tool {}", tool_use.name);
1505 Some(cx.foreground_executor().spawn(async move {
1506 let tool_result = tool_result.await.and_then(|output| {
1507 if let LanguageModelToolResultContent::Image(_) = &output.llm_output
1508 && !supports_images
1509 {
1510 return Err(anyhow!(
1511 "Attempted to read an image, but this model doesn't support it.",
1512 ));
1513 }
1514 Ok(output)
1515 });
1516
1517 match tool_result {
1518 Ok(output) => LanguageModelToolResult {
1519 tool_use_id: tool_use.id,
1520 tool_name: tool_use.name,
1521 is_error: false,
1522 content: output.llm_output,
1523 output: Some(output.raw_output),
1524 },
1525 Err(error) => LanguageModelToolResult {
1526 tool_use_id: tool_use.id,
1527 tool_name: tool_use.name,
1528 is_error: true,
1529 content: LanguageModelToolResultContent::Text(Arc::from(error.to_string())),
1530 output: Some(error.to_string().into()),
1531 },
1532 }
1533 }))
1534 }
1535
1536 fn handle_tool_use_json_parse_error_event(
1537 &mut self,
1538 tool_use_id: LanguageModelToolUseId,
1539 tool_name: Arc<str>,
1540 raw_input: Arc<str>,
1541 json_parse_error: String,
1542 ) -> LanguageModelToolResult {
1543 let tool_output = format!("Error parsing input JSON: {json_parse_error}");
1544 LanguageModelToolResult {
1545 tool_use_id,
1546 tool_name,
1547 is_error: true,
1548 content: LanguageModelToolResultContent::Text(tool_output.into()),
1549 output: Some(serde_json::Value::String(raw_input.to_string())),
1550 }
1551 }
1552
1553 fn update_model_request_usage(&self, amount: usize, limit: UsageLimit, cx: &mut Context<Self>) {
1554 self.project
1555 .read(cx)
1556 .user_store()
1557 .update(cx, |user_store, cx| {
1558 user_store.update_model_request_usage(
1559 ModelRequestUsage(RequestUsage {
1560 amount: amount as i32,
1561 limit,
1562 }),
1563 cx,
1564 )
1565 });
1566 }
1567
1568 pub fn title(&self) -> SharedString {
1569 self.title.clone().unwrap_or("New Thread".into())
1570 }
1571
1572 pub fn is_generating_summary(&self) -> bool {
1573 self.pending_summary_generation.is_some()
1574 }
1575
1576 pub fn summary(&mut self, cx: &mut Context<Self>) -> Shared<Task<Option<SharedString>>> {
1577 if let Some(summary) = self.summary.as_ref() {
1578 return Task::ready(Some(summary.clone())).shared();
1579 }
1580 if let Some(task) = self.pending_summary_generation.clone() {
1581 return task;
1582 }
1583 let Some(model) = self.summarization_model.clone() else {
1584 log::error!("No summarization model available");
1585 return Task::ready(None).shared();
1586 };
1587 let mut request = LanguageModelRequest {
1588 intent: Some(CompletionIntent::ThreadContextSummarization),
1589 temperature: AgentSettings::temperature_for_model(&model, cx),
1590 ..Default::default()
1591 };
1592
1593 for message in &self.messages {
1594 request.messages.extend(message.to_request());
1595 }
1596
1597 request.messages.push(LanguageModelRequestMessage {
1598 role: Role::User,
1599 content: vec![SUMMARIZE_THREAD_DETAILED_PROMPT.into()],
1600 cache: false,
1601 });
1602
1603 let task = cx
1604 .spawn(async move |this, cx| {
1605 let mut summary = String::new();
1606 let mut messages = model.stream_completion(request, cx).await.log_err()?;
1607 while let Some(event) = messages.next().await {
1608 let event = event.log_err()?;
1609 let text = match event {
1610 LanguageModelCompletionEvent::Text(text) => text,
1611 LanguageModelCompletionEvent::StatusUpdate(
1612 CompletionRequestStatus::UsageUpdated { amount, limit },
1613 ) => {
1614 this.update(cx, |thread, cx| {
1615 thread.update_model_request_usage(amount, limit, cx);
1616 })
1617 .ok()?;
1618 continue;
1619 }
1620 _ => continue,
1621 };
1622
1623 let mut lines = text.lines();
1624 summary.extend(lines.next());
1625 }
1626
1627 log::debug!("Setting summary: {}", summary);
1628 let summary = SharedString::from(summary);
1629
1630 this.update(cx, |this, cx| {
1631 this.summary = Some(summary.clone());
1632 this.pending_summary_generation = None;
1633 cx.notify()
1634 })
1635 .ok()?;
1636
1637 Some(summary)
1638 })
1639 .shared();
1640 self.pending_summary_generation = Some(task.clone());
1641 task
1642 }
1643
1644 fn generate_title(&mut self, cx: &mut Context<Self>) {
1645 let Some(model) = self.summarization_model.clone() else {
1646 return;
1647 };
1648
1649 log::debug!(
1650 "Generating title with model: {:?}",
1651 self.summarization_model.as_ref().map(|model| model.name())
1652 );
1653 let mut request = LanguageModelRequest {
1654 intent: Some(CompletionIntent::ThreadSummarization),
1655 temperature: AgentSettings::temperature_for_model(&model, cx),
1656 ..Default::default()
1657 };
1658
1659 for message in &self.messages {
1660 request.messages.extend(message.to_request());
1661 }
1662
1663 request.messages.push(LanguageModelRequestMessage {
1664 role: Role::User,
1665 content: vec![SUMMARIZE_THREAD_PROMPT.into()],
1666 cache: false,
1667 });
1668 self.pending_title_generation = Some(cx.spawn(async move |this, cx| {
1669 let mut title = String::new();
1670
1671 let generate = async {
1672 let mut messages = model.stream_completion(request, cx).await?;
1673 while let Some(event) = messages.next().await {
1674 let event = event?;
1675 let text = match event {
1676 LanguageModelCompletionEvent::Text(text) => text,
1677 LanguageModelCompletionEvent::StatusUpdate(
1678 CompletionRequestStatus::UsageUpdated { amount, limit },
1679 ) => {
1680 this.update(cx, |thread, cx| {
1681 thread.update_model_request_usage(amount, limit, cx);
1682 })?;
1683 continue;
1684 }
1685 _ => continue,
1686 };
1687
1688 let mut lines = text.lines();
1689 title.extend(lines.next());
1690
1691 // Stop if the LLM generated multiple lines.
1692 if lines.next().is_some() {
1693 break;
1694 }
1695 }
1696 anyhow::Ok(())
1697 };
1698
1699 if generate.await.context("failed to generate title").is_ok() {
1700 _ = this.update(cx, |this, cx| this.set_title(title.into(), cx));
1701 }
1702 _ = this.update(cx, |this, _| this.pending_title_generation = None);
1703 }));
1704 }
1705
1706 pub fn set_title(&mut self, title: SharedString, cx: &mut Context<Self>) {
1707 self.pending_title_generation = None;
1708 if Some(&title) != self.title.as_ref() {
1709 self.title = Some(title);
1710 cx.emit(TitleUpdated);
1711 cx.notify();
1712 }
1713 }
1714
1715 fn clear_summary(&mut self) {
1716 self.summary = None;
1717 self.pending_summary_generation = None;
1718 }
1719
1720 fn last_user_message(&self) -> Option<&UserMessage> {
1721 self.messages
1722 .iter()
1723 .rev()
1724 .find_map(|message| match message {
1725 Message::User(user_message) => Some(user_message),
1726 Message::Agent(_) => None,
1727 Message::Resume => None,
1728 })
1729 }
1730
1731 fn pending_message(&mut self) -> &mut AgentMessage {
1732 self.pending_message.get_or_insert_default()
1733 }
1734
1735 fn flush_pending_message(&mut self, cx: &mut Context<Self>) {
1736 let Some(mut message) = self.pending_message.take() else {
1737 return;
1738 };
1739
1740 if message.content.is_empty() {
1741 return;
1742 }
1743
1744 for content in &message.content {
1745 let AgentMessageContent::ToolUse(tool_use) = content else {
1746 continue;
1747 };
1748
1749 if !message.tool_results.contains_key(&tool_use.id) {
1750 message.tool_results.insert(
1751 tool_use.id.clone(),
1752 LanguageModelToolResult {
1753 tool_use_id: tool_use.id.clone(),
1754 tool_name: tool_use.name.clone(),
1755 is_error: true,
1756 content: LanguageModelToolResultContent::Text(TOOL_CANCELED_MESSAGE.into()),
1757 output: None,
1758 },
1759 );
1760 }
1761 }
1762
1763 self.messages.push(Message::Agent(message));
1764 self.updated_at = Utc::now();
1765 self.clear_summary();
1766 cx.notify()
1767 }
1768
1769 pub(crate) fn build_completion_request(
1770 &self,
1771 completion_intent: CompletionIntent,
1772 cx: &App,
1773 ) -> Result<LanguageModelRequest> {
1774 let model = self.model().context("No language model configured")?;
1775 let tools = if let Some(turn) = self.running_turn.as_ref() {
1776 turn.tools
1777 .iter()
1778 .filter_map(|(tool_name, tool)| {
1779 log::trace!("Including tool: {}", tool_name);
1780 Some(LanguageModelRequestTool {
1781 name: tool_name.to_string(),
1782 description: tool.description().to_string(),
1783 input_schema: tool.input_schema(model.tool_input_format()).log_err()?,
1784 })
1785 })
1786 .collect::<Vec<_>>()
1787 } else {
1788 Vec::new()
1789 };
1790
1791 log::debug!("Building completion request");
1792 log::debug!("Completion intent: {:?}", completion_intent);
1793 log::debug!("Completion mode: {:?}", self.completion_mode);
1794
1795 let messages = self.build_request_messages(cx);
1796 log::debug!("Request will include {} messages", messages.len());
1797 log::debug!("Request includes {} tools", tools.len());
1798
1799 let request = LanguageModelRequest {
1800 thread_id: Some(self.id.to_string()),
1801 prompt_id: Some(self.prompt_id.to_string()),
1802 intent: Some(completion_intent),
1803 mode: Some(self.completion_mode.into()),
1804 messages,
1805 tools,
1806 tool_choice: None,
1807 stop: Vec::new(),
1808 temperature: AgentSettings::temperature_for_model(model, cx),
1809 thinking_allowed: true,
1810 };
1811
1812 log::debug!("Completion request built successfully");
1813 Ok(request)
1814 }
1815
1816 fn enabled_tools(
1817 &self,
1818 profile: &AgentProfileSettings,
1819 model: &Arc<dyn LanguageModel>,
1820 cx: &App,
1821 ) -> BTreeMap<SharedString, Arc<dyn AnyAgentTool>> {
1822 fn truncate(tool_name: &SharedString) -> SharedString {
1823 if tool_name.len() > MAX_TOOL_NAME_LENGTH {
1824 let mut truncated = tool_name.to_string();
1825 truncated.truncate(MAX_TOOL_NAME_LENGTH);
1826 truncated.into()
1827 } else {
1828 tool_name.clone()
1829 }
1830 }
1831
1832 let mut tools = self
1833 .tools
1834 .iter()
1835 .filter_map(|(tool_name, tool)| {
1836 if tool.supported_provider(&model.provider_id())
1837 && profile.is_tool_enabled(tool_name)
1838 {
1839 Some((truncate(tool_name), tool.clone()))
1840 } else {
1841 None
1842 }
1843 })
1844 .collect::<BTreeMap<_, _>>();
1845
1846 let mut context_server_tools = Vec::new();
1847 let mut seen_tools = tools.keys().cloned().collect::<HashSet<_>>();
1848 let mut duplicate_tool_names = HashSet::default();
1849 for (server_id, server_tools) in self.context_server_registry.read(cx).servers() {
1850 for (tool_name, tool) in server_tools {
1851 if profile.is_context_server_tool_enabled(&server_id.0, &tool_name) {
1852 let tool_name = truncate(tool_name);
1853 if !seen_tools.insert(tool_name.clone()) {
1854 duplicate_tool_names.insert(tool_name.clone());
1855 }
1856 context_server_tools.push((server_id.clone(), tool_name, tool.clone()));
1857 }
1858 }
1859 }
1860
1861 // When there are duplicate tool names, disambiguate by prefixing them
1862 // with the server ID. In the rare case there isn't enough space for the
1863 // disambiguated tool name, keep only the last tool with this name.
1864 for (server_id, tool_name, tool) in context_server_tools {
1865 if duplicate_tool_names.contains(&tool_name) {
1866 let available = MAX_TOOL_NAME_LENGTH.saturating_sub(tool_name.len());
1867 if available >= 2 {
1868 let mut disambiguated = server_id.0.to_string();
1869 disambiguated.truncate(available - 1);
1870 disambiguated.push('_');
1871 disambiguated.push_str(&tool_name);
1872 tools.insert(disambiguated.into(), tool.clone());
1873 } else {
1874 tools.insert(tool_name, tool.clone());
1875 }
1876 } else {
1877 tools.insert(tool_name, tool.clone());
1878 }
1879 }
1880
1881 tools
1882 }
1883
1884 fn tool(&self, name: &str) -> Option<Arc<dyn AnyAgentTool>> {
1885 self.running_turn.as_ref()?.tools.get(name).cloned()
1886 }
1887
1888 fn build_request_messages(&self, cx: &App) -> Vec<LanguageModelRequestMessage> {
1889 log::trace!(
1890 "Building request messages from {} thread messages",
1891 self.messages.len()
1892 );
1893
1894 let system_prompt = SystemPromptTemplate {
1895 project: self.project_context.read(cx),
1896 available_tools: self.tools.keys().cloned().collect(),
1897 }
1898 .render(&self.templates)
1899 .context("failed to build system prompt")
1900 .expect("Invalid template");
1901 let mut messages = vec![LanguageModelRequestMessage {
1902 role: Role::System,
1903 content: vec![system_prompt.into()],
1904 cache: false,
1905 }];
1906 for message in &self.messages {
1907 messages.extend(message.to_request());
1908 }
1909
1910 if let Some(last_message) = messages.last_mut() {
1911 last_message.cache = true;
1912 }
1913
1914 if let Some(message) = self.pending_message.as_ref() {
1915 messages.extend(message.to_request());
1916 }
1917
1918 messages
1919 }
1920
1921 pub fn to_markdown(&self) -> String {
1922 let mut markdown = String::new();
1923 for (ix, message) in self.messages.iter().enumerate() {
1924 if ix > 0 {
1925 markdown.push('\n');
1926 }
1927 markdown.push_str(&message.to_markdown());
1928 }
1929
1930 if let Some(message) = self.pending_message.as_ref() {
1931 markdown.push('\n');
1932 markdown.push_str(&message.to_markdown());
1933 }
1934
1935 markdown
1936 }
1937
1938 fn advance_prompt_id(&mut self) {
1939 self.prompt_id = PromptId::new();
1940 }
1941
1942 fn retry_strategy_for(error: &LanguageModelCompletionError) -> Option<RetryStrategy> {
1943 use LanguageModelCompletionError::*;
1944 use http_client::StatusCode;
1945
1946 // General strategy here:
1947 // - If retrying won't help (e.g. invalid API key or payload too large), return None so we don't retry at all.
1948 // - If it's a time-based issue (e.g. server overloaded, rate limit exceeded), retry up to 4 times with exponential backoff.
1949 // - If it's an issue that *might* be fixed by retrying (e.g. internal server error), retry up to 3 times.
1950 match error {
1951 HttpResponseError {
1952 status_code: StatusCode::TOO_MANY_REQUESTS,
1953 ..
1954 } => Some(RetryStrategy::ExponentialBackoff {
1955 initial_delay: BASE_RETRY_DELAY,
1956 max_attempts: MAX_RETRY_ATTEMPTS,
1957 }),
1958 ServerOverloaded { retry_after, .. } | RateLimitExceeded { retry_after, .. } => {
1959 Some(RetryStrategy::Fixed {
1960 delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
1961 max_attempts: MAX_RETRY_ATTEMPTS,
1962 })
1963 }
1964 UpstreamProviderError {
1965 status,
1966 retry_after,
1967 ..
1968 } => match *status {
1969 StatusCode::TOO_MANY_REQUESTS | StatusCode::SERVICE_UNAVAILABLE => {
1970 Some(RetryStrategy::Fixed {
1971 delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
1972 max_attempts: MAX_RETRY_ATTEMPTS,
1973 })
1974 }
1975 StatusCode::INTERNAL_SERVER_ERROR => Some(RetryStrategy::Fixed {
1976 delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
1977 // Internal Server Error could be anything, retry up to 3 times.
1978 max_attempts: 3,
1979 }),
1980 status => {
1981 // There is no StatusCode variant for the unofficial HTTP 529 ("The service is overloaded"),
1982 // but we frequently get them in practice. See https://http.dev/529
1983 if status.as_u16() == 529 {
1984 Some(RetryStrategy::Fixed {
1985 delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
1986 max_attempts: MAX_RETRY_ATTEMPTS,
1987 })
1988 } else {
1989 Some(RetryStrategy::Fixed {
1990 delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
1991 max_attempts: 2,
1992 })
1993 }
1994 }
1995 },
1996 ApiInternalServerError { .. } => Some(RetryStrategy::Fixed {
1997 delay: BASE_RETRY_DELAY,
1998 max_attempts: 3,
1999 }),
2000 ApiReadResponseError { .. }
2001 | HttpSend { .. }
2002 | DeserializeResponse { .. }
2003 | BadRequestFormat { .. } => Some(RetryStrategy::Fixed {
2004 delay: BASE_RETRY_DELAY,
2005 max_attempts: 3,
2006 }),
2007 // Retrying these errors definitely shouldn't help.
2008 HttpResponseError {
2009 status_code:
2010 StatusCode::PAYLOAD_TOO_LARGE | StatusCode::FORBIDDEN | StatusCode::UNAUTHORIZED,
2011 ..
2012 }
2013 | AuthenticationError { .. }
2014 | PermissionError { .. }
2015 | NoApiKey { .. }
2016 | ApiEndpointNotFound { .. }
2017 | PromptTooLarge { .. } => None,
2018 // These errors might be transient, so retry them
2019 SerializeRequest { .. } | BuildRequestBody { .. } => Some(RetryStrategy::Fixed {
2020 delay: BASE_RETRY_DELAY,
2021 max_attempts: 1,
2022 }),
2023 // Retry all other 4xx and 5xx errors once.
2024 HttpResponseError { status_code, .. }
2025 if status_code.is_client_error() || status_code.is_server_error() =>
2026 {
2027 Some(RetryStrategy::Fixed {
2028 delay: BASE_RETRY_DELAY,
2029 max_attempts: 3,
2030 })
2031 }
2032 Other(err)
2033 if err.is::<language_model::PaymentRequiredError>()
2034 || err.is::<language_model::ModelRequestLimitReachedError>() =>
2035 {
2036 // Retrying won't help for Payment Required or Model Request Limit errors (where
2037 // the user must upgrade to usage-based billing to get more requests, or else wait
2038 // for a significant amount of time for the request limit to reset).
2039 None
2040 }
2041 // Conservatively assume that any other errors are non-retryable
2042 HttpResponseError { .. } | Other(..) => Some(RetryStrategy::Fixed {
2043 delay: BASE_RETRY_DELAY,
2044 max_attempts: 2,
2045 }),
2046 }
2047 }
2048}
2049
2050struct RunningTurn {
2051 /// Holds the task that handles agent interaction until the end of the turn.
2052 /// Survives across multiple requests as the model performs tool calls and
2053 /// we run tools, report their results.
2054 _task: Task<()>,
2055 /// The current event stream for the running turn. Used to report a final
2056 /// cancellation event if we cancel the turn.
2057 event_stream: ThreadEventStream,
2058 /// The tools that were enabled for this turn.
2059 tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
2060}
2061
2062impl RunningTurn {
2063 fn cancel(self) {
2064 log::debug!("Cancelling in progress turn");
2065 self.event_stream.send_canceled();
2066 }
2067}
2068
2069pub struct TokenUsageUpdated(pub Option<acp_thread::TokenUsage>);
2070
2071impl EventEmitter<TokenUsageUpdated> for Thread {}
2072
2073pub struct TitleUpdated;
2074
2075impl EventEmitter<TitleUpdated> for Thread {}
2076
2077pub trait AgentTool
2078where
2079 Self: 'static + Sized,
2080{
2081 type Input: for<'de> Deserialize<'de> + Serialize + JsonSchema;
2082 type Output: for<'de> Deserialize<'de> + Serialize + Into<LanguageModelToolResultContent>;
2083
2084 fn name() -> &'static str;
2085
2086 fn description() -> SharedString {
2087 let schema = schemars::schema_for!(Self::Input);
2088 SharedString::new(
2089 schema
2090 .get("description")
2091 .and_then(|description| description.as_str())
2092 .unwrap_or_default(),
2093 )
2094 }
2095
2096 fn kind() -> acp::ToolKind;
2097
2098 /// The initial tool title to display. Can be updated during the tool run.
2099 fn initial_title(
2100 &self,
2101 input: Result<Self::Input, serde_json::Value>,
2102 cx: &mut App,
2103 ) -> SharedString;
2104
2105 /// Returns the JSON schema that describes the tool's input.
2106 fn input_schema(format: LanguageModelToolSchemaFormat) -> Schema {
2107 crate::tool_schema::root_schema_for::<Self::Input>(format)
2108 }
2109
2110 /// Some tools rely on a provider for the underlying billing or other reasons.
2111 /// Allow the tool to check if they are compatible, or should be filtered out.
2112 fn supported_provider(&self, _provider: &LanguageModelProviderId) -> bool {
2113 true
2114 }
2115
2116 /// Runs the tool with the provided input.
2117 fn run(
2118 self: Arc<Self>,
2119 input: Self::Input,
2120 event_stream: ToolCallEventStream,
2121 cx: &mut App,
2122 ) -> Task<Result<Self::Output>>;
2123
2124 /// Emits events for a previous execution of the tool.
2125 fn replay(
2126 &self,
2127 _input: Self::Input,
2128 _output: Self::Output,
2129 _event_stream: ToolCallEventStream,
2130 _cx: &mut App,
2131 ) -> Result<()> {
2132 Ok(())
2133 }
2134
2135 fn erase(self) -> Arc<dyn AnyAgentTool> {
2136 Arc::new(Erased(Arc::new(self)))
2137 }
2138}
2139
2140pub struct Erased<T>(T);
2141
2142pub struct AgentToolOutput {
2143 pub llm_output: LanguageModelToolResultContent,
2144 pub raw_output: serde_json::Value,
2145}
2146
2147pub trait AnyAgentTool {
2148 fn name(&self) -> SharedString;
2149 fn description(&self) -> SharedString;
2150 fn kind(&self) -> acp::ToolKind;
2151 fn initial_title(&self, input: serde_json::Value, _cx: &mut App) -> SharedString;
2152 fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value>;
2153 fn supported_provider(&self, _provider: &LanguageModelProviderId) -> bool {
2154 true
2155 }
2156 fn run(
2157 self: Arc<Self>,
2158 input: serde_json::Value,
2159 event_stream: ToolCallEventStream,
2160 cx: &mut App,
2161 ) -> Task<Result<AgentToolOutput>>;
2162 fn replay(
2163 &self,
2164 input: serde_json::Value,
2165 output: serde_json::Value,
2166 event_stream: ToolCallEventStream,
2167 cx: &mut App,
2168 ) -> Result<()>;
2169}
2170
2171impl<T> AnyAgentTool for Erased<Arc<T>>
2172where
2173 T: AgentTool,
2174{
2175 fn name(&self) -> SharedString {
2176 T::name().into()
2177 }
2178
2179 fn description(&self) -> SharedString {
2180 T::description()
2181 }
2182
2183 fn kind(&self) -> agent_client_protocol::ToolKind {
2184 T::kind()
2185 }
2186
2187 fn initial_title(&self, input: serde_json::Value, _cx: &mut App) -> SharedString {
2188 let parsed_input = serde_json::from_value(input.clone()).map_err(|_| input);
2189 self.0.initial_title(parsed_input, _cx)
2190 }
2191
2192 fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
2193 let mut json = serde_json::to_value(T::input_schema(format))?;
2194 crate::tool_schema::adapt_schema_to_format(&mut json, format)?;
2195 Ok(json)
2196 }
2197
2198 fn supported_provider(&self, provider: &LanguageModelProviderId) -> bool {
2199 self.0.supported_provider(provider)
2200 }
2201
2202 fn run(
2203 self: Arc<Self>,
2204 input: serde_json::Value,
2205 event_stream: ToolCallEventStream,
2206 cx: &mut App,
2207 ) -> Task<Result<AgentToolOutput>> {
2208 cx.spawn(async move |cx| {
2209 let input = serde_json::from_value(input)?;
2210 let output = cx
2211 .update(|cx| self.0.clone().run(input, event_stream, cx))?
2212 .await?;
2213 let raw_output = serde_json::to_value(&output)?;
2214 Ok(AgentToolOutput {
2215 llm_output: output.into(),
2216 raw_output,
2217 })
2218 })
2219 }
2220
2221 fn replay(
2222 &self,
2223 input: serde_json::Value,
2224 output: serde_json::Value,
2225 event_stream: ToolCallEventStream,
2226 cx: &mut App,
2227 ) -> Result<()> {
2228 let input = serde_json::from_value(input)?;
2229 let output = serde_json::from_value(output)?;
2230 self.0.replay(input, output, event_stream, cx)
2231 }
2232}
2233
2234#[derive(Clone)]
2235struct ThreadEventStream(mpsc::UnboundedSender<Result<ThreadEvent>>);
2236
2237impl ThreadEventStream {
2238 fn send_user_message(&self, message: &UserMessage) {
2239 self.0
2240 .unbounded_send(Ok(ThreadEvent::UserMessage(message.clone())))
2241 .ok();
2242 }
2243
2244 fn send_text(&self, text: &str) {
2245 self.0
2246 .unbounded_send(Ok(ThreadEvent::AgentText(text.to_string())))
2247 .ok();
2248 }
2249
2250 fn send_thinking(&self, text: &str) {
2251 self.0
2252 .unbounded_send(Ok(ThreadEvent::AgentThinking(text.to_string())))
2253 .ok();
2254 }
2255
2256 fn send_tool_call(
2257 &self,
2258 id: &LanguageModelToolUseId,
2259 title: SharedString,
2260 kind: acp::ToolKind,
2261 input: serde_json::Value,
2262 ) {
2263 self.0
2264 .unbounded_send(Ok(ThreadEvent::ToolCall(Self::initial_tool_call(
2265 id,
2266 title.to_string(),
2267 kind,
2268 input,
2269 ))))
2270 .ok();
2271 }
2272
2273 fn initial_tool_call(
2274 id: &LanguageModelToolUseId,
2275 title: String,
2276 kind: acp::ToolKind,
2277 input: serde_json::Value,
2278 ) -> acp::ToolCall {
2279 acp::ToolCall {
2280 meta: None,
2281 id: acp::ToolCallId(id.to_string().into()),
2282 title,
2283 kind,
2284 status: acp::ToolCallStatus::Pending,
2285 content: vec![],
2286 locations: vec![],
2287 raw_input: Some(input),
2288 raw_output: None,
2289 }
2290 }
2291
2292 fn update_tool_call_fields(
2293 &self,
2294 tool_use_id: &LanguageModelToolUseId,
2295 fields: acp::ToolCallUpdateFields,
2296 ) {
2297 self.0
2298 .unbounded_send(Ok(ThreadEvent::ToolCallUpdate(
2299 acp::ToolCallUpdate {
2300 meta: None,
2301 id: acp::ToolCallId(tool_use_id.to_string().into()),
2302 fields,
2303 }
2304 .into(),
2305 )))
2306 .ok();
2307 }
2308
2309 fn send_retry(&self, status: acp_thread::RetryStatus) {
2310 self.0.unbounded_send(Ok(ThreadEvent::Retry(status))).ok();
2311 }
2312
2313 fn send_stop(&self, reason: acp::StopReason) {
2314 self.0.unbounded_send(Ok(ThreadEvent::Stop(reason))).ok();
2315 }
2316
2317 fn send_canceled(&self) {
2318 self.0
2319 .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::Cancelled)))
2320 .ok();
2321 }
2322
2323 fn send_error(&self, error: impl Into<anyhow::Error>) {
2324 self.0.unbounded_send(Err(error.into())).ok();
2325 }
2326}
2327
2328#[derive(Clone)]
2329pub struct ToolCallEventStream {
2330 tool_use_id: LanguageModelToolUseId,
2331 stream: ThreadEventStream,
2332 fs: Option<Arc<dyn Fs>>,
2333}
2334
2335impl ToolCallEventStream {
2336 #[cfg(any(test, feature = "test-support"))]
2337 pub fn test() -> (Self, ToolCallEventStreamReceiver) {
2338 let (events_tx, events_rx) = mpsc::unbounded::<Result<ThreadEvent>>();
2339
2340 let stream = ToolCallEventStream::new("test_id".into(), ThreadEventStream(events_tx), None);
2341
2342 (stream, ToolCallEventStreamReceiver(events_rx))
2343 }
2344
2345 fn new(
2346 tool_use_id: LanguageModelToolUseId,
2347 stream: ThreadEventStream,
2348 fs: Option<Arc<dyn Fs>>,
2349 ) -> Self {
2350 Self {
2351 tool_use_id,
2352 stream,
2353 fs,
2354 }
2355 }
2356
2357 pub fn update_fields(&self, fields: acp::ToolCallUpdateFields) {
2358 self.stream
2359 .update_tool_call_fields(&self.tool_use_id, fields);
2360 }
2361
2362 pub fn update_diff(&self, diff: Entity<acp_thread::Diff>) {
2363 self.stream
2364 .0
2365 .unbounded_send(Ok(ThreadEvent::ToolCallUpdate(
2366 acp_thread::ToolCallUpdateDiff {
2367 id: acp::ToolCallId(self.tool_use_id.to_string().into()),
2368 diff,
2369 }
2370 .into(),
2371 )))
2372 .ok();
2373 }
2374
2375 pub fn authorize(&self, title: impl Into<String>, cx: &mut App) -> Task<Result<()>> {
2376 if agent_settings::AgentSettings::get_global(cx).always_allow_tool_actions {
2377 return Task::ready(Ok(()));
2378 }
2379
2380 let (response_tx, response_rx) = oneshot::channel();
2381 self.stream
2382 .0
2383 .unbounded_send(Ok(ThreadEvent::ToolCallAuthorization(
2384 ToolCallAuthorization {
2385 tool_call: acp::ToolCallUpdate {
2386 meta: None,
2387 id: acp::ToolCallId(self.tool_use_id.to_string().into()),
2388 fields: acp::ToolCallUpdateFields {
2389 title: Some(title.into()),
2390 ..Default::default()
2391 },
2392 },
2393 options: vec![
2394 acp::PermissionOption {
2395 id: acp::PermissionOptionId("always_allow".into()),
2396 name: "Always Allow".into(),
2397 kind: acp::PermissionOptionKind::AllowAlways,
2398 meta: None,
2399 },
2400 acp::PermissionOption {
2401 id: acp::PermissionOptionId("allow".into()),
2402 name: "Allow".into(),
2403 kind: acp::PermissionOptionKind::AllowOnce,
2404 meta: None,
2405 },
2406 acp::PermissionOption {
2407 id: acp::PermissionOptionId("deny".into()),
2408 name: "Deny".into(),
2409 kind: acp::PermissionOptionKind::RejectOnce,
2410 meta: None,
2411 },
2412 ],
2413 response: response_tx,
2414 },
2415 )))
2416 .ok();
2417 let fs = self.fs.clone();
2418 cx.spawn(async move |cx| match response_rx.await?.0.as_ref() {
2419 "always_allow" => {
2420 if let Some(fs) = fs.clone() {
2421 cx.update(|cx| {
2422 update_settings_file(fs, cx, |settings, _| {
2423 settings
2424 .agent
2425 .get_or_insert_default()
2426 .set_always_allow_tool_actions(true);
2427 });
2428 })?;
2429 }
2430
2431 Ok(())
2432 }
2433 "allow" => Ok(()),
2434 _ => Err(anyhow!("Permission to run tool denied by user")),
2435 })
2436 }
2437}
2438
2439#[cfg(any(test, feature = "test-support"))]
2440pub struct ToolCallEventStreamReceiver(mpsc::UnboundedReceiver<Result<ThreadEvent>>);
2441
2442#[cfg(any(test, feature = "test-support"))]
2443impl ToolCallEventStreamReceiver {
2444 pub async fn expect_authorization(&mut self) -> ToolCallAuthorization {
2445 let event = self.0.next().await;
2446 if let Some(Ok(ThreadEvent::ToolCallAuthorization(auth))) = event {
2447 auth
2448 } else {
2449 panic!("Expected ToolCallAuthorization but got: {:?}", event);
2450 }
2451 }
2452
2453 pub async fn expect_update_fields(&mut self) -> acp::ToolCallUpdateFields {
2454 let event = self.0.next().await;
2455 if let Some(Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
2456 update,
2457 )))) = event
2458 {
2459 update.fields
2460 } else {
2461 panic!("Expected update fields but got: {:?}", event);
2462 }
2463 }
2464
2465 pub async fn expect_diff(&mut self) -> Entity<acp_thread::Diff> {
2466 let event = self.0.next().await;
2467 if let Some(Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateDiff(
2468 update,
2469 )))) = event
2470 {
2471 update.diff
2472 } else {
2473 panic!("Expected diff but got: {:?}", event);
2474 }
2475 }
2476
2477 pub async fn expect_terminal(&mut self) -> Entity<acp_thread::Terminal> {
2478 let event = self.0.next().await;
2479 if let Some(Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateTerminal(
2480 update,
2481 )))) = event
2482 {
2483 update.terminal
2484 } else {
2485 panic!("Expected terminal but got: {:?}", event);
2486 }
2487 }
2488}
2489
2490#[cfg(any(test, feature = "test-support"))]
2491impl std::ops::Deref for ToolCallEventStreamReceiver {
2492 type Target = mpsc::UnboundedReceiver<Result<ThreadEvent>>;
2493
2494 fn deref(&self) -> &Self::Target {
2495 &self.0
2496 }
2497}
2498
2499#[cfg(any(test, feature = "test-support"))]
2500impl std::ops::DerefMut for ToolCallEventStreamReceiver {
2501 fn deref_mut(&mut self) -> &mut Self::Target {
2502 &mut self.0
2503 }
2504}
2505
2506impl From<&str> for UserMessageContent {
2507 fn from(text: &str) -> Self {
2508 Self::Text(text.into())
2509 }
2510}
2511
2512impl From<acp::ContentBlock> for UserMessageContent {
2513 fn from(value: acp::ContentBlock) -> Self {
2514 match value {
2515 acp::ContentBlock::Text(text_content) => Self::Text(text_content.text),
2516 acp::ContentBlock::Image(image_content) => Self::Image(convert_image(image_content)),
2517 acp::ContentBlock::Audio(_) => {
2518 // TODO
2519 Self::Text("[audio]".to_string())
2520 }
2521 acp::ContentBlock::ResourceLink(resource_link) => {
2522 match MentionUri::parse(&resource_link.uri) {
2523 Ok(uri) => Self::Mention {
2524 uri,
2525 content: String::new(),
2526 },
2527 Err(err) => {
2528 log::error!("Failed to parse mention link: {}", err);
2529 Self::Text(format!("[{}]({})", resource_link.name, resource_link.uri))
2530 }
2531 }
2532 }
2533 acp::ContentBlock::Resource(resource) => match resource.resource {
2534 acp::EmbeddedResourceResource::TextResourceContents(resource) => {
2535 match MentionUri::parse(&resource.uri) {
2536 Ok(uri) => Self::Mention {
2537 uri,
2538 content: resource.text,
2539 },
2540 Err(err) => {
2541 log::error!("Failed to parse mention link: {}", err);
2542 Self::Text(
2543 MarkdownCodeBlock {
2544 tag: &resource.uri,
2545 text: &resource.text,
2546 }
2547 .to_string(),
2548 )
2549 }
2550 }
2551 }
2552 acp::EmbeddedResourceResource::BlobResourceContents(_) => {
2553 // TODO
2554 Self::Text("[blob]".to_string())
2555 }
2556 },
2557 }
2558 }
2559}
2560
2561impl From<UserMessageContent> for acp::ContentBlock {
2562 fn from(content: UserMessageContent) -> Self {
2563 match content {
2564 UserMessageContent::Text(text) => acp::ContentBlock::Text(acp::TextContent {
2565 text,
2566 annotations: None,
2567 meta: None,
2568 }),
2569 UserMessageContent::Image(image) => acp::ContentBlock::Image(acp::ImageContent {
2570 data: image.source.to_string(),
2571 mime_type: "image/png".to_string(),
2572 meta: None,
2573 annotations: None,
2574 uri: None,
2575 }),
2576 UserMessageContent::Mention { uri, content } => {
2577 acp::ContentBlock::Resource(acp::EmbeddedResource {
2578 meta: None,
2579 resource: acp::EmbeddedResourceResource::TextResourceContents(
2580 acp::TextResourceContents {
2581 meta: None,
2582 mime_type: None,
2583 text: content,
2584 uri: uri.to_uri().to_string(),
2585 },
2586 ),
2587 annotations: None,
2588 })
2589 }
2590 }
2591 }
2592}
2593
2594fn convert_image(image_content: acp::ImageContent) -> LanguageModelImage {
2595 LanguageModelImage {
2596 source: image_content.data.into(),
2597 // TODO: make this optional?
2598 size: gpui::Size::new(0.into(), 0.into()),
2599 }
2600}