1use crate::{
2 ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DbLanguageModel, DbThread,
3 DeletePathTool, DiagnosticsTool, EditFileTool, FetchTool, FindPathTool, GrepTool,
4 ListDirectoryTool, MovePathTool, NowTool, OpenTool, ProjectSnapshot, ReadFileTool,
5 RestoreFileFromDiskTool, SaveFileTool, SystemPromptTemplate, Template, Templates, TerminalTool,
6 ThinkingTool, WebSearchTool,
7};
8use acp_thread::{MentionUri, UserMessageId};
9use action_log::ActionLog;
10
11use agent_client_protocol as acp;
12use agent_settings::{
13 AgentProfileId, AgentProfileSettings, AgentSettings, CompletionMode,
14 SUMMARIZE_THREAD_DETAILED_PROMPT, SUMMARIZE_THREAD_PROMPT,
15};
16use anyhow::{Context as _, Result, anyhow};
17use chrono::{DateTime, Utc};
18use client::{ModelRequestUsage, RequestUsage, UserStore};
19use cloud_llm_client::{CompletionIntent, Plan, UsageLimit};
20use collections::{HashMap, HashSet, IndexMap};
21use fs::Fs;
22use futures::stream;
23use futures::{
24 FutureExt,
25 channel::{mpsc, oneshot},
26 future::Shared,
27 stream::FuturesUnordered,
28};
29use gpui::{
30 App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity,
31};
32use language_model::{
33 LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelExt,
34 LanguageModelId, LanguageModelImage, LanguageModelProviderId, LanguageModelRegistry,
35 LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
36 LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolSchemaFormat,
37 LanguageModelToolUse, LanguageModelToolUseId, Role, SelectedModel, StopReason, TokenUsage,
38 ZED_CLOUD_PROVIDER_ID,
39};
40use project::Project;
41use prompt_store::ProjectContext;
42use schemars::{JsonSchema, Schema};
43use serde::{Deserialize, Serialize};
44use settings::{LanguageModelSelection, Settings, update_settings_file};
45use smol::stream::StreamExt;
46use std::{
47 collections::BTreeMap,
48 ops::RangeInclusive,
49 path::Path,
50 rc::Rc,
51 sync::Arc,
52 time::{Duration, Instant},
53};
54use std::{fmt::Write, path::PathBuf};
55use util::{ResultExt, debug_panic, markdown::MarkdownCodeBlock, paths::PathStyle};
56use uuid::Uuid;
57
58const TOOL_CANCELED_MESSAGE: &str = "Tool canceled by user";
59pub const MAX_TOOL_NAME_LENGTH: usize = 64;
60
61/// The ID of the user prompt that initiated a request.
62///
63/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
64#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
65pub struct PromptId(Arc<str>);
66
67impl PromptId {
68 pub fn new() -> Self {
69 Self(Uuid::new_v4().to_string().into())
70 }
71}
72
73impl std::fmt::Display for PromptId {
74 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
75 write!(f, "{}", self.0)
76 }
77}
78
79pub(crate) const MAX_RETRY_ATTEMPTS: u8 = 4;
80pub(crate) const BASE_RETRY_DELAY: Duration = Duration::from_secs(5);
81
82#[derive(Debug, Clone)]
83enum RetryStrategy {
84 ExponentialBackoff {
85 initial_delay: Duration,
86 max_attempts: u8,
87 },
88 Fixed {
89 delay: Duration,
90 max_attempts: u8,
91 },
92}
93
94#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
95pub enum Message {
96 User(UserMessage),
97 Agent(AgentMessage),
98 Resume,
99}
100
101impl Message {
102 pub fn as_agent_message(&self) -> Option<&AgentMessage> {
103 match self {
104 Message::Agent(agent_message) => Some(agent_message),
105 _ => None,
106 }
107 }
108
109 pub fn to_request(&self) -> Vec<LanguageModelRequestMessage> {
110 match self {
111 Message::User(message) => {
112 if message.content.is_empty() {
113 vec![]
114 } else {
115 vec![message.to_request()]
116 }
117 }
118 Message::Agent(message) => message.to_request(),
119 Message::Resume => vec![LanguageModelRequestMessage {
120 role: Role::User,
121 content: vec!["Continue where you left off".into()],
122 cache: false,
123 reasoning_details: None,
124 }],
125 }
126 }
127
128 pub fn to_markdown(&self) -> String {
129 match self {
130 Message::User(message) => message.to_markdown(),
131 Message::Agent(message) => message.to_markdown(),
132 Message::Resume => "[resume]\n".into(),
133 }
134 }
135
136 pub fn role(&self) -> Role {
137 match self {
138 Message::User(_) | Message::Resume => Role::User,
139 Message::Agent(_) => Role::Assistant,
140 }
141 }
142}
143
144#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
145pub struct UserMessage {
146 pub id: UserMessageId,
147 pub content: Vec<UserMessageContent>,
148}
149
150#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
151pub enum UserMessageContent {
152 Text(String),
153 Mention { uri: MentionUri, content: String },
154 Image(LanguageModelImage),
155}
156
157impl UserMessage {
158 pub fn to_markdown(&self) -> String {
159 let mut markdown = String::from("## User\n\n");
160
161 for content in &self.content {
162 match content {
163 UserMessageContent::Text(text) => {
164 markdown.push_str(text);
165 markdown.push('\n');
166 }
167 UserMessageContent::Image(_) => {
168 markdown.push_str("<image />\n");
169 }
170 UserMessageContent::Mention { uri, content } => {
171 if !content.is_empty() {
172 let _ = writeln!(&mut markdown, "{}\n\n{}", uri.as_link(), content);
173 } else {
174 let _ = writeln!(&mut markdown, "{}", uri.as_link());
175 }
176 }
177 }
178 }
179
180 markdown
181 }
182
183 fn to_request(&self) -> LanguageModelRequestMessage {
184 let mut message = LanguageModelRequestMessage {
185 role: Role::User,
186 content: Vec::with_capacity(self.content.len()),
187 cache: false,
188 reasoning_details: None,
189 };
190
191 const OPEN_CONTEXT: &str = "<context>\n\
192 The following items were attached by the user. \
193 They are up-to-date and don't need to be re-read.\n\n";
194
195 const OPEN_FILES_TAG: &str = "<files>";
196 const OPEN_DIRECTORIES_TAG: &str = "<directories>";
197 const OPEN_SYMBOLS_TAG: &str = "<symbols>";
198 const OPEN_SELECTIONS_TAG: &str = "<selections>";
199 const OPEN_THREADS_TAG: &str = "<threads>";
200 const OPEN_FETCH_TAG: &str = "<fetched_urls>";
201 const OPEN_RULES_TAG: &str =
202 "<rules>\nThe user has specified the following rules that should be applied:\n";
203
204 let mut file_context = OPEN_FILES_TAG.to_string();
205 let mut directory_context = OPEN_DIRECTORIES_TAG.to_string();
206 let mut symbol_context = OPEN_SYMBOLS_TAG.to_string();
207 let mut selection_context = OPEN_SELECTIONS_TAG.to_string();
208 let mut thread_context = OPEN_THREADS_TAG.to_string();
209 let mut fetch_context = OPEN_FETCH_TAG.to_string();
210 let mut rules_context = OPEN_RULES_TAG.to_string();
211
212 for chunk in &self.content {
213 let chunk = match chunk {
214 UserMessageContent::Text(text) => {
215 language_model::MessageContent::Text(text.clone())
216 }
217 UserMessageContent::Image(value) => {
218 language_model::MessageContent::Image(value.clone())
219 }
220 UserMessageContent::Mention { uri, content } => {
221 match uri {
222 MentionUri::File { abs_path } => {
223 write!(
224 &mut file_context,
225 "\n{}",
226 MarkdownCodeBlock {
227 tag: &codeblock_tag(abs_path, None),
228 text: &content.to_string(),
229 }
230 )
231 .ok();
232 }
233 MentionUri::PastedImage => {
234 debug_panic!("pasted image URI should not be used in mention content")
235 }
236 MentionUri::Directory { .. } => {
237 write!(&mut directory_context, "\n{}\n", content).ok();
238 }
239 MentionUri::Symbol {
240 abs_path: path,
241 line_range,
242 ..
243 } => {
244 write!(
245 &mut symbol_context,
246 "\n{}",
247 MarkdownCodeBlock {
248 tag: &codeblock_tag(path, Some(line_range)),
249 text: content
250 }
251 )
252 .ok();
253 }
254 MentionUri::Selection {
255 abs_path: path,
256 line_range,
257 ..
258 } => {
259 write!(
260 &mut selection_context,
261 "\n{}",
262 MarkdownCodeBlock {
263 tag: &codeblock_tag(
264 path.as_deref().unwrap_or("Untitled".as_ref()),
265 Some(line_range)
266 ),
267 text: content
268 }
269 )
270 .ok();
271 }
272 MentionUri::Thread { .. } => {
273 write!(&mut thread_context, "\n{}\n", content).ok();
274 }
275 MentionUri::TextThread { .. } => {
276 write!(&mut thread_context, "\n{}\n", content).ok();
277 }
278 MentionUri::Rule { .. } => {
279 write!(
280 &mut rules_context,
281 "\n{}",
282 MarkdownCodeBlock {
283 tag: "",
284 text: content
285 }
286 )
287 .ok();
288 }
289 MentionUri::Fetch { url } => {
290 write!(&mut fetch_context, "\nFetch: {}\n\n{}", url, content).ok();
291 }
292 }
293
294 language_model::MessageContent::Text(uri.as_link().to_string())
295 }
296 };
297
298 message.content.push(chunk);
299 }
300
301 let len_before_context = message.content.len();
302
303 if file_context.len() > OPEN_FILES_TAG.len() {
304 file_context.push_str("</files>\n");
305 message
306 .content
307 .push(language_model::MessageContent::Text(file_context));
308 }
309
310 if directory_context.len() > OPEN_DIRECTORIES_TAG.len() {
311 directory_context.push_str("</directories>\n");
312 message
313 .content
314 .push(language_model::MessageContent::Text(directory_context));
315 }
316
317 if symbol_context.len() > OPEN_SYMBOLS_TAG.len() {
318 symbol_context.push_str("</symbols>\n");
319 message
320 .content
321 .push(language_model::MessageContent::Text(symbol_context));
322 }
323
324 if selection_context.len() > OPEN_SELECTIONS_TAG.len() {
325 selection_context.push_str("</selections>\n");
326 message
327 .content
328 .push(language_model::MessageContent::Text(selection_context));
329 }
330
331 if thread_context.len() > OPEN_THREADS_TAG.len() {
332 thread_context.push_str("</threads>\n");
333 message
334 .content
335 .push(language_model::MessageContent::Text(thread_context));
336 }
337
338 if fetch_context.len() > OPEN_FETCH_TAG.len() {
339 fetch_context.push_str("</fetched_urls>\n");
340 message
341 .content
342 .push(language_model::MessageContent::Text(fetch_context));
343 }
344
345 if rules_context.len() > OPEN_RULES_TAG.len() {
346 rules_context.push_str("</user_rules>\n");
347 message
348 .content
349 .push(language_model::MessageContent::Text(rules_context));
350 }
351
352 if message.content.len() > len_before_context {
353 message.content.insert(
354 len_before_context,
355 language_model::MessageContent::Text(OPEN_CONTEXT.into()),
356 );
357 message
358 .content
359 .push(language_model::MessageContent::Text("</context>".into()));
360 }
361
362 message
363 }
364}
365
366fn codeblock_tag(full_path: &Path, line_range: Option<&RangeInclusive<u32>>) -> String {
367 let mut result = String::new();
368
369 if let Some(extension) = full_path.extension().and_then(|ext| ext.to_str()) {
370 let _ = write!(result, "{} ", extension);
371 }
372
373 let _ = write!(result, "{}", full_path.display());
374
375 if let Some(range) = line_range {
376 if range.start() == range.end() {
377 let _ = write!(result, ":{}", range.start() + 1);
378 } else {
379 let _ = write!(result, ":{}-{}", range.start() + 1, range.end() + 1);
380 }
381 }
382
383 result
384}
385
386impl AgentMessage {
387 pub fn to_markdown(&self) -> String {
388 let mut markdown = String::from("## Assistant\n\n");
389
390 for content in &self.content {
391 match content {
392 AgentMessageContent::Text(text) => {
393 markdown.push_str(text);
394 markdown.push('\n');
395 }
396 AgentMessageContent::Thinking { text, .. } => {
397 markdown.push_str("<think>");
398 markdown.push_str(text);
399 markdown.push_str("</think>\n");
400 }
401 AgentMessageContent::RedactedThinking(_) => {
402 markdown.push_str("<redacted_thinking />\n")
403 }
404 AgentMessageContent::ToolUse(tool_use) => {
405 markdown.push_str(&format!(
406 "**Tool Use**: {} (ID: {})\n",
407 tool_use.name, tool_use.id
408 ));
409 markdown.push_str(&format!(
410 "{}\n",
411 MarkdownCodeBlock {
412 tag: "json",
413 text: &format!("{:#}", tool_use.input)
414 }
415 ));
416 }
417 }
418 }
419
420 for tool_result in self.tool_results.values() {
421 markdown.push_str(&format!(
422 "**Tool Result**: {} (ID: {})\n\n",
423 tool_result.tool_name, tool_result.tool_use_id
424 ));
425 if tool_result.is_error {
426 markdown.push_str("**ERROR:**\n");
427 }
428
429 match &tool_result.content {
430 LanguageModelToolResultContent::Text(text) => {
431 writeln!(markdown, "{text}\n").ok();
432 }
433 LanguageModelToolResultContent::Image(_) => {
434 writeln!(markdown, "<image />\n").ok();
435 }
436 }
437
438 if let Some(output) = tool_result.output.as_ref() {
439 writeln!(
440 markdown,
441 "**Debug Output**:\n\n```json\n{}\n```\n",
442 serde_json::to_string_pretty(output).unwrap()
443 )
444 .unwrap();
445 }
446 }
447
448 markdown
449 }
450
451 pub fn to_request(&self) -> Vec<LanguageModelRequestMessage> {
452 let mut assistant_message = LanguageModelRequestMessage {
453 role: Role::Assistant,
454 content: Vec::with_capacity(self.content.len()),
455 cache: false,
456 reasoning_details: self.reasoning_details.clone(),
457 };
458 for chunk in &self.content {
459 match chunk {
460 AgentMessageContent::Text(text) => {
461 assistant_message
462 .content
463 .push(language_model::MessageContent::Text(text.clone()));
464 }
465 AgentMessageContent::Thinking { text, signature } => {
466 assistant_message
467 .content
468 .push(language_model::MessageContent::Thinking {
469 text: text.clone(),
470 signature: signature.clone(),
471 });
472 }
473 AgentMessageContent::RedactedThinking(value) => {
474 assistant_message.content.push(
475 language_model::MessageContent::RedactedThinking(value.clone()),
476 );
477 }
478 AgentMessageContent::ToolUse(tool_use) => {
479 if self.tool_results.contains_key(&tool_use.id) {
480 assistant_message
481 .content
482 .push(language_model::MessageContent::ToolUse(tool_use.clone()));
483 }
484 }
485 };
486 }
487
488 let mut user_message = LanguageModelRequestMessage {
489 role: Role::User,
490 content: Vec::new(),
491 cache: false,
492 reasoning_details: None,
493 };
494
495 for tool_result in self.tool_results.values() {
496 let mut tool_result = tool_result.clone();
497 // Surprisingly, the API fails if we return an empty string here.
498 // It thinks we are sending a tool use without a tool result.
499 if tool_result.content.is_empty() {
500 tool_result.content = "<Tool returned an empty string>".into();
501 }
502 user_message
503 .content
504 .push(language_model::MessageContent::ToolResult(tool_result));
505 }
506
507 let mut messages = Vec::new();
508 if !assistant_message.content.is_empty() {
509 messages.push(assistant_message);
510 }
511 if !user_message.content.is_empty() {
512 messages.push(user_message);
513 }
514 messages
515 }
516}
517
518#[derive(Default, Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
519pub struct AgentMessage {
520 pub content: Vec<AgentMessageContent>,
521 pub tool_results: IndexMap<LanguageModelToolUseId, LanguageModelToolResult>,
522 pub reasoning_details: Option<serde_json::Value>,
523}
524
525#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
526pub enum AgentMessageContent {
527 Text(String),
528 Thinking {
529 text: String,
530 signature: Option<String>,
531 },
532 RedactedThinking(String),
533 ToolUse(LanguageModelToolUse),
534}
535
536pub trait TerminalHandle {
537 fn id(&self, cx: &AsyncApp) -> Result<acp::TerminalId>;
538 fn current_output(&self, cx: &AsyncApp) -> Result<acp::TerminalOutputResponse>;
539 fn wait_for_exit(&self, cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>>;
540 fn kill(&self, cx: &AsyncApp) -> Result<()>;
541}
542
543pub trait ThreadEnvironment {
544 fn create_terminal(
545 &self,
546 command: String,
547 cwd: Option<PathBuf>,
548 output_byte_limit: Option<u64>,
549 cx: &mut AsyncApp,
550 ) -> Task<Result<Rc<dyn TerminalHandle>>>;
551}
552
553#[derive(Debug)]
554pub enum ThreadEvent {
555 UserMessage(UserMessage),
556 AgentText(String),
557 AgentThinking(String),
558 ToolCall(acp::ToolCall),
559 ToolCallUpdate(acp_thread::ToolCallUpdate),
560 ToolCallAuthorization(ToolCallAuthorization),
561 Retry(acp_thread::RetryStatus),
562 Stop(acp::StopReason),
563}
564
565#[derive(Debug)]
566pub struct NewTerminal {
567 pub command: String,
568 pub output_byte_limit: Option<u64>,
569 pub cwd: Option<PathBuf>,
570 pub response: oneshot::Sender<Result<Entity<acp_thread::Terminal>>>,
571}
572
573#[derive(Debug)]
574pub struct ToolCallAuthorization {
575 pub tool_call: acp::ToolCallUpdate,
576 pub options: Vec<acp::PermissionOption>,
577 pub response: oneshot::Sender<acp::PermissionOptionId>,
578}
579
580#[derive(Debug, thiserror::Error)]
581enum CompletionError {
582 #[error("max tokens")]
583 MaxTokens,
584 #[error("refusal")]
585 Refusal,
586 #[error(transparent)]
587 Other(#[from] anyhow::Error),
588}
589
590pub struct Thread {
591 id: acp::SessionId,
592 prompt_id: PromptId,
593 updated_at: DateTime<Utc>,
594 title: Option<SharedString>,
595 pending_title_generation: Option<Task<()>>,
596 pending_summary_generation: Option<Shared<Task<Option<SharedString>>>>,
597 summary: Option<SharedString>,
598 messages: Vec<Message>,
599 user_store: Entity<UserStore>,
600 completion_mode: CompletionMode,
601 /// Holds the task that handles agent interaction until the end of the turn.
602 /// Survives across multiple requests as the model performs tool calls and
603 /// we run tools, report their results.
604 running_turn: Option<RunningTurn>,
605 pending_message: Option<AgentMessage>,
606 tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
607 tool_use_limit_reached: bool,
608 request_token_usage: HashMap<UserMessageId, language_model::TokenUsage>,
609 #[allow(unused)]
610 cumulative_token_usage: TokenUsage,
611 #[allow(unused)]
612 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
613 context_server_registry: Entity<ContextServerRegistry>,
614 profile_id: AgentProfileId,
615 project_context: Entity<ProjectContext>,
616 templates: Arc<Templates>,
617 model: Option<Arc<dyn LanguageModel>>,
618 summarization_model: Option<Arc<dyn LanguageModel>>,
619 prompt_capabilities_tx: watch::Sender<acp::PromptCapabilities>,
620 pub(crate) prompt_capabilities_rx: watch::Receiver<acp::PromptCapabilities>,
621 pub(crate) project: Entity<Project>,
622 pub(crate) action_log: Entity<ActionLog>,
623 /// Tracks the last time files were read by the agent, to detect external modifications
624 pub(crate) file_read_times: HashMap<PathBuf, fs::MTime>,
625 /// True if this thread was imported from a shared thread and can be synced.
626 imported: bool,
627}
628
629impl Thread {
630 fn prompt_capabilities(model: Option<&dyn LanguageModel>) -> acp::PromptCapabilities {
631 let image = model.map_or(true, |model| model.supports_images());
632 acp::PromptCapabilities::new()
633 .image(image)
634 .embedded_context(true)
635 }
636
637 pub fn new(
638 project: Entity<Project>,
639 project_context: Entity<ProjectContext>,
640 context_server_registry: Entity<ContextServerRegistry>,
641 templates: Arc<Templates>,
642 model: Option<Arc<dyn LanguageModel>>,
643 cx: &mut Context<Self>,
644 ) -> Self {
645 let profile_id = AgentSettings::get_global(cx).default_profile.clone();
646 let action_log = cx.new(|_cx| ActionLog::new(project.clone()));
647 let (prompt_capabilities_tx, prompt_capabilities_rx) =
648 watch::channel(Self::prompt_capabilities(model.as_deref()));
649 Self {
650 id: acp::SessionId::new(uuid::Uuid::new_v4().to_string()),
651 prompt_id: PromptId::new(),
652 updated_at: Utc::now(),
653 title: None,
654 pending_title_generation: None,
655 pending_summary_generation: None,
656 summary: None,
657 messages: Vec::new(),
658 user_store: project.read(cx).user_store(),
659 completion_mode: AgentSettings::get_global(cx).preferred_completion_mode,
660 running_turn: None,
661 pending_message: None,
662 tools: BTreeMap::default(),
663 tool_use_limit_reached: false,
664 request_token_usage: HashMap::default(),
665 cumulative_token_usage: TokenUsage::default(),
666 initial_project_snapshot: {
667 let project_snapshot = Self::project_snapshot(project.clone(), cx);
668 cx.foreground_executor()
669 .spawn(async move { Some(project_snapshot.await) })
670 .shared()
671 },
672 context_server_registry,
673 profile_id,
674 project_context,
675 templates,
676 model,
677 summarization_model: None,
678 prompt_capabilities_tx,
679 prompt_capabilities_rx,
680 project,
681 action_log,
682 file_read_times: HashMap::default(),
683 imported: false,
684 }
685 }
686
687 pub fn id(&self) -> &acp::SessionId {
688 &self.id
689 }
690
691 /// Returns true if this thread was imported from a shared thread.
692 pub fn is_imported(&self) -> bool {
693 self.imported
694 }
695
696 pub fn replay(
697 &mut self,
698 cx: &mut Context<Self>,
699 ) -> mpsc::UnboundedReceiver<Result<ThreadEvent>> {
700 let (tx, rx) = mpsc::unbounded();
701 let stream = ThreadEventStream(tx);
702 for message in &self.messages {
703 match message {
704 Message::User(user_message) => stream.send_user_message(user_message),
705 Message::Agent(assistant_message) => {
706 for content in &assistant_message.content {
707 match content {
708 AgentMessageContent::Text(text) => stream.send_text(text),
709 AgentMessageContent::Thinking { text, .. } => {
710 stream.send_thinking(text)
711 }
712 AgentMessageContent::RedactedThinking(_) => {}
713 AgentMessageContent::ToolUse(tool_use) => {
714 self.replay_tool_call(
715 tool_use,
716 assistant_message.tool_results.get(&tool_use.id),
717 &stream,
718 cx,
719 );
720 }
721 }
722 }
723 }
724 Message::Resume => {}
725 }
726 }
727 rx
728 }
729
730 fn replay_tool_call(
731 &self,
732 tool_use: &LanguageModelToolUse,
733 tool_result: Option<&LanguageModelToolResult>,
734 stream: &ThreadEventStream,
735 cx: &mut Context<Self>,
736 ) {
737 let tool = self.tools.get(tool_use.name.as_ref()).cloned().or_else(|| {
738 self.context_server_registry
739 .read(cx)
740 .servers()
741 .find_map(|(_, tools)| {
742 if let Some(tool) = tools.get(tool_use.name.as_ref()) {
743 Some(tool.clone())
744 } else {
745 None
746 }
747 })
748 });
749
750 let Some(tool) = tool else {
751 stream
752 .0
753 .unbounded_send(Ok(ThreadEvent::ToolCall(
754 acp::ToolCall::new(tool_use.id.to_string(), tool_use.name.to_string())
755 .status(acp::ToolCallStatus::Failed)
756 .raw_input(tool_use.input.clone()),
757 )))
758 .ok();
759 return;
760 };
761
762 let title = tool.initial_title(tool_use.input.clone(), cx);
763 let kind = tool.kind();
764 stream.send_tool_call(
765 &tool_use.id,
766 &tool_use.name,
767 title,
768 kind,
769 tool_use.input.clone(),
770 );
771
772 let output = tool_result
773 .as_ref()
774 .and_then(|result| result.output.clone());
775 if let Some(output) = output.clone() {
776 let tool_event_stream = ToolCallEventStream::new(
777 tool_use.id.clone(),
778 stream.clone(),
779 Some(self.project.read(cx).fs().clone()),
780 );
781 tool.replay(tool_use.input.clone(), output, tool_event_stream, cx)
782 .log_err();
783 }
784
785 stream.update_tool_call_fields(
786 &tool_use.id,
787 acp::ToolCallUpdateFields::new()
788 .status(
789 tool_result
790 .as_ref()
791 .map_or(acp::ToolCallStatus::Failed, |result| {
792 if result.is_error {
793 acp::ToolCallStatus::Failed
794 } else {
795 acp::ToolCallStatus::Completed
796 }
797 }),
798 )
799 .raw_output(output),
800 );
801 }
802
803 pub fn from_db(
804 id: acp::SessionId,
805 db_thread: DbThread,
806 project: Entity<Project>,
807 project_context: Entity<ProjectContext>,
808 context_server_registry: Entity<ContextServerRegistry>,
809 templates: Arc<Templates>,
810 cx: &mut Context<Self>,
811 ) -> Self {
812 let profile_id = db_thread
813 .profile
814 .unwrap_or_else(|| AgentSettings::get_global(cx).default_profile.clone());
815
816 let mut model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
817 db_thread
818 .model
819 .and_then(|model| {
820 let model = SelectedModel {
821 provider: model.provider.clone().into(),
822 model: model.model.into(),
823 };
824 registry.select_model(&model, cx)
825 })
826 .or_else(|| registry.default_model())
827 .map(|model| model.model)
828 });
829
830 if model.is_none() {
831 model = Self::resolve_profile_model(&profile_id, cx);
832 }
833 if model.is_none() {
834 model = LanguageModelRegistry::global(cx).update(cx, |registry, _cx| {
835 registry.default_model().map(|model| model.model)
836 });
837 }
838
839 let (prompt_capabilities_tx, prompt_capabilities_rx) =
840 watch::channel(Self::prompt_capabilities(model.as_deref()));
841
842 let action_log = cx.new(|_| ActionLog::new(project.clone()));
843
844 Self {
845 id,
846 prompt_id: PromptId::new(),
847 title: if db_thread.title.is_empty() {
848 None
849 } else {
850 Some(db_thread.title.clone())
851 },
852 pending_title_generation: None,
853 pending_summary_generation: None,
854 summary: db_thread.detailed_summary,
855 messages: db_thread.messages,
856 user_store: project.read(cx).user_store(),
857 completion_mode: db_thread.completion_mode.unwrap_or_default(),
858 running_turn: None,
859 pending_message: None,
860 tools: BTreeMap::default(),
861 tool_use_limit_reached: false,
862 request_token_usage: db_thread.request_token_usage.clone(),
863 cumulative_token_usage: db_thread.cumulative_token_usage,
864 initial_project_snapshot: Task::ready(db_thread.initial_project_snapshot).shared(),
865 context_server_registry,
866 profile_id,
867 project_context,
868 templates,
869 model,
870 summarization_model: None,
871 project,
872 action_log,
873 updated_at: db_thread.updated_at,
874 prompt_capabilities_tx,
875 prompt_capabilities_rx,
876 file_read_times: HashMap::default(),
877 imported: db_thread.imported,
878 }
879 }
880
881 pub fn to_db(&self, cx: &App) -> Task<DbThread> {
882 let initial_project_snapshot = self.initial_project_snapshot.clone();
883 let mut thread = DbThread {
884 title: self.title(),
885 messages: self.messages.clone(),
886 updated_at: self.updated_at,
887 detailed_summary: self.summary.clone(),
888 initial_project_snapshot: None,
889 cumulative_token_usage: self.cumulative_token_usage,
890 request_token_usage: self.request_token_usage.clone(),
891 model: self.model.as_ref().map(|model| DbLanguageModel {
892 provider: model.provider_id().to_string(),
893 model: model.name().0.to_string(),
894 }),
895 completion_mode: Some(self.completion_mode),
896 profile: Some(self.profile_id.clone()),
897 imported: self.imported,
898 };
899
900 cx.background_spawn(async move {
901 let initial_project_snapshot = initial_project_snapshot.await;
902 thread.initial_project_snapshot = initial_project_snapshot;
903 thread
904 })
905 }
906
907 /// Create a snapshot of the current project state including git information and unsaved buffers.
908 fn project_snapshot(
909 project: Entity<Project>,
910 cx: &mut Context<Self>,
911 ) -> Task<Arc<ProjectSnapshot>> {
912 let task = project::telemetry_snapshot::TelemetrySnapshot::new(&project, cx);
913 cx.spawn(async move |_, _| {
914 let snapshot = task.await;
915
916 Arc::new(ProjectSnapshot {
917 worktree_snapshots: snapshot.worktree_snapshots,
918 timestamp: Utc::now(),
919 })
920 })
921 }
922
923 pub fn project_context(&self) -> &Entity<ProjectContext> {
924 &self.project_context
925 }
926
927 pub fn project(&self) -> &Entity<Project> {
928 &self.project
929 }
930
931 pub fn action_log(&self) -> &Entity<ActionLog> {
932 &self.action_log
933 }
934
935 pub fn is_empty(&self) -> bool {
936 self.messages.is_empty() && self.title.is_none()
937 }
938
939 pub fn model(&self) -> Option<&Arc<dyn LanguageModel>> {
940 self.model.as_ref()
941 }
942
943 pub fn set_model(&mut self, model: Arc<dyn LanguageModel>, cx: &mut Context<Self>) {
944 let old_usage = self.latest_token_usage();
945 self.model = Some(model);
946 let new_caps = Self::prompt_capabilities(self.model.as_deref());
947 let new_usage = self.latest_token_usage();
948 if old_usage != new_usage {
949 cx.emit(TokenUsageUpdated(new_usage));
950 }
951 self.prompt_capabilities_tx.send(new_caps).log_err();
952 cx.notify()
953 }
954
955 pub fn summarization_model(&self) -> Option<&Arc<dyn LanguageModel>> {
956 self.summarization_model.as_ref()
957 }
958
959 pub fn set_summarization_model(
960 &mut self,
961 model: Option<Arc<dyn LanguageModel>>,
962 cx: &mut Context<Self>,
963 ) {
964 self.summarization_model = model;
965 cx.notify()
966 }
967
968 pub fn completion_mode(&self) -> CompletionMode {
969 self.completion_mode
970 }
971
972 pub fn set_completion_mode(&mut self, mode: CompletionMode, cx: &mut Context<Self>) {
973 let old_usage = self.latest_token_usage();
974 self.completion_mode = mode;
975 let new_usage = self.latest_token_usage();
976 if old_usage != new_usage {
977 cx.emit(TokenUsageUpdated(new_usage));
978 }
979 cx.notify()
980 }
981
982 #[cfg(any(test, feature = "test-support"))]
983 pub fn last_message(&self) -> Option<Message> {
984 if let Some(message) = self.pending_message.clone() {
985 Some(Message::Agent(message))
986 } else {
987 self.messages.last().cloned()
988 }
989 }
990
991 pub fn add_default_tools(
992 &mut self,
993 environment: Rc<dyn ThreadEnvironment>,
994 cx: &mut Context<Self>,
995 ) {
996 let language_registry = self.project.read(cx).languages().clone();
997 self.add_tool(CopyPathTool::new(self.project.clone()));
998 self.add_tool(CreateDirectoryTool::new(self.project.clone()));
999 self.add_tool(DeletePathTool::new(
1000 self.project.clone(),
1001 self.action_log.clone(),
1002 ));
1003 self.add_tool(DiagnosticsTool::new(self.project.clone()));
1004 self.add_tool(EditFileTool::new(
1005 self.project.clone(),
1006 cx.weak_entity(),
1007 language_registry,
1008 Templates::new(),
1009 ));
1010 self.add_tool(FetchTool::new(self.project.read(cx).client().http_client()));
1011 self.add_tool(FindPathTool::new(self.project.clone()));
1012 self.add_tool(GrepTool::new(self.project.clone()));
1013 self.add_tool(ListDirectoryTool::new(self.project.clone()));
1014 self.add_tool(MovePathTool::new(self.project.clone()));
1015 self.add_tool(NowTool);
1016 self.add_tool(OpenTool::new(self.project.clone()));
1017 self.add_tool(ReadFileTool::new(
1018 cx.weak_entity(),
1019 self.project.clone(),
1020 self.action_log.clone(),
1021 ));
1022 self.add_tool(SaveFileTool::new(self.project.clone()));
1023 self.add_tool(RestoreFileFromDiskTool::new(self.project.clone()));
1024 self.add_tool(TerminalTool::new(self.project.clone(), environment));
1025 self.add_tool(ThinkingTool);
1026 self.add_tool(WebSearchTool);
1027 }
1028
1029 pub fn add_tool<T: AgentTool>(&mut self, tool: T) {
1030 self.tools.insert(T::name().into(), tool.erase());
1031 }
1032
1033 pub fn remove_tool(&mut self, name: &str) -> bool {
1034 self.tools.remove(name).is_some()
1035 }
1036
1037 pub fn profile(&self) -> &AgentProfileId {
1038 &self.profile_id
1039 }
1040
1041 pub fn set_profile(&mut self, profile_id: AgentProfileId, cx: &mut Context<Self>) {
1042 if self.profile_id == profile_id {
1043 return;
1044 }
1045
1046 self.profile_id = profile_id;
1047
1048 // Swap to the profile's preferred model when available.
1049 if let Some(model) = Self::resolve_profile_model(&self.profile_id, cx) {
1050 self.set_model(model, cx);
1051 }
1052 }
1053
1054 pub fn cancel(&mut self, cx: &mut Context<Self>) {
1055 if let Some(running_turn) = self.running_turn.take() {
1056 running_turn.cancel();
1057 }
1058 self.flush_pending_message(cx);
1059 }
1060
1061 fn update_token_usage(&mut self, update: language_model::TokenUsage, cx: &mut Context<Self>) {
1062 let Some(last_user_message) = self.last_user_message() else {
1063 return;
1064 };
1065
1066 self.request_token_usage
1067 .insert(last_user_message.id.clone(), update);
1068 cx.emit(TokenUsageUpdated(self.latest_token_usage()));
1069 cx.notify();
1070 }
1071
1072 pub fn truncate(&mut self, message_id: UserMessageId, cx: &mut Context<Self>) -> Result<()> {
1073 self.cancel(cx);
1074 let Some(position) = self.messages.iter().position(
1075 |msg| matches!(msg, Message::User(UserMessage { id, .. }) if id == &message_id),
1076 ) else {
1077 return Err(anyhow!("Message not found"));
1078 };
1079
1080 for message in self.messages.drain(position..) {
1081 match message {
1082 Message::User(message) => {
1083 self.request_token_usage.remove(&message.id);
1084 }
1085 Message::Agent(_) | Message::Resume => {}
1086 }
1087 }
1088 self.clear_summary();
1089 cx.notify();
1090 Ok(())
1091 }
1092
1093 pub fn latest_request_token_usage(&self) -> Option<language_model::TokenUsage> {
1094 let last_user_message = self.last_user_message()?;
1095 let tokens = self.request_token_usage.get(&last_user_message.id)?;
1096 Some(*tokens)
1097 }
1098
1099 pub fn latest_token_usage(&self) -> Option<acp_thread::TokenUsage> {
1100 let usage = self.latest_request_token_usage()?;
1101 let model = self.model.clone()?;
1102 Some(acp_thread::TokenUsage {
1103 max_tokens: model.max_token_count_for_mode(self.completion_mode.into()),
1104 used_tokens: usage.total_tokens(),
1105 })
1106 }
1107
1108 /// Get the total input token count as of the message before the given message.
1109 ///
1110 /// Returns `None` if:
1111 /// - `target_id` is the first message (no previous message)
1112 /// - The previous message hasn't received a response yet (no usage data)
1113 /// - `target_id` is not found in the messages
1114 pub fn tokens_before_message(&self, target_id: &UserMessageId) -> Option<u64> {
1115 let mut previous_user_message_id: Option<&UserMessageId> = None;
1116
1117 for message in &self.messages {
1118 if let Message::User(user_msg) = message {
1119 if &user_msg.id == target_id {
1120 let prev_id = previous_user_message_id?;
1121 let usage = self.request_token_usage.get(prev_id)?;
1122 return Some(usage.input_tokens);
1123 }
1124 previous_user_message_id = Some(&user_msg.id);
1125 }
1126 }
1127 None
1128 }
1129
1130 /// Look up the active profile and resolve its preferred model if one is configured.
1131 fn resolve_profile_model(
1132 profile_id: &AgentProfileId,
1133 cx: &mut Context<Self>,
1134 ) -> Option<Arc<dyn LanguageModel>> {
1135 let selection = AgentSettings::get_global(cx)
1136 .profiles
1137 .get(profile_id)?
1138 .default_model
1139 .clone()?;
1140 Self::resolve_model_from_selection(&selection, cx)
1141 }
1142
1143 /// Translate a stored model selection into the configured model from the registry.
1144 fn resolve_model_from_selection(
1145 selection: &LanguageModelSelection,
1146 cx: &mut Context<Self>,
1147 ) -> Option<Arc<dyn LanguageModel>> {
1148 let selected = SelectedModel {
1149 provider: LanguageModelProviderId::from(selection.provider.0.clone()),
1150 model: LanguageModelId::from(selection.model.clone()),
1151 };
1152 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
1153 registry
1154 .select_model(&selected, cx)
1155 .map(|configured| configured.model)
1156 })
1157 }
1158
1159 pub fn resume(
1160 &mut self,
1161 cx: &mut Context<Self>,
1162 ) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>> {
1163 self.messages.push(Message::Resume);
1164 cx.notify();
1165
1166 log::debug!("Total messages in thread: {}", self.messages.len());
1167 self.run_turn(cx)
1168 }
1169
1170 /// Sending a message results in the model streaming a response, which could include tool calls.
1171 /// After calling tools, the model will stops and waits for any outstanding tool calls to be completed and their results sent.
1172 /// The returned channel will report all the occurrences in which the model stops before erroring or ending its turn.
1173 pub fn send<T>(
1174 &mut self,
1175 id: UserMessageId,
1176 content: impl IntoIterator<Item = T>,
1177 cx: &mut Context<Self>,
1178 ) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>
1179 where
1180 T: Into<UserMessageContent>,
1181 {
1182 let content = content.into_iter().map(Into::into).collect::<Vec<_>>();
1183 log::debug!("Thread::send content: {:?}", content);
1184
1185 self.messages
1186 .push(Message::User(UserMessage { id, content }));
1187 cx.notify();
1188
1189 self.send_existing(cx)
1190 }
1191
1192 pub fn send_existing(
1193 &mut self,
1194 cx: &mut Context<Self>,
1195 ) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>> {
1196 let model = self.model().context("No language model configured")?;
1197
1198 log::info!("Thread::send called with model: {}", model.name().0);
1199 self.advance_prompt_id();
1200
1201 log::debug!("Total messages in thread: {}", self.messages.len());
1202 self.run_turn(cx)
1203 }
1204
1205 pub fn push_acp_user_block(
1206 &mut self,
1207 id: UserMessageId,
1208 blocks: impl IntoIterator<Item = acp::ContentBlock>,
1209 path_style: PathStyle,
1210 cx: &mut Context<Self>,
1211 ) {
1212 let content = blocks
1213 .into_iter()
1214 .map(|block| UserMessageContent::from_content_block(block, path_style))
1215 .collect::<Vec<_>>();
1216 self.messages
1217 .push(Message::User(UserMessage { id, content }));
1218 cx.notify();
1219 }
1220
1221 pub fn push_acp_agent_block(&mut self, block: acp::ContentBlock, cx: &mut Context<Self>) {
1222 let text = match block {
1223 acp::ContentBlock::Text(text_content) => text_content.text,
1224 acp::ContentBlock::Image(_) => "[image]".to_string(),
1225 acp::ContentBlock::Audio(_) => "[audio]".to_string(),
1226 acp::ContentBlock::ResourceLink(resource_link) => resource_link.uri,
1227 acp::ContentBlock::Resource(resource) => match resource.resource {
1228 acp::EmbeddedResourceResource::TextResourceContents(resource) => resource.uri,
1229 acp::EmbeddedResourceResource::BlobResourceContents(resource) => resource.uri,
1230 _ => "[resource]".to_string(),
1231 },
1232 _ => "[unknown]".to_string(),
1233 };
1234
1235 self.messages.push(Message::Agent(AgentMessage {
1236 content: vec![AgentMessageContent::Text(text)],
1237 ..Default::default()
1238 }));
1239 cx.notify();
1240 }
1241
1242 #[cfg(feature = "eval")]
1243 pub fn proceed(
1244 &mut self,
1245 cx: &mut Context<Self>,
1246 ) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>> {
1247 self.run_turn(cx)
1248 }
1249
1250 fn run_turn(
1251 &mut self,
1252 cx: &mut Context<Self>,
1253 ) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>> {
1254 self.cancel(cx);
1255
1256 let model = self.model.clone().context("No language model configured")?;
1257 let profile = AgentSettings::get_global(cx)
1258 .profiles
1259 .get(&self.profile_id)
1260 .context("Profile not found")?;
1261 let (events_tx, events_rx) = mpsc::unbounded::<Result<ThreadEvent>>();
1262 let event_stream = ThreadEventStream(events_tx);
1263 let message_ix = self.messages.len().saturating_sub(1);
1264 self.tool_use_limit_reached = false;
1265 self.clear_summary();
1266 self.running_turn = Some(RunningTurn {
1267 event_stream: event_stream.clone(),
1268 tools: self.enabled_tools(profile, &model, cx),
1269 _task: cx.spawn(async move |this, cx| {
1270 log::debug!("Starting agent turn execution");
1271
1272 let turn_result = Self::run_turn_internal(&this, model, &event_stream, cx).await;
1273 _ = this.update(cx, |this, cx| this.flush_pending_message(cx));
1274
1275 match turn_result {
1276 Ok(()) => {
1277 log::debug!("Turn execution completed");
1278 event_stream.send_stop(acp::StopReason::EndTurn);
1279 }
1280 Err(error) => {
1281 log::error!("Turn execution failed: {:?}", error);
1282 match error.downcast::<CompletionError>() {
1283 Ok(CompletionError::Refusal) => {
1284 event_stream.send_stop(acp::StopReason::Refusal);
1285 _ = this.update(cx, |this, _| this.messages.truncate(message_ix));
1286 }
1287 Ok(CompletionError::MaxTokens) => {
1288 event_stream.send_stop(acp::StopReason::MaxTokens);
1289 }
1290 Ok(CompletionError::Other(error)) | Err(error) => {
1291 event_stream.send_error(error);
1292 }
1293 }
1294 }
1295 }
1296
1297 _ = this.update(cx, |this, _| this.running_turn.take());
1298 }),
1299 });
1300 Ok(events_rx)
1301 }
1302
1303 async fn run_turn_internal(
1304 this: &WeakEntity<Self>,
1305 model: Arc<dyn LanguageModel>,
1306 event_stream: &ThreadEventStream,
1307 cx: &mut AsyncApp,
1308 ) -> Result<()> {
1309 let mut attempt = 0;
1310 let mut intent = CompletionIntent::UserPrompt;
1311 loop {
1312 let request =
1313 this.update(cx, |this, cx| this.build_completion_request(intent, cx))??;
1314
1315 telemetry::event!(
1316 "Agent Thread Completion",
1317 thread_id = this.read_with(cx, |this, _| this.id.to_string())?,
1318 prompt_id = this.read_with(cx, |this, _| this.prompt_id.to_string())?,
1319 model = model.telemetry_id(),
1320 model_provider = model.provider_id().to_string(),
1321 attempt
1322 );
1323
1324 log::debug!("Calling model.stream_completion, attempt {}", attempt);
1325
1326 let (mut events, mut error) = match model.stream_completion(request, cx).await {
1327 Ok(events) => (events, None),
1328 Err(err) => (stream::empty().boxed(), Some(err)),
1329 };
1330 let mut tool_results = FuturesUnordered::new();
1331 while let Some(event) = events.next().await {
1332 log::trace!("Received completion event: {:?}", event);
1333 match event {
1334 Ok(event) => {
1335 tool_results.extend(this.update(cx, |this, cx| {
1336 this.handle_completion_event(event, event_stream, cx)
1337 })??);
1338 }
1339 Err(err) => {
1340 error = Some(err);
1341 break;
1342 }
1343 }
1344 }
1345
1346 let end_turn = tool_results.is_empty();
1347 while let Some(tool_result) = tool_results.next().await {
1348 log::debug!("Tool finished {:?}", tool_result);
1349
1350 event_stream.update_tool_call_fields(
1351 &tool_result.tool_use_id,
1352 acp::ToolCallUpdateFields::new()
1353 .status(if tool_result.is_error {
1354 acp::ToolCallStatus::Failed
1355 } else {
1356 acp::ToolCallStatus::Completed
1357 })
1358 .raw_output(tool_result.output.clone()),
1359 );
1360 this.update(cx, |this, _cx| {
1361 this.pending_message()
1362 .tool_results
1363 .insert(tool_result.tool_use_id.clone(), tool_result);
1364 })?;
1365 }
1366
1367 this.update(cx, |this, cx| {
1368 this.flush_pending_message(cx);
1369 if this.title.is_none() && this.pending_title_generation.is_none() {
1370 this.generate_title(cx);
1371 }
1372 })?;
1373
1374 if let Some(error) = error {
1375 attempt += 1;
1376 let retry = this.update(cx, |this, cx| {
1377 let user_store = this.user_store.read(cx);
1378 this.handle_completion_error(error, attempt, user_store.plan())
1379 })??;
1380 let timer = cx.background_executor().timer(retry.duration);
1381 event_stream.send_retry(retry);
1382 timer.await;
1383 this.update(cx, |this, _cx| {
1384 if let Some(Message::Agent(message)) = this.messages.last() {
1385 if message.tool_results.is_empty() {
1386 intent = CompletionIntent::UserPrompt;
1387 this.messages.push(Message::Resume);
1388 }
1389 }
1390 })?;
1391 } else if this.read_with(cx, |this, _| this.tool_use_limit_reached)? {
1392 return Err(language_model::ToolUseLimitReachedError.into());
1393 } else if end_turn {
1394 return Ok(());
1395 } else {
1396 intent = CompletionIntent::ToolResults;
1397 attempt = 0;
1398 }
1399 }
1400 }
1401
1402 fn handle_completion_error(
1403 &mut self,
1404 error: LanguageModelCompletionError,
1405 attempt: u8,
1406 plan: Option<Plan>,
1407 ) -> Result<acp_thread::RetryStatus> {
1408 let Some(model) = self.model.as_ref() else {
1409 return Err(anyhow!(error));
1410 };
1411
1412 let auto_retry = if model.provider_id() == ZED_CLOUD_PROVIDER_ID {
1413 match plan {
1414 Some(Plan::V2(_)) => true,
1415 Some(Plan::V1(_)) => self.completion_mode == CompletionMode::Burn,
1416 None => false,
1417 }
1418 } else {
1419 true
1420 };
1421
1422 if !auto_retry {
1423 return Err(anyhow!(error));
1424 }
1425
1426 let Some(strategy) = Self::retry_strategy_for(&error) else {
1427 return Err(anyhow!(error));
1428 };
1429
1430 let max_attempts = match &strategy {
1431 RetryStrategy::ExponentialBackoff { max_attempts, .. } => *max_attempts,
1432 RetryStrategy::Fixed { max_attempts, .. } => *max_attempts,
1433 };
1434
1435 if attempt > max_attempts {
1436 return Err(anyhow!(error));
1437 }
1438
1439 let delay = match &strategy {
1440 RetryStrategy::ExponentialBackoff { initial_delay, .. } => {
1441 let delay_secs = initial_delay.as_secs() * 2u64.pow((attempt - 1) as u32);
1442 Duration::from_secs(delay_secs)
1443 }
1444 RetryStrategy::Fixed { delay, .. } => *delay,
1445 };
1446 log::debug!("Retry attempt {attempt} with delay {delay:?}");
1447
1448 Ok(acp_thread::RetryStatus {
1449 last_error: error.to_string().into(),
1450 attempt: attempt as usize,
1451 max_attempts: max_attempts as usize,
1452 started_at: Instant::now(),
1453 duration: delay,
1454 })
1455 }
1456
1457 /// A helper method that's called on every streamed completion event.
1458 /// Returns an optional tool result task, which the main agentic loop will
1459 /// send back to the model when it resolves.
1460 fn handle_completion_event(
1461 &mut self,
1462 event: LanguageModelCompletionEvent,
1463 event_stream: &ThreadEventStream,
1464 cx: &mut Context<Self>,
1465 ) -> Result<Option<Task<LanguageModelToolResult>>> {
1466 log::trace!("Handling streamed completion event: {:?}", event);
1467 use LanguageModelCompletionEvent::*;
1468
1469 match event {
1470 StartMessage { .. } => {
1471 self.flush_pending_message(cx);
1472 self.pending_message = Some(AgentMessage::default());
1473 }
1474 Text(new_text) => self.handle_text_event(new_text, event_stream, cx),
1475 Thinking { text, signature } => {
1476 self.handle_thinking_event(text, signature, event_stream, cx)
1477 }
1478 RedactedThinking { data } => self.handle_redacted_thinking_event(data, cx),
1479 ReasoningDetails(details) => {
1480 let last_message = self.pending_message();
1481 // Store the last non-empty reasoning_details (overwrites earlier ones)
1482 // This ensures we keep the encrypted reasoning with signatures, not the early text reasoning
1483 if let serde_json::Value::Array(ref arr) = details {
1484 if !arr.is_empty() {
1485 last_message.reasoning_details = Some(details);
1486 }
1487 } else {
1488 last_message.reasoning_details = Some(details);
1489 }
1490 }
1491 ToolUse(tool_use) => {
1492 return Ok(self.handle_tool_use_event(tool_use, event_stream, cx));
1493 }
1494 ToolUseJsonParseError {
1495 id,
1496 tool_name,
1497 raw_input,
1498 json_parse_error,
1499 } => {
1500 return Ok(Some(Task::ready(
1501 self.handle_tool_use_json_parse_error_event(
1502 id,
1503 tool_name,
1504 raw_input,
1505 json_parse_error,
1506 ),
1507 )));
1508 }
1509 UsageUpdate(usage) => {
1510 telemetry::event!(
1511 "Agent Thread Completion Usage Updated",
1512 thread_id = self.id.to_string(),
1513 prompt_id = self.prompt_id.to_string(),
1514 model = self.model.as_ref().map(|m| m.telemetry_id()),
1515 model_provider = self.model.as_ref().map(|m| m.provider_id().to_string()),
1516 input_tokens = usage.input_tokens,
1517 output_tokens = usage.output_tokens,
1518 cache_creation_input_tokens = usage.cache_creation_input_tokens,
1519 cache_read_input_tokens = usage.cache_read_input_tokens,
1520 );
1521 self.update_token_usage(usage, cx);
1522 }
1523 UsageUpdated { amount, limit } => {
1524 self.update_model_request_usage(amount, limit, cx);
1525 }
1526 ToolUseLimitReached => {
1527 self.tool_use_limit_reached = true;
1528 }
1529 Stop(StopReason::Refusal) => return Err(CompletionError::Refusal.into()),
1530 Stop(StopReason::MaxTokens) => return Err(CompletionError::MaxTokens.into()),
1531 Stop(StopReason::ToolUse | StopReason::EndTurn) => {}
1532 Started | Queued { .. } => {}
1533 }
1534
1535 Ok(None)
1536 }
1537
1538 fn handle_text_event(
1539 &mut self,
1540 new_text: String,
1541 event_stream: &ThreadEventStream,
1542 cx: &mut Context<Self>,
1543 ) {
1544 event_stream.send_text(&new_text);
1545
1546 let last_message = self.pending_message();
1547 if let Some(AgentMessageContent::Text(text)) = last_message.content.last_mut() {
1548 text.push_str(&new_text);
1549 } else {
1550 last_message
1551 .content
1552 .push(AgentMessageContent::Text(new_text));
1553 }
1554
1555 cx.notify();
1556 }
1557
1558 fn handle_thinking_event(
1559 &mut self,
1560 new_text: String,
1561 new_signature: Option<String>,
1562 event_stream: &ThreadEventStream,
1563 cx: &mut Context<Self>,
1564 ) {
1565 event_stream.send_thinking(&new_text);
1566
1567 let last_message = self.pending_message();
1568 if let Some(AgentMessageContent::Thinking { text, signature }) =
1569 last_message.content.last_mut()
1570 {
1571 text.push_str(&new_text);
1572 *signature = new_signature.or(signature.take());
1573 } else {
1574 last_message.content.push(AgentMessageContent::Thinking {
1575 text: new_text,
1576 signature: new_signature,
1577 });
1578 }
1579
1580 cx.notify();
1581 }
1582
1583 fn handle_redacted_thinking_event(&mut self, data: String, cx: &mut Context<Self>) {
1584 let last_message = self.pending_message();
1585 last_message
1586 .content
1587 .push(AgentMessageContent::RedactedThinking(data));
1588 cx.notify();
1589 }
1590
1591 fn handle_tool_use_event(
1592 &mut self,
1593 tool_use: LanguageModelToolUse,
1594 event_stream: &ThreadEventStream,
1595 cx: &mut Context<Self>,
1596 ) -> Option<Task<LanguageModelToolResult>> {
1597 cx.notify();
1598
1599 let tool = self.tool(tool_use.name.as_ref());
1600 let mut title = SharedString::from(&tool_use.name);
1601 let mut kind = acp::ToolKind::Other;
1602 if let Some(tool) = tool.as_ref() {
1603 title = tool.initial_title(tool_use.input.clone(), cx);
1604 kind = tool.kind();
1605 }
1606
1607 // Ensure the last message ends in the current tool use
1608 let last_message = self.pending_message();
1609 let push_new_tool_use = last_message.content.last_mut().is_none_or(|content| {
1610 if let AgentMessageContent::ToolUse(last_tool_use) = content {
1611 if last_tool_use.id == tool_use.id {
1612 *last_tool_use = tool_use.clone();
1613 false
1614 } else {
1615 true
1616 }
1617 } else {
1618 true
1619 }
1620 });
1621
1622 if push_new_tool_use {
1623 event_stream.send_tool_call(
1624 &tool_use.id,
1625 &tool_use.name,
1626 title,
1627 kind,
1628 tool_use.input.clone(),
1629 );
1630 last_message
1631 .content
1632 .push(AgentMessageContent::ToolUse(tool_use.clone()));
1633 } else {
1634 event_stream.update_tool_call_fields(
1635 &tool_use.id,
1636 acp::ToolCallUpdateFields::new()
1637 .title(title.as_str())
1638 .kind(kind)
1639 .raw_input(tool_use.input.clone()),
1640 );
1641 }
1642
1643 if !tool_use.is_input_complete {
1644 return None;
1645 }
1646
1647 let Some(tool) = tool else {
1648 let content = format!("No tool named {} exists", tool_use.name);
1649 return Some(Task::ready(LanguageModelToolResult {
1650 content: LanguageModelToolResultContent::Text(Arc::from(content)),
1651 tool_use_id: tool_use.id,
1652 tool_name: tool_use.name,
1653 is_error: true,
1654 output: None,
1655 }));
1656 };
1657
1658 let fs = self.project.read(cx).fs().clone();
1659 let tool_event_stream =
1660 ToolCallEventStream::new(tool_use.id.clone(), event_stream.clone(), Some(fs));
1661 tool_event_stream.update_fields(
1662 acp::ToolCallUpdateFields::new().status(acp::ToolCallStatus::InProgress),
1663 );
1664 let supports_images = self.model().is_some_and(|model| model.supports_images());
1665 let tool_result = tool.run(tool_use.input, tool_event_stream, cx);
1666 log::debug!("Running tool {}", tool_use.name);
1667 Some(cx.foreground_executor().spawn(async move {
1668 let tool_result = tool_result.await.and_then(|output| {
1669 if let LanguageModelToolResultContent::Image(_) = &output.llm_output
1670 && !supports_images
1671 {
1672 return Err(anyhow!(
1673 "Attempted to read an image, but this model doesn't support it.",
1674 ));
1675 }
1676 Ok(output)
1677 });
1678
1679 match tool_result {
1680 Ok(output) => LanguageModelToolResult {
1681 tool_use_id: tool_use.id,
1682 tool_name: tool_use.name,
1683 is_error: false,
1684 content: output.llm_output,
1685 output: Some(output.raw_output),
1686 },
1687 Err(error) => LanguageModelToolResult {
1688 tool_use_id: tool_use.id,
1689 tool_name: tool_use.name,
1690 is_error: true,
1691 content: LanguageModelToolResultContent::Text(Arc::from(error.to_string())),
1692 output: Some(error.to_string().into()),
1693 },
1694 }
1695 }))
1696 }
1697
1698 fn handle_tool_use_json_parse_error_event(
1699 &mut self,
1700 tool_use_id: LanguageModelToolUseId,
1701 tool_name: Arc<str>,
1702 raw_input: Arc<str>,
1703 json_parse_error: String,
1704 ) -> LanguageModelToolResult {
1705 let tool_output = format!("Error parsing input JSON: {json_parse_error}");
1706 LanguageModelToolResult {
1707 tool_use_id,
1708 tool_name,
1709 is_error: true,
1710 content: LanguageModelToolResultContent::Text(tool_output.into()),
1711 output: Some(serde_json::Value::String(raw_input.to_string())),
1712 }
1713 }
1714
1715 fn update_model_request_usage(&self, amount: usize, limit: UsageLimit, cx: &mut Context<Self>) {
1716 self.project
1717 .read(cx)
1718 .user_store()
1719 .update(cx, |user_store, cx| {
1720 user_store.update_model_request_usage(
1721 ModelRequestUsage(RequestUsage {
1722 amount: amount as i32,
1723 limit,
1724 }),
1725 cx,
1726 )
1727 });
1728 }
1729
1730 pub fn title(&self) -> SharedString {
1731 self.title.clone().unwrap_or("New Thread".into())
1732 }
1733
1734 pub fn is_generating_summary(&self) -> bool {
1735 self.pending_summary_generation.is_some()
1736 }
1737
1738 pub fn is_generating_title(&self) -> bool {
1739 self.pending_title_generation.is_some()
1740 }
1741
1742 pub fn summary(&mut self, cx: &mut Context<Self>) -> Shared<Task<Option<SharedString>>> {
1743 if let Some(summary) = self.summary.as_ref() {
1744 return Task::ready(Some(summary.clone())).shared();
1745 }
1746 if let Some(task) = self.pending_summary_generation.clone() {
1747 return task;
1748 }
1749 let Some(model) = self.summarization_model.clone() else {
1750 log::error!("No summarization model available");
1751 return Task::ready(None).shared();
1752 };
1753 let mut request = LanguageModelRequest {
1754 intent: Some(CompletionIntent::ThreadContextSummarization),
1755 temperature: AgentSettings::temperature_for_model(&model, cx),
1756 ..Default::default()
1757 };
1758
1759 for message in &self.messages {
1760 request.messages.extend(message.to_request());
1761 }
1762
1763 request.messages.push(LanguageModelRequestMessage {
1764 role: Role::User,
1765 content: vec![SUMMARIZE_THREAD_DETAILED_PROMPT.into()],
1766 cache: false,
1767 reasoning_details: None,
1768 });
1769
1770 let task = cx
1771 .spawn(async move |this, cx| {
1772 let mut summary = String::new();
1773 let mut messages = model.stream_completion(request, cx).await.log_err()?;
1774 while let Some(event) = messages.next().await {
1775 let event = event.log_err()?;
1776 let text = match event {
1777 LanguageModelCompletionEvent::Text(text) => text,
1778 LanguageModelCompletionEvent::UsageUpdated { amount, limit } => {
1779 this.update(cx, |thread, cx| {
1780 thread.update_model_request_usage(amount, limit, cx);
1781 })
1782 .ok()?;
1783 continue;
1784 }
1785 _ => continue,
1786 };
1787
1788 let mut lines = text.lines();
1789 summary.extend(lines.next());
1790 }
1791
1792 log::debug!("Setting summary: {}", summary);
1793 let summary = SharedString::from(summary);
1794
1795 this.update(cx, |this, cx| {
1796 this.summary = Some(summary.clone());
1797 this.pending_summary_generation = None;
1798 cx.notify()
1799 })
1800 .ok()?;
1801
1802 Some(summary)
1803 })
1804 .shared();
1805 self.pending_summary_generation = Some(task.clone());
1806 task
1807 }
1808
1809 pub fn generate_title(&mut self, cx: &mut Context<Self>) {
1810 let Some(model) = self.summarization_model.clone() else {
1811 return;
1812 };
1813
1814 log::debug!(
1815 "Generating title with model: {:?}",
1816 self.summarization_model.as_ref().map(|model| model.name())
1817 );
1818 let mut request = LanguageModelRequest {
1819 intent: Some(CompletionIntent::ThreadSummarization),
1820 temperature: AgentSettings::temperature_for_model(&model, cx),
1821 ..Default::default()
1822 };
1823
1824 for message in &self.messages {
1825 request.messages.extend(message.to_request());
1826 }
1827
1828 request.messages.push(LanguageModelRequestMessage {
1829 role: Role::User,
1830 content: vec![SUMMARIZE_THREAD_PROMPT.into()],
1831 cache: false,
1832 reasoning_details: None,
1833 });
1834 self.pending_title_generation = Some(cx.spawn(async move |this, cx| {
1835 let mut title = String::new();
1836
1837 let generate = async {
1838 let mut messages = model.stream_completion(request, cx).await?;
1839 while let Some(event) = messages.next().await {
1840 let event = event?;
1841 let text = match event {
1842 LanguageModelCompletionEvent::Text(text) => text,
1843 LanguageModelCompletionEvent::UsageUpdated { amount, limit } => {
1844 this.update(cx, |thread, cx| {
1845 thread.update_model_request_usage(amount, limit, cx);
1846 })?;
1847 continue;
1848 }
1849 _ => continue,
1850 };
1851
1852 let mut lines = text.lines();
1853 title.extend(lines.next());
1854
1855 // Stop if the LLM generated multiple lines.
1856 if lines.next().is_some() {
1857 break;
1858 }
1859 }
1860 anyhow::Ok(())
1861 };
1862
1863 if generate.await.context("failed to generate title").is_ok() {
1864 _ = this.update(cx, |this, cx| this.set_title(title.into(), cx));
1865 }
1866 _ = this.update(cx, |this, _| this.pending_title_generation = None);
1867 }));
1868 }
1869
1870 pub fn set_title(&mut self, title: SharedString, cx: &mut Context<Self>) {
1871 self.pending_title_generation = None;
1872 if Some(&title) != self.title.as_ref() {
1873 self.title = Some(title);
1874 cx.emit(TitleUpdated);
1875 cx.notify();
1876 }
1877 }
1878
1879 fn clear_summary(&mut self) {
1880 self.summary = None;
1881 self.pending_summary_generation = None;
1882 }
1883
1884 fn last_user_message(&self) -> Option<&UserMessage> {
1885 self.messages
1886 .iter()
1887 .rev()
1888 .find_map(|message| match message {
1889 Message::User(user_message) => Some(user_message),
1890 Message::Agent(_) => None,
1891 Message::Resume => None,
1892 })
1893 }
1894
1895 fn pending_message(&mut self) -> &mut AgentMessage {
1896 self.pending_message.get_or_insert_default()
1897 }
1898
1899 fn flush_pending_message(&mut self, cx: &mut Context<Self>) {
1900 let Some(mut message) = self.pending_message.take() else {
1901 return;
1902 };
1903
1904 if message.content.is_empty() {
1905 return;
1906 }
1907
1908 for content in &message.content {
1909 let AgentMessageContent::ToolUse(tool_use) = content else {
1910 continue;
1911 };
1912
1913 if !message.tool_results.contains_key(&tool_use.id) {
1914 message.tool_results.insert(
1915 tool_use.id.clone(),
1916 LanguageModelToolResult {
1917 tool_use_id: tool_use.id.clone(),
1918 tool_name: tool_use.name.clone(),
1919 is_error: true,
1920 content: LanguageModelToolResultContent::Text(TOOL_CANCELED_MESSAGE.into()),
1921 output: None,
1922 },
1923 );
1924 }
1925 }
1926
1927 self.messages.push(Message::Agent(message));
1928 self.updated_at = Utc::now();
1929 self.clear_summary();
1930 cx.notify()
1931 }
1932
1933 pub(crate) fn build_completion_request(
1934 &self,
1935 completion_intent: CompletionIntent,
1936 cx: &App,
1937 ) -> Result<LanguageModelRequest> {
1938 let model = self.model().context("No language model configured")?;
1939 let tools = if let Some(turn) = self.running_turn.as_ref() {
1940 turn.tools
1941 .iter()
1942 .filter_map(|(tool_name, tool)| {
1943 log::trace!("Including tool: {}", tool_name);
1944 Some(LanguageModelRequestTool {
1945 name: tool_name.to_string(),
1946 description: tool.description().to_string(),
1947 input_schema: tool.input_schema(model.tool_input_format()).log_err()?,
1948 })
1949 })
1950 .collect::<Vec<_>>()
1951 } else {
1952 Vec::new()
1953 };
1954
1955 log::debug!("Building completion request");
1956 log::debug!("Completion intent: {:?}", completion_intent);
1957 log::debug!("Completion mode: {:?}", self.completion_mode);
1958
1959 let available_tools: Vec<_> = self
1960 .running_turn
1961 .as_ref()
1962 .map(|turn| turn.tools.keys().cloned().collect())
1963 .unwrap_or_default();
1964
1965 log::debug!("Request includes {} tools", available_tools.len());
1966 let messages = self.build_request_messages(available_tools, cx);
1967 log::debug!("Request will include {} messages", messages.len());
1968
1969 let request = LanguageModelRequest {
1970 thread_id: Some(self.id.to_string()),
1971 prompt_id: Some(self.prompt_id.to_string()),
1972 intent: Some(completion_intent),
1973 mode: Some(self.completion_mode.into()),
1974 messages,
1975 tools,
1976 tool_choice: None,
1977 stop: Vec::new(),
1978 temperature: AgentSettings::temperature_for_model(model, cx),
1979 thinking_allowed: true,
1980 };
1981
1982 log::debug!("Completion request built successfully");
1983 Ok(request)
1984 }
1985
1986 fn enabled_tools(
1987 &self,
1988 profile: &AgentProfileSettings,
1989 model: &Arc<dyn LanguageModel>,
1990 cx: &App,
1991 ) -> BTreeMap<SharedString, Arc<dyn AnyAgentTool>> {
1992 fn truncate(tool_name: &SharedString) -> SharedString {
1993 if tool_name.len() > MAX_TOOL_NAME_LENGTH {
1994 let mut truncated = tool_name.to_string();
1995 truncated.truncate(MAX_TOOL_NAME_LENGTH);
1996 truncated.into()
1997 } else {
1998 tool_name.clone()
1999 }
2000 }
2001
2002 let mut tools = self
2003 .tools
2004 .iter()
2005 .filter_map(|(tool_name, tool)| {
2006 if tool.supports_provider(&model.provider_id())
2007 && profile.is_tool_enabled(tool_name)
2008 {
2009 Some((truncate(tool_name), tool.clone()))
2010 } else {
2011 None
2012 }
2013 })
2014 .collect::<BTreeMap<_, _>>();
2015
2016 let mut context_server_tools = Vec::new();
2017 let mut seen_tools = tools.keys().cloned().collect::<HashSet<_>>();
2018 let mut duplicate_tool_names = HashSet::default();
2019 for (server_id, server_tools) in self.context_server_registry.read(cx).servers() {
2020 for (tool_name, tool) in server_tools {
2021 if profile.is_context_server_tool_enabled(&server_id.0, &tool_name) {
2022 let tool_name = truncate(tool_name);
2023 if !seen_tools.insert(tool_name.clone()) {
2024 duplicate_tool_names.insert(tool_name.clone());
2025 }
2026 context_server_tools.push((server_id.clone(), tool_name, tool.clone()));
2027 }
2028 }
2029 }
2030
2031 // When there are duplicate tool names, disambiguate by prefixing them
2032 // with the server ID. In the rare case there isn't enough space for the
2033 // disambiguated tool name, keep only the last tool with this name.
2034 for (server_id, tool_name, tool) in context_server_tools {
2035 if duplicate_tool_names.contains(&tool_name) {
2036 let available = MAX_TOOL_NAME_LENGTH.saturating_sub(tool_name.len());
2037 if available >= 2 {
2038 let mut disambiguated = server_id.0.to_string();
2039 disambiguated.truncate(available - 1);
2040 disambiguated.push('_');
2041 disambiguated.push_str(&tool_name);
2042 tools.insert(disambiguated.into(), tool.clone());
2043 } else {
2044 tools.insert(tool_name, tool.clone());
2045 }
2046 } else {
2047 tools.insert(tool_name, tool.clone());
2048 }
2049 }
2050
2051 tools
2052 }
2053
2054 fn tool(&self, name: &str) -> Option<Arc<dyn AnyAgentTool>> {
2055 self.running_turn.as_ref()?.tools.get(name).cloned()
2056 }
2057
2058 pub fn has_tool(&self, name: &str) -> bool {
2059 self.running_turn
2060 .as_ref()
2061 .is_some_and(|turn| turn.tools.contains_key(name))
2062 }
2063
2064 fn build_request_messages(
2065 &self,
2066 available_tools: Vec<SharedString>,
2067 cx: &App,
2068 ) -> Vec<LanguageModelRequestMessage> {
2069 log::trace!(
2070 "Building request messages from {} thread messages",
2071 self.messages.len()
2072 );
2073
2074 let system_prompt = SystemPromptTemplate {
2075 project: self.project_context.read(cx),
2076 available_tools,
2077 model_name: self.model.as_ref().map(|m| m.name().0.to_string()),
2078 }
2079 .render(&self.templates)
2080 .context("failed to build system prompt")
2081 .expect("Invalid template");
2082 let mut messages = vec![LanguageModelRequestMessage {
2083 role: Role::System,
2084 content: vec![system_prompt.into()],
2085 cache: false,
2086 reasoning_details: None,
2087 }];
2088 for message in &self.messages {
2089 messages.extend(message.to_request());
2090 }
2091
2092 if let Some(last_message) = messages.last_mut() {
2093 last_message.cache = true;
2094 }
2095
2096 if let Some(message) = self.pending_message.as_ref() {
2097 messages.extend(message.to_request());
2098 }
2099
2100 messages
2101 }
2102
2103 pub fn to_markdown(&self) -> String {
2104 let mut markdown = String::new();
2105 for (ix, message) in self.messages.iter().enumerate() {
2106 if ix > 0 {
2107 markdown.push('\n');
2108 }
2109 markdown.push_str(&message.to_markdown());
2110 }
2111
2112 if let Some(message) = self.pending_message.as_ref() {
2113 markdown.push('\n');
2114 markdown.push_str(&message.to_markdown());
2115 }
2116
2117 markdown
2118 }
2119
2120 fn advance_prompt_id(&mut self) {
2121 self.prompt_id = PromptId::new();
2122 }
2123
2124 fn retry_strategy_for(error: &LanguageModelCompletionError) -> Option<RetryStrategy> {
2125 use LanguageModelCompletionError::*;
2126 use http_client::StatusCode;
2127
2128 // General strategy here:
2129 // - If retrying won't help (e.g. invalid API key or payload too large), return None so we don't retry at all.
2130 // - If it's a time-based issue (e.g. server overloaded, rate limit exceeded), retry up to 4 times with exponential backoff.
2131 // - If it's an issue that *might* be fixed by retrying (e.g. internal server error), retry up to 3 times.
2132 match error {
2133 HttpResponseError {
2134 status_code: StatusCode::TOO_MANY_REQUESTS,
2135 ..
2136 } => Some(RetryStrategy::ExponentialBackoff {
2137 initial_delay: BASE_RETRY_DELAY,
2138 max_attempts: MAX_RETRY_ATTEMPTS,
2139 }),
2140 ServerOverloaded { retry_after, .. } | RateLimitExceeded { retry_after, .. } => {
2141 Some(RetryStrategy::Fixed {
2142 delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
2143 max_attempts: MAX_RETRY_ATTEMPTS,
2144 })
2145 }
2146 UpstreamProviderError {
2147 status,
2148 retry_after,
2149 ..
2150 } => match *status {
2151 StatusCode::TOO_MANY_REQUESTS | StatusCode::SERVICE_UNAVAILABLE => {
2152 Some(RetryStrategy::Fixed {
2153 delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
2154 max_attempts: MAX_RETRY_ATTEMPTS,
2155 })
2156 }
2157 StatusCode::INTERNAL_SERVER_ERROR => Some(RetryStrategy::Fixed {
2158 delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
2159 // Internal Server Error could be anything, retry up to 3 times.
2160 max_attempts: 3,
2161 }),
2162 status => {
2163 // There is no StatusCode variant for the unofficial HTTP 529 ("The service is overloaded"),
2164 // but we frequently get them in practice. See https://http.dev/529
2165 if status.as_u16() == 529 {
2166 Some(RetryStrategy::Fixed {
2167 delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
2168 max_attempts: MAX_RETRY_ATTEMPTS,
2169 })
2170 } else {
2171 Some(RetryStrategy::Fixed {
2172 delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
2173 max_attempts: 2,
2174 })
2175 }
2176 }
2177 },
2178 ApiInternalServerError { .. } => Some(RetryStrategy::Fixed {
2179 delay: BASE_RETRY_DELAY,
2180 max_attempts: 3,
2181 }),
2182 ApiReadResponseError { .. }
2183 | HttpSend { .. }
2184 | DeserializeResponse { .. }
2185 | BadRequestFormat { .. } => Some(RetryStrategy::Fixed {
2186 delay: BASE_RETRY_DELAY,
2187 max_attempts: 3,
2188 }),
2189 // Retrying these errors definitely shouldn't help.
2190 HttpResponseError {
2191 status_code:
2192 StatusCode::PAYLOAD_TOO_LARGE | StatusCode::FORBIDDEN | StatusCode::UNAUTHORIZED,
2193 ..
2194 }
2195 | AuthenticationError { .. }
2196 | PermissionError { .. }
2197 | NoApiKey { .. }
2198 | ApiEndpointNotFound { .. }
2199 | PromptTooLarge { .. } => None,
2200 // These errors might be transient, so retry them
2201 SerializeRequest { .. } | BuildRequestBody { .. } => Some(RetryStrategy::Fixed {
2202 delay: BASE_RETRY_DELAY,
2203 max_attempts: 1,
2204 }),
2205 // Retry all other 4xx and 5xx errors once.
2206 HttpResponseError { status_code, .. }
2207 if status_code.is_client_error() || status_code.is_server_error() =>
2208 {
2209 Some(RetryStrategy::Fixed {
2210 delay: BASE_RETRY_DELAY,
2211 max_attempts: 3,
2212 })
2213 }
2214 Other(err)
2215 if err.is::<language_model::PaymentRequiredError>()
2216 || err.is::<language_model::ModelRequestLimitReachedError>() =>
2217 {
2218 // Retrying won't help for Payment Required or Model Request Limit errors (where
2219 // the user must upgrade to usage-based billing to get more requests, or else wait
2220 // for a significant amount of time for the request limit to reset).
2221 None
2222 }
2223 // Conservatively assume that any other errors are non-retryable
2224 HttpResponseError { .. } | Other(..) => Some(RetryStrategy::Fixed {
2225 delay: BASE_RETRY_DELAY,
2226 max_attempts: 2,
2227 }),
2228 }
2229 }
2230}
2231
2232struct RunningTurn {
2233 /// Holds the task that handles agent interaction until the end of the turn.
2234 /// Survives across multiple requests as the model performs tool calls and
2235 /// we run tools, report their results.
2236 _task: Task<()>,
2237 /// The current event stream for the running turn. Used to report a final
2238 /// cancellation event if we cancel the turn.
2239 event_stream: ThreadEventStream,
2240 /// The tools that were enabled for this turn.
2241 tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
2242}
2243
2244impl RunningTurn {
2245 fn cancel(self) {
2246 log::debug!("Cancelling in progress turn");
2247 self.event_stream.send_canceled();
2248 }
2249}
2250
2251pub struct TokenUsageUpdated(pub Option<acp_thread::TokenUsage>);
2252
2253impl EventEmitter<TokenUsageUpdated> for Thread {}
2254
2255pub struct TitleUpdated;
2256
2257impl EventEmitter<TitleUpdated> for Thread {}
2258
2259pub trait AgentTool
2260where
2261 Self: 'static + Sized,
2262{
2263 type Input: for<'de> Deserialize<'de> + Serialize + JsonSchema;
2264 type Output: for<'de> Deserialize<'de> + Serialize + Into<LanguageModelToolResultContent>;
2265
2266 fn name() -> &'static str;
2267
2268 fn description() -> SharedString {
2269 let schema = schemars::schema_for!(Self::Input);
2270 SharedString::new(
2271 schema
2272 .get("description")
2273 .and_then(|description| description.as_str())
2274 .unwrap_or_default(),
2275 )
2276 }
2277
2278 fn kind() -> acp::ToolKind;
2279
2280 /// The initial tool title to display. Can be updated during the tool run.
2281 fn initial_title(
2282 &self,
2283 input: Result<Self::Input, serde_json::Value>,
2284 cx: &mut App,
2285 ) -> SharedString;
2286
2287 /// Returns the JSON schema that describes the tool's input.
2288 fn input_schema(format: LanguageModelToolSchemaFormat) -> Schema {
2289 language_model::tool_schema::root_schema_for::<Self::Input>(format)
2290 }
2291
2292 /// Some tools rely on a provider for the underlying billing or other reasons.
2293 /// Allow the tool to check if they are compatible, or should be filtered out.
2294 fn supports_provider(_provider: &LanguageModelProviderId) -> bool {
2295 true
2296 }
2297
2298 /// Runs the tool with the provided input.
2299 fn run(
2300 self: Arc<Self>,
2301 input: Self::Input,
2302 event_stream: ToolCallEventStream,
2303 cx: &mut App,
2304 ) -> Task<Result<Self::Output>>;
2305
2306 /// Emits events for a previous execution of the tool.
2307 fn replay(
2308 &self,
2309 _input: Self::Input,
2310 _output: Self::Output,
2311 _event_stream: ToolCallEventStream,
2312 _cx: &mut App,
2313 ) -> Result<()> {
2314 Ok(())
2315 }
2316
2317 fn erase(self) -> Arc<dyn AnyAgentTool> {
2318 Arc::new(Erased(Arc::new(self)))
2319 }
2320}
2321
2322pub struct Erased<T>(T);
2323
2324pub struct AgentToolOutput {
2325 pub llm_output: LanguageModelToolResultContent,
2326 pub raw_output: serde_json::Value,
2327}
2328
2329pub trait AnyAgentTool {
2330 fn name(&self) -> SharedString;
2331 fn description(&self) -> SharedString;
2332 fn kind(&self) -> acp::ToolKind;
2333 fn initial_title(&self, input: serde_json::Value, _cx: &mut App) -> SharedString;
2334 fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value>;
2335 fn supports_provider(&self, _provider: &LanguageModelProviderId) -> bool {
2336 true
2337 }
2338 fn run(
2339 self: Arc<Self>,
2340 input: serde_json::Value,
2341 event_stream: ToolCallEventStream,
2342 cx: &mut App,
2343 ) -> Task<Result<AgentToolOutput>>;
2344 fn replay(
2345 &self,
2346 input: serde_json::Value,
2347 output: serde_json::Value,
2348 event_stream: ToolCallEventStream,
2349 cx: &mut App,
2350 ) -> Result<()>;
2351}
2352
2353impl<T> AnyAgentTool for Erased<Arc<T>>
2354where
2355 T: AgentTool,
2356{
2357 fn name(&self) -> SharedString {
2358 T::name().into()
2359 }
2360
2361 fn description(&self) -> SharedString {
2362 T::description()
2363 }
2364
2365 fn kind(&self) -> agent_client_protocol::ToolKind {
2366 T::kind()
2367 }
2368
2369 fn initial_title(&self, input: serde_json::Value, _cx: &mut App) -> SharedString {
2370 let parsed_input = serde_json::from_value(input.clone()).map_err(|_| input);
2371 self.0.initial_title(parsed_input, _cx)
2372 }
2373
2374 fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
2375 let mut json = serde_json::to_value(T::input_schema(format))?;
2376 language_model::tool_schema::adapt_schema_to_format(&mut json, format)?;
2377 Ok(json)
2378 }
2379
2380 fn supports_provider(&self, provider: &LanguageModelProviderId) -> bool {
2381 T::supports_provider(provider)
2382 }
2383
2384 fn run(
2385 self: Arc<Self>,
2386 input: serde_json::Value,
2387 event_stream: ToolCallEventStream,
2388 cx: &mut App,
2389 ) -> Task<Result<AgentToolOutput>> {
2390 cx.spawn(async move |cx| {
2391 let input = serde_json::from_value(input)?;
2392 let output = cx
2393 .update(|cx| self.0.clone().run(input, event_stream, cx))?
2394 .await?;
2395 let raw_output = serde_json::to_value(&output)?;
2396 Ok(AgentToolOutput {
2397 llm_output: output.into(),
2398 raw_output,
2399 })
2400 })
2401 }
2402
2403 fn replay(
2404 &self,
2405 input: serde_json::Value,
2406 output: serde_json::Value,
2407 event_stream: ToolCallEventStream,
2408 cx: &mut App,
2409 ) -> Result<()> {
2410 let input = serde_json::from_value(input)?;
2411 let output = serde_json::from_value(output)?;
2412 self.0.replay(input, output, event_stream, cx)
2413 }
2414}
2415
2416#[derive(Clone)]
2417struct ThreadEventStream(mpsc::UnboundedSender<Result<ThreadEvent>>);
2418
2419impl ThreadEventStream {
2420 fn send_user_message(&self, message: &UserMessage) {
2421 self.0
2422 .unbounded_send(Ok(ThreadEvent::UserMessage(message.clone())))
2423 .ok();
2424 }
2425
2426 fn send_text(&self, text: &str) {
2427 self.0
2428 .unbounded_send(Ok(ThreadEvent::AgentText(text.to_string())))
2429 .ok();
2430 }
2431
2432 fn send_thinking(&self, text: &str) {
2433 self.0
2434 .unbounded_send(Ok(ThreadEvent::AgentThinking(text.to_string())))
2435 .ok();
2436 }
2437
2438 fn send_tool_call(
2439 &self,
2440 id: &LanguageModelToolUseId,
2441 tool_name: &str,
2442 title: SharedString,
2443 kind: acp::ToolKind,
2444 input: serde_json::Value,
2445 ) {
2446 self.0
2447 .unbounded_send(Ok(ThreadEvent::ToolCall(Self::initial_tool_call(
2448 id,
2449 tool_name,
2450 title.to_string(),
2451 kind,
2452 input,
2453 ))))
2454 .ok();
2455 }
2456
2457 fn initial_tool_call(
2458 id: &LanguageModelToolUseId,
2459 tool_name: &str,
2460 title: String,
2461 kind: acp::ToolKind,
2462 input: serde_json::Value,
2463 ) -> acp::ToolCall {
2464 acp::ToolCall::new(id.to_string(), title)
2465 .kind(kind)
2466 .raw_input(input)
2467 .meta(acp::Meta::from_iter([(
2468 "tool_name".into(),
2469 tool_name.into(),
2470 )]))
2471 }
2472
2473 fn update_tool_call_fields(
2474 &self,
2475 tool_use_id: &LanguageModelToolUseId,
2476 fields: acp::ToolCallUpdateFields,
2477 ) {
2478 self.0
2479 .unbounded_send(Ok(ThreadEvent::ToolCallUpdate(
2480 acp::ToolCallUpdate::new(tool_use_id.to_string(), fields).into(),
2481 )))
2482 .ok();
2483 }
2484
2485 fn send_retry(&self, status: acp_thread::RetryStatus) {
2486 self.0.unbounded_send(Ok(ThreadEvent::Retry(status))).ok();
2487 }
2488
2489 fn send_stop(&self, reason: acp::StopReason) {
2490 self.0.unbounded_send(Ok(ThreadEvent::Stop(reason))).ok();
2491 }
2492
2493 fn send_canceled(&self) {
2494 self.0
2495 .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::Cancelled)))
2496 .ok();
2497 }
2498
2499 fn send_error(&self, error: impl Into<anyhow::Error>) {
2500 self.0.unbounded_send(Err(error.into())).ok();
2501 }
2502}
2503
2504#[derive(Clone)]
2505pub struct ToolCallEventStream {
2506 tool_use_id: LanguageModelToolUseId,
2507 stream: ThreadEventStream,
2508 fs: Option<Arc<dyn Fs>>,
2509}
2510
2511impl ToolCallEventStream {
2512 #[cfg(any(test, feature = "test-support"))]
2513 pub fn test() -> (Self, ToolCallEventStreamReceiver) {
2514 let (events_tx, events_rx) = mpsc::unbounded::<Result<ThreadEvent>>();
2515
2516 let stream = ToolCallEventStream::new("test_id".into(), ThreadEventStream(events_tx), None);
2517
2518 (stream, ToolCallEventStreamReceiver(events_rx))
2519 }
2520
2521 fn new(
2522 tool_use_id: LanguageModelToolUseId,
2523 stream: ThreadEventStream,
2524 fs: Option<Arc<dyn Fs>>,
2525 ) -> Self {
2526 Self {
2527 tool_use_id,
2528 stream,
2529 fs,
2530 }
2531 }
2532
2533 pub fn update_fields(&self, fields: acp::ToolCallUpdateFields) {
2534 self.stream
2535 .update_tool_call_fields(&self.tool_use_id, fields);
2536 }
2537
2538 pub fn update_diff(&self, diff: Entity<acp_thread::Diff>) {
2539 self.stream
2540 .0
2541 .unbounded_send(Ok(ThreadEvent::ToolCallUpdate(
2542 acp_thread::ToolCallUpdateDiff {
2543 id: acp::ToolCallId::new(self.tool_use_id.to_string()),
2544 diff,
2545 }
2546 .into(),
2547 )))
2548 .ok();
2549 }
2550
2551 pub fn authorize(&self, title: impl Into<String>, cx: &mut App) -> Task<Result<()>> {
2552 if agent_settings::AgentSettings::get_global(cx).always_allow_tool_actions {
2553 return Task::ready(Ok(()));
2554 }
2555
2556 let (response_tx, response_rx) = oneshot::channel();
2557 self.stream
2558 .0
2559 .unbounded_send(Ok(ThreadEvent::ToolCallAuthorization(
2560 ToolCallAuthorization {
2561 tool_call: acp::ToolCallUpdate::new(
2562 self.tool_use_id.to_string(),
2563 acp::ToolCallUpdateFields::new().title(title.into()),
2564 ),
2565 options: vec![
2566 acp::PermissionOption::new(
2567 acp::PermissionOptionId::new("always_allow"),
2568 "Always Allow",
2569 acp::PermissionOptionKind::AllowAlways,
2570 ),
2571 acp::PermissionOption::new(
2572 acp::PermissionOptionId::new("allow"),
2573 "Allow",
2574 acp::PermissionOptionKind::AllowOnce,
2575 ),
2576 acp::PermissionOption::new(
2577 acp::PermissionOptionId::new("deny"),
2578 "Deny",
2579 acp::PermissionOptionKind::RejectOnce,
2580 ),
2581 ],
2582 response: response_tx,
2583 },
2584 )))
2585 .ok();
2586 let fs = self.fs.clone();
2587 cx.spawn(async move |cx| match response_rx.await?.0.as_ref() {
2588 "always_allow" => {
2589 if let Some(fs) = fs.clone() {
2590 cx.update(|cx| {
2591 update_settings_file(fs, cx, |settings, _| {
2592 settings
2593 .agent
2594 .get_or_insert_default()
2595 .set_always_allow_tool_actions(true);
2596 });
2597 })?;
2598 }
2599
2600 Ok(())
2601 }
2602 "allow" => Ok(()),
2603 _ => Err(anyhow!("Permission to run tool denied by user")),
2604 })
2605 }
2606}
2607
2608#[cfg(any(test, feature = "test-support"))]
2609pub struct ToolCallEventStreamReceiver(mpsc::UnboundedReceiver<Result<ThreadEvent>>);
2610
2611#[cfg(any(test, feature = "test-support"))]
2612impl ToolCallEventStreamReceiver {
2613 pub async fn expect_authorization(&mut self) -> ToolCallAuthorization {
2614 let event = self.0.next().await;
2615 if let Some(Ok(ThreadEvent::ToolCallAuthorization(auth))) = event {
2616 auth
2617 } else {
2618 panic!("Expected ToolCallAuthorization but got: {:?}", event);
2619 }
2620 }
2621
2622 pub async fn expect_update_fields(&mut self) -> acp::ToolCallUpdateFields {
2623 let event = self.0.next().await;
2624 if let Some(Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
2625 update,
2626 )))) = event
2627 {
2628 update.fields
2629 } else {
2630 panic!("Expected update fields but got: {:?}", event);
2631 }
2632 }
2633
2634 pub async fn expect_diff(&mut self) -> Entity<acp_thread::Diff> {
2635 let event = self.0.next().await;
2636 if let Some(Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateDiff(
2637 update,
2638 )))) = event
2639 {
2640 update.diff
2641 } else {
2642 panic!("Expected diff but got: {:?}", event);
2643 }
2644 }
2645
2646 pub async fn expect_terminal(&mut self) -> Entity<acp_thread::Terminal> {
2647 let event = self.0.next().await;
2648 if let Some(Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateTerminal(
2649 update,
2650 )))) = event
2651 {
2652 update.terminal
2653 } else {
2654 panic!("Expected terminal but got: {:?}", event);
2655 }
2656 }
2657}
2658
2659#[cfg(any(test, feature = "test-support"))]
2660impl std::ops::Deref for ToolCallEventStreamReceiver {
2661 type Target = mpsc::UnboundedReceiver<Result<ThreadEvent>>;
2662
2663 fn deref(&self) -> &Self::Target {
2664 &self.0
2665 }
2666}
2667
2668#[cfg(any(test, feature = "test-support"))]
2669impl std::ops::DerefMut for ToolCallEventStreamReceiver {
2670 fn deref_mut(&mut self) -> &mut Self::Target {
2671 &mut self.0
2672 }
2673}
2674
2675impl From<&str> for UserMessageContent {
2676 fn from(text: &str) -> Self {
2677 Self::Text(text.into())
2678 }
2679}
2680
2681impl UserMessageContent {
2682 pub fn from_content_block(value: acp::ContentBlock, path_style: PathStyle) -> Self {
2683 match value {
2684 acp::ContentBlock::Text(text_content) => Self::Text(text_content.text),
2685 acp::ContentBlock::Image(image_content) => Self::Image(convert_image(image_content)),
2686 acp::ContentBlock::Audio(_) => {
2687 // TODO
2688 Self::Text("[audio]".to_string())
2689 }
2690 acp::ContentBlock::ResourceLink(resource_link) => {
2691 match MentionUri::parse(&resource_link.uri, path_style) {
2692 Ok(uri) => Self::Mention {
2693 uri,
2694 content: String::new(),
2695 },
2696 Err(err) => {
2697 log::error!("Failed to parse mention link: {}", err);
2698 Self::Text(format!("[{}]({})", resource_link.name, resource_link.uri))
2699 }
2700 }
2701 }
2702 acp::ContentBlock::Resource(resource) => match resource.resource {
2703 acp::EmbeddedResourceResource::TextResourceContents(resource) => {
2704 match MentionUri::parse(&resource.uri, path_style) {
2705 Ok(uri) => Self::Mention {
2706 uri,
2707 content: resource.text,
2708 },
2709 Err(err) => {
2710 log::error!("Failed to parse mention link: {}", err);
2711 Self::Text(
2712 MarkdownCodeBlock {
2713 tag: &resource.uri,
2714 text: &resource.text,
2715 }
2716 .to_string(),
2717 )
2718 }
2719 }
2720 }
2721 acp::EmbeddedResourceResource::BlobResourceContents(_) => {
2722 // TODO
2723 Self::Text("[blob]".to_string())
2724 }
2725 other => {
2726 log::warn!("Unexpected content type: {:?}", other);
2727 Self::Text("[unknown]".to_string())
2728 }
2729 },
2730 other => {
2731 log::warn!("Unexpected content type: {:?}", other);
2732 Self::Text("[unknown]".to_string())
2733 }
2734 }
2735 }
2736}
2737
2738impl From<UserMessageContent> for acp::ContentBlock {
2739 fn from(content: UserMessageContent) -> Self {
2740 match content {
2741 UserMessageContent::Text(text) => text.into(),
2742 UserMessageContent::Image(image) => {
2743 acp::ContentBlock::Image(acp::ImageContent::new(image.source, "image/png"))
2744 }
2745 UserMessageContent::Mention { uri, content } => acp::ContentBlock::Resource(
2746 acp::EmbeddedResource::new(acp::EmbeddedResourceResource::TextResourceContents(
2747 acp::TextResourceContents::new(content, uri.to_uri().to_string()),
2748 )),
2749 ),
2750 }
2751 }
2752}
2753
2754fn convert_image(image_content: acp::ImageContent) -> LanguageModelImage {
2755 LanguageModelImage {
2756 source: image_content.data.into(),
2757 size: None,
2758 }
2759}