thread.rs

   1use std::fmt::Write as _;
   2use std::io::Write;
   3use std::ops::Range;
   4use std::sync::Arc;
   5use std::time::Instant;
   6
   7use anyhow::{Result, anyhow};
   8use assistant_settings::AssistantSettings;
   9use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
  10use chrono::{DateTime, Utc};
  11use collections::HashMap;
  12use editor::display_map::CreaseMetadata;
  13use feature_flags::{self, FeatureFlagAppExt};
  14use futures::future::Shared;
  15use futures::{FutureExt, StreamExt as _};
  16use git::repository::DiffType;
  17use gpui::{
  18    AnyWindowHandle, App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task,
  19    WeakEntity,
  20};
  21use language_model::{
  22    ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
  23    LanguageModelId, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
  24    LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
  25    LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
  26    ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, SelectedModel,
  27    StopReason, TokenUsage,
  28};
  29use postage::stream::Stream as _;
  30use project::Project;
  31use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
  32use prompt_store::{ModelContext, PromptBuilder};
  33use proto::Plan;
  34use schemars::JsonSchema;
  35use serde::{Deserialize, Serialize};
  36use settings::Settings;
  37use thiserror::Error;
  38use util::{ResultExt as _, TryFutureExt as _, post_inc};
  39use uuid::Uuid;
  40use zed_llm_client::CompletionMode;
  41
  42use crate::ThreadStore;
  43use crate::context::{AgentContext, AgentContextHandle, ContextLoadResult, LoadedContext};
  44use crate::thread_store::{
  45    SerializedCrease, SerializedLanguageModel, SerializedMessage, SerializedMessageSegment,
  46    SerializedThread, SerializedToolResult, SerializedToolUse, SharedProjectContext,
  47};
  48use crate::tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState};
  49
  50#[derive(
  51    Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
  52)]
  53pub struct ThreadId(Arc<str>);
  54
  55impl ThreadId {
  56    pub fn new() -> Self {
  57        Self(Uuid::new_v4().to_string().into())
  58    }
  59}
  60
  61impl std::fmt::Display for ThreadId {
  62    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  63        write!(f, "{}", self.0)
  64    }
  65}
  66
  67impl From<&str> for ThreadId {
  68    fn from(value: &str) -> Self {
  69        Self(value.into())
  70    }
  71}
  72
  73/// The ID of the user prompt that initiated a request.
  74///
  75/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
  76#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
  77pub struct PromptId(Arc<str>);
  78
  79impl PromptId {
  80    pub fn new() -> Self {
  81        Self(Uuid::new_v4().to_string().into())
  82    }
  83}
  84
  85impl std::fmt::Display for PromptId {
  86    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  87        write!(f, "{}", self.0)
  88    }
  89}
  90
  91#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
  92pub struct MessageId(pub(crate) usize);
  93
  94impl MessageId {
  95    fn post_inc(&mut self) -> Self {
  96        Self(post_inc(&mut self.0))
  97    }
  98}
  99
 100/// Stored information that can be used to resurrect a context crease when creating an editor for a past message.
 101#[derive(Clone, Debug)]
 102pub struct MessageCrease {
 103    pub range: Range<usize>,
 104    pub metadata: CreaseMetadata,
 105    /// None for a deserialized message, Some otherwise.
 106    pub context: Option<AgentContextHandle>,
 107}
 108
 109/// A message in a [`Thread`].
 110#[derive(Debug, Clone)]
 111pub struct Message {
 112    pub id: MessageId,
 113    pub role: Role,
 114    pub segments: Vec<MessageSegment>,
 115    pub loaded_context: LoadedContext,
 116    pub creases: Vec<MessageCrease>,
 117}
 118
 119impl Message {
 120    /// Returns whether the message contains any meaningful text that should be displayed
 121    /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
 122    pub fn should_display_content(&self) -> bool {
 123        self.segments.iter().all(|segment| segment.should_display())
 124    }
 125
 126    pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
 127        if let Some(MessageSegment::Thinking {
 128            text: segment,
 129            signature: current_signature,
 130        }) = self.segments.last_mut()
 131        {
 132            if let Some(signature) = signature {
 133                *current_signature = Some(signature);
 134            }
 135            segment.push_str(text);
 136        } else {
 137            self.segments.push(MessageSegment::Thinking {
 138                text: text.to_string(),
 139                signature,
 140            });
 141        }
 142    }
 143
 144    pub fn push_text(&mut self, text: &str) {
 145        if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
 146            segment.push_str(text);
 147        } else {
 148            self.segments.push(MessageSegment::Text(text.to_string()));
 149        }
 150    }
 151
 152    pub fn to_string(&self) -> String {
 153        let mut result = String::new();
 154
 155        if !self.loaded_context.text.is_empty() {
 156            result.push_str(&self.loaded_context.text);
 157        }
 158
 159        for segment in &self.segments {
 160            match segment {
 161                MessageSegment::Text(text) => result.push_str(text),
 162                MessageSegment::Thinking { text, .. } => {
 163                    result.push_str("<think>\n");
 164                    result.push_str(text);
 165                    result.push_str("\n</think>");
 166                }
 167                MessageSegment::RedactedThinking(_) => {}
 168            }
 169        }
 170
 171        result
 172    }
 173}
 174
 175#[derive(Debug, Clone, PartialEq, Eq)]
 176pub enum MessageSegment {
 177    Text(String),
 178    Thinking {
 179        text: String,
 180        signature: Option<String>,
 181    },
 182    RedactedThinking(Vec<u8>),
 183}
 184
 185impl MessageSegment {
 186    pub fn should_display(&self) -> bool {
 187        match self {
 188            Self::Text(text) => text.is_empty(),
 189            Self::Thinking { text, .. } => text.is_empty(),
 190            Self::RedactedThinking(_) => false,
 191        }
 192    }
 193}
 194
 195#[derive(Debug, Clone, Serialize, Deserialize)]
 196pub struct ProjectSnapshot {
 197    pub worktree_snapshots: Vec<WorktreeSnapshot>,
 198    pub unsaved_buffer_paths: Vec<String>,
 199    pub timestamp: DateTime<Utc>,
 200}
 201
 202#[derive(Debug, Clone, Serialize, Deserialize)]
 203pub struct WorktreeSnapshot {
 204    pub worktree_path: String,
 205    pub git_state: Option<GitState>,
 206}
 207
 208#[derive(Debug, Clone, Serialize, Deserialize)]
 209pub struct GitState {
 210    pub remote_url: Option<String>,
 211    pub head_sha: Option<String>,
 212    pub current_branch: Option<String>,
 213    pub diff: Option<String>,
 214}
 215
 216#[derive(Clone)]
 217pub struct ThreadCheckpoint {
 218    message_id: MessageId,
 219    git_checkpoint: GitStoreCheckpoint,
 220}
 221
 222#[derive(Copy, Clone, Debug, PartialEq, Eq)]
 223pub enum ThreadFeedback {
 224    Positive,
 225    Negative,
 226}
 227
 228pub enum LastRestoreCheckpoint {
 229    Pending {
 230        message_id: MessageId,
 231    },
 232    Error {
 233        message_id: MessageId,
 234        error: String,
 235    },
 236}
 237
 238impl LastRestoreCheckpoint {
 239    pub fn message_id(&self) -> MessageId {
 240        match self {
 241            LastRestoreCheckpoint::Pending { message_id } => *message_id,
 242            LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
 243        }
 244    }
 245}
 246
 247#[derive(Clone, Debug, Default, Serialize, Deserialize)]
 248pub enum DetailedSummaryState {
 249    #[default]
 250    NotGenerated,
 251    Generating {
 252        message_id: MessageId,
 253    },
 254    Generated {
 255        text: SharedString,
 256        message_id: MessageId,
 257    },
 258}
 259
 260impl DetailedSummaryState {
 261    fn text(&self) -> Option<SharedString> {
 262        if let Self::Generated { text, .. } = self {
 263            Some(text.clone())
 264        } else {
 265            None
 266        }
 267    }
 268}
 269
 270#[derive(Default)]
 271pub struct TotalTokenUsage {
 272    pub total: usize,
 273    pub max: usize,
 274}
 275
 276impl TotalTokenUsage {
 277    pub fn ratio(&self) -> TokenUsageRatio {
 278        #[cfg(debug_assertions)]
 279        let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
 280            .unwrap_or("0.8".to_string())
 281            .parse()
 282            .unwrap();
 283        #[cfg(not(debug_assertions))]
 284        let warning_threshold: f32 = 0.8;
 285
 286        // When the maximum is unknown because there is no selected model,
 287        // avoid showing the token limit warning.
 288        if self.max == 0 {
 289            TokenUsageRatio::Normal
 290        } else if self.total >= self.max {
 291            TokenUsageRatio::Exceeded
 292        } else if self.total as f32 / self.max as f32 >= warning_threshold {
 293            TokenUsageRatio::Warning
 294        } else {
 295            TokenUsageRatio::Normal
 296        }
 297    }
 298
 299    pub fn add(&self, tokens: usize) -> TotalTokenUsage {
 300        TotalTokenUsage {
 301            total: self.total + tokens,
 302            max: self.max,
 303        }
 304    }
 305}
 306
 307#[derive(Debug, Default, PartialEq, Eq)]
 308pub enum TokenUsageRatio {
 309    #[default]
 310    Normal,
 311    Warning,
 312    Exceeded,
 313}
 314
 315fn default_completion_mode(cx: &App) -> CompletionMode {
 316    if cx.is_staff() {
 317        CompletionMode::Max
 318    } else {
 319        CompletionMode::Normal
 320    }
 321}
 322
 323#[derive(Debug, Clone, Copy)]
 324pub enum QueueState {
 325    Sending,
 326    Queued { position: usize },
 327    Started,
 328}
 329
 330/// A thread of conversation with the LLM.
 331pub struct Thread {
 332    id: ThreadId,
 333    updated_at: DateTime<Utc>,
 334    summary: Option<SharedString>,
 335    pending_summary: Task<Option<()>>,
 336    detailed_summary_task: Task<Option<()>>,
 337    detailed_summary_tx: postage::watch::Sender<DetailedSummaryState>,
 338    detailed_summary_rx: postage::watch::Receiver<DetailedSummaryState>,
 339    completion_mode: CompletionMode,
 340    messages: Vec<Message>,
 341    next_message_id: MessageId,
 342    last_prompt_id: PromptId,
 343    project_context: SharedProjectContext,
 344    checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
 345    completion_count: usize,
 346    pending_completions: Vec<PendingCompletion>,
 347    project: Entity<Project>,
 348    prompt_builder: Arc<PromptBuilder>,
 349    tools: Entity<ToolWorkingSet>,
 350    tool_use: ToolUseState,
 351    action_log: Entity<ActionLog>,
 352    last_restore_checkpoint: Option<LastRestoreCheckpoint>,
 353    pending_checkpoint: Option<ThreadCheckpoint>,
 354    initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
 355    request_token_usage: Vec<TokenUsage>,
 356    cumulative_token_usage: TokenUsage,
 357    exceeded_window_error: Option<ExceededWindowError>,
 358    tool_use_limit_reached: bool,
 359    feedback: Option<ThreadFeedback>,
 360    message_feedback: HashMap<MessageId, ThreadFeedback>,
 361    last_auto_capture_at: Option<Instant>,
 362    last_received_chunk_at: Option<Instant>,
 363    request_callback: Option<
 364        Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
 365    >,
 366    remaining_turns: u32,
 367    configured_model: Option<ConfiguredModel>,
 368}
 369
 370#[derive(Debug, Clone, Serialize, Deserialize)]
 371pub struct ExceededWindowError {
 372    /// Model used when last message exceeded context window
 373    model_id: LanguageModelId,
 374    /// Token count including last message
 375    token_count: usize,
 376}
 377
 378impl Thread {
 379    pub fn new(
 380        project: Entity<Project>,
 381        tools: Entity<ToolWorkingSet>,
 382        prompt_builder: Arc<PromptBuilder>,
 383        system_prompt: SharedProjectContext,
 384        cx: &mut Context<Self>,
 385    ) -> Self {
 386        let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel();
 387        let configured_model = LanguageModelRegistry::read_global(cx).default_model();
 388
 389        Self {
 390            id: ThreadId::new(),
 391            updated_at: Utc::now(),
 392            summary: None,
 393            pending_summary: Task::ready(None),
 394            detailed_summary_task: Task::ready(None),
 395            detailed_summary_tx,
 396            detailed_summary_rx,
 397            completion_mode: default_completion_mode(cx),
 398            messages: Vec::new(),
 399            next_message_id: MessageId(0),
 400            last_prompt_id: PromptId::new(),
 401            project_context: system_prompt,
 402            checkpoints_by_message: HashMap::default(),
 403            completion_count: 0,
 404            pending_completions: Vec::new(),
 405            project: project.clone(),
 406            prompt_builder,
 407            tools: tools.clone(),
 408            last_restore_checkpoint: None,
 409            pending_checkpoint: None,
 410            tool_use: ToolUseState::new(tools.clone()),
 411            action_log: cx.new(|_| ActionLog::new(project.clone())),
 412            initial_project_snapshot: {
 413                let project_snapshot = Self::project_snapshot(project, cx);
 414                cx.foreground_executor()
 415                    .spawn(async move { Some(project_snapshot.await) })
 416                    .shared()
 417            },
 418            request_token_usage: Vec::new(),
 419            cumulative_token_usage: TokenUsage::default(),
 420            exceeded_window_error: None,
 421            tool_use_limit_reached: false,
 422            feedback: None,
 423            message_feedback: HashMap::default(),
 424            last_auto_capture_at: None,
 425            last_received_chunk_at: None,
 426            request_callback: None,
 427            remaining_turns: u32::MAX,
 428            configured_model,
 429        }
 430    }
 431
 432    pub fn deserialize(
 433        id: ThreadId,
 434        serialized: SerializedThread,
 435        project: Entity<Project>,
 436        tools: Entity<ToolWorkingSet>,
 437        prompt_builder: Arc<PromptBuilder>,
 438        project_context: SharedProjectContext,
 439        cx: &mut Context<Self>,
 440    ) -> Self {
 441        let next_message_id = MessageId(
 442            serialized
 443                .messages
 444                .last()
 445                .map(|message| message.id.0 + 1)
 446                .unwrap_or(0),
 447        );
 448        let tool_use = ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages);
 449        let (detailed_summary_tx, detailed_summary_rx) =
 450            postage::watch::channel_with(serialized.detailed_summary_state);
 451
 452        let configured_model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
 453            serialized
 454                .model
 455                .and_then(|model| {
 456                    let model = SelectedModel {
 457                        provider: model.provider.clone().into(),
 458                        model: model.model.clone().into(),
 459                    };
 460                    registry.select_model(&model, cx)
 461                })
 462                .or_else(|| registry.default_model())
 463        });
 464
 465        Self {
 466            id,
 467            updated_at: serialized.updated_at,
 468            summary: Some(serialized.summary),
 469            pending_summary: Task::ready(None),
 470            detailed_summary_task: Task::ready(None),
 471            detailed_summary_tx,
 472            detailed_summary_rx,
 473            completion_mode: default_completion_mode(cx),
 474            messages: serialized
 475                .messages
 476                .into_iter()
 477                .map(|message| Message {
 478                    id: message.id,
 479                    role: message.role,
 480                    segments: message
 481                        .segments
 482                        .into_iter()
 483                        .map(|segment| match segment {
 484                            SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
 485                            SerializedMessageSegment::Thinking { text, signature } => {
 486                                MessageSegment::Thinking { text, signature }
 487                            }
 488                            SerializedMessageSegment::RedactedThinking { data } => {
 489                                MessageSegment::RedactedThinking(data)
 490                            }
 491                        })
 492                        .collect(),
 493                    loaded_context: LoadedContext {
 494                        contexts: Vec::new(),
 495                        text: message.context,
 496                        images: Vec::new(),
 497                    },
 498                    creases: message
 499                        .creases
 500                        .into_iter()
 501                        .map(|crease| MessageCrease {
 502                            range: crease.start..crease.end,
 503                            metadata: CreaseMetadata {
 504                                icon_path: crease.icon_path,
 505                                label: crease.label,
 506                            },
 507                            context: None,
 508                        })
 509                        .collect(),
 510                })
 511                .collect(),
 512            next_message_id,
 513            last_prompt_id: PromptId::new(),
 514            project_context,
 515            checkpoints_by_message: HashMap::default(),
 516            completion_count: 0,
 517            pending_completions: Vec::new(),
 518            last_restore_checkpoint: None,
 519            pending_checkpoint: None,
 520            project: project.clone(),
 521            prompt_builder,
 522            tools,
 523            tool_use,
 524            action_log: cx.new(|_| ActionLog::new(project)),
 525            initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
 526            request_token_usage: serialized.request_token_usage,
 527            cumulative_token_usage: serialized.cumulative_token_usage,
 528            exceeded_window_error: None,
 529            tool_use_limit_reached: false,
 530            feedback: None,
 531            message_feedback: HashMap::default(),
 532            last_auto_capture_at: None,
 533            last_received_chunk_at: None,
 534            request_callback: None,
 535            remaining_turns: u32::MAX,
 536            configured_model,
 537        }
 538    }
 539
 540    pub fn set_request_callback(
 541        &mut self,
 542        callback: impl 'static
 543        + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
 544    ) {
 545        self.request_callback = Some(Box::new(callback));
 546    }
 547
 548    pub fn id(&self) -> &ThreadId {
 549        &self.id
 550    }
 551
 552    pub fn is_empty(&self) -> bool {
 553        self.messages.is_empty()
 554    }
 555
 556    pub fn updated_at(&self) -> DateTime<Utc> {
 557        self.updated_at
 558    }
 559
 560    pub fn touch_updated_at(&mut self) {
 561        self.updated_at = Utc::now();
 562    }
 563
 564    pub fn advance_prompt_id(&mut self) {
 565        self.last_prompt_id = PromptId::new();
 566    }
 567
 568    pub fn summary(&self) -> Option<SharedString> {
 569        self.summary.clone()
 570    }
 571
 572    pub fn project_context(&self) -> SharedProjectContext {
 573        self.project_context.clone()
 574    }
 575
 576    pub fn get_or_init_configured_model(&mut self, cx: &App) -> Option<ConfiguredModel> {
 577        if self.configured_model.is_none() {
 578            self.configured_model = LanguageModelRegistry::read_global(cx).default_model();
 579        }
 580        self.configured_model.clone()
 581    }
 582
 583    pub fn configured_model(&self) -> Option<ConfiguredModel> {
 584        self.configured_model.clone()
 585    }
 586
 587    pub fn set_configured_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
 588        self.configured_model = model;
 589        cx.notify();
 590    }
 591
 592    pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
 593
 594    pub fn summary_or_default(&self) -> SharedString {
 595        self.summary.clone().unwrap_or(Self::DEFAULT_SUMMARY)
 596    }
 597
 598    pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
 599        let Some(current_summary) = &self.summary else {
 600            // Don't allow setting summary until generated
 601            return;
 602        };
 603
 604        let mut new_summary = new_summary.into();
 605
 606        if new_summary.is_empty() {
 607            new_summary = Self::DEFAULT_SUMMARY;
 608        }
 609
 610        if current_summary != &new_summary {
 611            self.summary = Some(new_summary);
 612            cx.emit(ThreadEvent::SummaryChanged);
 613        }
 614    }
 615
 616    pub fn completion_mode(&self) -> CompletionMode {
 617        self.completion_mode
 618    }
 619
 620    pub fn set_completion_mode(&mut self, mode: CompletionMode) {
 621        self.completion_mode = mode;
 622    }
 623
 624    pub fn message(&self, id: MessageId) -> Option<&Message> {
 625        let index = self
 626            .messages
 627            .binary_search_by(|message| message.id.cmp(&id))
 628            .ok()?;
 629
 630        self.messages.get(index)
 631    }
 632
 633    pub fn messages(&self) -> impl ExactSizeIterator<Item = &Message> {
 634        self.messages.iter()
 635    }
 636
 637    pub fn is_generating(&self) -> bool {
 638        !self.pending_completions.is_empty() || !self.all_tools_finished()
 639    }
 640
 641    /// Indicates whether streaming of language model events is stale.
 642    /// When `is_generating()` is false, this method returns `None`.
 643    pub fn is_generation_stale(&self) -> Option<bool> {
 644        const STALE_THRESHOLD: u128 = 250;
 645
 646        self.last_received_chunk_at
 647            .map(|instant| instant.elapsed().as_millis() > STALE_THRESHOLD)
 648    }
 649
 650    fn received_chunk(&mut self) {
 651        self.last_received_chunk_at = Some(Instant::now());
 652    }
 653
 654    pub fn queue_state(&self) -> Option<QueueState> {
 655        self.pending_completions
 656            .first()
 657            .map(|pending_completion| pending_completion.queue_state)
 658    }
 659
 660    pub fn tools(&self) -> &Entity<ToolWorkingSet> {
 661        &self.tools
 662    }
 663
 664    pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
 665        self.tool_use
 666            .pending_tool_uses()
 667            .into_iter()
 668            .find(|tool_use| &tool_use.id == id)
 669    }
 670
 671    pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
 672        self.tool_use
 673            .pending_tool_uses()
 674            .into_iter()
 675            .filter(|tool_use| tool_use.status.needs_confirmation())
 676    }
 677
 678    pub fn has_pending_tool_uses(&self) -> bool {
 679        !self.tool_use.pending_tool_uses().is_empty()
 680    }
 681
 682    pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
 683        self.checkpoints_by_message.get(&id).cloned()
 684    }
 685
 686    pub fn restore_checkpoint(
 687        &mut self,
 688        checkpoint: ThreadCheckpoint,
 689        cx: &mut Context<Self>,
 690    ) -> Task<Result<()>> {
 691        self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
 692            message_id: checkpoint.message_id,
 693        });
 694        cx.emit(ThreadEvent::CheckpointChanged);
 695        cx.notify();
 696
 697        let git_store = self.project().read(cx).git_store().clone();
 698        let restore = git_store.update(cx, |git_store, cx| {
 699            git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
 700        });
 701
 702        cx.spawn(async move |this, cx| {
 703            let result = restore.await;
 704            this.update(cx, |this, cx| {
 705                if let Err(err) = result.as_ref() {
 706                    this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
 707                        message_id: checkpoint.message_id,
 708                        error: err.to_string(),
 709                    });
 710                } else {
 711                    this.truncate(checkpoint.message_id, cx);
 712                    this.last_restore_checkpoint = None;
 713                }
 714                this.pending_checkpoint = None;
 715                cx.emit(ThreadEvent::CheckpointChanged);
 716                cx.notify();
 717            })?;
 718            result
 719        })
 720    }
 721
 722    fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
 723        let pending_checkpoint = if self.is_generating() {
 724            return;
 725        } else if let Some(checkpoint) = self.pending_checkpoint.take() {
 726            checkpoint
 727        } else {
 728            return;
 729        };
 730
 731        let git_store = self.project.read(cx).git_store().clone();
 732        let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
 733        cx.spawn(async move |this, cx| match final_checkpoint.await {
 734            Ok(final_checkpoint) => {
 735                let equal = git_store
 736                    .update(cx, |store, cx| {
 737                        store.compare_checkpoints(
 738                            pending_checkpoint.git_checkpoint.clone(),
 739                            final_checkpoint.clone(),
 740                            cx,
 741                        )
 742                    })?
 743                    .await
 744                    .unwrap_or(false);
 745
 746                if !equal {
 747                    this.update(cx, |this, cx| {
 748                        this.insert_checkpoint(pending_checkpoint, cx)
 749                    })?;
 750                }
 751
 752                Ok(())
 753            }
 754            Err(_) => this.update(cx, |this, cx| {
 755                this.insert_checkpoint(pending_checkpoint, cx)
 756            }),
 757        })
 758        .detach();
 759    }
 760
 761    fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
 762        self.checkpoints_by_message
 763            .insert(checkpoint.message_id, checkpoint);
 764        cx.emit(ThreadEvent::CheckpointChanged);
 765        cx.notify();
 766    }
 767
 768    pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
 769        self.last_restore_checkpoint.as_ref()
 770    }
 771
 772    pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
 773        let Some(message_ix) = self
 774            .messages
 775            .iter()
 776            .rposition(|message| message.id == message_id)
 777        else {
 778            return;
 779        };
 780        for deleted_message in self.messages.drain(message_ix..) {
 781            self.checkpoints_by_message.remove(&deleted_message.id);
 782        }
 783        cx.notify();
 784    }
 785
 786    pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AgentContext> {
 787        self.messages
 788            .iter()
 789            .find(|message| message.id == id)
 790            .into_iter()
 791            .flat_map(|message| message.loaded_context.contexts.iter())
 792    }
 793
 794    pub fn is_turn_end(&self, ix: usize) -> bool {
 795        if self.messages.is_empty() {
 796            return false;
 797        }
 798
 799        if !self.is_generating() && ix == self.messages.len() - 1 {
 800            return true;
 801        }
 802
 803        let Some(message) = self.messages.get(ix) else {
 804            return false;
 805        };
 806
 807        if message.role != Role::Assistant {
 808            return false;
 809        }
 810
 811        self.messages
 812            .get(ix + 1)
 813            .and_then(|message| {
 814                self.message(message.id)
 815                    .map(|next_message| next_message.role == Role::User)
 816            })
 817            .unwrap_or(false)
 818    }
 819
 820    pub fn tool_use_limit_reached(&self) -> bool {
 821        self.tool_use_limit_reached
 822    }
 823
 824    /// Returns whether all of the tool uses have finished running.
 825    pub fn all_tools_finished(&self) -> bool {
 826        // If the only pending tool uses left are the ones with errors, then
 827        // that means that we've finished running all of the pending tools.
 828        self.tool_use
 829            .pending_tool_uses()
 830            .iter()
 831            .all(|tool_use| tool_use.status.is_error())
 832    }
 833
 834    pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
 835        self.tool_use.tool_uses_for_message(id, cx)
 836    }
 837
 838    pub fn tool_results_for_message(
 839        &self,
 840        assistant_message_id: MessageId,
 841    ) -> Vec<&LanguageModelToolResult> {
 842        self.tool_use.tool_results_for_message(assistant_message_id)
 843    }
 844
 845    pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
 846        self.tool_use.tool_result(id)
 847    }
 848
 849    pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
 850        Some(&self.tool_use.tool_result(id)?.content)
 851    }
 852
 853    pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
 854        self.tool_use.tool_result_card(id).cloned()
 855    }
 856
 857    /// Return tools that are both enabled and supported by the model
 858    pub fn available_tools(
 859        &self,
 860        cx: &App,
 861        model: Arc<dyn LanguageModel>,
 862    ) -> Vec<LanguageModelRequestTool> {
 863        if model.supports_tools() {
 864            self.tools()
 865                .read(cx)
 866                .enabled_tools(cx)
 867                .into_iter()
 868                .filter_map(|tool| {
 869                    // Skip tools that cannot be supported
 870                    let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
 871                    Some(LanguageModelRequestTool {
 872                        name: tool.name(),
 873                        description: tool.description(),
 874                        input_schema,
 875                    })
 876                })
 877                .collect()
 878        } else {
 879            Vec::default()
 880        }
 881    }
 882
 883    pub fn insert_user_message(
 884        &mut self,
 885        text: impl Into<String>,
 886        loaded_context: ContextLoadResult,
 887        git_checkpoint: Option<GitStoreCheckpoint>,
 888        creases: Vec<MessageCrease>,
 889        cx: &mut Context<Self>,
 890    ) -> MessageId {
 891        if !loaded_context.referenced_buffers.is_empty() {
 892            self.action_log.update(cx, |log, cx| {
 893                for buffer in loaded_context.referenced_buffers {
 894                    log.track_buffer(buffer, cx);
 895                }
 896            });
 897        }
 898
 899        let message_id = self.insert_message(
 900            Role::User,
 901            vec![MessageSegment::Text(text.into())],
 902            loaded_context.loaded_context,
 903            creases,
 904            cx,
 905        );
 906
 907        if let Some(git_checkpoint) = git_checkpoint {
 908            self.pending_checkpoint = Some(ThreadCheckpoint {
 909                message_id,
 910                git_checkpoint,
 911            });
 912        }
 913
 914        self.auto_capture_telemetry(cx);
 915
 916        message_id
 917    }
 918
 919    pub fn insert_assistant_message(
 920        &mut self,
 921        segments: Vec<MessageSegment>,
 922        cx: &mut Context<Self>,
 923    ) -> MessageId {
 924        self.insert_message(
 925            Role::Assistant,
 926            segments,
 927            LoadedContext::default(),
 928            Vec::new(),
 929            cx,
 930        )
 931    }
 932
 933    pub fn insert_message(
 934        &mut self,
 935        role: Role,
 936        segments: Vec<MessageSegment>,
 937        loaded_context: LoadedContext,
 938        creases: Vec<MessageCrease>,
 939        cx: &mut Context<Self>,
 940    ) -> MessageId {
 941        let id = self.next_message_id.post_inc();
 942        self.messages.push(Message {
 943            id,
 944            role,
 945            segments,
 946            loaded_context,
 947            creases,
 948        });
 949        self.touch_updated_at();
 950        cx.emit(ThreadEvent::MessageAdded(id));
 951        id
 952    }
 953
 954    pub fn edit_message(
 955        &mut self,
 956        id: MessageId,
 957        new_role: Role,
 958        new_segments: Vec<MessageSegment>,
 959        loaded_context: Option<LoadedContext>,
 960        cx: &mut Context<Self>,
 961    ) -> bool {
 962        let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
 963            return false;
 964        };
 965        message.role = new_role;
 966        message.segments = new_segments;
 967        if let Some(context) = loaded_context {
 968            message.loaded_context = context;
 969        }
 970        self.touch_updated_at();
 971        cx.emit(ThreadEvent::MessageEdited(id));
 972        true
 973    }
 974
 975    pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
 976        let Some(index) = self.messages.iter().position(|message| message.id == id) else {
 977            return false;
 978        };
 979        self.messages.remove(index);
 980        self.touch_updated_at();
 981        cx.emit(ThreadEvent::MessageDeleted(id));
 982        true
 983    }
 984
 985    /// Returns the representation of this [`Thread`] in a textual form.
 986    ///
 987    /// This is the representation we use when attaching a thread as context to another thread.
 988    pub fn text(&self) -> String {
 989        let mut text = String::new();
 990
 991        for message in &self.messages {
 992            text.push_str(match message.role {
 993                language_model::Role::User => "User:",
 994                language_model::Role::Assistant => "Assistant:",
 995                language_model::Role::System => "System:",
 996            });
 997            text.push('\n');
 998
 999            for segment in &message.segments {
1000                match segment {
1001                    MessageSegment::Text(content) => text.push_str(content),
1002                    MessageSegment::Thinking { text: content, .. } => {
1003                        text.push_str(&format!("<think>{}</think>", content))
1004                    }
1005                    MessageSegment::RedactedThinking(_) => {}
1006                }
1007            }
1008            text.push('\n');
1009        }
1010
1011        text
1012    }
1013
1014    /// Serializes this thread into a format for storage or telemetry.
1015    pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
1016        let initial_project_snapshot = self.initial_project_snapshot.clone();
1017        cx.spawn(async move |this, cx| {
1018            let initial_project_snapshot = initial_project_snapshot.await;
1019            this.read_with(cx, |this, cx| SerializedThread {
1020                version: SerializedThread::VERSION.to_string(),
1021                summary: this.summary_or_default(),
1022                updated_at: this.updated_at(),
1023                messages: this
1024                    .messages()
1025                    .map(|message| SerializedMessage {
1026                        id: message.id,
1027                        role: message.role,
1028                        segments: message
1029                            .segments
1030                            .iter()
1031                            .map(|segment| match segment {
1032                                MessageSegment::Text(text) => {
1033                                    SerializedMessageSegment::Text { text: text.clone() }
1034                                }
1035                                MessageSegment::Thinking { text, signature } => {
1036                                    SerializedMessageSegment::Thinking {
1037                                        text: text.clone(),
1038                                        signature: signature.clone(),
1039                                    }
1040                                }
1041                                MessageSegment::RedactedThinking(data) => {
1042                                    SerializedMessageSegment::RedactedThinking {
1043                                        data: data.clone(),
1044                                    }
1045                                }
1046                            })
1047                            .collect(),
1048                        tool_uses: this
1049                            .tool_uses_for_message(message.id, cx)
1050                            .into_iter()
1051                            .map(|tool_use| SerializedToolUse {
1052                                id: tool_use.id,
1053                                name: tool_use.name,
1054                                input: tool_use.input,
1055                            })
1056                            .collect(),
1057                        tool_results: this
1058                            .tool_results_for_message(message.id)
1059                            .into_iter()
1060                            .map(|tool_result| SerializedToolResult {
1061                                tool_use_id: tool_result.tool_use_id.clone(),
1062                                is_error: tool_result.is_error,
1063                                content: tool_result.content.clone(),
1064                            })
1065                            .collect(),
1066                        context: message.loaded_context.text.clone(),
1067                        creases: message
1068                            .creases
1069                            .iter()
1070                            .map(|crease| SerializedCrease {
1071                                start: crease.range.start,
1072                                end: crease.range.end,
1073                                icon_path: crease.metadata.icon_path.clone(),
1074                                label: crease.metadata.label.clone(),
1075                            })
1076                            .collect(),
1077                    })
1078                    .collect(),
1079                initial_project_snapshot,
1080                cumulative_token_usage: this.cumulative_token_usage,
1081                request_token_usage: this.request_token_usage.clone(),
1082                detailed_summary_state: this.detailed_summary_rx.borrow().clone(),
1083                exceeded_window_error: this.exceeded_window_error.clone(),
1084                model: this
1085                    .configured_model
1086                    .as_ref()
1087                    .map(|model| SerializedLanguageModel {
1088                        provider: model.provider.id().0.to_string(),
1089                        model: model.model.id().0.to_string(),
1090                    }),
1091            })
1092        })
1093    }
1094
1095    pub fn remaining_turns(&self) -> u32 {
1096        self.remaining_turns
1097    }
1098
1099    pub fn set_remaining_turns(&mut self, remaining_turns: u32) {
1100        self.remaining_turns = remaining_turns;
1101    }
1102
1103    pub fn send_to_model(
1104        &mut self,
1105        model: Arc<dyn LanguageModel>,
1106        window: Option<AnyWindowHandle>,
1107        cx: &mut Context<Self>,
1108    ) {
1109        if self.remaining_turns == 0 {
1110            return;
1111        }
1112
1113        self.remaining_turns -= 1;
1114
1115        let request = self.to_completion_request(model.clone(), cx);
1116
1117        self.stream_completion(request, model, window, cx);
1118    }
1119
1120    pub fn used_tools_since_last_user_message(&self) -> bool {
1121        for message in self.messages.iter().rev() {
1122            if self.tool_use.message_has_tool_results(message.id) {
1123                return true;
1124            } else if message.role == Role::User {
1125                return false;
1126            }
1127        }
1128
1129        false
1130    }
1131
1132    pub fn to_completion_request(
1133        &self,
1134        model: Arc<dyn LanguageModel>,
1135        cx: &mut Context<Self>,
1136    ) -> LanguageModelRequest {
1137        let mut request = LanguageModelRequest {
1138            thread_id: Some(self.id.to_string()),
1139            prompt_id: Some(self.last_prompt_id.to_string()),
1140            mode: None,
1141            messages: vec![],
1142            tools: Vec::new(),
1143            stop: Vec::new(),
1144            temperature: None,
1145        };
1146
1147        let available_tools = self.available_tools(cx, model.clone());
1148        let available_tool_names = available_tools
1149            .iter()
1150            .map(|tool| tool.name.clone())
1151            .collect();
1152
1153        let model_context = &ModelContext {
1154            available_tools: available_tool_names,
1155        };
1156
1157        if let Some(project_context) = self.project_context.borrow().as_ref() {
1158            match self
1159                .prompt_builder
1160                .generate_assistant_system_prompt(project_context, model_context)
1161            {
1162                Err(err) => {
1163                    let message = format!("{err:?}").into();
1164                    log::error!("{message}");
1165                    cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1166                        header: "Error generating system prompt".into(),
1167                        message,
1168                    }));
1169                }
1170                Ok(system_prompt) => {
1171                    request.messages.push(LanguageModelRequestMessage {
1172                        role: Role::System,
1173                        content: vec![MessageContent::Text(system_prompt)],
1174                        cache: true,
1175                    });
1176                }
1177            }
1178        } else {
1179            let message = "Context for system prompt unexpectedly not ready.".into();
1180            log::error!("{message}");
1181            cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1182                header: "Error generating system prompt".into(),
1183                message,
1184            }));
1185        }
1186
1187        for message in &self.messages {
1188            let mut request_message = LanguageModelRequestMessage {
1189                role: message.role,
1190                content: Vec::new(),
1191                cache: false,
1192            };
1193
1194            message
1195                .loaded_context
1196                .add_to_request_message(&mut request_message);
1197
1198            for segment in &message.segments {
1199                match segment {
1200                    MessageSegment::Text(text) => {
1201                        if !text.is_empty() {
1202                            request_message
1203                                .content
1204                                .push(MessageContent::Text(text.into()));
1205                        }
1206                    }
1207                    MessageSegment::Thinking { text, signature } => {
1208                        if !text.is_empty() {
1209                            request_message.content.push(MessageContent::Thinking {
1210                                text: text.into(),
1211                                signature: signature.clone(),
1212                            });
1213                        }
1214                    }
1215                    MessageSegment::RedactedThinking(data) => {
1216                        request_message
1217                            .content
1218                            .push(MessageContent::RedactedThinking(data.clone()));
1219                    }
1220                };
1221            }
1222
1223            self.tool_use
1224                .attach_tool_uses(message.id, &mut request_message);
1225
1226            request.messages.push(request_message);
1227
1228            if let Some(tool_results_message) = self.tool_use.tool_results_message(message.id) {
1229                request.messages.push(tool_results_message);
1230            }
1231        }
1232
1233        // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1234        if let Some(last) = request.messages.last_mut() {
1235            last.cache = true;
1236        }
1237
1238        self.attached_tracked_files_state(&mut request.messages, cx);
1239
1240        request.tools = available_tools;
1241        request.mode = if model.supports_max_mode() {
1242            Some(self.completion_mode)
1243        } else {
1244            Some(CompletionMode::Normal)
1245        };
1246
1247        request
1248    }
1249
1250    fn to_summarize_request(&self, added_user_message: String) -> LanguageModelRequest {
1251        let mut request = LanguageModelRequest {
1252            thread_id: None,
1253            prompt_id: None,
1254            mode: None,
1255            messages: vec![],
1256            tools: Vec::new(),
1257            stop: Vec::new(),
1258            temperature: None,
1259        };
1260
1261        for message in &self.messages {
1262            let mut request_message = LanguageModelRequestMessage {
1263                role: message.role,
1264                content: Vec::new(),
1265                cache: false,
1266            };
1267
1268            for segment in &message.segments {
1269                match segment {
1270                    MessageSegment::Text(text) => request_message
1271                        .content
1272                        .push(MessageContent::Text(text.clone())),
1273                    MessageSegment::Thinking { .. } => {}
1274                    MessageSegment::RedactedThinking(_) => {}
1275                }
1276            }
1277
1278            if request_message.content.is_empty() {
1279                continue;
1280            }
1281
1282            request.messages.push(request_message);
1283        }
1284
1285        request.messages.push(LanguageModelRequestMessage {
1286            role: Role::User,
1287            content: vec![MessageContent::Text(added_user_message)],
1288            cache: false,
1289        });
1290
1291        request
1292    }
1293
1294    fn attached_tracked_files_state(
1295        &self,
1296        messages: &mut Vec<LanguageModelRequestMessage>,
1297        cx: &App,
1298    ) {
1299        const STALE_FILES_HEADER: &str = "These files changed since last read:";
1300
1301        let mut stale_message = String::new();
1302
1303        let action_log = self.action_log.read(cx);
1304
1305        for stale_file in action_log.stale_buffers(cx) {
1306            let Some(file) = stale_file.read(cx).file() else {
1307                continue;
1308            };
1309
1310            if stale_message.is_empty() {
1311                write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
1312            }
1313
1314            writeln!(&mut stale_message, "- {}", file.path().display()).ok();
1315        }
1316
1317        let mut content = Vec::with_capacity(2);
1318
1319        if !stale_message.is_empty() {
1320            content.push(stale_message.into());
1321        }
1322
1323        if !content.is_empty() {
1324            let context_message = LanguageModelRequestMessage {
1325                role: Role::User,
1326                content,
1327                cache: false,
1328            };
1329
1330            messages.push(context_message);
1331        }
1332    }
1333
1334    pub fn stream_completion(
1335        &mut self,
1336        request: LanguageModelRequest,
1337        model: Arc<dyn LanguageModel>,
1338        window: Option<AnyWindowHandle>,
1339        cx: &mut Context<Self>,
1340    ) {
1341        self.tool_use_limit_reached = false;
1342
1343        let pending_completion_id = post_inc(&mut self.completion_count);
1344        let mut request_callback_parameters = if self.request_callback.is_some() {
1345            Some((request.clone(), Vec::new()))
1346        } else {
1347            None
1348        };
1349        let prompt_id = self.last_prompt_id.clone();
1350        let tool_use_metadata = ToolUseMetadata {
1351            model: model.clone(),
1352            thread_id: self.id.clone(),
1353            prompt_id: prompt_id.clone(),
1354        };
1355
1356        self.last_received_chunk_at = Some(Instant::now());
1357
1358        let task = cx.spawn(async move |thread, cx| {
1359            let stream_completion_future = model.stream_completion_with_usage(request, &cx);
1360            let initial_token_usage =
1361                thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1362            let stream_completion = async {
1363                let (mut events, usage) = stream_completion_future.await?;
1364
1365                let mut stop_reason = StopReason::EndTurn;
1366                let mut current_token_usage = TokenUsage::default();
1367
1368                thread
1369                    .update(cx, |_thread, cx| {
1370                        if let Some(usage) = usage {
1371                            cx.emit(ThreadEvent::UsageUpdated(usage));
1372                        }
1373                        cx.emit(ThreadEvent::NewRequest);
1374                    })
1375                    .ok();
1376
1377                let mut request_assistant_message_id = None;
1378
1379                while let Some(event) = events.next().await {
1380                    if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1381                        response_events
1382                            .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1383                    }
1384
1385                    thread.update(cx, |thread, cx| {
1386                        let event = match event {
1387                            Ok(event) => event,
1388                            Err(LanguageModelCompletionError::BadInputJson {
1389                                id,
1390                                tool_name,
1391                                raw_input: invalid_input_json,
1392                                json_parse_error,
1393                            }) => {
1394                                thread.receive_invalid_tool_json(
1395                                    id,
1396                                    tool_name,
1397                                    invalid_input_json,
1398                                    json_parse_error,
1399                                    window,
1400                                    cx,
1401                                );
1402                                return Ok(());
1403                            }
1404                            Err(LanguageModelCompletionError::Other(error)) => {
1405                                return Err(error);
1406                            }
1407                        };
1408
1409                        match event {
1410                            LanguageModelCompletionEvent::StartMessage { .. } => {
1411                                request_assistant_message_id =
1412                                    Some(thread.insert_assistant_message(
1413                                        vec![MessageSegment::Text(String::new())],
1414                                        cx,
1415                                    ));
1416                            }
1417                            LanguageModelCompletionEvent::Stop(reason) => {
1418                                stop_reason = reason;
1419                            }
1420                            LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1421                                thread.update_token_usage_at_last_message(token_usage);
1422                                thread.cumulative_token_usage = thread.cumulative_token_usage
1423                                    + token_usage
1424                                    - current_token_usage;
1425                                current_token_usage = token_usage;
1426                            }
1427                            LanguageModelCompletionEvent::Text(chunk) => {
1428                                thread.received_chunk();
1429
1430                                cx.emit(ThreadEvent::ReceivedTextChunk);
1431                                if let Some(last_message) = thread.messages.last_mut() {
1432                                    if last_message.role == Role::Assistant
1433                                        && !thread.tool_use.has_tool_results(last_message.id)
1434                                    {
1435                                        last_message.push_text(&chunk);
1436                                        cx.emit(ThreadEvent::StreamedAssistantText(
1437                                            last_message.id,
1438                                            chunk,
1439                                        ));
1440                                    } else {
1441                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
1442                                        // of a new Assistant response.
1443                                        //
1444                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1445                                        // will result in duplicating the text of the chunk in the rendered Markdown.
1446                                        request_assistant_message_id =
1447                                            Some(thread.insert_assistant_message(
1448                                                vec![MessageSegment::Text(chunk.to_string())],
1449                                                cx,
1450                                            ));
1451                                    };
1452                                }
1453                            }
1454                            LanguageModelCompletionEvent::Thinking {
1455                                text: chunk,
1456                                signature,
1457                            } => {
1458                                thread.received_chunk();
1459
1460                                if let Some(last_message) = thread.messages.last_mut() {
1461                                    if last_message.role == Role::Assistant
1462                                        && !thread.tool_use.has_tool_results(last_message.id)
1463                                    {
1464                                        last_message.push_thinking(&chunk, signature);
1465                                        cx.emit(ThreadEvent::StreamedAssistantThinking(
1466                                            last_message.id,
1467                                            chunk,
1468                                        ));
1469                                    } else {
1470                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
1471                                        // of a new Assistant response.
1472                                        //
1473                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1474                                        // will result in duplicating the text of the chunk in the rendered Markdown.
1475                                        request_assistant_message_id =
1476                                            Some(thread.insert_assistant_message(
1477                                                vec![MessageSegment::Thinking {
1478                                                    text: chunk.to_string(),
1479                                                    signature,
1480                                                }],
1481                                                cx,
1482                                            ));
1483                                    };
1484                                }
1485                            }
1486                            LanguageModelCompletionEvent::ToolUse(tool_use) => {
1487                                let last_assistant_message_id = request_assistant_message_id
1488                                    .unwrap_or_else(|| {
1489                                        let new_assistant_message_id =
1490                                            thread.insert_assistant_message(vec![], cx);
1491                                        request_assistant_message_id =
1492                                            Some(new_assistant_message_id);
1493                                        new_assistant_message_id
1494                                    });
1495
1496                                let tool_use_id = tool_use.id.clone();
1497                                let streamed_input = if tool_use.is_input_complete {
1498                                    None
1499                                } else {
1500                                    Some((&tool_use.input).clone())
1501                                };
1502
1503                                let ui_text = thread.tool_use.request_tool_use(
1504                                    last_assistant_message_id,
1505                                    tool_use,
1506                                    tool_use_metadata.clone(),
1507                                    cx,
1508                                );
1509
1510                                if let Some(input) = streamed_input {
1511                                    cx.emit(ThreadEvent::StreamedToolUse {
1512                                        tool_use_id,
1513                                        ui_text,
1514                                        input,
1515                                    });
1516                                }
1517                            }
1518                            LanguageModelCompletionEvent::QueueUpdate(status) => {
1519                                if let Some(completion) = thread
1520                                    .pending_completions
1521                                    .iter_mut()
1522                                    .find(|completion| completion.id == pending_completion_id)
1523                                {
1524                                    let queue_state = match status {
1525                                        language_model::CompletionRequestStatus::Queued {
1526                                            position,
1527                                        } => Some(QueueState::Queued { position }),
1528                                        language_model::CompletionRequestStatus::Started => {
1529                                            Some(QueueState::Started)
1530                                        }
1531                                        language_model::CompletionRequestStatus::ToolUseLimitReached => {
1532                                            thread.tool_use_limit_reached = true;
1533                                            None
1534                                        }
1535                                    };
1536
1537                                    if let Some(queue_state) = queue_state {
1538                                        completion.queue_state = queue_state;
1539                                    }
1540                                }
1541                            }
1542                        }
1543
1544                        thread.touch_updated_at();
1545                        cx.emit(ThreadEvent::StreamedCompletion);
1546                        cx.notify();
1547
1548                        thread.auto_capture_telemetry(cx);
1549                        Ok(())
1550                    })??;
1551
1552                    smol::future::yield_now().await;
1553                }
1554
1555                thread.update(cx, |thread, cx| {
1556                    thread.last_received_chunk_at = None;
1557                    thread
1558                        .pending_completions
1559                        .retain(|completion| completion.id != pending_completion_id);
1560
1561                    // If there is a response without tool use, summarize the message. Otherwise,
1562                    // allow two tool uses before summarizing.
1563                    if thread.summary.is_none()
1564                        && thread.messages.len() >= 2
1565                        && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
1566                    {
1567                        thread.summarize(cx);
1568                    }
1569                })?;
1570
1571                anyhow::Ok(stop_reason)
1572            };
1573
1574            let result = stream_completion.await;
1575
1576            thread
1577                .update(cx, |thread, cx| {
1578                    thread.finalize_pending_checkpoint(cx);
1579                    match result.as_ref() {
1580                        Ok(stop_reason) => match stop_reason {
1581                            StopReason::ToolUse => {
1582                                let tool_uses = thread.use_pending_tools(window, cx, model.clone());
1583                                cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1584                            }
1585                            StopReason::EndTurn => {}
1586                            StopReason::MaxTokens => {}
1587                        },
1588                        Err(error) => {
1589                            if error.is::<PaymentRequiredError>() {
1590                                cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1591                            } else if error.is::<MaxMonthlySpendReachedError>() {
1592                                cx.emit(ThreadEvent::ShowError(
1593                                    ThreadError::MaxMonthlySpendReached,
1594                                ));
1595                            } else if let Some(error) =
1596                                error.downcast_ref::<ModelRequestLimitReachedError>()
1597                            {
1598                                cx.emit(ThreadEvent::ShowError(
1599                                    ThreadError::ModelRequestLimitReached { plan: error.plan },
1600                                ));
1601                            } else if let Some(known_error) =
1602                                error.downcast_ref::<LanguageModelKnownError>()
1603                            {
1604                                match known_error {
1605                                    LanguageModelKnownError::ContextWindowLimitExceeded {
1606                                        tokens,
1607                                    } => {
1608                                        thread.exceeded_window_error = Some(ExceededWindowError {
1609                                            model_id: model.id(),
1610                                            token_count: *tokens,
1611                                        });
1612                                        cx.notify();
1613                                    }
1614                                }
1615                            } else {
1616                                let error_message = error
1617                                    .chain()
1618                                    .map(|err| err.to_string())
1619                                    .collect::<Vec<_>>()
1620                                    .join("\n");
1621                                cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1622                                    header: "Error interacting with language model".into(),
1623                                    message: SharedString::from(error_message.clone()),
1624                                }));
1625                            }
1626
1627                            thread.cancel_last_completion(window, cx);
1628                        }
1629                    }
1630                    cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1631
1632                    if let Some((request_callback, (request, response_events))) = thread
1633                        .request_callback
1634                        .as_mut()
1635                        .zip(request_callback_parameters.as_ref())
1636                    {
1637                        request_callback(request, response_events);
1638                    }
1639
1640                    thread.auto_capture_telemetry(cx);
1641
1642                    if let Ok(initial_usage) = initial_token_usage {
1643                        let usage = thread.cumulative_token_usage - initial_usage;
1644
1645                        telemetry::event!(
1646                            "Assistant Thread Completion",
1647                            thread_id = thread.id().to_string(),
1648                            prompt_id = prompt_id,
1649                            model = model.telemetry_id(),
1650                            model_provider = model.provider_id().to_string(),
1651                            input_tokens = usage.input_tokens,
1652                            output_tokens = usage.output_tokens,
1653                            cache_creation_input_tokens = usage.cache_creation_input_tokens,
1654                            cache_read_input_tokens = usage.cache_read_input_tokens,
1655                        );
1656                    }
1657                })
1658                .ok();
1659        });
1660
1661        self.pending_completions.push(PendingCompletion {
1662            id: pending_completion_id,
1663            queue_state: QueueState::Sending,
1664            _task: task,
1665        });
1666    }
1667
1668    pub fn summarize(&mut self, cx: &mut Context<Self>) {
1669        let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1670            return;
1671        };
1672
1673        if !model.provider.is_authenticated(cx) {
1674            return;
1675        }
1676
1677        let added_user_message = "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1678            Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1679            If the conversation is about a specific subject, include it in the title. \
1680            Be descriptive. DO NOT speak in the first person.";
1681
1682        let request = self.to_summarize_request(added_user_message.into());
1683
1684        self.pending_summary = cx.spawn(async move |this, cx| {
1685            async move {
1686                let stream = model.model.stream_completion_text_with_usage(request, &cx);
1687                let (mut messages, usage) = stream.await?;
1688
1689                if let Some(usage) = usage {
1690                    this.update(cx, |_thread, cx| {
1691                        cx.emit(ThreadEvent::UsageUpdated(usage));
1692                    })
1693                    .ok();
1694                }
1695
1696                let mut new_summary = String::new();
1697                while let Some(message) = messages.stream.next().await {
1698                    let text = message?;
1699                    let mut lines = text.lines();
1700                    new_summary.extend(lines.next());
1701
1702                    // Stop if the LLM generated multiple lines.
1703                    if lines.next().is_some() {
1704                        break;
1705                    }
1706                }
1707
1708                this.update(cx, |this, cx| {
1709                    if !new_summary.is_empty() {
1710                        this.summary = Some(new_summary.into());
1711                    }
1712
1713                    cx.emit(ThreadEvent::SummaryGenerated);
1714                })?;
1715
1716                anyhow::Ok(())
1717            }
1718            .log_err()
1719            .await
1720        });
1721    }
1722
1723    pub fn start_generating_detailed_summary_if_needed(
1724        &mut self,
1725        thread_store: WeakEntity<ThreadStore>,
1726        cx: &mut Context<Self>,
1727    ) {
1728        let Some(last_message_id) = self.messages.last().map(|message| message.id) else {
1729            return;
1730        };
1731
1732        match &*self.detailed_summary_rx.borrow() {
1733            DetailedSummaryState::Generating { message_id, .. }
1734            | DetailedSummaryState::Generated { message_id, .. }
1735                if *message_id == last_message_id =>
1736            {
1737                // Already up-to-date
1738                return;
1739            }
1740            _ => {}
1741        }
1742
1743        let Some(ConfiguredModel { model, provider }) =
1744            LanguageModelRegistry::read_global(cx).thread_summary_model()
1745        else {
1746            return;
1747        };
1748
1749        if !provider.is_authenticated(cx) {
1750            return;
1751        }
1752
1753        let added_user_message = "Generate a detailed summary of this conversation. Include:\n\
1754             1. A brief overview of what was discussed\n\
1755             2. Key facts or information discovered\n\
1756             3. Outcomes or conclusions reached\n\
1757             4. Any action items or next steps if any\n\
1758             Format it in Markdown with headings and bullet points.";
1759
1760        let request = self.to_summarize_request(added_user_message.into());
1761
1762        *self.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generating {
1763            message_id: last_message_id,
1764        };
1765
1766        // Replace the detailed summarization task if there is one, cancelling it. It would probably
1767        // be better to allow the old task to complete, but this would require logic for choosing
1768        // which result to prefer (the old task could complete after the new one, resulting in a
1769        // stale summary).
1770        self.detailed_summary_task = cx.spawn(async move |thread, cx| {
1771            let stream = model.stream_completion_text(request, &cx);
1772            let Some(mut messages) = stream.await.log_err() else {
1773                thread
1774                    .update(cx, |thread, _cx| {
1775                        *thread.detailed_summary_tx.borrow_mut() =
1776                            DetailedSummaryState::NotGenerated;
1777                    })
1778                    .ok()?;
1779                return None;
1780            };
1781
1782            let mut new_detailed_summary = String::new();
1783
1784            while let Some(chunk) = messages.stream.next().await {
1785                if let Some(chunk) = chunk.log_err() {
1786                    new_detailed_summary.push_str(&chunk);
1787                }
1788            }
1789
1790            thread
1791                .update(cx, |thread, _cx| {
1792                    *thread.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generated {
1793                        text: new_detailed_summary.into(),
1794                        message_id: last_message_id,
1795                    };
1796                })
1797                .ok()?;
1798
1799            // Save thread so its summary can be reused later
1800            if let Some(thread) = thread.upgrade() {
1801                if let Ok(Ok(save_task)) = cx.update(|cx| {
1802                    thread_store
1803                        .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx))
1804                }) {
1805                    save_task.await.log_err();
1806                }
1807            }
1808
1809            Some(())
1810        });
1811    }
1812
1813    pub async fn wait_for_detailed_summary_or_text(
1814        this: &Entity<Self>,
1815        cx: &mut AsyncApp,
1816    ) -> Option<SharedString> {
1817        let mut detailed_summary_rx = this
1818            .read_with(cx, |this, _cx| this.detailed_summary_rx.clone())
1819            .ok()?;
1820        loop {
1821            match detailed_summary_rx.recv().await? {
1822                DetailedSummaryState::Generating { .. } => {}
1823                DetailedSummaryState::NotGenerated => {
1824                    return this.read_with(cx, |this, _cx| this.text().into()).ok();
1825                }
1826                DetailedSummaryState::Generated { text, .. } => return Some(text),
1827            }
1828        }
1829    }
1830
1831    pub fn latest_detailed_summary_or_text(&self) -> SharedString {
1832        self.detailed_summary_rx
1833            .borrow()
1834            .text()
1835            .unwrap_or_else(|| self.text().into())
1836    }
1837
1838    pub fn is_generating_detailed_summary(&self) -> bool {
1839        matches!(
1840            &*self.detailed_summary_rx.borrow(),
1841            DetailedSummaryState::Generating { .. }
1842        )
1843    }
1844
1845    pub fn use_pending_tools(
1846        &mut self,
1847        window: Option<AnyWindowHandle>,
1848        cx: &mut Context<Self>,
1849        model: Arc<dyn LanguageModel>,
1850    ) -> Vec<PendingToolUse> {
1851        self.auto_capture_telemetry(cx);
1852        let request = self.to_completion_request(model, cx);
1853        let messages = Arc::new(request.messages);
1854        let pending_tool_uses = self
1855            .tool_use
1856            .pending_tool_uses()
1857            .into_iter()
1858            .filter(|tool_use| tool_use.status.is_idle())
1859            .cloned()
1860            .collect::<Vec<_>>();
1861
1862        for tool_use in pending_tool_uses.iter() {
1863            if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
1864                if tool.needs_confirmation(&tool_use.input, cx)
1865                    && !AssistantSettings::get_global(cx).always_allow_tool_actions
1866                {
1867                    self.tool_use.confirm_tool_use(
1868                        tool_use.id.clone(),
1869                        tool_use.ui_text.clone(),
1870                        tool_use.input.clone(),
1871                        messages.clone(),
1872                        tool,
1873                    );
1874                    cx.emit(ThreadEvent::ToolConfirmationNeeded);
1875                } else {
1876                    self.run_tool(
1877                        tool_use.id.clone(),
1878                        tool_use.ui_text.clone(),
1879                        tool_use.input.clone(),
1880                        &messages,
1881                        tool,
1882                        window,
1883                        cx,
1884                    );
1885                }
1886            }
1887        }
1888
1889        pending_tool_uses
1890    }
1891
1892    pub fn receive_invalid_tool_json(
1893        &mut self,
1894        tool_use_id: LanguageModelToolUseId,
1895        tool_name: Arc<str>,
1896        invalid_json: Arc<str>,
1897        error: String,
1898        window: Option<AnyWindowHandle>,
1899        cx: &mut Context<Thread>,
1900    ) {
1901        log::error!("The model returned invalid input JSON: {invalid_json}");
1902
1903        let pending_tool_use = self.tool_use.insert_tool_output(
1904            tool_use_id.clone(),
1905            tool_name,
1906            Err(anyhow!("Error parsing input JSON: {error}")),
1907            self.configured_model.as_ref(),
1908        );
1909        let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
1910            pending_tool_use.ui_text.clone()
1911        } else {
1912            log::error!(
1913                "There was no pending tool use for tool use {tool_use_id}, even though it finished (with invalid input JSON)."
1914            );
1915            format!("Unknown tool {}", tool_use_id).into()
1916        };
1917
1918        cx.emit(ThreadEvent::InvalidToolInput {
1919            tool_use_id: tool_use_id.clone(),
1920            ui_text,
1921            invalid_input_json: invalid_json,
1922        });
1923
1924        self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
1925    }
1926
1927    pub fn run_tool(
1928        &mut self,
1929        tool_use_id: LanguageModelToolUseId,
1930        ui_text: impl Into<SharedString>,
1931        input: serde_json::Value,
1932        messages: &[LanguageModelRequestMessage],
1933        tool: Arc<dyn Tool>,
1934        window: Option<AnyWindowHandle>,
1935        cx: &mut Context<Thread>,
1936    ) {
1937        let task = self.spawn_tool_use(tool_use_id.clone(), messages, input, tool, window, cx);
1938        self.tool_use
1939            .run_pending_tool(tool_use_id, ui_text.into(), task);
1940    }
1941
1942    fn spawn_tool_use(
1943        &mut self,
1944        tool_use_id: LanguageModelToolUseId,
1945        messages: &[LanguageModelRequestMessage],
1946        input: serde_json::Value,
1947        tool: Arc<dyn Tool>,
1948        window: Option<AnyWindowHandle>,
1949        cx: &mut Context<Thread>,
1950    ) -> Task<()> {
1951        let tool_name: Arc<str> = tool.name().into();
1952
1953        let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
1954            Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))).into()
1955        } else {
1956            tool.run(
1957                input,
1958                messages,
1959                self.project.clone(),
1960                self.action_log.clone(),
1961                window,
1962                cx,
1963            )
1964        };
1965
1966        // Store the card separately if it exists
1967        if let Some(card) = tool_result.card.clone() {
1968            self.tool_use
1969                .insert_tool_result_card(tool_use_id.clone(), card);
1970        }
1971
1972        cx.spawn({
1973            async move |thread: WeakEntity<Thread>, cx| {
1974                let output = tool_result.output.await;
1975
1976                thread
1977                    .update(cx, |thread, cx| {
1978                        let pending_tool_use = thread.tool_use.insert_tool_output(
1979                            tool_use_id.clone(),
1980                            tool_name,
1981                            output,
1982                            thread.configured_model.as_ref(),
1983                        );
1984                        thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
1985                    })
1986                    .ok();
1987            }
1988        })
1989    }
1990
1991    fn tool_finished(
1992        &mut self,
1993        tool_use_id: LanguageModelToolUseId,
1994        pending_tool_use: Option<PendingToolUse>,
1995        canceled: bool,
1996        window: Option<AnyWindowHandle>,
1997        cx: &mut Context<Self>,
1998    ) {
1999        if self.all_tools_finished() {
2000            if let Some(ConfiguredModel { model, .. }) = self.configured_model.as_ref() {
2001                if !canceled {
2002                    self.send_to_model(model.clone(), window, cx);
2003                }
2004                self.auto_capture_telemetry(cx);
2005            }
2006        }
2007
2008        cx.emit(ThreadEvent::ToolFinished {
2009            tool_use_id,
2010            pending_tool_use,
2011        });
2012    }
2013
2014    /// Cancels the last pending completion, if there are any pending.
2015    ///
2016    /// Returns whether a completion was canceled.
2017    pub fn cancel_last_completion(
2018        &mut self,
2019        window: Option<AnyWindowHandle>,
2020        cx: &mut Context<Self>,
2021    ) -> bool {
2022        let mut canceled = self.pending_completions.pop().is_some();
2023
2024        for pending_tool_use in self.tool_use.cancel_pending() {
2025            canceled = true;
2026            self.tool_finished(
2027                pending_tool_use.id.clone(),
2028                Some(pending_tool_use),
2029                true,
2030                window,
2031                cx,
2032            );
2033        }
2034
2035        self.finalize_pending_checkpoint(cx);
2036
2037        if canceled {
2038            cx.emit(ThreadEvent::CompletionCanceled);
2039        }
2040
2041        canceled
2042    }
2043
2044    /// Signals that any in-progress editing should be canceled.
2045    ///
2046    /// This method is used to notify listeners (like ActiveThread) that
2047    /// they should cancel any editing operations.
2048    pub fn cancel_editing(&mut self, cx: &mut Context<Self>) {
2049        cx.emit(ThreadEvent::CancelEditing);
2050    }
2051
2052    pub fn feedback(&self) -> Option<ThreadFeedback> {
2053        self.feedback
2054    }
2055
2056    pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
2057        self.message_feedback.get(&message_id).copied()
2058    }
2059
2060    pub fn report_message_feedback(
2061        &mut self,
2062        message_id: MessageId,
2063        feedback: ThreadFeedback,
2064        cx: &mut Context<Self>,
2065    ) -> Task<Result<()>> {
2066        if self.message_feedback.get(&message_id) == Some(&feedback) {
2067            return Task::ready(Ok(()));
2068        }
2069
2070        let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2071        let serialized_thread = self.serialize(cx);
2072        let thread_id = self.id().clone();
2073        let client = self.project.read(cx).client();
2074
2075        let enabled_tool_names: Vec<String> = self
2076            .tools()
2077            .read(cx)
2078            .enabled_tools(cx)
2079            .iter()
2080            .map(|tool| tool.name().to_string())
2081            .collect();
2082
2083        self.message_feedback.insert(message_id, feedback);
2084
2085        cx.notify();
2086
2087        let message_content = self
2088            .message(message_id)
2089            .map(|msg| msg.to_string())
2090            .unwrap_or_default();
2091
2092        cx.background_spawn(async move {
2093            let final_project_snapshot = final_project_snapshot.await;
2094            let serialized_thread = serialized_thread.await?;
2095            let thread_data =
2096                serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
2097
2098            let rating = match feedback {
2099                ThreadFeedback::Positive => "positive",
2100                ThreadFeedback::Negative => "negative",
2101            };
2102            telemetry::event!(
2103                "Assistant Thread Rated",
2104                rating,
2105                thread_id,
2106                enabled_tool_names,
2107                message_id = message_id.0,
2108                message_content,
2109                thread_data,
2110                final_project_snapshot
2111            );
2112            client.telemetry().flush_events().await;
2113
2114            Ok(())
2115        })
2116    }
2117
2118    pub fn report_feedback(
2119        &mut self,
2120        feedback: ThreadFeedback,
2121        cx: &mut Context<Self>,
2122    ) -> Task<Result<()>> {
2123        let last_assistant_message_id = self
2124            .messages
2125            .iter()
2126            .rev()
2127            .find(|msg| msg.role == Role::Assistant)
2128            .map(|msg| msg.id);
2129
2130        if let Some(message_id) = last_assistant_message_id {
2131            self.report_message_feedback(message_id, feedback, cx)
2132        } else {
2133            let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2134            let serialized_thread = self.serialize(cx);
2135            let thread_id = self.id().clone();
2136            let client = self.project.read(cx).client();
2137            self.feedback = Some(feedback);
2138            cx.notify();
2139
2140            cx.background_spawn(async move {
2141                let final_project_snapshot = final_project_snapshot.await;
2142                let serialized_thread = serialized_thread.await?;
2143                let thread_data = serde_json::to_value(serialized_thread)
2144                    .unwrap_or_else(|_| serde_json::Value::Null);
2145
2146                let rating = match feedback {
2147                    ThreadFeedback::Positive => "positive",
2148                    ThreadFeedback::Negative => "negative",
2149                };
2150                telemetry::event!(
2151                    "Assistant Thread Rated",
2152                    rating,
2153                    thread_id,
2154                    thread_data,
2155                    final_project_snapshot
2156                );
2157                client.telemetry().flush_events().await;
2158
2159                Ok(())
2160            })
2161        }
2162    }
2163
2164    /// Create a snapshot of the current project state including git information and unsaved buffers.
2165    fn project_snapshot(
2166        project: Entity<Project>,
2167        cx: &mut Context<Self>,
2168    ) -> Task<Arc<ProjectSnapshot>> {
2169        let git_store = project.read(cx).git_store().clone();
2170        let worktree_snapshots: Vec<_> = project
2171            .read(cx)
2172            .visible_worktrees(cx)
2173            .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
2174            .collect();
2175
2176        cx.spawn(async move |_, cx| {
2177            let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
2178
2179            let mut unsaved_buffers = Vec::new();
2180            cx.update(|app_cx| {
2181                let buffer_store = project.read(app_cx).buffer_store();
2182                for buffer_handle in buffer_store.read(app_cx).buffers() {
2183                    let buffer = buffer_handle.read(app_cx);
2184                    if buffer.is_dirty() {
2185                        if let Some(file) = buffer.file() {
2186                            let path = file.path().to_string_lossy().to_string();
2187                            unsaved_buffers.push(path);
2188                        }
2189                    }
2190                }
2191            })
2192            .ok();
2193
2194            Arc::new(ProjectSnapshot {
2195                worktree_snapshots,
2196                unsaved_buffer_paths: unsaved_buffers,
2197                timestamp: Utc::now(),
2198            })
2199        })
2200    }
2201
2202    fn worktree_snapshot(
2203        worktree: Entity<project::Worktree>,
2204        git_store: Entity<GitStore>,
2205        cx: &App,
2206    ) -> Task<WorktreeSnapshot> {
2207        cx.spawn(async move |cx| {
2208            // Get worktree path and snapshot
2209            let worktree_info = cx.update(|app_cx| {
2210                let worktree = worktree.read(app_cx);
2211                let path = worktree.abs_path().to_string_lossy().to_string();
2212                let snapshot = worktree.snapshot();
2213                (path, snapshot)
2214            });
2215
2216            let Ok((worktree_path, _snapshot)) = worktree_info else {
2217                return WorktreeSnapshot {
2218                    worktree_path: String::new(),
2219                    git_state: None,
2220                };
2221            };
2222
2223            let git_state = git_store
2224                .update(cx, |git_store, cx| {
2225                    git_store
2226                        .repositories()
2227                        .values()
2228                        .find(|repo| {
2229                            repo.read(cx)
2230                                .abs_path_to_repo_path(&worktree.read(cx).abs_path())
2231                                .is_some()
2232                        })
2233                        .cloned()
2234                })
2235                .ok()
2236                .flatten()
2237                .map(|repo| {
2238                    repo.update(cx, |repo, _| {
2239                        let current_branch =
2240                            repo.branch.as_ref().map(|branch| branch.name().to_owned());
2241                        repo.send_job(None, |state, _| async move {
2242                            let RepositoryState::Local { backend, .. } = state else {
2243                                return GitState {
2244                                    remote_url: None,
2245                                    head_sha: None,
2246                                    current_branch,
2247                                    diff: None,
2248                                };
2249                            };
2250
2251                            let remote_url = backend.remote_url("origin");
2252                            let head_sha = backend.head_sha().await;
2253                            let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
2254
2255                            GitState {
2256                                remote_url,
2257                                head_sha,
2258                                current_branch,
2259                                diff,
2260                            }
2261                        })
2262                    })
2263                });
2264
2265            let git_state = match git_state {
2266                Some(git_state) => match git_state.ok() {
2267                    Some(git_state) => git_state.await.ok(),
2268                    None => None,
2269                },
2270                None => None,
2271            };
2272
2273            WorktreeSnapshot {
2274                worktree_path,
2275                git_state,
2276            }
2277        })
2278    }
2279
2280    pub fn to_markdown(&self, cx: &App) -> Result<String> {
2281        let mut markdown = Vec::new();
2282
2283        if let Some(summary) = self.summary() {
2284            writeln!(markdown, "# {summary}\n")?;
2285        };
2286
2287        for message in self.messages() {
2288            writeln!(
2289                markdown,
2290                "## {role}\n",
2291                role = match message.role {
2292                    Role::User => "User",
2293                    Role::Assistant => "Assistant",
2294                    Role::System => "System",
2295                }
2296            )?;
2297
2298            if !message.loaded_context.text.is_empty() {
2299                writeln!(markdown, "{}", message.loaded_context.text)?;
2300            }
2301
2302            if !message.loaded_context.images.is_empty() {
2303                writeln!(
2304                    markdown,
2305                    "\n{} images attached as context.\n",
2306                    message.loaded_context.images.len()
2307                )?;
2308            }
2309
2310            for segment in &message.segments {
2311                match segment {
2312                    MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
2313                    MessageSegment::Thinking { text, .. } => {
2314                        writeln!(markdown, "<think>\n{}\n</think>\n", text)?
2315                    }
2316                    MessageSegment::RedactedThinking(_) => {}
2317                }
2318            }
2319
2320            for tool_use in self.tool_uses_for_message(message.id, cx) {
2321                writeln!(
2322                    markdown,
2323                    "**Use Tool: {} ({})**",
2324                    tool_use.name, tool_use.id
2325                )?;
2326                writeln!(markdown, "```json")?;
2327                writeln!(
2328                    markdown,
2329                    "{}",
2330                    serde_json::to_string_pretty(&tool_use.input)?
2331                )?;
2332                writeln!(markdown, "```")?;
2333            }
2334
2335            for tool_result in self.tool_results_for_message(message.id) {
2336                write!(markdown, "\n**Tool Results: {}", tool_result.tool_use_id)?;
2337                if tool_result.is_error {
2338                    write!(markdown, " (Error)")?;
2339                }
2340
2341                writeln!(markdown, "**\n")?;
2342                writeln!(markdown, "{}", tool_result.content)?;
2343            }
2344        }
2345
2346        Ok(String::from_utf8_lossy(&markdown).to_string())
2347    }
2348
2349    pub fn keep_edits_in_range(
2350        &mut self,
2351        buffer: Entity<language::Buffer>,
2352        buffer_range: Range<language::Anchor>,
2353        cx: &mut Context<Self>,
2354    ) {
2355        self.action_log.update(cx, |action_log, cx| {
2356            action_log.keep_edits_in_range(buffer, buffer_range, cx)
2357        });
2358    }
2359
2360    pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
2361        self.action_log
2362            .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
2363    }
2364
2365    pub fn reject_edits_in_ranges(
2366        &mut self,
2367        buffer: Entity<language::Buffer>,
2368        buffer_ranges: Vec<Range<language::Anchor>>,
2369        cx: &mut Context<Self>,
2370    ) -> Task<Result<()>> {
2371        self.action_log.update(cx, |action_log, cx| {
2372            action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
2373        })
2374    }
2375
2376    pub fn action_log(&self) -> &Entity<ActionLog> {
2377        &self.action_log
2378    }
2379
2380    pub fn project(&self) -> &Entity<Project> {
2381        &self.project
2382    }
2383
2384    pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
2385        if !cx.has_flag::<feature_flags::ThreadAutoCaptureFeatureFlag>() {
2386            return;
2387        }
2388
2389        let now = Instant::now();
2390        if let Some(last) = self.last_auto_capture_at {
2391            if now.duration_since(last).as_secs() < 10 {
2392                return;
2393            }
2394        }
2395
2396        self.last_auto_capture_at = Some(now);
2397
2398        let thread_id = self.id().clone();
2399        let github_login = self
2400            .project
2401            .read(cx)
2402            .user_store()
2403            .read(cx)
2404            .current_user()
2405            .map(|user| user.github_login.clone());
2406        let client = self.project.read(cx).client().clone();
2407        let serialize_task = self.serialize(cx);
2408
2409        cx.background_executor()
2410            .spawn(async move {
2411                if let Ok(serialized_thread) = serialize_task.await {
2412                    if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
2413                        telemetry::event!(
2414                            "Agent Thread Auto-Captured",
2415                            thread_id = thread_id.to_string(),
2416                            thread_data = thread_data,
2417                            auto_capture_reason = "tracked_user",
2418                            github_login = github_login
2419                        );
2420
2421                        client.telemetry().flush_events().await;
2422                    }
2423                }
2424            })
2425            .detach();
2426    }
2427
2428    pub fn cumulative_token_usage(&self) -> TokenUsage {
2429        self.cumulative_token_usage
2430    }
2431
2432    pub fn token_usage_up_to_message(&self, message_id: MessageId) -> TotalTokenUsage {
2433        let Some(model) = self.configured_model.as_ref() else {
2434            return TotalTokenUsage::default();
2435        };
2436
2437        let max = model.model.max_token_count();
2438
2439        let index = self
2440            .messages
2441            .iter()
2442            .position(|msg| msg.id == message_id)
2443            .unwrap_or(0);
2444
2445        if index == 0 {
2446            return TotalTokenUsage { total: 0, max };
2447        }
2448
2449        let token_usage = &self
2450            .request_token_usage
2451            .get(index - 1)
2452            .cloned()
2453            .unwrap_or_default();
2454
2455        TotalTokenUsage {
2456            total: token_usage.total_tokens() as usize,
2457            max,
2458        }
2459    }
2460
2461    pub fn total_token_usage(&self) -> Option<TotalTokenUsage> {
2462        let model = self.configured_model.as_ref()?;
2463
2464        let max = model.model.max_token_count();
2465
2466        if let Some(exceeded_error) = &self.exceeded_window_error {
2467            if model.model.id() == exceeded_error.model_id {
2468                return Some(TotalTokenUsage {
2469                    total: exceeded_error.token_count,
2470                    max,
2471                });
2472            }
2473        }
2474
2475        let total = self
2476            .token_usage_at_last_message()
2477            .unwrap_or_default()
2478            .total_tokens() as usize;
2479
2480        Some(TotalTokenUsage { total, max })
2481    }
2482
2483    fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
2484        self.request_token_usage
2485            .get(self.messages.len().saturating_sub(1))
2486            .or_else(|| self.request_token_usage.last())
2487            .cloned()
2488    }
2489
2490    fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
2491        let placeholder = self.token_usage_at_last_message().unwrap_or_default();
2492        self.request_token_usage
2493            .resize(self.messages.len(), placeholder);
2494
2495        if let Some(last) = self.request_token_usage.last_mut() {
2496            *last = token_usage;
2497        }
2498    }
2499
2500    pub fn deny_tool_use(
2501        &mut self,
2502        tool_use_id: LanguageModelToolUseId,
2503        tool_name: Arc<str>,
2504        window: Option<AnyWindowHandle>,
2505        cx: &mut Context<Self>,
2506    ) {
2507        let err = Err(anyhow::anyhow!(
2508            "Permission to run tool action denied by user"
2509        ));
2510
2511        self.tool_use.insert_tool_output(
2512            tool_use_id.clone(),
2513            tool_name,
2514            err,
2515            self.configured_model.as_ref(),
2516        );
2517        self.tool_finished(tool_use_id.clone(), None, true, window, cx);
2518    }
2519}
2520
2521#[derive(Debug, Clone, Error)]
2522pub enum ThreadError {
2523    #[error("Payment required")]
2524    PaymentRequired,
2525    #[error("Max monthly spend reached")]
2526    MaxMonthlySpendReached,
2527    #[error("Model request limit reached")]
2528    ModelRequestLimitReached { plan: Plan },
2529    #[error("Message {header}: {message}")]
2530    Message {
2531        header: SharedString,
2532        message: SharedString,
2533    },
2534}
2535
2536#[derive(Debug, Clone)]
2537pub enum ThreadEvent {
2538    ShowError(ThreadError),
2539    UsageUpdated(RequestUsage),
2540    StreamedCompletion,
2541    ReceivedTextChunk,
2542    NewRequest,
2543    StreamedAssistantText(MessageId, String),
2544    StreamedAssistantThinking(MessageId, String),
2545    StreamedToolUse {
2546        tool_use_id: LanguageModelToolUseId,
2547        ui_text: Arc<str>,
2548        input: serde_json::Value,
2549    },
2550    InvalidToolInput {
2551        tool_use_id: LanguageModelToolUseId,
2552        ui_text: Arc<str>,
2553        invalid_input_json: Arc<str>,
2554    },
2555    Stopped(Result<StopReason, Arc<anyhow::Error>>),
2556    MessageAdded(MessageId),
2557    MessageEdited(MessageId),
2558    MessageDeleted(MessageId),
2559    SummaryGenerated,
2560    SummaryChanged,
2561    UsePendingTools {
2562        tool_uses: Vec<PendingToolUse>,
2563    },
2564    ToolFinished {
2565        #[allow(unused)]
2566        tool_use_id: LanguageModelToolUseId,
2567        /// The pending tool use that corresponds to this tool.
2568        pending_tool_use: Option<PendingToolUse>,
2569    },
2570    CheckpointChanged,
2571    ToolConfirmationNeeded,
2572    CancelEditing,
2573    CompletionCanceled,
2574}
2575
2576impl EventEmitter<ThreadEvent> for Thread {}
2577
2578struct PendingCompletion {
2579    id: usize,
2580    queue_state: QueueState,
2581    _task: Task<()>,
2582}
2583
2584#[cfg(test)]
2585mod tests {
2586    use super::*;
2587    use crate::{ThreadStore, context::load_context, context_store::ContextStore, thread_store};
2588    use assistant_settings::AssistantSettings;
2589    use assistant_tool::ToolRegistry;
2590    use context_server::ContextServerSettings;
2591    use editor::EditorSettings;
2592    use gpui::TestAppContext;
2593    use language_model::fake_provider::FakeLanguageModel;
2594    use project::{FakeFs, Project};
2595    use prompt_store::PromptBuilder;
2596    use serde_json::json;
2597    use settings::{Settings, SettingsStore};
2598    use std::sync::Arc;
2599    use theme::ThemeSettings;
2600    use util::path;
2601    use workspace::Workspace;
2602
2603    #[gpui::test]
2604    async fn test_message_with_context(cx: &mut TestAppContext) {
2605        init_test_settings(cx);
2606
2607        let project = create_test_project(
2608            cx,
2609            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
2610        )
2611        .await;
2612
2613        let (_workspace, _thread_store, thread, context_store, model) =
2614            setup_test_environment(cx, project.clone()).await;
2615
2616        add_file_to_context(&project, &context_store, "test/code.rs", cx)
2617            .await
2618            .unwrap();
2619
2620        let context = context_store.update(cx, |store, _| store.context().next().cloned().unwrap());
2621        let loaded_context = cx
2622            .update(|cx| load_context(vec![context], &project, &None, cx))
2623            .await;
2624
2625        // Insert user message with context
2626        let message_id = thread.update(cx, |thread, cx| {
2627            thread.insert_user_message(
2628                "Please explain this code",
2629                loaded_context,
2630                None,
2631                Vec::new(),
2632                cx,
2633            )
2634        });
2635
2636        // Check content and context in message object
2637        let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2638
2639        // Use different path format strings based on platform for the test
2640        #[cfg(windows)]
2641        let path_part = r"test\code.rs";
2642        #[cfg(not(windows))]
2643        let path_part = "test/code.rs";
2644
2645        let expected_context = format!(
2646            r#"
2647<context>
2648The following items were attached by the user. They are up-to-date and don't need to be re-read.
2649
2650<files>
2651```rs {path_part}
2652fn main() {{
2653    println!("Hello, world!");
2654}}
2655```
2656</files>
2657</context>
2658"#
2659        );
2660
2661        assert_eq!(message.role, Role::User);
2662        assert_eq!(message.segments.len(), 1);
2663        assert_eq!(
2664            message.segments[0],
2665            MessageSegment::Text("Please explain this code".to_string())
2666        );
2667        assert_eq!(message.loaded_context.text, expected_context);
2668
2669        // Check message in request
2670        let request = thread.update(cx, |thread, cx| {
2671            thread.to_completion_request(model.clone(), cx)
2672        });
2673
2674        assert_eq!(request.messages.len(), 2);
2675        let expected_full_message = format!("{}Please explain this code", expected_context);
2676        assert_eq!(request.messages[1].string_contents(), expected_full_message);
2677    }
2678
2679    #[gpui::test]
2680    async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2681        init_test_settings(cx);
2682
2683        let project = create_test_project(
2684            cx,
2685            json!({
2686                "file1.rs": "fn function1() {}\n",
2687                "file2.rs": "fn function2() {}\n",
2688                "file3.rs": "fn function3() {}\n",
2689                "file4.rs": "fn function4() {}\n",
2690            }),
2691        )
2692        .await;
2693
2694        let (_, _thread_store, thread, context_store, model) =
2695            setup_test_environment(cx, project.clone()).await;
2696
2697        // First message with context 1
2698        add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2699            .await
2700            .unwrap();
2701        let new_contexts = context_store.update(cx, |store, cx| {
2702            store.new_context_for_thread(thread.read(cx), None)
2703        });
2704        assert_eq!(new_contexts.len(), 1);
2705        let loaded_context = cx
2706            .update(|cx| load_context(new_contexts, &project, &None, cx))
2707            .await;
2708        let message1_id = thread.update(cx, |thread, cx| {
2709            thread.insert_user_message("Message 1", loaded_context, None, Vec::new(), cx)
2710        });
2711
2712        // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2713        add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2714            .await
2715            .unwrap();
2716        let new_contexts = context_store.update(cx, |store, cx| {
2717            store.new_context_for_thread(thread.read(cx), None)
2718        });
2719        assert_eq!(new_contexts.len(), 1);
2720        let loaded_context = cx
2721            .update(|cx| load_context(new_contexts, &project, &None, cx))
2722            .await;
2723        let message2_id = thread.update(cx, |thread, cx| {
2724            thread.insert_user_message("Message 2", loaded_context, None, Vec::new(), cx)
2725        });
2726
2727        // Third message with all three contexts (contexts 1 and 2 should be skipped)
2728        //
2729        add_file_to_context(&project, &context_store, "test/file3.rs", cx)
2730            .await
2731            .unwrap();
2732        let new_contexts = context_store.update(cx, |store, cx| {
2733            store.new_context_for_thread(thread.read(cx), None)
2734        });
2735        assert_eq!(new_contexts.len(), 1);
2736        let loaded_context = cx
2737            .update(|cx| load_context(new_contexts, &project, &None, cx))
2738            .await;
2739        let message3_id = thread.update(cx, |thread, cx| {
2740            thread.insert_user_message("Message 3", loaded_context, None, Vec::new(), cx)
2741        });
2742
2743        // Check what contexts are included in each message
2744        let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
2745            (
2746                thread.message(message1_id).unwrap().clone(),
2747                thread.message(message2_id).unwrap().clone(),
2748                thread.message(message3_id).unwrap().clone(),
2749            )
2750        });
2751
2752        // First message should include context 1
2753        assert!(message1.loaded_context.text.contains("file1.rs"));
2754
2755        // Second message should include only context 2 (not 1)
2756        assert!(!message2.loaded_context.text.contains("file1.rs"));
2757        assert!(message2.loaded_context.text.contains("file2.rs"));
2758
2759        // Third message should include only context 3 (not 1 or 2)
2760        assert!(!message3.loaded_context.text.contains("file1.rs"));
2761        assert!(!message3.loaded_context.text.contains("file2.rs"));
2762        assert!(message3.loaded_context.text.contains("file3.rs"));
2763
2764        // Check entire request to make sure all contexts are properly included
2765        let request = thread.update(cx, |thread, cx| {
2766            thread.to_completion_request(model.clone(), cx)
2767        });
2768
2769        // The request should contain all 3 messages
2770        assert_eq!(request.messages.len(), 4);
2771
2772        // Check that the contexts are properly formatted in each message
2773        assert!(request.messages[1].string_contents().contains("file1.rs"));
2774        assert!(!request.messages[1].string_contents().contains("file2.rs"));
2775        assert!(!request.messages[1].string_contents().contains("file3.rs"));
2776
2777        assert!(!request.messages[2].string_contents().contains("file1.rs"));
2778        assert!(request.messages[2].string_contents().contains("file2.rs"));
2779        assert!(!request.messages[2].string_contents().contains("file3.rs"));
2780
2781        assert!(!request.messages[3].string_contents().contains("file1.rs"));
2782        assert!(!request.messages[3].string_contents().contains("file2.rs"));
2783        assert!(request.messages[3].string_contents().contains("file3.rs"));
2784
2785        add_file_to_context(&project, &context_store, "test/file4.rs", cx)
2786            .await
2787            .unwrap();
2788        let new_contexts = context_store.update(cx, |store, cx| {
2789            store.new_context_for_thread(thread.read(cx), Some(message2_id))
2790        });
2791        assert_eq!(new_contexts.len(), 3);
2792        let loaded_context = cx
2793            .update(|cx| load_context(new_contexts, &project, &None, cx))
2794            .await
2795            .loaded_context;
2796
2797        assert!(!loaded_context.text.contains("file1.rs"));
2798        assert!(loaded_context.text.contains("file2.rs"));
2799        assert!(loaded_context.text.contains("file3.rs"));
2800        assert!(loaded_context.text.contains("file4.rs"));
2801
2802        let new_contexts = context_store.update(cx, |store, cx| {
2803            // Remove file4.rs
2804            store.remove_context(&loaded_context.contexts[2].handle(), cx);
2805            store.new_context_for_thread(thread.read(cx), Some(message2_id))
2806        });
2807        assert_eq!(new_contexts.len(), 2);
2808        let loaded_context = cx
2809            .update(|cx| load_context(new_contexts, &project, &None, cx))
2810            .await
2811            .loaded_context;
2812
2813        assert!(!loaded_context.text.contains("file1.rs"));
2814        assert!(loaded_context.text.contains("file2.rs"));
2815        assert!(loaded_context.text.contains("file3.rs"));
2816        assert!(!loaded_context.text.contains("file4.rs"));
2817
2818        let new_contexts = context_store.update(cx, |store, cx| {
2819            // Remove file3.rs
2820            store.remove_context(&loaded_context.contexts[1].handle(), cx);
2821            store.new_context_for_thread(thread.read(cx), Some(message2_id))
2822        });
2823        assert_eq!(new_contexts.len(), 1);
2824        let loaded_context = cx
2825            .update(|cx| load_context(new_contexts, &project, &None, cx))
2826            .await
2827            .loaded_context;
2828
2829        assert!(!loaded_context.text.contains("file1.rs"));
2830        assert!(loaded_context.text.contains("file2.rs"));
2831        assert!(!loaded_context.text.contains("file3.rs"));
2832        assert!(!loaded_context.text.contains("file4.rs"));
2833    }
2834
2835    #[gpui::test]
2836    async fn test_message_without_files(cx: &mut TestAppContext) {
2837        init_test_settings(cx);
2838
2839        let project = create_test_project(
2840            cx,
2841            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
2842        )
2843        .await;
2844
2845        let (_, _thread_store, thread, _context_store, model) =
2846            setup_test_environment(cx, project.clone()).await;
2847
2848        // Insert user message without any context (empty context vector)
2849        let message_id = thread.update(cx, |thread, cx| {
2850            thread.insert_user_message(
2851                "What is the best way to learn Rust?",
2852                ContextLoadResult::default(),
2853                None,
2854                Vec::new(),
2855                cx,
2856            )
2857        });
2858
2859        // Check content and context in message object
2860        let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2861
2862        // Context should be empty when no files are included
2863        assert_eq!(message.role, Role::User);
2864        assert_eq!(message.segments.len(), 1);
2865        assert_eq!(
2866            message.segments[0],
2867            MessageSegment::Text("What is the best way to learn Rust?".to_string())
2868        );
2869        assert_eq!(message.loaded_context.text, "");
2870
2871        // Check message in request
2872        let request = thread.update(cx, |thread, cx| {
2873            thread.to_completion_request(model.clone(), cx)
2874        });
2875
2876        assert_eq!(request.messages.len(), 2);
2877        assert_eq!(
2878            request.messages[1].string_contents(),
2879            "What is the best way to learn Rust?"
2880        );
2881
2882        // Add second message, also without context
2883        let message2_id = thread.update(cx, |thread, cx| {
2884            thread.insert_user_message(
2885                "Are there any good books?",
2886                ContextLoadResult::default(),
2887                None,
2888                Vec::new(),
2889                cx,
2890            )
2891        });
2892
2893        let message2 =
2894            thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
2895        assert_eq!(message2.loaded_context.text, "");
2896
2897        // Check that both messages appear in the request
2898        let request = thread.update(cx, |thread, cx| {
2899            thread.to_completion_request(model.clone(), cx)
2900        });
2901
2902        assert_eq!(request.messages.len(), 3);
2903        assert_eq!(
2904            request.messages[1].string_contents(),
2905            "What is the best way to learn Rust?"
2906        );
2907        assert_eq!(
2908            request.messages[2].string_contents(),
2909            "Are there any good books?"
2910        );
2911    }
2912
2913    #[gpui::test]
2914    async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
2915        init_test_settings(cx);
2916
2917        let project = create_test_project(
2918            cx,
2919            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
2920        )
2921        .await;
2922
2923        let (_workspace, _thread_store, thread, context_store, model) =
2924            setup_test_environment(cx, project.clone()).await;
2925
2926        // Open buffer and add it to context
2927        let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
2928            .await
2929            .unwrap();
2930
2931        let context = context_store.update(cx, |store, _| store.context().next().cloned().unwrap());
2932        let loaded_context = cx
2933            .update(|cx| load_context(vec![context], &project, &None, cx))
2934            .await;
2935
2936        // Insert user message with the buffer as context
2937        thread.update(cx, |thread, cx| {
2938            thread.insert_user_message("Explain this code", loaded_context, None, Vec::new(), cx)
2939        });
2940
2941        // Create a request and check that it doesn't have a stale buffer warning yet
2942        let initial_request = thread.update(cx, |thread, cx| {
2943            thread.to_completion_request(model.clone(), cx)
2944        });
2945
2946        // Make sure we don't have a stale file warning yet
2947        let has_stale_warning = initial_request.messages.iter().any(|msg| {
2948            msg.string_contents()
2949                .contains("These files changed since last read:")
2950        });
2951        assert!(
2952            !has_stale_warning,
2953            "Should not have stale buffer warning before buffer is modified"
2954        );
2955
2956        // Modify the buffer
2957        buffer.update(cx, |buffer, cx| {
2958            // Find a position at the end of line 1
2959            buffer.edit(
2960                [(1..1, "\n    println!(\"Added a new line\");\n")],
2961                None,
2962                cx,
2963            );
2964        });
2965
2966        // Insert another user message without context
2967        thread.update(cx, |thread, cx| {
2968            thread.insert_user_message(
2969                "What does the code do now?",
2970                ContextLoadResult::default(),
2971                None,
2972                Vec::new(),
2973                cx,
2974            )
2975        });
2976
2977        // Create a new request and check for the stale buffer warning
2978        let new_request = thread.update(cx, |thread, cx| {
2979            thread.to_completion_request(model.clone(), cx)
2980        });
2981
2982        // We should have a stale file warning as the last message
2983        let last_message = new_request
2984            .messages
2985            .last()
2986            .expect("Request should have messages");
2987
2988        // The last message should be the stale buffer notification
2989        assert_eq!(last_message.role, Role::User);
2990
2991        // Check the exact content of the message
2992        let expected_content = "These files changed since last read:\n- code.rs\n";
2993        assert_eq!(
2994            last_message.string_contents(),
2995            expected_content,
2996            "Last message should be exactly the stale buffer notification"
2997        );
2998    }
2999
3000    fn init_test_settings(cx: &mut TestAppContext) {
3001        cx.update(|cx| {
3002            let settings_store = SettingsStore::test(cx);
3003            cx.set_global(settings_store);
3004            language::init(cx);
3005            Project::init_settings(cx);
3006            AssistantSettings::register(cx);
3007            prompt_store::init(cx);
3008            thread_store::init(cx);
3009            workspace::init_settings(cx);
3010            language_model::init_settings(cx);
3011            ThemeSettings::register(cx);
3012            ContextServerSettings::register(cx);
3013            EditorSettings::register(cx);
3014            ToolRegistry::default_global(cx);
3015        });
3016    }
3017
3018    // Helper to create a test project with test files
3019    async fn create_test_project(
3020        cx: &mut TestAppContext,
3021        files: serde_json::Value,
3022    ) -> Entity<Project> {
3023        let fs = FakeFs::new(cx.executor());
3024        fs.insert_tree(path!("/test"), files).await;
3025        Project::test(fs, [path!("/test").as_ref()], cx).await
3026    }
3027
3028    async fn setup_test_environment(
3029        cx: &mut TestAppContext,
3030        project: Entity<Project>,
3031    ) -> (
3032        Entity<Workspace>,
3033        Entity<ThreadStore>,
3034        Entity<Thread>,
3035        Entity<ContextStore>,
3036        Arc<dyn LanguageModel>,
3037    ) {
3038        let (workspace, cx) =
3039            cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
3040
3041        let thread_store = cx
3042            .update(|_, cx| {
3043                ThreadStore::load(
3044                    project.clone(),
3045                    cx.new(|_| ToolWorkingSet::default()),
3046                    None,
3047                    Arc::new(PromptBuilder::new(None).unwrap()),
3048                    cx,
3049                )
3050            })
3051            .await
3052            .unwrap();
3053
3054        let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
3055        let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
3056
3057        let model = FakeLanguageModel::default();
3058        let model: Arc<dyn LanguageModel> = Arc::new(model);
3059
3060        (workspace, thread_store, thread, context_store, model)
3061    }
3062
3063    async fn add_file_to_context(
3064        project: &Entity<Project>,
3065        context_store: &Entity<ContextStore>,
3066        path: &str,
3067        cx: &mut TestAppContext,
3068    ) -> Result<Entity<language::Buffer>> {
3069        let buffer_path = project
3070            .read_with(cx, |project, cx| project.find_project_path(path, cx))
3071            .unwrap();
3072
3073        let buffer = project
3074            .update(cx, |project, cx| {
3075                project.open_buffer(buffer_path.clone(), cx)
3076            })
3077            .await
3078            .unwrap();
3079
3080        context_store.update(cx, |context_store, cx| {
3081            context_store.add_file_from_buffer(&buffer_path, buffer.clone(), false, cx);
3082        });
3083
3084        Ok(buffer)
3085    }
3086}