1use crate::{ContextServerRegistry, SystemPromptTemplate, Template, Templates};
2use acp_thread::{MentionUri, UserMessageId};
3use action_log::ActionLog;
4use agent_client_protocol as acp;
5use agent_settings::{AgentProfileId, AgentSettings, CompletionMode};
6use anyhow::{Context as _, Result, anyhow};
7use assistant_tool::adapt_schema_to_format;
8use cloud_llm_client::{CompletionIntent, CompletionRequestStatus};
9use collections::IndexMap;
10use fs::Fs;
11use futures::{
12 channel::{mpsc, oneshot},
13 stream::FuturesUnordered,
14};
15use gpui::{App, Context, Entity, SharedString, Task};
16use language_model::{
17 LanguageModel, LanguageModelCompletionEvent, LanguageModelImage, LanguageModelProviderId,
18 LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
19 LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolSchemaFormat,
20 LanguageModelToolUse, LanguageModelToolUseId, Role, StopReason,
21};
22use project::Project;
23use prompt_store::ProjectContext;
24use schemars::{JsonSchema, Schema};
25use serde::{Deserialize, Serialize};
26use settings::{Settings, update_settings_file};
27use smol::stream::StreamExt;
28use std::{cell::RefCell, collections::BTreeMap, path::Path, rc::Rc, sync::Arc};
29use std::{fmt::Write, ops::Range};
30use util::{ResultExt, markdown::MarkdownCodeBlock};
31
32#[derive(Debug, Clone, PartialEq, Eq)]
33pub enum Message {
34 User(UserMessage),
35 Agent(AgentMessage),
36 Resume,
37}
38
39impl Message {
40 pub fn as_agent_message(&self) -> Option<&AgentMessage> {
41 match self {
42 Message::Agent(agent_message) => Some(agent_message),
43 _ => None,
44 }
45 }
46
47 pub fn to_markdown(&self) -> String {
48 match self {
49 Message::User(message) => message.to_markdown(),
50 Message::Agent(message) => message.to_markdown(),
51 Message::Resume => "[resumed after tool use limit was reached]".into(),
52 }
53 }
54}
55
56#[derive(Debug, Clone, PartialEq, Eq)]
57pub struct UserMessage {
58 pub id: UserMessageId,
59 pub content: Vec<UserMessageContent>,
60}
61
62#[derive(Debug, Clone, PartialEq, Eq)]
63pub enum UserMessageContent {
64 Text(String),
65 Mention { uri: MentionUri, content: String },
66 Image(LanguageModelImage),
67}
68
69impl UserMessage {
70 pub fn to_markdown(&self) -> String {
71 let mut markdown = String::from("## User\n\n");
72
73 for content in &self.content {
74 match content {
75 UserMessageContent::Text(text) => {
76 markdown.push_str(text);
77 markdown.push('\n');
78 }
79 UserMessageContent::Image(_) => {
80 markdown.push_str("<image />\n");
81 }
82 UserMessageContent::Mention { uri, content } => {
83 if !content.is_empty() {
84 let _ = write!(&mut markdown, "{}\n\n{}\n", uri.as_link(), content);
85 } else {
86 let _ = write!(&mut markdown, "{}\n", uri.as_link());
87 }
88 }
89 }
90 }
91
92 markdown
93 }
94
95 fn to_request(&self) -> LanguageModelRequestMessage {
96 let mut message = LanguageModelRequestMessage {
97 role: Role::User,
98 content: Vec::with_capacity(self.content.len()),
99 cache: false,
100 };
101
102 const OPEN_CONTEXT: &str = "<context>\n\
103 The following items were attached by the user. \
104 They are up-to-date and don't need to be re-read.\n\n";
105
106 const OPEN_FILES_TAG: &str = "<files>";
107 const OPEN_SYMBOLS_TAG: &str = "<symbols>";
108 const OPEN_THREADS_TAG: &str = "<threads>";
109 const OPEN_FETCH_TAG: &str = "<fetched_urls>";
110 const OPEN_RULES_TAG: &str =
111 "<rules>\nThe user has specified the following rules that should be applied:\n";
112
113 let mut file_context = OPEN_FILES_TAG.to_string();
114 let mut symbol_context = OPEN_SYMBOLS_TAG.to_string();
115 let mut thread_context = OPEN_THREADS_TAG.to_string();
116 let mut fetch_context = OPEN_FETCH_TAG.to_string();
117 let mut rules_context = OPEN_RULES_TAG.to_string();
118
119 for chunk in &self.content {
120 let chunk = match chunk {
121 UserMessageContent::Text(text) => {
122 language_model::MessageContent::Text(text.clone())
123 }
124 UserMessageContent::Image(value) => {
125 language_model::MessageContent::Image(value.clone())
126 }
127 UserMessageContent::Mention { uri, content } => {
128 match uri {
129 MentionUri::File { abs_path, .. } => {
130 write!(
131 &mut symbol_context,
132 "\n{}",
133 MarkdownCodeBlock {
134 tag: &codeblock_tag(&abs_path, None),
135 text: &content.to_string(),
136 }
137 )
138 .ok();
139 }
140 MentionUri::Symbol {
141 path, line_range, ..
142 }
143 | MentionUri::Selection {
144 path, line_range, ..
145 } => {
146 write!(
147 &mut rules_context,
148 "\n{}",
149 MarkdownCodeBlock {
150 tag: &codeblock_tag(&path, Some(line_range)),
151 text: &content
152 }
153 )
154 .ok();
155 }
156 MentionUri::Thread { .. } => {
157 write!(&mut thread_context, "\n{}\n", content).ok();
158 }
159 MentionUri::TextThread { .. } => {
160 write!(&mut thread_context, "\n{}\n", content).ok();
161 }
162 MentionUri::Rule { .. } => {
163 write!(
164 &mut rules_context,
165 "\n{}",
166 MarkdownCodeBlock {
167 tag: "",
168 text: &content
169 }
170 )
171 .ok();
172 }
173 MentionUri::Fetch { url } => {
174 write!(&mut fetch_context, "\nFetch: {}\n\n{}", url, content).ok();
175 }
176 }
177
178 language_model::MessageContent::Text(uri.as_link().to_string())
179 }
180 };
181
182 message.content.push(chunk);
183 }
184
185 let len_before_context = message.content.len();
186
187 if file_context.len() > OPEN_FILES_TAG.len() {
188 file_context.push_str("</files>\n");
189 message
190 .content
191 .push(language_model::MessageContent::Text(file_context));
192 }
193
194 if symbol_context.len() > OPEN_SYMBOLS_TAG.len() {
195 symbol_context.push_str("</symbols>\n");
196 message
197 .content
198 .push(language_model::MessageContent::Text(symbol_context));
199 }
200
201 if thread_context.len() > OPEN_THREADS_TAG.len() {
202 thread_context.push_str("</threads>\n");
203 message
204 .content
205 .push(language_model::MessageContent::Text(thread_context));
206 }
207
208 if fetch_context.len() > OPEN_FETCH_TAG.len() {
209 fetch_context.push_str("</fetched_urls>\n");
210 message
211 .content
212 .push(language_model::MessageContent::Text(fetch_context));
213 }
214
215 if rules_context.len() > OPEN_RULES_TAG.len() {
216 rules_context.push_str("</user_rules>\n");
217 message
218 .content
219 .push(language_model::MessageContent::Text(rules_context));
220 }
221
222 if message.content.len() > len_before_context {
223 message.content.insert(
224 len_before_context,
225 language_model::MessageContent::Text(OPEN_CONTEXT.into()),
226 );
227 message
228 .content
229 .push(language_model::MessageContent::Text("</context>".into()));
230 }
231
232 message
233 }
234}
235
236fn codeblock_tag(full_path: &Path, line_range: Option<&Range<u32>>) -> String {
237 let mut result = String::new();
238
239 if let Some(extension) = full_path.extension().and_then(|ext| ext.to_str()) {
240 let _ = write!(result, "{} ", extension);
241 }
242
243 let _ = write!(result, "{}", full_path.display());
244
245 if let Some(range) = line_range {
246 if range.start == range.end {
247 let _ = write!(result, ":{}", range.start + 1);
248 } else {
249 let _ = write!(result, ":{}-{}", range.start + 1, range.end + 1);
250 }
251 }
252
253 result
254}
255
256impl AgentMessage {
257 pub fn to_markdown(&self) -> String {
258 let mut markdown = String::from("## Assistant\n\n");
259
260 for content in &self.content {
261 match content {
262 AgentMessageContent::Text(text) => {
263 markdown.push_str(text);
264 markdown.push('\n');
265 }
266 AgentMessageContent::Thinking { text, .. } => {
267 markdown.push_str("<think>");
268 markdown.push_str(text);
269 markdown.push_str("</think>\n");
270 }
271 AgentMessageContent::RedactedThinking(_) => {
272 markdown.push_str("<redacted_thinking />\n")
273 }
274 AgentMessageContent::Image(_) => {
275 markdown.push_str("<image />\n");
276 }
277 AgentMessageContent::ToolUse(tool_use) => {
278 markdown.push_str(&format!(
279 "**Tool Use**: {} (ID: {})\n",
280 tool_use.name, tool_use.id
281 ));
282 markdown.push_str(&format!(
283 "{}\n",
284 MarkdownCodeBlock {
285 tag: "json",
286 text: &format!("{:#}", tool_use.input)
287 }
288 ));
289 }
290 }
291 }
292
293 for tool_result in self.tool_results.values() {
294 markdown.push_str(&format!(
295 "**Tool Result**: {} (ID: {})\n\n",
296 tool_result.tool_name, tool_result.tool_use_id
297 ));
298 if tool_result.is_error {
299 markdown.push_str("**ERROR:**\n");
300 }
301
302 match &tool_result.content {
303 LanguageModelToolResultContent::Text(text) => {
304 writeln!(markdown, "{text}\n").ok();
305 }
306 LanguageModelToolResultContent::Image(_) => {
307 writeln!(markdown, "<image />\n").ok();
308 }
309 }
310
311 if let Some(output) = tool_result.output.as_ref() {
312 writeln!(
313 markdown,
314 "**Debug Output**:\n\n```json\n{}\n```\n",
315 serde_json::to_string_pretty(output).unwrap()
316 )
317 .unwrap();
318 }
319 }
320
321 markdown
322 }
323
324 pub fn to_request(&self) -> Vec<LanguageModelRequestMessage> {
325 let mut assistant_message = LanguageModelRequestMessage {
326 role: Role::Assistant,
327 content: Vec::with_capacity(self.content.len()),
328 cache: false,
329 };
330 for chunk in &self.content {
331 let chunk = match chunk {
332 AgentMessageContent::Text(text) => {
333 language_model::MessageContent::Text(text.clone())
334 }
335 AgentMessageContent::Thinking { text, signature } => {
336 language_model::MessageContent::Thinking {
337 text: text.clone(),
338 signature: signature.clone(),
339 }
340 }
341 AgentMessageContent::RedactedThinking(value) => {
342 language_model::MessageContent::RedactedThinking(value.clone())
343 }
344 AgentMessageContent::ToolUse(value) => {
345 language_model::MessageContent::ToolUse(value.clone())
346 }
347 AgentMessageContent::Image(value) => {
348 language_model::MessageContent::Image(value.clone())
349 }
350 };
351 assistant_message.content.push(chunk);
352 }
353
354 let mut user_message = LanguageModelRequestMessage {
355 role: Role::User,
356 content: Vec::new(),
357 cache: false,
358 };
359
360 for tool_result in self.tool_results.values() {
361 user_message
362 .content
363 .push(language_model::MessageContent::ToolResult(
364 tool_result.clone(),
365 ));
366 }
367
368 let mut messages = Vec::new();
369 if !assistant_message.content.is_empty() {
370 messages.push(assistant_message);
371 }
372 if !user_message.content.is_empty() {
373 messages.push(user_message);
374 }
375 messages
376 }
377}
378
379#[derive(Default, Debug, Clone, PartialEq, Eq)]
380pub struct AgentMessage {
381 pub content: Vec<AgentMessageContent>,
382 pub tool_results: IndexMap<LanguageModelToolUseId, LanguageModelToolResult>,
383}
384
385#[derive(Debug, Clone, PartialEq, Eq)]
386pub enum AgentMessageContent {
387 Text(String),
388 Thinking {
389 text: String,
390 signature: Option<String>,
391 },
392 RedactedThinking(String),
393 Image(LanguageModelImage),
394 ToolUse(LanguageModelToolUse),
395}
396
397#[derive(Debug)]
398pub enum AgentResponseEvent {
399 Text(String),
400 Thinking(String),
401 ToolCall(acp::ToolCall),
402 ToolCallUpdate(acp_thread::ToolCallUpdate),
403 ToolCallAuthorization(ToolCallAuthorization),
404 Stop(acp::StopReason),
405}
406
407#[derive(Debug)]
408pub struct ToolCallAuthorization {
409 pub tool_call: acp::ToolCall,
410 pub options: Vec<acp::PermissionOption>,
411 pub response: oneshot::Sender<acp::PermissionOptionId>,
412}
413
414pub struct Thread {
415 messages: Vec<Message>,
416 completion_mode: CompletionMode,
417 /// Holds the task that handles agent interaction until the end of the turn.
418 /// Survives across multiple requests as the model performs tool calls and
419 /// we run tools, report their results.
420 running_turn: Option<Task<()>>,
421 pending_message: Option<AgentMessage>,
422 tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
423 tool_use_limit_reached: bool,
424 context_server_registry: Entity<ContextServerRegistry>,
425 profile_id: AgentProfileId,
426 project_context: Rc<RefCell<ProjectContext>>,
427 templates: Arc<Templates>,
428 model: Arc<dyn LanguageModel>,
429 project: Entity<Project>,
430 action_log: Entity<ActionLog>,
431}
432
433impl Thread {
434 pub fn new(
435 project: Entity<Project>,
436 project_context: Rc<RefCell<ProjectContext>>,
437 context_server_registry: Entity<ContextServerRegistry>,
438 action_log: Entity<ActionLog>,
439 templates: Arc<Templates>,
440 model: Arc<dyn LanguageModel>,
441 cx: &mut Context<Self>,
442 ) -> Self {
443 let profile_id = AgentSettings::get_global(cx).default_profile.clone();
444 Self {
445 messages: Vec::new(),
446 completion_mode: CompletionMode::Normal,
447 running_turn: None,
448 pending_message: None,
449 tools: BTreeMap::default(),
450 tool_use_limit_reached: false,
451 context_server_registry,
452 profile_id,
453 project_context,
454 templates,
455 model,
456 project,
457 action_log,
458 }
459 }
460
461 pub fn project(&self) -> &Entity<Project> {
462 &self.project
463 }
464
465 pub fn action_log(&self) -> &Entity<ActionLog> {
466 &self.action_log
467 }
468
469 pub fn model(&self) -> &Arc<dyn LanguageModel> {
470 &self.model
471 }
472
473 pub fn set_model(&mut self, model: Arc<dyn LanguageModel>) {
474 self.model = model;
475 }
476
477 pub fn completion_mode(&self) -> CompletionMode {
478 self.completion_mode
479 }
480
481 pub fn set_completion_mode(&mut self, mode: CompletionMode) {
482 self.completion_mode = mode;
483 }
484
485 #[cfg(any(test, feature = "test-support"))]
486 pub fn last_message(&self) -> Option<Message> {
487 if let Some(message) = self.pending_message.clone() {
488 Some(Message::Agent(message))
489 } else {
490 self.messages.last().cloned()
491 }
492 }
493
494 pub fn add_tool(&mut self, tool: impl AgentTool) {
495 self.tools.insert(tool.name(), tool.erase());
496 }
497
498 pub fn remove_tool(&mut self, name: &str) -> bool {
499 self.tools.remove(name).is_some()
500 }
501
502 pub fn set_profile(&mut self, profile_id: AgentProfileId) {
503 self.profile_id = profile_id;
504 }
505
506 pub fn cancel(&mut self) {
507 // TODO: do we need to emit a stop::cancel for ACP?
508 self.running_turn.take();
509 self.flush_pending_message();
510 }
511
512 pub fn truncate(&mut self, message_id: UserMessageId) -> Result<()> {
513 self.cancel();
514 let Some(position) = self.messages.iter().position(
515 |msg| matches!(msg, Message::User(UserMessage { id, .. }) if id == &message_id),
516 ) else {
517 return Err(anyhow!("Message not found"));
518 };
519 self.messages.truncate(position);
520 Ok(())
521 }
522
523 pub fn resume(
524 &mut self,
525 cx: &mut Context<Self>,
526 ) -> Result<mpsc::UnboundedReceiver<Result<AgentResponseEvent>>> {
527 anyhow::ensure!(
528 self.tool_use_limit_reached,
529 "can only resume after tool use limit is reached"
530 );
531
532 self.messages.push(Message::Resume);
533 cx.notify();
534
535 log::info!("Total messages in thread: {}", self.messages.len());
536 Ok(self.run_turn(cx))
537 }
538
539 /// Sending a message results in the model streaming a response, which could include tool calls.
540 /// After calling tools, the model will stops and waits for any outstanding tool calls to be completed and their results sent.
541 /// The returned channel will report all the occurrences in which the model stops before erroring or ending its turn.
542 pub fn send<T>(
543 &mut self,
544 id: UserMessageId,
545 content: impl IntoIterator<Item = T>,
546 cx: &mut Context<Self>,
547 ) -> mpsc::UnboundedReceiver<Result<AgentResponseEvent>>
548 where
549 T: Into<UserMessageContent>,
550 {
551 log::info!("Thread::send called with model: {:?}", self.model.name());
552
553 let content = content.into_iter().map(Into::into).collect::<Vec<_>>();
554 log::debug!("Thread::send content: {:?}", content);
555
556 self.messages
557 .push(Message::User(UserMessage { id, content }));
558 cx.notify();
559
560 log::info!("Total messages in thread: {}", self.messages.len());
561 self.run_turn(cx)
562 }
563
564 fn run_turn(
565 &mut self,
566 cx: &mut Context<Self>,
567 ) -> mpsc::UnboundedReceiver<Result<AgentResponseEvent>> {
568 let model = self.model.clone();
569 let (events_tx, events_rx) = mpsc::unbounded::<Result<AgentResponseEvent>>();
570 let event_stream = AgentResponseEventStream(events_tx);
571 let message_ix = self.messages.len().saturating_sub(1);
572 self.tool_use_limit_reached = false;
573 self.running_turn = Some(cx.spawn(async move |this, cx| {
574 log::info!("Starting agent turn execution");
575 let turn_result: Result<()> = async {
576 let mut completion_intent = CompletionIntent::UserPrompt;
577 loop {
578 log::debug!(
579 "Building completion request with intent: {:?}",
580 completion_intent
581 );
582 let request = this.update(cx, |this, cx| {
583 this.build_completion_request(completion_intent, cx)
584 })?;
585
586 log::info!("Calling model.stream_completion");
587 let mut events = model.stream_completion(request, cx).await?;
588 log::debug!("Stream completion started successfully");
589
590 let mut tool_use_limit_reached = false;
591 let mut tool_uses = FuturesUnordered::new();
592 while let Some(event) = events.next().await {
593 match event? {
594 LanguageModelCompletionEvent::StatusUpdate(
595 CompletionRequestStatus::ToolUseLimitReached,
596 ) => {
597 tool_use_limit_reached = true;
598 }
599 LanguageModelCompletionEvent::Stop(reason) => {
600 event_stream.send_stop(reason);
601 if reason == StopReason::Refusal {
602 this.update(cx, |this, _cx| {
603 this.flush_pending_message();
604 this.messages.truncate(message_ix);
605 })?;
606 return Ok(());
607 }
608 }
609 event => {
610 log::trace!("Received completion event: {:?}", event);
611 this.update(cx, |this, cx| {
612 tool_uses.extend(this.handle_streamed_completion_event(
613 event,
614 &event_stream,
615 cx,
616 ));
617 })
618 .ok();
619 }
620 }
621 }
622
623 let used_tools = tool_uses.is_empty();
624 while let Some(tool_result) = tool_uses.next().await {
625 log::info!("Tool finished {:?}", tool_result);
626
627 event_stream.update_tool_call_fields(
628 &tool_result.tool_use_id,
629 acp::ToolCallUpdateFields {
630 status: Some(if tool_result.is_error {
631 acp::ToolCallStatus::Failed
632 } else {
633 acp::ToolCallStatus::Completed
634 }),
635 raw_output: tool_result.output.clone(),
636 ..Default::default()
637 },
638 );
639 this.update(cx, |this, _cx| {
640 this.pending_message()
641 .tool_results
642 .insert(tool_result.tool_use_id.clone(), tool_result);
643 })
644 .ok();
645 }
646
647 if tool_use_limit_reached {
648 log::info!("Tool use limit reached, completing turn");
649 this.update(cx, |this, _cx| this.tool_use_limit_reached = true)?;
650 return Err(language_model::ToolUseLimitReachedError.into());
651 } else if used_tools {
652 log::info!("No tool uses found, completing turn");
653 return Ok(());
654 } else {
655 this.update(cx, |this, _| this.flush_pending_message())?;
656 completion_intent = CompletionIntent::ToolResults;
657 }
658 }
659 }
660 .await;
661
662 this.update(cx, |this, _| this.flush_pending_message()).ok();
663 if let Err(error) = turn_result {
664 log::error!("Turn execution failed: {:?}", error);
665 event_stream.send_error(error);
666 } else {
667 log::info!("Turn execution completed successfully");
668 }
669 }));
670 events_rx
671 }
672
673 pub fn build_system_message(&self) -> LanguageModelRequestMessage {
674 log::debug!("Building system message");
675 let prompt = SystemPromptTemplate {
676 project: &self.project_context.borrow(),
677 available_tools: self.tools.keys().cloned().collect(),
678 }
679 .render(&self.templates)
680 .context("failed to build system prompt")
681 .expect("Invalid template");
682 log::debug!("System message built");
683 LanguageModelRequestMessage {
684 role: Role::System,
685 content: vec![prompt.into()],
686 cache: true,
687 }
688 }
689
690 /// A helper method that's called on every streamed completion event.
691 /// Returns an optional tool result task, which the main agentic loop in
692 /// send will send back to the model when it resolves.
693 fn handle_streamed_completion_event(
694 &mut self,
695 event: LanguageModelCompletionEvent,
696 event_stream: &AgentResponseEventStream,
697 cx: &mut Context<Self>,
698 ) -> Option<Task<LanguageModelToolResult>> {
699 log::trace!("Handling streamed completion event: {:?}", event);
700 use LanguageModelCompletionEvent::*;
701
702 match event {
703 StartMessage { .. } => {
704 self.flush_pending_message();
705 self.pending_message = Some(AgentMessage::default());
706 }
707 Text(new_text) => self.handle_text_event(new_text, event_stream, cx),
708 Thinking { text, signature } => {
709 self.handle_thinking_event(text, signature, event_stream, cx)
710 }
711 RedactedThinking { data } => self.handle_redacted_thinking_event(data, cx),
712 ToolUse(tool_use) => {
713 return self.handle_tool_use_event(tool_use, event_stream, cx);
714 }
715 ToolUseJsonParseError {
716 id,
717 tool_name,
718 raw_input,
719 json_parse_error,
720 } => {
721 return Some(Task::ready(self.handle_tool_use_json_parse_error_event(
722 id,
723 tool_name,
724 raw_input,
725 json_parse_error,
726 )));
727 }
728 UsageUpdate(_) | StatusUpdate(_) => {}
729 Stop(_) => unreachable!(),
730 }
731
732 None
733 }
734
735 fn handle_text_event(
736 &mut self,
737 new_text: String,
738 event_stream: &AgentResponseEventStream,
739 cx: &mut Context<Self>,
740 ) {
741 event_stream.send_text(&new_text);
742
743 let last_message = self.pending_message();
744 if let Some(AgentMessageContent::Text(text)) = last_message.content.last_mut() {
745 text.push_str(&new_text);
746 } else {
747 last_message
748 .content
749 .push(AgentMessageContent::Text(new_text));
750 }
751
752 cx.notify();
753 }
754
755 fn handle_thinking_event(
756 &mut self,
757 new_text: String,
758 new_signature: Option<String>,
759 event_stream: &AgentResponseEventStream,
760 cx: &mut Context<Self>,
761 ) {
762 event_stream.send_thinking(&new_text);
763
764 let last_message = self.pending_message();
765 if let Some(AgentMessageContent::Thinking { text, signature }) =
766 last_message.content.last_mut()
767 {
768 text.push_str(&new_text);
769 *signature = new_signature.or(signature.take());
770 } else {
771 last_message.content.push(AgentMessageContent::Thinking {
772 text: new_text,
773 signature: new_signature,
774 });
775 }
776
777 cx.notify();
778 }
779
780 fn handle_redacted_thinking_event(&mut self, data: String, cx: &mut Context<Self>) {
781 let last_message = self.pending_message();
782 last_message
783 .content
784 .push(AgentMessageContent::RedactedThinking(data));
785 cx.notify();
786 }
787
788 fn handle_tool_use_event(
789 &mut self,
790 tool_use: LanguageModelToolUse,
791 event_stream: &AgentResponseEventStream,
792 cx: &mut Context<Self>,
793 ) -> Option<Task<LanguageModelToolResult>> {
794 cx.notify();
795
796 let tool = self.tools.get(tool_use.name.as_ref()).cloned();
797 let mut title = SharedString::from(&tool_use.name);
798 let mut kind = acp::ToolKind::Other;
799 if let Some(tool) = tool.as_ref() {
800 title = tool.initial_title(tool_use.input.clone());
801 kind = tool.kind();
802 }
803
804 // Ensure the last message ends in the current tool use
805 let last_message = self.pending_message();
806 let push_new_tool_use = last_message.content.last_mut().map_or(true, |content| {
807 if let AgentMessageContent::ToolUse(last_tool_use) = content {
808 if last_tool_use.id == tool_use.id {
809 *last_tool_use = tool_use.clone();
810 false
811 } else {
812 true
813 }
814 } else {
815 true
816 }
817 });
818
819 if push_new_tool_use {
820 event_stream.send_tool_call(&tool_use.id, title, kind, tool_use.input.clone());
821 last_message
822 .content
823 .push(AgentMessageContent::ToolUse(tool_use.clone()));
824 } else {
825 event_stream.update_tool_call_fields(
826 &tool_use.id,
827 acp::ToolCallUpdateFields {
828 title: Some(title.into()),
829 kind: Some(kind),
830 raw_input: Some(tool_use.input.clone()),
831 ..Default::default()
832 },
833 );
834 }
835
836 if !tool_use.is_input_complete {
837 return None;
838 }
839
840 let Some(tool) = tool else {
841 let content = format!("No tool named {} exists", tool_use.name);
842 return Some(Task::ready(LanguageModelToolResult {
843 content: LanguageModelToolResultContent::Text(Arc::from(content)),
844 tool_use_id: tool_use.id,
845 tool_name: tool_use.name,
846 is_error: true,
847 output: None,
848 }));
849 };
850
851 let fs = self.project.read(cx).fs().clone();
852 let tool_event_stream =
853 ToolCallEventStream::new(&tool_use, tool.kind(), event_stream.clone(), Some(fs));
854 tool_event_stream.update_fields(acp::ToolCallUpdateFields {
855 status: Some(acp::ToolCallStatus::InProgress),
856 ..Default::default()
857 });
858 let supports_images = self.model.supports_images();
859 let tool_result = tool.run(tool_use.input, tool_event_stream, cx);
860 log::info!("Running tool {}", tool_use.name);
861 Some(cx.foreground_executor().spawn(async move {
862 let tool_result = tool_result.await.and_then(|output| {
863 if let LanguageModelToolResultContent::Image(_) = &output.llm_output {
864 if !supports_images {
865 return Err(anyhow!(
866 "Attempted to read an image, but this model doesn't support it.",
867 ));
868 }
869 }
870 Ok(output)
871 });
872
873 match tool_result {
874 Ok(output) => LanguageModelToolResult {
875 tool_use_id: tool_use.id,
876 tool_name: tool_use.name,
877 is_error: false,
878 content: output.llm_output,
879 output: Some(output.raw_output),
880 },
881 Err(error) => LanguageModelToolResult {
882 tool_use_id: tool_use.id,
883 tool_name: tool_use.name,
884 is_error: true,
885 content: LanguageModelToolResultContent::Text(Arc::from(error.to_string())),
886 output: None,
887 },
888 }
889 }))
890 }
891
892 fn handle_tool_use_json_parse_error_event(
893 &mut self,
894 tool_use_id: LanguageModelToolUseId,
895 tool_name: Arc<str>,
896 raw_input: Arc<str>,
897 json_parse_error: String,
898 ) -> LanguageModelToolResult {
899 let tool_output = format!("Error parsing input JSON: {json_parse_error}");
900 LanguageModelToolResult {
901 tool_use_id,
902 tool_name,
903 is_error: true,
904 content: LanguageModelToolResultContent::Text(tool_output.into()),
905 output: Some(serde_json::Value::String(raw_input.to_string())),
906 }
907 }
908
909 fn pending_message(&mut self) -> &mut AgentMessage {
910 self.pending_message.get_or_insert_default()
911 }
912
913 fn flush_pending_message(&mut self) {
914 let Some(mut message) = self.pending_message.take() else {
915 return;
916 };
917
918 for content in &message.content {
919 let AgentMessageContent::ToolUse(tool_use) = content else {
920 continue;
921 };
922
923 if !message.tool_results.contains_key(&tool_use.id) {
924 message.tool_results.insert(
925 tool_use.id.clone(),
926 LanguageModelToolResult {
927 tool_use_id: tool_use.id.clone(),
928 tool_name: tool_use.name.clone(),
929 is_error: true,
930 content: LanguageModelToolResultContent::Text(
931 "Tool canceled by user".into(),
932 ),
933 output: None,
934 },
935 );
936 }
937 }
938
939 self.messages.push(Message::Agent(message));
940 }
941
942 pub(crate) fn build_completion_request(
943 &self,
944 completion_intent: CompletionIntent,
945 cx: &mut App,
946 ) -> LanguageModelRequest {
947 log::debug!("Building completion request");
948 log::debug!("Completion intent: {:?}", completion_intent);
949 log::debug!("Completion mode: {:?}", self.completion_mode);
950
951 let messages = self.build_request_messages();
952 log::info!("Request will include {} messages", messages.len());
953
954 let tools = if let Some(tools) = self.tools(cx).log_err() {
955 tools
956 .filter_map(|tool| {
957 let tool_name = tool.name().to_string();
958 log::trace!("Including tool: {}", tool_name);
959 Some(LanguageModelRequestTool {
960 name: tool_name,
961 description: tool.description().to_string(),
962 input_schema: tool
963 .input_schema(self.model.tool_input_format())
964 .log_err()?,
965 })
966 })
967 .collect()
968 } else {
969 Vec::new()
970 };
971
972 log::info!("Request includes {} tools", tools.len());
973
974 let request = LanguageModelRequest {
975 thread_id: None,
976 prompt_id: None,
977 intent: Some(completion_intent),
978 mode: Some(self.completion_mode.into()),
979 messages,
980 tools,
981 tool_choice: None,
982 stop: Vec::new(),
983 temperature: None,
984 thinking_allowed: true,
985 };
986
987 log::debug!("Completion request built successfully");
988 request
989 }
990
991 fn tools<'a>(&'a self, cx: &'a App) -> Result<impl Iterator<Item = &'a Arc<dyn AnyAgentTool>>> {
992 let profile = AgentSettings::get_global(cx)
993 .profiles
994 .get(&self.profile_id)
995 .context("profile not found")?;
996 let provider_id = self.model.provider_id();
997
998 Ok(self
999 .tools
1000 .iter()
1001 .filter(move |(_, tool)| tool.supported_provider(&provider_id))
1002 .filter_map(|(tool_name, tool)| {
1003 if profile.is_tool_enabled(tool_name) {
1004 Some(tool)
1005 } else {
1006 None
1007 }
1008 })
1009 .chain(self.context_server_registry.read(cx).servers().flat_map(
1010 |(server_id, tools)| {
1011 tools.iter().filter_map(|(tool_name, tool)| {
1012 if profile.is_context_server_tool_enabled(&server_id.0, tool_name) {
1013 Some(tool)
1014 } else {
1015 None
1016 }
1017 })
1018 },
1019 )))
1020 }
1021
1022 fn build_request_messages(&self) -> Vec<LanguageModelRequestMessage> {
1023 log::trace!(
1024 "Building request messages from {} thread messages",
1025 self.messages.len()
1026 );
1027 let mut messages = vec![self.build_system_message()];
1028 for message in &self.messages {
1029 match message {
1030 Message::User(message) => messages.push(message.to_request()),
1031 Message::Agent(message) => messages.extend(message.to_request()),
1032 Message::Resume => messages.push(LanguageModelRequestMessage {
1033 role: Role::User,
1034 content: vec!["Continue where you left off".into()],
1035 cache: false,
1036 }),
1037 }
1038 }
1039
1040 if let Some(message) = self.pending_message.as_ref() {
1041 messages.extend(message.to_request());
1042 }
1043
1044 if let Some(last_user_message) = messages
1045 .iter_mut()
1046 .rev()
1047 .find(|message| message.role == Role::User)
1048 {
1049 last_user_message.cache = true;
1050 }
1051
1052 messages
1053 }
1054
1055 pub fn to_markdown(&self) -> String {
1056 let mut markdown = String::new();
1057 for (ix, message) in self.messages.iter().enumerate() {
1058 if ix > 0 {
1059 markdown.push('\n');
1060 }
1061 markdown.push_str(&message.to_markdown());
1062 }
1063
1064 if let Some(message) = self.pending_message.as_ref() {
1065 markdown.push('\n');
1066 markdown.push_str(&message.to_markdown());
1067 }
1068
1069 markdown
1070 }
1071}
1072
1073pub trait AgentTool
1074where
1075 Self: 'static + Sized,
1076{
1077 type Input: for<'de> Deserialize<'de> + Serialize + JsonSchema;
1078 type Output: for<'de> Deserialize<'de> + Serialize + Into<LanguageModelToolResultContent>;
1079
1080 fn name(&self) -> SharedString;
1081
1082 fn description(&self) -> SharedString {
1083 let schema = schemars::schema_for!(Self::Input);
1084 SharedString::new(
1085 schema
1086 .get("description")
1087 .and_then(|description| description.as_str())
1088 .unwrap_or_default(),
1089 )
1090 }
1091
1092 fn kind(&self) -> acp::ToolKind;
1093
1094 /// The initial tool title to display. Can be updated during the tool run.
1095 fn initial_title(&self, input: Result<Self::Input, serde_json::Value>) -> SharedString;
1096
1097 /// Returns the JSON schema that describes the tool's input.
1098 fn input_schema(&self) -> Schema {
1099 schemars::schema_for!(Self::Input)
1100 }
1101
1102 /// Some tools rely on a provider for the underlying billing or other reasons.
1103 /// Allow the tool to check if they are compatible, or should be filtered out.
1104 fn supported_provider(&self, _provider: &LanguageModelProviderId) -> bool {
1105 true
1106 }
1107
1108 /// Runs the tool with the provided input.
1109 fn run(
1110 self: Arc<Self>,
1111 input: Self::Input,
1112 event_stream: ToolCallEventStream,
1113 cx: &mut App,
1114 ) -> Task<Result<Self::Output>>;
1115
1116 fn erase(self) -> Arc<dyn AnyAgentTool> {
1117 Arc::new(Erased(Arc::new(self)))
1118 }
1119}
1120
1121pub struct Erased<T>(T);
1122
1123pub struct AgentToolOutput {
1124 pub llm_output: LanguageModelToolResultContent,
1125 pub raw_output: serde_json::Value,
1126}
1127
1128pub trait AnyAgentTool {
1129 fn name(&self) -> SharedString;
1130 fn description(&self) -> SharedString;
1131 fn kind(&self) -> acp::ToolKind;
1132 fn initial_title(&self, input: serde_json::Value) -> SharedString;
1133 fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value>;
1134 fn supported_provider(&self, _provider: &LanguageModelProviderId) -> bool {
1135 true
1136 }
1137 fn run(
1138 self: Arc<Self>,
1139 input: serde_json::Value,
1140 event_stream: ToolCallEventStream,
1141 cx: &mut App,
1142 ) -> Task<Result<AgentToolOutput>>;
1143}
1144
1145impl<T> AnyAgentTool for Erased<Arc<T>>
1146where
1147 T: AgentTool,
1148{
1149 fn name(&self) -> SharedString {
1150 self.0.name()
1151 }
1152
1153 fn description(&self) -> SharedString {
1154 self.0.description()
1155 }
1156
1157 fn kind(&self) -> agent_client_protocol::ToolKind {
1158 self.0.kind()
1159 }
1160
1161 fn initial_title(&self, input: serde_json::Value) -> SharedString {
1162 let parsed_input = serde_json::from_value(input.clone()).map_err(|_| input);
1163 self.0.initial_title(parsed_input)
1164 }
1165
1166 fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
1167 let mut json = serde_json::to_value(self.0.input_schema())?;
1168 adapt_schema_to_format(&mut json, format)?;
1169 Ok(json)
1170 }
1171
1172 fn supported_provider(&self, provider: &LanguageModelProviderId) -> bool {
1173 self.0.supported_provider(provider)
1174 }
1175
1176 fn run(
1177 self: Arc<Self>,
1178 input: serde_json::Value,
1179 event_stream: ToolCallEventStream,
1180 cx: &mut App,
1181 ) -> Task<Result<AgentToolOutput>> {
1182 cx.spawn(async move |cx| {
1183 let input = serde_json::from_value(input)?;
1184 let output = cx
1185 .update(|cx| self.0.clone().run(input, event_stream, cx))?
1186 .await?;
1187 let raw_output = serde_json::to_value(&output)?;
1188 Ok(AgentToolOutput {
1189 llm_output: output.into(),
1190 raw_output,
1191 })
1192 })
1193 }
1194}
1195
1196#[derive(Clone)]
1197struct AgentResponseEventStream(mpsc::UnboundedSender<Result<AgentResponseEvent>>);
1198
1199impl AgentResponseEventStream {
1200 fn send_text(&self, text: &str) {
1201 self.0
1202 .unbounded_send(Ok(AgentResponseEvent::Text(text.to_string())))
1203 .ok();
1204 }
1205
1206 fn send_thinking(&self, text: &str) {
1207 self.0
1208 .unbounded_send(Ok(AgentResponseEvent::Thinking(text.to_string())))
1209 .ok();
1210 }
1211
1212 fn send_tool_call(
1213 &self,
1214 id: &LanguageModelToolUseId,
1215 title: SharedString,
1216 kind: acp::ToolKind,
1217 input: serde_json::Value,
1218 ) {
1219 self.0
1220 .unbounded_send(Ok(AgentResponseEvent::ToolCall(Self::initial_tool_call(
1221 id,
1222 title.to_string(),
1223 kind,
1224 input,
1225 ))))
1226 .ok();
1227 }
1228
1229 fn initial_tool_call(
1230 id: &LanguageModelToolUseId,
1231 title: String,
1232 kind: acp::ToolKind,
1233 input: serde_json::Value,
1234 ) -> acp::ToolCall {
1235 acp::ToolCall {
1236 id: acp::ToolCallId(id.to_string().into()),
1237 title,
1238 kind,
1239 status: acp::ToolCallStatus::Pending,
1240 content: vec![],
1241 locations: vec![],
1242 raw_input: Some(input),
1243 raw_output: None,
1244 }
1245 }
1246
1247 fn update_tool_call_fields(
1248 &self,
1249 tool_use_id: &LanguageModelToolUseId,
1250 fields: acp::ToolCallUpdateFields,
1251 ) {
1252 self.0
1253 .unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate(
1254 acp::ToolCallUpdate {
1255 id: acp::ToolCallId(tool_use_id.to_string().into()),
1256 fields,
1257 }
1258 .into(),
1259 )))
1260 .ok();
1261 }
1262
1263 fn send_stop(&self, reason: StopReason) {
1264 match reason {
1265 StopReason::EndTurn => {
1266 self.0
1267 .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::EndTurn)))
1268 .ok();
1269 }
1270 StopReason::MaxTokens => {
1271 self.0
1272 .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::MaxTokens)))
1273 .ok();
1274 }
1275 StopReason::Refusal => {
1276 self.0
1277 .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::Refusal)))
1278 .ok();
1279 }
1280 StopReason::ToolUse => {}
1281 }
1282 }
1283
1284 fn send_error(&self, error: impl Into<anyhow::Error>) {
1285 self.0.unbounded_send(Err(error.into())).ok();
1286 }
1287}
1288
1289#[derive(Clone)]
1290pub struct ToolCallEventStream {
1291 tool_use_id: LanguageModelToolUseId,
1292 kind: acp::ToolKind,
1293 input: serde_json::Value,
1294 stream: AgentResponseEventStream,
1295 fs: Option<Arc<dyn Fs>>,
1296}
1297
1298impl ToolCallEventStream {
1299 #[cfg(test)]
1300 pub fn test() -> (Self, ToolCallEventStreamReceiver) {
1301 let (events_tx, events_rx) = mpsc::unbounded::<Result<AgentResponseEvent>>();
1302
1303 let stream = ToolCallEventStream::new(
1304 &LanguageModelToolUse {
1305 id: "test_id".into(),
1306 name: "test_tool".into(),
1307 raw_input: String::new(),
1308 input: serde_json::Value::Null,
1309 is_input_complete: true,
1310 },
1311 acp::ToolKind::Other,
1312 AgentResponseEventStream(events_tx),
1313 None,
1314 );
1315
1316 (stream, ToolCallEventStreamReceiver(events_rx))
1317 }
1318
1319 fn new(
1320 tool_use: &LanguageModelToolUse,
1321 kind: acp::ToolKind,
1322 stream: AgentResponseEventStream,
1323 fs: Option<Arc<dyn Fs>>,
1324 ) -> Self {
1325 Self {
1326 tool_use_id: tool_use.id.clone(),
1327 kind,
1328 input: tool_use.input.clone(),
1329 stream,
1330 fs,
1331 }
1332 }
1333
1334 pub fn update_fields(&self, fields: acp::ToolCallUpdateFields) {
1335 self.stream
1336 .update_tool_call_fields(&self.tool_use_id, fields);
1337 }
1338
1339 pub fn update_diff(&self, diff: Entity<acp_thread::Diff>) {
1340 self.stream
1341 .0
1342 .unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate(
1343 acp_thread::ToolCallUpdateDiff {
1344 id: acp::ToolCallId(self.tool_use_id.to_string().into()),
1345 diff,
1346 }
1347 .into(),
1348 )))
1349 .ok();
1350 }
1351
1352 pub fn update_terminal(&self, terminal: Entity<acp_thread::Terminal>) {
1353 self.stream
1354 .0
1355 .unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate(
1356 acp_thread::ToolCallUpdateTerminal {
1357 id: acp::ToolCallId(self.tool_use_id.to_string().into()),
1358 terminal,
1359 }
1360 .into(),
1361 )))
1362 .ok();
1363 }
1364
1365 pub fn authorize(&self, title: impl Into<String>, cx: &mut App) -> Task<Result<()>> {
1366 if agent_settings::AgentSettings::get_global(cx).always_allow_tool_actions {
1367 return Task::ready(Ok(()));
1368 }
1369
1370 let (response_tx, response_rx) = oneshot::channel();
1371 self.stream
1372 .0
1373 .unbounded_send(Ok(AgentResponseEvent::ToolCallAuthorization(
1374 ToolCallAuthorization {
1375 tool_call: AgentResponseEventStream::initial_tool_call(
1376 &self.tool_use_id,
1377 title.into(),
1378 self.kind.clone(),
1379 self.input.clone(),
1380 ),
1381 options: vec![
1382 acp::PermissionOption {
1383 id: acp::PermissionOptionId("always_allow".into()),
1384 name: "Always Allow".into(),
1385 kind: acp::PermissionOptionKind::AllowAlways,
1386 },
1387 acp::PermissionOption {
1388 id: acp::PermissionOptionId("allow".into()),
1389 name: "Allow".into(),
1390 kind: acp::PermissionOptionKind::AllowOnce,
1391 },
1392 acp::PermissionOption {
1393 id: acp::PermissionOptionId("deny".into()),
1394 name: "Deny".into(),
1395 kind: acp::PermissionOptionKind::RejectOnce,
1396 },
1397 ],
1398 response: response_tx,
1399 },
1400 )))
1401 .ok();
1402 let fs = self.fs.clone();
1403 cx.spawn(async move |cx| match response_rx.await?.0.as_ref() {
1404 "always_allow" => {
1405 if let Some(fs) = fs.clone() {
1406 cx.update(|cx| {
1407 update_settings_file::<AgentSettings>(fs, cx, |settings, _| {
1408 settings.set_always_allow_tool_actions(true);
1409 });
1410 })?;
1411 }
1412
1413 Ok(())
1414 }
1415 "allow" => Ok(()),
1416 _ => Err(anyhow!("Permission to run tool denied by user")),
1417 })
1418 }
1419}
1420
1421#[cfg(test)]
1422pub struct ToolCallEventStreamReceiver(mpsc::UnboundedReceiver<Result<AgentResponseEvent>>);
1423
1424#[cfg(test)]
1425impl ToolCallEventStreamReceiver {
1426 pub async fn expect_authorization(&mut self) -> ToolCallAuthorization {
1427 let event = self.0.next().await;
1428 if let Some(Ok(AgentResponseEvent::ToolCallAuthorization(auth))) = event {
1429 auth
1430 } else {
1431 panic!("Expected ToolCallAuthorization but got: {:?}", event);
1432 }
1433 }
1434
1435 pub async fn expect_terminal(&mut self) -> Entity<acp_thread::Terminal> {
1436 let event = self.0.next().await;
1437 if let Some(Ok(AgentResponseEvent::ToolCallUpdate(
1438 acp_thread::ToolCallUpdate::UpdateTerminal(update),
1439 ))) = event
1440 {
1441 update.terminal
1442 } else {
1443 panic!("Expected terminal but got: {:?}", event);
1444 }
1445 }
1446}
1447
1448#[cfg(test)]
1449impl std::ops::Deref for ToolCallEventStreamReceiver {
1450 type Target = mpsc::UnboundedReceiver<Result<AgentResponseEvent>>;
1451
1452 fn deref(&self) -> &Self::Target {
1453 &self.0
1454 }
1455}
1456
1457#[cfg(test)]
1458impl std::ops::DerefMut for ToolCallEventStreamReceiver {
1459 fn deref_mut(&mut self) -> &mut Self::Target {
1460 &mut self.0
1461 }
1462}
1463
1464impl From<&str> for UserMessageContent {
1465 fn from(text: &str) -> Self {
1466 Self::Text(text.into())
1467 }
1468}
1469
1470impl From<acp::ContentBlock> for UserMessageContent {
1471 fn from(value: acp::ContentBlock) -> Self {
1472 match value {
1473 acp::ContentBlock::Text(text_content) => Self::Text(text_content.text),
1474 acp::ContentBlock::Image(image_content) => Self::Image(convert_image(image_content)),
1475 acp::ContentBlock::Audio(_) => {
1476 // TODO
1477 Self::Text("[audio]".to_string())
1478 }
1479 acp::ContentBlock::ResourceLink(resource_link) => {
1480 match MentionUri::parse(&resource_link.uri) {
1481 Ok(uri) => Self::Mention {
1482 uri,
1483 content: String::new(),
1484 },
1485 Err(err) => {
1486 log::error!("Failed to parse mention link: {}", err);
1487 Self::Text(format!("[{}]({})", resource_link.name, resource_link.uri))
1488 }
1489 }
1490 }
1491 acp::ContentBlock::Resource(resource) => match resource.resource {
1492 acp::EmbeddedResourceResource::TextResourceContents(resource) => {
1493 match MentionUri::parse(&resource.uri) {
1494 Ok(uri) => Self::Mention {
1495 uri,
1496 content: resource.text,
1497 },
1498 Err(err) => {
1499 log::error!("Failed to parse mention link: {}", err);
1500 Self::Text(
1501 MarkdownCodeBlock {
1502 tag: &resource.uri,
1503 text: &resource.text,
1504 }
1505 .to_string(),
1506 )
1507 }
1508 }
1509 }
1510 acp::EmbeddedResourceResource::BlobResourceContents(_) => {
1511 // TODO
1512 Self::Text("[blob]".to_string())
1513 }
1514 },
1515 }
1516 }
1517}
1518
1519fn convert_image(image_content: acp::ImageContent) -> LanguageModelImage {
1520 LanguageModelImage {
1521 source: image_content.data.into(),
1522 // TODO: make this optional?
1523 size: gpui::Size::new(0.into(), 0.into()),
1524 }
1525}