thread.rs

   1use crate::{
   2    agent_profile::AgentProfile,
   3    context::{AgentContext, AgentContextHandle, ContextLoadResult, LoadedContext},
   4    thread_store::{
   5        SerializedCrease, SerializedLanguageModel, SerializedMessage, SerializedMessageSegment,
   6        SerializedThread, SerializedToolResult, SerializedToolUse, SharedProjectContext,
   7        ThreadStore,
   8    },
   9    tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState},
  10};
  11use agent_settings::{AgentProfileId, AgentSettings, CompletionMode};
  12use anyhow::{Result, anyhow};
  13use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
  14use chrono::{DateTime, Utc};
  15use client::{ModelRequestUsage, RequestUsage};
  16use collections::HashMap;
  17use feature_flags::{self, FeatureFlagAppExt};
  18use futures::{FutureExt, StreamExt as _, future::Shared};
  19use git::repository::DiffType;
  20use gpui::{
  21    AnyWindowHandle, App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task,
  22    WeakEntity, Window,
  23};
  24use language_model::{
  25    ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
  26    LanguageModelId, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
  27    LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolResultContent,
  28    LanguageModelToolUse, LanguageModelToolUseId, MessageContent, ModelRequestLimitReachedError,
  29    PaymentRequiredError, Role, SelectedModel, StopReason, TokenUsage,
  30};
  31use postage::stream::Stream as _;
  32use project::{
  33    Project,
  34    git_store::{GitStore, GitStoreCheckpoint, RepositoryState},
  35};
  36use prompt_store::{ModelContext, PromptBuilder};
  37use proto::Plan;
  38use schemars::JsonSchema;
  39use serde::{Deserialize, Serialize};
  40use settings::Settings;
  41use std::{
  42    io::Write,
  43    ops::Range,
  44    sync::Arc,
  45    time::{Duration, Instant},
  46};
  47use thiserror::Error;
  48use util::{ResultExt as _, debug_panic, post_inc};
  49use uuid::Uuid;
  50use zed_llm_client::{CompletionIntent, CompletionRequestStatus, UsageLimit};
  51
  52const MAX_RETRY_ATTEMPTS: u8 = 3;
  53const BASE_RETRY_DELAY_SECS: u64 = 5;
  54
  55#[derive(
  56    Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
  57)]
  58pub struct ThreadId(Arc<str>);
  59
  60impl ThreadId {
  61    pub fn new() -> Self {
  62        Self(Uuid::new_v4().to_string().into())
  63    }
  64}
  65
  66impl std::fmt::Display for ThreadId {
  67    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  68        write!(f, "{}", self.0)
  69    }
  70}
  71
  72impl From<&str> for ThreadId {
  73    fn from(value: &str) -> Self {
  74        Self(value.into())
  75    }
  76}
  77
  78/// The ID of the user prompt that initiated a request.
  79///
  80/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
  81#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
  82pub struct PromptId(Arc<str>);
  83
  84impl PromptId {
  85    pub fn new() -> Self {
  86        Self(Uuid::new_v4().to_string().into())
  87    }
  88}
  89
  90impl std::fmt::Display for PromptId {
  91    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  92        write!(f, "{}", self.0)
  93    }
  94}
  95
  96#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
  97pub struct MessageId(pub(crate) usize);
  98
  99impl MessageId {
 100    fn post_inc(&mut self) -> Self {
 101        Self(post_inc(&mut self.0))
 102    }
 103
 104    pub fn as_usize(&self) -> usize {
 105        self.0
 106    }
 107}
 108
 109/// Stored information that can be used to resurrect a context crease when creating an editor for a past message.
 110#[derive(Clone, Debug)]
 111pub struct MessageCrease {
 112    pub range: Range<usize>,
 113    pub icon_path: SharedString,
 114    pub label: SharedString,
 115    /// None for a deserialized message, Some otherwise.
 116    pub context: Option<AgentContextHandle>,
 117}
 118
 119/// A message in a [`Thread`].
 120#[derive(Debug, Clone)]
 121pub struct Message {
 122    pub id: MessageId,
 123    pub role: Role,
 124    pub segments: Vec<MessageSegment>,
 125    pub loaded_context: LoadedContext,
 126    pub creases: Vec<MessageCrease>,
 127    pub is_hidden: bool,
 128    pub ui_only: bool,
 129}
 130
 131impl Message {
 132    /// Returns whether the message contains any meaningful text that should be displayed
 133    /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
 134    pub fn should_display_content(&self) -> bool {
 135        self.segments.iter().all(|segment| segment.should_display())
 136    }
 137
 138    pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
 139        if let Some(MessageSegment::Thinking {
 140            text: segment,
 141            signature: current_signature,
 142        }) = self.segments.last_mut()
 143        {
 144            if let Some(signature) = signature {
 145                *current_signature = Some(signature);
 146            }
 147            segment.push_str(text);
 148        } else {
 149            self.segments.push(MessageSegment::Thinking {
 150                text: text.to_string(),
 151                signature,
 152            });
 153        }
 154    }
 155
 156    pub fn push_redacted_thinking(&mut self, data: String) {
 157        self.segments.push(MessageSegment::RedactedThinking(data));
 158    }
 159
 160    pub fn push_text(&mut self, text: &str) {
 161        if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
 162            segment.push_str(text);
 163        } else {
 164            self.segments.push(MessageSegment::Text(text.to_string()));
 165        }
 166    }
 167
 168    pub fn to_string(&self) -> String {
 169        let mut result = String::new();
 170
 171        if !self.loaded_context.text.is_empty() {
 172            result.push_str(&self.loaded_context.text);
 173        }
 174
 175        for segment in &self.segments {
 176            match segment {
 177                MessageSegment::Text(text) => result.push_str(text),
 178                MessageSegment::Thinking { text, .. } => {
 179                    result.push_str("<think>\n");
 180                    result.push_str(text);
 181                    result.push_str("\n</think>");
 182                }
 183                MessageSegment::RedactedThinking(_) => {}
 184            }
 185        }
 186
 187        result
 188    }
 189}
 190
 191#[derive(Debug, Clone, PartialEq, Eq)]
 192pub enum MessageSegment {
 193    Text(String),
 194    Thinking {
 195        text: String,
 196        signature: Option<String>,
 197    },
 198    RedactedThinking(String),
 199}
 200
 201impl MessageSegment {
 202    pub fn should_display(&self) -> bool {
 203        match self {
 204            Self::Text(text) => text.is_empty(),
 205            Self::Thinking { text, .. } => text.is_empty(),
 206            Self::RedactedThinking(_) => false,
 207        }
 208    }
 209
 210    pub fn text(&self) -> Option<&str> {
 211        match self {
 212            MessageSegment::Text(text) => Some(text),
 213            _ => None,
 214        }
 215    }
 216}
 217
 218#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
 219pub struct ProjectSnapshot {
 220    pub worktree_snapshots: Vec<WorktreeSnapshot>,
 221    pub unsaved_buffer_paths: Vec<String>,
 222    pub timestamp: DateTime<Utc>,
 223}
 224
 225#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
 226pub struct WorktreeSnapshot {
 227    pub worktree_path: String,
 228    pub git_state: Option<GitState>,
 229}
 230
 231#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
 232pub struct GitState {
 233    pub remote_url: Option<String>,
 234    pub head_sha: Option<String>,
 235    pub current_branch: Option<String>,
 236    pub diff: Option<String>,
 237}
 238
 239#[derive(Clone, Debug)]
 240pub struct ThreadCheckpoint {
 241    message_id: MessageId,
 242    git_checkpoint: GitStoreCheckpoint,
 243}
 244
 245#[derive(Copy, Clone, Debug, PartialEq, Eq)]
 246pub enum ThreadFeedback {
 247    Positive,
 248    Negative,
 249}
 250
 251pub enum LastRestoreCheckpoint {
 252    Pending {
 253        message_id: MessageId,
 254    },
 255    Error {
 256        message_id: MessageId,
 257        error: String,
 258    },
 259}
 260
 261impl LastRestoreCheckpoint {
 262    pub fn message_id(&self) -> MessageId {
 263        match self {
 264            LastRestoreCheckpoint::Pending { message_id } => *message_id,
 265            LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
 266        }
 267    }
 268}
 269
 270#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
 271pub enum DetailedSummaryState {
 272    #[default]
 273    NotGenerated,
 274    Generating {
 275        message_id: MessageId,
 276    },
 277    Generated {
 278        text: SharedString,
 279        message_id: MessageId,
 280    },
 281}
 282
 283impl DetailedSummaryState {
 284    fn text(&self) -> Option<SharedString> {
 285        if let Self::Generated { text, .. } = self {
 286            Some(text.clone())
 287        } else {
 288            None
 289        }
 290    }
 291}
 292
 293#[derive(Default, Debug)]
 294pub struct TotalTokenUsage {
 295    pub total: u64,
 296    pub max: u64,
 297}
 298
 299impl TotalTokenUsage {
 300    pub fn ratio(&self) -> TokenUsageRatio {
 301        #[cfg(debug_assertions)]
 302        let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
 303            .unwrap_or("0.8".to_string())
 304            .parse()
 305            .unwrap();
 306        #[cfg(not(debug_assertions))]
 307        let warning_threshold: f32 = 0.8;
 308
 309        // When the maximum is unknown because there is no selected model,
 310        // avoid showing the token limit warning.
 311        if self.max == 0 {
 312            TokenUsageRatio::Normal
 313        } else if self.total >= self.max {
 314            TokenUsageRatio::Exceeded
 315        } else if self.total as f32 / self.max as f32 >= warning_threshold {
 316            TokenUsageRatio::Warning
 317        } else {
 318            TokenUsageRatio::Normal
 319        }
 320    }
 321
 322    pub fn add(&self, tokens: u64) -> TotalTokenUsage {
 323        TotalTokenUsage {
 324            total: self.total + tokens,
 325            max: self.max,
 326        }
 327    }
 328}
 329
 330#[derive(Debug, Default, PartialEq, Eq)]
 331pub enum TokenUsageRatio {
 332    #[default]
 333    Normal,
 334    Warning,
 335    Exceeded,
 336}
 337
 338#[derive(Debug, Clone, Copy)]
 339pub enum QueueState {
 340    Sending,
 341    Queued { position: usize },
 342    Started,
 343}
 344
 345/// A thread of conversation with the LLM.
 346pub struct Thread {
 347    id: ThreadId,
 348    updated_at: DateTime<Utc>,
 349    summary: ThreadSummary,
 350    pending_summary: Task<Option<()>>,
 351    detailed_summary_task: Task<Option<()>>,
 352    detailed_summary_tx: postage::watch::Sender<DetailedSummaryState>,
 353    detailed_summary_rx: postage::watch::Receiver<DetailedSummaryState>,
 354    completion_mode: agent_settings::CompletionMode,
 355    messages: Vec<Message>,
 356    next_message_id: MessageId,
 357    last_prompt_id: PromptId,
 358    project_context: SharedProjectContext,
 359    checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
 360    completion_count: usize,
 361    pending_completions: Vec<PendingCompletion>,
 362    project: Entity<Project>,
 363    prompt_builder: Arc<PromptBuilder>,
 364    tools: Entity<ToolWorkingSet>,
 365    tool_use: ToolUseState,
 366    action_log: Entity<ActionLog>,
 367    last_restore_checkpoint: Option<LastRestoreCheckpoint>,
 368    pending_checkpoint: Option<ThreadCheckpoint>,
 369    initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
 370    request_token_usage: Vec<TokenUsage>,
 371    cumulative_token_usage: TokenUsage,
 372    exceeded_window_error: Option<ExceededWindowError>,
 373    tool_use_limit_reached: bool,
 374    feedback: Option<ThreadFeedback>,
 375    retry_state: Option<RetryState>,
 376    message_feedback: HashMap<MessageId, ThreadFeedback>,
 377    last_auto_capture_at: Option<Instant>,
 378    last_received_chunk_at: Option<Instant>,
 379    request_callback: Option<
 380        Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
 381    >,
 382    remaining_turns: u32,
 383    configured_model: Option<ConfiguredModel>,
 384    profile: AgentProfile,
 385}
 386
 387#[derive(Clone, Debug)]
 388struct RetryState {
 389    attempt: u8,
 390    max_attempts: u8,
 391    intent: CompletionIntent,
 392}
 393
 394#[derive(Clone, Debug, PartialEq, Eq)]
 395pub enum ThreadSummary {
 396    Pending,
 397    Generating,
 398    Ready(SharedString),
 399    Error,
 400}
 401
 402impl ThreadSummary {
 403    pub const DEFAULT: SharedString = SharedString::new_static("New Thread");
 404
 405    pub fn or_default(&self) -> SharedString {
 406        self.unwrap_or(Self::DEFAULT)
 407    }
 408
 409    pub fn unwrap_or(&self, message: impl Into<SharedString>) -> SharedString {
 410        self.ready().unwrap_or_else(|| message.into())
 411    }
 412
 413    pub fn ready(&self) -> Option<SharedString> {
 414        match self {
 415            ThreadSummary::Ready(summary) => Some(summary.clone()),
 416            ThreadSummary::Pending | ThreadSummary::Generating | ThreadSummary::Error => None,
 417        }
 418    }
 419}
 420
 421#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
 422pub struct ExceededWindowError {
 423    /// Model used when last message exceeded context window
 424    model_id: LanguageModelId,
 425    /// Token count including last message
 426    token_count: u64,
 427}
 428
 429impl Thread {
 430    pub fn new(
 431        project: Entity<Project>,
 432        tools: Entity<ToolWorkingSet>,
 433        prompt_builder: Arc<PromptBuilder>,
 434        system_prompt: SharedProjectContext,
 435        cx: &mut Context<Self>,
 436    ) -> Self {
 437        let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel();
 438        let configured_model = LanguageModelRegistry::read_global(cx).default_model();
 439        let profile_id = AgentSettings::get_global(cx).default_profile.clone();
 440
 441        Self {
 442            id: ThreadId::new(),
 443            updated_at: Utc::now(),
 444            summary: ThreadSummary::Pending,
 445            pending_summary: Task::ready(None),
 446            detailed_summary_task: Task::ready(None),
 447            detailed_summary_tx,
 448            detailed_summary_rx,
 449            completion_mode: AgentSettings::get_global(cx).preferred_completion_mode,
 450            messages: Vec::new(),
 451            next_message_id: MessageId(0),
 452            last_prompt_id: PromptId::new(),
 453            project_context: system_prompt,
 454            checkpoints_by_message: HashMap::default(),
 455            completion_count: 0,
 456            pending_completions: Vec::new(),
 457            project: project.clone(),
 458            prompt_builder,
 459            tools: tools.clone(),
 460            last_restore_checkpoint: None,
 461            pending_checkpoint: None,
 462            tool_use: ToolUseState::new(tools.clone()),
 463            action_log: cx.new(|_| ActionLog::new(project.clone())),
 464            initial_project_snapshot: {
 465                let project_snapshot = Self::project_snapshot(project, cx);
 466                cx.foreground_executor()
 467                    .spawn(async move { Some(project_snapshot.await) })
 468                    .shared()
 469            },
 470            request_token_usage: Vec::new(),
 471            cumulative_token_usage: TokenUsage::default(),
 472            exceeded_window_error: None,
 473            tool_use_limit_reached: false,
 474            feedback: None,
 475            retry_state: None,
 476            message_feedback: HashMap::default(),
 477            last_auto_capture_at: None,
 478            last_received_chunk_at: None,
 479            request_callback: None,
 480            remaining_turns: u32::MAX,
 481            configured_model,
 482            profile: AgentProfile::new(profile_id, tools),
 483        }
 484    }
 485
 486    pub fn deserialize(
 487        id: ThreadId,
 488        serialized: SerializedThread,
 489        project: Entity<Project>,
 490        tools: Entity<ToolWorkingSet>,
 491        prompt_builder: Arc<PromptBuilder>,
 492        project_context: SharedProjectContext,
 493        window: Option<&mut Window>, // None in headless mode
 494        cx: &mut Context<Self>,
 495    ) -> Self {
 496        let next_message_id = MessageId(
 497            serialized
 498                .messages
 499                .last()
 500                .map(|message| message.id.0 + 1)
 501                .unwrap_or(0),
 502        );
 503        let tool_use = ToolUseState::from_serialized_messages(
 504            tools.clone(),
 505            &serialized.messages,
 506            project.clone(),
 507            window,
 508            cx,
 509        );
 510        let (detailed_summary_tx, detailed_summary_rx) =
 511            postage::watch::channel_with(serialized.detailed_summary_state);
 512
 513        let configured_model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
 514            serialized
 515                .model
 516                .and_then(|model| {
 517                    let model = SelectedModel {
 518                        provider: model.provider.clone().into(),
 519                        model: model.model.clone().into(),
 520                    };
 521                    registry.select_model(&model, cx)
 522                })
 523                .or_else(|| registry.default_model())
 524        });
 525
 526        let completion_mode = serialized
 527            .completion_mode
 528            .unwrap_or_else(|| AgentSettings::get_global(cx).preferred_completion_mode);
 529        let profile_id = serialized
 530            .profile
 531            .unwrap_or_else(|| AgentSettings::get_global(cx).default_profile.clone());
 532
 533        Self {
 534            id,
 535            updated_at: serialized.updated_at,
 536            summary: ThreadSummary::Ready(serialized.summary),
 537            pending_summary: Task::ready(None),
 538            detailed_summary_task: Task::ready(None),
 539            detailed_summary_tx,
 540            detailed_summary_rx,
 541            completion_mode,
 542            retry_state: None,
 543            messages: serialized
 544                .messages
 545                .into_iter()
 546                .map(|message| Message {
 547                    id: message.id,
 548                    role: message.role,
 549                    segments: message
 550                        .segments
 551                        .into_iter()
 552                        .map(|segment| match segment {
 553                            SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
 554                            SerializedMessageSegment::Thinking { text, signature } => {
 555                                MessageSegment::Thinking { text, signature }
 556                            }
 557                            SerializedMessageSegment::RedactedThinking { data } => {
 558                                MessageSegment::RedactedThinking(data)
 559                            }
 560                        })
 561                        .collect(),
 562                    loaded_context: LoadedContext {
 563                        contexts: Vec::new(),
 564                        text: message.context,
 565                        images: Vec::new(),
 566                    },
 567                    creases: message
 568                        .creases
 569                        .into_iter()
 570                        .map(|crease| MessageCrease {
 571                            range: crease.start..crease.end,
 572                            icon_path: crease.icon_path,
 573                            label: crease.label,
 574                            context: None,
 575                        })
 576                        .collect(),
 577                    is_hidden: message.is_hidden,
 578                    ui_only: false, // UI-only messages are not persisted
 579                })
 580                .collect(),
 581            next_message_id,
 582            last_prompt_id: PromptId::new(),
 583            project_context,
 584            checkpoints_by_message: HashMap::default(),
 585            completion_count: 0,
 586            pending_completions: Vec::new(),
 587            last_restore_checkpoint: None,
 588            pending_checkpoint: None,
 589            project: project.clone(),
 590            prompt_builder,
 591            tools: tools.clone(),
 592            tool_use,
 593            action_log: cx.new(|_| ActionLog::new(project)),
 594            initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
 595            request_token_usage: serialized.request_token_usage,
 596            cumulative_token_usage: serialized.cumulative_token_usage,
 597            exceeded_window_error: None,
 598            tool_use_limit_reached: serialized.tool_use_limit_reached,
 599            feedback: None,
 600            message_feedback: HashMap::default(),
 601            last_auto_capture_at: None,
 602            last_received_chunk_at: None,
 603            request_callback: None,
 604            remaining_turns: u32::MAX,
 605            configured_model,
 606            profile: AgentProfile::new(profile_id, tools),
 607        }
 608    }
 609
 610    pub fn set_request_callback(
 611        &mut self,
 612        callback: impl 'static
 613        + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
 614    ) {
 615        self.request_callback = Some(Box::new(callback));
 616    }
 617
 618    pub fn id(&self) -> &ThreadId {
 619        &self.id
 620    }
 621
 622    pub fn profile(&self) -> &AgentProfile {
 623        &self.profile
 624    }
 625
 626    pub fn set_profile(&mut self, id: AgentProfileId, cx: &mut Context<Self>) {
 627        if &id != self.profile.id() {
 628            self.profile = AgentProfile::new(id, self.tools.clone());
 629            cx.emit(ThreadEvent::ProfileChanged);
 630        }
 631    }
 632
 633    pub fn is_empty(&self) -> bool {
 634        self.messages.is_empty()
 635    }
 636
 637    pub fn updated_at(&self) -> DateTime<Utc> {
 638        self.updated_at
 639    }
 640
 641    pub fn touch_updated_at(&mut self) {
 642        self.updated_at = Utc::now();
 643    }
 644
 645    pub fn advance_prompt_id(&mut self) {
 646        self.last_prompt_id = PromptId::new();
 647    }
 648
 649    pub fn project_context(&self) -> SharedProjectContext {
 650        self.project_context.clone()
 651    }
 652
 653    pub fn get_or_init_configured_model(&mut self, cx: &App) -> Option<ConfiguredModel> {
 654        if self.configured_model.is_none() {
 655            self.configured_model = LanguageModelRegistry::read_global(cx).default_model();
 656        }
 657        self.configured_model.clone()
 658    }
 659
 660    pub fn configured_model(&self) -> Option<ConfiguredModel> {
 661        self.configured_model.clone()
 662    }
 663
 664    pub fn set_configured_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
 665        self.configured_model = model;
 666        cx.notify();
 667    }
 668
 669    pub fn summary(&self) -> &ThreadSummary {
 670        &self.summary
 671    }
 672
 673    pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
 674        let current_summary = match &self.summary {
 675            ThreadSummary::Pending | ThreadSummary::Generating => return,
 676            ThreadSummary::Ready(summary) => summary,
 677            ThreadSummary::Error => &ThreadSummary::DEFAULT,
 678        };
 679
 680        let mut new_summary = new_summary.into();
 681
 682        if new_summary.is_empty() {
 683            new_summary = ThreadSummary::DEFAULT;
 684        }
 685
 686        if current_summary != &new_summary {
 687            self.summary = ThreadSummary::Ready(new_summary);
 688            cx.emit(ThreadEvent::SummaryChanged);
 689        }
 690    }
 691
 692    pub fn completion_mode(&self) -> CompletionMode {
 693        self.completion_mode
 694    }
 695
 696    pub fn set_completion_mode(&mut self, mode: CompletionMode) {
 697        self.completion_mode = mode;
 698    }
 699
 700    pub fn message(&self, id: MessageId) -> Option<&Message> {
 701        let index = self
 702            .messages
 703            .binary_search_by(|message| message.id.cmp(&id))
 704            .ok()?;
 705
 706        self.messages.get(index)
 707    }
 708
 709    pub fn messages(&self) -> impl ExactSizeIterator<Item = &Message> {
 710        self.messages.iter()
 711    }
 712
 713    pub fn is_generating(&self) -> bool {
 714        !self.pending_completions.is_empty() || !self.all_tools_finished()
 715    }
 716
 717    /// Indicates whether streaming of language model events is stale.
 718    /// When `is_generating()` is false, this method returns `None`.
 719    pub fn is_generation_stale(&self) -> Option<bool> {
 720        const STALE_THRESHOLD: u128 = 250;
 721
 722        self.last_received_chunk_at
 723            .map(|instant| instant.elapsed().as_millis() > STALE_THRESHOLD)
 724    }
 725
 726    fn received_chunk(&mut self) {
 727        self.last_received_chunk_at = Some(Instant::now());
 728    }
 729
 730    pub fn queue_state(&self) -> Option<QueueState> {
 731        self.pending_completions
 732            .first()
 733            .map(|pending_completion| pending_completion.queue_state)
 734    }
 735
 736    pub fn tools(&self) -> &Entity<ToolWorkingSet> {
 737        &self.tools
 738    }
 739
 740    pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
 741        self.tool_use
 742            .pending_tool_uses()
 743            .into_iter()
 744            .find(|tool_use| &tool_use.id == id)
 745    }
 746
 747    pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
 748        self.tool_use
 749            .pending_tool_uses()
 750            .into_iter()
 751            .filter(|tool_use| tool_use.status.needs_confirmation())
 752    }
 753
 754    pub fn has_pending_tool_uses(&self) -> bool {
 755        !self.tool_use.pending_tool_uses().is_empty()
 756    }
 757
 758    pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
 759        self.checkpoints_by_message.get(&id).cloned()
 760    }
 761
 762    pub fn restore_checkpoint(
 763        &mut self,
 764        checkpoint: ThreadCheckpoint,
 765        cx: &mut Context<Self>,
 766    ) -> Task<Result<()>> {
 767        self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
 768            message_id: checkpoint.message_id,
 769        });
 770        cx.emit(ThreadEvent::CheckpointChanged);
 771        cx.notify();
 772
 773        let git_store = self.project().read(cx).git_store().clone();
 774        let restore = git_store.update(cx, |git_store, cx| {
 775            git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
 776        });
 777
 778        cx.spawn(async move |this, cx| {
 779            let result = restore.await;
 780            this.update(cx, |this, cx| {
 781                if let Err(err) = result.as_ref() {
 782                    this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
 783                        message_id: checkpoint.message_id,
 784                        error: err.to_string(),
 785                    });
 786                } else {
 787                    this.truncate(checkpoint.message_id, cx);
 788                    this.last_restore_checkpoint = None;
 789                }
 790                this.pending_checkpoint = None;
 791                cx.emit(ThreadEvent::CheckpointChanged);
 792                cx.notify();
 793            })?;
 794            result
 795        })
 796    }
 797
 798    fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
 799        let pending_checkpoint = if self.is_generating() {
 800            return;
 801        } else if let Some(checkpoint) = self.pending_checkpoint.take() {
 802            checkpoint
 803        } else {
 804            return;
 805        };
 806
 807        self.finalize_checkpoint(pending_checkpoint, cx);
 808    }
 809
 810    fn finalize_checkpoint(
 811        &mut self,
 812        pending_checkpoint: ThreadCheckpoint,
 813        cx: &mut Context<Self>,
 814    ) {
 815        let git_store = self.project.read(cx).git_store().clone();
 816        let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
 817        cx.spawn(async move |this, cx| match final_checkpoint.await {
 818            Ok(final_checkpoint) => {
 819                let equal = git_store
 820                    .update(cx, |store, cx| {
 821                        store.compare_checkpoints(
 822                            pending_checkpoint.git_checkpoint.clone(),
 823                            final_checkpoint.clone(),
 824                            cx,
 825                        )
 826                    })?
 827                    .await
 828                    .unwrap_or(false);
 829
 830                if !equal {
 831                    this.update(cx, |this, cx| {
 832                        this.insert_checkpoint(pending_checkpoint, cx)
 833                    })?;
 834                }
 835
 836                Ok(())
 837            }
 838            Err(_) => this.update(cx, |this, cx| {
 839                this.insert_checkpoint(pending_checkpoint, cx)
 840            }),
 841        })
 842        .detach();
 843    }
 844
 845    fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
 846        self.checkpoints_by_message
 847            .insert(checkpoint.message_id, checkpoint);
 848        cx.emit(ThreadEvent::CheckpointChanged);
 849        cx.notify();
 850    }
 851
 852    pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
 853        self.last_restore_checkpoint.as_ref()
 854    }
 855
 856    pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
 857        let Some(message_ix) = self
 858            .messages
 859            .iter()
 860            .rposition(|message| message.id == message_id)
 861        else {
 862            return;
 863        };
 864        for deleted_message in self.messages.drain(message_ix..) {
 865            self.checkpoints_by_message.remove(&deleted_message.id);
 866        }
 867        cx.notify();
 868    }
 869
 870    pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AgentContext> {
 871        self.messages
 872            .iter()
 873            .find(|message| message.id == id)
 874            .into_iter()
 875            .flat_map(|message| message.loaded_context.contexts.iter())
 876    }
 877
 878    pub fn is_turn_end(&self, ix: usize) -> bool {
 879        if self.messages.is_empty() {
 880            return false;
 881        }
 882
 883        if !self.is_generating() && ix == self.messages.len() - 1 {
 884            return true;
 885        }
 886
 887        let Some(message) = self.messages.get(ix) else {
 888            return false;
 889        };
 890
 891        if message.role != Role::Assistant {
 892            return false;
 893        }
 894
 895        self.messages
 896            .get(ix + 1)
 897            .and_then(|message| {
 898                self.message(message.id)
 899                    .map(|next_message| next_message.role == Role::User && !next_message.is_hidden)
 900            })
 901            .unwrap_or(false)
 902    }
 903
 904    pub fn tool_use_limit_reached(&self) -> bool {
 905        self.tool_use_limit_reached
 906    }
 907
 908    /// Returns whether all of the tool uses have finished running.
 909    pub fn all_tools_finished(&self) -> bool {
 910        // If the only pending tool uses left are the ones with errors, then
 911        // that means that we've finished running all of the pending tools.
 912        self.tool_use
 913            .pending_tool_uses()
 914            .iter()
 915            .all(|pending_tool_use| pending_tool_use.status.is_error())
 916    }
 917
 918    /// Returns whether any pending tool uses may perform edits
 919    pub fn has_pending_edit_tool_uses(&self) -> bool {
 920        self.tool_use
 921            .pending_tool_uses()
 922            .iter()
 923            .filter(|pending_tool_use| !pending_tool_use.status.is_error())
 924            .any(|pending_tool_use| pending_tool_use.may_perform_edits)
 925    }
 926
 927    pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
 928        self.tool_use.tool_uses_for_message(id, cx)
 929    }
 930
 931    pub fn tool_results_for_message(
 932        &self,
 933        assistant_message_id: MessageId,
 934    ) -> Vec<&LanguageModelToolResult> {
 935        self.tool_use.tool_results_for_message(assistant_message_id)
 936    }
 937
 938    pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
 939        self.tool_use.tool_result(id)
 940    }
 941
 942    pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
 943        match &self.tool_use.tool_result(id)?.content {
 944            LanguageModelToolResultContent::Text(text) => Some(text),
 945            LanguageModelToolResultContent::Image(_) => {
 946                // TODO: We should display image
 947                None
 948            }
 949        }
 950    }
 951
 952    pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
 953        self.tool_use.tool_result_card(id).cloned()
 954    }
 955
 956    /// Return tools that are both enabled and supported by the model
 957    pub fn available_tools(
 958        &self,
 959        cx: &App,
 960        model: Arc<dyn LanguageModel>,
 961    ) -> Vec<LanguageModelRequestTool> {
 962        if model.supports_tools() {
 963            self.profile
 964                .enabled_tools(cx)
 965                .into_iter()
 966                .filter_map(|(name, tool)| {
 967                    // Skip tools that cannot be supported
 968                    let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
 969                    Some(LanguageModelRequestTool {
 970                        name: name.into(),
 971                        description: tool.description(),
 972                        input_schema,
 973                    })
 974                })
 975                .collect()
 976        } else {
 977            Vec::default()
 978        }
 979    }
 980
 981    pub fn insert_user_message(
 982        &mut self,
 983        text: impl Into<String>,
 984        loaded_context: ContextLoadResult,
 985        git_checkpoint: Option<GitStoreCheckpoint>,
 986        creases: Vec<MessageCrease>,
 987        cx: &mut Context<Self>,
 988    ) -> MessageId {
 989        if !loaded_context.referenced_buffers.is_empty() {
 990            self.action_log.update(cx, |log, cx| {
 991                for buffer in loaded_context.referenced_buffers {
 992                    log.buffer_read(buffer, cx);
 993                }
 994            });
 995        }
 996
 997        let message_id = self.insert_message(
 998            Role::User,
 999            vec![MessageSegment::Text(text.into())],
1000            loaded_context.loaded_context,
1001            creases,
1002            false,
1003            cx,
1004        );
1005
1006        if let Some(git_checkpoint) = git_checkpoint {
1007            self.pending_checkpoint = Some(ThreadCheckpoint {
1008                message_id,
1009                git_checkpoint,
1010            });
1011        }
1012
1013        self.auto_capture_telemetry(cx);
1014
1015        message_id
1016    }
1017
1018    pub fn insert_invisible_continue_message(&mut self, cx: &mut Context<Self>) -> MessageId {
1019        let id = self.insert_message(
1020            Role::User,
1021            vec![MessageSegment::Text("Continue where you left off".into())],
1022            LoadedContext::default(),
1023            vec![],
1024            true,
1025            cx,
1026        );
1027        self.pending_checkpoint = None;
1028
1029        id
1030    }
1031
1032    pub fn insert_assistant_message(
1033        &mut self,
1034        segments: Vec<MessageSegment>,
1035        cx: &mut Context<Self>,
1036    ) -> MessageId {
1037        self.insert_message(
1038            Role::Assistant,
1039            segments,
1040            LoadedContext::default(),
1041            Vec::new(),
1042            false,
1043            cx,
1044        )
1045    }
1046
1047    pub fn insert_message(
1048        &mut self,
1049        role: Role,
1050        segments: Vec<MessageSegment>,
1051        loaded_context: LoadedContext,
1052        creases: Vec<MessageCrease>,
1053        is_hidden: bool,
1054        cx: &mut Context<Self>,
1055    ) -> MessageId {
1056        let id = self.next_message_id.post_inc();
1057        self.messages.push(Message {
1058            id,
1059            role,
1060            segments,
1061            loaded_context,
1062            creases,
1063            is_hidden,
1064            ui_only: false,
1065        });
1066        self.touch_updated_at();
1067        cx.emit(ThreadEvent::MessageAdded(id));
1068        id
1069    }
1070
1071    pub fn edit_message(
1072        &mut self,
1073        id: MessageId,
1074        new_role: Role,
1075        new_segments: Vec<MessageSegment>,
1076        creases: Vec<MessageCrease>,
1077        loaded_context: Option<LoadedContext>,
1078        checkpoint: Option<GitStoreCheckpoint>,
1079        cx: &mut Context<Self>,
1080    ) -> bool {
1081        let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
1082            return false;
1083        };
1084        message.role = new_role;
1085        message.segments = new_segments;
1086        message.creases = creases;
1087        if let Some(context) = loaded_context {
1088            message.loaded_context = context;
1089        }
1090        if let Some(git_checkpoint) = checkpoint {
1091            self.checkpoints_by_message.insert(
1092                id,
1093                ThreadCheckpoint {
1094                    message_id: id,
1095                    git_checkpoint,
1096                },
1097            );
1098        }
1099        self.touch_updated_at();
1100        cx.emit(ThreadEvent::MessageEdited(id));
1101        true
1102    }
1103
1104    pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
1105        let Some(index) = self.messages.iter().position(|message| message.id == id) else {
1106            return false;
1107        };
1108        self.messages.remove(index);
1109        self.touch_updated_at();
1110        cx.emit(ThreadEvent::MessageDeleted(id));
1111        true
1112    }
1113
1114    /// Returns the representation of this [`Thread`] in a textual form.
1115    ///
1116    /// This is the representation we use when attaching a thread as context to another thread.
1117    pub fn text(&self) -> String {
1118        let mut text = String::new();
1119
1120        for message in &self.messages {
1121            text.push_str(match message.role {
1122                language_model::Role::User => "User:",
1123                language_model::Role::Assistant => "Agent:",
1124                language_model::Role::System => "System:",
1125            });
1126            text.push('\n');
1127
1128            for segment in &message.segments {
1129                match segment {
1130                    MessageSegment::Text(content) => text.push_str(content),
1131                    MessageSegment::Thinking { text: content, .. } => {
1132                        text.push_str(&format!("<think>{}</think>", content))
1133                    }
1134                    MessageSegment::RedactedThinking(_) => {}
1135                }
1136            }
1137            text.push('\n');
1138        }
1139
1140        text
1141    }
1142
1143    /// Serializes this thread into a format for storage or telemetry.
1144    pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
1145        let initial_project_snapshot = self.initial_project_snapshot.clone();
1146        cx.spawn(async move |this, cx| {
1147            let initial_project_snapshot = initial_project_snapshot.await;
1148            this.read_with(cx, |this, cx| SerializedThread {
1149                version: SerializedThread::VERSION.to_string(),
1150                summary: this.summary().or_default(),
1151                updated_at: this.updated_at(),
1152                messages: this
1153                    .messages()
1154                    .filter(|message| !message.ui_only)
1155                    .map(|message| SerializedMessage {
1156                        id: message.id,
1157                        role: message.role,
1158                        segments: message
1159                            .segments
1160                            .iter()
1161                            .map(|segment| match segment {
1162                                MessageSegment::Text(text) => {
1163                                    SerializedMessageSegment::Text { text: text.clone() }
1164                                }
1165                                MessageSegment::Thinking { text, signature } => {
1166                                    SerializedMessageSegment::Thinking {
1167                                        text: text.clone(),
1168                                        signature: signature.clone(),
1169                                    }
1170                                }
1171                                MessageSegment::RedactedThinking(data) => {
1172                                    SerializedMessageSegment::RedactedThinking {
1173                                        data: data.clone(),
1174                                    }
1175                                }
1176                            })
1177                            .collect(),
1178                        tool_uses: this
1179                            .tool_uses_for_message(message.id, cx)
1180                            .into_iter()
1181                            .map(|tool_use| SerializedToolUse {
1182                                id: tool_use.id,
1183                                name: tool_use.name,
1184                                input: tool_use.input,
1185                            })
1186                            .collect(),
1187                        tool_results: this
1188                            .tool_results_for_message(message.id)
1189                            .into_iter()
1190                            .map(|tool_result| SerializedToolResult {
1191                                tool_use_id: tool_result.tool_use_id.clone(),
1192                                is_error: tool_result.is_error,
1193                                content: tool_result.content.clone(),
1194                                output: tool_result.output.clone(),
1195                            })
1196                            .collect(),
1197                        context: message.loaded_context.text.clone(),
1198                        creases: message
1199                            .creases
1200                            .iter()
1201                            .map(|crease| SerializedCrease {
1202                                start: crease.range.start,
1203                                end: crease.range.end,
1204                                icon_path: crease.icon_path.clone(),
1205                                label: crease.label.clone(),
1206                            })
1207                            .collect(),
1208                        is_hidden: message.is_hidden,
1209                    })
1210                    .collect(),
1211                initial_project_snapshot,
1212                cumulative_token_usage: this.cumulative_token_usage,
1213                request_token_usage: this.request_token_usage.clone(),
1214                detailed_summary_state: this.detailed_summary_rx.borrow().clone(),
1215                exceeded_window_error: this.exceeded_window_error.clone(),
1216                model: this
1217                    .configured_model
1218                    .as_ref()
1219                    .map(|model| SerializedLanguageModel {
1220                        provider: model.provider.id().0.to_string(),
1221                        model: model.model.id().0.to_string(),
1222                    }),
1223                completion_mode: Some(this.completion_mode),
1224                tool_use_limit_reached: this.tool_use_limit_reached,
1225                profile: Some(this.profile.id().clone()),
1226            })
1227        })
1228    }
1229
1230    pub fn remaining_turns(&self) -> u32 {
1231        self.remaining_turns
1232    }
1233
1234    pub fn set_remaining_turns(&mut self, remaining_turns: u32) {
1235        self.remaining_turns = remaining_turns;
1236    }
1237
1238    pub fn send_to_model(
1239        &mut self,
1240        model: Arc<dyn LanguageModel>,
1241        intent: CompletionIntent,
1242        window: Option<AnyWindowHandle>,
1243        cx: &mut Context<Self>,
1244    ) {
1245        if self.remaining_turns == 0 {
1246            return;
1247        }
1248
1249        self.remaining_turns -= 1;
1250
1251        self.flush_notifications(model.clone(), intent, cx);
1252
1253        let request = self.to_completion_request(model.clone(), intent, cx);
1254
1255        self.stream_completion(request, model, intent, window, cx);
1256    }
1257
1258    pub fn used_tools_since_last_user_message(&self) -> bool {
1259        for message in self.messages.iter().rev() {
1260            if self.tool_use.message_has_tool_results(message.id) {
1261                return true;
1262            } else if message.role == Role::User {
1263                return false;
1264            }
1265        }
1266
1267        false
1268    }
1269
1270    pub fn to_completion_request(
1271        &self,
1272        model: Arc<dyn LanguageModel>,
1273        intent: CompletionIntent,
1274        cx: &mut Context<Self>,
1275    ) -> LanguageModelRequest {
1276        let mut request = LanguageModelRequest {
1277            thread_id: Some(self.id.to_string()),
1278            prompt_id: Some(self.last_prompt_id.to_string()),
1279            intent: Some(intent),
1280            mode: None,
1281            messages: vec![],
1282            tools: Vec::new(),
1283            tool_choice: None,
1284            stop: Vec::new(),
1285            temperature: AgentSettings::temperature_for_model(&model, cx),
1286        };
1287
1288        let available_tools = self.available_tools(cx, model.clone());
1289        let available_tool_names = available_tools
1290            .iter()
1291            .map(|tool| tool.name.clone())
1292            .collect();
1293
1294        let model_context = &ModelContext {
1295            available_tools: available_tool_names,
1296        };
1297
1298        if let Some(project_context) = self.project_context.borrow().as_ref() {
1299            match self
1300                .prompt_builder
1301                .generate_assistant_system_prompt(project_context, model_context)
1302            {
1303                Err(err) => {
1304                    let message = format!("{err:?}").into();
1305                    log::error!("{message}");
1306                    cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1307                        header: "Error generating system prompt".into(),
1308                        message,
1309                    }));
1310                }
1311                Ok(system_prompt) => {
1312                    request.messages.push(LanguageModelRequestMessage {
1313                        role: Role::System,
1314                        content: vec![MessageContent::Text(system_prompt)],
1315                        cache: true,
1316                    });
1317                }
1318            }
1319        } else {
1320            let message = "Context for system prompt unexpectedly not ready.".into();
1321            log::error!("{message}");
1322            cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1323                header: "Error generating system prompt".into(),
1324                message,
1325            }));
1326        }
1327
1328        let mut message_ix_to_cache = None;
1329        for message in &self.messages {
1330            // ui_only messages are for the UI only, not for the model
1331            if message.ui_only {
1332                continue;
1333            }
1334
1335            let mut request_message = LanguageModelRequestMessage {
1336                role: message.role,
1337                content: Vec::new(),
1338                cache: false,
1339            };
1340
1341            message
1342                .loaded_context
1343                .add_to_request_message(&mut request_message);
1344
1345            for segment in &message.segments {
1346                match segment {
1347                    MessageSegment::Text(text) => {
1348                        let text = text.trim_end();
1349                        if !text.is_empty() {
1350                            request_message
1351                                .content
1352                                .push(MessageContent::Text(text.into()));
1353                        }
1354                    }
1355                    MessageSegment::Thinking { text, signature } => {
1356                        if !text.is_empty() {
1357                            request_message.content.push(MessageContent::Thinking {
1358                                text: text.into(),
1359                                signature: signature.clone(),
1360                            });
1361                        }
1362                    }
1363                    MessageSegment::RedactedThinking(data) => {
1364                        request_message
1365                            .content
1366                            .push(MessageContent::RedactedThinking(data.clone()));
1367                    }
1368                };
1369            }
1370
1371            let mut cache_message = true;
1372            let mut tool_results_message = LanguageModelRequestMessage {
1373                role: Role::User,
1374                content: Vec::new(),
1375                cache: false,
1376            };
1377            for (tool_use, tool_result) in self.tool_use.tool_results(message.id) {
1378                if let Some(tool_result) = tool_result {
1379                    request_message
1380                        .content
1381                        .push(MessageContent::ToolUse(tool_use.clone()));
1382                    tool_results_message
1383                        .content
1384                        .push(MessageContent::ToolResult(LanguageModelToolResult {
1385                            tool_use_id: tool_use.id.clone(),
1386                            tool_name: tool_result.tool_name.clone(),
1387                            is_error: tool_result.is_error,
1388                            content: if tool_result.content.is_empty() {
1389                                // Surprisingly, the API fails if we return an empty string here.
1390                                // It thinks we are sending a tool use without a tool result.
1391                                "<Tool returned an empty string>".into()
1392                            } else {
1393                                tool_result.content.clone()
1394                            },
1395                            output: None,
1396                        }));
1397                } else {
1398                    cache_message = false;
1399                    log::debug!(
1400                        "skipped tool use {:?} because it is still pending",
1401                        tool_use
1402                    );
1403                }
1404            }
1405
1406            if cache_message {
1407                message_ix_to_cache = Some(request.messages.len());
1408            }
1409            request.messages.push(request_message);
1410
1411            if !tool_results_message.content.is_empty() {
1412                if cache_message {
1413                    message_ix_to_cache = Some(request.messages.len());
1414                }
1415                request.messages.push(tool_results_message);
1416            }
1417        }
1418
1419        // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1420        if let Some(message_ix_to_cache) = message_ix_to_cache {
1421            request.messages[message_ix_to_cache].cache = true;
1422        }
1423
1424        request.tools = available_tools;
1425        request.mode = if model.supports_burn_mode() {
1426            Some(self.completion_mode.into())
1427        } else {
1428            Some(CompletionMode::Normal.into())
1429        };
1430
1431        request
1432    }
1433
1434    fn to_summarize_request(
1435        &self,
1436        model: &Arc<dyn LanguageModel>,
1437        intent: CompletionIntent,
1438        added_user_message: String,
1439        cx: &App,
1440    ) -> LanguageModelRequest {
1441        let mut request = LanguageModelRequest {
1442            thread_id: None,
1443            prompt_id: None,
1444            intent: Some(intent),
1445            mode: None,
1446            messages: vec![],
1447            tools: Vec::new(),
1448            tool_choice: None,
1449            stop: Vec::new(),
1450            temperature: AgentSettings::temperature_for_model(model, cx),
1451        };
1452
1453        for message in &self.messages {
1454            let mut request_message = LanguageModelRequestMessage {
1455                role: message.role,
1456                content: Vec::new(),
1457                cache: false,
1458            };
1459
1460            for segment in &message.segments {
1461                match segment {
1462                    MessageSegment::Text(text) => request_message
1463                        .content
1464                        .push(MessageContent::Text(text.clone())),
1465                    MessageSegment::Thinking { .. } => {}
1466                    MessageSegment::RedactedThinking(_) => {}
1467                }
1468            }
1469
1470            if request_message.content.is_empty() {
1471                continue;
1472            }
1473
1474            request.messages.push(request_message);
1475        }
1476
1477        request.messages.push(LanguageModelRequestMessage {
1478            role: Role::User,
1479            content: vec![MessageContent::Text(added_user_message)],
1480            cache: false,
1481        });
1482
1483        request
1484    }
1485
1486    /// Insert auto-generated notifications (if any) to the thread
1487    fn flush_notifications(
1488        &mut self,
1489        model: Arc<dyn LanguageModel>,
1490        intent: CompletionIntent,
1491        cx: &mut Context<Self>,
1492    ) {
1493        match intent {
1494            CompletionIntent::UserPrompt | CompletionIntent::ToolResults => {
1495                if let Some(pending_tool_use) = self.attach_tracked_files_state(model, cx) {
1496                    cx.emit(ThreadEvent::ToolFinished {
1497                        tool_use_id: pending_tool_use.id.clone(),
1498                        pending_tool_use: Some(pending_tool_use),
1499                    });
1500                }
1501            }
1502            CompletionIntent::ThreadSummarization
1503            | CompletionIntent::ThreadContextSummarization
1504            | CompletionIntent::CreateFile
1505            | CompletionIntent::EditFile
1506            | CompletionIntent::InlineAssist
1507            | CompletionIntent::TerminalInlineAssist
1508            | CompletionIntent::GenerateGitCommitMessage => {}
1509        };
1510    }
1511
1512    fn attach_tracked_files_state(
1513        &mut self,
1514        model: Arc<dyn LanguageModel>,
1515        cx: &mut App,
1516    ) -> Option<PendingToolUse> {
1517        let action_log = self.action_log.read(cx);
1518
1519        action_log.stale_buffers(cx).next()?;
1520
1521        // Represent notification as a simulated `project_notifications` tool call
1522        let tool_name = Arc::from("project_notifications");
1523        let Some(tool) = self.tools.read(cx).tool(&tool_name, cx) else {
1524            debug_panic!("`project_notifications` tool not found");
1525            return None;
1526        };
1527
1528        if !self.profile.is_tool_enabled(tool.source(), tool.name(), cx) {
1529            return None;
1530        }
1531
1532        let input = serde_json::json!({});
1533        let request = Arc::new(LanguageModelRequest::default()); // unused
1534        let window = None;
1535        let tool_result = tool.run(
1536            input,
1537            request,
1538            self.project.clone(),
1539            self.action_log.clone(),
1540            model.clone(),
1541            window,
1542            cx,
1543        );
1544
1545        let tool_use_id =
1546            LanguageModelToolUseId::from(format!("project_notifications_{}", self.messages.len()));
1547
1548        let tool_use = LanguageModelToolUse {
1549            id: tool_use_id.clone(),
1550            name: tool_name.clone(),
1551            raw_input: "{}".to_string(),
1552            input: serde_json::json!({}),
1553            is_input_complete: true,
1554        };
1555
1556        let tool_output = cx.background_executor().block(tool_result.output);
1557
1558        // Attach a project_notification tool call to the latest existing
1559        // Assistant message. We cannot create a new Assistant message
1560        // because thinking models require a `thinking` block that we
1561        // cannot mock. We cannot send a notification as a normal
1562        // (non-tool-use) User message because this distracts Agent
1563        // too much.
1564        let tool_message_id = self
1565            .messages
1566            .iter()
1567            .enumerate()
1568            .rfind(|(_, message)| message.role == Role::Assistant)
1569            .map(|(_, message)| message.id)?;
1570
1571        let tool_use_metadata = ToolUseMetadata {
1572            model: model.clone(),
1573            thread_id: self.id.clone(),
1574            prompt_id: self.last_prompt_id.clone(),
1575        };
1576
1577        self.tool_use
1578            .request_tool_use(tool_message_id, tool_use, tool_use_metadata.clone(), cx);
1579
1580        let pending_tool_use = self.tool_use.insert_tool_output(
1581            tool_use_id.clone(),
1582            tool_name,
1583            tool_output,
1584            self.configured_model.as_ref(),
1585        );
1586
1587        pending_tool_use
1588    }
1589
1590    pub fn stream_completion(
1591        &mut self,
1592        request: LanguageModelRequest,
1593        model: Arc<dyn LanguageModel>,
1594        intent: CompletionIntent,
1595        window: Option<AnyWindowHandle>,
1596        cx: &mut Context<Self>,
1597    ) {
1598        self.tool_use_limit_reached = false;
1599
1600        let pending_completion_id = post_inc(&mut self.completion_count);
1601        let mut request_callback_parameters = if self.request_callback.is_some() {
1602            Some((request.clone(), Vec::new()))
1603        } else {
1604            None
1605        };
1606        let prompt_id = self.last_prompt_id.clone();
1607        let tool_use_metadata = ToolUseMetadata {
1608            model: model.clone(),
1609            thread_id: self.id.clone(),
1610            prompt_id: prompt_id.clone(),
1611        };
1612
1613        self.last_received_chunk_at = Some(Instant::now());
1614
1615        let task = cx.spawn(async move |thread, cx| {
1616            let stream_completion_future = model.stream_completion(request, &cx);
1617            let initial_token_usage =
1618                thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1619            let stream_completion = async {
1620                let mut events = stream_completion_future.await?;
1621
1622                let mut stop_reason = StopReason::EndTurn;
1623                let mut current_token_usage = TokenUsage::default();
1624
1625                thread
1626                    .update(cx, |_thread, cx| {
1627                        cx.emit(ThreadEvent::NewRequest);
1628                    })
1629                    .ok();
1630
1631                let mut request_assistant_message_id = None;
1632
1633                while let Some(event) = events.next().await {
1634                    if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1635                        response_events
1636                            .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1637                    }
1638
1639                    thread.update(cx, |thread, cx| {
1640                        match event? {
1641                            LanguageModelCompletionEvent::StartMessage { .. } => {
1642                                request_assistant_message_id =
1643                                    Some(thread.insert_assistant_message(
1644                                        vec![MessageSegment::Text(String::new())],
1645                                        cx,
1646                                    ));
1647                            }
1648                            LanguageModelCompletionEvent::Stop(reason) => {
1649                                stop_reason = reason;
1650                            }
1651                            LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1652                                thread.update_token_usage_at_last_message(token_usage);
1653                                thread.cumulative_token_usage = thread.cumulative_token_usage
1654                                    + token_usage
1655                                    - current_token_usage;
1656                                current_token_usage = token_usage;
1657                            }
1658                            LanguageModelCompletionEvent::Text(chunk) => {
1659                                thread.received_chunk();
1660
1661                                cx.emit(ThreadEvent::ReceivedTextChunk);
1662                                if let Some(last_message) = thread.messages.last_mut() {
1663                                    if last_message.role == Role::Assistant
1664                                        && !thread.tool_use.has_tool_results(last_message.id)
1665                                    {
1666                                        last_message.push_text(&chunk);
1667                                        cx.emit(ThreadEvent::StreamedAssistantText(
1668                                            last_message.id,
1669                                            chunk,
1670                                        ));
1671                                    } else {
1672                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
1673                                        // of a new Assistant response.
1674                                        //
1675                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1676                                        // will result in duplicating the text of the chunk in the rendered Markdown.
1677                                        request_assistant_message_id =
1678                                            Some(thread.insert_assistant_message(
1679                                                vec![MessageSegment::Text(chunk.to_string())],
1680                                                cx,
1681                                            ));
1682                                    };
1683                                }
1684                            }
1685                            LanguageModelCompletionEvent::Thinking {
1686                                text: chunk,
1687                                signature,
1688                            } => {
1689                                thread.received_chunk();
1690
1691                                if let Some(last_message) = thread.messages.last_mut() {
1692                                    if last_message.role == Role::Assistant
1693                                        && !thread.tool_use.has_tool_results(last_message.id)
1694                                    {
1695                                        last_message.push_thinking(&chunk, signature);
1696                                        cx.emit(ThreadEvent::StreamedAssistantThinking(
1697                                            last_message.id,
1698                                            chunk,
1699                                        ));
1700                                    } else {
1701                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
1702                                        // of a new Assistant response.
1703                                        //
1704                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1705                                        // will result in duplicating the text of the chunk in the rendered Markdown.
1706                                        request_assistant_message_id =
1707                                            Some(thread.insert_assistant_message(
1708                                                vec![MessageSegment::Thinking {
1709                                                    text: chunk.to_string(),
1710                                                    signature,
1711                                                }],
1712                                                cx,
1713                                            ));
1714                                    };
1715                                }
1716                            }
1717                            LanguageModelCompletionEvent::RedactedThinking { data } => {
1718                                thread.received_chunk();
1719
1720                                if let Some(last_message) = thread.messages.last_mut() {
1721                                    if last_message.role == Role::Assistant
1722                                        && !thread.tool_use.has_tool_results(last_message.id)
1723                                    {
1724                                        last_message.push_redacted_thinking(data);
1725                                    } else {
1726                                        request_assistant_message_id =
1727                                            Some(thread.insert_assistant_message(
1728                                                vec![MessageSegment::RedactedThinking(data)],
1729                                                cx,
1730                                            ));
1731                                    };
1732                                }
1733                            }
1734                            LanguageModelCompletionEvent::ToolUse(tool_use) => {
1735                                let last_assistant_message_id = request_assistant_message_id
1736                                    .unwrap_or_else(|| {
1737                                        let new_assistant_message_id =
1738                                            thread.insert_assistant_message(vec![], cx);
1739                                        request_assistant_message_id =
1740                                            Some(new_assistant_message_id);
1741                                        new_assistant_message_id
1742                                    });
1743
1744                                let tool_use_id = tool_use.id.clone();
1745                                let streamed_input = if tool_use.is_input_complete {
1746                                    None
1747                                } else {
1748                                    Some((&tool_use.input).clone())
1749                                };
1750
1751                                let ui_text = thread.tool_use.request_tool_use(
1752                                    last_assistant_message_id,
1753                                    tool_use,
1754                                    tool_use_metadata.clone(),
1755                                    cx,
1756                                );
1757
1758                                if let Some(input) = streamed_input {
1759                                    cx.emit(ThreadEvent::StreamedToolUse {
1760                                        tool_use_id,
1761                                        ui_text,
1762                                        input,
1763                                    });
1764                                }
1765                            }
1766                            LanguageModelCompletionEvent::ToolUseJsonParseError {
1767                                id,
1768                                tool_name,
1769                                raw_input: invalid_input_json,
1770                                json_parse_error,
1771                            } => {
1772                                thread.receive_invalid_tool_json(
1773                                    id,
1774                                    tool_name,
1775                                    invalid_input_json,
1776                                    json_parse_error,
1777                                    window,
1778                                    cx,
1779                                );
1780                            }
1781                            LanguageModelCompletionEvent::StatusUpdate(status_update) => {
1782                                if let Some(completion) = thread
1783                                    .pending_completions
1784                                    .iter_mut()
1785                                    .find(|completion| completion.id == pending_completion_id)
1786                                {
1787                                    match status_update {
1788                                        CompletionRequestStatus::Queued { position } => {
1789                                            completion.queue_state =
1790                                                QueueState::Queued { position };
1791                                        }
1792                                        CompletionRequestStatus::Started => {
1793                                            completion.queue_state = QueueState::Started;
1794                                        }
1795                                        CompletionRequestStatus::Failed {
1796                                            code,
1797                                            message,
1798                                            request_id: _,
1799                                            retry_after,
1800                                        } => {
1801                                            return Err(
1802                                                LanguageModelCompletionError::from_cloud_failure(
1803                                                    model.upstream_provider_name(),
1804                                                    code,
1805                                                    message,
1806                                                    retry_after.map(Duration::from_secs_f64),
1807                                                ),
1808                                            );
1809                                        }
1810                                        CompletionRequestStatus::UsageUpdated { amount, limit } => {
1811                                            thread.update_model_request_usage(
1812                                                amount as u32,
1813                                                limit,
1814                                                cx,
1815                                            );
1816                                        }
1817                                        CompletionRequestStatus::ToolUseLimitReached => {
1818                                            thread.tool_use_limit_reached = true;
1819                                            cx.emit(ThreadEvent::ToolUseLimitReached);
1820                                        }
1821                                    }
1822                                }
1823                            }
1824                        }
1825
1826                        thread.touch_updated_at();
1827                        cx.emit(ThreadEvent::StreamedCompletion);
1828                        cx.notify();
1829
1830                        thread.auto_capture_telemetry(cx);
1831                        Ok(())
1832                    })??;
1833
1834                    smol::future::yield_now().await;
1835                }
1836
1837                thread.update(cx, |thread, cx| {
1838                    thread.last_received_chunk_at = None;
1839                    thread
1840                        .pending_completions
1841                        .retain(|completion| completion.id != pending_completion_id);
1842
1843                    // If there is a response without tool use, summarize the message. Otherwise,
1844                    // allow two tool uses before summarizing.
1845                    if matches!(thread.summary, ThreadSummary::Pending)
1846                        && thread.messages.len() >= 2
1847                        && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
1848                    {
1849                        thread.summarize(cx);
1850                    }
1851                })?;
1852
1853                anyhow::Ok(stop_reason)
1854            };
1855
1856            let result = stream_completion.await;
1857            let mut retry_scheduled = false;
1858
1859            thread
1860                .update(cx, |thread, cx| {
1861                    thread.finalize_pending_checkpoint(cx);
1862                    match result.as_ref() {
1863                        Ok(stop_reason) => {
1864                            match stop_reason {
1865                                StopReason::ToolUse => {
1866                                    let tool_uses =
1867                                        thread.use_pending_tools(window, model.clone(), cx);
1868                                    cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1869                                }
1870                                StopReason::EndTurn | StopReason::MaxTokens => {
1871                                    thread.project.update(cx, |project, cx| {
1872                                        project.set_agent_location(None, cx);
1873                                    });
1874                                }
1875                                StopReason::Refusal => {
1876                                    thread.project.update(cx, |project, cx| {
1877                                        project.set_agent_location(None, cx);
1878                                    });
1879
1880                                    // Remove the turn that was refused.
1881                                    //
1882                                    // https://docs.anthropic.com/en/docs/test-and-evaluate/strengthen-guardrails/handle-streaming-refusals#reset-context-after-refusal
1883                                    {
1884                                        let mut messages_to_remove = Vec::new();
1885
1886                                        for (ix, message) in
1887                                            thread.messages.iter().enumerate().rev()
1888                                        {
1889                                            messages_to_remove.push(message.id);
1890
1891                                            if message.role == Role::User {
1892                                                if ix == 0 {
1893                                                    break;
1894                                                }
1895
1896                                                if let Some(prev_message) =
1897                                                    thread.messages.get(ix - 1)
1898                                                {
1899                                                    if prev_message.role == Role::Assistant {
1900                                                        break;
1901                                                    }
1902                                                }
1903                                            }
1904                                        }
1905
1906                                        for message_id in messages_to_remove {
1907                                            thread.delete_message(message_id, cx);
1908                                        }
1909                                    }
1910
1911                                    cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1912                                        header: "Language model refusal".into(),
1913                                        message:
1914                                            "Model refused to generate content for safety reasons."
1915                                                .into(),
1916                                    }));
1917                                }
1918                            }
1919
1920                            // We successfully completed, so cancel any remaining retries.
1921                            thread.retry_state = None;
1922                        }
1923                        Err(error) => {
1924                            thread.project.update(cx, |project, cx| {
1925                                project.set_agent_location(None, cx);
1926                            });
1927
1928                            fn emit_generic_error(error: &anyhow::Error, cx: &mut Context<Thread>) {
1929                                let error_message = error
1930                                    .chain()
1931                                    .map(|err| err.to_string())
1932                                    .collect::<Vec<_>>()
1933                                    .join("\n");
1934                                cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1935                                    header: "Error interacting with language model".into(),
1936                                    message: SharedString::from(error_message.clone()),
1937                                }));
1938                            }
1939
1940                            if error.is::<PaymentRequiredError>() {
1941                                cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1942                            } else if let Some(error) =
1943                                error.downcast_ref::<ModelRequestLimitReachedError>()
1944                            {
1945                                cx.emit(ThreadEvent::ShowError(
1946                                    ThreadError::ModelRequestLimitReached { plan: error.plan },
1947                                ));
1948                            } else if let Some(completion_error) =
1949                                error.downcast_ref::<LanguageModelCompletionError>()
1950                            {
1951                                use LanguageModelCompletionError::*;
1952                                match &completion_error {
1953                                    PromptTooLarge { tokens, .. } => {
1954                                        let tokens = tokens.unwrap_or_else(|| {
1955                                            // We didn't get an exact token count from the API, so fall back on our estimate.
1956                                            thread
1957                                                .total_token_usage()
1958                                                .map(|usage| usage.total)
1959                                                .unwrap_or(0)
1960                                                // We know the context window was exceeded in practice, so if our estimate was
1961                                                // lower than max tokens, the estimate was wrong; return that we exceeded by 1.
1962                                                .max(model.max_token_count().saturating_add(1))
1963                                        });
1964                                        thread.exceeded_window_error = Some(ExceededWindowError {
1965                                            model_id: model.id(),
1966                                            token_count: tokens,
1967                                        });
1968                                        cx.notify();
1969                                    }
1970                                    RateLimitExceeded {
1971                                        retry_after: Some(retry_after),
1972                                        ..
1973                                    }
1974                                    | ServerOverloaded {
1975                                        retry_after: Some(retry_after),
1976                                        ..
1977                                    } => {
1978                                        thread.handle_rate_limit_error(
1979                                            &completion_error,
1980                                            *retry_after,
1981                                            model.clone(),
1982                                            intent,
1983                                            window,
1984                                            cx,
1985                                        );
1986                                        retry_scheduled = true;
1987                                    }
1988                                    RateLimitExceeded { .. } | ServerOverloaded { .. } => {
1989                                        retry_scheduled = thread.handle_retryable_error(
1990                                            &completion_error,
1991                                            model.clone(),
1992                                            intent,
1993                                            window,
1994                                            cx,
1995                                        );
1996                                        if !retry_scheduled {
1997                                            emit_generic_error(error, cx);
1998                                        }
1999                                    }
2000                                    ApiInternalServerError { .. }
2001                                    | ApiReadResponseError { .. }
2002                                    | HttpSend { .. } => {
2003                                        retry_scheduled = thread.handle_retryable_error(
2004                                            &completion_error,
2005                                            model.clone(),
2006                                            intent,
2007                                            window,
2008                                            cx,
2009                                        );
2010                                        if !retry_scheduled {
2011                                            emit_generic_error(error, cx);
2012                                        }
2013                                    }
2014                                    NoApiKey { .. }
2015                                    | HttpResponseError { .. }
2016                                    | BadRequestFormat { .. }
2017                                    | AuthenticationError { .. }
2018                                    | PermissionError { .. }
2019                                    | ApiEndpointNotFound { .. }
2020                                    | SerializeRequest { .. }
2021                                    | BuildRequestBody { .. }
2022                                    | DeserializeResponse { .. }
2023                                    | Other { .. } => emit_generic_error(error, cx),
2024                                }
2025                            } else {
2026                                emit_generic_error(error, cx);
2027                            }
2028
2029                            if !retry_scheduled {
2030                                thread.cancel_last_completion(window, cx);
2031                            }
2032                        }
2033                    }
2034
2035                    if !retry_scheduled {
2036                        cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
2037                    }
2038
2039                    if let Some((request_callback, (request, response_events))) = thread
2040                        .request_callback
2041                        .as_mut()
2042                        .zip(request_callback_parameters.as_ref())
2043                    {
2044                        request_callback(request, response_events);
2045                    }
2046
2047                    thread.auto_capture_telemetry(cx);
2048
2049                    if let Ok(initial_usage) = initial_token_usage {
2050                        let usage = thread.cumulative_token_usage - initial_usage;
2051
2052                        telemetry::event!(
2053                            "Assistant Thread Completion",
2054                            thread_id = thread.id().to_string(),
2055                            prompt_id = prompt_id,
2056                            model = model.telemetry_id(),
2057                            model_provider = model.provider_id().to_string(),
2058                            input_tokens = usage.input_tokens,
2059                            output_tokens = usage.output_tokens,
2060                            cache_creation_input_tokens = usage.cache_creation_input_tokens,
2061                            cache_read_input_tokens = usage.cache_read_input_tokens,
2062                        );
2063                    }
2064                })
2065                .ok();
2066        });
2067
2068        self.pending_completions.push(PendingCompletion {
2069            id: pending_completion_id,
2070            queue_state: QueueState::Sending,
2071            _task: task,
2072        });
2073    }
2074
2075    pub fn summarize(&mut self, cx: &mut Context<Self>) {
2076        let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
2077            println!("No thread summary model");
2078            return;
2079        };
2080
2081        if !model.provider.is_authenticated(cx) {
2082            return;
2083        }
2084
2085        let added_user_message = include_str!("./prompts/summarize_thread_prompt.txt");
2086
2087        let request = self.to_summarize_request(
2088            &model.model,
2089            CompletionIntent::ThreadSummarization,
2090            added_user_message.into(),
2091            cx,
2092        );
2093
2094        self.summary = ThreadSummary::Generating;
2095
2096        self.pending_summary = cx.spawn(async move |this, cx| {
2097            let result = async {
2098                let mut messages = model.model.stream_completion(request, &cx).await?;
2099
2100                let mut new_summary = String::new();
2101                while let Some(event) = messages.next().await {
2102                    let Ok(event) = event else {
2103                        continue;
2104                    };
2105                    let text = match event {
2106                        LanguageModelCompletionEvent::Text(text) => text,
2107                        LanguageModelCompletionEvent::StatusUpdate(
2108                            CompletionRequestStatus::UsageUpdated { amount, limit },
2109                        ) => {
2110                            this.update(cx, |thread, cx| {
2111                                thread.update_model_request_usage(amount as u32, limit, cx);
2112                            })?;
2113                            continue;
2114                        }
2115                        _ => continue,
2116                    };
2117
2118                    let mut lines = text.lines();
2119                    new_summary.extend(lines.next());
2120
2121                    // Stop if the LLM generated multiple lines.
2122                    if lines.next().is_some() {
2123                        break;
2124                    }
2125                }
2126
2127                anyhow::Ok(new_summary)
2128            }
2129            .await;
2130
2131            this.update(cx, |this, cx| {
2132                match result {
2133                    Ok(new_summary) => {
2134                        if new_summary.is_empty() {
2135                            this.summary = ThreadSummary::Error;
2136                        } else {
2137                            this.summary = ThreadSummary::Ready(new_summary.into());
2138                        }
2139                    }
2140                    Err(err) => {
2141                        this.summary = ThreadSummary::Error;
2142                        log::error!("Failed to generate thread summary: {}", err);
2143                    }
2144                }
2145                cx.emit(ThreadEvent::SummaryGenerated);
2146            })
2147            .log_err()?;
2148
2149            Some(())
2150        });
2151    }
2152
2153    fn handle_rate_limit_error(
2154        &mut self,
2155        error: &LanguageModelCompletionError,
2156        retry_after: Duration,
2157        model: Arc<dyn LanguageModel>,
2158        intent: CompletionIntent,
2159        window: Option<AnyWindowHandle>,
2160        cx: &mut Context<Self>,
2161    ) {
2162        // For rate limit errors, we only retry once with the specified duration
2163        let retry_message = format!("{error}. Retrying in {} seconds…", retry_after.as_secs());
2164        log::warn!(
2165            "Retrying completion request in {} seconds: {error:?}",
2166            retry_after.as_secs(),
2167        );
2168
2169        // Add a UI-only message instead of a regular message
2170        let id = self.next_message_id.post_inc();
2171        self.messages.push(Message {
2172            id,
2173            role: Role::System,
2174            segments: vec![MessageSegment::Text(retry_message)],
2175            loaded_context: LoadedContext::default(),
2176            creases: Vec::new(),
2177            is_hidden: false,
2178            ui_only: true,
2179        });
2180        cx.emit(ThreadEvent::MessageAdded(id));
2181        // Schedule the retry
2182        let thread_handle = cx.entity().downgrade();
2183
2184        cx.spawn(async move |_thread, cx| {
2185            cx.background_executor().timer(retry_after).await;
2186
2187            thread_handle
2188                .update(cx, |thread, cx| {
2189                    // Retry the completion
2190                    thread.send_to_model(model, intent, window, cx);
2191                })
2192                .log_err();
2193        })
2194        .detach();
2195    }
2196
2197    fn handle_retryable_error(
2198        &mut self,
2199        error: &LanguageModelCompletionError,
2200        model: Arc<dyn LanguageModel>,
2201        intent: CompletionIntent,
2202        window: Option<AnyWindowHandle>,
2203        cx: &mut Context<Self>,
2204    ) -> bool {
2205        self.handle_retryable_error_with_delay(error, None, model, intent, window, cx)
2206    }
2207
2208    fn handle_retryable_error_with_delay(
2209        &mut self,
2210        error: &LanguageModelCompletionError,
2211        custom_delay: Option<Duration>,
2212        model: Arc<dyn LanguageModel>,
2213        intent: CompletionIntent,
2214        window: Option<AnyWindowHandle>,
2215        cx: &mut Context<Self>,
2216    ) -> bool {
2217        let retry_state = self.retry_state.get_or_insert(RetryState {
2218            attempt: 0,
2219            max_attempts: MAX_RETRY_ATTEMPTS,
2220            intent,
2221        });
2222
2223        retry_state.attempt += 1;
2224        let attempt = retry_state.attempt;
2225        let max_attempts = retry_state.max_attempts;
2226        let intent = retry_state.intent;
2227
2228        if attempt <= max_attempts {
2229            // Use custom delay if provided (e.g., from rate limit), otherwise exponential backoff
2230            let delay = if let Some(custom_delay) = custom_delay {
2231                custom_delay
2232            } else {
2233                let delay_secs = BASE_RETRY_DELAY_SECS * 2u64.pow((attempt - 1) as u32);
2234                Duration::from_secs(delay_secs)
2235            };
2236
2237            // Add a transient message to inform the user
2238            let delay_secs = delay.as_secs();
2239            let retry_message = format!(
2240                "{error}. Retrying (attempt {attempt} of {max_attempts}) \
2241                in {delay_secs} seconds..."
2242            );
2243            log::warn!(
2244                "Retrying completion request (attempt {attempt} of {max_attempts}) \
2245                in {delay_secs} seconds: {error:?}",
2246            );
2247
2248            // Add a UI-only message instead of a regular message
2249            let id = self.next_message_id.post_inc();
2250            self.messages.push(Message {
2251                id,
2252                role: Role::System,
2253                segments: vec![MessageSegment::Text(retry_message)],
2254                loaded_context: LoadedContext::default(),
2255                creases: Vec::new(),
2256                is_hidden: false,
2257                ui_only: true,
2258            });
2259            cx.emit(ThreadEvent::MessageAdded(id));
2260
2261            // Schedule the retry
2262            let thread_handle = cx.entity().downgrade();
2263
2264            cx.spawn(async move |_thread, cx| {
2265                cx.background_executor().timer(delay).await;
2266
2267                thread_handle
2268                    .update(cx, |thread, cx| {
2269                        // Retry the completion
2270                        thread.send_to_model(model, intent, window, cx);
2271                    })
2272                    .log_err();
2273            })
2274            .detach();
2275
2276            true
2277        } else {
2278            // Max retries exceeded
2279            self.retry_state = None;
2280
2281            let notification_text = if max_attempts == 1 {
2282                "Failed after retrying.".into()
2283            } else {
2284                format!("Failed after retrying {} times.", max_attempts).into()
2285            };
2286
2287            // Stop generating since we're giving up on retrying.
2288            self.pending_completions.clear();
2289
2290            cx.emit(ThreadEvent::RetriesFailed {
2291                message: notification_text,
2292            });
2293
2294            false
2295        }
2296    }
2297
2298    pub fn start_generating_detailed_summary_if_needed(
2299        &mut self,
2300        thread_store: WeakEntity<ThreadStore>,
2301        cx: &mut Context<Self>,
2302    ) {
2303        let Some(last_message_id) = self.messages.last().map(|message| message.id) else {
2304            return;
2305        };
2306
2307        match &*self.detailed_summary_rx.borrow() {
2308            DetailedSummaryState::Generating { message_id, .. }
2309            | DetailedSummaryState::Generated { message_id, .. }
2310                if *message_id == last_message_id =>
2311            {
2312                // Already up-to-date
2313                return;
2314            }
2315            _ => {}
2316        }
2317
2318        let Some(ConfiguredModel { model, provider }) =
2319            LanguageModelRegistry::read_global(cx).thread_summary_model()
2320        else {
2321            return;
2322        };
2323
2324        if !provider.is_authenticated(cx) {
2325            return;
2326        }
2327
2328        let added_user_message = include_str!("./prompts/summarize_thread_detailed_prompt.txt");
2329
2330        let request = self.to_summarize_request(
2331            &model,
2332            CompletionIntent::ThreadContextSummarization,
2333            added_user_message.into(),
2334            cx,
2335        );
2336
2337        *self.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generating {
2338            message_id: last_message_id,
2339        };
2340
2341        // Replace the detailed summarization task if there is one, cancelling it. It would probably
2342        // be better to allow the old task to complete, but this would require logic for choosing
2343        // which result to prefer (the old task could complete after the new one, resulting in a
2344        // stale summary).
2345        self.detailed_summary_task = cx.spawn(async move |thread, cx| {
2346            let stream = model.stream_completion_text(request, &cx);
2347            let Some(mut messages) = stream.await.log_err() else {
2348                thread
2349                    .update(cx, |thread, _cx| {
2350                        *thread.detailed_summary_tx.borrow_mut() =
2351                            DetailedSummaryState::NotGenerated;
2352                    })
2353                    .ok()?;
2354                return None;
2355            };
2356
2357            let mut new_detailed_summary = String::new();
2358
2359            while let Some(chunk) = messages.stream.next().await {
2360                if let Some(chunk) = chunk.log_err() {
2361                    new_detailed_summary.push_str(&chunk);
2362                }
2363            }
2364
2365            thread
2366                .update(cx, |thread, _cx| {
2367                    *thread.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generated {
2368                        text: new_detailed_summary.into(),
2369                        message_id: last_message_id,
2370                    };
2371                })
2372                .ok()?;
2373
2374            // Save thread so its summary can be reused later
2375            if let Some(thread) = thread.upgrade() {
2376                if let Ok(Ok(save_task)) = cx.update(|cx| {
2377                    thread_store
2378                        .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx))
2379                }) {
2380                    save_task.await.log_err();
2381                }
2382            }
2383
2384            Some(())
2385        });
2386    }
2387
2388    pub async fn wait_for_detailed_summary_or_text(
2389        this: &Entity<Self>,
2390        cx: &mut AsyncApp,
2391    ) -> Option<SharedString> {
2392        let mut detailed_summary_rx = this
2393            .read_with(cx, |this, _cx| this.detailed_summary_rx.clone())
2394            .ok()?;
2395        loop {
2396            match detailed_summary_rx.recv().await? {
2397                DetailedSummaryState::Generating { .. } => {}
2398                DetailedSummaryState::NotGenerated => {
2399                    return this.read_with(cx, |this, _cx| this.text().into()).ok();
2400                }
2401                DetailedSummaryState::Generated { text, .. } => return Some(text),
2402            }
2403        }
2404    }
2405
2406    pub fn latest_detailed_summary_or_text(&self) -> SharedString {
2407        self.detailed_summary_rx
2408            .borrow()
2409            .text()
2410            .unwrap_or_else(|| self.text().into())
2411    }
2412
2413    pub fn is_generating_detailed_summary(&self) -> bool {
2414        matches!(
2415            &*self.detailed_summary_rx.borrow(),
2416            DetailedSummaryState::Generating { .. }
2417        )
2418    }
2419
2420    pub fn use_pending_tools(
2421        &mut self,
2422        window: Option<AnyWindowHandle>,
2423        model: Arc<dyn LanguageModel>,
2424        cx: &mut Context<Self>,
2425    ) -> Vec<PendingToolUse> {
2426        self.auto_capture_telemetry(cx);
2427        let request =
2428            Arc::new(self.to_completion_request(model.clone(), CompletionIntent::ToolResults, cx));
2429        let pending_tool_uses = self
2430            .tool_use
2431            .pending_tool_uses()
2432            .into_iter()
2433            .filter(|tool_use| tool_use.status.is_idle())
2434            .cloned()
2435            .collect::<Vec<_>>();
2436
2437        for tool_use in pending_tool_uses.iter() {
2438            self.use_pending_tool(tool_use.clone(), request.clone(), model.clone(), window, cx);
2439        }
2440
2441        pending_tool_uses
2442    }
2443
2444    fn use_pending_tool(
2445        &mut self,
2446        tool_use: PendingToolUse,
2447        request: Arc<LanguageModelRequest>,
2448        model: Arc<dyn LanguageModel>,
2449        window: Option<AnyWindowHandle>,
2450        cx: &mut Context<Self>,
2451    ) {
2452        let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) else {
2453            return self.handle_hallucinated_tool_use(tool_use.id, tool_use.name, window, cx);
2454        };
2455
2456        if !self.profile.is_tool_enabled(tool.source(), tool.name(), cx) {
2457            return self.handle_hallucinated_tool_use(tool_use.id, tool_use.name, window, cx);
2458        }
2459
2460        if tool.needs_confirmation(&tool_use.input, cx)
2461            && !AgentSettings::get_global(cx).always_allow_tool_actions
2462        {
2463            self.tool_use.confirm_tool_use(
2464                tool_use.id,
2465                tool_use.ui_text,
2466                tool_use.input,
2467                request,
2468                tool,
2469            );
2470            cx.emit(ThreadEvent::ToolConfirmationNeeded);
2471        } else {
2472            self.run_tool(
2473                tool_use.id,
2474                tool_use.ui_text,
2475                tool_use.input,
2476                request,
2477                tool,
2478                model,
2479                window,
2480                cx,
2481            );
2482        }
2483    }
2484
2485    pub fn handle_hallucinated_tool_use(
2486        &mut self,
2487        tool_use_id: LanguageModelToolUseId,
2488        hallucinated_tool_name: Arc<str>,
2489        window: Option<AnyWindowHandle>,
2490        cx: &mut Context<Thread>,
2491    ) {
2492        let available_tools = self.profile.enabled_tools(cx);
2493
2494        let tool_list = available_tools
2495            .iter()
2496            .map(|(name, tool)| format!("- {}: {}", name, tool.description()))
2497            .collect::<Vec<_>>()
2498            .join("\n");
2499
2500        let error_message = format!(
2501            "The tool '{}' doesn't exist or is not enabled. Available tools:\n{}",
2502            hallucinated_tool_name, tool_list
2503        );
2504
2505        let pending_tool_use = self.tool_use.insert_tool_output(
2506            tool_use_id.clone(),
2507            hallucinated_tool_name,
2508            Err(anyhow!("Missing tool call: {error_message}")),
2509            self.configured_model.as_ref(),
2510        );
2511
2512        cx.emit(ThreadEvent::MissingToolUse {
2513            tool_use_id: tool_use_id.clone(),
2514            ui_text: error_message.into(),
2515        });
2516
2517        self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2518    }
2519
2520    pub fn receive_invalid_tool_json(
2521        &mut self,
2522        tool_use_id: LanguageModelToolUseId,
2523        tool_name: Arc<str>,
2524        invalid_json: Arc<str>,
2525        error: String,
2526        window: Option<AnyWindowHandle>,
2527        cx: &mut Context<Thread>,
2528    ) {
2529        log::error!("The model returned invalid input JSON: {invalid_json}");
2530
2531        let pending_tool_use = self.tool_use.insert_tool_output(
2532            tool_use_id.clone(),
2533            tool_name,
2534            Err(anyhow!("Error parsing input JSON: {error}")),
2535            self.configured_model.as_ref(),
2536        );
2537        let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
2538            pending_tool_use.ui_text.clone()
2539        } else {
2540            log::error!(
2541                "There was no pending tool use for tool use {tool_use_id}, even though it finished (with invalid input JSON)."
2542            );
2543            format!("Unknown tool {}", tool_use_id).into()
2544        };
2545
2546        cx.emit(ThreadEvent::InvalidToolInput {
2547            tool_use_id: tool_use_id.clone(),
2548            ui_text,
2549            invalid_input_json: invalid_json,
2550        });
2551
2552        self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2553    }
2554
2555    pub fn run_tool(
2556        &mut self,
2557        tool_use_id: LanguageModelToolUseId,
2558        ui_text: impl Into<SharedString>,
2559        input: serde_json::Value,
2560        request: Arc<LanguageModelRequest>,
2561        tool: Arc<dyn Tool>,
2562        model: Arc<dyn LanguageModel>,
2563        window: Option<AnyWindowHandle>,
2564        cx: &mut Context<Thread>,
2565    ) {
2566        let task =
2567            self.spawn_tool_use(tool_use_id.clone(), request, input, tool, model, window, cx);
2568        self.tool_use
2569            .run_pending_tool(tool_use_id, ui_text.into(), task);
2570    }
2571
2572    fn spawn_tool_use(
2573        &mut self,
2574        tool_use_id: LanguageModelToolUseId,
2575        request: Arc<LanguageModelRequest>,
2576        input: serde_json::Value,
2577        tool: Arc<dyn Tool>,
2578        model: Arc<dyn LanguageModel>,
2579        window: Option<AnyWindowHandle>,
2580        cx: &mut Context<Thread>,
2581    ) -> Task<()> {
2582        let tool_name: Arc<str> = tool.name().into();
2583
2584        let tool_result = tool.run(
2585            input,
2586            request,
2587            self.project.clone(),
2588            self.action_log.clone(),
2589            model,
2590            window,
2591            cx,
2592        );
2593
2594        // Store the card separately if it exists
2595        if let Some(card) = tool_result.card.clone() {
2596            self.tool_use
2597                .insert_tool_result_card(tool_use_id.clone(), card);
2598        }
2599
2600        cx.spawn({
2601            async move |thread: WeakEntity<Thread>, cx| {
2602                let output = tool_result.output.await;
2603
2604                thread
2605                    .update(cx, |thread, cx| {
2606                        let pending_tool_use = thread.tool_use.insert_tool_output(
2607                            tool_use_id.clone(),
2608                            tool_name,
2609                            output,
2610                            thread.configured_model.as_ref(),
2611                        );
2612                        thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2613                    })
2614                    .ok();
2615            }
2616        })
2617    }
2618
2619    fn tool_finished(
2620        &mut self,
2621        tool_use_id: LanguageModelToolUseId,
2622        pending_tool_use: Option<PendingToolUse>,
2623        canceled: bool,
2624        window: Option<AnyWindowHandle>,
2625        cx: &mut Context<Self>,
2626    ) {
2627        if self.all_tools_finished() {
2628            if let Some(ConfiguredModel { model, .. }) = self.configured_model.as_ref() {
2629                if !canceled {
2630                    self.send_to_model(model.clone(), CompletionIntent::ToolResults, window, cx);
2631                }
2632                self.auto_capture_telemetry(cx);
2633            }
2634        }
2635
2636        cx.emit(ThreadEvent::ToolFinished {
2637            tool_use_id,
2638            pending_tool_use,
2639        });
2640    }
2641
2642    /// Cancels the last pending completion, if there are any pending.
2643    ///
2644    /// Returns whether a completion was canceled.
2645    pub fn cancel_last_completion(
2646        &mut self,
2647        window: Option<AnyWindowHandle>,
2648        cx: &mut Context<Self>,
2649    ) -> bool {
2650        let mut canceled = self.pending_completions.pop().is_some() || self.retry_state.is_some();
2651
2652        self.retry_state = None;
2653
2654        for pending_tool_use in self.tool_use.cancel_pending() {
2655            canceled = true;
2656            self.tool_finished(
2657                pending_tool_use.id.clone(),
2658                Some(pending_tool_use),
2659                true,
2660                window,
2661                cx,
2662            );
2663        }
2664
2665        if canceled {
2666            cx.emit(ThreadEvent::CompletionCanceled);
2667
2668            // When canceled, we always want to insert the checkpoint.
2669            // (We skip over finalize_pending_checkpoint, because it
2670            // would conclude we didn't have anything to insert here.)
2671            if let Some(checkpoint) = self.pending_checkpoint.take() {
2672                self.insert_checkpoint(checkpoint, cx);
2673            }
2674        } else {
2675            self.finalize_pending_checkpoint(cx);
2676        }
2677
2678        canceled
2679    }
2680
2681    /// Signals that any in-progress editing should be canceled.
2682    ///
2683    /// This method is used to notify listeners (like ActiveThread) that
2684    /// they should cancel any editing operations.
2685    pub fn cancel_editing(&mut self, cx: &mut Context<Self>) {
2686        cx.emit(ThreadEvent::CancelEditing);
2687    }
2688
2689    pub fn feedback(&self) -> Option<ThreadFeedback> {
2690        self.feedback
2691    }
2692
2693    pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
2694        self.message_feedback.get(&message_id).copied()
2695    }
2696
2697    pub fn report_message_feedback(
2698        &mut self,
2699        message_id: MessageId,
2700        feedback: ThreadFeedback,
2701        cx: &mut Context<Self>,
2702    ) -> Task<Result<()>> {
2703        if self.message_feedback.get(&message_id) == Some(&feedback) {
2704            return Task::ready(Ok(()));
2705        }
2706
2707        let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2708        let serialized_thread = self.serialize(cx);
2709        let thread_id = self.id().clone();
2710        let client = self.project.read(cx).client();
2711
2712        let enabled_tool_names: Vec<String> = self
2713            .profile
2714            .enabled_tools(cx)
2715            .iter()
2716            .map(|(name, _)| name.clone().into())
2717            .collect();
2718
2719        self.message_feedback.insert(message_id, feedback);
2720
2721        cx.notify();
2722
2723        let message_content = self
2724            .message(message_id)
2725            .map(|msg| msg.to_string())
2726            .unwrap_or_default();
2727
2728        cx.background_spawn(async move {
2729            let final_project_snapshot = final_project_snapshot.await;
2730            let serialized_thread = serialized_thread.await?;
2731            let thread_data =
2732                serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
2733
2734            let rating = match feedback {
2735                ThreadFeedback::Positive => "positive",
2736                ThreadFeedback::Negative => "negative",
2737            };
2738            telemetry::event!(
2739                "Assistant Thread Rated",
2740                rating,
2741                thread_id,
2742                enabled_tool_names,
2743                message_id = message_id.0,
2744                message_content,
2745                thread_data,
2746                final_project_snapshot
2747            );
2748            client.telemetry().flush_events().await;
2749
2750            Ok(())
2751        })
2752    }
2753
2754    pub fn report_feedback(
2755        &mut self,
2756        feedback: ThreadFeedback,
2757        cx: &mut Context<Self>,
2758    ) -> Task<Result<()>> {
2759        let last_assistant_message_id = self
2760            .messages
2761            .iter()
2762            .rev()
2763            .find(|msg| msg.role == Role::Assistant)
2764            .map(|msg| msg.id);
2765
2766        if let Some(message_id) = last_assistant_message_id {
2767            self.report_message_feedback(message_id, feedback, cx)
2768        } else {
2769            let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2770            let serialized_thread = self.serialize(cx);
2771            let thread_id = self.id().clone();
2772            let client = self.project.read(cx).client();
2773            self.feedback = Some(feedback);
2774            cx.notify();
2775
2776            cx.background_spawn(async move {
2777                let final_project_snapshot = final_project_snapshot.await;
2778                let serialized_thread = serialized_thread.await?;
2779                let thread_data = serde_json::to_value(serialized_thread)
2780                    .unwrap_or_else(|_| serde_json::Value::Null);
2781
2782                let rating = match feedback {
2783                    ThreadFeedback::Positive => "positive",
2784                    ThreadFeedback::Negative => "negative",
2785                };
2786                telemetry::event!(
2787                    "Assistant Thread Rated",
2788                    rating,
2789                    thread_id,
2790                    thread_data,
2791                    final_project_snapshot
2792                );
2793                client.telemetry().flush_events().await;
2794
2795                Ok(())
2796            })
2797        }
2798    }
2799
2800    /// Create a snapshot of the current project state including git information and unsaved buffers.
2801    fn project_snapshot(
2802        project: Entity<Project>,
2803        cx: &mut Context<Self>,
2804    ) -> Task<Arc<ProjectSnapshot>> {
2805        let git_store = project.read(cx).git_store().clone();
2806        let worktree_snapshots: Vec<_> = project
2807            .read(cx)
2808            .visible_worktrees(cx)
2809            .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
2810            .collect();
2811
2812        cx.spawn(async move |_, cx| {
2813            let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
2814
2815            let mut unsaved_buffers = Vec::new();
2816            cx.update(|app_cx| {
2817                let buffer_store = project.read(app_cx).buffer_store();
2818                for buffer_handle in buffer_store.read(app_cx).buffers() {
2819                    let buffer = buffer_handle.read(app_cx);
2820                    if buffer.is_dirty() {
2821                        if let Some(file) = buffer.file() {
2822                            let path = file.path().to_string_lossy().to_string();
2823                            unsaved_buffers.push(path);
2824                        }
2825                    }
2826                }
2827            })
2828            .ok();
2829
2830            Arc::new(ProjectSnapshot {
2831                worktree_snapshots,
2832                unsaved_buffer_paths: unsaved_buffers,
2833                timestamp: Utc::now(),
2834            })
2835        })
2836    }
2837
2838    fn worktree_snapshot(
2839        worktree: Entity<project::Worktree>,
2840        git_store: Entity<GitStore>,
2841        cx: &App,
2842    ) -> Task<WorktreeSnapshot> {
2843        cx.spawn(async move |cx| {
2844            // Get worktree path and snapshot
2845            let worktree_info = cx.update(|app_cx| {
2846                let worktree = worktree.read(app_cx);
2847                let path = worktree.abs_path().to_string_lossy().to_string();
2848                let snapshot = worktree.snapshot();
2849                (path, snapshot)
2850            });
2851
2852            let Ok((worktree_path, _snapshot)) = worktree_info else {
2853                return WorktreeSnapshot {
2854                    worktree_path: String::new(),
2855                    git_state: None,
2856                };
2857            };
2858
2859            let git_state = git_store
2860                .update(cx, |git_store, cx| {
2861                    git_store
2862                        .repositories()
2863                        .values()
2864                        .find(|repo| {
2865                            repo.read(cx)
2866                                .abs_path_to_repo_path(&worktree.read(cx).abs_path())
2867                                .is_some()
2868                        })
2869                        .cloned()
2870                })
2871                .ok()
2872                .flatten()
2873                .map(|repo| {
2874                    repo.update(cx, |repo, _| {
2875                        let current_branch =
2876                            repo.branch.as_ref().map(|branch| branch.name().to_owned());
2877                        repo.send_job(None, |state, _| async move {
2878                            let RepositoryState::Local { backend, .. } = state else {
2879                                return GitState {
2880                                    remote_url: None,
2881                                    head_sha: None,
2882                                    current_branch,
2883                                    diff: None,
2884                                };
2885                            };
2886
2887                            let remote_url = backend.remote_url("origin");
2888                            let head_sha = backend.head_sha().await;
2889                            let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
2890
2891                            GitState {
2892                                remote_url,
2893                                head_sha,
2894                                current_branch,
2895                                diff,
2896                            }
2897                        })
2898                    })
2899                });
2900
2901            let git_state = match git_state {
2902                Some(git_state) => match git_state.ok() {
2903                    Some(git_state) => git_state.await.ok(),
2904                    None => None,
2905                },
2906                None => None,
2907            };
2908
2909            WorktreeSnapshot {
2910                worktree_path,
2911                git_state,
2912            }
2913        })
2914    }
2915
2916    pub fn to_markdown(&self, cx: &App) -> Result<String> {
2917        let mut markdown = Vec::new();
2918
2919        let summary = self.summary().or_default();
2920        writeln!(markdown, "# {summary}\n")?;
2921
2922        for message in self.messages() {
2923            writeln!(
2924                markdown,
2925                "## {role}\n",
2926                role = match message.role {
2927                    Role::User => "User",
2928                    Role::Assistant => "Agent",
2929                    Role::System => "System",
2930                }
2931            )?;
2932
2933            if !message.loaded_context.text.is_empty() {
2934                writeln!(markdown, "{}", message.loaded_context.text)?;
2935            }
2936
2937            if !message.loaded_context.images.is_empty() {
2938                writeln!(
2939                    markdown,
2940                    "\n{} images attached as context.\n",
2941                    message.loaded_context.images.len()
2942                )?;
2943            }
2944
2945            for segment in &message.segments {
2946                match segment {
2947                    MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
2948                    MessageSegment::Thinking { text, .. } => {
2949                        writeln!(markdown, "<think>\n{}\n</think>\n", text)?
2950                    }
2951                    MessageSegment::RedactedThinking(_) => {}
2952                }
2953            }
2954
2955            for tool_use in self.tool_uses_for_message(message.id, cx) {
2956                writeln!(
2957                    markdown,
2958                    "**Use Tool: {} ({})**",
2959                    tool_use.name, tool_use.id
2960                )?;
2961                writeln!(markdown, "```json")?;
2962                writeln!(
2963                    markdown,
2964                    "{}",
2965                    serde_json::to_string_pretty(&tool_use.input)?
2966                )?;
2967                writeln!(markdown, "```")?;
2968            }
2969
2970            for tool_result in self.tool_results_for_message(message.id) {
2971                write!(markdown, "\n**Tool Results: {}", tool_result.tool_use_id)?;
2972                if tool_result.is_error {
2973                    write!(markdown, " (Error)")?;
2974                }
2975
2976                writeln!(markdown, "**\n")?;
2977                match &tool_result.content {
2978                    LanguageModelToolResultContent::Text(text) => {
2979                        writeln!(markdown, "{text}")?;
2980                    }
2981                    LanguageModelToolResultContent::Image(image) => {
2982                        writeln!(markdown, "![Image](data:base64,{})", image.source)?;
2983                    }
2984                }
2985
2986                if let Some(output) = tool_result.output.as_ref() {
2987                    writeln!(
2988                        markdown,
2989                        "\n\nDebug Output:\n\n```json\n{}\n```\n",
2990                        serde_json::to_string_pretty(output)?
2991                    )?;
2992                }
2993            }
2994        }
2995
2996        Ok(String::from_utf8_lossy(&markdown).to_string())
2997    }
2998
2999    pub fn keep_edits_in_range(
3000        &mut self,
3001        buffer: Entity<language::Buffer>,
3002        buffer_range: Range<language::Anchor>,
3003        cx: &mut Context<Self>,
3004    ) {
3005        self.action_log.update(cx, |action_log, cx| {
3006            action_log.keep_edits_in_range(buffer, buffer_range, cx)
3007        });
3008    }
3009
3010    pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
3011        self.action_log
3012            .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
3013    }
3014
3015    pub fn reject_edits_in_ranges(
3016        &mut self,
3017        buffer: Entity<language::Buffer>,
3018        buffer_ranges: Vec<Range<language::Anchor>>,
3019        cx: &mut Context<Self>,
3020    ) -> Task<Result<()>> {
3021        self.action_log.update(cx, |action_log, cx| {
3022            action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
3023        })
3024    }
3025
3026    pub fn action_log(&self) -> &Entity<ActionLog> {
3027        &self.action_log
3028    }
3029
3030    pub fn project(&self) -> &Entity<Project> {
3031        &self.project
3032    }
3033
3034    pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
3035        if !cx.has_flag::<feature_flags::ThreadAutoCaptureFeatureFlag>() {
3036            return;
3037        }
3038
3039        let now = Instant::now();
3040        if let Some(last) = self.last_auto_capture_at {
3041            if now.duration_since(last).as_secs() < 10 {
3042                return;
3043            }
3044        }
3045
3046        self.last_auto_capture_at = Some(now);
3047
3048        let thread_id = self.id().clone();
3049        let github_login = self
3050            .project
3051            .read(cx)
3052            .user_store()
3053            .read(cx)
3054            .current_user()
3055            .map(|user| user.github_login.clone());
3056        let client = self.project.read(cx).client();
3057        let serialize_task = self.serialize(cx);
3058
3059        cx.background_executor()
3060            .spawn(async move {
3061                if let Ok(serialized_thread) = serialize_task.await {
3062                    if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
3063                        telemetry::event!(
3064                            "Agent Thread Auto-Captured",
3065                            thread_id = thread_id.to_string(),
3066                            thread_data = thread_data,
3067                            auto_capture_reason = "tracked_user",
3068                            github_login = github_login
3069                        );
3070
3071                        client.telemetry().flush_events().await;
3072                    }
3073                }
3074            })
3075            .detach();
3076    }
3077
3078    pub fn cumulative_token_usage(&self) -> TokenUsage {
3079        self.cumulative_token_usage
3080    }
3081
3082    pub fn token_usage_up_to_message(&self, message_id: MessageId) -> TotalTokenUsage {
3083        let Some(model) = self.configured_model.as_ref() else {
3084            return TotalTokenUsage::default();
3085        };
3086
3087        let max = model.model.max_token_count();
3088
3089        let index = self
3090            .messages
3091            .iter()
3092            .position(|msg| msg.id == message_id)
3093            .unwrap_or(0);
3094
3095        if index == 0 {
3096            return TotalTokenUsage { total: 0, max };
3097        }
3098
3099        let token_usage = &self
3100            .request_token_usage
3101            .get(index - 1)
3102            .cloned()
3103            .unwrap_or_default();
3104
3105        TotalTokenUsage {
3106            total: token_usage.total_tokens(),
3107            max,
3108        }
3109    }
3110
3111    pub fn total_token_usage(&self) -> Option<TotalTokenUsage> {
3112        let model = self.configured_model.as_ref()?;
3113
3114        let max = model.model.max_token_count();
3115
3116        if let Some(exceeded_error) = &self.exceeded_window_error {
3117            if model.model.id() == exceeded_error.model_id {
3118                return Some(TotalTokenUsage {
3119                    total: exceeded_error.token_count,
3120                    max,
3121                });
3122            }
3123        }
3124
3125        let total = self
3126            .token_usage_at_last_message()
3127            .unwrap_or_default()
3128            .total_tokens();
3129
3130        Some(TotalTokenUsage { total, max })
3131    }
3132
3133    fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
3134        self.request_token_usage
3135            .get(self.messages.len().saturating_sub(1))
3136            .or_else(|| self.request_token_usage.last())
3137            .cloned()
3138    }
3139
3140    fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
3141        let placeholder = self.token_usage_at_last_message().unwrap_or_default();
3142        self.request_token_usage
3143            .resize(self.messages.len(), placeholder);
3144
3145        if let Some(last) = self.request_token_usage.last_mut() {
3146            *last = token_usage;
3147        }
3148    }
3149
3150    fn update_model_request_usage(&self, amount: u32, limit: UsageLimit, cx: &mut Context<Self>) {
3151        self.project.update(cx, |project, cx| {
3152            project.user_store().update(cx, |user_store, cx| {
3153                user_store.update_model_request_usage(
3154                    ModelRequestUsage(RequestUsage {
3155                        amount: amount as i32,
3156                        limit,
3157                    }),
3158                    cx,
3159                )
3160            })
3161        });
3162    }
3163
3164    pub fn deny_tool_use(
3165        &mut self,
3166        tool_use_id: LanguageModelToolUseId,
3167        tool_name: Arc<str>,
3168        window: Option<AnyWindowHandle>,
3169        cx: &mut Context<Self>,
3170    ) {
3171        let err = Err(anyhow::anyhow!(
3172            "Permission to run tool action denied by user"
3173        ));
3174
3175        self.tool_use.insert_tool_output(
3176            tool_use_id.clone(),
3177            tool_name,
3178            err,
3179            self.configured_model.as_ref(),
3180        );
3181        self.tool_finished(tool_use_id.clone(), None, true, window, cx);
3182    }
3183}
3184
3185#[derive(Debug, Clone, Error)]
3186pub enum ThreadError {
3187    #[error("Payment required")]
3188    PaymentRequired,
3189    #[error("Model request limit reached")]
3190    ModelRequestLimitReached { plan: Plan },
3191    #[error("Message {header}: {message}")]
3192    Message {
3193        header: SharedString,
3194        message: SharedString,
3195    },
3196}
3197
3198#[derive(Debug, Clone)]
3199pub enum ThreadEvent {
3200    ShowError(ThreadError),
3201    StreamedCompletion,
3202    ReceivedTextChunk,
3203    NewRequest,
3204    StreamedAssistantText(MessageId, String),
3205    StreamedAssistantThinking(MessageId, String),
3206    StreamedToolUse {
3207        tool_use_id: LanguageModelToolUseId,
3208        ui_text: Arc<str>,
3209        input: serde_json::Value,
3210    },
3211    MissingToolUse {
3212        tool_use_id: LanguageModelToolUseId,
3213        ui_text: Arc<str>,
3214    },
3215    InvalidToolInput {
3216        tool_use_id: LanguageModelToolUseId,
3217        ui_text: Arc<str>,
3218        invalid_input_json: Arc<str>,
3219    },
3220    Stopped(Result<StopReason, Arc<anyhow::Error>>),
3221    MessageAdded(MessageId),
3222    MessageEdited(MessageId),
3223    MessageDeleted(MessageId),
3224    SummaryGenerated,
3225    SummaryChanged,
3226    UsePendingTools {
3227        tool_uses: Vec<PendingToolUse>,
3228    },
3229    ToolFinished {
3230        #[allow(unused)]
3231        tool_use_id: LanguageModelToolUseId,
3232        /// The pending tool use that corresponds to this tool.
3233        pending_tool_use: Option<PendingToolUse>,
3234    },
3235    CheckpointChanged,
3236    ToolConfirmationNeeded,
3237    ToolUseLimitReached,
3238    CancelEditing,
3239    CompletionCanceled,
3240    ProfileChanged,
3241    RetriesFailed {
3242        message: SharedString,
3243    },
3244}
3245
3246impl EventEmitter<ThreadEvent> for Thread {}
3247
3248struct PendingCompletion {
3249    id: usize,
3250    queue_state: QueueState,
3251    _task: Task<()>,
3252}
3253
3254#[cfg(test)]
3255mod tests {
3256    use super::*;
3257    use crate::{
3258        context::load_context, context_store::ContextStore, thread_store, thread_store::ThreadStore,
3259    };
3260
3261    // Test-specific constants
3262    const TEST_RATE_LIMIT_RETRY_SECS: u64 = 30;
3263    use agent_settings::{AgentProfileId, AgentSettings, LanguageModelParameters};
3264    use assistant_tool::ToolRegistry;
3265    use assistant_tools;
3266    use futures::StreamExt;
3267    use futures::future::BoxFuture;
3268    use futures::stream::BoxStream;
3269    use gpui::TestAppContext;
3270    use http_client;
3271    use indoc::indoc;
3272    use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
3273    use language_model::{
3274        LanguageModelCompletionError, LanguageModelName, LanguageModelProviderId,
3275        LanguageModelProviderName, LanguageModelToolChoice,
3276    };
3277    use parking_lot::Mutex;
3278    use project::{FakeFs, Project};
3279    use prompt_store::PromptBuilder;
3280    use serde_json::json;
3281    use settings::{Settings, SettingsStore};
3282    use std::sync::Arc;
3283    use std::time::Duration;
3284    use theme::ThemeSettings;
3285    use util::path;
3286    use workspace::Workspace;
3287
3288    #[gpui::test]
3289    async fn test_message_with_context(cx: &mut TestAppContext) {
3290        init_test_settings(cx);
3291
3292        let project = create_test_project(
3293            cx,
3294            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
3295        )
3296        .await;
3297
3298        let (_workspace, _thread_store, thread, context_store, model) =
3299            setup_test_environment(cx, project.clone()).await;
3300
3301        add_file_to_context(&project, &context_store, "test/code.rs", cx)
3302            .await
3303            .unwrap();
3304
3305        let context =
3306            context_store.read_with(cx, |store, _| store.context().next().cloned().unwrap());
3307        let loaded_context = cx
3308            .update(|cx| load_context(vec![context], &project, &None, cx))
3309            .await;
3310
3311        // Insert user message with context
3312        let message_id = thread.update(cx, |thread, cx| {
3313            thread.insert_user_message(
3314                "Please explain this code",
3315                loaded_context,
3316                None,
3317                Vec::new(),
3318                cx,
3319            )
3320        });
3321
3322        // Check content and context in message object
3323        let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
3324
3325        // Use different path format strings based on platform for the test
3326        #[cfg(windows)]
3327        let path_part = r"test\code.rs";
3328        #[cfg(not(windows))]
3329        let path_part = "test/code.rs";
3330
3331        let expected_context = format!(
3332            r#"
3333<context>
3334The following items were attached by the user. They are up-to-date and don't need to be re-read.
3335
3336<files>
3337```rs {path_part}
3338fn main() {{
3339    println!("Hello, world!");
3340}}
3341```
3342</files>
3343</context>
3344"#
3345        );
3346
3347        assert_eq!(message.role, Role::User);
3348        assert_eq!(message.segments.len(), 1);
3349        assert_eq!(
3350            message.segments[0],
3351            MessageSegment::Text("Please explain this code".to_string())
3352        );
3353        assert_eq!(message.loaded_context.text, expected_context);
3354
3355        // Check message in request
3356        let request = thread.update(cx, |thread, cx| {
3357            thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3358        });
3359
3360        assert_eq!(request.messages.len(), 2);
3361        let expected_full_message = format!("{}Please explain this code", expected_context);
3362        assert_eq!(request.messages[1].string_contents(), expected_full_message);
3363    }
3364
3365    #[gpui::test]
3366    async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
3367        init_test_settings(cx);
3368
3369        let project = create_test_project(
3370            cx,
3371            json!({
3372                "file1.rs": "fn function1() {}\n",
3373                "file2.rs": "fn function2() {}\n",
3374                "file3.rs": "fn function3() {}\n",
3375                "file4.rs": "fn function4() {}\n",
3376            }),
3377        )
3378        .await;
3379
3380        let (_, _thread_store, thread, context_store, model) =
3381            setup_test_environment(cx, project.clone()).await;
3382
3383        // First message with context 1
3384        add_file_to_context(&project, &context_store, "test/file1.rs", cx)
3385            .await
3386            .unwrap();
3387        let new_contexts = context_store.update(cx, |store, cx| {
3388            store.new_context_for_thread(thread.read(cx), None)
3389        });
3390        assert_eq!(new_contexts.len(), 1);
3391        let loaded_context = cx
3392            .update(|cx| load_context(new_contexts, &project, &None, cx))
3393            .await;
3394        let message1_id = thread.update(cx, |thread, cx| {
3395            thread.insert_user_message("Message 1", loaded_context, None, Vec::new(), cx)
3396        });
3397
3398        // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
3399        add_file_to_context(&project, &context_store, "test/file2.rs", cx)
3400            .await
3401            .unwrap();
3402        let new_contexts = context_store.update(cx, |store, cx| {
3403            store.new_context_for_thread(thread.read(cx), None)
3404        });
3405        assert_eq!(new_contexts.len(), 1);
3406        let loaded_context = cx
3407            .update(|cx| load_context(new_contexts, &project, &None, cx))
3408            .await;
3409        let message2_id = thread.update(cx, |thread, cx| {
3410            thread.insert_user_message("Message 2", loaded_context, None, Vec::new(), cx)
3411        });
3412
3413        // Third message with all three contexts (contexts 1 and 2 should be skipped)
3414        //
3415        add_file_to_context(&project, &context_store, "test/file3.rs", cx)
3416            .await
3417            .unwrap();
3418        let new_contexts = context_store.update(cx, |store, cx| {
3419            store.new_context_for_thread(thread.read(cx), None)
3420        });
3421        assert_eq!(new_contexts.len(), 1);
3422        let loaded_context = cx
3423            .update(|cx| load_context(new_contexts, &project, &None, cx))
3424            .await;
3425        let message3_id = thread.update(cx, |thread, cx| {
3426            thread.insert_user_message("Message 3", loaded_context, None, Vec::new(), cx)
3427        });
3428
3429        // Check what contexts are included in each message
3430        let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
3431            (
3432                thread.message(message1_id).unwrap().clone(),
3433                thread.message(message2_id).unwrap().clone(),
3434                thread.message(message3_id).unwrap().clone(),
3435            )
3436        });
3437
3438        // First message should include context 1
3439        assert!(message1.loaded_context.text.contains("file1.rs"));
3440
3441        // Second message should include only context 2 (not 1)
3442        assert!(!message2.loaded_context.text.contains("file1.rs"));
3443        assert!(message2.loaded_context.text.contains("file2.rs"));
3444
3445        // Third message should include only context 3 (not 1 or 2)
3446        assert!(!message3.loaded_context.text.contains("file1.rs"));
3447        assert!(!message3.loaded_context.text.contains("file2.rs"));
3448        assert!(message3.loaded_context.text.contains("file3.rs"));
3449
3450        // Check entire request to make sure all contexts are properly included
3451        let request = thread.update(cx, |thread, cx| {
3452            thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3453        });
3454
3455        // The request should contain all 3 messages
3456        assert_eq!(request.messages.len(), 4);
3457
3458        // Check that the contexts are properly formatted in each message
3459        assert!(request.messages[1].string_contents().contains("file1.rs"));
3460        assert!(!request.messages[1].string_contents().contains("file2.rs"));
3461        assert!(!request.messages[1].string_contents().contains("file3.rs"));
3462
3463        assert!(!request.messages[2].string_contents().contains("file1.rs"));
3464        assert!(request.messages[2].string_contents().contains("file2.rs"));
3465        assert!(!request.messages[2].string_contents().contains("file3.rs"));
3466
3467        assert!(!request.messages[3].string_contents().contains("file1.rs"));
3468        assert!(!request.messages[3].string_contents().contains("file2.rs"));
3469        assert!(request.messages[3].string_contents().contains("file3.rs"));
3470
3471        add_file_to_context(&project, &context_store, "test/file4.rs", cx)
3472            .await
3473            .unwrap();
3474        let new_contexts = context_store.update(cx, |store, cx| {
3475            store.new_context_for_thread(thread.read(cx), Some(message2_id))
3476        });
3477        assert_eq!(new_contexts.len(), 3);
3478        let loaded_context = cx
3479            .update(|cx| load_context(new_contexts, &project, &None, cx))
3480            .await
3481            .loaded_context;
3482
3483        assert!(!loaded_context.text.contains("file1.rs"));
3484        assert!(loaded_context.text.contains("file2.rs"));
3485        assert!(loaded_context.text.contains("file3.rs"));
3486        assert!(loaded_context.text.contains("file4.rs"));
3487
3488        let new_contexts = context_store.update(cx, |store, cx| {
3489            // Remove file4.rs
3490            store.remove_context(&loaded_context.contexts[2].handle(), cx);
3491            store.new_context_for_thread(thread.read(cx), Some(message2_id))
3492        });
3493        assert_eq!(new_contexts.len(), 2);
3494        let loaded_context = cx
3495            .update(|cx| load_context(new_contexts, &project, &None, cx))
3496            .await
3497            .loaded_context;
3498
3499        assert!(!loaded_context.text.contains("file1.rs"));
3500        assert!(loaded_context.text.contains("file2.rs"));
3501        assert!(loaded_context.text.contains("file3.rs"));
3502        assert!(!loaded_context.text.contains("file4.rs"));
3503
3504        let new_contexts = context_store.update(cx, |store, cx| {
3505            // Remove file3.rs
3506            store.remove_context(&loaded_context.contexts[1].handle(), cx);
3507            store.new_context_for_thread(thread.read(cx), Some(message2_id))
3508        });
3509        assert_eq!(new_contexts.len(), 1);
3510        let loaded_context = cx
3511            .update(|cx| load_context(new_contexts, &project, &None, cx))
3512            .await
3513            .loaded_context;
3514
3515        assert!(!loaded_context.text.contains("file1.rs"));
3516        assert!(loaded_context.text.contains("file2.rs"));
3517        assert!(!loaded_context.text.contains("file3.rs"));
3518        assert!(!loaded_context.text.contains("file4.rs"));
3519    }
3520
3521    #[gpui::test]
3522    async fn test_message_without_files(cx: &mut TestAppContext) {
3523        init_test_settings(cx);
3524
3525        let project = create_test_project(
3526            cx,
3527            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
3528        )
3529        .await;
3530
3531        let (_, _thread_store, thread, _context_store, model) =
3532            setup_test_environment(cx, project.clone()).await;
3533
3534        // Insert user message without any context (empty context vector)
3535        let message_id = thread.update(cx, |thread, cx| {
3536            thread.insert_user_message(
3537                "What is the best way to learn Rust?",
3538                ContextLoadResult::default(),
3539                None,
3540                Vec::new(),
3541                cx,
3542            )
3543        });
3544
3545        // Check content and context in message object
3546        let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
3547
3548        // Context should be empty when no files are included
3549        assert_eq!(message.role, Role::User);
3550        assert_eq!(message.segments.len(), 1);
3551        assert_eq!(
3552            message.segments[0],
3553            MessageSegment::Text("What is the best way to learn Rust?".to_string())
3554        );
3555        assert_eq!(message.loaded_context.text, "");
3556
3557        // Check message in request
3558        let request = thread.update(cx, |thread, cx| {
3559            thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3560        });
3561
3562        assert_eq!(request.messages.len(), 2);
3563        assert_eq!(
3564            request.messages[1].string_contents(),
3565            "What is the best way to learn Rust?"
3566        );
3567
3568        // Add second message, also without context
3569        let message2_id = thread.update(cx, |thread, cx| {
3570            thread.insert_user_message(
3571                "Are there any good books?",
3572                ContextLoadResult::default(),
3573                None,
3574                Vec::new(),
3575                cx,
3576            )
3577        });
3578
3579        let message2 =
3580            thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
3581        assert_eq!(message2.loaded_context.text, "");
3582
3583        // Check that both messages appear in the request
3584        let request = thread.update(cx, |thread, cx| {
3585            thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3586        });
3587
3588        assert_eq!(request.messages.len(), 3);
3589        assert_eq!(
3590            request.messages[1].string_contents(),
3591            "What is the best way to learn Rust?"
3592        );
3593        assert_eq!(
3594            request.messages[2].string_contents(),
3595            "Are there any good books?"
3596        );
3597    }
3598
3599    #[gpui::test]
3600    async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
3601        init_test_settings(cx);
3602
3603        let project = create_test_project(
3604            cx,
3605            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
3606        )
3607        .await;
3608
3609        let (_workspace, _thread_store, thread, context_store, model) =
3610            setup_test_environment(cx, project.clone()).await;
3611
3612        // Add a buffer to the context. This will be a tracked buffer
3613        let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
3614            .await
3615            .unwrap();
3616
3617        let context = context_store
3618            .read_with(cx, |store, _| store.context().next().cloned())
3619            .unwrap();
3620        let loaded_context = cx
3621            .update(|cx| load_context(vec![context], &project, &None, cx))
3622            .await;
3623
3624        // Insert user message and assistant response
3625        thread.update(cx, |thread, cx| {
3626            thread.insert_user_message("Explain this code", loaded_context, None, Vec::new(), cx);
3627            thread.insert_assistant_message(
3628                vec![MessageSegment::Text("This code prints 42.".into())],
3629                cx,
3630            );
3631        });
3632
3633        // We shouldn't have a stale buffer notification yet
3634        let notification = thread.read_with(cx, |thread, _| {
3635            find_tool_use(thread, "project_notifications")
3636        });
3637        assert!(
3638            notification.is_none(),
3639            "Should not have stale buffer notification before buffer is modified"
3640        );
3641
3642        // Modify the buffer
3643        buffer.update(cx, |buffer, cx| {
3644            buffer.edit(
3645                [(1..1, "\n    println!(\"Added a new line\");\n")],
3646                None,
3647                cx,
3648            );
3649        });
3650
3651        // Insert another user message
3652        thread.update(cx, |thread, cx| {
3653            thread.insert_user_message(
3654                "What does the code do now?",
3655                ContextLoadResult::default(),
3656                None,
3657                Vec::new(),
3658                cx,
3659            )
3660        });
3661
3662        // Check for the stale buffer warning
3663        thread.update(cx, |thread, cx| {
3664            thread.flush_notifications(model.clone(), CompletionIntent::UserPrompt, cx)
3665        });
3666
3667        let Some(notification_result) = thread.read_with(cx, |thread, _cx| {
3668            find_tool_use(thread, "project_notifications")
3669        }) else {
3670            panic!("Should have a `project_notifications` tool use");
3671        };
3672
3673        let Some(notification_content) = notification_result.content.to_str() else {
3674            panic!("`project_notifications` should return text");
3675        };
3676
3677        let expected_content = indoc! {"[The following is an auto-generated notification; do not reply]
3678
3679        These files have changed since the last read:
3680        - code.rs
3681        "};
3682        assert_eq!(notification_content, expected_content);
3683    }
3684
3685    fn find_tool_use(thread: &Thread, tool_name: &str) -> Option<LanguageModelToolResult> {
3686        thread
3687            .messages()
3688            .filter_map(|message| {
3689                thread
3690                    .tool_results_for_message(message.id)
3691                    .into_iter()
3692                    .find(|result| result.tool_name == tool_name.into())
3693            })
3694            .next()
3695            .cloned()
3696    }
3697
3698    #[gpui::test]
3699    async fn test_storing_profile_setting_per_thread(cx: &mut TestAppContext) {
3700        init_test_settings(cx);
3701
3702        let project = create_test_project(
3703            cx,
3704            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
3705        )
3706        .await;
3707
3708        let (_workspace, thread_store, thread, _context_store, _model) =
3709            setup_test_environment(cx, project.clone()).await;
3710
3711        // Check that we are starting with the default profile
3712        let profile = cx.read(|cx| thread.read(cx).profile.clone());
3713        let tool_set = cx.read(|cx| thread_store.read(cx).tools());
3714        assert_eq!(
3715            profile,
3716            AgentProfile::new(AgentProfileId::default(), tool_set)
3717        );
3718    }
3719
3720    #[gpui::test]
3721    async fn test_serializing_thread_profile(cx: &mut TestAppContext) {
3722        init_test_settings(cx);
3723
3724        let project = create_test_project(
3725            cx,
3726            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
3727        )
3728        .await;
3729
3730        let (_workspace, thread_store, thread, _context_store, _model) =
3731            setup_test_environment(cx, project.clone()).await;
3732
3733        // Profile gets serialized with default values
3734        let serialized = thread
3735            .update(cx, |thread, cx| thread.serialize(cx))
3736            .await
3737            .unwrap();
3738
3739        assert_eq!(serialized.profile, Some(AgentProfileId::default()));
3740
3741        let deserialized = cx.update(|cx| {
3742            thread.update(cx, |thread, cx| {
3743                Thread::deserialize(
3744                    thread.id.clone(),
3745                    serialized,
3746                    thread.project.clone(),
3747                    thread.tools.clone(),
3748                    thread.prompt_builder.clone(),
3749                    thread.project_context.clone(),
3750                    None,
3751                    cx,
3752                )
3753            })
3754        });
3755        let tool_set = cx.read(|cx| thread_store.read(cx).tools());
3756
3757        assert_eq!(
3758            deserialized.profile,
3759            AgentProfile::new(AgentProfileId::default(), tool_set)
3760        );
3761    }
3762
3763    #[gpui::test]
3764    async fn test_temperature_setting(cx: &mut TestAppContext) {
3765        init_test_settings(cx);
3766
3767        let project = create_test_project(
3768            cx,
3769            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
3770        )
3771        .await;
3772
3773        let (_workspace, _thread_store, thread, _context_store, model) =
3774            setup_test_environment(cx, project.clone()).await;
3775
3776        // Both model and provider
3777        cx.update(|cx| {
3778            AgentSettings::override_global(
3779                AgentSettings {
3780                    model_parameters: vec![LanguageModelParameters {
3781                        provider: Some(model.provider_id().0.to_string().into()),
3782                        model: Some(model.id().0.clone()),
3783                        temperature: Some(0.66),
3784                    }],
3785                    ..AgentSettings::get_global(cx).clone()
3786                },
3787                cx,
3788            );
3789        });
3790
3791        let request = thread.update(cx, |thread, cx| {
3792            thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3793        });
3794        assert_eq!(request.temperature, Some(0.66));
3795
3796        // Only model
3797        cx.update(|cx| {
3798            AgentSettings::override_global(
3799                AgentSettings {
3800                    model_parameters: vec![LanguageModelParameters {
3801                        provider: None,
3802                        model: Some(model.id().0.clone()),
3803                        temperature: Some(0.66),
3804                    }],
3805                    ..AgentSettings::get_global(cx).clone()
3806                },
3807                cx,
3808            );
3809        });
3810
3811        let request = thread.update(cx, |thread, cx| {
3812            thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3813        });
3814        assert_eq!(request.temperature, Some(0.66));
3815
3816        // Only provider
3817        cx.update(|cx| {
3818            AgentSettings::override_global(
3819                AgentSettings {
3820                    model_parameters: vec![LanguageModelParameters {
3821                        provider: Some(model.provider_id().0.to_string().into()),
3822                        model: None,
3823                        temperature: Some(0.66),
3824                    }],
3825                    ..AgentSettings::get_global(cx).clone()
3826                },
3827                cx,
3828            );
3829        });
3830
3831        let request = thread.update(cx, |thread, cx| {
3832            thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3833        });
3834        assert_eq!(request.temperature, Some(0.66));
3835
3836        // Same model name, different provider
3837        cx.update(|cx| {
3838            AgentSettings::override_global(
3839                AgentSettings {
3840                    model_parameters: vec![LanguageModelParameters {
3841                        provider: Some("anthropic".into()),
3842                        model: Some(model.id().0.clone()),
3843                        temperature: Some(0.66),
3844                    }],
3845                    ..AgentSettings::get_global(cx).clone()
3846                },
3847                cx,
3848            );
3849        });
3850
3851        let request = thread.update(cx, |thread, cx| {
3852            thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3853        });
3854        assert_eq!(request.temperature, None);
3855    }
3856
3857    #[gpui::test]
3858    async fn test_thread_summary(cx: &mut TestAppContext) {
3859        init_test_settings(cx);
3860
3861        let project = create_test_project(cx, json!({})).await;
3862
3863        let (_, _thread_store, thread, _context_store, model) =
3864            setup_test_environment(cx, project.clone()).await;
3865
3866        // Initial state should be pending
3867        thread.read_with(cx, |thread, _| {
3868            assert!(matches!(thread.summary(), ThreadSummary::Pending));
3869            assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3870        });
3871
3872        // Manually setting the summary should not be allowed in this state
3873        thread.update(cx, |thread, cx| {
3874            thread.set_summary("This should not work", cx);
3875        });
3876
3877        thread.read_with(cx, |thread, _| {
3878            assert!(matches!(thread.summary(), ThreadSummary::Pending));
3879        });
3880
3881        // Send a message
3882        thread.update(cx, |thread, cx| {
3883            thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
3884            thread.send_to_model(
3885                model.clone(),
3886                CompletionIntent::ThreadSummarization,
3887                None,
3888                cx,
3889            );
3890        });
3891
3892        let fake_model = model.as_fake();
3893        simulate_successful_response(&fake_model, cx);
3894
3895        // Should start generating summary when there are >= 2 messages
3896        thread.read_with(cx, |thread, _| {
3897            assert_eq!(*thread.summary(), ThreadSummary::Generating);
3898        });
3899
3900        // Should not be able to set the summary while generating
3901        thread.update(cx, |thread, cx| {
3902            thread.set_summary("This should not work either", cx);
3903        });
3904
3905        thread.read_with(cx, |thread, _| {
3906            assert!(matches!(thread.summary(), ThreadSummary::Generating));
3907            assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3908        });
3909
3910        cx.run_until_parked();
3911        fake_model.stream_last_completion_response("Brief");
3912        fake_model.stream_last_completion_response(" Introduction");
3913        fake_model.end_last_completion_stream();
3914        cx.run_until_parked();
3915
3916        // Summary should be set
3917        thread.read_with(cx, |thread, _| {
3918            assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3919            assert_eq!(thread.summary().or_default(), "Brief Introduction");
3920        });
3921
3922        // Now we should be able to set a summary
3923        thread.update(cx, |thread, cx| {
3924            thread.set_summary("Brief Intro", cx);
3925        });
3926
3927        thread.read_with(cx, |thread, _| {
3928            assert_eq!(thread.summary().or_default(), "Brief Intro");
3929        });
3930
3931        // Test setting an empty summary (should default to DEFAULT)
3932        thread.update(cx, |thread, cx| {
3933            thread.set_summary("", cx);
3934        });
3935
3936        thread.read_with(cx, |thread, _| {
3937            assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3938            assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3939        });
3940    }
3941
3942    #[gpui::test]
3943    async fn test_thread_summary_error_set_manually(cx: &mut TestAppContext) {
3944        init_test_settings(cx);
3945
3946        let project = create_test_project(cx, json!({})).await;
3947
3948        let (_, _thread_store, thread, _context_store, model) =
3949            setup_test_environment(cx, project.clone()).await;
3950
3951        test_summarize_error(&model, &thread, cx);
3952
3953        // Now we should be able to set a summary
3954        thread.update(cx, |thread, cx| {
3955            thread.set_summary("Brief Intro", cx);
3956        });
3957
3958        thread.read_with(cx, |thread, _| {
3959            assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3960            assert_eq!(thread.summary().or_default(), "Brief Intro");
3961        });
3962    }
3963
3964    #[gpui::test]
3965    async fn test_thread_summary_error_retry(cx: &mut TestAppContext) {
3966        init_test_settings(cx);
3967
3968        let project = create_test_project(cx, json!({})).await;
3969
3970        let (_, _thread_store, thread, _context_store, model) =
3971            setup_test_environment(cx, project.clone()).await;
3972
3973        test_summarize_error(&model, &thread, cx);
3974
3975        // Sending another message should not trigger another summarize request
3976        thread.update(cx, |thread, cx| {
3977            thread.insert_user_message(
3978                "How are you?",
3979                ContextLoadResult::default(),
3980                None,
3981                vec![],
3982                cx,
3983            );
3984            thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
3985        });
3986
3987        let fake_model = model.as_fake();
3988        simulate_successful_response(&fake_model, cx);
3989
3990        thread.read_with(cx, |thread, _| {
3991            // State is still Error, not Generating
3992            assert!(matches!(thread.summary(), ThreadSummary::Error));
3993        });
3994
3995        // But the summarize request can be invoked manually
3996        thread.update(cx, |thread, cx| {
3997            thread.summarize(cx);
3998        });
3999
4000        thread.read_with(cx, |thread, _| {
4001            assert!(matches!(thread.summary(), ThreadSummary::Generating));
4002        });
4003
4004        cx.run_until_parked();
4005        fake_model.stream_last_completion_response("A successful summary");
4006        fake_model.end_last_completion_stream();
4007        cx.run_until_parked();
4008
4009        thread.read_with(cx, |thread, _| {
4010            assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
4011            assert_eq!(thread.summary().or_default(), "A successful summary");
4012        });
4013    }
4014
4015    // Helper to create a model that returns errors
4016    enum TestError {
4017        Overloaded,
4018        InternalServerError,
4019    }
4020
4021    struct ErrorInjector {
4022        inner: Arc<FakeLanguageModel>,
4023        error_type: TestError,
4024    }
4025
4026    impl ErrorInjector {
4027        fn new(error_type: TestError) -> Self {
4028            Self {
4029                inner: Arc::new(FakeLanguageModel::default()),
4030                error_type,
4031            }
4032        }
4033    }
4034
4035    impl LanguageModel for ErrorInjector {
4036        fn id(&self) -> LanguageModelId {
4037            self.inner.id()
4038        }
4039
4040        fn name(&self) -> LanguageModelName {
4041            self.inner.name()
4042        }
4043
4044        fn provider_id(&self) -> LanguageModelProviderId {
4045            self.inner.provider_id()
4046        }
4047
4048        fn provider_name(&self) -> LanguageModelProviderName {
4049            self.inner.provider_name()
4050        }
4051
4052        fn supports_tools(&self) -> bool {
4053            self.inner.supports_tools()
4054        }
4055
4056        fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
4057            self.inner.supports_tool_choice(choice)
4058        }
4059
4060        fn supports_images(&self) -> bool {
4061            self.inner.supports_images()
4062        }
4063
4064        fn telemetry_id(&self) -> String {
4065            self.inner.telemetry_id()
4066        }
4067
4068        fn max_token_count(&self) -> u64 {
4069            self.inner.max_token_count()
4070        }
4071
4072        fn count_tokens(
4073            &self,
4074            request: LanguageModelRequest,
4075            cx: &App,
4076        ) -> BoxFuture<'static, Result<u64>> {
4077            self.inner.count_tokens(request, cx)
4078        }
4079
4080        fn stream_completion(
4081            &self,
4082            _request: LanguageModelRequest,
4083            _cx: &AsyncApp,
4084        ) -> BoxFuture<
4085            'static,
4086            Result<
4087                BoxStream<
4088                    'static,
4089                    Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
4090                >,
4091                LanguageModelCompletionError,
4092            >,
4093        > {
4094            let error = match self.error_type {
4095                TestError::Overloaded => LanguageModelCompletionError::ServerOverloaded {
4096                    provider: self.provider_name(),
4097                    retry_after: None,
4098                },
4099                TestError::InternalServerError => {
4100                    LanguageModelCompletionError::ApiInternalServerError {
4101                        provider: self.provider_name(),
4102                        message: "I'm a teapot orbiting the sun".to_string(),
4103                    }
4104                }
4105            };
4106            async move {
4107                let stream = futures::stream::once(async move { Err(error) });
4108                Ok(stream.boxed())
4109            }
4110            .boxed()
4111        }
4112
4113        fn as_fake(&self) -> &FakeLanguageModel {
4114            &self.inner
4115        }
4116    }
4117
4118    #[gpui::test]
4119    async fn test_retry_on_overloaded_error(cx: &mut TestAppContext) {
4120        init_test_settings(cx);
4121
4122        let project = create_test_project(cx, json!({})).await;
4123        let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4124
4125        // Create model that returns overloaded error
4126        let model = Arc::new(ErrorInjector::new(TestError::Overloaded));
4127
4128        // Insert a user message
4129        thread.update(cx, |thread, cx| {
4130            thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4131        });
4132
4133        // Start completion
4134        thread.update(cx, |thread, cx| {
4135            thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4136        });
4137
4138        cx.run_until_parked();
4139
4140        thread.read_with(cx, |thread, _| {
4141            assert!(thread.retry_state.is_some(), "Should have retry state");
4142            let retry_state = thread.retry_state.as_ref().unwrap();
4143            assert_eq!(retry_state.attempt, 1, "Should be first retry attempt");
4144            assert_eq!(
4145                retry_state.max_attempts, MAX_RETRY_ATTEMPTS,
4146                "Should have default max attempts"
4147            );
4148        });
4149
4150        // Check that a retry message was added
4151        thread.read_with(cx, |thread, _| {
4152            let mut messages = thread.messages();
4153            assert!(
4154                messages.any(|msg| {
4155                    msg.role == Role::System
4156                        && msg.ui_only
4157                        && msg.segments.iter().any(|seg| {
4158                            if let MessageSegment::Text(text) = seg {
4159                                text.contains("overloaded")
4160                                    && text
4161                                        .contains(&format!("attempt 1 of {}", MAX_RETRY_ATTEMPTS))
4162                            } else {
4163                                false
4164                            }
4165                        })
4166                }),
4167                "Should have added a system retry message"
4168            );
4169        });
4170
4171        let retry_count = thread.update(cx, |thread, _| {
4172            thread
4173                .messages
4174                .iter()
4175                .filter(|m| {
4176                    m.ui_only
4177                        && m.segments.iter().any(|s| {
4178                            if let MessageSegment::Text(text) = s {
4179                                text.contains("Retrying") && text.contains("seconds")
4180                            } else {
4181                                false
4182                            }
4183                        })
4184                })
4185                .count()
4186        });
4187
4188        assert_eq!(retry_count, 1, "Should have one retry message");
4189    }
4190
4191    #[gpui::test]
4192    async fn test_retry_on_internal_server_error(cx: &mut TestAppContext) {
4193        init_test_settings(cx);
4194
4195        let project = create_test_project(cx, json!({})).await;
4196        let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4197
4198        // Create model that returns internal server error
4199        let model = Arc::new(ErrorInjector::new(TestError::InternalServerError));
4200
4201        // Insert a user message
4202        thread.update(cx, |thread, cx| {
4203            thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4204        });
4205
4206        // Start completion
4207        thread.update(cx, |thread, cx| {
4208            thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4209        });
4210
4211        cx.run_until_parked();
4212
4213        // Check retry state on thread
4214        thread.read_with(cx, |thread, _| {
4215            assert!(thread.retry_state.is_some(), "Should have retry state");
4216            let retry_state = thread.retry_state.as_ref().unwrap();
4217            assert_eq!(retry_state.attempt, 1, "Should be first retry attempt");
4218            assert_eq!(
4219                retry_state.max_attempts, MAX_RETRY_ATTEMPTS,
4220                "Should have correct max attempts"
4221            );
4222        });
4223
4224        // Check that a retry message was added with provider name
4225        thread.read_with(cx, |thread, _| {
4226            let mut messages = thread.messages();
4227            assert!(
4228                messages.any(|msg| {
4229                    msg.role == Role::System
4230                        && msg.ui_only
4231                        && msg.segments.iter().any(|seg| {
4232                            if let MessageSegment::Text(text) = seg {
4233                                text.contains("internal")
4234                                    && text.contains("Fake")
4235                                    && text
4236                                        .contains(&format!("attempt 1 of {}", MAX_RETRY_ATTEMPTS))
4237                            } else {
4238                                false
4239                            }
4240                        })
4241                }),
4242                "Should have added a system retry message with provider name"
4243            );
4244        });
4245
4246        // Count retry messages
4247        let retry_count = thread.update(cx, |thread, _| {
4248            thread
4249                .messages
4250                .iter()
4251                .filter(|m| {
4252                    m.ui_only
4253                        && m.segments.iter().any(|s| {
4254                            if let MessageSegment::Text(text) = s {
4255                                text.contains("Retrying") && text.contains("seconds")
4256                            } else {
4257                                false
4258                            }
4259                        })
4260                })
4261                .count()
4262        });
4263
4264        assert_eq!(retry_count, 1, "Should have one retry message");
4265    }
4266
4267    #[gpui::test]
4268    async fn test_exponential_backoff_on_retries(cx: &mut TestAppContext) {
4269        init_test_settings(cx);
4270
4271        let project = create_test_project(cx, json!({})).await;
4272        let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4273
4274        // Create model that returns overloaded error
4275        let model = Arc::new(ErrorInjector::new(TestError::Overloaded));
4276
4277        // Insert a user message
4278        thread.update(cx, |thread, cx| {
4279            thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4280        });
4281
4282        // Track retry events and completion count
4283        // Track completion events
4284        let completion_count = Arc::new(Mutex::new(0));
4285        let completion_count_clone = completion_count.clone();
4286
4287        let _subscription = thread.update(cx, |_, cx| {
4288            cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| {
4289                if let ThreadEvent::NewRequest = event {
4290                    *completion_count_clone.lock() += 1;
4291                }
4292            })
4293        });
4294
4295        // First attempt
4296        thread.update(cx, |thread, cx| {
4297            thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4298        });
4299        cx.run_until_parked();
4300
4301        // Should have scheduled first retry - count retry messages
4302        let retry_count = thread.update(cx, |thread, _| {
4303            thread
4304                .messages
4305                .iter()
4306                .filter(|m| {
4307                    m.ui_only
4308                        && m.segments.iter().any(|s| {
4309                            if let MessageSegment::Text(text) = s {
4310                                text.contains("Retrying") && text.contains("seconds")
4311                            } else {
4312                                false
4313                            }
4314                        })
4315                })
4316                .count()
4317        });
4318        assert_eq!(retry_count, 1, "Should have scheduled first retry");
4319
4320        // Check retry state
4321        thread.read_with(cx, |thread, _| {
4322            assert!(thread.retry_state.is_some(), "Should have retry state");
4323            let retry_state = thread.retry_state.as_ref().unwrap();
4324            assert_eq!(retry_state.attempt, 1, "Should be first retry attempt");
4325        });
4326
4327        // Advance clock for first retry
4328        cx.executor()
4329            .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS));
4330        cx.run_until_parked();
4331
4332        // Should have scheduled second retry - count retry messages
4333        let retry_count = thread.update(cx, |thread, _| {
4334            thread
4335                .messages
4336                .iter()
4337                .filter(|m| {
4338                    m.ui_only
4339                        && m.segments.iter().any(|s| {
4340                            if let MessageSegment::Text(text) = s {
4341                                text.contains("Retrying") && text.contains("seconds")
4342                            } else {
4343                                false
4344                            }
4345                        })
4346                })
4347                .count()
4348        });
4349        assert_eq!(retry_count, 2, "Should have scheduled second retry");
4350
4351        // Check retry state updated
4352        thread.read_with(cx, |thread, _| {
4353            assert!(thread.retry_state.is_some(), "Should have retry state");
4354            let retry_state = thread.retry_state.as_ref().unwrap();
4355            assert_eq!(retry_state.attempt, 2, "Should be second retry attempt");
4356            assert_eq!(
4357                retry_state.max_attempts, MAX_RETRY_ATTEMPTS,
4358                "Should have correct max attempts"
4359            );
4360        });
4361
4362        // Advance clock for second retry (exponential backoff)
4363        cx.executor()
4364            .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS * 2));
4365        cx.run_until_parked();
4366
4367        // Should have scheduled third retry
4368        // Count all retry messages now
4369        let retry_count = thread.update(cx, |thread, _| {
4370            thread
4371                .messages
4372                .iter()
4373                .filter(|m| {
4374                    m.ui_only
4375                        && m.segments.iter().any(|s| {
4376                            if let MessageSegment::Text(text) = s {
4377                                text.contains("Retrying") && text.contains("seconds")
4378                            } else {
4379                                false
4380                            }
4381                        })
4382                })
4383                .count()
4384        });
4385        assert_eq!(
4386            retry_count, MAX_RETRY_ATTEMPTS as usize,
4387            "Should have scheduled third retry"
4388        );
4389
4390        // Check retry state updated
4391        thread.read_with(cx, |thread, _| {
4392            assert!(thread.retry_state.is_some(), "Should have retry state");
4393            let retry_state = thread.retry_state.as_ref().unwrap();
4394            assert_eq!(
4395                retry_state.attempt, MAX_RETRY_ATTEMPTS,
4396                "Should be at max retry attempt"
4397            );
4398            assert_eq!(
4399                retry_state.max_attempts, MAX_RETRY_ATTEMPTS,
4400                "Should have correct max attempts"
4401            );
4402        });
4403
4404        // Advance clock for third retry (exponential backoff)
4405        cx.executor()
4406            .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS * 4));
4407        cx.run_until_parked();
4408
4409        // No more retries should be scheduled after clock was advanced.
4410        let retry_count = thread.update(cx, |thread, _| {
4411            thread
4412                .messages
4413                .iter()
4414                .filter(|m| {
4415                    m.ui_only
4416                        && m.segments.iter().any(|s| {
4417                            if let MessageSegment::Text(text) = s {
4418                                text.contains("Retrying") && text.contains("seconds")
4419                            } else {
4420                                false
4421                            }
4422                        })
4423                })
4424                .count()
4425        });
4426        assert_eq!(
4427            retry_count, MAX_RETRY_ATTEMPTS as usize,
4428            "Should not exceed max retries"
4429        );
4430
4431        // Final completion count should be initial + max retries
4432        assert_eq!(
4433            *completion_count.lock(),
4434            (MAX_RETRY_ATTEMPTS + 1) as usize,
4435            "Should have made initial + max retry attempts"
4436        );
4437    }
4438
4439    #[gpui::test]
4440    async fn test_max_retries_exceeded(cx: &mut TestAppContext) {
4441        init_test_settings(cx);
4442
4443        let project = create_test_project(cx, json!({})).await;
4444        let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4445
4446        // Create model that returns overloaded error
4447        let model = Arc::new(ErrorInjector::new(TestError::Overloaded));
4448
4449        // Insert a user message
4450        thread.update(cx, |thread, cx| {
4451            thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4452        });
4453
4454        // Track events
4455        let retries_failed = Arc::new(Mutex::new(false));
4456        let retries_failed_clone = retries_failed.clone();
4457
4458        let _subscription = thread.update(cx, |_, cx| {
4459            cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| {
4460                if let ThreadEvent::RetriesFailed { .. } = event {
4461                    *retries_failed_clone.lock() = true;
4462                }
4463            })
4464        });
4465
4466        // Start initial completion
4467        thread.update(cx, |thread, cx| {
4468            thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4469        });
4470        cx.run_until_parked();
4471
4472        // Advance through all retries
4473        for i in 0..MAX_RETRY_ATTEMPTS {
4474            let delay = if i == 0 {
4475                BASE_RETRY_DELAY_SECS
4476            } else {
4477                BASE_RETRY_DELAY_SECS * 2u64.pow(i as u32 - 1)
4478            };
4479            cx.executor().advance_clock(Duration::from_secs(delay));
4480            cx.run_until_parked();
4481        }
4482
4483        // After the 3rd retry is scheduled, we need to wait for it to execute and fail
4484        // The 3rd retry has a delay of BASE_RETRY_DELAY_SECS * 4 (20 seconds)
4485        let final_delay = BASE_RETRY_DELAY_SECS * 2u64.pow((MAX_RETRY_ATTEMPTS - 1) as u32);
4486        cx.executor()
4487            .advance_clock(Duration::from_secs(final_delay));
4488        cx.run_until_parked();
4489
4490        let retry_count = thread.update(cx, |thread, _| {
4491            thread
4492                .messages
4493                .iter()
4494                .filter(|m| {
4495                    m.ui_only
4496                        && m.segments.iter().any(|s| {
4497                            if let MessageSegment::Text(text) = s {
4498                                text.contains("Retrying") && text.contains("seconds")
4499                            } else {
4500                                false
4501                            }
4502                        })
4503                })
4504                .count()
4505        });
4506
4507        // After max retries, should emit RetriesFailed event
4508        assert_eq!(
4509            retry_count, MAX_RETRY_ATTEMPTS as usize,
4510            "Should have attempted max retries"
4511        );
4512        assert!(
4513            *retries_failed.lock(),
4514            "Should emit RetriesFailed event after max retries exceeded"
4515        );
4516
4517        // Retry state should be cleared
4518        thread.read_with(cx, |thread, _| {
4519            assert!(
4520                thread.retry_state.is_none(),
4521                "Retry state should be cleared after max retries"
4522            );
4523
4524            // Verify we have the expected number of retry messages
4525            let retry_messages = thread
4526                .messages
4527                .iter()
4528                .filter(|msg| msg.ui_only && msg.role == Role::System)
4529                .count();
4530            assert_eq!(
4531                retry_messages, MAX_RETRY_ATTEMPTS as usize,
4532                "Should have one retry message per attempt"
4533            );
4534        });
4535    }
4536
4537    #[gpui::test]
4538    async fn test_retry_message_removed_on_retry(cx: &mut TestAppContext) {
4539        init_test_settings(cx);
4540
4541        let project = create_test_project(cx, json!({})).await;
4542        let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4543
4544        // We'll use a wrapper to switch behavior after first failure
4545        struct RetryTestModel {
4546            inner: Arc<FakeLanguageModel>,
4547            failed_once: Arc<Mutex<bool>>,
4548        }
4549
4550        impl LanguageModel for RetryTestModel {
4551            fn id(&self) -> LanguageModelId {
4552                self.inner.id()
4553            }
4554
4555            fn name(&self) -> LanguageModelName {
4556                self.inner.name()
4557            }
4558
4559            fn provider_id(&self) -> LanguageModelProviderId {
4560                self.inner.provider_id()
4561            }
4562
4563            fn provider_name(&self) -> LanguageModelProviderName {
4564                self.inner.provider_name()
4565            }
4566
4567            fn supports_tools(&self) -> bool {
4568                self.inner.supports_tools()
4569            }
4570
4571            fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
4572                self.inner.supports_tool_choice(choice)
4573            }
4574
4575            fn supports_images(&self) -> bool {
4576                self.inner.supports_images()
4577            }
4578
4579            fn telemetry_id(&self) -> String {
4580                self.inner.telemetry_id()
4581            }
4582
4583            fn max_token_count(&self) -> u64 {
4584                self.inner.max_token_count()
4585            }
4586
4587            fn count_tokens(
4588                &self,
4589                request: LanguageModelRequest,
4590                cx: &App,
4591            ) -> BoxFuture<'static, Result<u64>> {
4592                self.inner.count_tokens(request, cx)
4593            }
4594
4595            fn stream_completion(
4596                &self,
4597                request: LanguageModelRequest,
4598                cx: &AsyncApp,
4599            ) -> BoxFuture<
4600                'static,
4601                Result<
4602                    BoxStream<
4603                        'static,
4604                        Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
4605                    >,
4606                    LanguageModelCompletionError,
4607                >,
4608            > {
4609                if !*self.failed_once.lock() {
4610                    *self.failed_once.lock() = true;
4611                    let provider = self.provider_name();
4612                    // Return error on first attempt
4613                    let stream = futures::stream::once(async move {
4614                        Err(LanguageModelCompletionError::ServerOverloaded {
4615                            provider,
4616                            retry_after: None,
4617                        })
4618                    });
4619                    async move { Ok(stream.boxed()) }.boxed()
4620                } else {
4621                    // Succeed on retry
4622                    self.inner.stream_completion(request, cx)
4623                }
4624            }
4625
4626            fn as_fake(&self) -> &FakeLanguageModel {
4627                &self.inner
4628            }
4629        }
4630
4631        let model = Arc::new(RetryTestModel {
4632            inner: Arc::new(FakeLanguageModel::default()),
4633            failed_once: Arc::new(Mutex::new(false)),
4634        });
4635
4636        // Insert a user message
4637        thread.update(cx, |thread, cx| {
4638            thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4639        });
4640
4641        // Track message deletions
4642        // Track when retry completes successfully
4643        let retry_completed = Arc::new(Mutex::new(false));
4644        let retry_completed_clone = retry_completed.clone();
4645
4646        let _subscription = thread.update(cx, |_, cx| {
4647            cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| {
4648                if let ThreadEvent::StreamedCompletion = event {
4649                    *retry_completed_clone.lock() = true;
4650                }
4651            })
4652        });
4653
4654        // Start completion
4655        thread.update(cx, |thread, cx| {
4656            thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4657        });
4658        cx.run_until_parked();
4659
4660        // Get the retry message ID
4661        let retry_message_id = thread.read_with(cx, |thread, _| {
4662            thread
4663                .messages()
4664                .find(|msg| msg.role == Role::System && msg.ui_only)
4665                .map(|msg| msg.id)
4666                .expect("Should have a retry message")
4667        });
4668
4669        // Wait for retry
4670        cx.executor()
4671            .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS));
4672        cx.run_until_parked();
4673
4674        // Stream some successful content
4675        let fake_model = model.as_fake();
4676        // After the retry, there should be a new pending completion
4677        let pending = fake_model.pending_completions();
4678        assert!(
4679            !pending.is_empty(),
4680            "Should have a pending completion after retry"
4681        );
4682        fake_model.stream_completion_response(&pending[0], "Success!");
4683        fake_model.end_completion_stream(&pending[0]);
4684        cx.run_until_parked();
4685
4686        // Check that the retry completed successfully
4687        assert!(
4688            *retry_completed.lock(),
4689            "Retry should have completed successfully"
4690        );
4691
4692        // Retry message should still exist but be marked as ui_only
4693        thread.read_with(cx, |thread, _| {
4694            let retry_msg = thread
4695                .message(retry_message_id)
4696                .expect("Retry message should still exist");
4697            assert!(retry_msg.ui_only, "Retry message should be ui_only");
4698            assert_eq!(
4699                retry_msg.role,
4700                Role::System,
4701                "Retry message should have System role"
4702            );
4703        });
4704    }
4705
4706    #[gpui::test]
4707    async fn test_successful_completion_clears_retry_state(cx: &mut TestAppContext) {
4708        init_test_settings(cx);
4709
4710        let project = create_test_project(cx, json!({})).await;
4711        let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4712
4713        // Create a model that fails once then succeeds
4714        struct FailOnceModel {
4715            inner: Arc<FakeLanguageModel>,
4716            failed_once: Arc<Mutex<bool>>,
4717        }
4718
4719        impl LanguageModel for FailOnceModel {
4720            fn id(&self) -> LanguageModelId {
4721                self.inner.id()
4722            }
4723
4724            fn name(&self) -> LanguageModelName {
4725                self.inner.name()
4726            }
4727
4728            fn provider_id(&self) -> LanguageModelProviderId {
4729                self.inner.provider_id()
4730            }
4731
4732            fn provider_name(&self) -> LanguageModelProviderName {
4733                self.inner.provider_name()
4734            }
4735
4736            fn supports_tools(&self) -> bool {
4737                self.inner.supports_tools()
4738            }
4739
4740            fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
4741                self.inner.supports_tool_choice(choice)
4742            }
4743
4744            fn supports_images(&self) -> bool {
4745                self.inner.supports_images()
4746            }
4747
4748            fn telemetry_id(&self) -> String {
4749                self.inner.telemetry_id()
4750            }
4751
4752            fn max_token_count(&self) -> u64 {
4753                self.inner.max_token_count()
4754            }
4755
4756            fn count_tokens(
4757                &self,
4758                request: LanguageModelRequest,
4759                cx: &App,
4760            ) -> BoxFuture<'static, Result<u64>> {
4761                self.inner.count_tokens(request, cx)
4762            }
4763
4764            fn stream_completion(
4765                &self,
4766                request: LanguageModelRequest,
4767                cx: &AsyncApp,
4768            ) -> BoxFuture<
4769                'static,
4770                Result<
4771                    BoxStream<
4772                        'static,
4773                        Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
4774                    >,
4775                    LanguageModelCompletionError,
4776                >,
4777            > {
4778                if !*self.failed_once.lock() {
4779                    *self.failed_once.lock() = true;
4780                    let provider = self.provider_name();
4781                    // Return error on first attempt
4782                    let stream = futures::stream::once(async move {
4783                        Err(LanguageModelCompletionError::ServerOverloaded {
4784                            provider,
4785                            retry_after: None,
4786                        })
4787                    });
4788                    async move { Ok(stream.boxed()) }.boxed()
4789                } else {
4790                    // Succeed on retry
4791                    self.inner.stream_completion(request, cx)
4792                }
4793            }
4794        }
4795
4796        let fail_once_model = Arc::new(FailOnceModel {
4797            inner: Arc::new(FakeLanguageModel::default()),
4798            failed_once: Arc::new(Mutex::new(false)),
4799        });
4800
4801        // Insert a user message
4802        thread.update(cx, |thread, cx| {
4803            thread.insert_user_message(
4804                "Test message",
4805                ContextLoadResult::default(),
4806                None,
4807                vec![],
4808                cx,
4809            );
4810        });
4811
4812        // Start completion with fail-once model
4813        thread.update(cx, |thread, cx| {
4814            thread.send_to_model(
4815                fail_once_model.clone(),
4816                CompletionIntent::UserPrompt,
4817                None,
4818                cx,
4819            );
4820        });
4821
4822        cx.run_until_parked();
4823
4824        // Verify retry state exists after first failure
4825        thread.read_with(cx, |thread, _| {
4826            assert!(
4827                thread.retry_state.is_some(),
4828                "Should have retry state after failure"
4829            );
4830        });
4831
4832        // Wait for retry delay
4833        cx.executor()
4834            .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS));
4835        cx.run_until_parked();
4836
4837        // The retry should now use our FailOnceModel which should succeed
4838        // We need to help the FakeLanguageModel complete the stream
4839        let inner_fake = fail_once_model.inner.clone();
4840
4841        // Wait a bit for the retry to start
4842        cx.run_until_parked();
4843
4844        // Check for pending completions and complete them
4845        if let Some(pending) = inner_fake.pending_completions().first() {
4846            inner_fake.stream_completion_response(pending, "Success!");
4847            inner_fake.end_completion_stream(pending);
4848        }
4849        cx.run_until_parked();
4850
4851        thread.read_with(cx, |thread, _| {
4852            assert!(
4853                thread.retry_state.is_none(),
4854                "Retry state should be cleared after successful completion"
4855            );
4856
4857            let has_assistant_message = thread
4858                .messages
4859                .iter()
4860                .any(|msg| msg.role == Role::Assistant && !msg.ui_only);
4861            assert!(
4862                has_assistant_message,
4863                "Should have an assistant message after successful retry"
4864            );
4865        });
4866    }
4867
4868    #[gpui::test]
4869    async fn test_rate_limit_retry_single_attempt(cx: &mut TestAppContext) {
4870        init_test_settings(cx);
4871
4872        let project = create_test_project(cx, json!({})).await;
4873        let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
4874
4875        // Create a model that returns rate limit error with retry_after
4876        struct RateLimitModel {
4877            inner: Arc<FakeLanguageModel>,
4878        }
4879
4880        impl LanguageModel for RateLimitModel {
4881            fn id(&self) -> LanguageModelId {
4882                self.inner.id()
4883            }
4884
4885            fn name(&self) -> LanguageModelName {
4886                self.inner.name()
4887            }
4888
4889            fn provider_id(&self) -> LanguageModelProviderId {
4890                self.inner.provider_id()
4891            }
4892
4893            fn provider_name(&self) -> LanguageModelProviderName {
4894                self.inner.provider_name()
4895            }
4896
4897            fn supports_tools(&self) -> bool {
4898                self.inner.supports_tools()
4899            }
4900
4901            fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
4902                self.inner.supports_tool_choice(choice)
4903            }
4904
4905            fn supports_images(&self) -> bool {
4906                self.inner.supports_images()
4907            }
4908
4909            fn telemetry_id(&self) -> String {
4910                self.inner.telemetry_id()
4911            }
4912
4913            fn max_token_count(&self) -> u64 {
4914                self.inner.max_token_count()
4915            }
4916
4917            fn count_tokens(
4918                &self,
4919                request: LanguageModelRequest,
4920                cx: &App,
4921            ) -> BoxFuture<'static, Result<u64>> {
4922                self.inner.count_tokens(request, cx)
4923            }
4924
4925            fn stream_completion(
4926                &self,
4927                _request: LanguageModelRequest,
4928                _cx: &AsyncApp,
4929            ) -> BoxFuture<
4930                'static,
4931                Result<
4932                    BoxStream<
4933                        'static,
4934                        Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
4935                    >,
4936                    LanguageModelCompletionError,
4937                >,
4938            > {
4939                let provider = self.provider_name();
4940                async move {
4941                    let stream = futures::stream::once(async move {
4942                        Err(LanguageModelCompletionError::RateLimitExceeded {
4943                            provider,
4944                            retry_after: Some(Duration::from_secs(TEST_RATE_LIMIT_RETRY_SECS)),
4945                        })
4946                    });
4947                    Ok(stream.boxed())
4948                }
4949                .boxed()
4950            }
4951
4952            fn as_fake(&self) -> &FakeLanguageModel {
4953                &self.inner
4954            }
4955        }
4956
4957        let model = Arc::new(RateLimitModel {
4958            inner: Arc::new(FakeLanguageModel::default()),
4959        });
4960
4961        // Insert a user message
4962        thread.update(cx, |thread, cx| {
4963            thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
4964        });
4965
4966        // Start completion
4967        thread.update(cx, |thread, cx| {
4968            thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
4969        });
4970
4971        cx.run_until_parked();
4972
4973        let retry_count = thread.update(cx, |thread, _| {
4974            thread
4975                .messages
4976                .iter()
4977                .filter(|m| {
4978                    m.ui_only
4979                        && m.segments.iter().any(|s| {
4980                            if let MessageSegment::Text(text) = s {
4981                                text.contains("rate limit exceeded")
4982                            } else {
4983                                false
4984                            }
4985                        })
4986                })
4987                .count()
4988        });
4989        assert_eq!(retry_count, 1, "Should have scheduled one retry");
4990
4991        thread.read_with(cx, |thread, _| {
4992            assert!(
4993                thread.retry_state.is_none(),
4994                "Rate limit errors should not set retry_state"
4995            );
4996        });
4997
4998        // Verify we have one retry message
4999        thread.read_with(cx, |thread, _| {
5000            let retry_messages = thread
5001                .messages
5002                .iter()
5003                .filter(|msg| {
5004                    msg.ui_only
5005                        && msg.segments.iter().any(|seg| {
5006                            if let MessageSegment::Text(text) = seg {
5007                                text.contains("rate limit exceeded")
5008                            } else {
5009                                false
5010                            }
5011                        })
5012                })
5013                .count();
5014            assert_eq!(
5015                retry_messages, 1,
5016                "Should have one rate limit retry message"
5017            );
5018        });
5019
5020        // Check that retry message doesn't include attempt count
5021        thread.read_with(cx, |thread, _| {
5022            let retry_message = thread
5023                .messages
5024                .iter()
5025                .find(|msg| msg.role == Role::System && msg.ui_only)
5026                .expect("Should have a retry message");
5027
5028            // Check that the message doesn't contain attempt count
5029            if let Some(MessageSegment::Text(text)) = retry_message.segments.first() {
5030                assert!(
5031                    !text.contains("attempt"),
5032                    "Rate limit retry message should not contain attempt count"
5033                );
5034                assert!(
5035                    text.contains(&format!(
5036                        "Retrying in {} seconds",
5037                        TEST_RATE_LIMIT_RETRY_SECS
5038                    )),
5039                    "Rate limit retry message should contain retry delay"
5040                );
5041            }
5042        });
5043    }
5044
5045    #[gpui::test]
5046    async fn test_ui_only_messages_not_sent_to_model(cx: &mut TestAppContext) {
5047        init_test_settings(cx);
5048
5049        let project = create_test_project(cx, json!({})).await;
5050        let (_, _, thread, _, model) = setup_test_environment(cx, project.clone()).await;
5051
5052        // Insert a regular user message
5053        thread.update(cx, |thread, cx| {
5054            thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
5055        });
5056
5057        // Insert a UI-only message (like our retry notifications)
5058        thread.update(cx, |thread, cx| {
5059            let id = thread.next_message_id.post_inc();
5060            thread.messages.push(Message {
5061                id,
5062                role: Role::System,
5063                segments: vec![MessageSegment::Text(
5064                    "This is a UI-only message that should not be sent to the model".to_string(),
5065                )],
5066                loaded_context: LoadedContext::default(),
5067                creases: Vec::new(),
5068                is_hidden: true,
5069                ui_only: true,
5070            });
5071            cx.emit(ThreadEvent::MessageAdded(id));
5072        });
5073
5074        // Insert another regular message
5075        thread.update(cx, |thread, cx| {
5076            thread.insert_user_message(
5077                "How are you?",
5078                ContextLoadResult::default(),
5079                None,
5080                vec![],
5081                cx,
5082            );
5083        });
5084
5085        // Generate the completion request
5086        let request = thread.update(cx, |thread, cx| {
5087            thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
5088        });
5089
5090        // Verify that the request only contains non-UI-only messages
5091        // Should have system prompt + 2 user messages, but not the UI-only message
5092        let user_messages: Vec<_> = request
5093            .messages
5094            .iter()
5095            .filter(|msg| msg.role == Role::User)
5096            .collect();
5097        assert_eq!(
5098            user_messages.len(),
5099            2,
5100            "Should have exactly 2 user messages"
5101        );
5102
5103        // Verify the UI-only content is not present anywhere in the request
5104        let request_text = request
5105            .messages
5106            .iter()
5107            .flat_map(|msg| &msg.content)
5108            .filter_map(|content| match content {
5109                MessageContent::Text(text) => Some(text.as_str()),
5110                _ => None,
5111            })
5112            .collect::<String>();
5113
5114        assert!(
5115            !request_text.contains("UI-only message"),
5116            "UI-only message content should not be in the request"
5117        );
5118
5119        // Verify the thread still has all 3 messages (including UI-only)
5120        thread.read_with(cx, |thread, _| {
5121            assert_eq!(
5122                thread.messages().count(),
5123                3,
5124                "Thread should have 3 messages"
5125            );
5126            assert_eq!(
5127                thread.messages().filter(|m| m.ui_only).count(),
5128                1,
5129                "Thread should have 1 UI-only message"
5130            );
5131        });
5132
5133        // Verify that UI-only messages are not serialized
5134        let serialized = thread
5135            .update(cx, |thread, cx| thread.serialize(cx))
5136            .await
5137            .unwrap();
5138        assert_eq!(
5139            serialized.messages.len(),
5140            2,
5141            "Serialized thread should only have 2 messages (no UI-only)"
5142        );
5143    }
5144
5145    #[gpui::test]
5146    async fn test_retry_cancelled_on_stop(cx: &mut TestAppContext) {
5147        init_test_settings(cx);
5148
5149        let project = create_test_project(cx, json!({})).await;
5150        let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await;
5151
5152        // Create model that returns overloaded error
5153        let model = Arc::new(ErrorInjector::new(TestError::Overloaded));
5154
5155        // Insert a user message
5156        thread.update(cx, |thread, cx| {
5157            thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx);
5158        });
5159
5160        // Start completion
5161        thread.update(cx, |thread, cx| {
5162            thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
5163        });
5164
5165        cx.run_until_parked();
5166
5167        // Verify retry was scheduled by checking for retry message
5168        let has_retry_message = thread.read_with(cx, |thread, _| {
5169            thread.messages.iter().any(|m| {
5170                m.ui_only
5171                    && m.segments.iter().any(|s| {
5172                        if let MessageSegment::Text(text) = s {
5173                            text.contains("Retrying") && text.contains("seconds")
5174                        } else {
5175                            false
5176                        }
5177                    })
5178            })
5179        });
5180        assert!(has_retry_message, "Should have scheduled a retry");
5181
5182        // Cancel the completion before the retry happens
5183        thread.update(cx, |thread, cx| {
5184            thread.cancel_last_completion(None, cx);
5185        });
5186
5187        cx.run_until_parked();
5188
5189        // The retry should not have happened - no pending completions
5190        let fake_model = model.as_fake();
5191        assert_eq!(
5192            fake_model.pending_completions().len(),
5193            0,
5194            "Should have no pending completions after cancellation"
5195        );
5196
5197        // Verify the retry was cancelled by checking retry state
5198        thread.read_with(cx, |thread, _| {
5199            if let Some(retry_state) = &thread.retry_state {
5200                panic!(
5201                    "retry_state should be cleared after cancellation, but found: attempt={}, max_attempts={}, intent={:?}",
5202                    retry_state.attempt, retry_state.max_attempts, retry_state.intent
5203                );
5204            }
5205        });
5206    }
5207
5208    fn test_summarize_error(
5209        model: &Arc<dyn LanguageModel>,
5210        thread: &Entity<Thread>,
5211        cx: &mut TestAppContext,
5212    ) {
5213        thread.update(cx, |thread, cx| {
5214            thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
5215            thread.send_to_model(
5216                model.clone(),
5217                CompletionIntent::ThreadSummarization,
5218                None,
5219                cx,
5220            );
5221        });
5222
5223        let fake_model = model.as_fake();
5224        simulate_successful_response(&fake_model, cx);
5225
5226        thread.read_with(cx, |thread, _| {
5227            assert!(matches!(thread.summary(), ThreadSummary::Generating));
5228            assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
5229        });
5230
5231        // Simulate summary request ending
5232        cx.run_until_parked();
5233        fake_model.end_last_completion_stream();
5234        cx.run_until_parked();
5235
5236        // State is set to Error and default message
5237        thread.read_with(cx, |thread, _| {
5238            assert!(matches!(thread.summary(), ThreadSummary::Error));
5239            assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
5240        });
5241    }
5242
5243    fn simulate_successful_response(fake_model: &FakeLanguageModel, cx: &mut TestAppContext) {
5244        cx.run_until_parked();
5245        fake_model.stream_last_completion_response("Assistant response");
5246        fake_model.end_last_completion_stream();
5247        cx.run_until_parked();
5248    }
5249
5250    fn init_test_settings(cx: &mut TestAppContext) {
5251        cx.update(|cx| {
5252            let settings_store = SettingsStore::test(cx);
5253            cx.set_global(settings_store);
5254            language::init(cx);
5255            Project::init_settings(cx);
5256            AgentSettings::register(cx);
5257            prompt_store::init(cx);
5258            thread_store::init(cx);
5259            workspace::init_settings(cx);
5260            language_model::init_settings(cx);
5261            ThemeSettings::register(cx);
5262            ToolRegistry::default_global(cx);
5263            assistant_tool::init(cx);
5264
5265            let http_client = Arc::new(http_client::HttpClientWithUrl::new(
5266                http_client::FakeHttpClient::with_200_response(),
5267                "http://localhost".to_string(),
5268                None,
5269            ));
5270            assistant_tools::init(http_client, cx);
5271        });
5272    }
5273
5274    // Helper to create a test project with test files
5275    async fn create_test_project(
5276        cx: &mut TestAppContext,
5277        files: serde_json::Value,
5278    ) -> Entity<Project> {
5279        let fs = FakeFs::new(cx.executor());
5280        fs.insert_tree(path!("/test"), files).await;
5281        Project::test(fs, [path!("/test").as_ref()], cx).await
5282    }
5283
5284    async fn setup_test_environment(
5285        cx: &mut TestAppContext,
5286        project: Entity<Project>,
5287    ) -> (
5288        Entity<Workspace>,
5289        Entity<ThreadStore>,
5290        Entity<Thread>,
5291        Entity<ContextStore>,
5292        Arc<dyn LanguageModel>,
5293    ) {
5294        let (workspace, cx) =
5295            cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
5296
5297        let thread_store = cx
5298            .update(|_, cx| {
5299                ThreadStore::load(
5300                    project.clone(),
5301                    cx.new(|_| ToolWorkingSet::default()),
5302                    None,
5303                    Arc::new(PromptBuilder::new(None).unwrap()),
5304                    cx,
5305                )
5306            })
5307            .await
5308            .unwrap();
5309
5310        let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
5311        let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
5312
5313        let provider = Arc::new(FakeLanguageModelProvider);
5314        let model = provider.test_model();
5315        let model: Arc<dyn LanguageModel> = Arc::new(model);
5316
5317        cx.update(|_, cx| {
5318            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
5319                registry.set_default_model(
5320                    Some(ConfiguredModel {
5321                        provider: provider.clone(),
5322                        model: model.clone(),
5323                    }),
5324                    cx,
5325                );
5326                registry.set_thread_summary_model(
5327                    Some(ConfiguredModel {
5328                        provider,
5329                        model: model.clone(),
5330                    }),
5331                    cx,
5332                );
5333            })
5334        });
5335
5336        (workspace, thread_store, thread, context_store, model)
5337    }
5338
5339    async fn add_file_to_context(
5340        project: &Entity<Project>,
5341        context_store: &Entity<ContextStore>,
5342        path: &str,
5343        cx: &mut TestAppContext,
5344    ) -> Result<Entity<language::Buffer>> {
5345        let buffer_path = project
5346            .read_with(cx, |project, cx| project.find_project_path(path, cx))
5347            .unwrap();
5348
5349        let buffer = project
5350            .update(cx, |project, cx| {
5351                project.open_buffer(buffer_path.clone(), cx)
5352            })
5353            .await
5354            .unwrap();
5355
5356        context_store.update(cx, |context_store, cx| {
5357            context_store.add_file_from_buffer(&buffer_path, buffer.clone(), false, cx);
5358        });
5359
5360        Ok(buffer)
5361    }
5362}