1use crate::{
2 ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DbLanguageModel, DbThread,
3 DeletePathTool, DiagnosticsTool, EditFileTool, FetchTool, FindPathTool, GrepTool,
4 ListDirectoryTool, MovePathTool, NowTool, OpenTool, ReadFileTool, SystemPromptTemplate,
5 Template, Templates, TerminalTool, ThinkingTool, WebSearchTool,
6};
7use acp_thread::{MentionUri, UserMessageId};
8use action_log::ActionLog;
9use agent::thread::{DetailedSummaryState, GitState, ProjectSnapshot, WorktreeSnapshot};
10use agent_client_protocol as acp;
11use agent_settings::{AgentProfileId, AgentSettings, CompletionMode, SUMMARIZE_THREAD_PROMPT};
12use anyhow::{Context as _, Result, anyhow};
13use assistant_tool::adapt_schema_to_format;
14use chrono::{DateTime, Utc};
15use cloud_llm_client::{CompletionIntent, CompletionRequestStatus};
16use collections::{HashMap, IndexMap};
17use fs::Fs;
18use futures::{
19 FutureExt,
20 channel::{mpsc, oneshot},
21 future::Shared,
22 stream::FuturesUnordered,
23};
24use git::repository::DiffType;
25use gpui::{App, AppContext, AsyncApp, Context, Entity, SharedString, Task, WeakEntity};
26use language_model::{
27 LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelExt,
28 LanguageModelImage, LanguageModelProviderId, LanguageModelRegistry, LanguageModelRequest,
29 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
30 LanguageModelToolResultContent, LanguageModelToolSchemaFormat, LanguageModelToolUse,
31 LanguageModelToolUseId, Role, SelectedModel, StopReason, TokenUsage,
32};
33use project::{
34 Project,
35 git_store::{GitStore, RepositoryState},
36};
37use prompt_store::ProjectContext;
38use schemars::{JsonSchema, Schema};
39use serde::{Deserialize, Serialize};
40use settings::{Settings, update_settings_file};
41use smol::stream::StreamExt;
42use std::{
43 collections::BTreeMap,
44 path::Path,
45 sync::Arc,
46 time::{Duration, Instant},
47};
48use std::{fmt::Write, ops::Range};
49use util::{ResultExt, markdown::MarkdownCodeBlock};
50use uuid::Uuid;
51
52const TOOL_CANCELED_MESSAGE: &str = "Tool canceled by user";
53
54/// The ID of the user prompt that initiated a request.
55///
56/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
57#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
58pub struct PromptId(Arc<str>);
59
60impl PromptId {
61 pub fn new() -> Self {
62 Self(Uuid::new_v4().to_string().into())
63 }
64}
65
66impl std::fmt::Display for PromptId {
67 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
68 write!(f, "{}", self.0)
69 }
70}
71
72pub(crate) const MAX_RETRY_ATTEMPTS: u8 = 4;
73pub(crate) const BASE_RETRY_DELAY: Duration = Duration::from_secs(5);
74
75#[derive(Debug, Clone)]
76enum RetryStrategy {
77 ExponentialBackoff {
78 initial_delay: Duration,
79 max_attempts: u8,
80 },
81 Fixed {
82 delay: Duration,
83 max_attempts: u8,
84 },
85}
86
87#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
88pub enum Message {
89 User(UserMessage),
90 Agent(AgentMessage),
91 Resume,
92}
93
94impl Message {
95 pub fn as_agent_message(&self) -> Option<&AgentMessage> {
96 match self {
97 Message::Agent(agent_message) => Some(agent_message),
98 _ => None,
99 }
100 }
101
102 pub fn to_request(&self) -> Vec<LanguageModelRequestMessage> {
103 match self {
104 Message::User(message) => vec![message.to_request()],
105 Message::Agent(message) => message.to_request(),
106 Message::Resume => vec![LanguageModelRequestMessage {
107 role: Role::User,
108 content: vec!["Continue where you left off".into()],
109 cache: false,
110 }],
111 }
112 }
113
114 pub fn to_markdown(&self) -> String {
115 match self {
116 Message::User(message) => message.to_markdown(),
117 Message::Agent(message) => message.to_markdown(),
118 Message::Resume => "[resumed after tool use limit was reached]".into(),
119 }
120 }
121
122 pub fn role(&self) -> Role {
123 match self {
124 Message::User(_) | Message::Resume => Role::User,
125 Message::Agent(_) => Role::Assistant,
126 }
127 }
128}
129
130#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
131pub struct UserMessage {
132 pub id: UserMessageId,
133 pub content: Vec<UserMessageContent>,
134}
135
136#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
137pub enum UserMessageContent {
138 Text(String),
139 Mention { uri: MentionUri, content: String },
140 Image(LanguageModelImage),
141}
142
143impl UserMessage {
144 pub fn to_markdown(&self) -> String {
145 let mut markdown = String::from("## User\n\n");
146
147 for content in &self.content {
148 match content {
149 UserMessageContent::Text(text) => {
150 markdown.push_str(text);
151 markdown.push('\n');
152 }
153 UserMessageContent::Image(_) => {
154 markdown.push_str("<image />\n");
155 }
156 UserMessageContent::Mention { uri, content } => {
157 if !content.is_empty() {
158 let _ = write!(&mut markdown, "{}\n\n{}\n", uri.as_link(), content);
159 } else {
160 let _ = write!(&mut markdown, "{}\n", uri.as_link());
161 }
162 }
163 }
164 }
165
166 markdown
167 }
168
169 fn to_request(&self) -> LanguageModelRequestMessage {
170 let mut message = LanguageModelRequestMessage {
171 role: Role::User,
172 content: Vec::with_capacity(self.content.len()),
173 cache: false,
174 };
175
176 const OPEN_CONTEXT: &str = "<context>\n\
177 The following items were attached by the user. \
178 They are up-to-date and don't need to be re-read.\n\n";
179
180 const OPEN_FILES_TAG: &str = "<files>";
181 const OPEN_DIRECTORIES_TAG: &str = "<directories>";
182 const OPEN_SYMBOLS_TAG: &str = "<symbols>";
183 const OPEN_THREADS_TAG: &str = "<threads>";
184 const OPEN_FETCH_TAG: &str = "<fetched_urls>";
185 const OPEN_RULES_TAG: &str =
186 "<rules>\nThe user has specified the following rules that should be applied:\n";
187
188 let mut file_context = OPEN_FILES_TAG.to_string();
189 let mut directory_context = OPEN_DIRECTORIES_TAG.to_string();
190 let mut symbol_context = OPEN_SYMBOLS_TAG.to_string();
191 let mut thread_context = OPEN_THREADS_TAG.to_string();
192 let mut fetch_context = OPEN_FETCH_TAG.to_string();
193 let mut rules_context = OPEN_RULES_TAG.to_string();
194
195 for chunk in &self.content {
196 let chunk = match chunk {
197 UserMessageContent::Text(text) => {
198 language_model::MessageContent::Text(text.clone())
199 }
200 UserMessageContent::Image(value) => {
201 language_model::MessageContent::Image(value.clone())
202 }
203 UserMessageContent::Mention { uri, content } => {
204 match uri {
205 MentionUri::File { abs_path } => {
206 write!(
207 &mut symbol_context,
208 "\n{}",
209 MarkdownCodeBlock {
210 tag: &codeblock_tag(abs_path, None),
211 text: &content.to_string(),
212 }
213 )
214 .ok();
215 }
216 MentionUri::Directory { .. } => {
217 write!(&mut directory_context, "\n{}\n", content).ok();
218 }
219 MentionUri::Symbol {
220 path, line_range, ..
221 }
222 | MentionUri::Selection {
223 path, line_range, ..
224 } => {
225 write!(
226 &mut rules_context,
227 "\n{}",
228 MarkdownCodeBlock {
229 tag: &codeblock_tag(path, Some(line_range)),
230 text: content
231 }
232 )
233 .ok();
234 }
235 MentionUri::Thread { .. } => {
236 write!(&mut thread_context, "\n{}\n", content).ok();
237 }
238 MentionUri::TextThread { .. } => {
239 write!(&mut thread_context, "\n{}\n", content).ok();
240 }
241 MentionUri::Rule { .. } => {
242 write!(
243 &mut rules_context,
244 "\n{}",
245 MarkdownCodeBlock {
246 tag: "",
247 text: content
248 }
249 )
250 .ok();
251 }
252 MentionUri::Fetch { url } => {
253 write!(&mut fetch_context, "\nFetch: {}\n\n{}", url, content).ok();
254 }
255 }
256
257 language_model::MessageContent::Text(uri.as_link().to_string())
258 }
259 };
260
261 message.content.push(chunk);
262 }
263
264 let len_before_context = message.content.len();
265
266 if file_context.len() > OPEN_FILES_TAG.len() {
267 file_context.push_str("</files>\n");
268 message
269 .content
270 .push(language_model::MessageContent::Text(file_context));
271 }
272
273 if directory_context.len() > OPEN_DIRECTORIES_TAG.len() {
274 directory_context.push_str("</directories>\n");
275 message
276 .content
277 .push(language_model::MessageContent::Text(directory_context));
278 }
279
280 if symbol_context.len() > OPEN_SYMBOLS_TAG.len() {
281 symbol_context.push_str("</symbols>\n");
282 message
283 .content
284 .push(language_model::MessageContent::Text(symbol_context));
285 }
286
287 if thread_context.len() > OPEN_THREADS_TAG.len() {
288 thread_context.push_str("</threads>\n");
289 message
290 .content
291 .push(language_model::MessageContent::Text(thread_context));
292 }
293
294 if fetch_context.len() > OPEN_FETCH_TAG.len() {
295 fetch_context.push_str("</fetched_urls>\n");
296 message
297 .content
298 .push(language_model::MessageContent::Text(fetch_context));
299 }
300
301 if rules_context.len() > OPEN_RULES_TAG.len() {
302 rules_context.push_str("</user_rules>\n");
303 message
304 .content
305 .push(language_model::MessageContent::Text(rules_context));
306 }
307
308 if message.content.len() > len_before_context {
309 message.content.insert(
310 len_before_context,
311 language_model::MessageContent::Text(OPEN_CONTEXT.into()),
312 );
313 message
314 .content
315 .push(language_model::MessageContent::Text("</context>".into()));
316 }
317
318 message
319 }
320}
321
322fn codeblock_tag(full_path: &Path, line_range: Option<&Range<u32>>) -> String {
323 let mut result = String::new();
324
325 if let Some(extension) = full_path.extension().and_then(|ext| ext.to_str()) {
326 let _ = write!(result, "{} ", extension);
327 }
328
329 let _ = write!(result, "{}", full_path.display());
330
331 if let Some(range) = line_range {
332 if range.start == range.end {
333 let _ = write!(result, ":{}", range.start + 1);
334 } else {
335 let _ = write!(result, ":{}-{}", range.start + 1, range.end + 1);
336 }
337 }
338
339 result
340}
341
342impl AgentMessage {
343 pub fn to_markdown(&self) -> String {
344 let mut markdown = String::from("## Assistant\n\n");
345
346 for content in &self.content {
347 match content {
348 AgentMessageContent::Text(text) => {
349 markdown.push_str(text);
350 markdown.push('\n');
351 }
352 AgentMessageContent::Thinking { text, .. } => {
353 markdown.push_str("<think>");
354 markdown.push_str(text);
355 markdown.push_str("</think>\n");
356 }
357 AgentMessageContent::RedactedThinking(_) => {
358 markdown.push_str("<redacted_thinking />\n")
359 }
360 AgentMessageContent::ToolUse(tool_use) => {
361 markdown.push_str(&format!(
362 "**Tool Use**: {} (ID: {})\n",
363 tool_use.name, tool_use.id
364 ));
365 markdown.push_str(&format!(
366 "{}\n",
367 MarkdownCodeBlock {
368 tag: "json",
369 text: &format!("{:#}", tool_use.input)
370 }
371 ));
372 }
373 }
374 }
375
376 for tool_result in self.tool_results.values() {
377 markdown.push_str(&format!(
378 "**Tool Result**: {} (ID: {})\n\n",
379 tool_result.tool_name, tool_result.tool_use_id
380 ));
381 if tool_result.is_error {
382 markdown.push_str("**ERROR:**\n");
383 }
384
385 match &tool_result.content {
386 LanguageModelToolResultContent::Text(text) => {
387 writeln!(markdown, "{text}\n").ok();
388 }
389 LanguageModelToolResultContent::Image(_) => {
390 writeln!(markdown, "<image />\n").ok();
391 }
392 }
393
394 if let Some(output) = tool_result.output.as_ref() {
395 writeln!(
396 markdown,
397 "**Debug Output**:\n\n```json\n{}\n```\n",
398 serde_json::to_string_pretty(output).unwrap()
399 )
400 .unwrap();
401 }
402 }
403
404 markdown
405 }
406
407 pub fn to_request(&self) -> Vec<LanguageModelRequestMessage> {
408 let mut assistant_message = LanguageModelRequestMessage {
409 role: Role::Assistant,
410 content: Vec::with_capacity(self.content.len()),
411 cache: false,
412 };
413 for chunk in &self.content {
414 let chunk = match chunk {
415 AgentMessageContent::Text(text) => {
416 language_model::MessageContent::Text(text.clone())
417 }
418 AgentMessageContent::Thinking { text, signature } => {
419 language_model::MessageContent::Thinking {
420 text: text.clone(),
421 signature: signature.clone(),
422 }
423 }
424 AgentMessageContent::RedactedThinking(value) => {
425 language_model::MessageContent::RedactedThinking(value.clone())
426 }
427 AgentMessageContent::ToolUse(value) => {
428 language_model::MessageContent::ToolUse(value.clone())
429 }
430 };
431 assistant_message.content.push(chunk);
432 }
433
434 let mut user_message = LanguageModelRequestMessage {
435 role: Role::User,
436 content: Vec::new(),
437 cache: false,
438 };
439
440 for tool_result in self.tool_results.values() {
441 user_message
442 .content
443 .push(language_model::MessageContent::ToolResult(
444 tool_result.clone(),
445 ));
446 }
447
448 let mut messages = Vec::new();
449 if !assistant_message.content.is_empty() {
450 messages.push(assistant_message);
451 }
452 if !user_message.content.is_empty() {
453 messages.push(user_message);
454 }
455 messages
456 }
457}
458
459#[derive(Default, Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
460pub struct AgentMessage {
461 pub content: Vec<AgentMessageContent>,
462 pub tool_results: IndexMap<LanguageModelToolUseId, LanguageModelToolResult>,
463}
464
465#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
466pub enum AgentMessageContent {
467 Text(String),
468 Thinking {
469 text: String,
470 signature: Option<String>,
471 },
472 RedactedThinking(String),
473 ToolUse(LanguageModelToolUse),
474}
475
476#[derive(Debug)]
477pub enum ThreadEvent {
478 UserMessage(UserMessage),
479 AgentText(String),
480 AgentThinking(String),
481 ToolCall(acp::ToolCall),
482 ToolCallUpdate(acp_thread::ToolCallUpdate),
483 ToolCallAuthorization(ToolCallAuthorization),
484 TokenUsageUpdate(acp_thread::TokenUsage),
485 TitleUpdate(SharedString),
486 Retry(acp_thread::RetryStatus),
487 Stop(acp::StopReason),
488}
489
490#[derive(Debug)]
491pub struct ToolCallAuthorization {
492 pub tool_call: acp::ToolCallUpdate,
493 pub options: Vec<acp::PermissionOption>,
494 pub response: oneshot::Sender<acp::PermissionOptionId>,
495}
496
497pub struct Thread {
498 id: acp::SessionId,
499 prompt_id: PromptId,
500 updated_at: DateTime<Utc>,
501 title: Option<SharedString>,
502 #[allow(unused)]
503 summary: DetailedSummaryState,
504 messages: Vec<Message>,
505 completion_mode: CompletionMode,
506 /// Holds the task that handles agent interaction until the end of the turn.
507 /// Survives across multiple requests as the model performs tool calls and
508 /// we run tools, report their results.
509 running_turn: Option<RunningTurn>,
510 pending_message: Option<AgentMessage>,
511 tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
512 tool_use_limit_reached: bool,
513 request_token_usage: HashMap<UserMessageId, language_model::TokenUsage>,
514 #[allow(unused)]
515 cumulative_token_usage: TokenUsage,
516 #[allow(unused)]
517 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
518 context_server_registry: Entity<ContextServerRegistry>,
519 profile_id: AgentProfileId,
520 project_context: Entity<ProjectContext>,
521 templates: Arc<Templates>,
522 model: Option<Arc<dyn LanguageModel>>,
523 summarization_model: Option<Arc<dyn LanguageModel>>,
524 pub(crate) project: Entity<Project>,
525 pub(crate) action_log: Entity<ActionLog>,
526}
527
528impl Thread {
529 pub fn new(
530 project: Entity<Project>,
531 project_context: Entity<ProjectContext>,
532 context_server_registry: Entity<ContextServerRegistry>,
533 action_log: Entity<ActionLog>,
534 templates: Arc<Templates>,
535 model: Option<Arc<dyn LanguageModel>>,
536 cx: &mut Context<Self>,
537 ) -> Self {
538 let profile_id = AgentSettings::get_global(cx).default_profile.clone();
539 Self {
540 id: acp::SessionId(uuid::Uuid::new_v4().to_string().into()),
541 prompt_id: PromptId::new(),
542 updated_at: Utc::now(),
543 title: None,
544 summary: DetailedSummaryState::default(),
545 messages: Vec::new(),
546 completion_mode: AgentSettings::get_global(cx).preferred_completion_mode,
547 running_turn: None,
548 pending_message: None,
549 tools: BTreeMap::default(),
550 tool_use_limit_reached: false,
551 request_token_usage: HashMap::default(),
552 cumulative_token_usage: TokenUsage::default(),
553 initial_project_snapshot: {
554 let project_snapshot = Self::project_snapshot(project.clone(), cx);
555 cx.foreground_executor()
556 .spawn(async move { Some(project_snapshot.await) })
557 .shared()
558 },
559 context_server_registry,
560 profile_id,
561 project_context,
562 templates,
563 model,
564 summarization_model: None,
565 project,
566 action_log,
567 }
568 }
569
570 pub fn id(&self) -> &acp::SessionId {
571 &self.id
572 }
573
574 pub fn replay(
575 &mut self,
576 cx: &mut Context<Self>,
577 ) -> mpsc::UnboundedReceiver<Result<ThreadEvent>> {
578 let (tx, rx) = mpsc::unbounded();
579 let stream = ThreadEventStream(tx);
580 for message in &self.messages {
581 match message {
582 Message::User(user_message) => stream.send_user_message(user_message),
583 Message::Agent(assistant_message) => {
584 for content in &assistant_message.content {
585 match content {
586 AgentMessageContent::Text(text) => stream.send_text(text),
587 AgentMessageContent::Thinking { text, .. } => {
588 stream.send_thinking(text)
589 }
590 AgentMessageContent::RedactedThinking(_) => {}
591 AgentMessageContent::ToolUse(tool_use) => {
592 self.replay_tool_call(
593 tool_use,
594 assistant_message.tool_results.get(&tool_use.id),
595 &stream,
596 cx,
597 );
598 }
599 }
600 }
601 }
602 Message::Resume => {}
603 }
604 }
605 rx
606 }
607
608 fn replay_tool_call(
609 &self,
610 tool_use: &LanguageModelToolUse,
611 tool_result: Option<&LanguageModelToolResult>,
612 stream: &ThreadEventStream,
613 cx: &mut Context<Self>,
614 ) {
615 let Some(tool) = self.tools.get(tool_use.name.as_ref()) else {
616 stream
617 .0
618 .unbounded_send(Ok(ThreadEvent::ToolCall(acp::ToolCall {
619 id: acp::ToolCallId(tool_use.id.to_string().into()),
620 title: tool_use.name.to_string(),
621 kind: acp::ToolKind::Other,
622 status: acp::ToolCallStatus::Failed,
623 content: Vec::new(),
624 locations: Vec::new(),
625 raw_input: Some(tool_use.input.clone()),
626 raw_output: None,
627 })))
628 .ok();
629 return;
630 };
631
632 let title = tool.initial_title(tool_use.input.clone());
633 let kind = tool.kind();
634 stream.send_tool_call(&tool_use.id, title, kind, tool_use.input.clone());
635
636 let output = tool_result
637 .as_ref()
638 .and_then(|result| result.output.clone());
639 if let Some(output) = output.clone() {
640 let tool_event_stream = ToolCallEventStream::new(
641 tool_use.id.clone(),
642 stream.clone(),
643 Some(self.project.read(cx).fs().clone()),
644 );
645 tool.replay(tool_use.input.clone(), output, tool_event_stream, cx)
646 .log_err();
647 }
648
649 stream.update_tool_call_fields(
650 &tool_use.id,
651 acp::ToolCallUpdateFields {
652 status: Some(acp::ToolCallStatus::Completed),
653 raw_output: output,
654 ..Default::default()
655 },
656 );
657 }
658
659 pub fn from_db(
660 id: acp::SessionId,
661 db_thread: DbThread,
662 project: Entity<Project>,
663 project_context: Entity<ProjectContext>,
664 context_server_registry: Entity<ContextServerRegistry>,
665 action_log: Entity<ActionLog>,
666 templates: Arc<Templates>,
667 cx: &mut Context<Self>,
668 ) -> Self {
669 let profile_id = db_thread
670 .profile
671 .unwrap_or_else(|| AgentSettings::get_global(cx).default_profile.clone());
672 let model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
673 db_thread
674 .model
675 .and_then(|model| {
676 let model = SelectedModel {
677 provider: model.provider.clone().into(),
678 model: model.model.clone().into(),
679 };
680 registry.select_model(&model, cx)
681 })
682 .or_else(|| registry.default_model())
683 .map(|model| model.model)
684 });
685
686 Self {
687 id,
688 prompt_id: PromptId::new(),
689 title: if db_thread.title.is_empty() {
690 None
691 } else {
692 Some(db_thread.title.clone())
693 },
694 summary: db_thread.summary,
695 messages: db_thread.messages,
696 completion_mode: db_thread.completion_mode.unwrap_or_default(),
697 running_turn: None,
698 pending_message: None,
699 tools: BTreeMap::default(),
700 tool_use_limit_reached: false,
701 request_token_usage: db_thread.request_token_usage.clone(),
702 cumulative_token_usage: db_thread.cumulative_token_usage,
703 initial_project_snapshot: Task::ready(db_thread.initial_project_snapshot).shared(),
704 context_server_registry,
705 profile_id,
706 project_context,
707 templates,
708 model,
709 summarization_model: None,
710 project,
711 action_log,
712 updated_at: db_thread.updated_at,
713 }
714 }
715
716 pub fn to_db(&self, cx: &App) -> Task<DbThread> {
717 let initial_project_snapshot = self.initial_project_snapshot.clone();
718 let mut thread = DbThread {
719 title: self.title.clone().unwrap_or_default(),
720 messages: self.messages.clone(),
721 updated_at: self.updated_at,
722 summary: self.summary.clone(),
723 initial_project_snapshot: None,
724 cumulative_token_usage: self.cumulative_token_usage,
725 request_token_usage: self.request_token_usage.clone(),
726 model: self.model.as_ref().map(|model| DbLanguageModel {
727 provider: model.provider_id().to_string(),
728 model: model.name().0.to_string(),
729 }),
730 completion_mode: Some(self.completion_mode),
731 profile: Some(self.profile_id.clone()),
732 };
733
734 cx.background_spawn(async move {
735 let initial_project_snapshot = initial_project_snapshot.await;
736 thread.initial_project_snapshot = initial_project_snapshot;
737 thread
738 })
739 }
740
741 /// Create a snapshot of the current project state including git information and unsaved buffers.
742 fn project_snapshot(
743 project: Entity<Project>,
744 cx: &mut Context<Self>,
745 ) -> Task<Arc<agent::thread::ProjectSnapshot>> {
746 let git_store = project.read(cx).git_store().clone();
747 let worktree_snapshots: Vec<_> = project
748 .read(cx)
749 .visible_worktrees(cx)
750 .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
751 .collect();
752
753 cx.spawn(async move |_, cx| {
754 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
755
756 let mut unsaved_buffers = Vec::new();
757 cx.update(|app_cx| {
758 let buffer_store = project.read(app_cx).buffer_store();
759 for buffer_handle in buffer_store.read(app_cx).buffers() {
760 let buffer = buffer_handle.read(app_cx);
761 if buffer.is_dirty()
762 && let Some(file) = buffer.file()
763 {
764 let path = file.path().to_string_lossy().to_string();
765 unsaved_buffers.push(path);
766 }
767 }
768 })
769 .ok();
770
771 Arc::new(ProjectSnapshot {
772 worktree_snapshots,
773 unsaved_buffer_paths: unsaved_buffers,
774 timestamp: Utc::now(),
775 })
776 })
777 }
778
779 fn worktree_snapshot(
780 worktree: Entity<project::Worktree>,
781 git_store: Entity<GitStore>,
782 cx: &App,
783 ) -> Task<agent::thread::WorktreeSnapshot> {
784 cx.spawn(async move |cx| {
785 // Get worktree path and snapshot
786 let worktree_info = cx.update(|app_cx| {
787 let worktree = worktree.read(app_cx);
788 let path = worktree.abs_path().to_string_lossy().to_string();
789 let snapshot = worktree.snapshot();
790 (path, snapshot)
791 });
792
793 let Ok((worktree_path, _snapshot)) = worktree_info else {
794 return WorktreeSnapshot {
795 worktree_path: String::new(),
796 git_state: None,
797 };
798 };
799
800 let git_state = git_store
801 .update(cx, |git_store, cx| {
802 git_store
803 .repositories()
804 .values()
805 .find(|repo| {
806 repo.read(cx)
807 .abs_path_to_repo_path(&worktree.read(cx).abs_path())
808 .is_some()
809 })
810 .cloned()
811 })
812 .ok()
813 .flatten()
814 .map(|repo| {
815 repo.update(cx, |repo, _| {
816 let current_branch =
817 repo.branch.as_ref().map(|branch| branch.name().to_owned());
818 repo.send_job(None, |state, _| async move {
819 let RepositoryState::Local { backend, .. } = state else {
820 return GitState {
821 remote_url: None,
822 head_sha: None,
823 current_branch,
824 diff: None,
825 };
826 };
827
828 let remote_url = backend.remote_url("origin");
829 let head_sha = backend.head_sha().await;
830 let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
831
832 GitState {
833 remote_url,
834 head_sha,
835 current_branch,
836 diff,
837 }
838 })
839 })
840 });
841
842 let git_state = match git_state {
843 Some(git_state) => match git_state.ok() {
844 Some(git_state) => git_state.await.ok(),
845 None => None,
846 },
847 None => None,
848 };
849
850 WorktreeSnapshot {
851 worktree_path,
852 git_state,
853 }
854 })
855 }
856
857 pub fn project_context(&self) -> &Entity<ProjectContext> {
858 &self.project_context
859 }
860
861 pub fn project(&self) -> &Entity<Project> {
862 &self.project
863 }
864
865 pub fn action_log(&self) -> &Entity<ActionLog> {
866 &self.action_log
867 }
868
869 pub fn model(&self) -> Option<&Arc<dyn LanguageModel>> {
870 self.model.as_ref()
871 }
872
873 pub fn set_model(&mut self, model: Arc<dyn LanguageModel>, cx: &mut Context<Self>) {
874 self.model = Some(model);
875 cx.notify()
876 }
877
878 pub fn set_summarization_model(
879 &mut self,
880 model: Option<Arc<dyn LanguageModel>>,
881 cx: &mut Context<Self>,
882 ) {
883 self.summarization_model = model;
884 cx.notify()
885 }
886
887 pub fn completion_mode(&self) -> CompletionMode {
888 self.completion_mode
889 }
890
891 pub fn set_completion_mode(&mut self, mode: CompletionMode, cx: &mut Context<Self>) {
892 self.completion_mode = mode;
893 cx.notify()
894 }
895
896 #[cfg(any(test, feature = "test-support"))]
897 pub fn last_message(&self) -> Option<Message> {
898 if let Some(message) = self.pending_message.clone() {
899 Some(Message::Agent(message))
900 } else {
901 self.messages.last().cloned()
902 }
903 }
904
905 pub fn add_default_tools(&mut self, cx: &mut Context<Self>) {
906 let language_registry = self.project.read(cx).languages().clone();
907 self.add_tool(CopyPathTool::new(self.project.clone()));
908 self.add_tool(CreateDirectoryTool::new(self.project.clone()));
909 self.add_tool(DeletePathTool::new(
910 self.project.clone(),
911 self.action_log.clone(),
912 ));
913 self.add_tool(DiagnosticsTool::new(self.project.clone()));
914 self.add_tool(EditFileTool::new(cx.weak_entity(), language_registry));
915 self.add_tool(FetchTool::new(self.project.read(cx).client().http_client()));
916 self.add_tool(FindPathTool::new(self.project.clone()));
917 self.add_tool(GrepTool::new(self.project.clone()));
918 self.add_tool(ListDirectoryTool::new(self.project.clone()));
919 self.add_tool(MovePathTool::new(self.project.clone()));
920 self.add_tool(NowTool);
921 self.add_tool(OpenTool::new(self.project.clone()));
922 self.add_tool(ReadFileTool::new(
923 self.project.clone(),
924 self.action_log.clone(),
925 ));
926 self.add_tool(TerminalTool::new(self.project.clone(), cx));
927 self.add_tool(ThinkingTool);
928 self.add_tool(WebSearchTool); // TODO: Enable this only if it's a zed model.
929 }
930
931 pub fn add_tool(&mut self, tool: impl AgentTool) {
932 self.tools.insert(tool.name(), tool.erase());
933 }
934
935 pub fn remove_tool(&mut self, name: &str) -> bool {
936 self.tools.remove(name).is_some()
937 }
938
939 pub fn profile(&self) -> &AgentProfileId {
940 &self.profile_id
941 }
942
943 pub fn set_profile(&mut self, profile_id: AgentProfileId) {
944 self.profile_id = profile_id;
945 }
946
947 pub fn cancel(&mut self, cx: &mut Context<Self>) {
948 if let Some(running_turn) = self.running_turn.take() {
949 running_turn.cancel();
950 }
951 self.flush_pending_message(cx);
952 }
953
954 pub fn update_token_usage(&mut self, update: language_model::TokenUsage) {
955 let Some(last_user_message) = self.last_user_message() else {
956 return;
957 };
958
959 self.request_token_usage
960 .insert(last_user_message.id.clone(), update);
961 }
962
963 pub fn truncate(&mut self, message_id: UserMessageId, cx: &mut Context<Self>) -> Result<()> {
964 self.cancel(cx);
965 let Some(position) = self.messages.iter().position(
966 |msg| matches!(msg, Message::User(UserMessage { id, .. }) if id == &message_id),
967 ) else {
968 return Err(anyhow!("Message not found"));
969 };
970
971 for message in self.messages.drain(position..) {
972 match message {
973 Message::User(message) => {
974 self.request_token_usage.remove(&message.id);
975 }
976 Message::Agent(_) | Message::Resume => {}
977 }
978 }
979
980 cx.notify();
981 Ok(())
982 }
983
984 pub fn latest_token_usage(&self) -> Option<acp_thread::TokenUsage> {
985 let last_user_message = self.last_user_message()?;
986 let tokens = self.request_token_usage.get(&last_user_message.id)?;
987 let model = self.model.clone()?;
988
989 Some(acp_thread::TokenUsage {
990 max_tokens: model.max_token_count_for_mode(self.completion_mode.into()),
991 used_tokens: tokens.total_tokens(),
992 })
993 }
994
995 pub fn resume(
996 &mut self,
997 cx: &mut Context<Self>,
998 ) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>> {
999 anyhow::ensure!(
1000 self.tool_use_limit_reached,
1001 "can only resume after tool use limit is reached"
1002 );
1003
1004 self.messages.push(Message::Resume);
1005 cx.notify();
1006
1007 log::info!("Total messages in thread: {}", self.messages.len());
1008 self.run_turn(cx)
1009 }
1010
1011 /// Sending a message results in the model streaming a response, which could include tool calls.
1012 /// After calling tools, the model will stops and waits for any outstanding tool calls to be completed and their results sent.
1013 /// The returned channel will report all the occurrences in which the model stops before erroring or ending its turn.
1014 pub fn send<T>(
1015 &mut self,
1016 id: UserMessageId,
1017 content: impl IntoIterator<Item = T>,
1018 cx: &mut Context<Self>,
1019 ) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>
1020 where
1021 T: Into<UserMessageContent>,
1022 {
1023 let model = self.model().context("No language model configured")?;
1024
1025 log::info!("Thread::send called with model: {:?}", model.name());
1026 self.advance_prompt_id();
1027
1028 let content = content.into_iter().map(Into::into).collect::<Vec<_>>();
1029 log::debug!("Thread::send content: {:?}", content);
1030
1031 self.messages
1032 .push(Message::User(UserMessage { id, content }));
1033 cx.notify();
1034
1035 log::info!("Total messages in thread: {}", self.messages.len());
1036 self.run_turn(cx)
1037 }
1038
1039 fn run_turn(
1040 &mut self,
1041 cx: &mut Context<Self>,
1042 ) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>> {
1043 self.cancel(cx);
1044
1045 let model = self.model.clone().context("No language model configured")?;
1046 let (events_tx, events_rx) = mpsc::unbounded::<Result<ThreadEvent>>();
1047 let event_stream = ThreadEventStream(events_tx);
1048 let message_ix = self.messages.len().saturating_sub(1);
1049 self.tool_use_limit_reached = false;
1050 self.running_turn = Some(RunningTurn {
1051 event_stream: event_stream.clone(),
1052 _task: cx.spawn(async move |this, cx| {
1053 log::info!("Starting agent turn execution");
1054 let turn_result: Result<StopReason> = async {
1055 let mut completion_intent = CompletionIntent::UserPrompt;
1056 loop {
1057 log::debug!(
1058 "Building completion request with intent: {:?}",
1059 completion_intent
1060 );
1061 let request = this.update(cx, |this, cx| {
1062 this.build_completion_request(completion_intent, cx)
1063 })??;
1064
1065 log::info!("Calling model.stream_completion");
1066
1067 let mut tool_use_limit_reached = false;
1068 let mut refused = false;
1069 let mut reached_max_tokens = false;
1070 let mut tool_uses = Self::stream_completion_with_retries(
1071 this.clone(),
1072 model.clone(),
1073 request,
1074 &event_stream,
1075 &mut tool_use_limit_reached,
1076 &mut refused,
1077 &mut reached_max_tokens,
1078 cx,
1079 )
1080 .await?;
1081
1082 if refused {
1083 return Ok(StopReason::Refusal);
1084 } else if reached_max_tokens {
1085 return Ok(StopReason::MaxTokens);
1086 }
1087
1088 let end_turn = tool_uses.is_empty();
1089 while let Some(tool_result) = tool_uses.next().await {
1090 log::info!("Tool finished {:?}", tool_result);
1091
1092 event_stream.update_tool_call_fields(
1093 &tool_result.tool_use_id,
1094 acp::ToolCallUpdateFields {
1095 status: Some(if tool_result.is_error {
1096 acp::ToolCallStatus::Failed
1097 } else {
1098 acp::ToolCallStatus::Completed
1099 }),
1100 raw_output: tool_result.output.clone(),
1101 ..Default::default()
1102 },
1103 );
1104 this.update(cx, |this, _cx| {
1105 this.pending_message()
1106 .tool_results
1107 .insert(tool_result.tool_use_id.clone(), tool_result);
1108 })
1109 .ok();
1110 }
1111
1112 if tool_use_limit_reached {
1113 log::info!("Tool use limit reached, completing turn");
1114 this.update(cx, |this, _cx| this.tool_use_limit_reached = true)?;
1115 return Err(language_model::ToolUseLimitReachedError.into());
1116 } else if end_turn {
1117 log::info!("No tool uses found, completing turn");
1118 return Ok(StopReason::EndTurn);
1119 } else {
1120 this.update(cx, |this, cx| this.flush_pending_message(cx))?;
1121 completion_intent = CompletionIntent::ToolResults;
1122 }
1123 }
1124 }
1125 .await;
1126 _ = this.update(cx, |this, cx| this.flush_pending_message(cx));
1127
1128 match turn_result {
1129 Ok(reason) => {
1130 log::info!("Turn execution completed: {:?}", reason);
1131
1132 let update_title = this
1133 .update(cx, |this, cx| this.update_title(&event_stream, cx))
1134 .ok()
1135 .flatten();
1136 if let Some(update_title) = update_title {
1137 update_title.await.context("update title failed").log_err();
1138 }
1139
1140 event_stream.send_stop(reason);
1141 if reason == StopReason::Refusal {
1142 _ = this.update(cx, |this, _| this.messages.truncate(message_ix));
1143 }
1144 }
1145 Err(error) => {
1146 log::error!("Turn execution failed: {:?}", error);
1147 event_stream.send_error(error);
1148 }
1149 }
1150
1151 _ = this.update(cx, |this, _| this.running_turn.take());
1152 }),
1153 });
1154 Ok(events_rx)
1155 }
1156
1157 async fn stream_completion_with_retries(
1158 this: WeakEntity<Self>,
1159 model: Arc<dyn LanguageModel>,
1160 request: LanguageModelRequest,
1161 event_stream: &ThreadEventStream,
1162 tool_use_limit_reached: &mut bool,
1163 refusal: &mut bool,
1164 max_tokens_reached: &mut bool,
1165 cx: &mut AsyncApp,
1166 ) -> Result<FuturesUnordered<Task<LanguageModelToolResult>>> {
1167 log::debug!("Stream completion started successfully");
1168
1169 let mut attempt = None;
1170 'retry: loop {
1171 let mut events = model.stream_completion(request.clone(), cx).await?;
1172 let mut tool_uses = FuturesUnordered::new();
1173 while let Some(event) = events.next().await {
1174 match event {
1175 Ok(LanguageModelCompletionEvent::StatusUpdate(
1176 CompletionRequestStatus::ToolUseLimitReached,
1177 )) => {
1178 *tool_use_limit_reached = true;
1179 }
1180 Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
1181 let usage = acp_thread::TokenUsage {
1182 max_tokens: model.max_token_count_for_mode(
1183 request
1184 .mode
1185 .unwrap_or(cloud_llm_client::CompletionMode::Normal),
1186 ),
1187 used_tokens: token_usage.total_tokens(),
1188 };
1189
1190 this.update(cx, |this, _cx| this.update_token_usage(token_usage))
1191 .ok();
1192
1193 event_stream.send_token_usage_update(usage);
1194 }
1195 Ok(LanguageModelCompletionEvent::Stop(StopReason::Refusal)) => {
1196 *refusal = true;
1197 return Ok(FuturesUnordered::default());
1198 }
1199 Ok(LanguageModelCompletionEvent::Stop(StopReason::MaxTokens)) => {
1200 *max_tokens_reached = true;
1201 return Ok(FuturesUnordered::default());
1202 }
1203 Ok(LanguageModelCompletionEvent::Stop(
1204 StopReason::ToolUse | StopReason::EndTurn,
1205 )) => break,
1206 Ok(event) => {
1207 log::trace!("Received completion event: {:?}", event);
1208 this.update(cx, |this, cx| {
1209 tool_uses.extend(this.handle_streamed_completion_event(
1210 event,
1211 event_stream,
1212 cx,
1213 ));
1214 })
1215 .ok();
1216 }
1217 Err(error) => {
1218 let completion_mode =
1219 this.read_with(cx, |thread, _cx| thread.completion_mode())?;
1220 if completion_mode == CompletionMode::Normal {
1221 return Err(error.into());
1222 }
1223
1224 let Some(strategy) = Self::retry_strategy_for(&error) else {
1225 return Err(error.into());
1226 };
1227
1228 let max_attempts = match &strategy {
1229 RetryStrategy::ExponentialBackoff { max_attempts, .. } => *max_attempts,
1230 RetryStrategy::Fixed { max_attempts, .. } => *max_attempts,
1231 };
1232
1233 let attempt = attempt.get_or_insert(0u8);
1234
1235 *attempt += 1;
1236
1237 let attempt = *attempt;
1238 if attempt > max_attempts {
1239 return Err(error.into());
1240 }
1241
1242 let delay = match &strategy {
1243 RetryStrategy::ExponentialBackoff { initial_delay, .. } => {
1244 let delay_secs =
1245 initial_delay.as_secs() * 2u64.pow((attempt - 1) as u32);
1246 Duration::from_secs(delay_secs)
1247 }
1248 RetryStrategy::Fixed { delay, .. } => *delay,
1249 };
1250 log::debug!("Retry attempt {attempt} with delay {delay:?}");
1251
1252 event_stream.send_retry(acp_thread::RetryStatus {
1253 last_error: error.to_string().into(),
1254 attempt: attempt as usize,
1255 max_attempts: max_attempts as usize,
1256 started_at: Instant::now(),
1257 duration: delay,
1258 });
1259
1260 cx.background_executor().timer(delay).await;
1261 continue 'retry;
1262 }
1263 }
1264 }
1265
1266 return Ok(tool_uses);
1267 }
1268 }
1269
1270 pub fn build_system_message(&self, cx: &App) -> LanguageModelRequestMessage {
1271 log::debug!("Building system message");
1272 let prompt = SystemPromptTemplate {
1273 project: self.project_context.read(cx),
1274 available_tools: self.tools.keys().cloned().collect(),
1275 }
1276 .render(&self.templates)
1277 .context("failed to build system prompt")
1278 .expect("Invalid template");
1279 log::debug!("System message built");
1280 LanguageModelRequestMessage {
1281 role: Role::System,
1282 content: vec![prompt.into()],
1283 cache: true,
1284 }
1285 }
1286
1287 /// A helper method that's called on every streamed completion event.
1288 /// Returns an optional tool result task, which the main agentic loop in
1289 /// send will send back to the model when it resolves.
1290 fn handle_streamed_completion_event(
1291 &mut self,
1292 event: LanguageModelCompletionEvent,
1293 event_stream: &ThreadEventStream,
1294 cx: &mut Context<Self>,
1295 ) -> Option<Task<LanguageModelToolResult>> {
1296 log::trace!("Handling streamed completion event: {:?}", event);
1297 use LanguageModelCompletionEvent::*;
1298
1299 match event {
1300 StartMessage { .. } => {
1301 self.flush_pending_message(cx);
1302 self.pending_message = Some(AgentMessage::default());
1303 }
1304 Text(new_text) => self.handle_text_event(new_text, event_stream, cx),
1305 Thinking { text, signature } => {
1306 self.handle_thinking_event(text, signature, event_stream, cx)
1307 }
1308 RedactedThinking { data } => self.handle_redacted_thinking_event(data, cx),
1309 ToolUse(tool_use) => {
1310 return self.handle_tool_use_event(tool_use, event_stream, cx);
1311 }
1312 ToolUseJsonParseError {
1313 id,
1314 tool_name,
1315 raw_input,
1316 json_parse_error,
1317 } => {
1318 return Some(Task::ready(self.handle_tool_use_json_parse_error_event(
1319 id,
1320 tool_name,
1321 raw_input,
1322 json_parse_error,
1323 )));
1324 }
1325 UsageUpdate(_) | StatusUpdate(_) => {}
1326 Stop(_) => unreachable!(),
1327 }
1328
1329 None
1330 }
1331
1332 fn handle_text_event(
1333 &mut self,
1334 new_text: String,
1335 event_stream: &ThreadEventStream,
1336 cx: &mut Context<Self>,
1337 ) {
1338 event_stream.send_text(&new_text);
1339
1340 let last_message = self.pending_message();
1341 if let Some(AgentMessageContent::Text(text)) = last_message.content.last_mut() {
1342 text.push_str(&new_text);
1343 } else {
1344 last_message
1345 .content
1346 .push(AgentMessageContent::Text(new_text));
1347 }
1348
1349 cx.notify();
1350 }
1351
1352 fn handle_thinking_event(
1353 &mut self,
1354 new_text: String,
1355 new_signature: Option<String>,
1356 event_stream: &ThreadEventStream,
1357 cx: &mut Context<Self>,
1358 ) {
1359 event_stream.send_thinking(&new_text);
1360
1361 let last_message = self.pending_message();
1362 if let Some(AgentMessageContent::Thinking { text, signature }) =
1363 last_message.content.last_mut()
1364 {
1365 text.push_str(&new_text);
1366 *signature = new_signature.or(signature.take());
1367 } else {
1368 last_message.content.push(AgentMessageContent::Thinking {
1369 text: new_text,
1370 signature: new_signature,
1371 });
1372 }
1373
1374 cx.notify();
1375 }
1376
1377 fn handle_redacted_thinking_event(&mut self, data: String, cx: &mut Context<Self>) {
1378 let last_message = self.pending_message();
1379 last_message
1380 .content
1381 .push(AgentMessageContent::RedactedThinking(data));
1382 cx.notify();
1383 }
1384
1385 fn handle_tool_use_event(
1386 &mut self,
1387 tool_use: LanguageModelToolUse,
1388 event_stream: &ThreadEventStream,
1389 cx: &mut Context<Self>,
1390 ) -> Option<Task<LanguageModelToolResult>> {
1391 cx.notify();
1392
1393 let tool = self.tools.get(tool_use.name.as_ref()).cloned();
1394 let mut title = SharedString::from(&tool_use.name);
1395 let mut kind = acp::ToolKind::Other;
1396 if let Some(tool) = tool.as_ref() {
1397 title = tool.initial_title(tool_use.input.clone());
1398 kind = tool.kind();
1399 }
1400
1401 // Ensure the last message ends in the current tool use
1402 let last_message = self.pending_message();
1403 let push_new_tool_use = last_message.content.last_mut().is_none_or(|content| {
1404 if let AgentMessageContent::ToolUse(last_tool_use) = content {
1405 if last_tool_use.id == tool_use.id {
1406 *last_tool_use = tool_use.clone();
1407 false
1408 } else {
1409 true
1410 }
1411 } else {
1412 true
1413 }
1414 });
1415
1416 if push_new_tool_use {
1417 event_stream.send_tool_call(&tool_use.id, title, kind, tool_use.input.clone());
1418 last_message
1419 .content
1420 .push(AgentMessageContent::ToolUse(tool_use.clone()));
1421 } else {
1422 event_stream.update_tool_call_fields(
1423 &tool_use.id,
1424 acp::ToolCallUpdateFields {
1425 title: Some(title.into()),
1426 kind: Some(kind),
1427 raw_input: Some(tool_use.input.clone()),
1428 ..Default::default()
1429 },
1430 );
1431 }
1432
1433 if !tool_use.is_input_complete {
1434 return None;
1435 }
1436
1437 let Some(tool) = tool else {
1438 let content = format!("No tool named {} exists", tool_use.name);
1439 return Some(Task::ready(LanguageModelToolResult {
1440 content: LanguageModelToolResultContent::Text(Arc::from(content)),
1441 tool_use_id: tool_use.id,
1442 tool_name: tool_use.name,
1443 is_error: true,
1444 output: None,
1445 }));
1446 };
1447
1448 let fs = self.project.read(cx).fs().clone();
1449 let tool_event_stream =
1450 ToolCallEventStream::new(tool_use.id.clone(), event_stream.clone(), Some(fs));
1451 tool_event_stream.update_fields(acp::ToolCallUpdateFields {
1452 status: Some(acp::ToolCallStatus::InProgress),
1453 ..Default::default()
1454 });
1455 let supports_images = self.model().is_some_and(|model| model.supports_images());
1456 let tool_result = tool.run(tool_use.input, tool_event_stream, cx);
1457 log::info!("Running tool {}", tool_use.name);
1458 Some(cx.foreground_executor().spawn(async move {
1459 let tool_result = tool_result.await.and_then(|output| {
1460 if let LanguageModelToolResultContent::Image(_) = &output.llm_output
1461 && !supports_images
1462 {
1463 return Err(anyhow!(
1464 "Attempted to read an image, but this model doesn't support it.",
1465 ));
1466 }
1467 Ok(output)
1468 });
1469
1470 match tool_result {
1471 Ok(output) => LanguageModelToolResult {
1472 tool_use_id: tool_use.id,
1473 tool_name: tool_use.name,
1474 is_error: false,
1475 content: output.llm_output,
1476 output: Some(output.raw_output),
1477 },
1478 Err(error) => LanguageModelToolResult {
1479 tool_use_id: tool_use.id,
1480 tool_name: tool_use.name,
1481 is_error: true,
1482 content: LanguageModelToolResultContent::Text(Arc::from(error.to_string())),
1483 output: None,
1484 },
1485 }
1486 }))
1487 }
1488
1489 fn handle_tool_use_json_parse_error_event(
1490 &mut self,
1491 tool_use_id: LanguageModelToolUseId,
1492 tool_name: Arc<str>,
1493 raw_input: Arc<str>,
1494 json_parse_error: String,
1495 ) -> LanguageModelToolResult {
1496 let tool_output = format!("Error parsing input JSON: {json_parse_error}");
1497 LanguageModelToolResult {
1498 tool_use_id,
1499 tool_name,
1500 is_error: true,
1501 content: LanguageModelToolResultContent::Text(tool_output.into()),
1502 output: Some(serde_json::Value::String(raw_input.to_string())),
1503 }
1504 }
1505
1506 pub fn title(&self) -> SharedString {
1507 self.title.clone().unwrap_or("New Thread".into())
1508 }
1509
1510 fn update_title(
1511 &mut self,
1512 event_stream: &ThreadEventStream,
1513 cx: &mut Context<Self>,
1514 ) -> Option<Task<Result<()>>> {
1515 if self.title.is_some() {
1516 log::debug!("Skipping title generation because we already have one.");
1517 return None;
1518 }
1519
1520 log::info!(
1521 "Generating title with model: {:?}",
1522 self.summarization_model.as_ref().map(|model| model.name())
1523 );
1524 let model = self.summarization_model.clone()?;
1525 let event_stream = event_stream.clone();
1526 let mut request = LanguageModelRequest {
1527 intent: Some(CompletionIntent::ThreadSummarization),
1528 temperature: AgentSettings::temperature_for_model(&model, cx),
1529 ..Default::default()
1530 };
1531
1532 for message in &self.messages {
1533 request.messages.extend(message.to_request());
1534 }
1535
1536 request.messages.push(LanguageModelRequestMessage {
1537 role: Role::User,
1538 content: vec![SUMMARIZE_THREAD_PROMPT.into()],
1539 cache: false,
1540 });
1541 Some(cx.spawn(async move |this, cx| {
1542 let mut title = String::new();
1543 let mut messages = model.stream_completion(request, cx).await?;
1544 while let Some(event) = messages.next().await {
1545 let event = event?;
1546 let text = match event {
1547 LanguageModelCompletionEvent::Text(text) => text,
1548 LanguageModelCompletionEvent::StatusUpdate(
1549 CompletionRequestStatus::UsageUpdated { .. },
1550 ) => {
1551 // this.update(cx, |thread, cx| {
1552 // thread.update_model_request_usage(amount as u32, limit, cx);
1553 // })?;
1554 // TODO: handle usage update
1555 continue;
1556 }
1557 _ => continue,
1558 };
1559
1560 let mut lines = text.lines();
1561 title.extend(lines.next());
1562
1563 // Stop if the LLM generated multiple lines.
1564 if lines.next().is_some() {
1565 break;
1566 }
1567 }
1568
1569 log::info!("Setting title: {}", title);
1570
1571 this.update(cx, |this, cx| {
1572 let title = SharedString::from(title);
1573 event_stream.send_title_update(title.clone());
1574 this.title = Some(title);
1575 cx.notify();
1576 })
1577 }))
1578 }
1579 fn last_user_message(&self) -> Option<&UserMessage> {
1580 self.messages
1581 .iter()
1582 .rev()
1583 .find_map(|message| match message {
1584 Message::User(user_message) => Some(user_message),
1585 Message::Agent(_) => None,
1586 Message::Resume => None,
1587 })
1588 }
1589
1590 fn pending_message(&mut self) -> &mut AgentMessage {
1591 self.pending_message.get_or_insert_default()
1592 }
1593
1594 fn flush_pending_message(&mut self, cx: &mut Context<Self>) {
1595 let Some(mut message) = self.pending_message.take() else {
1596 return;
1597 };
1598
1599 for content in &message.content {
1600 let AgentMessageContent::ToolUse(tool_use) = content else {
1601 continue;
1602 };
1603
1604 if !message.tool_results.contains_key(&tool_use.id) {
1605 message.tool_results.insert(
1606 tool_use.id.clone(),
1607 LanguageModelToolResult {
1608 tool_use_id: tool_use.id.clone(),
1609 tool_name: tool_use.name.clone(),
1610 is_error: true,
1611 content: LanguageModelToolResultContent::Text(TOOL_CANCELED_MESSAGE.into()),
1612 output: None,
1613 },
1614 );
1615 }
1616 }
1617
1618 self.messages.push(Message::Agent(message));
1619 self.updated_at = Utc::now();
1620 cx.notify()
1621 }
1622
1623 pub(crate) fn build_completion_request(
1624 &self,
1625 completion_intent: CompletionIntent,
1626 cx: &mut App,
1627 ) -> Result<LanguageModelRequest> {
1628 let model = self.model().context("No language model configured")?;
1629
1630 log::debug!("Building completion request");
1631 log::debug!("Completion intent: {:?}", completion_intent);
1632 log::debug!("Completion mode: {:?}", self.completion_mode);
1633
1634 let messages = self.build_request_messages(cx);
1635 log::info!("Request will include {} messages", messages.len());
1636
1637 let tools = if let Some(tools) = self.tools(cx).log_err() {
1638 tools
1639 .filter_map(|tool| {
1640 let tool_name = tool.name().to_string();
1641 log::trace!("Including tool: {}", tool_name);
1642 Some(LanguageModelRequestTool {
1643 name: tool_name,
1644 description: tool.description().to_string(),
1645 input_schema: tool.input_schema(model.tool_input_format()).log_err()?,
1646 })
1647 })
1648 .collect()
1649 } else {
1650 Vec::new()
1651 };
1652
1653 log::info!("Request includes {} tools", tools.len());
1654
1655 let request = LanguageModelRequest {
1656 thread_id: Some(self.id.to_string()),
1657 prompt_id: Some(self.prompt_id.to_string()),
1658 intent: Some(completion_intent),
1659 mode: Some(self.completion_mode.into()),
1660 messages,
1661 tools,
1662 tool_choice: None,
1663 stop: Vec::new(),
1664 temperature: AgentSettings::temperature_for_model(model, cx),
1665 thinking_allowed: true,
1666 };
1667
1668 log::debug!("Completion request built successfully");
1669 Ok(request)
1670 }
1671
1672 fn tools<'a>(&'a self, cx: &'a App) -> Result<impl Iterator<Item = &'a Arc<dyn AnyAgentTool>>> {
1673 let model = self.model().context("No language model configured")?;
1674
1675 let profile = AgentSettings::get_global(cx)
1676 .profiles
1677 .get(&self.profile_id)
1678 .context("profile not found")?;
1679 let provider_id = model.provider_id();
1680
1681 Ok(self
1682 .tools
1683 .iter()
1684 .filter(move |(_, tool)| tool.supported_provider(&provider_id))
1685 .filter_map(|(tool_name, tool)| {
1686 if profile.is_tool_enabled(tool_name) {
1687 Some(tool)
1688 } else {
1689 None
1690 }
1691 })
1692 .chain(self.context_server_registry.read(cx).servers().flat_map(
1693 |(server_id, tools)| {
1694 tools.iter().filter_map(|(tool_name, tool)| {
1695 if profile.is_context_server_tool_enabled(&server_id.0, tool_name) {
1696 Some(tool)
1697 } else {
1698 None
1699 }
1700 })
1701 },
1702 )))
1703 }
1704
1705 fn build_request_messages(&self, cx: &App) -> Vec<LanguageModelRequestMessage> {
1706 log::trace!(
1707 "Building request messages from {} thread messages",
1708 self.messages.len()
1709 );
1710 let mut messages = vec![self.build_system_message(cx)];
1711 for message in &self.messages {
1712 messages.extend(message.to_request());
1713 }
1714
1715 if let Some(message) = self.pending_message.as_ref() {
1716 messages.extend(message.to_request());
1717 }
1718
1719 if let Some(last_user_message) = messages
1720 .iter_mut()
1721 .rev()
1722 .find(|message| message.role == Role::User)
1723 {
1724 last_user_message.cache = true;
1725 }
1726
1727 messages
1728 }
1729
1730 pub fn to_markdown(&self) -> String {
1731 let mut markdown = String::new();
1732 for (ix, message) in self.messages.iter().enumerate() {
1733 if ix > 0 {
1734 markdown.push('\n');
1735 }
1736 markdown.push_str(&message.to_markdown());
1737 }
1738
1739 if let Some(message) = self.pending_message.as_ref() {
1740 markdown.push('\n');
1741 markdown.push_str(&message.to_markdown());
1742 }
1743
1744 markdown
1745 }
1746
1747 fn advance_prompt_id(&mut self) {
1748 self.prompt_id = PromptId::new();
1749 }
1750
1751 fn retry_strategy_for(error: &LanguageModelCompletionError) -> Option<RetryStrategy> {
1752 use LanguageModelCompletionError::*;
1753 use http_client::StatusCode;
1754
1755 // General strategy here:
1756 // - If retrying won't help (e.g. invalid API key or payload too large), return None so we don't retry at all.
1757 // - If it's a time-based issue (e.g. server overloaded, rate limit exceeded), retry up to 4 times with exponential backoff.
1758 // - If it's an issue that *might* be fixed by retrying (e.g. internal server error), retry up to 3 times.
1759 match error {
1760 HttpResponseError {
1761 status_code: StatusCode::TOO_MANY_REQUESTS,
1762 ..
1763 } => Some(RetryStrategy::ExponentialBackoff {
1764 initial_delay: BASE_RETRY_DELAY,
1765 max_attempts: MAX_RETRY_ATTEMPTS,
1766 }),
1767 ServerOverloaded { retry_after, .. } | RateLimitExceeded { retry_after, .. } => {
1768 Some(RetryStrategy::Fixed {
1769 delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
1770 max_attempts: MAX_RETRY_ATTEMPTS,
1771 })
1772 }
1773 UpstreamProviderError {
1774 status,
1775 retry_after,
1776 ..
1777 } => match *status {
1778 StatusCode::TOO_MANY_REQUESTS | StatusCode::SERVICE_UNAVAILABLE => {
1779 Some(RetryStrategy::Fixed {
1780 delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
1781 max_attempts: MAX_RETRY_ATTEMPTS,
1782 })
1783 }
1784 StatusCode::INTERNAL_SERVER_ERROR => Some(RetryStrategy::Fixed {
1785 delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
1786 // Internal Server Error could be anything, retry up to 3 times.
1787 max_attempts: 3,
1788 }),
1789 status => {
1790 // There is no StatusCode variant for the unofficial HTTP 529 ("The service is overloaded"),
1791 // but we frequently get them in practice. See https://http.dev/529
1792 if status.as_u16() == 529 {
1793 Some(RetryStrategy::Fixed {
1794 delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
1795 max_attempts: MAX_RETRY_ATTEMPTS,
1796 })
1797 } else {
1798 Some(RetryStrategy::Fixed {
1799 delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
1800 max_attempts: 2,
1801 })
1802 }
1803 }
1804 },
1805 ApiInternalServerError { .. } => Some(RetryStrategy::Fixed {
1806 delay: BASE_RETRY_DELAY,
1807 max_attempts: 3,
1808 }),
1809 ApiReadResponseError { .. }
1810 | HttpSend { .. }
1811 | DeserializeResponse { .. }
1812 | BadRequestFormat { .. } => Some(RetryStrategy::Fixed {
1813 delay: BASE_RETRY_DELAY,
1814 max_attempts: 3,
1815 }),
1816 // Retrying these errors definitely shouldn't help.
1817 HttpResponseError {
1818 status_code:
1819 StatusCode::PAYLOAD_TOO_LARGE | StatusCode::FORBIDDEN | StatusCode::UNAUTHORIZED,
1820 ..
1821 }
1822 | AuthenticationError { .. }
1823 | PermissionError { .. }
1824 | NoApiKey { .. }
1825 | ApiEndpointNotFound { .. }
1826 | PromptTooLarge { .. } => None,
1827 // These errors might be transient, so retry them
1828 SerializeRequest { .. } | BuildRequestBody { .. } => Some(RetryStrategy::Fixed {
1829 delay: BASE_RETRY_DELAY,
1830 max_attempts: 1,
1831 }),
1832 // Retry all other 4xx and 5xx errors once.
1833 HttpResponseError { status_code, .. }
1834 if status_code.is_client_error() || status_code.is_server_error() =>
1835 {
1836 Some(RetryStrategy::Fixed {
1837 delay: BASE_RETRY_DELAY,
1838 max_attempts: 3,
1839 })
1840 }
1841 Other(err)
1842 if err.is::<language_model::PaymentRequiredError>()
1843 || err.is::<language_model::ModelRequestLimitReachedError>() =>
1844 {
1845 // Retrying won't help for Payment Required or Model Request Limit errors (where
1846 // the user must upgrade to usage-based billing to get more requests, or else wait
1847 // for a significant amount of time for the request limit to reset).
1848 None
1849 }
1850 // Conservatively assume that any other errors are non-retryable
1851 HttpResponseError { .. } | Other(..) => Some(RetryStrategy::Fixed {
1852 delay: BASE_RETRY_DELAY,
1853 max_attempts: 2,
1854 }),
1855 }
1856 }
1857}
1858
1859struct RunningTurn {
1860 /// Holds the task that handles agent interaction until the end of the turn.
1861 /// Survives across multiple requests as the model performs tool calls and
1862 /// we run tools, report their results.
1863 _task: Task<()>,
1864 /// The current event stream for the running turn. Used to report a final
1865 /// cancellation event if we cancel the turn.
1866 event_stream: ThreadEventStream,
1867}
1868
1869impl RunningTurn {
1870 fn cancel(self) {
1871 log::debug!("Cancelling in progress turn");
1872 self.event_stream.send_canceled();
1873 }
1874}
1875
1876pub trait AgentTool
1877where
1878 Self: 'static + Sized,
1879{
1880 type Input: for<'de> Deserialize<'de> + Serialize + JsonSchema;
1881 type Output: for<'de> Deserialize<'de> + Serialize + Into<LanguageModelToolResultContent>;
1882
1883 fn name(&self) -> SharedString;
1884
1885 fn description(&self) -> SharedString {
1886 let schema = schemars::schema_for!(Self::Input);
1887 SharedString::new(
1888 schema
1889 .get("description")
1890 .and_then(|description| description.as_str())
1891 .unwrap_or_default(),
1892 )
1893 }
1894
1895 fn kind(&self) -> acp::ToolKind;
1896
1897 /// The initial tool title to display. Can be updated during the tool run.
1898 fn initial_title(&self, input: Result<Self::Input, serde_json::Value>) -> SharedString;
1899
1900 /// Returns the JSON schema that describes the tool's input.
1901 fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Schema {
1902 crate::tool_schema::root_schema_for::<Self::Input>(format)
1903 }
1904
1905 /// Some tools rely on a provider for the underlying billing or other reasons.
1906 /// Allow the tool to check if they are compatible, or should be filtered out.
1907 fn supported_provider(&self, _provider: &LanguageModelProviderId) -> bool {
1908 true
1909 }
1910
1911 /// Runs the tool with the provided input.
1912 fn run(
1913 self: Arc<Self>,
1914 input: Self::Input,
1915 event_stream: ToolCallEventStream,
1916 cx: &mut App,
1917 ) -> Task<Result<Self::Output>>;
1918
1919 /// Emits events for a previous execution of the tool.
1920 fn replay(
1921 &self,
1922 _input: Self::Input,
1923 _output: Self::Output,
1924 _event_stream: ToolCallEventStream,
1925 _cx: &mut App,
1926 ) -> Result<()> {
1927 Ok(())
1928 }
1929
1930 fn erase(self) -> Arc<dyn AnyAgentTool> {
1931 Arc::new(Erased(Arc::new(self)))
1932 }
1933}
1934
1935pub struct Erased<T>(T);
1936
1937pub struct AgentToolOutput {
1938 pub llm_output: LanguageModelToolResultContent,
1939 pub raw_output: serde_json::Value,
1940}
1941
1942pub trait AnyAgentTool {
1943 fn name(&self) -> SharedString;
1944 fn description(&self) -> SharedString;
1945 fn kind(&self) -> acp::ToolKind;
1946 fn initial_title(&self, input: serde_json::Value) -> SharedString;
1947 fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value>;
1948 fn supported_provider(&self, _provider: &LanguageModelProviderId) -> bool {
1949 true
1950 }
1951 fn run(
1952 self: Arc<Self>,
1953 input: serde_json::Value,
1954 event_stream: ToolCallEventStream,
1955 cx: &mut App,
1956 ) -> Task<Result<AgentToolOutput>>;
1957 fn replay(
1958 &self,
1959 input: serde_json::Value,
1960 output: serde_json::Value,
1961 event_stream: ToolCallEventStream,
1962 cx: &mut App,
1963 ) -> Result<()>;
1964}
1965
1966impl<T> AnyAgentTool for Erased<Arc<T>>
1967where
1968 T: AgentTool,
1969{
1970 fn name(&self) -> SharedString {
1971 self.0.name()
1972 }
1973
1974 fn description(&self) -> SharedString {
1975 self.0.description()
1976 }
1977
1978 fn kind(&self) -> agent_client_protocol::ToolKind {
1979 self.0.kind()
1980 }
1981
1982 fn initial_title(&self, input: serde_json::Value) -> SharedString {
1983 let parsed_input = serde_json::from_value(input.clone()).map_err(|_| input);
1984 self.0.initial_title(parsed_input)
1985 }
1986
1987 fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
1988 let mut json = serde_json::to_value(self.0.input_schema(format))?;
1989 adapt_schema_to_format(&mut json, format)?;
1990 Ok(json)
1991 }
1992
1993 fn supported_provider(&self, provider: &LanguageModelProviderId) -> bool {
1994 self.0.supported_provider(provider)
1995 }
1996
1997 fn run(
1998 self: Arc<Self>,
1999 input: serde_json::Value,
2000 event_stream: ToolCallEventStream,
2001 cx: &mut App,
2002 ) -> Task<Result<AgentToolOutput>> {
2003 cx.spawn(async move |cx| {
2004 let input = serde_json::from_value(input)?;
2005 let output = cx
2006 .update(|cx| self.0.clone().run(input, event_stream, cx))?
2007 .await?;
2008 let raw_output = serde_json::to_value(&output)?;
2009 Ok(AgentToolOutput {
2010 llm_output: output.into(),
2011 raw_output,
2012 })
2013 })
2014 }
2015
2016 fn replay(
2017 &self,
2018 input: serde_json::Value,
2019 output: serde_json::Value,
2020 event_stream: ToolCallEventStream,
2021 cx: &mut App,
2022 ) -> Result<()> {
2023 let input = serde_json::from_value(input)?;
2024 let output = serde_json::from_value(output)?;
2025 self.0.replay(input, output, event_stream, cx)
2026 }
2027}
2028
2029#[derive(Clone)]
2030struct ThreadEventStream(mpsc::UnboundedSender<Result<ThreadEvent>>);
2031
2032impl ThreadEventStream {
2033 fn send_title_update(&self, text: SharedString) {
2034 self.0
2035 .unbounded_send(Ok(ThreadEvent::TitleUpdate(text)))
2036 .ok();
2037 }
2038
2039 fn send_user_message(&self, message: &UserMessage) {
2040 self.0
2041 .unbounded_send(Ok(ThreadEvent::UserMessage(message.clone())))
2042 .ok();
2043 }
2044
2045 fn send_text(&self, text: &str) {
2046 self.0
2047 .unbounded_send(Ok(ThreadEvent::AgentText(text.to_string())))
2048 .ok();
2049 }
2050
2051 fn send_thinking(&self, text: &str) {
2052 self.0
2053 .unbounded_send(Ok(ThreadEvent::AgentThinking(text.to_string())))
2054 .ok();
2055 }
2056
2057 fn send_tool_call(
2058 &self,
2059 id: &LanguageModelToolUseId,
2060 title: SharedString,
2061 kind: acp::ToolKind,
2062 input: serde_json::Value,
2063 ) {
2064 self.0
2065 .unbounded_send(Ok(ThreadEvent::ToolCall(Self::initial_tool_call(
2066 id,
2067 title.to_string(),
2068 kind,
2069 input,
2070 ))))
2071 .ok();
2072 }
2073
2074 fn initial_tool_call(
2075 id: &LanguageModelToolUseId,
2076 title: String,
2077 kind: acp::ToolKind,
2078 input: serde_json::Value,
2079 ) -> acp::ToolCall {
2080 acp::ToolCall {
2081 id: acp::ToolCallId(id.to_string().into()),
2082 title,
2083 kind,
2084 status: acp::ToolCallStatus::Pending,
2085 content: vec![],
2086 locations: vec![],
2087 raw_input: Some(input),
2088 raw_output: None,
2089 }
2090 }
2091
2092 fn update_tool_call_fields(
2093 &self,
2094 tool_use_id: &LanguageModelToolUseId,
2095 fields: acp::ToolCallUpdateFields,
2096 ) {
2097 self.0
2098 .unbounded_send(Ok(ThreadEvent::ToolCallUpdate(
2099 acp::ToolCallUpdate {
2100 id: acp::ToolCallId(tool_use_id.to_string().into()),
2101 fields,
2102 }
2103 .into(),
2104 )))
2105 .ok();
2106 }
2107
2108 fn send_token_usage_update(&self, usage: acp_thread::TokenUsage) {
2109 self.0
2110 .unbounded_send(Ok(ThreadEvent::TokenUsageUpdate(usage)))
2111 .ok();
2112 }
2113
2114 fn send_retry(&self, status: acp_thread::RetryStatus) {
2115 self.0.unbounded_send(Ok(ThreadEvent::Retry(status))).ok();
2116 }
2117
2118 fn send_stop(&self, reason: StopReason) {
2119 match reason {
2120 StopReason::EndTurn => {
2121 self.0
2122 .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::EndTurn)))
2123 .ok();
2124 }
2125 StopReason::MaxTokens => {
2126 self.0
2127 .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::MaxTokens)))
2128 .ok();
2129 }
2130 StopReason::Refusal => {
2131 self.0
2132 .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::Refusal)))
2133 .ok();
2134 }
2135 StopReason::ToolUse => {}
2136 }
2137 }
2138
2139 fn send_canceled(&self) {
2140 self.0
2141 .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::Canceled)))
2142 .ok();
2143 }
2144
2145 fn send_error(&self, error: impl Into<anyhow::Error>) {
2146 self.0.unbounded_send(Err(error.into())).ok();
2147 }
2148}
2149
2150#[derive(Clone)]
2151pub struct ToolCallEventStream {
2152 tool_use_id: LanguageModelToolUseId,
2153 stream: ThreadEventStream,
2154 fs: Option<Arc<dyn Fs>>,
2155}
2156
2157impl ToolCallEventStream {
2158 #[cfg(test)]
2159 pub fn test() -> (Self, ToolCallEventStreamReceiver) {
2160 let (events_tx, events_rx) = mpsc::unbounded::<Result<ThreadEvent>>();
2161
2162 let stream = ToolCallEventStream::new("test_id".into(), ThreadEventStream(events_tx), None);
2163
2164 (stream, ToolCallEventStreamReceiver(events_rx))
2165 }
2166
2167 fn new(
2168 tool_use_id: LanguageModelToolUseId,
2169 stream: ThreadEventStream,
2170 fs: Option<Arc<dyn Fs>>,
2171 ) -> Self {
2172 Self {
2173 tool_use_id,
2174 stream,
2175 fs,
2176 }
2177 }
2178
2179 pub fn update_fields(&self, fields: acp::ToolCallUpdateFields) {
2180 self.stream
2181 .update_tool_call_fields(&self.tool_use_id, fields);
2182 }
2183
2184 pub fn update_diff(&self, diff: Entity<acp_thread::Diff>) {
2185 self.stream
2186 .0
2187 .unbounded_send(Ok(ThreadEvent::ToolCallUpdate(
2188 acp_thread::ToolCallUpdateDiff {
2189 id: acp::ToolCallId(self.tool_use_id.to_string().into()),
2190 diff,
2191 }
2192 .into(),
2193 )))
2194 .ok();
2195 }
2196
2197 pub fn update_terminal(&self, terminal: Entity<acp_thread::Terminal>) {
2198 self.stream
2199 .0
2200 .unbounded_send(Ok(ThreadEvent::ToolCallUpdate(
2201 acp_thread::ToolCallUpdateTerminal {
2202 id: acp::ToolCallId(self.tool_use_id.to_string().into()),
2203 terminal,
2204 }
2205 .into(),
2206 )))
2207 .ok();
2208 }
2209
2210 pub fn authorize(&self, title: impl Into<String>, cx: &mut App) -> Task<Result<()>> {
2211 if agent_settings::AgentSettings::get_global(cx).always_allow_tool_actions {
2212 return Task::ready(Ok(()));
2213 }
2214
2215 let (response_tx, response_rx) = oneshot::channel();
2216 self.stream
2217 .0
2218 .unbounded_send(Ok(ThreadEvent::ToolCallAuthorization(
2219 ToolCallAuthorization {
2220 tool_call: acp::ToolCallUpdate {
2221 id: acp::ToolCallId(self.tool_use_id.to_string().into()),
2222 fields: acp::ToolCallUpdateFields {
2223 title: Some(title.into()),
2224 ..Default::default()
2225 },
2226 },
2227 options: vec![
2228 acp::PermissionOption {
2229 id: acp::PermissionOptionId("always_allow".into()),
2230 name: "Always Allow".into(),
2231 kind: acp::PermissionOptionKind::AllowAlways,
2232 },
2233 acp::PermissionOption {
2234 id: acp::PermissionOptionId("allow".into()),
2235 name: "Allow".into(),
2236 kind: acp::PermissionOptionKind::AllowOnce,
2237 },
2238 acp::PermissionOption {
2239 id: acp::PermissionOptionId("deny".into()),
2240 name: "Deny".into(),
2241 kind: acp::PermissionOptionKind::RejectOnce,
2242 },
2243 ],
2244 response: response_tx,
2245 },
2246 )))
2247 .ok();
2248 let fs = self.fs.clone();
2249 cx.spawn(async move |cx| match response_rx.await?.0.as_ref() {
2250 "always_allow" => {
2251 if let Some(fs) = fs.clone() {
2252 cx.update(|cx| {
2253 update_settings_file::<AgentSettings>(fs, cx, |settings, _| {
2254 settings.set_always_allow_tool_actions(true);
2255 });
2256 })?;
2257 }
2258
2259 Ok(())
2260 }
2261 "allow" => Ok(()),
2262 _ => Err(anyhow!("Permission to run tool denied by user")),
2263 })
2264 }
2265}
2266
2267#[cfg(test)]
2268pub struct ToolCallEventStreamReceiver(mpsc::UnboundedReceiver<Result<ThreadEvent>>);
2269
2270#[cfg(test)]
2271impl ToolCallEventStreamReceiver {
2272 pub async fn expect_authorization(&mut self) -> ToolCallAuthorization {
2273 let event = self.0.next().await;
2274 if let Some(Ok(ThreadEvent::ToolCallAuthorization(auth))) = event {
2275 auth
2276 } else {
2277 panic!("Expected ToolCallAuthorization but got: {:?}", event);
2278 }
2279 }
2280
2281 pub async fn expect_terminal(&mut self) -> Entity<acp_thread::Terminal> {
2282 let event = self.0.next().await;
2283 if let Some(Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateTerminal(
2284 update,
2285 )))) = event
2286 {
2287 update.terminal
2288 } else {
2289 panic!("Expected terminal but got: {:?}", event);
2290 }
2291 }
2292}
2293
2294#[cfg(test)]
2295impl std::ops::Deref for ToolCallEventStreamReceiver {
2296 type Target = mpsc::UnboundedReceiver<Result<ThreadEvent>>;
2297
2298 fn deref(&self) -> &Self::Target {
2299 &self.0
2300 }
2301}
2302
2303#[cfg(test)]
2304impl std::ops::DerefMut for ToolCallEventStreamReceiver {
2305 fn deref_mut(&mut self) -> &mut Self::Target {
2306 &mut self.0
2307 }
2308}
2309
2310impl From<&str> for UserMessageContent {
2311 fn from(text: &str) -> Self {
2312 Self::Text(text.into())
2313 }
2314}
2315
2316impl From<acp::ContentBlock> for UserMessageContent {
2317 fn from(value: acp::ContentBlock) -> Self {
2318 match value {
2319 acp::ContentBlock::Text(text_content) => Self::Text(text_content.text),
2320 acp::ContentBlock::Image(image_content) => Self::Image(convert_image(image_content)),
2321 acp::ContentBlock::Audio(_) => {
2322 // TODO
2323 Self::Text("[audio]".to_string())
2324 }
2325 acp::ContentBlock::ResourceLink(resource_link) => {
2326 match MentionUri::parse(&resource_link.uri) {
2327 Ok(uri) => Self::Mention {
2328 uri,
2329 content: String::new(),
2330 },
2331 Err(err) => {
2332 log::error!("Failed to parse mention link: {}", err);
2333 Self::Text(format!("[{}]({})", resource_link.name, resource_link.uri))
2334 }
2335 }
2336 }
2337 acp::ContentBlock::Resource(resource) => match resource.resource {
2338 acp::EmbeddedResourceResource::TextResourceContents(resource) => {
2339 match MentionUri::parse(&resource.uri) {
2340 Ok(uri) => Self::Mention {
2341 uri,
2342 content: resource.text,
2343 },
2344 Err(err) => {
2345 log::error!("Failed to parse mention link: {}", err);
2346 Self::Text(
2347 MarkdownCodeBlock {
2348 tag: &resource.uri,
2349 text: &resource.text,
2350 }
2351 .to_string(),
2352 )
2353 }
2354 }
2355 }
2356 acp::EmbeddedResourceResource::BlobResourceContents(_) => {
2357 // TODO
2358 Self::Text("[blob]".to_string())
2359 }
2360 },
2361 }
2362 }
2363}
2364
2365impl From<UserMessageContent> for acp::ContentBlock {
2366 fn from(content: UserMessageContent) -> Self {
2367 match content {
2368 UserMessageContent::Text(text) => acp::ContentBlock::Text(acp::TextContent {
2369 text,
2370 annotations: None,
2371 }),
2372 UserMessageContent::Image(image) => acp::ContentBlock::Image(acp::ImageContent {
2373 data: image.source.to_string(),
2374 mime_type: "image/png".to_string(),
2375 annotations: None,
2376 uri: None,
2377 }),
2378 UserMessageContent::Mention { uri, content } => {
2379 acp::ContentBlock::ResourceLink(acp::ResourceLink {
2380 uri: uri.to_uri().to_string(),
2381 name: uri.name(),
2382 annotations: None,
2383 description: if content.is_empty() {
2384 None
2385 } else {
2386 Some(content)
2387 },
2388 mime_type: None,
2389 size: None,
2390 title: None,
2391 })
2392 }
2393 }
2394 }
2395}
2396
2397fn convert_image(image_content: acp::ImageContent) -> LanguageModelImage {
2398 LanguageModelImage {
2399 source: image_content.data.into(),
2400 // TODO: make this optional?
2401 size: gpui::Size::new(0.into(), 0.into()),
2402 }
2403}