thread.rs

   1use std::fmt::Write as _;
   2use std::io::Write;
   3use std::ops::Range;
   4use std::sync::Arc;
   5use std::time::Instant;
   6
   7use anyhow::{Result, anyhow};
   8use assistant_settings::AssistantSettings;
   9use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
  10use chrono::{DateTime, Utc};
  11use collections::HashMap;
  12use editor::display_map::CreaseMetadata;
  13use feature_flags::{self, FeatureFlagAppExt};
  14use futures::future::Shared;
  15use futures::{FutureExt, StreamExt as _};
  16use git::repository::DiffType;
  17use gpui::{
  18    AnyWindowHandle, App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task,
  19    WeakEntity,
  20};
  21use language_model::{
  22    ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
  23    LanguageModelId, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
  24    LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
  25    LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
  26    ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, SelectedModel,
  27    StopReason, TokenUsage,
  28};
  29use postage::stream::Stream as _;
  30use project::Project;
  31use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
  32use prompt_store::{ModelContext, PromptBuilder};
  33use proto::Plan;
  34use schemars::JsonSchema;
  35use serde::{Deserialize, Serialize};
  36use settings::Settings;
  37use thiserror::Error;
  38use util::{ResultExt as _, TryFutureExt as _, post_inc};
  39use uuid::Uuid;
  40use zed_llm_client::CompletionMode;
  41
  42use crate::ThreadStore;
  43use crate::context::{AgentContext, AgentContextHandle, ContextLoadResult, LoadedContext};
  44use crate::thread_store::{
  45    SerializedCrease, SerializedLanguageModel, SerializedMessage, SerializedMessageSegment,
  46    SerializedThread, SerializedToolResult, SerializedToolUse, SharedProjectContext,
  47};
  48use crate::tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState};
  49
  50#[derive(
  51    Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
  52)]
  53pub struct ThreadId(Arc<str>);
  54
  55impl ThreadId {
  56    pub fn new() -> Self {
  57        Self(Uuid::new_v4().to_string().into())
  58    }
  59}
  60
  61impl std::fmt::Display for ThreadId {
  62    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  63        write!(f, "{}", self.0)
  64    }
  65}
  66
  67impl From<&str> for ThreadId {
  68    fn from(value: &str) -> Self {
  69        Self(value.into())
  70    }
  71}
  72
  73/// The ID of the user prompt that initiated a request.
  74///
  75/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
  76#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
  77pub struct PromptId(Arc<str>);
  78
  79impl PromptId {
  80    pub fn new() -> Self {
  81        Self(Uuid::new_v4().to_string().into())
  82    }
  83}
  84
  85impl std::fmt::Display for PromptId {
  86    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  87        write!(f, "{}", self.0)
  88    }
  89}
  90
  91#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
  92pub struct MessageId(pub(crate) usize);
  93
  94impl MessageId {
  95    fn post_inc(&mut self) -> Self {
  96        Self(post_inc(&mut self.0))
  97    }
  98}
  99
 100/// Stored information that can be used to resurrect a context crease when creating an editor for a past message.
 101#[derive(Clone, Debug)]
 102pub struct MessageCrease {
 103    pub range: Range<usize>,
 104    pub metadata: CreaseMetadata,
 105    /// None for a deserialized message, Some otherwise.
 106    pub context: Option<AgentContextHandle>,
 107}
 108
 109/// A message in a [`Thread`].
 110#[derive(Debug, Clone)]
 111pub struct Message {
 112    pub id: MessageId,
 113    pub role: Role,
 114    pub segments: Vec<MessageSegment>,
 115    pub loaded_context: LoadedContext,
 116    pub creases: Vec<MessageCrease>,
 117}
 118
 119impl Message {
 120    /// Returns whether the message contains any meaningful text that should be displayed
 121    /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
 122    pub fn should_display_content(&self) -> bool {
 123        self.segments.iter().all(|segment| segment.should_display())
 124    }
 125
 126    pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
 127        if let Some(MessageSegment::Thinking {
 128            text: segment,
 129            signature: current_signature,
 130        }) = self.segments.last_mut()
 131        {
 132            if let Some(signature) = signature {
 133                *current_signature = Some(signature);
 134            }
 135            segment.push_str(text);
 136        } else {
 137            self.segments.push(MessageSegment::Thinking {
 138                text: text.to_string(),
 139                signature,
 140            });
 141        }
 142    }
 143
 144    pub fn push_text(&mut self, text: &str) {
 145        if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
 146            segment.push_str(text);
 147        } else {
 148            self.segments.push(MessageSegment::Text(text.to_string()));
 149        }
 150    }
 151
 152    pub fn to_string(&self) -> String {
 153        let mut result = String::new();
 154
 155        if !self.loaded_context.text.is_empty() {
 156            result.push_str(&self.loaded_context.text);
 157        }
 158
 159        for segment in &self.segments {
 160            match segment {
 161                MessageSegment::Text(text) => result.push_str(text),
 162                MessageSegment::Thinking { text, .. } => {
 163                    result.push_str("<think>\n");
 164                    result.push_str(text);
 165                    result.push_str("\n</think>");
 166                }
 167                MessageSegment::RedactedThinking(_) => {}
 168            }
 169        }
 170
 171        result
 172    }
 173}
 174
 175#[derive(Debug, Clone, PartialEq, Eq)]
 176pub enum MessageSegment {
 177    Text(String),
 178    Thinking {
 179        text: String,
 180        signature: Option<String>,
 181    },
 182    RedactedThinking(Vec<u8>),
 183}
 184
 185impl MessageSegment {
 186    pub fn should_display(&self) -> bool {
 187        match self {
 188            Self::Text(text) => text.is_empty(),
 189            Self::Thinking { text, .. } => text.is_empty(),
 190            Self::RedactedThinking(_) => false,
 191        }
 192    }
 193}
 194
 195#[derive(Debug, Clone, Serialize, Deserialize)]
 196pub struct ProjectSnapshot {
 197    pub worktree_snapshots: Vec<WorktreeSnapshot>,
 198    pub unsaved_buffer_paths: Vec<String>,
 199    pub timestamp: DateTime<Utc>,
 200}
 201
 202#[derive(Debug, Clone, Serialize, Deserialize)]
 203pub struct WorktreeSnapshot {
 204    pub worktree_path: String,
 205    pub git_state: Option<GitState>,
 206}
 207
 208#[derive(Debug, Clone, Serialize, Deserialize)]
 209pub struct GitState {
 210    pub remote_url: Option<String>,
 211    pub head_sha: Option<String>,
 212    pub current_branch: Option<String>,
 213    pub diff: Option<String>,
 214}
 215
 216#[derive(Clone)]
 217pub struct ThreadCheckpoint {
 218    message_id: MessageId,
 219    git_checkpoint: GitStoreCheckpoint,
 220}
 221
 222#[derive(Copy, Clone, Debug, PartialEq, Eq)]
 223pub enum ThreadFeedback {
 224    Positive,
 225    Negative,
 226}
 227
 228pub enum LastRestoreCheckpoint {
 229    Pending {
 230        message_id: MessageId,
 231    },
 232    Error {
 233        message_id: MessageId,
 234        error: String,
 235    },
 236}
 237
 238impl LastRestoreCheckpoint {
 239    pub fn message_id(&self) -> MessageId {
 240        match self {
 241            LastRestoreCheckpoint::Pending { message_id } => *message_id,
 242            LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
 243        }
 244    }
 245}
 246
 247#[derive(Clone, Debug, Default, Serialize, Deserialize)]
 248pub enum DetailedSummaryState {
 249    #[default]
 250    NotGenerated,
 251    Generating {
 252        message_id: MessageId,
 253    },
 254    Generated {
 255        text: SharedString,
 256        message_id: MessageId,
 257    },
 258}
 259
 260impl DetailedSummaryState {
 261    fn text(&self) -> Option<SharedString> {
 262        if let Self::Generated { text, .. } = self {
 263            Some(text.clone())
 264        } else {
 265            None
 266        }
 267    }
 268}
 269
 270#[derive(Default)]
 271pub struct TotalTokenUsage {
 272    pub total: usize,
 273    pub max: usize,
 274}
 275
 276impl TotalTokenUsage {
 277    pub fn ratio(&self) -> TokenUsageRatio {
 278        #[cfg(debug_assertions)]
 279        let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
 280            .unwrap_or("0.8".to_string())
 281            .parse()
 282            .unwrap();
 283        #[cfg(not(debug_assertions))]
 284        let warning_threshold: f32 = 0.8;
 285
 286        // When the maximum is unknown because there is no selected model,
 287        // avoid showing the token limit warning.
 288        if self.max == 0 {
 289            TokenUsageRatio::Normal
 290        } else if self.total >= self.max {
 291            TokenUsageRatio::Exceeded
 292        } else if self.total as f32 / self.max as f32 >= warning_threshold {
 293            TokenUsageRatio::Warning
 294        } else {
 295            TokenUsageRatio::Normal
 296        }
 297    }
 298
 299    pub fn add(&self, tokens: usize) -> TotalTokenUsage {
 300        TotalTokenUsage {
 301            total: self.total + tokens,
 302            max: self.max,
 303        }
 304    }
 305}
 306
 307#[derive(Debug, Default, PartialEq, Eq)]
 308pub enum TokenUsageRatio {
 309    #[default]
 310    Normal,
 311    Warning,
 312    Exceeded,
 313}
 314
 315fn default_completion_mode(cx: &App) -> CompletionMode {
 316    if cx.is_staff() {
 317        CompletionMode::Max
 318    } else {
 319        CompletionMode::Normal
 320    }
 321}
 322
 323#[derive(Debug, Clone, Copy)]
 324pub enum QueueState {
 325    Sending,
 326    Queued { position: usize },
 327    Started,
 328}
 329
 330/// A thread of conversation with the LLM.
 331pub struct Thread {
 332    id: ThreadId,
 333    updated_at: DateTime<Utc>,
 334    summary: Option<SharedString>,
 335    pending_summary: Task<Option<()>>,
 336    detailed_summary_task: Task<Option<()>>,
 337    detailed_summary_tx: postage::watch::Sender<DetailedSummaryState>,
 338    detailed_summary_rx: postage::watch::Receiver<DetailedSummaryState>,
 339    completion_mode: CompletionMode,
 340    messages: Vec<Message>,
 341    next_message_id: MessageId,
 342    last_prompt_id: PromptId,
 343    project_context: SharedProjectContext,
 344    checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
 345    completion_count: usize,
 346    pending_completions: Vec<PendingCompletion>,
 347    project: Entity<Project>,
 348    prompt_builder: Arc<PromptBuilder>,
 349    tools: Entity<ToolWorkingSet>,
 350    tool_use: ToolUseState,
 351    action_log: Entity<ActionLog>,
 352    last_restore_checkpoint: Option<LastRestoreCheckpoint>,
 353    pending_checkpoint: Option<ThreadCheckpoint>,
 354    initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
 355    request_token_usage: Vec<TokenUsage>,
 356    cumulative_token_usage: TokenUsage,
 357    exceeded_window_error: Option<ExceededWindowError>,
 358    tool_use_limit_reached: bool,
 359    feedback: Option<ThreadFeedback>,
 360    message_feedback: HashMap<MessageId, ThreadFeedback>,
 361    last_auto_capture_at: Option<Instant>,
 362    last_received_chunk_at: Option<Instant>,
 363    request_callback: Option<
 364        Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
 365    >,
 366    remaining_turns: u32,
 367    configured_model: Option<ConfiguredModel>,
 368}
 369
 370#[derive(Debug, Clone, Serialize, Deserialize)]
 371pub struct ExceededWindowError {
 372    /// Model used when last message exceeded context window
 373    model_id: LanguageModelId,
 374    /// Token count including last message
 375    token_count: usize,
 376}
 377
 378impl Thread {
 379    pub fn new(
 380        project: Entity<Project>,
 381        tools: Entity<ToolWorkingSet>,
 382        prompt_builder: Arc<PromptBuilder>,
 383        system_prompt: SharedProjectContext,
 384        cx: &mut Context<Self>,
 385    ) -> Self {
 386        let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel();
 387        let configured_model = LanguageModelRegistry::read_global(cx).default_model();
 388
 389        Self {
 390            id: ThreadId::new(),
 391            updated_at: Utc::now(),
 392            summary: None,
 393            pending_summary: Task::ready(None),
 394            detailed_summary_task: Task::ready(None),
 395            detailed_summary_tx,
 396            detailed_summary_rx,
 397            completion_mode: default_completion_mode(cx),
 398            messages: Vec::new(),
 399            next_message_id: MessageId(0),
 400            last_prompt_id: PromptId::new(),
 401            project_context: system_prompt,
 402            checkpoints_by_message: HashMap::default(),
 403            completion_count: 0,
 404            pending_completions: Vec::new(),
 405            project: project.clone(),
 406            prompt_builder,
 407            tools: tools.clone(),
 408            last_restore_checkpoint: None,
 409            pending_checkpoint: None,
 410            tool_use: ToolUseState::new(tools.clone()),
 411            action_log: cx.new(|_| ActionLog::new(project.clone())),
 412            initial_project_snapshot: {
 413                let project_snapshot = Self::project_snapshot(project, cx);
 414                cx.foreground_executor()
 415                    .spawn(async move { Some(project_snapshot.await) })
 416                    .shared()
 417            },
 418            request_token_usage: Vec::new(),
 419            cumulative_token_usage: TokenUsage::default(),
 420            exceeded_window_error: None,
 421            tool_use_limit_reached: false,
 422            feedback: None,
 423            message_feedback: HashMap::default(),
 424            last_auto_capture_at: None,
 425            last_received_chunk_at: None,
 426            request_callback: None,
 427            remaining_turns: u32::MAX,
 428            configured_model,
 429        }
 430    }
 431
 432    pub fn deserialize(
 433        id: ThreadId,
 434        serialized: SerializedThread,
 435        project: Entity<Project>,
 436        tools: Entity<ToolWorkingSet>,
 437        prompt_builder: Arc<PromptBuilder>,
 438        project_context: SharedProjectContext,
 439        cx: &mut Context<Self>,
 440    ) -> Self {
 441        let next_message_id = MessageId(
 442            serialized
 443                .messages
 444                .last()
 445                .map(|message| message.id.0 + 1)
 446                .unwrap_or(0),
 447        );
 448        let tool_use = ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages);
 449        let (detailed_summary_tx, detailed_summary_rx) =
 450            postage::watch::channel_with(serialized.detailed_summary_state);
 451
 452        let configured_model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
 453            serialized
 454                .model
 455                .and_then(|model| {
 456                    let model = SelectedModel {
 457                        provider: model.provider.clone().into(),
 458                        model: model.model.clone().into(),
 459                    };
 460                    registry.select_model(&model, cx)
 461                })
 462                .or_else(|| registry.default_model())
 463        });
 464
 465        Self {
 466            id,
 467            updated_at: serialized.updated_at,
 468            summary: Some(serialized.summary),
 469            pending_summary: Task::ready(None),
 470            detailed_summary_task: Task::ready(None),
 471            detailed_summary_tx,
 472            detailed_summary_rx,
 473            completion_mode: default_completion_mode(cx),
 474            messages: serialized
 475                .messages
 476                .into_iter()
 477                .map(|message| Message {
 478                    id: message.id,
 479                    role: message.role,
 480                    segments: message
 481                        .segments
 482                        .into_iter()
 483                        .map(|segment| match segment {
 484                            SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
 485                            SerializedMessageSegment::Thinking { text, signature } => {
 486                                MessageSegment::Thinking { text, signature }
 487                            }
 488                            SerializedMessageSegment::RedactedThinking { data } => {
 489                                MessageSegment::RedactedThinking(data)
 490                            }
 491                        })
 492                        .collect(),
 493                    loaded_context: LoadedContext {
 494                        contexts: Vec::new(),
 495                        text: message.context,
 496                        images: Vec::new(),
 497                    },
 498                    creases: message
 499                        .creases
 500                        .into_iter()
 501                        .map(|crease| MessageCrease {
 502                            range: crease.start..crease.end,
 503                            metadata: CreaseMetadata {
 504                                icon_path: crease.icon_path,
 505                                label: crease.label,
 506                            },
 507                            context: None,
 508                        })
 509                        .collect(),
 510                })
 511                .collect(),
 512            next_message_id,
 513            last_prompt_id: PromptId::new(),
 514            project_context,
 515            checkpoints_by_message: HashMap::default(),
 516            completion_count: 0,
 517            pending_completions: Vec::new(),
 518            last_restore_checkpoint: None,
 519            pending_checkpoint: None,
 520            project: project.clone(),
 521            prompt_builder,
 522            tools,
 523            tool_use,
 524            action_log: cx.new(|_| ActionLog::new(project)),
 525            initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
 526            request_token_usage: serialized.request_token_usage,
 527            cumulative_token_usage: serialized.cumulative_token_usage,
 528            exceeded_window_error: None,
 529            tool_use_limit_reached: false,
 530            feedback: None,
 531            message_feedback: HashMap::default(),
 532            last_auto_capture_at: None,
 533            last_received_chunk_at: None,
 534            request_callback: None,
 535            remaining_turns: u32::MAX,
 536            configured_model,
 537        }
 538    }
 539
 540    pub fn set_request_callback(
 541        &mut self,
 542        callback: impl 'static
 543        + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
 544    ) {
 545        self.request_callback = Some(Box::new(callback));
 546    }
 547
 548    pub fn id(&self) -> &ThreadId {
 549        &self.id
 550    }
 551
 552    pub fn is_empty(&self) -> bool {
 553        self.messages.is_empty()
 554    }
 555
 556    pub fn updated_at(&self) -> DateTime<Utc> {
 557        self.updated_at
 558    }
 559
 560    pub fn touch_updated_at(&mut self) {
 561        self.updated_at = Utc::now();
 562    }
 563
 564    pub fn advance_prompt_id(&mut self) {
 565        self.last_prompt_id = PromptId::new();
 566    }
 567
 568    pub fn summary(&self) -> Option<SharedString> {
 569        self.summary.clone()
 570    }
 571
 572    pub fn project_context(&self) -> SharedProjectContext {
 573        self.project_context.clone()
 574    }
 575
 576    pub fn get_or_init_configured_model(&mut self, cx: &App) -> Option<ConfiguredModel> {
 577        if self.configured_model.is_none() {
 578            self.configured_model = LanguageModelRegistry::read_global(cx).default_model();
 579        }
 580        self.configured_model.clone()
 581    }
 582
 583    pub fn configured_model(&self) -> Option<ConfiguredModel> {
 584        self.configured_model.clone()
 585    }
 586
 587    pub fn set_configured_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
 588        self.configured_model = model;
 589        cx.notify();
 590    }
 591
 592    pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
 593
 594    pub fn summary_or_default(&self) -> SharedString {
 595        self.summary.clone().unwrap_or(Self::DEFAULT_SUMMARY)
 596    }
 597
 598    pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
 599        let Some(current_summary) = &self.summary else {
 600            // Don't allow setting summary until generated
 601            return;
 602        };
 603
 604        let mut new_summary = new_summary.into();
 605
 606        if new_summary.is_empty() {
 607            new_summary = Self::DEFAULT_SUMMARY;
 608        }
 609
 610        if current_summary != &new_summary {
 611            self.summary = Some(new_summary);
 612            cx.emit(ThreadEvent::SummaryChanged);
 613        }
 614    }
 615
 616    pub fn completion_mode(&self) -> CompletionMode {
 617        self.completion_mode
 618    }
 619
 620    pub fn set_completion_mode(&mut self, mode: CompletionMode) {
 621        self.completion_mode = mode;
 622    }
 623
 624    pub fn message(&self, id: MessageId) -> Option<&Message> {
 625        let index = self
 626            .messages
 627            .binary_search_by(|message| message.id.cmp(&id))
 628            .ok()?;
 629
 630        self.messages.get(index)
 631    }
 632
 633    pub fn messages(&self) -> impl ExactSizeIterator<Item = &Message> {
 634        self.messages.iter()
 635    }
 636
 637    pub fn is_generating(&self) -> bool {
 638        !self.pending_completions.is_empty() || !self.all_tools_finished()
 639    }
 640
 641    /// Indicates whether streaming of language model events is stale.
 642    /// When `is_generating()` is false, this method returns `None`.
 643    pub fn is_generation_stale(&self) -> Option<bool> {
 644        const STALE_THRESHOLD: u128 = 250;
 645
 646        self.last_received_chunk_at
 647            .map(|instant| instant.elapsed().as_millis() > STALE_THRESHOLD)
 648    }
 649
 650    fn received_chunk(&mut self) {
 651        self.last_received_chunk_at = Some(Instant::now());
 652    }
 653
 654    pub fn queue_state(&self) -> Option<QueueState> {
 655        self.pending_completions
 656            .first()
 657            .map(|pending_completion| pending_completion.queue_state)
 658    }
 659
 660    pub fn tools(&self) -> &Entity<ToolWorkingSet> {
 661        &self.tools
 662    }
 663
 664    pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
 665        self.tool_use
 666            .pending_tool_uses()
 667            .into_iter()
 668            .find(|tool_use| &tool_use.id == id)
 669    }
 670
 671    pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
 672        self.tool_use
 673            .pending_tool_uses()
 674            .into_iter()
 675            .filter(|tool_use| tool_use.status.needs_confirmation())
 676    }
 677
 678    pub fn has_pending_tool_uses(&self) -> bool {
 679        !self.tool_use.pending_tool_uses().is_empty()
 680    }
 681
 682    pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
 683        self.checkpoints_by_message.get(&id).cloned()
 684    }
 685
 686    pub fn restore_checkpoint(
 687        &mut self,
 688        checkpoint: ThreadCheckpoint,
 689        cx: &mut Context<Self>,
 690    ) -> Task<Result<()>> {
 691        self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
 692            message_id: checkpoint.message_id,
 693        });
 694        cx.emit(ThreadEvent::CheckpointChanged);
 695        cx.notify();
 696
 697        let git_store = self.project().read(cx).git_store().clone();
 698        let restore = git_store.update(cx, |git_store, cx| {
 699            git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
 700        });
 701
 702        cx.spawn(async move |this, cx| {
 703            let result = restore.await;
 704            this.update(cx, |this, cx| {
 705                if let Err(err) = result.as_ref() {
 706                    this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
 707                        message_id: checkpoint.message_id,
 708                        error: err.to_string(),
 709                    });
 710                } else {
 711                    this.truncate(checkpoint.message_id, cx);
 712                    this.last_restore_checkpoint = None;
 713                }
 714                this.pending_checkpoint = None;
 715                cx.emit(ThreadEvent::CheckpointChanged);
 716                cx.notify();
 717            })?;
 718            result
 719        })
 720    }
 721
 722    fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
 723        let pending_checkpoint = if self.is_generating() {
 724            return;
 725        } else if let Some(checkpoint) = self.pending_checkpoint.take() {
 726            checkpoint
 727        } else {
 728            return;
 729        };
 730
 731        let git_store = self.project.read(cx).git_store().clone();
 732        let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
 733        cx.spawn(async move |this, cx| match final_checkpoint.await {
 734            Ok(final_checkpoint) => {
 735                let equal = git_store
 736                    .update(cx, |store, cx| {
 737                        store.compare_checkpoints(
 738                            pending_checkpoint.git_checkpoint.clone(),
 739                            final_checkpoint.clone(),
 740                            cx,
 741                        )
 742                    })?
 743                    .await
 744                    .unwrap_or(false);
 745
 746                if !equal {
 747                    this.update(cx, |this, cx| {
 748                        this.insert_checkpoint(pending_checkpoint, cx)
 749                    })?;
 750                }
 751
 752                Ok(())
 753            }
 754            Err(_) => this.update(cx, |this, cx| {
 755                this.insert_checkpoint(pending_checkpoint, cx)
 756            }),
 757        })
 758        .detach();
 759    }
 760
 761    fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
 762        self.checkpoints_by_message
 763            .insert(checkpoint.message_id, checkpoint);
 764        cx.emit(ThreadEvent::CheckpointChanged);
 765        cx.notify();
 766    }
 767
 768    pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
 769        self.last_restore_checkpoint.as_ref()
 770    }
 771
 772    pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
 773        let Some(message_ix) = self
 774            .messages
 775            .iter()
 776            .rposition(|message| message.id == message_id)
 777        else {
 778            return;
 779        };
 780        for deleted_message in self.messages.drain(message_ix..) {
 781            self.checkpoints_by_message.remove(&deleted_message.id);
 782        }
 783        cx.notify();
 784    }
 785
 786    pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AgentContext> {
 787        self.messages
 788            .iter()
 789            .find(|message| message.id == id)
 790            .into_iter()
 791            .flat_map(|message| message.loaded_context.contexts.iter())
 792    }
 793
 794    pub fn is_turn_end(&self, ix: usize) -> bool {
 795        if self.messages.is_empty() {
 796            return false;
 797        }
 798
 799        if !self.is_generating() && ix == self.messages.len() - 1 {
 800            return true;
 801        }
 802
 803        let Some(message) = self.messages.get(ix) else {
 804            return false;
 805        };
 806
 807        if message.role != Role::Assistant {
 808            return false;
 809        }
 810
 811        self.messages
 812            .get(ix + 1)
 813            .and_then(|message| {
 814                self.message(message.id)
 815                    .map(|next_message| next_message.role == Role::User)
 816            })
 817            .unwrap_or(false)
 818    }
 819
 820    pub fn tool_use_limit_reached(&self) -> bool {
 821        self.tool_use_limit_reached
 822    }
 823
 824    /// Returns whether all of the tool uses have finished running.
 825    pub fn all_tools_finished(&self) -> bool {
 826        // If the only pending tool uses left are the ones with errors, then
 827        // that means that we've finished running all of the pending tools.
 828        self.tool_use
 829            .pending_tool_uses()
 830            .iter()
 831            .all(|tool_use| tool_use.status.is_error())
 832    }
 833
 834    pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
 835        self.tool_use.tool_uses_for_message(id, cx)
 836    }
 837
 838    pub fn tool_results_for_message(
 839        &self,
 840        assistant_message_id: MessageId,
 841    ) -> Vec<&LanguageModelToolResult> {
 842        self.tool_use.tool_results_for_message(assistant_message_id)
 843    }
 844
 845    pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
 846        self.tool_use.tool_result(id)
 847    }
 848
 849    pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
 850        Some(&self.tool_use.tool_result(id)?.content)
 851    }
 852
 853    pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
 854        self.tool_use.tool_result_card(id).cloned()
 855    }
 856
 857    /// Return tools that are both enabled and supported by the model
 858    pub fn available_tools(
 859        &self,
 860        cx: &App,
 861        model: Arc<dyn LanguageModel>,
 862    ) -> Vec<LanguageModelRequestTool> {
 863        if model.supports_tools() {
 864            self.tools()
 865                .read(cx)
 866                .enabled_tools(cx)
 867                .into_iter()
 868                .filter_map(|tool| {
 869                    // Skip tools that cannot be supported
 870                    let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
 871                    Some(LanguageModelRequestTool {
 872                        name: tool.name(),
 873                        description: tool.description(),
 874                        input_schema,
 875                    })
 876                })
 877                .collect()
 878        } else {
 879            Vec::default()
 880        }
 881    }
 882
 883    pub fn insert_user_message(
 884        &mut self,
 885        text: impl Into<String>,
 886        loaded_context: ContextLoadResult,
 887        git_checkpoint: Option<GitStoreCheckpoint>,
 888        creases: Vec<MessageCrease>,
 889        cx: &mut Context<Self>,
 890    ) -> MessageId {
 891        if !loaded_context.referenced_buffers.is_empty() {
 892            self.action_log.update(cx, |log, cx| {
 893                for buffer in loaded_context.referenced_buffers {
 894                    log.track_buffer(buffer, cx);
 895                }
 896            });
 897        }
 898
 899        let message_id = self.insert_message(
 900            Role::User,
 901            vec![MessageSegment::Text(text.into())],
 902            loaded_context.loaded_context,
 903            creases,
 904            cx,
 905        );
 906
 907        if let Some(git_checkpoint) = git_checkpoint {
 908            self.pending_checkpoint = Some(ThreadCheckpoint {
 909                message_id,
 910                git_checkpoint,
 911            });
 912        }
 913
 914        self.auto_capture_telemetry(cx);
 915
 916        message_id
 917    }
 918
 919    pub fn insert_assistant_message(
 920        &mut self,
 921        segments: Vec<MessageSegment>,
 922        cx: &mut Context<Self>,
 923    ) -> MessageId {
 924        self.insert_message(
 925            Role::Assistant,
 926            segments,
 927            LoadedContext::default(),
 928            Vec::new(),
 929            cx,
 930        )
 931    }
 932
 933    pub fn insert_message(
 934        &mut self,
 935        role: Role,
 936        segments: Vec<MessageSegment>,
 937        loaded_context: LoadedContext,
 938        creases: Vec<MessageCrease>,
 939        cx: &mut Context<Self>,
 940    ) -> MessageId {
 941        let id = self.next_message_id.post_inc();
 942        self.messages.push(Message {
 943            id,
 944            role,
 945            segments,
 946            loaded_context,
 947            creases,
 948        });
 949        self.touch_updated_at();
 950        cx.emit(ThreadEvent::MessageAdded(id));
 951        id
 952    }
 953
 954    pub fn edit_message(
 955        &mut self,
 956        id: MessageId,
 957        new_role: Role,
 958        new_segments: Vec<MessageSegment>,
 959        loaded_context: Option<LoadedContext>,
 960        cx: &mut Context<Self>,
 961    ) -> bool {
 962        let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
 963            return false;
 964        };
 965        message.role = new_role;
 966        message.segments = new_segments;
 967        if let Some(context) = loaded_context {
 968            message.loaded_context = context;
 969        }
 970        self.touch_updated_at();
 971        cx.emit(ThreadEvent::MessageEdited(id));
 972        true
 973    }
 974
 975    pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
 976        let Some(index) = self.messages.iter().position(|message| message.id == id) else {
 977            return false;
 978        };
 979        self.messages.remove(index);
 980        self.touch_updated_at();
 981        cx.emit(ThreadEvent::MessageDeleted(id));
 982        true
 983    }
 984
 985    /// Returns the representation of this [`Thread`] in a textual form.
 986    ///
 987    /// This is the representation we use when attaching a thread as context to another thread.
 988    pub fn text(&self) -> String {
 989        let mut text = String::new();
 990
 991        for message in &self.messages {
 992            text.push_str(match message.role {
 993                language_model::Role::User => "User:",
 994                language_model::Role::Assistant => "Assistant:",
 995                language_model::Role::System => "System:",
 996            });
 997            text.push('\n');
 998
 999            for segment in &message.segments {
1000                match segment {
1001                    MessageSegment::Text(content) => text.push_str(content),
1002                    MessageSegment::Thinking { text: content, .. } => {
1003                        text.push_str(&format!("<think>{}</think>", content))
1004                    }
1005                    MessageSegment::RedactedThinking(_) => {}
1006                }
1007            }
1008            text.push('\n');
1009        }
1010
1011        text
1012    }
1013
1014    /// Serializes this thread into a format for storage or telemetry.
1015    pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
1016        let initial_project_snapshot = self.initial_project_snapshot.clone();
1017        cx.spawn(async move |this, cx| {
1018            let initial_project_snapshot = initial_project_snapshot.await;
1019            this.read_with(cx, |this, cx| SerializedThread {
1020                version: SerializedThread::VERSION.to_string(),
1021                summary: this.summary_or_default(),
1022                updated_at: this.updated_at(),
1023                messages: this
1024                    .messages()
1025                    .map(|message| SerializedMessage {
1026                        id: message.id,
1027                        role: message.role,
1028                        segments: message
1029                            .segments
1030                            .iter()
1031                            .map(|segment| match segment {
1032                                MessageSegment::Text(text) => {
1033                                    SerializedMessageSegment::Text { text: text.clone() }
1034                                }
1035                                MessageSegment::Thinking { text, signature } => {
1036                                    SerializedMessageSegment::Thinking {
1037                                        text: text.clone(),
1038                                        signature: signature.clone(),
1039                                    }
1040                                }
1041                                MessageSegment::RedactedThinking(data) => {
1042                                    SerializedMessageSegment::RedactedThinking {
1043                                        data: data.clone(),
1044                                    }
1045                                }
1046                            })
1047                            .collect(),
1048                        tool_uses: this
1049                            .tool_uses_for_message(message.id, cx)
1050                            .into_iter()
1051                            .map(|tool_use| SerializedToolUse {
1052                                id: tool_use.id,
1053                                name: tool_use.name,
1054                                input: tool_use.input,
1055                            })
1056                            .collect(),
1057                        tool_results: this
1058                            .tool_results_for_message(message.id)
1059                            .into_iter()
1060                            .map(|tool_result| SerializedToolResult {
1061                                tool_use_id: tool_result.tool_use_id.clone(),
1062                                is_error: tool_result.is_error,
1063                                content: tool_result.content.clone(),
1064                            })
1065                            .collect(),
1066                        context: message.loaded_context.text.clone(),
1067                        creases: message
1068                            .creases
1069                            .iter()
1070                            .map(|crease| SerializedCrease {
1071                                start: crease.range.start,
1072                                end: crease.range.end,
1073                                icon_path: crease.metadata.icon_path.clone(),
1074                                label: crease.metadata.label.clone(),
1075                            })
1076                            .collect(),
1077                    })
1078                    .collect(),
1079                initial_project_snapshot,
1080                cumulative_token_usage: this.cumulative_token_usage,
1081                request_token_usage: this.request_token_usage.clone(),
1082                detailed_summary_state: this.detailed_summary_rx.borrow().clone(),
1083                exceeded_window_error: this.exceeded_window_error.clone(),
1084                model: this
1085                    .configured_model
1086                    .as_ref()
1087                    .map(|model| SerializedLanguageModel {
1088                        provider: model.provider.id().0.to_string(),
1089                        model: model.model.id().0.to_string(),
1090                    }),
1091            })
1092        })
1093    }
1094
1095    pub fn remaining_turns(&self) -> u32 {
1096        self.remaining_turns
1097    }
1098
1099    pub fn set_remaining_turns(&mut self, remaining_turns: u32) {
1100        self.remaining_turns = remaining_turns;
1101    }
1102
1103    pub fn send_to_model(
1104        &mut self,
1105        model: Arc<dyn LanguageModel>,
1106        window: Option<AnyWindowHandle>,
1107        cx: &mut Context<Self>,
1108    ) {
1109        if self.remaining_turns == 0 {
1110            return;
1111        }
1112
1113        self.remaining_turns -= 1;
1114
1115        let request = self.to_completion_request(model.clone(), cx);
1116
1117        self.stream_completion(request, model, window, cx);
1118    }
1119
1120    pub fn used_tools_since_last_user_message(&self) -> bool {
1121        for message in self.messages.iter().rev() {
1122            if self.tool_use.message_has_tool_results(message.id) {
1123                return true;
1124            } else if message.role == Role::User {
1125                return false;
1126            }
1127        }
1128
1129        false
1130    }
1131
1132    pub fn to_completion_request(
1133        &self,
1134        model: Arc<dyn LanguageModel>,
1135        cx: &mut Context<Self>,
1136    ) -> LanguageModelRequest {
1137        let mut request = LanguageModelRequest {
1138            thread_id: Some(self.id.to_string()),
1139            prompt_id: Some(self.last_prompt_id.to_string()),
1140            mode: None,
1141            messages: vec![],
1142            tools: Vec::new(),
1143            stop: Vec::new(),
1144            temperature: None,
1145        };
1146
1147        let available_tools = self.available_tools(cx, model.clone());
1148        let available_tool_names = available_tools
1149            .iter()
1150            .map(|tool| tool.name.clone())
1151            .collect();
1152
1153        let model_context = &ModelContext {
1154            available_tools: available_tool_names,
1155        };
1156
1157        if let Some(project_context) = self.project_context.borrow().as_ref() {
1158            match self
1159                .prompt_builder
1160                .generate_assistant_system_prompt(project_context, model_context)
1161            {
1162                Err(err) => {
1163                    let message = format!("{err:?}").into();
1164                    log::error!("{message}");
1165                    cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1166                        header: "Error generating system prompt".into(),
1167                        message,
1168                    }));
1169                }
1170                Ok(system_prompt) => {
1171                    request.messages.push(LanguageModelRequestMessage {
1172                        role: Role::System,
1173                        content: vec![MessageContent::Text(system_prompt)],
1174                        cache: true,
1175                    });
1176                }
1177            }
1178        } else {
1179            let message = "Context for system prompt unexpectedly not ready.".into();
1180            log::error!("{message}");
1181            cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1182                header: "Error generating system prompt".into(),
1183                message,
1184            }));
1185        }
1186
1187        for message in &self.messages {
1188            let mut request_message = LanguageModelRequestMessage {
1189                role: message.role,
1190                content: Vec::new(),
1191                cache: false,
1192            };
1193
1194            message
1195                .loaded_context
1196                .add_to_request_message(&mut request_message);
1197
1198            for segment in &message.segments {
1199                match segment {
1200                    MessageSegment::Text(text) => {
1201                        if !text.is_empty() {
1202                            request_message
1203                                .content
1204                                .push(MessageContent::Text(text.into()));
1205                        }
1206                    }
1207                    MessageSegment::Thinking { text, signature } => {
1208                        if !text.is_empty() {
1209                            request_message.content.push(MessageContent::Thinking {
1210                                text: text.into(),
1211                                signature: signature.clone(),
1212                            });
1213                        }
1214                    }
1215                    MessageSegment::RedactedThinking(data) => {
1216                        request_message
1217                            .content
1218                            .push(MessageContent::RedactedThinking(data.clone()));
1219                    }
1220                };
1221            }
1222
1223            self.tool_use
1224                .attach_tool_uses(message.id, &mut request_message);
1225
1226            request.messages.push(request_message);
1227
1228            if let Some(tool_results_message) = self.tool_use.tool_results_message(message.id) {
1229                request.messages.push(tool_results_message);
1230            }
1231        }
1232
1233        // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1234        if let Some(last) = request.messages.last_mut() {
1235            last.cache = true;
1236        }
1237
1238        self.attached_tracked_files_state(&mut request.messages, cx);
1239
1240        request.tools = available_tools;
1241        request.mode = if model.supports_max_mode() {
1242            Some(self.completion_mode)
1243        } else {
1244            Some(CompletionMode::Normal)
1245        };
1246
1247        request
1248    }
1249
1250    fn to_summarize_request(&self, added_user_message: String) -> LanguageModelRequest {
1251        let mut request = LanguageModelRequest {
1252            thread_id: None,
1253            prompt_id: None,
1254            mode: None,
1255            messages: vec![],
1256            tools: Vec::new(),
1257            stop: Vec::new(),
1258            temperature: None,
1259        };
1260
1261        for message in &self.messages {
1262            let mut request_message = LanguageModelRequestMessage {
1263                role: message.role,
1264                content: Vec::new(),
1265                cache: false,
1266            };
1267
1268            for segment in &message.segments {
1269                match segment {
1270                    MessageSegment::Text(text) => request_message
1271                        .content
1272                        .push(MessageContent::Text(text.clone())),
1273                    MessageSegment::Thinking { .. } => {}
1274                    MessageSegment::RedactedThinking(_) => {}
1275                }
1276            }
1277
1278            if request_message.content.is_empty() {
1279                continue;
1280            }
1281
1282            request.messages.push(request_message);
1283        }
1284
1285        request.messages.push(LanguageModelRequestMessage {
1286            role: Role::User,
1287            content: vec![MessageContent::Text(added_user_message)],
1288            cache: false,
1289        });
1290
1291        request
1292    }
1293
1294    fn attached_tracked_files_state(
1295        &self,
1296        messages: &mut Vec<LanguageModelRequestMessage>,
1297        cx: &App,
1298    ) {
1299        const STALE_FILES_HEADER: &str = "These files changed since last read:";
1300
1301        let mut stale_message = String::new();
1302
1303        let action_log = self.action_log.read(cx);
1304
1305        for stale_file in action_log.stale_buffers(cx) {
1306            let Some(file) = stale_file.read(cx).file() else {
1307                continue;
1308            };
1309
1310            if stale_message.is_empty() {
1311                write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
1312            }
1313
1314            writeln!(&mut stale_message, "- {}", file.path().display()).ok();
1315        }
1316
1317        let mut content = Vec::with_capacity(2);
1318
1319        if !stale_message.is_empty() {
1320            content.push(stale_message.into());
1321        }
1322
1323        if !content.is_empty() {
1324            let context_message = LanguageModelRequestMessage {
1325                role: Role::User,
1326                content,
1327                cache: false,
1328            };
1329
1330            messages.push(context_message);
1331        }
1332    }
1333
1334    pub fn stream_completion(
1335        &mut self,
1336        request: LanguageModelRequest,
1337        model: Arc<dyn LanguageModel>,
1338        window: Option<AnyWindowHandle>,
1339        cx: &mut Context<Self>,
1340    ) {
1341        self.tool_use_limit_reached = false;
1342
1343        let pending_completion_id = post_inc(&mut self.completion_count);
1344        let mut request_callback_parameters = if self.request_callback.is_some() {
1345            Some((request.clone(), Vec::new()))
1346        } else {
1347            None
1348        };
1349        let prompt_id = self.last_prompt_id.clone();
1350        let tool_use_metadata = ToolUseMetadata {
1351            model: model.clone(),
1352            thread_id: self.id.clone(),
1353            prompt_id: prompt_id.clone(),
1354        };
1355
1356        self.last_received_chunk_at = Some(Instant::now());
1357
1358        let task = cx.spawn(async move |thread, cx| {
1359            let stream_completion_future = model.stream_completion_with_usage(request, &cx);
1360            let initial_token_usage =
1361                thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1362            let stream_completion = async {
1363                let (mut events, usage) = stream_completion_future.await?;
1364
1365                let mut stop_reason = StopReason::EndTurn;
1366                let mut current_token_usage = TokenUsage::default();
1367
1368                thread
1369                    .update(cx, |_thread, cx| {
1370                        if let Some(usage) = usage {
1371                            cx.emit(ThreadEvent::UsageUpdated(usage));
1372                        }
1373                        cx.emit(ThreadEvent::NewRequest);
1374                    })
1375                    .ok();
1376
1377                let mut request_assistant_message_id = None;
1378
1379                while let Some(event) = events.next().await {
1380                    if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1381                        response_events
1382                            .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1383                    }
1384
1385                    thread.update(cx, |thread, cx| {
1386                        let event = match event {
1387                            Ok(event) => event,
1388                            Err(LanguageModelCompletionError::BadInputJson {
1389                                id,
1390                                tool_name,
1391                                raw_input: invalid_input_json,
1392                                json_parse_error,
1393                            }) => {
1394                                thread.receive_invalid_tool_json(
1395                                    id,
1396                                    tool_name,
1397                                    invalid_input_json,
1398                                    json_parse_error,
1399                                    window,
1400                                    cx,
1401                                );
1402                                return Ok(());
1403                            }
1404                            Err(LanguageModelCompletionError::Other(error)) => {
1405                                return Err(error);
1406                            }
1407                        };
1408
1409                        match event {
1410                            LanguageModelCompletionEvent::StartMessage { .. } => {
1411                                request_assistant_message_id =
1412                                    Some(thread.insert_assistant_message(
1413                                        vec![MessageSegment::Text(String::new())],
1414                                        cx,
1415                                    ));
1416                            }
1417                            LanguageModelCompletionEvent::Stop(reason) => {
1418                                stop_reason = reason;
1419                            }
1420                            LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1421                                thread.update_token_usage_at_last_message(token_usage);
1422                                thread.cumulative_token_usage = thread.cumulative_token_usage
1423                                    + token_usage
1424                                    - current_token_usage;
1425                                current_token_usage = token_usage;
1426                            }
1427                            LanguageModelCompletionEvent::Text(chunk) => {
1428                                thread.received_chunk();
1429
1430                                cx.emit(ThreadEvent::ReceivedTextChunk);
1431                                if let Some(last_message) = thread.messages.last_mut() {
1432                                    if last_message.role == Role::Assistant
1433                                        && !thread.tool_use.has_tool_results(last_message.id)
1434                                    {
1435                                        last_message.push_text(&chunk);
1436                                        cx.emit(ThreadEvent::StreamedAssistantText(
1437                                            last_message.id,
1438                                            chunk,
1439                                        ));
1440                                    } else {
1441                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
1442                                        // of a new Assistant response.
1443                                        //
1444                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1445                                        // will result in duplicating the text of the chunk in the rendered Markdown.
1446                                        request_assistant_message_id =
1447                                            Some(thread.insert_assistant_message(
1448                                                vec![MessageSegment::Text(chunk.to_string())],
1449                                                cx,
1450                                            ));
1451                                    };
1452                                }
1453                            }
1454                            LanguageModelCompletionEvent::Thinking {
1455                                text: chunk,
1456                                signature,
1457                            } => {
1458                                thread.received_chunk();
1459
1460                                if let Some(last_message) = thread.messages.last_mut() {
1461                                    if last_message.role == Role::Assistant
1462                                        && !thread.tool_use.has_tool_results(last_message.id)
1463                                    {
1464                                        last_message.push_thinking(&chunk, signature);
1465                                        cx.emit(ThreadEvent::StreamedAssistantThinking(
1466                                            last_message.id,
1467                                            chunk,
1468                                        ));
1469                                    } else {
1470                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
1471                                        // of a new Assistant response.
1472                                        //
1473                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1474                                        // will result in duplicating the text of the chunk in the rendered Markdown.
1475                                        request_assistant_message_id =
1476                                            Some(thread.insert_assistant_message(
1477                                                vec![MessageSegment::Thinking {
1478                                                    text: chunk.to_string(),
1479                                                    signature,
1480                                                }],
1481                                                cx,
1482                                            ));
1483                                    };
1484                                }
1485                            }
1486                            LanguageModelCompletionEvent::ToolUse(tool_use) => {
1487                                let last_assistant_message_id = request_assistant_message_id
1488                                    .unwrap_or_else(|| {
1489                                        let new_assistant_message_id =
1490                                            thread.insert_assistant_message(vec![], cx);
1491                                        request_assistant_message_id =
1492                                            Some(new_assistant_message_id);
1493                                        new_assistant_message_id
1494                                    });
1495
1496                                let tool_use_id = tool_use.id.clone();
1497                                let streamed_input = if tool_use.is_input_complete {
1498                                    None
1499                                } else {
1500                                    Some((&tool_use.input).clone())
1501                                };
1502
1503                                let ui_text = thread.tool_use.request_tool_use(
1504                                    last_assistant_message_id,
1505                                    tool_use,
1506                                    tool_use_metadata.clone(),
1507                                    cx,
1508                                );
1509
1510                                if let Some(input) = streamed_input {
1511                                    cx.emit(ThreadEvent::StreamedToolUse {
1512                                        tool_use_id,
1513                                        ui_text,
1514                                        input,
1515                                    });
1516                                }
1517                            }
1518                            LanguageModelCompletionEvent::QueueUpdate(status) => {
1519                                if let Some(completion) = thread
1520                                    .pending_completions
1521                                    .iter_mut()
1522                                    .find(|completion| completion.id == pending_completion_id)
1523                                {
1524                                    let queue_state = match status {
1525                                        language_model::CompletionRequestStatus::Queued {
1526                                            position,
1527                                        } => Some(QueueState::Queued { position }),
1528                                        language_model::CompletionRequestStatus::Started => {
1529                                            Some(QueueState::Started)
1530                                        }
1531                                        language_model::CompletionRequestStatus::ToolUseLimitReached => {
1532                                            thread.tool_use_limit_reached = true;
1533                                            None
1534                                        }
1535                                    };
1536
1537                                    if let Some(queue_state) = queue_state {
1538                                        completion.queue_state = queue_state;
1539                                    }
1540                                }
1541                            }
1542                        }
1543
1544                        thread.touch_updated_at();
1545                        cx.emit(ThreadEvent::StreamedCompletion);
1546                        cx.notify();
1547
1548                        thread.auto_capture_telemetry(cx);
1549                        Ok(())
1550                    })??;
1551
1552                    smol::future::yield_now().await;
1553                }
1554
1555                thread.update(cx, |thread, cx| {
1556                    thread.last_received_chunk_at = None;
1557                    thread
1558                        .pending_completions
1559                        .retain(|completion| completion.id != pending_completion_id);
1560
1561                    // If there is a response without tool use, summarize the message. Otherwise,
1562                    // allow two tool uses before summarizing.
1563                    if thread.summary.is_none()
1564                        && thread.messages.len() >= 2
1565                        && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
1566                    {
1567                        thread.summarize(cx);
1568                    }
1569                })?;
1570
1571                anyhow::Ok(stop_reason)
1572            };
1573
1574            let result = stream_completion.await;
1575
1576            thread
1577                .update(cx, |thread, cx| {
1578                    thread.finalize_pending_checkpoint(cx);
1579                    match result.as_ref() {
1580                        Ok(stop_reason) => match stop_reason {
1581                            StopReason::ToolUse => {
1582                                let tool_uses = thread.use_pending_tools(window, cx, model.clone());
1583                                cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1584                            }
1585                            StopReason::EndTurn | StopReason::MaxTokens  => {
1586                                thread.project.update(cx, |project, cx| {
1587                                    project.set_agent_location(None, cx);
1588                                });
1589                            }
1590                        },
1591                        Err(error) => {
1592                            thread.project.update(cx, |project, cx| {
1593                                project.set_agent_location(None, cx);
1594                            });
1595
1596                            if error.is::<PaymentRequiredError>() {
1597                                cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1598                            } else if error.is::<MaxMonthlySpendReachedError>() {
1599                                cx.emit(ThreadEvent::ShowError(
1600                                    ThreadError::MaxMonthlySpendReached,
1601                                ));
1602                            } else if let Some(error) =
1603                                error.downcast_ref::<ModelRequestLimitReachedError>()
1604                            {
1605                                cx.emit(ThreadEvent::ShowError(
1606                                    ThreadError::ModelRequestLimitReached { plan: error.plan },
1607                                ));
1608                            } else if let Some(known_error) =
1609                                error.downcast_ref::<LanguageModelKnownError>()
1610                            {
1611                                match known_error {
1612                                    LanguageModelKnownError::ContextWindowLimitExceeded {
1613                                        tokens,
1614                                    } => {
1615                                        thread.exceeded_window_error = Some(ExceededWindowError {
1616                                            model_id: model.id(),
1617                                            token_count: *tokens,
1618                                        });
1619                                        cx.notify();
1620                                    }
1621                                }
1622                            } else {
1623                                let error_message = error
1624                                    .chain()
1625                                    .map(|err| err.to_string())
1626                                    .collect::<Vec<_>>()
1627                                    .join("\n");
1628                                cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1629                                    header: "Error interacting with language model".into(),
1630                                    message: SharedString::from(error_message.clone()),
1631                                }));
1632                            }
1633
1634                            thread.cancel_last_completion(window, cx);
1635                        }
1636                    }
1637                    cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1638
1639                    if let Some((request_callback, (request, response_events))) = thread
1640                        .request_callback
1641                        .as_mut()
1642                        .zip(request_callback_parameters.as_ref())
1643                    {
1644                        request_callback(request, response_events);
1645                    }
1646
1647                    thread.auto_capture_telemetry(cx);
1648
1649                    if let Ok(initial_usage) = initial_token_usage {
1650                        let usage = thread.cumulative_token_usage - initial_usage;
1651
1652                        telemetry::event!(
1653                            "Assistant Thread Completion",
1654                            thread_id = thread.id().to_string(),
1655                            prompt_id = prompt_id,
1656                            model = model.telemetry_id(),
1657                            model_provider = model.provider_id().to_string(),
1658                            input_tokens = usage.input_tokens,
1659                            output_tokens = usage.output_tokens,
1660                            cache_creation_input_tokens = usage.cache_creation_input_tokens,
1661                            cache_read_input_tokens = usage.cache_read_input_tokens,
1662                        );
1663                    }
1664                })
1665                .ok();
1666        });
1667
1668        self.pending_completions.push(PendingCompletion {
1669            id: pending_completion_id,
1670            queue_state: QueueState::Sending,
1671            _task: task,
1672        });
1673    }
1674
1675    pub fn summarize(&mut self, cx: &mut Context<Self>) {
1676        let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1677            return;
1678        };
1679
1680        if !model.provider.is_authenticated(cx) {
1681            return;
1682        }
1683
1684        let added_user_message = "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1685            Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1686            If the conversation is about a specific subject, include it in the title. \
1687            Be descriptive. DO NOT speak in the first person.";
1688
1689        let request = self.to_summarize_request(added_user_message.into());
1690
1691        self.pending_summary = cx.spawn(async move |this, cx| {
1692            async move {
1693                let stream = model.model.stream_completion_text_with_usage(request, &cx);
1694                let (mut messages, usage) = stream.await?;
1695
1696                if let Some(usage) = usage {
1697                    this.update(cx, |_thread, cx| {
1698                        cx.emit(ThreadEvent::UsageUpdated(usage));
1699                    })
1700                    .ok();
1701                }
1702
1703                let mut new_summary = String::new();
1704                while let Some(message) = messages.stream.next().await {
1705                    let text = message?;
1706                    let mut lines = text.lines();
1707                    new_summary.extend(lines.next());
1708
1709                    // Stop if the LLM generated multiple lines.
1710                    if lines.next().is_some() {
1711                        break;
1712                    }
1713                }
1714
1715                this.update(cx, |this, cx| {
1716                    if !new_summary.is_empty() {
1717                        this.summary = Some(new_summary.into());
1718                    }
1719
1720                    cx.emit(ThreadEvent::SummaryGenerated);
1721                })?;
1722
1723                anyhow::Ok(())
1724            }
1725            .log_err()
1726            .await
1727        });
1728    }
1729
1730    pub fn start_generating_detailed_summary_if_needed(
1731        &mut self,
1732        thread_store: WeakEntity<ThreadStore>,
1733        cx: &mut Context<Self>,
1734    ) {
1735        let Some(last_message_id) = self.messages.last().map(|message| message.id) else {
1736            return;
1737        };
1738
1739        match &*self.detailed_summary_rx.borrow() {
1740            DetailedSummaryState::Generating { message_id, .. }
1741            | DetailedSummaryState::Generated { message_id, .. }
1742                if *message_id == last_message_id =>
1743            {
1744                // Already up-to-date
1745                return;
1746            }
1747            _ => {}
1748        }
1749
1750        let Some(ConfiguredModel { model, provider }) =
1751            LanguageModelRegistry::read_global(cx).thread_summary_model()
1752        else {
1753            return;
1754        };
1755
1756        if !provider.is_authenticated(cx) {
1757            return;
1758        }
1759
1760        let added_user_message = "Generate a detailed summary of this conversation. Include:\n\
1761             1. A brief overview of what was discussed\n\
1762             2. Key facts or information discovered\n\
1763             3. Outcomes or conclusions reached\n\
1764             4. Any action items or next steps if any\n\
1765             Format it in Markdown with headings and bullet points.";
1766
1767        let request = self.to_summarize_request(added_user_message.into());
1768
1769        *self.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generating {
1770            message_id: last_message_id,
1771        };
1772
1773        // Replace the detailed summarization task if there is one, cancelling it. It would probably
1774        // be better to allow the old task to complete, but this would require logic for choosing
1775        // which result to prefer (the old task could complete after the new one, resulting in a
1776        // stale summary).
1777        self.detailed_summary_task = cx.spawn(async move |thread, cx| {
1778            let stream = model.stream_completion_text(request, &cx);
1779            let Some(mut messages) = stream.await.log_err() else {
1780                thread
1781                    .update(cx, |thread, _cx| {
1782                        *thread.detailed_summary_tx.borrow_mut() =
1783                            DetailedSummaryState::NotGenerated;
1784                    })
1785                    .ok()?;
1786                return None;
1787            };
1788
1789            let mut new_detailed_summary = String::new();
1790
1791            while let Some(chunk) = messages.stream.next().await {
1792                if let Some(chunk) = chunk.log_err() {
1793                    new_detailed_summary.push_str(&chunk);
1794                }
1795            }
1796
1797            thread
1798                .update(cx, |thread, _cx| {
1799                    *thread.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generated {
1800                        text: new_detailed_summary.into(),
1801                        message_id: last_message_id,
1802                    };
1803                })
1804                .ok()?;
1805
1806            // Save thread so its summary can be reused later
1807            if let Some(thread) = thread.upgrade() {
1808                if let Ok(Ok(save_task)) = cx.update(|cx| {
1809                    thread_store
1810                        .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx))
1811                }) {
1812                    save_task.await.log_err();
1813                }
1814            }
1815
1816            Some(())
1817        });
1818    }
1819
1820    pub async fn wait_for_detailed_summary_or_text(
1821        this: &Entity<Self>,
1822        cx: &mut AsyncApp,
1823    ) -> Option<SharedString> {
1824        let mut detailed_summary_rx = this
1825            .read_with(cx, |this, _cx| this.detailed_summary_rx.clone())
1826            .ok()?;
1827        loop {
1828            match detailed_summary_rx.recv().await? {
1829                DetailedSummaryState::Generating { .. } => {}
1830                DetailedSummaryState::NotGenerated => {
1831                    return this.read_with(cx, |this, _cx| this.text().into()).ok();
1832                }
1833                DetailedSummaryState::Generated { text, .. } => return Some(text),
1834            }
1835        }
1836    }
1837
1838    pub fn latest_detailed_summary_or_text(&self) -> SharedString {
1839        self.detailed_summary_rx
1840            .borrow()
1841            .text()
1842            .unwrap_or_else(|| self.text().into())
1843    }
1844
1845    pub fn is_generating_detailed_summary(&self) -> bool {
1846        matches!(
1847            &*self.detailed_summary_rx.borrow(),
1848            DetailedSummaryState::Generating { .. }
1849        )
1850    }
1851
1852    pub fn use_pending_tools(
1853        &mut self,
1854        window: Option<AnyWindowHandle>,
1855        cx: &mut Context<Self>,
1856        model: Arc<dyn LanguageModel>,
1857    ) -> Vec<PendingToolUse> {
1858        self.auto_capture_telemetry(cx);
1859        let request = self.to_completion_request(model, cx);
1860        let messages = Arc::new(request.messages);
1861        let pending_tool_uses = self
1862            .tool_use
1863            .pending_tool_uses()
1864            .into_iter()
1865            .filter(|tool_use| tool_use.status.is_idle())
1866            .cloned()
1867            .collect::<Vec<_>>();
1868
1869        for tool_use in pending_tool_uses.iter() {
1870            if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
1871                if tool.needs_confirmation(&tool_use.input, cx)
1872                    && !AssistantSettings::get_global(cx).always_allow_tool_actions
1873                {
1874                    self.tool_use.confirm_tool_use(
1875                        tool_use.id.clone(),
1876                        tool_use.ui_text.clone(),
1877                        tool_use.input.clone(),
1878                        messages.clone(),
1879                        tool,
1880                    );
1881                    cx.emit(ThreadEvent::ToolConfirmationNeeded);
1882                } else {
1883                    self.run_tool(
1884                        tool_use.id.clone(),
1885                        tool_use.ui_text.clone(),
1886                        tool_use.input.clone(),
1887                        &messages,
1888                        tool,
1889                        window,
1890                        cx,
1891                    );
1892                }
1893            }
1894        }
1895
1896        pending_tool_uses
1897    }
1898
1899    pub fn receive_invalid_tool_json(
1900        &mut self,
1901        tool_use_id: LanguageModelToolUseId,
1902        tool_name: Arc<str>,
1903        invalid_json: Arc<str>,
1904        error: String,
1905        window: Option<AnyWindowHandle>,
1906        cx: &mut Context<Thread>,
1907    ) {
1908        log::error!("The model returned invalid input JSON: {invalid_json}");
1909
1910        let pending_tool_use = self.tool_use.insert_tool_output(
1911            tool_use_id.clone(),
1912            tool_name,
1913            Err(anyhow!("Error parsing input JSON: {error}")),
1914            self.configured_model.as_ref(),
1915        );
1916        let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
1917            pending_tool_use.ui_text.clone()
1918        } else {
1919            log::error!(
1920                "There was no pending tool use for tool use {tool_use_id}, even though it finished (with invalid input JSON)."
1921            );
1922            format!("Unknown tool {}", tool_use_id).into()
1923        };
1924
1925        cx.emit(ThreadEvent::InvalidToolInput {
1926            tool_use_id: tool_use_id.clone(),
1927            ui_text,
1928            invalid_input_json: invalid_json,
1929        });
1930
1931        self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
1932    }
1933
1934    pub fn run_tool(
1935        &mut self,
1936        tool_use_id: LanguageModelToolUseId,
1937        ui_text: impl Into<SharedString>,
1938        input: serde_json::Value,
1939        messages: &[LanguageModelRequestMessage],
1940        tool: Arc<dyn Tool>,
1941        window: Option<AnyWindowHandle>,
1942        cx: &mut Context<Thread>,
1943    ) {
1944        let task = self.spawn_tool_use(tool_use_id.clone(), messages, input, tool, window, cx);
1945        self.tool_use
1946            .run_pending_tool(tool_use_id, ui_text.into(), task);
1947    }
1948
1949    fn spawn_tool_use(
1950        &mut self,
1951        tool_use_id: LanguageModelToolUseId,
1952        messages: &[LanguageModelRequestMessage],
1953        input: serde_json::Value,
1954        tool: Arc<dyn Tool>,
1955        window: Option<AnyWindowHandle>,
1956        cx: &mut Context<Thread>,
1957    ) -> Task<()> {
1958        let tool_name: Arc<str> = tool.name().into();
1959
1960        let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
1961            Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))).into()
1962        } else {
1963            tool.run(
1964                input,
1965                messages,
1966                self.project.clone(),
1967                self.action_log.clone(),
1968                window,
1969                cx,
1970            )
1971        };
1972
1973        // Store the card separately if it exists
1974        if let Some(card) = tool_result.card.clone() {
1975            self.tool_use
1976                .insert_tool_result_card(tool_use_id.clone(), card);
1977        }
1978
1979        cx.spawn({
1980            async move |thread: WeakEntity<Thread>, cx| {
1981                let output = tool_result.output.await;
1982
1983                thread
1984                    .update(cx, |thread, cx| {
1985                        let pending_tool_use = thread.tool_use.insert_tool_output(
1986                            tool_use_id.clone(),
1987                            tool_name,
1988                            output,
1989                            thread.configured_model.as_ref(),
1990                        );
1991                        thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
1992                    })
1993                    .ok();
1994            }
1995        })
1996    }
1997
1998    fn tool_finished(
1999        &mut self,
2000        tool_use_id: LanguageModelToolUseId,
2001        pending_tool_use: Option<PendingToolUse>,
2002        canceled: bool,
2003        window: Option<AnyWindowHandle>,
2004        cx: &mut Context<Self>,
2005    ) {
2006        if self.all_tools_finished() {
2007            if let Some(ConfiguredModel { model, .. }) = self.configured_model.as_ref() {
2008                if !canceled {
2009                    self.send_to_model(model.clone(), window, cx);
2010                }
2011                self.auto_capture_telemetry(cx);
2012            }
2013        }
2014
2015        cx.emit(ThreadEvent::ToolFinished {
2016            tool_use_id,
2017            pending_tool_use,
2018        });
2019    }
2020
2021    /// Cancels the last pending completion, if there are any pending.
2022    ///
2023    /// Returns whether a completion was canceled.
2024    pub fn cancel_last_completion(
2025        &mut self,
2026        window: Option<AnyWindowHandle>,
2027        cx: &mut Context<Self>,
2028    ) -> bool {
2029        let mut canceled = self.pending_completions.pop().is_some();
2030
2031        for pending_tool_use in self.tool_use.cancel_pending() {
2032            canceled = true;
2033            self.tool_finished(
2034                pending_tool_use.id.clone(),
2035                Some(pending_tool_use),
2036                true,
2037                window,
2038                cx,
2039            );
2040        }
2041
2042        self.finalize_pending_checkpoint(cx);
2043
2044        if canceled {
2045            cx.emit(ThreadEvent::CompletionCanceled);
2046        }
2047
2048        canceled
2049    }
2050
2051    /// Signals that any in-progress editing should be canceled.
2052    ///
2053    /// This method is used to notify listeners (like ActiveThread) that
2054    /// they should cancel any editing operations.
2055    pub fn cancel_editing(&mut self, cx: &mut Context<Self>) {
2056        cx.emit(ThreadEvent::CancelEditing);
2057    }
2058
2059    pub fn feedback(&self) -> Option<ThreadFeedback> {
2060        self.feedback
2061    }
2062
2063    pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
2064        self.message_feedback.get(&message_id).copied()
2065    }
2066
2067    pub fn report_message_feedback(
2068        &mut self,
2069        message_id: MessageId,
2070        feedback: ThreadFeedback,
2071        cx: &mut Context<Self>,
2072    ) -> Task<Result<()>> {
2073        if self.message_feedback.get(&message_id) == Some(&feedback) {
2074            return Task::ready(Ok(()));
2075        }
2076
2077        let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2078        let serialized_thread = self.serialize(cx);
2079        let thread_id = self.id().clone();
2080        let client = self.project.read(cx).client();
2081
2082        let enabled_tool_names: Vec<String> = self
2083            .tools()
2084            .read(cx)
2085            .enabled_tools(cx)
2086            .iter()
2087            .map(|tool| tool.name().to_string())
2088            .collect();
2089
2090        self.message_feedback.insert(message_id, feedback);
2091
2092        cx.notify();
2093
2094        let message_content = self
2095            .message(message_id)
2096            .map(|msg| msg.to_string())
2097            .unwrap_or_default();
2098
2099        cx.background_spawn(async move {
2100            let final_project_snapshot = final_project_snapshot.await;
2101            let serialized_thread = serialized_thread.await?;
2102            let thread_data =
2103                serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
2104
2105            let rating = match feedback {
2106                ThreadFeedback::Positive => "positive",
2107                ThreadFeedback::Negative => "negative",
2108            };
2109            telemetry::event!(
2110                "Assistant Thread Rated",
2111                rating,
2112                thread_id,
2113                enabled_tool_names,
2114                message_id = message_id.0,
2115                message_content,
2116                thread_data,
2117                final_project_snapshot
2118            );
2119            client.telemetry().flush_events().await;
2120
2121            Ok(())
2122        })
2123    }
2124
2125    pub fn report_feedback(
2126        &mut self,
2127        feedback: ThreadFeedback,
2128        cx: &mut Context<Self>,
2129    ) -> Task<Result<()>> {
2130        let last_assistant_message_id = self
2131            .messages
2132            .iter()
2133            .rev()
2134            .find(|msg| msg.role == Role::Assistant)
2135            .map(|msg| msg.id);
2136
2137        if let Some(message_id) = last_assistant_message_id {
2138            self.report_message_feedback(message_id, feedback, cx)
2139        } else {
2140            let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2141            let serialized_thread = self.serialize(cx);
2142            let thread_id = self.id().clone();
2143            let client = self.project.read(cx).client();
2144            self.feedback = Some(feedback);
2145            cx.notify();
2146
2147            cx.background_spawn(async move {
2148                let final_project_snapshot = final_project_snapshot.await;
2149                let serialized_thread = serialized_thread.await?;
2150                let thread_data = serde_json::to_value(serialized_thread)
2151                    .unwrap_or_else(|_| serde_json::Value::Null);
2152
2153                let rating = match feedback {
2154                    ThreadFeedback::Positive => "positive",
2155                    ThreadFeedback::Negative => "negative",
2156                };
2157                telemetry::event!(
2158                    "Assistant Thread Rated",
2159                    rating,
2160                    thread_id,
2161                    thread_data,
2162                    final_project_snapshot
2163                );
2164                client.telemetry().flush_events().await;
2165
2166                Ok(())
2167            })
2168        }
2169    }
2170
2171    /// Create a snapshot of the current project state including git information and unsaved buffers.
2172    fn project_snapshot(
2173        project: Entity<Project>,
2174        cx: &mut Context<Self>,
2175    ) -> Task<Arc<ProjectSnapshot>> {
2176        let git_store = project.read(cx).git_store().clone();
2177        let worktree_snapshots: Vec<_> = project
2178            .read(cx)
2179            .visible_worktrees(cx)
2180            .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
2181            .collect();
2182
2183        cx.spawn(async move |_, cx| {
2184            let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
2185
2186            let mut unsaved_buffers = Vec::new();
2187            cx.update(|app_cx| {
2188                let buffer_store = project.read(app_cx).buffer_store();
2189                for buffer_handle in buffer_store.read(app_cx).buffers() {
2190                    let buffer = buffer_handle.read(app_cx);
2191                    if buffer.is_dirty() {
2192                        if let Some(file) = buffer.file() {
2193                            let path = file.path().to_string_lossy().to_string();
2194                            unsaved_buffers.push(path);
2195                        }
2196                    }
2197                }
2198            })
2199            .ok();
2200
2201            Arc::new(ProjectSnapshot {
2202                worktree_snapshots,
2203                unsaved_buffer_paths: unsaved_buffers,
2204                timestamp: Utc::now(),
2205            })
2206        })
2207    }
2208
2209    fn worktree_snapshot(
2210        worktree: Entity<project::Worktree>,
2211        git_store: Entity<GitStore>,
2212        cx: &App,
2213    ) -> Task<WorktreeSnapshot> {
2214        cx.spawn(async move |cx| {
2215            // Get worktree path and snapshot
2216            let worktree_info = cx.update(|app_cx| {
2217                let worktree = worktree.read(app_cx);
2218                let path = worktree.abs_path().to_string_lossy().to_string();
2219                let snapshot = worktree.snapshot();
2220                (path, snapshot)
2221            });
2222
2223            let Ok((worktree_path, _snapshot)) = worktree_info else {
2224                return WorktreeSnapshot {
2225                    worktree_path: String::new(),
2226                    git_state: None,
2227                };
2228            };
2229
2230            let git_state = git_store
2231                .update(cx, |git_store, cx| {
2232                    git_store
2233                        .repositories()
2234                        .values()
2235                        .find(|repo| {
2236                            repo.read(cx)
2237                                .abs_path_to_repo_path(&worktree.read(cx).abs_path())
2238                                .is_some()
2239                        })
2240                        .cloned()
2241                })
2242                .ok()
2243                .flatten()
2244                .map(|repo| {
2245                    repo.update(cx, |repo, _| {
2246                        let current_branch =
2247                            repo.branch.as_ref().map(|branch| branch.name().to_owned());
2248                        repo.send_job(None, |state, _| async move {
2249                            let RepositoryState::Local { backend, .. } = state else {
2250                                return GitState {
2251                                    remote_url: None,
2252                                    head_sha: None,
2253                                    current_branch,
2254                                    diff: None,
2255                                };
2256                            };
2257
2258                            let remote_url = backend.remote_url("origin");
2259                            let head_sha = backend.head_sha().await;
2260                            let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
2261
2262                            GitState {
2263                                remote_url,
2264                                head_sha,
2265                                current_branch,
2266                                diff,
2267                            }
2268                        })
2269                    })
2270                });
2271
2272            let git_state = match git_state {
2273                Some(git_state) => match git_state.ok() {
2274                    Some(git_state) => git_state.await.ok(),
2275                    None => None,
2276                },
2277                None => None,
2278            };
2279
2280            WorktreeSnapshot {
2281                worktree_path,
2282                git_state,
2283            }
2284        })
2285    }
2286
2287    pub fn to_markdown(&self, cx: &App) -> Result<String> {
2288        let mut markdown = Vec::new();
2289
2290        if let Some(summary) = self.summary() {
2291            writeln!(markdown, "# {summary}\n")?;
2292        };
2293
2294        for message in self.messages() {
2295            writeln!(
2296                markdown,
2297                "## {role}\n",
2298                role = match message.role {
2299                    Role::User => "User",
2300                    Role::Assistant => "Assistant",
2301                    Role::System => "System",
2302                }
2303            )?;
2304
2305            if !message.loaded_context.text.is_empty() {
2306                writeln!(markdown, "{}", message.loaded_context.text)?;
2307            }
2308
2309            if !message.loaded_context.images.is_empty() {
2310                writeln!(
2311                    markdown,
2312                    "\n{} images attached as context.\n",
2313                    message.loaded_context.images.len()
2314                )?;
2315            }
2316
2317            for segment in &message.segments {
2318                match segment {
2319                    MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
2320                    MessageSegment::Thinking { text, .. } => {
2321                        writeln!(markdown, "<think>\n{}\n</think>\n", text)?
2322                    }
2323                    MessageSegment::RedactedThinking(_) => {}
2324                }
2325            }
2326
2327            for tool_use in self.tool_uses_for_message(message.id, cx) {
2328                writeln!(
2329                    markdown,
2330                    "**Use Tool: {} ({})**",
2331                    tool_use.name, tool_use.id
2332                )?;
2333                writeln!(markdown, "```json")?;
2334                writeln!(
2335                    markdown,
2336                    "{}",
2337                    serde_json::to_string_pretty(&tool_use.input)?
2338                )?;
2339                writeln!(markdown, "```")?;
2340            }
2341
2342            for tool_result in self.tool_results_for_message(message.id) {
2343                write!(markdown, "\n**Tool Results: {}", tool_result.tool_use_id)?;
2344                if tool_result.is_error {
2345                    write!(markdown, " (Error)")?;
2346                }
2347
2348                writeln!(markdown, "**\n")?;
2349                writeln!(markdown, "{}", tool_result.content)?;
2350            }
2351        }
2352
2353        Ok(String::from_utf8_lossy(&markdown).to_string())
2354    }
2355
2356    pub fn keep_edits_in_range(
2357        &mut self,
2358        buffer: Entity<language::Buffer>,
2359        buffer_range: Range<language::Anchor>,
2360        cx: &mut Context<Self>,
2361    ) {
2362        self.action_log.update(cx, |action_log, cx| {
2363            action_log.keep_edits_in_range(buffer, buffer_range, cx)
2364        });
2365    }
2366
2367    pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
2368        self.action_log
2369            .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
2370    }
2371
2372    pub fn reject_edits_in_ranges(
2373        &mut self,
2374        buffer: Entity<language::Buffer>,
2375        buffer_ranges: Vec<Range<language::Anchor>>,
2376        cx: &mut Context<Self>,
2377    ) -> Task<Result<()>> {
2378        self.action_log.update(cx, |action_log, cx| {
2379            action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
2380        })
2381    }
2382
2383    pub fn action_log(&self) -> &Entity<ActionLog> {
2384        &self.action_log
2385    }
2386
2387    pub fn project(&self) -> &Entity<Project> {
2388        &self.project
2389    }
2390
2391    pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
2392        if !cx.has_flag::<feature_flags::ThreadAutoCaptureFeatureFlag>() {
2393            return;
2394        }
2395
2396        let now = Instant::now();
2397        if let Some(last) = self.last_auto_capture_at {
2398            if now.duration_since(last).as_secs() < 10 {
2399                return;
2400            }
2401        }
2402
2403        self.last_auto_capture_at = Some(now);
2404
2405        let thread_id = self.id().clone();
2406        let github_login = self
2407            .project
2408            .read(cx)
2409            .user_store()
2410            .read(cx)
2411            .current_user()
2412            .map(|user| user.github_login.clone());
2413        let client = self.project.read(cx).client().clone();
2414        let serialize_task = self.serialize(cx);
2415
2416        cx.background_executor()
2417            .spawn(async move {
2418                if let Ok(serialized_thread) = serialize_task.await {
2419                    if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
2420                        telemetry::event!(
2421                            "Agent Thread Auto-Captured",
2422                            thread_id = thread_id.to_string(),
2423                            thread_data = thread_data,
2424                            auto_capture_reason = "tracked_user",
2425                            github_login = github_login
2426                        );
2427
2428                        client.telemetry().flush_events().await;
2429                    }
2430                }
2431            })
2432            .detach();
2433    }
2434
2435    pub fn cumulative_token_usage(&self) -> TokenUsage {
2436        self.cumulative_token_usage
2437    }
2438
2439    pub fn token_usage_up_to_message(&self, message_id: MessageId) -> TotalTokenUsage {
2440        let Some(model) = self.configured_model.as_ref() else {
2441            return TotalTokenUsage::default();
2442        };
2443
2444        let max = model.model.max_token_count();
2445
2446        let index = self
2447            .messages
2448            .iter()
2449            .position(|msg| msg.id == message_id)
2450            .unwrap_or(0);
2451
2452        if index == 0 {
2453            return TotalTokenUsage { total: 0, max };
2454        }
2455
2456        let token_usage = &self
2457            .request_token_usage
2458            .get(index - 1)
2459            .cloned()
2460            .unwrap_or_default();
2461
2462        TotalTokenUsage {
2463            total: token_usage.total_tokens() as usize,
2464            max,
2465        }
2466    }
2467
2468    pub fn total_token_usage(&self) -> Option<TotalTokenUsage> {
2469        let model = self.configured_model.as_ref()?;
2470
2471        let max = model.model.max_token_count();
2472
2473        if let Some(exceeded_error) = &self.exceeded_window_error {
2474            if model.model.id() == exceeded_error.model_id {
2475                return Some(TotalTokenUsage {
2476                    total: exceeded_error.token_count,
2477                    max,
2478                });
2479            }
2480        }
2481
2482        let total = self
2483            .token_usage_at_last_message()
2484            .unwrap_or_default()
2485            .total_tokens() as usize;
2486
2487        Some(TotalTokenUsage { total, max })
2488    }
2489
2490    fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
2491        self.request_token_usage
2492            .get(self.messages.len().saturating_sub(1))
2493            .or_else(|| self.request_token_usage.last())
2494            .cloned()
2495    }
2496
2497    fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
2498        let placeholder = self.token_usage_at_last_message().unwrap_or_default();
2499        self.request_token_usage
2500            .resize(self.messages.len(), placeholder);
2501
2502        if let Some(last) = self.request_token_usage.last_mut() {
2503            *last = token_usage;
2504        }
2505    }
2506
2507    pub fn deny_tool_use(
2508        &mut self,
2509        tool_use_id: LanguageModelToolUseId,
2510        tool_name: Arc<str>,
2511        window: Option<AnyWindowHandle>,
2512        cx: &mut Context<Self>,
2513    ) {
2514        let err = Err(anyhow::anyhow!(
2515            "Permission to run tool action denied by user"
2516        ));
2517
2518        self.tool_use.insert_tool_output(
2519            tool_use_id.clone(),
2520            tool_name,
2521            err,
2522            self.configured_model.as_ref(),
2523        );
2524        self.tool_finished(tool_use_id.clone(), None, true, window, cx);
2525    }
2526}
2527
2528#[derive(Debug, Clone, Error)]
2529pub enum ThreadError {
2530    #[error("Payment required")]
2531    PaymentRequired,
2532    #[error("Max monthly spend reached")]
2533    MaxMonthlySpendReached,
2534    #[error("Model request limit reached")]
2535    ModelRequestLimitReached { plan: Plan },
2536    #[error("Message {header}: {message}")]
2537    Message {
2538        header: SharedString,
2539        message: SharedString,
2540    },
2541}
2542
2543#[derive(Debug, Clone)]
2544pub enum ThreadEvent {
2545    ShowError(ThreadError),
2546    UsageUpdated(RequestUsage),
2547    StreamedCompletion,
2548    ReceivedTextChunk,
2549    NewRequest,
2550    StreamedAssistantText(MessageId, String),
2551    StreamedAssistantThinking(MessageId, String),
2552    StreamedToolUse {
2553        tool_use_id: LanguageModelToolUseId,
2554        ui_text: Arc<str>,
2555        input: serde_json::Value,
2556    },
2557    InvalidToolInput {
2558        tool_use_id: LanguageModelToolUseId,
2559        ui_text: Arc<str>,
2560        invalid_input_json: Arc<str>,
2561    },
2562    Stopped(Result<StopReason, Arc<anyhow::Error>>),
2563    MessageAdded(MessageId),
2564    MessageEdited(MessageId),
2565    MessageDeleted(MessageId),
2566    SummaryGenerated,
2567    SummaryChanged,
2568    UsePendingTools {
2569        tool_uses: Vec<PendingToolUse>,
2570    },
2571    ToolFinished {
2572        #[allow(unused)]
2573        tool_use_id: LanguageModelToolUseId,
2574        /// The pending tool use that corresponds to this tool.
2575        pending_tool_use: Option<PendingToolUse>,
2576    },
2577    CheckpointChanged,
2578    ToolConfirmationNeeded,
2579    CancelEditing,
2580    CompletionCanceled,
2581}
2582
2583impl EventEmitter<ThreadEvent> for Thread {}
2584
2585struct PendingCompletion {
2586    id: usize,
2587    queue_state: QueueState,
2588    _task: Task<()>,
2589}
2590
2591#[cfg(test)]
2592mod tests {
2593    use super::*;
2594    use crate::{ThreadStore, context::load_context, context_store::ContextStore, thread_store};
2595    use assistant_settings::AssistantSettings;
2596    use assistant_tool::ToolRegistry;
2597    use context_server::ContextServerSettings;
2598    use editor::EditorSettings;
2599    use gpui::TestAppContext;
2600    use language_model::fake_provider::FakeLanguageModel;
2601    use project::{FakeFs, Project};
2602    use prompt_store::PromptBuilder;
2603    use serde_json::json;
2604    use settings::{Settings, SettingsStore};
2605    use std::sync::Arc;
2606    use theme::ThemeSettings;
2607    use util::path;
2608    use workspace::Workspace;
2609
2610    #[gpui::test]
2611    async fn test_message_with_context(cx: &mut TestAppContext) {
2612        init_test_settings(cx);
2613
2614        let project = create_test_project(
2615            cx,
2616            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
2617        )
2618        .await;
2619
2620        let (_workspace, _thread_store, thread, context_store, model) =
2621            setup_test_environment(cx, project.clone()).await;
2622
2623        add_file_to_context(&project, &context_store, "test/code.rs", cx)
2624            .await
2625            .unwrap();
2626
2627        let context = context_store.update(cx, |store, _| store.context().next().cloned().unwrap());
2628        let loaded_context = cx
2629            .update(|cx| load_context(vec![context], &project, &None, cx))
2630            .await;
2631
2632        // Insert user message with context
2633        let message_id = thread.update(cx, |thread, cx| {
2634            thread.insert_user_message(
2635                "Please explain this code",
2636                loaded_context,
2637                None,
2638                Vec::new(),
2639                cx,
2640            )
2641        });
2642
2643        // Check content and context in message object
2644        let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2645
2646        // Use different path format strings based on platform for the test
2647        #[cfg(windows)]
2648        let path_part = r"test\code.rs";
2649        #[cfg(not(windows))]
2650        let path_part = "test/code.rs";
2651
2652        let expected_context = format!(
2653            r#"
2654<context>
2655The following items were attached by the user. They are up-to-date and don't need to be re-read.
2656
2657<files>
2658```rs {path_part}
2659fn main() {{
2660    println!("Hello, world!");
2661}}
2662```
2663</files>
2664</context>
2665"#
2666        );
2667
2668        assert_eq!(message.role, Role::User);
2669        assert_eq!(message.segments.len(), 1);
2670        assert_eq!(
2671            message.segments[0],
2672            MessageSegment::Text("Please explain this code".to_string())
2673        );
2674        assert_eq!(message.loaded_context.text, expected_context);
2675
2676        // Check message in request
2677        let request = thread.update(cx, |thread, cx| {
2678            thread.to_completion_request(model.clone(), cx)
2679        });
2680
2681        assert_eq!(request.messages.len(), 2);
2682        let expected_full_message = format!("{}Please explain this code", expected_context);
2683        assert_eq!(request.messages[1].string_contents(), expected_full_message);
2684    }
2685
2686    #[gpui::test]
2687    async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2688        init_test_settings(cx);
2689
2690        let project = create_test_project(
2691            cx,
2692            json!({
2693                "file1.rs": "fn function1() {}\n",
2694                "file2.rs": "fn function2() {}\n",
2695                "file3.rs": "fn function3() {}\n",
2696                "file4.rs": "fn function4() {}\n",
2697            }),
2698        )
2699        .await;
2700
2701        let (_, _thread_store, thread, context_store, model) =
2702            setup_test_environment(cx, project.clone()).await;
2703
2704        // First message with context 1
2705        add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2706            .await
2707            .unwrap();
2708        let new_contexts = context_store.update(cx, |store, cx| {
2709            store.new_context_for_thread(thread.read(cx), None)
2710        });
2711        assert_eq!(new_contexts.len(), 1);
2712        let loaded_context = cx
2713            .update(|cx| load_context(new_contexts, &project, &None, cx))
2714            .await;
2715        let message1_id = thread.update(cx, |thread, cx| {
2716            thread.insert_user_message("Message 1", loaded_context, None, Vec::new(), cx)
2717        });
2718
2719        // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2720        add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2721            .await
2722            .unwrap();
2723        let new_contexts = context_store.update(cx, |store, cx| {
2724            store.new_context_for_thread(thread.read(cx), None)
2725        });
2726        assert_eq!(new_contexts.len(), 1);
2727        let loaded_context = cx
2728            .update(|cx| load_context(new_contexts, &project, &None, cx))
2729            .await;
2730        let message2_id = thread.update(cx, |thread, cx| {
2731            thread.insert_user_message("Message 2", loaded_context, None, Vec::new(), cx)
2732        });
2733
2734        // Third message with all three contexts (contexts 1 and 2 should be skipped)
2735        //
2736        add_file_to_context(&project, &context_store, "test/file3.rs", cx)
2737            .await
2738            .unwrap();
2739        let new_contexts = context_store.update(cx, |store, cx| {
2740            store.new_context_for_thread(thread.read(cx), None)
2741        });
2742        assert_eq!(new_contexts.len(), 1);
2743        let loaded_context = cx
2744            .update(|cx| load_context(new_contexts, &project, &None, cx))
2745            .await;
2746        let message3_id = thread.update(cx, |thread, cx| {
2747            thread.insert_user_message("Message 3", loaded_context, None, Vec::new(), cx)
2748        });
2749
2750        // Check what contexts are included in each message
2751        let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
2752            (
2753                thread.message(message1_id).unwrap().clone(),
2754                thread.message(message2_id).unwrap().clone(),
2755                thread.message(message3_id).unwrap().clone(),
2756            )
2757        });
2758
2759        // First message should include context 1
2760        assert!(message1.loaded_context.text.contains("file1.rs"));
2761
2762        // Second message should include only context 2 (not 1)
2763        assert!(!message2.loaded_context.text.contains("file1.rs"));
2764        assert!(message2.loaded_context.text.contains("file2.rs"));
2765
2766        // Third message should include only context 3 (not 1 or 2)
2767        assert!(!message3.loaded_context.text.contains("file1.rs"));
2768        assert!(!message3.loaded_context.text.contains("file2.rs"));
2769        assert!(message3.loaded_context.text.contains("file3.rs"));
2770
2771        // Check entire request to make sure all contexts are properly included
2772        let request = thread.update(cx, |thread, cx| {
2773            thread.to_completion_request(model.clone(), cx)
2774        });
2775
2776        // The request should contain all 3 messages
2777        assert_eq!(request.messages.len(), 4);
2778
2779        // Check that the contexts are properly formatted in each message
2780        assert!(request.messages[1].string_contents().contains("file1.rs"));
2781        assert!(!request.messages[1].string_contents().contains("file2.rs"));
2782        assert!(!request.messages[1].string_contents().contains("file3.rs"));
2783
2784        assert!(!request.messages[2].string_contents().contains("file1.rs"));
2785        assert!(request.messages[2].string_contents().contains("file2.rs"));
2786        assert!(!request.messages[2].string_contents().contains("file3.rs"));
2787
2788        assert!(!request.messages[3].string_contents().contains("file1.rs"));
2789        assert!(!request.messages[3].string_contents().contains("file2.rs"));
2790        assert!(request.messages[3].string_contents().contains("file3.rs"));
2791
2792        add_file_to_context(&project, &context_store, "test/file4.rs", cx)
2793            .await
2794            .unwrap();
2795        let new_contexts = context_store.update(cx, |store, cx| {
2796            store.new_context_for_thread(thread.read(cx), Some(message2_id))
2797        });
2798        assert_eq!(new_contexts.len(), 3);
2799        let loaded_context = cx
2800            .update(|cx| load_context(new_contexts, &project, &None, cx))
2801            .await
2802            .loaded_context;
2803
2804        assert!(!loaded_context.text.contains("file1.rs"));
2805        assert!(loaded_context.text.contains("file2.rs"));
2806        assert!(loaded_context.text.contains("file3.rs"));
2807        assert!(loaded_context.text.contains("file4.rs"));
2808
2809        let new_contexts = context_store.update(cx, |store, cx| {
2810            // Remove file4.rs
2811            store.remove_context(&loaded_context.contexts[2].handle(), cx);
2812            store.new_context_for_thread(thread.read(cx), Some(message2_id))
2813        });
2814        assert_eq!(new_contexts.len(), 2);
2815        let loaded_context = cx
2816            .update(|cx| load_context(new_contexts, &project, &None, cx))
2817            .await
2818            .loaded_context;
2819
2820        assert!(!loaded_context.text.contains("file1.rs"));
2821        assert!(loaded_context.text.contains("file2.rs"));
2822        assert!(loaded_context.text.contains("file3.rs"));
2823        assert!(!loaded_context.text.contains("file4.rs"));
2824
2825        let new_contexts = context_store.update(cx, |store, cx| {
2826            // Remove file3.rs
2827            store.remove_context(&loaded_context.contexts[1].handle(), cx);
2828            store.new_context_for_thread(thread.read(cx), Some(message2_id))
2829        });
2830        assert_eq!(new_contexts.len(), 1);
2831        let loaded_context = cx
2832            .update(|cx| load_context(new_contexts, &project, &None, cx))
2833            .await
2834            .loaded_context;
2835
2836        assert!(!loaded_context.text.contains("file1.rs"));
2837        assert!(loaded_context.text.contains("file2.rs"));
2838        assert!(!loaded_context.text.contains("file3.rs"));
2839        assert!(!loaded_context.text.contains("file4.rs"));
2840    }
2841
2842    #[gpui::test]
2843    async fn test_message_without_files(cx: &mut TestAppContext) {
2844        init_test_settings(cx);
2845
2846        let project = create_test_project(
2847            cx,
2848            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
2849        )
2850        .await;
2851
2852        let (_, _thread_store, thread, _context_store, model) =
2853            setup_test_environment(cx, project.clone()).await;
2854
2855        // Insert user message without any context (empty context vector)
2856        let message_id = thread.update(cx, |thread, cx| {
2857            thread.insert_user_message(
2858                "What is the best way to learn Rust?",
2859                ContextLoadResult::default(),
2860                None,
2861                Vec::new(),
2862                cx,
2863            )
2864        });
2865
2866        // Check content and context in message object
2867        let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2868
2869        // Context should be empty when no files are included
2870        assert_eq!(message.role, Role::User);
2871        assert_eq!(message.segments.len(), 1);
2872        assert_eq!(
2873            message.segments[0],
2874            MessageSegment::Text("What is the best way to learn Rust?".to_string())
2875        );
2876        assert_eq!(message.loaded_context.text, "");
2877
2878        // Check message in request
2879        let request = thread.update(cx, |thread, cx| {
2880            thread.to_completion_request(model.clone(), cx)
2881        });
2882
2883        assert_eq!(request.messages.len(), 2);
2884        assert_eq!(
2885            request.messages[1].string_contents(),
2886            "What is the best way to learn Rust?"
2887        );
2888
2889        // Add second message, also without context
2890        let message2_id = thread.update(cx, |thread, cx| {
2891            thread.insert_user_message(
2892                "Are there any good books?",
2893                ContextLoadResult::default(),
2894                None,
2895                Vec::new(),
2896                cx,
2897            )
2898        });
2899
2900        let message2 =
2901            thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
2902        assert_eq!(message2.loaded_context.text, "");
2903
2904        // Check that both messages appear in the request
2905        let request = thread.update(cx, |thread, cx| {
2906            thread.to_completion_request(model.clone(), cx)
2907        });
2908
2909        assert_eq!(request.messages.len(), 3);
2910        assert_eq!(
2911            request.messages[1].string_contents(),
2912            "What is the best way to learn Rust?"
2913        );
2914        assert_eq!(
2915            request.messages[2].string_contents(),
2916            "Are there any good books?"
2917        );
2918    }
2919
2920    #[gpui::test]
2921    async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
2922        init_test_settings(cx);
2923
2924        let project = create_test_project(
2925            cx,
2926            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
2927        )
2928        .await;
2929
2930        let (_workspace, _thread_store, thread, context_store, model) =
2931            setup_test_environment(cx, project.clone()).await;
2932
2933        // Open buffer and add it to context
2934        let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
2935            .await
2936            .unwrap();
2937
2938        let context = context_store.update(cx, |store, _| store.context().next().cloned().unwrap());
2939        let loaded_context = cx
2940            .update(|cx| load_context(vec![context], &project, &None, cx))
2941            .await;
2942
2943        // Insert user message with the buffer as context
2944        thread.update(cx, |thread, cx| {
2945            thread.insert_user_message("Explain this code", loaded_context, None, Vec::new(), cx)
2946        });
2947
2948        // Create a request and check that it doesn't have a stale buffer warning yet
2949        let initial_request = thread.update(cx, |thread, cx| {
2950            thread.to_completion_request(model.clone(), cx)
2951        });
2952
2953        // Make sure we don't have a stale file warning yet
2954        let has_stale_warning = initial_request.messages.iter().any(|msg| {
2955            msg.string_contents()
2956                .contains("These files changed since last read:")
2957        });
2958        assert!(
2959            !has_stale_warning,
2960            "Should not have stale buffer warning before buffer is modified"
2961        );
2962
2963        // Modify the buffer
2964        buffer.update(cx, |buffer, cx| {
2965            // Find a position at the end of line 1
2966            buffer.edit(
2967                [(1..1, "\n    println!(\"Added a new line\");\n")],
2968                None,
2969                cx,
2970            );
2971        });
2972
2973        // Insert another user message without context
2974        thread.update(cx, |thread, cx| {
2975            thread.insert_user_message(
2976                "What does the code do now?",
2977                ContextLoadResult::default(),
2978                None,
2979                Vec::new(),
2980                cx,
2981            )
2982        });
2983
2984        // Create a new request and check for the stale buffer warning
2985        let new_request = thread.update(cx, |thread, cx| {
2986            thread.to_completion_request(model.clone(), cx)
2987        });
2988
2989        // We should have a stale file warning as the last message
2990        let last_message = new_request
2991            .messages
2992            .last()
2993            .expect("Request should have messages");
2994
2995        // The last message should be the stale buffer notification
2996        assert_eq!(last_message.role, Role::User);
2997
2998        // Check the exact content of the message
2999        let expected_content = "These files changed since last read:\n- code.rs\n";
3000        assert_eq!(
3001            last_message.string_contents(),
3002            expected_content,
3003            "Last message should be exactly the stale buffer notification"
3004        );
3005    }
3006
3007    fn init_test_settings(cx: &mut TestAppContext) {
3008        cx.update(|cx| {
3009            let settings_store = SettingsStore::test(cx);
3010            cx.set_global(settings_store);
3011            language::init(cx);
3012            Project::init_settings(cx);
3013            AssistantSettings::register(cx);
3014            prompt_store::init(cx);
3015            thread_store::init(cx);
3016            workspace::init_settings(cx);
3017            language_model::init_settings(cx);
3018            ThemeSettings::register(cx);
3019            ContextServerSettings::register(cx);
3020            EditorSettings::register(cx);
3021            ToolRegistry::default_global(cx);
3022        });
3023    }
3024
3025    // Helper to create a test project with test files
3026    async fn create_test_project(
3027        cx: &mut TestAppContext,
3028        files: serde_json::Value,
3029    ) -> Entity<Project> {
3030        let fs = FakeFs::new(cx.executor());
3031        fs.insert_tree(path!("/test"), files).await;
3032        Project::test(fs, [path!("/test").as_ref()], cx).await
3033    }
3034
3035    async fn setup_test_environment(
3036        cx: &mut TestAppContext,
3037        project: Entity<Project>,
3038    ) -> (
3039        Entity<Workspace>,
3040        Entity<ThreadStore>,
3041        Entity<Thread>,
3042        Entity<ContextStore>,
3043        Arc<dyn LanguageModel>,
3044    ) {
3045        let (workspace, cx) =
3046            cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
3047
3048        let thread_store = cx
3049            .update(|_, cx| {
3050                ThreadStore::load(
3051                    project.clone(),
3052                    cx.new(|_| ToolWorkingSet::default()),
3053                    None,
3054                    Arc::new(PromptBuilder::new(None).unwrap()),
3055                    cx,
3056                )
3057            })
3058            .await
3059            .unwrap();
3060
3061        let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
3062        let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
3063
3064        let model = FakeLanguageModel::default();
3065        let model: Arc<dyn LanguageModel> = Arc::new(model);
3066
3067        (workspace, thread_store, thread, context_store, model)
3068    }
3069
3070    async fn add_file_to_context(
3071        project: &Entity<Project>,
3072        context_store: &Entity<ContextStore>,
3073        path: &str,
3074        cx: &mut TestAppContext,
3075    ) -> Result<Entity<language::Buffer>> {
3076        let buffer_path = project
3077            .read_with(cx, |project, cx| project.find_project_path(path, cx))
3078            .unwrap();
3079
3080        let buffer = project
3081            .update(cx, |project, cx| {
3082                project.open_buffer(buffer_path.clone(), cx)
3083            })
3084            .await
3085            .unwrap();
3086
3087        context_store.update(cx, |context_store, cx| {
3088            context_store.add_file_from_buffer(&buffer_path, buffer.clone(), false, cx);
3089        });
3090
3091        Ok(buffer)
3092    }
3093}