thread.rs

   1use crate::{
   2    ContextServerRegistry, DbLanguageModel, DbThread, SystemPromptTemplate, Template, Templates,
   3};
   4use acp_thread::{MentionUri, UserMessageId};
   5use action_log::ActionLog;
   6use agent::thread::{DetailedSummaryState, GitState, ProjectSnapshot, WorktreeSnapshot};
   7use agent_client_protocol as acp;
   8use agent_settings::{AgentProfileId, AgentSettings, CompletionMode};
   9use anyhow::{Context as _, Result, anyhow};
  10use assistant_tool::adapt_schema_to_format;
  11use chrono::{DateTime, Utc};
  12use cloud_llm_client::{CompletionIntent, CompletionRequestStatus};
  13use collections::IndexMap;
  14use fs::Fs;
  15use futures::{
  16    FutureExt,
  17    channel::{mpsc, oneshot},
  18    future::Shared,
  19    stream::FuturesUnordered,
  20};
  21use git::repository::DiffType;
  22use gpui::{App, AppContext, Context, Entity, SharedString, Task};
  23use language_model::{
  24    LanguageModel, LanguageModelCompletionEvent, LanguageModelImage, LanguageModelProviderId,
  25    LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
  26    LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolSchemaFormat,
  27    LanguageModelToolUse, LanguageModelToolUseId, Role, StopReason, TokenUsage,
  28};
  29use project::{
  30    Project,
  31    git_store::{GitStore, RepositoryState},
  32};
  33use prompt_store::ProjectContext;
  34use schemars::{JsonSchema, Schema};
  35use serde::{Deserialize, Serialize};
  36use settings::{Settings, update_settings_file};
  37use smol::stream::StreamExt;
  38use std::{cell::RefCell, collections::BTreeMap, path::Path, rc::Rc, sync::Arc};
  39use std::{fmt::Write, ops::Range};
  40use util::{ResultExt, markdown::MarkdownCodeBlock};
  41use uuid::Uuid;
  42
  43const TOOL_CANCELED_MESSAGE: &str = "Tool canceled by user";
  44
  45/// The ID of the user prompt that initiated a request.
  46///
  47/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
  48#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
  49pub struct PromptId(Arc<str>);
  50
  51impl PromptId {
  52    pub fn new() -> Self {
  53        Self(Uuid::new_v4().to_string().into())
  54    }
  55}
  56
  57impl std::fmt::Display for PromptId {
  58    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  59        write!(f, "{}", self.0)
  60    }
  61}
  62
  63#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
  64pub enum Message {
  65    User(UserMessage),
  66    Agent(AgentMessage),
  67    Resume,
  68}
  69
  70impl Message {
  71    pub fn as_agent_message(&self) -> Option<&AgentMessage> {
  72        match self {
  73            Message::Agent(agent_message) => Some(agent_message),
  74            _ => None,
  75        }
  76    }
  77
  78    pub fn to_markdown(&self) -> String {
  79        match self {
  80            Message::User(message) => message.to_markdown(),
  81            Message::Agent(message) => message.to_markdown(),
  82            Message::Resume => "[resumed after tool use limit was reached]".into(),
  83        }
  84    }
  85}
  86
  87#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
  88pub struct UserMessage {
  89    pub id: UserMessageId,
  90    pub content: Vec<UserMessageContent>,
  91}
  92
  93#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
  94pub enum UserMessageContent {
  95    Text(String),
  96    Mention { uri: MentionUri, content: String },
  97    Image(LanguageModelImage),
  98}
  99
 100impl UserMessage {
 101    pub fn to_markdown(&self) -> String {
 102        let mut markdown = String::from("## User\n\n");
 103
 104        for content in &self.content {
 105            match content {
 106                UserMessageContent::Text(text) => {
 107                    markdown.push_str(text);
 108                    markdown.push('\n');
 109                }
 110                UserMessageContent::Image(_) => {
 111                    markdown.push_str("<image />\n");
 112                }
 113                UserMessageContent::Mention { uri, content } => {
 114                    if !content.is_empty() {
 115                        let _ = write!(&mut markdown, "{}\n\n{}\n", uri.as_link(), content);
 116                    } else {
 117                        let _ = write!(&mut markdown, "{}\n", uri.as_link());
 118                    }
 119                }
 120            }
 121        }
 122
 123        markdown
 124    }
 125
 126    fn to_request(&self) -> LanguageModelRequestMessage {
 127        let mut message = LanguageModelRequestMessage {
 128            role: Role::User,
 129            content: Vec::with_capacity(self.content.len()),
 130            cache: false,
 131        };
 132
 133        const OPEN_CONTEXT: &str = "<context>\n\
 134            The following items were attached by the user. \
 135            They are up-to-date and don't need to be re-read.\n\n";
 136
 137        const OPEN_FILES_TAG: &str = "<files>";
 138        const OPEN_SYMBOLS_TAG: &str = "<symbols>";
 139        const OPEN_THREADS_TAG: &str = "<threads>";
 140        const OPEN_FETCH_TAG: &str = "<fetched_urls>";
 141        const OPEN_RULES_TAG: &str =
 142            "<rules>\nThe user has specified the following rules that should be applied:\n";
 143
 144        let mut file_context = OPEN_FILES_TAG.to_string();
 145        let mut symbol_context = OPEN_SYMBOLS_TAG.to_string();
 146        let mut thread_context = OPEN_THREADS_TAG.to_string();
 147        let mut fetch_context = OPEN_FETCH_TAG.to_string();
 148        let mut rules_context = OPEN_RULES_TAG.to_string();
 149
 150        for chunk in &self.content {
 151            let chunk = match chunk {
 152                UserMessageContent::Text(text) => {
 153                    language_model::MessageContent::Text(text.clone())
 154                }
 155                UserMessageContent::Image(value) => {
 156                    language_model::MessageContent::Image(value.clone())
 157                }
 158                UserMessageContent::Mention { uri, content } => {
 159                    match uri {
 160                        MentionUri::File { abs_path, .. } => {
 161                            write!(
 162                                &mut symbol_context,
 163                                "\n{}",
 164                                MarkdownCodeBlock {
 165                                    tag: &codeblock_tag(&abs_path, None),
 166                                    text: &content.to_string(),
 167                                }
 168                            )
 169                            .ok();
 170                        }
 171                        MentionUri::Symbol {
 172                            path, line_range, ..
 173                        }
 174                        | MentionUri::Selection {
 175                            path, line_range, ..
 176                        } => {
 177                            write!(
 178                                &mut rules_context,
 179                                "\n{}",
 180                                MarkdownCodeBlock {
 181                                    tag: &codeblock_tag(&path, Some(line_range)),
 182                                    text: &content
 183                                }
 184                            )
 185                            .ok();
 186                        }
 187                        MentionUri::Thread { .. } => {
 188                            write!(&mut thread_context, "\n{}\n", content).ok();
 189                        }
 190                        MentionUri::TextThread { .. } => {
 191                            write!(&mut thread_context, "\n{}\n", content).ok();
 192                        }
 193                        MentionUri::Rule { .. } => {
 194                            write!(
 195                                &mut rules_context,
 196                                "\n{}",
 197                                MarkdownCodeBlock {
 198                                    tag: "",
 199                                    text: &content
 200                                }
 201                            )
 202                            .ok();
 203                        }
 204                        MentionUri::Fetch { url } => {
 205                            write!(&mut fetch_context, "\nFetch: {}\n\n{}", url, content).ok();
 206                        }
 207                    }
 208
 209                    language_model::MessageContent::Text(uri.as_link().to_string())
 210                }
 211            };
 212
 213            message.content.push(chunk);
 214        }
 215
 216        let len_before_context = message.content.len();
 217
 218        if file_context.len() > OPEN_FILES_TAG.len() {
 219            file_context.push_str("</files>\n");
 220            message
 221                .content
 222                .push(language_model::MessageContent::Text(file_context));
 223        }
 224
 225        if symbol_context.len() > OPEN_SYMBOLS_TAG.len() {
 226            symbol_context.push_str("</symbols>\n");
 227            message
 228                .content
 229                .push(language_model::MessageContent::Text(symbol_context));
 230        }
 231
 232        if thread_context.len() > OPEN_THREADS_TAG.len() {
 233            thread_context.push_str("</threads>\n");
 234            message
 235                .content
 236                .push(language_model::MessageContent::Text(thread_context));
 237        }
 238
 239        if fetch_context.len() > OPEN_FETCH_TAG.len() {
 240            fetch_context.push_str("</fetched_urls>\n");
 241            message
 242                .content
 243                .push(language_model::MessageContent::Text(fetch_context));
 244        }
 245
 246        if rules_context.len() > OPEN_RULES_TAG.len() {
 247            rules_context.push_str("</user_rules>\n");
 248            message
 249                .content
 250                .push(language_model::MessageContent::Text(rules_context));
 251        }
 252
 253        if message.content.len() > len_before_context {
 254            message.content.insert(
 255                len_before_context,
 256                language_model::MessageContent::Text(OPEN_CONTEXT.into()),
 257            );
 258            message
 259                .content
 260                .push(language_model::MessageContent::Text("</context>".into()));
 261        }
 262
 263        message
 264    }
 265}
 266
 267fn codeblock_tag(full_path: &Path, line_range: Option<&Range<u32>>) -> String {
 268    let mut result = String::new();
 269
 270    if let Some(extension) = full_path.extension().and_then(|ext| ext.to_str()) {
 271        let _ = write!(result, "{} ", extension);
 272    }
 273
 274    let _ = write!(result, "{}", full_path.display());
 275
 276    if let Some(range) = line_range {
 277        if range.start == range.end {
 278            let _ = write!(result, ":{}", range.start + 1);
 279        } else {
 280            let _ = write!(result, ":{}-{}", range.start + 1, range.end + 1);
 281        }
 282    }
 283
 284    result
 285}
 286
 287impl AgentMessage {
 288    pub fn to_markdown(&self) -> String {
 289        let mut markdown = String::from("## Assistant\n\n");
 290
 291        for content in &self.content {
 292            match content {
 293                AgentMessageContent::Text(text) => {
 294                    markdown.push_str(text);
 295                    markdown.push('\n');
 296                }
 297                AgentMessageContent::Thinking { text, .. } => {
 298                    markdown.push_str("<think>");
 299                    markdown.push_str(text);
 300                    markdown.push_str("</think>\n");
 301                }
 302                AgentMessageContent::RedactedThinking(_) => {
 303                    markdown.push_str("<redacted_thinking />\n")
 304                }
 305                AgentMessageContent::ToolUse(tool_use) => {
 306                    markdown.push_str(&format!(
 307                        "**Tool Use**: {} (ID: {})\n",
 308                        tool_use.name, tool_use.id
 309                    ));
 310                    markdown.push_str(&format!(
 311                        "{}\n",
 312                        MarkdownCodeBlock {
 313                            tag: "json",
 314                            text: &format!("{:#}", tool_use.input)
 315                        }
 316                    ));
 317                }
 318            }
 319        }
 320
 321        for tool_result in self.tool_results.values() {
 322            markdown.push_str(&format!(
 323                "**Tool Result**: {} (ID: {})\n\n",
 324                tool_result.tool_name, tool_result.tool_use_id
 325            ));
 326            if tool_result.is_error {
 327                markdown.push_str("**ERROR:**\n");
 328            }
 329
 330            match &tool_result.content {
 331                LanguageModelToolResultContent::Text(text) => {
 332                    writeln!(markdown, "{text}\n").ok();
 333                }
 334                LanguageModelToolResultContent::Image(_) => {
 335                    writeln!(markdown, "<image />\n").ok();
 336                }
 337            }
 338
 339            if let Some(output) = tool_result.output.as_ref() {
 340                writeln!(
 341                    markdown,
 342                    "**Debug Output**:\n\n```json\n{}\n```\n",
 343                    serde_json::to_string_pretty(output).unwrap()
 344                )
 345                .unwrap();
 346            }
 347        }
 348
 349        markdown
 350    }
 351
 352    pub fn to_request(&self) -> Vec<LanguageModelRequestMessage> {
 353        let mut assistant_message = LanguageModelRequestMessage {
 354            role: Role::Assistant,
 355            content: Vec::with_capacity(self.content.len()),
 356            cache: false,
 357        };
 358        for chunk in &self.content {
 359            let chunk = match chunk {
 360                AgentMessageContent::Text(text) => {
 361                    language_model::MessageContent::Text(text.clone())
 362                }
 363                AgentMessageContent::Thinking { text, signature } => {
 364                    language_model::MessageContent::Thinking {
 365                        text: text.clone(),
 366                        signature: signature.clone(),
 367                    }
 368                }
 369                AgentMessageContent::RedactedThinking(value) => {
 370                    language_model::MessageContent::RedactedThinking(value.clone())
 371                }
 372                AgentMessageContent::ToolUse(value) => {
 373                    language_model::MessageContent::ToolUse(value.clone())
 374                }
 375            };
 376            assistant_message.content.push(chunk);
 377        }
 378
 379        let mut user_message = LanguageModelRequestMessage {
 380            role: Role::User,
 381            content: Vec::new(),
 382            cache: false,
 383        };
 384
 385        for tool_result in self.tool_results.values() {
 386            user_message
 387                .content
 388                .push(language_model::MessageContent::ToolResult(
 389                    tool_result.clone(),
 390                ));
 391        }
 392
 393        let mut messages = Vec::new();
 394        if !assistant_message.content.is_empty() {
 395            messages.push(assistant_message);
 396        }
 397        if !user_message.content.is_empty() {
 398            messages.push(user_message);
 399        }
 400        messages
 401    }
 402}
 403
 404#[derive(Default, Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
 405pub struct AgentMessage {
 406    pub content: Vec<AgentMessageContent>,
 407    pub tool_results: IndexMap<LanguageModelToolUseId, LanguageModelToolResult>,
 408}
 409
 410#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
 411pub enum AgentMessageContent {
 412    Text(String),
 413    Thinking {
 414        text: String,
 415        signature: Option<String>,
 416    },
 417    RedactedThinking(String),
 418    ToolUse(LanguageModelToolUse),
 419}
 420
 421#[derive(Debug)]
 422pub enum ThreadEvent {
 423    UserMessage(UserMessage),
 424    AgentText(String),
 425    AgentThinking(String),
 426    ToolCall(acp::ToolCall),
 427    ToolCallUpdate(acp_thread::ToolCallUpdate),
 428    ToolCallAuthorization(ToolCallAuthorization),
 429    Stop(acp::StopReason),
 430}
 431
 432#[derive(Debug)]
 433pub struct ToolCallAuthorization {
 434    pub tool_call: acp::ToolCallUpdate,
 435    pub options: Vec<acp::PermissionOption>,
 436    pub response: oneshot::Sender<acp::PermissionOptionId>,
 437}
 438
 439enum ThreadTitle {
 440    None,
 441    Pending(Task<()>),
 442    Done(Result<SharedString>),
 443}
 444
 445impl ThreadTitle {
 446    pub fn unwrap_or_default(&self) -> SharedString {
 447        if let ThreadTitle::Done(Ok(title)) = self {
 448            title.clone()
 449        } else {
 450            "New Thread".into()
 451        }
 452    }
 453}
 454
 455pub struct Thread {
 456    id: acp::SessionId,
 457    prompt_id: PromptId,
 458    updated_at: DateTime<Utc>,
 459    title: ThreadTitle,
 460    summary: DetailedSummaryState,
 461    messages: Vec<Message>,
 462    completion_mode: CompletionMode,
 463    /// Holds the task that handles agent interaction until the end of the turn.
 464    /// Survives across multiple requests as the model performs tool calls and
 465    /// we run tools, report their results.
 466    running_turn: Option<RunningTurn>,
 467    pending_message: Option<AgentMessage>,
 468    tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
 469    tool_use_limit_reached: bool,
 470    request_token_usage: Vec<TokenUsage>,
 471    cumulative_token_usage: TokenUsage,
 472    initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
 473    context_server_registry: Entity<ContextServerRegistry>,
 474    profile_id: AgentProfileId,
 475    project_context: Rc<RefCell<ProjectContext>>,
 476    templates: Arc<Templates>,
 477    model: Arc<dyn LanguageModel>,
 478    project: Entity<Project>,
 479    action_log: Entity<ActionLog>,
 480}
 481
 482impl Thread {
 483    pub fn new(
 484        id: acp::SessionId,
 485        project: Entity<Project>,
 486        project_context: Rc<RefCell<ProjectContext>>,
 487        context_server_registry: Entity<ContextServerRegistry>,
 488        action_log: Entity<ActionLog>,
 489        templates: Arc<Templates>,
 490        model: Arc<dyn LanguageModel>,
 491        cx: &mut Context<Self>,
 492    ) -> Self {
 493        let profile_id = AgentSettings::get_global(cx).default_profile.clone();
 494        Self {
 495            id,
 496            prompt_id: PromptId::new(),
 497            updated_at: Utc::now(),
 498            title: ThreadTitle::None,
 499            summary: DetailedSummaryState::default(),
 500            messages: Vec::new(),
 501            completion_mode: CompletionMode::Normal,
 502            running_turn: None,
 503            pending_message: None,
 504            tools: BTreeMap::default(),
 505            tool_use_limit_reached: false,
 506            request_token_usage: Vec::new(),
 507            cumulative_token_usage: TokenUsage::default(),
 508            initial_project_snapshot: {
 509                let project_snapshot = Self::project_snapshot(project.clone(), cx);
 510                cx.foreground_executor()
 511                    .spawn(async move { Some(project_snapshot.await) })
 512                    .shared()
 513            },
 514            context_server_registry,
 515            profile_id,
 516            project_context,
 517            templates,
 518            model,
 519            project,
 520            action_log,
 521        }
 522    }
 523
 524    pub fn id(&self) -> &acp::SessionId {
 525        &self.id
 526    }
 527
 528    pub fn from_db(
 529        id: acp::SessionId,
 530        db_thread: DbThread,
 531        project: Entity<Project>,
 532        project_context: Rc<RefCell<ProjectContext>>,
 533        context_server_registry: Entity<ContextServerRegistry>,
 534        action_log: Entity<ActionLog>,
 535        templates: Arc<Templates>,
 536        model: Arc<dyn LanguageModel>,
 537        cx: &mut Context<Self>,
 538    ) -> Self {
 539        let profile_id = db_thread
 540            .profile
 541            .unwrap_or_else(|| AgentSettings::get_global(cx).default_profile.clone());
 542        Self {
 543            id,
 544            prompt_id: PromptId::new(),
 545            title: ThreadTitle::Done(Ok(db_thread.title.clone())),
 546            summary: db_thread.summary,
 547            messages: db_thread.messages,
 548            completion_mode: CompletionMode::Normal,
 549            running_turn: None,
 550            pending_message: None,
 551            tools: BTreeMap::default(),
 552            tool_use_limit_reached: false,
 553            request_token_usage: db_thread.request_token_usage.clone(),
 554            cumulative_token_usage: db_thread.cumulative_token_usage.clone(),
 555            initial_project_snapshot: Task::ready(db_thread.initial_project_snapshot).shared(),
 556            context_server_registry,
 557            profile_id,
 558            project_context,
 559            templates,
 560            model,
 561            project,
 562            action_log,
 563            updated_at: db_thread.updated_at, // todo!(figure out if we can remove the "recently opened" list)
 564        }
 565    }
 566
 567    pub fn to_db(&self, cx: &App) -> Task<DbThread> {
 568        let initial_project_snapshot = self.initial_project_snapshot.clone();
 569        let mut thread = DbThread {
 570            title: self.title.unwrap_or_default(),
 571            messages: self.messages.clone(),
 572            updated_at: self.updated_at.clone(),
 573            summary: self.summary.clone(),
 574            initial_project_snapshot: None,
 575            cumulative_token_usage: self.cumulative_token_usage.clone(),
 576            request_token_usage: self.request_token_usage.clone(),
 577            model: Some(DbLanguageModel {
 578                provider: self.model.provider_id().to_string(),
 579                model: self.model.name().0.to_string(),
 580            }),
 581            completion_mode: Some(self.completion_mode.into()),
 582            profile: Some(self.profile_id.clone()),
 583        };
 584
 585        cx.background_spawn(async move {
 586            let initial_project_snapshot = initial_project_snapshot.await;
 587            thread.initial_project_snapshot = initial_project_snapshot;
 588            thread
 589        })
 590    }
 591
 592    pub fn replay(
 593        &mut self,
 594        cx: &mut Context<Self>,
 595    ) -> mpsc::UnboundedReceiver<Result<ThreadEvent>> {
 596        let (tx, rx) = mpsc::unbounded();
 597        let stream = ThreadEventStream(tx);
 598        for message in &self.messages {
 599            match message {
 600                Message::User(user_message) => stream.send_user_message(&user_message),
 601                Message::Agent(assistant_message) => {
 602                    for content in &assistant_message.content {
 603                        match content {
 604                            AgentMessageContent::Text(text) => stream.send_text(text),
 605                            AgentMessageContent::Thinking { text, .. } => {
 606                                stream.send_thinking(text)
 607                            }
 608                            AgentMessageContent::RedactedThinking(_) => {}
 609                            AgentMessageContent::ToolUse(tool_use) => {
 610                                self.replay_tool_call(
 611                                    tool_use,
 612                                    assistant_message.tool_results.get(&tool_use.id),
 613                                    &stream,
 614                                    cx,
 615                                );
 616                            }
 617                        }
 618                    }
 619                }
 620                Message::Resume => {}
 621            }
 622        }
 623        rx
 624    }
 625
 626    fn replay_tool_call(
 627        &self,
 628        tool_use: &LanguageModelToolUse,
 629        tool_result: Option<&LanguageModelToolResult>,
 630        stream: &ThreadEventStream,
 631        cx: &mut Context<Self>,
 632    ) {
 633        let Some(tool) = self.tools.get(tool_use.name.as_ref()) else {
 634            stream
 635                .0
 636                .unbounded_send(Ok(ThreadEvent::ToolCall(acp::ToolCall {
 637                    id: acp::ToolCallId(tool_use.id.to_string().into()),
 638                    title: tool_use.name.to_string(),
 639                    kind: acp::ToolKind::Other,
 640                    status: acp::ToolCallStatus::Failed,
 641                    content: Vec::new(),
 642                    locations: Vec::new(),
 643                    raw_input: Some(tool_use.input.clone()),
 644                    raw_output: None,
 645                })))
 646                .ok();
 647            return;
 648        };
 649
 650        let title = tool.initial_title(tool_use.input.clone());
 651        let kind = tool.kind();
 652        stream.send_tool_call(&tool_use.id, title, kind, tool_use.input.clone());
 653
 654        let output = tool_result
 655            .as_ref()
 656            .and_then(|result| result.output.clone());
 657        if let Some(output) = output.clone() {
 658            let tool_event_stream = ToolCallEventStream::new(
 659                tool_use.id.clone(),
 660                stream.clone(),
 661                Some(self.project.read(cx).fs().clone()),
 662            );
 663            tool.replay(tool_use.input.clone(), output, tool_event_stream, cx)
 664                .log_err();
 665        }
 666
 667        stream.update_tool_call_fields(
 668            &tool_use.id,
 669            acp::ToolCallUpdateFields {
 670                status: Some(acp::ToolCallStatus::Completed),
 671                raw_output: output,
 672                ..Default::default()
 673            },
 674        );
 675    }
 676
 677    /// Create a snapshot of the current project state including git information and unsaved buffers.
 678    fn project_snapshot(
 679        project: Entity<Project>,
 680        cx: &mut Context<Self>,
 681    ) -> Task<Arc<agent::thread::ProjectSnapshot>> {
 682        let git_store = project.read(cx).git_store().clone();
 683        let worktree_snapshots: Vec<_> = project
 684            .read(cx)
 685            .visible_worktrees(cx)
 686            .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
 687            .collect();
 688
 689        cx.spawn(async move |_, cx| {
 690            let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
 691
 692            let mut unsaved_buffers = Vec::new();
 693            cx.update(|app_cx| {
 694                let buffer_store = project.read(app_cx).buffer_store();
 695                for buffer_handle in buffer_store.read(app_cx).buffers() {
 696                    let buffer = buffer_handle.read(app_cx);
 697                    if buffer.is_dirty() {
 698                        if let Some(file) = buffer.file() {
 699                            let path = file.path().to_string_lossy().to_string();
 700                            unsaved_buffers.push(path);
 701                        }
 702                    }
 703                }
 704            })
 705            .ok();
 706
 707            Arc::new(ProjectSnapshot {
 708                worktree_snapshots,
 709                unsaved_buffer_paths: unsaved_buffers,
 710                timestamp: Utc::now(),
 711            })
 712        })
 713    }
 714
 715    fn worktree_snapshot(
 716        worktree: Entity<project::Worktree>,
 717        git_store: Entity<GitStore>,
 718        cx: &App,
 719    ) -> Task<agent::thread::WorktreeSnapshot> {
 720        cx.spawn(async move |cx| {
 721            // Get worktree path and snapshot
 722            let worktree_info = cx.update(|app_cx| {
 723                let worktree = worktree.read(app_cx);
 724                let path = worktree.abs_path().to_string_lossy().to_string();
 725                let snapshot = worktree.snapshot();
 726                (path, snapshot)
 727            });
 728
 729            let Ok((worktree_path, _snapshot)) = worktree_info else {
 730                return WorktreeSnapshot {
 731                    worktree_path: String::new(),
 732                    git_state: None,
 733                };
 734            };
 735
 736            let git_state = git_store
 737                .update(cx, |git_store, cx| {
 738                    git_store
 739                        .repositories()
 740                        .values()
 741                        .find(|repo| {
 742                            repo.read(cx)
 743                                .abs_path_to_repo_path(&worktree.read(cx).abs_path())
 744                                .is_some()
 745                        })
 746                        .cloned()
 747                })
 748                .ok()
 749                .flatten()
 750                .map(|repo| {
 751                    repo.update(cx, |repo, _| {
 752                        let current_branch =
 753                            repo.branch.as_ref().map(|branch| branch.name().to_owned());
 754                        repo.send_job(None, |state, _| async move {
 755                            let RepositoryState::Local { backend, .. } = state else {
 756                                return GitState {
 757                                    remote_url: None,
 758                                    head_sha: None,
 759                                    current_branch,
 760                                    diff: None,
 761                                };
 762                            };
 763
 764                            let remote_url = backend.remote_url("origin");
 765                            let head_sha = backend.head_sha().await;
 766                            let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
 767
 768                            GitState {
 769                                remote_url,
 770                                head_sha,
 771                                current_branch,
 772                                diff,
 773                            }
 774                        })
 775                    })
 776                });
 777
 778            let git_state = match git_state {
 779                Some(git_state) => match git_state.ok() {
 780                    Some(git_state) => git_state.await.ok(),
 781                    None => None,
 782                },
 783                None => None,
 784            };
 785
 786            WorktreeSnapshot {
 787                worktree_path,
 788                git_state,
 789            }
 790        })
 791    }
 792
 793    pub fn project(&self) -> &Entity<Project> {
 794        &self.project
 795    }
 796
 797    pub fn action_log(&self) -> &Entity<ActionLog> {
 798        &self.action_log
 799    }
 800
 801    pub fn model(&self) -> &Arc<dyn LanguageModel> {
 802        &self.model
 803    }
 804
 805    pub fn set_model(&mut self, model: Arc<dyn LanguageModel>) {
 806        self.model = model;
 807    }
 808
 809    pub fn completion_mode(&self) -> CompletionMode {
 810        self.completion_mode
 811    }
 812
 813    pub fn set_completion_mode(&mut self, mode: CompletionMode) {
 814        self.completion_mode = mode;
 815    }
 816
 817    #[cfg(any(test, feature = "test-support"))]
 818    pub fn last_message(&self) -> Option<Message> {
 819        if let Some(message) = self.pending_message.clone() {
 820            Some(Message::Agent(message))
 821        } else {
 822            self.messages.last().cloned()
 823        }
 824    }
 825
 826    pub fn add_tool(&mut self, tool: impl AgentTool) {
 827        self.tools.insert(tool.name(), tool.erase());
 828    }
 829
 830    pub fn remove_tool(&mut self, name: &str) -> bool {
 831        self.tools.remove(name).is_some()
 832    }
 833
 834    pub fn profile(&self) -> &AgentProfileId {
 835        &self.profile_id
 836    }
 837
 838    pub fn set_profile(&mut self, profile_id: AgentProfileId) {
 839        self.profile_id = profile_id;
 840    }
 841
 842    pub fn cancel(&mut self) {
 843        if let Some(running_turn) = self.running_turn.take() {
 844            running_turn.cancel();
 845        }
 846        self.flush_pending_message();
 847    }
 848
 849    pub fn truncate(&mut self, message_id: UserMessageId) -> Result<()> {
 850        self.cancel();
 851        let Some(position) = self.messages.iter().position(
 852            |msg| matches!(msg, Message::User(UserMessage { id, .. }) if id == &message_id),
 853        ) else {
 854            return Err(anyhow!("Message not found"));
 855        };
 856        self.messages.truncate(position);
 857        Ok(())
 858    }
 859
 860    pub fn resume(
 861        &mut self,
 862        cx: &mut Context<Self>,
 863    ) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>> {
 864        anyhow::ensure!(
 865            self.tool_use_limit_reached,
 866            "can only resume after tool use limit is reached"
 867        );
 868
 869        self.messages.push(Message::Resume);
 870        cx.notify();
 871
 872        log::info!("Total messages in thread: {}", self.messages.len());
 873        Ok(self.run_turn(cx))
 874    }
 875
 876    /// Sending a message results in the model streaming a response, which could include tool calls.
 877    /// After calling tools, the model will stops and waits for any outstanding tool calls to be completed and their results sent.
 878    /// The returned channel will report all the occurrences in which the model stops before erroring or ending its turn.
 879    pub fn send<T>(
 880        &mut self,
 881        id: UserMessageId,
 882        content: impl IntoIterator<Item = T>,
 883        cx: &mut Context<Self>,
 884    ) -> mpsc::UnboundedReceiver<Result<ThreadEvent>>
 885    where
 886        T: Into<UserMessageContent>,
 887    {
 888        log::info!("Thread::send called with model: {:?}", self.model.name());
 889        self.advance_prompt_id();
 890
 891        let content = content.into_iter().map(Into::into).collect::<Vec<_>>();
 892        log::debug!("Thread::send content: {:?}", content);
 893
 894        self.messages
 895            .push(Message::User(UserMessage { id, content }));
 896        cx.notify();
 897
 898        log::info!("Total messages in thread: {}", self.messages.len());
 899        self.run_turn(cx)
 900    }
 901
 902    fn run_turn(&mut self, cx: &mut Context<Self>) -> mpsc::UnboundedReceiver<Result<ThreadEvent>> {
 903        self.cancel();
 904
 905        let model = self.model.clone();
 906        let (events_tx, events_rx) = mpsc::unbounded::<Result<ThreadEvent>>();
 907        let event_stream = ThreadEventStream(events_tx);
 908        let message_ix = self.messages.len().saturating_sub(1);
 909        self.tool_use_limit_reached = false;
 910        self.running_turn = Some(RunningTurn {
 911            event_stream: event_stream.clone(),
 912            _task: cx.spawn(async move |this, cx| {
 913                log::info!("Starting agent turn execution");
 914                let turn_result: Result<()> = async {
 915                    let mut completion_intent = CompletionIntent::UserPrompt;
 916                    loop {
 917                        log::debug!(
 918                            "Building completion request with intent: {:?}",
 919                            completion_intent
 920                        );
 921                        let request = this.update(cx, |this, cx| {
 922                            this.build_completion_request(completion_intent, cx)
 923                        })?;
 924
 925                        log::info!("Calling model.stream_completion");
 926                        let mut events = model.stream_completion(request, cx).await?;
 927                        log::debug!("Stream completion started successfully");
 928
 929                        let mut tool_use_limit_reached = false;
 930                        let mut tool_uses = FuturesUnordered::new();
 931                        while let Some(event) = events.next().await {
 932                            match event? {
 933                                LanguageModelCompletionEvent::StatusUpdate(
 934                                    CompletionRequestStatus::ToolUseLimitReached,
 935                                ) => {
 936                                    tool_use_limit_reached = true;
 937                                }
 938                                LanguageModelCompletionEvent::Stop(reason) => {
 939                                    event_stream.send_stop(reason);
 940                                    if reason == StopReason::Refusal {
 941                                        this.update(cx, |this, _cx| {
 942                                            this.flush_pending_message();
 943                                            this.messages.truncate(message_ix);
 944                                        })?;
 945                                        return Ok(());
 946                                    }
 947                                }
 948                                event => {
 949                                    log::trace!("Received completion event: {:?}", event);
 950                                    this.update(cx, |this, cx| {
 951                                        tool_uses.extend(this.handle_streamed_completion_event(
 952                                            event,
 953                                            &event_stream,
 954                                            cx,
 955                                        ));
 956                                    })
 957                                    .ok();
 958                                }
 959                            }
 960                        }
 961
 962                        let used_tools = tool_uses.is_empty();
 963                        while let Some(tool_result) = tool_uses.next().await {
 964                            log::info!("Tool finished {:?}", tool_result);
 965
 966                            event_stream.update_tool_call_fields(
 967                                &tool_result.tool_use_id,
 968                                acp::ToolCallUpdateFields {
 969                                    status: Some(if tool_result.is_error {
 970                                        acp::ToolCallStatus::Failed
 971                                    } else {
 972                                        acp::ToolCallStatus::Completed
 973                                    }),
 974                                    raw_output: tool_result.output.clone(),
 975                                    ..Default::default()
 976                                },
 977                            );
 978                            this.update(cx, |this, _cx| {
 979                                this.pending_message()
 980                                    .tool_results
 981                                    .insert(tool_result.tool_use_id.clone(), tool_result);
 982                            })
 983                            .ok();
 984                        }
 985
 986                        if tool_use_limit_reached {
 987                            log::info!("Tool use limit reached, completing turn");
 988                            this.update(cx, |this, _cx| this.tool_use_limit_reached = true)?;
 989                            return Err(language_model::ToolUseLimitReachedError.into());
 990                        } else if used_tools {
 991                            log::info!("No tool uses found, completing turn");
 992                            return Ok(());
 993                        } else {
 994                            this.update(cx, |this, _| this.flush_pending_message())?;
 995                            completion_intent = CompletionIntent::ToolResults;
 996                        }
 997                    }
 998                }
 999                .await;
1000
1001                if let Err(error) = turn_result {
1002                    log::error!("Turn execution failed: {:?}", error);
1003                    event_stream.send_error(error);
1004                } else {
1005                    log::info!("Turn execution completed successfully");
1006                }
1007
1008                this.update(cx, |this, _| {
1009                    this.flush_pending_message();
1010                    this.running_turn.take();
1011                })
1012                .ok();
1013            }),
1014        });
1015        events_rx
1016    }
1017
1018    pub fn build_system_message(&self) -> LanguageModelRequestMessage {
1019        log::debug!("Building system message");
1020        let prompt = SystemPromptTemplate {
1021            project: &self.project_context.borrow(),
1022            available_tools: self.tools.keys().cloned().collect(),
1023        }
1024        .render(&self.templates)
1025        .context("failed to build system prompt")
1026        .expect("Invalid template");
1027        log::debug!("System message built");
1028        LanguageModelRequestMessage {
1029            role: Role::System,
1030            content: vec![prompt.into()],
1031            cache: true,
1032        }
1033    }
1034
1035    /// A helper method that's called on every streamed completion event.
1036    /// Returns an optional tool result task, which the main agentic loop in
1037    /// send will send back to the model when it resolves.
1038    fn handle_streamed_completion_event(
1039        &mut self,
1040        event: LanguageModelCompletionEvent,
1041        event_stream: &ThreadEventStream,
1042        cx: &mut Context<Self>,
1043    ) -> Option<Task<LanguageModelToolResult>> {
1044        log::trace!("Handling streamed completion event: {:?}", event);
1045        use LanguageModelCompletionEvent::*;
1046
1047        match event {
1048            StartMessage { .. } => {
1049                self.flush_pending_message();
1050                self.pending_message = Some(AgentMessage::default());
1051            }
1052            Text(new_text) => self.handle_text_event(new_text, event_stream, cx),
1053            Thinking { text, signature } => {
1054                self.handle_thinking_event(text, signature, event_stream, cx)
1055            }
1056            RedactedThinking { data } => self.handle_redacted_thinking_event(data, cx),
1057            ToolUse(tool_use) => {
1058                return self.handle_tool_use_event(tool_use, event_stream, cx);
1059            }
1060            ToolUseJsonParseError {
1061                id,
1062                tool_name,
1063                raw_input,
1064                json_parse_error,
1065            } => {
1066                return Some(Task::ready(self.handle_tool_use_json_parse_error_event(
1067                    id,
1068                    tool_name,
1069                    raw_input,
1070                    json_parse_error,
1071                )));
1072            }
1073            UsageUpdate(_) | StatusUpdate(_) => {}
1074            Stop(_) => unreachable!(),
1075        }
1076
1077        None
1078    }
1079
1080    fn handle_text_event(
1081        &mut self,
1082        new_text: String,
1083        event_stream: &ThreadEventStream,
1084        cx: &mut Context<Self>,
1085    ) {
1086        event_stream.send_text(&new_text);
1087
1088        let last_message = self.pending_message();
1089        if let Some(AgentMessageContent::Text(text)) = last_message.content.last_mut() {
1090            text.push_str(&new_text);
1091        } else {
1092            last_message
1093                .content
1094                .push(AgentMessageContent::Text(new_text));
1095        }
1096
1097        cx.notify();
1098    }
1099
1100    fn handle_thinking_event(
1101        &mut self,
1102        new_text: String,
1103        new_signature: Option<String>,
1104        event_stream: &ThreadEventStream,
1105        cx: &mut Context<Self>,
1106    ) {
1107        event_stream.send_thinking(&new_text);
1108
1109        let last_message = self.pending_message();
1110        if let Some(AgentMessageContent::Thinking { text, signature }) =
1111            last_message.content.last_mut()
1112        {
1113            text.push_str(&new_text);
1114            *signature = new_signature.or(signature.take());
1115        } else {
1116            last_message.content.push(AgentMessageContent::Thinking {
1117                text: new_text,
1118                signature: new_signature,
1119            });
1120        }
1121
1122        cx.notify();
1123    }
1124
1125    fn handle_redacted_thinking_event(&mut self, data: String, cx: &mut Context<Self>) {
1126        let last_message = self.pending_message();
1127        last_message
1128            .content
1129            .push(AgentMessageContent::RedactedThinking(data));
1130        cx.notify();
1131    }
1132
1133    fn handle_tool_use_event(
1134        &mut self,
1135        tool_use: LanguageModelToolUse,
1136        event_stream: &ThreadEventStream,
1137        cx: &mut Context<Self>,
1138    ) -> Option<Task<LanguageModelToolResult>> {
1139        cx.notify();
1140
1141        let tool = self.tools.get(tool_use.name.as_ref()).cloned();
1142        let mut title = SharedString::from(&tool_use.name);
1143        let mut kind = acp::ToolKind::Other;
1144        if let Some(tool) = tool.as_ref() {
1145            title = tool.initial_title(tool_use.input.clone());
1146            kind = tool.kind();
1147        }
1148
1149        // Ensure the last message ends in the current tool use
1150        let last_message = self.pending_message();
1151        let push_new_tool_use = last_message.content.last_mut().map_or(true, |content| {
1152            if let AgentMessageContent::ToolUse(last_tool_use) = content {
1153                if last_tool_use.id == tool_use.id {
1154                    *last_tool_use = tool_use.clone();
1155                    false
1156                } else {
1157                    true
1158                }
1159            } else {
1160                true
1161            }
1162        });
1163
1164        if push_new_tool_use {
1165            event_stream.send_tool_call(&tool_use.id, title, kind, tool_use.input.clone());
1166            last_message
1167                .content
1168                .push(AgentMessageContent::ToolUse(tool_use.clone()));
1169        } else {
1170            event_stream.update_tool_call_fields(
1171                &tool_use.id,
1172                acp::ToolCallUpdateFields {
1173                    title: Some(title.into()),
1174                    kind: Some(kind),
1175                    raw_input: Some(tool_use.input.clone()),
1176                    ..Default::default()
1177                },
1178            );
1179        }
1180
1181        if !tool_use.is_input_complete {
1182            return None;
1183        }
1184
1185        let Some(tool) = tool else {
1186            let content = format!("No tool named {} exists", tool_use.name);
1187            return Some(Task::ready(LanguageModelToolResult {
1188                content: LanguageModelToolResultContent::Text(Arc::from(content)),
1189                tool_use_id: tool_use.id,
1190                tool_name: tool_use.name,
1191                is_error: true,
1192                output: None,
1193            }));
1194        };
1195
1196        let fs = self.project.read(cx).fs().clone();
1197        let tool_event_stream =
1198            ToolCallEventStream::new(tool_use.id.clone(), event_stream.clone(), Some(fs));
1199        tool_event_stream.update_fields(acp::ToolCallUpdateFields {
1200            status: Some(acp::ToolCallStatus::InProgress),
1201            ..Default::default()
1202        });
1203        let supports_images = self.model.supports_images();
1204        let tool_result = tool.run(tool_use.input, tool_event_stream, cx);
1205        log::info!("Running tool {}", tool_use.name);
1206        Some(cx.foreground_executor().spawn(async move {
1207            let tool_result = tool_result.await.and_then(|output| {
1208                if let LanguageModelToolResultContent::Image(_) = &output.llm_output {
1209                    if !supports_images {
1210                        return Err(anyhow!(
1211                            "Attempted to read an image, but this model doesn't support it.",
1212                        ));
1213                    }
1214                }
1215                Ok(output)
1216            });
1217
1218            match tool_result {
1219                Ok(output) => LanguageModelToolResult {
1220                    tool_use_id: tool_use.id,
1221                    tool_name: tool_use.name,
1222                    is_error: false,
1223                    content: output.llm_output,
1224                    output: Some(output.raw_output),
1225                },
1226                Err(error) => LanguageModelToolResult {
1227                    tool_use_id: tool_use.id,
1228                    tool_name: tool_use.name,
1229                    is_error: true,
1230                    content: LanguageModelToolResultContent::Text(Arc::from(error.to_string())),
1231                    output: None,
1232                },
1233            }
1234        }))
1235    }
1236
1237    fn handle_tool_use_json_parse_error_event(
1238        &mut self,
1239        tool_use_id: LanguageModelToolUseId,
1240        tool_name: Arc<str>,
1241        raw_input: Arc<str>,
1242        json_parse_error: String,
1243    ) -> LanguageModelToolResult {
1244        let tool_output = format!("Error parsing input JSON: {json_parse_error}");
1245        LanguageModelToolResult {
1246            tool_use_id,
1247            tool_name,
1248            is_error: true,
1249            content: LanguageModelToolResultContent::Text(tool_output.into()),
1250            output: Some(serde_json::Value::String(raw_input.to_string())),
1251        }
1252    }
1253
1254    fn pending_message(&mut self) -> &mut AgentMessage {
1255        self.pending_message.get_or_insert_default()
1256    }
1257
1258    fn flush_pending_message(&mut self) {
1259        let Some(mut message) = self.pending_message.take() else {
1260            return;
1261        };
1262
1263        for content in &message.content {
1264            let AgentMessageContent::ToolUse(tool_use) = content else {
1265                continue;
1266            };
1267
1268            if !message.tool_results.contains_key(&tool_use.id) {
1269                message.tool_results.insert(
1270                    tool_use.id.clone(),
1271                    LanguageModelToolResult {
1272                        tool_use_id: tool_use.id.clone(),
1273                        tool_name: tool_use.name.clone(),
1274                        is_error: true,
1275                        content: LanguageModelToolResultContent::Text(TOOL_CANCELED_MESSAGE.into()),
1276                        output: None,
1277                    },
1278                );
1279            }
1280        }
1281
1282        self.messages.push(Message::Agent(message));
1283    }
1284
1285    pub(crate) fn build_completion_request(
1286        &self,
1287        completion_intent: CompletionIntent,
1288        cx: &mut App,
1289    ) -> LanguageModelRequest {
1290        log::debug!("Building completion request");
1291        log::debug!("Completion intent: {:?}", completion_intent);
1292        log::debug!("Completion mode: {:?}", self.completion_mode);
1293
1294        let messages = self.build_request_messages();
1295        log::info!("Request will include {} messages", messages.len());
1296
1297        let tools = if let Some(tools) = self.tools(cx).log_err() {
1298            tools
1299                .filter_map(|tool| {
1300                    let tool_name = tool.name().to_string();
1301                    log::trace!("Including tool: {}", tool_name);
1302                    Some(LanguageModelRequestTool {
1303                        name: tool_name,
1304                        description: tool.description().to_string(),
1305                        input_schema: tool
1306                            .input_schema(self.model.tool_input_format())
1307                            .log_err()?,
1308                    })
1309                })
1310                .collect()
1311        } else {
1312            Vec::new()
1313        };
1314
1315        log::info!("Request includes {} tools", tools.len());
1316
1317        let request = LanguageModelRequest {
1318            thread_id: Some(self.id.to_string()),
1319            prompt_id: Some(self.prompt_id.to_string()),
1320            intent: Some(completion_intent),
1321            mode: Some(self.completion_mode.into()),
1322            messages,
1323            tools,
1324            tool_choice: None,
1325            stop: Vec::new(),
1326            temperature: AgentSettings::temperature_for_model(self.model(), cx),
1327            thinking_allowed: true,
1328        };
1329
1330        log::debug!("Completion request built successfully");
1331        request
1332    }
1333
1334    fn tools<'a>(&'a self, cx: &'a App) -> Result<impl Iterator<Item = &'a Arc<dyn AnyAgentTool>>> {
1335        let profile = AgentSettings::get_global(cx)
1336            .profiles
1337            .get(&self.profile_id)
1338            .context("profile not found")?;
1339        let provider_id = self.model.provider_id();
1340
1341        Ok(self
1342            .tools
1343            .iter()
1344            .filter(move |(_, tool)| tool.supported_provider(&provider_id))
1345            .filter_map(|(tool_name, tool)| {
1346                if profile.is_tool_enabled(tool_name) {
1347                    Some(tool)
1348                } else {
1349                    None
1350                }
1351            })
1352            .chain(self.context_server_registry.read(cx).servers().flat_map(
1353                |(server_id, tools)| {
1354                    tools.iter().filter_map(|(tool_name, tool)| {
1355                        if profile.is_context_server_tool_enabled(&server_id.0, tool_name) {
1356                            Some(tool)
1357                        } else {
1358                            None
1359                        }
1360                    })
1361                },
1362            )))
1363    }
1364
1365    fn build_request_messages(&self) -> Vec<LanguageModelRequestMessage> {
1366        log::trace!(
1367            "Building request messages from {} thread messages",
1368            self.messages.len()
1369        );
1370        let mut messages = vec![self.build_system_message()];
1371        for message in &self.messages {
1372            match message {
1373                Message::User(message) => messages.push(message.to_request()),
1374                Message::Agent(message) => messages.extend(message.to_request()),
1375                Message::Resume => messages.push(LanguageModelRequestMessage {
1376                    role: Role::User,
1377                    content: vec!["Continue where you left off".into()],
1378                    cache: false,
1379                }),
1380            }
1381        }
1382
1383        if let Some(message) = self.pending_message.as_ref() {
1384            messages.extend(message.to_request());
1385        }
1386
1387        if let Some(last_user_message) = messages
1388            .iter_mut()
1389            .rev()
1390            .find(|message| message.role == Role::User)
1391        {
1392            last_user_message.cache = true;
1393        }
1394
1395        messages
1396    }
1397
1398    pub fn to_markdown(&self) -> String {
1399        let mut markdown = String::new();
1400        for (ix, message) in self.messages.iter().enumerate() {
1401            if ix > 0 {
1402                markdown.push('\n');
1403            }
1404            markdown.push_str(&message.to_markdown());
1405        }
1406
1407        if let Some(message) = self.pending_message.as_ref() {
1408            markdown.push('\n');
1409            markdown.push_str(&message.to_markdown());
1410        }
1411
1412        markdown
1413    }
1414
1415    fn advance_prompt_id(&mut self) {
1416        self.prompt_id = PromptId::new();
1417    }
1418}
1419
1420struct RunningTurn {
1421    /// Holds the task that handles agent interaction until the end of the turn.
1422    /// Survives across multiple requests as the model performs tool calls and
1423    /// we run tools, report their results.
1424    _task: Task<()>,
1425    /// The current event stream for the running turn. Used to report a final
1426    /// cancellation event if we cancel the turn.
1427    event_stream: ThreadEventStream,
1428}
1429
1430impl RunningTurn {
1431    fn cancel(self) {
1432        log::debug!("Cancelling in progress turn");
1433        self.event_stream.send_canceled();
1434    }
1435}
1436
1437pub trait AgentTool
1438where
1439    Self: 'static + Sized,
1440{
1441    type Input: for<'de> Deserialize<'de> + Serialize + JsonSchema;
1442    type Output: for<'de> Deserialize<'de> + Serialize + Into<LanguageModelToolResultContent>;
1443
1444    fn name(&self) -> SharedString;
1445
1446    fn description(&self) -> SharedString {
1447        let schema = schemars::schema_for!(Self::Input);
1448        SharedString::new(
1449            schema
1450                .get("description")
1451                .and_then(|description| description.as_str())
1452                .unwrap_or_default(),
1453        )
1454    }
1455
1456    fn kind(&self) -> acp::ToolKind;
1457
1458    /// The initial tool title to display. Can be updated during the tool run.
1459    fn initial_title(&self, input: Result<Self::Input, serde_json::Value>) -> SharedString;
1460
1461    /// Returns the JSON schema that describes the tool's input.
1462    fn input_schema(&self) -> Schema {
1463        schemars::schema_for!(Self::Input)
1464    }
1465
1466    /// Some tools rely on a provider for the underlying billing or other reasons.
1467    /// Allow the tool to check if they are compatible, or should be filtered out.
1468    fn supported_provider(&self, _provider: &LanguageModelProviderId) -> bool {
1469        true
1470    }
1471
1472    /// Runs the tool with the provided input.
1473    fn run(
1474        self: Arc<Self>,
1475        input: Self::Input,
1476        event_stream: ToolCallEventStream,
1477        cx: &mut App,
1478    ) -> Task<Result<Self::Output>>;
1479
1480    /// Emits events for a previous execution of the tool.
1481    fn replay(
1482        &self,
1483        _input: Self::Input,
1484        _output: Self::Output,
1485        _event_stream: ToolCallEventStream,
1486        _cx: &mut App,
1487    ) -> Result<()> {
1488        Ok(())
1489    }
1490
1491    fn erase(self) -> Arc<dyn AnyAgentTool> {
1492        Arc::new(Erased(Arc::new(self)))
1493    }
1494}
1495
1496pub struct Erased<T>(T);
1497
1498pub struct AgentToolOutput {
1499    pub llm_output: LanguageModelToolResultContent,
1500    pub raw_output: serde_json::Value,
1501}
1502
1503pub trait AnyAgentTool {
1504    fn name(&self) -> SharedString;
1505    fn description(&self) -> SharedString;
1506    fn kind(&self) -> acp::ToolKind;
1507    fn initial_title(&self, input: serde_json::Value) -> SharedString;
1508    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value>;
1509    fn supported_provider(&self, _provider: &LanguageModelProviderId) -> bool {
1510        true
1511    }
1512    fn run(
1513        self: Arc<Self>,
1514        input: serde_json::Value,
1515        event_stream: ToolCallEventStream,
1516        cx: &mut App,
1517    ) -> Task<Result<AgentToolOutput>>;
1518    fn replay(
1519        &self,
1520        input: serde_json::Value,
1521        output: serde_json::Value,
1522        event_stream: ToolCallEventStream,
1523        cx: &mut App,
1524    ) -> Result<()>;
1525}
1526
1527impl<T> AnyAgentTool for Erased<Arc<T>>
1528where
1529    T: AgentTool,
1530{
1531    fn name(&self) -> SharedString {
1532        self.0.name()
1533    }
1534
1535    fn description(&self) -> SharedString {
1536        self.0.description()
1537    }
1538
1539    fn kind(&self) -> agent_client_protocol::ToolKind {
1540        self.0.kind()
1541    }
1542
1543    fn initial_title(&self, input: serde_json::Value) -> SharedString {
1544        let parsed_input = serde_json::from_value(input.clone()).map_err(|_| input);
1545        self.0.initial_title(parsed_input)
1546    }
1547
1548    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
1549        let mut json = serde_json::to_value(self.0.input_schema())?;
1550        adapt_schema_to_format(&mut json, format)?;
1551        Ok(json)
1552    }
1553
1554    fn supported_provider(&self, provider: &LanguageModelProviderId) -> bool {
1555        self.0.supported_provider(provider)
1556    }
1557
1558    fn run(
1559        self: Arc<Self>,
1560        input: serde_json::Value,
1561        event_stream: ToolCallEventStream,
1562        cx: &mut App,
1563    ) -> Task<Result<AgentToolOutput>> {
1564        cx.spawn(async move |cx| {
1565            let input = serde_json::from_value(input)?;
1566            let output = cx
1567                .update(|cx| self.0.clone().run(input, event_stream, cx))?
1568                .await?;
1569            let raw_output = serde_json::to_value(&output)?;
1570            Ok(AgentToolOutput {
1571                llm_output: output.into(),
1572                raw_output,
1573            })
1574        })
1575    }
1576
1577    fn replay(
1578        &self,
1579        input: serde_json::Value,
1580        output: serde_json::Value,
1581        event_stream: ToolCallEventStream,
1582        cx: &mut App,
1583    ) -> Result<()> {
1584        let input = serde_json::from_value(input)?;
1585        let output = serde_json::from_value(output)?;
1586        self.0.replay(input, output, event_stream, cx)
1587    }
1588}
1589
1590#[derive(Clone)]
1591struct ThreadEventStream(mpsc::UnboundedSender<Result<ThreadEvent>>);
1592
1593impl ThreadEventStream {
1594    fn send_user_message(&self, message: &UserMessage) {
1595        self.0
1596            .unbounded_send(Ok(ThreadEvent::UserMessage(message.clone())))
1597            .ok();
1598    }
1599
1600    fn send_text(&self, text: &str) {
1601        self.0
1602            .unbounded_send(Ok(ThreadEvent::AgentText(text.to_string())))
1603            .ok();
1604    }
1605
1606    fn send_thinking(&self, text: &str) {
1607        self.0
1608            .unbounded_send(Ok(ThreadEvent::AgentThinking(text.to_string())))
1609            .ok();
1610    }
1611
1612    fn send_tool_call(
1613        &self,
1614        id: &LanguageModelToolUseId,
1615        title: SharedString,
1616        kind: acp::ToolKind,
1617        input: serde_json::Value,
1618    ) {
1619        self.0
1620            .unbounded_send(Ok(ThreadEvent::ToolCall(Self::initial_tool_call(
1621                id,
1622                title.to_string(),
1623                kind,
1624                input,
1625            ))))
1626            .ok();
1627    }
1628
1629    fn initial_tool_call(
1630        id: &LanguageModelToolUseId,
1631        title: String,
1632        kind: acp::ToolKind,
1633        input: serde_json::Value,
1634    ) -> acp::ToolCall {
1635        acp::ToolCall {
1636            id: acp::ToolCallId(id.to_string().into()),
1637            title,
1638            kind,
1639            status: acp::ToolCallStatus::Pending,
1640            content: vec![],
1641            locations: vec![],
1642            raw_input: Some(input),
1643            raw_output: None,
1644        }
1645    }
1646
1647    fn update_tool_call_fields(
1648        &self,
1649        tool_use_id: &LanguageModelToolUseId,
1650        fields: acp::ToolCallUpdateFields,
1651    ) {
1652        self.0
1653            .unbounded_send(Ok(ThreadEvent::ToolCallUpdate(
1654                acp::ToolCallUpdate {
1655                    id: acp::ToolCallId(tool_use_id.to_string().into()),
1656                    fields,
1657                }
1658                .into(),
1659            )))
1660            .ok();
1661    }
1662
1663    fn send_stop(&self, reason: StopReason) {
1664        match reason {
1665            StopReason::EndTurn => {
1666                self.0
1667                    .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::EndTurn)))
1668                    .ok();
1669            }
1670            StopReason::MaxTokens => {
1671                self.0
1672                    .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::MaxTokens)))
1673                    .ok();
1674            }
1675            StopReason::Refusal => {
1676                self.0
1677                    .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::Refusal)))
1678                    .ok();
1679            }
1680            StopReason::ToolUse => {}
1681        }
1682    }
1683
1684    fn send_canceled(&self) {
1685        self.0
1686            .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::Canceled)))
1687            .ok();
1688    }
1689
1690    fn send_error(&self, error: impl Into<anyhow::Error>) {
1691        self.0.unbounded_send(Err(error.into())).ok();
1692    }
1693}
1694
1695#[derive(Clone)]
1696pub struct ToolCallEventStream {
1697    tool_use_id: LanguageModelToolUseId,
1698    stream: ThreadEventStream,
1699    fs: Option<Arc<dyn Fs>>,
1700}
1701
1702impl ToolCallEventStream {
1703    #[cfg(test)]
1704    pub fn test() -> (Self, ToolCallEventStreamReceiver) {
1705        let (events_tx, events_rx) = mpsc::unbounded::<Result<ThreadEvent>>();
1706
1707        let stream = ToolCallEventStream::new("test_id".into(), ThreadEventStream(events_tx), None);
1708
1709        (stream, ToolCallEventStreamReceiver(events_rx))
1710    }
1711
1712    fn new(
1713        tool_use_id: LanguageModelToolUseId,
1714        stream: ThreadEventStream,
1715        fs: Option<Arc<dyn Fs>>,
1716    ) -> Self {
1717        Self {
1718            tool_use_id,
1719            stream,
1720            fs,
1721        }
1722    }
1723
1724    pub fn update_fields(&self, fields: acp::ToolCallUpdateFields) {
1725        self.stream
1726            .update_tool_call_fields(&self.tool_use_id, fields);
1727    }
1728
1729    pub fn update_diff(&self, diff: Entity<acp_thread::Diff>) {
1730        self.stream
1731            .0
1732            .unbounded_send(Ok(ThreadEvent::ToolCallUpdate(
1733                acp_thread::ToolCallUpdateDiff {
1734                    id: acp::ToolCallId(self.tool_use_id.to_string().into()),
1735                    diff,
1736                }
1737                .into(),
1738            )))
1739            .ok();
1740    }
1741
1742    pub fn update_terminal(&self, terminal: Entity<acp_thread::Terminal>) {
1743        self.stream
1744            .0
1745            .unbounded_send(Ok(ThreadEvent::ToolCallUpdate(
1746                acp_thread::ToolCallUpdateTerminal {
1747                    id: acp::ToolCallId(self.tool_use_id.to_string().into()),
1748                    terminal,
1749                }
1750                .into(),
1751            )))
1752            .ok();
1753    }
1754
1755    pub fn authorize(&self, title: impl Into<String>, cx: &mut App) -> Task<Result<()>> {
1756        if agent_settings::AgentSettings::get_global(cx).always_allow_tool_actions {
1757            return Task::ready(Ok(()));
1758        }
1759
1760        let (response_tx, response_rx) = oneshot::channel();
1761        self.stream
1762            .0
1763            .unbounded_send(Ok(ThreadEvent::ToolCallAuthorization(
1764                ToolCallAuthorization {
1765                    tool_call: acp::ToolCallUpdate {
1766                        id: acp::ToolCallId(self.tool_use_id.to_string().into()),
1767                        fields: acp::ToolCallUpdateFields {
1768                            title: Some(title.into()),
1769                            ..Default::default()
1770                        },
1771                    },
1772                    options: vec![
1773                        acp::PermissionOption {
1774                            id: acp::PermissionOptionId("always_allow".into()),
1775                            name: "Always Allow".into(),
1776                            kind: acp::PermissionOptionKind::AllowAlways,
1777                        },
1778                        acp::PermissionOption {
1779                            id: acp::PermissionOptionId("allow".into()),
1780                            name: "Allow".into(),
1781                            kind: acp::PermissionOptionKind::AllowOnce,
1782                        },
1783                        acp::PermissionOption {
1784                            id: acp::PermissionOptionId("deny".into()),
1785                            name: "Deny".into(),
1786                            kind: acp::PermissionOptionKind::RejectOnce,
1787                        },
1788                    ],
1789                    response: response_tx,
1790                },
1791            )))
1792            .ok();
1793        let fs = self.fs.clone();
1794        cx.spawn(async move |cx| match response_rx.await?.0.as_ref() {
1795            "always_allow" => {
1796                if let Some(fs) = fs.clone() {
1797                    cx.update(|cx| {
1798                        update_settings_file::<AgentSettings>(fs, cx, |settings, _| {
1799                            settings.set_always_allow_tool_actions(true);
1800                        });
1801                    })?;
1802                }
1803
1804                Ok(())
1805            }
1806            "allow" => Ok(()),
1807            _ => Err(anyhow!("Permission to run tool denied by user")),
1808        })
1809    }
1810}
1811
1812#[cfg(test)]
1813pub struct ToolCallEventStreamReceiver(mpsc::UnboundedReceiver<Result<ThreadEvent>>);
1814
1815#[cfg(test)]
1816impl ToolCallEventStreamReceiver {
1817    pub async fn expect_authorization(&mut self) -> ToolCallAuthorization {
1818        let event = self.0.next().await;
1819        if let Some(Ok(ThreadEvent::ToolCallAuthorization(auth))) = event {
1820            auth
1821        } else {
1822            panic!("Expected ToolCallAuthorization but got: {:?}", event);
1823        }
1824    }
1825
1826    pub async fn expect_terminal(&mut self) -> Entity<acp_thread::Terminal> {
1827        let event = self.0.next().await;
1828        if let Some(Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateTerminal(
1829            update,
1830        )))) = event
1831        {
1832            update.terminal
1833        } else {
1834            panic!("Expected terminal but got: {:?}", event);
1835        }
1836    }
1837}
1838
1839#[cfg(test)]
1840impl std::ops::Deref for ToolCallEventStreamReceiver {
1841    type Target = mpsc::UnboundedReceiver<Result<ThreadEvent>>;
1842
1843    fn deref(&self) -> &Self::Target {
1844        &self.0
1845    }
1846}
1847
1848#[cfg(test)]
1849impl std::ops::DerefMut for ToolCallEventStreamReceiver {
1850    fn deref_mut(&mut self) -> &mut Self::Target {
1851        &mut self.0
1852    }
1853}
1854
1855impl From<&str> for UserMessageContent {
1856    fn from(text: &str) -> Self {
1857        Self::Text(text.into())
1858    }
1859}
1860
1861impl From<acp::ContentBlock> for UserMessageContent {
1862    fn from(value: acp::ContentBlock) -> Self {
1863        match value {
1864            acp::ContentBlock::Text(text_content) => Self::Text(text_content.text),
1865            acp::ContentBlock::Image(image_content) => Self::Image(convert_image(image_content)),
1866            acp::ContentBlock::Audio(_) => {
1867                // TODO
1868                Self::Text("[audio]".to_string())
1869            }
1870            acp::ContentBlock::ResourceLink(resource_link) => {
1871                match MentionUri::parse(&resource_link.uri) {
1872                    Ok(uri) => Self::Mention {
1873                        uri,
1874                        content: String::new(),
1875                    },
1876                    Err(err) => {
1877                        log::error!("Failed to parse mention link: {}", err);
1878                        Self::Text(format!("[{}]({})", resource_link.name, resource_link.uri))
1879                    }
1880                }
1881            }
1882            acp::ContentBlock::Resource(resource) => match resource.resource {
1883                acp::EmbeddedResourceResource::TextResourceContents(resource) => {
1884                    match MentionUri::parse(&resource.uri) {
1885                        Ok(uri) => Self::Mention {
1886                            uri,
1887                            content: resource.text,
1888                        },
1889                        Err(err) => {
1890                            log::error!("Failed to parse mention link: {}", err);
1891                            Self::Text(
1892                                MarkdownCodeBlock {
1893                                    tag: &resource.uri,
1894                                    text: &resource.text,
1895                                }
1896                                .to_string(),
1897                            )
1898                        }
1899                    }
1900                }
1901                acp::EmbeddedResourceResource::BlobResourceContents(_) => {
1902                    // TODO
1903                    Self::Text("[blob]".to_string())
1904                }
1905            },
1906        }
1907    }
1908}
1909
1910impl From<UserMessageContent> for acp::ContentBlock {
1911    fn from(content: UserMessageContent) -> Self {
1912        match content {
1913            UserMessageContent::Text(text) => acp::ContentBlock::Text(acp::TextContent {
1914                text,
1915                annotations: None,
1916            }),
1917            UserMessageContent::Image(image) => acp::ContentBlock::Image(acp::ImageContent {
1918                data: image.source.to_string(),
1919                mime_type: "image/png".to_string(),
1920                annotations: None,
1921                uri: None,
1922            }),
1923            UserMessageContent::Mention { uri, content } => {
1924                todo!()
1925            }
1926        }
1927    }
1928}
1929
1930fn convert_image(image_content: acp::ImageContent) -> LanguageModelImage {
1931    LanguageModelImage {
1932        source: image_content.data.into(),
1933        // TODO: make this optional?
1934        size: gpui::Size::new(0.into(), 0.into()),
1935    }
1936}