1use crate::{
2 ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DbLanguageModel, DbThread,
3 DeletePathTool, DiagnosticsTool, EditFileTool, FetchTool, FindPathTool, GrepTool,
4 ListDirectoryTool, MovePathTool, NowTool, OpenTool, ReadFileTool, SystemPromptTemplate,
5 Template, Templates, TerminalTool, ThinkingTool, WebSearchTool,
6};
7use acp_thread::{MentionUri, UserMessageId};
8use action_log::ActionLog;
9use agent::thread::{DetailedSummaryState, GitState, ProjectSnapshot, WorktreeSnapshot};
10use agent_client_protocol as acp;
11use agent_settings::{AgentProfileId, AgentSettings, CompletionMode, SUMMARIZE_THREAD_PROMPT};
12use anyhow::{Context as _, Result, anyhow};
13use assistant_tool::adapt_schema_to_format;
14use chrono::{DateTime, Utc};
15use cloud_llm_client::{CompletionIntent, CompletionRequestStatus};
16use collections::IndexMap;
17use fs::Fs;
18use futures::{
19 FutureExt,
20 channel::{mpsc, oneshot},
21 future::Shared,
22 stream::FuturesUnordered,
23};
24use git::repository::DiffType;
25use gpui::{App, AppContext, AsyncApp, Context, Entity, SharedString, Task, WeakEntity};
26use language_model::{
27 LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelImage,
28 LanguageModelProviderId, LanguageModelRegistry, LanguageModelRequest,
29 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
30 LanguageModelToolResultContent, LanguageModelToolSchemaFormat, LanguageModelToolUse,
31 LanguageModelToolUseId, Role, SelectedModel, StopReason, TokenUsage,
32};
33use project::{
34 Project,
35 git_store::{GitStore, RepositoryState},
36};
37use prompt_store::ProjectContext;
38use schemars::{JsonSchema, Schema};
39use serde::{Deserialize, Serialize};
40use settings::{Settings, update_settings_file};
41use smol::stream::StreamExt;
42use std::{
43 collections::BTreeMap,
44 path::Path,
45 sync::Arc,
46 time::{Duration, Instant},
47};
48use std::{fmt::Write, ops::Range};
49use util::{ResultExt, markdown::MarkdownCodeBlock};
50use uuid::Uuid;
51
52const TOOL_CANCELED_MESSAGE: &str = "Tool canceled by user";
53
54/// The ID of the user prompt that initiated a request.
55///
56/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
57#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
58pub struct PromptId(Arc<str>);
59
60impl PromptId {
61 pub fn new() -> Self {
62 Self(Uuid::new_v4().to_string().into())
63 }
64}
65
66impl std::fmt::Display for PromptId {
67 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
68 write!(f, "{}", self.0)
69 }
70}
71
72pub(crate) const MAX_RETRY_ATTEMPTS: u8 = 4;
73pub(crate) const BASE_RETRY_DELAY: Duration = Duration::from_secs(5);
74
75#[derive(Debug, Clone)]
76enum RetryStrategy {
77 ExponentialBackoff {
78 initial_delay: Duration,
79 max_attempts: u8,
80 },
81 Fixed {
82 delay: Duration,
83 max_attempts: u8,
84 },
85}
86
87#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
88pub enum Message {
89 User(UserMessage),
90 Agent(AgentMessage),
91 Resume,
92}
93
94impl Message {
95 pub fn as_agent_message(&self) -> Option<&AgentMessage> {
96 match self {
97 Message::Agent(agent_message) => Some(agent_message),
98 _ => None,
99 }
100 }
101
102 pub fn to_request(&self) -> Vec<LanguageModelRequestMessage> {
103 match self {
104 Message::User(message) => vec![message.to_request()],
105 Message::Agent(message) => message.to_request(),
106 Message::Resume => vec![LanguageModelRequestMessage {
107 role: Role::User,
108 content: vec!["Continue where you left off".into()],
109 cache: false,
110 }],
111 }
112 }
113
114 pub fn to_markdown(&self) -> String {
115 match self {
116 Message::User(message) => message.to_markdown(),
117 Message::Agent(message) => message.to_markdown(),
118 Message::Resume => "[resumed after tool use limit was reached]".into(),
119 }
120 }
121
122 pub fn role(&self) -> Role {
123 match self {
124 Message::User(_) | Message::Resume => Role::User,
125 Message::Agent(_) => Role::Assistant,
126 }
127 }
128}
129
130#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
131pub struct UserMessage {
132 pub id: UserMessageId,
133 pub content: Vec<UserMessageContent>,
134}
135
136#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
137pub enum UserMessageContent {
138 Text(String),
139 Mention { uri: MentionUri, content: String },
140 Image(LanguageModelImage),
141}
142
143impl UserMessage {
144 pub fn to_markdown(&self) -> String {
145 let mut markdown = String::from("## User\n\n");
146
147 for content in &self.content {
148 match content {
149 UserMessageContent::Text(text) => {
150 markdown.push_str(text);
151 markdown.push('\n');
152 }
153 UserMessageContent::Image(_) => {
154 markdown.push_str("<image />\n");
155 }
156 UserMessageContent::Mention { uri, content } => {
157 if !content.is_empty() {
158 let _ = write!(&mut markdown, "{}\n\n{}\n", uri.as_link(), content);
159 } else {
160 let _ = write!(&mut markdown, "{}\n", uri.as_link());
161 }
162 }
163 }
164 }
165
166 markdown
167 }
168
169 fn to_request(&self) -> LanguageModelRequestMessage {
170 let mut message = LanguageModelRequestMessage {
171 role: Role::User,
172 content: Vec::with_capacity(self.content.len()),
173 cache: false,
174 };
175
176 const OPEN_CONTEXT: &str = "<context>\n\
177 The following items were attached by the user. \
178 They are up-to-date and don't need to be re-read.\n\n";
179
180 const OPEN_FILES_TAG: &str = "<files>";
181 const OPEN_DIRECTORIES_TAG: &str = "<directories>";
182 const OPEN_SYMBOLS_TAG: &str = "<symbols>";
183 const OPEN_THREADS_TAG: &str = "<threads>";
184 const OPEN_FETCH_TAG: &str = "<fetched_urls>";
185 const OPEN_RULES_TAG: &str =
186 "<rules>\nThe user has specified the following rules that should be applied:\n";
187
188 let mut file_context = OPEN_FILES_TAG.to_string();
189 let mut directory_context = OPEN_DIRECTORIES_TAG.to_string();
190 let mut symbol_context = OPEN_SYMBOLS_TAG.to_string();
191 let mut thread_context = OPEN_THREADS_TAG.to_string();
192 let mut fetch_context = OPEN_FETCH_TAG.to_string();
193 let mut rules_context = OPEN_RULES_TAG.to_string();
194
195 for chunk in &self.content {
196 let chunk = match chunk {
197 UserMessageContent::Text(text) => {
198 language_model::MessageContent::Text(text.clone())
199 }
200 UserMessageContent::Image(value) => {
201 language_model::MessageContent::Image(value.clone())
202 }
203 UserMessageContent::Mention { uri, content } => {
204 match uri {
205 MentionUri::File { abs_path } => {
206 write!(
207 &mut symbol_context,
208 "\n{}",
209 MarkdownCodeBlock {
210 tag: &codeblock_tag(abs_path, None),
211 text: &content.to_string(),
212 }
213 )
214 .ok();
215 }
216 MentionUri::Directory { .. } => {
217 write!(&mut directory_context, "\n{}\n", content).ok();
218 }
219 MentionUri::Symbol {
220 path, line_range, ..
221 }
222 | MentionUri::Selection {
223 path, line_range, ..
224 } => {
225 write!(
226 &mut rules_context,
227 "\n{}",
228 MarkdownCodeBlock {
229 tag: &codeblock_tag(path, Some(line_range)),
230 text: content
231 }
232 )
233 .ok();
234 }
235 MentionUri::Thread { .. } => {
236 write!(&mut thread_context, "\n{}\n", content).ok();
237 }
238 MentionUri::TextThread { .. } => {
239 write!(&mut thread_context, "\n{}\n", content).ok();
240 }
241 MentionUri::Rule { .. } => {
242 write!(
243 &mut rules_context,
244 "\n{}",
245 MarkdownCodeBlock {
246 tag: "",
247 text: content
248 }
249 )
250 .ok();
251 }
252 MentionUri::Fetch { url } => {
253 write!(&mut fetch_context, "\nFetch: {}\n\n{}", url, content).ok();
254 }
255 }
256
257 language_model::MessageContent::Text(uri.as_link().to_string())
258 }
259 };
260
261 message.content.push(chunk);
262 }
263
264 let len_before_context = message.content.len();
265
266 if file_context.len() > OPEN_FILES_TAG.len() {
267 file_context.push_str("</files>\n");
268 message
269 .content
270 .push(language_model::MessageContent::Text(file_context));
271 }
272
273 if directory_context.len() > OPEN_DIRECTORIES_TAG.len() {
274 directory_context.push_str("</directories>\n");
275 message
276 .content
277 .push(language_model::MessageContent::Text(directory_context));
278 }
279
280 if symbol_context.len() > OPEN_SYMBOLS_TAG.len() {
281 symbol_context.push_str("</symbols>\n");
282 message
283 .content
284 .push(language_model::MessageContent::Text(symbol_context));
285 }
286
287 if thread_context.len() > OPEN_THREADS_TAG.len() {
288 thread_context.push_str("</threads>\n");
289 message
290 .content
291 .push(language_model::MessageContent::Text(thread_context));
292 }
293
294 if fetch_context.len() > OPEN_FETCH_TAG.len() {
295 fetch_context.push_str("</fetched_urls>\n");
296 message
297 .content
298 .push(language_model::MessageContent::Text(fetch_context));
299 }
300
301 if rules_context.len() > OPEN_RULES_TAG.len() {
302 rules_context.push_str("</user_rules>\n");
303 message
304 .content
305 .push(language_model::MessageContent::Text(rules_context));
306 }
307
308 if message.content.len() > len_before_context {
309 message.content.insert(
310 len_before_context,
311 language_model::MessageContent::Text(OPEN_CONTEXT.into()),
312 );
313 message
314 .content
315 .push(language_model::MessageContent::Text("</context>".into()));
316 }
317
318 message
319 }
320}
321
322fn codeblock_tag(full_path: &Path, line_range: Option<&Range<u32>>) -> String {
323 let mut result = String::new();
324
325 if let Some(extension) = full_path.extension().and_then(|ext| ext.to_str()) {
326 let _ = write!(result, "{} ", extension);
327 }
328
329 let _ = write!(result, "{}", full_path.display());
330
331 if let Some(range) = line_range {
332 if range.start == range.end {
333 let _ = write!(result, ":{}", range.start + 1);
334 } else {
335 let _ = write!(result, ":{}-{}", range.start + 1, range.end + 1);
336 }
337 }
338
339 result
340}
341
342impl AgentMessage {
343 pub fn to_markdown(&self) -> String {
344 let mut markdown = String::from("## Assistant\n\n");
345
346 for content in &self.content {
347 match content {
348 AgentMessageContent::Text(text) => {
349 markdown.push_str(text);
350 markdown.push('\n');
351 }
352 AgentMessageContent::Thinking { text, .. } => {
353 markdown.push_str("<think>");
354 markdown.push_str(text);
355 markdown.push_str("</think>\n");
356 }
357 AgentMessageContent::RedactedThinking(_) => {
358 markdown.push_str("<redacted_thinking />\n")
359 }
360 AgentMessageContent::ToolUse(tool_use) => {
361 markdown.push_str(&format!(
362 "**Tool Use**: {} (ID: {})\n",
363 tool_use.name, tool_use.id
364 ));
365 markdown.push_str(&format!(
366 "{}\n",
367 MarkdownCodeBlock {
368 tag: "json",
369 text: &format!("{:#}", tool_use.input)
370 }
371 ));
372 }
373 }
374 }
375
376 for tool_result in self.tool_results.values() {
377 markdown.push_str(&format!(
378 "**Tool Result**: {} (ID: {})\n\n",
379 tool_result.tool_name, tool_result.tool_use_id
380 ));
381 if tool_result.is_error {
382 markdown.push_str("**ERROR:**\n");
383 }
384
385 match &tool_result.content {
386 LanguageModelToolResultContent::Text(text) => {
387 writeln!(markdown, "{text}\n").ok();
388 }
389 LanguageModelToolResultContent::Image(_) => {
390 writeln!(markdown, "<image />\n").ok();
391 }
392 }
393
394 if let Some(output) = tool_result.output.as_ref() {
395 writeln!(
396 markdown,
397 "**Debug Output**:\n\n```json\n{}\n```\n",
398 serde_json::to_string_pretty(output).unwrap()
399 )
400 .unwrap();
401 }
402 }
403
404 markdown
405 }
406
407 pub fn to_request(&self) -> Vec<LanguageModelRequestMessage> {
408 let mut assistant_message = LanguageModelRequestMessage {
409 role: Role::Assistant,
410 content: Vec::with_capacity(self.content.len()),
411 cache: false,
412 };
413 for chunk in &self.content {
414 let chunk = match chunk {
415 AgentMessageContent::Text(text) => {
416 language_model::MessageContent::Text(text.clone())
417 }
418 AgentMessageContent::Thinking { text, signature } => {
419 language_model::MessageContent::Thinking {
420 text: text.clone(),
421 signature: signature.clone(),
422 }
423 }
424 AgentMessageContent::RedactedThinking(value) => {
425 language_model::MessageContent::RedactedThinking(value.clone())
426 }
427 AgentMessageContent::ToolUse(value) => {
428 language_model::MessageContent::ToolUse(value.clone())
429 }
430 };
431 assistant_message.content.push(chunk);
432 }
433
434 let mut user_message = LanguageModelRequestMessage {
435 role: Role::User,
436 content: Vec::new(),
437 cache: false,
438 };
439
440 for tool_result in self.tool_results.values() {
441 user_message
442 .content
443 .push(language_model::MessageContent::ToolResult(
444 tool_result.clone(),
445 ));
446 }
447
448 let mut messages = Vec::new();
449 if !assistant_message.content.is_empty() {
450 messages.push(assistant_message);
451 }
452 if !user_message.content.is_empty() {
453 messages.push(user_message);
454 }
455 messages
456 }
457}
458
459#[derive(Default, Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
460pub struct AgentMessage {
461 pub content: Vec<AgentMessageContent>,
462 pub tool_results: IndexMap<LanguageModelToolUseId, LanguageModelToolResult>,
463}
464
465#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
466pub enum AgentMessageContent {
467 Text(String),
468 Thinking {
469 text: String,
470 signature: Option<String>,
471 },
472 RedactedThinking(String),
473 ToolUse(LanguageModelToolUse),
474}
475
476#[derive(Debug)]
477pub enum ThreadEvent {
478 UserMessage(UserMessage),
479 AgentText(String),
480 AgentThinking(String),
481 ToolCall(acp::ToolCall),
482 ToolCallUpdate(acp_thread::ToolCallUpdate),
483 ToolCallAuthorization(ToolCallAuthorization),
484 TitleUpdate(SharedString),
485 Retry(acp_thread::RetryStatus),
486 Stop(acp::StopReason),
487}
488
489#[derive(Debug)]
490pub struct ToolCallAuthorization {
491 pub tool_call: acp::ToolCallUpdate,
492 pub options: Vec<acp::PermissionOption>,
493 pub response: oneshot::Sender<acp::PermissionOptionId>,
494}
495
496pub struct Thread {
497 id: acp::SessionId,
498 prompt_id: PromptId,
499 updated_at: DateTime<Utc>,
500 title: Option<SharedString>,
501 #[allow(unused)]
502 summary: DetailedSummaryState,
503 messages: Vec<Message>,
504 completion_mode: CompletionMode,
505 /// Holds the task that handles agent interaction until the end of the turn.
506 /// Survives across multiple requests as the model performs tool calls and
507 /// we run tools, report their results.
508 running_turn: Option<RunningTurn>,
509 pending_message: Option<AgentMessage>,
510 tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
511 tool_use_limit_reached: bool,
512 #[allow(unused)]
513 request_token_usage: Vec<TokenUsage>,
514 #[allow(unused)]
515 cumulative_token_usage: TokenUsage,
516 #[allow(unused)]
517 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
518 context_server_registry: Entity<ContextServerRegistry>,
519 profile_id: AgentProfileId,
520 project_context: Entity<ProjectContext>,
521 templates: Arc<Templates>,
522 model: Option<Arc<dyn LanguageModel>>,
523 summarization_model: Option<Arc<dyn LanguageModel>>,
524 pub(crate) project: Entity<Project>,
525 pub(crate) action_log: Entity<ActionLog>,
526}
527
528impl Thread {
529 pub fn new(
530 project: Entity<Project>,
531 project_context: Entity<ProjectContext>,
532 context_server_registry: Entity<ContextServerRegistry>,
533 action_log: Entity<ActionLog>,
534 templates: Arc<Templates>,
535 model: Option<Arc<dyn LanguageModel>>,
536 cx: &mut Context<Self>,
537 ) -> Self {
538 let profile_id = AgentSettings::get_global(cx).default_profile.clone();
539 Self {
540 id: acp::SessionId(uuid::Uuid::new_v4().to_string().into()),
541 prompt_id: PromptId::new(),
542 updated_at: Utc::now(),
543 title: None,
544 summary: DetailedSummaryState::default(),
545 messages: Vec::new(),
546 completion_mode: AgentSettings::get_global(cx).preferred_completion_mode,
547 running_turn: None,
548 pending_message: None,
549 tools: BTreeMap::default(),
550 tool_use_limit_reached: false,
551 request_token_usage: Vec::new(),
552 cumulative_token_usage: TokenUsage::default(),
553 initial_project_snapshot: {
554 let project_snapshot = Self::project_snapshot(project.clone(), cx);
555 cx.foreground_executor()
556 .spawn(async move { Some(project_snapshot.await) })
557 .shared()
558 },
559 context_server_registry,
560 profile_id,
561 project_context,
562 templates,
563 model,
564 summarization_model: None,
565 project,
566 action_log,
567 }
568 }
569
570 pub fn id(&self) -> &acp::SessionId {
571 &self.id
572 }
573
574 pub fn replay(
575 &mut self,
576 cx: &mut Context<Self>,
577 ) -> mpsc::UnboundedReceiver<Result<ThreadEvent>> {
578 let (tx, rx) = mpsc::unbounded();
579 let stream = ThreadEventStream(tx);
580 for message in &self.messages {
581 match message {
582 Message::User(user_message) => stream.send_user_message(user_message),
583 Message::Agent(assistant_message) => {
584 for content in &assistant_message.content {
585 match content {
586 AgentMessageContent::Text(text) => stream.send_text(text),
587 AgentMessageContent::Thinking { text, .. } => {
588 stream.send_thinking(text)
589 }
590 AgentMessageContent::RedactedThinking(_) => {}
591 AgentMessageContent::ToolUse(tool_use) => {
592 self.replay_tool_call(
593 tool_use,
594 assistant_message.tool_results.get(&tool_use.id),
595 &stream,
596 cx,
597 );
598 }
599 }
600 }
601 }
602 Message::Resume => {}
603 }
604 }
605 rx
606 }
607
608 fn replay_tool_call(
609 &self,
610 tool_use: &LanguageModelToolUse,
611 tool_result: Option<&LanguageModelToolResult>,
612 stream: &ThreadEventStream,
613 cx: &mut Context<Self>,
614 ) {
615 let Some(tool) = self.tools.get(tool_use.name.as_ref()) else {
616 stream
617 .0
618 .unbounded_send(Ok(ThreadEvent::ToolCall(acp::ToolCall {
619 id: acp::ToolCallId(tool_use.id.to_string().into()),
620 title: tool_use.name.to_string(),
621 kind: acp::ToolKind::Other,
622 status: acp::ToolCallStatus::Failed,
623 content: Vec::new(),
624 locations: Vec::new(),
625 raw_input: Some(tool_use.input.clone()),
626 raw_output: None,
627 })))
628 .ok();
629 return;
630 };
631
632 let title = tool.initial_title(tool_use.input.clone());
633 let kind = tool.kind();
634 stream.send_tool_call(&tool_use.id, title, kind, tool_use.input.clone());
635
636 let output = tool_result
637 .as_ref()
638 .and_then(|result| result.output.clone());
639 if let Some(output) = output.clone() {
640 let tool_event_stream = ToolCallEventStream::new(
641 tool_use.id.clone(),
642 stream.clone(),
643 Some(self.project.read(cx).fs().clone()),
644 );
645 tool.replay(tool_use.input.clone(), output, tool_event_stream, cx)
646 .log_err();
647 }
648
649 stream.update_tool_call_fields(
650 &tool_use.id,
651 acp::ToolCallUpdateFields {
652 status: Some(acp::ToolCallStatus::Completed),
653 raw_output: output,
654 ..Default::default()
655 },
656 );
657 }
658
659 pub fn from_db(
660 id: acp::SessionId,
661 db_thread: DbThread,
662 project: Entity<Project>,
663 project_context: Entity<ProjectContext>,
664 context_server_registry: Entity<ContextServerRegistry>,
665 action_log: Entity<ActionLog>,
666 templates: Arc<Templates>,
667 cx: &mut Context<Self>,
668 ) -> Self {
669 let profile_id = db_thread
670 .profile
671 .unwrap_or_else(|| AgentSettings::get_global(cx).default_profile.clone());
672 let model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
673 db_thread
674 .model
675 .and_then(|model| {
676 let model = SelectedModel {
677 provider: model.provider.clone().into(),
678 model: model.model.clone().into(),
679 };
680 registry.select_model(&model, cx)
681 })
682 .or_else(|| registry.default_model())
683 .map(|model| model.model)
684 });
685
686 Self {
687 id,
688 prompt_id: PromptId::new(),
689 title: if db_thread.title.is_empty() {
690 None
691 } else {
692 Some(db_thread.title.clone())
693 },
694 summary: db_thread.summary,
695 messages: db_thread.messages,
696 completion_mode: db_thread.completion_mode.unwrap_or_default(),
697 running_turn: None,
698 pending_message: None,
699 tools: BTreeMap::default(),
700 tool_use_limit_reached: false,
701 request_token_usage: db_thread.request_token_usage.clone(),
702 cumulative_token_usage: db_thread.cumulative_token_usage,
703 initial_project_snapshot: Task::ready(db_thread.initial_project_snapshot).shared(),
704 context_server_registry,
705 profile_id,
706 project_context,
707 templates,
708 model,
709 summarization_model: None,
710 project,
711 action_log,
712 updated_at: db_thread.updated_at,
713 }
714 }
715
716 pub fn to_db(&self, cx: &App) -> Task<DbThread> {
717 let initial_project_snapshot = self.initial_project_snapshot.clone();
718 let mut thread = DbThread {
719 title: self.title.clone().unwrap_or_default(),
720 messages: self.messages.clone(),
721 updated_at: self.updated_at,
722 summary: self.summary.clone(),
723 initial_project_snapshot: None,
724 cumulative_token_usage: self.cumulative_token_usage,
725 request_token_usage: self.request_token_usage.clone(),
726 model: self.model.as_ref().map(|model| DbLanguageModel {
727 provider: model.provider_id().to_string(),
728 model: model.name().0.to_string(),
729 }),
730 completion_mode: Some(self.completion_mode),
731 profile: Some(self.profile_id.clone()),
732 };
733
734 cx.background_spawn(async move {
735 let initial_project_snapshot = initial_project_snapshot.await;
736 thread.initial_project_snapshot = initial_project_snapshot;
737 thread
738 })
739 }
740
741 /// Create a snapshot of the current project state including git information and unsaved buffers.
742 fn project_snapshot(
743 project: Entity<Project>,
744 cx: &mut Context<Self>,
745 ) -> Task<Arc<agent::thread::ProjectSnapshot>> {
746 let git_store = project.read(cx).git_store().clone();
747 let worktree_snapshots: Vec<_> = project
748 .read(cx)
749 .visible_worktrees(cx)
750 .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
751 .collect();
752
753 cx.spawn(async move |_, cx| {
754 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
755
756 let mut unsaved_buffers = Vec::new();
757 cx.update(|app_cx| {
758 let buffer_store = project.read(app_cx).buffer_store();
759 for buffer_handle in buffer_store.read(app_cx).buffers() {
760 let buffer = buffer_handle.read(app_cx);
761 if buffer.is_dirty()
762 && let Some(file) = buffer.file()
763 {
764 let path = file.path().to_string_lossy().to_string();
765 unsaved_buffers.push(path);
766 }
767 }
768 })
769 .ok();
770
771 Arc::new(ProjectSnapshot {
772 worktree_snapshots,
773 unsaved_buffer_paths: unsaved_buffers,
774 timestamp: Utc::now(),
775 })
776 })
777 }
778
779 fn worktree_snapshot(
780 worktree: Entity<project::Worktree>,
781 git_store: Entity<GitStore>,
782 cx: &App,
783 ) -> Task<agent::thread::WorktreeSnapshot> {
784 cx.spawn(async move |cx| {
785 // Get worktree path and snapshot
786 let worktree_info = cx.update(|app_cx| {
787 let worktree = worktree.read(app_cx);
788 let path = worktree.abs_path().to_string_lossy().to_string();
789 let snapshot = worktree.snapshot();
790 (path, snapshot)
791 });
792
793 let Ok((worktree_path, _snapshot)) = worktree_info else {
794 return WorktreeSnapshot {
795 worktree_path: String::new(),
796 git_state: None,
797 };
798 };
799
800 let git_state = git_store
801 .update(cx, |git_store, cx| {
802 git_store
803 .repositories()
804 .values()
805 .find(|repo| {
806 repo.read(cx)
807 .abs_path_to_repo_path(&worktree.read(cx).abs_path())
808 .is_some()
809 })
810 .cloned()
811 })
812 .ok()
813 .flatten()
814 .map(|repo| {
815 repo.update(cx, |repo, _| {
816 let current_branch =
817 repo.branch.as_ref().map(|branch| branch.name().to_owned());
818 repo.send_job(None, |state, _| async move {
819 let RepositoryState::Local { backend, .. } = state else {
820 return GitState {
821 remote_url: None,
822 head_sha: None,
823 current_branch,
824 diff: None,
825 };
826 };
827
828 let remote_url = backend.remote_url("origin");
829 let head_sha = backend.head_sha().await;
830 let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
831
832 GitState {
833 remote_url,
834 head_sha,
835 current_branch,
836 diff,
837 }
838 })
839 })
840 });
841
842 let git_state = match git_state {
843 Some(git_state) => match git_state.ok() {
844 Some(git_state) => git_state.await.ok(),
845 None => None,
846 },
847 None => None,
848 };
849
850 WorktreeSnapshot {
851 worktree_path,
852 git_state,
853 }
854 })
855 }
856
857 pub fn project_context(&self) -> &Entity<ProjectContext> {
858 &self.project_context
859 }
860
861 pub fn project(&self) -> &Entity<Project> {
862 &self.project
863 }
864
865 pub fn action_log(&self) -> &Entity<ActionLog> {
866 &self.action_log
867 }
868
869 pub fn model(&self) -> Option<&Arc<dyn LanguageModel>> {
870 self.model.as_ref()
871 }
872
873 pub fn set_model(&mut self, model: Arc<dyn LanguageModel>, cx: &mut Context<Self>) {
874 self.model = Some(model);
875 cx.notify()
876 }
877
878 pub fn set_summarization_model(
879 &mut self,
880 model: Option<Arc<dyn LanguageModel>>,
881 cx: &mut Context<Self>,
882 ) {
883 self.summarization_model = model;
884 cx.notify()
885 }
886
887 pub fn completion_mode(&self) -> CompletionMode {
888 self.completion_mode
889 }
890
891 pub fn set_completion_mode(&mut self, mode: CompletionMode, cx: &mut Context<Self>) {
892 self.completion_mode = mode;
893 cx.notify()
894 }
895
896 #[cfg(any(test, feature = "test-support"))]
897 pub fn last_message(&self) -> Option<Message> {
898 if let Some(message) = self.pending_message.clone() {
899 Some(Message::Agent(message))
900 } else {
901 self.messages.last().cloned()
902 }
903 }
904
905 pub fn add_default_tools(&mut self, cx: &mut Context<Self>) {
906 let language_registry = self.project.read(cx).languages().clone();
907 self.add_tool(CopyPathTool::new(self.project.clone()));
908 self.add_tool(CreateDirectoryTool::new(self.project.clone()));
909 self.add_tool(DeletePathTool::new(
910 self.project.clone(),
911 self.action_log.clone(),
912 ));
913 self.add_tool(DiagnosticsTool::new(self.project.clone()));
914 self.add_tool(EditFileTool::new(cx.weak_entity(), language_registry));
915 self.add_tool(FetchTool::new(self.project.read(cx).client().http_client()));
916 self.add_tool(FindPathTool::new(self.project.clone()));
917 self.add_tool(GrepTool::new(self.project.clone()));
918 self.add_tool(ListDirectoryTool::new(self.project.clone()));
919 self.add_tool(MovePathTool::new(self.project.clone()));
920 self.add_tool(NowTool);
921 self.add_tool(OpenTool::new(self.project.clone()));
922 self.add_tool(ReadFileTool::new(
923 self.project.clone(),
924 self.action_log.clone(),
925 ));
926 self.add_tool(TerminalTool::new(self.project.clone(), cx));
927 self.add_tool(ThinkingTool);
928 self.add_tool(WebSearchTool); // TODO: Enable this only if it's a zed model.
929 }
930
931 pub fn add_tool(&mut self, tool: impl AgentTool) {
932 self.tools.insert(tool.name(), tool.erase());
933 }
934
935 pub fn remove_tool(&mut self, name: &str) -> bool {
936 self.tools.remove(name).is_some()
937 }
938
939 pub fn profile(&self) -> &AgentProfileId {
940 &self.profile_id
941 }
942
943 pub fn set_profile(&mut self, profile_id: AgentProfileId) {
944 self.profile_id = profile_id;
945 }
946
947 pub fn cancel(&mut self, cx: &mut Context<Self>) {
948 if let Some(running_turn) = self.running_turn.take() {
949 running_turn.cancel();
950 }
951 self.flush_pending_message(cx);
952 }
953
954 pub fn truncate(&mut self, message_id: UserMessageId, cx: &mut Context<Self>) -> Result<()> {
955 self.cancel(cx);
956 let Some(position) = self.messages.iter().position(
957 |msg| matches!(msg, Message::User(UserMessage { id, .. }) if id == &message_id),
958 ) else {
959 return Err(anyhow!("Message not found"));
960 };
961 self.messages.truncate(position);
962 cx.notify();
963 Ok(())
964 }
965
966 pub fn resume(
967 &mut self,
968 cx: &mut Context<Self>,
969 ) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>> {
970 anyhow::ensure!(
971 self.tool_use_limit_reached,
972 "can only resume after tool use limit is reached"
973 );
974
975 self.messages.push(Message::Resume);
976 cx.notify();
977
978 log::info!("Total messages in thread: {}", self.messages.len());
979 self.run_turn(cx)
980 }
981
982 /// Sending a message results in the model streaming a response, which could include tool calls.
983 /// After calling tools, the model will stops and waits for any outstanding tool calls to be completed and their results sent.
984 /// The returned channel will report all the occurrences in which the model stops before erroring or ending its turn.
985 pub fn send<T>(
986 &mut self,
987 id: UserMessageId,
988 content: impl IntoIterator<Item = T>,
989 cx: &mut Context<Self>,
990 ) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>
991 where
992 T: Into<UserMessageContent>,
993 {
994 let model = self.model().context("No language model configured")?;
995
996 log::info!("Thread::send called with model: {:?}", model.name());
997 self.advance_prompt_id();
998
999 let content = content.into_iter().map(Into::into).collect::<Vec<_>>();
1000 log::debug!("Thread::send content: {:?}", content);
1001
1002 self.messages
1003 .push(Message::User(UserMessage { id, content }));
1004 cx.notify();
1005
1006 log::info!("Total messages in thread: {}", self.messages.len());
1007 self.run_turn(cx)
1008 }
1009
1010 fn run_turn(
1011 &mut self,
1012 cx: &mut Context<Self>,
1013 ) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>> {
1014 self.cancel(cx);
1015
1016 let model = self.model.clone().context("No language model configured")?;
1017 let (events_tx, events_rx) = mpsc::unbounded::<Result<ThreadEvent>>();
1018 let event_stream = ThreadEventStream(events_tx);
1019 let message_ix = self.messages.len().saturating_sub(1);
1020 self.tool_use_limit_reached = false;
1021 self.running_turn = Some(RunningTurn {
1022 event_stream: event_stream.clone(),
1023 _task: cx.spawn(async move |this, cx| {
1024 log::info!("Starting agent turn execution");
1025 let turn_result: Result<StopReason> = async {
1026 let mut completion_intent = CompletionIntent::UserPrompt;
1027 loop {
1028 log::debug!(
1029 "Building completion request with intent: {:?}",
1030 completion_intent
1031 );
1032 let request = this.update(cx, |this, cx| {
1033 this.build_completion_request(completion_intent, cx)
1034 })??;
1035
1036 log::info!("Calling model.stream_completion");
1037
1038 let mut tool_use_limit_reached = false;
1039 let mut refused = false;
1040 let mut reached_max_tokens = false;
1041 let mut tool_uses = Self::stream_completion_with_retries(
1042 this.clone(),
1043 model.clone(),
1044 request,
1045 &event_stream,
1046 &mut tool_use_limit_reached,
1047 &mut refused,
1048 &mut reached_max_tokens,
1049 cx,
1050 )
1051 .await?;
1052
1053 if refused {
1054 return Ok(StopReason::Refusal);
1055 } else if reached_max_tokens {
1056 return Ok(StopReason::MaxTokens);
1057 }
1058
1059 let end_turn = tool_uses.is_empty();
1060 while let Some(tool_result) = tool_uses.next().await {
1061 log::info!("Tool finished {:?}", tool_result);
1062
1063 event_stream.update_tool_call_fields(
1064 &tool_result.tool_use_id,
1065 acp::ToolCallUpdateFields {
1066 status: Some(if tool_result.is_error {
1067 acp::ToolCallStatus::Failed
1068 } else {
1069 acp::ToolCallStatus::Completed
1070 }),
1071 raw_output: tool_result.output.clone(),
1072 ..Default::default()
1073 },
1074 );
1075 this.update(cx, |this, _cx| {
1076 this.pending_message()
1077 .tool_results
1078 .insert(tool_result.tool_use_id.clone(), tool_result);
1079 })
1080 .ok();
1081 }
1082
1083 if tool_use_limit_reached {
1084 log::info!("Tool use limit reached, completing turn");
1085 this.update(cx, |this, _cx| this.tool_use_limit_reached = true)?;
1086 return Err(language_model::ToolUseLimitReachedError.into());
1087 } else if end_turn {
1088 log::info!("No tool uses found, completing turn");
1089 return Ok(StopReason::EndTurn);
1090 } else {
1091 this.update(cx, |this, cx| this.flush_pending_message(cx))?;
1092 completion_intent = CompletionIntent::ToolResults;
1093 }
1094 }
1095 }
1096 .await;
1097 _ = this.update(cx, |this, cx| this.flush_pending_message(cx));
1098
1099 match turn_result {
1100 Ok(reason) => {
1101 log::info!("Turn execution completed: {:?}", reason);
1102
1103 let update_title = this
1104 .update(cx, |this, cx| this.update_title(&event_stream, cx))
1105 .ok()
1106 .flatten();
1107 if let Some(update_title) = update_title {
1108 update_title.await.context("update title failed").log_err();
1109 }
1110
1111 event_stream.send_stop(reason);
1112 if reason == StopReason::Refusal {
1113 _ = this.update(cx, |this, _| this.messages.truncate(message_ix));
1114 }
1115 }
1116 Err(error) => {
1117 log::error!("Turn execution failed: {:?}", error);
1118 event_stream.send_error(error);
1119 }
1120 }
1121
1122 _ = this.update(cx, |this, _| this.running_turn.take());
1123 }),
1124 });
1125 Ok(events_rx)
1126 }
1127
1128 async fn stream_completion_with_retries(
1129 this: WeakEntity<Self>,
1130 model: Arc<dyn LanguageModel>,
1131 request: LanguageModelRequest,
1132 event_stream: &ThreadEventStream,
1133 tool_use_limit_reached: &mut bool,
1134 refusal: &mut bool,
1135 max_tokens_reached: &mut bool,
1136 cx: &mut AsyncApp,
1137 ) -> Result<FuturesUnordered<Task<LanguageModelToolResult>>> {
1138 log::debug!("Stream completion started successfully");
1139
1140 let mut attempt = None;
1141 'retry: loop {
1142 let mut events = model.stream_completion(request.clone(), cx).await?;
1143 let mut tool_uses = FuturesUnordered::new();
1144 while let Some(event) = events.next().await {
1145 match event {
1146 Ok(LanguageModelCompletionEvent::StatusUpdate(
1147 CompletionRequestStatus::ToolUseLimitReached,
1148 )) => {
1149 *tool_use_limit_reached = true;
1150 }
1151 Ok(LanguageModelCompletionEvent::Stop(StopReason::Refusal)) => {
1152 *refusal = true;
1153 return Ok(FuturesUnordered::default());
1154 }
1155 Ok(LanguageModelCompletionEvent::Stop(StopReason::MaxTokens)) => {
1156 *max_tokens_reached = true;
1157 return Ok(FuturesUnordered::default());
1158 }
1159 Ok(LanguageModelCompletionEvent::Stop(
1160 StopReason::ToolUse | StopReason::EndTurn,
1161 )) => break,
1162 Ok(event) => {
1163 log::trace!("Received completion event: {:?}", event);
1164 this.update(cx, |this, cx| {
1165 tool_uses.extend(this.handle_streamed_completion_event(
1166 event,
1167 event_stream,
1168 cx,
1169 ));
1170 })
1171 .ok();
1172 }
1173 Err(error) => {
1174 let completion_mode =
1175 this.read_with(cx, |thread, _cx| thread.completion_mode())?;
1176 if completion_mode == CompletionMode::Normal {
1177 return Err(error.into());
1178 }
1179
1180 let Some(strategy) = Self::retry_strategy_for(&error) else {
1181 return Err(error.into());
1182 };
1183
1184 let max_attempts = match &strategy {
1185 RetryStrategy::ExponentialBackoff { max_attempts, .. } => *max_attempts,
1186 RetryStrategy::Fixed { max_attempts, .. } => *max_attempts,
1187 };
1188
1189 let attempt = attempt.get_or_insert(0u8);
1190
1191 *attempt += 1;
1192
1193 let attempt = *attempt;
1194 if attempt > max_attempts {
1195 return Err(error.into());
1196 }
1197
1198 let delay = match &strategy {
1199 RetryStrategy::ExponentialBackoff { initial_delay, .. } => {
1200 let delay_secs =
1201 initial_delay.as_secs() * 2u64.pow((attempt - 1) as u32);
1202 Duration::from_secs(delay_secs)
1203 }
1204 RetryStrategy::Fixed { delay, .. } => *delay,
1205 };
1206 log::debug!("Retry attempt {attempt} with delay {delay:?}");
1207
1208 event_stream.send_retry(acp_thread::RetryStatus {
1209 last_error: error.to_string().into(),
1210 attempt: attempt as usize,
1211 max_attempts: max_attempts as usize,
1212 started_at: Instant::now(),
1213 duration: delay,
1214 });
1215
1216 cx.background_executor().timer(delay).await;
1217 continue 'retry;
1218 }
1219 }
1220 }
1221
1222 return Ok(tool_uses);
1223 }
1224 }
1225
1226 pub fn build_system_message(&self, cx: &App) -> LanguageModelRequestMessage {
1227 log::debug!("Building system message");
1228 let prompt = SystemPromptTemplate {
1229 project: self.project_context.read(cx),
1230 available_tools: self.tools.keys().cloned().collect(),
1231 }
1232 .render(&self.templates)
1233 .context("failed to build system prompt")
1234 .expect("Invalid template");
1235 log::debug!("System message built");
1236 LanguageModelRequestMessage {
1237 role: Role::System,
1238 content: vec![prompt.into()],
1239 cache: true,
1240 }
1241 }
1242
1243 /// A helper method that's called on every streamed completion event.
1244 /// Returns an optional tool result task, which the main agentic loop in
1245 /// send will send back to the model when it resolves.
1246 fn handle_streamed_completion_event(
1247 &mut self,
1248 event: LanguageModelCompletionEvent,
1249 event_stream: &ThreadEventStream,
1250 cx: &mut Context<Self>,
1251 ) -> Option<Task<LanguageModelToolResult>> {
1252 log::trace!("Handling streamed completion event: {:?}", event);
1253 use LanguageModelCompletionEvent::*;
1254
1255 match event {
1256 StartMessage { .. } => {
1257 self.flush_pending_message(cx);
1258 self.pending_message = Some(AgentMessage::default());
1259 }
1260 Text(new_text) => self.handle_text_event(new_text, event_stream, cx),
1261 Thinking { text, signature } => {
1262 self.handle_thinking_event(text, signature, event_stream, cx)
1263 }
1264 RedactedThinking { data } => self.handle_redacted_thinking_event(data, cx),
1265 ToolUse(tool_use) => {
1266 return self.handle_tool_use_event(tool_use, event_stream, cx);
1267 }
1268 ToolUseJsonParseError {
1269 id,
1270 tool_name,
1271 raw_input,
1272 json_parse_error,
1273 } => {
1274 return Some(Task::ready(self.handle_tool_use_json_parse_error_event(
1275 id,
1276 tool_name,
1277 raw_input,
1278 json_parse_error,
1279 )));
1280 }
1281 UsageUpdate(_) | StatusUpdate(_) => {}
1282 Stop(_) => unreachable!(),
1283 }
1284
1285 None
1286 }
1287
1288 fn handle_text_event(
1289 &mut self,
1290 new_text: String,
1291 event_stream: &ThreadEventStream,
1292 cx: &mut Context<Self>,
1293 ) {
1294 event_stream.send_text(&new_text);
1295
1296 let last_message = self.pending_message();
1297 if let Some(AgentMessageContent::Text(text)) = last_message.content.last_mut() {
1298 text.push_str(&new_text);
1299 } else {
1300 last_message
1301 .content
1302 .push(AgentMessageContent::Text(new_text));
1303 }
1304
1305 cx.notify();
1306 }
1307
1308 fn handle_thinking_event(
1309 &mut self,
1310 new_text: String,
1311 new_signature: Option<String>,
1312 event_stream: &ThreadEventStream,
1313 cx: &mut Context<Self>,
1314 ) {
1315 event_stream.send_thinking(&new_text);
1316
1317 let last_message = self.pending_message();
1318 if let Some(AgentMessageContent::Thinking { text, signature }) =
1319 last_message.content.last_mut()
1320 {
1321 text.push_str(&new_text);
1322 *signature = new_signature.or(signature.take());
1323 } else {
1324 last_message.content.push(AgentMessageContent::Thinking {
1325 text: new_text,
1326 signature: new_signature,
1327 });
1328 }
1329
1330 cx.notify();
1331 }
1332
1333 fn handle_redacted_thinking_event(&mut self, data: String, cx: &mut Context<Self>) {
1334 let last_message = self.pending_message();
1335 last_message
1336 .content
1337 .push(AgentMessageContent::RedactedThinking(data));
1338 cx.notify();
1339 }
1340
1341 fn handle_tool_use_event(
1342 &mut self,
1343 tool_use: LanguageModelToolUse,
1344 event_stream: &ThreadEventStream,
1345 cx: &mut Context<Self>,
1346 ) -> Option<Task<LanguageModelToolResult>> {
1347 cx.notify();
1348
1349 let tool = self.tools.get(tool_use.name.as_ref()).cloned();
1350 let mut title = SharedString::from(&tool_use.name);
1351 let mut kind = acp::ToolKind::Other;
1352 if let Some(tool) = tool.as_ref() {
1353 title = tool.initial_title(tool_use.input.clone());
1354 kind = tool.kind();
1355 }
1356
1357 // Ensure the last message ends in the current tool use
1358 let last_message = self.pending_message();
1359 let push_new_tool_use = last_message.content.last_mut().is_none_or(|content| {
1360 if let AgentMessageContent::ToolUse(last_tool_use) = content {
1361 if last_tool_use.id == tool_use.id {
1362 *last_tool_use = tool_use.clone();
1363 false
1364 } else {
1365 true
1366 }
1367 } else {
1368 true
1369 }
1370 });
1371
1372 if push_new_tool_use {
1373 event_stream.send_tool_call(&tool_use.id, title, kind, tool_use.input.clone());
1374 last_message
1375 .content
1376 .push(AgentMessageContent::ToolUse(tool_use.clone()));
1377 } else {
1378 event_stream.update_tool_call_fields(
1379 &tool_use.id,
1380 acp::ToolCallUpdateFields {
1381 title: Some(title.into()),
1382 kind: Some(kind),
1383 raw_input: Some(tool_use.input.clone()),
1384 ..Default::default()
1385 },
1386 );
1387 }
1388
1389 if !tool_use.is_input_complete {
1390 return None;
1391 }
1392
1393 let Some(tool) = tool else {
1394 let content = format!("No tool named {} exists", tool_use.name);
1395 return Some(Task::ready(LanguageModelToolResult {
1396 content: LanguageModelToolResultContent::Text(Arc::from(content)),
1397 tool_use_id: tool_use.id,
1398 tool_name: tool_use.name,
1399 is_error: true,
1400 output: None,
1401 }));
1402 };
1403
1404 let fs = self.project.read(cx).fs().clone();
1405 let tool_event_stream =
1406 ToolCallEventStream::new(tool_use.id.clone(), event_stream.clone(), Some(fs));
1407 tool_event_stream.update_fields(acp::ToolCallUpdateFields {
1408 status: Some(acp::ToolCallStatus::InProgress),
1409 ..Default::default()
1410 });
1411 let supports_images = self.model().is_some_and(|model| model.supports_images());
1412 let tool_result = tool.run(tool_use.input, tool_event_stream, cx);
1413 log::info!("Running tool {}", tool_use.name);
1414 Some(cx.foreground_executor().spawn(async move {
1415 let tool_result = tool_result.await.and_then(|output| {
1416 if let LanguageModelToolResultContent::Image(_) = &output.llm_output
1417 && !supports_images
1418 {
1419 return Err(anyhow!(
1420 "Attempted to read an image, but this model doesn't support it.",
1421 ));
1422 }
1423 Ok(output)
1424 });
1425
1426 match tool_result {
1427 Ok(output) => LanguageModelToolResult {
1428 tool_use_id: tool_use.id,
1429 tool_name: tool_use.name,
1430 is_error: false,
1431 content: output.llm_output,
1432 output: Some(output.raw_output),
1433 },
1434 Err(error) => LanguageModelToolResult {
1435 tool_use_id: tool_use.id,
1436 tool_name: tool_use.name,
1437 is_error: true,
1438 content: LanguageModelToolResultContent::Text(Arc::from(error.to_string())),
1439 output: None,
1440 },
1441 }
1442 }))
1443 }
1444
1445 fn handle_tool_use_json_parse_error_event(
1446 &mut self,
1447 tool_use_id: LanguageModelToolUseId,
1448 tool_name: Arc<str>,
1449 raw_input: Arc<str>,
1450 json_parse_error: String,
1451 ) -> LanguageModelToolResult {
1452 let tool_output = format!("Error parsing input JSON: {json_parse_error}");
1453 LanguageModelToolResult {
1454 tool_use_id,
1455 tool_name,
1456 is_error: true,
1457 content: LanguageModelToolResultContent::Text(tool_output.into()),
1458 output: Some(serde_json::Value::String(raw_input.to_string())),
1459 }
1460 }
1461
1462 pub fn title(&self) -> SharedString {
1463 self.title.clone().unwrap_or("New Thread".into())
1464 }
1465
1466 fn update_title(
1467 &mut self,
1468 event_stream: &ThreadEventStream,
1469 cx: &mut Context<Self>,
1470 ) -> Option<Task<Result<()>>> {
1471 if self.title.is_some() {
1472 log::debug!("Skipping title generation because we already have one.");
1473 return None;
1474 }
1475
1476 log::info!(
1477 "Generating title with model: {:?}",
1478 self.summarization_model.as_ref().map(|model| model.name())
1479 );
1480 let model = self.summarization_model.clone()?;
1481 let event_stream = event_stream.clone();
1482 let mut request = LanguageModelRequest {
1483 intent: Some(CompletionIntent::ThreadSummarization),
1484 temperature: AgentSettings::temperature_for_model(&model, cx),
1485 ..Default::default()
1486 };
1487
1488 for message in &self.messages {
1489 request.messages.extend(message.to_request());
1490 }
1491
1492 request.messages.push(LanguageModelRequestMessage {
1493 role: Role::User,
1494 content: vec![SUMMARIZE_THREAD_PROMPT.into()],
1495 cache: false,
1496 });
1497 Some(cx.spawn(async move |this, cx| {
1498 let mut title = String::new();
1499 let mut messages = model.stream_completion(request, cx).await?;
1500 while let Some(event) = messages.next().await {
1501 let event = event?;
1502 let text = match event {
1503 LanguageModelCompletionEvent::Text(text) => text,
1504 LanguageModelCompletionEvent::StatusUpdate(
1505 CompletionRequestStatus::UsageUpdated { .. },
1506 ) => {
1507 // this.update(cx, |thread, cx| {
1508 // thread.update_model_request_usage(amount as u32, limit, cx);
1509 // })?;
1510 // TODO: handle usage update
1511 continue;
1512 }
1513 _ => continue,
1514 };
1515
1516 let mut lines = text.lines();
1517 title.extend(lines.next());
1518
1519 // Stop if the LLM generated multiple lines.
1520 if lines.next().is_some() {
1521 break;
1522 }
1523 }
1524
1525 log::info!("Setting title: {}", title);
1526
1527 this.update(cx, |this, cx| {
1528 let title = SharedString::from(title);
1529 event_stream.send_title_update(title.clone());
1530 this.title = Some(title);
1531 cx.notify();
1532 })
1533 }))
1534 }
1535
1536 fn pending_message(&mut self) -> &mut AgentMessage {
1537 self.pending_message.get_or_insert_default()
1538 }
1539
1540 fn flush_pending_message(&mut self, cx: &mut Context<Self>) {
1541 let Some(mut message) = self.pending_message.take() else {
1542 return;
1543 };
1544
1545 for content in &message.content {
1546 let AgentMessageContent::ToolUse(tool_use) = content else {
1547 continue;
1548 };
1549
1550 if !message.tool_results.contains_key(&tool_use.id) {
1551 message.tool_results.insert(
1552 tool_use.id.clone(),
1553 LanguageModelToolResult {
1554 tool_use_id: tool_use.id.clone(),
1555 tool_name: tool_use.name.clone(),
1556 is_error: true,
1557 content: LanguageModelToolResultContent::Text(TOOL_CANCELED_MESSAGE.into()),
1558 output: None,
1559 },
1560 );
1561 }
1562 }
1563
1564 self.messages.push(Message::Agent(message));
1565 self.updated_at = Utc::now();
1566 cx.notify()
1567 }
1568
1569 pub(crate) fn build_completion_request(
1570 &self,
1571 completion_intent: CompletionIntent,
1572 cx: &mut App,
1573 ) -> Result<LanguageModelRequest> {
1574 let model = self.model().context("No language model configured")?;
1575
1576 log::debug!("Building completion request");
1577 log::debug!("Completion intent: {:?}", completion_intent);
1578 log::debug!("Completion mode: {:?}", self.completion_mode);
1579
1580 let messages = self.build_request_messages(cx);
1581 log::info!("Request will include {} messages", messages.len());
1582
1583 let tools = if let Some(tools) = self.tools(cx).log_err() {
1584 tools
1585 .filter_map(|tool| {
1586 let tool_name = tool.name().to_string();
1587 log::trace!("Including tool: {}", tool_name);
1588 Some(LanguageModelRequestTool {
1589 name: tool_name,
1590 description: tool.description().to_string(),
1591 input_schema: tool.input_schema(model.tool_input_format()).log_err()?,
1592 })
1593 })
1594 .collect()
1595 } else {
1596 Vec::new()
1597 };
1598
1599 log::info!("Request includes {} tools", tools.len());
1600
1601 let request = LanguageModelRequest {
1602 thread_id: Some(self.id.to_string()),
1603 prompt_id: Some(self.prompt_id.to_string()),
1604 intent: Some(completion_intent),
1605 mode: Some(self.completion_mode.into()),
1606 messages,
1607 tools,
1608 tool_choice: None,
1609 stop: Vec::new(),
1610 temperature: AgentSettings::temperature_for_model(model, cx),
1611 thinking_allowed: true,
1612 };
1613
1614 log::debug!("Completion request built successfully");
1615 Ok(request)
1616 }
1617
1618 fn tools<'a>(&'a self, cx: &'a App) -> Result<impl Iterator<Item = &'a Arc<dyn AnyAgentTool>>> {
1619 let model = self.model().context("No language model configured")?;
1620
1621 let profile = AgentSettings::get_global(cx)
1622 .profiles
1623 .get(&self.profile_id)
1624 .context("profile not found")?;
1625 let provider_id = model.provider_id();
1626
1627 Ok(self
1628 .tools
1629 .iter()
1630 .filter(move |(_, tool)| tool.supported_provider(&provider_id))
1631 .filter_map(|(tool_name, tool)| {
1632 if profile.is_tool_enabled(tool_name) {
1633 Some(tool)
1634 } else {
1635 None
1636 }
1637 })
1638 .chain(self.context_server_registry.read(cx).servers().flat_map(
1639 |(server_id, tools)| {
1640 tools.iter().filter_map(|(tool_name, tool)| {
1641 if profile.is_context_server_tool_enabled(&server_id.0, tool_name) {
1642 Some(tool)
1643 } else {
1644 None
1645 }
1646 })
1647 },
1648 )))
1649 }
1650
1651 fn build_request_messages(&self, cx: &App) -> Vec<LanguageModelRequestMessage> {
1652 log::trace!(
1653 "Building request messages from {} thread messages",
1654 self.messages.len()
1655 );
1656 let mut messages = vec![self.build_system_message(cx)];
1657 for message in &self.messages {
1658 messages.extend(message.to_request());
1659 }
1660
1661 if let Some(message) = self.pending_message.as_ref() {
1662 messages.extend(message.to_request());
1663 }
1664
1665 if let Some(last_user_message) = messages
1666 .iter_mut()
1667 .rev()
1668 .find(|message| message.role == Role::User)
1669 {
1670 last_user_message.cache = true;
1671 }
1672
1673 messages
1674 }
1675
1676 pub fn to_markdown(&self) -> String {
1677 let mut markdown = String::new();
1678 for (ix, message) in self.messages.iter().enumerate() {
1679 if ix > 0 {
1680 markdown.push('\n');
1681 }
1682 markdown.push_str(&message.to_markdown());
1683 }
1684
1685 if let Some(message) = self.pending_message.as_ref() {
1686 markdown.push('\n');
1687 markdown.push_str(&message.to_markdown());
1688 }
1689
1690 markdown
1691 }
1692
1693 fn advance_prompt_id(&mut self) {
1694 self.prompt_id = PromptId::new();
1695 }
1696
1697 fn retry_strategy_for(error: &LanguageModelCompletionError) -> Option<RetryStrategy> {
1698 use LanguageModelCompletionError::*;
1699 use http_client::StatusCode;
1700
1701 // General strategy here:
1702 // - If retrying won't help (e.g. invalid API key or payload too large), return None so we don't retry at all.
1703 // - If it's a time-based issue (e.g. server overloaded, rate limit exceeded), retry up to 4 times with exponential backoff.
1704 // - If it's an issue that *might* be fixed by retrying (e.g. internal server error), retry up to 3 times.
1705 match error {
1706 HttpResponseError {
1707 status_code: StatusCode::TOO_MANY_REQUESTS,
1708 ..
1709 } => Some(RetryStrategy::ExponentialBackoff {
1710 initial_delay: BASE_RETRY_DELAY,
1711 max_attempts: MAX_RETRY_ATTEMPTS,
1712 }),
1713 ServerOverloaded { retry_after, .. } | RateLimitExceeded { retry_after, .. } => {
1714 Some(RetryStrategy::Fixed {
1715 delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
1716 max_attempts: MAX_RETRY_ATTEMPTS,
1717 })
1718 }
1719 UpstreamProviderError {
1720 status,
1721 retry_after,
1722 ..
1723 } => match *status {
1724 StatusCode::TOO_MANY_REQUESTS | StatusCode::SERVICE_UNAVAILABLE => {
1725 Some(RetryStrategy::Fixed {
1726 delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
1727 max_attempts: MAX_RETRY_ATTEMPTS,
1728 })
1729 }
1730 StatusCode::INTERNAL_SERVER_ERROR => Some(RetryStrategy::Fixed {
1731 delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
1732 // Internal Server Error could be anything, retry up to 3 times.
1733 max_attempts: 3,
1734 }),
1735 status => {
1736 // There is no StatusCode variant for the unofficial HTTP 529 ("The service is overloaded"),
1737 // but we frequently get them in practice. See https://http.dev/529
1738 if status.as_u16() == 529 {
1739 Some(RetryStrategy::Fixed {
1740 delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
1741 max_attempts: MAX_RETRY_ATTEMPTS,
1742 })
1743 } else {
1744 Some(RetryStrategy::Fixed {
1745 delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
1746 max_attempts: 2,
1747 })
1748 }
1749 }
1750 },
1751 ApiInternalServerError { .. } => Some(RetryStrategy::Fixed {
1752 delay: BASE_RETRY_DELAY,
1753 max_attempts: 3,
1754 }),
1755 ApiReadResponseError { .. }
1756 | HttpSend { .. }
1757 | DeserializeResponse { .. }
1758 | BadRequestFormat { .. } => Some(RetryStrategy::Fixed {
1759 delay: BASE_RETRY_DELAY,
1760 max_attempts: 3,
1761 }),
1762 // Retrying these errors definitely shouldn't help.
1763 HttpResponseError {
1764 status_code:
1765 StatusCode::PAYLOAD_TOO_LARGE | StatusCode::FORBIDDEN | StatusCode::UNAUTHORIZED,
1766 ..
1767 }
1768 | AuthenticationError { .. }
1769 | PermissionError { .. }
1770 | NoApiKey { .. }
1771 | ApiEndpointNotFound { .. }
1772 | PromptTooLarge { .. } => None,
1773 // These errors might be transient, so retry them
1774 SerializeRequest { .. } | BuildRequestBody { .. } => Some(RetryStrategy::Fixed {
1775 delay: BASE_RETRY_DELAY,
1776 max_attempts: 1,
1777 }),
1778 // Retry all other 4xx and 5xx errors once.
1779 HttpResponseError { status_code, .. }
1780 if status_code.is_client_error() || status_code.is_server_error() =>
1781 {
1782 Some(RetryStrategy::Fixed {
1783 delay: BASE_RETRY_DELAY,
1784 max_attempts: 3,
1785 })
1786 }
1787 Other(err)
1788 if err.is::<language_model::PaymentRequiredError>()
1789 || err.is::<language_model::ModelRequestLimitReachedError>() =>
1790 {
1791 // Retrying won't help for Payment Required or Model Request Limit errors (where
1792 // the user must upgrade to usage-based billing to get more requests, or else wait
1793 // for a significant amount of time for the request limit to reset).
1794 None
1795 }
1796 // Conservatively assume that any other errors are non-retryable
1797 HttpResponseError { .. } | Other(..) => Some(RetryStrategy::Fixed {
1798 delay: BASE_RETRY_DELAY,
1799 max_attempts: 2,
1800 }),
1801 }
1802 }
1803}
1804
1805struct RunningTurn {
1806 /// Holds the task that handles agent interaction until the end of the turn.
1807 /// Survives across multiple requests as the model performs tool calls and
1808 /// we run tools, report their results.
1809 _task: Task<()>,
1810 /// The current event stream for the running turn. Used to report a final
1811 /// cancellation event if we cancel the turn.
1812 event_stream: ThreadEventStream,
1813}
1814
1815impl RunningTurn {
1816 fn cancel(self) {
1817 log::debug!("Cancelling in progress turn");
1818 self.event_stream.send_canceled();
1819 }
1820}
1821
1822pub trait AgentTool
1823where
1824 Self: 'static + Sized,
1825{
1826 type Input: for<'de> Deserialize<'de> + Serialize + JsonSchema;
1827 type Output: for<'de> Deserialize<'de> + Serialize + Into<LanguageModelToolResultContent>;
1828
1829 fn name(&self) -> SharedString;
1830
1831 fn description(&self) -> SharedString {
1832 let schema = schemars::schema_for!(Self::Input);
1833 SharedString::new(
1834 schema
1835 .get("description")
1836 .and_then(|description| description.as_str())
1837 .unwrap_or_default(),
1838 )
1839 }
1840
1841 fn kind(&self) -> acp::ToolKind;
1842
1843 /// The initial tool title to display. Can be updated during the tool run.
1844 fn initial_title(&self, input: Result<Self::Input, serde_json::Value>) -> SharedString;
1845
1846 /// Returns the JSON schema that describes the tool's input.
1847 fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Schema {
1848 crate::tool_schema::root_schema_for::<Self::Input>(format)
1849 }
1850
1851 /// Some tools rely on a provider for the underlying billing or other reasons.
1852 /// Allow the tool to check if they are compatible, or should be filtered out.
1853 fn supported_provider(&self, _provider: &LanguageModelProviderId) -> bool {
1854 true
1855 }
1856
1857 /// Runs the tool with the provided input.
1858 fn run(
1859 self: Arc<Self>,
1860 input: Self::Input,
1861 event_stream: ToolCallEventStream,
1862 cx: &mut App,
1863 ) -> Task<Result<Self::Output>>;
1864
1865 /// Emits events for a previous execution of the tool.
1866 fn replay(
1867 &self,
1868 _input: Self::Input,
1869 _output: Self::Output,
1870 _event_stream: ToolCallEventStream,
1871 _cx: &mut App,
1872 ) -> Result<()> {
1873 Ok(())
1874 }
1875
1876 fn erase(self) -> Arc<dyn AnyAgentTool> {
1877 Arc::new(Erased(Arc::new(self)))
1878 }
1879}
1880
1881pub struct Erased<T>(T);
1882
1883pub struct AgentToolOutput {
1884 pub llm_output: LanguageModelToolResultContent,
1885 pub raw_output: serde_json::Value,
1886}
1887
1888pub trait AnyAgentTool {
1889 fn name(&self) -> SharedString;
1890 fn description(&self) -> SharedString;
1891 fn kind(&self) -> acp::ToolKind;
1892 fn initial_title(&self, input: serde_json::Value) -> SharedString;
1893 fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value>;
1894 fn supported_provider(&self, _provider: &LanguageModelProviderId) -> bool {
1895 true
1896 }
1897 fn run(
1898 self: Arc<Self>,
1899 input: serde_json::Value,
1900 event_stream: ToolCallEventStream,
1901 cx: &mut App,
1902 ) -> Task<Result<AgentToolOutput>>;
1903 fn replay(
1904 &self,
1905 input: serde_json::Value,
1906 output: serde_json::Value,
1907 event_stream: ToolCallEventStream,
1908 cx: &mut App,
1909 ) -> Result<()>;
1910}
1911
1912impl<T> AnyAgentTool for Erased<Arc<T>>
1913where
1914 T: AgentTool,
1915{
1916 fn name(&self) -> SharedString {
1917 self.0.name()
1918 }
1919
1920 fn description(&self) -> SharedString {
1921 self.0.description()
1922 }
1923
1924 fn kind(&self) -> agent_client_protocol::ToolKind {
1925 self.0.kind()
1926 }
1927
1928 fn initial_title(&self, input: serde_json::Value) -> SharedString {
1929 let parsed_input = serde_json::from_value(input.clone()).map_err(|_| input);
1930 self.0.initial_title(parsed_input)
1931 }
1932
1933 fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
1934 let mut json = serde_json::to_value(self.0.input_schema(format))?;
1935 adapt_schema_to_format(&mut json, format)?;
1936 Ok(json)
1937 }
1938
1939 fn supported_provider(&self, provider: &LanguageModelProviderId) -> bool {
1940 self.0.supported_provider(provider)
1941 }
1942
1943 fn run(
1944 self: Arc<Self>,
1945 input: serde_json::Value,
1946 event_stream: ToolCallEventStream,
1947 cx: &mut App,
1948 ) -> Task<Result<AgentToolOutput>> {
1949 cx.spawn(async move |cx| {
1950 let input = serde_json::from_value(input)?;
1951 let output = cx
1952 .update(|cx| self.0.clone().run(input, event_stream, cx))?
1953 .await?;
1954 let raw_output = serde_json::to_value(&output)?;
1955 Ok(AgentToolOutput {
1956 llm_output: output.into(),
1957 raw_output,
1958 })
1959 })
1960 }
1961
1962 fn replay(
1963 &self,
1964 input: serde_json::Value,
1965 output: serde_json::Value,
1966 event_stream: ToolCallEventStream,
1967 cx: &mut App,
1968 ) -> Result<()> {
1969 let input = serde_json::from_value(input)?;
1970 let output = serde_json::from_value(output)?;
1971 self.0.replay(input, output, event_stream, cx)
1972 }
1973}
1974
1975#[derive(Clone)]
1976struct ThreadEventStream(mpsc::UnboundedSender<Result<ThreadEvent>>);
1977
1978impl ThreadEventStream {
1979 fn send_title_update(&self, text: SharedString) {
1980 self.0
1981 .unbounded_send(Ok(ThreadEvent::TitleUpdate(text)))
1982 .ok();
1983 }
1984
1985 fn send_user_message(&self, message: &UserMessage) {
1986 self.0
1987 .unbounded_send(Ok(ThreadEvent::UserMessage(message.clone())))
1988 .ok();
1989 }
1990
1991 fn send_text(&self, text: &str) {
1992 self.0
1993 .unbounded_send(Ok(ThreadEvent::AgentText(text.to_string())))
1994 .ok();
1995 }
1996
1997 fn send_thinking(&self, text: &str) {
1998 self.0
1999 .unbounded_send(Ok(ThreadEvent::AgentThinking(text.to_string())))
2000 .ok();
2001 }
2002
2003 fn send_tool_call(
2004 &self,
2005 id: &LanguageModelToolUseId,
2006 title: SharedString,
2007 kind: acp::ToolKind,
2008 input: serde_json::Value,
2009 ) {
2010 self.0
2011 .unbounded_send(Ok(ThreadEvent::ToolCall(Self::initial_tool_call(
2012 id,
2013 title.to_string(),
2014 kind,
2015 input,
2016 ))))
2017 .ok();
2018 }
2019
2020 fn initial_tool_call(
2021 id: &LanguageModelToolUseId,
2022 title: String,
2023 kind: acp::ToolKind,
2024 input: serde_json::Value,
2025 ) -> acp::ToolCall {
2026 acp::ToolCall {
2027 id: acp::ToolCallId(id.to_string().into()),
2028 title,
2029 kind,
2030 status: acp::ToolCallStatus::Pending,
2031 content: vec![],
2032 locations: vec![],
2033 raw_input: Some(input),
2034 raw_output: None,
2035 }
2036 }
2037
2038 fn update_tool_call_fields(
2039 &self,
2040 tool_use_id: &LanguageModelToolUseId,
2041 fields: acp::ToolCallUpdateFields,
2042 ) {
2043 self.0
2044 .unbounded_send(Ok(ThreadEvent::ToolCallUpdate(
2045 acp::ToolCallUpdate {
2046 id: acp::ToolCallId(tool_use_id.to_string().into()),
2047 fields,
2048 }
2049 .into(),
2050 )))
2051 .ok();
2052 }
2053
2054 fn send_retry(&self, status: acp_thread::RetryStatus) {
2055 self.0.unbounded_send(Ok(ThreadEvent::Retry(status))).ok();
2056 }
2057
2058 fn send_stop(&self, reason: StopReason) {
2059 match reason {
2060 StopReason::EndTurn => {
2061 self.0
2062 .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::EndTurn)))
2063 .ok();
2064 }
2065 StopReason::MaxTokens => {
2066 self.0
2067 .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::MaxTokens)))
2068 .ok();
2069 }
2070 StopReason::Refusal => {
2071 self.0
2072 .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::Refusal)))
2073 .ok();
2074 }
2075 StopReason::ToolUse => {}
2076 }
2077 }
2078
2079 fn send_canceled(&self) {
2080 self.0
2081 .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::Canceled)))
2082 .ok();
2083 }
2084
2085 fn send_error(&self, error: impl Into<anyhow::Error>) {
2086 self.0.unbounded_send(Err(error.into())).ok();
2087 }
2088}
2089
2090#[derive(Clone)]
2091pub struct ToolCallEventStream {
2092 tool_use_id: LanguageModelToolUseId,
2093 stream: ThreadEventStream,
2094 fs: Option<Arc<dyn Fs>>,
2095}
2096
2097impl ToolCallEventStream {
2098 #[cfg(test)]
2099 pub fn test() -> (Self, ToolCallEventStreamReceiver) {
2100 let (events_tx, events_rx) = mpsc::unbounded::<Result<ThreadEvent>>();
2101
2102 let stream = ToolCallEventStream::new("test_id".into(), ThreadEventStream(events_tx), None);
2103
2104 (stream, ToolCallEventStreamReceiver(events_rx))
2105 }
2106
2107 fn new(
2108 tool_use_id: LanguageModelToolUseId,
2109 stream: ThreadEventStream,
2110 fs: Option<Arc<dyn Fs>>,
2111 ) -> Self {
2112 Self {
2113 tool_use_id,
2114 stream,
2115 fs,
2116 }
2117 }
2118
2119 pub fn update_fields(&self, fields: acp::ToolCallUpdateFields) {
2120 self.stream
2121 .update_tool_call_fields(&self.tool_use_id, fields);
2122 }
2123
2124 pub fn update_diff(&self, diff: Entity<acp_thread::Diff>) {
2125 self.stream
2126 .0
2127 .unbounded_send(Ok(ThreadEvent::ToolCallUpdate(
2128 acp_thread::ToolCallUpdateDiff {
2129 id: acp::ToolCallId(self.tool_use_id.to_string().into()),
2130 diff,
2131 }
2132 .into(),
2133 )))
2134 .ok();
2135 }
2136
2137 pub fn update_terminal(&self, terminal: Entity<acp_thread::Terminal>) {
2138 self.stream
2139 .0
2140 .unbounded_send(Ok(ThreadEvent::ToolCallUpdate(
2141 acp_thread::ToolCallUpdateTerminal {
2142 id: acp::ToolCallId(self.tool_use_id.to_string().into()),
2143 terminal,
2144 }
2145 .into(),
2146 )))
2147 .ok();
2148 }
2149
2150 pub fn authorize(&self, title: impl Into<String>, cx: &mut App) -> Task<Result<()>> {
2151 if agent_settings::AgentSettings::get_global(cx).always_allow_tool_actions {
2152 return Task::ready(Ok(()));
2153 }
2154
2155 let (response_tx, response_rx) = oneshot::channel();
2156 self.stream
2157 .0
2158 .unbounded_send(Ok(ThreadEvent::ToolCallAuthorization(
2159 ToolCallAuthorization {
2160 tool_call: acp::ToolCallUpdate {
2161 id: acp::ToolCallId(self.tool_use_id.to_string().into()),
2162 fields: acp::ToolCallUpdateFields {
2163 title: Some(title.into()),
2164 ..Default::default()
2165 },
2166 },
2167 options: vec![
2168 acp::PermissionOption {
2169 id: acp::PermissionOptionId("always_allow".into()),
2170 name: "Always Allow".into(),
2171 kind: acp::PermissionOptionKind::AllowAlways,
2172 },
2173 acp::PermissionOption {
2174 id: acp::PermissionOptionId("allow".into()),
2175 name: "Allow".into(),
2176 kind: acp::PermissionOptionKind::AllowOnce,
2177 },
2178 acp::PermissionOption {
2179 id: acp::PermissionOptionId("deny".into()),
2180 name: "Deny".into(),
2181 kind: acp::PermissionOptionKind::RejectOnce,
2182 },
2183 ],
2184 response: response_tx,
2185 },
2186 )))
2187 .ok();
2188 let fs = self.fs.clone();
2189 cx.spawn(async move |cx| match response_rx.await?.0.as_ref() {
2190 "always_allow" => {
2191 if let Some(fs) = fs.clone() {
2192 cx.update(|cx| {
2193 update_settings_file::<AgentSettings>(fs, cx, |settings, _| {
2194 settings.set_always_allow_tool_actions(true);
2195 });
2196 })?;
2197 }
2198
2199 Ok(())
2200 }
2201 "allow" => Ok(()),
2202 _ => Err(anyhow!("Permission to run tool denied by user")),
2203 })
2204 }
2205}
2206
2207#[cfg(test)]
2208pub struct ToolCallEventStreamReceiver(mpsc::UnboundedReceiver<Result<ThreadEvent>>);
2209
2210#[cfg(test)]
2211impl ToolCallEventStreamReceiver {
2212 pub async fn expect_authorization(&mut self) -> ToolCallAuthorization {
2213 let event = self.0.next().await;
2214 if let Some(Ok(ThreadEvent::ToolCallAuthorization(auth))) = event {
2215 auth
2216 } else {
2217 panic!("Expected ToolCallAuthorization but got: {:?}", event);
2218 }
2219 }
2220
2221 pub async fn expect_terminal(&mut self) -> Entity<acp_thread::Terminal> {
2222 let event = self.0.next().await;
2223 if let Some(Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateTerminal(
2224 update,
2225 )))) = event
2226 {
2227 update.terminal
2228 } else {
2229 panic!("Expected terminal but got: {:?}", event);
2230 }
2231 }
2232}
2233
2234#[cfg(test)]
2235impl std::ops::Deref for ToolCallEventStreamReceiver {
2236 type Target = mpsc::UnboundedReceiver<Result<ThreadEvent>>;
2237
2238 fn deref(&self) -> &Self::Target {
2239 &self.0
2240 }
2241}
2242
2243#[cfg(test)]
2244impl std::ops::DerefMut for ToolCallEventStreamReceiver {
2245 fn deref_mut(&mut self) -> &mut Self::Target {
2246 &mut self.0
2247 }
2248}
2249
2250impl From<&str> for UserMessageContent {
2251 fn from(text: &str) -> Self {
2252 Self::Text(text.into())
2253 }
2254}
2255
2256impl From<acp::ContentBlock> for UserMessageContent {
2257 fn from(value: acp::ContentBlock) -> Self {
2258 match value {
2259 acp::ContentBlock::Text(text_content) => Self::Text(text_content.text),
2260 acp::ContentBlock::Image(image_content) => Self::Image(convert_image(image_content)),
2261 acp::ContentBlock::Audio(_) => {
2262 // TODO
2263 Self::Text("[audio]".to_string())
2264 }
2265 acp::ContentBlock::ResourceLink(resource_link) => {
2266 match MentionUri::parse(&resource_link.uri) {
2267 Ok(uri) => Self::Mention {
2268 uri,
2269 content: String::new(),
2270 },
2271 Err(err) => {
2272 log::error!("Failed to parse mention link: {}", err);
2273 Self::Text(format!("[{}]({})", resource_link.name, resource_link.uri))
2274 }
2275 }
2276 }
2277 acp::ContentBlock::Resource(resource) => match resource.resource {
2278 acp::EmbeddedResourceResource::TextResourceContents(resource) => {
2279 match MentionUri::parse(&resource.uri) {
2280 Ok(uri) => Self::Mention {
2281 uri,
2282 content: resource.text,
2283 },
2284 Err(err) => {
2285 log::error!("Failed to parse mention link: {}", err);
2286 Self::Text(
2287 MarkdownCodeBlock {
2288 tag: &resource.uri,
2289 text: &resource.text,
2290 }
2291 .to_string(),
2292 )
2293 }
2294 }
2295 }
2296 acp::EmbeddedResourceResource::BlobResourceContents(_) => {
2297 // TODO
2298 Self::Text("[blob]".to_string())
2299 }
2300 },
2301 }
2302 }
2303}
2304
2305impl From<UserMessageContent> for acp::ContentBlock {
2306 fn from(content: UserMessageContent) -> Self {
2307 match content {
2308 UserMessageContent::Text(text) => acp::ContentBlock::Text(acp::TextContent {
2309 text,
2310 annotations: None,
2311 }),
2312 UserMessageContent::Image(image) => acp::ContentBlock::Image(acp::ImageContent {
2313 data: image.source.to_string(),
2314 mime_type: "image/png".to_string(),
2315 annotations: None,
2316 uri: None,
2317 }),
2318 UserMessageContent::Mention { uri, content } => {
2319 acp::ContentBlock::ResourceLink(acp::ResourceLink {
2320 uri: uri.to_uri().to_string(),
2321 name: uri.name(),
2322 annotations: None,
2323 description: if content.is_empty() {
2324 None
2325 } else {
2326 Some(content)
2327 },
2328 mime_type: None,
2329 size: None,
2330 title: None,
2331 })
2332 }
2333 }
2334 }
2335}
2336
2337fn convert_image(image_content: acp::ImageContent) -> LanguageModelImage {
2338 LanguageModelImage {
2339 source: image_content.data.into(),
2340 // TODO: make this optional?
2341 size: gpui::Size::new(0.into(), 0.into()),
2342 }
2343}