1use crate::{ContextServerRegistry, SystemPromptTemplate, Template, Templates};
2use acp_thread::MentionUri;
3use action_log::ActionLog;
4use agent_client_protocol as acp;
5use agent_settings::{AgentProfileId, AgentSettings};
6use anyhow::{Context as _, Result, anyhow};
7use assistant_tool::adapt_schema_to_format;
8use cloud_llm_client::{CompletionIntent, CompletionMode};
9use collections::HashMap;
10use fs::Fs;
11use futures::{
12 channel::{mpsc, oneshot},
13 stream::FuturesUnordered,
14};
15use gpui::{App, Context, Entity, SharedString, Task};
16use language_model::{
17 LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelImage,
18 LanguageModelProviderId, LanguageModelRequest, LanguageModelRequestMessage,
19 LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolResultContent,
20 LanguageModelToolSchemaFormat, LanguageModelToolUse, LanguageModelToolUseId, Role, StopReason,
21};
22use log;
23use project::Project;
24use prompt_store::ProjectContext;
25use schemars::{JsonSchema, Schema};
26use serde::{Deserialize, Serialize};
27use settings::{Settings, update_settings_file};
28use smol::stream::StreamExt;
29use std::fmt::Write;
30use std::{cell::RefCell, collections::BTreeMap, path::Path, rc::Rc, sync::Arc};
31use util::{ResultExt, markdown::MarkdownCodeBlock};
32
33#[derive(Debug, Clone)]
34pub struct AgentMessage {
35 pub role: Role,
36 pub content: Vec<MessageContent>,
37}
38
39#[derive(Debug, Clone, PartialEq, Eq)]
40pub enum MessageContent {
41 Text(String),
42 Thinking {
43 text: String,
44 signature: Option<String>,
45 },
46 Mention {
47 uri: MentionUri,
48 content: String,
49 },
50 RedactedThinking(String),
51 Image(LanguageModelImage),
52 ToolUse(LanguageModelToolUse),
53 ToolResult(LanguageModelToolResult),
54}
55
56impl AgentMessage {
57 pub fn to_markdown(&self) -> String {
58 let mut markdown = format!("## {}\n", self.role);
59
60 for content in &self.content {
61 match content {
62 MessageContent::Text(text) => {
63 markdown.push_str(text);
64 markdown.push('\n');
65 }
66 MessageContent::Thinking { text, .. } => {
67 markdown.push_str("<think>");
68 markdown.push_str(text);
69 markdown.push_str("</think>\n");
70 }
71 MessageContent::RedactedThinking(_) => markdown.push_str("<redacted_thinking />\n"),
72 MessageContent::Image(_) => {
73 markdown.push_str("<image />\n");
74 }
75 MessageContent::ToolUse(tool_use) => {
76 markdown.push_str(&format!(
77 "**Tool Use**: {} (ID: {})\n",
78 tool_use.name, tool_use.id
79 ));
80 markdown.push_str(&format!(
81 "{}\n",
82 MarkdownCodeBlock {
83 tag: "json",
84 text: &format!("{:#}", tool_use.input)
85 }
86 ));
87 }
88 MessageContent::ToolResult(tool_result) => {
89 markdown.push_str(&format!(
90 "**Tool Result**: {} (ID: {})\n\n",
91 tool_result.tool_name, tool_result.tool_use_id
92 ));
93 if tool_result.is_error {
94 markdown.push_str("**ERROR:**\n");
95 }
96
97 match &tool_result.content {
98 LanguageModelToolResultContent::Text(text) => {
99 writeln!(markdown, "{text}\n").ok();
100 }
101 LanguageModelToolResultContent::Image(_) => {
102 writeln!(markdown, "<image />\n").ok();
103 }
104 }
105
106 if let Some(output) = tool_result.output.as_ref() {
107 writeln!(
108 markdown,
109 "**Debug Output**:\n\n```json\n{}\n```\n",
110 serde_json::to_string_pretty(output).unwrap()
111 )
112 .unwrap();
113 }
114 }
115 MessageContent::Mention { uri, .. } => {
116 write!(markdown, "{}", uri.to_link()).ok();
117 }
118 }
119 }
120
121 markdown
122 }
123}
124
125#[derive(Debug)]
126pub enum AgentResponseEvent {
127 Text(String),
128 Thinking(String),
129 ToolCall(acp::ToolCall),
130 ToolCallUpdate(acp_thread::ToolCallUpdate),
131 ToolCallAuthorization(ToolCallAuthorization),
132 Stop(acp::StopReason),
133}
134
135#[derive(Debug)]
136pub struct ToolCallAuthorization {
137 pub tool_call: acp::ToolCall,
138 pub options: Vec<acp::PermissionOption>,
139 pub response: oneshot::Sender<acp::PermissionOptionId>,
140}
141
142pub struct Thread {
143 messages: Vec<AgentMessage>,
144 completion_mode: CompletionMode,
145 /// Holds the task that handles agent interaction until the end of the turn.
146 /// Survives across multiple requests as the model performs tool calls and
147 /// we run tools, report their results.
148 running_turn: Option<Task<()>>,
149 pending_tool_uses: HashMap<LanguageModelToolUseId, LanguageModelToolUse>,
150 tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
151 context_server_registry: Entity<ContextServerRegistry>,
152 profile_id: AgentProfileId,
153 project_context: Rc<RefCell<ProjectContext>>,
154 templates: Arc<Templates>,
155 pub selected_model: Arc<dyn LanguageModel>,
156 project: Entity<Project>,
157 action_log: Entity<ActionLog>,
158}
159
160impl Thread {
161 pub fn new(
162 project: Entity<Project>,
163 project_context: Rc<RefCell<ProjectContext>>,
164 context_server_registry: Entity<ContextServerRegistry>,
165 action_log: Entity<ActionLog>,
166 templates: Arc<Templates>,
167 default_model: Arc<dyn LanguageModel>,
168 cx: &mut Context<Self>,
169 ) -> Self {
170 let profile_id = AgentSettings::get_global(cx).default_profile.clone();
171 Self {
172 messages: Vec::new(),
173 completion_mode: CompletionMode::Normal,
174 running_turn: None,
175 pending_tool_uses: HashMap::default(),
176 tools: BTreeMap::default(),
177 context_server_registry,
178 profile_id,
179 project_context,
180 templates,
181 selected_model: default_model,
182 project,
183 action_log,
184 }
185 }
186
187 pub fn project(&self) -> &Entity<Project> {
188 &self.project
189 }
190
191 pub fn action_log(&self) -> &Entity<ActionLog> {
192 &self.action_log
193 }
194
195 pub fn set_mode(&mut self, mode: CompletionMode) {
196 self.completion_mode = mode;
197 }
198
199 pub fn messages(&self) -> &[AgentMessage] {
200 &self.messages
201 }
202
203 pub fn add_tool(&mut self, tool: impl AgentTool) {
204 self.tools.insert(tool.name(), tool.erase());
205 }
206
207 pub fn remove_tool(&mut self, name: &str) -> bool {
208 self.tools.remove(name).is_some()
209 }
210
211 pub fn set_profile(&mut self, profile_id: AgentProfileId) {
212 self.profile_id = profile_id;
213 }
214
215 pub fn cancel(&mut self) {
216 self.running_turn.take();
217
218 let tool_results = self
219 .pending_tool_uses
220 .drain()
221 .map(|(tool_use_id, tool_use)| {
222 MessageContent::ToolResult(LanguageModelToolResult {
223 tool_use_id,
224 tool_name: tool_use.name.clone(),
225 is_error: true,
226 content: LanguageModelToolResultContent::Text("Tool canceled by user".into()),
227 output: None,
228 })
229 })
230 .collect::<Vec<_>>();
231 self.last_user_message().content.extend(tool_results);
232 }
233
234 /// Sending a message results in the model streaming a response, which could include tool calls.
235 /// After calling tools, the model will stops and waits for any outstanding tool calls to be completed and their results sent.
236 /// The returned channel will report all the occurrences in which the model stops before erroring or ending its turn.
237 pub fn send(
238 &mut self,
239 content: impl Into<UserMessage>,
240 cx: &mut Context<Self>,
241 ) -> mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>> {
242 let content = content.into().0;
243
244 let model = self.selected_model.clone();
245 log::info!("Thread::send called with model: {:?}", model.name());
246 log::debug!("Thread::send content: {:?}", content);
247
248 cx.notify();
249 let (events_tx, events_rx) =
250 mpsc::unbounded::<Result<AgentResponseEvent, LanguageModelCompletionError>>();
251 let event_stream = AgentResponseEventStream(events_tx);
252
253 let user_message_ix = self.messages.len();
254 self.messages.push(AgentMessage {
255 role: Role::User,
256 content,
257 });
258 log::info!("Total messages in thread: {}", self.messages.len());
259 self.running_turn = Some(cx.spawn(async move |thread, cx| {
260 log::info!("Starting agent turn execution");
261 let turn_result = async {
262 // Perform one request, then keep looping if the model makes tool calls.
263 let mut completion_intent = CompletionIntent::UserPrompt;
264 'outer: loop {
265 log::debug!(
266 "Building completion request with intent: {:?}",
267 completion_intent
268 );
269 let request = thread.update(cx, |thread, cx| {
270 thread.build_completion_request(completion_intent, cx)
271 })?;
272
273 // println!(
274 // "request: {}",
275 // serde_json::to_string_pretty(&request).unwrap()
276 // );
277
278 // Stream events, appending to messages and collecting up tool uses.
279 log::info!("Calling model.stream_completion");
280 let mut events = model.stream_completion(request, cx).await?;
281 log::debug!("Stream completion started successfully");
282 let mut tool_uses = FuturesUnordered::new();
283 while let Some(event) = events.next().await {
284 match event {
285 Ok(LanguageModelCompletionEvent::Stop(reason)) => {
286 event_stream.send_stop(reason);
287 if reason == StopReason::Refusal {
288 thread.update(cx, |thread, _cx| {
289 thread.messages.truncate(user_message_ix);
290 })?;
291 break 'outer;
292 }
293 }
294 Ok(event) => {
295 log::trace!("Received completion event: {:?}", event);
296 thread
297 .update(cx, |thread, cx| {
298 tool_uses.extend(thread.handle_streamed_completion_event(
299 event,
300 &event_stream,
301 cx,
302 ));
303 })
304 .ok();
305 }
306 Err(error) => {
307 log::error!("Error in completion stream: {:?}", error);
308 event_stream.send_error(error);
309 break;
310 }
311 }
312 }
313
314 // If there are no tool uses, the turn is done.
315 if tool_uses.is_empty() {
316 log::info!("No tool uses found, completing turn");
317 break;
318 }
319 log::info!("Found {} tool uses to execute", tool_uses.len());
320
321 // As tool results trickle in, insert them in the last user
322 // message so that they can be sent on the next tick of the
323 // agentic loop.
324 while let Some(tool_result) = tool_uses.next().await {
325 log::info!("Tool finished {:?}", tool_result);
326
327 event_stream.update_tool_call_fields(
328 &tool_result.tool_use_id,
329 acp::ToolCallUpdateFields {
330 status: Some(if tool_result.is_error {
331 acp::ToolCallStatus::Failed
332 } else {
333 acp::ToolCallStatus::Completed
334 }),
335 raw_output: tool_result.output.clone(),
336 ..Default::default()
337 },
338 );
339 thread
340 .update(cx, |thread, _cx| {
341 thread.pending_tool_uses.remove(&tool_result.tool_use_id);
342 thread
343 .last_user_message()
344 .content
345 .push(MessageContent::ToolResult(tool_result));
346 })
347 .ok();
348 }
349
350 completion_intent = CompletionIntent::ToolResults;
351 }
352
353 Ok(())
354 }
355 .await;
356
357 if let Err(error) = turn_result {
358 log::error!("Turn execution failed: {:?}", error);
359 event_stream.send_error(error);
360 } else {
361 log::info!("Turn execution completed successfully");
362 }
363 }));
364 events_rx
365 }
366
367 pub fn build_system_message(&self) -> AgentMessage {
368 log::debug!("Building system message");
369 let prompt = SystemPromptTemplate {
370 project: &self.project_context.borrow(),
371 available_tools: self.tools.keys().cloned().collect(),
372 }
373 .render(&self.templates)
374 .context("failed to build system prompt")
375 .expect("Invalid template");
376 log::debug!("System message built");
377 AgentMessage {
378 role: Role::System,
379 content: vec![prompt.as_str().into()],
380 }
381 }
382
383 /// A helper method that's called on every streamed completion event.
384 /// Returns an optional tool result task, which the main agentic loop in
385 /// send will send back to the model when it resolves.
386 fn handle_streamed_completion_event(
387 &mut self,
388 event: LanguageModelCompletionEvent,
389 event_stream: &AgentResponseEventStream,
390 cx: &mut Context<Self>,
391 ) -> Option<Task<LanguageModelToolResult>> {
392 log::trace!("Handling streamed completion event: {:?}", event);
393 use LanguageModelCompletionEvent::*;
394
395 match event {
396 StartMessage { .. } => {
397 self.messages.push(AgentMessage {
398 role: Role::Assistant,
399 content: Vec::new(),
400 });
401 }
402 Text(new_text) => self.handle_text_event(new_text, event_stream, cx),
403 Thinking { text, signature } => {
404 self.handle_thinking_event(text, signature, event_stream, cx)
405 }
406 RedactedThinking { data } => self.handle_redacted_thinking_event(data, cx),
407 ToolUse(tool_use) => {
408 return self.handle_tool_use_event(tool_use, event_stream, cx);
409 }
410 ToolUseJsonParseError {
411 id,
412 tool_name,
413 raw_input,
414 json_parse_error,
415 } => {
416 return Some(Task::ready(self.handle_tool_use_json_parse_error_event(
417 id,
418 tool_name,
419 raw_input,
420 json_parse_error,
421 )));
422 }
423 UsageUpdate(_) | StatusUpdate(_) => {}
424 Stop(_) => unreachable!(),
425 }
426
427 None
428 }
429
430 fn handle_text_event(
431 &mut self,
432 new_text: String,
433 events_stream: &AgentResponseEventStream,
434 cx: &mut Context<Self>,
435 ) {
436 events_stream.send_text(&new_text);
437
438 let last_message = self.last_assistant_message();
439 if let Some(MessageContent::Text(text)) = last_message.content.last_mut() {
440 text.push_str(&new_text);
441 } else {
442 last_message.content.push(MessageContent::Text(new_text));
443 }
444
445 cx.notify();
446 }
447
448 fn handle_thinking_event(
449 &mut self,
450 new_text: String,
451 new_signature: Option<String>,
452 event_stream: &AgentResponseEventStream,
453 cx: &mut Context<Self>,
454 ) {
455 event_stream.send_thinking(&new_text);
456
457 let last_message = self.last_assistant_message();
458 if let Some(MessageContent::Thinking { text, signature }) = last_message.content.last_mut()
459 {
460 text.push_str(&new_text);
461 *signature = new_signature.or(signature.take());
462 } else {
463 last_message.content.push(MessageContent::Thinking {
464 text: new_text,
465 signature: new_signature,
466 });
467 }
468
469 cx.notify();
470 }
471
472 fn handle_redacted_thinking_event(&mut self, data: String, cx: &mut Context<Self>) {
473 let last_message = self.last_assistant_message();
474 last_message
475 .content
476 .push(MessageContent::RedactedThinking(data));
477 cx.notify();
478 }
479
480 fn handle_tool_use_event(
481 &mut self,
482 tool_use: LanguageModelToolUse,
483 event_stream: &AgentResponseEventStream,
484 cx: &mut Context<Self>,
485 ) -> Option<Task<LanguageModelToolResult>> {
486 cx.notify();
487
488 let tool = self.tools.get(tool_use.name.as_ref()).cloned();
489
490 self.pending_tool_uses
491 .insert(tool_use.id.clone(), tool_use.clone());
492 let last_message = self.last_assistant_message();
493
494 // Ensure the last message ends in the current tool use
495 let push_new_tool_use = last_message.content.last_mut().map_or(true, |content| {
496 if let MessageContent::ToolUse(last_tool_use) = content {
497 if last_tool_use.id == tool_use.id {
498 *last_tool_use = tool_use.clone();
499 false
500 } else {
501 true
502 }
503 } else {
504 true
505 }
506 });
507
508 let mut title = SharedString::from(&tool_use.name);
509 let mut kind = acp::ToolKind::Other;
510 if let Some(tool) = tool.as_ref() {
511 title = tool.initial_title(tool_use.input.clone());
512 kind = tool.kind();
513 }
514
515 if push_new_tool_use {
516 event_stream.send_tool_call(&tool_use.id, title, kind, tool_use.input.clone());
517 last_message
518 .content
519 .push(MessageContent::ToolUse(tool_use.clone()));
520 } else {
521 event_stream.update_tool_call_fields(
522 &tool_use.id,
523 acp::ToolCallUpdateFields {
524 title: Some(title.into()),
525 kind: Some(kind),
526 raw_input: Some(tool_use.input.clone()),
527 ..Default::default()
528 },
529 );
530 }
531
532 if !tool_use.is_input_complete {
533 return None;
534 }
535
536 let Some(tool) = tool else {
537 let content = format!("No tool named {} exists", tool_use.name);
538 return Some(Task::ready(LanguageModelToolResult {
539 content: LanguageModelToolResultContent::Text(Arc::from(content)),
540 tool_use_id: tool_use.id,
541 tool_name: tool_use.name,
542 is_error: true,
543 output: None,
544 }));
545 };
546
547 let fs = self.project.read(cx).fs().clone();
548 let tool_event_stream =
549 ToolCallEventStream::new(&tool_use, tool.kind(), event_stream.clone(), Some(fs));
550 tool_event_stream.update_fields(acp::ToolCallUpdateFields {
551 status: Some(acp::ToolCallStatus::InProgress),
552 ..Default::default()
553 });
554 let supports_images = self.selected_model.supports_images();
555 let tool_result = tool.run(tool_use.input, tool_event_stream, cx);
556 Some(cx.foreground_executor().spawn(async move {
557 let tool_result = tool_result.await.and_then(|output| {
558 if let LanguageModelToolResultContent::Image(_) = &output.llm_output {
559 if !supports_images {
560 return Err(anyhow!(
561 "Attempted to read an image, but this model doesn't support it.",
562 ));
563 }
564 }
565 Ok(output)
566 });
567
568 match tool_result {
569 Ok(output) => LanguageModelToolResult {
570 tool_use_id: tool_use.id,
571 tool_name: tool_use.name,
572 is_error: false,
573 content: output.llm_output,
574 output: Some(output.raw_output),
575 },
576 Err(error) => LanguageModelToolResult {
577 tool_use_id: tool_use.id,
578 tool_name: tool_use.name,
579 is_error: true,
580 content: LanguageModelToolResultContent::Text(Arc::from(error.to_string())),
581 output: None,
582 },
583 }
584 }))
585 }
586
587 fn handle_tool_use_json_parse_error_event(
588 &mut self,
589 tool_use_id: LanguageModelToolUseId,
590 tool_name: Arc<str>,
591 raw_input: Arc<str>,
592 json_parse_error: String,
593 ) -> LanguageModelToolResult {
594 let tool_output = format!("Error parsing input JSON: {json_parse_error}");
595 LanguageModelToolResult {
596 tool_use_id,
597 tool_name,
598 is_error: true,
599 content: LanguageModelToolResultContent::Text(tool_output.into()),
600 output: Some(serde_json::Value::String(raw_input.to_string())),
601 }
602 }
603
604 /// Guarantees the last message is from the assistant and returns a mutable reference.
605 fn last_assistant_message(&mut self) -> &mut AgentMessage {
606 if self
607 .messages
608 .last()
609 .map_or(true, |m| m.role != Role::Assistant)
610 {
611 self.messages.push(AgentMessage {
612 role: Role::Assistant,
613 content: Vec::new(),
614 });
615 }
616 self.messages.last_mut().unwrap()
617 }
618
619 /// Guarantees the last message is from the user and returns a mutable reference.
620 fn last_user_message(&mut self) -> &mut AgentMessage {
621 if self.messages.last().map_or(true, |m| m.role != Role::User) {
622 self.messages.push(AgentMessage {
623 role: Role::User,
624 content: Vec::new(),
625 });
626 }
627 self.messages.last_mut().unwrap()
628 }
629
630 pub(crate) fn build_completion_request(
631 &self,
632 completion_intent: CompletionIntent,
633 cx: &mut App,
634 ) -> LanguageModelRequest {
635 log::debug!("Building completion request");
636 log::debug!("Completion intent: {:?}", completion_intent);
637 log::debug!("Completion mode: {:?}", self.completion_mode);
638
639 let messages = self.build_request_messages();
640 log::info!("Request will include {} messages", messages.len());
641
642 let tools = if let Some(tools) = self.tools(cx).log_err() {
643 tools
644 .filter_map(|tool| {
645 let tool_name = tool.name().to_string();
646 log::trace!("Including tool: {}", tool_name);
647 Some(LanguageModelRequestTool {
648 name: tool_name,
649 description: tool.description().to_string(),
650 input_schema: tool
651 .input_schema(self.selected_model.tool_input_format())
652 .log_err()?,
653 })
654 })
655 .collect()
656 } else {
657 Vec::new()
658 };
659
660 log::info!("Request includes {} tools", tools.len());
661
662 let request = LanguageModelRequest {
663 thread_id: None,
664 prompt_id: None,
665 intent: Some(completion_intent),
666 mode: Some(self.completion_mode),
667 messages,
668 tools,
669 tool_choice: None,
670 stop: Vec::new(),
671 temperature: None,
672 thinking_allowed: true,
673 };
674
675 log::debug!("Completion request built successfully");
676 request
677 }
678
679 fn tools<'a>(&'a self, cx: &'a App) -> Result<impl Iterator<Item = &'a Arc<dyn AnyAgentTool>>> {
680 let profile = AgentSettings::get_global(cx)
681 .profiles
682 .get(&self.profile_id)
683 .context("profile not found")?;
684 let provider_id = self.selected_model.provider_id();
685
686 Ok(self
687 .tools
688 .iter()
689 .filter(move |(_, tool)| tool.supported_provider(&provider_id))
690 .filter_map(|(tool_name, tool)| {
691 if profile.is_tool_enabled(tool_name) {
692 Some(tool)
693 } else {
694 None
695 }
696 })
697 .chain(self.context_server_registry.read(cx).servers().flat_map(
698 |(server_id, tools)| {
699 tools.iter().filter_map(|(tool_name, tool)| {
700 if profile.is_context_server_tool_enabled(&server_id.0, tool_name) {
701 Some(tool)
702 } else {
703 None
704 }
705 })
706 },
707 )))
708 }
709
710 fn build_request_messages(&self) -> Vec<LanguageModelRequestMessage> {
711 log::trace!(
712 "Building request messages from {} thread messages",
713 self.messages.len()
714 );
715
716 let messages = Some(self.build_system_message())
717 .iter()
718 .chain(self.messages.iter())
719 .map(|message| {
720 log::trace!(
721 " - {} message with {} content items",
722 match message.role {
723 Role::System => "System",
724 Role::User => "User",
725 Role::Assistant => "Assistant",
726 },
727 message.content.len()
728 );
729 message.to_request()
730 })
731 .collect();
732 messages
733 }
734
735 pub fn to_markdown(&self) -> String {
736 let mut markdown = String::new();
737 for message in &self.messages {
738 markdown.push_str(&message.to_markdown());
739 }
740 markdown
741 }
742}
743
744pub struct UserMessage(Vec<MessageContent>);
745
746impl From<Vec<MessageContent>> for UserMessage {
747 fn from(content: Vec<MessageContent>) -> Self {
748 UserMessage(content)
749 }
750}
751
752impl<T: Into<MessageContent>> From<T> for UserMessage {
753 fn from(content: T) -> Self {
754 UserMessage(vec![content.into()])
755 }
756}
757
758pub trait AgentTool
759where
760 Self: 'static + Sized,
761{
762 type Input: for<'de> Deserialize<'de> + Serialize + JsonSchema;
763 type Output: for<'de> Deserialize<'de> + Serialize + Into<LanguageModelToolResultContent>;
764
765 fn name(&self) -> SharedString;
766
767 fn description(&self) -> SharedString {
768 let schema = schemars::schema_for!(Self::Input);
769 SharedString::new(
770 schema
771 .get("description")
772 .and_then(|description| description.as_str())
773 .unwrap_or_default(),
774 )
775 }
776
777 fn kind(&self) -> acp::ToolKind;
778
779 /// The initial tool title to display. Can be updated during the tool run.
780 fn initial_title(&self, input: Result<Self::Input, serde_json::Value>) -> SharedString;
781
782 /// Returns the JSON schema that describes the tool's input.
783 fn input_schema(&self) -> Schema {
784 schemars::schema_for!(Self::Input)
785 }
786
787 /// Some tools rely on a provider for the underlying billing or other reasons.
788 /// Allow the tool to check if they are compatible, or should be filtered out.
789 fn supported_provider(&self, _provider: &LanguageModelProviderId) -> bool {
790 true
791 }
792
793 /// Runs the tool with the provided input.
794 fn run(
795 self: Arc<Self>,
796 input: Self::Input,
797 event_stream: ToolCallEventStream,
798 cx: &mut App,
799 ) -> Task<Result<Self::Output>>;
800
801 fn erase(self) -> Arc<dyn AnyAgentTool> {
802 Arc::new(Erased(Arc::new(self)))
803 }
804}
805
806pub struct Erased<T>(T);
807
808pub struct AgentToolOutput {
809 pub llm_output: LanguageModelToolResultContent,
810 pub raw_output: serde_json::Value,
811}
812
813pub trait AnyAgentTool {
814 fn name(&self) -> SharedString;
815 fn description(&self) -> SharedString;
816 fn kind(&self) -> acp::ToolKind;
817 fn initial_title(&self, input: serde_json::Value) -> SharedString;
818 fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value>;
819 fn supported_provider(&self, _provider: &LanguageModelProviderId) -> bool {
820 true
821 }
822 fn run(
823 self: Arc<Self>,
824 input: serde_json::Value,
825 event_stream: ToolCallEventStream,
826 cx: &mut App,
827 ) -> Task<Result<AgentToolOutput>>;
828}
829
830impl<T> AnyAgentTool for Erased<Arc<T>>
831where
832 T: AgentTool,
833{
834 fn name(&self) -> SharedString {
835 self.0.name()
836 }
837
838 fn description(&self) -> SharedString {
839 self.0.description()
840 }
841
842 fn kind(&self) -> agent_client_protocol::ToolKind {
843 self.0.kind()
844 }
845
846 fn initial_title(&self, input: serde_json::Value) -> SharedString {
847 let parsed_input = serde_json::from_value(input.clone()).map_err(|_| input);
848 self.0.initial_title(parsed_input)
849 }
850
851 fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
852 let mut json = serde_json::to_value(self.0.input_schema())?;
853 adapt_schema_to_format(&mut json, format)?;
854 Ok(json)
855 }
856
857 fn supported_provider(&self, provider: &LanguageModelProviderId) -> bool {
858 self.0.supported_provider(provider)
859 }
860
861 fn run(
862 self: Arc<Self>,
863 input: serde_json::Value,
864 event_stream: ToolCallEventStream,
865 cx: &mut App,
866 ) -> Task<Result<AgentToolOutput>> {
867 cx.spawn(async move |cx| {
868 let input = serde_json::from_value(input)?;
869 let output = cx
870 .update(|cx| self.0.clone().run(input, event_stream, cx))?
871 .await?;
872 let raw_output = serde_json::to_value(&output)?;
873 Ok(AgentToolOutput {
874 llm_output: output.into(),
875 raw_output,
876 })
877 })
878 }
879}
880
881#[derive(Clone)]
882struct AgentResponseEventStream(
883 mpsc::UnboundedSender<Result<AgentResponseEvent, LanguageModelCompletionError>>,
884);
885
886impl AgentResponseEventStream {
887 fn send_text(&self, text: &str) {
888 self.0
889 .unbounded_send(Ok(AgentResponseEvent::Text(text.to_string())))
890 .ok();
891 }
892
893 fn send_thinking(&self, text: &str) {
894 self.0
895 .unbounded_send(Ok(AgentResponseEvent::Thinking(text.to_string())))
896 .ok();
897 }
898
899 fn send_tool_call(
900 &self,
901 id: &LanguageModelToolUseId,
902 title: SharedString,
903 kind: acp::ToolKind,
904 input: serde_json::Value,
905 ) {
906 self.0
907 .unbounded_send(Ok(AgentResponseEvent::ToolCall(Self::initial_tool_call(
908 id,
909 title.to_string(),
910 kind,
911 input,
912 ))))
913 .ok();
914 }
915
916 fn initial_tool_call(
917 id: &LanguageModelToolUseId,
918 title: String,
919 kind: acp::ToolKind,
920 input: serde_json::Value,
921 ) -> acp::ToolCall {
922 acp::ToolCall {
923 id: acp::ToolCallId(id.to_string().into()),
924 title,
925 kind,
926 status: acp::ToolCallStatus::Pending,
927 content: vec![],
928 locations: vec![],
929 raw_input: Some(input),
930 raw_output: None,
931 }
932 }
933
934 fn update_tool_call_fields(
935 &self,
936 tool_use_id: &LanguageModelToolUseId,
937 fields: acp::ToolCallUpdateFields,
938 ) {
939 self.0
940 .unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate(
941 acp::ToolCallUpdate {
942 id: acp::ToolCallId(tool_use_id.to_string().into()),
943 fields,
944 }
945 .into(),
946 )))
947 .ok();
948 }
949
950 fn send_stop(&self, reason: StopReason) {
951 match reason {
952 StopReason::EndTurn => {
953 self.0
954 .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::EndTurn)))
955 .ok();
956 }
957 StopReason::MaxTokens => {
958 self.0
959 .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::MaxTokens)))
960 .ok();
961 }
962 StopReason::Refusal => {
963 self.0
964 .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::Refusal)))
965 .ok();
966 }
967 StopReason::ToolUse => {}
968 }
969 }
970
971 fn send_error(&self, error: LanguageModelCompletionError) {
972 self.0.unbounded_send(Err(error)).ok();
973 }
974}
975
976#[derive(Clone)]
977pub struct ToolCallEventStream {
978 tool_use_id: LanguageModelToolUseId,
979 kind: acp::ToolKind,
980 input: serde_json::Value,
981 stream: AgentResponseEventStream,
982 fs: Option<Arc<dyn Fs>>,
983}
984
985impl ToolCallEventStream {
986 #[cfg(test)]
987 pub fn test() -> (Self, ToolCallEventStreamReceiver) {
988 let (events_tx, events_rx) =
989 mpsc::unbounded::<Result<AgentResponseEvent, LanguageModelCompletionError>>();
990
991 let stream = ToolCallEventStream::new(
992 &LanguageModelToolUse {
993 id: "test_id".into(),
994 name: "test_tool".into(),
995 raw_input: String::new(),
996 input: serde_json::Value::Null,
997 is_input_complete: true,
998 },
999 acp::ToolKind::Other,
1000 AgentResponseEventStream(events_tx),
1001 None,
1002 );
1003
1004 (stream, ToolCallEventStreamReceiver(events_rx))
1005 }
1006
1007 fn new(
1008 tool_use: &LanguageModelToolUse,
1009 kind: acp::ToolKind,
1010 stream: AgentResponseEventStream,
1011 fs: Option<Arc<dyn Fs>>,
1012 ) -> Self {
1013 Self {
1014 tool_use_id: tool_use.id.clone(),
1015 kind,
1016 input: tool_use.input.clone(),
1017 stream,
1018 fs,
1019 }
1020 }
1021
1022 pub fn update_fields(&self, fields: acp::ToolCallUpdateFields) {
1023 self.stream
1024 .update_tool_call_fields(&self.tool_use_id, fields);
1025 }
1026
1027 pub fn update_diff(&self, diff: Entity<acp_thread::Diff>) {
1028 self.stream
1029 .0
1030 .unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate(
1031 acp_thread::ToolCallUpdateDiff {
1032 id: acp::ToolCallId(self.tool_use_id.to_string().into()),
1033 diff,
1034 }
1035 .into(),
1036 )))
1037 .ok();
1038 }
1039
1040 pub fn update_terminal(&self, terminal: Entity<acp_thread::Terminal>) {
1041 self.stream
1042 .0
1043 .unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate(
1044 acp_thread::ToolCallUpdateTerminal {
1045 id: acp::ToolCallId(self.tool_use_id.to_string().into()),
1046 terminal,
1047 }
1048 .into(),
1049 )))
1050 .ok();
1051 }
1052
1053 pub fn authorize(&self, title: impl Into<String>, cx: &mut App) -> Task<Result<()>> {
1054 if agent_settings::AgentSettings::get_global(cx).always_allow_tool_actions {
1055 return Task::ready(Ok(()));
1056 }
1057
1058 let (response_tx, response_rx) = oneshot::channel();
1059 self.stream
1060 .0
1061 .unbounded_send(Ok(AgentResponseEvent::ToolCallAuthorization(
1062 ToolCallAuthorization {
1063 tool_call: AgentResponseEventStream::initial_tool_call(
1064 &self.tool_use_id,
1065 title.into(),
1066 self.kind.clone(),
1067 self.input.clone(),
1068 ),
1069 options: vec![
1070 acp::PermissionOption {
1071 id: acp::PermissionOptionId("always_allow".into()),
1072 name: "Always Allow".into(),
1073 kind: acp::PermissionOptionKind::AllowAlways,
1074 },
1075 acp::PermissionOption {
1076 id: acp::PermissionOptionId("allow".into()),
1077 name: "Allow".into(),
1078 kind: acp::PermissionOptionKind::AllowOnce,
1079 },
1080 acp::PermissionOption {
1081 id: acp::PermissionOptionId("deny".into()),
1082 name: "Deny".into(),
1083 kind: acp::PermissionOptionKind::RejectOnce,
1084 },
1085 ],
1086 response: response_tx,
1087 },
1088 )))
1089 .ok();
1090 let fs = self.fs.clone();
1091 cx.spawn(async move |cx| match response_rx.await?.0.as_ref() {
1092 "always_allow" => {
1093 if let Some(fs) = fs.clone() {
1094 cx.update(|cx| {
1095 update_settings_file::<AgentSettings>(fs, cx, |settings, _| {
1096 settings.set_always_allow_tool_actions(true);
1097 });
1098 })?;
1099 }
1100
1101 Ok(())
1102 }
1103 "allow" => Ok(()),
1104 _ => Err(anyhow!("Permission to run tool denied by user")),
1105 })
1106 }
1107}
1108
1109#[cfg(test)]
1110pub struct ToolCallEventStreamReceiver(
1111 mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
1112);
1113
1114#[cfg(test)]
1115impl ToolCallEventStreamReceiver {
1116 pub async fn expect_authorization(&mut self) -> ToolCallAuthorization {
1117 let event = self.0.next().await;
1118 if let Some(Ok(AgentResponseEvent::ToolCallAuthorization(auth))) = event {
1119 auth
1120 } else {
1121 panic!("Expected ToolCallAuthorization but got: {:?}", event);
1122 }
1123 }
1124
1125 pub async fn expect_terminal(&mut self) -> Entity<acp_thread::Terminal> {
1126 let event = self.0.next().await;
1127 if let Some(Ok(AgentResponseEvent::ToolCallUpdate(
1128 acp_thread::ToolCallUpdate::UpdateTerminal(update),
1129 ))) = event
1130 {
1131 update.terminal
1132 } else {
1133 panic!("Expected terminal but got: {:?}", event);
1134 }
1135 }
1136}
1137
1138#[cfg(test)]
1139impl std::ops::Deref for ToolCallEventStreamReceiver {
1140 type Target = mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>;
1141
1142 fn deref(&self) -> &Self::Target {
1143 &self.0
1144 }
1145}
1146
1147#[cfg(test)]
1148impl std::ops::DerefMut for ToolCallEventStreamReceiver {
1149 fn deref_mut(&mut self) -> &mut Self::Target {
1150 &mut self.0
1151 }
1152}
1153
1154impl AgentMessage {
1155 fn to_request(&self) -> language_model::LanguageModelRequestMessage {
1156 let mut message = LanguageModelRequestMessage {
1157 role: self.role,
1158 content: Vec::with_capacity(self.content.len()),
1159 cache: false,
1160 };
1161
1162 const OPEN_CONTEXT: &str = "<context>\n\
1163 The following items were attached by the user. \
1164 They are up-to-date and don't need to be re-read.\n\n";
1165
1166 const OPEN_FILES_TAG: &str = "<files>";
1167 const OPEN_SYMBOLS_TAG: &str = "<symbols>";
1168 const OPEN_THREADS_TAG: &str = "<threads>";
1169 const OPEN_RULES_TAG: &str =
1170 "<rules>\nThe user has specified the following rules that should be applied:\n";
1171
1172 let mut file_context = OPEN_FILES_TAG.to_string();
1173 let mut symbol_context = OPEN_SYMBOLS_TAG.to_string();
1174 let mut thread_context = OPEN_THREADS_TAG.to_string();
1175 let mut rules_context = OPEN_RULES_TAG.to_string();
1176
1177 for chunk in &self.content {
1178 let chunk = match chunk {
1179 MessageContent::Text(text) => language_model::MessageContent::Text(text.clone()),
1180 MessageContent::Thinking { text, signature } => {
1181 language_model::MessageContent::Thinking {
1182 text: text.clone(),
1183 signature: signature.clone(),
1184 }
1185 }
1186 MessageContent::RedactedThinking(value) => {
1187 language_model::MessageContent::RedactedThinking(value.clone())
1188 }
1189 MessageContent::ToolUse(value) => {
1190 language_model::MessageContent::ToolUse(value.clone())
1191 }
1192 MessageContent::ToolResult(value) => {
1193 language_model::MessageContent::ToolResult(value.clone())
1194 }
1195 MessageContent::Image(value) => {
1196 language_model::MessageContent::Image(value.clone())
1197 }
1198 MessageContent::Mention { uri, content } => {
1199 match uri {
1200 MentionUri::File(path) | MentionUri::Symbol(path, _) => {
1201 write!(
1202 &mut symbol_context,
1203 "\n{}",
1204 MarkdownCodeBlock {
1205 tag: &codeblock_tag(&path),
1206 text: &content.to_string(),
1207 }
1208 )
1209 .ok();
1210 }
1211 MentionUri::Thread(_session_id) => {
1212 write!(&mut thread_context, "\n{}\n", content).ok();
1213 }
1214 MentionUri::Rule(_user_prompt_id) => {
1215 write!(
1216 &mut rules_context,
1217 "\n{}",
1218 MarkdownCodeBlock {
1219 tag: "",
1220 text: &content
1221 }
1222 )
1223 .ok();
1224 }
1225 }
1226
1227 language_model::MessageContent::Text(uri.to_link())
1228 }
1229 };
1230
1231 message.content.push(chunk);
1232 }
1233
1234 let len_before_context = message.content.len();
1235
1236 if file_context.len() > OPEN_FILES_TAG.len() {
1237 file_context.push_str("</files>\n");
1238 message
1239 .content
1240 .push(language_model::MessageContent::Text(file_context));
1241 }
1242
1243 if symbol_context.len() > OPEN_SYMBOLS_TAG.len() {
1244 symbol_context.push_str("</symbols>\n");
1245 message
1246 .content
1247 .push(language_model::MessageContent::Text(symbol_context));
1248 }
1249
1250 if thread_context.len() > OPEN_THREADS_TAG.len() {
1251 thread_context.push_str("</threads>\n");
1252 message
1253 .content
1254 .push(language_model::MessageContent::Text(thread_context));
1255 }
1256
1257 if rules_context.len() > OPEN_RULES_TAG.len() {
1258 rules_context.push_str("</user_rules>\n");
1259 message
1260 .content
1261 .push(language_model::MessageContent::Text(rules_context));
1262 }
1263
1264 if message.content.len() > len_before_context {
1265 message.content.insert(
1266 len_before_context,
1267 language_model::MessageContent::Text(OPEN_CONTEXT.into()),
1268 );
1269 message
1270 .content
1271 .push(language_model::MessageContent::Text("</context>".into()));
1272 }
1273
1274 message
1275 }
1276}
1277
1278fn codeblock_tag(full_path: &Path) -> String {
1279 let mut result = String::new();
1280
1281 if let Some(extension) = full_path.extension().and_then(|ext| ext.to_str()) {
1282 let _ = write!(result, "{} ", extension);
1283 }
1284
1285 let _ = write!(result, "{}", full_path.display());
1286
1287 result
1288}
1289
1290impl From<acp::ContentBlock> for MessageContent {
1291 fn from(value: acp::ContentBlock) -> Self {
1292 match value {
1293 acp::ContentBlock::Text(text_content) => MessageContent::Text(text_content.text),
1294 acp::ContentBlock::Image(image_content) => {
1295 MessageContent::Image(convert_image(image_content))
1296 }
1297 acp::ContentBlock::Audio(_) => {
1298 // TODO
1299 MessageContent::Text("[audio]".to_string())
1300 }
1301 acp::ContentBlock::ResourceLink(resource_link) => {
1302 match MentionUri::parse(&resource_link.uri) {
1303 Ok(uri) => Self::Mention {
1304 uri,
1305 content: String::new(),
1306 },
1307 Err(err) => {
1308 log::error!("Failed to parse mention link: {}", err);
1309 MessageContent::Text(format!(
1310 "[{}]({})",
1311 resource_link.name, resource_link.uri
1312 ))
1313 }
1314 }
1315 }
1316 acp::ContentBlock::Resource(resource) => match resource.resource {
1317 acp::EmbeddedResourceResource::TextResourceContents(resource) => {
1318 match MentionUri::parse(&resource.uri) {
1319 Ok(uri) => Self::Mention {
1320 uri,
1321 content: resource.text,
1322 },
1323 Err(err) => {
1324 log::error!("Failed to parse mention link: {}", err);
1325 MessageContent::Text(
1326 MarkdownCodeBlock {
1327 tag: &resource.uri,
1328 text: &resource.text,
1329 }
1330 .to_string(),
1331 )
1332 }
1333 }
1334 }
1335 acp::EmbeddedResourceResource::BlobResourceContents(_) => {
1336 // TODO
1337 MessageContent::Text("[blob]".to_string())
1338 }
1339 },
1340 }
1341 }
1342}
1343
1344fn convert_image(image_content: acp::ImageContent) -> LanguageModelImage {
1345 LanguageModelImage {
1346 source: image_content.data.into(),
1347 // TODO: make this optional?
1348 size: gpui::Size::new(0.into(), 0.into()),
1349 }
1350}
1351
1352impl From<&str> for MessageContent {
1353 fn from(text: &str) -> Self {
1354 MessageContent::Text(text.into())
1355 }
1356}