thread.rs

   1use crate::{
   2    agent_profile::AgentProfile,
   3    context::{AgentContext, AgentContextHandle, ContextLoadResult, LoadedContext},
   4    thread_store::{
   5        SerializedCrease, SerializedLanguageModel, SerializedMessage, SerializedMessageSegment,
   6        SerializedThread, SerializedToolResult, SerializedToolUse, SharedProjectContext,
   7        ThreadStore,
   8    },
   9    tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState},
  10};
  11use action_log::ActionLog;
  12use agent_settings::{
  13    AgentProfileId, AgentSettings, CompletionMode, SUMMARIZE_THREAD_DETAILED_PROMPT,
  14    SUMMARIZE_THREAD_PROMPT,
  15};
  16use anyhow::{Result, anyhow};
  17use assistant_tool::{AnyToolCard, Tool, ToolWorkingSet};
  18use chrono::{DateTime, Utc};
  19use client::{ModelRequestUsage, RequestUsage};
  20use cloud_llm_client::{CompletionIntent, CompletionRequestStatus, Plan, UsageLimit};
  21use collections::HashMap;
  22use futures::{FutureExt, StreamExt as _, future::Shared};
  23use git::repository::DiffType;
  24use gpui::{
  25    AnyWindowHandle, App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task,
  26    WeakEntity, Window,
  27};
  28use http_client::StatusCode;
  29use language_model::{
  30    ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
  31    LanguageModelExt as _, LanguageModelId, LanguageModelRegistry, LanguageModelRequest,
  32    LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
  33    LanguageModelToolResultContent, LanguageModelToolUse, LanguageModelToolUseId, MessageContent,
  34    ModelRequestLimitReachedError, PaymentRequiredError, Role, SelectedModel, StopReason,
  35    TokenUsage,
  36};
  37use postage::stream::Stream as _;
  38use project::{
  39    Project,
  40    git_store::{GitStore, GitStoreCheckpoint, RepositoryState},
  41};
  42use prompt_store::{ModelContext, PromptBuilder};
  43use schemars::JsonSchema;
  44use serde::{Deserialize, Serialize};
  45use settings::Settings;
  46use std::{
  47    io::Write,
  48    ops::Range,
  49    sync::Arc,
  50    time::{Duration, Instant},
  51};
  52use thiserror::Error;
  53use util::{ResultExt as _, post_inc};
  54use uuid::Uuid;
  55
  56const MAX_RETRY_ATTEMPTS: u8 = 4;
  57const BASE_RETRY_DELAY: Duration = Duration::from_secs(5);
  58
  59#[derive(Debug, Clone)]
  60enum RetryStrategy {
  61    ExponentialBackoff {
  62        initial_delay: Duration,
  63        max_attempts: u8,
  64    },
  65    Fixed {
  66        delay: Duration,
  67        max_attempts: u8,
  68    },
  69}
  70
  71#[derive(
  72    Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
  73)]
  74pub struct ThreadId(Arc<str>);
  75
  76impl ThreadId {
  77    pub fn new() -> Self {
  78        Self(Uuid::new_v4().to_string().into())
  79    }
  80}
  81
  82impl std::fmt::Display for ThreadId {
  83    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  84        write!(f, "{}", self.0)
  85    }
  86}
  87
  88impl From<&str> for ThreadId {
  89    fn from(value: &str) -> Self {
  90        Self(value.into())
  91    }
  92}
  93
  94/// The ID of the user prompt that initiated a request.
  95///
  96/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
  97#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
  98pub struct PromptId(Arc<str>);
  99
 100impl PromptId {
 101    pub fn new() -> Self {
 102        Self(Uuid::new_v4().to_string().into())
 103    }
 104}
 105
 106impl std::fmt::Display for PromptId {
 107    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 108        write!(f, "{}", self.0)
 109    }
 110}
 111
 112#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
 113pub struct MessageId(pub usize);
 114
 115impl MessageId {
 116    fn post_inc(&mut self) -> Self {
 117        Self(post_inc(&mut self.0))
 118    }
 119
 120    pub fn as_usize(&self) -> usize {
 121        self.0
 122    }
 123}
 124
 125/// Stored information that can be used to resurrect a context crease when creating an editor for a past message.
 126#[derive(Clone, Debug)]
 127pub struct MessageCrease {
 128    pub range: Range<usize>,
 129    pub icon_path: SharedString,
 130    pub label: SharedString,
 131    /// None for a deserialized message, Some otherwise.
 132    pub context: Option<AgentContextHandle>,
 133}
 134
 135/// A message in a [`Thread`].
 136#[derive(Debug, Clone)]
 137pub struct Message {
 138    pub id: MessageId,
 139    pub role: Role,
 140    pub segments: Vec<MessageSegment>,
 141    pub loaded_context: LoadedContext,
 142    pub creases: Vec<MessageCrease>,
 143    pub is_hidden: bool,
 144    pub ui_only: bool,
 145}
 146
 147impl Message {
 148    /// Returns whether the message contains any meaningful text that should be displayed
 149    /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
 150    pub fn should_display_content(&self) -> bool {
 151        self.segments.iter().all(|segment| segment.should_display())
 152    }
 153
 154    pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
 155        if let Some(MessageSegment::Thinking {
 156            text: segment,
 157            signature: current_signature,
 158        }) = self.segments.last_mut()
 159        {
 160            if let Some(signature) = signature {
 161                *current_signature = Some(signature);
 162            }
 163            segment.push_str(text);
 164        } else {
 165            self.segments.push(MessageSegment::Thinking {
 166                text: text.to_string(),
 167                signature,
 168            });
 169        }
 170    }
 171
 172    pub fn push_redacted_thinking(&mut self, data: String) {
 173        self.segments.push(MessageSegment::RedactedThinking(data));
 174    }
 175
 176    pub fn push_text(&mut self, text: &str) {
 177        if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
 178            segment.push_str(text);
 179        } else {
 180            self.segments.push(MessageSegment::Text(text.to_string()));
 181        }
 182    }
 183
 184    pub fn to_string(&self) -> String {
 185        let mut result = String::new();
 186
 187        if !self.loaded_context.text.is_empty() {
 188            result.push_str(&self.loaded_context.text);
 189        }
 190
 191        for segment in &self.segments {
 192            match segment {
 193                MessageSegment::Text(text) => result.push_str(text),
 194                MessageSegment::Thinking { text, .. } => {
 195                    result.push_str("<think>\n");
 196                    result.push_str(text);
 197                    result.push_str("\n</think>");
 198                }
 199                MessageSegment::RedactedThinking(_) => {}
 200            }
 201        }
 202
 203        result
 204    }
 205}
 206
 207#[derive(Debug, Clone, PartialEq, Eq)]
 208pub enum MessageSegment {
 209    Text(String),
 210    Thinking {
 211        text: String,
 212        signature: Option<String>,
 213    },
 214    RedactedThinking(String),
 215}
 216
 217impl MessageSegment {
 218    pub fn should_display(&self) -> bool {
 219        match self {
 220            Self::Text(text) => text.is_empty(),
 221            Self::Thinking { text, .. } => text.is_empty(),
 222            Self::RedactedThinking(_) => false,
 223        }
 224    }
 225
 226    pub fn text(&self) -> Option<&str> {
 227        match self {
 228            MessageSegment::Text(text) => Some(text),
 229            _ => None,
 230        }
 231    }
 232}
 233
 234#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
 235pub struct ProjectSnapshot {
 236    pub worktree_snapshots: Vec<WorktreeSnapshot>,
 237    pub unsaved_buffer_paths: Vec<String>,
 238    pub timestamp: DateTime<Utc>,
 239}
 240
 241#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
 242pub struct WorktreeSnapshot {
 243    pub worktree_path: String,
 244    pub git_state: Option<GitState>,
 245}
 246
 247#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
 248pub struct GitState {
 249    pub remote_url: Option<String>,
 250    pub head_sha: Option<String>,
 251    pub current_branch: Option<String>,
 252    pub diff: Option<String>,
 253}
 254
 255#[derive(Clone, Debug)]
 256pub struct ThreadCheckpoint {
 257    message_id: MessageId,
 258    git_checkpoint: GitStoreCheckpoint,
 259}
 260
 261#[derive(Copy, Clone, Debug, PartialEq, Eq)]
 262pub enum ThreadFeedback {
 263    Positive,
 264    Negative,
 265}
 266
 267pub enum LastRestoreCheckpoint {
 268    Pending {
 269        message_id: MessageId,
 270    },
 271    Error {
 272        message_id: MessageId,
 273        error: String,
 274    },
 275}
 276
 277impl LastRestoreCheckpoint {
 278    pub fn message_id(&self) -> MessageId {
 279        match self {
 280            LastRestoreCheckpoint::Pending { message_id } => *message_id,
 281            LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
 282        }
 283    }
 284}
 285
 286#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
 287pub enum DetailedSummaryState {
 288    #[default]
 289    NotGenerated,
 290    Generating {
 291        message_id: MessageId,
 292    },
 293    Generated {
 294        text: SharedString,
 295        message_id: MessageId,
 296    },
 297}
 298
 299impl DetailedSummaryState {
 300    fn text(&self) -> Option<SharedString> {
 301        if let Self::Generated { text, .. } = self {
 302            Some(text.clone())
 303        } else {
 304            None
 305        }
 306    }
 307}
 308
 309#[derive(Default, Debug)]
 310pub struct TotalTokenUsage {
 311    pub total: u64,
 312    pub max: u64,
 313}
 314
 315impl TotalTokenUsage {
 316    pub fn ratio(&self) -> TokenUsageRatio {
 317        #[cfg(debug_assertions)]
 318        let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
 319            .unwrap_or("0.8".to_string())
 320            .parse()
 321            .unwrap();
 322        #[cfg(not(debug_assertions))]
 323        let warning_threshold: f32 = 0.8;
 324
 325        // When the maximum is unknown because there is no selected model,
 326        // avoid showing the token limit warning.
 327        if self.max == 0 {
 328            TokenUsageRatio::Normal
 329        } else if self.total >= self.max {
 330            TokenUsageRatio::Exceeded
 331        } else if self.total as f32 / self.max as f32 >= warning_threshold {
 332            TokenUsageRatio::Warning
 333        } else {
 334            TokenUsageRatio::Normal
 335        }
 336    }
 337
 338    pub fn add(&self, tokens: u64) -> TotalTokenUsage {
 339        TotalTokenUsage {
 340            total: self.total + tokens,
 341            max: self.max,
 342        }
 343    }
 344}
 345
 346#[derive(Debug, Default, PartialEq, Eq)]
 347pub enum TokenUsageRatio {
 348    #[default]
 349    Normal,
 350    Warning,
 351    Exceeded,
 352}
 353
 354#[derive(Debug, Clone, Copy)]
 355pub enum QueueState {
 356    Sending,
 357    Queued { position: usize },
 358    Started,
 359}
 360
 361/// A thread of conversation with the LLM.
 362pub struct Thread {
 363    id: ThreadId,
 364    updated_at: DateTime<Utc>,
 365    summary: ThreadSummary,
 366    pending_summary: Task<Option<()>>,
 367    detailed_summary_task: Task<Option<()>>,
 368    detailed_summary_tx: postage::watch::Sender<DetailedSummaryState>,
 369    detailed_summary_rx: postage::watch::Receiver<DetailedSummaryState>,
 370    completion_mode: agent_settings::CompletionMode,
 371    messages: Vec<Message>,
 372    next_message_id: MessageId,
 373    last_prompt_id: PromptId,
 374    project_context: SharedProjectContext,
 375    checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
 376    completion_count: usize,
 377    pending_completions: Vec<PendingCompletion>,
 378    project: Entity<Project>,
 379    prompt_builder: Arc<PromptBuilder>,
 380    tools: Entity<ToolWorkingSet>,
 381    tool_use: ToolUseState,
 382    action_log: Entity<ActionLog>,
 383    last_restore_checkpoint: Option<LastRestoreCheckpoint>,
 384    pending_checkpoint: Option<ThreadCheckpoint>,
 385    initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
 386    request_token_usage: Vec<TokenUsage>,
 387    cumulative_token_usage: TokenUsage,
 388    exceeded_window_error: Option<ExceededWindowError>,
 389    tool_use_limit_reached: bool,
 390    feedback: Option<ThreadFeedback>,
 391    retry_state: Option<RetryState>,
 392    message_feedback: HashMap<MessageId, ThreadFeedback>,
 393    last_received_chunk_at: Option<Instant>,
 394    request_callback: Option<
 395        Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
 396    >,
 397    remaining_turns: u32,
 398    configured_model: Option<ConfiguredModel>,
 399    profile: AgentProfile,
 400    last_error_context: Option<(Arc<dyn LanguageModel>, CompletionIntent)>,
 401}
 402
 403#[derive(Clone, Debug)]
 404struct RetryState {
 405    attempt: u8,
 406    max_attempts: u8,
 407    intent: CompletionIntent,
 408}
 409
 410#[derive(Clone, Debug, PartialEq, Eq)]
 411pub enum ThreadSummary {
 412    Pending,
 413    Generating,
 414    Ready(SharedString),
 415    Error,
 416}
 417
 418impl ThreadSummary {
 419    pub const DEFAULT: SharedString = SharedString::new_static("New Thread");
 420
 421    pub fn or_default(&self) -> SharedString {
 422        self.unwrap_or(Self::DEFAULT)
 423    }
 424
 425    pub fn unwrap_or(&self, message: impl Into<SharedString>) -> SharedString {
 426        self.ready().unwrap_or_else(|| message.into())
 427    }
 428
 429    pub fn ready(&self) -> Option<SharedString> {
 430        match self {
 431            ThreadSummary::Ready(summary) => Some(summary.clone()),
 432            ThreadSummary::Pending | ThreadSummary::Generating | ThreadSummary::Error => None,
 433        }
 434    }
 435}
 436
 437#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
 438pub struct ExceededWindowError {
 439    /// Model used when last message exceeded context window
 440    model_id: LanguageModelId,
 441    /// Token count including last message
 442    token_count: u64,
 443}
 444
 445impl Thread {
 446    pub fn new(
 447        project: Entity<Project>,
 448        tools: Entity<ToolWorkingSet>,
 449        prompt_builder: Arc<PromptBuilder>,
 450        system_prompt: SharedProjectContext,
 451        cx: &mut Context<Self>,
 452    ) -> Self {
 453        let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel();
 454        let configured_model = LanguageModelRegistry::read_global(cx).default_model();
 455        let profile_id = AgentSettings::get_global(cx).default_profile.clone();
 456
 457        Self {
 458            id: ThreadId::new(),
 459            updated_at: Utc::now(),
 460            summary: ThreadSummary::Pending,
 461            pending_summary: Task::ready(None),
 462            detailed_summary_task: Task::ready(None),
 463            detailed_summary_tx,
 464            detailed_summary_rx,
 465            completion_mode: AgentSettings::get_global(cx).preferred_completion_mode,
 466            messages: Vec::new(),
 467            next_message_id: MessageId(0),
 468            last_prompt_id: PromptId::new(),
 469            project_context: system_prompt,
 470            checkpoints_by_message: HashMap::default(),
 471            completion_count: 0,
 472            pending_completions: Vec::new(),
 473            project: project.clone(),
 474            prompt_builder,
 475            tools: tools.clone(),
 476            last_restore_checkpoint: None,
 477            pending_checkpoint: None,
 478            tool_use: ToolUseState::new(tools.clone()),
 479            action_log: cx.new(|_| ActionLog::new(project.clone())),
 480            initial_project_snapshot: {
 481                let project_snapshot = Self::project_snapshot(project, cx);
 482                cx.foreground_executor()
 483                    .spawn(async move { Some(project_snapshot.await) })
 484                    .shared()
 485            },
 486            request_token_usage: Vec::new(),
 487            cumulative_token_usage: TokenUsage::default(),
 488            exceeded_window_error: None,
 489            tool_use_limit_reached: false,
 490            feedback: None,
 491            retry_state: None,
 492            message_feedback: HashMap::default(),
 493            last_error_context: None,
 494            last_received_chunk_at: None,
 495            request_callback: None,
 496            remaining_turns: u32::MAX,
 497            configured_model: configured_model.clone(),
 498            profile: AgentProfile::new(profile_id, tools),
 499        }
 500    }
 501
 502    pub fn deserialize(
 503        id: ThreadId,
 504        serialized: SerializedThread,
 505        project: Entity<Project>,
 506        tools: Entity<ToolWorkingSet>,
 507        prompt_builder: Arc<PromptBuilder>,
 508        project_context: SharedProjectContext,
 509        window: Option<&mut Window>, // None in headless mode
 510        cx: &mut Context<Self>,
 511    ) -> Self {
 512        let next_message_id = MessageId(
 513            serialized
 514                .messages
 515                .last()
 516                .map(|message| message.id.0 + 1)
 517                .unwrap_or(0),
 518        );
 519        let tool_use = ToolUseState::from_serialized_messages(
 520            tools.clone(),
 521            &serialized.messages,
 522            project.clone(),
 523            window,
 524            cx,
 525        );
 526        let (detailed_summary_tx, detailed_summary_rx) =
 527            postage::watch::channel_with(serialized.detailed_summary_state);
 528
 529        let configured_model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
 530            serialized
 531                .model
 532                .and_then(|model| {
 533                    let model = SelectedModel {
 534                        provider: model.provider.clone().into(),
 535                        model: model.model.clone().into(),
 536                    };
 537                    registry.select_model(&model, cx)
 538                })
 539                .or_else(|| registry.default_model())
 540        });
 541
 542        let completion_mode = serialized
 543            .completion_mode
 544            .unwrap_or_else(|| AgentSettings::get_global(cx).preferred_completion_mode);
 545        let profile_id = serialized
 546            .profile
 547            .unwrap_or_else(|| AgentSettings::get_global(cx).default_profile.clone());
 548
 549        Self {
 550            id,
 551            updated_at: serialized.updated_at,
 552            summary: ThreadSummary::Ready(serialized.summary),
 553            pending_summary: Task::ready(None),
 554            detailed_summary_task: Task::ready(None),
 555            detailed_summary_tx,
 556            detailed_summary_rx,
 557            completion_mode,
 558            retry_state: None,
 559            messages: serialized
 560                .messages
 561                .into_iter()
 562                .map(|message| Message {
 563                    id: message.id,
 564                    role: message.role,
 565                    segments: message
 566                        .segments
 567                        .into_iter()
 568                        .map(|segment| match segment {
 569                            SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
 570                            SerializedMessageSegment::Thinking { text, signature } => {
 571                                MessageSegment::Thinking { text, signature }
 572                            }
 573                            SerializedMessageSegment::RedactedThinking { data } => {
 574                                MessageSegment::RedactedThinking(data)
 575                            }
 576                        })
 577                        .collect(),
 578                    loaded_context: LoadedContext {
 579                        contexts: Vec::new(),
 580                        text: message.context,
 581                        images: Vec::new(),
 582                    },
 583                    creases: message
 584                        .creases
 585                        .into_iter()
 586                        .map(|crease| MessageCrease {
 587                            range: crease.start..crease.end,
 588                            icon_path: crease.icon_path,
 589                            label: crease.label,
 590                            context: None,
 591                        })
 592                        .collect(),
 593                    is_hidden: message.is_hidden,
 594                    ui_only: false, // UI-only messages are not persisted
 595                })
 596                .collect(),
 597            next_message_id,
 598            last_prompt_id: PromptId::new(),
 599            project_context,
 600            checkpoints_by_message: HashMap::default(),
 601            completion_count: 0,
 602            pending_completions: Vec::new(),
 603            last_restore_checkpoint: None,
 604            pending_checkpoint: None,
 605            project: project.clone(),
 606            prompt_builder,
 607            tools: tools.clone(),
 608            tool_use,
 609            action_log: cx.new(|_| ActionLog::new(project)),
 610            initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
 611            request_token_usage: serialized.request_token_usage,
 612            cumulative_token_usage: serialized.cumulative_token_usage,
 613            exceeded_window_error: None,
 614            tool_use_limit_reached: serialized.tool_use_limit_reached,
 615            feedback: None,
 616            message_feedback: HashMap::default(),
 617            last_error_context: None,
 618            last_received_chunk_at: None,
 619            request_callback: None,
 620            remaining_turns: u32::MAX,
 621            configured_model,
 622            profile: AgentProfile::new(profile_id, tools),
 623        }
 624    }
 625
 626    pub fn set_request_callback(
 627        &mut self,
 628        callback: impl 'static
 629        + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
 630    ) {
 631        self.request_callback = Some(Box::new(callback));
 632    }
 633
 634    pub fn id(&self) -> &ThreadId {
 635        &self.id
 636    }
 637
 638    pub fn profile(&self) -> &AgentProfile {
 639        &self.profile
 640    }
 641
 642    pub fn set_profile(&mut self, id: AgentProfileId, cx: &mut Context<Self>) {
 643        if &id != self.profile.id() {
 644            self.profile = AgentProfile::new(id, self.tools.clone());
 645            cx.emit(ThreadEvent::ProfileChanged);
 646        }
 647    }
 648
 649    pub fn is_empty(&self) -> bool {
 650        self.messages.is_empty()
 651    }
 652
 653    pub fn updated_at(&self) -> DateTime<Utc> {
 654        self.updated_at
 655    }
 656
 657    pub fn touch_updated_at(&mut self) {
 658        self.updated_at = Utc::now();
 659    }
 660
 661    pub fn advance_prompt_id(&mut self) {
 662        self.last_prompt_id = PromptId::new();
 663    }
 664
 665    pub fn project_context(&self) -> SharedProjectContext {
 666        self.project_context.clone()
 667    }
 668
 669    pub fn get_or_init_configured_model(&mut self, cx: &App) -> Option<ConfiguredModel> {
 670        if self.configured_model.is_none() {
 671            self.configured_model = LanguageModelRegistry::read_global(cx).default_model();
 672        }
 673        self.configured_model.clone()
 674    }
 675
 676    pub fn configured_model(&self) -> Option<ConfiguredModel> {
 677        self.configured_model.clone()
 678    }
 679
 680    pub fn set_configured_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
 681        self.configured_model = model;
 682        cx.notify();
 683    }
 684
 685    pub fn summary(&self) -> &ThreadSummary {
 686        &self.summary
 687    }
 688
 689    pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
 690        let current_summary = match &self.summary {
 691            ThreadSummary::Pending | ThreadSummary::Generating => return,
 692            ThreadSummary::Ready(summary) => summary,
 693            ThreadSummary::Error => &ThreadSummary::DEFAULT,
 694        };
 695
 696        let mut new_summary = new_summary.into();
 697
 698        if new_summary.is_empty() {
 699            new_summary = ThreadSummary::DEFAULT;
 700        }
 701
 702        if current_summary != &new_summary {
 703            self.summary = ThreadSummary::Ready(new_summary);
 704            cx.emit(ThreadEvent::SummaryChanged);
 705        }
 706    }
 707
 708    pub fn completion_mode(&self) -> CompletionMode {
 709        self.completion_mode
 710    }
 711
 712    pub fn set_completion_mode(&mut self, mode: CompletionMode) {
 713        self.completion_mode = mode;
 714    }
 715
 716    pub fn message(&self, id: MessageId) -> Option<&Message> {
 717        let index = self
 718            .messages
 719            .binary_search_by(|message| message.id.cmp(&id))
 720            .ok()?;
 721
 722        self.messages.get(index)
 723    }
 724
 725    pub fn messages(&self) -> impl ExactSizeIterator<Item = &Message> {
 726        self.messages.iter()
 727    }
 728
 729    pub fn is_generating(&self) -> bool {
 730        !self.pending_completions.is_empty() || !self.all_tools_finished()
 731    }
 732
 733    /// Indicates whether streaming of language model events is stale.
 734    /// When `is_generating()` is false, this method returns `None`.
 735    pub fn is_generation_stale(&self) -> Option<bool> {
 736        const STALE_THRESHOLD: u128 = 250;
 737
 738        self.last_received_chunk_at
 739            .map(|instant| instant.elapsed().as_millis() > STALE_THRESHOLD)
 740    }
 741
 742    fn received_chunk(&mut self) {
 743        self.last_received_chunk_at = Some(Instant::now());
 744    }
 745
 746    pub fn queue_state(&self) -> Option<QueueState> {
 747        self.pending_completions
 748            .first()
 749            .map(|pending_completion| pending_completion.queue_state)
 750    }
 751
 752    pub fn tools(&self) -> &Entity<ToolWorkingSet> {
 753        &self.tools
 754    }
 755
 756    pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
 757        self.tool_use
 758            .pending_tool_uses()
 759            .into_iter()
 760            .find(|tool_use| &tool_use.id == id)
 761    }
 762
 763    pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
 764        self.tool_use
 765            .pending_tool_uses()
 766            .into_iter()
 767            .filter(|tool_use| tool_use.status.needs_confirmation())
 768    }
 769
 770    pub fn has_pending_tool_uses(&self) -> bool {
 771        !self.tool_use.pending_tool_uses().is_empty()
 772    }
 773
 774    pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
 775        self.checkpoints_by_message.get(&id).cloned()
 776    }
 777
 778    pub fn restore_checkpoint(
 779        &mut self,
 780        checkpoint: ThreadCheckpoint,
 781        cx: &mut Context<Self>,
 782    ) -> Task<Result<()>> {
 783        self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
 784            message_id: checkpoint.message_id,
 785        });
 786        cx.emit(ThreadEvent::CheckpointChanged);
 787        cx.notify();
 788
 789        let git_store = self.project().read(cx).git_store().clone();
 790        let restore = git_store.update(cx, |git_store, cx| {
 791            git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
 792        });
 793
 794        cx.spawn(async move |this, cx| {
 795            let result = restore.await;
 796            this.update(cx, |this, cx| {
 797                if let Err(err) = result.as_ref() {
 798                    this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
 799                        message_id: checkpoint.message_id,
 800                        error: err.to_string(),
 801                    });
 802                } else {
 803                    this.truncate(checkpoint.message_id, cx);
 804                    this.last_restore_checkpoint = None;
 805                }
 806                this.pending_checkpoint = None;
 807                cx.emit(ThreadEvent::CheckpointChanged);
 808                cx.notify();
 809            })?;
 810            result
 811        })
 812    }
 813
 814    fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
 815        let pending_checkpoint = if self.is_generating() {
 816            return;
 817        } else if let Some(checkpoint) = self.pending_checkpoint.take() {
 818            checkpoint
 819        } else {
 820            return;
 821        };
 822
 823        self.finalize_checkpoint(pending_checkpoint, cx);
 824    }
 825
 826    fn finalize_checkpoint(
 827        &mut self,
 828        pending_checkpoint: ThreadCheckpoint,
 829        cx: &mut Context<Self>,
 830    ) {
 831        let git_store = self.project.read(cx).git_store().clone();
 832        let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
 833        cx.spawn(async move |this, cx| match final_checkpoint.await {
 834            Ok(final_checkpoint) => {
 835                let equal = git_store
 836                    .update(cx, |store, cx| {
 837                        store.compare_checkpoints(
 838                            pending_checkpoint.git_checkpoint.clone(),
 839                            final_checkpoint.clone(),
 840                            cx,
 841                        )
 842                    })?
 843                    .await
 844                    .unwrap_or(false);
 845
 846                this.update(cx, |this, cx| {
 847                    this.pending_checkpoint = if equal {
 848                        Some(pending_checkpoint)
 849                    } else {
 850                        this.insert_checkpoint(pending_checkpoint, cx);
 851                        Some(ThreadCheckpoint {
 852                            message_id: this.next_message_id,
 853                            git_checkpoint: final_checkpoint,
 854                        })
 855                    }
 856                })?;
 857
 858                Ok(())
 859            }
 860            Err(_) => this.update(cx, |this, cx| {
 861                this.insert_checkpoint(pending_checkpoint, cx)
 862            }),
 863        })
 864        .detach();
 865    }
 866
 867    fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
 868        self.checkpoints_by_message
 869            .insert(checkpoint.message_id, checkpoint);
 870        cx.emit(ThreadEvent::CheckpointChanged);
 871        cx.notify();
 872    }
 873
 874    pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
 875        self.last_restore_checkpoint.as_ref()
 876    }
 877
 878    pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
 879        let Some(message_ix) = self
 880            .messages
 881            .iter()
 882            .rposition(|message| message.id == message_id)
 883        else {
 884            return;
 885        };
 886        for deleted_message in self.messages.drain(message_ix..) {
 887            self.checkpoints_by_message.remove(&deleted_message.id);
 888        }
 889        cx.notify();
 890    }
 891
 892    pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AgentContext> {
 893        self.messages
 894            .iter()
 895            .find(|message| message.id == id)
 896            .into_iter()
 897            .flat_map(|message| message.loaded_context.contexts.iter())
 898    }
 899
 900    pub fn is_turn_end(&self, ix: usize) -> bool {
 901        if self.messages.is_empty() {
 902            return false;
 903        }
 904
 905        if !self.is_generating() && ix == self.messages.len() - 1 {
 906            return true;
 907        }
 908
 909        let Some(message) = self.messages.get(ix) else {
 910            return false;
 911        };
 912
 913        if message.role != Role::Assistant {
 914            return false;
 915        }
 916
 917        self.messages
 918            .get(ix + 1)
 919            .and_then(|message| {
 920                self.message(message.id)
 921                    .map(|next_message| next_message.role == Role::User && !next_message.is_hidden)
 922            })
 923            .unwrap_or(false)
 924    }
 925
 926    pub fn tool_use_limit_reached(&self) -> bool {
 927        self.tool_use_limit_reached
 928    }
 929
 930    /// Returns whether all of the tool uses have finished running.
 931    pub fn all_tools_finished(&self) -> bool {
 932        // If the only pending tool uses left are the ones with errors, then
 933        // that means that we've finished running all of the pending tools.
 934        self.tool_use
 935            .pending_tool_uses()
 936            .iter()
 937            .all(|pending_tool_use| pending_tool_use.status.is_error())
 938    }
 939
 940    /// Returns whether any pending tool uses may perform edits
 941    pub fn has_pending_edit_tool_uses(&self) -> bool {
 942        self.tool_use
 943            .pending_tool_uses()
 944            .iter()
 945            .filter(|pending_tool_use| !pending_tool_use.status.is_error())
 946            .any(|pending_tool_use| pending_tool_use.may_perform_edits)
 947    }
 948
 949    pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
 950        self.tool_use.tool_uses_for_message(id, &self.project, cx)
 951    }
 952
 953    pub fn tool_results_for_message(
 954        &self,
 955        assistant_message_id: MessageId,
 956    ) -> Vec<&LanguageModelToolResult> {
 957        self.tool_use.tool_results_for_message(assistant_message_id)
 958    }
 959
 960    pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
 961        self.tool_use.tool_result(id)
 962    }
 963
 964    pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
 965        match &self.tool_use.tool_result(id)?.content {
 966            LanguageModelToolResultContent::Text(text) => Some(text),
 967            LanguageModelToolResultContent::Image(_) => {
 968                // TODO: We should display image
 969                None
 970            }
 971        }
 972    }
 973
 974    pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
 975        self.tool_use.tool_result_card(id).cloned()
 976    }
 977
 978    /// Return tools that are both enabled and supported by the model
 979    pub fn available_tools(
 980        &self,
 981        cx: &App,
 982        model: Arc<dyn LanguageModel>,
 983    ) -> Vec<LanguageModelRequestTool> {
 984        if model.supports_tools() {
 985            self.profile
 986                .enabled_tools(cx)
 987                .into_iter()
 988                .filter_map(|(name, tool)| {
 989                    // Skip tools that cannot be supported
 990                    let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
 991                    Some(LanguageModelRequestTool {
 992                        name: name.into(),
 993                        description: tool.description(),
 994                        input_schema,
 995                    })
 996                })
 997                .collect()
 998        } else {
 999            Vec::default()
1000        }
1001    }
1002
1003    pub fn insert_user_message(
1004        &mut self,
1005        text: impl Into<String>,
1006        loaded_context: ContextLoadResult,
1007        git_checkpoint: Option<GitStoreCheckpoint>,
1008        creases: Vec<MessageCrease>,
1009        cx: &mut Context<Self>,
1010    ) -> MessageId {
1011        if !loaded_context.referenced_buffers.is_empty() {
1012            self.action_log.update(cx, |log, cx| {
1013                for buffer in loaded_context.referenced_buffers {
1014                    log.buffer_read(buffer, cx);
1015                }
1016            });
1017        }
1018
1019        let message_id = self.insert_message(
1020            Role::User,
1021            vec![MessageSegment::Text(text.into())],
1022            loaded_context.loaded_context,
1023            creases,
1024            false,
1025            cx,
1026        );
1027
1028        if let Some(git_checkpoint) = git_checkpoint {
1029            self.pending_checkpoint = Some(ThreadCheckpoint {
1030                message_id,
1031                git_checkpoint,
1032            });
1033        }
1034
1035        message_id
1036    }
1037
1038    pub fn insert_invisible_continue_message(&mut self, cx: &mut Context<Self>) -> MessageId {
1039        let id = self.insert_message(
1040            Role::User,
1041            vec![MessageSegment::Text("Continue where you left off".into())],
1042            LoadedContext::default(),
1043            vec![],
1044            true,
1045            cx,
1046        );
1047        self.pending_checkpoint = None;
1048
1049        id
1050    }
1051
1052    pub fn insert_assistant_message(
1053        &mut self,
1054        segments: Vec<MessageSegment>,
1055        cx: &mut Context<Self>,
1056    ) -> MessageId {
1057        self.insert_message(
1058            Role::Assistant,
1059            segments,
1060            LoadedContext::default(),
1061            Vec::new(),
1062            false,
1063            cx,
1064        )
1065    }
1066
1067    pub fn insert_message(
1068        &mut self,
1069        role: Role,
1070        segments: Vec<MessageSegment>,
1071        loaded_context: LoadedContext,
1072        creases: Vec<MessageCrease>,
1073        is_hidden: bool,
1074        cx: &mut Context<Self>,
1075    ) -> MessageId {
1076        let id = self.next_message_id.post_inc();
1077        self.messages.push(Message {
1078            id,
1079            role,
1080            segments,
1081            loaded_context,
1082            creases,
1083            is_hidden,
1084            ui_only: false,
1085        });
1086        self.touch_updated_at();
1087        cx.emit(ThreadEvent::MessageAdded(id));
1088        id
1089    }
1090
1091    pub fn edit_message(
1092        &mut self,
1093        id: MessageId,
1094        new_role: Role,
1095        new_segments: Vec<MessageSegment>,
1096        creases: Vec<MessageCrease>,
1097        loaded_context: Option<LoadedContext>,
1098        checkpoint: Option<GitStoreCheckpoint>,
1099        cx: &mut Context<Self>,
1100    ) -> bool {
1101        let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
1102            return false;
1103        };
1104        message.role = new_role;
1105        message.segments = new_segments;
1106        message.creases = creases;
1107        if let Some(context) = loaded_context {
1108            message.loaded_context = context;
1109        }
1110        if let Some(git_checkpoint) = checkpoint {
1111            self.checkpoints_by_message.insert(
1112                id,
1113                ThreadCheckpoint {
1114                    message_id: id,
1115                    git_checkpoint,
1116                },
1117            );
1118        }
1119        self.touch_updated_at();
1120        cx.emit(ThreadEvent::MessageEdited(id));
1121        true
1122    }
1123
1124    pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
1125        let Some(index) = self.messages.iter().position(|message| message.id == id) else {
1126            return false;
1127        };
1128        self.messages.remove(index);
1129        self.touch_updated_at();
1130        cx.emit(ThreadEvent::MessageDeleted(id));
1131        true
1132    }
1133
1134    /// Returns the representation of this [`Thread`] in a textual form.
1135    ///
1136    /// This is the representation we use when attaching a thread as context to another thread.
1137    pub fn text(&self) -> String {
1138        let mut text = String::new();
1139
1140        for message in &self.messages {
1141            text.push_str(match message.role {
1142                language_model::Role::User => "User:",
1143                language_model::Role::Assistant => "Agent:",
1144                language_model::Role::System => "System:",
1145            });
1146            text.push('\n');
1147
1148            for segment in &message.segments {
1149                match segment {
1150                    MessageSegment::Text(content) => text.push_str(content),
1151                    MessageSegment::Thinking { text: content, .. } => {
1152                        text.push_str(&format!("<think>{}</think>", content))
1153                    }
1154                    MessageSegment::RedactedThinking(_) => {}
1155                }
1156            }
1157            text.push('\n');
1158        }
1159
1160        text
1161    }
1162
1163    /// Serializes this thread into a format for storage or telemetry.
1164    pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
1165        let initial_project_snapshot = self.initial_project_snapshot.clone();
1166        cx.spawn(async move |this, cx| {
1167            let initial_project_snapshot = initial_project_snapshot.await;
1168            this.read_with(cx, |this, cx| SerializedThread {
1169                version: SerializedThread::VERSION.to_string(),
1170                summary: this.summary().or_default(),
1171                updated_at: this.updated_at(),
1172                messages: this
1173                    .messages()
1174                    .filter(|message| !message.ui_only)
1175                    .map(|message| SerializedMessage {
1176                        id: message.id,
1177                        role: message.role,
1178                        segments: message
1179                            .segments
1180                            .iter()
1181                            .map(|segment| match segment {
1182                                MessageSegment::Text(text) => {
1183                                    SerializedMessageSegment::Text { text: text.clone() }
1184                                }
1185                                MessageSegment::Thinking { text, signature } => {
1186                                    SerializedMessageSegment::Thinking {
1187                                        text: text.clone(),
1188                                        signature: signature.clone(),
1189                                    }
1190                                }
1191                                MessageSegment::RedactedThinking(data) => {
1192                                    SerializedMessageSegment::RedactedThinking {
1193                                        data: data.clone(),
1194                                    }
1195                                }
1196                            })
1197                            .collect(),
1198                        tool_uses: this
1199                            .tool_uses_for_message(message.id, cx)
1200                            .into_iter()
1201                            .map(|tool_use| SerializedToolUse {
1202                                id: tool_use.id,
1203                                name: tool_use.name,
1204                                input: tool_use.input,
1205                            })
1206                            .collect(),
1207                        tool_results: this
1208                            .tool_results_for_message(message.id)
1209                            .into_iter()
1210                            .map(|tool_result| SerializedToolResult {
1211                                tool_use_id: tool_result.tool_use_id.clone(),
1212                                is_error: tool_result.is_error,
1213                                content: tool_result.content.clone(),
1214                                output: tool_result.output.clone(),
1215                            })
1216                            .collect(),
1217                        context: message.loaded_context.text.clone(),
1218                        creases: message
1219                            .creases
1220                            .iter()
1221                            .map(|crease| SerializedCrease {
1222                                start: crease.range.start,
1223                                end: crease.range.end,
1224                                icon_path: crease.icon_path.clone(),
1225                                label: crease.label.clone(),
1226                            })
1227                            .collect(),
1228                        is_hidden: message.is_hidden,
1229                    })
1230                    .collect(),
1231                initial_project_snapshot,
1232                cumulative_token_usage: this.cumulative_token_usage,
1233                request_token_usage: this.request_token_usage.clone(),
1234                detailed_summary_state: this.detailed_summary_rx.borrow().clone(),
1235                exceeded_window_error: this.exceeded_window_error.clone(),
1236                model: this
1237                    .configured_model
1238                    .as_ref()
1239                    .map(|model| SerializedLanguageModel {
1240                        provider: model.provider.id().0.to_string(),
1241                        model: model.model.id().0.to_string(),
1242                    }),
1243                completion_mode: Some(this.completion_mode),
1244                tool_use_limit_reached: this.tool_use_limit_reached,
1245                profile: Some(this.profile.id().clone()),
1246            })
1247        })
1248    }
1249
1250    pub fn remaining_turns(&self) -> u32 {
1251        self.remaining_turns
1252    }
1253
1254    pub fn set_remaining_turns(&mut self, remaining_turns: u32) {
1255        self.remaining_turns = remaining_turns;
1256    }
1257
1258    pub fn send_to_model(
1259        &mut self,
1260        model: Arc<dyn LanguageModel>,
1261        intent: CompletionIntent,
1262        window: Option<AnyWindowHandle>,
1263        cx: &mut Context<Self>,
1264    ) {
1265        if self.remaining_turns == 0 {
1266            return;
1267        }
1268
1269        self.remaining_turns -= 1;
1270
1271        self.flush_notifications(model.clone(), intent, cx);
1272
1273        let _checkpoint = self.finalize_pending_checkpoint(cx);
1274        self.stream_completion(
1275            self.to_completion_request(model.clone(), intent, cx),
1276            model,
1277            intent,
1278            window,
1279            cx,
1280        );
1281    }
1282
1283    pub fn retry_last_completion(
1284        &mut self,
1285        window: Option<AnyWindowHandle>,
1286        cx: &mut Context<Self>,
1287    ) {
1288        // Clear any existing error state
1289        self.retry_state = None;
1290
1291        // Use the last error context if available, otherwise fall back to configured model
1292        let (model, intent) = if let Some((model, intent)) = self.last_error_context.take() {
1293            (model, intent)
1294        } else if let Some(configured_model) = self.configured_model.as_ref() {
1295            let model = configured_model.model.clone();
1296            let intent = if self.has_pending_tool_uses() {
1297                CompletionIntent::ToolResults
1298            } else {
1299                CompletionIntent::UserPrompt
1300            };
1301            (model, intent)
1302        } else if let Some(configured_model) = self.get_or_init_configured_model(cx) {
1303            let model = configured_model.model.clone();
1304            let intent = if self.has_pending_tool_uses() {
1305                CompletionIntent::ToolResults
1306            } else {
1307                CompletionIntent::UserPrompt
1308            };
1309            (model, intent)
1310        } else {
1311            return;
1312        };
1313
1314        self.send_to_model(model, intent, window, cx);
1315    }
1316
1317    pub fn enable_burn_mode_and_retry(
1318        &mut self,
1319        window: Option<AnyWindowHandle>,
1320        cx: &mut Context<Self>,
1321    ) {
1322        self.completion_mode = CompletionMode::Burn;
1323        cx.emit(ThreadEvent::ProfileChanged);
1324        self.retry_last_completion(window, cx);
1325    }
1326
1327    pub fn used_tools_since_last_user_message(&self) -> bool {
1328        for message in self.messages.iter().rev() {
1329            if self.tool_use.message_has_tool_results(message.id) {
1330                return true;
1331            } else if message.role == Role::User {
1332                return false;
1333            }
1334        }
1335
1336        false
1337    }
1338
1339    pub fn to_completion_request(
1340        &self,
1341        model: Arc<dyn LanguageModel>,
1342        intent: CompletionIntent,
1343        cx: &mut Context<Self>,
1344    ) -> LanguageModelRequest {
1345        let mut request = LanguageModelRequest {
1346            thread_id: Some(self.id.to_string()),
1347            prompt_id: Some(self.last_prompt_id.to_string()),
1348            intent: Some(intent),
1349            mode: None,
1350            messages: vec![],
1351            tools: Vec::new(),
1352            tool_choice: None,
1353            stop: Vec::new(),
1354            temperature: AgentSettings::temperature_for_model(&model, cx),
1355            thinking_allowed: true,
1356        };
1357
1358        let available_tools = self.available_tools(cx, model.clone());
1359        let available_tool_names = available_tools
1360            .iter()
1361            .map(|tool| tool.name.clone())
1362            .collect();
1363
1364        let model_context = &ModelContext {
1365            available_tools: available_tool_names,
1366        };
1367
1368        if let Some(project_context) = self.project_context.borrow().as_ref() {
1369            match self
1370                .prompt_builder
1371                .generate_assistant_system_prompt(project_context, model_context)
1372            {
1373                Err(err) => {
1374                    let message = format!("{err:?}").into();
1375                    log::error!("{message}");
1376                    cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1377                        header: "Error generating system prompt".into(),
1378                        message,
1379                    }));
1380                }
1381                Ok(system_prompt) => {
1382                    request.messages.push(LanguageModelRequestMessage {
1383                        role: Role::System,
1384                        content: vec![MessageContent::Text(system_prompt)],
1385                        cache: true,
1386                    });
1387                }
1388            }
1389        } else {
1390            let message = "Context for system prompt unexpectedly not ready.".into();
1391            log::error!("{message}");
1392            cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1393                header: "Error generating system prompt".into(),
1394                message,
1395            }));
1396        }
1397
1398        let mut message_ix_to_cache = None;
1399        for message in &self.messages {
1400            // ui_only messages are for the UI only, not for the model
1401            if message.ui_only {
1402                continue;
1403            }
1404
1405            let mut request_message = LanguageModelRequestMessage {
1406                role: message.role,
1407                content: Vec::new(),
1408                cache: false,
1409            };
1410
1411            message
1412                .loaded_context
1413                .add_to_request_message(&mut request_message);
1414
1415            for segment in &message.segments {
1416                match segment {
1417                    MessageSegment::Text(text) => {
1418                        let text = text.trim_end();
1419                        if !text.is_empty() {
1420                            request_message
1421                                .content
1422                                .push(MessageContent::Text(text.into()));
1423                        }
1424                    }
1425                    MessageSegment::Thinking { text, signature } => {
1426                        if !text.is_empty() {
1427                            request_message.content.push(MessageContent::Thinking {
1428                                text: text.into(),
1429                                signature: signature.clone(),
1430                            });
1431                        }
1432                    }
1433                    MessageSegment::RedactedThinking(data) => {
1434                        request_message
1435                            .content
1436                            .push(MessageContent::RedactedThinking(data.clone()));
1437                    }
1438                };
1439            }
1440
1441            let mut cache_message = true;
1442            let mut tool_results_message = LanguageModelRequestMessage {
1443                role: Role::User,
1444                content: Vec::new(),
1445                cache: false,
1446            };
1447            for (tool_use, tool_result) in self.tool_use.tool_results(message.id) {
1448                if let Some(tool_result) = tool_result {
1449                    request_message
1450                        .content
1451                        .push(MessageContent::ToolUse(tool_use.clone()));
1452                    tool_results_message
1453                        .content
1454                        .push(MessageContent::ToolResult(LanguageModelToolResult {
1455                            tool_use_id: tool_use.id.clone(),
1456                            tool_name: tool_result.tool_name.clone(),
1457                            is_error: tool_result.is_error,
1458                            content: if tool_result.content.is_empty() {
1459                                // Surprisingly, the API fails if we return an empty string here.
1460                                // It thinks we are sending a tool use without a tool result.
1461                                "<Tool returned an empty string>".into()
1462                            } else {
1463                                tool_result.content.clone()
1464                            },
1465                            output: None,
1466                        }));
1467                } else {
1468                    cache_message = false;
1469                    log::debug!(
1470                        "skipped tool use {:?} because it is still pending",
1471                        tool_use
1472                    );
1473                }
1474            }
1475
1476            if cache_message {
1477                message_ix_to_cache = Some(request.messages.len());
1478            }
1479            request.messages.push(request_message);
1480
1481            if !tool_results_message.content.is_empty() {
1482                if cache_message {
1483                    message_ix_to_cache = Some(request.messages.len());
1484                }
1485                request.messages.push(tool_results_message);
1486            }
1487        }
1488
1489        // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1490        if let Some(message_ix_to_cache) = message_ix_to_cache {
1491            request.messages[message_ix_to_cache].cache = true;
1492        }
1493
1494        request.tools = available_tools;
1495        request.mode = if model.supports_burn_mode() {
1496            Some(self.completion_mode.into())
1497        } else {
1498            Some(CompletionMode::Normal.into())
1499        };
1500
1501        request
1502    }
1503
1504    fn to_summarize_request(
1505        &self,
1506        model: &Arc<dyn LanguageModel>,
1507        intent: CompletionIntent,
1508        added_user_message: String,
1509        cx: &App,
1510    ) -> LanguageModelRequest {
1511        let mut request = LanguageModelRequest {
1512            thread_id: None,
1513            prompt_id: None,
1514            intent: Some(intent),
1515            mode: None,
1516            messages: vec![],
1517            tools: Vec::new(),
1518            tool_choice: None,
1519            stop: Vec::new(),
1520            temperature: AgentSettings::temperature_for_model(model, cx),
1521            thinking_allowed: false,
1522        };
1523
1524        for message in &self.messages {
1525            let mut request_message = LanguageModelRequestMessage {
1526                role: message.role,
1527                content: Vec::new(),
1528                cache: false,
1529            };
1530
1531            for segment in &message.segments {
1532                match segment {
1533                    MessageSegment::Text(text) => request_message
1534                        .content
1535                        .push(MessageContent::Text(text.clone())),
1536                    MessageSegment::Thinking { .. } => {}
1537                    MessageSegment::RedactedThinking(_) => {}
1538                }
1539            }
1540
1541            if request_message.content.is_empty() {
1542                continue;
1543            }
1544
1545            request.messages.push(request_message);
1546        }
1547
1548        request.messages.push(LanguageModelRequestMessage {
1549            role: Role::User,
1550            content: vec![MessageContent::Text(added_user_message)],
1551            cache: false,
1552        });
1553
1554        request
1555    }
1556
1557    /// Insert auto-generated notifications (if any) to the thread
1558    fn flush_notifications(
1559        &mut self,
1560        model: Arc<dyn LanguageModel>,
1561        intent: CompletionIntent,
1562        cx: &mut Context<Self>,
1563    ) {
1564        match intent {
1565            CompletionIntent::UserPrompt | CompletionIntent::ToolResults => {
1566                if let Some(pending_tool_use) = self.attach_tracked_files_state(model, cx) {
1567                    cx.emit(ThreadEvent::ToolFinished {
1568                        tool_use_id: pending_tool_use.id.clone(),
1569                        pending_tool_use: Some(pending_tool_use),
1570                    });
1571                }
1572            }
1573            CompletionIntent::ThreadSummarization
1574            | CompletionIntent::ThreadContextSummarization
1575            | CompletionIntent::CreateFile
1576            | CompletionIntent::EditFile
1577            | CompletionIntent::InlineAssist
1578            | CompletionIntent::TerminalInlineAssist
1579            | CompletionIntent::GenerateGitCommitMessage => {}
1580        };
1581    }
1582
1583    fn attach_tracked_files_state(
1584        &mut self,
1585        model: Arc<dyn LanguageModel>,
1586        cx: &mut App,
1587    ) -> Option<PendingToolUse> {
1588        // Represent notification as a simulated `project_notifications` tool call
1589        let tool_name = Arc::from("project_notifications");
1590        let tool = self.tools.read(cx).tool(&tool_name, cx)?;
1591
1592        if !self.profile.is_tool_enabled(tool.source(), tool.name(), cx) {
1593            return None;
1594        }
1595
1596        if self
1597            .action_log
1598            .update(cx, |log, cx| log.unnotified_user_edits(cx).is_none())
1599        {
1600            return None;
1601        }
1602
1603        let input = serde_json::json!({});
1604        let request = Arc::new(LanguageModelRequest::default()); // unused
1605        let window = None;
1606        let tool_result = tool.run(
1607            input,
1608            request,
1609            self.project.clone(),
1610            self.action_log.clone(),
1611            model.clone(),
1612            window,
1613            cx,
1614        );
1615
1616        let tool_use_id =
1617            LanguageModelToolUseId::from(format!("project_notifications_{}", self.messages.len()));
1618
1619        let tool_use = LanguageModelToolUse {
1620            id: tool_use_id.clone(),
1621            name: tool_name.clone(),
1622            raw_input: "{}".to_string(),
1623            input: serde_json::json!({}),
1624            is_input_complete: true,
1625        };
1626
1627        let tool_output = cx.background_executor().block(tool_result.output);
1628
1629        // Attach a project_notification tool call to the latest existing
1630        // Assistant message. We cannot create a new Assistant message
1631        // because thinking models require a `thinking` block that we
1632        // cannot mock. We cannot send a notification as a normal
1633        // (non-tool-use) User message because this distracts Agent
1634        // too much.
1635        let tool_message_id = self
1636            .messages
1637            .iter()
1638            .enumerate()
1639            .rfind(|(_, message)| message.role == Role::Assistant)
1640            .map(|(_, message)| message.id)?;
1641
1642        let tool_use_metadata = ToolUseMetadata {
1643            model: model.clone(),
1644            thread_id: self.id.clone(),
1645            prompt_id: self.last_prompt_id.clone(),
1646        };
1647
1648        self.tool_use
1649            .request_tool_use(tool_message_id, tool_use, tool_use_metadata.clone(), cx);
1650
1651        self.tool_use.insert_tool_output(
1652            tool_use_id.clone(),
1653            tool_name,
1654            tool_output,
1655            self.configured_model.as_ref(),
1656            self.completion_mode,
1657        )
1658    }
1659
1660    pub fn stream_completion(
1661        &mut self,
1662        request: LanguageModelRequest,
1663        model: Arc<dyn LanguageModel>,
1664        intent: CompletionIntent,
1665        window: Option<AnyWindowHandle>,
1666        cx: &mut Context<Self>,
1667    ) {
1668        self.tool_use_limit_reached = false;
1669
1670        let pending_completion_id = post_inc(&mut self.completion_count);
1671        let mut request_callback_parameters = if self.request_callback.is_some() {
1672            Some((request.clone(), Vec::new()))
1673        } else {
1674            None
1675        };
1676        let prompt_id = self.last_prompt_id.clone();
1677        let tool_use_metadata = ToolUseMetadata {
1678            model: model.clone(),
1679            thread_id: self.id.clone(),
1680            prompt_id: prompt_id.clone(),
1681        };
1682
1683        let completion_mode = request
1684            .mode
1685            .unwrap_or(cloud_llm_client::CompletionMode::Normal);
1686
1687        self.last_received_chunk_at = Some(Instant::now());
1688
1689        let task = cx.spawn(async move |thread, cx| {
1690            let stream_completion_future = model.stream_completion(request, cx);
1691            let initial_token_usage =
1692                thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1693            let stream_completion = async {
1694                let mut events = stream_completion_future.await?;
1695
1696                let mut stop_reason = StopReason::EndTurn;
1697                let mut current_token_usage = TokenUsage::default();
1698
1699                thread
1700                    .update(cx, |_thread, cx| {
1701                        cx.emit(ThreadEvent::NewRequest);
1702                    })
1703                    .ok();
1704
1705                let mut request_assistant_message_id = None;
1706
1707                while let Some(event) = events.next().await {
1708                    if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1709                        response_events
1710                            .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1711                    }
1712
1713                    thread.update(cx, |thread, cx| {
1714                        match event? {
1715                            LanguageModelCompletionEvent::StartMessage { .. } => {
1716                                request_assistant_message_id =
1717                                    Some(thread.insert_assistant_message(
1718                                        vec![MessageSegment::Text(String::new())],
1719                                        cx,
1720                                    ));
1721                            }
1722                            LanguageModelCompletionEvent::Stop(reason) => {
1723                                stop_reason = reason;
1724                            }
1725                            LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1726                                thread.update_token_usage_at_last_message(token_usage);
1727                                thread.cumulative_token_usage = thread.cumulative_token_usage
1728                                    + token_usage
1729                                    - current_token_usage;
1730                                current_token_usage = token_usage;
1731                            }
1732                            LanguageModelCompletionEvent::Text(chunk) => {
1733                                thread.received_chunk();
1734
1735                                cx.emit(ThreadEvent::ReceivedTextChunk);
1736                                if let Some(last_message) = thread.messages.last_mut() {
1737                                    if last_message.role == Role::Assistant
1738                                        && !thread.tool_use.has_tool_results(last_message.id)
1739                                    {
1740                                        last_message.push_text(&chunk);
1741                                        cx.emit(ThreadEvent::StreamedAssistantText(
1742                                            last_message.id,
1743                                            chunk,
1744                                        ));
1745                                    } else {
1746                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
1747                                        // of a new Assistant response.
1748                                        //
1749                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1750                                        // will result in duplicating the text of the chunk in the rendered Markdown.
1751                                        request_assistant_message_id =
1752                                            Some(thread.insert_assistant_message(
1753                                                vec![MessageSegment::Text(chunk.to_string())],
1754                                                cx,
1755                                            ));
1756                                    };
1757                                }
1758                            }
1759                            LanguageModelCompletionEvent::Thinking {
1760                                text: chunk,
1761                                signature,
1762                            } => {
1763                                thread.received_chunk();
1764
1765                                if let Some(last_message) = thread.messages.last_mut() {
1766                                    if last_message.role == Role::Assistant
1767                                        && !thread.tool_use.has_tool_results(last_message.id)
1768                                    {
1769                                        last_message.push_thinking(&chunk, signature);
1770                                        cx.emit(ThreadEvent::StreamedAssistantThinking(
1771                                            last_message.id,
1772                                            chunk,
1773                                        ));
1774                                    } else {
1775                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
1776                                        // of a new Assistant response.
1777                                        //
1778                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1779                                        // will result in duplicating the text of the chunk in the rendered Markdown.
1780                                        request_assistant_message_id =
1781                                            Some(thread.insert_assistant_message(
1782                                                vec![MessageSegment::Thinking {
1783                                                    text: chunk.to_string(),
1784                                                    signature,
1785                                                }],
1786                                                cx,
1787                                            ));
1788                                    };
1789                                }
1790                            }
1791                            LanguageModelCompletionEvent::RedactedThinking { data } => {
1792                                thread.received_chunk();
1793
1794                                if let Some(last_message) = thread.messages.last_mut() {
1795                                    if last_message.role == Role::Assistant
1796                                        && !thread.tool_use.has_tool_results(last_message.id)
1797                                    {
1798                                        last_message.push_redacted_thinking(data);
1799                                    } else {
1800                                        request_assistant_message_id =
1801                                            Some(thread.insert_assistant_message(
1802                                                vec![MessageSegment::RedactedThinking(data)],
1803                                                cx,
1804                                            ));
1805                                    };
1806                                }
1807                            }
1808                            LanguageModelCompletionEvent::ToolUse(tool_use) => {
1809                                let last_assistant_message_id = request_assistant_message_id
1810                                    .unwrap_or_else(|| {
1811                                        let new_assistant_message_id =
1812                                            thread.insert_assistant_message(vec![], cx);
1813                                        request_assistant_message_id =
1814                                            Some(new_assistant_message_id);
1815                                        new_assistant_message_id
1816                                    });
1817
1818                                let tool_use_id = tool_use.id.clone();
1819                                let streamed_input = if tool_use.is_input_complete {
1820                                    None
1821                                } else {
1822                                    Some(tool_use.input.clone())
1823                                };
1824
1825                                let ui_text = thread.tool_use.request_tool_use(
1826                                    last_assistant_message_id,
1827                                    tool_use,
1828                                    tool_use_metadata.clone(),
1829                                    cx,
1830                                );
1831
1832                                if let Some(input) = streamed_input {
1833                                    cx.emit(ThreadEvent::StreamedToolUse {
1834                                        tool_use_id,
1835                                        ui_text,
1836                                        input,
1837                                    });
1838                                }
1839                            }
1840                            LanguageModelCompletionEvent::ToolUseJsonParseError {
1841                                id,
1842                                tool_name,
1843                                raw_input: invalid_input_json,
1844                                json_parse_error,
1845                            } => {
1846                                thread.receive_invalid_tool_json(
1847                                    id,
1848                                    tool_name,
1849                                    invalid_input_json,
1850                                    json_parse_error,
1851                                    window,
1852                                    cx,
1853                                );
1854                            }
1855                            LanguageModelCompletionEvent::StatusUpdate(status_update) => {
1856                                if let Some(completion) = thread
1857                                    .pending_completions
1858                                    .iter_mut()
1859                                    .find(|completion| completion.id == pending_completion_id)
1860                                {
1861                                    match status_update {
1862                                        CompletionRequestStatus::Queued { position } => {
1863                                            completion.queue_state =
1864                                                QueueState::Queued { position };
1865                                        }
1866                                        CompletionRequestStatus::Started => {
1867                                            completion.queue_state = QueueState::Started;
1868                                        }
1869                                        CompletionRequestStatus::Failed {
1870                                            code,
1871                                            message,
1872                                            request_id: _,
1873                                            retry_after,
1874                                        } => {
1875                                            return Err(
1876                                                LanguageModelCompletionError::from_cloud_failure(
1877                                                    model.upstream_provider_name(),
1878                                                    code,
1879                                                    message,
1880                                                    retry_after.map(Duration::from_secs_f64),
1881                                                ),
1882                                            );
1883                                        }
1884                                        CompletionRequestStatus::UsageUpdated { amount, limit } => {
1885                                            thread.update_model_request_usage(
1886                                                amount as u32,
1887                                                limit,
1888                                                cx,
1889                                            );
1890                                        }
1891                                        CompletionRequestStatus::ToolUseLimitReached => {
1892                                            thread.tool_use_limit_reached = true;
1893                                            cx.emit(ThreadEvent::ToolUseLimitReached);
1894                                        }
1895                                    }
1896                                }
1897                            }
1898                        }
1899
1900                        thread.touch_updated_at();
1901                        cx.emit(ThreadEvent::StreamedCompletion);
1902                        cx.notify();
1903
1904                        Ok(())
1905                    })??;
1906
1907                    smol::future::yield_now().await;
1908                }
1909
1910                thread.update(cx, |thread, cx| {
1911                    thread.last_received_chunk_at = None;
1912                    thread
1913                        .pending_completions
1914                        .retain(|completion| completion.id != pending_completion_id);
1915
1916                    // If there is a response without tool use, summarize the message. Otherwise,
1917                    // allow two tool uses before summarizing.
1918                    if matches!(thread.summary, ThreadSummary::Pending)
1919                        && thread.messages.len() >= 2
1920                        && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
1921                    {
1922                        thread.summarize(cx);
1923                    }
1924                })?;
1925
1926                anyhow::Ok(stop_reason)
1927            };
1928
1929            let result = stream_completion.await;
1930            let mut retry_scheduled = false;
1931
1932            thread
1933                .update(cx, |thread, cx| {
1934                    thread.finalize_pending_checkpoint(cx);
1935                    match result.as_ref() {
1936                        Ok(stop_reason) => {
1937                            match stop_reason {
1938                                StopReason::ToolUse => {
1939                                    let tool_uses =
1940                                        thread.use_pending_tools(window, model.clone(), cx);
1941                                    cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1942                                }
1943                                StopReason::EndTurn | StopReason::MaxTokens => {
1944                                    thread.project.update(cx, |project, cx| {
1945                                        project.set_agent_location(None, cx);
1946                                    });
1947                                }
1948                                StopReason::Refusal => {
1949                                    thread.project.update(cx, |project, cx| {
1950                                        project.set_agent_location(None, cx);
1951                                    });
1952
1953                                    // Remove the turn that was refused.
1954                                    //
1955                                    // https://docs.anthropic.com/en/docs/test-and-evaluate/strengthen-guardrails/handle-streaming-refusals#reset-context-after-refusal
1956                                    {
1957                                        let mut messages_to_remove = Vec::new();
1958
1959                                        for (ix, message) in
1960                                            thread.messages.iter().enumerate().rev()
1961                                        {
1962                                            messages_to_remove.push(message.id);
1963
1964                                            if message.role == Role::User {
1965                                                if ix == 0 {
1966                                                    break;
1967                                                }
1968
1969                                                if let Some(prev_message) =
1970                                                    thread.messages.get(ix - 1)
1971                                                    && prev_message.role == Role::Assistant {
1972                                                        break;
1973                                                    }
1974                                            }
1975                                        }
1976
1977                                        for message_id in messages_to_remove {
1978                                            thread.delete_message(message_id, cx);
1979                                        }
1980                                    }
1981
1982                                    cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1983                                        header: "Language model refusal".into(),
1984                                        message:
1985                                            "Model refused to generate content for safety reasons."
1986                                                .into(),
1987                                    }));
1988                                }
1989                            }
1990
1991                            // We successfully completed, so cancel any remaining retries.
1992                            thread.retry_state = None;
1993                        }
1994                        Err(error) => {
1995                            thread.project.update(cx, |project, cx| {
1996                                project.set_agent_location(None, cx);
1997                            });
1998
1999                            if error.is::<PaymentRequiredError>() {
2000                                cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
2001                            } else if let Some(error) =
2002                                error.downcast_ref::<ModelRequestLimitReachedError>()
2003                            {
2004                                cx.emit(ThreadEvent::ShowError(
2005                                    ThreadError::ModelRequestLimitReached { plan: error.plan },
2006                                ));
2007                            } else if let Some(completion_error) =
2008                                error.downcast_ref::<LanguageModelCompletionError>()
2009                            {
2010                                match &completion_error {
2011                                    LanguageModelCompletionError::PromptTooLarge {
2012                                        tokens, ..
2013                                    } => {
2014                                        let tokens = tokens.unwrap_or_else(|| {
2015                                            // We didn't get an exact token count from the API, so fall back on our estimate.
2016                                            thread
2017                                                .total_token_usage()
2018                                                .map(|usage| usage.total)
2019                                                .unwrap_or(0)
2020                                                // We know the context window was exceeded in practice, so if our estimate was
2021                                                // lower than max tokens, the estimate was wrong; return that we exceeded by 1.
2022                                                .max(
2023                                                    model
2024                                                        .max_token_count_for_mode(completion_mode)
2025                                                        .saturating_add(1),
2026                                                )
2027                                        });
2028                                        thread.exceeded_window_error = Some(ExceededWindowError {
2029                                            model_id: model.id(),
2030                                            token_count: tokens,
2031                                        });
2032                                        cx.notify();
2033                                    }
2034                                    _ => {
2035                                        if let Some(retry_strategy) =
2036                                            Thread::get_retry_strategy(completion_error)
2037                                        {
2038                                            log::info!(
2039                                                "Retrying with {:?} for language model completion error {:?}",
2040                                                retry_strategy,
2041                                                completion_error
2042                                            );
2043
2044                                            retry_scheduled = thread
2045                                                .handle_retryable_error_with_delay(
2046                                                    completion_error,
2047                                                    Some(retry_strategy),
2048                                                    model.clone(),
2049                                                    intent,
2050                                                    window,
2051                                                    cx,
2052                                                );
2053                                        }
2054                                    }
2055                                }
2056                            }
2057
2058                            if !retry_scheduled {
2059                                thread.cancel_last_completion(window, cx);
2060                            }
2061                        }
2062                    }
2063
2064                    if !retry_scheduled {
2065                        cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
2066                    }
2067
2068                    if let Some((request_callback, (request, response_events))) = thread
2069                        .request_callback
2070                        .as_mut()
2071                        .zip(request_callback_parameters.as_ref())
2072                    {
2073                        request_callback(request, response_events);
2074                    }
2075
2076                    if let Ok(initial_usage) = initial_token_usage {
2077                        let usage = thread.cumulative_token_usage - initial_usage;
2078
2079                        telemetry::event!(
2080                            "Assistant Thread Completion",
2081                            thread_id = thread.id().to_string(),
2082                            prompt_id = prompt_id,
2083                            model = model.telemetry_id(),
2084                            model_provider = model.provider_id().to_string(),
2085                            input_tokens = usage.input_tokens,
2086                            output_tokens = usage.output_tokens,
2087                            cache_creation_input_tokens = usage.cache_creation_input_tokens,
2088                            cache_read_input_tokens = usage.cache_read_input_tokens,
2089                        );
2090                    }
2091                })
2092                .ok();
2093        });
2094
2095        self.pending_completions.push(PendingCompletion {
2096            id: pending_completion_id,
2097            queue_state: QueueState::Sending,
2098            _task: task,
2099        });
2100    }
2101
2102    pub fn summarize(&mut self, cx: &mut Context<Self>) {
2103        let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
2104            println!("No thread summary model");
2105            return;
2106        };
2107
2108        if !model.provider.is_authenticated(cx) {
2109            return;
2110        }
2111
2112        let request = self.to_summarize_request(
2113            &model.model,
2114            CompletionIntent::ThreadSummarization,
2115            SUMMARIZE_THREAD_PROMPT.into(),
2116            cx,
2117        );
2118
2119        self.summary = ThreadSummary::Generating;
2120
2121        self.pending_summary = cx.spawn(async move |this, cx| {
2122            let result = async {
2123                let mut messages = model.model.stream_completion(request, cx).await?;
2124
2125                let mut new_summary = String::new();
2126                while let Some(event) = messages.next().await {
2127                    let Ok(event) = event else {
2128                        continue;
2129                    };
2130                    let text = match event {
2131                        LanguageModelCompletionEvent::Text(text) => text,
2132                        LanguageModelCompletionEvent::StatusUpdate(
2133                            CompletionRequestStatus::UsageUpdated { amount, limit },
2134                        ) => {
2135                            this.update(cx, |thread, cx| {
2136                                thread.update_model_request_usage(amount as u32, limit, cx);
2137                            })?;
2138                            continue;
2139                        }
2140                        _ => continue,
2141                    };
2142
2143                    let mut lines = text.lines();
2144                    new_summary.extend(lines.next());
2145
2146                    // Stop if the LLM generated multiple lines.
2147                    if lines.next().is_some() {
2148                        break;
2149                    }
2150                }
2151
2152                anyhow::Ok(new_summary)
2153            }
2154            .await;
2155
2156            this.update(cx, |this, cx| {
2157                match result {
2158                    Ok(new_summary) => {
2159                        if new_summary.is_empty() {
2160                            this.summary = ThreadSummary::Error;
2161                        } else {
2162                            this.summary = ThreadSummary::Ready(new_summary.into());
2163                        }
2164                    }
2165                    Err(err) => {
2166                        this.summary = ThreadSummary::Error;
2167                        log::error!("Failed to generate thread summary: {}", err);
2168                    }
2169                }
2170                cx.emit(ThreadEvent::SummaryGenerated);
2171            })
2172            .log_err()?;
2173
2174            Some(())
2175        });
2176    }
2177
2178    fn get_retry_strategy(error: &LanguageModelCompletionError) -> Option<RetryStrategy> {
2179        use LanguageModelCompletionError::*;
2180
2181        // General strategy here:
2182        // - If retrying won't help (e.g. invalid API key or payload too large), return None so we don't retry at all.
2183        // - If it's a time-based issue (e.g. server overloaded, rate limit exceeded), retry up to 4 times with exponential backoff.
2184        // - If it's an issue that *might* be fixed by retrying (e.g. internal server error), retry up to 3 times.
2185        match error {
2186            HttpResponseError {
2187                status_code: StatusCode::TOO_MANY_REQUESTS,
2188                ..
2189            } => Some(RetryStrategy::ExponentialBackoff {
2190                initial_delay: BASE_RETRY_DELAY,
2191                max_attempts: MAX_RETRY_ATTEMPTS,
2192            }),
2193            ServerOverloaded { retry_after, .. } | RateLimitExceeded { retry_after, .. } => {
2194                Some(RetryStrategy::Fixed {
2195                    delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
2196                    max_attempts: MAX_RETRY_ATTEMPTS,
2197                })
2198            }
2199            UpstreamProviderError {
2200                status,
2201                retry_after,
2202                ..
2203            } => match *status {
2204                StatusCode::TOO_MANY_REQUESTS | StatusCode::SERVICE_UNAVAILABLE => {
2205                    Some(RetryStrategy::Fixed {
2206                        delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
2207                        max_attempts: MAX_RETRY_ATTEMPTS,
2208                    })
2209                }
2210                StatusCode::INTERNAL_SERVER_ERROR => Some(RetryStrategy::Fixed {
2211                    delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
2212                    // Internal Server Error could be anything, retry up to 3 times.
2213                    max_attempts: 3,
2214                }),
2215                status => {
2216                    // There is no StatusCode variant for the unofficial HTTP 529 ("The service is overloaded"),
2217                    // but we frequently get them in practice. See https://http.dev/529
2218                    if status.as_u16() == 529 {
2219                        Some(RetryStrategy::Fixed {
2220                            delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
2221                            max_attempts: MAX_RETRY_ATTEMPTS,
2222                        })
2223                    } else {
2224                        Some(RetryStrategy::Fixed {
2225                            delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
2226                            max_attempts: 2,
2227                        })
2228                    }
2229                }
2230            },
2231            ApiInternalServerError { .. } => Some(RetryStrategy::Fixed {
2232                delay: BASE_RETRY_DELAY,
2233                max_attempts: 3,
2234            }),
2235            ApiReadResponseError { .. }
2236            | HttpSend { .. }
2237            | DeserializeResponse { .. }
2238            | BadRequestFormat { .. } => Some(RetryStrategy::Fixed {
2239                delay: BASE_RETRY_DELAY,
2240                max_attempts: 3,
2241            }),
2242            // Retrying these errors definitely shouldn't help.
2243            HttpResponseError {
2244                status_code:
2245                    StatusCode::PAYLOAD_TOO_LARGE | StatusCode::FORBIDDEN | StatusCode::UNAUTHORIZED,
2246                ..
2247            }
2248            | AuthenticationError { .. }
2249            | PermissionError { .. }
2250            | NoApiKey { .. }
2251            | ApiEndpointNotFound { .. }
2252            | PromptTooLarge { .. } => None,
2253            // These errors might be transient, so retry them
2254            SerializeRequest { .. } | BuildRequestBody { .. } => Some(RetryStrategy::Fixed {
2255                delay: BASE_RETRY_DELAY,
2256                max_attempts: 1,
2257            }),
2258            // Retry all other 4xx and 5xx errors once.
2259            HttpResponseError { status_code, .. }
2260                if status_code.is_client_error() || status_code.is_server_error() =>
2261            {
2262                Some(RetryStrategy::Fixed {
2263                    delay: BASE_RETRY_DELAY,
2264                    max_attempts: 3,
2265                })
2266            }
2267            Other(err)
2268                if err.is::<PaymentRequiredError>()
2269                    || err.is::<ModelRequestLimitReachedError>() =>
2270            {
2271                // Retrying won't help for Payment Required or Model Request Limit errors (where
2272                // the user must upgrade to usage-based billing to get more requests, or else wait
2273                // for a significant amount of time for the request limit to reset).
2274                None
2275            }
2276            // Conservatively assume that any other errors are non-retryable
2277            HttpResponseError { .. } | Other(..) => Some(RetryStrategy::Fixed {
2278                delay: BASE_RETRY_DELAY,
2279                max_attempts: 2,
2280            }),
2281        }
2282    }
2283
2284    fn handle_retryable_error_with_delay(
2285        &mut self,
2286        error: &LanguageModelCompletionError,
2287        strategy: Option<RetryStrategy>,
2288        model: Arc<dyn LanguageModel>,
2289        intent: CompletionIntent,
2290        window: Option<AnyWindowHandle>,
2291        cx: &mut Context<Self>,
2292    ) -> bool {
2293        // Store context for the Retry button
2294        self.last_error_context = Some((model.clone(), intent));
2295
2296        // Only auto-retry if Burn Mode is enabled
2297        if self.completion_mode != CompletionMode::Burn {
2298            // Show error with retry options
2299            cx.emit(ThreadEvent::ShowError(ThreadError::RetryableError {
2300                message: format!(
2301                    "{}\n\nTo automatically retry when similar errors happen, enable Burn Mode.",
2302                    error
2303                )
2304                .into(),
2305                can_enable_burn_mode: true,
2306            }));
2307            return false;
2308        }
2309
2310        let Some(strategy) = strategy.or_else(|| Self::get_retry_strategy(error)) else {
2311            return false;
2312        };
2313
2314        let max_attempts = match &strategy {
2315            RetryStrategy::ExponentialBackoff { max_attempts, .. } => *max_attempts,
2316            RetryStrategy::Fixed { max_attempts, .. } => *max_attempts,
2317        };
2318
2319        let retry_state = self.retry_state.get_or_insert(RetryState {
2320            attempt: 0,
2321            max_attempts,
2322            intent,
2323        });
2324
2325        retry_state.attempt += 1;
2326        let attempt = retry_state.attempt;
2327        let max_attempts = retry_state.max_attempts;
2328        let intent = retry_state.intent;
2329
2330        if attempt <= max_attempts {
2331            let delay = match &strategy {
2332                RetryStrategy::ExponentialBackoff { initial_delay, .. } => {
2333                    let delay_secs = initial_delay.as_secs() * 2u64.pow((attempt - 1) as u32);
2334                    Duration::from_secs(delay_secs)
2335                }
2336                RetryStrategy::Fixed { delay, .. } => *delay,
2337            };
2338
2339            // Add a transient message to inform the user
2340            let delay_secs = delay.as_secs();
2341            let retry_message = if max_attempts == 1 {
2342                format!("{error}. Retrying in {delay_secs} seconds...")
2343            } else {
2344                format!(
2345                    "{error}. Retrying (attempt {attempt} of {max_attempts}) \
2346                    in {delay_secs} seconds..."
2347                )
2348            };
2349            log::warn!(
2350                "Retrying completion request (attempt {attempt} of {max_attempts}) \
2351                in {delay_secs} seconds: {error:?}",
2352            );
2353
2354            // Add a UI-only message instead of a regular message
2355            let id = self.next_message_id.post_inc();
2356            self.messages.push(Message {
2357                id,
2358                role: Role::System,
2359                segments: vec![MessageSegment::Text(retry_message)],
2360                loaded_context: LoadedContext::default(),
2361                creases: Vec::new(),
2362                is_hidden: false,
2363                ui_only: true,
2364            });
2365            cx.emit(ThreadEvent::MessageAdded(id));
2366
2367            // Schedule the retry
2368            let thread_handle = cx.entity().downgrade();
2369
2370            cx.spawn(async move |_thread, cx| {
2371                cx.background_executor().timer(delay).await;
2372
2373                thread_handle
2374                    .update(cx, |thread, cx| {
2375                        // Retry the completion
2376                        thread.send_to_model(model, intent, window, cx);
2377                    })
2378                    .log_err();
2379            })
2380            .detach();
2381
2382            true
2383        } else {
2384            // Max retries exceeded
2385            self.retry_state = None;
2386
2387            // Stop generating since we're giving up on retrying.
2388            self.pending_completions.clear();
2389
2390            // Show error alongside a Retry button, but no
2391            // Enable Burn Mode button (since it's already enabled)
2392            cx.emit(ThreadEvent::ShowError(ThreadError::RetryableError {
2393                message: format!("Failed after retrying: {}", error).into(),
2394                can_enable_burn_mode: false,
2395            }));
2396
2397            false
2398        }
2399    }
2400
2401    pub fn start_generating_detailed_summary_if_needed(
2402        &mut self,
2403        thread_store: WeakEntity<ThreadStore>,
2404        cx: &mut Context<Self>,
2405    ) {
2406        let Some(last_message_id) = self.messages.last().map(|message| message.id) else {
2407            return;
2408        };
2409
2410        match &*self.detailed_summary_rx.borrow() {
2411            DetailedSummaryState::Generating { message_id, .. }
2412            | DetailedSummaryState::Generated { message_id, .. }
2413                if *message_id == last_message_id =>
2414            {
2415                // Already up-to-date
2416                return;
2417            }
2418            _ => {}
2419        }
2420
2421        let Some(ConfiguredModel { model, provider }) =
2422            LanguageModelRegistry::read_global(cx).thread_summary_model()
2423        else {
2424            return;
2425        };
2426
2427        if !provider.is_authenticated(cx) {
2428            return;
2429        }
2430
2431        let request = self.to_summarize_request(
2432            &model,
2433            CompletionIntent::ThreadContextSummarization,
2434            SUMMARIZE_THREAD_DETAILED_PROMPT.into(),
2435            cx,
2436        );
2437
2438        *self.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generating {
2439            message_id: last_message_id,
2440        };
2441
2442        // Replace the detailed summarization task if there is one, cancelling it. It would probably
2443        // be better to allow the old task to complete, but this would require logic for choosing
2444        // which result to prefer (the old task could complete after the new one, resulting in a
2445        // stale summary).
2446        self.detailed_summary_task = cx.spawn(async move |thread, cx| {
2447            let stream = model.stream_completion_text(request, cx);
2448            let Some(mut messages) = stream.await.log_err() else {
2449                thread
2450                    .update(cx, |thread, _cx| {
2451                        *thread.detailed_summary_tx.borrow_mut() =
2452                            DetailedSummaryState::NotGenerated;
2453                    })
2454                    .ok()?;
2455                return None;
2456            };
2457
2458            let mut new_detailed_summary = String::new();
2459
2460            while let Some(chunk) = messages.stream.next().await {
2461                if let Some(chunk) = chunk.log_err() {
2462                    new_detailed_summary.push_str(&chunk);
2463                }
2464            }
2465
2466            thread
2467                .update(cx, |thread, _cx| {
2468                    *thread.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generated {
2469                        text: new_detailed_summary.into(),
2470                        message_id: last_message_id,
2471                    };
2472                })
2473                .ok()?;
2474
2475            // Save thread so its summary can be reused later
2476            if let Some(thread) = thread.upgrade()
2477                && let Ok(Ok(save_task)) = cx.update(|cx| {
2478                    thread_store
2479                        .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx))
2480                })
2481            {
2482                save_task.await.log_err();
2483            }
2484
2485            Some(())
2486        });
2487    }
2488
2489    pub async fn wait_for_detailed_summary_or_text(
2490        this: &Entity<Self>,
2491        cx: &mut AsyncApp,
2492    ) -> Option<SharedString> {
2493        let mut detailed_summary_rx = this
2494            .read_with(cx, |this, _cx| this.detailed_summary_rx.clone())
2495            .ok()?;
2496        loop {
2497            match detailed_summary_rx.recv().await? {
2498                DetailedSummaryState::Generating { .. } => {}
2499                DetailedSummaryState::NotGenerated => {
2500                    return this.read_with(cx, |this, _cx| this.text().into()).ok();
2501                }
2502                DetailedSummaryState::Generated { text, .. } => return Some(text),
2503            }
2504        }
2505    }
2506
2507    pub fn latest_detailed_summary_or_text(&self) -> SharedString {
2508        self.detailed_summary_rx
2509            .borrow()
2510            .text()
2511            .unwrap_or_else(|| self.text().into())
2512    }
2513
2514    pub fn is_generating_detailed_summary(&self) -> bool {
2515        matches!(
2516            &*self.detailed_summary_rx.borrow(),
2517            DetailedSummaryState::Generating { .. }
2518        )
2519    }
2520
2521    pub fn use_pending_tools(
2522        &mut self,
2523        window: Option<AnyWindowHandle>,
2524        model: Arc<dyn LanguageModel>,
2525        cx: &mut Context<Self>,
2526    ) -> Vec<PendingToolUse> {
2527        let request =
2528            Arc::new(self.to_completion_request(model.clone(), CompletionIntent::ToolResults, cx));
2529        let pending_tool_uses = self
2530            .tool_use
2531            .pending_tool_uses()
2532            .into_iter()
2533            .filter(|tool_use| tool_use.status.is_idle())
2534            .cloned()
2535            .collect::<Vec<_>>();
2536
2537        for tool_use in pending_tool_uses.iter() {
2538            self.use_pending_tool(tool_use.clone(), request.clone(), model.clone(), window, cx);
2539        }
2540
2541        pending_tool_uses
2542    }
2543
2544    fn use_pending_tool(
2545        &mut self,
2546        tool_use: PendingToolUse,
2547        request: Arc<LanguageModelRequest>,
2548        model: Arc<dyn LanguageModel>,
2549        window: Option<AnyWindowHandle>,
2550        cx: &mut Context<Self>,
2551    ) {
2552        let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) else {
2553            return self.handle_hallucinated_tool_use(tool_use.id, tool_use.name, window, cx);
2554        };
2555
2556        if !self.profile.is_tool_enabled(tool.source(), tool.name(), cx) {
2557            return self.handle_hallucinated_tool_use(tool_use.id, tool_use.name, window, cx);
2558        }
2559
2560        if tool.needs_confirmation(&tool_use.input, &self.project, cx)
2561            && !AgentSettings::get_global(cx).always_allow_tool_actions
2562        {
2563            self.tool_use.confirm_tool_use(
2564                tool_use.id,
2565                tool_use.ui_text,
2566                tool_use.input,
2567                request,
2568                tool,
2569            );
2570            cx.emit(ThreadEvent::ToolConfirmationNeeded);
2571        } else {
2572            self.run_tool(
2573                tool_use.id,
2574                tool_use.ui_text,
2575                tool_use.input,
2576                request,
2577                tool,
2578                model,
2579                window,
2580                cx,
2581            );
2582        }
2583    }
2584
2585    pub fn handle_hallucinated_tool_use(
2586        &mut self,
2587        tool_use_id: LanguageModelToolUseId,
2588        hallucinated_tool_name: Arc<str>,
2589        window: Option<AnyWindowHandle>,
2590        cx: &mut Context<Thread>,
2591    ) {
2592        let available_tools = self.profile.enabled_tools(cx);
2593
2594        let tool_list = available_tools
2595            .iter()
2596            .map(|(name, tool)| format!("- {}: {}", name, tool.description()))
2597            .collect::<Vec<_>>()
2598            .join("\n");
2599
2600        let error_message = format!(
2601            "The tool '{}' doesn't exist or is not enabled. Available tools:\n{}",
2602            hallucinated_tool_name, tool_list
2603        );
2604
2605        let pending_tool_use = self.tool_use.insert_tool_output(
2606            tool_use_id.clone(),
2607            hallucinated_tool_name,
2608            Err(anyhow!("Missing tool call: {error_message}")),
2609            self.configured_model.as_ref(),
2610            self.completion_mode,
2611        );
2612
2613        cx.emit(ThreadEvent::MissingToolUse {
2614            tool_use_id: tool_use_id.clone(),
2615            ui_text: error_message.into(),
2616        });
2617
2618        self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2619    }
2620
2621    pub fn receive_invalid_tool_json(
2622        &mut self,
2623        tool_use_id: LanguageModelToolUseId,
2624        tool_name: Arc<str>,
2625        invalid_json: Arc<str>,
2626        error: String,
2627        window: Option<AnyWindowHandle>,
2628        cx: &mut Context<Thread>,
2629    ) {
2630        log::error!("The model returned invalid input JSON: {invalid_json}");
2631
2632        let pending_tool_use = self.tool_use.insert_tool_output(
2633            tool_use_id.clone(),
2634            tool_name,
2635            Err(anyhow!("Error parsing input JSON: {error}")),
2636            self.configured_model.as_ref(),
2637            self.completion_mode,
2638        );
2639        let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
2640            pending_tool_use.ui_text.clone()
2641        } else {
2642            log::error!(
2643                "There was no pending tool use for tool use {tool_use_id}, even though it finished (with invalid input JSON)."
2644            );
2645            format!("Unknown tool {}", tool_use_id).into()
2646        };
2647
2648        cx.emit(ThreadEvent::InvalidToolInput {
2649            tool_use_id: tool_use_id.clone(),
2650            ui_text,
2651            invalid_input_json: invalid_json,
2652        });
2653
2654        self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2655    }
2656
2657    pub fn run_tool(
2658        &mut self,
2659        tool_use_id: LanguageModelToolUseId,
2660        ui_text: impl Into<SharedString>,
2661        input: serde_json::Value,
2662        request: Arc<LanguageModelRequest>,
2663        tool: Arc<dyn Tool>,
2664        model: Arc<dyn LanguageModel>,
2665        window: Option<AnyWindowHandle>,
2666        cx: &mut Context<Thread>,
2667    ) {
2668        let task =
2669            self.spawn_tool_use(tool_use_id.clone(), request, input, tool, model, window, cx);
2670        self.tool_use
2671            .run_pending_tool(tool_use_id, ui_text.into(), task);
2672    }
2673
2674    fn spawn_tool_use(
2675        &mut self,
2676        tool_use_id: LanguageModelToolUseId,
2677        request: Arc<LanguageModelRequest>,
2678        input: serde_json::Value,
2679        tool: Arc<dyn Tool>,
2680        model: Arc<dyn LanguageModel>,
2681        window: Option<AnyWindowHandle>,
2682        cx: &mut Context<Thread>,
2683    ) -> Task<()> {
2684        let tool_name: Arc<str> = tool.name().into();
2685
2686        let tool_result = tool.run(
2687            input,
2688            request,
2689            self.project.clone(),
2690            self.action_log.clone(),
2691            model,
2692            window,
2693            cx,
2694        );
2695
2696        // Store the card separately if it exists
2697        if let Some(card) = tool_result.card.clone() {
2698            self.tool_use
2699                .insert_tool_result_card(tool_use_id.clone(), card);
2700        }
2701
2702        cx.spawn({
2703            async move |thread: WeakEntity<Thread>, cx| {
2704                let output = tool_result.output.await;
2705
2706                thread
2707                    .update(cx, |thread, cx| {
2708                        let pending_tool_use = thread.tool_use.insert_tool_output(
2709                            tool_use_id.clone(),
2710                            tool_name,
2711                            output,
2712                            thread.configured_model.as_ref(),
2713                            thread.completion_mode,
2714                        );
2715                        thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2716                    })
2717                    .ok();
2718            }
2719        })
2720    }
2721
2722    fn tool_finished(
2723        &mut self,
2724        tool_use_id: LanguageModelToolUseId,
2725        pending_tool_use: Option<PendingToolUse>,
2726        canceled: bool,
2727        window: Option<AnyWindowHandle>,
2728        cx: &mut Context<Self>,
2729    ) {
2730        if self.all_tools_finished()
2731            && let Some(ConfiguredModel { model, .. }) = self.configured_model.as_ref()
2732            && !canceled
2733        {
2734            self.send_to_model(model.clone(), CompletionIntent::ToolResults, window, cx);
2735        }
2736
2737        cx.emit(ThreadEvent::ToolFinished {
2738            tool_use_id,
2739            pending_tool_use,
2740        });
2741    }
2742
2743    /// Cancels the last pending completion, if there are any pending.
2744    ///
2745    /// Returns whether a completion was canceled.
2746    pub fn cancel_last_completion(
2747        &mut self,
2748        window: Option<AnyWindowHandle>,
2749        cx: &mut Context<Self>,
2750    ) -> bool {
2751        let mut canceled = self.pending_completions.pop().is_some() || self.retry_state.is_some();
2752
2753        self.retry_state = None;
2754
2755        for pending_tool_use in self.tool_use.cancel_pending() {
2756            canceled = true;
2757            self.tool_finished(
2758                pending_tool_use.id.clone(),
2759                Some(pending_tool_use),
2760                true,
2761                window,
2762                cx,
2763            );
2764        }
2765
2766        if canceled {
2767            cx.emit(ThreadEvent::CompletionCanceled);
2768
2769            // When canceled, we always want to insert the checkpoint.
2770            // (We skip over finalize_pending_checkpoint, because it
2771            // would conclude we didn't have anything to insert here.)
2772            if let Some(checkpoint) = self.pending_checkpoint.take() {
2773                self.insert_checkpoint(checkpoint, cx);
2774            }
2775        } else {
2776            self.finalize_pending_checkpoint(cx);
2777        }
2778
2779        canceled
2780    }
2781
2782    /// Signals that any in-progress editing should be canceled.
2783    ///
2784    /// This method is used to notify listeners (like ActiveThread) that
2785    /// they should cancel any editing operations.
2786    pub fn cancel_editing(&mut self, cx: &mut Context<Self>) {
2787        cx.emit(ThreadEvent::CancelEditing);
2788    }
2789
2790    pub fn feedback(&self) -> Option<ThreadFeedback> {
2791        self.feedback
2792    }
2793
2794    pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
2795        self.message_feedback.get(&message_id).copied()
2796    }
2797
2798    pub fn report_message_feedback(
2799        &mut self,
2800        message_id: MessageId,
2801        feedback: ThreadFeedback,
2802        cx: &mut Context<Self>,
2803    ) -> Task<Result<()>> {
2804        if self.message_feedback.get(&message_id) == Some(&feedback) {
2805            return Task::ready(Ok(()));
2806        }
2807
2808        let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2809        let serialized_thread = self.serialize(cx);
2810        let thread_id = self.id().clone();
2811        let client = self.project.read(cx).client();
2812
2813        let enabled_tool_names: Vec<String> = self
2814            .profile
2815            .enabled_tools(cx)
2816            .iter()
2817            .map(|(name, _)| name.clone().into())
2818            .collect();
2819
2820        self.message_feedback.insert(message_id, feedback);
2821
2822        cx.notify();
2823
2824        let message_content = self
2825            .message(message_id)
2826            .map(|msg| msg.to_string())
2827            .unwrap_or_default();
2828
2829        cx.background_spawn(async move {
2830            let final_project_snapshot = final_project_snapshot.await;
2831            let serialized_thread = serialized_thread.await?;
2832            let thread_data =
2833                serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
2834
2835            let rating = match feedback {
2836                ThreadFeedback::Positive => "positive",
2837                ThreadFeedback::Negative => "negative",
2838            };
2839            telemetry::event!(
2840                "Assistant Thread Rated",
2841                rating,
2842                thread_id,
2843                enabled_tool_names,
2844                message_id = message_id.0,
2845                message_content,
2846                thread_data,
2847                final_project_snapshot
2848            );
2849            client.telemetry().flush_events().await;
2850
2851            Ok(())
2852        })
2853    }
2854
2855    pub fn report_feedback(
2856        &mut self,
2857        feedback: ThreadFeedback,
2858        cx: &mut Context<Self>,
2859    ) -> Task<Result<()>> {
2860        let last_assistant_message_id = self
2861            .messages
2862            .iter()
2863            .rev()
2864            .find(|msg| msg.role == Role::Assistant)
2865            .map(|msg| msg.id);
2866
2867        if let Some(message_id) = last_assistant_message_id {
2868            self.report_message_feedback(message_id, feedback, cx)
2869        } else {
2870            let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2871            let serialized_thread = self.serialize(cx);
2872            let thread_id = self.id().clone();
2873            let client = self.project.read(cx).client();
2874            self.feedback = Some(feedback);
2875            cx.notify();
2876
2877            cx.background_spawn(async move {
2878                let final_project_snapshot = final_project_snapshot.await;
2879                let serialized_thread = serialized_thread.await?;
2880                let thread_data = serde_json::to_value(serialized_thread)
2881                    .unwrap_or_else(|_| serde_json::Value::Null);
2882
2883                let rating = match feedback {
2884                    ThreadFeedback::Positive => "positive",
2885                    ThreadFeedback::Negative => "negative",
2886                };
2887                telemetry::event!(
2888                    "Assistant Thread Rated",
2889                    rating,
2890                    thread_id,
2891                    thread_data,
2892                    final_project_snapshot
2893                );
2894                client.telemetry().flush_events().await;
2895
2896                Ok(())
2897            })
2898        }
2899    }
2900
2901    /// Create a snapshot of the current project state including git information and unsaved buffers.
2902    fn project_snapshot(
2903        project: Entity<Project>,
2904        cx: &mut Context<Self>,
2905    ) -> Task<Arc<ProjectSnapshot>> {
2906        let git_store = project.read(cx).git_store().clone();
2907        let worktree_snapshots: Vec<_> = project
2908            .read(cx)
2909            .visible_worktrees(cx)
2910            .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
2911            .collect();
2912
2913        cx.spawn(async move |_, cx| {
2914            let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
2915
2916            let mut unsaved_buffers = Vec::new();
2917            cx.update(|app_cx| {
2918                let buffer_store = project.read(app_cx).buffer_store();
2919                for buffer_handle in buffer_store.read(app_cx).buffers() {
2920                    let buffer = buffer_handle.read(app_cx);
2921                    if buffer.is_dirty()
2922                        && let Some(file) = buffer.file()
2923                    {
2924                        let path = file.path().to_string_lossy().to_string();
2925                        unsaved_buffers.push(path);
2926                    }
2927                }
2928            })
2929            .ok();
2930
2931            Arc::new(ProjectSnapshot {
2932                worktree_snapshots,
2933                unsaved_buffer_paths: unsaved_buffers,
2934                timestamp: Utc::now(),
2935            })
2936        })
2937    }
2938
2939    fn worktree_snapshot(
2940        worktree: Entity<project::Worktree>,
2941        git_store: Entity<GitStore>,
2942        cx: &App,
2943    ) -> Task<WorktreeSnapshot> {
2944        cx.spawn(async move |cx| {
2945            // Get worktree path and snapshot
2946            let worktree_info = cx.update(|app_cx| {
2947                let worktree = worktree.read(app_cx);
2948                let path = worktree.abs_path().to_string_lossy().to_string();
2949                let snapshot = worktree.snapshot();
2950                (path, snapshot)
2951            });
2952
2953            let Ok((worktree_path, _snapshot)) = worktree_info else {
2954                return WorktreeSnapshot {
2955                    worktree_path: String::new(),
2956                    git_state: None,
2957                };
2958            };
2959
2960            let git_state = git_store
2961                .update(cx, |git_store, cx| {
2962                    git_store
2963                        .repositories()
2964                        .values()
2965                        .find(|repo| {
2966                            repo.read(cx)
2967                                .abs_path_to_repo_path(&worktree.read(cx).abs_path())
2968                                .is_some()
2969                        })
2970                        .cloned()
2971                })
2972                .ok()
2973                .flatten()
2974                .map(|repo| {
2975                    repo.update(cx, |repo, _| {
2976                        let current_branch =
2977                            repo.branch.as_ref().map(|branch| branch.name().to_owned());
2978                        repo.send_job(None, |state, _| async move {
2979                            let RepositoryState::Local { backend, .. } = state else {
2980                                return GitState {
2981                                    remote_url: None,
2982                                    head_sha: None,
2983                                    current_branch,
2984                                    diff: None,
2985                                };
2986                            };
2987
2988                            let remote_url = backend.remote_url("origin");
2989                            let head_sha = backend.head_sha().await;
2990                            let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
2991
2992                            GitState {
2993                                remote_url,
2994                                head_sha,
2995                                current_branch,
2996                                diff,
2997                            }
2998                        })
2999                    })
3000                });
3001
3002            let git_state = match git_state {
3003                Some(git_state) => match git_state.ok() {
3004                    Some(git_state) => git_state.await.ok(),
3005                    None => None,
3006                },
3007                None => None,
3008            };
3009
3010            WorktreeSnapshot {
3011                worktree_path,
3012                git_state,
3013            }
3014        })
3015    }
3016
3017    pub fn to_markdown(&self, cx: &App) -> Result<String> {
3018        let mut markdown = Vec::new();
3019
3020        let summary = self.summary().or_default();
3021        writeln!(markdown, "# {summary}\n")?;
3022
3023        for message in self.messages() {
3024            writeln!(
3025                markdown,
3026                "## {role}\n",
3027                role = match message.role {
3028                    Role::User => "User",
3029                    Role::Assistant => "Agent",
3030                    Role::System => "System",
3031                }
3032            )?;
3033
3034            if !message.loaded_context.text.is_empty() {
3035                writeln!(markdown, "{}", message.loaded_context.text)?;
3036            }
3037
3038            if !message.loaded_context.images.is_empty() {
3039                writeln!(
3040                    markdown,
3041                    "\n{} images attached as context.\n",
3042                    message.loaded_context.images.len()
3043                )?;
3044            }
3045
3046            for segment in &message.segments {
3047                match segment {
3048                    MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
3049                    MessageSegment::Thinking { text, .. } => {
3050                        writeln!(markdown, "<think>\n{}\n</think>\n", text)?
3051                    }
3052                    MessageSegment::RedactedThinking(_) => {}
3053                }
3054            }
3055
3056            for tool_use in self.tool_uses_for_message(message.id, cx) {
3057                writeln!(
3058                    markdown,
3059                    "**Use Tool: {} ({})**",
3060                    tool_use.name, tool_use.id
3061                )?;
3062                writeln!(markdown, "```json")?;
3063                writeln!(
3064                    markdown,
3065                    "{}",
3066                    serde_json::to_string_pretty(&tool_use.input)?
3067                )?;
3068                writeln!(markdown, "```")?;
3069            }
3070
3071            for tool_result in self.tool_results_for_message(message.id) {
3072                write!(markdown, "\n**Tool Results: {}", tool_result.tool_use_id)?;
3073                if tool_result.is_error {
3074                    write!(markdown, " (Error)")?;
3075                }
3076
3077                writeln!(markdown, "**\n")?;
3078                match &tool_result.content {
3079                    LanguageModelToolResultContent::Text(text) => {
3080                        writeln!(markdown, "{text}")?;
3081                    }
3082                    LanguageModelToolResultContent::Image(image) => {
3083                        writeln!(markdown, "![Image](data:base64,{})", image.source)?;
3084                    }
3085                }
3086
3087                if let Some(output) = tool_result.output.as_ref() {
3088                    writeln!(
3089                        markdown,
3090                        "\n\nDebug Output:\n\n```json\n{}\n```\n",
3091                        serde_json::to_string_pretty(output)?
3092                    )?;
3093                }
3094            }
3095        }
3096
3097        Ok(String::from_utf8_lossy(&markdown).to_string())
3098    }
3099
3100    pub fn keep_edits_in_range(
3101        &mut self,
3102        buffer: Entity<language::Buffer>,
3103        buffer_range: Range<language::Anchor>,
3104        cx: &mut Context<Self>,
3105    ) {
3106        self.action_log.update(cx, |action_log, cx| {
3107            action_log.keep_edits_in_range(buffer, buffer_range, cx)
3108        });
3109    }
3110
3111    pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
3112        self.action_log
3113            .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
3114    }
3115
3116    pub fn reject_edits_in_ranges(
3117        &mut self,
3118        buffer: Entity<language::Buffer>,
3119        buffer_ranges: Vec<Range<language::Anchor>>,
3120        cx: &mut Context<Self>,
3121    ) -> Task<Result<()>> {
3122        self.action_log.update(cx, |action_log, cx| {
3123            action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
3124        })
3125    }
3126
3127    pub fn action_log(&self) -> &Entity<ActionLog> {
3128        &self.action_log
3129    }
3130
3131    pub fn project(&self) -> &Entity<Project> {
3132        &self.project
3133    }
3134
3135    pub fn cumulative_token_usage(&self) -> TokenUsage {
3136        self.cumulative_token_usage
3137    }
3138
3139    pub fn token_usage_up_to_message(&self, message_id: MessageId) -> TotalTokenUsage {
3140        let Some(model) = self.configured_model.as_ref() else {
3141            return TotalTokenUsage::default();
3142        };
3143
3144        let max = model
3145            .model
3146            .max_token_count_for_mode(self.completion_mode().into());
3147
3148        let index = self
3149            .messages
3150            .iter()
3151            .position(|msg| msg.id == message_id)
3152            .unwrap_or(0);
3153
3154        if index == 0 {
3155            return TotalTokenUsage { total: 0, max };
3156        }
3157
3158        let token_usage = &self
3159            .request_token_usage
3160            .get(index - 1)
3161            .cloned()
3162            .unwrap_or_default();
3163
3164        TotalTokenUsage {
3165            total: token_usage.total_tokens(),
3166            max,
3167        }
3168    }
3169
3170    pub fn total_token_usage(&self) -> Option<TotalTokenUsage> {
3171        let model = self.configured_model.as_ref()?;
3172
3173        let max = model
3174            .model
3175            .max_token_count_for_mode(self.completion_mode().into());
3176
3177        if let Some(exceeded_error) = &self.exceeded_window_error
3178            && model.model.id() == exceeded_error.model_id
3179        {
3180            return Some(TotalTokenUsage {
3181                total: exceeded_error.token_count,
3182                max,
3183            });
3184        }
3185
3186        let total = self
3187            .token_usage_at_last_message()
3188            .unwrap_or_default()
3189            .total_tokens();
3190
3191        Some(TotalTokenUsage { total, max })
3192    }
3193
3194    fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
3195        self.request_token_usage
3196            .get(self.messages.len().saturating_sub(1))
3197            .or_else(|| self.request_token_usage.last())
3198            .cloned()
3199    }
3200
3201    fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
3202        let placeholder = self.token_usage_at_last_message().unwrap_or_default();
3203        self.request_token_usage
3204            .resize(self.messages.len(), placeholder);
3205
3206        if let Some(last) = self.request_token_usage.last_mut() {
3207            *last = token_usage;
3208        }
3209    }
3210
3211    fn update_model_request_usage(&self, amount: u32, limit: UsageLimit, cx: &mut Context<Self>) {
3212        self.project
3213            .read(cx)
3214            .user_store()
3215            .update(cx, |user_store, cx| {
3216                user_store.update_model_request_usage(
3217                    ModelRequestUsage(RequestUsage {
3218                        amount: amount as i32,
3219                        limit,
3220                    }),
3221                    cx,
3222                )
3223            });
3224    }
3225
3226    pub fn deny_tool_use(
3227        &mut self,
3228        tool_use_id: LanguageModelToolUseId,
3229        tool_name: Arc<str>,
3230        window: Option<AnyWindowHandle>,
3231        cx: &mut Context<Self>,
3232    ) {
3233        let err = Err(anyhow::anyhow!(
3234            "Permission to run tool action denied by user"
3235        ));
3236
3237        self.tool_use.insert_tool_output(
3238            tool_use_id.clone(),
3239            tool_name,
3240            err,
3241            self.configured_model.as_ref(),
3242            self.completion_mode,
3243        );
3244        self.tool_finished(tool_use_id.clone(), None, true, window, cx);
3245    }
3246}
3247
3248#[derive(Debug, Clone, Error)]
3249pub enum ThreadError {
3250    #[error("Payment required")]
3251    PaymentRequired,
3252    #[error("Model request limit reached")]
3253    ModelRequestLimitReached { plan: Plan },
3254    #[error("Message {header}: {message}")]
3255    Message {
3256        header: SharedString,
3257        message: SharedString,
3258    },
3259    #[error("Retryable error: {message}")]
3260    RetryableError {
3261        message: SharedString,
3262        can_enable_burn_mode: bool,
3263    },
3264}
3265
3266#[derive(Debug, Clone)]
3267pub enum ThreadEvent {
3268    ShowError(ThreadError),
3269    StreamedCompletion,
3270    ReceivedTextChunk,
3271    NewRequest,
3272    StreamedAssistantText(MessageId, String),
3273    StreamedAssistantThinking(MessageId, String),
3274    StreamedToolUse {
3275        tool_use_id: LanguageModelToolUseId,
3276        ui_text: Arc<str>,
3277        input: serde_json::Value,
3278    },
3279    MissingToolUse {
3280        tool_use_id: LanguageModelToolUseId,
3281        ui_text: Arc<str>,
3282    },
3283    InvalidToolInput {
3284        tool_use_id: LanguageModelToolUseId,
3285        ui_text: Arc<str>,
3286        invalid_input_json: Arc<str>,
3287    },
3288    Stopped(Result<StopReason, Arc<anyhow::Error>>),
3289    MessageAdded(MessageId),
3290    MessageEdited(MessageId),
3291    MessageDeleted(MessageId),
3292    SummaryGenerated,
3293    SummaryChanged,
3294    UsePendingTools {
3295        tool_uses: Vec<PendingToolUse>,
3296    },
3297    ToolFinished {
3298        #[allow(unused)]
3299        tool_use_id: LanguageModelToolUseId,
3300        /// The pending tool use that corresponds to this tool.
3301        pending_tool_use: Option<PendingToolUse>,
3302    },
3303    CheckpointChanged,
3304    ToolConfirmationNeeded,
3305    ToolUseLimitReached,
3306    CancelEditing,
3307    CompletionCanceled,
3308    ProfileChanged,
3309}
3310
3311impl EventEmitter<ThreadEvent> for Thread {}
3312
3313struct PendingCompletion {
3314    id: usize,
3315    queue_state: QueueState,
3316    _task: Task<()>,
3317}
3318
3319#[cfg(test)]
3320mod tests {
3321    use super::*;
3322    use crate::{
3323        context::load_context, context_store::ContextStore, thread_store, thread_store::ThreadStore,
3324    };
3325
3326    // Test-specific constants
3327    const TEST_RATE_LIMIT_RETRY_SECS: u64 = 30;
3328    use agent_settings::{AgentProfileId, AgentSettings, LanguageModelParameters};
3329    use assistant_tool::ToolRegistry;
3330    use assistant_tools;
3331    use futures::StreamExt;
3332    use futures::future::BoxFuture;
3333    use futures::stream::BoxStream;
3334    use gpui::TestAppContext;
3335    use http_client;
3336    use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
3337    use language_model::{
3338        LanguageModelCompletionError, LanguageModelName, LanguageModelProviderId,
3339        LanguageModelProviderName, LanguageModelToolChoice,
3340    };
3341    use parking_lot::Mutex;
3342    use project::{FakeFs, Project};
3343    use prompt_store::PromptBuilder;
3344    use serde_json::json;
3345    use settings::{Settings, SettingsStore};
3346    use std::sync::Arc;
3347    use std::time::Duration;
3348    use theme::ThemeSettings;
3349    use util::path;
3350    use workspace::Workspace;
3351
3352    #[gpui::test]
3353    async fn test_message_with_context(cx: &mut TestAppContext) {
3354        init_test_settings(cx);
3355
3356        let project = create_test_project(
3357            cx,
3358            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
3359        )
3360        .await;
3361
3362        let (_workspace, _thread_store, thread, context_store, model) =
3363            setup_test_environment(cx, project.clone()).await;
3364
3365        add_file_to_context(&project, &context_store, "test/code.rs", cx)
3366            .await
3367            .unwrap();
3368
3369        let context =
3370            context_store.read_with(cx, |store, _| store.context().next().cloned().unwrap());
3371        let loaded_context = cx
3372            .update(|cx| load_context(vec![context], &project, &None, cx))
3373            .await;
3374
3375        // Insert user message with context
3376        let message_id = thread.update(cx, |thread, cx| {
3377            thread.insert_user_message(
3378                "Please explain this code",
3379                loaded_context,
3380                None,
3381                Vec::new(),
3382                cx,
3383            )
3384        });
3385
3386        // Check content and context in message object
3387        let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
3388
3389        // Use different path format strings based on platform for the test
3390        #[cfg(windows)]
3391        let path_part = r"test\code.rs";
3392        #[cfg(not(windows))]
3393        let path_part = "test/code.rs";
3394
3395        let expected_context = format!(
3396            r#"
3397<context>
3398The following items were attached by the user. They are up-to-date and don't need to be re-read.
3399
3400<files>
3401```rs {path_part}
3402fn main() {{
3403    println!("Hello, world!");
3404}}
3405```
3406</files>
3407</context>
3408"#
3409        );
3410
3411        assert_eq!(message.role, Role::User);
3412        assert_eq!(message.segments.len(), 1);
3413        assert_eq!(
3414            message.segments[0],
3415            MessageSegment::Text("Please explain this code".to_string())
3416        );
3417        assert_eq!(message.loaded_context.text, expected_context);
3418
3419        // Check message in request
3420        let request = thread.update(cx, |thread, cx| {
3421            thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3422        });
3423
3424        assert_eq!(request.messages.len(), 2);
3425        let expected_full_message = format!("{}Please explain this code", expected_context);
3426        assert_eq!(request.messages[1].string_contents(), expected_full_message);
3427    }
3428
3429    #[gpui::test]
3430    async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
3431        init_test_settings(cx);
3432
3433        let project = create_test_project(
3434            cx,
3435            json!({
3436                "file1.rs": "fn function1() {}\n",
3437                "file2.rs": "fn function2() {}\n",
3438                "file3.rs": "fn function3() {}\n",
3439                "file4.rs": "fn function4() {}\n",
3440            }),
3441        )
3442        .await;
3443
3444        let (_, _thread_store, thread, context_store, model) =
3445            setup_test_environment(cx, project.clone()).await;
3446
3447        // First message with context 1
3448        add_file_to_context(&project, &context_store, "test/file1.rs", cx)
3449            .await
3450            .unwrap();
3451        let new_contexts = context_store.update(cx, |store, cx| {
3452            store.new_context_for_thread(thread.read(cx), None)
3453        });
3454        assert_eq!(new_contexts.len(), 1);
3455        let loaded_context = cx
3456            .update(|cx| load_context(new_contexts, &project, &None, cx))
3457            .await;
3458        let message1_id = thread.update(cx, |thread, cx| {
3459            thread.insert_user_message("Message 1", loaded_context, None, Vec::new(), cx)
3460        });
3461
3462        // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
3463        add_file_to_context(&project, &context_store, "test/file2.rs", cx)
3464            .await
3465            .unwrap();
3466        let new_contexts = context_store.update(cx, |store, cx| {
3467            store.new_context_for_thread(thread.read(cx), None)
3468        });
3469        assert_eq!(new_contexts.len(), 1);
3470        let loaded_context = cx
3471            .update(|cx| load_context(new_contexts, &project, &None, cx))
3472            .await;
3473        let message2_id = thread.update(cx, |thread, cx| {
3474            thread.insert_user_message("Message 2", loaded_context, None, Vec::new(), cx)
3475        });
3476
3477        // Third message with all three contexts (contexts 1 and 2 should be skipped)
3478        //
3479        add_file_to_context(&project, &context_store, "test/file3.rs", cx)
3480            .await
3481            .unwrap();
3482        let new_contexts = context_store.update(cx, |store, cx| {
3483            store.new_context_for_thread(thread.read(cx), None)
3484        });
3485        assert_eq!(new_contexts.len(), 1);
3486        let loaded_context = cx
3487            .update(|cx| load_context(new_contexts, &project, &None, cx))
3488            .await;
3489        let message3_id = thread.update(cx, |thread, cx| {
3490            thread.insert_user_message("Message 3", loaded_context, None, Vec::new(), cx)
3491        });
3492
3493        // Check what contexts are included in each message
3494        let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
3495            (
3496                thread.message(message1_id).unwrap().clone(),
3497                thread.message(message2_id).unwrap().clone(),
3498                thread.message(message3_id).unwrap().clone(),
3499            )
3500        });
3501
3502        // First message should include context 1
3503        assert!(message1.loaded_context.text.contains("file1.rs"));
3504
3505        // Second message should include only context 2 (not 1)
3506        assert!(!message2.loaded_context.text.contains("file1.rs"));
3507        assert!(message2.loaded_context.text.contains("file2.rs"));
3508
3509        // Third message should include only context 3 (not 1 or 2)
3510        assert!(!message3.loaded_context.text.contains("file1.rs"));
3511        assert!(!message3.loaded_context.text.contains("file2.rs"));
3512        assert!(message3.loaded_context.text.contains("file3.rs"));
3513
3514        // Check entire request to make sure all contexts are properly included
3515        let request = thread.update(cx, |thread, cx| {
3516            thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3517        });
3518
3519        // The request should contain all 3 messages
3520        assert_eq!(request.messages.len(), 4);
3521
3522        // Check that the contexts are properly formatted in each message
3523        assert!(request.messages[1].string_contents().contains("file1.rs"));
3524        assert!(!request.messages[1].string_contents().contains("file2.rs"));
3525        assert!(!request.messages[1].string_contents().contains("file3.rs"));
3526
3527        assert!(!request.messages[2].string_contents().contains("file1.rs"));
3528        assert!(request.messages[2].string_contents().contains("file2.rs"));
3529        assert!(!request.messages[2].string_contents().contains("file3.rs"));
3530
3531        assert!(!request.messages[3].string_contents().contains("file1.rs"));
3532        assert!(!request.messages[3].string_contents().contains("file2.rs"));
3533        assert!(request.messages[3].string_contents().contains("file3.rs"));
3534
3535        add_file_to_context(&project, &context_store, "test/file4.rs", cx)
3536            .await
3537            .unwrap();
3538        let new_contexts = context_store.update(cx, |store, cx| {
3539            store.new_context_for_thread(thread.read(cx), Some(message2_id))
3540        });
3541        assert_eq!(new_contexts.len(), 3);
3542        let loaded_context = cx
3543            .update(|cx| load_context(new_contexts, &project, &None, cx))
3544            .await
3545            .loaded_context;
3546
3547        assert!(!loaded_context.text.contains("file1.rs"));
3548        assert!(loaded_context.text.contains("file2.rs"));
3549        assert!(loaded_context.text.contains("file3.rs"));
3550        assert!(loaded_context.text.contains("file4.rs"));
3551
3552        let new_contexts = context_store.update(cx, |store, cx| {
3553            // Remove file4.rs
3554            store.remove_context(&loaded_context.contexts[2].handle(), cx);
3555            store.new_context_for_thread(thread.read(cx), Some(message2_id))
3556        });
3557        assert_eq!(new_contexts.len(), 2);
3558        let loaded_context = cx
3559            .update(|cx| load_context(new_contexts, &project, &None, cx))
3560            .await
3561            .loaded_context;
3562
3563        assert!(!loaded_context.text.contains("file1.rs"));
3564        assert!(loaded_context.text.contains("file2.rs"));
3565        assert!(loaded_context.text.contains("file3.rs"));
3566        assert!(!loaded_context.text.contains("file4.rs"));
3567
3568        let new_contexts = context_store.update(cx, |store, cx| {
3569            // Remove file3.rs
3570            store.remove_context(&loaded_context.contexts[1].handle(), cx);
3571            store.new_context_for_thread(thread.read(cx), Some(message2_id))
3572        });
3573        assert_eq!(new_contexts.len(), 1);
3574        let loaded_context = cx
3575            .update(|cx| load_context(new_contexts, &project, &None, cx))
3576            .await
3577            .loaded_context;
3578
3579        assert!(!loaded_context.text.contains("file1.rs"));
3580        assert!(loaded_context.text.contains("file2.rs"));
3581        assert!(!loaded_context.text.contains("file3.rs"));
3582        assert!(!loaded_context.text.contains("file4.rs"));
3583    }
3584
3585    #[gpui::test]
3586    async fn test_message_without_files(cx: &mut TestAppContext) {
3587        init_test_settings(cx);
3588
3589        let project = create_test_project(
3590            cx,
3591            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
3592        )
3593        .await;
3594
3595        let (_, _thread_store, thread, _context_store, model) =
3596            setup_test_environment(cx, project.clone()).await;
3597
3598        // Insert user message without any context (empty context vector)
3599        let message_id = thread.update(cx, |thread, cx| {
3600            thread.insert_user_message(
3601                "What is the best way to learn Rust?",
3602                ContextLoadResult::default(),
3603                None,
3604                Vec::new(),
3605                cx,
3606            )
3607        });
3608
3609        // Check content and context in message object
3610        let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
3611
3612        // Context should be empty when no files are included
3613        assert_eq!(message.role, Role::User);
3614        assert_eq!(message.segments.len(), 1);
3615        assert_eq!(
3616            message.segments[0],
3617            MessageSegment::Text("What is the best way to learn Rust?".to_string())
3618        );
3619        assert_eq!(message.loaded_context.text, "");
3620
3621        // Check message in request
3622        let request = thread.update(cx, |thread, cx| {
3623            thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3624        });
3625
3626        assert_eq!(request.messages.len(), 2);
3627        assert_eq!(
3628            request.messages[1].string_contents(),
3629            "What is the best way to learn Rust?"
3630        );
3631
3632        // Add second message, also without context
3633        let message2_id = thread.update(cx, |thread, cx| {
3634            thread.insert_user_message(
3635                "Are there any good books?",
3636                ContextLoadResult::default(),
3637                None,
3638                Vec::new(),
3639                cx,
3640            )
3641        });
3642
3643        let message2 =
3644            thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
3645        assert_eq!(message2.loaded_context.text, "");
3646
3647        // Check that both messages appear in the request
3648        let request = thread.update(cx, |thread, cx| {
3649            thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3650        });
3651
3652        assert_eq!(request.messages.len(), 3);
3653        assert_eq!(
3654            request.messages[1].string_contents(),
3655            "What is the best way to learn Rust?"
3656        );
3657        assert_eq!(
3658            request.messages[2].string_contents(),
3659            "Are there any good books?"
3660        );
3661    }
3662
3663    #[gpui::test]
3664    #[ignore] // turn this test on when project_notifications tool is re-enabled
3665    async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
3666        init_test_settings(cx);
3667
3668        let project = create_test_project(
3669            cx,
3670            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
3671        )
3672        .await;
3673
3674        let (_workspace, _thread_store, thread, context_store, model) =
3675            setup_test_environment(cx, project.clone()).await;
3676
3677        // Add a buffer to the context. This will be a tracked buffer
3678        let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
3679            .await
3680            .unwrap();
3681
3682        let context = context_store
3683            .read_with(cx, |store, _| store.context().next().cloned())
3684            .unwrap();
3685        let loaded_context = cx
3686            .update(|cx| load_context(vec![context], &project, &None, cx))
3687            .await;
3688
3689        // Insert user message and assistant response
3690        thread.update(cx, |thread, cx| {
3691            thread.insert_user_message("Explain this code", loaded_context, None, Vec::new(), cx);
3692            thread.insert_assistant_message(
3693                vec![MessageSegment::Text("This code prints 42.".into())],
3694                cx,
3695            );
3696        });
3697        cx.run_until_parked();
3698
3699        // We shouldn't have a stale buffer notification yet
3700        let notifications = thread.read_with(cx, |thread, _| {
3701            find_tool_uses(thread, "project_notifications")
3702        });
3703        assert!(
3704            notifications.is_empty(),
3705            "Should not have stale buffer notification before buffer is modified"
3706        );
3707
3708        // Modify the buffer
3709        buffer.update(cx, |buffer, cx| {
3710            buffer.edit(
3711                [(1..1, "\n    println!(\"Added a new line\");\n")],
3712                None,
3713                cx,
3714            );
3715        });
3716
3717        // Insert another user message
3718        thread.update(cx, |thread, cx| {
3719            thread.insert_user_message(
3720                "What does the code do now?",
3721                ContextLoadResult::default(),
3722                None,
3723                Vec::new(),
3724                cx,
3725            )
3726        });
3727        cx.run_until_parked();
3728
3729        // Check for the stale buffer warning
3730        thread.update(cx, |thread, cx| {
3731            thread.flush_notifications(model.clone(), CompletionIntent::UserPrompt, cx)
3732        });
3733        cx.run_until_parked();
3734
3735        let notifications = thread.read_with(cx, |thread, _cx| {
3736            find_tool_uses(thread, "project_notifications")
3737        });
3738
3739        let [notification] = notifications.as_slice() else {
3740            panic!("Should have a `project_notifications` tool use");
3741        };
3742
3743        let Some(notification_content) = notification.content.to_str() else {
3744            panic!("`project_notifications` should return text");
3745        };
3746
3747        assert!(notification_content.contains("These files have changed since the last read:"));
3748        assert!(notification_content.contains("code.rs"));
3749
3750        // Insert another user message and flush notifications again
3751        thread.update(cx, |thread, cx| {
3752            thread.insert_user_message(
3753                "Can you tell me more?",
3754                ContextLoadResult::default(),
3755                None,
3756                Vec::new(),
3757                cx,
3758            )
3759        });
3760
3761        thread.update(cx, |thread, cx| {
3762            thread.flush_notifications(model.clone(), CompletionIntent::UserPrompt, cx)
3763        });
3764        cx.run_until_parked();
3765
3766        // There should be no new notifications (we already flushed one)
3767        let notifications = thread.read_with(cx, |thread, _cx| {
3768            find_tool_uses(thread, "project_notifications")
3769        });
3770
3771        assert_eq!(
3772            notifications.len(),
3773            1,
3774            "Should still have only one notification after second flush - no duplicates"
3775        );
3776    }
3777
3778    fn find_tool_uses(thread: &Thread, tool_name: &str) -> Vec<LanguageModelToolResult> {
3779        thread
3780            .messages()
3781            .flat_map(|message| {
3782                thread
3783                    .tool_results_for_message(message.id)
3784                    .into_iter()
3785                    .filter(|result| result.tool_name == tool_name.into())
3786                    .cloned()
3787                    .collect::<Vec<_>>()
3788            })
3789            .collect()
3790    }
3791
3792    #[gpui::test]
3793    async fn test_storing_profile_setting_per_thread(cx: &mut TestAppContext) {
3794        init_test_settings(cx);
3795
3796        let project = create_test_project(
3797            cx,
3798            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
3799        )
3800        .await;
3801
3802        let (_workspace, thread_store, thread, _context_store, _model) =
3803            setup_test_environment(cx, project.clone()).await;
3804
3805        // Check that we are starting with the default profile
3806        let profile = cx.read(|cx| thread.read(cx).profile.clone());
3807        let tool_set = cx.read(|cx| thread_store.read(cx).tools());
3808        assert_eq!(
3809            profile,
3810            AgentProfile::new(AgentProfileId::default(), tool_set)
3811        );
3812    }
3813
3814    #[gpui::test]
3815    async fn test_serializing_thread_profile(cx: &mut TestAppContext) {
3816        init_test_settings(cx);
3817
3818        let project = create_test_project(
3819            cx,
3820            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
3821        )
3822        .await;
3823
3824        let (_workspace, thread_store, thread, _context_store, _model) =
3825            setup_test_environment(cx, project.clone()).await;
3826
3827        // Profile gets serialized with default values
3828        let serialized = thread
3829            .update(cx, |thread, cx| thread.serialize(cx))
3830            .await
3831            .unwrap();
3832
3833        assert_eq!(serialized.profile, Some(AgentProfileId::default()));
3834
3835        let deserialized = cx.update(|cx| {
3836            thread.update(cx, |thread, cx| {
3837                Thread::deserialize(
3838                    thread.id.clone(),
3839                    serialized,
3840                    thread.project.clone(),
3841                    thread.tools.clone(),
3842                    thread.prompt_builder.clone(),
3843                    thread.project_context.clone(),
3844                    None,
3845                    cx,
3846                )
3847            })
3848        });
3849        let tool_set = cx.read(|cx| thread_store.read(cx).tools());
3850
3851        assert_eq!(
3852            deserialized.profile,
3853            AgentProfile::new(AgentProfileId::default(), tool_set)
3854        );
3855    }
3856
3857    #[gpui::test]
3858    async fn test_temperature_setting(cx: &mut TestAppContext) {
3859        init_test_settings(cx);
3860
3861        let project = create_test_project(
3862            cx,
3863            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
3864        )
3865        .await;
3866
3867        let (_workspace, _thread_store, thread, _context_store, model) =
3868            setup_test_environment(cx, project.clone()).await;
3869
3870        // Both model and provider
3871        cx.update(|cx| {
3872            AgentSettings::override_global(
3873                AgentSettings {
3874                    model_parameters: vec![LanguageModelParameters {
3875                        provider: Some(model.provider_id().0.to_string().into()),
3876                        model: Some(model.id().0.clone()),
3877                        temperature: Some(0.66),
3878                    }],
3879                    ..AgentSettings::get_global(cx).clone()
3880                },
3881                cx,
3882            );
3883        });
3884
3885        let request = thread.update(cx, |thread, cx| {
3886            thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3887        });
3888        assert_eq!(request.temperature, Some(0.66));
3889
3890        // Only model
3891        cx.update(|cx| {
3892            AgentSettings::override_global(
3893                AgentSettings {
3894                    model_parameters: vec![LanguageModelParameters {
3895                        provider: None,
3896                        model: Some(model.id().0.clone()),
3897                        temperature: Some(0.66),
3898                    }],
3899                    ..AgentSettings::get_global(cx).clone()
3900                },
3901                cx,
3902            );
3903        });
3904
3905        let request = thread.update(cx, |thread, cx| {
3906            thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3907        });
3908        assert_eq!(request.temperature, Some(0.66));
3909
3910        // Only provider
3911        cx.update(|cx| {
3912            AgentSettings::override_global(
3913                AgentSettings {
3914                    model_parameters: vec![LanguageModelParameters {
3915                        provider: Some(model.provider_id().0.to_string().into()),
3916                        model: None,
3917                        temperature: Some(0.66),
3918                    }],
3919                    ..AgentSettings::get_global(cx).clone()
3920                },
3921                cx,
3922            );
3923        });
3924
3925        let request = thread.update(cx, |thread, cx| {
3926            thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3927        });
3928        assert_eq!(request.temperature, Some(0.66));
3929
3930        // Same model name, different provider
3931        cx.update(|cx| {
3932            AgentSettings::override_global(
3933                AgentSettings {
3934                    model_parameters: vec![LanguageModelParameters {
3935                        provider: Some("anthropic".into()),
3936                        model: Some(model.id().0.clone()),
3937                        temperature: Some(0.66),
3938                    }],
3939                    ..AgentSettings::get_global(cx).clone()
3940                },
3941                cx,
3942            );
3943        });
3944
3945        let request = thread.update(cx, |thread, cx| {
3946            thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3947        });
3948        assert_eq!(request.temperature, None);
3949    }
3950
3951    #[gpui::test]
3952    async fn test_thread_summary(cx: &mut TestAppContext) {
3953        init_test_settings(cx);
3954
3955        let project = create_test_project(cx, json!({})).await;
3956
3957        let (_, _thread_store, thread, _context_store, model) =
3958            setup_test_environment(cx, project.clone()).await;
3959
3960        // Initial state should be pending
3961        thread.read_with(cx, |thread, _| {
3962            assert!(matches!(thread.summary(), ThreadSummary::Pending));
3963            assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3964        });
3965
3966        // Manually setting the summary should not be allowed in this state
3967        thread.update(cx, |thread, cx| {
3968            thread.set_summary("This should not work", cx);
3969        });
3970
3971        thread.read_with(cx, |thread, _| {
3972            assert!(matches!(thread.summary(), ThreadSummary::Pending));
3973        });
3974
3975        // Send a message
3976        thread.update(cx, |thread, cx| {
3977            thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
3978            thread.send_to_model(
3979                model.clone(),
3980                CompletionIntent::ThreadSummarization,
3981                None,
3982                cx,
3983            );
3984        });
3985
3986        let fake_model = model.as_fake();
3987        simulate_successful_response(fake_model, cx);
3988
3989        // Should start generating summary when there are >= 2 messages
3990        thread.read_with(cx, |thread, _| {
3991            assert_eq!(*thread.summary(), ThreadSummary::Generating);
3992        });
3993
3994        // Should not be able to set the summary while generating
3995        thread.update(cx, |thread, cx| {
3996            thread.set_summary("This should not work either", cx);
3997        });
3998
3999        thread.read_with(cx, |thread, _| {
4000            assert!(matches!(thread.summary(), ThreadSummary::Generating));
4001            assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
4002        });
4003
4004        cx.run_until_parked();
4005        fake_model.send_last_completion_stream_text_chunk("Brief");
4006        fake_model.send_last_completion_stream_text_chunk(" Introduction");
4007        fake_model.end_last_completion_stream();
4008        cx.run_until_parked();
4009
4010        // Summary should be set
4011        thread.read_with(cx, |thread, _| {
4012            assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
4013            assert_eq!(thread.summary().or_default(), "Brief Introduction");
4014        });
4015
4016        // Now we should be able to set a summary
4017        thread.update(cx, |thread, cx| {
4018            thread.set_summary("Brief Intro", cx);
4019        });
4020
4021        thread.read_with(cx, |thread, _| {
4022            assert_eq!(thread.summary().or_default(), "Brief Intro");
4023        });
4024
4025        // Test setting an empty summary (should default to DEFAULT)
4026        thread.update(cx, |thread, cx| {
4027            thread.set_summary("", cx);
4028        });
4029
4030        thread.read_with(cx, |thread, _| {
4031            assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
4032            assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
4033        });
4034    }
4035
4036    #[gpui::test]
4037    async fn test_thread_summary_error_set_manually(cx: &mut TestAppContext) {
4038        init_test_settings(cx);
4039
4040        let project = create_test_project(cx, json!({})).await;
4041
4042        let (_, _thread_store, thread, _context_store, model) =
4043            setup_test_environment(cx, project.clone()).await;
4044
4045        test_summarize_error(&model, &thread, cx);
4046
4047        // Now we should be able to set a summary
4048        thread.update(cx, |thread, cx| {
4049            thread.set_summary("Brief Intro", cx);
4050        });
4051
4052        thread.read_with(cx, |thread, _| {
4053            assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
4054            assert_eq!(thread.summary().or_default(), "Brief Intro");
4055        });
4056    }
4057
4058    #[gpui::test]
4059    async fn test_thread_summary_error_retry(cx: &mut TestAppContext) {
4060        init_test_settings(cx);
4061
4062        let project = create_test_project(cx, json!({})).await;
4063
4064        let (_, _thread_store, thread, _context_store, model) =
4065            setup_test_environment(cx, project.clone()).await;
4066
4067        test_summarize_error(&model, &thread, cx);
4068
4069        // Sending another message should not trigger another summarize request
4070        thread.update(cx, |thread, cx| {
4071            thread.insert_user_message(
4072                "How are you?",
4073                ContextLoadResult::default(),
4074                None,
4075                vec![],
4076                cx,
4077            );
4078            thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4079        });
4080
4081        let fake_model = model.as_fake();
4082        simulate_successful_response(fake_model, cx);
4083
4084        thread.read_with(cx, |thread, _| {
4085            // State is still Error, not Generating
4086            assert!(matches!(thread.summary(), ThreadSummary::Error));
4087        });
4088
4089        // But the summarize request can be invoked manually
4090        thread.update(cx, |thread, cx| {
4091            thread.summarize(cx);
4092        });
4093
4094        thread.read_with(cx, |thread, _| {
4095            assert!(matches!(thread.summary(), ThreadSummary::Generating));
4096        });
4097
4098        cx.run_until_parked();
4099        fake_model.send_last_completion_stream_text_chunk("A successful summary");
4100        fake_model.end_last_completion_stream();
4101        cx.run_until_parked();
4102
4103        thread.read_with(cx, |thread, _| {
4104            assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
4105            assert_eq!(thread.summary().or_default(), "A successful summary");
4106        });
4107    }
4108
4109    // Helper to create a model that returns errors
4110    enum TestError {
4111        Overloaded,
4112        InternalServerError,
4113    }
4114
4115    struct ErrorInjector {
4116        inner: Arc<FakeLanguageModel>,
4117        error_type: TestError,
4118    }
4119
4120    impl ErrorInjector {
4121        fn new(error_type: TestError) -> Self {
4122            Self {
4123                inner: Arc::new(FakeLanguageModel::default()),
4124                error_type,
4125            }
4126        }
4127    }
4128
4129    impl LanguageModel for ErrorInjector {
4130        fn id(&self) -> LanguageModelId {
4131            self.inner.id()
4132        }
4133
4134        fn name(&self) -> LanguageModelName {
4135            self.inner.name()
4136        }
4137
4138        fn provider_id(&self) -> LanguageModelProviderId {
4139            self.inner.provider_id()
4140        }
4141
4142        fn provider_name(&self) -> LanguageModelProviderName {
4143            self.inner.provider_name()
4144        }
4145
4146        fn supports_tools(&self) -> bool {
4147            self.inner.supports_tools()
4148        }
4149
4150        fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
4151            self.inner.supports_tool_choice(choice)
4152        }
4153
4154        fn supports_images(&self) -> bool {
4155            self.inner.supports_images()
4156        }
4157
4158        fn telemetry_id(&self) -> String {
4159            self.inner.telemetry_id()
4160        }
4161
4162        fn max_token_count(&self) -> u64 {
4163            self.inner.max_token_count()
4164        }
4165
4166        fn count_tokens(
4167            &self,
4168            request: LanguageModelRequest,
4169            cx: &App,
4170        ) -> BoxFuture<'static, Result<u64>> {
4171            self.inner.count_tokens(request, cx)
4172        }
4173
4174        fn stream_completion(
4175            &self,
4176            _request: LanguageModelRequest,
4177            _cx: &AsyncApp,
4178        ) -> BoxFuture<
4179            'static,
4180            Result<
4181                BoxStream<
4182                    'static,
4183                    Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
4184                >,
4185                LanguageModelCompletionError,
4186            >,
4187        > {
4188            let error = match self.error_type {
4189                TestError::Overloaded => LanguageModelCompletionError::ServerOverloaded {
4190                    provider: self.provider_name(),
4191                    retry_after: None,
4192                },
4193                TestError::InternalServerError => {
4194                    LanguageModelCompletionError::ApiInternalServerError {
4195                        provider: self.provider_name(),
4196                        message: "I'm a teapot orbiting the sun".to_string(),
4197                    }
4198                }
4199            };
4200            async move {
4201                let stream = futures::stream::once(async move { Err(error) });
4202                Ok(stream.boxed())
4203            }
4204            .boxed()
4205        }
4206
4207        fn as_fake(&self) -> &FakeLanguageModel {
4208            &self.inner
4209        }
4210    }
4211
4212    #[gpui::test]
4213    async fn test_retry_on_overloaded_error(cx: &mut TestAppContext) {
4214        init_test_settings(cx);
4215
4216        let project = create_test_project(cx, json!({})).await;
4217        let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4218
4219        // Enable Burn Mode to allow retries
4220        thread.update(cx, |thread, _| {
4221            thread.set_completion_mode(CompletionMode::Burn);
4222        });
4223
4224        // Create model that returns overloaded error
4225        let model = Arc::new(ErrorInjector::new(TestError::Overloaded));
4226
4227        // Insert a user message
4228        thread.update(cx, |thread, cx| {
4229            thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4230        });
4231
4232        // Start completion
4233        thread.update(cx, |thread, cx| {
4234            thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4235        });
4236
4237        cx.run_until_parked();
4238
4239        thread.read_with(cx, |thread, _| {
4240            assert!(thread.retry_state.is_some(), "Should have retry state");
4241            let retry_state = thread.retry_state.as_ref().unwrap();
4242            assert_eq!(retry_state.attempt, 1, "Should be first retry attempt");
4243            assert_eq!(
4244                retry_state.max_attempts, MAX_RETRY_ATTEMPTS,
4245                "Should retry MAX_RETRY_ATTEMPTS times for overloaded errors"
4246            );
4247        });
4248
4249        // Check that a retry message was added
4250        thread.read_with(cx, |thread, _| {
4251            let mut messages = thread.messages();
4252            assert!(
4253                messages.any(|msg| {
4254                    msg.role == Role::System
4255                        && msg.ui_only
4256                        && msg.segments.iter().any(|seg| {
4257                            if let MessageSegment::Text(text) = seg {
4258                                text.contains("overloaded")
4259                                    && text
4260                                        .contains(&format!("attempt 1 of {}", MAX_RETRY_ATTEMPTS))
4261                            } else {
4262                                false
4263                            }
4264                        })
4265                }),
4266                "Should have added a system retry message"
4267            );
4268        });
4269
4270        let retry_count = thread.update(cx, |thread, _| {
4271            thread
4272                .messages
4273                .iter()
4274                .filter(|m| {
4275                    m.ui_only
4276                        && m.segments.iter().any(|s| {
4277                            if let MessageSegment::Text(text) = s {
4278                                text.contains("Retrying") && text.contains("seconds")
4279                            } else {
4280                                false
4281                            }
4282                        })
4283                })
4284                .count()
4285        });
4286
4287        assert_eq!(retry_count, 1, "Should have one retry message");
4288    }
4289
4290    #[gpui::test]
4291    async fn test_retry_on_internal_server_error(cx: &mut TestAppContext) {
4292        init_test_settings(cx);
4293
4294        let project = create_test_project(cx, json!({})).await;
4295        let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4296
4297        // Enable Burn Mode to allow retries
4298        thread.update(cx, |thread, _| {
4299            thread.set_completion_mode(CompletionMode::Burn);
4300        });
4301
4302        // Create model that returns internal server error
4303        let model = Arc::new(ErrorInjector::new(TestError::InternalServerError));
4304
4305        // Insert a user message
4306        thread.update(cx, |thread, cx| {
4307            thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4308        });
4309
4310        // Start completion
4311        thread.update(cx, |thread, cx| {
4312            thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4313        });
4314
4315        cx.run_until_parked();
4316
4317        // Check retry state on thread
4318        thread.read_with(cx, |thread, _| {
4319            assert!(thread.retry_state.is_some(), "Should have retry state");
4320            let retry_state = thread.retry_state.as_ref().unwrap();
4321            assert_eq!(retry_state.attempt, 1, "Should be first retry attempt");
4322            assert_eq!(
4323                retry_state.max_attempts, 3,
4324                "Should have correct max attempts"
4325            );
4326        });
4327
4328        // Check that a retry message was added with provider name
4329        thread.read_with(cx, |thread, _| {
4330            let mut messages = thread.messages();
4331            assert!(
4332                messages.any(|msg| {
4333                    msg.role == Role::System
4334                        && msg.ui_only
4335                        && msg.segments.iter().any(|seg| {
4336                            if let MessageSegment::Text(text) = seg {
4337                                text.contains("internal")
4338                                    && text.contains("Fake")
4339                                    && text.contains("Retrying")
4340                                    && text.contains("attempt 1 of 3")
4341                                    && text.contains("seconds")
4342                            } else {
4343                                false
4344                            }
4345                        })
4346                }),
4347                "Should have added a system retry message with provider name"
4348            );
4349        });
4350
4351        // Count retry messages
4352        let retry_count = thread.update(cx, |thread, _| {
4353            thread
4354                .messages
4355                .iter()
4356                .filter(|m| {
4357                    m.ui_only
4358                        && m.segments.iter().any(|s| {
4359                            if let MessageSegment::Text(text) = s {
4360                                text.contains("Retrying") && text.contains("seconds")
4361                            } else {
4362                                false
4363                            }
4364                        })
4365                })
4366                .count()
4367        });
4368
4369        assert_eq!(retry_count, 1, "Should have one retry message");
4370    }
4371
4372    #[gpui::test]
4373    async fn test_exponential_backoff_on_retries(cx: &mut TestAppContext) {
4374        init_test_settings(cx);
4375
4376        let project = create_test_project(cx, json!({})).await;
4377        let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4378
4379        // Enable Burn Mode to allow retries
4380        thread.update(cx, |thread, _| {
4381            thread.set_completion_mode(CompletionMode::Burn);
4382        });
4383
4384        // Create model that returns internal server error
4385        let model = Arc::new(ErrorInjector::new(TestError::InternalServerError));
4386
4387        // Insert a user message
4388        thread.update(cx, |thread, cx| {
4389            thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4390        });
4391
4392        // Track retry events and completion count
4393        // Track completion events
4394        let completion_count = Arc::new(Mutex::new(0));
4395        let completion_count_clone = completion_count.clone();
4396
4397        let _subscription = thread.update(cx, |_, cx| {
4398            cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| {
4399                if let ThreadEvent::NewRequest = event {
4400                    *completion_count_clone.lock() += 1;
4401                }
4402            })
4403        });
4404
4405        // First attempt
4406        thread.update(cx, |thread, cx| {
4407            thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4408        });
4409        cx.run_until_parked();
4410
4411        // Should have scheduled first retry - count retry messages
4412        let retry_count = thread.update(cx, |thread, _| {
4413            thread
4414                .messages
4415                .iter()
4416                .filter(|m| {
4417                    m.ui_only
4418                        && m.segments.iter().any(|s| {
4419                            if let MessageSegment::Text(text) = s {
4420                                text.contains("Retrying") && text.contains("seconds")
4421                            } else {
4422                                false
4423                            }
4424                        })
4425                })
4426                .count()
4427        });
4428        assert_eq!(retry_count, 1, "Should have scheduled first retry");
4429
4430        // Check retry state
4431        thread.read_with(cx, |thread, _| {
4432            assert!(thread.retry_state.is_some(), "Should have retry state");
4433            let retry_state = thread.retry_state.as_ref().unwrap();
4434            assert_eq!(retry_state.attempt, 1, "Should be first retry attempt");
4435            assert_eq!(
4436                retry_state.max_attempts, 3,
4437                "Internal server errors should retry up to 3 times"
4438            );
4439        });
4440
4441        // Advance clock for first retry
4442        cx.executor().advance_clock(BASE_RETRY_DELAY);
4443        cx.run_until_parked();
4444
4445        // Advance clock for second retry
4446        cx.executor().advance_clock(BASE_RETRY_DELAY);
4447        cx.run_until_parked();
4448
4449        // Advance clock for third retry
4450        cx.executor().advance_clock(BASE_RETRY_DELAY);
4451        cx.run_until_parked();
4452
4453        // Should have completed all retries - count retry messages
4454        let retry_count = thread.update(cx, |thread, _| {
4455            thread
4456                .messages
4457                .iter()
4458                .filter(|m| {
4459                    m.ui_only
4460                        && m.segments.iter().any(|s| {
4461                            if let MessageSegment::Text(text) = s {
4462                                text.contains("Retrying") && text.contains("seconds")
4463                            } else {
4464                                false
4465                            }
4466                        })
4467                })
4468                .count()
4469        });
4470        assert_eq!(
4471            retry_count, 3,
4472            "Should have 3 retries for internal server errors"
4473        );
4474
4475        // For internal server errors, we retry 3 times and then give up
4476        // Check that retry_state is cleared after all retries
4477        thread.read_with(cx, |thread, _| {
4478            assert!(
4479                thread.retry_state.is_none(),
4480                "Retry state should be cleared after all retries"
4481            );
4482        });
4483
4484        // Verify total attempts (1 initial + 3 retries)
4485        assert_eq!(
4486            *completion_count.lock(),
4487            4,
4488            "Should have attempted once plus 3 retries"
4489        );
4490    }
4491
4492    #[gpui::test]
4493    async fn test_max_retries_exceeded(cx: &mut TestAppContext) {
4494        init_test_settings(cx);
4495
4496        let project = create_test_project(cx, json!({})).await;
4497        let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4498
4499        // Enable Burn Mode to allow retries
4500        thread.update(cx, |thread, _| {
4501            thread.set_completion_mode(CompletionMode::Burn);
4502        });
4503
4504        // Create model that returns overloaded error
4505        let model = Arc::new(ErrorInjector::new(TestError::Overloaded));
4506
4507        // Insert a user message
4508        thread.update(cx, |thread, cx| {
4509            thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4510        });
4511
4512        // Track events
4513        let stopped_with_error = Arc::new(Mutex::new(false));
4514        let stopped_with_error_clone = stopped_with_error.clone();
4515
4516        let _subscription = thread.update(cx, |_, cx| {
4517            cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| {
4518                if let ThreadEvent::Stopped(Err(_)) = event {
4519                    *stopped_with_error_clone.lock() = true;
4520                }
4521            })
4522        });
4523
4524        // Start initial completion
4525        thread.update(cx, |thread, cx| {
4526            thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4527        });
4528        cx.run_until_parked();
4529
4530        // Advance through all retries
4531        for _ in 0..MAX_RETRY_ATTEMPTS {
4532            cx.executor().advance_clock(BASE_RETRY_DELAY);
4533            cx.run_until_parked();
4534        }
4535
4536        let retry_count = thread.update(cx, |thread, _| {
4537            thread
4538                .messages
4539                .iter()
4540                .filter(|m| {
4541                    m.ui_only
4542                        && m.segments.iter().any(|s| {
4543                            if let MessageSegment::Text(text) = s {
4544                                text.contains("Retrying") && text.contains("seconds")
4545                            } else {
4546                                false
4547                            }
4548                        })
4549                })
4550                .count()
4551        });
4552
4553        // After max retries, should emit Stopped(Err(...)) event
4554        assert_eq!(
4555            retry_count, MAX_RETRY_ATTEMPTS as usize,
4556            "Should have attempted MAX_RETRY_ATTEMPTS retries for overloaded errors"
4557        );
4558        assert!(
4559            *stopped_with_error.lock(),
4560            "Should emit Stopped(Err(...)) event after max retries exceeded"
4561        );
4562
4563        // Retry state should be cleared
4564        thread.read_with(cx, |thread, _| {
4565            assert!(
4566                thread.retry_state.is_none(),
4567                "Retry state should be cleared after max retries"
4568            );
4569
4570            // Verify we have the expected number of retry messages
4571            let retry_messages = thread
4572                .messages
4573                .iter()
4574                .filter(|msg| msg.ui_only && msg.role == Role::System)
4575                .count();
4576            assert_eq!(
4577                retry_messages, MAX_RETRY_ATTEMPTS as usize,
4578                "Should have MAX_RETRY_ATTEMPTS retry messages for overloaded errors"
4579            );
4580        });
4581    }
4582
4583    #[gpui::test]
4584    async fn test_retry_message_removed_on_retry(cx: &mut TestAppContext) {
4585        init_test_settings(cx);
4586
4587        let project = create_test_project(cx, json!({})).await;
4588        let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4589
4590        // Enable Burn Mode to allow retries
4591        thread.update(cx, |thread, _| {
4592            thread.set_completion_mode(CompletionMode::Burn);
4593        });
4594
4595        // We'll use a wrapper to switch behavior after first failure
4596        struct RetryTestModel {
4597            inner: Arc<FakeLanguageModel>,
4598            failed_once: Arc<Mutex<bool>>,
4599        }
4600
4601        impl LanguageModel for RetryTestModel {
4602            fn id(&self) -> LanguageModelId {
4603                self.inner.id()
4604            }
4605
4606            fn name(&self) -> LanguageModelName {
4607                self.inner.name()
4608            }
4609
4610            fn provider_id(&self) -> LanguageModelProviderId {
4611                self.inner.provider_id()
4612            }
4613
4614            fn provider_name(&self) -> LanguageModelProviderName {
4615                self.inner.provider_name()
4616            }
4617
4618            fn supports_tools(&self) -> bool {
4619                self.inner.supports_tools()
4620            }
4621
4622            fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
4623                self.inner.supports_tool_choice(choice)
4624            }
4625
4626            fn supports_images(&self) -> bool {
4627                self.inner.supports_images()
4628            }
4629
4630            fn telemetry_id(&self) -> String {
4631                self.inner.telemetry_id()
4632            }
4633
4634            fn max_token_count(&self) -> u64 {
4635                self.inner.max_token_count()
4636            }
4637
4638            fn count_tokens(
4639                &self,
4640                request: LanguageModelRequest,
4641                cx: &App,
4642            ) -> BoxFuture<'static, Result<u64>> {
4643                self.inner.count_tokens(request, cx)
4644            }
4645
4646            fn stream_completion(
4647                &self,
4648                request: LanguageModelRequest,
4649                cx: &AsyncApp,
4650            ) -> BoxFuture<
4651                'static,
4652                Result<
4653                    BoxStream<
4654                        'static,
4655                        Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
4656                    >,
4657                    LanguageModelCompletionError,
4658                >,
4659            > {
4660                if !*self.failed_once.lock() {
4661                    *self.failed_once.lock() = true;
4662                    let provider = self.provider_name();
4663                    // Return error on first attempt
4664                    let stream = futures::stream::once(async move {
4665                        Err(LanguageModelCompletionError::ServerOverloaded {
4666                            provider,
4667                            retry_after: None,
4668                        })
4669                    });
4670                    async move { Ok(stream.boxed()) }.boxed()
4671                } else {
4672                    // Succeed on retry
4673                    self.inner.stream_completion(request, cx)
4674                }
4675            }
4676
4677            fn as_fake(&self) -> &FakeLanguageModel {
4678                &self.inner
4679            }
4680        }
4681
4682        let model = Arc::new(RetryTestModel {
4683            inner: Arc::new(FakeLanguageModel::default()),
4684            failed_once: Arc::new(Mutex::new(false)),
4685        });
4686
4687        // Insert a user message
4688        thread.update(cx, |thread, cx| {
4689            thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4690        });
4691
4692        // Track message deletions
4693        // Track when retry completes successfully
4694        let retry_completed = Arc::new(Mutex::new(false));
4695        let retry_completed_clone = retry_completed.clone();
4696
4697        let _subscription = thread.update(cx, |_, cx| {
4698            cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| {
4699                if let ThreadEvent::StreamedCompletion = event {
4700                    *retry_completed_clone.lock() = true;
4701                }
4702            })
4703        });
4704
4705        // Start completion
4706        thread.update(cx, |thread, cx| {
4707            thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4708        });
4709        cx.run_until_parked();
4710
4711        // Get the retry message ID
4712        let retry_message_id = thread.read_with(cx, |thread, _| {
4713            thread
4714                .messages()
4715                .find(|msg| msg.role == Role::System && msg.ui_only)
4716                .map(|msg| msg.id)
4717                .expect("Should have a retry message")
4718        });
4719
4720        // Wait for retry
4721        cx.executor().advance_clock(BASE_RETRY_DELAY);
4722        cx.run_until_parked();
4723
4724        // Stream some successful content
4725        let fake_model = model.as_fake();
4726        // After the retry, there should be a new pending completion
4727        let pending = fake_model.pending_completions();
4728        assert!(
4729            !pending.is_empty(),
4730            "Should have a pending completion after retry"
4731        );
4732        fake_model.send_completion_stream_text_chunk(&pending[0], "Success!");
4733        fake_model.end_completion_stream(&pending[0]);
4734        cx.run_until_parked();
4735
4736        // Check that the retry completed successfully
4737        assert!(
4738            *retry_completed.lock(),
4739            "Retry should have completed successfully"
4740        );
4741
4742        // Retry message should still exist but be marked as ui_only
4743        thread.read_with(cx, |thread, _| {
4744            let retry_msg = thread
4745                .message(retry_message_id)
4746                .expect("Retry message should still exist");
4747            assert!(retry_msg.ui_only, "Retry message should be ui_only");
4748            assert_eq!(
4749                retry_msg.role,
4750                Role::System,
4751                "Retry message should have System role"
4752            );
4753        });
4754    }
4755
4756    #[gpui::test]
4757    async fn test_successful_completion_clears_retry_state(cx: &mut TestAppContext) {
4758        init_test_settings(cx);
4759
4760        let project = create_test_project(cx, json!({})).await;
4761        let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4762
4763        // Enable Burn Mode to allow retries
4764        thread.update(cx, |thread, _| {
4765            thread.set_completion_mode(CompletionMode::Burn);
4766        });
4767
4768        // Create a model that fails once then succeeds
4769        struct FailOnceModel {
4770            inner: Arc<FakeLanguageModel>,
4771            failed_once: Arc<Mutex<bool>>,
4772        }
4773
4774        impl LanguageModel for FailOnceModel {
4775            fn id(&self) -> LanguageModelId {
4776                self.inner.id()
4777            }
4778
4779            fn name(&self) -> LanguageModelName {
4780                self.inner.name()
4781            }
4782
4783            fn provider_id(&self) -> LanguageModelProviderId {
4784                self.inner.provider_id()
4785            }
4786
4787            fn provider_name(&self) -> LanguageModelProviderName {
4788                self.inner.provider_name()
4789            }
4790
4791            fn supports_tools(&self) -> bool {
4792                self.inner.supports_tools()
4793            }
4794
4795            fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
4796                self.inner.supports_tool_choice(choice)
4797            }
4798
4799            fn supports_images(&self) -> bool {
4800                self.inner.supports_images()
4801            }
4802
4803            fn telemetry_id(&self) -> String {
4804                self.inner.telemetry_id()
4805            }
4806
4807            fn max_token_count(&self) -> u64 {
4808                self.inner.max_token_count()
4809            }
4810
4811            fn count_tokens(
4812                &self,
4813                request: LanguageModelRequest,
4814                cx: &App,
4815            ) -> BoxFuture<'static, Result<u64>> {
4816                self.inner.count_tokens(request, cx)
4817            }
4818
4819            fn stream_completion(
4820                &self,
4821                request: LanguageModelRequest,
4822                cx: &AsyncApp,
4823            ) -> BoxFuture<
4824                'static,
4825                Result<
4826                    BoxStream<
4827                        'static,
4828                        Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
4829                    >,
4830                    LanguageModelCompletionError,
4831                >,
4832            > {
4833                if !*self.failed_once.lock() {
4834                    *self.failed_once.lock() = true;
4835                    let provider = self.provider_name();
4836                    // Return error on first attempt
4837                    let stream = futures::stream::once(async move {
4838                        Err(LanguageModelCompletionError::ServerOverloaded {
4839                            provider,
4840                            retry_after: None,
4841                        })
4842                    });
4843                    async move { Ok(stream.boxed()) }.boxed()
4844                } else {
4845                    // Succeed on retry
4846                    self.inner.stream_completion(request, cx)
4847                }
4848            }
4849        }
4850
4851        let fail_once_model = Arc::new(FailOnceModel {
4852            inner: Arc::new(FakeLanguageModel::default()),
4853            failed_once: Arc::new(Mutex::new(false)),
4854        });
4855
4856        // Insert a user message
4857        thread.update(cx, |thread, cx| {
4858            thread.insert_user_message(
4859                "Test message",
4860                ContextLoadResult::default(),
4861                None,
4862                vec![],
4863                cx,
4864            );
4865        });
4866
4867        // Start completion with fail-once model
4868        thread.update(cx, |thread, cx| {
4869            thread.send_to_model(
4870                fail_once_model.clone(),
4871                CompletionIntent::UserPrompt,
4872                None,
4873                cx,
4874            );
4875        });
4876
4877        cx.run_until_parked();
4878
4879        // Verify retry state exists after first failure
4880        thread.read_with(cx, |thread, _| {
4881            assert!(
4882                thread.retry_state.is_some(),
4883                "Should have retry state after failure"
4884            );
4885        });
4886
4887        // Wait for retry delay
4888        cx.executor().advance_clock(BASE_RETRY_DELAY);
4889        cx.run_until_parked();
4890
4891        // The retry should now use our FailOnceModel which should succeed
4892        // We need to help the FakeLanguageModel complete the stream
4893        let inner_fake = fail_once_model.inner.clone();
4894
4895        // Wait a bit for the retry to start
4896        cx.run_until_parked();
4897
4898        // Check for pending completions and complete them
4899        if let Some(pending) = inner_fake.pending_completions().first() {
4900            inner_fake.send_completion_stream_text_chunk(pending, "Success!");
4901            inner_fake.end_completion_stream(pending);
4902        }
4903        cx.run_until_parked();
4904
4905        thread.read_with(cx, |thread, _| {
4906            assert!(
4907                thread.retry_state.is_none(),
4908                "Retry state should be cleared after successful completion"
4909            );
4910
4911            let has_assistant_message = thread
4912                .messages
4913                .iter()
4914                .any(|msg| msg.role == Role::Assistant && !msg.ui_only);
4915            assert!(
4916                has_assistant_message,
4917                "Should have an assistant message after successful retry"
4918            );
4919        });
4920    }
4921
4922    #[gpui::test]
4923    async fn test_rate_limit_retry_single_attempt(cx: &mut TestAppContext) {
4924        init_test_settings(cx);
4925
4926        let project = create_test_project(cx, json!({})).await;
4927        let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4928
4929        // Enable Burn Mode to allow retries
4930        thread.update(cx, |thread, _| {
4931            thread.set_completion_mode(CompletionMode::Burn);
4932        });
4933
4934        // Create a model that returns rate limit error with retry_after
4935        struct RateLimitModel {
4936            inner: Arc<FakeLanguageModel>,
4937        }
4938
4939        impl LanguageModel for RateLimitModel {
4940            fn id(&self) -> LanguageModelId {
4941                self.inner.id()
4942            }
4943
4944            fn name(&self) -> LanguageModelName {
4945                self.inner.name()
4946            }
4947
4948            fn provider_id(&self) -> LanguageModelProviderId {
4949                self.inner.provider_id()
4950            }
4951
4952            fn provider_name(&self) -> LanguageModelProviderName {
4953                self.inner.provider_name()
4954            }
4955
4956            fn supports_tools(&self) -> bool {
4957                self.inner.supports_tools()
4958            }
4959
4960            fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
4961                self.inner.supports_tool_choice(choice)
4962            }
4963
4964            fn supports_images(&self) -> bool {
4965                self.inner.supports_images()
4966            }
4967
4968            fn telemetry_id(&self) -> String {
4969                self.inner.telemetry_id()
4970            }
4971
4972            fn max_token_count(&self) -> u64 {
4973                self.inner.max_token_count()
4974            }
4975
4976            fn count_tokens(
4977                &self,
4978                request: LanguageModelRequest,
4979                cx: &App,
4980            ) -> BoxFuture<'static, Result<u64>> {
4981                self.inner.count_tokens(request, cx)
4982            }
4983
4984            fn stream_completion(
4985                &self,
4986                _request: LanguageModelRequest,
4987                _cx: &AsyncApp,
4988            ) -> BoxFuture<
4989                'static,
4990                Result<
4991                    BoxStream<
4992                        'static,
4993                        Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
4994                    >,
4995                    LanguageModelCompletionError,
4996                >,
4997            > {
4998                let provider = self.provider_name();
4999                async move {
5000                    let stream = futures::stream::once(async move {
5001                        Err(LanguageModelCompletionError::RateLimitExceeded {
5002                            provider,
5003                            retry_after: Some(Duration::from_secs(TEST_RATE_LIMIT_RETRY_SECS)),
5004                        })
5005                    });
5006                    Ok(stream.boxed())
5007                }
5008                .boxed()
5009            }
5010
5011            fn as_fake(&self) -> &FakeLanguageModel {
5012                &self.inner
5013            }
5014        }
5015
5016        let model = Arc::new(RateLimitModel {
5017            inner: Arc::new(FakeLanguageModel::default()),
5018        });
5019
5020        // Insert a user message
5021        thread.update(cx, |thread, cx| {
5022            thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
5023        });
5024
5025        // Start completion
5026        thread.update(cx, |thread, cx| {
5027            thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
5028        });
5029
5030        cx.run_until_parked();
5031
5032        let retry_count = thread.update(cx, |thread, _| {
5033            thread
5034                .messages
5035                .iter()
5036                .filter(|m| {
5037                    m.ui_only
5038                        && m.segments.iter().any(|s| {
5039                            if let MessageSegment::Text(text) = s {
5040                                text.contains("rate limit exceeded")
5041                            } else {
5042                                false
5043                            }
5044                        })
5045                })
5046                .count()
5047        });
5048        assert_eq!(retry_count, 1, "Should have scheduled one retry");
5049
5050        thread.read_with(cx, |thread, _| {
5051            assert!(
5052                thread.retry_state.is_some(),
5053                "Rate limit errors should set retry_state"
5054            );
5055            if let Some(retry_state) = &thread.retry_state {
5056                assert_eq!(
5057                    retry_state.max_attempts, MAX_RETRY_ATTEMPTS,
5058                    "Rate limit errors should use MAX_RETRY_ATTEMPTS"
5059                );
5060            }
5061        });
5062
5063        // Verify we have one retry message
5064        thread.read_with(cx, |thread, _| {
5065            let retry_messages = thread
5066                .messages
5067                .iter()
5068                .filter(|msg| {
5069                    msg.ui_only
5070                        && msg.segments.iter().any(|seg| {
5071                            if let MessageSegment::Text(text) = seg {
5072                                text.contains("rate limit exceeded")
5073                            } else {
5074                                false
5075                            }
5076                        })
5077                })
5078                .count();
5079            assert_eq!(
5080                retry_messages, 1,
5081                "Should have one rate limit retry message"
5082            );
5083        });
5084
5085        // Check that retry message doesn't include attempt count
5086        thread.read_with(cx, |thread, _| {
5087            let retry_message = thread
5088                .messages
5089                .iter()
5090                .find(|msg| msg.role == Role::System && msg.ui_only)
5091                .expect("Should have a retry message");
5092
5093            // Check that the message contains attempt count since we use retry_state
5094            if let Some(MessageSegment::Text(text)) = retry_message.segments.first() {
5095                assert!(
5096                    text.contains(&format!("attempt 1 of {}", MAX_RETRY_ATTEMPTS)),
5097                    "Rate limit retry message should contain attempt count with MAX_RETRY_ATTEMPTS"
5098                );
5099                assert!(
5100                    text.contains("Retrying"),
5101                    "Rate limit retry message should contain retry text"
5102                );
5103            }
5104        });
5105    }
5106
5107    #[gpui::test]
5108    async fn test_ui_only_messages_not_sent_to_model(cx: &mut TestAppContext) {
5109        init_test_settings(cx);
5110
5111        let project = create_test_project(cx, json!({})).await;
5112        let (_, _, thread, _, model) = setup_test_environment(cx, project.clone()).await;
5113
5114        // Insert a regular user message
5115        thread.update(cx, |thread, cx| {
5116            thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
5117        });
5118
5119        // Insert a UI-only message (like our retry notifications)
5120        thread.update(cx, |thread, cx| {
5121            let id = thread.next_message_id.post_inc();
5122            thread.messages.push(Message {
5123                id,
5124                role: Role::System,
5125                segments: vec![MessageSegment::Text(
5126                    "This is a UI-only message that should not be sent to the model".to_string(),
5127                )],
5128                loaded_context: LoadedContext::default(),
5129                creases: Vec::new(),
5130                is_hidden: true,
5131                ui_only: true,
5132            });
5133            cx.emit(ThreadEvent::MessageAdded(id));
5134        });
5135
5136        // Insert another regular message
5137        thread.update(cx, |thread, cx| {
5138            thread.insert_user_message(
5139                "How are you?",
5140                ContextLoadResult::default(),
5141                None,
5142                vec![],
5143                cx,
5144            );
5145        });
5146
5147        // Generate the completion request
5148        let request = thread.update(cx, |thread, cx| {
5149            thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
5150        });
5151
5152        // Verify that the request only contains non-UI-only messages
5153        // Should have system prompt + 2 user messages, but not the UI-only message
5154        let user_messages: Vec<_> = request
5155            .messages
5156            .iter()
5157            .filter(|msg| msg.role == Role::User)
5158            .collect();
5159        assert_eq!(
5160            user_messages.len(),
5161            2,
5162            "Should have exactly 2 user messages"
5163        );
5164
5165        // Verify the UI-only content is not present anywhere in the request
5166        let request_text = request
5167            .messages
5168            .iter()
5169            .flat_map(|msg| &msg.content)
5170            .filter_map(|content| match content {
5171                MessageContent::Text(text) => Some(text.as_str()),
5172                _ => None,
5173            })
5174            .collect::<String>();
5175
5176        assert!(
5177            !request_text.contains("UI-only message"),
5178            "UI-only message content should not be in the request"
5179        );
5180
5181        // Verify the thread still has all 3 messages (including UI-only)
5182        thread.read_with(cx, |thread, _| {
5183            assert_eq!(
5184                thread.messages().count(),
5185                3,
5186                "Thread should have 3 messages"
5187            );
5188            assert_eq!(
5189                thread.messages().filter(|m| m.ui_only).count(),
5190                1,
5191                "Thread should have 1 UI-only message"
5192            );
5193        });
5194
5195        // Verify that UI-only messages are not serialized
5196        let serialized = thread
5197            .update(cx, |thread, cx| thread.serialize(cx))
5198            .await
5199            .unwrap();
5200        assert_eq!(
5201            serialized.messages.len(),
5202            2,
5203            "Serialized thread should only have 2 messages (no UI-only)"
5204        );
5205    }
5206
5207    #[gpui::test]
5208    async fn test_no_retry_without_burn_mode(cx: &mut TestAppContext) {
5209        init_test_settings(cx);
5210
5211        let project = create_test_project(cx, json!({})).await;
5212        let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
5213
5214        // Ensure we're in Normal mode (not Burn mode)
5215        thread.update(cx, |thread, _| {
5216            thread.set_completion_mode(CompletionMode::Normal);
5217        });
5218
5219        // Track error events
5220        let error_events = Arc::new(Mutex::new(Vec::new()));
5221        let error_events_clone = error_events.clone();
5222
5223        let _subscription = thread.update(cx, |_, cx| {
5224            cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| {
5225                if let ThreadEvent::ShowError(error) = event {
5226                    error_events_clone.lock().push(error.clone());
5227                }
5228            })
5229        });
5230
5231        // Create model that returns overloaded error
5232        let model = Arc::new(ErrorInjector::new(TestError::Overloaded));
5233
5234        // Insert a user message
5235        thread.update(cx, |thread, cx| {
5236            thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
5237        });
5238
5239        // Start completion
5240        thread.update(cx, |thread, cx| {
5241            thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
5242        });
5243
5244        cx.run_until_parked();
5245
5246        // Verify no retry state was created
5247        thread.read_with(cx, |thread, _| {
5248            assert!(
5249                thread.retry_state.is_none(),
5250                "Should not have retry state in Normal mode"
5251            );
5252        });
5253
5254        // Check that a retryable error was reported
5255        let errors = error_events.lock();
5256        assert!(!errors.is_empty(), "Should have received an error event");
5257
5258        if let ThreadError::RetryableError {
5259            message: _,
5260            can_enable_burn_mode,
5261        } = &errors[0]
5262        {
5263            assert!(
5264                *can_enable_burn_mode,
5265                "Error should indicate burn mode can be enabled"
5266            );
5267        } else {
5268            panic!("Expected RetryableError, got {:?}", errors[0]);
5269        }
5270
5271        // Verify the thread is no longer generating
5272        thread.read_with(cx, |thread, _| {
5273            assert!(
5274                !thread.is_generating(),
5275                "Should not be generating after error without retry"
5276            );
5277        });
5278    }
5279
5280    #[gpui::test]
5281    async fn test_retry_canceled_on_stop(cx: &mut TestAppContext) {
5282        init_test_settings(cx);
5283
5284        let project = create_test_project(cx, json!({})).await;
5285        let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
5286
5287        // Enable Burn Mode to allow retries
5288        thread.update(cx, |thread, _| {
5289            thread.set_completion_mode(CompletionMode::Burn);
5290        });
5291
5292        // Create model that returns overloaded error
5293        let model = Arc::new(ErrorInjector::new(TestError::Overloaded));
5294
5295        // Insert a user message
5296        thread.update(cx, |thread, cx| {
5297            thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
5298        });
5299
5300        // Start completion
5301        thread.update(cx, |thread, cx| {
5302            thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
5303        });
5304
5305        cx.run_until_parked();
5306
5307        // Verify retry was scheduled by checking for retry message
5308        let has_retry_message = thread.read_with(cx, |thread, _| {
5309            thread.messages.iter().any(|m| {
5310                m.ui_only
5311                    && m.segments.iter().any(|s| {
5312                        if let MessageSegment::Text(text) = s {
5313                            text.contains("Retrying") && text.contains("seconds")
5314                        } else {
5315                            false
5316                        }
5317                    })
5318            })
5319        });
5320        assert!(has_retry_message, "Should have scheduled a retry");
5321
5322        // Cancel the completion before the retry happens
5323        thread.update(cx, |thread, cx| {
5324            thread.cancel_last_completion(None, cx);
5325        });
5326
5327        cx.run_until_parked();
5328
5329        // The retry should not have happened - no pending completions
5330        let fake_model = model.as_fake();
5331        assert_eq!(
5332            fake_model.pending_completions().len(),
5333            0,
5334            "Should have no pending completions after cancellation"
5335        );
5336
5337        // Verify the retry was canceled by checking retry state
5338        thread.read_with(cx, |thread, _| {
5339            if let Some(retry_state) = &thread.retry_state {
5340                panic!(
5341                    "retry_state should be cleared after cancellation, but found: attempt={}, max_attempts={}, intent={:?}",
5342                    retry_state.attempt, retry_state.max_attempts, retry_state.intent
5343                );
5344            }
5345        });
5346    }
5347
5348    fn test_summarize_error(
5349        model: &Arc<dyn LanguageModel>,
5350        thread: &Entity<Thread>,
5351        cx: &mut TestAppContext,
5352    ) {
5353        thread.update(cx, |thread, cx| {
5354            thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
5355            thread.send_to_model(
5356                model.clone(),
5357                CompletionIntent::ThreadSummarization,
5358                None,
5359                cx,
5360            );
5361        });
5362
5363        let fake_model = model.as_fake();
5364        simulate_successful_response(fake_model, cx);
5365
5366        thread.read_with(cx, |thread, _| {
5367            assert!(matches!(thread.summary(), ThreadSummary::Generating));
5368            assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
5369        });
5370
5371        // Simulate summary request ending
5372        cx.run_until_parked();
5373        fake_model.end_last_completion_stream();
5374        cx.run_until_parked();
5375
5376        // State is set to Error and default message
5377        thread.read_with(cx, |thread, _| {
5378            assert!(matches!(thread.summary(), ThreadSummary::Error));
5379            assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
5380        });
5381    }
5382
5383    fn simulate_successful_response(fake_model: &FakeLanguageModel, cx: &mut TestAppContext) {
5384        cx.run_until_parked();
5385        fake_model.send_last_completion_stream_text_chunk("Assistant response");
5386        fake_model.end_last_completion_stream();
5387        cx.run_until_parked();
5388    }
5389
5390    fn init_test_settings(cx: &mut TestAppContext) {
5391        cx.update(|cx| {
5392            let settings_store = SettingsStore::test(cx);
5393            cx.set_global(settings_store);
5394            language::init(cx);
5395            Project::init_settings(cx);
5396            AgentSettings::register(cx);
5397            prompt_store::init(cx);
5398            thread_store::init(cx);
5399            workspace::init_settings(cx);
5400            language_model::init_settings(cx);
5401            ThemeSettings::register(cx);
5402            ToolRegistry::default_global(cx);
5403            assistant_tool::init(cx);
5404
5405            let http_client = Arc::new(http_client::HttpClientWithUrl::new(
5406                http_client::FakeHttpClient::with_200_response(),
5407                "http://localhost".to_string(),
5408                None,
5409            ));
5410            assistant_tools::init(http_client, cx);
5411        });
5412    }
5413
5414    // Helper to create a test project with test files
5415    async fn create_test_project(
5416        cx: &mut TestAppContext,
5417        files: serde_json::Value,
5418    ) -> Entity<Project> {
5419        let fs = FakeFs::new(cx.executor());
5420        fs.insert_tree(path!("/test"), files).await;
5421        Project::test(fs, [path!("/test").as_ref()], cx).await
5422    }
5423
5424    async fn setup_test_environment(
5425        cx: &mut TestAppContext,
5426        project: Entity<Project>,
5427    ) -> (
5428        Entity<Workspace>,
5429        Entity<ThreadStore>,
5430        Entity<Thread>,
5431        Entity<ContextStore>,
5432        Arc<dyn LanguageModel>,
5433    ) {
5434        let (workspace, cx) =
5435            cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
5436
5437        let thread_store = cx
5438            .update(|_, cx| {
5439                ThreadStore::load(
5440                    project.clone(),
5441                    cx.new(|_| ToolWorkingSet::default()),
5442                    None,
5443                    Arc::new(PromptBuilder::new(None).unwrap()),
5444                    cx,
5445                )
5446            })
5447            .await
5448            .unwrap();
5449
5450        let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
5451        let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
5452
5453        let provider = Arc::new(FakeLanguageModelProvider::default());
5454        let model = provider.test_model();
5455        let model: Arc<dyn LanguageModel> = Arc::new(model);
5456
5457        cx.update(|_, cx| {
5458            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
5459                registry.set_default_model(
5460                    Some(ConfiguredModel {
5461                        provider: provider.clone(),
5462                        model: model.clone(),
5463                    }),
5464                    cx,
5465                );
5466                registry.set_thread_summary_model(
5467                    Some(ConfiguredModel {
5468                        provider,
5469                        model: model.clone(),
5470                    }),
5471                    cx,
5472                );
5473            })
5474        });
5475
5476        (workspace, thread_store, thread, context_store, model)
5477    }
5478
5479    async fn add_file_to_context(
5480        project: &Entity<Project>,
5481        context_store: &Entity<ContextStore>,
5482        path: &str,
5483        cx: &mut TestAppContext,
5484    ) -> Result<Entity<language::Buffer>> {
5485        let buffer_path = project
5486            .read_with(cx, |project, cx| project.find_project_path(path, cx))
5487            .unwrap();
5488
5489        let buffer = project
5490            .update(cx, |project, cx| {
5491                project.open_buffer(buffer_path.clone(), cx)
5492            })
5493            .await
5494            .unwrap();
5495
5496        context_store.update(cx, |context_store, cx| {
5497            context_store.add_file_from_buffer(&buffer_path, buffer.clone(), false, cx);
5498        });
5499
5500        Ok(buffer)
5501    }
5502}