1use crate::{ContextServerRegistry, DbThread, SystemPromptTemplate, Template, Templates};
2use acp_thread::{MentionUri, UserMessageId};
3use action_log::ActionLog;
4use agent_client_protocol as acp;
5use agent_settings::{AgentProfileId, AgentSettings, CompletionMode};
6use anyhow::{Context as _, Result, anyhow};
7use assistant_tool::adapt_schema_to_format;
8use cloud_llm_client::{CompletionIntent, CompletionRequestStatus};
9use collections::IndexMap;
10use fs::Fs;
11use futures::{
12 channel::{mpsc, oneshot},
13 stream::FuturesUnordered,
14};
15use gpui::{App, Context, Entity, SharedString, Task};
16use language_model::{
17 LanguageModel, LanguageModelCompletionEvent, LanguageModelImage, LanguageModelProviderId,
18 LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
19 LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolSchemaFormat,
20 LanguageModelToolUse, LanguageModelToolUseId, Role, StopReason,
21};
22use project::Project;
23use prompt_store::ProjectContext;
24use schemars::{JsonSchema, Schema};
25use serde::{Deserialize, Serialize};
26use settings::{Settings, update_settings_file};
27use smol::stream::StreamExt;
28use std::{cell::RefCell, collections::BTreeMap, path::Path, rc::Rc, sync::Arc};
29use std::{fmt::Write, ops::Range};
30use util::{ResultExt, markdown::MarkdownCodeBlock};
31use uuid::Uuid;
32
33const TOOL_CANCELED_MESSAGE: &str = "Tool canceled by user";
34
35#[derive(
36 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
37)]
38pub struct ThreadId(pub(crate) Arc<str>);
39
40impl ThreadId {
41 pub fn new() -> Self {
42 Self(Uuid::new_v4().to_string().into())
43 }
44}
45
46impl std::fmt::Display for ThreadId {
47 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
48 write!(f, "{}", self.0)
49 }
50}
51
52impl From<&str> for ThreadId {
53 fn from(value: &str) -> Self {
54 Self(value.into())
55 }
56}
57
58impl From<acp::SessionId> for ThreadId {
59 fn from(value: acp::SessionId) -> Self {
60 Self(value.0)
61 }
62}
63
64impl From<ThreadId> for acp::SessionId {
65 fn from(value: ThreadId) -> Self {
66 Self(value.0)
67 }
68}
69
70/// The ID of the user prompt that initiated a request.
71///
72/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
73#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
74pub struct PromptId(Arc<str>);
75
76impl PromptId {
77 pub fn new() -> Self {
78 Self(Uuid::new_v4().to_string().into())
79 }
80}
81
82impl std::fmt::Display for PromptId {
83 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
84 write!(f, "{}", self.0)
85 }
86}
87
88#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
89pub enum Message {
90 User(UserMessage),
91 Agent(AgentMessage),
92 Resume,
93}
94
95impl Message {
96 pub fn as_agent_message(&self) -> Option<&AgentMessage> {
97 match self {
98 Message::Agent(agent_message) => Some(agent_message),
99 _ => None,
100 }
101 }
102
103 pub fn to_markdown(&self) -> String {
104 match self {
105 Message::User(message) => message.to_markdown(),
106 Message::Agent(message) => message.to_markdown(),
107 Message::Resume => "[resumed after tool use limit was reached]".into(),
108 }
109 }
110}
111
112#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
113pub struct UserMessage {
114 pub id: UserMessageId,
115 pub content: Vec<UserMessageContent>,
116}
117
118#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
119pub enum UserMessageContent {
120 Text(String),
121 Mention { uri: MentionUri, content: String },
122 Image(LanguageModelImage),
123}
124
125impl UserMessage {
126 pub fn to_markdown(&self) -> String {
127 let mut markdown = String::from("## User\n\n");
128
129 for content in &self.content {
130 match content {
131 UserMessageContent::Text(text) => {
132 markdown.push_str(text);
133 markdown.push('\n');
134 }
135 UserMessageContent::Image(_) => {
136 markdown.push_str("<image />\n");
137 }
138 UserMessageContent::Mention { uri, content } => {
139 if !content.is_empty() {
140 let _ = write!(&mut markdown, "{}\n\n{}\n", uri.as_link(), content);
141 } else {
142 let _ = write!(&mut markdown, "{}\n", uri.as_link());
143 }
144 }
145 }
146 }
147
148 markdown
149 }
150
151 fn to_request(&self) -> LanguageModelRequestMessage {
152 let mut message = LanguageModelRequestMessage {
153 role: Role::User,
154 content: Vec::with_capacity(self.content.len()),
155 cache: false,
156 };
157
158 const OPEN_CONTEXT: &str = "<context>\n\
159 The following items were attached by the user. \
160 They are up-to-date and don't need to be re-read.\n\n";
161
162 const OPEN_FILES_TAG: &str = "<files>";
163 const OPEN_SYMBOLS_TAG: &str = "<symbols>";
164 const OPEN_THREADS_TAG: &str = "<threads>";
165 const OPEN_FETCH_TAG: &str = "<fetched_urls>";
166 const OPEN_RULES_TAG: &str =
167 "<rules>\nThe user has specified the following rules that should be applied:\n";
168
169 let mut file_context = OPEN_FILES_TAG.to_string();
170 let mut symbol_context = OPEN_SYMBOLS_TAG.to_string();
171 let mut thread_context = OPEN_THREADS_TAG.to_string();
172 let mut fetch_context = OPEN_FETCH_TAG.to_string();
173 let mut rules_context = OPEN_RULES_TAG.to_string();
174
175 for chunk in &self.content {
176 let chunk = match chunk {
177 UserMessageContent::Text(text) => {
178 language_model::MessageContent::Text(text.clone())
179 }
180 UserMessageContent::Image(value) => {
181 language_model::MessageContent::Image(value.clone())
182 }
183 UserMessageContent::Mention { uri, content } => {
184 match uri {
185 MentionUri::File { abs_path, .. } => {
186 write!(
187 &mut symbol_context,
188 "\n{}",
189 MarkdownCodeBlock {
190 tag: &codeblock_tag(&abs_path, None),
191 text: &content.to_string(),
192 }
193 )
194 .ok();
195 }
196 MentionUri::Symbol {
197 path, line_range, ..
198 }
199 | MentionUri::Selection {
200 path, line_range, ..
201 } => {
202 write!(
203 &mut rules_context,
204 "\n{}",
205 MarkdownCodeBlock {
206 tag: &codeblock_tag(&path, Some(line_range)),
207 text: &content
208 }
209 )
210 .ok();
211 }
212 MentionUri::Thread { .. } => {
213 write!(&mut thread_context, "\n{}\n", content).ok();
214 }
215 MentionUri::TextThread { .. } => {
216 write!(&mut thread_context, "\n{}\n", content).ok();
217 }
218 MentionUri::Rule { .. } => {
219 write!(
220 &mut rules_context,
221 "\n{}",
222 MarkdownCodeBlock {
223 tag: "",
224 text: &content
225 }
226 )
227 .ok();
228 }
229 MentionUri::Fetch { url } => {
230 write!(&mut fetch_context, "\nFetch: {}\n\n{}", url, content).ok();
231 }
232 }
233
234 language_model::MessageContent::Text(uri.as_link().to_string())
235 }
236 };
237
238 message.content.push(chunk);
239 }
240
241 let len_before_context = message.content.len();
242
243 if file_context.len() > OPEN_FILES_TAG.len() {
244 file_context.push_str("</files>\n");
245 message
246 .content
247 .push(language_model::MessageContent::Text(file_context));
248 }
249
250 if symbol_context.len() > OPEN_SYMBOLS_TAG.len() {
251 symbol_context.push_str("</symbols>\n");
252 message
253 .content
254 .push(language_model::MessageContent::Text(symbol_context));
255 }
256
257 if thread_context.len() > OPEN_THREADS_TAG.len() {
258 thread_context.push_str("</threads>\n");
259 message
260 .content
261 .push(language_model::MessageContent::Text(thread_context));
262 }
263
264 if fetch_context.len() > OPEN_FETCH_TAG.len() {
265 fetch_context.push_str("</fetched_urls>\n");
266 message
267 .content
268 .push(language_model::MessageContent::Text(fetch_context));
269 }
270
271 if rules_context.len() > OPEN_RULES_TAG.len() {
272 rules_context.push_str("</user_rules>\n");
273 message
274 .content
275 .push(language_model::MessageContent::Text(rules_context));
276 }
277
278 if message.content.len() > len_before_context {
279 message.content.insert(
280 len_before_context,
281 language_model::MessageContent::Text(OPEN_CONTEXT.into()),
282 );
283 message
284 .content
285 .push(language_model::MessageContent::Text("</context>".into()));
286 }
287
288 message
289 }
290}
291
292fn codeblock_tag(full_path: &Path, line_range: Option<&Range<u32>>) -> String {
293 let mut result = String::new();
294
295 if let Some(extension) = full_path.extension().and_then(|ext| ext.to_str()) {
296 let _ = write!(result, "{} ", extension);
297 }
298
299 let _ = write!(result, "{}", full_path.display());
300
301 if let Some(range) = line_range {
302 if range.start == range.end {
303 let _ = write!(result, ":{}", range.start + 1);
304 } else {
305 let _ = write!(result, ":{}-{}", range.start + 1, range.end + 1);
306 }
307 }
308
309 result
310}
311
312impl AgentMessage {
313 pub fn to_markdown(&self) -> String {
314 let mut markdown = String::from("## Assistant\n\n");
315
316 for content in &self.content {
317 match content {
318 AgentMessageContent::Text(text) => {
319 markdown.push_str(text);
320 markdown.push('\n');
321 }
322 AgentMessageContent::Thinking { text, .. } => {
323 markdown.push_str("<think>");
324 markdown.push_str(text);
325 markdown.push_str("</think>\n");
326 }
327 AgentMessageContent::RedactedThinking(_) => {
328 markdown.push_str("<redacted_thinking />\n")
329 }
330 AgentMessageContent::ToolUse(tool_use) => {
331 markdown.push_str(&format!(
332 "**Tool Use**: {} (ID: {})\n",
333 tool_use.name, tool_use.id
334 ));
335 markdown.push_str(&format!(
336 "{}\n",
337 MarkdownCodeBlock {
338 tag: "json",
339 text: &format!("{:#}", tool_use.input)
340 }
341 ));
342 }
343 }
344 }
345
346 for tool_result in self.tool_results.values() {
347 markdown.push_str(&format!(
348 "**Tool Result**: {} (ID: {})\n\n",
349 tool_result.tool_name, tool_result.tool_use_id
350 ));
351 if tool_result.is_error {
352 markdown.push_str("**ERROR:**\n");
353 }
354
355 match &tool_result.content {
356 LanguageModelToolResultContent::Text(text) => {
357 writeln!(markdown, "{text}\n").ok();
358 }
359 LanguageModelToolResultContent::Image(_) => {
360 writeln!(markdown, "<image />\n").ok();
361 }
362 }
363
364 if let Some(output) = tool_result.output.as_ref() {
365 writeln!(
366 markdown,
367 "**Debug Output**:\n\n```json\n{}\n```\n",
368 serde_json::to_string_pretty(output).unwrap()
369 )
370 .unwrap();
371 }
372 }
373
374 markdown
375 }
376
377 pub fn to_request(&self) -> Vec<LanguageModelRequestMessage> {
378 let mut assistant_message = LanguageModelRequestMessage {
379 role: Role::Assistant,
380 content: Vec::with_capacity(self.content.len()),
381 cache: false,
382 };
383 for chunk in &self.content {
384 let chunk = match chunk {
385 AgentMessageContent::Text(text) => {
386 language_model::MessageContent::Text(text.clone())
387 }
388 AgentMessageContent::Thinking { text, signature } => {
389 language_model::MessageContent::Thinking {
390 text: text.clone(),
391 signature: signature.clone(),
392 }
393 }
394 AgentMessageContent::RedactedThinking(value) => {
395 language_model::MessageContent::RedactedThinking(value.clone())
396 }
397 AgentMessageContent::ToolUse(value) => {
398 language_model::MessageContent::ToolUse(value.clone())
399 }
400 };
401 assistant_message.content.push(chunk);
402 }
403
404 let mut user_message = LanguageModelRequestMessage {
405 role: Role::User,
406 content: Vec::new(),
407 cache: false,
408 };
409
410 for tool_result in self.tool_results.values() {
411 user_message
412 .content
413 .push(language_model::MessageContent::ToolResult(
414 tool_result.clone(),
415 ));
416 }
417
418 let mut messages = Vec::new();
419 if !assistant_message.content.is_empty() {
420 messages.push(assistant_message);
421 }
422 if !user_message.content.is_empty() {
423 messages.push(user_message);
424 }
425 messages
426 }
427}
428
429#[derive(Default, Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
430pub struct AgentMessage {
431 pub content: Vec<AgentMessageContent>,
432 pub tool_results: IndexMap<LanguageModelToolUseId, LanguageModelToolResult>,
433}
434
435#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
436pub enum AgentMessageContent {
437 Text(String),
438 Thinking {
439 text: String,
440 signature: Option<String>,
441 },
442 RedactedThinking(String),
443 ToolUse(LanguageModelToolUse),
444}
445
446#[derive(Debug)]
447pub enum ThreadEvent {
448 UserMessage(UserMessage),
449 AgentText(String),
450 AgentThinking(String),
451 ToolCall(acp::ToolCall),
452 ToolCallUpdate(acp_thread::ToolCallUpdate),
453 ToolCallAuthorization(ToolCallAuthorization),
454 Stop(acp::StopReason),
455}
456
457#[derive(Debug)]
458pub struct ToolCallAuthorization {
459 pub tool_call: acp::ToolCallUpdate,
460 pub options: Vec<acp::PermissionOption>,
461 pub response: oneshot::Sender<acp::PermissionOptionId>,
462}
463
464pub struct Thread {
465 id: ThreadId,
466 prompt_id: PromptId,
467 messages: Vec<Message>,
468 completion_mode: CompletionMode,
469 /// Holds the task that handles agent interaction until the end of the turn.
470 /// Survives across multiple requests as the model performs tool calls and
471 /// we run tools, report their results.
472 running_turn: Option<RunningTurn>,
473 pending_message: Option<AgentMessage>,
474 tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
475 tool_use_limit_reached: bool,
476 context_server_registry: Entity<ContextServerRegistry>,
477 profile_id: AgentProfileId,
478 project_context: Rc<RefCell<ProjectContext>>,
479 templates: Arc<Templates>,
480 model: Arc<dyn LanguageModel>,
481 project: Entity<Project>,
482 action_log: Entity<ActionLog>,
483}
484
485impl Thread {
486 pub fn new(
487 project: Entity<Project>,
488 project_context: Rc<RefCell<ProjectContext>>,
489 context_server_registry: Entity<ContextServerRegistry>,
490 action_log: Entity<ActionLog>,
491 templates: Arc<Templates>,
492 model: Arc<dyn LanguageModel>,
493 cx: &mut Context<Self>,
494 ) -> Self {
495 let profile_id = AgentSettings::get_global(cx).default_profile.clone();
496 Self {
497 id: ThreadId::new(),
498 prompt_id: PromptId::new(),
499 messages: Vec::new(),
500 completion_mode: CompletionMode::Normal,
501 running_turn: None,
502 pending_message: None,
503 tools: BTreeMap::default(),
504 tool_use_limit_reached: false,
505 context_server_registry,
506 profile_id,
507 project_context,
508 templates,
509 model,
510 project,
511 action_log,
512 }
513 }
514
515 pub fn from_db(
516 id: ThreadId,
517 db_thread: DbThread,
518 project: Entity<Project>,
519 project_context: Rc<RefCell<ProjectContext>>,
520 context_server_registry: Entity<ContextServerRegistry>,
521 action_log: Entity<ActionLog>,
522 templates: Arc<Templates>,
523 model: Arc<dyn LanguageModel>,
524 cx: &mut Context<Self>,
525 ) -> Self {
526 let profile_id = db_thread
527 .profile
528 .unwrap_or_else(|| AgentSettings::get_global(cx).default_profile.clone());
529 Self {
530 id,
531 prompt_id: PromptId::new(),
532 messages: db_thread.messages,
533 completion_mode: CompletionMode::Normal,
534 running_turn: None,
535 pending_message: None,
536 tools: BTreeMap::default(),
537 tool_use_limit_reached: false,
538 context_server_registry,
539 profile_id,
540 project_context,
541 templates,
542 model,
543 project,
544 action_log,
545 }
546 }
547
548 pub fn replay(&self, cx: &mut Context<Self>) -> mpsc::UnboundedReceiver<Result<ThreadEvent>> {
549 let (tx, rx) = mpsc::unbounded();
550 let stream = ThreadEventStream(tx);
551 for message in &self.messages {
552 match message {
553 Message::User(user_message) => stream.send_user_message(&user_message),
554 Message::Agent(assistant_message) => {
555 for content in &assistant_message.content {
556 match content {
557 AgentMessageContent::Text(text) => stream.send_text(text),
558 AgentMessageContent::Thinking { text, .. } => {
559 stream.send_thinking(text)
560 }
561 AgentMessageContent::RedactedThinking(_) => {}
562 AgentMessageContent::ToolUse(tool_use) => {
563 self.replay_tool_call(
564 tool_use,
565 assistant_message.tool_results.get(&tool_use.id),
566 &stream,
567 cx,
568 );
569 }
570 }
571 }
572 }
573 Message::Resume => {}
574 }
575 }
576 rx
577 }
578
579 fn replay_tool_call(
580 &self,
581 tool_use: &LanguageModelToolUse,
582 tool_result: Option<&LanguageModelToolResult>,
583 stream: &ThreadEventStream,
584 cx: &mut Context<Self>,
585 ) {
586 let Some(tool) = self.tools.get(tool_use.name.as_ref()) else {
587 stream
588 .0
589 .unbounded_send(Ok(ThreadEvent::ToolCall(acp::ToolCall {
590 id: acp::ToolCallId(tool_use.id.to_string().into()),
591 title: tool_use.name.to_string(),
592 kind: acp::ToolKind::Other,
593 status: acp::ToolCallStatus::Failed,
594 content: Vec::new(),
595 locations: Vec::new(),
596 raw_input: Some(tool_use.input.clone()),
597 raw_output: None,
598 })))
599 .ok();
600 return;
601 };
602
603 let title = tool.initial_title(tool_use.input.clone());
604 let kind = tool.kind();
605 stream.send_tool_call(&tool_use.id, title, kind, tool_use.input.clone());
606
607 if let Some(output) = tool_result
608 .as_ref()
609 .and_then(|result| result.output.clone())
610 {
611 let tool_event_stream = ToolCallEventStream::new(
612 tool_use.id.clone(),
613 stream.clone(),
614 Some(self.project.read(cx).fs().clone()),
615 );
616 tool.replay(tool_use.input.clone(), output, tool_event_stream, cx)
617 .log_err();
618 } else {
619 stream.update_tool_call_fields(
620 &tool_use.id,
621 acp::ToolCallUpdateFields {
622 content: Some(vec![TOOL_CANCELED_MESSAGE.into()]),
623 status: Some(acp::ToolCallStatus::Failed),
624 ..Default::default()
625 },
626 );
627 }
628 }
629
630 pub fn project(&self) -> &Entity<Project> {
631 &self.project
632 }
633
634 pub fn action_log(&self) -> &Entity<ActionLog> {
635 &self.action_log
636 }
637
638 pub fn model(&self) -> &Arc<dyn LanguageModel> {
639 &self.model
640 }
641
642 pub fn set_model(&mut self, model: Arc<dyn LanguageModel>) {
643 self.model = model;
644 }
645
646 pub fn completion_mode(&self) -> CompletionMode {
647 self.completion_mode
648 }
649
650 pub fn set_completion_mode(&mut self, mode: CompletionMode) {
651 self.completion_mode = mode;
652 }
653
654 #[cfg(any(test, feature = "test-support"))]
655 pub fn last_message(&self) -> Option<Message> {
656 if let Some(message) = self.pending_message.clone() {
657 Some(Message::Agent(message))
658 } else {
659 self.messages.last().cloned()
660 }
661 }
662
663 pub fn add_tool(&mut self, tool: impl AgentTool) {
664 self.tools.insert(tool.name(), tool.erase());
665 }
666
667 pub fn remove_tool(&mut self, name: &str) -> bool {
668 self.tools.remove(name).is_some()
669 }
670
671 pub fn profile(&self) -> &AgentProfileId {
672 &self.profile_id
673 }
674
675 pub fn set_profile(&mut self, profile_id: AgentProfileId) {
676 self.profile_id = profile_id;
677 }
678
679 pub fn cancel(&mut self) {
680 if let Some(running_turn) = self.running_turn.take() {
681 running_turn.cancel();
682 }
683 self.flush_pending_message();
684 }
685
686 pub fn truncate(&mut self, message_id: UserMessageId) -> Result<()> {
687 self.cancel();
688 let Some(position) = self.messages.iter().position(
689 |msg| matches!(msg, Message::User(UserMessage { id, .. }) if id == &message_id),
690 ) else {
691 return Err(anyhow!("Message not found"));
692 };
693 self.messages.truncate(position);
694 Ok(())
695 }
696
697 pub fn resume(
698 &mut self,
699 cx: &mut Context<Self>,
700 ) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>> {
701 anyhow::ensure!(
702 self.tool_use_limit_reached,
703 "can only resume after tool use limit is reached"
704 );
705
706 self.messages.push(Message::Resume);
707 cx.notify();
708
709 log::info!("Total messages in thread: {}", self.messages.len());
710 Ok(self.run_turn(cx))
711 }
712
713 /// Sending a message results in the model streaming a response, which could include tool calls.
714 /// After calling tools, the model will stops and waits for any outstanding tool calls to be completed and their results sent.
715 /// The returned channel will report all the occurrences in which the model stops before erroring or ending its turn.
716 pub fn send<T>(
717 &mut self,
718 id: UserMessageId,
719 content: impl IntoIterator<Item = T>,
720 cx: &mut Context<Self>,
721 ) -> mpsc::UnboundedReceiver<Result<ThreadEvent>>
722 where
723 T: Into<UserMessageContent>,
724 {
725 log::info!("Thread::send called with model: {:?}", self.model.name());
726 self.advance_prompt_id();
727
728 let content = content.into_iter().map(Into::into).collect::<Vec<_>>();
729 log::debug!("Thread::send content: {:?}", content);
730
731 self.messages
732 .push(Message::User(UserMessage { id, content }));
733 cx.notify();
734
735 log::info!("Total messages in thread: {}", self.messages.len());
736 self.run_turn(cx)
737 }
738
739 fn run_turn(&mut self, cx: &mut Context<Self>) -> mpsc::UnboundedReceiver<Result<ThreadEvent>> {
740 self.cancel();
741
742 let model = self.model.clone();
743 let (events_tx, events_rx) = mpsc::unbounded::<Result<ThreadEvent>>();
744 let event_stream = ThreadEventStream(events_tx);
745 let message_ix = self.messages.len().saturating_sub(1);
746 self.tool_use_limit_reached = false;
747 self.running_turn = Some(RunningTurn {
748 event_stream: event_stream.clone(),
749 _task: cx.spawn(async move |this, cx| {
750 log::info!("Starting agent turn execution");
751 let turn_result: Result<()> = async {
752 let mut completion_intent = CompletionIntent::UserPrompt;
753 loop {
754 log::debug!(
755 "Building completion request with intent: {:?}",
756 completion_intent
757 );
758 let request = this.update(cx, |this, cx| {
759 this.build_completion_request(completion_intent, cx)
760 })?;
761
762 log::info!("Calling model.stream_completion");
763 let mut events = model.stream_completion(request, cx).await?;
764 log::debug!("Stream completion started successfully");
765
766 let mut tool_use_limit_reached = false;
767 let mut tool_uses = FuturesUnordered::new();
768 while let Some(event) = events.next().await {
769 match event? {
770 LanguageModelCompletionEvent::StatusUpdate(
771 CompletionRequestStatus::ToolUseLimitReached,
772 ) => {
773 tool_use_limit_reached = true;
774 }
775 LanguageModelCompletionEvent::Stop(reason) => {
776 event_stream.send_stop(reason);
777 if reason == StopReason::Refusal {
778 this.update(cx, |this, _cx| {
779 this.flush_pending_message();
780 this.messages.truncate(message_ix);
781 })?;
782 return Ok(());
783 }
784 }
785 event => {
786 log::trace!("Received completion event: {:?}", event);
787 this.update(cx, |this, cx| {
788 tool_uses.extend(this.handle_streamed_completion_event(
789 event,
790 &event_stream,
791 cx,
792 ));
793 })
794 .ok();
795 }
796 }
797 }
798
799 let used_tools = tool_uses.is_empty();
800 while let Some(tool_result) = tool_uses.next().await {
801 log::info!("Tool finished {:?}", tool_result);
802
803 event_stream.update_tool_call_fields(
804 &tool_result.tool_use_id,
805 acp::ToolCallUpdateFields {
806 status: Some(if tool_result.is_error {
807 acp::ToolCallStatus::Failed
808 } else {
809 acp::ToolCallStatus::Completed
810 }),
811 raw_output: tool_result.output.clone(),
812 ..Default::default()
813 },
814 );
815 this.update(cx, |this, _cx| {
816 this.pending_message()
817 .tool_results
818 .insert(tool_result.tool_use_id.clone(), tool_result);
819 })
820 .ok();
821 }
822
823 if tool_use_limit_reached {
824 log::info!("Tool use limit reached, completing turn");
825 this.update(cx, |this, _cx| this.tool_use_limit_reached = true)?;
826 return Err(language_model::ToolUseLimitReachedError.into());
827 } else if used_tools {
828 log::info!("No tool uses found, completing turn");
829 return Ok(());
830 } else {
831 this.update(cx, |this, _| this.flush_pending_message())?;
832 completion_intent = CompletionIntent::ToolResults;
833 }
834 }
835 }
836 .await;
837
838 if let Err(error) = turn_result {
839 log::error!("Turn execution failed: {:?}", error);
840 event_stream.send_error(error);
841 } else {
842 log::info!("Turn execution completed successfully");
843 }
844
845 this.update(cx, |this, _| {
846 this.flush_pending_message();
847 this.running_turn.take();
848 })
849 .ok();
850 }),
851 });
852 events_rx
853 }
854
855 pub fn build_system_message(&self) -> LanguageModelRequestMessage {
856 log::debug!("Building system message");
857 let prompt = SystemPromptTemplate {
858 project: &self.project_context.borrow(),
859 available_tools: self.tools.keys().cloned().collect(),
860 }
861 .render(&self.templates)
862 .context("failed to build system prompt")
863 .expect("Invalid template");
864 log::debug!("System message built");
865 LanguageModelRequestMessage {
866 role: Role::System,
867 content: vec![prompt.into()],
868 cache: true,
869 }
870 }
871
872 /// A helper method that's called on every streamed completion event.
873 /// Returns an optional tool result task, which the main agentic loop in
874 /// send will send back to the model when it resolves.
875 fn handle_streamed_completion_event(
876 &mut self,
877 event: LanguageModelCompletionEvent,
878 event_stream: &ThreadEventStream,
879 cx: &mut Context<Self>,
880 ) -> Option<Task<LanguageModelToolResult>> {
881 log::trace!("Handling streamed completion event: {:?}", event);
882 use LanguageModelCompletionEvent::*;
883
884 match event {
885 StartMessage { .. } => {
886 self.flush_pending_message();
887 self.pending_message = Some(AgentMessage::default());
888 }
889 Text(new_text) => self.handle_text_event(new_text, event_stream, cx),
890 Thinking { text, signature } => {
891 self.handle_thinking_event(text, signature, event_stream, cx)
892 }
893 RedactedThinking { data } => self.handle_redacted_thinking_event(data, cx),
894 ToolUse(tool_use) => {
895 return self.handle_tool_use_event(tool_use, event_stream, cx);
896 }
897 ToolUseJsonParseError {
898 id,
899 tool_name,
900 raw_input,
901 json_parse_error,
902 } => {
903 return Some(Task::ready(self.handle_tool_use_json_parse_error_event(
904 id,
905 tool_name,
906 raw_input,
907 json_parse_error,
908 )));
909 }
910 UsageUpdate(_) | StatusUpdate(_) => {}
911 Stop(_) => unreachable!(),
912 }
913
914 None
915 }
916
917 fn handle_text_event(
918 &mut self,
919 new_text: String,
920 event_stream: &ThreadEventStream,
921 cx: &mut Context<Self>,
922 ) {
923 event_stream.send_text(&new_text);
924
925 let last_message = self.pending_message();
926 if let Some(AgentMessageContent::Text(text)) = last_message.content.last_mut() {
927 text.push_str(&new_text);
928 } else {
929 last_message
930 .content
931 .push(AgentMessageContent::Text(new_text));
932 }
933
934 cx.notify();
935 }
936
937 fn handle_thinking_event(
938 &mut self,
939 new_text: String,
940 new_signature: Option<String>,
941 event_stream: &ThreadEventStream,
942 cx: &mut Context<Self>,
943 ) {
944 event_stream.send_thinking(&new_text);
945
946 let last_message = self.pending_message();
947 if let Some(AgentMessageContent::Thinking { text, signature }) =
948 last_message.content.last_mut()
949 {
950 text.push_str(&new_text);
951 *signature = new_signature.or(signature.take());
952 } else {
953 last_message.content.push(AgentMessageContent::Thinking {
954 text: new_text,
955 signature: new_signature,
956 });
957 }
958
959 cx.notify();
960 }
961
962 fn handle_redacted_thinking_event(&mut self, data: String, cx: &mut Context<Self>) {
963 let last_message = self.pending_message();
964 last_message
965 .content
966 .push(AgentMessageContent::RedactedThinking(data));
967 cx.notify();
968 }
969
970 fn handle_tool_use_event(
971 &mut self,
972 tool_use: LanguageModelToolUse,
973 event_stream: &ThreadEventStream,
974 cx: &mut Context<Self>,
975 ) -> Option<Task<LanguageModelToolResult>> {
976 cx.notify();
977
978 let tool = self.tools.get(tool_use.name.as_ref()).cloned();
979 let mut title = SharedString::from(&tool_use.name);
980 let mut kind = acp::ToolKind::Other;
981 if let Some(tool) = tool.as_ref() {
982 title = tool.initial_title(tool_use.input.clone());
983 kind = tool.kind();
984 }
985
986 // Ensure the last message ends in the current tool use
987 let last_message = self.pending_message();
988 let push_new_tool_use = last_message.content.last_mut().map_or(true, |content| {
989 if let AgentMessageContent::ToolUse(last_tool_use) = content {
990 if last_tool_use.id == tool_use.id {
991 *last_tool_use = tool_use.clone();
992 false
993 } else {
994 true
995 }
996 } else {
997 true
998 }
999 });
1000
1001 if push_new_tool_use {
1002 event_stream.send_tool_call(&tool_use.id, title, kind, tool_use.input.clone());
1003 last_message
1004 .content
1005 .push(AgentMessageContent::ToolUse(tool_use.clone()));
1006 } else {
1007 event_stream.update_tool_call_fields(
1008 &tool_use.id,
1009 acp::ToolCallUpdateFields {
1010 title: Some(title.into()),
1011 kind: Some(kind),
1012 raw_input: Some(tool_use.input.clone()),
1013 ..Default::default()
1014 },
1015 );
1016 }
1017
1018 if !tool_use.is_input_complete {
1019 return None;
1020 }
1021
1022 let Some(tool) = tool else {
1023 let content = format!("No tool named {} exists", tool_use.name);
1024 return Some(Task::ready(LanguageModelToolResult {
1025 content: LanguageModelToolResultContent::Text(Arc::from(content)),
1026 tool_use_id: tool_use.id,
1027 tool_name: tool_use.name,
1028 is_error: true,
1029 output: None,
1030 }));
1031 };
1032
1033 let fs = self.project.read(cx).fs().clone();
1034 let tool_event_stream =
1035 ToolCallEventStream::new(tool_use.id.clone(), event_stream.clone(), Some(fs));
1036 tool_event_stream.update_fields(acp::ToolCallUpdateFields {
1037 status: Some(acp::ToolCallStatus::InProgress),
1038 ..Default::default()
1039 });
1040 let supports_images = self.model.supports_images();
1041 let tool_result = tool.run(tool_use.input, tool_event_stream, cx);
1042 log::info!("Running tool {}", tool_use.name);
1043 Some(cx.foreground_executor().spawn(async move {
1044 let tool_result = tool_result.await.and_then(|output| {
1045 if let LanguageModelToolResultContent::Image(_) = &output.llm_output {
1046 if !supports_images {
1047 return Err(anyhow!(
1048 "Attempted to read an image, but this model doesn't support it.",
1049 ));
1050 }
1051 }
1052 Ok(output)
1053 });
1054
1055 match tool_result {
1056 Ok(output) => LanguageModelToolResult {
1057 tool_use_id: tool_use.id,
1058 tool_name: tool_use.name,
1059 is_error: false,
1060 content: output.llm_output,
1061 output: Some(output.raw_output),
1062 },
1063 Err(error) => LanguageModelToolResult {
1064 tool_use_id: tool_use.id,
1065 tool_name: tool_use.name,
1066 is_error: true,
1067 content: LanguageModelToolResultContent::Text(Arc::from(error.to_string())),
1068 output: None,
1069 },
1070 }
1071 }))
1072 }
1073
1074 fn handle_tool_use_json_parse_error_event(
1075 &mut self,
1076 tool_use_id: LanguageModelToolUseId,
1077 tool_name: Arc<str>,
1078 raw_input: Arc<str>,
1079 json_parse_error: String,
1080 ) -> LanguageModelToolResult {
1081 let tool_output = format!("Error parsing input JSON: {json_parse_error}");
1082 LanguageModelToolResult {
1083 tool_use_id,
1084 tool_name,
1085 is_error: true,
1086 content: LanguageModelToolResultContent::Text(tool_output.into()),
1087 output: Some(serde_json::Value::String(raw_input.to_string())),
1088 }
1089 }
1090
1091 fn pending_message(&mut self) -> &mut AgentMessage {
1092 self.pending_message.get_or_insert_default()
1093 }
1094
1095 fn flush_pending_message(&mut self) {
1096 let Some(mut message) = self.pending_message.take() else {
1097 return;
1098 };
1099
1100 for content in &message.content {
1101 let AgentMessageContent::ToolUse(tool_use) = content else {
1102 continue;
1103 };
1104
1105 if !message.tool_results.contains_key(&tool_use.id) {
1106 message.tool_results.insert(
1107 tool_use.id.clone(),
1108 LanguageModelToolResult {
1109 tool_use_id: tool_use.id.clone(),
1110 tool_name: tool_use.name.clone(),
1111 is_error: true,
1112 content: LanguageModelToolResultContent::Text(TOOL_CANCELED_MESSAGE.into()),
1113 output: None,
1114 },
1115 );
1116 }
1117 }
1118
1119 self.messages.push(Message::Agent(message));
1120 }
1121
1122 pub(crate) fn build_completion_request(
1123 &self,
1124 completion_intent: CompletionIntent,
1125 cx: &mut App,
1126 ) -> LanguageModelRequest {
1127 log::debug!("Building completion request");
1128 log::debug!("Completion intent: {:?}", completion_intent);
1129 log::debug!("Completion mode: {:?}", self.completion_mode);
1130
1131 let messages = self.build_request_messages();
1132 log::info!("Request will include {} messages", messages.len());
1133
1134 let tools = if let Some(tools) = self.tools(cx).log_err() {
1135 tools
1136 .filter_map(|tool| {
1137 let tool_name = tool.name().to_string();
1138 log::trace!("Including tool: {}", tool_name);
1139 Some(LanguageModelRequestTool {
1140 name: tool_name,
1141 description: tool.description().to_string(),
1142 input_schema: tool
1143 .input_schema(self.model.tool_input_format())
1144 .log_err()?,
1145 })
1146 })
1147 .collect()
1148 } else {
1149 Vec::new()
1150 };
1151
1152 log::info!("Request includes {} tools", tools.len());
1153
1154 let request = LanguageModelRequest {
1155 thread_id: Some(self.id.to_string()),
1156 prompt_id: Some(self.prompt_id.to_string()),
1157 intent: Some(completion_intent),
1158 mode: Some(self.completion_mode.into()),
1159 messages,
1160 tools,
1161 tool_choice: None,
1162 stop: Vec::new(),
1163 temperature: AgentSettings::temperature_for_model(self.model(), cx),
1164 thinking_allowed: true,
1165 };
1166
1167 log::debug!("Completion request built successfully");
1168 request
1169 }
1170
1171 fn tools<'a>(&'a self, cx: &'a App) -> Result<impl Iterator<Item = &'a Arc<dyn AnyAgentTool>>> {
1172 let profile = AgentSettings::get_global(cx)
1173 .profiles
1174 .get(&self.profile_id)
1175 .context("profile not found")?;
1176 let provider_id = self.model.provider_id();
1177
1178 Ok(self
1179 .tools
1180 .iter()
1181 .filter(move |(_, tool)| tool.supported_provider(&provider_id))
1182 .filter_map(|(tool_name, tool)| {
1183 if profile.is_tool_enabled(tool_name) {
1184 Some(tool)
1185 } else {
1186 None
1187 }
1188 })
1189 .chain(self.context_server_registry.read(cx).servers().flat_map(
1190 |(server_id, tools)| {
1191 tools.iter().filter_map(|(tool_name, tool)| {
1192 if profile.is_context_server_tool_enabled(&server_id.0, tool_name) {
1193 Some(tool)
1194 } else {
1195 None
1196 }
1197 })
1198 },
1199 )))
1200 }
1201
1202 fn build_request_messages(&self) -> Vec<LanguageModelRequestMessage> {
1203 log::trace!(
1204 "Building request messages from {} thread messages",
1205 self.messages.len()
1206 );
1207 let mut messages = vec![self.build_system_message()];
1208 for message in &self.messages {
1209 match message {
1210 Message::User(message) => messages.push(message.to_request()),
1211 Message::Agent(message) => messages.extend(message.to_request()),
1212 Message::Resume => messages.push(LanguageModelRequestMessage {
1213 role: Role::User,
1214 content: vec!["Continue where you left off".into()],
1215 cache: false,
1216 }),
1217 }
1218 }
1219
1220 if let Some(message) = self.pending_message.as_ref() {
1221 messages.extend(message.to_request());
1222 }
1223
1224 if let Some(last_user_message) = messages
1225 .iter_mut()
1226 .rev()
1227 .find(|message| message.role == Role::User)
1228 {
1229 last_user_message.cache = true;
1230 }
1231
1232 messages
1233 }
1234
1235 pub fn to_markdown(&self) -> String {
1236 let mut markdown = String::new();
1237 for (ix, message) in self.messages.iter().enumerate() {
1238 if ix > 0 {
1239 markdown.push('\n');
1240 }
1241 markdown.push_str(&message.to_markdown());
1242 }
1243
1244 if let Some(message) = self.pending_message.as_ref() {
1245 markdown.push('\n');
1246 markdown.push_str(&message.to_markdown());
1247 }
1248
1249 markdown
1250 }
1251
1252 fn advance_prompt_id(&mut self) {
1253 self.prompt_id = PromptId::new();
1254 }
1255}
1256
1257struct RunningTurn {
1258 /// Holds the task that handles agent interaction until the end of the turn.
1259 /// Survives across multiple requests as the model performs tool calls and
1260 /// we run tools, report their results.
1261 _task: Task<()>,
1262 /// The current event stream for the running turn. Used to report a final
1263 /// cancellation event if we cancel the turn.
1264 event_stream: ThreadEventStream,
1265}
1266
1267impl RunningTurn {
1268 fn cancel(self) {
1269 log::debug!("Cancelling in progress turn");
1270 self.event_stream.send_canceled();
1271 }
1272}
1273
1274pub trait AgentTool
1275where
1276 Self: 'static + Sized,
1277{
1278 type Input: for<'de> Deserialize<'de> + Serialize + JsonSchema;
1279 type Output: for<'de> Deserialize<'de> + Serialize + Into<LanguageModelToolResultContent>;
1280
1281 fn name(&self) -> SharedString;
1282
1283 fn description(&self) -> SharedString {
1284 let schema = schemars::schema_for!(Self::Input);
1285 SharedString::new(
1286 schema
1287 .get("description")
1288 .and_then(|description| description.as_str())
1289 .unwrap_or_default(),
1290 )
1291 }
1292
1293 fn kind(&self) -> acp::ToolKind;
1294
1295 /// The initial tool title to display. Can be updated during the tool run.
1296 fn initial_title(&self, input: Result<Self::Input, serde_json::Value>) -> SharedString;
1297
1298 /// Returns the JSON schema that describes the tool's input.
1299 fn input_schema(&self) -> Schema {
1300 schemars::schema_for!(Self::Input)
1301 }
1302
1303 /// Some tools rely on a provider for the underlying billing or other reasons.
1304 /// Allow the tool to check if they are compatible, or should be filtered out.
1305 fn supported_provider(&self, _provider: &LanguageModelProviderId) -> bool {
1306 true
1307 }
1308
1309 /// Runs the tool with the provided input.
1310 fn run(
1311 self: Arc<Self>,
1312 input: Self::Input,
1313 event_stream: ToolCallEventStream,
1314 cx: &mut App,
1315 ) -> Task<Result<Self::Output>>;
1316
1317 /// Emits events for a previous execution of the tool.
1318 fn replay(
1319 &self,
1320 _input: Self::Input,
1321 _output: Self::Output,
1322 _event_stream: ToolCallEventStream,
1323 _cx: &mut App,
1324 ) -> Result<()> {
1325 Ok(())
1326 }
1327
1328 fn erase(self) -> Arc<dyn AnyAgentTool> {
1329 Arc::new(Erased(Arc::new(self)))
1330 }
1331}
1332
1333pub struct Erased<T>(T);
1334
1335pub struct AgentToolOutput {
1336 pub llm_output: LanguageModelToolResultContent,
1337 pub raw_output: serde_json::Value,
1338}
1339
1340pub trait AnyAgentTool {
1341 fn name(&self) -> SharedString;
1342 fn description(&self) -> SharedString;
1343 fn kind(&self) -> acp::ToolKind;
1344 fn initial_title(&self, input: serde_json::Value) -> SharedString;
1345 fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value>;
1346 fn supported_provider(&self, _provider: &LanguageModelProviderId) -> bool {
1347 true
1348 }
1349 fn run(
1350 self: Arc<Self>,
1351 input: serde_json::Value,
1352 event_stream: ToolCallEventStream,
1353 cx: &mut App,
1354 ) -> Task<Result<AgentToolOutput>>;
1355 fn replay(
1356 &self,
1357 input: serde_json::Value,
1358 output: serde_json::Value,
1359 event_stream: ToolCallEventStream,
1360 cx: &mut App,
1361 ) -> Result<()>;
1362}
1363
1364impl<T> AnyAgentTool for Erased<Arc<T>>
1365where
1366 T: AgentTool,
1367{
1368 fn name(&self) -> SharedString {
1369 self.0.name()
1370 }
1371
1372 fn description(&self) -> SharedString {
1373 self.0.description()
1374 }
1375
1376 fn kind(&self) -> agent_client_protocol::ToolKind {
1377 self.0.kind()
1378 }
1379
1380 fn initial_title(&self, input: serde_json::Value) -> SharedString {
1381 let parsed_input = serde_json::from_value(input.clone()).map_err(|_| input);
1382 self.0.initial_title(parsed_input)
1383 }
1384
1385 fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
1386 let mut json = serde_json::to_value(self.0.input_schema())?;
1387 adapt_schema_to_format(&mut json, format)?;
1388 Ok(json)
1389 }
1390
1391 fn supported_provider(&self, provider: &LanguageModelProviderId) -> bool {
1392 self.0.supported_provider(provider)
1393 }
1394
1395 fn run(
1396 self: Arc<Self>,
1397 input: serde_json::Value,
1398 event_stream: ToolCallEventStream,
1399 cx: &mut App,
1400 ) -> Task<Result<AgentToolOutput>> {
1401 cx.spawn(async move |cx| {
1402 let input = serde_json::from_value(input)?;
1403 let output = cx
1404 .update(|cx| self.0.clone().run(input, event_stream, cx))?
1405 .await?;
1406 let raw_output = serde_json::to_value(&output)?;
1407 Ok(AgentToolOutput {
1408 llm_output: output.into(),
1409 raw_output,
1410 })
1411 })
1412 }
1413
1414 fn replay(
1415 &self,
1416 input: serde_json::Value,
1417 output: serde_json::Value,
1418 event_stream: ToolCallEventStream,
1419 cx: &mut App,
1420 ) -> Result<()> {
1421 let input = serde_json::from_value(input)?;
1422 let output = serde_json::from_value(output)?;
1423 self.0.replay(input, output, event_stream, cx)
1424 }
1425}
1426
1427#[derive(Clone)]
1428struct ThreadEventStream(mpsc::UnboundedSender<Result<ThreadEvent>>);
1429
1430impl ThreadEventStream {
1431 fn send_user_message(&self, message: &UserMessage) {
1432 self.0
1433 .unbounded_send(Ok(ThreadEvent::UserMessage(message.clone())))
1434 .ok();
1435 }
1436
1437 fn send_text(&self, text: &str) {
1438 self.0
1439 .unbounded_send(Ok(ThreadEvent::AgentText(text.to_string())))
1440 .ok();
1441 }
1442
1443 fn send_thinking(&self, text: &str) {
1444 self.0
1445 .unbounded_send(Ok(ThreadEvent::AgentThinking(text.to_string())))
1446 .ok();
1447 }
1448
1449 fn send_tool_call(
1450 &self,
1451 id: &LanguageModelToolUseId,
1452 title: SharedString,
1453 kind: acp::ToolKind,
1454 input: serde_json::Value,
1455 ) {
1456 self.0
1457 .unbounded_send(Ok(ThreadEvent::ToolCall(Self::initial_tool_call(
1458 id,
1459 title.to_string(),
1460 kind,
1461 input,
1462 ))))
1463 .ok();
1464 }
1465
1466 fn initial_tool_call(
1467 id: &LanguageModelToolUseId,
1468 title: String,
1469 kind: acp::ToolKind,
1470 input: serde_json::Value,
1471 ) -> acp::ToolCall {
1472 acp::ToolCall {
1473 id: acp::ToolCallId(id.to_string().into()),
1474 title,
1475 kind,
1476 status: acp::ToolCallStatus::Pending,
1477 content: vec![],
1478 locations: vec![],
1479 raw_input: Some(input),
1480 raw_output: None,
1481 }
1482 }
1483
1484 fn update_tool_call_fields(
1485 &self,
1486 tool_use_id: &LanguageModelToolUseId,
1487 fields: acp::ToolCallUpdateFields,
1488 ) {
1489 self.0
1490 .unbounded_send(Ok(ThreadEvent::ToolCallUpdate(
1491 acp::ToolCallUpdate {
1492 id: acp::ToolCallId(tool_use_id.to_string().into()),
1493 fields,
1494 }
1495 .into(),
1496 )))
1497 .ok();
1498 }
1499
1500 fn send_stop(&self, reason: StopReason) {
1501 match reason {
1502 StopReason::EndTurn => {
1503 self.0
1504 .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::EndTurn)))
1505 .ok();
1506 }
1507 StopReason::MaxTokens => {
1508 self.0
1509 .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::MaxTokens)))
1510 .ok();
1511 }
1512 StopReason::Refusal => {
1513 self.0
1514 .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::Refusal)))
1515 .ok();
1516 }
1517 StopReason::ToolUse => {}
1518 }
1519 }
1520
1521 fn send_canceled(&self) {
1522 self.0
1523 .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::Canceled)))
1524 .ok();
1525 }
1526
1527 fn send_error(&self, error: impl Into<anyhow::Error>) {
1528 self.0.unbounded_send(Err(error.into())).ok();
1529 }
1530}
1531
1532#[derive(Clone)]
1533pub struct ToolCallEventStream {
1534 tool_use_id: LanguageModelToolUseId,
1535 stream: ThreadEventStream,
1536 fs: Option<Arc<dyn Fs>>,
1537}
1538
1539impl ToolCallEventStream {
1540 #[cfg(test)]
1541 pub fn test() -> (Self, ToolCallEventStreamReceiver) {
1542 let (events_tx, events_rx) = mpsc::unbounded::<Result<ThreadEvent>>();
1543
1544 let stream = ToolCallEventStream::new("test_id".into(), ThreadEventStream(events_tx), None);
1545
1546 (stream, ToolCallEventStreamReceiver(events_rx))
1547 }
1548
1549 fn new(
1550 tool_use_id: LanguageModelToolUseId,
1551 stream: ThreadEventStream,
1552 fs: Option<Arc<dyn Fs>>,
1553 ) -> Self {
1554 Self {
1555 tool_use_id,
1556 stream,
1557 fs,
1558 }
1559 }
1560
1561 pub fn update_fields(&self, fields: acp::ToolCallUpdateFields) {
1562 self.stream
1563 .update_tool_call_fields(&self.tool_use_id, fields);
1564 }
1565
1566 pub fn update_diff(&self, diff: Entity<acp_thread::Diff>) {
1567 self.stream
1568 .0
1569 .unbounded_send(Ok(ThreadEvent::ToolCallUpdate(
1570 acp_thread::ToolCallUpdateDiff {
1571 id: acp::ToolCallId(self.tool_use_id.to_string().into()),
1572 diff,
1573 }
1574 .into(),
1575 )))
1576 .ok();
1577 }
1578
1579 pub fn update_terminal(&self, terminal: Entity<acp_thread::Terminal>) {
1580 self.stream
1581 .0
1582 .unbounded_send(Ok(ThreadEvent::ToolCallUpdate(
1583 acp_thread::ToolCallUpdateTerminal {
1584 id: acp::ToolCallId(self.tool_use_id.to_string().into()),
1585 terminal,
1586 }
1587 .into(),
1588 )))
1589 .ok();
1590 }
1591
1592 pub fn authorize(&self, title: impl Into<String>, cx: &mut App) -> Task<Result<()>> {
1593 if agent_settings::AgentSettings::get_global(cx).always_allow_tool_actions {
1594 return Task::ready(Ok(()));
1595 }
1596
1597 let (response_tx, response_rx) = oneshot::channel();
1598 self.stream
1599 .0
1600 .unbounded_send(Ok(ThreadEvent::ToolCallAuthorization(
1601 ToolCallAuthorization {
1602 tool_call: acp::ToolCallUpdate {
1603 id: acp::ToolCallId(self.tool_use_id.to_string().into()),
1604 fields: acp::ToolCallUpdateFields {
1605 title: Some(title.into()),
1606 ..Default::default()
1607 },
1608 },
1609 options: vec![
1610 acp::PermissionOption {
1611 id: acp::PermissionOptionId("always_allow".into()),
1612 name: "Always Allow".into(),
1613 kind: acp::PermissionOptionKind::AllowAlways,
1614 },
1615 acp::PermissionOption {
1616 id: acp::PermissionOptionId("allow".into()),
1617 name: "Allow".into(),
1618 kind: acp::PermissionOptionKind::AllowOnce,
1619 },
1620 acp::PermissionOption {
1621 id: acp::PermissionOptionId("deny".into()),
1622 name: "Deny".into(),
1623 kind: acp::PermissionOptionKind::RejectOnce,
1624 },
1625 ],
1626 response: response_tx,
1627 },
1628 )))
1629 .ok();
1630 let fs = self.fs.clone();
1631 cx.spawn(async move |cx| match response_rx.await?.0.as_ref() {
1632 "always_allow" => {
1633 if let Some(fs) = fs.clone() {
1634 cx.update(|cx| {
1635 update_settings_file::<AgentSettings>(fs, cx, |settings, _| {
1636 settings.set_always_allow_tool_actions(true);
1637 });
1638 })?;
1639 }
1640
1641 Ok(())
1642 }
1643 "allow" => Ok(()),
1644 _ => Err(anyhow!("Permission to run tool denied by user")),
1645 })
1646 }
1647}
1648
1649#[cfg(test)]
1650pub struct ToolCallEventStreamReceiver(mpsc::UnboundedReceiver<Result<ThreadEvent>>);
1651
1652#[cfg(test)]
1653impl ToolCallEventStreamReceiver {
1654 pub async fn expect_authorization(&mut self) -> ToolCallAuthorization {
1655 let event = self.0.next().await;
1656 if let Some(Ok(ThreadEvent::ToolCallAuthorization(auth))) = event {
1657 auth
1658 } else {
1659 panic!("Expected ToolCallAuthorization but got: {:?}", event);
1660 }
1661 }
1662
1663 pub async fn expect_terminal(&mut self) -> Entity<acp_thread::Terminal> {
1664 let event = self.0.next().await;
1665 if let Some(Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateTerminal(
1666 update,
1667 )))) = event
1668 {
1669 update.terminal
1670 } else {
1671 panic!("Expected terminal but got: {:?}", event);
1672 }
1673 }
1674}
1675
1676#[cfg(test)]
1677impl std::ops::Deref for ToolCallEventStreamReceiver {
1678 type Target = mpsc::UnboundedReceiver<Result<ThreadEvent>>;
1679
1680 fn deref(&self) -> &Self::Target {
1681 &self.0
1682 }
1683}
1684
1685#[cfg(test)]
1686impl std::ops::DerefMut for ToolCallEventStreamReceiver {
1687 fn deref_mut(&mut self) -> &mut Self::Target {
1688 &mut self.0
1689 }
1690}
1691
1692impl From<&str> for UserMessageContent {
1693 fn from(text: &str) -> Self {
1694 Self::Text(text.into())
1695 }
1696}
1697
1698impl From<acp::ContentBlock> for UserMessageContent {
1699 fn from(value: acp::ContentBlock) -> Self {
1700 match value {
1701 acp::ContentBlock::Text(text_content) => Self::Text(text_content.text),
1702 acp::ContentBlock::Image(image_content) => Self::Image(convert_image(image_content)),
1703 acp::ContentBlock::Audio(_) => {
1704 // TODO
1705 Self::Text("[audio]".to_string())
1706 }
1707 acp::ContentBlock::ResourceLink(resource_link) => {
1708 match MentionUri::parse(&resource_link.uri) {
1709 Ok(uri) => Self::Mention {
1710 uri,
1711 content: String::new(),
1712 },
1713 Err(err) => {
1714 log::error!("Failed to parse mention link: {}", err);
1715 Self::Text(format!("[{}]({})", resource_link.name, resource_link.uri))
1716 }
1717 }
1718 }
1719 acp::ContentBlock::Resource(resource) => match resource.resource {
1720 acp::EmbeddedResourceResource::TextResourceContents(resource) => {
1721 match MentionUri::parse(&resource.uri) {
1722 Ok(uri) => Self::Mention {
1723 uri,
1724 content: resource.text,
1725 },
1726 Err(err) => {
1727 log::error!("Failed to parse mention link: {}", err);
1728 Self::Text(
1729 MarkdownCodeBlock {
1730 tag: &resource.uri,
1731 text: &resource.text,
1732 }
1733 .to_string(),
1734 )
1735 }
1736 }
1737 }
1738 acp::EmbeddedResourceResource::BlobResourceContents(_) => {
1739 // TODO
1740 Self::Text("[blob]".to_string())
1741 }
1742 },
1743 }
1744 }
1745}
1746
1747fn convert_image(image_content: acp::ImageContent) -> LanguageModelImage {
1748 LanguageModelImage {
1749 source: image_content.data.into(),
1750 // TODO: make this optional?
1751 size: gpui::Size::new(0.into(), 0.into()),
1752 }
1753}