1use crate::templates::{SystemPromptTemplate, Template, Templates};
2use agent_client_protocol as acp;
3use anyhow::{anyhow, Context as _, Result};
4use assistant_tool::{adapt_schema_to_format, ActionLog};
5use cloud_llm_client::{CompletionIntent, CompletionMode};
6use collections::HashMap;
7use futures::{
8 channel::{mpsc, oneshot},
9 stream::FuturesUnordered,
10};
11use gpui::{App, Context, Entity, SharedString, Task};
12use language_model::{
13 LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
14 LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
15 LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolSchemaFormat,
16 LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role, StopReason,
17};
18use log;
19use project::Project;
20use prompt_store::ProjectContext;
21use schemars::{JsonSchema, Schema};
22use serde::{Deserialize, Serialize};
23use smol::stream::StreamExt;
24use std::{cell::RefCell, collections::BTreeMap, fmt::Write, future::Future, rc::Rc, sync::Arc};
25use util::{markdown::MarkdownCodeBlock, ResultExt};
26
27#[derive(Debug, Clone)]
28pub struct AgentMessage {
29 pub role: Role,
30 pub content: Vec<MessageContent>,
31}
32
33impl AgentMessage {
34 pub fn to_markdown(&self) -> String {
35 let mut markdown = format!("## {}\n", self.role);
36
37 for content in &self.content {
38 match content {
39 MessageContent::Text(text) => {
40 markdown.push_str(text);
41 markdown.push('\n');
42 }
43 MessageContent::Thinking { text, .. } => {
44 markdown.push_str("<think>");
45 markdown.push_str(text);
46 markdown.push_str("</think>\n");
47 }
48 MessageContent::RedactedThinking(_) => markdown.push_str("<redacted_thinking />\n"),
49 MessageContent::Image(_) => {
50 markdown.push_str("<image />\n");
51 }
52 MessageContent::ToolUse(tool_use) => {
53 markdown.push_str(&format!(
54 "**Tool Use**: {} (ID: {})\n",
55 tool_use.name, tool_use.id
56 ));
57 markdown.push_str(&format!(
58 "{}\n",
59 MarkdownCodeBlock {
60 tag: "json",
61 text: &format!("{:#}", tool_use.input)
62 }
63 ));
64 }
65 MessageContent::ToolResult(tool_result) => {
66 markdown.push_str(&format!(
67 "**Tool Result**: {} (ID: {})\n\n",
68 tool_result.tool_name, tool_result.tool_use_id
69 ));
70 if tool_result.is_error {
71 markdown.push_str("**ERROR:**\n");
72 }
73
74 match &tool_result.content {
75 LanguageModelToolResultContent::Text(text) => {
76 writeln!(markdown, "{text}\n").ok();
77 }
78 LanguageModelToolResultContent::Image(_) => {
79 writeln!(markdown, "<image />\n").ok();
80 }
81 }
82
83 if let Some(output) = tool_result.output.as_ref() {
84 writeln!(
85 markdown,
86 "**Debug Output**:\n\n```json\n{}\n```\n",
87 serde_json::to_string_pretty(output).unwrap()
88 )
89 .unwrap();
90 }
91 }
92 }
93 }
94
95 markdown
96 }
97}
98
99#[derive(Debug)]
100pub enum AgentResponseEvent {
101 Text(String),
102 Thinking(String),
103 ToolCall(acp::ToolCall),
104 ToolCallUpdate(acp::ToolCallUpdate),
105 ToolCallAuthorization(ToolCallAuthorization),
106 Stop(acp::StopReason),
107}
108
109#[derive(Debug)]
110pub struct ToolCallAuthorization {
111 pub tool_call: acp::ToolCall,
112 pub options: Vec<acp::PermissionOption>,
113 pub response: oneshot::Sender<acp::PermissionOptionId>,
114}
115
116pub struct Thread {
117 messages: Vec<AgentMessage>,
118 completion_mode: CompletionMode,
119 /// Holds the task that handles agent interaction until the end of the turn.
120 /// Survives across multiple requests as the model performs tool calls and
121 /// we run tools, report their results.
122 running_turn: Option<Task<()>>,
123 pending_tool_uses: HashMap<LanguageModelToolUseId, LanguageModelToolUse>,
124 tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
125 project_context: Rc<RefCell<ProjectContext>>,
126 templates: Arc<Templates>,
127 pub selected_model: Arc<dyn LanguageModel>,
128 action_log: Entity<ActionLog>,
129}
130
131impl Thread {
132 pub fn new(
133 _project: Entity<Project>,
134 project_context: Rc<RefCell<ProjectContext>>,
135 action_log: Entity<ActionLog>,
136 templates: Arc<Templates>,
137 default_model: Arc<dyn LanguageModel>,
138 ) -> Self {
139 Self {
140 messages: Vec::new(),
141 completion_mode: CompletionMode::Normal,
142 running_turn: None,
143 pending_tool_uses: HashMap::default(),
144 tools: BTreeMap::default(),
145 project_context,
146 templates,
147 selected_model: default_model,
148 action_log,
149 }
150 }
151
152 pub fn set_mode(&mut self, mode: CompletionMode) {
153 self.completion_mode = mode;
154 }
155
156 pub fn messages(&self) -> &[AgentMessage] {
157 &self.messages
158 }
159
160 pub fn add_tool(&mut self, tool: impl AgentTool) {
161 self.tools.insert(tool.name(), tool.erase());
162 }
163
164 pub fn remove_tool(&mut self, name: &str) -> bool {
165 self.tools.remove(name).is_some()
166 }
167
168 pub fn cancel(&mut self) {
169 self.running_turn.take();
170
171 let tool_results = self
172 .pending_tool_uses
173 .drain()
174 .map(|(tool_use_id, tool_use)| {
175 MessageContent::ToolResult(LanguageModelToolResult {
176 tool_use_id,
177 tool_name: tool_use.name.clone(),
178 is_error: true,
179 content: LanguageModelToolResultContent::Text("Tool canceled by user".into()),
180 output: None,
181 })
182 })
183 .collect::<Vec<_>>();
184 self.last_user_message().content.extend(tool_results);
185 }
186
187 /// Sending a message results in the model streaming a response, which could include tool calls.
188 /// After calling tools, the model will stops and waits for any outstanding tool calls to be completed and their results sent.
189 /// The returned channel will report all the occurrences in which the model stops before erroring or ending its turn.
190 pub fn send(
191 &mut self,
192 model: Arc<dyn LanguageModel>,
193 content: impl Into<MessageContent>,
194 cx: &mut Context<Self>,
195 ) -> mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>> {
196 let content = content.into();
197 log::info!("Thread::send called with model: {:?}", model.name());
198 log::debug!("Thread::send content: {:?}", content);
199
200 cx.notify();
201 let (events_tx, events_rx) =
202 mpsc::unbounded::<Result<AgentResponseEvent, LanguageModelCompletionError>>();
203 let event_stream = AgentResponseEventStream(events_tx);
204
205 let user_message_ix = self.messages.len();
206 self.messages.push(AgentMessage {
207 role: Role::User,
208 content: vec![content],
209 });
210 log::info!("Total messages in thread: {}", self.messages.len());
211 self.running_turn = Some(cx.spawn(async move |thread, cx| {
212 log::info!("Starting agent turn execution");
213 let turn_result = async {
214 // Perform one request, then keep looping if the model makes tool calls.
215 let mut completion_intent = CompletionIntent::UserPrompt;
216 'outer: loop {
217 log::debug!(
218 "Building completion request with intent: {:?}",
219 completion_intent
220 );
221 let request = thread.update(cx, |thread, cx| {
222 thread.build_completion_request(completion_intent, cx)
223 })?;
224
225 // println!(
226 // "request: {}",
227 // serde_json::to_string_pretty(&request).unwrap()
228 // );
229
230 // Stream events, appending to messages and collecting up tool uses.
231 log::info!("Calling model.stream_completion");
232 let mut events = model.stream_completion(request, cx).await?;
233 log::debug!("Stream completion started successfully");
234 let mut tool_uses = FuturesUnordered::new();
235 while let Some(event) = events.next().await {
236 match event {
237 Ok(LanguageModelCompletionEvent::Stop(reason)) => {
238 event_stream.send_stop(reason);
239 if reason == StopReason::Refusal {
240 thread.update(cx, |thread, _cx| {
241 thread.messages.truncate(user_message_ix);
242 })?;
243 break 'outer;
244 }
245 }
246 Ok(event) => {
247 log::trace!("Received completion event: {:?}", event);
248 thread
249 .update(cx, |thread, cx| {
250 tool_uses.extend(thread.handle_streamed_completion_event(
251 event,
252 &event_stream,
253 cx,
254 ));
255 })
256 .ok();
257 }
258 Err(error) => {
259 log::error!("Error in completion stream: {:?}", error);
260 event_stream.send_error(error);
261 break;
262 }
263 }
264 }
265
266 // If there are no tool uses, the turn is done.
267 if tool_uses.is_empty() {
268 log::info!("No tool uses found, completing turn");
269 break;
270 }
271 log::info!("Found {} tool uses to execute", tool_uses.len());
272
273 // As tool results trickle in, insert them in the last user
274 // message so that they can be sent on the next tick of the
275 // agentic loop.
276 while let Some(tool_result) = tool_uses.next().await {
277 log::info!("Tool finished {:?}", tool_result);
278
279 event_stream.send_tool_call_update(
280 &tool_result.tool_use_id,
281 acp::ToolCallUpdateFields {
282 status: Some(if tool_result.is_error {
283 acp::ToolCallStatus::Failed
284 } else {
285 acp::ToolCallStatus::Completed
286 }),
287 ..Default::default()
288 },
289 );
290 thread
291 .update(cx, |thread, _cx| {
292 thread.pending_tool_uses.remove(&tool_result.tool_use_id);
293 thread
294 .last_user_message()
295 .content
296 .push(MessageContent::ToolResult(tool_result));
297 })
298 .ok();
299 }
300
301 completion_intent = CompletionIntent::ToolResults;
302 }
303
304 Ok(())
305 }
306 .await;
307
308 if let Err(error) = turn_result {
309 log::error!("Turn execution failed: {:?}", error);
310 event_stream.send_error(error);
311 } else {
312 log::info!("Turn execution completed successfully");
313 }
314 }));
315 events_rx
316 }
317
318 pub fn action_log(&self) -> &Entity<ActionLog> {
319 &self.action_log
320 }
321
322 pub fn build_system_message(&self) -> AgentMessage {
323 log::debug!("Building system message");
324 let prompt = SystemPromptTemplate {
325 project: &self.project_context.borrow(),
326 available_tools: self.tools.keys().cloned().collect(),
327 }
328 .render(&self.templates)
329 .context("failed to build system prompt")
330 .expect("Invalid template");
331 log::debug!("System message built");
332 AgentMessage {
333 role: Role::System,
334 content: vec![prompt.into()],
335 }
336 }
337
338 /// A helper method that's called on every streamed completion event.
339 /// Returns an optional tool result task, which the main agentic loop in
340 /// send will send back to the model when it resolves.
341 fn handle_streamed_completion_event(
342 &mut self,
343 event: LanguageModelCompletionEvent,
344 event_stream: &AgentResponseEventStream,
345 cx: &mut Context<Self>,
346 ) -> Option<Task<LanguageModelToolResult>> {
347 log::trace!("Handling streamed completion event: {:?}", event);
348 use LanguageModelCompletionEvent::*;
349
350 match event {
351 StartMessage { .. } => {
352 self.messages.push(AgentMessage {
353 role: Role::Assistant,
354 content: Vec::new(),
355 });
356 }
357 Text(new_text) => self.handle_text_event(new_text, event_stream, cx),
358 Thinking { text, signature } => {
359 self.handle_thinking_event(text, signature, event_stream, cx)
360 }
361 RedactedThinking { data } => self.handle_redacted_thinking_event(data, cx),
362 ToolUse(tool_use) => {
363 return self.handle_tool_use_event(tool_use, event_stream, cx);
364 }
365 ToolUseJsonParseError {
366 id,
367 tool_name,
368 raw_input,
369 json_parse_error,
370 } => {
371 return Some(Task::ready(self.handle_tool_use_json_parse_error_event(
372 id,
373 tool_name,
374 raw_input,
375 json_parse_error,
376 )));
377 }
378 UsageUpdate(_) | StatusUpdate(_) => {}
379 Stop(_) => unreachable!(),
380 }
381
382 None
383 }
384
385 fn handle_text_event(
386 &mut self,
387 new_text: String,
388 events_stream: &AgentResponseEventStream,
389 cx: &mut Context<Self>,
390 ) {
391 events_stream.send_text(&new_text);
392
393 let last_message = self.last_assistant_message();
394 if let Some(MessageContent::Text(text)) = last_message.content.last_mut() {
395 text.push_str(&new_text);
396 } else {
397 last_message.content.push(MessageContent::Text(new_text));
398 }
399
400 cx.notify();
401 }
402
403 fn handle_thinking_event(
404 &mut self,
405 new_text: String,
406 new_signature: Option<String>,
407 event_stream: &AgentResponseEventStream,
408 cx: &mut Context<Self>,
409 ) {
410 event_stream.send_thinking(&new_text);
411
412 let last_message = self.last_assistant_message();
413 if let Some(MessageContent::Thinking { text, signature }) = last_message.content.last_mut()
414 {
415 text.push_str(&new_text);
416 *signature = new_signature.or(signature.take());
417 } else {
418 last_message.content.push(MessageContent::Thinking {
419 text: new_text,
420 signature: new_signature,
421 });
422 }
423
424 cx.notify();
425 }
426
427 fn handle_redacted_thinking_event(&mut self, data: String, cx: &mut Context<Self>) {
428 let last_message = self.last_assistant_message();
429 last_message
430 .content
431 .push(MessageContent::RedactedThinking(data));
432 cx.notify();
433 }
434
435 fn handle_tool_use_event(
436 &mut self,
437 tool_use: LanguageModelToolUse,
438 event_stream: &AgentResponseEventStream,
439 cx: &mut Context<Self>,
440 ) -> Option<Task<LanguageModelToolResult>> {
441 cx.notify();
442
443 let tool = self.tools.get(tool_use.name.as_ref()).cloned();
444
445 self.pending_tool_uses
446 .insert(tool_use.id.clone(), tool_use.clone());
447 let last_message = self.last_assistant_message();
448
449 // Ensure the last message ends in the current tool use
450 let push_new_tool_use = last_message.content.last_mut().map_or(true, |content| {
451 if let MessageContent::ToolUse(last_tool_use) = content {
452 if last_tool_use.id == tool_use.id {
453 *last_tool_use = tool_use.clone();
454 false
455 } else {
456 true
457 }
458 } else {
459 true
460 }
461 });
462
463 if push_new_tool_use {
464 event_stream.send_tool_call(tool.as_ref(), &tool_use);
465 last_message
466 .content
467 .push(MessageContent::ToolUse(tool_use.clone()));
468 } else {
469 event_stream.send_tool_call_update(
470 &tool_use.id,
471 acp::ToolCallUpdateFields {
472 raw_input: Some(tool_use.input.clone()),
473 ..Default::default()
474 },
475 );
476 }
477
478 if !tool_use.is_input_complete {
479 return None;
480 }
481
482 let Some(tool) = tool else {
483 let content = format!("No tool named {} exists", tool_use.name);
484 return Some(Task::ready(LanguageModelToolResult {
485 content: LanguageModelToolResultContent::Text(Arc::from(content)),
486 tool_use_id: tool_use.id,
487 tool_name: tool_use.name,
488 is_error: true,
489 output: None,
490 }));
491 };
492
493 let tool_result = self.run_tool(tool, tool_use.clone(), event_stream.clone(), cx);
494 Some(cx.foreground_executor().spawn(async move {
495 match tool_result.await {
496 Ok(tool_output) => LanguageModelToolResult {
497 tool_use_id: tool_use.id,
498 tool_name: tool_use.name,
499 is_error: false,
500 content: LanguageModelToolResultContent::Text(Arc::from(tool_output)),
501 output: None,
502 },
503 Err(error) => LanguageModelToolResult {
504 tool_use_id: tool_use.id,
505 tool_name: tool_use.name,
506 is_error: true,
507 content: LanguageModelToolResultContent::Text(Arc::from(error.to_string())),
508 output: None,
509 },
510 }
511 }))
512 }
513
514 fn run_tool(
515 &self,
516 tool: Arc<dyn AnyAgentTool>,
517 tool_use: LanguageModelToolUse,
518 event_stream: AgentResponseEventStream,
519 cx: &mut Context<Self>,
520 ) -> Task<Result<String>> {
521 cx.spawn(async move |_this, cx| {
522 let tool_event_stream = ToolCallEventStream::new(tool_use.id, event_stream);
523 tool_event_stream.send_update(acp::ToolCallUpdateFields {
524 status: Some(acp::ToolCallStatus::InProgress),
525 ..Default::default()
526 });
527 cx.update(|cx| tool.run(tool_use.input, tool_event_stream, cx))?
528 .await
529 })
530 }
531
532 fn handle_tool_use_json_parse_error_event(
533 &mut self,
534 tool_use_id: LanguageModelToolUseId,
535 tool_name: Arc<str>,
536 raw_input: Arc<str>,
537 json_parse_error: String,
538 ) -> LanguageModelToolResult {
539 let tool_output = format!("Error parsing input JSON: {json_parse_error}");
540 LanguageModelToolResult {
541 tool_use_id,
542 tool_name,
543 is_error: true,
544 content: LanguageModelToolResultContent::Text(tool_output.into()),
545 output: Some(serde_json::Value::String(raw_input.to_string())),
546 }
547 }
548
549 /// Guarantees the last message is from the assistant and returns a mutable reference.
550 fn last_assistant_message(&mut self) -> &mut AgentMessage {
551 if self
552 .messages
553 .last()
554 .map_or(true, |m| m.role != Role::Assistant)
555 {
556 self.messages.push(AgentMessage {
557 role: Role::Assistant,
558 content: Vec::new(),
559 });
560 }
561 self.messages.last_mut().unwrap()
562 }
563
564 /// Guarantees the last message is from the user and returns a mutable reference.
565 fn last_user_message(&mut self) -> &mut AgentMessage {
566 if self.messages.last().map_or(true, |m| m.role != Role::User) {
567 self.messages.push(AgentMessage {
568 role: Role::User,
569 content: Vec::new(),
570 });
571 }
572 self.messages.last_mut().unwrap()
573 }
574
575 fn build_completion_request(
576 &self,
577 completion_intent: CompletionIntent,
578 cx: &mut App,
579 ) -> LanguageModelRequest {
580 log::debug!("Building completion request");
581 log::debug!("Completion intent: {:?}", completion_intent);
582 log::debug!("Completion mode: {:?}", self.completion_mode);
583
584 let messages = self.build_request_messages();
585 log::info!("Request will include {} messages", messages.len());
586
587 let tools: Vec<LanguageModelRequestTool> = self
588 .tools
589 .values()
590 .filter_map(|tool| {
591 let tool_name = tool.name().to_string();
592 log::trace!("Including tool: {}", tool_name);
593 Some(LanguageModelRequestTool {
594 name: tool_name,
595 description: tool.description(cx).to_string(),
596 input_schema: tool
597 .input_schema(self.selected_model.tool_input_format())
598 .log_err()?,
599 })
600 })
601 .collect();
602
603 log::info!("Request includes {} tools", tools.len());
604
605 let request = LanguageModelRequest {
606 thread_id: None,
607 prompt_id: None,
608 intent: Some(completion_intent),
609 mode: Some(self.completion_mode),
610 messages,
611 tools,
612 tool_choice: None,
613 stop: Vec::new(),
614 temperature: None,
615 thinking_allowed: true,
616 };
617
618 log::debug!("Completion request built successfully");
619 request
620 }
621
622 fn build_request_messages(&self) -> Vec<LanguageModelRequestMessage> {
623 log::trace!(
624 "Building request messages from {} thread messages",
625 self.messages.len()
626 );
627
628 let messages = Some(self.build_system_message())
629 .iter()
630 .chain(self.messages.iter())
631 .map(|message| {
632 log::trace!(
633 " - {} message with {} content items",
634 match message.role {
635 Role::System => "System",
636 Role::User => "User",
637 Role::Assistant => "Assistant",
638 },
639 message.content.len()
640 );
641 LanguageModelRequestMessage {
642 role: message.role,
643 content: message.content.clone(),
644 cache: false,
645 }
646 })
647 .collect();
648 messages
649 }
650
651 pub fn to_markdown(&self) -> String {
652 let mut markdown = String::new();
653 for message in &self.messages {
654 markdown.push_str(&message.to_markdown());
655 }
656 markdown
657 }
658}
659
660pub trait AgentTool
661where
662 Self: 'static + Sized,
663{
664 type Input: for<'de> Deserialize<'de> + Serialize + JsonSchema;
665
666 fn name(&self) -> SharedString;
667
668 fn description(&self, _cx: &mut App) -> SharedString {
669 let schema = schemars::schema_for!(Self::Input);
670 SharedString::new(
671 schema
672 .get("description")
673 .and_then(|description| description.as_str())
674 .unwrap_or_default(),
675 )
676 }
677
678 fn kind(&self) -> acp::ToolKind;
679
680 /// The initial tool title to display. Can be updated during the tool run.
681 fn initial_title(&self, input: Self::Input) -> SharedString;
682
683 /// Returns the JSON schema that describes the tool's input.
684 fn input_schema(&self) -> Schema {
685 schemars::schema_for!(Self::Input)
686 }
687
688 /// Allows the tool to authorize a given tool call with the user if necessary
689 fn authorize(
690 &self,
691 input: Self::Input,
692 event_stream: ToolCallEventStream,
693 ) -> impl use<Self> + Future<Output = Result<()>> {
694 let json_input = serde_json::json!(&input);
695 event_stream.authorize(self.initial_title(input).into(), self.kind(), json_input)
696 }
697
698 /// Runs the tool with the provided input.
699 fn run(
700 self: Arc<Self>,
701 input: Self::Input,
702 event_stream: ToolCallEventStream,
703 cx: &mut App,
704 ) -> Task<Result<String>>;
705
706 fn erase(self) -> Arc<dyn AnyAgentTool> {
707 Arc::new(Erased(Arc::new(self)))
708 }
709}
710
711pub struct Erased<T>(T);
712
713pub trait AnyAgentTool {
714 fn name(&self) -> SharedString;
715 fn description(&self, cx: &mut App) -> SharedString;
716 fn kind(&self) -> acp::ToolKind;
717 fn initial_title(&self, input: serde_json::Value) -> Result<SharedString>;
718 fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value>;
719 fn run(
720 self: Arc<Self>,
721 input: serde_json::Value,
722 event_stream: ToolCallEventStream,
723 cx: &mut App,
724 ) -> Task<Result<String>>;
725}
726
727impl<T> AnyAgentTool for Erased<Arc<T>>
728where
729 T: AgentTool,
730{
731 fn name(&self) -> SharedString {
732 self.0.name()
733 }
734
735 fn description(&self, cx: &mut App) -> SharedString {
736 self.0.description(cx)
737 }
738
739 fn kind(&self) -> agent_client_protocol::ToolKind {
740 self.0.kind()
741 }
742
743 fn initial_title(&self, input: serde_json::Value) -> Result<SharedString> {
744 let parsed_input = serde_json::from_value(input)?;
745 Ok(self.0.initial_title(parsed_input))
746 }
747
748 fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
749 let mut json = serde_json::to_value(self.0.input_schema())?;
750 adapt_schema_to_format(&mut json, format)?;
751 Ok(json)
752 }
753
754 fn run(
755 self: Arc<Self>,
756 input: serde_json::Value,
757 event_stream: ToolCallEventStream,
758 cx: &mut App,
759 ) -> Task<Result<String>> {
760 let parsed_input: Result<T::Input> = serde_json::from_value(input).map_err(Into::into);
761 match parsed_input {
762 Ok(input) => self.0.clone().run(input, event_stream, cx),
763 Err(error) => Task::ready(Err(anyhow!(error))),
764 }
765 }
766}
767
768#[derive(Clone)]
769struct AgentResponseEventStream(
770 mpsc::UnboundedSender<Result<AgentResponseEvent, LanguageModelCompletionError>>,
771);
772
773impl AgentResponseEventStream {
774 fn send_text(&self, text: &str) {
775 self.0
776 .unbounded_send(Ok(AgentResponseEvent::Text(text.to_string())))
777 .ok();
778 }
779
780 fn send_thinking(&self, text: &str) {
781 self.0
782 .unbounded_send(Ok(AgentResponseEvent::Thinking(text.to_string())))
783 .ok();
784 }
785
786 fn authorize_tool_call(
787 &self,
788 id: &LanguageModelToolUseId,
789 title: String,
790 kind: acp::ToolKind,
791 input: serde_json::Value,
792 ) -> impl use<> + Future<Output = Result<()>> {
793 let (response_tx, response_rx) = oneshot::channel();
794 self.0
795 .unbounded_send(Ok(AgentResponseEvent::ToolCallAuthorization(
796 ToolCallAuthorization {
797 tool_call: Self::initial_tool_call(id, title, kind, input),
798 options: vec![
799 acp::PermissionOption {
800 id: acp::PermissionOptionId("always_allow".into()),
801 name: "Always Allow".into(),
802 kind: acp::PermissionOptionKind::AllowAlways,
803 },
804 acp::PermissionOption {
805 id: acp::PermissionOptionId("allow".into()),
806 name: "Allow".into(),
807 kind: acp::PermissionOptionKind::AllowOnce,
808 },
809 acp::PermissionOption {
810 id: acp::PermissionOptionId("deny".into()),
811 name: "Deny".into(),
812 kind: acp::PermissionOptionKind::RejectOnce,
813 },
814 ],
815 response: response_tx,
816 },
817 )))
818 .ok();
819 async move {
820 match response_rx.await?.0.as_ref() {
821 "allow" | "always_allow" => Ok(()),
822 _ => Err(anyhow!("Permission to run tool denied by user")),
823 }
824 }
825 }
826
827 fn send_tool_call(
828 &self,
829 tool: Option<&Arc<dyn AnyAgentTool>>,
830 tool_use: &LanguageModelToolUse,
831 ) {
832 self.0
833 .unbounded_send(Ok(AgentResponseEvent::ToolCall(Self::initial_tool_call(
834 &tool_use.id,
835 tool.and_then(|t| t.initial_title(tool_use.input.clone()).ok())
836 .map(|i| i.into())
837 .unwrap_or_else(|| tool_use.name.to_string()),
838 tool.map(|t| t.kind()).unwrap_or(acp::ToolKind::Other),
839 tool_use.input.clone(),
840 ))))
841 .ok();
842 }
843
844 fn initial_tool_call(
845 id: &LanguageModelToolUseId,
846 title: String,
847 kind: acp::ToolKind,
848 input: serde_json::Value,
849 ) -> acp::ToolCall {
850 acp::ToolCall {
851 id: acp::ToolCallId(id.to_string().into()),
852 title,
853 kind,
854 status: acp::ToolCallStatus::Pending,
855 content: vec![],
856 locations: vec![],
857 raw_input: Some(input),
858 raw_output: None,
859 }
860 }
861
862 fn send_tool_call_update(
863 &self,
864 tool_use_id: &LanguageModelToolUseId,
865 fields: acp::ToolCallUpdateFields,
866 ) {
867 self.0
868 .unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate(
869 acp::ToolCallUpdate {
870 id: acp::ToolCallId(tool_use_id.to_string().into()),
871 fields,
872 },
873 )))
874 .ok();
875 }
876
877 fn send_stop(&self, reason: StopReason) {
878 match reason {
879 StopReason::EndTurn => {
880 self.0
881 .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::EndTurn)))
882 .ok();
883 }
884 StopReason::MaxTokens => {
885 self.0
886 .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::MaxTokens)))
887 .ok();
888 }
889 StopReason::Refusal => {
890 self.0
891 .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::Refusal)))
892 .ok();
893 }
894 StopReason::ToolUse => {}
895 }
896 }
897
898 fn send_error(&self, error: LanguageModelCompletionError) {
899 self.0.unbounded_send(Err(error)).ok();
900 }
901}
902
903#[derive(Clone)]
904pub struct ToolCallEventStream {
905 tool_use_id: LanguageModelToolUseId,
906 stream: AgentResponseEventStream,
907}
908
909impl ToolCallEventStream {
910 fn new(tool_use_id: LanguageModelToolUseId, stream: AgentResponseEventStream) -> Self {
911 Self {
912 tool_use_id,
913 stream,
914 }
915 }
916
917 pub fn send_update(&self, fields: acp::ToolCallUpdateFields) {
918 self.stream.send_tool_call_update(&self.tool_use_id, fields);
919 }
920
921 pub fn authorize(
922 &self,
923 title: String,
924 kind: acp::ToolKind,
925 input: serde_json::Value,
926 ) -> impl use<> + Future<Output = Result<()>> {
927 self.stream
928 .authorize_tool_call(&self.tool_use_id, title, kind, input)
929 }
930}
931
932#[cfg(test)]
933pub struct TestToolCallEventStream {
934 stream: ToolCallEventStream,
935 _events_rx: mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
936}
937
938#[cfg(test)]
939impl TestToolCallEventStream {
940 pub fn new() -> Self {
941 let (events_tx, events_rx) =
942 mpsc::unbounded::<Result<AgentResponseEvent, LanguageModelCompletionError>>();
943
944 let stream = ToolCallEventStream::new("test".into(), AgentResponseEventStream(events_tx));
945
946 Self {
947 stream,
948 _events_rx: events_rx,
949 }
950 }
951
952 pub fn stream(&self) -> ToolCallEventStream {
953 self.stream.clone()
954 }
955}