thread.rs

  1use crate::templates::{SystemPromptTemplate, Template, Templates};
  2use agent_client_protocol as acp;
  3use anyhow::{anyhow, Context as _, Result};
  4use assistant_tool::{adapt_schema_to_format, ActionLog};
  5use cloud_llm_client::{CompletionIntent, CompletionMode};
  6use collections::HashMap;
  7use futures::{
  8    channel::{mpsc, oneshot},
  9    stream::FuturesUnordered,
 10};
 11use gpui::{App, Context, Entity, SharedString, Task};
 12use language_model::{
 13    LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
 14    LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
 15    LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolSchemaFormat,
 16    LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role, StopReason,
 17};
 18use log;
 19use project::Project;
 20use prompt_store::ProjectContext;
 21use schemars::{JsonSchema, Schema};
 22use serde::{Deserialize, Serialize};
 23use smol::stream::StreamExt;
 24use std::{cell::RefCell, collections::BTreeMap, fmt::Write, future::Future, rc::Rc, sync::Arc};
 25use util::{markdown::MarkdownCodeBlock, ResultExt};
 26
 27#[derive(Debug, Clone)]
 28pub struct AgentMessage {
 29    pub role: Role,
 30    pub content: Vec<MessageContent>,
 31}
 32
 33impl AgentMessage {
 34    pub fn to_markdown(&self) -> String {
 35        let mut markdown = format!("## {}\n", self.role);
 36
 37        for content in &self.content {
 38            match content {
 39                MessageContent::Text(text) => {
 40                    markdown.push_str(text);
 41                    markdown.push('\n');
 42                }
 43                MessageContent::Thinking { text, .. } => {
 44                    markdown.push_str("<think>");
 45                    markdown.push_str(text);
 46                    markdown.push_str("</think>\n");
 47                }
 48                MessageContent::RedactedThinking(_) => markdown.push_str("<redacted_thinking />\n"),
 49                MessageContent::Image(_) => {
 50                    markdown.push_str("<image />\n");
 51                }
 52                MessageContent::ToolUse(tool_use) => {
 53                    markdown.push_str(&format!(
 54                        "**Tool Use**: {} (ID: {})\n",
 55                        tool_use.name, tool_use.id
 56                    ));
 57                    markdown.push_str(&format!(
 58                        "{}\n",
 59                        MarkdownCodeBlock {
 60                            tag: "json",
 61                            text: &format!("{:#}", tool_use.input)
 62                        }
 63                    ));
 64                }
 65                MessageContent::ToolResult(tool_result) => {
 66                    markdown.push_str(&format!(
 67                        "**Tool Result**: {} (ID: {})\n\n",
 68                        tool_result.tool_name, tool_result.tool_use_id
 69                    ));
 70                    if tool_result.is_error {
 71                        markdown.push_str("**ERROR:**\n");
 72                    }
 73
 74                    match &tool_result.content {
 75                        LanguageModelToolResultContent::Text(text) => {
 76                            writeln!(markdown, "{text}\n").ok();
 77                        }
 78                        LanguageModelToolResultContent::Image(_) => {
 79                            writeln!(markdown, "<image />\n").ok();
 80                        }
 81                    }
 82
 83                    if let Some(output) = tool_result.output.as_ref() {
 84                        writeln!(
 85                            markdown,
 86                            "**Debug Output**:\n\n```json\n{}\n```\n",
 87                            serde_json::to_string_pretty(output).unwrap()
 88                        )
 89                        .unwrap();
 90                    }
 91                }
 92            }
 93        }
 94
 95        markdown
 96    }
 97}
 98
 99#[derive(Debug)]
