1use crate::{
2 ContextServerRegistry, DbLanguageModel, DbThread, SystemPromptTemplate, Template, Templates,
3};
4use acp_thread::{MentionUri, UserMessageId};
5use action_log::ActionLog;
6use agent::thread::{DetailedSummaryState, GitState, ProjectSnapshot, WorktreeSnapshot};
7use agent_client_protocol as acp;
8use agent_settings::{AgentProfileId, AgentSettings, CompletionMode, SUMMARIZE_THREAD_PROMPT};
9use anyhow::{Context as _, Result, anyhow};
10use assistant_tool::adapt_schema_to_format;
11use chrono::{DateTime, Utc};
12use cloud_llm_client::{CompletionIntent, CompletionRequestStatus};
13use collections::IndexMap;
14use fs::Fs;
15use futures::{
16 FutureExt,
17 channel::{mpsc, oneshot},
18 future::Shared,
19 stream::FuturesUnordered,
20};
21use git::repository::DiffType;
22use gpui::{App, AppContext, Context, Entity, SharedString, Task};
23use language_model::{
24 LanguageModel, LanguageModelCompletionEvent, LanguageModelImage, LanguageModelProviderId,
25 LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
26 LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolSchemaFormat,
27 LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role, StopReason, TokenUsage,
28};
29use project::{
30 Project,
31 git_store::{GitStore, RepositoryState},
32};
33use prompt_store::ProjectContext;
34use schemars::{JsonSchema, Schema};
35use serde::{Deserialize, Serialize};
36use settings::{Settings, update_settings_file};
37use smol::stream::StreamExt;
38use std::{cell::RefCell, collections::BTreeMap, path::Path, rc::Rc, sync::Arc};
39use std::{fmt::Write, ops::Range};
40use util::{ResultExt, markdown::MarkdownCodeBlock};
41use uuid::Uuid;
42
43const TOOL_CANCELED_MESSAGE: &str = "Tool canceled by user";
44
45/// The ID of the user prompt that initiated a request.
46///
47/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
48#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
49pub struct PromptId(Arc<str>);
50
51impl PromptId {
52 pub fn new() -> Self {
53 Self(Uuid::new_v4().to_string().into())
54 }
55}
56
57impl std::fmt::Display for PromptId {
58 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
59 write!(f, "{}", self.0)
60 }
61}
62
63#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
64pub enum Message {
65 User(UserMessage),
66 Agent(AgentMessage),
67 Resume,
68}
69
70impl Message {
71 pub fn as_agent_message(&self) -> Option<&AgentMessage> {
72 match self {
73 Message::Agent(agent_message) => Some(agent_message),
74 _ => None,
75 }
76 }
77
78 pub fn to_request(&self) -> Vec<LanguageModelRequestMessage> {
79 match self {
80 Message::User(message) => vec![message.to_request()],
81 Message::Agent(message) => message.to_request(),
82 Message::Resume => vec![LanguageModelRequestMessage {
83 role: Role::User,
84 content: vec!["Continue where you left off".into()],
85 cache: false,
86 }],
87 }
88 }
89
90 pub fn to_markdown(&self) -> String {
91 match self {
92 Message::User(message) => message.to_markdown(),
93 Message::Agent(message) => message.to_markdown(),
94 Message::Resume => "[resumed after tool use limit was reached]".into(),
95 }
96 }
97
98 pub fn role(&self) -> Role {
99 match self {
100 Message::User(_) | Message::Resume => Role::User,
101 Message::Agent(_) => Role::Assistant,
102 }
103 }
104}
105
106#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
107pub struct UserMessage {
108 pub id: UserMessageId,
109 pub content: Vec<UserMessageContent>,
110}
111
112#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
113pub enum UserMessageContent {
114 Text(String),
115 Mention { uri: MentionUri, content: String },
116 Image(LanguageModelImage),
117}
118
119impl UserMessage {
120 pub fn to_markdown(&self) -> String {
121 let mut markdown = String::from("## User\n\n");
122
123 for content in &self.content {
124 match content {
125 UserMessageContent::Text(text) => {
126 markdown.push_str(text);
127 markdown.push('\n');
128 }
129 UserMessageContent::Image(_) => {
130 markdown.push_str("<image />\n");
131 }
132 UserMessageContent::Mention { uri, content } => {
133 if !content.is_empty() {
134 let _ = write!(&mut markdown, "{}\n\n{}\n", uri.as_link(), content);
135 } else {
136 let _ = write!(&mut markdown, "{}\n", uri.as_link());
137 }
138 }
139 }
140 }
141
142 markdown
143 }
144
145 fn to_request(&self) -> LanguageModelRequestMessage {
146 let mut message = LanguageModelRequestMessage {
147 role: Role::User,
148 content: Vec::with_capacity(self.content.len()),
149 cache: false,
150 };
151
152 const OPEN_CONTEXT: &str = "<context>\n\
153 The following items were attached by the user. \
154 They are up-to-date and don't need to be re-read.\n\n";
155
156 const OPEN_FILES_TAG: &str = "<files>";
157 const OPEN_SYMBOLS_TAG: &str = "<symbols>";
158 const OPEN_THREADS_TAG: &str = "<threads>";
159 const OPEN_FETCH_TAG: &str = "<fetched_urls>";
160 const OPEN_RULES_TAG: &str =
161 "<rules>\nThe user has specified the following rules that should be applied:\n";
162
163 let mut file_context = OPEN_FILES_TAG.to_string();
164 let mut symbol_context = OPEN_SYMBOLS_TAG.to_string();
165 let mut thread_context = OPEN_THREADS_TAG.to_string();
166 let mut fetch_context = OPEN_FETCH_TAG.to_string();
167 let mut rules_context = OPEN_RULES_TAG.to_string();
168
169 for chunk in &self.content {
170 let chunk = match chunk {
171 UserMessageContent::Text(text) => {
172 language_model::MessageContent::Text(text.clone())
173 }
174 UserMessageContent::Image(value) => {
175 language_model::MessageContent::Image(value.clone())
176 }
177 UserMessageContent::Mention { uri, content } => {
178 match uri {
179 MentionUri::File { abs_path, .. } => {
180 write!(
181 &mut symbol_context,
182 "\n{}",
183 MarkdownCodeBlock {
184 tag: &codeblock_tag(&abs_path, None),
185 text: &content.to_string(),
186 }
187 )
188 .ok();
189 }
190 MentionUri::Symbol {
191 path, line_range, ..
192 }
193 | MentionUri::Selection {
194 path, line_range, ..
195 } => {
196 write!(
197 &mut rules_context,
198 "\n{}",
199 MarkdownCodeBlock {
200 tag: &codeblock_tag(&path, Some(line_range)),
201 text: &content
202 }
203 )
204 .ok();
205 }
206 MentionUri::Thread { .. } => {
207 write!(&mut thread_context, "\n{}\n", content).ok();
208 }
209 MentionUri::TextThread { .. } => {
210 write!(&mut thread_context, "\n{}\n", content).ok();
211 }
212 MentionUri::Rule { .. } => {
213 write!(
214 &mut rules_context,
215 "\n{}",
216 MarkdownCodeBlock {
217 tag: "",
218 text: &content
219 }
220 )
221 .ok();
222 }
223 MentionUri::Fetch { url } => {
224 write!(&mut fetch_context, "\nFetch: {}\n\n{}", url, content).ok();
225 }
226 }
227
228 language_model::MessageContent::Text(uri.as_link().to_string())
229 }
230 };
231
232 message.content.push(chunk);
233 }
234
235 let len_before_context = message.content.len();
236
237 if file_context.len() > OPEN_FILES_TAG.len() {
238 file_context.push_str("</files>\n");
239 message
240 .content
241 .push(language_model::MessageContent::Text(file_context));
242 }
243
244 if symbol_context.len() > OPEN_SYMBOLS_TAG.len() {
245 symbol_context.push_str("</symbols>\n");
246 message
247 .content
248 .push(language_model::MessageContent::Text(symbol_context));
249 }
250
251 if thread_context.len() > OPEN_THREADS_TAG.len() {
252 thread_context.push_str("</threads>\n");
253 message
254 .content
255 .push(language_model::MessageContent::Text(thread_context));
256 }
257
258 if fetch_context.len() > OPEN_FETCH_TAG.len() {
259 fetch_context.push_str("</fetched_urls>\n");
260 message
261 .content
262 .push(language_model::MessageContent::Text(fetch_context));
263 }
264
265 if rules_context.len() > OPEN_RULES_TAG.len() {
266 rules_context.push_str("</user_rules>\n");
267 message
268 .content
269 .push(language_model::MessageContent::Text(rules_context));
270 }
271
272 if message.content.len() > len_before_context {
273 message.content.insert(
274 len_before_context,
275 language_model::MessageContent::Text(OPEN_CONTEXT.into()),
276 );
277 message
278 .content
279 .push(language_model::MessageContent::Text("</context>".into()));
280 }
281
282 message
283 }
284}
285
286fn codeblock_tag(full_path: &Path, line_range: Option<&Range<u32>>) -> String {
287 let mut result = String::new();
288
289 if let Some(extension) = full_path.extension().and_then(|ext| ext.to_str()) {
290 let _ = write!(result, "{} ", extension);
291 }
292
293 let _ = write!(result, "{}", full_path.display());
294
295 if let Some(range) = line_range {
296 if range.start == range.end {
297 let _ = write!(result, ":{}", range.start + 1);
298 } else {
299 let _ = write!(result, ":{}-{}", range.start + 1, range.end + 1);
300 }
301 }
302
303 result
304}
305
306impl AgentMessage {
307 pub fn to_markdown(&self) -> String {
308 let mut markdown = String::from("## Assistant\n\n");
309
310 for content in &self.content {
311 match content {
312 AgentMessageContent::Text(text) => {
313 markdown.push_str(text);
314 markdown.push('\n');
315 }
316 AgentMessageContent::Thinking { text, .. } => {
317 markdown.push_str("<think>");
318 markdown.push_str(text);
319 markdown.push_str("</think>\n");
320 }
321 AgentMessageContent::RedactedThinking(_) => {
322 markdown.push_str("<redacted_thinking />\n")
323 }
324 AgentMessageContent::ToolUse(tool_use) => {
325 markdown.push_str(&format!(
326 "**Tool Use**: {} (ID: {})\n",
327 tool_use.name, tool_use.id
328 ));
329 markdown.push_str(&format!(
330 "{}\n",
331 MarkdownCodeBlock {
332 tag: "json",
333 text: &format!("{:#}", tool_use.input)
334 }
335 ));
336 }
337 }
338 }
339
340 for tool_result in self.tool_results.values() {
341 markdown.push_str(&format!(
342 "**Tool Result**: {} (ID: {})\n\n",
343 tool_result.tool_name, tool_result.tool_use_id
344 ));
345 if tool_result.is_error {
346 markdown.push_str("**ERROR:**\n");
347 }
348
349 match &tool_result.content {
350 LanguageModelToolResultContent::Text(text) => {
351 writeln!(markdown, "{text}\n").ok();
352 }
353 LanguageModelToolResultContent::Image(_) => {
354 writeln!(markdown, "<image />\n").ok();
355 }
356 }
357
358 if let Some(output) = tool_result.output.as_ref() {
359 writeln!(
360 markdown,
361 "**Debug Output**:\n\n```json\n{}\n```\n",
362 serde_json::to_string_pretty(output).unwrap()
363 )
364 .unwrap();
365 }
366 }
367
368 markdown
369 }
370
371 pub fn to_request(&self) -> Vec<LanguageModelRequestMessage> {
372 let mut assistant_message = LanguageModelRequestMessage {
373 role: Role::Assistant,
374 content: Vec::with_capacity(self.content.len()),
375 cache: false,
376 };
377 for chunk in &self.content {
378 let chunk = match chunk {
379 AgentMessageContent::Text(text) => {
380 language_model::MessageContent::Text(text.clone())
381 }
382 AgentMessageContent::Thinking { text, signature } => {
383 language_model::MessageContent::Thinking {
384 text: text.clone(),
385 signature: signature.clone(),
386 }
387 }
388 AgentMessageContent::RedactedThinking(value) => {
389 language_model::MessageContent::RedactedThinking(value.clone())
390 }
391 AgentMessageContent::ToolUse(value) => {
392 language_model::MessageContent::ToolUse(value.clone())
393 }
394 };
395 assistant_message.content.push(chunk);
396 }
397
398 let mut user_message = LanguageModelRequestMessage {
399 role: Role::User,
400 content: Vec::new(),
401 cache: false,
402 };
403
404 for tool_result in self.tool_results.values() {
405 user_message
406 .content
407 .push(language_model::MessageContent::ToolResult(
408 tool_result.clone(),
409 ));
410 }
411
412 let mut messages = Vec::new();
413 if !assistant_message.content.is_empty() {
414 messages.push(assistant_message);
415 }
416 if !user_message.content.is_empty() {
417 messages.push(user_message);
418 }
419 messages
420 }
421}
422
423#[derive(Default, Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
424pub struct AgentMessage {
425 pub content: Vec<AgentMessageContent>,
426 pub tool_results: IndexMap<LanguageModelToolUseId, LanguageModelToolResult>,
427}
428
429#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
430pub enum AgentMessageContent {
431 Text(String),
432 Thinking {
433 text: String,
434 signature: Option<String>,
435 },
436 RedactedThinking(String),
437 ToolUse(LanguageModelToolUse),
438}
439
440#[derive(Debug)]
441pub enum ThreadEvent {
442 UserMessage(UserMessage),
443 AgentText(String),
444 AgentThinking(String),
445 ToolCall(acp::ToolCall),
446 ToolCallUpdate(acp_thread::ToolCallUpdate),
447 ToolCallAuthorization(ToolCallAuthorization),
448 TitleUpdate(SharedString),
449 Stop(acp::StopReason),
450}
451
452#[derive(Debug)]
453pub struct ToolCallAuthorization {
454 pub tool_call: acp::ToolCallUpdate,
455 pub options: Vec<acp::PermissionOption>,
456 pub response: oneshot::Sender<acp::PermissionOptionId>,
457}
458
459enum ThreadTitle {
460 None,
461 Pending(Task<()>),
462 Done(Result<SharedString>),
463}
464
465impl ThreadTitle {
466 pub fn unwrap_or_default(&self) -> SharedString {
467 if let ThreadTitle::Done(Ok(title)) = self {
468 title.clone()
469 } else {
470 "New Thread".into()
471 }
472 }
473}
474
475pub struct Thread {
476 id: acp::SessionId,
477 prompt_id: PromptId,
478 updated_at: DateTime<Utc>,
479 title: ThreadTitle,
480 summary: DetailedSummaryState,
481 messages: Vec<Message>,
482 completion_mode: CompletionMode,
483 /// Holds the task that handles agent interaction until the end of the turn.
484 /// Survives across multiple requests as the model performs tool calls and
485 /// we run tools, report their results.
486 running_turn: Option<RunningTurn>,
487 pending_message: Option<AgentMessage>,
488 tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
489 tool_use_limit_reached: bool,
490 request_token_usage: Vec<TokenUsage>,
491 cumulative_token_usage: TokenUsage,
492 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
493 context_server_registry: Entity<ContextServerRegistry>,
494 profile_id: AgentProfileId,
495 project_context: Rc<RefCell<ProjectContext>>,
496 templates: Arc<Templates>,
497 model: Option<Arc<dyn LanguageModel>>,
498 project: Entity<Project>,
499 action_log: Entity<ActionLog>,
500}
501
502impl Thread {
503 pub fn new(
504 id: acp::SessionId,
505 project: Entity<Project>,
506 project_context: Rc<RefCell<ProjectContext>>,
507 context_server_registry: Entity<ContextServerRegistry>,
508 action_log: Entity<ActionLog>,
509 templates: Arc<Templates>,
510 model: Option<Arc<dyn LanguageModel>>,
511 cx: &mut Context<Self>,
512 ) -> Self {
513 let profile_id = AgentSettings::get_global(cx).default_profile.clone();
514 Self {
515 id,
516 prompt_id: PromptId::new(),
517 updated_at: Utc::now(),
518 title: ThreadTitle::None,
519 summary: DetailedSummaryState::default(),
520 messages: Vec::new(),
521 completion_mode: CompletionMode::Normal,
522 running_turn: None,
523 pending_message: None,
524 tools: BTreeMap::default(),
525 tool_use_limit_reached: false,
526 request_token_usage: Vec::new(),
527 cumulative_token_usage: TokenUsage::default(),
528 initial_project_snapshot: {
529 let project_snapshot = Self::project_snapshot(project.clone(), cx);
530 cx.foreground_executor()
531 .spawn(async move { Some(project_snapshot.await) })
532 .shared()
533 },
534 context_server_registry,
535 profile_id,
536 project_context,
537 templates,
538 model,
539 summarization_model,
540 project,
541 action_log,
542 }
543 }
544
545 #[cfg(any(test, feature = "test-support"))]
546 pub fn test(
547 model: Arc<dyn LanguageModel>,
548 project: Entity<Project>,
549 action_log: Entity<ActionLog>,
550 cx: &mut Context<Self>,
551 ) -> Self {
552 use crate::generate_session_id;
553
554 let context_server_registry =
555 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
556
557 Self::new(
558 generate_session_id(),
559 project,
560 Rc::default(),
561 context_server_registry,
562 action_log,
563 Templates::new(),
564 model,
565 None,
566 cx,
567 )
568 }
569
570 pub fn id(&self) -> &acp::SessionId {
571 &self.id
572 }
573
574 pub fn from_db(
575 id: acp::SessionId,
576 db_thread: DbThread,
577 project: Entity<Project>,
578 project_context: Rc<RefCell<ProjectContext>>,
579 context_server_registry: Entity<ContextServerRegistry>,
580 action_log: Entity<ActionLog>,
581 templates: Arc<Templates>,
582 model: Arc<dyn LanguageModel>,
583 summarization_model: Option<Arc<dyn LanguageModel>>,
584 cx: &mut Context<Self>,
585 ) -> Self {
586 let profile_id = db_thread
587 .profile
588 .unwrap_or_else(|| AgentSettings::get_global(cx).default_profile.clone());
589 Self {
590 id,
591 prompt_id: PromptId::new(),
592 title: ThreadTitle::Done(Ok(db_thread.title.clone())),
593 summary: db_thread.summary,
594 messages: db_thread.messages,
595 completion_mode: CompletionMode::Normal,
596 running_turn: None,
597 pending_message: None,
598 tools: BTreeMap::default(),
599 tool_use_limit_reached: false,
600 request_token_usage: db_thread.request_token_usage.clone(),
601 cumulative_token_usage: db_thread.cumulative_token_usage.clone(),
602 initial_project_snapshot: Task::ready(db_thread.initial_project_snapshot).shared(),
603 context_server_registry,
604 profile_id,
605 project_context,
606 templates,
607 model,
608 summarization_model,
609 project,
610 action_log,
611 updated_at: db_thread.updated_at, // todo!(figure out if we can remove the "recently opened" list)
612 }
613 }
614
615 pub fn to_db(&self, cx: &App) -> Task<DbThread> {
616 let initial_project_snapshot = self.initial_project_snapshot.clone();
617 let mut thread = DbThread {
618 title: self.title.unwrap_or_default(),
619 messages: self.messages.clone(),
620 updated_at: self.updated_at.clone(),
621 summary: self.summary.clone(),
622 initial_project_snapshot: None,
623 cumulative_token_usage: self.cumulative_token_usage.clone(),
624 request_token_usage: self.request_token_usage.clone(),
625 model: Some(DbLanguageModel {
626 provider: self.model.provider_id().to_string(),
627 model: self.model.name().0.to_string(),
628 }),
629 completion_mode: Some(self.completion_mode.into()),
630 profile: Some(self.profile_id.clone()),
631 };
632
633 cx.background_spawn(async move {
634 let initial_project_snapshot = initial_project_snapshot.await;
635 thread.initial_project_snapshot = initial_project_snapshot;
636 thread
637 })
638 }
639
640 pub fn replay(
641 &mut self,
642 cx: &mut Context<Self>,
643 ) -> mpsc::UnboundedReceiver<Result<ThreadEvent>> {
644 let (tx, rx) = mpsc::unbounded();
645 let stream = ThreadEventStream(tx);
646 for message in &self.messages {
647 match message {
648 Message::User(user_message) => stream.send_user_message(&user_message),
649 Message::Agent(assistant_message) => {
650 for content in &assistant_message.content {
651 match content {
652 AgentMessageContent::Text(text) => stream.send_text(text),
653 AgentMessageContent::Thinking { text, .. } => {
654 stream.send_thinking(text)
655 }
656 AgentMessageContent::RedactedThinking(_) => {}
657 AgentMessageContent::ToolUse(tool_use) => {
658 self.replay_tool_call(
659 tool_use,
660 assistant_message.tool_results.get(&tool_use.id),
661 &stream,
662 cx,
663 );
664 }
665 }
666 }
667 }
668 Message::Resume => {}
669 }
670 }
671 rx
672 }
673
674 fn replay_tool_call(
675 &self,
676 tool_use: &LanguageModelToolUse,
677 tool_result: Option<&LanguageModelToolResult>,
678 stream: &ThreadEventStream,
679 cx: &mut Context<Self>,
680 ) {
681 let Some(tool) = self.tools.get(tool_use.name.as_ref()) else {
682 stream
683 .0
684 .unbounded_send(Ok(ThreadEvent::ToolCall(acp::ToolCall {
685 id: acp::ToolCallId(tool_use.id.to_string().into()),
686 title: tool_use.name.to_string(),
687 kind: acp::ToolKind::Other,
688 status: acp::ToolCallStatus::Failed,
689 content: Vec::new(),
690 locations: Vec::new(),
691 raw_input: Some(tool_use.input.clone()),
692 raw_output: None,
693 })))
694 .ok();
695 return;
696 };
697
698 let title = tool.initial_title(tool_use.input.clone());
699 let kind = tool.kind();
700 stream.send_tool_call(&tool_use.id, title, kind, tool_use.input.clone());
701
702 let output = tool_result
703 .as_ref()
704 .and_then(|result| result.output.clone());
705 if let Some(output) = output.clone() {
706 let tool_event_stream = ToolCallEventStream::new(
707 tool_use.id.clone(),
708 stream.clone(),
709 Some(self.project.read(cx).fs().clone()),
710 );
711 tool.replay(tool_use.input.clone(), output, tool_event_stream, cx)
712 .log_err();
713 }
714
715 stream.update_tool_call_fields(
716 &tool_use.id,
717 acp::ToolCallUpdateFields {
718 status: Some(acp::ToolCallStatus::Completed),
719 raw_output: output,
720 ..Default::default()
721 },
722 );
723 }
724
725 /// Create a snapshot of the current project state including git information and unsaved buffers.
726 fn project_snapshot(
727 project: Entity<Project>,
728 cx: &mut Context<Self>,
729 ) -> Task<Arc<agent::thread::ProjectSnapshot>> {
730 let git_store = project.read(cx).git_store().clone();
731 let worktree_snapshots: Vec<_> = project
732 .read(cx)
733 .visible_worktrees(cx)
734 .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
735 .collect();
736
737 cx.spawn(async move |_, cx| {
738 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
739
740 let mut unsaved_buffers = Vec::new();
741 cx.update(|app_cx| {
742 let buffer_store = project.read(app_cx).buffer_store();
743 for buffer_handle in buffer_store.read(app_cx).buffers() {
744 let buffer = buffer_handle.read(app_cx);
745 if buffer.is_dirty() {
746 if let Some(file) = buffer.file() {
747 let path = file.path().to_string_lossy().to_string();
748 unsaved_buffers.push(path);
749 }
750 }
751 }
752 })
753 .ok();
754
755 Arc::new(ProjectSnapshot {
756 worktree_snapshots,
757 unsaved_buffer_paths: unsaved_buffers,
758 timestamp: Utc::now(),
759 })
760 })
761 }
762
763 fn worktree_snapshot(
764 worktree: Entity<project::Worktree>,
765 git_store: Entity<GitStore>,
766 cx: &App,
767 ) -> Task<agent::thread::WorktreeSnapshot> {
768 cx.spawn(async move |cx| {
769 // Get worktree path and snapshot
770 let worktree_info = cx.update(|app_cx| {
771 let worktree = worktree.read(app_cx);
772 let path = worktree.abs_path().to_string_lossy().to_string();
773 let snapshot = worktree.snapshot();
774 (path, snapshot)
775 });
776
777 let Ok((worktree_path, _snapshot)) = worktree_info else {
778 return WorktreeSnapshot {
779 worktree_path: String::new(),
780 git_state: None,
781 };
782 };
783
784 let git_state = git_store
785 .update(cx, |git_store, cx| {
786 git_store
787 .repositories()
788 .values()
789 .find(|repo| {
790 repo.read(cx)
791 .abs_path_to_repo_path(&worktree.read(cx).abs_path())
792 .is_some()
793 })
794 .cloned()
795 })
796 .ok()
797 .flatten()
798 .map(|repo| {
799 repo.update(cx, |repo, _| {
800 let current_branch =
801 repo.branch.as_ref().map(|branch| branch.name().to_owned());
802 repo.send_job(None, |state, _| async move {
803 let RepositoryState::Local { backend, .. } = state else {
804 return GitState {
805 remote_url: None,
806 head_sha: None,
807 current_branch,
808 diff: None,
809 };
810 };
811
812 let remote_url = backend.remote_url("origin");
813 let head_sha = backend.head_sha().await;
814 let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
815
816 GitState {
817 remote_url,
818 head_sha,
819 current_branch,
820 diff,
821 }
822 })
823 })
824 });
825
826 let git_state = match git_state {
827 Some(git_state) => match git_state.ok() {
828 Some(git_state) => git_state.await.ok(),
829 None => None,
830 },
831 None => None,
832 };
833
834 WorktreeSnapshot {
835 worktree_path,
836 git_state,
837 }
838 })
839 }
840
841 pub fn project(&self) -> &Entity<Project> {
842 &self.project
843 }
844
845 pub fn action_log(&self) -> &Entity<ActionLog> {
846 &self.action_log
847 }
848
849 pub fn model(&self) -> Option<&Arc<dyn LanguageModel>> {
850 self.model.as_ref()
851 }
852
853 pub fn set_model(&mut self, model: Arc<dyn LanguageModel>) {
854 self.model = Some(model);
855 }
856
857 pub fn completion_mode(&self) -> CompletionMode {
858 self.completion_mode
859 }
860
861 pub fn set_completion_mode(&mut self, mode: CompletionMode, cx: &mut Context<Self>) {
862 self.completion_mode = mode;
863 cx.notify()
864 }
865
866 #[cfg(any(test, feature = "test-support"))]
867 pub fn last_message(&self) -> Option<Message> {
868 if let Some(message) = self.pending_message.clone() {
869 Some(Message::Agent(message))
870 } else {
871 self.messages.last().cloned()
872 }
873 }
874
875 pub fn add_tool(&mut self, tool: impl AgentTool) {
876 self.tools.insert(tool.name(), tool.erase());
877 }
878
879 pub fn remove_tool(&mut self, name: &str) -> bool {
880 self.tools.remove(name).is_some()
881 }
882
883 pub fn profile(&self) -> &AgentProfileId {
884 &self.profile_id
885 }
886
887 pub fn set_profile(&mut self, profile_id: AgentProfileId) {
888 self.profile_id = profile_id;
889 }
890
891 pub fn cancel(&mut self, cx: &mut Context<Self>) {
892 if let Some(running_turn) = self.running_turn.take() {
893 running_turn.cancel();
894 }
895 self.flush_pending_message(cx);
896 }
897
898 pub fn truncate(&mut self, message_id: UserMessageId, cx: &mut Context<Self>) -> Result<()> {
899 self.cancel(cx);
900 let Some(position) = self.messages.iter().position(
901 |msg| matches!(msg, Message::User(UserMessage { id, .. }) if id == &message_id),
902 ) else {
903 return Err(anyhow!("Message not found"));
904 };
905 self.messages.truncate(position);
906 cx.notify();
907 Ok(())
908 }
909
910 pub fn resume(
911 &mut self,
912 cx: &mut Context<Self>,
913 ) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>> {
914 anyhow::ensure!(
915 self.tool_use_limit_reached,
916 "can only resume after tool use limit is reached"
917 );
918
919 self.messages.push(Message::Resume);
920 cx.notify();
921
922 log::info!("Total messages in thread: {}", self.messages.len());
923 self.run_turn(cx)
924 }
925
926 /// Sending a message results in the model streaming a response, which could include tool calls.
927 /// After calling tools, the model will stops and waits for any outstanding tool calls to be completed and their results sent.
928 /// The returned channel will report all the occurrences in which the model stops before erroring or ending its turn.
929 pub fn send<T>(
930 &mut self,
931 id: UserMessageId,
932 content: impl IntoIterator<Item = T>,
933 cx: &mut Context<Self>,
934 ) -> mpsc::UnboundedReceiver<Result<ThreadEvent>>
935 where
936 T: Into<UserMessageContent>,
937 {
938 let model = self.model().context("No language model configured")?;
939
940 log::info!("Thread::send called with model: {:?}", model.name());
941 self.advance_prompt_id();
942
943 let content = content.into_iter().map(Into::into).collect::<Vec<_>>();
944 log::debug!("Thread::send content: {:?}", content);
945
946 self.messages
947 .push(Message::User(UserMessage { id, content }));
948 cx.notify();
949
950 log::info!("Total messages in thread: {}", self.messages.len());
951 self.run_turn(cx)
952 }
953
954 fn run_turn(&mut self, cx: &mut Context<Self>) -> mpsc::UnboundedReceiver<Result<ThreadEvent>> {
955 self.cancel(cx);
956
957 let model = self.model.clone();
958 let (events_tx, events_rx) = mpsc::unbounded::<Result<ThreadEvent>>();
959 let event_stream = ThreadEventStream(events_tx);
960 let message_ix = self.messages.len().saturating_sub(1);
961 self.tool_use_limit_reached = false;
962 self.running_turn = Some(RunningTurn {
963 event_stream: event_stream.clone(),
964 _task: cx.spawn(async move |this, cx| {
965 log::info!("Starting agent turn execution");
966 let turn_result: Result<()> = async {
967 let mut completion_intent = CompletionIntent::UserPrompt;
968 loop {
969 log::debug!(
970 "Building completion request with intent: {:?}",
971 completion_intent
972 );
973 let request = this.update(cx, |this, cx| {
974 this.build_completion_request(completion_intent, cx)
975 })??;
976
977 log::info!("Calling model.stream_completion");
978 let mut events = model.stream_completion(request, cx).await?;
979 log::debug!("Stream completion started successfully");
980
981 let mut tool_use_limit_reached = false;
982 let mut tool_uses = FuturesUnordered::new();
983 while let Some(event) = events.next().await {
984 match event? {
985 LanguageModelCompletionEvent::StatusUpdate(
986 CompletionRequestStatus::ToolUseLimitReached,
987 ) => {
988 tool_use_limit_reached = true;
989 }
990 LanguageModelCompletionEvent::Stop(reason) => {
991 event_stream.send_stop(reason);
992 if reason == StopReason::Refusal {
993 this.update(cx, |this, cx| {
994 this.flush_pending_message(cx);
995 this.messages.truncate(message_ix);
996 })?;
997 return Ok(());
998 }
999 }
1000 event => {
1001 log::trace!("Received completion event: {:?}", event);
1002 this.update(cx, |this, cx| {
1003 tool_uses.extend(this.handle_streamed_completion_event(
1004 event,
1005 &event_stream,
1006 cx,
1007 ));
1008 })
1009 .ok();
1010 }
1011 }
1012 }
1013
1014 let used_tools = tool_uses.is_empty();
1015 while let Some(tool_result) = tool_uses.next().await {
1016 log::info!("Tool finished {:?}", tool_result);
1017
1018 event_stream.update_tool_call_fields(
1019 &tool_result.tool_use_id,
1020 acp::ToolCallUpdateFields {
1021 status: Some(if tool_result.is_error {
1022 acp::ToolCallStatus::Failed
1023 } else {
1024 acp::ToolCallStatus::Completed
1025 }),
1026 raw_output: tool_result.output.clone(),
1027 ..Default::default()
1028 },
1029 );
1030 this.update(cx, |this, _cx| {
1031 this.pending_message()
1032 .tool_results
1033 .insert(tool_result.tool_use_id.clone(), tool_result);
1034 })
1035 .ok();
1036 }
1037
1038 if tool_use_limit_reached {
1039 log::info!("Tool use limit reached, completing turn");
1040 this.update(cx, |this, _cx| this.tool_use_limit_reached = true)?;
1041 return Err(language_model::ToolUseLimitReachedError.into());
1042 } else if used_tools {
1043 log::info!("No tool uses found, completing turn");
1044 return Ok(());
1045 } else {
1046 this.update(cx, |this, cx| this.flush_pending_message(cx))?;
1047 completion_intent = CompletionIntent::ToolResults;
1048 }
1049 }
1050 }
1051 .await;
1052
1053 if let Err(error) = turn_result {
1054 log::error!("Turn execution failed: {:?}", error);
1055 event_stream.send_error(error);
1056 } else {
1057 log::info!("Turn execution completed successfully");
1058 }
1059
1060 this.update(cx, |this, cx| {
1061 this.flush_pending_message(cx);
1062 this.running_turn.take();
1063 })
1064 .ok();
1065 }),
1066 });
1067 Ok(events_rx)
1068 }
1069
1070 pub fn generate_title_if_needed(&mut self, cx: &mut Context<Self>) {
1071 if !matches!(self.title, ThreadTitle::None) {
1072 return;
1073 }
1074
1075 // todo!() copy logic from agent1 re: tool calls, etc.?
1076 if self.messages.len() < 2 {
1077 return;
1078 }
1079
1080 self.generate_title(cx);
1081 }
1082
1083 fn generate_title(&mut self, cx: &mut Context<Self>) {
1084 let Some(model) = self.summarization_model.clone() else {
1085 println!("No thread summary model");
1086 return;
1087 };
1088 let mut request = LanguageModelRequest {
1089 intent: Some(CompletionIntent::ThreadSummarization),
1090 temperature: AgentSettings::temperature_for_model(&model, cx),
1091 ..Default::default()
1092 };
1093
1094 for message in &self.messages {
1095 request.messages.extend(message.to_request());
1096 }
1097
1098 request.messages.push(LanguageModelRequestMessage {
1099 role: Role::User,
1100 content: vec![MessageContent::Text(SUMMARIZE_THREAD_PROMPT.into())],
1101 cache: false,
1102 });
1103
1104 let task = cx.spawn(async move |this, cx| {
1105 let result = async {
1106 let mut messages = model.stream_completion(request, &cx).await?;
1107
1108 let mut new_summary = String::new();
1109 while let Some(event) = messages.next().await {
1110 let Ok(event) = event else {
1111 continue;
1112 };
1113 let text = match event {
1114 LanguageModelCompletionEvent::Text(text) => text,
1115 LanguageModelCompletionEvent::StatusUpdate(
1116 CompletionRequestStatus::UsageUpdated { .. },
1117 ) => {
1118 // this.update(cx, |thread, cx| {
1119 // thread.update_model_request_usage(amount as u32, limit, cx);
1120 // })?;
1121 // todo!()? not sure if this is the right place to do this.
1122 continue;
1123 }
1124 _ => continue,
1125 };
1126
1127 let mut lines = text.lines();
1128 new_summary.extend(lines.next());
1129
1130 // Stop if the LLM generated multiple lines.
1131 if lines.next().is_some() {
1132 break;
1133 }
1134 }
1135
1136 anyhow::Ok(new_summary.into())
1137 }
1138 .await;
1139
1140 this.update(cx, |this, cx| {
1141 this.title = ThreadTitle::Done(result);
1142 cx.notify();
1143 })
1144 .log_err();
1145 });
1146
1147 self.title = ThreadTitle::Pending(task);
1148 }
1149
1150 pub fn build_system_message(&self) -> LanguageModelRequestMessage {
1151 log::debug!("Building system message");
1152 let prompt = SystemPromptTemplate {
1153 project: &self.project_context.borrow(),
1154 available_tools: self.tools.keys().cloned().collect(),
1155 }
1156 .render(&self.templates)
1157 .context("failed to build system prompt")
1158 .expect("Invalid template");
1159 log::debug!("System message built");
1160 LanguageModelRequestMessage {
1161 role: Role::System,
1162 content: vec![prompt.into()],
1163 cache: true,
1164 }
1165 }
1166
1167 /// A helper method that's called on every streamed completion event.
1168 /// Returns an optional tool result task, which the main agentic loop in
1169 /// send will send back to the model when it resolves.
1170 fn handle_streamed_completion_event(
1171 &mut self,
1172 event: LanguageModelCompletionEvent,
1173 event_stream: &ThreadEventStream,
1174 cx: &mut Context<Self>,
1175 ) -> Option<Task<LanguageModelToolResult>> {
1176 log::trace!("Handling streamed completion event: {:?}", event);
1177 use LanguageModelCompletionEvent::*;
1178
1179 match event {
1180 StartMessage { .. } => {
1181 self.flush_pending_message(cx);
1182 self.pending_message = Some(AgentMessage::default());
1183 }
1184 Text(new_text) => self.handle_text_event(new_text, event_stream, cx),
1185 Thinking { text, signature } => {
1186 self.handle_thinking_event(text, signature, event_stream, cx)
1187 }
1188 RedactedThinking { data } => self.handle_redacted_thinking_event(data, cx),
1189 ToolUse(tool_use) => {
1190 return self.handle_tool_use_event(tool_use, event_stream, cx);
1191 }
1192 ToolUseJsonParseError {
1193 id,
1194 tool_name,
1195 raw_input,
1196 json_parse_error,
1197 } => {
1198 return Some(Task::ready(self.handle_tool_use_json_parse_error_event(
1199 id,
1200 tool_name,
1201 raw_input,
1202 json_parse_error,
1203 )));
1204 }
1205 UsageUpdate(_) | StatusUpdate(_) => {}
1206 Stop(_) => unreachable!(),
1207 }
1208
1209 None
1210 }
1211
1212 fn handle_text_event(
1213 &mut self,
1214 new_text: String,
1215 event_stream: &ThreadEventStream,
1216 cx: &mut Context<Self>,
1217 ) {
1218 event_stream.send_text(&new_text);
1219
1220 let last_message = self.pending_message();
1221 if let Some(AgentMessageContent::Text(text)) = last_message.content.last_mut() {
1222 text.push_str(&new_text);
1223 } else {
1224 last_message
1225 .content
1226 .push(AgentMessageContent::Text(new_text));
1227 }
1228
1229 cx.notify();
1230 }
1231
1232 fn handle_thinking_event(
1233 &mut self,
1234 new_text: String,
1235 new_signature: Option<String>,
1236 event_stream: &ThreadEventStream,
1237 cx: &mut Context<Self>,
1238 ) {
1239 event_stream.send_thinking(&new_text);
1240
1241 let last_message = self.pending_message();
1242 if let Some(AgentMessageContent::Thinking { text, signature }) =
1243 last_message.content.last_mut()
1244 {
1245 text.push_str(&new_text);
1246 *signature = new_signature.or(signature.take());
1247 } else {
1248 last_message.content.push(AgentMessageContent::Thinking {
1249 text: new_text,
1250 signature: new_signature,
1251 });
1252 }
1253
1254 cx.notify();
1255 }
1256
1257 fn handle_redacted_thinking_event(&mut self, data: String, cx: &mut Context<Self>) {
1258 let last_message = self.pending_message();
1259 last_message
1260 .content
1261 .push(AgentMessageContent::RedactedThinking(data));
1262 cx.notify();
1263 }
1264
1265 fn handle_tool_use_event(
1266 &mut self,
1267 tool_use: LanguageModelToolUse,
1268 event_stream: &ThreadEventStream,
1269 cx: &mut Context<Self>,
1270 ) -> Option<Task<LanguageModelToolResult>> {
1271 cx.notify();
1272
1273 let tool = self.tools.get(tool_use.name.as_ref()).cloned();
1274 let mut title = SharedString::from(&tool_use.name);
1275 let mut kind = acp::ToolKind::Other;
1276 if let Some(tool) = tool.as_ref() {
1277 title = tool.initial_title(tool_use.input.clone());
1278 kind = tool.kind();
1279 }
1280
1281 // Ensure the last message ends in the current tool use
1282 let last_message = self.pending_message();
1283 let push_new_tool_use = last_message.content.last_mut().map_or(true, |content| {
1284 if let AgentMessageContent::ToolUse(last_tool_use) = content {
1285 if last_tool_use.id == tool_use.id {
1286 *last_tool_use = tool_use.clone();
1287 false
1288 } else {
1289 true
1290 }
1291 } else {
1292 true
1293 }
1294 });
1295
1296 if push_new_tool_use {
1297 event_stream.send_tool_call(&tool_use.id, title, kind, tool_use.input.clone());
1298 last_message
1299 .content
1300 .push(AgentMessageContent::ToolUse(tool_use.clone()));
1301 } else {
1302 event_stream.update_tool_call_fields(
1303 &tool_use.id,
1304 acp::ToolCallUpdateFields {
1305 title: Some(title.into()),
1306 kind: Some(kind),
1307 raw_input: Some(tool_use.input.clone()),
1308 ..Default::default()
1309 },
1310 );
1311 }
1312
1313 if !tool_use.is_input_complete {
1314 return None;
1315 }
1316
1317 let Some(tool) = tool else {
1318 let content = format!("No tool named {} exists", tool_use.name);
1319 return Some(Task::ready(LanguageModelToolResult {
1320 content: LanguageModelToolResultContent::Text(Arc::from(content)),
1321 tool_use_id: tool_use.id,
1322 tool_name: tool_use.name,
1323 is_error: true,
1324 output: None,
1325 }));
1326 };
1327
1328 let fs = self.project.read(cx).fs().clone();
1329 let tool_event_stream =
1330 ToolCallEventStream::new(tool_use.id.clone(), event_stream.clone(), Some(fs));
1331 tool_event_stream.update_fields(acp::ToolCallUpdateFields {
1332 status: Some(acp::ToolCallStatus::InProgress),
1333 ..Default::default()
1334 });
1335 let supports_images = self.model().map_or(false, |model| model.supports_images());
1336 let tool_result = tool.run(tool_use.input, tool_event_stream, cx);
1337 log::info!("Running tool {}", tool_use.name);
1338 Some(cx.foreground_executor().spawn(async move {
1339 let tool_result = tool_result.await.and_then(|output| {
1340 if let LanguageModelToolResultContent::Image(_) = &output.llm_output {
1341 if !supports_images {
1342 return Err(anyhow!(
1343 "Attempted to read an image, but this model doesn't support it.",
1344 ));
1345 }
1346 }
1347 Ok(output)
1348 });
1349
1350 match tool_result {
1351 Ok(output) => LanguageModelToolResult {
1352 tool_use_id: tool_use.id,
1353 tool_name: tool_use.name,
1354 is_error: false,
1355 content: output.llm_output,
1356 output: Some(output.raw_output),
1357 },
1358 Err(error) => LanguageModelToolResult {
1359 tool_use_id: tool_use.id,
1360 tool_name: tool_use.name,
1361 is_error: true,
1362 content: LanguageModelToolResultContent::Text(Arc::from(error.to_string())),
1363 output: None,
1364 },
1365 }
1366 }))
1367 }
1368
1369 fn handle_tool_use_json_parse_error_event(
1370 &mut self,
1371 tool_use_id: LanguageModelToolUseId,
1372 tool_name: Arc<str>,
1373 raw_input: Arc<str>,
1374 json_parse_error: String,
1375 ) -> LanguageModelToolResult {
1376 let tool_output = format!("Error parsing input JSON: {json_parse_error}");
1377 LanguageModelToolResult {
1378 tool_use_id,
1379 tool_name,
1380 is_error: true,
1381 content: LanguageModelToolResultContent::Text(tool_output.into()),
1382 output: Some(serde_json::Value::String(raw_input.to_string())),
1383 }
1384 }
1385
1386 fn pending_message(&mut self) -> &mut AgentMessage {
1387 self.pending_message.get_or_insert_default()
1388 }
1389
1390 fn flush_pending_message(&mut self, cx: &mut Context<Self>) {
1391 let Some(mut message) = self.pending_message.take() else {
1392 return;
1393 };
1394
1395 for content in &message.content {
1396 let AgentMessageContent::ToolUse(tool_use) = content else {
1397 continue;
1398 };
1399
1400 if !message.tool_results.contains_key(&tool_use.id) {
1401 message.tool_results.insert(
1402 tool_use.id.clone(),
1403 LanguageModelToolResult {
1404 tool_use_id: tool_use.id.clone(),
1405 tool_name: tool_use.name.clone(),
1406 is_error: true,
1407 content: LanguageModelToolResultContent::Text(TOOL_CANCELED_MESSAGE.into()),
1408 output: None,
1409 },
1410 );
1411 }
1412 }
1413
1414 self.messages.push(Message::Agent(message));
1415 cx.notify()
1416 }
1417
1418 pub(crate) fn build_completion_request(
1419 &self,
1420 completion_intent: CompletionIntent,
1421 cx: &mut App,
1422 ) -> Result<LanguageModelRequest> {
1423 let model = self.model().context("No language model configured")?;
1424
1425 log::debug!("Building completion request");
1426 log::debug!("Completion intent: {:?}", completion_intent);
1427 log::debug!("Completion mode: {:?}", self.completion_mode);
1428
1429 let messages = self.build_request_messages();
1430 log::info!("Request will include {} messages", messages.len());
1431
1432 let tools = if let Some(tools) = self.tools(cx).log_err() {
1433 tools
1434 .filter_map(|tool| {
1435 let tool_name = tool.name().to_string();
1436 log::trace!("Including tool: {}", tool_name);
1437 Some(LanguageModelRequestTool {
1438 name: tool_name,
1439 description: tool.description().to_string(),
1440 input_schema: tool.input_schema(model.tool_input_format()).log_err()?,
1441 })
1442 })
1443 .collect()
1444 } else {
1445 Vec::new()
1446 };
1447
1448 log::info!("Request includes {} tools", tools.len());
1449
1450 let request = LanguageModelRequest {
1451 thread_id: Some(self.id.to_string()),
1452 prompt_id: Some(self.prompt_id.to_string()),
1453 intent: Some(completion_intent),
1454 mode: Some(self.completion_mode.into()),
1455 messages,
1456 tools,
1457 tool_choice: None,
1458 stop: Vec::new(),
1459 temperature: AgentSettings::temperature_for_model(&model, cx),
1460 thinking_allowed: true,
1461 };
1462
1463 log::debug!("Completion request built successfully");
1464 Ok(request)
1465 }
1466
1467 fn tools<'a>(&'a self, cx: &'a App) -> Result<impl Iterator<Item = &'a Arc<dyn AnyAgentTool>>> {
1468 let model = self.model().context("No language model configured")?;
1469
1470 let profile = AgentSettings::get_global(cx)
1471 .profiles
1472 .get(&self.profile_id)
1473 .context("profile not found")?;
1474 let provider_id = model.provider_id();
1475
1476 Ok(self
1477 .tools
1478 .iter()
1479 .filter(move |(_, tool)| tool.supported_provider(&provider_id))
1480 .filter_map(|(tool_name, tool)| {
1481 if profile.is_tool_enabled(tool_name) {
1482 Some(tool)
1483 } else {
1484 None
1485 }
1486 })
1487 .chain(self.context_server_registry.read(cx).servers().flat_map(
1488 |(server_id, tools)| {
1489 tools.iter().filter_map(|(tool_name, tool)| {
1490 if profile.is_context_server_tool_enabled(&server_id.0, tool_name) {
1491 Some(tool)
1492 } else {
1493 None
1494 }
1495 })
1496 },
1497 )))
1498 }
1499
1500 fn build_request_messages(&self) -> Vec<LanguageModelRequestMessage> {
1501 log::trace!(
1502 "Building request messages from {} thread messages",
1503 self.messages.len()
1504 );
1505 let mut messages = vec![self.build_system_message()];
1506 for message in &self.messages {
1507 messages.extend(message.to_request());
1508 }
1509
1510 if let Some(message) = self.pending_message.as_ref() {
1511 messages.extend(message.to_request());
1512 }
1513
1514 if let Some(last_user_message) = messages
1515 .iter_mut()
1516 .rev()
1517 .find(|message| message.role == Role::User)
1518 {
1519 last_user_message.cache = true;
1520 }
1521
1522 messages
1523 }
1524
1525 pub fn to_markdown(&self) -> String {
1526 let mut markdown = String::new();
1527 for (ix, message) in self.messages.iter().enumerate() {
1528 if ix > 0 {
1529 markdown.push('\n');
1530 }
1531 markdown.push_str(&message.to_markdown());
1532 }
1533
1534 if let Some(message) = self.pending_message.as_ref() {
1535 markdown.push('\n');
1536 markdown.push_str(&message.to_markdown());
1537 }
1538
1539 markdown
1540 }
1541
1542 fn advance_prompt_id(&mut self) {
1543 self.prompt_id = PromptId::new();
1544 }
1545}
1546
1547struct RunningTurn {
1548 /// Holds the task that handles agent interaction until the end of the turn.
1549 /// Survives across multiple requests as the model performs tool calls and
1550 /// we run tools, report their results.
1551 _task: Task<()>,
1552 /// The current event stream for the running turn. Used to report a final
1553 /// cancellation event if we cancel the turn.
1554 event_stream: ThreadEventStream,
1555}
1556
1557impl RunningTurn {
1558 fn cancel(self) {
1559 log::debug!("Cancelling in progress turn");
1560 self.event_stream.send_canceled();
1561 }
1562}
1563
1564pub trait AgentTool
1565where
1566 Self: 'static + Sized,
1567{
1568 type Input: for<'de> Deserialize<'de> + Serialize + JsonSchema;
1569 type Output: for<'de> Deserialize<'de> + Serialize + Into<LanguageModelToolResultContent>;
1570
1571 fn name(&self) -> SharedString;
1572
1573 fn description(&self) -> SharedString {
1574 let schema = schemars::schema_for!(Self::Input);
1575 SharedString::new(
1576 schema
1577 .get("description")
1578 .and_then(|description| description.as_str())
1579 .unwrap_or_default(),
1580 )
1581 }
1582
1583 fn kind(&self) -> acp::ToolKind;
1584
1585 /// The initial tool title to display. Can be updated during the tool run.
1586 fn initial_title(&self, input: Result<Self::Input, serde_json::Value>) -> SharedString;
1587
1588 /// Returns the JSON schema that describes the tool's input.
1589 fn input_schema(&self) -> Schema {
1590 schemars::schema_for!(Self::Input)
1591 }
1592
1593 /// Some tools rely on a provider for the underlying billing or other reasons.
1594 /// Allow the tool to check if they are compatible, or should be filtered out.
1595 fn supported_provider(&self, _provider: &LanguageModelProviderId) -> bool {
1596 true
1597 }
1598
1599 /// Runs the tool with the provided input.
1600 fn run(
1601 self: Arc<Self>,
1602 input: Self::Input,
1603 event_stream: ToolCallEventStream,
1604 cx: &mut App,
1605 ) -> Task<Result<Self::Output>>;
1606
1607 /// Emits events for a previous execution of the tool.
1608 fn replay(
1609 &self,
1610 _input: Self::Input,
1611 _output: Self::Output,
1612 _event_stream: ToolCallEventStream,
1613 _cx: &mut App,
1614 ) -> Result<()> {
1615 Ok(())
1616 }
1617
1618 fn erase(self) -> Arc<dyn AnyAgentTool> {
1619 Arc::new(Erased(Arc::new(self)))
1620 }
1621}
1622
1623pub struct Erased<T>(T);
1624
1625pub struct AgentToolOutput {
1626 pub llm_output: LanguageModelToolResultContent,
1627 pub raw_output: serde_json::Value,
1628}
1629
1630pub trait AnyAgentTool {
1631 fn name(&self) -> SharedString;
1632 fn description(&self) -> SharedString;
1633 fn kind(&self) -> acp::ToolKind;
1634 fn initial_title(&self, input: serde_json::Value) -> SharedString;
1635 fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value>;
1636 fn supported_provider(&self, _provider: &LanguageModelProviderId) -> bool {
1637 true
1638 }
1639 fn run(
1640 self: Arc<Self>,
1641 input: serde_json::Value,
1642 event_stream: ToolCallEventStream,
1643 cx: &mut App,
1644 ) -> Task<Result<AgentToolOutput>>;
1645 fn replay(
1646 &self,
1647 input: serde_json::Value,
1648 output: serde_json::Value,
1649 event_stream: ToolCallEventStream,
1650 cx: &mut App,
1651 ) -> Result<()>;
1652}
1653
1654impl<T> AnyAgentTool for Erased<Arc<T>>
1655where
1656 T: AgentTool,
1657{
1658 fn name(&self) -> SharedString {
1659 self.0.name()
1660 }
1661
1662 fn description(&self) -> SharedString {
1663 self.0.description()
1664 }
1665
1666 fn kind(&self) -> agent_client_protocol::ToolKind {
1667 self.0.kind()
1668 }
1669
1670 fn initial_title(&self, input: serde_json::Value) -> SharedString {
1671 let parsed_input = serde_json::from_value(input.clone()).map_err(|_| input);
1672 self.0.initial_title(parsed_input)
1673 }
1674
1675 fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
1676 let mut json = serde_json::to_value(self.0.input_schema())?;
1677 adapt_schema_to_format(&mut json, format)?;
1678 Ok(json)
1679 }
1680
1681 fn supported_provider(&self, provider: &LanguageModelProviderId) -> bool {
1682 self.0.supported_provider(provider)
1683 }
1684
1685 fn run(
1686 self: Arc<Self>,
1687 input: serde_json::Value,
1688 event_stream: ToolCallEventStream,
1689 cx: &mut App,
1690 ) -> Task<Result<AgentToolOutput>> {
1691 cx.spawn(async move |cx| {
1692 let input = serde_json::from_value(input)?;
1693 let output = cx
1694 .update(|cx| self.0.clone().run(input, event_stream, cx))?
1695 .await?;
1696 let raw_output = serde_json::to_value(&output)?;
1697 Ok(AgentToolOutput {
1698 llm_output: output.into(),
1699 raw_output,
1700 })
1701 })
1702 }
1703
1704 fn replay(
1705 &self,
1706 input: serde_json::Value,
1707 output: serde_json::Value,
1708 event_stream: ToolCallEventStream,
1709 cx: &mut App,
1710 ) -> Result<()> {
1711 let input = serde_json::from_value(input)?;
1712 let output = serde_json::from_value(output)?;
1713 self.0.replay(input, output, event_stream, cx)
1714 }
1715}
1716
1717#[derive(Clone)]
1718struct ThreadEventStream(mpsc::UnboundedSender<Result<ThreadEvent>>);
1719
1720impl ThreadEventStream {
1721 fn send_user_message(&self, message: &UserMessage) {
1722 self.0
1723 .unbounded_send(Ok(ThreadEvent::UserMessage(message.clone())))
1724 .ok();
1725 }
1726
1727 fn send_text(&self, text: &str) {
1728 self.0
1729 .unbounded_send(Ok(ThreadEvent::AgentText(text.to_string())))
1730 .ok();
1731 }
1732
1733 fn send_thinking(&self, text: &str) {
1734 self.0
1735 .unbounded_send(Ok(ThreadEvent::AgentThinking(text.to_string())))
1736 .ok();
1737 }
1738
1739 fn send_tool_call(
1740 &self,
1741 id: &LanguageModelToolUseId,
1742 title: SharedString,
1743 kind: acp::ToolKind,
1744 input: serde_json::Value,
1745 ) {
1746 self.0
1747 .unbounded_send(Ok(ThreadEvent::ToolCall(Self::initial_tool_call(
1748 id,
1749 title.to_string(),
1750 kind,
1751 input,
1752 ))))
1753 .ok();
1754 }
1755
1756 fn initial_tool_call(
1757 id: &LanguageModelToolUseId,
1758 title: String,
1759 kind: acp::ToolKind,
1760 input: serde_json::Value,
1761 ) -> acp::ToolCall {
1762 acp::ToolCall {
1763 id: acp::ToolCallId(id.to_string().into()),
1764 title,
1765 kind,
1766 status: acp::ToolCallStatus::Pending,
1767 content: vec![],
1768 locations: vec![],
1769 raw_input: Some(input),
1770 raw_output: None,
1771 }
1772 }
1773
1774 fn update_tool_call_fields(
1775 &self,
1776 tool_use_id: &LanguageModelToolUseId,
1777 fields: acp::ToolCallUpdateFields,
1778 ) {
1779 self.0
1780 .unbounded_send(Ok(ThreadEvent::ToolCallUpdate(
1781 acp::ToolCallUpdate {
1782 id: acp::ToolCallId(tool_use_id.to_string().into()),
1783 fields,
1784 }
1785 .into(),
1786 )))
1787 .ok();
1788 }
1789
1790 fn send_stop(&self, reason: StopReason) {
1791 match reason {
1792 StopReason::EndTurn => {
1793 self.0
1794 .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::EndTurn)))
1795 .ok();
1796 }
1797 StopReason::MaxTokens => {
1798 self.0
1799 .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::MaxTokens)))
1800 .ok();
1801 }
1802 StopReason::Refusal => {
1803 self.0
1804 .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::Refusal)))
1805 .ok();
1806 }
1807 StopReason::ToolUse => {}
1808 }
1809 }
1810
1811 fn send_canceled(&self) {
1812 self.0
1813 .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::Canceled)))
1814 .ok();
1815 }
1816
1817 fn send_error(&self, error: impl Into<anyhow::Error>) {
1818 self.0.unbounded_send(Err(error.into())).ok();
1819 }
1820}
1821
1822#[derive(Clone)]
1823pub struct ToolCallEventStream {
1824 tool_use_id: LanguageModelToolUseId,
1825 stream: ThreadEventStream,
1826 fs: Option<Arc<dyn Fs>>,
1827}
1828
1829impl ToolCallEventStream {
1830 #[cfg(test)]
1831 pub fn test() -> (Self, ToolCallEventStreamReceiver) {
1832 let (events_tx, events_rx) = mpsc::unbounded::<Result<ThreadEvent>>();
1833
1834 let stream = ToolCallEventStream::new("test_id".into(), ThreadEventStream(events_tx), None);
1835
1836 (stream, ToolCallEventStreamReceiver(events_rx))
1837 }
1838
1839 fn new(
1840 tool_use_id: LanguageModelToolUseId,
1841 stream: ThreadEventStream,
1842 fs: Option<Arc<dyn Fs>>,
1843 ) -> Self {
1844 Self {
1845 tool_use_id,
1846 stream,
1847 fs,
1848 }
1849 }
1850
1851 pub fn update_fields(&self, fields: acp::ToolCallUpdateFields) {
1852 self.stream
1853 .update_tool_call_fields(&self.tool_use_id, fields);
1854 }
1855
1856 pub fn update_diff(&self, diff: Entity<acp_thread::Diff>) {
1857 self.stream
1858 .0
1859 .unbounded_send(Ok(ThreadEvent::ToolCallUpdate(
1860 acp_thread::ToolCallUpdateDiff {
1861 id: acp::ToolCallId(self.tool_use_id.to_string().into()),
1862 diff,
1863 }
1864 .into(),
1865 )))
1866 .ok();
1867 }
1868
1869 pub fn update_terminal(&self, terminal: Entity<acp_thread::Terminal>) {
1870 self.stream
1871 .0
1872 .unbounded_send(Ok(ThreadEvent::ToolCallUpdate(
1873 acp_thread::ToolCallUpdateTerminal {
1874 id: acp::ToolCallId(self.tool_use_id.to_string().into()),
1875 terminal,
1876 }
1877 .into(),
1878 )))
1879 .ok();
1880 }
1881
1882 pub fn authorize(&self, title: impl Into<String>, cx: &mut App) -> Task<Result<()>> {
1883 if agent_settings::AgentSettings::get_global(cx).always_allow_tool_actions {
1884 return Task::ready(Ok(()));
1885 }
1886
1887 let (response_tx, response_rx) = oneshot::channel();
1888 self.stream
1889 .0
1890 .unbounded_send(Ok(ThreadEvent::ToolCallAuthorization(
1891 ToolCallAuthorization {
1892 tool_call: acp::ToolCallUpdate {
1893 id: acp::ToolCallId(self.tool_use_id.to_string().into()),
1894 fields: acp::ToolCallUpdateFields {
1895 title: Some(title.into()),
1896 ..Default::default()
1897 },
1898 },
1899 options: vec![
1900 acp::PermissionOption {
1901 id: acp::PermissionOptionId("always_allow".into()),
1902 name: "Always Allow".into(),
1903 kind: acp::PermissionOptionKind::AllowAlways,
1904 },
1905 acp::PermissionOption {
1906 id: acp::PermissionOptionId("allow".into()),
1907 name: "Allow".into(),
1908 kind: acp::PermissionOptionKind::AllowOnce,
1909 },
1910 acp::PermissionOption {
1911 id: acp::PermissionOptionId("deny".into()),
1912 name: "Deny".into(),
1913 kind: acp::PermissionOptionKind::RejectOnce,
1914 },
1915 ],
1916 response: response_tx,
1917 },
1918 )))
1919 .ok();
1920 let fs = self.fs.clone();
1921 cx.spawn(async move |cx| match response_rx.await?.0.as_ref() {
1922 "always_allow" => {
1923 if let Some(fs) = fs.clone() {
1924 cx.update(|cx| {
1925 update_settings_file::<AgentSettings>(fs, cx, |settings, _| {
1926 settings.set_always_allow_tool_actions(true);
1927 });
1928 })?;
1929 }
1930
1931 Ok(())
1932 }
1933 "allow" => Ok(()),
1934 _ => Err(anyhow!("Permission to run tool denied by user")),
1935 })
1936 }
1937}
1938
1939#[cfg(test)]
1940pub struct ToolCallEventStreamReceiver(mpsc::UnboundedReceiver<Result<ThreadEvent>>);
1941
1942#[cfg(test)]
1943impl ToolCallEventStreamReceiver {
1944 pub async fn expect_authorization(&mut self) -> ToolCallAuthorization {
1945 let event = self.0.next().await;
1946 if let Some(Ok(ThreadEvent::ToolCallAuthorization(auth))) = event {
1947 auth
1948 } else {
1949 panic!("Expected ToolCallAuthorization but got: {:?}", event);
1950 }
1951 }
1952
1953 pub async fn expect_terminal(&mut self) -> Entity<acp_thread::Terminal> {
1954 let event = self.0.next().await;
1955 if let Some(Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateTerminal(
1956 update,
1957 )))) = event
1958 {
1959 update.terminal
1960 } else {
1961 panic!("Expected terminal but got: {:?}", event);
1962 }
1963 }
1964}
1965
1966#[cfg(test)]
1967impl std::ops::Deref for ToolCallEventStreamReceiver {
1968 type Target = mpsc::UnboundedReceiver<Result<ThreadEvent>>;
1969
1970 fn deref(&self) -> &Self::Target {
1971 &self.0
1972 }
1973}
1974
1975#[cfg(test)]
1976impl std::ops::DerefMut for ToolCallEventStreamReceiver {
1977 fn deref_mut(&mut self) -> &mut Self::Target {
1978 &mut self.0
1979 }
1980}
1981
1982impl From<&str> for UserMessageContent {
1983 fn from(text: &str) -> Self {
1984 Self::Text(text.into())
1985 }
1986}
1987
1988impl From<acp::ContentBlock> for UserMessageContent {
1989 fn from(value: acp::ContentBlock) -> Self {
1990 match value {
1991 acp::ContentBlock::Text(text_content) => Self::Text(text_content.text),
1992 acp::ContentBlock::Image(image_content) => Self::Image(convert_image(image_content)),
1993 acp::ContentBlock::Audio(_) => {
1994 // TODO
1995 Self::Text("[audio]".to_string())
1996 }
1997 acp::ContentBlock::ResourceLink(resource_link) => {
1998 match MentionUri::parse(&resource_link.uri) {
1999 Ok(uri) => Self::Mention {
2000 uri,
2001 content: String::new(),
2002 },
2003 Err(err) => {
2004 log::error!("Failed to parse mention link: {}", err);
2005 Self::Text(format!("[{}]({})", resource_link.name, resource_link.uri))
2006 }
2007 }
2008 }
2009 acp::ContentBlock::Resource(resource) => match resource.resource {
2010 acp::EmbeddedResourceResource::TextResourceContents(resource) => {
2011 match MentionUri::parse(&resource.uri) {
2012 Ok(uri) => Self::Mention {
2013 uri,
2014 content: resource.text,
2015 },
2016 Err(err) => {
2017 log::error!("Failed to parse mention link: {}", err);
2018 Self::Text(
2019 MarkdownCodeBlock {
2020 tag: &resource.uri,
2021 text: &resource.text,
2022 }
2023 .to_string(),
2024 )
2025 }
2026 }
2027 }
2028 acp::EmbeddedResourceResource::BlobResourceContents(_) => {
2029 // TODO
2030 Self::Text("[blob]".to_string())
2031 }
2032 },
2033 }
2034 }
2035}
2036
2037impl From<UserMessageContent> for acp::ContentBlock {
2038 fn from(content: UserMessageContent) -> Self {
2039 match content {
2040 UserMessageContent::Text(text) => acp::ContentBlock::Text(acp::TextContent {
2041 text,
2042 annotations: None,
2043 }),
2044 UserMessageContent::Image(image) => acp::ContentBlock::Image(acp::ImageContent {
2045 data: image.source.to_string(),
2046 mime_type: "image/png".to_string(),
2047 annotations: None,
2048 uri: None,
2049 }),
2050 UserMessageContent::Mention { .. } => {
2051 todo!()
2052 }
2053 }
2054 }
2055}
2056
2057fn convert_image(image_content: acp::ImageContent) -> LanguageModelImage {
2058 LanguageModelImage {
2059 source: image_content.data.into(),
2060 // TODO: make this optional?
2061 size: gpui::Size::new(0.into(), 0.into()),
2062 }
2063}