1use crate::{
2 ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DbLanguageModel, DbThread,
3 DeletePathTool, DiagnosticsTool, EditFileTool, FetchTool, FindPathTool, GrepTool,
4 ListDirectoryTool, MovePathTool, NowTool, OpenTool, ReadFileTool, SystemPromptTemplate,
5 Template, Templates, TerminalTool, ThinkingTool, WebSearchTool,
6};
7use acp_thread::{MentionUri, UserMessageId};
8use action_log::ActionLog;
9use agent::thread::{GitState, ProjectSnapshot, WorktreeSnapshot};
10use agent_client_protocol as acp;
11use agent_settings::{
12 AgentProfileId, AgentSettings, CompletionMode, SUMMARIZE_THREAD_DETAILED_PROMPT,
13 SUMMARIZE_THREAD_PROMPT,
14};
15use anyhow::{Context as _, Result, anyhow};
16use assistant_tool::adapt_schema_to_format;
17use chrono::{DateTime, Utc};
18use cloud_llm_client::{CompletionIntent, CompletionRequestStatus};
19use collections::{HashMap, IndexMap};
20use fs::Fs;
21use futures::{
22 FutureExt,
23 channel::{mpsc, oneshot},
24 future::Shared,
25 stream::FuturesUnordered,
26};
27use git::repository::DiffType;
28use gpui::{App, AppContext, AsyncApp, Context, Entity, SharedString, Task, WeakEntity};
29use language_model::{
30 LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelExt,
31 LanguageModelImage, LanguageModelProviderId, LanguageModelRegistry, LanguageModelRequest,
32 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
33 LanguageModelToolResultContent, LanguageModelToolSchemaFormat, LanguageModelToolUse,
34 LanguageModelToolUseId, Role, SelectedModel, StopReason, TokenUsage,
35};
36use project::{
37 Project,
38 git_store::{GitStore, RepositoryState},
39};
40use prompt_store::ProjectContext;
41use schemars::{JsonSchema, Schema};
42use serde::{Deserialize, Serialize};
43use settings::{Settings, update_settings_file};
44use smol::stream::StreamExt;
45use std::{
46 collections::BTreeMap,
47 path::Path,
48 sync::Arc,
49 time::{Duration, Instant},
50};
51use std::{fmt::Write, ops::Range};
52use util::{ResultExt, markdown::MarkdownCodeBlock};
53use uuid::Uuid;
54
55const TOOL_CANCELED_MESSAGE: &str = "Tool canceled by user";
56
57/// The ID of the user prompt that initiated a request.
58///
59/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
60#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
61pub struct PromptId(Arc<str>);
62
63impl PromptId {
64 pub fn new() -> Self {
65 Self(Uuid::new_v4().to_string().into())
66 }
67}
68
69impl std::fmt::Display for PromptId {
70 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
71 write!(f, "{}", self.0)
72 }
73}
74
75pub(crate) const MAX_RETRY_ATTEMPTS: u8 = 4;
76pub(crate) const BASE_RETRY_DELAY: Duration = Duration::from_secs(5);
77
78#[derive(Debug, Clone)]
79enum RetryStrategy {
80 ExponentialBackoff {
81 initial_delay: Duration,
82 max_attempts: u8,
83 },
84 Fixed {
85 delay: Duration,
86 max_attempts: u8,
87 },
88}
89
90#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
91pub enum Message {
92 User(UserMessage),
93 Agent(AgentMessage),
94 Resume,
95}
96
97impl Message {
98 pub fn as_agent_message(&self) -> Option<&AgentMessage> {
99 match self {
100 Message::Agent(agent_message) => Some(agent_message),
101 _ => None,
102 }
103 }
104
105 pub fn to_request(&self) -> Vec<LanguageModelRequestMessage> {
106 match self {
107 Message::User(message) => vec![message.to_request()],
108 Message::Agent(message) => message.to_request(),
109 Message::Resume => vec![LanguageModelRequestMessage {
110 role: Role::User,
111 content: vec!["Continue where you left off".into()],
112 cache: false,
113 }],
114 }
115 }
116
117 pub fn to_markdown(&self) -> String {
118 match self {
119 Message::User(message) => message.to_markdown(),
120 Message::Agent(message) => message.to_markdown(),
121 Message::Resume => "[resumed after tool use limit was reached]".into(),
122 }
123 }
124
125 pub fn role(&self) -> Role {
126 match self {
127 Message::User(_) | Message::Resume => Role::User,
128 Message::Agent(_) => Role::Assistant,
129 }
130 }
131}
132
133#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
134pub struct UserMessage {
135 pub id: UserMessageId,
136 pub content: Vec<UserMessageContent>,
137}
138
139#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
140pub enum UserMessageContent {
141 Text(String),
142 Mention { uri: MentionUri, content: String },
143 Image(LanguageModelImage),
144}
145
146impl UserMessage {
147 pub fn to_markdown(&self) -> String {
148 let mut markdown = String::from("## User\n\n");
149
150 for content in &self.content {
151 match content {
152 UserMessageContent::Text(text) => {
153 markdown.push_str(text);
154 markdown.push('\n');
155 }
156 UserMessageContent::Image(_) => {
157 markdown.push_str("<image />\n");
158 }
159 UserMessageContent::Mention { uri, content } => {
160 if !content.is_empty() {
161 let _ = write!(&mut markdown, "{}\n\n{}\n", uri.as_link(), content);
162 } else {
163 let _ = write!(&mut markdown, "{}\n", uri.as_link());
164 }
165 }
166 }
167 }
168
169 markdown
170 }
171
172 fn to_request(&self) -> LanguageModelRequestMessage {
173 let mut message = LanguageModelRequestMessage {
174 role: Role::User,
175 content: Vec::with_capacity(self.content.len()),
176 cache: false,
177 };
178
179 const OPEN_CONTEXT: &str = "<context>\n\
180 The following items were attached by the user. \
181 They are up-to-date and don't need to be re-read.\n\n";
182
183 const OPEN_FILES_TAG: &str = "<files>";
184 const OPEN_DIRECTORIES_TAG: &str = "<directories>";
185 const OPEN_SYMBOLS_TAG: &str = "<symbols>";
186 const OPEN_THREADS_TAG: &str = "<threads>";
187 const OPEN_FETCH_TAG: &str = "<fetched_urls>";
188 const OPEN_RULES_TAG: &str =
189 "<rules>\nThe user has specified the following rules that should be applied:\n";
190
191 let mut file_context = OPEN_FILES_TAG.to_string();
192 let mut directory_context = OPEN_DIRECTORIES_TAG.to_string();
193 let mut symbol_context = OPEN_SYMBOLS_TAG.to_string();
194 let mut thread_context = OPEN_THREADS_TAG.to_string();
195 let mut fetch_context = OPEN_FETCH_TAG.to_string();
196 let mut rules_context = OPEN_RULES_TAG.to_string();
197
198 for chunk in &self.content {
199 let chunk = match chunk {
200 UserMessageContent::Text(text) => {
201 language_model::MessageContent::Text(text.clone())
202 }
203 UserMessageContent::Image(value) => {
204 language_model::MessageContent::Image(value.clone())
205 }
206 UserMessageContent::Mention { uri, content } => {
207 match uri {
208 MentionUri::File { abs_path } => {
209 write!(
210 &mut symbol_context,
211 "\n{}",
212 MarkdownCodeBlock {
213 tag: &codeblock_tag(abs_path, None),
214 text: &content.to_string(),
215 }
216 )
217 .ok();
218 }
219 MentionUri::Directory { .. } => {
220 write!(&mut directory_context, "\n{}\n", content).ok();
221 }
222 MentionUri::Symbol {
223 path, line_range, ..
224 }
225 | MentionUri::Selection {
226 path, line_range, ..
227 } => {
228 write!(
229 &mut rules_context,
230 "\n{}",
231 MarkdownCodeBlock {
232 tag: &codeblock_tag(path, Some(line_range)),
233 text: content
234 }
235 )
236 .ok();
237 }
238 MentionUri::Thread { .. } => {
239 write!(&mut thread_context, "\n{}\n", content).ok();
240 }
241 MentionUri::TextThread { .. } => {
242 write!(&mut thread_context, "\n{}\n", content).ok();
243 }
244 MentionUri::Rule { .. } => {
245 write!(
246 &mut rules_context,
247 "\n{}",
248 MarkdownCodeBlock {
249 tag: "",
250 text: content
251 }
252 )
253 .ok();
254 }
255 MentionUri::Fetch { url } => {
256 write!(&mut fetch_context, "\nFetch: {}\n\n{}", url, content).ok();
257 }
258 }
259
260 language_model::MessageContent::Text(uri.as_link().to_string())
261 }
262 };
263
264 message.content.push(chunk);
265 }
266
267 let len_before_context = message.content.len();
268
269 if file_context.len() > OPEN_FILES_TAG.len() {
270 file_context.push_str("</files>\n");
271 message
272 .content
273 .push(language_model::MessageContent::Text(file_context));
274 }
275
276 if directory_context.len() > OPEN_DIRECTORIES_TAG.len() {
277 directory_context.push_str("</directories>\n");
278 message
279 .content
280 .push(language_model::MessageContent::Text(directory_context));
281 }
282
283 if symbol_context.len() > OPEN_SYMBOLS_TAG.len() {
284 symbol_context.push_str("</symbols>\n");
285 message
286 .content
287 .push(language_model::MessageContent::Text(symbol_context));
288 }
289
290 if thread_context.len() > OPEN_THREADS_TAG.len() {
291 thread_context.push_str("</threads>\n");
292 message
293 .content
294 .push(language_model::MessageContent::Text(thread_context));
295 }
296
297 if fetch_context.len() > OPEN_FETCH_TAG.len() {
298 fetch_context.push_str("</fetched_urls>\n");
299 message
300 .content
301 .push(language_model::MessageContent::Text(fetch_context));
302 }
303
304 if rules_context.len() > OPEN_RULES_TAG.len() {
305 rules_context.push_str("</user_rules>\n");
306 message
307 .content
308 .push(language_model::MessageContent::Text(rules_context));
309 }
310
311 if message.content.len() > len_before_context {
312 message.content.insert(
313 len_before_context,
314 language_model::MessageContent::Text(OPEN_CONTEXT.into()),
315 );
316 message
317 .content
318 .push(language_model::MessageContent::Text("</context>".into()));
319 }
320
321 message
322 }
323}
324
325fn codeblock_tag(full_path: &Path, line_range: Option<&Range<u32>>) -> String {
326 let mut result = String::new();
327
328 if let Some(extension) = full_path.extension().and_then(|ext| ext.to_str()) {
329 let _ = write!(result, "{} ", extension);
330 }
331
332 let _ = write!(result, "{}", full_path.display());
333
334 if let Some(range) = line_range {
335 if range.start == range.end {
336 let _ = write!(result, ":{}", range.start + 1);
337 } else {
338 let _ = write!(result, ":{}-{}", range.start + 1, range.end + 1);
339 }
340 }
341
342 result
343}
344
345impl AgentMessage {
346 pub fn to_markdown(&self) -> String {
347 let mut markdown = String::from("## Assistant\n\n");
348
349 for content in &self.content {
350 match content {
351 AgentMessageContent::Text(text) => {
352 markdown.push_str(text);
353 markdown.push('\n');
354 }
355 AgentMessageContent::Thinking { text, .. } => {
356 markdown.push_str("<think>");
357 markdown.push_str(text);
358 markdown.push_str("</think>\n");
359 }
360 AgentMessageContent::RedactedThinking(_) => {
361 markdown.push_str("<redacted_thinking />\n")
362 }
363 AgentMessageContent::ToolUse(tool_use) => {
364 markdown.push_str(&format!(
365 "**Tool Use**: {} (ID: {})\n",
366 tool_use.name, tool_use.id
367 ));
368 markdown.push_str(&format!(
369 "{}\n",
370 MarkdownCodeBlock {
371 tag: "json",
372 text: &format!("{:#}", tool_use.input)
373 }
374 ));
375 }
376 }
377 }
378
379 for tool_result in self.tool_results.values() {
380 markdown.push_str(&format!(
381 "**Tool Result**: {} (ID: {})\n\n",
382 tool_result.tool_name, tool_result.tool_use_id
383 ));
384 if tool_result.is_error {
385 markdown.push_str("**ERROR:**\n");
386 }
387
388 match &tool_result.content {
389 LanguageModelToolResultContent::Text(text) => {
390 writeln!(markdown, "{text}\n").ok();
391 }
392 LanguageModelToolResultContent::Image(_) => {
393 writeln!(markdown, "<image />\n").ok();
394 }
395 }
396
397 if let Some(output) = tool_result.output.as_ref() {
398 writeln!(
399 markdown,
400 "**Debug Output**:\n\n```json\n{}\n```\n",
401 serde_json::to_string_pretty(output).unwrap()
402 )
403 .unwrap();
404 }
405 }
406
407 markdown
408 }
409
410 pub fn to_request(&self) -> Vec<LanguageModelRequestMessage> {
411 let mut assistant_message = LanguageModelRequestMessage {
412 role: Role::Assistant,
413 content: Vec::with_capacity(self.content.len()),
414 cache: false,
415 };
416 for chunk in &self.content {
417 let chunk = match chunk {
418 AgentMessageContent::Text(text) => {
419 language_model::MessageContent::Text(text.clone())
420 }
421 AgentMessageContent::Thinking { text, signature } => {
422 language_model::MessageContent::Thinking {
423 text: text.clone(),
424 signature: signature.clone(),
425 }
426 }
427 AgentMessageContent::RedactedThinking(value) => {
428 language_model::MessageContent::RedactedThinking(value.clone())
429 }
430 AgentMessageContent::ToolUse(value) => {
431 language_model::MessageContent::ToolUse(value.clone())
432 }
433 };
434 assistant_message.content.push(chunk);
435 }
436
437 let mut user_message = LanguageModelRequestMessage {
438 role: Role::User,
439 content: Vec::new(),
440 cache: false,
441 };
442
443 for tool_result in self.tool_results.values() {
444 user_message
445 .content
446 .push(language_model::MessageContent::ToolResult(
447 tool_result.clone(),
448 ));
449 }
450
451 let mut messages = Vec::new();
452 if !assistant_message.content.is_empty() {
453 messages.push(assistant_message);
454 }
455 if !user_message.content.is_empty() {
456 messages.push(user_message);
457 }
458 messages
459 }
460}
461
462#[derive(Default, Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
463pub struct AgentMessage {
464 pub content: Vec<AgentMessageContent>,
465 pub tool_results: IndexMap<LanguageModelToolUseId, LanguageModelToolResult>,
466}
467
468#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
469pub enum AgentMessageContent {
470 Text(String),
471 Thinking {
472 text: String,
473 signature: Option<String>,
474 },
475 RedactedThinking(String),
476 ToolUse(LanguageModelToolUse),
477}
478
479#[derive(Debug)]
480pub enum ThreadEvent {
481 UserMessage(UserMessage),
482 AgentText(String),
483 AgentThinking(String),
484 ToolCall(acp::ToolCall),
485 ToolCallUpdate(acp_thread::ToolCallUpdate),
486 ToolCallAuthorization(ToolCallAuthorization),
487 TokenUsageUpdate(acp_thread::TokenUsage),
488 TitleUpdate(SharedString),
489 Retry(acp_thread::RetryStatus),
490 Stop(acp::StopReason),
491}
492
493#[derive(Debug)]
494pub struct ToolCallAuthorization {
495 pub tool_call: acp::ToolCallUpdate,
496 pub options: Vec<acp::PermissionOption>,
497 pub response: oneshot::Sender<acp::PermissionOptionId>,
498}
499
500pub struct Thread {
501 id: acp::SessionId,
502 prompt_id: PromptId,
503 updated_at: DateTime<Utc>,
504 title: Option<SharedString>,
505 summary: Option<SharedString>,
506 messages: Vec<Message>,
507 completion_mode: CompletionMode,
508 /// Holds the task that handles agent interaction until the end of the turn.
509 /// Survives across multiple requests as the model performs tool calls and
510 /// we run tools, report their results.
511 running_turn: Option<RunningTurn>,
512 pending_message: Option<AgentMessage>,
513 tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
514 tool_use_limit_reached: bool,
515 request_token_usage: HashMap<UserMessageId, language_model::TokenUsage>,
516 #[allow(unused)]
517 cumulative_token_usage: TokenUsage,
518 #[allow(unused)]
519 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
520 context_server_registry: Entity<ContextServerRegistry>,
521 profile_id: AgentProfileId,
522 project_context: Entity<ProjectContext>,
523 templates: Arc<Templates>,
524 model: Option<Arc<dyn LanguageModel>>,
525 summarization_model: Option<Arc<dyn LanguageModel>>,
526 pub(crate) project: Entity<Project>,
527 pub(crate) action_log: Entity<ActionLog>,
528}
529
530impl Thread {
531 pub fn new(
532 project: Entity<Project>,
533 project_context: Entity<ProjectContext>,
534 context_server_registry: Entity<ContextServerRegistry>,
535 action_log: Entity<ActionLog>,
536 templates: Arc<Templates>,
537 model: Option<Arc<dyn LanguageModel>>,
538 cx: &mut Context<Self>,
539 ) -> Self {
540 let profile_id = AgentSettings::get_global(cx).default_profile.clone();
541 Self {
542 id: acp::SessionId(uuid::Uuid::new_v4().to_string().into()),
543 prompt_id: PromptId::new(),
544 updated_at: Utc::now(),
545 title: None,
546 summary: None,
547 messages: Vec::new(),
548 completion_mode: AgentSettings::get_global(cx).preferred_completion_mode,
549 running_turn: None,
550 pending_message: None,
551 tools: BTreeMap::default(),
552 tool_use_limit_reached: false,
553 request_token_usage: HashMap::default(),
554 cumulative_token_usage: TokenUsage::default(),
555 initial_project_snapshot: {
556 let project_snapshot = Self::project_snapshot(project.clone(), cx);
557 cx.foreground_executor()
558 .spawn(async move { Some(project_snapshot.await) })
559 .shared()
560 },
561 context_server_registry,
562 profile_id,
563 project_context,
564 templates,
565 model,
566 summarization_model: None,
567 project,
568 action_log,
569 }
570 }
571
572 pub fn id(&self) -> &acp::SessionId {
573 &self.id
574 }
575
576 pub fn replay(
577 &mut self,
578 cx: &mut Context<Self>,
579 ) -> mpsc::UnboundedReceiver<Result<ThreadEvent>> {
580 let (tx, rx) = mpsc::unbounded();
581 let stream = ThreadEventStream(tx);
582 for message in &self.messages {
583 match message {
584 Message::User(user_message) => stream.send_user_message(user_message),
585 Message::Agent(assistant_message) => {
586 for content in &assistant_message.content {
587 match content {
588 AgentMessageContent::Text(text) => stream.send_text(text),
589 AgentMessageContent::Thinking { text, .. } => {
590 stream.send_thinking(text)
591 }
592 AgentMessageContent::RedactedThinking(_) => {}
593 AgentMessageContent::ToolUse(tool_use) => {
594 self.replay_tool_call(
595 tool_use,
596 assistant_message.tool_results.get(&tool_use.id),
597 &stream,
598 cx,
599 );
600 }
601 }
602 }
603 }
604 Message::Resume => {}
605 }
606 }
607 rx
608 }
609
610 fn replay_tool_call(
611 &self,
612 tool_use: &LanguageModelToolUse,
613 tool_result: Option<&LanguageModelToolResult>,
614 stream: &ThreadEventStream,
615 cx: &mut Context<Self>,
616 ) {
617 let Some(tool) = self.tools.get(tool_use.name.as_ref()) else {
618 stream
619 .0
620 .unbounded_send(Ok(ThreadEvent::ToolCall(acp::ToolCall {
621 id: acp::ToolCallId(tool_use.id.to_string().into()),
622 title: tool_use.name.to_string(),
623 kind: acp::ToolKind::Other,
624 status: acp::ToolCallStatus::Failed,
625 content: Vec::new(),
626 locations: Vec::new(),
627 raw_input: Some(tool_use.input.clone()),
628 raw_output: None,
629 })))
630 .ok();
631 return;
632 };
633
634 let title = tool.initial_title(tool_use.input.clone());
635 let kind = tool.kind();
636 stream.send_tool_call(&tool_use.id, title, kind, tool_use.input.clone());
637
638 let output = tool_result
639 .as_ref()
640 .and_then(|result| result.output.clone());
641 if let Some(output) = output.clone() {
642 let tool_event_stream = ToolCallEventStream::new(
643 tool_use.id.clone(),
644 stream.clone(),
645 Some(self.project.read(cx).fs().clone()),
646 );
647 tool.replay(tool_use.input.clone(), output, tool_event_stream, cx)
648 .log_err();
649 }
650
651 stream.update_tool_call_fields(
652 &tool_use.id,
653 acp::ToolCallUpdateFields {
654 status: Some(acp::ToolCallStatus::Completed),
655 raw_output: output,
656 ..Default::default()
657 },
658 );
659 }
660
661 pub fn from_db(
662 id: acp::SessionId,
663 db_thread: DbThread,
664 project: Entity<Project>,
665 project_context: Entity<ProjectContext>,
666 context_server_registry: Entity<ContextServerRegistry>,
667 action_log: Entity<ActionLog>,
668 templates: Arc<Templates>,
669 cx: &mut Context<Self>,
670 ) -> Self {
671 let profile_id = db_thread
672 .profile
673 .unwrap_or_else(|| AgentSettings::get_global(cx).default_profile.clone());
674 let model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
675 db_thread
676 .model
677 .and_then(|model| {
678 let model = SelectedModel {
679 provider: model.provider.clone().into(),
680 model: model.model.clone().into(),
681 };
682 registry.select_model(&model, cx)
683 })
684 .or_else(|| registry.default_model())
685 .map(|model| model.model)
686 });
687
688 Self {
689 id,
690 prompt_id: PromptId::new(),
691 title: if db_thread.title.is_empty() {
692 None
693 } else {
694 Some(db_thread.title.clone())
695 },
696 summary: db_thread.detailed_summary,
697 messages: db_thread.messages,
698 completion_mode: db_thread.completion_mode.unwrap_or_default(),
699 running_turn: None,
700 pending_message: None,
701 tools: BTreeMap::default(),
702 tool_use_limit_reached: false,
703 request_token_usage: db_thread.request_token_usage.clone(),
704 cumulative_token_usage: db_thread.cumulative_token_usage,
705 initial_project_snapshot: Task::ready(db_thread.initial_project_snapshot).shared(),
706 context_server_registry,
707 profile_id,
708 project_context,
709 templates,
710 model,
711 summarization_model: None,
712 project,
713 action_log,
714 updated_at: db_thread.updated_at,
715 }
716 }
717
718 pub fn to_db(&self, cx: &App) -> Task<DbThread> {
719 let initial_project_snapshot = self.initial_project_snapshot.clone();
720 let mut thread = DbThread {
721 title: self.title.clone().unwrap_or_default(),
722 messages: self.messages.clone(),
723 updated_at: self.updated_at,
724 detailed_summary: self.summary.clone(),
725 initial_project_snapshot: None,
726 cumulative_token_usage: self.cumulative_token_usage,
727 request_token_usage: self.request_token_usage.clone(),
728 model: self.model.as_ref().map(|model| DbLanguageModel {
729 provider: model.provider_id().to_string(),
730 model: model.name().0.to_string(),
731 }),
732 completion_mode: Some(self.completion_mode),
733 profile: Some(self.profile_id.clone()),
734 };
735
736 cx.background_spawn(async move {
737 let initial_project_snapshot = initial_project_snapshot.await;
738 thread.initial_project_snapshot = initial_project_snapshot;
739 thread
740 })
741 }
742
743 /// Create a snapshot of the current project state including git information and unsaved buffers.
744 fn project_snapshot(
745 project: Entity<Project>,
746 cx: &mut Context<Self>,
747 ) -> Task<Arc<agent::thread::ProjectSnapshot>> {
748 let git_store = project.read(cx).git_store().clone();
749 let worktree_snapshots: Vec<_> = project
750 .read(cx)
751 .visible_worktrees(cx)
752 .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
753 .collect();
754
755 cx.spawn(async move |_, cx| {
756 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
757
758 let mut unsaved_buffers = Vec::new();
759 cx.update(|app_cx| {
760 let buffer_store = project.read(app_cx).buffer_store();
761 for buffer_handle in buffer_store.read(app_cx).buffers() {
762 let buffer = buffer_handle.read(app_cx);
763 if buffer.is_dirty()
764 && let Some(file) = buffer.file()
765 {
766 let path = file.path().to_string_lossy().to_string();
767 unsaved_buffers.push(path);
768 }
769 }
770 })
771 .ok();
772
773 Arc::new(ProjectSnapshot {
774 worktree_snapshots,
775 unsaved_buffer_paths: unsaved_buffers,
776 timestamp: Utc::now(),
777 })
778 })
779 }
780
781 fn worktree_snapshot(
782 worktree: Entity<project::Worktree>,
783 git_store: Entity<GitStore>,
784 cx: &App,
785 ) -> Task<agent::thread::WorktreeSnapshot> {
786 cx.spawn(async move |cx| {
787 // Get worktree path and snapshot
788 let worktree_info = cx.update(|app_cx| {
789 let worktree = worktree.read(app_cx);
790 let path = worktree.abs_path().to_string_lossy().to_string();
791 let snapshot = worktree.snapshot();
792 (path, snapshot)
793 });
794
795 let Ok((worktree_path, _snapshot)) = worktree_info else {
796 return WorktreeSnapshot {
797 worktree_path: String::new(),
798 git_state: None,
799 };
800 };
801
802 let git_state = git_store
803 .update(cx, |git_store, cx| {
804 git_store
805 .repositories()
806 .values()
807 .find(|repo| {
808 repo.read(cx)
809 .abs_path_to_repo_path(&worktree.read(cx).abs_path())
810 .is_some()
811 })
812 .cloned()
813 })
814 .ok()
815 .flatten()
816 .map(|repo| {
817 repo.update(cx, |repo, _| {
818 let current_branch =
819 repo.branch.as_ref().map(|branch| branch.name().to_owned());
820 repo.send_job(None, |state, _| async move {
821 let RepositoryState::Local { backend, .. } = state else {
822 return GitState {
823 remote_url: None,
824 head_sha: None,
825 current_branch,
826 diff: None,
827 };
828 };
829
830 let remote_url = backend.remote_url("origin");
831 let head_sha = backend.head_sha().await;
832 let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
833
834 GitState {
835 remote_url,
836 head_sha,
837 current_branch,
838 diff,
839 }
840 })
841 })
842 });
843
844 let git_state = match git_state {
845 Some(git_state) => match git_state.ok() {
846 Some(git_state) => git_state.await.ok(),
847 None => None,
848 },
849 None => None,
850 };
851
852 WorktreeSnapshot {
853 worktree_path,
854 git_state,
855 }
856 })
857 }
858
859 pub fn project_context(&self) -> &Entity<ProjectContext> {
860 &self.project_context
861 }
862
863 pub fn project(&self) -> &Entity<Project> {
864 &self.project
865 }
866
867 pub fn action_log(&self) -> &Entity<ActionLog> {
868 &self.action_log
869 }
870
871 pub fn model(&self) -> Option<&Arc<dyn LanguageModel>> {
872 self.model.as_ref()
873 }
874
875 pub fn set_model(&mut self, model: Arc<dyn LanguageModel>, cx: &mut Context<Self>) {
876 self.model = Some(model);
877 cx.notify()
878 }
879
880 pub fn set_summarization_model(
881 &mut self,
882 model: Option<Arc<dyn LanguageModel>>,
883 cx: &mut Context<Self>,
884 ) {
885 self.summarization_model = model;
886 cx.notify()
887 }
888
889 pub fn completion_mode(&self) -> CompletionMode {
890 self.completion_mode
891 }
892
893 pub fn set_completion_mode(&mut self, mode: CompletionMode, cx: &mut Context<Self>) {
894 self.completion_mode = mode;
895 cx.notify()
896 }
897
898 #[cfg(any(test, feature = "test-support"))]
899 pub fn last_message(&self) -> Option<Message> {
900 if let Some(message) = self.pending_message.clone() {
901 Some(Message::Agent(message))
902 } else {
903 self.messages.last().cloned()
904 }
905 }
906
907 pub fn add_default_tools(&mut self, cx: &mut Context<Self>) {
908 let language_registry = self.project.read(cx).languages().clone();
909 self.add_tool(CopyPathTool::new(self.project.clone()));
910 self.add_tool(CreateDirectoryTool::new(self.project.clone()));
911 self.add_tool(DeletePathTool::new(
912 self.project.clone(),
913 self.action_log.clone(),
914 ));
915 self.add_tool(DiagnosticsTool::new(self.project.clone()));
916 self.add_tool(EditFileTool::new(cx.weak_entity(), language_registry));
917 self.add_tool(FetchTool::new(self.project.read(cx).client().http_client()));
918 self.add_tool(FindPathTool::new(self.project.clone()));
919 self.add_tool(GrepTool::new(self.project.clone()));
920 self.add_tool(ListDirectoryTool::new(self.project.clone()));
921 self.add_tool(MovePathTool::new(self.project.clone()));
922 self.add_tool(NowTool);
923 self.add_tool(OpenTool::new(self.project.clone()));
924 self.add_tool(ReadFileTool::new(
925 self.project.clone(),
926 self.action_log.clone(),
927 ));
928 self.add_tool(TerminalTool::new(self.project.clone(), cx));
929 self.add_tool(ThinkingTool);
930 self.add_tool(WebSearchTool); // TODO: Enable this only if it's a zed model.
931 }
932
933 pub fn add_tool(&mut self, tool: impl AgentTool) {
934 self.tools.insert(tool.name(), tool.erase());
935 }
936
937 pub fn remove_tool(&mut self, name: &str) -> bool {
938 self.tools.remove(name).is_some()
939 }
940
941 pub fn profile(&self) -> &AgentProfileId {
942 &self.profile_id
943 }
944
945 pub fn set_profile(&mut self, profile_id: AgentProfileId) {
946 self.profile_id = profile_id;
947 }
948
949 pub fn cancel(&mut self, cx: &mut Context<Self>) {
950 if let Some(running_turn) = self.running_turn.take() {
951 running_turn.cancel();
952 }
953 self.flush_pending_message(cx);
954 }
955
956 pub fn update_token_usage(&mut self, update: language_model::TokenUsage) {
957 let Some(last_user_message) = self.last_user_message() else {
958 return;
959 };
960
961 self.request_token_usage
962 .insert(last_user_message.id.clone(), update);
963 }
964
965 pub fn truncate(&mut self, message_id: UserMessageId, cx: &mut Context<Self>) -> Result<()> {
966 self.cancel(cx);
967 let Some(position) = self.messages.iter().position(
968 |msg| matches!(msg, Message::User(UserMessage { id, .. }) if id == &message_id),
969 ) else {
970 return Err(anyhow!("Message not found"));
971 };
972
973 for message in self.messages.drain(position..) {
974 match message {
975 Message::User(message) => {
976 self.request_token_usage.remove(&message.id);
977 }
978 Message::Agent(_) | Message::Resume => {}
979 }
980 }
981 self.summary = None;
982 cx.notify();
983 Ok(())
984 }
985
986 pub fn latest_token_usage(&self) -> Option<acp_thread::TokenUsage> {
987 let last_user_message = self.last_user_message()?;
988 let tokens = self.request_token_usage.get(&last_user_message.id)?;
989 let model = self.model.clone()?;
990
991 Some(acp_thread::TokenUsage {
992 max_tokens: model.max_token_count_for_mode(self.completion_mode.into()),
993 used_tokens: tokens.total_tokens(),
994 })
995 }
996
997 pub fn resume(
998 &mut self,
999 cx: &mut Context<Self>,
1000 ) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>> {
1001 anyhow::ensure!(
1002 self.tool_use_limit_reached,
1003 "can only resume after tool use limit is reached"
1004 );
1005
1006 self.messages.push(Message::Resume);
1007 cx.notify();
1008
1009 log::info!("Total messages in thread: {}", self.messages.len());
1010 self.run_turn(cx)
1011 }
1012
1013 /// Sending a message results in the model streaming a response, which could include tool calls.
1014 /// After calling tools, the model will stops and waits for any outstanding tool calls to be completed and their results sent.
1015 /// The returned channel will report all the occurrences in which the model stops before erroring or ending its turn.
1016 pub fn send<T>(
1017 &mut self,
1018 id: UserMessageId,
1019 content: impl IntoIterator<Item = T>,
1020 cx: &mut Context<Self>,
1021 ) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>
1022 where
1023 T: Into<UserMessageContent>,
1024 {
1025 let model = self.model().context("No language model configured")?;
1026
1027 log::info!("Thread::send called with model: {:?}", model.name());
1028 self.advance_prompt_id();
1029
1030 let content = content.into_iter().map(Into::into).collect::<Vec<_>>();
1031 log::debug!("Thread::send content: {:?}", content);
1032
1033 self.messages
1034 .push(Message::User(UserMessage { id, content }));
1035 cx.notify();
1036
1037 log::info!("Total messages in thread: {}", self.messages.len());
1038 self.run_turn(cx)
1039 }
1040
1041 fn run_turn(
1042 &mut self,
1043 cx: &mut Context<Self>,
1044 ) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>> {
1045 self.cancel(cx);
1046
1047 let model = self.model.clone().context("No language model configured")?;
1048 let (events_tx, events_rx) = mpsc::unbounded::<Result<ThreadEvent>>();
1049 let event_stream = ThreadEventStream(events_tx);
1050 let message_ix = self.messages.len().saturating_sub(1);
1051 self.tool_use_limit_reached = false;
1052 self.summary = None;
1053 self.running_turn = Some(RunningTurn {
1054 event_stream: event_stream.clone(),
1055 _task: cx.spawn(async move |this, cx| {
1056 log::info!("Starting agent turn execution");
1057 let turn_result: Result<StopReason> = async {
1058 let mut completion_intent = CompletionIntent::UserPrompt;
1059 loop {
1060 log::debug!(
1061 "Building completion request with intent: {:?}",
1062 completion_intent
1063 );
1064 let request = this.update(cx, |this, cx| {
1065 this.build_completion_request(completion_intent, cx)
1066 })??;
1067
1068 log::info!("Calling model.stream_completion");
1069
1070 let mut tool_use_limit_reached = false;
1071 let mut refused = false;
1072 let mut reached_max_tokens = false;
1073 let mut tool_uses = Self::stream_completion_with_retries(
1074 this.clone(),
1075 model.clone(),
1076 request,
1077 &event_stream,
1078 &mut tool_use_limit_reached,
1079 &mut refused,
1080 &mut reached_max_tokens,
1081 cx,
1082 )
1083 .await?;
1084
1085 if refused {
1086 return Ok(StopReason::Refusal);
1087 } else if reached_max_tokens {
1088 return Ok(StopReason::MaxTokens);
1089 }
1090
1091 let end_turn = tool_uses.is_empty();
1092 while let Some(tool_result) = tool_uses.next().await {
1093 log::info!("Tool finished {:?}", tool_result);
1094
1095 event_stream.update_tool_call_fields(
1096 &tool_result.tool_use_id,
1097 acp::ToolCallUpdateFields {
1098 status: Some(if tool_result.is_error {
1099 acp::ToolCallStatus::Failed
1100 } else {
1101 acp::ToolCallStatus::Completed
1102 }),
1103 raw_output: tool_result.output.clone(),
1104 ..Default::default()
1105 },
1106 );
1107 this.update(cx, |this, _cx| {
1108 this.pending_message()
1109 .tool_results
1110 .insert(tool_result.tool_use_id.clone(), tool_result);
1111 })
1112 .ok();
1113 }
1114
1115 if tool_use_limit_reached {
1116 log::info!("Tool use limit reached, completing turn");
1117 this.update(cx, |this, _cx| this.tool_use_limit_reached = true)?;
1118 return Err(language_model::ToolUseLimitReachedError.into());
1119 } else if end_turn {
1120 log::info!("No tool uses found, completing turn");
1121 return Ok(StopReason::EndTurn);
1122 } else {
1123 this.update(cx, |this, cx| this.flush_pending_message(cx))?;
1124 completion_intent = CompletionIntent::ToolResults;
1125 }
1126 }
1127 }
1128 .await;
1129 _ = this.update(cx, |this, cx| this.flush_pending_message(cx));
1130
1131 match turn_result {
1132 Ok(reason) => {
1133 log::info!("Turn execution completed: {:?}", reason);
1134
1135 let update_title = this
1136 .update(cx, |this, cx| this.update_title(&event_stream, cx))
1137 .ok()
1138 .flatten();
1139 if let Some(update_title) = update_title {
1140 update_title.await.context("update title failed").log_err();
1141 }
1142
1143 event_stream.send_stop(reason);
1144 if reason == StopReason::Refusal {
1145 _ = this.update(cx, |this, _| this.messages.truncate(message_ix));
1146 }
1147 }
1148 Err(error) => {
1149 log::error!("Turn execution failed: {:?}", error);
1150 event_stream.send_error(error);
1151 }
1152 }
1153
1154 _ = this.update(cx, |this, _| this.running_turn.take());
1155 }),
1156 });
1157 Ok(events_rx)
1158 }
1159
1160 async fn stream_completion_with_retries(
1161 this: WeakEntity<Self>,
1162 model: Arc<dyn LanguageModel>,
1163 request: LanguageModelRequest,
1164 event_stream: &ThreadEventStream,
1165 tool_use_limit_reached: &mut bool,
1166 refusal: &mut bool,
1167 max_tokens_reached: &mut bool,
1168 cx: &mut AsyncApp,
1169 ) -> Result<FuturesUnordered<Task<LanguageModelToolResult>>> {
1170 log::debug!("Stream completion started successfully");
1171
1172 let mut attempt = None;
1173 'retry: loop {
1174 let mut events = model.stream_completion(request.clone(), cx).await?;
1175 let mut tool_uses = FuturesUnordered::new();
1176 while let Some(event) = events.next().await {
1177 match event {
1178 Ok(LanguageModelCompletionEvent::StatusUpdate(
1179 CompletionRequestStatus::ToolUseLimitReached,
1180 )) => {
1181 *tool_use_limit_reached = true;
1182 }
1183 Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
1184 let usage = acp_thread::TokenUsage {
1185 max_tokens: model.max_token_count_for_mode(
1186 request
1187 .mode
1188 .unwrap_or(cloud_llm_client::CompletionMode::Normal),
1189 ),
1190 used_tokens: token_usage.total_tokens(),
1191 };
1192
1193 this.update(cx, |this, _cx| this.update_token_usage(token_usage))
1194 .ok();
1195
1196 event_stream.send_token_usage_update(usage);
1197 }
1198 Ok(LanguageModelCompletionEvent::Stop(StopReason::Refusal)) => {
1199 *refusal = true;
1200 return Ok(FuturesUnordered::default());
1201 }
1202 Ok(LanguageModelCompletionEvent::Stop(StopReason::MaxTokens)) => {
1203 *max_tokens_reached = true;
1204 return Ok(FuturesUnordered::default());
1205 }
1206 Ok(LanguageModelCompletionEvent::Stop(
1207 StopReason::ToolUse | StopReason::EndTurn,
1208 )) => break,
1209 Ok(event) => {
1210 log::trace!("Received completion event: {:?}", event);
1211 this.update(cx, |this, cx| {
1212 tool_uses.extend(this.handle_streamed_completion_event(
1213 event,
1214 event_stream,
1215 cx,
1216 ));
1217 })
1218 .ok();
1219 }
1220 Err(error) => {
1221 let completion_mode =
1222 this.read_with(cx, |thread, _cx| thread.completion_mode())?;
1223 if completion_mode == CompletionMode::Normal {
1224 return Err(error.into());
1225 }
1226
1227 let Some(strategy) = Self::retry_strategy_for(&error) else {
1228 return Err(error.into());
1229 };
1230
1231 let max_attempts = match &strategy {
1232 RetryStrategy::ExponentialBackoff { max_attempts, .. } => *max_attempts,
1233 RetryStrategy::Fixed { max_attempts, .. } => *max_attempts,
1234 };
1235
1236 let attempt = attempt.get_or_insert(0u8);
1237
1238 *attempt += 1;
1239
1240 let attempt = *attempt;
1241 if attempt > max_attempts {
1242 return Err(error.into());
1243 }
1244
1245 let delay = match &strategy {
1246 RetryStrategy::ExponentialBackoff { initial_delay, .. } => {
1247 let delay_secs =
1248 initial_delay.as_secs() * 2u64.pow((attempt - 1) as u32);
1249 Duration::from_secs(delay_secs)
1250 }
1251 RetryStrategy::Fixed { delay, .. } => *delay,
1252 };
1253 log::debug!("Retry attempt {attempt} with delay {delay:?}");
1254
1255 event_stream.send_retry(acp_thread::RetryStatus {
1256 last_error: error.to_string().into(),
1257 attempt: attempt as usize,
1258 max_attempts: max_attempts as usize,
1259 started_at: Instant::now(),
1260 duration: delay,
1261 });
1262
1263 cx.background_executor().timer(delay).await;
1264 continue 'retry;
1265 }
1266 }
1267 }
1268
1269 return Ok(tool_uses);
1270 }
1271 }
1272
1273 pub fn build_system_message(&self, cx: &App) -> LanguageModelRequestMessage {
1274 log::debug!("Building system message");
1275 let prompt = SystemPromptTemplate {
1276 project: self.project_context.read(cx),
1277 available_tools: self.tools.keys().cloned().collect(),
1278 }
1279 .render(&self.templates)
1280 .context("failed to build system prompt")
1281 .expect("Invalid template");
1282 log::debug!("System message built");
1283 LanguageModelRequestMessage {
1284 role: Role::System,
1285 content: vec![prompt.into()],
1286 cache: true,
1287 }
1288 }
1289
1290 /// A helper method that's called on every streamed completion event.
1291 /// Returns an optional tool result task, which the main agentic loop in
1292 /// send will send back to the model when it resolves.
1293 fn handle_streamed_completion_event(
1294 &mut self,
1295 event: LanguageModelCompletionEvent,
1296 event_stream: &ThreadEventStream,
1297 cx: &mut Context<Self>,
1298 ) -> Option<Task<LanguageModelToolResult>> {
1299 log::trace!("Handling streamed completion event: {:?}", event);
1300 use LanguageModelCompletionEvent::*;
1301
1302 match event {
1303 StartMessage { .. } => {
1304 self.flush_pending_message(cx);
1305 self.pending_message = Some(AgentMessage::default());
1306 }
1307 Text(new_text) => self.handle_text_event(new_text, event_stream, cx),
1308 Thinking { text, signature } => {
1309 self.handle_thinking_event(text, signature, event_stream, cx)
1310 }
1311 RedactedThinking { data } => self.handle_redacted_thinking_event(data, cx),
1312 ToolUse(tool_use) => {
1313 return self.handle_tool_use_event(tool_use, event_stream, cx);
1314 }
1315 ToolUseJsonParseError {
1316 id,
1317 tool_name,
1318 raw_input,
1319 json_parse_error,
1320 } => {
1321 return Some(Task::ready(self.handle_tool_use_json_parse_error_event(
1322 id,
1323 tool_name,
1324 raw_input,
1325 json_parse_error,
1326 )));
1327 }
1328 UsageUpdate(_) | StatusUpdate(_) => {}
1329 Stop(_) => unreachable!(),
1330 }
1331
1332 None
1333 }
1334
1335 fn handle_text_event(
1336 &mut self,
1337 new_text: String,
1338 event_stream: &ThreadEventStream,
1339 cx: &mut Context<Self>,
1340 ) {
1341 event_stream.send_text(&new_text);
1342
1343 let last_message = self.pending_message();
1344 if let Some(AgentMessageContent::Text(text)) = last_message.content.last_mut() {
1345 text.push_str(&new_text);
1346 } else {
1347 last_message
1348 .content
1349 .push(AgentMessageContent::Text(new_text));
1350 }
1351
1352 cx.notify();
1353 }
1354
1355 fn handle_thinking_event(
1356 &mut self,
1357 new_text: String,
1358 new_signature: Option<String>,
1359 event_stream: &ThreadEventStream,
1360 cx: &mut Context<Self>,
1361 ) {
1362 event_stream.send_thinking(&new_text);
1363
1364 let last_message = self.pending_message();
1365 if let Some(AgentMessageContent::Thinking { text, signature }) =
1366 last_message.content.last_mut()
1367 {
1368 text.push_str(&new_text);
1369 *signature = new_signature.or(signature.take());
1370 } else {
1371 last_message.content.push(AgentMessageContent::Thinking {
1372 text: new_text,
1373 signature: new_signature,
1374 });
1375 }
1376
1377 cx.notify();
1378 }
1379
1380 fn handle_redacted_thinking_event(&mut self, data: String, cx: &mut Context<Self>) {
1381 let last_message = self.pending_message();
1382 last_message
1383 .content
1384 .push(AgentMessageContent::RedactedThinking(data));
1385 cx.notify();
1386 }
1387
1388 fn handle_tool_use_event(
1389 &mut self,
1390 tool_use: LanguageModelToolUse,
1391 event_stream: &ThreadEventStream,
1392 cx: &mut Context<Self>,
1393 ) -> Option<Task<LanguageModelToolResult>> {
1394 cx.notify();
1395
1396 let tool = self.tools.get(tool_use.name.as_ref()).cloned();
1397 let mut title = SharedString::from(&tool_use.name);
1398 let mut kind = acp::ToolKind::Other;
1399 if let Some(tool) = tool.as_ref() {
1400 title = tool.initial_title(tool_use.input.clone());
1401 kind = tool.kind();
1402 }
1403
1404 // Ensure the last message ends in the current tool use
1405 let last_message = self.pending_message();
1406 let push_new_tool_use = last_message.content.last_mut().is_none_or(|content| {
1407 if let AgentMessageContent::ToolUse(last_tool_use) = content {
1408 if last_tool_use.id == tool_use.id {
1409 *last_tool_use = tool_use.clone();
1410 false
1411 } else {
1412 true
1413 }
1414 } else {
1415 true
1416 }
1417 });
1418
1419 if push_new_tool_use {
1420 event_stream.send_tool_call(&tool_use.id, title, kind, tool_use.input.clone());
1421 last_message
1422 .content
1423 .push(AgentMessageContent::ToolUse(tool_use.clone()));
1424 } else {
1425 event_stream.update_tool_call_fields(
1426 &tool_use.id,
1427 acp::ToolCallUpdateFields {
1428 title: Some(title.into()),
1429 kind: Some(kind),
1430 raw_input: Some(tool_use.input.clone()),
1431 ..Default::default()
1432 },
1433 );
1434 }
1435
1436 if !tool_use.is_input_complete {
1437 return None;
1438 }
1439
1440 let Some(tool) = tool else {
1441 let content = format!("No tool named {} exists", tool_use.name);
1442 return Some(Task::ready(LanguageModelToolResult {
1443 content: LanguageModelToolResultContent::Text(Arc::from(content)),
1444 tool_use_id: tool_use.id,
1445 tool_name: tool_use.name,
1446 is_error: true,
1447 output: None,
1448 }));
1449 };
1450
1451 let fs = self.project.read(cx).fs().clone();
1452 let tool_event_stream =
1453 ToolCallEventStream::new(tool_use.id.clone(), event_stream.clone(), Some(fs));
1454 tool_event_stream.update_fields(acp::ToolCallUpdateFields {
1455 status: Some(acp::ToolCallStatus::InProgress),
1456 ..Default::default()
1457 });
1458 let supports_images = self.model().is_some_and(|model| model.supports_images());
1459 let tool_result = tool.run(tool_use.input, tool_event_stream, cx);
1460 log::info!("Running tool {}", tool_use.name);
1461 Some(cx.foreground_executor().spawn(async move {
1462 let tool_result = tool_result.await.and_then(|output| {
1463 if let LanguageModelToolResultContent::Image(_) = &output.llm_output
1464 && !supports_images
1465 {
1466 return Err(anyhow!(
1467 "Attempted to read an image, but this model doesn't support it.",
1468 ));
1469 }
1470 Ok(output)
1471 });
1472
1473 match tool_result {
1474 Ok(output) => LanguageModelToolResult {
1475 tool_use_id: tool_use.id,
1476 tool_name: tool_use.name,
1477 is_error: false,
1478 content: output.llm_output,
1479 output: Some(output.raw_output),
1480 },
1481 Err(error) => LanguageModelToolResult {
1482 tool_use_id: tool_use.id,
1483 tool_name: tool_use.name,
1484 is_error: true,
1485 content: LanguageModelToolResultContent::Text(Arc::from(error.to_string())),
1486 output: None,
1487 },
1488 }
1489 }))
1490 }
1491
1492 fn handle_tool_use_json_parse_error_event(
1493 &mut self,
1494 tool_use_id: LanguageModelToolUseId,
1495 tool_name: Arc<str>,
1496 raw_input: Arc<str>,
1497 json_parse_error: String,
1498 ) -> LanguageModelToolResult {
1499 let tool_output = format!("Error parsing input JSON: {json_parse_error}");
1500 LanguageModelToolResult {
1501 tool_use_id,
1502 tool_name,
1503 is_error: true,
1504 content: LanguageModelToolResultContent::Text(tool_output.into()),
1505 output: Some(serde_json::Value::String(raw_input.to_string())),
1506 }
1507 }
1508
1509 pub fn title(&self) -> SharedString {
1510 self.title.clone().unwrap_or("New Thread".into())
1511 }
1512
1513 pub fn summary(&mut self, cx: &mut Context<Self>) -> Task<Result<SharedString>> {
1514 if let Some(summary) = self.summary.as_ref() {
1515 return Task::ready(Ok(summary.clone()));
1516 }
1517 let Some(model) = self.summarization_model.clone() else {
1518 return Task::ready(Err(anyhow!("No summarization model available")));
1519 };
1520 let mut request = LanguageModelRequest {
1521 intent: Some(CompletionIntent::ThreadSummarization),
1522 temperature: AgentSettings::temperature_for_model(&model, cx),
1523 ..Default::default()
1524 };
1525
1526 for message in &self.messages {
1527 request.messages.extend(message.to_request());
1528 }
1529
1530 request.messages.push(LanguageModelRequestMessage {
1531 role: Role::User,
1532 content: vec![SUMMARIZE_THREAD_DETAILED_PROMPT.into()],
1533 cache: false,
1534 });
1535 cx.spawn(async move |this, cx| {
1536 let mut summary = String::new();
1537 let mut messages = model.stream_completion(request, cx).await?;
1538 while let Some(event) = messages.next().await {
1539 let event = event?;
1540 let text = match event {
1541 LanguageModelCompletionEvent::Text(text) => text,
1542 LanguageModelCompletionEvent::StatusUpdate(
1543 CompletionRequestStatus::UsageUpdated { .. },
1544 ) => {
1545 // this.update(cx, |thread, cx| {
1546 // thread.update_model_request_usage(amount as u32, limit, cx);
1547 // })?;
1548 // TODO: handle usage update
1549 continue;
1550 }
1551 _ => continue,
1552 };
1553
1554 let mut lines = text.lines();
1555 summary.extend(lines.next());
1556 }
1557
1558 log::info!("Setting summary: {}", summary);
1559 let summary = SharedString::from(summary);
1560
1561 this.update(cx, |this, cx| {
1562 this.summary = Some(summary.clone());
1563 cx.notify()
1564 })?;
1565
1566 Ok(summary)
1567 })
1568 }
1569
1570 fn update_title(
1571 &mut self,
1572 event_stream: &ThreadEventStream,
1573 cx: &mut Context<Self>,
1574 ) -> Option<Task<Result<()>>> {
1575 if self.title.is_some() {
1576 log::debug!("Skipping title generation because we already have one.");
1577 return None;
1578 }
1579
1580 log::info!(
1581 "Generating title with model: {:?}",
1582 self.summarization_model.as_ref().map(|model| model.name())
1583 );
1584 let model = self.summarization_model.clone()?;
1585 let event_stream = event_stream.clone();
1586 let mut request = LanguageModelRequest {
1587 intent: Some(CompletionIntent::ThreadSummarization),
1588 temperature: AgentSettings::temperature_for_model(&model, cx),
1589 ..Default::default()
1590 };
1591
1592 for message in &self.messages {
1593 request.messages.extend(message.to_request());
1594 }
1595
1596 request.messages.push(LanguageModelRequestMessage {
1597 role: Role::User,
1598 content: vec![SUMMARIZE_THREAD_PROMPT.into()],
1599 cache: false,
1600 });
1601 Some(cx.spawn(async move |this, cx| {
1602 let mut title = String::new();
1603 let mut messages = model.stream_completion(request, cx).await?;
1604 while let Some(event) = messages.next().await {
1605 let event = event?;
1606 let text = match event {
1607 LanguageModelCompletionEvent::Text(text) => text,
1608 LanguageModelCompletionEvent::StatusUpdate(
1609 CompletionRequestStatus::UsageUpdated { .. },
1610 ) => {
1611 // this.update(cx, |thread, cx| {
1612 // thread.update_model_request_usage(amount as u32, limit, cx);
1613 // })?;
1614 // TODO: handle usage update
1615 continue;
1616 }
1617 _ => continue,
1618 };
1619
1620 let mut lines = text.lines();
1621 title.extend(lines.next());
1622
1623 // Stop if the LLM generated multiple lines.
1624 if lines.next().is_some() {
1625 break;
1626 }
1627 }
1628
1629 log::info!("Setting title: {}", title);
1630
1631 this.update(cx, |this, cx| {
1632 let title = SharedString::from(title);
1633 event_stream.send_title_update(title.clone());
1634 this.title = Some(title);
1635 cx.notify();
1636 })
1637 }))
1638 }
1639 fn last_user_message(&self) -> Option<&UserMessage> {
1640 self.messages
1641 .iter()
1642 .rev()
1643 .find_map(|message| match message {
1644 Message::User(user_message) => Some(user_message),
1645 Message::Agent(_) => None,
1646 Message::Resume => None,
1647 })
1648 }
1649
1650 fn pending_message(&mut self) -> &mut AgentMessage {
1651 self.pending_message.get_or_insert_default()
1652 }
1653
1654 fn flush_pending_message(&mut self, cx: &mut Context<Self>) {
1655 let Some(mut message) = self.pending_message.take() else {
1656 return;
1657 };
1658
1659 for content in &message.content {
1660 let AgentMessageContent::ToolUse(tool_use) = content else {
1661 continue;
1662 };
1663
1664 if !message.tool_results.contains_key(&tool_use.id) {
1665 message.tool_results.insert(
1666 tool_use.id.clone(),
1667 LanguageModelToolResult {
1668 tool_use_id: tool_use.id.clone(),
1669 tool_name: tool_use.name.clone(),
1670 is_error: true,
1671 content: LanguageModelToolResultContent::Text(TOOL_CANCELED_MESSAGE.into()),
1672 output: None,
1673 },
1674 );
1675 }
1676 }
1677
1678 self.messages.push(Message::Agent(message));
1679 self.updated_at = Utc::now();
1680 self.summary = None;
1681 cx.notify()
1682 }
1683
1684 pub(crate) fn build_completion_request(
1685 &self,
1686 completion_intent: CompletionIntent,
1687 cx: &mut App,
1688 ) -> Result<LanguageModelRequest> {
1689 let model = self.model().context("No language model configured")?;
1690
1691 log::debug!("Building completion request");
1692 log::debug!("Completion intent: {:?}", completion_intent);
1693 log::debug!("Completion mode: {:?}", self.completion_mode);
1694
1695 let messages = self.build_request_messages(cx);
1696 log::info!("Request will include {} messages", messages.len());
1697
1698 let tools = if let Some(tools) = self.tools(cx).log_err() {
1699 tools
1700 .filter_map(|tool| {
1701 let tool_name = tool.name().to_string();
1702 log::trace!("Including tool: {}", tool_name);
1703 Some(LanguageModelRequestTool {
1704 name: tool_name,
1705 description: tool.description().to_string(),
1706 input_schema: tool.input_schema(model.tool_input_format()).log_err()?,
1707 })
1708 })
1709 .collect()
1710 } else {
1711 Vec::new()
1712 };
1713
1714 log::info!("Request includes {} tools", tools.len());
1715
1716 let request = LanguageModelRequest {
1717 thread_id: Some(self.id.to_string()),
1718 prompt_id: Some(self.prompt_id.to_string()),
1719 intent: Some(completion_intent),
1720 mode: Some(self.completion_mode.into()),
1721 messages,
1722 tools,
1723 tool_choice: None,
1724 stop: Vec::new(),
1725 temperature: AgentSettings::temperature_for_model(model, cx),
1726 thinking_allowed: true,
1727 };
1728
1729 log::debug!("Completion request built successfully");
1730 Ok(request)
1731 }
1732
1733 fn tools<'a>(&'a self, cx: &'a App) -> Result<impl Iterator<Item = &'a Arc<dyn AnyAgentTool>>> {
1734 let model = self.model().context("No language model configured")?;
1735
1736 let profile = AgentSettings::get_global(cx)
1737 .profiles
1738 .get(&self.profile_id)
1739 .context("profile not found")?;
1740 let provider_id = model.provider_id();
1741
1742 Ok(self
1743 .tools
1744 .iter()
1745 .filter(move |(_, tool)| tool.supported_provider(&provider_id))
1746 .filter_map(|(tool_name, tool)| {
1747 if profile.is_tool_enabled(tool_name) {
1748 Some(tool)
1749 } else {
1750 None
1751 }
1752 })
1753 .chain(self.context_server_registry.read(cx).servers().flat_map(
1754 |(server_id, tools)| {
1755 tools.iter().filter_map(|(tool_name, tool)| {
1756 if profile.is_context_server_tool_enabled(&server_id.0, tool_name) {
1757 Some(tool)
1758 } else {
1759 None
1760 }
1761 })
1762 },
1763 )))
1764 }
1765
1766 fn build_request_messages(&self, cx: &App) -> Vec<LanguageModelRequestMessage> {
1767 log::trace!(
1768 "Building request messages from {} thread messages",
1769 self.messages.len()
1770 );
1771 let mut messages = vec![self.build_system_message(cx)];
1772 for message in &self.messages {
1773 messages.extend(message.to_request());
1774 }
1775
1776 if let Some(message) = self.pending_message.as_ref() {
1777 messages.extend(message.to_request());
1778 }
1779
1780 if let Some(last_user_message) = messages
1781 .iter_mut()
1782 .rev()
1783 .find(|message| message.role == Role::User)
1784 {
1785 last_user_message.cache = true;
1786 }
1787
1788 messages
1789 }
1790
1791 pub fn to_markdown(&self) -> String {
1792 let mut markdown = String::new();
1793 for (ix, message) in self.messages.iter().enumerate() {
1794 if ix > 0 {
1795 markdown.push('\n');
1796 }
1797 markdown.push_str(&message.to_markdown());
1798 }
1799
1800 if let Some(message) = self.pending_message.as_ref() {
1801 markdown.push('\n');
1802 markdown.push_str(&message.to_markdown());
1803 }
1804
1805 markdown
1806 }
1807
1808 fn advance_prompt_id(&mut self) {
1809 self.prompt_id = PromptId::new();
1810 }
1811
1812 fn retry_strategy_for(error: &LanguageModelCompletionError) -> Option<RetryStrategy> {
1813 use LanguageModelCompletionError::*;
1814 use http_client::StatusCode;
1815
1816 // General strategy here:
1817 // - If retrying won't help (e.g. invalid API key or payload too large), return None so we don't retry at all.
1818 // - If it's a time-based issue (e.g. server overloaded, rate limit exceeded), retry up to 4 times with exponential backoff.
1819 // - If it's an issue that *might* be fixed by retrying (e.g. internal server error), retry up to 3 times.
1820 match error {
1821 HttpResponseError {
1822 status_code: StatusCode::TOO_MANY_REQUESTS,
1823 ..
1824 } => Some(RetryStrategy::ExponentialBackoff {
1825 initial_delay: BASE_RETRY_DELAY,
1826 max_attempts: MAX_RETRY_ATTEMPTS,
1827 }),
1828 ServerOverloaded { retry_after, .. } | RateLimitExceeded { retry_after, .. } => {
1829 Some(RetryStrategy::Fixed {
1830 delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
1831 max_attempts: MAX_RETRY_ATTEMPTS,
1832 })
1833 }
1834 UpstreamProviderError {
1835 status,
1836 retry_after,
1837 ..
1838 } => match *status {
1839 StatusCode::TOO_MANY_REQUESTS | StatusCode::SERVICE_UNAVAILABLE => {
1840 Some(RetryStrategy::Fixed {
1841 delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
1842 max_attempts: MAX_RETRY_ATTEMPTS,
1843 })
1844 }
1845 StatusCode::INTERNAL_SERVER_ERROR => Some(RetryStrategy::Fixed {
1846 delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
1847 // Internal Server Error could be anything, retry up to 3 times.
1848 max_attempts: 3,
1849 }),
1850 status => {
1851 // There is no StatusCode variant for the unofficial HTTP 529 ("The service is overloaded"),
1852 // but we frequently get them in practice. See https://http.dev/529
1853 if status.as_u16() == 529 {
1854 Some(RetryStrategy::Fixed {
1855 delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
1856 max_attempts: MAX_RETRY_ATTEMPTS,
1857 })
1858 } else {
1859 Some(RetryStrategy::Fixed {
1860 delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
1861 max_attempts: 2,
1862 })
1863 }
1864 }
1865 },
1866 ApiInternalServerError { .. } => Some(RetryStrategy::Fixed {
1867 delay: BASE_RETRY_DELAY,
1868 max_attempts: 3,
1869 }),
1870 ApiReadResponseError { .. }
1871 | HttpSend { .. }
1872 | DeserializeResponse { .. }
1873 | BadRequestFormat { .. } => Some(RetryStrategy::Fixed {
1874 delay: BASE_RETRY_DELAY,
1875 max_attempts: 3,
1876 }),
1877 // Retrying these errors definitely shouldn't help.
1878 HttpResponseError {
1879 status_code:
1880 StatusCode::PAYLOAD_TOO_LARGE | StatusCode::FORBIDDEN | StatusCode::UNAUTHORIZED,
1881 ..
1882 }
1883 | AuthenticationError { .. }
1884 | PermissionError { .. }
1885 | NoApiKey { .. }
1886 | ApiEndpointNotFound { .. }
1887 | PromptTooLarge { .. } => None,
1888 // These errors might be transient, so retry them
1889 SerializeRequest { .. } | BuildRequestBody { .. } => Some(RetryStrategy::Fixed {
1890 delay: BASE_RETRY_DELAY,
1891 max_attempts: 1,
1892 }),
1893 // Retry all other 4xx and 5xx errors once.
1894 HttpResponseError { status_code, .. }
1895 if status_code.is_client_error() || status_code.is_server_error() =>
1896 {
1897 Some(RetryStrategy::Fixed {
1898 delay: BASE_RETRY_DELAY,
1899 max_attempts: 3,
1900 })
1901 }
1902 Other(err)
1903 if err.is::<language_model::PaymentRequiredError>()
1904 || err.is::<language_model::ModelRequestLimitReachedError>() =>
1905 {
1906 // Retrying won't help for Payment Required or Model Request Limit errors (where
1907 // the user must upgrade to usage-based billing to get more requests, or else wait
1908 // for a significant amount of time for the request limit to reset).
1909 None
1910 }
1911 // Conservatively assume that any other errors are non-retryable
1912 HttpResponseError { .. } | Other(..) => Some(RetryStrategy::Fixed {
1913 delay: BASE_RETRY_DELAY,
1914 max_attempts: 2,
1915 }),
1916 }
1917 }
1918}
1919
1920struct RunningTurn {
1921 /// Holds the task that handles agent interaction until the end of the turn.
1922 /// Survives across multiple requests as the model performs tool calls and
1923 /// we run tools, report their results.
1924 _task: Task<()>,
1925 /// The current event stream for the running turn. Used to report a final
1926 /// cancellation event if we cancel the turn.
1927 event_stream: ThreadEventStream,
1928}
1929
1930impl RunningTurn {
1931 fn cancel(self) {
1932 log::debug!("Cancelling in progress turn");
1933 self.event_stream.send_canceled();
1934 }
1935}
1936
1937pub trait AgentTool
1938where
1939 Self: 'static + Sized,
1940{
1941 type Input: for<'de> Deserialize<'de> + Serialize + JsonSchema;
1942 type Output: for<'de> Deserialize<'de> + Serialize + Into<LanguageModelToolResultContent>;
1943
1944 fn name(&self) -> SharedString;
1945
1946 fn description(&self) -> SharedString {
1947 let schema = schemars::schema_for!(Self::Input);
1948 SharedString::new(
1949 schema
1950 .get("description")
1951 .and_then(|description| description.as_str())
1952 .unwrap_or_default(),
1953 )
1954 }
1955
1956 fn kind(&self) -> acp::ToolKind;
1957
1958 /// The initial tool title to display. Can be updated during the tool run.
1959 fn initial_title(&self, input: Result<Self::Input, serde_json::Value>) -> SharedString;
1960
1961 /// Returns the JSON schema that describes the tool's input.
1962 fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Schema {
1963 crate::tool_schema::root_schema_for::<Self::Input>(format)
1964 }
1965
1966 /// Some tools rely on a provider for the underlying billing or other reasons.
1967 /// Allow the tool to check if they are compatible, or should be filtered out.
1968 fn supported_provider(&self, _provider: &LanguageModelProviderId) -> bool {
1969 true
1970 }
1971
1972 /// Runs the tool with the provided input.
1973 fn run(
1974 self: Arc<Self>,
1975 input: Self::Input,
1976 event_stream: ToolCallEventStream,
1977 cx: &mut App,
1978 ) -> Task<Result<Self::Output>>;
1979
1980 /// Emits events for a previous execution of the tool.
1981 fn replay(
1982 &self,
1983 _input: Self::Input,
1984 _output: Self::Output,
1985 _event_stream: ToolCallEventStream,
1986 _cx: &mut App,
1987 ) -> Result<()> {
1988 Ok(())
1989 }
1990
1991 fn erase(self) -> Arc<dyn AnyAgentTool> {
1992 Arc::new(Erased(Arc::new(self)))
1993 }
1994}
1995
1996pub struct Erased<T>(T);
1997
1998pub struct AgentToolOutput {
1999 pub llm_output: LanguageModelToolResultContent,
2000 pub raw_output: serde_json::Value,
2001}
2002
2003pub trait AnyAgentTool {
2004 fn name(&self) -> SharedString;
2005 fn description(&self) -> SharedString;
2006 fn kind(&self) -> acp::ToolKind;
2007 fn initial_title(&self, input: serde_json::Value) -> SharedString;
2008 fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value>;
2009 fn supported_provider(&self, _provider: &LanguageModelProviderId) -> bool {
2010 true
2011 }
2012 fn run(
2013 self: Arc<Self>,
2014 input: serde_json::Value,
2015 event_stream: ToolCallEventStream,
2016 cx: &mut App,
2017 ) -> Task<Result<AgentToolOutput>>;
2018 fn replay(
2019 &self,
2020 input: serde_json::Value,
2021 output: serde_json::Value,
2022 event_stream: ToolCallEventStream,
2023 cx: &mut App,
2024 ) -> Result<()>;
2025}
2026
2027impl<T> AnyAgentTool for Erased<Arc<T>>
2028where
2029 T: AgentTool,
2030{
2031 fn name(&self) -> SharedString {
2032 self.0.name()
2033 }
2034
2035 fn description(&self) -> SharedString {
2036 self.0.description()
2037 }
2038
2039 fn kind(&self) -> agent_client_protocol::ToolKind {
2040 self.0.kind()
2041 }
2042
2043 fn initial_title(&self, input: serde_json::Value) -> SharedString {
2044 let parsed_input = serde_json::from_value(input.clone()).map_err(|_| input);
2045 self.0.initial_title(parsed_input)
2046 }
2047
2048 fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
2049 let mut json = serde_json::to_value(self.0.input_schema(format))?;
2050 adapt_schema_to_format(&mut json, format)?;
2051 Ok(json)
2052 }
2053
2054 fn supported_provider(&self, provider: &LanguageModelProviderId) -> bool {
2055 self.0.supported_provider(provider)
2056 }
2057
2058 fn run(
2059 self: Arc<Self>,
2060 input: serde_json::Value,
2061 event_stream: ToolCallEventStream,
2062 cx: &mut App,
2063 ) -> Task<Result<AgentToolOutput>> {
2064 cx.spawn(async move |cx| {
2065 let input = serde_json::from_value(input)?;
2066 let output = cx
2067 .update(|cx| self.0.clone().run(input, event_stream, cx))?
2068 .await?;
2069 let raw_output = serde_json::to_value(&output)?;
2070 Ok(AgentToolOutput {
2071 llm_output: output.into(),
2072 raw_output,
2073 })
2074 })
2075 }
2076
2077 fn replay(
2078 &self,
2079 input: serde_json::Value,
2080 output: serde_json::Value,
2081 event_stream: ToolCallEventStream,
2082 cx: &mut App,
2083 ) -> Result<()> {
2084 let input = serde_json::from_value(input)?;
2085 let output = serde_json::from_value(output)?;
2086 self.0.replay(input, output, event_stream, cx)
2087 }
2088}
2089
2090#[derive(Clone)]
2091struct ThreadEventStream(mpsc::UnboundedSender<Result<ThreadEvent>>);
2092
2093impl ThreadEventStream {
2094 fn send_title_update(&self, text: SharedString) {
2095 self.0
2096 .unbounded_send(Ok(ThreadEvent::TitleUpdate(text)))
2097 .ok();
2098 }
2099
2100 fn send_user_message(&self, message: &UserMessage) {
2101 self.0
2102 .unbounded_send(Ok(ThreadEvent::UserMessage(message.clone())))
2103 .ok();
2104 }
2105
2106 fn send_text(&self, text: &str) {
2107 self.0
2108 .unbounded_send(Ok(ThreadEvent::AgentText(text.to_string())))
2109 .ok();
2110 }
2111
2112 fn send_thinking(&self, text: &str) {
2113 self.0
2114 .unbounded_send(Ok(ThreadEvent::AgentThinking(text.to_string())))
2115 .ok();
2116 }
2117
2118 fn send_tool_call(
2119 &self,
2120 id: &LanguageModelToolUseId,
2121 title: SharedString,
2122 kind: acp::ToolKind,
2123 input: serde_json::Value,
2124 ) {
2125 self.0
2126 .unbounded_send(Ok(ThreadEvent::ToolCall(Self::initial_tool_call(
2127 id,
2128 title.to_string(),
2129 kind,
2130 input,
2131 ))))
2132 .ok();
2133 }
2134
2135 fn initial_tool_call(
2136 id: &LanguageModelToolUseId,
2137 title: String,
2138 kind: acp::ToolKind,
2139 input: serde_json::Value,
2140 ) -> acp::ToolCall {
2141 acp::ToolCall {
2142 id: acp::ToolCallId(id.to_string().into()),
2143 title,
2144 kind,
2145 status: acp::ToolCallStatus::Pending,
2146 content: vec![],
2147 locations: vec![],
2148 raw_input: Some(input),
2149 raw_output: None,
2150 }
2151 }
2152
2153 fn update_tool_call_fields(
2154 &self,
2155 tool_use_id: &LanguageModelToolUseId,
2156 fields: acp::ToolCallUpdateFields,
2157 ) {
2158 self.0
2159 .unbounded_send(Ok(ThreadEvent::ToolCallUpdate(
2160 acp::ToolCallUpdate {
2161 id: acp::ToolCallId(tool_use_id.to_string().into()),
2162 fields,
2163 }
2164 .into(),
2165 )))
2166 .ok();
2167 }
2168
2169 fn send_token_usage_update(&self, usage: acp_thread::TokenUsage) {
2170 self.0
2171 .unbounded_send(Ok(ThreadEvent::TokenUsageUpdate(usage)))
2172 .ok();
2173 }
2174
2175 fn send_retry(&self, status: acp_thread::RetryStatus) {
2176 self.0.unbounded_send(Ok(ThreadEvent::Retry(status))).ok();
2177 }
2178
2179 fn send_stop(&self, reason: StopReason) {
2180 match reason {
2181 StopReason::EndTurn => {
2182 self.0
2183 .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::EndTurn)))
2184 .ok();
2185 }
2186 StopReason::MaxTokens => {
2187 self.0
2188 .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::MaxTokens)))
2189 .ok();
2190 }
2191 StopReason::Refusal => {
2192 self.0
2193 .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::Refusal)))
2194 .ok();
2195 }
2196 StopReason::ToolUse => {}
2197 }
2198 }
2199
2200 fn send_canceled(&self) {
2201 self.0
2202 .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::Canceled)))
2203 .ok();
2204 }
2205
2206 fn send_error(&self, error: impl Into<anyhow::Error>) {
2207 self.0.unbounded_send(Err(error.into())).ok();
2208 }
2209}
2210
2211#[derive(Clone)]
2212pub struct ToolCallEventStream {
2213 tool_use_id: LanguageModelToolUseId,
2214 stream: ThreadEventStream,
2215 fs: Option<Arc<dyn Fs>>,
2216}
2217
2218impl ToolCallEventStream {
2219 #[cfg(test)]
2220 pub fn test() -> (Self, ToolCallEventStreamReceiver) {
2221 let (events_tx, events_rx) = mpsc::unbounded::<Result<ThreadEvent>>();
2222
2223 let stream = ToolCallEventStream::new("test_id".into(), ThreadEventStream(events_tx), None);
2224
2225 (stream, ToolCallEventStreamReceiver(events_rx))
2226 }
2227
2228 fn new(
2229 tool_use_id: LanguageModelToolUseId,
2230 stream: ThreadEventStream,
2231 fs: Option<Arc<dyn Fs>>,
2232 ) -> Self {
2233 Self {
2234 tool_use_id,
2235 stream,
2236 fs,
2237 }
2238 }
2239
2240 pub fn update_fields(&self, fields: acp::ToolCallUpdateFields) {
2241 self.stream
2242 .update_tool_call_fields(&self.tool_use_id, fields);
2243 }
2244
2245 pub fn update_diff(&self, diff: Entity<acp_thread::Diff>) {
2246 self.stream
2247 .0
2248 .unbounded_send(Ok(ThreadEvent::ToolCallUpdate(
2249 acp_thread::ToolCallUpdateDiff {
2250 id: acp::ToolCallId(self.tool_use_id.to_string().into()),
2251 diff,
2252 }
2253 .into(),
2254 )))
2255 .ok();
2256 }
2257
2258 pub fn update_terminal(&self, terminal: Entity<acp_thread::Terminal>) {
2259 self.stream
2260 .0
2261 .unbounded_send(Ok(ThreadEvent::ToolCallUpdate(
2262 acp_thread::ToolCallUpdateTerminal {
2263 id: acp::ToolCallId(self.tool_use_id.to_string().into()),
2264 terminal,
2265 }
2266 .into(),
2267 )))
2268 .ok();
2269 }
2270
2271 pub fn authorize(&self, title: impl Into<String>, cx: &mut App) -> Task<Result<()>> {
2272 if agent_settings::AgentSettings::get_global(cx).always_allow_tool_actions {
2273 return Task::ready(Ok(()));
2274 }
2275
2276 let (response_tx, response_rx) = oneshot::channel();
2277 self.stream
2278 .0
2279 .unbounded_send(Ok(ThreadEvent::ToolCallAuthorization(
2280 ToolCallAuthorization {
2281 tool_call: acp::ToolCallUpdate {
2282 id: acp::ToolCallId(self.tool_use_id.to_string().into()),
2283 fields: acp::ToolCallUpdateFields {
2284 title: Some(title.into()),
2285 ..Default::default()
2286 },
2287 },
2288 options: vec![
2289 acp::PermissionOption {
2290 id: acp::PermissionOptionId("always_allow".into()),
2291 name: "Always Allow".into(),
2292 kind: acp::PermissionOptionKind::AllowAlways,
2293 },
2294 acp::PermissionOption {
2295 id: acp::PermissionOptionId("allow".into()),
2296 name: "Allow".into(),
2297 kind: acp::PermissionOptionKind::AllowOnce,
2298 },
2299 acp::PermissionOption {
2300 id: acp::PermissionOptionId("deny".into()),
2301 name: "Deny".into(),
2302 kind: acp::PermissionOptionKind::RejectOnce,
2303 },
2304 ],
2305 response: response_tx,
2306 },
2307 )))
2308 .ok();
2309 let fs = self.fs.clone();
2310 cx.spawn(async move |cx| match response_rx.await?.0.as_ref() {
2311 "always_allow" => {
2312 if let Some(fs) = fs.clone() {
2313 cx.update(|cx| {
2314 update_settings_file::<AgentSettings>(fs, cx, |settings, _| {
2315 settings.set_always_allow_tool_actions(true);
2316 });
2317 })?;
2318 }
2319
2320 Ok(())
2321 }
2322 "allow" => Ok(()),
2323 _ => Err(anyhow!("Permission to run tool denied by user")),
2324 })
2325 }
2326}
2327
2328#[cfg(test)]
2329pub struct ToolCallEventStreamReceiver(mpsc::UnboundedReceiver<Result<ThreadEvent>>);
2330
2331#[cfg(test)]
2332impl ToolCallEventStreamReceiver {
2333 pub async fn expect_authorization(&mut self) -> ToolCallAuthorization {
2334 let event = self.0.next().await;
2335 if let Some(Ok(ThreadEvent::ToolCallAuthorization(auth))) = event {
2336 auth
2337 } else {
2338 panic!("Expected ToolCallAuthorization but got: {:?}", event);
2339 }
2340 }
2341
2342 pub async fn expect_terminal(&mut self) -> Entity<acp_thread::Terminal> {
2343 let event = self.0.next().await;
2344 if let Some(Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateTerminal(
2345 update,
2346 )))) = event
2347 {
2348 update.terminal
2349 } else {
2350 panic!("Expected terminal but got: {:?}", event);
2351 }
2352 }
2353}
2354
2355#[cfg(test)]
2356impl std::ops::Deref for ToolCallEventStreamReceiver {
2357 type Target = mpsc::UnboundedReceiver<Result<ThreadEvent>>;
2358
2359 fn deref(&self) -> &Self::Target {
2360 &self.0
2361 }
2362}
2363
2364#[cfg(test)]
2365impl std::ops::DerefMut for ToolCallEventStreamReceiver {
2366 fn deref_mut(&mut self) -> &mut Self::Target {
2367 &mut self.0
2368 }
2369}
2370
2371impl From<&str> for UserMessageContent {
2372 fn from(text: &str) -> Self {
2373 Self::Text(text.into())
2374 }
2375}
2376
2377impl From<acp::ContentBlock> for UserMessageContent {
2378 fn from(value: acp::ContentBlock) -> Self {
2379 match value {
2380 acp::ContentBlock::Text(text_content) => Self::Text(text_content.text),
2381 acp::ContentBlock::Image(image_content) => Self::Image(convert_image(image_content)),
2382 acp::ContentBlock::Audio(_) => {
2383 // TODO
2384 Self::Text("[audio]".to_string())
2385 }
2386 acp::ContentBlock::ResourceLink(resource_link) => {
2387 match MentionUri::parse(&resource_link.uri) {
2388 Ok(uri) => Self::Mention {
2389 uri,
2390 content: String::new(),
2391 },
2392 Err(err) => {
2393 log::error!("Failed to parse mention link: {}", err);
2394 Self::Text(format!("[{}]({})", resource_link.name, resource_link.uri))
2395 }
2396 }
2397 }
2398 acp::ContentBlock::Resource(resource) => match resource.resource {
2399 acp::EmbeddedResourceResource::TextResourceContents(resource) => {
2400 match MentionUri::parse(&resource.uri) {
2401 Ok(uri) => Self::Mention {
2402 uri,
2403 content: resource.text,
2404 },
2405 Err(err) => {
2406 log::error!("Failed to parse mention link: {}", err);
2407 Self::Text(
2408 MarkdownCodeBlock {
2409 tag: &resource.uri,
2410 text: &resource.text,
2411 }
2412 .to_string(),
2413 )
2414 }
2415 }
2416 }
2417 acp::EmbeddedResourceResource::BlobResourceContents(_) => {
2418 // TODO
2419 Self::Text("[blob]".to_string())
2420 }
2421 },
2422 }
2423 }
2424}
2425
2426impl From<UserMessageContent> for acp::ContentBlock {
2427 fn from(content: UserMessageContent) -> Self {
2428 match content {
2429 UserMessageContent::Text(text) => acp::ContentBlock::Text(acp::TextContent {
2430 text,
2431 annotations: None,
2432 }),
2433 UserMessageContent::Image(image) => acp::ContentBlock::Image(acp::ImageContent {
2434 data: image.source.to_string(),
2435 mime_type: "image/png".to_string(),
2436 annotations: None,
2437 uri: None,
2438 }),
2439 UserMessageContent::Mention { uri, content } => {
2440 acp::ContentBlock::ResourceLink(acp::ResourceLink {
2441 uri: uri.to_uri().to_string(),
2442 name: uri.name(),
2443 annotations: None,
2444 description: if content.is_empty() {
2445 None
2446 } else {
2447 Some(content)
2448 },
2449 mime_type: None,
2450 size: None,
2451 title: None,
2452 })
2453 }
2454 }
2455 }
2456}
2457
2458fn convert_image(image_content: acp::ImageContent) -> LanguageModelImage {
2459 LanguageModelImage {
2460 source: image_content.data.into(),
2461 // TODO: make this optional?
2462 size: gpui::Size::new(0.into(), 0.into()),
2463 }
2464}