1use crate::{SystemPromptTemplate, Template, Templates};
2use action_log::ActionLog;
3use agent_client_protocol as acp;
4use agent_settings::AgentSettings;
5use anyhow::{Context as _, Result, anyhow};
6use assistant_tool::adapt_schema_to_format;
7use cloud_llm_client::{CompletionIntent, CompletionMode};
8use collections::HashMap;
9use fs::Fs;
10use futures::{
11 channel::{mpsc, oneshot},
12 stream::FuturesUnordered,
13};
14use gpui::{App, Context, Entity, SharedString, Task};
15use language_model::{
16 LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
17 LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
18 LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolSchemaFormat,
19 LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role, StopReason,
20};
21use log;
22use project::Project;
23use prompt_store::ProjectContext;
24use schemars::{JsonSchema, Schema};
25use serde::{Deserialize, Serialize};
26use settings::{Settings, update_settings_file};
27use smol::stream::StreamExt;
28use std::{cell::RefCell, collections::BTreeMap, fmt::Write, rc::Rc, sync::Arc};
29use util::{ResultExt, markdown::MarkdownCodeBlock};
30
31#[derive(Debug, Clone)]
32pub struct AgentMessage {
33 pub role: Role,
34 pub content: Vec<MessageContent>,
35}
36
37impl AgentMessage {
38 pub fn to_markdown(&self) -> String {
39 let mut markdown = format!("## {}\n", self.role);
40
41 for content in &self.content {
42 match content {
43 MessageContent::Text(text) => {
44 markdown.push_str(text);
45 markdown.push('\n');
46 }
47 MessageContent::Thinking { text, .. } => {
48 markdown.push_str("<think>");
49 markdown.push_str(text);
50 markdown.push_str("</think>\n");
51 }
52 MessageContent::RedactedThinking(_) => markdown.push_str("<redacted_thinking />\n"),
53 MessageContent::Image(_) => {
54 markdown.push_str("<image />\n");
55 }
56 MessageContent::ToolUse(tool_use) => {
57 markdown.push_str(&format!(
58 "**Tool Use**: {} (ID: {})\n",
59 tool_use.name, tool_use.id
60 ));
61 markdown.push_str(&format!(
62 "{}\n",
63 MarkdownCodeBlock {
64 tag: "json",
65 text: &format!("{:#}", tool_use.input)
66 }
67 ));
68 }
69 MessageContent::ToolResult(tool_result) => {
70 markdown.push_str(&format!(
71 "**Tool Result**: {} (ID: {})\n\n",
72 tool_result.tool_name, tool_result.tool_use_id
73 ));
74 if tool_result.is_error {
75 markdown.push_str("**ERROR:**\n");
76 }
77
78 match &tool_result.content {
79 LanguageModelToolResultContent::Text(text) => {
80 writeln!(markdown, "{text}\n").ok();
81 }
82 LanguageModelToolResultContent::Image(_) => {
83 writeln!(markdown, "<image />\n").ok();
84 }
85 }
86
87 if let Some(output) = tool_result.output.as_ref() {
88 writeln!(
89 markdown,
90 "**Debug Output**:\n\n```json\n{}\n```\n",
91 serde_json::to_string_pretty(output).unwrap()
92 )
93 .unwrap();
94 }
95 }
96 }
97 }
98
99 markdown
100 }
101}
102
103#[derive(Debug)]
104pub enum AgentResponseEvent {
105 Text(String),
106 Thinking(String),
107 ToolCall(acp::ToolCall),
108 ToolCallUpdate(acp_thread::ToolCallUpdate),
109 ToolCallAuthorization(ToolCallAuthorization),
110 Stop(acp::StopReason),
111}
112
113#[derive(Debug)]
114pub struct ToolCallAuthorization {
115 pub tool_call: acp::ToolCall,
116 pub options: Vec<acp::PermissionOption>,
117 pub response: oneshot::Sender<acp::PermissionOptionId>,
118}
119
120pub struct Thread {
121 messages: Vec<AgentMessage>,
122 completion_mode: CompletionMode,
123 /// Holds the task that handles agent interaction until the end of the turn.
124 /// Survives across multiple requests as the model performs tool calls and
125 /// we run tools, report their results.
126 running_turn: Option<Task<()>>,
127 pending_tool_uses: HashMap<LanguageModelToolUseId, LanguageModelToolUse>,
128 tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
129 project_context: Rc<RefCell<ProjectContext>>,
130 templates: Arc<Templates>,
131 pub selected_model: Arc<dyn LanguageModel>,
132 project: Entity<Project>,
133 action_log: Entity<ActionLog>,
134}
135
136impl Thread {
137 pub fn new(
138 project: Entity<Project>,
139 project_context: Rc<RefCell<ProjectContext>>,
140 action_log: Entity<ActionLog>,
141 templates: Arc<Templates>,
142 default_model: Arc<dyn LanguageModel>,
143 ) -> Self {
144 Self {
145 messages: Vec::new(),
146 completion_mode: CompletionMode::Normal,
147 running_turn: None,
148 pending_tool_uses: HashMap::default(),
149 tools: BTreeMap::default(),
150 project_context,
151 templates,
152 selected_model: default_model,
153 project,
154 action_log,
155 }
156 }
157
158 pub fn project(&self) -> &Entity<Project> {
159 &self.project
160 }
161
162 pub fn action_log(&self) -> &Entity<ActionLog> {
163 &self.action_log
164 }
165
166 pub fn set_mode(&mut self, mode: CompletionMode) {
167 self.completion_mode = mode;
168 }
169
170 pub fn messages(&self) -> &[AgentMessage] {
171 &self.messages
172 }
173
174 pub fn add_tool(&mut self, tool: impl AgentTool) {
175 self.tools.insert(tool.name(), tool.erase());
176 }
177
178 pub fn remove_tool(&mut self, name: &str) -> bool {
179 self.tools.remove(name).is_some()
180 }
181
182 pub fn cancel(&mut self) {
183 self.running_turn.take();
184
185 let tool_results = self
186 .pending_tool_uses
187 .drain()
188 .map(|(tool_use_id, tool_use)| {
189 MessageContent::ToolResult(LanguageModelToolResult {
190 tool_use_id,
191 tool_name: tool_use.name.clone(),
192 is_error: true,
193 content: LanguageModelToolResultContent::Text("Tool canceled by user".into()),
194 output: None,
195 })
196 })
197 .collect::<Vec<_>>();
198 self.last_user_message().content.extend(tool_results);
199 }
200
201 /// Sending a message results in the model streaming a response, which could include tool calls.
202 /// After calling tools, the model will stops and waits for any outstanding tool calls to be completed and their results sent.
203 /// The returned channel will report all the occurrences in which the model stops before erroring or ending its turn.
204 pub fn send(
205 &mut self,
206 content: impl Into<MessageContent>,
207 cx: &mut Context<Self>,
208 ) -> mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>> {
209 let content = content.into();
210 let model = self.selected_model.clone();
211 log::info!("Thread::send called with model: {:?}", model.name());
212 log::debug!("Thread::send content: {:?}", content);
213
214 cx.notify();
215 let (events_tx, events_rx) =
216 mpsc::unbounded::<Result<AgentResponseEvent, LanguageModelCompletionError>>();
217 let event_stream = AgentResponseEventStream(events_tx);
218
219 let user_message_ix = self.messages.len();
220 self.messages.push(AgentMessage {
221 role: Role::User,
222 content: vec![content],
223 });
224 log::info!("Total messages in thread: {}", self.messages.len());
225 self.running_turn = Some(cx.spawn(async move |thread, cx| {
226 log::info!("Starting agent turn execution");
227 let turn_result = async {
228 // Perform one request, then keep looping if the model makes tool calls.
229 let mut completion_intent = CompletionIntent::UserPrompt;
230 'outer: loop {
231 log::debug!(
232 "Building completion request with intent: {:?}",
233 completion_intent
234 );
235 let request = thread.update(cx, |thread, cx| {
236 thread.build_completion_request(completion_intent, cx)
237 })?;
238
239 // println!(
240 // "request: {}",
241 // serde_json::to_string_pretty(&request).unwrap()
242 // );
243
244 // Stream events, appending to messages and collecting up tool uses.
245 log::info!("Calling model.stream_completion");
246 let mut events = model.stream_completion(request, cx).await?;
247 log::debug!("Stream completion started successfully");
248 let mut tool_uses = FuturesUnordered::new();
249 while let Some(event) = events.next().await {
250 match event {
251 Ok(LanguageModelCompletionEvent::Stop(reason)) => {
252 event_stream.send_stop(reason);
253 if reason == StopReason::Refusal {
254 thread.update(cx, |thread, _cx| {
255 thread.messages.truncate(user_message_ix);
256 })?;
257 break 'outer;
258 }
259 }
260 Ok(event) => {
261 log::trace!("Received completion event: {:?}", event);
262 thread
263 .update(cx, |thread, cx| {
264 tool_uses.extend(thread.handle_streamed_completion_event(
265 event,
266 &event_stream,
267 cx,
268 ));
269 })
270 .ok();
271 }
272 Err(error) => {
273 log::error!("Error in completion stream: {:?}", error);
274 event_stream.send_error(error);
275 break;
276 }
277 }
278 }
279
280 // If there are no tool uses, the turn is done.
281 if tool_uses.is_empty() {
282 log::info!("No tool uses found, completing turn");
283 break;
284 }
285 log::info!("Found {} tool uses to execute", tool_uses.len());
286
287 // As tool results trickle in, insert them in the last user
288 // message so that they can be sent on the next tick of the
289 // agentic loop.
290 while let Some(tool_result) = tool_uses.next().await {
291 log::info!("Tool finished {:?}", tool_result);
292
293 event_stream.update_tool_call_fields(
294 &tool_result.tool_use_id,
295 acp::ToolCallUpdateFields {
296 status: Some(if tool_result.is_error {
297 acp::ToolCallStatus::Failed
298 } else {
299 acp::ToolCallStatus::Completed
300 }),
301 ..Default::default()
302 },
303 );
304 thread
305 .update(cx, |thread, _cx| {
306 thread.pending_tool_uses.remove(&tool_result.tool_use_id);
307 thread
308 .last_user_message()
309 .content
310 .push(MessageContent::ToolResult(tool_result));
311 })
312 .ok();
313 }
314
315 completion_intent = CompletionIntent::ToolResults;
316 }
317
318 Ok(())
319 }
320 .await;
321
322 if let Err(error) = turn_result {
323 log::error!("Turn execution failed: {:?}", error);
324 event_stream.send_error(error);
325 } else {
326 log::info!("Turn execution completed successfully");
327 }
328 }));
329 events_rx
330 }
331
332 pub fn build_system_message(&self) -> AgentMessage {
333 log::debug!("Building system message");
334 let prompt = SystemPromptTemplate {
335 project: &self.project_context.borrow(),
336 available_tools: self.tools.keys().cloned().collect(),
337 }
338 .render(&self.templates)
339 .context("failed to build system prompt")
340 .expect("Invalid template");
341 log::debug!("System message built");
342 AgentMessage {
343 role: Role::System,
344 content: vec![prompt.into()],
345 }
346 }
347
348 /// A helper method that's called on every streamed completion event.
349 /// Returns an optional tool result task, which the main agentic loop in
350 /// send will send back to the model when it resolves.
351 fn handle_streamed_completion_event(
352 &mut self,
353 event: LanguageModelCompletionEvent,
354 event_stream: &AgentResponseEventStream,
355 cx: &mut Context<Self>,
356 ) -> Option<Task<LanguageModelToolResult>> {
357 log::trace!("Handling streamed completion event: {:?}", event);
358 use LanguageModelCompletionEvent::*;
359
360 match event {
361 StartMessage { .. } => {
362 self.messages.push(AgentMessage {
363 role: Role::Assistant,
364 content: Vec::new(),
365 });
366 }
367 Text(new_text) => self.handle_text_event(new_text, event_stream, cx),
368 Thinking { text, signature } => {
369 self.handle_thinking_event(text, signature, event_stream, cx)
370 }
371 RedactedThinking { data } => self.handle_redacted_thinking_event(data, cx),
372 ToolUse(tool_use) => {
373 return self.handle_tool_use_event(tool_use, event_stream, cx);
374 }
375 ToolUseJsonParseError {
376 id,
377 tool_name,
378 raw_input,
379 json_parse_error,
380 } => {
381 return Some(Task::ready(self.handle_tool_use_json_parse_error_event(
382 id,
383 tool_name,
384 raw_input,
385 json_parse_error,
386 )));
387 }
388 UsageUpdate(_) | StatusUpdate(_) => {}
389 Stop(_) => unreachable!(),
390 }
391
392 None
393 }
394
395 fn handle_text_event(
396 &mut self,
397 new_text: String,
398 events_stream: &AgentResponseEventStream,
399 cx: &mut Context<Self>,
400 ) {
401 events_stream.send_text(&new_text);
402
403 let last_message = self.last_assistant_message();
404 if let Some(MessageContent::Text(text)) = last_message.content.last_mut() {
405 text.push_str(&new_text);
406 } else {
407 last_message.content.push(MessageContent::Text(new_text));
408 }
409
410 cx.notify();
411 }
412
413 fn handle_thinking_event(
414 &mut self,
415 new_text: String,
416 new_signature: Option<String>,
417 event_stream: &AgentResponseEventStream,
418 cx: &mut Context<Self>,
419 ) {
420 event_stream.send_thinking(&new_text);
421
422 let last_message = self.last_assistant_message();
423 if let Some(MessageContent::Thinking { text, signature }) = last_message.content.last_mut()
424 {
425 text.push_str(&new_text);
426 *signature = new_signature.or(signature.take());
427 } else {
428 last_message.content.push(MessageContent::Thinking {
429 text: new_text,
430 signature: new_signature,
431 });
432 }
433
434 cx.notify();
435 }
436
437 fn handle_redacted_thinking_event(&mut self, data: String, cx: &mut Context<Self>) {
438 let last_message = self.last_assistant_message();
439 last_message
440 .content
441 .push(MessageContent::RedactedThinking(data));
442 cx.notify();
443 }
444
445 fn handle_tool_use_event(
446 &mut self,
447 tool_use: LanguageModelToolUse,
448 event_stream: &AgentResponseEventStream,
449 cx: &mut Context<Self>,
450 ) -> Option<Task<LanguageModelToolResult>> {
451 cx.notify();
452
453 let tool = self.tools.get(tool_use.name.as_ref()).cloned();
454
455 self.pending_tool_uses
456 .insert(tool_use.id.clone(), tool_use.clone());
457 let last_message = self.last_assistant_message();
458
459 // Ensure the last message ends in the current tool use
460 let push_new_tool_use = last_message.content.last_mut().map_or(true, |content| {
461 if let MessageContent::ToolUse(last_tool_use) = content {
462 if last_tool_use.id == tool_use.id {
463 *last_tool_use = tool_use.clone();
464 false
465 } else {
466 true
467 }
468 } else {
469 true
470 }
471 });
472
473 let mut title = SharedString::from(&tool_use.name);
474 let mut kind = acp::ToolKind::Other;
475 if let Some(tool) = tool.as_ref() {
476 title = tool.initial_title(tool_use.input.clone());
477 kind = tool.kind();
478 }
479
480 if push_new_tool_use {
481 event_stream.send_tool_call(&tool_use.id, title, kind, tool_use.input.clone());
482 last_message
483 .content
484 .push(MessageContent::ToolUse(tool_use.clone()));
485 } else {
486 event_stream.update_tool_call_fields(
487 &tool_use.id,
488 acp::ToolCallUpdateFields {
489 title: Some(title.into()),
490 kind: Some(kind),
491 raw_input: Some(tool_use.input.clone()),
492 ..Default::default()
493 },
494 );
495 }
496
497 if !tool_use.is_input_complete {
498 return None;
499 }
500
501 let Some(tool) = tool else {
502 let content = format!("No tool named {} exists", tool_use.name);
503 return Some(Task::ready(LanguageModelToolResult {
504 content: LanguageModelToolResultContent::Text(Arc::from(content)),
505 tool_use_id: tool_use.id,
506 tool_name: tool_use.name,
507 is_error: true,
508 output: None,
509 }));
510 };
511
512 let fs = self.project.read(cx).fs().clone();
513 let tool_event_stream =
514 ToolCallEventStream::new(&tool_use, tool.kind(), event_stream.clone(), Some(fs));
515 tool_event_stream.update_fields(acp::ToolCallUpdateFields {
516 status: Some(acp::ToolCallStatus::InProgress),
517 ..Default::default()
518 });
519 let supports_images = self.selected_model.supports_images();
520 let tool_result = tool.run(tool_use.input, tool_event_stream, cx);
521 Some(cx.foreground_executor().spawn(async move {
522 let tool_result = tool_result.await.and_then(|output| {
523 if let LanguageModelToolResultContent::Image(_) = &output.llm_output {
524 if !supports_images {
525 return Err(anyhow!(
526 "Attempted to read an image, but this model doesn't support it.",
527 ));
528 }
529 }
530 Ok(output)
531 });
532
533 match tool_result {
534 Ok(output) => LanguageModelToolResult {
535 tool_use_id: tool_use.id,
536 tool_name: tool_use.name,
537 is_error: false,
538 content: output.llm_output,
539 output: Some(output.raw_output),
540 },
541 Err(error) => LanguageModelToolResult {
542 tool_use_id: tool_use.id,
543 tool_name: tool_use.name,
544 is_error: true,
545 content: LanguageModelToolResultContent::Text(Arc::from(error.to_string())),
546 output: None,
547 },
548 }
549 }))
550 }
551
552 fn handle_tool_use_json_parse_error_event(
553 &mut self,
554 tool_use_id: LanguageModelToolUseId,
555 tool_name: Arc<str>,
556 raw_input: Arc<str>,
557 json_parse_error: String,
558 ) -> LanguageModelToolResult {
559 let tool_output = format!("Error parsing input JSON: {json_parse_error}");
560 LanguageModelToolResult {
561 tool_use_id,
562 tool_name,
563 is_error: true,
564 content: LanguageModelToolResultContent::Text(tool_output.into()),
565 output: Some(serde_json::Value::String(raw_input.to_string())),
566 }
567 }
568
569 /// Guarantees the last message is from the assistant and returns a mutable reference.
570 fn last_assistant_message(&mut self) -> &mut AgentMessage {
571 if self
572 .messages
573 .last()
574 .map_or(true, |m| m.role != Role::Assistant)
575 {
576 self.messages.push(AgentMessage {
577 role: Role::Assistant,
578 content: Vec::new(),
579 });
580 }
581 self.messages.last_mut().unwrap()
582 }
583
584 /// Guarantees the last message is from the user and returns a mutable reference.
585 fn last_user_message(&mut self) -> &mut AgentMessage {
586 if self.messages.last().map_or(true, |m| m.role != Role::User) {
587 self.messages.push(AgentMessage {
588 role: Role::User,
589 content: Vec::new(),
590 });
591 }
592 self.messages.last_mut().unwrap()
593 }
594
595 pub(crate) fn build_completion_request(
596 &self,
597 completion_intent: CompletionIntent,
598 cx: &mut App,
599 ) -> LanguageModelRequest {
600 log::debug!("Building completion request");
601 log::debug!("Completion intent: {:?}", completion_intent);
602 log::debug!("Completion mode: {:?}", self.completion_mode);
603
604 let messages = self.build_request_messages();
605 log::info!("Request will include {} messages", messages.len());
606
607 let tools: Vec<LanguageModelRequestTool> = self
608 .tools
609 .values()
610 .filter_map(|tool| {
611 let tool_name = tool.name().to_string();
612 log::trace!("Including tool: {}", tool_name);
613 Some(LanguageModelRequestTool {
614 name: tool_name,
615 description: tool.description(cx).to_string(),
616 input_schema: tool
617 .input_schema(self.selected_model.tool_input_format())
618 .log_err()?,
619 })
620 })
621 .collect();
622
623 log::info!("Request includes {} tools", tools.len());
624
625 let request = LanguageModelRequest {
626 thread_id: None,
627 prompt_id: None,
628 intent: Some(completion_intent),
629 mode: Some(self.completion_mode),
630 messages,
631 tools,
632 tool_choice: None,
633 stop: Vec::new(),
634 temperature: None,
635 thinking_allowed: true,
636 };
637
638 log::debug!("Completion request built successfully");
639 request
640 }
641
642 fn build_request_messages(&self) -> Vec<LanguageModelRequestMessage> {
643 log::trace!(
644 "Building request messages from {} thread messages",
645 self.messages.len()
646 );
647
648 let messages = Some(self.build_system_message())
649 .iter()
650 .chain(self.messages.iter())
651 .map(|message| {
652 log::trace!(
653 " - {} message with {} content items",
654 match message.role {
655 Role::System => "System",
656 Role::User => "User",
657 Role::Assistant => "Assistant",
658 },
659 message.content.len()
660 );
661 LanguageModelRequestMessage {
662 role: message.role,
663 content: message.content.clone(),
664 cache: false,
665 }
666 })
667 .collect();
668 messages
669 }
670
671 pub fn to_markdown(&self) -> String {
672 let mut markdown = String::new();
673 for message in &self.messages {
674 markdown.push_str(&message.to_markdown());
675 }
676 markdown
677 }
678}
679
680pub trait AgentTool
681where
682 Self: 'static + Sized,
683{
684 type Input: for<'de> Deserialize<'de> + Serialize + JsonSchema;
685 type Output: for<'de> Deserialize<'de> + Serialize + Into<LanguageModelToolResultContent>;
686
687 fn name(&self) -> SharedString;
688
689 fn description(&self, _cx: &mut App) -> SharedString {
690 let schema = schemars::schema_for!(Self::Input);
691 SharedString::new(
692 schema
693 .get("description")
694 .and_then(|description| description.as_str())
695 .unwrap_or_default(),
696 )
697 }
698
699 fn kind(&self) -> acp::ToolKind;
700
701 /// The initial tool title to display. Can be updated during the tool run.
702 fn initial_title(&self, input: Result<Self::Input, serde_json::Value>) -> SharedString;
703
704 /// Returns the JSON schema that describes the tool's input.
705 fn input_schema(&self) -> Schema {
706 schemars::schema_for!(Self::Input)
707 }
708
709 /// Runs the tool with the provided input.
710 fn run(
711 self: Arc<Self>,
712 input: Self::Input,
713 event_stream: ToolCallEventStream,
714 cx: &mut App,
715 ) -> Task<Result<Self::Output>>;
716
717 fn erase(self) -> Arc<dyn AnyAgentTool> {
718 Arc::new(Erased(Arc::new(self)))
719 }
720}
721
722pub struct Erased<T>(T);
723
724pub struct AgentToolOutput {
725 llm_output: LanguageModelToolResultContent,
726 raw_output: serde_json::Value,
727}
728
729pub trait AnyAgentTool {
730 fn name(&self) -> SharedString;
731 fn description(&self, cx: &mut App) -> SharedString;
732 fn kind(&self) -> acp::ToolKind;
733 fn initial_title(&self, input: serde_json::Value) -> SharedString;
734 fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value>;
735 fn run(
736 self: Arc<Self>,
737 input: serde_json::Value,
738 event_stream: ToolCallEventStream,
739 cx: &mut App,
740 ) -> Task<Result<AgentToolOutput>>;
741}
742
743impl<T> AnyAgentTool for Erased<Arc<T>>
744where
745 T: AgentTool,
746{
747 fn name(&self) -> SharedString {
748 self.0.name()
749 }
750
751 fn description(&self, cx: &mut App) -> SharedString {
752 self.0.description(cx)
753 }
754
755 fn kind(&self) -> agent_client_protocol::ToolKind {
756 self.0.kind()
757 }
758
759 fn initial_title(&self, input: serde_json::Value) -> SharedString {
760 let parsed_input = serde_json::from_value(input.clone()).map_err(|_| input);
761 self.0.initial_title(parsed_input)
762 }
763
764 fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
765 let mut json = serde_json::to_value(self.0.input_schema())?;
766 adapt_schema_to_format(&mut json, format)?;
767 Ok(json)
768 }
769
770 fn run(
771 self: Arc<Self>,
772 input: serde_json::Value,
773 event_stream: ToolCallEventStream,
774 cx: &mut App,
775 ) -> Task<Result<AgentToolOutput>> {
776 cx.spawn(async move |cx| {
777 let input = serde_json::from_value(input)?;
778 let output = cx
779 .update(|cx| self.0.clone().run(input, event_stream, cx))?
780 .await?;
781 let raw_output = serde_json::to_value(&output)?;
782 Ok(AgentToolOutput {
783 llm_output: output.into(),
784 raw_output,
785 })
786 })
787 }
788}
789
790#[derive(Clone)]
791struct AgentResponseEventStream(
792 mpsc::UnboundedSender<Result<AgentResponseEvent, LanguageModelCompletionError>>,
793);
794
795impl AgentResponseEventStream {
796 fn send_text(&self, text: &str) {
797 self.0
798 .unbounded_send(Ok(AgentResponseEvent::Text(text.to_string())))
799 .ok();
800 }
801
802 fn send_thinking(&self, text: &str) {
803 self.0
804 .unbounded_send(Ok(AgentResponseEvent::Thinking(text.to_string())))
805 .ok();
806 }
807
808 fn send_tool_call(
809 &self,
810 id: &LanguageModelToolUseId,
811 title: SharedString,
812 kind: acp::ToolKind,
813 input: serde_json::Value,
814 ) {
815 self.0
816 .unbounded_send(Ok(AgentResponseEvent::ToolCall(Self::initial_tool_call(
817 id,
818 title.to_string(),
819 kind,
820 input,
821 ))))
822 .ok();
823 }
824
825 fn initial_tool_call(
826 id: &LanguageModelToolUseId,
827 title: String,
828 kind: acp::ToolKind,
829 input: serde_json::Value,
830 ) -> acp::ToolCall {
831 acp::ToolCall {
832 id: acp::ToolCallId(id.to_string().into()),
833 title,
834 kind,
835 status: acp::ToolCallStatus::Pending,
836 content: vec![],
837 locations: vec![],
838 raw_input: Some(input),
839 raw_output: None,
840 }
841 }
842
843 fn update_tool_call_fields(
844 &self,
845 tool_use_id: &LanguageModelToolUseId,
846 fields: acp::ToolCallUpdateFields,
847 ) {
848 self.0
849 .unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate(
850 acp::ToolCallUpdate {
851 id: acp::ToolCallId(tool_use_id.to_string().into()),
852 fields,
853 }
854 .into(),
855 )))
856 .ok();
857 }
858
859 fn send_stop(&self, reason: StopReason) {
860 match reason {
861 StopReason::EndTurn => {
862 self.0
863 .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::EndTurn)))
864 .ok();
865 }
866 StopReason::MaxTokens => {
867 self.0
868 .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::MaxTokens)))
869 .ok();
870 }
871 StopReason::Refusal => {
872 self.0
873 .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::Refusal)))
874 .ok();
875 }
876 StopReason::ToolUse => {}
877 }
878 }
879
880 fn send_error(&self, error: LanguageModelCompletionError) {
881 self.0.unbounded_send(Err(error)).ok();
882 }
883}
884
885#[derive(Clone)]
886pub struct ToolCallEventStream {
887 tool_use_id: LanguageModelToolUseId,
888 kind: acp::ToolKind,
889 input: serde_json::Value,
890 stream: AgentResponseEventStream,
891 fs: Option<Arc<dyn Fs>>,
892}
893
894impl ToolCallEventStream {
895 #[cfg(test)]
896 pub fn test() -> (Self, ToolCallEventStreamReceiver) {
897 let (events_tx, events_rx) =
898 mpsc::unbounded::<Result<AgentResponseEvent, LanguageModelCompletionError>>();
899
900 let stream = ToolCallEventStream::new(
901 &LanguageModelToolUse {
902 id: "test_id".into(),
903 name: "test_tool".into(),
904 raw_input: String::new(),
905 input: serde_json::Value::Null,
906 is_input_complete: true,
907 },
908 acp::ToolKind::Other,
909 AgentResponseEventStream(events_tx),
910 None,
911 );
912
913 (stream, ToolCallEventStreamReceiver(events_rx))
914 }
915
916 fn new(
917 tool_use: &LanguageModelToolUse,
918 kind: acp::ToolKind,
919 stream: AgentResponseEventStream,
920 fs: Option<Arc<dyn Fs>>,
921 ) -> Self {
922 Self {
923 tool_use_id: tool_use.id.clone(),
924 kind,
925 input: tool_use.input.clone(),
926 stream,
927 fs,
928 }
929 }
930
931 pub fn update_fields(&self, fields: acp::ToolCallUpdateFields) {
932 self.stream
933 .update_tool_call_fields(&self.tool_use_id, fields);
934 }
935
936 pub fn update_diff(&self, diff: Entity<acp_thread::Diff>) {
937 self.stream
938 .0
939 .unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate(
940 acp_thread::ToolCallUpdateDiff {
941 id: acp::ToolCallId(self.tool_use_id.to_string().into()),
942 diff,
943 }
944 .into(),
945 )))
946 .ok();
947 }
948
949 pub fn update_terminal(&self, terminal: Entity<acp_thread::Terminal>) {
950 self.stream
951 .0
952 .unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate(
953 acp_thread::ToolCallUpdateTerminal {
954 id: acp::ToolCallId(self.tool_use_id.to_string().into()),
955 terminal,
956 }
957 .into(),
958 )))
959 .ok();
960 }
961
962 pub fn authorize(&self, title: impl Into<String>, cx: &mut App) -> Task<Result<()>> {
963 if agent_settings::AgentSettings::get_global(cx).always_allow_tool_actions {
964 return Task::ready(Ok(()));
965 }
966
967 let (response_tx, response_rx) = oneshot::channel();
968 self.stream
969 .0
970 .unbounded_send(Ok(AgentResponseEvent::ToolCallAuthorization(
971 ToolCallAuthorization {
972 tool_call: AgentResponseEventStream::initial_tool_call(
973 &self.tool_use_id,
974 title.into(),
975 self.kind.clone(),
976 self.input.clone(),
977 ),
978 options: vec![
979 acp::PermissionOption {
980 id: acp::PermissionOptionId("always_allow".into()),
981 name: "Always Allow".into(),
982 kind: acp::PermissionOptionKind::AllowAlways,
983 },
984 acp::PermissionOption {
985 id: acp::PermissionOptionId("allow".into()),
986 name: "Allow".into(),
987 kind: acp::PermissionOptionKind::AllowOnce,
988 },
989 acp::PermissionOption {
990 id: acp::PermissionOptionId("deny".into()),
991 name: "Deny".into(),
992 kind: acp::PermissionOptionKind::RejectOnce,
993 },
994 ],
995 response: response_tx,
996 },
997 )))
998 .ok();
999 let fs = self.fs.clone();
1000 cx.spawn(async move |cx| match response_rx.await?.0.as_ref() {
1001 "always_allow" => {
1002 if let Some(fs) = fs.clone() {
1003 cx.update(|cx| {
1004 update_settings_file::<AgentSettings>(fs, cx, |settings, _| {
1005 settings.set_always_allow_tool_actions(true);
1006 });
1007 })?;
1008 }
1009
1010 Ok(())
1011 }
1012 "allow" => Ok(()),
1013 _ => Err(anyhow!("Permission to run tool denied by user")),
1014 })
1015 }
1016}
1017
1018#[cfg(test)]
1019pub struct ToolCallEventStreamReceiver(
1020 mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
1021);
1022
1023#[cfg(test)]
1024impl ToolCallEventStreamReceiver {
1025 pub async fn expect_authorization(&mut self) -> ToolCallAuthorization {
1026 let event = self.0.next().await;
1027 if let Some(Ok(AgentResponseEvent::ToolCallAuthorization(auth))) = event {
1028 auth
1029 } else {
1030 panic!("Expected ToolCallAuthorization but got: {:?}", event);
1031 }
1032 }
1033
1034 pub async fn expect_terminal(&mut self) -> Entity<acp_thread::Terminal> {
1035 let event = self.0.next().await;
1036 if let Some(Ok(AgentResponseEvent::ToolCallUpdate(
1037 acp_thread::ToolCallUpdate::UpdateTerminal(update),
1038 ))) = event
1039 {
1040 update.terminal
1041 } else {
1042 panic!("Expected terminal but got: {:?}", event);
1043 }
1044 }
1045}
1046
1047#[cfg(test)]
1048impl std::ops::Deref for ToolCallEventStreamReceiver {
1049 type Target = mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>;
1050
1051 fn deref(&self) -> &Self::Target {
1052 &self.0
1053 }
1054}
1055
1056#[cfg(test)]
1057impl std::ops::DerefMut for ToolCallEventStreamReceiver {
1058 fn deref_mut(&mut self) -> &mut Self::Target {
1059 &mut self.0
1060 }
1061}