thread.rs

   1use std::fmt::Write as _;
   2use std::io::Write;
   3use std::ops::Range;
   4use std::sync::Arc;
   5
   6use agent_rules::load_worktree_rules_file;
   7use anyhow::{Context as _, Result, anyhow};
   8use assistant_settings::AssistantSettings;
   9use assistant_tool::{ActionLog, Tool, ToolWorkingSet};
  10use chrono::{DateTime, Utc};
  11use collections::{BTreeMap, HashMap};
  12use fs::Fs;
  13use futures::future::Shared;
  14use futures::{FutureExt, StreamExt as _};
  15use git::repository::DiffType;
  16use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
  17use language_model::{
  18    ConfiguredModel, LanguageModel, LanguageModelCompletionEvent, LanguageModelRegistry,
  19    LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
  20    LanguageModelToolResult, LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
  21    PaymentRequiredError, Role, StopReason, TokenUsage,
  22};
  23use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
  24use project::{Project, Worktree};
  25use prompt_store::{AssistantSystemPromptContext, PromptBuilder, WorktreeInfoForSystemPrompt};
  26use schemars::JsonSchema;
  27use serde::{Deserialize, Serialize};
  28use settings::Settings;
  29use util::{ResultExt as _, TryFutureExt as _, post_inc};
  30use uuid::Uuid;
  31
  32use crate::context::{AssistantContext, ContextId, format_context_as_string};
  33use crate::thread_store::{
  34    SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult,
  35    SerializedToolUse,
  36};
  37use crate::tool_use::{PendingToolUse, ToolUse, ToolUseState, USING_TOOL_MARKER};
  38
  39#[derive(Debug, Clone, Copy)]
  40pub enum RequestKind {
  41    Chat,
  42    /// Used when summarizing a thread.
  43    Summarize,
  44}
  45
  46#[derive(
  47    Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
  48)]
  49pub struct ThreadId(Arc<str>);
  50
  51impl ThreadId {
  52    pub fn new() -> Self {
  53        Self(Uuid::new_v4().to_string().into())
  54    }
  55}
  56
  57impl std::fmt::Display for ThreadId {
  58    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  59        write!(f, "{}", self.0)
  60    }
  61}
  62
  63impl From<&str> for ThreadId {
  64    fn from(value: &str) -> Self {
  65        Self(value.into())
  66    }
  67}
  68
  69#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
  70pub struct MessageId(pub(crate) usize);
  71
  72impl MessageId {
  73    fn post_inc(&mut self) -> Self {
  74        Self(post_inc(&mut self.0))
  75    }
  76}
  77
  78/// A message in a [`Thread`].
  79#[derive(Debug, Clone)]
  80pub struct Message {
  81    pub id: MessageId,
  82    pub role: Role,
  83    pub segments: Vec<MessageSegment>,
  84    pub context: String,
  85}
  86
  87impl Message {
  88    /// Returns whether the message contains any meaningful text that should be displayed
  89    /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
  90    pub fn should_display_content(&self) -> bool {
  91        self.segments.iter().all(|segment| segment.should_display())
  92    }
  93
  94    pub fn push_thinking(&mut self, text: &str) {
  95        if let Some(MessageSegment::Thinking(segment)) = self.segments.last_mut() {
  96            segment.push_str(text);
  97        } else {
  98            self.segments
  99                .push(MessageSegment::Thinking(text.to_string()));
 100        }
 101    }
 102
 103    pub fn push_text(&mut self, text: &str) {
 104        if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
 105            segment.push_str(text);
 106        } else {
 107            self.segments.push(MessageSegment::Text(text.to_string()));
 108        }
 109    }
 110
 111    pub fn to_string(&self) -> String {
 112        let mut result = String::new();
 113
 114        if !self.context.is_empty() {
 115            result.push_str(&self.context);
 116        }
 117
 118        for segment in &self.segments {
 119            match segment {
 120                MessageSegment::Text(text) => result.push_str(text),
 121                MessageSegment::Thinking(text) => {
 122                    result.push_str("<think>");
 123                    result.push_str(text);
 124                    result.push_str("</think>");
 125                }
 126            }
 127        }
 128
 129        result
 130    }
 131}
 132
 133#[derive(Debug, Clone, PartialEq, Eq)]
 134pub enum MessageSegment {
 135    Text(String),
 136    Thinking(String),
 137}
 138
 139impl MessageSegment {
 140    pub fn text_mut(&mut self) -> &mut String {
 141        match self {
 142            Self::Text(text) => text,
 143            Self::Thinking(text) => text,
 144        }
 145    }
 146
 147    pub fn should_display(&self) -> bool {
 148        // We add USING_TOOL_MARKER when making a request that includes tool uses
 149        // without non-whitespace text around them, and this can cause the model
 150        // to mimic the pattern, so we consider those segments not displayable.
 151        match self {
 152            Self::Text(text) => text.is_empty() || text.trim() == USING_TOOL_MARKER,
 153            Self::Thinking(text) => text.is_empty() || text.trim() == USING_TOOL_MARKER,
 154        }
 155    }
 156}
 157
 158#[derive(Debug, Clone, Serialize, Deserialize)]
 159pub struct ProjectSnapshot {
 160    pub worktree_snapshots: Vec<WorktreeSnapshot>,
 161    pub unsaved_buffer_paths: Vec<String>,
 162    pub timestamp: DateTime<Utc>,
 163}
 164
 165#[derive(Debug, Clone, Serialize, Deserialize)]
 166pub struct WorktreeSnapshot {
 167    pub worktree_path: String,
 168    pub git_state: Option<GitState>,
 169}
 170
 171#[derive(Debug, Clone, Serialize, Deserialize)]
 172pub struct GitState {
 173    pub remote_url: Option<String>,
 174    pub head_sha: Option<String>,
 175    pub current_branch: Option<String>,
 176    pub diff: Option<String>,
 177}
 178
 179#[derive(Clone)]
 180pub struct ThreadCheckpoint {
 181    message_id: MessageId,
 182    git_checkpoint: GitStoreCheckpoint,
 183}
 184
 185#[derive(Copy, Clone, Debug)]
 186pub enum ThreadFeedback {
 187    Positive,
 188    Negative,
 189}
 190
 191pub enum LastRestoreCheckpoint {
 192    Pending {
 193        message_id: MessageId,
 194    },
 195    Error {
 196        message_id: MessageId,
 197        error: String,
 198    },
 199}
 200
 201impl LastRestoreCheckpoint {
 202    pub fn message_id(&self) -> MessageId {
 203        match self {
 204            LastRestoreCheckpoint::Pending { message_id } => *message_id,
 205            LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
 206        }
 207    }
 208}
 209
 210#[derive(Clone, Debug, Default, Serialize, Deserialize)]
 211pub enum DetailedSummaryState {
 212    #[default]
 213    NotGenerated,
 214    Generating {
 215        message_id: MessageId,
 216    },
 217    Generated {
 218        text: SharedString,
 219        message_id: MessageId,
 220    },
 221}
 222
 223#[derive(Default)]
 224pub struct TotalTokenUsage {
 225    pub total: usize,
 226    pub max: usize,
 227    pub ratio: TokenUsageRatio,
 228}
 229
 230#[derive(Default, PartialEq, Eq)]
 231pub enum TokenUsageRatio {
 232    #[default]
 233    Normal,
 234    Warning,
 235    Exceeded,
 236}
 237
 238/// A thread of conversation with the LLM.
 239pub struct Thread {
 240    id: ThreadId,
 241    updated_at: DateTime<Utc>,
 242    summary: Option<SharedString>,
 243    pending_summary: Task<Option<()>>,
 244    detailed_summary_state: DetailedSummaryState,
 245    messages: Vec<Message>,
 246    next_message_id: MessageId,
 247    context: BTreeMap<ContextId, AssistantContext>,
 248    context_by_message: HashMap<MessageId, Vec<ContextId>>,
 249    system_prompt_context: Option<AssistantSystemPromptContext>,
 250    checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
 251    completion_count: usize,
 252    pending_completions: Vec<PendingCompletion>,
 253    project: Entity<Project>,
 254    prompt_builder: Arc<PromptBuilder>,
 255    tools: Arc<ToolWorkingSet>,
 256    tool_use: ToolUseState,
 257    action_log: Entity<ActionLog>,
 258    last_restore_checkpoint: Option<LastRestoreCheckpoint>,
 259    pending_checkpoint: Option<ThreadCheckpoint>,
 260    initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
 261    cumulative_token_usage: TokenUsage,
 262    feedback: Option<ThreadFeedback>,
 263}
 264
 265impl Thread {
 266    pub fn new(
 267        project: Entity<Project>,
 268        tools: Arc<ToolWorkingSet>,
 269        prompt_builder: Arc<PromptBuilder>,
 270        cx: &mut Context<Self>,
 271    ) -> Self {
 272        Self {
 273            id: ThreadId::new(),
 274            updated_at: Utc::now(),
 275            summary: None,
 276            pending_summary: Task::ready(None),
 277            detailed_summary_state: DetailedSummaryState::NotGenerated,
 278            messages: Vec::new(),
 279            next_message_id: MessageId(0),
 280            context: BTreeMap::default(),
 281            context_by_message: HashMap::default(),
 282            system_prompt_context: None,
 283            checkpoints_by_message: HashMap::default(),
 284            completion_count: 0,
 285            pending_completions: Vec::new(),
 286            project: project.clone(),
 287            prompt_builder,
 288            tools: tools.clone(),
 289            last_restore_checkpoint: None,
 290            pending_checkpoint: None,
 291            tool_use: ToolUseState::new(tools.clone()),
 292            action_log: cx.new(|_| ActionLog::new(project.clone())),
 293            initial_project_snapshot: {
 294                let project_snapshot = Self::project_snapshot(project, cx);
 295                cx.foreground_executor()
 296                    .spawn(async move { Some(project_snapshot.await) })
 297                    .shared()
 298            },
 299            cumulative_token_usage: TokenUsage::default(),
 300            feedback: None,
 301        }
 302    }
 303
 304    pub fn deserialize(
 305        id: ThreadId,
 306        serialized: SerializedThread,
 307        project: Entity<Project>,
 308        tools: Arc<ToolWorkingSet>,
 309        prompt_builder: Arc<PromptBuilder>,
 310        cx: &mut Context<Self>,
 311    ) -> Self {
 312        let next_message_id = MessageId(
 313            serialized
 314                .messages
 315                .last()
 316                .map(|message| message.id.0 + 1)
 317                .unwrap_or(0),
 318        );
 319        let tool_use =
 320            ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages, |_| true);
 321
 322        Self {
 323            id,
 324            updated_at: serialized.updated_at,
 325            summary: Some(serialized.summary),
 326            pending_summary: Task::ready(None),
 327            detailed_summary_state: serialized.detailed_summary_state,
 328            messages: serialized
 329                .messages
 330                .into_iter()
 331                .map(|message| Message {
 332                    id: message.id,
 333                    role: message.role,
 334                    segments: message
 335                        .segments
 336                        .into_iter()
 337                        .map(|segment| match segment {
 338                            SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
 339                            SerializedMessageSegment::Thinking { text } => {
 340                                MessageSegment::Thinking(text)
 341                            }
 342                        })
 343                        .collect(),
 344                    context: message.context,
 345                })
 346                .collect(),
 347            next_message_id,
 348            context: BTreeMap::default(),
 349            context_by_message: HashMap::default(),
 350            system_prompt_context: None,
 351            checkpoints_by_message: HashMap::default(),
 352            completion_count: 0,
 353            pending_completions: Vec::new(),
 354            last_restore_checkpoint: None,
 355            pending_checkpoint: None,
 356            project: project.clone(),
 357            prompt_builder,
 358            tools,
 359            tool_use,
 360            action_log: cx.new(|_| ActionLog::new(project)),
 361            initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
 362            cumulative_token_usage: serialized.cumulative_token_usage,
 363            feedback: None,
 364        }
 365    }
 366
 367    pub fn id(&self) -> &ThreadId {
 368        &self.id
 369    }
 370
 371    pub fn is_empty(&self) -> bool {
 372        self.messages.is_empty()
 373    }
 374
 375    pub fn updated_at(&self) -> DateTime<Utc> {
 376        self.updated_at
 377    }
 378
 379    pub fn touch_updated_at(&mut self) {
 380        self.updated_at = Utc::now();
 381    }
 382
 383    pub fn summary(&self) -> Option<SharedString> {
 384        self.summary.clone()
 385    }
 386
 387    pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
 388
 389    pub fn summary_or_default(&self) -> SharedString {
 390        self.summary.clone().unwrap_or(Self::DEFAULT_SUMMARY)
 391    }
 392
 393    pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
 394        let Some(current_summary) = &self.summary else {
 395            // Don't allow setting summary until generated
 396            return;
 397        };
 398
 399        let mut new_summary = new_summary.into();
 400
 401        if new_summary.is_empty() {
 402            new_summary = Self::DEFAULT_SUMMARY;
 403        }
 404
 405        if current_summary != &new_summary {
 406            self.summary = Some(new_summary);
 407            cx.emit(ThreadEvent::SummaryChanged);
 408        }
 409    }
 410
 411    pub fn latest_detailed_summary_or_text(&self) -> SharedString {
 412        self.latest_detailed_summary()
 413            .unwrap_or_else(|| self.text().into())
 414    }
 415
 416    fn latest_detailed_summary(&self) -> Option<SharedString> {
 417        if let DetailedSummaryState::Generated { text, .. } = &self.detailed_summary_state {
 418            Some(text.clone())
 419        } else {
 420            None
 421        }
 422    }
 423
 424    pub fn message(&self, id: MessageId) -> Option<&Message> {
 425        self.messages.iter().find(|message| message.id == id)
 426    }
 427
 428    pub fn messages(&self) -> impl Iterator<Item = &Message> {
 429        self.messages.iter()
 430    }
 431
 432    pub fn is_generating(&self) -> bool {
 433        !self.pending_completions.is_empty() || !self.all_tools_finished()
 434    }
 435
 436    pub fn tools(&self) -> &Arc<ToolWorkingSet> {
 437        &self.tools
 438    }
 439
 440    pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
 441        self.tool_use
 442            .pending_tool_uses()
 443            .into_iter()
 444            .find(|tool_use| &tool_use.id == id)
 445    }
 446
 447    pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
 448        self.tool_use
 449            .pending_tool_uses()
 450            .into_iter()
 451            .filter(|tool_use| tool_use.status.needs_confirmation())
 452    }
 453
 454    pub fn has_pending_tool_uses(&self) -> bool {
 455        !self.tool_use.pending_tool_uses().is_empty()
 456    }
 457
 458    pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
 459        self.checkpoints_by_message.get(&id).cloned()
 460    }
 461
 462    pub fn restore_checkpoint(
 463        &mut self,
 464        checkpoint: ThreadCheckpoint,
 465        cx: &mut Context<Self>,
 466    ) -> Task<Result<()>> {
 467        self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
 468            message_id: checkpoint.message_id,
 469        });
 470        cx.emit(ThreadEvent::CheckpointChanged);
 471        cx.notify();
 472
 473        let git_store = self.project().read(cx).git_store().clone();
 474        let restore = git_store.update(cx, |git_store, cx| {
 475            git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
 476        });
 477
 478        cx.spawn(async move |this, cx| {
 479            let result = restore.await;
 480            this.update(cx, |this, cx| {
 481                if let Err(err) = result.as_ref() {
 482                    this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
 483                        message_id: checkpoint.message_id,
 484                        error: err.to_string(),
 485                    });
 486                } else {
 487                    this.truncate(checkpoint.message_id, cx);
 488                    this.last_restore_checkpoint = None;
 489                }
 490                this.pending_checkpoint = None;
 491                cx.emit(ThreadEvent::CheckpointChanged);
 492                cx.notify();
 493            })?;
 494            result
 495        })
 496    }
 497
 498    fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
 499        let pending_checkpoint = if self.is_generating() {
 500            return;
 501        } else if let Some(checkpoint) = self.pending_checkpoint.take() {
 502            checkpoint
 503        } else {
 504            return;
 505        };
 506
 507        let git_store = self.project.read(cx).git_store().clone();
 508        let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
 509        cx.spawn(async move |this, cx| match final_checkpoint.await {
 510            Ok(final_checkpoint) => {
 511                let equal = git_store
 512                    .update(cx, |store, cx| {
 513                        store.compare_checkpoints(
 514                            pending_checkpoint.git_checkpoint.clone(),
 515                            final_checkpoint.clone(),
 516                            cx,
 517                        )
 518                    })?
 519                    .await
 520                    .unwrap_or(false);
 521
 522                if equal {
 523                    git_store
 524                        .update(cx, |store, cx| {
 525                            store.delete_checkpoint(pending_checkpoint.git_checkpoint, cx)
 526                        })?
 527                        .detach();
 528                } else {
 529                    this.update(cx, |this, cx| {
 530                        this.insert_checkpoint(pending_checkpoint, cx)
 531                    })?;
 532                }
 533
 534                git_store
 535                    .update(cx, |store, cx| {
 536                        store.delete_checkpoint(final_checkpoint, cx)
 537                    })?
 538                    .detach();
 539
 540                Ok(())
 541            }
 542            Err(_) => this.update(cx, |this, cx| {
 543                this.insert_checkpoint(pending_checkpoint, cx)
 544            }),
 545        })
 546        .detach();
 547    }
 548
 549    fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
 550        self.checkpoints_by_message
 551            .insert(checkpoint.message_id, checkpoint);
 552        cx.emit(ThreadEvent::CheckpointChanged);
 553        cx.notify();
 554    }
 555
 556    pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
 557        self.last_restore_checkpoint.as_ref()
 558    }
 559
 560    pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
 561        let Some(message_ix) = self
 562            .messages
 563            .iter()
 564            .rposition(|message| message.id == message_id)
 565        else {
 566            return;
 567        };
 568        for deleted_message in self.messages.drain(message_ix..) {
 569            self.context_by_message.remove(&deleted_message.id);
 570            self.checkpoints_by_message.remove(&deleted_message.id);
 571        }
 572        cx.notify();
 573    }
 574
 575    pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AssistantContext> {
 576        self.context_by_message
 577            .get(&id)
 578            .into_iter()
 579            .flat_map(|context| {
 580                context
 581                    .iter()
 582                    .filter_map(|context_id| self.context.get(&context_id))
 583            })
 584    }
 585
 586    /// Returns whether all of the tool uses have finished running.
 587    pub fn all_tools_finished(&self) -> bool {
 588        // If the only pending tool uses left are the ones with errors, then
 589        // that means that we've finished running all of the pending tools.
 590        self.tool_use
 591            .pending_tool_uses()
 592            .iter()
 593            .all(|tool_use| tool_use.status.is_error())
 594    }
 595
 596    pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
 597        self.tool_use.tool_uses_for_message(id, cx)
 598    }
 599
 600    pub fn tool_results_for_message(&self, id: MessageId) -> Vec<&LanguageModelToolResult> {
 601        self.tool_use.tool_results_for_message(id)
 602    }
 603
 604    pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
 605        self.tool_use.tool_result(id)
 606    }
 607
 608    pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
 609        self.tool_use.message_has_tool_results(message_id)
 610    }
 611
 612    pub fn insert_user_message(
 613        &mut self,
 614        text: impl Into<String>,
 615        context: Vec<AssistantContext>,
 616        git_checkpoint: Option<GitStoreCheckpoint>,
 617        cx: &mut Context<Self>,
 618    ) -> MessageId {
 619        let text = text.into();
 620
 621        let message_id = self.insert_message(Role::User, vec![MessageSegment::Text(text)], cx);
 622
 623        // Filter out contexts that have already been included in previous messages
 624        let new_context: Vec<_> = context
 625            .into_iter()
 626            .filter(|ctx| !self.context.contains_key(&ctx.id()))
 627            .collect();
 628
 629        if !new_context.is_empty() {
 630            if let Some(context_string) = format_context_as_string(new_context.iter(), cx) {
 631                if let Some(message) = self.messages.iter_mut().find(|m| m.id == message_id) {
 632                    message.context = context_string;
 633                }
 634            }
 635
 636            self.action_log.update(cx, |log, cx| {
 637                // Track all buffers added as context
 638                for ctx in &new_context {
 639                    match ctx {
 640                        AssistantContext::File(file_ctx) => {
 641                            log.buffer_added_as_context(file_ctx.context_buffer.buffer.clone(), cx);
 642                        }
 643                        AssistantContext::Directory(dir_ctx) => {
 644                            for context_buffer in &dir_ctx.context_buffers {
 645                                log.buffer_added_as_context(context_buffer.buffer.clone(), cx);
 646                            }
 647                        }
 648                        AssistantContext::Symbol(symbol_ctx) => {
 649                            log.buffer_added_as_context(
 650                                symbol_ctx.context_symbol.buffer.clone(),
 651                                cx,
 652                            );
 653                        }
 654                        AssistantContext::FetchedUrl(_) | AssistantContext::Thread(_) => {}
 655                    }
 656                }
 657            });
 658        }
 659
 660        let context_ids = new_context
 661            .iter()
 662            .map(|context| context.id())
 663            .collect::<Vec<_>>();
 664        self.context.extend(
 665            new_context
 666                .into_iter()
 667                .map(|context| (context.id(), context)),
 668        );
 669        self.context_by_message.insert(message_id, context_ids);
 670
 671        if let Some(git_checkpoint) = git_checkpoint {
 672            self.pending_checkpoint = Some(ThreadCheckpoint {
 673                message_id,
 674                git_checkpoint,
 675            });
 676        }
 677        message_id
 678    }
 679
 680    pub fn insert_message(
 681        &mut self,
 682        role: Role,
 683        segments: Vec<MessageSegment>,
 684        cx: &mut Context<Self>,
 685    ) -> MessageId {
 686        let id = self.next_message_id.post_inc();
 687        self.messages.push(Message {
 688            id,
 689            role,
 690            segments,
 691            context: String::new(),
 692        });
 693        self.touch_updated_at();
 694        cx.emit(ThreadEvent::MessageAdded(id));
 695        id
 696    }
 697
 698    pub fn edit_message(
 699        &mut self,
 700        id: MessageId,
 701        new_role: Role,
 702        new_segments: Vec<MessageSegment>,
 703        cx: &mut Context<Self>,
 704    ) -> bool {
 705        let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
 706            return false;
 707        };
 708        message.role = new_role;
 709        message.segments = new_segments;
 710        self.touch_updated_at();
 711        cx.emit(ThreadEvent::MessageEdited(id));
 712        true
 713    }
 714
 715    pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
 716        let Some(index) = self.messages.iter().position(|message| message.id == id) else {
 717            return false;
 718        };
 719        self.messages.remove(index);
 720        self.context_by_message.remove(&id);
 721        self.touch_updated_at();
 722        cx.emit(ThreadEvent::MessageDeleted(id));
 723        true
 724    }
 725
 726    /// Returns the representation of this [`Thread`] in a textual form.
 727    ///
 728    /// This is the representation we use when attaching a thread as context to another thread.
 729    pub fn text(&self) -> String {
 730        let mut text = String::new();
 731
 732        for message in &self.messages {
 733            text.push_str(match message.role {
 734                language_model::Role::User => "User:",
 735                language_model::Role::Assistant => "Assistant:",
 736                language_model::Role::System => "System:",
 737            });
 738            text.push('\n');
 739
 740            for segment in &message.segments {
 741                match segment {
 742                    MessageSegment::Text(content) => text.push_str(content),
 743                    MessageSegment::Thinking(content) => {
 744                        text.push_str(&format!("<think>{}</think>", content))
 745                    }
 746                }
 747            }
 748            text.push('\n');
 749        }
 750
 751        text
 752    }
 753
 754    /// Serializes this thread into a format for storage or telemetry.
 755    pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
 756        let initial_project_snapshot = self.initial_project_snapshot.clone();
 757        cx.spawn(async move |this, cx| {
 758            let initial_project_snapshot = initial_project_snapshot.await;
 759            this.read_with(cx, |this, cx| SerializedThread {
 760                version: SerializedThread::VERSION.to_string(),
 761                summary: this.summary_or_default(),
 762                updated_at: this.updated_at(),
 763                messages: this
 764                    .messages()
 765                    .map(|message| SerializedMessage {
 766                        id: message.id,
 767                        role: message.role,
 768                        segments: message
 769                            .segments
 770                            .iter()
 771                            .map(|segment| match segment {
 772                                MessageSegment::Text(text) => {
 773                                    SerializedMessageSegment::Text { text: text.clone() }
 774                                }
 775                                MessageSegment::Thinking(text) => {
 776                                    SerializedMessageSegment::Thinking { text: text.clone() }
 777                                }
 778                            })
 779                            .collect(),
 780                        tool_uses: this
 781                            .tool_uses_for_message(message.id, cx)
 782                            .into_iter()
 783                            .map(|tool_use| SerializedToolUse {
 784                                id: tool_use.id,
 785                                name: tool_use.name,
 786                                input: tool_use.input,
 787                            })
 788                            .collect(),
 789                        tool_results: this
 790                            .tool_results_for_message(message.id)
 791                            .into_iter()
 792                            .map(|tool_result| SerializedToolResult {
 793                                tool_use_id: tool_result.tool_use_id.clone(),
 794                                is_error: tool_result.is_error,
 795                                content: tool_result.content.clone(),
 796                            })
 797                            .collect(),
 798                        context: message.context.clone(),
 799                    })
 800                    .collect(),
 801                initial_project_snapshot,
 802                cumulative_token_usage: this.cumulative_token_usage.clone(),
 803                detailed_summary_state: this.detailed_summary_state.clone(),
 804            })
 805        })
 806    }
 807
 808    pub fn set_system_prompt_context(&mut self, context: AssistantSystemPromptContext) {
 809        self.system_prompt_context = Some(context);
 810    }
 811
 812    pub fn system_prompt_context(&self) -> &Option<AssistantSystemPromptContext> {
 813        &self.system_prompt_context
 814    }
 815
 816    pub fn load_system_prompt_context(
 817        &self,
 818        cx: &App,
 819    ) -> Task<(AssistantSystemPromptContext, Option<ThreadError>)> {
 820        let project = self.project.read(cx);
 821        let tasks = project
 822            .visible_worktrees(cx)
 823            .map(|worktree| {
 824                Self::load_worktree_info_for_system_prompt(
 825                    project.fs().clone(),
 826                    worktree.read(cx),
 827                    cx,
 828                )
 829            })
 830            .collect::<Vec<_>>();
 831
 832        cx.spawn(async |_cx| {
 833            let results = futures::future::join_all(tasks).await;
 834            let mut first_err = None;
 835            let worktrees = results
 836                .into_iter()
 837                .map(|(worktree, err)| {
 838                    if first_err.is_none() && err.is_some() {
 839                        first_err = err;
 840                    }
 841                    worktree
 842                })
 843                .collect::<Vec<_>>();
 844            (AssistantSystemPromptContext::new(worktrees), first_err)
 845        })
 846    }
 847
 848    fn load_worktree_info_for_system_prompt(
 849        fs: Arc<dyn Fs>,
 850        worktree: &Worktree,
 851        cx: &App,
 852    ) -> Task<(WorktreeInfoForSystemPrompt, Option<ThreadError>)> {
 853        let root_name = worktree.root_name().into();
 854        let abs_path = worktree.abs_path();
 855
 856        let rules_task = load_worktree_rules_file(fs, worktree, cx);
 857        let Some(rules_task) = rules_task else {
 858            return Task::ready((
 859                WorktreeInfoForSystemPrompt {
 860                    root_name,
 861                    abs_path,
 862                    rules_file: None,
 863                },
 864                None,
 865            ));
 866        };
 867
 868        cx.spawn(async move |_| {
 869            let (rules_file, rules_file_error) = match rules_task.await {
 870                Ok(rules_file) => (Some(rules_file), None),
 871                Err(err) => (
 872                    None,
 873                    Some(ThreadError::Message {
 874                        header: "Error loading rules file".into(),
 875                        message: format!("{err}").into(),
 876                    }),
 877                ),
 878            };
 879            let worktree_info = WorktreeInfoForSystemPrompt {
 880                root_name,
 881                abs_path,
 882                rules_file,
 883            };
 884            (worktree_info, rules_file_error)
 885        })
 886    }
 887
 888    pub fn send_to_model(
 889        &mut self,
 890        model: Arc<dyn LanguageModel>,
 891        request_kind: RequestKind,
 892        cx: &mut Context<Self>,
 893    ) {
 894        let mut request = self.to_completion_request(request_kind, cx);
 895        if model.supports_tools() {
 896            request.tools = {
 897                let mut tools = Vec::new();
 898                tools.extend(self.tools().enabled_tools(cx).into_iter().map(|tool| {
 899                    LanguageModelRequestTool {
 900                        name: tool.name(),
 901                        description: tool.description(),
 902                        input_schema: tool.input_schema(model.tool_input_format()),
 903                    }
 904                }));
 905
 906                tools
 907            };
 908        }
 909
 910        self.stream_completion(request, model, cx);
 911    }
 912
 913    pub fn used_tools_since_last_user_message(&self) -> bool {
 914        for message in self.messages.iter().rev() {
 915            if self.tool_use.message_has_tool_results(message.id) {
 916                return true;
 917            } else if message.role == Role::User {
 918                return false;
 919            }
 920        }
 921
 922        false
 923    }
 924
 925    pub fn to_completion_request(
 926        &self,
 927        request_kind: RequestKind,
 928        cx: &App,
 929    ) -> LanguageModelRequest {
 930        let mut request = LanguageModelRequest {
 931            messages: vec![],
 932            tools: Vec::new(),
 933            stop: Vec::new(),
 934            temperature: None,
 935        };
 936
 937        if let Some(system_prompt_context) = self.system_prompt_context.as_ref() {
 938            if let Some(system_prompt) = self
 939                .prompt_builder
 940                .generate_assistant_system_prompt(system_prompt_context)
 941                .context("failed to generate assistant system prompt")
 942                .log_err()
 943            {
 944                request.messages.push(LanguageModelRequestMessage {
 945                    role: Role::System,
 946                    content: vec![MessageContent::Text(system_prompt)],
 947                    cache: true,
 948                });
 949            }
 950        } else {
 951            log::error!("system_prompt_context not set.")
 952        }
 953
 954        for message in &self.messages {
 955            let mut request_message = LanguageModelRequestMessage {
 956                role: message.role,
 957                content: Vec::new(),
 958                cache: false,
 959            };
 960
 961            match request_kind {
 962                RequestKind::Chat => {
 963                    self.tool_use
 964                        .attach_tool_results(message.id, &mut request_message);
 965                }
 966                RequestKind::Summarize => {
 967                    // We don't care about tool use during summarization.
 968                    if self.tool_use.message_has_tool_results(message.id) {
 969                        continue;
 970                    }
 971                }
 972            }
 973
 974            if !message.segments.is_empty() {
 975                request_message
 976                    .content
 977                    .push(MessageContent::Text(message.to_string()));
 978            }
 979
 980            match request_kind {
 981                RequestKind::Chat => {
 982                    self.tool_use
 983                        .attach_tool_uses(message.id, &mut request_message);
 984                }
 985                RequestKind::Summarize => {
 986                    // We don't care about tool use during summarization.
 987                }
 988            };
 989
 990            request.messages.push(request_message);
 991        }
 992
 993        // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
 994        if let Some(last) = request.messages.last_mut() {
 995            last.cache = true;
 996        }
 997
 998        self.attached_tracked_files_state(&mut request.messages, cx);
 999
1000        request
1001    }
1002
1003    fn attached_tracked_files_state(
1004        &self,
1005        messages: &mut Vec<LanguageModelRequestMessage>,
1006        cx: &App,
1007    ) {
1008        const STALE_FILES_HEADER: &str = "These files changed since last read:";
1009
1010        let mut stale_message = String::new();
1011
1012        let action_log = self.action_log.read(cx);
1013
1014        for stale_file in action_log.stale_buffers(cx) {
1015            let Some(file) = stale_file.read(cx).file() else {
1016                continue;
1017            };
1018
1019            if stale_message.is_empty() {
1020                write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
1021            }
1022
1023            writeln!(&mut stale_message, "- {}", file.path().display()).ok();
1024        }
1025
1026        let mut content = Vec::with_capacity(2);
1027
1028        if !stale_message.is_empty() {
1029            content.push(stale_message.into());
1030        }
1031
1032        if action_log.has_edited_files_since_project_diagnostics_check() {
1033            content.push(
1034                "\n\nWhen you're done making changes, make sure to check project diagnostics \
1035                and fix all errors AND warnings you introduced! \
1036                DO NOT mention you're going to do this until you're done."
1037                    .into(),
1038            );
1039        }
1040
1041        if !content.is_empty() {
1042            let context_message = LanguageModelRequestMessage {
1043                role: Role::User,
1044                content,
1045                cache: false,
1046            };
1047
1048            messages.push(context_message);
1049        }
1050    }
1051
1052    pub fn stream_completion(
1053        &mut self,
1054        request: LanguageModelRequest,
1055        model: Arc<dyn LanguageModel>,
1056        cx: &mut Context<Self>,
1057    ) {
1058        let pending_completion_id = post_inc(&mut self.completion_count);
1059
1060        let task = cx.spawn(async move |thread, cx| {
1061            let stream = model.stream_completion(request, &cx);
1062            let initial_token_usage =
1063                thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage.clone());
1064            let stream_completion = async {
1065                let mut events = stream.await?;
1066                let mut stop_reason = StopReason::EndTurn;
1067                let mut current_token_usage = TokenUsage::default();
1068
1069                while let Some(event) = events.next().await {
1070                    let event = event?;
1071
1072                    thread.update(cx, |thread, cx| {
1073                        match event {
1074                            LanguageModelCompletionEvent::StartMessage { .. } => {
1075                                thread.insert_message(
1076                                    Role::Assistant,
1077                                    vec![MessageSegment::Text(String::new())],
1078                                    cx,
1079                                );
1080                            }
1081                            LanguageModelCompletionEvent::Stop(reason) => {
1082                                stop_reason = reason;
1083                            }
1084                            LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1085                                thread.cumulative_token_usage =
1086                                    thread.cumulative_token_usage.clone() + token_usage.clone()
1087                                        - current_token_usage.clone();
1088                                current_token_usage = token_usage;
1089                            }
1090                            LanguageModelCompletionEvent::Text(chunk) => {
1091                                if let Some(last_message) = thread.messages.last_mut() {
1092                                    if last_message.role == Role::Assistant {
1093                                        last_message.push_text(&chunk);
1094                                        cx.emit(ThreadEvent::StreamedAssistantText(
1095                                            last_message.id,
1096                                            chunk,
1097                                        ));
1098                                    } else {
1099                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
1100                                        // of a new Assistant response.
1101                                        //
1102                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1103                                        // will result in duplicating the text of the chunk in the rendered Markdown.
1104                                        thread.insert_message(
1105                                            Role::Assistant,
1106                                            vec![MessageSegment::Text(chunk.to_string())],
1107                                            cx,
1108                                        );
1109                                    };
1110                                }
1111                            }
1112                            LanguageModelCompletionEvent::Thinking(chunk) => {
1113                                if let Some(last_message) = thread.messages.last_mut() {
1114                                    if last_message.role == Role::Assistant {
1115                                        last_message.push_thinking(&chunk);
1116                                        cx.emit(ThreadEvent::StreamedAssistantThinking(
1117                                            last_message.id,
1118                                            chunk,
1119                                        ));
1120                                    } else {
1121                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
1122                                        // of a new Assistant response.
1123                                        //
1124                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1125                                        // will result in duplicating the text of the chunk in the rendered Markdown.
1126                                        thread.insert_message(
1127                                            Role::Assistant,
1128                                            vec![MessageSegment::Thinking(chunk.to_string())],
1129                                            cx,
1130                                        );
1131                                    };
1132                                }
1133                            }
1134                            LanguageModelCompletionEvent::ToolUse(tool_use) => {
1135                                let last_assistant_message_id = thread
1136                                    .messages
1137                                    .iter_mut()
1138                                    .rfind(|message| message.role == Role::Assistant)
1139                                    .map(|message| message.id)
1140                                    .unwrap_or_else(|| {
1141                                        thread.insert_message(Role::Assistant, vec![], cx)
1142                                    });
1143
1144                                thread.tool_use.request_tool_use(
1145                                    last_assistant_message_id,
1146                                    tool_use,
1147                                    cx,
1148                                );
1149                            }
1150                        }
1151
1152                        thread.touch_updated_at();
1153                        cx.emit(ThreadEvent::StreamedCompletion);
1154                        cx.notify();
1155                    })?;
1156
1157                    smol::future::yield_now().await;
1158                }
1159
1160                thread.update(cx, |thread, cx| {
1161                    thread
1162                        .pending_completions
1163                        .retain(|completion| completion.id != pending_completion_id);
1164
1165                    if thread.summary.is_none() && thread.messages.len() >= 2 {
1166                        thread.summarize(cx);
1167                    }
1168                })?;
1169
1170                anyhow::Ok(stop_reason)
1171            };
1172
1173            let result = stream_completion.await;
1174
1175            thread
1176                .update(cx, |thread, cx| {
1177                    thread.finalize_pending_checkpoint(cx);
1178                    match result.as_ref() {
1179                        Ok(stop_reason) => match stop_reason {
1180                            StopReason::ToolUse => {
1181                                cx.emit(ThreadEvent::UsePendingTools);
1182                            }
1183                            StopReason::EndTurn => {}
1184                            StopReason::MaxTokens => {}
1185                        },
1186                        Err(error) => {
1187                            if error.is::<PaymentRequiredError>() {
1188                                cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1189                            } else if error.is::<MaxMonthlySpendReachedError>() {
1190                                cx.emit(ThreadEvent::ShowError(
1191                                    ThreadError::MaxMonthlySpendReached,
1192                                ));
1193                            } else {
1194                                let error_message = error
1195                                    .chain()
1196                                    .map(|err| err.to_string())
1197                                    .collect::<Vec<_>>()
1198                                    .join("\n");
1199                                cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1200                                    header: "Error interacting with language model".into(),
1201                                    message: SharedString::from(error_message.clone()),
1202                                }));
1203                            }
1204
1205                            thread.cancel_last_completion(cx);
1206                        }
1207                    }
1208                    cx.emit(ThreadEvent::DoneStreaming);
1209
1210                    if let Ok(initial_usage) = initial_token_usage {
1211                        let usage = thread.cumulative_token_usage.clone() - initial_usage;
1212
1213                        telemetry::event!(
1214                            "Assistant Thread Completion",
1215                            thread_id = thread.id().to_string(),
1216                            model = model.telemetry_id(),
1217                            model_provider = model.provider_id().to_string(),
1218                            input_tokens = usage.input_tokens,
1219                            output_tokens = usage.output_tokens,
1220                            cache_creation_input_tokens = usage.cache_creation_input_tokens,
1221                            cache_read_input_tokens = usage.cache_read_input_tokens,
1222                        );
1223                    }
1224                })
1225                .ok();
1226        });
1227
1228        self.pending_completions.push(PendingCompletion {
1229            id: pending_completion_id,
1230            _task: task,
1231        });
1232    }
1233
1234    pub fn summarize(&mut self, cx: &mut Context<Self>) {
1235        let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1236            return;
1237        };
1238
1239        if !model.provider.is_authenticated(cx) {
1240            return;
1241        }
1242
1243        let mut request = self.to_completion_request(RequestKind::Summarize, cx);
1244        request.messages.push(LanguageModelRequestMessage {
1245            role: Role::User,
1246            content: vec![
1247                "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1248                 Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1249                 If the conversation is about a specific subject, include it in the title. \
1250                 Be descriptive. DO NOT speak in the first person."
1251                    .into(),
1252            ],
1253            cache: false,
1254        });
1255
1256        self.pending_summary = cx.spawn(async move |this, cx| {
1257            async move {
1258                let stream = model.model.stream_completion_text(request, &cx);
1259                let mut messages = stream.await?;
1260
1261                let mut new_summary = String::new();
1262                while let Some(message) = messages.stream.next().await {
1263                    let text = message?;
1264                    let mut lines = text.lines();
1265                    new_summary.extend(lines.next());
1266
1267                    // Stop if the LLM generated multiple lines.
1268                    if lines.next().is_some() {
1269                        break;
1270                    }
1271                }
1272
1273                this.update(cx, |this, cx| {
1274                    if !new_summary.is_empty() {
1275                        this.summary = Some(new_summary.into());
1276                    }
1277
1278                    cx.emit(ThreadEvent::SummaryGenerated);
1279                })?;
1280
1281                anyhow::Ok(())
1282            }
1283            .log_err()
1284            .await
1285        });
1286    }
1287
1288    pub fn generate_detailed_summary(&mut self, cx: &mut Context<Self>) -> Option<Task<()>> {
1289        let last_message_id = self.messages.last().map(|message| message.id)?;
1290
1291        match &self.detailed_summary_state {
1292            DetailedSummaryState::Generating { message_id, .. }
1293            | DetailedSummaryState::Generated { message_id, .. }
1294                if *message_id == last_message_id =>
1295            {
1296                // Already up-to-date
1297                return None;
1298            }
1299            _ => {}
1300        }
1301
1302        let ConfiguredModel { model, provider } =
1303            LanguageModelRegistry::read_global(cx).thread_summary_model()?;
1304
1305        if !provider.is_authenticated(cx) {
1306            return None;
1307        }
1308
1309        let mut request = self.to_completion_request(RequestKind::Summarize, cx);
1310
1311        request.messages.push(LanguageModelRequestMessage {
1312            role: Role::User,
1313            content: vec![
1314                "Generate a detailed summary of this conversation. Include:\n\
1315                1. A brief overview of what was discussed\n\
1316                2. Key facts or information discovered\n\
1317                3. Outcomes or conclusions reached\n\
1318                4. Any action items or next steps if any\n\
1319                Format it in Markdown with headings and bullet points."
1320                    .into(),
1321            ],
1322            cache: false,
1323        });
1324
1325        let task = cx.spawn(async move |thread, cx| {
1326            let stream = model.stream_completion_text(request, &cx);
1327            let Some(mut messages) = stream.await.log_err() else {
1328                thread
1329                    .update(cx, |this, _cx| {
1330                        this.detailed_summary_state = DetailedSummaryState::NotGenerated;
1331                    })
1332                    .log_err();
1333
1334                return;
1335            };
1336
1337            let mut new_detailed_summary = String::new();
1338
1339            while let Some(chunk) = messages.stream.next().await {
1340                if let Some(chunk) = chunk.log_err() {
1341                    new_detailed_summary.push_str(&chunk);
1342                }
1343            }
1344
1345            thread
1346                .update(cx, |this, _cx| {
1347                    this.detailed_summary_state = DetailedSummaryState::Generated {
1348                        text: new_detailed_summary.into(),
1349                        message_id: last_message_id,
1350                    };
1351                })
1352                .log_err();
1353        });
1354
1355        self.detailed_summary_state = DetailedSummaryState::Generating {
1356            message_id: last_message_id,
1357        };
1358
1359        Some(task)
1360    }
1361
1362    pub fn is_generating_detailed_summary(&self) -> bool {
1363        matches!(
1364            self.detailed_summary_state,
1365            DetailedSummaryState::Generating { .. }
1366        )
1367    }
1368
1369    pub fn use_pending_tools(
1370        &mut self,
1371        cx: &mut Context<Self>,
1372    ) -> impl IntoIterator<Item = PendingToolUse> + use<> {
1373        let request = self.to_completion_request(RequestKind::Chat, cx);
1374        let messages = Arc::new(request.messages);
1375        let pending_tool_uses = self
1376            .tool_use
1377            .pending_tool_uses()
1378            .into_iter()
1379            .filter(|tool_use| tool_use.status.is_idle())
1380            .cloned()
1381            .collect::<Vec<_>>();
1382
1383        for tool_use in pending_tool_uses.iter() {
1384            if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
1385                if tool.needs_confirmation(&tool_use.input, cx)
1386                    && !AssistantSettings::get_global(cx).always_allow_tool_actions
1387                {
1388                    self.tool_use.confirm_tool_use(
1389                        tool_use.id.clone(),
1390                        tool_use.ui_text.clone(),
1391                        tool_use.input.clone(),
1392                        messages.clone(),
1393                        tool,
1394                    );
1395                    cx.emit(ThreadEvent::ToolConfirmationNeeded);
1396                } else {
1397                    self.run_tool(
1398                        tool_use.id.clone(),
1399                        tool_use.ui_text.clone(),
1400                        tool_use.input.clone(),
1401                        &messages,
1402                        tool,
1403                        cx,
1404                    );
1405                }
1406            }
1407        }
1408
1409        pending_tool_uses
1410    }
1411
1412    pub fn run_tool(
1413        &mut self,
1414        tool_use_id: LanguageModelToolUseId,
1415        ui_text: impl Into<SharedString>,
1416        input: serde_json::Value,
1417        messages: &[LanguageModelRequestMessage],
1418        tool: Arc<dyn Tool>,
1419        cx: &mut Context<Thread>,
1420    ) {
1421        let task = self.spawn_tool_use(tool_use_id.clone(), messages, input, tool, cx);
1422        self.tool_use
1423            .run_pending_tool(tool_use_id, ui_text.into(), task);
1424    }
1425
1426    fn spawn_tool_use(
1427        &mut self,
1428        tool_use_id: LanguageModelToolUseId,
1429        messages: &[LanguageModelRequestMessage],
1430        input: serde_json::Value,
1431        tool: Arc<dyn Tool>,
1432        cx: &mut Context<Thread>,
1433    ) -> Task<()> {
1434        let tool_name: Arc<str> = tool.name().into();
1435
1436        let run_tool = if self.tools.is_disabled(&tool.source(), &tool_name) {
1437            Task::ready(Err(anyhow!("tool is disabled: {tool_name}")))
1438        } else {
1439            tool.run(
1440                input,
1441                messages,
1442                self.project.clone(),
1443                self.action_log.clone(),
1444                cx,
1445            )
1446        };
1447
1448        cx.spawn({
1449            async move |thread: WeakEntity<Thread>, cx| {
1450                let output = run_tool.await;
1451
1452                thread
1453                    .update(cx, |thread, cx| {
1454                        let pending_tool_use = thread.tool_use.insert_tool_output(
1455                            tool_use_id.clone(),
1456                            tool_name,
1457                            output,
1458                            cx,
1459                        );
1460
1461                        cx.emit(ThreadEvent::ToolFinished {
1462                            tool_use_id,
1463                            pending_tool_use,
1464                            canceled: false,
1465                        });
1466                    })
1467                    .ok();
1468            }
1469        })
1470    }
1471
1472    pub fn attach_tool_results(&mut self, cx: &mut Context<Self>) {
1473        // Insert a user message to contain the tool results.
1474        self.insert_user_message(
1475            // TODO: Sending up a user message without any content results in the model sending back
1476            // responses that also don't have any content. We currently don't handle this case well,
1477            // so for now we provide some text to keep the model on track.
1478            "Here are the tool results.",
1479            Vec::new(),
1480            None,
1481            cx,
1482        );
1483    }
1484
1485    /// Cancels the last pending completion, if there are any pending.
1486    ///
1487    /// Returns whether a completion was canceled.
1488    pub fn cancel_last_completion(&mut self, cx: &mut Context<Self>) -> bool {
1489        let canceled = if self.pending_completions.pop().is_some() {
1490            true
1491        } else {
1492            let mut canceled = false;
1493            for pending_tool_use in self.tool_use.cancel_pending() {
1494                canceled = true;
1495                cx.emit(ThreadEvent::ToolFinished {
1496                    tool_use_id: pending_tool_use.id.clone(),
1497                    pending_tool_use: Some(pending_tool_use),
1498                    canceled: true,
1499                });
1500            }
1501            canceled
1502        };
1503        self.finalize_pending_checkpoint(cx);
1504        canceled
1505    }
1506
1507    /// Returns the feedback given to the thread, if any.
1508    pub fn feedback(&self) -> Option<ThreadFeedback> {
1509        self.feedback
1510    }
1511
1512    /// Reports feedback about the thread and stores it in our telemetry backend.
1513    pub fn report_feedback(
1514        &mut self,
1515        feedback: ThreadFeedback,
1516        cx: &mut Context<Self>,
1517    ) -> Task<Result<()>> {
1518        let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1519        let serialized_thread = self.serialize(cx);
1520        let thread_id = self.id().clone();
1521        let client = self.project.read(cx).client();
1522        self.feedback = Some(feedback);
1523        cx.notify();
1524
1525        cx.background_spawn(async move {
1526            let final_project_snapshot = final_project_snapshot.await;
1527            let serialized_thread = serialized_thread.await?;
1528            let thread_data =
1529                serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
1530
1531            let rating = match feedback {
1532                ThreadFeedback::Positive => "positive",
1533                ThreadFeedback::Negative => "negative",
1534            };
1535            telemetry::event!(
1536                "Assistant Thread Rated",
1537                rating,
1538                thread_id,
1539                thread_data,
1540                final_project_snapshot
1541            );
1542            client.telemetry().flush_events();
1543
1544            Ok(())
1545        })
1546    }
1547
1548    /// Create a snapshot of the current project state including git information and unsaved buffers.
1549    fn project_snapshot(
1550        project: Entity<Project>,
1551        cx: &mut Context<Self>,
1552    ) -> Task<Arc<ProjectSnapshot>> {
1553        let git_store = project.read(cx).git_store().clone();
1554        let worktree_snapshots: Vec<_> = project
1555            .read(cx)
1556            .visible_worktrees(cx)
1557            .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
1558            .collect();
1559
1560        cx.spawn(async move |_, cx| {
1561            let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
1562
1563            let mut unsaved_buffers = Vec::new();
1564            cx.update(|app_cx| {
1565                let buffer_store = project.read(app_cx).buffer_store();
1566                for buffer_handle in buffer_store.read(app_cx).buffers() {
1567                    let buffer = buffer_handle.read(app_cx);
1568                    if buffer.is_dirty() {
1569                        if let Some(file) = buffer.file() {
1570                            let path = file.path().to_string_lossy().to_string();
1571                            unsaved_buffers.push(path);
1572                        }
1573                    }
1574                }
1575            })
1576            .ok();
1577
1578            Arc::new(ProjectSnapshot {
1579                worktree_snapshots,
1580                unsaved_buffer_paths: unsaved_buffers,
1581                timestamp: Utc::now(),
1582            })
1583        })
1584    }
1585
1586    fn worktree_snapshot(
1587        worktree: Entity<project::Worktree>,
1588        git_store: Entity<GitStore>,
1589        cx: &App,
1590    ) -> Task<WorktreeSnapshot> {
1591        cx.spawn(async move |cx| {
1592            // Get worktree path and snapshot
1593            let worktree_info = cx.update(|app_cx| {
1594                let worktree = worktree.read(app_cx);
1595                let path = worktree.abs_path().to_string_lossy().to_string();
1596                let snapshot = worktree.snapshot();
1597                (path, snapshot)
1598            });
1599
1600            let Ok((worktree_path, _snapshot)) = worktree_info else {
1601                return WorktreeSnapshot {
1602                    worktree_path: String::new(),
1603                    git_state: None,
1604                };
1605            };
1606
1607            let git_state = git_store
1608                .update(cx, |git_store, cx| {
1609                    git_store
1610                        .repositories()
1611                        .values()
1612                        .find(|repo| {
1613                            repo.read(cx)
1614                                .abs_path_to_repo_path(&worktree.read(cx).abs_path())
1615                                .is_some()
1616                        })
1617                        .cloned()
1618                })
1619                .ok()
1620                .flatten()
1621                .map(|repo| {
1622                    repo.update(cx, |repo, _| {
1623                        let current_branch =
1624                            repo.branch.as_ref().map(|branch| branch.name.to_string());
1625                        repo.send_job(None, |state, _| async move {
1626                            let RepositoryState::Local { backend, .. } = state else {
1627                                return GitState {
1628                                    remote_url: None,
1629                                    head_sha: None,
1630                                    current_branch,
1631                                    diff: None,
1632                                };
1633                            };
1634
1635                            let remote_url = backend.remote_url("origin");
1636                            let head_sha = backend.head_sha();
1637                            let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
1638
1639                            GitState {
1640                                remote_url,
1641                                head_sha,
1642                                current_branch,
1643                                diff,
1644                            }
1645                        })
1646                    })
1647                });
1648
1649            let git_state = match git_state {
1650                Some(git_state) => match git_state.ok() {
1651                    Some(git_state) => git_state.await.ok(),
1652                    None => None,
1653                },
1654                None => None,
1655            };
1656
1657            WorktreeSnapshot {
1658                worktree_path,
1659                git_state,
1660            }
1661        })
1662    }
1663
1664    pub fn to_markdown(&self, cx: &App) -> Result<String> {
1665        let mut markdown = Vec::new();
1666
1667        if let Some(summary) = self.summary() {
1668            writeln!(markdown, "# {summary}\n")?;
1669        };
1670
1671        for message in self.messages() {
1672            writeln!(
1673                markdown,
1674                "## {role}\n",
1675                role = match message.role {
1676                    Role::User => "User",
1677                    Role::Assistant => "Assistant",
1678                    Role::System => "System",
1679                }
1680            )?;
1681
1682            if !message.context.is_empty() {
1683                writeln!(markdown, "{}", message.context)?;
1684            }
1685
1686            for segment in &message.segments {
1687                match segment {
1688                    MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
1689                    MessageSegment::Thinking(text) => {
1690                        writeln!(markdown, "<think>{}</think>\n", text)?
1691                    }
1692                }
1693            }
1694
1695            for tool_use in self.tool_uses_for_message(message.id, cx) {
1696                writeln!(
1697                    markdown,
1698                    "**Use Tool: {} ({})**",
1699                    tool_use.name, tool_use.id
1700                )?;
1701                writeln!(markdown, "```json")?;
1702                writeln!(
1703                    markdown,
1704                    "{}",
1705                    serde_json::to_string_pretty(&tool_use.input)?
1706                )?;
1707                writeln!(markdown, "```")?;
1708            }
1709
1710            for tool_result in self.tool_results_for_message(message.id) {
1711                write!(markdown, "**Tool Results: {}", tool_result.tool_use_id)?;
1712                if tool_result.is_error {
1713                    write!(markdown, " (Error)")?;
1714                }
1715
1716                writeln!(markdown, "**\n")?;
1717                writeln!(markdown, "{}", tool_result.content)?;
1718            }
1719        }
1720
1721        Ok(String::from_utf8_lossy(&markdown).to_string())
1722    }
1723
1724    pub fn keep_edits_in_range(
1725        &mut self,
1726        buffer: Entity<language::Buffer>,
1727        buffer_range: Range<language::Anchor>,
1728        cx: &mut Context<Self>,
1729    ) {
1730        self.action_log.update(cx, |action_log, cx| {
1731            action_log.keep_edits_in_range(buffer, buffer_range, cx)
1732        });
1733    }
1734
1735    pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
1736        self.action_log
1737            .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
1738    }
1739
1740    pub fn reject_edits_in_range(
1741        &mut self,
1742        buffer: Entity<language::Buffer>,
1743        buffer_range: Range<language::Anchor>,
1744        cx: &mut Context<Self>,
1745    ) -> Task<Result<()>> {
1746        self.action_log.update(cx, |action_log, cx| {
1747            action_log.reject_edits_in_range(buffer, buffer_range, cx)
1748        })
1749    }
1750
1751    pub fn action_log(&self) -> &Entity<ActionLog> {
1752        &self.action_log
1753    }
1754
1755    pub fn project(&self) -> &Entity<Project> {
1756        &self.project
1757    }
1758
1759    pub fn cumulative_token_usage(&self) -> TokenUsage {
1760        self.cumulative_token_usage.clone()
1761    }
1762
1763    pub fn total_token_usage(&self, cx: &App) -> TotalTokenUsage {
1764        let model_registry = LanguageModelRegistry::read_global(cx);
1765        let Some(model) = model_registry.default_model() else {
1766            return TotalTokenUsage::default();
1767        };
1768
1769        let max = model.model.max_token_count();
1770
1771        #[cfg(debug_assertions)]
1772        let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
1773            .unwrap_or("0.8".to_string())
1774            .parse()
1775            .unwrap();
1776        #[cfg(not(debug_assertions))]
1777        let warning_threshold: f32 = 0.8;
1778
1779        let total = self.cumulative_token_usage.total_tokens() as usize;
1780
1781        let ratio = if total >= max {
1782            TokenUsageRatio::Exceeded
1783        } else if total as f32 / max as f32 >= warning_threshold {
1784            TokenUsageRatio::Warning
1785        } else {
1786            TokenUsageRatio::Normal
1787        };
1788
1789        TotalTokenUsage { total, max, ratio }
1790    }
1791
1792    pub fn deny_tool_use(
1793        &mut self,
1794        tool_use_id: LanguageModelToolUseId,
1795        tool_name: Arc<str>,
1796        cx: &mut Context<Self>,
1797    ) {
1798        let err = Err(anyhow::anyhow!(
1799            "Permission to run tool action denied by user"
1800        ));
1801
1802        self.tool_use
1803            .insert_tool_output(tool_use_id.clone(), tool_name, err, cx);
1804
1805        cx.emit(ThreadEvent::ToolFinished {
1806            tool_use_id,
1807            pending_tool_use: None,
1808            canceled: true,
1809        });
1810    }
1811}
1812
1813#[derive(Debug, Clone)]
1814pub enum ThreadError {
1815    PaymentRequired,
1816    MaxMonthlySpendReached,
1817    Message {
1818        header: SharedString,
1819        message: SharedString,
1820    },
1821}
1822
1823#[derive(Debug, Clone)]
1824pub enum ThreadEvent {
1825    ShowError(ThreadError),
1826    StreamedCompletion,
1827    StreamedAssistantText(MessageId, String),
1828    StreamedAssistantThinking(MessageId, String),
1829    DoneStreaming,
1830    MessageAdded(MessageId),
1831    MessageEdited(MessageId),
1832    MessageDeleted(MessageId),
1833    SummaryGenerated,
1834    SummaryChanged,
1835    UsePendingTools,
1836    ToolFinished {
1837        #[allow(unused)]
1838        tool_use_id: LanguageModelToolUseId,
1839        /// The pending tool use that corresponds to this tool.
1840        pending_tool_use: Option<PendingToolUse>,
1841        /// Whether the tool was canceled by the user.
1842        canceled: bool,
1843    },
1844    CheckpointChanged,
1845    ToolConfirmationNeeded,
1846}
1847
1848impl EventEmitter<ThreadEvent> for Thread {}
1849
1850struct PendingCompletion {
1851    id: usize,
1852    _task: Task<()>,
1853}
1854
1855#[cfg(test)]
1856mod tests {
1857    use super::*;
1858    use crate::{ThreadStore, context_store::ContextStore, thread_store};
1859    use assistant_settings::AssistantSettings;
1860    use context_server::ContextServerSettings;
1861    use editor::EditorSettings;
1862    use gpui::TestAppContext;
1863    use project::{FakeFs, Project};
1864    use prompt_store::PromptBuilder;
1865    use serde_json::json;
1866    use settings::{Settings, SettingsStore};
1867    use std::sync::Arc;
1868    use theme::ThemeSettings;
1869    use util::path;
1870    use workspace::Workspace;
1871
1872    #[gpui::test]
1873    async fn test_message_with_context(cx: &mut TestAppContext) {
1874        init_test_settings(cx);
1875
1876        let project = create_test_project(
1877            cx,
1878            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
1879        )
1880        .await;
1881
1882        let (_workspace, _thread_store, thread, context_store) =
1883            setup_test_environment(cx, project.clone()).await;
1884
1885        add_file_to_context(&project, &context_store, "test/code.rs", cx)
1886            .await
1887            .unwrap();
1888
1889        let context =
1890            context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
1891
1892        // Insert user message with context
1893        let message_id = thread.update(cx, |thread, cx| {
1894            thread.insert_user_message("Please explain this code", vec![context], None, cx)
1895        });
1896
1897        // Check content and context in message object
1898        let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
1899
1900        // Use different path format strings based on platform for the test
1901        #[cfg(windows)]
1902        let path_part = r"test\code.rs";
1903        #[cfg(not(windows))]
1904        let path_part = "test/code.rs";
1905
1906        let expected_context = format!(
1907            r#"
1908<context>
1909The following items were attached by the user. You don't need to use other tools to read them.
1910
1911<files>
1912```rs {path_part}
1913fn main() {{
1914    println!("Hello, world!");
1915}}
1916```
1917</files>
1918</context>
1919"#
1920        );
1921
1922        assert_eq!(message.role, Role::User);
1923        assert_eq!(message.segments.len(), 1);
1924        assert_eq!(
1925            message.segments[0],
1926            MessageSegment::Text("Please explain this code".to_string())
1927        );
1928        assert_eq!(message.context, expected_context);
1929
1930        // Check message in request
1931        let request = thread.read_with(cx, |thread, cx| {
1932            thread.to_completion_request(RequestKind::Chat, cx)
1933        });
1934
1935        assert_eq!(request.messages.len(), 1);
1936        let expected_full_message = format!("{}Please explain this code", expected_context);
1937        assert_eq!(request.messages[0].string_contents(), expected_full_message);
1938    }
1939
1940    #[gpui::test]
1941    async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
1942        init_test_settings(cx);
1943
1944        let project = create_test_project(
1945            cx,
1946            json!({
1947                "file1.rs": "fn function1() {}\n",
1948                "file2.rs": "fn function2() {}\n",
1949                "file3.rs": "fn function3() {}\n",
1950            }),
1951        )
1952        .await;
1953
1954        let (_, _thread_store, thread, context_store) =
1955            setup_test_environment(cx, project.clone()).await;
1956
1957        // Open files individually
1958        add_file_to_context(&project, &context_store, "test/file1.rs", cx)
1959            .await
1960            .unwrap();
1961        add_file_to_context(&project, &context_store, "test/file2.rs", cx)
1962            .await
1963            .unwrap();
1964        add_file_to_context(&project, &context_store, "test/file3.rs", cx)
1965            .await
1966            .unwrap();
1967
1968        // Get the context objects
1969        let contexts = context_store.update(cx, |store, _| store.context().clone());
1970        assert_eq!(contexts.len(), 3);
1971
1972        // First message with context 1
1973        let message1_id = thread.update(cx, |thread, cx| {
1974            thread.insert_user_message("Message 1", vec![contexts[0].clone()], None, cx)
1975        });
1976
1977        // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
1978        let message2_id = thread.update(cx, |thread, cx| {
1979            thread.insert_user_message(
1980                "Message 2",
1981                vec![contexts[0].clone(), contexts[1].clone()],
1982                None,
1983                cx,
1984            )
1985        });
1986
1987        // Third message with all three contexts (contexts 1 and 2 should be skipped)
1988        let message3_id = thread.update(cx, |thread, cx| {
1989            thread.insert_user_message(
1990                "Message 3",
1991                vec![
1992                    contexts[0].clone(),
1993                    contexts[1].clone(),
1994                    contexts[2].clone(),
1995                ],
1996                None,
1997                cx,
1998            )
1999        });
2000
2001        // Check what contexts are included in each message
2002        let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
2003            (
2004                thread.message(message1_id).unwrap().clone(),
2005                thread.message(message2_id).unwrap().clone(),
2006                thread.message(message3_id).unwrap().clone(),
2007            )
2008        });
2009
2010        // First message should include context 1
2011        assert!(message1.context.contains("file1.rs"));
2012
2013        // Second message should include only context 2 (not 1)
2014        assert!(!message2.context.contains("file1.rs"));
2015        assert!(message2.context.contains("file2.rs"));
2016
2017        // Third message should include only context 3 (not 1 or 2)
2018        assert!(!message3.context.contains("file1.rs"));
2019        assert!(!message3.context.contains("file2.rs"));
2020        assert!(message3.context.contains("file3.rs"));
2021
2022        // Check entire request to make sure all contexts are properly included
2023        let request = thread.read_with(cx, |thread, cx| {
2024            thread.to_completion_request(RequestKind::Chat, cx)
2025        });
2026
2027        // The request should contain all 3 messages
2028        assert_eq!(request.messages.len(), 3);
2029
2030        // Check that the contexts are properly formatted in each message
2031        assert!(request.messages[0].string_contents().contains("file1.rs"));
2032        assert!(!request.messages[0].string_contents().contains("file2.rs"));
2033        assert!(!request.messages[0].string_contents().contains("file3.rs"));
2034
2035        assert!(!request.messages[1].string_contents().contains("file1.rs"));
2036        assert!(request.messages[1].string_contents().contains("file2.rs"));
2037        assert!(!request.messages[1].string_contents().contains("file3.rs"));
2038
2039        assert!(!request.messages[2].string_contents().contains("file1.rs"));
2040        assert!(!request.messages[2].string_contents().contains("file2.rs"));
2041        assert!(request.messages[2].string_contents().contains("file3.rs"));
2042    }
2043
2044    #[gpui::test]
2045    async fn test_message_without_files(cx: &mut TestAppContext) {
2046        init_test_settings(cx);
2047
2048        let project = create_test_project(
2049            cx,
2050            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
2051        )
2052        .await;
2053
2054        let (_, _thread_store, thread, _context_store) =
2055            setup_test_environment(cx, project.clone()).await;
2056
2057        // Insert user message without any context (empty context vector)
2058        let message_id = thread.update(cx, |thread, cx| {
2059            thread.insert_user_message("What is the best way to learn Rust?", vec![], None, cx)
2060        });
2061
2062        // Check content and context in message object
2063        let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2064
2065        // Context should be empty when no files are included
2066        assert_eq!(message.role, Role::User);
2067        assert_eq!(message.segments.len(), 1);
2068        assert_eq!(
2069            message.segments[0],
2070            MessageSegment::Text("What is the best way to learn Rust?".to_string())
2071        );
2072        assert_eq!(message.context, "");
2073
2074        // Check message in request
2075        let request = thread.read_with(cx, |thread, cx| {
2076            thread.to_completion_request(RequestKind::Chat, cx)
2077        });
2078
2079        assert_eq!(request.messages.len(), 1);
2080        assert_eq!(
2081            request.messages[0].string_contents(),
2082            "What is the best way to learn Rust?"
2083        );
2084
2085        // Add second message, also without context
2086        let message2_id = thread.update(cx, |thread, cx| {
2087            thread.insert_user_message("Are there any good books?", vec![], None, cx)
2088        });
2089
2090        let message2 =
2091            thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
2092        assert_eq!(message2.context, "");
2093
2094        // Check that both messages appear in the request
2095        let request = thread.read_with(cx, |thread, cx| {
2096            thread.to_completion_request(RequestKind::Chat, cx)
2097        });
2098
2099        assert_eq!(request.messages.len(), 2);
2100        assert_eq!(
2101            request.messages[0].string_contents(),
2102            "What is the best way to learn Rust?"
2103        );
2104        assert_eq!(
2105            request.messages[1].string_contents(),
2106            "Are there any good books?"
2107        );
2108    }
2109
2110    #[gpui::test]
2111    async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
2112        init_test_settings(cx);
2113
2114        let project = create_test_project(
2115            cx,
2116            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
2117        )
2118        .await;
2119
2120        let (_workspace, _thread_store, thread, context_store) =
2121            setup_test_environment(cx, project.clone()).await;
2122
2123        // Open buffer and add it to context
2124        let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
2125            .await
2126            .unwrap();
2127
2128        let context =
2129            context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
2130
2131        // Insert user message with the buffer as context
2132        thread.update(cx, |thread, cx| {
2133            thread.insert_user_message("Explain this code", vec![context], None, cx)
2134        });
2135
2136        // Create a request and check that it doesn't have a stale buffer warning yet
2137        let initial_request = thread.read_with(cx, |thread, cx| {
2138            thread.to_completion_request(RequestKind::Chat, cx)
2139        });
2140
2141        // Make sure we don't have a stale file warning yet
2142        let has_stale_warning = initial_request.messages.iter().any(|msg| {
2143            msg.string_contents()
2144                .contains("These files changed since last read:")
2145        });
2146        assert!(
2147            !has_stale_warning,
2148            "Should not have stale buffer warning before buffer is modified"
2149        );
2150
2151        // Modify the buffer
2152        buffer.update(cx, |buffer, cx| {
2153            // Find a position at the end of line 1
2154            buffer.edit(
2155                [(1..1, "\n    println!(\"Added a new line\");\n")],
2156                None,
2157                cx,
2158            );
2159        });
2160
2161        // Insert another user message without context
2162        thread.update(cx, |thread, cx| {
2163            thread.insert_user_message("What does the code do now?", vec![], None, cx)
2164        });
2165
2166        // Create a new request and check for the stale buffer warning
2167        let new_request = thread.read_with(cx, |thread, cx| {
2168            thread.to_completion_request(RequestKind::Chat, cx)
2169        });
2170
2171        // We should have a stale file warning as the last message
2172        let last_message = new_request
2173            .messages
2174            .last()
2175            .expect("Request should have messages");
2176
2177        // The last message should be the stale buffer notification
2178        assert_eq!(last_message.role, Role::User);
2179
2180        // Check the exact content of the message
2181        let expected_content = "These files changed since last read:\n- code.rs\n";
2182        assert_eq!(
2183            last_message.string_contents(),
2184            expected_content,
2185            "Last message should be exactly the stale buffer notification"
2186        );
2187    }
2188
2189    fn init_test_settings(cx: &mut TestAppContext) {
2190        cx.update(|cx| {
2191            let settings_store = SettingsStore::test(cx);
2192            cx.set_global(settings_store);
2193            language::init(cx);
2194            Project::init_settings(cx);
2195            AssistantSettings::register(cx);
2196            thread_store::init(cx);
2197            workspace::init_settings(cx);
2198            ThemeSettings::register(cx);
2199            ContextServerSettings::register(cx);
2200            EditorSettings::register(cx);
2201        });
2202    }
2203
2204    // Helper to create a test project with test files
2205    async fn create_test_project(
2206        cx: &mut TestAppContext,
2207        files: serde_json::Value,
2208    ) -> Entity<Project> {
2209        let fs = FakeFs::new(cx.executor());
2210        fs.insert_tree(path!("/test"), files).await;
2211        Project::test(fs, [path!("/test").as_ref()], cx).await
2212    }
2213
2214    async fn setup_test_environment(
2215        cx: &mut TestAppContext,
2216        project: Entity<Project>,
2217    ) -> (
2218        Entity<Workspace>,
2219        Entity<ThreadStore>,
2220        Entity<Thread>,
2221        Entity<ContextStore>,
2222    ) {
2223        let (workspace, cx) =
2224            cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
2225
2226        let thread_store = cx.update(|_, cx| {
2227            ThreadStore::new(
2228                project.clone(),
2229                Arc::default(),
2230                Arc::new(PromptBuilder::new(None).unwrap()),
2231                cx,
2232            )
2233            .unwrap()
2234        });
2235
2236        let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
2237        let context_store = cx.new(|_cx| ContextStore::new(workspace.downgrade(), None));
2238
2239        (workspace, thread_store, thread, context_store)
2240    }
2241
2242    async fn add_file_to_context(
2243        project: &Entity<Project>,
2244        context_store: &Entity<ContextStore>,
2245        path: &str,
2246        cx: &mut TestAppContext,
2247    ) -> Result<Entity<language::Buffer>> {
2248        let buffer_path = project
2249            .read_with(cx, |project, cx| project.find_project_path(path, cx))
2250            .unwrap();
2251
2252        let buffer = project
2253            .update(cx, |project, cx| project.open_buffer(buffer_path, cx))
2254            .await
2255            .unwrap();
2256
2257        context_store
2258            .update(cx, |store, cx| {
2259                store.add_file_from_buffer(buffer.clone(), cx)
2260            })
2261            .await?;
2262
2263        Ok(buffer)
2264    }
2265}