thread.rs

   1use std::fmt::Write as _;
   2use std::io::Write;
   3use std::ops::Range;
   4use std::sync::Arc;
   5
   6use agent_rules::load_worktree_rules_file;
   7use anyhow::{Context as _, Result, anyhow};
   8use assistant_settings::AssistantSettings;
   9use assistant_tool::{ActionLog, Tool, ToolWorkingSet};
  10use chrono::{DateTime, Utc};
  11use collections::{BTreeMap, HashMap};
  12use fs::Fs;
  13use futures::future::Shared;
  14use futures::{FutureExt, StreamExt as _};
  15use git::repository::DiffType;
  16use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
  17use language_model::{
  18    ConfiguredModel, LanguageModel, LanguageModelCompletionEvent, LanguageModelRegistry,
  19    LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
  20    LanguageModelToolResult, LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
  21    PaymentRequiredError, Role, StopReason, TokenUsage,
  22};
  23use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
  24use project::{Project, Worktree};
  25use prompt_store::{AssistantSystemPromptContext, PromptBuilder, WorktreeInfoForSystemPrompt};
  26use schemars::JsonSchema;
  27use serde::{Deserialize, Serialize};
  28use settings::Settings;
  29use util::{ResultExt as _, TryFutureExt as _, post_inc};
  30use uuid::Uuid;
  31
  32use crate::context::{AssistantContext, ContextId, format_context_as_string};
  33use crate::thread_store::{
  34    SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult,
  35    SerializedToolUse,
  36};
  37use crate::tool_use::{PendingToolUse, ToolUse, ToolUseState, USING_TOOL_MARKER};
  38
  39#[derive(Debug, Clone, Copy)]
  40pub enum RequestKind {
  41    Chat,
  42    /// Used when summarizing a thread.
  43    Summarize,
  44}
  45
  46#[derive(
  47    Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
  48)]
  49pub struct ThreadId(Arc<str>);
  50
  51impl ThreadId {
  52    pub fn new() -> Self {
  53        Self(Uuid::new_v4().to_string().into())
  54    }
  55}
  56
  57impl std::fmt::Display for ThreadId {
  58    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  59        write!(f, "{}", self.0)
  60    }
  61}
  62
  63impl From<&str> for ThreadId {
  64    fn from(value: &str) -> Self {
  65        Self(value.into())
  66    }
  67}
  68
  69#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
  70pub struct MessageId(pub(crate) usize);
  71
  72impl MessageId {
  73    fn post_inc(&mut self) -> Self {
  74        Self(post_inc(&mut self.0))
  75    }
  76}
  77
  78/// A message in a [`Thread`].
  79#[derive(Debug, Clone)]
  80pub struct Message {
  81    pub id: MessageId,
  82    pub role: Role,
  83    pub segments: Vec<MessageSegment>,
  84    pub context: String,
  85}
  86
  87impl Message {
  88    /// Returns whether the message contains any meaningful text that should be displayed
  89    /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
  90    pub fn should_display_content(&self) -> bool {
  91        self.segments.iter().all(|segment| segment.should_display())
  92    }
  93
  94    pub fn push_thinking(&mut self, text: &str) {
  95        if let Some(MessageSegment::Thinking(segment)) = self.segments.last_mut() {
  96            segment.push_str(text);
  97        } else {
  98            self.segments
  99                .push(MessageSegment::Thinking(text.to_string()));
 100        }
 101    }
 102
 103    pub fn push_text(&mut self, text: &str) {
 104        if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
 105            segment.push_str(text);
 106        } else {
 107            self.segments.push(MessageSegment::Text(text.to_string()));
 108        }
 109    }
 110
 111    pub fn to_string(&self) -> String {
 112        let mut result = String::new();
 113
 114        if !self.context.is_empty() {
 115            result.push_str(&self.context);
 116        }
 117
 118        for segment in &self.segments {
 119            match segment {
 120                MessageSegment::Text(text) => result.push_str(text),
 121                MessageSegment::Thinking(text) => {
 122                    result.push_str("<think>");
 123                    result.push_str(text);
 124                    result.push_str("</think>");
 125                }
 126            }
 127        }
 128
 129        result
 130    }
 131}
 132
 133#[derive(Debug, Clone, PartialEq, Eq)]
 134pub enum MessageSegment {
 135    Text(String),
 136    Thinking(String),
 137}
 138
 139impl MessageSegment {
 140    pub fn text_mut(&mut self) -> &mut String {
 141        match self {
 142            Self::Text(text) => text,
 143            Self::Thinking(text) => text,
 144        }
 145    }
 146
 147    pub fn should_display(&self) -> bool {
 148        // We add USING_TOOL_MARKER when making a request that includes tool uses
 149        // without non-whitespace text around them, and this can cause the model
 150        // to mimic the pattern, so we consider those segments not displayable.
 151        match self {
 152            Self::Text(text) => text.is_empty() || text.trim() == USING_TOOL_MARKER,
 153            Self::Thinking(text) => text.is_empty() || text.trim() == USING_TOOL_MARKER,
 154        }
 155    }
 156}
 157
 158#[derive(Debug, Clone, Serialize, Deserialize)]
 159pub struct ProjectSnapshot {
 160    pub worktree_snapshots: Vec<WorktreeSnapshot>,
 161    pub unsaved_buffer_paths: Vec<String>,
 162    pub timestamp: DateTime<Utc>,
 163}
 164
 165#[derive(Debug, Clone, Serialize, Deserialize)]
 166pub struct WorktreeSnapshot {
 167    pub worktree_path: String,
 168    pub git_state: Option<GitState>,
 169}
 170
 171#[derive(Debug, Clone, Serialize, Deserialize)]
 172pub struct GitState {
 173    pub remote_url: Option<String>,
 174    pub head_sha: Option<String>,
 175    pub current_branch: Option<String>,
 176    pub diff: Option<String>,
 177}
 178
 179#[derive(Clone)]
 180pub struct ThreadCheckpoint {
 181    message_id: MessageId,
 182    git_checkpoint: GitStoreCheckpoint,
 183}
 184
 185#[derive(Copy, Clone, Debug, PartialEq, Eq)]
 186pub enum ThreadFeedback {
 187    Positive,
 188    Negative,
 189}
 190
 191pub enum LastRestoreCheckpoint {
 192    Pending {
 193        message_id: MessageId,
 194    },
 195    Error {
 196        message_id: MessageId,
 197        error: String,
 198    },
 199}
 200
 201impl LastRestoreCheckpoint {
 202    pub fn message_id(&self) -> MessageId {
 203        match self {
 204            LastRestoreCheckpoint::Pending { message_id } => *message_id,
 205            LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
 206        }
 207    }
 208}
 209
 210#[derive(Clone, Debug, Default, Serialize, Deserialize)]
 211pub enum DetailedSummaryState {
 212    #[default]
 213    NotGenerated,
 214    Generating {
 215        message_id: MessageId,
 216    },
 217    Generated {
 218        text: SharedString,
 219        message_id: MessageId,
 220    },
 221}
 222
 223#[derive(Default)]
 224pub struct TotalTokenUsage {
 225    pub total: usize,
 226    pub max: usize,
 227    pub ratio: TokenUsageRatio,
 228}
 229
 230#[derive(Default, PartialEq, Eq)]
 231pub enum TokenUsageRatio {
 232    #[default]
 233    Normal,
 234    Warning,
 235    Exceeded,
 236}
 237
 238/// A thread of conversation with the LLM.
 239pub struct Thread {
 240    id: ThreadId,
 241    updated_at: DateTime<Utc>,
 242    summary: Option<SharedString>,
 243    pending_summary: Task<Option<()>>,
 244    detailed_summary_state: DetailedSummaryState,
 245    messages: Vec<Message>,
 246    next_message_id: MessageId,
 247    context: BTreeMap<ContextId, AssistantContext>,
 248    context_by_message: HashMap<MessageId, Vec<ContextId>>,
 249    system_prompt_context: Option<AssistantSystemPromptContext>,
 250    checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
 251    completion_count: usize,
 252    pending_completions: Vec<PendingCompletion>,
 253    project: Entity<Project>,
 254    prompt_builder: Arc<PromptBuilder>,
 255    tools: Arc<ToolWorkingSet>,
 256    tool_use: ToolUseState,
 257    action_log: Entity<ActionLog>,
 258    last_restore_checkpoint: Option<LastRestoreCheckpoint>,
 259    pending_checkpoint: Option<ThreadCheckpoint>,
 260    initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
 261    cumulative_token_usage: TokenUsage,
 262    feedback: Option<ThreadFeedback>,
 263    message_feedback: HashMap<MessageId, ThreadFeedback>,
 264}
 265
 266impl Thread {
 267    pub fn new(
 268        project: Entity<Project>,
 269        tools: Arc<ToolWorkingSet>,
 270        prompt_builder: Arc<PromptBuilder>,
 271        cx: &mut Context<Self>,
 272    ) -> Self {
 273        Self {
 274            id: ThreadId::new(),
 275            updated_at: Utc::now(),
 276            summary: None,
 277            pending_summary: Task::ready(None),
 278            detailed_summary_state: DetailedSummaryState::NotGenerated,
 279            messages: Vec::new(),
 280            next_message_id: MessageId(0),
 281            context: BTreeMap::default(),
 282            context_by_message: HashMap::default(),
 283            system_prompt_context: None,
 284            checkpoints_by_message: HashMap::default(),
 285            completion_count: 0,
 286            pending_completions: Vec::new(),
 287            project: project.clone(),
 288            prompt_builder,
 289            tools: tools.clone(),
 290            last_restore_checkpoint: None,
 291            pending_checkpoint: None,
 292            tool_use: ToolUseState::new(tools.clone()),
 293            action_log: cx.new(|_| ActionLog::new(project.clone())),
 294            initial_project_snapshot: {
 295                let project_snapshot = Self::project_snapshot(project, cx);
 296                cx.foreground_executor()
 297                    .spawn(async move { Some(project_snapshot.await) })
 298                    .shared()
 299            },
 300            cumulative_token_usage: TokenUsage::default(),
 301            feedback: None,
 302            message_feedback: HashMap::default(),
 303        }
 304    }
 305
 306    pub fn deserialize(
 307        id: ThreadId,
 308        serialized: SerializedThread,
 309        project: Entity<Project>,
 310        tools: Arc<ToolWorkingSet>,
 311        prompt_builder: Arc<PromptBuilder>,
 312        cx: &mut Context<Self>,
 313    ) -> Self {
 314        let next_message_id = MessageId(
 315            serialized
 316                .messages
 317                .last()
 318                .map(|message| message.id.0 + 1)
 319                .unwrap_or(0),
 320        );
 321        let tool_use =
 322            ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages, |_| true);
 323
 324        Self {
 325            id,
 326            updated_at: serialized.updated_at,
 327            summary: Some(serialized.summary),
 328            pending_summary: Task::ready(None),
 329            detailed_summary_state: serialized.detailed_summary_state,
 330            messages: serialized
 331                .messages
 332                .into_iter()
 333                .map(|message| Message {
 334                    id: message.id,
 335                    role: message.role,
 336                    segments: message
 337                        .segments
 338                        .into_iter()
 339                        .map(|segment| match segment {
 340                            SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
 341                            SerializedMessageSegment::Thinking { text } => {
 342                                MessageSegment::Thinking(text)
 343                            }
 344                        })
 345                        .collect(),
 346                    context: message.context,
 347                })
 348                .collect(),
 349            next_message_id,
 350            context: BTreeMap::default(),
 351            context_by_message: HashMap::default(),
 352            system_prompt_context: None,
 353            checkpoints_by_message: HashMap::default(),
 354            completion_count: 0,
 355            pending_completions: Vec::new(),
 356            last_restore_checkpoint: None,
 357            pending_checkpoint: None,
 358            project: project.clone(),
 359            prompt_builder,
 360            tools,
 361            tool_use,
 362            action_log: cx.new(|_| ActionLog::new(project)),
 363            initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
 364            cumulative_token_usage: serialized.cumulative_token_usage,
 365            feedback: None,
 366            message_feedback: HashMap::default(),
 367        }
 368    }
 369
 370    pub fn id(&self) -> &ThreadId {
 371        &self.id
 372    }
 373
 374    pub fn is_empty(&self) -> bool {
 375        self.messages.is_empty()
 376    }
 377
 378    pub fn updated_at(&self) -> DateTime<Utc> {
 379        self.updated_at
 380    }
 381
 382    pub fn touch_updated_at(&mut self) {
 383        self.updated_at = Utc::now();
 384    }
 385
 386    pub fn summary(&self) -> Option<SharedString> {
 387        self.summary.clone()
 388    }
 389
 390    pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
 391
 392    pub fn summary_or_default(&self) -> SharedString {
 393        self.summary.clone().unwrap_or(Self::DEFAULT_SUMMARY)
 394    }
 395
 396    pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
 397        let Some(current_summary) = &self.summary else {
 398            // Don't allow setting summary until generated
 399            return;
 400        };
 401
 402        let mut new_summary = new_summary.into();
 403
 404        if new_summary.is_empty() {
 405            new_summary = Self::DEFAULT_SUMMARY;
 406        }
 407
 408        if current_summary != &new_summary {
 409            self.summary = Some(new_summary);
 410            cx.emit(ThreadEvent::SummaryChanged);
 411        }
 412    }
 413
 414    pub fn latest_detailed_summary_or_text(&self) -> SharedString {
 415        self.latest_detailed_summary()
 416            .unwrap_or_else(|| self.text().into())
 417    }
 418
 419    fn latest_detailed_summary(&self) -> Option<SharedString> {
 420        if let DetailedSummaryState::Generated { text, .. } = &self.detailed_summary_state {
 421            Some(text.clone())
 422        } else {
 423            None
 424        }
 425    }
 426
 427    pub fn message(&self, id: MessageId) -> Option<&Message> {
 428        self.messages.iter().find(|message| message.id == id)
 429    }
 430
 431    pub fn messages(&self) -> impl Iterator<Item = &Message> {
 432        self.messages.iter()
 433    }
 434
 435    pub fn is_generating(&self) -> bool {
 436        !self.pending_completions.is_empty() || !self.all_tools_finished()
 437    }
 438
 439    pub fn tools(&self) -> &Arc<ToolWorkingSet> {
 440        &self.tools
 441    }
 442
 443    pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
 444        self.tool_use
 445            .pending_tool_uses()
 446            .into_iter()
 447            .find(|tool_use| &tool_use.id == id)
 448    }
 449
 450    pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
 451        self.tool_use
 452            .pending_tool_uses()
 453            .into_iter()
 454            .filter(|tool_use| tool_use.status.needs_confirmation())
 455    }
 456
 457    pub fn has_pending_tool_uses(&self) -> bool {
 458        !self.tool_use.pending_tool_uses().is_empty()
 459    }
 460
 461    pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
 462        self.checkpoints_by_message.get(&id).cloned()
 463    }
 464
 465    pub fn restore_checkpoint(
 466        &mut self,
 467        checkpoint: ThreadCheckpoint,
 468        cx: &mut Context<Self>,
 469    ) -> Task<Result<()>> {
 470        self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
 471            message_id: checkpoint.message_id,
 472        });
 473        cx.emit(ThreadEvent::CheckpointChanged);
 474        cx.notify();
 475
 476        let git_store = self.project().read(cx).git_store().clone();
 477        let restore = git_store.update(cx, |git_store, cx| {
 478            git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
 479        });
 480
 481        cx.spawn(async move |this, cx| {
 482            let result = restore.await;
 483            this.update(cx, |this, cx| {
 484                if let Err(err) = result.as_ref() {
 485                    this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
 486                        message_id: checkpoint.message_id,
 487                        error: err.to_string(),
 488                    });
 489                } else {
 490                    this.truncate(checkpoint.message_id, cx);
 491                    this.last_restore_checkpoint = None;
 492                }
 493                this.pending_checkpoint = None;
 494                cx.emit(ThreadEvent::CheckpointChanged);
 495                cx.notify();
 496            })?;
 497            result
 498        })
 499    }
 500
 501    fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
 502        let pending_checkpoint = if self.is_generating() {
 503            return;
 504        } else if let Some(checkpoint) = self.pending_checkpoint.take() {
 505            checkpoint
 506        } else {
 507            return;
 508        };
 509
 510        let git_store = self.project.read(cx).git_store().clone();
 511        let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
 512        cx.spawn(async move |this, cx| match final_checkpoint.await {
 513            Ok(final_checkpoint) => {
 514                let equal = git_store
 515                    .update(cx, |store, cx| {
 516                        store.compare_checkpoints(
 517                            pending_checkpoint.git_checkpoint.clone(),
 518                            final_checkpoint.clone(),
 519                            cx,
 520                        )
 521                    })?
 522                    .await
 523                    .unwrap_or(false);
 524
 525                if equal {
 526                    git_store
 527                        .update(cx, |store, cx| {
 528                            store.delete_checkpoint(pending_checkpoint.git_checkpoint, cx)
 529                        })?
 530                        .detach();
 531                } else {
 532                    this.update(cx, |this, cx| {
 533                        this.insert_checkpoint(pending_checkpoint, cx)
 534                    })?;
 535                }
 536
 537                git_store
 538                    .update(cx, |store, cx| {
 539                        store.delete_checkpoint(final_checkpoint, cx)
 540                    })?
 541                    .detach();
 542
 543                Ok(())
 544            }
 545            Err(_) => this.update(cx, |this, cx| {
 546                this.insert_checkpoint(pending_checkpoint, cx)
 547            }),
 548        })
 549        .detach();
 550    }
 551
 552    fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
 553        self.checkpoints_by_message
 554            .insert(checkpoint.message_id, checkpoint);
 555        cx.emit(ThreadEvent::CheckpointChanged);
 556        cx.notify();
 557    }
 558
 559    pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
 560        self.last_restore_checkpoint.as_ref()
 561    }
 562
 563    pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
 564        let Some(message_ix) = self
 565            .messages
 566            .iter()
 567            .rposition(|message| message.id == message_id)
 568        else {
 569            return;
 570        };
 571        for deleted_message in self.messages.drain(message_ix..) {
 572            self.context_by_message.remove(&deleted_message.id);
 573            self.checkpoints_by_message.remove(&deleted_message.id);
 574        }
 575        cx.notify();
 576    }
 577
 578    pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AssistantContext> {
 579        self.context_by_message
 580            .get(&id)
 581            .into_iter()
 582            .flat_map(|context| {
 583                context
 584                    .iter()
 585                    .filter_map(|context_id| self.context.get(&context_id))
 586            })
 587    }
 588
 589    /// Returns whether all of the tool uses have finished running.
 590    pub fn all_tools_finished(&self) -> bool {
 591        // If the only pending tool uses left are the ones with errors, then
 592        // that means that we've finished running all of the pending tools.
 593        self.tool_use
 594            .pending_tool_uses()
 595            .iter()
 596            .all(|tool_use| tool_use.status.is_error())
 597    }
 598
 599    pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
 600        self.tool_use.tool_uses_for_message(id, cx)
 601    }
 602
 603    pub fn tool_results_for_message(&self, id: MessageId) -> Vec<&LanguageModelToolResult> {
 604        self.tool_use.tool_results_for_message(id)
 605    }
 606
 607    pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
 608        self.tool_use.tool_result(id)
 609    }
 610
 611    pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
 612        self.tool_use.message_has_tool_results(message_id)
 613    }
 614
 615    pub fn insert_user_message(
 616        &mut self,
 617        text: impl Into<String>,
 618        context: Vec<AssistantContext>,
 619        git_checkpoint: Option<GitStoreCheckpoint>,
 620        cx: &mut Context<Self>,
 621    ) -> MessageId {
 622        let text = text.into();
 623
 624        let message_id = self.insert_message(Role::User, vec![MessageSegment::Text(text)], cx);
 625
 626        // Filter out contexts that have already been included in previous messages
 627        let new_context: Vec<_> = context
 628            .into_iter()
 629            .filter(|ctx| !self.context.contains_key(&ctx.id()))
 630            .collect();
 631
 632        if !new_context.is_empty() {
 633            if let Some(context_string) = format_context_as_string(new_context.iter(), cx) {
 634                if let Some(message) = self.messages.iter_mut().find(|m| m.id == message_id) {
 635                    message.context = context_string;
 636                }
 637            }
 638
 639            self.action_log.update(cx, |log, cx| {
 640                // Track all buffers added as context
 641                for ctx in &new_context {
 642                    match ctx {
 643                        AssistantContext::File(file_ctx) => {
 644                            log.buffer_added_as_context(file_ctx.context_buffer.buffer.clone(), cx);
 645                        }
 646                        AssistantContext::Directory(dir_ctx) => {
 647                            for context_buffer in &dir_ctx.context_buffers {
 648                                log.buffer_added_as_context(context_buffer.buffer.clone(), cx);
 649                            }
 650                        }
 651                        AssistantContext::Symbol(symbol_ctx) => {
 652                            log.buffer_added_as_context(
 653                                symbol_ctx.context_symbol.buffer.clone(),
 654                                cx,
 655                            );
 656                        }
 657                        AssistantContext::FetchedUrl(_) | AssistantContext::Thread(_) => {}
 658                    }
 659                }
 660            });
 661        }
 662
 663        let context_ids = new_context
 664            .iter()
 665            .map(|context| context.id())
 666            .collect::<Vec<_>>();
 667        self.context.extend(
 668            new_context
 669                .into_iter()
 670                .map(|context| (context.id(), context)),
 671        );
 672        self.context_by_message.insert(message_id, context_ids);
 673
 674        if let Some(git_checkpoint) = git_checkpoint {
 675            self.pending_checkpoint = Some(ThreadCheckpoint {
 676                message_id,
 677                git_checkpoint,
 678            });
 679        }
 680        message_id
 681    }
 682
 683    pub fn insert_message(
 684        &mut self,
 685        role: Role,
 686        segments: Vec<MessageSegment>,
 687        cx: &mut Context<Self>,
 688    ) -> MessageId {
 689        let id = self.next_message_id.post_inc();
 690        self.messages.push(Message {
 691            id,
 692            role,
 693            segments,
 694            context: String::new(),
 695        });
 696        self.touch_updated_at();
 697        cx.emit(ThreadEvent::MessageAdded(id));
 698        id
 699    }
 700
 701    pub fn edit_message(
 702        &mut self,
 703        id: MessageId,
 704        new_role: Role,
 705        new_segments: Vec<MessageSegment>,
 706        cx: &mut Context<Self>,
 707    ) -> bool {
 708        let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
 709            return false;
 710        };
 711        message.role = new_role;
 712        message.segments = new_segments;
 713        self.touch_updated_at();
 714        cx.emit(ThreadEvent::MessageEdited(id));
 715        true
 716    }
 717
 718    pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
 719        let Some(index) = self.messages.iter().position(|message| message.id == id) else {
 720            return false;
 721        };
 722        self.messages.remove(index);
 723        self.context_by_message.remove(&id);
 724        self.touch_updated_at();
 725        cx.emit(ThreadEvent::MessageDeleted(id));
 726        true
 727    }
 728
 729    /// Returns the representation of this [`Thread`] in a textual form.
 730    ///
 731    /// This is the representation we use when attaching a thread as context to another thread.
 732    pub fn text(&self) -> String {
 733        let mut text = String::new();
 734
 735        for message in &self.messages {
 736            text.push_str(match message.role {
 737                language_model::Role::User => "User:",
 738                language_model::Role::Assistant => "Assistant:",
 739                language_model::Role::System => "System:",
 740            });
 741            text.push('\n');
 742
 743            for segment in &message.segments {
 744                match segment {
 745                    MessageSegment::Text(content) => text.push_str(content),
 746                    MessageSegment::Thinking(content) => {
 747                        text.push_str(&format!("<think>{}</think>", content))
 748                    }
 749                }
 750            }
 751            text.push('\n');
 752        }
 753
 754        text
 755    }
 756
 757    /// Serializes this thread into a format for storage or telemetry.
 758    pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
 759        let initial_project_snapshot = self.initial_project_snapshot.clone();
 760        cx.spawn(async move |this, cx| {
 761            let initial_project_snapshot = initial_project_snapshot.await;
 762            this.read_with(cx, |this, cx| SerializedThread {
 763                version: SerializedThread::VERSION.to_string(),
 764                summary: this.summary_or_default(),
 765                updated_at: this.updated_at(),
 766                messages: this
 767                    .messages()
 768                    .map(|message| SerializedMessage {
 769                        id: message.id,
 770                        role: message.role,
 771                        segments: message
 772                            .segments
 773                            .iter()
 774                            .map(|segment| match segment {
 775                                MessageSegment::Text(text) => {
 776                                    SerializedMessageSegment::Text { text: text.clone() }
 777                                }
 778                                MessageSegment::Thinking(text) => {
 779                                    SerializedMessageSegment::Thinking { text: text.clone() }
 780                                }
 781                            })
 782                            .collect(),
 783                        tool_uses: this
 784                            .tool_uses_for_message(message.id, cx)
 785                            .into_iter()
 786                            .map(|tool_use| SerializedToolUse {
 787                                id: tool_use.id,
 788                                name: tool_use.name,
 789                                input: tool_use.input,
 790                            })
 791                            .collect(),
 792                        tool_results: this
 793                            .tool_results_for_message(message.id)
 794                            .into_iter()
 795                            .map(|tool_result| SerializedToolResult {
 796                                tool_use_id: tool_result.tool_use_id.clone(),
 797                                is_error: tool_result.is_error,
 798                                content: tool_result.content.clone(),
 799                            })
 800                            .collect(),
 801                        context: message.context.clone(),
 802                    })
 803                    .collect(),
 804                initial_project_snapshot,
 805                cumulative_token_usage: this.cumulative_token_usage.clone(),
 806                detailed_summary_state: this.detailed_summary_state.clone(),
 807            })
 808        })
 809    }
 810
 811    pub fn set_system_prompt_context(&mut self, context: AssistantSystemPromptContext) {
 812        self.system_prompt_context = Some(context);
 813    }
 814
 815    pub fn system_prompt_context(&self) -> &Option<AssistantSystemPromptContext> {
 816        &self.system_prompt_context
 817    }
 818
 819    pub fn load_system_prompt_context(
 820        &self,
 821        cx: &App,
 822    ) -> Task<(AssistantSystemPromptContext, Option<ThreadError>)> {
 823        let project = self.project.read(cx);
 824        let tasks = project
 825            .visible_worktrees(cx)
 826            .map(|worktree| {
 827                Self::load_worktree_info_for_system_prompt(
 828                    project.fs().clone(),
 829                    worktree.read(cx),
 830                    cx,
 831                )
 832            })
 833            .collect::<Vec<_>>();
 834
 835        cx.spawn(async |_cx| {
 836            let results = futures::future::join_all(tasks).await;
 837            let mut first_err = None;
 838            let worktrees = results
 839                .into_iter()
 840                .map(|(worktree, err)| {
 841                    if first_err.is_none() && err.is_some() {
 842                        first_err = err;
 843                    }
 844                    worktree
 845                })
 846                .collect::<Vec<_>>();
 847            (AssistantSystemPromptContext::new(worktrees), first_err)
 848        })
 849    }
 850
 851    fn load_worktree_info_for_system_prompt(
 852        fs: Arc<dyn Fs>,
 853        worktree: &Worktree,
 854        cx: &App,
 855    ) -> Task<(WorktreeInfoForSystemPrompt, Option<ThreadError>)> {
 856        let root_name = worktree.root_name().into();
 857        let abs_path = worktree.abs_path();
 858
 859        let rules_task = load_worktree_rules_file(fs, worktree, cx);
 860        let Some(rules_task) = rules_task else {
 861            return Task::ready((
 862                WorktreeInfoForSystemPrompt {
 863                    root_name,
 864                    abs_path,
 865                    rules_file: None,
 866                },
 867                None,
 868            ));
 869        };
 870
 871        cx.spawn(async move |_| {
 872            let (rules_file, rules_file_error) = match rules_task.await {
 873                Ok(rules_file) => (Some(rules_file), None),
 874                Err(err) => (
 875                    None,
 876                    Some(ThreadError::Message {
 877                        header: "Error loading rules file".into(),
 878                        message: format!("{err}").into(),
 879                    }),
 880                ),
 881            };
 882            let worktree_info = WorktreeInfoForSystemPrompt {
 883                root_name,
 884                abs_path,
 885                rules_file,
 886            };
 887            (worktree_info, rules_file_error)
 888        })
 889    }
 890
 891    pub fn send_to_model(
 892        &mut self,
 893        model: Arc<dyn LanguageModel>,
 894        request_kind: RequestKind,
 895        cx: &mut Context<Self>,
 896    ) {
 897        let mut request = self.to_completion_request(request_kind, cx);
 898        if model.supports_tools() {
 899            request.tools = {
 900                let mut tools = Vec::new();
 901                tools.extend(self.tools().enabled_tools(cx).into_iter().map(|tool| {
 902                    LanguageModelRequestTool {
 903                        name: tool.name(),
 904                        description: tool.description(),
 905                        input_schema: tool.input_schema(model.tool_input_format()),
 906                    }
 907                }));
 908
 909                tools
 910            };
 911        }
 912
 913        self.stream_completion(request, model, cx);
 914    }
 915
 916    pub fn used_tools_since_last_user_message(&self) -> bool {
 917        for message in self.messages.iter().rev() {
 918            if self.tool_use.message_has_tool_results(message.id) {
 919                return true;
 920            } else if message.role == Role::User {
 921                return false;
 922            }
 923        }
 924
 925        false
 926    }
 927
 928    pub fn to_completion_request(
 929        &self,
 930        request_kind: RequestKind,
 931        cx: &App,
 932    ) -> LanguageModelRequest {
 933        let mut request = LanguageModelRequest {
 934            messages: vec![],
 935            tools: Vec::new(),
 936            stop: Vec::new(),
 937            temperature: None,
 938        };
 939
 940        if let Some(system_prompt_context) = self.system_prompt_context.as_ref() {
 941            if let Some(system_prompt) = self
 942                .prompt_builder
 943                .generate_assistant_system_prompt(system_prompt_context)
 944                .context("failed to generate assistant system prompt")
 945                .log_err()
 946            {
 947                request.messages.push(LanguageModelRequestMessage {
 948                    role: Role::System,
 949                    content: vec![MessageContent::Text(system_prompt)],
 950                    cache: true,
 951                });
 952            }
 953        } else {
 954            log::error!("system_prompt_context not set.")
 955        }
 956
 957        for message in &self.messages {
 958            let mut request_message = LanguageModelRequestMessage {
 959                role: message.role,
 960                content: Vec::new(),
 961                cache: false,
 962            };
 963
 964            match request_kind {
 965                RequestKind::Chat => {
 966                    self.tool_use
 967                        .attach_tool_results(message.id, &mut request_message);
 968                }
 969                RequestKind::Summarize => {
 970                    // We don't care about tool use during summarization.
 971                    if self.tool_use.message_has_tool_results(message.id) {
 972                        continue;
 973                    }
 974                }
 975            }
 976
 977            if !message.segments.is_empty() {
 978                request_message
 979                    .content
 980                    .push(MessageContent::Text(message.to_string()));
 981            }
 982
 983            match request_kind {
 984                RequestKind::Chat => {
 985                    self.tool_use
 986                        .attach_tool_uses(message.id, &mut request_message);
 987                }
 988                RequestKind::Summarize => {
 989                    // We don't care about tool use during summarization.
 990                }
 991            };
 992
 993            request.messages.push(request_message);
 994        }
 995
 996        // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
 997        if let Some(last) = request.messages.last_mut() {
 998            last.cache = true;
 999        }
1000
1001        self.attached_tracked_files_state(&mut request.messages, cx);
1002
1003        request
1004    }
1005
1006    fn attached_tracked_files_state(
1007        &self,
1008        messages: &mut Vec<LanguageModelRequestMessage>,
1009        cx: &App,
1010    ) {
1011        const STALE_FILES_HEADER: &str = "These files changed since last read:";
1012
1013        let mut stale_message = String::new();
1014
1015        let action_log = self.action_log.read(cx);
1016
1017        for stale_file in action_log.stale_buffers(cx) {
1018            let Some(file) = stale_file.read(cx).file() else {
1019                continue;
1020            };
1021
1022            if stale_message.is_empty() {
1023                write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
1024            }
1025
1026            writeln!(&mut stale_message, "- {}", file.path().display()).ok();
1027        }
1028
1029        let mut content = Vec::with_capacity(2);
1030
1031        if !stale_message.is_empty() {
1032            content.push(stale_message.into());
1033        }
1034
1035        if action_log.has_edited_files_since_project_diagnostics_check() {
1036            content.push(
1037                "\n\nWhen you're done making changes, make sure to check project diagnostics \
1038                and fix all errors AND warnings you introduced! \
1039                DO NOT mention you're going to do this until you're done."
1040                    .into(),
1041            );
1042        }
1043
1044        if !content.is_empty() {
1045            let context_message = LanguageModelRequestMessage {
1046                role: Role::User,
1047                content,
1048                cache: false,
1049            };
1050
1051            messages.push(context_message);
1052        }
1053    }
1054
1055    pub fn stream_completion(
1056        &mut self,
1057        request: LanguageModelRequest,
1058        model: Arc<dyn LanguageModel>,
1059        cx: &mut Context<Self>,
1060    ) {
1061        let pending_completion_id = post_inc(&mut self.completion_count);
1062
1063        let task = cx.spawn(async move |thread, cx| {
1064            let stream = model.stream_completion(request, &cx);
1065            let initial_token_usage =
1066                thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage.clone());
1067            let stream_completion = async {
1068                let mut events = stream.await?;
1069                let mut stop_reason = StopReason::EndTurn;
1070                let mut current_token_usage = TokenUsage::default();
1071
1072                while let Some(event) = events.next().await {
1073                    let event = event?;
1074
1075                    thread.update(cx, |thread, cx| {
1076                        match event {
1077                            LanguageModelCompletionEvent::StartMessage { .. } => {
1078                                thread.insert_message(
1079                                    Role::Assistant,
1080                                    vec![MessageSegment::Text(String::new())],
1081                                    cx,
1082                                );
1083                            }
1084                            LanguageModelCompletionEvent::Stop(reason) => {
1085                                stop_reason = reason;
1086                            }
1087                            LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1088                                thread.cumulative_token_usage =
1089                                    thread.cumulative_token_usage.clone() + token_usage.clone()
1090                                        - current_token_usage.clone();
1091                                current_token_usage = token_usage;
1092                            }
1093                            LanguageModelCompletionEvent::Text(chunk) => {
1094                                if let Some(last_message) = thread.messages.last_mut() {
1095                                    if last_message.role == Role::Assistant {
1096                                        last_message.push_text(&chunk);
1097                                        cx.emit(ThreadEvent::StreamedAssistantText(
1098                                            last_message.id,
1099                                            chunk,
1100                                        ));
1101                                    } else {
1102                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
1103                                        // of a new Assistant response.
1104                                        //
1105                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1106                                        // will result in duplicating the text of the chunk in the rendered Markdown.
1107                                        thread.insert_message(
1108                                            Role::Assistant,
1109                                            vec![MessageSegment::Text(chunk.to_string())],
1110                                            cx,
1111                                        );
1112                                    };
1113                                }
1114                            }
1115                            LanguageModelCompletionEvent::Thinking(chunk) => {
1116                                if let Some(last_message) = thread.messages.last_mut() {
1117                                    if last_message.role == Role::Assistant {
1118                                        last_message.push_thinking(&chunk);
1119                                        cx.emit(ThreadEvent::StreamedAssistantThinking(
1120                                            last_message.id,
1121                                            chunk,
1122                                        ));
1123                                    } else {
1124                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
1125                                        // of a new Assistant response.
1126                                        //
1127                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1128                                        // will result in duplicating the text of the chunk in the rendered Markdown.
1129                                        thread.insert_message(
1130                                            Role::Assistant,
1131                                            vec![MessageSegment::Thinking(chunk.to_string())],
1132                                            cx,
1133                                        );
1134                                    };
1135                                }
1136                            }
1137                            LanguageModelCompletionEvent::ToolUse(tool_use) => {
1138                                let last_assistant_message_id = thread
1139                                    .messages
1140                                    .iter_mut()
1141                                    .rfind(|message| message.role == Role::Assistant)
1142                                    .map(|message| message.id)
1143                                    .unwrap_or_else(|| {
1144                                        thread.insert_message(Role::Assistant, vec![], cx)
1145                                    });
1146
1147                                thread.tool_use.request_tool_use(
1148                                    last_assistant_message_id,
1149                                    tool_use,
1150                                    cx,
1151                                );
1152                            }
1153                        }
1154
1155                        thread.touch_updated_at();
1156                        cx.emit(ThreadEvent::StreamedCompletion);
1157                        cx.notify();
1158                    })?;
1159
1160                    smol::future::yield_now().await;
1161                }
1162
1163                thread.update(cx, |thread, cx| {
1164                    thread
1165                        .pending_completions
1166                        .retain(|completion| completion.id != pending_completion_id);
1167
1168                    if thread.summary.is_none() && thread.messages.len() >= 2 {
1169                        thread.summarize(cx);
1170                    }
1171                })?;
1172
1173                anyhow::Ok(stop_reason)
1174            };
1175
1176            let result = stream_completion.await;
1177
1178            thread
1179                .update(cx, |thread, cx| {
1180                    thread.finalize_pending_checkpoint(cx);
1181                    match result.as_ref() {
1182                        Ok(stop_reason) => match stop_reason {
1183                            StopReason::ToolUse => {
1184                                cx.emit(ThreadEvent::UsePendingTools);
1185                            }
1186                            StopReason::EndTurn => {}
1187                            StopReason::MaxTokens => {}
1188                        },
1189                        Err(error) => {
1190                            if error.is::<PaymentRequiredError>() {
1191                                cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1192                            } else if error.is::<MaxMonthlySpendReachedError>() {
1193                                cx.emit(ThreadEvent::ShowError(
1194                                    ThreadError::MaxMonthlySpendReached,
1195                                ));
1196                            } else {
1197                                let error_message = error
1198                                    .chain()
1199                                    .map(|err| err.to_string())
1200                                    .collect::<Vec<_>>()
1201                                    .join("\n");
1202                                cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1203                                    header: "Error interacting with language model".into(),
1204                                    message: SharedString::from(error_message.clone()),
1205                                }));
1206                            }
1207
1208                            thread.cancel_last_completion(cx);
1209                        }
1210                    }
1211                    cx.emit(ThreadEvent::DoneStreaming);
1212
1213                    if let Ok(initial_usage) = initial_token_usage {
1214                        let usage = thread.cumulative_token_usage.clone() - initial_usage;
1215
1216                        telemetry::event!(
1217                            "Assistant Thread Completion",
1218                            thread_id = thread.id().to_string(),
1219                            model = model.telemetry_id(),
1220                            model_provider = model.provider_id().to_string(),
1221                            input_tokens = usage.input_tokens,
1222                            output_tokens = usage.output_tokens,
1223                            cache_creation_input_tokens = usage.cache_creation_input_tokens,
1224                            cache_read_input_tokens = usage.cache_read_input_tokens,
1225                        );
1226                    }
1227                })
1228                .ok();
1229        });
1230
1231        self.pending_completions.push(PendingCompletion {
1232            id: pending_completion_id,
1233            _task: task,
1234        });
1235    }
1236
1237    pub fn summarize(&mut self, cx: &mut Context<Self>) {
1238        let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1239            return;
1240        };
1241
1242        if !model.provider.is_authenticated(cx) {
1243            return;
1244        }
1245
1246        let mut request = self.to_completion_request(RequestKind::Summarize, cx);
1247        request.messages.push(LanguageModelRequestMessage {
1248            role: Role::User,
1249            content: vec![
1250                "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1251                 Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1252                 If the conversation is about a specific subject, include it in the title. \
1253                 Be descriptive. DO NOT speak in the first person."
1254                    .into(),
1255            ],
1256            cache: false,
1257        });
1258
1259        self.pending_summary = cx.spawn(async move |this, cx| {
1260            async move {
1261                let stream = model.model.stream_completion_text(request, &cx);
1262                let mut messages = stream.await?;
1263
1264                let mut new_summary = String::new();
1265                while let Some(message) = messages.stream.next().await {
1266                    let text = message?;
1267                    let mut lines = text.lines();
1268                    new_summary.extend(lines.next());
1269
1270                    // Stop if the LLM generated multiple lines.
1271                    if lines.next().is_some() {
1272                        break;
1273                    }
1274                }
1275
1276                this.update(cx, |this, cx| {
1277                    if !new_summary.is_empty() {
1278                        this.summary = Some(new_summary.into());
1279                    }
1280
1281                    cx.emit(ThreadEvent::SummaryGenerated);
1282                })?;
1283
1284                anyhow::Ok(())
1285            }
1286            .log_err()
1287            .await
1288        });
1289    }
1290
1291    pub fn generate_detailed_summary(&mut self, cx: &mut Context<Self>) -> Option<Task<()>> {
1292        let last_message_id = self.messages.last().map(|message| message.id)?;
1293
1294        match &self.detailed_summary_state {
1295            DetailedSummaryState::Generating { message_id, .. }
1296            | DetailedSummaryState::Generated { message_id, .. }
1297                if *message_id == last_message_id =>
1298            {
1299                // Already up-to-date
1300                return None;
1301            }
1302            _ => {}
1303        }
1304
1305        let ConfiguredModel { model, provider } =
1306            LanguageModelRegistry::read_global(cx).thread_summary_model()?;
1307
1308        if !provider.is_authenticated(cx) {
1309            return None;
1310        }
1311
1312        let mut request = self.to_completion_request(RequestKind::Summarize, cx);
1313
1314        request.messages.push(LanguageModelRequestMessage {
1315            role: Role::User,
1316            content: vec![
1317                "Generate a detailed summary of this conversation. Include:\n\
1318                1. A brief overview of what was discussed\n\
1319                2. Key facts or information discovered\n\
1320                3. Outcomes or conclusions reached\n\
1321                4. Any action items or next steps if any\n\
1322                Format it in Markdown with headings and bullet points."
1323                    .into(),
1324            ],
1325            cache: false,
1326        });
1327
1328        let task = cx.spawn(async move |thread, cx| {
1329            let stream = model.stream_completion_text(request, &cx);
1330            let Some(mut messages) = stream.await.log_err() else {
1331                thread
1332                    .update(cx, |this, _cx| {
1333                        this.detailed_summary_state = DetailedSummaryState::NotGenerated;
1334                    })
1335                    .log_err();
1336
1337                return;
1338            };
1339
1340            let mut new_detailed_summary = String::new();
1341
1342            while let Some(chunk) = messages.stream.next().await {
1343                if let Some(chunk) = chunk.log_err() {
1344                    new_detailed_summary.push_str(&chunk);
1345                }
1346            }
1347
1348            thread
1349                .update(cx, |this, _cx| {
1350                    this.detailed_summary_state = DetailedSummaryState::Generated {
1351                        text: new_detailed_summary.into(),
1352                        message_id: last_message_id,
1353                    };
1354                })
1355                .log_err();
1356        });
1357
1358        self.detailed_summary_state = DetailedSummaryState::Generating {
1359            message_id: last_message_id,
1360        };
1361
1362        Some(task)
1363    }
1364
1365    pub fn is_generating_detailed_summary(&self) -> bool {
1366        matches!(
1367            self.detailed_summary_state,
1368            DetailedSummaryState::Generating { .. }
1369        )
1370    }
1371
1372    pub fn use_pending_tools(
1373        &mut self,
1374        cx: &mut Context<Self>,
1375    ) -> impl IntoIterator<Item = PendingToolUse> + use<> {
1376        let request = self.to_completion_request(RequestKind::Chat, cx);
1377        let messages = Arc::new(request.messages);
1378        let pending_tool_uses = self
1379            .tool_use
1380            .pending_tool_uses()
1381            .into_iter()
1382            .filter(|tool_use| tool_use.status.is_idle())
1383            .cloned()
1384            .collect::<Vec<_>>();
1385
1386        for tool_use in pending_tool_uses.iter() {
1387            if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
1388                if tool.needs_confirmation(&tool_use.input, cx)
1389                    && !AssistantSettings::get_global(cx).always_allow_tool_actions
1390                {
1391                    self.tool_use.confirm_tool_use(
1392                        tool_use.id.clone(),
1393                        tool_use.ui_text.clone(),
1394                        tool_use.input.clone(),
1395                        messages.clone(),
1396                        tool,
1397                    );
1398                    cx.emit(ThreadEvent::ToolConfirmationNeeded);
1399                } else {
1400                    self.run_tool(
1401                        tool_use.id.clone(),
1402                        tool_use.ui_text.clone(),
1403                        tool_use.input.clone(),
1404                        &messages,
1405                        tool,
1406                        cx,
1407                    );
1408                }
1409            }
1410        }
1411
1412        pending_tool_uses
1413    }
1414
1415    pub fn run_tool(
1416        &mut self,
1417        tool_use_id: LanguageModelToolUseId,
1418        ui_text: impl Into<SharedString>,
1419        input: serde_json::Value,
1420        messages: &[LanguageModelRequestMessage],
1421        tool: Arc<dyn Tool>,
1422        cx: &mut Context<Thread>,
1423    ) {
1424        let task = self.spawn_tool_use(tool_use_id.clone(), messages, input, tool, cx);
1425        self.tool_use
1426            .run_pending_tool(tool_use_id, ui_text.into(), task);
1427    }
1428
1429    fn spawn_tool_use(
1430        &mut self,
1431        tool_use_id: LanguageModelToolUseId,
1432        messages: &[LanguageModelRequestMessage],
1433        input: serde_json::Value,
1434        tool: Arc<dyn Tool>,
1435        cx: &mut Context<Thread>,
1436    ) -> Task<()> {
1437        let tool_name: Arc<str> = tool.name().into();
1438
1439        let run_tool = if self.tools.is_disabled(&tool.source(), &tool_name) {
1440            Task::ready(Err(anyhow!("tool is disabled: {tool_name}")))
1441        } else {
1442            tool.run(
1443                input,
1444                messages,
1445                self.project.clone(),
1446                self.action_log.clone(),
1447                cx,
1448            )
1449        };
1450
1451        cx.spawn({
1452            async move |thread: WeakEntity<Thread>, cx| {
1453                let output = run_tool.await;
1454
1455                thread
1456                    .update(cx, |thread, cx| {
1457                        let pending_tool_use = thread.tool_use.insert_tool_output(
1458                            tool_use_id.clone(),
1459                            tool_name,
1460                            output,
1461                            cx,
1462                        );
1463
1464                        cx.emit(ThreadEvent::ToolFinished {
1465                            tool_use_id,
1466                            pending_tool_use,
1467                            canceled: false,
1468                        });
1469                    })
1470                    .ok();
1471            }
1472        })
1473    }
1474
1475    pub fn attach_tool_results(&mut self, cx: &mut Context<Self>) {
1476        // Insert a user message to contain the tool results.
1477        self.insert_user_message(
1478            // TODO: Sending up a user message without any content results in the model sending back
1479            // responses that also don't have any content. We currently don't handle this case well,
1480            // so for now we provide some text to keep the model on track.
1481            "Here are the tool results.",
1482            Vec::new(),
1483            None,
1484            cx,
1485        );
1486    }
1487
1488    /// Cancels the last pending completion, if there are any pending.
1489    ///
1490    /// Returns whether a completion was canceled.
1491    pub fn cancel_last_completion(&mut self, cx: &mut Context<Self>) -> bool {
1492        let canceled = if self.pending_completions.pop().is_some() {
1493            true
1494        } else {
1495            let mut canceled = false;
1496            for pending_tool_use in self.tool_use.cancel_pending() {
1497                canceled = true;
1498                cx.emit(ThreadEvent::ToolFinished {
1499                    tool_use_id: pending_tool_use.id.clone(),
1500                    pending_tool_use: Some(pending_tool_use),
1501                    canceled: true,
1502                });
1503            }
1504            canceled
1505        };
1506        self.finalize_pending_checkpoint(cx);
1507        canceled
1508    }
1509
1510    pub fn feedback(&self) -> Option<ThreadFeedback> {
1511        self.feedback
1512    }
1513
1514    pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
1515        self.message_feedback.get(&message_id).copied()
1516    }
1517
1518    pub fn report_message_feedback(
1519        &mut self,
1520        message_id: MessageId,
1521        feedback: ThreadFeedback,
1522        cx: &mut Context<Self>,
1523    ) -> Task<Result<()>> {
1524        if self.message_feedback.get(&message_id) == Some(&feedback) {
1525            return Task::ready(Ok(()));
1526        }
1527
1528        let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1529        let serialized_thread = self.serialize(cx);
1530        let thread_id = self.id().clone();
1531        let client = self.project.read(cx).client();
1532
1533        self.message_feedback.insert(message_id, feedback);
1534
1535        cx.notify();
1536
1537        let message_content = self
1538            .message(message_id)
1539            .map(|msg| msg.to_string())
1540            .unwrap_or_default();
1541
1542        cx.background_spawn(async move {
1543            let final_project_snapshot = final_project_snapshot.await;
1544            let serialized_thread = serialized_thread.await?;
1545            let thread_data =
1546                serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
1547
1548            let rating = match feedback {
1549                ThreadFeedback::Positive => "positive",
1550                ThreadFeedback::Negative => "negative",
1551            };
1552            telemetry::event!(
1553                "Assistant Thread Rated",
1554                rating,
1555                thread_id,
1556                message_id = message_id.0,
1557                message_content,
1558                thread_data,
1559                final_project_snapshot
1560            );
1561            client.telemetry().flush_events();
1562
1563            Ok(())
1564        })
1565    }
1566
1567    pub fn report_feedback(
1568        &mut self,
1569        feedback: ThreadFeedback,
1570        cx: &mut Context<Self>,
1571    ) -> Task<Result<()>> {
1572        let last_assistant_message_id = self
1573            .messages
1574            .iter()
1575            .rev()
1576            .find(|msg| msg.role == Role::Assistant)
1577            .map(|msg| msg.id);
1578
1579        if let Some(message_id) = last_assistant_message_id {
1580            self.report_message_feedback(message_id, feedback, cx)
1581        } else {
1582            let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1583            let serialized_thread = self.serialize(cx);
1584            let thread_id = self.id().clone();
1585            let client = self.project.read(cx).client();
1586            self.feedback = Some(feedback);
1587            cx.notify();
1588
1589            cx.background_spawn(async move {
1590                let final_project_snapshot = final_project_snapshot.await;
1591                let serialized_thread = serialized_thread.await?;
1592                let thread_data = serde_json::to_value(serialized_thread)
1593                    .unwrap_or_else(|_| serde_json::Value::Null);
1594
1595                let rating = match feedback {
1596                    ThreadFeedback::Positive => "positive",
1597                    ThreadFeedback::Negative => "negative",
1598                };
1599                telemetry::event!(
1600                    "Assistant Thread Rated",
1601                    rating,
1602                    thread_id,
1603                    thread_data,
1604                    final_project_snapshot
1605                );
1606                client.telemetry().flush_events();
1607
1608                Ok(())
1609            })
1610        }
1611    }
1612
1613    /// Create a snapshot of the current project state including git information and unsaved buffers.
1614    fn project_snapshot(
1615        project: Entity<Project>,
1616        cx: &mut Context<Self>,
1617    ) -> Task<Arc<ProjectSnapshot>> {
1618        let git_store = project.read(cx).git_store().clone();
1619        let worktree_snapshots: Vec<_> = project
1620            .read(cx)
1621            .visible_worktrees(cx)
1622            .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
1623            .collect();
1624
1625        cx.spawn(async move |_, cx| {
1626            let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
1627
1628            let mut unsaved_buffers = Vec::new();
1629            cx.update(|app_cx| {
1630                let buffer_store = project.read(app_cx).buffer_store();
1631                for buffer_handle in buffer_store.read(app_cx).buffers() {
1632                    let buffer = buffer_handle.read(app_cx);
1633                    if buffer.is_dirty() {
1634                        if let Some(file) = buffer.file() {
1635                            let path = file.path().to_string_lossy().to_string();
1636                            unsaved_buffers.push(path);
1637                        }
1638                    }
1639                }
1640            })
1641            .ok();
1642
1643            Arc::new(ProjectSnapshot {
1644                worktree_snapshots,
1645                unsaved_buffer_paths: unsaved_buffers,
1646                timestamp: Utc::now(),
1647            })
1648        })
1649    }
1650
1651    fn worktree_snapshot(
1652        worktree: Entity<project::Worktree>,
1653        git_store: Entity<GitStore>,
1654        cx: &App,
1655    ) -> Task<WorktreeSnapshot> {
1656        cx.spawn(async move |cx| {
1657            // Get worktree path and snapshot
1658            let worktree_info = cx.update(|app_cx| {
1659                let worktree = worktree.read(app_cx);
1660                let path = worktree.abs_path().to_string_lossy().to_string();
1661                let snapshot = worktree.snapshot();
1662                (path, snapshot)
1663            });
1664
1665            let Ok((worktree_path, _snapshot)) = worktree_info else {
1666                return WorktreeSnapshot {
1667                    worktree_path: String::new(),
1668                    git_state: None,
1669                };
1670            };
1671
1672            let git_state = git_store
1673                .update(cx, |git_store, cx| {
1674                    git_store
1675                        .repositories()
1676                        .values()
1677                        .find(|repo| {
1678                            repo.read(cx)
1679                                .abs_path_to_repo_path(&worktree.read(cx).abs_path())
1680                                .is_some()
1681                        })
1682                        .cloned()
1683                })
1684                .ok()
1685                .flatten()
1686                .map(|repo| {
1687                    repo.update(cx, |repo, _| {
1688                        let current_branch =
1689                            repo.branch.as_ref().map(|branch| branch.name.to_string());
1690                        repo.send_job(None, |state, _| async move {
1691                            let RepositoryState::Local { backend, .. } = state else {
1692                                return GitState {
1693                                    remote_url: None,
1694                                    head_sha: None,
1695                                    current_branch,
1696                                    diff: None,
1697                                };
1698                            };
1699
1700                            let remote_url = backend.remote_url("origin");
1701                            let head_sha = backend.head_sha();
1702                            let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
1703
1704                            GitState {
1705                                remote_url,
1706                                head_sha,
1707                                current_branch,
1708                                diff,
1709                            }
1710                        })
1711                    })
1712                });
1713
1714            let git_state = match git_state {
1715                Some(git_state) => match git_state.ok() {
1716                    Some(git_state) => git_state.await.ok(),
1717                    None => None,
1718                },
1719                None => None,
1720            };
1721
1722            WorktreeSnapshot {
1723                worktree_path,
1724                git_state,
1725            }
1726        })
1727    }
1728
1729    pub fn to_markdown(&self, cx: &App) -> Result<String> {
1730        let mut markdown = Vec::new();
1731
1732        if let Some(summary) = self.summary() {
1733            writeln!(markdown, "# {summary}\n")?;
1734        };
1735
1736        for message in self.messages() {
1737            writeln!(
1738                markdown,
1739                "## {role}\n",
1740                role = match message.role {
1741                    Role::User => "User",
1742                    Role::Assistant => "Assistant",
1743                    Role::System => "System",
1744                }
1745            )?;
1746
1747            if !message.context.is_empty() {
1748                writeln!(markdown, "{}", message.context)?;
1749            }
1750
1751            for segment in &message.segments {
1752                match segment {
1753                    MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
1754                    MessageSegment::Thinking(text) => {
1755                        writeln!(markdown, "<think>{}</think>\n", text)?
1756                    }
1757                }
1758            }
1759
1760            for tool_use in self.tool_uses_for_message(message.id, cx) {
1761                writeln!(
1762                    markdown,
1763                    "**Use Tool: {} ({})**",
1764                    tool_use.name, tool_use.id
1765                )?;
1766                writeln!(markdown, "```json")?;
1767                writeln!(
1768                    markdown,
1769                    "{}",
1770                    serde_json::to_string_pretty(&tool_use.input)?
1771                )?;
1772                writeln!(markdown, "```")?;
1773            }
1774
1775            for tool_result in self.tool_results_for_message(message.id) {
1776                write!(markdown, "**Tool Results: {}", tool_result.tool_use_id)?;
1777                if tool_result.is_error {
1778                    write!(markdown, " (Error)")?;
1779                }
1780
1781                writeln!(markdown, "**\n")?;
1782                writeln!(markdown, "{}", tool_result.content)?;
1783            }
1784        }
1785
1786        Ok(String::from_utf8_lossy(&markdown).to_string())
1787    }
1788
1789    pub fn keep_edits_in_range(
1790        &mut self,
1791        buffer: Entity<language::Buffer>,
1792        buffer_range: Range<language::Anchor>,
1793        cx: &mut Context<Self>,
1794    ) {
1795        self.action_log.update(cx, |action_log, cx| {
1796            action_log.keep_edits_in_range(buffer, buffer_range, cx)
1797        });
1798    }
1799
1800    pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
1801        self.action_log
1802            .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
1803    }
1804
1805    pub fn reject_edits_in_range(
1806        &mut self,
1807        buffer: Entity<language::Buffer>,
1808        buffer_range: Range<language::Anchor>,
1809        cx: &mut Context<Self>,
1810    ) -> Task<Result<()>> {
1811        self.action_log.update(cx, |action_log, cx| {
1812            action_log.reject_edits_in_range(buffer, buffer_range, cx)
1813        })
1814    }
1815
1816    pub fn action_log(&self) -> &Entity<ActionLog> {
1817        &self.action_log
1818    }
1819
1820    pub fn project(&self) -> &Entity<Project> {
1821        &self.project
1822    }
1823
1824    pub fn cumulative_token_usage(&self) -> TokenUsage {
1825        self.cumulative_token_usage.clone()
1826    }
1827
1828    pub fn total_token_usage(&self, cx: &App) -> TotalTokenUsage {
1829        let model_registry = LanguageModelRegistry::read_global(cx);
1830        let Some(model) = model_registry.default_model() else {
1831            return TotalTokenUsage::default();
1832        };
1833
1834        let max = model.model.max_token_count();
1835
1836        #[cfg(debug_assertions)]
1837        let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
1838            .unwrap_or("0.8".to_string())
1839            .parse()
1840            .unwrap();
1841        #[cfg(not(debug_assertions))]
1842        let warning_threshold: f32 = 0.8;
1843
1844        let total = self.cumulative_token_usage.total_tokens() as usize;
1845
1846        let ratio = if total >= max {
1847            TokenUsageRatio::Exceeded
1848        } else if total as f32 / max as f32 >= warning_threshold {
1849            TokenUsageRatio::Warning
1850        } else {
1851            TokenUsageRatio::Normal
1852        };
1853
1854        TotalTokenUsage { total, max, ratio }
1855    }
1856
1857    pub fn deny_tool_use(
1858        &mut self,
1859        tool_use_id: LanguageModelToolUseId,
1860        tool_name: Arc<str>,
1861        cx: &mut Context<Self>,
1862    ) {
1863        let err = Err(anyhow::anyhow!(
1864            "Permission to run tool action denied by user"
1865        ));
1866
1867        self.tool_use
1868            .insert_tool_output(tool_use_id.clone(), tool_name, err, cx);
1869
1870        cx.emit(ThreadEvent::ToolFinished {
1871            tool_use_id,
1872            pending_tool_use: None,
1873            canceled: true,
1874        });
1875    }
1876}
1877
1878#[derive(Debug, Clone)]
1879pub enum ThreadError {
1880    PaymentRequired,
1881    MaxMonthlySpendReached,
1882    Message {
1883        header: SharedString,
1884        message: SharedString,
1885    },
1886}
1887
1888#[derive(Debug, Clone)]
1889pub enum ThreadEvent {
1890    ShowError(ThreadError),
1891    StreamedCompletion,
1892    StreamedAssistantText(MessageId, String),
1893    StreamedAssistantThinking(MessageId, String),
1894    DoneStreaming,
1895    MessageAdded(MessageId),
1896    MessageEdited(MessageId),
1897    MessageDeleted(MessageId),
1898    SummaryGenerated,
1899    SummaryChanged,
1900    UsePendingTools,
1901    ToolFinished {
1902        #[allow(unused)]
1903        tool_use_id: LanguageModelToolUseId,
1904        /// The pending tool use that corresponds to this tool.
1905        pending_tool_use: Option<PendingToolUse>,
1906        /// Whether the tool was canceled by the user.
1907        canceled: bool,
1908    },
1909    CheckpointChanged,
1910    ToolConfirmationNeeded,
1911}
1912
1913impl EventEmitter<ThreadEvent> for Thread {}
1914
1915struct PendingCompletion {
1916    id: usize,
1917    _task: Task<()>,
1918}
1919
1920#[cfg(test)]
1921mod tests {
1922    use super::*;
1923    use crate::{ThreadStore, context_store::ContextStore, thread_store};
1924    use assistant_settings::AssistantSettings;
1925    use context_server::ContextServerSettings;
1926    use editor::EditorSettings;
1927    use gpui::TestAppContext;
1928    use project::{FakeFs, Project};
1929    use prompt_store::PromptBuilder;
1930    use serde_json::json;
1931    use settings::{Settings, SettingsStore};
1932    use std::sync::Arc;
1933    use theme::ThemeSettings;
1934    use util::path;
1935    use workspace::Workspace;
1936
1937    #[gpui::test]
1938    async fn test_message_with_context(cx: &mut TestAppContext) {
1939        init_test_settings(cx);
1940
1941        let project = create_test_project(
1942            cx,
1943            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
1944        )
1945        .await;
1946
1947        let (_workspace, _thread_store, thread, context_store) =
1948            setup_test_environment(cx, project.clone()).await;
1949
1950        add_file_to_context(&project, &context_store, "test/code.rs", cx)
1951            .await
1952            .unwrap();
1953
1954        let context =
1955            context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
1956
1957        // Insert user message with context
1958        let message_id = thread.update(cx, |thread, cx| {
1959            thread.insert_user_message("Please explain this code", vec![context], None, cx)
1960        });
1961
1962        // Check content and context in message object
1963        let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
1964
1965        // Use different path format strings based on platform for the test
1966        #[cfg(windows)]
1967        let path_part = r"test\code.rs";
1968        #[cfg(not(windows))]
1969        let path_part = "test/code.rs";
1970
1971        let expected_context = format!(
1972            r#"
1973<context>
1974The following items were attached by the user. You don't need to use other tools to read them.
1975
1976<files>
1977```rs {path_part}
1978fn main() {{
1979    println!("Hello, world!");
1980}}
1981```
1982</files>
1983</context>
1984"#
1985        );
1986
1987        assert_eq!(message.role, Role::User);
1988        assert_eq!(message.segments.len(), 1);
1989        assert_eq!(
1990            message.segments[0],
1991            MessageSegment::Text("Please explain this code".to_string())
1992        );
1993        assert_eq!(message.context, expected_context);
1994
1995        // Check message in request
1996        let request = thread.read_with(cx, |thread, cx| {
1997            thread.to_completion_request(RequestKind::Chat, cx)
1998        });
1999
2000        assert_eq!(request.messages.len(), 1);
2001        let expected_full_message = format!("{}Please explain this code", expected_context);
2002        assert_eq!(request.messages[0].string_contents(), expected_full_message);
2003    }
2004
2005    #[gpui::test]
2006    async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2007        init_test_settings(cx);
2008
2009        let project = create_test_project(
2010            cx,
2011            json!({
2012                "file1.rs": "fn function1() {}\n",
2013                "file2.rs": "fn function2() {}\n",
2014                "file3.rs": "fn function3() {}\n",
2015            }),
2016        )
2017        .await;
2018
2019        let (_, _thread_store, thread, context_store) =
2020            setup_test_environment(cx, project.clone()).await;
2021
2022        // Open files individually
2023        add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2024            .await
2025            .unwrap();
2026        add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2027            .await
2028            .unwrap();
2029        add_file_to_context(&project, &context_store, "test/file3.rs", cx)
2030            .await
2031            .unwrap();
2032
2033        // Get the context objects
2034        let contexts = context_store.update(cx, |store, _| store.context().clone());
2035        assert_eq!(contexts.len(), 3);
2036
2037        // First message with context 1
2038        let message1_id = thread.update(cx, |thread, cx| {
2039            thread.insert_user_message("Message 1", vec![contexts[0].clone()], None, cx)
2040        });
2041
2042        // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2043        let message2_id = thread.update(cx, |thread, cx| {
2044            thread.insert_user_message(
2045                "Message 2",
2046                vec![contexts[0].clone(), contexts[1].clone()],
2047                None,
2048                cx,
2049            )
2050        });
2051
2052        // Third message with all three contexts (contexts 1 and 2 should be skipped)
2053        let message3_id = thread.update(cx, |thread, cx| {
2054            thread.insert_user_message(
2055                "Message 3",
2056                vec![
2057                    contexts[0].clone(),
2058                    contexts[1].clone(),
2059                    contexts[2].clone(),
2060                ],
2061                None,
2062                cx,
2063            )
2064        });
2065
2066        // Check what contexts are included in each message
2067        let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
2068            (
2069                thread.message(message1_id).unwrap().clone(),
2070                thread.message(message2_id).unwrap().clone(),
2071                thread.message(message3_id).unwrap().clone(),
2072            )
2073        });
2074
2075        // First message should include context 1
2076        assert!(message1.context.contains("file1.rs"));
2077
2078        // Second message should include only context 2 (not 1)
2079        assert!(!message2.context.contains("file1.rs"));
2080        assert!(message2.context.contains("file2.rs"));
2081
2082        // Third message should include only context 3 (not 1 or 2)
2083        assert!(!message3.context.contains("file1.rs"));
2084        assert!(!message3.context.contains("file2.rs"));
2085        assert!(message3.context.contains("file3.rs"));
2086
2087        // Check entire request to make sure all contexts are properly included
2088        let request = thread.read_with(cx, |thread, cx| {
2089            thread.to_completion_request(RequestKind::Chat, cx)
2090        });
2091
2092        // The request should contain all 3 messages
2093        assert_eq!(request.messages.len(), 3);
2094
2095        // Check that the contexts are properly formatted in each message
2096        assert!(request.messages[0].string_contents().contains("file1.rs"));
2097        assert!(!request.messages[0].string_contents().contains("file2.rs"));
2098        assert!(!request.messages[0].string_contents().contains("file3.rs"));
2099
2100        assert!(!request.messages[1].string_contents().contains("file1.rs"));
2101        assert!(request.messages[1].string_contents().contains("file2.rs"));
2102        assert!(!request.messages[1].string_contents().contains("file3.rs"));
2103
2104        assert!(!request.messages[2].string_contents().contains("file1.rs"));
2105        assert!(!request.messages[2].string_contents().contains("file2.rs"));
2106        assert!(request.messages[2].string_contents().contains("file3.rs"));
2107    }
2108
2109    #[gpui::test]
2110    async fn test_message_without_files(cx: &mut TestAppContext) {
2111        init_test_settings(cx);
2112
2113        let project = create_test_project(
2114            cx,
2115            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
2116        )
2117        .await;
2118
2119        let (_, _thread_store, thread, _context_store) =
2120            setup_test_environment(cx, project.clone()).await;
2121
2122        // Insert user message without any context (empty context vector)
2123        let message_id = thread.update(cx, |thread, cx| {
2124            thread.insert_user_message("What is the best way to learn Rust?", vec![], None, cx)
2125        });
2126
2127        // Check content and context in message object
2128        let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2129
2130        // Context should be empty when no files are included
2131        assert_eq!(message.role, Role::User);
2132        assert_eq!(message.segments.len(), 1);
2133        assert_eq!(
2134            message.segments[0],
2135            MessageSegment::Text("What is the best way to learn Rust?".to_string())
2136        );
2137        assert_eq!(message.context, "");
2138
2139        // Check message in request
2140        let request = thread.read_with(cx, |thread, cx| {
2141            thread.to_completion_request(RequestKind::Chat, cx)
2142        });
2143
2144        assert_eq!(request.messages.len(), 1);
2145        assert_eq!(
2146            request.messages[0].string_contents(),
2147            "What is the best way to learn Rust?"
2148        );
2149
2150        // Add second message, also without context
2151        let message2_id = thread.update(cx, |thread, cx| {
2152            thread.insert_user_message("Are there any good books?", vec![], None, cx)
2153        });
2154
2155        let message2 =
2156            thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
2157        assert_eq!(message2.context, "");
2158
2159        // Check that both messages appear in the request
2160        let request = thread.read_with(cx, |thread, cx| {
2161            thread.to_completion_request(RequestKind::Chat, cx)
2162        });
2163
2164        assert_eq!(request.messages.len(), 2);
2165        assert_eq!(
2166            request.messages[0].string_contents(),
2167            "What is the best way to learn Rust?"
2168        );
2169        assert_eq!(
2170            request.messages[1].string_contents(),
2171            "Are there any good books?"
2172        );
2173    }
2174
2175    #[gpui::test]
2176    async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
2177        init_test_settings(cx);
2178
2179        let project = create_test_project(
2180            cx,
2181            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
2182        )
2183        .await;
2184
2185        let (_workspace, _thread_store, thread, context_store) =
2186            setup_test_environment(cx, project.clone()).await;
2187
2188        // Open buffer and add it to context
2189        let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
2190            .await
2191            .unwrap();
2192
2193        let context =
2194            context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
2195
2196        // Insert user message with the buffer as context
2197        thread.update(cx, |thread, cx| {
2198            thread.insert_user_message("Explain this code", vec![context], None, cx)
2199        });
2200
2201        // Create a request and check that it doesn't have a stale buffer warning yet
2202        let initial_request = thread.read_with(cx, |thread, cx| {
2203            thread.to_completion_request(RequestKind::Chat, cx)
2204        });
2205
2206        // Make sure we don't have a stale file warning yet
2207        let has_stale_warning = initial_request.messages.iter().any(|msg| {
2208            msg.string_contents()
2209                .contains("These files changed since last read:")
2210        });
2211        assert!(
2212            !has_stale_warning,
2213            "Should not have stale buffer warning before buffer is modified"
2214        );
2215
2216        // Modify the buffer
2217        buffer.update(cx, |buffer, cx| {
2218            // Find a position at the end of line 1
2219            buffer.edit(
2220                [(1..1, "\n    println!(\"Added a new line\");\n")],
2221                None,
2222                cx,
2223            );
2224        });
2225
2226        // Insert another user message without context
2227        thread.update(cx, |thread, cx| {
2228            thread.insert_user_message("What does the code do now?", vec![], None, cx)
2229        });
2230
2231        // Create a new request and check for the stale buffer warning
2232        let new_request = thread.read_with(cx, |thread, cx| {
2233            thread.to_completion_request(RequestKind::Chat, cx)
2234        });
2235
2236        // We should have a stale file warning as the last message
2237        let last_message = new_request
2238            .messages
2239            .last()
2240            .expect("Request should have messages");
2241
2242        // The last message should be the stale buffer notification
2243        assert_eq!(last_message.role, Role::User);
2244
2245        // Check the exact content of the message
2246        let expected_content = "These files changed since last read:\n- code.rs\n";
2247        assert_eq!(
2248            last_message.string_contents(),
2249            expected_content,
2250            "Last message should be exactly the stale buffer notification"
2251        );
2252    }
2253
2254    fn init_test_settings(cx: &mut TestAppContext) {
2255        cx.update(|cx| {
2256            let settings_store = SettingsStore::test(cx);
2257            cx.set_global(settings_store);
2258            language::init(cx);
2259            Project::init_settings(cx);
2260            AssistantSettings::register(cx);
2261            thread_store::init(cx);
2262            workspace::init_settings(cx);
2263            ThemeSettings::register(cx);
2264            ContextServerSettings::register(cx);
2265            EditorSettings::register(cx);
2266        });
2267    }
2268
2269    // Helper to create a test project with test files
2270    async fn create_test_project(
2271        cx: &mut TestAppContext,
2272        files: serde_json::Value,
2273    ) -> Entity<Project> {
2274        let fs = FakeFs::new(cx.executor());
2275        fs.insert_tree(path!("/test"), files).await;
2276        Project::test(fs, [path!("/test").as_ref()], cx).await
2277    }
2278
2279    async fn setup_test_environment(
2280        cx: &mut TestAppContext,
2281        project: Entity<Project>,
2282    ) -> (
2283        Entity<Workspace>,
2284        Entity<ThreadStore>,
2285        Entity<Thread>,
2286        Entity<ContextStore>,
2287    ) {
2288        let (workspace, cx) =
2289            cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
2290
2291        let thread_store = cx.update(|_, cx| {
2292            ThreadStore::new(
2293                project.clone(),
2294                Arc::default(),
2295                Arc::new(PromptBuilder::new(None).unwrap()),
2296                cx,
2297            )
2298            .unwrap()
2299        });
2300
2301        let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
2302        let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
2303
2304        (workspace, thread_store, thread, context_store)
2305    }
2306
2307    async fn add_file_to_context(
2308        project: &Entity<Project>,
2309        context_store: &Entity<ContextStore>,
2310        path: &str,
2311        cx: &mut TestAppContext,
2312    ) -> Result<Entity<language::Buffer>> {
2313        let buffer_path = project
2314            .read_with(cx, |project, cx| project.find_project_path(path, cx))
2315            .unwrap();
2316
2317        let buffer = project
2318            .update(cx, |project, cx| project.open_buffer(buffer_path, cx))
2319            .await
2320            .unwrap();
2321
2322        context_store
2323            .update(cx, |store, cx| {
2324                store.add_file_from_buffer(buffer.clone(), cx)
2325            })
2326            .await?;
2327
2328        Ok(buffer)
2329    }
2330}