1use crate::{ContextServerRegistry, DbThread, SystemPromptTemplate, Template, Templates};
2use acp_thread::{MentionUri, UserMessageId};
3use action_log::ActionLog;
4use agent_client_protocol as acp;
5use agent_settings::{AgentProfileId, AgentSettings, CompletionMode};
6use anyhow::{Context as _, Result, anyhow};
7use assistant_tool::adapt_schema_to_format;
8use cloud_llm_client::{CompletionIntent, CompletionRequestStatus};
9use collections::IndexMap;
10use fs::Fs;
11use futures::{
12 channel::{mpsc, oneshot},
13 stream::FuturesUnordered,
14};
15use gpui::{App, AppContext, Context, Entity, SharedString, Task};
16use language_model::{
17 LanguageModel, LanguageModelCompletionEvent, LanguageModelImage, LanguageModelProviderId,
18 LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
19 LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolSchemaFormat,
20 LanguageModelToolUse, LanguageModelToolUseId, Role, StopReason,
21};
22use project::Project;
23use prompt_store::ProjectContext;
24use schemars::{JsonSchema, Schema};
25use serde::{Deserialize, Serialize};
26use settings::{Settings, update_settings_file};
27use smol::stream::StreamExt;
28use std::{cell::RefCell, collections::BTreeMap, path::Path, rc::Rc, sync::Arc};
29use std::{fmt::Write, ops::Range};
30use util::{ResultExt, markdown::MarkdownCodeBlock};
31use uuid::Uuid;
32
33const TOOL_CANCELED_MESSAGE: &str = "Tool canceled by user";
34
35#[derive(
36 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
37)]
38pub struct ThreadId(pub(crate) Arc<str>);
39
40impl ThreadId {
41 pub fn new() -> Self {
42 Self(Uuid::new_v4().to_string().into())
43 }
44}
45
46impl std::fmt::Display for ThreadId {
47 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
48 write!(f, "{}", self.0)
49 }
50}
51
52impl From<&str> for ThreadId {
53 fn from(value: &str) -> Self {
54 Self(value.into())
55 }
56}
57
58impl From<acp::SessionId> for ThreadId {
59 fn from(value: acp::SessionId) -> Self {
60 Self(value.0)
61 }
62}
63
64impl From<ThreadId> for acp::SessionId {
65 fn from(value: ThreadId) -> Self {
66 Self(value.0)
67 }
68}
69
70/// The ID of the user prompt that initiated a request.
71///
72/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
73#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
74pub struct PromptId(Arc<str>);
75
76impl PromptId {
77 pub fn new() -> Self {
78 Self(Uuid::new_v4().to_string().into())
79 }
80}
81
82impl std::fmt::Display for PromptId {
83 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
84 write!(f, "{}", self.0)
85 }
86}
87
88#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
89pub enum Message {
90 User(UserMessage),
91 Agent(AgentMessage),
92 Resume,
93}
94
95impl Message {
96 pub fn as_agent_message(&self) -> Option<&AgentMessage> {
97 match self {
98 Message::Agent(agent_message) => Some(agent_message),
99 _ => None,
100 }
101 }
102
103 pub fn to_markdown(&self) -> String {
104 match self {
105 Message::User(message) => message.to_markdown(),
106 Message::Agent(message) => message.to_markdown(),
107 Message::Resume => "[resumed after tool use limit was reached]".into(),
108 }
109 }
110}
111
112#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
113pub struct UserMessage {
114 pub id: UserMessageId,
115 pub content: Vec<UserMessageContent>,
116}
117
118#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
119pub enum UserMessageContent {
120 Text(String),
121 Mention { uri: MentionUri, content: String },
122 Image(LanguageModelImage),
123}
124
125impl UserMessage {
126 pub fn to_markdown(&self) -> String {
127 let mut markdown = String::from("## User\n\n");
128
129 for content in &self.content {
130 match content {
131 UserMessageContent::Text(text) => {
132 markdown.push_str(text);
133 markdown.push('\n');
134 }
135 UserMessageContent::Image(_) => {
136 markdown.push_str("<image />\n");
137 }
138 UserMessageContent::Mention { uri, content } => {
139 if !content.is_empty() {
140 let _ = write!(&mut markdown, "{}\n\n{}\n", uri.as_link(), content);
141 } else {
142 let _ = write!(&mut markdown, "{}\n", uri.as_link());
143 }
144 }
145 }
146 }
147
148 markdown
149 }
150
151 fn to_request(&self) -> LanguageModelRequestMessage {
152 let mut message = LanguageModelRequestMessage {
153 role: Role::User,
154 content: Vec::with_capacity(self.content.len()),
155 cache: false,
156 };
157
158 const OPEN_CONTEXT: &str = "<context>\n\
159 The following items were attached by the user. \
160 They are up-to-date and don't need to be re-read.\n\n";
161
162 const OPEN_FILES_TAG: &str = "<files>";
163 const OPEN_SYMBOLS_TAG: &str = "<symbols>";
164 const OPEN_THREADS_TAG: &str = "<threads>";
165 const OPEN_FETCH_TAG: &str = "<fetched_urls>";
166 const OPEN_RULES_TAG: &str =
167 "<rules>\nThe user has specified the following rules that should be applied:\n";
168
169 let mut file_context = OPEN_FILES_TAG.to_string();
170 let mut symbol_context = OPEN_SYMBOLS_TAG.to_string();
171 let mut thread_context = OPEN_THREADS_TAG.to_string();
172 let mut fetch_context = OPEN_FETCH_TAG.to_string();
173 let mut rules_context = OPEN_RULES_TAG.to_string();
174
175 for chunk in &self.content {
176 let chunk = match chunk {
177 UserMessageContent::Text(text) => {
178 language_model::MessageContent::Text(text.clone())
179 }
180 UserMessageContent::Image(value) => {
181 language_model::MessageContent::Image(value.clone())
182 }
183 UserMessageContent::Mention { uri, content } => {
184 match uri {
185 MentionUri::File { abs_path, .. } => {
186 write!(
187 &mut symbol_context,
188 "\n{}",
189 MarkdownCodeBlock {
190 tag: &codeblock_tag(&abs_path, None),
191 text: &content.to_string(),
192 }
193 )
194 .ok();
195 }
196 MentionUri::Symbol {
197 path, line_range, ..
198 }
199 | MentionUri::Selection {
200 path, line_range, ..
201 } => {
202 write!(
203 &mut rules_context,
204 "\n{}",
205 MarkdownCodeBlock {
206 tag: &codeblock_tag(&path, Some(line_range)),
207 text: &content
208 }
209 )
210 .ok();
211 }
212 MentionUri::Thread { .. } => {
213 write!(&mut thread_context, "\n{}\n", content).ok();
214 }
215 MentionUri::TextThread { .. } => {
216 write!(&mut thread_context, "\n{}\n", content).ok();
217 }
218 MentionUri::Rule { .. } => {
219 write!(
220 &mut rules_context,
221 "\n{}",
222 MarkdownCodeBlock {
223 tag: "",
224 text: &content
225 }
226 )
227 .ok();
228 }
229 MentionUri::Fetch { url } => {
230 write!(&mut fetch_context, "\nFetch: {}\n\n{}", url, content).ok();
231 }
232 }
233
234 language_model::MessageContent::Text(uri.as_link().to_string())
235 }
236 };
237
238 message.content.push(chunk);
239 }
240
241 let len_before_context = message.content.len();
242
243 if file_context.len() > OPEN_FILES_TAG.len() {
244 file_context.push_str("</files>\n");
245 message
246 .content
247 .push(language_model::MessageContent::Text(file_context));
248 }
249
250 if symbol_context.len() > OPEN_SYMBOLS_TAG.len() {
251 symbol_context.push_str("</symbols>\n");
252 message
253 .content
254 .push(language_model::MessageContent::Text(symbol_context));
255 }
256
257 if thread_context.len() > OPEN_THREADS_TAG.len() {
258 thread_context.push_str("</threads>\n");
259 message
260 .content
261 .push(language_model::MessageContent::Text(thread_context));
262 }
263
264 if fetch_context.len() > OPEN_FETCH_TAG.len() {
265 fetch_context.push_str("</fetched_urls>\n");
266 message
267 .content
268 .push(language_model::MessageContent::Text(fetch_context));
269 }
270
271 if rules_context.len() > OPEN_RULES_TAG.len() {
272 rules_context.push_str("</user_rules>\n");
273 message
274 .content
275 .push(language_model::MessageContent::Text(rules_context));
276 }
277
278 if message.content.len() > len_before_context {
279 message.content.insert(
280 len_before_context,
281 language_model::MessageContent::Text(OPEN_CONTEXT.into()),
282 );
283 message
284 .content
285 .push(language_model::MessageContent::Text("</context>".into()));
286 }
287
288 message
289 }
290}
291
292fn codeblock_tag(full_path: &Path, line_range: Option<&Range<u32>>) -> String {
293 let mut result = String::new();
294
295 if let Some(extension) = full_path.extension().and_then(|ext| ext.to_str()) {
296 let _ = write!(result, "{} ", extension);
297 }
298
299 let _ = write!(result, "{}", full_path.display());
300
301 if let Some(range) = line_range {
302 if range.start == range.end {
303 let _ = write!(result, ":{}", range.start + 1);
304 } else {
305 let _ = write!(result, ":{}-{}", range.start + 1, range.end + 1);
306 }
307 }
308
309 result
310}
311
312impl AgentMessage {
313 pub fn to_markdown(&self) -> String {
314 let mut markdown = String::from("## Assistant\n\n");
315
316 for content in &self.content {
317 match content {
318 AgentMessageContent::Text(text) => {
319 markdown.push_str(text);
320 markdown.push('\n');
321 }
322 AgentMessageContent::Thinking { text, .. } => {
323 markdown.push_str("<think>");
324 markdown.push_str(text);
325 markdown.push_str("</think>\n");
326 }
327 AgentMessageContent::RedactedThinking(_) => {
328 markdown.push_str("<redacted_thinking />\n")
329 }
330 AgentMessageContent::ToolUse(tool_use) => {
331 markdown.push_str(&format!(
332 "**Tool Use**: {} (ID: {})\n",
333 tool_use.name, tool_use.id
334 ));
335 markdown.push_str(&format!(
336 "{}\n",
337 MarkdownCodeBlock {
338 tag: "json",
339 text: &format!("{:#}", tool_use.input)
340 }
341 ));
342 }
343 }
344 }
345
346 for tool_result in self.tool_results.values() {
347 markdown.push_str(&format!(
348 "**Tool Result**: {} (ID: {})\n\n",
349 tool_result.tool_name, tool_result.tool_use_id
350 ));
351 if tool_result.is_error {
352 markdown.push_str("**ERROR:**\n");
353 }
354
355 match &tool_result.content {
356 LanguageModelToolResultContent::Text(text) => {
357 writeln!(markdown, "{text}\n").ok();
358 }
359 LanguageModelToolResultContent::Image(_) => {
360 writeln!(markdown, "<image />\n").ok();
361 }
362 }
363
364 if let Some(output) = tool_result.output.as_ref() {
365 writeln!(
366 markdown,
367 "**Debug Output**:\n\n```json\n{}\n```\n",
368 serde_json::to_string_pretty(output).unwrap()
369 )
370 .unwrap();
371 }
372 }
373
374 markdown
375 }
376
377 pub fn to_request(&self) -> Vec<LanguageModelRequestMessage> {
378 let mut assistant_message = LanguageModelRequestMessage {
379 role: Role::Assistant,
380 content: Vec::with_capacity(self.content.len()),
381 cache: false,
382 };
383 for chunk in &self.content {
384 let chunk = match chunk {
385 AgentMessageContent::Text(text) => {
386 language_model::MessageContent::Text(text.clone())
387 }
388 AgentMessageContent::Thinking { text, signature } => {
389 language_model::MessageContent::Thinking {
390 text: text.clone(),
391 signature: signature.clone(),
392 }
393 }
394 AgentMessageContent::RedactedThinking(value) => {
395 language_model::MessageContent::RedactedThinking(value.clone())
396 }
397 AgentMessageContent::ToolUse(value) => {
398 language_model::MessageContent::ToolUse(value.clone())
399 }
400 };
401 assistant_message.content.push(chunk);
402 }
403
404 let mut user_message = LanguageModelRequestMessage {
405 role: Role::User,
406 content: Vec::new(),
407 cache: false,
408 };
409
410 for tool_result in self.tool_results.values() {
411 user_message
412 .content
413 .push(language_model::MessageContent::ToolResult(
414 tool_result.clone(),
415 ));
416 }
417
418 let mut messages = Vec::new();
419 if !assistant_message.content.is_empty() {
420 messages.push(assistant_message);
421 }
422 if !user_message.content.is_empty() {
423 messages.push(user_message);
424 }
425 messages
426 }
427}
428
429#[derive(Default, Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
430pub struct AgentMessage {
431 pub content: Vec<AgentMessageContent>,
432 pub tool_results: IndexMap<LanguageModelToolUseId, LanguageModelToolResult>,
433}
434
435#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
436pub enum AgentMessageContent {
437 Text(String),
438 Thinking {
439 text: String,
440 signature: Option<String>,
441 },
442 RedactedThinking(String),
443 ToolUse(LanguageModelToolUse),
444}
445
446#[derive(Debug)]
447pub enum ThreadEvent {
448 UserMessage(UserMessage),
449 AgentText(String),
450 AgentThinking(String),
451 ToolCall(acp::ToolCall),
452 ToolCallUpdate(acp_thread::ToolCallUpdate),
453 ToolCallAuthorization(ToolCallAuthorization),
454 Stop(acp::StopReason),
455}
456
457#[derive(Debug)]
458pub struct ToolCallAuthorization {
459 pub tool_call: acp::ToolCallUpdate,
460 pub options: Vec<acp::PermissionOption>,
461 pub response: oneshot::Sender<acp::PermissionOptionId>,
462}
463
464pub struct Thread {
465 id: ThreadId,
466 prompt_id: PromptId,
467 messages: Vec<Message>,
468 completion_mode: CompletionMode,
469 /// Holds the task that handles agent interaction until the end of the turn.
470 /// Survives across multiple requests as the model performs tool calls and
471 /// we run tools, report their results.
472 running_turn: Option<RunningTurn>,
473 pending_message: Option<AgentMessage>,
474 tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
475 tool_use_limit_reached: bool,
476 context_server_registry: Entity<ContextServerRegistry>,
477 profile_id: AgentProfileId,
478 project_context: Rc<RefCell<ProjectContext>>,
479 templates: Arc<Templates>,
480 model: Arc<dyn LanguageModel>,
481 project: Entity<Project>,
482 action_log: Entity<ActionLog>,
483}
484
485impl Thread {
486 pub fn new(
487 project: Entity<Project>,
488 project_context: Rc<RefCell<ProjectContext>>,
489 context_server_registry: Entity<ContextServerRegistry>,
490 action_log: Entity<ActionLog>,
491 templates: Arc<Templates>,
492 model: Arc<dyn LanguageModel>,
493 cx: &mut Context<Self>,
494 ) -> Self {
495 let profile_id = AgentSettings::get_global(cx).default_profile.clone();
496 Self {
497 id: ThreadId::new(),
498 prompt_id: PromptId::new(),
499 messages: Vec::new(),
500 completion_mode: CompletionMode::Normal,
501 running_turn: None,
502 pending_message: None,
503 tools: BTreeMap::default(),
504 tool_use_limit_reached: false,
505 context_server_registry,
506 profile_id,
507 project_context,
508 templates,
509 model,
510 project,
511 action_log,
512 }
513 }
514
515 pub fn from_db(
516 id: ThreadId,
517 db_thread: DbThread,
518 project: Entity<Project>,
519 project_context: Rc<RefCell<ProjectContext>>,
520 context_server_registry: Entity<ContextServerRegistry>,
521 action_log: Entity<ActionLog>,
522 templates: Arc<Templates>,
523 model: Arc<dyn LanguageModel>,
524 cx: &mut Context<Self>,
525 ) -> Self {
526 let profile_id = db_thread
527 .profile
528 .unwrap_or_else(|| AgentSettings::get_global(cx).default_profile.clone());
529 Self {
530 id,
531 prompt_id: PromptId::new(),
532 messages: db_thread.messages,
533 completion_mode: CompletionMode::Normal,
534 running_turn: None,
535 pending_message: None,
536 tools: BTreeMap::default(),
537 tool_use_limit_reached: false,
538 context_server_registry,
539 profile_id,
540 project_context,
541 templates,
542 model,
543 project,
544 action_log,
545 }
546 }
547
548 pub fn replay(
549 &mut self,
550 cx: &mut Context<Self>,
551 ) -> mpsc::UnboundedReceiver<Result<ThreadEvent>> {
552 let (tx, rx) = mpsc::unbounded();
553 let stream = ThreadEventStream(tx);
554 for message in &self.messages {
555 match message {
556 Message::User(user_message) => stream.send_user_message(&user_message),
557 Message::Agent(assistant_message) => {
558 for content in &assistant_message.content {
559 match content {
560 AgentMessageContent::Text(text) => stream.send_text(text),
561 AgentMessageContent::Thinking { text, .. } => {
562 stream.send_thinking(text)
563 }
564 AgentMessageContent::RedactedThinking(_) => {}
565 AgentMessageContent::ToolUse(tool_use) => {
566 self.replay_tool_call(
567 tool_use,
568 assistant_message.tool_results.get(&tool_use.id),
569 &stream,
570 cx,
571 );
572 }
573 }
574 }
575 }
576 Message::Resume => {}
577 }
578 }
579 rx
580 }
581
582 fn replay_tool_call(
583 &self,
584 tool_use: &LanguageModelToolUse,
585 tool_result: Option<&LanguageModelToolResult>,
586 stream: &ThreadEventStream,
587 cx: &mut Context<Self>,
588 ) {
589 let Some(tool) = self.tools.get(tool_use.name.as_ref()) else {
590 stream
591 .0
592 .unbounded_send(Ok(ThreadEvent::ToolCall(acp::ToolCall {
593 id: acp::ToolCallId(tool_use.id.to_string().into()),
594 title: tool_use.name.to_string(),
595 kind: acp::ToolKind::Other,
596 status: acp::ToolCallStatus::Failed,
597 content: Vec::new(),
598 locations: Vec::new(),
599 raw_input: Some(tool_use.input.clone()),
600 raw_output: None,
601 })))
602 .ok();
603 return;
604 };
605
606 let title = tool.initial_title(tool_use.input.clone());
607 let kind = tool.kind();
608 stream.send_tool_call(&tool_use.id, title, kind, tool_use.input.clone());
609
610 let output = tool_result
611 .as_ref()
612 .and_then(|result| result.output.clone());
613 if let Some(output) = output.clone() {
614 let tool_event_stream = ToolCallEventStream::new(
615 tool_use.id.clone(),
616 stream.clone(),
617 Some(self.project.read(cx).fs().clone()),
618 );
619 tool.replay(tool_use.input.clone(), output, tool_event_stream, cx)
620 .log_err();
621 }
622
623 stream.update_tool_call_fields(
624 &tool_use.id,
625 acp::ToolCallUpdateFields {
626 status: Some(acp::ToolCallStatus::Completed),
627 raw_output: output,
628 ..Default::default()
629 },
630 );
631 }
632
633 pub fn project(&self) -> &Entity<Project> {
634 &self.project
635 }
636
637 pub fn action_log(&self) -> &Entity<ActionLog> {
638 &self.action_log
639 }
640
641 pub fn model(&self) -> &Arc<dyn LanguageModel> {
642 &self.model
643 }
644
645 pub fn set_model(&mut self, model: Arc<dyn LanguageModel>) {
646 self.model = model;
647 }
648
649 pub fn completion_mode(&self) -> CompletionMode {
650 self.completion_mode
651 }
652
653 pub fn set_completion_mode(&mut self, mode: CompletionMode) {
654 self.completion_mode = mode;
655 }
656
657 #[cfg(any(test, feature = "test-support"))]
658 pub fn last_message(&self) -> Option<Message> {
659 if let Some(message) = self.pending_message.clone() {
660 Some(Message::Agent(message))
661 } else {
662 self.messages.last().cloned()
663 }
664 }
665
666 pub fn add_tool(&mut self, tool: impl AgentTool) {
667 self.tools.insert(tool.name(), tool.erase());
668 }
669
670 pub fn remove_tool(&mut self, name: &str) -> bool {
671 self.tools.remove(name).is_some()
672 }
673
674 pub fn profile(&self) -> &AgentProfileId {
675 &self.profile_id
676 }
677
678 pub fn set_profile(&mut self, profile_id: AgentProfileId) {
679 self.profile_id = profile_id;
680 }
681
682 pub fn cancel(&mut self) {
683 if let Some(running_turn) = self.running_turn.take() {
684 running_turn.cancel();
685 }
686 self.flush_pending_message();
687 }
688
689 pub fn truncate(&mut self, message_id: UserMessageId) -> Result<()> {
690 self.cancel();
691 let Some(position) = self.messages.iter().position(
692 |msg| matches!(msg, Message::User(UserMessage { id, .. }) if id == &message_id),
693 ) else {
694 return Err(anyhow!("Message not found"));
695 };
696 self.messages.truncate(position);
697 Ok(())
698 }
699
700 pub fn resume(
701 &mut self,
702 cx: &mut Context<Self>,
703 ) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>> {
704 anyhow::ensure!(
705 self.tool_use_limit_reached,
706 "can only resume after tool use limit is reached"
707 );
708
709 self.messages.push(Message::Resume);
710 cx.notify();
711
712 log::info!("Total messages in thread: {}", self.messages.len());
713 Ok(self.run_turn(cx))
714 }
715
716 /// Sending a message results in the model streaming a response, which could include tool calls.
717 /// After calling tools, the model will stops and waits for any outstanding tool calls to be completed and their results sent.
718 /// The returned channel will report all the occurrences in which the model stops before erroring or ending its turn.
719 pub fn send<T>(
720 &mut self,
721 id: UserMessageId,
722 content: impl IntoIterator<Item = T>,
723 cx: &mut Context<Self>,
724 ) -> mpsc::UnboundedReceiver<Result<ThreadEvent>>
725 where
726 T: Into<UserMessageContent>,
727 {
728 log::info!("Thread::send called with model: {:?}", self.model.name());
729 self.advance_prompt_id();
730
731 let content = content.into_iter().map(Into::into).collect::<Vec<_>>();
732 log::debug!("Thread::send content: {:?}", content);
733
734 self.messages
735 .push(Message::User(UserMessage { id, content }));
736 cx.notify();
737
738 log::info!("Total messages in thread: {}", self.messages.len());
739 self.run_turn(cx)
740 }
741
742 fn run_turn(&mut self, cx: &mut Context<Self>) -> mpsc::UnboundedReceiver<Result<ThreadEvent>> {
743 self.cancel();
744
745 let model = self.model.clone();
746 let (events_tx, events_rx) = mpsc::unbounded::<Result<ThreadEvent>>();
747 let event_stream = ThreadEventStream(events_tx);
748 let message_ix = self.messages.len().saturating_sub(1);
749 self.tool_use_limit_reached = false;
750 self.running_turn = Some(RunningTurn {
751 event_stream: event_stream.clone(),
752 _task: cx.spawn(async move |this, cx| {
753 log::info!("Starting agent turn execution");
754 let turn_result: Result<()> = async {
755 let mut completion_intent = CompletionIntent::UserPrompt;
756 loop {
757 log::debug!(
758 "Building completion request with intent: {:?}",
759 completion_intent
760 );
761 let request = this.update(cx, |this, cx| {
762 this.build_completion_request(completion_intent, cx)
763 })?;
764
765 log::info!("Calling model.stream_completion");
766 let mut events = model.stream_completion(request, cx).await?;
767 log::debug!("Stream completion started successfully");
768
769 let mut tool_use_limit_reached = false;
770 let mut tool_uses = FuturesUnordered::new();
771 while let Some(event) = events.next().await {
772 match event? {
773 LanguageModelCompletionEvent::StatusUpdate(
774 CompletionRequestStatus::ToolUseLimitReached,
775 ) => {
776 tool_use_limit_reached = true;
777 }
778 LanguageModelCompletionEvent::Stop(reason) => {
779 event_stream.send_stop(reason);
780 if reason == StopReason::Refusal {
781 this.update(cx, |this, _cx| {
782 this.flush_pending_message();
783 this.messages.truncate(message_ix);
784 })?;
785 return Ok(());
786 }
787 }
788 event => {
789 log::trace!("Received completion event: {:?}", event);
790 this.update(cx, |this, cx| {
791 tool_uses.extend(this.handle_streamed_completion_event(
792 event,
793 &event_stream,
794 cx,
795 ));
796 })
797 .ok();
798 }
799 }
800 }
801
802 let used_tools = tool_uses.is_empty();
803 while let Some(tool_result) = tool_uses.next().await {
804 log::info!("Tool finished {:?}", tool_result);
805
806 event_stream.update_tool_call_fields(
807 &tool_result.tool_use_id,
808 acp::ToolCallUpdateFields {
809 status: Some(if tool_result.is_error {
810 acp::ToolCallStatus::Failed
811 } else {
812 acp::ToolCallStatus::Completed
813 }),
814 raw_output: tool_result.output.clone(),
815 ..Default::default()
816 },
817 );
818 this.update(cx, |this, _cx| {
819 this.pending_message()
820 .tool_results
821 .insert(tool_result.tool_use_id.clone(), tool_result);
822 })
823 .ok();
824 }
825
826 if tool_use_limit_reached {
827 log::info!("Tool use limit reached, completing turn");
828 this.update(cx, |this, _cx| this.tool_use_limit_reached = true)?;
829 return Err(language_model::ToolUseLimitReachedError.into());
830 } else if used_tools {
831 log::info!("No tool uses found, completing turn");
832 return Ok(());
833 } else {
834 this.update(cx, |this, _| this.flush_pending_message())?;
835 completion_intent = CompletionIntent::ToolResults;
836 }
837 }
838 }
839 .await;
840
841 if let Err(error) = turn_result {
842 log::error!("Turn execution failed: {:?}", error);
843 event_stream.send_error(error);
844 } else {
845 log::info!("Turn execution completed successfully");
846 }
847
848 this.update(cx, |this, _| {
849 this.flush_pending_message();
850 this.running_turn.take();
851 })
852 .ok();
853 }),
854 });
855 events_rx
856 }
857
858 pub fn build_system_message(&self) -> LanguageModelRequestMessage {
859 log::debug!("Building system message");
860 let prompt = SystemPromptTemplate {
861 project: &self.project_context.borrow(),
862 available_tools: self.tools.keys().cloned().collect(),
863 }
864 .render(&self.templates)
865 .context("failed to build system prompt")
866 .expect("Invalid template");
867 log::debug!("System message built");
868 LanguageModelRequestMessage {
869 role: Role::System,
870 content: vec![prompt.into()],
871 cache: true,
872 }
873 }
874
875 /// A helper method that's called on every streamed completion event.
876 /// Returns an optional tool result task, which the main agentic loop in
877 /// send will send back to the model when it resolves.
878 fn handle_streamed_completion_event(
879 &mut self,
880 event: LanguageModelCompletionEvent,
881 event_stream: &ThreadEventStream,
882 cx: &mut Context<Self>,
883 ) -> Option<Task<LanguageModelToolResult>> {
884 log::trace!("Handling streamed completion event: {:?}", event);
885 use LanguageModelCompletionEvent::*;
886
887 match event {
888 StartMessage { .. } => {
889 self.flush_pending_message();
890 self.pending_message = Some(AgentMessage::default());
891 }
892 Text(new_text) => self.handle_text_event(new_text, event_stream, cx),
893 Thinking { text, signature } => {
894 self.handle_thinking_event(text, signature, event_stream, cx)
895 }
896 RedactedThinking { data } => self.handle_redacted_thinking_event(data, cx),
897 ToolUse(tool_use) => {
898 return self.handle_tool_use_event(tool_use, event_stream, cx);
899 }
900 ToolUseJsonParseError {
901 id,
902 tool_name,
903 raw_input,
904 json_parse_error,
905 } => {
906 return Some(Task::ready(self.handle_tool_use_json_parse_error_event(
907 id,
908 tool_name,
909 raw_input,
910 json_parse_error,
911 )));
912 }
913 UsageUpdate(_) | StatusUpdate(_) => {}
914 Stop(_) => unreachable!(),
915 }
916
917 None
918 }
919
920 fn handle_text_event(
921 &mut self,
922 new_text: String,
923 event_stream: &ThreadEventStream,
924 cx: &mut Context<Self>,
925 ) {
926 event_stream.send_text(&new_text);
927
928 let last_message = self.pending_message();
929 if let Some(AgentMessageContent::Text(text)) = last_message.content.last_mut() {
930 text.push_str(&new_text);
931 } else {
932 last_message
933 .content
934 .push(AgentMessageContent::Text(new_text));
935 }
936
937 cx.notify();
938 }
939
940 fn handle_thinking_event(
941 &mut self,
942 new_text: String,
943 new_signature: Option<String>,
944 event_stream: &ThreadEventStream,
945 cx: &mut Context<Self>,
946 ) {
947 event_stream.send_thinking(&new_text);
948
949 let last_message = self.pending_message();
950 if let Some(AgentMessageContent::Thinking { text, signature }) =
951 last_message.content.last_mut()
952 {
953 text.push_str(&new_text);
954 *signature = new_signature.or(signature.take());
955 } else {
956 last_message.content.push(AgentMessageContent::Thinking {
957 text: new_text,
958 signature: new_signature,
959 });
960 }
961
962 cx.notify();
963 }
964
965 fn handle_redacted_thinking_event(&mut self, data: String, cx: &mut Context<Self>) {
966 let last_message = self.pending_message();
967 last_message
968 .content
969 .push(AgentMessageContent::RedactedThinking(data));
970 cx.notify();
971 }
972
973 fn handle_tool_use_event(
974 &mut self,
975 tool_use: LanguageModelToolUse,
976 event_stream: &ThreadEventStream,
977 cx: &mut Context<Self>,
978 ) -> Option<Task<LanguageModelToolResult>> {
979 cx.notify();
980
981 let tool = self.tools.get(tool_use.name.as_ref()).cloned();
982 let mut title = SharedString::from(&tool_use.name);
983 let mut kind = acp::ToolKind::Other;
984 if let Some(tool) = tool.as_ref() {
985 title = tool.initial_title(tool_use.input.clone());
986 kind = tool.kind();
987 }
988
989 // Ensure the last message ends in the current tool use
990 let last_message = self.pending_message();
991 let push_new_tool_use = last_message.content.last_mut().map_or(true, |content| {
992 if let AgentMessageContent::ToolUse(last_tool_use) = content {
993 if last_tool_use.id == tool_use.id {
994 *last_tool_use = tool_use.clone();
995 false
996 } else {
997 true
998 }
999 } else {
1000 true
1001 }
1002 });
1003
1004 if push_new_tool_use {
1005 event_stream.send_tool_call(&tool_use.id, title, kind, tool_use.input.clone());
1006 last_message
1007 .content
1008 .push(AgentMessageContent::ToolUse(tool_use.clone()));
1009 } else {
1010 event_stream.update_tool_call_fields(
1011 &tool_use.id,
1012 acp::ToolCallUpdateFields {
1013 title: Some(title.into()),
1014 kind: Some(kind),
1015 raw_input: Some(tool_use.input.clone()),
1016 ..Default::default()
1017 },
1018 );
1019 }
1020
1021 if !tool_use.is_input_complete {
1022 return None;
1023 }
1024
1025 let Some(tool) = tool else {
1026 let content = format!("No tool named {} exists", tool_use.name);
1027 return Some(Task::ready(LanguageModelToolResult {
1028 content: LanguageModelToolResultContent::Text(Arc::from(content)),
1029 tool_use_id: tool_use.id,
1030 tool_name: tool_use.name,
1031 is_error: true,
1032 output: None,
1033 }));
1034 };
1035
1036 let fs = self.project.read(cx).fs().clone();
1037 let tool_event_stream =
1038 ToolCallEventStream::new(tool_use.id.clone(), event_stream.clone(), Some(fs));
1039 tool_event_stream.update_fields(acp::ToolCallUpdateFields {
1040 status: Some(acp::ToolCallStatus::InProgress),
1041 ..Default::default()
1042 });
1043 let supports_images = self.model.supports_images();
1044 let tool_result = tool.run(tool_use.input, tool_event_stream, cx);
1045 log::info!("Running tool {}", tool_use.name);
1046 Some(cx.foreground_executor().spawn(async move {
1047 let tool_result = tool_result.await.and_then(|output| {
1048 if let LanguageModelToolResultContent::Image(_) = &output.llm_output {
1049 if !supports_images {
1050 return Err(anyhow!(
1051 "Attempted to read an image, but this model doesn't support it.",
1052 ));
1053 }
1054 }
1055 Ok(output)
1056 });
1057
1058 match tool_result {
1059 Ok(output) => LanguageModelToolResult {
1060 tool_use_id: tool_use.id,
1061 tool_name: tool_use.name,
1062 is_error: false,
1063 content: output.llm_output,
1064 output: Some(output.raw_output),
1065 },
1066 Err(error) => LanguageModelToolResult {
1067 tool_use_id: tool_use.id,
1068 tool_name: tool_use.name,
1069 is_error: true,
1070 content: LanguageModelToolResultContent::Text(Arc::from(error.to_string())),
1071 output: None,
1072 },
1073 }
1074 }))
1075 }
1076
1077 fn handle_tool_use_json_parse_error_event(
1078 &mut self,
1079 tool_use_id: LanguageModelToolUseId,
1080 tool_name: Arc<str>,
1081 raw_input: Arc<str>,
1082 json_parse_error: String,
1083 ) -> LanguageModelToolResult {
1084 let tool_output = format!("Error parsing input JSON: {json_parse_error}");
1085 LanguageModelToolResult {
1086 tool_use_id,
1087 tool_name,
1088 is_error: true,
1089 content: LanguageModelToolResultContent::Text(tool_output.into()),
1090 output: Some(serde_json::Value::String(raw_input.to_string())),
1091 }
1092 }
1093
1094 fn pending_message(&mut self) -> &mut AgentMessage {
1095 self.pending_message.get_or_insert_default()
1096 }
1097
1098 fn flush_pending_message(&mut self) {
1099 let Some(mut message) = self.pending_message.take() else {
1100 return;
1101 };
1102
1103 for content in &message.content {
1104 let AgentMessageContent::ToolUse(tool_use) = content else {
1105 continue;
1106 };
1107
1108 if !message.tool_results.contains_key(&tool_use.id) {
1109 message.tool_results.insert(
1110 tool_use.id.clone(),
1111 LanguageModelToolResult {
1112 tool_use_id: tool_use.id.clone(),
1113 tool_name: tool_use.name.clone(),
1114 is_error: true,
1115 content: LanguageModelToolResultContent::Text(TOOL_CANCELED_MESSAGE.into()),
1116 output: None,
1117 },
1118 );
1119 }
1120 }
1121
1122 self.messages.push(Message::Agent(message));
1123 }
1124
1125 pub(crate) fn build_completion_request(
1126 &self,
1127 completion_intent: CompletionIntent,
1128 cx: &mut App,
1129 ) -> LanguageModelRequest {
1130 log::debug!("Building completion request");
1131 log::debug!("Completion intent: {:?}", completion_intent);
1132 log::debug!("Completion mode: {:?}", self.completion_mode);
1133
1134 let messages = self.build_request_messages();
1135 log::info!("Request will include {} messages", messages.len());
1136
1137 let tools = if let Some(tools) = self.tools(cx).log_err() {
1138 tools
1139 .filter_map(|tool| {
1140 let tool_name = tool.name().to_string();
1141 log::trace!("Including tool: {}", tool_name);
1142 Some(LanguageModelRequestTool {
1143 name: tool_name,
1144 description: tool.description().to_string(),
1145 input_schema: tool
1146 .input_schema(self.model.tool_input_format())
1147 .log_err()?,
1148 })
1149 })
1150 .collect()
1151 } else {
1152 Vec::new()
1153 };
1154
1155 log::info!("Request includes {} tools", tools.len());
1156
1157 let request = LanguageModelRequest {
1158 thread_id: Some(self.id.to_string()),
1159 prompt_id: Some(self.prompt_id.to_string()),
1160 intent: Some(completion_intent),
1161 mode: Some(self.completion_mode.into()),
1162 messages,
1163 tools,
1164 tool_choice: None,
1165 stop: Vec::new(),
1166 temperature: AgentSettings::temperature_for_model(self.model(), cx),
1167 thinking_allowed: true,
1168 };
1169
1170 log::debug!("Completion request built successfully");
1171 request
1172 }
1173
1174 fn tools<'a>(&'a self, cx: &'a App) -> Result<impl Iterator<Item = &'a Arc<dyn AnyAgentTool>>> {
1175 let profile = AgentSettings::get_global(cx)
1176 .profiles
1177 .get(&self.profile_id)
1178 .context("profile not found")?;
1179 let provider_id = self.model.provider_id();
1180
1181 Ok(self
1182 .tools
1183 .iter()
1184 .filter(move |(_, tool)| tool.supported_provider(&provider_id))
1185 .filter_map(|(tool_name, tool)| {
1186 if profile.is_tool_enabled(tool_name) {
1187 Some(tool)
1188 } else {
1189 None
1190 }
1191 })
1192 .chain(self.context_server_registry.read(cx).servers().flat_map(
1193 |(server_id, tools)| {
1194 tools.iter().filter_map(|(tool_name, tool)| {
1195 if profile.is_context_server_tool_enabled(&server_id.0, tool_name) {
1196 Some(tool)
1197 } else {
1198 None
1199 }
1200 })
1201 },
1202 )))
1203 }
1204
1205 fn build_request_messages(&self) -> Vec<LanguageModelRequestMessage> {
1206 log::trace!(
1207 "Building request messages from {} thread messages",
1208 self.messages.len()
1209 );
1210 let mut messages = vec![self.build_system_message()];
1211 for message in &self.messages {
1212 match message {
1213 Message::User(message) => messages.push(message.to_request()),
1214 Message::Agent(message) => messages.extend(message.to_request()),
1215 Message::Resume => messages.push(LanguageModelRequestMessage {
1216 role: Role::User,
1217 content: vec!["Continue where you left off".into()],
1218 cache: false,
1219 }),
1220 }
1221 }
1222
1223 if let Some(message) = self.pending_message.as_ref() {
1224 messages.extend(message.to_request());
1225 }
1226
1227 if let Some(last_user_message) = messages
1228 .iter_mut()
1229 .rev()
1230 .find(|message| message.role == Role::User)
1231 {
1232 last_user_message.cache = true;
1233 }
1234
1235 messages
1236 }
1237
1238 pub fn to_markdown(&self) -> String {
1239 let mut markdown = String::new();
1240 for (ix, message) in self.messages.iter().enumerate() {
1241 if ix > 0 {
1242 markdown.push('\n');
1243 }
1244 markdown.push_str(&message.to_markdown());
1245 }
1246
1247 if let Some(message) = self.pending_message.as_ref() {
1248 markdown.push('\n');
1249 markdown.push_str(&message.to_markdown());
1250 }
1251
1252 markdown
1253 }
1254
1255 fn advance_prompt_id(&mut self) {
1256 self.prompt_id = PromptId::new();
1257 }
1258}
1259
1260struct RunningTurn {
1261 /// Holds the task that handles agent interaction until the end of the turn.
1262 /// Survives across multiple requests as the model performs tool calls and
1263 /// we run tools, report their results.
1264 _task: Task<()>,
1265 /// The current event stream for the running turn. Used to report a final
1266 /// cancellation event if we cancel the turn.
1267 event_stream: ThreadEventStream,
1268}
1269
1270impl RunningTurn {
1271 fn cancel(self) {
1272 log::debug!("Cancelling in progress turn");
1273 self.event_stream.send_canceled();
1274 }
1275}
1276
1277pub trait AgentTool
1278where
1279 Self: 'static + Sized,
1280{
1281 type Input: for<'de> Deserialize<'de> + Serialize + JsonSchema;
1282 type Output: for<'de> Deserialize<'de> + Serialize + Into<LanguageModelToolResultContent>;
1283
1284 fn name(&self) -> SharedString;
1285
1286 fn description(&self) -> SharedString {
1287 let schema = schemars::schema_for!(Self::Input);
1288 SharedString::new(
1289 schema
1290 .get("description")
1291 .and_then(|description| description.as_str())
1292 .unwrap_or_default(),
1293 )
1294 }
1295
1296 fn kind(&self) -> acp::ToolKind;
1297
1298 /// The initial tool title to display. Can be updated during the tool run.
1299 fn initial_title(&self, input: Result<Self::Input, serde_json::Value>) -> SharedString;
1300
1301 /// Returns the JSON schema that describes the tool's input.
1302 fn input_schema(&self) -> Schema {
1303 schemars::schema_for!(Self::Input)
1304 }
1305
1306 /// Some tools rely on a provider for the underlying billing or other reasons.
1307 /// Allow the tool to check if they are compatible, or should be filtered out.
1308 fn supported_provider(&self, _provider: &LanguageModelProviderId) -> bool {
1309 true
1310 }
1311
1312 /// Runs the tool with the provided input.
1313 fn run(
1314 self: Arc<Self>,
1315 input: Self::Input,
1316 event_stream: ToolCallEventStream,
1317 cx: &mut App,
1318 ) -> Task<Result<Self::Output>>;
1319
1320 /// Emits events for a previous execution of the tool.
1321 fn replay(
1322 &self,
1323 _input: Self::Input,
1324 _output: Self::Output,
1325 _event_stream: ToolCallEventStream,
1326 _cx: &mut App,
1327 ) -> Result<()> {
1328 Ok(())
1329 }
1330
1331 fn erase(self) -> Arc<dyn AnyAgentTool> {
1332 Arc::new(Erased(Arc::new(self)))
1333 }
1334}
1335
1336pub struct Erased<T>(T);
1337
1338pub struct AgentToolOutput {
1339 pub llm_output: LanguageModelToolResultContent,
1340 pub raw_output: serde_json::Value,
1341}
1342
1343pub trait AnyAgentTool {
1344 fn name(&self) -> SharedString;
1345 fn description(&self) -> SharedString;
1346 fn kind(&self) -> acp::ToolKind;
1347 fn initial_title(&self, input: serde_json::Value) -> SharedString;
1348 fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value>;
1349 fn supported_provider(&self, _provider: &LanguageModelProviderId) -> bool {
1350 true
1351 }
1352 fn run(
1353 self: Arc<Self>,
1354 input: serde_json::Value,
1355 event_stream: ToolCallEventStream,
1356 cx: &mut App,
1357 ) -> Task<Result<AgentToolOutput>>;
1358 fn replay(
1359 &self,
1360 input: serde_json::Value,
1361 output: serde_json::Value,
1362 event_stream: ToolCallEventStream,
1363 cx: &mut App,
1364 ) -> Result<()>;
1365}
1366
1367impl<T> AnyAgentTool for Erased<Arc<T>>
1368where
1369 T: AgentTool,
1370{
1371 fn name(&self) -> SharedString {
1372 self.0.name()
1373 }
1374
1375 fn description(&self) -> SharedString {
1376 self.0.description()
1377 }
1378
1379 fn kind(&self) -> agent_client_protocol::ToolKind {
1380 self.0.kind()
1381 }
1382
1383 fn initial_title(&self, input: serde_json::Value) -> SharedString {
1384 let parsed_input = serde_json::from_value(input.clone()).map_err(|_| input);
1385 self.0.initial_title(parsed_input)
1386 }
1387
1388 fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
1389 let mut json = serde_json::to_value(self.0.input_schema())?;
1390 adapt_schema_to_format(&mut json, format)?;
1391 Ok(json)
1392 }
1393
1394 fn supported_provider(&self, provider: &LanguageModelProviderId) -> bool {
1395 self.0.supported_provider(provider)
1396 }
1397
1398 fn run(
1399 self: Arc<Self>,
1400 input: serde_json::Value,
1401 event_stream: ToolCallEventStream,
1402 cx: &mut App,
1403 ) -> Task<Result<AgentToolOutput>> {
1404 cx.spawn(async move |cx| {
1405 let input = serde_json::from_value(input)?;
1406 let output = cx
1407 .update(|cx| self.0.clone().run(input, event_stream, cx))?
1408 .await?;
1409 let raw_output = serde_json::to_value(&output)?;
1410 Ok(AgentToolOutput {
1411 llm_output: output.into(),
1412 raw_output,
1413 })
1414 })
1415 }
1416
1417 fn replay(
1418 &self,
1419 input: serde_json::Value,
1420 output: serde_json::Value,
1421 event_stream: ToolCallEventStream,
1422 cx: &mut App,
1423 ) -> Result<()> {
1424 let input = serde_json::from_value(input)?;
1425 let output = serde_json::from_value(output)?;
1426 self.0.replay(input, output, event_stream, cx)
1427 }
1428}
1429
1430#[derive(Clone)]
1431struct ThreadEventStream(mpsc::UnboundedSender<Result<ThreadEvent>>);
1432
1433impl ThreadEventStream {
1434 fn send_user_message(&self, message: &UserMessage) {
1435 self.0
1436 .unbounded_send(Ok(ThreadEvent::UserMessage(message.clone())))
1437 .ok();
1438 }
1439
1440 fn send_text(&self, text: &str) {
1441 self.0
1442 .unbounded_send(Ok(ThreadEvent::AgentText(text.to_string())))
1443 .ok();
1444 }
1445
1446 fn send_thinking(&self, text: &str) {
1447 self.0
1448 .unbounded_send(Ok(ThreadEvent::AgentThinking(text.to_string())))
1449 .ok();
1450 }
1451
1452 fn send_tool_call(
1453 &self,
1454 id: &LanguageModelToolUseId,
1455 title: SharedString,
1456 kind: acp::ToolKind,
1457 input: serde_json::Value,
1458 ) {
1459 self.0
1460 .unbounded_send(Ok(ThreadEvent::ToolCall(Self::initial_tool_call(
1461 id,
1462 title.to_string(),
1463 kind,
1464 input,
1465 ))))
1466 .ok();
1467 }
1468
1469 fn initial_tool_call(
1470 id: &LanguageModelToolUseId,
1471 title: String,
1472 kind: acp::ToolKind,
1473 input: serde_json::Value,
1474 ) -> acp::ToolCall {
1475 acp::ToolCall {
1476 id: acp::ToolCallId(id.to_string().into()),
1477 title,
1478 kind,
1479 status: acp::ToolCallStatus::Pending,
1480 content: vec![],
1481 locations: vec![],
1482 raw_input: Some(input),
1483 raw_output: None,
1484 }
1485 }
1486
1487 fn update_tool_call_fields(
1488 &self,
1489 tool_use_id: &LanguageModelToolUseId,
1490 fields: acp::ToolCallUpdateFields,
1491 ) {
1492 self.0
1493 .unbounded_send(Ok(ThreadEvent::ToolCallUpdate(
1494 acp::ToolCallUpdate {
1495 id: acp::ToolCallId(tool_use_id.to_string().into()),
1496 fields,
1497 }
1498 .into(),
1499 )))
1500 .ok();
1501 }
1502
1503 fn send_stop(&self, reason: StopReason) {
1504 match reason {
1505 StopReason::EndTurn => {
1506 self.0
1507 .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::EndTurn)))
1508 .ok();
1509 }
1510 StopReason::MaxTokens => {
1511 self.0
1512 .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::MaxTokens)))
1513 .ok();
1514 }
1515 StopReason::Refusal => {
1516 self.0
1517 .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::Refusal)))
1518 .ok();
1519 }
1520 StopReason::ToolUse => {}
1521 }
1522 }
1523
1524 fn send_canceled(&self) {
1525 self.0
1526 .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::Canceled)))
1527 .ok();
1528 }
1529
1530 fn send_error(&self, error: impl Into<anyhow::Error>) {
1531 self.0.unbounded_send(Err(error.into())).ok();
1532 }
1533}
1534
1535#[derive(Clone)]
1536pub struct ToolCallEventStream {
1537 tool_use_id: LanguageModelToolUseId,
1538 stream: ThreadEventStream,
1539 fs: Option<Arc<dyn Fs>>,
1540}
1541
1542impl ToolCallEventStream {
1543 #[cfg(test)]
1544 pub fn test() -> (Self, ToolCallEventStreamReceiver) {
1545 let (events_tx, events_rx) = mpsc::unbounded::<Result<ThreadEvent>>();
1546
1547 let stream = ToolCallEventStream::new("test_id".into(), ThreadEventStream(events_tx), None);
1548
1549 (stream, ToolCallEventStreamReceiver(events_rx))
1550 }
1551
1552 fn new(
1553 tool_use_id: LanguageModelToolUseId,
1554 stream: ThreadEventStream,
1555 fs: Option<Arc<dyn Fs>>,
1556 ) -> Self {
1557 Self {
1558 tool_use_id,
1559 stream,
1560 fs,
1561 }
1562 }
1563
1564 pub fn update_fields(&self, fields: acp::ToolCallUpdateFields) {
1565 self.stream
1566 .update_tool_call_fields(&self.tool_use_id, fields);
1567 }
1568
1569 pub fn update_diff(&self, diff: Entity<acp_thread::Diff>) {
1570 self.stream
1571 .0
1572 .unbounded_send(Ok(ThreadEvent::ToolCallUpdate(
1573 acp_thread::ToolCallUpdateDiff {
1574 id: acp::ToolCallId(self.tool_use_id.to_string().into()),
1575 diff,
1576 }
1577 .into(),
1578 )))
1579 .ok();
1580 }
1581
1582 pub fn update_terminal(&self, terminal: Entity<acp_thread::Terminal>) {
1583 self.stream
1584 .0
1585 .unbounded_send(Ok(ThreadEvent::ToolCallUpdate(
1586 acp_thread::ToolCallUpdateTerminal {
1587 id: acp::ToolCallId(self.tool_use_id.to_string().into()),
1588 terminal,
1589 }
1590 .into(),
1591 )))
1592 .ok();
1593 }
1594
1595 pub fn authorize(&self, title: impl Into<String>, cx: &mut App) -> Task<Result<()>> {
1596 if agent_settings::AgentSettings::get_global(cx).always_allow_tool_actions {
1597 return Task::ready(Ok(()));
1598 }
1599
1600 let (response_tx, response_rx) = oneshot::channel();
1601 self.stream
1602 .0
1603 .unbounded_send(Ok(ThreadEvent::ToolCallAuthorization(
1604 ToolCallAuthorization {
1605 tool_call: acp::ToolCallUpdate {
1606 id: acp::ToolCallId(self.tool_use_id.to_string().into()),
1607 fields: acp::ToolCallUpdateFields {
1608 title: Some(title.into()),
1609 ..Default::default()
1610 },
1611 },
1612 options: vec![
1613 acp::PermissionOption {
1614 id: acp::PermissionOptionId("always_allow".into()),
1615 name: "Always Allow".into(),
1616 kind: acp::PermissionOptionKind::AllowAlways,
1617 },
1618 acp::PermissionOption {
1619 id: acp::PermissionOptionId("allow".into()),
1620 name: "Allow".into(),
1621 kind: acp::PermissionOptionKind::AllowOnce,
1622 },
1623 acp::PermissionOption {
1624 id: acp::PermissionOptionId("deny".into()),
1625 name: "Deny".into(),
1626 kind: acp::PermissionOptionKind::RejectOnce,
1627 },
1628 ],
1629 response: response_tx,
1630 },
1631 )))
1632 .ok();
1633 let fs = self.fs.clone();
1634 cx.spawn(async move |cx| match response_rx.await?.0.as_ref() {
1635 "always_allow" => {
1636 if let Some(fs) = fs.clone() {
1637 cx.update(|cx| {
1638 update_settings_file::<AgentSettings>(fs, cx, |settings, _| {
1639 settings.set_always_allow_tool_actions(true);
1640 });
1641 })?;
1642 }
1643
1644 Ok(())
1645 }
1646 "allow" => Ok(()),
1647 _ => Err(anyhow!("Permission to run tool denied by user")),
1648 })
1649 }
1650}
1651
1652#[cfg(test)]
1653pub struct ToolCallEventStreamReceiver(mpsc::UnboundedReceiver<Result<ThreadEvent>>);
1654
1655#[cfg(test)]
1656impl ToolCallEventStreamReceiver {
1657 pub async fn expect_authorization(&mut self) -> ToolCallAuthorization {
1658 let event = self.0.next().await;
1659 if let Some(Ok(ThreadEvent::ToolCallAuthorization(auth))) = event {
1660 auth
1661 } else {
1662 panic!("Expected ToolCallAuthorization but got: {:?}", event);
1663 }
1664 }
1665
1666 pub async fn expect_terminal(&mut self) -> Entity<acp_thread::Terminal> {
1667 let event = self.0.next().await;
1668 if let Some(Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateTerminal(
1669 update,
1670 )))) = event
1671 {
1672 update.terminal
1673 } else {
1674 panic!("Expected terminal but got: {:?}", event);
1675 }
1676 }
1677}
1678
1679#[cfg(test)]
1680impl std::ops::Deref for ToolCallEventStreamReceiver {
1681 type Target = mpsc::UnboundedReceiver<Result<ThreadEvent>>;
1682
1683 fn deref(&self) -> &Self::Target {
1684 &self.0
1685 }
1686}
1687
1688#[cfg(test)]
1689impl std::ops::DerefMut for ToolCallEventStreamReceiver {
1690 fn deref_mut(&mut self) -> &mut Self::Target {
1691 &mut self.0
1692 }
1693}
1694
1695impl From<&str> for UserMessageContent {
1696 fn from(text: &str) -> Self {
1697 Self::Text(text.into())
1698 }
1699}
1700
1701impl From<acp::ContentBlock> for UserMessageContent {
1702 fn from(value: acp::ContentBlock) -> Self {
1703 match value {
1704 acp::ContentBlock::Text(text_content) => Self::Text(text_content.text),
1705 acp::ContentBlock::Image(image_content) => Self::Image(convert_image(image_content)),
1706 acp::ContentBlock::Audio(_) => {
1707 // TODO
1708 Self::Text("[audio]".to_string())
1709 }
1710 acp::ContentBlock::ResourceLink(resource_link) => {
1711 match MentionUri::parse(&resource_link.uri) {
1712 Ok(uri) => Self::Mention {
1713 uri,
1714 content: String::new(),
1715 },
1716 Err(err) => {
1717 log::error!("Failed to parse mention link: {}", err);
1718 Self::Text(format!("[{}]({})", resource_link.name, resource_link.uri))
1719 }
1720 }
1721 }
1722 acp::ContentBlock::Resource(resource) => match resource.resource {
1723 acp::EmbeddedResourceResource::TextResourceContents(resource) => {
1724 match MentionUri::parse(&resource.uri) {
1725 Ok(uri) => Self::Mention {
1726 uri,
1727 content: resource.text,
1728 },
1729 Err(err) => {
1730 log::error!("Failed to parse mention link: {}", err);
1731 Self::Text(
1732 MarkdownCodeBlock {
1733 tag: &resource.uri,
1734 text: &resource.text,
1735 }
1736 .to_string(),
1737 )
1738 }
1739 }
1740 }
1741 acp::EmbeddedResourceResource::BlobResourceContents(_) => {
1742 // TODO
1743 Self::Text("[blob]".to_string())
1744 }
1745 },
1746 }
1747 }
1748}
1749
1750impl From<UserMessageContent> for acp::ContentBlock {
1751 fn from(content: UserMessageContent) -> Self {
1752 match content {
1753 UserMessageContent::Text(text) => acp::ContentBlock::Text(acp::TextContent {
1754 text,
1755 annotations: None,
1756 }),
1757 UserMessageContent::Image(image) => acp::ContentBlock::Image(acp::ImageContent {
1758 data: image.source.to_string(),
1759 mime_type: "image/png".to_string(),
1760 annotations: None,
1761 uri: None,
1762 }),
1763 UserMessageContent::Mention { uri, content } => {
1764 todo!()
1765 }
1766 }
1767 }
1768}
1769
1770fn convert_image(image_content: acp::ImageContent) -> LanguageModelImage {
1771 LanguageModelImage {
1772 source: image_content.data.into(),
1773 // TODO: make this optional?
1774 size: gpui::Size::new(0.into(), 0.into()),
1775 }
1776}