thread.rs

   1use crate::{
   2    agent_profile::AgentProfile,
   3    context::{AgentContext, AgentContextHandle, ContextLoadResult, LoadedContext},
   4    thread_store::{
   5        SerializedCrease, SerializedLanguageModel, SerializedMessage, SerializedMessageSegment,
   6        SerializedThread, SerializedToolResult, SerializedToolUse, SharedProjectContext,
   7        ThreadStore,
   8    },
   9    tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState},
  10};
  11use agent_settings::{AgentProfileId, AgentSettings, CompletionMode};
  12use anyhow::{Result, anyhow};
  13use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
  14use chrono::{DateTime, Utc};
  15use client::{ModelRequestUsage, RequestUsage};
  16use collections::HashMap;
  17use feature_flags::{self, FeatureFlagAppExt};
  18use futures::{FutureExt, StreamExt as _, future::Shared};
  19use git::repository::DiffType;
  20use gpui::{
  21    AnyWindowHandle, App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task,
  22    WeakEntity, Window,
  23};
  24use language_model::{
  25    ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
  26    LanguageModelExt as _, LanguageModelId, LanguageModelRegistry, LanguageModelRequest,
  27    LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
  28    LanguageModelToolResultContent, LanguageModelToolUse, LanguageModelToolUseId, MessageContent,
  29    ModelRequestLimitReachedError, PaymentRequiredError, Role, SelectedModel, StopReason,
  30    TokenUsage,
  31};
  32use postage::stream::Stream as _;
  33use project::{
  34    Project,
  35    git_store::{GitStore, GitStoreCheckpoint, RepositoryState},
  36};
  37use prompt_store::{ModelContext, PromptBuilder};
  38use proto::Plan;
  39use schemars::JsonSchema;
  40use serde::{Deserialize, Serialize};
  41use settings::Settings;
  42use std::{
  43    io::Write,
  44    ops::Range,
  45    sync::Arc,
  46    time::{Duration, Instant},
  47};
  48use thiserror::Error;
  49use util::{ResultExt as _, debug_panic, post_inc};
  50use uuid::Uuid;
  51use zed_llm_client::{CompletionIntent, CompletionRequestStatus, UsageLimit};
  52
  53const MAX_RETRY_ATTEMPTS: u8 = 3;
  54const BASE_RETRY_DELAY_SECS: u64 = 5;
  55
  56#[derive(
  57    Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
  58)]
  59pub struct ThreadId(Arc<str>);
  60
  61impl ThreadId {
  62    pub fn new() -> Self {
  63        Self(Uuid::new_v4().to_string().into())
  64    }
  65}
  66
  67impl std::fmt::Display for ThreadId {
  68    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  69        write!(f, "{}", self.0)
  70    }
  71}
  72
  73impl From<&str> for ThreadId {
  74    fn from(value: &str) -> Self {
  75        Self(value.into())
  76    }
  77}
  78
  79/// The ID of the user prompt that initiated a request.
  80///
  81/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
  82#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
  83pub struct PromptId(Arc<str>);
  84
  85impl PromptId {
  86    pub fn new() -> Self {
  87        Self(Uuid::new_v4().to_string().into())
  88    }
  89}
  90
  91impl std::fmt::Display for PromptId {
  92    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  93        write!(f, "{}", self.0)
  94    }
  95}
  96
  97#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
  98pub struct MessageId(pub(crate) usize);
  99
 100impl MessageId {
 101    fn post_inc(&mut self) -> Self {
 102        Self(post_inc(&mut self.0))
 103    }
 104
 105    pub fn as_usize(&self) -> usize {
 106        self.0
 107    }
 108}
 109
 110/// Stored information that can be used to resurrect a context crease when creating an editor for a past message.
 111#[derive(Clone, Debug)]
 112pub struct MessageCrease {
 113    pub range: Range<usize>,
 114    pub icon_path: SharedString,
 115    pub label: SharedString,
 116    /// None for a deserialized message, Some otherwise.
 117    pub context: Option<AgentContextHandle>,
 118}
 119
 120/// A message in a [`Thread`].
 121#[derive(Debug, Clone)]
 122pub struct Message {
 123    pub id: MessageId,
 124    pub role: Role,
 125    pub segments: Vec<MessageSegment>,
 126    pub loaded_context: LoadedContext,
 127    pub creases: Vec<MessageCrease>,
 128    pub is_hidden: bool,
 129    pub ui_only: bool,
 130}
 131
 132impl Message {
 133    /// Returns whether the message contains any meaningful text that should be displayed
 134    /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
 135    pub fn should_display_content(&self) -> bool {
 136        self.segments.iter().all(|segment| segment.should_display())
 137    }
 138
 139    pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
 140        if let Some(MessageSegment::Thinking {
 141            text: segment,
 142            signature: current_signature,
 143        }) = self.segments.last_mut()
 144        {
 145            if let Some(signature) = signature {
 146                *current_signature = Some(signature);
 147            }
 148            segment.push_str(text);
 149        } else {
 150            self.segments.push(MessageSegment::Thinking {
 151                text: text.to_string(),
 152                signature,
 153            });
 154        }
 155    }
 156
 157    pub fn push_redacted_thinking(&mut self, data: String) {
 158        self.segments.push(MessageSegment::RedactedThinking(data));
 159    }
 160
 161    pub fn push_text(&mut self, text: &str) {
 162        if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
 163            segment.push_str(text);
 164        } else {
 165            self.segments.push(MessageSegment::Text(text.to_string()));
 166        }
 167    }
 168
 169    pub fn to_string(&self) -> String {
 170        let mut result = String::new();
 171
 172        if !self.loaded_context.text.is_empty() {
 173            result.push_str(&self.loaded_context.text);
 174        }
 175
 176        for segment in &self.segments {
 177            match segment {
 178                MessageSegment::Text(text) => result.push_str(text),
 179                MessageSegment::Thinking { text, .. } => {
 180                    result.push_str("<think>\n");
 181                    result.push_str(text);
 182                    result.push_str("\n</think>");
 183                }
 184                MessageSegment::RedactedThinking(_) => {}
 185            }
 186        }
 187
 188        result
 189    }
 190}
 191
 192#[derive(Debug, Clone, PartialEq, Eq)]
 193pub enum MessageSegment {
 194    Text(String),
 195    Thinking {
 196        text: String,
 197        signature: Option<String>,
 198    },
 199    RedactedThinking(String),
 200}
 201
 202impl MessageSegment {
 203    pub fn should_display(&self) -> bool {
 204        match self {
 205            Self::Text(text) => text.is_empty(),
 206            Self::Thinking { text, .. } => text.is_empty(),
 207            Self::RedactedThinking(_) => false,
 208        }
 209    }
 210
 211    pub fn text(&self) -> Option<&str> {
 212        match self {
 213            MessageSegment::Text(text) => Some(text),
 214            _ => None,
 215        }
 216    }
 217}
 218
 219#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
 220pub struct ProjectSnapshot {
 221    pub worktree_snapshots: Vec<WorktreeSnapshot>,
 222    pub unsaved_buffer_paths: Vec<String>,
 223    pub timestamp: DateTime<Utc>,
 224}
 225
 226#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
 227pub struct WorktreeSnapshot {
 228    pub worktree_path: String,
 229    pub git_state: Option<GitState>,
 230}
 231
 232#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
 233pub struct GitState {
 234    pub remote_url: Option<String>,
 235    pub head_sha: Option<String>,
 236    pub current_branch: Option<String>,
 237    pub diff: Option<String>,
 238}
 239
 240#[derive(Clone, Debug)]
 241pub struct ThreadCheckpoint {
 242    message_id: MessageId,
 243    git_checkpoint: GitStoreCheckpoint,
 244}
 245
 246#[derive(Copy, Clone, Debug, PartialEq, Eq)]
 247pub enum ThreadFeedback {
 248    Positive,
 249    Negative,
 250}
 251
 252pub enum LastRestoreCheckpoint {
 253    Pending {
 254        message_id: MessageId,
 255    },
 256    Error {
 257        message_id: MessageId,
 258        error: String,
 259    },
 260}
 261
 262impl LastRestoreCheckpoint {
 263    pub fn message_id(&self) -> MessageId {
 264        match self {
 265            LastRestoreCheckpoint::Pending { message_id } => *message_id,
 266            LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
 267        }
 268    }
 269}
 270
 271#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
 272pub enum DetailedSummaryState {
 273    #[default]
 274    NotGenerated,
 275    Generating {
 276        message_id: MessageId,
 277    },
 278    Generated {
 279        text: SharedString,
 280        message_id: MessageId,
 281    },
 282}
 283
 284impl DetailedSummaryState {
 285    fn text(&self) -> Option<SharedString> {
 286        if let Self::Generated { text, .. } = self {
 287            Some(text.clone())
 288        } else {
 289            None
 290        }
 291    }
 292}
 293
 294#[derive(Default, Debug)]
 295pub struct TotalTokenUsage {
 296    pub total: u64,
 297    pub max: u64,
 298}
 299
 300impl TotalTokenUsage {
 301    pub fn ratio(&self) -> TokenUsageRatio {
 302        #[cfg(debug_assertions)]
 303        let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
 304            .unwrap_or("0.8".to_string())
 305            .parse()
 306            .unwrap();
 307        #[cfg(not(debug_assertions))]
 308        let warning_threshold: f32 = 0.8;
 309
 310        // When the maximum is unknown because there is no selected model,
 311        // avoid showing the token limit warning.
 312        if self.max == 0 {
 313            TokenUsageRatio::Normal
 314        } else if self.total >= self.max {
 315            TokenUsageRatio::Exceeded
 316        } else if self.total as f32 / self.max as f32 >= warning_threshold {
 317            TokenUsageRatio::Warning
 318        } else {
 319            TokenUsageRatio::Normal
 320        }
 321    }
 322
 323    pub fn add(&self, tokens: u64) -> TotalTokenUsage {
 324        TotalTokenUsage {
 325            total: self.total + tokens,
 326            max: self.max,
 327        }
 328    }
 329}
 330
 331#[derive(Debug, Default, PartialEq, Eq)]
 332pub enum TokenUsageRatio {
 333    #[default]
 334    Normal,
 335    Warning,
 336    Exceeded,
 337}
 338
 339#[derive(Debug, Clone, Copy)]
 340pub enum QueueState {
 341    Sending,
 342    Queued { position: usize },
 343    Started,
 344}
 345
 346/// A thread of conversation with the LLM.
 347pub struct Thread {
 348    id: ThreadId,
 349    updated_at: DateTime<Utc>,
 350    summary: ThreadSummary,
 351    pending_summary: Task<Option<()>>,
 352    detailed_summary_task: Task<Option<()>>,
 353    detailed_summary_tx: postage::watch::Sender<DetailedSummaryState>,
 354    detailed_summary_rx: postage::watch::Receiver<DetailedSummaryState>,
 355    completion_mode: agent_settings::CompletionMode,
 356    messages: Vec<Message>,
 357    next_message_id: MessageId,
 358    last_prompt_id: PromptId,
 359    project_context: SharedProjectContext,
 360    checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
 361    completion_count: usize,
 362    pending_completions: Vec<PendingCompletion>,
 363    project: Entity<Project>,
 364    prompt_builder: Arc<PromptBuilder>,
 365    tools: Entity<ToolWorkingSet>,
 366    tool_use: ToolUseState,
 367    action_log: Entity<ActionLog>,
 368    last_restore_checkpoint: Option<LastRestoreCheckpoint>,
 369    pending_checkpoint: Option<ThreadCheckpoint>,
 370    initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
 371    request_token_usage: Vec<TokenUsage>,
 372    cumulative_token_usage: TokenUsage,
 373    exceeded_window_error: Option<ExceededWindowError>,
 374    tool_use_limit_reached: bool,
 375    feedback: Option<ThreadFeedback>,
 376    retry_state: Option<RetryState>,
 377    message_feedback: HashMap<MessageId, ThreadFeedback>,
 378    last_auto_capture_at: Option<Instant>,
 379    last_received_chunk_at: Option<Instant>,
 380    request_callback: Option<
 381        Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
 382    >,
 383    remaining_turns: u32,
 384    configured_model: Option<ConfiguredModel>,
 385    profile: AgentProfile,
 386}
 387
 388#[derive(Clone, Debug)]
 389struct RetryState {
 390    attempt: u8,
 391    max_attempts: u8,
 392    intent: CompletionIntent,
 393}
 394
 395#[derive(Clone, Debug, PartialEq, Eq)]
 396pub enum ThreadSummary {
 397    Pending,
 398    Generating,
 399    Ready(SharedString),
 400    Error,
 401}
 402
 403impl ThreadSummary {
 404    pub const DEFAULT: SharedString = SharedString::new_static("New Thread");
 405
 406    pub fn or_default(&self) -> SharedString {
 407        self.unwrap_or(Self::DEFAULT)
 408    }
 409
 410    pub fn unwrap_or(&self, message: impl Into<SharedString>) -> SharedString {
 411        self.ready().unwrap_or_else(|| message.into())
 412    }
 413
 414    pub fn ready(&self) -> Option<SharedString> {
 415        match self {
 416            ThreadSummary::Ready(summary) => Some(summary.clone()),
 417            ThreadSummary::Pending | ThreadSummary::Generating | ThreadSummary::Error => None,
 418        }
 419    }
 420}
 421
 422#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
 423pub struct ExceededWindowError {
 424    /// Model used when last message exceeded context window
 425    model_id: LanguageModelId,
 426    /// Token count including last message
 427    token_count: u64,
 428}
 429
 430impl Thread {
 431    pub fn new(
 432        project: Entity<Project>,
 433        tools: Entity<ToolWorkingSet>,
 434        prompt_builder: Arc<PromptBuilder>,
 435        system_prompt: SharedProjectContext,
 436        cx: &mut Context<Self>,
 437    ) -> Self {
 438        let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel();
 439        let configured_model = LanguageModelRegistry::read_global(cx).default_model();
 440        let profile_id = AgentSettings::get_global(cx).default_profile.clone();
 441
 442        Self {
 443            id: ThreadId::new(),
 444            updated_at: Utc::now(),
 445            summary: ThreadSummary::Pending,
 446            pending_summary: Task::ready(None),
 447            detailed_summary_task: Task::ready(None),
 448            detailed_summary_tx,
 449            detailed_summary_rx,
 450            completion_mode: AgentSettings::get_global(cx).preferred_completion_mode,
 451            messages: Vec::new(),
 452            next_message_id: MessageId(0),
 453            last_prompt_id: PromptId::new(),
 454            project_context: system_prompt,
 455            checkpoints_by_message: HashMap::default(),
 456            completion_count: 0,
 457            pending_completions: Vec::new(),
 458            project: project.clone(),
 459            prompt_builder,
 460            tools: tools.clone(),
 461            last_restore_checkpoint: None,
 462            pending_checkpoint: None,
 463            tool_use: ToolUseState::new(tools.clone()),
 464            action_log: cx.new(|_| ActionLog::new(project.clone())),
 465            initial_project_snapshot: {
 466                let project_snapshot = Self::project_snapshot(project, cx);
 467                cx.foreground_executor()
 468                    .spawn(async move { Some(project_snapshot.await) })
 469                    .shared()
 470            },
 471            request_token_usage: Vec::new(),
 472            cumulative_token_usage: TokenUsage::default(),
 473            exceeded_window_error: None,
 474            tool_use_limit_reached: false,
 475            feedback: None,
 476            retry_state: None,
 477            message_feedback: HashMap::default(),
 478            last_auto_capture_at: None,
 479            last_received_chunk_at: None,
 480            request_callback: None,
 481            remaining_turns: u32::MAX,
 482            configured_model,
 483            profile: AgentProfile::new(profile_id, tools),
 484        }
 485    }
 486
 487    pub fn deserialize(
 488        id: ThreadId,
 489        serialized: SerializedThread,
 490        project: Entity<Project>,
 491        tools: Entity<ToolWorkingSet>,
 492        prompt_builder: Arc<PromptBuilder>,
 493        project_context: SharedProjectContext,
 494        window: Option<&mut Window>, // None in headless mode
 495        cx: &mut Context<Self>,
 496    ) -> Self {
 497        let next_message_id = MessageId(
 498            serialized
 499                .messages
 500                .last()
 501                .map(|message| message.id.0 + 1)
 502                .unwrap_or(0),
 503        );
 504        let tool_use = ToolUseState::from_serialized_messages(
 505            tools.clone(),
 506            &serialized.messages,
 507            project.clone(),
 508            window,
 509            cx,
 510        );
 511        let (detailed_summary_tx, detailed_summary_rx) =
 512            postage::watch::channel_with(serialized.detailed_summary_state);
 513
 514        let configured_model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
 515            serialized
 516                .model
 517                .and_then(|model| {
 518                    let model = SelectedModel {
 519                        provider: model.provider.clone().into(),
 520                        model: model.model.clone().into(),
 521                    };
 522                    registry.select_model(&model, cx)
 523                })
 524                .or_else(|| registry.default_model())
 525        });
 526
 527        let completion_mode = serialized
 528            .completion_mode
 529            .unwrap_or_else(|| AgentSettings::get_global(cx).preferred_completion_mode);
 530        let profile_id = serialized
 531            .profile
 532            .unwrap_or_else(|| AgentSettings::get_global(cx).default_profile.clone());
 533
 534        Self {
 535            id,
 536            updated_at: serialized.updated_at,
 537            summary: ThreadSummary::Ready(serialized.summary),
 538            pending_summary: Task::ready(None),
 539            detailed_summary_task: Task::ready(None),
 540            detailed_summary_tx,
 541            detailed_summary_rx,
 542            completion_mode,
 543            retry_state: None,
 544            messages: serialized
 545                .messages
 546                .into_iter()
 547                .map(|message| Message {
 548                    id: message.id,
 549                    role: message.role,
 550                    segments: message
 551                        .segments
 552                        .into_iter()
 553                        .map(|segment| match segment {
 554                            SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
 555                            SerializedMessageSegment::Thinking { text, signature } => {
 556                                MessageSegment::Thinking { text, signature }
 557                            }
 558                            SerializedMessageSegment::RedactedThinking { data } => {
 559                                MessageSegment::RedactedThinking(data)
 560                            }
 561                        })
 562                        .collect(),
 563                    loaded_context: LoadedContext {
 564                        contexts: Vec::new(),
 565                        text: message.context,
 566                        images: Vec::new(),
 567                    },
 568                    creases: message
 569                        .creases
 570                        .into_iter()
 571                        .map(|crease| MessageCrease {
 572                            range: crease.start..crease.end,
 573                            icon_path: crease.icon_path,
 574                            label: crease.label,
 575                            context: None,
 576                        })
 577                        .collect(),
 578                    is_hidden: message.is_hidden,
 579                    ui_only: false, // UI-only messages are not persisted
 580                })
 581                .collect(),
 582            next_message_id,
 583            last_prompt_id: PromptId::new(),
 584            project_context,
 585            checkpoints_by_message: HashMap::default(),
 586            completion_count: 0,
 587            pending_completions: Vec::new(),
 588            last_restore_checkpoint: None,
 589            pending_checkpoint: None,
 590            project: project.clone(),
 591            prompt_builder,
 592            tools: tools.clone(),
 593            tool_use,
 594            action_log: cx.new(|_| ActionLog::new(project)),
 595            initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
 596            request_token_usage: serialized.request_token_usage,
 597            cumulative_token_usage: serialized.cumulative_token_usage,
 598            exceeded_window_error: None,
 599            tool_use_limit_reached: serialized.tool_use_limit_reached,
 600            feedback: None,
 601            message_feedback: HashMap::default(),
 602            last_auto_capture_at: None,
 603            last_received_chunk_at: None,
 604            request_callback: None,
 605            remaining_turns: u32::MAX,
 606            configured_model,
 607            profile: AgentProfile::new(profile_id, tools),
 608        }
 609    }
 610
 611    pub fn set_request_callback(
 612        &mut self,
 613        callback: impl 'static
 614        + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
 615    ) {
 616        self.request_callback = Some(Box::new(callback));
 617    }
 618
 619    pub fn id(&self) -> &ThreadId {
 620        &self.id
 621    }
 622
 623    pub fn profile(&self) -> &AgentProfile {
 624        &self.profile
 625    }
 626
 627    pub fn set_profile(&mut self, id: AgentProfileId, cx: &mut Context<Self>) {
 628        if &id != self.profile.id() {
 629            self.profile = AgentProfile::new(id, self.tools.clone());
 630            cx.emit(ThreadEvent::ProfileChanged);
 631        }
 632    }
 633
 634    pub fn is_empty(&self) -> bool {
 635        self.messages.is_empty()
 636    }
 637
 638    pub fn updated_at(&self) -> DateTime<Utc> {
 639        self.updated_at
 640    }
 641
 642    pub fn touch_updated_at(&mut self) {
 643        self.updated_at = Utc::now();
 644    }
 645
 646    pub fn advance_prompt_id(&mut self) {
 647        self.last_prompt_id = PromptId::new();
 648    }
 649
 650    pub fn project_context(&self) -> SharedProjectContext {
 651        self.project_context.clone()
 652    }
 653
 654    pub fn get_or_init_configured_model(&mut self, cx: &App) -> Option<ConfiguredModel> {
 655        if self.configured_model.is_none() {
 656            self.configured_model = LanguageModelRegistry::read_global(cx).default_model();
 657        }
 658        self.configured_model.clone()
 659    }
 660
 661    pub fn configured_model(&self) -> Option<ConfiguredModel> {
 662        self.configured_model.clone()
 663    }
 664
 665    pub fn set_configured_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
 666        self.configured_model = model;
 667        cx.notify();
 668    }
 669
 670    pub fn summary(&self) -> &ThreadSummary {
 671        &self.summary
 672    }
 673
 674    pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
 675        let current_summary = match &self.summary {
 676            ThreadSummary::Pending | ThreadSummary::Generating => return,
 677            ThreadSummary::Ready(summary) => summary,
 678            ThreadSummary::Error => &ThreadSummary::DEFAULT,
 679        };
 680
 681        let mut new_summary = new_summary.into();
 682
 683        if new_summary.is_empty() {
 684            new_summary = ThreadSummary::DEFAULT;
 685        }
 686
 687        if current_summary != &new_summary {
 688            self.summary = ThreadSummary::Ready(new_summary);
 689            cx.emit(ThreadEvent::SummaryChanged);
 690        }
 691    }
 692
 693    pub fn completion_mode(&self) -> CompletionMode {
 694        self.completion_mode
 695    }
 696
 697    pub fn set_completion_mode(&mut self, mode: CompletionMode) {
 698        self.completion_mode = mode;
 699    }
 700
 701    pub fn message(&self, id: MessageId) -> Option<&Message> {
 702        let index = self
 703            .messages
 704            .binary_search_by(|message| message.id.cmp(&id))
 705            .ok()?;
 706
 707        self.messages.get(index)
 708    }
 709
 710    pub fn messages(&self) -> impl ExactSizeIterator<Item = &Message> {
 711        self.messages.iter()
 712    }
 713
 714    pub fn is_generating(&self) -> bool {
 715        !self.pending_completions.is_empty() || !self.all_tools_finished()
 716    }
 717
 718    /// Indicates whether streaming of language model events is stale.
 719    /// When `is_generating()` is false, this method returns `None`.
 720    pub fn is_generation_stale(&self) -> Option<bool> {
 721        const STALE_THRESHOLD: u128 = 250;
 722
 723        self.last_received_chunk_at
 724            .map(|instant| instant.elapsed().as_millis() > STALE_THRESHOLD)
 725    }
 726
 727    fn received_chunk(&mut self) {
 728        self.last_received_chunk_at = Some(Instant::now());
 729    }
 730
 731    pub fn queue_state(&self) -> Option<QueueState> {
 732        self.pending_completions
 733            .first()
 734            .map(|pending_completion| pending_completion.queue_state)
 735    }
 736
 737    pub fn tools(&self) -> &Entity<ToolWorkingSet> {
 738        &self.tools
 739    }
 740
 741    pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
 742        self.tool_use
 743            .pending_tool_uses()
 744            .into_iter()
 745            .find(|tool_use| &tool_use.id == id)
 746    }
 747
 748    pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
 749        self.tool_use
 750            .pending_tool_uses()
 751            .into_iter()
 752            .filter(|tool_use| tool_use.status.needs_confirmation())
 753    }
 754
 755    pub fn has_pending_tool_uses(&self) -> bool {
 756        !self.tool_use.pending_tool_uses().is_empty()
 757    }
 758
 759    pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
 760        self.checkpoints_by_message.get(&id).cloned()
 761    }
 762
 763    pub fn restore_checkpoint(
 764        &mut self,
 765        checkpoint: ThreadCheckpoint,
 766        cx: &mut Context<Self>,
 767    ) -> Task<Result<()>> {
 768        self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
 769            message_id: checkpoint.message_id,
 770        });
 771        cx.emit(ThreadEvent::CheckpointChanged);
 772        cx.notify();
 773
 774        let git_store = self.project().read(cx).git_store().clone();
 775        let restore = git_store.update(cx, |git_store, cx| {
 776            git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
 777        });
 778
 779        cx.spawn(async move |this, cx| {
 780            let result = restore.await;
 781            this.update(cx, |this, cx| {
 782                if let Err(err) = result.as_ref() {
 783                    this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
 784                        message_id: checkpoint.message_id,
 785                        error: err.to_string(),
 786                    });
 787                } else {
 788                    this.truncate(checkpoint.message_id, cx);
 789                    this.last_restore_checkpoint = None;
 790                }
 791                this.pending_checkpoint = None;
 792                cx.emit(ThreadEvent::CheckpointChanged);
 793                cx.notify();
 794            })?;
 795            result
 796        })
 797    }
 798
 799    fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
 800        let pending_checkpoint = if self.is_generating() {
 801            return;
 802        } else if let Some(checkpoint) = self.pending_checkpoint.take() {
 803            checkpoint
 804        } else {
 805            return;
 806        };
 807
 808        self.finalize_checkpoint(pending_checkpoint, cx);
 809    }
 810
 811    fn finalize_checkpoint(
 812        &mut self,
 813        pending_checkpoint: ThreadCheckpoint,
 814        cx: &mut Context<Self>,
 815    ) {
 816        let git_store = self.project.read(cx).git_store().clone();
 817        let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
 818        cx.spawn(async move |this, cx| match final_checkpoint.await {
 819            Ok(final_checkpoint) => {
 820                let equal = git_store
 821                    .update(cx, |store, cx| {
 822                        store.compare_checkpoints(
 823                            pending_checkpoint.git_checkpoint.clone(),
 824                            final_checkpoint.clone(),
 825                            cx,
 826                        )
 827                    })?
 828                    .await
 829                    .unwrap_or(false);
 830
 831                if !equal {
 832                    this.update(cx, |this, cx| {
 833                        this.insert_checkpoint(pending_checkpoint, cx)
 834                    })?;
 835                }
 836
 837                Ok(())
 838            }
 839            Err(_) => this.update(cx, |this, cx| {
 840                this.insert_checkpoint(pending_checkpoint, cx)
 841            }),
 842        })
 843        .detach();
 844    }
 845
 846    fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
 847        self.checkpoints_by_message
 848            .insert(checkpoint.message_id, checkpoint);
 849        cx.emit(ThreadEvent::CheckpointChanged);
 850        cx.notify();
 851    }
 852
 853    pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
 854        self.last_restore_checkpoint.as_ref()
 855    }
 856
 857    pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
 858        let Some(message_ix) = self
 859            .messages
 860            .iter()
 861            .rposition(|message| message.id == message_id)
 862        else {
 863            return;
 864        };
 865        for deleted_message in self.messages.drain(message_ix..) {
 866            self.checkpoints_by_message.remove(&deleted_message.id);
 867        }
 868        cx.notify();
 869    }
 870
 871    pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AgentContext> {
 872        self.messages
 873            .iter()
 874            .find(|message| message.id == id)
 875            .into_iter()
 876            .flat_map(|message| message.loaded_context.contexts.iter())
 877    }
 878
 879    pub fn is_turn_end(&self, ix: usize) -> bool {
 880        if self.messages.is_empty() {
 881            return false;
 882        }
 883
 884        if !self.is_generating() && ix == self.messages.len() - 1 {
 885            return true;
 886        }
 887
 888        let Some(message) = self.messages.get(ix) else {
 889            return false;
 890        };
 891
 892        if message.role != Role::Assistant {
 893            return false;
 894        }
 895
 896        self.messages
 897            .get(ix + 1)
 898            .and_then(|message| {
 899                self.message(message.id)
 900                    .map(|next_message| next_message.role == Role::User && !next_message.is_hidden)
 901            })
 902            .unwrap_or(false)
 903    }
 904
 905    pub fn tool_use_limit_reached(&self) -> bool {
 906        self.tool_use_limit_reached
 907    }
 908
 909    /// Returns whether all of the tool uses have finished running.
 910    pub fn all_tools_finished(&self) -> bool {
 911        // If the only pending tool uses left are the ones with errors, then
 912        // that means that we've finished running all of the pending tools.
 913        self.tool_use
 914            .pending_tool_uses()
 915            .iter()
 916            .all(|pending_tool_use| pending_tool_use.status.is_error())
 917    }
 918
 919    /// Returns whether any pending tool uses may perform edits
 920    pub fn has_pending_edit_tool_uses(&self) -> bool {
 921        self.tool_use
 922            .pending_tool_uses()
 923            .iter()
 924            .filter(|pending_tool_use| !pending_tool_use.status.is_error())
 925            .any(|pending_tool_use| pending_tool_use.may_perform_edits)
 926    }
 927
 928    pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
 929        self.tool_use.tool_uses_for_message(id, cx)
 930    }
 931
 932    pub fn tool_results_for_message(
 933        &self,
 934        assistant_message_id: MessageId,
 935    ) -> Vec<&LanguageModelToolResult> {
 936        self.tool_use.tool_results_for_message(assistant_message_id)
 937    }
 938
 939    pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
 940        self.tool_use.tool_result(id)
 941    }
 942
 943    pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
 944        match &self.tool_use.tool_result(id)?.content {
 945            LanguageModelToolResultContent::Text(text) => Some(text),
 946            LanguageModelToolResultContent::Image(_) => {
 947                // TODO: We should display image
 948                None
 949            }
 950        }
 951    }
 952
 953    pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
 954        self.tool_use.tool_result_card(id).cloned()
 955    }
 956
 957    /// Return tools that are both enabled and supported by the model
 958    pub fn available_tools(
 959        &self,
 960        cx: &App,
 961        model: Arc<dyn LanguageModel>,
 962    ) -> Vec<LanguageModelRequestTool> {
 963        if model.supports_tools() {
 964            self.profile
 965                .enabled_tools(cx)
 966                .into_iter()
 967                .filter_map(|(name, tool)| {
 968                    // Skip tools that cannot be supported
 969                    let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
 970                    Some(LanguageModelRequestTool {
 971                        name: name.into(),
 972                        description: tool.description(),
 973                        input_schema,
 974                    })
 975                })
 976                .collect()
 977        } else {
 978            Vec::default()
 979        }
 980    }
 981
 982    pub fn insert_user_message(
 983        &mut self,
 984        text: impl Into<String>,
 985        loaded_context: ContextLoadResult,
 986        git_checkpoint: Option<GitStoreCheckpoint>,
 987        creases: Vec<MessageCrease>,
 988        cx: &mut Context<Self>,
 989    ) -> MessageId {
 990        if !loaded_context.referenced_buffers.is_empty() {
 991            self.action_log.update(cx, |log, cx| {
 992                for buffer in loaded_context.referenced_buffers {
 993                    log.buffer_read(buffer, cx);
 994                }
 995            });
 996        }
 997
 998        let message_id = self.insert_message(
 999            Role::User,
1000            vec![MessageSegment::Text(text.into())],
1001            loaded_context.loaded_context,
1002            creases,
1003            false,
1004            cx,
1005        );
1006
1007        if let Some(git_checkpoint) = git_checkpoint {
1008            self.pending_checkpoint = Some(ThreadCheckpoint {
1009                message_id,
1010                git_checkpoint,
1011            });
1012        }
1013
1014        self.auto_capture_telemetry(cx);
1015
1016        message_id
1017    }
1018
1019    pub fn insert_invisible_continue_message(&mut self, cx: &mut Context<Self>) -> MessageId {
1020        let id = self.insert_message(
1021            Role::User,
1022            vec![MessageSegment::Text("Continue where you left off".into())],
1023            LoadedContext::default(),
1024            vec![],
1025            true,
1026            cx,
1027        );
1028        self.pending_checkpoint = None;
1029
1030        id
1031    }
1032
1033    pub fn insert_assistant_message(
1034        &mut self,
1035        segments: Vec<MessageSegment>,
1036        cx: &mut Context<Self>,
1037    ) -> MessageId {
1038        self.insert_message(
1039            Role::Assistant,
1040            segments,
1041            LoadedContext::default(),
1042            Vec::new(),
1043            false,
1044            cx,
1045        )
1046    }
1047
1048    pub fn insert_message(
1049        &mut self,
1050        role: Role,
1051        segments: Vec<MessageSegment>,
1052        loaded_context: LoadedContext,
1053        creases: Vec<MessageCrease>,
1054        is_hidden: bool,
1055        cx: &mut Context<Self>,
1056    ) -> MessageId {
1057        let id = self.next_message_id.post_inc();
1058        self.messages.push(Message {
1059            id,
1060            role,
1061            segments,
1062            loaded_context,
1063            creases,
1064            is_hidden,
1065            ui_only: false,
1066        });
1067        self.touch_updated_at();
1068        cx.emit(ThreadEvent::MessageAdded(id));
1069        id
1070    }
1071
1072    pub fn edit_message(
1073        &mut self,
1074        id: MessageId,
1075        new_role: Role,
1076        new_segments: Vec<MessageSegment>,
1077        creases: Vec<MessageCrease>,
1078        loaded_context: Option<LoadedContext>,
1079        checkpoint: Option<GitStoreCheckpoint>,
1080        cx: &mut Context<Self>,
1081    ) -> bool {
1082        let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
1083            return false;
1084        };
1085        message.role = new_role;
1086        message.segments = new_segments;
1087        message.creases = creases;
1088        if let Some(context) = loaded_context {
1089            message.loaded_context = context;
1090        }
1091        if let Some(git_checkpoint) = checkpoint {
1092            self.checkpoints_by_message.insert(
1093                id,
1094                ThreadCheckpoint {
1095                    message_id: id,
1096                    git_checkpoint,
1097                },
1098            );
1099        }
1100        self.touch_updated_at();
1101        cx.emit(ThreadEvent::MessageEdited(id));
1102        true
1103    }
1104
1105    pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
1106        let Some(index) = self.messages.iter().position(|message| message.id == id) else {
1107            return false;
1108        };
1109        self.messages.remove(index);
1110        self.touch_updated_at();
1111        cx.emit(ThreadEvent::MessageDeleted(id));
1112        true
1113    }
1114
1115    /// Returns the representation of this [`Thread`] in a textual form.
1116    ///
1117    /// This is the representation we use when attaching a thread as context to another thread.
1118    pub fn text(&self) -> String {
1119        let mut text = String::new();
1120
1121        for message in &self.messages {
1122            text.push_str(match message.role {
1123                language_model::Role::User => "User:",
1124                language_model::Role::Assistant => "Agent:",
1125                language_model::Role::System => "System:",
1126            });
1127            text.push('\n');
1128
1129            for segment in &message.segments {
1130                match segment {
1131                    MessageSegment::Text(content) => text.push_str(content),
1132                    MessageSegment::Thinking { text: content, .. } => {
1133                        text.push_str(&format!("<think>{}</think>", content))
1134                    }
1135                    MessageSegment::RedactedThinking(_) => {}
1136                }
1137            }
1138            text.push('\n');
1139        }
1140
1141        text
1142    }
1143
1144    /// Serializes this thread into a format for storage or telemetry.
1145    pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
1146        let initial_project_snapshot = self.initial_project_snapshot.clone();
1147        cx.spawn(async move |this, cx| {
1148            let initial_project_snapshot = initial_project_snapshot.await;
1149            this.read_with(cx, |this, cx| SerializedThread {
1150                version: SerializedThread::VERSION.to_string(),
1151                summary: this.summary().or_default(),
1152                updated_at: this.updated_at(),
1153                messages: this
1154                    .messages()
1155                    .filter(|message| !message.ui_only)
1156                    .map(|message| SerializedMessage {
1157                        id: message.id,
1158                        role: message.role,
1159                        segments: message
1160                            .segments
1161                            .iter()
1162                            .map(|segment| match segment {
1163                                MessageSegment::Text(text) => {
1164                                    SerializedMessageSegment::Text { text: text.clone() }
1165                                }
1166                                MessageSegment::Thinking { text, signature } => {
1167                                    SerializedMessageSegment::Thinking {
1168                                        text: text.clone(),
1169                                        signature: signature.clone(),
1170                                    }
1171                                }
1172                                MessageSegment::RedactedThinking(data) => {
1173                                    SerializedMessageSegment::RedactedThinking {
1174                                        data: data.clone(),
1175                                    }
1176                                }
1177                            })
1178                            .collect(),
1179                        tool_uses: this
1180                            .tool_uses_for_message(message.id, cx)
1181                            .into_iter()
1182                            .map(|tool_use| SerializedToolUse {
1183                                id: tool_use.id,
1184                                name: tool_use.name,
1185                                input: tool_use.input,
1186                            })
1187                            .collect(),
1188                        tool_results: this
1189                            .tool_results_for_message(message.id)
1190                            .into_iter()
1191                            .map(|tool_result| SerializedToolResult {
1192                                tool_use_id: tool_result.tool_use_id.clone(),
1193                                is_error: tool_result.is_error,
1194                                content: tool_result.content.clone(),
1195                                output: tool_result.output.clone(),
1196                            })
1197                            .collect(),
1198                        context: message.loaded_context.text.clone(),
1199                        creases: message
1200                            .creases
1201                            .iter()
1202                            .map(|crease| SerializedCrease {
1203                                start: crease.range.start,
1204                                end: crease.range.end,
1205                                icon_path: crease.icon_path.clone(),
1206                                label: crease.label.clone(),
1207                            })
1208                            .collect(),
1209                        is_hidden: message.is_hidden,
1210                    })
1211                    .collect(),
1212                initial_project_snapshot,
1213                cumulative_token_usage: this.cumulative_token_usage,
1214                request_token_usage: this.request_token_usage.clone(),
1215                detailed_summary_state: this.detailed_summary_rx.borrow().clone(),
1216                exceeded_window_error: this.exceeded_window_error.clone(),
1217                model: this
1218                    .configured_model
1219                    .as_ref()
1220                    .map(|model| SerializedLanguageModel {
1221                        provider: model.provider.id().0.to_string(),
1222                        model: model.model.id().0.to_string(),
1223                    }),
1224                completion_mode: Some(this.completion_mode),
1225                tool_use_limit_reached: this.tool_use_limit_reached,
1226                profile: Some(this.profile.id().clone()),
1227            })
1228        })
1229    }
1230
1231    pub fn remaining_turns(&self) -> u32 {
1232        self.remaining_turns
1233    }
1234
1235    pub fn set_remaining_turns(&mut self, remaining_turns: u32) {
1236        self.remaining_turns = remaining_turns;
1237    }
1238
1239    pub fn send_to_model(
1240        &mut self,
1241        model: Arc<dyn LanguageModel>,
1242        intent: CompletionIntent,
1243        window: Option<AnyWindowHandle>,
1244        cx: &mut Context<Self>,
1245    ) {
1246        if self.remaining_turns == 0 {
1247            return;
1248        }
1249
1250        self.remaining_turns -= 1;
1251
1252        self.flush_notifications(model.clone(), intent, cx);
1253
1254        let request = self.to_completion_request(model.clone(), intent, cx);
1255
1256        self.stream_completion(request, model, intent, window, cx);
1257    }
1258
1259    pub fn used_tools_since_last_user_message(&self) -> bool {
1260        for message in self.messages.iter().rev() {
1261            if self.tool_use.message_has_tool_results(message.id) {
1262                return true;
1263            } else if message.role == Role::User {
1264                return false;
1265            }
1266        }
1267
1268        false
1269    }
1270
1271    pub fn to_completion_request(
1272        &self,
1273        model: Arc<dyn LanguageModel>,
1274        intent: CompletionIntent,
1275        cx: &mut Context<Self>,
1276    ) -> LanguageModelRequest {
1277        let mut request = LanguageModelRequest {
1278            thread_id: Some(self.id.to_string()),
1279            prompt_id: Some(self.last_prompt_id.to_string()),
1280            intent: Some(intent),
1281            mode: None,
1282            messages: vec![],
1283            tools: Vec::new(),
1284            tool_choice: None,
1285            stop: Vec::new(),
1286            temperature: AgentSettings::temperature_for_model(&model, cx),
1287            thinking_allowed: true,
1288        };
1289
1290        let available_tools = self.available_tools(cx, model.clone());
1291        let available_tool_names = available_tools
1292            .iter()
1293            .map(|tool| tool.name.clone())
1294            .collect();
1295
1296        let model_context = &ModelContext {
1297            available_tools: available_tool_names,
1298        };
1299
1300        if let Some(project_context) = self.project_context.borrow().as_ref() {
1301            match self
1302                .prompt_builder
1303                .generate_assistant_system_prompt(project_context, model_context)
1304            {
1305                Err(err) => {
1306                    let message = format!("{err:?}").into();
1307                    log::error!("{message}");
1308                    cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1309                        header: "Error generating system prompt".into(),
1310                        message,
1311                    }));
1312                }
1313                Ok(system_prompt) => {
1314                    request.messages.push(LanguageModelRequestMessage {
1315                        role: Role::System,
1316                        content: vec![MessageContent::Text(system_prompt)],
1317                        cache: true,
1318                    });
1319                }
1320            }
1321        } else {
1322            let message = "Context for system prompt unexpectedly not ready.".into();
1323            log::error!("{message}");
1324            cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1325                header: "Error generating system prompt".into(),
1326                message,
1327            }));
1328        }
1329
1330        let mut message_ix_to_cache = None;
1331        for message in &self.messages {
1332            // ui_only messages are for the UI only, not for the model
1333            if message.ui_only {
1334                continue;
1335            }
1336
1337            let mut request_message = LanguageModelRequestMessage {
1338                role: message.role,
1339                content: Vec::new(),
1340                cache: false,
1341            };
1342
1343            message
1344                .loaded_context
1345                .add_to_request_message(&mut request_message);
1346
1347            for segment in &message.segments {
1348                match segment {
1349                    MessageSegment::Text(text) => {
1350                        let text = text.trim_end();
1351                        if !text.is_empty() {
1352                            request_message
1353                                .content
1354                                .push(MessageContent::Text(text.into()));
1355                        }
1356                    }
1357                    MessageSegment::Thinking { text, signature } => {
1358                        if !text.is_empty() {
1359                            request_message.content.push(MessageContent::Thinking {
1360                                text: text.into(),
1361                                signature: signature.clone(),
1362                            });
1363                        }
1364                    }
1365                    MessageSegment::RedactedThinking(data) => {
1366                        request_message
1367                            .content
1368                            .push(MessageContent::RedactedThinking(data.clone()));
1369                    }
1370                };
1371            }
1372
1373            let mut cache_message = true;
1374            let mut tool_results_message = LanguageModelRequestMessage {
1375                role: Role::User,
1376                content: Vec::new(),
1377                cache: false,
1378            };
1379            for (tool_use, tool_result) in self.tool_use.tool_results(message.id) {
1380                if let Some(tool_result) = tool_result {
1381                    request_message
1382                        .content
1383                        .push(MessageContent::ToolUse(tool_use.clone()));
1384                    tool_results_message
1385                        .content
1386                        .push(MessageContent::ToolResult(LanguageModelToolResult {
1387                            tool_use_id: tool_use.id.clone(),
1388                            tool_name: tool_result.tool_name.clone(),
1389                            is_error: tool_result.is_error,
1390                            content: if tool_result.content.is_empty() {
1391                                // Surprisingly, the API fails if we return an empty string here.
1392                                // It thinks we are sending a tool use without a tool result.
1393                                "<Tool returned an empty string>".into()
1394                            } else {
1395                                tool_result.content.clone()
1396                            },
1397                            output: None,
1398                        }));
1399                } else {
1400                    cache_message = false;
1401                    log::debug!(
1402                        "skipped tool use {:?} because it is still pending",
1403                        tool_use
1404                    );
1405                }
1406            }
1407
1408            if cache_message {
1409                message_ix_to_cache = Some(request.messages.len());
1410            }
1411            request.messages.push(request_message);
1412
1413            if !tool_results_message.content.is_empty() {
1414                if cache_message {
1415                    message_ix_to_cache = Some(request.messages.len());
1416                }
1417                request.messages.push(tool_results_message);
1418            }
1419        }
1420
1421        // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1422        if let Some(message_ix_to_cache) = message_ix_to_cache {
1423            request.messages[message_ix_to_cache].cache = true;
1424        }
1425
1426        request.tools = available_tools;
1427        request.mode = if model.supports_burn_mode() {
1428            Some(self.completion_mode.into())
1429        } else {
1430            Some(CompletionMode::Normal.into())
1431        };
1432
1433        request
1434    }
1435
1436    fn to_summarize_request(
1437        &self,
1438        model: &Arc<dyn LanguageModel>,
1439        intent: CompletionIntent,
1440        added_user_message: String,
1441        cx: &App,
1442    ) -> LanguageModelRequest {
1443        let mut request = LanguageModelRequest {
1444            thread_id: None,
1445            prompt_id: None,
1446            intent: Some(intent),
1447            mode: None,
1448            messages: vec![],
1449            tools: Vec::new(),
1450            tool_choice: None,
1451            stop: Vec::new(),
1452            temperature: AgentSettings::temperature_for_model(model, cx),
1453            thinking_allowed: false,
1454        };
1455
1456        for message in &self.messages {
1457            let mut request_message = LanguageModelRequestMessage {
1458                role: message.role,
1459                content: Vec::new(),
1460                cache: false,
1461            };
1462
1463            for segment in &message.segments {
1464                match segment {
1465                    MessageSegment::Text(text) => request_message
1466                        .content
1467                        .push(MessageContent::Text(text.clone())),
1468                    MessageSegment::Thinking { .. } => {}
1469                    MessageSegment::RedactedThinking(_) => {}
1470                }
1471            }
1472
1473            if request_message.content.is_empty() {
1474                continue;
1475            }
1476
1477            request.messages.push(request_message);
1478        }
1479
1480        request.messages.push(LanguageModelRequestMessage {
1481            role: Role::User,
1482            content: vec![MessageContent::Text(added_user_message)],
1483            cache: false,
1484        });
1485
1486        request
1487    }
1488
1489    /// Insert auto-generated notifications (if any) to the thread
1490    fn flush_notifications(
1491        &mut self,
1492        model: Arc<dyn LanguageModel>,
1493        intent: CompletionIntent,
1494        cx: &mut Context<Self>,
1495    ) {
1496        match intent {
1497            CompletionIntent::UserPrompt | CompletionIntent::ToolResults => {
1498                if let Some(pending_tool_use) = self.attach_tracked_files_state(model, cx) {
1499                    cx.emit(ThreadEvent::ToolFinished {
1500                        tool_use_id: pending_tool_use.id.clone(),
1501                        pending_tool_use: Some(pending_tool_use),
1502                    });
1503                }
1504            }
1505            CompletionIntent::ThreadSummarization
1506            | CompletionIntent::ThreadContextSummarization
1507            | CompletionIntent::CreateFile
1508            | CompletionIntent::EditFile
1509            | CompletionIntent::InlineAssist
1510            | CompletionIntent::TerminalInlineAssist
1511            | CompletionIntent::GenerateGitCommitMessage => {}
1512        };
1513    }
1514
1515    fn attach_tracked_files_state(
1516        &mut self,
1517        model: Arc<dyn LanguageModel>,
1518        cx: &mut App,
1519    ) -> Option<PendingToolUse> {
1520        let action_log = self.action_log.read(cx);
1521
1522        action_log.unnotified_stale_buffers(cx).next()?;
1523
1524        // Represent notification as a simulated `project_notifications` tool call
1525        let tool_name = Arc::from("project_notifications");
1526        let Some(tool) = self.tools.read(cx).tool(&tool_name, cx) else {
1527            debug_panic!("`project_notifications` tool not found");
1528            return None;
1529        };
1530
1531        if !self.profile.is_tool_enabled(tool.source(), tool.name(), cx) {
1532            return None;
1533        }
1534
1535        let input = serde_json::json!({});
1536        let request = Arc::new(LanguageModelRequest::default()); // unused
1537        let window = None;
1538        let tool_result = tool.run(
1539            input,
1540            request,
1541            self.project.clone(),
1542            self.action_log.clone(),
1543            model.clone(),
1544            window,
1545            cx,
1546        );
1547
1548        let tool_use_id =
1549            LanguageModelToolUseId::from(format!("project_notifications_{}", self.messages.len()));
1550
1551        let tool_use = LanguageModelToolUse {
1552            id: tool_use_id.clone(),
1553            name: tool_name.clone(),
1554            raw_input: "{}".to_string(),
1555            input: serde_json::json!({}),
1556            is_input_complete: true,
1557        };
1558
1559        let tool_output = cx.background_executor().block(tool_result.output);
1560
1561        // Attach a project_notification tool call to the latest existing
1562        // Assistant message. We cannot create a new Assistant message
1563        // because thinking models require a `thinking` block that we
1564        // cannot mock. We cannot send a notification as a normal
1565        // (non-tool-use) User message because this distracts Agent
1566        // too much.
1567        let tool_message_id = self
1568            .messages
1569            .iter()
1570            .enumerate()
1571            .rfind(|(_, message)| message.role == Role::Assistant)
1572            .map(|(_, message)| message.id)?;
1573
1574        let tool_use_metadata = ToolUseMetadata {
1575            model: model.clone(),
1576            thread_id: self.id.clone(),
1577            prompt_id: self.last_prompt_id.clone(),
1578        };
1579
1580        self.tool_use
1581            .request_tool_use(tool_message_id, tool_use, tool_use_metadata.clone(), cx);
1582
1583        let pending_tool_use = self.tool_use.insert_tool_output(
1584            tool_use_id.clone(),
1585            tool_name,
1586            tool_output,
1587            self.configured_model.as_ref(),
1588            self.completion_mode,
1589        );
1590
1591        pending_tool_use
1592    }
1593
1594    pub fn stream_completion(
1595        &mut self,
1596        request: LanguageModelRequest,
1597        model: Arc<dyn LanguageModel>,
1598        intent: CompletionIntent,
1599        window: Option<AnyWindowHandle>,
1600        cx: &mut Context<Self>,
1601    ) {
1602        self.tool_use_limit_reached = false;
1603
1604        let pending_completion_id = post_inc(&mut self.completion_count);
1605        let mut request_callback_parameters = if self.request_callback.is_some() {
1606            Some((request.clone(), Vec::new()))
1607        } else {
1608            None
1609        };
1610        let prompt_id = self.last_prompt_id.clone();
1611        let tool_use_metadata = ToolUseMetadata {
1612            model: model.clone(),
1613            thread_id: self.id.clone(),
1614            prompt_id: prompt_id.clone(),
1615        };
1616
1617        let completion_mode = request
1618            .mode
1619            .unwrap_or(zed_llm_client::CompletionMode::Normal);
1620
1621        self.last_received_chunk_at = Some(Instant::now());
1622
1623        let task = cx.spawn(async move |thread, cx| {
1624            let stream_completion_future = model.stream_completion(request, &cx);
1625            let initial_token_usage =
1626                thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1627            let stream_completion = async {
1628                let mut events = stream_completion_future.await?;
1629
1630                let mut stop_reason = StopReason::EndTurn;
1631                let mut current_token_usage = TokenUsage::default();
1632
1633                thread
1634                    .update(cx, |_thread, cx| {
1635                        cx.emit(ThreadEvent::NewRequest);
1636                    })
1637                    .ok();
1638
1639                let mut request_assistant_message_id = None;
1640
1641                while let Some(event) = events.next().await {
1642                    if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1643                        response_events
1644                            .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1645                    }
1646
1647                    thread.update(cx, |thread, cx| {
1648                        match event? {
1649                            LanguageModelCompletionEvent::StartMessage { .. } => {
1650                                request_assistant_message_id =
1651                                    Some(thread.insert_assistant_message(
1652                                        vec![MessageSegment::Text(String::new())],
1653                                        cx,
1654                                    ));
1655                            }
1656                            LanguageModelCompletionEvent::Stop(reason) => {
1657                                stop_reason = reason;
1658                            }
1659                            LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1660                                thread.update_token_usage_at_last_message(token_usage);
1661                                thread.cumulative_token_usage = thread.cumulative_token_usage
1662                                    + token_usage
1663                                    - current_token_usage;
1664                                current_token_usage = token_usage;
1665                            }
1666                            LanguageModelCompletionEvent::Text(chunk) => {
1667                                thread.received_chunk();
1668
1669                                cx.emit(ThreadEvent::ReceivedTextChunk);
1670                                if let Some(last_message) = thread.messages.last_mut() {
1671                                    if last_message.role == Role::Assistant
1672                                        && !thread.tool_use.has_tool_results(last_message.id)
1673                                    {
1674                                        last_message.push_text(&chunk);
1675                                        cx.emit(ThreadEvent::StreamedAssistantText(
1676                                            last_message.id,
1677                                            chunk,
1678                                        ));
1679                                    } else {
1680                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
1681                                        // of a new Assistant response.
1682                                        //
1683                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1684                                        // will result in duplicating the text of the chunk in the rendered Markdown.
1685                                        request_assistant_message_id =
1686                                            Some(thread.insert_assistant_message(
1687                                                vec![MessageSegment::Text(chunk.to_string())],
1688                                                cx,
1689                                            ));
1690                                    };
1691                                }
1692                            }
1693                            LanguageModelCompletionEvent::Thinking {
1694                                text: chunk,
1695                                signature,
1696                            } => {
1697                                thread.received_chunk();
1698
1699                                if let Some(last_message) = thread.messages.last_mut() {
1700                                    if last_message.role == Role::Assistant
1701                                        && !thread.tool_use.has_tool_results(last_message.id)
1702                                    {
1703                                        last_message.push_thinking(&chunk, signature);
1704                                        cx.emit(ThreadEvent::StreamedAssistantThinking(
1705                                            last_message.id,
1706                                            chunk,
1707                                        ));
1708                                    } else {
1709                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
1710                                        // of a new Assistant response.
1711                                        //
1712                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1713                                        // will result in duplicating the text of the chunk in the rendered Markdown.
1714                                        request_assistant_message_id =
1715                                            Some(thread.insert_assistant_message(
1716                                                vec![MessageSegment::Thinking {
1717                                                    text: chunk.to_string(),
1718                                                    signature,
1719                                                }],
1720                                                cx,
1721                                            ));
1722                                    };
1723                                }
1724                            }
1725                            LanguageModelCompletionEvent::RedactedThinking { data } => {
1726                                thread.received_chunk();
1727
1728                                if let Some(last_message) = thread.messages.last_mut() {
1729                                    if last_message.role == Role::Assistant
1730                                        && !thread.tool_use.has_tool_results(last_message.id)
1731                                    {
1732                                        last_message.push_redacted_thinking(data);
1733                                    } else {
1734                                        request_assistant_message_id =
1735                                            Some(thread.insert_assistant_message(
1736                                                vec![MessageSegment::RedactedThinking(data)],
1737                                                cx,
1738                                            ));
1739                                    };
1740                                }
1741                            }
1742                            LanguageModelCompletionEvent::ToolUse(tool_use) => {
1743                                let last_assistant_message_id = request_assistant_message_id
1744                                    .unwrap_or_else(|| {
1745                                        let new_assistant_message_id =
1746                                            thread.insert_assistant_message(vec![], cx);
1747                                        request_assistant_message_id =
1748                                            Some(new_assistant_message_id);
1749                                        new_assistant_message_id
1750                                    });
1751
1752                                let tool_use_id = tool_use.id.clone();
1753                                let streamed_input = if tool_use.is_input_complete {
1754                                    None
1755                                } else {
1756                                    Some((&tool_use.input).clone())
1757                                };
1758
1759                                let ui_text = thread.tool_use.request_tool_use(
1760                                    last_assistant_message_id,
1761                                    tool_use,
1762                                    tool_use_metadata.clone(),
1763                                    cx,
1764                                );
1765
1766                                if let Some(input) = streamed_input {
1767                                    cx.emit(ThreadEvent::StreamedToolUse {
1768                                        tool_use_id,
1769                                        ui_text,
1770                                        input,
1771                                    });
1772                                }
1773                            }
1774                            LanguageModelCompletionEvent::ToolUseJsonParseError {
1775                                id,
1776                                tool_name,
1777                                raw_input: invalid_input_json,
1778                                json_parse_error,
1779                            } => {
1780                                thread.receive_invalid_tool_json(
1781                                    id,
1782                                    tool_name,
1783                                    invalid_input_json,
1784                                    json_parse_error,
1785                                    window,
1786                                    cx,
1787                                );
1788                            }
1789                            LanguageModelCompletionEvent::StatusUpdate(status_update) => {
1790                                if let Some(completion) = thread
1791                                    .pending_completions
1792                                    .iter_mut()
1793                                    .find(|completion| completion.id == pending_completion_id)
1794                                {
1795                                    match status_update {
1796                                        CompletionRequestStatus::Queued { position } => {
1797                                            completion.queue_state =
1798                                                QueueState::Queued { position };
1799                                        }
1800                                        CompletionRequestStatus::Started => {
1801                                            completion.queue_state = QueueState::Started;
1802                                        }
1803                                        CompletionRequestStatus::Failed {
1804                                            code,
1805                                            message,
1806                                            request_id: _,
1807                                            retry_after,
1808                                        } => {
1809                                            return Err(
1810                                                LanguageModelCompletionError::from_cloud_failure(
1811                                                    model.upstream_provider_name(),
1812                                                    code,
1813                                                    message,
1814                                                    retry_after.map(Duration::from_secs_f64),
1815                                                ),
1816                                            );
1817                                        }
1818                                        CompletionRequestStatus::UsageUpdated { amount, limit } => {
1819                                            thread.update_model_request_usage(
1820                                                amount as u32,
1821                                                limit,
1822                                                cx,
1823                                            );
1824                                        }
1825                                        CompletionRequestStatus::ToolUseLimitReached => {
1826                                            thread.tool_use_limit_reached = true;
1827                                            cx.emit(ThreadEvent::ToolUseLimitReached);
1828                                        }
1829                                    }
1830                                }
1831                            }
1832                        }
1833
1834                        thread.touch_updated_at();
1835                        cx.emit(ThreadEvent::StreamedCompletion);
1836                        cx.notify();
1837
1838                        thread.auto_capture_telemetry(cx);
1839                        Ok(())
1840                    })??;
1841
1842                    smol::future::yield_now().await;
1843                }
1844
1845                thread.update(cx, |thread, cx| {
1846                    thread.last_received_chunk_at = None;
1847                    thread
1848                        .pending_completions
1849                        .retain(|completion| completion.id != pending_completion_id);
1850
1851                    // If there is a response without tool use, summarize the message. Otherwise,
1852                    // allow two tool uses before summarizing.
1853                    if matches!(thread.summary, ThreadSummary::Pending)
1854                        && thread.messages.len() >= 2
1855                        && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
1856                    {
1857                        thread.summarize(cx);
1858                    }
1859                })?;
1860
1861                anyhow::Ok(stop_reason)
1862            };
1863
1864            let result = stream_completion.await;
1865            let mut retry_scheduled = false;
1866
1867            thread
1868                .update(cx, |thread, cx| {
1869                    thread.finalize_pending_checkpoint(cx);
1870                    match result.as_ref() {
1871                        Ok(stop_reason) => {
1872                            match stop_reason {
1873                                StopReason::ToolUse => {
1874                                    let tool_uses =
1875                                        thread.use_pending_tools(window, model.clone(), cx);
1876                                    cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1877                                }
1878                                StopReason::EndTurn | StopReason::MaxTokens => {
1879                                    thread.project.update(cx, |project, cx| {
1880                                        project.set_agent_location(None, cx);
1881                                    });
1882                                }
1883                                StopReason::Refusal => {
1884                                    thread.project.update(cx, |project, cx| {
1885                                        project.set_agent_location(None, cx);
1886                                    });
1887
1888                                    // Remove the turn that was refused.
1889                                    //
1890                                    // https://docs.anthropic.com/en/docs/test-and-evaluate/strengthen-guardrails/handle-streaming-refusals#reset-context-after-refusal
1891                                    {
1892                                        let mut messages_to_remove = Vec::new();
1893
1894                                        for (ix, message) in
1895                                            thread.messages.iter().enumerate().rev()
1896                                        {
1897                                            messages_to_remove.push(message.id);
1898
1899                                            if message.role == Role::User {
1900                                                if ix == 0 {
1901                                                    break;
1902                                                }
1903
1904                                                if let Some(prev_message) =
1905                                                    thread.messages.get(ix - 1)
1906                                                {
1907                                                    if prev_message.role == Role::Assistant {
1908                                                        break;
1909                                                    }
1910                                                }
1911                                            }
1912                                        }
1913
1914                                        for message_id in messages_to_remove {
1915                                            thread.delete_message(message_id, cx);
1916                                        }
1917                                    }
1918
1919                                    cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1920                                        header: "Language model refusal".into(),
1921                                        message:
1922                                            "Model refused to generate content for safety reasons."
1923                                                .into(),
1924                                    }));
1925                                }
1926                            }
1927
1928                            // We successfully completed, so cancel any remaining retries.
1929                            thread.retry_state = None;
1930                        }
1931                        Err(error) => {
1932                            thread.project.update(cx, |project, cx| {
1933                                project.set_agent_location(None, cx);
1934                            });
1935
1936                            fn emit_generic_error(error: &anyhow::Error, cx: &mut Context<Thread>) {
1937                                let error_message = error
1938                                    .chain()
1939                                    .map(|err| err.to_string())
1940                                    .collect::<Vec<_>>()
1941                                    .join("\n");
1942                                cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1943                                    header: "Error interacting with language model".into(),
1944                                    message: SharedString::from(error_message.clone()),
1945                                }));
1946                            }
1947
1948                            if error.is::<PaymentRequiredError>() {
1949                                cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1950                            } else if let Some(error) =
1951                                error.downcast_ref::<ModelRequestLimitReachedError>()
1952                            {
1953                                cx.emit(ThreadEvent::ShowError(
1954                                    ThreadError::ModelRequestLimitReached { plan: error.plan },
1955                                ));
1956                            } else if let Some(completion_error) =
1957                                error.downcast_ref::<LanguageModelCompletionError>()
1958                            {
1959                                use LanguageModelCompletionError::*;
1960                                match &completion_error {
1961                                    PromptTooLarge { tokens, .. } => {
1962                                        let tokens = tokens.unwrap_or_else(|| {
1963                                            // We didn't get an exact token count from the API, so fall back on our estimate.
1964                                            thread
1965                                                .total_token_usage()
1966                                                .map(|usage| usage.total)
1967                                                .unwrap_or(0)
1968                                                // We know the context window was exceeded in practice, so if our estimate was
1969                                                // lower than max tokens, the estimate was wrong; return that we exceeded by 1.
1970                                                .max(
1971                                                    model
1972                                                        .max_token_count_for_mode(completion_mode)
1973                                                        .saturating_add(1),
1974                                                )
1975                                        });
1976                                        thread.exceeded_window_error = Some(ExceededWindowError {
1977                                            model_id: model.id(),
1978                                            token_count: tokens,
1979                                        });
1980                                        cx.notify();
1981                                    }
1982                                    RateLimitExceeded {
1983                                        retry_after: Some(retry_after),
1984                                        ..
1985                                    }
1986                                    | ServerOverloaded {
1987                                        retry_after: Some(retry_after),
1988                                        ..
1989                                    } => {
1990                                        thread.handle_rate_limit_error(
1991                                            &completion_error,
1992                                            *retry_after,
1993                                            model.clone(),
1994                                            intent,
1995                                            window,
1996                                            cx,
1997                                        );
1998                                        retry_scheduled = true;
1999                                    }
2000                                    RateLimitExceeded { .. } | ServerOverloaded { .. } => {
2001                                        retry_scheduled = thread.handle_retryable_error(
2002                                            &completion_error,
2003                                            model.clone(),
2004                                            intent,
2005                                            window,
2006                                            cx,
2007                                        );
2008                                        if !retry_scheduled {
2009                                            emit_generic_error(error, cx);
2010                                        }
2011                                    }
2012                                    ApiInternalServerError { .. }
2013                                    | ApiReadResponseError { .. }
2014                                    | HttpSend { .. } => {
2015                                        retry_scheduled = thread.handle_retryable_error(
2016                                            &completion_error,
2017                                            model.clone(),
2018                                            intent,
2019                                            window,
2020                                            cx,
2021                                        );
2022                                        if !retry_scheduled {
2023                                            emit_generic_error(error, cx);
2024                                        }
2025                                    }
2026                                    NoApiKey { .. }
2027                                    | HttpResponseError { .. }
2028                                    | BadRequestFormat { .. }
2029                                    | AuthenticationError { .. }
2030                                    | PermissionError { .. }
2031                                    | ApiEndpointNotFound { .. }
2032                                    | SerializeRequest { .. }
2033                                    | BuildRequestBody { .. }
2034                                    | DeserializeResponse { .. }
2035                                    | Other { .. } => emit_generic_error(error, cx),
2036                                }
2037                            } else {
2038                                emit_generic_error(error, cx);
2039                            }
2040
2041                            if !retry_scheduled {
2042                                thread.cancel_last_completion(window, cx);
2043                            }
2044                        }
2045                    }
2046
2047                    if !retry_scheduled {
2048                        cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
2049                    }
2050
2051                    if let Some((request_callback, (request, response_events))) = thread
2052                        .request_callback
2053                        .as_mut()
2054                        .zip(request_callback_parameters.as_ref())
2055                    {
2056                        request_callback(request, response_events);
2057                    }
2058
2059                    thread.auto_capture_telemetry(cx);
2060
2061                    if let Ok(initial_usage) = initial_token_usage {
2062                        let usage = thread.cumulative_token_usage - initial_usage;
2063
2064                        telemetry::event!(
2065                            "Assistant Thread Completion",
2066                            thread_id = thread.id().to_string(),
2067                            prompt_id = prompt_id,
2068                            model = model.telemetry_id(),
2069                            model_provider = model.provider_id().to_string(),
2070                            input_tokens = usage.input_tokens,
2071                            output_tokens = usage.output_tokens,
2072                            cache_creation_input_tokens = usage.cache_creation_input_tokens,
2073                            cache_read_input_tokens = usage.cache_read_input_tokens,
2074                        );
2075                    }
2076                })
2077                .ok();
2078        });
2079
2080        self.pending_completions.push(PendingCompletion {
2081            id: pending_completion_id,
2082            queue_state: QueueState::Sending,
2083            _task: task,
2084        });
2085    }
2086
2087    pub fn summarize(&mut self, cx: &mut Context<Self>) {
2088        let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
2089            println!("No thread summary model");
2090            return;
2091        };
2092
2093        if !model.provider.is_authenticated(cx) {
2094            return;
2095        }
2096
2097        let added_user_message = include_str!("./prompts/summarize_thread_prompt.txt");
2098
2099        let request = self.to_summarize_request(
2100            &model.model,
2101            CompletionIntent::ThreadSummarization,
2102            added_user_message.into(),
2103            cx,
2104        );
2105
2106        self.summary = ThreadSummary::Generating;
2107
2108        self.pending_summary = cx.spawn(async move |this, cx| {
2109            let result = async {
2110                let mut messages = model.model.stream_completion(request, &cx).await?;
2111
2112                let mut new_summary = String::new();
2113                while let Some(event) = messages.next().await {
2114                    let Ok(event) = event else {
2115                        continue;
2116                    };
2117                    let text = match event {
2118                        LanguageModelCompletionEvent::Text(text) => text,
2119                        LanguageModelCompletionEvent::StatusUpdate(
2120                            CompletionRequestStatus::UsageUpdated { amount, limit },
2121                        ) => {
2122                            this.update(cx, |thread, cx| {
2123                                thread.update_model_request_usage(amount as u32, limit, cx);
2124                            })?;
2125                            continue;
2126                        }
2127                        _ => continue,
2128                    };
2129
2130                    let mut lines = text.lines();
2131                    new_summary.extend(lines.next());
2132
2133                    // Stop if the LLM generated multiple lines.
2134                    if lines.next().is_some() {
2135                        break;
2136                    }
2137                }
2138
2139                anyhow::Ok(new_summary)
2140            }
2141            .await;
2142
2143            this.update(cx, |this, cx| {
2144                match result {
2145                    Ok(new_summary) => {
2146                        if new_summary.is_empty() {
2147                            this.summary = ThreadSummary::Error;
2148                        } else {
2149                            this.summary = ThreadSummary::Ready(new_summary.into());
2150                        }
2151                    }
2152                    Err(err) => {
2153                        this.summary = ThreadSummary::Error;
2154                        log::error!("Failed to generate thread summary: {}", err);
2155                    }
2156                }
2157                cx.emit(ThreadEvent::SummaryGenerated);
2158            })
2159            .log_err()?;
2160
2161            Some(())
2162        });
2163    }
2164
2165    fn handle_rate_limit_error(
2166        &mut self,
2167        error: &LanguageModelCompletionError,
2168        retry_after: Duration,
2169        model: Arc<dyn LanguageModel>,
2170        intent: CompletionIntent,
2171        window: Option<AnyWindowHandle>,
2172        cx: &mut Context<Self>,
2173    ) {
2174        // For rate limit errors, we only retry once with the specified duration
2175        let retry_message = format!("{error}. Retrying in {} seconds…", retry_after.as_secs());
2176        log::warn!(
2177            "Retrying completion request in {} seconds: {error:?}",
2178            retry_after.as_secs(),
2179        );
2180
2181        // Add a UI-only message instead of a regular message
2182        let id = self.next_message_id.post_inc();
2183        self.messages.push(Message {
2184            id,
2185            role: Role::System,
2186            segments: vec![MessageSegment::Text(retry_message)],
2187            loaded_context: LoadedContext::default(),
2188            creases: Vec::new(),
2189            is_hidden: false,
2190            ui_only: true,
2191        });
2192        cx.emit(ThreadEvent::MessageAdded(id));
2193        // Schedule the retry
2194        let thread_handle = cx.entity().downgrade();
2195
2196        cx.spawn(async move |_thread, cx| {
2197            cx.background_executor().timer(retry_after).await;
2198
2199            thread_handle
2200                .update(cx, |thread, cx| {
2201                    // Retry the completion
2202                    thread.send_to_model(model, intent, window, cx);
2203                })
2204                .log_err();
2205        })
2206        .detach();
2207    }
2208
2209    fn handle_retryable_error(
2210        &mut self,
2211        error: &LanguageModelCompletionError,
2212        model: Arc<dyn LanguageModel>,
2213        intent: CompletionIntent,
2214        window: Option<AnyWindowHandle>,
2215        cx: &mut Context<Self>,
2216    ) -> bool {
2217        self.handle_retryable_error_with_delay(error, None, model, intent, window, cx)
2218    }
2219
2220    fn handle_retryable_error_with_delay(
2221        &mut self,
2222        error: &LanguageModelCompletionError,
2223        custom_delay: Option<Duration>,
2224        model: Arc<dyn LanguageModel>,
2225        intent: CompletionIntent,
2226        window: Option<AnyWindowHandle>,
2227        cx: &mut Context<Self>,
2228    ) -> bool {
2229        let retry_state = self.retry_state.get_or_insert(RetryState {
2230            attempt: 0,
2231            max_attempts: MAX_RETRY_ATTEMPTS,
2232            intent,
2233        });
2234
2235        retry_state.attempt += 1;
2236        let attempt = retry_state.attempt;
2237        let max_attempts = retry_state.max_attempts;
2238        let intent = retry_state.intent;
2239
2240        if attempt <= max_attempts {
2241            // Use custom delay if provided (e.g., from rate limit), otherwise exponential backoff
2242            let delay = if let Some(custom_delay) = custom_delay {
2243                custom_delay
2244            } else {
2245                let delay_secs = BASE_RETRY_DELAY_SECS * 2u64.pow((attempt - 1) as u32);
2246                Duration::from_secs(delay_secs)
2247            };
2248
2249            // Add a transient message to inform the user
2250            let delay_secs = delay.as_secs();
2251            let retry_message = format!(
2252                "{error}. Retrying (attempt {attempt} of {max_attempts}) \
2253                in {delay_secs} seconds..."
2254            );
2255            log::warn!(
2256                "Retrying completion request (attempt {attempt} of {max_attempts}) \
2257                in {delay_secs} seconds: {error:?}",
2258            );
2259
2260            // Add a UI-only message instead of a regular message
2261            let id = self.next_message_id.post_inc();
2262            self.messages.push(Message {
2263                id,
2264                role: Role::System,
2265                segments: vec![MessageSegment::Text(retry_message)],
2266                loaded_context: LoadedContext::default(),
2267                creases: Vec::new(),
2268                is_hidden: false,
2269                ui_only: true,
2270            });
2271            cx.emit(ThreadEvent::MessageAdded(id));
2272
2273            // Schedule the retry
2274            let thread_handle = cx.entity().downgrade();
2275
2276            cx.spawn(async move |_thread, cx| {
2277                cx.background_executor().timer(delay).await;
2278
2279                thread_handle
2280                    .update(cx, |thread, cx| {
2281                        // Retry the completion
2282                        thread.send_to_model(model, intent, window, cx);
2283                    })
2284                    .log_err();
2285            })
2286            .detach();
2287
2288            true
2289        } else {
2290            // Max retries exceeded
2291            self.retry_state = None;
2292
2293            let notification_text = if max_attempts == 1 {
2294                "Failed after retrying.".into()
2295            } else {
2296                format!("Failed after retrying {} times.", max_attempts).into()
2297            };
2298
2299            // Stop generating since we're giving up on retrying.
2300            self.pending_completions.clear();
2301
2302            cx.emit(ThreadEvent::RetriesFailed {
2303                message: notification_text,
2304            });
2305
2306            false
2307        }
2308    }
2309
2310    pub fn start_generating_detailed_summary_if_needed(
2311        &mut self,
2312        thread_store: WeakEntity<ThreadStore>,
2313        cx: &mut Context<Self>,
2314    ) {
2315        let Some(last_message_id) = self.messages.last().map(|message| message.id) else {
2316            return;
2317        };
2318
2319        match &*self.detailed_summary_rx.borrow() {
2320            DetailedSummaryState::Generating { message_id, .. }
2321            | DetailedSummaryState::Generated { message_id, .. }
2322                if *message_id == last_message_id =>
2323            {
2324                // Already up-to-date
2325                return;
2326            }
2327            _ => {}
2328        }
2329
2330        let Some(ConfiguredModel { model, provider }) =
2331            LanguageModelRegistry::read_global(cx).thread_summary_model()
2332        else {
2333            return;
2334        };
2335
2336        if !provider.is_authenticated(cx) {
2337            return;
2338        }
2339
2340        let added_user_message = include_str!("./prompts/summarize_thread_detailed_prompt.txt");
2341
2342        let request = self.to_summarize_request(
2343            &model,
2344            CompletionIntent::ThreadContextSummarization,
2345            added_user_message.into(),
2346            cx,
2347        );
2348
2349        *self.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generating {
2350            message_id: last_message_id,
2351        };
2352
2353        // Replace the detailed summarization task if there is one, cancelling it. It would probably
2354        // be better to allow the old task to complete, but this would require logic for choosing
2355        // which result to prefer (the old task could complete after the new one, resulting in a
2356        // stale summary).
2357        self.detailed_summary_task = cx.spawn(async move |thread, cx| {
2358            let stream = model.stream_completion_text(request, &cx);
2359            let Some(mut messages) = stream.await.log_err() else {
2360                thread
2361                    .update(cx, |thread, _cx| {
2362                        *thread.detailed_summary_tx.borrow_mut() =
2363                            DetailedSummaryState::NotGenerated;
2364                    })
2365                    .ok()?;
2366                return None;
2367            };
2368
2369            let mut new_detailed_summary = String::new();
2370
2371            while let Some(chunk) = messages.stream.next().await {
2372                if let Some(chunk) = chunk.log_err() {
2373                    new_detailed_summary.push_str(&chunk);
2374                }
2375            }
2376
2377            thread
2378                .update(cx, |thread, _cx| {
2379                    *thread.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generated {
2380                        text: new_detailed_summary.into(),
2381                        message_id: last_message_id,
2382                    };
2383                })
2384                .ok()?;
2385
2386            // Save thread so its summary can be reused later
2387            if let Some(thread) = thread.upgrade() {
2388                if let Ok(Ok(save_task)) = cx.update(|cx| {
2389                    thread_store
2390                        .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx))
2391                }) {
2392                    save_task.await.log_err();
2393                }
2394            }
2395
2396            Some(())
2397        });
2398    }
2399
2400    pub async fn wait_for_detailed_summary_or_text(
2401        this: &Entity<Self>,
2402        cx: &mut AsyncApp,
2403    ) -> Option<SharedString> {
2404        let mut detailed_summary_rx = this
2405            .read_with(cx, |this, _cx| this.detailed_summary_rx.clone())
2406            .ok()?;
2407        loop {
2408            match detailed_summary_rx.recv().await? {
2409                DetailedSummaryState::Generating { .. } => {}
2410                DetailedSummaryState::NotGenerated => {
2411                    return this.read_with(cx, |this, _cx| this.text().into()).ok();
2412                }
2413                DetailedSummaryState::Generated { text, .. } => return Some(text),
2414            }
2415        }
2416    }
2417
2418    pub fn latest_detailed_summary_or_text(&self) -> SharedString {
2419        self.detailed_summary_rx
2420            .borrow()
2421            .text()
2422            .unwrap_or_else(|| self.text().into())
2423    }
2424
2425    pub fn is_generating_detailed_summary(&self) -> bool {
2426        matches!(
2427            &*self.detailed_summary_rx.borrow(),
2428            DetailedSummaryState::Generating { .. }
2429        )
2430    }
2431
2432    pub fn use_pending_tools(
2433        &mut self,
2434        window: Option<AnyWindowHandle>,
2435        model: Arc<dyn LanguageModel>,
2436        cx: &mut Context<Self>,
2437    ) -> Vec<PendingToolUse> {
2438        self.auto_capture_telemetry(cx);
2439        let request =
2440            Arc::new(self.to_completion_request(model.clone(), CompletionIntent::ToolResults, cx));
2441        let pending_tool_uses = self
2442            .tool_use
2443            .pending_tool_uses()
2444            .into_iter()
2445            .filter(|tool_use| tool_use.status.is_idle())
2446            .cloned()
2447            .collect::<Vec<_>>();
2448
2449        for tool_use in pending_tool_uses.iter() {
2450            self.use_pending_tool(tool_use.clone(), request.clone(), model.clone(), window, cx);
2451        }
2452
2453        pending_tool_uses
2454    }
2455
2456    fn use_pending_tool(
2457        &mut self,
2458        tool_use: PendingToolUse,
2459        request: Arc<LanguageModelRequest>,
2460        model: Arc<dyn LanguageModel>,
2461        window: Option<AnyWindowHandle>,
2462        cx: &mut Context<Self>,
2463    ) {
2464        let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) else {
2465            return self.handle_hallucinated_tool_use(tool_use.id, tool_use.name, window, cx);
2466        };
2467
2468        if !self.profile.is_tool_enabled(tool.source(), tool.name(), cx) {
2469            return self.handle_hallucinated_tool_use(tool_use.id, tool_use.name, window, cx);
2470        }
2471
2472        if tool.needs_confirmation(&tool_use.input, cx)
2473            && !AgentSettings::get_global(cx).always_allow_tool_actions
2474        {
2475            self.tool_use.confirm_tool_use(
2476                tool_use.id,
2477                tool_use.ui_text,
2478                tool_use.input,
2479                request,
2480                tool,
2481            );
2482            cx.emit(ThreadEvent::ToolConfirmationNeeded);
2483        } else {
2484            self.run_tool(
2485                tool_use.id,
2486                tool_use.ui_text,
2487                tool_use.input,
2488                request,
2489                tool,
2490                model,
2491                window,
2492                cx,
2493            );
2494        }
2495    }
2496
2497    pub fn handle_hallucinated_tool_use(
2498        &mut self,
2499        tool_use_id: LanguageModelToolUseId,
2500        hallucinated_tool_name: Arc<str>,
2501        window: Option<AnyWindowHandle>,
2502        cx: &mut Context<Thread>,
2503    ) {
2504        let available_tools = self.profile.enabled_tools(cx);
2505
2506        let tool_list = available_tools
2507            .iter()
2508            .map(|(name, tool)| format!("- {}: {}", name, tool.description()))
2509            .collect::<Vec<_>>()
2510            .join("\n");
2511
2512        let error_message = format!(
2513            "The tool '{}' doesn't exist or is not enabled. Available tools:\n{}",
2514            hallucinated_tool_name, tool_list
2515        );
2516
2517        let pending_tool_use = self.tool_use.insert_tool_output(
2518            tool_use_id.clone(),
2519            hallucinated_tool_name,
2520            Err(anyhow!("Missing tool call: {error_message}")),
2521            self.configured_model.as_ref(),
2522            self.completion_mode,
2523        );
2524
2525        cx.emit(ThreadEvent::MissingToolUse {
2526            tool_use_id: tool_use_id.clone(),
2527            ui_text: error_message.into(),
2528        });
2529
2530        self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2531    }
2532
2533    pub fn receive_invalid_tool_json(
2534        &mut self,
2535        tool_use_id: LanguageModelToolUseId,
2536        tool_name: Arc<str>,
2537        invalid_json: Arc<str>,
2538        error: String,
2539        window: Option<AnyWindowHandle>,
2540        cx: &mut Context<Thread>,
2541    ) {
2542        log::error!("The model returned invalid input JSON: {invalid_json}");
2543
2544        let pending_tool_use = self.tool_use.insert_tool_output(
2545            tool_use_id.clone(),
2546            tool_name,
2547            Err(anyhow!("Error parsing input JSON: {error}")),
2548            self.configured_model.as_ref(),
2549            self.completion_mode,
2550        );
2551        let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
2552            pending_tool_use.ui_text.clone()
2553        } else {
2554            log::error!(
2555                "There was no pending tool use for tool use {tool_use_id}, even though it finished (with invalid input JSON)."
2556            );
2557            format!("Unknown tool {}", tool_use_id).into()
2558        };
2559
2560        cx.emit(ThreadEvent::InvalidToolInput {
2561            tool_use_id: tool_use_id.clone(),
2562            ui_text,
2563            invalid_input_json: invalid_json,
2564        });
2565
2566        self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2567    }
2568
2569    pub fn run_tool(
2570        &mut self,
2571        tool_use_id: LanguageModelToolUseId,
2572        ui_text: impl Into<SharedString>,
2573        input: serde_json::Value,
2574        request: Arc<LanguageModelRequest>,
2575        tool: Arc<dyn Tool>,
2576        model: Arc<dyn LanguageModel>,
2577        window: Option<AnyWindowHandle>,
2578        cx: &mut Context<Thread>,
2579    ) {
2580        let task =
2581            self.spawn_tool_use(tool_use_id.clone(), request, input, tool, model, window, cx);
2582        self.tool_use
2583            .run_pending_tool(tool_use_id, ui_text.into(), task);
2584    }
2585
2586    fn spawn_tool_use(
2587        &mut self,
2588        tool_use_id: LanguageModelToolUseId,
2589        request: Arc<LanguageModelRequest>,
2590        input: serde_json::Value,
2591        tool: Arc<dyn Tool>,
2592        model: Arc<dyn LanguageModel>,
2593        window: Option<AnyWindowHandle>,
2594        cx: &mut Context<Thread>,
2595    ) -> Task<()> {
2596        let tool_name: Arc<str> = tool.name().into();
2597
2598        let tool_result = tool.run(
2599            input,
2600            request,
2601            self.project.clone(),
2602            self.action_log.clone(),
2603            model,
2604            window,
2605            cx,
2606        );
2607
2608        // Store the card separately if it exists
2609        if let Some(card) = tool_result.card.clone() {
2610            self.tool_use
2611                .insert_tool_result_card(tool_use_id.clone(), card);
2612        }
2613
2614        cx.spawn({
2615            async move |thread: WeakEntity<Thread>, cx| {
2616                let output = tool_result.output.await;
2617
2618                thread
2619                    .update(cx, |thread, cx| {
2620                        let pending_tool_use = thread.tool_use.insert_tool_output(
2621                            tool_use_id.clone(),
2622                            tool_name,
2623                            output,
2624                            thread.configured_model.as_ref(),
2625                            thread.completion_mode,
2626                        );
2627                        thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2628                    })
2629                    .ok();
2630            }
2631        })
2632    }
2633
2634    fn tool_finished(
2635        &mut self,
2636        tool_use_id: LanguageModelToolUseId,
2637        pending_tool_use: Option<PendingToolUse>,
2638        canceled: bool,
2639        window: Option<AnyWindowHandle>,
2640        cx: &mut Context<Self>,
2641    ) {
2642        if self.all_tools_finished() {
2643            if let Some(ConfiguredModel { model, .. }) = self.configured_model.as_ref() {
2644                if !canceled {
2645                    self.send_to_model(model.clone(), CompletionIntent::ToolResults, window, cx);
2646                }
2647                self.auto_capture_telemetry(cx);
2648            }
2649        }
2650
2651        cx.emit(ThreadEvent::ToolFinished {
2652            tool_use_id,
2653            pending_tool_use,
2654        });
2655    }
2656
2657    /// Cancels the last pending completion, if there are any pending.
2658    ///
2659    /// Returns whether a completion was canceled.
2660    pub fn cancel_last_completion(
2661        &mut self,
2662        window: Option<AnyWindowHandle>,
2663        cx: &mut Context<Self>,
2664    ) -> bool {
2665        let mut canceled = self.pending_completions.pop().is_some() || self.retry_state.is_some();
2666
2667        self.retry_state = None;
2668
2669        for pending_tool_use in self.tool_use.cancel_pending() {
2670            canceled = true;
2671            self.tool_finished(
2672                pending_tool_use.id.clone(),
2673                Some(pending_tool_use),
2674                true,
2675                window,
2676                cx,
2677            );
2678        }
2679
2680        if canceled {
2681            cx.emit(ThreadEvent::CompletionCanceled);
2682
2683            // When canceled, we always want to insert the checkpoint.
2684            // (We skip over finalize_pending_checkpoint, because it
2685            // would conclude we didn't have anything to insert here.)
2686            if let Some(checkpoint) = self.pending_checkpoint.take() {
2687                self.insert_checkpoint(checkpoint, cx);
2688            }
2689        } else {
2690            self.finalize_pending_checkpoint(cx);
2691        }
2692
2693        canceled
2694    }
2695
2696    /// Signals that any in-progress editing should be canceled.
2697    ///
2698    /// This method is used to notify listeners (like ActiveThread) that
2699    /// they should cancel any editing operations.
2700    pub fn cancel_editing(&mut self, cx: &mut Context<Self>) {
2701        cx.emit(ThreadEvent::CancelEditing);
2702    }
2703
2704    pub fn feedback(&self) -> Option<ThreadFeedback> {
2705        self.feedback
2706    }
2707
2708    pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
2709        self.message_feedback.get(&message_id).copied()
2710    }
2711
2712    pub fn report_message_feedback(
2713        &mut self,
2714        message_id: MessageId,
2715        feedback: ThreadFeedback,
2716        cx: &mut Context<Self>,
2717    ) -> Task<Result<()>> {
2718        if self.message_feedback.get(&message_id) == Some(&feedback) {
2719            return Task::ready(Ok(()));
2720        }
2721
2722        let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2723        let serialized_thread = self.serialize(cx);
2724        let thread_id = self.id().clone();
2725        let client = self.project.read(cx).client();
2726
2727        let enabled_tool_names: Vec<String> = self
2728            .profile
2729            .enabled_tools(cx)
2730            .iter()
2731            .map(|(name, _)| name.clone().into())
2732            .collect();
2733
2734        self.message_feedback.insert(message_id, feedback);
2735
2736        cx.notify();
2737
2738        let message_content = self
2739            .message(message_id)
2740            .map(|msg| msg.to_string())
2741            .unwrap_or_default();
2742
2743        cx.background_spawn(async move {
2744            let final_project_snapshot = final_project_snapshot.await;
2745            let serialized_thread = serialized_thread.await?;
2746            let thread_data =
2747                serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
2748
2749            let rating = match feedback {
2750                ThreadFeedback::Positive => "positive",
2751                ThreadFeedback::Negative => "negative",
2752            };
2753            telemetry::event!(
2754                "Assistant Thread Rated",
2755                rating,
2756                thread_id,
2757                enabled_tool_names,
2758                message_id = message_id.0,
2759                message_content,
2760                thread_data,
2761                final_project_snapshot
2762            );
2763            client.telemetry().flush_events().await;
2764
2765            Ok(())
2766        })
2767    }
2768
2769    pub fn report_feedback(
2770        &mut self,
2771        feedback: ThreadFeedback,
2772        cx: &mut Context<Self>,
2773    ) -> Task<Result<()>> {
2774        let last_assistant_message_id = self
2775            .messages
2776            .iter()
2777            .rev()
2778            .find(|msg| msg.role == Role::Assistant)
2779            .map(|msg| msg.id);
2780
2781        if let Some(message_id) = last_assistant_message_id {
2782            self.report_message_feedback(message_id, feedback, cx)
2783        } else {
2784            let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2785            let serialized_thread = self.serialize(cx);
2786            let thread_id = self.id().clone();
2787            let client = self.project.read(cx).client();
2788            self.feedback = Some(feedback);
2789            cx.notify();
2790
2791            cx.background_spawn(async move {
2792                let final_project_snapshot = final_project_snapshot.await;
2793                let serialized_thread = serialized_thread.await?;
2794                let thread_data = serde_json::to_value(serialized_thread)
2795                    .unwrap_or_else(|_| serde_json::Value::Null);
2796
2797                let rating = match feedback {
2798                    ThreadFeedback::Positive => "positive",
2799                    ThreadFeedback::Negative => "negative",
2800                };
2801                telemetry::event!(
2802                    "Assistant Thread Rated",
2803                    rating,
2804                    thread_id,
2805                    thread_data,
2806                    final_project_snapshot
2807                );
2808                client.telemetry().flush_events().await;
2809
2810                Ok(())
2811            })
2812        }
2813    }
2814
2815    /// Create a snapshot of the current project state including git information and unsaved buffers.
2816    fn project_snapshot(
2817        project: Entity<Project>,
2818        cx: &mut Context<Self>,
2819    ) -> Task<Arc<ProjectSnapshot>> {
2820        let git_store = project.read(cx).git_store().clone();
2821        let worktree_snapshots: Vec<_> = project
2822            .read(cx)
2823            .visible_worktrees(cx)
2824            .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
2825            .collect();
2826
2827        cx.spawn(async move |_, cx| {
2828            let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
2829
2830            let mut unsaved_buffers = Vec::new();
2831            cx.update(|app_cx| {
2832                let buffer_store = project.read(app_cx).buffer_store();
2833                for buffer_handle in buffer_store.read(app_cx).buffers() {
2834                    let buffer = buffer_handle.read(app_cx);
2835                    if buffer.is_dirty() {
2836                        if let Some(file) = buffer.file() {
2837                            let path = file.path().to_string_lossy().to_string();
2838                            unsaved_buffers.push(path);
2839                        }
2840                    }
2841                }
2842            })
2843            .ok();
2844
2845            Arc::new(ProjectSnapshot {
2846                worktree_snapshots,
2847                unsaved_buffer_paths: unsaved_buffers,
2848                timestamp: Utc::now(),
2849            })
2850        })
2851    }
2852
2853    fn worktree_snapshot(
2854        worktree: Entity<project::Worktree>,
2855        git_store: Entity<GitStore>,
2856        cx: &App,
2857    ) -> Task<WorktreeSnapshot> {
2858        cx.spawn(async move |cx| {
2859            // Get worktree path and snapshot
2860            let worktree_info = cx.update(|app_cx| {
2861                let worktree = worktree.read(app_cx);
2862                let path = worktree.abs_path().to_string_lossy().to_string();
2863                let snapshot = worktree.snapshot();
2864                (path, snapshot)
2865            });
2866
2867            let Ok((worktree_path, _snapshot)) = worktree_info else {
2868                return WorktreeSnapshot {
2869                    worktree_path: String::new(),
2870                    git_state: None,
2871                };
2872            };
2873
2874            let git_state = git_store
2875                .update(cx, |git_store, cx| {
2876                    git_store
2877                        .repositories()
2878                        .values()
2879                        .find(|repo| {
2880                            repo.read(cx)
2881                                .abs_path_to_repo_path(&worktree.read(cx).abs_path())
2882                                .is_some()
2883                        })
2884                        .cloned()
2885                })
2886                .ok()
2887                .flatten()
2888                .map(|repo| {
2889                    repo.update(cx, |repo, _| {
2890                        let current_branch =
2891                            repo.branch.as_ref().map(|branch| branch.name().to_owned());
2892                        repo.send_job(None, |state, _| async move {
2893                            let RepositoryState::Local { backend, .. } = state else {
2894                                return GitState {
2895                                    remote_url: None,
2896                                    head_sha: None,
2897                                    current_branch,
2898                                    diff: None,
2899                                };
2900                            };
2901
2902                            let remote_url = backend.remote_url("origin");
2903                            let head_sha = backend.head_sha().await;
2904                            let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
2905
2906                            GitState {
2907                                remote_url,
2908                                head_sha,
2909                                current_branch,
2910                                diff,
2911                            }
2912                        })
2913                    })
2914                });
2915
2916            let git_state = match git_state {
2917                Some(git_state) => match git_state.ok() {
2918                    Some(git_state) => git_state.await.ok(),
2919                    None => None,
2920                },
2921                None => None,
2922            };
2923
2924            WorktreeSnapshot {
2925                worktree_path,
2926                git_state,
2927            }
2928        })
2929    }
2930
2931    pub fn to_markdown(&self, cx: &App) -> Result<String> {
2932        let mut markdown = Vec::new();
2933
2934        let summary = self.summary().or_default();
2935        writeln!(markdown, "# {summary}\n")?;
2936
2937        for message in self.messages() {
2938            writeln!(
2939                markdown,
2940                "## {role}\n",
2941                role = match message.role {
2942                    Role::User => "User",
2943                    Role::Assistant => "Agent",
2944                    Role::System => "System",
2945                }
2946            )?;
2947
2948            if !message.loaded_context.text.is_empty() {
2949                writeln!(markdown, "{}", message.loaded_context.text)?;
2950            }
2951
2952            if !message.loaded_context.images.is_empty() {
2953                writeln!(
2954                    markdown,
2955                    "\n{} images attached as context.\n",
2956                    message.loaded_context.images.len()
2957                )?;
2958            }
2959
2960            for segment in &message.segments {
2961                match segment {
2962                    MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
2963                    MessageSegment::Thinking { text, .. } => {
2964                        writeln!(markdown, "<think>\n{}\n</think>\n", text)?
2965                    }
2966                    MessageSegment::RedactedThinking(_) => {}
2967                }
2968            }
2969
2970            for tool_use in self.tool_uses_for_message(message.id, cx) {
2971                writeln!(
2972                    markdown,
2973                    "**Use Tool: {} ({})**",
2974                    tool_use.name, tool_use.id
2975                )?;
2976                writeln!(markdown, "```json")?;
2977                writeln!(
2978                    markdown,
2979                    "{}",
2980                    serde_json::to_string_pretty(&tool_use.input)?
2981                )?;
2982                writeln!(markdown, "```")?;
2983            }
2984
2985            for tool_result in self.tool_results_for_message(message.id) {
2986                write!(markdown, "\n**Tool Results: {}", tool_result.tool_use_id)?;
2987                if tool_result.is_error {
2988                    write!(markdown, " (Error)")?;
2989                }
2990
2991                writeln!(markdown, "**\n")?;
2992                match &tool_result.content {
2993                    LanguageModelToolResultContent::Text(text) => {
2994                        writeln!(markdown, "{text}")?;
2995                    }
2996                    LanguageModelToolResultContent::Image(image) => {
2997                        writeln!(markdown, "![Image](data:base64,{})", image.source)?;
2998                    }
2999                }
3000
3001                if let Some(output) = tool_result.output.as_ref() {
3002                    writeln!(
3003                        markdown,
3004                        "\n\nDebug Output:\n\n```json\n{}\n```\n",
3005                        serde_json::to_string_pretty(output)?
3006                    )?;
3007                }
3008            }
3009        }
3010
3011        Ok(String::from_utf8_lossy(&markdown).to_string())
3012    }
3013
3014    pub fn keep_edits_in_range(
3015        &mut self,
3016        buffer: Entity<language::Buffer>,
3017        buffer_range: Range<language::Anchor>,
3018        cx: &mut Context<Self>,
3019    ) {
3020        self.action_log.update(cx, |action_log, cx| {
3021            action_log.keep_edits_in_range(buffer, buffer_range, cx)
3022        });
3023    }
3024
3025    pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
3026        self.action_log
3027            .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
3028    }
3029
3030    pub fn reject_edits_in_ranges(
3031        &mut self,
3032        buffer: Entity<language::Buffer>,
3033        buffer_ranges: Vec<Range<language::Anchor>>,
3034        cx: &mut Context<Self>,
3035    ) -> Task<Result<()>> {
3036        self.action_log.update(cx, |action_log, cx| {
3037            action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
3038        })
3039    }
3040
3041    pub fn action_log(&self) -> &Entity<ActionLog> {
3042        &self.action_log
3043    }
3044
3045    pub fn project(&self) -> &Entity<Project> {
3046        &self.project
3047    }
3048
3049    pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
3050        if !cx.has_flag::<feature_flags::ThreadAutoCaptureFeatureFlag>() {
3051            return;
3052        }
3053
3054        let now = Instant::now();
3055        if let Some(last) = self.last_auto_capture_at {
3056            if now.duration_since(last).as_secs() < 10 {
3057                return;
3058            }
3059        }
3060
3061        self.last_auto_capture_at = Some(now);
3062
3063        let thread_id = self.id().clone();
3064        let github_login = self
3065            .project
3066            .read(cx)
3067            .user_store()
3068            .read(cx)
3069            .current_user()
3070            .map(|user| user.github_login.clone());
3071        let client = self.project.read(cx).client();
3072        let serialize_task = self.serialize(cx);
3073
3074        cx.background_executor()
3075            .spawn(async move {
3076                if let Ok(serialized_thread) = serialize_task.await {
3077                    if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
3078                        telemetry::event!(
3079                            "Agent Thread Auto-Captured",
3080                            thread_id = thread_id.to_string(),
3081                            thread_data = thread_data,
3082                            auto_capture_reason = "tracked_user",
3083                            github_login = github_login
3084                        );
3085
3086                        client.telemetry().flush_events().await;
3087                    }
3088                }
3089            })
3090            .detach();
3091    }
3092
3093    pub fn cumulative_token_usage(&self) -> TokenUsage {
3094        self.cumulative_token_usage
3095    }
3096
3097    pub fn token_usage_up_to_message(&self, message_id: MessageId) -> TotalTokenUsage {
3098        let Some(model) = self.configured_model.as_ref() else {
3099            return TotalTokenUsage::default();
3100        };
3101
3102        let max = model
3103            .model
3104            .max_token_count_for_mode(self.completion_mode().into());
3105
3106        let index = self
3107            .messages
3108            .iter()
3109            .position(|msg| msg.id == message_id)
3110            .unwrap_or(0);
3111
3112        if index == 0 {
3113            return TotalTokenUsage { total: 0, max };
3114        }
3115
3116        let token_usage = &self
3117            .request_token_usage
3118            .get(index - 1)
3119            .cloned()
3120            .unwrap_or_default();
3121
3122        TotalTokenUsage {
3123            total: token_usage.total_tokens(),
3124            max,
3125        }
3126    }
3127
3128    pub fn total_token_usage(&self) -> Option<TotalTokenUsage> {
3129        let model = self.configured_model.as_ref()?;
3130
3131        let max = model
3132            .model
3133            .max_token_count_for_mode(self.completion_mode().into());
3134
3135        if let Some(exceeded_error) = &self.exceeded_window_error {
3136            if model.model.id() == exceeded_error.model_id {
3137                return Some(TotalTokenUsage {
3138                    total: exceeded_error.token_count,
3139                    max,
3140                });
3141            }
3142        }
3143
3144        let total = self
3145            .token_usage_at_last_message()
3146            .unwrap_or_default()
3147            .total_tokens();
3148
3149        Some(TotalTokenUsage { total, max })
3150    }
3151
3152    fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
3153        self.request_token_usage
3154            .get(self.messages.len().saturating_sub(1))
3155            .or_else(|| self.request_token_usage.last())
3156            .cloned()
3157    }
3158
3159    fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
3160        let placeholder = self.token_usage_at_last_message().unwrap_or_default();
3161        self.request_token_usage
3162            .resize(self.messages.len(), placeholder);
3163
3164        if let Some(last) = self.request_token_usage.last_mut() {
3165            *last = token_usage;
3166        }
3167    }
3168
3169    fn update_model_request_usage(&self, amount: u32, limit: UsageLimit, cx: &mut Context<Self>) {
3170        self.project.update(cx, |project, cx| {
3171            project.user_store().update(cx, |user_store, cx| {
3172                user_store.update_model_request_usage(
3173                    ModelRequestUsage(RequestUsage {
3174                        amount: amount as i32,
3175                        limit,
3176                    }),
3177                    cx,
3178                )
3179            })
3180        });
3181    }
3182
3183    pub fn deny_tool_use(
3184        &mut self,
3185        tool_use_id: LanguageModelToolUseId,
3186        tool_name: Arc<str>,
3187        window: Option<AnyWindowHandle>,
3188        cx: &mut Context<Self>,
3189    ) {
3190        let err = Err(anyhow::anyhow!(
3191            "Permission to run tool action denied by user"
3192        ));
3193
3194        self.tool_use.insert_tool_output(
3195            tool_use_id.clone(),
3196            tool_name,
3197            err,
3198            self.configured_model.as_ref(),
3199            self.completion_mode,
3200        );
3201        self.tool_finished(tool_use_id.clone(), None, true, window, cx);
3202    }
3203}
3204
3205#[derive(Debug, Clone, Error)]
3206pub enum ThreadError {
3207    #[error("Payment required")]
3208    PaymentRequired,
3209    #[error("Model request limit reached")]
3210    ModelRequestLimitReached { plan: Plan },
3211    #[error("Message {header}: {message}")]
3212    Message {
3213        header: SharedString,
3214        message: SharedString,
3215    },
3216}
3217
3218#[derive(Debug, Clone)]
3219pub enum ThreadEvent {
3220    ShowError(ThreadError),
3221    StreamedCompletion,
3222    ReceivedTextChunk,
3223    NewRequest,
3224    StreamedAssistantText(MessageId, String),
3225    StreamedAssistantThinking(MessageId, String),
3226    StreamedToolUse {
3227        tool_use_id: LanguageModelToolUseId,
3228        ui_text: Arc<str>,
3229        input: serde_json::Value,
3230    },
3231    MissingToolUse {
3232        tool_use_id: LanguageModelToolUseId,
3233        ui_text: Arc<str>,
3234    },
3235    InvalidToolInput {
3236        tool_use_id: LanguageModelToolUseId,
3237        ui_text: Arc<str>,
3238        invalid_input_json: Arc<str>,
3239    },
3240    Stopped(Result<StopReason, Arc<anyhow::Error>>),
3241    MessageAdded(MessageId),
3242    MessageEdited(MessageId),
3243    MessageDeleted(MessageId),
3244    SummaryGenerated,
3245    SummaryChanged,
3246    UsePendingTools {
3247        tool_uses: Vec<PendingToolUse>,
3248    },
3249    ToolFinished {
3250        #[allow(unused)]
3251        tool_use_id: LanguageModelToolUseId,
3252        /// The pending tool use that corresponds to this tool.
3253        pending_tool_use: Option<PendingToolUse>,
3254    },
3255    CheckpointChanged,
3256    ToolConfirmationNeeded,
3257    ToolUseLimitReached,
3258    CancelEditing,
3259    CompletionCanceled,
3260    ProfileChanged,
3261    RetriesFailed {
3262        message: SharedString,
3263    },
3264}
3265
3266impl EventEmitter<ThreadEvent> for Thread {}
3267
3268struct PendingCompletion {
3269    id: usize,
3270    queue_state: QueueState,
3271    _task: Task<()>,
3272}
3273
3274#[cfg(test)]
3275mod tests {
3276    use super::*;
3277    use crate::{
3278        context::load_context, context_store::ContextStore, thread_store, thread_store::ThreadStore,
3279    };
3280
3281    // Test-specific constants
3282    const TEST_RATE_LIMIT_RETRY_SECS: u64 = 30;
3283    use agent_settings::{AgentProfileId, AgentSettings, LanguageModelParameters};
3284    use assistant_tool::ToolRegistry;
3285    use assistant_tools;
3286    use futures::StreamExt;
3287    use futures::future::BoxFuture;
3288    use futures::stream::BoxStream;
3289    use gpui::TestAppContext;
3290    use http_client;
3291    use indoc::indoc;
3292    use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
3293    use language_model::{
3294        LanguageModelCompletionError, LanguageModelName, LanguageModelProviderId,
3295        LanguageModelProviderName, LanguageModelToolChoice,
3296    };
3297    use parking_lot::Mutex;
3298    use project::{FakeFs, Project};
3299    use prompt_store::PromptBuilder;
3300    use serde_json::json;
3301    use settings::{Settings, SettingsStore};
3302    use std::sync::Arc;
3303    use std::time::Duration;
3304    use theme::ThemeSettings;
3305    use util::path;
3306    use workspace::Workspace;
3307
3308    #[gpui::test]
3309    async fn test_message_with_context(cx: &mut TestAppContext) {
3310        init_test_settings(cx);
3311
3312        let project = create_test_project(
3313            cx,
3314            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
3315        )
3316        .await;
3317
3318        let (_workspace, _thread_store, thread, context_store, model) =
3319            setup_test_environment(cx, project.clone()).await;
3320
3321        add_file_to_context(&project, &context_store, "test/code.rs", cx)
3322            .await
3323            .unwrap();
3324
3325        let context =
3326            context_store.read_with(cx, |store, _| store.context().next().cloned().unwrap());
3327        let loaded_context = cx
3328            .update(|cx| load_context(vec![context], &project, &None, cx))
3329            .await;
3330
3331        // Insert user message with context
3332        let message_id = thread.update(cx, |thread, cx| {
3333            thread.insert_user_message(
3334                "Please explain this code",
3335                loaded_context,
3336                None,
3337                Vec::new(),
3338                cx,
3339            )
3340        });
3341
3342        // Check content and context in message object
3343        let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
3344
3345        // Use different path format strings based on platform for the test
3346        #[cfg(windows)]
3347        let path_part = r"test\code.rs";
3348        #[cfg(not(windows))]
3349        let path_part = "test/code.rs";
3350
3351        let expected_context = format!(
3352            r#"
3353<context>
3354The following items were attached by the user. They are up-to-date and don't need to be re-read.
3355
3356<files>
3357```rs {path_part}
3358fn main() {{
3359    println!("Hello, world!");
3360}}
3361```
3362</files>
3363</context>
3364"#
3365        );
3366
3367        assert_eq!(message.role, Role::User);
3368        assert_eq!(message.segments.len(), 1);
3369        assert_eq!(
3370            message.segments[0],
3371            MessageSegment::Text("Please explain this code".to_string())
3372        );
3373        assert_eq!(message.loaded_context.text, expected_context);
3374
3375        // Check message in request
3376        let request = thread.update(cx, |thread, cx| {
3377            thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3378        });
3379
3380        assert_eq!(request.messages.len(), 2);
3381        let expected_full_message = format!("{}Please explain this code", expected_context);
3382        assert_eq!(request.messages[1].string_contents(), expected_full_message);
3383    }
3384
3385    #[gpui::test]
3386    async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
3387        init_test_settings(cx);
3388
3389        let project = create_test_project(
3390            cx,
3391            json!({
3392                "file1.rs": "fn function1() {}\n",
3393                "file2.rs": "fn function2() {}\n",
3394                "file3.rs": "fn function3() {}\n",
3395                "file4.rs": "fn function4() {}\n",
3396            }),
3397        )
3398        .await;
3399
3400        let (_, _thread_store, thread, context_store, model) =
3401            setup_test_environment(cx, project.clone()).await;
3402
3403        // First message with context 1
3404        add_file_to_context(&project, &context_store, "test/file1.rs", cx)
3405            .await
3406            .unwrap();
3407        let new_contexts = context_store.update(cx, |store, cx| {
3408            store.new_context_for_thread(thread.read(cx), None)
3409        });
3410        assert_eq!(new_contexts.len(), 1);
3411        let loaded_context = cx
3412            .update(|cx| load_context(new_contexts, &project, &None, cx))
3413            .await;
3414        let message1_id = thread.update(cx, |thread, cx| {
3415            thread.insert_user_message("Message 1", loaded_context, None, Vec::new(), cx)
3416        });
3417
3418        // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
3419        add_file_to_context(&project, &context_store, "test/file2.rs", cx)
3420            .await
3421            .unwrap();
3422        let new_contexts = context_store.update(cx, |store, cx| {
3423            store.new_context_for_thread(thread.read(cx), None)
3424        });
3425        assert_eq!(new_contexts.len(), 1);
3426        let loaded_context = cx
3427            .update(|cx| load_context(new_contexts, &project, &None, cx))
3428            .await;
3429        let message2_id = thread.update(cx, |thread, cx| {
3430            thread.insert_user_message("Message 2", loaded_context, None, Vec::new(), cx)
3431        });
3432
3433        // Third message with all three contexts (contexts 1 and 2 should be skipped)
3434        //
3435        add_file_to_context(&project, &context_store, "test/file3.rs", cx)
3436            .await
3437            .unwrap();
3438        let new_contexts = context_store.update(cx, |store, cx| {
3439            store.new_context_for_thread(thread.read(cx), None)
3440        });
3441        assert_eq!(new_contexts.len(), 1);
3442        let loaded_context = cx
3443            .update(|cx| load_context(new_contexts, &project, &None, cx))
3444            .await;
3445        let message3_id = thread.update(cx, |thread, cx| {
3446            thread.insert_user_message("Message 3", loaded_context, None, Vec::new(), cx)
3447        });
3448
3449        // Check what contexts are included in each message
3450        let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
3451            (
3452                thread.message(message1_id).unwrap().clone(),
3453                thread.message(message2_id).unwrap().clone(),
3454                thread.message(message3_id).unwrap().clone(),
3455            )
3456        });
3457
3458        // First message should include context 1
3459        assert!(message1.loaded_context.text.contains("file1.rs"));
3460
3461        // Second message should include only context 2 (not 1)
3462        assert!(!message2.loaded_context.text.contains("file1.rs"));
3463        assert!(message2.loaded_context.text.contains("file2.rs"));
3464
3465        // Third message should include only context 3 (not 1 or 2)
3466        assert!(!message3.loaded_context.text.contains("file1.rs"));
3467        assert!(!message3.loaded_context.text.contains("file2.rs"));
3468        assert!(message3.loaded_context.text.contains("file3.rs"));
3469
3470        // Check entire request to make sure all contexts are properly included
3471        let request = thread.update(cx, |thread, cx| {
3472            thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3473        });
3474
3475        // The request should contain all 3 messages
3476        assert_eq!(request.messages.len(), 4);
3477
3478        // Check that the contexts are properly formatted in each message
3479        assert!(request.messages[1].string_contents().contains("file1.rs"));
3480        assert!(!request.messages[1].string_contents().contains("file2.rs"));
3481        assert!(!request.messages[1].string_contents().contains("file3.rs"));
3482
3483        assert!(!request.messages[2].string_contents().contains("file1.rs"));
3484        assert!(request.messages[2].string_contents().contains("file2.rs"));
3485        assert!(!request.messages[2].string_contents().contains("file3.rs"));
3486
3487        assert!(!request.messages[3].string_contents().contains("file1.rs"));
3488        assert!(!request.messages[3].string_contents().contains("file2.rs"));
3489        assert!(request.messages[3].string_contents().contains("file3.rs"));
3490
3491        add_file_to_context(&project, &context_store, "test/file4.rs", cx)
3492            .await
3493            .unwrap();
3494        let new_contexts = context_store.update(cx, |store, cx| {
3495            store.new_context_for_thread(thread.read(cx), Some(message2_id))
3496        });
3497        assert_eq!(new_contexts.len(), 3);
3498        let loaded_context = cx
3499            .update(|cx| load_context(new_contexts, &project, &None, cx))
3500            .await
3501            .loaded_context;
3502
3503        assert!(!loaded_context.text.contains("file1.rs"));
3504        assert!(loaded_context.text.contains("file2.rs"));
3505        assert!(loaded_context.text.contains("file3.rs"));
3506        assert!(loaded_context.text.contains("file4.rs"));
3507
3508        let new_contexts = context_store.update(cx, |store, cx| {
3509            // Remove file4.rs
3510            store.remove_context(&loaded_context.contexts[2].handle(), cx);
3511            store.new_context_for_thread(thread.read(cx), Some(message2_id))
3512        });
3513        assert_eq!(new_contexts.len(), 2);
3514        let loaded_context = cx
3515            .update(|cx| load_context(new_contexts, &project, &None, cx))
3516            .await
3517            .loaded_context;
3518
3519        assert!(!loaded_context.text.contains("file1.rs"));
3520        assert!(loaded_context.text.contains("file2.rs"));
3521        assert!(loaded_context.text.contains("file3.rs"));
3522        assert!(!loaded_context.text.contains("file4.rs"));
3523
3524        let new_contexts = context_store.update(cx, |store, cx| {
3525            // Remove file3.rs
3526            store.remove_context(&loaded_context.contexts[1].handle(), cx);
3527            store.new_context_for_thread(thread.read(cx), Some(message2_id))
3528        });
3529        assert_eq!(new_contexts.len(), 1);
3530        let loaded_context = cx
3531            .update(|cx| load_context(new_contexts, &project, &None, cx))
3532            .await
3533            .loaded_context;
3534
3535        assert!(!loaded_context.text.contains("file1.rs"));
3536        assert!(loaded_context.text.contains("file2.rs"));
3537        assert!(!loaded_context.text.contains("file3.rs"));
3538        assert!(!loaded_context.text.contains("file4.rs"));
3539    }
3540
3541    #[gpui::test]
3542    async fn test_message_without_files(cx: &mut TestAppContext) {
3543        init_test_settings(cx);
3544
3545        let project = create_test_project(
3546            cx,
3547            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
3548        )
3549        .await;
3550
3551        let (_, _thread_store, thread, _context_store, model) =
3552            setup_test_environment(cx, project.clone()).await;
3553
3554        // Insert user message without any context (empty context vector)
3555        let message_id = thread.update(cx, |thread, cx| {
3556            thread.insert_user_message(
3557                "What is the best way to learn Rust?",
3558                ContextLoadResult::default(),
3559                None,
3560                Vec::new(),
3561                cx,
3562            )
3563        });
3564
3565        // Check content and context in message object
3566        let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
3567
3568        // Context should be empty when no files are included
3569        assert_eq!(message.role, Role::User);
3570        assert_eq!(message.segments.len(), 1);
3571        assert_eq!(
3572            message.segments[0],
3573            MessageSegment::Text("What is the best way to learn Rust?".to_string())
3574        );
3575        assert_eq!(message.loaded_context.text, "");
3576
3577        // Check message in request
3578        let request = thread.update(cx, |thread, cx| {
3579            thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3580        });
3581
3582        assert_eq!(request.messages.len(), 2);
3583        assert_eq!(
3584            request.messages[1].string_contents(),
3585            "What is the best way to learn Rust?"
3586        );
3587
3588        // Add second message, also without context
3589        let message2_id = thread.update(cx, |thread, cx| {
3590            thread.insert_user_message(
3591                "Are there any good books?",
3592                ContextLoadResult::default(),
3593                None,
3594                Vec::new(),
3595                cx,
3596            )
3597        });
3598
3599        let message2 =
3600            thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
3601        assert_eq!(message2.loaded_context.text, "");
3602
3603        // Check that both messages appear in the request
3604        let request = thread.update(cx, |thread, cx| {
3605            thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3606        });
3607
3608        assert_eq!(request.messages.len(), 3);
3609        assert_eq!(
3610            request.messages[1].string_contents(),
3611            "What is the best way to learn Rust?"
3612        );
3613        assert_eq!(
3614            request.messages[2].string_contents(),
3615            "Are there any good books?"
3616        );
3617    }
3618
3619    #[gpui::test]
3620    async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
3621        init_test_settings(cx);
3622
3623        let project = create_test_project(
3624            cx,
3625            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
3626        )
3627        .await;
3628
3629        let (_workspace, _thread_store, thread, context_store, model) =
3630            setup_test_environment(cx, project.clone()).await;
3631
3632        // Add a buffer to the context. This will be a tracked buffer
3633        let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
3634            .await
3635            .unwrap();
3636
3637        let context = context_store
3638            .read_with(cx, |store, _| store.context().next().cloned())
3639            .unwrap();
3640        let loaded_context = cx
3641            .update(|cx| load_context(vec![context], &project, &None, cx))
3642            .await;
3643
3644        // Insert user message and assistant response
3645        thread.update(cx, |thread, cx| {
3646            thread.insert_user_message("Explain this code", loaded_context, None, Vec::new(), cx);
3647            thread.insert_assistant_message(
3648                vec![MessageSegment::Text("This code prints 42.".into())],
3649                cx,
3650            );
3651        });
3652
3653        // We shouldn't have a stale buffer notification yet
3654        let notifications = thread.read_with(cx, |thread, _| {
3655            find_tool_uses(thread, "project_notifications")
3656        });
3657        assert!(
3658            notifications.is_empty(),
3659            "Should not have stale buffer notification before buffer is modified"
3660        );
3661
3662        // Modify the buffer
3663        buffer.update(cx, |buffer, cx| {
3664            buffer.edit(
3665                [(1..1, "\n    println!(\"Added a new line\");\n")],
3666                None,
3667                cx,
3668            );
3669        });
3670
3671        // Insert another user message
3672        thread.update(cx, |thread, cx| {
3673            thread.insert_user_message(
3674                "What does the code do now?",
3675                ContextLoadResult::default(),
3676                None,
3677                Vec::new(),
3678                cx,
3679            )
3680        });
3681
3682        // Check for the stale buffer warning
3683        thread.update(cx, |thread, cx| {
3684            thread.flush_notifications(model.clone(), CompletionIntent::UserPrompt, cx)
3685        });
3686
3687        let notifications = thread.read_with(cx, |thread, _cx| {
3688            find_tool_uses(thread, "project_notifications")
3689        });
3690
3691        let [notification] = notifications.as_slice() else {
3692            panic!("Should have a `project_notifications` tool use");
3693        };
3694
3695        let Some(notification_content) = notification.content.to_str() else {
3696            panic!("`project_notifications` should return text");
3697        };
3698
3699        let expected_content = indoc! {"[The following is an auto-generated notification; do not reply]
3700
3701        These files have changed since the last read:
3702        - code.rs
3703        "};
3704        assert_eq!(notification_content, expected_content);
3705
3706        // Insert another user message and flush notifications again
3707        thread.update(cx, |thread, cx| {
3708            thread.insert_user_message(
3709                "Can you tell me more?",
3710                ContextLoadResult::default(),
3711                None,
3712                Vec::new(),
3713                cx,
3714            )
3715        });
3716
3717        thread.update(cx, |thread, cx| {
3718            thread.flush_notifications(model.clone(), CompletionIntent::UserPrompt, cx)
3719        });
3720
3721        // There should be no new notifications (we already flushed one)
3722        let notifications = thread.read_with(cx, |thread, _cx| {
3723            find_tool_uses(thread, "project_notifications")
3724        });
3725
3726        assert_eq!(
3727            notifications.len(),
3728            1,
3729            "Should still have only one notification after second flush - no duplicates"
3730        );
3731    }
3732
3733    fn find_tool_uses(thread: &Thread, tool_name: &str) -> Vec<LanguageModelToolResult> {
3734        thread
3735            .messages()
3736            .flat_map(|message| {
3737                thread
3738                    .tool_results_for_message(message.id)
3739                    .into_iter()
3740                    .filter(|result| result.tool_name == tool_name.into())
3741                    .cloned()
3742                    .collect::<Vec<_>>()
3743            })
3744            .collect()
3745    }
3746
3747    #[gpui::test]
3748    async fn test_storing_profile_setting_per_thread(cx: &mut TestAppContext) {
3749        init_test_settings(cx);
3750
3751        let project = create_test_project(
3752            cx,
3753            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
3754        )
3755        .await;
3756
3757        let (_workspace, thread_store, thread, _context_store, _model) =
3758            setup_test_environment(cx, project.clone()).await;
3759
3760        // Check that we are starting with the default profile
3761        let profile = cx.read(|cx| thread.read(cx).profile.clone());
3762        let tool_set = cx.read(|cx| thread_store.read(cx).tools());
3763        assert_eq!(
3764            profile,
3765            AgentProfile::new(AgentProfileId::default(), tool_set)
3766        );
3767    }
3768
3769    #[gpui::test]
3770    async fn test_serializing_thread_profile(cx: &mut TestAppContext) {
3771        init_test_settings(cx);
3772
3773        let project = create_test_project(
3774            cx,
3775            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
3776        )
3777        .await;
3778
3779        let (_workspace, thread_store, thread, _context_store, _model) =
3780            setup_test_environment(cx, project.clone()).await;
3781
3782        // Profile gets serialized with default values
3783        let serialized = thread
3784            .update(cx, |thread, cx| thread.serialize(cx))
3785            .await
3786            .unwrap();
3787
3788        assert_eq!(serialized.profile, Some(AgentProfileId::default()));
3789
3790        let deserialized = cx.update(|cx| {
3791            thread.update(cx, |thread, cx| {
3792                Thread::deserialize(
3793                    thread.id.clone(),
3794                    serialized,
3795                    thread.project.clone(),
3796                    thread.tools.clone(),
3797                    thread.prompt_builder.clone(),
3798                    thread.project_context.clone(),
3799                    None,
3800                    cx,
3801                )
3802            })
3803        });
3804        let tool_set = cx.read(|cx| thread_store.read(cx).tools());
3805
3806        assert_eq!(
3807            deserialized.profile,
3808            AgentProfile::new(AgentProfileId::default(), tool_set)
3809        );
3810    }
3811
3812    #[gpui::test]
3813    async fn test_temperature_setting(cx: &mut TestAppContext) {
3814        init_test_settings(cx);
3815
3816        let project = create_test_project(
3817            cx,
3818            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
3819        )
3820        .await;
3821
3822        let (_workspace, _thread_store, thread, _context_store, model) =
3823            setup_test_environment(cx, project.clone()).await;
3824
3825        // Both model and provider
3826        cx.update(|cx| {
3827            AgentSettings::override_global(
3828                AgentSettings {
3829                    model_parameters: vec![LanguageModelParameters {
3830                        provider: Some(model.provider_id().0.to_string().into()),
3831                        model: Some(model.id().0.clone()),
3832                        temperature: Some(0.66),
3833                    }],
3834                    ..AgentSettings::get_global(cx).clone()
3835                },
3836                cx,
3837            );
3838        });
3839
3840        let request = thread.update(cx, |thread, cx| {
3841            thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3842        });
3843        assert_eq!(request.temperature, Some(0.66));
3844
3845        // Only model
3846        cx.update(|cx| {
3847            AgentSettings::override_global(
3848                AgentSettings {
3849                    model_parameters: vec![LanguageModelParameters {
3850                        provider: None,
3851                        model: Some(model.id().0.clone()),
3852                        temperature: Some(0.66),
3853                    }],
3854                    ..AgentSettings::get_global(cx).clone()
3855                },
3856                cx,
3857            );
3858        });
3859
3860        let request = thread.update(cx, |thread, cx| {
3861            thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3862        });
3863        assert_eq!(request.temperature, Some(0.66));
3864
3865        // Only provider
3866        cx.update(|cx| {
3867            AgentSettings::override_global(
3868                AgentSettings {
3869                    model_parameters: vec![LanguageModelParameters {
3870                        provider: Some(model.provider_id().0.to_string().into()),
3871                        model: None,
3872                        temperature: Some(0.66),
3873                    }],
3874                    ..AgentSettings::get_global(cx).clone()
3875                },
3876                cx,
3877            );
3878        });
3879
3880        let request = thread.update(cx, |thread, cx| {
3881            thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3882        });
3883        assert_eq!(request.temperature, Some(0.66));
3884
3885        // Same model name, different provider
3886        cx.update(|cx| {
3887            AgentSettings::override_global(
3888                AgentSettings {
3889                    model_parameters: vec![LanguageModelParameters {
3890                        provider: Some("anthropic".into()),
3891                        model: Some(model.id().0.clone()),
3892                        temperature: Some(0.66),
3893                    }],
3894                    ..AgentSettings::get_global(cx).clone()
3895                },
3896                cx,
3897            );
3898        });
3899
3900        let request = thread.update(cx, |thread, cx| {
3901            thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3902        });
3903        assert_eq!(request.temperature, None);
3904    }
3905
3906    #[gpui::test]
3907    async fn test_thread_summary(cx: &mut TestAppContext) {
3908        init_test_settings(cx);
3909
3910        let project = create_test_project(cx, json!({})).await;
3911
3912        let (_, _thread_store, thread, _context_store, model) =
3913            setup_test_environment(cx, project.clone()).await;
3914
3915        // Initial state should be pending
3916        thread.read_with(cx, |thread, _| {
3917            assert!(matches!(thread.summary(), ThreadSummary::Pending));
3918            assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3919        });
3920
3921        // Manually setting the summary should not be allowed in this state
3922        thread.update(cx, |thread, cx| {
3923            thread.set_summary("This should not work", cx);
3924        });
3925
3926        thread.read_with(cx, |thread, _| {
3927            assert!(matches!(thread.summary(), ThreadSummary::Pending));
3928        });
3929
3930        // Send a message
3931        thread.update(cx, |thread, cx| {
3932            thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
3933            thread.send_to_model(
3934                model.clone(),
3935                CompletionIntent::ThreadSummarization,
3936                None,
3937                cx,
3938            );
3939        });
3940
3941        let fake_model = model.as_fake();
3942        simulate_successful_response(&fake_model, cx);
3943
3944        // Should start generating summary when there are >= 2 messages
3945        thread.read_with(cx, |thread, _| {
3946            assert_eq!(*thread.summary(), ThreadSummary::Generating);
3947        });
3948
3949        // Should not be able to set the summary while generating
3950        thread.update(cx, |thread, cx| {
3951            thread.set_summary("This should not work either", cx);
3952        });
3953
3954        thread.read_with(cx, |thread, _| {
3955            assert!(matches!(thread.summary(), ThreadSummary::Generating));
3956            assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3957        });
3958
3959        cx.run_until_parked();
3960        fake_model.stream_last_completion_response("Brief");
3961        fake_model.stream_last_completion_response(" Introduction");
3962        fake_model.end_last_completion_stream();
3963        cx.run_until_parked();
3964
3965        // Summary should be set
3966        thread.read_with(cx, |thread, _| {
3967            assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3968            assert_eq!(thread.summary().or_default(), "Brief Introduction");
3969        });
3970
3971        // Now we should be able to set a summary
3972        thread.update(cx, |thread, cx| {
3973            thread.set_summary("Brief Intro", cx);
3974        });
3975
3976        thread.read_with(cx, |thread, _| {
3977            assert_eq!(thread.summary().or_default(), "Brief Intro");
3978        });
3979
3980        // Test setting an empty summary (should default to DEFAULT)
3981        thread.update(cx, |thread, cx| {
3982            thread.set_summary("", cx);
3983        });
3984
3985        thread.read_with(cx, |thread, _| {
3986            assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3987            assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3988        });
3989    }
3990
3991    #[gpui::test]
3992    async fn test_thread_summary_error_set_manually(cx: &mut TestAppContext) {
3993        init_test_settings(cx);
3994
3995        let project = create_test_project(cx, json!({})).await;
3996
3997        let (_, _thread_store, thread, _context_store, model) =
3998            setup_test_environment(cx, project.clone()).await;
3999
4000        test_summarize_error(&model, &thread, cx);
4001
4002        // Now we should be able to set a summary
4003        thread.update(cx, |thread, cx| {
4004            thread.set_summary("Brief Intro", cx);
4005        });
4006
4007        thread.read_with(cx, |thread, _| {
4008            assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
4009            assert_eq!(thread.summary().or_default(), "Brief Intro");
4010        });
4011    }
4012
4013    #[gpui::test]
4014    async fn test_thread_summary_error_retry(cx: &mut TestAppContext) {
4015        init_test_settings(cx);
4016
4017        let project = create_test_project(cx, json!({})).await;
4018
4019        let (_, _thread_store, thread, _context_store, model) =
4020            setup_test_environment(cx, project.clone()).await;
4021
4022        test_summarize_error(&model, &thread, cx);
4023
4024        // Sending another message should not trigger another summarize request
4025        thread.update(cx, |thread, cx| {
4026            thread.insert_user_message(
4027                "How are you?",
4028                ContextLoadResult::default(),
4029                None,
4030                vec![],
4031                cx,
4032            );
4033            thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4034        });
4035
4036        let fake_model = model.as_fake();
4037        simulate_successful_response(&fake_model, cx);
4038
4039        thread.read_with(cx, |thread, _| {
4040            // State is still Error, not Generating
4041            assert!(matches!(thread.summary(), ThreadSummary::Error));
4042        });
4043
4044        // But the summarize request can be invoked manually
4045        thread.update(cx, |thread, cx| {
4046            thread.summarize(cx);
4047        });
4048
4049        thread.read_with(cx, |thread, _| {
4050            assert!(matches!(thread.summary(), ThreadSummary::Generating));
4051        });
4052
4053        cx.run_until_parked();
4054        fake_model.stream_last_completion_response("A successful summary");
4055        fake_model.end_last_completion_stream();
4056        cx.run_until_parked();
4057
4058        thread.read_with(cx, |thread, _| {
4059            assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
4060            assert_eq!(thread.summary().or_default(), "A successful summary");
4061        });
4062    }
4063
4064    // Helper to create a model that returns errors
4065    enum TestError {
4066        Overloaded,
4067        InternalServerError,
4068    }
4069
4070    struct ErrorInjector {
4071        inner: Arc<FakeLanguageModel>,
4072        error_type: TestError,
4073    }
4074
4075    impl ErrorInjector {
4076        fn new(error_type: TestError) -> Self {
4077            Self {
4078                inner: Arc::new(FakeLanguageModel::default()),
4079                error_type,
4080            }
4081        }
4082    }
4083
4084    impl LanguageModel for ErrorInjector {
4085        fn id(&self) -> LanguageModelId {
4086            self.inner.id()
4087        }
4088
4089        fn name(&self) -> LanguageModelName {
4090            self.inner.name()
4091        }
4092
4093        fn provider_id(&self) -> LanguageModelProviderId {
4094            self.inner.provider_id()
4095        }
4096
4097        fn provider_name(&self) -> LanguageModelProviderName {
4098            self.inner.provider_name()
4099        }
4100
4101        fn supports_tools(&self) -> bool {
4102            self.inner.supports_tools()
4103        }
4104
4105        fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
4106            self.inner.supports_tool_choice(choice)
4107        }
4108
4109        fn supports_images(&self) -> bool {
4110            self.inner.supports_images()
4111        }
4112
4113        fn telemetry_id(&self) -> String {
4114            self.inner.telemetry_id()
4115        }
4116
4117        fn max_token_count(&self) -> u64 {
4118            self.inner.max_token_count()
4119        }
4120
4121        fn count_tokens(
4122            &self,
4123            request: LanguageModelRequest,
4124            cx: &App,
4125        ) -> BoxFuture<'static, Result<u64>> {
4126            self.inner.count_tokens(request, cx)
4127        }
4128
4129        fn stream_completion(
4130            &self,
4131            _request: LanguageModelRequest,
4132            _cx: &AsyncApp,
4133        ) -> BoxFuture<
4134            'static,
4135            Result<
4136                BoxStream<
4137                    'static,
4138                    Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
4139                >,
4140                LanguageModelCompletionError,
4141            >,
4142        > {
4143            let error = match self.error_type {
4144                TestError::Overloaded => LanguageModelCompletionError::ServerOverloaded {
4145                    provider: self.provider_name(),
4146                    retry_after: None,
4147                },
4148                TestError::InternalServerError => {
4149                    LanguageModelCompletionError::ApiInternalServerError {
4150                        provider: self.provider_name(),
4151                        message: "I'm a teapot orbiting the sun".to_string(),
4152                    }
4153                }
4154            };
4155            async move {
4156                let stream = futures::stream::once(async move { Err(error) });
4157                Ok(stream.boxed())
4158            }
4159            .boxed()
4160        }
4161
4162        fn as_fake(&self) -> &FakeLanguageModel {
4163            &self.inner
4164        }
4165    }
4166
4167    #[gpui::test]
4168    async fn test_retry_on_overloaded_error(cx: &mut TestAppContext) {
4169        init_test_settings(cx);
4170
4171        let project = create_test_project(cx, json!({})).await;
4172        let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4173
4174        // Create model that returns overloaded error
4175        let model = Arc::new(ErrorInjector::new(TestError::Overloaded));
4176
4177        // Insert a user message
4178        thread.update(cx, |thread, cx| {
4179            thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4180        });
4181
4182        // Start completion
4183        thread.update(cx, |thread, cx| {
4184            thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4185        });
4186
4187        cx.run_until_parked();
4188
4189        thread.read_with(cx, |thread, _| {
4190            assert!(thread.retry_state.is_some(), "Should have retry state");
4191            let retry_state = thread.retry_state.as_ref().unwrap();
4192            assert_eq!(retry_state.attempt, 1, "Should be first retry attempt");
4193            assert_eq!(
4194                retry_state.max_attempts, MAX_RETRY_ATTEMPTS,
4195                "Should have default max attempts"
4196            );
4197        });
4198
4199        // Check that a retry message was added
4200        thread.read_with(cx, |thread, _| {
4201            let mut messages = thread.messages();
4202            assert!(
4203                messages.any(|msg| {
4204                    msg.role == Role::System
4205                        && msg.ui_only
4206                        && msg.segments.iter().any(|seg| {
4207                            if let MessageSegment::Text(text) = seg {
4208                                text.contains("overloaded")
4209                                    && text
4210                                        .contains(&format!("attempt 1 of {}", MAX_RETRY_ATTEMPTS))
4211                            } else {
4212                                false
4213                            }
4214                        })
4215                }),
4216                "Should have added a system retry message"
4217            );
4218        });
4219
4220        let retry_count = thread.update(cx, |thread, _| {
4221            thread
4222                .messages
4223                .iter()
4224                .filter(|m| {
4225                    m.ui_only
4226                        && m.segments.iter().any(|s| {
4227                            if let MessageSegment::Text(text) = s {
4228                                text.contains("Retrying") && text.contains("seconds")
4229                            } else {
4230                                false
4231                            }
4232                        })
4233                })
4234                .count()
4235        });
4236
4237        assert_eq!(retry_count, 1, "Should have one retry message");
4238    }
4239
4240    #[gpui::test]
4241    async fn test_retry_on_internal_server_error(cx: &mut TestAppContext) {
4242        init_test_settings(cx);
4243
4244        let project = create_test_project(cx, json!({})).await;
4245        let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4246
4247        // Create model that returns internal server error
4248        let model = Arc::new(ErrorInjector::new(TestError::InternalServerError));
4249
4250        // Insert a user message
4251        thread.update(cx, |thread, cx| {
4252            thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4253        });
4254
4255        // Start completion
4256        thread.update(cx, |thread, cx| {
4257            thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4258        });
4259
4260        cx.run_until_parked();
4261
4262        // Check retry state on thread
4263        thread.read_with(cx, |thread, _| {
4264            assert!(thread.retry_state.is_some(), "Should have retry state");
4265            let retry_state = thread.retry_state.as_ref().unwrap();
4266            assert_eq!(retry_state.attempt, 1, "Should be first retry attempt");
4267            assert_eq!(
4268                retry_state.max_attempts, MAX_RETRY_ATTEMPTS,
4269                "Should have correct max attempts"
4270            );
4271        });
4272
4273        // Check that a retry message was added with provider name
4274        thread.read_with(cx, |thread, _| {
4275            let mut messages = thread.messages();
4276            assert!(
4277                messages.any(|msg| {
4278                    msg.role == Role::System
4279                        && msg.ui_only
4280                        && msg.segments.iter().any(|seg| {
4281                            if let MessageSegment::Text(text) = seg {
4282                                text.contains("internal")
4283                                    && text.contains("Fake")
4284                                    && text
4285                                        .contains(&format!("attempt 1 of {}", MAX_RETRY_ATTEMPTS))
4286                            } else {
4287                                false
4288                            }
4289                        })
4290                }),
4291                "Should have added a system retry message with provider name"
4292            );
4293        });
4294
4295        // Count retry messages
4296        let retry_count = thread.update(cx, |thread, _| {
4297            thread
4298                .messages
4299                .iter()
4300                .filter(|m| {
4301                    m.ui_only
4302                        && m.segments.iter().any(|s| {
4303                            if let MessageSegment::Text(text) = s {
4304                                text.contains("Retrying") && text.contains("seconds")
4305                            } else {
4306                                false
4307                            }
4308                        })
4309                })
4310                .count()
4311        });
4312
4313        assert_eq!(retry_count, 1, "Should have one retry message");
4314    }
4315
4316    #[gpui::test]
4317    async fn test_exponential_backoff_on_retries(cx: &mut TestAppContext) {
4318        init_test_settings(cx);
4319
4320        let project = create_test_project(cx, json!({})).await;
4321        let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4322
4323        // Create model that returns overloaded error
4324        let model = Arc::new(ErrorInjector::new(TestError::Overloaded));
4325
4326        // Insert a user message
4327        thread.update(cx, |thread, cx| {
4328            thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4329        });
4330
4331        // Track retry events and completion count
4332        // Track completion events
4333        let completion_count = Arc::new(Mutex::new(0));
4334        let completion_count_clone = completion_count.clone();
4335
4336        let _subscription = thread.update(cx, |_, cx| {
4337            cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| {
4338                if let ThreadEvent::NewRequest = event {
4339                    *completion_count_clone.lock() += 1;
4340                }
4341            })
4342        });
4343
4344        // First attempt
4345        thread.update(cx, |thread, cx| {
4346            thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4347        });
4348        cx.run_until_parked();
4349
4350        // Should have scheduled first retry - count retry messages
4351        let retry_count = thread.update(cx, |thread, _| {
4352            thread
4353                .messages
4354                .iter()
4355                .filter(|m| {
4356                    m.ui_only
4357                        && m.segments.iter().any(|s| {
4358                            if let MessageSegment::Text(text) = s {
4359                                text.contains("Retrying") && text.contains("seconds")
4360                            } else {
4361                                false
4362                            }
4363                        })
4364                })
4365                .count()
4366        });
4367        assert_eq!(retry_count, 1, "Should have scheduled first retry");
4368
4369        // Check retry state
4370        thread.read_with(cx, |thread, _| {
4371            assert!(thread.retry_state.is_some(), "Should have retry state");
4372            let retry_state = thread.retry_state.as_ref().unwrap();
4373            assert_eq!(retry_state.attempt, 1, "Should be first retry attempt");
4374        });
4375
4376        // Advance clock for first retry
4377        cx.executor()
4378            .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS));
4379        cx.run_until_parked();
4380
4381        // Should have scheduled second retry - count retry messages
4382        let retry_count = thread.update(cx, |thread, _| {
4383            thread
4384                .messages
4385                .iter()
4386                .filter(|m| {
4387                    m.ui_only
4388                        && m.segments.iter().any(|s| {
4389                            if let MessageSegment::Text(text) = s {
4390                                text.contains("Retrying") && text.contains("seconds")
4391                            } else {
4392                                false
4393                            }
4394                        })
4395                })
4396                .count()
4397        });
4398        assert_eq!(retry_count, 2, "Should have scheduled second retry");
4399
4400        // Check retry state updated
4401        thread.read_with(cx, |thread, _| {
4402            assert!(thread.retry_state.is_some(), "Should have retry state");
4403            let retry_state = thread.retry_state.as_ref().unwrap();
4404            assert_eq!(retry_state.attempt, 2, "Should be second retry attempt");
4405            assert_eq!(
4406                retry_state.max_attempts, MAX_RETRY_ATTEMPTS,
4407                "Should have correct max attempts"
4408            );
4409        });
4410
4411        // Advance clock for second retry (exponential backoff)
4412        cx.executor()
4413            .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS * 2));
4414        cx.run_until_parked();
4415
4416        // Should have scheduled third retry
4417        // Count all retry messages now
4418        let retry_count = thread.update(cx, |thread, _| {
4419            thread
4420                .messages
4421                .iter()
4422                .filter(|m| {
4423                    m.ui_only
4424                        && m.segments.iter().any(|s| {
4425                            if let MessageSegment::Text(text) = s {
4426                                text.contains("Retrying") && text.contains("seconds")
4427                            } else {
4428                                false
4429                            }
4430                        })
4431                })
4432                .count()
4433        });
4434        assert_eq!(
4435            retry_count, MAX_RETRY_ATTEMPTS as usize,
4436            "Should have scheduled third retry"
4437        );
4438
4439        // Check retry state updated
4440        thread.read_with(cx, |thread, _| {
4441            assert!(thread.retry_state.is_some(), "Should have retry state");
4442            let retry_state = thread.retry_state.as_ref().unwrap();
4443            assert_eq!(
4444                retry_state.attempt, MAX_RETRY_ATTEMPTS,
4445                "Should be at max retry attempt"
4446            );
4447            assert_eq!(
4448                retry_state.max_attempts, MAX_RETRY_ATTEMPTS,
4449                "Should have correct max attempts"
4450            );
4451        });
4452
4453        // Advance clock for third retry (exponential backoff)
4454        cx.executor()
4455            .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS * 4));
4456        cx.run_until_parked();
4457
4458        // No more retries should be scheduled after clock was advanced.
4459        let retry_count = thread.update(cx, |thread, _| {
4460            thread
4461                .messages
4462                .iter()
4463                .filter(|m| {
4464                    m.ui_only
4465                        && m.segments.iter().any(|s| {
4466                            if let MessageSegment::Text(text) = s {
4467                                text.contains("Retrying") && text.contains("seconds")
4468                            } else {
4469                                false
4470                            }
4471                        })
4472                })
4473                .count()
4474        });
4475        assert_eq!(
4476            retry_count, MAX_RETRY_ATTEMPTS as usize,
4477            "Should not exceed max retries"
4478        );
4479
4480        // Final completion count should be initial + max retries
4481        assert_eq!(
4482            *completion_count.lock(),
4483            (MAX_RETRY_ATTEMPTS + 1) as usize,
4484            "Should have made initial + max retry attempts"
4485        );
4486    }
4487
4488    #[gpui::test]
4489    async fn test_max_retries_exceeded(cx: &mut TestAppContext) {
4490        init_test_settings(cx);
4491
4492        let project = create_test_project(cx, json!({})).await;
4493        let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4494
4495        // Create model that returns overloaded error
4496        let model = Arc::new(ErrorInjector::new(TestError::Overloaded));
4497
4498        // Insert a user message
4499        thread.update(cx, |thread, cx| {
4500            thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4501        });
4502
4503        // Track events
4504        let retries_failed = Arc::new(Mutex::new(false));
4505        let retries_failed_clone = retries_failed.clone();
4506
4507        let _subscription = thread.update(cx, |_, cx| {
4508            cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| {
4509                if let ThreadEvent::RetriesFailed { .. } = event {
4510                    *retries_failed_clone.lock() = true;
4511                }
4512            })
4513        });
4514
4515        // Start initial completion
4516        thread.update(cx, |thread, cx| {
4517            thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4518        });
4519        cx.run_until_parked();
4520
4521        // Advance through all retries
4522        for i in 0..MAX_RETRY_ATTEMPTS {
4523            let delay = if i == 0 {
4524                BASE_RETRY_DELAY_SECS
4525            } else {
4526                BASE_RETRY_DELAY_SECS * 2u64.pow(i as u32 - 1)
4527            };
4528            cx.executor().advance_clock(Duration::from_secs(delay));
4529            cx.run_until_parked();
4530        }
4531
4532        // After the 3rd retry is scheduled, we need to wait for it to execute and fail
4533        // The 3rd retry has a delay of BASE_RETRY_DELAY_SECS * 4 (20 seconds)
4534        let final_delay = BASE_RETRY_DELAY_SECS * 2u64.pow((MAX_RETRY_ATTEMPTS - 1) as u32);
4535        cx.executor()
4536            .advance_clock(Duration::from_secs(final_delay));
4537        cx.run_until_parked();
4538
4539        let retry_count = thread.update(cx, |thread, _| {
4540            thread
4541                .messages
4542                .iter()
4543                .filter(|m| {
4544                    m.ui_only
4545                        && m.segments.iter().any(|s| {
4546                            if let MessageSegment::Text(text) = s {
4547                                text.contains("Retrying") && text.contains("seconds")
4548                            } else {
4549                                false
4550                            }
4551                        })
4552                })
4553                .count()
4554        });
4555
4556        // After max retries, should emit RetriesFailed event
4557        assert_eq!(
4558            retry_count, MAX_RETRY_ATTEMPTS as usize,
4559            "Should have attempted max retries"
4560        );
4561        assert!(
4562            *retries_failed.lock(),
4563            "Should emit RetriesFailed event after max retries exceeded"
4564        );
4565
4566        // Retry state should be cleared
4567        thread.read_with(cx, |thread, _| {
4568            assert!(
4569                thread.retry_state.is_none(),
4570                "Retry state should be cleared after max retries"
4571            );
4572
4573            // Verify we have the expected number of retry messages
4574            let retry_messages = thread
4575                .messages
4576                .iter()
4577                .filter(|msg| msg.ui_only && msg.role == Role::System)
4578                .count();
4579            assert_eq!(
4580                retry_messages, MAX_RETRY_ATTEMPTS as usize,
4581                "Should have one retry message per attempt"
4582            );
4583        });
4584    }
4585
4586    #[gpui::test]
4587    async fn test_retry_message_removed_on_retry(cx: &mut TestAppContext) {
4588        init_test_settings(cx);
4589
4590        let project = create_test_project(cx, json!({})).await;
4591        let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4592
4593        // We'll use a wrapper to switch behavior after first failure
4594        struct RetryTestModel {
4595            inner: Arc<FakeLanguageModel>,
4596            failed_once: Arc<Mutex<bool>>,
4597        }
4598
4599        impl LanguageModel for RetryTestModel {
4600            fn id(&self) -> LanguageModelId {
4601                self.inner.id()
4602            }
4603
4604            fn name(&self) -> LanguageModelName {
4605                self.inner.name()
4606            }
4607
4608            fn provider_id(&self) -> LanguageModelProviderId {
4609                self.inner.provider_id()
4610            }
4611
4612            fn provider_name(&self) -> LanguageModelProviderName {
4613                self.inner.provider_name()
4614            }
4615
4616            fn supports_tools(&self) -> bool {
4617                self.inner.supports_tools()
4618            }
4619
4620            fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
4621                self.inner.supports_tool_choice(choice)
4622            }
4623
4624            fn supports_images(&self) -> bool {
4625                self.inner.supports_images()
4626            }
4627
4628            fn telemetry_id(&self) -> String {
4629                self.inner.telemetry_id()
4630            }
4631
4632            fn max_token_count(&self) -> u64 {
4633                self.inner.max_token_count()
4634            }
4635
4636            fn count_tokens(
4637                &self,
4638                request: LanguageModelRequest,
4639                cx: &App,
4640            ) -> BoxFuture<'static, Result<u64>> {
4641                self.inner.count_tokens(request, cx)
4642            }
4643
4644            fn stream_completion(
4645                &self,
4646                request: LanguageModelRequest,
4647                cx: &AsyncApp,
4648            ) -> BoxFuture<
4649                'static,
4650                Result<
4651                    BoxStream<
4652                        'static,
4653                        Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
4654                    >,
4655                    LanguageModelCompletionError,
4656                >,
4657            > {
4658                if !*self.failed_once.lock() {
4659                    *self.failed_once.lock() = true;
4660                    let provider = self.provider_name();
4661                    // Return error on first attempt
4662                    let stream = futures::stream::once(async move {
4663                        Err(LanguageModelCompletionError::ServerOverloaded {
4664                            provider,
4665                            retry_after: None,
4666                        })
4667                    });
4668                    async move { Ok(stream.boxed()) }.boxed()
4669                } else {
4670                    // Succeed on retry
4671                    self.inner.stream_completion(request, cx)
4672                }
4673            }
4674
4675            fn as_fake(&self) -> &FakeLanguageModel {
4676                &self.inner
4677            }
4678        }
4679
4680        let model = Arc::new(RetryTestModel {
4681            inner: Arc::new(FakeLanguageModel::default()),
4682            failed_once: Arc::new(Mutex::new(false)),
4683        });
4684
4685        // Insert a user message
4686        thread.update(cx, |thread, cx| {
4687            thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4688        });
4689
4690        // Track message deletions
4691        // Track when retry completes successfully
4692        let retry_completed = Arc::new(Mutex::new(false));
4693        let retry_completed_clone = retry_completed.clone();
4694
4695        let _subscription = thread.update(cx, |_, cx| {
4696            cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| {
4697                if let ThreadEvent::StreamedCompletion = event {
4698                    *retry_completed_clone.lock() = true;
4699                }
4700            })
4701        });
4702
4703        // Start completion
4704        thread.update(cx, |thread, cx| {
4705            thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4706        });
4707        cx.run_until_parked();
4708
4709        // Get the retry message ID
4710        let retry_message_id = thread.read_with(cx, |thread, _| {
4711            thread
4712                .messages()
4713                .find(|msg| msg.role == Role::System && msg.ui_only)
4714                .map(|msg| msg.id)
4715                .expect("Should have a retry message")
4716        });
4717
4718        // Wait for retry
4719        cx.executor()
4720            .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS));
4721        cx.run_until_parked();
4722
4723        // Stream some successful content
4724        let fake_model = model.as_fake();
4725        // After the retry, there should be a new pending completion
4726        let pending = fake_model.pending_completions();
4727        assert!(
4728            !pending.is_empty(),
4729            "Should have a pending completion after retry"
4730        );
4731        fake_model.stream_completion_response(&pending[0], "Success!");
4732        fake_model.end_completion_stream(&pending[0]);
4733        cx.run_until_parked();
4734
4735        // Check that the retry completed successfully
4736        assert!(
4737            *retry_completed.lock(),
4738            "Retry should have completed successfully"
4739        );
4740
4741        // Retry message should still exist but be marked as ui_only
4742        thread.read_with(cx, |thread, _| {
4743            let retry_msg = thread
4744                .message(retry_message_id)
4745                .expect("Retry message should still exist");
4746            assert!(retry_msg.ui_only, "Retry message should be ui_only");
4747            assert_eq!(
4748                retry_msg.role,
4749                Role::System,
4750                "Retry message should have System role"
4751            );
4752        });
4753    }
4754
4755    #[gpui::test]
4756    async fn test_successful_completion_clears_retry_state(cx: &mut TestAppContext) {
4757        init_test_settings(cx);
4758
4759        let project = create_test_project(cx, json!({})).await;
4760        let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4761
4762        // Create a model that fails once then succeeds
4763        struct FailOnceModel {
4764            inner: Arc<FakeLanguageModel>,
4765            failed_once: Arc<Mutex<bool>>,
4766        }
4767
4768        impl LanguageModel for FailOnceModel {
4769            fn id(&self) -> LanguageModelId {
4770                self.inner.id()
4771            }
4772
4773            fn name(&self) -> LanguageModelName {
4774                self.inner.name()
4775            }
4776
4777            fn provider_id(&self) -> LanguageModelProviderId {
4778                self.inner.provider_id()
4779            }
4780
4781            fn provider_name(&self) -> LanguageModelProviderName {
4782                self.inner.provider_name()
4783            }
4784
4785            fn supports_tools(&self) -> bool {
4786                self.inner.supports_tools()
4787            }
4788
4789            fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
4790                self.inner.supports_tool_choice(choice)
4791            }
4792
4793            fn supports_images(&self) -> bool {
4794                self.inner.supports_images()
4795            }
4796
4797            fn telemetry_id(&self) -> String {
4798                self.inner.telemetry_id()
4799            }
4800
4801            fn max_token_count(&self) -> u64 {
4802                self.inner.max_token_count()
4803            }
4804
4805            fn count_tokens(
4806                &self,
4807                request: LanguageModelRequest,
4808                cx: &App,
4809            ) -> BoxFuture<'static, Result<u64>> {
4810                self.inner.count_tokens(request, cx)
4811            }
4812
4813            fn stream_completion(
4814                &self,
4815                request: LanguageModelRequest,
4816                cx: &AsyncApp,
4817            ) -> BoxFuture<
4818                'static,
4819                Result<
4820                    BoxStream<
4821                        'static,
4822                        Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
4823                    >,
4824                    LanguageModelCompletionError,
4825                >,
4826            > {
4827                if !*self.failed_once.lock() {
4828                    *self.failed_once.lock() = true;
4829                    let provider = self.provider_name();
4830                    // Return error on first attempt
4831                    let stream = futures::stream::once(async move {
4832                        Err(LanguageModelCompletionError::ServerOverloaded {
4833                            provider,
4834                            retry_after: None,
4835                        })
4836                    });
4837                    async move { Ok(stream.boxed()) }.boxed()
4838                } else {
4839                    // Succeed on retry
4840                    self.inner.stream_completion(request, cx)
4841                }
4842            }
4843        }
4844
4845        let fail_once_model = Arc::new(FailOnceModel {
4846            inner: Arc::new(FakeLanguageModel::default()),
4847            failed_once: Arc::new(Mutex::new(false)),
4848        });
4849
4850        // Insert a user message
4851        thread.update(cx, |thread, cx| {
4852            thread.insert_user_message(
4853                "Test message",
4854                ContextLoadResult::default(),
4855                None,
4856                vec![],
4857                cx,
4858            );
4859        });
4860
4861        // Start completion with fail-once model
4862        thread.update(cx, |thread, cx| {
4863            thread.send_to_model(
4864                fail_once_model.clone(),
4865                CompletionIntent::UserPrompt,
4866                None,
4867                cx,
4868            );
4869        });
4870
4871        cx.run_until_parked();
4872
4873        // Verify retry state exists after first failure
4874        thread.read_with(cx, |thread, _| {
4875            assert!(
4876                thread.retry_state.is_some(),
4877                "Should have retry state after failure"
4878            );
4879        });
4880
4881        // Wait for retry delay
4882        cx.executor()
4883            .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS));
4884        cx.run_until_parked();
4885
4886        // The retry should now use our FailOnceModel which should succeed
4887        // We need to help the FakeLanguageModel complete the stream
4888        let inner_fake = fail_once_model.inner.clone();
4889
4890        // Wait a bit for the retry to start
4891        cx.run_until_parked();
4892
4893        // Check for pending completions and complete them
4894        if let Some(pending) = inner_fake.pending_completions().first() {
4895            inner_fake.stream_completion_response(pending, "Success!");
4896            inner_fake.end_completion_stream(pending);
4897        }
4898        cx.run_until_parked();
4899
4900        thread.read_with(cx, |thread, _| {
4901            assert!(
4902                thread.retry_state.is_none(),
4903                "Retry state should be cleared after successful completion"
4904            );
4905
4906            let has_assistant_message = thread
4907                .messages
4908                .iter()
4909                .any(|msg| msg.role == Role::Assistant && !msg.ui_only);
4910            assert!(
4911                has_assistant_message,
4912                "Should have an assistant message after successful retry"
4913            );
4914        });
4915    }
4916
4917    #[gpui::test]
4918    async fn test_rate_limit_retry_single_attempt(cx: &mut TestAppContext) {
4919        init_test_settings(cx);
4920
4921        let project = create_test_project(cx, json!({})).await;
4922        let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4923
4924        // Create a model that returns rate limit error with retry_after
4925        struct RateLimitModel {
4926            inner: Arc<FakeLanguageModel>,
4927        }
4928
4929        impl LanguageModel for RateLimitModel {
4930            fn id(&self) -> LanguageModelId {
4931                self.inner.id()
4932            }
4933
4934            fn name(&self) -> LanguageModelName {
4935                self.inner.name()
4936            }
4937
4938            fn provider_id(&self) -> LanguageModelProviderId {
4939                self.inner.provider_id()
4940            }
4941
4942            fn provider_name(&self) -> LanguageModelProviderName {
4943                self.inner.provider_name()
4944            }
4945
4946            fn supports_tools(&self) -> bool {
4947                self.inner.supports_tools()
4948            }
4949
4950            fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
4951                self.inner.supports_tool_choice(choice)
4952            }
4953
4954            fn supports_images(&self) -> bool {
4955                self.inner.supports_images()
4956            }
4957
4958            fn telemetry_id(&self) -> String {
4959                self.inner.telemetry_id()
4960            }
4961
4962            fn max_token_count(&self) -> u64 {
4963                self.inner.max_token_count()
4964            }
4965
4966            fn count_tokens(
4967                &self,
4968                request: LanguageModelRequest,
4969                cx: &App,
4970            ) -> BoxFuture<'static, Result<u64>> {
4971                self.inner.count_tokens(request, cx)
4972            }
4973
4974            fn stream_completion(
4975                &self,
4976                _request: LanguageModelRequest,
4977                _cx: &AsyncApp,
4978            ) -> BoxFuture<
4979                'static,
4980                Result<
4981                    BoxStream<
4982                        'static,
4983                        Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
4984                    >,
4985                    LanguageModelCompletionError,
4986                >,
4987            > {
4988                let provider = self.provider_name();
4989                async move {
4990                    let stream = futures::stream::once(async move {
4991                        Err(LanguageModelCompletionError::RateLimitExceeded {
4992                            provider,
4993                            retry_after: Some(Duration::from_secs(TEST_RATE_LIMIT_RETRY_SECS)),
4994                        })
4995                    });
4996                    Ok(stream.boxed())
4997                }
4998                .boxed()
4999            }
5000
5001            fn as_fake(&self) -> &FakeLanguageModel {
5002                &self.inner
5003            }
5004        }
5005
5006        let model = Arc::new(RateLimitModel {
5007            inner: Arc::new(FakeLanguageModel::default()),
5008        });
5009
5010        // Insert a user message
5011        thread.update(cx, |thread, cx| {
5012            thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
5013        });
5014
5015        // Start completion
5016        thread.update(cx, |thread, cx| {
5017            thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
5018        });
5019
5020        cx.run_until_parked();
5021
5022        let retry_count = thread.update(cx, |thread, _| {
5023            thread
5024                .messages
5025                .iter()
5026                .filter(|m| {
5027                    m.ui_only
5028                        && m.segments.iter().any(|s| {
5029                            if let MessageSegment::Text(text) = s {
5030                                text.contains("rate limit exceeded")
5031                            } else {
5032                                false
5033                            }
5034                        })
5035                })
5036                .count()
5037        });
5038        assert_eq!(retry_count, 1, "Should have scheduled one retry");
5039
5040        thread.read_with(cx, |thread, _| {
5041            assert!(
5042                thread.retry_state.is_none(),
5043                "Rate limit errors should not set retry_state"
5044            );
5045        });
5046
5047        // Verify we have one retry message
5048        thread.read_with(cx, |thread, _| {
5049            let retry_messages = thread
5050                .messages
5051                .iter()
5052                .filter(|msg| {
5053                    msg.ui_only
5054                        && msg.segments.iter().any(|seg| {
5055                            if let MessageSegment::Text(text) = seg {
5056                                text.contains("rate limit exceeded")
5057                            } else {
5058                                false
5059                            }
5060                        })
5061                })
5062                .count();
5063            assert_eq!(
5064                retry_messages, 1,
5065                "Should have one rate limit retry message"
5066            );
5067        });
5068
5069        // Check that retry message doesn't include attempt count
5070        thread.read_with(cx, |thread, _| {
5071            let retry_message = thread
5072                .messages
5073                .iter()
5074                .find(|msg| msg.role == Role::System && msg.ui_only)
5075                .expect("Should have a retry message");
5076
5077            // Check that the message doesn't contain attempt count
5078            if let Some(MessageSegment::Text(text)) = retry_message.segments.first() {
5079                assert!(
5080                    !text.contains("attempt"),
5081                    "Rate limit retry message should not contain attempt count"
5082                );
5083                assert!(
5084                    text.contains(&format!(
5085                        "Retrying in {} seconds",
5086                        TEST_RATE_LIMIT_RETRY_SECS
5087                    )),
5088                    "Rate limit retry message should contain retry delay"
5089                );
5090            }
5091        });
5092    }
5093
5094    #[gpui::test]
5095    async fn test_ui_only_messages_not_sent_to_model(cx: &mut TestAppContext) {
5096        init_test_settings(cx);
5097
5098        let project = create_test_project(cx, json!({})).await;
5099        let (_, _, thread, _, model) = setup_test_environment(cx, project.clone()).await;
5100
5101        // Insert a regular user message
5102        thread.update(cx, |thread, cx| {
5103            thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
5104        });
5105
5106        // Insert a UI-only message (like our retry notifications)
5107        thread.update(cx, |thread, cx| {
5108            let id = thread.next_message_id.post_inc();
5109            thread.messages.push(Message {
5110                id,
5111                role: Role::System,
5112                segments: vec![MessageSegment::Text(
5113                    "This is a UI-only message that should not be sent to the model".to_string(),
5114                )],
5115                loaded_context: LoadedContext::default(),
5116                creases: Vec::new(),
5117                is_hidden: true,
5118                ui_only: true,
5119            });
5120            cx.emit(ThreadEvent::MessageAdded(id));
5121        });
5122
5123        // Insert another regular message
5124        thread.update(cx, |thread, cx| {
5125            thread.insert_user_message(
5126                "How are you?",
5127                ContextLoadResult::default(),
5128                None,
5129                vec![],
5130                cx,
5131            );
5132        });
5133
5134        // Generate the completion request
5135        let request = thread.update(cx, |thread, cx| {
5136            thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
5137        });
5138
5139        // Verify that the request only contains non-UI-only messages
5140        // Should have system prompt + 2 user messages, but not the UI-only message
5141        let user_messages: Vec<_> = request
5142            .messages
5143            .iter()
5144            .filter(|msg| msg.role == Role::User)
5145            .collect();
5146        assert_eq!(
5147            user_messages.len(),
5148            2,
5149            "Should have exactly 2 user messages"
5150        );
5151
5152        // Verify the UI-only content is not present anywhere in the request
5153        let request_text = request
5154            .messages
5155            .iter()
5156            .flat_map(|msg| &msg.content)
5157            .filter_map(|content| match content {
5158                MessageContent::Text(text) => Some(text.as_str()),
5159                _ => None,
5160            })
5161            .collect::<String>();
5162
5163        assert!(
5164            !request_text.contains("UI-only message"),
5165            "UI-only message content should not be in the request"
5166        );
5167
5168        // Verify the thread still has all 3 messages (including UI-only)
5169        thread.read_with(cx, |thread, _| {
5170            assert_eq!(
5171                thread.messages().count(),
5172                3,
5173                "Thread should have 3 messages"
5174            );
5175            assert_eq!(
5176                thread.messages().filter(|m| m.ui_only).count(),
5177                1,
5178                "Thread should have 1 UI-only message"
5179            );
5180        });
5181
5182        // Verify that UI-only messages are not serialized
5183        let serialized = thread
5184            .update(cx, |thread, cx| thread.serialize(cx))
5185            .await
5186            .unwrap();
5187        assert_eq!(
5188            serialized.messages.len(),
5189            2,
5190            "Serialized thread should only have 2 messages (no UI-only)"
5191        );
5192    }
5193
5194    #[gpui::test]
5195    async fn test_retry_cancelled_on_stop(cx: &mut TestAppContext) {
5196        init_test_settings(cx);
5197
5198        let project = create_test_project(cx, json!({})).await;
5199        let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
5200
5201        // Create model that returns overloaded error
5202        let model = Arc::new(ErrorInjector::new(TestError::Overloaded));
5203
5204        // Insert a user message
5205        thread.update(cx, |thread, cx| {
5206            thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
5207        });
5208
5209        // Start completion
5210        thread.update(cx, |thread, cx| {
5211            thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
5212        });
5213
5214        cx.run_until_parked();
5215
5216        // Verify retry was scheduled by checking for retry message
5217        let has_retry_message = thread.read_with(cx, |thread, _| {
5218            thread.messages.iter().any(|m| {
5219                m.ui_only
5220                    && m.segments.iter().any(|s| {
5221                        if let MessageSegment::Text(text) = s {
5222                            text.contains("Retrying") && text.contains("seconds")
5223                        } else {
5224                            false
5225                        }
5226                    })
5227            })
5228        });
5229        assert!(has_retry_message, "Should have scheduled a retry");
5230
5231        // Cancel the completion before the retry happens
5232        thread.update(cx, |thread, cx| {
5233            thread.cancel_last_completion(None, cx);
5234        });
5235
5236        cx.run_until_parked();
5237
5238        // The retry should not have happened - no pending completions
5239        let fake_model = model.as_fake();
5240        assert_eq!(
5241            fake_model.pending_completions().len(),
5242            0,
5243            "Should have no pending completions after cancellation"
5244        );
5245
5246        // Verify the retry was cancelled by checking retry state
5247        thread.read_with(cx, |thread, _| {
5248            if let Some(retry_state) = &thread.retry_state {
5249                panic!(
5250                    "retry_state should be cleared after cancellation, but found: attempt={}, max_attempts={}, intent={:?}",
5251                    retry_state.attempt, retry_state.max_attempts, retry_state.intent
5252                );
5253            }
5254        });
5255    }
5256
5257    fn test_summarize_error(
5258        model: &Arc<dyn LanguageModel>,
5259        thread: &Entity<Thread>,
5260        cx: &mut TestAppContext,
5261    ) {
5262        thread.update(cx, |thread, cx| {
5263            thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
5264            thread.send_to_model(
5265                model.clone(),
5266                CompletionIntent::ThreadSummarization,
5267                None,
5268                cx,
5269            );
5270        });
5271
5272        let fake_model = model.as_fake();
5273        simulate_successful_response(&fake_model, cx);
5274
5275        thread.read_with(cx, |thread, _| {
5276            assert!(matches!(thread.summary(), ThreadSummary::Generating));
5277            assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
5278        });
5279
5280        // Simulate summary request ending
5281        cx.run_until_parked();
5282        fake_model.end_last_completion_stream();
5283        cx.run_until_parked();
5284
5285        // State is set to Error and default message
5286        thread.read_with(cx, |thread, _| {
5287            assert!(matches!(thread.summary(), ThreadSummary::Error));
5288            assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
5289        });
5290    }
5291
5292    fn simulate_successful_response(fake_model: &FakeLanguageModel, cx: &mut TestAppContext) {
5293        cx.run_until_parked();
5294        fake_model.stream_last_completion_response("Assistant response");
5295        fake_model.end_last_completion_stream();
5296        cx.run_until_parked();
5297    }
5298
5299    fn init_test_settings(cx: &mut TestAppContext) {
5300        cx.update(|cx| {
5301            let settings_store = SettingsStore::test(cx);
5302            cx.set_global(settings_store);
5303            language::init(cx);
5304            Project::init_settings(cx);
5305            AgentSettings::register(cx);
5306            prompt_store::init(cx);
5307            thread_store::init(cx);
5308            workspace::init_settings(cx);
5309            language_model::init_settings(cx);
5310            ThemeSettings::register(cx);
5311            ToolRegistry::default_global(cx);
5312            assistant_tool::init(cx);
5313
5314            let http_client = Arc::new(http_client::HttpClientWithUrl::new(
5315                http_client::FakeHttpClient::with_200_response(),
5316                "http://localhost".to_string(),
5317                None,
5318            ));
5319            assistant_tools::init(http_client, cx);
5320        });
5321    }
5322
5323    // Helper to create a test project with test files
5324    async fn create_test_project(
5325        cx: &mut TestAppContext,
5326        files: serde_json::Value,
5327    ) -> Entity<Project> {
5328        let fs = FakeFs::new(cx.executor());
5329        fs.insert_tree(path!("/test"), files).await;
5330        Project::test(fs, [path!("/test").as_ref()], cx).await
5331    }
5332
5333    async fn setup_test_environment(
5334        cx: &mut TestAppContext,
5335        project: Entity<Project>,
5336    ) -> (
5337        Entity<Workspace>,
5338        Entity<ThreadStore>,
5339        Entity<Thread>,
5340        Entity<ContextStore>,
5341        Arc<dyn LanguageModel>,
5342    ) {
5343        let (workspace, cx) =
5344            cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
5345
5346        let thread_store = cx
5347            .update(|_, cx| {
5348                ThreadStore::load(
5349                    project.clone(),
5350                    cx.new(|_| ToolWorkingSet::default()),
5351                    None,
5352                    Arc::new(PromptBuilder::new(None).unwrap()),
5353                    cx,
5354                )
5355            })
5356            .await
5357            .unwrap();
5358
5359        let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
5360        let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
5361
5362        let provider = Arc::new(FakeLanguageModelProvider);
5363        let model = provider.test_model();
5364        let model: Arc<dyn LanguageModel> = Arc::new(model);
5365
5366        cx.update(|_, cx| {
5367            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
5368                registry.set_default_model(
5369                    Some(ConfiguredModel {
5370                        provider: provider.clone(),
5371                        model: model.clone(),
5372                    }),
5373                    cx,
5374                );
5375                registry.set_thread_summary_model(
5376                    Some(ConfiguredModel {
5377                        provider,
5378                        model: model.clone(),
5379                    }),
5380                    cx,
5381                );
5382            })
5383        });
5384
5385        (workspace, thread_store, thread, context_store, model)
5386    }
5387
5388    async fn add_file_to_context(
5389        project: &Entity<Project>,
5390        context_store: &Entity<ContextStore>,
5391        path: &str,
5392        cx: &mut TestAppContext,
5393    ) -> Result<Entity<language::Buffer>> {
5394        let buffer_path = project
5395            .read_with(cx, |project, cx| project.find_project_path(path, cx))
5396            .unwrap();
5397
5398        let buffer = project
5399            .update(cx, |project, cx| {
5400                project.open_buffer(buffer_path.clone(), cx)
5401            })
5402            .await
5403            .unwrap();
5404
5405        context_store.update(cx, |context_store, cx| {
5406            context_store.add_file_from_buffer(&buffer_path, buffer.clone(), false, cx);
5407        });
5408
5409        Ok(buffer)
5410    }
5411}