1use crate::{SystemPromptTemplate, Template, Templates};
2use acp_thread::Diff;
3use action_log::ActionLog;
4use agent_client_protocol as acp;
5use anyhow::{Context as _, Result, anyhow};
6use assistant_tool::adapt_schema_to_format;
7use cloud_llm_client::{CompletionIntent, CompletionMode};
8use collections::HashMap;
9use futures::{
10 channel::{mpsc, oneshot},
11 stream::FuturesUnordered,
12};
13use gpui::{App, Context, Entity, SharedString, Task};
14use language_model::{
15 LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
16 LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
17 LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolSchemaFormat,
18 LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role, StopReason,
19};
20use log;
21use project::Project;
22use prompt_store::ProjectContext;
23use schemars::{JsonSchema, Schema};
24use serde::{Deserialize, Serialize};
25use smol::stream::StreamExt;
26use std::{cell::RefCell, collections::BTreeMap, fmt::Write, future::Future, rc::Rc, sync::Arc};
27use util::{ResultExt, markdown::MarkdownCodeBlock};
28
29#[derive(Debug, Clone)]
30pub struct AgentMessage {
31 pub role: Role,
32 pub content: Vec<MessageContent>,
33}
34
35impl AgentMessage {
36 pub fn to_markdown(&self) -> String {
37 let mut markdown = format!("## {}\n", self.role);
38
39 for content in &self.content {
40 match content {
41 MessageContent::Text(text) => {
42 markdown.push_str(text);
43 markdown.push('\n');
44 }
45 MessageContent::Thinking { text, .. } => {
46 markdown.push_str("<think>");
47 markdown.push_str(text);
48 markdown.push_str("</think>\n");
49 }
50 MessageContent::RedactedThinking(_) => markdown.push_str("<redacted_thinking />\n"),
51 MessageContent::Image(_) => {
52 markdown.push_str("<image />\n");
53 }
54 MessageContent::ToolUse(tool_use) => {
55 markdown.push_str(&format!(
56 "**Tool Use**: {} (ID: {})\n",
57 tool_use.name, tool_use.id
58 ));
59 markdown.push_str(&format!(
60 "{}\n",
61 MarkdownCodeBlock {
62 tag: "json",
63 text: &format!("{:#}", tool_use.input)
64 }
65 ));
66 }
67 MessageContent::ToolResult(tool_result) => {
68 markdown.push_str(&format!(
69 "**Tool Result**: {} (ID: {})\n\n",
70 tool_result.tool_name, tool_result.tool_use_id
71 ));
72 if tool_result.is_error {
73 markdown.push_str("**ERROR:**\n");
74 }
75
76 match &tool_result.content {
77 LanguageModelToolResultContent::Text(text) => {
78 writeln!(markdown, "{text}\n").ok();
79 }
80 LanguageModelToolResultContent::Image(_) => {
81 writeln!(markdown, "<image />\n").ok();
82 }
83 }
84
85 if let Some(output) = tool_result.output.as_ref() {
86 writeln!(
87 markdown,
88 "**Debug Output**:\n\n```json\n{}\n```\n",
89 serde_json::to_string_pretty(output).unwrap()
90 )
91 .unwrap();
92 }
93 }
94 }
95 }
96
97 markdown
98 }
99}
100
101#[derive(Debug)]
102pub enum AgentResponseEvent {
103 Text(String),
104 Thinking(String),
105 ToolCall(acp::ToolCall),
106 ToolCallUpdate(acp_thread::ToolCallUpdate),
107 ToolCallAuthorization(ToolCallAuthorization),
108 Stop(acp::StopReason),
109}
110
111#[derive(Debug)]
112pub struct ToolCallAuthorization {
113 pub tool_call: acp::ToolCall,
114 pub options: Vec<acp::PermissionOption>,
115 pub response: oneshot::Sender<acp::PermissionOptionId>,
116}
117
118pub struct Thread {
119 messages: Vec<AgentMessage>,
120 completion_mode: CompletionMode,
121 /// Holds the task that handles agent interaction until the end of the turn.
122 /// Survives across multiple requests as the model performs tool calls and
123 /// we run tools, report their results.
124 running_turn: Option<Task<()>>,
125 pending_tool_uses: HashMap<LanguageModelToolUseId, LanguageModelToolUse>,
126 tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
127 project_context: Rc<RefCell<ProjectContext>>,
128 templates: Arc<Templates>,
129 pub selected_model: Arc<dyn LanguageModel>,
130 project: Entity<Project>,
131 action_log: Entity<ActionLog>,
132}
133
134impl Thread {
135 pub fn new(
136 project: Entity<Project>,
137 project_context: Rc<RefCell<ProjectContext>>,
138 action_log: Entity<ActionLog>,
139 templates: Arc<Templates>,
140 default_model: Arc<dyn LanguageModel>,
141 ) -> Self {
142 Self {
143 messages: Vec::new(),
144 completion_mode: CompletionMode::Normal,
145 running_turn: None,
146 pending_tool_uses: HashMap::default(),
147 tools: BTreeMap::default(),
148 project_context,
149 templates,
150 selected_model: default_model,
151 project,
152 action_log,
153 }
154 }
155
156 pub fn project(&self) -> &Entity<Project> {
157 &self.project
158 }
159
160 pub fn action_log(&self) -> &Entity<ActionLog> {
161 &self.action_log
162 }
163
164 pub fn set_mode(&mut self, mode: CompletionMode) {
165 self.completion_mode = mode;
166 }
167
168 pub fn messages(&self) -> &[AgentMessage] {
169 &self.messages
170 }
171
172 pub fn add_tool(&mut self, tool: impl AgentTool) {
173 self.tools.insert(tool.name(), tool.erase());
174 }
175
176 pub fn remove_tool(&mut self, name: &str) -> bool {
177 self.tools.remove(name).is_some()
178 }
179
180 pub fn cancel(&mut self) {
181 self.running_turn.take();
182
183 let tool_results = self
184 .pending_tool_uses
185 .drain()
186 .map(|(tool_use_id, tool_use)| {
187 MessageContent::ToolResult(LanguageModelToolResult {
188 tool_use_id,
189 tool_name: tool_use.name.clone(),
190 is_error: true,
191 content: LanguageModelToolResultContent::Text("Tool canceled by user".into()),
192 output: None,
193 })
194 })
195 .collect::<Vec<_>>();
196 self.last_user_message().content.extend(tool_results);
197 }
198
199 /// Sending a message results in the model streaming a response, which could include tool calls.
200 /// After calling tools, the model will stops and waits for any outstanding tool calls to be completed and their results sent.
201 /// The returned channel will report all the occurrences in which the model stops before erroring or ending its turn.
202 pub fn send(
203 &mut self,
204 content: impl Into<MessageContent>,
205 cx: &mut Context<Self>,
206 ) -> mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>> {
207 let content = content.into();
208 let model = self.selected_model.clone();
209 log::info!("Thread::send called with model: {:?}", model.name());
210 log::debug!("Thread::send content: {:?}", content);
211
212 cx.notify();
213 let (events_tx, events_rx) =
214 mpsc::unbounded::<Result<AgentResponseEvent, LanguageModelCompletionError>>();
215 let event_stream = AgentResponseEventStream(events_tx);
216
217 let user_message_ix = self.messages.len();
218 self.messages.push(AgentMessage {
219 role: Role::User,
220 content: vec![content],
221 });
222 log::info!("Total messages in thread: {}", self.messages.len());
223 self.running_turn = Some(cx.spawn(async move |thread, cx| {
224 log::info!("Starting agent turn execution");
225 let turn_result = async {
226 // Perform one request, then keep looping if the model makes tool calls.
227 let mut completion_intent = CompletionIntent::UserPrompt;
228 'outer: loop {
229 log::debug!(
230 "Building completion request with intent: {:?}",
231 completion_intent
232 );
233 let request = thread.update(cx, |thread, cx| {
234 thread.build_completion_request(completion_intent, cx)
235 })?;
236
237 // println!(
238 // "request: {}",
239 // serde_json::to_string_pretty(&request).unwrap()
240 // );
241
242 // Stream events, appending to messages and collecting up tool uses.
243 log::info!("Calling model.stream_completion");
244 let mut events = model.stream_completion(request, cx).await?;
245 log::debug!("Stream completion started successfully");
246 let mut tool_uses = FuturesUnordered::new();
247 while let Some(event) = events.next().await {
248 match event {
249 Ok(LanguageModelCompletionEvent::Stop(reason)) => {
250 event_stream.send_stop(reason);
251 if reason == StopReason::Refusal {
252 thread.update(cx, |thread, _cx| {
253 thread.messages.truncate(user_message_ix);
254 })?;
255 break 'outer;
256 }
257 }
258 Ok(event) => {
259 log::trace!("Received completion event: {:?}", event);
260 thread
261 .update(cx, |thread, cx| {
262 tool_uses.extend(thread.handle_streamed_completion_event(
263 event,
264 &event_stream,
265 cx,
266 ));
267 })
268 .ok();
269 }
270 Err(error) => {
271 log::error!("Error in completion stream: {:?}", error);
272 event_stream.send_error(error);
273 break;
274 }
275 }
276 }
277
278 // If there are no tool uses, the turn is done.
279 if tool_uses.is_empty() {
280 log::info!("No tool uses found, completing turn");
281 break;
282 }
283 log::info!("Found {} tool uses to execute", tool_uses.len());
284
285 // As tool results trickle in, insert them in the last user
286 // message so that they can be sent on the next tick of the
287 // agentic loop.
288 while let Some(tool_result) = tool_uses.next().await {
289 log::info!("Tool finished {:?}", tool_result);
290
291 event_stream.update_tool_call_fields(
292 &tool_result.tool_use_id,
293 acp::ToolCallUpdateFields {
294 status: Some(if tool_result.is_error {
295 acp::ToolCallStatus::Failed
296 } else {
297 acp::ToolCallStatus::Completed
298 }),
299 ..Default::default()
300 },
301 );
302 thread
303 .update(cx, |thread, _cx| {
304 thread.pending_tool_uses.remove(&tool_result.tool_use_id);
305 thread
306 .last_user_message()
307 .content
308 .push(MessageContent::ToolResult(tool_result));
309 })
310 .ok();
311 }
312
313 completion_intent = CompletionIntent::ToolResults;
314 }
315
316 Ok(())
317 }
318 .await;
319
320 if let Err(error) = turn_result {
321 log::error!("Turn execution failed: {:?}", error);
322 event_stream.send_error(error);
323 } else {
324 log::info!("Turn execution completed successfully");
325 }
326 }));
327 events_rx
328 }
329
330 pub fn build_system_message(&self) -> AgentMessage {
331 log::debug!("Building system message");
332 let prompt = SystemPromptTemplate {
333 project: &self.project_context.borrow(),
334 available_tools: self.tools.keys().cloned().collect(),
335 }
336 .render(&self.templates)
337 .context("failed to build system prompt")
338 .expect("Invalid template");
339 log::debug!("System message built");
340 AgentMessage {
341 role: Role::System,
342 content: vec![prompt.into()],
343 }
344 }
345
346 /// A helper method that's called on every streamed completion event.
347 /// Returns an optional tool result task, which the main agentic loop in
348 /// send will send back to the model when it resolves.
349 fn handle_streamed_completion_event(
350 &mut self,
351 event: LanguageModelCompletionEvent,
352 event_stream: &AgentResponseEventStream,
353 cx: &mut Context<Self>,
354 ) -> Option<Task<LanguageModelToolResult>> {
355 log::trace!("Handling streamed completion event: {:?}", event);
356 use LanguageModelCompletionEvent::*;
357
358 match event {
359 StartMessage { .. } => {
360 self.messages.push(AgentMessage {
361 role: Role::Assistant,
362 content: Vec::new(),
363 });
364 }
365 Text(new_text) => self.handle_text_event(new_text, event_stream, cx),
366 Thinking { text, signature } => {
367 self.handle_thinking_event(text, signature, event_stream, cx)
368 }
369 RedactedThinking { data } => self.handle_redacted_thinking_event(data, cx),
370 ToolUse(tool_use) => {
371 return self.handle_tool_use_event(tool_use, event_stream, cx);
372 }
373 ToolUseJsonParseError {
374 id,
375 tool_name,
376 raw_input,
377 json_parse_error,
378 } => {
379 return Some(Task::ready(self.handle_tool_use_json_parse_error_event(
380 id,
381 tool_name,
382 raw_input,
383 json_parse_error,
384 )));
385 }
386 UsageUpdate(_) | StatusUpdate(_) => {}
387 Stop(_) => unreachable!(),
388 }
389
390 None
391 }
392
393 fn handle_text_event(
394 &mut self,
395 new_text: String,
396 events_stream: &AgentResponseEventStream,
397 cx: &mut Context<Self>,
398 ) {
399 events_stream.send_text(&new_text);
400
401 let last_message = self.last_assistant_message();
402 if let Some(MessageContent::Text(text)) = last_message.content.last_mut() {
403 text.push_str(&new_text);
404 } else {
405 last_message.content.push(MessageContent::Text(new_text));
406 }
407
408 cx.notify();
409 }
410
411 fn handle_thinking_event(
412 &mut self,
413 new_text: String,
414 new_signature: Option<String>,
415 event_stream: &AgentResponseEventStream,
416 cx: &mut Context<Self>,
417 ) {
418 event_stream.send_thinking(&new_text);
419
420 let last_message = self.last_assistant_message();
421 if let Some(MessageContent::Thinking { text, signature }) = last_message.content.last_mut()
422 {
423 text.push_str(&new_text);
424 *signature = new_signature.or(signature.take());
425 } else {
426 last_message.content.push(MessageContent::Thinking {
427 text: new_text,
428 signature: new_signature,
429 });
430 }
431
432 cx.notify();
433 }
434
435 fn handle_redacted_thinking_event(&mut self, data: String, cx: &mut Context<Self>) {
436 let last_message = self.last_assistant_message();
437 last_message
438 .content
439 .push(MessageContent::RedactedThinking(data));
440 cx.notify();
441 }
442
443 fn handle_tool_use_event(
444 &mut self,
445 tool_use: LanguageModelToolUse,
446 event_stream: &AgentResponseEventStream,
447 cx: &mut Context<Self>,
448 ) -> Option<Task<LanguageModelToolResult>> {
449 cx.notify();
450
451 let tool = self.tools.get(tool_use.name.as_ref()).cloned();
452
453 self.pending_tool_uses
454 .insert(tool_use.id.clone(), tool_use.clone());
455 let last_message = self.last_assistant_message();
456
457 // Ensure the last message ends in the current tool use
458 let push_new_tool_use = last_message.content.last_mut().map_or(true, |content| {
459 if let MessageContent::ToolUse(last_tool_use) = content {
460 if last_tool_use.id == tool_use.id {
461 *last_tool_use = tool_use.clone();
462 false
463 } else {
464 true
465 }
466 } else {
467 true
468 }
469 });
470
471 let mut title = SharedString::from(&tool_use.name);
472 let mut kind = acp::ToolKind::Other;
473 if let Some(tool) = tool.as_ref() {
474 title = tool.initial_title(tool_use.input.clone());
475 kind = tool.kind();
476 }
477
478 if push_new_tool_use {
479 event_stream.send_tool_call(&tool_use.id, title, kind, tool_use.input.clone());
480 last_message
481 .content
482 .push(MessageContent::ToolUse(tool_use.clone()));
483 } else {
484 event_stream.update_tool_call_fields(
485 &tool_use.id,
486 acp::ToolCallUpdateFields {
487 title: Some(title.into()),
488 kind: Some(kind),
489 raw_input: Some(tool_use.input.clone()),
490 ..Default::default()
491 },
492 );
493 }
494
495 if !tool_use.is_input_complete {
496 return None;
497 }
498
499 let Some(tool) = tool else {
500 let content = format!("No tool named {} exists", tool_use.name);
501 return Some(Task::ready(LanguageModelToolResult {
502 content: LanguageModelToolResultContent::Text(Arc::from(content)),
503 tool_use_id: tool_use.id,
504 tool_name: tool_use.name,
505 is_error: true,
506 output: None,
507 }));
508 };
509
510 let tool_event_stream =
511 ToolCallEventStream::new(&tool_use, tool.kind(), event_stream.clone());
512 tool_event_stream.update_fields(acp::ToolCallUpdateFields {
513 status: Some(acp::ToolCallStatus::InProgress),
514 ..Default::default()
515 });
516 let supports_images = self.selected_model.supports_images();
517 let tool_result = tool.run(tool_use.input, tool_event_stream, cx);
518 Some(cx.foreground_executor().spawn(async move {
519 let tool_result = tool_result.await.and_then(|output| {
520 if let LanguageModelToolResultContent::Image(_) = &output.llm_output {
521 if !supports_images {
522 return Err(anyhow!(
523 "Attempted to read an image, but this model doesn't support it.",
524 ));
525 }
526 }
527 Ok(output)
528 });
529
530 match tool_result {
531 Ok(output) => LanguageModelToolResult {
532 tool_use_id: tool_use.id,
533 tool_name: tool_use.name,
534 is_error: false,
535 content: output.llm_output,
536 output: Some(output.raw_output),
537 },
538 Err(error) => LanguageModelToolResult {
539 tool_use_id: tool_use.id,
540 tool_name: tool_use.name,
541 is_error: true,
542 content: LanguageModelToolResultContent::Text(Arc::from(error.to_string())),
543 output: None,
544 },
545 }
546 }))
547 }
548
549 fn handle_tool_use_json_parse_error_event(
550 &mut self,
551 tool_use_id: LanguageModelToolUseId,
552 tool_name: Arc<str>,
553 raw_input: Arc<str>,
554 json_parse_error: String,
555 ) -> LanguageModelToolResult {
556 let tool_output = format!("Error parsing input JSON: {json_parse_error}");
557 LanguageModelToolResult {
558 tool_use_id,
559 tool_name,
560 is_error: true,
561 content: LanguageModelToolResultContent::Text(tool_output.into()),
562 output: Some(serde_json::Value::String(raw_input.to_string())),
563 }
564 }
565
566 /// Guarantees the last message is from the assistant and returns a mutable reference.
567 fn last_assistant_message(&mut self) -> &mut AgentMessage {
568 if self
569 .messages
570 .last()
571 .map_or(true, |m| m.role != Role::Assistant)
572 {
573 self.messages.push(AgentMessage {
574 role: Role::Assistant,
575 content: Vec::new(),
576 });
577 }
578 self.messages.last_mut().unwrap()
579 }
580
581 /// Guarantees the last message is from the user and returns a mutable reference.
582 fn last_user_message(&mut self) -> &mut AgentMessage {
583 if self.messages.last().map_or(true, |m| m.role != Role::User) {
584 self.messages.push(AgentMessage {
585 role: Role::User,
586 content: Vec::new(),
587 });
588 }
589 self.messages.last_mut().unwrap()
590 }
591
592 pub(crate) fn build_completion_request(
593 &self,
594 completion_intent: CompletionIntent,
595 cx: &mut App,
596 ) -> LanguageModelRequest {
597 log::debug!("Building completion request");
598 log::debug!("Completion intent: {:?}", completion_intent);
599 log::debug!("Completion mode: {:?}", self.completion_mode);
600
601 let messages = self.build_request_messages();
602 log::info!("Request will include {} messages", messages.len());
603
604 let tools: Vec<LanguageModelRequestTool> = self
605 .tools
606 .values()
607 .filter_map(|tool| {
608 let tool_name = tool.name().to_string();
609 log::trace!("Including tool: {}", tool_name);
610 Some(LanguageModelRequestTool {
611 name: tool_name,
612 description: tool.description(cx).to_string(),
613 input_schema: tool
614 .input_schema(self.selected_model.tool_input_format())
615 .log_err()?,
616 })
617 })
618 .collect();
619
620 log::info!("Request includes {} tools", tools.len());
621
622 let request = LanguageModelRequest {
623 thread_id: None,
624 prompt_id: None,
625 intent: Some(completion_intent),
626 mode: Some(self.completion_mode),
627 messages,
628 tools,
629 tool_choice: None,
630 stop: Vec::new(),
631 temperature: None,
632 thinking_allowed: true,
633 };
634
635 log::debug!("Completion request built successfully");
636 request
637 }
638
639 fn build_request_messages(&self) -> Vec<LanguageModelRequestMessage> {
640 log::trace!(
641 "Building request messages from {} thread messages",
642 self.messages.len()
643 );
644
645 let messages = Some(self.build_system_message())
646 .iter()
647 .chain(self.messages.iter())
648 .map(|message| {
649 log::trace!(
650 " - {} message with {} content items",
651 match message.role {
652 Role::System => "System",
653 Role::User => "User",
654 Role::Assistant => "Assistant",
655 },
656 message.content.len()
657 );
658 LanguageModelRequestMessage {
659 role: message.role,
660 content: message.content.clone(),
661 cache: false,
662 }
663 })
664 .collect();
665 messages
666 }
667
668 pub fn to_markdown(&self) -> String {
669 let mut markdown = String::new();
670 for message in &self.messages {
671 markdown.push_str(&message.to_markdown());
672 }
673 markdown
674 }
675}
676
677pub trait AgentTool
678where
679 Self: 'static + Sized,
680{
681 type Input: for<'de> Deserialize<'de> + Serialize + JsonSchema;
682 type Output: for<'de> Deserialize<'de> + Serialize + Into<LanguageModelToolResultContent>;
683
684 fn name(&self) -> SharedString;
685
686 fn description(&self, _cx: &mut App) -> SharedString {
687 let schema = schemars::schema_for!(Self::Input);
688 SharedString::new(
689 schema
690 .get("description")
691 .and_then(|description| description.as_str())
692 .unwrap_or_default(),
693 )
694 }
695
696 fn kind(&self) -> acp::ToolKind;
697
698 /// The initial tool title to display. Can be updated during the tool run.
699 fn initial_title(&self, input: Result<Self::Input, serde_json::Value>) -> SharedString;
700
701 /// Returns the JSON schema that describes the tool's input.
702 fn input_schema(&self) -> Schema {
703 schemars::schema_for!(Self::Input)
704 }
705
706 /// Runs the tool with the provided input.
707 fn run(
708 self: Arc<Self>,
709 input: Self::Input,
710 event_stream: ToolCallEventStream,
711 cx: &mut App,
712 ) -> Task<Result<Self::Output>>;
713
714 fn erase(self) -> Arc<dyn AnyAgentTool> {
715 Arc::new(Erased(Arc::new(self)))
716 }
717}
718
719pub struct Erased<T>(T);
720
721pub struct AgentToolOutput {
722 llm_output: LanguageModelToolResultContent,
723 raw_output: serde_json::Value,
724}
725
726pub trait AnyAgentTool {
727 fn name(&self) -> SharedString;
728 fn description(&self, cx: &mut App) -> SharedString;
729 fn kind(&self) -> acp::ToolKind;
730 fn initial_title(&self, input: serde_json::Value) -> SharedString;
731 fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value>;
732 fn run(
733 self: Arc<Self>,
734 input: serde_json::Value,
735 event_stream: ToolCallEventStream,
736 cx: &mut App,
737 ) -> Task<Result<AgentToolOutput>>;
738}
739
740impl<T> AnyAgentTool for Erased<Arc<T>>
741where
742 T: AgentTool,
743{
744 fn name(&self) -> SharedString {
745 self.0.name()
746 }
747
748 fn description(&self, cx: &mut App) -> SharedString {
749 self.0.description(cx)
750 }
751
752 fn kind(&self) -> agent_client_protocol::ToolKind {
753 self.0.kind()
754 }
755
756 fn initial_title(&self, input: serde_json::Value) -> SharedString {
757 let parsed_input = serde_json::from_value(input.clone()).map_err(|_| input);
758 self.0.initial_title(parsed_input)
759 }
760
761 fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
762 let mut json = serde_json::to_value(self.0.input_schema())?;
763 adapt_schema_to_format(&mut json, format)?;
764 Ok(json)
765 }
766
767 fn run(
768 self: Arc<Self>,
769 input: serde_json::Value,
770 event_stream: ToolCallEventStream,
771 cx: &mut App,
772 ) -> Task<Result<AgentToolOutput>> {
773 cx.spawn(async move |cx| {
774 let input = serde_json::from_value(input)?;
775 let output = cx
776 .update(|cx| self.0.clone().run(input, event_stream, cx))?
777 .await?;
778 let raw_output = serde_json::to_value(&output)?;
779 Ok(AgentToolOutput {
780 llm_output: output.into(),
781 raw_output,
782 })
783 })
784 }
785}
786
787#[derive(Clone)]
788struct AgentResponseEventStream(
789 mpsc::UnboundedSender<Result<AgentResponseEvent, LanguageModelCompletionError>>,
790);
791
792impl AgentResponseEventStream {
793 fn send_text(&self, text: &str) {
794 self.0
795 .unbounded_send(Ok(AgentResponseEvent::Text(text.to_string())))
796 .ok();
797 }
798
799 fn send_thinking(&self, text: &str) {
800 self.0
801 .unbounded_send(Ok(AgentResponseEvent::Thinking(text.to_string())))
802 .ok();
803 }
804
805 fn authorize_tool_call(
806 &self,
807 id: &LanguageModelToolUseId,
808 title: String,
809 kind: acp::ToolKind,
810 input: serde_json::Value,
811 ) -> impl use<> + Future<Output = Result<()>> {
812 let (response_tx, response_rx) = oneshot::channel();
813 self.0
814 .unbounded_send(Ok(AgentResponseEvent::ToolCallAuthorization(
815 ToolCallAuthorization {
816 tool_call: Self::initial_tool_call(id, title, kind, input),
817 options: vec![
818 acp::PermissionOption {
819 id: acp::PermissionOptionId("always_allow".into()),
820 name: "Always Allow".into(),
821 kind: acp::PermissionOptionKind::AllowAlways,
822 },
823 acp::PermissionOption {
824 id: acp::PermissionOptionId("allow".into()),
825 name: "Allow".into(),
826 kind: acp::PermissionOptionKind::AllowOnce,
827 },
828 acp::PermissionOption {
829 id: acp::PermissionOptionId("deny".into()),
830 name: "Deny".into(),
831 kind: acp::PermissionOptionKind::RejectOnce,
832 },
833 ],
834 response: response_tx,
835 },
836 )))
837 .ok();
838 async move {
839 match response_rx.await?.0.as_ref() {
840 "allow" | "always_allow" => Ok(()),
841 _ => Err(anyhow!("Permission to run tool denied by user")),
842 }
843 }
844 }
845
846 fn send_tool_call(
847 &self,
848 id: &LanguageModelToolUseId,
849 title: SharedString,
850 kind: acp::ToolKind,
851 input: serde_json::Value,
852 ) {
853 self.0
854 .unbounded_send(Ok(AgentResponseEvent::ToolCall(Self::initial_tool_call(
855 id,
856 title.to_string(),
857 kind,
858 input,
859 ))))
860 .ok();
861 }
862
863 fn initial_tool_call(
864 id: &LanguageModelToolUseId,
865 title: String,
866 kind: acp::ToolKind,
867 input: serde_json::Value,
868 ) -> acp::ToolCall {
869 acp::ToolCall {
870 id: acp::ToolCallId(id.to_string().into()),
871 title,
872 kind,
873 status: acp::ToolCallStatus::Pending,
874 content: vec![],
875 locations: vec![],
876 raw_input: Some(input),
877 raw_output: None,
878 }
879 }
880
881 fn update_tool_call_fields(
882 &self,
883 tool_use_id: &LanguageModelToolUseId,
884 fields: acp::ToolCallUpdateFields,
885 ) {
886 self.0
887 .unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate(
888 acp::ToolCallUpdate {
889 id: acp::ToolCallId(tool_use_id.to_string().into()),
890 fields,
891 }
892 .into(),
893 )))
894 .ok();
895 }
896
897 fn update_tool_call_diff(&self, tool_use_id: &LanguageModelToolUseId, diff: Entity<Diff>) {
898 self.0
899 .unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate(
900 acp_thread::ToolCallUpdateDiff {
901 id: acp::ToolCallId(tool_use_id.to_string().into()),
902 diff,
903 }
904 .into(),
905 )))
906 .ok();
907 }
908
909 fn send_stop(&self, reason: StopReason) {
910 match reason {
911 StopReason::EndTurn => {
912 self.0
913 .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::EndTurn)))
914 .ok();
915 }
916 StopReason::MaxTokens => {
917 self.0
918 .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::MaxTokens)))
919 .ok();
920 }
921 StopReason::Refusal => {
922 self.0
923 .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::Refusal)))
924 .ok();
925 }
926 StopReason::ToolUse => {}
927 }
928 }
929
930 fn send_error(&self, error: LanguageModelCompletionError) {
931 self.0.unbounded_send(Err(error)).ok();
932 }
933}
934
935#[derive(Clone)]
936pub struct ToolCallEventStream {
937 tool_use_id: LanguageModelToolUseId,
938 kind: acp::ToolKind,
939 input: serde_json::Value,
940 stream: AgentResponseEventStream,
941}
942
943impl ToolCallEventStream {
944 #[cfg(test)]
945 pub fn test() -> (Self, ToolCallEventStreamReceiver) {
946 let (events_tx, events_rx) =
947 mpsc::unbounded::<Result<AgentResponseEvent, LanguageModelCompletionError>>();
948
949 let stream = ToolCallEventStream::new(
950 &LanguageModelToolUse {
951 id: "test_id".into(),
952 name: "test_tool".into(),
953 raw_input: String::new(),
954 input: serde_json::Value::Null,
955 is_input_complete: true,
956 },
957 acp::ToolKind::Other,
958 AgentResponseEventStream(events_tx),
959 );
960
961 (stream, ToolCallEventStreamReceiver(events_rx))
962 }
963
964 fn new(
965 tool_use: &LanguageModelToolUse,
966 kind: acp::ToolKind,
967 stream: AgentResponseEventStream,
968 ) -> Self {
969 Self {
970 tool_use_id: tool_use.id.clone(),
971 kind,
972 input: tool_use.input.clone(),
973 stream,
974 }
975 }
976
977 pub fn update_fields(&self, fields: acp::ToolCallUpdateFields) {
978 self.stream
979 .update_tool_call_fields(&self.tool_use_id, fields);
980 }
981
982 pub fn update_diff(&self, diff: Entity<Diff>) {
983 self.stream.update_tool_call_diff(&self.tool_use_id, diff);
984 }
985
986 pub fn authorize(&self, title: String) -> impl use<> + Future<Output = Result<()>> {
987 self.stream.authorize_tool_call(
988 &self.tool_use_id,
989 title,
990 self.kind.clone(),
991 self.input.clone(),
992 )
993 }
994}
995
996#[cfg(test)]
997pub struct ToolCallEventStreamReceiver(
998 mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
999);
1000
1001#[cfg(test)]
1002impl ToolCallEventStreamReceiver {
1003 pub async fn expect_tool_authorization(&mut self) -> ToolCallAuthorization {
1004 let event = self.0.next().await;
1005 if let Some(Ok(AgentResponseEvent::ToolCallAuthorization(auth))) = event {
1006 auth
1007 } else {
1008 panic!("Expected ToolCallAuthorization but got: {:?}", event);
1009 }
1010 }
1011}
1012
1013#[cfg(test)]
1014impl std::ops::Deref for ToolCallEventStreamReceiver {
1015 type Target = mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>;
1016
1017 fn deref(&self) -> &Self::Target {
1018 &self.0
1019 }
1020}
1021
1022#[cfg(test)]
1023impl std::ops::DerefMut for ToolCallEventStreamReceiver {
1024 fn deref_mut(&mut self) -> &mut Self::Target {
1025 &mut self.0
1026 }
1027}