thread.rs

   1use std::fmt::Write as _;
   2use std::io::Write;
   3use std::ops::Range;
   4use std::sync::Arc;
   5use std::time::Instant;
   6
   7use anyhow::{Context as _, Result, anyhow};
   8use assistant_settings::AssistantSettings;
   9use assistant_tool::{ActionLog, Tool, ToolWorkingSet};
  10use chrono::{DateTime, Utc};
  11use collections::{BTreeMap, HashMap};
  12use feature_flags::{self, FeatureFlagAppExt};
  13use futures::future::Shared;
  14use futures::{FutureExt, StreamExt as _};
  15use git::repository::DiffType;
  16use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
  17use language_model::{
  18    ConfiguredModel, LanguageModel, LanguageModelCompletionEvent, LanguageModelRegistry,
  19    LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
  20    LanguageModelToolResult, LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
  21    PaymentRequiredError, Role, StopReason, TokenUsage,
  22};
  23use project::Project;
  24use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
  25use prompt_store::PromptBuilder;
  26use schemars::JsonSchema;
  27use serde::{Deserialize, Serialize};
  28use settings::Settings;
  29use thiserror::Error;
  30use util::{ResultExt as _, TryFutureExt as _, post_inc};
  31use uuid::Uuid;
  32
  33use crate::context::{AssistantContext, ContextId, format_context_as_string};
  34use crate::thread_store::{
  35    SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult,
  36    SerializedToolUse, SharedProjectContext,
  37};
  38use crate::tool_use::{PendingToolUse, ToolUse, ToolUseState, USING_TOOL_MARKER};
  39
  40#[derive(Debug, Clone, Copy)]
  41pub enum RequestKind {
  42    Chat,
  43    /// Used when summarizing a thread.
  44    Summarize,
  45}
  46
  47#[derive(
  48    Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
  49)]
  50pub struct ThreadId(Arc<str>);
  51
  52impl ThreadId {
  53    pub fn new() -> Self {
  54        Self(Uuid::new_v4().to_string().into())
  55    }
  56}
  57
  58impl std::fmt::Display for ThreadId {
  59    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  60        write!(f, "{}", self.0)
  61    }
  62}
  63
  64impl From<&str> for ThreadId {
  65    fn from(value: &str) -> Self {
  66        Self(value.into())
  67    }
  68}
  69
  70#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
  71pub struct MessageId(pub(crate) usize);
  72
  73impl MessageId {
  74    fn post_inc(&mut self) -> Self {
  75        Self(post_inc(&mut self.0))
  76    }
  77}
  78
  79/// A message in a [`Thread`].
  80#[derive(Debug, Clone)]
  81pub struct Message {
  82    pub id: MessageId,
  83    pub role: Role,
  84    pub segments: Vec<MessageSegment>,
  85    pub context: String,
  86}
  87
  88impl Message {
  89    /// Returns whether the message contains any meaningful text that should be displayed
  90    /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
  91    pub fn should_display_content(&self) -> bool {
  92        self.segments.iter().all(|segment| segment.should_display())
  93    }
  94
  95    pub fn push_thinking(&mut self, text: &str) {
  96        if let Some(MessageSegment::Thinking(segment)) = self.segments.last_mut() {
  97            segment.push_str(text);
  98        } else {
  99            self.segments
 100                .push(MessageSegment::Thinking(text.to_string()));
 101        }
 102    }
 103
 104    pub fn push_text(&mut self, text: &str) {
 105        if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
 106            segment.push_str(text);
 107        } else {
 108            self.segments.push(MessageSegment::Text(text.to_string()));
 109        }
 110    }
 111
 112    pub fn to_string(&self) -> String {
 113        let mut result = String::new();
 114
 115        if !self.context.is_empty() {
 116            result.push_str(&self.context);
 117        }
 118
 119        for segment in &self.segments {
 120            match segment {
 121                MessageSegment::Text(text) => result.push_str(text),
 122                MessageSegment::Thinking(text) => {
 123                    result.push_str("<think>");
 124                    result.push_str(text);
 125                    result.push_str("</think>");
 126                }
 127            }
 128        }
 129
 130        result
 131    }
 132}
 133
 134#[derive(Debug, Clone, PartialEq, Eq)]
 135pub enum MessageSegment {
 136    Text(String),
 137    Thinking(String),
 138}
 139
 140impl MessageSegment {
 141    pub fn text_mut(&mut self) -> &mut String {
 142        match self {
 143            Self::Text(text) => text,
 144            Self::Thinking(text) => text,
 145        }
 146    }
 147
 148    pub fn should_display(&self) -> bool {
 149        // We add USING_TOOL_MARKER when making a request that includes tool uses
 150        // without non-whitespace text around them, and this can cause the model
 151        // to mimic the pattern, so we consider those segments not displayable.
 152        match self {
 153            Self::Text(text) => text.is_empty() || text.trim() == USING_TOOL_MARKER,
 154            Self::Thinking(text) => text.is_empty() || text.trim() == USING_TOOL_MARKER,
 155        }
 156    }
 157}
 158
 159#[derive(Debug, Clone, Serialize, Deserialize)]
 160pub struct ProjectSnapshot {
 161    pub worktree_snapshots: Vec<WorktreeSnapshot>,
 162    pub unsaved_buffer_paths: Vec<String>,
 163    pub timestamp: DateTime<Utc>,
 164}
 165
 166#[derive(Debug, Clone, Serialize, Deserialize)]
 167pub struct WorktreeSnapshot {
 168    pub worktree_path: String,
 169    pub git_state: Option<GitState>,
 170}
 171
 172#[derive(Debug, Clone, Serialize, Deserialize)]
 173pub struct GitState {
 174    pub remote_url: Option<String>,
 175    pub head_sha: Option<String>,
 176    pub current_branch: Option<String>,
 177    pub diff: Option<String>,
 178}
 179
 180#[derive(Clone)]
 181pub struct ThreadCheckpoint {
 182    message_id: MessageId,
 183    git_checkpoint: GitStoreCheckpoint,
 184}
 185
 186#[derive(Copy, Clone, Debug, PartialEq, Eq)]
 187pub enum ThreadFeedback {
 188    Positive,
 189    Negative,
 190}
 191
 192pub enum LastRestoreCheckpoint {
 193    Pending {
 194        message_id: MessageId,
 195    },
 196    Error {
 197        message_id: MessageId,
 198        error: String,
 199    },
 200}
 201
 202impl LastRestoreCheckpoint {
 203    pub fn message_id(&self) -> MessageId {
 204        match self {
 205            LastRestoreCheckpoint::Pending { message_id } => *message_id,
 206            LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
 207        }
 208    }
 209}
 210
 211#[derive(Clone, Debug, Default, Serialize, Deserialize)]
 212pub enum DetailedSummaryState {
 213    #[default]
 214    NotGenerated,
 215    Generating {
 216        message_id: MessageId,
 217    },
 218    Generated {
 219        text: SharedString,
 220        message_id: MessageId,
 221    },
 222}
 223
 224#[derive(Default)]
 225pub struct TotalTokenUsage {
 226    pub total: usize,
 227    pub max: usize,
 228    pub ratio: TokenUsageRatio,
 229}
 230
 231#[derive(Default, PartialEq, Eq)]
 232pub enum TokenUsageRatio {
 233    #[default]
 234    Normal,
 235    Warning,
 236    Exceeded,
 237}
 238
 239/// A thread of conversation with the LLM.
 240pub struct Thread {
 241    id: ThreadId,
 242    updated_at: DateTime<Utc>,
 243    summary: Option<SharedString>,
 244    pending_summary: Task<Option<()>>,
 245    detailed_summary_state: DetailedSummaryState,
 246    messages: Vec<Message>,
 247    next_message_id: MessageId,
 248    context: BTreeMap<ContextId, AssistantContext>,
 249    context_by_message: HashMap<MessageId, Vec<ContextId>>,
 250    project_context: SharedProjectContext,
 251    checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
 252    completion_count: usize,
 253    pending_completions: Vec<PendingCompletion>,
 254    project: Entity<Project>,
 255    prompt_builder: Arc<PromptBuilder>,
 256    tools: Arc<ToolWorkingSet>,
 257    tool_use: ToolUseState,
 258    action_log: Entity<ActionLog>,
 259    last_restore_checkpoint: Option<LastRestoreCheckpoint>,
 260    pending_checkpoint: Option<ThreadCheckpoint>,
 261    initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
 262    cumulative_token_usage: TokenUsage,
 263    feedback: Option<ThreadFeedback>,
 264    message_feedback: HashMap<MessageId, ThreadFeedback>,
 265    last_auto_capture_at: Option<Instant>,
 266}
 267
 268impl Thread {
 269    pub fn new(
 270        project: Entity<Project>,
 271        tools: Arc<ToolWorkingSet>,
 272        prompt_builder: Arc<PromptBuilder>,
 273        system_prompt: SharedProjectContext,
 274        cx: &mut Context<Self>,
 275    ) -> Self {
 276        Self {
 277            id: ThreadId::new(),
 278            updated_at: Utc::now(),
 279            summary: None,
 280            pending_summary: Task::ready(None),
 281            detailed_summary_state: DetailedSummaryState::NotGenerated,
 282            messages: Vec::new(),
 283            next_message_id: MessageId(0),
 284            context: BTreeMap::default(),
 285            context_by_message: HashMap::default(),
 286            project_context: system_prompt,
 287            checkpoints_by_message: HashMap::default(),
 288            completion_count: 0,
 289            pending_completions: Vec::new(),
 290            project: project.clone(),
 291            prompt_builder,
 292            tools: tools.clone(),
 293            last_restore_checkpoint: None,
 294            pending_checkpoint: None,
 295            tool_use: ToolUseState::new(tools.clone()),
 296            action_log: cx.new(|_| ActionLog::new(project.clone())),
 297            initial_project_snapshot: {
 298                let project_snapshot = Self::project_snapshot(project, cx);
 299                cx.foreground_executor()
 300                    .spawn(async move { Some(project_snapshot.await) })
 301                    .shared()
 302            },
 303            cumulative_token_usage: TokenUsage::default(),
 304            feedback: None,
 305            message_feedback: HashMap::default(),
 306            last_auto_capture_at: None,
 307        }
 308    }
 309
 310    pub fn deserialize(
 311        id: ThreadId,
 312        serialized: SerializedThread,
 313        project: Entity<Project>,
 314        tools: Arc<ToolWorkingSet>,
 315        prompt_builder: Arc<PromptBuilder>,
 316        project_context: SharedProjectContext,
 317        cx: &mut Context<Self>,
 318    ) -> Self {
 319        let next_message_id = MessageId(
 320            serialized
 321                .messages
 322                .last()
 323                .map(|message| message.id.0 + 1)
 324                .unwrap_or(0),
 325        );
 326        let tool_use =
 327            ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages, |_| true);
 328
 329        Self {
 330            id,
 331            updated_at: serialized.updated_at,
 332            summary: Some(serialized.summary),
 333            pending_summary: Task::ready(None),
 334            detailed_summary_state: serialized.detailed_summary_state,
 335            messages: serialized
 336                .messages
 337                .into_iter()
 338                .map(|message| Message {
 339                    id: message.id,
 340                    role: message.role,
 341                    segments: message
 342                        .segments
 343                        .into_iter()
 344                        .map(|segment| match segment {
 345                            SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
 346                            SerializedMessageSegment::Thinking { text } => {
 347                                MessageSegment::Thinking(text)
 348                            }
 349                        })
 350                        .collect(),
 351                    context: message.context,
 352                })
 353                .collect(),
 354            next_message_id,
 355            context: BTreeMap::default(),
 356            context_by_message: HashMap::default(),
 357            project_context,
 358            checkpoints_by_message: HashMap::default(),
 359            completion_count: 0,
 360            pending_completions: Vec::new(),
 361            last_restore_checkpoint: None,
 362            pending_checkpoint: None,
 363            project: project.clone(),
 364            prompt_builder,
 365            tools,
 366            tool_use,
 367            action_log: cx.new(|_| ActionLog::new(project)),
 368            initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
 369            cumulative_token_usage: serialized.cumulative_token_usage,
 370            feedback: None,
 371            message_feedback: HashMap::default(),
 372            last_auto_capture_at: None,
 373        }
 374    }
 375
 376    pub fn id(&self) -> &ThreadId {
 377        &self.id
 378    }
 379
 380    pub fn is_empty(&self) -> bool {
 381        self.messages.is_empty()
 382    }
 383
 384    pub fn updated_at(&self) -> DateTime<Utc> {
 385        self.updated_at
 386    }
 387
 388    pub fn touch_updated_at(&mut self) {
 389        self.updated_at = Utc::now();
 390    }
 391
 392    pub fn summary(&self) -> Option<SharedString> {
 393        self.summary.clone()
 394    }
 395
 396    pub fn project_context(&self) -> SharedProjectContext {
 397        self.project_context.clone()
 398    }
 399
 400    pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
 401
 402    pub fn summary_or_default(&self) -> SharedString {
 403        self.summary.clone().unwrap_or(Self::DEFAULT_SUMMARY)
 404    }
 405
 406    pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
 407        let Some(current_summary) = &self.summary else {
 408            // Don't allow setting summary until generated
 409            return;
 410        };
 411
 412        let mut new_summary = new_summary.into();
 413
 414        if new_summary.is_empty() {
 415            new_summary = Self::DEFAULT_SUMMARY;
 416        }
 417
 418        if current_summary != &new_summary {
 419            self.summary = Some(new_summary);
 420            cx.emit(ThreadEvent::SummaryChanged);
 421        }
 422    }
 423
 424    pub fn latest_detailed_summary_or_text(&self) -> SharedString {
 425        self.latest_detailed_summary()
 426            .unwrap_or_else(|| self.text().into())
 427    }
 428
 429    fn latest_detailed_summary(&self) -> Option<SharedString> {
 430        if let DetailedSummaryState::Generated { text, .. } = &self.detailed_summary_state {
 431            Some(text.clone())
 432        } else {
 433            None
 434        }
 435    }
 436
 437    pub fn message(&self, id: MessageId) -> Option<&Message> {
 438        self.messages.iter().find(|message| message.id == id)
 439    }
 440
 441    pub fn messages(&self) -> impl Iterator<Item = &Message> {
 442        self.messages.iter()
 443    }
 444
 445    pub fn is_generating(&self) -> bool {
 446        !self.pending_completions.is_empty() || !self.all_tools_finished()
 447    }
 448
 449    pub fn tools(&self) -> &Arc<ToolWorkingSet> {
 450        &self.tools
 451    }
 452
 453    pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
 454        self.tool_use
 455            .pending_tool_uses()
 456            .into_iter()
 457            .find(|tool_use| &tool_use.id == id)
 458    }
 459
 460    pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
 461        self.tool_use
 462            .pending_tool_uses()
 463            .into_iter()
 464            .filter(|tool_use| tool_use.status.needs_confirmation())
 465    }
 466
 467    pub fn has_pending_tool_uses(&self) -> bool {
 468        !self.tool_use.pending_tool_uses().is_empty()
 469    }
 470
 471    pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
 472        self.checkpoints_by_message.get(&id).cloned()
 473    }
 474
 475    pub fn restore_checkpoint(
 476        &mut self,
 477        checkpoint: ThreadCheckpoint,
 478        cx: &mut Context<Self>,
 479    ) -> Task<Result<()>> {
 480        self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
 481            message_id: checkpoint.message_id,
 482        });
 483        cx.emit(ThreadEvent::CheckpointChanged);
 484        cx.notify();
 485
 486        let git_store = self.project().read(cx).git_store().clone();
 487        let restore = git_store.update(cx, |git_store, cx| {
 488            git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
 489        });
 490
 491        cx.spawn(async move |this, cx| {
 492            let result = restore.await;
 493            this.update(cx, |this, cx| {
 494                if let Err(err) = result.as_ref() {
 495                    this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
 496                        message_id: checkpoint.message_id,
 497                        error: err.to_string(),
 498                    });
 499                } else {
 500                    this.truncate(checkpoint.message_id, cx);
 501                    this.last_restore_checkpoint = None;
 502                }
 503                this.pending_checkpoint = None;
 504                cx.emit(ThreadEvent::CheckpointChanged);
 505                cx.notify();
 506            })?;
 507            result
 508        })
 509    }
 510
 511    fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
 512        let pending_checkpoint = if self.is_generating() {
 513            return;
 514        } else if let Some(checkpoint) = self.pending_checkpoint.take() {
 515            checkpoint
 516        } else {
 517            return;
 518        };
 519
 520        let git_store = self.project.read(cx).git_store().clone();
 521        let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
 522        cx.spawn(async move |this, cx| match final_checkpoint.await {
 523            Ok(final_checkpoint) => {
 524                let equal = git_store
 525                    .update(cx, |store, cx| {
 526                        store.compare_checkpoints(
 527                            pending_checkpoint.git_checkpoint.clone(),
 528                            final_checkpoint.clone(),
 529                            cx,
 530                        )
 531                    })?
 532                    .await
 533                    .unwrap_or(false);
 534
 535                if equal {
 536                    git_store
 537                        .update(cx, |store, cx| {
 538                            store.delete_checkpoint(pending_checkpoint.git_checkpoint, cx)
 539                        })?
 540                        .detach();
 541                } else {
 542                    this.update(cx, |this, cx| {
 543                        this.insert_checkpoint(pending_checkpoint, cx)
 544                    })?;
 545                }
 546
 547                git_store
 548                    .update(cx, |store, cx| {
 549                        store.delete_checkpoint(final_checkpoint, cx)
 550                    })?
 551                    .detach();
 552
 553                Ok(())
 554            }
 555            Err(_) => this.update(cx, |this, cx| {
 556                this.insert_checkpoint(pending_checkpoint, cx)
 557            }),
 558        })
 559        .detach();
 560    }
 561
 562    fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
 563        self.checkpoints_by_message
 564            .insert(checkpoint.message_id, checkpoint);
 565        cx.emit(ThreadEvent::CheckpointChanged);
 566        cx.notify();
 567    }
 568
 569    pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
 570        self.last_restore_checkpoint.as_ref()
 571    }
 572
 573    pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
 574        let Some(message_ix) = self
 575            .messages
 576            .iter()
 577            .rposition(|message| message.id == message_id)
 578        else {
 579            return;
 580        };
 581        for deleted_message in self.messages.drain(message_ix..) {
 582            self.context_by_message.remove(&deleted_message.id);
 583            self.checkpoints_by_message.remove(&deleted_message.id);
 584        }
 585        cx.notify();
 586    }
 587
 588    pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AssistantContext> {
 589        self.context_by_message
 590            .get(&id)
 591            .into_iter()
 592            .flat_map(|context| {
 593                context
 594                    .iter()
 595                    .filter_map(|context_id| self.context.get(&context_id))
 596            })
 597    }
 598
 599    /// Returns whether all of the tool uses have finished running.
 600    pub fn all_tools_finished(&self) -> bool {
 601        // If the only pending tool uses left are the ones with errors, then
 602        // that means that we've finished running all of the pending tools.
 603        self.tool_use
 604            .pending_tool_uses()
 605            .iter()
 606            .all(|tool_use| tool_use.status.is_error())
 607    }
 608
 609    pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
 610        self.tool_use.tool_uses_for_message(id, cx)
 611    }
 612
 613    pub fn tool_results_for_message(&self, id: MessageId) -> Vec<&LanguageModelToolResult> {
 614        self.tool_use.tool_results_for_message(id)
 615    }
 616
 617    pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
 618        self.tool_use.tool_result(id)
 619    }
 620
 621    pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
 622        self.tool_use.message_has_tool_results(message_id)
 623    }
 624
 625    pub fn insert_user_message(
 626        &mut self,
 627        text: impl Into<String>,
 628        context: Vec<AssistantContext>,
 629        git_checkpoint: Option<GitStoreCheckpoint>,
 630        cx: &mut Context<Self>,
 631    ) -> MessageId {
 632        let text = text.into();
 633
 634        let message_id = self.insert_message(Role::User, vec![MessageSegment::Text(text)], cx);
 635
 636        // Filter out contexts that have already been included in previous messages
 637        let new_context: Vec<_> = context
 638            .into_iter()
 639            .filter(|ctx| !self.context.contains_key(&ctx.id()))
 640            .collect();
 641
 642        if !new_context.is_empty() {
 643            if let Some(context_string) = format_context_as_string(new_context.iter(), cx) {
 644                if let Some(message) = self.messages.iter_mut().find(|m| m.id == message_id) {
 645                    message.context = context_string;
 646                }
 647            }
 648
 649            self.action_log.update(cx, |log, cx| {
 650                // Track all buffers added as context
 651                for ctx in &new_context {
 652                    match ctx {
 653                        AssistantContext::File(file_ctx) => {
 654                            log.buffer_added_as_context(file_ctx.context_buffer.buffer.clone(), cx);
 655                        }
 656                        AssistantContext::Directory(dir_ctx) => {
 657                            for context_buffer in &dir_ctx.context_buffers {
 658                                log.buffer_added_as_context(context_buffer.buffer.clone(), cx);
 659                            }
 660                        }
 661                        AssistantContext::Symbol(symbol_ctx) => {
 662                            log.buffer_added_as_context(
 663                                symbol_ctx.context_symbol.buffer.clone(),
 664                                cx,
 665                            );
 666                        }
 667                        AssistantContext::FetchedUrl(_) | AssistantContext::Thread(_) => {}
 668                    }
 669                }
 670            });
 671        }
 672
 673        let context_ids = new_context
 674            .iter()
 675            .map(|context| context.id())
 676            .collect::<Vec<_>>();
 677        self.context.extend(
 678            new_context
 679                .into_iter()
 680                .map(|context| (context.id(), context)),
 681        );
 682        self.context_by_message.insert(message_id, context_ids);
 683
 684        if let Some(git_checkpoint) = git_checkpoint {
 685            self.pending_checkpoint = Some(ThreadCheckpoint {
 686                message_id,
 687                git_checkpoint,
 688            });
 689        }
 690
 691        self.auto_capture_telemetry(cx);
 692
 693        message_id
 694    }
 695
 696    pub fn insert_message(
 697        &mut self,
 698        role: Role,
 699        segments: Vec<MessageSegment>,
 700        cx: &mut Context<Self>,
 701    ) -> MessageId {
 702        let id = self.next_message_id.post_inc();
 703        self.messages.push(Message {
 704            id,
 705            role,
 706            segments,
 707            context: String::new(),
 708        });
 709        self.touch_updated_at();
 710        cx.emit(ThreadEvent::MessageAdded(id));
 711        id
 712    }
 713
 714    pub fn edit_message(
 715        &mut self,
 716        id: MessageId,
 717        new_role: Role,
 718        new_segments: Vec<MessageSegment>,
 719        cx: &mut Context<Self>,
 720    ) -> bool {
 721        let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
 722            return false;
 723        };
 724        message.role = new_role;
 725        message.segments = new_segments;
 726        self.touch_updated_at();
 727        cx.emit(ThreadEvent::MessageEdited(id));
 728        true
 729    }
 730
 731    pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
 732        let Some(index) = self.messages.iter().position(|message| message.id == id) else {
 733            return false;
 734        };
 735        self.messages.remove(index);
 736        self.context_by_message.remove(&id);
 737        self.touch_updated_at();
 738        cx.emit(ThreadEvent::MessageDeleted(id));
 739        true
 740    }
 741
 742    /// Returns the representation of this [`Thread`] in a textual form.
 743    ///
 744    /// This is the representation we use when attaching a thread as context to another thread.
 745    pub fn text(&self) -> String {
 746        let mut text = String::new();
 747
 748        for message in &self.messages {
 749            text.push_str(match message.role {
 750                language_model::Role::User => "User:",
 751                language_model::Role::Assistant => "Assistant:",
 752                language_model::Role::System => "System:",
 753            });
 754            text.push('\n');
 755
 756            for segment in &message.segments {
 757                match segment {
 758                    MessageSegment::Text(content) => text.push_str(content),
 759                    MessageSegment::Thinking(content) => {
 760                        text.push_str(&format!("<think>{}</think>", content))
 761                    }
 762                }
 763            }
 764            text.push('\n');
 765        }
 766
 767        text
 768    }
 769
 770    /// Serializes this thread into a format for storage or telemetry.
 771    pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
 772        let initial_project_snapshot = self.initial_project_snapshot.clone();
 773        cx.spawn(async move |this, cx| {
 774            let initial_project_snapshot = initial_project_snapshot.await;
 775            this.read_with(cx, |this, cx| SerializedThread {
 776                version: SerializedThread::VERSION.to_string(),
 777                summary: this.summary_or_default(),
 778                updated_at: this.updated_at(),
 779                messages: this
 780                    .messages()
 781                    .map(|message| SerializedMessage {
 782                        id: message.id,
 783                        role: message.role,
 784                        segments: message
 785                            .segments
 786                            .iter()
 787                            .map(|segment| match segment {
 788                                MessageSegment::Text(text) => {
 789                                    SerializedMessageSegment::Text { text: text.clone() }
 790                                }
 791                                MessageSegment::Thinking(text) => {
 792                                    SerializedMessageSegment::Thinking { text: text.clone() }
 793                                }
 794                            })
 795                            .collect(),
 796                        tool_uses: this
 797                            .tool_uses_for_message(message.id, cx)
 798                            .into_iter()
 799                            .map(|tool_use| SerializedToolUse {
 800                                id: tool_use.id,
 801                                name: tool_use.name,
 802                                input: tool_use.input,
 803                            })
 804                            .collect(),
 805                        tool_results: this
 806                            .tool_results_for_message(message.id)
 807                            .into_iter()
 808                            .map(|tool_result| SerializedToolResult {
 809                                tool_use_id: tool_result.tool_use_id.clone(),
 810                                is_error: tool_result.is_error,
 811                                content: tool_result.content.clone(),
 812                            })
 813                            .collect(),
 814                        context: message.context.clone(),
 815                    })
 816                    .collect(),
 817                initial_project_snapshot,
 818                cumulative_token_usage: this.cumulative_token_usage.clone(),
 819                detailed_summary_state: this.detailed_summary_state.clone(),
 820            })
 821        })
 822    }
 823
 824    pub fn send_to_model(
 825        &mut self,
 826        model: Arc<dyn LanguageModel>,
 827        request_kind: RequestKind,
 828        cx: &mut Context<Self>,
 829    ) {
 830        let mut request = self.to_completion_request(request_kind, cx);
 831        if model.supports_tools() {
 832            request.tools = {
 833                let mut tools = Vec::new();
 834                tools.extend(self.tools().enabled_tools(cx).into_iter().map(|tool| {
 835                    LanguageModelRequestTool {
 836                        name: tool.name(),
 837                        description: tool.description(),
 838                        input_schema: tool.input_schema(model.tool_input_format()),
 839                    }
 840                }));
 841
 842                tools
 843            };
 844        }
 845
 846        self.stream_completion(request, model, cx);
 847    }
 848
 849    pub fn used_tools_since_last_user_message(&self) -> bool {
 850        for message in self.messages.iter().rev() {
 851            if self.tool_use.message_has_tool_results(message.id) {
 852                return true;
 853            } else if message.role == Role::User {
 854                return false;
 855            }
 856        }
 857
 858        false
 859    }
 860
 861    pub fn to_completion_request(
 862        &self,
 863        request_kind: RequestKind,
 864        cx: &App,
 865    ) -> LanguageModelRequest {
 866        let mut request = LanguageModelRequest {
 867            messages: vec![],
 868            tools: Vec::new(),
 869            stop: Vec::new(),
 870            temperature: None,
 871        };
 872
 873        if let Some(project_context) = self.project_context.borrow().as_ref() {
 874            if let Some(system_prompt) = self
 875                .prompt_builder
 876                .generate_assistant_system_prompt(project_context)
 877                .context("failed to generate assistant system prompt")
 878                .log_err()
 879            {
 880                request.messages.push(LanguageModelRequestMessage {
 881                    role: Role::System,
 882                    content: vec![MessageContent::Text(system_prompt)],
 883                    cache: true,
 884                });
 885            }
 886        } else {
 887            log::error!("project_context not set.")
 888        }
 889
 890        for message in &self.messages {
 891            let mut request_message = LanguageModelRequestMessage {
 892                role: message.role,
 893                content: Vec::new(),
 894                cache: false,
 895            };
 896
 897            match request_kind {
 898                RequestKind::Chat => {
 899                    self.tool_use
 900                        .attach_tool_results(message.id, &mut request_message);
 901                }
 902                RequestKind::Summarize => {
 903                    // We don't care about tool use during summarization.
 904                    if self.tool_use.message_has_tool_results(message.id) {
 905                        continue;
 906                    }
 907                }
 908            }
 909
 910            if !message.segments.is_empty() {
 911                request_message
 912                    .content
 913                    .push(MessageContent::Text(message.to_string()));
 914            }
 915
 916            match request_kind {
 917                RequestKind::Chat => {
 918                    self.tool_use
 919                        .attach_tool_uses(message.id, &mut request_message);
 920                }
 921                RequestKind::Summarize => {
 922                    // We don't care about tool use during summarization.
 923                }
 924            };
 925
 926            request.messages.push(request_message);
 927        }
 928
 929        // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
 930        if let Some(last) = request.messages.last_mut() {
 931            last.cache = true;
 932        }
 933
 934        self.attached_tracked_files_state(&mut request.messages, cx);
 935
 936        request
 937    }
 938
 939    fn attached_tracked_files_state(
 940        &self,
 941        messages: &mut Vec<LanguageModelRequestMessage>,
 942        cx: &App,
 943    ) {
 944        const STALE_FILES_HEADER: &str = "These files changed since last read:";
 945
 946        let mut stale_message = String::new();
 947
 948        let action_log = self.action_log.read(cx);
 949
 950        for stale_file in action_log.stale_buffers(cx) {
 951            let Some(file) = stale_file.read(cx).file() else {
 952                continue;
 953            };
 954
 955            if stale_message.is_empty() {
 956                write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
 957            }
 958
 959            writeln!(&mut stale_message, "- {}", file.path().display()).ok();
 960        }
 961
 962        let mut content = Vec::with_capacity(2);
 963
 964        if !stale_message.is_empty() {
 965            content.push(stale_message.into());
 966        }
 967
 968        if action_log.has_edited_files_since_project_diagnostics_check() {
 969            content.push(
 970                "\n\nWhen you're done making changes, make sure to check project diagnostics \
 971                and fix all errors AND warnings you introduced! \
 972                DO NOT mention you're going to do this until you're done."
 973                    .into(),
 974            );
 975        }
 976
 977        if !content.is_empty() {
 978            let context_message = LanguageModelRequestMessage {
 979                role: Role::User,
 980                content,
 981                cache: false,
 982            };
 983
 984            messages.push(context_message);
 985        }
 986    }
 987
 988    pub fn stream_completion(
 989        &mut self,
 990        request: LanguageModelRequest,
 991        model: Arc<dyn LanguageModel>,
 992        cx: &mut Context<Self>,
 993    ) {
 994        let pending_completion_id = post_inc(&mut self.completion_count);
 995
 996        let task = cx.spawn(async move |thread, cx| {
 997            let stream = model.stream_completion(request, &cx);
 998            let initial_token_usage =
 999                thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage.clone());
1000            let stream_completion = async {
1001                let mut events = stream.await?;
1002                let mut stop_reason = StopReason::EndTurn;
1003                let mut current_token_usage = TokenUsage::default();
1004
1005                while let Some(event) = events.next().await {
1006                    let event = event?;
1007
1008                    thread.update(cx, |thread, cx| {
1009                        match event {
1010                            LanguageModelCompletionEvent::StartMessage { .. } => {
1011                                thread.insert_message(
1012                                    Role::Assistant,
1013                                    vec![MessageSegment::Text(String::new())],
1014                                    cx,
1015                                );
1016                            }
1017                            LanguageModelCompletionEvent::Stop(reason) => {
1018                                stop_reason = reason;
1019                            }
1020                            LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1021                                thread.cumulative_token_usage =
1022                                    thread.cumulative_token_usage.clone() + token_usage.clone()
1023                                        - current_token_usage.clone();
1024                                current_token_usage = token_usage;
1025                            }
1026                            LanguageModelCompletionEvent::Text(chunk) => {
1027                                if let Some(last_message) = thread.messages.last_mut() {
1028                                    if last_message.role == Role::Assistant {
1029                                        last_message.push_text(&chunk);
1030                                        cx.emit(ThreadEvent::StreamedAssistantText(
1031                                            last_message.id,
1032                                            chunk,
1033                                        ));
1034                                    } else {
1035                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
1036                                        // of a new Assistant response.
1037                                        //
1038                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1039                                        // will result in duplicating the text of the chunk in the rendered Markdown.
1040                                        thread.insert_message(
1041                                            Role::Assistant,
1042                                            vec![MessageSegment::Text(chunk.to_string())],
1043                                            cx,
1044                                        );
1045                                    };
1046                                }
1047                            }
1048                            LanguageModelCompletionEvent::Thinking(chunk) => {
1049                                if let Some(last_message) = thread.messages.last_mut() {
1050                                    if last_message.role == Role::Assistant {
1051                                        last_message.push_thinking(&chunk);
1052                                        cx.emit(ThreadEvent::StreamedAssistantThinking(
1053                                            last_message.id,
1054                                            chunk,
1055                                        ));
1056                                    } else {
1057                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
1058                                        // of a new Assistant response.
1059                                        //
1060                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1061                                        // will result in duplicating the text of the chunk in the rendered Markdown.
1062                                        thread.insert_message(
1063                                            Role::Assistant,
1064                                            vec![MessageSegment::Thinking(chunk.to_string())],
1065                                            cx,
1066                                        );
1067                                    };
1068                                }
1069                            }
1070                            LanguageModelCompletionEvent::ToolUse(tool_use) => {
1071                                let last_assistant_message_id = thread
1072                                    .messages
1073                                    .iter_mut()
1074                                    .rfind(|message| message.role == Role::Assistant)
1075                                    .map(|message| message.id)
1076                                    .unwrap_or_else(|| {
1077                                        thread.insert_message(Role::Assistant, vec![], cx)
1078                                    });
1079
1080                                thread.tool_use.request_tool_use(
1081                                    last_assistant_message_id,
1082                                    tool_use,
1083                                    cx,
1084                                );
1085                            }
1086                        }
1087
1088                        thread.touch_updated_at();
1089                        cx.emit(ThreadEvent::StreamedCompletion);
1090                        cx.notify();
1091
1092                        thread.auto_capture_telemetry(cx);
1093                    })?;
1094
1095                    smol::future::yield_now().await;
1096                }
1097
1098                thread.update(cx, |thread, cx| {
1099                    thread
1100                        .pending_completions
1101                        .retain(|completion| completion.id != pending_completion_id);
1102
1103                    if thread.summary.is_none() && thread.messages.len() >= 2 {
1104                        thread.summarize(cx);
1105                    }
1106                })?;
1107
1108                anyhow::Ok(stop_reason)
1109            };
1110
1111            let result = stream_completion.await;
1112
1113            thread
1114                .update(cx, |thread, cx| {
1115                    thread.finalize_pending_checkpoint(cx);
1116                    match result.as_ref() {
1117                        Ok(stop_reason) => match stop_reason {
1118                            StopReason::ToolUse => {
1119                                let tool_uses = thread.use_pending_tools(cx);
1120                                cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1121                            }
1122                            StopReason::EndTurn => {}
1123                            StopReason::MaxTokens => {}
1124                        },
1125                        Err(error) => {
1126                            if error.is::<PaymentRequiredError>() {
1127                                cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1128                            } else if error.is::<MaxMonthlySpendReachedError>() {
1129                                cx.emit(ThreadEvent::ShowError(
1130                                    ThreadError::MaxMonthlySpendReached,
1131                                ));
1132                            } else {
1133                                let error_message = error
1134                                    .chain()
1135                                    .map(|err| err.to_string())
1136                                    .collect::<Vec<_>>()
1137                                    .join("\n");
1138                                cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1139                                    header: "Error interacting with language model".into(),
1140                                    message: SharedString::from(error_message.clone()),
1141                                }));
1142                            }
1143
1144                            thread.cancel_last_completion(cx);
1145                        }
1146                    }
1147                    cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1148
1149                    thread.auto_capture_telemetry(cx);
1150
1151                    if let Ok(initial_usage) = initial_token_usage {
1152                        let usage = thread.cumulative_token_usage.clone() - initial_usage;
1153
1154                        telemetry::event!(
1155                            "Assistant Thread Completion",
1156                            thread_id = thread.id().to_string(),
1157                            model = model.telemetry_id(),
1158                            model_provider = model.provider_id().to_string(),
1159                            input_tokens = usage.input_tokens,
1160                            output_tokens = usage.output_tokens,
1161                            cache_creation_input_tokens = usage.cache_creation_input_tokens,
1162                            cache_read_input_tokens = usage.cache_read_input_tokens,
1163                        );
1164                    }
1165                })
1166                .ok();
1167        });
1168
1169        self.pending_completions.push(PendingCompletion {
1170            id: pending_completion_id,
1171            _task: task,
1172        });
1173    }
1174
1175    pub fn summarize(&mut self, cx: &mut Context<Self>) {
1176        let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1177            return;
1178        };
1179
1180        if !model.provider.is_authenticated(cx) {
1181            return;
1182        }
1183
1184        let mut request = self.to_completion_request(RequestKind::Summarize, cx);
1185        request.messages.push(LanguageModelRequestMessage {
1186            role: Role::User,
1187            content: vec![
1188                "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1189                 Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1190                 If the conversation is about a specific subject, include it in the title. \
1191                 Be descriptive. DO NOT speak in the first person."
1192                    .into(),
1193            ],
1194            cache: false,
1195        });
1196
1197        self.pending_summary = cx.spawn(async move |this, cx| {
1198            async move {
1199                let stream = model.model.stream_completion_text(request, &cx);
1200                let mut messages = stream.await?;
1201
1202                let mut new_summary = String::new();
1203                while let Some(message) = messages.stream.next().await {
1204                    let text = message?;
1205                    let mut lines = text.lines();
1206                    new_summary.extend(lines.next());
1207
1208                    // Stop if the LLM generated multiple lines.
1209                    if lines.next().is_some() {
1210                        break;
1211                    }
1212                }
1213
1214                this.update(cx, |this, cx| {
1215                    if !new_summary.is_empty() {
1216                        this.summary = Some(new_summary.into());
1217                    }
1218
1219                    cx.emit(ThreadEvent::SummaryGenerated);
1220                })?;
1221
1222                anyhow::Ok(())
1223            }
1224            .log_err()
1225            .await
1226        });
1227    }
1228
1229    pub fn generate_detailed_summary(&mut self, cx: &mut Context<Self>) -> Option<Task<()>> {
1230        let last_message_id = self.messages.last().map(|message| message.id)?;
1231
1232        match &self.detailed_summary_state {
1233            DetailedSummaryState::Generating { message_id, .. }
1234            | DetailedSummaryState::Generated { message_id, .. }
1235                if *message_id == last_message_id =>
1236            {
1237                // Already up-to-date
1238                return None;
1239            }
1240            _ => {}
1241        }
1242
1243        let ConfiguredModel { model, provider } =
1244            LanguageModelRegistry::read_global(cx).thread_summary_model()?;
1245
1246        if !provider.is_authenticated(cx) {
1247            return None;
1248        }
1249
1250        let mut request = self.to_completion_request(RequestKind::Summarize, cx);
1251
1252        request.messages.push(LanguageModelRequestMessage {
1253            role: Role::User,
1254            content: vec![
1255                "Generate a detailed summary of this conversation. Include:\n\
1256                1. A brief overview of what was discussed\n\
1257                2. Key facts or information discovered\n\
1258                3. Outcomes or conclusions reached\n\
1259                4. Any action items or next steps if any\n\
1260                Format it in Markdown with headings and bullet points."
1261                    .into(),
1262            ],
1263            cache: false,
1264        });
1265
1266        let task = cx.spawn(async move |thread, cx| {
1267            let stream = model.stream_completion_text(request, &cx);
1268            let Some(mut messages) = stream.await.log_err() else {
1269                thread
1270                    .update(cx, |this, _cx| {
1271                        this.detailed_summary_state = DetailedSummaryState::NotGenerated;
1272                    })
1273                    .log_err();
1274
1275                return;
1276            };
1277
1278            let mut new_detailed_summary = String::new();
1279
1280            while let Some(chunk) = messages.stream.next().await {
1281                if let Some(chunk) = chunk.log_err() {
1282                    new_detailed_summary.push_str(&chunk);
1283                }
1284            }
1285
1286            thread
1287                .update(cx, |this, _cx| {
1288                    this.detailed_summary_state = DetailedSummaryState::Generated {
1289                        text: new_detailed_summary.into(),
1290                        message_id: last_message_id,
1291                    };
1292                })
1293                .log_err();
1294        });
1295
1296        self.detailed_summary_state = DetailedSummaryState::Generating {
1297            message_id: last_message_id,
1298        };
1299
1300        Some(task)
1301    }
1302
1303    pub fn is_generating_detailed_summary(&self) -> bool {
1304        matches!(
1305            self.detailed_summary_state,
1306            DetailedSummaryState::Generating { .. }
1307        )
1308    }
1309
1310    pub fn use_pending_tools(&mut self, cx: &mut Context<Self>) -> Vec<PendingToolUse> {
1311        self.auto_capture_telemetry(cx);
1312        let request = self.to_completion_request(RequestKind::Chat, cx);
1313        let messages = Arc::new(request.messages);
1314        let pending_tool_uses = self
1315            .tool_use
1316            .pending_tool_uses()
1317            .into_iter()
1318            .filter(|tool_use| tool_use.status.is_idle())
1319            .cloned()
1320            .collect::<Vec<_>>();
1321
1322        for tool_use in pending_tool_uses.iter() {
1323            if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
1324                if tool.needs_confirmation(&tool_use.input, cx)
1325                    && !AssistantSettings::get_global(cx).always_allow_tool_actions
1326                {
1327                    self.tool_use.confirm_tool_use(
1328                        tool_use.id.clone(),
1329                        tool_use.ui_text.clone(),
1330                        tool_use.input.clone(),
1331                        messages.clone(),
1332                        tool,
1333                    );
1334                    cx.emit(ThreadEvent::ToolConfirmationNeeded);
1335                } else {
1336                    self.run_tool(
1337                        tool_use.id.clone(),
1338                        tool_use.ui_text.clone(),
1339                        tool_use.input.clone(),
1340                        &messages,
1341                        tool,
1342                        cx,
1343                    );
1344                }
1345            }
1346        }
1347
1348        pending_tool_uses
1349    }
1350
1351    pub fn run_tool(
1352        &mut self,
1353        tool_use_id: LanguageModelToolUseId,
1354        ui_text: impl Into<SharedString>,
1355        input: serde_json::Value,
1356        messages: &[LanguageModelRequestMessage],
1357        tool: Arc<dyn Tool>,
1358        cx: &mut Context<Thread>,
1359    ) {
1360        let task = self.spawn_tool_use(tool_use_id.clone(), messages, input, tool, cx);
1361        self.tool_use
1362            .run_pending_tool(tool_use_id, ui_text.into(), task);
1363    }
1364
1365    fn spawn_tool_use(
1366        &mut self,
1367        tool_use_id: LanguageModelToolUseId,
1368        messages: &[LanguageModelRequestMessage],
1369        input: serde_json::Value,
1370        tool: Arc<dyn Tool>,
1371        cx: &mut Context<Thread>,
1372    ) -> Task<()> {
1373        let tool_name: Arc<str> = tool.name().into();
1374
1375        let run_tool = if self.tools.is_disabled(&tool.source(), &tool_name) {
1376            Task::ready(Err(anyhow!("tool is disabled: {tool_name}")))
1377        } else {
1378            tool.run(
1379                input,
1380                messages,
1381                self.project.clone(),
1382                self.action_log.clone(),
1383                cx,
1384            )
1385        };
1386
1387        cx.spawn({
1388            async move |thread: WeakEntity<Thread>, cx| {
1389                let output = run_tool.await;
1390
1391                thread
1392                    .update(cx, |thread, cx| {
1393                        let pending_tool_use = thread.tool_use.insert_tool_output(
1394                            tool_use_id.clone(),
1395                            tool_name,
1396                            output,
1397                            cx,
1398                        );
1399                        thread.tool_finished(tool_use_id, pending_tool_use, false, cx);
1400                    })
1401                    .ok();
1402            }
1403        })
1404    }
1405
1406    fn tool_finished(
1407        &mut self,
1408        tool_use_id: LanguageModelToolUseId,
1409        pending_tool_use: Option<PendingToolUse>,
1410        canceled: bool,
1411        cx: &mut Context<Self>,
1412    ) {
1413        if self.all_tools_finished() {
1414            let model_registry = LanguageModelRegistry::read_global(cx);
1415            if let Some(ConfiguredModel { model, .. }) = model_registry.default_model() {
1416                self.attach_tool_results(cx);
1417                if !canceled {
1418                    self.send_to_model(model, RequestKind::Chat, cx);
1419                }
1420            }
1421        }
1422
1423        cx.emit(ThreadEvent::ToolFinished {
1424            tool_use_id,
1425            pending_tool_use,
1426        });
1427    }
1428
1429    pub fn attach_tool_results(&mut self, cx: &mut Context<Self>) {
1430        // Insert a user message to contain the tool results.
1431        self.insert_user_message(
1432            // TODO: Sending up a user message without any content results in the model sending back
1433            // responses that also don't have any content. We currently don't handle this case well,
1434            // so for now we provide some text to keep the model on track.
1435            "Here are the tool results.",
1436            Vec::new(),
1437            None,
1438            cx,
1439        );
1440    }
1441
1442    /// Cancels the last pending completion, if there are any pending.
1443    ///
1444    /// Returns whether a completion was canceled.
1445    pub fn cancel_last_completion(&mut self, cx: &mut Context<Self>) -> bool {
1446        let canceled = if self.pending_completions.pop().is_some() {
1447            true
1448        } else {
1449            let mut canceled = false;
1450            for pending_tool_use in self.tool_use.cancel_pending() {
1451                canceled = true;
1452                self.tool_finished(
1453                    pending_tool_use.id.clone(),
1454                    Some(pending_tool_use),
1455                    true,
1456                    cx,
1457                );
1458            }
1459            canceled
1460        };
1461        self.finalize_pending_checkpoint(cx);
1462        canceled
1463    }
1464
1465    pub fn feedback(&self) -> Option<ThreadFeedback> {
1466        self.feedback
1467    }
1468
1469    pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
1470        self.message_feedback.get(&message_id).copied()
1471    }
1472
1473    pub fn report_message_feedback(
1474        &mut self,
1475        message_id: MessageId,
1476        feedback: ThreadFeedback,
1477        cx: &mut Context<Self>,
1478    ) -> Task<Result<()>> {
1479        if self.message_feedback.get(&message_id) == Some(&feedback) {
1480            return Task::ready(Ok(()));
1481        }
1482
1483        let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1484        let serialized_thread = self.serialize(cx);
1485        let thread_id = self.id().clone();
1486        let client = self.project.read(cx).client();
1487
1488        let enabled_tool_names: Vec<String> = self
1489            .tools()
1490            .enabled_tools(cx)
1491            .iter()
1492            .map(|tool| tool.name().to_string())
1493            .collect();
1494
1495        self.message_feedback.insert(message_id, feedback);
1496
1497        cx.notify();
1498
1499        let message_content = self
1500            .message(message_id)
1501            .map(|msg| msg.to_string())
1502            .unwrap_or_default();
1503
1504        cx.background_spawn(async move {
1505            let final_project_snapshot = final_project_snapshot.await;
1506            let serialized_thread = serialized_thread.await?;
1507            let thread_data =
1508                serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
1509
1510            let rating = match feedback {
1511                ThreadFeedback::Positive => "positive",
1512                ThreadFeedback::Negative => "negative",
1513            };
1514            telemetry::event!(
1515                "Assistant Thread Rated",
1516                rating,
1517                thread_id,
1518                enabled_tool_names,
1519                message_id = message_id.0,
1520                message_content,
1521                thread_data,
1522                final_project_snapshot
1523            );
1524            client.telemetry().flush_events();
1525
1526            Ok(())
1527        })
1528    }
1529
1530    pub fn report_feedback(
1531        &mut self,
1532        feedback: ThreadFeedback,
1533        cx: &mut Context<Self>,
1534    ) -> Task<Result<()>> {
1535        let last_assistant_message_id = self
1536            .messages
1537            .iter()
1538            .rev()
1539            .find(|msg| msg.role == Role::Assistant)
1540            .map(|msg| msg.id);
1541
1542        if let Some(message_id) = last_assistant_message_id {
1543            self.report_message_feedback(message_id, feedback, cx)
1544        } else {
1545            let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1546            let serialized_thread = self.serialize(cx);
1547            let thread_id = self.id().clone();
1548            let client = self.project.read(cx).client();
1549            self.feedback = Some(feedback);
1550            cx.notify();
1551
1552            cx.background_spawn(async move {
1553                let final_project_snapshot = final_project_snapshot.await;
1554                let serialized_thread = serialized_thread.await?;
1555                let thread_data = serde_json::to_value(serialized_thread)
1556                    .unwrap_or_else(|_| serde_json::Value::Null);
1557
1558                let rating = match feedback {
1559                    ThreadFeedback::Positive => "positive",
1560                    ThreadFeedback::Negative => "negative",
1561                };
1562                telemetry::event!(
1563                    "Assistant Thread Rated",
1564                    rating,
1565                    thread_id,
1566                    thread_data,
1567                    final_project_snapshot
1568                );
1569                client.telemetry().flush_events();
1570
1571                Ok(())
1572            })
1573        }
1574    }
1575
1576    /// Create a snapshot of the current project state including git information and unsaved buffers.
1577    fn project_snapshot(
1578        project: Entity<Project>,
1579        cx: &mut Context<Self>,
1580    ) -> Task<Arc<ProjectSnapshot>> {
1581        let git_store = project.read(cx).git_store().clone();
1582        let worktree_snapshots: Vec<_> = project
1583            .read(cx)
1584            .visible_worktrees(cx)
1585            .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
1586            .collect();
1587
1588        cx.spawn(async move |_, cx| {
1589            let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
1590
1591            let mut unsaved_buffers = Vec::new();
1592            cx.update(|app_cx| {
1593                let buffer_store = project.read(app_cx).buffer_store();
1594                for buffer_handle in buffer_store.read(app_cx).buffers() {
1595                    let buffer = buffer_handle.read(app_cx);
1596                    if buffer.is_dirty() {
1597                        if let Some(file) = buffer.file() {
1598                            let path = file.path().to_string_lossy().to_string();
1599                            unsaved_buffers.push(path);
1600                        }
1601                    }
1602                }
1603            })
1604            .ok();
1605
1606            Arc::new(ProjectSnapshot {
1607                worktree_snapshots,
1608                unsaved_buffer_paths: unsaved_buffers,
1609                timestamp: Utc::now(),
1610            })
1611        })
1612    }
1613
1614    fn worktree_snapshot(
1615        worktree: Entity<project::Worktree>,
1616        git_store: Entity<GitStore>,
1617        cx: &App,
1618    ) -> Task<WorktreeSnapshot> {
1619        cx.spawn(async move |cx| {
1620            // Get worktree path and snapshot
1621            let worktree_info = cx.update(|app_cx| {
1622                let worktree = worktree.read(app_cx);
1623                let path = worktree.abs_path().to_string_lossy().to_string();
1624                let snapshot = worktree.snapshot();
1625                (path, snapshot)
1626            });
1627
1628            let Ok((worktree_path, _snapshot)) = worktree_info else {
1629                return WorktreeSnapshot {
1630                    worktree_path: String::new(),
1631                    git_state: None,
1632                };
1633            };
1634
1635            let git_state = git_store
1636                .update(cx, |git_store, cx| {
1637                    git_store
1638                        .repositories()
1639                        .values()
1640                        .find(|repo| {
1641                            repo.read(cx)
1642                                .abs_path_to_repo_path(&worktree.read(cx).abs_path())
1643                                .is_some()
1644                        })
1645                        .cloned()
1646                })
1647                .ok()
1648                .flatten()
1649                .map(|repo| {
1650                    repo.update(cx, |repo, _| {
1651                        let current_branch =
1652                            repo.branch.as_ref().map(|branch| branch.name.to_string());
1653                        repo.send_job(None, |state, _| async move {
1654                            let RepositoryState::Local { backend, .. } = state else {
1655                                return GitState {
1656                                    remote_url: None,
1657                                    head_sha: None,
1658                                    current_branch,
1659                                    diff: None,
1660                                };
1661                            };
1662
1663                            let remote_url = backend.remote_url("origin");
1664                            let head_sha = backend.head_sha();
1665                            let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
1666
1667                            GitState {
1668                                remote_url,
1669                                head_sha,
1670                                current_branch,
1671                                diff,
1672                            }
1673                        })
1674                    })
1675                });
1676
1677            let git_state = match git_state {
1678                Some(git_state) => match git_state.ok() {
1679                    Some(git_state) => git_state.await.ok(),
1680                    None => None,
1681                },
1682                None => None,
1683            };
1684
1685            WorktreeSnapshot {
1686                worktree_path,
1687                git_state,
1688            }
1689        })
1690    }
1691
1692    pub fn to_markdown(&self, cx: &App) -> Result<String> {
1693        let mut markdown = Vec::new();
1694
1695        if let Some(summary) = self.summary() {
1696            writeln!(markdown, "# {summary}\n")?;
1697        };
1698
1699        for message in self.messages() {
1700            writeln!(
1701                markdown,
1702                "## {role}\n",
1703                role = match message.role {
1704                    Role::User => "User",
1705                    Role::Assistant => "Assistant",
1706                    Role::System => "System",
1707                }
1708            )?;
1709
1710            if !message.context.is_empty() {
1711                writeln!(markdown, "{}", message.context)?;
1712            }
1713
1714            for segment in &message.segments {
1715                match segment {
1716                    MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
1717                    MessageSegment::Thinking(text) => {
1718                        writeln!(markdown, "<think>{}</think>\n", text)?
1719                    }
1720                }
1721            }
1722
1723            for tool_use in self.tool_uses_for_message(message.id, cx) {
1724                writeln!(
1725                    markdown,
1726                    "**Use Tool: {} ({})**",
1727                    tool_use.name, tool_use.id
1728                )?;
1729                writeln!(markdown, "```json")?;
1730                writeln!(
1731                    markdown,
1732                    "{}",
1733                    serde_json::to_string_pretty(&tool_use.input)?
1734                )?;
1735                writeln!(markdown, "```")?;
1736            }
1737
1738            for tool_result in self.tool_results_for_message(message.id) {
1739                write!(markdown, "**Tool Results: {}", tool_result.tool_use_id)?;
1740                if tool_result.is_error {
1741                    write!(markdown, " (Error)")?;
1742                }
1743
1744                writeln!(markdown, "**\n")?;
1745                writeln!(markdown, "{}", tool_result.content)?;
1746            }
1747        }
1748
1749        Ok(String::from_utf8_lossy(&markdown).to_string())
1750    }
1751
1752    pub fn keep_edits_in_range(
1753        &mut self,
1754        buffer: Entity<language::Buffer>,
1755        buffer_range: Range<language::Anchor>,
1756        cx: &mut Context<Self>,
1757    ) {
1758        self.action_log.update(cx, |action_log, cx| {
1759            action_log.keep_edits_in_range(buffer, buffer_range, cx)
1760        });
1761    }
1762
1763    pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
1764        self.action_log
1765            .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
1766    }
1767
1768    pub fn reject_edits_in_range(
1769        &mut self,
1770        buffer: Entity<language::Buffer>,
1771        buffer_range: Range<language::Anchor>,
1772        cx: &mut Context<Self>,
1773    ) -> Task<Result<()>> {
1774        self.action_log.update(cx, |action_log, cx| {
1775            action_log.reject_edits_in_range(buffer, buffer_range, cx)
1776        })
1777    }
1778
1779    pub fn action_log(&self) -> &Entity<ActionLog> {
1780        &self.action_log
1781    }
1782
1783    pub fn project(&self) -> &Entity<Project> {
1784        &self.project
1785    }
1786
1787    pub fn cumulative_token_usage(&self) -> TokenUsage {
1788        self.cumulative_token_usage.clone()
1789    }
1790
1791    pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
1792        if !cx.has_flag::<feature_flags::ThreadAutoCapture>() {
1793            return;
1794        }
1795
1796        let now = Instant::now();
1797        if let Some(last) = self.last_auto_capture_at {
1798            if now.duration_since(last).as_secs() < 10 {
1799                return;
1800            }
1801        }
1802
1803        self.last_auto_capture_at = Some(now);
1804
1805        let thread_id = self.id().clone();
1806        let github_login = self
1807            .project
1808            .read(cx)
1809            .user_store()
1810            .read(cx)
1811            .current_user()
1812            .map(|user| user.github_login.clone());
1813        let client = self.project.read(cx).client().clone();
1814        let serialize_task = self.serialize(cx);
1815
1816        cx.background_executor()
1817            .spawn(async move {
1818                if let Ok(serialized_thread) = serialize_task.await {
1819                    if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
1820                        telemetry::event!(
1821                            "Agent Thread Auto-Captured",
1822                            thread_id = thread_id.to_string(),
1823                            thread_data = thread_data,
1824                            auto_capture_reason = "tracked_user",
1825                            github_login = github_login
1826                        );
1827
1828                        client.telemetry().flush_events();
1829                    }
1830                }
1831            })
1832            .detach();
1833    }
1834
1835    pub fn total_token_usage(&self, cx: &App) -> TotalTokenUsage {
1836        let model_registry = LanguageModelRegistry::read_global(cx);
1837        let Some(model) = model_registry.default_model() else {
1838            return TotalTokenUsage::default();
1839        };
1840
1841        let max = model.model.max_token_count();
1842
1843        #[cfg(debug_assertions)]
1844        let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
1845            .unwrap_or("0.8".to_string())
1846            .parse()
1847            .unwrap();
1848        #[cfg(not(debug_assertions))]
1849        let warning_threshold: f32 = 0.8;
1850
1851        let total = self.cumulative_token_usage.total_tokens() as usize;
1852
1853        let ratio = if total >= max {
1854            TokenUsageRatio::Exceeded
1855        } else if total as f32 / max as f32 >= warning_threshold {
1856            TokenUsageRatio::Warning
1857        } else {
1858            TokenUsageRatio::Normal
1859        };
1860
1861        TotalTokenUsage { total, max, ratio }
1862    }
1863
1864    pub fn deny_tool_use(
1865        &mut self,
1866        tool_use_id: LanguageModelToolUseId,
1867        tool_name: Arc<str>,
1868        cx: &mut Context<Self>,
1869    ) {
1870        let err = Err(anyhow::anyhow!(
1871            "Permission to run tool action denied by user"
1872        ));
1873
1874        self.tool_use
1875            .insert_tool_output(tool_use_id.clone(), tool_name, err, cx);
1876        self.tool_finished(tool_use_id.clone(), None, true, cx);
1877    }
1878}
1879
1880#[derive(Debug, Clone, Error)]
1881pub enum ThreadError {
1882    #[error("Payment required")]
1883    PaymentRequired,
1884    #[error("Max monthly spend reached")]
1885    MaxMonthlySpendReached,
1886    #[error("Message {header}: {message}")]
1887    Message {
1888        header: SharedString,
1889        message: SharedString,
1890    },
1891}
1892
1893#[derive(Debug, Clone)]
1894pub enum ThreadEvent {
1895    ShowError(ThreadError),
1896    StreamedCompletion,
1897    StreamedAssistantText(MessageId, String),
1898    StreamedAssistantThinking(MessageId, String),
1899    Stopped(Result<StopReason, Arc<anyhow::Error>>),
1900    MessageAdded(MessageId),
1901    MessageEdited(MessageId),
1902    MessageDeleted(MessageId),
1903    SummaryGenerated,
1904    SummaryChanged,
1905    UsePendingTools {
1906        tool_uses: Vec<PendingToolUse>,
1907    },
1908    ToolFinished {
1909        #[allow(unused)]
1910        tool_use_id: LanguageModelToolUseId,
1911        /// The pending tool use that corresponds to this tool.
1912        pending_tool_use: Option<PendingToolUse>,
1913    },
1914    CheckpointChanged,
1915    ToolConfirmationNeeded,
1916}
1917
1918impl EventEmitter<ThreadEvent> for Thread {}
1919
1920struct PendingCompletion {
1921    id: usize,
1922    _task: Task<()>,
1923}
1924
1925#[cfg(test)]
1926mod tests {
1927    use super::*;
1928    use crate::{ThreadStore, context_store::ContextStore, thread_store};
1929    use assistant_settings::AssistantSettings;
1930    use context_server::ContextServerSettings;
1931    use editor::EditorSettings;
1932    use gpui::TestAppContext;
1933    use project::{FakeFs, Project};
1934    use prompt_store::PromptBuilder;
1935    use serde_json::json;
1936    use settings::{Settings, SettingsStore};
1937    use std::sync::Arc;
1938    use theme::ThemeSettings;
1939    use util::path;
1940    use workspace::Workspace;
1941
1942    #[gpui::test]
1943    async fn test_message_with_context(cx: &mut TestAppContext) {
1944        init_test_settings(cx);
1945
1946        let project = create_test_project(
1947            cx,
1948            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
1949        )
1950        .await;
1951
1952        let (_workspace, _thread_store, thread, context_store) =
1953            setup_test_environment(cx, project.clone()).await;
1954
1955        add_file_to_context(&project, &context_store, "test/code.rs", cx)
1956            .await
1957            .unwrap();
1958
1959        let context =
1960            context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
1961
1962        // Insert user message with context
1963        let message_id = thread.update(cx, |thread, cx| {
1964            thread.insert_user_message("Please explain this code", vec![context], None, cx)
1965        });
1966
1967        // Check content and context in message object
1968        let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
1969
1970        // Use different path format strings based on platform for the test
1971        #[cfg(windows)]
1972        let path_part = r"test\code.rs";
1973        #[cfg(not(windows))]
1974        let path_part = "test/code.rs";
1975
1976        let expected_context = format!(
1977            r#"
1978<context>
1979The following items were attached by the user. You don't need to use other tools to read them.
1980
1981<files>
1982```rs {path_part}
1983fn main() {{
1984    println!("Hello, world!");
1985}}
1986```
1987</files>
1988</context>
1989"#
1990        );
1991
1992        assert_eq!(message.role, Role::User);
1993        assert_eq!(message.segments.len(), 1);
1994        assert_eq!(
1995            message.segments[0],
1996            MessageSegment::Text("Please explain this code".to_string())
1997        );
1998        assert_eq!(message.context, expected_context);
1999
2000        // Check message in request
2001        let request = thread.read_with(cx, |thread, cx| {
2002            thread.to_completion_request(RequestKind::Chat, cx)
2003        });
2004
2005        assert_eq!(request.messages.len(), 2);
2006        let expected_full_message = format!("{}Please explain this code", expected_context);
2007        assert_eq!(request.messages[1].string_contents(), expected_full_message);
2008    }
2009
2010    #[gpui::test]
2011    async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2012        init_test_settings(cx);
2013
2014        let project = create_test_project(
2015            cx,
2016            json!({
2017                "file1.rs": "fn function1() {}\n",
2018                "file2.rs": "fn function2() {}\n",
2019                "file3.rs": "fn function3() {}\n",
2020            }),
2021        )
2022        .await;
2023
2024        let (_, _thread_store, thread, context_store) =
2025            setup_test_environment(cx, project.clone()).await;
2026
2027        // Open files individually
2028        add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2029            .await
2030            .unwrap();
2031        add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2032            .await
2033            .unwrap();
2034        add_file_to_context(&project, &context_store, "test/file3.rs", cx)
2035            .await
2036            .unwrap();
2037
2038        // Get the context objects
2039        let contexts = context_store.update(cx, |store, _| store.context().clone());
2040        assert_eq!(contexts.len(), 3);
2041
2042        // First message with context 1
2043        let message1_id = thread.update(cx, |thread, cx| {
2044            thread.insert_user_message("Message 1", vec![contexts[0].clone()], None, cx)
2045        });
2046
2047        // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2048        let message2_id = thread.update(cx, |thread, cx| {
2049            thread.insert_user_message(
2050                "Message 2",
2051                vec![contexts[0].clone(), contexts[1].clone()],
2052                None,
2053                cx,
2054            )
2055        });
2056
2057        // Third message with all three contexts (contexts 1 and 2 should be skipped)
2058        let message3_id = thread.update(cx, |thread, cx| {
2059            thread.insert_user_message(
2060                "Message 3",
2061                vec![
2062                    contexts[0].clone(),
2063                    contexts[1].clone(),
2064                    contexts[2].clone(),
2065                ],
2066                None,
2067                cx,
2068            )
2069        });
2070
2071        // Check what contexts are included in each message
2072        let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
2073            (
2074                thread.message(message1_id).unwrap().clone(),
2075                thread.message(message2_id).unwrap().clone(),
2076                thread.message(message3_id).unwrap().clone(),
2077            )
2078        });
2079
2080        // First message should include context 1
2081        assert!(message1.context.contains("file1.rs"));
2082
2083        // Second message should include only context 2 (not 1)
2084        assert!(!message2.context.contains("file1.rs"));
2085        assert!(message2.context.contains("file2.rs"));
2086
2087        // Third message should include only context 3 (not 1 or 2)
2088        assert!(!message3.context.contains("file1.rs"));
2089        assert!(!message3.context.contains("file2.rs"));
2090        assert!(message3.context.contains("file3.rs"));
2091
2092        // Check entire request to make sure all contexts are properly included
2093        let request = thread.read_with(cx, |thread, cx| {
2094            thread.to_completion_request(RequestKind::Chat, cx)
2095        });
2096
2097        // The request should contain all 3 messages
2098        assert_eq!(request.messages.len(), 4);
2099
2100        // Check that the contexts are properly formatted in each message
2101        assert!(request.messages[1].string_contents().contains("file1.rs"));
2102        assert!(!request.messages[1].string_contents().contains("file2.rs"));
2103        assert!(!request.messages[1].string_contents().contains("file3.rs"));
2104
2105        assert!(!request.messages[2].string_contents().contains("file1.rs"));
2106        assert!(request.messages[2].string_contents().contains("file2.rs"));
2107        assert!(!request.messages[2].string_contents().contains("file3.rs"));
2108
2109        assert!(!request.messages[3].string_contents().contains("file1.rs"));
2110        assert!(!request.messages[3].string_contents().contains("file2.rs"));
2111        assert!(request.messages[3].string_contents().contains("file3.rs"));
2112    }
2113
2114    #[gpui::test]
2115    async fn test_message_without_files(cx: &mut TestAppContext) {
2116        init_test_settings(cx);
2117
2118        let project = create_test_project(
2119            cx,
2120            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
2121        )
2122        .await;
2123
2124        let (_, _thread_store, thread, _context_store) =
2125            setup_test_environment(cx, project.clone()).await;
2126
2127        // Insert user message without any context (empty context vector)
2128        let message_id = thread.update(cx, |thread, cx| {
2129            thread.insert_user_message("What is the best way to learn Rust?", vec![], None, cx)
2130        });
2131
2132        // Check content and context in message object
2133        let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2134
2135        // Context should be empty when no files are included
2136        assert_eq!(message.role, Role::User);
2137        assert_eq!(message.segments.len(), 1);
2138        assert_eq!(
2139            message.segments[0],
2140            MessageSegment::Text("What is the best way to learn Rust?".to_string())
2141        );
2142        assert_eq!(message.context, "");
2143
2144        // Check message in request
2145        let request = thread.read_with(cx, |thread, cx| {
2146            thread.to_completion_request(RequestKind::Chat, cx)
2147        });
2148
2149        assert_eq!(request.messages.len(), 2);
2150        assert_eq!(
2151            request.messages[1].string_contents(),
2152            "What is the best way to learn Rust?"
2153        );
2154
2155        // Add second message, also without context
2156        let message2_id = thread.update(cx, |thread, cx| {
2157            thread.insert_user_message("Are there any good books?", vec![], None, cx)
2158        });
2159
2160        let message2 =
2161            thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
2162        assert_eq!(message2.context, "");
2163
2164        // Check that both messages appear in the request
2165        let request = thread.read_with(cx, |thread, cx| {
2166            thread.to_completion_request(RequestKind::Chat, cx)
2167        });
2168
2169        assert_eq!(request.messages.len(), 3);
2170        assert_eq!(
2171            request.messages[1].string_contents(),
2172            "What is the best way to learn Rust?"
2173        );
2174        assert_eq!(
2175            request.messages[2].string_contents(),
2176            "Are there any good books?"
2177        );
2178    }
2179
2180    #[gpui::test]
2181    async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
2182        init_test_settings(cx);
2183
2184        let project = create_test_project(
2185            cx,
2186            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
2187        )
2188        .await;
2189
2190        let (_workspace, _thread_store, thread, context_store) =
2191            setup_test_environment(cx, project.clone()).await;
2192
2193        // Open buffer and add it to context
2194        let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
2195            .await
2196            .unwrap();
2197
2198        let context =
2199            context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
2200
2201        // Insert user message with the buffer as context
2202        thread.update(cx, |thread, cx| {
2203            thread.insert_user_message("Explain this code", vec![context], None, cx)
2204        });
2205
2206        // Create a request and check that it doesn't have a stale buffer warning yet
2207        let initial_request = thread.read_with(cx, |thread, cx| {
2208            thread.to_completion_request(RequestKind::Chat, cx)
2209        });
2210
2211        // Make sure we don't have a stale file warning yet
2212        let has_stale_warning = initial_request.messages.iter().any(|msg| {
2213            msg.string_contents()
2214                .contains("These files changed since last read:")
2215        });
2216        assert!(
2217            !has_stale_warning,
2218            "Should not have stale buffer warning before buffer is modified"
2219        );
2220
2221        // Modify the buffer
2222        buffer.update(cx, |buffer, cx| {
2223            // Find a position at the end of line 1
2224            buffer.edit(
2225                [(1..1, "\n    println!(\"Added a new line\");\n")],
2226                None,
2227                cx,
2228            );
2229        });
2230
2231        // Insert another user message without context
2232        thread.update(cx, |thread, cx| {
2233            thread.insert_user_message("What does the code do now?", vec![], None, cx)
2234        });
2235
2236        // Create a new request and check for the stale buffer warning
2237        let new_request = thread.read_with(cx, |thread, cx| {
2238            thread.to_completion_request(RequestKind::Chat, cx)
2239        });
2240
2241        // We should have a stale file warning as the last message
2242        let last_message = new_request
2243            .messages
2244            .last()
2245            .expect("Request should have messages");
2246
2247        // The last message should be the stale buffer notification
2248        assert_eq!(last_message.role, Role::User);
2249
2250        // Check the exact content of the message
2251        let expected_content = "These files changed since last read:\n- code.rs\n";
2252        assert_eq!(
2253            last_message.string_contents(),
2254            expected_content,
2255            "Last message should be exactly the stale buffer notification"
2256        );
2257    }
2258
2259    fn init_test_settings(cx: &mut TestAppContext) {
2260        cx.update(|cx| {
2261            let settings_store = SettingsStore::test(cx);
2262            cx.set_global(settings_store);
2263            language::init(cx);
2264            Project::init_settings(cx);
2265            AssistantSettings::register(cx);
2266            thread_store::init(cx);
2267            workspace::init_settings(cx);
2268            ThemeSettings::register(cx);
2269            ContextServerSettings::register(cx);
2270            EditorSettings::register(cx);
2271        });
2272    }
2273
2274    // Helper to create a test project with test files
2275    async fn create_test_project(
2276        cx: &mut TestAppContext,
2277        files: serde_json::Value,
2278    ) -> Entity<Project> {
2279        let fs = FakeFs::new(cx.executor());
2280        fs.insert_tree(path!("/test"), files).await;
2281        Project::test(fs, [path!("/test").as_ref()], cx).await
2282    }
2283
2284    async fn setup_test_environment(
2285        cx: &mut TestAppContext,
2286        project: Entity<Project>,
2287    ) -> (
2288        Entity<Workspace>,
2289        Entity<ThreadStore>,
2290        Entity<Thread>,
2291        Entity<ContextStore>,
2292    ) {
2293        let (workspace, cx) =
2294            cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
2295
2296        let thread_store = cx
2297            .update(|_, cx| {
2298                ThreadStore::load(
2299                    project.clone(),
2300                    Arc::default(),
2301                    Arc::new(PromptBuilder::new(None).unwrap()),
2302                    cx,
2303                )
2304            })
2305            .await;
2306
2307        let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
2308        let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
2309
2310        (workspace, thread_store, thread, context_store)
2311    }
2312
2313    async fn add_file_to_context(
2314        project: &Entity<Project>,
2315        context_store: &Entity<ContextStore>,
2316        path: &str,
2317        cx: &mut TestAppContext,
2318    ) -> Result<Entity<language::Buffer>> {
2319        let buffer_path = project
2320            .read_with(cx, |project, cx| project.find_project_path(path, cx))
2321            .unwrap();
2322
2323        let buffer = project
2324            .update(cx, |project, cx| project.open_buffer(buffer_path, cx))
2325            .await
2326            .unwrap();
2327
2328        context_store
2329            .update(cx, |store, cx| {
2330                store.add_file_from_buffer(buffer.clone(), cx)
2331            })
2332            .await?;
2333
2334        Ok(buffer)
2335    }
2336}