1use crate::{ContextServerRegistry, SystemPromptTemplate, Template, Templates};
2use acp_thread::{MentionUri, UserMessageId};
3use action_log::ActionLog;
4use agent_client_protocol as acp;
5use agent_settings::{AgentProfileId, AgentSettings, CompletionMode};
6use anyhow::{Context as _, Result, anyhow};
7use assistant_tool::adapt_schema_to_format;
8use cloud_llm_client::{CompletionIntent, CompletionRequestStatus};
9use collections::IndexMap;
10use fs::Fs;
11use futures::{
12 channel::{mpsc, oneshot},
13 stream::FuturesUnordered,
14};
15use gpui::{App, Context, Entity, SharedString, Task};
16use language_model::{
17 LanguageModel, LanguageModelCompletionEvent, LanguageModelImage, LanguageModelProviderId,
18 LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
19 LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolSchemaFormat,
20 LanguageModelToolUse, LanguageModelToolUseId, Role, StopReason,
21};
22use project::Project;
23use prompt_store::ProjectContext;
24use schemars::{JsonSchema, Schema};
25use serde::{Deserialize, Serialize};
26use settings::{Settings, update_settings_file};
27use smol::stream::StreamExt;
28use std::{cell::RefCell, collections::BTreeMap, path::Path, rc::Rc, sync::Arc};
29use std::{fmt::Write, ops::Range};
30use util::{ResultExt, markdown::MarkdownCodeBlock};
31
32#[derive(Debug, Clone, PartialEq, Eq)]
33pub enum Message {
34 User(UserMessage),
35 Agent(AgentMessage),
36 Resume,
37}
38
39impl Message {
40 pub fn as_agent_message(&self) -> Option<&AgentMessage> {
41 match self {
42 Message::Agent(agent_message) => Some(agent_message),
43 _ => None,
44 }
45 }
46
47 pub fn to_markdown(&self) -> String {
48 match self {
49 Message::User(message) => message.to_markdown(),
50 Message::Agent(message) => message.to_markdown(),
51 Message::Resume => "[resumed after tool use limit was reached]".into(),
52 }
53 }
54}
55
56#[derive(Debug, Clone, PartialEq, Eq)]
57pub struct UserMessage {
58 pub id: UserMessageId,
59 pub content: Vec<UserMessageContent>,
60}
61
62#[derive(Debug, Clone, PartialEq, Eq)]
63pub enum UserMessageContent {
64 Text(String),
65 Mention { uri: MentionUri, content: String },
66 Image(LanguageModelImage),
67}
68
69impl UserMessage {
70 pub fn to_markdown(&self) -> String {
71 let mut markdown = String::from("## User\n\n");
72
73 for content in &self.content {
74 match content {
75 UserMessageContent::Text(text) => {
76 markdown.push_str(text);
77 markdown.push('\n');
78 }
79 UserMessageContent::Image(_) => {
80 markdown.push_str("<image />\n");
81 }
82 UserMessageContent::Mention { uri, content } => {
83 if !content.is_empty() {
84 let _ = write!(&mut markdown, "{}\n\n{}\n", uri.as_link(), content);
85 } else {
86 let _ = write!(&mut markdown, "{}\n", uri.as_link());
87 }
88 }
89 }
90 }
91
92 markdown
93 }
94
95 fn to_request(&self) -> LanguageModelRequestMessage {
96 let mut message = LanguageModelRequestMessage {
97 role: Role::User,
98 content: Vec::with_capacity(self.content.len()),
99 cache: false,
100 };
101
102 const OPEN_CONTEXT: &str = "<context>\n\
103 The following items were attached by the user. \
104 They are up-to-date and don't need to be re-read.\n\n";
105
106 const OPEN_FILES_TAG: &str = "<files>";
107 const OPEN_SYMBOLS_TAG: &str = "<symbols>";
108 const OPEN_THREADS_TAG: &str = "<threads>";
109 const OPEN_FETCH_TAG: &str = "<fetched_urls>";
110 const OPEN_RULES_TAG: &str =
111 "<rules>\nThe user has specified the following rules that should be applied:\n";
112
113 let mut file_context = OPEN_FILES_TAG.to_string();
114 let mut symbol_context = OPEN_SYMBOLS_TAG.to_string();
115 let mut thread_context = OPEN_THREADS_TAG.to_string();
116 let mut fetch_context = OPEN_FETCH_TAG.to_string();
117 let mut rules_context = OPEN_RULES_TAG.to_string();
118
119 for chunk in &self.content {
120 let chunk = match chunk {
121 UserMessageContent::Text(text) => {
122 language_model::MessageContent::Text(text.clone())
123 }
124 UserMessageContent::Image(value) => {
125 language_model::MessageContent::Image(value.clone())
126 }
127 UserMessageContent::Mention { uri, content } => {
128 match uri {
129 MentionUri::File { abs_path, .. } => {
130 write!(
131 &mut symbol_context,
132 "\n{}",
133 MarkdownCodeBlock {
134 tag: &codeblock_tag(&abs_path, None),
135 text: &content.to_string(),
136 }
137 )
138 .ok();
139 }
140 MentionUri::Symbol {
141 path, line_range, ..
142 }
143 | MentionUri::Selection {
144 path, line_range, ..
145 } => {
146 write!(
147 &mut rules_context,
148 "\n{}",
149 MarkdownCodeBlock {
150 tag: &codeblock_tag(&path, Some(line_range)),
151 text: &content
152 }
153 )
154 .ok();
155 }
156 MentionUri::Thread { .. } => {
157 write!(&mut thread_context, "\n{}\n", content).ok();
158 }
159 MentionUri::TextThread { .. } => {
160 write!(&mut thread_context, "\n{}\n", content).ok();
161 }
162 MentionUri::Rule { .. } => {
163 write!(
164 &mut rules_context,
165 "\n{}",
166 MarkdownCodeBlock {
167 tag: "",
168 text: &content
169 }
170 )
171 .ok();
172 }
173 MentionUri::Fetch { url } => {
174 write!(&mut fetch_context, "\nFetch: {}\n\n{}", url, content).ok();
175 }
176 }
177
178 language_model::MessageContent::Text(uri.as_link().to_string())
179 }
180 };
181
182 message.content.push(chunk);
183 }
184
185 let len_before_context = message.content.len();
186
187 if file_context.len() > OPEN_FILES_TAG.len() {
188 file_context.push_str("</files>\n");
189 message
190 .content
191 .push(language_model::MessageContent::Text(file_context));
192 }
193
194 if symbol_context.len() > OPEN_SYMBOLS_TAG.len() {
195 symbol_context.push_str("</symbols>\n");
196 message
197 .content
198 .push(language_model::MessageContent::Text(symbol_context));
199 }
200
201 if thread_context.len() > OPEN_THREADS_TAG.len() {
202 thread_context.push_str("</threads>\n");
203 message
204 .content
205 .push(language_model::MessageContent::Text(thread_context));
206 }
207
208 if fetch_context.len() > OPEN_FETCH_TAG.len() {
209 fetch_context.push_str("</fetched_urls>\n");
210 message
211 .content
212 .push(language_model::MessageContent::Text(fetch_context));
213 }
214
215 if rules_context.len() > OPEN_RULES_TAG.len() {
216 rules_context.push_str("</user_rules>\n");
217 message
218 .content
219 .push(language_model::MessageContent::Text(rules_context));
220 }
221
222 if message.content.len() > len_before_context {
223 message.content.insert(
224 len_before_context,
225 language_model::MessageContent::Text(OPEN_CONTEXT.into()),
226 );
227 message
228 .content
229 .push(language_model::MessageContent::Text("</context>".into()));
230 }
231
232 message
233 }
234}
235
236fn codeblock_tag(full_path: &Path, line_range: Option<&Range<u32>>) -> String {
237 let mut result = String::new();
238
239 if let Some(extension) = full_path.extension().and_then(|ext| ext.to_str()) {
240 let _ = write!(result, "{} ", extension);
241 }
242
243 let _ = write!(result, "{}", full_path.display());
244
245 if let Some(range) = line_range {
246 if range.start == range.end {
247 let _ = write!(result, ":{}", range.start + 1);
248 } else {
249 let _ = write!(result, ":{}-{}", range.start + 1, range.end + 1);
250 }
251 }
252
253 result
254}
255
256impl AgentMessage {
257 pub fn to_markdown(&self) -> String {
258 let mut markdown = String::from("## Assistant\n\n");
259
260 for content in &self.content {
261 match content {
262 AgentMessageContent::Text(text) => {
263 markdown.push_str(text);
264 markdown.push('\n');
265 }
266 AgentMessageContent::Thinking { text, .. } => {
267 markdown.push_str("<think>");
268 markdown.push_str(text);
269 markdown.push_str("</think>\n");
270 }
271 AgentMessageContent::RedactedThinking(_) => {
272 markdown.push_str("<redacted_thinking />\n")
273 }
274 AgentMessageContent::Image(_) => {
275 markdown.push_str("<image />\n");
276 }
277 AgentMessageContent::ToolUse(tool_use) => {
278 markdown.push_str(&format!(
279 "**Tool Use**: {} (ID: {})\n",
280 tool_use.name, tool_use.id
281 ));
282 markdown.push_str(&format!(
283 "{}\n",
284 MarkdownCodeBlock {
285 tag: "json",
286 text: &format!("{:#}", tool_use.input)
287 }
288 ));
289 }
290 }
291 }
292
293 for tool_result in self.tool_results.values() {
294 markdown.push_str(&format!(
295 "**Tool Result**: {} (ID: {})\n\n",
296 tool_result.tool_name, tool_result.tool_use_id
297 ));
298 if tool_result.is_error {
299 markdown.push_str("**ERROR:**\n");
300 }
301
302 match &tool_result.content {
303 LanguageModelToolResultContent::Text(text) => {
304 writeln!(markdown, "{text}\n").ok();
305 }
306 LanguageModelToolResultContent::Image(_) => {
307 writeln!(markdown, "<image />\n").ok();
308 }
309 }
310
311 if let Some(output) = tool_result.output.as_ref() {
312 writeln!(
313 markdown,
314 "**Debug Output**:\n\n```json\n{}\n```\n",
315 serde_json::to_string_pretty(output).unwrap()
316 )
317 .unwrap();
318 }
319 }
320
321 markdown
322 }
323
324 pub fn to_request(&self) -> Vec<LanguageModelRequestMessage> {
325 let mut assistant_message = LanguageModelRequestMessage {
326 role: Role::Assistant,
327 content: Vec::with_capacity(self.content.len()),
328 cache: false,
329 };
330 for chunk in &self.content {
331 let chunk = match chunk {
332 AgentMessageContent::Text(text) => {
333 language_model::MessageContent::Text(text.clone())
334 }
335 AgentMessageContent::Thinking { text, signature } => {
336 language_model::MessageContent::Thinking {
337 text: text.clone(),
338 signature: signature.clone(),
339 }
340 }
341 AgentMessageContent::RedactedThinking(value) => {
342 language_model::MessageContent::RedactedThinking(value.clone())
343 }
344 AgentMessageContent::ToolUse(value) => {
345 language_model::MessageContent::ToolUse(value.clone())
346 }
347 AgentMessageContent::Image(value) => {
348 language_model::MessageContent::Image(value.clone())
349 }
350 };
351 assistant_message.content.push(chunk);
352 }
353
354 let mut user_message = LanguageModelRequestMessage {
355 role: Role::User,
356 content: Vec::new(),
357 cache: false,
358 };
359
360 for tool_result in self.tool_results.values() {
361 user_message
362 .content
363 .push(language_model::MessageContent::ToolResult(
364 tool_result.clone(),
365 ));
366 }
367
368 let mut messages = Vec::new();
369 if !assistant_message.content.is_empty() {
370 messages.push(assistant_message);
371 }
372 if !user_message.content.is_empty() {
373 messages.push(user_message);
374 }
375 messages
376 }
377}
378
379#[derive(Default, Debug, Clone, PartialEq, Eq)]
380pub struct AgentMessage {
381 pub content: Vec<AgentMessageContent>,
382 pub tool_results: IndexMap<LanguageModelToolUseId, LanguageModelToolResult>,
383}
384
385#[derive(Debug, Clone, PartialEq, Eq)]
386pub enum AgentMessageContent {
387 Text(String),
388 Thinking {
389 text: String,
390 signature: Option<String>,
391 },
392 RedactedThinking(String),
393 Image(LanguageModelImage),
394 ToolUse(LanguageModelToolUse),
395}
396
397#[derive(Debug)]
398pub enum AgentResponseEvent {
399 Text(String),
400 Thinking(String),
401 ToolCall(acp::ToolCall),
402 ToolCallUpdate(acp_thread::ToolCallUpdate),
403 ToolCallAuthorization(ToolCallAuthorization),
404 Stop(acp::StopReason),
405}
406
407#[derive(Debug)]
408pub struct ToolCallAuthorization {
409 pub tool_call: acp::ToolCall,
410 pub options: Vec<acp::PermissionOption>,
411 pub response: oneshot::Sender<acp::PermissionOptionId>,
412}
413
414pub struct Thread {
415 messages: Vec<Message>,
416 completion_mode: CompletionMode,
417 /// Holds the task that handles agent interaction until the end of the turn.
418 /// Survives across multiple requests as the model performs tool calls and
419 /// we run tools, report their results.
420 running_turn: Option<Task<()>>,
421 pending_message: Option<AgentMessage>,
422 tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
423 tool_use_limit_reached: bool,
424 context_server_registry: Entity<ContextServerRegistry>,
425 profile_id: AgentProfileId,
426 project_context: Rc<RefCell<ProjectContext>>,
427 templates: Arc<Templates>,
428 model: Arc<dyn LanguageModel>,
429 project: Entity<Project>,
430 action_log: Entity<ActionLog>,
431}
432
433impl Thread {
434 pub fn new(
435 project: Entity<Project>,
436 project_context: Rc<RefCell<ProjectContext>>,
437 context_server_registry: Entity<ContextServerRegistry>,
438 action_log: Entity<ActionLog>,
439 templates: Arc<Templates>,
440 model: Arc<dyn LanguageModel>,
441 cx: &mut Context<Self>,
442 ) -> Self {
443 let profile_id = AgentSettings::get_global(cx).default_profile.clone();
444 Self {
445 messages: Vec::new(),
446 completion_mode: CompletionMode::Normal,
447 running_turn: None,
448 pending_message: None,
449 tools: BTreeMap::default(),
450 tool_use_limit_reached: false,
451 context_server_registry,
452 profile_id,
453 project_context,
454 templates,
455 model,
456 project,
457 action_log,
458 }
459 }
460
461 pub fn project(&self) -> &Entity<Project> {
462 &self.project
463 }
464
465 pub fn action_log(&self) -> &Entity<ActionLog> {
466 &self.action_log
467 }
468
469 pub fn model(&self) -> &Arc<dyn LanguageModel> {
470 &self.model
471 }
472
473 pub fn set_model(&mut self, model: Arc<dyn LanguageModel>) {
474 self.model = model;
475 }
476
477 pub fn completion_mode(&self) -> CompletionMode {
478 self.completion_mode
479 }
480
481 pub fn set_completion_mode(&mut self, mode: CompletionMode) {
482 self.completion_mode = mode;
483 }
484
485 #[cfg(any(test, feature = "test-support"))]
486 pub fn last_message(&self) -> Option<Message> {
487 if let Some(message) = self.pending_message.clone() {
488 Some(Message::Agent(message))
489 } else {
490 self.messages.last().cloned()
491 }
492 }
493
494 pub fn add_tool(&mut self, tool: impl AgentTool) {
495 self.tools.insert(tool.name(), tool.erase());
496 }
497
498 pub fn remove_tool(&mut self, name: &str) -> bool {
499 self.tools.remove(name).is_some()
500 }
501
502 pub fn profile(&self) -> &AgentProfileId {
503 &self.profile_id
504 }
505
506 pub fn set_profile(&mut self, profile_id: AgentProfileId) {
507 self.profile_id = profile_id;
508 }
509
510 pub fn cancel(&mut self) {
511 // TODO: do we need to emit a stop::cancel for ACP?
512 self.running_turn.take();
513 self.flush_pending_message();
514 }
515
516 pub fn truncate(&mut self, message_id: UserMessageId) -> Result<()> {
517 self.cancel();
518 let Some(position) = self.messages.iter().position(
519 |msg| matches!(msg, Message::User(UserMessage { id, .. }) if id == &message_id),
520 ) else {
521 return Err(anyhow!("Message not found"));
522 };
523 self.messages.truncate(position);
524 Ok(())
525 }
526
527 pub fn resume(
528 &mut self,
529 cx: &mut Context<Self>,
530 ) -> Result<mpsc::UnboundedReceiver<Result<AgentResponseEvent>>> {
531 anyhow::ensure!(
532 self.tool_use_limit_reached,
533 "can only resume after tool use limit is reached"
534 );
535
536 self.messages.push(Message::Resume);
537 cx.notify();
538
539 log::info!("Total messages in thread: {}", self.messages.len());
540 Ok(self.run_turn(cx))
541 }
542
543 /// Sending a message results in the model streaming a response, which could include tool calls.
544 /// After calling tools, the model will stops and waits for any outstanding tool calls to be completed and their results sent.
545 /// The returned channel will report all the occurrences in which the model stops before erroring or ending its turn.
546 pub fn send<T>(
547 &mut self,
548 id: UserMessageId,
549 content: impl IntoIterator<Item = T>,
550 cx: &mut Context<Self>,
551 ) -> mpsc::UnboundedReceiver<Result<AgentResponseEvent>>
552 where
553 T: Into<UserMessageContent>,
554 {
555 log::info!("Thread::send called with model: {:?}", self.model.name());
556
557 let content = content.into_iter().map(Into::into).collect::<Vec<_>>();
558 log::debug!("Thread::send content: {:?}", content);
559
560 self.messages
561 .push(Message::User(UserMessage { id, content }));
562 cx.notify();
563
564 log::info!("Total messages in thread: {}", self.messages.len());
565 self.run_turn(cx)
566 }
567
568 fn run_turn(
569 &mut self,
570 cx: &mut Context<Self>,
571 ) -> mpsc::UnboundedReceiver<Result<AgentResponseEvent>> {
572 let model = self.model.clone();
573 let (events_tx, events_rx) = mpsc::unbounded::<Result<AgentResponseEvent>>();
574 let event_stream = AgentResponseEventStream(events_tx);
575 let message_ix = self.messages.len().saturating_sub(1);
576 self.tool_use_limit_reached = false;
577 self.running_turn = Some(cx.spawn(async move |this, cx| {
578 log::info!("Starting agent turn execution");
579 let turn_result: Result<()> = async {
580 let mut completion_intent = CompletionIntent::UserPrompt;
581 loop {
582 log::debug!(
583 "Building completion request with intent: {:?}",
584 completion_intent
585 );
586 let request = this.update(cx, |this, cx| {
587 this.build_completion_request(completion_intent, cx)
588 })?;
589
590 log::info!("Calling model.stream_completion");
591 let mut events = model.stream_completion(request, cx).await?;
592 log::debug!("Stream completion started successfully");
593
594 let mut tool_use_limit_reached = false;
595 let mut tool_uses = FuturesUnordered::new();
596 while let Some(event) = events.next().await {
597 match event? {
598 LanguageModelCompletionEvent::StatusUpdate(
599 CompletionRequestStatus::ToolUseLimitReached,
600 ) => {
601 tool_use_limit_reached = true;
602 }
603 LanguageModelCompletionEvent::Stop(reason) => {
604 event_stream.send_stop(reason);
605 if reason == StopReason::Refusal {
606 this.update(cx, |this, _cx| {
607 this.flush_pending_message();
608 this.messages.truncate(message_ix);
609 })?;
610 return Ok(());
611 }
612 }
613 event => {
614 log::trace!("Received completion event: {:?}", event);
615 this.update(cx, |this, cx| {
616 tool_uses.extend(this.handle_streamed_completion_event(
617 event,
618 &event_stream,
619 cx,
620 ));
621 })
622 .ok();
623 }
624 }
625 }
626
627 let used_tools = tool_uses.is_empty();
628 while let Some(tool_result) = tool_uses.next().await {
629 log::info!("Tool finished {:?}", tool_result);
630
631 event_stream.update_tool_call_fields(
632 &tool_result.tool_use_id,
633 acp::ToolCallUpdateFields {
634 status: Some(if tool_result.is_error {
635 acp::ToolCallStatus::Failed
636 } else {
637 acp::ToolCallStatus::Completed
638 }),
639 raw_output: tool_result.output.clone(),
640 ..Default::default()
641 },
642 );
643 this.update(cx, |this, _cx| {
644 this.pending_message()
645 .tool_results
646 .insert(tool_result.tool_use_id.clone(), tool_result);
647 })
648 .ok();
649 }
650
651 if tool_use_limit_reached {
652 log::info!("Tool use limit reached, completing turn");
653 this.update(cx, |this, _cx| this.tool_use_limit_reached = true)?;
654 return Err(language_model::ToolUseLimitReachedError.into());
655 } else if used_tools {
656 log::info!("No tool uses found, completing turn");
657 return Ok(());
658 } else {
659 this.update(cx, |this, _| this.flush_pending_message())?;
660 completion_intent = CompletionIntent::ToolResults;
661 }
662 }
663 }
664 .await;
665
666 this.update(cx, |this, _| this.flush_pending_message()).ok();
667 if let Err(error) = turn_result {
668 log::error!("Turn execution failed: {:?}", error);
669 event_stream.send_error(error);
670 } else {
671 log::info!("Turn execution completed successfully");
672 }
673 }));
674 events_rx
675 }
676
677 pub fn build_system_message(&self) -> LanguageModelRequestMessage {
678 log::debug!("Building system message");
679 let prompt = SystemPromptTemplate {
680 project: &self.project_context.borrow(),
681 available_tools: self.tools.keys().cloned().collect(),
682 }
683 .render(&self.templates)
684 .context("failed to build system prompt")
685 .expect("Invalid template");
686 log::debug!("System message built");
687 LanguageModelRequestMessage {
688 role: Role::System,
689 content: vec![prompt.into()],
690 cache: true,
691 }
692 }
693
694 /// A helper method that's called on every streamed completion event.
695 /// Returns an optional tool result task, which the main agentic loop in
696 /// send will send back to the model when it resolves.
697 fn handle_streamed_completion_event(
698 &mut self,
699 event: LanguageModelCompletionEvent,
700 event_stream: &AgentResponseEventStream,
701 cx: &mut Context<Self>,
702 ) -> Option<Task<LanguageModelToolResult>> {
703 log::trace!("Handling streamed completion event: {:?}", event);
704 use LanguageModelCompletionEvent::*;
705
706 match event {
707 StartMessage { .. } => {
708 self.flush_pending_message();
709 self.pending_message = Some(AgentMessage::default());
710 }
711 Text(new_text) => self.handle_text_event(new_text, event_stream, cx),
712 Thinking { text, signature } => {
713 self.handle_thinking_event(text, signature, event_stream, cx)
714 }
715 RedactedThinking { data } => self.handle_redacted_thinking_event(data, cx),
716 ToolUse(tool_use) => {
717 return self.handle_tool_use_event(tool_use, event_stream, cx);
718 }
719 ToolUseJsonParseError {
720 id,
721 tool_name,
722 raw_input,
723 json_parse_error,
724 } => {
725 return Some(Task::ready(self.handle_tool_use_json_parse_error_event(
726 id,
727 tool_name,
728 raw_input,
729 json_parse_error,
730 )));
731 }
732 UsageUpdate(_) | StatusUpdate(_) => {}
733 Stop(_) => unreachable!(),
734 }
735
736 None
737 }
738
739 fn handle_text_event(
740 &mut self,
741 new_text: String,
742 event_stream: &AgentResponseEventStream,
743 cx: &mut Context<Self>,
744 ) {
745 event_stream.send_text(&new_text);
746
747 let last_message = self.pending_message();
748 if let Some(AgentMessageContent::Text(text)) = last_message.content.last_mut() {
749 text.push_str(&new_text);
750 } else {
751 last_message
752 .content
753 .push(AgentMessageContent::Text(new_text));
754 }
755
756 cx.notify();
757 }
758
759 fn handle_thinking_event(
760 &mut self,
761 new_text: String,
762 new_signature: Option<String>,
763 event_stream: &AgentResponseEventStream,
764 cx: &mut Context<Self>,
765 ) {
766 event_stream.send_thinking(&new_text);
767
768 let last_message = self.pending_message();
769 if let Some(AgentMessageContent::Thinking { text, signature }) =
770 last_message.content.last_mut()
771 {
772 text.push_str(&new_text);
773 *signature = new_signature.or(signature.take());
774 } else {
775 last_message.content.push(AgentMessageContent::Thinking {
776 text: new_text,
777 signature: new_signature,
778 });
779 }
780
781 cx.notify();
782 }
783
784 fn handle_redacted_thinking_event(&mut self, data: String, cx: &mut Context<Self>) {
785 let last_message = self.pending_message();
786 last_message
787 .content
788 .push(AgentMessageContent::RedactedThinking(data));
789 cx.notify();
790 }
791
792 fn handle_tool_use_event(
793 &mut self,
794 tool_use: LanguageModelToolUse,
795 event_stream: &AgentResponseEventStream,
796 cx: &mut Context<Self>,
797 ) -> Option<Task<LanguageModelToolResult>> {
798 cx.notify();
799
800 let tool = self.tools.get(tool_use.name.as_ref()).cloned();
801 let mut title = SharedString::from(&tool_use.name);
802 let mut kind = acp::ToolKind::Other;
803 if let Some(tool) = tool.as_ref() {
804 title = tool.initial_title(tool_use.input.clone());
805 kind = tool.kind();
806 }
807
808 // Ensure the last message ends in the current tool use
809 let last_message = self.pending_message();
810 let push_new_tool_use = last_message.content.last_mut().map_or(true, |content| {
811 if let AgentMessageContent::ToolUse(last_tool_use) = content {
812 if last_tool_use.id == tool_use.id {
813 *last_tool_use = tool_use.clone();
814 false
815 } else {
816 true
817 }
818 } else {
819 true
820 }
821 });
822
823 if push_new_tool_use {
824 event_stream.send_tool_call(&tool_use.id, title, kind, tool_use.input.clone());
825 last_message
826 .content
827 .push(AgentMessageContent::ToolUse(tool_use.clone()));
828 } else {
829 event_stream.update_tool_call_fields(
830 &tool_use.id,
831 acp::ToolCallUpdateFields {
832 title: Some(title.into()),
833 kind: Some(kind),
834 raw_input: Some(tool_use.input.clone()),
835 ..Default::default()
836 },
837 );
838 }
839
840 if !tool_use.is_input_complete {
841 return None;
842 }
843
844 let Some(tool) = tool else {
845 let content = format!("No tool named {} exists", tool_use.name);
846 return Some(Task::ready(LanguageModelToolResult {
847 content: LanguageModelToolResultContent::Text(Arc::from(content)),
848 tool_use_id: tool_use.id,
849 tool_name: tool_use.name,
850 is_error: true,
851 output: None,
852 }));
853 };
854
855 let fs = self.project.read(cx).fs().clone();
856 let tool_event_stream =
857 ToolCallEventStream::new(&tool_use, tool.kind(), event_stream.clone(), Some(fs));
858 tool_event_stream.update_fields(acp::ToolCallUpdateFields {
859 status: Some(acp::ToolCallStatus::InProgress),
860 ..Default::default()
861 });
862 let supports_images = self.model.supports_images();
863 let tool_result = tool.run(tool_use.input, tool_event_stream, cx);
864 log::info!("Running tool {}", tool_use.name);
865 Some(cx.foreground_executor().spawn(async move {
866 let tool_result = tool_result.await.and_then(|output| {
867 if let LanguageModelToolResultContent::Image(_) = &output.llm_output {
868 if !supports_images {
869 return Err(anyhow!(
870 "Attempted to read an image, but this model doesn't support it.",
871 ));
872 }
873 }
874 Ok(output)
875 });
876
877 match tool_result {
878 Ok(output) => LanguageModelToolResult {
879 tool_use_id: tool_use.id,
880 tool_name: tool_use.name,
881 is_error: false,
882 content: output.llm_output,
883 output: Some(output.raw_output),
884 },
885 Err(error) => LanguageModelToolResult {
886 tool_use_id: tool_use.id,
887 tool_name: tool_use.name,
888 is_error: true,
889 content: LanguageModelToolResultContent::Text(Arc::from(error.to_string())),
890 output: None,
891 },
892 }
893 }))
894 }
895
896 fn handle_tool_use_json_parse_error_event(
897 &mut self,
898 tool_use_id: LanguageModelToolUseId,
899 tool_name: Arc<str>,
900 raw_input: Arc<str>,
901 json_parse_error: String,
902 ) -> LanguageModelToolResult {
903 let tool_output = format!("Error parsing input JSON: {json_parse_error}");
904 LanguageModelToolResult {
905 tool_use_id,
906 tool_name,
907 is_error: true,
908 content: LanguageModelToolResultContent::Text(tool_output.into()),
909 output: Some(serde_json::Value::String(raw_input.to_string())),
910 }
911 }
912
913 fn pending_message(&mut self) -> &mut AgentMessage {
914 self.pending_message.get_or_insert_default()
915 }
916
917 fn flush_pending_message(&mut self) {
918 let Some(mut message) = self.pending_message.take() else {
919 return;
920 };
921
922 for content in &message.content {
923 let AgentMessageContent::ToolUse(tool_use) = content else {
924 continue;
925 };
926
927 if !message.tool_results.contains_key(&tool_use.id) {
928 message.tool_results.insert(
929 tool_use.id.clone(),
930 LanguageModelToolResult {
931 tool_use_id: tool_use.id.clone(),
932 tool_name: tool_use.name.clone(),
933 is_error: true,
934 content: LanguageModelToolResultContent::Text(
935 "Tool canceled by user".into(),
936 ),
937 output: None,
938 },
939 );
940 }
941 }
942
943 self.messages.push(Message::Agent(message));
944 }
945
946 pub(crate) fn build_completion_request(
947 &self,
948 completion_intent: CompletionIntent,
949 cx: &mut App,
950 ) -> LanguageModelRequest {
951 log::debug!("Building completion request");
952 log::debug!("Completion intent: {:?}", completion_intent);
953 log::debug!("Completion mode: {:?}", self.completion_mode);
954
955 let messages = self.build_request_messages();
956 log::info!("Request will include {} messages", messages.len());
957
958 let tools = if let Some(tools) = self.tools(cx).log_err() {
959 tools
960 .filter_map(|tool| {
961 let tool_name = tool.name().to_string();
962 log::trace!("Including tool: {}", tool_name);
963 Some(LanguageModelRequestTool {
964 name: tool_name,
965 description: tool.description().to_string(),
966 input_schema: tool
967 .input_schema(self.model.tool_input_format())
968 .log_err()?,
969 })
970 })
971 .collect()
972 } else {
973 Vec::new()
974 };
975
976 log::info!("Request includes {} tools", tools.len());
977
978 let request = LanguageModelRequest {
979 thread_id: None,
980 prompt_id: None,
981 intent: Some(completion_intent),
982 mode: Some(self.completion_mode.into()),
983 messages,
984 tools,
985 tool_choice: None,
986 stop: Vec::new(),
987 temperature: None,
988 thinking_allowed: true,
989 };
990
991 log::debug!("Completion request built successfully");
992 request
993 }
994
995 fn tools<'a>(&'a self, cx: &'a App) -> Result<impl Iterator<Item = &'a Arc<dyn AnyAgentTool>>> {
996 let profile = AgentSettings::get_global(cx)
997 .profiles
998 .get(&self.profile_id)
999 .context("profile not found")?;
1000 let provider_id = self.model.provider_id();
1001
1002 Ok(self
1003 .tools
1004 .iter()
1005 .filter(move |(_, tool)| tool.supported_provider(&provider_id))
1006 .filter_map(|(tool_name, tool)| {
1007 if profile.is_tool_enabled(tool_name) {
1008 Some(tool)
1009 } else {
1010 None
1011 }
1012 })
1013 .chain(self.context_server_registry.read(cx).servers().flat_map(
1014 |(server_id, tools)| {
1015 tools.iter().filter_map(|(tool_name, tool)| {
1016 if profile.is_context_server_tool_enabled(&server_id.0, tool_name) {
1017 Some(tool)
1018 } else {
1019 None
1020 }
1021 })
1022 },
1023 )))
1024 }
1025
1026 fn build_request_messages(&self) -> Vec<LanguageModelRequestMessage> {
1027 log::trace!(
1028 "Building request messages from {} thread messages",
1029 self.messages.len()
1030 );
1031 let mut messages = vec![self.build_system_message()];
1032 for message in &self.messages {
1033 match message {
1034 Message::User(message) => messages.push(message.to_request()),
1035 Message::Agent(message) => messages.extend(message.to_request()),
1036 Message::Resume => messages.push(LanguageModelRequestMessage {
1037 role: Role::User,
1038 content: vec!["Continue where you left off".into()],
1039 cache: false,
1040 }),
1041 }
1042 }
1043
1044 if let Some(message) = self.pending_message.as_ref() {
1045 messages.extend(message.to_request());
1046 }
1047
1048 if let Some(last_user_message) = messages
1049 .iter_mut()
1050 .rev()
1051 .find(|message| message.role == Role::User)
1052 {
1053 last_user_message.cache = true;
1054 }
1055
1056 messages
1057 }
1058
1059 pub fn to_markdown(&self) -> String {
1060 let mut markdown = String::new();
1061 for (ix, message) in self.messages.iter().enumerate() {
1062 if ix > 0 {
1063 markdown.push('\n');
1064 }
1065 markdown.push_str(&message.to_markdown());
1066 }
1067
1068 if let Some(message) = self.pending_message.as_ref() {
1069 markdown.push('\n');
1070 markdown.push_str(&message.to_markdown());
1071 }
1072
1073 markdown
1074 }
1075}
1076
1077pub trait AgentTool
1078where
1079 Self: 'static + Sized,
1080{
1081 type Input: for<'de> Deserialize<'de> + Serialize + JsonSchema;
1082 type Output: for<'de> Deserialize<'de> + Serialize + Into<LanguageModelToolResultContent>;
1083
1084 fn name(&self) -> SharedString;
1085
1086 fn description(&self) -> SharedString {
1087 let schema = schemars::schema_for!(Self::Input);
1088 SharedString::new(
1089 schema
1090 .get("description")
1091 .and_then(|description| description.as_str())
1092 .unwrap_or_default(),
1093 )
1094 }
1095
1096 fn kind(&self) -> acp::ToolKind;
1097
1098 /// The initial tool title to display. Can be updated during the tool run.
1099 fn initial_title(&self, input: Result<Self::Input, serde_json::Value>) -> SharedString;
1100
1101 /// Returns the JSON schema that describes the tool's input.
1102 fn input_schema(&self) -> Schema {
1103 schemars::schema_for!(Self::Input)
1104 }
1105
1106 /// Some tools rely on a provider for the underlying billing or other reasons.
1107 /// Allow the tool to check if they are compatible, or should be filtered out.
1108 fn supported_provider(&self, _provider: &LanguageModelProviderId) -> bool {
1109 true
1110 }
1111
1112 /// Runs the tool with the provided input.
1113 fn run(
1114 self: Arc<Self>,
1115 input: Self::Input,
1116 event_stream: ToolCallEventStream,
1117 cx: &mut App,
1118 ) -> Task<Result<Self::Output>>;
1119
1120 fn erase(self) -> Arc<dyn AnyAgentTool> {
1121 Arc::new(Erased(Arc::new(self)))
1122 }
1123}
1124
1125pub struct Erased<T>(T);
1126
1127pub struct AgentToolOutput {
1128 pub llm_output: LanguageModelToolResultContent,
1129 pub raw_output: serde_json::Value,
1130}
1131
1132pub trait AnyAgentTool {
1133 fn name(&self) -> SharedString;
1134 fn description(&self) -> SharedString;
1135 fn kind(&self) -> acp::ToolKind;
1136 fn initial_title(&self, input: serde_json::Value) -> SharedString;
1137 fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value>;
1138 fn supported_provider(&self, _provider: &LanguageModelProviderId) -> bool {
1139 true
1140 }
1141 fn run(
1142 self: Arc<Self>,
1143 input: serde_json::Value,
1144 event_stream: ToolCallEventStream,
1145 cx: &mut App,
1146 ) -> Task<Result<AgentToolOutput>>;
1147}
1148
1149impl<T> AnyAgentTool for Erased<Arc<T>>
1150where
1151 T: AgentTool,
1152{
1153 fn name(&self) -> SharedString {
1154 self.0.name()
1155 }
1156
1157 fn description(&self) -> SharedString {
1158 self.0.description()
1159 }
1160
1161 fn kind(&self) -> agent_client_protocol::ToolKind {
1162 self.0.kind()
1163 }
1164
1165 fn initial_title(&self, input: serde_json::Value) -> SharedString {
1166 let parsed_input = serde_json::from_value(input.clone()).map_err(|_| input);
1167 self.0.initial_title(parsed_input)
1168 }
1169
1170 fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
1171 let mut json = serde_json::to_value(self.0.input_schema())?;
1172 adapt_schema_to_format(&mut json, format)?;
1173 Ok(json)
1174 }
1175
1176 fn supported_provider(&self, provider: &LanguageModelProviderId) -> bool {
1177 self.0.supported_provider(provider)
1178 }
1179
1180 fn run(
1181 self: Arc<Self>,
1182 input: serde_json::Value,
1183 event_stream: ToolCallEventStream,
1184 cx: &mut App,
1185 ) -> Task<Result<AgentToolOutput>> {
1186 cx.spawn(async move |cx| {
1187 let input = serde_json::from_value(input)?;
1188 let output = cx
1189 .update(|cx| self.0.clone().run(input, event_stream, cx))?
1190 .await?;
1191 let raw_output = serde_json::to_value(&output)?;
1192 Ok(AgentToolOutput {
1193 llm_output: output.into(),
1194 raw_output,
1195 })
1196 })
1197 }
1198}
1199
1200#[derive(Clone)]
1201struct AgentResponseEventStream(mpsc::UnboundedSender<Result<AgentResponseEvent>>);
1202
1203impl AgentResponseEventStream {
1204 fn send_text(&self, text: &str) {
1205 self.0
1206 .unbounded_send(Ok(AgentResponseEvent::Text(text.to_string())))
1207 .ok();
1208 }
1209
1210 fn send_thinking(&self, text: &str) {
1211 self.0
1212 .unbounded_send(Ok(AgentResponseEvent::Thinking(text.to_string())))
1213 .ok();
1214 }
1215
1216 fn send_tool_call(
1217 &self,
1218 id: &LanguageModelToolUseId,
1219 title: SharedString,
1220 kind: acp::ToolKind,
1221 input: serde_json::Value,
1222 ) {
1223 self.0
1224 .unbounded_send(Ok(AgentResponseEvent::ToolCall(Self::initial_tool_call(
1225 id,
1226 title.to_string(),
1227 kind,
1228 input,
1229 ))))
1230 .ok();
1231 }
1232
1233 fn initial_tool_call(
1234 id: &LanguageModelToolUseId,
1235 title: String,
1236 kind: acp::ToolKind,
1237 input: serde_json::Value,
1238 ) -> acp::ToolCall {
1239 acp::ToolCall {
1240 id: acp::ToolCallId(id.to_string().into()),
1241 title,
1242 kind,
1243 status: acp::ToolCallStatus::Pending,
1244 content: vec![],
1245 locations: vec![],
1246 raw_input: Some(input),
1247 raw_output: None,
1248 }
1249 }
1250
1251 fn update_tool_call_fields(
1252 &self,
1253 tool_use_id: &LanguageModelToolUseId,
1254 fields: acp::ToolCallUpdateFields,
1255 ) {
1256 self.0
1257 .unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate(
1258 acp::ToolCallUpdate {
1259 id: acp::ToolCallId(tool_use_id.to_string().into()),
1260 fields,
1261 }
1262 .into(),
1263 )))
1264 .ok();
1265 }
1266
1267 fn send_stop(&self, reason: StopReason) {
1268 match reason {
1269 StopReason::EndTurn => {
1270 self.0
1271 .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::EndTurn)))
1272 .ok();
1273 }
1274 StopReason::MaxTokens => {
1275 self.0
1276 .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::MaxTokens)))
1277 .ok();
1278 }
1279 StopReason::Refusal => {
1280 self.0
1281 .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::Refusal)))
1282 .ok();
1283 }
1284 StopReason::ToolUse => {}
1285 }
1286 }
1287
1288 fn send_error(&self, error: impl Into<anyhow::Error>) {
1289 self.0.unbounded_send(Err(error.into())).ok();
1290 }
1291}
1292
1293#[derive(Clone)]
1294pub struct ToolCallEventStream {
1295 tool_use_id: LanguageModelToolUseId,
1296 kind: acp::ToolKind,
1297 input: serde_json::Value,
1298 stream: AgentResponseEventStream,
1299 fs: Option<Arc<dyn Fs>>,
1300}
1301
1302impl ToolCallEventStream {
1303 #[cfg(test)]
1304 pub fn test() -> (Self, ToolCallEventStreamReceiver) {
1305 let (events_tx, events_rx) = mpsc::unbounded::<Result<AgentResponseEvent>>();
1306
1307 let stream = ToolCallEventStream::new(
1308 &LanguageModelToolUse {
1309 id: "test_id".into(),
1310 name: "test_tool".into(),
1311 raw_input: String::new(),
1312 input: serde_json::Value::Null,
1313 is_input_complete: true,
1314 },
1315 acp::ToolKind::Other,
1316 AgentResponseEventStream(events_tx),
1317 None,
1318 );
1319
1320 (stream, ToolCallEventStreamReceiver(events_rx))
1321 }
1322
1323 fn new(
1324 tool_use: &LanguageModelToolUse,
1325 kind: acp::ToolKind,
1326 stream: AgentResponseEventStream,
1327 fs: Option<Arc<dyn Fs>>,
1328 ) -> Self {
1329 Self {
1330 tool_use_id: tool_use.id.clone(),
1331 kind,
1332 input: tool_use.input.clone(),
1333 stream,
1334 fs,
1335 }
1336 }
1337
1338 pub fn update_fields(&self, fields: acp::ToolCallUpdateFields) {
1339 self.stream
1340 .update_tool_call_fields(&self.tool_use_id, fields);
1341 }
1342
1343 pub fn update_diff(&self, diff: Entity<acp_thread::Diff>) {
1344 self.stream
1345 .0
1346 .unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate(
1347 acp_thread::ToolCallUpdateDiff {
1348 id: acp::ToolCallId(self.tool_use_id.to_string().into()),
1349 diff,
1350 }
1351 .into(),
1352 )))
1353 .ok();
1354 }
1355
1356 pub fn update_terminal(&self, terminal: Entity<acp_thread::Terminal>) {
1357 self.stream
1358 .0
1359 .unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate(
1360 acp_thread::ToolCallUpdateTerminal {
1361 id: acp::ToolCallId(self.tool_use_id.to_string().into()),
1362 terminal,
1363 }
1364 .into(),
1365 )))
1366 .ok();
1367 }
1368
1369 pub fn authorize(&self, title: impl Into<String>, cx: &mut App) -> Task<Result<()>> {
1370 if agent_settings::AgentSettings::get_global(cx).always_allow_tool_actions {
1371 return Task::ready(Ok(()));
1372 }
1373
1374 let (response_tx, response_rx) = oneshot::channel();
1375 self.stream
1376 .0
1377 .unbounded_send(Ok(AgentResponseEvent::ToolCallAuthorization(
1378 ToolCallAuthorization {
1379 tool_call: AgentResponseEventStream::initial_tool_call(
1380 &self.tool_use_id,
1381 title.into(),
1382 self.kind.clone(),
1383 self.input.clone(),
1384 ),
1385 options: vec![
1386 acp::PermissionOption {
1387 id: acp::PermissionOptionId("always_allow".into()),
1388 name: "Always Allow".into(),
1389 kind: acp::PermissionOptionKind::AllowAlways,
1390 },
1391 acp::PermissionOption {
1392 id: acp::PermissionOptionId("allow".into()),
1393 name: "Allow".into(),
1394 kind: acp::PermissionOptionKind::AllowOnce,
1395 },
1396 acp::PermissionOption {
1397 id: acp::PermissionOptionId("deny".into()),
1398 name: "Deny".into(),
1399 kind: acp::PermissionOptionKind::RejectOnce,
1400 },
1401 ],
1402 response: response_tx,
1403 },
1404 )))
1405 .ok();
1406 let fs = self.fs.clone();
1407 cx.spawn(async move |cx| match response_rx.await?.0.as_ref() {
1408 "always_allow" => {
1409 if let Some(fs) = fs.clone() {
1410 cx.update(|cx| {
1411 update_settings_file::<AgentSettings>(fs, cx, |settings, _| {
1412 settings.set_always_allow_tool_actions(true);
1413 });
1414 })?;
1415 }
1416
1417 Ok(())
1418 }
1419 "allow" => Ok(()),
1420 _ => Err(anyhow!("Permission to run tool denied by user")),
1421 })
1422 }
1423}
1424
1425#[cfg(test)]
1426pub struct ToolCallEventStreamReceiver(mpsc::UnboundedReceiver<Result<AgentResponseEvent>>);
1427
1428#[cfg(test)]
1429impl ToolCallEventStreamReceiver {
1430 pub async fn expect_authorization(&mut self) -> ToolCallAuthorization {
1431 let event = self.0.next().await;
1432 if let Some(Ok(AgentResponseEvent::ToolCallAuthorization(auth))) = event {
1433 auth
1434 } else {
1435 panic!("Expected ToolCallAuthorization but got: {:?}", event);
1436 }
1437 }
1438
1439 pub async fn expect_terminal(&mut self) -> Entity<acp_thread::Terminal> {
1440 let event = self.0.next().await;
1441 if let Some(Ok(AgentResponseEvent::ToolCallUpdate(
1442 acp_thread::ToolCallUpdate::UpdateTerminal(update),
1443 ))) = event
1444 {
1445 update.terminal
1446 } else {
1447 panic!("Expected terminal but got: {:?}", event);
1448 }
1449 }
1450}
1451
1452#[cfg(test)]
1453impl std::ops::Deref for ToolCallEventStreamReceiver {
1454 type Target = mpsc::UnboundedReceiver<Result<AgentResponseEvent>>;
1455
1456 fn deref(&self) -> &Self::Target {
1457 &self.0
1458 }
1459}
1460
1461#[cfg(test)]
1462impl std::ops::DerefMut for ToolCallEventStreamReceiver {
1463 fn deref_mut(&mut self) -> &mut Self::Target {
1464 &mut self.0
1465 }
1466}
1467
1468impl From<&str> for UserMessageContent {
1469 fn from(text: &str) -> Self {
1470 Self::Text(text.into())
1471 }
1472}
1473
1474impl From<acp::ContentBlock> for UserMessageContent {
1475 fn from(value: acp::ContentBlock) -> Self {
1476 match value {
1477 acp::ContentBlock::Text(text_content) => Self::Text(text_content.text),
1478 acp::ContentBlock::Image(image_content) => Self::Image(convert_image(image_content)),
1479 acp::ContentBlock::Audio(_) => {
1480 // TODO
1481 Self::Text("[audio]".to_string())
1482 }
1483 acp::ContentBlock::ResourceLink(resource_link) => {
1484 match MentionUri::parse(&resource_link.uri) {
1485 Ok(uri) => Self::Mention {
1486 uri,
1487 content: String::new(),
1488 },
1489 Err(err) => {
1490 log::error!("Failed to parse mention link: {}", err);
1491 Self::Text(format!("[{}]({})", resource_link.name, resource_link.uri))
1492 }
1493 }
1494 }
1495 acp::ContentBlock::Resource(resource) => match resource.resource {
1496 acp::EmbeddedResourceResource::TextResourceContents(resource) => {
1497 match MentionUri::parse(&resource.uri) {
1498 Ok(uri) => Self::Mention {
1499 uri,
1500 content: resource.text,
1501 },
1502 Err(err) => {
1503 log::error!("Failed to parse mention link: {}", err);
1504 Self::Text(
1505 MarkdownCodeBlock {
1506 tag: &resource.uri,
1507 text: &resource.text,
1508 }
1509 .to_string(),
1510 )
1511 }
1512 }
1513 }
1514 acp::EmbeddedResourceResource::BlobResourceContents(_) => {
1515 // TODO
1516 Self::Text("[blob]".to_string())
1517 }
1518 },
1519 }
1520 }
1521}
1522
1523fn convert_image(image_content: acp::ImageContent) -> LanguageModelImage {
1524 LanguageModelImage {
1525 source: image_content.data.into(),
1526 // TODO: make this optional?
1527 size: gpui::Size::new(0.into(), 0.into()),
1528 }
1529}