1use crate::{SystemPromptTemplate, Template, Templates};
2use action_log::ActionLog;
3use agent_client_protocol as acp;
4use anyhow::{Context as _, Result, anyhow};
5use assistant_tool::adapt_schema_to_format;
6use cloud_llm_client::{CompletionIntent, CompletionMode};
7use collections::HashMap;
8use futures::{
9 channel::{mpsc, oneshot},
10 stream::FuturesUnordered,
11};
12use gpui::{App, Context, Entity, SharedString, Task};
13use language_model::{
14 LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
15 LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
16 LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolSchemaFormat,
17 LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role, StopReason,
18};
19use log;
20use project::Project;
21use prompt_store::ProjectContext;
22use schemars::{JsonSchema, Schema};
23use serde::{Deserialize, Serialize};
24use smol::stream::StreamExt;
25use std::{cell::RefCell, collections::BTreeMap, fmt::Write, future::Future, rc::Rc, sync::Arc};
26use util::{ResultExt, markdown::MarkdownCodeBlock};
27
28#[derive(Debug, Clone)]
29pub struct AgentMessage {
30 pub role: Role,
31 pub content: Vec<MessageContent>,
32}
33
34impl AgentMessage {
35 pub fn to_markdown(&self) -> String {
36 let mut markdown = format!("## {}\n", self.role);
37
38 for content in &self.content {
39 match content {
40 MessageContent::Text(text) => {
41 markdown.push_str(text);
42 markdown.push('\n');
43 }
44 MessageContent::Thinking { text, .. } => {
45 markdown.push_str("<think>");
46 markdown.push_str(text);
47 markdown.push_str("</think>\n");
48 }
49 MessageContent::RedactedThinking(_) => markdown.push_str("<redacted_thinking />\n"),
50 MessageContent::Image(_) => {
51 markdown.push_str("<image />\n");
52 }
53 MessageContent::ToolUse(tool_use) => {
54 markdown.push_str(&format!(
55 "**Tool Use**: {} (ID: {})\n",
56 tool_use.name, tool_use.id
57 ));
58 markdown.push_str(&format!(
59 "{}\n",
60 MarkdownCodeBlock {
61 tag: "json",
62 text: &format!("{:#}", tool_use.input)
63 }
64 ));
65 }
66 MessageContent::ToolResult(tool_result) => {
67 markdown.push_str(&format!(
68 "**Tool Result**: {} (ID: {})\n\n",
69 tool_result.tool_name, tool_result.tool_use_id
70 ));
71 if tool_result.is_error {
72 markdown.push_str("**ERROR:**\n");
73 }
74
75 match &tool_result.content {
76 LanguageModelToolResultContent::Text(text) => {
77 writeln!(markdown, "{text}\n").ok();
78 }
79 LanguageModelToolResultContent::Image(_) => {
80 writeln!(markdown, "<image />\n").ok();
81 }
82 }
83
84 if let Some(output) = tool_result.output.as_ref() {
85 writeln!(
86 markdown,
87 "**Debug Output**:\n\n```json\n{}\n```\n",
88 serde_json::to_string_pretty(output).unwrap()
89 )
90 .unwrap();
91 }
92 }
93 }
94 }
95
96 markdown
97 }
98}
99
100#[derive(Debug)]
101pub enum AgentResponseEvent {
102 Text(String),
103 Thinking(String),
104 ToolCall(acp::ToolCall),
105 ToolCallUpdate(acp_thread::ToolCallUpdate),
106 ToolCallAuthorization(ToolCallAuthorization),
107 Stop(acp::StopReason),
108}
109
110#[derive(Debug)]
111pub struct ToolCallAuthorization {
112 pub tool_call: acp::ToolCall,
113 pub options: Vec<acp::PermissionOption>,
114 pub response: oneshot::Sender<acp::PermissionOptionId>,
115}
116
117pub struct Thread {
118 messages: Vec<AgentMessage>,
119 completion_mode: CompletionMode,
120 /// Holds the task that handles agent interaction until the end of the turn.
121 /// Survives across multiple requests as the model performs tool calls and
122 /// we run tools, report their results.
123 running_turn: Option<Task<()>>,
124 pending_tool_uses: HashMap<LanguageModelToolUseId, LanguageModelToolUse>,
125 tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
126 project_context: Rc<RefCell<ProjectContext>>,
127 templates: Arc<Templates>,
128 pub selected_model: Arc<dyn LanguageModel>,
129 project: Entity<Project>,
130 action_log: Entity<ActionLog>,
131}
132
133impl Thread {
134 pub fn new(
135 project: Entity<Project>,
136 project_context: Rc<RefCell<ProjectContext>>,
137 action_log: Entity<ActionLog>,
138 templates: Arc<Templates>,
139 default_model: Arc<dyn LanguageModel>,
140 ) -> Self {
141 Self {
142 messages: Vec::new(),
143 completion_mode: CompletionMode::Normal,
144 running_turn: None,
145 pending_tool_uses: HashMap::default(),
146 tools: BTreeMap::default(),
147 project_context,
148 templates,
149 selected_model: default_model,
150 project,
151 action_log,
152 }
153 }
154
155 pub fn project(&self) -> &Entity<Project> {
156 &self.project
157 }
158
159 pub fn action_log(&self) -> &Entity<ActionLog> {
160 &self.action_log
161 }
162
163 pub fn set_mode(&mut self, mode: CompletionMode) {
164 self.completion_mode = mode;
165 }
166
167 pub fn messages(&self) -> &[AgentMessage] {
168 &self.messages
169 }
170
171 pub fn add_tool(&mut self, tool: impl AgentTool) {
172 self.tools.insert(tool.name(), tool.erase());
173 }
174
175 pub fn remove_tool(&mut self, name: &str) -> bool {
176 self.tools.remove(name).is_some()
177 }
178
179 pub fn cancel(&mut self) {
180 self.running_turn.take();
181
182 let tool_results = self
183 .pending_tool_uses
184 .drain()
185 .map(|(tool_use_id, tool_use)| {
186 MessageContent::ToolResult(LanguageModelToolResult {
187 tool_use_id,
188 tool_name: tool_use.name.clone(),
189 is_error: true,
190 content: LanguageModelToolResultContent::Text("Tool canceled by user".into()),
191 output: None,
192 })
193 })
194 .collect::<Vec<_>>();
195 self.last_user_message().content.extend(tool_results);
196 }
197
198 /// Sending a message results in the model streaming a response, which could include tool calls.
199 /// After calling tools, the model will stops and waits for any outstanding tool calls to be completed and their results sent.
200 /// The returned channel will report all the occurrences in which the model stops before erroring or ending its turn.
201 pub fn send(
202 &mut self,
203 content: impl Into<MessageContent>,
204 cx: &mut Context<Self>,
205 ) -> mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>> {
206 let content = content.into();
207 let model = self.selected_model.clone();
208 log::info!("Thread::send called with model: {:?}", model.name());
209 log::debug!("Thread::send content: {:?}", content);
210
211 cx.notify();
212 let (events_tx, events_rx) =
213 mpsc::unbounded::<Result<AgentResponseEvent, LanguageModelCompletionError>>();
214 let event_stream = AgentResponseEventStream(events_tx);
215
216 let user_message_ix = self.messages.len();
217 self.messages.push(AgentMessage {
218 role: Role::User,
219 content: vec![content],
220 });
221 log::info!("Total messages in thread: {}", self.messages.len());
222 self.running_turn = Some(cx.spawn(async move |thread, cx| {
223 log::info!("Starting agent turn execution");
224 let turn_result = async {
225 // Perform one request, then keep looping if the model makes tool calls.
226 let mut completion_intent = CompletionIntent::UserPrompt;
227 'outer: loop {
228 log::debug!(
229 "Building completion request with intent: {:?}",
230 completion_intent
231 );
232 let request = thread.update(cx, |thread, cx| {
233 thread.build_completion_request(completion_intent, cx)
234 })?;
235
236 // println!(
237 // "request: {}",
238 // serde_json::to_string_pretty(&request).unwrap()
239 // );
240
241 // Stream events, appending to messages and collecting up tool uses.
242 log::info!("Calling model.stream_completion");
243 let mut events = model.stream_completion(request, cx).await?;
244 log::debug!("Stream completion started successfully");
245 let mut tool_uses = FuturesUnordered::new();
246 while let Some(event) = events.next().await {
247 match event {
248 Ok(LanguageModelCompletionEvent::Stop(reason)) => {
249 event_stream.send_stop(reason);
250 if reason == StopReason::Refusal {
251 thread.update(cx, |thread, _cx| {
252 thread.messages.truncate(user_message_ix);
253 })?;
254 break 'outer;
255 }
256 }
257 Ok(event) => {
258 log::trace!("Received completion event: {:?}", event);
259 thread
260 .update(cx, |thread, cx| {
261 tool_uses.extend(thread.handle_streamed_completion_event(
262 event,
263 &event_stream,
264 cx,
265 ));
266 })
267 .ok();
268 }
269 Err(error) => {
270 log::error!("Error in completion stream: {:?}", error);
271 event_stream.send_error(error);
272 break;
273 }
274 }
275 }
276
277 // If there are no tool uses, the turn is done.
278 if tool_uses.is_empty() {
279 log::info!("No tool uses found, completing turn");
280 break;
281 }
282 log::info!("Found {} tool uses to execute", tool_uses.len());
283
284 // As tool results trickle in, insert them in the last user
285 // message so that they can be sent on the next tick of the
286 // agentic loop.
287 while let Some(tool_result) = tool_uses.next().await {
288 log::info!("Tool finished {:?}", tool_result);
289
290 event_stream.update_tool_call_fields(
291 &tool_result.tool_use_id,
292 acp::ToolCallUpdateFields {
293 status: Some(if tool_result.is_error {
294 acp::ToolCallStatus::Failed
295 } else {
296 acp::ToolCallStatus::Completed
297 }),
298 ..Default::default()
299 },
300 );
301 thread
302 .update(cx, |thread, _cx| {
303 thread.pending_tool_uses.remove(&tool_result.tool_use_id);
304 thread
305 .last_user_message()
306 .content
307 .push(MessageContent::ToolResult(tool_result));
308 })
309 .ok();
310 }
311
312 completion_intent = CompletionIntent::ToolResults;
313 }
314
315 Ok(())
316 }
317 .await;
318
319 if let Err(error) = turn_result {
320 log::error!("Turn execution failed: {:?}", error);
321 event_stream.send_error(error);
322 } else {
323 log::info!("Turn execution completed successfully");
324 }
325 }));
326 events_rx
327 }
328
329 pub fn build_system_message(&self) -> AgentMessage {
330 log::debug!("Building system message");
331 let prompt = SystemPromptTemplate {
332 project: &self.project_context.borrow(),
333 available_tools: self.tools.keys().cloned().collect(),
334 }
335 .render(&self.templates)
336 .context("failed to build system prompt")
337 .expect("Invalid template");
338 log::debug!("System message built");
339 AgentMessage {
340 role: Role::System,
341 content: vec![prompt.into()],
342 }
343 }
344
345 /// A helper method that's called on every streamed completion event.
346 /// Returns an optional tool result task, which the main agentic loop in
347 /// send will send back to the model when it resolves.
348 fn handle_streamed_completion_event(
349 &mut self,
350 event: LanguageModelCompletionEvent,
351 event_stream: &AgentResponseEventStream,
352 cx: &mut Context<Self>,
353 ) -> Option<Task<LanguageModelToolResult>> {
354 log::trace!("Handling streamed completion event: {:?}", event);
355 use LanguageModelCompletionEvent::*;
356
357 match event {
358 StartMessage { .. } => {
359 self.messages.push(AgentMessage {
360 role: Role::Assistant,
361 content: Vec::new(),
362 });
363 }
364 Text(new_text) => self.handle_text_event(new_text, event_stream, cx),
365 Thinking { text, signature } => {
366 self.handle_thinking_event(text, signature, event_stream, cx)
367 }
368 RedactedThinking { data } => self.handle_redacted_thinking_event(data, cx),
369 ToolUse(tool_use) => {
370 return self.handle_tool_use_event(tool_use, event_stream, cx);
371 }
372 ToolUseJsonParseError {
373 id,
374 tool_name,
375 raw_input,
376 json_parse_error,
377 } => {
378 return Some(Task::ready(self.handle_tool_use_json_parse_error_event(
379 id,
380 tool_name,
381 raw_input,
382 json_parse_error,
383 )));
384 }
385 UsageUpdate(_) | StatusUpdate(_) => {}
386 Stop(_) => unreachable!(),
387 }
388
389 None
390 }
391
392 fn handle_text_event(
393 &mut self,
394 new_text: String,
395 events_stream: &AgentResponseEventStream,
396 cx: &mut Context<Self>,
397 ) {
398 events_stream.send_text(&new_text);
399
400 let last_message = self.last_assistant_message();
401 if let Some(MessageContent::Text(text)) = last_message.content.last_mut() {
402 text.push_str(&new_text);
403 } else {
404 last_message.content.push(MessageContent::Text(new_text));
405 }
406
407 cx.notify();
408 }
409
410 fn handle_thinking_event(
411 &mut self,
412 new_text: String,
413 new_signature: Option<String>,
414 event_stream: &AgentResponseEventStream,
415 cx: &mut Context<Self>,
416 ) {
417 event_stream.send_thinking(&new_text);
418
419 let last_message = self.last_assistant_message();
420 if let Some(MessageContent::Thinking { text, signature }) = last_message.content.last_mut()
421 {
422 text.push_str(&new_text);
423 *signature = new_signature.or(signature.take());
424 } else {
425 last_message.content.push(MessageContent::Thinking {
426 text: new_text,
427 signature: new_signature,
428 });
429 }
430
431 cx.notify();
432 }
433
434 fn handle_redacted_thinking_event(&mut self, data: String, cx: &mut Context<Self>) {
435 let last_message = self.last_assistant_message();
436 last_message
437 .content
438 .push(MessageContent::RedactedThinking(data));
439 cx.notify();
440 }
441
442 fn handle_tool_use_event(
443 &mut self,
444 tool_use: LanguageModelToolUse,
445 event_stream: &AgentResponseEventStream,
446 cx: &mut Context<Self>,
447 ) -> Option<Task<LanguageModelToolResult>> {
448 cx.notify();
449
450 let tool = self.tools.get(tool_use.name.as_ref()).cloned();
451
452 self.pending_tool_uses
453 .insert(tool_use.id.clone(), tool_use.clone());
454 let last_message = self.last_assistant_message();
455
456 // Ensure the last message ends in the current tool use
457 let push_new_tool_use = last_message.content.last_mut().map_or(true, |content| {
458 if let MessageContent::ToolUse(last_tool_use) = content {
459 if last_tool_use.id == tool_use.id {
460 *last_tool_use = tool_use.clone();
461 false
462 } else {
463 true
464 }
465 } else {
466 true
467 }
468 });
469
470 let mut title = SharedString::from(&tool_use.name);
471 let mut kind = acp::ToolKind::Other;
472 if let Some(tool) = tool.as_ref() {
473 title = tool.initial_title(tool_use.input.clone());
474 kind = tool.kind();
475 }
476
477 if push_new_tool_use {
478 event_stream.send_tool_call(&tool_use.id, title, kind, tool_use.input.clone());
479 last_message
480 .content
481 .push(MessageContent::ToolUse(tool_use.clone()));
482 } else {
483 event_stream.update_tool_call_fields(
484 &tool_use.id,
485 acp::ToolCallUpdateFields {
486 title: Some(title.into()),
487 kind: Some(kind),
488 raw_input: Some(tool_use.input.clone()),
489 ..Default::default()
490 },
491 );
492 }
493
494 if !tool_use.is_input_complete {
495 return None;
496 }
497
498 let Some(tool) = tool else {
499 let content = format!("No tool named {} exists", tool_use.name);
500 return Some(Task::ready(LanguageModelToolResult {
501 content: LanguageModelToolResultContent::Text(Arc::from(content)),
502 tool_use_id: tool_use.id,
503 tool_name: tool_use.name,
504 is_error: true,
505 output: None,
506 }));
507 };
508
509 let tool_event_stream =
510 ToolCallEventStream::new(&tool_use, tool.kind(), event_stream.clone());
511 tool_event_stream.update_fields(acp::ToolCallUpdateFields {
512 status: Some(acp::ToolCallStatus::InProgress),
513 ..Default::default()
514 });
515 let supports_images = self.selected_model.supports_images();
516 let tool_result = tool.run(tool_use.input, tool_event_stream, cx);
517 Some(cx.foreground_executor().spawn(async move {
518 let tool_result = tool_result.await.and_then(|output| {
519 if let LanguageModelToolResultContent::Image(_) = &output.llm_output {
520 if !supports_images {
521 return Err(anyhow!(
522 "Attempted to read an image, but this model doesn't support it.",
523 ));
524 }
525 }
526 Ok(output)
527 });
528
529 match tool_result {
530 Ok(output) => LanguageModelToolResult {
531 tool_use_id: tool_use.id,
532 tool_name: tool_use.name,
533 is_error: false,
534 content: output.llm_output,
535 output: Some(output.raw_output),
536 },
537 Err(error) => LanguageModelToolResult {
538 tool_use_id: tool_use.id,
539 tool_name: tool_use.name,
540 is_error: true,
541 content: LanguageModelToolResultContent::Text(Arc::from(error.to_string())),
542 output: None,
543 },
544 }
545 }))
546 }
547
548 fn handle_tool_use_json_parse_error_event(
549 &mut self,
550 tool_use_id: LanguageModelToolUseId,
551 tool_name: Arc<str>,
552 raw_input: Arc<str>,
553 json_parse_error: String,
554 ) -> LanguageModelToolResult {
555 let tool_output = format!("Error parsing input JSON: {json_parse_error}");
556 LanguageModelToolResult {
557 tool_use_id,
558 tool_name,
559 is_error: true,
560 content: LanguageModelToolResultContent::Text(tool_output.into()),
561 output: Some(serde_json::Value::String(raw_input.to_string())),
562 }
563 }
564
565 /// Guarantees the last message is from the assistant and returns a mutable reference.
566 fn last_assistant_message(&mut self) -> &mut AgentMessage {
567 if self
568 .messages
569 .last()
570 .map_or(true, |m| m.role != Role::Assistant)
571 {
572 self.messages.push(AgentMessage {
573 role: Role::Assistant,
574 content: Vec::new(),
575 });
576 }
577 self.messages.last_mut().unwrap()
578 }
579
580 /// Guarantees the last message is from the user and returns a mutable reference.
581 fn last_user_message(&mut self) -> &mut AgentMessage {
582 if self.messages.last().map_or(true, |m| m.role != Role::User) {
583 self.messages.push(AgentMessage {
584 role: Role::User,
585 content: Vec::new(),
586 });
587 }
588 self.messages.last_mut().unwrap()
589 }
590
591 pub(crate) fn build_completion_request(
592 &self,
593 completion_intent: CompletionIntent,
594 cx: &mut App,
595 ) -> LanguageModelRequest {
596 log::debug!("Building completion request");
597 log::debug!("Completion intent: {:?}", completion_intent);
598 log::debug!("Completion mode: {:?}", self.completion_mode);
599
600 let messages = self.build_request_messages();
601 log::info!("Request will include {} messages", messages.len());
602
603 let tools: Vec<LanguageModelRequestTool> = self
604 .tools
605 .values()
606 .filter_map(|tool| {
607 let tool_name = tool.name().to_string();
608 log::trace!("Including tool: {}", tool_name);
609 Some(LanguageModelRequestTool {
610 name: tool_name,
611 description: tool.description(cx).to_string(),
612 input_schema: tool
613 .input_schema(self.selected_model.tool_input_format())
614 .log_err()?,
615 })
616 })
617 .collect();
618
619 log::info!("Request includes {} tools", tools.len());
620
621 let request = LanguageModelRequest {
622 thread_id: None,
623 prompt_id: None,
624 intent: Some(completion_intent),
625 mode: Some(self.completion_mode),
626 messages,
627 tools,
628 tool_choice: None,
629 stop: Vec::new(),
630 temperature: None,
631 thinking_allowed: true,
632 };
633
634 log::debug!("Completion request built successfully");
635 request
636 }
637
638 fn build_request_messages(&self) -> Vec<LanguageModelRequestMessage> {
639 log::trace!(
640 "Building request messages from {} thread messages",
641 self.messages.len()
642 );
643
644 let messages = Some(self.build_system_message())
645 .iter()
646 .chain(self.messages.iter())
647 .map(|message| {
648 log::trace!(
649 " - {} message with {} content items",
650 match message.role {
651 Role::System => "System",
652 Role::User => "User",
653 Role::Assistant => "Assistant",
654 },
655 message.content.len()
656 );
657 LanguageModelRequestMessage {
658 role: message.role,
659 content: message.content.clone(),
660 cache: false,
661 }
662 })
663 .collect();
664 messages
665 }
666
667 pub fn to_markdown(&self) -> String {
668 let mut markdown = String::new();
669 for message in &self.messages {
670 markdown.push_str(&message.to_markdown());
671 }
672 markdown
673 }
674}
675
676pub trait AgentTool
677where
678 Self: 'static + Sized,
679{
680 type Input: for<'de> Deserialize<'de> + Serialize + JsonSchema;
681 type Output: for<'de> Deserialize<'de> + Serialize + Into<LanguageModelToolResultContent>;
682
683 fn name(&self) -> SharedString;
684
685 fn description(&self, _cx: &mut App) -> SharedString {
686 let schema = schemars::schema_for!(Self::Input);
687 SharedString::new(
688 schema
689 .get("description")
690 .and_then(|description| description.as_str())
691 .unwrap_or_default(),
692 )
693 }
694
695 fn kind(&self) -> acp::ToolKind;
696
697 /// The initial tool title to display. Can be updated during the tool run.
698 fn initial_title(&self, input: Result<Self::Input, serde_json::Value>) -> SharedString;
699
700 /// Returns the JSON schema that describes the tool's input.
701 fn input_schema(&self) -> Schema {
702 schemars::schema_for!(Self::Input)
703 }
704
705 /// Runs the tool with the provided input.
706 fn run(
707 self: Arc<Self>,
708 input: Self::Input,
709 event_stream: ToolCallEventStream,
710 cx: &mut App,
711 ) -> Task<Result<Self::Output>>;
712
713 fn erase(self) -> Arc<dyn AnyAgentTool> {
714 Arc::new(Erased(Arc::new(self)))
715 }
716}
717
718pub struct Erased<T>(T);
719
720pub struct AgentToolOutput {
721 llm_output: LanguageModelToolResultContent,
722 raw_output: serde_json::Value,
723}
724
725pub trait AnyAgentTool {
726 fn name(&self) -> SharedString;
727 fn description(&self, cx: &mut App) -> SharedString;
728 fn kind(&self) -> acp::ToolKind;
729 fn initial_title(&self, input: serde_json::Value) -> SharedString;
730 fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value>;
731 fn run(
732 self: Arc<Self>,
733 input: serde_json::Value,
734 event_stream: ToolCallEventStream,
735 cx: &mut App,
736 ) -> Task<Result<AgentToolOutput>>;
737}
738
739impl<T> AnyAgentTool for Erased<Arc<T>>
740where
741 T: AgentTool,
742{
743 fn name(&self) -> SharedString {
744 self.0.name()
745 }
746
747 fn description(&self, cx: &mut App) -> SharedString {
748 self.0.description(cx)
749 }
750
751 fn kind(&self) -> agent_client_protocol::ToolKind {
752 self.0.kind()
753 }
754
755 fn initial_title(&self, input: serde_json::Value) -> SharedString {
756 let parsed_input = serde_json::from_value(input.clone()).map_err(|_| input);
757 self.0.initial_title(parsed_input)
758 }
759
760 fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
761 let mut json = serde_json::to_value(self.0.input_schema())?;
762 adapt_schema_to_format(&mut json, format)?;
763 Ok(json)
764 }
765
766 fn run(
767 self: Arc<Self>,
768 input: serde_json::Value,
769 event_stream: ToolCallEventStream,
770 cx: &mut App,
771 ) -> Task<Result<AgentToolOutput>> {
772 cx.spawn(async move |cx| {
773 let input = serde_json::from_value(input)?;
774 let output = cx
775 .update(|cx| self.0.clone().run(input, event_stream, cx))?
776 .await?;
777 let raw_output = serde_json::to_value(&output)?;
778 Ok(AgentToolOutput {
779 llm_output: output.into(),
780 raw_output,
781 })
782 })
783 }
784}
785
786#[derive(Clone)]
787struct AgentResponseEventStream(
788 mpsc::UnboundedSender<Result<AgentResponseEvent, LanguageModelCompletionError>>,
789);
790
791impl AgentResponseEventStream {
792 fn send_text(&self, text: &str) {
793 self.0
794 .unbounded_send(Ok(AgentResponseEvent::Text(text.to_string())))
795 .ok();
796 }
797
798 fn send_thinking(&self, text: &str) {
799 self.0
800 .unbounded_send(Ok(AgentResponseEvent::Thinking(text.to_string())))
801 .ok();
802 }
803
804 fn send_tool_call(
805 &self,
806 id: &LanguageModelToolUseId,
807 title: SharedString,
808 kind: acp::ToolKind,
809 input: serde_json::Value,
810 ) {
811 self.0
812 .unbounded_send(Ok(AgentResponseEvent::ToolCall(Self::initial_tool_call(
813 id,
814 title.to_string(),
815 kind,
816 input,
817 ))))
818 .ok();
819 }
820
821 fn initial_tool_call(
822 id: &LanguageModelToolUseId,
823 title: String,
824 kind: acp::ToolKind,
825 input: serde_json::Value,
826 ) -> acp::ToolCall {
827 acp::ToolCall {
828 id: acp::ToolCallId(id.to_string().into()),
829 title,
830 kind,
831 status: acp::ToolCallStatus::Pending,
832 content: vec![],
833 locations: vec![],
834 raw_input: Some(input),
835 raw_output: None,
836 }
837 }
838
839 fn update_tool_call_fields(
840 &self,
841 tool_use_id: &LanguageModelToolUseId,
842 fields: acp::ToolCallUpdateFields,
843 ) {
844 self.0
845 .unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate(
846 acp::ToolCallUpdate {
847 id: acp::ToolCallId(tool_use_id.to_string().into()),
848 fields,
849 }
850 .into(),
851 )))
852 .ok();
853 }
854
855 fn send_stop(&self, reason: StopReason) {
856 match reason {
857 StopReason::EndTurn => {
858 self.0
859 .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::EndTurn)))
860 .ok();
861 }
862 StopReason::MaxTokens => {
863 self.0
864 .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::MaxTokens)))
865 .ok();
866 }
867 StopReason::Refusal => {
868 self.0
869 .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::Refusal)))
870 .ok();
871 }
872 StopReason::ToolUse => {}
873 }
874 }
875
876 fn send_error(&self, error: LanguageModelCompletionError) {
877 self.0.unbounded_send(Err(error)).ok();
878 }
879}
880
881#[derive(Clone)]
882pub struct ToolCallEventStream {
883 tool_use_id: LanguageModelToolUseId,
884 kind: acp::ToolKind,
885 input: serde_json::Value,
886 stream: AgentResponseEventStream,
887}
888
889impl ToolCallEventStream {
890 #[cfg(test)]
891 pub fn test() -> (Self, ToolCallEventStreamReceiver) {
892 let (events_tx, events_rx) =
893 mpsc::unbounded::<Result<AgentResponseEvent, LanguageModelCompletionError>>();
894
895 let stream = ToolCallEventStream::new(
896 &LanguageModelToolUse {
897 id: "test_id".into(),
898 name: "test_tool".into(),
899 raw_input: String::new(),
900 input: serde_json::Value::Null,
901 is_input_complete: true,
902 },
903 acp::ToolKind::Other,
904 AgentResponseEventStream(events_tx),
905 );
906
907 (stream, ToolCallEventStreamReceiver(events_rx))
908 }
909
910 fn new(
911 tool_use: &LanguageModelToolUse,
912 kind: acp::ToolKind,
913 stream: AgentResponseEventStream,
914 ) -> Self {
915 Self {
916 tool_use_id: tool_use.id.clone(),
917 kind,
918 input: tool_use.input.clone(),
919 stream,
920 }
921 }
922
923 pub fn update_fields(&self, fields: acp::ToolCallUpdateFields) {
924 self.stream
925 .update_tool_call_fields(&self.tool_use_id, fields);
926 }
927
928 pub fn update_diff(&self, diff: Entity<acp_thread::Diff>) {
929 self.stream
930 .0
931 .unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate(
932 acp_thread::ToolCallUpdateDiff {
933 id: acp::ToolCallId(self.tool_use_id.to_string().into()),
934 diff,
935 }
936 .into(),
937 )))
938 .ok();
939 }
940
941 pub fn update_terminal(&self, terminal: Entity<acp_thread::Terminal>) {
942 self.stream
943 .0
944 .unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate(
945 acp_thread::ToolCallUpdateTerminal {
946 id: acp::ToolCallId(self.tool_use_id.to_string().into()),
947 terminal,
948 }
949 .into(),
950 )))
951 .ok();
952 }
953
954 pub fn authorize(&self, title: String) -> impl use<> + Future<Output = Result<()>> {
955 let (response_tx, response_rx) = oneshot::channel();
956 self.stream
957 .0
958 .unbounded_send(Ok(AgentResponseEvent::ToolCallAuthorization(
959 ToolCallAuthorization {
960 tool_call: AgentResponseEventStream::initial_tool_call(
961 &self.tool_use_id,
962 title,
963 self.kind.clone(),
964 self.input.clone(),
965 ),
966 options: vec![
967 acp::PermissionOption {
968 id: acp::PermissionOptionId("always_allow".into()),
969 name: "Always Allow".into(),
970 kind: acp::PermissionOptionKind::AllowAlways,
971 },
972 acp::PermissionOption {
973 id: acp::PermissionOptionId("allow".into()),
974 name: "Allow".into(),
975 kind: acp::PermissionOptionKind::AllowOnce,
976 },
977 acp::PermissionOption {
978 id: acp::PermissionOptionId("deny".into()),
979 name: "Deny".into(),
980 kind: acp::PermissionOptionKind::RejectOnce,
981 },
982 ],
983 response: response_tx,
984 },
985 )))
986 .ok();
987 async move {
988 match response_rx.await?.0.as_ref() {
989 "allow" | "always_allow" => Ok(()),
990 _ => Err(anyhow!("Permission to run tool denied by user")),
991 }
992 }
993 }
994}
995
996#[cfg(test)]
997pub struct ToolCallEventStreamReceiver(
998 mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
999);
1000
1001#[cfg(test)]
1002impl ToolCallEventStreamReceiver {
1003 pub async fn expect_authorization(&mut self) -> ToolCallAuthorization {
1004 let event = self.0.next().await;
1005 if let Some(Ok(AgentResponseEvent::ToolCallAuthorization(auth))) = event {
1006 auth
1007 } else {
1008 panic!("Expected ToolCallAuthorization but got: {:?}", event);
1009 }
1010 }
1011
1012 pub async fn expect_terminal(&mut self) -> Entity<acp_thread::Terminal> {
1013 let event = self.0.next().await;
1014 if let Some(Ok(AgentResponseEvent::ToolCallUpdate(
1015 acp_thread::ToolCallUpdate::UpdateTerminal(update),
1016 ))) = event
1017 {
1018 update.terminal
1019 } else {
1020 panic!("Expected terminal but got: {:?}", event);
1021 }
1022 }
1023}
1024
1025#[cfg(test)]
1026impl std::ops::Deref for ToolCallEventStreamReceiver {
1027 type Target = mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>;
1028
1029 fn deref(&self) -> &Self::Target {
1030 &self.0
1031 }
1032}
1033
1034#[cfg(test)]
1035impl std::ops::DerefMut for ToolCallEventStreamReceiver {
1036 fn deref_mut(&mut self) -> &mut Self::Target {
1037 &mut self.0
1038 }
1039}