100pub enum AgentResponseEvent {
101    Text(String),
102    Thinking(String),
103    ToolCall(acp::ToolCall),
104    ToolCallUpdate(acp::ToolCallUpdate),
105    ToolCallAuthorization(ToolCallAuthorization),
106    Stop(acp::StopReason),
107}
108
109#[derive(Debug)]
110pub struct ToolCallAuthorization {
111    pub tool_call: acp::ToolCall,
112    pub options: Vec<acp::PermissionOption>,
113    pub response: oneshot::Sender<acp::PermissionOptionId>,
114}
115
116pub struct Thread {
117    messages: Vec<AgentMessage>,
118    completion_mode: CompletionMode,
119    /// Holds the task that handles agent interaction until the end of the turn.
120    /// Survives across multiple requests as the model performs tool calls and
121    /// we run tools, report their results.
122    running_turn: Option<Task<()>>,
123    pending_tool_uses: HashMap<LanguageModelToolUseId, LanguageModelToolUse>,
124    tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
125    project_context: Rc<RefCell<ProjectContext>>,
126    templates: Arc<Templates>,
127    pub selected_model: Arc<dyn LanguageModel>,
128    action_log: Entity<ActionLog>,
129}
130
131impl Thread {
132    pub fn new(
133        _project: Entity<Project>,
134        project_context: Rc<RefCell<ProjectContext>>,
135        action_log: Entity<ActionLog>,
136        templates: Arc<Templates>,
137        default_model: Arc<dyn LanguageModel>,
138    ) -> Self {
139        Self {
140            messages: Vec::new(),
141            completion_mode: CompletionMode::Normal,
142            running_turn: None,
143            pending_tool_uses: HashMap::default(),
144            tools: BTreeMap::default(),
145            project_context,
146            templates,
147            selected_model: default_model,
148            action_log,
149        }
150    }
151
152    pub fn set_mode(&mut self, mode: CompletionMode) {
153        self.completion_mode = mode;
154    }
155
156    pub fn messages(&self) -> &[AgentMessage] {
157        &self.messages
158    }
159
160    pub fn add_tool(&mut self, tool: impl AgentTool) {
161        self.tools.insert(tool.name(), tool.erase());
162    }
163
164    pub fn remove_tool(&mut self, name: &str) -> bool {
165        self.tools.remove(name).is_some()
166    }
167
168    pub fn cancel(&mut self) {
169        self.running_turn.take();
170
171        let tool_results = self
172            .pending_tool_uses
173            .drain()
174            .map(|(tool_use_id, tool_use)| {
175                MessageContent::ToolResult(LanguageModelToolResult {
176                    tool_use_id,
177                    tool_name: tool_use.name.clone(),
178                    is_error: true,
179                    content: LanguageModelToolResultContent::Text("Tool canceled by user".into()),
180                    output: None,
181                })
182            })
183            .collect::<Vec<_>>();
184        self.last_user_message().content.extend(tool_results);
185    }
186
187    /// Sending a message results in the model streaming a response, which could include tool calls.
188    /// After calling tools, the model will stops and waits for any outstanding tool calls to be completed and their results sent.
189    /// The returned channel will report all the occurrences in which the model stops before erroring or ending its turn.
190    pub fn send(
191        &mut self,
192        model: Arc<dyn LanguageModel>,
193        content: impl Into<MessageContent>,
194        cx: &mut Context<Self>,
195    ) -> mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>> {
196        let content = content.into();
197        log::info!("Thread::send called with model: {:?}", model.name());
198        log::debug!("Thread::send content: {:?}", content);
199
200        cx.notify();
201        let (events_tx, events_rx) =
202            mpsc::unbounded::<Result<AgentResponseEvent, LanguageModelCompletionError>>();
203        let event_stream = AgentResponseEventStream(events_tx);
204
205        let user_message_ix = self.messages.len();
206        self.messages.push(AgentMessage {
207            role: Role::User,
208            content: vec![content],
209        });
210        log::info!("Total messages in thread: {}", self.messages.len());
211        self.running_turn = Some(cx.spawn(async move |thread, cx| {
212            log::info!("Starting agent turn execution");
213            let turn_result = async {
214                // Perform one request, then keep looping if the model makes tool calls.
215                let mut completion_intent = CompletionIntent::UserPrompt;
216                'outer: loop {
217                    log::debug!(
218                        "Building completion request with intent: {:?}",
219                        completion_intent
220                    );
221                    let request = thread.update(cx, |thread, cx| {
222                        thread.build_completion_request(completion_intent, cx)
223                    })?;
224
225                    // println!(
226                    //     "request: {}",
227                    //     serde_json::to_string_pretty(&request).unwrap()
228                    // );
229
230                    // Stream events, appending to messages and collecting up tool uses.
231                    log::info!("Calling model.stream_completion");
232                    let mut events = model.stream_completion(request, cx).await?;
233                    log::debug!("Stream completion started successfully");
234                    let mut tool_uses = FuturesUnordered::new();
235                    while let Some(event) = events.next().await {
236                        match event {
237                            Ok(LanguageModelCompletionEvent::Stop(reason)) => {
238                                event_stream.send_stop(reason);
239                                if reason == StopReason::Refusal {
240                                    thread.update(cx, |thread, _cx| {
241                                        thread.messages.truncate(user_message_ix);
242                                    })?;
243                                    break 'outer;
244                                }
245                            }
246                            Ok(event) => {
247                                log::trace!("Received completion event: {:?}", event);
248                                thread
249                                    .update(cx, |thread, cx| {
250                                        tool_uses.extend(thread.handle_streamed_completion_event(
251                                            event,
252                                            &event_stream,
253                                            cx,
254                                        ));
255                                    })
256                                    .ok();
257                            }
258                            Err(error) => {
259                                log::error!("Error in completion stream: {:?}", error);
260                                event_stream.send_error(error);
261                                break;
262                            }
263                        }
264                    }
265
266                    // If there are no tool uses, the turn is done.
267                    if tool_uses.is_empty() {
268                        log::info!("No tool uses found, completing turn");
269                        break;
270                    }
271                    log::info!("Found {} tool uses to execute", tool_uses.len());
272
273                    // As tool results trickle in, insert them in the last user
274                    // message so that they can be sent on the next tick of the
275                    // agentic loop.
276                    while let Some(tool_result) = tool_uses.next().await {
277                        log::info!("Tool finished {:?}", tool_result);
278
279                        event_stream.send_tool_call_update(
280                            &tool_result.tool_use_id,
281                            acp::ToolCallUpdateFields {
282                                status: Some(if tool_result.is_error {
283                                    acp::ToolCallStatus::Failed
284                                } else {
285                                    acp::ToolCallStatus::Completed
286                                }),
287                                ..Default::default()
288                            },
289                        );
290                        thread
291                            .update(cx, |thread, _cx| {
292                                thread.pending_tool_uses.remove(&tool_result.tool_use_id);
293                                thread
294                                    .last_user_message()
295                                    .content
296                                    .push(MessageContent::ToolResult(tool_result));
297                            })
298                            .ok();
299                    }
300
301                    completion_intent = CompletionIntent::ToolResults;
302                }
303
304                Ok(())
305            }
306            .await;
307
308            if let Err(error) = turn_result {
309                log::error!("Turn execution failed: {:?}", error);
310                event_stream.send_error(error);
311            } else {
312                log::info!("Turn execution completed successfully");
313            }
314        }));
315        events_rx
316    }
317
318    pub fn action_log(&self) -> &Entity<ActionLog> {
319        &self.action_log
320    }
321
322    pub fn build_system_message(&self) -> AgentMessage {
323        log::debug!("Building system message");
324        let prompt = SystemPromptTemplate {
325            project: &self.project_context.borrow(),
326            available_tools: self.tools.keys().cloned().collect(),
327        }
328        .render(&self.templates)
329        .context("failed to build system prompt")
330        .expect("Invalid template");
331        log::debug!("System message built");
332        AgentMessage {
333            role: Role::System,
334            content: vec![prompt.into()],
335        }
336    }
337
338    /// A helper method that's called on every streamed completion event.
339    /// Returns an optional tool result task, which the main agentic loop in
340    /// send will send back to the model when it resolves.
341    fn handle_streamed_completion_event(
342        &mut self,
343        event: LanguageModelCompletionEvent,
344        event_stream: &AgentResponseEventStream,
345        cx: &mut Context<Self>,
346    ) -> Option<Task<LanguageModelToolResult>> {
347        log::trace!("Handling streamed completion event: {:?}", event);
348        use LanguageModelCompletionEvent::*;
349
350        match event {
351            StartMessage { .. } => {
352                self.messages.push(AgentMessage {
353                    role: Role::Assistant,
354                    content: Vec::new(),
355                });
356            }
357            Text(new_text) => self.handle_text_event(new_text, event_stream, cx),
358            Thinking { text, signature } => {
359                self.handle_thinking_event(text, signature, event_stream, cx)
360            }
361            RedactedThinking { data } => self.handle_redacted_thinking_event(data, cx),
362            ToolUse(tool_use) => {
363                return self.handle_tool_use_event(tool_use, event_stream, cx);
364            }
365            ToolUseJsonParseError {
366                id,
367                tool_name,
368                raw_input,
369                json_parse_error,
370            } => {
371                return Some(Task::ready(self.handle_tool_use_json_parse_error_event(
372                    id,
373                    tool_name,
374                    raw_input,
375                    json_parse_error,
376                )));
377            }
378            UsageUpdate(_) | StatusUpdate(_) => {}
379            Stop(_) => unreachable!(),
380        }
381
382        None
383    }
384
385    fn handle_text_event(
386        &mut self,
387        new_text: String,
388        events_stream: &AgentResponseEventStream,
389        cx: &mut Context<Self>,
390    ) {
391        events_stream.send_text(&new_text);
392
393        let last_message = self.last_assistant_message();
394        if let Some(MessageContent::Text(text)) = last_message.content.last_mut() {
395            text.push_str(&new_text);
396        } else {
397            last_message.content.push(MessageContent::Text(new_text));
398        }
399
400        cx.notify();
401    }
402
403    fn handle_thinking_event(
404        &mut self,
405        new_text: String,
406        new_signature: Option<String>,
407        event_stream: &AgentResponseEventStream,
408        cx: &mut Context<Self>,
409    ) {
410        event_stream.send_thinking(&new_text);
411
412        let last_message = self.last_assistant_message();
413        if let Some(MessageContent::Thinking { text, signature }) = last_message.content.last_mut()
414        {
415            text.push_str(&new_text);
416            *signature = new_signature.or(signature.take());
417        } else {
418            last_message.content.push(MessageContent::Thinking {
419                text: new_text,
420                signature: new_signature,
421            });
422        }
423
424        cx.notify();
425    }
426
427    fn handle_redacted_thinking_event(&mut self, data: String, cx: &mut Context<Self>) {
428        let last_message = self.last_assistant_message();
429        last_message
430            .content
431            .push(MessageContent::RedactedThinking(data));
432        cx.notify();
433    }
434
435    fn handle_tool_use_event(
436        &mut self,
437        tool_use: LanguageModelToolUse,
438        event_stream: &AgentResponseEventStream,
439        cx: &mut Context<Self>,
440    ) -> Option<Task<LanguageModelToolResult>> {
441        cx.notify();
442
443        let tool = self.tools.get(tool_use.name.as_ref()).cloned();
444
445        self.pending_tool_uses
446            .insert(tool_use.id.clone(), tool_use.clone());
447        let last_message = self.last_assistant_message();
448
449        // Ensure the last message ends in the current tool use
450        let push_new_tool_use = last_message.content.last_mut().map_or(true, |content| {
451            if let MessageContent::ToolUse(last_tool_use) = content {
452                if last_tool_use.id == tool_use.id {
453                    *last_tool_use = tool_use.clone();
454                    false
455                } else {
456                    true
457                }
458            } else {
459                true
460            }
461        });
462
463        if push_new_tool_use {
464            event_stream.send_tool_call(tool.as_ref(), &tool_use);
465            last_message
466                .content
467                .push(MessageContent::ToolUse(tool_use.clone()));
468        } else {
469            event_stream.send_tool_call_update(
470                &tool_use.id,
471                acp::ToolCallUpdateFields {
472                    raw_input: Some(tool_use.input.clone()),
473                    ..Default::default()
474                },
475            );
476        }
477
478        if !tool_use.is_input_complete {
479            return None;
480        }
481
482        let Some(tool) = tool else {
483            let content = format!("No tool named {} exists", tool_use.name);
484            return Some(Task::ready(LanguageModelToolResult {
485                content: LanguageModelToolResultContent::Text(Arc::from(content)),
486                tool_use_id: tool_use.id,
487                tool_name: tool_use.name,
488                is_error: true,
489                output: None,
490            }));
491        };
492
493        let tool_result = self.run_tool(tool, tool_use.clone(), event_stream.clone(), cx);
494        Some(cx.foreground_executor().spawn(async move {
495            match tool_result.await {
496                Ok(tool_output) => LanguageModelToolResult {
497                    tool_use_id: tool_use.id,
498                    tool_name: tool_use.name,
499                    is_error: false,
500                    content: LanguageModelToolResultContent::Text(Arc::from(tool_output)),
501                    output: None,
502                },
503                Err(error) => LanguageModelToolResult {
504                    tool_use_id: tool_use.id,
505                    tool_name: tool_use.name,
506                    is_error: true,
507                    content: LanguageModelToolResultContent::Text(Arc::from(error.to_string())),
508                    output: None,
509                },
510            }
511        }))
512    }
513
514    fn run_tool(
515        &self,
516        tool: Arc<dyn AnyAgentTool>,
517        tool_use: LanguageModelToolUse,
518        event_stream: AgentResponseEventStream,
519        cx: &mut Context<Self>,
520    ) -> Task<Result<String>> {
521        cx.spawn(async move |_this, cx| {
522            let tool_event_stream = ToolCallEventStream::new(tool_use.id, event_stream);
523            tool_event_stream.send_update(acp::ToolCallUpdateFields {
524                status: Some(acp::ToolCallStatus::InProgress),
525                ..Default::default()
526            });
527            cx.update(|cx| tool.run(tool_use.input, tool_event_stream, cx))?
528                .await
529        })
530    }
531
532    fn handle_tool_use_json_parse_error_event(
533        &mut self,
534        tool_use_id: LanguageModelToolUseId,
535        tool_name: Arc<str>,
536        raw_input: Arc<str>,
537        json_parse_error: String,
538    ) -> LanguageModelToolResult {
539        let tool_output = format!("Error parsing input JSON: {json_parse_error}");
540        LanguageModelToolResult {
541            tool_use_id,
542            tool_name,
543            is_error: true,
544            content: LanguageModelToolResultContent::Text(tool_output.into()),
545            output: Some(serde_json::Value::String(raw_input.to_string())),
546        }
547    }
548
549    /// Guarantees the last message is from the assistant and returns a mutable reference.
550    fn last_assistant_message(&mut self) -> &mut AgentMessage {
551        if self
552            .messages
553            .last()
554            .map_or(true, |m| m.role != Role::Assistant)
555        {
556            self.messages.push(AgentMessage {
557                role: Role::Assistant,
558                content: Vec::new(),
559            });
560        }
561        self.messages.last_mut().unwrap()
562    }
563
564    /// Guarantees the last message is from the user and returns a mutable reference.
565    fn last_user_message(&mut self) -> &mut AgentMessage {
566        if self.messages.last().map_or(true, |m| m.role != Role::User) {
567            self.messages.push(AgentMessage {
568                role: Role::User,
569                content: Vec::new(),
570            });
571        }
572        self.messages.last_mut().unwrap()
573    }
574
575    fn build_completion_request(
576        &self,
577        completion_intent: CompletionIntent,
578        cx: &mut App,
579    ) -> LanguageModelRequest {
580        log::debug!("Building completion request");
581        log::debug!("Completion intent: {:?}", completion_intent);
582        log::debug!("Completion mode: {:?}", self.completion_mode);
583
584        let messages = self.build_request_messages();
585        log::info!("Request will include {} messages", messages.len());
586
587        let tools: Vec<LanguageModelRequestTool> = self
588            .tools
589            .values()
590            .filter_map(|tool| {
591                let tool_name = tool.name().to_string();
592                log::trace!("Including tool: {}", tool_name);
593                Some(LanguageModelRequestTool {
594                    name: tool_name,
595                    description: tool.description(cx).to_string(),
596                    input_schema: tool
597                        .input_schema(self.selected_model.tool_input_format())
598                        .log_err()?,
599                })
600            })
601            .collect();
602
603        log::info!("Request includes {} tools", tools.len());
604
605        let request = LanguageModelRequest {
606            thread_id: None,
607            prompt_id: None,
608            intent: Some(completion_intent),
609            mode: Some(self.completion_mode),
610            messages,
611            tools,
612            tool_choice: None,
613            stop: Vec::new(),
614            temperature: None,
615            thinking_allowed: true,
616        };
617
618        log::debug!("Completion request built successfully");
619        request
620    }
621
622    fn build_request_messages(&self) -> Vec<LanguageModelRequestMessage> {
623        log::trace!(
624            "Building request messages from {} thread messages",
625            self.messages.len()
626        );
627
628        let messages = Some(self.build_system_message())
629            .iter()
630            .chain(self.messages.iter())
631            .map(|message| {
632                log::trace!(
633                    "  - {} message with {} content items",
634                    match message.role {
635                        Role::System => "System",
636                        Role::User => "User",
637                        Role::Assistant => "Assistant",
638                    },
639                    message.content.len()
640                );
641                LanguageModelRequestMessage {
642                    role: message.role,
643                    content: message.content.clone(),
644                    cache: false,
645                }
646            })
647            .collect();
648        messages
649    }
650
651    pub fn to_markdown(&self) -> String {
652        let mut markdown = String::new();
653        for message in &self.messages {
654            markdown.push_str(&message.to_markdown());
655        }
656        markdown
657    }
658}
659
660pub trait AgentTool
661where
662    Self: 'static + Sized,
663{
664    type Input: for<'de> Deserialize<'de> + Serialize + JsonSchema;
665
666    fn name(&self) -> SharedString;
667
668    fn description(&self, _cx: &mut App) -> SharedString {
669        let schema = schemars::schema_for!(Self::Input);
670        SharedString::new(
671            schema
672                .get("description")
673                .and_then(|description| description.as_str())
674                .unwrap_or_default(),
675        )
676    }
677
678    fn kind(&self) -> acp::ToolKind;
679
680    /// The initial tool title to display. Can be updated during the tool run.
681    fn initial_title(&self, input: Self::Input) -> SharedString;
682
683    /// Returns the JSON schema that describes the tool's input.
684    fn input_schema(&self) -> Schema {
685        schemars::schema_for!(Self::Input)
686    }
687
688    /// Allows the tool to authorize a given tool call with the user if necessary
689    fn authorize(
690        &self,
691        input: Self::Input,
692        event_stream: ToolCallEventStream,
693    ) -> impl use<Self> + Future<Output = Result<()>> {
694        let json_input = serde_json::json!(&input);
695        event_stream.authorize(self.initial_title(input).into(), self.kind(), json_input)
696    }
697
698    /// Runs the tool with the provided input.
699    fn run(
700        self: Arc<Self>,
701        input: Self::Input,
702        event_stream: ToolCallEventStream,
703        cx: &mut App,
704    ) -> Task<Result<String>>;
705
706    fn erase(self) -> Arc<dyn AnyAgentTool> {
707        Arc::new(Erased(Arc::new(self)))
708    }
709}
710
711pub struct Erased<T>(T);
712
713pub trait AnyAgentTool {
714    fn name(&self) -> SharedString;
715    fn description(&self, cx: &mut App) -> SharedString;
716    fn kind(&self) -> acp::ToolKind;
717    fn initial_title(&self, input: serde_json::Value) -> Result<SharedString>;
718    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value>;
719    fn run(
720        self: Arc<Self>,
721        input: serde_json::Value,
722        event_stream: ToolCallEventStream,
723        cx: &mut App,
724    ) -> Task<Result<String>>;
725}
726
727impl<T> AnyAgentTool for Erased<Arc<T>>
728where
729    T: AgentTool,
730{
731    fn name(&self) -> SharedString {
732        self.0.name()
733    }
734
735    fn description(&self, cx: &mut App) -> SharedString {
736        self.0.description(cx)
737    }
738
739    fn kind(&self) -> agent_client_protocol::ToolKind {
740        self.0.kind()
741    }
742
743    fn initial_title(&self, input: serde_json::Value) -> Result<SharedString> {
744        let parsed_input = serde_json::from_value(input)?;
745        Ok(self.0.initial_title(parsed_input))
746    }
747
748    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
749        let mut json = serde_json::to_value(self.0.input_schema())?;
750        adapt_schema_to_format(&mut json, format)?;
751        Ok(json)
752    }
753
754    fn run(
755        self: Arc<Self>,
756        input: serde_json::Value,
757        event_stream: ToolCallEventStream,
758        cx: &mut App,
759    ) -> Task<Result<String>> {
760        let parsed_input: Result<T::Input> = serde_json::from_value(input).map_err(Into::into);
761        match parsed_input {
762            Ok(input) => self.0.clone().run(input, event_stream, cx),
763            Err(error) => Task::ready(Err(anyhow!(error))),
764        }
765    }
766}
767
768#[derive(Clone)]
769struct AgentResponseEventStream(
770    mpsc::UnboundedSender<Result<AgentResponseEvent, LanguageModelCompletionError>>,
771);
772
773impl AgentResponseEventStream {
774    fn send_text(&self, text: &str) {
775        self.0
776            .unbounded_send(Ok(AgentResponseEvent::Text(text.to_string())))
777            .ok();
778    }
779
780    fn send_thinking(&self, text: &str) {
781        self.0
782            .unbounded_send(Ok(AgentResponseEvent::Thinking(text.to_string())))
783            .ok();
784    }
785
786    fn authorize_tool_call(
787        &self,
788        id: &LanguageModelToolUseId,
789        title: String,
790        kind: acp::ToolKind,
791        input: serde_json::Value,
792    ) -> impl use<> + Future<Output = Result<()>> {
793        let (response_tx, response_rx) = oneshot::channel();
794        self.0
795            .unbounded_send(Ok(AgentResponseEvent::ToolCallAuthorization(
796                ToolCallAuthorization {
797                    tool_call: Self::initial_tool_call(id, title, kind, input),
798                    options: vec![
799                        acp::PermissionOption {
800                            id: acp::PermissionOptionId("always_allow".into()),
801                            name: "Always Allow".into(),
802                            kind: acp::PermissionOptionKind::AllowAlways,
803                        },
804                        acp::PermissionOption {
805                            id: acp::PermissionOptionId("allow".into()),
806                            name: "Allow".into(),
807                            kind: acp::PermissionOptionKind::AllowOnce,
808                        },
809                        acp::PermissionOption {
810                            id: acp::PermissionOptionId("deny".into()),
811                            name: "Deny".into(),
812                            kind: acp::PermissionOptionKind::RejectOnce,
813                        },
814                    ],
815                    response: response_tx,
816                },
817            )))
818            .ok();
819        async move {
820            match response_rx.await?.0.as_ref() {
821                "allow" | "always_allow" => Ok(()),
822                _ => Err(anyhow!("Permission to run tool denied by user")),
823            }
824        }
825    }
826
827    fn send_tool_call(
828        &self,
829        tool: Option<&Arc<dyn AnyAgentTool>>,
830        tool_use: &LanguageModelToolUse,
831    ) {
832        self.0
833            .unbounded_send(Ok(AgentResponseEvent::ToolCall(Self::initial_tool_call(
834                &tool_use.id,
835                tool.and_then(|t| t.initial_title(tool_use.input.clone()).ok())
836                    .map(|i| i.into())
837                    .unwrap_or_else(|| tool_use.name.to_string()),
838                tool.map(|t| t.kind()).unwrap_or(acp::ToolKind::Other),
839                tool_use.input.clone(),
840            ))))
841            .ok();
842    }
843
844    fn initial_tool_call(
845        id: &LanguageModelToolUseId,
846        title: String,
847        kind: acp::ToolKind,
848        input: serde_json::Value,
849    ) -> acp::ToolCall {
850        acp::ToolCall {
851            id: acp::ToolCallId(id.to_string().into()),
852            title,
853            kind,
854            status: acp::ToolCallStatus::Pending,
855            content: vec![],
856            locations: vec![],
857            raw_input: Some(input),
858            raw_output: None,
859        }
860    }
861
862    fn send_tool_call_update(
863        &self,
864        tool_use_id: &LanguageModelToolUseId,
865        fields: acp::ToolCallUpdateFields,
866    ) {
867        self.0
868            .unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate(
869                acp::ToolCallUpdate {
870                    id: acp::ToolCallId(tool_use_id.to_string().into()),
871                    fields,
872                },
873            )))
874            .ok();
875    }
876
877    fn send_stop(&self, reason: StopReason) {
878        match reason {
879            StopReason::EndTurn => {
880                self.0
881                    .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::EndTurn)))
882                    .ok();
883            }
884            StopReason::MaxTokens => {
885                self.0
886                    .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::MaxTokens)))
887                    .ok();
888            }
889            StopReason::Refusal => {
890                self.0
891                    .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::Refusal)))
892                    .ok();
893            }
894            StopReason::ToolUse => {}
895        }
896    }
897
898    fn send_error(&self, error: LanguageModelCompletionError) {
899        self.0.unbounded_send(Err(error)).ok();
900    }
901}
902
903#[derive(Clone)]
904pub struct ToolCallEventStream {
905    tool_use_id: LanguageModelToolUseId,
906    stream: AgentResponseEventStream,
907}
908
909impl ToolCallEventStream {
910    fn new(tool_use_id: LanguageModelToolUseId, stream: AgentResponseEventStream) -> Self {
911        Self {
912            tool_use_id,
913            stream,
914        }
915    }
916
917    pub fn send_update(&self, fields: acp::ToolCallUpdateFields) {
918        self.stream.send_tool_call_update(&self.tool_use_id, fields);
919    }
920
921    pub fn authorize(
922        &self,
923        title: String,
924        kind: acp::ToolKind,
925        input: serde_json::Value,
926    ) -> impl use<> + Future<Output = Result<()>> {
927        self.stream
928            .authorize_tool_call(&self.tool_use_id, title, kind, input)
929    }
930}
931
932#[cfg(test)]
933pub struct TestToolCallEventStream {
934    stream: ToolCallEventStream,
935    _events_rx: mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
936}
937
938#[cfg(test)]
939impl TestToolCallEventStream {
940    pub fn new() -> Self {
941        let (events_tx, events_rx) =
942            mpsc::unbounded::<Result<AgentResponseEvent, LanguageModelCompletionError>>();
943
944        let stream = ToolCallEventStream::new("test".into(), AgentResponseEventStream(events_tx));
945
946        Self {
947            stream,
948            _events_rx: events_rx,
949        }
950    }
951
952    pub fn stream(&self) -> ToolCallEventStream {
953        self.stream.clone()
954    }
955}