thread.rs

   1use crate::{
   2    ContextServerRegistry, DbLanguageModel, DbThread, SystemPromptTemplate, Template, Templates,
   3};
   4use acp_thread::{MentionUri, UserMessageId};
   5use action_log::ActionLog;
   6use agent::thread::{DetailedSummaryState, GitState, ProjectSnapshot, WorktreeSnapshot};
   7use agent_client_protocol as acp;
   8use agent_settings::{AgentProfileId, AgentSettings, CompletionMode};
   9use anyhow::{Context as _, Result, anyhow};
  10use assistant_tool::adapt_schema_to_format;
  11use chrono::{DateTime, Utc};
  12use cloud_llm_client::{CompletionIntent, CompletionRequestStatus};
  13use collections::IndexMap;
  14use fs::Fs;
  15use futures::{
  16    FutureExt,
  17    channel::{mpsc, oneshot},
  18    future::Shared,
  19    stream::FuturesUnordered,
  20};
  21use git::repository::DiffType;
  22use gpui::{App, AppContext, Context, Entity, SharedString, Task};
  23use language_model::{
  24    LanguageModel, LanguageModelCompletionEvent, LanguageModelImage, LanguageModelProviderId,
  25    LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
  26    LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolSchemaFormat,
  27    LanguageModelToolUse, LanguageModelToolUseId, Role, StopReason, TokenUsage,
  28};
  29use project::{
  30    Project,
  31    git_store::{GitStore, RepositoryState},
  32};
  33use prompt_store::ProjectContext;
  34use schemars::{JsonSchema, Schema};
  35use serde::{Deserialize, Serialize};
  36use settings::{Settings, update_settings_file};
  37use smol::stream::StreamExt;
  38use std::{cell::RefCell, collections::BTreeMap, path::Path, rc::Rc, sync::Arc};
  39use std::{fmt::Write, ops::Range};
  40use util::{ResultExt, markdown::MarkdownCodeBlock};
  41use uuid::Uuid;
  42
  43const TOOL_CANCELED_MESSAGE: &str = "Tool canceled by user";
  44
  45/// The ID of the user prompt that initiated a request.
  46///
  47/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
  48#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
  49pub struct PromptId(Arc<str>);
  50
  51impl PromptId {
  52    pub fn new() -> Self {
  53        Self(Uuid::new_v4().to_string().into())
  54    }
  55}
  56
  57impl std::fmt::Display for PromptId {
  58    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  59        write!(f, "{}", self.0)
  60    }
  61}
  62
  63#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
  64pub enum Message {
  65    User(UserMessage),
  66    Agent(AgentMessage),
  67    Resume,
  68}
  69
  70impl Message {
  71    pub fn as_agent_message(&self) -> Option<&AgentMessage> {
  72        match self {
  73            Message::Agent(agent_message) => Some(agent_message),
  74            _ => None,
  75        }
  76    }
  77
  78    pub fn to_markdown(&self) -> String {
  79        match self {
  80            Message::User(message) => message.to_markdown(),
  81            Message::Agent(message) => message.to_markdown(),
  82            Message::Resume => "[resumed after tool use limit was reached]".into(),
  83        }
  84    }
  85}
  86
  87#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
  88pub struct UserMessage {
  89    pub id: UserMessageId,
  90    pub content: Vec<UserMessageContent>,
  91}
  92
  93#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
  94pub enum UserMessageContent {
  95    Text(String),
  96    Mention { uri: MentionUri, content: String },
  97    Image(LanguageModelImage),
  98}
  99
 100impl UserMessage {
 101    pub fn to_markdown(&self) -> String {
 102        let mut markdown = String::from("## User\n\n");
 103
 104        for content in &self.content {
 105            match content {
 106                UserMessageContent::Text(text) => {
 107                    markdown.push_str(text);
 108                    markdown.push('\n');
 109                }
 110                UserMessageContent::Image(_) => {
 111                    markdown.push_str("<image />\n");
 112                }
 113                UserMessageContent::Mention { uri, content } => {
 114                    if !content.is_empty() {
 115                        let _ = write!(&mut markdown, "{}\n\n{}\n", uri.as_link(), content);
 116                    } else {
 117                        let _ = write!(&mut markdown, "{}\n", uri.as_link());
 118                    }
 119                }
 120            }
 121        }
 122
 123        markdown
 124    }
 125
 126    fn to_request(&self) -> LanguageModelRequestMessage {
 127        let mut message = LanguageModelRequestMessage {
 128            role: Role::User,
 129            content: Vec::with_capacity(self.content.len()),
 130            cache: false,
 131        };
 132
 133        const OPEN_CONTEXT: &str = "<context>\n\
 134            The following items were attached by the user. \
 135            They are up-to-date and don't need to be re-read.\n\n";
 136
 137        const OPEN_FILES_TAG: &str = "<files>";
 138        const OPEN_SYMBOLS_TAG: &str = "<symbols>";
 139        const OPEN_THREADS_TAG: &str = "<threads>";
 140        const OPEN_FETCH_TAG: &str = "<fetched_urls>";
 141        const OPEN_RULES_TAG: &str =
 142            "<rules>\nThe user has specified the following rules that should be applied:\n";
 143
 144        let mut file_context = OPEN_FILES_TAG.to_string();
 145        let mut symbol_context = OPEN_SYMBOLS_TAG.to_string();
 146        let mut thread_context = OPEN_THREADS_TAG.to_string();
 147        let mut fetch_context = OPEN_FETCH_TAG.to_string();
 148        let mut rules_context = OPEN_RULES_TAG.to_string();
 149
 150        for chunk in &self.content {
 151            let chunk = match chunk {
 152                UserMessageContent::Text(text) => {
 153                    language_model::MessageContent::Text(text.clone())
 154                }
 155                UserMessageContent::Image(value) => {
 156                    language_model::MessageContent::Image(value.clone())
 157                }
 158                UserMessageContent::Mention { uri, content } => {
 159                    match uri {
 160                        MentionUri::File { abs_path, .. } => {
 161                            write!(
 162                                &mut symbol_context,
 163                                "\n{}",
 164                                MarkdownCodeBlock {
 165                                    tag: &codeblock_tag(&abs_path, None),
 166                                    text: &content.to_string(),
 167                                }
 168                            )
 169                            .ok();
 170                        }
 171                        MentionUri::Symbol {
 172                            path, line_range, ..
 173                        }
 174                        | MentionUri::Selection {
 175                            path, line_range, ..
 176                        } => {
 177                            write!(
 178                                &mut rules_context,
 179                                "\n{}",
 180                                MarkdownCodeBlock {
 181                                    tag: &codeblock_tag(&path, Some(line_range)),
 182                                    text: &content
 183                                }
 184                            )
 185                            .ok();
 186                        }
 187                        MentionUri::Thread { .. } => {
 188                            write!(&mut thread_context, "\n{}\n", content).ok();
 189                        }
 190                        MentionUri::TextThread { .. } => {
 191                            write!(&mut thread_context, "\n{}\n", content).ok();
 192                        }
 193                        MentionUri::Rule { .. } => {
 194                            write!(
 195                                &mut rules_context,
 196                                "\n{}",
 197                                MarkdownCodeBlock {
 198                                    tag: "",
 199                                    text: &content
 200                                }
 201                            )
 202                            .ok();
 203                        }
 204                        MentionUri::Fetch { url } => {
 205                            write!(&mut fetch_context, "\nFetch: {}\n\n{}", url, content).ok();
 206                        }
 207                    }
 208
 209                    language_model::MessageContent::Text(uri.as_link().to_string())
 210                }
 211            };
 212
 213            message.content.push(chunk);
 214        }
 215
 216        let len_before_context = message.content.len();
 217
 218        if file_context.len() > OPEN_FILES_TAG.len() {
 219            file_context.push_str("</files>\n");
 220            message
 221                .content
 222                .push(language_model::MessageContent::Text(file_context));
 223        }
 224
 225        if symbol_context.len() > OPEN_SYMBOLS_TAG.len() {
 226            symbol_context.push_str("</symbols>\n");
 227            message
 228                .content
 229                .push(language_model::MessageContent::Text(symbol_context));
 230        }
 231
 232        if thread_context.len() > OPEN_THREADS_TAG.len() {
 233            thread_context.push_str("</threads>\n");
 234            message
 235                .content
 236                .push(language_model::MessageContent::Text(thread_context));
 237        }
 238
 239        if fetch_context.len() > OPEN_FETCH_TAG.len() {
 240            fetch_context.push_str("</fetched_urls>\n");
 241            message
 242                .content
 243                .push(language_model::MessageContent::Text(fetch_context));
 244        }
 245
 246        if rules_context.len() > OPEN_RULES_TAG.len() {
 247            rules_context.push_str("</user_rules>\n");
 248            message
 249                .content
 250                .push(language_model::MessageContent::Text(rules_context));
 251        }
 252
 253        if message.content.len() > len_before_context {
 254            message.content.insert(
 255                len_before_context,
 256                language_model::MessageContent::Text(OPEN_CONTEXT.into()),
 257            );
 258            message
 259                .content
 260                .push(language_model::MessageContent::Text("</context>".into()));
 261        }
 262
 263        message
 264    }
 265}
 266
 267fn codeblock_tag(full_path: &Path, line_range: Option<&Range<u32>>) -> String {
 268    let mut result = String::new();
 269
 270    if let Some(extension) = full_path.extension().and_then(|ext| ext.to_str()) {
 271        let _ = write!(result, "{} ", extension);
 272    }
 273
 274    let _ = write!(result, "{}", full_path.display());
 275
 276    if let Some(range) = line_range {
 277        if range.start == range.end {
 278            let _ = write!(result, ":{}", range.start + 1);
 279        } else {
 280            let _ = write!(result, ":{}-{}", range.start + 1, range.end + 1);
 281        }
 282    }
 283
 284    result
 285}
 286
 287impl AgentMessage {
 288    pub fn to_markdown(&self) -> String {
 289        let mut markdown = String::from("## Assistant\n\n");
 290
 291        for content in &self.content {
 292            match content {
 293                AgentMessageContent::Text(text) => {
 294                    markdown.push_str(text);
 295                    markdown.push('\n');
 296                }
 297                AgentMessageContent::Thinking { text, .. } => {
 298                    markdown.push_str("<think>");
 299                    markdown.push_str(text);
 300                    markdown.push_str("</think>\n");
 301                }
 302                AgentMessageContent::RedactedThinking(_) => {
 303                    markdown.push_str("<redacted_thinking />\n")
 304                }
 305                AgentMessageContent::ToolUse(tool_use) => {
 306                    markdown.push_str(&format!(
 307                        "**Tool Use**: {} (ID: {})\n",
 308                        tool_use.name, tool_use.id
 309                    ));
 310                    markdown.push_str(&format!(
 311                        "{}\n",
 312                        MarkdownCodeBlock {
 313                            tag: "json",
 314                            text: &format!("{:#}", tool_use.input)
 315                        }
 316                    ));
 317                }
 318            }
 319        }
 320
 321        for tool_result in self.tool_results.values() {
 322            markdown.push_str(&format!(
 323                "**Tool Result**: {} (ID: {})\n\n",
 324                tool_result.tool_name, tool_result.tool_use_id
 325            ));
 326            if tool_result.is_error {
 327                markdown.push_str("**ERROR:**\n");
 328            }
 329
 330            match &tool_result.content {
 331                LanguageModelToolResultContent::Text(text) => {
 332                    writeln!(markdown, "{text}\n").ok();
 333                }
 334                LanguageModelToolResultContent::Image(_) => {
 335                    writeln!(markdown, "<image />\n").ok();
 336                }
 337            }
 338
 339            if let Some(output) = tool_result.output.as_ref() {
 340                writeln!(
 341                    markdown,
 342                    "**Debug Output**:\n\n```json\n{}\n```\n",
 343                    serde_json::to_string_pretty(output).unwrap()
 344                )
 345                .unwrap();
 346            }
 347        }
 348
 349        markdown
 350    }
 351
 352    pub fn to_request(&self) -> Vec<LanguageModelRequestMessage> {
 353        let mut assistant_message = LanguageModelRequestMessage {
 354            role: Role::Assistant,
 355            content: Vec::with_capacity(self.content.len()),
 356            cache: false,
 357        };
 358        for chunk in &self.content {
 359            let chunk = match chunk {
 360                AgentMessageContent::Text(text) => {
 361                    language_model::MessageContent::Text(text.clone())
 362                }
 363                AgentMessageContent::Thinking { text, signature } => {
 364                    language_model::MessageContent::Thinking {
 365                        text: text.clone(),
 366                        signature: signature.clone(),
 367                    }
 368                }
 369                AgentMessageContent::RedactedThinking(value) => {
 370                    language_model::MessageContent::RedactedThinking(value.clone())
 371                }
 372                AgentMessageContent::ToolUse(value) => {
 373                    language_model::MessageContent::ToolUse(value.clone())
 374                }
 375            };
 376            assistant_message.content.push(chunk);
 377        }
 378
 379        let mut user_message = LanguageModelRequestMessage {
 380            role: Role::User,
 381            content: Vec::new(),
 382            cache: false,
 383        };
 384
 385        for tool_result in self.tool_results.values() {
 386            user_message
 387                .content
 388                .push(language_model::MessageContent::ToolResult(
 389                    tool_result.clone(),
 390                ));
 391        }
 392
 393        let mut messages = Vec::new();
 394        if !assistant_message.content.is_empty() {
 395            messages.push(assistant_message);
 396        }
 397        if !user_message.content.is_empty() {
 398            messages.push(user_message);
 399        }
 400        messages
 401    }
 402}
 403
 404#[derive(Default, Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
 405pub struct AgentMessage {
 406    pub content: Vec<AgentMessageContent>,
 407    pub tool_results: IndexMap<LanguageModelToolUseId, LanguageModelToolResult>,
 408}
 409
 410#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
 411pub enum AgentMessageContent {
 412    Text(String),
 413    Thinking {
 414        text: String,
 415        signature: Option<String>,
 416    },
 417    RedactedThinking(String),
 418    ToolUse(LanguageModelToolUse),
 419}
 420
 421#[derive(Debug)]
 422pub enum ThreadEvent {
 423    UserMessage(UserMessage),
 424    AgentText(String),
 425    AgentThinking(String),
 426    ToolCall(acp::ToolCall),
 427    ToolCallUpdate(acp_thread::ToolCallUpdate),
 428    ToolCallAuthorization(ToolCallAuthorization),
 429    Stop(acp::StopReason),
 430}
 431
 432#[derive(Debug)]
 433pub struct ToolCallAuthorization {
 434    pub tool_call: acp::ToolCallUpdate,
 435    pub options: Vec<acp::PermissionOption>,
 436    pub response: oneshot::Sender<acp::PermissionOptionId>,
 437}
 438
 439enum ThreadTitle {
 440    None,
 441    Pending(Task<()>),
 442    Done(Result<SharedString>),
 443}
 444
 445impl ThreadTitle {
 446    pub fn unwrap_or_default(&self) -> SharedString {
 447        if let ThreadTitle::Done(Ok(title)) = self {
 448            title.clone()
 449        } else {
 450            "New Thread".into()
 451        }
 452    }
 453}
 454
 455pub struct Thread {
 456    id: acp::SessionId,
 457    prompt_id: PromptId,
 458    updated_at: DateTime<Utc>,
 459    title: ThreadTitle,
 460    summary: DetailedSummaryState,
 461    messages: Vec<Message>,
 462    completion_mode: CompletionMode,
 463    /// Holds the task that handles agent interaction until the end of the turn.
 464    /// Survives across multiple requests as the model performs tool calls and
 465    /// we run tools, report their results.
 466    running_turn: Option<RunningTurn>,
 467    pending_message: Option<AgentMessage>,
 468    tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
 469    tool_use_limit_reached: bool,
 470    request_token_usage: Vec<TokenUsage>,
 471    cumulative_token_usage: TokenUsage,
 472    initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
 473    context_server_registry: Entity<ContextServerRegistry>,
 474    profile_id: AgentProfileId,
 475    project_context: Rc<RefCell<ProjectContext>>,
 476    templates: Arc<Templates>,
 477    model: Arc<dyn LanguageModel>,
 478    project: Entity<Project>,
 479    action_log: Entity<ActionLog>,
 480}
 481
 482impl Thread {
 483    pub fn new(
 484        id: acp::SessionId,
 485        project: Entity<Project>,
 486        project_context: Rc<RefCell<ProjectContext>>,
 487        context_server_registry: Entity<ContextServerRegistry>,
 488        action_log: Entity<ActionLog>,
 489        templates: Arc<Templates>,
 490        model: Arc<dyn LanguageModel>,
 491        cx: &mut Context<Self>,
 492    ) -> Self {
 493        let profile_id = AgentSettings::get_global(cx).default_profile.clone();
 494        Self {
 495            id,
 496            prompt_id: PromptId::new(),
 497            updated_at: Utc::now(),
 498            title: ThreadTitle::None,
 499            summary: DetailedSummaryState::default(),
 500            messages: Vec::new(),
 501            completion_mode: CompletionMode::Normal,
 502            running_turn: None,
 503            pending_message: None,
 504            tools: BTreeMap::default(),
 505            tool_use_limit_reached: false,
 506            request_token_usage: Vec::new(),
 507            cumulative_token_usage: TokenUsage::default(),
 508            initial_project_snapshot: {
 509                let project_snapshot = Self::project_snapshot(project.clone(), cx);
 510                cx.foreground_executor()
 511                    .spawn(async move { Some(project_snapshot.await) })
 512                    .shared()
 513            },
 514            context_server_registry,
 515            profile_id,
 516            project_context,
 517            templates,
 518            model,
 519            project,
 520            action_log,
 521        }
 522    }
 523
 524    pub fn id(&self) -> &acp::SessionId {
 525        &self.id
 526    }
 527
 528    pub fn from_db(
 529        id: acp::SessionId,
 530        db_thread: DbThread,
 531        project: Entity<Project>,
 532        project_context: Rc<RefCell<ProjectContext>>,
 533        context_server_registry: Entity<ContextServerRegistry>,
 534        action_log: Entity<ActionLog>,
 535        templates: Arc<Templates>,
 536        model: Arc<dyn LanguageModel>,
 537        cx: &mut Context<Self>,
 538    ) -> Self {
 539        let profile_id = db_thread
 540            .profile
 541            .unwrap_or_else(|| AgentSettings::get_global(cx).default_profile.clone());
 542        Self {
 543            id,
 544            prompt_id: PromptId::new(),
 545            title: ThreadTitle::Done(Ok(db_thread.title.clone())),
 546            summary: db_thread.summary,
 547            messages: db_thread.messages,
 548            completion_mode: CompletionMode::Normal,
 549            running_turn: None,
 550            pending_message: None,
 551            tools: BTreeMap::default(),
 552            tool_use_limit_reached: false,
 553            request_token_usage: db_thread.request_token_usage.clone(),
 554            cumulative_token_usage: db_thread.cumulative_token_usage.clone(),
 555            initial_project_snapshot: Task::ready(db_thread.initial_project_snapshot).shared(),
 556            context_server_registry,
 557            profile_id,
 558            project_context,
 559            templates,
 560            model,
 561            project,
 562            action_log,
 563            updated_at: db_thread.updated_at, // todo!(figure out if we can remove the "recently opened" list)
 564        }
 565    }
 566
 567    pub fn to_db(&self, cx: &App) -> Task<DbThread> {
 568        let initial_project_snapshot = self.initial_project_snapshot.clone();
 569        let mut thread = DbThread {
 570            title: self.title.unwrap_or_default(),
 571            messages: self.messages.clone(),
 572            updated_at: self.updated_at.clone(),
 573            summary: self.summary.clone(),
 574            initial_project_snapshot: None,
 575            cumulative_token_usage: self.cumulative_token_usage.clone(),
 576            request_token_usage: self.request_token_usage.clone(),
 577            model: Some(DbLanguageModel {
 578                provider: self.model.provider_id().to_string(),
 579                model: self.model.name().0.to_string(),
 580            }),
 581            completion_mode: Some(self.completion_mode.into()),
 582            profile: Some(self.profile_id.clone()),
 583        };
 584
 585        cx.background_spawn(async move {
 586            let initial_project_snapshot = initial_project_snapshot.await;
 587            thread.initial_project_snapshot = initial_project_snapshot;
 588            thread
 589        })
 590    }
 591
 592    pub fn replay(
 593        &mut self,
 594        cx: &mut Context<Self>,
 595    ) -> mpsc::UnboundedReceiver<Result<ThreadEvent>> {
 596        let (tx, rx) = mpsc::unbounded();
 597        let stream = ThreadEventStream(tx);
 598        for message in &self.messages {
 599            match message {
 600                Message::User(user_message) => stream.send_user_message(&user_message),
 601                Message::Agent(assistant_message) => {
 602                    for content in &assistant_message.content {
 603                        match content {
 604                            AgentMessageContent::Text(text) => stream.send_text(text),
 605                            AgentMessageContent::Thinking { text, .. } => {
 606                                stream.send_thinking(text)
 607                            }
 608                            AgentMessageContent::RedactedThinking(_) => {}
 609                            AgentMessageContent::ToolUse(tool_use) => {
 610                                self.replay_tool_call(
 611                                    tool_use,
 612                                    assistant_message.tool_results.get(&tool_use.id),
 613                                    &stream,
 614                                    cx,
 615                                );
 616                            }
 617                        }
 618                    }
 619                }
 620                Message::Resume => {}
 621            }
 622        }
 623        rx
 624    }
 625
 626    fn replay_tool_call(
 627        &self,
 628        tool_use: &LanguageModelToolUse,
 629        tool_result: Option<&LanguageModelToolResult>,
 630        stream: &ThreadEventStream,
 631        cx: &mut Context<Self>,
 632    ) {
 633        let Some(tool) = self.tools.get(tool_use.name.as_ref()) else {
 634            stream
 635                .0
 636                .unbounded_send(Ok(ThreadEvent::ToolCall(acp::ToolCall {
 637                    id: acp::ToolCallId(tool_use.id.to_string().into()),
 638                    title: tool_use.name.to_string(),
 639                    kind: acp::ToolKind::Other,
 640                    status: acp::ToolCallStatus::Failed,
 641                    content: Vec::new(),
 642                    locations: Vec::new(),
 643                    raw_input: Some(tool_use.input.clone()),
 644                    raw_output: None,
 645                })))
 646                .ok();
 647            return;
 648        };
 649
 650        let title = tool.initial_title(tool_use.input.clone());
 651        let kind = tool.kind();
 652        stream.send_tool_call(&tool_use.id, title, kind, tool_use.input.clone());
 653
 654        let output = tool_result
 655            .as_ref()
 656            .and_then(|result| result.output.clone());
 657        if let Some(output) = output.clone() {
 658            let tool_event_stream = ToolCallEventStream::new(
 659                tool_use.id.clone(),
 660                stream.clone(),
 661                Some(self.project.read(cx).fs().clone()),
 662            );
 663            tool.replay(tool_use.input.clone(), output, tool_event_stream, cx)
 664                .log_err();
 665        }
 666
 667        stream.update_tool_call_fields(
 668            &tool_use.id,
 669            acp::ToolCallUpdateFields {
 670                status: Some(acp::ToolCallStatus::Completed),
 671                raw_output: output,
 672                ..Default::default()
 673            },
 674        );
 675    }
 676
 677    /// Create a snapshot of the current project state including git information and unsaved buffers.
 678    fn project_snapshot(
 679        project: Entity<Project>,
 680        cx: &mut Context<Self>,
 681    ) -> Task<Arc<agent::thread::ProjectSnapshot>> {
 682        let git_store = project.read(cx).git_store().clone();
 683        let worktree_snapshots: Vec<_> = project
 684            .read(cx)
 685            .visible_worktrees(cx)
 686            .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
 687            .collect();
 688
 689        cx.spawn(async move |_, cx| {
 690            let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
 691
 692            let mut unsaved_buffers = Vec::new();
 693            cx.update(|app_cx| {
 694                let buffer_store = project.read(app_cx).buffer_store();
 695                for buffer_handle in buffer_store.read(app_cx).buffers() {
 696                    let buffer = buffer_handle.read(app_cx);
 697                    if buffer.is_dirty() {
 698                        if let Some(file) = buffer.file() {
 699                            let path = file.path().to_string_lossy().to_string();
 700                            unsaved_buffers.push(path);
 701                        }
 702                    }
 703                }
 704            })
 705            .ok();
 706
 707            Arc::new(ProjectSnapshot {
 708                worktree_snapshots,
 709                unsaved_buffer_paths: unsaved_buffers,
 710                timestamp: Utc::now(),
 711            })
 712        })
 713    }
 714
 715    fn worktree_snapshot(
 716        worktree: Entity<project::Worktree>,
 717        git_store: Entity<GitStore>,
 718        cx: &App,
 719    ) -> Task<agent::thread::WorktreeSnapshot> {
 720        cx.spawn(async move |cx| {
 721            // Get worktree path and snapshot
 722            let worktree_info = cx.update(|app_cx| {
 723                let worktree = worktree.read(app_cx);
 724                let path = worktree.abs_path().to_string_lossy().to_string();
 725                let snapshot = worktree.snapshot();
 726                (path, snapshot)
 727            });
 728
 729            let Ok((worktree_path, _snapshot)) = worktree_info else {
 730                return WorktreeSnapshot {
 731                    worktree_path: String::new(),
 732                    git_state: None,
 733                };
 734            };
 735
 736            let git_state = git_store
 737                .update(cx, |git_store, cx| {
 738                    git_store
 739                        .repositories()
 740                        .values()
 741                        .find(|repo| {
 742                            repo.read(cx)
 743                                .abs_path_to_repo_path(&worktree.read(cx).abs_path())
 744                                .is_some()
 745                        })
 746                        .cloned()
 747                })
 748                .ok()
 749                .flatten()
 750                .map(|repo| {
 751                    repo.update(cx, |repo, _| {
 752                        let current_branch =
 753                            repo.branch.as_ref().map(|branch| branch.name().to_owned());
 754                        repo.send_job(None, |state, _| async move {
 755                            let RepositoryState::Local { backend, .. } = state else {
 756                                return GitState {
 757                                    remote_url: None,
 758                                    head_sha: None,
 759                                    current_branch,
 760                                    diff: None,
 761                                };
 762                            };
 763
 764                            let remote_url = backend.remote_url("origin");
 765                            let head_sha = backend.head_sha().await;
 766                            let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
 767
 768                            GitState {
 769                                remote_url,
 770                                head_sha,
 771                                current_branch,
 772                                diff,
 773                            }
 774                        })
 775                    })
 776                });
 777
 778            let git_state = match git_state {
 779                Some(git_state) => match git_state.ok() {
 780                    Some(git_state) => git_state.await.ok(),
 781                    None => None,
 782                },
 783                None => None,
 784            };
 785
 786            WorktreeSnapshot {
 787                worktree_path,
 788                git_state,
 789            }
 790        })
 791    }
 792
 793    pub fn project(&self) -> &Entity<Project> {
 794        &self.project
 795    }
 796
 797    pub fn action_log(&self) -> &Entity<ActionLog> {
 798        &self.action_log
 799    }
 800
 801    pub fn model(&self) -> &Arc<dyn LanguageModel> {
 802        &self.model
 803    }
 804
 805    pub fn set_model(&mut self, model: Arc<dyn LanguageModel>, cx: &mut Context<Self>) {
 806        self.model = model;
 807        cx.notify()
 808    }
 809
 810    pub fn completion_mode(&self) -> CompletionMode {
 811        self.completion_mode
 812    }
 813
 814    pub fn set_completion_mode(&mut self, mode: CompletionMode, cx: &mut Context<Self>) {
 815        self.completion_mode = mode;
 816        cx.notify()
 817    }
 818
 819    #[cfg(any(test, feature = "test-support"))]
 820    pub fn last_message(&self) -> Option<Message> {
 821        if let Some(message) = self.pending_message.clone() {
 822            Some(Message::Agent(message))
 823        } else {
 824            self.messages.last().cloned()
 825        }
 826    }
 827
 828    pub fn add_tool(&mut self, tool: impl AgentTool) {
 829        self.tools.insert(tool.name(), tool.erase());
 830    }
 831
 832    pub fn remove_tool(&mut self, name: &str) -> bool {
 833        self.tools.remove(name).is_some()
 834    }
 835
 836    pub fn profile(&self) -> &AgentProfileId {
 837        &self.profile_id
 838    }
 839
 840    pub fn set_profile(&mut self, profile_id: AgentProfileId) {
 841        self.profile_id = profile_id;
 842    }
 843
 844    pub fn cancel(&mut self, cx: &mut Context<Self>) {
 845        if let Some(running_turn) = self.running_turn.take() {
 846            running_turn.cancel();
 847        }
 848        self.flush_pending_message(cx);
 849    }
 850
 851    pub fn truncate(&mut self, message_id: UserMessageId, cx: &mut Context<Self>) -> Result<()> {
 852        self.cancel(cx);
 853        let Some(position) = self.messages.iter().position(
 854            |msg| matches!(msg, Message::User(UserMessage { id, .. }) if id == &message_id),
 855        ) else {
 856            return Err(anyhow!("Message not found"));
 857        };
 858        self.messages.truncate(position);
 859        cx.notify();
 860        Ok(())
 861    }
 862
 863    pub fn resume(
 864        &mut self,
 865        cx: &mut Context<Self>,
 866    ) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>> {
 867        anyhow::ensure!(
 868            self.tool_use_limit_reached,
 869            "can only resume after tool use limit is reached"
 870        );
 871
 872        self.messages.push(Message::Resume);
 873        cx.notify();
 874
 875        log::info!("Total messages in thread: {}", self.messages.len());
 876        Ok(self.run_turn(cx))
 877    }
 878
 879    /// Sending a message results in the model streaming a response, which could include tool calls.
 880    /// After calling tools, the model will stops and waits for any outstanding tool calls to be completed and their results sent.
 881    /// The returned channel will report all the occurrences in which the model stops before erroring or ending its turn.
 882    pub fn send<T>(
 883        &mut self,
 884        id: UserMessageId,
 885        content: impl IntoIterator<Item = T>,
 886        cx: &mut Context<Self>,
 887    ) -> mpsc::UnboundedReceiver<Result<ThreadEvent>>
 888    where
 889        T: Into<UserMessageContent>,
 890    {
 891        log::info!("Thread::send called with model: {:?}", self.model.name());
 892        self.advance_prompt_id();
 893
 894        let content = content.into_iter().map(Into::into).collect::<Vec<_>>();
 895        log::debug!("Thread::send content: {:?}", content);
 896
 897        self.messages
 898            .push(Message::User(UserMessage { id, content }));
 899        cx.notify();
 900
 901        log::info!("Total messages in thread: {}", self.messages.len());
 902        self.run_turn(cx)
 903    }
 904
 905    fn run_turn(&mut self, cx: &mut Context<Self>) -> mpsc::UnboundedReceiver<Result<ThreadEvent>> {
 906        self.cancel(cx);
 907
 908        let model = self.model.clone();
 909        let (events_tx, events_rx) = mpsc::unbounded::<Result<ThreadEvent>>();
 910        let event_stream = ThreadEventStream(events_tx);
 911        let message_ix = self.messages.len().saturating_sub(1);
 912        self.tool_use_limit_reached = false;
 913        self.running_turn = Some(RunningTurn {
 914            event_stream: event_stream.clone(),
 915            _task: cx.spawn(async move |this, cx| {
 916                log::info!("Starting agent turn execution");
 917                let turn_result: Result<()> = async {
 918                    let mut completion_intent = CompletionIntent::UserPrompt;
 919                    loop {
 920                        log::debug!(
 921                            "Building completion request with intent: {:?}",
 922                            completion_intent
 923                        );
 924                        let request = this.update(cx, |this, cx| {
 925                            this.build_completion_request(completion_intent, cx)
 926                        })?;
 927
 928                        log::info!("Calling model.stream_completion");
 929                        let mut events = model.stream_completion(request, cx).await?;
 930                        log::debug!("Stream completion started successfully");
 931
 932                        let mut tool_use_limit_reached = false;
 933                        let mut tool_uses = FuturesUnordered::new();
 934                        while let Some(event) = events.next().await {
 935                            match event? {
 936                                LanguageModelCompletionEvent::StatusUpdate(
 937                                    CompletionRequestStatus::ToolUseLimitReached,
 938                                ) => {
 939                                    tool_use_limit_reached = true;
 940                                }
 941                                LanguageModelCompletionEvent::Stop(reason) => {
 942                                    event_stream.send_stop(reason);
 943                                    if reason == StopReason::Refusal {
 944                                        this.update(cx, |this, cx| {
 945                                            this.flush_pending_message(cx);
 946                                            this.messages.truncate(message_ix);
 947                                        })?;
 948                                        return Ok(());
 949                                    }
 950                                }
 951                                event => {
 952                                    log::trace!("Received completion event: {:?}", event);
 953                                    this.update(cx, |this, cx| {
 954                                        tool_uses.extend(this.handle_streamed_completion_event(
 955                                            event,
 956                                            &event_stream,
 957                                            cx,
 958                                        ));
 959                                    })
 960                                    .ok();
 961                                }
 962                            }
 963                        }
 964
 965                        let used_tools = tool_uses.is_empty();
 966                        while let Some(tool_result) = tool_uses.next().await {
 967                            log::info!("Tool finished {:?}", tool_result);
 968
 969                            event_stream.update_tool_call_fields(
 970                                &tool_result.tool_use_id,
 971                                acp::ToolCallUpdateFields {
 972                                    status: Some(if tool_result.is_error {
 973                                        acp::ToolCallStatus::Failed
 974                                    } else {
 975                                        acp::ToolCallStatus::Completed
 976                                    }),
 977                                    raw_output: tool_result.output.clone(),
 978                                    ..Default::default()
 979                                },
 980                            );
 981                            this.update(cx, |this, _cx| {
 982                                this.pending_message()
 983                                    .tool_results
 984                                    .insert(tool_result.tool_use_id.clone(), tool_result);
 985                            })
 986                            .ok();
 987                        }
 988
 989                        if tool_use_limit_reached {
 990                            log::info!("Tool use limit reached, completing turn");
 991                            this.update(cx, |this, _cx| this.tool_use_limit_reached = true)?;
 992                            return Err(language_model::ToolUseLimitReachedError.into());
 993                        } else if used_tools {
 994                            log::info!("No tool uses found, completing turn");
 995                            return Ok(());
 996                        } else {
 997                            this.update(cx, |this, cx| this.flush_pending_message(cx))?;
 998                            completion_intent = CompletionIntent::ToolResults;
 999                        }
1000                    }
1001                }
1002                .await;
1003
1004                if let Err(error) = turn_result {
1005                    log::error!("Turn execution failed: {:?}", error);
1006                    event_stream.send_error(error);
1007                } else {
1008                    log::info!("Turn execution completed successfully");
1009                }
1010
1011                this.update(cx, |this, cx| {
1012                    this.flush_pending_message(cx);
1013                    this.running_turn.take();
1014                })
1015                .ok();
1016            }),
1017        });
1018        events_rx
1019    }
1020
1021    pub fn build_system_message(&self) -> LanguageModelRequestMessage {
1022        log::debug!("Building system message");
1023        let prompt = SystemPromptTemplate {
1024            project: &self.project_context.borrow(),
1025            available_tools: self.tools.keys().cloned().collect(),
1026        }
1027        .render(&self.templates)
1028        .context("failed to build system prompt")
1029        .expect("Invalid template");
1030        log::debug!("System message built");
1031        LanguageModelRequestMessage {
1032            role: Role::System,
1033            content: vec![prompt.into()],
1034            cache: true,
1035        }
1036    }
1037
1038    /// A helper method that's called on every streamed completion event.
1039    /// Returns an optional tool result task, which the main agentic loop in
1040    /// send will send back to the model when it resolves.
1041    fn handle_streamed_completion_event(
1042        &mut self,
1043        event: LanguageModelCompletionEvent,
1044        event_stream: &ThreadEventStream,
1045        cx: &mut Context<Self>,
1046    ) -> Option<Task<LanguageModelToolResult>> {
1047        log::trace!("Handling streamed completion event: {:?}", event);
1048        use LanguageModelCompletionEvent::*;
1049
1050        match event {
1051            StartMessage { .. } => {
1052                self.flush_pending_message(cx);
1053                self.pending_message = Some(AgentMessage::default());
1054            }
1055            Text(new_text) => self.handle_text_event(new_text, event_stream, cx),
1056            Thinking { text, signature } => {
1057                self.handle_thinking_event(text, signature, event_stream, cx)
1058            }
1059            RedactedThinking { data } => self.handle_redacted_thinking_event(data, cx),
1060            ToolUse(tool_use) => {
1061                return self.handle_tool_use_event(tool_use, event_stream, cx);
1062            }
1063            ToolUseJsonParseError {
1064                id,
1065                tool_name,
1066                raw_input,
1067                json_parse_error,
1068            } => {
1069                return Some(Task::ready(self.handle_tool_use_json_parse_error_event(
1070                    id,
1071                    tool_name,
1072                    raw_input,
1073                    json_parse_error,
1074                )));
1075            }
1076            UsageUpdate(_) | StatusUpdate(_) => {}
1077            Stop(_) => unreachable!(),
1078        }
1079
1080        None
1081    }
1082
1083    fn handle_text_event(
1084        &mut self,
1085        new_text: String,
1086        event_stream: &ThreadEventStream,
1087        cx: &mut Context<Self>,
1088    ) {
1089        event_stream.send_text(&new_text);
1090
1091        let last_message = self.pending_message();
1092        if let Some(AgentMessageContent::Text(text)) = last_message.content.last_mut() {
1093            text.push_str(&new_text);
1094        } else {
1095            last_message
1096                .content
1097                .push(AgentMessageContent::Text(new_text));
1098        }
1099
1100        cx.notify();
1101    }
1102
1103    fn handle_thinking_event(
1104        &mut self,
1105        new_text: String,
1106        new_signature: Option<String>,
1107        event_stream: &ThreadEventStream,
1108        cx: &mut Context<Self>,
1109    ) {
1110        event_stream.send_thinking(&new_text);
1111
1112        let last_message = self.pending_message();
1113        if let Some(AgentMessageContent::Thinking { text, signature }) =
1114            last_message.content.last_mut()
1115        {
1116            text.push_str(&new_text);
1117            *signature = new_signature.or(signature.take());
1118        } else {
1119            last_message.content.push(AgentMessageContent::Thinking {
1120                text: new_text,
1121                signature: new_signature,
1122            });
1123        }
1124
1125        cx.notify();
1126    }
1127
1128    fn handle_redacted_thinking_event(&mut self, data: String, cx: &mut Context<Self>) {
1129        let last_message = self.pending_message();
1130        last_message
1131            .content
1132            .push(AgentMessageContent::RedactedThinking(data));
1133        cx.notify();
1134    }
1135
1136    fn handle_tool_use_event(
1137        &mut self,
1138        tool_use: LanguageModelToolUse,
1139        event_stream: &ThreadEventStream,
1140        cx: &mut Context<Self>,
1141    ) -> Option<Task<LanguageModelToolResult>> {
1142        cx.notify();
1143
1144        let tool = self.tools.get(tool_use.name.as_ref()).cloned();
1145        let mut title = SharedString::from(&tool_use.name);
1146        let mut kind = acp::ToolKind::Other;
1147        if let Some(tool) = tool.as_ref() {
1148            title = tool.initial_title(tool_use.input.clone());
1149            kind = tool.kind();
1150        }
1151
1152        // Ensure the last message ends in the current tool use
1153        let last_message = self.pending_message();
1154        let push_new_tool_use = last_message.content.last_mut().map_or(true, |content| {
1155            if let AgentMessageContent::ToolUse(last_tool_use) = content {
1156                if last_tool_use.id == tool_use.id {
1157                    *last_tool_use = tool_use.clone();
1158                    false
1159                } else {
1160                    true
1161                }
1162            } else {
1163                true
1164            }
1165        });
1166
1167        if push_new_tool_use {
1168            event_stream.send_tool_call(&tool_use.id, title, kind, tool_use.input.clone());
1169            last_message
1170                .content
1171                .push(AgentMessageContent::ToolUse(tool_use.clone()));
1172        } else {
1173            event_stream.update_tool_call_fields(
1174                &tool_use.id,
1175                acp::ToolCallUpdateFields {
1176                    title: Some(title.into()),
1177                    kind: Some(kind),
1178                    raw_input: Some(tool_use.input.clone()),
1179                    ..Default::default()
1180                },
1181            );
1182        }
1183
1184        if !tool_use.is_input_complete {
1185            return None;
1186        }
1187
1188        let Some(tool) = tool else {
1189            let content = format!("No tool named {} exists", tool_use.name);
1190            return Some(Task::ready(LanguageModelToolResult {
1191                content: LanguageModelToolResultContent::Text(Arc::from(content)),
1192                tool_use_id: tool_use.id,
1193                tool_name: tool_use.name,
1194                is_error: true,
1195                output: None,
1196            }));
1197        };
1198
1199        let fs = self.project.read(cx).fs().clone();
1200        let tool_event_stream =
1201            ToolCallEventStream::new(tool_use.id.clone(), event_stream.clone(), Some(fs));
1202        tool_event_stream.update_fields(acp::ToolCallUpdateFields {
1203            status: Some(acp::ToolCallStatus::InProgress),
1204            ..Default::default()
1205        });
1206        let supports_images = self.model.supports_images();
1207        let tool_result = tool.run(tool_use.input, tool_event_stream, cx);
1208        log::info!("Running tool {}", tool_use.name);
1209        Some(cx.foreground_executor().spawn(async move {
1210            let tool_result = tool_result.await.and_then(|output| {
1211                if let LanguageModelToolResultContent::Image(_) = &output.llm_output {
1212                    if !supports_images {
1213                        return Err(anyhow!(
1214                            "Attempted to read an image, but this model doesn't support it.",
1215                        ));
1216                    }
1217                }
1218                Ok(output)
1219            });
1220
1221            match tool_result {
1222                Ok(output) => LanguageModelToolResult {
1223                    tool_use_id: tool_use.id,
1224                    tool_name: tool_use.name,
1225                    is_error: false,
1226                    content: output.llm_output,
1227                    output: Some(output.raw_output),
1228                },
1229                Err(error) => LanguageModelToolResult {
1230                    tool_use_id: tool_use.id,
1231                    tool_name: tool_use.name,
1232                    is_error: true,
1233                    content: LanguageModelToolResultContent::Text(Arc::from(error.to_string())),
1234                    output: None,
1235                },
1236            }
1237        }))
1238    }
1239
1240    fn handle_tool_use_json_parse_error_event(
1241        &mut self,
1242        tool_use_id: LanguageModelToolUseId,
1243        tool_name: Arc<str>,
1244        raw_input: Arc<str>,
1245        json_parse_error: String,
1246    ) -> LanguageModelToolResult {
1247        let tool_output = format!("Error parsing input JSON: {json_parse_error}");
1248        LanguageModelToolResult {
1249            tool_use_id,
1250            tool_name,
1251            is_error: true,
1252            content: LanguageModelToolResultContent::Text(tool_output.into()),
1253            output: Some(serde_json::Value::String(raw_input.to_string())),
1254        }
1255    }
1256
1257    fn pending_message(&mut self) -> &mut AgentMessage {
1258        self.pending_message.get_or_insert_default()
1259    }
1260
1261    fn flush_pending_message(&mut self, cx: &mut Context<Self>) {
1262        let Some(mut message) = self.pending_message.take() else {
1263            return;
1264        };
1265
1266        for content in &message.content {
1267            let AgentMessageContent::ToolUse(tool_use) = content else {
1268                continue;
1269            };
1270
1271            if !message.tool_results.contains_key(&tool_use.id) {
1272                message.tool_results.insert(
1273                    tool_use.id.clone(),
1274                    LanguageModelToolResult {
1275                        tool_use_id: tool_use.id.clone(),
1276                        tool_name: tool_use.name.clone(),
1277                        is_error: true,
1278                        content: LanguageModelToolResultContent::Text(TOOL_CANCELED_MESSAGE.into()),
1279                        output: None,
1280                    },
1281                );
1282            }
1283        }
1284
1285        self.messages.push(Message::Agent(message));
1286        dbg!("!!!!!!!!!!!!!!!!!!!!!!!");
1287        cx.notify()
1288    }
1289
1290    pub(crate) fn build_completion_request(
1291        &self,
1292        completion_intent: CompletionIntent,
1293        cx: &mut App,
1294    ) -> LanguageModelRequest {
1295        log::debug!("Building completion request");
1296        log::debug!("Completion intent: {:?}", completion_intent);
1297        log::debug!("Completion mode: {:?}", self.completion_mode);
1298
1299        let messages = self.build_request_messages();
1300        log::info!("Request will include {} messages", messages.len());
1301
1302        let tools = if let Some(tools) = self.tools(cx).log_err() {
1303            tools
1304                .filter_map(|tool| {
1305                    let tool_name = tool.name().to_string();
1306                    log::trace!("Including tool: {}", tool_name);
1307                    Some(LanguageModelRequestTool {
1308                        name: tool_name,
1309                        description: tool.description().to_string(),
1310                        input_schema: tool
1311                            .input_schema(self.model.tool_input_format())
1312                            .log_err()?,
1313                    })
1314                })
1315                .collect()
1316        } else {
1317            Vec::new()
1318        };
1319
1320        log::info!("Request includes {} tools", tools.len());
1321
1322        let request = LanguageModelRequest {
1323            thread_id: Some(self.id.to_string()),
1324            prompt_id: Some(self.prompt_id.to_string()),
1325            intent: Some(completion_intent),
1326            mode: Some(self.completion_mode.into()),
1327            messages,
1328            tools,
1329            tool_choice: None,
1330            stop: Vec::new(),
1331            temperature: AgentSettings::temperature_for_model(self.model(), cx),
1332            thinking_allowed: true,
1333        };
1334
1335        log::debug!("Completion request built successfully");
1336        request
1337    }
1338
1339    fn tools<'a>(&'a self, cx: &'a App) -> Result<impl Iterator<Item = &'a Arc<dyn AnyAgentTool>>> {
1340        let profile = AgentSettings::get_global(cx)
1341            .profiles
1342            .get(&self.profile_id)
1343            .context("profile not found")?;
1344        let provider_id = self.model.provider_id();
1345
1346        Ok(self
1347            .tools
1348            .iter()
1349            .filter(move |(_, tool)| tool.supported_provider(&provider_id))
1350            .filter_map(|(tool_name, tool)| {
1351                if profile.is_tool_enabled(tool_name) {
1352                    Some(tool)
1353                } else {
1354                    None
1355                }
1356            })
1357            .chain(self.context_server_registry.read(cx).servers().flat_map(
1358                |(server_id, tools)| {
1359                    tools.iter().filter_map(|(tool_name, tool)| {
1360                        if profile.is_context_server_tool_enabled(&server_id.0, tool_name) {
1361                            Some(tool)
1362                        } else {
1363                            None
1364                        }
1365                    })
1366                },
1367            )))
1368    }
1369
1370    fn build_request_messages(&self) -> Vec<LanguageModelRequestMessage> {
1371        log::trace!(
1372            "Building request messages from {} thread messages",
1373            self.messages.len()
1374        );
1375        let mut messages = vec![self.build_system_message()];
1376        for message in &self.messages {
1377            match message {
1378                Message::User(message) => messages.push(message.to_request()),
1379                Message::Agent(message) => messages.extend(message.to_request()),
1380                Message::Resume => messages.push(LanguageModelRequestMessage {
1381                    role: Role::User,
1382                    content: vec!["Continue where you left off".into()],
1383                    cache: false,
1384                }),
1385            }
1386        }
1387
1388        if let Some(message) = self.pending_message.as_ref() {
1389            messages.extend(message.to_request());
1390        }
1391
1392        if let Some(last_user_message) = messages
1393            .iter_mut()
1394            .rev()
1395            .find(|message| message.role == Role::User)
1396        {
1397            last_user_message.cache = true;
1398        }
1399
1400        messages
1401    }
1402
1403    pub fn to_markdown(&self) -> String {
1404        let mut markdown = String::new();
1405        for (ix, message) in self.messages.iter().enumerate() {
1406            if ix > 0 {
1407                markdown.push('\n');
1408            }
1409            markdown.push_str(&message.to_markdown());
1410        }
1411
1412        if let Some(message) = self.pending_message.as_ref() {
1413            markdown.push('\n');
1414            markdown.push_str(&message.to_markdown());
1415        }
1416
1417        markdown
1418    }
1419
1420    fn advance_prompt_id(&mut self) {
1421        self.prompt_id = PromptId::new();
1422    }
1423}
1424
1425struct RunningTurn {
1426    /// Holds the task that handles agent interaction until the end of the turn.
1427    /// Survives across multiple requests as the model performs tool calls and
1428    /// we run tools, report their results.
1429    _task: Task<()>,
1430    /// The current event stream for the running turn. Used to report a final
1431    /// cancellation event if we cancel the turn.
1432    event_stream: ThreadEventStream,
1433}
1434
1435impl RunningTurn {
1436    fn cancel(self) {
1437        log::debug!("Cancelling in progress turn");
1438        self.event_stream.send_canceled();
1439    }
1440}
1441
1442pub trait AgentTool
1443where
1444    Self: 'static + Sized,
1445{
1446    type Input: for<'de> Deserialize<'de> + Serialize + JsonSchema;
1447    type Output: for<'de> Deserialize<'de> + Serialize + Into<LanguageModelToolResultContent>;
1448
1449    fn name(&self) -> SharedString;
1450
1451    fn description(&self) -> SharedString {
1452        let schema = schemars::schema_for!(Self::Input);
1453        SharedString::new(
1454            schema
1455                .get("description")
1456                .and_then(|description| description.as_str())
1457                .unwrap_or_default(),
1458        )
1459    }
1460
1461    fn kind(&self) -> acp::ToolKind;
1462
1463    /// The initial tool title to display. Can be updated during the tool run.
1464    fn initial_title(&self, input: Result<Self::Input, serde_json::Value>) -> SharedString;
1465
1466    /// Returns the JSON schema that describes the tool's input.
1467    fn input_schema(&self) -> Schema {
1468        schemars::schema_for!(Self::Input)
1469    }
1470
1471    /// Some tools rely on a provider for the underlying billing or other reasons.
1472    /// Allow the tool to check if they are compatible, or should be filtered out.
1473    fn supported_provider(&self, _provider: &LanguageModelProviderId) -> bool {
1474        true
1475    }
1476
1477    /// Runs the tool with the provided input.
1478    fn run(
1479        self: Arc<Self>,
1480        input: Self::Input,
1481        event_stream: ToolCallEventStream,
1482        cx: &mut App,
1483    ) -> Task<Result<Self::Output>>;
1484
1485    /// Emits events for a previous execution of the tool.
1486    fn replay(
1487        &self,
1488        _input: Self::Input,
1489        _output: Self::Output,
1490        _event_stream: ToolCallEventStream,
1491        _cx: &mut App,
1492    ) -> Result<()> {
1493        Ok(())
1494    }
1495
1496    fn erase(self) -> Arc<dyn AnyAgentTool> {
1497        Arc::new(Erased(Arc::new(self)))
1498    }
1499}
1500
1501pub struct Erased<T>(T);
1502
1503pub struct AgentToolOutput {
1504    pub llm_output: LanguageModelToolResultContent,
1505    pub raw_output: serde_json::Value,
1506}
1507
1508pub trait AnyAgentTool {
1509    fn name(&self) -> SharedString;
1510    fn description(&self) -> SharedString;
1511    fn kind(&self) -> acp::ToolKind;
1512    fn initial_title(&self, input: serde_json::Value) -> SharedString;
1513    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value>;
1514    fn supported_provider(&self, _provider: &LanguageModelProviderId) -> bool {
1515        true
1516    }
1517    fn run(
1518        self: Arc<Self>,
1519        input: serde_json::Value,
1520        event_stream: ToolCallEventStream,
1521        cx: &mut App,
1522    ) -> Task<Result<AgentToolOutput>>;
1523    fn replay(
1524        &self,
1525        input: serde_json::Value,
1526        output: serde_json::Value,
1527        event_stream: ToolCallEventStream,
1528        cx: &mut App,
1529    ) -> Result<()>;
1530}
1531
1532impl<T> AnyAgentTool for Erased<Arc<T>>
1533where
1534    T: AgentTool,
1535{
1536    fn name(&self) -> SharedString {
1537        self.0.name()
1538    }
1539
1540    fn description(&self) -> SharedString {
1541        self.0.description()
1542    }
1543
1544    fn kind(&self) -> agent_client_protocol::ToolKind {
1545        self.0.kind()
1546    }
1547
1548    fn initial_title(&self, input: serde_json::Value) -> SharedString {
1549        let parsed_input = serde_json::from_value(input.clone()).map_err(|_| input);
1550        self.0.initial_title(parsed_input)
1551    }
1552
1553    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
1554        let mut json = serde_json::to_value(self.0.input_schema())?;
1555        adapt_schema_to_format(&mut json, format)?;
1556        Ok(json)
1557    }
1558
1559    fn supported_provider(&self, provider: &LanguageModelProviderId) -> bool {
1560        self.0.supported_provider(provider)
1561    }
1562
1563    fn run(
1564        self: Arc<Self>,
1565        input: serde_json::Value,
1566        event_stream: ToolCallEventStream,
1567        cx: &mut App,
1568    ) -> Task<Result<AgentToolOutput>> {
1569        cx.spawn(async move |cx| {
1570            let input = serde_json::from_value(input)?;
1571            let output = cx
1572                .update(|cx| self.0.clone().run(input, event_stream, cx))?
1573                .await?;
1574            let raw_output = serde_json::to_value(&output)?;
1575            Ok(AgentToolOutput {
1576                llm_output: output.into(),
1577                raw_output,
1578            })
1579        })
1580    }
1581
1582    fn replay(
1583        &self,
1584        input: serde_json::Value,
1585        output: serde_json::Value,
1586        event_stream: ToolCallEventStream,
1587        cx: &mut App,
1588    ) -> Result<()> {
1589        let input = serde_json::from_value(input)?;
1590        let output = serde_json::from_value(output)?;
1591        self.0.replay(input, output, event_stream, cx)
1592    }
1593}
1594
1595#[derive(Clone)]
1596struct ThreadEventStream(mpsc::UnboundedSender<Result<ThreadEvent>>);
1597
1598impl ThreadEventStream {
1599    fn send_user_message(&self, message: &UserMessage) {
1600        self.0
1601            .unbounded_send(Ok(ThreadEvent::UserMessage(message.clone())))
1602            .ok();
1603    }
1604
1605    fn send_text(&self, text: &str) {
1606        self.0
1607            .unbounded_send(Ok(ThreadEvent::AgentText(text.to_string())))
1608            .ok();
1609    }
1610
1611    fn send_thinking(&self, text: &str) {
1612        self.0
1613            .unbounded_send(Ok(ThreadEvent::AgentThinking(text.to_string())))
1614            .ok();
1615    }
1616
1617    fn send_tool_call(
1618        &self,
1619        id: &LanguageModelToolUseId,
1620        title: SharedString,
1621        kind: acp::ToolKind,
1622        input: serde_json::Value,
1623    ) {
1624        self.0
1625            .unbounded_send(Ok(ThreadEvent::ToolCall(Self::initial_tool_call(
1626                id,
1627                title.to_string(),
1628                kind,
1629                input,
1630            ))))
1631            .ok();
1632    }
1633
1634    fn initial_tool_call(
1635        id: &LanguageModelToolUseId,
1636        title: String,
1637        kind: acp::ToolKind,
1638        input: serde_json::Value,
1639    ) -> acp::ToolCall {
1640        acp::ToolCall {
1641            id: acp::ToolCallId(id.to_string().into()),
1642            title,
1643            kind,
1644            status: acp::ToolCallStatus::Pending,
1645            content: vec![],
1646            locations: vec![],
1647            raw_input: Some(input),
1648            raw_output: None,
1649        }
1650    }
1651
1652    fn update_tool_call_fields(
1653        &self,
1654        tool_use_id: &LanguageModelToolUseId,
1655        fields: acp::ToolCallUpdateFields,
1656    ) {
1657        self.0
1658            .unbounded_send(Ok(ThreadEvent::ToolCallUpdate(
1659                acp::ToolCallUpdate {
1660                    id: acp::ToolCallId(tool_use_id.to_string().into()),
1661                    fields,
1662                }
1663                .into(),
1664            )))
1665            .ok();
1666    }
1667
1668    fn send_stop(&self, reason: StopReason) {
1669        match reason {
1670            StopReason::EndTurn => {
1671                self.0
1672                    .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::EndTurn)))
1673                    .ok();
1674            }
1675            StopReason::MaxTokens => {
1676                self.0
1677                    .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::MaxTokens)))
1678                    .ok();
1679            }
1680            StopReason::Refusal => {
1681                self.0
1682                    .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::Refusal)))
1683                    .ok();
1684            }
1685            StopReason::ToolUse => {}
1686        }
1687    }
1688
1689    fn send_canceled(&self) {
1690        self.0
1691            .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::Canceled)))
1692            .ok();
1693    }
1694
1695    fn send_error(&self, error: impl Into<anyhow::Error>) {
1696        self.0.unbounded_send(Err(error.into())).ok();
1697    }
1698}
1699
1700#[derive(Clone)]
1701pub struct ToolCallEventStream {
1702    tool_use_id: LanguageModelToolUseId,
1703    stream: ThreadEventStream,
1704    fs: Option<Arc<dyn Fs>>,
1705}
1706
1707impl ToolCallEventStream {
1708    #[cfg(test)]
1709    pub fn test() -> (Self, ToolCallEventStreamReceiver) {
1710        let (events_tx, events_rx) = mpsc::unbounded::<Result<ThreadEvent>>();
1711
1712        let stream = ToolCallEventStream::new("test_id".into(), ThreadEventStream(events_tx), None);
1713
1714        (stream, ToolCallEventStreamReceiver(events_rx))
1715    }
1716
1717    fn new(
1718        tool_use_id: LanguageModelToolUseId,
1719        stream: ThreadEventStream,
1720        fs: Option<Arc<dyn Fs>>,
1721    ) -> Self {
1722        Self {
1723            tool_use_id,
1724            stream,
1725            fs,
1726        }
1727    }
1728
1729    pub fn update_fields(&self, fields: acp::ToolCallUpdateFields) {
1730        self.stream
1731            .update_tool_call_fields(&self.tool_use_id, fields);
1732    }
1733
1734    pub fn update_diff(&self, diff: Entity<acp_thread::Diff>) {
1735        self.stream
1736            .0
1737            .unbounded_send(Ok(ThreadEvent::ToolCallUpdate(
1738                acp_thread::ToolCallUpdateDiff {
1739                    id: acp::ToolCallId(self.tool_use_id.to_string().into()),
1740                    diff,
1741                }
1742                .into(),
1743            )))
1744            .ok();
1745    }
1746
1747    pub fn update_terminal(&self, terminal: Entity<acp_thread::Terminal>) {
1748        self.stream
1749            .0
1750            .unbounded_send(Ok(ThreadEvent::ToolCallUpdate(
1751                acp_thread::ToolCallUpdateTerminal {
1752                    id: acp::ToolCallId(self.tool_use_id.to_string().into()),
1753                    terminal,
1754                }
1755                .into(),
1756            )))
1757            .ok();
1758    }
1759
1760    pub fn authorize(&self, title: impl Into<String>, cx: &mut App) -> Task<Result<()>> {
1761        if agent_settings::AgentSettings::get_global(cx).always_allow_tool_actions {
1762            return Task::ready(Ok(()));
1763        }
1764
1765        let (response_tx, response_rx) = oneshot::channel();
1766        self.stream
1767            .0
1768            .unbounded_send(Ok(ThreadEvent::ToolCallAuthorization(
1769                ToolCallAuthorization {
1770                    tool_call: acp::ToolCallUpdate {
1771                        id: acp::ToolCallId(self.tool_use_id.to_string().into()),
1772                        fields: acp::ToolCallUpdateFields {
1773                            title: Some(title.into()),
1774                            ..Default::default()
1775                        },
1776                    },
1777                    options: vec![
1778                        acp::PermissionOption {
1779                            id: acp::PermissionOptionId("always_allow".into()),
1780                            name: "Always Allow".into(),
1781                            kind: acp::PermissionOptionKind::AllowAlways,
1782                        },
1783                        acp::PermissionOption {
1784                            id: acp::PermissionOptionId("allow".into()),
1785                            name: "Allow".into(),
1786                            kind: acp::PermissionOptionKind::AllowOnce,
1787                        },
1788                        acp::PermissionOption {
1789                            id: acp::PermissionOptionId("deny".into()),
1790                            name: "Deny".into(),
1791                            kind: acp::PermissionOptionKind::RejectOnce,
1792                        },
1793                    ],
1794                    response: response_tx,
1795                },
1796            )))
1797            .ok();
1798        let fs = self.fs.clone();
1799        cx.spawn(async move |cx| match response_rx.await?.0.as_ref() {
1800            "always_allow" => {
1801                if let Some(fs) = fs.clone() {
1802                    cx.update(|cx| {
1803                        update_settings_file::<AgentSettings>(fs, cx, |settings, _| {
1804                            settings.set_always_allow_tool_actions(true);
1805                        });
1806                    })?;
1807                }
1808
1809                Ok(())
1810            }
1811            "allow" => Ok(()),
1812            _ => Err(anyhow!("Permission to run tool denied by user")),
1813        })
1814    }
1815}
1816
1817#[cfg(test)]
1818pub struct ToolCallEventStreamReceiver(mpsc::UnboundedReceiver<Result<ThreadEvent>>);
1819
1820#[cfg(test)]
1821impl ToolCallEventStreamReceiver {
1822    pub async fn expect_authorization(&mut self) -> ToolCallAuthorization {
1823        let event = self.0.next().await;
1824        if let Some(Ok(ThreadEvent::ToolCallAuthorization(auth))) = event {
1825            auth
1826        } else {
1827            panic!("Expected ToolCallAuthorization but got: {:?}", event);
1828        }
1829    }
1830
1831    pub async fn expect_terminal(&mut self) -> Entity<acp_thread::Terminal> {
1832        let event = self.0.next().await;
1833        if let Some(Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateTerminal(
1834            update,
1835        )))) = event
1836        {
1837            update.terminal
1838        } else {
1839            panic!("Expected terminal but got: {:?}", event);
1840        }
1841    }
1842}
1843
1844#[cfg(test)]
1845impl std::ops::Deref for ToolCallEventStreamReceiver {
1846    type Target = mpsc::UnboundedReceiver<Result<ThreadEvent>>;
1847
1848    fn deref(&self) -> &Self::Target {
1849        &self.0
1850    }
1851}
1852
1853#[cfg(test)]
1854impl std::ops::DerefMut for ToolCallEventStreamReceiver {
1855    fn deref_mut(&mut self) -> &mut Self::Target {
1856        &mut self.0
1857    }
1858}
1859
1860impl From<&str> for UserMessageContent {
1861    fn from(text: &str) -> Self {
1862        Self::Text(text.into())
1863    }
1864}
1865
1866impl From<acp::ContentBlock> for UserMessageContent {
1867    fn from(value: acp::ContentBlock) -> Self {
1868        match value {
1869            acp::ContentBlock::Text(text_content) => Self::Text(text_content.text),
1870            acp::ContentBlock::Image(image_content) => Self::Image(convert_image(image_content)),
1871            acp::ContentBlock::Audio(_) => {
1872                // TODO
1873                Self::Text("[audio]".to_string())
1874            }
1875            acp::ContentBlock::ResourceLink(resource_link) => {
1876                match MentionUri::parse(&resource_link.uri) {
1877                    Ok(uri) => Self::Mention {
1878                        uri,
1879                        content: String::new(),
1880                    },
1881                    Err(err) => {
1882                        log::error!("Failed to parse mention link: {}", err);
1883                        Self::Text(format!("[{}]({})", resource_link.name, resource_link.uri))
1884                    }
1885                }
1886            }
1887            acp::ContentBlock::Resource(resource) => match resource.resource {
1888                acp::EmbeddedResourceResource::TextResourceContents(resource) => {
1889                    match MentionUri::parse(&resource.uri) {
1890                        Ok(uri) => Self::Mention {
1891                            uri,
1892                            content: resource.text,
1893                        },
1894                        Err(err) => {
1895                            log::error!("Failed to parse mention link: {}", err);
1896                            Self::Text(
1897                                MarkdownCodeBlock {
1898                                    tag: &resource.uri,
1899                                    text: &resource.text,
1900                                }
1901                                .to_string(),
1902                            )
1903                        }
1904                    }
1905                }
1906                acp::EmbeddedResourceResource::BlobResourceContents(_) => {
1907                    // TODO
1908                    Self::Text("[blob]".to_string())
1909                }
1910            },
1911        }
1912    }
1913}
1914
1915impl From<UserMessageContent> for acp::ContentBlock {
1916    fn from(content: UserMessageContent) -> Self {
1917        match content {
1918            UserMessageContent::Text(text) => acp::ContentBlock::Text(acp::TextContent {
1919                text,
1920                annotations: None,
1921            }),
1922            UserMessageContent::Image(image) => acp::ContentBlock::Image(acp::ImageContent {
1923                data: image.source.to_string(),
1924                mime_type: "image/png".to_string(),
1925                annotations: None,
1926                uri: None,
1927            }),
1928            UserMessageContent::Mention { uri, content } => {
1929                todo!()
1930            }
1931        }
1932    }
1933}
1934
1935fn convert_image(image_content: acp::ImageContent) -> LanguageModelImage {
1936    LanguageModelImage {
1937        source: image_content.data.into(),
1938        // TODO: make this optional?
1939        size: gpui::Size::new(0.into(), 0.into()),
1940    }
1941}