thread.rs

   1use crate::{
   2    agent_profile::AgentProfile,
   3    context::{AgentContext, AgentContextHandle, ContextLoadResult, LoadedContext},
   4    thread_store::{
   5        SerializedCrease, SerializedLanguageModel, SerializedMessage, SerializedMessageSegment,
   6        SerializedThread, SerializedToolResult, SerializedToolUse, SharedProjectContext,
   7        ThreadStore,
   8    },
   9    tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState},
  10};
  11use action_log::ActionLog;
  12use agent_settings::{
  13    AgentProfileId, AgentSettings, CompletionMode, SUMMARIZE_THREAD_DETAILED_PROMPT,
  14    SUMMARIZE_THREAD_PROMPT,
  15};
  16use anyhow::{Result, anyhow};
  17use assistant_tool::{AnyToolCard, Tool, ToolWorkingSet};
  18use chrono::{DateTime, Utc};
  19use client::{ModelRequestUsage, RequestUsage};
  20use cloud_llm_client::{CompletionIntent, CompletionRequestStatus, Plan, UsageLimit};
  21use collections::HashMap;
  22use futures::{FutureExt, StreamExt as _, future::Shared};
  23use git::repository::DiffType;
  24use gpui::{
  25    AnyWindowHandle, App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task,
  26    WeakEntity, Window,
  27};
  28use http_client::StatusCode;
  29use language_model::{
  30    ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
  31    LanguageModelExt as _, LanguageModelId, LanguageModelRegistry, LanguageModelRequest,
  32    LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
  33    LanguageModelToolResultContent, LanguageModelToolUse, LanguageModelToolUseId, MessageContent,
  34    ModelRequestLimitReachedError, PaymentRequiredError, Role, SelectedModel, StopReason,
  35    TokenUsage,
  36};
  37use postage::stream::Stream as _;
  38use project::{
  39    Project,
  40    git_store::{GitStore, GitStoreCheckpoint, RepositoryState},
  41};
  42use prompt_store::{ModelContext, PromptBuilder};
  43use schemars::JsonSchema;
  44use serde::{Deserialize, Serialize};
  45use settings::Settings;
  46use std::{
  47    io::Write,
  48    ops::Range,
  49    sync::Arc,
  50    time::{Duration, Instant},
  51};
  52use thiserror::Error;
  53use util::{ResultExt as _, post_inc};
  54use uuid::Uuid;
  55
  56const MAX_RETRY_ATTEMPTS: u8 = 4;
  57const BASE_RETRY_DELAY: Duration = Duration::from_secs(5);
  58
  59#[derive(Debug, Clone)]
  60enum RetryStrategy {
  61    ExponentialBackoff {
  62        initial_delay: Duration,
  63        max_attempts: u8,
  64    },
  65    Fixed {
  66        delay: Duration,
  67        max_attempts: u8,
  68    },
  69}
  70
  71#[derive(
  72    Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
  73)]
  74pub struct ThreadId(Arc<str>);
  75
  76impl ThreadId {
  77    pub fn new() -> Self {
  78        Self(Uuid::new_v4().to_string().into())
  79    }
  80}
  81
  82impl std::fmt::Display for ThreadId {
  83    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  84        write!(f, "{}", self.0)
  85    }
  86}
  87
  88impl From<&str> for ThreadId {
  89    fn from(value: &str) -> Self {
  90        Self(value.into())
  91    }
  92}
  93
  94/// The ID of the user prompt that initiated a request.
  95///
  96/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
  97#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
  98pub struct PromptId(Arc<str>);
  99
 100impl PromptId {
 101    pub fn new() -> Self {
 102        Self(Uuid::new_v4().to_string().into())
 103    }
 104}
 105
 106impl std::fmt::Display for PromptId {
 107    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 108        write!(f, "{}", self.0)
 109    }
 110}
 111
 112#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
 113pub struct MessageId(pub usize);
 114
 115impl MessageId {
 116    fn post_inc(&mut self) -> Self {
 117        Self(post_inc(&mut self.0))
 118    }
 119
 120    pub fn as_usize(&self) -> usize {
 121        self.0
 122    }
 123}
 124
 125/// Stored information that can be used to resurrect a context crease when creating an editor for a past message.
 126#[derive(Clone, Debug)]
 127pub struct MessageCrease {
 128    pub range: Range<usize>,
 129    pub icon_path: SharedString,
 130    pub label: SharedString,
 131    /// None for a deserialized message, Some otherwise.
 132    pub context: Option<AgentContextHandle>,
 133}
 134
 135/// A message in a [`Thread`].
 136#[derive(Debug, Clone)]
 137pub struct Message {
 138    pub id: MessageId,
 139    pub role: Role,
 140    pub segments: Vec<MessageSegment>,
 141    pub loaded_context: LoadedContext,
 142    pub creases: Vec<MessageCrease>,
 143    pub is_hidden: bool,
 144    pub ui_only: bool,
 145}
 146
 147impl Message {
 148    /// Returns whether the message contains any meaningful text that should be displayed
 149    /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
 150    pub fn should_display_content(&self) -> bool {
 151        self.segments.iter().all(|segment| segment.should_display())
 152    }
 153
 154    pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
 155        if let Some(MessageSegment::Thinking {
 156            text: segment,
 157            signature: current_signature,
 158        }) = self.segments.last_mut()
 159        {
 160            if let Some(signature) = signature {
 161                *current_signature = Some(signature);
 162            }
 163            segment.push_str(text);
 164        } else {
 165            self.segments.push(MessageSegment::Thinking {
 166                text: text.to_string(),
 167                signature,
 168            });
 169        }
 170    }
 171
 172    pub fn push_redacted_thinking(&mut self, data: String) {
 173        self.segments.push(MessageSegment::RedactedThinking(data));
 174    }
 175
 176    pub fn push_text(&mut self, text: &str) {
 177        if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
 178            segment.push_str(text);
 179        } else {
 180            self.segments.push(MessageSegment::Text(text.to_string()));
 181        }
 182    }
 183
 184    pub fn to_message_content(&self) -> String {
 185        let mut result = String::new();
 186
 187        if !self.loaded_context.text.is_empty() {
 188            result.push_str(&self.loaded_context.text);
 189        }
 190
 191        for segment in &self.segments {
 192            match segment {
 193                MessageSegment::Text(text) => result.push_str(text),
 194                MessageSegment::Thinking { text, .. } => {
 195                    result.push_str("<think>\n");
 196                    result.push_str(text);
 197                    result.push_str("\n</think>");
 198                }
 199                MessageSegment::RedactedThinking(_) => {}
 200            }
 201        }
 202
 203        result
 204    }
 205}
 206
 207#[derive(Debug, Clone, PartialEq, Eq)]
 208pub enum MessageSegment {
 209    Text(String),
 210    Thinking {
 211        text: String,
 212        signature: Option<String>,
 213    },
 214    RedactedThinking(String),
 215}
 216
 217impl MessageSegment {
 218    pub fn should_display(&self) -> bool {
 219        match self {
 220            Self::Text(text) => text.is_empty(),
 221            Self::Thinking { text, .. } => text.is_empty(),
 222            Self::RedactedThinking(_) => false,
 223        }
 224    }
 225
 226    pub fn text(&self) -> Option<&str> {
 227        match self {
 228            MessageSegment::Text(text) => Some(text),
 229            _ => None,
 230        }
 231    }
 232}
 233
 234#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
 235pub struct ProjectSnapshot {
 236    pub worktree_snapshots: Vec<WorktreeSnapshot>,
 237    pub timestamp: DateTime<Utc>,
 238}
 239
 240#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
 241pub struct WorktreeSnapshot {
 242    pub worktree_path: String,
 243    pub git_state: Option<GitState>,
 244}
 245
 246#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
 247pub struct GitState {
 248    pub remote_url: Option<String>,
 249    pub head_sha: Option<String>,
 250    pub current_branch: Option<String>,
 251    pub diff: Option<String>,
 252}
 253
 254#[derive(Clone, Debug)]
 255pub struct ThreadCheckpoint {
 256    message_id: MessageId,
 257    git_checkpoint: GitStoreCheckpoint,
 258}
 259
 260#[derive(Copy, Clone, Debug, PartialEq, Eq)]
 261pub enum ThreadFeedback {
 262    Positive,
 263    Negative,
 264}
 265
 266pub enum LastRestoreCheckpoint {
 267    Pending {
 268        message_id: MessageId,
 269    },
 270    Error {
 271        message_id: MessageId,
 272        error: String,
 273    },
 274}
 275
 276impl LastRestoreCheckpoint {
 277    pub fn message_id(&self) -> MessageId {
 278        match self {
 279            LastRestoreCheckpoint::Pending { message_id } => *message_id,
 280            LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
 281        }
 282    }
 283}
 284
 285#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
 286pub enum DetailedSummaryState {
 287    #[default]
 288    NotGenerated,
 289    Generating {
 290        message_id: MessageId,
 291    },
 292    Generated {
 293        text: SharedString,
 294        message_id: MessageId,
 295    },
 296}
 297
 298impl DetailedSummaryState {
 299    fn text(&self) -> Option<SharedString> {
 300        if let Self::Generated { text, .. } = self {
 301            Some(text.clone())
 302        } else {
 303            None
 304        }
 305    }
 306}
 307
 308#[derive(Default, Debug)]
 309pub struct TotalTokenUsage {
 310    pub total: u64,
 311    pub max: u64,
 312}
 313
 314impl TotalTokenUsage {
 315    pub fn ratio(&self) -> TokenUsageRatio {
 316        #[cfg(debug_assertions)]
 317        let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
 318            .unwrap_or("0.8".to_string())
 319            .parse()
 320            .unwrap();
 321        #[cfg(not(debug_assertions))]
 322        let warning_threshold: f32 = 0.8;
 323
 324        // When the maximum is unknown because there is no selected model,
 325        // avoid showing the token limit warning.
 326        if self.max == 0 {
 327            TokenUsageRatio::Normal
 328        } else if self.total >= self.max {
 329            TokenUsageRatio::Exceeded
 330        } else if self.total as f32 / self.max as f32 >= warning_threshold {
 331            TokenUsageRatio::Warning
 332        } else {
 333            TokenUsageRatio::Normal
 334        }
 335    }
 336
 337    pub fn add(&self, tokens: u64) -> TotalTokenUsage {
 338        TotalTokenUsage {
 339            total: self.total + tokens,
 340            max: self.max,
 341        }
 342    }
 343}
 344
 345#[derive(Debug, Default, PartialEq, Eq)]
 346pub enum TokenUsageRatio {
 347    #[default]
 348    Normal,
 349    Warning,
 350    Exceeded,
 351}
 352
 353#[derive(Debug, Clone, Copy)]
 354pub enum QueueState {
 355    Sending,
 356    Queued { position: usize },
 357    Started,
 358}
 359
 360/// A thread of conversation with the LLM.
 361pub struct Thread {
 362    id: ThreadId,
 363    updated_at: DateTime<Utc>,
 364    summary: ThreadSummary,
 365    pending_summary: Task<Option<()>>,
 366    detailed_summary_task: Task<Option<()>>,
 367    detailed_summary_tx: postage::watch::Sender<DetailedSummaryState>,
 368    detailed_summary_rx: postage::watch::Receiver<DetailedSummaryState>,
 369    completion_mode: agent_settings::CompletionMode,
 370    messages: Vec<Message>,
 371    next_message_id: MessageId,
 372    last_prompt_id: PromptId,
 373    project_context: SharedProjectContext,
 374    checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
 375    completion_count: usize,
 376    pending_completions: Vec<PendingCompletion>,
 377    project: Entity<Project>,
 378    prompt_builder: Arc<PromptBuilder>,
 379    tools: Entity<ToolWorkingSet>,
 380    tool_use: ToolUseState,
 381    action_log: Entity<ActionLog>,
 382    last_restore_checkpoint: Option<LastRestoreCheckpoint>,
 383    pending_checkpoint: Option<ThreadCheckpoint>,
 384    initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
 385    request_token_usage: Vec<TokenUsage>,
 386    cumulative_token_usage: TokenUsage,
 387    exceeded_window_error: Option<ExceededWindowError>,
 388    tool_use_limit_reached: bool,
 389    retry_state: Option<RetryState>,
 390    message_feedback: HashMap<MessageId, ThreadFeedback>,
 391    last_received_chunk_at: Option<Instant>,
 392    request_callback: Option<
 393        Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
 394    >,
 395    remaining_turns: u32,
 396    configured_model: Option<ConfiguredModel>,
 397    profile: AgentProfile,
 398    last_error_context: Option<(Arc<dyn LanguageModel>, CompletionIntent)>,
 399}
 400
 401#[derive(Clone, Debug)]
 402struct RetryState {
 403    attempt: u8,
 404    max_attempts: u8,
 405    intent: CompletionIntent,
 406}
 407
 408#[derive(Clone, Debug, PartialEq, Eq)]
 409pub enum ThreadSummary {
 410    Pending,
 411    Generating,
 412    Ready(SharedString),
 413    Error,
 414}
 415
 416impl ThreadSummary {
 417    pub const DEFAULT: SharedString = SharedString::new_static("New Thread");
 418
 419    pub fn or_default(&self) -> SharedString {
 420        self.unwrap_or(Self::DEFAULT)
 421    }
 422
 423    pub fn unwrap_or(&self, message: impl Into<SharedString>) -> SharedString {
 424        self.ready().unwrap_or_else(|| message.into())
 425    }
 426
 427    pub fn ready(&self) -> Option<SharedString> {
 428        match self {
 429            ThreadSummary::Ready(summary) => Some(summary.clone()),
 430            ThreadSummary::Pending | ThreadSummary::Generating | ThreadSummary::Error => None,
 431        }
 432    }
 433}
 434
 435#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
 436pub struct ExceededWindowError {
 437    /// Model used when last message exceeded context window
 438    model_id: LanguageModelId,
 439    /// Token count including last message
 440    token_count: u64,
 441}
 442
 443impl Thread {
 444    pub fn new(
 445        project: Entity<Project>,
 446        tools: Entity<ToolWorkingSet>,
 447        prompt_builder: Arc<PromptBuilder>,
 448        system_prompt: SharedProjectContext,
 449        cx: &mut Context<Self>,
 450    ) -> Self {
 451        let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel();
 452        let configured_model = LanguageModelRegistry::read_global(cx).default_model();
 453        let profile_id = AgentSettings::get_global(cx).default_profile.clone();
 454
 455        Self {
 456            id: ThreadId::new(),
 457            updated_at: Utc::now(),
 458            summary: ThreadSummary::Pending,
 459            pending_summary: Task::ready(None),
 460            detailed_summary_task: Task::ready(None),
 461            detailed_summary_tx,
 462            detailed_summary_rx,
 463            completion_mode: AgentSettings::get_global(cx).preferred_completion_mode,
 464            messages: Vec::new(),
 465            next_message_id: MessageId(0),
 466            last_prompt_id: PromptId::new(),
 467            project_context: system_prompt,
 468            checkpoints_by_message: HashMap::default(),
 469            completion_count: 0,
 470            pending_completions: Vec::new(),
 471            project: project.clone(),
 472            prompt_builder,
 473            tools: tools.clone(),
 474            last_restore_checkpoint: None,
 475            pending_checkpoint: None,
 476            tool_use: ToolUseState::new(tools.clone()),
 477            action_log: cx.new(|_| ActionLog::new(project.clone())),
 478            initial_project_snapshot: {
 479                let project_snapshot = Self::project_snapshot(project, cx);
 480                cx.foreground_executor()
 481                    .spawn(async move { Some(project_snapshot.await) })
 482                    .shared()
 483            },
 484            request_token_usage: Vec::new(),
 485            cumulative_token_usage: TokenUsage::default(),
 486            exceeded_window_error: None,
 487            tool_use_limit_reached: false,
 488            retry_state: None,
 489            message_feedback: HashMap::default(),
 490            last_error_context: None,
 491            last_received_chunk_at: None,
 492            request_callback: None,
 493            remaining_turns: u32::MAX,
 494            configured_model,
 495            profile: AgentProfile::new(profile_id, tools),
 496        }
 497    }
 498
 499    pub fn deserialize(
 500        id: ThreadId,
 501        serialized: SerializedThread,
 502        project: Entity<Project>,
 503        tools: Entity<ToolWorkingSet>,
 504        prompt_builder: Arc<PromptBuilder>,
 505        project_context: SharedProjectContext,
 506        window: Option<&mut Window>, // None in headless mode
 507        cx: &mut Context<Self>,
 508    ) -> Self {
 509        let next_message_id = MessageId(
 510            serialized
 511                .messages
 512                .last()
 513                .map(|message| message.id.0 + 1)
 514                .unwrap_or(0),
 515        );
 516        let tool_use = ToolUseState::from_serialized_messages(
 517            tools.clone(),
 518            &serialized.messages,
 519            project.clone(),
 520            window,
 521            cx,
 522        );
 523        let (detailed_summary_tx, detailed_summary_rx) =
 524            postage::watch::channel_with(serialized.detailed_summary_state);
 525
 526        let configured_model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
 527            serialized
 528                .model
 529                .and_then(|model| {
 530                    let model = SelectedModel {
 531                        provider: model.provider.clone().into(),
 532                        model: model.model.into(),
 533                    };
 534                    registry.select_model(&model, cx)
 535                })
 536                .or_else(|| registry.default_model())
 537        });
 538
 539        let completion_mode = serialized
 540            .completion_mode
 541            .unwrap_or_else(|| AgentSettings::get_global(cx).preferred_completion_mode);
 542        let profile_id = serialized
 543            .profile
 544            .unwrap_or_else(|| AgentSettings::get_global(cx).default_profile.clone());
 545
 546        Self {
 547            id,
 548            updated_at: serialized.updated_at,
 549            summary: ThreadSummary::Ready(serialized.summary),
 550            pending_summary: Task::ready(None),
 551            detailed_summary_task: Task::ready(None),
 552            detailed_summary_tx,
 553            detailed_summary_rx,
 554            completion_mode,
 555            retry_state: None,
 556            messages: serialized
 557                .messages
 558                .into_iter()
 559                .map(|message| Message {
 560                    id: message.id,
 561                    role: message.role,
 562                    segments: message
 563                        .segments
 564                        .into_iter()
 565                        .map(|segment| match segment {
 566                            SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
 567                            SerializedMessageSegment::Thinking { text, signature } => {
 568                                MessageSegment::Thinking { text, signature }
 569                            }
 570                            SerializedMessageSegment::RedactedThinking { data } => {
 571                                MessageSegment::RedactedThinking(data)
 572                            }
 573                        })
 574                        .collect(),
 575                    loaded_context: LoadedContext {
 576                        contexts: Vec::new(),
 577                        text: message.context,
 578                        images: Vec::new(),
 579                    },
 580                    creases: message
 581                        .creases
 582                        .into_iter()
 583                        .map(|crease| MessageCrease {
 584                            range: crease.start..crease.end,
 585                            icon_path: crease.icon_path,
 586                            label: crease.label,
 587                            context: None,
 588                        })
 589                        .collect(),
 590                    is_hidden: message.is_hidden,
 591                    ui_only: false, // UI-only messages are not persisted
 592                })
 593                .collect(),
 594            next_message_id,
 595            last_prompt_id: PromptId::new(),
 596            project_context,
 597            checkpoints_by_message: HashMap::default(),
 598            completion_count: 0,
 599            pending_completions: Vec::new(),
 600            last_restore_checkpoint: None,
 601            pending_checkpoint: None,
 602            project: project.clone(),
 603            prompt_builder,
 604            tools: tools.clone(),
 605            tool_use,
 606            action_log: cx.new(|_| ActionLog::new(project)),
 607            initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
 608            request_token_usage: serialized.request_token_usage,
 609            cumulative_token_usage: serialized.cumulative_token_usage,
 610            exceeded_window_error: None,
 611            tool_use_limit_reached: serialized.tool_use_limit_reached,
 612            message_feedback: HashMap::default(),
 613            last_error_context: None,
 614            last_received_chunk_at: None,
 615            request_callback: None,
 616            remaining_turns: u32::MAX,
 617            configured_model,
 618            profile: AgentProfile::new(profile_id, tools),
 619        }
 620    }
 621
 622    pub fn set_request_callback(
 623        &mut self,
 624        callback: impl 'static
 625        + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
 626    ) {
 627        self.request_callback = Some(Box::new(callback));
 628    }
 629
 630    pub fn id(&self) -> &ThreadId {
 631        &self.id
 632    }
 633
 634    pub fn profile(&self) -> &AgentProfile {
 635        &self.profile
 636    }
 637
 638    pub fn set_profile(&mut self, id: AgentProfileId, cx: &mut Context<Self>) {
 639        if &id != self.profile.id() {
 640            self.profile = AgentProfile::new(id, self.tools.clone());
 641            cx.emit(ThreadEvent::ProfileChanged);
 642        }
 643    }
 644
 645    pub fn is_empty(&self) -> bool {
 646        self.messages.is_empty()
 647    }
 648
 649    pub fn updated_at(&self) -> DateTime<Utc> {
 650        self.updated_at
 651    }
 652
 653    pub fn touch_updated_at(&mut self) {
 654        self.updated_at = Utc::now();
 655    }
 656
 657    pub fn advance_prompt_id(&mut self) {
 658        self.last_prompt_id = PromptId::new();
 659    }
 660
 661    pub fn project_context(&self) -> SharedProjectContext {
 662        self.project_context.clone()
 663    }
 664
 665    pub fn get_or_init_configured_model(&mut self, cx: &App) -> Option<ConfiguredModel> {
 666        if self.configured_model.is_none() {
 667            self.configured_model = LanguageModelRegistry::read_global(cx).default_model();
 668        }
 669        self.configured_model.clone()
 670    }
 671
 672    pub fn configured_model(&self) -> Option<ConfiguredModel> {
 673        self.configured_model.clone()
 674    }
 675
 676    pub fn set_configured_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
 677        self.configured_model = model;
 678        cx.notify();
 679    }
 680
 681    pub fn summary(&self) -> &ThreadSummary {
 682        &self.summary
 683    }
 684
 685    pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
 686        let current_summary = match &self.summary {
 687            ThreadSummary::Pending | ThreadSummary::Generating => return,
 688            ThreadSummary::Ready(summary) => summary,
 689            ThreadSummary::Error => &ThreadSummary::DEFAULT,
 690        };
 691
 692        let mut new_summary = new_summary.into();
 693
 694        if new_summary.is_empty() {
 695            new_summary = ThreadSummary::DEFAULT;
 696        }
 697
 698        if current_summary != &new_summary {
 699            self.summary = ThreadSummary::Ready(new_summary);
 700            cx.emit(ThreadEvent::SummaryChanged);
 701        }
 702    }
 703
 704    pub fn completion_mode(&self) -> CompletionMode {
 705        self.completion_mode
 706    }
 707
 708    pub fn set_completion_mode(&mut self, mode: CompletionMode) {
 709        self.completion_mode = mode;
 710    }
 711
 712    pub fn message(&self, id: MessageId) -> Option<&Message> {
 713        let index = self
 714            .messages
 715            .binary_search_by(|message| message.id.cmp(&id))
 716            .ok()?;
 717
 718        self.messages.get(index)
 719    }
 720
 721    pub fn messages(&self) -> impl ExactSizeIterator<Item = &Message> {
 722        self.messages.iter()
 723    }
 724
 725    pub fn is_generating(&self) -> bool {
 726        !self.pending_completions.is_empty() || !self.all_tools_finished()
 727    }
 728
 729    /// Indicates whether streaming of language model events is stale.
 730    /// When `is_generating()` is false, this method returns `None`.
 731    pub fn is_generation_stale(&self) -> Option<bool> {
 732        const STALE_THRESHOLD: u128 = 250;
 733
 734        self.last_received_chunk_at
 735            .map(|instant| instant.elapsed().as_millis() > STALE_THRESHOLD)
 736    }
 737
 738    fn received_chunk(&mut self) {
 739        self.last_received_chunk_at = Some(Instant::now());
 740    }
 741
 742    pub fn queue_state(&self) -> Option<QueueState> {
 743        self.pending_completions
 744            .first()
 745            .map(|pending_completion| pending_completion.queue_state)
 746    }
 747
 748    pub fn tools(&self) -> &Entity<ToolWorkingSet> {
 749        &self.tools
 750    }
 751
 752    pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
 753        self.tool_use
 754            .pending_tool_uses()
 755            .into_iter()
 756            .find(|tool_use| &tool_use.id == id)
 757    }
 758
 759    pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
 760        self.tool_use
 761            .pending_tool_uses()
 762            .into_iter()
 763            .filter(|tool_use| tool_use.status.needs_confirmation())
 764    }
 765
 766    pub fn has_pending_tool_uses(&self) -> bool {
 767        !self.tool_use.pending_tool_uses().is_empty()
 768    }
 769
 770    pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
 771        self.checkpoints_by_message.get(&id).cloned()
 772    }
 773
 774    pub fn restore_checkpoint(
 775        &mut self,
 776        checkpoint: ThreadCheckpoint,
 777        cx: &mut Context<Self>,
 778    ) -> Task<Result<()>> {
 779        self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
 780            message_id: checkpoint.message_id,
 781        });
 782        cx.emit(ThreadEvent::CheckpointChanged);
 783        cx.notify();
 784
 785        let git_store = self.project().read(cx).git_store().clone();
 786        let restore = git_store.update(cx, |git_store, cx| {
 787            git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
 788        });
 789
 790        cx.spawn(async move |this, cx| {
 791            let result = restore.await;
 792            this.update(cx, |this, cx| {
 793                if let Err(err) = result.as_ref() {
 794                    this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
 795                        message_id: checkpoint.message_id,
 796                        error: err.to_string(),
 797                    });
 798                } else {
 799                    this.truncate(checkpoint.message_id, cx);
 800                    this.last_restore_checkpoint = None;
 801                }
 802                this.pending_checkpoint = None;
 803                cx.emit(ThreadEvent::CheckpointChanged);
 804                cx.notify();
 805            })?;
 806            result
 807        })
 808    }
 809
 810    fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
 811        let pending_checkpoint = if self.is_generating() {
 812            return;
 813        } else if let Some(checkpoint) = self.pending_checkpoint.take() {
 814            checkpoint
 815        } else {
 816            return;
 817        };
 818
 819        self.finalize_checkpoint(pending_checkpoint, cx);
 820    }
 821
 822    fn finalize_checkpoint(
 823        &mut self,
 824        pending_checkpoint: ThreadCheckpoint,
 825        cx: &mut Context<Self>,
 826    ) {
 827        let git_store = self.project.read(cx).git_store().clone();
 828        let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
 829        cx.spawn(async move |this, cx| match final_checkpoint.await {
 830            Ok(final_checkpoint) => {
 831                let equal = git_store
 832                    .update(cx, |store, cx| {
 833                        store.compare_checkpoints(
 834                            pending_checkpoint.git_checkpoint.clone(),
 835                            final_checkpoint.clone(),
 836                            cx,
 837                        )
 838                    })?
 839                    .await
 840                    .unwrap_or(false);
 841
 842                this.update(cx, |this, cx| {
 843                    this.pending_checkpoint = if equal {
 844                        Some(pending_checkpoint)
 845                    } else {
 846                        this.insert_checkpoint(pending_checkpoint, cx);
 847                        Some(ThreadCheckpoint {
 848                            message_id: this.next_message_id,
 849                            git_checkpoint: final_checkpoint,
 850                        })
 851                    }
 852                })?;
 853
 854                Ok(())
 855            }
 856            Err(_) => this.update(cx, |this, cx| {
 857                this.insert_checkpoint(pending_checkpoint, cx)
 858            }),
 859        })
 860        .detach();
 861    }
 862
 863    fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
 864        self.checkpoints_by_message
 865            .insert(checkpoint.message_id, checkpoint);
 866        cx.emit(ThreadEvent::CheckpointChanged);
 867        cx.notify();
 868    }
 869
 870    pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
 871        self.last_restore_checkpoint.as_ref()
 872    }
 873
 874    pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
 875        let Some(message_ix) = self
 876            .messages
 877            .iter()
 878            .rposition(|message| message.id == message_id)
 879        else {
 880            return;
 881        };
 882        for deleted_message in self.messages.drain(message_ix..) {
 883            self.checkpoints_by_message.remove(&deleted_message.id);
 884        }
 885        cx.notify();
 886    }
 887
 888    pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AgentContext> {
 889        self.messages
 890            .iter()
 891            .find(|message| message.id == id)
 892            .into_iter()
 893            .flat_map(|message| message.loaded_context.contexts.iter())
 894    }
 895
 896    pub fn is_turn_end(&self, ix: usize) -> bool {
 897        if self.messages.is_empty() {
 898            return false;
 899        }
 900
 901        if !self.is_generating() && ix == self.messages.len() - 1 {
 902            return true;
 903        }
 904
 905        let Some(message) = self.messages.get(ix) else {
 906            return false;
 907        };
 908
 909        if message.role != Role::Assistant {
 910            return false;
 911        }
 912
 913        self.messages
 914            .get(ix + 1)
 915            .and_then(|message| {
 916                self.message(message.id)
 917                    .map(|next_message| next_message.role == Role::User && !next_message.is_hidden)
 918            })
 919            .unwrap_or(false)
 920    }
 921
 922    pub fn tool_use_limit_reached(&self) -> bool {
 923        self.tool_use_limit_reached
 924    }
 925
 926    /// Returns whether all of the tool uses have finished running.
 927    pub fn all_tools_finished(&self) -> bool {
 928        // If the only pending tool uses left are the ones with errors, then
 929        // that means that we've finished running all of the pending tools.
 930        self.tool_use
 931            .pending_tool_uses()
 932            .iter()
 933            .all(|pending_tool_use| pending_tool_use.status.is_error())
 934    }
 935
 936    /// Returns whether any pending tool uses may perform edits
 937    pub fn has_pending_edit_tool_uses(&self) -> bool {
 938        self.tool_use
 939            .pending_tool_uses()
 940            .iter()
 941            .filter(|pending_tool_use| !pending_tool_use.status.is_error())
 942            .any(|pending_tool_use| pending_tool_use.may_perform_edits)
 943    }
 944
 945    pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
 946        self.tool_use.tool_uses_for_message(id, &self.project, cx)
 947    }
 948
 949    pub fn tool_results_for_message(
 950        &self,
 951        assistant_message_id: MessageId,
 952    ) -> Vec<&LanguageModelToolResult> {
 953        self.tool_use.tool_results_for_message(assistant_message_id)
 954    }
 955
 956    pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
 957        self.tool_use.tool_result(id)
 958    }
 959
 960    pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
 961        match &self.tool_use.tool_result(id)?.content {
 962            LanguageModelToolResultContent::Text(text) => Some(text),
 963            LanguageModelToolResultContent::Image(_) => {
 964                // TODO: We should display image
 965                None
 966            }
 967        }
 968    }
 969
 970    pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
 971        self.tool_use.tool_result_card(id).cloned()
 972    }
 973
 974    /// Return tools that are both enabled and supported by the model
 975    pub fn available_tools(
 976        &self,
 977        cx: &App,
 978        model: Arc<dyn LanguageModel>,
 979    ) -> Vec<LanguageModelRequestTool> {
 980        if model.supports_tools() {
 981            self.profile
 982                .enabled_tools(cx)
 983                .into_iter()
 984                .filter_map(|(name, tool)| {
 985                    // Skip tools that cannot be supported
 986                    let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
 987                    Some(LanguageModelRequestTool {
 988                        name: name.into(),
 989                        description: tool.description(),
 990                        input_schema,
 991                    })
 992                })
 993                .collect()
 994        } else {
 995            Vec::default()
 996        }
 997    }
 998
 999    pub fn insert_user_message(
1000        &mut self,
1001        text: impl Into<String>,
1002        loaded_context: ContextLoadResult,
1003        git_checkpoint: Option<GitStoreCheckpoint>,
1004        creases: Vec<MessageCrease>,
1005        cx: &mut Context<Self>,
1006    ) -> MessageId {
1007        if !loaded_context.referenced_buffers.is_empty() {
1008            self.action_log.update(cx, |log, cx| {
1009                for buffer in loaded_context.referenced_buffers {
1010                    log.buffer_read(buffer, cx);
1011                }
1012            });
1013        }
1014
1015        let message_id = self.insert_message(
1016            Role::User,
1017            vec![MessageSegment::Text(text.into())],
1018            loaded_context.loaded_context,
1019            creases,
1020            false,
1021            cx,
1022        );
1023
1024        if let Some(git_checkpoint) = git_checkpoint {
1025            self.pending_checkpoint = Some(ThreadCheckpoint {
1026                message_id,
1027                git_checkpoint,
1028            });
1029        }
1030
1031        message_id
1032    }
1033
1034    pub fn insert_invisible_continue_message(&mut self, cx: &mut Context<Self>) -> MessageId {
1035        let id = self.insert_message(
1036            Role::User,
1037            vec![MessageSegment::Text("Continue where you left off".into())],
1038            LoadedContext::default(),
1039            vec![],
1040            true,
1041            cx,
1042        );
1043        self.pending_checkpoint = None;
1044
1045        id
1046    }
1047
1048    pub fn insert_assistant_message(
1049        &mut self,
1050        segments: Vec<MessageSegment>,
1051        cx: &mut Context<Self>,
1052    ) -> MessageId {
1053        self.insert_message(
1054            Role::Assistant,
1055            segments,
1056            LoadedContext::default(),
1057            Vec::new(),
1058            false,
1059            cx,
1060        )
1061    }
1062
1063    pub fn insert_message(
1064        &mut self,
1065        role: Role,
1066        segments: Vec<MessageSegment>,
1067        loaded_context: LoadedContext,
1068        creases: Vec<MessageCrease>,
1069        is_hidden: bool,
1070        cx: &mut Context<Self>,
1071    ) -> MessageId {
1072        let id = self.next_message_id.post_inc();
1073        self.messages.push(Message {
1074            id,
1075            role,
1076            segments,
1077            loaded_context,
1078            creases,
1079            is_hidden,
1080            ui_only: false,
1081        });
1082        self.touch_updated_at();
1083        cx.emit(ThreadEvent::MessageAdded(id));
1084        id
1085    }
1086
1087    pub fn edit_message(
1088        &mut self,
1089        id: MessageId,
1090        new_role: Role,
1091        new_segments: Vec<MessageSegment>,
1092        creases: Vec<MessageCrease>,
1093        loaded_context: Option<LoadedContext>,
1094        checkpoint: Option<GitStoreCheckpoint>,
1095        cx: &mut Context<Self>,
1096    ) -> bool {
1097        let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
1098            return false;
1099        };
1100        message.role = new_role;
1101        message.segments = new_segments;
1102        message.creases = creases;
1103        if let Some(context) = loaded_context {
1104            message.loaded_context = context;
1105        }
1106        if let Some(git_checkpoint) = checkpoint {
1107            self.checkpoints_by_message.insert(
1108                id,
1109                ThreadCheckpoint {
1110                    message_id: id,
1111                    git_checkpoint,
1112                },
1113            );
1114        }
1115        self.touch_updated_at();
1116        cx.emit(ThreadEvent::MessageEdited(id));
1117        true
1118    }
1119
1120    pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
1121        let Some(index) = self.messages.iter().position(|message| message.id == id) else {
1122            return false;
1123        };
1124        self.messages.remove(index);
1125        self.touch_updated_at();
1126        cx.emit(ThreadEvent::MessageDeleted(id));
1127        true
1128    }
1129
1130    /// Returns the representation of this [`Thread`] in a textual form.
1131    ///
1132    /// This is the representation we use when attaching a thread as context to another thread.
1133    pub fn text(&self) -> String {
1134        let mut text = String::new();
1135
1136        for message in &self.messages {
1137            text.push_str(match message.role {
1138                language_model::Role::User => "User:",
1139                language_model::Role::Assistant => "Agent:",
1140                language_model::Role::System => "System:",
1141            });
1142            text.push('\n');
1143
1144            for segment in &message.segments {
1145                match segment {
1146                    MessageSegment::Text(content) => text.push_str(content),
1147                    MessageSegment::Thinking { text: content, .. } => {
1148                        text.push_str(&format!("<think>{}</think>", content))
1149                    }
1150                    MessageSegment::RedactedThinking(_) => {}
1151                }
1152            }
1153            text.push('\n');
1154        }
1155
1156        text
1157    }
1158
1159    /// Serializes this thread into a format for storage or telemetry.
1160    pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
1161        let initial_project_snapshot = self.initial_project_snapshot.clone();
1162        cx.spawn(async move |this, cx| {
1163            let initial_project_snapshot = initial_project_snapshot.await;
1164            this.read_with(cx, |this, cx| SerializedThread {
1165                version: SerializedThread::VERSION.to_string(),
1166                summary: this.summary().or_default(),
1167                updated_at: this.updated_at(),
1168                messages: this
1169                    .messages()
1170                    .filter(|message| !message.ui_only)
1171                    .map(|message| SerializedMessage {
1172                        id: message.id,
1173                        role: message.role,
1174                        segments: message
1175                            .segments
1176                            .iter()
1177                            .map(|segment| match segment {
1178                                MessageSegment::Text(text) => {
1179                                    SerializedMessageSegment::Text { text: text.clone() }
1180                                }
1181                                MessageSegment::Thinking { text, signature } => {
1182                                    SerializedMessageSegment::Thinking {
1183                                        text: text.clone(),
1184                                        signature: signature.clone(),
1185                                    }
1186                                }
1187                                MessageSegment::RedactedThinking(data) => {
1188                                    SerializedMessageSegment::RedactedThinking {
1189                                        data: data.clone(),
1190                                    }
1191                                }
1192                            })
1193                            .collect(),
1194                        tool_uses: this
1195                            .tool_uses_for_message(message.id, cx)
1196                            .into_iter()
1197                            .map(|tool_use| SerializedToolUse {
1198                                id: tool_use.id,
1199                                name: tool_use.name,
1200                                input: tool_use.input,
1201                            })
1202                            .collect(),
1203                        tool_results: this
1204                            .tool_results_for_message(message.id)
1205                            .into_iter()
1206                            .map(|tool_result| SerializedToolResult {
1207                                tool_use_id: tool_result.tool_use_id.clone(),
1208                                is_error: tool_result.is_error,
1209                                content: tool_result.content.clone(),
1210                                output: tool_result.output.clone(),
1211                            })
1212                            .collect(),
1213                        context: message.loaded_context.text.clone(),
1214                        creases: message
1215                            .creases
1216                            .iter()
1217                            .map(|crease| SerializedCrease {
1218                                start: crease.range.start,
1219                                end: crease.range.end,
1220                                icon_path: crease.icon_path.clone(),
1221                                label: crease.label.clone(),
1222                            })
1223                            .collect(),
1224                        is_hidden: message.is_hidden,
1225                    })
1226                    .collect(),
1227                initial_project_snapshot,
1228                cumulative_token_usage: this.cumulative_token_usage,
1229                request_token_usage: this.request_token_usage.clone(),
1230                detailed_summary_state: this.detailed_summary_rx.borrow().clone(),
1231                exceeded_window_error: this.exceeded_window_error.clone(),
1232                model: this
1233                    .configured_model
1234                    .as_ref()
1235                    .map(|model| SerializedLanguageModel {
1236                        provider: model.provider.id().0.to_string(),
1237                        model: model.model.id().0.to_string(),
1238                    }),
1239                completion_mode: Some(this.completion_mode),
1240                tool_use_limit_reached: this.tool_use_limit_reached,
1241                profile: Some(this.profile.id().clone()),
1242            })
1243        })
1244    }
1245
1246    pub fn remaining_turns(&self) -> u32 {
1247        self.remaining_turns
1248    }
1249
1250    pub fn set_remaining_turns(&mut self, remaining_turns: u32) {
1251        self.remaining_turns = remaining_turns;
1252    }
1253
1254    pub fn send_to_model(
1255        &mut self,
1256        model: Arc<dyn LanguageModel>,
1257        intent: CompletionIntent,
1258        window: Option<AnyWindowHandle>,
1259        cx: &mut Context<Self>,
1260    ) {
1261        if self.remaining_turns == 0 {
1262            return;
1263        }
1264
1265        self.remaining_turns -= 1;
1266
1267        self.flush_notifications(model.clone(), intent, cx);
1268
1269        let _checkpoint = self.finalize_pending_checkpoint(cx);
1270        self.stream_completion(
1271            self.to_completion_request(model.clone(), intent, cx),
1272            model,
1273            intent,
1274            window,
1275            cx,
1276        );
1277    }
1278
1279    pub fn retry_last_completion(
1280        &mut self,
1281        window: Option<AnyWindowHandle>,
1282        cx: &mut Context<Self>,
1283    ) {
1284        // Clear any existing error state
1285        self.retry_state = None;
1286
1287        // Use the last error context if available, otherwise fall back to configured model
1288        let (model, intent) = if let Some((model, intent)) = self.last_error_context.take() {
1289            (model, intent)
1290        } else if let Some(configured_model) = self.configured_model.as_ref() {
1291            let model = configured_model.model.clone();
1292            let intent = if self.has_pending_tool_uses() {
1293                CompletionIntent::ToolResults
1294            } else {
1295                CompletionIntent::UserPrompt
1296            };
1297            (model, intent)
1298        } else if let Some(configured_model) = self.get_or_init_configured_model(cx) {
1299            let model = configured_model.model.clone();
1300            let intent = if self.has_pending_tool_uses() {
1301                CompletionIntent::ToolResults
1302            } else {
1303                CompletionIntent::UserPrompt
1304            };
1305            (model, intent)
1306        } else {
1307            return;
1308        };
1309
1310        self.send_to_model(model, intent, window, cx);
1311    }
1312
1313    pub fn enable_burn_mode_and_retry(
1314        &mut self,
1315        window: Option<AnyWindowHandle>,
1316        cx: &mut Context<Self>,
1317    ) {
1318        self.completion_mode = CompletionMode::Burn;
1319        cx.emit(ThreadEvent::ProfileChanged);
1320        self.retry_last_completion(window, cx);
1321    }
1322
1323    pub fn used_tools_since_last_user_message(&self) -> bool {
1324        for message in self.messages.iter().rev() {
1325            if self.tool_use.message_has_tool_results(message.id) {
1326                return true;
1327            } else if message.role == Role::User {
1328                return false;
1329            }
1330        }
1331
1332        false
1333    }
1334
1335    pub fn to_completion_request(
1336        &self,
1337        model: Arc<dyn LanguageModel>,
1338        intent: CompletionIntent,
1339        cx: &mut Context<Self>,
1340    ) -> LanguageModelRequest {
1341        let mut request = LanguageModelRequest {
1342            thread_id: Some(self.id.to_string()),
1343            prompt_id: Some(self.last_prompt_id.to_string()),
1344            intent: Some(intent),
1345            mode: None,
1346            messages: vec![],
1347            tools: Vec::new(),
1348            tool_choice: None,
1349            stop: Vec::new(),
1350            temperature: AgentSettings::temperature_for_model(&model, cx),
1351            thinking_allowed: true,
1352        };
1353
1354        let available_tools = self.available_tools(cx, model.clone());
1355        let available_tool_names = available_tools
1356            .iter()
1357            .map(|tool| tool.name.clone())
1358            .collect();
1359
1360        let model_context = &ModelContext {
1361            available_tools: available_tool_names,
1362        };
1363
1364        if let Some(project_context) = self.project_context.borrow().as_ref() {
1365            match self
1366                .prompt_builder
1367                .generate_assistant_system_prompt(project_context, model_context)
1368            {
1369                Err(err) => {
1370                    let message = format!("{err:?}").into();
1371                    log::error!("{message}");
1372                    cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1373                        header: "Error generating system prompt".into(),
1374                        message,
1375                    }));
1376                }
1377                Ok(system_prompt) => {
1378                    request.messages.push(LanguageModelRequestMessage {
1379                        role: Role::System,
1380                        content: vec![MessageContent::Text(system_prompt)],
1381                        cache: true,
1382                    });
1383                }
1384            }
1385        } else {
1386            let message = "Context for system prompt unexpectedly not ready.".into();
1387            log::error!("{message}");
1388            cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1389                header: "Error generating system prompt".into(),
1390                message,
1391            }));
1392        }
1393
1394        let mut message_ix_to_cache = None;
1395        for message in &self.messages {
1396            // ui_only messages are for the UI only, not for the model
1397            if message.ui_only {
1398                continue;
1399            }
1400
1401            let mut request_message = LanguageModelRequestMessage {
1402                role: message.role,
1403                content: Vec::new(),
1404                cache: false,
1405            };
1406
1407            message
1408                .loaded_context
1409                .add_to_request_message(&mut request_message);
1410
1411            for segment in &message.segments {
1412                match segment {
1413                    MessageSegment::Text(text) => {
1414                        let text = text.trim_end();
1415                        if !text.is_empty() {
1416                            request_message
1417                                .content
1418                                .push(MessageContent::Text(text.into()));
1419                        }
1420                    }
1421                    MessageSegment::Thinking { text, signature } => {
1422                        if !text.is_empty() {
1423                            request_message.content.push(MessageContent::Thinking {
1424                                text: text.into(),
1425                                signature: signature.clone(),
1426                            });
1427                        }
1428                    }
1429                    MessageSegment::RedactedThinking(data) => {
1430                        request_message
1431                            .content
1432                            .push(MessageContent::RedactedThinking(data.clone()));
1433                    }
1434                };
1435            }
1436
1437            let mut cache_message = true;
1438            let mut tool_results_message = LanguageModelRequestMessage {
1439                role: Role::User,
1440                content: Vec::new(),
1441                cache: false,
1442            };
1443            for (tool_use, tool_result) in self.tool_use.tool_results(message.id) {
1444                if let Some(tool_result) = tool_result {
1445                    request_message
1446                        .content
1447                        .push(MessageContent::ToolUse(tool_use.clone()));
1448                    tool_results_message
1449                        .content
1450                        .push(MessageContent::ToolResult(LanguageModelToolResult {
1451                            tool_use_id: tool_use.id.clone(),
1452                            tool_name: tool_result.tool_name.clone(),
1453                            is_error: tool_result.is_error,
1454                            content: if tool_result.content.is_empty() {
1455                                // Surprisingly, the API fails if we return an empty string here.
1456                                // It thinks we are sending a tool use without a tool result.
1457                                "<Tool returned an empty string>".into()
1458                            } else {
1459                                tool_result.content.clone()
1460                            },
1461                            output: None,
1462                        }));
1463                } else {
1464                    cache_message = false;
1465                    log::debug!(
1466                        "skipped tool use {:?} because it is still pending",
1467                        tool_use
1468                    );
1469                }
1470            }
1471
1472            if cache_message {
1473                message_ix_to_cache = Some(request.messages.len());
1474            }
1475            request.messages.push(request_message);
1476
1477            if !tool_results_message.content.is_empty() {
1478                if cache_message {
1479                    message_ix_to_cache = Some(request.messages.len());
1480                }
1481                request.messages.push(tool_results_message);
1482            }
1483        }
1484
1485        // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1486        if let Some(message_ix_to_cache) = message_ix_to_cache {
1487            request.messages[message_ix_to_cache].cache = true;
1488        }
1489
1490        request.tools = available_tools;
1491        request.mode = if model.supports_burn_mode() {
1492            Some(self.completion_mode.into())
1493        } else {
1494            Some(CompletionMode::Normal.into())
1495        };
1496
1497        request
1498    }
1499
1500    fn to_summarize_request(
1501        &self,
1502        model: &Arc<dyn LanguageModel>,
1503        intent: CompletionIntent,
1504        added_user_message: String,
1505        cx: &App,
1506    ) -> LanguageModelRequest {
1507        let mut request = LanguageModelRequest {
1508            thread_id: None,
1509            prompt_id: None,
1510            intent: Some(intent),
1511            mode: None,
1512            messages: vec![],
1513            tools: Vec::new(),
1514            tool_choice: None,
1515            stop: Vec::new(),
1516            temperature: AgentSettings::temperature_for_model(model, cx),
1517            thinking_allowed: false,
1518        };
1519
1520        for message in &self.messages {
1521            let mut request_message = LanguageModelRequestMessage {
1522                role: message.role,
1523                content: Vec::new(),
1524                cache: false,
1525            };
1526
1527            for segment in &message.segments {
1528                match segment {
1529                    MessageSegment::Text(text) => request_message
1530                        .content
1531                        .push(MessageContent::Text(text.clone())),
1532                    MessageSegment::Thinking { .. } => {}
1533                    MessageSegment::RedactedThinking(_) => {}
1534                }
1535            }
1536
1537            if request_message.content.is_empty() {
1538                continue;
1539            }
1540
1541            request.messages.push(request_message);
1542        }
1543
1544        request.messages.push(LanguageModelRequestMessage {
1545            role: Role::User,
1546            content: vec![MessageContent::Text(added_user_message)],
1547            cache: false,
1548        });
1549
1550        request
1551    }
1552
1553    /// Insert auto-generated notifications (if any) to the thread
1554    fn flush_notifications(
1555        &mut self,
1556        model: Arc<dyn LanguageModel>,
1557        intent: CompletionIntent,
1558        cx: &mut Context<Self>,
1559    ) {
1560        match intent {
1561            CompletionIntent::UserPrompt | CompletionIntent::ToolResults => {
1562                if let Some(pending_tool_use) = self.attach_tracked_files_state(model, cx) {
1563                    cx.emit(ThreadEvent::ToolFinished {
1564                        tool_use_id: pending_tool_use.id.clone(),
1565                        pending_tool_use: Some(pending_tool_use),
1566                    });
1567                }
1568            }
1569            CompletionIntent::ThreadSummarization
1570            | CompletionIntent::ThreadContextSummarization
1571            | CompletionIntent::CreateFile
1572            | CompletionIntent::EditFile
1573            | CompletionIntent::InlineAssist
1574            | CompletionIntent::TerminalInlineAssist
1575            | CompletionIntent::GenerateGitCommitMessage => {}
1576        };
1577    }
1578
1579    fn attach_tracked_files_state(
1580        &mut self,
1581        model: Arc<dyn LanguageModel>,
1582        cx: &mut App,
1583    ) -> Option<PendingToolUse> {
1584        // Represent notification as a simulated `project_notifications` tool call
1585        let tool_name = Arc::from("project_notifications");
1586        let tool = self.tools.read(cx).tool(&tool_name, cx)?;
1587
1588        if !self.profile.is_tool_enabled(tool.source(), tool.name(), cx) {
1589            return None;
1590        }
1591
1592        if self
1593            .action_log
1594            .update(cx, |log, cx| log.unnotified_user_edits(cx).is_none())
1595        {
1596            return None;
1597        }
1598
1599        let input = serde_json::json!({});
1600        let request = Arc::new(LanguageModelRequest::default()); // unused
1601        let window = None;
1602        let tool_result = tool.run(
1603            input,
1604            request,
1605            self.project.clone(),
1606            self.action_log.clone(),
1607            model.clone(),
1608            window,
1609            cx,
1610        );
1611
1612        let tool_use_id =
1613            LanguageModelToolUseId::from(format!("project_notifications_{}", self.messages.len()));
1614
1615        let tool_use = LanguageModelToolUse {
1616            id: tool_use_id.clone(),
1617            name: tool_name.clone(),
1618            raw_input: "{}".to_string(),
1619            input: serde_json::json!({}),
1620            is_input_complete: true,
1621        };
1622
1623        let tool_output = cx.background_executor().block(tool_result.output);
1624
1625        // Attach a project_notification tool call to the latest existing
1626        // Assistant message. We cannot create a new Assistant message
1627        // because thinking models require a `thinking` block that we
1628        // cannot mock. We cannot send a notification as a normal
1629        // (non-tool-use) User message because this distracts Agent
1630        // too much.
1631        let tool_message_id = self
1632            .messages
1633            .iter()
1634            .enumerate()
1635            .rfind(|(_, message)| message.role == Role::Assistant)
1636            .map(|(_, message)| message.id)?;
1637
1638        let tool_use_metadata = ToolUseMetadata {
1639            model: model.clone(),
1640            thread_id: self.id.clone(),
1641            prompt_id: self.last_prompt_id.clone(),
1642        };
1643
1644        self.tool_use
1645            .request_tool_use(tool_message_id, tool_use, tool_use_metadata, cx);
1646
1647        self.tool_use.insert_tool_output(
1648            tool_use_id,
1649            tool_name,
1650            tool_output,
1651            self.configured_model.as_ref(),
1652            self.completion_mode,
1653        )
1654    }
1655
1656    pub fn stream_completion(
1657        &mut self,
1658        request: LanguageModelRequest,
1659        model: Arc<dyn LanguageModel>,
1660        intent: CompletionIntent,
1661        window: Option<AnyWindowHandle>,
1662        cx: &mut Context<Self>,
1663    ) {
1664        self.tool_use_limit_reached = false;
1665
1666        let pending_completion_id = post_inc(&mut self.completion_count);
1667        let mut request_callback_parameters = if self.request_callback.is_some() {
1668            Some((request.clone(), Vec::new()))
1669        } else {
1670            None
1671        };
1672        let prompt_id = self.last_prompt_id.clone();
1673        let tool_use_metadata = ToolUseMetadata {
1674            model: model.clone(),
1675            thread_id: self.id.clone(),
1676            prompt_id: prompt_id.clone(),
1677        };
1678
1679        let completion_mode = request
1680            .mode
1681            .unwrap_or(cloud_llm_client::CompletionMode::Normal);
1682
1683        self.last_received_chunk_at = Some(Instant::now());
1684
1685        let task = cx.spawn(async move |thread, cx| {
1686            let stream_completion_future = model.stream_completion(request, cx);
1687            let initial_token_usage =
1688                thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1689            let stream_completion = async {
1690                let mut events = stream_completion_future.await?;
1691
1692                let mut stop_reason = StopReason::EndTurn;
1693                let mut current_token_usage = TokenUsage::default();
1694
1695                thread
1696                    .update(cx, |_thread, cx| {
1697                        cx.emit(ThreadEvent::NewRequest);
1698                    })
1699                    .ok();
1700
1701                let mut request_assistant_message_id = None;
1702
1703                while let Some(event) = events.next().await {
1704                    if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1705                        response_events
1706                            .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1707                    }
1708
1709                    thread.update(cx, |thread, cx| {
1710                        match event? {
1711                            LanguageModelCompletionEvent::StartMessage { .. } => {
1712                                request_assistant_message_id =
1713                                    Some(thread.insert_assistant_message(
1714                                        vec![MessageSegment::Text(String::new())],
1715                                        cx,
1716                                    ));
1717                            }
1718                            LanguageModelCompletionEvent::Stop(reason) => {
1719                                stop_reason = reason;
1720                            }
1721                            LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1722                                thread.update_token_usage_at_last_message(token_usage);
1723                                thread.cumulative_token_usage = thread.cumulative_token_usage
1724                                    + token_usage
1725                                    - current_token_usage;
1726                                current_token_usage = token_usage;
1727                            }
1728                            LanguageModelCompletionEvent::Text(chunk) => {
1729                                thread.received_chunk();
1730
1731                                cx.emit(ThreadEvent::ReceivedTextChunk);
1732                                if let Some(last_message) = thread.messages.last_mut() {
1733                                    if last_message.role == Role::Assistant
1734                                        && !thread.tool_use.has_tool_results(last_message.id)
1735                                    {
1736                                        last_message.push_text(&chunk);
1737                                        cx.emit(ThreadEvent::StreamedAssistantText(
1738                                            last_message.id,
1739                                            chunk,
1740                                        ));
1741                                    } else {
1742                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
1743                                        // of a new Assistant response.
1744                                        //
1745                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1746                                        // will result in duplicating the text of the chunk in the rendered Markdown.
1747                                        request_assistant_message_id =
1748                                            Some(thread.insert_assistant_message(
1749                                                vec![MessageSegment::Text(chunk.to_string())],
1750                                                cx,
1751                                            ));
1752                                    };
1753                                }
1754                            }
1755                            LanguageModelCompletionEvent::Thinking {
1756                                text: chunk,
1757                                signature,
1758                            } => {
1759                                thread.received_chunk();
1760
1761                                if let Some(last_message) = thread.messages.last_mut() {
1762                                    if last_message.role == Role::Assistant
1763                                        && !thread.tool_use.has_tool_results(last_message.id)
1764                                    {
1765                                        last_message.push_thinking(&chunk, signature);
1766                                        cx.emit(ThreadEvent::StreamedAssistantThinking(
1767                                            last_message.id,
1768                                            chunk,
1769                                        ));
1770                                    } else {
1771                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
1772                                        // of a new Assistant response.
1773                                        //
1774                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1775                                        // will result in duplicating the text of the chunk in the rendered Markdown.
1776                                        request_assistant_message_id =
1777                                            Some(thread.insert_assistant_message(
1778                                                vec![MessageSegment::Thinking {
1779                                                    text: chunk.to_string(),
1780                                                    signature,
1781                                                }],
1782                                                cx,
1783                                            ));
1784                                    };
1785                                }
1786                            }
1787                            LanguageModelCompletionEvent::RedactedThinking { data } => {
1788                                thread.received_chunk();
1789
1790                                if let Some(last_message) = thread.messages.last_mut() {
1791                                    if last_message.role == Role::Assistant
1792                                        && !thread.tool_use.has_tool_results(last_message.id)
1793                                    {
1794                                        last_message.push_redacted_thinking(data);
1795                                    } else {
1796                                        request_assistant_message_id =
1797                                            Some(thread.insert_assistant_message(
1798                                                vec![MessageSegment::RedactedThinking(data)],
1799                                                cx,
1800                                            ));
1801                                    };
1802                                }
1803                            }
1804                            LanguageModelCompletionEvent::ToolUse(tool_use) => {
1805                                let last_assistant_message_id = request_assistant_message_id
1806                                    .unwrap_or_else(|| {
1807                                        let new_assistant_message_id =
1808                                            thread.insert_assistant_message(vec![], cx);
1809                                        request_assistant_message_id =
1810                                            Some(new_assistant_message_id);
1811                                        new_assistant_message_id
1812                                    });
1813
1814                                let tool_use_id = tool_use.id.clone();
1815                                let streamed_input = if tool_use.is_input_complete {
1816                                    None
1817                                } else {
1818                                    Some(tool_use.input.clone())
1819                                };
1820
1821                                let ui_text = thread.tool_use.request_tool_use(
1822                                    last_assistant_message_id,
1823                                    tool_use,
1824                                    tool_use_metadata.clone(),
1825                                    cx,
1826                                );
1827
1828                                if let Some(input) = streamed_input {
1829                                    cx.emit(ThreadEvent::StreamedToolUse {
1830                                        tool_use_id,
1831                                        ui_text,
1832                                        input,
1833                                    });
1834                                }
1835                            }
1836                            LanguageModelCompletionEvent::ToolUseJsonParseError {
1837                                id,
1838                                tool_name,
1839                                raw_input: invalid_input_json,
1840                                json_parse_error,
1841                            } => {
1842                                thread.receive_invalid_tool_json(
1843                                    id,
1844                                    tool_name,
1845                                    invalid_input_json,
1846                                    json_parse_error,
1847                                    window,
1848                                    cx,
1849                                );
1850                            }
1851                            LanguageModelCompletionEvent::StatusUpdate(status_update) => {
1852                                if let Some(completion) = thread
1853                                    .pending_completions
1854                                    .iter_mut()
1855                                    .find(|completion| completion.id == pending_completion_id)
1856                                {
1857                                    match status_update {
1858                                        CompletionRequestStatus::Queued { position } => {
1859                                            completion.queue_state =
1860                                                QueueState::Queued { position };
1861                                        }
1862                                        CompletionRequestStatus::Started => {
1863                                            completion.queue_state = QueueState::Started;
1864                                        }
1865                                        CompletionRequestStatus::Failed {
1866                                            code,
1867                                            message,
1868                                            request_id: _,
1869                                            retry_after,
1870                                        } => {
1871                                            return Err(
1872                                                LanguageModelCompletionError::from_cloud_failure(
1873                                                    model.upstream_provider_name(),
1874                                                    code,
1875                                                    message,
1876                                                    retry_after.map(Duration::from_secs_f64),
1877                                                ),
1878                                            );
1879                                        }
1880                                        CompletionRequestStatus::UsageUpdated { amount, limit } => {
1881                                            thread.update_model_request_usage(
1882                                                amount as u32,
1883                                                limit,
1884                                                cx,
1885                                            );
1886                                        }
1887                                        CompletionRequestStatus::ToolUseLimitReached => {
1888                                            thread.tool_use_limit_reached = true;
1889                                            cx.emit(ThreadEvent::ToolUseLimitReached);
1890                                        }
1891                                    }
1892                                }
1893                            }
1894                        }
1895
1896                        thread.touch_updated_at();
1897                        cx.emit(ThreadEvent::StreamedCompletion);
1898                        cx.notify();
1899
1900                        Ok(())
1901                    })??;
1902
1903                    smol::future::yield_now().await;
1904                }
1905
1906                thread.update(cx, |thread, cx| {
1907                    thread.last_received_chunk_at = None;
1908                    thread
1909                        .pending_completions
1910                        .retain(|completion| completion.id != pending_completion_id);
1911
1912                    // If there is a response without tool use, summarize the message. Otherwise,
1913                    // allow two tool uses before summarizing.
1914                    if matches!(thread.summary, ThreadSummary::Pending)
1915                        && thread.messages.len() >= 2
1916                        && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
1917                    {
1918                        thread.summarize(cx);
1919                    }
1920                })?;
1921
1922                anyhow::Ok(stop_reason)
1923            };
1924
1925            let result = stream_completion.await;
1926            let mut retry_scheduled = false;
1927
1928            thread
1929                .update(cx, |thread, cx| {
1930                    thread.finalize_pending_checkpoint(cx);
1931                    match result.as_ref() {
1932                        Ok(stop_reason) => {
1933                            match stop_reason {
1934                                StopReason::ToolUse => {
1935                                    let tool_uses =
1936                                        thread.use_pending_tools(window, model.clone(), cx);
1937                                    cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1938                                }
1939                                StopReason::EndTurn | StopReason::MaxTokens => {
1940                                    thread.project.update(cx, |project, cx| {
1941                                        project.set_agent_location(None, cx);
1942                                    });
1943                                }
1944                                StopReason::Refusal => {
1945                                    thread.project.update(cx, |project, cx| {
1946                                        project.set_agent_location(None, cx);
1947                                    });
1948
1949                                    // Remove the turn that was refused.
1950                                    //
1951                                    // https://docs.anthropic.com/en/docs/test-and-evaluate/strengthen-guardrails/handle-streaming-refusals#reset-context-after-refusal
1952                                    {
1953                                        let mut messages_to_remove = Vec::new();
1954
1955                                        for (ix, message) in
1956                                            thread.messages.iter().enumerate().rev()
1957                                        {
1958                                            messages_to_remove.push(message.id);
1959
1960                                            if message.role == Role::User {
1961                                                if ix == 0 {
1962                                                    break;
1963                                                }
1964
1965                                                if let Some(prev_message) =
1966                                                    thread.messages.get(ix - 1)
1967                                                    && prev_message.role == Role::Assistant {
1968                                                        break;
1969                                                    }
1970                                            }
1971                                        }
1972
1973                                        for message_id in messages_to_remove {
1974                                            thread.delete_message(message_id, cx);
1975                                        }
1976                                    }
1977
1978                                    cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1979                                        header: "Language model refusal".into(),
1980                                        message:
1981                                            "Model refused to generate content for safety reasons."
1982                                                .into(),
1983                                    }));
1984                                }
1985                            }
1986
1987                            // We successfully completed, so cancel any remaining retries.
1988                            thread.retry_state = None;
1989                        }
1990                        Err(error) => {
1991                            thread.project.update(cx, |project, cx| {
1992                                project.set_agent_location(None, cx);
1993                            });
1994
1995                            if error.is::<PaymentRequiredError>() {
1996                                cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1997                            } else if let Some(error) =
1998                                error.downcast_ref::<ModelRequestLimitReachedError>()
1999                            {
2000                                cx.emit(ThreadEvent::ShowError(
2001                                    ThreadError::ModelRequestLimitReached { plan: error.plan },
2002                                ));
2003                            } else if let Some(completion_error) =
2004                                error.downcast_ref::<LanguageModelCompletionError>()
2005                            {
2006                                match &completion_error {
2007                                    LanguageModelCompletionError::PromptTooLarge {
2008                                        tokens, ..
2009                                    } => {
2010                                        let tokens = tokens.unwrap_or_else(|| {
2011                                            // We didn't get an exact token count from the API, so fall back on our estimate.
2012                                            thread
2013                                                .total_token_usage()
2014                                                .map(|usage| usage.total)
2015                                                .unwrap_or(0)
2016                                                // We know the context window was exceeded in practice, so if our estimate was
2017                                                // lower than max tokens, the estimate was wrong; return that we exceeded by 1.
2018                                                .max(
2019                                                    model
2020                                                        .max_token_count_for_mode(completion_mode)
2021                                                        .saturating_add(1),
2022                                                )
2023                                        });
2024                                        thread.exceeded_window_error = Some(ExceededWindowError {
2025                                            model_id: model.id(),
2026                                            token_count: tokens,
2027                                        });
2028                                        cx.notify();
2029                                    }
2030                                    _ => {
2031                                        if let Some(retry_strategy) =
2032                                            Thread::get_retry_strategy(completion_error)
2033                                        {
2034                                            log::info!(
2035                                                "Retrying with {:?} for language model completion error {:?}",
2036                                                retry_strategy,
2037                                                completion_error
2038                                            );
2039
2040                                            retry_scheduled = thread
2041                                                .handle_retryable_error_with_delay(
2042                                                    completion_error,
2043                                                    Some(retry_strategy),
2044                                                    model.clone(),
2045                                                    intent,
2046                                                    window,
2047                                                    cx,
2048                                                );
2049                                        }
2050                                    }
2051                                }
2052                            }
2053
2054                            if !retry_scheduled {
2055                                thread.cancel_last_completion(window, cx);
2056                            }
2057                        }
2058                    }
2059
2060                    if !retry_scheduled {
2061                        cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
2062                    }
2063
2064                    if let Some((request_callback, (request, response_events))) = thread
2065                        .request_callback
2066                        .as_mut()
2067                        .zip(request_callback_parameters.as_ref())
2068                    {
2069                        request_callback(request, response_events);
2070                    }
2071
2072                    if let Ok(initial_usage) = initial_token_usage {
2073                        let usage = thread.cumulative_token_usage - initial_usage;
2074
2075                        telemetry::event!(
2076                            "Assistant Thread Completion",
2077                            thread_id = thread.id().to_string(),
2078                            prompt_id = prompt_id,
2079                            model = model.telemetry_id(),
2080                            model_provider = model.provider_id().to_string(),
2081                            input_tokens = usage.input_tokens,
2082                            output_tokens = usage.output_tokens,
2083                            cache_creation_input_tokens = usage.cache_creation_input_tokens,
2084                            cache_read_input_tokens = usage.cache_read_input_tokens,
2085                        );
2086                    }
2087                })
2088                .ok();
2089        });
2090
2091        self.pending_completions.push(PendingCompletion {
2092            id: pending_completion_id,
2093            queue_state: QueueState::Sending,
2094            _task: task,
2095        });
2096    }
2097
2098    pub fn summarize(&mut self, cx: &mut Context<Self>) {
2099        let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
2100            println!("No thread summary model");
2101            return;
2102        };
2103
2104        if !model.provider.is_authenticated(cx) {
2105            return;
2106        }
2107
2108        let request = self.to_summarize_request(
2109            &model.model,
2110            CompletionIntent::ThreadSummarization,
2111            SUMMARIZE_THREAD_PROMPT.into(),
2112            cx,
2113        );
2114
2115        self.summary = ThreadSummary::Generating;
2116
2117        self.pending_summary = cx.spawn(async move |this, cx| {
2118            let result = async {
2119                let mut messages = model.model.stream_completion(request, cx).await?;
2120
2121                let mut new_summary = String::new();
2122                while let Some(event) = messages.next().await {
2123                    let Ok(event) = event else {
2124                        continue;
2125                    };
2126                    let text = match event {
2127                        LanguageModelCompletionEvent::Text(text) => text,
2128                        LanguageModelCompletionEvent::StatusUpdate(
2129                            CompletionRequestStatus::UsageUpdated { amount, limit },
2130                        ) => {
2131                            this.update(cx, |thread, cx| {
2132                                thread.update_model_request_usage(amount as u32, limit, cx);
2133                            })?;
2134                            continue;
2135                        }
2136                        _ => continue,
2137                    };
2138
2139                    let mut lines = text.lines();
2140                    new_summary.extend(lines.next());
2141
2142                    // Stop if the LLM generated multiple lines.
2143                    if lines.next().is_some() {
2144                        break;
2145                    }
2146                }
2147
2148                anyhow::Ok(new_summary)
2149            }
2150            .await;
2151
2152            this.update(cx, |this, cx| {
2153                match result {
2154                    Ok(new_summary) => {
2155                        if new_summary.is_empty() {
2156                            this.summary = ThreadSummary::Error;
2157                        } else {
2158                            this.summary = ThreadSummary::Ready(new_summary.into());
2159                        }
2160                    }
2161                    Err(err) => {
2162                        this.summary = ThreadSummary::Error;
2163                        log::error!("Failed to generate thread summary: {}", err);
2164                    }
2165                }
2166                cx.emit(ThreadEvent::SummaryGenerated);
2167            })
2168            .log_err()?;
2169
2170            Some(())
2171        });
2172    }
2173
2174    fn get_retry_strategy(error: &LanguageModelCompletionError) -> Option<RetryStrategy> {
2175        use LanguageModelCompletionError::*;
2176
2177        // General strategy here:
2178        // - If retrying won't help (e.g. invalid API key or payload too large), return None so we don't retry at all.
2179        // - If it's a time-based issue (e.g. server overloaded, rate limit exceeded), retry up to 4 times with exponential backoff.
2180        // - If it's an issue that *might* be fixed by retrying (e.g. internal server error), retry up to 3 times.
2181        match error {
2182            HttpResponseError {
2183                status_code: StatusCode::TOO_MANY_REQUESTS,
2184                ..
2185            } => Some(RetryStrategy::ExponentialBackoff {
2186                initial_delay: BASE_RETRY_DELAY,
2187                max_attempts: MAX_RETRY_ATTEMPTS,
2188            }),
2189            ServerOverloaded { retry_after, .. } | RateLimitExceeded { retry_after, .. } => {
2190                Some(RetryStrategy::Fixed {
2191                    delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
2192                    max_attempts: MAX_RETRY_ATTEMPTS,
2193                })
2194            }
2195            UpstreamProviderError {
2196                status,
2197                retry_after,
2198                ..
2199            } => match *status {
2200                StatusCode::TOO_MANY_REQUESTS | StatusCode::SERVICE_UNAVAILABLE => {
2201                    Some(RetryStrategy::Fixed {
2202                        delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
2203                        max_attempts: MAX_RETRY_ATTEMPTS,
2204                    })
2205                }
2206                StatusCode::INTERNAL_SERVER_ERROR => Some(RetryStrategy::Fixed {
2207                    delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
2208                    // Internal Server Error could be anything, retry up to 3 times.
2209                    max_attempts: 3,
2210                }),
2211                status => {
2212                    // There is no StatusCode variant for the unofficial HTTP 529 ("The service is overloaded"),
2213                    // but we frequently get them in practice. See https://http.dev/529
2214                    if status.as_u16() == 529 {
2215                        Some(RetryStrategy::Fixed {
2216                            delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
2217                            max_attempts: MAX_RETRY_ATTEMPTS,
2218                        })
2219                    } else {
2220                        Some(RetryStrategy::Fixed {
2221                            delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
2222                            max_attempts: 2,
2223                        })
2224                    }
2225                }
2226            },
2227            ApiInternalServerError { .. } => Some(RetryStrategy::Fixed {
2228                delay: BASE_RETRY_DELAY,
2229                max_attempts: 3,
2230            }),
2231            ApiReadResponseError { .. }
2232            | HttpSend { .. }
2233            | DeserializeResponse { .. }
2234            | BadRequestFormat { .. } => Some(RetryStrategy::Fixed {
2235                delay: BASE_RETRY_DELAY,
2236                max_attempts: 3,
2237            }),
2238            // Retrying these errors definitely shouldn't help.
2239            HttpResponseError {
2240                status_code:
2241                    StatusCode::PAYLOAD_TOO_LARGE | StatusCode::FORBIDDEN | StatusCode::UNAUTHORIZED,
2242                ..
2243            }
2244            | AuthenticationError { .. }
2245            | PermissionError { .. }
2246            | NoApiKey { .. }
2247            | ApiEndpointNotFound { .. }
2248            | PromptTooLarge { .. } => None,
2249            // These errors might be transient, so retry them
2250            SerializeRequest { .. } | BuildRequestBody { .. } => Some(RetryStrategy::Fixed {
2251                delay: BASE_RETRY_DELAY,
2252                max_attempts: 1,
2253            }),
2254            // Retry all other 4xx and 5xx errors once.
2255            HttpResponseError { status_code, .. }
2256                if status_code.is_client_error() || status_code.is_server_error() =>
2257            {
2258                Some(RetryStrategy::Fixed {
2259                    delay: BASE_RETRY_DELAY,
2260                    max_attempts: 3,
2261                })
2262            }
2263            Other(err)
2264                if err.is::<PaymentRequiredError>()
2265                    || err.is::<ModelRequestLimitReachedError>() =>
2266            {
2267                // Retrying won't help for Payment Required or Model Request Limit errors (where
2268                // the user must upgrade to usage-based billing to get more requests, or else wait
2269                // for a significant amount of time for the request limit to reset).
2270                None
2271            }
2272            // Conservatively assume that any other errors are non-retryable
2273            HttpResponseError { .. } | Other(..) => Some(RetryStrategy::Fixed {
2274                delay: BASE_RETRY_DELAY,
2275                max_attempts: 2,
2276            }),
2277        }
2278    }
2279
2280    fn handle_retryable_error_with_delay(
2281        &mut self,
2282        error: &LanguageModelCompletionError,
2283        strategy: Option<RetryStrategy>,
2284        model: Arc<dyn LanguageModel>,
2285        intent: CompletionIntent,
2286        window: Option<AnyWindowHandle>,
2287        cx: &mut Context<Self>,
2288    ) -> bool {
2289        // Store context for the Retry button
2290        self.last_error_context = Some((model.clone(), intent));
2291
2292        // Only auto-retry if Burn Mode is enabled
2293        if self.completion_mode != CompletionMode::Burn {
2294            // Show error with retry options
2295            cx.emit(ThreadEvent::ShowError(ThreadError::RetryableError {
2296                message: format!(
2297                    "{}\n\nTo automatically retry when similar errors happen, enable Burn Mode.",
2298                    error
2299                )
2300                .into(),
2301                can_enable_burn_mode: true,
2302            }));
2303            return false;
2304        }
2305
2306        let Some(strategy) = strategy.or_else(|| Self::get_retry_strategy(error)) else {
2307            return false;
2308        };
2309
2310        let max_attempts = match &strategy {
2311            RetryStrategy::ExponentialBackoff { max_attempts, .. } => *max_attempts,
2312            RetryStrategy::Fixed { max_attempts, .. } => *max_attempts,
2313        };
2314
2315        let retry_state = self.retry_state.get_or_insert(RetryState {
2316            attempt: 0,
2317            max_attempts,
2318            intent,
2319        });
2320
2321        retry_state.attempt += 1;
2322        let attempt = retry_state.attempt;
2323        let max_attempts = retry_state.max_attempts;
2324        let intent = retry_state.intent;
2325
2326        if attempt <= max_attempts {
2327            let delay = match &strategy {
2328                RetryStrategy::ExponentialBackoff { initial_delay, .. } => {
2329                    let delay_secs = initial_delay.as_secs() * 2u64.pow((attempt - 1) as u32);
2330                    Duration::from_secs(delay_secs)
2331                }
2332                RetryStrategy::Fixed { delay, .. } => *delay,
2333            };
2334
2335            // Add a transient message to inform the user
2336            let delay_secs = delay.as_secs();
2337            let retry_message = if max_attempts == 1 {
2338                format!("{error}. Retrying in {delay_secs} seconds...")
2339            } else {
2340                format!(
2341                    "{error}. Retrying (attempt {attempt} of {max_attempts}) \
2342                    in {delay_secs} seconds..."
2343                )
2344            };
2345            log::warn!(
2346                "Retrying completion request (attempt {attempt} of {max_attempts}) \
2347                in {delay_secs} seconds: {error:?}",
2348            );
2349
2350            // Add a UI-only message instead of a regular message
2351            let id = self.next_message_id.post_inc();
2352            self.messages.push(Message {
2353                id,
2354                role: Role::System,
2355                segments: vec![MessageSegment::Text(retry_message)],
2356                loaded_context: LoadedContext::default(),
2357                creases: Vec::new(),
2358                is_hidden: false,
2359                ui_only: true,
2360            });
2361            cx.emit(ThreadEvent::MessageAdded(id));
2362
2363            // Schedule the retry
2364            let thread_handle = cx.entity().downgrade();
2365
2366            cx.spawn(async move |_thread, cx| {
2367                cx.background_executor().timer(delay).await;
2368
2369                thread_handle
2370                    .update(cx, |thread, cx| {
2371                        // Retry the completion
2372                        thread.send_to_model(model, intent, window, cx);
2373                    })
2374                    .log_err();
2375            })
2376            .detach();
2377
2378            true
2379        } else {
2380            // Max retries exceeded
2381            self.retry_state = None;
2382
2383            // Stop generating since we're giving up on retrying.
2384            self.pending_completions.clear();
2385
2386            // Show error alongside a Retry button, but no
2387            // Enable Burn Mode button (since it's already enabled)
2388            cx.emit(ThreadEvent::ShowError(ThreadError::RetryableError {
2389                message: format!("Failed after retrying: {}", error).into(),
2390                can_enable_burn_mode: false,
2391            }));
2392
2393            false
2394        }
2395    }
2396
2397    pub fn start_generating_detailed_summary_if_needed(
2398        &mut self,
2399        thread_store: WeakEntity<ThreadStore>,
2400        cx: &mut Context<Self>,
2401    ) {
2402        let Some(last_message_id) = self.messages.last().map(|message| message.id) else {
2403            return;
2404        };
2405
2406        match &*self.detailed_summary_rx.borrow() {
2407            DetailedSummaryState::Generating { message_id, .. }
2408            | DetailedSummaryState::Generated { message_id, .. }
2409                if *message_id == last_message_id =>
2410            {
2411                // Already up-to-date
2412                return;
2413            }
2414            _ => {}
2415        }
2416
2417        let Some(ConfiguredModel { model, provider }) =
2418            LanguageModelRegistry::read_global(cx).thread_summary_model()
2419        else {
2420            return;
2421        };
2422
2423        if !provider.is_authenticated(cx) {
2424            return;
2425        }
2426
2427        let request = self.to_summarize_request(
2428            &model,
2429            CompletionIntent::ThreadContextSummarization,
2430            SUMMARIZE_THREAD_DETAILED_PROMPT.into(),
2431            cx,
2432        );
2433
2434        *self.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generating {
2435            message_id: last_message_id,
2436        };
2437
2438        // Replace the detailed summarization task if there is one, cancelling it. It would probably
2439        // be better to allow the old task to complete, but this would require logic for choosing
2440        // which result to prefer (the old task could complete after the new one, resulting in a
2441        // stale summary).
2442        self.detailed_summary_task = cx.spawn(async move |thread, cx| {
2443            let stream = model.stream_completion_text(request, cx);
2444            let Some(mut messages) = stream.await.log_err() else {
2445                thread
2446                    .update(cx, |thread, _cx| {
2447                        *thread.detailed_summary_tx.borrow_mut() =
2448                            DetailedSummaryState::NotGenerated;
2449                    })
2450                    .ok()?;
2451                return None;
2452            };
2453
2454            let mut new_detailed_summary = String::new();
2455
2456            while let Some(chunk) = messages.stream.next().await {
2457                if let Some(chunk) = chunk.log_err() {
2458                    new_detailed_summary.push_str(&chunk);
2459                }
2460            }
2461
2462            thread
2463                .update(cx, |thread, _cx| {
2464                    *thread.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generated {
2465                        text: new_detailed_summary.into(),
2466                        message_id: last_message_id,
2467                    };
2468                })
2469                .ok()?;
2470
2471            // Save thread so its summary can be reused later
2472            if let Some(thread) = thread.upgrade()
2473                && let Ok(Ok(save_task)) = cx.update(|cx| {
2474                    thread_store
2475                        .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx))
2476                })
2477            {
2478                save_task.await.log_err();
2479            }
2480
2481            Some(())
2482        });
2483    }
2484
2485    pub async fn wait_for_detailed_summary_or_text(
2486        this: &Entity<Self>,
2487        cx: &mut AsyncApp,
2488    ) -> Option<SharedString> {
2489        let mut detailed_summary_rx = this
2490            .read_with(cx, |this, _cx| this.detailed_summary_rx.clone())
2491            .ok()?;
2492        loop {
2493            match detailed_summary_rx.recv().await? {
2494                DetailedSummaryState::Generating { .. } => {}
2495                DetailedSummaryState::NotGenerated => {
2496                    return this.read_with(cx, |this, _cx| this.text().into()).ok();
2497                }
2498                DetailedSummaryState::Generated { text, .. } => return Some(text),
2499            }
2500        }
2501    }
2502
2503    pub fn latest_detailed_summary_or_text(&self) -> SharedString {
2504        self.detailed_summary_rx
2505            .borrow()
2506            .text()
2507            .unwrap_or_else(|| self.text().into())
2508    }
2509
2510    pub fn is_generating_detailed_summary(&self) -> bool {
2511        matches!(
2512            &*self.detailed_summary_rx.borrow(),
2513            DetailedSummaryState::Generating { .. }
2514        )
2515    }
2516
2517    pub fn use_pending_tools(
2518        &mut self,
2519        window: Option<AnyWindowHandle>,
2520        model: Arc<dyn LanguageModel>,
2521        cx: &mut Context<Self>,
2522    ) -> Vec<PendingToolUse> {
2523        let request =
2524            Arc::new(self.to_completion_request(model.clone(), CompletionIntent::ToolResults, cx));
2525        let pending_tool_uses = self
2526            .tool_use
2527            .pending_tool_uses()
2528            .into_iter()
2529            .filter(|tool_use| tool_use.status.is_idle())
2530            .cloned()
2531            .collect::<Vec<_>>();
2532
2533        for tool_use in pending_tool_uses.iter() {
2534            self.use_pending_tool(tool_use.clone(), request.clone(), model.clone(), window, cx);
2535        }
2536
2537        pending_tool_uses
2538    }
2539
2540    fn use_pending_tool(
2541        &mut self,
2542        tool_use: PendingToolUse,
2543        request: Arc<LanguageModelRequest>,
2544        model: Arc<dyn LanguageModel>,
2545        window: Option<AnyWindowHandle>,
2546        cx: &mut Context<Self>,
2547    ) {
2548        let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) else {
2549            return self.handle_hallucinated_tool_use(tool_use.id, tool_use.name, window, cx);
2550        };
2551
2552        if !self.profile.is_tool_enabled(tool.source(), tool.name(), cx) {
2553            return self.handle_hallucinated_tool_use(tool_use.id, tool_use.name, window, cx);
2554        }
2555
2556        if tool.needs_confirmation(&tool_use.input, &self.project, cx)
2557            && !AgentSettings::get_global(cx).always_allow_tool_actions
2558        {
2559            self.tool_use.confirm_tool_use(
2560                tool_use.id,
2561                tool_use.ui_text,
2562                tool_use.input,
2563                request,
2564                tool,
2565            );
2566            cx.emit(ThreadEvent::ToolConfirmationNeeded);
2567        } else {
2568            self.run_tool(
2569                tool_use.id,
2570                tool_use.ui_text,
2571                tool_use.input,
2572                request,
2573                tool,
2574                model,
2575                window,
2576                cx,
2577            );
2578        }
2579    }
2580
2581    pub fn handle_hallucinated_tool_use(
2582        &mut self,
2583        tool_use_id: LanguageModelToolUseId,
2584        hallucinated_tool_name: Arc<str>,
2585        window: Option<AnyWindowHandle>,
2586        cx: &mut Context<Thread>,
2587    ) {
2588        let available_tools = self.profile.enabled_tools(cx);
2589
2590        let tool_list = available_tools
2591            .iter()
2592            .map(|(name, tool)| format!("- {}: {}", name, tool.description()))
2593            .collect::<Vec<_>>()
2594            .join("\n");
2595
2596        let error_message = format!(
2597            "The tool '{}' doesn't exist or is not enabled. Available tools:\n{}",
2598            hallucinated_tool_name, tool_list
2599        );
2600
2601        let pending_tool_use = self.tool_use.insert_tool_output(
2602            tool_use_id.clone(),
2603            hallucinated_tool_name,
2604            Err(anyhow!("Missing tool call: {error_message}")),
2605            self.configured_model.as_ref(),
2606            self.completion_mode,
2607        );
2608
2609        cx.emit(ThreadEvent::MissingToolUse {
2610            tool_use_id: tool_use_id.clone(),
2611            ui_text: error_message.into(),
2612        });
2613
2614        self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2615    }
2616
2617    pub fn receive_invalid_tool_json(
2618        &mut self,
2619        tool_use_id: LanguageModelToolUseId,
2620        tool_name: Arc<str>,
2621        invalid_json: Arc<str>,
2622        error: String,
2623        window: Option<AnyWindowHandle>,
2624        cx: &mut Context<Thread>,
2625    ) {
2626        log::error!("The model returned invalid input JSON: {invalid_json}");
2627
2628        let pending_tool_use = self.tool_use.insert_tool_output(
2629            tool_use_id.clone(),
2630            tool_name,
2631            Err(anyhow!("Error parsing input JSON: {error}")),
2632            self.configured_model.as_ref(),
2633            self.completion_mode,
2634        );
2635        let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
2636            pending_tool_use.ui_text.clone()
2637        } else {
2638            log::error!(
2639                "There was no pending tool use for tool use {tool_use_id}, even though it finished (with invalid input JSON)."
2640            );
2641            format!("Unknown tool {}", tool_use_id).into()
2642        };
2643
2644        cx.emit(ThreadEvent::InvalidToolInput {
2645            tool_use_id: tool_use_id.clone(),
2646            ui_text,
2647            invalid_input_json: invalid_json,
2648        });
2649
2650        self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2651    }
2652
2653    pub fn run_tool(
2654        &mut self,
2655        tool_use_id: LanguageModelToolUseId,
2656        ui_text: impl Into<SharedString>,
2657        input: serde_json::Value,
2658        request: Arc<LanguageModelRequest>,
2659        tool: Arc<dyn Tool>,
2660        model: Arc<dyn LanguageModel>,
2661        window: Option<AnyWindowHandle>,
2662        cx: &mut Context<Thread>,
2663    ) {
2664        let task =
2665            self.spawn_tool_use(tool_use_id.clone(), request, input, tool, model, window, cx);
2666        self.tool_use
2667            .run_pending_tool(tool_use_id, ui_text.into(), task);
2668    }
2669
2670    fn spawn_tool_use(
2671        &mut self,
2672        tool_use_id: LanguageModelToolUseId,
2673        request: Arc<LanguageModelRequest>,
2674        input: serde_json::Value,
2675        tool: Arc<dyn Tool>,
2676        model: Arc<dyn LanguageModel>,
2677        window: Option<AnyWindowHandle>,
2678        cx: &mut Context<Thread>,
2679    ) -> Task<()> {
2680        let tool_name: Arc<str> = tool.name().into();
2681
2682        let tool_result = tool.run(
2683            input,
2684            request,
2685            self.project.clone(),
2686            self.action_log.clone(),
2687            model,
2688            window,
2689            cx,
2690        );
2691
2692        // Store the card separately if it exists
2693        if let Some(card) = tool_result.card.clone() {
2694            self.tool_use
2695                .insert_tool_result_card(tool_use_id.clone(), card);
2696        }
2697
2698        cx.spawn({
2699            async move |thread: WeakEntity<Thread>, cx| {
2700                let output = tool_result.output.await;
2701
2702                thread
2703                    .update(cx, |thread, cx| {
2704                        let pending_tool_use = thread.tool_use.insert_tool_output(
2705                            tool_use_id.clone(),
2706                            tool_name,
2707                            output,
2708                            thread.configured_model.as_ref(),
2709                            thread.completion_mode,
2710                        );
2711                        thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2712                    })
2713                    .ok();
2714            }
2715        })
2716    }
2717
2718    fn tool_finished(
2719        &mut self,
2720        tool_use_id: LanguageModelToolUseId,
2721        pending_tool_use: Option<PendingToolUse>,
2722        canceled: bool,
2723        window: Option<AnyWindowHandle>,
2724        cx: &mut Context<Self>,
2725    ) {
2726        if self.all_tools_finished()
2727            && let Some(ConfiguredModel { model, .. }) = self.configured_model.as_ref()
2728            && !canceled
2729        {
2730            self.send_to_model(model.clone(), CompletionIntent::ToolResults, window, cx);
2731        }
2732
2733        cx.emit(ThreadEvent::ToolFinished {
2734            tool_use_id,
2735            pending_tool_use,
2736        });
2737    }
2738
2739    /// Cancels the last pending completion, if there are any pending.
2740    ///
2741    /// Returns whether a completion was canceled.
2742    pub fn cancel_last_completion(
2743        &mut self,
2744        window: Option<AnyWindowHandle>,
2745        cx: &mut Context<Self>,
2746    ) -> bool {
2747        let mut canceled = self.pending_completions.pop().is_some() || self.retry_state.is_some();
2748
2749        self.retry_state = None;
2750
2751        for pending_tool_use in self.tool_use.cancel_pending() {
2752            canceled = true;
2753            self.tool_finished(
2754                pending_tool_use.id.clone(),
2755                Some(pending_tool_use),
2756                true,
2757                window,
2758                cx,
2759            );
2760        }
2761
2762        if canceled {
2763            cx.emit(ThreadEvent::CompletionCanceled);
2764
2765            // When canceled, we always want to insert the checkpoint.
2766            // (We skip over finalize_pending_checkpoint, because it
2767            // would conclude we didn't have anything to insert here.)
2768            if let Some(checkpoint) = self.pending_checkpoint.take() {
2769                self.insert_checkpoint(checkpoint, cx);
2770            }
2771        } else {
2772            self.finalize_pending_checkpoint(cx);
2773        }
2774
2775        canceled
2776    }
2777
2778    /// Signals that any in-progress editing should be canceled.
2779    ///
2780    /// This method is used to notify listeners (like ActiveThread) that
2781    /// they should cancel any editing operations.
2782    pub fn cancel_editing(&mut self, cx: &mut Context<Self>) {
2783        cx.emit(ThreadEvent::CancelEditing);
2784    }
2785
2786    pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
2787        self.message_feedback.get(&message_id).copied()
2788    }
2789
2790    pub fn report_message_feedback(
2791        &mut self,
2792        message_id: MessageId,
2793        feedback: ThreadFeedback,
2794        cx: &mut Context<Self>,
2795    ) -> Task<Result<()>> {
2796        if self.message_feedback.get(&message_id) == Some(&feedback) {
2797            return Task::ready(Ok(()));
2798        }
2799
2800        let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2801        let serialized_thread = self.serialize(cx);
2802        let thread_id = self.id().clone();
2803        let client = self.project.read(cx).client();
2804
2805        let enabled_tool_names: Vec<String> = self
2806            .profile
2807            .enabled_tools(cx)
2808            .iter()
2809            .map(|(name, _)| name.clone().into())
2810            .collect();
2811
2812        self.message_feedback.insert(message_id, feedback);
2813
2814        cx.notify();
2815
2816        let message_content = self
2817            .message(message_id)
2818            .map(|msg| msg.to_message_content())
2819            .unwrap_or_default();
2820
2821        cx.background_spawn(async move {
2822            let final_project_snapshot = final_project_snapshot.await;
2823            let serialized_thread = serialized_thread.await?;
2824            let thread_data =
2825                serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
2826
2827            let rating = match feedback {
2828                ThreadFeedback::Positive => "positive",
2829                ThreadFeedback::Negative => "negative",
2830            };
2831            telemetry::event!(
2832                "Assistant Thread Rated",
2833                rating,
2834                thread_id,
2835                enabled_tool_names,
2836                message_id = message_id.0,
2837                message_content,
2838                thread_data,
2839                final_project_snapshot
2840            );
2841            client.telemetry().flush_events().await;
2842
2843            Ok(())
2844        })
2845    }
2846
2847    /// Create a snapshot of the current project state including git information and unsaved buffers.
2848    fn project_snapshot(
2849        project: Entity<Project>,
2850        cx: &mut Context<Self>,
2851    ) -> Task<Arc<ProjectSnapshot>> {
2852        let git_store = project.read(cx).git_store().clone();
2853        let worktree_snapshots: Vec<_> = project
2854            .read(cx)
2855            .visible_worktrees(cx)
2856            .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
2857            .collect();
2858
2859        cx.spawn(async move |_, _| {
2860            let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
2861
2862            Arc::new(ProjectSnapshot {
2863                worktree_snapshots,
2864                timestamp: Utc::now(),
2865            })
2866        })
2867    }
2868
2869    fn worktree_snapshot(
2870        worktree: Entity<project::Worktree>,
2871        git_store: Entity<GitStore>,
2872        cx: &App,
2873    ) -> Task<WorktreeSnapshot> {
2874        cx.spawn(async move |cx| {
2875            // Get worktree path and snapshot
2876            let worktree_info = cx.update(|app_cx| {
2877                let worktree = worktree.read(app_cx);
2878                let path = worktree.abs_path().to_string_lossy().into_owned();
2879                let snapshot = worktree.snapshot();
2880                (path, snapshot)
2881            });
2882
2883            let Ok((worktree_path, _snapshot)) = worktree_info else {
2884                return WorktreeSnapshot {
2885                    worktree_path: String::new(),
2886                    git_state: None,
2887                };
2888            };
2889
2890            let git_state = git_store
2891                .update(cx, |git_store, cx| {
2892                    git_store
2893                        .repositories()
2894                        .values()
2895                        .find(|repo| {
2896                            repo.read(cx)
2897                                .abs_path_to_repo_path(&worktree.read(cx).abs_path())
2898                                .is_some()
2899                        })
2900                        .cloned()
2901                })
2902                .ok()
2903                .flatten()
2904                .map(|repo| {
2905                    repo.update(cx, |repo, _| {
2906                        let current_branch =
2907                            repo.branch.as_ref().map(|branch| branch.name().to_owned());
2908                        repo.send_job(None, |state, _| async move {
2909                            let RepositoryState::Local { backend, .. } = state else {
2910                                return GitState {
2911                                    remote_url: None,
2912                                    head_sha: None,
2913                                    current_branch,
2914                                    diff: None,
2915                                };
2916                            };
2917
2918                            let remote_url = backend.remote_url("origin");
2919                            let head_sha = backend.head_sha().await;
2920                            let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
2921
2922                            GitState {
2923                                remote_url,
2924                                head_sha,
2925                                current_branch,
2926                                diff,
2927                            }
2928                        })
2929                    })
2930                });
2931
2932            let git_state = match git_state {
2933                Some(git_state) => match git_state.ok() {
2934                    Some(git_state) => git_state.await.ok(),
2935                    None => None,
2936                },
2937                None => None,
2938            };
2939
2940            WorktreeSnapshot {
2941                worktree_path,
2942                git_state,
2943            }
2944        })
2945    }
2946
2947    pub fn to_markdown(&self, cx: &App) -> Result<String> {
2948        let mut markdown = Vec::new();
2949
2950        let summary = self.summary().or_default();
2951        writeln!(markdown, "# {summary}\n")?;
2952
2953        for message in self.messages() {
2954            writeln!(
2955                markdown,
2956                "## {role}\n",
2957                role = match message.role {
2958                    Role::User => "User",
2959                    Role::Assistant => "Agent",
2960                    Role::System => "System",
2961                }
2962            )?;
2963
2964            if !message.loaded_context.text.is_empty() {
2965                writeln!(markdown, "{}", message.loaded_context.text)?;
2966            }
2967
2968            if !message.loaded_context.images.is_empty() {
2969                writeln!(
2970                    markdown,
2971                    "\n{} images attached as context.\n",
2972                    message.loaded_context.images.len()
2973                )?;
2974            }
2975
2976            for segment in &message.segments {
2977                match segment {
2978                    MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
2979                    MessageSegment::Thinking { text, .. } => {
2980                        writeln!(markdown, "<think>\n{}\n</think>\n", text)?
2981                    }
2982                    MessageSegment::RedactedThinking(_) => {}
2983                }
2984            }
2985
2986            for tool_use in self.tool_uses_for_message(message.id, cx) {
2987                writeln!(
2988                    markdown,
2989                    "**Use Tool: {} ({})**",
2990                    tool_use.name, tool_use.id
2991                )?;
2992                writeln!(markdown, "```json")?;
2993                writeln!(
2994                    markdown,
2995                    "{}",
2996                    serde_json::to_string_pretty(&tool_use.input)?
2997                )?;
2998                writeln!(markdown, "```")?;
2999            }
3000
3001            for tool_result in self.tool_results_for_message(message.id) {
3002                write!(markdown, "\n**Tool Results: {}", tool_result.tool_use_id)?;
3003                if tool_result.is_error {
3004                    write!(markdown, " (Error)")?;
3005                }
3006
3007                writeln!(markdown, "**\n")?;
3008                match &tool_result.content {
3009                    LanguageModelToolResultContent::Text(text) => {
3010                        writeln!(markdown, "{text}")?;
3011                    }
3012                    LanguageModelToolResultContent::Image(image) => {
3013                        writeln!(markdown, "![Image](data:base64,{})", image.source)?;
3014                    }
3015                }
3016
3017                if let Some(output) = tool_result.output.as_ref() {
3018                    writeln!(
3019                        markdown,
3020                        "\n\nDebug Output:\n\n```json\n{}\n```\n",
3021                        serde_json::to_string_pretty(output)?
3022                    )?;
3023                }
3024            }
3025        }
3026
3027        Ok(String::from_utf8_lossy(&markdown).to_string())
3028    }
3029
3030    pub fn keep_edits_in_range(
3031        &mut self,
3032        buffer: Entity<language::Buffer>,
3033        buffer_range: Range<language::Anchor>,
3034        cx: &mut Context<Self>,
3035    ) {
3036        self.action_log.update(cx, |action_log, cx| {
3037            action_log.keep_edits_in_range(buffer, buffer_range, cx)
3038        });
3039    }
3040
3041    pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
3042        self.action_log
3043            .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
3044    }
3045
3046    pub fn reject_edits_in_ranges(
3047        &mut self,
3048        buffer: Entity<language::Buffer>,
3049        buffer_ranges: Vec<Range<language::Anchor>>,
3050        cx: &mut Context<Self>,
3051    ) -> Task<Result<()>> {
3052        self.action_log.update(cx, |action_log, cx| {
3053            action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
3054        })
3055    }
3056
3057    pub fn action_log(&self) -> &Entity<ActionLog> {
3058        &self.action_log
3059    }
3060
3061    pub fn project(&self) -> &Entity<Project> {
3062        &self.project
3063    }
3064
3065    pub fn cumulative_token_usage(&self) -> TokenUsage {
3066        self.cumulative_token_usage
3067    }
3068
3069    pub fn token_usage_up_to_message(&self, message_id: MessageId) -> TotalTokenUsage {
3070        let Some(model) = self.configured_model.as_ref() else {
3071            return TotalTokenUsage::default();
3072        };
3073
3074        let max = model
3075            .model
3076            .max_token_count_for_mode(self.completion_mode().into());
3077
3078        let index = self
3079            .messages
3080            .iter()
3081            .position(|msg| msg.id == message_id)
3082            .unwrap_or(0);
3083
3084        if index == 0 {
3085            return TotalTokenUsage { total: 0, max };
3086        }
3087
3088        let token_usage = &self
3089            .request_token_usage
3090            .get(index - 1)
3091            .cloned()
3092            .unwrap_or_default();
3093
3094        TotalTokenUsage {
3095            total: token_usage.total_tokens(),
3096            max,
3097        }
3098    }
3099
3100    pub fn total_token_usage(&self) -> Option<TotalTokenUsage> {
3101        let model = self.configured_model.as_ref()?;
3102
3103        let max = model
3104            .model
3105            .max_token_count_for_mode(self.completion_mode().into());
3106
3107        if let Some(exceeded_error) = &self.exceeded_window_error
3108            && model.model.id() == exceeded_error.model_id
3109        {
3110            return Some(TotalTokenUsage {
3111                total: exceeded_error.token_count,
3112                max,
3113            });
3114        }
3115
3116        let total = self
3117            .token_usage_at_last_message()
3118            .unwrap_or_default()
3119            .total_tokens();
3120
3121        Some(TotalTokenUsage { total, max })
3122    }
3123
3124    fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
3125        self.request_token_usage
3126            .get(self.messages.len().saturating_sub(1))
3127            .or_else(|| self.request_token_usage.last())
3128            .cloned()
3129    }
3130
3131    fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
3132        let placeholder = self.token_usage_at_last_message().unwrap_or_default();
3133        self.request_token_usage
3134            .resize(self.messages.len(), placeholder);
3135
3136        if let Some(last) = self.request_token_usage.last_mut() {
3137            *last = token_usage;
3138        }
3139    }
3140
3141    fn update_model_request_usage(&self, amount: u32, limit: UsageLimit, cx: &mut Context<Self>) {
3142        self.project
3143            .read(cx)
3144            .user_store()
3145            .update(cx, |user_store, cx| {
3146                user_store.update_model_request_usage(
3147                    ModelRequestUsage(RequestUsage {
3148                        amount: amount as i32,
3149                        limit,
3150                    }),
3151                    cx,
3152                )
3153            });
3154    }
3155
3156    pub fn deny_tool_use(
3157        &mut self,
3158        tool_use_id: LanguageModelToolUseId,
3159        tool_name: Arc<str>,
3160        window: Option<AnyWindowHandle>,
3161        cx: &mut Context<Self>,
3162    ) {
3163        let err = Err(anyhow::anyhow!(
3164            "Permission to run tool action denied by user"
3165        ));
3166
3167        self.tool_use.insert_tool_output(
3168            tool_use_id.clone(),
3169            tool_name,
3170            err,
3171            self.configured_model.as_ref(),
3172            self.completion_mode,
3173        );
3174        self.tool_finished(tool_use_id, None, true, window, cx);
3175    }
3176}
3177
3178#[derive(Debug, Clone, Error)]
3179pub enum ThreadError {
3180    #[error("Payment required")]
3181    PaymentRequired,
3182    #[error("Model request limit reached")]
3183    ModelRequestLimitReached { plan: Plan },
3184    #[error("Message {header}: {message}")]
3185    Message {
3186        header: SharedString,
3187        message: SharedString,
3188    },
3189    #[error("Retryable error: {message}")]
3190    RetryableError {
3191        message: SharedString,
3192        can_enable_burn_mode: bool,
3193    },
3194}
3195
3196#[derive(Debug, Clone)]
3197pub enum ThreadEvent {
3198    ShowError(ThreadError),
3199    StreamedCompletion,
3200    ReceivedTextChunk,
3201    NewRequest,
3202    StreamedAssistantText(MessageId, String),
3203    StreamedAssistantThinking(MessageId, String),
3204    StreamedToolUse {
3205        tool_use_id: LanguageModelToolUseId,
3206        ui_text: Arc<str>,
3207        input: serde_json::Value,
3208    },
3209    MissingToolUse {
3210        tool_use_id: LanguageModelToolUseId,
3211        ui_text: Arc<str>,
3212    },
3213    InvalidToolInput {
3214        tool_use_id: LanguageModelToolUseId,
3215        ui_text: Arc<str>,
3216        invalid_input_json: Arc<str>,
3217    },
3218    Stopped(Result<StopReason, Arc<anyhow::Error>>),
3219    MessageAdded(MessageId),
3220    MessageEdited(MessageId),
3221    MessageDeleted(MessageId),
3222    SummaryGenerated,
3223    SummaryChanged,
3224    UsePendingTools {
3225        tool_uses: Vec<PendingToolUse>,
3226    },
3227    ToolFinished {
3228        #[allow(unused)]
3229        tool_use_id: LanguageModelToolUseId,
3230        /// The pending tool use that corresponds to this tool.
3231        pending_tool_use: Option<PendingToolUse>,
3232    },
3233    CheckpointChanged,
3234    ToolConfirmationNeeded,
3235    ToolUseLimitReached,
3236    CancelEditing,
3237    CompletionCanceled,
3238    ProfileChanged,
3239}
3240
3241impl EventEmitter<ThreadEvent> for Thread {}
3242
3243struct PendingCompletion {
3244    id: usize,
3245    queue_state: QueueState,
3246    _task: Task<()>,
3247}
3248
3249#[cfg(test)]
3250mod tests {
3251    use super::*;
3252    use crate::{
3253        context::load_context, context_store::ContextStore, thread_store, thread_store::ThreadStore,
3254    };
3255
3256    // Test-specific constants
3257    const TEST_RATE_LIMIT_RETRY_SECS: u64 = 30;
3258    use agent_settings::{AgentProfileId, AgentSettings};
3259    use assistant_tool::ToolRegistry;
3260    use assistant_tools;
3261    use fs::Fs;
3262    use futures::StreamExt;
3263    use futures::future::BoxFuture;
3264    use futures::stream::BoxStream;
3265    use gpui::TestAppContext;
3266    use http_client;
3267    use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
3268    use language_model::{
3269        LanguageModelCompletionError, LanguageModelName, LanguageModelProviderId,
3270        LanguageModelProviderName, LanguageModelToolChoice,
3271    };
3272    use parking_lot::Mutex;
3273    use project::{FakeFs, Project};
3274    use prompt_store::PromptBuilder;
3275    use serde_json::json;
3276    use settings::{LanguageModelParameters, Settings, SettingsStore};
3277    use std::sync::Arc;
3278    use std::time::Duration;
3279    use util::path;
3280    use workspace::Workspace;
3281
3282    #[gpui::test]
3283    async fn test_message_with_context(cx: &mut TestAppContext) {
3284        let fs = init_test_settings(cx);
3285
3286        let project = create_test_project(
3287            &fs,
3288            cx,
3289            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
3290        )
3291        .await;
3292
3293        let (_workspace, _thread_store, thread, context_store, model) =
3294            setup_test_environment(cx, project.clone()).await;
3295
3296        add_file_to_context(&project, &context_store, "test/code.rs", cx)
3297            .await
3298            .unwrap();
3299
3300        let context =
3301            context_store.read_with(cx, |store, _| store.context().next().cloned().unwrap());
3302        let loaded_context = cx
3303            .update(|cx| load_context(vec![context], &project, &None, cx))
3304            .await;
3305
3306        // Insert user message with context
3307        let message_id = thread.update(cx, |thread, cx| {
3308            thread.insert_user_message(
3309                "Please explain this code",
3310                loaded_context,
3311                None,
3312                Vec::new(),
3313                cx,
3314            )
3315        });
3316
3317        // Check content and context in message object
3318        let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
3319
3320        // Use different path format strings based on platform for the test
3321        #[cfg(windows)]
3322        let path_part = r"test\code.rs";
3323        #[cfg(not(windows))]
3324        let path_part = "test/code.rs";
3325
3326        let expected_context = format!(
3327            r#"
3328<context>
3329The following items were attached by the user. They are up-to-date and don't need to be re-read.
3330
3331<files>
3332```rs {path_part}
3333fn main() {{
3334    println!("Hello, world!");
3335}}
3336```
3337</files>
3338</context>
3339"#
3340        );
3341
3342        assert_eq!(message.role, Role::User);
3343        assert_eq!(message.segments.len(), 1);
3344        assert_eq!(
3345            message.segments[0],
3346            MessageSegment::Text("Please explain this code".to_string())
3347        );
3348        assert_eq!(message.loaded_context.text, expected_context);
3349
3350        // Check message in request
3351        let request = thread.update(cx, |thread, cx| {
3352            thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3353        });
3354
3355        assert_eq!(request.messages.len(), 2);
3356        let expected_full_message = format!("{}Please explain this code", expected_context);
3357        assert_eq!(request.messages[1].string_contents(), expected_full_message);
3358    }
3359
3360    #[gpui::test]
3361    async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
3362        let fs = init_test_settings(cx);
3363
3364        let project = create_test_project(
3365            &fs,
3366            cx,
3367            json!({
3368                "file1.rs": "fn function1() {}\n",
3369                "file2.rs": "fn function2() {}\n",
3370                "file3.rs": "fn function3() {}\n",
3371                "file4.rs": "fn function4() {}\n",
3372            }),
3373        )
3374        .await;
3375
3376        let (_, _thread_store, thread, context_store, model) =
3377            setup_test_environment(cx, project.clone()).await;
3378
3379        // First message with context 1
3380        add_file_to_context(&project, &context_store, "test/file1.rs", cx)
3381            .await
3382            .unwrap();
3383        let new_contexts = context_store.update(cx, |store, cx| {
3384            store.new_context_for_thread(thread.read(cx), None)
3385        });
3386        assert_eq!(new_contexts.len(), 1);
3387        let loaded_context = cx
3388            .update(|cx| load_context(new_contexts, &project, &None, cx))
3389            .await;
3390        let message1_id = thread.update(cx, |thread, cx| {
3391            thread.insert_user_message("Message 1", loaded_context, None, Vec::new(), cx)
3392        });
3393
3394        // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
3395        add_file_to_context(&project, &context_store, "test/file2.rs", cx)
3396            .await
3397            .unwrap();
3398        let new_contexts = context_store.update(cx, |store, cx| {
3399            store.new_context_for_thread(thread.read(cx), None)
3400        });
3401        assert_eq!(new_contexts.len(), 1);
3402        let loaded_context = cx
3403            .update(|cx| load_context(new_contexts, &project, &None, cx))
3404            .await;
3405        let message2_id = thread.update(cx, |thread, cx| {
3406            thread.insert_user_message("Message 2", loaded_context, None, Vec::new(), cx)
3407        });
3408
3409        // Third message with all three contexts (contexts 1 and 2 should be skipped)
3410        //
3411        add_file_to_context(&project, &context_store, "test/file3.rs", cx)
3412            .await
3413            .unwrap();
3414        let new_contexts = context_store.update(cx, |store, cx| {
3415            store.new_context_for_thread(thread.read(cx), None)
3416        });
3417        assert_eq!(new_contexts.len(), 1);
3418        let loaded_context = cx
3419            .update(|cx| load_context(new_contexts, &project, &None, cx))
3420            .await;
3421        let message3_id = thread.update(cx, |thread, cx| {
3422            thread.insert_user_message("Message 3", loaded_context, None, Vec::new(), cx)
3423        });
3424
3425        // Check what contexts are included in each message
3426        let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
3427            (
3428                thread.message(message1_id).unwrap().clone(),
3429                thread.message(message2_id).unwrap().clone(),
3430                thread.message(message3_id).unwrap().clone(),
3431            )
3432        });
3433
3434        // First message should include context 1
3435        assert!(message1.loaded_context.text.contains("file1.rs"));
3436
3437        // Second message should include only context 2 (not 1)
3438        assert!(!message2.loaded_context.text.contains("file1.rs"));
3439        assert!(message2.loaded_context.text.contains("file2.rs"));
3440
3441        // Third message should include only context 3 (not 1 or 2)
3442        assert!(!message3.loaded_context.text.contains("file1.rs"));
3443        assert!(!message3.loaded_context.text.contains("file2.rs"));
3444        assert!(message3.loaded_context.text.contains("file3.rs"));
3445
3446        // Check entire request to make sure all contexts are properly included
3447        let request = thread.update(cx, |thread, cx| {
3448            thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3449        });
3450
3451        // The request should contain all 3 messages
3452        assert_eq!(request.messages.len(), 4);
3453
3454        // Check that the contexts are properly formatted in each message
3455        assert!(request.messages[1].string_contents().contains("file1.rs"));
3456        assert!(!request.messages[1].string_contents().contains("file2.rs"));
3457        assert!(!request.messages[1].string_contents().contains("file3.rs"));
3458
3459        assert!(!request.messages[2].string_contents().contains("file1.rs"));
3460        assert!(request.messages[2].string_contents().contains("file2.rs"));
3461        assert!(!request.messages[2].string_contents().contains("file3.rs"));
3462
3463        assert!(!request.messages[3].string_contents().contains("file1.rs"));
3464        assert!(!request.messages[3].string_contents().contains("file2.rs"));
3465        assert!(request.messages[3].string_contents().contains("file3.rs"));
3466
3467        add_file_to_context(&project, &context_store, "test/file4.rs", cx)
3468            .await
3469            .unwrap();
3470        let new_contexts = context_store.update(cx, |store, cx| {
3471            store.new_context_for_thread(thread.read(cx), Some(message2_id))
3472        });
3473        assert_eq!(new_contexts.len(), 3);
3474        let loaded_context = cx
3475            .update(|cx| load_context(new_contexts, &project, &None, cx))
3476            .await
3477            .loaded_context;
3478
3479        assert!(!loaded_context.text.contains("file1.rs"));
3480        assert!(loaded_context.text.contains("file2.rs"));
3481        assert!(loaded_context.text.contains("file3.rs"));
3482        assert!(loaded_context.text.contains("file4.rs"));
3483
3484        let new_contexts = context_store.update(cx, |store, cx| {
3485            // Remove file4.rs
3486            store.remove_context(&loaded_context.contexts[2].handle(), cx);
3487            store.new_context_for_thread(thread.read(cx), Some(message2_id))
3488        });
3489        assert_eq!(new_contexts.len(), 2);
3490        let loaded_context = cx
3491            .update(|cx| load_context(new_contexts, &project, &None, cx))
3492            .await
3493            .loaded_context;
3494
3495        assert!(!loaded_context.text.contains("file1.rs"));
3496        assert!(loaded_context.text.contains("file2.rs"));
3497        assert!(loaded_context.text.contains("file3.rs"));
3498        assert!(!loaded_context.text.contains("file4.rs"));
3499
3500        let new_contexts = context_store.update(cx, |store, cx| {
3501            // Remove file3.rs
3502            store.remove_context(&loaded_context.contexts[1].handle(), cx);
3503            store.new_context_for_thread(thread.read(cx), Some(message2_id))
3504        });
3505        assert_eq!(new_contexts.len(), 1);
3506        let loaded_context = cx
3507            .update(|cx| load_context(new_contexts, &project, &None, cx))
3508            .await
3509            .loaded_context;
3510
3511        assert!(!loaded_context.text.contains("file1.rs"));
3512        assert!(loaded_context.text.contains("file2.rs"));
3513        assert!(!loaded_context.text.contains("file3.rs"));
3514        assert!(!loaded_context.text.contains("file4.rs"));
3515    }
3516
3517    #[gpui::test]
3518    async fn test_message_without_files(cx: &mut TestAppContext) {
3519        let fs = init_test_settings(cx);
3520
3521        let project = create_test_project(
3522            &fs,
3523            cx,
3524            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
3525        )
3526        .await;
3527
3528        let (_, _thread_store, thread, _context_store, model) =
3529            setup_test_environment(cx, project.clone()).await;
3530
3531        // Insert user message without any context (empty context vector)
3532        let message_id = thread.update(cx, |thread, cx| {
3533            thread.insert_user_message(
3534                "What is the best way to learn Rust?",
3535                ContextLoadResult::default(),
3536                None,
3537                Vec::new(),
3538                cx,
3539            )
3540        });
3541
3542        // Check content and context in message object
3543        let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
3544
3545        // Context should be empty when no files are included
3546        assert_eq!(message.role, Role::User);
3547        assert_eq!(message.segments.len(), 1);
3548        assert_eq!(
3549            message.segments[0],
3550            MessageSegment::Text("What is the best way to learn Rust?".to_string())
3551        );
3552        assert_eq!(message.loaded_context.text, "");
3553
3554        // Check message in request
3555        let request = thread.update(cx, |thread, cx| {
3556            thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3557        });
3558
3559        assert_eq!(request.messages.len(), 2);
3560        assert_eq!(
3561            request.messages[1].string_contents(),
3562            "What is the best way to learn Rust?"
3563        );
3564
3565        // Add second message, also without context
3566        let message2_id = thread.update(cx, |thread, cx| {
3567            thread.insert_user_message(
3568                "Are there any good books?",
3569                ContextLoadResult::default(),
3570                None,
3571                Vec::new(),
3572                cx,
3573            )
3574        });
3575
3576        let message2 =
3577            thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
3578        assert_eq!(message2.loaded_context.text, "");
3579
3580        // Check that both messages appear in the request
3581        let request = thread.update(cx, |thread, cx| {
3582            thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3583        });
3584
3585        assert_eq!(request.messages.len(), 3);
3586        assert_eq!(
3587            request.messages[1].string_contents(),
3588            "What is the best way to learn Rust?"
3589        );
3590        assert_eq!(
3591            request.messages[2].string_contents(),
3592            "Are there any good books?"
3593        );
3594    }
3595
3596    #[gpui::test]
3597    #[ignore] // turn this test on when project_notifications tool is re-enabled
3598    async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
3599        let fs = init_test_settings(cx);
3600
3601        let project = create_test_project(
3602            &fs,
3603            cx,
3604            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
3605        )
3606        .await;
3607
3608        let (_workspace, _thread_store, thread, context_store, model) =
3609            setup_test_environment(cx, project.clone()).await;
3610
3611        // Add a buffer to the context. This will be a tracked buffer
3612        let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
3613            .await
3614            .unwrap();
3615
3616        let context = context_store
3617            .read_with(cx, |store, _| store.context().next().cloned())
3618            .unwrap();
3619        let loaded_context = cx
3620            .update(|cx| load_context(vec![context], &project, &None, cx))
3621            .await;
3622
3623        // Insert user message and assistant response
3624        thread.update(cx, |thread, cx| {
3625            thread.insert_user_message("Explain this code", loaded_context, None, Vec::new(), cx);
3626            thread.insert_assistant_message(
3627                vec![MessageSegment::Text("This code prints 42.".into())],
3628                cx,
3629            );
3630        });
3631        cx.run_until_parked();
3632
3633        // We shouldn't have a stale buffer notification yet
3634        let notifications = thread.read_with(cx, |thread, _| {
3635            find_tool_uses(thread, "project_notifications")
3636        });
3637        assert!(
3638            notifications.is_empty(),
3639            "Should not have stale buffer notification before buffer is modified"
3640        );
3641
3642        // Modify the buffer
3643        buffer.update(cx, |buffer, cx| {
3644            buffer.edit(
3645                [(1..1, "\n    println!(\"Added a new line\");\n")],
3646                None,
3647                cx,
3648            );
3649        });
3650
3651        // Insert another user message
3652        thread.update(cx, |thread, cx| {
3653            thread.insert_user_message(
3654                "What does the code do now?",
3655                ContextLoadResult::default(),
3656                None,
3657                Vec::new(),
3658                cx,
3659            )
3660        });
3661        cx.run_until_parked();
3662
3663        // Check for the stale buffer warning
3664        thread.update(cx, |thread, cx| {
3665            thread.flush_notifications(model.clone(), CompletionIntent::UserPrompt, cx)
3666        });
3667        cx.run_until_parked();
3668
3669        let notifications = thread.read_with(cx, |thread, _cx| {
3670            find_tool_uses(thread, "project_notifications")
3671        });
3672
3673        let [notification] = notifications.as_slice() else {
3674            panic!("Should have a `project_notifications` tool use");
3675        };
3676
3677        let Some(notification_content) = notification.content.to_str() else {
3678            panic!("`project_notifications` should return text");
3679        };
3680
3681        assert!(notification_content.contains("These files have changed since the last read:"));
3682        assert!(notification_content.contains("code.rs"));
3683
3684        // Insert another user message and flush notifications again
3685        thread.update(cx, |thread, cx| {
3686            thread.insert_user_message(
3687                "Can you tell me more?",
3688                ContextLoadResult::default(),
3689                None,
3690                Vec::new(),
3691                cx,
3692            )
3693        });
3694
3695        thread.update(cx, |thread, cx| {
3696            thread.flush_notifications(model.clone(), CompletionIntent::UserPrompt, cx)
3697        });
3698        cx.run_until_parked();
3699
3700        // There should be no new notifications (we already flushed one)
3701        let notifications = thread.read_with(cx, |thread, _cx| {
3702            find_tool_uses(thread, "project_notifications")
3703        });
3704
3705        assert_eq!(
3706            notifications.len(),
3707            1,
3708            "Should still have only one notification after second flush - no duplicates"
3709        );
3710    }
3711
3712    fn find_tool_uses(thread: &Thread, tool_name: &str) -> Vec<LanguageModelToolResult> {
3713        thread
3714            .messages()
3715            .flat_map(|message| {
3716                thread
3717                    .tool_results_for_message(message.id)
3718                    .into_iter()
3719                    .filter(|result| result.tool_name == tool_name.into())
3720                    .cloned()
3721                    .collect::<Vec<_>>()
3722            })
3723            .collect()
3724    }
3725
3726    #[gpui::test]
3727    async fn test_storing_profile_setting_per_thread(cx: &mut TestAppContext) {
3728        let fs = init_test_settings(cx);
3729
3730        let project = create_test_project(
3731            &fs,
3732            cx,
3733            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
3734        )
3735        .await;
3736
3737        let (_workspace, thread_store, thread, _context_store, _model) =
3738            setup_test_environment(cx, project.clone()).await;
3739
3740        // Check that we are starting with the default profile
3741        let profile = cx.read(|cx| thread.read(cx).profile.clone());
3742        let tool_set = cx.read(|cx| thread_store.read(cx).tools());
3743        assert_eq!(
3744            profile,
3745            AgentProfile::new(AgentProfileId::default(), tool_set)
3746        );
3747    }
3748
3749    #[gpui::test]
3750    async fn test_serializing_thread_profile(cx: &mut TestAppContext) {
3751        let fs = init_test_settings(cx);
3752
3753        let project = create_test_project(
3754            &fs,
3755            cx,
3756            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
3757        )
3758        .await;
3759
3760        let (_workspace, thread_store, thread, _context_store, _model) =
3761            setup_test_environment(cx, project.clone()).await;
3762
3763        // Profile gets serialized with default values
3764        let serialized = thread
3765            .update(cx, |thread, cx| thread.serialize(cx))
3766            .await
3767            .unwrap();
3768
3769        assert_eq!(serialized.profile, Some(AgentProfileId::default()));
3770
3771        let deserialized = cx.update(|cx| {
3772            thread.update(cx, |thread, cx| {
3773                Thread::deserialize(
3774                    thread.id.clone(),
3775                    serialized,
3776                    thread.project.clone(),
3777                    thread.tools.clone(),
3778                    thread.prompt_builder.clone(),
3779                    thread.project_context.clone(),
3780                    None,
3781                    cx,
3782                )
3783            })
3784        });
3785        let tool_set = cx.read(|cx| thread_store.read(cx).tools());
3786
3787        assert_eq!(
3788            deserialized.profile,
3789            AgentProfile::new(AgentProfileId::default(), tool_set)
3790        );
3791    }
3792
3793    #[gpui::test]
3794    async fn test_temperature_setting(cx: &mut TestAppContext) {
3795        let fs = init_test_settings(cx);
3796
3797        let project = create_test_project(
3798            &fs,
3799            cx,
3800            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
3801        )
3802        .await;
3803
3804        let (_workspace, _thread_store, thread, _context_store, model) =
3805            setup_test_environment(cx, project.clone()).await;
3806
3807        // Both model and provider
3808        cx.update(|cx| {
3809            AgentSettings::override_global(
3810                AgentSettings {
3811                    model_parameters: vec![LanguageModelParameters {
3812                        provider: Some(model.provider_id().0.to_string().into()),
3813                        model: Some(model.id().0),
3814                        temperature: Some(0.66),
3815                    }],
3816                    ..AgentSettings::get_global(cx).clone()
3817                },
3818                cx,
3819            );
3820        });
3821
3822        let request = thread.update(cx, |thread, cx| {
3823            thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3824        });
3825        assert_eq!(request.temperature, Some(0.66));
3826
3827        // Only model
3828        cx.update(|cx| {
3829            AgentSettings::override_global(
3830                AgentSettings {
3831                    model_parameters: vec![LanguageModelParameters {
3832                        provider: None,
3833                        model: Some(model.id().0),
3834                        temperature: Some(0.66),
3835                    }],
3836                    ..AgentSettings::get_global(cx).clone()
3837                },
3838                cx,
3839            );
3840        });
3841
3842        let request = thread.update(cx, |thread, cx| {
3843            thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3844        });
3845        assert_eq!(request.temperature, Some(0.66));
3846
3847        // Only provider
3848        cx.update(|cx| {
3849            AgentSettings::override_global(
3850                AgentSettings {
3851                    model_parameters: vec![LanguageModelParameters {
3852                        provider: Some(model.provider_id().0.to_string().into()),
3853                        model: None,
3854                        temperature: Some(0.66),
3855                    }],
3856                    ..AgentSettings::get_global(cx).clone()
3857                },
3858                cx,
3859            );
3860        });
3861
3862        let request = thread.update(cx, |thread, cx| {
3863            thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3864        });
3865        assert_eq!(request.temperature, Some(0.66));
3866
3867        // Same model name, different provider
3868        cx.update(|cx| {
3869            AgentSettings::override_global(
3870                AgentSettings {
3871                    model_parameters: vec![LanguageModelParameters {
3872                        provider: Some("anthropic".into()),
3873                        model: Some(model.id().0),
3874                        temperature: Some(0.66),
3875                    }],
3876                    ..AgentSettings::get_global(cx).clone()
3877                },
3878                cx,
3879            );
3880        });
3881
3882        let request = thread.update(cx, |thread, cx| {
3883            thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3884        });
3885        assert_eq!(request.temperature, None);
3886    }
3887
3888    #[gpui::test]
3889    async fn test_thread_summary(cx: &mut TestAppContext) {
3890        let fs = init_test_settings(cx);
3891
3892        let project = create_test_project(&fs, cx, json!({})).await;
3893
3894        let (_, _thread_store, thread, _context_store, model) =
3895            setup_test_environment(cx, project.clone()).await;
3896
3897        // Initial state should be pending
3898        thread.read_with(cx, |thread, _| {
3899            assert!(matches!(thread.summary(), ThreadSummary::Pending));
3900            assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3901        });
3902
3903        // Manually setting the summary should not be allowed in this state
3904        thread.update(cx, |thread, cx| {
3905            thread.set_summary("This should not work", cx);
3906        });
3907
3908        thread.read_with(cx, |thread, _| {
3909            assert!(matches!(thread.summary(), ThreadSummary::Pending));
3910        });
3911
3912        // Send a message
3913        thread.update(cx, |thread, cx| {
3914            thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
3915            thread.send_to_model(
3916                model.clone(),
3917                CompletionIntent::ThreadSummarization,
3918                None,
3919                cx,
3920            );
3921        });
3922
3923        let fake_model = model.as_fake();
3924        simulate_successful_response(fake_model, cx);
3925
3926        // Should start generating summary when there are >= 2 messages
3927        thread.read_with(cx, |thread, _| {
3928            assert_eq!(*thread.summary(), ThreadSummary::Generating);
3929        });
3930
3931        // Should not be able to set the summary while generating
3932        thread.update(cx, |thread, cx| {
3933            thread.set_summary("This should not work either", cx);
3934        });
3935
3936        thread.read_with(cx, |thread, _| {
3937            assert!(matches!(thread.summary(), ThreadSummary::Generating));
3938            assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3939        });
3940
3941        cx.run_until_parked();
3942        fake_model.send_last_completion_stream_text_chunk("Brief");
3943        fake_model.send_last_completion_stream_text_chunk(" Introduction");
3944        fake_model.end_last_completion_stream();
3945        cx.run_until_parked();
3946
3947        // Summary should be set
3948        thread.read_with(cx, |thread, _| {
3949            assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3950            assert_eq!(thread.summary().or_default(), "Brief Introduction");
3951        });
3952
3953        // Now we should be able to set a summary
3954        thread.update(cx, |thread, cx| {
3955            thread.set_summary("Brief Intro", cx);
3956        });
3957
3958        thread.read_with(cx, |thread, _| {
3959            assert_eq!(thread.summary().or_default(), "Brief Intro");
3960        });
3961
3962        // Test setting an empty summary (should default to DEFAULT)
3963        thread.update(cx, |thread, cx| {
3964            thread.set_summary("", cx);
3965        });
3966
3967        thread.read_with(cx, |thread, _| {
3968            assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3969            assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3970        });
3971    }
3972
3973    #[gpui::test]
3974    async fn test_thread_summary_error_set_manually(cx: &mut TestAppContext) {
3975        let fs = init_test_settings(cx);
3976
3977        let project = create_test_project(&fs, cx, json!({})).await;
3978
3979        let (_, _thread_store, thread, _context_store, model) =
3980            setup_test_environment(cx, project.clone()).await;
3981
3982        test_summarize_error(&model, &thread, cx);
3983
3984        // Now we should be able to set a summary
3985        thread.update(cx, |thread, cx| {
3986            thread.set_summary("Brief Intro", cx);
3987        });
3988
3989        thread.read_with(cx, |thread, _| {
3990            assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3991            assert_eq!(thread.summary().or_default(), "Brief Intro");
3992        });
3993    }
3994
3995    #[gpui::test]
3996    async fn test_thread_summary_error_retry(cx: &mut TestAppContext) {
3997        let fs = init_test_settings(cx);
3998
3999        let project = create_test_project(&fs, cx, json!({})).await;
4000
4001        let (_, _thread_store, thread, _context_store, model) =
4002            setup_test_environment(cx, project.clone()).await;
4003
4004        test_summarize_error(&model, &thread, cx);
4005
4006        // Sending another message should not trigger another summarize request
4007        thread.update(cx, |thread, cx| {
4008            thread.insert_user_message(
4009                "How are you?",
4010                ContextLoadResult::default(),
4011                None,
4012                vec![],
4013                cx,
4014            );
4015            thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4016        });
4017
4018        let fake_model = model.as_fake();
4019        simulate_successful_response(fake_model, cx);
4020
4021        thread.read_with(cx, |thread, _| {
4022            // State is still Error, not Generating
4023            assert!(matches!(thread.summary(), ThreadSummary::Error));
4024        });
4025
4026        // But the summarize request can be invoked manually
4027        thread.update(cx, |thread, cx| {
4028            thread.summarize(cx);
4029        });
4030
4031        thread.read_with(cx, |thread, _| {
4032            assert!(matches!(thread.summary(), ThreadSummary::Generating));
4033        });
4034
4035        cx.run_until_parked();
4036        fake_model.send_last_completion_stream_text_chunk("A successful summary");
4037        fake_model.end_last_completion_stream();
4038        cx.run_until_parked();
4039
4040        thread.read_with(cx, |thread, _| {
4041            assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
4042            assert_eq!(thread.summary().or_default(), "A successful summary");
4043        });
4044    }
4045
4046    // Helper to create a model that returns errors
4047    enum TestError {
4048        Overloaded,
4049        InternalServerError,
4050    }
4051
4052    struct ErrorInjector {
4053        inner: Arc<FakeLanguageModel>,
4054        error_type: TestError,
4055    }
4056
4057    impl ErrorInjector {
4058        fn new(error_type: TestError) -> Self {
4059            Self {
4060                inner: Arc::new(FakeLanguageModel::default()),
4061                error_type,
4062            }
4063        }
4064    }
4065
4066    impl LanguageModel for ErrorInjector {
4067        fn id(&self) -> LanguageModelId {
4068            self.inner.id()
4069        }
4070
4071        fn name(&self) -> LanguageModelName {
4072            self.inner.name()
4073        }
4074
4075        fn provider_id(&self) -> LanguageModelProviderId {
4076            self.inner.provider_id()
4077        }
4078
4079        fn provider_name(&self) -> LanguageModelProviderName {
4080            self.inner.provider_name()
4081        }
4082
4083        fn supports_tools(&self) -> bool {
4084            self.inner.supports_tools()
4085        }
4086
4087        fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
4088            self.inner.supports_tool_choice(choice)
4089        }
4090
4091        fn supports_images(&self) -> bool {
4092            self.inner.supports_images()
4093        }
4094
4095        fn telemetry_id(&self) -> String {
4096            self.inner.telemetry_id()
4097        }
4098
4099        fn max_token_count(&self) -> u64 {
4100            self.inner.max_token_count()
4101        }
4102
4103        fn count_tokens(
4104            &self,
4105            request: LanguageModelRequest,
4106            cx: &App,
4107        ) -> BoxFuture<'static, Result<u64>> {
4108            self.inner.count_tokens(request, cx)
4109        }
4110
4111        fn stream_completion(
4112            &self,
4113            _request: LanguageModelRequest,
4114            _cx: &AsyncApp,
4115        ) -> BoxFuture<
4116            'static,
4117            Result<
4118                BoxStream<
4119                    'static,
4120                    Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
4121                >,
4122                LanguageModelCompletionError,
4123            >,
4124        > {
4125            let error = match self.error_type {
4126                TestError::Overloaded => LanguageModelCompletionError::ServerOverloaded {
4127                    provider: self.provider_name(),
4128                    retry_after: None,
4129                },
4130                TestError::InternalServerError => {
4131                    LanguageModelCompletionError::ApiInternalServerError {
4132                        provider: self.provider_name(),
4133                        message: "I'm a teapot orbiting the sun".to_string(),
4134                    }
4135                }
4136            };
4137            async move {
4138                let stream = futures::stream::once(async move { Err(error) });
4139                Ok(stream.boxed())
4140            }
4141            .boxed()
4142        }
4143
4144        fn as_fake(&self) -> &FakeLanguageModel {
4145            &self.inner
4146        }
4147    }
4148
4149    #[gpui::test]
4150    async fn test_retry_on_overloaded_error(cx: &mut TestAppContext) {
4151        let fs = init_test_settings(cx);
4152
4153        let project = create_test_project(&fs, cx, json!({})).await;
4154        let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4155
4156        // Enable Burn Mode to allow retries
4157        thread.update(cx, |thread, _| {
4158            thread.set_completion_mode(CompletionMode::Burn);
4159        });
4160
4161        // Create model that returns overloaded error
4162        let model = Arc::new(ErrorInjector::new(TestError::Overloaded));
4163
4164        // Insert a user message
4165        thread.update(cx, |thread, cx| {
4166            thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4167        });
4168
4169        // Start completion
4170        thread.update(cx, |thread, cx| {
4171            thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4172        });
4173
4174        cx.run_until_parked();
4175
4176        thread.read_with(cx, |thread, _| {
4177            assert!(thread.retry_state.is_some(), "Should have retry state");
4178            let retry_state = thread.retry_state.as_ref().unwrap();
4179            assert_eq!(retry_state.attempt, 1, "Should be first retry attempt");
4180            assert_eq!(
4181                retry_state.max_attempts, MAX_RETRY_ATTEMPTS,
4182                "Should retry MAX_RETRY_ATTEMPTS times for overloaded errors"
4183            );
4184        });
4185
4186        // Check that a retry message was added
4187        thread.read_with(cx, |thread, _| {
4188            let mut messages = thread.messages();
4189            assert!(
4190                messages.any(|msg| {
4191                    msg.role == Role::System
4192                        && msg.ui_only
4193                        && msg.segments.iter().any(|seg| {
4194                            if let MessageSegment::Text(text) = seg {
4195                                text.contains("overloaded")
4196                                    && text
4197                                        .contains(&format!("attempt 1 of {}", MAX_RETRY_ATTEMPTS))
4198                            } else {
4199                                false
4200                            }
4201                        })
4202                }),
4203                "Should have added a system retry message"
4204            );
4205        });
4206
4207        let retry_count = thread.update(cx, |thread, _| {
4208            thread
4209                .messages
4210                .iter()
4211                .filter(|m| {
4212                    m.ui_only
4213                        && m.segments.iter().any(|s| {
4214                            if let MessageSegment::Text(text) = s {
4215                                text.contains("Retrying") && text.contains("seconds")
4216                            } else {
4217                                false
4218                            }
4219                        })
4220                })
4221                .count()
4222        });
4223
4224        assert_eq!(retry_count, 1, "Should have one retry message");
4225    }
4226
4227    #[gpui::test]
4228    async fn test_retry_on_internal_server_error(cx: &mut TestAppContext) {
4229        let fs = init_test_settings(cx);
4230
4231        let project = create_test_project(&fs, cx, json!({})).await;
4232        let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4233
4234        // Enable Burn Mode to allow retries
4235        thread.update(cx, |thread, _| {
4236            thread.set_completion_mode(CompletionMode::Burn);
4237        });
4238
4239        // Create model that returns internal server error
4240        let model = Arc::new(ErrorInjector::new(TestError::InternalServerError));
4241
4242        // Insert a user message
4243        thread.update(cx, |thread, cx| {
4244            thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4245        });
4246
4247        // Start completion
4248        thread.update(cx, |thread, cx| {
4249            thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4250        });
4251
4252        cx.run_until_parked();
4253
4254        // Check retry state on thread
4255        thread.read_with(cx, |thread, _| {
4256            assert!(thread.retry_state.is_some(), "Should have retry state");
4257            let retry_state = thread.retry_state.as_ref().unwrap();
4258            assert_eq!(retry_state.attempt, 1, "Should be first retry attempt");
4259            assert_eq!(
4260                retry_state.max_attempts, 3,
4261                "Should have correct max attempts"
4262            );
4263        });
4264
4265        // Check that a retry message was added with provider name
4266        thread.read_with(cx, |thread, _| {
4267            let mut messages = thread.messages();
4268            assert!(
4269                messages.any(|msg| {
4270                    msg.role == Role::System
4271                        && msg.ui_only
4272                        && msg.segments.iter().any(|seg| {
4273                            if let MessageSegment::Text(text) = seg {
4274                                text.contains("internal")
4275                                    && text.contains("Fake")
4276                                    && text.contains("Retrying")
4277                                    && text.contains("attempt 1 of 3")
4278                                    && text.contains("seconds")
4279                            } else {
4280                                false
4281                            }
4282                        })
4283                }),
4284                "Should have added a system retry message with provider name"
4285            );
4286        });
4287
4288        // Count retry messages
4289        let retry_count = thread.update(cx, |thread, _| {
4290            thread
4291                .messages
4292                .iter()
4293                .filter(|m| {
4294                    m.ui_only
4295                        && m.segments.iter().any(|s| {
4296                            if let MessageSegment::Text(text) = s {
4297                                text.contains("Retrying") && text.contains("seconds")
4298                            } else {
4299                                false
4300                            }
4301                        })
4302                })
4303                .count()
4304        });
4305
4306        assert_eq!(retry_count, 1, "Should have one retry message");
4307    }
4308
4309    #[gpui::test]
4310    async fn test_exponential_backoff_on_retries(cx: &mut TestAppContext) {
4311        let fs = init_test_settings(cx);
4312
4313        let project = create_test_project(&fs, cx, json!({})).await;
4314        let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4315
4316        // Enable Burn Mode to allow retries
4317        thread.update(cx, |thread, _| {
4318            thread.set_completion_mode(CompletionMode::Burn);
4319        });
4320
4321        // Create model that returns internal server error
4322        let model = Arc::new(ErrorInjector::new(TestError::InternalServerError));
4323
4324        // Insert a user message
4325        thread.update(cx, |thread, cx| {
4326            thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4327        });
4328
4329        // Track retry events and completion count
4330        // Track completion events
4331        let completion_count = Arc::new(Mutex::new(0));
4332        let completion_count_clone = completion_count.clone();
4333
4334        let _subscription = thread.update(cx, |_, cx| {
4335            cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| {
4336                if let ThreadEvent::NewRequest = event {
4337                    *completion_count_clone.lock() += 1;
4338                }
4339            })
4340        });
4341
4342        // First attempt
4343        thread.update(cx, |thread, cx| {
4344            thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4345        });
4346        cx.run_until_parked();
4347
4348        // Should have scheduled first retry - count retry messages
4349        let retry_count = thread.update(cx, |thread, _| {
4350            thread
4351                .messages
4352                .iter()
4353                .filter(|m| {
4354                    m.ui_only
4355                        && m.segments.iter().any(|s| {
4356                            if let MessageSegment::Text(text) = s {
4357                                text.contains("Retrying") && text.contains("seconds")
4358                            } else {
4359                                false
4360                            }
4361                        })
4362                })
4363                .count()
4364        });
4365        assert_eq!(retry_count, 1, "Should have scheduled first retry");
4366
4367        // Check retry state
4368        thread.read_with(cx, |thread, _| {
4369            assert!(thread.retry_state.is_some(), "Should have retry state");
4370            let retry_state = thread.retry_state.as_ref().unwrap();
4371            assert_eq!(retry_state.attempt, 1, "Should be first retry attempt");
4372            assert_eq!(
4373                retry_state.max_attempts, 3,
4374                "Internal server errors should retry up to 3 times"
4375            );
4376        });
4377
4378        // Advance clock for first retry
4379        cx.executor().advance_clock(BASE_RETRY_DELAY);
4380        cx.run_until_parked();
4381
4382        // Advance clock for second retry
4383        cx.executor().advance_clock(BASE_RETRY_DELAY);
4384        cx.run_until_parked();
4385
4386        // Advance clock for third retry
4387        cx.executor().advance_clock(BASE_RETRY_DELAY);
4388        cx.run_until_parked();
4389
4390        // Should have completed all retries - count retry messages
4391        let retry_count = thread.update(cx, |thread, _| {
4392            thread
4393                .messages
4394                .iter()
4395                .filter(|m| {
4396                    m.ui_only
4397                        && m.segments.iter().any(|s| {
4398                            if let MessageSegment::Text(text) = s {
4399                                text.contains("Retrying") && text.contains("seconds")
4400                            } else {
4401                                false
4402                            }
4403                        })
4404                })
4405                .count()
4406        });
4407        assert_eq!(
4408            retry_count, 3,
4409            "Should have 3 retries for internal server errors"
4410        );
4411
4412        // For internal server errors, we retry 3 times and then give up
4413        // Check that retry_state is cleared after all retries
4414        thread.read_with(cx, |thread, _| {
4415            assert!(
4416                thread.retry_state.is_none(),
4417                "Retry state should be cleared after all retries"
4418            );
4419        });
4420
4421        // Verify total attempts (1 initial + 3 retries)
4422        assert_eq!(
4423            *completion_count.lock(),
4424            4,
4425            "Should have attempted once plus 3 retries"
4426        );
4427    }
4428
4429    #[gpui::test]
4430    async fn test_max_retries_exceeded(cx: &mut TestAppContext) {
4431        let fs = init_test_settings(cx);
4432
4433        let project = create_test_project(&fs, cx, json!({})).await;
4434        let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4435
4436        // Enable Burn Mode to allow retries
4437        thread.update(cx, |thread, _| {
4438            thread.set_completion_mode(CompletionMode::Burn);
4439        });
4440
4441        // Create model that returns overloaded error
4442        let model = Arc::new(ErrorInjector::new(TestError::Overloaded));
4443
4444        // Insert a user message
4445        thread.update(cx, |thread, cx| {
4446            thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4447        });
4448
4449        // Track events
4450        let stopped_with_error = Arc::new(Mutex::new(false));
4451        let stopped_with_error_clone = stopped_with_error.clone();
4452
4453        let _subscription = thread.update(cx, |_, cx| {
4454            cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| {
4455                if let ThreadEvent::Stopped(Err(_)) = event {
4456                    *stopped_with_error_clone.lock() = true;
4457                }
4458            })
4459        });
4460
4461        // Start initial completion
4462        thread.update(cx, |thread, cx| {
4463            thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4464        });
4465        cx.run_until_parked();
4466
4467        // Advance through all retries
4468        for _ in 0..MAX_RETRY_ATTEMPTS {
4469            cx.executor().advance_clock(BASE_RETRY_DELAY);
4470            cx.run_until_parked();
4471        }
4472
4473        let retry_count = thread.update(cx, |thread, _| {
4474            thread
4475                .messages
4476                .iter()
4477                .filter(|m| {
4478                    m.ui_only
4479                        && m.segments.iter().any(|s| {
4480                            if let MessageSegment::Text(text) = s {
4481                                text.contains("Retrying") && text.contains("seconds")
4482                            } else {
4483                                false
4484                            }
4485                        })
4486                })
4487                .count()
4488        });
4489
4490        // After max retries, should emit Stopped(Err(...)) event
4491        assert_eq!(
4492            retry_count, MAX_RETRY_ATTEMPTS as usize,
4493            "Should have attempted MAX_RETRY_ATTEMPTS retries for overloaded errors"
4494        );
4495        assert!(
4496            *stopped_with_error.lock(),
4497            "Should emit Stopped(Err(...)) event after max retries exceeded"
4498        );
4499
4500        // Retry state should be cleared
4501        thread.read_with(cx, |thread, _| {
4502            assert!(
4503                thread.retry_state.is_none(),
4504                "Retry state should be cleared after max retries"
4505            );
4506
4507            // Verify we have the expected number of retry messages
4508            let retry_messages = thread
4509                .messages
4510                .iter()
4511                .filter(|msg| msg.ui_only && msg.role == Role::System)
4512                .count();
4513            assert_eq!(
4514                retry_messages, MAX_RETRY_ATTEMPTS as usize,
4515                "Should have MAX_RETRY_ATTEMPTS retry messages for overloaded errors"
4516            );
4517        });
4518    }
4519
4520    #[gpui::test]
4521    async fn test_retry_message_removed_on_retry(cx: &mut TestAppContext) {
4522        let fs = init_test_settings(cx);
4523
4524        let project = create_test_project(&fs, cx, json!({})).await;
4525        let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4526
4527        // Enable Burn Mode to allow retries
4528        thread.update(cx, |thread, _| {
4529            thread.set_completion_mode(CompletionMode::Burn);
4530        });
4531
4532        // We'll use a wrapper to switch behavior after first failure
4533        struct RetryTestModel {
4534            inner: Arc<FakeLanguageModel>,
4535            failed_once: Arc<Mutex<bool>>,
4536        }
4537
4538        impl LanguageModel for RetryTestModel {
4539            fn id(&self) -> LanguageModelId {
4540                self.inner.id()
4541            }
4542
4543            fn name(&self) -> LanguageModelName {
4544                self.inner.name()
4545            }
4546
4547            fn provider_id(&self) -> LanguageModelProviderId {
4548                self.inner.provider_id()
4549            }
4550
4551            fn provider_name(&self) -> LanguageModelProviderName {
4552                self.inner.provider_name()
4553            }
4554
4555            fn supports_tools(&self) -> bool {
4556                self.inner.supports_tools()
4557            }
4558
4559            fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
4560                self.inner.supports_tool_choice(choice)
4561            }
4562
4563            fn supports_images(&self) -> bool {
4564                self.inner.supports_images()
4565            }
4566
4567            fn telemetry_id(&self) -> String {
4568                self.inner.telemetry_id()
4569            }
4570
4571            fn max_token_count(&self) -> u64 {
4572                self.inner.max_token_count()
4573            }
4574
4575            fn count_tokens(
4576                &self,
4577                request: LanguageModelRequest,
4578                cx: &App,
4579            ) -> BoxFuture<'static, Result<u64>> {
4580                self.inner.count_tokens(request, cx)
4581            }
4582
4583            fn stream_completion(
4584                &self,
4585                request: LanguageModelRequest,
4586                cx: &AsyncApp,
4587            ) -> BoxFuture<
4588                'static,
4589                Result<
4590                    BoxStream<
4591                        'static,
4592                        Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
4593                    >,
4594                    LanguageModelCompletionError,
4595                >,
4596            > {
4597                if !*self.failed_once.lock() {
4598                    *self.failed_once.lock() = true;
4599                    let provider = self.provider_name();
4600                    // Return error on first attempt
4601                    let stream = futures::stream::once(async move {
4602                        Err(LanguageModelCompletionError::ServerOverloaded {
4603                            provider,
4604                            retry_after: None,
4605                        })
4606                    });
4607                    async move { Ok(stream.boxed()) }.boxed()
4608                } else {
4609                    // Succeed on retry
4610                    self.inner.stream_completion(request, cx)
4611                }
4612            }
4613
4614            fn as_fake(&self) -> &FakeLanguageModel {
4615                &self.inner
4616            }
4617        }
4618
4619        let model = Arc::new(RetryTestModel {
4620            inner: Arc::new(FakeLanguageModel::default()),
4621            failed_once: Arc::new(Mutex::new(false)),
4622        });
4623
4624        // Insert a user message
4625        thread.update(cx, |thread, cx| {
4626            thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4627        });
4628
4629        // Track message deletions
4630        // Track when retry completes successfully
4631        let retry_completed = Arc::new(Mutex::new(false));
4632        let retry_completed_clone = retry_completed.clone();
4633
4634        let _subscription = thread.update(cx, |_, cx| {
4635            cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| {
4636                if let ThreadEvent::StreamedCompletion = event {
4637                    *retry_completed_clone.lock() = true;
4638                }
4639            })
4640        });
4641
4642        // Start completion
4643        thread.update(cx, |thread, cx| {
4644            thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4645        });
4646        cx.run_until_parked();
4647
4648        // Get the retry message ID
4649        let retry_message_id = thread.read_with(cx, |thread, _| {
4650            thread
4651                .messages()
4652                .find(|msg| msg.role == Role::System && msg.ui_only)
4653                .map(|msg| msg.id)
4654                .expect("Should have a retry message")
4655        });
4656
4657        // Wait for retry
4658        cx.executor().advance_clock(BASE_RETRY_DELAY);
4659        cx.run_until_parked();
4660
4661        // Stream some successful content
4662        let fake_model = model.as_fake();
4663        // After the retry, there should be a new pending completion
4664        let pending = fake_model.pending_completions();
4665        assert!(
4666            !pending.is_empty(),
4667            "Should have a pending completion after retry"
4668        );
4669        fake_model.send_completion_stream_text_chunk(&pending[0], "Success!");
4670        fake_model.end_completion_stream(&pending[0]);
4671        cx.run_until_parked();
4672
4673        // Check that the retry completed successfully
4674        assert!(
4675            *retry_completed.lock(),
4676            "Retry should have completed successfully"
4677        );
4678
4679        // Retry message should still exist but be marked as ui_only
4680        thread.read_with(cx, |thread, _| {
4681            let retry_msg = thread
4682                .message(retry_message_id)
4683                .expect("Retry message should still exist");
4684            assert!(retry_msg.ui_only, "Retry message should be ui_only");
4685            assert_eq!(
4686                retry_msg.role,
4687                Role::System,
4688                "Retry message should have System role"
4689            );
4690        });
4691    }
4692
4693    #[gpui::test]
4694    async fn test_successful_completion_clears_retry_state(cx: &mut TestAppContext) {
4695        let fs = init_test_settings(cx);
4696
4697        let project = create_test_project(&fs, cx, json!({})).await;
4698        let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4699
4700        // Enable Burn Mode to allow retries
4701        thread.update(cx, |thread, _| {
4702            thread.set_completion_mode(CompletionMode::Burn);
4703        });
4704
4705        // Create a model that fails once then succeeds
4706        struct FailOnceModel {
4707            inner: Arc<FakeLanguageModel>,
4708            failed_once: Arc<Mutex<bool>>,
4709        }
4710
4711        impl LanguageModel for FailOnceModel {
4712            fn id(&self) -> LanguageModelId {
4713                self.inner.id()
4714            }
4715
4716            fn name(&self) -> LanguageModelName {
4717                self.inner.name()
4718            }
4719
4720            fn provider_id(&self) -> LanguageModelProviderId {
4721                self.inner.provider_id()
4722            }
4723
4724            fn provider_name(&self) -> LanguageModelProviderName {
4725                self.inner.provider_name()
4726            }
4727
4728            fn supports_tools(&self) -> bool {
4729                self.inner.supports_tools()
4730            }
4731
4732            fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
4733                self.inner.supports_tool_choice(choice)
4734            }
4735
4736            fn supports_images(&self) -> bool {
4737                self.inner.supports_images()
4738            }
4739
4740            fn telemetry_id(&self) -> String {
4741                self.inner.telemetry_id()
4742            }
4743
4744            fn max_token_count(&self) -> u64 {
4745                self.inner.max_token_count()
4746            }
4747
4748            fn count_tokens(
4749                &self,
4750                request: LanguageModelRequest,
4751                cx: &App,
4752            ) -> BoxFuture<'static, Result<u64>> {
4753                self.inner.count_tokens(request, cx)
4754            }
4755
4756            fn stream_completion(
4757                &self,
4758                request: LanguageModelRequest,
4759                cx: &AsyncApp,
4760            ) -> BoxFuture<
4761                'static,
4762                Result<
4763                    BoxStream<
4764                        'static,
4765                        Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
4766                    >,
4767                    LanguageModelCompletionError,
4768                >,
4769            > {
4770                if !*self.failed_once.lock() {
4771                    *self.failed_once.lock() = true;
4772                    let provider = self.provider_name();
4773                    // Return error on first attempt
4774                    let stream = futures::stream::once(async move {
4775                        Err(LanguageModelCompletionError::ServerOverloaded {
4776                            provider,
4777                            retry_after: None,
4778                        })
4779                    });
4780                    async move { Ok(stream.boxed()) }.boxed()
4781                } else {
4782                    // Succeed on retry
4783                    self.inner.stream_completion(request, cx)
4784                }
4785            }
4786        }
4787
4788        let fail_once_model = Arc::new(FailOnceModel {
4789            inner: Arc::new(FakeLanguageModel::default()),
4790            failed_once: Arc::new(Mutex::new(false)),
4791        });
4792
4793        // Insert a user message
4794        thread.update(cx, |thread, cx| {
4795            thread.insert_user_message(
4796                "Test message",
4797                ContextLoadResult::default(),
4798                None,
4799                vec![],
4800                cx,
4801            );
4802        });
4803
4804        // Start completion with fail-once model
4805        thread.update(cx, |thread, cx| {
4806            thread.send_to_model(
4807                fail_once_model.clone(),
4808                CompletionIntent::UserPrompt,
4809                None,
4810                cx,
4811            );
4812        });
4813
4814        cx.run_until_parked();
4815
4816        // Verify retry state exists after first failure
4817        thread.read_with(cx, |thread, _| {
4818            assert!(
4819                thread.retry_state.is_some(),
4820                "Should have retry state after failure"
4821            );
4822        });
4823
4824        // Wait for retry delay
4825        cx.executor().advance_clock(BASE_RETRY_DELAY);
4826        cx.run_until_parked();
4827
4828        // The retry should now use our FailOnceModel which should succeed
4829        // We need to help the FakeLanguageModel complete the stream
4830        let inner_fake = fail_once_model.inner.clone();
4831
4832        // Wait a bit for the retry to start
4833        cx.run_until_parked();
4834
4835        // Check for pending completions and complete them
4836        if let Some(pending) = inner_fake.pending_completions().first() {
4837            inner_fake.send_completion_stream_text_chunk(pending, "Success!");
4838            inner_fake.end_completion_stream(pending);
4839        }
4840        cx.run_until_parked();
4841
4842        thread.read_with(cx, |thread, _| {
4843            assert!(
4844                thread.retry_state.is_none(),
4845                "Retry state should be cleared after successful completion"
4846            );
4847
4848            let has_assistant_message = thread
4849                .messages
4850                .iter()
4851                .any(|msg| msg.role == Role::Assistant && !msg.ui_only);
4852            assert!(
4853                has_assistant_message,
4854                "Should have an assistant message after successful retry"
4855            );
4856        });
4857    }
4858
4859    #[gpui::test]
4860    async fn test_rate_limit_retry_single_attempt(cx: &mut TestAppContext) {
4861        let fs = init_test_settings(cx);
4862
4863        let project = create_test_project(&fs, cx, json!({})).await;
4864        let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4865
4866        // Enable Burn Mode to allow retries
4867        thread.update(cx, |thread, _| {
4868            thread.set_completion_mode(CompletionMode::Burn);
4869        });
4870
4871        // Create a model that returns rate limit error with retry_after
4872        struct RateLimitModel {
4873            inner: Arc<FakeLanguageModel>,
4874        }
4875
4876        impl LanguageModel for RateLimitModel {
4877            fn id(&self) -> LanguageModelId {
4878                self.inner.id()
4879            }
4880
4881            fn name(&self) -> LanguageModelName {
4882                self.inner.name()
4883            }
4884
4885            fn provider_id(&self) -> LanguageModelProviderId {
4886                self.inner.provider_id()
4887            }
4888
4889            fn provider_name(&self) -> LanguageModelProviderName {
4890                self.inner.provider_name()
4891            }
4892
4893            fn supports_tools(&self) -> bool {
4894                self.inner.supports_tools()
4895            }
4896
4897            fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
4898                self.inner.supports_tool_choice(choice)
4899            }
4900
4901            fn supports_images(&self) -> bool {
4902                self.inner.supports_images()
4903            }
4904
4905            fn telemetry_id(&self) -> String {
4906                self.inner.telemetry_id()
4907            }
4908
4909            fn max_token_count(&self) -> u64 {
4910                self.inner.max_token_count()
4911            }
4912
4913            fn count_tokens(
4914                &self,
4915                request: LanguageModelRequest,
4916                cx: &App,
4917            ) -> BoxFuture<'static, Result<u64>> {
4918                self.inner.count_tokens(request, cx)
4919            }
4920
4921            fn stream_completion(
4922                &self,
4923                _request: LanguageModelRequest,
4924                _cx: &AsyncApp,
4925            ) -> BoxFuture<
4926                'static,
4927                Result<
4928                    BoxStream<
4929                        'static,
4930                        Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
4931                    >,
4932                    LanguageModelCompletionError,
4933                >,
4934            > {
4935                let provider = self.provider_name();
4936                async move {
4937                    let stream = futures::stream::once(async move {
4938                        Err(LanguageModelCompletionError::RateLimitExceeded {
4939                            provider,
4940                            retry_after: Some(Duration::from_secs(TEST_RATE_LIMIT_RETRY_SECS)),
4941                        })
4942                    });
4943                    Ok(stream.boxed())
4944                }
4945                .boxed()
4946            }
4947
4948            fn as_fake(&self) -> &FakeLanguageModel {
4949                &self.inner
4950            }
4951        }
4952
4953        let model = Arc::new(RateLimitModel {
4954            inner: Arc::new(FakeLanguageModel::default()),
4955        });
4956
4957        // Insert a user message
4958        thread.update(cx, |thread, cx| {
4959            thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4960        });
4961
4962        // Start completion
4963        thread.update(cx, |thread, cx| {
4964            thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4965        });
4966
4967        cx.run_until_parked();
4968
4969        let retry_count = thread.update(cx, |thread, _| {
4970            thread
4971                .messages
4972                .iter()
4973                .filter(|m| {
4974                    m.ui_only
4975                        && m.segments.iter().any(|s| {
4976                            if let MessageSegment::Text(text) = s {
4977                                text.contains("rate limit exceeded")
4978                            } else {
4979                                false
4980                            }
4981                        })
4982                })
4983                .count()
4984        });
4985        assert_eq!(retry_count, 1, "Should have scheduled one retry");
4986
4987        thread.read_with(cx, |thread, _| {
4988            assert!(
4989                thread.retry_state.is_some(),
4990                "Rate limit errors should set retry_state"
4991            );
4992            if let Some(retry_state) = &thread.retry_state {
4993                assert_eq!(
4994                    retry_state.max_attempts, MAX_RETRY_ATTEMPTS,
4995                    "Rate limit errors should use MAX_RETRY_ATTEMPTS"
4996                );
4997            }
4998        });
4999
5000        // Verify we have one retry message
5001        thread.read_with(cx, |thread, _| {
5002            let retry_messages = thread
5003                .messages
5004                .iter()
5005                .filter(|msg| {
5006                    msg.ui_only
5007                        && msg.segments.iter().any(|seg| {
5008                            if let MessageSegment::Text(text) = seg {
5009                                text.contains("rate limit exceeded")
5010                            } else {
5011                                false
5012                            }
5013                        })
5014                })
5015                .count();
5016            assert_eq!(
5017                retry_messages, 1,
5018                "Should have one rate limit retry message"
5019            );
5020        });
5021
5022        // Check that retry message doesn't include attempt count
5023        thread.read_with(cx, |thread, _| {
5024            let retry_message = thread
5025                .messages
5026                .iter()
5027                .find(|msg| msg.role == Role::System && msg.ui_only)
5028                .expect("Should have a retry message");
5029
5030            // Check that the message contains attempt count since we use retry_state
5031            if let Some(MessageSegment::Text(text)) = retry_message.segments.first() {
5032                assert!(
5033                    text.contains(&format!("attempt 1 of {}", MAX_RETRY_ATTEMPTS)),
5034                    "Rate limit retry message should contain attempt count with MAX_RETRY_ATTEMPTS"
5035                );
5036                assert!(
5037                    text.contains("Retrying"),
5038                    "Rate limit retry message should contain retry text"
5039                );
5040            }
5041        });
5042    }
5043
5044    #[gpui::test]
5045    async fn test_ui_only_messages_not_sent_to_model(cx: &mut TestAppContext) {
5046        let fs = init_test_settings(cx);
5047
5048        let project = create_test_project(&fs, cx, json!({})).await;
5049        let (_, _, thread, _, model) = setup_test_environment(cx, project.clone()).await;
5050
5051        // Insert a regular user message
5052        thread.update(cx, |thread, cx| {
5053            thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
5054        });
5055
5056        // Insert a UI-only message (like our retry notifications)
5057        thread.update(cx, |thread, cx| {
5058            let id = thread.next_message_id.post_inc();
5059            thread.messages.push(Message {
5060                id,
5061                role: Role::System,
5062                segments: vec![MessageSegment::Text(
5063                    "This is a UI-only message that should not be sent to the model".to_string(),
5064                )],
5065                loaded_context: LoadedContext::default(),
5066                creases: Vec::new(),
5067                is_hidden: true,
5068                ui_only: true,
5069            });
5070            cx.emit(ThreadEvent::MessageAdded(id));
5071        });
5072
5073        // Insert another regular message
5074        thread.update(cx, |thread, cx| {
5075            thread.insert_user_message(
5076                "How are you?",
5077                ContextLoadResult::default(),
5078                None,
5079                vec![],
5080                cx,
5081            );
5082        });
5083
5084        // Generate the completion request
5085        let request = thread.update(cx, |thread, cx| {
5086            thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
5087        });
5088
5089        // Verify that the request only contains non-UI-only messages
5090        // Should have system prompt + 2 user messages, but not the UI-only message
5091        let user_messages: Vec<_> = request
5092            .messages
5093            .iter()
5094            .filter(|msg| msg.role == Role::User)
5095            .collect();
5096        assert_eq!(
5097            user_messages.len(),
5098            2,
5099            "Should have exactly 2 user messages"
5100        );
5101
5102        // Verify the UI-only content is not present anywhere in the request
5103        let request_text = request
5104            .messages
5105            .iter()
5106            .flat_map(|msg| &msg.content)
5107            .filter_map(|content| match content {
5108                MessageContent::Text(text) => Some(text.as_str()),
5109                _ => None,
5110            })
5111            .collect::<String>();
5112
5113        assert!(
5114            !request_text.contains("UI-only message"),
5115            "UI-only message content should not be in the request"
5116        );
5117
5118        // Verify the thread still has all 3 messages (including UI-only)
5119        thread.read_with(cx, |thread, _| {
5120            assert_eq!(
5121                thread.messages().count(),
5122                3,
5123                "Thread should have 3 messages"
5124            );
5125            assert_eq!(
5126                thread.messages().filter(|m| m.ui_only).count(),
5127                1,
5128                "Thread should have 1 UI-only message"
5129            );
5130        });
5131
5132        // Verify that UI-only messages are not serialized
5133        let serialized = thread
5134            .update(cx, |thread, cx| thread.serialize(cx))
5135            .await
5136            .unwrap();
5137        assert_eq!(
5138            serialized.messages.len(),
5139            2,
5140            "Serialized thread should only have 2 messages (no UI-only)"
5141        );
5142    }
5143
5144    #[gpui::test]
5145    async fn test_no_retry_without_burn_mode(cx: &mut TestAppContext) {
5146        let fs = init_test_settings(cx);
5147
5148        let project = create_test_project(&fs, cx, json!({})).await;
5149        let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
5150
5151        // Ensure we're in Normal mode (not Burn mode)
5152        thread.update(cx, |thread, _| {
5153            thread.set_completion_mode(CompletionMode::Normal);
5154        });
5155
5156        // Track error events
5157        let error_events = Arc::new(Mutex::new(Vec::new()));
5158        let error_events_clone = error_events.clone();
5159
5160        let _subscription = thread.update(cx, |_, cx| {
5161            cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| {
5162                if let ThreadEvent::ShowError(error) = event {
5163                    error_events_clone.lock().push(error.clone());
5164                }
5165            })
5166        });
5167
5168        // Create model that returns overloaded error
5169        let model = Arc::new(ErrorInjector::new(TestError::Overloaded));
5170
5171        // Insert a user message
5172        thread.update(cx, |thread, cx| {
5173            thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
5174        });
5175
5176        // Start completion
5177        thread.update(cx, |thread, cx| {
5178            thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
5179        });
5180
5181        cx.run_until_parked();
5182
5183        // Verify no retry state was created
5184        thread.read_with(cx, |thread, _| {
5185            assert!(
5186                thread.retry_state.is_none(),
5187                "Should not have retry state in Normal mode"
5188            );
5189        });
5190
5191        // Check that a retryable error was reported
5192        let errors = error_events.lock();
5193        assert!(!errors.is_empty(), "Should have received an error event");
5194
5195        if let ThreadError::RetryableError {
5196            message: _,
5197            can_enable_burn_mode,
5198        } = &errors[0]
5199        {
5200            assert!(
5201                *can_enable_burn_mode,
5202                "Error should indicate burn mode can be enabled"
5203            );
5204        } else {
5205            panic!("Expected RetryableError, got {:?}", errors[0]);
5206        }
5207
5208        // Verify the thread is no longer generating
5209        thread.read_with(cx, |thread, _| {
5210            assert!(
5211                !thread.is_generating(),
5212                "Should not be generating after error without retry"
5213            );
5214        });
5215    }
5216
5217    #[gpui::test]
5218    async fn test_retry_canceled_on_stop(cx: &mut TestAppContext) {
5219        let fs = init_test_settings(cx);
5220
5221        let project = create_test_project(&fs, cx, json!({})).await;
5222        let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
5223
5224        // Enable Burn Mode to allow retries
5225        thread.update(cx, |thread, _| {
5226            thread.set_completion_mode(CompletionMode::Burn);
5227        });
5228
5229        // Create model that returns overloaded error
5230        let model = Arc::new(ErrorInjector::new(TestError::Overloaded));
5231
5232        // Insert a user message
5233        thread.update(cx, |thread, cx| {
5234            thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
5235        });
5236
5237        // Start completion
5238        thread.update(cx, |thread, cx| {
5239            thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
5240        });
5241
5242        cx.run_until_parked();
5243
5244        // Verify retry was scheduled by checking for retry message
5245        let has_retry_message = thread.read_with(cx, |thread, _| {
5246            thread.messages.iter().any(|m| {
5247                m.ui_only
5248                    && m.segments.iter().any(|s| {
5249                        if let MessageSegment::Text(text) = s {
5250                            text.contains("Retrying") && text.contains("seconds")
5251                        } else {
5252                            false
5253                        }
5254                    })
5255            })
5256        });
5257        assert!(has_retry_message, "Should have scheduled a retry");
5258
5259        // Cancel the completion before the retry happens
5260        thread.update(cx, |thread, cx| {
5261            thread.cancel_last_completion(None, cx);
5262        });
5263
5264        cx.run_until_parked();
5265
5266        // The retry should not have happened - no pending completions
5267        let fake_model = model.as_fake();
5268        assert_eq!(
5269            fake_model.pending_completions().len(),
5270            0,
5271            "Should have no pending completions after cancellation"
5272        );
5273
5274        // Verify the retry was canceled by checking retry state
5275        thread.read_with(cx, |thread, _| {
5276            if let Some(retry_state) = &thread.retry_state {
5277                panic!(
5278                    "retry_state should be cleared after cancellation, but found: attempt={}, max_attempts={}, intent={:?}",
5279                    retry_state.attempt, retry_state.max_attempts, retry_state.intent
5280                );
5281            }
5282        });
5283    }
5284
5285    fn test_summarize_error(
5286        model: &Arc<dyn LanguageModel>,
5287        thread: &Entity<Thread>,
5288        cx: &mut TestAppContext,
5289    ) {
5290        thread.update(cx, |thread, cx| {
5291            thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
5292            thread.send_to_model(
5293                model.clone(),
5294                CompletionIntent::ThreadSummarization,
5295                None,
5296                cx,
5297            );
5298        });
5299
5300        let fake_model = model.as_fake();
5301        simulate_successful_response(fake_model, cx);
5302
5303        thread.read_with(cx, |thread, _| {
5304            assert!(matches!(thread.summary(), ThreadSummary::Generating));
5305            assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
5306        });
5307
5308        // Simulate summary request ending
5309        cx.run_until_parked();
5310        fake_model.end_last_completion_stream();
5311        cx.run_until_parked();
5312
5313        // State is set to Error and default message
5314        thread.read_with(cx, |thread, _| {
5315            assert!(matches!(thread.summary(), ThreadSummary::Error));
5316            assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
5317        });
5318    }
5319
5320    fn simulate_successful_response(fake_model: &FakeLanguageModel, cx: &mut TestAppContext) {
5321        cx.run_until_parked();
5322        fake_model.send_last_completion_stream_text_chunk("Assistant response");
5323        fake_model.end_last_completion_stream();
5324        cx.run_until_parked();
5325    }
5326
5327    fn init_test_settings(cx: &mut TestAppContext) -> Arc<dyn Fs> {
5328        let fs = FakeFs::new(cx.executor());
5329        cx.update(|cx| {
5330            let settings_store = SettingsStore::test(cx);
5331            cx.set_global(settings_store);
5332            language::init(cx);
5333            Project::init_settings(cx);
5334            AgentSettings::register(cx);
5335            prompt_store::init(cx);
5336            thread_store::init(fs.clone(), cx);
5337            workspace::init_settings(cx);
5338            language_model::init_settings(cx);
5339            theme::init(theme::LoadThemes::JustBase, cx);
5340            ToolRegistry::default_global(cx);
5341            assistant_tool::init(cx);
5342
5343            let http_client = Arc::new(http_client::HttpClientWithUrl::new(
5344                http_client::FakeHttpClient::with_200_response(),
5345                "http://localhost".to_string(),
5346                None,
5347            ));
5348            assistant_tools::init(http_client, cx);
5349        });
5350        fs
5351    }
5352
5353    // Helper to create a test project with test files
5354    async fn create_test_project(
5355        fs: &Arc<dyn Fs>,
5356        cx: &mut TestAppContext,
5357        files: serde_json::Value,
5358    ) -> Entity<Project> {
5359        fs.as_fake().insert_tree(path!("/test"), files).await;
5360        Project::test(fs.clone(), [path!("/test").as_ref()], cx).await
5361    }
5362
5363    async fn setup_test_environment(
5364        cx: &mut TestAppContext,
5365        project: Entity<Project>,
5366    ) -> (
5367        Entity<Workspace>,
5368        Entity<ThreadStore>,
5369        Entity<Thread>,
5370        Entity<ContextStore>,
5371        Arc<dyn LanguageModel>,
5372    ) {
5373        let (workspace, cx) =
5374            cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
5375
5376        let thread_store = cx
5377            .update(|_, cx| {
5378                ThreadStore::load(
5379                    project.clone(),
5380                    cx.new(|_| ToolWorkingSet::default()),
5381                    None,
5382                    Arc::new(PromptBuilder::new(None).unwrap()),
5383                    cx,
5384                )
5385            })
5386            .await
5387            .unwrap();
5388
5389        let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
5390        let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
5391
5392        let provider = Arc::new(FakeLanguageModelProvider::default());
5393        let model = provider.test_model();
5394        let model: Arc<dyn LanguageModel> = Arc::new(model);
5395
5396        cx.update(|_, cx| {
5397            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
5398                registry.set_default_model(
5399                    Some(ConfiguredModel {
5400                        provider: provider.clone(),
5401                        model: model.clone(),
5402                    }),
5403                    cx,
5404                );
5405                registry.set_thread_summary_model(
5406                    Some(ConfiguredModel {
5407                        provider,
5408                        model: model.clone(),
5409                    }),
5410                    cx,
5411                );
5412            })
5413        });
5414
5415        (workspace, thread_store, thread, context_store, model)
5416    }
5417
5418    async fn add_file_to_context(
5419        project: &Entity<Project>,
5420        context_store: &Entity<ContextStore>,
5421        path: &str,
5422        cx: &mut TestAppContext,
5423    ) -> Result<Entity<language::Buffer>> {
5424        let buffer_path = project
5425            .read_with(cx, |project, cx| project.find_project_path(path, cx))
5426            .unwrap();
5427
5428        let buffer = project
5429            .update(cx, |project, cx| {
5430                project.open_buffer(buffer_path.clone(), cx)
5431            })
5432            .await
5433            .unwrap();
5434
5435        context_store.update(cx, |context_store, cx| {
5436            context_store.add_file_from_buffer(&buffer_path, buffer.clone(), false, cx);
5437        });
5438
5439        Ok(buffer)
5440    }
5441}