thread.rs

  1use crate::templates::{SystemPromptTemplate, Template, Templates};
  2use agent_client_protocol as acp;
  3use anyhow::{anyhow, Context as _, Result};
  4use assistant_tool::{adapt_schema_to_format, ActionLog};
  5use cloud_llm_client::{CompletionIntent, CompletionMode};
  6use collections::HashMap;
  7use futures::{
  8    channel::{mpsc, oneshot},
  9    stream::FuturesUnordered,
 10};
 11use gpui::{App, Context, Entity, SharedString, Task};
 12use language_model::{
 13    LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
 14    LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
 15    LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolSchemaFormat,
 16    LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role, StopReason,
 17};
 18use log;
 19use project::Project;
 20use prompt_store::ProjectContext;
 21use schemars::{JsonSchema, Schema};
 22use serde::{Deserialize, Serialize};
 23use smol::stream::StreamExt;
 24use std::{cell::RefCell, collections::BTreeMap, fmt::Write, future::Future, rc::Rc, sync::Arc};
 25use util::{markdown::MarkdownCodeBlock, ResultExt};
 26
 27#[derive(Debug, Clone)]
 28pub struct AgentMessage {
 29    pub role: Role,
 30    pub content: Vec<MessageContent>,
 31}
 32
 33impl AgentMessage {
 34    pub fn to_markdown(&self) -> String {
 35        let mut markdown = format!("## {}\n", self.role);
 36
 37        for content in &self.content {
 38            match content {
 39                MessageContent::Text(text) => {
 40                    markdown.push_str(text);
 41                    markdown.push('\n');
 42                }
 43                MessageContent::Thinking { text, .. } => {
 44                    markdown.push_str("<think>");
 45                    markdown.push_str(text);
 46                    markdown.push_str("</think>\n");
 47                }
 48                MessageContent::RedactedThinking(_) => markdown.push_str("<redacted_thinking />\n"),
 49                MessageContent::Image(_) => {
 50                    markdown.push_str("<image />\n");
 51                }
 52                MessageContent::ToolUse(tool_use) => {
 53                    markdown.push_str(&format!(
 54                        "**Tool Use**: {} (ID: {})\n",
 55                        tool_use.name, tool_use.id
 56                    ));
 57                    markdown.push_str(&format!(
 58                        "{}\n",
 59                        MarkdownCodeBlock {
 60                            tag: "json",
 61                            text: &format!("{:#}", tool_use.input)
 62                        }
 63                    ));
 64                }
 65                MessageContent::ToolResult(tool_result) => {
 66                    markdown.push_str(&format!(
 67                        "**Tool Result**: {} (ID: {})\n\n",
 68                        tool_result.tool_name, tool_result.tool_use_id
 69                    ));
 70                    if tool_result.is_error {
 71                        markdown.push_str("**ERROR:**\n");
 72                    }
 73
 74                    match &tool_result.content {
 75                        LanguageModelToolResultContent::Text(text) => {
 76                            writeln!(markdown, "{text}\n").ok();
 77                        }
 78                        LanguageModelToolResultContent::Image(_) => {
 79                            writeln!(markdown, "<image />\n").ok();
 80                        }
 81                    }
 82
 83                    if let Some(output) = tool_result.output.as_ref() {
 84                        writeln!(
 85                            markdown,
 86                            "**Debug Output**:\n\n```json\n{}\n```\n",
 87                            serde_json::to_string_pretty(output).unwrap()
 88                        )
 89                        .unwrap();
 90                    }
 91                }
 92            }
 93        }
 94
 95        markdown
 96    }
 97}
 98
 99#[derive(Debug)]
