thread.rs

   1use crate::{
   2    agent_profile::AgentProfile,
   3    context::{AgentContext, AgentContextHandle, ContextLoadResult, LoadedContext},
   4    thread_store::{
   5        SerializedCrease, SerializedLanguageModel, SerializedMessage, SerializedMessageSegment,
   6        SerializedThread, SerializedToolResult, SerializedToolUse, SharedProjectContext,
   7        ThreadStore,
   8    },
   9    tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState},
  10};
  11use agent_settings::{AgentProfileId, AgentSettings, CompletionMode};
  12use anyhow::{Result, anyhow};
  13use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
  14use chrono::{DateTime, Utc};
  15use client::{ModelRequestUsage, RequestUsage};
  16use collections::HashMap;
  17use feature_flags::{self, FeatureFlagAppExt};
  18use futures::{FutureExt, StreamExt as _, future::Shared};
  19use git::repository::DiffType;
  20use gpui::{
  21    AnyWindowHandle, App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task,
  22    WeakEntity, Window,
  23};
  24use language_model::{
  25    ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
  26    LanguageModelId, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
  27    LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolResultContent,
  28    LanguageModelToolUse, LanguageModelToolUseId, MessageContent, ModelRequestLimitReachedError,
  29    PaymentRequiredError, Role, SelectedModel, StopReason, TokenUsage,
  30};
  31use postage::stream::Stream as _;
  32use project::{
  33    Project,
  34    git_store::{GitStore, GitStoreCheckpoint, RepositoryState},
  35};
  36use prompt_store::{ModelContext, PromptBuilder};
  37use proto::Plan;
  38use schemars::JsonSchema;
  39use serde::{Deserialize, Serialize};
  40use settings::Settings;
  41use std::{
  42    io::Write,
  43    ops::Range,
  44    sync::Arc,
  45    time::{Duration, Instant},
  46};
  47use thiserror::Error;
  48use util::{ResultExt as _, debug_panic, post_inc};
  49use uuid::Uuid;
  50use zed_llm_client::{CompletionIntent, CompletionRequestStatus, UsageLimit};
  51
  52const MAX_RETRY_ATTEMPTS: u8 = 3;
  53const BASE_RETRY_DELAY_SECS: u64 = 5;
  54
  55#[derive(
  56    Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
  57)]
  58pub struct ThreadId(Arc<str>);
  59
  60impl ThreadId {
  61    pub fn new() -> Self {
  62        Self(Uuid::new_v4().to_string().into())
  63    }
  64}
  65
  66impl std::fmt::Display for ThreadId {
  67    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  68        write!(f, "{}", self.0)
  69    }
  70}
  71
  72impl From<&str> for ThreadId {
  73    fn from(value: &str) -> Self {
  74        Self(value.into())
  75    }
  76}
  77
  78/// The ID of the user prompt that initiated a request.
  79///
  80/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
  81#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
  82pub struct PromptId(Arc<str>);
  83
  84impl PromptId {
  85    pub fn new() -> Self {
  86        Self(Uuid::new_v4().to_string().into())
  87    }
  88}
  89
  90impl std::fmt::Display for PromptId {
  91    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  92        write!(f, "{}", self.0)
  93    }
  94}
  95
  96#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
  97pub struct MessageId(pub(crate) usize);
  98
  99impl MessageId {
 100    fn post_inc(&mut self) -> Self {
 101        Self(post_inc(&mut self.0))
 102    }
 103
 104    pub fn as_usize(&self) -> usize {
 105        self.0
 106    }
 107}
 108
 109/// Stored information that can be used to resurrect a context crease when creating an editor for a past message.
 110#[derive(Clone, Debug)]
 111pub struct MessageCrease {
 112    pub range: Range<usize>,
 113    pub icon_path: SharedString,
 114    pub label: SharedString,
 115    /// None for a deserialized message, Some otherwise.
 116    pub context: Option<AgentContextHandle>,
 117}
 118
 119/// A message in a [`Thread`].
 120#[derive(Debug, Clone)]
 121pub struct Message {
 122    pub id: MessageId,
 123    pub role: Role,
 124    pub segments: Vec<MessageSegment>,
 125    pub loaded_context: LoadedContext,
 126    pub creases: Vec<MessageCrease>,
 127    pub is_hidden: bool,
 128    pub ui_only: bool,
 129}
 130
 131impl Message {
 132    /// Returns whether the message contains any meaningful text that should be displayed
 133    /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
 134    pub fn should_display_content(&self) -> bool {
 135        self.segments.iter().all(|segment| segment.should_display())
 136    }
 137
 138    pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
 139        if let Some(MessageSegment::Thinking {
 140            text: segment,
 141            signature: current_signature,
 142        }) = self.segments.last_mut()
 143        {
 144            if let Some(signature) = signature {
 145                *current_signature = Some(signature);
 146            }
 147            segment.push_str(text);
 148        } else {
 149            self.segments.push(MessageSegment::Thinking {
 150                text: text.to_string(),
 151                signature,
 152            });
 153        }
 154    }
 155
 156    pub fn push_redacted_thinking(&mut self, data: String) {
 157        self.segments.push(MessageSegment::RedactedThinking(data));
 158    }
 159
 160    pub fn push_text(&mut self, text: &str) {
 161        if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
 162            segment.push_str(text);
 163        } else {
 164            self.segments.push(MessageSegment::Text(text.to_string()));
 165        }
 166    }
 167
 168    pub fn to_string(&self) -> String {
 169        let mut result = String::new();
 170
 171        if !self.loaded_context.text.is_empty() {
 172            result.push_str(&self.loaded_context.text);
 173        }
 174
 175        for segment in &self.segments {
 176            match segment {
 177                MessageSegment::Text(text) => result.push_str(text),
 178                MessageSegment::Thinking { text, .. } => {
 179                    result.push_str("<think>\n");
 180                    result.push_str(text);
 181                    result.push_str("\n</think>");
 182                }
 183                MessageSegment::RedactedThinking(_) => {}
 184            }
 185        }
 186
 187        result
 188    }
 189}
 190
 191#[derive(Debug, Clone, PartialEq, Eq)]
 192pub enum MessageSegment {
 193    Text(String),
 194    Thinking {
 195        text: String,
 196        signature: Option<String>,
 197    },
 198    RedactedThinking(String),
 199}
 200
 201impl MessageSegment {
 202    pub fn should_display(&self) -> bool {
 203        match self {
 204            Self::Text(text) => text.is_empty(),
 205            Self::Thinking { text, .. } => text.is_empty(),
 206            Self::RedactedThinking(_) => false,
 207        }
 208    }
 209
 210    pub fn text(&self) -> Option<&str> {
 211        match self {
 212            MessageSegment::Text(text) => Some(text),
 213            _ => None,
 214        }
 215    }
 216}
 217
 218#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
 219pub struct ProjectSnapshot {
 220    pub worktree_snapshots: Vec<WorktreeSnapshot>,
 221    pub unsaved_buffer_paths: Vec<String>,
 222    pub timestamp: DateTime<Utc>,
 223}
 224
 225#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
 226pub struct WorktreeSnapshot {
 227    pub worktree_path: String,
 228    pub git_state: Option<GitState>,
 229}
 230
 231#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
 232pub struct GitState {
 233    pub remote_url: Option<String>,
 234    pub head_sha: Option<String>,
 235    pub current_branch: Option<String>,
 236    pub diff: Option<String>,
 237}
 238
 239#[derive(Clone, Debug)]
 240pub struct ThreadCheckpoint {
 241    message_id: MessageId,
 242    git_checkpoint: GitStoreCheckpoint,
 243}
 244
 245#[derive(Copy, Clone, Debug, PartialEq, Eq)]
 246pub enum ThreadFeedback {
 247    Positive,
 248    Negative,
 249}
 250
 251pub enum LastRestoreCheckpoint {
 252    Pending {
 253        message_id: MessageId,
 254    },
 255    Error {
 256        message_id: MessageId,
 257        error: String,
 258    },
 259}
 260
 261impl LastRestoreCheckpoint {
 262    pub fn message_id(&self) -> MessageId {
 263        match self {
 264            LastRestoreCheckpoint::Pending { message_id } => *message_id,
 265            LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
 266        }
 267    }
 268}
 269
 270#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
 271pub enum DetailedSummaryState {
 272    #[default]
 273    NotGenerated,
 274    Generating {
 275        message_id: MessageId,
 276    },
 277    Generated {
 278        text: SharedString,
 279        message_id: MessageId,
 280    },
 281}
 282
 283impl DetailedSummaryState {
 284    fn text(&self) -> Option<SharedString> {
 285        if let Self::Generated { text, .. } = self {
 286            Some(text.clone())
 287        } else {
 288            None
 289        }
 290    }
 291}
 292
 293#[derive(Default, Debug)]
 294pub struct TotalTokenUsage {
 295    pub total: u64,
 296    pub max: u64,
 297}
 298
 299impl TotalTokenUsage {
 300    pub fn ratio(&self) -> TokenUsageRatio {
 301        #[cfg(debug_assertions)]
 302        let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
 303            .unwrap_or("0.8".to_string())
 304            .parse()
 305            .unwrap();
 306        #[cfg(not(debug_assertions))]
 307        let warning_threshold: f32 = 0.8;
 308
 309        // When the maximum is unknown because there is no selected model,
 310        // avoid showing the token limit warning.
 311        if self.max == 0 {
 312            TokenUsageRatio::Normal
 313        } else if self.total >= self.max {
 314            TokenUsageRatio::Exceeded
 315        } else if self.total as f32 / self.max as f32 >= warning_threshold {
 316            TokenUsageRatio::Warning
 317        } else {
 318            TokenUsageRatio::Normal
 319        }
 320    }
 321
 322    pub fn add(&self, tokens: u64) -> TotalTokenUsage {
 323        TotalTokenUsage {
 324            total: self.total + tokens,
 325            max: self.max,
 326        }
 327    }
 328}
 329
 330#[derive(Debug, Default, PartialEq, Eq)]
 331pub enum TokenUsageRatio {
 332    #[default]
 333    Normal,
 334    Warning,
 335    Exceeded,
 336}
 337
 338#[derive(Debug, Clone, Copy)]
 339pub enum QueueState {
 340    Sending,
 341    Queued { position: usize },
 342    Started,
 343}
 344
 345/// A thread of conversation with the LLM.
 346pub struct Thread {
 347    id: ThreadId,
 348    updated_at: DateTime<Utc>,
 349    summary: ThreadSummary,
 350    pending_summary: Task<Option<()>>,
 351    detailed_summary_task: Task<Option<()>>,
 352    detailed_summary_tx: postage::watch::Sender<DetailedSummaryState>,
 353    detailed_summary_rx: postage::watch::Receiver<DetailedSummaryState>,
 354    completion_mode: agent_settings::CompletionMode,
 355    messages: Vec<Message>,
 356    next_message_id: MessageId,
 357    last_prompt_id: PromptId,
 358    project_context: SharedProjectContext,
 359    checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
 360    completion_count: usize,
 361    pending_completions: Vec<PendingCompletion>,
 362    project: Entity<Project>,
 363    prompt_builder: Arc<PromptBuilder>,
 364    tools: Entity<ToolWorkingSet>,
 365    tool_use: ToolUseState,
 366    action_log: Entity<ActionLog>,
 367    last_restore_checkpoint: Option<LastRestoreCheckpoint>,
 368    pending_checkpoint: Option<ThreadCheckpoint>,
 369    initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
 370    request_token_usage: Vec<TokenUsage>,
 371    cumulative_token_usage: TokenUsage,
 372    exceeded_window_error: Option<ExceededWindowError>,
 373    tool_use_limit_reached: bool,
 374    feedback: Option<ThreadFeedback>,
 375    retry_state: Option<RetryState>,
 376    message_feedback: HashMap<MessageId, ThreadFeedback>,
 377    last_auto_capture_at: Option<Instant>,
 378    last_received_chunk_at: Option<Instant>,
 379    request_callback: Option<
 380        Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
 381    >,
 382    remaining_turns: u32,
 383    configured_model: Option<ConfiguredModel>,
 384    profile: AgentProfile,
 385}
 386
 387#[derive(Clone, Debug)]
 388struct RetryState {
 389    attempt: u8,
 390    max_attempts: u8,
 391    intent: CompletionIntent,
 392}
 393
 394#[derive(Clone, Debug, PartialEq, Eq)]
 395pub enum ThreadSummary {
 396    Pending,
 397    Generating,
 398    Ready(SharedString),
 399    Error,
 400}
 401
 402impl ThreadSummary {
 403    pub const DEFAULT: SharedString = SharedString::new_static("New Thread");
 404
 405    pub fn or_default(&self) -> SharedString {
 406        self.unwrap_or(Self::DEFAULT)
 407    }
 408
 409    pub fn unwrap_or(&self, message: impl Into<SharedString>) -> SharedString {
 410        self.ready().unwrap_or_else(|| message.into())
 411    }
 412
 413    pub fn ready(&self) -> Option<SharedString> {
 414        match self {
 415            ThreadSummary::Ready(summary) => Some(summary.clone()),
 416            ThreadSummary::Pending | ThreadSummary::Generating | ThreadSummary::Error => None,
 417        }
 418    }
 419}
 420
 421#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
 422pub struct ExceededWindowError {
 423    /// Model used when last message exceeded context window
 424    model_id: LanguageModelId,
 425    /// Token count including last message
 426    token_count: u64,
 427}
 428
 429impl Thread {
 430    pub fn new(
 431        project: Entity<Project>,
 432        tools: Entity<ToolWorkingSet>,
 433        prompt_builder: Arc<PromptBuilder>,
 434        system_prompt: SharedProjectContext,
 435        cx: &mut Context<Self>,
 436    ) -> Self {
 437        let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel();
 438        let configured_model = LanguageModelRegistry::read_global(cx).default_model();
 439        let profile_id = AgentSettings::get_global(cx).default_profile.clone();
 440
 441        Self {
 442            id: ThreadId::new(),
 443            updated_at: Utc::now(),
 444            summary: ThreadSummary::Pending,
 445            pending_summary: Task::ready(None),
 446            detailed_summary_task: Task::ready(None),
 447            detailed_summary_tx,
 448            detailed_summary_rx,
 449            completion_mode: AgentSettings::get_global(cx).preferred_completion_mode,
 450            messages: Vec::new(),
 451            next_message_id: MessageId(0),
 452            last_prompt_id: PromptId::new(),
 453            project_context: system_prompt,
 454            checkpoints_by_message: HashMap::default(),
 455            completion_count: 0,
 456            pending_completions: Vec::new(),
 457            project: project.clone(),
 458            prompt_builder,
 459            tools: tools.clone(),
 460            last_restore_checkpoint: None,
 461            pending_checkpoint: None,
 462            tool_use: ToolUseState::new(tools.clone()),
 463            action_log: cx.new(|_| ActionLog::new(project.clone())),
 464            initial_project_snapshot: {
 465                let project_snapshot = Self::project_snapshot(project, cx);
 466                cx.foreground_executor()
 467                    .spawn(async move { Some(project_snapshot.await) })
 468                    .shared()
 469            },
 470            request_token_usage: Vec::new(),
 471            cumulative_token_usage: TokenUsage::default(),
 472            exceeded_window_error: None,
 473            tool_use_limit_reached: false,
 474            feedback: None,
 475            retry_state: None,
 476            message_feedback: HashMap::default(),
 477            last_auto_capture_at: None,
 478            last_received_chunk_at: None,
 479            request_callback: None,
 480            remaining_turns: u32::MAX,
 481            configured_model,
 482            profile: AgentProfile::new(profile_id, tools),
 483        }
 484    }
 485
 486    pub fn deserialize(
 487        id: ThreadId,
 488        serialized: SerializedThread,
 489        project: Entity<Project>,
 490        tools: Entity<ToolWorkingSet>,
 491        prompt_builder: Arc<PromptBuilder>,
 492        project_context: SharedProjectContext,
 493        window: Option<&mut Window>, // None in headless mode
 494        cx: &mut Context<Self>,
 495    ) -> Self {
 496        let next_message_id = MessageId(
 497            serialized
 498                .messages
 499                .last()
 500                .map(|message| message.id.0 + 1)
 501                .unwrap_or(0),
 502        );
 503        let tool_use = ToolUseState::from_serialized_messages(
 504            tools.clone(),
 505            &serialized.messages,
 506            project.clone(),
 507            window,
 508            cx,
 509        );
 510        let (detailed_summary_tx, detailed_summary_rx) =
 511            postage::watch::channel_with(serialized.detailed_summary_state);
 512
 513        let configured_model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
 514            serialized
 515                .model
 516                .and_then(|model| {
 517                    let model = SelectedModel {
 518                        provider: model.provider.clone().into(),
 519                        model: model.model.clone().into(),
 520                    };
 521                    registry.select_model(&model, cx)
 522                })
 523                .or_else(|| registry.default_model())
 524        });
 525
 526        let completion_mode = serialized
 527            .completion_mode
 528            .unwrap_or_else(|| AgentSettings::get_global(cx).preferred_completion_mode);
 529        let profile_id = serialized
 530            .profile
 531            .unwrap_or_else(|| AgentSettings::get_global(cx).default_profile.clone());
 532
 533        Self {
 534            id,
 535            updated_at: serialized.updated_at,
 536            summary: ThreadSummary::Ready(serialized.summary),
 537            pending_summary: Task::ready(None),
 538            detailed_summary_task: Task::ready(None),
 539            detailed_summary_tx,
 540            detailed_summary_rx,
 541            completion_mode,
 542            retry_state: None,
 543            messages: serialized
 544                .messages
 545                .into_iter()
 546                .map(|message| Message {
 547                    id: message.id,
 548                    role: message.role,
 549                    segments: message
 550                        .segments
 551                        .into_iter()
 552                        .map(|segment| match segment {
 553                            SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
 554                            SerializedMessageSegment::Thinking { text, signature } => {
 555                                MessageSegment::Thinking { text, signature }
 556                            }
 557                            SerializedMessageSegment::RedactedThinking { data } => {
 558                                MessageSegment::RedactedThinking(data)
 559                            }
 560                        })
 561                        .collect(),
 562                    loaded_context: LoadedContext {
 563                        contexts: Vec::new(),
 564                        text: message.context,
 565                        images: Vec::new(),
 566                    },
 567                    creases: message
 568                        .creases
 569                        .into_iter()
 570                        .map(|crease| MessageCrease {
 571                            range: crease.start..crease.end,
 572                            icon_path: crease.icon_path,
 573                            label: crease.label,
 574                            context: None,
 575                        })
 576                        .collect(),
 577                    is_hidden: message.is_hidden,
 578                    ui_only: false, // UI-only messages are not persisted
 579                })
 580                .collect(),
 581            next_message_id,
 582            last_prompt_id: PromptId::new(),
 583            project_context,
 584            checkpoints_by_message: HashMap::default(),
 585            completion_count: 0,
 586            pending_completions: Vec::new(),
 587            last_restore_checkpoint: None,
 588            pending_checkpoint: None,
 589            project: project.clone(),
 590            prompt_builder,
 591            tools: tools.clone(),
 592            tool_use,
 593            action_log: cx.new(|_| ActionLog::new(project)),
 594            initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
 595            request_token_usage: serialized.request_token_usage,
 596            cumulative_token_usage: serialized.cumulative_token_usage,
 597            exceeded_window_error: None,
 598            tool_use_limit_reached: serialized.tool_use_limit_reached,
 599            feedback: None,
 600            message_feedback: HashMap::default(),
 601            last_auto_capture_at: None,
 602            last_received_chunk_at: None,
 603            request_callback: None,
 604            remaining_turns: u32::MAX,
 605            configured_model,
 606            profile: AgentProfile::new(profile_id, tools),
 607        }
 608    }
 609
 610    pub fn set_request_callback(
 611        &mut self,
 612        callback: impl 'static
 613        + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
 614    ) {
 615        self.request_callback = Some(Box::new(callback));
 616    }
 617
 618    pub fn id(&self) -> &ThreadId {
 619        &self.id
 620    }
 621
 622    pub fn profile(&self) -> &AgentProfile {
 623        &self.profile
 624    }
 625
 626    pub fn set_profile(&mut self, id: AgentProfileId, cx: &mut Context<Self>) {
 627        if &id != self.profile.id() {
 628            self.profile = AgentProfile::new(id, self.tools.clone());
 629            cx.emit(ThreadEvent::ProfileChanged);
 630        }
 631    }
 632
 633    pub fn is_empty(&self) -> bool {
 634        self.messages.is_empty()
 635    }
 636
 637    pub fn updated_at(&self) -> DateTime<Utc> {
 638        self.updated_at
 639    }
 640
 641    pub fn touch_updated_at(&mut self) {
 642        self.updated_at = Utc::now();
 643    }
 644
 645    pub fn advance_prompt_id(&mut self) {
 646        self.last_prompt_id = PromptId::new();
 647    }
 648
 649    pub fn project_context(&self) -> SharedProjectContext {
 650        self.project_context.clone()
 651    }
 652
 653    pub fn get_or_init_configured_model(&mut self, cx: &App) -> Option<ConfiguredModel> {
 654        if self.configured_model.is_none() {
 655            self.configured_model = LanguageModelRegistry::read_global(cx).default_model();
 656        }
 657        self.configured_model.clone()
 658    }
 659
 660    pub fn configured_model(&self) -> Option<ConfiguredModel> {
 661        self.configured_model.clone()
 662    }
 663
 664    pub fn set_configured_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
 665        self.configured_model = model;
 666        cx.notify();
 667    }
 668
 669    pub fn summary(&self) -> &ThreadSummary {
 670        &self.summary
 671    }
 672
 673    pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
 674        let current_summary = match &self.summary {
 675            ThreadSummary::Pending | ThreadSummary::Generating => return,
 676            ThreadSummary::Ready(summary) => summary,
 677            ThreadSummary::Error => &ThreadSummary::DEFAULT,
 678        };
 679
 680        let mut new_summary = new_summary.into();
 681
 682        if new_summary.is_empty() {
 683            new_summary = ThreadSummary::DEFAULT;
 684        }
 685
 686        if current_summary != &new_summary {
 687            self.summary = ThreadSummary::Ready(new_summary);
 688            cx.emit(ThreadEvent::SummaryChanged);
 689        }
 690    }
 691
 692    pub fn completion_mode(&self) -> CompletionMode {
 693        self.completion_mode
 694    }
 695
 696    pub fn set_completion_mode(&mut self, mode: CompletionMode) {
 697        self.completion_mode = mode;
 698    }
 699
 700    pub fn message(&self, id: MessageId) -> Option<&Message> {
 701        let index = self
 702            .messages
 703            .binary_search_by(|message| message.id.cmp(&id))
 704            .ok()?;
 705
 706        self.messages.get(index)
 707    }
 708
 709    pub fn messages(&self) -> impl ExactSizeIterator<Item = &Message> {
 710        self.messages.iter()
 711    }
 712
 713    pub fn is_generating(&self) -> bool {
 714        !self.pending_completions.is_empty() || !self.all_tools_finished()
 715    }
 716
 717    /// Indicates whether streaming of language model events is stale.
 718    /// When `is_generating()` is false, this method returns `None`.
 719    pub fn is_generation_stale(&self) -> Option<bool> {
 720        const STALE_THRESHOLD: u128 = 250;
 721
 722        self.last_received_chunk_at
 723            .map(|instant| instant.elapsed().as_millis() > STALE_THRESHOLD)
 724    }
 725
 726    fn received_chunk(&mut self) {
 727        self.last_received_chunk_at = Some(Instant::now());
 728    }
 729
 730    pub fn queue_state(&self) -> Option<QueueState> {
 731        self.pending_completions
 732            .first()
 733            .map(|pending_completion| pending_completion.queue_state)
 734    }
 735
 736    pub fn tools(&self) -> &Entity<ToolWorkingSet> {
 737        &self.tools
 738    }
 739
 740    pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
 741        self.tool_use
 742            .pending_tool_uses()
 743            .into_iter()
 744            .find(|tool_use| &tool_use.id == id)
 745    }
 746
 747    pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
 748        self.tool_use
 749            .pending_tool_uses()
 750            .into_iter()
 751            .filter(|tool_use| tool_use.status.needs_confirmation())
 752    }
 753
 754    pub fn has_pending_tool_uses(&self) -> bool {
 755        !self.tool_use.pending_tool_uses().is_empty()
 756    }
 757
 758    pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
 759        self.checkpoints_by_message.get(&id).cloned()
 760    }
 761
 762    pub fn restore_checkpoint(
 763        &mut self,
 764        checkpoint: ThreadCheckpoint,
 765        cx: &mut Context<Self>,
 766    ) -> Task<Result<()>> {
 767        self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
 768            message_id: checkpoint.message_id,
 769        });
 770        cx.emit(ThreadEvent::CheckpointChanged);
 771        cx.notify();
 772
 773        let git_store = self.project().read(cx).git_store().clone();
 774        let restore = git_store.update(cx, |git_store, cx| {
 775            git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
 776        });
 777
 778        cx.spawn(async move |this, cx| {
 779            let result = restore.await;
 780            this.update(cx, |this, cx| {
 781                if let Err(err) = result.as_ref() {
 782                    this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
 783                        message_id: checkpoint.message_id,
 784                        error: err.to_string(),
 785                    });
 786                } else {
 787                    this.truncate(checkpoint.message_id, cx);
 788                    this.last_restore_checkpoint = None;
 789                }
 790                this.pending_checkpoint = None;
 791                cx.emit(ThreadEvent::CheckpointChanged);
 792                cx.notify();
 793            })?;
 794            result
 795        })
 796    }
 797
 798    fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
 799        let pending_checkpoint = if self.is_generating() {
 800            return;
 801        } else if let Some(checkpoint) = self.pending_checkpoint.take() {
 802            checkpoint
 803        } else {
 804            return;
 805        };
 806
 807        self.finalize_checkpoint(pending_checkpoint, cx);
 808    }
 809
 810    fn finalize_checkpoint(
 811        &mut self,
 812        pending_checkpoint: ThreadCheckpoint,
 813        cx: &mut Context<Self>,
 814    ) {
 815        let git_store = self.project.read(cx).git_store().clone();
 816        let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
 817        cx.spawn(async move |this, cx| match final_checkpoint.await {
 818            Ok(final_checkpoint) => {
 819                let equal = git_store
 820                    .update(cx, |store, cx| {
 821                        store.compare_checkpoints(
 822                            pending_checkpoint.git_checkpoint.clone(),
 823                            final_checkpoint.clone(),
 824                            cx,
 825                        )
 826                    })?
 827                    .await
 828                    .unwrap_or(false);
 829
 830                if !equal {
 831                    this.update(cx, |this, cx| {
 832                        this.insert_checkpoint(pending_checkpoint, cx)
 833                    })?;
 834                }
 835
 836                Ok(())
 837            }
 838            Err(_) => this.update(cx, |this, cx| {
 839                this.insert_checkpoint(pending_checkpoint, cx)
 840            }),
 841        })
 842        .detach();
 843    }
 844
 845    fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
 846        self.checkpoints_by_message
 847            .insert(checkpoint.message_id, checkpoint);
 848        cx.emit(ThreadEvent::CheckpointChanged);
 849        cx.notify();
 850    }
 851
 852    pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
 853        self.last_restore_checkpoint.as_ref()
 854    }
 855
 856    pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
 857        let Some(message_ix) = self
 858            .messages
 859            .iter()
 860            .rposition(|message| message.id == message_id)
 861        else {
 862            return;
 863        };
 864        for deleted_message in self.messages.drain(message_ix..) {
 865            self.checkpoints_by_message.remove(&deleted_message.id);
 866        }
 867        cx.notify();
 868    }
 869
 870    pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AgentContext> {
 871        self.messages
 872            .iter()
 873            .find(|message| message.id == id)
 874            .into_iter()
 875            .flat_map(|message| message.loaded_context.contexts.iter())
 876    }
 877
 878    pub fn is_turn_end(&self, ix: usize) -> bool {
 879        if self.messages.is_empty() {
 880            return false;
 881        }
 882
 883        if !self.is_generating() && ix == self.messages.len() - 1 {
 884            return true;
 885        }
 886
 887        let Some(message) = self.messages.get(ix) else {
 888            return false;
 889        };
 890
 891        if message.role != Role::Assistant {
 892            return false;
 893        }
 894
 895        self.messages
 896            .get(ix + 1)
 897            .and_then(|message| {
 898                self.message(message.id)
 899                    .map(|next_message| next_message.role == Role::User && !next_message.is_hidden)
 900            })
 901            .unwrap_or(false)
 902    }
 903
 904    pub fn tool_use_limit_reached(&self) -> bool {
 905        self.tool_use_limit_reached
 906    }
 907
 908    /// Returns whether all of the tool uses have finished running.
 909    pub fn all_tools_finished(&self) -> bool {
 910        // If the only pending tool uses left are the ones with errors, then
 911        // that means that we've finished running all of the pending tools.
 912        self.tool_use
 913            .pending_tool_uses()
 914            .iter()
 915            .all(|pending_tool_use| pending_tool_use.status.is_error())
 916    }
 917
 918    /// Returns whether any pending tool uses may perform edits
 919    pub fn has_pending_edit_tool_uses(&self) -> bool {
 920        self.tool_use
 921            .pending_tool_uses()
 922            .iter()
 923            .filter(|pending_tool_use| !pending_tool_use.status.is_error())
 924            .any(|pending_tool_use| pending_tool_use.may_perform_edits)
 925    }
 926
 927    pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
 928        self.tool_use.tool_uses_for_message(id, cx)
 929    }
 930
 931    pub fn tool_results_for_message(
 932        &self,
 933        assistant_message_id: MessageId,
 934    ) -> Vec<&LanguageModelToolResult> {
 935        self.tool_use.tool_results_for_message(assistant_message_id)
 936    }
 937
 938    pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
 939        self.tool_use.tool_result(id)
 940    }
 941
 942    pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
 943        match &self.tool_use.tool_result(id)?.content {
 944            LanguageModelToolResultContent::Text(text) => Some(text),
 945            LanguageModelToolResultContent::Image(_) => {
 946                // TODO: We should display image
 947                None
 948            }
 949        }
 950    }
 951
 952    pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
 953        self.tool_use.tool_result_card(id).cloned()
 954    }
 955
 956    /// Return tools that are both enabled and supported by the model
 957    pub fn available_tools(
 958        &self,
 959        cx: &App,
 960        model: Arc<dyn LanguageModel>,
 961    ) -> Vec<LanguageModelRequestTool> {
 962        if model.supports_tools() {
 963            self.profile
 964                .enabled_tools(cx)
 965                .into_iter()
 966                .filter_map(|(name, tool)| {
 967                    // Skip tools that cannot be supported
 968                    let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
 969                    Some(LanguageModelRequestTool {
 970                        name: name.into(),
 971                        description: tool.description(),
 972                        input_schema,
 973                    })
 974                })
 975                .collect()
 976        } else {
 977            Vec::default()
 978        }
 979    }
 980
 981    pub fn insert_user_message(
 982        &mut self,
 983        text: impl Into<String>,
 984        loaded_context: ContextLoadResult,
 985        git_checkpoint: Option<GitStoreCheckpoint>,
 986        creases: Vec<MessageCrease>,
 987        cx: &mut Context<Self>,
 988    ) -> MessageId {
 989        if !loaded_context.referenced_buffers.is_empty() {
 990            self.action_log.update(cx, |log, cx| {
 991                for buffer in loaded_context.referenced_buffers {
 992                    log.buffer_read(buffer, cx);
 993                }
 994            });
 995        }
 996
 997        let message_id = self.insert_message(
 998            Role::User,
 999            vec![MessageSegment::Text(text.into())],
1000            loaded_context.loaded_context,
1001            creases,
1002            false,
1003            cx,
1004        );
1005
1006        if let Some(git_checkpoint) = git_checkpoint {
1007            self.pending_checkpoint = Some(ThreadCheckpoint {
1008                message_id,
1009                git_checkpoint,
1010            });
1011        }
1012
1013        self.auto_capture_telemetry(cx);
1014
1015        message_id
1016    }
1017
1018    pub fn insert_invisible_continue_message(&mut self, cx: &mut Context<Self>) -> MessageId {
1019        let id = self.insert_message(
1020            Role::User,
1021            vec![MessageSegment::Text("Continue where you left off".into())],
1022            LoadedContext::default(),
1023            vec![],
1024            true,
1025            cx,
1026        );
1027        self.pending_checkpoint = None;
1028
1029        id
1030    }
1031
1032    pub fn insert_assistant_message(
1033        &mut self,
1034        segments: Vec<MessageSegment>,
1035        cx: &mut Context<Self>,
1036    ) -> MessageId {
1037        self.insert_message(
1038            Role::Assistant,
1039            segments,
1040            LoadedContext::default(),
1041            Vec::new(),
1042            false,
1043            cx,
1044        )
1045    }
1046
1047    pub fn insert_message(
1048        &mut self,
1049        role: Role,
1050        segments: Vec<MessageSegment>,
1051        loaded_context: LoadedContext,
1052        creases: Vec<MessageCrease>,
1053        is_hidden: bool,
1054        cx: &mut Context<Self>,
1055    ) -> MessageId {
1056        let id = self.next_message_id.post_inc();
1057        self.messages.push(Message {
1058            id,
1059            role,
1060            segments,
1061            loaded_context,
1062            creases,
1063            is_hidden,
1064            ui_only: false,
1065        });
1066        self.touch_updated_at();
1067        cx.emit(ThreadEvent::MessageAdded(id));
1068        id
1069    }
1070
1071    pub fn edit_message(
1072        &mut self,
1073        id: MessageId,
1074        new_role: Role,
1075        new_segments: Vec<MessageSegment>,
1076        creases: Vec<MessageCrease>,
1077        loaded_context: Option<LoadedContext>,
1078        checkpoint: Option<GitStoreCheckpoint>,
1079        cx: &mut Context<Self>,
1080    ) -> bool {
1081        let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
1082            return false;
1083        };
1084        message.role = new_role;
1085        message.segments = new_segments;
1086        message.creases = creases;
1087        if let Some(context) = loaded_context {
1088            message.loaded_context = context;
1089        }
1090        if let Some(git_checkpoint) = checkpoint {
1091            self.checkpoints_by_message.insert(
1092                id,
1093                ThreadCheckpoint {
1094                    message_id: id,
1095                    git_checkpoint,
1096                },
1097            );
1098        }
1099        self.touch_updated_at();
1100        cx.emit(ThreadEvent::MessageEdited(id));
1101        true
1102    }
1103
1104    pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
1105        let Some(index) = self.messages.iter().position(|message| message.id == id) else {
1106            return false;
1107        };
1108        self.messages.remove(index);
1109        self.touch_updated_at();
1110        cx.emit(ThreadEvent::MessageDeleted(id));
1111        true
1112    }
1113
1114    /// Returns the representation of this [`Thread`] in a textual form.
1115    ///
1116    /// This is the representation we use when attaching a thread as context to another thread.
1117    pub fn text(&self) -> String {
1118        let mut text = String::new();
1119
1120        for message in &self.messages {
1121            text.push_str(match message.role {
1122                language_model::Role::User => "User:",
1123                language_model::Role::Assistant => "Agent:",
1124                language_model::Role::System => "System:",
1125            });
1126            text.push('\n');
1127
1128            for segment in &message.segments {
1129                match segment {
1130                    MessageSegment::Text(content) => text.push_str(content),
1131                    MessageSegment::Thinking { text: content, .. } => {
1132                        text.push_str(&format!("<think>{}</think>", content))
1133                    }
1134                    MessageSegment::RedactedThinking(_) => {}
1135                }
1136            }
1137            text.push('\n');
1138        }
1139
1140        text
1141    }
1142
1143    /// Serializes this thread into a format for storage or telemetry.
1144    pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
1145        let initial_project_snapshot = self.initial_project_snapshot.clone();
1146        cx.spawn(async move |this, cx| {
1147            let initial_project_snapshot = initial_project_snapshot.await;
1148            this.read_with(cx, |this, cx| SerializedThread {
1149                version: SerializedThread::VERSION.to_string(),
1150                summary: this.summary().or_default(),
1151                updated_at: this.updated_at(),
1152                messages: this
1153                    .messages()
1154                    .filter(|message| !message.ui_only)
1155                    .map(|message| SerializedMessage {
1156                        id: message.id,
1157                        role: message.role,
1158                        segments: message
1159                            .segments
1160                            .iter()
1161                            .map(|segment| match segment {
1162                                MessageSegment::Text(text) => {
1163                                    SerializedMessageSegment::Text { text: text.clone() }
1164                                }
1165                                MessageSegment::Thinking { text, signature } => {
1166                                    SerializedMessageSegment::Thinking {
1167                                        text: text.clone(),
1168                                        signature: signature.clone(),
1169                                    }
1170                                }
1171                                MessageSegment::RedactedThinking(data) => {
1172                                    SerializedMessageSegment::RedactedThinking {
1173                                        data: data.clone(),
1174                                    }
1175                                }
1176                            })
1177                            .collect(),
1178                        tool_uses: this
1179                            .tool_uses_for_message(message.id, cx)
1180                            .into_iter()
1181                            .map(|tool_use| SerializedToolUse {
1182                                id: tool_use.id,
1183                                name: tool_use.name,
1184                                input: tool_use.input,
1185                            })
1186                            .collect(),
1187                        tool_results: this
1188                            .tool_results_for_message(message.id)
1189                            .into_iter()
1190                            .map(|tool_result| SerializedToolResult {
1191                                tool_use_id: tool_result.tool_use_id.clone(),
1192                                is_error: tool_result.is_error,
1193                                content: tool_result.content.clone(),
1194                                output: tool_result.output.clone(),
1195                            })
1196                            .collect(),
1197                        context: message.loaded_context.text.clone(),
1198                        creases: message
1199                            .creases
1200                            .iter()
1201                            .map(|crease| SerializedCrease {
1202                                start: crease.range.start,
1203                                end: crease.range.end,
1204                                icon_path: crease.icon_path.clone(),
1205                                label: crease.label.clone(),
1206                            })
1207                            .collect(),
1208                        is_hidden: message.is_hidden,
1209                    })
1210                    .collect(),
1211                initial_project_snapshot,
1212                cumulative_token_usage: this.cumulative_token_usage,
1213                request_token_usage: this.request_token_usage.clone(),
1214                detailed_summary_state: this.detailed_summary_rx.borrow().clone(),
1215                exceeded_window_error: this.exceeded_window_error.clone(),
1216                model: this
1217                    .configured_model
1218                    .as_ref()
1219                    .map(|model| SerializedLanguageModel {
1220                        provider: model.provider.id().0.to_string(),
1221                        model: model.model.id().0.to_string(),
1222                    }),
1223                completion_mode: Some(this.completion_mode),
1224                tool_use_limit_reached: this.tool_use_limit_reached,
1225                profile: Some(this.profile.id().clone()),
1226            })
1227        })
1228    }
1229
1230    pub fn remaining_turns(&self) -> u32 {
1231        self.remaining_turns
1232    }
1233
1234    pub fn set_remaining_turns(&mut self, remaining_turns: u32) {
1235        self.remaining_turns = remaining_turns;
1236    }
1237
1238    pub fn send_to_model(
1239        &mut self,
1240        model: Arc<dyn LanguageModel>,
1241        intent: CompletionIntent,
1242        window: Option<AnyWindowHandle>,
1243        cx: &mut Context<Self>,
1244    ) {
1245        if self.remaining_turns == 0 {
1246            return;
1247        }
1248
1249        self.remaining_turns -= 1;
1250
1251        self.flush_notifications(model.clone(), intent, cx);
1252
1253        let request = self.to_completion_request(model.clone(), intent, cx);
1254
1255        self.stream_completion(request, model, intent, window, cx);
1256    }
1257
1258    pub fn used_tools_since_last_user_message(&self) -> bool {
1259        for message in self.messages.iter().rev() {
1260            if self.tool_use.message_has_tool_results(message.id) {
1261                return true;
1262            } else if message.role == Role::User {
1263                return false;
1264            }
1265        }
1266
1267        false
1268    }
1269
1270    pub fn to_completion_request(
1271        &self,
1272        model: Arc<dyn LanguageModel>,
1273        intent: CompletionIntent,
1274        cx: &mut Context<Self>,
1275    ) -> LanguageModelRequest {
1276        let mut request = LanguageModelRequest {
1277            thread_id: Some(self.id.to_string()),
1278            prompt_id: Some(self.last_prompt_id.to_string()),
1279            intent: Some(intent),
1280            mode: None,
1281            messages: vec![],
1282            tools: Vec::new(),
1283            tool_choice: None,
1284            stop: Vec::new(),
1285            temperature: AgentSettings::temperature_for_model(&model, cx),
1286        };
1287
1288        let available_tools = self.available_tools(cx, model.clone());
1289        let available_tool_names = available_tools
1290            .iter()
1291            .map(|tool| tool.name.clone())
1292            .collect();
1293
1294        let model_context = &ModelContext {
1295            available_tools: available_tool_names,
1296        };
1297
1298        if let Some(project_context) = self.project_context.borrow().as_ref() {
1299            match self
1300                .prompt_builder
1301                .generate_assistant_system_prompt(project_context, model_context)
1302            {
1303                Err(err) => {
1304                    let message = format!("{err:?}").into();
1305                    log::error!("{message}");
1306                    cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1307                        header: "Error generating system prompt".into(),
1308                        message,
1309                    }));
1310                }
1311                Ok(system_prompt) => {
1312                    request.messages.push(LanguageModelRequestMessage {
1313                        role: Role::System,
1314                        content: vec![MessageContent::Text(system_prompt)],
1315                        cache: true,
1316                    });
1317                }
1318            }
1319        } else {
1320            let message = "Context for system prompt unexpectedly not ready.".into();
1321            log::error!("{message}");
1322            cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1323                header: "Error generating system prompt".into(),
1324                message,
1325            }));
1326        }
1327
1328        let mut message_ix_to_cache = None;
1329        for message in &self.messages {
1330            // ui_only messages are for the UI only, not for the model
1331            if message.ui_only {
1332                continue;
1333            }
1334
1335            let mut request_message = LanguageModelRequestMessage {
1336                role: message.role,
1337                content: Vec::new(),
1338                cache: false,
1339            };
1340
1341            message
1342                .loaded_context
1343                .add_to_request_message(&mut request_message);
1344
1345            for segment in &message.segments {
1346                match segment {
1347                    MessageSegment::Text(text) => {
1348                        let text = text.trim_end();
1349                        if !text.is_empty() {
1350                            request_message
1351                                .content
1352                                .push(MessageContent::Text(text.into()));
1353                        }
1354                    }
1355                    MessageSegment::Thinking { text, signature } => {
1356                        if !text.is_empty() {
1357                            request_message.content.push(MessageContent::Thinking {
1358                                text: text.into(),
1359                                signature: signature.clone(),
1360                            });
1361                        }
1362                    }
1363                    MessageSegment::RedactedThinking(data) => {
1364                        request_message
1365                            .content
1366                            .push(MessageContent::RedactedThinking(data.clone()));
1367                    }
1368                };
1369            }
1370
1371            let mut cache_message = true;
1372            let mut tool_results_message = LanguageModelRequestMessage {
1373                role: Role::User,
1374                content: Vec::new(),
1375                cache: false,
1376            };
1377            for (tool_use, tool_result) in self.tool_use.tool_results(message.id) {
1378                if let Some(tool_result) = tool_result {
1379                    request_message
1380                        .content
1381                        .push(MessageContent::ToolUse(tool_use.clone()));
1382                    tool_results_message
1383                        .content
1384                        .push(MessageContent::ToolResult(LanguageModelToolResult {
1385                            tool_use_id: tool_use.id.clone(),
1386                            tool_name: tool_result.tool_name.clone(),
1387                            is_error: tool_result.is_error,
1388                            content: if tool_result.content.is_empty() {
1389                                // Surprisingly, the API fails if we return an empty string here.
1390                                // It thinks we are sending a tool use without a tool result.
1391                                "<Tool returned an empty string>".into()
1392                            } else {
1393                                tool_result.content.clone()
1394                            },
1395                            output: None,
1396                        }));
1397                } else {
1398                    cache_message = false;
1399                    log::debug!(
1400                        "skipped tool use {:?} because it is still pending",
1401                        tool_use
1402                    );
1403                }
1404            }
1405
1406            if cache_message {
1407                message_ix_to_cache = Some(request.messages.len());
1408            }
1409            request.messages.push(request_message);
1410
1411            if !tool_results_message.content.is_empty() {
1412                if cache_message {
1413                    message_ix_to_cache = Some(request.messages.len());
1414                }
1415                request.messages.push(tool_results_message);
1416            }
1417        }
1418
1419        // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1420        if let Some(message_ix_to_cache) = message_ix_to_cache {
1421            request.messages[message_ix_to_cache].cache = true;
1422        }
1423
1424        request.tools = available_tools;
1425        request.mode = if model.supports_burn_mode() {
1426            Some(self.completion_mode.into())
1427        } else {
1428            Some(CompletionMode::Normal.into())
1429        };
1430
1431        request
1432    }
1433
1434    fn to_summarize_request(
1435        &self,
1436        model: &Arc<dyn LanguageModel>,
1437        intent: CompletionIntent,
1438        added_user_message: String,
1439        cx: &App,
1440    ) -> LanguageModelRequest {
1441        let mut request = LanguageModelRequest {
1442            thread_id: None,
1443            prompt_id: None,
1444            intent: Some(intent),
1445            mode: None,
1446            messages: vec![],
1447            tools: Vec::new(),
1448            tool_choice: None,
1449            stop: Vec::new(),
1450            temperature: AgentSettings::temperature_for_model(model, cx),
1451        };
1452
1453        for message in &self.messages {
1454            let mut request_message = LanguageModelRequestMessage {
1455                role: message.role,
1456                content: Vec::new(),
1457                cache: false,
1458            };
1459
1460            for segment in &message.segments {
1461                match segment {
1462                    MessageSegment::Text(text) => request_message
1463                        .content
1464                        .push(MessageContent::Text(text.clone())),
1465                    MessageSegment::Thinking { .. } => {}
1466                    MessageSegment::RedactedThinking(_) => {}
1467                }
1468            }
1469
1470            if request_message.content.is_empty() {
1471                continue;
1472            }
1473
1474            request.messages.push(request_message);
1475        }
1476
1477        request.messages.push(LanguageModelRequestMessage {
1478            role: Role::User,
1479            content: vec![MessageContent::Text(added_user_message)],
1480            cache: false,
1481        });
1482
1483        request
1484    }
1485
1486    /// Insert auto-generated notifications (if any) to the thread
1487    fn flush_notifications(
1488        &mut self,
1489        model: Arc<dyn LanguageModel>,
1490        intent: CompletionIntent,
1491        cx: &mut Context<Self>,
1492    ) {
1493        match intent {
1494            CompletionIntent::UserPrompt | CompletionIntent::ToolResults => {
1495                if let Some(pending_tool_use) = self.attach_tracked_files_state(model, cx) {
1496                    cx.emit(ThreadEvent::ToolFinished {
1497                        tool_use_id: pending_tool_use.id.clone(),
1498                        pending_tool_use: Some(pending_tool_use),
1499                    });
1500                }
1501            }
1502            CompletionIntent::ThreadSummarization
1503            | CompletionIntent::ThreadContextSummarization
1504            | CompletionIntent::CreateFile
1505            | CompletionIntent::EditFile
1506            | CompletionIntent::InlineAssist
1507            | CompletionIntent::TerminalInlineAssist
1508            | CompletionIntent::GenerateGitCommitMessage => {}
1509        };
1510    }
1511
1512    fn attach_tracked_files_state(
1513        &mut self,
1514        model: Arc<dyn LanguageModel>,
1515        cx: &mut App,
1516    ) -> Option<PendingToolUse> {
1517        let action_log = self.action_log.read(cx);
1518
1519        action_log.unnotified_stale_buffers(cx).next()?;
1520
1521        // Represent notification as a simulated `project_notifications` tool call
1522        let tool_name = Arc::from("project_notifications");
1523        let Some(tool) = self.tools.read(cx).tool(&tool_name, cx) else {
1524            debug_panic!("`project_notifications` tool not found");
1525            return None;
1526        };
1527
1528        if !self.profile.is_tool_enabled(tool.source(), tool.name(), cx) {
1529            return None;
1530        }
1531
1532        let input = serde_json::json!({});
1533        let request = Arc::new(LanguageModelRequest::default()); // unused
1534        let window = None;
1535        let tool_result = tool.run(
1536            input,
1537            request,
1538            self.project.clone(),
1539            self.action_log.clone(),
1540            model.clone(),
1541            window,
1542            cx,
1543        );
1544
1545        let tool_use_id =
1546            LanguageModelToolUseId::from(format!("project_notifications_{}", self.messages.len()));
1547
1548        let tool_use = LanguageModelToolUse {
1549            id: tool_use_id.clone(),
1550            name: tool_name.clone(),
1551            raw_input: "{}".to_string(),
1552            input: serde_json::json!({}),
1553            is_input_complete: true,
1554        };
1555
1556        let tool_output = cx.background_executor().block(tool_result.output);
1557
1558        // Attach a project_notification tool call to the latest existing
1559        // Assistant message. We cannot create a new Assistant message
1560        // because thinking models require a `thinking` block that we
1561        // cannot mock. We cannot send a notification as a normal
1562        // (non-tool-use) User message because this distracts Agent
1563        // too much.
1564        let tool_message_id = self
1565            .messages
1566            .iter()
1567            .enumerate()
1568            .rfind(|(_, message)| message.role == Role::Assistant)
1569            .map(|(_, message)| message.id)?;
1570
1571        let tool_use_metadata = ToolUseMetadata {
1572            model: model.clone(),
1573            thread_id: self.id.clone(),
1574            prompt_id: self.last_prompt_id.clone(),
1575        };
1576
1577        self.tool_use
1578            .request_tool_use(tool_message_id, tool_use, tool_use_metadata.clone(), cx);
1579
1580        let pending_tool_use = self.tool_use.insert_tool_output(
1581            tool_use_id.clone(),
1582            tool_name,
1583            tool_output,
1584            self.configured_model.as_ref(),
1585        );
1586
1587        pending_tool_use
1588    }
1589
1590    pub fn stream_completion(
1591        &mut self,
1592        request: LanguageModelRequest,
1593        model: Arc<dyn LanguageModel>,
1594        intent: CompletionIntent,
1595        window: Option<AnyWindowHandle>,
1596        cx: &mut Context<Self>,
1597    ) {
1598        self.tool_use_limit_reached = false;
1599
1600        let pending_completion_id = post_inc(&mut self.completion_count);
1601        let mut request_callback_parameters = if self.request_callback.is_some() {
1602            Some((request.clone(), Vec::new()))
1603        } else {
1604            None
1605        };
1606        let prompt_id = self.last_prompt_id.clone();
1607        let tool_use_metadata = ToolUseMetadata {
1608            model: model.clone(),
1609            thread_id: self.id.clone(),
1610            prompt_id: prompt_id.clone(),
1611        };
1612
1613        self.last_received_chunk_at = Some(Instant::now());
1614
1615        let task = cx.spawn(async move |thread, cx| {
1616            let stream_completion_future = model.stream_completion(request, &cx);
1617            let initial_token_usage =
1618                thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1619            let stream_completion = async {
1620                let mut events = stream_completion_future.await?;
1621
1622                let mut stop_reason = StopReason::EndTurn;
1623                let mut current_token_usage = TokenUsage::default();
1624
1625                thread
1626                    .update(cx, |_thread, cx| {
1627                        cx.emit(ThreadEvent::NewRequest);
1628                    })
1629                    .ok();
1630
1631                let mut request_assistant_message_id = None;
1632
1633                while let Some(event) = events.next().await {
1634                    if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1635                        response_events
1636                            .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1637                    }
1638
1639                    thread.update(cx, |thread, cx| {
1640                        match event? {
1641                            LanguageModelCompletionEvent::StartMessage { .. } => {
1642                                request_assistant_message_id =
1643                                    Some(thread.insert_assistant_message(
1644                                        vec![MessageSegment::Text(String::new())],
1645                                        cx,
1646                                    ));
1647                            }
1648                            LanguageModelCompletionEvent::Stop(reason) => {
1649                                stop_reason = reason;
1650                            }
1651                            LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1652                                thread.update_token_usage_at_last_message(token_usage);
1653                                thread.cumulative_token_usage = thread.cumulative_token_usage
1654                                    + token_usage
1655                                    - current_token_usage;
1656                                current_token_usage = token_usage;
1657                            }
1658                            LanguageModelCompletionEvent::Text(chunk) => {
1659                                thread.received_chunk();
1660
1661                                cx.emit(ThreadEvent::ReceivedTextChunk);
1662                                if let Some(last_message) = thread.messages.last_mut() {
1663                                    if last_message.role == Role::Assistant
1664                                        && !thread.tool_use.has_tool_results(last_message.id)
1665                                    {
1666                                        last_message.push_text(&chunk);
1667                                        cx.emit(ThreadEvent::StreamedAssistantText(
1668                                            last_message.id,
1669                                            chunk,
1670                                        ));
1671                                    } else {
1672                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
1673                                        // of a new Assistant response.
1674                                        //
1675                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1676                                        // will result in duplicating the text of the chunk in the rendered Markdown.
1677                                        request_assistant_message_id =
1678                                            Some(thread.insert_assistant_message(
1679                                                vec![MessageSegment::Text(chunk.to_string())],
1680                                                cx,
1681                                            ));
1682                                    };
1683                                }
1684                            }
1685                            LanguageModelCompletionEvent::Thinking {
1686                                text: chunk,
1687                                signature,
1688                            } => {
1689                                thread.received_chunk();
1690
1691                                if let Some(last_message) = thread.messages.last_mut() {
1692                                    if last_message.role == Role::Assistant
1693                                        && !thread.tool_use.has_tool_results(last_message.id)
1694                                    {
1695                                        last_message.push_thinking(&chunk, signature);
1696                                        cx.emit(ThreadEvent::StreamedAssistantThinking(
1697                                            last_message.id,
1698                                            chunk,
1699                                        ));
1700                                    } else {
1701                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
1702                                        // of a new Assistant response.
1703                                        //
1704                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1705                                        // will result in duplicating the text of the chunk in the rendered Markdown.
1706                                        request_assistant_message_id =
1707                                            Some(thread.insert_assistant_message(
1708                                                vec![MessageSegment::Thinking {
1709                                                    text: chunk.to_string(),
1710                                                    signature,
1711                                                }],
1712                                                cx,
1713                                            ));
1714                                    };
1715                                }
1716                            }
1717                            LanguageModelCompletionEvent::RedactedThinking { data } => {
1718                                thread.received_chunk();
1719
1720                                if let Some(last_message) = thread.messages.last_mut() {
1721                                    if last_message.role == Role::Assistant
1722                                        && !thread.tool_use.has_tool_results(last_message.id)
1723                                    {
1724                                        last_message.push_redacted_thinking(data);
1725                                    } else {
1726                                        request_assistant_message_id =
1727                                            Some(thread.insert_assistant_message(
1728                                                vec![MessageSegment::RedactedThinking(data)],
1729                                                cx,
1730                                            ));
1731                                    };
1732                                }
1733                            }
1734                            LanguageModelCompletionEvent::ToolUse(tool_use) => {
1735                                let last_assistant_message_id = request_assistant_message_id
1736                                    .unwrap_or_else(|| {
1737                                        let new_assistant_message_id =
1738                                            thread.insert_assistant_message(vec![], cx);
1739                                        request_assistant_message_id =
1740                                            Some(new_assistant_message_id);
1741                                        new_assistant_message_id
1742                                    });
1743
1744                                let tool_use_id = tool_use.id.clone();
1745                                let streamed_input = if tool_use.is_input_complete {
1746                                    None
1747                                } else {
1748                                    Some((&tool_use.input).clone())
1749                                };
1750
1751                                let ui_text = thread.tool_use.request_tool_use(
1752                                    last_assistant_message_id,
1753                                    tool_use,
1754                                    tool_use_metadata.clone(),
1755                                    cx,
1756                                );
1757
1758                                if let Some(input) = streamed_input {
1759                                    cx.emit(ThreadEvent::StreamedToolUse {
1760                                        tool_use_id,
1761                                        ui_text,
1762                                        input,
1763                                    });
1764                                }
1765                            }
1766                            LanguageModelCompletionEvent::ToolUseJsonParseError {
1767                                id,
1768                                tool_name,
1769                                raw_input: invalid_input_json,
1770                                json_parse_error,
1771                            } => {
1772                                thread.receive_invalid_tool_json(
1773                                    id,
1774                                    tool_name,
1775                                    invalid_input_json,
1776                                    json_parse_error,
1777                                    window,
1778                                    cx,
1779                                );
1780                            }
1781                            LanguageModelCompletionEvent::StatusUpdate(status_update) => {
1782                                if let Some(completion) = thread
1783                                    .pending_completions
1784                                    .iter_mut()
1785                                    .find(|completion| completion.id == pending_completion_id)
1786                                {
1787                                    match status_update {
1788                                        CompletionRequestStatus::Queued { position } => {
1789                                            completion.queue_state =
1790                                                QueueState::Queued { position };
1791                                        }
1792                                        CompletionRequestStatus::Started => {
1793                                            completion.queue_state = QueueState::Started;
1794                                        }
1795                                        CompletionRequestStatus::Failed {
1796                                            code,
1797                                            message,
1798                                            request_id: _,
1799                                            retry_after,
1800                                        } => {
1801                                            return Err(
1802                                                LanguageModelCompletionError::from_cloud_failure(
1803                                                    model.upstream_provider_name(),
1804                                                    code,
1805                                                    message,
1806                                                    retry_after.map(Duration::from_secs_f64),
1807                                                ),
1808                                            );
1809                                        }
1810                                        CompletionRequestStatus::UsageUpdated { amount, limit } => {
1811                                            thread.update_model_request_usage(
1812                                                amount as u32,
1813                                                limit,
1814                                                cx,
1815                                            );
1816                                        }
1817                                        CompletionRequestStatus::ToolUseLimitReached => {
1818                                            thread.tool_use_limit_reached = true;
1819                                            cx.emit(ThreadEvent::ToolUseLimitReached);
1820                                        }
1821                                    }
1822                                }
1823                            }
1824                        }
1825
1826                        thread.touch_updated_at();
1827                        cx.emit(ThreadEvent::StreamedCompletion);
1828                        cx.notify();
1829
1830                        thread.auto_capture_telemetry(cx);
1831                        Ok(())
1832                    })??;
1833
1834                    smol::future::yield_now().await;
1835                }
1836
1837                thread.update(cx, |thread, cx| {
1838                    thread.last_received_chunk_at = None;
1839                    thread
1840                        .pending_completions
1841                        .retain(|completion| completion.id != pending_completion_id);
1842
1843                    // If there is a response without tool use, summarize the message. Otherwise,
1844                    // allow two tool uses before summarizing.
1845                    if matches!(thread.summary, ThreadSummary::Pending)
1846                        && thread.messages.len() >= 2
1847                        && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
1848                    {
1849                        thread.summarize(cx);
1850                    }
1851                })?;
1852
1853                anyhow::Ok(stop_reason)
1854            };
1855
1856            let result = stream_completion.await;
1857            let mut retry_scheduled = false;
1858
1859            thread
1860                .update(cx, |thread, cx| {
1861                    thread.finalize_pending_checkpoint(cx);
1862                    match result.as_ref() {
1863                        Ok(stop_reason) => {
1864                            match stop_reason {
1865                                StopReason::ToolUse => {
1866                                    let tool_uses =
1867                                        thread.use_pending_tools(window, model.clone(), cx);
1868                                    cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1869                                }
1870                                StopReason::EndTurn | StopReason::MaxTokens => {
1871                                    thread.project.update(cx, |project, cx| {
1872                                        project.set_agent_location(None, cx);
1873                                    });
1874                                }
1875                                StopReason::Refusal => {
1876                                    thread.project.update(cx, |project, cx| {
1877                                        project.set_agent_location(None, cx);
1878                                    });
1879
1880                                    // Remove the turn that was refused.
1881                                    //
1882                                    // https://docs.anthropic.com/en/docs/test-and-evaluate/strengthen-guardrails/handle-streaming-refusals#reset-context-after-refusal
1883                                    {
1884                                        let mut messages_to_remove = Vec::new();
1885
1886                                        for (ix, message) in
1887                                            thread.messages.iter().enumerate().rev()
1888                                        {
1889                                            messages_to_remove.push(message.id);
1890
1891                                            if message.role == Role::User {
1892                                                if ix == 0 {
1893                                                    break;
1894                                                }
1895
1896                                                if let Some(prev_message) =
1897                                                    thread.messages.get(ix - 1)
1898                                                {
1899                                                    if prev_message.role == Role::Assistant {
1900                                                        break;
1901                                                    }
1902                                                }
1903                                            }
1904                                        }
1905
1906                                        for message_id in messages_to_remove {
1907                                            thread.delete_message(message_id, cx);
1908                                        }
1909                                    }
1910
1911                                    cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1912                                        header: "Language model refusal".into(),
1913                                        message:
1914                                            "Model refused to generate content for safety reasons."
1915                                                .into(),
1916                                    }));
1917                                }
1918                            }
1919
1920                            // We successfully completed, so cancel any remaining retries.
1921                            thread.retry_state = None;
1922                        }
1923                        Err(error) => {
1924                            thread.project.update(cx, |project, cx| {
1925                                project.set_agent_location(None, cx);
1926                            });
1927
1928                            fn emit_generic_error(error: &anyhow::Error, cx: &mut Context<Thread>) {
1929                                let error_message = error
1930                                    .chain()
1931                                    .map(|err| err.to_string())
1932                                    .collect::<Vec<_>>()
1933                                    .join("\n");
1934                                cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1935                                    header: "Error interacting with language model".into(),
1936                                    message: SharedString::from(error_message.clone()),
1937                                }));
1938                            }
1939
1940                            if error.is::<PaymentRequiredError>() {
1941                                cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1942                            } else if let Some(error) =
1943                                error.downcast_ref::<ModelRequestLimitReachedError>()
1944                            {
1945                                cx.emit(ThreadEvent::ShowError(
1946                                    ThreadError::ModelRequestLimitReached { plan: error.plan },
1947                                ));
1948                            } else if let Some(completion_error) =
1949                                error.downcast_ref::<LanguageModelCompletionError>()
1950                            {
1951                                use LanguageModelCompletionError::*;
1952                                match &completion_error {
1953                                    PromptTooLarge { tokens, .. } => {
1954                                        let tokens = tokens.unwrap_or_else(|| {
1955                                            // We didn't get an exact token count from the API, so fall back on our estimate.
1956                                            thread
1957                                                .total_token_usage()
1958                                                .map(|usage| usage.total)
1959                                                .unwrap_or(0)
1960                                                // We know the context window was exceeded in practice, so if our estimate was
1961                                                // lower than max tokens, the estimate was wrong; return that we exceeded by 1.
1962                                                .max(model.max_token_count().saturating_add(1))
1963                                        });
1964                                        thread.exceeded_window_error = Some(ExceededWindowError {
1965                                            model_id: model.id(),
1966                                            token_count: tokens,
1967                                        });
1968                                        cx.notify();
1969                                    }
1970                                    RateLimitExceeded {
1971                                        retry_after: Some(retry_after),
1972                                        ..
1973                                    }
1974                                    | ServerOverloaded {
1975                                        retry_after: Some(retry_after),
1976                                        ..
1977                                    } => {
1978                                        thread.handle_rate_limit_error(
1979                                            &completion_error,
1980                                            *retry_after,
1981                                            model.clone(),
1982                                            intent,
1983                                            window,
1984                                            cx,
1985                                        );
1986                                        retry_scheduled = true;
1987                                    }
1988                                    RateLimitExceeded { .. } | ServerOverloaded { .. } => {
1989                                        retry_scheduled = thread.handle_retryable_error(
1990                                            &completion_error,
1991                                            model.clone(),
1992                                            intent,
1993                                            window,
1994                                            cx,
1995                                        );
1996                                        if !retry_scheduled {
1997                                            emit_generic_error(error, cx);
1998                                        }
1999                                    }
2000                                    ApiInternalServerError { .. }
2001                                    | ApiReadResponseError { .. }
2002                                    | HttpSend { .. } => {
2003                                        retry_scheduled = thread.handle_retryable_error(
2004                                            &completion_error,
2005                                            model.clone(),
2006                                            intent,
2007                                            window,
2008                                            cx,
2009                                        );
2010                                        if !retry_scheduled {
2011                                            emit_generic_error(error, cx);
2012                                        }
2013                                    }
2014                                    NoApiKey { .. }
2015                                    | HttpResponseError { .. }
2016                                    | BadRequestFormat { .. }
2017                                    | AuthenticationError { .. }
2018                                    | PermissionError { .. }
2019                                    | ApiEndpointNotFound { .. }
2020                                    | SerializeRequest { .. }
2021                                    | BuildRequestBody { .. }
2022                                    | DeserializeResponse { .. }
2023                                    | Other { .. } => emit_generic_error(error, cx),
2024                                }
2025                            } else {
2026                                emit_generic_error(error, cx);
2027                            }
2028
2029                            if !retry_scheduled {
2030                                thread.cancel_last_completion(window, cx);
2031                            }
2032                        }
2033                    }
2034
2035                    if !retry_scheduled {
2036                        cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
2037                    }
2038
2039                    if let Some((request_callback, (request, response_events))) = thread
2040                        .request_callback
2041                        .as_mut()
2042                        .zip(request_callback_parameters.as_ref())
2043                    {
2044                        request_callback(request, response_events);
2045                    }
2046
2047                    thread.auto_capture_telemetry(cx);
2048
2049                    if let Ok(initial_usage) = initial_token_usage {
2050                        let usage = thread.cumulative_token_usage - initial_usage;
2051
2052                        telemetry::event!(
2053                            "Assistant Thread Completion",
2054                            thread_id = thread.id().to_string(),
2055                            prompt_id = prompt_id,
2056                            model = model.telemetry_id(),
2057                            model_provider = model.provider_id().to_string(),
2058                            input_tokens = usage.input_tokens,
2059                            output_tokens = usage.output_tokens,
2060                            cache_creation_input_tokens = usage.cache_creation_input_tokens,
2061                            cache_read_input_tokens = usage.cache_read_input_tokens,
2062                        );
2063                    }
2064                })
2065                .ok();
2066        });
2067
2068        self.pending_completions.push(PendingCompletion {
2069            id: pending_completion_id,
2070            queue_state: QueueState::Sending,
2071            _task: task,
2072        });
2073    }
2074
2075    pub fn summarize(&mut self, cx: &mut Context<Self>) {
2076        let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
2077            println!("No thread summary model");
2078            return;
2079        };
2080
2081        if !model.provider.is_authenticated(cx) {
2082            return;
2083        }
2084
2085        let added_user_message = include_str!("./prompts/summarize_thread_prompt.txt");
2086
2087        let request = self.to_summarize_request(
2088            &model.model,
2089            CompletionIntent::ThreadSummarization,
2090            added_user_message.into(),
2091            cx,
2092        );
2093
2094        self.summary = ThreadSummary::Generating;
2095
2096        self.pending_summary = cx.spawn(async move |this, cx| {
2097            let result = async {
2098                let mut messages = model.model.stream_completion(request, &cx).await?;
2099
2100                let mut new_summary = String::new();
2101                while let Some(event) = messages.next().await {
2102                    let Ok(event) = event else {
2103                        continue;
2104                    };
2105                    let text = match event {
2106                        LanguageModelCompletionEvent::Text(text) => text,
2107                        LanguageModelCompletionEvent::StatusUpdate(
2108                            CompletionRequestStatus::UsageUpdated { amount, limit },
2109                        ) => {
2110                            this.update(cx, |thread, cx| {
2111                                thread.update_model_request_usage(amount as u32, limit, cx);
2112                            })?;
2113                            continue;
2114                        }
2115                        _ => continue,
2116                    };
2117
2118                    let mut lines = text.lines();
2119                    new_summary.extend(lines.next());
2120
2121                    // Stop if the LLM generated multiple lines.
2122                    if lines.next().is_some() {
2123                        break;
2124                    }
2125                }
2126
2127                anyhow::Ok(new_summary)
2128            }
2129            .await;
2130
2131            this.update(cx, |this, cx| {
2132                match result {
2133                    Ok(new_summary) => {
2134                        if new_summary.is_empty() {
2135                            this.summary = ThreadSummary::Error;
2136                        } else {
2137                            this.summary = ThreadSummary::Ready(new_summary.into());
2138                        }
2139                    }
2140                    Err(err) => {
2141                        this.summary = ThreadSummary::Error;
2142                        log::error!("Failed to generate thread summary: {}", err);
2143                    }
2144                }
2145                cx.emit(ThreadEvent::SummaryGenerated);
2146            })
2147            .log_err()?;
2148
2149            Some(())
2150        });
2151    }
2152
2153    fn handle_rate_limit_error(
2154        &mut self,
2155        error: &LanguageModelCompletionError,
2156        retry_after: Duration,
2157        model: Arc<dyn LanguageModel>,
2158        intent: CompletionIntent,
2159        window: Option<AnyWindowHandle>,
2160        cx: &mut Context<Self>,
2161    ) {
2162        // For rate limit errors, we only retry once with the specified duration
2163        let retry_message = format!("{error}. Retrying in {} seconds…", retry_after.as_secs());
2164        log::warn!(
2165            "Retrying completion request in {} seconds: {error:?}",
2166            retry_after.as_secs(),
2167        );
2168
2169        // Add a UI-only message instead of a regular message
2170        let id = self.next_message_id.post_inc();
2171        self.messages.push(Message {
2172            id,
2173            role: Role::System,
2174            segments: vec![MessageSegment::Text(retry_message)],
2175            loaded_context: LoadedContext::default(),
2176            creases: Vec::new(),
2177            is_hidden: false,
2178            ui_only: true,
2179        });
2180        cx.emit(ThreadEvent::MessageAdded(id));
2181        // Schedule the retry
2182        let thread_handle = cx.entity().downgrade();
2183
2184        cx.spawn(async move |_thread, cx| {
2185            cx.background_executor().timer(retry_after).await;
2186
2187            thread_handle
2188                .update(cx, |thread, cx| {
2189                    // Retry the completion
2190                    thread.send_to_model(model, intent, window, cx);
2191                })
2192                .log_err();
2193        })
2194        .detach();
2195    }
2196
2197    fn handle_retryable_error(
2198        &mut self,
2199        error: &LanguageModelCompletionError,
2200        model: Arc<dyn LanguageModel>,
2201        intent: CompletionIntent,
2202        window: Option<AnyWindowHandle>,
2203        cx: &mut Context<Self>,
2204    ) -> bool {
2205        self.handle_retryable_error_with_delay(error, None, model, intent, window, cx)
2206    }
2207
2208    fn handle_retryable_error_with_delay(
2209        &mut self,
2210        error: &LanguageModelCompletionError,
2211        custom_delay: Option<Duration>,
2212        model: Arc<dyn LanguageModel>,
2213        intent: CompletionIntent,
2214        window: Option<AnyWindowHandle>,
2215        cx: &mut Context<Self>,
2216    ) -> bool {
2217        let retry_state = self.retry_state.get_or_insert(RetryState {
2218            attempt: 0,
2219            max_attempts: MAX_RETRY_ATTEMPTS,
2220            intent,
2221        });
2222
2223        retry_state.attempt += 1;
2224        let attempt = retry_state.attempt;
2225        let max_attempts = retry_state.max_attempts;
2226        let intent = retry_state.intent;
2227
2228        if attempt <= max_attempts {
2229            // Use custom delay if provided (e.g., from rate limit), otherwise exponential backoff
2230            let delay = if let Some(custom_delay) = custom_delay {
2231                custom_delay
2232            } else {
2233                let delay_secs = BASE_RETRY_DELAY_SECS * 2u64.pow((attempt - 1) as u32);
2234                Duration::from_secs(delay_secs)
2235            };
2236
2237            // Add a transient message to inform the user
2238            let delay_secs = delay.as_secs();
2239            let retry_message = format!(
2240                "{error}. Retrying (attempt {attempt} of {max_attempts}) \
2241                in {delay_secs} seconds..."
2242            );
2243            log::warn!(
2244                "Retrying completion request (attempt {attempt} of {max_attempts}) \
2245                in {delay_secs} seconds: {error:?}",
2246            );
2247
2248            // Add a UI-only message instead of a regular message
2249            let id = self.next_message_id.post_inc();
2250            self.messages.push(Message {
2251                id,
2252                role: Role::System,
2253                segments: vec![MessageSegment::Text(retry_message)],
2254                loaded_context: LoadedContext::default(),
2255                creases: Vec::new(),
2256                is_hidden: false,
2257                ui_only: true,
2258            });
2259            cx.emit(ThreadEvent::MessageAdded(id));
2260
2261            // Schedule the retry
2262            let thread_handle = cx.entity().downgrade();
2263
2264            cx.spawn(async move |_thread, cx| {
2265                cx.background_executor().timer(delay).await;
2266
2267                thread_handle
2268                    .update(cx, |thread, cx| {
2269                        // Retry the completion
2270                        thread.send_to_model(model, intent, window, cx);
2271                    })
2272                    .log_err();
2273            })
2274            .detach();
2275
2276            true
2277        } else {
2278            // Max retries exceeded
2279            self.retry_state = None;
2280
2281            let notification_text = if max_attempts == 1 {
2282                "Failed after retrying.".into()
2283            } else {
2284                format!("Failed after retrying {} times.", max_attempts).into()
2285            };
2286
2287            // Stop generating since we're giving up on retrying.
2288            self.pending_completions.clear();
2289
2290            cx.emit(ThreadEvent::RetriesFailed {
2291                message: notification_text,
2292            });
2293
2294            false
2295        }
2296    }
2297
2298    pub fn start_generating_detailed_summary_if_needed(
2299        &mut self,
2300        thread_store: WeakEntity<ThreadStore>,
2301        cx: &mut Context<Self>,
2302    ) {
2303        let Some(last_message_id) = self.messages.last().map(|message| message.id) else {
2304            return;
2305        };
2306
2307        match &*self.detailed_summary_rx.borrow() {
2308            DetailedSummaryState::Generating { message_id, .. }
2309            | DetailedSummaryState::Generated { message_id, .. }
2310                if *message_id == last_message_id =>
2311            {
2312                // Already up-to-date
2313                return;
2314            }
2315            _ => {}
2316        }
2317
2318        let Some(ConfiguredModel { model, provider }) =
2319            LanguageModelRegistry::read_global(cx).thread_summary_model()
2320        else {
2321            return;
2322        };
2323
2324        if !provider.is_authenticated(cx) {
2325            return;
2326        }
2327
2328        let added_user_message = include_str!("./prompts/summarize_thread_detailed_prompt.txt");
2329
2330        let request = self.to_summarize_request(
2331            &model,
2332            CompletionIntent::ThreadContextSummarization,
2333            added_user_message.into(),
2334            cx,
2335        );
2336
2337        *self.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generating {
2338            message_id: last_message_id,
2339        };
2340
2341        // Replace the detailed summarization task if there is one, cancelling it. It would probably
2342        // be better to allow the old task to complete, but this would require logic for choosing
2343        // which result to prefer (the old task could complete after the new one, resulting in a
2344        // stale summary).
2345        self.detailed_summary_task = cx.spawn(async move |thread, cx| {
2346            let stream = model.stream_completion_text(request, &cx);
2347            let Some(mut messages) = stream.await.log_err() else {
2348                thread
2349                    .update(cx, |thread, _cx| {
2350                        *thread.detailed_summary_tx.borrow_mut() =
2351                            DetailedSummaryState::NotGenerated;
2352                    })
2353                    .ok()?;
2354                return None;
2355            };
2356
2357            let mut new_detailed_summary = String::new();
2358
2359            while let Some(chunk) = messages.stream.next().await {
2360                if let Some(chunk) = chunk.log_err() {
2361                    new_detailed_summary.push_str(&chunk);
2362                }
2363            }
2364
2365            thread
2366                .update(cx, |thread, _cx| {
2367                    *thread.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generated {
2368                        text: new_detailed_summary.into(),
2369                        message_id: last_message_id,
2370                    };
2371                })
2372                .ok()?;
2373
2374            // Save thread so its summary can be reused later
2375            if let Some(thread) = thread.upgrade() {
2376                if let Ok(Ok(save_task)) = cx.update(|cx| {
2377                    thread_store
2378                        .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx))
2379                }) {
2380                    save_task.await.log_err();
2381                }
2382            }
2383
2384            Some(())
2385        });
2386    }
2387
2388    pub async fn wait_for_detailed_summary_or_text(
2389        this: &Entity<Self>,
2390        cx: &mut AsyncApp,
2391    ) -> Option<SharedString> {
2392        let mut detailed_summary_rx = this
2393            .read_with(cx, |this, _cx| this.detailed_summary_rx.clone())
2394            .ok()?;
2395        loop {
2396            match detailed_summary_rx.recv().await? {
2397                DetailedSummaryState::Generating { .. } => {}
2398                DetailedSummaryState::NotGenerated => {
2399                    return this.read_with(cx, |this, _cx| this.text().into()).ok();
2400                }
2401                DetailedSummaryState::Generated { text, .. } => return Some(text),
2402            }
2403        }
2404    }
2405
2406    pub fn latest_detailed_summary_or_text(&self) -> SharedString {
2407        self.detailed_summary_rx
2408            .borrow()
2409            .text()
2410            .unwrap_or_else(|| self.text().into())
2411    }
2412
2413    pub fn is_generating_detailed_summary(&self) -> bool {
2414        matches!(
2415            &*self.detailed_summary_rx.borrow(),
2416            DetailedSummaryState::Generating { .. }
2417        )
2418    }
2419
2420    pub fn use_pending_tools(
2421        &mut self,
2422        window: Option<AnyWindowHandle>,
2423        model: Arc<dyn LanguageModel>,
2424        cx: &mut Context<Self>,
2425    ) -> Vec<PendingToolUse> {
2426        self.auto_capture_telemetry(cx);
2427        let request =
2428            Arc::new(self.to_completion_request(model.clone(), CompletionIntent::ToolResults, cx));
2429        let pending_tool_uses = self
2430            .tool_use
2431            .pending_tool_uses()
2432            .into_iter()
2433            .filter(|tool_use| tool_use.status.is_idle())
2434            .cloned()
2435            .collect::<Vec<_>>();
2436
2437        for tool_use in pending_tool_uses.iter() {
2438            self.use_pending_tool(tool_use.clone(), request.clone(), model.clone(), window, cx);
2439        }
2440
2441        pending_tool_uses
2442    }
2443
2444    fn use_pending_tool(
2445        &mut self,
2446        tool_use: PendingToolUse,
2447        request: Arc<LanguageModelRequest>,
2448        model: Arc<dyn LanguageModel>,
2449        window: Option<AnyWindowHandle>,
2450        cx: &mut Context<Self>,
2451    ) {
2452        let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) else {
2453            return self.handle_hallucinated_tool_use(tool_use.id, tool_use.name, window, cx);
2454        };
2455
2456        if !self.profile.is_tool_enabled(tool.source(), tool.name(), cx) {
2457            return self.handle_hallucinated_tool_use(tool_use.id, tool_use.name, window, cx);
2458        }
2459
2460        if tool.needs_confirmation(&tool_use.input, cx)
2461            && !AgentSettings::get_global(cx).always_allow_tool_actions
2462        {
2463            self.tool_use.confirm_tool_use(
2464                tool_use.id,
2465                tool_use.ui_text,
2466                tool_use.input,
2467                request,
2468                tool,
2469            );
2470            cx.emit(ThreadEvent::ToolConfirmationNeeded);
2471        } else {
2472            self.run_tool(
2473                tool_use.id,
2474                tool_use.ui_text,
2475                tool_use.input,
2476                request,
2477                tool,
2478                model,
2479                window,
2480                cx,
2481            );
2482        }
2483    }
2484
2485    pub fn handle_hallucinated_tool_use(
2486        &mut self,
2487        tool_use_id: LanguageModelToolUseId,
2488        hallucinated_tool_name: Arc<str>,
2489        window: Option<AnyWindowHandle>,
2490        cx: &mut Context<Thread>,
2491    ) {
2492        let available_tools = self.profile.enabled_tools(cx);
2493
2494        let tool_list = available_tools
2495            .iter()
2496            .map(|(name, tool)| format!("- {}: {}", name, tool.description()))
2497            .collect::<Vec<_>>()
2498            .join("\n");
2499
2500        let error_message = format!(
2501            "The tool '{}' doesn't exist or is not enabled. Available tools:\n{}",
2502            hallucinated_tool_name, tool_list
2503        );
2504
2505        let pending_tool_use = self.tool_use.insert_tool_output(
2506            tool_use_id.clone(),
2507            hallucinated_tool_name,
2508            Err(anyhow!("Missing tool call: {error_message}")),
2509            self.configured_model.as_ref(),
2510        );
2511
2512        cx.emit(ThreadEvent::MissingToolUse {
2513            tool_use_id: tool_use_id.clone(),
2514            ui_text: error_message.into(),
2515        });
2516
2517        self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2518    }
2519
2520    pub fn receive_invalid_tool_json(
2521        &mut self,
2522        tool_use_id: LanguageModelToolUseId,
2523        tool_name: Arc<str>,
2524        invalid_json: Arc<str>,
2525        error: String,
2526        window: Option<AnyWindowHandle>,
2527        cx: &mut Context<Thread>,
2528    ) {
2529        log::error!("The model returned invalid input JSON: {invalid_json}");
2530
2531        let pending_tool_use = self.tool_use.insert_tool_output(
2532            tool_use_id.clone(),
2533            tool_name,
2534            Err(anyhow!("Error parsing input JSON: {error}")),
2535            self.configured_model.as_ref(),
2536        );
2537        let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
2538            pending_tool_use.ui_text.clone()
2539        } else {
2540            log::error!(
2541                "There was no pending tool use for tool use {tool_use_id}, even though it finished (with invalid input JSON)."
2542            );
2543            format!("Unknown tool {}", tool_use_id).into()
2544        };
2545
2546        cx.emit(ThreadEvent::InvalidToolInput {
2547            tool_use_id: tool_use_id.clone(),
2548            ui_text,
2549            invalid_input_json: invalid_json,
2550        });
2551
2552        self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2553    }
2554
2555    pub fn run_tool(
2556        &mut self,
2557        tool_use_id: LanguageModelToolUseId,
2558        ui_text: impl Into<SharedString>,
2559        input: serde_json::Value,
2560        request: Arc<LanguageModelRequest>,
2561        tool: Arc<dyn Tool>,
2562        model: Arc<dyn LanguageModel>,
2563        window: Option<AnyWindowHandle>,
2564        cx: &mut Context<Thread>,
2565    ) {
2566        let task =
2567            self.spawn_tool_use(tool_use_id.clone(), request, input, tool, model, window, cx);
2568        self.tool_use
2569            .run_pending_tool(tool_use_id, ui_text.into(), task);
2570    }
2571
2572    fn spawn_tool_use(
2573        &mut self,
2574        tool_use_id: LanguageModelToolUseId,
2575        request: Arc<LanguageModelRequest>,
2576        input: serde_json::Value,
2577        tool: Arc<dyn Tool>,
2578        model: Arc<dyn LanguageModel>,
2579        window: Option<AnyWindowHandle>,
2580        cx: &mut Context<Thread>,
2581    ) -> Task<()> {
2582        let tool_name: Arc<str> = tool.name().into();
2583
2584        let tool_result = tool.run(
2585            input,
2586            request,
2587            self.project.clone(),
2588            self.action_log.clone(),
2589            model,
2590            window,
2591            cx,
2592        );
2593
2594        // Store the card separately if it exists
2595        if let Some(card) = tool_result.card.clone() {
2596            self.tool_use
2597                .insert_tool_result_card(tool_use_id.clone(), card);
2598        }
2599
2600        cx.spawn({
2601            async move |thread: WeakEntity<Thread>, cx| {
2602                let output = tool_result.output.await;
2603
2604                thread
2605                    .update(cx, |thread, cx| {
2606                        let pending_tool_use = thread.tool_use.insert_tool_output(
2607                            tool_use_id.clone(),
2608                            tool_name,
2609                            output,
2610                            thread.configured_model.as_ref(),
2611                        );
2612                        thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2613                    })
2614                    .ok();
2615            }
2616        })
2617    }
2618
2619    fn tool_finished(
2620        &mut self,
2621        tool_use_id: LanguageModelToolUseId,
2622        pending_tool_use: Option<PendingToolUse>,
2623        canceled: bool,
2624        window: Option<AnyWindowHandle>,
2625        cx: &mut Context<Self>,
2626    ) {
2627        if self.all_tools_finished() {
2628            if let Some(ConfiguredModel { model, .. }) = self.configured_model.as_ref() {
2629                if !canceled {
2630                    self.send_to_model(model.clone(), CompletionIntent::ToolResults, window, cx);
2631                }
2632                self.auto_capture_telemetry(cx);
2633            }
2634        }
2635
2636        cx.emit(ThreadEvent::ToolFinished {
2637            tool_use_id,
2638            pending_tool_use,
2639        });
2640    }
2641
2642    /// Cancels the last pending completion, if there are any pending.
2643    ///
2644    /// Returns whether a completion was canceled.
2645    pub fn cancel_last_completion(
2646        &mut self,
2647        window: Option<AnyWindowHandle>,
2648        cx: &mut Context<Self>,
2649    ) -> bool {
2650        let mut canceled = self.pending_completions.pop().is_some() || self.retry_state.is_some();
2651
2652        self.retry_state = None;
2653
2654        for pending_tool_use in self.tool_use.cancel_pending() {
2655            canceled = true;
2656            self.tool_finished(
2657                pending_tool_use.id.clone(),
2658                Some(pending_tool_use),
2659                true,
2660                window,
2661                cx,
2662            );
2663        }
2664
2665        if canceled {
2666            cx.emit(ThreadEvent::CompletionCanceled);
2667
2668            // When canceled, we always want to insert the checkpoint.
2669            // (We skip over finalize_pending_checkpoint, because it
2670            // would conclude we didn't have anything to insert here.)
2671            if let Some(checkpoint) = self.pending_checkpoint.take() {
2672                self.insert_checkpoint(checkpoint, cx);
2673            }
2674        } else {
2675            self.finalize_pending_checkpoint(cx);
2676        }
2677
2678        canceled
2679    }
2680
2681    /// Signals that any in-progress editing should be canceled.
2682    ///
2683    /// This method is used to notify listeners (like ActiveThread) that
2684    /// they should cancel any editing operations.
2685    pub fn cancel_editing(&mut self, cx: &mut Context<Self>) {
2686        cx.emit(ThreadEvent::CancelEditing);
2687    }
2688
2689    pub fn feedback(&self) -> Option<ThreadFeedback> {
2690        self.feedback
2691    }
2692
2693    pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
2694        self.message_feedback.get(&message_id).copied()
2695    }
2696
2697    pub fn report_message_feedback(
2698        &mut self,
2699        message_id: MessageId,
2700        feedback: ThreadFeedback,
2701        cx: &mut Context<Self>,
2702    ) -> Task<Result<()>> {
2703        if self.message_feedback.get(&message_id) == Some(&feedback) {
2704            return Task::ready(Ok(()));
2705        }
2706
2707        let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2708        let serialized_thread = self.serialize(cx);
2709        let thread_id = self.id().clone();
2710        let client = self.project.read(cx).client();
2711
2712        let enabled_tool_names: Vec<String> = self
2713            .profile
2714            .enabled_tools(cx)
2715            .iter()
2716            .map(|(name, _)| name.clone().into())
2717            .collect();
2718
2719        self.message_feedback.insert(message_id, feedback);
2720
2721        cx.notify();
2722
2723        let message_content = self
2724            .message(message_id)
2725            .map(|msg| msg.to_string())
2726            .unwrap_or_default();
2727
2728        cx.background_spawn(async move {
2729            let final_project_snapshot = final_project_snapshot.await;
2730            let serialized_thread = serialized_thread.await?;
2731            let thread_data =
2732                serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
2733
2734            let rating = match feedback {
2735                ThreadFeedback::Positive => "positive",
2736                ThreadFeedback::Negative => "negative",
2737            };
2738            telemetry::event!(
2739                "Assistant Thread Rated",
2740                rating,
2741                thread_id,
2742                enabled_tool_names,
2743                message_id = message_id.0,
2744                message_content,
2745                thread_data,
2746                final_project_snapshot
2747            );
2748            client.telemetry().flush_events().await;
2749
2750            Ok(())
2751        })
2752    }
2753
2754    pub fn report_feedback(
2755        &mut self,
2756        feedback: ThreadFeedback,
2757        cx: &mut Context<Self>,
2758    ) -> Task<Result<()>> {
2759        let last_assistant_message_id = self
2760            .messages
2761            .iter()
2762            .rev()
2763            .find(|msg| msg.role == Role::Assistant)
2764            .map(|msg| msg.id);
2765
2766        if let Some(message_id) = last_assistant_message_id {
2767            self.report_message_feedback(message_id, feedback, cx)
2768        } else {
2769            let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2770            let serialized_thread = self.serialize(cx);
2771            let thread_id = self.id().clone();
2772            let client = self.project.read(cx).client();
2773            self.feedback = Some(feedback);
2774            cx.notify();
2775
2776            cx.background_spawn(async move {
2777                let final_project_snapshot = final_project_snapshot.await;
2778                let serialized_thread = serialized_thread.await?;
2779                let thread_data = serde_json::to_value(serialized_thread)
2780                    .unwrap_or_else(|_| serde_json::Value::Null);
2781
2782                let rating = match feedback {
2783                    ThreadFeedback::Positive => "positive",
2784                    ThreadFeedback::Negative => "negative",
2785                };
2786                telemetry::event!(
2787                    "Assistant Thread Rated",
2788                    rating,
2789                    thread_id,
2790                    thread_data,
2791                    final_project_snapshot
2792                );
2793                client.telemetry().flush_events().await;
2794
2795                Ok(())
2796            })
2797        }
2798    }
2799
2800    /// Create a snapshot of the current project state including git information and unsaved buffers.
2801    fn project_snapshot(
2802        project: Entity<Project>,
2803        cx: &mut Context<Self>,
2804    ) -> Task<Arc<ProjectSnapshot>> {
2805        let git_store = project.read(cx).git_store().clone();
2806        let worktree_snapshots: Vec<_> = project
2807            .read(cx)
2808            .visible_worktrees(cx)
2809            .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
2810            .collect();
2811
2812        cx.spawn(async move |_, cx| {
2813            let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
2814
2815            let mut unsaved_buffers = Vec::new();
2816            cx.update(|app_cx| {
2817                let buffer_store = project.read(app_cx).buffer_store();
2818                for buffer_handle in buffer_store.read(app_cx).buffers() {
2819                    let buffer = buffer_handle.read(app_cx);
2820                    if buffer.is_dirty() {
2821                        if let Some(file) = buffer.file() {
2822                            let path = file.path().to_string_lossy().to_string();
2823                            unsaved_buffers.push(path);
2824                        }
2825                    }
2826                }
2827            })
2828            .ok();
2829
2830            Arc::new(ProjectSnapshot {
2831                worktree_snapshots,
2832                unsaved_buffer_paths: unsaved_buffers,
2833                timestamp: Utc::now(),
2834            })
2835        })
2836    }
2837
2838    fn worktree_snapshot(
2839        worktree: Entity<project::Worktree>,
2840        git_store: Entity<GitStore>,
2841        cx: &App,
2842    ) -> Task<WorktreeSnapshot> {
2843        cx.spawn(async move |cx| {
2844            // Get worktree path and snapshot
2845            let worktree_info = cx.update(|app_cx| {
2846                let worktree = worktree.read(app_cx);
2847                let path = worktree.abs_path().to_string_lossy().to_string();
2848                let snapshot = worktree.snapshot();
2849                (path, snapshot)
2850            });
2851
2852            let Ok((worktree_path, _snapshot)) = worktree_info else {
2853                return WorktreeSnapshot {
2854                    worktree_path: String::new(),
2855                    git_state: None,
2856                };
2857            };
2858
2859            let git_state = git_store
2860                .update(cx, |git_store, cx| {
2861                    git_store
2862                        .repositories()
2863                        .values()
2864                        .find(|repo| {
2865                            repo.read(cx)
2866                                .abs_path_to_repo_path(&worktree.read(cx).abs_path())
2867                                .is_some()
2868                        })
2869                        .cloned()
2870                })
2871                .ok()
2872                .flatten()
2873                .map(|repo| {
2874                    repo.update(cx, |repo, _| {
2875                        let current_branch =
2876                            repo.branch.as_ref().map(|branch| branch.name().to_owned());
2877                        repo.send_job(None, |state, _| async move {
2878                            let RepositoryState::Local { backend, .. } = state else {
2879                                return GitState {
2880                                    remote_url: None,
2881                                    head_sha: None,
2882                                    current_branch,
2883                                    diff: None,
2884                                };
2885                            };
2886
2887                            let remote_url = backend.remote_url("origin");
2888                            let head_sha = backend.head_sha().await;
2889                            let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
2890
2891                            GitState {
2892                                remote_url,
2893                                head_sha,
2894                                current_branch,
2895                                diff,
2896                            }
2897                        })
2898                    })
2899                });
2900
2901            let git_state = match git_state {
2902                Some(git_state) => match git_state.ok() {
2903                    Some(git_state) => git_state.await.ok(),
2904                    None => None,
2905                },
2906                None => None,
2907            };
2908
2909            WorktreeSnapshot {
2910                worktree_path,
2911                git_state,
2912            }
2913        })
2914    }
2915
2916    pub fn to_markdown(&self, cx: &App) -> Result<String> {
2917        let mut markdown = Vec::new();
2918
2919        let summary = self.summary().or_default();
2920        writeln!(markdown, "# {summary}\n")?;
2921
2922        for message in self.messages() {
2923            writeln!(
2924                markdown,
2925                "## {role}\n",
2926                role = match message.role {
2927                    Role::User => "User",
2928                    Role::Assistant => "Agent",
2929                    Role::System => "System",
2930                }
2931            )?;
2932
2933            if !message.loaded_context.text.is_empty() {
2934                writeln!(markdown, "{}", message.loaded_context.text)?;
2935            }
2936
2937            if !message.loaded_context.images.is_empty() {
2938                writeln!(
2939                    markdown,
2940                    "\n{} images attached as context.\n",
2941                    message.loaded_context.images.len()
2942                )?;
2943            }
2944
2945            for segment in &message.segments {
2946                match segment {
2947                    MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
2948                    MessageSegment::Thinking { text, .. } => {
2949                        writeln!(markdown, "<think>\n{}\n</think>\n", text)?
2950                    }
2951                    MessageSegment::RedactedThinking(_) => {}
2952                }
2953            }
2954
2955            for tool_use in self.tool_uses_for_message(message.id, cx) {
2956                writeln!(
2957                    markdown,
2958                    "**Use Tool: {} ({})**",
2959                    tool_use.name, tool_use.id
2960                )?;
2961                writeln!(markdown, "```json")?;
2962                writeln!(
2963                    markdown,
2964                    "{}",
2965                    serde_json::to_string_pretty(&tool_use.input)?
2966                )?;
2967                writeln!(markdown, "```")?;
2968            }
2969
2970            for tool_result in self.tool_results_for_message(message.id) {
2971                write!(markdown, "\n**Tool Results: {}", tool_result.tool_use_id)?;
2972                if tool_result.is_error {
2973                    write!(markdown, " (Error)")?;
2974                }
2975
2976                writeln!(markdown, "**\n")?;
2977                match &tool_result.content {
2978                    LanguageModelToolResultContent::Text(text) => {
2979                        writeln!(markdown, "{text}")?;
2980                    }
2981                    LanguageModelToolResultContent::Image(image) => {
2982                        writeln!(markdown, "![Image](data:base64,{})", image.source)?;
2983                    }
2984                }
2985
2986                if let Some(output) = tool_result.output.as_ref() {
2987                    writeln!(
2988                        markdown,
2989                        "\n\nDebug Output:\n\n```json\n{}\n```\n",
2990                        serde_json::to_string_pretty(output)?
2991                    )?;
2992                }
2993            }
2994        }
2995
2996        Ok(String::from_utf8_lossy(&markdown).to_string())
2997    }
2998
2999    pub fn keep_edits_in_range(
3000        &mut self,
3001        buffer: Entity<language::Buffer>,
3002        buffer_range: Range<language::Anchor>,
3003        cx: &mut Context<Self>,
3004    ) {
3005        self.action_log.update(cx, |action_log, cx| {
3006            action_log.keep_edits_in_range(buffer, buffer_range, cx)
3007        });
3008    }
3009
3010    pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
3011        self.action_log
3012            .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
3013    }
3014
3015    pub fn reject_edits_in_ranges(
3016        &mut self,
3017        buffer: Entity<language::Buffer>,
3018        buffer_ranges: Vec<Range<language::Anchor>>,
3019        cx: &mut Context<Self>,
3020    ) -> Task<Result<()>> {
3021        self.action_log.update(cx, |action_log, cx| {
3022            action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
3023        })
3024    }
3025
3026    pub fn action_log(&self) -> &Entity<ActionLog> {
3027        &self.action_log
3028    }
3029
3030    pub fn project(&self) -> &Entity<Project> {
3031        &self.project
3032    }
3033
3034    pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
3035        if !cx.has_flag::<feature_flags::ThreadAutoCaptureFeatureFlag>() {
3036            return;
3037        }
3038
3039        let now = Instant::now();
3040        if let Some(last) = self.last_auto_capture_at {
3041            if now.duration_since(last).as_secs() < 10 {
3042                return;
3043            }
3044        }
3045
3046        self.last_auto_capture_at = Some(now);
3047
3048        let thread_id = self.id().clone();
3049        let github_login = self
3050            .project
3051            .read(cx)
3052            .user_store()
3053            .read(cx)
3054            .current_user()
3055            .map(|user| user.github_login.clone());
3056        let client = self.project.read(cx).client();
3057        let serialize_task = self.serialize(cx);
3058
3059        cx.background_executor()
3060            .spawn(async move {
3061                if let Ok(serialized_thread) = serialize_task.await {
3062                    if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
3063                        telemetry::event!(
3064                            "Agent Thread Auto-Captured",
3065                            thread_id = thread_id.to_string(),
3066                            thread_data = thread_data,
3067                            auto_capture_reason = "tracked_user",
3068                            github_login = github_login
3069                        );
3070
3071                        client.telemetry().flush_events().await;
3072                    }
3073                }
3074            })
3075            .detach();
3076    }
3077
3078    pub fn cumulative_token_usage(&self) -> TokenUsage {
3079        self.cumulative_token_usage
3080    }
3081
3082    pub fn token_usage_up_to_message(&self, message_id: MessageId) -> TotalTokenUsage {
3083        let Some(model) = self.configured_model.as_ref() else {
3084            return TotalTokenUsage::default();
3085        };
3086
3087        let max = model.model.max_token_count();
3088
3089        let index = self
3090            .messages
3091            .iter()
3092            .position(|msg| msg.id == message_id)
3093            .unwrap_or(0);
3094
3095        if index == 0 {
3096            return TotalTokenUsage { total: 0, max };
3097        }
3098
3099        let token_usage = &self
3100            .request_token_usage
3101            .get(index - 1)
3102            .cloned()
3103            .unwrap_or_default();
3104
3105        TotalTokenUsage {
3106            total: token_usage.total_tokens(),
3107            max,
3108        }
3109    }
3110
3111    pub fn total_token_usage(&self) -> Option<TotalTokenUsage> {
3112        let model = self.configured_model.as_ref()?;
3113
3114        let max = model.model.max_token_count();
3115
3116        if let Some(exceeded_error) = &self.exceeded_window_error {
3117            if model.model.id() == exceeded_error.model_id {
3118                return Some(TotalTokenUsage {
3119                    total: exceeded_error.token_count,
3120                    max,
3121                });
3122            }
3123        }
3124
3125        let total = self
3126            .token_usage_at_last_message()
3127            .unwrap_or_default()
3128            .total_tokens();
3129
3130        Some(TotalTokenUsage { total, max })
3131    }
3132
3133    fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
3134        self.request_token_usage
3135            .get(self.messages.len().saturating_sub(1))
3136            .or_else(|| self.request_token_usage.last())
3137            .cloned()
3138    }
3139
3140    fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
3141        let placeholder = self.token_usage_at_last_message().unwrap_or_default();
3142        self.request_token_usage
3143            .resize(self.messages.len(), placeholder);
3144
3145        if let Some(last) = self.request_token_usage.last_mut() {
3146            *last = token_usage;
3147        }
3148    }
3149
3150    fn update_model_request_usage(&self, amount: u32, limit: UsageLimit, cx: &mut Context<Self>) {
3151        self.project.update(cx, |project, cx| {
3152            project.user_store().update(cx, |user_store, cx| {
3153                user_store.update_model_request_usage(
3154                    ModelRequestUsage(RequestUsage {
3155                        amount: amount as i32,
3156                        limit,
3157                    }),
3158                    cx,
3159                )
3160            })
3161        });
3162    }
3163
3164    pub fn deny_tool_use(
3165        &mut self,
3166        tool_use_id: LanguageModelToolUseId,
3167        tool_name: Arc<str>,
3168        window: Option<AnyWindowHandle>,
3169        cx: &mut Context<Self>,
3170    ) {
3171        let err = Err(anyhow::anyhow!(
3172            "Permission to run tool action denied by user"
3173        ));
3174
3175        self.tool_use.insert_tool_output(
3176            tool_use_id.clone(),
3177            tool_name,
3178            err,
3179            self.configured_model.as_ref(),
3180        );
3181        self.tool_finished(tool_use_id.clone(), None, true, window, cx);
3182    }
3183}
3184
3185#[derive(Debug, Clone, Error)]
3186pub enum ThreadError {
3187    #[error("Payment required")]
3188    PaymentRequired,
3189    #[error("Model request limit reached")]
3190    ModelRequestLimitReached { plan: Plan },
3191    #[error("Message {header}: {message}")]
3192    Message {
3193        header: SharedString,
3194        message: SharedString,
3195    },
3196}
3197
3198#[derive(Debug, Clone)]
3199pub enum ThreadEvent {
3200    ShowError(ThreadError),
3201    StreamedCompletion,
3202    ReceivedTextChunk,
3203    NewRequest,
3204    StreamedAssistantText(MessageId, String),
3205    StreamedAssistantThinking(MessageId, String),
3206    StreamedToolUse {
3207        tool_use_id: LanguageModelToolUseId,
3208        ui_text: Arc<str>,
3209        input: serde_json::Value,
3210    },
3211    MissingToolUse {
3212        tool_use_id: LanguageModelToolUseId,
3213        ui_text: Arc<str>,
3214    },
3215    InvalidToolInput {
3216        tool_use_id: LanguageModelToolUseId,
3217        ui_text: Arc<str>,
3218        invalid_input_json: Arc<str>,
3219    },
3220    Stopped(Result<StopReason, Arc<anyhow::Error>>),
3221    MessageAdded(MessageId),
3222    MessageEdited(MessageId),
3223    MessageDeleted(MessageId),
3224    SummaryGenerated,
3225    SummaryChanged,
3226    UsePendingTools {
3227        tool_uses: Vec<PendingToolUse>,
3228    },
3229    ToolFinished {
3230        #[allow(unused)]
3231        tool_use_id: LanguageModelToolUseId,
3232        /// The pending tool use that corresponds to this tool.
3233        pending_tool_use: Option<PendingToolUse>,
3234    },
3235    CheckpointChanged,
3236    ToolConfirmationNeeded,
3237    ToolUseLimitReached,
3238    CancelEditing,
3239    CompletionCanceled,
3240    ProfileChanged,
3241    RetriesFailed {
3242        message: SharedString,
3243    },
3244}
3245
3246impl EventEmitter<ThreadEvent> for Thread {}
3247
3248struct PendingCompletion {
3249    id: usize,
3250    queue_state: QueueState,
3251    _task: Task<()>,
3252}
3253
3254#[cfg(test)]
3255mod tests {
3256    use super::*;
3257    use crate::{
3258        context::load_context, context_store::ContextStore, thread_store, thread_store::ThreadStore,
3259    };
3260
3261    // Test-specific constants
3262    const TEST_RATE_LIMIT_RETRY_SECS: u64 = 30;
3263    use agent_settings::{AgentProfileId, AgentSettings, LanguageModelParameters};
3264    use assistant_tool::ToolRegistry;
3265    use assistant_tools;
3266    use futures::StreamExt;
3267    use futures::future::BoxFuture;
3268    use futures::stream::BoxStream;
3269    use gpui::TestAppContext;
3270    use http_client;
3271    use indoc::indoc;
3272    use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
3273    use language_model::{
3274        LanguageModelCompletionError, LanguageModelName, LanguageModelProviderId,
3275        LanguageModelProviderName, LanguageModelToolChoice,
3276    };
3277    use parking_lot::Mutex;
3278    use project::{FakeFs, Project};
3279    use prompt_store::PromptBuilder;
3280    use serde_json::json;
3281    use settings::{Settings, SettingsStore};
3282    use std::sync::Arc;
3283    use std::time::Duration;
3284    use theme::ThemeSettings;
3285    use util::path;
3286    use workspace::Workspace;
3287
3288    #[gpui::test]
3289    async fn test_message_with_context(cx: &mut TestAppContext) {
3290        init_test_settings(cx);
3291
3292        let project = create_test_project(
3293            cx,
3294            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
3295        )
3296        .await;
3297
3298        let (_workspace, _thread_store, thread, context_store, model) =
3299            setup_test_environment(cx, project.clone()).await;
3300
3301        add_file_to_context(&project, &context_store, "test/code.rs", cx)
3302            .await
3303            .unwrap();
3304
3305        let context =
3306            context_store.read_with(cx, |store, _| store.context().next().cloned().unwrap());
3307        let loaded_context = cx
3308            .update(|cx| load_context(vec![context], &project, &None, cx))
3309            .await;
3310
3311        // Insert user message with context
3312        let message_id = thread.update(cx, |thread, cx| {
3313            thread.insert_user_message(
3314                "Please explain this code",
3315                loaded_context,
3316                None,
3317                Vec::new(),
3318                cx,
3319            )
3320        });
3321
3322        // Check content and context in message object
3323        let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
3324
3325        // Use different path format strings based on platform for the test
3326        #[cfg(windows)]
3327        let path_part = r"test\code.rs";
3328        #[cfg(not(windows))]
3329        let path_part = "test/code.rs";
3330
3331        let expected_context = format!(
3332            r#"
3333<context>
3334The following items were attached by the user. They are up-to-date and don't need to be re-read.
3335
3336<files>
3337```rs {path_part}
3338fn main() {{
3339    println!("Hello, world!");
3340}}
3341```
3342</files>
3343</context>
3344"#
3345        );
3346
3347        assert_eq!(message.role, Role::User);
3348        assert_eq!(message.segments.len(), 1);
3349        assert_eq!(
3350            message.segments[0],
3351            MessageSegment::Text("Please explain this code".to_string())
3352        );
3353        assert_eq!(message.loaded_context.text, expected_context);
3354
3355        // Check message in request
3356        let request = thread.update(cx, |thread, cx| {
3357            thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3358        });
3359
3360        assert_eq!(request.messages.len(), 2);
3361        let expected_full_message = format!("{}Please explain this code", expected_context);
3362        assert_eq!(request.messages[1].string_contents(), expected_full_message);
3363    }
3364
3365    #[gpui::test]
3366    async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
3367        init_test_settings(cx);
3368
3369        let project = create_test_project(
3370            cx,
3371            json!({
3372                "file1.rs": "fn function1() {}\n",
3373                "file2.rs": "fn function2() {}\n",
3374                "file3.rs": "fn function3() {}\n",
3375                "file4.rs": "fn function4() {}\n",
3376            }),
3377        )
3378        .await;
3379
3380        let (_, _thread_store, thread, context_store, model) =
3381            setup_test_environment(cx, project.clone()).await;
3382
3383        // First message with context 1
3384        add_file_to_context(&project, &context_store, "test/file1.rs", cx)
3385            .await
3386            .unwrap();
3387        let new_contexts = context_store.update(cx, |store, cx| {
3388            store.new_context_for_thread(thread.read(cx), None)
3389        });
3390        assert_eq!(new_contexts.len(), 1);
3391        let loaded_context = cx
3392            .update(|cx| load_context(new_contexts, &project, &None, cx))
3393            .await;
3394        let message1_id = thread.update(cx, |thread, cx| {
3395            thread.insert_user_message("Message 1", loaded_context, None, Vec::new(), cx)
3396        });
3397
3398        // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
3399        add_file_to_context(&project, &context_store, "test/file2.rs", cx)
3400            .await
3401            .unwrap();
3402        let new_contexts = context_store.update(cx, |store, cx| {
3403            store.new_context_for_thread(thread.read(cx), None)
3404        });
3405        assert_eq!(new_contexts.len(), 1);
3406        let loaded_context = cx
3407            .update(|cx| load_context(new_contexts, &project, &None, cx))
3408            .await;
3409        let message2_id = thread.update(cx, |thread, cx| {
3410            thread.insert_user_message("Message 2", loaded_context, None, Vec::new(), cx)
3411        });
3412
3413        // Third message with all three contexts (contexts 1 and 2 should be skipped)
3414        //
3415        add_file_to_context(&project, &context_store, "test/file3.rs", cx)
3416            .await
3417            .unwrap();
3418        let new_contexts = context_store.update(cx, |store, cx| {
3419            store.new_context_for_thread(thread.read(cx), None)
3420        });
3421        assert_eq!(new_contexts.len(), 1);
3422        let loaded_context = cx
3423            .update(|cx| load_context(new_contexts, &project, &None, cx))
3424            .await;
3425        let message3_id = thread.update(cx, |thread, cx| {
3426            thread.insert_user_message("Message 3", loaded_context, None, Vec::new(), cx)
3427        });
3428
3429        // Check what contexts are included in each message
3430        let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
3431            (
3432                thread.message(message1_id).unwrap().clone(),
3433                thread.message(message2_id).unwrap().clone(),
3434                thread.message(message3_id).unwrap().clone(),
3435            )
3436        });
3437
3438        // First message should include context 1
3439        assert!(message1.loaded_context.text.contains("file1.rs"));
3440
3441        // Second message should include only context 2 (not 1)
3442        assert!(!message2.loaded_context.text.contains("file1.rs"));
3443        assert!(message2.loaded_context.text.contains("file2.rs"));
3444
3445        // Third message should include only context 3 (not 1 or 2)
3446        assert!(!message3.loaded_context.text.contains("file1.rs"));
3447        assert!(!message3.loaded_context.text.contains("file2.rs"));
3448        assert!(message3.loaded_context.text.contains("file3.rs"));
3449
3450        // Check entire request to make sure all contexts are properly included
3451        let request = thread.update(cx, |thread, cx| {
3452            thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3453        });
3454
3455        // The request should contain all 3 messages
3456        assert_eq!(request.messages.len(), 4);
3457
3458        // Check that the contexts are properly formatted in each message
3459        assert!(request.messages[1].string_contents().contains("file1.rs"));
3460        assert!(!request.messages[1].string_contents().contains("file2.rs"));
3461        assert!(!request.messages[1].string_contents().contains("file3.rs"));
3462
3463        assert!(!request.messages[2].string_contents().contains("file1.rs"));
3464        assert!(request.messages[2].string_contents().contains("file2.rs"));
3465        assert!(!request.messages[2].string_contents().contains("file3.rs"));
3466
3467        assert!(!request.messages[3].string_contents().contains("file1.rs"));
3468        assert!(!request.messages[3].string_contents().contains("file2.rs"));
3469        assert!(request.messages[3].string_contents().contains("file3.rs"));
3470
3471        add_file_to_context(&project, &context_store, "test/file4.rs", cx)
3472            .await
3473            .unwrap();
3474        let new_contexts = context_store.update(cx, |store, cx| {
3475            store.new_context_for_thread(thread.read(cx), Some(message2_id))
3476        });
3477        assert_eq!(new_contexts.len(), 3);
3478        let loaded_context = cx
3479            .update(|cx| load_context(new_contexts, &project, &None, cx))
3480            .await
3481            .loaded_context;
3482
3483        assert!(!loaded_context.text.contains("file1.rs"));
3484        assert!(loaded_context.text.contains("file2.rs"));
3485        assert!(loaded_context.text.contains("file3.rs"));
3486        assert!(loaded_context.text.contains("file4.rs"));
3487
3488        let new_contexts = context_store.update(cx, |store, cx| {
3489            // Remove file4.rs
3490            store.remove_context(&loaded_context.contexts[2].handle(), cx);
3491            store.new_context_for_thread(thread.read(cx), Some(message2_id))
3492        });
3493        assert_eq!(new_contexts.len(), 2);
3494        let loaded_context = cx
3495            .update(|cx| load_context(new_contexts, &project, &None, cx))
3496            .await
3497            .loaded_context;
3498
3499        assert!(!loaded_context.text.contains("file1.rs"));
3500        assert!(loaded_context.text.contains("file2.rs"));
3501        assert!(loaded_context.text.contains("file3.rs"));
3502        assert!(!loaded_context.text.contains("file4.rs"));
3503
3504        let new_contexts = context_store.update(cx, |store, cx| {
3505            // Remove file3.rs
3506            store.remove_context(&loaded_context.contexts[1].handle(), cx);
3507            store.new_context_for_thread(thread.read(cx), Some(message2_id))
3508        });
3509        assert_eq!(new_contexts.len(), 1);
3510        let loaded_context = cx
3511            .update(|cx| load_context(new_contexts, &project, &None, cx))
3512            .await
3513            .loaded_context;
3514
3515        assert!(!loaded_context.text.contains("file1.rs"));
3516        assert!(loaded_context.text.contains("file2.rs"));
3517        assert!(!loaded_context.text.contains("file3.rs"));
3518        assert!(!loaded_context.text.contains("file4.rs"));
3519    }
3520
3521    #[gpui::test]
3522    async fn test_message_without_files(cx: &mut TestAppContext) {
3523        init_test_settings(cx);
3524
3525        let project = create_test_project(
3526            cx,
3527            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
3528        )
3529        .await;
3530
3531        let (_, _thread_store, thread, _context_store, model) =
3532            setup_test_environment(cx, project.clone()).await;
3533
3534        // Insert user message without any context (empty context vector)
3535        let message_id = thread.update(cx, |thread, cx| {
3536            thread.insert_user_message(
3537                "What is the best way to learn Rust?",
3538                ContextLoadResult::default(),
3539                None,
3540                Vec::new(),
3541                cx,
3542            )
3543        });
3544
3545        // Check content and context in message object
3546        let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
3547
3548        // Context should be empty when no files are included
3549        assert_eq!(message.role, Role::User);
3550        assert_eq!(message.segments.len(), 1);
3551        assert_eq!(
3552            message.segments[0],
3553            MessageSegment::Text("What is the best way to learn Rust?".to_string())
3554        );
3555        assert_eq!(message.loaded_context.text, "");
3556
3557        // Check message in request
3558        let request = thread.update(cx, |thread, cx| {
3559            thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3560        });
3561
3562        assert_eq!(request.messages.len(), 2);
3563        assert_eq!(
3564            request.messages[1].string_contents(),
3565            "What is the best way to learn Rust?"
3566        );
3567
3568        // Add second message, also without context
3569        let message2_id = thread.update(cx, |thread, cx| {
3570            thread.insert_user_message(
3571                "Are there any good books?",
3572                ContextLoadResult::default(),
3573                None,
3574                Vec::new(),
3575                cx,
3576            )
3577        });
3578
3579        let message2 =
3580            thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
3581        assert_eq!(message2.loaded_context.text, "");
3582
3583        // Check that both messages appear in the request
3584        let request = thread.update(cx, |thread, cx| {
3585            thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3586        });
3587
3588        assert_eq!(request.messages.len(), 3);
3589        assert_eq!(
3590            request.messages[1].string_contents(),
3591            "What is the best way to learn Rust?"
3592        );
3593        assert_eq!(
3594            request.messages[2].string_contents(),
3595            "Are there any good books?"
3596        );
3597    }
3598
3599    #[gpui::test]
3600    async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
3601        init_test_settings(cx);
3602
3603        let project = create_test_project(
3604            cx,
3605            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
3606        )
3607        .await;
3608
3609        let (_workspace, _thread_store, thread, context_store, model) =
3610            setup_test_environment(cx, project.clone()).await;
3611
3612        // Add a buffer to the context. This will be a tracked buffer
3613        let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
3614            .await
3615            .unwrap();
3616
3617        let context = context_store
3618            .read_with(cx, |store, _| store.context().next().cloned())
3619            .unwrap();
3620        let loaded_context = cx
3621            .update(|cx| load_context(vec![context], &project, &None, cx))
3622            .await;
3623
3624        // Insert user message and assistant response
3625        thread.update(cx, |thread, cx| {
3626            thread.insert_user_message("Explain this code", loaded_context, None, Vec::new(), cx);
3627            thread.insert_assistant_message(
3628                vec![MessageSegment::Text("This code prints 42.".into())],
3629                cx,
3630            );
3631        });
3632
3633        // We shouldn't have a stale buffer notification yet
3634        let notifications = thread.read_with(cx, |thread, _| {
3635            find_tool_uses(thread, "project_notifications")
3636        });
3637        assert!(
3638            notifications.is_empty(),
3639            "Should not have stale buffer notification before buffer is modified"
3640        );
3641
3642        // Modify the buffer
3643        buffer.update(cx, |buffer, cx| {
3644            buffer.edit(
3645                [(1..1, "\n    println!(\"Added a new line\");\n")],
3646                None,
3647                cx,
3648            );
3649        });
3650
3651        // Insert another user message
3652        thread.update(cx, |thread, cx| {
3653            thread.insert_user_message(
3654                "What does the code do now?",
3655                ContextLoadResult::default(),
3656                None,
3657                Vec::new(),
3658                cx,
3659            )
3660        });
3661
3662        // Check for the stale buffer warning
3663        thread.update(cx, |thread, cx| {
3664            thread.flush_notifications(model.clone(), CompletionIntent::UserPrompt, cx)
3665        });
3666
3667        let notifications = thread.read_with(cx, |thread, _cx| {
3668            find_tool_uses(thread, "project_notifications")
3669        });
3670
3671        let [notification] = notifications.as_slice() else {
3672            panic!("Should have a `project_notifications` tool use");
3673        };
3674
3675        let Some(notification_content) = notification.content.to_str() else {
3676            panic!("`project_notifications` should return text");
3677        };
3678
3679        let expected_content = indoc! {"[The following is an auto-generated notification; do not reply]
3680
3681        These files have changed since the last read:
3682        - code.rs
3683        "};
3684        assert_eq!(notification_content, expected_content);
3685
3686        // Insert another user message and flush notifications again
3687        thread.update(cx, |thread, cx| {
3688            thread.insert_user_message(
3689                "Can you tell me more?",
3690                ContextLoadResult::default(),
3691                None,
3692                Vec::new(),
3693                cx,
3694            )
3695        });
3696
3697        thread.update(cx, |thread, cx| {
3698            thread.flush_notifications(model.clone(), CompletionIntent::UserPrompt, cx)
3699        });
3700
3701        // There should be no new notifications (we already flushed one)
3702        let notifications = thread.read_with(cx, |thread, _cx| {
3703            find_tool_uses(thread, "project_notifications")
3704        });
3705
3706        assert_eq!(
3707            notifications.len(),
3708            1,
3709            "Should still have only one notification after second flush - no duplicates"
3710        );
3711    }
3712
3713    fn find_tool_uses(thread: &Thread, tool_name: &str) -> Vec<LanguageModelToolResult> {
3714        thread
3715            .messages()
3716            .flat_map(|message| {
3717                thread
3718                    .tool_results_for_message(message.id)
3719                    .into_iter()
3720                    .filter(|result| result.tool_name == tool_name.into())
3721                    .cloned()
3722                    .collect::<Vec<_>>()
3723            })
3724            .collect()
3725    }
3726
3727    #[gpui::test]
3728    async fn test_storing_profile_setting_per_thread(cx: &mut TestAppContext) {
3729        init_test_settings(cx);
3730
3731        let project = create_test_project(
3732            cx,
3733            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
3734        )
3735        .await;
3736
3737        let (_workspace, thread_store, thread, _context_store, _model) =
3738            setup_test_environment(cx, project.clone()).await;
3739
3740        // Check that we are starting with the default profile
3741        let profile = cx.read(|cx| thread.read(cx).profile.clone());
3742        let tool_set = cx.read(|cx| thread_store.read(cx).tools());
3743        assert_eq!(
3744            profile,
3745            AgentProfile::new(AgentProfileId::default(), tool_set)
3746        );
3747    }
3748
3749    #[gpui::test]
3750    async fn test_serializing_thread_profile(cx: &mut TestAppContext) {
3751        init_test_settings(cx);
3752
3753        let project = create_test_project(
3754            cx,
3755            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
3756        )
3757        .await;
3758
3759        let (_workspace, thread_store, thread, _context_store, _model) =
3760            setup_test_environment(cx, project.clone()).await;
3761
3762        // Profile gets serialized with default values
3763        let serialized = thread
3764            .update(cx, |thread, cx| thread.serialize(cx))
3765            .await
3766            .unwrap();
3767
3768        assert_eq!(serialized.profile, Some(AgentProfileId::default()));
3769
3770        let deserialized = cx.update(|cx| {
3771            thread.update(cx, |thread, cx| {
3772                Thread::deserialize(
3773                    thread.id.clone(),
3774                    serialized,
3775                    thread.project.clone(),
3776                    thread.tools.clone(),
3777                    thread.prompt_builder.clone(),
3778                    thread.project_context.clone(),
3779                    None,
3780                    cx,
3781                )
3782            })
3783        });
3784        let tool_set = cx.read(|cx| thread_store.read(cx).tools());
3785
3786        assert_eq!(
3787            deserialized.profile,
3788            AgentProfile::new(AgentProfileId::default(), tool_set)
3789        );
3790    }
3791
3792    #[gpui::test]
3793    async fn test_temperature_setting(cx: &mut TestAppContext) {
3794        init_test_settings(cx);
3795
3796        let project = create_test_project(
3797            cx,
3798            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
3799        )
3800        .await;
3801
3802        let (_workspace, _thread_store, thread, _context_store, model) =
3803            setup_test_environment(cx, project.clone()).await;
3804
3805        // Both model and provider
3806        cx.update(|cx| {
3807            AgentSettings::override_global(
3808                AgentSettings {
3809                    model_parameters: vec![LanguageModelParameters {
3810                        provider: Some(model.provider_id().0.to_string().into()),
3811                        model: Some(model.id().0.clone()),
3812                        temperature: Some(0.66),
3813                    }],
3814                    ..AgentSettings::get_global(cx).clone()
3815                },
3816                cx,
3817            );
3818        });
3819
3820        let request = thread.update(cx, |thread, cx| {
3821            thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3822        });
3823        assert_eq!(request.temperature, Some(0.66));
3824
3825        // Only model
3826        cx.update(|cx| {
3827            AgentSettings::override_global(
3828                AgentSettings {
3829                    model_parameters: vec![LanguageModelParameters {
3830                        provider: None,
3831                        model: Some(model.id().0.clone()),
3832                        temperature: Some(0.66),
3833                    }],
3834                    ..AgentSettings::get_global(cx).clone()
3835                },
3836                cx,
3837            );
3838        });
3839
3840        let request = thread.update(cx, |thread, cx| {
3841            thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3842        });
3843        assert_eq!(request.temperature, Some(0.66));
3844
3845        // Only provider
3846        cx.update(|cx| {
3847            AgentSettings::override_global(
3848                AgentSettings {
3849                    model_parameters: vec![LanguageModelParameters {
3850                        provider: Some(model.provider_id().0.to_string().into()),
3851                        model: None,
3852                        temperature: Some(0.66),
3853                    }],
3854                    ..AgentSettings::get_global(cx).clone()
3855                },
3856                cx,
3857            );
3858        });
3859
3860        let request = thread.update(cx, |thread, cx| {
3861            thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3862        });
3863        assert_eq!(request.temperature, Some(0.66));
3864
3865        // Same model name, different provider
3866        cx.update(|cx| {
3867            AgentSettings::override_global(
3868                AgentSettings {
3869                    model_parameters: vec![LanguageModelParameters {
3870                        provider: Some("anthropic".into()),
3871                        model: Some(model.id().0.clone()),
3872                        temperature: Some(0.66),
3873                    }],
3874                    ..AgentSettings::get_global(cx).clone()
3875                },
3876                cx,
3877            );
3878        });
3879
3880        let request = thread.update(cx, |thread, cx| {
3881            thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3882        });
3883        assert_eq!(request.temperature, None);
3884    }
3885
3886    #[gpui::test]
3887    async fn test_thread_summary(cx: &mut TestAppContext) {
3888        init_test_settings(cx);
3889
3890        let project = create_test_project(cx, json!({})).await;
3891
3892        let (_, _thread_store, thread, _context_store, model) =
3893            setup_test_environment(cx, project.clone()).await;
3894
3895        // Initial state should be pending
3896        thread.read_with(cx, |thread, _| {
3897            assert!(matches!(thread.summary(), ThreadSummary::Pending));
3898            assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3899        });
3900
3901        // Manually setting the summary should not be allowed in this state
3902        thread.update(cx, |thread, cx| {
3903            thread.set_summary("This should not work", cx);
3904        });
3905
3906        thread.read_with(cx, |thread, _| {
3907            assert!(matches!(thread.summary(), ThreadSummary::Pending));
3908        });
3909
3910        // Send a message
3911        thread.update(cx, |thread, cx| {
3912            thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
3913            thread.send_to_model(
3914                model.clone(),
3915                CompletionIntent::ThreadSummarization,
3916                None,
3917                cx,
3918            );
3919        });
3920
3921        let fake_model = model.as_fake();
3922        simulate_successful_response(&fake_model, cx);
3923
3924        // Should start generating summary when there are >= 2 messages
3925        thread.read_with(cx, |thread, _| {
3926            assert_eq!(*thread.summary(), ThreadSummary::Generating);
3927        });
3928
3929        // Should not be able to set the summary while generating
3930        thread.update(cx, |thread, cx| {
3931            thread.set_summary("This should not work either", cx);
3932        });
3933
3934        thread.read_with(cx, |thread, _| {
3935            assert!(matches!(thread.summary(), ThreadSummary::Generating));
3936            assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3937        });
3938
3939        cx.run_until_parked();
3940        fake_model.stream_last_completion_response("Brief");
3941        fake_model.stream_last_completion_response(" Introduction");
3942        fake_model.end_last_completion_stream();
3943        cx.run_until_parked();
3944
3945        // Summary should be set
3946        thread.read_with(cx, |thread, _| {
3947            assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3948            assert_eq!(thread.summary().or_default(), "Brief Introduction");
3949        });
3950
3951        // Now we should be able to set a summary
3952        thread.update(cx, |thread, cx| {
3953            thread.set_summary("Brief Intro", cx);
3954        });
3955
3956        thread.read_with(cx, |thread, _| {
3957            assert_eq!(thread.summary().or_default(), "Brief Intro");
3958        });
3959
3960        // Test setting an empty summary (should default to DEFAULT)
3961        thread.update(cx, |thread, cx| {
3962            thread.set_summary("", cx);
3963        });
3964
3965        thread.read_with(cx, |thread, _| {
3966            assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3967            assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3968        });
3969    }
3970
3971    #[gpui::test]
3972    async fn test_thread_summary_error_set_manually(cx: &mut TestAppContext) {
3973        init_test_settings(cx);
3974
3975        let project = create_test_project(cx, json!({})).await;
3976
3977        let (_, _thread_store, thread, _context_store, model) =
3978            setup_test_environment(cx, project.clone()).await;
3979
3980        test_summarize_error(&model, &thread, cx);
3981
3982        // Now we should be able to set a summary
3983        thread.update(cx, |thread, cx| {
3984            thread.set_summary("Brief Intro", cx);
3985        });
3986
3987        thread.read_with(cx, |thread, _| {
3988            assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3989            assert_eq!(thread.summary().or_default(), "Brief Intro");
3990        });
3991    }
3992
3993    #[gpui::test]
3994    async fn test_thread_summary_error_retry(cx: &mut TestAppContext) {
3995        init_test_settings(cx);
3996
3997        let project = create_test_project(cx, json!({})).await;
3998
3999        let (_, _thread_store, thread, _context_store, model) =
4000            setup_test_environment(cx, project.clone()).await;
4001
4002        test_summarize_error(&model, &thread, cx);
4003
4004        // Sending another message should not trigger another summarize request
4005        thread.update(cx, |thread, cx| {
4006            thread.insert_user_message(
4007                "How are you?",
4008                ContextLoadResult::default(),
4009                None,
4010                vec![],
4011                cx,
4012            );
4013            thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4014        });
4015
4016        let fake_model = model.as_fake();
4017        simulate_successful_response(&fake_model, cx);
4018
4019        thread.read_with(cx, |thread, _| {
4020            // State is still Error, not Generating
4021            assert!(matches!(thread.summary(), ThreadSummary::Error));
4022        });
4023
4024        // But the summarize request can be invoked manually
4025        thread.update(cx, |thread, cx| {
4026            thread.summarize(cx);
4027        });
4028
4029        thread.read_with(cx, |thread, _| {
4030            assert!(matches!(thread.summary(), ThreadSummary::Generating));
4031        });
4032
4033        cx.run_until_parked();
4034        fake_model.stream_last_completion_response("A successful summary");
4035        fake_model.end_last_completion_stream();
4036        cx.run_until_parked();
4037
4038        thread.read_with(cx, |thread, _| {
4039            assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
4040            assert_eq!(thread.summary().or_default(), "A successful summary");
4041        });
4042    }
4043
4044    // Helper to create a model that returns errors
4045    enum TestError {
4046        Overloaded,
4047        InternalServerError,
4048    }
4049
4050    struct ErrorInjector {
4051        inner: Arc<FakeLanguageModel>,
4052        error_type: TestError,
4053    }
4054
4055    impl ErrorInjector {
4056        fn new(error_type: TestError) -> Self {
4057            Self {
4058                inner: Arc::new(FakeLanguageModel::default()),
4059                error_type,
4060            }
4061        }
4062    }
4063
4064    impl LanguageModel for ErrorInjector {
4065        fn id(&self) -> LanguageModelId {
4066            self.inner.id()
4067        }
4068
4069        fn name(&self) -> LanguageModelName {
4070            self.inner.name()
4071        }
4072
4073        fn provider_id(&self) -> LanguageModelProviderId {
4074            self.inner.provider_id()
4075        }
4076
4077        fn provider_name(&self) -> LanguageModelProviderName {
4078            self.inner.provider_name()
4079        }
4080
4081        fn supports_tools(&self) -> bool {
4082            self.inner.supports_tools()
4083        }
4084
4085        fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
4086            self.inner.supports_tool_choice(choice)
4087        }
4088
4089        fn supports_images(&self) -> bool {
4090            self.inner.supports_images()
4091        }
4092
4093        fn telemetry_id(&self) -> String {
4094            self.inner.telemetry_id()
4095        }
4096
4097        fn max_token_count(&self) -> u64 {
4098            self.inner.max_token_count()
4099        }
4100
4101        fn count_tokens(
4102            &self,
4103            request: LanguageModelRequest,
4104            cx: &App,
4105        ) -> BoxFuture<'static, Result<u64>> {
4106            self.inner.count_tokens(request, cx)
4107        }
4108
4109        fn stream_completion(
4110            &self,
4111            _request: LanguageModelRequest,
4112            _cx: &AsyncApp,
4113        ) -> BoxFuture<
4114            'static,
4115            Result<
4116                BoxStream<
4117                    'static,
4118                    Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
4119                >,
4120                LanguageModelCompletionError,
4121            >,
4122        > {
4123            let error = match self.error_type {
4124                TestError::Overloaded => LanguageModelCompletionError::ServerOverloaded {
4125                    provider: self.provider_name(),
4126                    retry_after: None,
4127                },
4128                TestError::InternalServerError => {
4129                    LanguageModelCompletionError::ApiInternalServerError {
4130                        provider: self.provider_name(),
4131                        message: "I'm a teapot orbiting the sun".to_string(),
4132                    }
4133                }
4134            };
4135            async move {
4136                let stream = futures::stream::once(async move { Err(error) });
4137                Ok(stream.boxed())
4138            }
4139            .boxed()
4140        }
4141
4142        fn as_fake(&self) -> &FakeLanguageModel {
4143            &self.inner
4144        }
4145    }
4146
4147    #[gpui::test]
4148    async fn test_retry_on_overloaded_error(cx: &mut TestAppContext) {
4149        init_test_settings(cx);
4150
4151        let project = create_test_project(cx, json!({})).await;
4152        let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4153
4154        // Create model that returns overloaded error
4155        let model = Arc::new(ErrorInjector::new(TestError::Overloaded));
4156
4157        // Insert a user message
4158        thread.update(cx, |thread, cx| {
4159            thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4160        });
4161
4162        // Start completion
4163        thread.update(cx, |thread, cx| {
4164            thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4165        });
4166
4167        cx.run_until_parked();
4168
4169        thread.read_with(cx, |thread, _| {
4170            assert!(thread.retry_state.is_some(), "Should have retry state");
4171            let retry_state = thread.retry_state.as_ref().unwrap();
4172            assert_eq!(retry_state.attempt, 1, "Should be first retry attempt");
4173            assert_eq!(
4174                retry_state.max_attempts, MAX_RETRY_ATTEMPTS,
4175                "Should have default max attempts"
4176            );
4177        });
4178
4179        // Check that a retry message was added
4180        thread.read_with(cx, |thread, _| {
4181            let mut messages = thread.messages();
4182            assert!(
4183                messages.any(|msg| {
4184                    msg.role == Role::System
4185                        && msg.ui_only
4186                        && msg.segments.iter().any(|seg| {
4187                            if let MessageSegment::Text(text) = seg {
4188                                text.contains("overloaded")
4189                                    && text
4190                                        .contains(&format!("attempt 1 of {}", MAX_RETRY_ATTEMPTS))
4191                            } else {
4192                                false
4193                            }
4194                        })
4195                }),
4196                "Should have added a system retry message"
4197            );
4198        });
4199
4200        let retry_count = thread.update(cx, |thread, _| {
4201            thread
4202                .messages
4203                .iter()
4204                .filter(|m| {
4205                    m.ui_only
4206                        && m.segments.iter().any(|s| {
4207                            if let MessageSegment::Text(text) = s {
4208                                text.contains("Retrying") && text.contains("seconds")
4209                            } else {
4210                                false
4211                            }
4212                        })
4213                })
4214                .count()
4215        });
4216
4217        assert_eq!(retry_count, 1, "Should have one retry message");
4218    }
4219
4220    #[gpui::test]
4221    async fn test_retry_on_internal_server_error(cx: &mut TestAppContext) {
4222        init_test_settings(cx);
4223
4224        let project = create_test_project(cx, json!({})).await;
4225        let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4226
4227        // Create model that returns internal server error
4228        let model = Arc::new(ErrorInjector::new(TestError::InternalServerError));
4229
4230        // Insert a user message
4231        thread.update(cx, |thread, cx| {
4232            thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4233        });
4234
4235        // Start completion
4236        thread.update(cx, |thread, cx| {
4237            thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4238        });
4239
4240        cx.run_until_parked();
4241
4242        // Check retry state on thread
4243        thread.read_with(cx, |thread, _| {
4244            assert!(thread.retry_state.is_some(), "Should have retry state");
4245            let retry_state = thread.retry_state.as_ref().unwrap();
4246            assert_eq!(retry_state.attempt, 1, "Should be first retry attempt");
4247            assert_eq!(
4248                retry_state.max_attempts, MAX_RETRY_ATTEMPTS,
4249                "Should have correct max attempts"
4250            );
4251        });
4252
4253        // Check that a retry message was added with provider name
4254        thread.read_with(cx, |thread, _| {
4255            let mut messages = thread.messages();
4256            assert!(
4257                messages.any(|msg| {
4258                    msg.role == Role::System
4259                        && msg.ui_only
4260                        && msg.segments.iter().any(|seg| {
4261                            if let MessageSegment::Text(text) = seg {
4262                                text.contains("internal")
4263                                    && text.contains("Fake")
4264                                    && text
4265                                        .contains(&format!("attempt 1 of {}", MAX_RETRY_ATTEMPTS))
4266                            } else {
4267                                false
4268                            }
4269                        })
4270                }),
4271                "Should have added a system retry message with provider name"
4272            );
4273        });
4274
4275        // Count retry messages
4276        let retry_count = thread.update(cx, |thread, _| {
4277            thread
4278                .messages
4279                .iter()
4280                .filter(|m| {
4281                    m.ui_only
4282                        && m.segments.iter().any(|s| {
4283                            if let MessageSegment::Text(text) = s {
4284                                text.contains("Retrying") && text.contains("seconds")
4285                            } else {
4286                                false
4287                            }
4288                        })
4289                })
4290                .count()
4291        });
4292
4293        assert_eq!(retry_count, 1, "Should have one retry message");
4294    }
4295
4296    #[gpui::test]
4297    async fn test_exponential_backoff_on_retries(cx: &mut TestAppContext) {
4298        init_test_settings(cx);
4299
4300        let project = create_test_project(cx, json!({})).await;
4301        let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4302
4303        // Create model that returns overloaded error
4304        let model = Arc::new(ErrorInjector::new(TestError::Overloaded));
4305
4306        // Insert a user message
4307        thread.update(cx, |thread, cx| {
4308            thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4309        });
4310
4311        // Track retry events and completion count
4312        // Track completion events
4313        let completion_count = Arc::new(Mutex::new(0));
4314        let completion_count_clone = completion_count.clone();
4315
4316        let _subscription = thread.update(cx, |_, cx| {
4317            cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| {
4318                if let ThreadEvent::NewRequest = event {
4319                    *completion_count_clone.lock() += 1;
4320                }
4321            })
4322        });
4323
4324        // First attempt
4325        thread.update(cx, |thread, cx| {
4326            thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4327        });
4328        cx.run_until_parked();
4329
4330        // Should have scheduled first retry - count retry messages
4331        let retry_count = thread.update(cx, |thread, _| {
4332            thread
4333                .messages
4334                .iter()
4335                .filter(|m| {
4336                    m.ui_only
4337                        && m.segments.iter().any(|s| {
4338                            if let MessageSegment::Text(text) = s {
4339                                text.contains("Retrying") && text.contains("seconds")
4340                            } else {
4341                                false
4342                            }
4343                        })
4344                })
4345                .count()
4346        });
4347        assert_eq!(retry_count, 1, "Should have scheduled first retry");
4348
4349        // Check retry state
4350        thread.read_with(cx, |thread, _| {
4351            assert!(thread.retry_state.is_some(), "Should have retry state");
4352            let retry_state = thread.retry_state.as_ref().unwrap();
4353            assert_eq!(retry_state.attempt, 1, "Should be first retry attempt");
4354        });
4355
4356        // Advance clock for first retry
4357        cx.executor()
4358            .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS));
4359        cx.run_until_parked();
4360
4361        // Should have scheduled second retry - count retry messages
4362        let retry_count = thread.update(cx, |thread, _| {
4363            thread
4364                .messages
4365                .iter()
4366                .filter(|m| {
4367                    m.ui_only
4368                        && m.segments.iter().any(|s| {
4369                            if let MessageSegment::Text(text) = s {
4370                                text.contains("Retrying") && text.contains("seconds")
4371                            } else {
4372                                false
4373                            }
4374                        })
4375                })
4376                .count()
4377        });
4378        assert_eq!(retry_count, 2, "Should have scheduled second retry");
4379
4380        // Check retry state updated
4381        thread.read_with(cx, |thread, _| {
4382            assert!(thread.retry_state.is_some(), "Should have retry state");
4383            let retry_state = thread.retry_state.as_ref().unwrap();
4384            assert_eq!(retry_state.attempt, 2, "Should be second retry attempt");
4385            assert_eq!(
4386                retry_state.max_attempts, MAX_RETRY_ATTEMPTS,
4387                "Should have correct max attempts"
4388            );
4389        });
4390
4391        // Advance clock for second retry (exponential backoff)
4392        cx.executor()
4393            .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS * 2));
4394        cx.run_until_parked();
4395
4396        // Should have scheduled third retry
4397        // Count all retry messages now
4398        let retry_count = thread.update(cx, |thread, _| {
4399            thread
4400                .messages
4401                .iter()
4402                .filter(|m| {
4403                    m.ui_only
4404                        && m.segments.iter().any(|s| {
4405                            if let MessageSegment::Text(text) = s {
4406                                text.contains("Retrying") && text.contains("seconds")
4407                            } else {
4408                                false
4409                            }
4410                        })
4411                })
4412                .count()
4413        });
4414        assert_eq!(
4415            retry_count, MAX_RETRY_ATTEMPTS as usize,
4416            "Should have scheduled third retry"
4417        );
4418
4419        // Check retry state updated
4420        thread.read_with(cx, |thread, _| {
4421            assert!(thread.retry_state.is_some(), "Should have retry state");
4422            let retry_state = thread.retry_state.as_ref().unwrap();
4423            assert_eq!(
4424                retry_state.attempt, MAX_RETRY_ATTEMPTS,
4425                "Should be at max retry attempt"
4426            );
4427            assert_eq!(
4428                retry_state.max_attempts, MAX_RETRY_ATTEMPTS,
4429                "Should have correct max attempts"
4430            );
4431        });
4432
4433        // Advance clock for third retry (exponential backoff)
4434        cx.executor()
4435            .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS * 4));
4436        cx.run_until_parked();
4437
4438        // No more retries should be scheduled after clock was advanced.
4439        let retry_count = thread.update(cx, |thread, _| {
4440            thread
4441                .messages
4442                .iter()
4443                .filter(|m| {
4444                    m.ui_only
4445                        && m.segments.iter().any(|s| {
4446                            if let MessageSegment::Text(text) = s {
4447                                text.contains("Retrying") && text.contains("seconds")
4448                            } else {
4449                                false
4450                            }
4451                        })
4452                })
4453                .count()
4454        });
4455        assert_eq!(
4456            retry_count, MAX_RETRY_ATTEMPTS as usize,
4457            "Should not exceed max retries"
4458        );
4459
4460        // Final completion count should be initial + max retries
4461        assert_eq!(
4462            *completion_count.lock(),
4463            (MAX_RETRY_ATTEMPTS + 1) as usize,
4464            "Should have made initial + max retry attempts"
4465        );
4466    }
4467
4468    #[gpui::test]
4469    async fn test_max_retries_exceeded(cx: &mut TestAppContext) {
4470        init_test_settings(cx);
4471
4472        let project = create_test_project(cx, json!({})).await;
4473        let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4474
4475        // Create model that returns overloaded error
4476        let model = Arc::new(ErrorInjector::new(TestError::Overloaded));
4477
4478        // Insert a user message
4479        thread.update(cx, |thread, cx| {
4480            thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4481        });
4482
4483        // Track events
4484        let retries_failed = Arc::new(Mutex::new(false));
4485        let retries_failed_clone = retries_failed.clone();
4486
4487        let _subscription = thread.update(cx, |_, cx| {
4488            cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| {
4489                if let ThreadEvent::RetriesFailed { .. } = event {
4490                    *retries_failed_clone.lock() = true;
4491                }
4492            })
4493        });
4494
4495        // Start initial completion
4496        thread.update(cx, |thread, cx| {
4497            thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4498        });
4499        cx.run_until_parked();
4500
4501        // Advance through all retries
4502        for i in 0..MAX_RETRY_ATTEMPTS {
4503            let delay = if i == 0 {
4504                BASE_RETRY_DELAY_SECS
4505            } else {
4506                BASE_RETRY_DELAY_SECS * 2u64.pow(i as u32 - 1)
4507            };
4508            cx.executor().advance_clock(Duration::from_secs(delay));
4509            cx.run_until_parked();
4510        }
4511
4512        // After the 3rd retry is scheduled, we need to wait for it to execute and fail
4513        // The 3rd retry has a delay of BASE_RETRY_DELAY_SECS * 4 (20 seconds)
4514        let final_delay = BASE_RETRY_DELAY_SECS * 2u64.pow((MAX_RETRY_ATTEMPTS - 1) as u32);
4515        cx.executor()
4516            .advance_clock(Duration::from_secs(final_delay));
4517        cx.run_until_parked();
4518
4519        let retry_count = thread.update(cx, |thread, _| {
4520            thread
4521                .messages
4522                .iter()
4523                .filter(|m| {
4524                    m.ui_only
4525                        && m.segments.iter().any(|s| {
4526                            if let MessageSegment::Text(text) = s {
4527                                text.contains("Retrying") && text.contains("seconds")
4528                            } else {
4529                                false
4530                            }
4531                        })
4532                })
4533                .count()
4534        });
4535
4536        // After max retries, should emit RetriesFailed event
4537        assert_eq!(
4538            retry_count, MAX_RETRY_ATTEMPTS as usize,
4539            "Should have attempted max retries"
4540        );
4541        assert!(
4542            *retries_failed.lock(),
4543            "Should emit RetriesFailed event after max retries exceeded"
4544        );
4545
4546        // Retry state should be cleared
4547        thread.read_with(cx, |thread, _| {
4548            assert!(
4549                thread.retry_state.is_none(),
4550                "Retry state should be cleared after max retries"
4551            );
4552
4553            // Verify we have the expected number of retry messages
4554            let retry_messages = thread
4555                .messages
4556                .iter()
4557                .filter(|msg| msg.ui_only && msg.role == Role::System)
4558                .count();
4559            assert_eq!(
4560                retry_messages, MAX_RETRY_ATTEMPTS as usize,
4561                "Should have one retry message per attempt"
4562            );
4563        });
4564    }
4565
4566    #[gpui::test]
4567    async fn test_retry_message_removed_on_retry(cx: &mut TestAppContext) {
4568        init_test_settings(cx);
4569
4570        let project = create_test_project(cx, json!({})).await;
4571        let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4572
4573        // We'll use a wrapper to switch behavior after first failure
4574        struct RetryTestModel {
4575            inner: Arc<FakeLanguageModel>,
4576            failed_once: Arc<Mutex<bool>>,
4577        }
4578
4579        impl LanguageModel for RetryTestModel {
4580            fn id(&self) -> LanguageModelId {
4581                self.inner.id()
4582            }
4583
4584            fn name(&self) -> LanguageModelName {
4585                self.inner.name()
4586            }
4587
4588            fn provider_id(&self) -> LanguageModelProviderId {
4589                self.inner.provider_id()
4590            }
4591
4592            fn provider_name(&self) -> LanguageModelProviderName {
4593                self.inner.provider_name()
4594            }
4595
4596            fn supports_tools(&self) -> bool {
4597                self.inner.supports_tools()
4598            }
4599
4600            fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
4601                self.inner.supports_tool_choice(choice)
4602            }
4603
4604            fn supports_images(&self) -> bool {
4605                self.inner.supports_images()
4606            }
4607
4608            fn telemetry_id(&self) -> String {
4609                self.inner.telemetry_id()
4610            }
4611
4612            fn max_token_count(&self) -> u64 {
4613                self.inner.max_token_count()
4614            }
4615
4616            fn count_tokens(
4617                &self,
4618                request: LanguageModelRequest,
4619                cx: &App,
4620            ) -> BoxFuture<'static, Result<u64>> {
4621                self.inner.count_tokens(request, cx)
4622            }
4623
4624            fn stream_completion(
4625                &self,
4626                request: LanguageModelRequest,
4627                cx: &AsyncApp,
4628            ) -> BoxFuture<
4629                'static,
4630                Result<
4631                    BoxStream<
4632                        'static,
4633                        Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
4634                    >,
4635                    LanguageModelCompletionError,
4636                >,
4637            > {
4638                if !*self.failed_once.lock() {
4639                    *self.failed_once.lock() = true;
4640                    let provider = self.provider_name();
4641                    // Return error on first attempt
4642                    let stream = futures::stream::once(async move {
4643                        Err(LanguageModelCompletionError::ServerOverloaded {
4644                            provider,
4645                            retry_after: None,
4646                        })
4647                    });
4648                    async move { Ok(stream.boxed()) }.boxed()
4649                } else {
4650                    // Succeed on retry
4651                    self.inner.stream_completion(request, cx)
4652                }
4653            }
4654
4655            fn as_fake(&self) -> &FakeLanguageModel {
4656                &self.inner
4657            }
4658        }
4659
4660        let model = Arc::new(RetryTestModel {
4661            inner: Arc::new(FakeLanguageModel::default()),
4662            failed_once: Arc::new(Mutex::new(false)),
4663        });
4664
4665        // Insert a user message
4666        thread.update(cx, |thread, cx| {
4667            thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4668        });
4669
4670        // Track message deletions
4671        // Track when retry completes successfully
4672        let retry_completed = Arc::new(Mutex::new(false));
4673        let retry_completed_clone = retry_completed.clone();
4674
4675        let _subscription = thread.update(cx, |_, cx| {
4676            cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| {
4677                if let ThreadEvent::StreamedCompletion = event {
4678                    *retry_completed_clone.lock() = true;
4679                }
4680            })
4681        });
4682
4683        // Start completion
4684        thread.update(cx, |thread, cx| {
4685            thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4686        });
4687        cx.run_until_parked();
4688
4689        // Get the retry message ID
4690        let retry_message_id = thread.read_with(cx, |thread, _| {
4691            thread
4692                .messages()
4693                .find(|msg| msg.role == Role::System && msg.ui_only)
4694                .map(|msg| msg.id)
4695                .expect("Should have a retry message")
4696        });
4697
4698        // Wait for retry
4699        cx.executor()
4700            .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS));
4701        cx.run_until_parked();
4702
4703        // Stream some successful content
4704        let fake_model = model.as_fake();
4705        // After the retry, there should be a new pending completion
4706        let pending = fake_model.pending_completions();
4707        assert!(
4708            !pending.is_empty(),
4709            "Should have a pending completion after retry"
4710        );
4711        fake_model.stream_completion_response(&pending[0], "Success!");
4712        fake_model.end_completion_stream(&pending[0]);
4713        cx.run_until_parked();
4714
4715        // Check that the retry completed successfully
4716        assert!(
4717            *retry_completed.lock(),
4718            "Retry should have completed successfully"
4719        );
4720
4721        // Retry message should still exist but be marked as ui_only
4722        thread.read_with(cx, |thread, _| {
4723            let retry_msg = thread
4724                .message(retry_message_id)
4725                .expect("Retry message should still exist");
4726            assert!(retry_msg.ui_only, "Retry message should be ui_only");
4727            assert_eq!(
4728                retry_msg.role,
4729                Role::System,
4730                "Retry message should have System role"
4731            );
4732        });
4733    }
4734
4735    #[gpui::test]
4736    async fn test_successful_completion_clears_retry_state(cx: &mut TestAppContext) {
4737        init_test_settings(cx);
4738
4739        let project = create_test_project(cx, json!({})).await;
4740        let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4741
4742        // Create a model that fails once then succeeds
4743        struct FailOnceModel {
4744            inner: Arc<FakeLanguageModel>,
4745            failed_once: Arc<Mutex<bool>>,
4746        }
4747
4748        impl LanguageModel for FailOnceModel {
4749            fn id(&self) -> LanguageModelId {
4750                self.inner.id()
4751            }
4752
4753            fn name(&self) -> LanguageModelName {
4754                self.inner.name()
4755            }
4756
4757            fn provider_id(&self) -> LanguageModelProviderId {
4758                self.inner.provider_id()
4759            }
4760
4761            fn provider_name(&self) -> LanguageModelProviderName {
4762                self.inner.provider_name()
4763            }
4764
4765            fn supports_tools(&self) -> bool {
4766                self.inner.supports_tools()
4767            }
4768
4769            fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
4770                self.inner.supports_tool_choice(choice)
4771            }
4772
4773            fn supports_images(&self) -> bool {
4774                self.inner.supports_images()
4775            }
4776
4777            fn telemetry_id(&self) -> String {
4778                self.inner.telemetry_id()
4779            }
4780
4781            fn max_token_count(&self) -> u64 {
4782                self.inner.max_token_count()
4783            }
4784
4785            fn count_tokens(
4786                &self,
4787                request: LanguageModelRequest,
4788                cx: &App,
4789            ) -> BoxFuture<'static, Result<u64>> {
4790                self.inner.count_tokens(request, cx)
4791            }
4792
4793            fn stream_completion(
4794                &self,
4795                request: LanguageModelRequest,
4796                cx: &AsyncApp,
4797            ) -> BoxFuture<
4798                'static,
4799                Result<
4800                    BoxStream<
4801                        'static,
4802                        Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
4803                    >,
4804                    LanguageModelCompletionError,
4805                >,
4806            > {
4807                if !*self.failed_once.lock() {
4808                    *self.failed_once.lock() = true;
4809                    let provider = self.provider_name();
4810                    // Return error on first attempt
4811                    let stream = futures::stream::once(async move {
4812                        Err(LanguageModelCompletionError::ServerOverloaded {
4813                            provider,
4814                            retry_after: None,
4815                        })
4816                    });
4817                    async move { Ok(stream.boxed()) }.boxed()
4818                } else {
4819                    // Succeed on retry
4820                    self.inner.stream_completion(request, cx)
4821                }
4822            }
4823        }
4824
4825        let fail_once_model = Arc::new(FailOnceModel {
4826            inner: Arc::new(FakeLanguageModel::default()),
4827            failed_once: Arc::new(Mutex::new(false)),
4828        });
4829
4830        // Insert a user message
4831        thread.update(cx, |thread, cx| {
4832            thread.insert_user_message(
4833                "Test message",
4834                ContextLoadResult::default(),
4835                None,
4836                vec![],
4837                cx,
4838            );
4839        });
4840
4841        // Start completion with fail-once model
4842        thread.update(cx, |thread, cx| {
4843            thread.send_to_model(
4844                fail_once_model.clone(),
4845                CompletionIntent::UserPrompt,
4846                None,
4847                cx,
4848            );
4849        });
4850
4851        cx.run_until_parked();
4852
4853        // Verify retry state exists after first failure
4854        thread.read_with(cx, |thread, _| {
4855            assert!(
4856                thread.retry_state.is_some(),
4857                "Should have retry state after failure"
4858            );
4859        });
4860
4861        // Wait for retry delay
4862        cx.executor()
4863            .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS));
4864        cx.run_until_parked();
4865
4866        // The retry should now use our FailOnceModel which should succeed
4867        // We need to help the FakeLanguageModel complete the stream
4868        let inner_fake = fail_once_model.inner.clone();
4869
4870        // Wait a bit for the retry to start
4871        cx.run_until_parked();
4872
4873        // Check for pending completions and complete them
4874        if let Some(pending) = inner_fake.pending_completions().first() {
4875            inner_fake.stream_completion_response(pending, "Success!");
4876            inner_fake.end_completion_stream(pending);
4877        }
4878        cx.run_until_parked();
4879
4880        thread.read_with(cx, |thread, _| {
4881            assert!(
4882                thread.retry_state.is_none(),
4883                "Retry state should be cleared after successful completion"
4884            );
4885
4886            let has_assistant_message = thread
4887                .messages
4888                .iter()
4889                .any(|msg| msg.role == Role::Assistant && !msg.ui_only);
4890            assert!(
4891                has_assistant_message,
4892                "Should have an assistant message after successful retry"
4893            );
4894        });
4895    }
4896
4897    #[gpui::test]
4898    async fn test_rate_limit_retry_single_attempt(cx: &mut TestAppContext) {
4899        init_test_settings(cx);
4900
4901        let project = create_test_project(cx, json!({})).await;
4902        let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4903
4904        // Create a model that returns rate limit error with retry_after
4905        struct RateLimitModel {
4906            inner: Arc<FakeLanguageModel>,
4907        }
4908
4909        impl LanguageModel for RateLimitModel {
4910            fn id(&self) -> LanguageModelId {
4911                self.inner.id()
4912            }
4913
4914            fn name(&self) -> LanguageModelName {
4915                self.inner.name()
4916            }
4917
4918            fn provider_id(&self) -> LanguageModelProviderId {
4919                self.inner.provider_id()
4920            }
4921
4922            fn provider_name(&self) -> LanguageModelProviderName {
4923                self.inner.provider_name()
4924            }
4925
4926            fn supports_tools(&self) -> bool {
4927                self.inner.supports_tools()
4928            }
4929
4930            fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
4931                self.inner.supports_tool_choice(choice)
4932            }
4933
4934            fn supports_images(&self) -> bool {
4935                self.inner.supports_images()
4936            }
4937
4938            fn telemetry_id(&self) -> String {
4939                self.inner.telemetry_id()
4940            }
4941
4942            fn max_token_count(&self) -> u64 {
4943                self.inner.max_token_count()
4944            }
4945
4946            fn count_tokens(
4947                &self,
4948                request: LanguageModelRequest,
4949                cx: &App,
4950            ) -> BoxFuture<'static, Result<u64>> {
4951                self.inner.count_tokens(request, cx)
4952            }
4953
4954            fn stream_completion(
4955                &self,
4956                _request: LanguageModelRequest,
4957                _cx: &AsyncApp,
4958            ) -> BoxFuture<
4959                'static,
4960                Result<
4961                    BoxStream<
4962                        'static,
4963                        Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
4964                    >,
4965                    LanguageModelCompletionError,
4966                >,
4967            > {
4968                let provider = self.provider_name();
4969                async move {
4970                    let stream = futures::stream::once(async move {
4971                        Err(LanguageModelCompletionError::RateLimitExceeded {
4972                            provider,
4973                            retry_after: Some(Duration::from_secs(TEST_RATE_LIMIT_RETRY_SECS)),
4974                        })
4975                    });
4976                    Ok(stream.boxed())
4977                }
4978                .boxed()
4979            }
4980
4981            fn as_fake(&self) -> &FakeLanguageModel {
4982                &self.inner
4983            }
4984        }
4985
4986        let model = Arc::new(RateLimitModel {
4987            inner: Arc::new(FakeLanguageModel::default()),
4988        });
4989
4990        // Insert a user message
4991        thread.update(cx, |thread, cx| {
4992            thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4993        });
4994
4995        // Start completion
4996        thread.update(cx, |thread, cx| {
4997            thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4998        });
4999
5000        cx.run_until_parked();
5001
5002        let retry_count = thread.update(cx, |thread, _| {
5003            thread
5004                .messages
5005                .iter()
5006                .filter(|m| {
5007                    m.ui_only
5008                        && m.segments.iter().any(|s| {
5009                            if let MessageSegment::Text(text) = s {
5010                                text.contains("rate limit exceeded")
5011                            } else {
5012                                false
5013                            }
5014                        })
5015                })
5016                .count()
5017        });
5018        assert_eq!(retry_count, 1, "Should have scheduled one retry");
5019
5020        thread.read_with(cx, |thread, _| {
5021            assert!(
5022                thread.retry_state.is_none(),
5023                "Rate limit errors should not set retry_state"
5024            );
5025        });
5026
5027        // Verify we have one retry message
5028        thread.read_with(cx, |thread, _| {
5029            let retry_messages = thread
5030                .messages
5031                .iter()
5032                .filter(|msg| {
5033                    msg.ui_only
5034                        && msg.segments.iter().any(|seg| {
5035                            if let MessageSegment::Text(text) = seg {
5036                                text.contains("rate limit exceeded")
5037                            } else {
5038                                false
5039                            }
5040                        })
5041                })
5042                .count();
5043            assert_eq!(
5044                retry_messages, 1,
5045                "Should have one rate limit retry message"
5046            );
5047        });
5048
5049        // Check that retry message doesn't include attempt count
5050        thread.read_with(cx, |thread, _| {
5051            let retry_message = thread
5052                .messages
5053                .iter()
5054                .find(|msg| msg.role == Role::System && msg.ui_only)
5055                .expect("Should have a retry message");
5056
5057            // Check that the message doesn't contain attempt count
5058            if let Some(MessageSegment::Text(text)) = retry_message.segments.first() {
5059                assert!(
5060                    !text.contains("attempt"),
5061                    "Rate limit retry message should not contain attempt count"
5062                );
5063                assert!(
5064                    text.contains(&format!(
5065                        "Retrying in {} seconds",
5066                        TEST_RATE_LIMIT_RETRY_SECS
5067                    )),
5068                    "Rate limit retry message should contain retry delay"
5069                );
5070            }
5071        });
5072    }
5073
5074    #[gpui::test]
5075    async fn test_ui_only_messages_not_sent_to_model(cx: &mut TestAppContext) {
5076        init_test_settings(cx);
5077
5078        let project = create_test_project(cx, json!({})).await;
5079        let (_, _, thread, _, model) = setup_test_environment(cx, project.clone()).await;
5080
5081        // Insert a regular user message
5082        thread.update(cx, |thread, cx| {
5083            thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
5084        });
5085
5086        // Insert a UI-only message (like our retry notifications)
5087        thread.update(cx, |thread, cx| {
5088            let id = thread.next_message_id.post_inc();
5089            thread.messages.push(Message {
5090                id,
5091                role: Role::System,
5092                segments: vec![MessageSegment::Text(
5093                    "This is a UI-only message that should not be sent to the model".to_string(),
5094                )],
5095                loaded_context: LoadedContext::default(),
5096                creases: Vec::new(),
5097                is_hidden: true,
5098                ui_only: true,
5099            });
5100            cx.emit(ThreadEvent::MessageAdded(id));
5101        });
5102
5103        // Insert another regular message
5104        thread.update(cx, |thread, cx| {
5105            thread.insert_user_message(
5106                "How are you?",
5107                ContextLoadResult::default(),
5108                None,
5109                vec![],
5110                cx,
5111            );
5112        });
5113
5114        // Generate the completion request
5115        let request = thread.update(cx, |thread, cx| {
5116            thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
5117        });
5118
5119        // Verify that the request only contains non-UI-only messages
5120        // Should have system prompt + 2 user messages, but not the UI-only message
5121        let user_messages: Vec<_> = request
5122            .messages
5123            .iter()
5124            .filter(|msg| msg.role == Role::User)
5125            .collect();
5126        assert_eq!(
5127            user_messages.len(),
5128            2,
5129            "Should have exactly 2 user messages"
5130        );
5131
5132        // Verify the UI-only content is not present anywhere in the request
5133        let request_text = request
5134            .messages
5135            .iter()
5136            .flat_map(|msg| &msg.content)
5137            .filter_map(|content| match content {
5138                MessageContent::Text(text) => Some(text.as_str()),
5139                _ => None,
5140            })
5141            .collect::<String>();
5142
5143        assert!(
5144            !request_text.contains("UI-only message"),
5145            "UI-only message content should not be in the request"
5146        );
5147
5148        // Verify the thread still has all 3 messages (including UI-only)
5149        thread.read_with(cx, |thread, _| {
5150            assert_eq!(
5151                thread.messages().count(),
5152                3,
5153                "Thread should have 3 messages"
5154            );
5155            assert_eq!(
5156                thread.messages().filter(|m| m.ui_only).count(),
5157                1,
5158                "Thread should have 1 UI-only message"
5159            );
5160        });
5161
5162        // Verify that UI-only messages are not serialized
5163        let serialized = thread
5164            .update(cx, |thread, cx| thread.serialize(cx))
5165            .await
5166            .unwrap();
5167        assert_eq!(
5168            serialized.messages.len(),
5169            2,
5170            "Serialized thread should only have 2 messages (no UI-only)"
5171        );
5172    }
5173
5174    #[gpui::test]
5175    async fn test_retry_cancelled_on_stop(cx: &mut TestAppContext) {
5176        init_test_settings(cx);
5177
5178        let project = create_test_project(cx, json!({})).await;
5179        let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
5180
5181        // Create model that returns overloaded error
5182        let model = Arc::new(ErrorInjector::new(TestError::Overloaded));
5183
5184        // Insert a user message
5185        thread.update(cx, |thread, cx| {
5186            thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
5187        });
5188
5189        // Start completion
5190        thread.update(cx, |thread, cx| {
5191            thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
5192        });
5193
5194        cx.run_until_parked();
5195
5196        // Verify retry was scheduled by checking for retry message
5197        let has_retry_message = thread.read_with(cx, |thread, _| {
5198            thread.messages.iter().any(|m| {
5199                m.ui_only
5200                    && m.segments.iter().any(|s| {
5201                        if let MessageSegment::Text(text) = s {
5202                            text.contains("Retrying") && text.contains("seconds")
5203                        } else {
5204                            false
5205                        }
5206                    })
5207            })
5208        });
5209        assert!(has_retry_message, "Should have scheduled a retry");
5210
5211        // Cancel the completion before the retry happens
5212        thread.update(cx, |thread, cx| {
5213            thread.cancel_last_completion(None, cx);
5214        });
5215
5216        cx.run_until_parked();
5217
5218        // The retry should not have happened - no pending completions
5219        let fake_model = model.as_fake();
5220        assert_eq!(
5221            fake_model.pending_completions().len(),
5222            0,
5223            "Should have no pending completions after cancellation"
5224        );
5225
5226        // Verify the retry was cancelled by checking retry state
5227        thread.read_with(cx, |thread, _| {
5228            if let Some(retry_state) = &thread.retry_state {
5229                panic!(
5230                    "retry_state should be cleared after cancellation, but found: attempt={}, max_attempts={}, intent={:?}",
5231                    retry_state.attempt, retry_state.max_attempts, retry_state.intent
5232                );
5233            }
5234        });
5235    }
5236
5237    fn test_summarize_error(
5238        model: &Arc<dyn LanguageModel>,
5239        thread: &Entity<Thread>,
5240        cx: &mut TestAppContext,
5241    ) {
5242        thread.update(cx, |thread, cx| {
5243            thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
5244            thread.send_to_model(
5245                model.clone(),
5246                CompletionIntent::ThreadSummarization,
5247                None,
5248                cx,
5249            );
5250        });
5251
5252        let fake_model = model.as_fake();
5253        simulate_successful_response(&fake_model, cx);
5254
5255        thread.read_with(cx, |thread, _| {
5256            assert!(matches!(thread.summary(), ThreadSummary::Generating));
5257            assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
5258        });
5259
5260        // Simulate summary request ending
5261        cx.run_until_parked();
5262        fake_model.end_last_completion_stream();
5263        cx.run_until_parked();
5264
5265        // State is set to Error and default message
5266        thread.read_with(cx, |thread, _| {
5267            assert!(matches!(thread.summary(), ThreadSummary::Error));
5268            assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
5269        });
5270    }
5271
5272    fn simulate_successful_response(fake_model: &FakeLanguageModel, cx: &mut TestAppContext) {
5273        cx.run_until_parked();
5274        fake_model.stream_last_completion_response("Assistant response");
5275        fake_model.end_last_completion_stream();
5276        cx.run_until_parked();
5277    }
5278
5279    fn init_test_settings(cx: &mut TestAppContext) {
5280        cx.update(|cx| {
5281            let settings_store = SettingsStore::test(cx);
5282            cx.set_global(settings_store);
5283            language::init(cx);
5284            Project::init_settings(cx);
5285            AgentSettings::register(cx);
5286            prompt_store::init(cx);
5287            thread_store::init(cx);
5288            workspace::init_settings(cx);
5289            language_model::init_settings(cx);
5290            ThemeSettings::register(cx);
5291            ToolRegistry::default_global(cx);
5292            assistant_tool::init(cx);
5293
5294            let http_client = Arc::new(http_client::HttpClientWithUrl::new(
5295                http_client::FakeHttpClient::with_200_response(),
5296                "http://localhost".to_string(),
5297                None,
5298            ));
5299            assistant_tools::init(http_client, cx);
5300        });
5301    }
5302
5303    // Helper to create a test project with test files
5304    async fn create_test_project(
5305        cx: &mut TestAppContext,
5306        files: serde_json::Value,
5307    ) -> Entity<Project> {
5308        let fs = FakeFs::new(cx.executor());
5309        fs.insert_tree(path!("/test"), files).await;
5310        Project::test(fs, [path!("/test").as_ref()], cx).await
5311    }
5312
5313    async fn setup_test_environment(
5314        cx: &mut TestAppContext,
5315        project: Entity<Project>,
5316    ) -> (
5317        Entity<Workspace>,
5318        Entity<ThreadStore>,
5319        Entity<Thread>,
5320        Entity<ContextStore>,
5321        Arc<dyn LanguageModel>,
5322    ) {
5323        let (workspace, cx) =
5324            cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
5325
5326        let thread_store = cx
5327            .update(|_, cx| {
5328                ThreadStore::load(
5329                    project.clone(),
5330                    cx.new(|_| ToolWorkingSet::default()),
5331                    None,
5332                    Arc::new(PromptBuilder::new(None).unwrap()),
5333                    cx,
5334                )
5335            })
5336            .await
5337            .unwrap();
5338
5339        let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
5340        let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
5341
5342        let provider = Arc::new(FakeLanguageModelProvider);
5343        let model = provider.test_model();
5344        let model: Arc<dyn LanguageModel> = Arc::new(model);
5345
5346        cx.update(|_, cx| {
5347            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
5348                registry.set_default_model(
5349                    Some(ConfiguredModel {
5350                        provider: provider.clone(),
5351                        model: model.clone(),
5352                    }),
5353                    cx,
5354                );
5355                registry.set_thread_summary_model(
5356                    Some(ConfiguredModel {
5357                        provider,
5358                        model: model.clone(),
5359                    }),
5360                    cx,
5361                );
5362            })
5363        });
5364
5365        (workspace, thread_store, thread, context_store, model)
5366    }
5367
5368    async fn add_file_to_context(
5369        project: &Entity<Project>,
5370        context_store: &Entity<ContextStore>,
5371        path: &str,
5372        cx: &mut TestAppContext,
5373    ) -> Result<Entity<language::Buffer>> {
5374        let buffer_path = project
5375            .read_with(cx, |project, cx| project.find_project_path(path, cx))
5376            .unwrap();
5377
5378        let buffer = project
5379            .update(cx, |project, cx| {
5380                project.open_buffer(buffer_path.clone(), cx)
5381            })
5382            .await
5383            .unwrap();
5384
5385        context_store.update(cx, |context_store, cx| {
5386            context_store.add_file_from_buffer(&buffer_path, buffer.clone(), false, cx);
5387        });
5388
5389        Ok(buffer)
5390    }
5391}