100pub enum AgentResponseEvent {
101    Text(String),
102    Thinking(String),
103    ToolCall(acp::ToolCall),
104    ToolCallUpdate(acp::ToolCallUpdate),
105    ToolCallAuthorization(ToolCallAuthorization),
106    Stop(acp::StopReason),
107}
108
109#[derive(Debug)]
110pub struct ToolCallAuthorization {
111    pub tool_call: acp::ToolCall,
112    pub options: Vec<acp::PermissionOption>,
113    pub response: oneshot::Sender<acp::PermissionOptionId>,
114}
115
116pub struct Thread {
117    messages: Vec<AgentMessage>,
118    completion_mode: CompletionMode,
119    /// Holds the task that handles agent interaction until the end of the turn.
120    /// Survives across multiple requests as the model performs tool calls and
121    /// we run tools, report their results.
122    running_turn: Option<Task<()>>,
123    pending_tool_uses: HashMap<LanguageModelToolUseId, LanguageModelToolUse>,
124    tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
125    project_context: Rc<RefCell<ProjectContext>>,
126    templates: Arc<Templates>,
127    pub selected_model: Arc<dyn LanguageModel>,
128    _action_log: Entity<ActionLog>,
129}
130
131impl Thread {
132    pub fn new(
133        _project: Entity<Project>,
134        project_context: Rc<RefCell<ProjectContext>>,
135        action_log: Entity<ActionLog>,
136        templates: Arc<Templates>,
137        default_model: Arc<dyn LanguageModel>,
138    ) -> Self {
139        Self {
140            messages: Vec::new(),
141            completion_mode: CompletionMode::Normal,
142            running_turn: None,
143            pending_tool_uses: HashMap::default(),
144            tools: BTreeMap::default(),
145            project_context,
146            templates,
147            selected_model: default_model,
148            _action_log: action_log,
149        }
150    }
151
152    pub fn set_mode(&mut self, mode: CompletionMode) {
153        self.completion_mode = mode;
154    }
155
156    pub fn messages(&self) -> &[AgentMessage] {
157        &self.messages
158    }
159
160    pub fn add_tool(&mut self, tool: impl AgentTool) {
161        self.tools.insert(tool.name(), tool.erase());
162    }
163
164    pub fn remove_tool(&mut self, name: &str) -> bool {
165        self.tools.remove(name).is_some()
166    }
167
168    pub fn cancel(&mut self) {
169        self.running_turn.take();
170
171        let tool_results = self
172            .pending_tool_uses
173            .drain()
174            .map(|(tool_use_id, tool_use)| {
175                MessageContent::ToolResult(LanguageModelToolResult {
176                    tool_use_id,
177                    tool_name: tool_use.name.clone(),
178                    is_error: true,
179                    content: LanguageModelToolResultContent::Text("Tool canceled by user".into()),
180                    output: None,
181                })
182            })
183            .collect::<Vec<_>>();
184        self.last_user_message().content.extend(tool_results);
185    }
186
187    /// Sending a message results in the model streaming a response, which could include tool calls.
188    /// After calling tools, the model will stops and waits for any outstanding tool calls to be completed and their results sent.
189    /// The returned channel will report all the occurrences in which the model stops before erroring or ending its turn.
190    pub fn send(
191        &mut self,
192        model: Arc<dyn LanguageModel>,
193        content: impl Into<MessageContent>,
194        cx: &mut Context<Self>,
195    ) -> mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>> {
196        let content = content.into();
197        log::info!("Thread::send called with model: {:?}", model.name());
198        log::debug!("Thread::send content: {:?}", content);
199
200        cx.notify();
201        let (events_tx, events_rx) =
202            mpsc::unbounded::<Result<AgentResponseEvent, LanguageModelCompletionError>>();
203        let event_stream = AgentResponseEventStream(events_tx);
204
205        let user_message_ix = self.messages.len();
206        self.messages.push(AgentMessage {
207            role: Role::User,
208            content: vec![content],
209        });
210        log::info!("Total messages in thread: {}", self.messages.len());
211        self.running_turn = Some(cx.spawn(async move |thread, cx| {
212            log::info!("Starting agent turn execution");
213            let turn_result = async {
214                // Perform one request, then keep looping if the model makes tool calls.
215                let mut completion_intent = CompletionIntent::UserPrompt;
216                'outer: loop {
217                    log::debug!(
218                        "Building completion request with intent: {:?}",
219                        completion_intent
220                    );
221                    let request = thread.update(cx, |thread, cx| {
222                        thread.build_completion_request(completion_intent, cx)
223                    })?;
224
225                    // println!(
226                    //     "request: {}",
227                    //     serde_json::to_string_pretty(&request).unwrap()
228                    // );
229
230                    // Stream events, appending to messages and collecting up tool uses.
231                    log::info!("Calling model.stream_completion");
232                    let mut events = model.stream_completion(request, cx).await?;
233                    log::debug!("Stream completion started successfully");
234                    let mut tool_uses = FuturesUnordered::new();
235                    while let Some(event) = events.next().await {
236                        match event {
237                            Ok(LanguageModelCompletionEvent::Stop(reason)) => {
238                                event_stream.send_stop(reason);
239                                if reason == StopReason::Refusal {
240                                    thread.update(cx, |thread, _cx| {
241                                        thread.messages.truncate(user_message_ix);
242                                    })?;
243                                    break 'outer;
244                                }
245                            }
246                            Ok(event) => {
247                                log::trace!("Received completion event: {:?}", event);
248                                thread
249                                    .update(cx, |thread, cx| {
250                                        tool_uses.extend(thread.handle_streamed_completion_event(
251                                            event,
252                                            &event_stream,
253                                            cx,
254                                        ));
255                                    })
256                                    .ok();
257                            }
258                            Err(error) => {
259                                log::error!("Error in completion stream: {:?}", error);
260                                event_stream.send_error(error);
261                                break;
262                            }
263                        }
264                    }
265
266                    // If there are no tool uses, the turn is done.
267                    if tool_uses.is_empty() {
268                        log::info!("No tool uses found, completing turn");
269                        break;
270                    }
271                    log::info!("Found {} tool uses to execute", tool_uses.len());
272
273                    // As tool results trickle in, insert them in the last user
274                    // message so that they can be sent on the next tick of the
275                    // agentic loop.
276                    while let Some(tool_result) = tool_uses.next().await {
277                        log::info!("Tool finished {:?}", tool_result);
278
279                        event_stream.send_tool_call_update(
280                            &tool_result.tool_use_id,
281                            acp::ToolCallUpdateFields {
282                                status: Some(if tool_result.is_error {
283                                    acp::ToolCallStatus::Failed
284                                } else {
285                                    acp::ToolCallStatus::Completed
286                                }),
287                                ..Default::default()
288                            },
289                        );
290                        thread
291                            .update(cx, |thread, _cx| {
292                                thread.pending_tool_uses.remove(&tool_result.tool_use_id);
293                                thread
294                                    .last_user_message()
295                                    .content
296                                    .push(MessageContent::ToolResult(tool_result));
297                            })
298                            .ok();
299                    }
300
301                    completion_intent = CompletionIntent::ToolResults;
302                }
303
304                Ok(())
305            }
306            .await;
307
308            if let Err(error) = turn_result {
309                log::error!("Turn execution failed: {:?}", error);
310                event_stream.send_error(error);
311            } else {
312                log::info!("Turn execution completed successfully");
313            }
314        }));
315        events_rx
316    }
317
318    pub fn build_system_message(&self) -> AgentMessage {
319        log::debug!("Building system message");
320        let prompt = SystemPromptTemplate {
321            project: &self.project_context.borrow(),
322            available_tools: self.tools.keys().cloned().collect(),
323        }
324        .render(&self.templates)
325        .context("failed to build system prompt")
326        .expect("Invalid template");
327        log::debug!("System message built");
328        AgentMessage {
329            role: Role::System,
330            content: vec![prompt.into()],
331        }
332    }
333
334    /// A helper method that's called on every streamed completion event.
335    /// Returns an optional tool result task, which the main agentic loop in
336    /// send will send back to the model when it resolves.
337    fn handle_streamed_completion_event(
338        &mut self,
339        event: LanguageModelCompletionEvent,
340        event_stream: &AgentResponseEventStream,
341        cx: &mut Context<Self>,
342    ) -> Option<Task<LanguageModelToolResult>> {
343        log::trace!("Handling streamed completion event: {:?}", event);
344        use LanguageModelCompletionEvent::*;
345
346        match event {
347            StartMessage { .. } => {
348                self.messages.push(AgentMessage {
349                    role: Role::Assistant,
350                    content: Vec::new(),
351                });
352            }
353            Text(new_text) => self.handle_text_event(new_text, event_stream, cx),
354            Thinking { text, signature } => {
355                self.handle_thinking_event(text, signature, event_stream, cx)
356            }
357            RedactedThinking { data } => self.handle_redacted_thinking_event(data, cx),
358            ToolUse(tool_use) => {
359                return self.handle_tool_use_event(tool_use, event_stream, cx);
360            }
361            ToolUseJsonParseError {
362                id,
363                tool_name,
364                raw_input,
365                json_parse_error,
366            } => {
367                return Some(Task::ready(self.handle_tool_use_json_parse_error_event(
368                    id,
369                    tool_name,
370                    raw_input,
371                    json_parse_error,
372                )));
373            }
374            UsageUpdate(_) | StatusUpdate(_) => {}
375            Stop(_) => unreachable!(),
376        }
377
378        None
379    }
380
381    fn handle_text_event(
382        &mut self,
383        new_text: String,
384        events_stream: &AgentResponseEventStream,
385        cx: &mut Context<Self>,
386    ) {
387        events_stream.send_text(&new_text);
388
389        let last_message = self.last_assistant_message();
390        if let Some(MessageContent::Text(text)) = last_message.content.last_mut() {
391            text.push_str(&new_text);
392        } else {
393            last_message.content.push(MessageContent::Text(new_text));
394        }
395
396        cx.notify();
397    }
398
399    fn handle_thinking_event(
400        &mut self,
401        new_text: String,
402        new_signature: Option<String>,
403        event_stream: &AgentResponseEventStream,
404        cx: &mut Context<Self>,
405    ) {
406        event_stream.send_thinking(&new_text);
407
408        let last_message = self.last_assistant_message();
409        if let Some(MessageContent::Thinking { text, signature }) = last_message.content.last_mut()
410        {
411            text.push_str(&new_text);
412            *signature = new_signature.or(signature.take());
413        } else {
414            last_message.content.push(MessageContent::Thinking {
415                text: new_text,
416                signature: new_signature,
417            });
418        }
419
420        cx.notify();
421    }
422
423    fn handle_redacted_thinking_event(&mut self, data: String, cx: &mut Context<Self>) {
424        let last_message = self.last_assistant_message();
425        last_message
426            .content
427            .push(MessageContent::RedactedThinking(data));
428        cx.notify();
429    }
430
431    fn handle_tool_use_event(
432        &mut self,
433        tool_use: LanguageModelToolUse,
434        event_stream: &AgentResponseEventStream,
435        cx: &mut Context<Self>,
436    ) -> Option<Task<LanguageModelToolResult>> {
437        cx.notify();
438
439        let tool = self.tools.get(tool_use.name.as_ref()).cloned();
440
441        self.pending_tool_uses
442            .insert(tool_use.id.clone(), tool_use.clone());
443        let last_message = self.last_assistant_message();
444
445        // Ensure the last message ends in the current tool use
446        let push_new_tool_use = last_message.content.last_mut().map_or(true, |content| {
447            if let MessageContent::ToolUse(last_tool_use) = content {
448                if last_tool_use.id == tool_use.id {
449                    *last_tool_use = tool_use.clone();
450                    false
451                } else {
452                    true
453                }
454            } else {
455                true
456            }
457        });
458
459        if push_new_tool_use {
460            event_stream.send_tool_call(tool.as_ref(), &tool_use);
461            last_message
462                .content
463                .push(MessageContent::ToolUse(tool_use.clone()));
464        } else {
465            event_stream.send_tool_call_update(
466                &tool_use.id,
467                acp::ToolCallUpdateFields {
468                    raw_input: Some(tool_use.input.clone()),
469                    ..Default::default()
470                },
471            );
472        }
473
474        if !tool_use.is_input_complete {
475            return None;
476        }
477
478        let Some(tool) = tool else {
479            let content = format!("No tool named {} exists", tool_use.name);
480            return Some(Task::ready(LanguageModelToolResult {
481                content: LanguageModelToolResultContent::Text(Arc::from(content)),
482                tool_use_id: tool_use.id,
483                tool_name: tool_use.name,
484                is_error: true,
485                output: None,
486            }));
487        };
488
489        let tool_result = self.run_tool(tool, tool_use.clone(), event_stream.clone(), cx);
490        Some(cx.foreground_executor().spawn(async move {
491            match tool_result.await {
492                Ok(tool_output) => LanguageModelToolResult {
493                    tool_use_id: tool_use.id,
494                    tool_name: tool_use.name,
495                    is_error: false,
496                    content: LanguageModelToolResultContent::Text(Arc::from(tool_output)),
497                    output: None,
498                },
499                Err(error) => LanguageModelToolResult {
500                    tool_use_id: tool_use.id,
501                    tool_name: tool_use.name,
502                    is_error: true,
503                    content: LanguageModelToolResultContent::Text(Arc::from(error.to_string())),
504                    output: None,
505                },
506            }
507        }))
508    }
509
510    fn run_tool(
511        &self,
512        tool: Arc<dyn AnyAgentTool>,
513        tool_use: LanguageModelToolUse,
514        event_stream: AgentResponseEventStream,
515        cx: &mut Context<Self>,
516    ) -> Task<Result<String>> {
517        cx.spawn(async move |_this, cx| {
518            let tool_event_stream = ToolCallEventStream::new(tool_use.id, event_stream);
519            tool_event_stream.send_update(acp::ToolCallUpdateFields {
520                status: Some(acp::ToolCallStatus::InProgress),
521                ..Default::default()
522            });
523            cx.update(|cx| tool.run(tool_use.input, tool_event_stream, cx))?
524                .await
525        })
526    }
527
528    fn handle_tool_use_json_parse_error_event(
529        &mut self,
530        tool_use_id: LanguageModelToolUseId,
531        tool_name: Arc<str>,
532        raw_input: Arc<str>,
533        json_parse_error: String,
534    ) -> LanguageModelToolResult {
535        let tool_output = format!("Error parsing input JSON: {json_parse_error}");
536        LanguageModelToolResult {
537            tool_use_id,
538            tool_name,
539            is_error: true,
540            content: LanguageModelToolResultContent::Text(tool_output.into()),
541            output: Some(serde_json::Value::String(raw_input.to_string())),
542        }
543    }
544
545    /// Guarantees the last message is from the assistant and returns a mutable reference.
546    fn last_assistant_message(&mut self) -> &mut AgentMessage {
547        if self
548            .messages
549            .last()
550            .map_or(true, |m| m.role != Role::Assistant)
551        {
552            self.messages.push(AgentMessage {
553                role: Role::Assistant,
554                content: Vec::new(),
555            });
556        }
557        self.messages.last_mut().unwrap()
558    }
559
560    /// Guarantees the last message is from the user and returns a mutable reference.
561    fn last_user_message(&mut self) -> &mut AgentMessage {
562        if self.messages.last().map_or(true, |m| m.role != Role::User) {
563            self.messages.push(AgentMessage {
564                role: Role::User,
565                content: Vec::new(),
566            });
567        }
568        self.messages.last_mut().unwrap()
569    }
570
571    fn build_completion_request(
572        &self,
573        completion_intent: CompletionIntent,
574        cx: &mut App,
575    ) -> LanguageModelRequest {
576        log::debug!("Building completion request");
577        log::debug!("Completion intent: {:?}", completion_intent);
578        log::debug!("Completion mode: {:?}", self.completion_mode);
579
580        let messages = self.build_request_messages();
581        log::info!("Request will include {} messages", messages.len());
582
583        let tools: Vec<LanguageModelRequestTool> = self
584            .tools
585            .values()
586            .filter_map(|tool| {
587                let tool_name = tool.name().to_string();
588                log::trace!("Including tool: {}", tool_name);
589                Some(LanguageModelRequestTool {
590                    name: tool_name,
591                    description: tool.description(cx).to_string(),
592                    input_schema: tool
593                        .input_schema(self.selected_model.tool_input_format())
594                        .log_err()?,
595                })
596            })
597            .collect();
598
599        log::info!("Request includes {} tools", tools.len());
600
601        let request = LanguageModelRequest {
602            thread_id: None,
603            prompt_id: None,
604            intent: Some(completion_intent),
605            mode: Some(self.completion_mode),
606            messages,
607            tools,
608            tool_choice: None,
609            stop: Vec::new(),
610            temperature: None,
611            thinking_allowed: true,
612        };
613
614        log::debug!("Completion request built successfully");
615        request
616    }
617
618    fn build_request_messages(&self) -> Vec<LanguageModelRequestMessage> {
619        log::trace!(
620            "Building request messages from {} thread messages",
621            self.messages.len()
622        );
623
624        let messages = Some(self.build_system_message())
625            .iter()
626            .chain(self.messages.iter())
627            .map(|message| {
628                log::trace!(
629                    "  - {} message with {} content items",
630                    match message.role {
631                        Role::System => "System",
632                        Role::User => "User",
633                        Role::Assistant => "Assistant",
634                    },
635                    message.content.len()
636                );
637                LanguageModelRequestMessage {
638                    role: message.role,
639                    content: message.content.clone(),
640                    cache: false,
641                }
642            })
643            .collect();
644        messages
645    }
646
647    pub fn to_markdown(&self) -> String {
648        let mut markdown = String::new();
649        for message in &self.messages {
650            markdown.push_str(&message.to_markdown());
651        }
652        markdown
653    }
654}
655
656pub trait AgentTool
657where
658    Self: 'static + Sized,
659{
660    type Input: for<'de> Deserialize<'de> + Serialize + JsonSchema;
661
662    fn name(&self) -> SharedString;
663
664    fn description(&self, _cx: &mut App) -> SharedString {
665        let schema = schemars::schema_for!(Self::Input);
666        SharedString::new(
667            schema
668                .get("description")
669                .and_then(|description| description.as_str())
670                .unwrap_or_default(),
671        )
672    }
673
674    fn kind(&self) -> acp::ToolKind;
675
676    /// The initial tool title to display. Can be updated during the tool run.
677    fn initial_title(&self, input: Self::Input) -> SharedString;
678
679    /// Returns the JSON schema that describes the tool's input.
680    fn input_schema(&self) -> Schema {
681        schemars::schema_for!(Self::Input)
682    }
683
684    /// Allows the tool to authorize a given tool call with the user if necessary
685    fn authorize(
686        &self,
687        input: Self::Input,
688        event_stream: ToolCallEventStream,
689    ) -> impl use<Self> + Future<Output = Result<()>> {
690        let json_input = serde_json::json!(&input);
691        event_stream.authorize(self.initial_title(input).into(), self.kind(), json_input)
692    }
693
694    /// Runs the tool with the provided input.
695    fn run(
696        self: Arc<Self>,
697        input: Self::Input,
698        event_stream: ToolCallEventStream,
699        cx: &mut App,
700    ) -> Task<Result<String>>;
701
702    fn erase(self) -> Arc<dyn AnyAgentTool> {
703        Arc::new(Erased(Arc::new(self)))
704    }
705}
706
707pub struct Erased<T>(T);
708
709pub trait AnyAgentTool {
710    fn name(&self) -> SharedString;
711    fn description(&self, cx: &mut App) -> SharedString;
712    fn kind(&self) -> acp::ToolKind;
713    fn initial_title(&self, input: serde_json::Value) -> Result<SharedString>;
714    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value>;
715    fn run(
716        self: Arc<Self>,
717        input: serde_json::Value,
718        event_stream: ToolCallEventStream,
719        cx: &mut App,
720    ) -> Task<Result<String>>;
721}
722
723impl<T> AnyAgentTool for Erased<Arc<T>>
724where
725    T: AgentTool,
726{
727    fn name(&self) -> SharedString {
728        self.0.name()
729    }
730
731    fn description(&self, cx: &mut App) -> SharedString {
732        self.0.description(cx)
733    }
734
735    fn kind(&self) -> agent_client_protocol::ToolKind {
736        self.0.kind()
737    }
738
739    fn initial_title(&self, input: serde_json::Value) -> Result<SharedString> {
740        let parsed_input = serde_json::from_value(input)?;
741        Ok(self.0.initial_title(parsed_input))
742    }
743
744    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
745        let mut json = serde_json::to_value(self.0.input_schema())?;
746        adapt_schema_to_format(&mut json, format)?;
747        Ok(json)
748    }
749
750    fn run(
751        self: Arc<Self>,
752        input: serde_json::Value,
753        event_stream: ToolCallEventStream,
754        cx: &mut App,
755    ) -> Task<Result<String>> {
756        let parsed_input: Result<T::Input> = serde_json::from_value(input).map_err(Into::into);
757        match parsed_input {
758            Ok(input) => self.0.clone().run(input, event_stream, cx),
759            Err(error) => Task::ready(Err(anyhow!(error))),
760        }
761    }
762}
763
764#[derive(Clone)]
765struct AgentResponseEventStream(
766    mpsc::UnboundedSender<Result<AgentResponseEvent, LanguageModelCompletionError>>,
767);
768
769impl AgentResponseEventStream {
770    fn send_text(&self, text: &str) {
771        self.0
772            .unbounded_send(Ok(AgentResponseEvent::Text(text.to_string())))
773            .ok();
774    }
775
776    fn send_thinking(&self, text: &str) {
777        self.0
778            .unbounded_send(Ok(AgentResponseEvent::Thinking(text.to_string())))
779            .ok();
780    }
781
782    fn authorize_tool_call(
783        &self,
784        id: &LanguageModelToolUseId,
785        title: String,
786        kind: acp::ToolKind,
787        input: serde_json::Value,
788    ) -> impl use<> + Future<Output = Result<()>> {
789        let (response_tx, response_rx) = oneshot::channel();
790        self.0
791            .unbounded_send(Ok(AgentResponseEvent::ToolCallAuthorization(
792                ToolCallAuthorization {
793                    tool_call: Self::initial_tool_call(id, title, kind, input),
794                    options: vec![
795                        acp::PermissionOption {
796                            id: acp::PermissionOptionId("always_allow".into()),
797                            name: "Always Allow".into(),
798                            kind: acp::PermissionOptionKind::AllowAlways,
799                        },
800                        acp::PermissionOption {
801                            id: acp::PermissionOptionId("allow".into()),
802                            name: "Allow".into(),
803                            kind: acp::PermissionOptionKind::AllowOnce,
804                        },
805                        acp::PermissionOption {
806                            id: acp::PermissionOptionId("deny".into()),
807                            name: "Deny".into(),
808                            kind: acp::PermissionOptionKind::RejectOnce,
809                        },
810                    ],
811                    response: response_tx,
812                },
813            )))
814            .ok();
815        async move {
816            match response_rx.await?.0.as_ref() {
817                "allow" | "always_allow" => Ok(()),
818                _ => Err(anyhow!("Permission to run tool denied by user")),
819            }
820        }
821    }
822
823    fn send_tool_call(
824        &self,
825        tool: Option<&Arc<dyn AnyAgentTool>>,
826        tool_use: &LanguageModelToolUse,
827    ) {
828        self.0
829            .unbounded_send(Ok(AgentResponseEvent::ToolCall(Self::initial_tool_call(
830                &tool_use.id,
831                tool.and_then(|t| t.initial_title(tool_use.input.clone()).ok())
832                    .map(|i| i.into())
833                    .unwrap_or_else(|| tool_use.name.to_string()),
834                tool.map(|t| t.kind()).unwrap_or(acp::ToolKind::Other),
835                tool_use.input.clone(),
836            ))))
837            .ok();
838    }
839
840    fn initial_tool_call(
841        id: &LanguageModelToolUseId,
842        title: String,
843        kind: acp::ToolKind,
844        input: serde_json::Value,
845    ) -> acp::ToolCall {
846        acp::ToolCall {
847            id: acp::ToolCallId(id.to_string().into()),
848            title,
849            kind,
850            status: acp::ToolCallStatus::Pending,
851            content: vec![],
852            locations: vec![],
853            raw_input: Some(input),
854            raw_output: None,
855        }
856    }
857
858    fn send_tool_call_update(
859        &self,
860        tool_use_id: &LanguageModelToolUseId,
861        fields: acp::ToolCallUpdateFields,
862    ) {
863        self.0
864            .unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate(
865                acp::ToolCallUpdate {
866                    id: acp::ToolCallId(tool_use_id.to_string().into()),
867                    fields,
868                },
869            )))
870            .ok();
871    }
872
873    fn send_stop(&self, reason: StopReason) {
874        match reason {
875            StopReason::EndTurn => {
876                self.0
877                    .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::EndTurn)))
878                    .ok();
879            }
880            StopReason::MaxTokens => {
881                self.0
882                    .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::MaxTokens)))
883                    .ok();
884            }
885            StopReason::Refusal => {
886                self.0
887                    .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::Refusal)))
888                    .ok();
889            }
890            StopReason::ToolUse => {}
891        }
892    }
893
894    fn send_error(&self, error: LanguageModelCompletionError) {
895        self.0.unbounded_send(Err(error)).ok();
896    }
897}
898
899#[derive(Clone)]
900pub struct ToolCallEventStream {
901    tool_use_id: LanguageModelToolUseId,
902    stream: AgentResponseEventStream,
903}
904
905impl ToolCallEventStream {
906    fn new(tool_use_id: LanguageModelToolUseId, stream: AgentResponseEventStream) -> Self {
907        Self {
908            tool_use_id,
909            stream,
910        }
911    }
912
913    pub fn send_update(&self, fields: acp::ToolCallUpdateFields) {
914        self.stream.send_tool_call_update(&self.tool_use_id, fields);
915    }
916
917    pub fn authorize(
918        &self,
919        title: String,
920        kind: acp::ToolKind,
921        input: serde_json::Value,
922    ) -> impl use<> + Future<Output = Result<()>> {
923        self.stream
924            .authorize_tool_call(&self.tool_use_id, title, kind, input)
925    }
926}